"""DELIVERABLE 2 -- train a HyperExpert per layer to mimic StarCoder2-3b's MLPs. Reuses train_compare.py's corpus buffer builder, FIXED held-out eval set, held-out perplexity eval, and 8-bit Adam, so results are directly comparable to the bank baseline (E=2048 bank feat-arm: held-out ppl 52.4 @5M tok, 26.2 @30M tok; original 6.16). Objective is identical to the bank's winning "feat" arm: loss = NTP + feat_w * mean_layer relMSE(expert_out, orig_mlp(in)) with tuned layernorms. Usage (single-config validation, ~2M tokens): python train_hyper.py --c 128 --r 16 --b 2048 --budget 2e6 """ import argparse, importlib.util, math, time, json import torch, torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer # reuse the corpus buffer builder + collect (and, transitively, the Bank module) spec = importlib.util.spec_from_file_location("tcmp", "/tmp/train_compare.py") tcmp = importlib.util.module_from_spec(spec); spec.loader.exec_module(tcmp) build_buffer = tcmp.build_buffer from hyper_expert import HyperExpert, boundary_token_mask, chunk_ids_from_tokens from bigcorpus import build_bigcorpus DEV = 0 MODEL = "bigcode/starcoder2-3b" SEED = 1234 def make_optimizer(name, params, lr): """Robust, memory-efficient optimizer for the ~800M-param hypernetwork. Default is Adafactor (pure PyTorch, no bitsandbytes): it factors the second moment and skips the first moment (beta1=None), so its optimizer state is negligible -- a plain fp32 AdamW would need ~6.4 GB of moments and OOM the 16 GB GPU. Adafactor is also robust on Blackwell/sm_120, where bitsandbytes CUDA kernels are fragile. `adamw8bit`/`adamw` are kept for bisection only. """ if name == "adafactor": from transformers.optimization import Adafactor opt = Adafactor(params, lr=lr, beta1=None, weight_decay=0.0, scale_parameter=False, relative_step=False, warmup_init=False) print("optimizer: Adafactor (factored 2nd moment, no 1st moment)", flush=True) return opt if name == "adamw8bit": import bitsandbytes as bnb opt = bnb.optim.AdamW8bit(params, lr=lr, betas=(0.9, 0.95)) print("optimizer: AdamW8bit", flush=True) return opt opt = torch.optim.AdamW(params, lr=lr, betas=(0.9, 0.95), foreach=False) print("optimizer: AdamW-fp32", flush=True) return opt def main(): ap = argparse.ArgumentParser() ap.add_argument("--c", type=int, default=128) # encoder latent ap.add_argument("--r", type=int, default=16) # generated rank ap.add_argument("--b", type=int, default=2048) # base FFN width ap.add_argument("--chunk", type=int, default=1, # tokens per generated expert help="generate ONE low-rank expert per N-token chunk and apply it " "across the chunk (base FFN still per token). 1 = per-token (orig).") ap.add_argument("--chunk-mode", default="fixed", choices=["fixed", "sentence"], help="fixed = arange//chunk (uses --chunk); sentence = chunk at " "sentence/line boundaries ( . ! ? ; newline), capped at --chunk-cap " "tokens (a chunk never crosses a boundary, base FFN still per token).") ap.add_argument("--chunk-cap", type=int, default=20, help="sentence mode: max tokens/chunk N (long segments split into <=N).") ap.add_argument("--budget", type=float, default=2e6) ap.add_argument("--ctx", type=int, default=256) ap.add_argument("--batch", type=int, default=2) ap.add_argument("--eval-tokens", type=int, default=40000) ap.add_argument("--lr", type=float, default=5e-4) ap.add_argument("--warmup", type=int, default=150) ap.add_argument("--feat-w", type=float, default=1.0) ap.add_argument("--teacher-ce-w", type=float, default=0.0, help="weight on tempered KL to the FROZEN ORIGINAL model logits " "(Mikey's teacher-CE; additive, on top of NTP+feat). 0 = off.") ap.add_argument("--kl-temp", type=float, default=2.0) ap.add_argument("--base-init", default="random", choices=["topnorm", "random"], help="base FFN subset: random (default, == baseline Bank, INIT ~12k) " "or topnorm (INIT ~42k -- worse under 30-layer stacking; see " "diag_warmstart.py). g_v=0 makes the hyper start == the chosen bank.") ap.add_argument("--bigcorpus", action="store_true", help="capacity-vs-data control: train on a LARGE non-repeating stream " "(disjoint from eval, never wraps) while the eval stays BYTE-IDENTICAL " "to the prior 25.94 runs. See bigcorpus.py.") ap.add_argument("--train-target-tokens", type=float, default=0.0, help="bigcorpus: unique training tokens to materialize. 0 => 1.5*budget " "(comfortably exceeds the budget so the train loop never wraps).") ap.add_argument("--skip-docs", type=int, default=9000, help="bigcorpus: training draws docs [skip:] of each source; the eval pool " "is docs [0:9000], so skip>=9000 keeps train DISJOINT from eval.") ap.add_argument("--eval-every", type=int, default=500) ap.add_argument("--opt", default="adafactor", choices=["adafactor", "adamw8bit", "adamw"]) ap.add_argument("--param-dtype", default="bf16", choices=["bf16", "fp32"]) ap.add_argument("--tag", default="hyper") args = ap.parse_args() pdtype = torch.bfloat16 if args.param_dtype == "bf16" else torch.float32 per = args.batch * args.ctx steps = int(args.budget / per) cm = f"sentence(cap={args.chunk_cap})" if args.chunk_mode == "sentence" else f"fixed({args.chunk})" print(f"=== HYPER c={args.c} r={args.r} b={args.b} chunk_mode={cm} ctx={args.ctx} " f"batch={args.batch} steps={steps} budget={args.budget/1e6:.1f}M feat_w={args.feat_w} " f"base_init={args.base_init} teacher_ce_w={args.teacher_ce_w} (T={args.kl_temp}) ===", flush=True) torch.cuda.set_device(DEV); torch.cuda.init() tok = AutoTokenizer.from_pretrained(MODEL) # sentence mode: precompute the boundary-token mask once (tokenizer-side). bmask = boundary_token_mask(tok) if args.chunk_mode == "sentence" else None if bmask is not None: print(f"[sentence] {int(bmask.sum())}/{len(tok)} vocab tokens are boundaries " f"( . ! ? ; newline ), cap={args.chunk_cap}", flush=True) ne = args.eval_tokens // args.ctx t_buf = time.time() if args.bigcorpus: # CAPACITY-vs-DATA control: eval byte-identical, training big + non-repeating. train_target = args.train_target_tokens or (1.5 * args.budget) train_target = int(train_target) buf = build_bigcorpus(tok, ne * args.ctx, train_target, skip=args.skip_docs) print(f"buffer tokens: {len(buf)} (eval {ne*args.ctx} + train_target {train_target}) " f"in {time.time()-t_buf:.0f}s", flush=True) else: need = ne * args.ctx + steps * per + per * 4 buf = build_buffer(tok, need) print(f"buffer tokens: {len(buf)} (needed {need}) in {time.time()-t_buf:.0f}s", flush=True) assert len(buf) >= ne * args.ctx + per, "corpus too small" eval_ids = buf[:ne * args.ctx].view(ne, args.ctx) eb = [eval_ids[i:i + args.batch].to(DEV) for i in range(0, ne, args.batch)] train_buf = buf[ne * args.ctx:] span = (len(train_buf) // per) * per if args.bigcorpus: # Confirm in code/logs that the train loop NEVER wraps: the largest token index it # reads is (steps-1)*per + per = steps*per; it must stay within `span`. max_idx = steps * per assert span >= max_idx + per, ( f"train WOULD WRAP: span={span} < max_idx+per={max_idx+per} " f"(train_buf {len(train_buf)} tok, need >= {max_idx+per})") print(f"[bigcorpus] NO-WRAP confirmed: train_buf {len(train_buf)} tok, usable span " f"{span}, max train index {max_idx} (<{span}) over {steps} steps => every " f"training token unique, loop never wraps ({span/max_idx:.2f}x headroom)", flush=True) def train_batch(step): s = (step * per) % max(per, span - per) return train_buf[s:s + per].view(args.batch, args.ctx).to(DEV) m = AutoModelForCausalLM.from_pretrained(MODEL, dtype=torch.bfloat16, device_map={"": DEV}) m.config.use_cache = False for p in m.parameters(): p.requires_grad_(False) layers = m.model.layers orig_mlps = [l.mlp for l in layers] def install(ex): for l, e in zip(layers, ex): l.mlp = e def uninstall(): for l, om in zip(layers, orig_mlps): l.mlp = om experts = [] # filled below; predeclared so set_cids/eval_ppl can close over it def set_cids(ids): """sentence mode: compute per-token chunk ids for this batch and push to every installed expert before the forward (constants -> survive grad-checkpoint re-run). No-op when experts aren't installed (e.g. the uninstalled ORIGINAL ppl ref).""" if bmask is None or not experts: return cids = chunk_ids_from_tokens(ids, bmask, args.chunk_cap) for e in experts: e.set_chunk_ids(cids) @torch.no_grad() def eval_ppl(): was = m.training; m.eval(); tot = 0.0; n = 0 for ids in eb: set_cids(ids) tot += m(ids, labels=ids).loss.item() * ids.shape[0]; n += ids.shape[0] if was: m.train() return math.exp(tot / n) uninstall(); op = eval_ppl() print(f"[ref] ORIGINAL held-out ppl {op:.3f}", flush=True) torch.manual_seed(SEED) experts[:] = [HyperExpert(om, args.c, args.r, args.b, dtype=pdtype, init=args.base_init, chunk=args.chunk, chunk_mode=args.chunk_mode, chunk_cap=args.chunk_cap).to(DEV) for om in orig_mlps] fp_layer = experts[0].footprint() print(f"[footprint] {fp_layer/1e6:.2f}M params/layer ({fp_layer*len(layers)/1e6:.1f}M total experts)", flush=True) install(experts) params = [p for e in experts for p in e.parameters()] for l in layers: for mod in (l.input_layernorm, l.post_attention_layernorm): for p in mod.parameters(): p.requires_grad_(True); params.append(p) m.gradient_checkpointing_enable() opt = make_optimizer(args.opt, params, args.lr) sched = torch.optim.lr_scheduler.LambdaLR(opt, lambda s: s / args.warmup if s < args.warmup else 0.5 * (1 + math.cos(math.pi * min(1.0, (s - args.warmup) / max(1, steps - args.warmup))))) ip = eval_ppl() print(f"[{args.tag}] INIT held-out ppl {ip:.1f}", flush=True) m.train(); t0 = time.time(); ema = None; traj = [] torch.cuda.reset_peak_memory_stats(DEV) for step in range(steps): ids = train_batch(step) set_cids(ids) # sentence mode: per-token chunk ids for this batch # tempered teacher logits from the FROZEN ORIGINAL model (Mikey's idea): # swap experts out for the original MLPs, grab logits with no grad, swap back. if args.teacher_ce_w > 0: uninstall(); was = m.training; m.eval() with torch.no_grad(): tl = m(ids).logits install(experts) if was: m.train() out = m(ids, labels=ids); loss = out.loss; ce = out.loss.item() if args.feat_w > 0: fl = 0.0 for om, e in zip(orig_mlps, experts): with torch.no_grad(): tgt = om(e.last_in) fl = fl + ((e.last_out - tgt).float().pow(2).mean() / tgt.float().pow(2).mean().clamp_min(1e-6)) loss = loss + args.feat_w * fl / len(experts) if args.teacher_ce_w > 0: T = args.kl_temp kl = F.kl_div(F.log_softmax(out.logits / T, -1), F.softmax(tl / T, -1), reduction="batchmean") * (T * T) / out.logits.shape[1] loss = loss + args.teacher_ce_w * kl del tl opt.zero_grad(set_to_none=True); loss.backward() torch.nn.utils.clip_grad_norm_(params, 1.0); opt.step(); sched.step() ema = ce if ema is None else 0.98 * ema + 0.02 * ce if (step + 1) % args.eval_every == 0: ep = eval_ppl(); m.train() tps = (step + 1) * per / (time.time() - t0) peak = torch.cuda.max_memory_allocated(DEV) / 1e9 train_ppl = math.exp(ema) # exp of the NTP ce ema = train perplexity gap = ep - train_ppl traj.append([step + 1, round(ep, 3), round(train_ppl, 3)]) print(f"[{args.tag}] step {step+1}/{steps} ce_ema {ema:.3f} train_ppl {train_ppl:.1f} " f"heldout_ppl {ep:.1f} gap {gap:+.1f} (orig {op:.1f}) {tps/1000:.1f}k tok/s " f"peakVRAM {peak:.2f}GB", flush=True) fpppl = eval_ppl() final_train_ppl = math.exp(ema) if ema is not None else float("nan") peak = torch.cuda.max_memory_allocated(DEV) / 1e9 print(f"\n[result {args.tag}] ORIG {op:.3f} | INIT {ip:.1f} | FINAL heldout {fpppl:.3f} ppl " f"| FINAL train {final_train_ppl:.3f} ppl | gap {fpppl-final_train_ppl:+.3f} " f"(c={args.c} r={args.r} b={args.b}, {steps*per/1e6:.1f}M tokens) " f"peakVRAM {peak:.2f}GB opt={args.opt} dtype={args.param_dtype}", flush=True) res = {"c": args.c, "r": args.r, "b": args.b, "chunk": args.chunk, "chunk_mode": args.chunk_mode, "chunk_cap": args.chunk_cap, "budget": args.budget, "footprint_per_layer": fp_layer, "original_ppl": op, "init_ppl": ip, "final_ppl": fpppl, "final_train_ppl": final_train_ppl, "bigcorpus": args.bigcorpus, "train_target_tokens": (args.train_target_tokens or (1.5 * args.budget)) if args.bigcorpus else None, "peak_vram_gb": round(peak, 2), "opt": args.opt, "param_dtype": args.param_dtype, "base_init": args.base_init, "teacher_ce_w": args.teacher_ce_w, "kl_temp": args.kl_temp, "traj": traj} json.dump(res, open(f"hyper_{args.tag}_c{args.c}_r{args.r}_b{args.b}.json", "w"), indent=2) print("HYPER DONE", flush=True) if __name__ == "__main__": main()