bmeyer2025 commited on
Commit
94d17bb
·
verified ·
1 Parent(s): adbcc81

Upload src/modernize.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/modernize.py +276 -0
src/modernize.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 3: Modern architecture components.
3
+
4
+ Four swaps over the vanilla transformer:
5
+ 1. RMSNorm — replaces LayerNorm (simpler, faster)
6
+ 2. SwiGLU — replaces ReLU FFN (better gradient flow, used in LLaMA/Qwen)
7
+ 3. RoPE — replaces learned positional embeddings (better length generalization)
8
+ 4. KV Cache — enables fast autoregressive inference
9
+
10
+ These are the components that make a "modern" LLM. After swapping all four,
11
+ the architecture is structurally similar to LLaMA / Qwen at tiny scale.
12
+ """
13
+
14
+ import math
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+
20
+ # ── Swap 1: RMSNorm ────────────────────────────────────────────────────────────
21
+ class RMSNorm(nn.Module):
22
+ """Root Mean Square Layer Normalization.
23
+
24
+ Simpler than LayerNorm: skips the mean-subtraction step, just divides by
25
+ the RMS of the activations and applies a learnable scale.
26
+
27
+ LayerNorm: y = (x - mean(x)) / sqrt(var(x) + eps) * weight + bias
28
+ RMSNorm: y = x / sqrt(mean(x^2) + eps) * weight (no mean, no bias)
29
+
30
+ Used in: LLaMA, Qwen, Mistral, Gemma.
31
+ Paper: "Root Mean Square Layer Normalization" (Zhang & Sennrich, 2019)
32
+ """
33
+
34
+ def __init__(self, n_embd: int, eps: float = 1e-6):
35
+ super().__init__()
36
+ self.eps = eps
37
+ self.weight = nn.Parameter(torch.ones(n_embd)) # learnable scale
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ # x: (B, T, C)
41
+ rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
42
+ return (x / rms) * self.weight
43
+
44
+
45
+ # ── Swap 2: SwiGLU Feed-Forward ───────────────────────────────────────────────
46
+ class SwiGLU(nn.Module):
47
+ """SwiGLU feed-forward network.
48
+
49
+ Replaces the standard FFN: Linear -> ReLU -> Linear
50
+
51
+ SwiGLU uses a gated mechanism:
52
+ gate = xW_gate
53
+ up = xW_up
54
+ out = (gate * silu(up)) @ W_down ← silu(x) = x * sigmoid(x)
55
+
56
+ Three weight matrices instead of two. To keep param count similar to a
57
+ standard 4x FFN, we use hidden_dim = (2/3 * 4 * n_embd) rounded to nearest
58
+ multiple of 64 (hardware-friendly).
59
+
60
+ Used in: LLaMA, Qwen, Mistral, PaLM.
61
+ Paper: "GLU Variants Improve Transformer" (Shazeer, 2020)
62
+ """
63
+
64
+ def __init__(self, n_embd: int, dropout: float):
65
+ super().__init__()
66
+ # Target hidden dim: 2/3 of 4x expansion, rounded to multiple of 64
67
+ hidden = int(2 / 3 * 4 * n_embd)
68
+ hidden = (hidden + 63) // 64 * 64 # round up to multiple of 64
69
+
70
+ self.gate = nn.Linear(n_embd, hidden, bias=False)
71
+ self.up = nn.Linear(n_embd, hidden, bias=False)
72
+ self.down = nn.Linear(hidden, n_embd, bias=False)
73
+ self.drop = nn.Dropout(dropout)
74
+
75
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
+ return self.drop(self.down(F.silu(self.gate(x)) * self.up(x)))
77
+
78
+
79
+ # ── Swap 3: RoPE (Rotary Position Embeddings) ─────────────────────────────────
80
+ def precompute_rope_freqs(head_size: int, seq_len: int, device: torch.device, theta: float = 10000.0):
81
+ """Precompute the RoPE rotation frequencies.
82
+
83
+ For each pair of dimensions (2i, 2i+1) in the head, we use frequency:
84
+ freq_i = 1 / theta^(2i / head_size)
85
+
86
+ Returns cos and sin tables of shape (seq_len, head_size//2).
87
+ """
88
+ # Frequencies decrease geometrically: dim 0 rotates fast, last dim barely moves
89
+ i = torch.arange(0, head_size, 2, device=device).float() # (head_size//2,)
90
+ freqs = 1.0 / (theta ** (i / head_size)) # (head_size//2,)
91
+ pos = torch.arange(seq_len, device=device).float() # (seq_len,)
92
+ angles = torch.outer(pos, freqs) # (seq_len, head_size//2)
93
+ return angles.cos(), angles.sin() # each (seq_len, head_size//2)
94
+
95
+
96
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
97
+ """Apply rotary position embeddings to a query or key tensor.
98
+
99
+ x: (B, n_heads, T, head_size)
100
+ cos: (T, head_size//2)
101
+ sin: (T, head_size//2)
102
+
103
+ RoPE rotates each consecutive pair of dimensions (x1, x2) by:
104
+ x1' = x1*cos - x2*sin
105
+ x2' = x1*sin + x2*cos
106
+
107
+ This encodes relative position into the dot product Q·K without adding
108
+ a separate positional embedding to the token embedding.
109
+ """
110
+ B, H, T, C = x.shape
111
+ x1 = x[..., 0::2] # even dims (B, H, T, C//2)
112
+ x2 = x[..., 1::2] # odd dims (B, H, T, C//2)
113
+
114
+ cos = cos[:T].unsqueeze(0).unsqueeze(0) # (1, 1, T, C//2)
115
+ sin = sin[:T].unsqueeze(0).unsqueeze(0) # (1, 1, T, C//2)
116
+
117
+ x_rot = torch.stack([
118
+ x1 * cos - x2 * sin,
119
+ x1 * sin + x2 * cos,
120
+ ], dim=-1) # (B, H, T, C//2, 2)
121
+
122
+ return x_rot.flatten(-2) # (B, H, T, C)
123
+
124
+
125
+ # ── Swap 4: Attention with RoPE + KV Cache ────────────────────────────────────
126
+ class ModernHead(nn.Module):
127
+ """Single attention head with RoPE and optional KV cache.
128
+
129
+ KV cache stores past (key, value) tensors so during generation we only
130
+ compute attention for the new token, not the entire sequence.
131
+ Disabled during training (we process full sequences with the causal mask).
132
+ """
133
+
134
+ def __init__(self, head_size: int, n_embd: int, block_size: int, dropout: float):
135
+ super().__init__()
136
+ self.head_size = head_size
137
+ self.block_size = block_size
138
+
139
+ self.key = nn.Linear(n_embd, head_size, bias=False)
140
+ self.query = nn.Linear(n_embd, head_size, bias=False)
141
+ self.value = nn.Linear(n_embd, head_size, bias=False)
142
+ self.drop = nn.Dropout(dropout)
143
+
144
+ self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
145
+
146
+ # KV cache (None = disabled, set during inference)
147
+ self._kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None
148
+
149
+ def clear_cache(self):
150
+ self._kv_cache = None
151
+
152
+ def forward(
153
+ self,
154
+ x: torch.Tensor,
155
+ cos: torch.Tensor,
156
+ sin: torch.Tensor,
157
+ use_cache: bool = False,
158
+ ) -> torch.Tensor:
159
+ B, T, C = x.shape
160
+
161
+ k = self.key(x) # (B, T, head_size)
162
+ q = self.query(x) # (B, T, head_size)
163
+ v = self.value(x) # (B, T, head_size)
164
+
165
+ # Reshape for RoPE: (B, 1, T, head_size)
166
+ k = k.unsqueeze(1)
167
+ q = q.unsqueeze(1)
168
+
169
+ # Apply RoPE to Q and K (not V — position only affects attention pattern)
170
+ k = apply_rope(k, cos, sin).squeeze(1) # (B, T, head_size)
171
+ q = apply_rope(q, cos, sin).squeeze(1)
172
+
173
+ # KV cache: append new K/V to cache during inference
174
+ if use_cache:
175
+ if self._kv_cache is not None:
176
+ k_cache, v_cache = self._kv_cache
177
+ k = torch.cat([k_cache, k], dim=1)
178
+ v = torch.cat([v_cache, v], dim=1)
179
+ self._kv_cache = (k, v)
180
+
181
+ T_k = k.shape[1] # key sequence length (may be longer than T with cache)
182
+
183
+ # Scaled dot-product attention
184
+ scores = q @ k.transpose(-2, -1) * (self.head_size ** -0.5) # (B, T, T_k)
185
+
186
+ # Causal mask — only needed during training (full sequence)
187
+ if not use_cache:
188
+ scores = scores.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
189
+
190
+ weights = F.softmax(scores, dim=-1)
191
+ weights = self.drop(weights)
192
+ return weights @ v # (B, T, head_size)
193
+
194
+
195
+ class ModernMultiHeadAttention(nn.Module):
196
+ """Multi-head attention using ModernHead (RoPE + KV cache)."""
197
+
198
+ def __init__(self, n_heads: int, head_size: int, n_embd: int, block_size: int, dropout: float):
199
+ super().__init__()
200
+ self.heads = nn.ModuleList([
201
+ ModernHead(head_size, n_embd, block_size, dropout)
202
+ for _ in range(n_heads)
203
+ ])
204
+ self.proj = nn.Linear(n_heads * head_size, n_embd, bias=False)
205
+ self.drop = nn.Dropout(dropout)
206
+
207
+ def clear_cache(self):
208
+ for h in self.heads:
209
+ h.clear_cache()
210
+
211
+ def forward(self, x, cos, sin, use_cache=False):
212
+ out = torch.cat([h(x, cos, sin, use_cache) for h in self.heads], dim=-1)
213
+ return self.drop(self.proj(out))
214
+
215
+
216
+ # ── Modern Transformer Block ───────────────────────────────────────────────────
217
+ class ModernBlock(nn.Module):
218
+ """Transformer block with all four modern swaps:
219
+ RMSNorm + ModernMultiHeadAttention (RoPE + KV cache) + SwiGLU
220
+ """
221
+
222
+ def __init__(self, n_embd: int, n_heads: int, block_size: int, dropout: float):
223
+ super().__init__()
224
+ head_size = n_embd // n_heads
225
+ self.attn = ModernMultiHeadAttention(n_heads, head_size, n_embd, block_size, dropout)
226
+ self.ffn = SwiGLU(n_embd, dropout)
227
+ self.rn1 = RMSNorm(n_embd)
228
+ self.rn2 = RMSNorm(n_embd)
229
+
230
+ def clear_cache(self):
231
+ self.attn.clear_cache()
232
+
233
+ def forward(self, x, cos, sin, use_cache=False):
234
+ x = x + self.attn(self.rn1(x), cos, sin, use_cache)
235
+ x = x + self.ffn(self.rn2(x))
236
+ return x
237
+
238
+
239
+ # ── Quick sanity check ────────────────────────────────────────────────────────
240
+ if __name__ == "__main__":
241
+ from tokenizer import DEVICE, BLOCK_SIZE
242
+
243
+ n_embd = 384
244
+ n_heads = 6
245
+ dropout = 0.1
246
+ B, T = 2, 64
247
+
248
+ head_size = n_embd // n_heads
249
+
250
+ # Test RMSNorm
251
+ rms = RMSNorm(n_embd).to(DEVICE)
252
+ x = torch.randn(B, T, n_embd, device=DEVICE)
253
+ print(f"RMSNorm output shape : {rms(x).shape}")
254
+
255
+ # Test SwiGLU
256
+ ffn = SwiGLU(n_embd, dropout).to(DEVICE)
257
+ print(f"SwiGLU output shape : {ffn(x).shape}")
258
+ swiglu_params = sum(p.numel() for p in ffn.parameters())
259
+ relu_params = 2 * n_embd * (4 * n_embd) # approximate for comparison
260
+ print(f"SwiGLU params : {swiglu_params:,} (vs ReLU FFN ~{relu_params:,})")
261
+
262
+ # Test RoPE
263
+ cos, sin = precompute_rope_freqs(head_size, BLOCK_SIZE, DEVICE)
264
+ print(f"RoPE cos/sin shape : {cos.shape}")
265
+
266
+ # Test ModernBlock
267
+ block = ModernBlock(n_embd, n_heads, BLOCK_SIZE, dropout).to(DEVICE)
268
+ x = torch.randn(B, T, n_embd, device=DEVICE)
269
+ cos_t, sin_t = precompute_rope_freqs(head_size, T, DEVICE)
270
+ out = block(x, cos_t, sin_t)
271
+ print(f"ModernBlock output : {out.shape} (expected [{B}, {T}, {n_embd}])")
272
+
273
+ block_params = sum(p.numel() for p in block.parameters())
274
+ print(f"ModernBlock params : {block_params:,}")
275
+
276
+ print("\nAll modernize.py components OK.")