HyperPEER / PHASE2_PLAN.md
MikeyBeez's picture
Add HyperPEER pipeline, testbed code, results, docs, Gradio landing
e41a3a4 verified
|
Raw
History Blame Contribute Delete
7.53 kB

Phase 2 — Gemma-4-26B-A4B → ~7.6B hypernetwork student (staged capture→train)

Box: pop — RTX 5070 Ti (16 GB, sm_120 Blackwell), 31 GB RAM, torch 2.11+cu128, transformers 5.12. Date: 2026-06-18.

Goal: compress the 26B MoE teacher's text decoder into a ~7.6B hypernetwork student that runs on the 16 GB card, by replacing each of the 30 decoder layers' feed-forward/MoE BLOCK with a per-token GENERATED low-rank expert, trained by layer-local feature distillation against the teacher's cached block I/O.

This file documents the architecture, the staged loop, the validated-small numbers, the full-run config, expected runtime, and how to read progress.

What the student is (per layer)

GemmaHyperExpert (gemma_hyper.py), the Gemma adaptation of the StarCoder testbed's validated HyperExpert (large-hypernetwork / small-generated-rank, the winning shape):

z   = gelu(Enc(x))              Enc: Linear(d=2816 -> c=9856)
U   = (G_U z).view(r=4, d)      G_U: Linear(c -> r*d)     # generated per token
V   = (G_V z).view(d, r=4)      G_V: Linear(c -> d*r)
base(x) = b_proj(gelu(b_fc(x))) widths d->b=640->d        # shared low-width FFN
Yhat = x + base(x) + V @ gelu(U @ x)                       # RESIDUAL block output
  • Residual parameterization: the teacher block is Y = X + D, so the student models the delta on top of the carried residual and is trained to minimize relMSE(Yhat, Y) = mean((Yhat-Y)^2)/mean(Y^2).
  • Exact identity warm start: b_proj and g_v are zero-init, so at step 0 Yhat == X exactly (verified). Init relMSE is then ||D||²/||Y||² — a finite, meaningful start that training drives down.
  • Size: c=9856, r=4, b=640 → 253.4 M params/layer × 30 = 7.60 B (teacher/3.5).
  • fp32 params at train time. Unlike the testbed (all 30 experts resident → bf16), here we train ONE layer at a time, so fp32 is affordable AND necessary: bf16 diverged on outlier-heavy layers (e.g. layer 29, Y absmax≈444 vs rms≈3); fp32 + lr 2e-4 converges every layer. Peak VRAM for a full-size layer: 4.0 GiB.

The staged loop (staged_run.py)

Over a large, diverse, NON-REPEATING corpus, repeat (only ONE chunk on disk at once):

  1. CAPTURE a chunk of FRESH tokens: stream the bf16 teacher forward-only on the 16 GB card (Phase-1 method, accelerate NVMe layer-streaming), hooking each layer's FF/MoE block input X (= pre_feedforward_layernorm input) and output Y (= X+D). Cache all 30 layers' (X,Y) bf16 to /mnt/data/cache/gemma_cap/chunkNNN/.
  2. TRAIN (train_layers.py): for each of the 30 layers, load its cached X→Y, train its GemmaHyperExpert for --passes passes (relMSE, Adafactor — bitsandbytes SEGFAULTs on this Blackwell GPU). Per-layer (model+optimizer) checkpoints persist in student/layerNN.pt and are reloaded each chunk → continual distillation.
  3. DELETE the chunk cache; advance per-source doc cursors; capture the next chunk.

A fixed held-out eval chunk (/mnt/data/cache/gemma_cap_eval, captured ONCE from a reserved doc range, never trained) tracks per-layer relMSE over the whole run.

Corpus (gemma_corpus.py): codeparrot-clean (code) + fineweb-edu (web), streamed with per-source doc cursors that only advance → every training token is fresh and non-repeating; reserved docs [0:3000) are the eval holdout. Falls back to local wikitext-103 if HF streaming comes up short, so a capture never stalls on the network. Cursors + next-chunk index persist in student/state.jsonkill/relaunch resumes the non-repeating stream and the continual student.

