"""Tier-2 driver: minimal RRT recovery run on real Laguna (or --tiny plumbing). Flow (see scratch_rrt.md): 1. load model + tokenizer (untied = teacher) 2. build a narrow corpus (default: a Python code slice) -> fixed-length blocks 3. precompute teacher top-k logits over the corpus (cache; teacher then freed from the loop -- only the tied student trains) 4. eval baseline perplexity (B) 5. tie a few adjacent mid-stack MoE pairs; eval tied-at-init (T, degraded) 6. param-efficient KD (LM + top-k forward-KL) on the LoRA adapters 7. eval final (R); print the recovery curve B -> T -> R Real run (GPU box): uv run python scripts/rrt_run.py --model poolside/Laguna-XS.2 --device cuda \ --dtype bfloat16 --tie-layers 18,19,20,21 --rank 16 --tokens 50_000_000 Local plumbing check (CPU, tiny random model, synthetic data, no network): uv run python scripts/rrt_run.py --tiny --tokens 20000 --steps 30 The --tiny path runs the entire code path; metrics are meaningless but it proves the GPU run is turn-key. Parts marked TODO(scale) are fine for a small run but should stream to disk for the full 50M-token run. """ from __future__ import annotations import argparse import json import os import sys import time from pathlib import Path # Reduce CUDA fragmentation (the backward OOM left 2.5GB reserved-but-unallocated). # Must be set before torch initializes CUDA. os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import torch import torch.nn.functional as F def save_json(path: Path, obj: dict) -> None: """Atomic write (tmp + replace) + fsync, so a crash never leaves a half-file.""" path.parent.mkdir(parents=True, exist_ok=True) tmp = path.with_suffix(path.suffix + ".tmp") with open(tmp, "w") as f: json.dump(obj, f, indent=2) f.flush() os.fsync(f.fileno()) os.replace(tmp, path) def save_lora(path: Path, model) -> None: """Save just the trainable (LoRA / unfrozen) params — small, crash-safe checkpoint.""" path.parent.mkdir(parents=True, exist_ok=True) state = {n: p.detach().cpu() for n, p in model.named_parameters() if p.requires_grad} tmp = path.with_suffix(path.suffix + ".tmp") torch.save(state, tmp) os.replace(tmp, path) from looped_laguna import build_tiny_model, load_model_and_tokenizer from rrt_laguna import ( TieConfig, add_lora_adapters, adjacent_pairs, parameter_report, set_param_efficient, tie_model, trainable_parameters, untie, ) DEFAULT_TOKENIZER = str(Path(__file__).resolve().parent.parent / "laguna_src") # --------------------------------------------------------------------------- # # Data # # --------------------------------------------------------------------------- # def _open_code_stream(dataset: str, lang: str): """Open a streamed, language-filtered code dataset. Returns (iterable, text_field). Verified loader calls + Python-filter recipes per dataset (see scratch_rrt.md): - bigcode/the-stack-smol : data_dir=f"data/{lang}" (lowercase), field "content". GATED (HF login + accept terms); smallest (~10k rows). - codeparrot/github-code : languages=[Lang] (Capitalized), field "code". No gating; needs trust_remote_code; large (stream+cap). - bigcode/starcoderdata : data_dir=lang (lowercase), field "content". GATED, large. """ from datasets import load_dataset if dataset == "bigcode/the-stack-smol": ds = load_dataset(dataset, data_dir=f"data/{lang.lower()}", split="train", streaming=True) return ds, "content" if dataset in ("codeparrot/github-code", "codeparrot/github-code-clean"): ds = load_dataset(dataset, split="train", streaming=True, languages=[lang.capitalize()], trust_remote_code=True) return ds, "code" if dataset == "bigcode/starcoderdata": ds = load_dataset(dataset, data_dir=lang.lower(), split="train", streaming=True) return ds, "content" # Unknown dataset: stream split="train" and sniff a text field per row. return load_dataset(dataset, split="train", streaming=True), None def build_blocks(tok, *, tiny: bool, vocab: int, seq_len: int, n_tokens: int, dataset: str, lang: str): """Return a [N, seq_len] LongTensor of token blocks (language-filtered code).""" n_blocks = max(1, n_tokens // seq_len) if tiny: gen = torch.Generator().manual_seed(0) return torch.randint(3, vocab, (n_blocks, seq_len), generator=gen) stream, field = _open_code_stream(dataset, lang) ids: list[int] = [] for row in stream: text = row.get(field) if field else (row.get("content") or row.get("code") or row.get("text")) if not text: continue ids.extend(tok(text).input_ids) if len(ids) >= n_blocks * seq_len: break if len(ids) < n_blocks * seq_len: print(f"WARNING: only {len(ids):,} tokens available (< requested {n_blocks * seq_len:,}); " f"dataset '{dataset}' may be smaller than --tokens.") ids = ids[: (len(ids) // seq_len) * seq_len] return torch.tensor(ids, dtype=torch.long).view(-1, seq_len) def dry_run_data(tok, *, dataset: str, lang: str, n_rows: int = 5) -> None: """Pull the first few rows from the configured stream, print the detected text field, token counts, and a snippet. No model load. Use to validate the loader (gating/config/field) before kicking off the full teacher precompute.""" print(f"DRY RUN: dataset={dataset!r} lang={lang!r}") stream, field = _open_code_stream(dataset, lang) total_tok = 0 for i, row in enumerate(stream): if i >= n_rows: break text = row.get(field) if field else (row.get("content") or row.get("code") or row.get("text")) used_field = field or next((f for f in ("content", "code", "text") if row.get(f)), "?") n = len(tok(text).input_ids) if text else 0 total_tok += n snippet = (text[:80].replace("\n", "\\n") if text else "") print(f" row {i}: field={used_field!r} tokens={n:<6} keys={list(row.keys())[:6]} | {snippet}") if total_tok: print(f"OK: {n_rows} rows ~ {total_tok:,} tokens ({total_tok // n_rows:,}/row avg). " f"Loader works — safe to run the full job.") else: print("ERROR: no text found in the first rows — check dataset/lang/field.") # --------------------------------------------------------------------------- # # Teacher targets + eval # # --------------------------------------------------------------------------- # @torch.no_grad() def precompute_teacher_topk(model, blocks, *, k: int, batch: int, device): """Top-k teacher logits + indices per token. TODO(scale): stream to disk for 50M+ tokens instead of holding in RAM.""" import time vals, idxs = [], [] n = len(blocks) t0 = time.time() every = max(1, (n // batch) // 20) # ~20 heartbeats for bi, i in enumerate(range(0, n, batch)): b = blocks[i : i + batch].to(device) logits = model(input_ids=b, use_cache=False).logits.float() tv, ti = torch.topk(logits, k, dim=-1) vals.append(tv.cpu()) idxs.append(ti.cpu()) if bi % every == 0: done = min(i + batch, n) el = time.time() - t0 eta = el / max(done, 1) * (n - done) print(f" [precompute] {done}/{n} blocks ({el:.0f}s elapsed, ~{eta:.0f}s left)", flush=True) print(f" [precompute] done {n} blocks in {time.time() - t0:.0f}s", flush=True) return torch.cat(vals), torch.cat(idxs) @torch.no_grad() def top1_agreement(model, blocks, teacher_top1, *, batch: int, device) -> float: """Fraction of held-out positions where the (tied) model's argmax matches the untied teacher's argmax. The fallback recovery diagnostic: 'behaves like the full model'. Domain-agnostic and bounded [0,1] (can't overshoot).""" match, total = 0, 0 for i in range(0, len(blocks), batch): b = blocks[i : i + batch].to(device) pred = model(input_ids=b, use_cache=False).logits.argmax(-1) tt = teacher_top1[i : i + batch].to(device) match += (pred == tt).sum().item() total += pred.numel() return match / total @torch.no_grad() def perplexity(model, blocks, *, batch: int, device) -> float: total_nll, total_tok = 0.0, 0 for i in range(0, len(blocks), batch): b = blocks[i : i + batch].to(device) logits = model(input_ids=b, use_cache=False).logits nll = F.cross_entropy( logits[:, :-1].reshape(-1, logits.shape[-1]).float(), b[:, 1:].reshape(-1), reduction="sum", ) total_nll += nll.item() total_tok += b[:, 1:].numel() return float(torch.exp(torch.tensor(total_nll / total_tok))) def topk_kd_loss(student_logits, teacher_vals, teacher_idx, input_ids, kd_weight: float): """LM cross-entropy + top-k forward-KL. Teacher probs are renormalized over the top-k support; student logprobs are gathered at the same indices.""" logp = torch.log_softmax(student_logits.float(), dim=-1) student_topk = torch.gather(logp, -1, teacher_idx) # [B,T,k] teacher_p = torch.softmax(teacher_vals, dim=-1) # renormalized over k kl = -(teacher_p * student_topk).sum(-1).mean() + (teacher_p * teacher_p.clamp_min(1e-9).log()).sum(-1).mean() ce = F.cross_entropy( student_logits[:, :-1].reshape(-1, student_logits.shape[-1]).float(), input_ids[:, 1:].reshape(-1), ) return ce + kd_weight * kl, kl.item(), ce.item() # --------------------------------------------------------------------------- # def main() -> None: p = argparse.ArgumentParser() p.add_argument("--model", default="poolside/Laguna-XS.2") p.add_argument("--tokenizer", default=DEFAULT_TOKENIZER) p.add_argument("--tiny", action="store_true", help="CPU plumbing run on the tiny random model") p.add_argument("--device", default="cuda") p.add_argument("--dtype", default="bfloat16") p.add_argument("--dataset", default="bigcode/the-stack-smol", help="HF code dataset. Known: bigcode/the-stack-smol (gated, smallest, ~10k Py rows), " "codeparrot/github-code (no gating, needs trust_remote_code, large), " "bigcode/starcoderdata (gated, large).") p.add_argument("--lang", default="python", help="language to segment to (default python)") p.add_argument("--tie-layers", default="18,19,20,21", help="comma-separated sparse layers; adjacent-paired (a,b),(c,d),...") p.add_argument("--rank", type=int, default=16) p.add_argument("--router-rank", type=int, default=None, help="LoRA rank on the router (default = --rank)") p.add_argument("--init", default="lower", choices=["lower", "average"]) p.add_argument("--seq-len", type=int, default=1024) p.add_argument("--tokens", type=int, default=50_000_000) p.add_argument("--steps", type=int, default=2000) p.add_argument("--batch", type=int, default=4) p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--topk", type=int, default=64) p.add_argument("--kd-weight", type=float, default=1.0) p.add_argument("--unfreeze-shared", action="store_true", help="fallback: also train the shared expert/router base") p.add_argument("--grad-checkpoint", action=argparse.BooleanOptionalAction, default=True, help="gradient checkpointing (recompute activations in backward) — needed to fit training") p.add_argument("--reference", action="store_true", help="matched reference: LoRA the SAME layers without tying, train CE-only (no KD). " "Its final ppl is the ceiling the tied run is compared against.") p.add_argument("--dry-run", action="store_true", help="pull the first few dataset rows, print field/token counts, and exit (no model load)") p.add_argument("--outdir", default="results_rrt", help="directory for results JSON + checkpoints") p.add_argument("--run-name", default=None, help="run name (defaults to a config-derived slug)") p.add_argument("--eval-every", type=int, default=250, help="eval held-out ppl every N steps (0=off)") p.add_argument("--save-lora", action="store_true", help="save LoRA checkpoint at each eval + at end") args = p.parse_args() device = "cpu" if args.tiny else args.device # Data-loader validation: tokenizer only, no model. if args.dry_run: if args.tiny: print("--dry-run is for real datasets; --tiny uses synthetic data.") return from transformers import AutoTokenizer tok = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True, fix_mistral_regex=True) dry_run_data(tok, dataset=args.dataset, lang=args.lang) return if args.tiny: from transformers import AutoTokenizer tok = None model = build_tiny_model(num_layers=8) vocab = model.config.vocab_size else: model, tok = load_model_and_tokenizer(args.model, args.tokenizer, args.dtype, device) vocab = model.config.vocab_size blocks = build_blocks(tok, tiny=args.tiny, vocab=vocab, seq_len=args.seq_len, n_tokens=args.tokens, dataset=args.dataset, lang=args.lang) # simple train/held-out split n_eval = max(1, len(blocks) // 10) eval_blocks, train_blocks = blocks[:n_eval], blocks[n_eval:] print(f"corpus: {len(train_blocks)} train / {len(eval_blocks)} eval blocks of {args.seq_len}") # Result record, persisted after every milestone (atomic write) so a crash keeps # whatever we'd reached. mode = "reference" if args.reference else "tied" layers = [int(x) for x in args.tie_layers.split(",")] slug = "-".join(map(str, layers)) run_name = args.run_name or ( ("ref_" if args.reference else "tied_") + f"L{slug}_r{args.rank}" + ("_unfz" if (args.unfreeze_shared and not args.reference) else "") ) outdir = Path(args.outdir) json_path = outdir / f"{run_name}.json" results = { "run_name": run_name, "mode": mode, "config": vars(args), "layers": layers, "corpus": {"train_blocks": len(train_blocks), "eval_blocks": len(eval_blocks), "seq_len": args.seq_len}, "status": "running", "baseline_ppl": None, "init_ppl": None, "final_ppl": None, "agreement_init": None, "agreement_final": None, # tied only "params": None, "curve": [], # [{step, train_loss, kl, ce, eval_ppl, agreement}] } save_json(json_path, results) print(f"[{mode}] saving results to {json_path}", flush=True) # Baseline = the untied model's held-out ppl, before any modification. print(f"[1/3] baseline perplexity over {len(eval_blocks)} eval blocks...", flush=True) base_ppl = perplexity(model, eval_blocks, batch=args.batch, device=device) print(f" baseline ppl = {base_ppl:.3f}", flush=True) results["baseline_ppl"] = base_ppl save_json(json_path, results) # Teacher targets (tied mode only): top-k on train (for KD) + top-1 on eval (agreement). t_vals = t_idx = teacher_top1_eval = None if mode == "tied": print(f"[2/3] precomputing teacher top-{args.topk} logits over {len(train_blocks)} train blocks...", flush=True) t_vals, t_idx = precompute_teacher_topk(model, train_blocks, k=args.topk, batch=args.batch, device=device) te_vals, te_idx = precompute_teacher_topk(model, eval_blocks, k=1, batch=args.batch, device=device) teacher_top1_eval = te_idx[..., 0] else: print("[2/3] reference run (CE-only, no teacher) — skipping precompute.", flush=True) # Modify the model: tie (CE+KD) or add LoRA to own banks (reference, CE-only). before = parameter_report(model)["total_unique"] if mode == "tied": cfg = TieConfig(pairs=adjacent_pairs(layers), rank=args.rank, init=args.init, lora_init="svd") tie_model(model, cfg, keep_for_untie=False) # never unties -> free originals desc = f"tied {cfg.pairs} init={cfg.init}" else: add_lora_adapters(model, layers, rank=args.rank, router_rank=args.router_rank, lora_init="zero") desc = f"reference (LoRA on {layers}, no tie)" rep = parameter_report(model) set_param_efficient(model, unfreeze_shared=(args.unfreeze_shared and mode == "tied")) if not args.tiny and args.grad_checkpoint: model.config.use_cache = False model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) if not args.tiny: torch.cuda.empty_cache() init_ppl = perplexity(model, eval_blocks, batch=args.batch, device=device) results["init_ppl"] = init_ppl if mode == "tied": results["agreement_init"] = top1_agreement(model, eval_blocks, teacher_top1_eval, batch=args.batch, device=device) results["params"] = {"baseline_unique": before, "after_unique": rep["total_unique"], "pct_smaller": 100 * (1 - rep["total_unique"] / before), "trainable": sum(p.numel() for p in trainable_parameters(model))} save_json(json_path, results) print(f"{desc} rank={args.rank}: {before:,} -> {rep['total_unique']:,} params " f"({results['params']['pct_smaller']:.1f}% smaller), trainable {results['params']['trainable']:,}", flush=True) if mode == "tied": print(f" init: ppl {init_ppl:.3f} top1-agreement {results['agreement_init']:.1%}", flush=True) def record(step, loss, kl, ce): model.eval() ppl = perplexity(model, eval_blocks, batch=args.batch, device=device) agr = (top1_agreement(model, eval_blocks, teacher_top1_eval, batch=args.batch, device=device) if mode == "tied" else None) model.train() results["curve"].append({"step": step, "train_loss": loss, "kl": kl, "ce": ce, "eval_ppl": ppl, "agreement": agr}) results["final_ppl"] = ppl results["agreement_final"] = agr save_json(json_path, results) if args.save_lora: save_lora(outdir / f"{run_name}_lora.pt", model) extra = f" top1 {agr:.1%}" if agr is not None else "" print(f" [eval] step {step:>5} eval_ppl {ppl:.3f}{extra} (saved)", flush=True) # Param-efficient training: tied -> CE+KD, reference -> CE only. opt = torch.optim.AdamW(trainable_parameters(model), lr=args.lr) model.train() step = 0 t_train, tokens_seen = time.time(), 0 while step < args.steps: for i in range(0, len(train_blocks), args.batch): if step >= args.steps: break sl = slice(i, i + args.batch) b = train_blocks[sl].to(device) logits = model(input_ids=b, use_cache=False).logits if mode == "tied": loss, kl, ce = topk_kd_loss(logits, t_vals[sl].to(device), t_idx[sl].to(device), b, args.kd_weight) else: ce_t = F.cross_entropy(logits[:, :-1].reshape(-1, logits.shape[-1]).float(), b[:, 1:].reshape(-1)) loss, kl, ce = ce_t, 0.0, ce_t.item() opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(trainable_parameters(model), 1.0) opt.step() tokens_seen += b.numel() if step == 0: print(f"[3/3] {'CE+KD' if mode == 'tied' else 'CE-only'} training for {args.steps} steps...", flush=True) # Print early (for the batch-size/fit probe) and then every 50 steps. if step < 5 or step % 50 == 0: tps = tokens_seen / max(time.time() - t_train, 1e-6) print(f"step {step:>5} loss {loss.item():.4f} KL {float(kl):.4f} CE {ce:.4f} ({tps:,.0f} tok/s)", flush=True) if args.eval_every and step > 0 and step % args.eval_every == 0: record(step, loss.item(), float(kl), ce) step += 1 model.eval() final_ppl = perplexity(model, eval_blocks, batch=args.batch, device=device) results["final_ppl"] = final_ppl if mode == "tied": results["agreement_final"] = top1_agreement(model, eval_blocks, teacher_top1_eval, batch=args.batch, device=device) results["status"] = "done" save_json(json_path, results) if args.save_lora: save_lora(outdir / f"{run_name}_lora.pt", model) print(f"\n=== {mode} result (held-out Python) ===") print(f"baseline (untied) : {base_ppl:.3f}") print(f"after-mod @ init : {init_ppl:.3f}") print(f"after training : {final_ppl:.3f} <-- compare tied vs reference (the recovery gap)") if mode == "tied": print(f"top1 agreement : init {results['agreement_init']:.1%} -> final {results['agreement_final']:.1%}") print(f"results -> {json_path}") if __name__ == "__main__": main()