karthick
Upload TinyStories 24.5M model - article generation success
6dbdf6d
"""
Simple story generation script for TinyStories 24.5M model.
Usage:
python generate_simple.py
Or with custom prompt:
python generate_simple.py --prompt "Once upon a time there was"
"""
import torch
import argparse
from pathlib import Path
import sys
# Add src to path
sys.path.insert(0, str(Path(__file__).parent))
from src.model.transformer_block import WikiMiniModel
from src.data.tokenizer import load_tokenizer
def load_model(checkpoint_path, tokenizer_path, device='cuda'):
"""Load model and tokenizer."""
# Load tokenizer
print(f"Loading tokenizer from {tokenizer_path}...")
tokenizer = load_tokenizer(tokenizer_path)
print(f"✓ Tokenizer loaded (vocab size: {tokenizer.vocab_size:,})")
# Load checkpoint
print(f"\nLoading model from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
# Get config
if 'config' in checkpoint:
config = checkpoint['config']['model']
else:
raise ValueError("Config not found in checkpoint")
# Ensure vocab size matches tokenizer
config['vocab_size'] = tokenizer.vocab_size
# Create model
model = WikiMiniModel(config)
# Load weights
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
params = sum(p.numel() for p in model.parameters())
print(f"✓ Model loaded ({params/1e6:.1f}M parameters)\n")
return model, tokenizer
def generate_story(model, tokenizer, prompt, max_length=200, temperature=0.8,
top_k=50, top_p=0.95, device='cuda'):
"""Generate a story from a prompt."""
# Encode prompt
input_ids = tokenizer.encode(prompt)
input_ids = torch.tensor([input_ids]).to(device)
print(f"Prompt: {prompt}")
print(f"Generating (max {max_length} tokens)...\n")
generated_ids = input_ids[0].tolist()
with torch.no_grad():
for _ in range(max_length):
# Get predictions
outputs = model(input_ids)
logits = outputs['logits'][0, -1, :]
# Apply temperature
logits = logits / temperature
# Top-k filtering
if top_k > 0:
top_k_logits, top_k_indices = torch.topk(logits, top_k)
logits = torch.full_like(logits, float('-inf'))
logits.scatter_(0, top_k_indices, top_k_logits)
# Top-p filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=0), dim=0)
# Remove tokens with cumulative prob > top_p
remove_indices = cumulative_probs > top_p
remove_indices[1:] = remove_indices[:-1].clone()
remove_indices[0] = False
sorted_logits[remove_indices] = float('-inf')
logits.scatter_(0, sorted_indices, sorted_logits)
# Sample next token
probs = torch.softmax(logits, dim=0)
next_token = torch.multinomial(probs, 1)
# Add to sequence
generated_ids.append(next_token.item())
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
# Stop at EOS
if next_token.item() == tokenizer.eos_token_id:
break
# Decode
story = tokenizer.decode(generated_ids)
return story
def main():
parser = argparse.ArgumentParser(description='Generate TinyStories')
parser.add_argument('--checkpoint', type=str,
default='pytorch_model.bin',
help='Path to model checkpoint')
parser.add_argument('--tokenizer', type=str,
default='./tokenizer',
help='Path to tokenizer directory')
parser.add_argument('--prompt', type=str,
default='Once upon a time there was',
help='Story prompt')
parser.add_argument('--max-length', type=int, default=200,
help='Maximum tokens to generate')
parser.add_argument('--temperature', type=float, default=0.8,
help='Sampling temperature (0.7-0.9 recommended)')
parser.add_argument('--device', type=str, default='cuda',
help='Device: cuda or cpu')
args = parser.parse_args()
# Auto-detect device
if args.device == 'cuda' and not torch.cuda.is_available():
print("CUDA not available, using CPU")
args.device = 'cpu'
# Load model
model, tokenizer = load_model(args.checkpoint, args.tokenizer, args.device)
# Generate
story = generate_story(
model, tokenizer, args.prompt,
max_length=args.max_length,
temperature=args.temperature,
device=args.device
)
# Display
print("="*70)
print("GENERATED STORY")
print("="*70)
print(story)
print("="*70)
if __name__ == '__main__':
main()