looped-laguna / scripts /rrt_run.py
e-p's picture
rrt fair comparison and sweep
9316acb
"""Tier-2 driver: minimal RRT recovery run on real Laguna (or --tiny plumbing).
Flow (see scratch_rrt.md):
1. load model + tokenizer (untied = teacher)
2. build a narrow corpus (default: a Python code slice) -> fixed-length blocks
3. precompute teacher top-k logits over the corpus (cache; teacher then freed from
the loop -- only the tied student trains)
4. eval baseline perplexity (B)
5. tie a few adjacent mid-stack MoE pairs; eval tied-at-init (T, degraded)
6. param-efficient KD (LM + top-k forward-KL) on the LoRA adapters
7. eval final (R); print the recovery curve B -> T -> R
Real run (GPU box):
uv run python scripts/rrt_run.py --model poolside/Laguna-XS.2 --device cuda \
--dtype bfloat16 --tie-layers 18,19,20,21 --rank 16 --tokens 50_000_000
Local plumbing check (CPU, tiny random model, synthetic data, no network):
uv run python scripts/rrt_run.py --tiny --tokens 20000 --steps 30
The --tiny path runs the entire code path; metrics are meaningless but it proves the
GPU run is turn-key. Parts marked TODO(scale) are fine for a small run but should
stream to disk for the full 50M-token run.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from pathlib import Path
# Reduce CUDA fragmentation (the backward OOM left 2.5GB reserved-but-unallocated).
# Must be set before torch initializes CUDA.
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
import torch.nn.functional as F
def save_json(path: Path, obj: dict) -> None:
"""Atomic write (tmp + replace) + fsync, so a crash never leaves a half-file."""
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + ".tmp")
with open(tmp, "w") as f:
json.dump(obj, f, indent=2)
f.flush()
os.fsync(f.fileno())
os.replace(tmp, path)
def save_lora(path: Path, model) -> None:
"""Save just the trainable (LoRA / unfrozen) params — small, crash-safe checkpoint."""
path.parent.mkdir(parents=True, exist_ok=True)
state = {n: p.detach().cpu() for n, p in model.named_parameters() if p.requires_grad}
tmp = path.with_suffix(path.suffix + ".tmp")
torch.save(state, tmp)
os.replace(tmp, path)
from looped_laguna import build_tiny_model, load_model_and_tokenizer
from rrt_laguna import (
TieConfig,
add_lora_adapters,
adjacent_pairs,
parameter_report,
set_param_efficient,
tie_model,
trainable_parameters,
untie,
)
DEFAULT_TOKENIZER = str(Path(__file__).resolve().parent.parent / "laguna_src")
# --------------------------------------------------------------------------- #
# Data #
# --------------------------------------------------------------------------- #
def _open_code_stream(dataset: str, lang: str):
"""Open a streamed, language-filtered code dataset. Returns (iterable, text_field).
Verified loader calls + Python-filter recipes per dataset (see scratch_rrt.md):
- bigcode/the-stack-smol : data_dir=f"data/{lang}" (lowercase), field "content".
GATED (HF login + accept terms); smallest (~10k rows).
- codeparrot/github-code : languages=[Lang] (Capitalized), field "code".
No gating; needs trust_remote_code; large (stream+cap).
- bigcode/starcoderdata : data_dir=lang (lowercase), field "content". GATED, large.
"""
from datasets import load_dataset
if dataset == "bigcode/the-stack-smol":
ds = load_dataset(dataset, data_dir=f"data/{lang.lower()}", split="train", streaming=True)
return ds, "content"
if dataset in ("codeparrot/github-code", "codeparrot/github-code-clean"):
ds = load_dataset(dataset, split="train", streaming=True,
languages=[lang.capitalize()], trust_remote_code=True)
return ds, "code"
if dataset == "bigcode/starcoderdata":
ds = load_dataset(dataset, data_dir=lang.lower(), split="train", streaming=True)
return ds, "content"
# Unknown dataset: stream split="train" and sniff a text field per row.
return load_dataset(dataset, split="train", streaming=True), None
def build_blocks(tok, *, tiny: bool, vocab: int, seq_len: int, n_tokens: int, dataset: str, lang: str):
"""Return a [N, seq_len] LongTensor of token blocks (language-filtered code)."""
n_blocks = max(1, n_tokens // seq_len)
if tiny:
gen = torch.Generator().manual_seed(0)
return torch.randint(3, vocab, (n_blocks, seq_len), generator=gen)
stream, field = _open_code_stream(dataset, lang)
ids: list[int] = []
for row in stream:
text = row.get(field) if field else (row.get("content") or row.get("code") or row.get("text"))
if not text:
continue
ids.extend(tok(text).input_ids)
if len(ids) >= n_blocks * seq_len:
break
if len(ids) < n_blocks * seq_len:
print(f"WARNING: only {len(ids):,} tokens available (< requested {n_blocks * seq_len:,}); "
f"dataset '{dataset}' may be smaller than --tokens.")
ids = ids[: (len(ids) // seq_len) * seq_len]
return torch.tensor(ids, dtype=torch.long).view(-1, seq_len)
def dry_run_data(tok, *, dataset: str, lang: str, n_rows: int = 5) -> None:
"""Pull the first few rows from the configured stream, print the detected text
field, token counts, and a snippet. No model load. Use to validate the loader
(gating/config/field) before kicking off the full teacher precompute."""
print(f"DRY RUN: dataset={dataset!r} lang={lang!r}")
stream, field = _open_code_stream(dataset, lang)
total_tok = 0
for i, row in enumerate(stream):
if i >= n_rows:
break
text = row.get(field) if field else (row.get("content") or row.get("code") or row.get("text"))
used_field = field or next((f for f in ("content", "code", "text") if row.get(f)), "?")
n = len(tok(text).input_ids) if text else 0
total_tok += n
snippet = (text[:80].replace("\n", "\\n") if text else "<empty>")
print(f" row {i}: field={used_field!r} tokens={n:<6} keys={list(row.keys())[:6]} | {snippet}")
if total_tok:
print(f"OK: {n_rows} rows ~ {total_tok:,} tokens ({total_tok // n_rows:,}/row avg). "
f"Loader works — safe to run the full job.")
else:
print("ERROR: no text found in the first rows — check dataset/lang/field.")
# --------------------------------------------------------------------------- #
# Teacher targets + eval #
# --------------------------------------------------------------------------- #
@torch.no_grad()
def precompute_teacher_topk(model, blocks, *, k: int, batch: int, device):
"""Top-k teacher logits + indices per token. TODO(scale): stream to disk for
50M+ tokens instead of holding in RAM."""
import time
vals, idxs = [], []
n = len(blocks)
t0 = time.time()
every = max(1, (n // batch) // 20) # ~20 heartbeats
for bi, i in enumerate(range(0, n, batch)):
b = blocks[i : i + batch].to(device)
logits = model(input_ids=b, use_cache=False).logits.float()
tv, ti = torch.topk(logits, k, dim=-1)
vals.append(tv.cpu())
idxs.append(ti.cpu())
if bi % every == 0:
done = min(i + batch, n)
el = time.time() - t0
eta = el / max(done, 1) * (n - done)
print(f" [precompute] {done}/{n} blocks ({el:.0f}s elapsed, ~{eta:.0f}s left)", flush=True)
print(f" [precompute] done {n} blocks in {time.time() - t0:.0f}s", flush=True)
return torch.cat(vals), torch.cat(idxs)
@torch.no_grad()
def top1_agreement(model, blocks, teacher_top1, *, batch: int, device) -> float:
"""Fraction of held-out positions where the (tied) model's argmax matches the
untied teacher's argmax. The fallback recovery diagnostic: 'behaves like the
full model'. Domain-agnostic and bounded [0,1] (can't overshoot)."""
match, total = 0, 0
for i in range(0, len(blocks), batch):
b = blocks[i : i + batch].to(device)
pred = model(input_ids=b, use_cache=False).logits.argmax(-1)
tt = teacher_top1[i : i + batch].to(device)
match += (pred == tt).sum().item()
total += pred.numel()
return match / total
@torch.no_grad()
def perplexity(model, blocks, *, batch: int, device) -> float:
total_nll, total_tok = 0.0, 0
for i in range(0, len(blocks), batch):
b = blocks[i : i + batch].to(device)
logits = model(input_ids=b, use_cache=False).logits
nll = F.cross_entropy(
logits[:, :-1].reshape(-1, logits.shape[-1]).float(),
b[:, 1:].reshape(-1), reduction="sum",
)
total_nll += nll.item()
total_tok += b[:, 1:].numel()
return float(torch.exp(torch.tensor(total_nll / total_tok)))
def topk_kd_loss(student_logits, teacher_vals, teacher_idx, input_ids, kd_weight: float):
"""LM cross-entropy + top-k forward-KL. Teacher probs are renormalized over the
top-k support; student logprobs are gathered at the same indices."""
logp = torch.log_softmax(student_logits.float(), dim=-1)
student_topk = torch.gather(logp, -1, teacher_idx) # [B,T,k]
teacher_p = torch.softmax(teacher_vals, dim=-1) # renormalized over k
kl = -(teacher_p * student_topk).sum(-1).mean() + (teacher_p * teacher_p.clamp_min(1e-9).log()).sum(-1).mean()
ce = F.cross_entropy(
student_logits[:, :-1].reshape(-1, student_logits.shape[-1]).float(),
input_ids[:, 1:].reshape(-1),
)
return ce + kd_weight * kl, kl.item(), ce.item()
# --------------------------------------------------------------------------- #
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--model", default="poolside/Laguna-XS.2")
p.add_argument("--tokenizer", default=DEFAULT_TOKENIZER)
p.add_argument("--tiny", action="store_true", help="CPU plumbing run on the tiny random model")
p.add_argument("--device", default="cuda")
p.add_argument("--dtype", default="bfloat16")
p.add_argument("--dataset", default="bigcode/the-stack-smol",
help="HF code dataset. Known: bigcode/the-stack-smol (gated, smallest, ~10k Py rows), "
"codeparrot/github-code (no gating, needs trust_remote_code, large), "
"bigcode/starcoderdata (gated, large).")
p.add_argument("--lang", default="python", help="language to segment to (default python)")
p.add_argument("--tie-layers", default="18,19,20,21",
help="comma-separated sparse layers; adjacent-paired (a,b),(c,d),...")
p.add_argument("--rank", type=int, default=16)
p.add_argument("--router-rank", type=int, default=None, help="LoRA rank on the router (default = --rank)")
p.add_argument("--init", default="lower", choices=["lower", "average"])
p.add_argument("--seq-len", type=int, default=1024)
p.add_argument("--tokens", type=int, default=50_000_000)
p.add_argument("--steps", type=int, default=2000)
p.add_argument("--batch", type=int, default=4)
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--topk", type=int, default=64)
p.add_argument("--kd-weight", type=float, default=1.0)
p.add_argument("--unfreeze-shared", action="store_true",
help="fallback: also train the shared expert/router base")
p.add_argument("--grad-checkpoint", action=argparse.BooleanOptionalAction, default=True,
help="gradient checkpointing (recompute activations in backward) — needed to fit training")
p.add_argument("--reference", action="store_true",
help="matched reference: LoRA the SAME layers without tying, train CE-only (no KD). "
"Its final ppl is the ceiling the tied run is compared against.")
p.add_argument("--dry-run", action="store_true",
help="pull the first few dataset rows, print field/token counts, and exit (no model load)")
p.add_argument("--outdir", default="results_rrt", help="directory for results JSON + checkpoints")
p.add_argument("--run-name", default=None, help="run name (defaults to a config-derived slug)")
p.add_argument("--eval-every", type=int, default=250, help="eval held-out ppl every N steps (0=off)")
p.add_argument("--save-lora", action="store_true", help="save LoRA checkpoint at each eval + at end")
args = p.parse_args()
device = "cpu" if args.tiny else args.device
# Data-loader validation: tokenizer only, no model.
if args.dry_run:
if args.tiny:
print("--dry-run is for real datasets; --tiny uses synthetic data.")
return
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True, fix_mistral_regex=True)
dry_run_data(tok, dataset=args.dataset, lang=args.lang)
return
if args.tiny:
from transformers import AutoTokenizer
tok = None
model = build_tiny_model(num_layers=8)
vocab = model.config.vocab_size
else:
model, tok = load_model_and_tokenizer(args.model, args.tokenizer, args.dtype, device)
vocab = model.config.vocab_size
blocks = build_blocks(tok, tiny=args.tiny, vocab=vocab, seq_len=args.seq_len,
n_tokens=args.tokens, dataset=args.dataset, lang=args.lang)
# simple train/held-out split
n_eval = max(1, len(blocks) // 10)
eval_blocks, train_blocks = blocks[:n_eval], blocks[n_eval:]
print(f"corpus: {len(train_blocks)} train / {len(eval_blocks)} eval blocks of {args.seq_len}")
# Result record, persisted after every milestone (atomic write) so a crash keeps
# whatever we'd reached.
mode = "reference" if args.reference else "tied"
layers = [int(x) for x in args.tie_layers.split(",")]
slug = "-".join(map(str, layers))
run_name = args.run_name or (
("ref_" if args.reference else "tied_") + f"L{slug}_r{args.rank}"
+ ("_unfz" if (args.unfreeze_shared and not args.reference) else "")
)
outdir = Path(args.outdir)
json_path = outdir / f"{run_name}.json"
results = {
"run_name": run_name, "mode": mode, "config": vars(args), "layers": layers,
"corpus": {"train_blocks": len(train_blocks), "eval_blocks": len(eval_blocks), "seq_len": args.seq_len},
"status": "running",
"baseline_ppl": None, "init_ppl": None, "final_ppl": None,
"agreement_init": None, "agreement_final": None, # tied only
"params": None,
"curve": [], # [{step, train_loss, kl, ce, eval_ppl, agreement}]
}
save_json(json_path, results)
print(f"[{mode}] saving results to {json_path}", flush=True)
# Baseline = the untied model's held-out ppl, before any modification.
print(f"[1/3] baseline perplexity over {len(eval_blocks)} eval blocks...", flush=True)
base_ppl = perplexity(model, eval_blocks, batch=args.batch, device=device)
print(f" baseline ppl = {base_ppl:.3f}", flush=True)
results["baseline_ppl"] = base_ppl
save_json(json_path, results)
# Teacher targets (tied mode only): top-k on train (for KD) + top-1 on eval (agreement).
t_vals = t_idx = teacher_top1_eval = None
if mode == "tied":
print(f"[2/3] precomputing teacher top-{args.topk} logits over {len(train_blocks)} train blocks...", flush=True)
t_vals, t_idx = precompute_teacher_topk(model, train_blocks, k=args.topk, batch=args.batch, device=device)
te_vals, te_idx = precompute_teacher_topk(model, eval_blocks, k=1, batch=args.batch, device=device)
teacher_top1_eval = te_idx[..., 0]
else:
print("[2/3] reference run (CE-only, no teacher) — skipping precompute.", flush=True)
# Modify the model: tie (CE+KD) or add LoRA to own banks (reference, CE-only).
before = parameter_report(model)["total_unique"]
if mode == "tied":
cfg = TieConfig(pairs=adjacent_pairs(layers), rank=args.rank, init=args.init, lora_init="svd")
tie_model(model, cfg, keep_for_untie=False) # never unties -> free originals
desc = f"tied {cfg.pairs} init={cfg.init}"
else:
add_lora_adapters(model, layers, rank=args.rank, router_rank=args.router_rank, lora_init="zero")
desc = f"reference (LoRA on {layers}, no tie)"
rep = parameter_report(model)
set_param_efficient(model, unfreeze_shared=(args.unfreeze_shared and mode == "tied"))
if not args.tiny and args.grad_checkpoint:
model.config.use_cache = False
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if not args.tiny:
torch.cuda.empty_cache()
init_ppl = perplexity(model, eval_blocks, batch=args.batch, device=device)
results["init_ppl"] = init_ppl
if mode == "tied":
results["agreement_init"] = top1_agreement(model, eval_blocks, teacher_top1_eval, batch=args.batch, device=device)
results["params"] = {"baseline_unique": before, "after_unique": rep["total_unique"],
"pct_smaller": 100 * (1 - rep["total_unique"] / before),
"trainable": sum(p.numel() for p in trainable_parameters(model))}
save_json(json_path, results)
print(f"{desc} rank={args.rank}: {before:,} -> {rep['total_unique']:,} params "
f"({results['params']['pct_smaller']:.1f}% smaller), trainable {results['params']['trainable']:,}", flush=True)
if mode == "tied":
print(f" init: ppl {init_ppl:.3f} top1-agreement {results['agreement_init']:.1%}", flush=True)
def record(step, loss, kl, ce):
model.eval()
ppl = perplexity(model, eval_blocks, batch=args.batch, device=device)
agr = (top1_agreement(model, eval_blocks, teacher_top1_eval, batch=args.batch, device=device)
if mode == "tied" else None)
model.train()
results["curve"].append({"step": step, "train_loss": loss, "kl": kl, "ce": ce,
"eval_ppl": ppl, "agreement": agr})
results["final_ppl"] = ppl
results["agreement_final"] = agr
save_json(json_path, results)
if args.save_lora:
save_lora(outdir / f"{run_name}_lora.pt", model)
extra = f" top1 {agr:.1%}" if agr is not None else ""
print(f" [eval] step {step:>5} eval_ppl {ppl:.3f}{extra} (saved)", flush=True)
# Param-efficient training: tied -> CE+KD, reference -> CE only.
opt = torch.optim.AdamW(trainable_parameters(model), lr=args.lr)
model.train()
step = 0
t_train, tokens_seen = time.time(), 0
while step < args.steps:
for i in range(0, len(train_blocks), args.batch):
if step >= args.steps:
break
sl = slice(i, i + args.batch)
b = train_blocks[sl].to(device)
logits = model(input_ids=b, use_cache=False).logits
if mode == "tied":
loss, kl, ce = topk_kd_loss(logits, t_vals[sl].to(device), t_idx[sl].to(device), b, args.kd_weight)
else:
ce_t = F.cross_entropy(logits[:, :-1].reshape(-1, logits.shape[-1]).float(), b[:, 1:].reshape(-1))
loss, kl, ce = ce_t, 0.0, ce_t.item()
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(trainable_parameters(model), 1.0)
opt.step()
tokens_seen += b.numel()
if step == 0:
print(f"[3/3] {'CE+KD' if mode == 'tied' else 'CE-only'} training for {args.steps} steps...", flush=True)
# Print early (for the batch-size/fit probe) and then every 50 steps.
if step < 5 or step % 50 == 0:
tps = tokens_seen / max(time.time() - t_train, 1e-6)
print(f"step {step:>5} loss {loss.item():.4f} KL {float(kl):.4f} CE {ce:.4f} ({tps:,.0f} tok/s)", flush=True)
if args.eval_every and step > 0 and step % args.eval_every == 0:
record(step, loss.item(), float(kl), ce)
step += 1
model.eval()
final_ppl = perplexity(model, eval_blocks, batch=args.batch, device=device)
results["final_ppl"] = final_ppl
if mode == "tied":
results["agreement_final"] = top1_agreement(model, eval_blocks, teacher_top1_eval, batch=args.batch, device=device)
results["status"] = "done"
save_json(json_path, results)
if args.save_lora:
save_lora(outdir / f"{run_name}_lora.pt", model)
print(f"\n=== {mode} result (held-out Python) ===")
print(f"baseline (untied) : {base_ppl:.3f}")
print(f"after-mod @ init : {init_ppl:.3f}")
print(f"after training : {final_ppl:.3f} <-- compare tied vs reference (the recovery gap)")
if mode == "tied":
print(f"top1 agreement : init {results['agreement_init']:.1%} -> final {results['agreement_final']:.1%}")
print(f"results -> {json_path}")
if __name__ == "__main__":
main()