agent-cost-optimizer / batch_validate.py
narcolepticchicken's picture
Upload batch_validate.py
29c4a80 verified
"""
Batch Cascade Validation Script
Runs the cascade agent on MULTIPLE SWE-bench instances and
verifies each patch via conda environment + pytest.
This is the script that proves the cascade causally, not just
correlationally from trace simulation.
Usage:
python batch_validate.py --instances 3 --target cascade-only
python batch_validate.py --instances 5 --target all
Requirements: hf_jobs with a10g-largex2, 8h timeout
"""
import json, os, re, subprocess, sys, tempfile, time, traceback
from datetime import datetime
from pathlib import Path
# ============================================================
# INSTANCE LISTS
# ============================================================
CASCADE_ONLY = [
"astropy__astropy-14365", "astropy__astropy-14995",
"django__django-11815", "django__django-13089",
"django__django-13807", "django__django-14315",
"matplotlib__matplotlib-25224", "matplotlib__matplotlib-25311",
"sympy__sympy-19487", "sympy__sympy-20590",
]
FRONTIER_ONLY = [
"django__django-12453", "django__django-14030",
"django__django-14349", "django__django-14855",
"django__django-15098", "django__django-16235",
"matplotlib__matplotlib-26020", "psf__requests-6028",
"pylint-dev__pylint-7080", "scikit-learn__scikit-learn-13439",
"scikit-learn__scikit-learn-14087", "sphinx-doc__sphinx-10323",
"sphinx-doc__sphinx-10466", "sphinx-doc__sphinx-10614",
]
def sh(cmd, cwd=None, timeout=120):
r = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, timeout=timeout, shell=True)
return r.returncode, r.stdout, r.stderr
def ensure_conda():
for p in [os.path.expanduser("~/miniconda3/bin/conda"), "/opt/conda/bin/conda"]:
if os.path.exists(p): return p
print("πŸ“¦ Installing Miniconda...")
sh("wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && bash /tmp/miniconda.sh -b -p $HOME/miniconda3", timeout=300)
p = os.path.expanduser("~/miniconda3/bin/conda")
sh(f"{p} config --set always_yes yes --set changeps1 no && {p} tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main 2>/dev/null; true", timeout=30)
os.environ["PATH"] = os.path.expanduser("~/miniconda3/bin:") + os.environ.get("PATH", "")
return p
def is_valid_patch(text):
return bool(text) and len(text) > 10 and 'diff --git' in text and '@@' in text
def extract_patch(text):
m = re.search(r'<patch>\s*\n?(.*?)</patch>', text, re.DOTALL)
if m and is_valid_patch(m.group(1)): return m.group(1).strip()
m = re.search(r'```diff\s*\n(.*?)```', text, re.DOTALL)
if m and is_valid_patch(m.group(1)): return m.group(1).strip()
di = text.find('diff --git')
if di >= 0:
patch = text[di:].strip()[:2000]
if is_valid_patch(patch): return patch
return None
def call_model(client, messages, max_tokens=4096):
try:
c = client.chat.completions.create(model=client.model, messages=messages, max_tokens=max_tokens, temperature=0.2)
t = c.choices[0].message.content
it = c.usage.prompt_tokens if hasattr(c,'usage') and c.usage else 0
ot = c.usage.completion_tokens if hasattr(c,'usage') and c.usage else len(t)//4
return t, it, ot
except Exception as e: return f"[ERROR: {e}]", 0, 0
def run_cascade(instance, repo_dir, conda, env_name):
from huggingface_hub import InferenceClient
T1, T2 = "meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.3-70B-Instruct"
system = f"Fix bug in {instance['repo']}. Repo: {repo_dir}.\n<bash>cmd</bash>\n<patch>\ndiff --git a/path b/path\n--- a/path\n+++ b/path\n@@ -N,M +N,M @@\nchanges\n</patch>\n<submit>Done</submit>"
messages = [{"role":"system","content":system},{"role":"user","content":f"PROBLEM:\n{instance.get('problem_statement','')}\n\nExplore and fix."}]
for tier, mid, mt in [("T1",T1,30),("T2",T2,30)]:
print(f"\n[{tier}] {mid}")
client = InferenceClient(mid)
ti = to = 0
for turn in range(mt):
text, it, ot = call_model(client, messages, 4096)
ti += it; to += ot
messages.append({"role":"assistant","content":text})
print(f" T{turn+1}: {it}+{ot} tok")
patch = extract_patch(text)
if patch:
(Path(repo_dir)/"_c.patch").write_text(patch)
rc, out, err = sh(f"cd {repo_dir} && git apply --check _c.patch 2>&1", timeout=10)
if rc == 0:
print(f" βœ… VALID ({len(patch)}ch)")
return {"patch":patch,"tier":tier,"turns":turn+1,"input_tokens":ti,"output_tokens":to}
print(f" ❌ Invalid: {err[:100]}")
messages.append({"role":"user","content":f"Patch check failed: {err[:200]}\nUse git diff for valid unified diff."})
continue
for cmd in re.findall(r'<bash>(.*?)</bash>', text, re.DOTALL):
cmd = cmd.strip().replace("pytest", f"{conda} run -n {env_name} python -m pytest")
rc, out, err = sh(cmd, cwd=str(repo_dir), timeout=60)
o = f"<output>\n{(out+err)[:1500]}\n</output>"
if rc: o = o[:-9] + f" [EXIT:{rc}]\n</output>"
messages.append({"role":"user","content":o})
if "<submit>" in text: break
return {"patch":None,"tier":None,"turns":0,"input_tokens":0,"output_tokens":0}
def verify_patch(instance, patch, repo_dir, conda, env_name):
base = instance.get("base_commit","")
tp = instance.get("test_patch","")
f2p = instance.get("FAIL_TO_PASS",[])
sh(f"cd {repo_dir} && git checkout -f {base} && git clean -fd", timeout=30)
(Path(repo_dir)/"_aco.patch").write_text(patch)
rc, out, err = sh(f"cd {repo_dir} && git apply --check _aco.patch", timeout=10)
if rc: return {"resolved":False,"error":f"patch check: {err[:150]}"}
sh(f"cd {repo_dir} && git apply _aco.patch", timeout=10)
(Path(repo_dir)/"_t.patch").write_text(tp)
sh(f"cd {repo_dir} && (git apply _t.patch) || git apply --reject _t.patch 2>/dev/null; true", timeout=10)
cmd = f"cd {repo_dir} && {conda} run -n {env_name} python -m pytest -v --tb=short -x {' '.join(f2p[:10])}"
rc, out, err = sh(cmd, timeout=300)
if rc == 0:
p2p = instance.get("PASS_TO_PASS",[])
if p2p:
cmd2 = f"cd {repo_dir} && {conda} run -n {env_name} python -m pytest -v --tb=short -x {' '.join(p2p[:10])}"
rc2, out2, err2 = sh(cmd2, timeout=300)
if rc2: return {"resolved":False,"error":f"P2P: {(out2+err2)[:200]}"}
return {"resolved":True,"test_output":(out+err)[:500]}
return {"resolved":False,"error":f"{len(re.findall(r'FAILED', out+err))} F2P failures","test_output":(out+err)[:500]}
def setup_env(conda, instance, repo_dir, env_name):
ec = instance.get("environment_setup_commit","")
if ec:
sh(f"cd {repo_dir} && git fetch origin {ec} && git checkout {ec}", timeout=60)
eyml = None
for c in ["environment.yml","dev/environment.yml",".github/environment.yml","ci/environment.yml"]:
if (Path(repo_dir)/c).exists(): eyml = c; break
if eyml:
rc, out, err = sh(f"{conda} env create -f {repo_dir}/{eyml} -n {env_name} --quiet", timeout=600)
else:
rc, out, err = sh(f"{conda} create -n {env_name} python=3.10 pip -y", timeout=300)
if rc:
rc, out, err = sh(f"{conda} create -n {env_name} python=3.10 pip -y", timeout=300)
if rc: return False, f"conda: {err[:200]}"
base = instance["base_commit"]
sh(f"cd {repo_dir} && git fetch origin {base} && git checkout {base}", timeout=60)
sh(f"cd {repo_dir} && {conda} run -n {env_name} pip install -e . 2>&1 | tail -3", timeout=300)
sh(f"cd {repo_dir} && {conda} run -n {env_name} pip install . 2>&1 | tail -3", timeout=300)
return True, ""
def main():
import datasets
target = os.environ.get("INSTANCE_TARGET", "cascade-only")
max_instances = int(os.environ.get("MAX_INSTANCES", "3"))
print(f"πŸš€ BATCH CASCADE VALIDATION β€” target={target} max={max_instances}")
print(f" {datetime.now().isoformat()}")
conda = ensure_conda()
if not conda: print("❌ No conda"); sys.exit(1)
ds = datasets.load_dataset("princeton-nlp/SWE-bench_Verified", split="test")
if target == "cascade-only":
iids = CASCADE_ONLY[:max_instances]
elif target == "frontier-only":
iids = FRONTIER_ONLY[:max_instances]
else:
iids = [r["instance_id"] for r in ds][:max_instances]
instances = {}
for row in ds:
if row["instance_id"] in iids:
instances[row["instance_id"]] = dict(row)
print(f"Instances: {iids}\n")
results = []
for i, iid in enumerate(iids):
instance = instances.get(iid)
if not instance: continue
print(f"\n{'='*60}\n[{i+1}/{len(iids)}] {iid}\n{'='*60}")
with tempfile.TemporaryDirectory(prefix=f"aco_{i}_") as tmpdir:
repo_dir = Path(tmpdir) / "repo"
env_name = f"aco_{iid.replace('__','_').replace('-','_')[:30]}"
print(f"Clone...")
url = f"https://github.com/{instance['repo']}.git"
rc, out, err = sh(f"git clone --depth 100 {url} {repo_dir}", timeout=600)
if rc:
results.append({"instance_id":iid,"resolved":False,"error":f"Clone: {err[:200]}"})
continue
print(f"Env...")
ok, err = setup_env(conda, instance, repo_dir, env_name)
if not ok:
results.append({"instance_id":iid,"resolved":False,"error":f"Env: {err}"})
continue
print(f"Cascade...")
agent = run_cascade(instance, repo_dir, conda, env_name)
if not agent["patch"]:
results.append({"instance_id":iid,"resolved":False,"tier":None,"error":"No valid patch"})
sh(f"{conda} env remove -n {env_name} -y --quiet 2>/dev/null; true", timeout=30)
continue
print(f"Verify...")
verify = verify_patch(instance, agent["patch"], repo_dir, conda, env_name)
r = {
"instance_id":iid, "repo":instance["repo"],
"resolved":verify["resolved"], "tier":agent["tier"],
"turns":agent["turns"], "input_tokens":agent["input_tokens"],
"output_tokens":agent["output_tokens"],
"error":verify.get("error"),
"timestamp":datetime.now().isoformat()
}
results.append(r)
status = "βœ…" if verify["resolved"] else "❌"
print(f" {status} {agent['tier']} {agent['turns']}t")
sh(f"{conda} env remove -n {env_name} -y --quiet 2>/dev/null; true", timeout=30)
# Incremental save
with open("batch_results.jsonl","w") as f:
for r in results: f.write(json.dumps(r)+"\n")
resolved = [r for r in results if r["resolved"]]
t1 = [r for r in resolved if r.get("tier")=="T1"]
t2 = [r for r in resolved if r.get("tier")=="T2"]
print(f"\n{'='*60}\nRESULTS: {len(resolved)}/{len(results)} resolved")
print(f" T1: {len(t1)} T2: {len(t2)}")
for r in results:
s = "βœ…" if r["resolved"] else "❌"
print(f" {s} {r['instance_id']} [{r.get('tier','')}]")
print(f"Saved: batch_results.jsonl")
if __name__=="__main__":
try: sys.exit(main())
except Exception as e: print(f"πŸ’₯ {e}"); traceback.print_exc(); sys.exit(1)