""" Simple story generation script for TinyStories 24.5M model. Usage: python generate_simple.py Or with custom prompt: python generate_simple.py --prompt "Once upon a time there was" """ import torch import argparse from pathlib import Path import sys # Add src to path sys.path.insert(0, str(Path(__file__).parent)) from src.model.transformer_block import WikiMiniModel from src.data.tokenizer import load_tokenizer def load_model(checkpoint_path, tokenizer_path, device='cuda'): """Load model and tokenizer.""" # Load tokenizer print(f"Loading tokenizer from {tokenizer_path}...") tokenizer = load_tokenizer(tokenizer_path) print(f"✓ Tokenizer loaded (vocab size: {tokenizer.vocab_size:,})") # Load checkpoint print(f"\nLoading model from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) # Get config if 'config' in checkpoint: config = checkpoint['config']['model'] else: raise ValueError("Config not found in checkpoint") # Ensure vocab size matches tokenizer config['vocab_size'] = tokenizer.vocab_size # Create model model = WikiMiniModel(config) # Load weights if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) model = model.to(device) model.eval() params = sum(p.numel() for p in model.parameters()) print(f"✓ Model loaded ({params/1e6:.1f}M parameters)\n") return model, tokenizer def generate_story(model, tokenizer, prompt, max_length=200, temperature=0.8, top_k=50, top_p=0.95, device='cuda'): """Generate a story from a prompt.""" # Encode prompt input_ids = tokenizer.encode(prompt) input_ids = torch.tensor([input_ids]).to(device) print(f"Prompt: {prompt}") print(f"Generating (max {max_length} tokens)...\n") generated_ids = input_ids[0].tolist() with torch.no_grad(): for _ in range(max_length): # Get predictions outputs = model(input_ids) logits = outputs['logits'][0, -1, :] # Apply temperature logits = logits / temperature # Top-k filtering if top_k > 0: top_k_logits, top_k_indices = torch.topk(logits, top_k) logits = torch.full_like(logits, float('-inf')) logits.scatter_(0, top_k_indices, top_k_logits) # Top-p filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=0), dim=0) # Remove tokens with cumulative prob > top_p remove_indices = cumulative_probs > top_p remove_indices[1:] = remove_indices[:-1].clone() remove_indices[0] = False sorted_logits[remove_indices] = float('-inf') logits.scatter_(0, sorted_indices, sorted_logits) # Sample next token probs = torch.softmax(logits, dim=0) next_token = torch.multinomial(probs, 1) # Add to sequence generated_ids.append(next_token.item()) input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) # Stop at EOS if next_token.item() == tokenizer.eos_token_id: break # Decode story = tokenizer.decode(generated_ids) return story def main(): parser = argparse.ArgumentParser(description='Generate TinyStories') parser.add_argument('--checkpoint', type=str, default='pytorch_model.bin', help='Path to model checkpoint') parser.add_argument('--tokenizer', type=str, default='./tokenizer', help='Path to tokenizer directory') parser.add_argument('--prompt', type=str, default='Once upon a time there was', help='Story prompt') parser.add_argument('--max-length', type=int, default=200, help='Maximum tokens to generate') parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature (0.7-0.9 recommended)') parser.add_argument('--device', type=str, default='cuda', help='Device: cuda or cpu') args = parser.parse_args() # Auto-detect device if args.device == 'cuda' and not torch.cuda.is_available(): print("CUDA not available, using CPU") args.device = 'cpu' # Load model model, tokenizer = load_model(args.checkpoint, args.tokenizer, args.device) # Generate story = generate_story( model, tokenizer, args.prompt, max_length=args.max_length, temperature=args.temperature, device=args.device ) # Display print("="*70) print("GENERATED STORY") print("="*70) print(story) print("="*70) if __name__ == '__main__': main()