tinystories-slm-gpt / example_usage.py
abhilash88's picture
Upload example_usage.py with huggingface_hub
409fd5e verified
#!/usr/bin/env python3
"""
Example usage of TinyStories SLM model
"""
import torch
import tiktoken
from model import GPT, GPTConfig
def load_model():
"""Load the model and tokenizer"""
# Load tokenizer
enc = tiktoken.get_encoding("gpt2")
# Model configuration
config = GPTConfig(
vocab_size=50257,
block_size=128,
n_layer=6,
n_head=6,
n_embd=384,
dropout=0.0, # Set to 0 for inference
bias=True
)
# Load model
model = GPT(config)
model.load_state_dict(torch.load('pytorch_model.bin', map_location='cpu'))
model.eval()
return model, enc
def generate_story(model, enc, prompt, max_tokens=200, temperature=1.0, top_k=None):
"""Generate a story from a prompt"""
context = torch.tensor(enc.encode_ordinary(prompt)).unsqueeze(0)
with torch.no_grad():
generated = model.generate(
context,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=top_k
)
return enc.decode(generated.squeeze().tolist())
if __name__ == "__main__":
# Load model
model, enc = load_model()
# Example prompts
prompts = [
"Once upon a time there was a pumpkin.",
"A little girl went to the woods",
"Once upon a time in India",
"The magic cat could",
"In a small village"
]
print("TinyStories SLM - Story Generation Examples")
print("=" * 50)
for i, prompt in enumerate(prompts, 1):
print(f"\nExample {i}:")
print(f"Prompt: {prompt}")
print("-" * 30)
story = generate_story(model, enc, prompt, max_tokens=150, temperature=0.8)
print(story)
print("=" * 50)