| """ |
| Batch Cascade Validation Script |
| |
| Runs the cascade agent on MULTIPLE SWE-bench instances and |
| verifies each patch via conda environment + pytest. |
| |
| This is the script that proves the cascade causally, not just |
| correlationally from trace simulation. |
| |
| Usage: |
| python batch_validate.py --instances 3 --target cascade-only |
| python batch_validate.py --instances 5 --target all |
| |
| Requirements: hf_jobs with a10g-largex2, 8h timeout |
| """ |
|
|
| import json, os, re, subprocess, sys, tempfile, time, traceback |
| from datetime import datetime |
| from pathlib import Path |
|
|
| |
| |
| |
| CASCADE_ONLY = [ |
| "astropy__astropy-14365", "astropy__astropy-14995", |
| "django__django-11815", "django__django-13089", |
| "django__django-13807", "django__django-14315", |
| "matplotlib__matplotlib-25224", "matplotlib__matplotlib-25311", |
| "sympy__sympy-19487", "sympy__sympy-20590", |
| ] |
|
|
| FRONTIER_ONLY = [ |
| "django__django-12453", "django__django-14030", |
| "django__django-14349", "django__django-14855", |
| "django__django-15098", "django__django-16235", |
| "matplotlib__matplotlib-26020", "psf__requests-6028", |
| "pylint-dev__pylint-7080", "scikit-learn__scikit-learn-13439", |
| "scikit-learn__scikit-learn-14087", "sphinx-doc__sphinx-10323", |
| "sphinx-doc__sphinx-10466", "sphinx-doc__sphinx-10614", |
| ] |
|
|
| def sh(cmd, cwd=None, timeout=120): |
| r = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, timeout=timeout, shell=True) |
| return r.returncode, r.stdout, r.stderr |
|
|
| def ensure_conda(): |
| for p in [os.path.expanduser("~/miniconda3/bin/conda"), "/opt/conda/bin/conda"]: |
| if os.path.exists(p): return p |
| print("π¦ Installing Miniconda...") |
| sh("wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && bash /tmp/miniconda.sh -b -p $HOME/miniconda3", timeout=300) |
| p = os.path.expanduser("~/miniconda3/bin/conda") |
| sh(f"{p} config --set always_yes yes --set changeps1 no && {p} tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main 2>/dev/null; true", timeout=30) |
| os.environ["PATH"] = os.path.expanduser("~/miniconda3/bin:") + os.environ.get("PATH", "") |
| return p |
|
|
| def is_valid_patch(text): |
| return bool(text) and len(text) > 10 and 'diff --git' in text and '@@' in text |
|
|
| def extract_patch(text): |
| m = re.search(r'<patch>\s*\n?(.*?)</patch>', text, re.DOTALL) |
| if m and is_valid_patch(m.group(1)): return m.group(1).strip() |
| m = re.search(r'```diff\s*\n(.*?)```', text, re.DOTALL) |
| if m and is_valid_patch(m.group(1)): return m.group(1).strip() |
| di = text.find('diff --git') |
| if di >= 0: |
| patch = text[di:].strip()[:2000] |
| if is_valid_patch(patch): return patch |
| return None |
|
|
| def call_model(client, messages, max_tokens=4096): |
| try: |
| c = client.chat.completions.create(model=client.model, messages=messages, max_tokens=max_tokens, temperature=0.2) |
| t = c.choices[0].message.content |
| it = c.usage.prompt_tokens if hasattr(c,'usage') and c.usage else 0 |
| ot = c.usage.completion_tokens if hasattr(c,'usage') and c.usage else len(t)//4 |
| return t, it, ot |
| except Exception as e: return f"[ERROR: {e}]", 0, 0 |
|
|
| def run_cascade(instance, repo_dir, conda, env_name): |
| from huggingface_hub import InferenceClient |
| T1, T2 = "meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.3-70B-Instruct" |
| |
| system = f"Fix bug in {instance['repo']}. Repo: {repo_dir}.\n<bash>cmd</bash>\n<patch>\ndiff --git a/path b/path\n--- a/path\n+++ b/path\n@@ -N,M +N,M @@\nchanges\n</patch>\n<submit>Done</submit>" |
| messages = [{"role":"system","content":system},{"role":"user","content":f"PROBLEM:\n{instance.get('problem_statement','')}\n\nExplore and fix."}] |
| |
| for tier, mid, mt in [("T1",T1,30),("T2",T2,30)]: |
| print(f"\n[{tier}] {mid}") |
| client = InferenceClient(mid) |
| ti = to = 0 |
| for turn in range(mt): |
| text, it, ot = call_model(client, messages, 4096) |
| ti += it; to += ot |
| messages.append({"role":"assistant","content":text}) |
| print(f" T{turn+1}: {it}+{ot} tok") |
| |
| patch = extract_patch(text) |
| if patch: |
| (Path(repo_dir)/"_c.patch").write_text(patch) |
| rc, out, err = sh(f"cd {repo_dir} && git apply --check _c.patch 2>&1", timeout=10) |
| if rc == 0: |
| print(f" β
VALID ({len(patch)}ch)") |
| return {"patch":patch,"tier":tier,"turns":turn+1,"input_tokens":ti,"output_tokens":to} |
| print(f" β Invalid: {err[:100]}") |
| messages.append({"role":"user","content":f"Patch check failed: {err[:200]}\nUse git diff for valid unified diff."}) |
| continue |
| |
| for cmd in re.findall(r'<bash>(.*?)</bash>', text, re.DOTALL): |
| cmd = cmd.strip().replace("pytest", f"{conda} run -n {env_name} python -m pytest") |
| rc, out, err = sh(cmd, cwd=str(repo_dir), timeout=60) |
| o = f"<output>\n{(out+err)[:1500]}\n</output>" |
| if rc: o = o[:-9] + f" [EXIT:{rc}]\n</output>" |
| messages.append({"role":"user","content":o}) |
| |
| if "<submit>" in text: break |
| return {"patch":None,"tier":None,"turns":0,"input_tokens":0,"output_tokens":0} |
|
|
| def verify_patch(instance, patch, repo_dir, conda, env_name): |
| base = instance.get("base_commit","") |
| tp = instance.get("test_patch","") |
| f2p = instance.get("FAIL_TO_PASS",[]) |
| |
| sh(f"cd {repo_dir} && git checkout -f {base} && git clean -fd", timeout=30) |
| |
| (Path(repo_dir)/"_aco.patch").write_text(patch) |
| rc, out, err = sh(f"cd {repo_dir} && git apply --check _aco.patch", timeout=10) |
| if rc: return {"resolved":False,"error":f"patch check: {err[:150]}"} |
| sh(f"cd {repo_dir} && git apply _aco.patch", timeout=10) |
| |
| (Path(repo_dir)/"_t.patch").write_text(tp) |
| sh(f"cd {repo_dir} && (git apply _t.patch) || git apply --reject _t.patch 2>/dev/null; true", timeout=10) |
| |
| cmd = f"cd {repo_dir} && {conda} run -n {env_name} python -m pytest -v --tb=short -x {' '.join(f2p[:10])}" |
| rc, out, err = sh(cmd, timeout=300) |
| |
| if rc == 0: |
| p2p = instance.get("PASS_TO_PASS",[]) |
| if p2p: |
| cmd2 = f"cd {repo_dir} && {conda} run -n {env_name} python -m pytest -v --tb=short -x {' '.join(p2p[:10])}" |
| rc2, out2, err2 = sh(cmd2, timeout=300) |
| if rc2: return {"resolved":False,"error":f"P2P: {(out2+err2)[:200]}"} |
| return {"resolved":True,"test_output":(out+err)[:500]} |
| return {"resolved":False,"error":f"{len(re.findall(r'FAILED', out+err))} F2P failures","test_output":(out+err)[:500]} |
|
|
| def setup_env(conda, instance, repo_dir, env_name): |
| ec = instance.get("environment_setup_commit","") |
| if ec: |
| sh(f"cd {repo_dir} && git fetch origin {ec} && git checkout {ec}", timeout=60) |
| |
| eyml = None |
| for c in ["environment.yml","dev/environment.yml",".github/environment.yml","ci/environment.yml"]: |
| if (Path(repo_dir)/c).exists(): eyml = c; break |
| |
| if eyml: |
| rc, out, err = sh(f"{conda} env create -f {repo_dir}/{eyml} -n {env_name} --quiet", timeout=600) |
| else: |
| rc, out, err = sh(f"{conda} create -n {env_name} python=3.10 pip -y", timeout=300) |
| |
| if rc: |
| rc, out, err = sh(f"{conda} create -n {env_name} python=3.10 pip -y", timeout=300) |
| if rc: return False, f"conda: {err[:200]}" |
| |
| base = instance["base_commit"] |
| sh(f"cd {repo_dir} && git fetch origin {base} && git checkout {base}", timeout=60) |
| sh(f"cd {repo_dir} && {conda} run -n {env_name} pip install -e . 2>&1 | tail -3", timeout=300) |
| sh(f"cd {repo_dir} && {conda} run -n {env_name} pip install . 2>&1 | tail -3", timeout=300) |
| return True, "" |
|
|
| def main(): |
| import datasets |
| target = os.environ.get("INSTANCE_TARGET", "cascade-only") |
| max_instances = int(os.environ.get("MAX_INSTANCES", "3")) |
| |
| print(f"π BATCH CASCADE VALIDATION β target={target} max={max_instances}") |
| print(f" {datetime.now().isoformat()}") |
| |
| conda = ensure_conda() |
| if not conda: print("β No conda"); sys.exit(1) |
| |
| ds = datasets.load_dataset("princeton-nlp/SWE-bench_Verified", split="test") |
| |
| if target == "cascade-only": |
| iids = CASCADE_ONLY[:max_instances] |
| elif target == "frontier-only": |
| iids = FRONTIER_ONLY[:max_instances] |
| else: |
| iids = [r["instance_id"] for r in ds][:max_instances] |
| |
| instances = {} |
| for row in ds: |
| if row["instance_id"] in iids: |
| instances[row["instance_id"]] = dict(row) |
| |
| print(f"Instances: {iids}\n") |
| |
| results = [] |
| for i, iid in enumerate(iids): |
| instance = instances.get(iid) |
| if not instance: continue |
| |
| print(f"\n{'='*60}\n[{i+1}/{len(iids)}] {iid}\n{'='*60}") |
| |
| with tempfile.TemporaryDirectory(prefix=f"aco_{i}_") as tmpdir: |
| repo_dir = Path(tmpdir) / "repo" |
| env_name = f"aco_{iid.replace('__','_').replace('-','_')[:30]}" |
| |
| print(f"Clone...") |
| url = f"https://github.com/{instance['repo']}.git" |
| rc, out, err = sh(f"git clone --depth 100 {url} {repo_dir}", timeout=600) |
| if rc: |
| results.append({"instance_id":iid,"resolved":False,"error":f"Clone: {err[:200]}"}) |
| continue |
| |
| print(f"Env...") |
| ok, err = setup_env(conda, instance, repo_dir, env_name) |
| if not ok: |
| results.append({"instance_id":iid,"resolved":False,"error":f"Env: {err}"}) |
| continue |
| |
| print(f"Cascade...") |
| agent = run_cascade(instance, repo_dir, conda, env_name) |
| if not agent["patch"]: |
| results.append({"instance_id":iid,"resolved":False,"tier":None,"error":"No valid patch"}) |
| sh(f"{conda} env remove -n {env_name} -y --quiet 2>/dev/null; true", timeout=30) |
| continue |
| |
| print(f"Verify...") |
| verify = verify_patch(instance, agent["patch"], repo_dir, conda, env_name) |
| |
| r = { |
| "instance_id":iid, "repo":instance["repo"], |
| "resolved":verify["resolved"], "tier":agent["tier"], |
| "turns":agent["turns"], "input_tokens":agent["input_tokens"], |
| "output_tokens":agent["output_tokens"], |
| "error":verify.get("error"), |
| "timestamp":datetime.now().isoformat() |
| } |
| results.append(r) |
| |
| status = "β
" if verify["resolved"] else "β" |
| print(f" {status} {agent['tier']} {agent['turns']}t") |
| |
| sh(f"{conda} env remove -n {env_name} -y --quiet 2>/dev/null; true", timeout=30) |
| |
| |
| with open("batch_results.jsonl","w") as f: |
| for r in results: f.write(json.dumps(r)+"\n") |
| |
| resolved = [r for r in results if r["resolved"]] |
| t1 = [r for r in resolved if r.get("tier")=="T1"] |
| t2 = [r for r in resolved if r.get("tier")=="T2"] |
| |
| print(f"\n{'='*60}\nRESULTS: {len(resolved)}/{len(results)} resolved") |
| print(f" T1: {len(t1)} T2: {len(t2)}") |
| for r in results: |
| s = "β
" if r["resolved"] else "β" |
| print(f" {s} {r['instance_id']} [{r.get('tier','')}]") |
| print(f"Saved: batch_results.jsonl") |
|
|
| if __name__=="__main__": |
| try: sys.exit(main()) |
| except Exception as e: print(f"π₯ {e}"); traceback.print_exc(); sys.exit(1) |
|
|