# server/baseline_inference.py """ Baseline inference script. Uses API_BASE_URL and API_KEY environment variables injected by the validator. Falls back to Groq if those are not set (local dev). Model: llama-3.3-70b-versatile """ import os import sys import json _server_dir = os.path.dirname(os.path.abspath(__file__)) if _server_dir not in sys.path: sys.path.insert(0, _server_dir) from openai import OpenAI from bug_generator import ( get_scenario, TASK_SHAPE_MISMATCH, TASK_TRAINING_COLLAPSE, TASK_DATA_LEAKAGE, ) from grader import grade GROQ_BASE_URL = "https://api.groq.com/openai/v1" MODEL = "llama-3.3-70b-versatile" SEED = 42 SYSTEM_PROMPT = """You are an expert ML engineer specializing in debugging PyTorch training code. You will be given a broken Python training script and a description of how it fails. Your job is to: 1. Identify the exact bug type 2. Explain the root cause clearly 3. Return a complete corrected script that fixes the issue You must respond with valid JSON in exactly this format: { "bug_type": "", "diagnosis": "", "fixed_code": "" } Rules: - fixed_code must be the COMPLETE script, not a diff or partial fix - fixed_code must include all imports - Do not add markdown code fences inside the JSON string - Do not add any text outside the JSON object""" def build_user_prompt(task_description: str, buggy_code: str, error_output: str) -> str: return f"""Task: {task_description} Broken script: ```python {buggy_code} ``` Failure observed: {error_output} Respond with JSON only.""" def call_llm(client: OpenAI, user_prompt: str) -> dict: response = client.chat.completions.create( model=MODEL, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], temperature=0.0, max_tokens=2048, response_format={"type": "json_object"}, ) content = response.choices[0].message.content.strip() return json.loads(content) def run_single_task(client: OpenAI, task_id: str) -> dict: scenario = get_scenario(task_id, seed=SEED) user_prompt = build_user_prompt( scenario.task_description, scenario.buggy_code, scenario.error_output, ) try: parsed = call_llm(client, user_prompt) bug_type = parsed.get("bug_type", "other") diagnosis = parsed.get("diagnosis", "") fixed_code = parsed.get("fixed_code", "") except (json.JSONDecodeError, KeyError) as e: return { "task_id": task_id, "score": 0.0, "feedback": f"Model returned invalid JSON: {e}", "bug_type_submitted": "parse_error", "execution_output": "", } result = grade( action_bug_type=bug_type, action_diagnosis=diagnosis, fixed_code=fixed_code, scenario=scenario, ) return { "task_id": task_id, "score": result.score, "feedback": result.feedback, "bug_type_submitted": bug_type, "execution_output": result.execution_output, } def run_baseline_on_all_tasks(api_key: str, base_url: str) -> list: client = OpenAI( api_key=api_key, base_url=base_url, ) results = [] for task_id in [TASK_SHAPE_MISMATCH, TASK_TRAINING_COLLAPSE, TASK_DATA_LEAKAGE]: print(f" Running baseline on: {task_id} ...", flush=True) result = run_single_task(client, task_id) results.append(result) print(f" Score: {result['score']} | Bug type: {result['bug_type_submitted']}", flush=True) return results if __name__ == "__main__": # Use injected proxy creds if available, fall back to Groq for local dev api_key = (os.environ.get("API_KEY") or os.environ.get("GROQ_API_KEY", "")).strip() base_url = os.environ.get("API_BASE_URL") or GROQ_BASE_URL if not api_key: print("ERROR: API_KEY (or GROQ_API_KEY) environment variable not set.") sys.exit(1) print(f"Running baseline with model: {MODEL}") print(f"Base URL: {base_url}") print(f"Seed: {SEED}\n") results = run_baseline_on_all_tasks(api_key, base_url) print("\n=== BASELINE RESULTS ===") total = 0.0 for r in results: print(f"\nTask: {r['task_id']}") print(f" Score: {r['score']}") print(f" Bug type: {r['bug_type_submitted']}") print(f" Feedback: {r['feedback'][:120]}...") total += r["score"] avg = total / len(results) print(f"\nAverage score: {avg:.4f}") print("========================")