Spaces:
Running
Running
| # Phase 2 — Gemma-4-26B-A4B → ~7.6B hypernetwork student (staged capture→train) | |
| **Box:** pop — RTX 5070 Ti (16 GB, sm_120 Blackwell), 31 GB RAM, torch 2.11+cu128, | |
| transformers 5.12. Date: 2026-06-18. | |
| **Goal:** compress the 26B MoE teacher's text decoder into a ~7.6B **hypernetwork | |
| student** that runs on the 16 GB card, by replacing each of the 30 decoder layers' | |
| feed-forward/MoE BLOCK with a per-token GENERATED low-rank expert, trained by | |
| **layer-local feature distillation** against the teacher's cached block I/O. | |
| This file documents the architecture, the staged loop, the validated-small numbers, | |
| the full-run config, expected runtime, and how to read progress. | |
| ## What the student is (per layer) | |
| `GemmaHyperExpert` (gemma_hyper.py), the Gemma adaptation of the StarCoder testbed's | |
| validated `HyperExpert` (large-hypernetwork / small-generated-rank, the winning shape): | |
| ``` | |
| z = gelu(Enc(x)) Enc: Linear(d=2816 -> c=9856) | |
| U = (G_U z).view(r=4, d) G_U: Linear(c -> r*d) # generated per token | |
| V = (G_V z).view(d, r=4) G_V: Linear(c -> d*r) | |
| base(x) = b_proj(gelu(b_fc(x))) widths d->b=640->d # shared low-width FFN | |
| Yhat = x + base(x) + V @ gelu(U @ x) # RESIDUAL block output | |
| ``` | |
| - **Residual parameterization:** the teacher block is `Y = X + D`, so the student | |
| models the delta on top of the carried residual and is trained to minimize | |
| `relMSE(Yhat, Y) = mean((Yhat-Y)^2)/mean(Y^2)`. | |
| - **Exact identity warm start:** `b_proj` and `g_v` are zero-init, so at step 0 | |
| `Yhat == X` exactly (verified). Init relMSE is then `||D||²/||Y||²` — a finite, | |
| meaningful start that training drives down. | |
| - **Size:** c=9856, r=4, b=640 → **253.4 M params/layer × 30 = 7.60 B** (teacher/3.5). | |
| - **fp32 params at train time.** Unlike the testbed (all 30 experts resident → bf16), | |
| here we train ONE layer at a time, so fp32 is affordable AND necessary: bf16 | |
| diverged on outlier-heavy layers (e.g. layer 29, Y absmax≈444 vs rms≈3); fp32 + | |
| lr 2e-4 converges every layer. Peak VRAM for a full-size layer: **4.0 GiB**. | |
| ## The staged loop (staged_run.py) | |
| Over a large, diverse, NON-REPEATING corpus, repeat (only ONE chunk on disk at once): | |
| 1. **CAPTURE** a chunk of FRESH tokens: stream the bf16 teacher forward-only on the | |
| 16 GB card (Phase-1 method, accelerate NVMe layer-streaming), hooking each layer's | |
| FF/MoE block input `X` (= pre_feedforward_layernorm input) and output `Y` (= X+D). | |
| Cache all 30 layers' (X,Y) bf16 to `/mnt/data/cache/gemma_cap/chunkNNN/`. | |
| 2. **TRAIN** (train_layers.py): for each of the 30 layers, load its cached X→Y, train | |
| its `GemmaHyperExpert` for `--passes` passes (relMSE, Adafactor — bitsandbytes | |
| SEGFAULTs on this Blackwell GPU). Per-layer **(model+optimizer) checkpoints persist** | |
| in `student/layerNN.pt` and are reloaded each chunk → **continual** distillation. | |
| 3. **DELETE** the chunk cache; advance per-source doc cursors; capture the next chunk. | |
| A fixed **held-out eval chunk** (`/mnt/data/cache/gemma_cap_eval`, captured ONCE from a | |
| reserved doc range, never trained) tracks per-layer relMSE over the whole run. | |
| **Corpus** (gemma_corpus.py): codeparrot-clean (code) + fineweb-edu (web), streamed | |
| with per-source doc cursors that only advance → every training token is fresh and | |
| non-repeating; reserved docs [0:3000) are the eval holdout. Falls back to local | |
| wikitext-103 if HF streaming comes up short, so a capture never stalls on the network. | |
| Cursors + next-chunk index persist in `student/state.json` → **kill/relaunch resumes** | |
| the non-repeating stream and the continual student. | |
| ## Why this fits 16 GB (the Phase-1 storage blocker is gone) | |
| Phase 1's blocker was materializing all 30 layers × 50M tokens (6–15 TB). The staged | |
| loop keeps only ONE chunk on disk: a 1M-token chunk is ~330 GB on the 6.3 TB USB cache | |
| drive (capture writes it, training reads it once, then it's deleted). Training never | |
| loads the teacher (Y is cached) and holds one layer's expert + data at a time. | |
| ## Full-run config (LAUNCHED) | |
| ``` | |
| tokens-per-chunk 1,000,000 passes 2 mb 8192 lr 2e-4 param-dtype fp32 | |
| student c=9856 r=4 b=640 (7.60 B) capture-batch 64 ctx 256 | |
| n-layers 30 (capture all) layers all (train all 30) | |
| cache-root /mnt/data/cache/gemma_cap eval-dir /mnt/data/cache/gemma_cap_eval | |
| student-dir compression/gemma/student (checkpoints, progress.jsonl, state.json) | |
| ``` | |
| ## Expected runtime | |
| Per 1M-token chunk (USB cache drive is the bottleneck; teacher load is ~7 s, negligible): | |
| - corpus build (network): ~1–3 min | |
| - capture: write-bound ~320 tok/s on the USB drive → **~50 min** (forward itself is | |
| ~558 tok/s @ batch 64; writing 330 GB caps it) | |
| - train 30 layers: ~17k tok/s compute + one-time per-layer read of ~11 GB from USB | |
| → **~90 min** | |
| - → **~2.3 h/chunk ≈ 10 chunks/day ≈ 10M fresh tokens/day.** | |
| Over ~5 days: ~50 chunks ≈ 50M fresh tokens; each layer sees ~50 chunks × 2 passes. | |
| The run is unbounded (`--n-chunks` very large) — stop it when per-layer eval relMSE | |
| plateaus. It is detached under nohup and survives logout. | |
| ## How to read progress | |
| - **Run log:** `compression/gemma/student/run.log` (`tail -f`). Per-chunk banners, | |
| capture throughput, and per-layer `init_rel -> train_rel | EVAL_rel`. | |
| - **Metrics:** `compression/gemma/student/progress.jsonl` — one JSON row per | |
| (chunk_idx, layer): `init_train_rel`, `train_rel`, `eval_rel`, `tokens_seen`. | |
| Track `eval_rel` per layer falling over chunks (held-out fidelity). | |
| - **State:** `student/state.json` — `next_chunk` + per-source doc cursors (resume point). | |
| - **Checkpoints:** `student/layerNN.pt` = `{model, opt, tokens_seen, cfg}`. | |
| - **Blocker signal:** `/tmp/NEEDS_MIKEY.txt` is written and the loop stops only on a | |
| hard failure (capture failed twice, training failed, or cache drive low on space). | |
| ## Small-validation results (must-pass gate, before launch) | |
| Trained the 3 representative layers (0 / 15 / 29) on real captured `gemma_cap_val` | |
| activations (24.6k train tokens, held-out 8.2k), fp32 / lr 2e-4 — **held-out relMSE**: | |
| | layer | init (Yhat=X) | trained (held-out) | | |
| |------:|--------------:|-------------------:| | |
| | 0 | 0.787 | **0.262** | | |
| | 15 | 0.347 | **0.123** | | |
| | 29 | 0.834 | **0.108** | (bf16 here diverged to 2.6 — fp32 fixed it) | |
| Eval ≈ train (healthy generalization). Full-size (253 M/layer) layer: peak VRAM 4.0 GiB. | |
| End-to-end loop (fresh capture → train → eval → delete → state) validated on a tiny | |
| 3-layer / 40k-token chunk before the full launch. | |
| ## Final assembly (LATER, not this run) | |
| This run trains the 30 experts. Assembling the deployable student — inherit Gemma's | |
| attention + embeddings, 4-bit-quantize them, swap each layer's FF/MoE block for its | |
| trained `GemmaHyperExpert`, and measure end-to-end ppl/generation — is a separate step | |
| once per-layer eval relMSE has converged. (Cached activations are stored flattened, so | |
| the chunked/sentence-boundary DECODE-SPEED variants are an assembly-time concern.) | |
| ## Files | |
| - `gemma_hyper.py` — `GemmaHyperExpert` (per-token generated low-rank, residual, warm start). | |
| - `gemma_corpus.py` — diverse non-repeating streamed corpus + eval holdout + cursors. | |
| - `capture_gemma.py` — Phase-1 capture (now accepts `--token-file` for staged chunks). | |
| - `train_layers.py` — layer-local relMSE distillation, continual per-layer checkpoints. | |
| - `staged_run.py` — the capture→train→delete orchestrator (resumable, detached). | |