yat343 commited on
Commit
6a1cd42
·
verified ·
1 Parent(s): 3229f14

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +255 -0
model.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nano GPT: A tiny GPT model built from scratch in pure PyTorch.
3
+
4
+ This is a step-by-step tutorial implementation following Andrej Karpathy's
5
+ build-nanogpt approach. Every piece is explicit and commented.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ from dataclasses import dataclass
12
+
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Step 1: Configuration
16
+ # ---------------------------------------------------------------------------
17
+ # We define all hyperparameters in a single dataclass so they are easy to
18
+ # tweak without hunting through the code.
19
+
20
+ @dataclass
21
+ class GPTConfig:
22
+ block_size: int = 256 # maximum sequence length (context length)
23
+ vocab_size: int = 65 # number of unique characters in our dataset
24
+ n_layer: int = 4 # number of transformer blocks
25
+ n_head: int = 4 # number of attention heads per block
26
+ n_embd: int = 256 # embedding dimension (hidden size)
27
+ dropout: float = 0.0 # dropout probability (0 for small overfit-prone runs)
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Step 2: Causal Self-Attention
32
+ # ---------------------------------------------------------------------------
33
+ # This is the heart of the transformer. For each token we compute three
34
+ # vectors: Query, Key, and Value.
35
+ #
36
+ # Query: "What am I looking for?"
37
+ # Key: "What do I contain?"
38
+ # Value: "What information do I have?"
39
+ #
40
+ # We then compute attention scores = Q @ K.T, mask future tokens so the
41
+ # model cannot "cheat" by looking ahead, and take a weighted sum of Values.
42
+
43
+ class CausalSelfAttention(nn.Module):
44
+ def __init__(self, config: GPTConfig):
45
+ super().__init__()
46
+ assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head"
47
+
48
+ # One linear layer projects input into Q, K, V concatenated together.
49
+ # Output shape: (B, T, 3 * n_embd)
50
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
51
+
52
+ # Output projection back to n_embd
53
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
54
+
55
+ self.n_head = config.n_head
56
+ self.n_embd = config.n_embd
57
+ self.dropout = config.dropout
58
+
59
+ # Register a causal mask (lower-triangular) so we never attend to future tokens.
60
+ # We do this once at init instead of recomputing every forward pass.
61
+ self.register_buffer(
62
+ "bias",
63
+ torch.tril(torch.ones(config.block_size, config.block_size))
64
+ .view(1, 1, config.block_size, config.block_size)
65
+ )
66
+
67
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
68
+ B, T, C = x.size() # batch, sequence length, embedding dim
69
+
70
+ # 1. Compute Q, K, V
71
+ qkv = self.c_attn(x) # (B, T, 3*C)
72
+ q, k, v = qkv.split(self.n_embd, dim=2) # each (B, T, C)
73
+
74
+ # 2. Reshape into (B, n_head, T, head_size) for multi-head attention
75
+ head_size = C // self.n_head
76
+ q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
77
+ k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
78
+ v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
79
+
80
+ # 3. Compute attention scores: (B, nh, T, hs) @ (B, nh, hs, T) -> (B, nh, T, T)
81
+ # We scale by 1/sqrt(head_size) to keep gradients stable.
82
+ att = (q @ k.transpose(-2, -1)) * (1.0 / (head_size ** 0.5))
83
+
84
+ # 4. Apply causal mask: set future positions to -inf so softmax gives 0
85
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
86
+
87
+ # 5. Softmax to get probability distribution over past tokens
88
+ att = F.softmax(att, dim=-1)
89
+
90
+ # 6. Weighted sum of values: (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
91
+ y = att @ v
92
+
93
+ # 7. Concatenate heads back together: (B, nh, T, hs) -> (B, T, nh*hs) = (B, T, C)
94
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
95
+
96
+ # 8. Final output projection
97
+ y = self.c_proj(y)
98
+ return y
99
+
100
+
101
+ # ---------------------------------------------------------------------------
102
+ # Step 3: Feed-Forward Network (MLP)
103
+ # ---------------------------------------------------------------------------
104
+ # After attention, each token gets its own private "thinking" step through
105
+ # a simple two-layer MLP with a GELU non-linearity.
106
+
107
+ class MLP(nn.Module):
108
+ def __init__(self, config: GPTConfig):
109
+ super().__init__()
110
+ # Expand by 4x (common in transformers) then project back down
111
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
112
+ self.gelu = nn.GELU()
113
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
114
+ self.dropout = nn.Dropout(config.dropout)
115
+
116
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
117
+ x = self.c_fc(x)
118
+ x = self.gelu(x)
119
+ x = self.c_proj(x)
120
+ x = self.dropout(x)
121
+ return x
122
+
123
+
124
+ # ---------------------------------------------------------------------------
125
+ # Step 4: Transformer Block
126
+ # ---------------------------------------------------------------------------
127
+ # A block = Attention -> Add & Norm -> MLP -> Add & Norm
128
+ # We use **pre-norm**: normalize BEFORE applying attention/MLP.
129
+ # This is what modern models (GPT-2, GPT-3, Llama, etc.) use.
130
+
131
+ class Block(nn.Module):
132
+ def __init__(self, config: GPTConfig):
133
+ super().__init__()
134
+ self.ln_1 = nn.LayerNorm(config.n_embd)
135
+ self.attn = CausalSelfAttention(config)
136
+ self.ln_2 = nn.LayerNorm(config.n_embd)
137
+ self.mlp = MLP(config)
138
+
139
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
140
+ # Pre-norm residual connections
141
+ x = x + self.attn(self.ln_1(x)) # attention branch
142
+ x = x + self.mlp(self.ln_2(x)) # MLP branch
143
+ return x
144
+
145
+
146
+ # ---------------------------------------------------------------------------
147
+ # Step 5: Full GPT Model
148
+ # ---------------------------------------------------------------------------
149
+ # Putting it all together:
150
+ # 1. Token embedding table (wte): maps character index -> vector
151
+ # 2. Position embedding table (wpe): maps position index -> vector
152
+ # 3. Stack of N transformer blocks
153
+ # 4. Final layer norm
154
+ # 5. Language model head: projects back to vocab_size logits
155
+
156
+ class GPT(nn.Module):
157
+ def __init__(self, config: GPTConfig):
158
+ super().__init__()
159
+ self.config = config
160
+
161
+ self.transformer = nn.ModuleDict({
162
+ "wte": nn.Embedding(config.vocab_size, config.n_embd), # token embeddings
163
+ "wpe": nn.Embedding(config.block_size, config.n_embd), # position embeddings
164
+ "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
165
+ "ln_f": nn.LayerNorm(config.n_embd),
166
+ })
167
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
168
+
169
+ # Weight tying: share the token embedding weights with the output projection.
170
+ # This saves parameters and often improves training.
171
+ self.transformer.wte.weight = self.lm_head.weight
172
+
173
+ # Initialize weights
174
+ self.apply(self._init_weights)
175
+
176
+ def _init_weights(self, module):
177
+ if isinstance(module, nn.Linear):
178
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
179
+ if module.bias is not None:
180
+ torch.nn.init.zeros_(module.bias)
181
+ elif isinstance(module, nn.Embedding):
182
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
183
+
184
+ def forward(
185
+ self,
186
+ idx: torch.Tensor,
187
+ targets: torch.Tensor | None = None,
188
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
189
+ """
190
+ idx: (B, T) integer token indices
191
+ targets:(B, T) integer targets for next-token prediction (optional)
192
+ returns: logits (B, T, vocab_size), loss (scalar or None)
193
+ """
194
+ B, T = idx.size()
195
+ assert T <= self.config.block_size, f"Sequence length {T} exceeds block_size {self.config.block_size}"
196
+
197
+ # 1. Token + position embeddings
198
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # (T,)
199
+ tok_emb = self.transformer.wte(idx) # (B, T, C)
200
+ pos_emb = self.transformer.wpe(pos) # (T, C)
201
+ x = tok_emb + pos_emb # (B, T, C)
202
+
203
+ # 2. Pass through transformer blocks
204
+ for block in self.transformer.h:
205
+ x = block(x)
206
+
207
+ # 3. Final layer norm
208
+ x = self.transformer.ln_f(x)
209
+
210
+ # 4. Project to vocabulary logits
211
+ logits = self.lm_head(x) # (B, T, vocab_size)
212
+
213
+ # 5. Compute cross-entropy loss if targets are provided
214
+ loss = None
215
+ if targets is not None:
216
+ loss = F.cross_entropy(
217
+ logits.view(-1, logits.size(-1)),
218
+ targets.view(-1),
219
+ ignore_index=-1,
220
+ )
221
+
222
+ return logits, loss
223
+
224
+ def generate(
225
+ self,
226
+ idx: torch.Tensor,
227
+ max_new_tokens: int,
228
+ temperature: float = 1.0,
229
+ top_k: int | None = None,
230
+ ) -> torch.Tensor:
231
+ """
232
+ Generate new tokens autoregressively.
233
+ idx: (B, T) starting token indices
234
+ """
235
+ for _ in range(max_new_tokens):
236
+ # Crop to block_size so we never exceed context length
237
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
238
+
239
+ # Forward pass
240
+ logits, _ = self(idx_cond)
241
+ logits = logits[:, -1, :] # take logits for the last token only: (B, vocab_size)
242
+
243
+ # Optional top-k sampling
244
+ if top_k is not None:
245
+ v, _ = torch.topk(logits, top_k, dim=-1)
246
+ logits[logits < v[:, [-1]]] = float("-inf")
247
+
248
+ # Apply temperature and softmax
249
+ probs = F.softmax(logits / temperature, dim=-1)
250
+
251
+ # Sample from the distribution
252
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
253
+ idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
254
+
255
+ return idx