Spaces:
Running
Running
| """DELIVERABLE 2 -- train a HyperExpert per layer to mimic StarCoder2-3b's MLPs. | |
| Reuses train_compare.py's corpus buffer builder, FIXED held-out eval set, held-out | |
| perplexity eval, and 8-bit Adam, so results are directly comparable to the bank | |
| baseline (E=2048 bank feat-arm: held-out ppl 52.4 @5M tok, 26.2 @30M tok; original | |
| 6.16). Objective is identical to the bank's winning "feat" arm: | |
| loss = NTP + feat_w * mean_layer relMSE(expert_out, orig_mlp(in)) | |
| with tuned layernorms. | |
| Usage (single-config validation, ~2M tokens): | |
| python train_hyper.py --c 128 --r 16 --b 2048 --budget 2e6 | |
| """ | |
| import argparse, importlib.util, math, time, json | |
| import torch, torch.nn.functional as F | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # reuse the corpus buffer builder + collect (and, transitively, the Bank module) | |
| spec = importlib.util.spec_from_file_location("tcmp", "/tmp/train_compare.py") | |
| tcmp = importlib.util.module_from_spec(spec); spec.loader.exec_module(tcmp) | |
| build_buffer = tcmp.build_buffer | |
| from hyper_expert import HyperExpert, boundary_token_mask, chunk_ids_from_tokens | |
| from bigcorpus import build_bigcorpus | |
| DEV = 0 | |
| MODEL = "bigcode/starcoder2-3b" | |
| SEED = 1234 | |
| def make_optimizer(name, params, lr): | |
| """Robust, memory-efficient optimizer for the ~800M-param hypernetwork. | |
| Default is Adafactor (pure PyTorch, no bitsandbytes): it factors the second | |
| moment and skips the first moment (beta1=None), so its optimizer state is | |
| negligible -- a plain fp32 AdamW would need ~6.4 GB of moments and OOM the | |
| 16 GB GPU. Adafactor is also robust on Blackwell/sm_120, where bitsandbytes | |
| CUDA kernels are fragile. `adamw8bit`/`adamw` are kept for bisection only. | |
| """ | |
| if name == "adafactor": | |
| from transformers.optimization import Adafactor | |
| opt = Adafactor(params, lr=lr, beta1=None, weight_decay=0.0, | |
| scale_parameter=False, relative_step=False, warmup_init=False) | |
| print("optimizer: Adafactor (factored 2nd moment, no 1st moment)", flush=True) | |
| return opt | |
| if name == "adamw8bit": | |
| import bitsandbytes as bnb | |
| opt = bnb.optim.AdamW8bit(params, lr=lr, betas=(0.9, 0.95)) | |
| print("optimizer: AdamW8bit", flush=True) | |
| return opt | |
| opt = torch.optim.AdamW(params, lr=lr, betas=(0.9, 0.95), foreach=False) | |
| print("optimizer: AdamW-fp32", flush=True) | |
| return opt | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--c", type=int, default=128) # encoder latent | |
| ap.add_argument("--r", type=int, default=16) # generated rank | |
| ap.add_argument("--b", type=int, default=2048) # base FFN width | |
| ap.add_argument("--chunk", type=int, default=1, # tokens per generated expert | |
| help="generate ONE low-rank expert per N-token chunk and apply it " | |
| "across the chunk (base FFN still per token). 1 = per-token (orig).") | |
| ap.add_argument("--chunk-mode", default="fixed", choices=["fixed", "sentence"], | |
| help="fixed = arange//chunk (uses --chunk); sentence = chunk at " | |
| "sentence/line boundaries ( . ! ? ; newline), capped at --chunk-cap " | |
| "tokens (a chunk never crosses a boundary, base FFN still per token).") | |
| ap.add_argument("--chunk-cap", type=int, default=20, | |
| help="sentence mode: max tokens/chunk N (long segments split into <=N).") | |
| ap.add_argument("--budget", type=float, default=2e6) | |
| ap.add_argument("--ctx", type=int, default=256) | |
| ap.add_argument("--batch", type=int, default=2) | |
| ap.add_argument("--eval-tokens", type=int, default=40000) | |
| ap.add_argument("--lr", type=float, default=5e-4) | |
| ap.add_argument("--warmup", type=int, default=150) | |
| ap.add_argument("--feat-w", type=float, default=1.0) | |
| ap.add_argument("--teacher-ce-w", type=float, default=0.0, | |
| help="weight on tempered KL to the FROZEN ORIGINAL model logits " | |
| "(Mikey's teacher-CE; additive, on top of NTP+feat). 0 = off.") | |
| ap.add_argument("--kl-temp", type=float, default=2.0) | |
| ap.add_argument("--base-init", default="random", choices=["topnorm", "random"], | |
| help="base FFN subset: random (default, == baseline Bank, INIT ~12k) " | |
| "or topnorm (INIT ~42k -- worse under 30-layer stacking; see " | |
| "diag_warmstart.py). g_v=0 makes the hyper start == the chosen bank.") | |
| ap.add_argument("--bigcorpus", action="store_true", | |
| help="capacity-vs-data control: train on a LARGE non-repeating stream " | |
| "(disjoint from eval, never wraps) while the eval stays BYTE-IDENTICAL " | |
| "to the prior 25.94 runs. See bigcorpus.py.") | |
| ap.add_argument("--train-target-tokens", type=float, default=0.0, | |
| help="bigcorpus: unique training tokens to materialize. 0 => 1.5*budget " | |
| "(comfortably exceeds the budget so the train loop never wraps).") | |
| ap.add_argument("--skip-docs", type=int, default=9000, | |
| help="bigcorpus: training draws docs [skip:] of each source; the eval pool " | |
| "is docs [0:9000], so skip>=9000 keeps train DISJOINT from eval.") | |
| ap.add_argument("--eval-every", type=int, default=500) | |
| ap.add_argument("--opt", default="adafactor", | |
| choices=["adafactor", "adamw8bit", "adamw"]) | |
| ap.add_argument("--param-dtype", default="bf16", choices=["bf16", "fp32"]) | |
| ap.add_argument("--tag", default="hyper") | |
| args = ap.parse_args() | |
| pdtype = torch.bfloat16 if args.param_dtype == "bf16" else torch.float32 | |
| per = args.batch * args.ctx | |
| steps = int(args.budget / per) | |
| cm = f"sentence(cap={args.chunk_cap})" if args.chunk_mode == "sentence" else f"fixed({args.chunk})" | |
| print(f"=== HYPER c={args.c} r={args.r} b={args.b} chunk_mode={cm} ctx={args.ctx} " | |
| f"batch={args.batch} steps={steps} budget={args.budget/1e6:.1f}M feat_w={args.feat_w} " | |
| f"base_init={args.base_init} teacher_ce_w={args.teacher_ce_w} (T={args.kl_temp}) ===", | |
| flush=True) | |
| torch.cuda.set_device(DEV); torch.cuda.init() | |
| tok = AutoTokenizer.from_pretrained(MODEL) | |
| # sentence mode: precompute the boundary-token mask once (tokenizer-side). | |
| bmask = boundary_token_mask(tok) if args.chunk_mode == "sentence" else None | |
| if bmask is not None: | |
| print(f"[sentence] {int(bmask.sum())}/{len(tok)} vocab tokens are boundaries " | |
| f"( . ! ? ; newline ), cap={args.chunk_cap}", flush=True) | |
| ne = args.eval_tokens // args.ctx | |
| t_buf = time.time() | |
| if args.bigcorpus: | |
| # CAPACITY-vs-DATA control: eval byte-identical, training big + non-repeating. | |
| train_target = args.train_target_tokens or (1.5 * args.budget) | |
| train_target = int(train_target) | |
| buf = build_bigcorpus(tok, ne * args.ctx, train_target, | |
| skip=args.skip_docs) | |
| print(f"buffer tokens: {len(buf)} (eval {ne*args.ctx} + train_target {train_target}) " | |
| f"in {time.time()-t_buf:.0f}s", flush=True) | |
| else: | |
| need = ne * args.ctx + steps * per + per * 4 | |
| buf = build_buffer(tok, need) | |
| print(f"buffer tokens: {len(buf)} (needed {need}) in {time.time()-t_buf:.0f}s", flush=True) | |
| assert len(buf) >= ne * args.ctx + per, "corpus too small" | |
| eval_ids = buf[:ne * args.ctx].view(ne, args.ctx) | |
| eb = [eval_ids[i:i + args.batch].to(DEV) for i in range(0, ne, args.batch)] | |
| train_buf = buf[ne * args.ctx:] | |
| span = (len(train_buf) // per) * per | |
| if args.bigcorpus: | |
| # Confirm in code/logs that the train loop NEVER wraps: the largest token index it | |
| # reads is (steps-1)*per + per = steps*per; it must stay within `span`. | |
| max_idx = steps * per | |
| assert span >= max_idx + per, ( | |
| f"train WOULD WRAP: span={span} < max_idx+per={max_idx+per} " | |
| f"(train_buf {len(train_buf)} tok, need >= {max_idx+per})") | |
| print(f"[bigcorpus] NO-WRAP confirmed: train_buf {len(train_buf)} tok, usable span " | |
| f"{span}, max train index {max_idx} (<{span}) over {steps} steps => every " | |
| f"training token unique, loop never wraps ({span/max_idx:.2f}x headroom)", flush=True) | |
| def train_batch(step): | |
| s = (step * per) % max(per, span - per) | |
| return train_buf[s:s + per].view(args.batch, args.ctx).to(DEV) | |
| m = AutoModelForCausalLM.from_pretrained(MODEL, dtype=torch.bfloat16, device_map={"": DEV}) | |
| m.config.use_cache = False | |
| for p in m.parameters(): p.requires_grad_(False) | |
| layers = m.model.layers | |
| orig_mlps = [l.mlp for l in layers] | |
| def install(ex): | |
| for l, e in zip(layers, ex): l.mlp = e | |
| def uninstall(): | |
| for l, om in zip(layers, orig_mlps): l.mlp = om | |
| experts = [] # filled below; predeclared so set_cids/eval_ppl can close over it | |
| def set_cids(ids): | |
| """sentence mode: compute per-token chunk ids for this batch and push to every | |
| installed expert before the forward (constants -> survive grad-checkpoint re-run). | |
| No-op when experts aren't installed (e.g. the uninstalled ORIGINAL ppl ref).""" | |
| if bmask is None or not experts: return | |
| cids = chunk_ids_from_tokens(ids, bmask, args.chunk_cap) | |
| for e in experts: e.set_chunk_ids(cids) | |
| def eval_ppl(): | |
| was = m.training; m.eval(); tot = 0.0; n = 0 | |
| for ids in eb: | |
| set_cids(ids) | |
| tot += m(ids, labels=ids).loss.item() * ids.shape[0]; n += ids.shape[0] | |
| if was: m.train() | |
| return math.exp(tot / n) | |
| uninstall(); op = eval_ppl() | |
| print(f"[ref] ORIGINAL held-out ppl {op:.3f}", flush=True) | |
| torch.manual_seed(SEED) | |
| experts[:] = [HyperExpert(om, args.c, args.r, args.b, dtype=pdtype, init=args.base_init, | |
| chunk=args.chunk, chunk_mode=args.chunk_mode, | |
| chunk_cap=args.chunk_cap).to(DEV) | |
| for om in orig_mlps] | |
| fp_layer = experts[0].footprint() | |
| print(f"[footprint] {fp_layer/1e6:.2f}M params/layer ({fp_layer*len(layers)/1e6:.1f}M total experts)", flush=True) | |
| install(experts) | |
| params = [p for e in experts for p in e.parameters()] | |
| for l in layers: | |
| for mod in (l.input_layernorm, l.post_attention_layernorm): | |
| for p in mod.parameters(): p.requires_grad_(True); params.append(p) | |
| m.gradient_checkpointing_enable() | |
| opt = make_optimizer(args.opt, params, args.lr) | |
| sched = torch.optim.lr_scheduler.LambdaLR(opt, lambda s: | |
| s / args.warmup if s < args.warmup else | |
| 0.5 * (1 + math.cos(math.pi * min(1.0, (s - args.warmup) / max(1, steps - args.warmup))))) | |
| ip = eval_ppl() | |
| print(f"[{args.tag}] INIT held-out ppl {ip:.1f}", flush=True) | |
| m.train(); t0 = time.time(); ema = None; traj = [] | |
| torch.cuda.reset_peak_memory_stats(DEV) | |
| for step in range(steps): | |
| ids = train_batch(step) | |
| set_cids(ids) # sentence mode: per-token chunk ids for this batch | |
| # tempered teacher logits from the FROZEN ORIGINAL model (Mikey's idea): | |
| # swap experts out for the original MLPs, grab logits with no grad, swap back. | |
| if args.teacher_ce_w > 0: | |
| uninstall(); was = m.training; m.eval() | |
| with torch.no_grad(): tl = m(ids).logits | |
| install(experts) | |
| if was: m.train() | |
| out = m(ids, labels=ids); loss = out.loss; ce = out.loss.item() | |
| if args.feat_w > 0: | |
| fl = 0.0 | |
| for om, e in zip(orig_mlps, experts): | |
| with torch.no_grad(): tgt = om(e.last_in) | |
| fl = fl + ((e.last_out - tgt).float().pow(2).mean() | |
| / tgt.float().pow(2).mean().clamp_min(1e-6)) | |
| loss = loss + args.feat_w * fl / len(experts) | |
| if args.teacher_ce_w > 0: | |
| T = args.kl_temp | |
| kl = F.kl_div(F.log_softmax(out.logits / T, -1), F.softmax(tl / T, -1), | |
| reduction="batchmean") * (T * T) / out.logits.shape[1] | |
| loss = loss + args.teacher_ce_w * kl | |
| del tl | |
| opt.zero_grad(set_to_none=True); loss.backward() | |
| torch.nn.utils.clip_grad_norm_(params, 1.0); opt.step(); sched.step() | |
| ema = ce if ema is None else 0.98 * ema + 0.02 * ce | |
| if (step + 1) % args.eval_every == 0: | |
| ep = eval_ppl(); m.train() | |
| tps = (step + 1) * per / (time.time() - t0) | |
| peak = torch.cuda.max_memory_allocated(DEV) / 1e9 | |
| train_ppl = math.exp(ema) # exp of the NTP ce ema = train perplexity | |
| gap = ep - train_ppl | |
| traj.append([step + 1, round(ep, 3), round(train_ppl, 3)]) | |
| print(f"[{args.tag}] step {step+1}/{steps} ce_ema {ema:.3f} train_ppl {train_ppl:.1f} " | |
| f"heldout_ppl {ep:.1f} gap {gap:+.1f} (orig {op:.1f}) {tps/1000:.1f}k tok/s " | |
| f"peakVRAM {peak:.2f}GB", flush=True) | |
| fpppl = eval_ppl() | |
| final_train_ppl = math.exp(ema) if ema is not None else float("nan") | |
| peak = torch.cuda.max_memory_allocated(DEV) / 1e9 | |
| print(f"\n[result {args.tag}] ORIG {op:.3f} | INIT {ip:.1f} | FINAL heldout {fpppl:.3f} ppl " | |
| f"| FINAL train {final_train_ppl:.3f} ppl | gap {fpppl-final_train_ppl:+.3f} " | |
| f"(c={args.c} r={args.r} b={args.b}, {steps*per/1e6:.1f}M tokens) " | |
| f"peakVRAM {peak:.2f}GB opt={args.opt} dtype={args.param_dtype}", flush=True) | |
| res = {"c": args.c, "r": args.r, "b": args.b, "chunk": args.chunk, | |
| "chunk_mode": args.chunk_mode, "chunk_cap": args.chunk_cap, "budget": args.budget, | |
| "footprint_per_layer": fp_layer, "original_ppl": op, | |
| "init_ppl": ip, "final_ppl": fpppl, "final_train_ppl": final_train_ppl, | |
| "bigcorpus": args.bigcorpus, | |
| "train_target_tokens": (args.train_target_tokens or (1.5 * args.budget)) if args.bigcorpus else None, | |
| "peak_vram_gb": round(peak, 2), | |
| "opt": args.opt, "param_dtype": args.param_dtype, | |
| "base_init": args.base_init, "teacher_ce_w": args.teacher_ce_w, | |
| "kl_temp": args.kl_temp, "traj": traj} | |
| json.dump(res, open(f"hyper_{args.tag}_c{args.c}_r{args.r}_b{args.b}.json", "w"), indent=2) | |
| print("HYPER DONE", flush=True) | |
| if __name__ == "__main__": | |
| main() | |