code-debug-env / inference.py
Souravdanyal's picture
error fixed
f931480
raw
history blame
10.3 kB
#!/usr/bin/env python3
"""
inference.py - Code Debug Environment Baseline Agent
Required env vars: API_BASE_URL, MODEL_NAME, HF_TOKEN
Usage:
python inference.py
python inference.py --url https://Souravdanyal-code-debug-env.hf.space
python inference.py --difficulty easy
STDOUT FORMAT (required by evaluator):
[START] task=<id> env=<benchmark> model=<model>
[STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> rewards=<r1,r2,...>
"""
import os, sys, json, time, argparse, requests
from openai import OpenAI
from typing import List, Optional
# ── Config ────────────────────────────────────────────────────────────────────
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
BENCHMARK = "code-debug-env"
MAX_STEPS = 5
client = OpenAI(api_key=HF_TOKEN or "dummy", base_url=API_BASE_URL)
# ── Logging ───────────────────────────────────────────────────────────────────
def log_start(task_id, env, model):
print(f"[START] task={task_id} env={env} model={model}", flush=True)
def log_step(step, action, reward, done, error):
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error or 'null'}", flush=True)
def log_end(success, steps, rewards):
print(f"[END] success={str(success).lower()} steps={steps} rewards={','.join(f'{r:.2f}' for r in rewards)}", flush=True)
# ── Env client ────────────────────────────────────────────────────────────────
def env_reset(url, difficulty):
r = requests.post(f"{url}/reset", json={"difficulty": difficulty}, timeout=30)
r.raise_for_status()
return r.json()
def env_step(url, fixed_code, explanation=None):
payload = {"fixed_code": fixed_code}
if explanation:
payload["explanation"] = explanation
r = requests.post(f"{url}/step", json=payload, timeout=30)
r.raise_for_status()
return r.json()
# ── LLM ──────────────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """You are an expert Python debugging agent. Fix bugs in Python functions.
RESPONSE FORMAT β€” strictly JSON only, no markdown:
{
"fixed_code": "<complete corrected Python function including imports>",
"explanation": "<for hard tasks: explain the bug, root cause, and fix>"
}
RULES:
- Return COMPLETE function with all imports (e.g. from collections import deque)
- fixed_code must be valid Python
- For hard tasks explanation MUST mention the algorithmic concept listed in instructions
COMMON BUG PATTERNS:
- List rotation RIGHT by k: correct is lst[-k:] + lst[:-k] NOT lst[k:] + lst[:k]
- List rotation LEFT by k: correct is lst[k:] + lst[:k]
- Graph/BFS missing visited set β†’ infinite loop β†’ add visited=set()
- 0/1 Knapsack: must iterate BACKWARD: range(capacity, weight-1, -1) not forward
- Binary search wrong boundary: return high not low, or high=n//2
- Off-by-one: lst[2] should be lst[1] for second element
- Wrong operator: complement = target - n NOT target + n
FOR HARD TASKS β€” explanation MUST include words from the instructions hint.
Example: if instructions say "mention: iteration order" then write about iteration order.
Example: if instructions say "mention: visited" then write about visited set.
"""
def call_llm(buggy_code, instructions, difficulty, feedback=None, attempt=1, prev_code=None):
content = f"Difficulty: {difficulty}\nInstructions: {instructions}\n\nBuggy code:\n```python\n{buggy_code}\n```\n"
if feedback and attempt > 1:
content += f"\nPREVIOUS FIX FAILED. Feedback:\n{feedback}\n\nYour previous code:\n```python\n{prev_code or ''}\n```\n"
content += "IMPORTANT: Your fix did not work. Look at the Expected vs Got values carefully.\n"
content += "- If Got is a LEFT rotation but Expected is RIGHT: use lst[-k:] + lst[:-k]\n"
content += "- If you see TimeoutError: add visited=set() for graph traversal\n"
content += "- Try a COMPLETELY DIFFERENT approach.\n"
if difficulty == "hard":
# Extract keyword hints from instructions (e.g. "mention: visited, queue")
import re
hint_match = re.search(r'[Mm]ention[:\s]+([^.]+)', instructions)
if hint_match:
hints = hint_match.group(1).strip()
content += f"\nFor your explanation, you MUST mention these concepts: {hints}\n"
content += "Include a detailed explanation field β€” it counts for 30% of your reward.\n"
try:
resp = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": content}],
max_tokens=1500,
temperature=0.1 if attempt == 1 else 0.4,
)
raw = resp.choices[0].message.content.strip()
# Remove markdown fences
if "```json" in raw:
raw = raw.split("```json")[1].split("```")[0].strip()
elif "```" in raw:
raw = raw.split("```")[1].split("```")[0].strip()
if raw.startswith("json"):
raw = raw[4:].strip()
# Find JSON object boundaries
start = raw.find("{")
end = raw.rfind("}") + 1
if start >= 0 and end > start:
raw = raw[start:end]
# Try direct parse first
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
# Fix control characters by replacing literal newlines inside strings
import re
# Replace actual newlines within JSON string values with \n escape
raw = re.sub(r'(?<!\\)\n', r'\\n', raw)
raw = re.sub(r'(?<!\\)\t', r'\\t', raw)
raw = re.sub(r'(?<!\\)\r', r'\\r', raw)
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
# Last resort: extract fixed_code manually using regex
code_match = re.search(r'"fixed_code"\s*:\s*"(.*?)"(?=\s*[,}])', raw, re.DOTALL)
exp_match = re.search(r'"explanation"\s*:\s*"(.*?)"(?=\s*[,}])', raw, re.DOTALL)
if code_match:
code = code_match.group(1).encode().decode('unicode_escape') if '\\n' in code_match.group(1) else code_match.group(1)
return {"fixed_code": code, "explanation": exp_match.group(1) if exp_match else None}
raise
return {"fixed_code": parsed.get("fixed_code", ""), "explanation": parsed.get("explanation")}
except Exception as e:
print(f"# LLM error: {e}", file=sys.stderr)
return {"fixed_code": buggy_code, "explanation": None}
# ── Episode ───────────────────────────────────────────────────────────────────
def run_episode(env_url, difficulty):
data = env_reset(env_url, difficulty)
obs = data["observation"]
task_id = obs["task_id"]
buggy_code = obs["buggy_code"]
instructions = obs["instructions"]
log_start(task_id, BENCHMARK, MODEL_NAME)
rewards, steps_taken, success = [], 0, False
last_feedback, last_code = None, None
for attempt in range(1, MAX_STEPS + 1):
steps_taken = attempt
action = call_llm(buggy_code, instructions, difficulty, last_feedback, attempt, last_code)
code = action["fixed_code"]
last_code = code
if not code or not code.strip():
log_step(attempt, "empty_submission", 0.0, False, "empty_code")
rewards.append(0.0)
continue
try:
result = env_step(env_url, code, action.get("explanation"))
except Exception as e:
log_step(attempt, "step_failed", 0.0, False, str(e)[:60])
rewards.append(0.0)
continue
reward = result.get("reward", 0.0)
done = result.get("done", False)
last_feedback = result.get("observation", {}).get("feedback", "")
log_step(attempt, f"fix_{difficulty}_attempt{attempt}", reward, done, None)
rewards.append(reward)
if reward >= 1.0:
success = True
if done:
break
log_end(success, steps_taken, rewards)
return success, steps_taken, rewards
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--url", default=ENV_URL)
parser.add_argument("--difficulty", default=None, choices=["easy","medium","hard","all"])
args = parser.parse_args()
url = args.url.rstrip("/")
try:
requests.get(f"{url}/health", timeout=10).raise_for_status()
print(f"# Environment healthy at {url}", flush=True)
except Exception as e:
print(f"# Health check failed: {e}", file=sys.stderr)
sys.exit(1)
diffs = ["easy","medium","hard"] if args.difficulty in (None,"all") else [args.difficulty]
all_rewards, successes = [], []
for d in diffs:
ok, _, rewards = run_episode(url, d)
all_rewards.extend(rewards)
successes.append(ok)
time.sleep(0.5)
avg = round(sum(all_rewards)/len(all_rewards), 3) if all_rewards else 0.0
print(f"# SUMMARY: {sum(successes)}/{len(diffs)} tasks solved | avg_reward={avg}", flush=True)
if __name__ == "__main__":
main()