ncylich commited on
Commit
7a8dc75
·
verified ·
1 Parent(s): 3b336e9

Upload rung6_moe_g4.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. rung6_moe_g4.py +1372 -0
rung6_moe_g4.py ADDED
@@ -0,0 +1,1372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ rung6_moe_g4.py — Gemma-4 E2B port of rung6_moe.py.
4
+
5
+ Same MECE MoE approach, adapted for Gemma-4's heterogeneous MLP widths:
6
+ - Layers 0-14: D_FFN=6144 (INTERMEDIATE)
7
+ - Layers 15-34: D_FFN=12288 (INTERMEDIATE_WIDE)
8
+ Per-layer A logits have different row counts; PRUNE_K is per-layer.
9
+
10
+ Architecture:
11
+ - Frozen base weights: W_gate, W_up, W_down (per-layer, variable D_FFN)
12
+ - Trainable per-layer:
13
+ * Assignment logits A ∈ R^{D_FFN_i, K}
14
+ * Router W_r ∈ R^{D_MODEL, K_spec} (K_spec = K - K_const)
15
+ - Expert k's soft mask: m_k[j] = softmax(A[j,:] / tau)[k]
16
+ - τ anneals 1.0 → 0.01
17
+ - Per-token forward:
18
+ 1. Apply K_const always-on experts' combined soft mask to h
19
+ 2. Route top-K_active specialist experts via W_r (+ noise)
20
+ 3. Add selected specialist masks to combined mask (softmax-weighted within top-K)
21
+ 4. h = gelu(gate) * up * combined_mask; y = W_down @ h
22
+ - Aux losses: Switch balance (α_b=0.01) + router z-loss (α_z=0.001)
23
+
24
+ Usage:
25
+ # fix_both-style Gemma-4 launch
26
+ python rung6_moe_g4.py --phase g4_fixboth \
27
+ --K 8 --K_const 2 --K_active_spec 2 \
28
+ --init taylor --loss ce \
29
+ --int4_qat --int4_group_size 32 \
30
+ --calib_path 3BASiL/calibration_data/gemma4_e2b_it_bulk_50k.jsonl \
31
+ --eval_calib_path 3BASiL/calibration_data/gemma4_e2b_it_final_50k.jsonl \
32
+ --diverse_calib_path 3BASiL/calibration_data/diverse_wikitext.jsonl \
33
+ --kl_base_lambda 0.5 --kl_base_temp 8.0 \
34
+ --w_drift_lambda 1e-6 \
35
+ --max_steps 2000 --save_checkpoint ckpts/g4_fixboth.pt
36
+
37
+ Output:
38
+ logs/rung6_moe_<phase>_results.json
39
+ """
40
+
41
+ import argparse
42
+ import json
43
+ import math
44
+ import os
45
+ import time
46
+ import torch
47
+ import torch.nn as nn
48
+ import torch.nn.functional as F
49
+ from torch.optim import AdamW
50
+ try:
51
+ import bitsandbytes as bnb
52
+ _HAS_BNB = True
53
+ except ImportError:
54
+ _HAS_BNB = False
55
+ from torch.optim.lr_scheduler import CosineAnnealingLR
56
+ from gemma4_hf import (
57
+ load_gemma4 as load_model,
58
+ N_LAYERS,
59
+ HIDDEN_SIZE as D_MODEL,
60
+ DEVICE,
61
+ DTYPE,
62
+ INTERMEDIATE,
63
+ INTERMEDIATE_WIDE,
64
+ DOUBLE_WIDE_START,
65
+ )
66
+ from moe_recovery import (
67
+ recover_modules_via_generic_pipeline,
68
+ finetune_moe_per_layer,
69
+ )
70
+
71
+ CALIB_DATA_PATH = "3BASiL/calibration_data/gemma4_e2b_it_final_50k.jsonl" # default; override via --calib_path
72
+ BASELINE_PPL = 0.0 # Gemma-4 baselines TBD — set to 0 so diff prints as "+ppl"
73
+ CLEAN_PPL = 0.0
74
+ # MAX_SEQ_LEN: per-record padded length. We use one-sequence-per-record so every
75
+ # sequence starts with BOS + chat-template scaffold (no mid-document chunks losing
76
+ # the BOS / scaffold context). 2048 covers ~70% of `final.jsonl` records fully;
77
+ # longer records are truncated (prompt + response prefix that fits). This eliminates
78
+ # the eval unfairness where mid-document chunks lacked BOS — base model lost context
79
+ # while the trained student had memorized the chunked positions.
80
+ MAX_SEQ_LEN = 2048
81
+ SEQ_LEN = MAX_SEQ_LEN # alias for back-compat (eval/train loops use SEQ_LEN)
82
+ BATCH = 1 # Gemma-4 E2B (4.65B) is ~17× larger than Gemma-3 (270M)
83
+ GRAD_ACCUM = 16 # 1 × 16 = 16 effective — keeps optimizer-step cadence similar
84
+ EVAL_BATCHES = 0 # 0 = no cap; eval scans every chunk in the eval split
85
+ LR = 1e-4
86
+ NOISE_SCALE = 0.020264
87
+ PRUNE_P = 0.40 # 40% kept (same per-token sparsity target as Gemma-3)
88
+
89
+
90
+ def _d_ffn_at(layer_idx: int) -> int:
91
+ """Return the FFN intermediate size for a given layer index."""
92
+ return INTERMEDIATE_WIDE if layer_idx >= DOUBLE_WIDE_START else INTERMEDIATE
93
+
94
+
95
+ def _prune_k_at(layer_idx: int) -> int:
96
+ """Per-layer target of active neurons at bottom-60 parity."""
97
+ return int(_d_ffn_at(layer_idx) * PRUNE_P)
98
+
99
+
100
+ # ─────────────────────────── Int4 QAT (Phase I) ────────────────────────
101
+
102
+ def quantize_int4_groupwise_ste(w, group_size=32):
103
+ """Fake-quantize w (fp) to int4 groupwise along last dim with STE gradient.
104
+
105
+ Symmetric int4: range [-7, 7] (one sign bit + 3 magnitude bits, skip -8 to stay
106
+ symmetric — matches AWQ/GGUF Q4_K convention). One scale per group (groupwise).
107
+ Forward: returns the dequantized weight. Backward: gradient passes through the
108
+ original weight unchanged via straight-through estimator.
109
+
110
+ w: [out_dim, in_dim] — typical nn.Linear.weight shape.
111
+ group_size: in_features per scale group. Default 32 to match GGUF Q4_0 / Q4_K
112
+ block size used by llama.cpp-family inference kernels. Gemma-3's in_features
113
+ (640, 1024, 2048) are all divisible by 32 — no padding needed.
114
+ """
115
+ out_dim, in_dim = w.shape
116
+ orig_dtype = w.dtype
117
+ # Do quant math in fp32 to avoid bf16 precision loss in scale/round steps.
118
+ w_fp32 = w.float()
119
+ pad = (group_size - in_dim % group_size) % group_size
120
+ if pad:
121
+ w_padded = F.pad(w_fp32, (0, pad))
122
+ else:
123
+ w_padded = w_fp32
124
+ n_groups = (in_dim + pad) // group_size
125
+ w_g = w_padded.view(out_dim, n_groups, group_size)
126
+ max_abs = w_g.abs().amax(dim=-1, keepdim=True).clamp_min(1e-6)
127
+ scale = max_abs / 7.0 # [out_dim, n_groups, 1]
128
+ w_int = torch.round(w_g / scale).clamp(-7, 7)
129
+ w_deq = (w_int * scale).view(out_dim, -1) # [out_dim, in_dim+pad]
130
+ if pad:
131
+ w_deq = w_deq[:, :in_dim]
132
+ w_deq = w_deq.to(orig_dtype)
133
+ # STE: forward = w_deq, backward = identity w.r.t. w
134
+ return w + (w_deq - w).detach()
135
+
136
+
137
+ class Int4QuantLinear(nn.Linear):
138
+ """Drop-in nn.Linear replacement that fake-quantizes its weight to int4 in forward.
139
+
140
+ Subclasses nn.Linear, so state_dict keys (.weight, .bias) are identical to a
141
+ regular nn.Linear — cross-loadable. The quantization happens only in forward,
142
+ leaving the stored fp weight intact (trained by QAT gradients).
143
+ """
144
+ _group_size = 32 # GGUF Q4_0 / Q4_K block size — matches deploy-time inference kernels
145
+
146
+ def forward(self, x):
147
+ w_q = quantize_int4_groupwise_ste(self.weight, self._group_size)
148
+ return F.linear(x, w_q, self.bias)
149
+
150
+
151
+ def apply_int4_inplace(model, group_size=32,
152
+ target_substrings=("gate_proj", "up_proj", "down_proj",
153
+ "q_proj", "k_proj", "v_proj", "o_proj")):
154
+ """Actually quantize target Linear weights to int4 grid IN-PLACE (deployment simulation).
155
+
156
+ Unlike wrap_int4 (which fake-quantizes every forward via STE), this snaps the
157
+ stored fp weight to the int4 grid exactly once. Post-call the model behaves as
158
+ if it's been exported to a real int4 deploy format — no runtime quantize overhead.
159
+ Returns count of modified weights.
160
+ """
161
+ count = 0
162
+ with torch.no_grad():
163
+ for name, mod in model.named_modules():
164
+ if not isinstance(mod, nn.Linear):
165
+ continue
166
+ if isinstance(mod, Int4QuantLinear):
167
+ continue
168
+ if not any(t in name for t in target_substrings):
169
+ continue
170
+ w_q = quantize_int4_groupwise_ste(mod.weight, group_size).detach()
171
+ mod.weight.data.copy_(w_q)
172
+ count += 1
173
+ return count
174
+
175
+
176
+ def apply_gaussian_noise_inplace(model, noise_scale,
177
+ target_substrings=("gate_proj", "up_proj", "down_proj",
178
+ "q_proj", "k_proj", "v_proj", "o_proj"),
179
+ seed=0):
180
+ """Add N(0, noise_scale × p.std()) to target Linear weights IN-PLACE.
181
+
182
+ Gaussian proxy for quantization noise. For int4 group=32, analytically equivalent
183
+ noise_scale ≈ 0.129 (from σ_q/σ_w ≈ √((max_abs/7)²/12)/σ_w with
184
+ max_abs ≈ σ_w·√(2·ln group_size)). Returns count of modified weights.
185
+ """
186
+ gen = torch.Generator(device=DEVICE)
187
+ gen.manual_seed(seed)
188
+ count = 0
189
+ with torch.no_grad():
190
+ for name, mod in model.named_modules():
191
+ if not isinstance(mod, nn.Linear):
192
+ continue
193
+ if isinstance(mod, Int4QuantLinear):
194
+ # Skip to avoid compounding noise with fake-quant in forward (ambiguous semantics).
195
+ continue
196
+ if not any(t in name for t in target_substrings):
197
+ continue
198
+ w = mod.weight.data
199
+ std_w = w.float().std()
200
+ noise = torch.randn(w.shape, generator=gen, device=w.device, dtype=torch.float32) * std_w * noise_scale
201
+ w.add_(noise.to(w.dtype))
202
+ count += 1
203
+ return count
204
+
205
+
206
+ class LoRALinear(nn.Module):
207
+ """Wraps an nn.Linear (incl. Int4QuantLinear). Base is frozen; trainable rank-r LoRA delta.
208
+
209
+ forward(x) = base(x) + (alpha / rank) * lora_b(lora_a(x))
210
+ A is initialized Kaiming-uniform; B is zero — so initial output equals base output.
211
+ """
212
+ def __init__(self, base_linear: nn.Linear, rank: int, alpha: float):
213
+ super().__init__()
214
+ self.base = base_linear
215
+ for p in self.base.parameters():
216
+ p.requires_grad_(False)
217
+ in_dim, out_dim = base_linear.in_features, base_linear.out_features
218
+ self.lora_a = nn.Linear(in_dim, rank, bias=False,
219
+ device=base_linear.weight.device, dtype=base_linear.weight.dtype)
220
+ self.lora_b = nn.Linear(rank, out_dim, bias=False,
221
+ device=base_linear.weight.device, dtype=base_linear.weight.dtype)
222
+ nn.init.kaiming_uniform_(self.lora_a.weight, a=5 ** 0.5)
223
+ nn.init.zeros_(self.lora_b.weight)
224
+ self.scale = alpha / rank
225
+
226
+ def forward(self, x):
227
+ return self.base(x) + self.lora_b(self.lora_a(x)) * self.scale
228
+
229
+
230
+ def wrap_lora(model, rank: int, alpha: float,
231
+ target_substrings=("gate_proj", "up_proj", "down_proj",
232
+ "q_proj", "k_proj", "v_proj", "o_proj")):
233
+ """Replace target Linear modules with LoRALinear. Base is frozen; only LoRA A/B train.
234
+
235
+ Run AFTER wrap_int4 so the base inside LoRALinear is the int4-quantized Linear.
236
+ Returns number of wrapped modules and total LoRA params added.
237
+ """
238
+ count = 0
239
+ n_params = 0
240
+ for name, mod in list(model.named_modules()):
241
+ if not isinstance(mod, nn.Linear):
242
+ continue
243
+ if isinstance(mod, LoRALinear):
244
+ continue
245
+ if not any(t in name for t in target_substrings):
246
+ continue
247
+ new_mod = LoRALinear(mod, rank=rank, alpha=alpha)
248
+ parent_name, _, attr = name.rpartition(".")
249
+ parent = model.get_submodule(parent_name) if parent_name else model
250
+ setattr(parent, attr, new_mod)
251
+ n_params += sum(p.numel() for p in new_mod.lora_a.parameters()) + \
252
+ sum(p.numel() for p in new_mod.lora_b.parameters())
253
+ count += 1
254
+ return count, n_params
255
+
256
+
257
+ def wrap_int4(model, target_substrings=("gate_proj", "up_proj", "down_proj",
258
+ "q_proj", "k_proj", "v_proj", "o_proj")):
259
+ """Replace matching nn.Linear modules with Int4QuantLinear (subclass).
260
+ State-dict keys unchanged; weights shared (same Tensor). Returns count of wrapped modules."""
261
+ count = 0
262
+ for name, mod in list(model.named_modules()):
263
+ if not isinstance(mod, nn.Linear):
264
+ continue
265
+ if isinstance(mod, Int4QuantLinear):
266
+ continue # already wrapped
267
+ if not any(t in name for t in target_substrings):
268
+ continue
269
+ new_mod = Int4QuantLinear(mod.in_features, mod.out_features,
270
+ bias=mod.bias is not None,
271
+ device=mod.weight.device, dtype=mod.weight.dtype)
272
+ # Share the underlying tensor (no copy) so optimizer state and grad flow are preserved
273
+ new_mod.weight = mod.weight
274
+ if mod.bias is not None:
275
+ new_mod.bias = mod.bias
276
+ parent_name, _, attr = name.rpartition(".")
277
+ parent = model.get_submodule(parent_name) if parent_name else model
278
+ setattr(parent, attr, new_mod)
279
+ count += 1
280
+ return count
281
+
282
+
283
+ # ────────────────────────────── utilities ──────────────────────────────
284
+
285
+ def corrupt_model(model, noise_scale=NOISE_SCALE, seed=42):
286
+ rng = torch.Generator(); rng.manual_seed(seed)
287
+ with torch.no_grad():
288
+ for p in model.parameters():
289
+ noise = torch.randn(p.shape, generator=rng, dtype=p.dtype).to(p.device)
290
+ p.add_(noise * p.std() * noise_scale)
291
+ print(f" Corrupted model with noise_scale={noise_scale}")
292
+
293
+
294
+ def load_seqs(tokenizer, split="train", calib_path=None, raw_text=False):
295
+ """Load tokenized sequences from a JSONL calibration file.
296
+ 80/20 train/eval split within the file. Use split='all' to return all records
297
+ (useful when train path and eval path differ — no need to withhold).
298
+ Pass `calib_path` to override default.
299
+
300
+ Format: ONE sequence per record, length MAX_SEQ_LEN, padded with pad_token_id.
301
+ Every sequence starts with BOS + the chat-template scaffold. We do NOT chunk
302
+ long records into multiple length-MAX_SEQ_LEN pieces, because mid-document
303
+ chunks lack the BOS + chat scaffold and drop the base model out of distribution
304
+ while a trained student memorizes the chunked positions — that produced an
305
+ unfair eval comparison previously. Records longer than MAX_SEQ_LEN are
306
+ truncated to MAX_SEQ_LEN (prompt + response prefix that fits). Records whose
307
+ user prompt alone exceeds MAX_SEQ_LEN-1 (no room for a response token) are
308
+ skipped — they have no scored positions.
309
+
310
+ If raw_text=True, expects JSONL with a 'text' field (e.g., wikitext) and skips
311
+ the chat-template wrapping — suitable for KL-to-teacher regularization on a
312
+ diverse pretraining-style corpus. Each record produces one MAX_SEQ_LEN
313
+ sequence (truncated if longer) with every non-pad position scored.
314
+ """
315
+ path = calib_path or CALIB_DATA_PATH
316
+ records = []
317
+ with open(path) as f:
318
+ for line in f:
319
+ records.append(json.loads(line))
320
+ if split == "all":
321
+ pass # use all records
322
+ else:
323
+ n_train = int(len(records) * 0.8)
324
+ records = records[:n_train] if split == "train" else records[n_train:]
325
+
326
+ pad_id = tokenizer.pad_token_id or 0
327
+
328
+ if raw_text:
329
+ # Pretraining-style: each record has a 'text' field; no chat template.
330
+ # One sequence per record (truncated to MAX_SEQ_LEN); every non-pad
331
+ # position is scored (no prompt mask — every token is informative).
332
+ seqs = []
333
+ for r in records:
334
+ text = r.get("text") or r.get("content") or ""
335
+ if not text:
336
+ continue
337
+ ids = tokenizer.encode(text, add_special_tokens=True)
338
+ if len(ids) < 32:
339
+ continue
340
+ ids = ids[:MAX_SEQ_LEN]
341
+ n = len(ids)
342
+ pad_len = MAX_SEQ_LEN - n
343
+ # labels[t] = ids[t+1] for t in [0, n-2]; labels[n-1] = -100 (boundary);
344
+ # labels[n:] = -100 (pad). Total length = MAX_SEQ_LEN.
345
+ labels_list = ids[1:n] + [-100] * (pad_len + 1)
346
+ assert len(labels_list) == MAX_SEQ_LEN
347
+ seqs.append({
348
+ "input_ids": torch.tensor(ids + [pad_id] * pad_len, dtype=torch.long),
349
+ "labels": torch.tensor(labels_list, dtype=torch.long),
350
+ })
351
+ return seqs
352
+ # Chat-template format: mask user-prompt tokens with -100 in labels so only assistant
353
+ # response tokens are scored (CE training and PPL eval). Avoids over-fitting to the user
354
+ # prompt distribution and gives a meaningful PPL number for "how well does the model
355
+ # produce the assistant response given the prompt." Pretraining-style raw_text above is
356
+ # NOT masked (every token is informative).
357
+ seqs = []
358
+ for r in records:
359
+ msgs = [{"role": "user", "content": r["prompt"]},
360
+ {"role": "model", "content": r["response"]}]
361
+ try:
362
+ text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
363
+ user_text = tokenizer.apply_chat_template([msgs[0]], tokenize=False, add_generation_prompt=True)
364
+ except Exception:
365
+ text = f"{r['prompt']}\n{r['response']}"
366
+ user_text = f"{r['prompt']}\n"
367
+ ids = tokenizer.encode(text, add_special_tokens=True)
368
+ # n_user = number of tokens at the START of `ids` that are user prompt + chat-template
369
+ # scaffolding. Tokens at positions [n_user, len(ids)) are the assistant response.
370
+ n_user = len(tokenizer.encode(user_text, add_special_tokens=True))
371
+ # Truncate to MAX_SEQ_LEN. Drop records where the user prompt alone fills the
372
+ # window (no scored response tokens would survive).
373
+ if n_user >= MAX_SEQ_LEN:
374
+ continue
375
+ ids = ids[:MAX_SEQ_LEN]
376
+ n = len(ids)
377
+ if n < 32:
378
+ continue
379
+ # labels[t] is the next-token target for position t (paired with logits[t]).
380
+ # labels[t] = ids[t+1] when ids[t+1] is in the response (i.e., t+1 >= n_user),
381
+ # else -100. Final labels[n-1] is set to -100 because there is no ids[n] inside
382
+ # the in-document content (truncation/end). labels[n:] = -100 (pad).
383
+ # In code: first n_user-1 labels are -100 (prompt-token targets),
384
+ # then labels[n_user-1 .. n-2] = ids[n_user .. n-1] (response targets),
385
+ # then labels[n-1 .. MAX_SEQ_LEN-1] = -100.
386
+ n_mask = n_user - 1
387
+ pad_len = MAX_SEQ_LEN - n
388
+ labels_list = [-100] * n_mask + ids[n_user:n] + [-100] * (pad_len + 1)
389
+ assert len(labels_list) == MAX_SEQ_LEN, \
390
+ f"label len {len(labels_list)} != MAX_SEQ_LEN {MAX_SEQ_LEN}"
391
+ # Sanity: at least one scored position (otherwise drop).
392
+ if not any(l != -100 for l in labels_list):
393
+ continue
394
+ seqs.append({
395
+ "input_ids": torch.tensor(ids + [pad_id] * pad_len, dtype=torch.long),
396
+ "labels": torch.tensor(labels_list, dtype=torch.long),
397
+ })
398
+ return seqs
399
+
400
+
401
+ def kl_loss(s_logits, t_logits, temp=1.0, mask=None):
402
+ """KL(student || teacher), optional bool [B,T] mask of positions to score.
403
+
404
+ Without mask, equivalent to F.kl_div(reduction='batchmean') * temp**2 (legacy).
405
+ With mask, scales the masked elements as if positions outside the mask had
406
+ contributed 0 — preserves the same per-batch loss magnitude as legacy.
407
+ """
408
+ s_log = F.log_softmax(s_logits / temp, dim=-1)
409
+ t_prob = F.softmax(t_logits / temp, dim=-1)
410
+ if mask is None:
411
+ return F.kl_div(s_log, t_prob, reduction="batchmean") * (temp ** 2)
412
+ # Per-(B,T) vocab-summed KL; preserve batchmean (sum / batch) semantics over masked subset.
413
+ elem = F.kl_div(s_log, t_prob, reduction="none").sum(dim=-1) # [B, T]
414
+ return elem[mask].sum() / s_logits.shape[0] * (temp ** 2)
415
+
416
+
417
+ def ce_loss(s_logits, labels):
418
+ return F.cross_entropy(
419
+ s_logits.reshape(-1, s_logits.size(-1)),
420
+ labels.reshape(-1), ignore_index=-100)
421
+
422
+
423
+ @torch.no_grad()
424
+ def eval_ppl(model, tokenizer, calib_path=None, max_seqs=None):
425
+ """Compute PPL over the eval split. If max_seqs is set, cap at that many
426
+ sequences (in load order — deterministic). Default behavior unchanged when
427
+ max_seqs is None."""
428
+ seqs = load_seqs(tokenizer, "eval", calib_path=calib_path)
429
+ if max_seqs is not None and max_seqs > 0:
430
+ seqs = seqs[:max_seqs]
431
+ loader = torch.utils.data.DataLoader(seqs, batch_size=1)
432
+ total_nll, total_tok = 0.0, 0
433
+ model.eval()
434
+ for i, batch in enumerate(loader):
435
+ if EVAL_BATCHES and i >= EVAL_BATCHES: break
436
+ ids = batch["input_ids"].to(DEVICE)
437
+ labels = batch["labels"][:, :-1].to(DEVICE)
438
+ logits = model(ids)
439
+ loss = F.cross_entropy(
440
+ logits[:, :-1].reshape(-1, logits.size(-1)),
441
+ labels.reshape(-1), ignore_index=-100, reduction="sum")
442
+ total_nll += loss.item()
443
+ total_tok += (labels != -100).sum().item()
444
+ return math.exp(total_nll / total_tok) if total_tok > 0 else float("inf")
445
+
446
+
447
+ # ──────────────────────── initialization helpers ────────────────────────
448
+
449
+ def compute_taylor_saliency(model, tokenizer, n_batches=8, calib_path=None):
450
+ """Mean |h * dL/dh| per neuron per layer. Returns list[N_LAYERS] of [D_FFN] tensors.
451
+ Temporarily re-enables grad on model params; restores frozen state on exit."""
452
+ model.eval()
453
+ # Snapshot freeze state; temporarily unfreeze for Taylor computation
454
+ prev_grad = [p.requires_grad for p in model.parameters()]
455
+ for p in model.parameters(): p.requires_grad_(True)
456
+ try:
457
+ scores = [torch.zeros(_d_ffn_at(i), device=DEVICE) for i in range(N_LAYERS)]
458
+ seqs = load_seqs(tokenizer, "train", calib_path=calib_path)[:n_batches * BATCH]
459
+ loader = torch.utils.data.DataLoader(seqs, batch_size=BATCH)
460
+ caches = [None] * N_LAYERS
461
+ hooks = []
462
+ def make_hook(i):
463
+ def hook(mod, inp, out):
464
+ caches[i] = out
465
+ out.retain_grad()
466
+ return hook
467
+ for i, layer in enumerate(model.layers):
468
+ hooks.append(layer.mlp.gate_proj.register_forward_hook(make_hook(i)))
469
+
470
+ n_seen = 0
471
+ for batch in loader:
472
+ ids = batch["input_ids"].to(DEVICE)
473
+ labels = batch["labels"][:, :-1].to(DEVICE)
474
+ logits = model(ids)
475
+ loss = F.cross_entropy(
476
+ logits[:, :-1].reshape(-1, logits.size(-1)),
477
+ labels.reshape(-1), ignore_index=-100)
478
+ loss.backward()
479
+ for i in range(N_LAYERS):
480
+ if caches[i] is not None and caches[i].grad is not None:
481
+ s = (caches[i].detach() * caches[i].grad.detach()).abs().mean(dim=(0, 1))
482
+ scores[i] += s
483
+ model.zero_grad(set_to_none=True)
484
+ n_seen += 1
485
+ if n_seen >= n_batches: break
486
+
487
+ for h in hooks: h.remove()
488
+ scores = [s.detach() / max(n_seen, 1) for s in scores]
489
+ return scores
490
+ finally:
491
+ # Restore original freeze state
492
+ for p, g in zip(model.parameters(), prev_grad):
493
+ p.requires_grad_(g)
494
+ model.zero_grad(set_to_none=True)
495
+
496
+
497
+ def init_assignment_logits(init_mode, K, K_const, taylor_scores=None, core_frac=0.5):
498
+ """Return per-layer list of A ∈ [D_FFN_i, K] init tensors."""
499
+ As = []
500
+ for i in range(N_LAYERS):
501
+ d_ffn_i = _d_ffn_at(i)
502
+ prune_k_i = _prune_k_at(i)
503
+ if init_mode == "random":
504
+ # std=0.5 gives softmax(A/1.0) mild bias but softmax(A/0.01) near one-hot
505
+ # at the end of training — room for A to grow meaningfully during anneal.
506
+ A = torch.randn(d_ffn_i, K) * 0.5
507
+ elif init_mode == "taylor":
508
+ assert taylor_scores is not None, "taylor init requires scores"
509
+ scores = taylor_scores[i].cpu() # [D_FFN_i]
510
+ order = scores.argsort(descending=True) # high-saliency first
511
+ # Softer init (±2.0) so τ-anneal 1.0→0.01 has dynamic range
512
+ A = torch.full((d_ffn_i, K), -2.0)
513
+ if K_const > 0:
514
+ # Top core_frac of prune_k_i active neurons into K_const always-on experts
515
+ n_core = int(prune_k_i * core_frac)
516
+ for rank, idx in enumerate(order[:n_core]):
517
+ A[idx, rank % K_const] = 2.0
518
+ for rank, idx in enumerate(order[n_core:prune_k_i]):
519
+ A[idx, K_const + rank % (K - K_const)] = 2.0
520
+ # Low-saliency neurons: uniform mild bias, τ-anneal drives assignment
521
+ for rank, idx in enumerate(order[prune_k_i:]):
522
+ A[idx, rank % K] = 0.0
523
+ else:
524
+ for rank, idx in enumerate(order[:prune_k_i]):
525
+ A[idx, rank % K] = 2.0
526
+ for rank, idx in enumerate(order[prune_k_i:]):
527
+ A[idx, rank % K] = 0.0
528
+ elif init_mode == "em" or init_mode == "kmeans":
529
+ raise NotImplementedError(
530
+ f"init={init_mode} requires a precomputed file; use --init random|taylor for now")
531
+ else:
532
+ raise ValueError(f"Unknown init_mode: {init_mode}")
533
+ As.append(A)
534
+ return As
535
+
536
+
537
+ def init_router_weights(init_mode, model, init_As, K, K_const, scale_multiplier=1.0):
538
+ """Return per-layer list of W_r ∈ [D_MODEL, K_spec] init tensors (None = use default random).
539
+
540
+ Modes:
541
+ - random: let MoEMLP use its default N(0, 0.02) init (returns list of Nones)
542
+ - zero: W_r = 0 everywhere (uniform routing at init; router learns from scratch)
543
+ - centroid: W_r[:, k] = L2-normalized mean of W_gate rows for expert k's
544
+ argmax-assigned specialist neurons, scaled to magnitude 0.02
545
+ (matches default random init scale). Informed warm start.
546
+ - scaled_centroid: W_r[:, k] = scale_multiplier × mean of W_gate rows for
547
+ expert k's assigned neurons (NOT normalized). The router's weight
548
+ scale inherits the natural magnitude of the base model's W_gate
549
+ columns — i.e., the router is "a multiple" of the underlying
550
+ weight geometry. scale_multiplier sets that multiple explicitly.
551
+ """
552
+ K_spec = K - K_const
553
+ if K_spec == 0:
554
+ return [None] * N_LAYERS
555
+ W_rs = []
556
+ for i in range(N_LAYERS):
557
+ if init_mode == "random":
558
+ W_rs.append(None)
559
+ elif init_mode == "zero":
560
+ W_rs.append(torch.zeros(D_MODEL, K_spec))
561
+ elif init_mode == "centroid":
562
+ W_gate = model.layers[i].mlp.gate_proj.weight.detach().float().cpu() # [D_FFN, D_MODEL]
563
+ A = init_As[i].cpu() # [D_FFN, K]
564
+ assignment = A.argmax(dim=-1) # [D_FFN], values in [0, K)
565
+ W_r = torch.zeros(D_MODEL, K_spec)
566
+ for k in range(K_spec):
567
+ expert_k = K_const + k # specialist expert index in full K space
568
+ mask = (assignment == expert_k)
569
+ if mask.any():
570
+ W_r[:, k] = W_gate[mask].mean(dim=0)
571
+ else:
572
+ # Fallback: mean of all columns
573
+ W_r[:, k] = W_gate.mean(dim=0)
574
+ W_r = F.normalize(W_r, dim=0) * 0.02 # unit-direction, small magnitude comparable to default
575
+ W_rs.append(W_r)
576
+ elif init_mode == "scaled_centroid":
577
+ W_gate = model.layers[i].mlp.gate_proj.weight.detach().float().cpu() # [D_FFN, D_MODEL]
578
+ A = init_As[i].cpu() # [D_FFN, K]
579
+ assignment = A.argmax(dim=-1)
580
+ W_r = torch.zeros(D_MODEL, K_spec)
581
+ for k in range(K_spec):
582
+ expert_k = K_const + k
583
+ mask = (assignment == expert_k)
584
+ if mask.any():
585
+ W_r[:, k] = W_gate[mask].mean(dim=0)
586
+ else:
587
+ W_r[:, k] = W_gate.mean(dim=0)
588
+ # NOT normalized — router magnitude inherits base-model weight scale.
589
+ # scale_multiplier is a free "how many multiples of base weights" knob.
590
+ W_r = W_r * scale_multiplier
591
+ W_rs.append(W_r)
592
+ else:
593
+ raise ValueError(f"Unknown W_r init_mode: {init_mode}")
594
+ return W_rs
595
+
596
+
597
+ # ───────────────────────── MoE MLP module ─────────────────────────────
598
+
599
+ class MoEMLP(nn.Module):
600
+ """
601
+ K experts via per-neuron softmax assignment (MECE mode) OR
602
+ independent sigmoid masks with orthogonality loss (sigmoid_ortho mode).
603
+
604
+ K_const always-on experts always apply; K_spec routed experts selected
605
+ via top-K_active_spec routing.
606
+ """
607
+ def __init__(self, base_mlp, K, K_const, K_active_spec, mece_mode, init_A,
608
+ noise_std=1.0, freeze_base=True, init_W_r=None):
609
+ super().__init__()
610
+ self.gate_proj = base_mlp.gate_proj
611
+ self.up_proj = base_mlp.up_proj
612
+ self.down_proj = base_mlp.down_proj
613
+ if freeze_base:
614
+ for p in self.gate_proj.parameters(): p.requires_grad_(False)
615
+ for p in self.up_proj.parameters(): p.requires_grad_(False)
616
+ for p in self.down_proj.parameters(): p.requires_grad_(False)
617
+
618
+ self.K = K
619
+ self.K_const = K_const
620
+ self.K_spec = K - K_const
621
+ self.K_active_spec = K_active_spec # # of specialist experts fired per token
622
+ self.mece_mode = mece_mode # "softmax" | "sigmoid_ortho"
623
+ self.noise_std = noise_std
624
+ self.tau = 1.0
625
+
626
+ # Assignment logits
627
+ self.A = nn.Parameter(init_A.to(DEVICE).float())
628
+
629
+ if self.K_spec > 0:
630
+ if init_W_r is not None:
631
+ self.W_r = nn.Parameter(init_W_r.to(DEVICE).float())
632
+ else:
633
+ self.W_r = nn.Parameter(torch.zeros(D_MODEL, self.K_spec, device=DEVICE, dtype=torch.float32))
634
+ nn.init.normal_(self.W_r, std=0.02)
635
+ else:
636
+ self.register_parameter("W_r", None)
637
+
638
+ # Diagnostics cache (populated during training forward)
639
+ self._last_logits = None
640
+ self._last_top_idx = None
641
+
642
+ def _expert_masks(self):
643
+ """Return [K, D_FFN] — each expert's soft mask."""
644
+ if self.mece_mode == "softmax":
645
+ probs = F.softmax(self.A / max(self.tau, 1e-3), dim=-1) # [D_FFN, K]
646
+ return probs.T.contiguous() # [K, D_FFN]
647
+ elif self.mece_mode == "sigmoid_ortho":
648
+ return torch.sigmoid(self.A / max(self.tau, 1e-3)).T.contiguous()
649
+ else:
650
+ raise ValueError(self.mece_mode)
651
+
652
+ def forward(self, x):
653
+ gate_raw = self.gate_proj(x) # [B, T, D_FFN]
654
+ gate_act = F.gelu(gate_raw, approximate="tanh")
655
+ up_act = self.up_proj(x)
656
+ h_pre = gate_act * up_act # [B, T, D_FFN]
657
+
658
+ masks = self._expert_masks() # [K, D_FFN]
659
+
660
+ # Always-on core contribution
661
+ d_ffn = self.A.shape[0] # per-layer D_FFN
662
+ if self.K_const > 0:
663
+ const_mask = masks[:self.K_const].sum(dim=0) # [D_FFN]
664
+ else:
665
+ const_mask = torch.zeros(d_ffn, device=x.device, dtype=torch.float32)
666
+
667
+ # Routed specialist contribution
668
+ if self.K_spec > 0:
669
+ logits = x.to(torch.float32) @ self.W_r # [B, T, K_spec]
670
+ if self.training and self.noise_std > 0:
671
+ logits = logits + torch.randn_like(logits) * (self.noise_std / (self.K_spec ** 0.5))
672
+ self._last_logits = logits
673
+
674
+ k_act = min(self.K_active_spec, self.K_spec)
675
+ top_vals, top_idx = logits.topk(k_act, dim=-1) # [B, T, k_act]
676
+ self._last_top_idx = top_idx
677
+ top_w = F.softmax(top_vals, dim=-1) # [B, T, k_act]
678
+
679
+ spec_masks = masks[self.K_const:] # [K_spec, D_FFN]
680
+ gathered = spec_masks[top_idx] # [B, T, k_act, D_FFN]
681
+ spec_combined = (gathered * top_w.unsqueeze(-1)).sum(dim=-2) # [B, T, D_FFN]
682
+ combined = const_mask.view(1, 1, -1) + spec_combined
683
+ else:
684
+ combined = const_mask.view(1, 1, -1).expand_as(h_pre)
685
+
686
+ h = h_pre * combined.to(x.dtype)
687
+ return self.down_proj(h)
688
+
689
+ def aux_loss(self, alpha_b=0.01, alpha_z=0.001):
690
+ """Switch balance loss (on specialists only) + router z-loss.
691
+
692
+ Balance (Switch/GShard top-k generalization):
693
+ f_k = fraction of tokens routed to expert k
694
+ = (tokens_selecting_k / total_tokens) / K_active_spec
695
+ p_k = mean softmax probability for expert k
696
+ L = α * K * Σ f_k p_k, minimized (→ 1/K) when uniform.
697
+ """
698
+ if self.K_spec == 0 or self._last_logits is None:
699
+ return torch.tensor(0.0, device=DEVICE)
700
+ logits = self._last_logits # [B, T, K_spec]
701
+ probs = F.softmax(logits, dim=-1)
702
+ top_idx = self._last_top_idx # [B, T, k_act]
703
+ hot = F.one_hot(top_idx, self.K_spec).float().sum(dim=-2) # [B, T, K_spec]
704
+ # Normalize so Σ_k f_k = 1 regardless of K_active_spec.
705
+ f_k = hot.mean(dim=(0, 1)) / max(self.K_active_spec, 1) # [K_spec]
706
+ p_k = probs.mean(dim=(0, 1)) # [K_spec]
707
+ balance = alpha_b * self.K_spec * (f_k * p_k).sum()
708
+
709
+ lse = torch.logsumexp(logits, dim=-1) # [B, T]
710
+ z_loss = alpha_z * (lse ** 2).mean()
711
+ return balance + z_loss
712
+
713
+ def orth_loss(self):
714
+ """For sigmoid_ortho mode: penalize pairwise expert mask overlap."""
715
+ if self.mece_mode != "sigmoid_ortho": return torch.tensor(0.0, device=DEVICE)
716
+ masks = self._expert_masks() # [K, D_FFN]
717
+ # L2-normalize rows, then off-diagonal Gram
718
+ mn = F.normalize(masks, dim=-1)
719
+ gram = mn @ mn.T # [K, K]
720
+ K = gram.size(0)
721
+ off = gram - torch.eye(K, device=gram.device)
722
+ return (off ** 2).sum() / (K * (K - 1) + 1e-8)
723
+
724
+
725
+ # ───────────────────────────── training ─────────────────────────────
726
+
727
+ def get_tau(step, max_steps, tau_start, tau_end, hold_frac=0.2):
728
+ """Linear anneal over first (1-hold_frac) of steps, then hold at tau_end.
729
+ This prevents τ-anneal shock — the model needs time to adapt to hard masks."""
730
+ anneal_steps = max(1, int(max_steps * (1 - hold_frac)))
731
+ if step >= anneal_steps:
732
+ return tau_end
733
+ frac = step / max(1, anneal_steps - 1)
734
+ return tau_start + frac * (tau_end - tau_start)
735
+
736
+
737
+ def install_moe(model, K, K_const, K_active_spec, mece_mode, init_As, noise_std,
738
+ freeze_base=True, init_W_rs=None):
739
+ mlp_modules = []
740
+ if init_W_rs is None:
741
+ init_W_rs = [None] * N_LAYERS
742
+ for i in range(N_LAYERS):
743
+ new_mlp = MoEMLP(
744
+ base_mlp=model.layers[i].mlp,
745
+ K=K, K_const=K_const, K_active_spec=K_active_spec,
746
+ mece_mode=mece_mode, init_A=init_As[i], noise_std=noise_std,
747
+ freeze_base=freeze_base, init_W_r=init_W_rs[i])
748
+ model.layers[i].mlp = new_mlp
749
+ mlp_modules.append(new_mlp)
750
+ return mlp_modules
751
+
752
+
753
+ def main():
754
+ parser = argparse.ArgumentParser()
755
+ parser.add_argument("--phase", type=str, default="A1")
756
+ parser.add_argument("--K", type=int, default=4)
757
+ parser.add_argument("--K_const", type=int, default=0)
758
+ parser.add_argument("--K_active_spec", type=int, default=-1,
759
+ help="# specialists fired per token. Default = round(K_spec * 0.40 / (1 - K_const/K * 0.40)); falls back to max(1, round(K_spec*0.5))")
760
+ parser.add_argument("--loss", choices=["kl", "ce"], default="kl")
761
+ parser.add_argument("--init", choices=["random", "taylor", "em", "kmeans"], default="random")
762
+ parser.add_argument("--core_frac", type=float, default=0.5,
763
+ help="Fraction of PRUNE_K active neurons to concentrate in K_const core (Taylor init only)")
764
+ parser.add_argument("--mece_mode", choices=["softmax", "sigmoid_ortho"], default="softmax")
765
+ parser.add_argument("--tau_start", type=float, default=1.0)
766
+ parser.add_argument("--tau_end", type=float, default=0.01)
767
+ parser.add_argument("--tau_hold_frac", type=float, default=0.2,
768
+ help="Fraction of max_steps to HOLD at tau_end after annealing. Default 0.2 = "
769
+ "anneal over first 80%, hold last 20%. For long continuation runs, set "
770
+ "to e.g. 0.857 to give just 5k anneal steps and 30k hard-tau steps "
771
+ "(out of 35k total).")
772
+ parser.add_argument("--max_steps", type=int, default=2000)
773
+ parser.add_argument("--lr", type=float, default=LR)
774
+ parser.add_argument("--alpha_b", type=float, default=0.01)
775
+ parser.add_argument("--alpha_z", type=float, default=0.001)
776
+ parser.add_argument("--alpha_orth", type=float, default=0.01)
777
+ parser.add_argument("--noise_std", type=float, default=1.0)
778
+ parser.add_argument("--eval_every", type=int, default=200)
779
+ parser.add_argument("--optimizer", choices=["adamw", "adamw8bit"], default="adamw",
780
+ help="adamw8bit uses bitsandbytes 8-bit optimizer — saves ~28GB "
781
+ "optimizer state on 4.65B model, required to --unfreeze_base on H100 80GB")
782
+ parser.add_argument("--freeze_embeddings", action="store_true",
783
+ help="Freeze embed_tokens (+tied lm_head) and embed_tokens_per_layer. "
784
+ "For Gemma-4 E2B these are 2.75B of 5.1B params and embed_tokens_per_layer "
785
+ "is a single 2.35B-element tensor that exceeds bnb 8bit kernel limits. "
786
+ "Freezing them makes --unfreeze_base feasible with plain fp32 AdamW on "
787
+ "the remaining ~2.35B params (~19GB state, fits 80GB).")
788
+ parser.add_argument("--use_lora", action="store_true",
789
+ help="Wrap target Linears with LoRALinear (frozen base + trainable rank-r delta). "
790
+ "Use INSTEAD of full base fine-tuning. Combines naturally with --int4_qat: "
791
+ "LoRA wraps the int4-quantized Linear. Trains only ~10-30M LoRA params + MoE.")
792
+ parser.add_argument("--lora_rank", type=int, default=16,
793
+ help="LoRA rank (low-dim adapter dim). Typical: 8 (less capacity, less overfit), "
794
+ "16 (default), 32 (more capacity).")
795
+ parser.add_argument("--lora_alpha", type=float, default=16.0,
796
+ help="LoRA scaling factor; effective scale = alpha/rank. Default 16/16 = 1.0.")
797
+ parser.add_argument("--W_r_init", choices=["random", "zero", "centroid", "scaled_centroid"], default="random",
798
+ help="Router W_r init: random (default), zero (uniform routing), "
799
+ "centroid (mean W_gate row per Taylor-assigned expert, L2-normalized to 0.02 mag), "
800
+ "scaled_centroid (mean W_gate row per expert, NOT normalized, scaled by --W_r_scale).")
801
+ parser.add_argument("--W_r_scale", type=float, default=1.0,
802
+ help="Multiplier for scaled_centroid init. W_r = scale × mean(W_gate per expert). "
803
+ "Values ~0.1–10 control how 'loud' the router is relative to base weight scale.")
804
+ parser.add_argument("--W_r_lr_mult", type=float, default=1.0,
805
+ help="Learning rate multiplier for router W_r params (and A logits). "
806
+ "E.g., 5.0 trains the router 5× faster than base weights. The router "
807
+ "is ~0.03% of total params and has a specific job — higher LR can "
808
+ "help it converge quickly without destabilizing base-weight training.")
809
+ parser.add_argument("--freeze_A", action="store_true",
810
+ help="Freeze assignment logits A (only router + optionally base train)")
811
+ parser.add_argument("--unfreeze_base", action="store_true",
812
+ help="Train base weights (W_gate/W_up/W_down, attn, norms). Default freezes them.")
813
+ parser.add_argument("--save_checkpoint", type=str, default="",
814
+ help="Save final student state_dict to this path (.pt)")
815
+ parser.add_argument("--save_every", type=int, default=0,
816
+ help="If >0 and --save_checkpoint set, also save an intermediate ckpt every "
817
+ "N max_steps. Filename: <save_checkpoint stem>_step<N>.pt. Use for long "
818
+ "runs where you may want to early-stop without losing progress.")
819
+ parser.add_argument("--shuffle_seed", type=int, default=0,
820
+ help="Seed for the dataloader shuffle. Same seed → same record order. Use a "
821
+ "different seed in continuation runs to expose the model to a new ordering "
822
+ "of the dataset.")
823
+ parser.add_argument("--data_skip", type=int, default=0,
824
+ help="Discard first N samples of the (shuffled) dataloader stream before "
825
+ "training. Combine with same --shuffle_seed as a previous run to start "
826
+ "where it left off — model sees fresh records first.")
827
+ parser.add_argument("--load_checkpoint", type=str, default="",
828
+ help="Load student state_dict from this path BEFORE training (warm-start). "
829
+ "Must be from a prior rung6_moe.py run with matching architecture.")
830
+ parser.add_argument("--calib_path", type=str, default=CALIB_DATA_PATH,
831
+ help="Path to JSONL calibration data for TRAINING. Default: final.jsonl (640 records). "
832
+ "Use bulk.jsonl (~12k records) or trajectories_25k.jsonl (25k) for more data.")
833
+ parser.add_argument("--eval_calib_path", type=str, default="",
834
+ help="Path to JSONL calibration data for EVAL. Default: same as --calib_path. "
835
+ "Set to final.jsonl for consistent eval across curriculum phases.")
836
+ parser.add_argument("--int4_qat", action="store_true",
837
+ help="Enable int4 QAT: wrap target Linears (MLP + attention) with Int4QuantLinear "
838
+ "so forward uses fake-quantized weights (groupwise STE, group_size=128).")
839
+ parser.add_argument("--int4_group_size", type=int, default=32,
840
+ help="Groupwise int4 quant group size. Default 32 matches GGUF Q4_0/Q4_K deploy block size. "
841
+ "128 is another common choice (AWQ-style) with less storage overhead but larger quant error.")
842
+ parser.add_argument("--eval_only", action="store_true",
843
+ help="Skip training; just eval after setup (init + optional checkpoint load + optional "
844
+ "int4 wrap). Useful for measuring untrained-int4 baseline or a specific checkpoint's "
845
+ "eval PPL at tau_end without further optimization.")
846
+ # Knowledge preservation fixes
847
+ parser.add_argument("--diverse_calib_path", type=str, default="",
848
+ help="Path to JSONL (raw 'text' field) for periodic KL-to-base preservation batches. "
849
+ "Usually wikitext or similar pretraining-distribution text.")
850
+ parser.add_argument("--diverse_every_n", type=int, default=4,
851
+ help="Every N optimizer steps, replace the normal CE batch with a KL-to-teacher pass "
852
+ "on diverse data. Default 4 = ~25%% of batches.")
853
+ parser.add_argument("--main_kl_temp", type=float, default=1.0,
854
+ help="Softmax temperature for the MAIN loss when --loss kl. "
855
+ "T>1 softens teacher's argmax commitment. Useful for knowledge "
856
+ "retention but too high (>5) can destabilize Gemma-4 training "
857
+ "due to low teacher entropy.")
858
+ parser.add_argument("--kl_base_lambda", type=float, default=0.5,
859
+ help="Scalar on the diverse-batch KL-to-teacher loss.")
860
+ parser.add_argument("--kl_base_temp", type=float, default=2.0,
861
+ help="Softmax temperature for KL-to-teacher. >1 softens distributions, recovering "
862
+ "tail mass — important when teacher entropy is low (e.g., Gemma-4 E2B). "
863
+ "Try 2-5 for Gemma-3, 5-10 for Gemma-4.")
864
+ parser.add_argument("--w_drift_lambda", type=float, default=0.0,
865
+ help="L2-to-base weight-drift penalty: λ × Σ ‖W_student − W_teacher‖² over trainable "
866
+ "base weights (excluding MoE .A and .W_r). Prevents catastrophic forgetting by "
867
+ "anchoring weights to base. Typical: 1e-6 to 1e-4.")
868
+ parser.add_argument("--real_int4_inplace", action="store_true",
869
+ help="After load_checkpoint, snap target Linear weights to int4 grid in-place (no STE, "
870
+ "no runtime overhead). Simulates deployment — forward uses plain nn.Linear with "
871
+ "already-quantized weights. Combine with --eval_only for the real-int4 benchmark.")
872
+ parser.add_argument("--gaussian_noise_scale", type=float, default=0.0,
873
+ help="Add N(0, scale × p.std()) Gaussian noise to target Linear weights in-place. "
874
+ "Default 0.0 = disabled. 0.129 is the analytical int4 group=32 equivalent.")
875
+ # ── Activation-MSE recovery (mechanism A: generic per-module) ──
876
+ parser.add_argument("--recovery_steps", type=int, default=0,
877
+ help="If >0: run module_recovery.recover_modules_sequentially on every per-layer "
878
+ "MLP after install_moe + wrap_int4 (+ wrap_lora) and BEFORE main training. "
879
+ "Default 0 = disabled.")
880
+ parser.add_argument("--recovery_lr", type=float, default=1e-4,
881
+ help="LR for the generic recovery AdamW (only A and W_r receive grad — base "
882
+ "and LoRA params are not in the trainable set during recovery).")
883
+ parser.add_argument("--recovery_n_batches", type=int, default=8,
884
+ help="# calibration batches sampled from --calib_path for generic recovery.")
885
+ # ── Activation-MSE recovery (mechanism B: specialized MoE per-layer) ──
886
+ parser.add_argument("--moe_recovery_seconds_per_layer", type=float, default=0.0,
887
+ help="If >0: run finetune_moe_per_layer for this many wall-clock seconds per "
888
+ "MLP layer. Pre-caches teacher (X, Y), optimizes A and W_r only. "
889
+ "Default 0 = disabled.")
890
+ parser.add_argument("--moe_recovery_lr", type=float, default=1e-3,
891
+ help="LR for the specialized per-layer recovery (A and W_r are tiny — 1e-3 is fine).")
892
+ parser.add_argument("--moe_recovery_n_calib_records", type=int, default=32,
893
+ help="# calibration records (single-sequence, len MAX_SEQ_LEN) cached for the "
894
+ "specialized recovery. Memory ≈ 2 × N × MAX_SEQ_LEN × hidden × 2 bytes.")
895
+ parser.add_argument("--moe_recovery_use_student_inputs", type=lambda s: s.lower() in ("1", "true", "yes"),
896
+ default=True,
897
+ help="If True (default), refresh student X between layers so each layer sees "
898
+ "error-corrected upstream activations. If False, use teacher X throughout "
899
+ "(matches Sunday's original pipeline).")
900
+ parser.add_argument("--moe_recovery_optimizer", choices=["adam", "muon"], default="adam",
901
+ help="Specialized recovery optimizer. 'muon' uses muon.MuonWithAdam (matrix-aware "
902
+ "Newton-Schulz). A and W_r are both 2D so Muon-eligible.")
903
+ parser.add_argument("--moe_recovery_noise_std", type=float, default=-1.0,
904
+ help="Override MoEMLP router noise during recovery. -1.0 = keep current "
905
+ "MoEMLP setting (default 1.0 from MoE training convention). 0.0 = "
906
+ "deterministic routing for clean per-step loss + meaningful best-state "
907
+ "tracking + train/deploy match. Higher = more router exploration.")
908
+ args = parser.parse_args()
909
+ if not args.eval_calib_path:
910
+ args.eval_calib_path = args.calib_path
911
+
912
+ K_spec = args.K - args.K_const
913
+ assert K_spec >= 0 and args.K_const >= 0 and args.K >= 1
914
+ if args.K_active_spec < 0:
915
+ # Target per-token sparsity = 40% of D_FFN = PRUNE_K neurons.
916
+ # Each expert covers ~D_FFN/K neurons at MECE. K_const always fires (D_FFN/K * K_const).
917
+ # Need K_active_spec such that (K_const + K_active_spec) * D_FFN/K ≈ PRUNE_K
918
+ # → K_active_spec = round(K * PRUNE_P - K_const)
919
+ k_act = max(1, round(args.K * PRUNE_P) - args.K_const) if K_spec > 0 else 0
920
+ args.K_active_spec = k_act
921
+ assert args.K_active_spec <= K_spec
922
+
923
+ os.makedirs("logs", exist_ok=True)
924
+ print(f"=== Rung 6 MoE — phase={args.phase} ===")
925
+ print(f" K={args.K} K_const={args.K_const} K_spec={K_spec} K_active_spec={args.K_active_spec}")
926
+ print(f" mece_mode={args.mece_mode} init={args.init} loss={args.loss}")
927
+ print(f" tau: {args.tau_start} → {args.tau_end} over {args.max_steps} steps")
928
+ # Gemma-4 has two MLP widths (6144 / 12288). Report both layer types' active budgets.
929
+ ratio = (args.K_const + args.K_active_spec) / args.K
930
+ for width_name, d in (("narrow (layers 0-14)", INTERMEDIATE),
931
+ ("wide (layers 15+)", INTERMEDIATE_WIDE)):
932
+ eff_active = ratio * d
933
+ prune_k = int(d * PRUNE_P)
934
+ print(f" {width_name}: active ~{eff_active:.0f}/{d} "
935
+ f"(40% target = {prune_k}; diff = {eff_active - prune_k:+.0f})")
936
+
937
+ print(f" freeze_A={args.freeze_A} unfreeze_base={args.unfreeze_base} W_r_init={args.W_r_init}")
938
+ if args.load_checkpoint: print(f" load_checkpoint={args.load_checkpoint}")
939
+ if args.save_checkpoint: print(f" save_checkpoint={args.save_checkpoint}")
940
+
941
+ print(f"Loading teacher & student on {DEVICE}...")
942
+ teacher, tokenizer = load_model()
943
+ teacher.eval()
944
+ for p in teacher.parameters(): p.requires_grad_(False)
945
+
946
+ student, _ = load_model()
947
+ # Note: NO corruption — rung 6 uses the CLEAN IT model.
948
+ freeze_base = not args.unfreeze_base
949
+ if freeze_base:
950
+ for p in student.parameters(): p.requires_grad_(False) # freeze base first
951
+ # If unfreeze_base: leave requires_grad=True on all params (default)
952
+
953
+ # Embedding freeze for Gemma-4 (selectively keep embed_tokens and embed_tokens_per_layer
954
+ # frozen even when the rest of the base is training). Required for Gemma-4 4.65B on 80GB:
955
+ # embed_tokens_per_layer alone is a single 2.35B tensor that breaks bnb 8bit kernels, and
956
+ # embeddings rarely need to move for MoE-preservation work anyway.
957
+ if args.freeze_embeddings:
958
+ n_frozen = 0
959
+ for name, p in student.named_parameters():
960
+ if "embed_tokens" in name: # catches embed_tokens and embed_tokens_per_layer (and tied lm_head)
961
+ p.requires_grad_(False)
962
+ n_frozen += p.numel()
963
+ print(f" Froze embeddings: {n_frozen/1e9:.2f}B params (embed_tokens, embed_tokens_per_layer, tied lm_head)")
964
+
965
+ # Initialization
966
+ taylor_scores = None
967
+ if args.init == "taylor" and not args.load_checkpoint:
968
+ print("Computing Taylor saliency for init...")
969
+ taylor_scores = compute_taylor_saliency(student, tokenizer, n_batches=8, calib_path=args.calib_path)
970
+ init_As = init_assignment_logits(args.init if not args.load_checkpoint else "random",
971
+ args.K, args.K_const, taylor_scores, core_frac=args.core_frac)
972
+ init_W_rs = init_router_weights(args.W_r_init, student, init_As, args.K, args.K_const,
973
+ scale_multiplier=args.W_r_scale)
974
+
975
+ mlp_modules = install_moe(
976
+ student, K=args.K, K_const=args.K_const,
977
+ K_active_spec=args.K_active_spec, mece_mode=args.mece_mode,
978
+ init_As=init_As, noise_std=args.noise_std,
979
+ freeze_base=freeze_base, init_W_rs=init_W_rs)
980
+
981
+ # Optionally freeze A (only router trains) — done AFTER install_moe
982
+ if args.freeze_A:
983
+ for m in mlp_modules:
984
+ m.A.requires_grad_(False)
985
+ print(" A frozen — only router W_r (and base if --unfreeze_base) trains")
986
+
987
+ # Load warm-start checkpoint BEFORE computing trainable params
988
+ if args.load_checkpoint:
989
+ print(f" Loading checkpoint from {args.load_checkpoint}...")
990
+ ckpt = torch.load(args.load_checkpoint, map_location=DEVICE)
991
+ state = ckpt.get('student_state', ckpt) if isinstance(ckpt, dict) else ckpt
992
+ missing, unexpected = student.load_state_dict(state, strict=False)
993
+ print(f" missing={len(missing)} unexpected={len(unexpected)}")
994
+
995
+ # Int4 QAT: wrap target Linears AFTER state_dict load (keys unchanged — subclass of nn.Linear).
996
+ # Must happen BEFORE optimizer creation so parameter references are stable.
997
+ if args.int4_qat:
998
+ Int4QuantLinear._group_size = args.int4_group_size
999
+ n_wrapped = wrap_int4(student)
1000
+ print(f" Int4 QAT: wrapped {n_wrapped} nn.Linear modules (group_size={args.int4_group_size}, "
1001
+ f"range [-7, 7]). Forward uses fake-quant; backward is STE through fp weight.")
1002
+
1003
+ # LoRA: wrap target Linears (incl. Int4QuantLinear) with LoRALinear so base is frozen
1004
+ # and only LoRA A/B + MoE A logits/W_r train. Apply AFTER int4 so the base inside LoRA
1005
+ # is the int4-quantized Linear (deploy-realistic).
1006
+ if args.use_lora:
1007
+ # Pure-LoRA semantics: freeze ALL base params (including attention, norms, scalars
1008
+ # not in LoRA target list). MoE A/W_r and the LoRA adapters added by wrap_lora are
1009
+ # the only trainable things. Overrides --unfreeze_base.
1010
+ for name, p in student.named_parameters():
1011
+ if not (name.endswith(".A") or name.endswith(".W_r")):
1012
+ p.requires_grad_(False)
1013
+ n_wrapped, n_lora_params = wrap_lora(student, rank=args.lora_rank, alpha=args.lora_alpha)
1014
+ print(f" LoRA: wrapped {n_wrapped} Linears with rank={args.lora_rank} alpha={args.lora_alpha} "
1015
+ f"(trainable LoRA params: {n_lora_params/1e6:.2f}M)")
1016
+
1017
+ # Real int4 quantization in-place (deploy simulation — no runtime quant overhead).
1018
+ if args.real_int4_inplace:
1019
+ n_q = apply_int4_inplace(student, group_size=args.int4_group_size)
1020
+ print(f" Real int4 inplace: quantized {n_q} Linear weights to int4 grid "
1021
+ f"(group_size={args.int4_group_size}); weights now on-grid, regular nn.Linear forward.")
1022
+
1023
+ # Gaussian-proxy noise benchmark.
1024
+ if args.gaussian_noise_scale > 0:
1025
+ n_g = apply_gaussian_noise_inplace(student, noise_scale=args.gaussian_noise_scale)
1026
+ print(f" Gaussian noise inplace: added N(0, {args.gaussian_noise_scale} × p.std()) "
1027
+ f"to {n_g} Linear weights.")
1028
+
1029
+ # ────────── Activation-MSE recovery (mechanism A: generic) ──────────
1030
+ # Runs AFTER install_moe + wrap_int4 (+ wrap_lora) so the recovered student
1031
+ # is the deployed one (int4 fake-quant in the loop, MoE routing engaged at
1032
+ # tau_end). Trainable params during recovery: same as training (i.e., A,
1033
+ # W_r — base is frozen unless --unfreeze_base, in which case it'd also move,
1034
+ # but we explicitly want only A/W_r so we do NOT alter requires_grad here).
1035
+ if args.recovery_steps > 0:
1036
+ # Hard routing during recovery — match deploy-time temperature.
1037
+ for m in mlp_modules: m.tau = args.tau_end
1038
+ # Optionally override router noise during recovery (default -1 = leave as-is).
1039
+ prev_noise = [getattr(m, "noise_std", None) for m in mlp_modules]
1040
+ if args.moe_recovery_noise_std >= 0:
1041
+ for m in mlp_modules:
1042
+ if hasattr(m, "noise_std"): m.noise_std = args.moe_recovery_noise_std
1043
+ print(f"\n [recovery A] generic recover_modules_sequentially "
1044
+ f"steps={args.recovery_steps} lr={args.recovery_lr} "
1045
+ f"n_batches={args.recovery_n_batches} tau={args.tau_end} "
1046
+ f"noise={args.moe_recovery_noise_std if args.moe_recovery_noise_std >= 0 else 'unchanged'}")
1047
+ # Restrict trainable set to MoE params (A, W_r) for the duration of
1048
+ # recovery. Snapshot prior requires_grad so we can restore it for main
1049
+ # training (e.g., LoRA adapters that should keep training afterwards).
1050
+ prev_requires_grad = {n: p.requires_grad for n, p in student.named_parameters()}
1051
+ # Restrict to A/W_r — but RESPECT --freeze_A: don't enable A if it was
1052
+ # frozen pre-recovery. Same for W_r (in case caller froze it).
1053
+ for n, p in student.named_parameters():
1054
+ is_moe = n.endswith(".A") or n.endswith(".W_r")
1055
+ p.requires_grad_(is_moe and prev_requires_grad[n])
1056
+ # Pull `recovery_n_batches` calibration batches (input_ids only).
1057
+ rec_seqs = load_seqs(tokenizer, "train", calib_path=args.calib_path)
1058
+ rec_seqs = rec_seqs[:args.recovery_n_batches * BATCH]
1059
+ rec_loader = torch.utils.data.DataLoader(rec_seqs, batch_size=BATCH)
1060
+ rec_input_ids = [batch["input_ids"] for batch in rec_loader][:args.recovery_n_batches]
1061
+ if not rec_input_ids:
1062
+ print(" [recovery A] no calibration data — skipping")
1063
+ else:
1064
+ n_train_per_mlp = sum(
1065
+ p.numel() for n, p in mlp_modules[0].named_parameters(recurse=False)
1066
+ if p.requires_grad and n in ("A", "W_r")
1067
+ )
1068
+ print(f" [recovery A] per-layer MoE trainable params: {n_train_per_mlp}")
1069
+ rec_results = recover_modules_via_generic_pipeline(
1070
+ student=student, teacher=teacher,
1071
+ calibration_input_ids=rec_input_ids,
1072
+ n_layers=N_LAYERS,
1073
+ steps=args.recovery_steps,
1074
+ lr=args.recovery_lr,
1075
+ device=DEVICE,
1076
+ )
1077
+ for r in rec_results:
1078
+ print(f" {r['name']} in_mse={r['input_mse']:.4e} "
1079
+ f"out_pre={r['output_mse_before']:.4e} out_post={r['output_mse_after']:.4e}")
1080
+ # Restore prior requires_grad state, tau, and noise.
1081
+ for n, p in student.named_parameters():
1082
+ p.requires_grad_(prev_requires_grad[n])
1083
+ for m in mlp_modules: m.tau = args.tau_start
1084
+ if args.moe_recovery_noise_std >= 0:
1085
+ for m, n in zip(mlp_modules, prev_noise):
1086
+ if hasattr(m, "noise_std") and n is not None: m.noise_std = n
1087
+
1088
+ # ────────── Activation-MSE recovery (mechanism B: specialized MoE) ──────────
1089
+ # Pre-cache (X, Y) per layer once via teacher forward, then per-layer
1090
+ # time-budgeted optimization of A and W_r only with student-input
1091
+ # propagation between layers.
1092
+ if args.moe_recovery_seconds_per_layer > 0:
1093
+ # Hard routing during recovery — match deploy-time temperature.
1094
+ for m in mlp_modules: m.tau = args.tau_end
1095
+ print(f"\n [recovery B] finetune_moe_per_layer "
1096
+ f"sec/layer={args.moe_recovery_seconds_per_layer} "
1097
+ f"lr={args.moe_recovery_lr} n_calib={args.moe_recovery_n_calib_records} "
1098
+ f"use_student_inputs={args.moe_recovery_use_student_inputs} "
1099
+ f"opt={args.moe_recovery_optimizer} tau={args.tau_end}")
1100
+ moe_rec_seqs = load_seqs(tokenizer, "train", calib_path=args.calib_path)
1101
+ moe_rec_seqs = moe_rec_seqs[:args.moe_recovery_n_calib_records * BATCH]
1102
+ moe_rec_loader = torch.utils.data.DataLoader(moe_rec_seqs, batch_size=BATCH)
1103
+ moe_rec_input_ids = [b["input_ids"] for b in moe_rec_loader][:args.moe_recovery_n_calib_records]
1104
+ if not moe_rec_input_ids:
1105
+ print(" [recovery B] no calibration data — skipping")
1106
+ else:
1107
+ n_train_per_mlp = sum(
1108
+ p.numel() for n, p in mlp_modules[0].named_parameters(recurse=False)
1109
+ if p.requires_grad and n in ("A", "W_r")
1110
+ )
1111
+ print(f" [recovery B] per-layer MoE trainable params: {n_train_per_mlp}")
1112
+ moe_rec_results = finetune_moe_per_layer(
1113
+ student=student, teacher=teacher,
1114
+ calibration_input_ids=moe_rec_input_ids,
1115
+ n_layers=N_LAYERS,
1116
+ seconds_per_layer=args.moe_recovery_seconds_per_layer,
1117
+ lr=args.moe_recovery_lr,
1118
+ optimizer=args.moe_recovery_optimizer,
1119
+ use_student_inputs=args.moe_recovery_use_student_inputs,
1120
+ device=DEVICE,
1121
+ tau_end=args.tau_end,
1122
+ noise_std=(None if args.moe_recovery_noise_std < 0 else args.moe_recovery_noise_std),
1123
+ )
1124
+ # Restore tau to start for main training.
1125
+ for m in mlp_modules: m.tau = args.tau_start
1126
+
1127
+ trainable_params = [p for p in student.parameters() if p.requires_grad]
1128
+ n_train = sum(p.numel() for p in trainable_params)
1129
+ moe_params_max = sum(_d_ffn_at(i) * args.K for i in range(N_LAYERS)) \
1130
+ + N_LAYERS * D_MODEL * max(K_spec, 0)
1131
+ trainable_base = sum(p.numel() for n, p in student.named_parameters()
1132
+ if p.requires_grad and not (n.endswith(".A") or n.endswith(".W_r")))
1133
+ trainable_moe = sum(p.numel() for n, p in student.named_parameters()
1134
+ if p.requires_grad and (n.endswith(".A") or n.endswith(".W_r")))
1135
+ print(f" Trainable params: {n_train/1e6:.3f}M "
1136
+ f"(MoE: {trainable_moe/1e6:.3f}M / max {moe_params_max/1e6:.3f}M, "
1137
+ f"base trainable: {trainable_base/1e6:.2f}M)")
1138
+ if freeze_base and not args.freeze_A:
1139
+ assert trainable_base == 0, f"freeze_base=True but {trainable_base} base params are trainable"
1140
+ assert trainable_moe <= moe_params_max * 1.01, "Too many MoE params trainable"
1141
+ if args.freeze_A:
1142
+ assert trainable_moe <= N_LAYERS * D_MODEL * max(K_spec, 0) * 1.01, \
1143
+ "freeze_A=True but A appears to be trainable"
1144
+
1145
+ # Eval-only mode: skip training entirely, jump to final eval at tau_end.
1146
+ if args.eval_only:
1147
+ print(f" Eval-only mode — skipping training, evaluating at tau={args.tau_end}")
1148
+ print(f" Eval data: {args.eval_calib_path}")
1149
+ for m in mlp_modules: m.tau = args.tau_end
1150
+ final_ppl = eval_ppl(student, tokenizer, calib_path=args.eval_calib_path)
1151
+ print(f"\n=== Eval-only PPL (tau={args.tau_end}): {final_ppl:.4f} "
1152
+ f"baseline(bottom60 CE)={BASELINE_PPL:.4f} clean={CLEAN_PPL:.4f} ===")
1153
+ out = {
1154
+ "phase": args.phase, "config": vars(args),
1155
+ "final_ppl": final_ppl,
1156
+ "baseline_ppl": BASELINE_PPL, "clean_ppl": CLEAN_PPL,
1157
+ "ppl_curve": [], "eval_only": True,
1158
+ }
1159
+ os.makedirs("logs", exist_ok=True)
1160
+ out_path = f"logs/rung6_moe_{args.phase}_results.json"
1161
+ with open(out_path, "w") as f:
1162
+ json.dump(out, f, indent=2)
1163
+ print(f"Saved to {out_path}")
1164
+ return
1165
+
1166
+ # Split params into MoE (A + W_r) vs base for per-group LR.
1167
+ # --W_r_lr_mult multiplies the MoE group's LR relative to base_params' args.lr.
1168
+ moe_group_params = [p for n, p in student.named_parameters()
1169
+ if p.requires_grad and (n.endswith(".A") or n.endswith(".W_r"))]
1170
+ base_group_params = [p for n, p in student.named_parameters()
1171
+ if p.requires_grad and not (n.endswith(".A") or n.endswith(".W_r"))]
1172
+ param_groups = [
1173
+ {"params": base_group_params, "lr": args.lr},
1174
+ {"params": moe_group_params, "lr": args.lr * args.W_r_lr_mult},
1175
+ ]
1176
+ print(f" LR: base={args.lr:.2e} MoE(A+W_r)={args.lr * args.W_r_lr_mult:.2e} "
1177
+ f"(multiplier={args.W_r_lr_mult})")
1178
+ if args.optimizer == "adamw8bit":
1179
+ if not _HAS_BNB:
1180
+ raise RuntimeError("bitsandbytes not installed — pip install bitsandbytes")
1181
+ # Paged variant handles huge tensors (Gemma-4's embed_tokens_per_layer is 2.35B params,
1182
+ # exceeds non-paged bnb kernel grid limits → "invalid configuration argument").
1183
+ optimizer = bnb.optim.PagedAdamW8bit(param_groups, weight_decay=0.01)
1184
+ print(f" Using bnb.optim.PagedAdamW8bit (~28GB optimizer-state savings, "
1185
+ f"paged to handle Gemma-4's 2.35B embed_tokens_per_layer)")
1186
+ else:
1187
+ optimizer = AdamW(param_groups, weight_decay=0.01)
1188
+ scheduler = CosineAnnealingLR(optimizer, T_max=args.max_steps, eta_min=args.lr * 0.1)
1189
+
1190
+ print(f" Train data: {args.calib_path}")
1191
+ print(f" Eval data: {args.eval_calib_path}")
1192
+ # When train and eval paths differ, use ALL records of train file (no need to withhold 20%
1193
+ # since eval comes from a separate file).
1194
+ train_split = "all" if args.calib_path != args.eval_calib_path else "train"
1195
+ seqs = load_seqs(tokenizer, train_split, calib_path=args.calib_path)
1196
+ print(f" Loaded {len(seqs)} train sequences of {MAX_SEQ_LEN} tokens = {len(seqs)*MAX_SEQ_LEN/1e6:.2f}M tokens"
1197
+ f" (split={train_split})")
1198
+ # Deterministic shuffle: same --shuffle_seed reproduces the same record order.
1199
+ # Use a different seed in continuation runs to expose model to NEW orderings of
1200
+ # the dataset (avoids replaying the same trajectory the prior run already trained on).
1201
+ g = torch.Generator(); g.manual_seed(args.shuffle_seed)
1202
+ loader = torch.utils.data.DataLoader(seqs, BATCH, shuffle=True, generator=g)
1203
+ loader_iter = iter(loader)
1204
+ # Optional skip: discard first N samples of the shuffled stream before training begins.
1205
+ # Useful when a previous run with the same shuffle_seed consumed N samples.
1206
+ if args.data_skip > 0:
1207
+ skipped = 0
1208
+ for _ in range(args.data_skip):
1209
+ try:
1210
+ next(loader_iter); skipped += 1
1211
+ except StopIteration:
1212
+ loader_iter = iter(loader)
1213
+ next(loader_iter); skipped += 1
1214
+ print(f" Skipped first {skipped} samples (data_skip={args.data_skip})")
1215
+
1216
+ # Optional knowledge-preservation: load diverse corpus + cache teacher base params.
1217
+ diverse_loader_iter = None
1218
+ diverse_dataset_obj = None
1219
+ if args.diverse_calib_path:
1220
+ print(f" Diverse corpus (KL-to-base): {args.diverse_calib_path}")
1221
+ diverse_seqs = load_seqs(tokenizer, "all", calib_path=args.diverse_calib_path, raw_text=True)
1222
+ print(f" {len(diverse_seqs)} sequences, every {args.diverse_every_n} steps, "
1223
+ f"λ={args.kl_base_lambda}, T={args.kl_base_temp}")
1224
+ diverse_dataset_obj = torch.utils.data.DataLoader(diverse_seqs, BATCH, shuffle=True)
1225
+ diverse_loader_iter = iter(diverse_dataset_obj)
1226
+
1227
+ teacher_param_map = None
1228
+ if args.w_drift_lambda > 0:
1229
+ print(f" W-drift penalty active: λ={args.w_drift_lambda} on trainable base params")
1230
+ teacher_param_map = {n: p.detach() for n, p in teacher.named_parameters()}
1231
+
1232
+ step, accum_loss = 0, 0.0
1233
+ optimizer.zero_grad()
1234
+ t0 = time.time()
1235
+ curve = []
1236
+
1237
+ while step < args.max_steps:
1238
+ tau = get_tau(step, args.max_steps, args.tau_start, args.tau_end, hold_frac=args.tau_hold_frac)
1239
+ for m in mlp_modules: m.tau = tau
1240
+
1241
+ student.train()
1242
+ use_diverse = (diverse_loader_iter is not None and step > 0 and (step % args.diverse_every_n == 0))
1243
+
1244
+ if use_diverse:
1245
+ # Pretraining-distribution preservation batch: KL-to-teacher at temperature T.
1246
+ try: batch = next(diverse_loader_iter)
1247
+ except StopIteration:
1248
+ diverse_loader_iter = iter(diverse_dataset_obj); batch = next(diverse_loader_iter)
1249
+ ids = batch["input_ids"].to(DEVICE)
1250
+ with torch.no_grad():
1251
+ t_logits = teacher(ids)
1252
+ s_logits = student(ids)
1253
+ # High-temperature KL: softens sharp teacher distributions to carry tail signal.
1254
+ main_loss = args.kl_base_lambda * kl_loss(s_logits[:, :-1], t_logits[:, :-1], temp=args.kl_base_temp)
1255
+ else:
1256
+ # Normal CE/KL batch on IT trajectories.
1257
+ try: batch = next(loader_iter)
1258
+ except StopIteration:
1259
+ loader_iter = iter(loader); batch = next(loader_iter)
1260
+
1261
+ ids = batch["input_ids"].to(DEVICE)
1262
+ labels = batch["labels"][:, :-1].to(DEVICE)
1263
+ with torch.no_grad():
1264
+ t_logits = teacher(ids)
1265
+ s_logits = student(ids)
1266
+
1267
+ if args.loss == "kl":
1268
+ # Mask = positions where labels != -100 (i.e., assistant response only).
1269
+ # Same masking we apply to CE — keeps "don't train on prompt tokens" consistent.
1270
+ kl_mask = (labels != -100)
1271
+ main_loss = kl_loss(s_logits[:, :-1], t_logits[:, :-1],
1272
+ temp=args.main_kl_temp, mask=kl_mask)
1273
+ else:
1274
+ main_loss = ce_loss(s_logits[:, :-1], labels)
1275
+
1276
+ # Aux losses apply on every batch — functions of module state, not batch content.
1277
+ aux = sum(m.aux_loss(args.alpha_b, args.alpha_z) for m in mlp_modules)
1278
+ orth = sum(m.orth_loss() for m in mlp_modules) * args.alpha_orth
1279
+
1280
+ # Optional: weight-drift penalty on trainable base params (EWC-lite).
1281
+ drift = torch.tensor(0.0, device=DEVICE)
1282
+ if args.w_drift_lambda > 0:
1283
+ for n, p in student.named_parameters():
1284
+ if not p.requires_grad: continue
1285
+ if n.endswith(".A") or n.endswith(".W_r"): continue
1286
+ t = teacher_param_map.get(n) if teacher_param_map is not None else None
1287
+ if t is not None and t.shape == p.shape:
1288
+ drift = drift + ((p - t) ** 2).sum()
1289
+ drift = drift * args.w_drift_lambda
1290
+
1291
+ loss = (main_loss + aux + orth + drift) / GRAD_ACCUM
1292
+ loss.backward()
1293
+ accum_loss += loss.item()
1294
+
1295
+ if (step + 1) % GRAD_ACCUM == 0:
1296
+ torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
1297
+ optimizer.step(); scheduler.step(); optimizer.zero_grad()
1298
+
1299
+ if (step + 1) % args.eval_every == 0:
1300
+ # Diagnostic metrics (argmax-based hard assignment regardless of τ)
1301
+ with torch.no_grad():
1302
+ avg_entropy = 0.0; avg_jaccard = 0.0
1303
+ for m in mlp_modules:
1304
+ probs = F.softmax(m.A / max(tau, 1e-3), dim=-1) # [D_FFN, K]
1305
+ ent = -(probs * (probs.clamp_min(1e-8)).log()).sum(-1).mean().item()
1306
+ avg_entropy += ent
1307
+ # Hard assignment: each neuron → argmax expert
1308
+ hard = F.one_hot(probs.argmax(dim=-1), args.K).float().T # [K, D_FFN]
1309
+ inter = hard @ hard.T # [K, K]
1310
+ sz = hard.sum(dim=-1, keepdim=True) # [K, 1]
1311
+ union = sz + sz.T - inter
1312
+ jac_off = (inter / union.clamp_min(1.0))
1313
+ jac_off = jac_off - torch.diag(torch.diag(jac_off)) # zero diagonal
1314
+ avg_jaccard += jac_off.sum().item() / (args.K * (args.K - 1) + 1e-8)
1315
+ avg_entropy /= len(mlp_modules)
1316
+ avg_jaccard /= len(mlp_modules)
1317
+ ppl = eval_ppl(student, tokenizer, calib_path=args.eval_calib_path)
1318
+ curve.append({
1319
+ "step": step + 1, "ppl": ppl, "tau": tau,
1320
+ "assign_entropy": avg_entropy, "jaccard": avg_jaccard,
1321
+ })
1322
+ print(f" step={step+1:4d} tau={tau:.4f} loss={accum_loss*GRAD_ACCUM:.4f} "
1323
+ f"ppl={ppl:.4f} H(A)={avg_entropy:.3f} Jac={avg_jaccard:.4f} "
1324
+ f"t={time.time()-t0:.0f}s")
1325
+ accum_loss = 0.0
1326
+
1327
+ # Intermediate ckpt save (--save_every) — single rolling file, OVERWRITES previous.
1328
+ # Filename: <save_checkpoint stem>_intermediate.pt — only one extra ckpt on disk
1329
+ # at any time. Read 'step' field of the saved dict to know which step it was at.
1330
+ if args.save_every and args.save_checkpoint and (step + 1) % args.save_every == 0:
1331
+ stem, ext = os.path.splitext(args.save_checkpoint)
1332
+ inter_path = f"{stem}_intermediate{ext}"
1333
+ os.makedirs(os.path.dirname(inter_path) or ".", exist_ok=True)
1334
+ torch.save({
1335
+ 'student_state': student.state_dict(),
1336
+ 'config': vars(args),
1337
+ 'step': step + 1,
1338
+ }, inter_path)
1339
+ print(f" [intermediate] overwrote {inter_path} (step {step+1})")
1340
+
1341
+ step += 1
1342
+
1343
+ # Final eval at tau_end
1344
+ for m in mlp_modules: m.tau = args.tau_end
1345
+ final_ppl = eval_ppl(student, tokenizer, calib_path=args.eval_calib_path)
1346
+ print(f"\n=== Final PPL (tau={args.tau_end}): {final_ppl:.4f} "
1347
+ f"baseline(bottom60 CE)={BASELINE_PPL:.4f} clean={CLEAN_PPL:.4f} ===")
1348
+
1349
+ out = {
1350
+ "phase": args.phase, "config": vars(args),
1351
+ "final_ppl": final_ppl,
1352
+ "baseline_ppl": BASELINE_PPL, "clean_ppl": CLEAN_PPL,
1353
+ "ppl_curve": curve,
1354
+ }
1355
+ os.makedirs("logs", exist_ok=True)
1356
+ out_path = f"logs/rung6_moe_{args.phase}_results.json"
1357
+ with open(out_path, "w") as f:
1358
+ json.dump(out, f, indent=2)
1359
+ print(f"Saved to {out_path}")
1360
+
1361
+ if args.save_checkpoint:
1362
+ os.makedirs(os.path.dirname(args.save_checkpoint) or ".", exist_ok=True)
1363
+ torch.save({
1364
+ 'student_state': student.state_dict(),
1365
+ 'config': vars(args),
1366
+ 'final_ppl': final_ppl,
1367
+ }, args.save_checkpoint)
1368
+ print(f"Saved checkpoint to {args.save_checkpoint}")
1369
+
1370
+
1371
+ if __name__ == "__main__":
1372
+ main()