ataeff commited on
Commit
9ee0dae
·
verified ·
1 Parent(s): b3a256b

Upload janus/janus_gpt_v4_lowrank.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. janus/janus_gpt_v4_lowrank.py +654 -0
janus/janus_gpt_v4_lowrank.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Janus 285M GPT — nanochat fork with 3-way hybrid attention.
3
+
4
+ Architecture delta from nanochat's gpt.py:
5
+ 1. MLP: ReLU^2 -> SwiGLU (w_gate, w_up, w_down)
6
+ 2. Attention: CausalSelfAttention -> JanusHybridAttention
7
+ - Standard QKV (FA3/SDPA, RoPE, QK-norm)
8
+ - RRPRAM: positional resonance via Wr[H, E, T_r], linear, non-quadratic
9
+ - Janus echo: Wj^T * Wj self-resonance
10
+ - Learned per-head 3-way gate: softmax([3]) blends the three pathways
11
+ 3. No value_embeds / ve_gate (nanochat feature not used in Janus)
12
+
13
+ Everything else from nanochat is preserved:
14
+ - resid_lambdas, x0_lambdas (per-layer residual scaling)
15
+ - smear_gate, smear_lambda (bigram token mixing)
16
+ - backout_lambda (mid-layer subtraction)
17
+ - RoPE, QK-norm, softcap=15, non-parametric RMSNorm
18
+ - Sliding window attention support
19
+
20
+ Confirmed against checkpoint keys from janus_285m_base_final.pt:
21
+ V=32000, E=640, H=10, D=64, B=20, M=1664, T=1024, ~285M params
22
+ """
23
+
24
+ from functools import partial
25
+ from dataclasses import dataclass
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+
31
+ from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
32
+ from nanochat.optim import MuonAdamW, DistMuonAdamW
33
+ from nanochat.flash_attention import flash_attn
34
+
35
+
36
+ @dataclass
37
+ class JanusConfig:
38
+ sequence_len: int = 1024
39
+ vocab_size: int = 32000
40
+ n_layer: int = 20
41
+ n_head: int = 10 # number of query heads (H)
42
+ n_kv_head: int = 10 # same as n_head for Janus (no GQA)
43
+ n_embd: int = 640 # embedding dim (E)
44
+ mlp_hidden: int = 1664 # SwiGLU intermediate dim (M) — NOT 4*n_embd
45
+ rrpram_T: int = 1024 # RRPRAM positional dimension (T_r, same as sequence_len)
46
+ rrpram_rank: int = 64 # low-rank factorization rank (0 = full rank for backward compat)
47
+ # Sliding window attention pattern string, tiled across layers.
48
+ window_pattern: str = "L" # Janus used full context (no sliding window)
49
+
50
+
51
+ def norm(x):
52
+ if hasattr(F, 'rms_norm'):
53
+ return F.rms_norm(x, (x.size(-1),))
54
+ # Fallback for older PyTorch versions
55
+ variance = x.float().pow(2).mean(-1, keepdim=True)
56
+ return (x * torch.rsqrt(variance + 1e-6)).to(x.dtype)
57
+
58
+
59
+ class Linear(nn.Linear):
60
+ """nn.Linear that casts weights to match input dtype in forward."""
61
+ def forward(self, x):
62
+ return F.linear(x, self.weight.to(dtype=x.dtype))
63
+
64
+
65
+ def apply_rotary_emb(x, cos, sin):
66
+ assert x.ndim == 4 # (B, T, H, D)
67
+ d = x.shape[3] // 2
68
+ x1, x2 = x[..., :d], x[..., d:]
69
+ y1 = x1 * cos + x2 * sin
70
+ y2 = x1 * (-sin) + x2 * cos
71
+ return torch.cat([y1, y2], 3)
72
+
73
+
74
+ class JanusHybridAttention(nn.Module):
75
+ """
76
+ 3-way hybrid attention: QKV + RRPRAM + Janus echo, blended by learned per-head gate.
77
+
78
+ Pathway 1 - QKV (standard):
79
+ Standard scaled dot-product attention with RoPE and QK-norm.
80
+ Uses FA3 on Hopper, SDPA fallback elsewhere.
81
+
82
+ Pathway 2 - RRPRAM (positional resonance):
83
+ Wr: nn.Parameter [H, E, T_r] — positional pattern per head
84
+ score[t] = sum_e(x[t,e] * Wr[h,e,t]) — linear in T, non-quadratic
85
+ Attention: broadcast score to all query positions (with causal mask)
86
+ Values: separate Wvr projection
87
+
88
+ Pathway 3 - Janus echo (self-resonance):
89
+ echo = Wj(x) — project through Wj
90
+ echo_back = echo @ Wj.T — project back through transpose (W^T * W)
91
+ score[t] = dot(x[t], echo_back[t]) / sqrt(E)
92
+ Attention: score[i] * score[j] (with causal mask)
93
+ Values: echo itself (Wj(x))
94
+
95
+ Gate: nn.Parameter [H, 3], softmax per head, blends three pathway outputs.
96
+ """
97
+
98
+ def __init__(self, config, layer_idx):
99
+ super().__init__()
100
+ self.layer_idx = layer_idx
101
+ self.n_head = config.n_head
102
+ self.n_kv_head = config.n_kv_head
103
+ self.n_embd = config.n_embd
104
+ self.head_dim = self.n_embd // self.n_head
105
+ assert self.n_embd % self.n_head == 0
106
+
107
+ # Pathway 1: Standard QKV
108
+ self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
109
+ self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
110
+ self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
111
+
112
+ # Pathway 2: RRPRAM
113
+ self.rrpram_rank = config.rrpram_rank
114
+ if config.rrpram_rank > 0:
115
+ # Low-rank factorization: Wr ≈ wr_a @ wr_b
116
+ # Was [H, E, T_r] = 6.5M per layer. Now wr_a[H,E,R]+wr_b[H,R,T] ≈ 1.1M
117
+ self.wr_a = nn.Parameter(torch.zeros(config.n_head, config.n_embd, config.rrpram_rank))
118
+ self.wr_b = nn.Parameter(torch.zeros(config.n_head, config.rrpram_rank, config.rrpram_T))
119
+ else:
120
+ # Full rank (backward compat with v3 checkpoints)
121
+ self.wr = nn.Parameter(torch.zeros(config.n_head, config.n_embd, config.rrpram_T))
122
+ # Separate value projection for RRPRAM
123
+ self.wvr = Linear(self.n_embd, self.n_embd, bias=False)
124
+
125
+ # Pathway 3: Janus echo (W^T * W self-resonance)
126
+ self.wj = Linear(self.n_embd, self.n_embd, bias=False)
127
+
128
+ # Per-head 3-way gate: [H, 3]
129
+ # Pad gate to multiple of 8 for DDP reduce_scatter compatibility
130
+
131
+ self.gate = nn.Parameter(torch.zeros(config.n_head, 3))
132
+
133
+ # Output projection
134
+ self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
135
+
136
+ def _rrpram_attention(self, x, vr, B, T, H, D):
137
+ """
138
+ RRPRAM pathway: positional resonance, linear in T.
139
+
140
+ x: (B, T, E) — input (after norm)
141
+ vr: (B, T, H, D) — RRPRAM values
142
+
143
+ score[t] = sum_e x[b,t,e] * wr[h,e,t]
144
+ This is einsum('bte,het->bht') with causal broadcast.
145
+ """
146
+ E = self.n_embd
147
+ sc = (D ** -0.5)
148
+
149
+ # Compute per-position scores: (B, H, T)
150
+ if self.rrpram_rank > 0:
151
+ # Low-rank: x → wr_a → intermediate (B,H,R) → wr_b → scores (B,H,T)
152
+ wr_a = self.wr_a.to(x.dtype) # (H, E, R)
153
+ wr_b_slice = self.wr_b[:, :, :T].to(x.dtype) # (H, R, T)
154
+ intermediate = torch.einsum('bte,her->bhr', x, wr_a) # (B, H, R)
155
+ scores = torch.einsum('bhr,hrt->bht', intermediate, wr_b_slice) * sc # (B, H, T)
156
+ else:
157
+ # Full rank (backward compat)
158
+ wr_slice = self.wr[:, :, :T].to(x.dtype) # (H, E, T)
159
+ scores = torch.einsum('bte,het->bht', x, wr_slice) * sc # (B, H, T)
160
+
161
+ # Build causal attention from broadcast scores:
162
+ # attn[i, j] = score[j] for j <= i, -inf for j > i
163
+ # Efficient: expand scores to (B, H, 1, T) and apply causal mask
164
+ causal_mask = torch.triu(
165
+ torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
166
+ ) # True where j > i
167
+ attn = scores.unsqueeze(2).expand(B, H, T, T) # (B, H, T, T)
168
+ attn = attn.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
169
+ attn = F.softmax(attn.float(), dim=-1).to(x.dtype)
170
+
171
+ # Apply to values: (B, H, T, T) @ (B, H, T, D) -> (B, H, T, D)
172
+ # vr is (B, T, H, D), transpose to (B, H, T, D)
173
+ vr_t = vr.transpose(1, 2)
174
+ out = torch.matmul(attn, vr_t) # (B, H, T, D)
175
+ return out # (B, H, T, D)
176
+
177
+ def _janus_echo_attention(self, x, B, T, H, D):
178
+ """
179
+ Janus echo pathway: W^T * W self-resonance.
180
+
181
+ echo = Wj(x) — (B, T, E)
182
+ echo_back = echo @ Wj.weight — (B, T, E), i.e. F.linear(echo, Wj.T)
183
+ score[t] = dot(x[t], echo_back[t]) / sqrt(E)
184
+ attn[i,j] = score[i] * score[j] (causal)
185
+ values = echo reshaped to (B, T, H, D)
186
+ """
187
+ E = self.n_embd
188
+
189
+ # echo = F.linear(x, wj) = x @ wj.T
190
+ echo = self.wj(x) # (B, T, E)
191
+
192
+ # echo_back = echo @ wj.weight (standard mm, NOT transposed)
193
+ # wj.weight is [E, E] (PyTorch stores [out, in])
194
+ # F.linear(echo, wj.T) = echo @ wj = echo @ wj.weight.T.T = echo @ wj.weight
195
+ echo_back = torch.matmul(echo, self.wj.weight.to(echo.dtype)) # (B, T, E)
196
+
197
+ # Self-resonance scores (capped to prevent bf16 overflow in outer product)
198
+ scores = (x * echo_back).sum(dim=-1) / (E ** 0.5) # (B, T)
199
+ scores = 15.0 * torch.tanh(scores / 15.0) # softcap like logits
200
+
201
+ # Build attention: attn[i,j] = score[i] * score[j] (with causal mask)
202
+ causal_mask = torch.triu(
203
+ torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
204
+ )
205
+ attn = scores.unsqueeze(-1) * scores.unsqueeze(-2) # (B, T, T)
206
+ attn = attn.masked_fill(causal_mask.unsqueeze(0), float('-inf'))
207
+ attn = F.softmax(attn.float(), dim=-1).to(x.dtype)
208
+
209
+ # Values: echo reshaped to (B, T, H, D) -> (B, H, T, D)
210
+ jv = echo.view(B, T, H, D).transpose(1, 2) # (B, H, T, D)
211
+ # Attention is (B, T, T), need (B, H, T, T) for per-head application
212
+ attn = attn.unsqueeze(1).expand(B, H, T, T) # (B, H, T, T)
213
+ out = torch.matmul(attn, jv) # (B, H, T, D)
214
+ return out # (B, H, T, D)
215
+
216
+ def forward(self, x, cos_sin, window_size, kv_cache):
217
+ B, T, C = x.size()
218
+ H = self.n_head
219
+ D = self.head_dim
220
+
221
+ # === Pathway 1: Standard QKV attention ===
222
+ q = self.c_q(x).view(B, T, H, D)
223
+ k = self.c_k(x).view(B, T, self.n_kv_head, D)
224
+ v = self.c_v(x).view(B, T, self.n_kv_head, D)
225
+
226
+ # RoPE
227
+ cos, sin = cos_sin
228
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
229
+ # QK norm (from nanochat)
230
+ q, k = norm(q), norm(k)
231
+ q = q * 1.2
232
+ k = k * 1.2
233
+
234
+ # Flash Attention
235
+ if kv_cache is None:
236
+ qkv_out = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
237
+ else:
238
+ k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
239
+ qkv_out = flash_attn.flash_attn_with_kvcache(
240
+ q, k_cache, v_cache, k=k, v=v,
241
+ cache_seqlens=kv_cache.cache_seqlens,
242
+ causal=True, window_size=window_size,
243
+ )
244
+ if self.layer_idx == kv_cache.n_layers - 1:
245
+ kv_cache.advance(T)
246
+ # qkv_out: (B, T, H, D) -> (B, H, T, D)
247
+ qkv_out = qkv_out.transpose(1, 2)
248
+
249
+ # === Pathway 2: RRPRAM ===
250
+ vr = self.wvr(x).view(B, T, H, D)
251
+ rrpram_out = self._rrpram_attention(x, vr, B, T, H, D) # (B, H, T, D)
252
+
253
+ # === Pathway 3: Janus echo ===
254
+ janus_out = self._janus_echo_attention(x, B, T, H, D) # (B, H, T, D)
255
+
256
+ # === 3-way gate blending ===
257
+ # gate: [H, 3] -> softmax -> [H, 3]
258
+ g = F.softmax(self.gate.float(), dim=-1).to(x.dtype) # (H, 3)
259
+ # g[:, 0] = QKV weight, g[:, 1] = RRPRAM weight, g[:, 2] = Janus weight
260
+ # Reshape for broadcasting: (1, H, 1, 1) per pathway
261
+ g0 = g[:, 0].view(1, H, 1, 1)
262
+ g1 = g[:, 1].view(1, H, 1, 1)
263
+ g2 = g[:, 2].view(1, H, 1, 1)
264
+
265
+ blended = g0 * qkv_out + g1 * rrpram_out + g2 * janus_out # (B, H, T, D)
266
+
267
+ # (B, H, T, D) -> (B, T, H, D) -> (B, T, E)
268
+ y = blended.transpose(1, 2).contiguous().view(B, T, -1)
269
+ y = self.c_proj(y)
270
+ return y
271
+
272
+
273
+ class SwiGLU_MLP(nn.Module):
274
+ """SwiGLU MLP: gate(x) = SiLU(w_gate(x)) * w_up(x); out = w_down(gate(x))"""
275
+ def __init__(self, config):
276
+ super().__init__()
277
+ self.w_gate = Linear(config.n_embd, config.mlp_hidden, bias=False)
278
+ self.w_up = Linear(config.n_embd, config.mlp_hidden, bias=False)
279
+ self.w_down = Linear(config.mlp_hidden, config.n_embd, bias=False)
280
+
281
+ def forward(self, x):
282
+ return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
283
+
284
+
285
+ class Block(nn.Module):
286
+ def __init__(self, config, layer_idx):
287
+ super().__init__()
288
+ self.attn = JanusHybridAttention(config, layer_idx)
289
+ self.mlp = SwiGLU_MLP(config)
290
+
291
+ def forward(self, x, cos_sin, window_size, kv_cache):
292
+ x = x + self.attn(norm(x), cos_sin, window_size, kv_cache)
293
+ x = x + self.mlp(norm(x))
294
+ return x
295
+
296
+
297
+ class JanusGPT(nn.Module):
298
+ """
299
+ Janus 285M: nanochat GPT with 3-way hybrid attention and SwiGLU MLP.
300
+
301
+ Preserves all nanochat mechanisms:
302
+ - resid_lambdas, x0_lambdas (per-layer residual scaling)
303
+ - smear_gate, smear_lambda (bigram token mixing)
304
+ - backout_lambda (mid-layer subtraction)
305
+ - Softcap=15 on logits
306
+ - Non-parametric RMSNorm
307
+
308
+ Removed from nanochat:
309
+ - value_embeds / ve_gate (not used in Janus)
310
+ """
311
+
312
+ def __init__(self, config, pad_vocab_size_to=64):
313
+ super().__init__()
314
+ self.config = config
315
+ self.window_sizes = self._compute_window_sizes(config)
316
+ padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
317
+ if padded_vocab_size != config.vocab_size:
318
+ print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
319
+ self.transformer = nn.ModuleDict({
320
+ "wte": nn.Embedding(padded_vocab_size, config.n_embd),
321
+ "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
322
+ })
323
+ self.lm_head = Linear(config.n_embd, padded_vocab_size, bias=False)
324
+
325
+ # Per-layer learnable scalars (from nanochat)
326
+ self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer))
327
+ self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer))
328
+
329
+ # Smear: mix previous token's embedding into current token
330
+ self.smear_gate = Linear(24, 1, bias=False)
331
+ self.smear_lambda = nn.Parameter(torch.zeros(1))
332
+
333
+ # Backout: subtract cached mid-layer residual
334
+ self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
335
+
336
+ # Rotary embeddings
337
+ self.rotary_seq_len = config.sequence_len * 10
338
+ head_dim = config.n_embd // config.n_head
339
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
340
+ self.register_buffer("cos", cos, persistent=False)
341
+ self.register_buffer("sin", sin, persistent=False)
342
+
343
+ @torch.no_grad()
344
+ def init_weights(self):
345
+ """Initialize all weights. Matches nanochat conventions + Janus-specific init."""
346
+
347
+ # Embedding and unembedding
348
+ torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
349
+ torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
350
+
351
+ # Transformer blocks
352
+ n_embd = self.config.n_embd
353
+ s = 3**0.5 * n_embd**-0.5
354
+
355
+ for block in self.transformer.h:
356
+ # QKV projections
357
+ torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
358
+ torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
359
+ torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
360
+ torch.nn.init.zeros_(block.attn.c_proj.weight)
361
+
362
+ # RRPRAM: Wr init with small values (positional patterns need to learn from data)
363
+ if hasattr(block.attn, 'wr_a'):
364
+ torch.nn.init.normal_(block.attn.wr_a, mean=0.0, std=0.01)
365
+ torch.nn.init.normal_(block.attn.wr_b, mean=0.0, std=0.01)
366
+ else:
367
+ torch.nn.init.normal_(block.attn.wr, mean=0.0, std=0.01)
368
+ # RRPRAM value projection
369
+ torch.nn.init.uniform_(block.attn.wvr.weight, -s, s)
370
+
371
+ # Janus echo projection
372
+ torch.nn.init.uniform_(block.attn.wj.weight, -s, s)
373
+
374
+ # Gate: init biased toward QKV (standard attention gets most weight early on)
375
+ # [H, 3]: column 0 = QKV (larger), columns 1,2 = RRPRAM, Janus (smaller)
376
+ block.attn.gate.data[:, 0] = 1.0 # QKV dominant
377
+ block.attn.gate.data[:, 1] = -0.5 # RRPRAM starts lower
378
+ block.attn.gate.data[:, 2] = -0.5 # Janus starts lower
379
+ # After softmax: ~0.58 QKV, ~0.21 RRPRAM, ~0.21 Janus
380
+
381
+ # SwiGLU MLP
382
+ torch.nn.init.uniform_(block.mlp.w_gate.weight, -s * 0.4, s * 0.4)
383
+ torch.nn.init.uniform_(block.mlp.w_up.weight, -s * 0.4, s * 0.4)
384
+ torch.nn.init.zeros_(block.mlp.w_down.weight)
385
+
386
+ # Per-layer scalars (from nanochat)
387
+ n_layer = self.config.n_layer
388
+ for i in range(n_layer):
389
+ self.resid_lambdas.data[i] = 1.15 - (0.10 * i / max(n_layer - 1, 1))
390
+ for i in range(n_layer):
391
+ self.x0_lambdas.data[i] = 0.20 - (0.15 * i / max(n_layer - 1, 1))
392
+
393
+ # Rotary embeddings
394
+ head_dim = self.config.n_embd // self.config.n_head
395
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
396
+ self.cos, self.sin = cos, sin
397
+
398
+ # Cast embeddings to COMPUTE_DTYPE
399
+ if COMPUTE_DTYPE != torch.float16:
400
+ self.transformer.wte.to(dtype=COMPUTE_DTYPE)
401
+
402
+ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000, device=None):
403
+ if device is None:
404
+ device = self.transformer.wte.weight.device
405
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
406
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
407
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
408
+ freqs = torch.outer(t, inv_freq)
409
+ cos, sin = freqs.cos(), freqs.sin()
410
+ cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
411
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :]
412
+ return cos, sin
413
+
414
+ def _compute_window_sizes(self, config):
415
+ pattern = config.window_pattern.upper()
416
+ assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}"
417
+ long_window = config.sequence_len
418
+ short_window = -(-long_window // 4 // 128) * 128
419
+ char_to_window = {"L": (long_window, 0), "S": (short_window, 0)}
420
+ window_sizes = []
421
+ for layer_idx in range(config.n_layer):
422
+ char = pattern[layer_idx % len(pattern)]
423
+ window_sizes.append(char_to_window[char])
424
+ window_sizes[-1] = (long_window, 0)
425
+ return window_sizes
426
+
427
+ def get_device(self):
428
+ return self.transformer.wte.weight.device
429
+
430
+ def estimate_flops(self):
431
+ """Estimated FLOPs per token (forward + backward)."""
432
+ nparams = sum(p.numel() for p in self.parameters())
433
+ nparams_exclude = (self.transformer.wte.weight.numel() +
434
+ self.resid_lambdas.numel() + self.x0_lambdas.numel() +
435
+ self.smear_gate.weight.numel() + self.smear_lambda.numel() +
436
+ self.backout_lambda.numel())
437
+ h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
438
+ attn_flops = 0
439
+ for window_size in self.window_sizes:
440
+ window = window_size[0]
441
+ effective_seq = t if window < 0 else min(window, t)
442
+ # QKV attention FLOPs
443
+ attn_flops += 12 * h * q * effective_seq
444
+ # RRPRAM FLOPs (roughly linear, much cheaper than QKV)
445
+ attn_flops += 4 * h * q * effective_seq
446
+ # Janus echo FLOPs
447
+ attn_flops += 4 * h * q * effective_seq
448
+ num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
449
+ return num_flops_per_token
450
+
451
+ def num_scaling_params(self):
452
+ """Parameter counts for scaling law analysis."""
453
+ wte = sum(p.numel() for p in self.transformer.wte.parameters())
454
+ lm_head = sum(p.numel() for p in self.lm_head.parameters())
455
+ transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
456
+ scalars = (self.resid_lambdas.numel() + self.x0_lambdas.numel() +
457
+ self.smear_gate.weight.numel() + self.smear_lambda.numel() +
458
+ self.backout_lambda.numel())
459
+ total = wte + lm_head + transformer_matrices + scalars
460
+ assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
461
+ return {
462
+ 'wte': wte,
463
+ 'lm_head': lm_head,
464
+ 'transformer_matrices': transformer_matrices,
465
+ 'scalars': scalars,
466
+ 'total': total,
467
+ }
468
+
469
+ def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02,
470
+ weight_decay=0.0, scalar_lr=0.5,
471
+ rrpram_lr_scale=0.5, janus_lr_scale=0.5, gate_lr=0.1):
472
+ """
473
+ Setup MuonAdamW optimizer with Janus-specific parameter groups.
474
+
475
+ Extra groups vs nanochat:
476
+ - wr (RRPRAM): AdamW with reduced LR (3D tensor, not suitable for Muon)
477
+ - wj (Janus echo): Muon with slightly reduced LR
478
+ - gate: AdamW with small LR and no weight decay (learned blending, keep stable)
479
+ """
480
+ model_dim = self.config.n_embd
481
+ ddp, rank, local_rank, world_size = get_dist_info()
482
+
483
+ # Collect all parameters by role
484
+ # Standard matrix params (QKV, projections, MLP): go to Muon
485
+ standard_matrix_params = []
486
+ # Janus-specific params: separate groups
487
+ wr_params = [] # RRPRAM positional patterns (3D, AdamW)
488
+ wj_params = [] # Janus echo projection (2D, Muon with separate LR)
489
+ wvr_params = [] # RRPRAM value projection (2D, Muon)
490
+ gate_params = [] # Per-head 3-way gate (2D small, AdamW)
491
+
492
+ for block in self.transformer.h:
493
+ # Standard attention matrices -> Muon
494
+ standard_matrix_params.extend([
495
+ block.attn.c_q.weight,
496
+ block.attn.c_k.weight,
497
+ block.attn.c_v.weight,
498
+ block.attn.c_proj.weight,
499
+ ])
500
+ # SwiGLU MLP -> Muon
501
+ standard_matrix_params.extend([
502
+ block.mlp.w_gate.weight,
503
+ block.mlp.w_up.weight,
504
+ block.mlp.w_down.weight,
505
+ ])
506
+ # RRPRAM Wr -> AdamW (3D tensor, Muon needs 2D)
507
+ if hasattr(block.attn, 'wr_a'):
508
+ wr_params.append(block.attn.wr_a)
509
+ wr_params.append(block.attn.wr_b)
510
+ else:
511
+ wr_params.append(block.attn.wr)
512
+ # RRPRAM Wvr -> Muon (standard 2D matrix)
513
+ wvr_params.append(block.attn.wvr.weight)
514
+ # Janus echo Wj -> Muon with separate LR
515
+ wj_params.append(block.attn.wj.weight)
516
+ # Gate -> AdamW (small 2D [H, 3])
517
+ gate_params.append(block.attn.gate)
518
+
519
+ embedding_params = list(self.transformer.wte.parameters())
520
+ lm_head_params = list(self.lm_head.parameters())
521
+ resid_params = [self.resid_lambdas]
522
+ x0_params = [self.x0_lambdas]
523
+ smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
524
+
525
+ # Verify all params are accounted for
526
+ all_params_list = (standard_matrix_params + wr_params + wj_params + wvr_params +
527
+ gate_params + embedding_params + lm_head_params +
528
+ resid_params + x0_params + smear_params)
529
+ model_params = list(self.parameters())
530
+ assert len(model_params) == len(all_params_list), \
531
+ f"Parameter count mismatch: model has {len(model_params)}, grouped {len(all_params_list)}"
532
+
533
+ # Scale LR proportional to 1/sqrt(dmodel/768)
534
+ dmodel_lr_scale = (model_dim / 768) ** -0.5
535
+ print0(f"Scaling AdamW LR by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}")
536
+
537
+ param_groups = [
538
+ # AdamW groups
539
+ dict(kind='adamw', params=lm_head_params,
540
+ lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01),
541
+ dict(kind='adamw', params=embedding_params,
542
+ lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001),
543
+ dict(kind='adamw', params=resid_params,
544
+ lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
545
+ dict(kind='adamw', params=x0_params,
546
+ lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0),
547
+ dict(kind='adamw', params=smear_params,
548
+ lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
549
+ # Janus-specific AdamW groups
550
+ dict(kind='adamw', params=wr_params,
551
+ lr=embedding_lr * dmodel_lr_scale * rrpram_lr_scale,
552
+ betas=(0.9, 0.999), eps=1e-10, weight_decay=0.01),
553
+ dict(kind='adamw', params=gate_params,
554
+ lr=gate_lr, betas=(0.9, 0.99), eps=1e-10, weight_decay=0.0),
555
+ ]
556
+
557
+ # Muon groups: group by shape for stacking
558
+ all_muon_2d = standard_matrix_params + wvr_params + wj_params
559
+ for shape in sorted({p.shape for p in all_muon_2d}):
560
+ group_params = [p for p in all_muon_2d if p.shape == shape]
561
+ param_groups.append(dict(
562
+ kind='muon', params=group_params, lr=matrix_lr,
563
+ momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
564
+ ))
565
+
566
+ Factory = DistMuonAdamW if ddp else MuonAdamW
567
+ optimizer = Factory(param_groups)
568
+ for group in optimizer.param_groups:
569
+ group["initial_lr"] = group["lr"]
570
+ return optimizer
571
+
572
+ def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
573
+ B, T = idx.size()
574
+
575
+ # Rotary embeddings
576
+ assert T <= self.cos.size(1)
577
+ assert idx.device == self.cos.device
578
+ T0 = 0 if kv_cache is None else kv_cache.get_pos()
579
+ cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
580
+
581
+ # Embed
582
+ x = self.transformer.wte(idx)
583
+ x = x.to(COMPUTE_DTYPE)
584
+ x = norm(x)
585
+
586
+ # Smear (from nanochat)
587
+ if kv_cache is None:
588
+ assert T > 1
589
+ gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
590
+ x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
591
+ else:
592
+ x_pre_smear = kv_cache.prev_embedding
593
+ kv_cache.prev_embedding = x[:, -1:, :]
594
+ if T > 1:
595
+ gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, 1:, :24]))
596
+ x = torch.cat([x[:, :1], x[:, 1:] + gate * x[:, :-1]], dim=1)
597
+ elif x_pre_smear is not None:
598
+ gate = self.smear_lambda.to(x.dtype) * torch.sigmoid(self.smear_gate(x[:, :, :24]))
599
+ x = x + gate * x_pre_smear
600
+
601
+ # Forward the transformer
602
+ x0 = x
603
+ n_layer = self.config.n_layer
604
+ backout_layer = n_layer // 2
605
+ x_backout = None
606
+ for i, block in enumerate(self.transformer.h):
607
+ x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
608
+ x = block(x, cos_sin, self.window_sizes[i], kv_cache)
609
+ if i == backout_layer:
610
+ x_backout = x
611
+
612
+ # Backout subtraction
613
+ if x_backout is not None:
614
+ x = x - self.backout_lambda.to(x.dtype) * x_backout
615
+ x = norm(x)
616
+
617
+ # Logits with softcap
618
+ softcap = 15
619
+ logits = self.lm_head(x)
620
+ logits = logits[..., :self.config.vocab_size]
621
+ logits = logits.float()
622
+ logits = softcap * torch.tanh(logits / softcap)
623
+
624
+ if targets is not None:
625
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),
626
+ ignore_index=-1, reduction=loss_reduction)
627
+ return loss
628
+ else:
629
+ return logits
630
+
631
+ @torch.inference_mode()
632
+ def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
633
+ assert isinstance(tokens, list)
634
+ device = self.get_device()
635
+ rng = None
636
+ if temperature > 0:
637
+ rng = torch.Generator(device=device)
638
+ rng.manual_seed(seed)
639
+ ids = torch.tensor([tokens], dtype=torch.long, device=device)
640
+ for _ in range(max_tokens):
641
+ logits = self.forward(ids)
642
+ logits = logits[:, -1, :]
643
+ if top_k is not None and top_k > 0:
644
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
645
+ logits[logits < v[:, [-1]]] = -float('Inf')
646
+ if temperature > 0:
647
+ logits = logits / temperature
648
+ probs = F.softmax(logits, dim=-1)
649
+ next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
650
+ else:
651
+ next_ids = torch.argmax(logits, dim=-1, keepdim=True)
652
+ ids = torch.cat((ids, next_ids), dim=1)
653
+ token = next_ids.item()
654
+ yield token