# debug_baseline.py # drop in ml_debug_env/server/, run: python debug_baseline.py import os import sys import json import subprocess import tempfile sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from openai import OpenAI from bug_generator import get_scenario, TASK_SHAPE_MISMATCH GROQ_BASE_URL = "https://api.groq.com/openai/v1" MODEL = "llama-3.3-70b-versatile" api_key = os.environ.get("GROQ_API_KEY", "") if not api_key: print("Set GROQ_API_KEY first"); sys.exit(1) client = OpenAI(api_key=api_key, base_url=GROQ_BASE_URL) scenario = get_scenario(TASK_SHAPE_MISMATCH, seed=42) SYSTEM_PROMPT = """You are an expert ML engineer specializing in debugging PyTorch training code. You will be given a broken Python training script and a description of how it fails. Your job is to: 1. Identify the exact bug type 2. Explain the root cause clearly 3. Return a complete corrected script that fixes the issue You must respond with valid JSON in exactly this format: { "bug_type": "", "diagnosis": "", "fixed_code": "" } Rules: - fixed_code must be the COMPLETE script, not a diff or partial fix - fixed_code must include all imports - Do not add markdown code fences inside the JSON string - Do not add any text outside the JSON object""" user_prompt = f"""Task: {scenario.task_description} Broken script: ```python {scenario.buggy_code} ``` Failure observed: {scenario.error_output} Respond with JSON only.""" response = client.chat.completions.create( model=MODEL, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], temperature=0.0, max_tokens=2048, response_format={"type": "json_object"}, ) raw = response.choices[0].message.content.strip() print("=== RAW RESPONSE (first 500 chars) ===") print(raw[:500]) print("\n=== PARSING ===") parsed = json.loads(raw) print(f"bug_type: {parsed.get('bug_type')}") print(f"diagnosis: {parsed.get('diagnosis', '')[:100]}") fixed_code = parsed.get("fixed_code", "") print(f"\nfixed_code length: {len(fixed_code)} chars") print(f"fixed_code first 300 chars:\n{fixed_code[:300]}") print("\n=== RUNNING FIXED CODE ===") with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f: f.write(fixed_code) tmp = f.name result = subprocess.run( [sys.executable, tmp], capture_output=True, text=True, timeout=30 ) print(f"Return code: {result.returncode}") print(f"STDOUT:\n{result.stdout[:300]}") print(f"STDERR:\n{result.stderr[:300]}")