File size: 11,768 Bytes
29c4a80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""
Batch Cascade Validation Script

Runs the cascade agent on MULTIPLE SWE-bench instances and
verifies each patch via conda environment + pytest.

This is the script that proves the cascade causally, not just
correlationally from trace simulation.

Usage:
    python batch_validate.py --instances 3 --target cascade-only
    python batch_validate.py --instances 5 --target all

Requirements: hf_jobs with a10g-largex2, 8h timeout
"""

import json, os, re, subprocess, sys, tempfile, time, traceback
from datetime import datetime
from pathlib import Path

# ============================================================
# INSTANCE LISTS
# ============================================================
CASCADE_ONLY = [
    "astropy__astropy-14365", "astropy__astropy-14995",
    "django__django-11815", "django__django-13089",
    "django__django-13807", "django__django-14315",
    "matplotlib__matplotlib-25224", "matplotlib__matplotlib-25311",
    "sympy__sympy-19487", "sympy__sympy-20590",
]

FRONTIER_ONLY = [
    "django__django-12453", "django__django-14030",
    "django__django-14349", "django__django-14855",
    "django__django-15098", "django__django-16235",
    "matplotlib__matplotlib-26020", "psf__requests-6028",
    "pylint-dev__pylint-7080", "scikit-learn__scikit-learn-13439",
    "scikit-learn__scikit-learn-14087", "sphinx-doc__sphinx-10323",
    "sphinx-doc__sphinx-10466", "sphinx-doc__sphinx-10614",
]

def sh(cmd, cwd=None, timeout=120):
    r = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, timeout=timeout, shell=True)
    return r.returncode, r.stdout, r.stderr

def ensure_conda():
    for p in [os.path.expanduser("~/miniconda3/bin/conda"), "/opt/conda/bin/conda"]:
        if os.path.exists(p): return p
    print("πŸ“¦ Installing Miniconda...")
    sh("wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && bash /tmp/miniconda.sh -b -p $HOME/miniconda3", timeout=300)
    p = os.path.expanduser("~/miniconda3/bin/conda")
    sh(f"{p} config --set always_yes yes --set changeps1 no && {p} tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main 2>/dev/null; true", timeout=30)
    os.environ["PATH"] = os.path.expanduser("~/miniconda3/bin:") + os.environ.get("PATH", "")
    return p

def is_valid_patch(text):
    return bool(text) and len(text) > 10 and 'diff --git' in text and '@@' in text

def extract_patch(text):
    m = re.search(r'<patch>\s*\n?(.*?)</patch>', text, re.DOTALL)
    if m and is_valid_patch(m.group(1)): return m.group(1).strip()
    m = re.search(r'```diff\s*\n(.*?)```', text, re.DOTALL)
    if m and is_valid_patch(m.group(1)): return m.group(1).strip()
    di = text.find('diff --git')
    if di >= 0:
        patch = text[di:].strip()[:2000]
        if is_valid_patch(patch): return patch
    return None

def call_model(client, messages, max_tokens=4096):
    try:
        c = client.chat.completions.create(model=client.model, messages=messages, max_tokens=max_tokens, temperature=0.2)
        t = c.choices[0].message.content
        it = c.usage.prompt_tokens if hasattr(c,'usage') and c.usage else 0
        ot = c.usage.completion_tokens if hasattr(c,'usage') and c.usage else len(t)//4
        return t, it, ot
    except Exception as e: return f"[ERROR: {e}]", 0, 0

def run_cascade(instance, repo_dir, conda, env_name):
    from huggingface_hub import InferenceClient
    T1, T2 = "meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.3-70B-Instruct"
    
    system = f"Fix bug in {instance['repo']}. Repo: {repo_dir}.\n<bash>cmd</bash>\n<patch>\ndiff --git a/path b/path\n--- a/path\n+++ b/path\n@@ -N,M +N,M @@\nchanges\n</patch>\n<submit>Done</submit>"
    messages = [{"role":"system","content":system},{"role":"user","content":f"PROBLEM:\n{instance.get('problem_statement','')}\n\nExplore and fix."}]
    
    for tier, mid, mt in [("T1",T1,30),("T2",T2,30)]:
        print(f"\n[{tier}] {mid}")
        client = InferenceClient(mid)
        ti = to = 0
        for turn in range(mt):
            text, it, ot = call_model(client, messages, 4096)
            ti += it; to += ot
            messages.append({"role":"assistant","content":text})
            print(f"  T{turn+1}: {it}+{ot} tok")
            
            patch = extract_patch(text)
            if patch:
                (Path(repo_dir)/"_c.patch").write_text(patch)
                rc, out, err = sh(f"cd {repo_dir} && git apply --check _c.patch 2>&1", timeout=10)
                if rc == 0:
                    print(f"  βœ… VALID ({len(patch)}ch)")
                    return {"patch":patch,"tier":tier,"turns":turn+1,"input_tokens":ti,"output_tokens":to}
                print(f"  ❌ Invalid: {err[:100]}")
                messages.append({"role":"user","content":f"Patch check failed: {err[:200]}\nUse git diff for valid unified diff."})
                continue
            
            for cmd in re.findall(r'<bash>(.*?)</bash>', text, re.DOTALL):
                cmd = cmd.strip().replace("pytest", f"{conda} run -n {env_name} python -m pytest")
                rc, out, err = sh(cmd, cwd=str(repo_dir), timeout=60)
                o = f"<output>\n{(out+err)[:1500]}\n</output>"
                if rc: o = o[:-9] + f" [EXIT:{rc}]\n</output>"
                messages.append({"role":"user","content":o})
            
            if "<submit>" in text: break
    return {"patch":None,"tier":None,"turns":0,"input_tokens":0,"output_tokens":0}

def verify_patch(instance, patch, repo_dir, conda, env_name):
    base = instance.get("base_commit","")
    tp = instance.get("test_patch","")
    f2p = instance.get("FAIL_TO_PASS",[])
    
    sh(f"cd {repo_dir} && git checkout -f {base} && git clean -fd", timeout=30)
    
    (Path(repo_dir)/"_aco.patch").write_text(patch)
    rc, out, err = sh(f"cd {repo_dir} && git apply --check _aco.patch", timeout=10)
    if rc: return {"resolved":False,"error":f"patch check: {err[:150]}"}
    sh(f"cd {repo_dir} && git apply _aco.patch", timeout=10)
    
    (Path(repo_dir)/"_t.patch").write_text(tp)
    sh(f"cd {repo_dir} && (git apply _t.patch) || git apply --reject _t.patch 2>/dev/null; true", timeout=10)
    
    cmd = f"cd {repo_dir} && {conda} run -n {env_name} python -m pytest -v --tb=short -x {' '.join(f2p[:10])}"
    rc, out, err = sh(cmd, timeout=300)
    
    if rc == 0:
        p2p = instance.get("PASS_TO_PASS",[])
        if p2p:
            cmd2 = f"cd {repo_dir} && {conda} run -n {env_name} python -m pytest -v --tb=short -x {' '.join(p2p[:10])}"
            rc2, out2, err2 = sh(cmd2, timeout=300)
            if rc2: return {"resolved":False,"error":f"P2P: {(out2+err2)[:200]}"}
        return {"resolved":True,"test_output":(out+err)[:500]}
    return {"resolved":False,"error":f"{len(re.findall(r'FAILED', out+err))} F2P failures","test_output":(out+err)[:500]}

def setup_env(conda, instance, repo_dir, env_name):
    ec = instance.get("environment_setup_commit","")
    if ec:
        sh(f"cd {repo_dir} && git fetch origin {ec} && git checkout {ec}", timeout=60)
    
    eyml = None
    for c in ["environment.yml","dev/environment.yml",".github/environment.yml","ci/environment.yml"]:
        if (Path(repo_dir)/c).exists(): eyml = c; break
    
    if eyml:
        rc, out, err = sh(f"{conda} env create -f {repo_dir}/{eyml} -n {env_name} --quiet", timeout=600)
    else:
        rc, out, err = sh(f"{conda} create -n {env_name} python=3.10 pip -y", timeout=300)
    
    if rc:
        rc, out, err = sh(f"{conda} create -n {env_name} python=3.10 pip -y", timeout=300)
        if rc: return False, f"conda: {err[:200]}"
    
    base = instance["base_commit"]
    sh(f"cd {repo_dir} && git fetch origin {base} && git checkout {base}", timeout=60)
    sh(f"cd {repo_dir} && {conda} run -n {env_name} pip install -e . 2>&1 | tail -3", timeout=300)
    sh(f"cd {repo_dir} && {conda} run -n {env_name} pip install . 2>&1 | tail -3", timeout=300)
    return True, ""

def main():
    import datasets
    target = os.environ.get("INSTANCE_TARGET", "cascade-only")
    max_instances = int(os.environ.get("MAX_INSTANCES", "3"))
    
    print(f"πŸš€ BATCH CASCADE VALIDATION β€” target={target} max={max_instances}")
    print(f"   {datetime.now().isoformat()}")
    
    conda = ensure_conda()
    if not conda: print("❌ No conda"); sys.exit(1)
    
    ds = datasets.load_dataset("princeton-nlp/SWE-bench_Verified", split="test")
    
    if target == "cascade-only":
        iids = CASCADE_ONLY[:max_instances]
    elif target == "frontier-only":
        iids = FRONTIER_ONLY[:max_instances]
    else:
        iids = [r["instance_id"] for r in ds][:max_instances]
    
    instances = {}
    for row in ds:
        if row["instance_id"] in iids:
            instances[row["instance_id"]] = dict(row)
    
    print(f"Instances: {iids}\n")
    
    results = []
    for i, iid in enumerate(iids):
        instance = instances.get(iid)
        if not instance: continue
        
        print(f"\n{'='*60}\n[{i+1}/{len(iids)}] {iid}\n{'='*60}")
        
        with tempfile.TemporaryDirectory(prefix=f"aco_{i}_") as tmpdir:
            repo_dir = Path(tmpdir) / "repo"
            env_name = f"aco_{iid.replace('__','_').replace('-','_')[:30]}"
            
            print(f"Clone...")
            url = f"https://github.com/{instance['repo']}.git"
            rc, out, err = sh(f"git clone --depth 100 {url} {repo_dir}", timeout=600)
            if rc:
                results.append({"instance_id":iid,"resolved":False,"error":f"Clone: {err[:200]}"})
                continue
            
            print(f"Env...")
            ok, err = setup_env(conda, instance, repo_dir, env_name)
            if not ok:
                results.append({"instance_id":iid,"resolved":False,"error":f"Env: {err}"})
                continue
            
            print(f"Cascade...")
            agent = run_cascade(instance, repo_dir, conda, env_name)
            if not agent["patch"]:
                results.append({"instance_id":iid,"resolved":False,"tier":None,"error":"No valid patch"})
                sh(f"{conda} env remove -n {env_name} -y --quiet 2>/dev/null; true", timeout=30)
                continue
            
            print(f"Verify...")
            verify = verify_patch(instance, agent["patch"], repo_dir, conda, env_name)
            
            r = {
                "instance_id":iid, "repo":instance["repo"],
                "resolved":verify["resolved"], "tier":agent["tier"],
                "turns":agent["turns"], "input_tokens":agent["input_tokens"],
                "output_tokens":agent["output_tokens"],
                "error":verify.get("error"),
                "timestamp":datetime.now().isoformat()
            }
            results.append(r)
            
            status = "βœ…" if verify["resolved"] else "❌"
            print(f"  {status} {agent['tier']} {agent['turns']}t")
            
            sh(f"{conda} env remove -n {env_name} -y --quiet 2>/dev/null; true", timeout=30)
        
        # Incremental save
        with open("batch_results.jsonl","w") as f:
            for r in results: f.write(json.dumps(r)+"\n")
    
    resolved = [r for r in results if r["resolved"]]
    t1 = [r for r in resolved if r.get("tier")=="T1"]
    t2 = [r for r in resolved if r.get("tier")=="T2"]
    
    print(f"\n{'='*60}\nRESULTS: {len(resolved)}/{len(results)} resolved")
    print(f"  T1: {len(t1)}  T2: {len(t2)}")
    for r in results:
        s = "βœ…" if r["resolved"] else "❌"
        print(f"  {s} {r['instance_id']} [{r.get('tier','')}]")
    print(f"Saved: batch_results.jsonl")

if __name__=="__main__":
    try: sys.exit(main())
    except Exception as e: print(f"πŸ’₯ {e}"); traceback.print_exc(); sys.exit(1)