|
|
|
|
|
""" |
|
|
Example usage of TinyStories SLM model |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import tiktoken |
|
|
from model import GPT, GPTConfig |
|
|
|
|
|
def load_model(): |
|
|
"""Load the model and tokenizer""" |
|
|
|
|
|
enc = tiktoken.get_encoding("gpt2") |
|
|
|
|
|
|
|
|
config = GPTConfig( |
|
|
vocab_size=50257, |
|
|
block_size=128, |
|
|
n_layer=6, |
|
|
n_head=6, |
|
|
n_embd=384, |
|
|
dropout=0.0, |
|
|
bias=True |
|
|
) |
|
|
|
|
|
|
|
|
model = GPT(config) |
|
|
model.load_state_dict(torch.load('pytorch_model.bin', map_location='cpu')) |
|
|
model.eval() |
|
|
|
|
|
return model, enc |
|
|
|
|
|
def generate_story(model, enc, prompt, max_tokens=200, temperature=1.0, top_k=None): |
|
|
"""Generate a story from a prompt""" |
|
|
context = torch.tensor(enc.encode_ordinary(prompt)).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
generated = model.generate( |
|
|
context, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
top_k=top_k |
|
|
) |
|
|
|
|
|
return enc.decode(generated.squeeze().tolist()) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
model, enc = load_model() |
|
|
|
|
|
|
|
|
prompts = [ |
|
|
"Once upon a time there was a pumpkin.", |
|
|
"A little girl went to the woods", |
|
|
"Once upon a time in India", |
|
|
"The magic cat could", |
|
|
"In a small village" |
|
|
] |
|
|
|
|
|
print("TinyStories SLM - Story Generation Examples") |
|
|
print("=" * 50) |
|
|
|
|
|
for i, prompt in enumerate(prompts, 1): |
|
|
print(f"\nExample {i}:") |
|
|
print(f"Prompt: {prompt}") |
|
|
print("-" * 30) |
|
|
story = generate_story(model, enc, prompt, max_tokens=150, temperature=0.8) |
|
|
print(story) |
|
|
print("=" * 50) |
|
|
|