nishantup commited on
Commit
d8a8efb
·
verified ·
1 Parent(s): f0017e2

Upload nanogpt_slm_tinystories_instruct_inference.py with huggingface_hub

Browse files
nanogpt_slm_tinystories_instruct_inference.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prepared by: Dr. Nishant Upadhyay
3
+
4
+ nanoGPT SLM TinyStories Instruct -- Standalone Inference
5
+ ==========================================================
6
+ 124M parameter instruction-tuned Small Language Model.
7
+ Pretrained on TinyStories (2.1M children's stories) -> SFT on 300K multi-source instructions.
8
+
9
+ Dataset: 300K instruction dataset (Alpaca + Dolly + UltraChat + OpenAssistant + FLAN)
10
+ Format: Unified Task / Question / Answer prompt format
11
+
12
+ Install: pip install torch tiktoken huggingface_hub
13
+ Run: python nanogpt_slm_tinystories_instruct_inference.py
14
+ Import: from nanogpt_slm_tinystories_instruct_inference import ask
15
+ """
16
+
17
+ import torch, torch.nn as nn, torch.nn.functional as F, math, tiktoken
18
+ from dataclasses import dataclass
19
+ from huggingface_hub import hf_hub_download
20
+
21
+ # ==============================================================
22
+ # ARCHITECTURE
23
+ # ==============================================================
24
+
25
+ class LayerNorm(nn.Module):
26
+ def __init__(self, ndim, bias):
27
+ super().__init__()
28
+ self.weight = nn.Parameter(torch.ones(ndim))
29
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
30
+ def forward(self, x):
31
+ return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
32
+
33
+ class CausalSelfAttention(nn.Module):
34
+ def __init__(self, config):
35
+ super().__init__()
36
+ assert config.n_embd % config.n_head == 0
37
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
38
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
39
+ self.attn_dropout = nn.Dropout(config.dropout)
40
+ self.resid_dropout = nn.Dropout(config.dropout)
41
+ self.n_head, self.n_embd = config.n_head, config.n_embd
42
+ self.flash = hasattr(F, 'scaled_dot_product_attention')
43
+ if not self.flash:
44
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
45
+ .view(1, 1, config.block_size, config.block_size))
46
+ def forward(self, x):
47
+ B, T, C = x.size()
48
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
49
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
50
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
51
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
52
+ if self.flash:
53
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None,
54
+ dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
55
+ else:
56
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
57
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
58
+ att = F.softmax(att, dim=-1); att = self.attn_dropout(att); y = att @ v
59
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
60
+ return self.resid_dropout(self.c_proj(y))
61
+
62
+ class MLP(nn.Module):
63
+ def __init__(self, config):
64
+ super().__init__()
65
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
66
+ self.gelu = nn.GELU()
67
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
68
+ self.dropout = nn.Dropout(config.dropout)
69
+ def forward(self, x):
70
+ return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
71
+
72
+ class Block(nn.Module):
73
+ def __init__(self, config):
74
+ super().__init__()
75
+ self.ln1, self.attn = LayerNorm(config.n_embd, config.bias), CausalSelfAttention(config)
76
+ self.ln2, self.mlp = LayerNorm(config.n_embd, config.bias), MLP(config)
77
+ def forward(self, x):
78
+ x = x + self.attn(self.ln1(x))
79
+ return x + self.mlp(self.ln2(x))
80
+
81
+ @dataclass
82
+ class GPTConfig:
83
+ block_size: int = 512; vocab_size: int = 50257
84
+ n_layer: int = 12; n_head: int = 12; n_embd: int = 768
85
+ dropout: float = 0.0; bias: bool = True
86
+
87
+ class GPT(nn.Module):
88
+ def __init__(self, config):
89
+ super().__init__()
90
+ self.config = config
91
+ self.transformer = nn.ModuleDict(dict(
92
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
93
+ wpe=nn.Embedding(config.block_size, config.n_embd),
94
+ drop=nn.Dropout(config.dropout),
95
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
96
+ ln_f=LayerNorm(config.n_embd, config.bias),
97
+ ))
98
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
99
+ self.transformer.wte.weight = self.lm_head.weight
100
+
101
+ def forward(self, idx, targets=None):
102
+ b, t = idx.size()
103
+ pos = torch.arange(0, t, dtype=torch.long, device=idx.device)
104
+ x = self.transformer.drop(self.transformer.wte(idx) + self.transformer.wpe(pos))
105
+ for block in self.transformer.h:
106
+ x = block(x)
107
+ x = self.transformer.ln_f(x)
108
+ if targets is not None:
109
+ logits = self.lm_head(x)
110
+ return logits, F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
111
+ else:
112
+ return self.lm_head(x[:, [-1], :]), None
113
+
114
+ # ==============================================================
115
+ # GENERATION + PROMPT FORMATTING
116
+ # ==============================================================
117
+
118
+ def generate(model, idx, max_new_tokens, context_size, temperature=0.7, top_k=40, eos_id=None):
119
+ for _ in range(max_new_tokens):
120
+ idx_cond = idx[:, -context_size:]
121
+ with torch.no_grad():
122
+ logits, _ = model(idx_cond)
123
+ logits = logits[:, -1, :]
124
+ if top_k is not None:
125
+ v, _ = torch.topk(logits, top_k)
126
+ logits = torch.where(logits < v[:, -1], torch.tensor(float("-inf")).to(logits.device), logits)
127
+ if temperature > 0.0:
128
+ probs = torch.softmax(logits / temperature, dim=-1)
129
+ idx_next = torch.multinomial(probs, num_samples=1)
130
+ else:
131
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True)
132
+ if idx_next == eos_id:
133
+ break
134
+ idx = torch.cat((idx, idx_next), dim=1)
135
+ return idx
136
+
137
+ def format_input(entry):
138
+ parts = [f"Task: {entry['instruction']}"]
139
+ if entry.get('input', '').strip():
140
+ parts.append(f"Question:\n{entry['input']}")
141
+ return '\n\n'.join(parts)
142
+
143
+ def ask(instruction, input_text="", max_tokens=256, temperature=0.7, top_k=40):
144
+ """Ask the instruction-tuned model and get a response."""
145
+ prompt = format_input({"instruction": instruction, "input": input_text})
146
+ idx = torch.tensor(tokenizer.encode(prompt, allowed_special={'<|endoftext|>'})
147
+ ).unsqueeze(0).to(device)
148
+ out = generate(model, idx, max_tokens, config.block_size, temperature, top_k, eos_id=50256)
149
+ return tokenizer.decode(out.squeeze(0).tolist())[len(prompt):].replace("Answer:", "").strip()
150
+
151
+ # ==============================================================
152
+ # LOAD MODEL
153
+ # ==============================================================
154
+
155
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156
+ config = GPTConfig()
157
+ tokenizer = tiktoken.get_encoding("gpt2")
158
+
159
+ weights_path = hf_hub_download(repo_id="nishantup/nanogpt-slm-tinystories-instruct",
160
+ filename="nanogpt_slm_tinystories_instruct.pth")
161
+ model = GPT(config)
162
+ model.load_state_dict(torch.load(weights_path, map_location=device))
163
+ model.to(device)
164
+ model.eval()
165
+
166
+ print(f"nanoGPT SLM TinyStories Instruct loaded: {sum(p.numel() for p in model.parameters()):,} params on {device}")
167
+ print(f"Config: {config.n_layer}L / {config.n_head}H / {config.n_embd}D / ctx={config.block_size}")
168
+ print(f"Format: Task / Question / Answer\n")
169
+
170
+ # ==============================================================
171
+ # EXAMPLES
172
+ # ==============================================================
173
+
174
+ if __name__ == "__main__":
175
+ examples = [
176
+ ("What is the capital of France?", ""),
177
+ ("Explain gravity in simple terms.", ""),
178
+ ("Summarize the following text.",
179
+ "Machine learning enables systems to learn from data rather than being explicitly programmed."),
180
+ ("List three benefits of reading books.", ""),
181
+ ("Write a short poem about the stars.", ""),
182
+ ]
183
+
184
+ for instruction, inp in examples:
185
+ response = ask(instruction, inp)
186
+ print(f"Instruction: {instruction}")
187
+ if inp:
188
+ print(f"Input: {inp[:80]}...")
189
+ print(f"Response: {response}")
190
+ print(f"{'-' * 60}\n")