ncylich commited on
Commit
659ffa1
·
verified ·
1 Parent(s): 58d34c8

Upload rung7_swiglu_g4.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. rung7_swiglu_g4.py +364 -0
rung7_swiglu_g4.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """rung7_swiglu_g4.py — per-token top-K gate mask for Gemma-4 (no expert structure).
3
+
4
+ Each neuron is its own "expert". Per-token mask = top-K by |gate_act| magnitude,
5
+ relaxed via sigmoid((|gate| - kth_threshold) / τ) for differentiability.
6
+ At τ→0 it converges to hard top-K. No router, no MECE partition, no A matrix.
7
+
8
+ Mirrors rung6_moe_g4.py CLI/training loop but installs GateMaskedMLP instead of
9
+ MoEMLP. Reuses load_seqs / eval_ppl / wrap_int4 / get_tau / kl_loss / ce_loss.
10
+
11
+ Strong prior: Gemma-3 Design 6 (this exact mechanism) hit PPL 7.26 vs base 7.89
12
+ = 0.92× base. Best result on Gemma-3. Never tried on Gemma-4.
13
+ """
14
+ import argparse, json, math, os, time
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.optim import AdamW
19
+ from torch.optim.lr_scheduler import CosineAnnealingLR
20
+
21
+ from gemma4_hf import load_gemma4, DEVICE, N_LAYERS
22
+ from rung6_moe_g4 import (
23
+ Int4QuantLinear, wrap_int4, apply_int4_inplace,
24
+ LoRALinear, wrap_lora,
25
+ load_seqs, eval_ppl, kl_loss, ce_loss, get_tau,
26
+ _d_ffn_at,
27
+ MAX_SEQ_LEN, BATCH, LR, BASELINE_PPL, CLEAN_PPL,
28
+ )
29
+
30
+
31
+ class GateMaskedMLP(nn.Module):
32
+ """Per-token top-K mask on |gate_act|. τ-annealed sigmoid relaxation.
33
+
34
+ Forward:
35
+ gate_act = gelu(gate_proj(x))
36
+ threshold[t] = kth-largest |gate_act[t]| (k = k_keep)
37
+ mask[t,j] = sigmoid((|gate_act[t,j]| - threshold[t]) / τ)
38
+ h = gate_act * up_proj(x) * mask
39
+ out = down_proj(h)
40
+ """
41
+ def __init__(self, base_mlp, k_keep, freeze_base=False):
42
+ super().__init__()
43
+ self.gate_proj = base_mlp.gate_proj
44
+ self.up_proj = base_mlp.up_proj
45
+ self.down_proj = base_mlp.down_proj
46
+ if freeze_base:
47
+ for p in self.gate_proj.parameters(): p.requires_grad_(False)
48
+ for p in self.up_proj.parameters(): p.requires_grad_(False)
49
+ for p in self.down_proj.parameters(): p.requires_grad_(False)
50
+ self.k_keep = int(k_keep)
51
+ self.tau = 1.0 # set externally each step
52
+
53
+ def forward(self, x):
54
+ gate_raw = self.gate_proj(x)
55
+ gate_act = F.gelu(gate_raw, approximate="tanh") # [B, T, D_FFN]
56
+ up_act = self.up_proj(x)
57
+ gate_abs = gate_act.abs().to(torch.float32)
58
+ # Per-token kth-largest threshold (non-differentiable wrt selection,
59
+ # but mask values around the threshold ARE differentiable via sigmoid).
60
+ threshold = gate_abs.topk(self.k_keep, dim=-1).values[..., -1:] # [B, T, 1]
61
+ mask = torch.sigmoid((gate_abs - threshold) / max(self.tau, 1e-3)) # [B, T, D_FFN]
62
+ h = gate_act * up_act * mask.to(gate_act.dtype)
63
+ return self.down_proj(h)
64
+
65
+
66
+ def install_gate_mask(model, density, freeze_base=False):
67
+ mlp_modules = []
68
+ for i in range(N_LAYERS):
69
+ d_ffn = _d_ffn_at(i)
70
+ k_keep = max(1, int(round(d_ffn * density)))
71
+ new_mlp = GateMaskedMLP(model.layers[i].mlp, k_keep=k_keep, freeze_base=freeze_base)
72
+ model.layers[i].mlp = new_mlp
73
+ mlp_modules.append(new_mlp)
74
+ return mlp_modules
75
+
76
+
77
+ def main():
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument("--phase", type=str, default="S1")
80
+ parser.add_argument("--density", type=float, default=0.75,
81
+ help="Fraction of MLP neurons to keep per token (e.g. 0.75 ≈ Aconst4 density)")
82
+ parser.add_argument("--loss", choices=["kl", "ce"], default="ce")
83
+ parser.add_argument("--int4_qat", action="store_true")
84
+ parser.add_argument("--int4_group_size", type=int, default=32)
85
+ parser.add_argument("--unfreeze_base", action="store_true",
86
+ help="Train base weights (gate/up/down + attn). Default freezes them.")
87
+ parser.add_argument("--freeze_embeddings", action="store_true")
88
+ parser.add_argument("--gate_only_train", action="store_true",
89
+ help="Override: freeze entire model, only gate_proj across all layers trains. "
90
+ "Tests whether the gate alone can route + adapt.")
91
+ parser.add_argument("--gate_lora_train", action="store_true",
92
+ help="Override: freeze entire model, train gate_proj + LoRA adapters on "
93
+ "up_proj/down_proj. Tests whether LoRA on the masked weights "
94
+ "compensates for aggressive masking at low density.")
95
+ parser.add_argument("--lora_targets", type=str, default="",
96
+ help="Comma-separated substrings of Linear names to wrap with LoRA. "
97
+ "Default (empty) uses the rung6 wrap_lora default. For gate_lora_train, "
98
+ "set to 'up_proj,down_proj' to skip gate_proj.")
99
+ parser.add_argument("--use_lora", action="store_true")
100
+ parser.add_argument("--lora_rank", type=int, default=16)
101
+ parser.add_argument("--lora_alpha", type=float, default=16.0)
102
+ parser.add_argument("--tau_start", type=float, default=1.0)
103
+ parser.add_argument("--tau_end", type=float, default=0.01)
104
+ parser.add_argument("--tau_hold_frac", type=float, default=0.2)
105
+ parser.add_argument("--max_steps", type=int, default=10000)
106
+ parser.add_argument("--lr", type=float, default=LR)
107
+ parser.add_argument("--main_kl_temp", type=float, default=2.0)
108
+ parser.add_argument("--shuffle_seed", type=int, default=42)
109
+ parser.add_argument("--data_skip", type=int, default=0)
110
+ parser.add_argument("--save_every", type=int, default=2500)
111
+ parser.add_argument("--eval_every", type=int, default=2500)
112
+ parser.add_argument("--eval_max_seqs", type=int, default=0,
113
+ help="Cap eval to first N sequences (0 = no cap, current behavior). "
114
+ "Set e.g. 200 to keep mid-training evals fast; the final "
115
+ "post-training eval line still runs full unless capped here.")
116
+ parser.add_argument("--calib_path", type=str, required=True)
117
+ parser.add_argument("--eval_calib_path", type=str, required=True)
118
+ parser.add_argument("--load_checkpoint", type=str, default="")
119
+ parser.add_argument("--save_checkpoint", type=str, default="")
120
+ parser.add_argument("--diverse_calib_path", type=str, default="")
121
+ parser.add_argument("--diverse_every_n", type=int, default=4)
122
+ parser.add_argument("--kl_base_lambda", type=float, default=0.5)
123
+ parser.add_argument("--kl_base_temp", type=float, default=2.0)
124
+ parser.add_argument("--w_drift_lambda", type=float, default=0.0)
125
+ args = parser.parse_args()
126
+
127
+ print(f"=== Rung 7 SWIGLU gate-mask — phase={args.phase} ===")
128
+ print(f" density={args.density:.2f} loss={args.loss}")
129
+ print(f" tau: {args.tau_start} → {args.tau_end} over {args.max_steps} steps "
130
+ f"(hold last {args.tau_hold_frac*100:.0f}%)")
131
+ print(f" unfreeze_base={args.unfreeze_base} freeze_embeddings={args.freeze_embeddings}")
132
+ print(f" int4_qat={args.int4_qat} use_lora={args.use_lora}")
133
+ if args.load_checkpoint:
134
+ print(f" load_checkpoint={args.load_checkpoint}")
135
+ if args.save_checkpoint:
136
+ print(f" save_checkpoint={args.save_checkpoint}")
137
+
138
+ # Teacher is only needed if main loss is KL, a diverse-corpus KL-to-base
139
+ # regularizer is configured, or the W-drift penalty needs the teacher's
140
+ # snapshot. With --loss ce and no diverse / drift, the teacher forward is
141
+ # dead compute and ~9GB of dead weight; skip loading it.
142
+ teacher_ever_needed = (
143
+ args.loss == "kl"
144
+ or bool(args.diverse_calib_path)
145
+ or args.w_drift_lambda > 0
146
+ )
147
+
148
+ if teacher_ever_needed:
149
+ print("Loading teacher & student on cuda...")
150
+ teacher, tokenizer = load_gemma4()
151
+ teacher.eval()
152
+ for p in teacher.parameters(): p.requires_grad_(False)
153
+ else:
154
+ print("Loading student only on cuda (teacher not needed: --loss ce, no diverse calib)...")
155
+ # Tokenizer comes from a lightweight load; reuse the student load below.
156
+ teacher = None
157
+
158
+ student, tokenizer_s = load_gemma4()
159
+ if teacher is None:
160
+ tokenizer = tokenizer_s
161
+ if args.freeze_embeddings:
162
+ for n, p in student.named_parameters():
163
+ if "embed_tokens" in n or "lm_head" in n:
164
+ p.requires_grad_(False)
165
+ n_frozen = sum(p.numel() for n, p in student.named_parameters()
166
+ if ("embed_tokens" in n or "lm_head" in n))
167
+ print(f" Froze embeddings: {n_frozen/1e9:.2f}B params")
168
+
169
+ freeze_base_in_mlp = not args.unfreeze_base
170
+ mlp_modules = install_gate_mask(student, density=args.density,
171
+ freeze_base=freeze_base_in_mlp)
172
+ print(f" Installed GateMaskedMLP on {N_LAYERS} layers; "
173
+ f"k_keep range = [{min(m.k_keep for m in mlp_modules)}, {max(m.k_keep for m in mlp_modules)}]")
174
+
175
+ if args.load_checkpoint:
176
+ print(f" Loading checkpoint from {args.load_checkpoint}...")
177
+ ckpt = torch.load(args.load_checkpoint, map_location=DEVICE, weights_only=False)
178
+ missing, unexpected = student.load_state_dict(ckpt["student_state"], strict=False)
179
+ print(f" missing={len(missing)} unexpected={len(unexpected)}")
180
+
181
+ if args.int4_qat:
182
+ Int4QuantLinear._group_size = args.int4_group_size
183
+ n_wrap = wrap_int4(student)
184
+ print(f" Int4 QAT: wrapped {n_wrap} nn.Linear modules (group_size={args.int4_group_size})")
185
+
186
+ if args.use_lora or args.gate_lora_train:
187
+ if args.lora_targets:
188
+ targets = tuple(t.strip() for t in args.lora_targets.split(",") if t.strip())
189
+ n_lora, n_lora_p = wrap_lora(student, rank=args.lora_rank,
190
+ alpha=args.lora_alpha, target_substrings=targets)
191
+ else:
192
+ n_lora, n_lora_p = wrap_lora(student, rank=args.lora_rank, alpha=args.lora_alpha)
193
+ print(f" LoRA: rank={args.lora_rank} alpha={args.lora_alpha} "
194
+ f"({n_lora} modules, {n_lora_p/1e6:.2f}M params)")
195
+
196
+ if args.load_checkpoint:
197
+ # Re-load after wrappers (LoRA / int4 add new keys)
198
+ missing2, unexp2 = student.load_state_dict(ckpt["student_state"], strict=False)
199
+ print(f" re-loaded after wrappers: missing={len(missing2)} unexpected={len(unexp2)}")
200
+
201
+ if args.gate_only_train:
202
+ for p in student.parameters():
203
+ p.requires_grad_(False)
204
+ for n, p in student.named_parameters():
205
+ if "gate_proj" in n:
206
+ p.requires_grad_(True)
207
+ n_gate = sum(p.numel() for n, p in student.named_parameters() if p.requires_grad)
208
+ print(f" --gate_only_train override: only gate_proj trains ({n_gate/1e6:.2f}M params)")
209
+
210
+ if args.gate_lora_train:
211
+ for p in student.parameters():
212
+ p.requires_grad_(False)
213
+ for n, p in student.named_parameters():
214
+ # Gate projection trains directly (router specialization).
215
+ # LoRA adapters on up/down_proj train (compensate for aggressive masking).
216
+ # NOTE: a Linear named "..mlp.gate_proj" wrapped by LoRA becomes "..mlp.gate_proj.base"
217
+ # — to avoid ambiguity we use --lora_targets up_proj,down_proj so gate isn't wrapped.
218
+ if "gate_proj" in n or "lora_a" in n or "lora_b" in n:
219
+ p.requires_grad_(True)
220
+ n_train = sum(p.numel() for n, p in student.named_parameters() if p.requires_grad)
221
+ n_gate_p = sum(p.numel() for n, p in student.named_parameters()
222
+ if p.requires_grad and "gate_proj" in n)
223
+ n_lora_p = sum(p.numel() for n, p in student.named_parameters()
224
+ if p.requires_grad and ("lora_a" in n or "lora_b" in n))
225
+ print(f" --gate_lora_train override: gate_proj + LoRA adapters train "
226
+ f"({n_train/1e6:.2f}M total — gate {n_gate_p/1e6:.2f}M + LoRA {n_lora_p/1e6:.2f}M)")
227
+
228
+ n_train = sum(p.numel() for p in student.parameters() if p.requires_grad)
229
+ print(f" Trainable params: {n_train/1e6:.3f}M (no router; mask is non-parametric)")
230
+
231
+ optimizer = AdamW([p for p in student.parameters() if p.requires_grad],
232
+ lr=args.lr, weight_decay=0.01)
233
+ scheduler = CosineAnnealingLR(optimizer, T_max=args.max_steps, eta_min=args.lr * 0.1)
234
+
235
+ print(f" Train data: {args.calib_path}")
236
+ print(f" Eval data: {args.eval_calib_path}")
237
+ train_split = "all" if args.calib_path != args.eval_calib_path else "train"
238
+ seqs = load_seqs(tokenizer, train_split, calib_path=args.calib_path)
239
+ print(f" Loaded {len(seqs)} train sequences of {MAX_SEQ_LEN} tokens "
240
+ f"= {len(seqs)*MAX_SEQ_LEN/1e6:.2f}M tokens (split={train_split})")
241
+ g = torch.Generator(); g.manual_seed(args.shuffle_seed)
242
+ loader = torch.utils.data.DataLoader(seqs, BATCH, shuffle=True, generator=g)
243
+ loader_iter = iter(loader)
244
+ if args.data_skip > 0:
245
+ for _ in range(args.data_skip):
246
+ try: next(loader_iter)
247
+ except StopIteration:
248
+ loader_iter = iter(loader); next(loader_iter)
249
+ print(f" Skipped first {args.data_skip} samples")
250
+
251
+ diverse_loader_iter = None
252
+ if args.diverse_calib_path:
253
+ print(f" Diverse corpus (KL-to-base): {args.diverse_calib_path}")
254
+ diverse_seqs = load_seqs(tokenizer, "all", calib_path=args.diverse_calib_path, raw_text=True)
255
+ print(f" {len(diverse_seqs)} sequences, every {args.diverse_every_n} steps, "
256
+ f"λ={args.kl_base_lambda}, T={args.kl_base_temp}")
257
+ diverse_loader = torch.utils.data.DataLoader(diverse_seqs, BATCH, shuffle=True)
258
+ diverse_loader_iter = iter(diverse_loader)
259
+
260
+ teacher_param_map = None
261
+ if args.w_drift_lambda > 0:
262
+ print(f" W-drift penalty active: λ={args.w_drift_lambda}")
263
+ teacher_param_map = {n: p.detach() for n, p in teacher.named_parameters()}
264
+
265
+ step = 0
266
+ t0 = time.time()
267
+ curve = []
268
+ optimizer.zero_grad()
269
+
270
+ while step < args.max_steps:
271
+ tau = get_tau(step, args.max_steps, args.tau_start, args.tau_end,
272
+ hold_frac=args.tau_hold_frac)
273
+ for m in mlp_modules: m.tau = tau
274
+
275
+ try: batch = next(loader_iter)
276
+ except StopIteration:
277
+ loader_iter = iter(loader); batch = next(loader_iter)
278
+ input_ids = batch["input_ids"].to(DEVICE)
279
+ labels = batch["labels"].to(DEVICE)
280
+
281
+ # Teacher forward is needed only if the main loss is KL or if a diverse
282
+ # KL-to-base regularizer is firing this step. With --loss ce and no
283
+ # --diverse_calib_path, the teacher logits are computed-then-discarded —
284
+ # ~half the per-step compute (4.65B params) for nothing. Short-circuit
285
+ # in that case. Numerically equivalent to dropping a dead branch; no
286
+ # change to the gradient that reaches the student.
287
+ diverse_active_this_step = (
288
+ diverse_loader_iter is not None and step % args.diverse_every_n == 0
289
+ )
290
+ teacher_needed = (args.loss == "kl") or diverse_active_this_step
291
+
292
+ if teacher_needed and args.loss == "kl":
293
+ with torch.no_grad():
294
+ t_logits = teacher(input_ids)
295
+ s_logits = student(input_ids)
296
+
297
+ if args.loss == "kl":
298
+ mask = (labels != -100)
299
+ loss = kl_loss(s_logits, t_logits, temp=args.main_kl_temp, mask=mask)
300
+ else:
301
+ loss = ce_loss(s_logits, labels)
302
+
303
+ if diverse_loader_iter is not None and step % args.diverse_every_n == 0:
304
+ try: dbatch = next(diverse_loader_iter)
305
+ except StopIteration:
306
+ diverse_loader_iter = iter(diverse_loader); dbatch = next(diverse_loader_iter)
307
+ d_ids = dbatch["input_ids"].to(DEVICE)
308
+ with torch.no_grad():
309
+ t_d_logits = teacher(d_ids)
310
+ s_d_logits = student(d_ids)
311
+ d_kl = kl_loss(s_d_logits, t_d_logits, temp=args.kl_base_temp)
312
+ loss = loss + args.kl_base_lambda * d_kl
313
+
314
+ if teacher_param_map is not None:
315
+ drift = 0.0
316
+ for n, p in student.named_parameters():
317
+ if not p.requires_grad: continue
318
+ if n in teacher_param_map:
319
+ drift = drift + (p - teacher_param_map[n]).pow(2).sum()
320
+ loss = loss + args.w_drift_lambda * drift
321
+
322
+ loss.backward()
323
+ torch.nn.utils.clip_grad_norm_([p for p in student.parameters() if p.requires_grad], 1.0)
324
+ optimizer.step()
325
+ optimizer.zero_grad()
326
+ scheduler.step()
327
+
328
+ step += 1
329
+ if step % args.eval_every == 0 or step == args.max_steps:
330
+ ppl = eval_ppl(student, tokenizer, calib_path=args.eval_calib_path,
331
+ max_seqs=(args.eval_max_seqs or None))
332
+ elapsed = time.time() - t0
333
+ print(f" step={step:5d} tau={tau:.4f} loss={loss.item():.4f} "
334
+ f"ppl={ppl:.4f} t={elapsed:.0f}s")
335
+ curve.append({"step": step, "tau": tau, "loss": float(loss.item()), "ppl": float(ppl)})
336
+ if args.save_checkpoint and step % args.save_every == 0 and step < args.max_steps:
337
+ interim = args.save_checkpoint.replace(".pt", "_intermediate.pt")
338
+ torch.save({"student_state": student.state_dict(),
339
+ "config": vars(args), "step": step, "ppl": ppl}, interim)
340
+ print(f" [intermediate] overwrote {interim} (step {step})")
341
+
342
+ final_ppl = eval_ppl(student, tokenizer, calib_path=args.eval_calib_path,
343
+ max_seqs=(args.eval_max_seqs or None))
344
+ print(f"\n=== Final PPL (tau={args.tau_end}): {final_ppl:.4f} ===")
345
+
346
+ out = {"phase": args.phase, "config": vars(args), "final_ppl": final_ppl,
347
+ "ppl_curve": curve}
348
+ os.makedirs("logs", exist_ok=True)
349
+ out_path = f"logs/rung7_swiglu_{args.phase}_results.json"
350
+ with open(out_path, "w") as f: json.dump(out, f, indent=2)
351
+ print(f"Saved to {out_path}")
352
+
353
+ if args.save_checkpoint:
354
+ torch.save({"student_state": student.state_dict(),
355
+ "config": vars(args), "final_ppl": final_ppl}, args.save_checkpoint)
356
+ print(f"Saved checkpoint to {args.save_checkpoint}")
357
+ interim = args.save_checkpoint.replace(".pt", "_intermediate.pt")
358
+ if os.path.exists(interim):
359
+ os.remove(interim)
360
+ print(f"Removed {interim}")
361
+
362
+
363
+ if __name__ == "__main__":
364
+ main()