|
|
""" |
|
|
Simple story generation script for TinyStories 24.5M model. |
|
|
|
|
|
Usage: |
|
|
python generate_simple.py |
|
|
|
|
|
Or with custom prompt: |
|
|
python generate_simple.py --prompt "Once upon a time there was" |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
import sys |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
|
|
|
from src.model.transformer_block import WikiMiniModel |
|
|
from src.data.tokenizer import load_tokenizer |
|
|
|
|
|
|
|
|
def load_model(checkpoint_path, tokenizer_path, device='cuda'): |
|
|
"""Load model and tokenizer.""" |
|
|
|
|
|
print(f"Loading tokenizer from {tokenizer_path}...") |
|
|
tokenizer = load_tokenizer(tokenizer_path) |
|
|
print(f"✓ Tokenizer loaded (vocab size: {tokenizer.vocab_size:,})") |
|
|
|
|
|
|
|
|
print(f"\nLoading model from {checkpoint_path}...") |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
|
|
|
|
|
|
|
|
if 'config' in checkpoint: |
|
|
config = checkpoint['config']['model'] |
|
|
else: |
|
|
raise ValueError("Config not found in checkpoint") |
|
|
|
|
|
|
|
|
config['vocab_size'] = tokenizer.vocab_size |
|
|
|
|
|
|
|
|
model = WikiMiniModel(config) |
|
|
|
|
|
|
|
|
if 'model_state_dict' in checkpoint: |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
else: |
|
|
model.load_state_dict(checkpoint) |
|
|
|
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
params = sum(p.numel() for p in model.parameters()) |
|
|
print(f"✓ Model loaded ({params/1e6:.1f}M parameters)\n") |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def generate_story(model, tokenizer, prompt, max_length=200, temperature=0.8, |
|
|
top_k=50, top_p=0.95, device='cuda'): |
|
|
"""Generate a story from a prompt.""" |
|
|
|
|
|
input_ids = tokenizer.encode(prompt) |
|
|
input_ids = torch.tensor([input_ids]).to(device) |
|
|
|
|
|
print(f"Prompt: {prompt}") |
|
|
print(f"Generating (max {max_length} tokens)...\n") |
|
|
|
|
|
generated_ids = input_ids[0].tolist() |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(max_length): |
|
|
|
|
|
outputs = model(input_ids) |
|
|
logits = outputs['logits'][0, -1, :] |
|
|
|
|
|
|
|
|
logits = logits / temperature |
|
|
|
|
|
|
|
|
if top_k > 0: |
|
|
top_k_logits, top_k_indices = torch.topk(logits, top_k) |
|
|
logits = torch.full_like(logits, float('-inf')) |
|
|
logits.scatter_(0, top_k_indices, top_k_logits) |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=0), dim=0) |
|
|
|
|
|
|
|
|
remove_indices = cumulative_probs > top_p |
|
|
remove_indices[1:] = remove_indices[:-1].clone() |
|
|
remove_indices[0] = False |
|
|
|
|
|
sorted_logits[remove_indices] = float('-inf') |
|
|
logits.scatter_(0, sorted_indices, sorted_logits) |
|
|
|
|
|
|
|
|
probs = torch.softmax(logits, dim=0) |
|
|
next_token = torch.multinomial(probs, 1) |
|
|
|
|
|
|
|
|
generated_ids.append(next_token.item()) |
|
|
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) |
|
|
|
|
|
|
|
|
if next_token.item() == tokenizer.eos_token_id: |
|
|
break |
|
|
|
|
|
|
|
|
story = tokenizer.decode(generated_ids) |
|
|
return story |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Generate TinyStories') |
|
|
parser.add_argument('--checkpoint', type=str, |
|
|
default='pytorch_model.bin', |
|
|
help='Path to model checkpoint') |
|
|
parser.add_argument('--tokenizer', type=str, |
|
|
default='./tokenizer', |
|
|
help='Path to tokenizer directory') |
|
|
parser.add_argument('--prompt', type=str, |
|
|
default='Once upon a time there was', |
|
|
help='Story prompt') |
|
|
parser.add_argument('--max-length', type=int, default=200, |
|
|
help='Maximum tokens to generate') |
|
|
parser.add_argument('--temperature', type=float, default=0.8, |
|
|
help='Sampling temperature (0.7-0.9 recommended)') |
|
|
parser.add_argument('--device', type=str, default='cuda', |
|
|
help='Device: cuda or cpu') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.device == 'cuda' and not torch.cuda.is_available(): |
|
|
print("CUDA not available, using CPU") |
|
|
args.device = 'cpu' |
|
|
|
|
|
|
|
|
model, tokenizer = load_model(args.checkpoint, args.tokenizer, args.device) |
|
|
|
|
|
|
|
|
story = generate_story( |
|
|
model, tokenizer, args.prompt, |
|
|
max_length=args.max_length, |
|
|
temperature=args.temperature, |
|
|
device=args.device |
|
|
) |
|
|
|
|
|
|
|
|
print("="*70) |
|
|
print("GENERATED STORY") |
|
|
print("="*70) |
|
|
print(story) |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|