""" Batch Cascade Validation Script Runs the cascade agent on MULTIPLE SWE-bench instances and verifies each patch via conda environment + pytest. This is the script that proves the cascade causally, not just correlationally from trace simulation. Usage: python batch_validate.py --instances 3 --target cascade-only python batch_validate.py --instances 5 --target all Requirements: hf_jobs with a10g-largex2, 8h timeout """ import json, os, re, subprocess, sys, tempfile, time, traceback from datetime import datetime from pathlib import Path # ============================================================ # INSTANCE LISTS # ============================================================ CASCADE_ONLY = [ "astropy__astropy-14365", "astropy__astropy-14995", "django__django-11815", "django__django-13089", "django__django-13807", "django__django-14315", "matplotlib__matplotlib-25224", "matplotlib__matplotlib-25311", "sympy__sympy-19487", "sympy__sympy-20590", ] FRONTIER_ONLY = [ "django__django-12453", "django__django-14030", "django__django-14349", "django__django-14855", "django__django-15098", "django__django-16235", "matplotlib__matplotlib-26020", "psf__requests-6028", "pylint-dev__pylint-7080", "scikit-learn__scikit-learn-13439", "scikit-learn__scikit-learn-14087", "sphinx-doc__sphinx-10323", "sphinx-doc__sphinx-10466", "sphinx-doc__sphinx-10614", ] def sh(cmd, cwd=None, timeout=120): r = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, timeout=timeout, shell=True) return r.returncode, r.stdout, r.stderr def ensure_conda(): for p in [os.path.expanduser("~/miniconda3/bin/conda"), "/opt/conda/bin/conda"]: if os.path.exists(p): return p print("📦 Installing Miniconda...") sh("wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && bash /tmp/miniconda.sh -b -p $HOME/miniconda3", timeout=300) p = os.path.expanduser("~/miniconda3/bin/conda") sh(f"{p} config --set always_yes yes --set changeps1 no && {p} tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main 2>/dev/null; true", timeout=30) os.environ["PATH"] = os.path.expanduser("~/miniconda3/bin:") + os.environ.get("PATH", "") return p def is_valid_patch(text): return bool(text) and len(text) > 10 and 'diff --git' in text and '@@' in text def extract_patch(text): m = re.search(r'\s*\n?(.*?)', text, re.DOTALL) if m and is_valid_patch(m.group(1)): return m.group(1).strip() m = re.search(r'```diff\s*\n(.*?)```', text, re.DOTALL) if m and is_valid_patch(m.group(1)): return m.group(1).strip() di = text.find('diff --git') if di >= 0: patch = text[di:].strip()[:2000] if is_valid_patch(patch): return patch return None def call_model(client, messages, max_tokens=4096): try: c = client.chat.completions.create(model=client.model, messages=messages, max_tokens=max_tokens, temperature=0.2) t = c.choices[0].message.content it = c.usage.prompt_tokens if hasattr(c,'usage') and c.usage else 0 ot = c.usage.completion_tokens if hasattr(c,'usage') and c.usage else len(t)//4 return t, it, ot except Exception as e: return f"[ERROR: {e}]", 0, 0 def run_cascade(instance, repo_dir, conda, env_name): from huggingface_hub import InferenceClient T1, T2 = "meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.3-70B-Instruct" system = f"Fix bug in {instance['repo']}. Repo: {repo_dir}.\ncmd\n\ndiff --git a/path b/path\n--- a/path\n+++ b/path\n@@ -N,M +N,M @@\nchanges\n\nDone" messages = [{"role":"system","content":system},{"role":"user","content":f"PROBLEM:\n{instance.get('problem_statement','')}\n\nExplore and fix."}] for tier, mid, mt in [("T1",T1,30),("T2",T2,30)]: print(f"\n[{tier}] {mid}") client = InferenceClient(mid) ti = to = 0 for turn in range(mt): text, it, ot = call_model(client, messages, 4096) ti += it; to += ot messages.append({"role":"assistant","content":text}) print(f" T{turn+1}: {it}+{ot} tok") patch = extract_patch(text) if patch: (Path(repo_dir)/"_c.patch").write_text(patch) rc, out, err = sh(f"cd {repo_dir} && git apply --check _c.patch 2>&1", timeout=10) if rc == 0: print(f" ✅ VALID ({len(patch)}ch)") return {"patch":patch,"tier":tier,"turns":turn+1,"input_tokens":ti,"output_tokens":to} print(f" ❌ Invalid: {err[:100]}") messages.append({"role":"user","content":f"Patch check failed: {err[:200]}\nUse git diff for valid unified diff."}) continue for cmd in re.findall(r'(.*?)', text, re.DOTALL): cmd = cmd.strip().replace("pytest", f"{conda} run -n {env_name} python -m pytest") rc, out, err = sh(cmd, cwd=str(repo_dir), timeout=60) o = f"\n{(out+err)[:1500]}\n" if rc: o = o[:-9] + f" [EXIT:{rc}]\n" messages.append({"role":"user","content":o}) if "" in text: break return {"patch":None,"tier":None,"turns":0,"input_tokens":0,"output_tokens":0} def verify_patch(instance, patch, repo_dir, conda, env_name): base = instance.get("base_commit","") tp = instance.get("test_patch","") f2p = instance.get("FAIL_TO_PASS",[]) sh(f"cd {repo_dir} && git checkout -f {base} && git clean -fd", timeout=30) (Path(repo_dir)/"_aco.patch").write_text(patch) rc, out, err = sh(f"cd {repo_dir} && git apply --check _aco.patch", timeout=10) if rc: return {"resolved":False,"error":f"patch check: {err[:150]}"} sh(f"cd {repo_dir} && git apply _aco.patch", timeout=10) (Path(repo_dir)/"_t.patch").write_text(tp) sh(f"cd {repo_dir} && (git apply _t.patch) || git apply --reject _t.patch 2>/dev/null; true", timeout=10) cmd = f"cd {repo_dir} && {conda} run -n {env_name} python -m pytest -v --tb=short -x {' '.join(f2p[:10])}" rc, out, err = sh(cmd, timeout=300) if rc == 0: p2p = instance.get("PASS_TO_PASS",[]) if p2p: cmd2 = f"cd {repo_dir} && {conda} run -n {env_name} python -m pytest -v --tb=short -x {' '.join(p2p[:10])}" rc2, out2, err2 = sh(cmd2, timeout=300) if rc2: return {"resolved":False,"error":f"P2P: {(out2+err2)[:200]}"} return {"resolved":True,"test_output":(out+err)[:500]} return {"resolved":False,"error":f"{len(re.findall(r'FAILED', out+err))} F2P failures","test_output":(out+err)[:500]} def setup_env(conda, instance, repo_dir, env_name): ec = instance.get("environment_setup_commit","") if ec: sh(f"cd {repo_dir} && git fetch origin {ec} && git checkout {ec}", timeout=60) eyml = None for c in ["environment.yml","dev/environment.yml",".github/environment.yml","ci/environment.yml"]: if (Path(repo_dir)/c).exists(): eyml = c; break if eyml: rc, out, err = sh(f"{conda} env create -f {repo_dir}/{eyml} -n {env_name} --quiet", timeout=600) else: rc, out, err = sh(f"{conda} create -n {env_name} python=3.10 pip -y", timeout=300) if rc: rc, out, err = sh(f"{conda} create -n {env_name} python=3.10 pip -y", timeout=300) if rc: return False, f"conda: {err[:200]}" base = instance["base_commit"] sh(f"cd {repo_dir} && git fetch origin {base} && git checkout {base}", timeout=60) sh(f"cd {repo_dir} && {conda} run -n {env_name} pip install -e . 2>&1 | tail -3", timeout=300) sh(f"cd {repo_dir} && {conda} run -n {env_name} pip install . 2>&1 | tail -3", timeout=300) return True, "" def main(): import datasets target = os.environ.get("INSTANCE_TARGET", "cascade-only") max_instances = int(os.environ.get("MAX_INSTANCES", "3")) print(f"🚀 BATCH CASCADE VALIDATION — target={target} max={max_instances}") print(f" {datetime.now().isoformat()}") conda = ensure_conda() if not conda: print("❌ No conda"); sys.exit(1) ds = datasets.load_dataset("princeton-nlp/SWE-bench_Verified", split="test") if target == "cascade-only": iids = CASCADE_ONLY[:max_instances] elif target == "frontier-only": iids = FRONTIER_ONLY[:max_instances] else: iids = [r["instance_id"] for r in ds][:max_instances] instances = {} for row in ds: if row["instance_id"] in iids: instances[row["instance_id"]] = dict(row) print(f"Instances: {iids}\n") results = [] for i, iid in enumerate(iids): instance = instances.get(iid) if not instance: continue print(f"\n{'='*60}\n[{i+1}/{len(iids)}] {iid}\n{'='*60}") with tempfile.TemporaryDirectory(prefix=f"aco_{i}_") as tmpdir: repo_dir = Path(tmpdir) / "repo" env_name = f"aco_{iid.replace('__','_').replace('-','_')[:30]}" print(f"Clone...") url = f"https://github.com/{instance['repo']}.git" rc, out, err = sh(f"git clone --depth 100 {url} {repo_dir}", timeout=600) if rc: results.append({"instance_id":iid,"resolved":False,"error":f"Clone: {err[:200]}"}) continue print(f"Env...") ok, err = setup_env(conda, instance, repo_dir, env_name) if not ok: results.append({"instance_id":iid,"resolved":False,"error":f"Env: {err}"}) continue print(f"Cascade...") agent = run_cascade(instance, repo_dir, conda, env_name) if not agent["patch"]: results.append({"instance_id":iid,"resolved":False,"tier":None,"error":"No valid patch"}) sh(f"{conda} env remove -n {env_name} -y --quiet 2>/dev/null; true", timeout=30) continue print(f"Verify...") verify = verify_patch(instance, agent["patch"], repo_dir, conda, env_name) r = { "instance_id":iid, "repo":instance["repo"], "resolved":verify["resolved"], "tier":agent["tier"], "turns":agent["turns"], "input_tokens":agent["input_tokens"], "output_tokens":agent["output_tokens"], "error":verify.get("error"), "timestamp":datetime.now().isoformat() } results.append(r) status = "✅" if verify["resolved"] else "❌" print(f" {status} {agent['tier']} {agent['turns']}t") sh(f"{conda} env remove -n {env_name} -y --quiet 2>/dev/null; true", timeout=30) # Incremental save with open("batch_results.jsonl","w") as f: for r in results: f.write(json.dumps(r)+"\n") resolved = [r for r in results if r["resolved"]] t1 = [r for r in resolved if r.get("tier")=="T1"] t2 = [r for r in resolved if r.get("tier")=="T2"] print(f"\n{'='*60}\nRESULTS: {len(resolved)}/{len(results)} resolved") print(f" T1: {len(t1)} T2: {len(t2)}") for r in results: s = "✅" if r["resolved"] else "❌" print(f" {s} {r['instance_id']} [{r.get('tier','')}]") print(f"Saved: batch_results.jsonl") if __name__=="__main__": try: sys.exit(main()) except Exception as e: print(f"💥 {e}"); traceback.print_exc(); sys.exit(1)