"""
benchmark.py — fit AttnVQ codebooks and evaluate on real Laguna-XS.2 caches.
Stages:
dump capture post-RoPE K/V from full-attention layers
fit per-layer LBG codebooks → artifacts/codebooks.pt
cheap proxy metrics (key cosine, attn-output error, ip-bias)
swebench optional resolve-rate delta on SWE-bench Verified (needs Docker)
Usage:
python benchmark.py --stage fit
python benchmark.py --stage cheap --n_eval 64
python benchmark.py --stage dump
"""
from __future__ import annotations
import argparse
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, asdict
import torch
from vqkv.quantizers import (ScalarKV, KIVIScalarKV, ProductVQKV, RoPESplitVQKV,
SignScalarKV, TernaryScalarKV)
from vqkv.metrics import (key_cosine, cache_mse, inner_product_distortion,
attention_output, attn_output_cosine, attn_output_error,
calibration_sample_weights)
MODEL_ID = os.environ.get("LAGUNA_ID", "poolside/Laguna-XS.2")
ARTIFACT_DIR = os.environ.get("VQKV_ARTIFACTS", "./artifacts")
os.makedirs(ARTIFACT_DIR, exist_ok=True)
CALIB_DATASET = os.environ.get("CALIB_DATASET", "SWE-bench/SWE-smith-trajectories")
CALIB_SPLIT = os.environ.get("CALIB_SPLIT", "tool")
CALIB_SOURCE = os.environ.get("CALIB_SOURCE", "swesmith") # swesmith | longbench-hotpotqa
_HOTPOTQA_PROMPT = (
"Answer the question based on the given passages. "
"Only give me the answer and do not output any other words.\n\n"
"The following are given passages.\n{context}\n\n"
"Answer the question based on the given passages. "
"Only give me the answer and do not output any other words.\n\n"
"Question: {input}\nAnswer:"
)
# ============================================================================
# Cache configurations under test. Each is a (name, factory) where factory()
# returns an unfitted quantizer. `None` means the fp16 baseline (no quant).
# ============================================================================
def cache_configs():
return [
("fp16 (baseline)", None),
("scalar-int4", lambda: ScalarKV(nbits=4)),
("scalar-int2", lambda: ScalarKV(nbits=2)),
("kivi-int2", lambda: KIVIScalarKV(nbits=2)),
("productvq-64x256-4b", lambda: ProductVQKV(n_sub=64, n_codes=256, iters=15)),
("productvq-32x256-2b", lambda: ProductVQKV(n_sub=32, n_codes=256, iters=15)),
("productvq-16x256-1b", lambda: ProductVQKV(n_sub=16, n_codes=256, iters=15)),
("productvq-8x256-0.5b", lambda: ProductVQKV(n_sub=8, n_codes=256, iters=15)),
("ropesplit-1b", lambda: RoPESplitVQKV(n_sub_half=8, n_codes=256, iters=15)),
("sign-1bit", lambda: SignScalarKV(per_channel_key=True)),
("ternary-bitnet", lambda: TernaryScalarKV(alpha=0.7, per_channel_key=True)),
]
# ============================================================================
# Model + layer-structure loading
# ============================================================================
def load_model_and_meta():
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.bfloat16, device_map="cuda",
trust_remote_code=True)
model.eval()
cfg = model.config
full_layers = [i for i, t in enumerate(cfg.layer_types) if t == "full_attention"]
meta = {
"full_layers": full_layers,
"n_kv_heads": cfg.num_key_value_heads,
"n_q_heads": cfg.num_attention_heads,
"head_dim": cfg.head_dim,
"n_layers": cfg.num_hidden_layers,
}
print(f"[meta] full-attention layers ({len(full_layers)}): {full_layers}")
print(f"[meta] kv_heads={meta['n_kv_heads']} head_dim={meta['head_dim']}")
assert len(full_layers) > 0, "no full-attention layers found; check config"
return model, tok, meta
# Trace flattening (SWE-smith messages JSON, nebius trajectory/tool roles).
def flatten_trace(example, tok) -> str:
"""Flatten one trajectory to a string via the model's chat template."""
raw = (example.get("messages") or example.get("trajectory")
or example.get("conversations"))
if raw is None:
return json.dumps(example)[:200_000]
# SWE-smith stores 'messages' as a JSON string, not a native list.
if isinstance(raw, str):
try:
raw = json.loads(raw)
except json.JSONDecodeError:
return raw[:200_000]
norm = []
for m in raw:
role = m.get("role") or m.get("from") or "user"
content = m.get("content") or m.get("value") or ""
# SWE-smith: content can be a list of {"type":"text","text":"..."} blocks
if isinstance(content, list):
content = "\n".join(
item.get("text", str(item)) if isinstance(item, dict) else str(item)
for item in content
)
# nebius: assistant turns carry tool_calls alongside content
tool_calls = m.get("tool_calls")
if tool_calls:
tc_text = json.dumps(tool_calls, ensure_ascii=False)
content = (content + "\n" + tc_text).strip() if content else tc_text
# 'tool' role (observation) has no equivalent in most chat templates;
# map it to 'user' so the template accepts it.
role = {"human": "user", "gpt": "assistant", "tool": "user"}.get(role, role)
if not content.strip():
continue
norm.append({"role": role, "content": content})
# Merge consecutive same-role messages produced by tool->user collapsing.
merged: list[dict] = []
for m in norm:
if merged and merged[-1]["role"] == m["role"]:
merged[-1]["content"] += "\n\n" + m["content"]
else:
merged.append(dict(m))
try:
return tok.apply_chat_template(merged, tokenize=False,
add_generation_prompt=False)
except Exception:
return "\n\n".join(f"{m['role']}: {m['content']}" for m in merged)
def flatten_longbench(example) -> str:
"""Format a LongBench hotpotqa example (context + input) as a plain string."""
return _HOTPOTQA_PROMPT.format(
context=example["context"], input=example["input"])
def _load_longbench_hotpotqa():
"""Load THUDM/LongBench hotpotqa, bypassing the deprecated dataset script."""
from datasets import load_dataset as _ld
for fname in ("hotpotqa_e.jsonl", "hotpotqa.jsonl"):
try:
return _ld(
"json",
data_files=f"hf://datasets/THUDM/LongBench/data/{fname}",
split="train",
)
except Exception:
continue
return _ld("THUDM/LongBench", name="hotpotqa", split="test")
def _load_calib_dump_dataset(n_calib: int, calib_source: str, tok):
"""Return (dataset, text_fn, label) for stage_dump."""
from datasets import load_dataset
if calib_source == "longbench-hotpotqa":
ds = _load_longbench_hotpotqa()
n_total = len(ds)
# Last n_calib rows avoid overlap with longbench_eval.py (range(0, n_eval)).
start = max(0, n_total - n_calib)
ds = ds.select(range(start, n_total))
label = f"LongBench hotpotqa (rows {start}–{n_total - 1})"
return ds, flatten_longbench, label
ds = load_dataset(CALIB_DATASET, split=CALIB_SPLIT)
label = f"{CALIB_DATASET} split={CALIB_SPLIT}"
return ds, lambda ex: flatten_trace(ex, tok), label
# STAGE: dump -- run real traces, capture post-RoPE K/V from full-attn layers
def stage_dump(n_calib=16, max_len=32768, calib_source: str | None = None,
min_len=2048):
from transformers.cache_utils import DynamicCache
calib_source = calib_source or CALIB_SOURCE
model, tok, meta = load_model_and_meta()
full = set(meta["full_layers"])
# device_map="auto" can leave model.device as 'meta'; always use cuda:0.
input_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class DumpingCache(DynamicCache):
def __init__(self, *a, **k):
super().__init__(*a, **k)
self.dump = {i: {"k": [], "v": []} for i in full}
def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
if layer_idx in full:
self.dump[layer_idx]["k"].append(
key_states.detach()[0].permute(1, 0, 2).float().cpu())
self.dump[layer_idx]["v"].append(
value_states.detach()[0].permute(1, 0, 2).float().cpu())
return super().update(key_states, value_states, layer_idx, cache_kwargs)
ds, get_text, source_label = _load_calib_dump_dataset(n_calib, calib_source, tok)
print(f"[dump] source={calib_source} {source_label} rows={len(ds)} "
f"schema={list(ds[0].keys())}")
agg = {i: {"k": [], "v": []} for i in full}
used = 0
for ex in ds:
if used >= n_calib:
break
text = get_text(ex)
ids = tok(text, return_tensors="pt", truncation=True,
max_length=max_len).to(input_device)
if ids["input_ids"].shape[1] < min_len:
continue
cache = DumpingCache(config=model.config) # fresh per trace
with torch.no_grad():
model.model(**ids, past_key_values=cache, use_cache=True) # skip lm_head
for i in full:
agg[i]["k"].append(torch.cat(cache.dump[i]["k"]))
agg[i]["v"].append(torch.cat(cache.dump[i]["v"]))
used += 1
print(f"[dump] trace {used}/{n_calib} len={ids['input_ids'].shape[1]}")
calib = {i: {"k": torch.cat(agg[i]["k"]), "v": torch.cat(agg[i]["v"])}
for i in full}
for i in list(full)[:3]:
k = calib[i]["k"].reshape(-1, meta["head_dim"])
cut = meta["head_dim"] // 2
print(f"[dump] layer {i}: {tuple(calib[i]['k'].shape)} | "
f"rope-half std {k[:, :cut].std():.3f} "
f"pass-half std {k[:, cut:].std():.3f} "
f"max|ch| {k.abs().amax(0).max():.2f}")
path = os.path.join(ARTIFACT_DIR, "calib_caches.pt")
torch.save({"calib": calib, "meta": meta}, path)
print(f"[dump] saved -> {path}")
# STAGE: fit -- per-layer codebooks for every (data-dependent) config
def stage_fit(only: list[str] | None = None):
"""Fit quantizers and write artifacts/codebooks.pt.
By default fits every config in cache_configs(). Pass ``only=["sign-1bit", ...]``
to fit just those names and merge into an existing codebooks.pt (skips the rest).
Tuning-free quantizers (Sign, Ternary, Scalar, KIVI) finish in seconds; only
ProductVQ / RoPE-split need calib_caches.pt.
"""
codebooks_path = os.path.join(ARTIFACT_DIR, "codebooks.pt")
calib_path = os.path.join(ARTIFACT_DIR, "calib_caches.pt")
if only and os.path.exists(codebooks_path):
existing = torch.load(codebooks_path, weights_only=False)
fitted = existing["fitted"]
meta = existing["meta"]
print(f"[fit] merging into existing {codebooks_path} ({len(fitted)} configs)")
else:
if not os.path.exists(calib_path):
raise FileNotFoundError(
f"{calib_path} not found; run --stage dump first, or use "
f"--only with an existing codebooks.pt for tuning-free configs")
blob = torch.load(calib_path)
calib, meta = blob["calib"], blob["meta"]
fitted = {}
calib = None
if os.path.exists(calib_path):
calib = torch.load(calib_path)["calib"]
hd = meta["head_dim"]
layer_ids = meta["full_layers"]
configs = [(n, f) for n, f in cache_configs() if f is not None]
if only is not None:
only_set = set(only)
configs = [(n, f) for n, f in configs if n in only_set]
unknown = only_set - {n for n, _ in configs}
if unknown:
raise ValueError(f"unknown --only config(s): {sorted(unknown)}")
# GPU fitting: LBG is pure torch; moving calib to CUDA makes bmm/argmin ~20x
# faster. Serial over layers on GPU (CUDA is already async; threading adds no
# benefit). Parallel over layers on CPU (BLAS releases GIL; real concurrency).
fit_device = "cuda" if torch.cuda.is_available() else "cpu"
if fit_device == "cuda":
print(f"[fit] GPU available ({torch.cuda.get_device_name()}) -- fitting LBG on CUDA")
n_workers = 1 if fit_device == "cuda" else min(len(layer_ids), os.cpu_count() or 1)
attn_weighted = any(
isinstance(f(), (ProductVQKV, RoPESplitVQKV)) for _, f in configs
)
if attn_weighted:
print("[fit] ProductVQ / RoPESplit: attention-weighted LBG "
"(centroids weighted by key attention mass)")
n_q = meta.get("n_q_heads", 48)
for name, factory in configs:
t0 = time.time()
def _fit_layer(i, _factory=factory, _device=fit_device):
if calib is not None and i in calib:
k_struct = calib[i]["k"]
v_struct = calib[i]["v"]
if k_struct.shape[0] > 512:
k_struct = k_struct[-512:]
v_struct = v_struct[-512:]
w = calibration_sample_weights(k_struct, n_q)
kf = k_struct.reshape(-1, hd)[:200_000].to(_device)
vf = v_struct.reshape(-1, hd)[:200_000].to(_device)
if w is not None:
w = w[:kf.shape[0]].to(_device)
else:
kf = torch.zeros(1, hd, device=_device)
vf = torch.zeros(1, hd, device=_device)
w = k_struct = None
q = _factory()
if isinstance(q, (ProductVQKV, RoPESplitVQKV)):
q.fit(kf, vf, sample_weights=w, n_q_heads=n_q,
k_struct=k_struct.to(_device) if k_struct is not None else None)
else:
q.fit(kf, vf)
# Codebooks are saved to disk as CPU tensors; move back before returning.
if _device != "cpu" and hasattr(q, "to"):
q.to("cpu")
return i, q
with ThreadPoolExecutor(max_workers=n_workers) as ex:
per_layer = dict(ex.map(_fit_layer, layer_ids))
fitted[name] = per_layer
print(f"[fit] {name}: {len(per_layer)} layer-codebooks in {time.time()-t0:.1f}s")
torch.save({"fitted": fitted, "meta": meta}, codebooks_path)
print(f"[fit] saved -> {codebooks_path} ({len(fitted)} configs total)")
# A drop-in cache that applies a per-layer quantizer to the target layers only.
def make_vq_cache_class(per_layer_quantizers, target_layers, model_config, device=None):
"""Build a VQCache class with codebooks pre-moved to `device`.
Pass device="cuda" (or the model's device) so the roundtrip runs entirely
on GPU with no CPU<->GPU transfers. Without this, every update() call
implicitly transfers key_states to CPU and back, making decode very slow.
"""
from transformers.cache_utils import DynamicCache
if device is not None:
for q in per_layer_quantizers.values():
if hasattr(q, "to"):
q.to(device)
class VQCache(DynamicCache):
def __init__(self, *a, **k):
super().__init__(*a, **k)
self.q = per_layer_quantizers
self.target = set(target_layers)
def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
if layer_idx in self.target and layer_idx in self.q:
q = self.q[layer_idx]
b, h, s, d = key_states.shape
kf = key_states[0].transpose(0, 1).reshape(-1, d).float()
vf = value_states[0].transpose(0, 1).reshape(-1, d).float()
# See stage_cheap: per-channel-key quantizers reduce the key
# along the token axis and need a per-head block, not a flatten.
per_channel_key = (
isinstance(q, KIVIScalarKV)
or getattr(q, "per_channel_key", False)
)
if per_channel_key:
kk = key_states[0].permute(1, 0, 2).float() # (s, h, d)
k_hat = torch.stack([q.roundtrip_k(kk[:, hh, :]) for hh in range(h)], 1)
vv = value_states[0].permute(1, 0, 2).float()
v_hat = torch.stack([q.roundtrip_v(vv[:, hh, :]) for hh in range(h)], 1)
k_hat = k_hat.permute(1, 0, 2)[None]
v_hat = v_hat.permute(1, 0, 2)[None]
else:
k_hat = q.roundtrip_k(kf).reshape(s, h, d).permute(1, 0, 2)[None]
v_hat = q.roundtrip_v(vf).reshape(s, h, d).permute(1, 0, 2)[None]
key_states = k_hat.to(key_states.dtype).to(key_states.device)
value_states = v_hat.to(value_states.dtype).to(value_states.device)
return super().update(key_states, value_states, layer_idx, cache_kwargs)
return VQCache
# STAGE: cheap -- tier-1/2 metrics on held-out trace windows (no test suites)
def stage_cheap(n_eval=64, max_len=16384):
import collections
from datasets import load_dataset
blob = torch.load(os.path.join(ARTIFACT_DIR, "codebooks.pt"), weights_only=False)
fitted, meta = blob["fitted"], blob["meta"]
hd, full = meta["head_dim"], meta["full_layers"]
n_q = meta.get("n_q_heads", 48) # fallback for codebooks.pt written before this field
# Window for O(T²) attention metrics. 512 tokens keeps peak mem <200MB on GPU.
ATTN_WIN = 512
model, tok, _ = load_model_and_meta()
from tqdm import tqdm
from transformers.cache_utils import DynamicCache
# Move every fitted codebook to the model device ONCE. The roundtrip then
# runs entirely on-GPU against the on-GPU eval caches (see EvalDump below).
# ScalarKV/KIVI are tuning-free and have no .to(); they follow their input.
for per_layer in fitted.values():
for q in per_layer.values():
if hasattr(q, "to"):
q.to(model.device)
class EvalDump(DynamicCache):
def __init__(self):
super().__init__(); self.d = {i: {} for i in full}
def update(self, ks, vs, li, ck=None):
if li in set(full):
# Keep on-GPU: the quantizer roundtrip is a nearest-neighbour
# search -- a GPU job. Moving to CPU here dominated runtime.
self.d[li]["k"] = ks.detach()[0].permute(1, 0, 2).float()
self.d[li]["v"] = vs.detach()[0].permute(1, 0, 2).float()
return super().update(ks, vs, li, ck)
ds = load_dataset(CALIB_DATASET, split=CALIB_SPLIT)
# Per-trace rows (not saved to disk; aggregated below)
trace_rows = []
# Held-out slice. Start at 500 to safely clear any rows the dump stage
# consumed (dump skips short traces, so actual rows used >> n_calib=16).
for ex in tqdm(ds.select(range(500, 500 + n_eval))):
text = flatten_trace(ex, tok)
ids = tok(text, return_tensors="pt", truncation=True, max_length=max_len).to(model.device)
cache = EvalDump()
with torch.no_grad():
# model.model() skips lm_head: avoids a ~3 GB allocation per trace
# (max_len * vocab_size * 2 bytes) that is not needed for KV metrics.
model.model(**ids, past_key_values=cache, use_cache=True)
# Synthetic Q for attention/IP metrics: generated once per (trace, layer)
# and reused across all configs so the comparison is apples-to-apples.
# Unit-normalised so inner-product scale doesn't swamp the bias signal.
synth_q = {}
for i in full:
s = cache.d[i]["k"].shape[0]
win = min(s, ATTN_WIN)
q_rand = torch.randn(win, n_q, hd, device=cache.d[i]["k"].device)
synth_q[i] = q_rand / q_rand.norm(dim=-1, keepdim=True).clamp_min(1e-8)
for name, _ in cache_configs():
if name == "fp16 (baseline)":
continue
per_layer = fitted[name]
acc = collections.defaultdict(float)
nL = 0
for i in full:
k = cache.d[i]["k"] # (s, h, d)
v = cache.d[i]["v"]
q = per_layer[i]
s, h, d = k.shape
# Per-channel-key quantizers (KIVI, Sign, Ternary with
# per_channel_key=True) reduce the KEY along the TOKEN axis
# (dim=0). They must see one (s, d) block PER HEAD; flattening
# (s,h,d)->(s*h,d) would mix tokens across heads into one scale
# and corrupt the key metric. VQ/scalar-per-token quantizers
# reduce along dim=-1 and are safe to flatten.
per_channel_key = (
isinstance(q, KIVIScalarKV)
or getattr(q, "per_channel_key", False)
)
if per_channel_key:
k_hat = torch.stack([q.roundtrip_k(k[:, hh, :]) for hh in range(h)], 1)
v_hat = torch.stack([q.roundtrip_v(v[:, hh, :]) for hh in range(h)], 1)
else:
k_hat = q.roundtrip_k(k.reshape(-1, d)).reshape(s, h, d)
v_hat = q.roundtrip_v(v.reshape(-1, d)).reshape(s, h, d)
acc["key_cos"] += key_cosine(k, k_hat)
acc["val_cos"] += key_cosine(v, v_hat)
acc["key_mse"] += cache_mse(k, k_hat)
acc["val_mse"] += cache_mse(v, v_hat)
# Windowed attention/IP metrics on last ATTN_WIN tokens
win = min(s, ATTN_WIN)
kw, kw_hat = k[-win:], k_hat[-win:]
vw, vw_hat = v[-win:], v_hat[-win:]
q_syn = synth_q[i] # (win, n_q, d)
out_ref, _ = attention_output(q_syn, kw, vw, n_q)
out_hat, _ = attention_output(q_syn, kw_hat, vw_hat, n_q)
acc["attn_cos"] += attn_output_cosine(out_ref, out_hat)
acc["attn_output_error"] += attn_output_error(out_ref, out_hat)
ip = inner_product_distortion(q_syn, kw, kw_hat)
acc["ip_rel"] += ip["ip_rel_err"]
acc["ip_bias"] += ip["ip_bias"]
nL += 1
trace_rows.append({
"trace_len": ids["input_ids"].shape[1],
"config": name,
"key_cos": acc["key_cos"] / nL,
"val_cos": acc["val_cos"] / nL,
"key_mse": acc["key_mse"] / nL,
"val_mse": acc["val_mse"] / nL,
"attn_cos": acc["attn_cos"] / nL,
"attn_output_error": acc["attn_output_error"] / nL,
"ip_rel": acc["ip_rel"] / nL,
"ip_bias": acc["ip_bias"] / nL,
})
# Aggregate across traces: one summary row per config (what gets saved)
agg = collections.defaultdict(lambda: collections.defaultdict(list))
for r in trace_rows:
for col in ("key_cos", "val_cos", "key_mse", "val_mse",
"attn_cos", "attn_output_error", "ip_rel", "ip_bias"):
agg[r["config"]][col].append(r[col])
COLS = ("key_cos", "val_cos", "key_mse", "val_mse",
"attn_cos", "attn_output_error", "ip_rel", "ip_bias")
summary = []
for name, _ in cache_configs():
if name not in agg:
continue
cols = agg[name]
n = len(cols["key_cos"])
q0 = next(iter(fitted[name].values()))
row = {"config": name, "bits_per_elt": round(q0.bits_per_element(hd), 4), "n_traces": n}
for col in COLS:
row[col] = round(sum(cols[col]) / n, 5)
summary.append(row)
print(f"\n[cheap] mean metrics over {n_eval} held-out traces:")
print(f" {'config':24s} {'bpe':>5} {'key_cos':>8} {'val_cos':>8} "
f"{'key_mse':>9} {'val_mse':>9} {'attn_cos':>9} {'attn_err':>9} "
f"{'ip_rel':>8} {'ip_bias':>9}")
for row in summary:
print(f" {row['config']:24s} {row['bits_per_elt']:5.2f} "
f"{row['key_cos']:8.4f} {row['val_cos']:8.4f} "
f"{row['key_mse']:9.5f} {row['val_mse']:9.5f} "
f"{row['attn_cos']:9.4f} {row['attn_output_error']:9.4f} "
f"{row['ip_rel']:8.5f} {row['ip_bias']:9.6f}")
out_path = os.path.join(ARTIFACT_DIR, "cheap_metrics.json")
json.dump(summary, open(out_path, "w"), indent=2)
print(f"[cheap] saved -> {out_path}")
# Optional SWE-bench Verified eval (requires Docker + swebench).
_AGENT_SYSTEM = (
"You are an expert software engineer fixing a GitHub issue. "
"You have a bash shell inside the repository checked out at the failing commit. "
"Use command tags to run shell commands. "
"Explore the code, implement the fix, then output when done."
)
def _agent_loop(model, tok, task: dict, cache_factory, max_turns: int,
max_new: int = 1024) -> str:
"""Run a minimal ReAct-bash loop on one SWE-bench task.
Clones the repo at base_commit into a temp dir, runs the model in a
generate→bash→observe loop, and returns the final `git diff HEAD` patch.
Each generate call rebuilds the full context from scratch so the VQCache
sees the compounding long-context pressure that the project targets.
"""
import re
import shutil
import subprocess
import tempfile
repo_dir = tempfile.mkdtemp(prefix="sweagent_")
try:
subprocess.run(
["git", "clone", f"https://github.com/{task['repo']}.git", repo_dir],
check=True, capture_output=True, timeout=120,
)
subprocess.run(
["git", "checkout", task["base_commit"]],
check=True, capture_output=True, cwd=repo_dir, timeout=30,
)
except Exception as exc:
print(f" [agent] clone/checkout failed for {task['instance_id']}: {exc}")
shutil.rmtree(repo_dir, ignore_errors=True)
return ""
try:
messages = [
{"role": "system", "content": _AGENT_SYSTEM},
{"role": "user", "content": (
f"Repository: {task['repo']}\n\n"
f"Issue:\n{task['problem_statement']}"
)},
]
for _ in range(max_turns):
prompt = tok.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
ids = tok(prompt, return_tensors="pt", truncation=True,
max_length=32768).to(model.device)
cache = cache_factory()
with torch.no_grad():
out = model.generate(
**ids, max_new_tokens=max_new, do_sample=False,
past_key_values=cache, use_cache=True,
)
gen = tok.decode(out[0, ids["input_ids"].shape[1]:],
skip_special_tokens=True)
messages.append({"role": "assistant", "content": gen})
if re.search(r"", gen, re.I):
break
cmds = re.findall(r"(.*?)", gen, re.DOTALL)
if not cmds:
break # model stopped issuing commands; take whatever diff we have
obs_parts = []
for cmd in cmds:
try:
r = subprocess.run(
cmd, shell=True, capture_output=True, text=True,
timeout=30, cwd=repo_dir,
)
obs_parts.append(
f"$ {cmd.strip()}\n{(r.stdout + r.stderr)[:2000]}")
except subprocess.TimeoutExpired:
obs_parts.append(f"$ {cmd.strip()}\n[timeout after 30s]")
messages.append({"role": "user", "content": "\n\n".join(obs_parts)})
diff = subprocess.run(
["git", "diff", "HEAD"],
capture_output=True, text=True, cwd=repo_dir,
)
return diff.stdout
finally:
shutil.rmtree(repo_dir, ignore_errors=True)
def run_swebench_subset(model, tok, cache_factory_per_layer, target_layers,
task_ids, max_turns=100):
"""Run the official SWE-bench Verified harness over `task_ids`.
Builds per-config VQCache (or None for fp16), runs the mini-bash agent on
each task, writes predictions.jsonl, evaluates with the swebench harness,
and returns resolve rate in [0, 1]. Keep `task_ids` FIXED across configs.
"""
import glob
from datasets import load_dataset
# Build cache factory: VQCache for quantized configs, None for fp16.
if cache_factory_per_layer is not None:
VQCls = make_vq_cache_class(cache_factory_per_layer, target_layers,
model.config, device=model.device)
def cache_factory(): return VQCls()
else:
def cache_factory(): return None # model allocates DynamicCache internally
# Load task metadata indexed by instance_id.
verified = load_dataset("princeton-nlp/SWE-bench_Verified", split="test")
tasks = {r["instance_id"]: r for r in verified
if r["instance_id"] in set(task_ids)}
# Generate patches with the agent.
run_id = f"vqkv_{int(time.time())}"
predictions = []
for task_id in task_ids:
task = tasks.get(task_id)
if task is None:
print(f"[swebench] {task_id}: not in Verified dataset, skipping")
continue
print(f"[swebench] {task_id} ({len(predictions)+1}/{len(task_ids)})")
patch = _agent_loop(model, tok, task, cache_factory, max_turns=max_turns)
predictions.append({
"instance_id": task_id,
"model_patch": patch,
"model_name_or_path": "laguna-vqkv",
})
preds_path = os.path.join(ARTIFACT_DIR, f"{run_id}.jsonl")
with open(preds_path, "w") as f:
for p in predictions:
f.write(json.dumps(p) + "\n")
print(f"[swebench] wrote {len(predictions)} predictions -> {preds_path}")
# Run the official harness (needs Docker daemon).
# pip install swebench
from swebench.harness.run_evaluation import main as run_evaluation
run_evaluation(
dataset_name_or_path="princeton-nlp/SWE-bench_Verified",
split="test",
instance_ids=task_ids,
predictions_path=preds_path,
max_workers=4,
force_rebuild=False,
cache_level="env",
clean=False,
open_file_limit=4096,
run_id=run_id,
timeout=1800,
)
# Parse results. swebench writes a JSON summary; location varies by version.
result_files = (
glob.glob(os.path.join(ARTIFACT_DIR, f"{run_id}*.json"))
+ glob.glob(f"{run_id}*.json") # also check cwd
)
if not result_files:
print(f"[swebench] WARNING: no results file found for run_id={run_id}")
return 0.0
with open(result_files[0]) as f:
results = json.load(f)
if isinstance(results, list):
n_resolved = sum(1 for r in results if r.get("resolved", False))
elif isinstance(results, dict):
# some harness versions use {instance_id: {resolved: bool, ...}}
n_resolved = sum(1 for v in results.values()
if (v.get("resolved") if isinstance(v, dict) else v))
else:
n_resolved = 0
return n_resolved / len(task_ids)
def stage_swebench(n_tasks=50, seed=0):
import random
from datasets import load_dataset
blob = torch.load(os.path.join(ARTIFACT_DIR, "codebooks.pt"))
fitted, meta = blob["fitted"], blob["meta"]
model, tok, _ = load_model_and_meta()
verified = load_dataset("princeton-nlp/SWE-bench_Verified", split="test")
rng = random.Random(seed)
task_ids = [verified[i]["instance_id"]
for i in rng.sample(range(len(verified)), n_tasks)]
json.dump(task_ids, open(os.path.join(ARTIFACT_DIR, "task_subset.json"), "w"))
print(f"[swebench] fixed subset of {n_tasks} tasks (seed={seed}) saved.")
results = []
for name, _ in cache_configs():
quantizers = None if name == "fp16 (baseline)" else fitted[name]
try:
rate = run_swebench_subset(model, tok, quantizers,
meta["full_layers"], task_ids)
except NotImplementedError as e:
print(f"[swebench] {name}: STUB -- {e}")
rate = None
results.append({"config": name, "resolve_rate": rate})
print(f"[swebench] {name}: resolve_rate={rate}")
# report DELTAS vs fp16 (robust to absolute-score contamination)
base = next((r["resolve_rate"] for r in results
if r["config"] == "fp16 (baseline)"), None)
print("\n[swebench] resolve rate on fixed subset (delta vs fp16):")
for r in results:
d = (None if (r["resolve_rate"] is None or base is None)
else round(r["resolve_rate"] - base, 4))
print(f" {r['config']:24s} {r['resolve_rate']} (Δ {d})")
json.dump(results, open(os.path.join(ARTIFACT_DIR, "swebench_results.json"), "w"),
indent=2)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--stage", required=True,
choices=["dump", "fit", "cheap", "swebench"])
ap.add_argument("--n_calib", type=int, default=16)
ap.add_argument(
"--calib_source", type=str, default=None,
choices=["swesmith", "longbench-hotpotqa"],
help="dump stage: calibration corpus (default: CALIB_SOURCE env or swesmith)")
ap.add_argument("--n_eval", type=int, default=64)
ap.add_argument("--n_tasks", type=int, default=50)
ap.add_argument(
"--only", type=str, default=None,
help="fit stage: comma-separated config names to fit/merge (e.g. "
"sign-1bit,ternary-bitnet). Skips refitting other configs.")
args = ap.parse_args()
if args.stage == "dump":
stage_dump(n_calib=args.n_calib, calib_source=args.calib_source)
elif args.stage == "fit":
only = [s.strip() for s in args.only.split(",")] if args.only else None
stage_fit(only=only)
elif args.stage == "cheap":
stage_cheap(n_eval=args.n_eval)
elif args.stage == "swebench":
stage_swebench(n_tasks=args.n_tasks)
if __name__ == "__main__":
main()