narcolepticchicken commited on
Commit
cffce7e
·
verified ·
1 Parent(s): f191fbd

Upload smoke_test.py

Browse files
Files changed (1) hide show
  1. smoke_test.py +371 -0
smoke_test.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cascade Smoke Test — Self-Contained HF Job Script
3
+
4
+ Tests the cascade agent end-to-end on one SWE-bench instance.
5
+ Installs miniconda if needed, sets up environment, runs agent,
6
+ verifies patch.
7
+
8
+ Usage via hf_jobs:
9
+ operation: run
10
+ script: "https://huggingface.co/narcolepticchicken/agent-cost-optimizer/resolve/main/smoke_test.py"
11
+ dependencies: ["huggingface_hub", "datasets", "trackio"]
12
+ hardware: a10g-largex2
13
+ timeout: 4h
14
+ env:
15
+ INSTANCE_ID: "django__django-14315"
16
+ """
17
+
18
+ import json
19
+ import os
20
+ import re
21
+ import subprocess
22
+ import sys
23
+ import tempfile
24
+ import time
25
+ import traceback
26
+ from datetime import datetime
27
+ from pathlib import Path
28
+
29
+ # ============================================================
30
+ # Bootstrap: ensure conda is available
31
+ # ============================================================
32
+
33
+ def ensure_conda():
34
+ """Install miniconda if not present."""
35
+ for path in [os.path.expanduser("~/miniconda3/bin/conda"), "/opt/conda/bin/conda"]:
36
+ if os.path.exists(path):
37
+ return path
38
+
39
+ print("📦 Installing Miniconda...")
40
+ rc, out, err = subprocess_run(
41
+ "wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && "
42
+ "bash /tmp/miniconda.sh -b -p $HOME/miniconda3",
43
+ timeout=120
44
+ )
45
+ if rc != 0:
46
+ print(f"Conda install failed: {err}")
47
+ return None
48
+
49
+ conda_path = os.path.expanduser("~/miniconda3/bin/conda")
50
+ # Add to PATH for this session
51
+ os.environ["PATH"] = os.path.expanduser("~/miniconda3/bin:") + os.environ.get("PATH", "")
52
+ return conda_path
53
+
54
+
55
+ def subprocess_run(cmd, cwd=None, timeout=120):
56
+ result = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, timeout=timeout, shell=True)
57
+ return result.returncode, result.stdout, result.stderr
58
+
59
+
60
+ def conda_run(conda, subcmd, timeout=300):
61
+ """Run a conda command with timeout."""
62
+ return subprocess_run(f"{conda} {subcmd}", timeout=timeout)
63
+
64
+
65
+ # ============================================================
66
+ # Agent
67
+ # ============================================================
68
+
69
+ def call_model(client, messages, max_tokens=4096):
70
+ try:
71
+ completion = client.chat.completions.create(
72
+ model=client.model,
73
+ messages=messages,
74
+ max_tokens=max_tokens,
75
+ temperature=0.2,
76
+ )
77
+ text = completion.choices[0].message.content
78
+ itok = completion.usage.prompt_tokens if hasattr(completion, 'usage') and completion.usage else 0
79
+ otok = completion.usage.completion_tokens if hasattr(completion, 'usage') and completion.usage else len(text) // 4
80
+ return text, itok, otok
81
+ except Exception as e:
82
+ return f"[ERROR: {e}]", 0, 0
83
+
84
+
85
+ def extract_patch(text):
86
+ for tag in ['patch', 'diff']:
87
+ m = re.search(rf'<{tag}>(.*?)</{tag}>', text, re.DOTALL)
88
+ if m:
89
+ return m.group(1).strip()
90
+ for block in ['diff', 'patch']:
91
+ m = re.search(rf'```{block}\s*\n(.*?)```', text, re.DOTALL)
92
+ if m:
93
+ return m.group(1).strip()
94
+ diff_match = re.search(r'(diff --git a/.*?(?:\n(?:@@|\+\+\+|diff --git|```|</).*)*)', text, re.DOTALL)
95
+ if diff_match:
96
+ return diff_match.group(1).strip()
97
+ return None
98
+
99
+
100
+ def run_cascade(instance, repo_dir, conda, env_name):
101
+ """Run T1 → T2 cascade agent."""
102
+ from huggingface_hub import InferenceClient
103
+
104
+ T1 = "meta-llama/Llama-3.1-8B-Instruct"
105
+ T2 = "meta-llama/Llama-3.3-70B-Instruct"
106
+
107
+ problem = instance.get("problem_statement", "")
108
+
109
+ system = f"""You are fixing a bug in {instance['repo']}. Repository is at {repo_dir}.
110
+
111
+ Output format:
112
+ - Bash commands: <bash>command</bash>
113
+ - Final patch: <patch>your diff here</patch>
114
+ - Done: <submit>Done</submit>
115
+
116
+ First explore the codebase to understand the issue, then make a minimal fix.
117
+ Use pytest to verify. Be thorough but efficient."""
118
+
119
+ messages = [
120
+ {"role": "system", "content": system},
121
+ {"role": "user", "content": f"PROBLEM:\n{problem}\n\nStart by exploring the repository structure."}
122
+ ]
123
+
124
+ tiers = [("T1", T1, 30), ("T2", T2, 30)]
125
+
126
+ for tier_name, model_id, max_turns in tiers:
127
+ print(f"\n[{tier_name}] {model_id}")
128
+ client = InferenceClient(model_id)
129
+ total_itok = 0
130
+ total_otok = 0
131
+
132
+ for turn in range(max_turns):
133
+ text, itok, otok = call_model(client, messages, max_tokens=4096)
134
+ total_itok += itok
135
+ total_otok += otok
136
+ messages.append({"role": "assistant", "content": text})
137
+ print(f" Turn {turn+1}: {itok}+{otok} tok, {len(text)} ch")
138
+
139
+ patch = extract_patch(text)
140
+ if patch:
141
+ print(f" ✅ PATCH ({len(patch)} ch)")
142
+ return {"patch": patch, "tier": tier_name, "turns": turn + 1, "input_tokens": total_itok, "output_tokens": total_otok}
143
+
144
+ cmds = re.findall(r'<bash>(.*?)</bash>', text, re.DOTALL)
145
+ for cmd in cmds:
146
+ cmd = cmd.strip()
147
+ cmd = cmd.replace("pytest", f"{conda} run -n {env_name} python -m pytest") if "pytest" in cmd else cmd
148
+ print(f" $ {cmd[:120]}")
149
+ rc, stdout, stderr = subprocess_run(cmd, cwd=str(repo_dir), timeout=60)
150
+ output = (stdout + stderr)[:1500]
151
+ if rc != 0:
152
+ output += f" [EXIT:{rc}]"
153
+ messages.append({"role": "user", "content": f"<output>\n{output}\n</output>"})
154
+
155
+ if "<submit>" in text:
156
+ break
157
+
158
+ return {"patch": None, "tier": None, "turns": 0, "input_tokens": 0, "output_tokens": 0}
159
+
160
+
161
+ def verify_patch(instance, model_patch, repo_dir, conda, env_name):
162
+ """Apply patches, run FAIL_TO_PASS tests."""
163
+ base_commit = instance.get("base_commit", "")
164
+ test_patch = instance.get("test_patch", "")
165
+ f2p = instance.get("FAIL_TO_PASS", [])
166
+
167
+ # Reset repo
168
+ subprocess_run(f"cd {repo_dir} && git checkout -f {base_commit}", timeout=30)
169
+ subprocess_run(f"cd {repo_dir} && git clean -fd", timeout=30)
170
+
171
+ # Apply model patch
172
+ (Path(repo_dir) / "_aco.patch").write_text(model_patch)
173
+ rc, out, err = subprocess_run(f"cd {repo_dir} && git apply --check _aco.patch", timeout=10)
174
+ if rc != 0:
175
+ return {"resolved": False, "error": f"patch --check: {err[:200]}"}
176
+ rc, out, err = subprocess_run(f"cd {repo_dir} && git apply _aco.patch", timeout=10)
177
+ if rc != 0:
178
+ return {"resolved": False, "error": f"patch apply: {err[:200]}"}
179
+
180
+ # Apply test_patch
181
+ (Path(repo_dir) / "_aco_test.patch").write_text(test_patch)
182
+ subprocess_run(f"cd {repo_dir} && git apply --check _aco_test.patch && git apply _aco_test.patch || git apply --reject _aco_test.patch", timeout=10)
183
+
184
+ # Run F2P tests
185
+ f2p_str = ' '.join(f2p[:10])
186
+ cmd = f"cd {repo_dir} && {conda} run -n {env_name} python -m pytest -v --tb=short -x {f2p_str}"
187
+ print(f" F2P: pytest {' '.join(f2p[:3])}...")
188
+ rc, out, err = subprocess_run(cmd, timeout=300)
189
+
190
+ if rc == 0:
191
+ p2p = instance.get("PASS_TO_PASS", [])
192
+ if p2p:
193
+ p2p_str = ' '.join(p2p[:10])
194
+ cmd2 = f"cd {repo_dir} && {conda} run -n {env_name} python -m pytest -v --tb=short -x {p2p_str}"
195
+ rc2, out2, err2 = subprocess_run(cmd2, timeout=300)
196
+ if rc2 != 0:
197
+ return {"resolved": False, "error": f"P2P: {(out2+err2)[:200]}"}
198
+ return {"resolved": True, "test_output": (out + err)[:500]}
199
+
200
+ failures = len(re.findall(r'FAILED', out + err))
201
+ return {"resolved": False, "error": f"{failures} F2P failures", "test_output": (out + err)[:500]}
202
+
203
+
204
+ def setup_environment(conda, instance, repo_dir, env_name):
205
+ """Create conda environment for the repo."""
206
+ env_commit = instance.get("environment_setup_commit", "")
207
+
208
+ if env_commit:
209
+ subprocess_run(f"cd {repo_dir} && git fetch origin {env_commit}", timeout=60)
210
+ subprocess_run(f"cd {repo_dir} && git checkout {env_commit}", timeout=30)
211
+
212
+ # Find environment.yml
213
+ env_yml = None
214
+ for c in ["environment.yml", "dev/environment.yml", ".github/environment.yml",
215
+ "ci/environment.yml"]:
216
+ if (Path(repo_dir) / c).exists():
217
+ env_yml = c
218
+ break
219
+
220
+ if env_yml:
221
+ print(f" Using {env_yml}")
222
+ rc, out, err = conda_run(conda, f"env create -f {repo_dir}/{env_yml} -n {env_name} --quiet", timeout=600)
223
+ else:
224
+ print(f" Creating basic python=3.10 env")
225
+ rc, out, err = conda_run(conda, f"create -n {env_name} python=3.10 pip -y --quiet", timeout=300)
226
+
227
+ if rc != 0:
228
+ print(f" ⚠️ env creation failed: {err[:200]}")
229
+ # Try with just pip
230
+ rc2, out2, err2 = conda_run(conda, f"create -n {env_name} python=3.10 pip -y --quiet", timeout=300)
231
+ if rc2 != 0:
232
+ return False, f"conda env: {err[:200]}"
233
+
234
+ # Install repo
235
+ base_commit = instance["base_commit"]
236
+ subprocess_run(f"cd {repo_dir} && git fetch origin {base_commit}", timeout=60)
237
+ subprocess_run(f"cd {repo_dir} && git checkout {base_commit}", timeout=30)
238
+ rc, out, err = conda_run(conda, f"run -n {env_name} pip install -e . --quiet",
239
+ cwd=str(repo_dir), timeout=300)
240
+ if rc != 0:
241
+ print(f" ⚠️ pip install: {err[:200]}")
242
+ # Try without -e
243
+ rc2, out2, err2 = conda_run(conda, f"run -n {env_name} pip install . --quiet",
244
+ cwd=str(repo_dir), timeout=300)
245
+
246
+ return True, ""
247
+
248
+
249
+ def main():
250
+ INSTANCE_ID = os.environ.get("INSTANCE_ID", "django__django-14315")
251
+ T1_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
252
+ T2_MODEL = "meta-llama/Llama-3.3-70B-Instruct"
253
+
254
+ print("="*60)
255
+ print(f"🚀 CASCADE SMOKE TEST — {datetime.now().isoformat()}")
256
+ print("="*60)
257
+ print(f" Instance: {INSTANCE_ID}")
258
+ print(f" T1: {T1_MODEL}")
259
+ print(f" T2: {T2_MODEL}")
260
+
261
+ # Ensure conda
262
+ conda = ensure_conda()
263
+ if not conda:
264
+ print("❌ Cannot install conda")
265
+ sys.exit(1)
266
+ print(f"\n✅ Conda: {conda}")
267
+
268
+ # Load instance
269
+ print(f"\n[1/5] Loading instance...")
270
+ import datasets
271
+ ds = datasets.load_dataset("princeton-nlp/SWE-bench_Verified", split="test")
272
+ instance = None
273
+ for row in ds:
274
+ if row["instance_id"] == INSTANCE_ID:
275
+ instance = dict(row)
276
+ break
277
+
278
+ if not instance:
279
+ print(f"❌ {INSTANCE_ID} not found!")
280
+ sys.exit(1)
281
+
282
+ print(f" Repo: {instance['repo']}")
283
+ print(f" Base: {instance['base_commit'][:12]}")
284
+ print(f" F2P: {len(instance.get('FAIL_TO_PASS', []))} tests")
285
+ print(f" P2P: {len(instance.get('PASS_TO_PASS', []))} tests")
286
+
287
+ with tempfile.TemporaryDirectory(prefix="aco_smoke_") as tmpdir:
288
+ repo_dir = Path(tmpdir) / "repo"
289
+ env_name = f"aco_{INSTANCE_ID.replace('__','_').replace('-','_')[:30]}"
290
+
291
+ # Clone
292
+ print(f"\n[2/5] Cloning {instance['repo']}...")
293
+ t0 = time.time()
294
+ url = f"https://github.com/{instance['repo']}.git"
295
+ rc, out, err = subprocess_run(f"git clone --depth 50 {url} {repo_dir}", timeout=300)
296
+ if rc != 0:
297
+ rc, out, err = subprocess_run(f"git clone {url} {repo_dir}", timeout=600)
298
+ clone_t = time.time() - t0
299
+ if rc != 0:
300
+ print(f"❌ Clone: {err[:200]}")
301
+ sys.exit(1)
302
+ print(f" Done ({clone_t:.0f}s)")
303
+
304
+ # Environment
305
+ print(f"\n[3/5] Setting up conda env '{env_name}'...")
306
+ t0 = time.time()
307
+ ok, err = setup_environment(conda, instance, repo_dir, env_name)
308
+ env_t = time.time() - t0
309
+ if not ok:
310
+ print(f"❌ Env setup: {err}")
311
+ sys.exit(1)
312
+ print(f" Done ({env_t:.0f}s)")
313
+
314
+ # Agent
315
+ print(f"\n[4/5] Running cascade agent...")
316
+ t0 = time.time()
317
+ agent = run_cascade(instance, repo_dir, conda, env_name)
318
+ agent_t = time.time() - t0
319
+
320
+ if not agent["patch"]:
321
+ print(f"❌ No patch!")
322
+ sys.exit(1)
323
+
324
+ print(f" ✅ {agent['tier']}, {agent['turns']} turns, {agent['input_tokens']}+{agent['output_tokens']} tokens")
325
+
326
+ # Verify
327
+ print(f"\n[5/5] Verifying...")
328
+ t0 = time.time()
329
+ verify = verify_patch(instance, agent["patch"], repo_dir, conda, env_name)
330
+ verify_t = time.time() - t0
331
+
332
+ # Result
333
+ status = "✅ RESOLVED" if verify["resolved"] else "❌ FAILED"
334
+ print(f"\n{'='*60}")
335
+ print(f"{status}")
336
+ print(f"{'='*60}")
337
+ print(f" Times: clone={clone_t:.0f}s env={env_t:.0f}s agent={agent_t:.0f}s verify={verify_t:.0f}s")
338
+ if not verify["resolved"]:
339
+ print(f" Error: {verify.get('error', 'unknown')[:300]}")
340
+
341
+ # Save
342
+ result = {
343
+ "instance_id": INSTANCE_ID,
344
+ "repo": instance["repo"],
345
+ "resolved": verify["resolved"],
346
+ "tier": agent["tier"],
347
+ "turns": agent["turns"],
348
+ "input_tokens": agent["input_tokens"],
349
+ "output_tokens": agent["output_tokens"],
350
+ "times": {"clone": clone_t, "env": env_t, "agent": agent_t, "verify": verify_t},
351
+ "error": verify.get("error"),
352
+ "patch_preview": agent["patch"][:500],
353
+ }
354
+
355
+ with open("smoke_result.json", "w") as f:
356
+ json.dump(result, f, indent=2)
357
+ print(f"\nSaved: smoke_result.json")
358
+
359
+ # Cleanup
360
+ conda_run(conda, f"env remove -n {env_name} -y --quiet", timeout=30)
361
+
362
+ return 0 if verify["resolved"] else 1
363
+
364
+
365
+ if __name__ == "__main__":
366
+ try:
367
+ sys.exit(main())
368
+ except Exception as e:
369
+ print(f"💥 {e}")
370
+ traceback.print_exc()
371
+ sys.exit(1)