Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| inference.py - Code Debug Environment Baseline Agent | |
| Required env vars: API_BASE_URL, MODEL_NAME, HF_TOKEN | |
| Usage: | |
| python inference.py | |
| python inference.py --url https://Souravdanyal-code-debug-env.hf.space | |
| python inference.py --difficulty easy | |
| STDOUT FORMAT (strictly required by evaluator - plaintext): | |
| [START] task=<id> env=<benchmark> model=<model> | |
| [STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null> | |
| [END] success=<true|false> steps=<n> score=<0.00> rewards=<r1,r2,...> | |
| """ | |
| import os, sys, json, time, argparse, requests, re | |
| from openai import OpenAI | |
| from typing import List, Optional | |
| # Load .env file if it exists | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass # dotenv not installed, will use system env vars | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| HF_TOKEN_SOURCE = "HF_TOKEN" | |
| if not HF_TOKEN: | |
| HF_TOKEN = os.getenv("API_KEY") | |
| HF_TOKEN_SOURCE = "API_KEY" | |
| if not HF_TOKEN: | |
| HF_TOKEN = os.getenv("hf_token") | |
| HF_TOKEN_SOURCE = "hf_token" | |
| LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") | |
| ENV_URL = os.getenv("ENV_URL") | |
| BENCHMARK = "code-debug-env" | |
| MAX_STEPS = 5 | |
| SUCCESS_SCORE_THRESHOLD = 0.5 | |
| client = OpenAI(api_key=HF_TOKEN or "dummy", base_url=API_BASE_URL) | |
| # ββ Logging β STRICT PLAINTEXT FORMAT ββββββββββββββββββββββββββββββββββββββββ | |
| def _format_bool(value: bool) -> str: | |
| return "true" if value else "false" | |
| def _normalize_token(value: str) -> str: | |
| return re.sub(r"\s+", " ", str(value)).strip() | |
| def _format_error(error: Optional[str]) -> str: | |
| if error is None: | |
| return "null" | |
| text = str(error).replace("\r", "\\r").replace("\n", "\\n") | |
| return text if text else "null" | |
| def _format_rewards(rewards: List[float]) -> str: | |
| return ",".join(f"{round(r, 2):.2f}" for r in rewards) | |
| def log_start(task_id: str, env: str, model: str) -> None: | |
| print( | |
| f"[START] task={_normalize_token(task_id)} env={_normalize_token(env)} model={_normalize_token(model)}", | |
| flush=True, | |
| ) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: | |
| print( | |
| f"[STEP] step={step} action={_normalize_token(action)} reward={round(reward, 2):.2f} " | |
| f"done={_format_bool(done)} error={_format_error(error)}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| print( | |
| f"[END] success={_format_bool(success)} steps={steps} score={round(score, 2):.2f} " | |
| f"rewards={_format_rewards(rewards)}", | |
| flush=True, | |
| ) | |
| # ββ Env client ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def env_reset(url: str, difficulty: str) -> dict: | |
| r = requests.post(f"{url}/reset", json={"difficulty": difficulty}, timeout=30) | |
| r.raise_for_status() | |
| return r.json() | |
| def env_step(url: str, fixed_code: str, explanation: Optional[str] = None) -> dict: | |
| payload = {"fixed_code": fixed_code} | |
| if explanation: | |
| payload["explanation"] = explanation | |
| r = requests.post(f"{url}/step", json=payload, timeout=30) | |
| r.raise_for_status() | |
| return r.json() | |
| # ββ LLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = """You are an expert Python debugging agent. | |
| RESPONSE FORMAT β JSON only, no markdown fences, no extra text: | |
| {"fixed_code": "<complete Python function with all imports>", "explanation": "<for hard tasks only>"} | |
| RULES: | |
| - Return the COMPLETE function including all imports (e.g. from collections import deque) | |
| - fixed_code must be valid, executable Python | |
| - For hard tasks: explanation MUST mention the algorithmic concepts from the instructions | |
| COMMON BUG PATTERNS β memorize these: | |
| - RIGHT rotate list by k: lst[-k:] + lst[:-k] (NOT lst[k:] + lst[:k] which is LEFT rotate) | |
| - LEFT rotate list by k: lst[k:] + lst[:k] | |
| - BFS/graph missing visited: add visited=set([start]) before queue, check before appending | |
| - 0/1 Knapsack: iterate BACKWARD range(capacity, weight-1, -1) NOT forward | |
| - Binary search boundary: often return high not low, or initial high=n//2 not n | |
| - Wrong operator: target-n not target+n for complement | |
| - Off-by-one: lst[1] for second element not lst[2] | |
| IMPORTANT: If feedback shows TimeoutError, you have infinite loop. Add visited set. | |
| IMPORTANT: If Expected shows right-rotated list, use lst[-k:] + lst[:-k]. | |
| """ | |
| def _parse_llm_response(raw: str, buggy_code: str) -> dict: | |
| """Robustly parse LLM response handling control chars and malformed JSON.""" | |
| # Remove markdown fences | |
| if "```json" in raw: | |
| raw = raw.split("```json")[1].split("```")[0].strip() | |
| elif "```" in raw: | |
| parts = raw.split("```") | |
| if len(parts) >= 2: | |
| raw = parts[1].strip() | |
| if raw.startswith("json"): | |
| raw = raw[4:].strip() | |
| # Find JSON boundaries | |
| start = raw.find("{") | |
| end = raw.rfind("}") + 1 | |
| if start >= 0 and end > start: | |
| raw = raw[start:end] | |
| # Try direct parse | |
| try: | |
| parsed = json.loads(raw) | |
| return { | |
| "fixed_code": parsed.get("fixed_code", ""), | |
| "explanation": parsed.get("explanation"), | |
| } | |
| except json.JSONDecodeError: | |
| pass | |
| # Fix literal control characters inside JSON strings | |
| try: | |
| fixed = re.sub(r'(?<!\\)\n', r'\\n', raw) | |
| fixed = re.sub(r'(?<!\\)\t', r'\\t', fixed) | |
| fixed = re.sub(r'(?<!\\)\r', r'\\r', fixed) | |
| parsed = json.loads(fixed) | |
| code = parsed.get("fixed_code", "") | |
| if "\\n" in code: | |
| code = code.replace("\\n", "\n").replace("\\t", "\t") | |
| return {"fixed_code": code, "explanation": parsed.get("explanation")} | |
| except json.JSONDecodeError: | |
| pass | |
| # Last resort: regex extraction | |
| code_match = re.search(r'"fixed_code"\s*:\s*"((?:[^"\\]|\\.)*)"', raw, re.DOTALL) | |
| exp_match = re.search(r'"explanation"\s*:\s*"((?:[^"\\]|\\.)*)"', raw, re.DOTALL) | |
| if code_match: | |
| code = code_match.group(1).replace("\\n", "\n").replace("\\t", "\t") | |
| exp = exp_match.group(1).replace("\\n", "\n") if exp_match else None | |
| return {"fixed_code": code, "explanation": exp} | |
| # Complete fallback | |
| return {"fixed_code": buggy_code, "explanation": None} | |
| def call_llm( | |
| buggy_code: str, | |
| instructions: str, | |
| difficulty: str, | |
| feedback: Optional[str] = None, | |
| attempt: int = 1, | |
| prev_code: Optional[str] = None, | |
| ) -> dict: | |
| content = ( | |
| f"Difficulty: {difficulty}\n" | |
| f"Instructions: {instructions}\n\n" | |
| f"Buggy code:\n```python\n{buggy_code}\n```\n" | |
| ) | |
| if feedback and attempt > 1: | |
| content += ( | |
| f"\nPREVIOUS FIX FAILED. Feedback:\n{feedback}\n\n" | |
| f"Your previous code:\n```python\n{prev_code or ''}\n```\n" | |
| "ANALYZE THE FEEDBACK CAREFULLY:\n" | |
| "- Look at Input/Expected/Got for each failing test\n" | |
| "- If Got shows wrong rotation direction: use lst[-k:] + lst[:-k] for RIGHT rotate\n" | |
| "- If TimeoutError: add visited=set([start]) before queue in graph code\n" | |
| "- Try a COMPLETELY DIFFERENT fix.\n" | |
| ) | |
| if difficulty == "hard": | |
| hint_match = re.search(r'[Mm]ention[:\s]+([^.]+?)(?:\.|$)', instructions) | |
| if hint_match: | |
| hints = hint_match.group(1).strip() | |
| content += f"\nFor explanation, you MUST mention these concepts: {hints}\n" | |
| content += "Explanation counts for 30% of reward β make it detailed and specific.\n" | |
| try: | |
| resp = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": content}, | |
| ], | |
| max_tokens=1500, | |
| temperature=0.1 if attempt == 1 else 0.4, | |
| ) | |
| raw = resp.choices[0].message.content.strip() | |
| return _parse_llm_response(raw, buggy_code) | |
| except Exception as e: | |
| print(f"# LLM error: {e}", file=sys.stderr) | |
| return {"fixed_code": buggy_code, "explanation": None} | |
| # ββ Episode βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_episode(env_url: str, difficulty: str) -> tuple: | |
| """Run one full episode. Returns (success, steps_taken, rewards).""" | |
| data = env_reset(env_url, difficulty) | |
| obs = data["observation"] | |
| task_id = obs["task_id"] | |
| buggy_code = obs["buggy_code"] | |
| instructions = obs["instructions"] | |
| log_start(task_id, BENCHMARK, MODEL_NAME) | |
| rewards: List[float] = [] | |
| steps_taken = 0 | |
| success = False | |
| last_feedback = None | |
| last_code = None | |
| try: | |
| for attempt in range(1, MAX_STEPS + 1): | |
| steps_taken = attempt | |
| action = call_llm(buggy_code, instructions, difficulty, last_feedback, attempt, last_code) | |
| code = action.get("fixed_code") or "" | |
| last_code = code | |
| reward: float = 0.0 | |
| done: bool = False | |
| step_error: Optional[str] = None | |
| try: | |
| result = env_step(env_url, code, action.get("explanation")) | |
| reward = result.get("reward", 0.0) | |
| done = result.get("done", False) | |
| obs_r = result.get("observation", {}) | |
| if isinstance(obs_r, dict): | |
| last_feedback = obs_r.get("feedback", "") | |
| step_error = obs_r.get("last_action_error") or obs_r.get("error") | |
| except Exception as e: | |
| step_error = str(e) | |
| log_step(attempt, f"fix_{difficulty}_attempt{attempt}", reward, done, step_error) | |
| rewards.append(reward) | |
| if reward >= 1.0: | |
| success = True | |
| if done: | |
| break | |
| finally: | |
| score = max(rewards) if rewards else 0.0 | |
| score = min(max(score, 0.0), 1.0) | |
| success = success or (score >= SUCCESS_SCORE_THRESHOLD) | |
| log_end(success, steps_taken, score, rewards) | |
| return success, steps_taken, rewards | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Code Debug Environment Baseline Agent") | |
| parser.add_argument("--url", default=ENV_URL or "http://localhost:7860") | |
| parser.add_argument( | |
| "--difficulty", default=None, | |
| choices=["easy", "medium", "hard", "all"], | |
| ) | |
| args = parser.parse_args() | |
| url = args.url.rstrip("/") | |
| if not HF_TOKEN: | |
| print( | |
| "# Missing API key. Set HF_TOKEN (or API_KEY / lowercase hf_token).", | |
| file=sys.stderr, flush=True, | |
| ) | |
| sys.exit(1) | |
| print(f"# Using API key from {HF_TOKEN_SOURCE}", file=sys.stderr, flush=True) | |
| # Health check | |
| try: | |
| requests.get(f"{url}/health", timeout=10).raise_for_status() | |
| print(f"# Environment healthy at {url}", file=sys.stderr, flush=True) | |
| except Exception as e: | |
| print(f"# Health check failed: {e}", file=sys.stderr) | |
| sys.exit(1) | |
| diffs = ["easy", "medium", "hard"] if args.difficulty in (None, "all") else [args.difficulty] | |
| all_rewards: List[float] = [] | |
| successes: List[bool] = [] | |
| for d in diffs: | |
| ok, _, rewards = run_episode(url, d) | |
| all_rewards.extend(rewards) | |
| successes.append(ok) | |
| time.sleep(0.5) | |
| avg = round(sum(all_rewards) / len(all_rewards), 3) if all_rewards else 0.0 | |
| print( | |
| f"# SUMMARY: {sum(successes)}/{len(diffs)} tasks solved | avg_reward={avg}", | |
| file=sys.stderr, flush=True, | |
| ) | |
| if __name__ == "__main__": | |
| main() |