from __future__ import annotations import argparse import json import re import sys from pathlib import Path import numpy as np import torch REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) SCRIPTS_DIR = REPO_ROOT / "scripts" if str(SCRIPTS_DIR) not in sys.path: sys.path.insert(0, str(SCRIPTS_DIR)) from infer_context_compare_from_c128 import build_model, decode # noqa: E402 from flowtext_lab.tokenization import BpeTextTokenizer # noqa: E402 def ckpt_step(path: Path) -> int: m = re.search(r"step_(\d+)\.pt$", path.name) if not m: return -1 return int(m.group(1)) def select_ckpts(run_dir: Path, *, latest_only: bool, step_stride: int) -> list[Path]: ckpts = sorted(run_dir.glob("step_*.pt"), key=ckpt_step) if not ckpts: return [] if latest_only: return [ckpts[-1]] if step_stride > 0: picked = [p for p in ckpts if ckpt_step(p) % step_stride == 0] if ckpts[-1] not in picked: picked.append(ckpts[-1]) return sorted(set(picked), key=ckpt_step) return ckpts def load_refs(data_dir: Path, max_len: int) -> np.ndarray: meta = json.loads((data_dir / "meta.json").read_text()) n = int(meta.get("num_chunks", meta.get("n_chunks", 0))) if n <= 0: size = (data_dir / "chunks.i32.bin").stat().st_size // np.dtype(np.int32).itemsize n = size // max_len arr = np.memmap(data_dir / "chunks.i32.bin", dtype=np.int32, mode="r") arr = np.asarray(arr).reshape(n, -1) return arr[:, :max_len].copy() def token_match_metrics(ids: list[list[int]], refs: np.ndarray) -> dict[str, object]: gen = np.asarray(ids, dtype=np.int32) if gen.ndim != 2: raise ValueError(f"expected 2D generated ids, got {gen.shape}") if gen.shape[1] != refs.shape[1]: n = min(gen.shape[1], refs.shape[1]) gen = gen[:, :n] refs = refs[:, :n] matches = (gen[:, None, :] == refs[None, :, :]).mean(axis=2) best_idx = matches.argmax(axis=1) best_acc = matches[np.arange(matches.shape[0]), best_idx] exact = best_acc >= 1.0 exact_ref_hits = sorted(set(best_idx[exact].astype(int).tolist())) return { "n_gen": int(gen.shape[0]), "n_refs": int(refs.shape[0]), "token_acc_mean": float(best_acc.mean()), "token_acc_min": float(best_acc.min()), "token_acc_max": float(best_acc.max()), "exact_acc": float(exact.mean()), "exact_count": int(exact.sum()), "exact_ref_coverage": float(len(exact_ref_hits) / max(refs.shape[0], 1)), "exact_ref_count": int(len(exact_ref_hits)), "exact_ref_hits": exact_ref_hits, "best_ref_idx": best_idx.astype(int).tolist(), "best_token_acc": best_acc.astype(float).tolist(), } @torch.inference_mode() def eval_one( ckpt_path: Path, tokenizer: BpeTextTokenizer, refs: np.ndarray, args: argparse.Namespace, endpoint_softening: str, device: torch.device, ) -> dict[str, object]: ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False, mmap=True) model = build_model(ckpt, tokenizer, args.max_len, device, args.pos_extend) ids, _texts, _traces = decode( model, tokenizer, max_len=args.max_len, n_samples=args.n_samples, batch_size=args.batch_size, steps=args.steps, seed=args.seed + ckpt_step(ckpt_path) + args.max_len, device=device, decode_rule=args.decode_rule, support_power=args.support_power, semantic_power=args.semantic_power, early_temp=args.early_temp, late_temp=args.late_temp, temp_end=args.temp_end, temp_power=args.temp_power, hybrid_switch=args.hybrid_switch, tail_temp=args.tail_temp, c_min=args.c_min, c_max=args.c_max, model_t_mode=args.model_t_mode, time_schedule=args.time_schedule, time_logit_mean=args.time_logit_mean, time_logit_std=args.time_logit_std, time_power=args.time_power, input_noise_scale=args.input_noise_scale, input_noise_until=args.input_noise_until, input_noise_dirichlet_concentration=args.input_noise_dirichlet_concentration, endpoint_softening=endpoint_softening, endpoint_soft_power=args.endpoint_soft_power, endpoint_soft_min_conf=args.endpoint_soft_min_conf, endpoint_soft_max_conf=args.endpoint_soft_max_conf, final_from=args.final_from, final_decode=args.final_decode, final_sample_temp=args.final_sample_temp, final_top_k=args.final_top_k, final_top_p=args.final_top_p, eps=1e-8, fixed_first_token_id=None, fixed_first_initial_argmax=False, ) metrics = token_match_metrics(ids, refs) del model if device.type == "cuda": torch.cuda.empty_cache() return { "run": ckpt_path.parent.name, "checkpoint": str(ckpt_path), "ckpt_step": ckpt_step(ckpt_path), "endpoint_softening": endpoint_softening, "decode_rule": args.decode_rule, "steps": args.steps, "time_schedule": args.time_schedule, "model_t_mode": args.model_t_mode, "final_from": args.final_from, **metrics, } def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--runs_glob", default="runs/train8_n1024_*") ap.add_argument("--data_dir", required=True) ap.add_argument("--tokenizer_path", required=True) ap.add_argument("--out_dir", required=True) ap.add_argument("--max_len", type=int, default=1024) ap.add_argument("--n_samples", type=int, default=64) ap.add_argument("--batch_size", type=int, default=2) ap.add_argument("--latest_only", action="store_true") ap.add_argument("--step_stride", type=int, default=100) ap.add_argument("--endpoint_softenings", default="none") ap.add_argument("--steps", type=int, default=128) ap.add_argument("--decode_rule", default="flowmap") ap.add_argument("--support_power", type=float, default=1.0) ap.add_argument("--semantic_power", type=float, default=1.0) ap.add_argument("--early_temp", type=float, default=1.0) ap.add_argument("--late_temp", type=float, default=1.0) ap.add_argument("--temp_end", type=float, default=1.0) ap.add_argument("--temp_power", type=float, default=1.0) ap.add_argument("--hybrid_switch", type=float, default=0.5) ap.add_argument("--tail_temp", type=float, default=-1.0) ap.add_argument("--c_min", type=float, default=1.0) ap.add_argument("--c_max", type=float, default=512.0) ap.add_argument("--model_t_mode", default="post") ap.add_argument("--time_schedule", default="logit_normal") ap.add_argument("--time_logit_mean", type=float, default=-1.5) ap.add_argument("--time_logit_std", type=float, default=0.8) ap.add_argument("--time_power", type=float, default=2.0) ap.add_argument("--input_noise_scale", type=float, default=0.0) ap.add_argument("--input_noise_until", type=float, default=1.0) ap.add_argument("--input_noise_dirichlet_concentration", type=float, default=1.0) ap.add_argument("--endpoint_soft_power", type=float, default=1.0) ap.add_argument("--endpoint_soft_min_conf", type=float, default=0.0) ap.add_argument("--endpoint_soft_max_conf", type=float, default=1.0) ap.add_argument("--final_from", default="state") ap.add_argument("--final_decode", default="argmax") ap.add_argument("--final_sample_temp", type=float, default=1.0) ap.add_argument("--final_top_k", type=int, default=0) ap.add_argument("--final_top_p", type=float, default=1.0) ap.add_argument("--pos_extend", default="repeat") ap.add_argument("--seed", type=int, default=20260517) args = ap.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) refs = load_refs(Path(args.data_dir), args.max_len) out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) for stale in ["decode_token_acc.jsonl", "decode_token_acc.tsv", "decode_token_acc_summary.json"]: path = out_dir / stale if path.exists(): path.unlink() run_dirs = sorted(Path(".").glob(args.runs_glob)) endpoint_softenings = [x.strip() for x in args.endpoint_softenings.split(",") if x.strip()] rows: list[dict[str, object]] = [] for run_dir in run_dirs: ckpts = select_ckpts(run_dir, latest_only=args.latest_only, step_stride=args.step_stride) for ckpt_path in ckpts: for soft in endpoint_softenings: print(f"[eval-decode-acc] {run_dir.name} step={ckpt_step(ckpt_path)} soft={soft}", flush=True) rec = eval_one(ckpt_path, tokenizer, refs, args, soft, device) rows.append(rec) with (out_dir / "decode_token_acc.jsonl").open("a", encoding="utf-8") as f: f.write(json.dumps(rec, ensure_ascii=False) + "\n") fields = [ "run", "ckpt_step", "endpoint_softening", "token_acc_mean", "token_acc_min", "token_acc_max", "exact_acc", "exact_count", "exact_ref_coverage", "exact_ref_count", ] with (out_dir / "decode_token_acc.tsv").open("w", encoding="utf-8") as f: f.write("\t".join(fields) + "\n") for r in rows: f.write("\t".join(str(r[k]) for k in fields) + "\n") best_by_run: dict[str, dict[str, object]] = {} first_exact_by_run: dict[str, dict[str, object] | None] = {} for r in rows: key = f"{r['run']}::{r['endpoint_softening']}" if key not in best_by_run or float(r["token_acc_mean"]) > float(best_by_run[key]["token_acc_mean"]): best_by_run[key] = r if float(r["exact_acc"]) > 0 and key not in first_exact_by_run: first_exact_by_run[key] = r summary = { "num_rows": len(rows), "best_by_run": best_by_run, "first_exact_by_run": first_exact_by_run, } (out_dir / "decode_token_acc_summary.json").write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") print(json.dumps(summary, ensure_ascii=False, indent=2), flush=True) if __name__ == "__main__": main()