abhilash88 commited on
Commit
409fd5e
·
verified ·
1 Parent(s): e83d515

Upload example_usage.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. example_usage.py +69 -0
example_usage.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Example usage of TinyStories SLM model
4
+ """
5
+
6
+ import torch
7
+ import tiktoken
8
+ from model import GPT, GPTConfig
9
+
10
+ def load_model():
11
+ """Load the model and tokenizer"""
12
+ # Load tokenizer
13
+ enc = tiktoken.get_encoding("gpt2")
14
+
15
+ # Model configuration
16
+ config = GPTConfig(
17
+ vocab_size=50257,
18
+ block_size=128,
19
+ n_layer=6,
20
+ n_head=6,
21
+ n_embd=384,
22
+ dropout=0.0, # Set to 0 for inference
23
+ bias=True
24
+ )
25
+
26
+ # Load model
27
+ model = GPT(config)
28
+ model.load_state_dict(torch.load('pytorch_model.bin', map_location='cpu'))
29
+ model.eval()
30
+
31
+ return model, enc
32
+
33
+ def generate_story(model, enc, prompt, max_tokens=200, temperature=1.0, top_k=None):
34
+ """Generate a story from a prompt"""
35
+ context = torch.tensor(enc.encode_ordinary(prompt)).unsqueeze(0)
36
+
37
+ with torch.no_grad():
38
+ generated = model.generate(
39
+ context,
40
+ max_new_tokens=max_tokens,
41
+ temperature=temperature,
42
+ top_k=top_k
43
+ )
44
+
45
+ return enc.decode(generated.squeeze().tolist())
46
+
47
+ if __name__ == "__main__":
48
+ # Load model
49
+ model, enc = load_model()
50
+
51
+ # Example prompts
52
+ prompts = [
53
+ "Once upon a time there was a pumpkin.",
54
+ "A little girl went to the woods",
55
+ "Once upon a time in India",
56
+ "The magic cat could",
57
+ "In a small village"
58
+ ]
59
+
60
+ print("TinyStories SLM - Story Generation Examples")
61
+ print("=" * 50)
62
+
63
+ for i, prompt in enumerate(prompts, 1):
64
+ print(f"\nExample {i}:")
65
+ print(f"Prompt: {prompt}")
66
+ print("-" * 30)
67
+ story = generate_story(model, enc, prompt, max_tokens=150, temperature=0.8)
68
+ print(story)
69
+ print("=" * 50)