attnvq / longbench_eval.py
adirik's picture
AttnVQ submission
5a2d2ad
Raw
History Blame Contribute Delete
31.4 kB
"""
longbench_eval.py — LongBench v1 evaluation for AttnVQ and baselines.
Stages:
cheap proxy KV metrics on task contexts (fast, default)
generate end-to-end task scoring via model.generate (slow)
Tasks: qasper, 2wikimqa, hotpotqa, passage_retrieval_en, repobench-p
Usage:
python longbench_eval.py --stage cheap --n_eval 50
python longbench_eval.py --stage generate --tasks hotpotqa qasper
"""
from __future__ import annotations
import argparse
import collections
import json
import os
import re
import string
from difflib import SequenceMatcher
import torch
from tqdm import tqdm
from turbo_benchmark import TurboQuantMSE, QJLResidualIP, turbo_configs # noqa: F401 — unpickle
from vqkv.quantizers import KIVIScalarKV # noqa: F401
from vqkv.metrics import (key_cosine, cache_mse, inner_product_distortion,
attention_output, attn_output_cosine, attn_output_error)
MODEL_ID = os.environ.get("LAGUNA_ID", "poolside/Laguna-XS.2")
ARTIFACT_DIR = os.environ.get("VQKV_ARTIFACTS", "./artifacts")
ATTN_WIN = 512
CHEAP_COLS = ("key_cos", "val_cos", "key_mse", "val_mse",
"attn_cos", "attn_output_error", "ip_rel", "ip_bias")
# Scorers (referenced from TASK_CONFIGS)
def _normalize(s: str) -> str:
s = s.lower()
s = s.translate(str.maketrans("", "", string.punctuation))
return " ".join(s.split())
def qa_f1_score(prediction: str, ground_truths: list[str], **_) -> float:
"""Max token-level F1 over all reference answers (QA tasks)."""
pred_toks = _normalize(prediction).split()
best = 0.0
for ref in ground_truths:
ref_toks = _normalize(ref).split()
if not pred_toks or not ref_toks:
best = max(best, float(pred_toks == ref_toks))
continue
common = collections.Counter(pred_toks) & collections.Counter(ref_toks)
n_common = sum(common.values())
if n_common == 0:
continue
p = n_common / len(pred_toks)
r = n_common / len(ref_toks)
best = max(best, 2 * p * r / (p + r))
return best
def retrieval_score(prediction: str, ground_truths: list[str], **_) -> float:
"""Passage-retrieval accuracy: extract digit from prediction, match gold.
Ground truths are strings like 'Paragraph 3'; we extract the number and
check whether it appears among the first 10 digits in the prediction.
Mirrors the LongBench retrieval_score implementation.
"""
gt_ids = set()
for ref in ground_truths:
m = re.findall(r"\d+", ref)
gt_ids.add(m[0] if m else _normalize(ref))
pred_nums = re.findall(r"\d+", prediction)
if not pred_nums:
return 0.0
right = sum(1 for n in pred_nums[:10] if n in gt_ids)
return right / len(pred_nums[:10])
def classification_score(prediction: str, ground_truths: list[str],
all_classes: list[str] | None = None, **_) -> float:
"""Classification accuracy (TREC).
If all_classes is provided (from the dataset example), restrict matches
to valid class labels to avoid spurious substring hits — matches the
LongBench classification_score behaviour.
"""
if all_classes:
matched = [c for c in all_classes if c.lower() in prediction.lower()]
return float(any(ref.lower() in [m.lower() for m in matched]
for ref in ground_truths))
# Fallback: normalised substring match
pred_norm = _normalize(prediction)
return float(any(_normalize(ref) in pred_norm for ref in ground_truths))
def edit_similarity_score(prediction: str, ground_truths: list[str], **_) -> float:
"""Character-level edit similarity for code completion (RepoBench-P).
Uses difflib.SequenceMatcher.ratio() — equivalent to
fuzz.ratio() (not partial_ratio) and zero-dep.
"""
pred = prediction.strip()
best = 0.0
for ref in ground_truths:
ref_s = ref.strip()
if not pred and not ref_s:
best = 1.0
elif pred and ref_s:
best = max(best, SequenceMatcher(None, pred, ref_s).ratio())
return best
# ============================================================================
# Task configs
# ============================================================================
# QA tasks: increasing context length (~3.6K → ~18K tokens) directly probes
# the compounding-error-vs-length axis from RUNBOOK Step 6.
# Three extra tasks cover retrieval accuracy, classification, and code completion.
# ============================================================================
TASK_CONFIGS: dict[str, dict] = {
"qasper": dict(
score_fn=qa_f1_score,
max_new_tokens=20,
max_len=16384,
prompt_template=(
"Answer the question based on the given passages. "
"Only give me the answer and do not output any other words.\n\n"
"The following are given passages.\n{context}\n\n"
"Answer the question based on the given passages. "
"Only give me the answer and do not output any other words.\n\n"
"Question: {input}\nAnswer:"
),
),
"2wikimqa": dict(
score_fn=qa_f1_score,
max_new_tokens=10,
max_len=16384,
prompt_template=(
"Answer the question based on the given passages. "
"Only give me the answer and do not output any other words.\n\n"
"The following are given passages.\n{context}\n\n"
"Answer the question based on the given passages. "
"Only give me the answer and do not output any other words.\n\n"
"Question: {input}\nAnswer:"
),
),
"hotpotqa": dict(
score_fn=qa_f1_score,
max_new_tokens=32,
max_len=32768,
prompt_template=(
"Answer the question based on the given passages. "
"Only give me the answer and do not output any other words.\n\n"
"The following are given passages.\n{context}\n\n"
"Answer the question based on the given passages. "
"Only give me the answer and do not output any other words.\n\n"
"Question: {input}\nAnswer:"
),
),
"passage_retrieval_en": dict(
score_fn=retrieval_score,
max_new_tokens=10,
max_len=32768,
prompt_template=(
"Below is a record of a series of paragraphs, each from different "
"documents. Tell me which paragraph the given passage is from.\n\n"
"{context}\n\n"
"Based on the above paragraphs, which paragraph does the following "
"passage come from? Only output the paragraph number. "
"Do not output any other characters.\n\n"
"{input}\nAnswer:"
),
),
"repobench-p": dict(
score_fn=edit_similarity_score,
max_new_tokens=64,
max_len=16384,
prompt_template=(
"Please complete the code given below.\n{context}\n{input}\n"
),
),
}
# ============================================================================
# Model loading
# ============================================================================
def load_model_and_meta():
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, fix_mistral_regex=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.bfloat16, device_map="cuda",
trust_remote_code=True)
model.eval()
cfg = model.config
full_layers = [i for i, t in enumerate(cfg.layer_types) if t == "full_attention"]
meta = {
"full_layers": full_layers,
"n_kv_heads": cfg.num_key_value_heads,
"n_q_heads": cfg.num_attention_heads,
"head_dim": cfg.head_dim,
}
print(f"[meta] full-attention layers ({len(full_layers)}): {full_layers}")
print(f"[meta] kv_heads={meta['n_kv_heads']} head_dim={meta['head_dim']}")
return model, tok, meta
# ============================================================================
# Generic VQCache
# ============================================================================
def make_cache_class(per_layer_fns: dict, target_layers: list):
"""Build a DynamicCache subclass that round-trips K/V through the given fns.
per_layer_fns[layer_idx] = {
"k_fn": callable (N, d) -> (N, d),
"v_fn": callable (N, d) -> (N, d),
"per_channel": bool, # True for KIVI/Sign/Ternary (key dim=0 reduction)
}
For non-per-channel quantizers, K and V are flattened to (s*h, d) before
calling k_fn/v_fn. For per-channel, each head's (s, d) block is passed
separately so the quantizer's token-axis statistics are per-head (not
cross-head), matching the KIVI / benchmark.py convention.
"""
from transformers.cache_utils import DynamicCache
_target = set(target_layers)
class VQCache(DynamicCache):
def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
if layer_idx in _target and layer_idx in per_layer_fns:
entry = per_layer_fns[layer_idx]
k_fn = entry["k_fn"]
v_fn = entry["v_fn"]
per_channel = entry.get("per_channel", False)
_b, h, s, d = key_states.shape
if per_channel:
# key_states[0]: (h, s, d) -> (s, h, d)
kk = key_states[0].permute(1, 0, 2).float()
k_hat = torch.stack([k_fn(kk[:, hh, :]) for hh in range(h)], dim=1)
vv = value_states[0].permute(1, 0, 2).float()
v_hat = torch.stack([v_fn(vv[:, hh, :]) for hh in range(h)], dim=1)
# (s, h, d) -> (h, s, d) -> (1, h, s, d)
k_hat = k_hat.permute(1, 0, 2).unsqueeze(0)
v_hat = v_hat.permute(1, 0, 2).unsqueeze(0)
else:
# key_states[0]: (h, s, d) -> (s, h, d) -> (s*h, d)
kf = key_states[0].permute(1, 0, 2).reshape(-1, d).float()
vf = value_states[0].permute(1, 0, 2).reshape(-1, d).float()
# (s*h, d) -> (s, h, d) -> (h, s, d) -> (1, h, s, d)
k_hat = k_fn(kf).reshape(s, h, d).permute(1, 0, 2).unsqueeze(0)
v_hat = v_fn(vf).reshape(s, h, d).permute(1, 0, 2).unsqueeze(0)
key_states = k_hat.to(key_states.dtype).to(key_states.device)
value_states = v_hat.to(value_states.dtype).to(value_states.device)
return super().update(key_states, value_states, layer_idx, cache_kwargs)
return VQCache
# ============================================================================
# Build unified config list from both codebook files
# ============================================================================
def build_all_configs(meta: dict, device, only: list[str] | None = None):
"""Return list of (name, bpe, cache_cls_or_None).
cache_cls_or_None: class (not instance) to call as cls() each generation,
or None for the fp16 baseline (DynamicCache allocated internally by model).
"""
hd = meta["head_dim"]
layers = meta["full_layers"]
configs: list[tuple[str, float, type | None]] = []
configs.append(("fp16", 16.0, None))
# -- Regular codebooks (ProductVQ, RoPESplit, Scalar, KIVI, Sign, Ternary) --
cb_path = os.path.join(ARTIFACT_DIR, "codebooks.pt")
if os.path.exists(cb_path):
blob = torch.load(cb_path, weights_only=False)
fitted = blob["fitted"]
for name, per_layer in fitted.items():
for q in per_layer.values():
if hasattr(q, "to"):
q.to(device)
q0 = next(iter(per_layer.values()))
bpe = round(q0.bits_per_element(hd), 4)
per_channel = (isinstance(q0, KIVIScalarKV)
or getattr(q0, "per_channel_key", False))
per_layer_fns = {
i: {"k_fn": per_layer[i].roundtrip_k,
"v_fn": per_layer[i].roundtrip_v,
"per_channel": per_channel}
for i in layers if i in per_layer
}
configs.append((name, bpe, make_cache_class(per_layer_fns, layers)))
else:
print(f"[warn] {cb_path} not found — skipping ProductVQ / scalar configs")
# -- TurboQuant codebooks --
tc_path = os.path.join(ARTIFACT_DIR, "turbo_codebooks.pt")
if os.path.exists(tc_path):
blob = torch.load(tc_path, weights_only=False)
fitted = blob["fitted"]
bits_list = blob["bits"]
cfg_lookup = {n: (b, use_qjl) for n, b, use_qjl in turbo_configs(bits_list)}
for name, per_layer in fitted.items():
b, use_qjl = cfg_lookup[name]
bpe = round(b + (1 if use_qjl else 0) + 16.0 / hd, 4)
for entry in per_layer.values():
entry["kq"].to(device)
entry["vq"].to(device)
if entry["qjl"] is not None:
entry["qjl"].to(device)
per_layer_fns = {
i: {"k_fn": per_layer[i]["kq"].roundtrip,
"v_fn": per_layer[i]["vq"].roundtrip,
"per_channel": False}
for i in layers if i in per_layer
}
configs.append((name, bpe, make_cache_class(per_layer_fns, layers)))
else:
print(f"[warn] {tc_path} not found — skipping TurboQuant configs")
if only is not None:
only_set = set(only)
configs = [(n, b, c) for n, b, c in configs if n in only_set]
unknown = only_set - {n for n, _, _ in configs}
if unknown:
print(f"[warn] unknown --configs names: {sorted(unknown)}")
print(f"[build] {len(configs)} configs loaded")
for n, b, _ in configs:
print(f" {n:<32} {b:.3f} bpe")
return configs
# ============================================================================
# Per-task generation + scoring
# ============================================================================
def _load_longbench(task_name: str):
"""Load a LongBench task, compatible with datasets >= 3.0.
datasets >= 3.0 dropped custom dataset-script support and raises
'Dataset scripts are no longer supported' for THUDM/LongBench.
We load the underlying JSONL files from the HF Hub directly instead.
Tries the English-suffixed file first (_e.jsonl), then the plain name.
"""
from datasets import load_dataset as _ld
# Candidates in priority order:
# {task}_e.jsonl – English-suffixed bilingual tasks
# {task}.jsonl – single-language tasks
# {task_underscored}.jsonl – hyphenated names (repobench-p → repobench_p)
slug = task_name.replace("-", "_")
candidates = [f"{task_name}_e.jsonl", f"{task_name}.jsonl",
f"{slug}_e.jsonl", f"{slug}.jsonl"]
# deduplicate while preserving order
seen: set[str] = set()
for fname in [c for c in candidates if not (c in seen or seen.add(c))]:
try:
return _ld(
"json",
data_files=f"hf://datasets/THUDM/LongBench/data/{fname}",
split="train",
)
except Exception:
continue
# Fallback for older datasets versions that still support scripts
return _ld("THUDM/LongBench", name=task_name, split="test")
def prompt_text(task_cfg: dict, example: dict) -> str:
return task_cfg["prompt_template"].format(
context=example["context"], input=example["input"])
def load_vqkv_fitted(device):
path = os.path.join(ARTIFACT_DIR, "codebooks.pt")
if not os.path.exists(path):
raise FileNotFoundError(f"{path} not found; run benchmark.py --stage fit")
blob = torch.load(path, weights_only=False)
fitted, meta = blob["fitted"], blob["meta"]
for per_layer in fitted.values():
for q in per_layer.values():
if hasattr(q, "to"):
q.to(device)
return fitted, meta
def load_turbo_fitted(device):
path = os.path.join(ARTIFACT_DIR, "turbo_codebooks.pt")
if not os.path.exists(path):
print(f"[warn] {path} not found — skipping TurboQuant configs")
return {}, {}
blob = torch.load(path, weights_only=False)
fitted = blob["fitted"]
cfg_lookup = {n: (b, use_qjl) for n, b, use_qjl in turbo_configs(blob["bits"])}
for per_layer in fitted.values():
for entry in per_layer.values():
entry["kq"].to(device)
entry["vq"].to(device)
if entry["qjl"] is not None:
entry["qjl"].to(device)
return fitted, cfg_lookup
def _aggregate_cheap_rows(trace_rows: list[dict], bpe_lookup: dict) -> list[dict]:
agg = collections.defaultdict(lambda: collections.defaultdict(list))
for r in trace_rows:
for col in CHEAP_COLS:
agg[r["config"]][col].append(r[col])
summary = []
for name in bpe_lookup:
if name not in agg:
continue
cols = agg[name]
n = len(cols["key_cos"])
row = {"config": name, "bits_per_elt": bpe_lookup[name], "n_traces": n}
for col in CHEAP_COLS:
row[col] = round(sum(cols[col]) / n, 5)
summary.append(row)
return summary
def _print_cheap_table(task_name: str, summary: list[dict]):
print(f"\n[cheap/{task_name}] mean metrics:")
print(f" {'config':32s} {'bpe':>5} {'key_cos':>8} {'val_cos':>8} "
f"{'key_mse':>9} {'val_mse':>9} {'attn_cos':>9} {'attn_err':>9} "
f"{'ip_rel':>8} {'ip_bias':>9}")
for row in summary:
print(f" {row['config']:32s} {row['bits_per_elt']:5.2f} "
f"{row['key_cos']:8.4f} {row['val_cos']:8.4f} "
f"{row['key_mse']:9.5f} {row['val_mse']:9.5f} "
f"{row['attn_cos']:9.4f} {row['attn_output_error']:9.4f} "
f"{row['ip_rel']:8.5f} {row['ip_bias']:9.6f}")
def run_cheap_task(model, tok, meta, task_name: str, task_cfg: dict,
vqkv_fitted: dict, turbo_fitted: dict,
turbo_cfg_lookup: dict, bpe_lookup: dict,
config_filter: set[str] | None, n_eval: int,
min_len: int = 2048) -> list[dict]:
"""Dump fp16 caches on n_eval LongBench prompts; score all quantizer configs."""
from transformers.cache_utils import DynamicCache
ds = _load_longbench(task_name)
subset = ds.select(range(min(n_eval, len(ds))))
max_len = task_cfg["max_len"]
full = meta["full_layers"]
hd = meta["head_dim"]
n_q = meta.get("n_q_heads", 48)
dev = model.device
if str(dev) == "meta":
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"\n[cheap/{task_name}] {len(subset)} examples, max_len={max_len}")
class EvalDump(DynamicCache):
def __init__(self):
super().__init__()
self.d = {i: {} for i in full}
def update(self, ks, vs, li, ck=None):
if li in set(full):
self.d[li]["k"] = ks.detach()[0].permute(1, 0, 2).float()
self.d[li]["v"] = vs.detach()[0].permute(1, 0, 2).float()
return super().update(ks, vs, li, ck)
trace_rows: list[dict] = []
n_used = 0
for ex in tqdm(subset, desc=f"cheap/{task_name}"):
text = prompt_text(task_cfg, ex)
ids = tok(text, return_tensors="pt", truncation=True,
max_length=max_len).to(dev)
if ids["input_ids"].shape[1] < min_len:
continue
cache = EvalDump()
with torch.no_grad():
model.model(**ids, past_key_values=cache, use_cache=True)
n_used += 1
synth_q = {}
for i in full:
s = cache.d[i]["k"].shape[0]
win = min(s, ATTN_WIN)
q_rand = torch.randn(win, n_q, hd, device=dev)
synth_q[i] = q_rand / q_rand.norm(dim=-1, keepdim=True).clamp_min(1e-8)
# -- vqkv configs (ProductVQ, scalar, KIVI, …) --
for name, per_layer in vqkv_fitted.items():
if config_filter and name not in config_filter:
continue
acc = collections.defaultdict(float)
nL = 0
for i in full:
k = cache.d[i]["k"]
v = cache.d[i]["v"]
q = per_layer[i]
s, h, d = k.shape
per_channel = (
isinstance(q, KIVIScalarKV)
or getattr(q, "per_channel_key", False)
)
if per_channel:
k_hat = torch.stack([q.roundtrip_k(k[:, hh, :]) for hh in range(h)], 1)
v_hat = torch.stack([q.roundtrip_v(v[:, hh, :]) for hh in range(h)], 1)
else:
k_hat = q.roundtrip_k(k.reshape(-1, d)).reshape(s, h, d)
v_hat = q.roundtrip_v(v.reshape(-1, d)).reshape(s, h, d)
acc["key_cos"] += key_cosine(k, k_hat)
acc["val_cos"] += key_cosine(v, v_hat)
acc["key_mse"] += cache_mse(k, k_hat)
acc["val_mse"] += cache_mse(v, v_hat)
win = min(s, ATTN_WIN)
kw, kw_hat = k[-win:], k_hat[-win:]
vw, vw_hat = v[-win:], v_hat[-win:]
q_syn = synth_q[i]
out_ref, _ = attention_output(q_syn, kw, vw, n_q)
out_hat, _ = attention_output(q_syn, kw_hat, vw_hat, n_q)
acc["attn_cos"] += attn_output_cosine(out_ref, out_hat)
acc["attn_output_error"] += attn_output_error(out_ref, out_hat)
ip = inner_product_distortion(q_syn, kw, kw_hat)
acc["ip_rel"] += ip["ip_rel_err"]
acc["ip_bias"] += ip["ip_bias"]
nL += 1
trace_rows.append({
"task": task_name,
"trace_len": ids["input_ids"].shape[1],
"config": name,
**{col: acc[col] / nL for col in CHEAP_COLS},
})
# -- TurboQuant configs --
for name, per_layer in turbo_fitted.items():
if config_filter and name not in config_filter:
continue
acc = collections.defaultdict(float)
nL = 0
for i in full:
k = cache.d[i]["k"]
v = cache.d[i]["v"]
e = per_layer[i]
s, h, d = k.shape
k_hat = e["kq"].roundtrip(k.reshape(-1, d)).reshape(s, h, d)
v_hat = e["vq"].roundtrip(v.reshape(-1, d)).reshape(s, h, d)
acc["key_cos"] += key_cosine(k, k_hat)
acc["val_cos"] += key_cosine(v, v_hat)
acc["key_mse"] += cache_mse(k, k_hat)
acc["val_mse"] += cache_mse(v, v_hat)
win = min(s, ATTN_WIN)
kw, kw_hat = k[-win:], k_hat[-win:]
vw, vw_hat = v[-win:], v_hat[-win:]
q_syn = synth_q[i]
out_ref, _ = attention_output(q_syn, kw, vw, n_q)
out_hat, _ = attention_output(q_syn, kw_hat, vw_hat, n_q)
acc["attn_cos"] += attn_output_cosine(out_ref, out_hat)
acc["attn_output_error"] += attn_output_error(out_ref, out_hat)
q0 = q_syn[:, 0, :]
kr0 = kw[:, 0, :]
kh0 = kw_hat[:, 0, :]
qi = torch.randint(0, win, (4096,), device=dev)
ki = torch.randint(0, win, (4096,), device=dev)
ip_ref = (q0[qi] * kr0[ki]).sum(-1)
if e["qjl"] is not None:
ip_hat = e["qjl"].estimate_ip(q0[qi], kr0[ki])
else:
ip_hat = (q0[qi] * kh0[ki]).sum(-1)
acc["ip_bias"] += (ip_hat - ip_ref).mean().item()
acc["ip_rel"] += ((ip_hat - ip_ref).abs() /
ip_ref.abs().clamp_min(1e-6)).mean().item()
nL += 1
trace_rows.append({
"task": task_name,
"trace_len": ids["input_ids"].shape[1],
"config": name,
**{col: acc[col] / nL for col in CHEAP_COLS},
})
print(f"[cheap/{task_name}] used {n_used}/{len(subset)} traces (min_len={min_len})")
return trace_rows
def stage_cheap(tasks: list[str], n_eval: int, config_filter: set[str] | None):
model, tok, model_meta = load_model_and_meta()
dev = model.device
if str(dev) == "meta":
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
vqkv_fitted, cb_meta = load_vqkv_fitted(dev)
turbo_fitted, turbo_cfg_lookup = load_turbo_fitted(dev)
meta = cb_meta if cb_meta else model_meta
hd = meta["head_dim"]
bpe_lookup: dict[str, float] = {}
for name, per_layer in vqkv_fitted.items():
q0 = next(iter(per_layer.values()))
bpe_lookup[name] = round(q0.bits_per_element(hd), 4)
for name in turbo_cfg_lookup:
if name in turbo_fitted:
b, use_qjl = turbo_cfg_lookup[name]
bpe_lookup[name] = round(b + (1 if use_qjl else 0) + 16.0 / hd, 4)
if config_filter:
bpe_lookup = {k: v for k, v in bpe_lookup.items() if k in config_filter}
print(f"[cheap] config filter: {len(bpe_lookup)} configs")
all_trace_rows: list[dict] = []
all_summary: list[dict] = []
for task_name in tasks:
task_cfg = TASK_CONFIGS[task_name]
# trec few-shot prompts can be shorter than agentic traces
min_len = 512 if task_name == "trec" else 2048
rows = run_cheap_task(
model, tok, meta, task_name, task_cfg,
vqkv_fitted, turbo_fitted, turbo_cfg_lookup, bpe_lookup,
config_filter, n_eval, min_len=min_len,
)
all_trace_rows.extend(rows)
summary = _aggregate_cheap_rows(rows, bpe_lookup)
for row in summary:
row["task"] = task_name
all_summary.extend(summary)
_print_cheap_table(task_name, summary)
out_path = os.path.join(ARTIFACT_DIR, "longbench_cheap_metrics.json")
with open(out_path, "w") as fh:
json.dump(all_summary, fh, indent=2)
print(f"\n[cheap] saved -> {out_path} ({len(all_summary)} task×config rows)")
# ============================================================================
# Per-task generation + scoring (--stage generate)
# ============================================================================
def run_task(model, tok, task_name: str, task_cfg: dict,
configs: list, n_eval: int) -> dict[str, float]:
"""Run every config on n_eval examples. Returns {config_name: mean score}."""
ds = _load_longbench(task_name)
subset = ds.select(range(min(n_eval, len(ds))))
print(f"\n[{task_name}] {len(subset)} examples, "
f"max_len={task_cfg['max_len']}, max_new={task_cfg['max_new_tokens']}")
score_fn = task_cfg.get("score_fn", qa_f1_score)
max_new = task_cfg["max_new_tokens"]
max_len = task_cfg["max_len"]
scores: dict[str, list[float]] = {name: [] for name, _, _ in configs}
for ex in tqdm(subset, desc=task_name):
prompt = prompt_text(task_cfg, ex)
refs = (ex["answers"] if isinstance(ex["answers"], list)
else [ex["answers"]])
# all_classes is present in TREC examples; ignored by other scorers
all_classes = ex.get("all_classes") or None
ids = tok(prompt, return_tensors="pt", truncation=True,
max_length=max_len).to(model.device)
for name, _bpe, cache_cls in configs:
cache = cache_cls() if cache_cls is not None else None
with torch.no_grad():
out = model.generate(
**ids,
max_new_tokens=max_new,
do_sample=False,
past_key_values=cache,
use_cache=True,
)
pred = tok.decode(out[0, ids["input_ids"].shape[1]:],
skip_special_tokens=True).strip()
scores[name].append(score_fn(pred, refs, all_classes=all_classes))
return {name: sum(vs) / len(vs) for name, vs in scores.items() if vs}
# ============================================================================
# Main
# ============================================================================
def stage_generate(tasks: list[str], n_eval: int, config_filter: list[str] | None):
model, tok, meta = load_model_and_meta()
device = model.device
if str(device) == "meta":
device = torch.device("cuda:0")
configs = build_all_configs(meta, device, only=config_filter)
print(f"\n[generate] {len(configs)} configs × {len(tasks)} tasks × "
f"{n_eval} examples each")
all_results: dict[str, dict[str, float]] = {}
for task_name in tasks:
all_results[task_name] = run_task(
model, tok, task_name, TASK_CONFIGS[task_name], configs, n_eval)
records = []
print(f"\n{'task':<22} {'config':<32} {'bpe':>5} {'score':>6} {'Δ':>8}")
print("-" * 78)
for task_name in tasks:
task_scores = all_results[task_name]
fp16_score = task_scores.get("fp16")
for name, bpe, _ in configs:
score = task_scores.get(name)
delta = (None if (score is None or fp16_score is None)
else round(score - fp16_score, 4))
delta_str = f"{delta:+.4f}" if delta is not None else " —"
print(f"{task_name:<22} {name:<32} {bpe:5.2f} "
f"{score:6.4f} {delta_str}")
records.append({
"task": task_name,
"config": name,
"bpe": bpe,
"score": round(score, 4) if score is not None else None,
"delta": delta,
})
print()
out_path = os.path.join(ARTIFACT_DIR, "longbench_results.json")
with open(out_path, "w") as fh:
json.dump(records, fh, indent=2)
print(f"[generate] saved -> {out_path}")
def main():
ap = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--stage", choices=["cheap", "generate"], default="cheap",
help="cheap: KV proxy metrics (default); generate: task scores")
ap.add_argument("--tasks", nargs="+", default=list(TASK_CONFIGS),
choices=list(TASK_CONFIGS),
help="LongBench tasks to run (default: all 6)")
ap.add_argument("--n_eval", type=int, default=50,
help="Examples per task (default: 50)")
ap.add_argument("--configs", nargs="+", default=None,
help="Config names to include (default: all). "
"E.g. --configs productvq-32x256-2b turbo-mse-2b")
args = ap.parse_args()
config_filter = set(args.configs) if args.configs else None
if args.stage == "cheap":
stage_cheap(args.tasks, args.n_eval, config_filter)
else:
stage_generate(args.tasks, args.n_eval, args.configs)
if __name__ == "__main__":
main()