algorythmtechnologies commited on
Commit
f89b6c1
·
verified ·
1 Parent(s): c866f18

Update supernova/model.py

Browse files
Files changed (1) hide show
  1. supernova/model.py +580 -134
supernova/model.py CHANGED
@@ -1,134 +1,580 @@
1
- import math
2
- from dataclasses import dataclass
3
- from typing import Optional, Tuple
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
- from .config import ModelConfig
10
-
11
-
12
- class MultiHeadSelfAttention(nn.Module):
13
- def __init__(self, d_model: int, n_heads: int, dropout: float):
14
- super().__init__()
15
- assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
16
- self.d_model = d_model
17
- self.n_heads = n_heads
18
- self.d_head = d_model // n_heads
19
- self.qkv = nn.Linear(d_model, 3 * d_model, bias=True)
20
- self.out_proj = nn.Linear(d_model, d_model, bias=True)
21
- self.attn_dropout = nn.Dropout(dropout)
22
- self.resid_dropout = nn.Dropout(dropout)
23
-
24
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
25
- B, T, C = x.size()
26
- qkv = self.qkv(x) # (B, T, 3*C)
27
- q, k, v = qkv.split(self.d_model, dim=-1)
28
- # reshape to (B, n_heads, T, d_head)
29
- q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
30
- k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
31
- v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
32
-
33
- # scaled dot-product attention with causal mask
34
- att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
35
- causal = torch.tril(torch.ones(T, T, dtype=torch.bool, device=x.device))
36
- att = att.masked_fill(~causal, float("-inf"))
37
- if attn_mask is not None:
38
- # attn_mask: (B, 1, 1, T) with 0 for keep, -inf for mask
39
- att = att + attn_mask
40
- att = F.softmax(att, dim=-1)
41
- att = self.attn_dropout(att)
42
- y = att @ v # (B, n_heads, T, d_head)
43
- y = y.transpose(1, 2).contiguous().view(B, T, C)
44
- y = self.out_proj(y)
45
- y = self.resid_dropout(y)
46
- return y
47
-
48
-
49
- class TransformerBlock(nn.Module):
50
- def __init__(self, d_model: int, n_heads: int, mlp_ratio: int, dropout: float):
51
- super().__init__()
52
- self.ln1 = nn.LayerNorm(d_model)
53
- self.attn = MultiHeadSelfAttention(d_model, n_heads, dropout)
54
- self.ln2 = nn.LayerNorm(d_model)
55
- self.mlp = nn.Sequential(
56
- nn.Linear(d_model, mlp_ratio * d_model, bias=True),
57
- nn.GELU(),
58
- nn.Linear(mlp_ratio * d_model, d_model, bias=True),
59
- nn.Dropout(dropout),
60
- )
61
-
62
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
63
- x = x + self.attn(self.ln1(x), attn_mask)
64
- x = x + self.mlp(self.ln2(x))
65
- return x
66
-
67
-
68
- class SupernovaModel(nn.Module):
69
- def __init__(self, cfg: ModelConfig):
70
- super().__init__()
71
- self.cfg = cfg
72
- d = cfg.d_model
73
- V = cfg.vocab_size
74
- P = cfg.n_positions if cfg.use_positional_embedding else 0
75
-
76
- self.tok_emb = nn.Embedding(V, d)
77
- self.pos_emb = nn.Embedding(P, d) if cfg.use_positional_embedding else None
78
- self.drop = nn.Dropout(cfg.dropout)
79
- self.blocks = nn.ModuleList([
80
- TransformerBlock(d, cfg.n_heads, cfg.mlp_ratio, cfg.dropout) for _ in range(cfg.n_layers)
81
- ])
82
- self.ln_f = nn.LayerNorm(d) if cfg.final_layer_norm else nn.Identity()
83
- # No separate LM head weight; logits computed via tied embedding matrix
84
- # No LM head bias to preserve exact parameter count formula
85
-
86
- self.apply(self._init_weights)
87
-
88
- def _init_weights(self, module):
89
- if isinstance(module, nn.Linear):
90
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
91
- if module.bias is not None:
92
- nn.init.zeros_(module.bias)
93
- elif isinstance(module, nn.Embedding):
94
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
95
- elif isinstance(module, nn.LayerNorm):
96
- nn.init.ones_(module.weight)
97
- nn.init.zeros_(module.bias)
98
-
99
- def forward(self, input_ids: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
100
- B, T = input_ids.shape
101
- device = input_ids.device
102
- if self.pos_emb is not None:
103
- assert T <= self.cfg.n_positions, f"Sequence length {T} exceeds n_positions {self.cfg.n_positions}"
104
- tok = self.tok_emb(input_ids) # (B, T, d)
105
- if self.pos_emb is not None:
106
- pos = torch.arange(0, T, device=device)
107
- pos = self.pos_emb(pos)[None, :, :] # (1, T, d)
108
- x = tok + pos
109
- else:
110
- x = tok
111
- x = self.drop(x)
112
-
113
- attn_mask = None # causal mask applied inside attention; no padding by default
114
- for block in self.blocks:
115
- x = block(x, attn_mask)
116
- x = self.ln_f(x)
117
-
118
- # Tied output: logits = x @ W_emb^T
119
- logits = x @ self.tok_emb.weight.T # (B, T, V)
120
-
121
- loss = None
122
- if targets is not None:
123
- # shift for next-token prediction
124
- logits_ = logits[:, :-1, :].contiguous()
125
- targets_ = targets[:, 1:].contiguous()
126
- loss = F.cross_entropy(
127
- logits_.view(-1, logits_.size(-1)),
128
- targets_.view(-1),
129
- ignore_index=-100,
130
- )
131
- return logits, loss
132
-
133
- def num_parameters(self) -> int:
134
- return sum(p.numel() for p in self.parameters())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Tuple, List
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .config import ModelConfig
10
+
11
+
12
+ class RotaryEmbedding(nn.Module):
13
+ """Rotary Position Embedding (RoPE) - used in LLaMA, GPT-NeoX"""
14
+ def __init__(self, dim: int, max_seq_len: int = 8192, base: float = 10000.0):
15
+ super().__init__()
16
+ self.dim = dim
17
+ self.max_seq_len = max_seq_len
18
+ self.base = base
19
+
20
+ # Precompute frequencies
21
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
22
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
23
+
24
+ # Build cache for efficiency
25
+ self._build_cache(max_seq_len)
26
+
27
+ def _build_cache(self, seq_len: int):
28
+ """Precompute cos/sin for given sequence length"""
29
+ t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
30
+ freqs = torch.outer(t, self.inv_freq)
31
+ emb = torch.cat((freqs, freqs), dim=-1)
32
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
33
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
34
+ self.cached_seq_len = seq_len
35
+
36
+ def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
37
+ """Return cos and sin for position embeddings"""
38
+ if seq_len > self.cached_seq_len:
39
+ self._build_cache(seq_len)
40
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
41
+
42
+
43
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
44
+ """
45
+ Apply rotary position embedding to queries and keys.
46
+
47
+ Args:
48
+ q: (B, n_heads, T, d_head)
49
+ k: (B, n_heads, T, d_head)
50
+ cos: (T, d_head)
51
+ sin: (T, d_head)
52
+ """
53
+ # Reshape for broadcasting
54
+ cos = cos.unsqueeze(0).unsqueeze(0) # (1, 1, T, d_head)
55
+ sin = sin.unsqueeze(0).unsqueeze(0)
56
+
57
+ # Split into first and second half
58
+ q_half1, q_half2 = q.chunk(2, dim=-1)
59
+ k_half1, k_half2 = k.chunk(2, dim=-1)
60
+
61
+ # Apply rotation
62
+ q_rot = torch.cat([
63
+ q_half1 * cos - q_half2 * sin,
64
+ q_half2 * cos + q_half1 * sin
65
+ ], dim=-1)
66
+
67
+ k_rot = torch.cat([
68
+ k_half1 * cos - k_half2 * sin,
69
+ k_half2 * cos + k_half1 * sin
70
+ ], dim=-1)
71
+
72
+ return q_rot, k_rot
73
+
74
+
75
+ class MultiHeadSelfAttention(nn.Module):
76
+ def __init__(
77
+ self,
78
+ d_model: int,
79
+ n_heads: int,
80
+ dropout: float,
81
+ max_seq_len: int = 8192,
82
+ use_rope: bool = True,
83
+ use_flash: bool = True
84
+ ):
85
+ super().__init__()
86
+ assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
87
+
88
+ self.d_model = d_model
89
+ self.n_heads = n_heads
90
+ self.d_head = d_model // n_heads
91
+ self.use_rope = use_rope
92
+ self.use_flash = use_flash and hasattr(F, 'scaled_dot_product_attention')
93
+
94
+ # QKV projection
95
+ self.qkv = nn.Linear(d_model, 3 * d_model, bias=True)
96
+ self.out_proj = nn.Linear(d_model, d_model, bias=True)
97
+
98
+ # Dropout
99
+ self.attn_dropout = nn.Dropout(dropout)
100
+ self.resid_dropout = nn.Dropout(dropout)
101
+
102
+ # Rotary embeddings
103
+ if use_rope:
104
+ self.rotary_emb = RotaryEmbedding(self.d_head, max_seq_len)
105
+
106
+ # Causal mask (fallback for non-flash attention)
107
+ if not self.use_flash:
108
+ self.register_buffer(
109
+ "causal_mask",
110
+ torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)),
111
+ persistent=False
112
+ )
113
+
114
+ def forward(
115
+ self,
116
+ x: torch.Tensor,
117
+ attn_mask: Optional[torch.Tensor] = None,
118
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
119
+ use_cache: bool = False
120
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
121
+ B, T, C = x.size()
122
+
123
+ # Compute QKV
124
+ qkv = self.qkv(x) # (B, T, 3*C)
125
+ q, k, v = qkv.split(self.d_model, dim=-1)
126
+
127
+ # Reshape to (B, n_heads, T, d_head)
128
+ q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
129
+ k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
130
+ v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
131
+
132
+ # Apply rotary embeddings
133
+ if self.use_rope:
134
+ cos, sin = self.rotary_emb(T)
135
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
136
+
137
+ # KV cache for inference
138
+ if past_kv is not None:
139
+ past_k, past_v = past_kv
140
+ k = torch.cat([past_k, k], dim=2)
141
+ v = torch.cat([past_v, v], dim=2)
142
+
143
+ present_kv = (k, v) if use_cache else None
144
+
145
+ # Compute attention
146
+ if self.use_flash:
147
+ # Use PyTorch's optimized Flash Attention
148
+ y = F.scaled_dot_product_attention(
149
+ q, k, v,
150
+ attn_mask=None,
151
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
152
+ is_causal=True
153
+ )
154
+ else:
155
+ # Fallback: manual attention computation
156
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
157
+
158
+ # Apply causal mask
159
+ T_q, T_k = q.size(2), k.size(2)
160
+ causal = self.causal_mask[:T_q, :T_k]
161
+ att = att.masked_fill(~causal, float("-inf"))
162
+
163
+ # Apply additional mask if provided
164
+ if attn_mask is not None:
165
+ att = att + attn_mask
166
+
167
+ att = F.softmax(att, dim=-1)
168
+ att = self.attn_dropout(att)
169
+ y = att @ v # (B, n_heads, T, d_head)
170
+
171
+ # Reshape and project output
172
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
173
+ y = self.out_proj(y)
174
+ y = self.resid_dropout(y)
175
+
176
+ return y, present_kv
177
+
178
+
179
+ class TransformerBlock(nn.Module):
180
+ def __init__(
181
+ self,
182
+ d_model: int,
183
+ n_heads: int,
184
+ mlp_ratio: int,
185
+ dropout: float,
186
+ max_seq_len: int = 8192,
187
+ use_rope: bool = True,
188
+ use_flash: bool = True
189
+ ):
190
+ super().__init__()
191
+ self.ln1 = nn.LayerNorm(d_model)
192
+ self.attn = MultiHeadSelfAttention(
193
+ d_model, n_heads, dropout, max_seq_len, use_rope, use_flash
194
+ )
195
+ self.ln2 = nn.LayerNorm(d_model)
196
+
197
+ # MLP with GELU activation (SwiGLU would be even better)
198
+ self.mlp = nn.Sequential(
199
+ nn.Linear(d_model, mlp_ratio * d_model, bias=True),
200
+ nn.GELU(),
201
+ nn.Linear(mlp_ratio * d_model, d_model, bias=True),
202
+ nn.Dropout(dropout),
203
+ )
204
+
205
+ def forward(
206
+ self,
207
+ x: torch.Tensor,
208
+ attn_mask: Optional[torch.Tensor] = None,
209
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
210
+ use_cache: bool = False
211
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
212
+ # Pre-LayerNorm architecture
213
+ attn_out, present_kv = self.attn(self.ln1(x), attn_mask, past_kv, use_cache)
214
+ x = x + attn_out
215
+ x = x + self.mlp(self.ln2(x))
216
+ return x, present_kv
217
+
218
+
219
+ class SupernovaModel(nn.Module):
220
+ """
221
+ Optimized Transformer Language Model with:
222
+ - Flash Attention support
223
+ - Rotary Position Embeddings (RoPE)
224
+ - KV caching for efficient generation
225
+ - Gradient checkpointing support
226
+ - Mixed precision training compatibility
227
+ """
228
+
229
+ def __init__(self, cfg: ModelConfig):
230
+ super().__init__()
231
+ self.cfg = cfg
232
+ d = cfg.d_model
233
+ V = cfg.vocab_size
234
+
235
+ # Token embeddings
236
+ self.tok_emb = nn.Embedding(V, d)
237
+
238
+ # Optional learned positional embeddings (if not using RoPE)
239
+ use_rope = getattr(cfg, 'use_rope', True)
240
+ if not use_rope and cfg.use_positional_embedding:
241
+ self.pos_emb = nn.Embedding(cfg.n_positions, d)
242
+ else:
243
+ self.pos_emb = None
244
+
245
+ # Dropout
246
+ self.drop = nn.Dropout(cfg.dropout)
247
+
248
+ # Transformer blocks
249
+ self.blocks = nn.ModuleList([
250
+ TransformerBlock(
251
+ d,
252
+ cfg.n_heads,
253
+ cfg.mlp_ratio,
254
+ cfg.dropout,
255
+ max_seq_len=getattr(cfg, 'n_positions', 8192),
256
+ use_rope=use_rope,
257
+ use_flash=getattr(cfg, 'use_flash', True)
258
+ )
259
+ for _ in range(cfg.n_layers)
260
+ ])
261
+
262
+ # Final layer norm
263
+ self.ln_f = nn.LayerNorm(d) if cfg.final_layer_norm else nn.Identity()
264
+
265
+ # Gradient checkpointing flag (set during training)
266
+ self.gradient_checkpointing = False
267
+
268
+ # Initialize weights
269
+ self.apply(self._init_weights)
270
+
271
+ def _init_weights(self, module):
272
+ """Initialize weights following GPT-2/3 initialization scheme"""
273
+ if isinstance(module, nn.Linear):
274
+ # Use normal distribution with std=0.02
275
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
276
+ if module.bias is not None:
277
+ nn.init.zeros_(module.bias)
278
+ elif isinstance(module, nn.Embedding):
279
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
280
+ elif isinstance(module, nn.LayerNorm):
281
+ nn.init.ones_(module.weight)
282
+ nn.init.zeros_(module.bias)
283
+
284
+ def forward(
285
+ self,
286
+ input_ids: torch.Tensor,
287
+ targets: Optional[torch.Tensor] = None,
288
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
289
+ use_cache: bool = False
290
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
291
+ """
292
+ Forward pass with optional KV caching for efficient generation.
293
+
294
+ Args:
295
+ input_ids: (B, T) input token indices
296
+ targets: (B, T) target token indices for loss computation
297
+ past_key_values: List of (k, v) tuples for each layer (for caching)
298
+ use_cache: Whether to return present key values
299
+
300
+ Returns:
301
+ logits: (B, T, V) output logits
302
+ loss: Optional loss value
303
+ present_key_values: Optional list of present (k, v) for caching
304
+ """
305
+ B, T = input_ids.shape
306
+ device = input_ids.device
307
+
308
+ # Compute embeddings
309
+ tok = self.tok_emb(input_ids) # (B, T, d)
310
+
311
+ # Add positional embeddings if using learned positions (not RoPE)
312
+ if self.pos_emb is not None:
313
+ if past_key_values is not None:
314
+ # During generation with cache, only process new position
315
+ pos_offset = past_key_values[0][0].size(2)
316
+ pos = torch.arange(pos_offset, pos_offset + T, device=device)
317
+ else:
318
+ pos = torch.arange(0, T, device=device)
319
+
320
+ assert pos.max() < self.cfg.n_positions, f"Position {pos.max()} exceeds n_positions {self.cfg.n_positions}"
321
+ pos_emb = self.pos_emb(pos)[None, :, :] # (1, T, d)
322
+ x = tok + pos_emb
323
+ else:
324
+ x = tok
325
+
326
+ x = self.drop(x)
327
+
328
+ # Pass through transformer blocks
329
+ present_key_values = [] if use_cache else None
330
+ for i, block in enumerate(self.blocks):
331
+ past_kv = past_key_values[i] if past_key_values is not None else None
332
+
333
+ if self.gradient_checkpointing and self.training:
334
+ # Use gradient checkpointing to save memory
335
+ def create_custom_forward(module):
336
+ def custom_forward(*inputs):
337
+ return module(*inputs, use_cache=False)
338
+ return custom_forward
339
+
340
+ x, _ = torch.utils.checkpoint.checkpoint(
341
+ create_custom_forward(block),
342
+ x,
343
+ None, # attn_mask
344
+ past_kv,
345
+ use_reentrant=False
346
+ )
347
+ if use_cache:
348
+ present_key_values.append(None) # Placeholder
349
+ else:
350
+ x, present_kv = block(x, attn_mask=None, past_kv=past_kv, use_cache=use_cache)
351
+ if use_cache:
352
+ present_key_values.append(present_kv)
353
+
354
+ x = self.ln_f(x)
355
+
356
+ # Compute logits via tied embeddings
357
+ logits = x @ self.tok_emb.weight.T # (B, T, V)
358
+
359
+ # Compute loss if targets provided
360
+ loss = None
361
+ if targets is not None:
362
+ # Shift for next-token prediction
363
+ logits_ = logits[:, :-1, :].contiguous()
364
+ targets_ = targets[:, 1:].contiguous()
365
+ loss = F.cross_entropy(
366
+ logits_.view(-1, logits_.size(-1)),
367
+ targets_.view(-1),
368
+ ignore_index=-100,
369
+ )
370
+
371
+ return logits, loss, present_key_values
372
+
373
+ @torch.no_grad()
374
+ def generate(
375
+ self,
376
+ idx: torch.Tensor,
377
+ max_new_tokens: int,
378
+ temperature: float = 1.0,
379
+ top_k: Optional[int] = None,
380
+ top_p: Optional[float] = None,
381
+ repetition_penalty: float = 1.0,
382
+ use_cache: bool = True
383
+ ) -> torch.Tensor:
384
+ """
385
+ Generate text autoregressively with various sampling strategies.
386
+
387
+ Args:
388
+ idx: (B, T) input token indices
389
+ max_new_tokens: Number of tokens to generate
390
+ temperature: Sampling temperature (higher = more random)
391
+ top_k: Keep only top k logits (None = disabled)
392
+ top_p: Nucleus sampling threshold (None = disabled)
393
+ repetition_penalty: Penalty for repeated tokens (1.0 = no penalty)
394
+ use_cache: Use KV caching for faster generation
395
+
396
+ Returns:
397
+ (B, T + max_new_tokens) generated token indices
398
+ """
399
+ past_key_values = None
400
+
401
+ for _ in range(max_new_tokens):
402
+ # Crop context if needed (only when not using cache)
403
+ if not use_cache or past_key_values is None:
404
+ max_len = getattr(self.cfg, 'n_positions', 8192)
405
+ idx_cond = idx if idx.size(1) <= max_len else idx[:, -max_len:]
406
+ else:
407
+ # With cache, only process the last token
408
+ idx_cond = idx[:, -1:]
409
+
410
+ # Forward pass
411
+ logits, _, past_key_values = self(
412
+ idx_cond,
413
+ use_cache=use_cache
414
+ )
415
+ logits = logits[:, -1, :] # (B, V)
416
+
417
+ # Apply repetition penalty
418
+ if repetition_penalty != 1.0:
419
+ for i in range(idx.size(0)):
420
+ for token_id in set(idx[i].tolist()):
421
+ logits[i, token_id] /= repetition_penalty
422
+
423
+ # Apply temperature
424
+ logits = logits / temperature
425
+
426
+ # Top-k filtering
427
+ if top_k is not None:
428
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
429
+ logits[logits < v[:, [-1]]] = float('-inf')
430
+
431
+ # Nucleus (top-p) sampling
432
+ if top_p is not None:
433
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
434
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
435
+
436
+ # Remove tokens with cumulative probability above threshold
437
+ sorted_indices_to_remove = cumulative_probs > top_p
438
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
439
+ sorted_indices_to_remove[:, 0] = 0
440
+
441
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
442
+ logits[indices_to_remove] = float('-inf')
443
+
444
+ # Sample next token
445
+ probs = F.softmax(logits, dim=-1)
446
+ idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
447
+
448
+ # Append to sequence
449
+ idx = torch.cat([idx, idx_next], dim=1)
450
+
451
+ return idx
452
+
453
+ def num_parameters(self, only_trainable: bool = True) -> int:
454
+ """
455
+ Count model parameters.
456
+
457
+ Args:
458
+ only_trainable: If True, count only trainable parameters
459
+
460
+ Returns:
461
+ Total number of parameters
462
+ """
463
+ if only_trainable:
464
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
465
+ return sum(p.numel() for p in self.parameters())
466
+
467
+ def parameter_breakdown(self) -> dict:
468
+ """
469
+ Get detailed parameter count by component.
470
+
471
+ Returns:
472
+ Dictionary with parameter counts for each component
473
+ """
474
+ breakdown = {
475
+ "token_embeddings": sum(p.numel() for p in self.tok_emb.parameters()),
476
+ "positional_embeddings": sum(p.numel() for p in self.pos_emb.parameters()) if self.pos_emb else 0,
477
+ "attention": sum(
478
+ p.numel()
479
+ for block in self.blocks
480
+ for p in block.attn.parameters()
481
+ ),
482
+ "mlp": sum(
483
+ p.numel()
484
+ for block in self.blocks
485
+ for p in block.mlp.parameters()
486
+ ),
487
+ "layer_norm": sum(
488
+ p.numel()
489
+ for block in self.blocks
490
+ for p in [block.ln1, block.ln2]
491
+ ) + (sum(p.numel() for p in self.ln_f.parameters()) if self.cfg.final_layer_norm else 0),
492
+ }
493
+ breakdown["total"] = sum(breakdown.values())
494
+ breakdown["total_trainable"] = self.num_parameters(only_trainable=True)
495
+
496
+ return breakdown
497
+
498
+ def estimate_mfu(self, fwdbwd_per_iter: int, dt: float) -> float:
499
+ """
500
+ Estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS.
501
+
502
+ Args:
503
+ fwdbwd_per_iter: Number of forward-backward passes per iteration
504
+ dt: Time taken for iteration (seconds)
505
+
506
+ Returns:
507
+ MFU as a percentage (0-100)
508
+ """
509
+ N = self.num_parameters()
510
+ cfg = self.cfg
511
+ L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.d_model // cfg.n_heads, cfg.n_positions
512
+
513
+ # Estimate FLOPs per token (forward pass only)
514
+ # Approximation: 6N + 12LHQ*T (attention dominates)
515
+ flops_per_token = 6 * N + 12 * L * H * Q * T
516
+ flops_per_fwdbwd = flops_per_token * T * fwdbwd_per_iter * 3 # 3x for backward pass
517
+ flops_per_iter = flops_per_fwdbwd
518
+
519
+ # A100 bfloat16 peak FLOPS
520
+ flops_achieved = flops_per_iter / dt
521
+ flops_promised = 312e12 # A100 GPU bfloat16 peak
522
+
523
+ mfu = flops_achieved / flops_promised * 100
524
+ return mfu
525
+
526
+ def configure_optimizers(
527
+ self,
528
+ weight_decay: float,
529
+ learning_rate: float,
530
+ betas: Tuple[float, float],
531
+ device_type: str
532
+ ):
533
+ """
534
+ Configure optimizer with weight decay only on specific parameters.
535
+
536
+ Args:
537
+ weight_decay: L2 regularization coefficient
538
+ learning_rate: Learning rate
539
+ betas: Adam beta parameters
540
+ device_type: 'cuda' or 'cpu'
541
+
542
+ Returns:
543
+ Configured AdamW optimizer
544
+ """
545
+ # Separate parameters that should and shouldn't have weight decay
546
+ decay = set()
547
+ no_decay = set()
548
+
549
+ whitelist_weight_modules = (nn.Linear,)
550
+ blacklist_weight_modules = (nn.LayerNorm, nn.Embedding)
551
+
552
+ for mn, m in self.named_modules():
553
+ for pn, p in m.named_parameters():
554
+ fpn = f'{mn}.{pn}' if mn else pn
555
+
556
+ if pn.endswith('bias'):
557
+ no_decay.add(fpn)
558
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
559
+ decay.add(fpn)
560
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
561
+ no_decay.add(fpn)
562
+
563
+ # Validate that we've covered all parameters
564
+ param_dict = {pn: p for pn, p in self.named_parameters()}
565
+ inter_params = decay & no_decay
566
+ union_params = decay | no_decay
567
+ assert len(inter_params) == 0, f"Parameters in both decay/no_decay: {inter_params}"
568
+ assert len(param_dict.keys() - union_params) == 0, f"Missing parameters: {param_dict.keys() - union_params}"
569
+
570
+ # Create optimizer groups
571
+ optim_groups = [
572
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
573
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
574
+ ]
575
+
576
+ # Use fused AdamW if on CUDA for better performance
577
+ use_fused = device_type == 'cuda'
578
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused)
579
+
580
+ return optimizer