| """Tier-2 driver: minimal RRT recovery run on real Laguna (or --tiny plumbing). |
| |
| Flow (see scratch_rrt.md): |
| 1. load model + tokenizer (untied = teacher) |
| 2. build a narrow corpus (default: a Python code slice) -> fixed-length blocks |
| 3. precompute teacher top-k logits over the corpus (cache; teacher then freed from |
| the loop -- only the tied student trains) |
| 4. eval baseline perplexity (B) |
| 5. tie a few adjacent mid-stack MoE pairs; eval tied-at-init (T, degraded) |
| 6. param-efficient KD (LM + top-k forward-KL) on the LoRA adapters |
| 7. eval final (R); print the recovery curve B -> T -> R |
| |
| Real run (GPU box): |
| uv run python scripts/rrt_run.py --model poolside/Laguna-XS.2 --device cuda \ |
| --dtype bfloat16 --tie-layers 18,19,20,21 --rank 16 --tokens 50_000_000 |
| |
| Local plumbing check (CPU, tiny random model, synthetic data, no network): |
| uv run python scripts/rrt_run.py --tiny --tokens 20000 --steps 30 |
| |
| The --tiny path runs the entire code path; metrics are meaningless but it proves the |
| GPU run is turn-key. Parts marked TODO(scale) are fine for a small run but should |
| stream to disk for the full 50M-token run. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| |
| |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| def save_json(path: Path, obj: dict) -> None: |
| """Atomic write (tmp + replace) + fsync, so a crash never leaves a half-file.""" |
| path.parent.mkdir(parents=True, exist_ok=True) |
| tmp = path.with_suffix(path.suffix + ".tmp") |
| with open(tmp, "w") as f: |
| json.dump(obj, f, indent=2) |
| f.flush() |
| os.fsync(f.fileno()) |
| os.replace(tmp, path) |
|
|
|
|
| def save_lora(path: Path, model) -> None: |
| """Save just the trainable (LoRA / unfrozen) params — small, crash-safe checkpoint.""" |
| path.parent.mkdir(parents=True, exist_ok=True) |
| state = {n: p.detach().cpu() for n, p in model.named_parameters() if p.requires_grad} |
| tmp = path.with_suffix(path.suffix + ".tmp") |
| torch.save(state, tmp) |
| os.replace(tmp, path) |
|
|
| from looped_laguna import build_tiny_model, load_model_and_tokenizer |
| from rrt_laguna import ( |
| TieConfig, |
| add_lora_adapters, |
| adjacent_pairs, |
| parameter_report, |
| set_param_efficient, |
| tie_model, |
| trainable_parameters, |
| untie, |
| ) |
|
|
| DEFAULT_TOKENIZER = str(Path(__file__).resolve().parent.parent / "laguna_src") |
|
|
|
|
| |
| |
| |
| def _open_code_stream(dataset: str, lang: str): |
| """Open a streamed, language-filtered code dataset. Returns (iterable, text_field). |
| |
| Verified loader calls + Python-filter recipes per dataset (see scratch_rrt.md): |
| - bigcode/the-stack-smol : data_dir=f"data/{lang}" (lowercase), field "content". |
| GATED (HF login + accept terms); smallest (~10k rows). |
| - codeparrot/github-code : languages=[Lang] (Capitalized), field "code". |
| No gating; needs trust_remote_code; large (stream+cap). |
| - bigcode/starcoderdata : data_dir=lang (lowercase), field "content". GATED, large. |
| """ |
| from datasets import load_dataset |
|
|
| if dataset == "bigcode/the-stack-smol": |
| ds = load_dataset(dataset, data_dir=f"data/{lang.lower()}", split="train", streaming=True) |
| return ds, "content" |
| if dataset in ("codeparrot/github-code", "codeparrot/github-code-clean"): |
| ds = load_dataset(dataset, split="train", streaming=True, |
| languages=[lang.capitalize()], trust_remote_code=True) |
| return ds, "code" |
| if dataset == "bigcode/starcoderdata": |
| ds = load_dataset(dataset, data_dir=lang.lower(), split="train", streaming=True) |
| return ds, "content" |
| |
| return load_dataset(dataset, split="train", streaming=True), None |
|
|
|
|
| def build_blocks(tok, *, tiny: bool, vocab: int, seq_len: int, n_tokens: int, dataset: str, lang: str): |
| """Return a [N, seq_len] LongTensor of token blocks (language-filtered code).""" |
| n_blocks = max(1, n_tokens // seq_len) |
| if tiny: |
| gen = torch.Generator().manual_seed(0) |
| return torch.randint(3, vocab, (n_blocks, seq_len), generator=gen) |
|
|
| stream, field = _open_code_stream(dataset, lang) |
| ids: list[int] = [] |
| for row in stream: |
| text = row.get(field) if field else (row.get("content") or row.get("code") or row.get("text")) |
| if not text: |
| continue |
| ids.extend(tok(text).input_ids) |
| if len(ids) >= n_blocks * seq_len: |
| break |
| if len(ids) < n_blocks * seq_len: |
| print(f"WARNING: only {len(ids):,} tokens available (< requested {n_blocks * seq_len:,}); " |
| f"dataset '{dataset}' may be smaller than --tokens.") |
| ids = ids[: (len(ids) // seq_len) * seq_len] |
| return torch.tensor(ids, dtype=torch.long).view(-1, seq_len) |
|
|
|
|
| def dry_run_data(tok, *, dataset: str, lang: str, n_rows: int = 5) -> None: |
| """Pull the first few rows from the configured stream, print the detected text |
| field, token counts, and a snippet. No model load. Use to validate the loader |
| (gating/config/field) before kicking off the full teacher precompute.""" |
| print(f"DRY RUN: dataset={dataset!r} lang={lang!r}") |
| stream, field = _open_code_stream(dataset, lang) |
| total_tok = 0 |
| for i, row in enumerate(stream): |
| if i >= n_rows: |
| break |
| text = row.get(field) if field else (row.get("content") or row.get("code") or row.get("text")) |
| used_field = field or next((f for f in ("content", "code", "text") if row.get(f)), "?") |
| n = len(tok(text).input_ids) if text else 0 |
| total_tok += n |
| snippet = (text[:80].replace("\n", "\\n") if text else "<empty>") |
| print(f" row {i}: field={used_field!r} tokens={n:<6} keys={list(row.keys())[:6]} | {snippet}") |
| if total_tok: |
| print(f"OK: {n_rows} rows ~ {total_tok:,} tokens ({total_tok // n_rows:,}/row avg). " |
| f"Loader works — safe to run the full job.") |
| else: |
| print("ERROR: no text found in the first rows — check dataset/lang/field.") |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def precompute_teacher_topk(model, blocks, *, k: int, batch: int, device): |
| """Top-k teacher logits + indices per token. TODO(scale): stream to disk for |
| 50M+ tokens instead of holding in RAM.""" |
| import time |
|
|
| vals, idxs = [], [] |
| n = len(blocks) |
| t0 = time.time() |
| every = max(1, (n // batch) // 20) |
| for bi, i in enumerate(range(0, n, batch)): |
| b = blocks[i : i + batch].to(device) |
| logits = model(input_ids=b, use_cache=False).logits.float() |
| tv, ti = torch.topk(logits, k, dim=-1) |
| vals.append(tv.cpu()) |
| idxs.append(ti.cpu()) |
| if bi % every == 0: |
| done = min(i + batch, n) |
| el = time.time() - t0 |
| eta = el / max(done, 1) * (n - done) |
| print(f" [precompute] {done}/{n} blocks ({el:.0f}s elapsed, ~{eta:.0f}s left)", flush=True) |
| print(f" [precompute] done {n} blocks in {time.time() - t0:.0f}s", flush=True) |
| return torch.cat(vals), torch.cat(idxs) |
|
|
|
|
| @torch.no_grad() |
| def top1_agreement(model, blocks, teacher_top1, *, batch: int, device) -> float: |
| """Fraction of held-out positions where the (tied) model's argmax matches the |
| untied teacher's argmax. The fallback recovery diagnostic: 'behaves like the |
| full model'. Domain-agnostic and bounded [0,1] (can't overshoot).""" |
| match, total = 0, 0 |
| for i in range(0, len(blocks), batch): |
| b = blocks[i : i + batch].to(device) |
| pred = model(input_ids=b, use_cache=False).logits.argmax(-1) |
| tt = teacher_top1[i : i + batch].to(device) |
| match += (pred == tt).sum().item() |
| total += pred.numel() |
| return match / total |
|
|
|
|
| @torch.no_grad() |
| def perplexity(model, blocks, *, batch: int, device) -> float: |
| total_nll, total_tok = 0.0, 0 |
| for i in range(0, len(blocks), batch): |
| b = blocks[i : i + batch].to(device) |
| logits = model(input_ids=b, use_cache=False).logits |
| nll = F.cross_entropy( |
| logits[:, :-1].reshape(-1, logits.shape[-1]).float(), |
| b[:, 1:].reshape(-1), reduction="sum", |
| ) |
| total_nll += nll.item() |
| total_tok += b[:, 1:].numel() |
| return float(torch.exp(torch.tensor(total_nll / total_tok))) |
|
|
|
|
| def topk_kd_loss(student_logits, teacher_vals, teacher_idx, input_ids, kd_weight: float): |
| """LM cross-entropy + top-k forward-KL. Teacher probs are renormalized over the |
| top-k support; student logprobs are gathered at the same indices.""" |
| logp = torch.log_softmax(student_logits.float(), dim=-1) |
| student_topk = torch.gather(logp, -1, teacher_idx) |
| teacher_p = torch.softmax(teacher_vals, dim=-1) |
| kl = -(teacher_p * student_topk).sum(-1).mean() + (teacher_p * teacher_p.clamp_min(1e-9).log()).sum(-1).mean() |
| ce = F.cross_entropy( |
| student_logits[:, :-1].reshape(-1, student_logits.shape[-1]).float(), |
| input_ids[:, 1:].reshape(-1), |
| ) |
| return ce + kd_weight * kl, kl.item(), ce.item() |
|
|
|
|
| |
| def main() -> None: |
| p = argparse.ArgumentParser() |
| p.add_argument("--model", default="poolside/Laguna-XS.2") |
| p.add_argument("--tokenizer", default=DEFAULT_TOKENIZER) |
| p.add_argument("--tiny", action="store_true", help="CPU plumbing run on the tiny random model") |
| p.add_argument("--device", default="cuda") |
| p.add_argument("--dtype", default="bfloat16") |
| p.add_argument("--dataset", default="bigcode/the-stack-smol", |
| help="HF code dataset. Known: bigcode/the-stack-smol (gated, smallest, ~10k Py rows), " |
| "codeparrot/github-code (no gating, needs trust_remote_code, large), " |
| "bigcode/starcoderdata (gated, large).") |
| p.add_argument("--lang", default="python", help="language to segment to (default python)") |
| p.add_argument("--tie-layers", default="18,19,20,21", |
| help="comma-separated sparse layers; adjacent-paired (a,b),(c,d),...") |
| p.add_argument("--rank", type=int, default=16) |
| p.add_argument("--router-rank", type=int, default=None, help="LoRA rank on the router (default = --rank)") |
| p.add_argument("--init", default="lower", choices=["lower", "average"]) |
| p.add_argument("--seq-len", type=int, default=1024) |
| p.add_argument("--tokens", type=int, default=50_000_000) |
| p.add_argument("--steps", type=int, default=2000) |
| p.add_argument("--batch", type=int, default=4) |
| p.add_argument("--lr", type=float, default=1e-3) |
| p.add_argument("--topk", type=int, default=64) |
| p.add_argument("--kd-weight", type=float, default=1.0) |
| p.add_argument("--unfreeze-shared", action="store_true", |
| help="fallback: also train the shared expert/router base") |
| p.add_argument("--grad-checkpoint", action=argparse.BooleanOptionalAction, default=True, |
| help="gradient checkpointing (recompute activations in backward) — needed to fit training") |
| p.add_argument("--reference", action="store_true", |
| help="matched reference: LoRA the SAME layers without tying, train CE-only (no KD). " |
| "Its final ppl is the ceiling the tied run is compared against.") |
| p.add_argument("--dry-run", action="store_true", |
| help="pull the first few dataset rows, print field/token counts, and exit (no model load)") |
| p.add_argument("--outdir", default="results_rrt", help="directory for results JSON + checkpoints") |
| p.add_argument("--run-name", default=None, help="run name (defaults to a config-derived slug)") |
| p.add_argument("--eval-every", type=int, default=250, help="eval held-out ppl every N steps (0=off)") |
| p.add_argument("--save-lora", action="store_true", help="save LoRA checkpoint at each eval + at end") |
| args = p.parse_args() |
|
|
| device = "cpu" if args.tiny else args.device |
|
|
| |
| if args.dry_run: |
| if args.tiny: |
| print("--dry-run is for real datasets; --tiny uses synthetic data.") |
| return |
| from transformers import AutoTokenizer |
|
|
| tok = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True, fix_mistral_regex=True) |
| dry_run_data(tok, dataset=args.dataset, lang=args.lang) |
| return |
|
|
| if args.tiny: |
| from transformers import AutoTokenizer |
| tok = None |
| model = build_tiny_model(num_layers=8) |
| vocab = model.config.vocab_size |
| else: |
| model, tok = load_model_and_tokenizer(args.model, args.tokenizer, args.dtype, device) |
| vocab = model.config.vocab_size |
|
|
| blocks = build_blocks(tok, tiny=args.tiny, vocab=vocab, seq_len=args.seq_len, |
| n_tokens=args.tokens, dataset=args.dataset, lang=args.lang) |
| |
| n_eval = max(1, len(blocks) // 10) |
| eval_blocks, train_blocks = blocks[:n_eval], blocks[n_eval:] |
| print(f"corpus: {len(train_blocks)} train / {len(eval_blocks)} eval blocks of {args.seq_len}") |
|
|
| |
| |
| mode = "reference" if args.reference else "tied" |
| layers = [int(x) for x in args.tie_layers.split(",")] |
| slug = "-".join(map(str, layers)) |
| run_name = args.run_name or ( |
| ("ref_" if args.reference else "tied_") + f"L{slug}_r{args.rank}" |
| + ("_unfz" if (args.unfreeze_shared and not args.reference) else "") |
| ) |
| outdir = Path(args.outdir) |
| json_path = outdir / f"{run_name}.json" |
| results = { |
| "run_name": run_name, "mode": mode, "config": vars(args), "layers": layers, |
| "corpus": {"train_blocks": len(train_blocks), "eval_blocks": len(eval_blocks), "seq_len": args.seq_len}, |
| "status": "running", |
| "baseline_ppl": None, "init_ppl": None, "final_ppl": None, |
| "agreement_init": None, "agreement_final": None, |
| "params": None, |
| "curve": [], |
| } |
| save_json(json_path, results) |
| print(f"[{mode}] saving results to {json_path}", flush=True) |
|
|
| |
| print(f"[1/3] baseline perplexity over {len(eval_blocks)} eval blocks...", flush=True) |
| base_ppl = perplexity(model, eval_blocks, batch=args.batch, device=device) |
| print(f" baseline ppl = {base_ppl:.3f}", flush=True) |
| results["baseline_ppl"] = base_ppl |
| save_json(json_path, results) |
|
|
| |
| t_vals = t_idx = teacher_top1_eval = None |
| if mode == "tied": |
| print(f"[2/3] precomputing teacher top-{args.topk} logits over {len(train_blocks)} train blocks...", flush=True) |
| t_vals, t_idx = precompute_teacher_topk(model, train_blocks, k=args.topk, batch=args.batch, device=device) |
| te_vals, te_idx = precompute_teacher_topk(model, eval_blocks, k=1, batch=args.batch, device=device) |
| teacher_top1_eval = te_idx[..., 0] |
| else: |
| print("[2/3] reference run (CE-only, no teacher) — skipping precompute.", flush=True) |
|
|
| |
| before = parameter_report(model)["total_unique"] |
| if mode == "tied": |
| cfg = TieConfig(pairs=adjacent_pairs(layers), rank=args.rank, init=args.init, lora_init="svd") |
| tie_model(model, cfg, keep_for_untie=False) |
| desc = f"tied {cfg.pairs} init={cfg.init}" |
| else: |
| add_lora_adapters(model, layers, rank=args.rank, router_rank=args.router_rank, lora_init="zero") |
| desc = f"reference (LoRA on {layers}, no tie)" |
| rep = parameter_report(model) |
| set_param_efficient(model, unfreeze_shared=(args.unfreeze_shared and mode == "tied")) |
| if not args.tiny and args.grad_checkpoint: |
| model.config.use_cache = False |
| model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) |
| if not args.tiny: |
| torch.cuda.empty_cache() |
|
|
| init_ppl = perplexity(model, eval_blocks, batch=args.batch, device=device) |
| results["init_ppl"] = init_ppl |
| if mode == "tied": |
| results["agreement_init"] = top1_agreement(model, eval_blocks, teacher_top1_eval, batch=args.batch, device=device) |
| results["params"] = {"baseline_unique": before, "after_unique": rep["total_unique"], |
| "pct_smaller": 100 * (1 - rep["total_unique"] / before), |
| "trainable": sum(p.numel() for p in trainable_parameters(model))} |
| save_json(json_path, results) |
| print(f"{desc} rank={args.rank}: {before:,} -> {rep['total_unique']:,} params " |
| f"({results['params']['pct_smaller']:.1f}% smaller), trainable {results['params']['trainable']:,}", flush=True) |
| if mode == "tied": |
| print(f" init: ppl {init_ppl:.3f} top1-agreement {results['agreement_init']:.1%}", flush=True) |
|
|
| def record(step, loss, kl, ce): |
| model.eval() |
| ppl = perplexity(model, eval_blocks, batch=args.batch, device=device) |
| agr = (top1_agreement(model, eval_blocks, teacher_top1_eval, batch=args.batch, device=device) |
| if mode == "tied" else None) |
| model.train() |
| results["curve"].append({"step": step, "train_loss": loss, "kl": kl, "ce": ce, |
| "eval_ppl": ppl, "agreement": agr}) |
| results["final_ppl"] = ppl |
| results["agreement_final"] = agr |
| save_json(json_path, results) |
| if args.save_lora: |
| save_lora(outdir / f"{run_name}_lora.pt", model) |
| extra = f" top1 {agr:.1%}" if agr is not None else "" |
| print(f" [eval] step {step:>5} eval_ppl {ppl:.3f}{extra} (saved)", flush=True) |
|
|
| |
| opt = torch.optim.AdamW(trainable_parameters(model), lr=args.lr) |
| model.train() |
| step = 0 |
| t_train, tokens_seen = time.time(), 0 |
| while step < args.steps: |
| for i in range(0, len(train_blocks), args.batch): |
| if step >= args.steps: |
| break |
| sl = slice(i, i + args.batch) |
| b = train_blocks[sl].to(device) |
| logits = model(input_ids=b, use_cache=False).logits |
| if mode == "tied": |
| loss, kl, ce = topk_kd_loss(logits, t_vals[sl].to(device), t_idx[sl].to(device), b, args.kd_weight) |
| else: |
| ce_t = F.cross_entropy(logits[:, :-1].reshape(-1, logits.shape[-1]).float(), b[:, 1:].reshape(-1)) |
| loss, kl, ce = ce_t, 0.0, ce_t.item() |
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(trainable_parameters(model), 1.0) |
| opt.step() |
| tokens_seen += b.numel() |
| if step == 0: |
| print(f"[3/3] {'CE+KD' if mode == 'tied' else 'CE-only'} training for {args.steps} steps...", flush=True) |
| |
| if step < 5 or step % 50 == 0: |
| tps = tokens_seen / max(time.time() - t_train, 1e-6) |
| print(f"step {step:>5} loss {loss.item():.4f} KL {float(kl):.4f} CE {ce:.4f} ({tps:,.0f} tok/s)", flush=True) |
| if args.eval_every and step > 0 and step % args.eval_every == 0: |
| record(step, loss.item(), float(kl), ce) |
| step += 1 |
| model.eval() |
|
|
| final_ppl = perplexity(model, eval_blocks, batch=args.batch, device=device) |
| results["final_ppl"] = final_ppl |
| if mode == "tied": |
| results["agreement_final"] = top1_agreement(model, eval_blocks, teacher_top1_eval, batch=args.batch, device=device) |
| results["status"] = "done" |
| save_json(json_path, results) |
| if args.save_lora: |
| save_lora(outdir / f"{run_name}_lora.pt", model) |
| print(f"\n=== {mode} result (held-out Python) ===") |
| print(f"baseline (untied) : {base_ppl:.3f}") |
| print(f"after-mod @ init : {init_ppl:.3f}") |
| print(f"after training : {final_ppl:.3f} <-- compare tied vs reference (the recovery gap)") |
| if mode == "tied": |
| print(f"top1 agreement : init {results['agreement_init']:.1%} -> final {results['agreement_final']:.1%}") |
| print(f"results -> {json_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|