""" longbench_eval.py — LongBench v1 evaluation for AttnVQ and baselines. Stages: cheap proxy KV metrics on task contexts (fast, default) generate end-to-end task scoring via model.generate (slow) Tasks: qasper, 2wikimqa, hotpotqa, passage_retrieval_en, repobench-p Usage: python longbench_eval.py --stage cheap --n_eval 50 python longbench_eval.py --stage generate --tasks hotpotqa qasper """ from __future__ import annotations import argparse import collections import json import os import re import string from difflib import SequenceMatcher import torch from tqdm import tqdm from turbo_benchmark import TurboQuantMSE, QJLResidualIP, turbo_configs # noqa: F401 — unpickle from vqkv.quantizers import KIVIScalarKV # noqa: F401 from vqkv.metrics import (key_cosine, cache_mse, inner_product_distortion, attention_output, attn_output_cosine, attn_output_error) MODEL_ID = os.environ.get("LAGUNA_ID", "poolside/Laguna-XS.2") ARTIFACT_DIR = os.environ.get("VQKV_ARTIFACTS", "./artifacts") ATTN_WIN = 512 CHEAP_COLS = ("key_cos", "val_cos", "key_mse", "val_mse", "attn_cos", "attn_output_error", "ip_rel", "ip_bias") # Scorers (referenced from TASK_CONFIGS) def _normalize(s: str) -> str: s = s.lower() s = s.translate(str.maketrans("", "", string.punctuation)) return " ".join(s.split()) def qa_f1_score(prediction: str, ground_truths: list[str], **_) -> float: """Max token-level F1 over all reference answers (QA tasks).""" pred_toks = _normalize(prediction).split() best = 0.0 for ref in ground_truths: ref_toks = _normalize(ref).split() if not pred_toks or not ref_toks: best = max(best, float(pred_toks == ref_toks)) continue common = collections.Counter(pred_toks) & collections.Counter(ref_toks) n_common = sum(common.values()) if n_common == 0: continue p = n_common / len(pred_toks) r = n_common / len(ref_toks) best = max(best, 2 * p * r / (p + r)) return best def retrieval_score(prediction: str, ground_truths: list[str], **_) -> float: """Passage-retrieval accuracy: extract digit from prediction, match gold. Ground truths are strings like 'Paragraph 3'; we extract the number and check whether it appears among the first 10 digits in the prediction. Mirrors the LongBench retrieval_score implementation. """ gt_ids = set() for ref in ground_truths: m = re.findall(r"\d+", ref) gt_ids.add(m[0] if m else _normalize(ref)) pred_nums = re.findall(r"\d+", prediction) if not pred_nums: return 0.0 right = sum(1 for n in pred_nums[:10] if n in gt_ids) return right / len(pred_nums[:10]) def classification_score(prediction: str, ground_truths: list[str], all_classes: list[str] | None = None, **_) -> float: """Classification accuracy (TREC). If all_classes is provided (from the dataset example), restrict matches to valid class labels to avoid spurious substring hits — matches the LongBench classification_score behaviour. """ if all_classes: matched = [c for c in all_classes if c.lower() in prediction.lower()] return float(any(ref.lower() in [m.lower() for m in matched] for ref in ground_truths)) # Fallback: normalised substring match pred_norm = _normalize(prediction) return float(any(_normalize(ref) in pred_norm for ref in ground_truths)) def edit_similarity_score(prediction: str, ground_truths: list[str], **_) -> float: """Character-level edit similarity for code completion (RepoBench-P). Uses difflib.SequenceMatcher.ratio() — equivalent to fuzz.ratio() (not partial_ratio) and zero-dep. """ pred = prediction.strip() best = 0.0 for ref in ground_truths: ref_s = ref.strip() if not pred and not ref_s: best = 1.0 elif pred and ref_s: best = max(best, SequenceMatcher(None, pred, ref_s).ratio()) return best # ============================================================================ # Task configs # ============================================================================ # QA tasks: increasing context length (~3.6K → ~18K tokens) directly probes # the compounding-error-vs-length axis from RUNBOOK Step 6. # Three extra tasks cover retrieval accuracy, classification, and code completion. # ============================================================================ TASK_CONFIGS: dict[str, dict] = { "qasper": dict( score_fn=qa_f1_score, max_new_tokens=20, max_len=16384, prompt_template=( "Answer the question based on the given passages. " "Only give me the answer and do not output any other words.\n\n" "The following are given passages.\n{context}\n\n" "Answer the question based on the given passages. " "Only give me the answer and do not output any other words.\n\n" "Question: {input}\nAnswer:" ), ), "2wikimqa": dict( score_fn=qa_f1_score, max_new_tokens=10, max_len=16384, prompt_template=( "Answer the question based on the given passages. " "Only give me the answer and do not output any other words.\n\n" "The following are given passages.\n{context}\n\n" "Answer the question based on the given passages. " "Only give me the answer and do not output any other words.\n\n" "Question: {input}\nAnswer:" ), ), "hotpotqa": dict( score_fn=qa_f1_score, max_new_tokens=32, max_len=32768, prompt_template=( "Answer the question based on the given passages. " "Only give me the answer and do not output any other words.\n\n" "The following are given passages.\n{context}\n\n" "Answer the question based on the given passages. " "Only give me the answer and do not output any other words.\n\n" "Question: {input}\nAnswer:" ), ), "passage_retrieval_en": dict( score_fn=retrieval_score, max_new_tokens=10, max_len=32768, prompt_template=( "Below is a record of a series of paragraphs, each from different " "documents. Tell me which paragraph the given passage is from.\n\n" "{context}\n\n" "Based on the above paragraphs, which paragraph does the following " "passage come from? Only output the paragraph number. " "Do not output any other characters.\n\n" "{input}\nAnswer:" ), ), "repobench-p": dict( score_fn=edit_similarity_score, max_new_tokens=64, max_len=16384, prompt_template=( "Please complete the code given below.\n{context}\n{input}\n" ), ), } # ============================================================================ # Model loading # ============================================================================ def load_model_and_meta(): from transformers import AutoModelForCausalLM, AutoTokenizer tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, fix_mistral_regex=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True) model.eval() cfg = model.config full_layers = [i for i, t in enumerate(cfg.layer_types) if t == "full_attention"] meta = { "full_layers": full_layers, "n_kv_heads": cfg.num_key_value_heads, "n_q_heads": cfg.num_attention_heads, "head_dim": cfg.head_dim, } print(f"[meta] full-attention layers ({len(full_layers)}): {full_layers}") print(f"[meta] kv_heads={meta['n_kv_heads']} head_dim={meta['head_dim']}") return model, tok, meta # ============================================================================ # Generic VQCache # ============================================================================ def make_cache_class(per_layer_fns: dict, target_layers: list): """Build a DynamicCache subclass that round-trips K/V through the given fns. per_layer_fns[layer_idx] = { "k_fn": callable (N, d) -> (N, d), "v_fn": callable (N, d) -> (N, d), "per_channel": bool, # True for KIVI/Sign/Ternary (key dim=0 reduction) } For non-per-channel quantizers, K and V are flattened to (s*h, d) before calling k_fn/v_fn. For per-channel, each head's (s, d) block is passed separately so the quantizer's token-axis statistics are per-head (not cross-head), matching the KIVI / benchmark.py convention. """ from transformers.cache_utils import DynamicCache _target = set(target_layers) class VQCache(DynamicCache): def update(self, key_states, value_states, layer_idx, cache_kwargs=None): if layer_idx in _target and layer_idx in per_layer_fns: entry = per_layer_fns[layer_idx] k_fn = entry["k_fn"] v_fn = entry["v_fn"] per_channel = entry.get("per_channel", False) _b, h, s, d = key_states.shape if per_channel: # key_states[0]: (h, s, d) -> (s, h, d) kk = key_states[0].permute(1, 0, 2).float() k_hat = torch.stack([k_fn(kk[:, hh, :]) for hh in range(h)], dim=1) vv = value_states[0].permute(1, 0, 2).float() v_hat = torch.stack([v_fn(vv[:, hh, :]) for hh in range(h)], dim=1) # (s, h, d) -> (h, s, d) -> (1, h, s, d) k_hat = k_hat.permute(1, 0, 2).unsqueeze(0) v_hat = v_hat.permute(1, 0, 2).unsqueeze(0) else: # key_states[0]: (h, s, d) -> (s, h, d) -> (s*h, d) kf = key_states[0].permute(1, 0, 2).reshape(-1, d).float() vf = value_states[0].permute(1, 0, 2).reshape(-1, d).float() # (s*h, d) -> (s, h, d) -> (h, s, d) -> (1, h, s, d) k_hat = k_fn(kf).reshape(s, h, d).permute(1, 0, 2).unsqueeze(0) v_hat = v_fn(vf).reshape(s, h, d).permute(1, 0, 2).unsqueeze(0) key_states = k_hat.to(key_states.dtype).to(key_states.device) value_states = v_hat.to(value_states.dtype).to(value_states.device) return super().update(key_states, value_states, layer_idx, cache_kwargs) return VQCache # ============================================================================ # Build unified config list from both codebook files # ============================================================================ def build_all_configs(meta: dict, device, only: list[str] | None = None): """Return list of (name, bpe, cache_cls_or_None). cache_cls_or_None: class (not instance) to call as cls() each generation, or None for the fp16 baseline (DynamicCache allocated internally by model). """ hd = meta["head_dim"] layers = meta["full_layers"] configs: list[tuple[str, float, type | None]] = [] configs.append(("fp16", 16.0, None)) # -- Regular codebooks (ProductVQ, RoPESplit, Scalar, KIVI, Sign, Ternary) -- cb_path = os.path.join(ARTIFACT_DIR, "codebooks.pt") if os.path.exists(cb_path): blob = torch.load(cb_path, weights_only=False) fitted = blob["fitted"] for name, per_layer in fitted.items(): for q in per_layer.values(): if hasattr(q, "to"): q.to(device) q0 = next(iter(per_layer.values())) bpe = round(q0.bits_per_element(hd), 4) per_channel = (isinstance(q0, KIVIScalarKV) or getattr(q0, "per_channel_key", False)) per_layer_fns = { i: {"k_fn": per_layer[i].roundtrip_k, "v_fn": per_layer[i].roundtrip_v, "per_channel": per_channel} for i in layers if i in per_layer } configs.append((name, bpe, make_cache_class(per_layer_fns, layers))) else: print(f"[warn] {cb_path} not found — skipping ProductVQ / scalar configs") # -- TurboQuant codebooks -- tc_path = os.path.join(ARTIFACT_DIR, "turbo_codebooks.pt") if os.path.exists(tc_path): blob = torch.load(tc_path, weights_only=False) fitted = blob["fitted"] bits_list = blob["bits"] cfg_lookup = {n: (b, use_qjl) for n, b, use_qjl in turbo_configs(bits_list)} for name, per_layer in fitted.items(): b, use_qjl = cfg_lookup[name] bpe = round(b + (1 if use_qjl else 0) + 16.0 / hd, 4) for entry in per_layer.values(): entry["kq"].to(device) entry["vq"].to(device) if entry["qjl"] is not None: entry["qjl"].to(device) per_layer_fns = { i: {"k_fn": per_layer[i]["kq"].roundtrip, "v_fn": per_layer[i]["vq"].roundtrip, "per_channel": False} for i in layers if i in per_layer } configs.append((name, bpe, make_cache_class(per_layer_fns, layers))) else: print(f"[warn] {tc_path} not found — skipping TurboQuant configs") if only is not None: only_set = set(only) configs = [(n, b, c) for n, b, c in configs if n in only_set] unknown = only_set - {n for n, _, _ in configs} if unknown: print(f"[warn] unknown --configs names: {sorted(unknown)}") print(f"[build] {len(configs)} configs loaded") for n, b, _ in configs: print(f" {n:<32} {b:.3f} bpe") return configs # ============================================================================ # Per-task generation + scoring # ============================================================================ def _load_longbench(task_name: str): """Load a LongBench task, compatible with datasets >= 3.0. datasets >= 3.0 dropped custom dataset-script support and raises 'Dataset scripts are no longer supported' for THUDM/LongBench. We load the underlying JSONL files from the HF Hub directly instead. Tries the English-suffixed file first (_e.jsonl), then the plain name. """ from datasets import load_dataset as _ld # Candidates in priority order: # {task}_e.jsonl – English-suffixed bilingual tasks # {task}.jsonl – single-language tasks # {task_underscored}.jsonl – hyphenated names (repobench-p → repobench_p) slug = task_name.replace("-", "_") candidates = [f"{task_name}_e.jsonl", f"{task_name}.jsonl", f"{slug}_e.jsonl", f"{slug}.jsonl"] # deduplicate while preserving order seen: set[str] = set() for fname in [c for c in candidates if not (c in seen or seen.add(c))]: try: return _ld( "json", data_files=f"hf://datasets/THUDM/LongBench/data/{fname}", split="train", ) except Exception: continue # Fallback for older datasets versions that still support scripts return _ld("THUDM/LongBench", name=task_name, split="test") def prompt_text(task_cfg: dict, example: dict) -> str: return task_cfg["prompt_template"].format( context=example["context"], input=example["input"]) def load_vqkv_fitted(device): path = os.path.join(ARTIFACT_DIR, "codebooks.pt") if not os.path.exists(path): raise FileNotFoundError(f"{path} not found; run benchmark.py --stage fit") blob = torch.load(path, weights_only=False) fitted, meta = blob["fitted"], blob["meta"] for per_layer in fitted.values(): for q in per_layer.values(): if hasattr(q, "to"): q.to(device) return fitted, meta def load_turbo_fitted(device): path = os.path.join(ARTIFACT_DIR, "turbo_codebooks.pt") if not os.path.exists(path): print(f"[warn] {path} not found — skipping TurboQuant configs") return {}, {} blob = torch.load(path, weights_only=False) fitted = blob["fitted"] cfg_lookup = {n: (b, use_qjl) for n, b, use_qjl in turbo_configs(blob["bits"])} for per_layer in fitted.values(): for entry in per_layer.values(): entry["kq"].to(device) entry["vq"].to(device) if entry["qjl"] is not None: entry["qjl"].to(device) return fitted, cfg_lookup def _aggregate_cheap_rows(trace_rows: list[dict], bpe_lookup: dict) -> list[dict]: agg = collections.defaultdict(lambda: collections.defaultdict(list)) for r in trace_rows: for col in CHEAP_COLS: agg[r["config"]][col].append(r[col]) summary = [] for name in bpe_lookup: if name not in agg: continue cols = agg[name] n = len(cols["key_cos"]) row = {"config": name, "bits_per_elt": bpe_lookup[name], "n_traces": n} for col in CHEAP_COLS: row[col] = round(sum(cols[col]) / n, 5) summary.append(row) return summary def _print_cheap_table(task_name: str, summary: list[dict]): print(f"\n[cheap/{task_name}] mean metrics:") print(f" {'config':32s} {'bpe':>5} {'key_cos':>8} {'val_cos':>8} " f"{'key_mse':>9} {'val_mse':>9} {'attn_cos':>9} {'attn_err':>9} " f"{'ip_rel':>8} {'ip_bias':>9}") for row in summary: print(f" {row['config']:32s} {row['bits_per_elt']:5.2f} " f"{row['key_cos']:8.4f} {row['val_cos']:8.4f} " f"{row['key_mse']:9.5f} {row['val_mse']:9.5f} " f"{row['attn_cos']:9.4f} {row['attn_output_error']:9.4f} " f"{row['ip_rel']:8.5f} {row['ip_bias']:9.6f}") def run_cheap_task(model, tok, meta, task_name: str, task_cfg: dict, vqkv_fitted: dict, turbo_fitted: dict, turbo_cfg_lookup: dict, bpe_lookup: dict, config_filter: set[str] | None, n_eval: int, min_len: int = 2048) -> list[dict]: """Dump fp16 caches on n_eval LongBench prompts; score all quantizer configs.""" from transformers.cache_utils import DynamicCache ds = _load_longbench(task_name) subset = ds.select(range(min(n_eval, len(ds)))) max_len = task_cfg["max_len"] full = meta["full_layers"] hd = meta["head_dim"] n_q = meta.get("n_q_heads", 48) dev = model.device if str(dev) == "meta": dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"\n[cheap/{task_name}] {len(subset)} examples, max_len={max_len}") class EvalDump(DynamicCache): def __init__(self): super().__init__() self.d = {i: {} for i in full} def update(self, ks, vs, li, ck=None): if li in set(full): self.d[li]["k"] = ks.detach()[0].permute(1, 0, 2).float() self.d[li]["v"] = vs.detach()[0].permute(1, 0, 2).float() return super().update(ks, vs, li, ck) trace_rows: list[dict] = [] n_used = 0 for ex in tqdm(subset, desc=f"cheap/{task_name}"): text = prompt_text(task_cfg, ex) ids = tok(text, return_tensors="pt", truncation=True, max_length=max_len).to(dev) if ids["input_ids"].shape[1] < min_len: continue cache = EvalDump() with torch.no_grad(): model.model(**ids, past_key_values=cache, use_cache=True) n_used += 1 synth_q = {} for i in full: s = cache.d[i]["k"].shape[0] win = min(s, ATTN_WIN) q_rand = torch.randn(win, n_q, hd, device=dev) synth_q[i] = q_rand / q_rand.norm(dim=-1, keepdim=True).clamp_min(1e-8) # -- vqkv configs (ProductVQ, scalar, KIVI, …) -- for name, per_layer in vqkv_fitted.items(): if config_filter and name not in config_filter: continue acc = collections.defaultdict(float) nL = 0 for i in full: k = cache.d[i]["k"] v = cache.d[i]["v"] q = per_layer[i] s, h, d = k.shape per_channel = ( isinstance(q, KIVIScalarKV) or getattr(q, "per_channel_key", False) ) if per_channel: k_hat = torch.stack([q.roundtrip_k(k[:, hh, :]) for hh in range(h)], 1) v_hat = torch.stack([q.roundtrip_v(v[:, hh, :]) for hh in range(h)], 1) else: k_hat = q.roundtrip_k(k.reshape(-1, d)).reshape(s, h, d) v_hat = q.roundtrip_v(v.reshape(-1, d)).reshape(s, h, d) acc["key_cos"] += key_cosine(k, k_hat) acc["val_cos"] += key_cosine(v, v_hat) acc["key_mse"] += cache_mse(k, k_hat) acc["val_mse"] += cache_mse(v, v_hat) win = min(s, ATTN_WIN) kw, kw_hat = k[-win:], k_hat[-win:] vw, vw_hat = v[-win:], v_hat[-win:] q_syn = synth_q[i] out_ref, _ = attention_output(q_syn, kw, vw, n_q) out_hat, _ = attention_output(q_syn, kw_hat, vw_hat, n_q) acc["attn_cos"] += attn_output_cosine(out_ref, out_hat) acc["attn_output_error"] += attn_output_error(out_ref, out_hat) ip = inner_product_distortion(q_syn, kw, kw_hat) acc["ip_rel"] += ip["ip_rel_err"] acc["ip_bias"] += ip["ip_bias"] nL += 1 trace_rows.append({ "task": task_name, "trace_len": ids["input_ids"].shape[1], "config": name, **{col: acc[col] / nL for col in CHEAP_COLS}, }) # -- TurboQuant configs -- for name, per_layer in turbo_fitted.items(): if config_filter and name not in config_filter: continue acc = collections.defaultdict(float) nL = 0 for i in full: k = cache.d[i]["k"] v = cache.d[i]["v"] e = per_layer[i] s, h, d = k.shape k_hat = e["kq"].roundtrip(k.reshape(-1, d)).reshape(s, h, d) v_hat = e["vq"].roundtrip(v.reshape(-1, d)).reshape(s, h, d) acc["key_cos"] += key_cosine(k, k_hat) acc["val_cos"] += key_cosine(v, v_hat) acc["key_mse"] += cache_mse(k, k_hat) acc["val_mse"] += cache_mse(v, v_hat) win = min(s, ATTN_WIN) kw, kw_hat = k[-win:], k_hat[-win:] vw, vw_hat = v[-win:], v_hat[-win:] q_syn = synth_q[i] out_ref, _ = attention_output(q_syn, kw, vw, n_q) out_hat, _ = attention_output(q_syn, kw_hat, vw_hat, n_q) acc["attn_cos"] += attn_output_cosine(out_ref, out_hat) acc["attn_output_error"] += attn_output_error(out_ref, out_hat) q0 = q_syn[:, 0, :] kr0 = kw[:, 0, :] kh0 = kw_hat[:, 0, :] qi = torch.randint(0, win, (4096,), device=dev) ki = torch.randint(0, win, (4096,), device=dev) ip_ref = (q0[qi] * kr0[ki]).sum(-1) if e["qjl"] is not None: ip_hat = e["qjl"].estimate_ip(q0[qi], kr0[ki]) else: ip_hat = (q0[qi] * kh0[ki]).sum(-1) acc["ip_bias"] += (ip_hat - ip_ref).mean().item() acc["ip_rel"] += ((ip_hat - ip_ref).abs() / ip_ref.abs().clamp_min(1e-6)).mean().item() nL += 1 trace_rows.append({ "task": task_name, "trace_len": ids["input_ids"].shape[1], "config": name, **{col: acc[col] / nL for col in CHEAP_COLS}, }) print(f"[cheap/{task_name}] used {n_used}/{len(subset)} traces (min_len={min_len})") return trace_rows def stage_cheap(tasks: list[str], n_eval: int, config_filter: set[str] | None): model, tok, model_meta = load_model_and_meta() dev = model.device if str(dev) == "meta": dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") vqkv_fitted, cb_meta = load_vqkv_fitted(dev) turbo_fitted, turbo_cfg_lookup = load_turbo_fitted(dev) meta = cb_meta if cb_meta else model_meta hd = meta["head_dim"] bpe_lookup: dict[str, float] = {} for name, per_layer in vqkv_fitted.items(): q0 = next(iter(per_layer.values())) bpe_lookup[name] = round(q0.bits_per_element(hd), 4) for name in turbo_cfg_lookup: if name in turbo_fitted: b, use_qjl = turbo_cfg_lookup[name] bpe_lookup[name] = round(b + (1 if use_qjl else 0) + 16.0 / hd, 4) if config_filter: bpe_lookup = {k: v for k, v in bpe_lookup.items() if k in config_filter} print(f"[cheap] config filter: {len(bpe_lookup)} configs") all_trace_rows: list[dict] = [] all_summary: list[dict] = [] for task_name in tasks: task_cfg = TASK_CONFIGS[task_name] # trec few-shot prompts can be shorter than agentic traces min_len = 512 if task_name == "trec" else 2048 rows = run_cheap_task( model, tok, meta, task_name, task_cfg, vqkv_fitted, turbo_fitted, turbo_cfg_lookup, bpe_lookup, config_filter, n_eval, min_len=min_len, ) all_trace_rows.extend(rows) summary = _aggregate_cheap_rows(rows, bpe_lookup) for row in summary: row["task"] = task_name all_summary.extend(summary) _print_cheap_table(task_name, summary) out_path = os.path.join(ARTIFACT_DIR, "longbench_cheap_metrics.json") with open(out_path, "w") as fh: json.dump(all_summary, fh, indent=2) print(f"\n[cheap] saved -> {out_path} ({len(all_summary)} task×config rows)") # ============================================================================ # Per-task generation + scoring (--stage generate) # ============================================================================ def run_task(model, tok, task_name: str, task_cfg: dict, configs: list, n_eval: int) -> dict[str, float]: """Run every config on n_eval examples. Returns {config_name: mean score}.""" ds = _load_longbench(task_name) subset = ds.select(range(min(n_eval, len(ds)))) print(f"\n[{task_name}] {len(subset)} examples, " f"max_len={task_cfg['max_len']}, max_new={task_cfg['max_new_tokens']}") score_fn = task_cfg.get("score_fn", qa_f1_score) max_new = task_cfg["max_new_tokens"] max_len = task_cfg["max_len"] scores: dict[str, list[float]] = {name: [] for name, _, _ in configs} for ex in tqdm(subset, desc=task_name): prompt = prompt_text(task_cfg, ex) refs = (ex["answers"] if isinstance(ex["answers"], list) else [ex["answers"]]) # all_classes is present in TREC examples; ignored by other scorers all_classes = ex.get("all_classes") or None ids = tok(prompt, return_tensors="pt", truncation=True, max_length=max_len).to(model.device) for name, _bpe, cache_cls in configs: cache = cache_cls() if cache_cls is not None else None with torch.no_grad(): out = model.generate( **ids, max_new_tokens=max_new, do_sample=False, past_key_values=cache, use_cache=True, ) pred = tok.decode(out[0, ids["input_ids"].shape[1]:], skip_special_tokens=True).strip() scores[name].append(score_fn(pred, refs, all_classes=all_classes)) return {name: sum(vs) / len(vs) for name, vs in scores.items() if vs} # ============================================================================ # Main # ============================================================================ def stage_generate(tasks: list[str], n_eval: int, config_filter: list[str] | None): model, tok, meta = load_model_and_meta() device = model.device if str(device) == "meta": device = torch.device("cuda:0") configs = build_all_configs(meta, device, only=config_filter) print(f"\n[generate] {len(configs)} configs × {len(tasks)} tasks × " f"{n_eval} examples each") all_results: dict[str, dict[str, float]] = {} for task_name in tasks: all_results[task_name] = run_task( model, tok, task_name, TASK_CONFIGS[task_name], configs, n_eval) records = [] print(f"\n{'task':<22} {'config':<32} {'bpe':>5} {'score':>6} {'Δ':>8}") print("-" * 78) for task_name in tasks: task_scores = all_results[task_name] fp16_score = task_scores.get("fp16") for name, bpe, _ in configs: score = task_scores.get(name) delta = (None if (score is None or fp16_score is None) else round(score - fp16_score, 4)) delta_str = f"{delta:+.4f}" if delta is not None else " —" print(f"{task_name:<22} {name:<32} {bpe:5.2f} " f"{score:6.4f} {delta_str}") records.append({ "task": task_name, "config": name, "bpe": bpe, "score": round(score, 4) if score is not None else None, "delta": delta, }) print() out_path = os.path.join(ARTIFACT_DIR, "longbench_results.json") with open(out_path, "w") as fh: json.dump(records, fh, indent=2) print(f"[generate] saved -> {out_path}") def main(): ap = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--stage", choices=["cheap", "generate"], default="cheap", help="cheap: KV proxy metrics (default); generate: task scores") ap.add_argument("--tasks", nargs="+", default=list(TASK_CONFIGS), choices=list(TASK_CONFIGS), help="LongBench tasks to run (default: all 6)") ap.add_argument("--n_eval", type=int, default=50, help="Examples per task (default: 50)") ap.add_argument("--configs", nargs="+", default=None, help="Config names to include (default: all). " "E.g. --configs productvq-32x256-2b turbo-mse-2b") args = ap.parse_args() config_filter = set(args.configs) if args.configs else None if args.stage == "cheap": stage_cheap(args.tasks, args.n_eval, config_filter) else: stage_generate(args.tasks, args.n_eval, args.configs) if __name__ == "__main__": main()