Why this fits 16 GB (the Phase-1 storage blocker is gone)

Phase 1's blocker was materializing all 30 layers × 50M tokens (6–15 TB). The staged loop keeps only ONE chunk on disk: a 1M-token chunk is ~330 GB on the 6.3 TB USB cache drive (capture writes it, training reads it once, then it's deleted). Training never loads the teacher (Y is cached) and holds one layer's expert + data at a time.

Full-run config (LAUNCHED)

tokens-per-chunk 1,000,000   passes 2     mb 8192    lr 2e-4   param-dtype fp32
student c=9856 r=4 b=640 (7.60 B)         capture-batch 64     ctx 256
n-layers 30 (capture all)    layers all (train all 30)
cache-root /mnt/data/cache/gemma_cap   eval-dir /mnt/data/cache/gemma_cap_eval
student-dir compression/gemma/student   (checkpoints, progress.jsonl, state.json)

Expected runtime

Per 1M-token chunk (USB cache drive is the bottleneck; teacher load is ~7 s, negligible):

  • corpus build (network): ~1–3 min
  • capture: write-bound 320 tok/s on the USB drive → **50 min** (forward itself is ~558 tok/s @ batch 64; writing 330 GB caps it)
  • train 30 layers: 17k tok/s compute + one-time per-layer read of ~11 GB from USB → **90 min**
  • ~2.3 h/chunk ≈ 10 chunks/day ≈ 10M fresh tokens/day.

Over ~5 days: ~50 chunks ≈ 50M fresh tokens; each layer sees ~50 chunks × 2 passes. The run is unbounded (--n-chunks very large) — stop it when per-layer eval relMSE plateaus. It is detached under nohup and survives logout.

How to read progress

  • Run log: compression/gemma/student/run.log (tail -f). Per-chunk banners, capture throughput, and per-layer init_rel -> train_rel | EVAL_rel.
  • Metrics: compression/gemma/student/progress.jsonl — one JSON row per (chunk_idx, layer): init_train_rel, train_rel, eval_rel, tokens_seen. Track eval_rel per layer falling over chunks (held-out fidelity).
  • State: student/state.jsonnext_chunk + per-source doc cursors (resume point).
  • Checkpoints: student/layerNN.pt = {model, opt, tokens_seen, cfg}.
  • Blocker signal: /tmp/NEEDS_MIKEY.txt is written and the loop stops only on a hard failure (capture failed twice, training failed, or cache drive low on space).

Small-validation results (must-pass gate, before launch)

Trained the 3 representative layers (0 / 15 / 29) on real captured gemma_cap_val activations (24.6k train tokens, held-out 8.2k), fp32 / lr 2e-4 — held-out relMSE:

layer init (Yhat=X) trained (held-out)
0 0.787 0.262
15 0.347 0.123
29 0.834 0.108

Eval ≈ train (healthy generalization). Full-size (253 M/layer) layer: peak VRAM 4.0 GiB. End-to-end loop (fresh capture → train → eval → delete → state) validated on a tiny 3-layer / 40k-token chunk before the full launch.

Final assembly (LATER, not this run)

This run trains the 30 experts. Assembling the deployable student — inherit Gemma's attention + embeddings, 4-bit-quantize them, swap each layer's FF/MoE block for its trained GemmaHyperExpert, and measure end-to-end ppl/generation — is a separate step once per-layer eval relMSE has converged. (Cached activations are stored flattened, so the chunked/sentence-boundary DECODE-SPEED variants are an assembly-time concern.)

Files

  • gemma_hyper.pyGemmaHyperExpert (per-token generated low-rank, residual, warm start).
  • gemma_corpus.py — diverse non-repeating streamed corpus + eval holdout + cursors.
  • capture_gemma.py — Phase-1 capture (now accepts --token-file for staged chunks).
  • train_layers.py — layer-local relMSE distillation, continual per-layer checkpoints.
  • staged_run.py — the capture→train→delete orchestrator (resumable, detached).