File size: 2,721 Bytes
70a9d5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# debug_baseline.py
# drop in ml_debug_env/server/, run: python debug_baseline.py

import os
import sys
import json
import subprocess
import tempfile

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from openai import OpenAI
from bug_generator import get_scenario, TASK_SHAPE_MISMATCH

GROQ_BASE_URL = "https://api.groq.com/openai/v1"
MODEL = "llama-3.3-70b-versatile"

api_key = os.environ.get("GROQ_API_KEY", "")
if not api_key:
    print("Set GROQ_API_KEY first"); sys.exit(1)

client = OpenAI(api_key=api_key, base_url=GROQ_BASE_URL)

scenario = get_scenario(TASK_SHAPE_MISMATCH, seed=42)

SYSTEM_PROMPT = """You are an expert ML engineer specializing in debugging PyTorch training code.

You will be given a broken Python training script and a description of how it fails.
Your job is to:
1. Identify the exact bug type
2. Explain the root cause clearly
3. Return a complete corrected script that fixes the issue

You must respond with valid JSON in exactly this format:
{
  "bug_type": "<one of: shape_mismatch, training_collapse, data_leakage, other>",
  "diagnosis": "<clear explanation of the root cause>",
  "fixed_code": "<complete corrected Python script, runnable as-is>"
}

Rules:
- fixed_code must be the COMPLETE script, not a diff or partial fix
- fixed_code must include all imports
- Do not add markdown code fences inside the JSON string
- Do not add any text outside the JSON object"""

user_prompt = f"""Task: {scenario.task_description}

Broken script:
```python
{scenario.buggy_code}
```

Failure observed:
{scenario.error_output}

Respond with JSON only."""

response = client.chat.completions.create(
    model=MODEL,
    messages=[
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt},
    ],
    temperature=0.0,
    max_tokens=2048,
    response_format={"type": "json_object"},
)

raw = response.choices[0].message.content.strip()
print("=== RAW RESPONSE (first 500 chars) ===")
print(raw[:500])
print("\n=== PARSING ===")

parsed = json.loads(raw)
print(f"bug_type: {parsed.get('bug_type')}")
print(f"diagnosis: {parsed.get('diagnosis', '')[:100]}")

fixed_code = parsed.get("fixed_code", "")
print(f"\nfixed_code length: {len(fixed_code)} chars")
print(f"fixed_code first 300 chars:\n{fixed_code[:300]}")

print("\n=== RUNNING FIXED CODE ===")
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
    f.write(fixed_code)
    tmp = f.name

result = subprocess.run(
    [sys.executable, tmp],
    capture_output=True, text=True, timeout=30
)
print(f"Return code: {result.returncode}")
print(f"STDOUT:\n{result.stdout[:300]}")
print(f"STDERR:\n{result.stderr[:300]}")