| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import re |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
| SCRIPTS_DIR = REPO_ROOT / "scripts" |
| if str(SCRIPTS_DIR) not in sys.path: |
| sys.path.insert(0, str(SCRIPTS_DIR)) |
|
|
| from infer_context_compare_from_c128 import build_model, decode |
| from flowtext_lab.tokenization import BpeTextTokenizer |
|
|
|
|
| def ckpt_step(path: Path) -> int: |
| m = re.search(r"step_(\d+)\.pt$", path.name) |
| if not m: |
| return -1 |
| return int(m.group(1)) |
|
|
|
|
| def select_ckpts(run_dir: Path, *, latest_only: bool, step_stride: int) -> list[Path]: |
| ckpts = sorted(run_dir.glob("step_*.pt"), key=ckpt_step) |
| if not ckpts: |
| return [] |
| if latest_only: |
| return [ckpts[-1]] |
| if step_stride > 0: |
| picked = [p for p in ckpts if ckpt_step(p) % step_stride == 0] |
| if ckpts[-1] not in picked: |
| picked.append(ckpts[-1]) |
| return sorted(set(picked), key=ckpt_step) |
| return ckpts |
|
|
|
|
| def load_refs(data_dir: Path, max_len: int) -> np.ndarray: |
| meta = json.loads((data_dir / "meta.json").read_text()) |
| n = int(meta.get("num_chunks", meta.get("n_chunks", 0))) |
| if n <= 0: |
| size = (data_dir / "chunks.i32.bin").stat().st_size // np.dtype(np.int32).itemsize |
| n = size // max_len |
| arr = np.memmap(data_dir / "chunks.i32.bin", dtype=np.int32, mode="r") |
| arr = np.asarray(arr).reshape(n, -1) |
| return arr[:, :max_len].copy() |
|
|
|
|
| def token_match_metrics(ids: list[list[int]], refs: np.ndarray) -> dict[str, object]: |
| gen = np.asarray(ids, dtype=np.int32) |
| if gen.ndim != 2: |
| raise ValueError(f"expected 2D generated ids, got {gen.shape}") |
| if gen.shape[1] != refs.shape[1]: |
| n = min(gen.shape[1], refs.shape[1]) |
| gen = gen[:, :n] |
| refs = refs[:, :n] |
| matches = (gen[:, None, :] == refs[None, :, :]).mean(axis=2) |
| best_idx = matches.argmax(axis=1) |
| best_acc = matches[np.arange(matches.shape[0]), best_idx] |
| exact = best_acc >= 1.0 |
| exact_ref_hits = sorted(set(best_idx[exact].astype(int).tolist())) |
| return { |
| "n_gen": int(gen.shape[0]), |
| "n_refs": int(refs.shape[0]), |
| "token_acc_mean": float(best_acc.mean()), |
| "token_acc_min": float(best_acc.min()), |
| "token_acc_max": float(best_acc.max()), |
| "exact_acc": float(exact.mean()), |
| "exact_count": int(exact.sum()), |
| "exact_ref_coverage": float(len(exact_ref_hits) / max(refs.shape[0], 1)), |
| "exact_ref_count": int(len(exact_ref_hits)), |
| "exact_ref_hits": exact_ref_hits, |
| "best_ref_idx": best_idx.astype(int).tolist(), |
| "best_token_acc": best_acc.astype(float).tolist(), |
| } |
|
|
|
|
| @torch.inference_mode() |
| def eval_one( |
| ckpt_path: Path, |
| tokenizer: BpeTextTokenizer, |
| refs: np.ndarray, |
| args: argparse.Namespace, |
| endpoint_softening: str, |
| device: torch.device, |
| ) -> dict[str, object]: |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False, mmap=True) |
| model = build_model(ckpt, tokenizer, args.max_len, device, args.pos_extend) |
| ids, _texts, _traces = decode( |
| model, |
| tokenizer, |
| max_len=args.max_len, |
| n_samples=args.n_samples, |
| batch_size=args.batch_size, |
| steps=args.steps, |
| seed=args.seed + ckpt_step(ckpt_path) + args.max_len, |
| device=device, |
| decode_rule=args.decode_rule, |
| support_power=args.support_power, |
| semantic_power=args.semantic_power, |
| early_temp=args.early_temp, |
| late_temp=args.late_temp, |
| temp_end=args.temp_end, |
| temp_power=args.temp_power, |
| hybrid_switch=args.hybrid_switch, |
| tail_temp=args.tail_temp, |
| c_min=args.c_min, |
| c_max=args.c_max, |
| model_t_mode=args.model_t_mode, |
| time_schedule=args.time_schedule, |
| time_logit_mean=args.time_logit_mean, |
| time_logit_std=args.time_logit_std, |
| time_power=args.time_power, |
| input_noise_scale=args.input_noise_scale, |
| input_noise_until=args.input_noise_until, |
| input_noise_dirichlet_concentration=args.input_noise_dirichlet_concentration, |
| endpoint_softening=endpoint_softening, |
| endpoint_soft_power=args.endpoint_soft_power, |
| endpoint_soft_min_conf=args.endpoint_soft_min_conf, |
| endpoint_soft_max_conf=args.endpoint_soft_max_conf, |
| final_from=args.final_from, |
| final_decode=args.final_decode, |
| final_sample_temp=args.final_sample_temp, |
| final_top_k=args.final_top_k, |
| final_top_p=args.final_top_p, |
| eps=1e-8, |
| fixed_first_token_id=None, |
| fixed_first_initial_argmax=False, |
| ) |
| metrics = token_match_metrics(ids, refs) |
| del model |
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
| return { |
| "run": ckpt_path.parent.name, |
| "checkpoint": str(ckpt_path), |
| "ckpt_step": ckpt_step(ckpt_path), |
| "endpoint_softening": endpoint_softening, |
| "decode_rule": args.decode_rule, |
| "steps": args.steps, |
| "time_schedule": args.time_schedule, |
| "model_t_mode": args.model_t_mode, |
| "final_from": args.final_from, |
| **metrics, |
| } |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--runs_glob", default="runs/train8_n1024_*") |
| ap.add_argument("--data_dir", required=True) |
| ap.add_argument("--tokenizer_path", required=True) |
| ap.add_argument("--out_dir", required=True) |
| ap.add_argument("--max_len", type=int, default=1024) |
| ap.add_argument("--n_samples", type=int, default=64) |
| ap.add_argument("--batch_size", type=int, default=2) |
| ap.add_argument("--latest_only", action="store_true") |
| ap.add_argument("--step_stride", type=int, default=100) |
| ap.add_argument("--endpoint_softenings", default="none") |
| ap.add_argument("--steps", type=int, default=128) |
| ap.add_argument("--decode_rule", default="flowmap") |
| ap.add_argument("--support_power", type=float, default=1.0) |
| ap.add_argument("--semantic_power", type=float, default=1.0) |
| ap.add_argument("--early_temp", type=float, default=1.0) |
| ap.add_argument("--late_temp", type=float, default=1.0) |
| ap.add_argument("--temp_end", type=float, default=1.0) |
| ap.add_argument("--temp_power", type=float, default=1.0) |
| ap.add_argument("--hybrid_switch", type=float, default=0.5) |
| ap.add_argument("--tail_temp", type=float, default=-1.0) |
| ap.add_argument("--c_min", type=float, default=1.0) |
| ap.add_argument("--c_max", type=float, default=512.0) |
| ap.add_argument("--model_t_mode", default="post") |
| ap.add_argument("--time_schedule", default="logit_normal") |
| ap.add_argument("--time_logit_mean", type=float, default=-1.5) |
| ap.add_argument("--time_logit_std", type=float, default=0.8) |
| ap.add_argument("--time_power", type=float, default=2.0) |
| ap.add_argument("--input_noise_scale", type=float, default=0.0) |
| ap.add_argument("--input_noise_until", type=float, default=1.0) |
| ap.add_argument("--input_noise_dirichlet_concentration", type=float, default=1.0) |
| ap.add_argument("--endpoint_soft_power", type=float, default=1.0) |
| ap.add_argument("--endpoint_soft_min_conf", type=float, default=0.0) |
| ap.add_argument("--endpoint_soft_max_conf", type=float, default=1.0) |
| ap.add_argument("--final_from", default="state") |
| ap.add_argument("--final_decode", default="argmax") |
| ap.add_argument("--final_sample_temp", type=float, default=1.0) |
| ap.add_argument("--final_top_k", type=int, default=0) |
| ap.add_argument("--final_top_p", type=float, default=1.0) |
| ap.add_argument("--pos_extend", default="repeat") |
| ap.add_argument("--seed", type=int, default=20260517) |
| args = ap.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path) |
| refs = load_refs(Path(args.data_dir), args.max_len) |
| out_dir = Path(args.out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| for stale in ["decode_token_acc.jsonl", "decode_token_acc.tsv", "decode_token_acc_summary.json"]: |
| path = out_dir / stale |
| if path.exists(): |
| path.unlink() |
|
|
| run_dirs = sorted(Path(".").glob(args.runs_glob)) |
| endpoint_softenings = [x.strip() for x in args.endpoint_softenings.split(",") if x.strip()] |
| rows: list[dict[str, object]] = [] |
| for run_dir in run_dirs: |
| ckpts = select_ckpts(run_dir, latest_only=args.latest_only, step_stride=args.step_stride) |
| for ckpt_path in ckpts: |
| for soft in endpoint_softenings: |
| print(f"[eval-decode-acc] {run_dir.name} step={ckpt_step(ckpt_path)} soft={soft}", flush=True) |
| rec = eval_one(ckpt_path, tokenizer, refs, args, soft, device) |
| rows.append(rec) |
| with (out_dir / "decode_token_acc.jsonl").open("a", encoding="utf-8") as f: |
| f.write(json.dumps(rec, ensure_ascii=False) + "\n") |
|
|
| fields = [ |
| "run", |
| "ckpt_step", |
| "endpoint_softening", |
| "token_acc_mean", |
| "token_acc_min", |
| "token_acc_max", |
| "exact_acc", |
| "exact_count", |
| "exact_ref_coverage", |
| "exact_ref_count", |
| ] |
| with (out_dir / "decode_token_acc.tsv").open("w", encoding="utf-8") as f: |
| f.write("\t".join(fields) + "\n") |
| for r in rows: |
| f.write("\t".join(str(r[k]) for k in fields) + "\n") |
|
|
| best_by_run: dict[str, dict[str, object]] = {} |
| first_exact_by_run: dict[str, dict[str, object] | None] = {} |
| for r in rows: |
| key = f"{r['run']}::{r['endpoint_softening']}" |
| if key not in best_by_run or float(r["token_acc_mean"]) > float(best_by_run[key]["token_acc_mean"]): |
| best_by_run[key] = r |
| if float(r["exact_acc"]) > 0 and key not in first_exact_by_run: |
| first_exact_by_run[key] = r |
| summary = { |
| "num_rows": len(rows), |
| "best_by_run": best_by_run, |
| "first_exact_by_run": first_exact_by_run, |
| } |
| (out_dir / "decode_token_acc_summary.json").write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") |
| print(json.dumps(summary, ensure_ascii=False, indent=2), flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|