code-debug-env / inference.py
Souravdanyal's picture
Update inference.py
40ac3c8 verified
#!/usr/bin/env python3
"""
inference.py - Code Debug Environment Baseline Agent
Required env vars: API_BASE_URL, MODEL_NAME, HF_TOKEN
Usage:
python inference.py
python inference.py --url https://Souravdanyal-code-debug-env.hf.space
python inference.py --difficulty easy
STDOUT FORMAT (strictly required by evaluator - plaintext):
[START] task=<id> env=<benchmark> model=<model>
[STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<0.00> rewards=<r1,r2,...>
"""
import os, sys, json, time, argparse, requests, re
from openai import OpenAI
from typing import List, Optional
# Load .env file if it exists
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass # dotenv not installed, will use system env vars
# ── Config ────────────────────────────────────────────────────────────────────
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant")
HF_TOKEN = os.getenv("HF_TOKEN")
HF_TOKEN_SOURCE = "HF_TOKEN"
if not HF_TOKEN:
HF_TOKEN = os.getenv("API_KEY")
HF_TOKEN_SOURCE = "API_KEY"
if not HF_TOKEN:
HF_TOKEN = os.getenv("hf_token")
HF_TOKEN_SOURCE = "hf_token"
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
ENV_URL = os.getenv("ENV_URL")
BENCHMARK = "code-debug-env"
MAX_STEPS = 5
SUCCESS_SCORE_THRESHOLD = 0.5
client = OpenAI(api_key=HF_TOKEN or "dummy", base_url=API_BASE_URL)
# ── Logging β€” STRICT PLAINTEXT FORMAT ────────────────────────────────────────
def _format_bool(value: bool) -> str:
return "true" if value else "false"
def _normalize_token(value: str) -> str:
return re.sub(r"\s+", " ", str(value)).strip()
def _format_error(error: Optional[str]) -> str:
if error is None:
return "null"
text = str(error).replace("\r", "\\r").replace("\n", "\\n")
return text if text else "null"
def _format_rewards(rewards: List[float]) -> str:
return ",".join(f"{round(r, 2):.2f}" for r in rewards)
def log_start(task_id: str, env: str, model: str) -> None:
print(
f"[START] task={_normalize_token(task_id)} env={_normalize_token(env)} model={_normalize_token(model)}",
flush=True,
)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
print(
f"[STEP] step={step} action={_normalize_token(action)} reward={round(reward, 2):.2f} "
f"done={_format_bool(done)} error={_format_error(error)}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
print(
f"[END] success={_format_bool(success)} steps={steps} score={round(score, 2):.2f} "
f"rewards={_format_rewards(rewards)}",
flush=True,
)
# ── Env client ────────────────────────────────────────────────────────────────
def env_reset(url: str, difficulty: str) -> dict:
r = requests.post(f"{url}/reset", json={"difficulty": difficulty}, timeout=30)
r.raise_for_status()
return r.json()
def env_step(url: str, fixed_code: str, explanation: Optional[str] = None) -> dict:
payload = {"fixed_code": fixed_code}
if explanation:
payload["explanation"] = explanation
r = requests.post(f"{url}/step", json=payload, timeout=30)
r.raise_for_status()
return r.json()
# ── LLM ──────────────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """You are an expert Python debugging agent.
RESPONSE FORMAT β€” JSON only, no markdown fences, no extra text:
{"fixed_code": "<complete Python function with all imports>", "explanation": "<for hard tasks only>"}
RULES:
- Return the COMPLETE function including all imports (e.g. from collections import deque)
- fixed_code must be valid, executable Python
- For hard tasks: explanation MUST mention the algorithmic concepts from the instructions
COMMON BUG PATTERNS β€” memorize these:
- RIGHT rotate list by k: lst[-k:] + lst[:-k] (NOT lst[k:] + lst[:k] which is LEFT rotate)
- LEFT rotate list by k: lst[k:] + lst[:k]
- BFS/graph missing visited: add visited=set([start]) before queue, check before appending
- 0/1 Knapsack: iterate BACKWARD range(capacity, weight-1, -1) NOT forward
- Binary search boundary: often return high not low, or initial high=n//2 not n
- Wrong operator: target-n not target+n for complement
- Off-by-one: lst[1] for second element not lst[2]
IMPORTANT: If feedback shows TimeoutError, you have infinite loop. Add visited set.
IMPORTANT: If Expected shows right-rotated list, use lst[-k:] + lst[:-k].
"""
def _parse_llm_response(raw: str, buggy_code: str) -> dict:
"""Robustly parse LLM response handling control chars and malformed JSON."""
# Remove markdown fences
if "```json" in raw:
raw = raw.split("```json")[1].split("```")[0].strip()
elif "```" in raw:
parts = raw.split("```")
if len(parts) >= 2:
raw = parts[1].strip()
if raw.startswith("json"):
raw = raw[4:].strip()
# Find JSON boundaries
start = raw.find("{")
end = raw.rfind("}") + 1
if start >= 0 and end > start:
raw = raw[start:end]
# Try direct parse
try:
parsed = json.loads(raw)
return {
"fixed_code": parsed.get("fixed_code", ""),
"explanation": parsed.get("explanation"),
}
except json.JSONDecodeError:
pass
# Fix literal control characters inside JSON strings
try:
fixed = re.sub(r'(?<!\\)\n', r'\\n', raw)
fixed = re.sub(r'(?<!\\)\t', r'\\t', fixed)
fixed = re.sub(r'(?<!\\)\r', r'\\r', fixed)
parsed = json.loads(fixed)
code = parsed.get("fixed_code", "")
if "\\n" in code:
code = code.replace("\\n", "\n").replace("\\t", "\t")
return {"fixed_code": code, "explanation": parsed.get("explanation")}
except json.JSONDecodeError:
pass
# Last resort: regex extraction
code_match = re.search(r'"fixed_code"\s*:\s*"((?:[^"\\]|\\.)*)"', raw, re.DOTALL)
exp_match = re.search(r'"explanation"\s*:\s*"((?:[^"\\]|\\.)*)"', raw, re.DOTALL)
if code_match:
code = code_match.group(1).replace("\\n", "\n").replace("\\t", "\t")
exp = exp_match.group(1).replace("\\n", "\n") if exp_match else None
return {"fixed_code": code, "explanation": exp}
# Complete fallback
return {"fixed_code": buggy_code, "explanation": None}
def call_llm(
buggy_code: str,
instructions: str,
difficulty: str,
feedback: Optional[str] = None,
attempt: int = 1,
prev_code: Optional[str] = None,
) -> dict:
content = (
f"Difficulty: {difficulty}\n"
f"Instructions: {instructions}\n\n"
f"Buggy code:\n```python\n{buggy_code}\n```\n"
)
if feedback and attempt > 1:
content += (
f"\nPREVIOUS FIX FAILED. Feedback:\n{feedback}\n\n"
f"Your previous code:\n```python\n{prev_code or ''}\n```\n"
"ANALYZE THE FEEDBACK CAREFULLY:\n"
"- Look at Input/Expected/Got for each failing test\n"
"- If Got shows wrong rotation direction: use lst[-k:] + lst[:-k] for RIGHT rotate\n"
"- If TimeoutError: add visited=set([start]) before queue in graph code\n"
"- Try a COMPLETELY DIFFERENT fix.\n"
)
if difficulty == "hard":
hint_match = re.search(r'[Mm]ention[:\s]+([^.]+?)(?:\.|$)', instructions)
if hint_match:
hints = hint_match.group(1).strip()
content += f"\nFor explanation, you MUST mention these concepts: {hints}\n"
content += "Explanation counts for 30% of reward β€” make it detailed and specific.\n"
try:
resp = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": content},
],
max_tokens=1500,
temperature=0.1 if attempt == 1 else 0.4,
)
raw = resp.choices[0].message.content.strip()
return _parse_llm_response(raw, buggy_code)
except Exception as e:
print(f"# LLM error: {e}", file=sys.stderr)
return {"fixed_code": buggy_code, "explanation": None}
# ── Episode ───────────────────────────────────────────────────────────────────
def run_episode(env_url: str, difficulty: str) -> tuple:
"""Run one full episode. Returns (success, steps_taken, rewards)."""
data = env_reset(env_url, difficulty)
obs = data["observation"]
task_id = obs["task_id"]
buggy_code = obs["buggy_code"]
instructions = obs["instructions"]
log_start(task_id, BENCHMARK, MODEL_NAME)
rewards: List[float] = []
steps_taken = 0
success = False
last_feedback = None
last_code = None
try:
for attempt in range(1, MAX_STEPS + 1):
steps_taken = attempt
action = call_llm(buggy_code, instructions, difficulty, last_feedback, attempt, last_code)
code = action.get("fixed_code") or ""
last_code = code
reward: float = 0.0
done: bool = False
step_error: Optional[str] = None
try:
result = env_step(env_url, code, action.get("explanation"))
reward = result.get("reward", 0.0)
done = result.get("done", False)
obs_r = result.get("observation", {})
if isinstance(obs_r, dict):
last_feedback = obs_r.get("feedback", "")
step_error = obs_r.get("last_action_error") or obs_r.get("error")
except Exception as e:
step_error = str(e)
log_step(attempt, f"fix_{difficulty}_attempt{attempt}", reward, done, step_error)
rewards.append(reward)
if reward >= 1.0:
success = True
if done:
break
finally:
score = max(rewards) if rewards else 0.0
score = min(max(score, 0.0), 1.0)
success = success or (score >= SUCCESS_SCORE_THRESHOLD)
log_end(success, steps_taken, score, rewards)
return success, steps_taken, rewards
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Code Debug Environment Baseline Agent")
parser.add_argument("--url", default=ENV_URL or "http://localhost:7860")
parser.add_argument(
"--difficulty", default=None,
choices=["easy", "medium", "hard", "all"],
)
args = parser.parse_args()
url = args.url.rstrip("/")
if not HF_TOKEN:
print(
"# Missing API key. Set HF_TOKEN (or API_KEY / lowercase hf_token).",
file=sys.stderr, flush=True,
)
sys.exit(1)
print(f"# Using API key from {HF_TOKEN_SOURCE}", file=sys.stderr, flush=True)
# Health check
try:
requests.get(f"{url}/health", timeout=10).raise_for_status()
print(f"# Environment healthy at {url}", file=sys.stderr, flush=True)
except Exception as e:
print(f"# Health check failed: {e}", file=sys.stderr)
sys.exit(1)
diffs = ["easy", "medium", "hard"] if args.difficulty in (None, "all") else [args.difficulty]
all_rewards: List[float] = []
successes: List[bool] = []
for d in diffs:
ok, _, rewards = run_episode(url, d)
all_rewards.extend(rewards)
successes.append(ok)
time.sleep(0.5)
avg = round(sum(all_rewards) / len(all_rewards), 3) if all_rewards else 0.0
print(
f"# SUMMARY: {sum(successes)}/{len(diffs)} tasks solved | avg_reward={avg}",
file=sys.stderr, flush=True,
)
if __name__ == "__main__":
main()