File size: 5,156 Bytes
6dbdf6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
"""
Simple story generation script for TinyStories 24.5M model.
Usage:
python generate_simple.py
Or with custom prompt:
python generate_simple.py --prompt "Once upon a time there was"
"""
import torch
import argparse
from pathlib import Path
import sys
# Add src to path
sys.path.insert(0, str(Path(__file__).parent))
from src.model.transformer_block import WikiMiniModel
from src.data.tokenizer import load_tokenizer
def load_model(checkpoint_path, tokenizer_path, device='cuda'):
"""Load model and tokenizer."""
# Load tokenizer
print(f"Loading tokenizer from {tokenizer_path}...")
tokenizer = load_tokenizer(tokenizer_path)
print(f"✓ Tokenizer loaded (vocab size: {tokenizer.vocab_size:,})")
# Load checkpoint
print(f"\nLoading model from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
# Get config
if 'config' in checkpoint:
config = checkpoint['config']['model']
else:
raise ValueError("Config not found in checkpoint")
# Ensure vocab size matches tokenizer
config['vocab_size'] = tokenizer.vocab_size
# Create model
model = WikiMiniModel(config)
# Load weights
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
params = sum(p.numel() for p in model.parameters())
print(f"✓ Model loaded ({params/1e6:.1f}M parameters)\n")
return model, tokenizer
def generate_story(model, tokenizer, prompt, max_length=200, temperature=0.8,
top_k=50, top_p=0.95, device='cuda'):
"""Generate a story from a prompt."""
# Encode prompt
input_ids = tokenizer.encode(prompt)
input_ids = torch.tensor([input_ids]).to(device)
print(f"Prompt: {prompt}")
print(f"Generating (max {max_length} tokens)...\n")
generated_ids = input_ids[0].tolist()
with torch.no_grad():
for _ in range(max_length):
# Get predictions
outputs = model(input_ids)
logits = outputs['logits'][0, -1, :]
# Apply temperature
logits = logits / temperature
# Top-k filtering
if top_k > 0:
top_k_logits, top_k_indices = torch.topk(logits, top_k)
logits = torch.full_like(logits, float('-inf'))
logits.scatter_(0, top_k_indices, top_k_logits)
# Top-p filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=0), dim=0)
# Remove tokens with cumulative prob > top_p
remove_indices = cumulative_probs > top_p
remove_indices[1:] = remove_indices[:-1].clone()
remove_indices[0] = False
sorted_logits[remove_indices] = float('-inf')
logits.scatter_(0, sorted_indices, sorted_logits)
# Sample next token
probs = torch.softmax(logits, dim=0)
next_token = torch.multinomial(probs, 1)
# Add to sequence
generated_ids.append(next_token.item())
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
# Stop at EOS
if next_token.item() == tokenizer.eos_token_id:
break
# Decode
story = tokenizer.decode(generated_ids)
return story
def main():
parser = argparse.ArgumentParser(description='Generate TinyStories')
parser.add_argument('--checkpoint', type=str,
default='pytorch_model.bin',
help='Path to model checkpoint')
parser.add_argument('--tokenizer', type=str,
default='./tokenizer',
help='Path to tokenizer directory')
parser.add_argument('--prompt', type=str,
default='Once upon a time there was',
help='Story prompt')
parser.add_argument('--max-length', type=int, default=200,
help='Maximum tokens to generate')
parser.add_argument('--temperature', type=float, default=0.8,
help='Sampling temperature (0.7-0.9 recommended)')
parser.add_argument('--device', type=str, default='cuda',
help='Device: cuda or cpu')
args = parser.parse_args()
# Auto-detect device
if args.device == 'cuda' and not torch.cuda.is_available():
print("CUDA not available, using CPU")
args.device = 'cpu'
# Load model
model, tokenizer = load_model(args.checkpoint, args.tokenizer, args.device)
# Generate
story = generate_story(
model, tokenizer, args.prompt,
max_length=args.max_length,
temperature=args.temperature,
device=args.device
)
# Display
print("="*70)
print("GENERATED STORY")
print("="*70)
print(story)
print("="*70)
if __name__ == '__main__':
main()
|