ncylich commited on
Commit
58d34c8
·
verified ·
1 Parent(s): b460126

Upload rung8_grouped_g4.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. rung8_grouped_g4.py +276 -0
rung8_grouped_g4.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """rung8_grouped_g4.py — coarse-grained grouped expert routing on Gemma-4.
3
+
4
+ Each MLP's D_FFN neurons are partitioned into K groups (cluster assignments
5
+ loaded from analyze_activation_groups.py output). Per token, top-K_active
6
+ groups are selected; all neurons within a selected group are activated.
7
+
8
+ Vs rung7's per-neuron mask:
9
+ - Coarser-grained → potentially compute/memory-bandwidth efficient
10
+ (skip whole groups, not individual neurons)
11
+ - Same per-token density target but expressed via group selection
12
+
13
+ Usage mirrors rung7 + adds:
14
+ --K_groups 64 --group_assignments_dir logs/groups --group_tag s25
15
+ """
16
+ import argparse, json, math, os, time
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.optim import AdamW
21
+ from torch.optim.lr_scheduler import CosineAnnealingLR
22
+
23
+ from gemma4_hf import load_gemma4, DEVICE, N_LAYERS
24
+ from rung6_moe_g4 import (
25
+ Int4QuantLinear, wrap_int4, apply_int4_inplace,
26
+ LoRALinear, wrap_lora,
27
+ load_seqs, eval_ppl, kl_loss, ce_loss, get_tau,
28
+ _d_ffn_at,
29
+ MAX_SEQ_LEN, BATCH, LR, BASELINE_PPL, CLEAN_PPL,
30
+ )
31
+
32
+
33
+ class GroupedMaskedMLP(nn.Module):
34
+ """Top-K_active group routing. Group score = max |gate_act| within group.
35
+
36
+ Forward:
37
+ gate_act = gelu(gate_proj(x)) # [B, T, D_FFN]
38
+ for each token, group_g_score = max over j in group g of |gate_act[j]|
39
+ select top K_active groups → mask all neurons in selected groups
40
+ h = gate_act * up_proj(x) * mask
41
+ out = down_proj(h)
42
+ """
43
+ def __init__(self, base_mlp, K_groups, K_active, group_assignments, freeze_base=False):
44
+ super().__init__()
45
+ self.gate_proj = base_mlp.gate_proj
46
+ self.up_proj = base_mlp.up_proj
47
+ self.down_proj = base_mlp.down_proj
48
+ if freeze_base:
49
+ for p in self.gate_proj.parameters(): p.requires_grad_(False)
50
+ for p in self.up_proj.parameters(): p.requires_grad_(False)
51
+ for p in self.down_proj.parameters(): p.requires_grad_(False)
52
+ self.K_groups = int(K_groups)
53
+ self.K_active = int(K_active)
54
+ # group_assignments: [D_FFN] long, in [0, K_groups)
55
+ self.register_buffer("group_assignments", group_assignments.long())
56
+ # Build group → neuron map (one-hot) for vectorized scatter
57
+ # neuron_in_group[d, g] = 1 if neuron d is in group g, else 0 shape [D_FFN, K_groups]
58
+ D = group_assignments.shape[0]
59
+ nig = torch.zeros(D, K_groups)
60
+ nig.scatter_(1, group_assignments.long().unsqueeze(1), 1.0)
61
+ self.register_buffer("neuron_in_group", nig)
62
+ self.tau = 0.01 # used only for sigmoid relaxation; defaults hard
63
+
64
+ def forward(self, x):
65
+ gate_act = F.gelu(self.gate_proj(x), approximate="tanh") # [B, T, D_FFN]
66
+ up_act = self.up_proj(x)
67
+ gate_abs = gate_act.abs().to(torch.float32)
68
+ B, T, D = gate_abs.shape
69
+ BT = B * T
70
+ flat = gate_abs.view(BT, D)
71
+ # Group score = max within group (vectorized via scatter_reduce)
72
+ group_score = torch.full((BT, self.K_groups), -float("inf"),
73
+ device=gate_act.device, dtype=torch.float32)
74
+ group_score.scatter_reduce_(1, self.group_assignments.unsqueeze(0).expand(BT, -1),
75
+ flat, reduce="amax", include_self=False)
76
+ # Top-K_active groups per token
77
+ top_vals, top_idx = group_score.topk(self.K_active, dim=-1) # [BT, K_active]
78
+ # Sigmoid relaxation around the K_active-th largest group-score:
79
+ # neuron_score[d] = group_score[group_of_d]
80
+ # mask = sigmoid((neuron_score - kth_thr) / tau)
81
+ kth_thr = top_vals[..., -1:] # [BT, 1]
82
+ neuron_score = group_score.gather(1, self.group_assignments.unsqueeze(0).expand(BT, -1).long())
83
+ mask_flat = torch.sigmoid((neuron_score - kth_thr) / max(self.tau, 1e-3))
84
+ mask = mask_flat.view(B, T, D)
85
+ h = gate_act * up_act * mask.to(gate_act.dtype)
86
+ return self.down_proj(h)
87
+
88
+
89
+ def install_grouped(model, K_groups, K_active, group_assignments_dir, group_tag, freeze_base=False):
90
+ mlp_modules = []
91
+ for i in range(N_LAYERS):
92
+ d_ffn = _d_ffn_at(i)
93
+ path = f"{group_assignments_dir}/{group_tag}_K{K_groups}_layer{i}.pt"
94
+ if not os.path.exists(path):
95
+ raise FileNotFoundError(f"Missing group assignments: {path}")
96
+ assignments = torch.load(path, map_location="cpu", weights_only=False)
97
+ if assignments.numel() != d_ffn:
98
+ raise ValueError(f"Layer {i}: assignments size {assignments.numel()} != D_FFN {d_ffn}")
99
+ new_mlp = GroupedMaskedMLP(model.layers[i].mlp,
100
+ K_groups=K_groups, K_active=K_active,
101
+ group_assignments=assignments,
102
+ freeze_base=freeze_base)
103
+ new_mlp = new_mlp.to(DEVICE)
104
+ model.layers[i].mlp = new_mlp
105
+ mlp_modules.append(new_mlp)
106
+ return mlp_modules
107
+
108
+
109
+ def main():
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument("--phase", type=str, default="G1")
112
+ parser.add_argument("--K_groups", type=int, required=True)
113
+ parser.add_argument("--density", type=float, default=0.25,
114
+ help="Target per-token density: K_active = density * K_groups (rounded)")
115
+ parser.add_argument("--K_active", type=int, default=0,
116
+ help="Override K_active explicitly (else computed from density)")
117
+ parser.add_argument("--group_assignments_dir", default="logs/groups")
118
+ parser.add_argument("--group_tag", required=True)
119
+ parser.add_argument("--loss", choices=["kl", "ce"], default="ce")
120
+ parser.add_argument("--int4_qat", action="store_true")
121
+ parser.add_argument("--int4_group_size", type=int, default=32)
122
+ parser.add_argument("--unfreeze_base", action="store_true")
123
+ parser.add_argument("--freeze_embeddings", action="store_true")
124
+ parser.add_argument("--use_lora", action="store_true")
125
+ parser.add_argument("--lora_targets", type=str, default="")
126
+ parser.add_argument("--lora_rank", type=int, default=16)
127
+ parser.add_argument("--lora_alpha", type=float, default=16.0)
128
+ parser.add_argument("--tau", type=float, default=0.01)
129
+ parser.add_argument("--max_steps", type=int, default=500)
130
+ parser.add_argument("--lr", type=float, default=1e-5)
131
+ parser.add_argument("--shuffle_seed", type=int, default=42)
132
+ parser.add_argument("--save_every", type=int, default=200)
133
+ parser.add_argument("--eval_every", type=int, default=100)
134
+ parser.add_argument("--eval_max_seqs", type=int, default=200)
135
+ parser.add_argument("--calib_path", required=True)
136
+ parser.add_argument("--eval_calib_path", required=True)
137
+ parser.add_argument("--load_checkpoint", type=str, default="")
138
+ parser.add_argument("--save_checkpoint", type=str, default="")
139
+ args = parser.parse_args()
140
+
141
+ K_active = args.K_active if args.K_active > 0 else max(1, round(args.K_groups * args.density))
142
+ print(f"=== Rung 8 Grouped — phase={args.phase} ===")
143
+ print(f" K_groups={args.K_groups} K_active={K_active} effective_density={K_active/args.K_groups:.3f}")
144
+ print(f" loss={args.loss} unfreeze_base={args.unfreeze_base} use_lora={args.use_lora}")
145
+ print(f" group_assignments_dir={args.group_assignments_dir} group_tag={args.group_tag}")
146
+ if args.load_checkpoint:
147
+ print(f" load_checkpoint={args.load_checkpoint}")
148
+
149
+ # Teacher only needed for KL; for CE we skip
150
+ teacher_needed = (args.loss == "kl")
151
+ if teacher_needed:
152
+ print("Loading teacher & student...")
153
+ teacher, tokenizer = load_gemma4()
154
+ teacher.eval()
155
+ for p in teacher.parameters(): p.requires_grad_(False)
156
+ student, _ = load_gemma4()
157
+ else:
158
+ print("Loading student only (CE loss; teacher skipped)...")
159
+ teacher = None
160
+ student, tokenizer = load_gemma4()
161
+
162
+ if args.freeze_embeddings:
163
+ for n, p in student.named_parameters():
164
+ if "embed_tokens" in n or "lm_head" in n:
165
+ p.requires_grad_(False)
166
+
167
+ freeze_base_in_mlp = not args.unfreeze_base
168
+ mlp_modules = install_grouped(student,
169
+ K_groups=args.K_groups, K_active=K_active,
170
+ group_assignments_dir=args.group_assignments_dir,
171
+ group_tag=args.group_tag,
172
+ freeze_base=freeze_base_in_mlp)
173
+ print(f" Installed GroupedMaskedMLP on {N_LAYERS} layers")
174
+
175
+ if args.load_checkpoint:
176
+ print(f" Loading checkpoint from {args.load_checkpoint}...")
177
+ ckpt = torch.load(args.load_checkpoint, map_location=DEVICE, weights_only=False)
178
+ # The loaded ckpt has GateMaskedMLP state (no group_assignments, no neuron_in_group).
179
+ # Load with strict=False — only base proj weights match, group buffers stay as we set them.
180
+ missing, unexpected = student.load_state_dict(ckpt["student_state"], strict=False)
181
+ print(f" missing={len(missing)} unexpected={len(unexpected)}")
182
+
183
+ if args.int4_qat:
184
+ Int4QuantLinear._group_size = args.int4_group_size
185
+ n_wrap = wrap_int4(student)
186
+ print(f" Int4 QAT: wrapped {n_wrap} Linear modules")
187
+
188
+ if args.use_lora:
189
+ if args.lora_targets:
190
+ targets = tuple(t.strip() for t in args.lora_targets.split(",") if t.strip())
191
+ n_lora, n_lora_p = wrap_lora(student, rank=args.lora_rank, alpha=args.lora_alpha,
192
+ target_substrings=targets)
193
+ else:
194
+ n_lora, n_lora_p = wrap_lora(student, rank=args.lora_rank, alpha=args.lora_alpha)
195
+ print(f" LoRA: rank={args.lora_rank} alpha={args.lora_alpha} ({n_lora} modules, {n_lora_p/1e6:.2f}M)")
196
+
197
+ if args.load_checkpoint:
198
+ missing2, unexp2 = student.load_state_dict(ckpt["student_state"], strict=False)
199
+ print(f" re-loaded after wrappers: missing={len(missing2)} unexpected={len(unexp2)}")
200
+
201
+ for m in mlp_modules: m.tau = args.tau
202
+
203
+ n_train = sum(p.numel() for p in student.parameters() if p.requires_grad)
204
+ print(f" Trainable params: {n_train/1e6:.3f}M")
205
+
206
+ optimizer = AdamW([p for p in student.parameters() if p.requires_grad],
207
+ lr=args.lr, weight_decay=0.01)
208
+ scheduler = CosineAnnealingLR(optimizer, T_max=args.max_steps, eta_min=args.lr * 0.1)
209
+
210
+ print(f" Train: {args.calib_path}\n Eval: {args.eval_calib_path}")
211
+ train_split = "all" if args.calib_path != args.eval_calib_path else "train"
212
+ seqs = load_seqs(tokenizer, train_split, calib_path=args.calib_path)
213
+ print(f" Loaded {len(seqs)} sequences")
214
+ g = torch.Generator(); g.manual_seed(args.shuffle_seed)
215
+ loader = torch.utils.data.DataLoader(seqs, BATCH, shuffle=True, generator=g)
216
+ loader_iter = iter(loader)
217
+
218
+ step = 0
219
+ t0 = time.time()
220
+ curve = []
221
+ optimizer.zero_grad()
222
+
223
+ while step < args.max_steps:
224
+ try: batch = next(loader_iter)
225
+ except StopIteration:
226
+ loader_iter = iter(loader); batch = next(loader_iter)
227
+ input_ids = batch["input_ids"].to(DEVICE)
228
+ labels = batch["labels"].to(DEVICE)
229
+
230
+ if teacher is not None and args.loss == "kl":
231
+ with torch.no_grad():
232
+ t_logits = teacher(input_ids)
233
+ s_logits = student(input_ids)
234
+
235
+ if args.loss == "kl":
236
+ mask = (labels != -100)
237
+ loss = kl_loss(s_logits, t_logits, temp=2.0, mask=mask)
238
+ else:
239
+ loss = ce_loss(s_logits, labels)
240
+
241
+ loss.backward()
242
+ torch.nn.utils.clip_grad_norm_([p for p in student.parameters() if p.requires_grad], 1.0)
243
+ optimizer.step()
244
+ optimizer.zero_grad()
245
+ scheduler.step()
246
+
247
+ step += 1
248
+ if step % args.eval_every == 0 or step == args.max_steps:
249
+ ppl = eval_ppl(student, tokenizer, calib_path=args.eval_calib_path,
250
+ max_seqs=args.eval_max_seqs if args.eval_max_seqs > 0 else None)
251
+ elapsed = time.time() - t0
252
+ print(f" step={step:5d} loss={loss.item():.4f} ppl={ppl:.4f} t={elapsed:.0f}s", flush=True)
253
+ curve.append({"step": step, "loss": float(loss.item()), "ppl": float(ppl)})
254
+ if args.save_checkpoint and step % args.save_every == 0 and step < args.max_steps:
255
+ interim = args.save_checkpoint.replace(".pt", "_intermediate.pt")
256
+ torch.save({"student_state": student.state_dict(),
257
+ "config": vars(args), "step": step, "ppl": ppl}, interim)
258
+
259
+ final_ppl = eval_ppl(student, tokenizer, calib_path=args.eval_calib_path,
260
+ max_seqs=args.eval_max_seqs if args.eval_max_seqs > 0 else None)
261
+ print(f"\n=== Final PPL: {final_ppl:.4f} ===")
262
+
263
+ out = {"phase": args.phase, "config": vars(args), "final_ppl": final_ppl, "ppl_curve": curve}
264
+ os.makedirs("logs", exist_ok=True)
265
+ with open(f"logs/rung8_{args.phase}_results.json", "w") as f: json.dump(out, f, indent=2)
266
+
267
+ if args.save_checkpoint:
268
+ torch.save({"student_state": student.state_dict(),
269
+ "config": vars(args), "final_ppl": final_ppl}, args.save_checkpoint)
270
+ print(f"Saved {args.save_checkpoint}")
271
+ interim = args.save_checkpoint.replace(".pt", "_intermediate.pt")
272
+ if os.path.exists(interim): os.remove(interim)
273
+
274
+
275
+ if __name__ == "__main__":
276
+ main()