narcolepticchicken commited on
Commit
f191fbd
·
verified ·
1 Parent(s): 3b0aa48

Upload smoke_test_cascade.py

Browse files
Files changed (1) hide show
  1. smoke_test_cascade.py +407 -0
smoke_test_cascade.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Single-instance smoke test for cascade validation.
3
+
4
+ This tests the ENTIRE pipeline end-to-end:
5
+ 1. Clone repo
6
+ 2. Set up conda environment
7
+ 3. Run cascade agent (T1 Llama-3.1-8B → T2 Llama-3.3-70B via HF Inference)
8
+ 4. Apply patch + test_patch
9
+ 5. Run FAIL_TO_PASS tests
10
+
11
+ This is the minimal test to prove the cascade works causally,
12
+ not just correlatively from trace simulation.
13
+
14
+ Submits to trackio for monitoring.
15
+
16
+ Usage (hf_jobs):
17
+ operation: run
18
+ script: "smoke_test_cascade.py"
19
+ dependencies: ["huggingface_hub", "datasets", "trackio"]
20
+ hardware: a10g-largex2 (need GPU for inference + CPU for conda + memory for clone)
21
+ timeout: 4h
22
+ """
23
+
24
+ import json
25
+ import os
26
+ import re
27
+ import subprocess
28
+ import sys
29
+ import tempfile
30
+ import time
31
+ import traceback
32
+ from datetime import datetime
33
+ from pathlib import Path
34
+
35
+ import datasets
36
+ import huggingface_hub
37
+ import trackio
38
+
39
+ # ============================================================
40
+ # Trackio setup
41
+ # ============================================================
42
+ trackio.init(
43
+ project=os.environ.get("TRACKIO_PROJECT", "aco-smoke-test"),
44
+ )
45
+
46
+ # ============================================================
47
+ # CONFIG
48
+ # ============================================================
49
+ # The easiest SWE-bench instance: django bug with clear fix
50
+ INSTANCE_ID = os.environ.get("INSTANCE_ID", "django__django-14315")
51
+
52
+ T1_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
53
+ T2_MODEL = "meta-llama/Llama-3.3-70B-Instruct"
54
+
55
+ trackio.log_params({
56
+ "instance_id": INSTANCE_ID,
57
+ "t1_model": T1_MODEL,
58
+ "t2_model": T2_MODEL,
59
+ "strategy": "cascade",
60
+ })
61
+
62
+
63
+ def run_shell(cmd, cwd=None, timeout=120):
64
+ result = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, timeout=timeout, shell=True)
65
+ return result.returncode, result.stdout, result.stderr
66
+
67
+
68
+ def call_model(client, messages, max_tokens=4096):
69
+ try:
70
+ completion = client.chat.completions.create(
71
+ model=client.model,
72
+ messages=messages,
73
+ max_tokens=max_tokens,
74
+ temperature=0.2,
75
+ )
76
+ text = completion.choices[0].message.content
77
+ itok = completion.usage.prompt_tokens if hasattr(completion, 'usage') and completion.usage else 0
78
+ otok = completion.usage.completion_tokens if hasattr(completion, 'usage') and completion.usage else len(text) // 4
79
+ return text, itok, otok
80
+ except Exception as e:
81
+ return f"[ERROR: {e}]", 0, 0
82
+
83
+
84
+ def extract_patch(text):
85
+ for tag in ['patch', 'diff']:
86
+ m = re.search(rf'<{tag}>(.*?)</{tag}>', text, re.DOTALL)
87
+ if m:
88
+ return m.group(1).strip()
89
+ for block in ['diff', 'patch']:
90
+ m = re.search(rf'```{block}\s*\n(.*?)```', text, re.DOTALL)
91
+ if m:
92
+ return m.group(1).strip()
93
+ diff_match = re.search(r'(diff --git a/.*?(?:\n(?:@@|\+\+\+|diff --git|```|</).*)*)', text, re.DOTALL)
94
+ if diff_match:
95
+ return diff_match.group(1).strip()
96
+ return None
97
+
98
+
99
+ def run_cascade_agent(instance, repo_dir):
100
+ """Run T1 then T2. Maximum 30 turns each."""
101
+
102
+ problem = instance.get("problem_statement", "")
103
+
104
+ system = f"""You are fixing a bug in {instance['repo']}. Repository is at {repo_dir}.
105
+
106
+ Output format:
107
+ - Bash commands: <bash>command</bash>
108
+ - Final patch: <patch>your diff here</patch>
109
+ - Done: <submit>Done</submit>
110
+
111
+ First explore the codebase to understand the issue, then make a minimal fix and verify it."""
112
+
113
+ messages = [
114
+ {"role": "system", "content": system},
115
+ {"role": "user", "content": f"PROBLEM:\n{problem}\n\nStart by exploring the repository."}
116
+ ]
117
+
118
+ tiers = [
119
+ ("T1", T1_MODEL, 30),
120
+ ("T2", T2_MODEL, 30),
121
+ ]
122
+
123
+ for tier_name, model_id, max_turns in tiers:
124
+ print(f"\n{'='*50}")
125
+ print(f"[{tier_name}] {model_id}")
126
+ print(f"{'='*50}")
127
+
128
+ trackio.log({"event": "tier_start", "tier": tier_name, "model": model_id})
129
+
130
+ client = huggingface_hub.InferenceClient(model_id)
131
+ total_itok = 0
132
+ total_otok = 0
133
+
134
+ for turn in range(max_turns):
135
+ text, itok, otok = call_model(client, messages, max_tokens=4096)
136
+ total_itok += itok
137
+ total_otok += otok
138
+ messages.append({"role": "assistant", "content": text})
139
+
140
+ print(f" Turn {turn+1}: {itok}+{otok} tokens, {len(text)} chars")
141
+
142
+ patch = extract_patch(text)
143
+ if patch:
144
+ print(f" ✅ PATCH FOUND ({len(patch)} chars)")
145
+ trackio.log({
146
+ "event": "patch_found",
147
+ "tier": tier_name,
148
+ "turn": turn + 1,
149
+ "input_tokens": total_itok,
150
+ "output_tokens": total_otok,
151
+ })
152
+ return {
153
+ "patch": patch,
154
+ "tier": tier_name,
155
+ "turns": turn + 1,
156
+ "input_tokens": total_itok,
157
+ "output_tokens": total_otok,
158
+ }
159
+
160
+ cmds = re.findall(r'<bash>(.*?)</bash>', text, re.DOTALL)
161
+ for cmd in cmds:
162
+ cmd = cmd.strip()
163
+ print(f" $ {cmd[:120]}")
164
+ rc, stdout, stderr = run_shell(cmd, cwd=str(repo_dir), timeout=30)
165
+ output = (stdout + stderr)[:1500]
166
+ if rc != 0:
167
+ output += f"\n[EXIT:{rc}]"
168
+ messages.append({"role": "user", "content": f"<output>\n{output}\n</output>"})
169
+
170
+ if "<submit>" in text:
171
+ print(f" [Submit without patch]")
172
+ break
173
+
174
+ trackio.log({"event": "tier_exhausted", "tier": tier_name, "turns": max_turns})
175
+
176
+ return {"patch": None, "tier": None, "turns": 0, "input_tokens": 0, "output_tokens": 0}
177
+
178
+
179
+ def verify_patch(instance, model_patch, repo_dir, env_name=None):
180
+ """Apply patches, run FAIL_TO_PASS tests."""
181
+ base_commit = instance.get("base_commit", "")
182
+ test_patch = instance.get("test_patch", "")
183
+ f2p = instance.get("FAIL_TO_PASS", [])
184
+
185
+ # Reset repo
186
+ run_shell(f"cd {repo_dir} && git checkout -f {base_commit}", timeout=30)
187
+ run_shell(f"cd {repo_dir} && git clean -fd", timeout=30)
188
+
189
+ # Apply model patch
190
+ patch_file = Path(repo_dir) / "_aco.patch"
191
+ patch_file.write_text(model_patch)
192
+ rc, out, err = run_shell(f"cd {repo_dir} && git apply --check _aco.patch", timeout=10)
193
+ if rc != 0:
194
+ return {"resolved": False, "error": f"model patch --check: {err[:300]}"}
195
+ rc, out, err = run_shell(f"cd {repo_dir} && git apply _aco.patch", timeout=10)
196
+ if rc != 0:
197
+ return {"resolved": False, "error": f"model patch apply: {err[:300]}"}
198
+
199
+ # Apply test_patch
200
+ test_file = Path(repo_dir) / "_aco_test.patch"
201
+ test_file.write_text(test_patch)
202
+ rc, out, err = run_shell(f"cd {repo_dir} && git apply --check _aco_test.patch", timeout=10)
203
+ if rc == 0:
204
+ run_shell(f"cd {repo_dir} && git apply _aco_test.patch", timeout=10)
205
+ else:
206
+ run_shell(f"cd {repo_dir} && git apply --reject _aco_test.patch 2>/dev/null; true", timeout=10)
207
+
208
+ # Run F2P tests
209
+ cmd_prefix = f"conda run -n {env_name} " if env_name else ""
210
+ f2p_cmd = f"{cmd_prefix}python -m pytest -v --tb=short -x {' '.join(f2p[:10])}"
211
+ print(f" F2P: {f2p_cmd[:150]}...")
212
+ rc, out, err = run_shell(f"cd {repo_dir} && {f2p_cmd}", timeout=300)
213
+
214
+ if rc == 0:
215
+ # Run P2P regression tests
216
+ p2p = instance.get("PASS_TO_PASS", [])
217
+ if p2p:
218
+ p2p_cmd = f"{cmd_prefix}python -m pytest -v --tb=short -x {' '.join(p2p[:10])}"
219
+ rc2, out2, err2 = run_shell(f"cd {repo_dir} && {p2p_cmd}", timeout=300)
220
+ if rc2 != 0:
221
+ return {"resolved": False, "error": f"P2P regression: {(out2+err2)[:300]}"}
222
+ return {"resolved": True, "test_output": (out + err)[:500]}
223
+
224
+ return {"resolved": False, "error": f"F2P: {len(re.findall(r'FAILED', out+err))} failures", "test_output": (out + err)[:500]}
225
+
226
+
227
+ def main():
228
+ print(f"🚀 CASCADE SMOKE TEST")
229
+ print(f" Instance: {INSTANCE_ID}")
230
+ print(f" T1: {T1_MODEL}")
231
+ print(f" T2: {T2_MODEL}")
232
+ print(f" Time: {datetime.now().isoformat()}")
233
+
234
+ # Load instance
235
+ print("\n[1/5] Loading SWE-bench_Verified...")
236
+ ds = datasets.load_dataset("princeton-nlp/SWE-bench_Verified", split="test")
237
+ instance = None
238
+ for row in ds:
239
+ if row["instance_id"] == INSTANCE_ID:
240
+ instance = dict(row)
241
+ break
242
+
243
+ if not instance:
244
+ print(f"❌ Instance {INSTANCE_ID} not found!")
245
+ trackio.alert("Instance not found", f"{INSTANCE_ID} not in SWE-bench_Verified", level="ERROR")
246
+ sys.exit(1)
247
+
248
+ print(f" Repo: {instance['repo']}")
249
+ print(f" Base: {instance['base_commit'][:12]}")
250
+ print(f" F2P tests: {len(instance.get('FAIL_TO_PASS', []))}")
251
+ print(f" P2P tests: {len(instance.get('PASS_TO_PASS', []))}")
252
+
253
+ trackio.log({
254
+ "event": "instance_loaded",
255
+ "repo": instance["repo"],
256
+ "f2p_count": len(instance.get("FAIL_TO_PASS", [])),
257
+ "p2p_count": len(instance.get("PASS_TO_PASS", [])),
258
+ })
259
+
260
+ with tempfile.TemporaryDirectory(prefix="aco_smoke_") as tmpdir:
261
+ repo_dir = Path(tmpdir) / "repo"
262
+ env_name = f"aco_{INSTANCE_ID.replace('__','_').replace('-','_')[:30]}"
263
+
264
+ # Clone
265
+ print(f"\n[2/5] Cloning repo...")
266
+ t0 = time.time()
267
+ repo = instance["repo"]
268
+ url = f"https://github.com/{repo}.git"
269
+ rc, out, err = run_shell(f"git clone --depth 50 {url} {repo_dir}", timeout=300)
270
+ if rc != 0:
271
+ rc, out, err = run_shell(f"git clone {url} {repo_dir}", timeout=600)
272
+ clone_time = time.time() - t0
273
+
274
+ if rc != 0:
275
+ print(f"❌ Clone failed: {err[:300]}")
276
+ trackio.alert("Clone failed", err[:200], level="ERROR")
277
+ sys.exit(1)
278
+ print(f" Done ({clone_time:.0f}s)")
279
+
280
+ # Set up conda env
281
+ print(f"\n[3/5] Setting up conda environment...")
282
+ t0 = time.time()
283
+
284
+ env_commit = instance.get("environment_setup_commit", "")
285
+ if env_commit:
286
+ run_shell(f"cd {repo_dir} && git fetch origin {env_commit}", timeout=60)
287
+ run_shell(f"cd {repo_dir} && git checkout {env_commit}", timeout=30)
288
+
289
+ env_yml = None
290
+ for c in ["environment.yml", "dev/environment.yml", ".github/environment.yml",
291
+ "ci/environment.yml"]:
292
+ if (repo_dir / c).exists():
293
+ env_yml = c
294
+ break
295
+
296
+ if env_yml:
297
+ print(f" Found: {env_yml}")
298
+ rc, out, err = run_shell(f"cd {repo_dir} && conda env create -f {env_yml} -n {env_name} --quiet", timeout=600)
299
+ else:
300
+ print(f" No environment.yml, creating basic env")
301
+ rc, out, err = run_shell(f"conda create -n {env_name} python=3.10 pip -y --quiet", timeout=300)
302
+
303
+ if rc != 0:
304
+ print(f"❌ Conda env creation failed: {err[:300]}")
305
+ trackio.alert("Conda env failed", err[:300], level="ERROR")
306
+ sys.exit(1)
307
+
308
+ # Install repo at base_commit
309
+ base_commit = instance["base_commit"]
310
+ run_shell(f"cd {repo_dir} && git fetch origin {base_commit}", timeout=60)
311
+ run_shell(f"cd {repo_dir} && git checkout {base_commit}", timeout=30)
312
+ rc, out, err = run_shell(f"cd {repo_dir} && conda run -n {env_name} pip install -e . --quiet", timeout=300)
313
+ if rc != 0:
314
+ print(f"⚠️ pip install had issues: {err[:200]}")
315
+
316
+ env_time = time.time() - t0
317
+ print(f" Done ({env_time:.0f}s)")
318
+
319
+ # Run cascade agent
320
+ print(f"\n[4/5] Running cascade agent (T1→T2)...")
321
+ t0 = time.time()
322
+ agent_result = run_cascade_agent(instance, repo_dir)
323
+ agent_time = time.time() - t0
324
+
325
+ if not agent_result["patch"]:
326
+ print(f"❌ No patch produced")
327
+ trackio.alert("No patch", "Cascade produced no patch", level="ERROR")
328
+ sys.exit(1)
329
+
330
+ print(f"\n ✅ Patch: {len(agent_result['patch'])} chars")
331
+ print(f" Tier: {agent_result['tier']}")
332
+ print(f" Turns: {agent_result['turns']}")
333
+ print(f" Tokens: {agent_result['input_tokens']} in + {agent_result['output_tokens']} out")
334
+ print(f" Time: {agent_time:.0f}s")
335
+
336
+ trackio.log({
337
+ "event": "agent_done",
338
+ "tier": agent_result["tier"],
339
+ "turns": agent_result["turns"],
340
+ "input_tokens": agent_result["input_tokens"],
341
+ "output_tokens": agent_result["output_tokens"],
342
+ "agent_time": agent_time,
343
+ })
344
+
345
+ # Verify
346
+ print(f"\n[5/5] Verifying patch...")
347
+ t0 = time.time()
348
+ verify_result = verify_patch(instance, agent_result["patch"], repo_dir, env_name)
349
+ verify_time = time.time() - t0
350
+
351
+ trackio.log({
352
+ "event": "verification_done",
353
+ "resolved": verify_result["resolved"],
354
+ "verify_time": verify_time,
355
+ })
356
+
357
+ # Final result
358
+ print(f"\n{'='*60}")
359
+ print(f"{'✅ RESOLVED' if verify_result['resolved'] else '❌ FAILED'}")
360
+ print(f"{'='*60}")
361
+ print(f" Instance: {INSTANCE_ID}")
362
+ print(f" Tier: {agent_result['tier']}")
363
+ print(f" Turns: {agent_result['turns']}")
364
+ print(f" Time: clone={clone_time:.0f}s env={env_time:.0f}s agent={agent_time:.0f}s verify={verify_time:.0f}s")
365
+
366
+ if not verify_result["resolved"]:
367
+ print(f" Error: {verify_result.get('error', 'unknown')[:300]}")
368
+ trackio.alert("Not resolved", verify_result.get('error', 'unknown')[:300], level="WARN")
369
+ else:
370
+ trackio.alert("✅ Resolved!", f"Instance {INSTANCE_ID} resolved via {agent_result['tier']}", level="INFO")
371
+
372
+ # Save final result
373
+ final = {
374
+ "instance_id": INSTANCE_ID,
375
+ "repo": instance["repo"],
376
+ "timestamp": datetime.now().isoformat(),
377
+ "resolved": verify_result["resolved"],
378
+ "tier": agent_result["tier"],
379
+ "turns": agent_result["turns"],
380
+ "input_tokens": agent_result["input_tokens"],
381
+ "output_tokens": agent_result["output_tokens"],
382
+ "clone_time": clone_time,
383
+ "env_time": env_time,
384
+ "agent_time": agent_time,
385
+ "verify_time": verify_time,
386
+ "patch_preview": agent_result["patch"][:500],
387
+ "error": verify_result.get("error"),
388
+ }
389
+
390
+ with open("smoke_result.json", "w") as f:
391
+ json.dump(final, f, indent=2)
392
+ print(f"\nSaved: smoke_result.json")
393
+
394
+ # Cleanup
395
+ run_shell(f"conda env remove -n {env_name} -y --quiet", timeout=30)
396
+
397
+ return 0 if verify_result["resolved"] else 1
398
+
399
+
400
+ if __name__ == "__main__":
401
+ try:
402
+ sys.exit(main())
403
+ except Exception as e:
404
+ print(f"💥 CRASH: {e}")
405
+ traceback.print_exc()
406
+ trackio.alert("Crash", str(e)[:300], level="ERROR")
407
+ sys.exit(1)