looped-laguna / scripts /tf_run_eval.py
e-p's picture
refactor: tf prefix
cd009b8
"""Tier-1 evaluation driver: baseline vs looped Laguna on knowledge-MC tasks.
Runs a matrix of loop configs (baseline / layer / block / naive) over one or more
multiple-choice tasks and reports acc, acc_norm, and the delta vs baseline (pp).
Real run (GPU):
uv run python scripts/tf_run_eval.py --model poolside/Laguna-XS.2 \
--tasks arc_challenge,openbookqa --dtype bfloat16 --device cuda
Local plumbing check (CPU, tiny random model + REAL tokenizer + REAL dataset):
uv run python scripts/tf_run_eval.py --tiny --tasks arc_challenge --limit 20
The --tiny path exercises the entire real code path (tokenizer, dataset loading,
scoring, config matrix) against a tiny random-weight model; metrics are meaningless
but it proves the GPU run is turn-key.
"""
from __future__ import annotations
import argparse
import json
import sys
import time
from datetime import datetime
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
from looped_laguna import LoopConfig, load_model_and_tokenizer
from looped_laguna.eval import DATASET_LOADERS, format_results, run_matrix
DEFAULT_TOKENIZER = str(Path(__file__).resolve().parent.parent / "laguna_src")
def build_configs(num_layers: int, ks: list[int], width: int, center: float, controls: bool = True) -> dict:
"""Config matrix, named `{mode}-{strategy}-K{k}` (mode x strategy is a clean 2x2).
strategy: rk = damped K-stage Runge-Kutta (the proposed method)
naive = undamped x<-g(x) repeated K times (control, expected to collapse)
mode: layer = iterate each window layer in place (required for MoE)
block = iterate the whole window as a unit (control, routing thrash)
We sweep K for the proposed config (layer-rk). With `controls=True` we also add the
other three corners of the 2x2 at K_max (block-rk isolates mode; layer-naive isolates
strategy; block-naive is the canonical "naive looped transformer"). Pass controls=False
for a clean K-sweep (baseline + layer-rk only) that doesn't duplicate the controls a
breadth run already covers.
"""
window = LoopConfig.from_depth_fraction(num_layers, center_frac=center, width=width).window
configs: dict[str, LoopConfig | None] = {"baseline": None}
for k in ks:
configs[f"layer-rk-K{k}"] = LoopConfig(window=window, K=k, mode="layer")
if controls:
k_max = max(ks)
configs[f"block-rk-K{k_max}"] = LoopConfig(window=window, K=k_max, mode="block")
configs[f"layer-naive-K{k_max}"] = LoopConfig(window=window, K=k_max, mode="layer", naive=True)
configs[f"block-naive-K{k_max}"] = LoopConfig(window=window, K=k_max, mode="block", naive=True)
return configs
def main() -> None:
p = argparse.ArgumentParser(description="Tier-1 eval: baseline vs looped Laguna on knowledge-MC tasks.")
p.add_argument("--model", default="poolside/Laguna-XS.2",
help="HF model id or local path of the real model (ignored with --tiny)")
p.add_argument("--tokenizer", default=DEFAULT_TOKENIZER,
help="tokenizer path/id (default: the vendored laguna_src)")
p.add_argument("--tasks", default="arc_challenge",
help="comma-separated task names (see DATASET_LOADERS in eval.py)")
p.add_argument("--limit", type=int, default=None, help="max examples per task (default: all)")
p.add_argument("--ks", default="2,3", help="comma-separated loop counts K to sweep for layer-rk")
p.add_argument("--no-controls", dest="controls", action="store_false",
help="only run baseline + layer-rk per K (skip block/naive controls), e.g. for a K-sweep")
p.add_argument("--width", type=int, default=4, help="loop window width in layers")
p.add_argument("--center", type=float, default=0.5, help="loop window center as a depth fraction (0-1)")
p.add_argument("--dtype", default="bfloat16", help="model dtype: bfloat16/float16/float32")
p.add_argument("--device", default="cuda", help="device or device_map: cuda/cpu/auto")
p.add_argument("--batch-size", type=int, default=16,
help="padded-batch size for scoring (bump to 32-64 on big GPUs; lower for long tasks)")
p.add_argument("--tiny", action="store_true",
help="use a tiny random-weight model instead of real weights (plumbing check)")
p.add_argument("--tiny-layers", type=int, default=8, help="number of layers for the --tiny model")
p.add_argument("--output", default=None,
help="results JSON path (default: auto-named results_eval_<timestamp>.json)")
p.add_argument("--no-save", action="store_true", help="do not write results to disk")
p.add_argument("--no-peritem", dest="peritem", action="store_false",
help="don't save per-item raw data (per-choice LLs/gold/subject) for re-aggregation")
args = p.parse_args()
ks = [int(x) for x in args.ks.split(",")]
tasks = [t.strip() for t in args.tasks.split(",")]
for t in tasks:
if t not in DATASET_LOADERS:
raise SystemExit(f"unknown task {t!r}; known: {sorted(DATASET_LOADERS)}")
model, tokenizer = load_model_and_tokenizer(
args.model, args.tokenizer, dtype=args.dtype, device=args.device,
tiny=args.tiny, tiny_layers=args.tiny_layers,
)
num_layers = model.config.num_hidden_layers
configs = build_configs(num_layers, ks, args.width, args.center, controls=args.controls)
window = next(c.window for c in configs.values() if c is not None)
print(f"model: {'tiny' if args.tiny else args.model} | layers={num_layers} | loop window={window}")
print(f"configs: {list(configs)}")
# Always persist (incrementally) unless --no-save, so a crash never loses finished work.
save_path = None if args.no_save else (
args.output or f"results_eval_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
)
meta = {"model": "tiny" if args.tiny else args.model, "num_layers": num_layers,
"window": list(window), "ks": ks, "tasks": tasks, "limit": args.limit}
all_results: dict = {}
# Per-item raw data (per-choice LLs + gold/subject/lengths) for later re-aggregation.
peritem_path = str(Path(save_path).with_suffix("")) + "_peritem.jsonl" if (save_path and args.peritem) else None
def save() -> None:
if save_path:
Path(save_path).write_text(json.dumps({"meta": meta, "results": all_results}, indent=2))
if save_path:
print(f"saving results to {save_path} (updated after every config)")
if peritem_path:
Path(peritem_path).write_text("") # truncate; run_matrix appends one line per task
print(f"saving per-item data to {peritem_path}")
for task in tasks:
examples = DATASET_LOADERS[task](limit=args.limit)
print(f"\n=== {task} ({len(examples)} items) ===")
all_results[task] = {} # partial configs land here as they finish
t0 = time.time()
results = run_matrix(
model, tokenizer, examples, configs, batch_size=args.batch_size,
peritem_path=peritem_path, task_name=task,
on_result=lambda name, r, t=task: (all_results[t].__setitem__(name, r), save()),
)
all_results[task] = results # full results, now with deltas + significance
save()
for line in format_results(results):
print(line)
print(f" ({time.time() - t0:.1f}s)")
if save_path:
print(f"\nresults saved to {save_path}")
if __name__ == "__main__":
main()