| """ |
| longbench_eval.py — LongBench v1 evaluation for AttnVQ and baselines. |
| |
| Stages: |
| cheap proxy KV metrics on task contexts (fast, default) |
| generate end-to-end task scoring via model.generate (slow) |
| |
| Tasks: qasper, 2wikimqa, hotpotqa, passage_retrieval_en, repobench-p |
| |
| Usage: |
| python longbench_eval.py --stage cheap --n_eval 50 |
| python longbench_eval.py --stage generate --tasks hotpotqa qasper |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import collections |
| import json |
| import os |
| import re |
| import string |
| from difflib import SequenceMatcher |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| from turbo_benchmark import TurboQuantMSE, QJLResidualIP, turbo_configs |
| from vqkv.quantizers import KIVIScalarKV |
| from vqkv.metrics import (key_cosine, cache_mse, inner_product_distortion, |
| attention_output, attn_output_cosine, attn_output_error) |
|
|
| MODEL_ID = os.environ.get("LAGUNA_ID", "poolside/Laguna-XS.2") |
| ARTIFACT_DIR = os.environ.get("VQKV_ARTIFACTS", "./artifacts") |
|
|
| ATTN_WIN = 512 |
| CHEAP_COLS = ("key_cos", "val_cos", "key_mse", "val_mse", |
| "attn_cos", "attn_output_error", "ip_rel", "ip_bias") |
|
|
| |
| def _normalize(s: str) -> str: |
| s = s.lower() |
| s = s.translate(str.maketrans("", "", string.punctuation)) |
| return " ".join(s.split()) |
|
|
|
|
| def qa_f1_score(prediction: str, ground_truths: list[str], **_) -> float: |
| """Max token-level F1 over all reference answers (QA tasks).""" |
| pred_toks = _normalize(prediction).split() |
| best = 0.0 |
| for ref in ground_truths: |
| ref_toks = _normalize(ref).split() |
| if not pred_toks or not ref_toks: |
| best = max(best, float(pred_toks == ref_toks)) |
| continue |
| common = collections.Counter(pred_toks) & collections.Counter(ref_toks) |
| n_common = sum(common.values()) |
| if n_common == 0: |
| continue |
| p = n_common / len(pred_toks) |
| r = n_common / len(ref_toks) |
| best = max(best, 2 * p * r / (p + r)) |
| return best |
|
|
|
|
| def retrieval_score(prediction: str, ground_truths: list[str], **_) -> float: |
| """Passage-retrieval accuracy: extract digit from prediction, match gold. |
| |
| Ground truths are strings like 'Paragraph 3'; we extract the number and |
| check whether it appears among the first 10 digits in the prediction. |
| Mirrors the LongBench retrieval_score implementation. |
| """ |
| gt_ids = set() |
| for ref in ground_truths: |
| m = re.findall(r"\d+", ref) |
| gt_ids.add(m[0] if m else _normalize(ref)) |
| pred_nums = re.findall(r"\d+", prediction) |
| if not pred_nums: |
| return 0.0 |
| right = sum(1 for n in pred_nums[:10] if n in gt_ids) |
| return right / len(pred_nums[:10]) |
|
|
|
|
| def classification_score(prediction: str, ground_truths: list[str], |
| all_classes: list[str] | None = None, **_) -> float: |
| """Classification accuracy (TREC). |
| |
| If all_classes is provided (from the dataset example), restrict matches |
| to valid class labels to avoid spurious substring hits — matches the |
| LongBench classification_score behaviour. |
| """ |
| if all_classes: |
| matched = [c for c in all_classes if c.lower() in prediction.lower()] |
| return float(any(ref.lower() in [m.lower() for m in matched] |
| for ref in ground_truths)) |
| |
| pred_norm = _normalize(prediction) |
| return float(any(_normalize(ref) in pred_norm for ref in ground_truths)) |
|
|
|
|
| def edit_similarity_score(prediction: str, ground_truths: list[str], **_) -> float: |
| """Character-level edit similarity for code completion (RepoBench-P). |
| |
| Uses difflib.SequenceMatcher.ratio() — equivalent to |
| fuzz.ratio() (not partial_ratio) and zero-dep. |
| """ |
| pred = prediction.strip() |
| best = 0.0 |
| for ref in ground_truths: |
| ref_s = ref.strip() |
| if not pred and not ref_s: |
| best = 1.0 |
| elif pred and ref_s: |
| best = max(best, SequenceMatcher(None, pred, ref_s).ratio()) |
| return best |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| TASK_CONFIGS: dict[str, dict] = { |
| "qasper": dict( |
| score_fn=qa_f1_score, |
| max_new_tokens=20, |
| max_len=16384, |
| prompt_template=( |
| "Answer the question based on the given passages. " |
| "Only give me the answer and do not output any other words.\n\n" |
| "The following are given passages.\n{context}\n\n" |
| "Answer the question based on the given passages. " |
| "Only give me the answer and do not output any other words.\n\n" |
| "Question: {input}\nAnswer:" |
| ), |
| ), |
| "2wikimqa": dict( |
| score_fn=qa_f1_score, |
| max_new_tokens=10, |
| max_len=16384, |
| prompt_template=( |
| "Answer the question based on the given passages. " |
| "Only give me the answer and do not output any other words.\n\n" |
| "The following are given passages.\n{context}\n\n" |
| "Answer the question based on the given passages. " |
| "Only give me the answer and do not output any other words.\n\n" |
| "Question: {input}\nAnswer:" |
| ), |
| ), |
| "hotpotqa": dict( |
| score_fn=qa_f1_score, |
| max_new_tokens=32, |
| max_len=32768, |
| prompt_template=( |
| "Answer the question based on the given passages. " |
| "Only give me the answer and do not output any other words.\n\n" |
| "The following are given passages.\n{context}\n\n" |
| "Answer the question based on the given passages. " |
| "Only give me the answer and do not output any other words.\n\n" |
| "Question: {input}\nAnswer:" |
| ), |
| ), |
| "passage_retrieval_en": dict( |
| score_fn=retrieval_score, |
| max_new_tokens=10, |
| max_len=32768, |
| prompt_template=( |
| "Below is a record of a series of paragraphs, each from different " |
| "documents. Tell me which paragraph the given passage is from.\n\n" |
| "{context}\n\n" |
| "Based on the above paragraphs, which paragraph does the following " |
| "passage come from? Only output the paragraph number. " |
| "Do not output any other characters.\n\n" |
| "{input}\nAnswer:" |
| ), |
| ), |
| "repobench-p": dict( |
| score_fn=edit_similarity_score, |
| max_new_tokens=64, |
| max_len=16384, |
| prompt_template=( |
| "Please complete the code given below.\n{context}\n{input}\n" |
| ), |
| ), |
| } |
|
|
|
|
| |
| |
| |
| def load_model_and_meta(): |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, fix_mistral_regex=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, torch_dtype=torch.bfloat16, device_map="cuda", |
| trust_remote_code=True) |
| model.eval() |
| cfg = model.config |
| full_layers = [i for i, t in enumerate(cfg.layer_types) if t == "full_attention"] |
| meta = { |
| "full_layers": full_layers, |
| "n_kv_heads": cfg.num_key_value_heads, |
| "n_q_heads": cfg.num_attention_heads, |
| "head_dim": cfg.head_dim, |
| } |
| print(f"[meta] full-attention layers ({len(full_layers)}): {full_layers}") |
| print(f"[meta] kv_heads={meta['n_kv_heads']} head_dim={meta['head_dim']}") |
| return model, tok, meta |
|
|
|
|
| |
| |
| |
| def make_cache_class(per_layer_fns: dict, target_layers: list): |
| """Build a DynamicCache subclass that round-trips K/V through the given fns. |
| |
| per_layer_fns[layer_idx] = { |
| "k_fn": callable (N, d) -> (N, d), |
| "v_fn": callable (N, d) -> (N, d), |
| "per_channel": bool, # True for KIVI/Sign/Ternary (key dim=0 reduction) |
| } |
| |
| For non-per-channel quantizers, K and V are flattened to (s*h, d) before |
| calling k_fn/v_fn. For per-channel, each head's (s, d) block is passed |
| separately so the quantizer's token-axis statistics are per-head (not |
| cross-head), matching the KIVI / benchmark.py convention. |
| """ |
| from transformers.cache_utils import DynamicCache |
| _target = set(target_layers) |
|
|
| class VQCache(DynamicCache): |
| def update(self, key_states, value_states, layer_idx, cache_kwargs=None): |
| if layer_idx in _target and layer_idx in per_layer_fns: |
| entry = per_layer_fns[layer_idx] |
| k_fn = entry["k_fn"] |
| v_fn = entry["v_fn"] |
| per_channel = entry.get("per_channel", False) |
|
|
| _b, h, s, d = key_states.shape |
|
|
| if per_channel: |
| |
| kk = key_states[0].permute(1, 0, 2).float() |
| k_hat = torch.stack([k_fn(kk[:, hh, :]) for hh in range(h)], dim=1) |
| vv = value_states[0].permute(1, 0, 2).float() |
| v_hat = torch.stack([v_fn(vv[:, hh, :]) for hh in range(h)], dim=1) |
| |
| k_hat = k_hat.permute(1, 0, 2).unsqueeze(0) |
| v_hat = v_hat.permute(1, 0, 2).unsqueeze(0) |
| else: |
| |
| kf = key_states[0].permute(1, 0, 2).reshape(-1, d).float() |
| vf = value_states[0].permute(1, 0, 2).reshape(-1, d).float() |
| |
| k_hat = k_fn(kf).reshape(s, h, d).permute(1, 0, 2).unsqueeze(0) |
| v_hat = v_fn(vf).reshape(s, h, d).permute(1, 0, 2).unsqueeze(0) |
|
|
| key_states = k_hat.to(key_states.dtype).to(key_states.device) |
| value_states = v_hat.to(value_states.dtype).to(value_states.device) |
|
|
| return super().update(key_states, value_states, layer_idx, cache_kwargs) |
|
|
| return VQCache |
|
|
|
|
| |
| |
| |
| def build_all_configs(meta: dict, device, only: list[str] | None = None): |
| """Return list of (name, bpe, cache_cls_or_None). |
| |
| cache_cls_or_None: class (not instance) to call as cls() each generation, |
| or None for the fp16 baseline (DynamicCache allocated internally by model). |
| """ |
| hd = meta["head_dim"] |
| layers = meta["full_layers"] |
| configs: list[tuple[str, float, type | None]] = [] |
|
|
| configs.append(("fp16", 16.0, None)) |
|
|
| |
| cb_path = os.path.join(ARTIFACT_DIR, "codebooks.pt") |
| if os.path.exists(cb_path): |
| blob = torch.load(cb_path, weights_only=False) |
| fitted = blob["fitted"] |
| for name, per_layer in fitted.items(): |
| for q in per_layer.values(): |
| if hasattr(q, "to"): |
| q.to(device) |
| q0 = next(iter(per_layer.values())) |
| bpe = round(q0.bits_per_element(hd), 4) |
| per_channel = (isinstance(q0, KIVIScalarKV) |
| or getattr(q0, "per_channel_key", False)) |
| per_layer_fns = { |
| i: {"k_fn": per_layer[i].roundtrip_k, |
| "v_fn": per_layer[i].roundtrip_v, |
| "per_channel": per_channel} |
| for i in layers if i in per_layer |
| } |
| configs.append((name, bpe, make_cache_class(per_layer_fns, layers))) |
| else: |
| print(f"[warn] {cb_path} not found — skipping ProductVQ / scalar configs") |
|
|
| |
| tc_path = os.path.join(ARTIFACT_DIR, "turbo_codebooks.pt") |
| if os.path.exists(tc_path): |
| blob = torch.load(tc_path, weights_only=False) |
| fitted = blob["fitted"] |
| bits_list = blob["bits"] |
| cfg_lookup = {n: (b, use_qjl) for n, b, use_qjl in turbo_configs(bits_list)} |
| for name, per_layer in fitted.items(): |
| b, use_qjl = cfg_lookup[name] |
| bpe = round(b + (1 if use_qjl else 0) + 16.0 / hd, 4) |
| for entry in per_layer.values(): |
| entry["kq"].to(device) |
| entry["vq"].to(device) |
| if entry["qjl"] is not None: |
| entry["qjl"].to(device) |
| per_layer_fns = { |
| i: {"k_fn": per_layer[i]["kq"].roundtrip, |
| "v_fn": per_layer[i]["vq"].roundtrip, |
| "per_channel": False} |
| for i in layers if i in per_layer |
| } |
| configs.append((name, bpe, make_cache_class(per_layer_fns, layers))) |
| else: |
| print(f"[warn] {tc_path} not found — skipping TurboQuant configs") |
|
|
| if only is not None: |
| only_set = set(only) |
| configs = [(n, b, c) for n, b, c in configs if n in only_set] |
| unknown = only_set - {n for n, _, _ in configs} |
| if unknown: |
| print(f"[warn] unknown --configs names: {sorted(unknown)}") |
|
|
| print(f"[build] {len(configs)} configs loaded") |
| for n, b, _ in configs: |
| print(f" {n:<32} {b:.3f} bpe") |
| return configs |
|
|
|
|
| |
| |
| |
| def _load_longbench(task_name: str): |
| """Load a LongBench task, compatible with datasets >= 3.0. |
| |
| datasets >= 3.0 dropped custom dataset-script support and raises |
| 'Dataset scripts are no longer supported' for THUDM/LongBench. |
| We load the underlying JSONL files from the HF Hub directly instead. |
| Tries the English-suffixed file first (_e.jsonl), then the plain name. |
| """ |
| from datasets import load_dataset as _ld |
| |
| |
| |
| |
| slug = task_name.replace("-", "_") |
| candidates = [f"{task_name}_e.jsonl", f"{task_name}.jsonl", |
| f"{slug}_e.jsonl", f"{slug}.jsonl"] |
| |
| seen: set[str] = set() |
| for fname in [c for c in candidates if not (c in seen or seen.add(c))]: |
| try: |
| return _ld( |
| "json", |
| data_files=f"hf://datasets/THUDM/LongBench/data/{fname}", |
| split="train", |
| ) |
| except Exception: |
| continue |
| |
| return _ld("THUDM/LongBench", name=task_name, split="test") |
|
|
|
|
| def prompt_text(task_cfg: dict, example: dict) -> str: |
| return task_cfg["prompt_template"].format( |
| context=example["context"], input=example["input"]) |
|
|
|
|
| def load_vqkv_fitted(device): |
| path = os.path.join(ARTIFACT_DIR, "codebooks.pt") |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"{path} not found; run benchmark.py --stage fit") |
| blob = torch.load(path, weights_only=False) |
| fitted, meta = blob["fitted"], blob["meta"] |
| for per_layer in fitted.values(): |
| for q in per_layer.values(): |
| if hasattr(q, "to"): |
| q.to(device) |
| return fitted, meta |
|
|
|
|
| def load_turbo_fitted(device): |
| path = os.path.join(ARTIFACT_DIR, "turbo_codebooks.pt") |
| if not os.path.exists(path): |
| print(f"[warn] {path} not found — skipping TurboQuant configs") |
| return {}, {} |
| blob = torch.load(path, weights_only=False) |
| fitted = blob["fitted"] |
| cfg_lookup = {n: (b, use_qjl) for n, b, use_qjl in turbo_configs(blob["bits"])} |
| for per_layer in fitted.values(): |
| for entry in per_layer.values(): |
| entry["kq"].to(device) |
| entry["vq"].to(device) |
| if entry["qjl"] is not None: |
| entry["qjl"].to(device) |
| return fitted, cfg_lookup |
|
|
|
|
| def _aggregate_cheap_rows(trace_rows: list[dict], bpe_lookup: dict) -> list[dict]: |
| agg = collections.defaultdict(lambda: collections.defaultdict(list)) |
| for r in trace_rows: |
| for col in CHEAP_COLS: |
| agg[r["config"]][col].append(r[col]) |
| summary = [] |
| for name in bpe_lookup: |
| if name not in agg: |
| continue |
| cols = agg[name] |
| n = len(cols["key_cos"]) |
| row = {"config": name, "bits_per_elt": bpe_lookup[name], "n_traces": n} |
| for col in CHEAP_COLS: |
| row[col] = round(sum(cols[col]) / n, 5) |
| summary.append(row) |
| return summary |
|
|
|
|
| def _print_cheap_table(task_name: str, summary: list[dict]): |
| print(f"\n[cheap/{task_name}] mean metrics:") |
| print(f" {'config':32s} {'bpe':>5} {'key_cos':>8} {'val_cos':>8} " |
| f"{'key_mse':>9} {'val_mse':>9} {'attn_cos':>9} {'attn_err':>9} " |
| f"{'ip_rel':>8} {'ip_bias':>9}") |
| for row in summary: |
| print(f" {row['config']:32s} {row['bits_per_elt']:5.2f} " |
| f"{row['key_cos']:8.4f} {row['val_cos']:8.4f} " |
| f"{row['key_mse']:9.5f} {row['val_mse']:9.5f} " |
| f"{row['attn_cos']:9.4f} {row['attn_output_error']:9.4f} " |
| f"{row['ip_rel']:8.5f} {row['ip_bias']:9.6f}") |
|
|
|
|
| def run_cheap_task(model, tok, meta, task_name: str, task_cfg: dict, |
| vqkv_fitted: dict, turbo_fitted: dict, |
| turbo_cfg_lookup: dict, bpe_lookup: dict, |
| config_filter: set[str] | None, n_eval: int, |
| min_len: int = 2048) -> list[dict]: |
| """Dump fp16 caches on n_eval LongBench prompts; score all quantizer configs.""" |
| from transformers.cache_utils import DynamicCache |
|
|
| ds = _load_longbench(task_name) |
| subset = ds.select(range(min(n_eval, len(ds)))) |
| max_len = task_cfg["max_len"] |
| full = meta["full_layers"] |
| hd = meta["head_dim"] |
| n_q = meta.get("n_q_heads", 48) |
| dev = model.device |
| if str(dev) == "meta": |
| dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
| print(f"\n[cheap/{task_name}] {len(subset)} examples, max_len={max_len}") |
|
|
| class EvalDump(DynamicCache): |
| def __init__(self): |
| super().__init__() |
| self.d = {i: {} for i in full} |
|
|
| def update(self, ks, vs, li, ck=None): |
| if li in set(full): |
| self.d[li]["k"] = ks.detach()[0].permute(1, 0, 2).float() |
| self.d[li]["v"] = vs.detach()[0].permute(1, 0, 2).float() |
| return super().update(ks, vs, li, ck) |
|
|
| trace_rows: list[dict] = [] |
| n_used = 0 |
|
|
| for ex in tqdm(subset, desc=f"cheap/{task_name}"): |
| text = prompt_text(task_cfg, ex) |
| ids = tok(text, return_tensors="pt", truncation=True, |
| max_length=max_len).to(dev) |
| if ids["input_ids"].shape[1] < min_len: |
| continue |
|
|
| cache = EvalDump() |
| with torch.no_grad(): |
| model.model(**ids, past_key_values=cache, use_cache=True) |
| n_used += 1 |
|
|
| synth_q = {} |
| for i in full: |
| s = cache.d[i]["k"].shape[0] |
| win = min(s, ATTN_WIN) |
| q_rand = torch.randn(win, n_q, hd, device=dev) |
| synth_q[i] = q_rand / q_rand.norm(dim=-1, keepdim=True).clamp_min(1e-8) |
|
|
| |
| for name, per_layer in vqkv_fitted.items(): |
| if config_filter and name not in config_filter: |
| continue |
| acc = collections.defaultdict(float) |
| nL = 0 |
| for i in full: |
| k = cache.d[i]["k"] |
| v = cache.d[i]["v"] |
| q = per_layer[i] |
| s, h, d = k.shape |
| per_channel = ( |
| isinstance(q, KIVIScalarKV) |
| or getattr(q, "per_channel_key", False) |
| ) |
| if per_channel: |
| k_hat = torch.stack([q.roundtrip_k(k[:, hh, :]) for hh in range(h)], 1) |
| v_hat = torch.stack([q.roundtrip_v(v[:, hh, :]) for hh in range(h)], 1) |
| else: |
| k_hat = q.roundtrip_k(k.reshape(-1, d)).reshape(s, h, d) |
| v_hat = q.roundtrip_v(v.reshape(-1, d)).reshape(s, h, d) |
|
|
| acc["key_cos"] += key_cosine(k, k_hat) |
| acc["val_cos"] += key_cosine(v, v_hat) |
| acc["key_mse"] += cache_mse(k, k_hat) |
| acc["val_mse"] += cache_mse(v, v_hat) |
|
|
| win = min(s, ATTN_WIN) |
| kw, kw_hat = k[-win:], k_hat[-win:] |
| vw, vw_hat = v[-win:], v_hat[-win:] |
| q_syn = synth_q[i] |
|
|
| out_ref, _ = attention_output(q_syn, kw, vw, n_q) |
| out_hat, _ = attention_output(q_syn, kw_hat, vw_hat, n_q) |
| acc["attn_cos"] += attn_output_cosine(out_ref, out_hat) |
| acc["attn_output_error"] += attn_output_error(out_ref, out_hat) |
|
|
| ip = inner_product_distortion(q_syn, kw, kw_hat) |
| acc["ip_rel"] += ip["ip_rel_err"] |
| acc["ip_bias"] += ip["ip_bias"] |
| nL += 1 |
|
|
| trace_rows.append({ |
| "task": task_name, |
| "trace_len": ids["input_ids"].shape[1], |
| "config": name, |
| **{col: acc[col] / nL for col in CHEAP_COLS}, |
| }) |
|
|
| |
| for name, per_layer in turbo_fitted.items(): |
| if config_filter and name not in config_filter: |
| continue |
| acc = collections.defaultdict(float) |
| nL = 0 |
| for i in full: |
| k = cache.d[i]["k"] |
| v = cache.d[i]["v"] |
| e = per_layer[i] |
| s, h, d = k.shape |
| k_hat = e["kq"].roundtrip(k.reshape(-1, d)).reshape(s, h, d) |
| v_hat = e["vq"].roundtrip(v.reshape(-1, d)).reshape(s, h, d) |
|
|
| acc["key_cos"] += key_cosine(k, k_hat) |
| acc["val_cos"] += key_cosine(v, v_hat) |
| acc["key_mse"] += cache_mse(k, k_hat) |
| acc["val_mse"] += cache_mse(v, v_hat) |
|
|
| win = min(s, ATTN_WIN) |
| kw, kw_hat = k[-win:], k_hat[-win:] |
| vw, vw_hat = v[-win:], v_hat[-win:] |
| q_syn = synth_q[i] |
|
|
| out_ref, _ = attention_output(q_syn, kw, vw, n_q) |
| out_hat, _ = attention_output(q_syn, kw_hat, vw_hat, n_q) |
| acc["attn_cos"] += attn_output_cosine(out_ref, out_hat) |
| acc["attn_output_error"] += attn_output_error(out_ref, out_hat) |
|
|
| q0 = q_syn[:, 0, :] |
| kr0 = kw[:, 0, :] |
| kh0 = kw_hat[:, 0, :] |
| qi = torch.randint(0, win, (4096,), device=dev) |
| ki = torch.randint(0, win, (4096,), device=dev) |
| ip_ref = (q0[qi] * kr0[ki]).sum(-1) |
| if e["qjl"] is not None: |
| ip_hat = e["qjl"].estimate_ip(q0[qi], kr0[ki]) |
| else: |
| ip_hat = (q0[qi] * kh0[ki]).sum(-1) |
| acc["ip_bias"] += (ip_hat - ip_ref).mean().item() |
| acc["ip_rel"] += ((ip_hat - ip_ref).abs() / |
| ip_ref.abs().clamp_min(1e-6)).mean().item() |
| nL += 1 |
|
|
| trace_rows.append({ |
| "task": task_name, |
| "trace_len": ids["input_ids"].shape[1], |
| "config": name, |
| **{col: acc[col] / nL for col in CHEAP_COLS}, |
| }) |
|
|
| print(f"[cheap/{task_name}] used {n_used}/{len(subset)} traces (min_len={min_len})") |
| return trace_rows |
|
|
|
|
| def stage_cheap(tasks: list[str], n_eval: int, config_filter: set[str] | None): |
| model, tok, model_meta = load_model_and_meta() |
| dev = model.device |
| if str(dev) == "meta": |
| dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
| vqkv_fitted, cb_meta = load_vqkv_fitted(dev) |
| turbo_fitted, turbo_cfg_lookup = load_turbo_fitted(dev) |
| meta = cb_meta if cb_meta else model_meta |
| hd = meta["head_dim"] |
|
|
| bpe_lookup: dict[str, float] = {} |
| for name, per_layer in vqkv_fitted.items(): |
| q0 = next(iter(per_layer.values())) |
| bpe_lookup[name] = round(q0.bits_per_element(hd), 4) |
| for name in turbo_cfg_lookup: |
| if name in turbo_fitted: |
| b, use_qjl = turbo_cfg_lookup[name] |
| bpe_lookup[name] = round(b + (1 if use_qjl else 0) + 16.0 / hd, 4) |
|
|
| if config_filter: |
| bpe_lookup = {k: v for k, v in bpe_lookup.items() if k in config_filter} |
| print(f"[cheap] config filter: {len(bpe_lookup)} configs") |
|
|
| all_trace_rows: list[dict] = [] |
| all_summary: list[dict] = [] |
|
|
| for task_name in tasks: |
| task_cfg = TASK_CONFIGS[task_name] |
| |
| min_len = 512 if task_name == "trec" else 2048 |
| rows = run_cheap_task( |
| model, tok, meta, task_name, task_cfg, |
| vqkv_fitted, turbo_fitted, turbo_cfg_lookup, bpe_lookup, |
| config_filter, n_eval, min_len=min_len, |
| ) |
| all_trace_rows.extend(rows) |
| summary = _aggregate_cheap_rows(rows, bpe_lookup) |
| for row in summary: |
| row["task"] = task_name |
| all_summary.extend(summary) |
| _print_cheap_table(task_name, summary) |
|
|
| out_path = os.path.join(ARTIFACT_DIR, "longbench_cheap_metrics.json") |
| with open(out_path, "w") as fh: |
| json.dump(all_summary, fh, indent=2) |
| print(f"\n[cheap] saved -> {out_path} ({len(all_summary)} task×config rows)") |
|
|
|
|
| |
| |
| |
| def run_task(model, tok, task_name: str, task_cfg: dict, |
| configs: list, n_eval: int) -> dict[str, float]: |
| """Run every config on n_eval examples. Returns {config_name: mean score}.""" |
| ds = _load_longbench(task_name) |
| subset = ds.select(range(min(n_eval, len(ds)))) |
| print(f"\n[{task_name}] {len(subset)} examples, " |
| f"max_len={task_cfg['max_len']}, max_new={task_cfg['max_new_tokens']}") |
|
|
| score_fn = task_cfg.get("score_fn", qa_f1_score) |
| max_new = task_cfg["max_new_tokens"] |
| max_len = task_cfg["max_len"] |
| scores: dict[str, list[float]] = {name: [] for name, _, _ in configs} |
|
|
| for ex in tqdm(subset, desc=task_name): |
| prompt = prompt_text(task_cfg, ex) |
| refs = (ex["answers"] if isinstance(ex["answers"], list) |
| else [ex["answers"]]) |
| |
| all_classes = ex.get("all_classes") or None |
| ids = tok(prompt, return_tensors="pt", truncation=True, |
| max_length=max_len).to(model.device) |
|
|
| for name, _bpe, cache_cls in configs: |
| cache = cache_cls() if cache_cls is not None else None |
| with torch.no_grad(): |
| out = model.generate( |
| **ids, |
| max_new_tokens=max_new, |
| do_sample=False, |
| past_key_values=cache, |
| use_cache=True, |
| ) |
| pred = tok.decode(out[0, ids["input_ids"].shape[1]:], |
| skip_special_tokens=True).strip() |
| scores[name].append(score_fn(pred, refs, all_classes=all_classes)) |
|
|
| return {name: sum(vs) / len(vs) for name, vs in scores.items() if vs} |
|
|
|
|
| |
| |
| |
| def stage_generate(tasks: list[str], n_eval: int, config_filter: list[str] | None): |
| model, tok, meta = load_model_and_meta() |
| device = model.device |
| if str(device) == "meta": |
| device = torch.device("cuda:0") |
|
|
| configs = build_all_configs(meta, device, only=config_filter) |
| print(f"\n[generate] {len(configs)} configs × {len(tasks)} tasks × " |
| f"{n_eval} examples each") |
|
|
| all_results: dict[str, dict[str, float]] = {} |
| for task_name in tasks: |
| all_results[task_name] = run_task( |
| model, tok, task_name, TASK_CONFIGS[task_name], configs, n_eval) |
|
|
| records = [] |
| print(f"\n{'task':<22} {'config':<32} {'bpe':>5} {'score':>6} {'Δ':>8}") |
| print("-" * 78) |
| for task_name in tasks: |
| task_scores = all_results[task_name] |
| fp16_score = task_scores.get("fp16") |
| for name, bpe, _ in configs: |
| score = task_scores.get(name) |
| delta = (None if (score is None or fp16_score is None) |
| else round(score - fp16_score, 4)) |
| delta_str = f"{delta:+.4f}" if delta is not None else " —" |
| print(f"{task_name:<22} {name:<32} {bpe:5.2f} " |
| f"{score:6.4f} {delta_str}") |
| records.append({ |
| "task": task_name, |
| "config": name, |
| "bpe": bpe, |
| "score": round(score, 4) if score is not None else None, |
| "delta": delta, |
| }) |
| print() |
|
|
| out_path = os.path.join(ARTIFACT_DIR, "longbench_results.json") |
| with open(out_path, "w") as fh: |
| json.dump(records, fh, indent=2) |
| print(f"[generate] saved -> {out_path}") |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser( |
| description=__doc__, |
| formatter_class=argparse.RawDescriptionHelpFormatter) |
| ap.add_argument("--stage", choices=["cheap", "generate"], default="cheap", |
| help="cheap: KV proxy metrics (default); generate: task scores") |
| ap.add_argument("--tasks", nargs="+", default=list(TASK_CONFIGS), |
| choices=list(TASK_CONFIGS), |
| help="LongBench tasks to run (default: all 6)") |
| ap.add_argument("--n_eval", type=int, default=50, |
| help="Examples per task (default: 50)") |
| ap.add_argument("--configs", nargs="+", default=None, |
| help="Config names to include (default: all). " |
| "E.g. --configs productvq-32x256-2b turbo-mse-2b") |
| args = ap.parse_args() |
|
|
| config_filter = set(args.configs) if args.configs else None |
|
|
| if args.stage == "cheap": |
| stage_cheap(args.tasks, args.n_eval, config_filter) |
| else: |
| stage_generate(args.tasks, args.n_eval, args.configs) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|