lta / LTA_openwebtext_dualt /scripts /eval_train8_decode_acc.py
JinghuiLuAstronaut's picture
Add files using upload-large-folder tool
36dad47 verified
Raw
History Blame Contribute Delete
10.4 kB
from __future__ import annotations
import argparse
import json
import re
import sys
from pathlib import Path
import numpy as np
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
SCRIPTS_DIR = REPO_ROOT / "scripts"
if str(SCRIPTS_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPTS_DIR))
from infer_context_compare_from_c128 import build_model, decode # noqa: E402
from flowtext_lab.tokenization import BpeTextTokenizer # noqa: E402
def ckpt_step(path: Path) -> int:
m = re.search(r"step_(\d+)\.pt$", path.name)
if not m:
return -1
return int(m.group(1))
def select_ckpts(run_dir: Path, *, latest_only: bool, step_stride: int) -> list[Path]:
ckpts = sorted(run_dir.glob("step_*.pt"), key=ckpt_step)
if not ckpts:
return []
if latest_only:
return [ckpts[-1]]
if step_stride > 0:
picked = [p for p in ckpts if ckpt_step(p) % step_stride == 0]
if ckpts[-1] not in picked:
picked.append(ckpts[-1])
return sorted(set(picked), key=ckpt_step)
return ckpts
def load_refs(data_dir: Path, max_len: int) -> np.ndarray:
meta = json.loads((data_dir / "meta.json").read_text())
n = int(meta.get("num_chunks", meta.get("n_chunks", 0)))
if n <= 0:
size = (data_dir / "chunks.i32.bin").stat().st_size // np.dtype(np.int32).itemsize
n = size // max_len
arr = np.memmap(data_dir / "chunks.i32.bin", dtype=np.int32, mode="r")
arr = np.asarray(arr).reshape(n, -1)
return arr[:, :max_len].copy()
def token_match_metrics(ids: list[list[int]], refs: np.ndarray) -> dict[str, object]:
gen = np.asarray(ids, dtype=np.int32)
if gen.ndim != 2:
raise ValueError(f"expected 2D generated ids, got {gen.shape}")
if gen.shape[1] != refs.shape[1]:
n = min(gen.shape[1], refs.shape[1])
gen = gen[:, :n]
refs = refs[:, :n]
matches = (gen[:, None, :] == refs[None, :, :]).mean(axis=2)
best_idx = matches.argmax(axis=1)
best_acc = matches[np.arange(matches.shape[0]), best_idx]
exact = best_acc >= 1.0
exact_ref_hits = sorted(set(best_idx[exact].astype(int).tolist()))
return {
"n_gen": int(gen.shape[0]),
"n_refs": int(refs.shape[0]),
"token_acc_mean": float(best_acc.mean()),
"token_acc_min": float(best_acc.min()),
"token_acc_max": float(best_acc.max()),
"exact_acc": float(exact.mean()),
"exact_count": int(exact.sum()),
"exact_ref_coverage": float(len(exact_ref_hits) / max(refs.shape[0], 1)),
"exact_ref_count": int(len(exact_ref_hits)),
"exact_ref_hits": exact_ref_hits,
"best_ref_idx": best_idx.astype(int).tolist(),
"best_token_acc": best_acc.astype(float).tolist(),
}
@torch.inference_mode()
def eval_one(
ckpt_path: Path,
tokenizer: BpeTextTokenizer,
refs: np.ndarray,
args: argparse.Namespace,
endpoint_softening: str,
device: torch.device,
) -> dict[str, object]:
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False, mmap=True)
model = build_model(ckpt, tokenizer, args.max_len, device, args.pos_extend)
ids, _texts, _traces = decode(
model,
tokenizer,
max_len=args.max_len,
n_samples=args.n_samples,
batch_size=args.batch_size,
steps=args.steps,
seed=args.seed + ckpt_step(ckpt_path) + args.max_len,
device=device,
decode_rule=args.decode_rule,
support_power=args.support_power,
semantic_power=args.semantic_power,
early_temp=args.early_temp,
late_temp=args.late_temp,
temp_end=args.temp_end,
temp_power=args.temp_power,
hybrid_switch=args.hybrid_switch,
tail_temp=args.tail_temp,
c_min=args.c_min,
c_max=args.c_max,
model_t_mode=args.model_t_mode,
time_schedule=args.time_schedule,
time_logit_mean=args.time_logit_mean,
time_logit_std=args.time_logit_std,
time_power=args.time_power,
input_noise_scale=args.input_noise_scale,
input_noise_until=args.input_noise_until,
input_noise_dirichlet_concentration=args.input_noise_dirichlet_concentration,
endpoint_softening=endpoint_softening,
endpoint_soft_power=args.endpoint_soft_power,
endpoint_soft_min_conf=args.endpoint_soft_min_conf,
endpoint_soft_max_conf=args.endpoint_soft_max_conf,
final_from=args.final_from,
final_decode=args.final_decode,
final_sample_temp=args.final_sample_temp,
final_top_k=args.final_top_k,
final_top_p=args.final_top_p,
eps=1e-8,
fixed_first_token_id=None,
fixed_first_initial_argmax=False,
)
metrics = token_match_metrics(ids, refs)
del model
if device.type == "cuda":
torch.cuda.empty_cache()
return {
"run": ckpt_path.parent.name,
"checkpoint": str(ckpt_path),
"ckpt_step": ckpt_step(ckpt_path),
"endpoint_softening": endpoint_softening,
"decode_rule": args.decode_rule,
"steps": args.steps,
"time_schedule": args.time_schedule,
"model_t_mode": args.model_t_mode,
"final_from": args.final_from,
**metrics,
}
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--runs_glob", default="runs/train8_n1024_*")
ap.add_argument("--data_dir", required=True)
ap.add_argument("--tokenizer_path", required=True)
ap.add_argument("--out_dir", required=True)
ap.add_argument("--max_len", type=int, default=1024)
ap.add_argument("--n_samples", type=int, default=64)
ap.add_argument("--batch_size", type=int, default=2)
ap.add_argument("--latest_only", action="store_true")
ap.add_argument("--step_stride", type=int, default=100)
ap.add_argument("--endpoint_softenings", default="none")
ap.add_argument("--steps", type=int, default=128)
ap.add_argument("--decode_rule", default="flowmap")
ap.add_argument("--support_power", type=float, default=1.0)
ap.add_argument("--semantic_power", type=float, default=1.0)
ap.add_argument("--early_temp", type=float, default=1.0)
ap.add_argument("--late_temp", type=float, default=1.0)
ap.add_argument("--temp_end", type=float, default=1.0)
ap.add_argument("--temp_power", type=float, default=1.0)
ap.add_argument("--hybrid_switch", type=float, default=0.5)
ap.add_argument("--tail_temp", type=float, default=-1.0)
ap.add_argument("--c_min", type=float, default=1.0)
ap.add_argument("--c_max", type=float, default=512.0)
ap.add_argument("--model_t_mode", default="post")
ap.add_argument("--time_schedule", default="logit_normal")
ap.add_argument("--time_logit_mean", type=float, default=-1.5)
ap.add_argument("--time_logit_std", type=float, default=0.8)
ap.add_argument("--time_power", type=float, default=2.0)
ap.add_argument("--input_noise_scale", type=float, default=0.0)
ap.add_argument("--input_noise_until", type=float, default=1.0)
ap.add_argument("--input_noise_dirichlet_concentration", type=float, default=1.0)
ap.add_argument("--endpoint_soft_power", type=float, default=1.0)
ap.add_argument("--endpoint_soft_min_conf", type=float, default=0.0)
ap.add_argument("--endpoint_soft_max_conf", type=float, default=1.0)
ap.add_argument("--final_from", default="state")
ap.add_argument("--final_decode", default="argmax")
ap.add_argument("--final_sample_temp", type=float, default=1.0)
ap.add_argument("--final_top_k", type=int, default=0)
ap.add_argument("--final_top_p", type=float, default=1.0)
ap.add_argument("--pos_extend", default="repeat")
ap.add_argument("--seed", type=int, default=20260517)
args = ap.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BpeTextTokenizer.from_file(args.tokenizer_path)
refs = load_refs(Path(args.data_dir), args.max_len)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for stale in ["decode_token_acc.jsonl", "decode_token_acc.tsv", "decode_token_acc_summary.json"]:
path = out_dir / stale
if path.exists():
path.unlink()
run_dirs = sorted(Path(".").glob(args.runs_glob))
endpoint_softenings = [x.strip() for x in args.endpoint_softenings.split(",") if x.strip()]
rows: list[dict[str, object]] = []
for run_dir in run_dirs:
ckpts = select_ckpts(run_dir, latest_only=args.latest_only, step_stride=args.step_stride)
for ckpt_path in ckpts:
for soft in endpoint_softenings:
print(f"[eval-decode-acc] {run_dir.name} step={ckpt_step(ckpt_path)} soft={soft}", flush=True)
rec = eval_one(ckpt_path, tokenizer, refs, args, soft, device)
rows.append(rec)
with (out_dir / "decode_token_acc.jsonl").open("a", encoding="utf-8") as f:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
fields = [
"run",
"ckpt_step",
"endpoint_softening",
"token_acc_mean",
"token_acc_min",
"token_acc_max",
"exact_acc",
"exact_count",
"exact_ref_coverage",
"exact_ref_count",
]
with (out_dir / "decode_token_acc.tsv").open("w", encoding="utf-8") as f:
f.write("\t".join(fields) + "\n")
for r in rows:
f.write("\t".join(str(r[k]) for k in fields) + "\n")
best_by_run: dict[str, dict[str, object]] = {}
first_exact_by_run: dict[str, dict[str, object] | None] = {}
for r in rows:
key = f"{r['run']}::{r['endpoint_softening']}"
if key not in best_by_run or float(r["token_acc_mean"]) > float(best_by_run[key]["token_acc_mean"]):
best_by_run[key] = r
if float(r["exact_acc"]) > 0 and key not in first_exact_by_run:
first_exact_by_run[key] = r
summary = {
"num_rows": len(rows),
"best_by_run": best_by_run,
"first_exact_by_run": first_exact_by_run,
}
(out_dir / "decode_token_acc_summary.json").write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(summary, ensure_ascii=False, indent=2), flush=True)
if __name__ == "__main__":
main()