| """ |
| benchmark.py — fit AttnVQ codebooks and evaluate on real Laguna-XS.2 caches. |
| |
| Stages: |
| dump capture post-RoPE K/V from full-attention layers |
| fit per-layer LBG codebooks → artifacts/codebooks.pt |
| cheap proxy metrics (key cosine, attn-output error, ip-bias) |
| swebench optional resolve-rate delta on SWE-bench Verified (needs Docker) |
| |
| Usage: |
| python benchmark.py --stage fit |
| python benchmark.py --stage cheap --n_eval 64 |
| python benchmark.py --stage dump |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import time |
| from concurrent.futures import ThreadPoolExecutor |
| from dataclasses import dataclass, asdict |
|
|
| import torch |
|
|
| from vqkv.quantizers import (ScalarKV, KIVIScalarKV, ProductVQKV, RoPESplitVQKV, |
| SignScalarKV, TernaryScalarKV) |
| from vqkv.metrics import (key_cosine, cache_mse, inner_product_distortion, |
| attention_output, attn_output_cosine, attn_output_error, |
| calibration_sample_weights) |
|
|
| MODEL_ID = os.environ.get("LAGUNA_ID", "poolside/Laguna-XS.2") |
| ARTIFACT_DIR = os.environ.get("VQKV_ARTIFACTS", "./artifacts") |
| os.makedirs(ARTIFACT_DIR, exist_ok=True) |
|
|
| CALIB_DATASET = os.environ.get("CALIB_DATASET", "SWE-bench/SWE-smith-trajectories") |
| CALIB_SPLIT = os.environ.get("CALIB_SPLIT", "tool") |
| CALIB_SOURCE = os.environ.get("CALIB_SOURCE", "swesmith") |
|
|
| _HOTPOTQA_PROMPT = ( |
| "Answer the question based on the given passages. " |
| "Only give me the answer and do not output any other words.\n\n" |
| "The following are given passages.\n{context}\n\n" |
| "Answer the question based on the given passages. " |
| "Only give me the answer and do not output any other words.\n\n" |
| "Question: {input}\nAnswer:" |
| ) |
|
|
| |
| |
| |
| |
| def cache_configs(): |
| return [ |
| ("fp16 (baseline)", None), |
| ("scalar-int4", lambda: ScalarKV(nbits=4)), |
| ("scalar-int2", lambda: ScalarKV(nbits=2)), |
| ("kivi-int2", lambda: KIVIScalarKV(nbits=2)), |
| ("productvq-64x256-4b", lambda: ProductVQKV(n_sub=64, n_codes=256, iters=15)), |
| ("productvq-32x256-2b", lambda: ProductVQKV(n_sub=32, n_codes=256, iters=15)), |
| ("productvq-16x256-1b", lambda: ProductVQKV(n_sub=16, n_codes=256, iters=15)), |
| ("productvq-8x256-0.5b", lambda: ProductVQKV(n_sub=8, n_codes=256, iters=15)), |
| ("ropesplit-1b", lambda: RoPESplitVQKV(n_sub_half=8, n_codes=256, iters=15)), |
| ("sign-1bit", lambda: SignScalarKV(per_channel_key=True)), |
| ("ternary-bitnet", lambda: TernaryScalarKV(alpha=0.7, per_channel_key=True)), |
| ] |
|
|
|
|
| |
| |
| |
| def load_model_and_meta(): |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, torch_dtype=torch.bfloat16, device_map="cuda", |
| trust_remote_code=True) |
| model.eval() |
| cfg = model.config |
| full_layers = [i for i, t in enumerate(cfg.layer_types) if t == "full_attention"] |
| meta = { |
| "full_layers": full_layers, |
| "n_kv_heads": cfg.num_key_value_heads, |
| "n_q_heads": cfg.num_attention_heads, |
| "head_dim": cfg.head_dim, |
| "n_layers": cfg.num_hidden_layers, |
| } |
| print(f"[meta] full-attention layers ({len(full_layers)}): {full_layers}") |
| print(f"[meta] kv_heads={meta['n_kv_heads']} head_dim={meta['head_dim']}") |
| assert len(full_layers) > 0, "no full-attention layers found; check config" |
| return model, tok, meta |
|
|
|
|
| |
| def flatten_trace(example, tok) -> str: |
| """Flatten one trajectory to a string via the model's chat template.""" |
| raw = (example.get("messages") or example.get("trajectory") |
| or example.get("conversations")) |
| if raw is None: |
| return json.dumps(example)[:200_000] |
|
|
| |
| if isinstance(raw, str): |
| try: |
| raw = json.loads(raw) |
| except json.JSONDecodeError: |
| return raw[:200_000] |
|
|
| norm = [] |
| for m in raw: |
| role = m.get("role") or m.get("from") or "user" |
| content = m.get("content") or m.get("value") or "" |
|
|
| |
| if isinstance(content, list): |
| content = "\n".join( |
| item.get("text", str(item)) if isinstance(item, dict) else str(item) |
| for item in content |
| ) |
|
|
| |
| tool_calls = m.get("tool_calls") |
| if tool_calls: |
| tc_text = json.dumps(tool_calls, ensure_ascii=False) |
| content = (content + "\n" + tc_text).strip() if content else tc_text |
|
|
| |
| |
| role = {"human": "user", "gpt": "assistant", "tool": "user"}.get(role, role) |
|
|
| if not content.strip(): |
| continue |
| norm.append({"role": role, "content": content}) |
|
|
| |
| merged: list[dict] = [] |
| for m in norm: |
| if merged and merged[-1]["role"] == m["role"]: |
| merged[-1]["content"] += "\n\n" + m["content"] |
| else: |
| merged.append(dict(m)) |
|
|
| try: |
| return tok.apply_chat_template(merged, tokenize=False, |
| add_generation_prompt=False) |
| except Exception: |
| return "\n\n".join(f"{m['role']}: {m['content']}" for m in merged) |
|
|
|
|
| def flatten_longbench(example) -> str: |
| """Format a LongBench hotpotqa example (context + input) as a plain string.""" |
| return _HOTPOTQA_PROMPT.format( |
| context=example["context"], input=example["input"]) |
|
|
|
|
| def _load_longbench_hotpotqa(): |
| """Load THUDM/LongBench hotpotqa, bypassing the deprecated dataset script.""" |
| from datasets import load_dataset as _ld |
| for fname in ("hotpotqa_e.jsonl", "hotpotqa.jsonl"): |
| try: |
| return _ld( |
| "json", |
| data_files=f"hf://datasets/THUDM/LongBench/data/{fname}", |
| split="train", |
| ) |
| except Exception: |
| continue |
| return _ld("THUDM/LongBench", name="hotpotqa", split="test") |
|
|
|
|
| def _load_calib_dump_dataset(n_calib: int, calib_source: str, tok): |
| """Return (dataset, text_fn, label) for stage_dump.""" |
| from datasets import load_dataset |
|
|
| if calib_source == "longbench-hotpotqa": |
| ds = _load_longbench_hotpotqa() |
| n_total = len(ds) |
| |
| start = max(0, n_total - n_calib) |
| ds = ds.select(range(start, n_total)) |
| label = f"LongBench hotpotqa (rows {start}–{n_total - 1})" |
| return ds, flatten_longbench, label |
|
|
| ds = load_dataset(CALIB_DATASET, split=CALIB_SPLIT) |
| label = f"{CALIB_DATASET} split={CALIB_SPLIT}" |
| return ds, lambda ex: flatten_trace(ex, tok), label |
|
|
|
|
| |
| def stage_dump(n_calib=16, max_len=32768, calib_source: str | None = None, |
| min_len=2048): |
| from transformers.cache_utils import DynamicCache |
|
|
| calib_source = calib_source or CALIB_SOURCE |
| model, tok, meta = load_model_and_meta() |
| full = set(meta["full_layers"]) |
|
|
| |
| input_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
| class DumpingCache(DynamicCache): |
| def __init__(self, *a, **k): |
| super().__init__(*a, **k) |
| self.dump = {i: {"k": [], "v": []} for i in full} |
| def update(self, key_states, value_states, layer_idx, cache_kwargs=None): |
| if layer_idx in full: |
| self.dump[layer_idx]["k"].append( |
| key_states.detach()[0].permute(1, 0, 2).float().cpu()) |
| self.dump[layer_idx]["v"].append( |
| value_states.detach()[0].permute(1, 0, 2).float().cpu()) |
| return super().update(key_states, value_states, layer_idx, cache_kwargs) |
|
|
| ds, get_text, source_label = _load_calib_dump_dataset(n_calib, calib_source, tok) |
| print(f"[dump] source={calib_source} {source_label} rows={len(ds)} " |
| f"schema={list(ds[0].keys())}") |
|
|
| agg = {i: {"k": [], "v": []} for i in full} |
| used = 0 |
| for ex in ds: |
| if used >= n_calib: |
| break |
| text = get_text(ex) |
| ids = tok(text, return_tensors="pt", truncation=True, |
| max_length=max_len).to(input_device) |
| if ids["input_ids"].shape[1] < min_len: |
| continue |
| cache = DumpingCache(config=model.config) |
| with torch.no_grad(): |
| model.model(**ids, past_key_values=cache, use_cache=True) |
| for i in full: |
| agg[i]["k"].append(torch.cat(cache.dump[i]["k"])) |
| agg[i]["v"].append(torch.cat(cache.dump[i]["v"])) |
| used += 1 |
| print(f"[dump] trace {used}/{n_calib} len={ids['input_ids'].shape[1]}") |
|
|
| calib = {i: {"k": torch.cat(agg[i]["k"]), "v": torch.cat(agg[i]["v"])} |
| for i in full} |
| for i in list(full)[:3]: |
| k = calib[i]["k"].reshape(-1, meta["head_dim"]) |
| cut = meta["head_dim"] // 2 |
| print(f"[dump] layer {i}: {tuple(calib[i]['k'].shape)} | " |
| f"rope-half std {k[:, :cut].std():.3f} " |
| f"pass-half std {k[:, cut:].std():.3f} " |
| f"max|ch| {k.abs().amax(0).max():.2f}") |
| path = os.path.join(ARTIFACT_DIR, "calib_caches.pt") |
| torch.save({"calib": calib, "meta": meta}, path) |
| print(f"[dump] saved -> {path}") |
|
|
|
|
| |
| def stage_fit(only: list[str] | None = None): |
| """Fit quantizers and write artifacts/codebooks.pt. |
| |
| By default fits every config in cache_configs(). Pass ``only=["sign-1bit", ...]`` |
| to fit just those names and merge into an existing codebooks.pt (skips the rest). |
| Tuning-free quantizers (Sign, Ternary, Scalar, KIVI) finish in seconds; only |
| ProductVQ / RoPE-split need calib_caches.pt. |
| """ |
| codebooks_path = os.path.join(ARTIFACT_DIR, "codebooks.pt") |
| calib_path = os.path.join(ARTIFACT_DIR, "calib_caches.pt") |
|
|
| if only and os.path.exists(codebooks_path): |
| existing = torch.load(codebooks_path, weights_only=False) |
| fitted = existing["fitted"] |
| meta = existing["meta"] |
| print(f"[fit] merging into existing {codebooks_path} ({len(fitted)} configs)") |
| else: |
| if not os.path.exists(calib_path): |
| raise FileNotFoundError( |
| f"{calib_path} not found; run --stage dump first, or use " |
| f"--only with an existing codebooks.pt for tuning-free configs") |
| blob = torch.load(calib_path) |
| calib, meta = blob["calib"], blob["meta"] |
| fitted = {} |
|
|
| calib = None |
| if os.path.exists(calib_path): |
| calib = torch.load(calib_path)["calib"] |
|
|
| hd = meta["head_dim"] |
| layer_ids = meta["full_layers"] |
|
|
| configs = [(n, f) for n, f in cache_configs() if f is not None] |
| if only is not None: |
| only_set = set(only) |
| configs = [(n, f) for n, f in configs if n in only_set] |
| unknown = only_set - {n for n, _ in configs} |
| if unknown: |
| raise ValueError(f"unknown --only config(s): {sorted(unknown)}") |
|
|
| |
| |
| |
| fit_device = "cuda" if torch.cuda.is_available() else "cpu" |
| if fit_device == "cuda": |
| print(f"[fit] GPU available ({torch.cuda.get_device_name()}) -- fitting LBG on CUDA") |
| n_workers = 1 if fit_device == "cuda" else min(len(layer_ids), os.cpu_count() or 1) |
|
|
| attn_weighted = any( |
| isinstance(f(), (ProductVQKV, RoPESplitVQKV)) for _, f in configs |
| ) |
| if attn_weighted: |
| print("[fit] ProductVQ / RoPESplit: attention-weighted LBG " |
| "(centroids weighted by key attention mass)") |
|
|
| n_q = meta.get("n_q_heads", 48) |
| for name, factory in configs: |
| t0 = time.time() |
|
|
| def _fit_layer(i, _factory=factory, _device=fit_device): |
| if calib is not None and i in calib: |
| k_struct = calib[i]["k"] |
| v_struct = calib[i]["v"] |
| if k_struct.shape[0] > 512: |
| k_struct = k_struct[-512:] |
| v_struct = v_struct[-512:] |
| w = calibration_sample_weights(k_struct, n_q) |
| kf = k_struct.reshape(-1, hd)[:200_000].to(_device) |
| vf = v_struct.reshape(-1, hd)[:200_000].to(_device) |
| if w is not None: |
| w = w[:kf.shape[0]].to(_device) |
| else: |
| kf = torch.zeros(1, hd, device=_device) |
| vf = torch.zeros(1, hd, device=_device) |
| w = k_struct = None |
| q = _factory() |
| if isinstance(q, (ProductVQKV, RoPESplitVQKV)): |
| q.fit(kf, vf, sample_weights=w, n_q_heads=n_q, |
| k_struct=k_struct.to(_device) if k_struct is not None else None) |
| else: |
| q.fit(kf, vf) |
| |
| if _device != "cpu" and hasattr(q, "to"): |
| q.to("cpu") |
| return i, q |
|
|
| with ThreadPoolExecutor(max_workers=n_workers) as ex: |
| per_layer = dict(ex.map(_fit_layer, layer_ids)) |
|
|
| fitted[name] = per_layer |
| print(f"[fit] {name}: {len(per_layer)} layer-codebooks in {time.time()-t0:.1f}s") |
|
|
| torch.save({"fitted": fitted, "meta": meta}, codebooks_path) |
| print(f"[fit] saved -> {codebooks_path} ({len(fitted)} configs total)") |
|
|
|
|
| |
| def make_vq_cache_class(per_layer_quantizers, target_layers, model_config, device=None): |
| """Build a VQCache class with codebooks pre-moved to `device`. |
| |
| Pass device="cuda" (or the model's device) so the roundtrip runs entirely |
| on GPU with no CPU<->GPU transfers. Without this, every update() call |
| implicitly transfers key_states to CPU and back, making decode very slow. |
| """ |
| from transformers.cache_utils import DynamicCache |
|
|
| if device is not None: |
| for q in per_layer_quantizers.values(): |
| if hasattr(q, "to"): |
| q.to(device) |
|
|
| class VQCache(DynamicCache): |
| def __init__(self, *a, **k): |
| super().__init__(*a, **k) |
| self.q = per_layer_quantizers |
| self.target = set(target_layers) |
| def update(self, key_states, value_states, layer_idx, cache_kwargs=None): |
| if layer_idx in self.target and layer_idx in self.q: |
| q = self.q[layer_idx] |
| b, h, s, d = key_states.shape |
| kf = key_states[0].transpose(0, 1).reshape(-1, d).float() |
| vf = value_states[0].transpose(0, 1).reshape(-1, d).float() |
| |
| |
| per_channel_key = ( |
| isinstance(q, KIVIScalarKV) |
| or getattr(q, "per_channel_key", False) |
| ) |
| if per_channel_key: |
| kk = key_states[0].permute(1, 0, 2).float() |
| k_hat = torch.stack([q.roundtrip_k(kk[:, hh, :]) for hh in range(h)], 1) |
| vv = value_states[0].permute(1, 0, 2).float() |
| v_hat = torch.stack([q.roundtrip_v(vv[:, hh, :]) for hh in range(h)], 1) |
| k_hat = k_hat.permute(1, 0, 2)[None] |
| v_hat = v_hat.permute(1, 0, 2)[None] |
| else: |
| k_hat = q.roundtrip_k(kf).reshape(s, h, d).permute(1, 0, 2)[None] |
| v_hat = q.roundtrip_v(vf).reshape(s, h, d).permute(1, 0, 2)[None] |
| key_states = k_hat.to(key_states.dtype).to(key_states.device) |
| value_states = v_hat.to(value_states.dtype).to(value_states.device) |
| return super().update(key_states, value_states, layer_idx, cache_kwargs) |
| return VQCache |
|
|
|
|
| |
| def stage_cheap(n_eval=64, max_len=16384): |
| import collections |
| from datasets import load_dataset |
| blob = torch.load(os.path.join(ARTIFACT_DIR, "codebooks.pt"), weights_only=False) |
| fitted, meta = blob["fitted"], blob["meta"] |
| hd, full = meta["head_dim"], meta["full_layers"] |
| n_q = meta.get("n_q_heads", 48) |
|
|
| |
| ATTN_WIN = 512 |
|
|
| model, tok, _ = load_model_and_meta() |
| from tqdm import tqdm |
| from transformers.cache_utils import DynamicCache |
|
|
| |
| |
| |
| for per_layer in fitted.values(): |
| for q in per_layer.values(): |
| if hasattr(q, "to"): |
| q.to(model.device) |
|
|
| class EvalDump(DynamicCache): |
| def __init__(self): |
| super().__init__(); self.d = {i: {} for i in full} |
| def update(self, ks, vs, li, ck=None): |
| if li in set(full): |
| |
| |
| self.d[li]["k"] = ks.detach()[0].permute(1, 0, 2).float() |
| self.d[li]["v"] = vs.detach()[0].permute(1, 0, 2).float() |
| return super().update(ks, vs, li, ck) |
|
|
| ds = load_dataset(CALIB_DATASET, split=CALIB_SPLIT) |
|
|
| |
| trace_rows = [] |
|
|
| |
| |
| for ex in tqdm(ds.select(range(500, 500 + n_eval))): |
| text = flatten_trace(ex, tok) |
| ids = tok(text, return_tensors="pt", truncation=True, max_length=max_len).to(model.device) |
| cache = EvalDump() |
| with torch.no_grad(): |
| |
| |
| model.model(**ids, past_key_values=cache, use_cache=True) |
|
|
| |
| |
| |
| synth_q = {} |
| for i in full: |
| s = cache.d[i]["k"].shape[0] |
| win = min(s, ATTN_WIN) |
| q_rand = torch.randn(win, n_q, hd, device=cache.d[i]["k"].device) |
| synth_q[i] = q_rand / q_rand.norm(dim=-1, keepdim=True).clamp_min(1e-8) |
|
|
| for name, _ in cache_configs(): |
| if name == "fp16 (baseline)": |
| continue |
| per_layer = fitted[name] |
| acc = collections.defaultdict(float) |
| nL = 0 |
| for i in full: |
| k = cache.d[i]["k"] |
| v = cache.d[i]["v"] |
| q = per_layer[i] |
| s, h, d = k.shape |
|
|
| |
| |
| |
| |
| |
| |
| per_channel_key = ( |
| isinstance(q, KIVIScalarKV) |
| or getattr(q, "per_channel_key", False) |
| ) |
| if per_channel_key: |
| k_hat = torch.stack([q.roundtrip_k(k[:, hh, :]) for hh in range(h)], 1) |
| v_hat = torch.stack([q.roundtrip_v(v[:, hh, :]) for hh in range(h)], 1) |
| else: |
| k_hat = q.roundtrip_k(k.reshape(-1, d)).reshape(s, h, d) |
| v_hat = q.roundtrip_v(v.reshape(-1, d)).reshape(s, h, d) |
|
|
| acc["key_cos"] += key_cosine(k, k_hat) |
| acc["val_cos"] += key_cosine(v, v_hat) |
| acc["key_mse"] += cache_mse(k, k_hat) |
| acc["val_mse"] += cache_mse(v, v_hat) |
|
|
| |
| win = min(s, ATTN_WIN) |
| kw, kw_hat = k[-win:], k_hat[-win:] |
| vw, vw_hat = v[-win:], v_hat[-win:] |
| q_syn = synth_q[i] |
|
|
| out_ref, _ = attention_output(q_syn, kw, vw, n_q) |
| out_hat, _ = attention_output(q_syn, kw_hat, vw_hat, n_q) |
| acc["attn_cos"] += attn_output_cosine(out_ref, out_hat) |
| acc["attn_output_error"] += attn_output_error(out_ref, out_hat) |
|
|
| ip = inner_product_distortion(q_syn, kw, kw_hat) |
| acc["ip_rel"] += ip["ip_rel_err"] |
| acc["ip_bias"] += ip["ip_bias"] |
|
|
| nL += 1 |
|
|
| trace_rows.append({ |
| "trace_len": ids["input_ids"].shape[1], |
| "config": name, |
| "key_cos": acc["key_cos"] / nL, |
| "val_cos": acc["val_cos"] / nL, |
| "key_mse": acc["key_mse"] / nL, |
| "val_mse": acc["val_mse"] / nL, |
| "attn_cos": acc["attn_cos"] / nL, |
| "attn_output_error": acc["attn_output_error"] / nL, |
| "ip_rel": acc["ip_rel"] / nL, |
| "ip_bias": acc["ip_bias"] / nL, |
| }) |
|
|
| |
| agg = collections.defaultdict(lambda: collections.defaultdict(list)) |
| for r in trace_rows: |
| for col in ("key_cos", "val_cos", "key_mse", "val_mse", |
| "attn_cos", "attn_output_error", "ip_rel", "ip_bias"): |
| agg[r["config"]][col].append(r[col]) |
|
|
| COLS = ("key_cos", "val_cos", "key_mse", "val_mse", |
| "attn_cos", "attn_output_error", "ip_rel", "ip_bias") |
|
|
| summary = [] |
| for name, _ in cache_configs(): |
| if name not in agg: |
| continue |
| cols = agg[name] |
| n = len(cols["key_cos"]) |
| q0 = next(iter(fitted[name].values())) |
| row = {"config": name, "bits_per_elt": round(q0.bits_per_element(hd), 4), "n_traces": n} |
| for col in COLS: |
| row[col] = round(sum(cols[col]) / n, 5) |
| summary.append(row) |
|
|
| print(f"\n[cheap] mean metrics over {n_eval} held-out traces:") |
| print(f" {'config':24s} {'bpe':>5} {'key_cos':>8} {'val_cos':>8} " |
| f"{'key_mse':>9} {'val_mse':>9} {'attn_cos':>9} {'attn_err':>9} " |
| f"{'ip_rel':>8} {'ip_bias':>9}") |
| for row in summary: |
| print(f" {row['config']:24s} {row['bits_per_elt']:5.2f} " |
| f"{row['key_cos']:8.4f} {row['val_cos']:8.4f} " |
| f"{row['key_mse']:9.5f} {row['val_mse']:9.5f} " |
| f"{row['attn_cos']:9.4f} {row['attn_output_error']:9.4f} " |
| f"{row['ip_rel']:8.5f} {row['ip_bias']:9.6f}") |
|
|
| out_path = os.path.join(ARTIFACT_DIR, "cheap_metrics.json") |
| json.dump(summary, open(out_path, "w"), indent=2) |
| print(f"[cheap] saved -> {out_path}") |
|
|
|
|
| |
| _AGENT_SYSTEM = ( |
| "You are an expert software engineer fixing a GitHub issue. " |
| "You have a bash shell inside the repository checked out at the failing commit. " |
| "Use <bash>command</bash> tags to run shell commands. " |
| "Explore the code, implement the fix, then output <submit> when done." |
| ) |
|
|
|
|
| def _agent_loop(model, tok, task: dict, cache_factory, max_turns: int, |
| max_new: int = 1024) -> str: |
| """Run a minimal ReAct-bash loop on one SWE-bench task. |
| |
| Clones the repo at base_commit into a temp dir, runs the model in a |
| generate→bash→observe loop, and returns the final `git diff HEAD` patch. |
| Each generate call rebuilds the full context from scratch so the VQCache |
| sees the compounding long-context pressure that the project targets. |
| """ |
| import re |
| import shutil |
| import subprocess |
| import tempfile |
|
|
| repo_dir = tempfile.mkdtemp(prefix="sweagent_") |
| try: |
| subprocess.run( |
| ["git", "clone", f"https://github.com/{task['repo']}.git", repo_dir], |
| check=True, capture_output=True, timeout=120, |
| ) |
| subprocess.run( |
| ["git", "checkout", task["base_commit"]], |
| check=True, capture_output=True, cwd=repo_dir, timeout=30, |
| ) |
| except Exception as exc: |
| print(f" [agent] clone/checkout failed for {task['instance_id']}: {exc}") |
| shutil.rmtree(repo_dir, ignore_errors=True) |
| return "" |
|
|
| try: |
| messages = [ |
| {"role": "system", "content": _AGENT_SYSTEM}, |
| {"role": "user", "content": ( |
| f"Repository: {task['repo']}\n\n" |
| f"Issue:\n{task['problem_statement']}" |
| )}, |
| ] |
|
|
| for _ in range(max_turns): |
| prompt = tok.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True) |
| ids = tok(prompt, return_tensors="pt", truncation=True, |
| max_length=32768).to(model.device) |
|
|
| cache = cache_factory() |
| with torch.no_grad(): |
| out = model.generate( |
| **ids, max_new_tokens=max_new, do_sample=False, |
| past_key_values=cache, use_cache=True, |
| ) |
| gen = tok.decode(out[0, ids["input_ids"].shape[1]:], |
| skip_special_tokens=True) |
| messages.append({"role": "assistant", "content": gen}) |
|
|
| if re.search(r"<submit\s*/?>", gen, re.I): |
| break |
|
|
| cmds = re.findall(r"<bash>(.*?)</bash>", gen, re.DOTALL) |
| if not cmds: |
| break |
|
|
| obs_parts = [] |
| for cmd in cmds: |
| try: |
| r = subprocess.run( |
| cmd, shell=True, capture_output=True, text=True, |
| timeout=30, cwd=repo_dir, |
| ) |
| obs_parts.append( |
| f"$ {cmd.strip()}\n{(r.stdout + r.stderr)[:2000]}") |
| except subprocess.TimeoutExpired: |
| obs_parts.append(f"$ {cmd.strip()}\n[timeout after 30s]") |
| messages.append({"role": "user", "content": "\n\n".join(obs_parts)}) |
|
|
| diff = subprocess.run( |
| ["git", "diff", "HEAD"], |
| capture_output=True, text=True, cwd=repo_dir, |
| ) |
| return diff.stdout |
| finally: |
| shutil.rmtree(repo_dir, ignore_errors=True) |
|
|
|
|
| def run_swebench_subset(model, tok, cache_factory_per_layer, target_layers, |
| task_ids, max_turns=100): |
| """Run the official SWE-bench Verified harness over `task_ids`. |
| |
| Builds per-config VQCache (or None for fp16), runs the mini-bash agent on |
| each task, writes predictions.jsonl, evaluates with the swebench harness, |
| and returns resolve rate in [0, 1]. Keep `task_ids` FIXED across configs. |
| """ |
| import glob |
| from datasets import load_dataset |
|
|
| |
| if cache_factory_per_layer is not None: |
| VQCls = make_vq_cache_class(cache_factory_per_layer, target_layers, |
| model.config, device=model.device) |
| def cache_factory(): return VQCls() |
| else: |
| def cache_factory(): return None |
|
|
| |
| verified = load_dataset("princeton-nlp/SWE-bench_Verified", split="test") |
| tasks = {r["instance_id"]: r for r in verified |
| if r["instance_id"] in set(task_ids)} |
|
|
| |
| run_id = f"vqkv_{int(time.time())}" |
| predictions = [] |
| for task_id in task_ids: |
| task = tasks.get(task_id) |
| if task is None: |
| print(f"[swebench] {task_id}: not in Verified dataset, skipping") |
| continue |
| print(f"[swebench] {task_id} ({len(predictions)+1}/{len(task_ids)})") |
| patch = _agent_loop(model, tok, task, cache_factory, max_turns=max_turns) |
| predictions.append({ |
| "instance_id": task_id, |
| "model_patch": patch, |
| "model_name_or_path": "laguna-vqkv", |
| }) |
|
|
| preds_path = os.path.join(ARTIFACT_DIR, f"{run_id}.jsonl") |
| with open(preds_path, "w") as f: |
| for p in predictions: |
| f.write(json.dumps(p) + "\n") |
| print(f"[swebench] wrote {len(predictions)} predictions -> {preds_path}") |
|
|
| |
| |
| from swebench.harness.run_evaluation import main as run_evaluation |
| run_evaluation( |
| dataset_name_or_path="princeton-nlp/SWE-bench_Verified", |
| split="test", |
| instance_ids=task_ids, |
| predictions_path=preds_path, |
| max_workers=4, |
| force_rebuild=False, |
| cache_level="env", |
| clean=False, |
| open_file_limit=4096, |
| run_id=run_id, |
| timeout=1800, |
| ) |
|
|
| |
| result_files = ( |
| glob.glob(os.path.join(ARTIFACT_DIR, f"{run_id}*.json")) |
| + glob.glob(f"{run_id}*.json") |
| ) |
| if not result_files: |
| print(f"[swebench] WARNING: no results file found for run_id={run_id}") |
| return 0.0 |
|
|
| with open(result_files[0]) as f: |
| results = json.load(f) |
|
|
| if isinstance(results, list): |
| n_resolved = sum(1 for r in results if r.get("resolved", False)) |
| elif isinstance(results, dict): |
| |
| n_resolved = sum(1 for v in results.values() |
| if (v.get("resolved") if isinstance(v, dict) else v)) |
| else: |
| n_resolved = 0 |
|
|
| return n_resolved / len(task_ids) |
|
|
|
|
| def stage_swebench(n_tasks=50, seed=0): |
| import random |
| from datasets import load_dataset |
| blob = torch.load(os.path.join(ARTIFACT_DIR, "codebooks.pt")) |
| fitted, meta = blob["fitted"], blob["meta"] |
| model, tok, _ = load_model_and_meta() |
|
|
| verified = load_dataset("princeton-nlp/SWE-bench_Verified", split="test") |
| rng = random.Random(seed) |
| task_ids = [verified[i]["instance_id"] |
| for i in rng.sample(range(len(verified)), n_tasks)] |
| json.dump(task_ids, open(os.path.join(ARTIFACT_DIR, "task_subset.json"), "w")) |
| print(f"[swebench] fixed subset of {n_tasks} tasks (seed={seed}) saved.") |
|
|
| results = [] |
| for name, _ in cache_configs(): |
| quantizers = None if name == "fp16 (baseline)" else fitted[name] |
| try: |
| rate = run_swebench_subset(model, tok, quantizers, |
| meta["full_layers"], task_ids) |
| except NotImplementedError as e: |
| print(f"[swebench] {name}: STUB -- {e}") |
| rate = None |
| results.append({"config": name, "resolve_rate": rate}) |
| print(f"[swebench] {name}: resolve_rate={rate}") |
|
|
| |
| base = next((r["resolve_rate"] for r in results |
| if r["config"] == "fp16 (baseline)"), None) |
| print("\n[swebench] resolve rate on fixed subset (delta vs fp16):") |
| for r in results: |
| d = (None if (r["resolve_rate"] is None or base is None) |
| else round(r["resolve_rate"] - base, 4)) |
| print(f" {r['config']:24s} {r['resolve_rate']} (Δ {d})") |
| json.dump(results, open(os.path.join(ARTIFACT_DIR, "swebench_results.json"), "w"), |
| indent=2) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--stage", required=True, |
| choices=["dump", "fit", "cheap", "swebench"]) |
| ap.add_argument("--n_calib", type=int, default=16) |
| ap.add_argument( |
| "--calib_source", type=str, default=None, |
| choices=["swesmith", "longbench-hotpotqa"], |
| help="dump stage: calibration corpus (default: CALIB_SOURCE env or swesmith)") |
| ap.add_argument("--n_eval", type=int, default=64) |
| ap.add_argument("--n_tasks", type=int, default=50) |
| ap.add_argument( |
| "--only", type=str, default=None, |
| help="fit stage: comma-separated config names to fit/merge (e.g. " |
| "sign-1bit,ternary-bitnet). Skips refitting other configs.") |
| args = ap.parse_args() |
|
|
| if args.stage == "dump": |
| stage_dump(n_calib=args.n_calib, calib_source=args.calib_source) |
| elif args.stage == "fit": |
| only = [s.strip() for s in args.only.split(",")] if args.only else None |
| stage_fit(only=only) |
| elif args.stage == "cheap": |
| stage_cheap(n_eval=args.n_eval) |
| elif args.stage == "swebench": |
| stage_swebench(n_tasks=args.n_tasks) |
|
|
|
|
| if __name__ == "__main__": |
| main() |