#!/usr/bin/env python3 """ inference.py - Code Debug Environment Baseline Agent Required env vars: API_BASE_URL, MODEL_NAME, HF_TOKEN Usage: python inference.py python inference.py --url https://Souravdanyal-code-debug-env.hf.space python inference.py --difficulty easy STDOUT FORMAT (required by evaluator): [START] task= env= model= [STEP] step= action= reward=<0.00> done= error= [END] success= steps= rewards= """ import os, sys, json, time, argparse, requests from openai import OpenAI from typing import List, Optional # ── Config ──────────────────────────────────────────────────────────────────── API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1") MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini") HF_TOKEN = os.environ.get("HF_TOKEN", "") ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860") BENCHMARK = "code-debug-env" MAX_STEPS = 5 client = OpenAI(api_key=HF_TOKEN or "dummy", base_url=API_BASE_URL) # ── Logging ─────────────────────────────────────────────────────────────────── def log_start(task_id, env, model): print(f"[START] task={task_id} env={env} model={model}", flush=True) def log_step(step, action, reward, done, error): print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error or 'null'}", flush=True) def log_end(success, steps, rewards): print(f"[END] success={str(success).lower()} steps={steps} rewards={','.join(f'{r:.2f}' for r in rewards)}", flush=True) # ── Env client ──────────────────────────────────────────────────────────────── def env_reset(url, difficulty): r = requests.post(f"{url}/reset", json={"difficulty": difficulty}, timeout=30) r.raise_for_status() return r.json() def env_step(url, fixed_code, explanation=None): payload = {"fixed_code": fixed_code} if explanation: payload["explanation"] = explanation r = requests.post(f"{url}/step", json=payload, timeout=30) r.raise_for_status() return r.json() # ── LLM ────────────────────────────────────────────────────────────────────── SYSTEM_PROMPT = """You are an expert Python debugging agent. Fix bugs in Python functions. RESPONSE FORMAT — strictly JSON only, no markdown: { "fixed_code": "", "explanation": "" } RULES: - Return COMPLETE function with all imports (e.g. from collections import deque) - fixed_code must be valid Python - For hard tasks explanation MUST mention the algorithmic concept listed in instructions COMMON BUG PATTERNS: - List rotation RIGHT by k: correct is lst[-k:] + lst[:-k] NOT lst[k:] + lst[:k] - List rotation LEFT by k: correct is lst[k:] + lst[:k] - Graph/BFS missing visited set → infinite loop → add visited=set() - 0/1 Knapsack: must iterate BACKWARD: range(capacity, weight-1, -1) not forward - Binary search wrong boundary: return high not low, or high=n//2 - Off-by-one: lst[2] should be lst[1] for second element - Wrong operator: complement = target - n NOT target + n FOR HARD TASKS — explanation MUST include words from the instructions hint. Example: if instructions say "mention: iteration order" then write about iteration order. Example: if instructions say "mention: visited" then write about visited set. """ def call_llm(buggy_code, instructions, difficulty, feedback=None, attempt=1, prev_code=None): content = f"Difficulty: {difficulty}\nInstructions: {instructions}\n\nBuggy code:\n```python\n{buggy_code}\n```\n" if feedback and attempt > 1: content += f"\nPREVIOUS FIX FAILED. Feedback:\n{feedback}\n\nYour previous code:\n```python\n{prev_code or ''}\n```\n" content += "IMPORTANT: Your fix did not work. Look at the Expected vs Got values carefully.\n" content += "- If Got is a LEFT rotation but Expected is RIGHT: use lst[-k:] + lst[:-k]\n" content += "- If you see TimeoutError: add visited=set() for graph traversal\n" content += "- Try a COMPLETELY DIFFERENT approach.\n" if difficulty == "hard": # Extract keyword hints from instructions (e.g. "mention: visited, queue") import re hint_match = re.search(r'[Mm]ention[:\s]+([^.]+)', instructions) if hint_match: hints = hint_match.group(1).strip() content += f"\nFor your explanation, you MUST mention these concepts: {hints}\n" content += "Include a detailed explanation field — it counts for 30% of your reward.\n" try: resp = client.chat.completions.create( model=MODEL_NAME, messages=[{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": content}], max_tokens=1500, temperature=0.1 if attempt == 1 else 0.4, ) raw = resp.choices[0].message.content.strip() # Remove markdown fences if "```json" in raw: raw = raw.split("```json")[1].split("```")[0].strip() elif "```" in raw: raw = raw.split("```")[1].split("```")[0].strip() if raw.startswith("json"): raw = raw[4:].strip() # Find JSON object boundaries start = raw.find("{") end = raw.rfind("}") + 1 if start >= 0 and end > start: raw = raw[start:end] # Try direct parse first try: parsed = json.loads(raw) except json.JSONDecodeError: # Fix control characters by replacing literal newlines inside strings import re # Replace actual newlines within JSON string values with \n escape raw = re.sub(r'(?= 1.0: success = True if done: break log_end(success, steps_taken, rewards) return success, steps_taken, rewards # ── Main ────────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser() parser.add_argument("--url", default=ENV_URL) parser.add_argument("--difficulty", default=None, choices=["easy","medium","hard","all"]) args = parser.parse_args() url = args.url.rstrip("/") try: requests.get(f"{url}/health", timeout=10).raise_for_status() print(f"# Environment healthy at {url}", flush=True) except Exception as e: print(f"# Health check failed: {e}", file=sys.stderr) sys.exit(1) diffs = ["easy","medium","hard"] if args.difficulty in (None,"all") else [args.difficulty] all_rewards, successes = [], [] for d in diffs: ok, _, rewards = run_episode(url, d) all_rewards.extend(rewards) successes.append(ok) time.sleep(0.5) avg = round(sum(all_rewards)/len(all_rewards), 3) if all_rewards else 0.0 print(f"# SUMMARY: {sum(successes)}/{len(diffs)} tasks solved | avg_reward={avg}", flush=True) if __name__ == "__main__": main()