import asyncio import inspect import json import os import sys import textwrap from typing import Any, Optional sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from dotenv import load_dotenv from openai import OpenAI from client import DebugzeroEnv from models import DebugzeroAction load_dotenv() API_BASE_URL = os.getenv("API_BASE_URL") or os.getenv("OPENAI_BASE_URL", "https://openrouter.ai/api/v1") MODEL_NAME = os.getenv("MODEL_NAME") or os.getenv("OPENAI_MODEL", "meta-llama/llama-3.1-8b-instruct") API_KEY = os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") ENV_URL = os.getenv("DEBUGZERO_ENV_URL", "http://localhost:8000") NUM_EPISODES = int(os.getenv("NUM_EPISODES", "6")) MAX_STEPS = int(os.getenv("MAX_STEPS", "8")) PROPOSER_TEMPERATURE = float(os.getenv("PROPOSER_TEMPERATURE", "0.7")) SOLVER_TEMPERATURE = float(os.getenv("SOLVER_TEMPERATURE", "0.2")) MAX_TOKENS = int(os.getenv("MAX_TOKENS", "1024")) BUG_FOCUS = os.getenv("DEBUGZERO_BUG_FOCUS") def extract_python_code(text: str) -> str: content = (text or "").strip() if content.startswith("```"): content = content.split("\n", 1)[-1] if content.endswith("```"): content = content.rsplit("\n", 1)[0] return content.strip() def summarize_error(text: str, max_chars: int = 220) -> str: cleaned = " ".join(text.strip().split()) if not cleaned: return "null" if len(cleaned) <= max_chars: return cleaned return cleaned[: max_chars - 3].rstrip() + "..." def extract_env_error(result: Any) -> Optional[str]: for attr in ("last_action_error", "error", "message"): if hasattr(result, attr): value = getattr(result, attr) if value: return str(value) obs = getattr(result, "observation", None) if obs is None: return None for attr in ("last_action_error", "error"): if hasattr(obs, attr): value = getattr(obs, attr) if value: return str(value) execution_result = getattr(obs, "execution_result", "") if isinstance(execution_result, str) and execution_result: if getattr(obs, "syntax_error", False): return summarize_error(execution_result) if execution_result.startswith("Unsafe import detected."): return execution_result if not getattr(obs, "tests_passed", False): return summarize_error(execution_result) return None def compact_action_string(role: str, code: str) -> str: return json.dumps({"role": role, "code": code}, separators=(",", ":"), ensure_ascii=False) def build_prompt(obs_dict: dict[str, Any], history: list[str]) -> str: role = str(obs_dict.get("role_next", "proposer")) current_code = str(obs_dict.get("current_code", "")) execution_result = str(obs_dict.get("execution_result", "")) tests_passed = bool(obs_dict.get("tests_passed", False)) syntax_error = bool(obs_dict.get("syntax_error", False)) metadata = obs_dict.get("metadata", {}) or {} seed_id = metadata.get("seed_id", "unknown") history_block = "\n".join(history[-4:]) if history else "None" if role == "proposer": focus_line = "" if BUG_FOCUS: focus_line = f"- Focus specifically on the `{BUG_FOCUS}` mutation family.\n" instructions = textwrap.dedent( f""" You are the Proposer in a debugging self-play environment. Return a full Python function with exactly one small logical bug injected. Rules: - Keep the code valid Python. - Keep the same function signature. - Preserve the overall structure and formatting as much as possible. - Make exactly one small local behavioral change. - Avoid comments, explanations, markdown outside the code block, and broad rewrites. {focus_line}- Your goal is to make tests fail without creating a syntax error. """ ).strip() else: instructions = textwrap.dedent( """ You are the Solver in a debugging self-play environment. Return the full fixed Python function. Rules: - Keep the code valid Python. - Keep the same function signature. - Make the smallest correct local fix you can. - Use the failure output to guide the repair. - Avoid comments, explanations, markdown outside the code block, and unrelated refactors. """ ).strip() return textwrap.dedent( f""" {instructions} Current environment state: - seed_id: {seed_id} - role_next: {role} - tests_passed: {tests_passed} - syntax_error: {syntax_error} Current code: ```python {current_code} ``` Execution result: {execution_result if execution_result else "None"} Previous actions: {history_block} Return only the full Python code inside triple backticks. """ ).strip() def get_model_code(client: OpenAI, obs_dict: dict[str, Any], history: list[str]) -> str: role = str(obs_dict.get("role_next", "proposer")) prompt = build_prompt(obs_dict, history) temperature = PROPOSER_TEMPERATURE if role == "proposer" else SOLVER_TEMPERATURE response = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": "You are an expert Python coder."}, {"role": "user", "content": prompt}, ], temperature=temperature, max_tokens=MAX_TOKENS, ) return extract_python_code(response.choices[0].message.content or "") async def maybe_await(value: Any) -> Any: if inspect.isawaitable(value): return await value return value async def call_env_method(obj: Any, method_name: str, *args: Any) -> Any: method = getattr(obj, method_name) result = method(*args) return await maybe_await(result) async def make_env() -> Any: max_retries = 30 for attempt in range(max_retries): try: return DebugzeroEnv(base_url=ENV_URL) except Exception as exc: print( f"[SYSTEM ERROR] Env connection to {ENV_URL} failed (attempt {attempt + 1}/{max_retries}): {exc}", file=sys.stderr, flush=True, ) if attempt < max_retries - 1: await asyncio.sleep(5.0) else: raise def print_live_summary(metrics: dict[str, Any]) -> None: episodes = max(1, int(metrics["episodes"])) proposer_attempts = max(1, int(metrics["proposer_attempts"])) solver_attempts = max(1, int(metrics["solver_attempts"])) rewards = metrics["rewards"] average_reward = (sum(rewards) / len(rewards)) if rewards else 0.0 print("\n" + "=" * 80) print("Live API summary") print("=" * 80) print(f"Episode success rate: {metrics['episode_successes'] / episodes:.2%}") print(f"Proposer syntax rate: {metrics['proposer_syntax_errors'] / proposer_attempts:.2%}") print(f"Solver syntax rate: {metrics['solver_syntax_errors'] / solver_attempts:.2%}") print(f"Average step reward: {average_reward:.2f}") print(f"Average steps/episode: {metrics['total_steps'] / episodes:.2f}") print(f"Representative success: {metrics['representative_success']}") print(f"Representative failure: {metrics['representative_failure']}") async def run_live_api_probe() -> dict[str, Any] | None: if not API_KEY: print("Skipping live API probe: OPENAI_API_KEY/API_KEY is not set.") return None if not MODEL_NAME: print("Skipping live API probe: OPENAI_MODEL/MODEL_NAME is not set.") return None client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL) env = await make_env() metrics = { "episodes": NUM_EPISODES, "episode_successes": 0, "proposer_attempts": 0, "solver_attempts": 0, "proposer_syntax_errors": 0, "solver_syntax_errors": 0, "rewards": [], "total_steps": 0, "representative_success": None, "representative_failure": None, } print("=" * 80) print("Live API probe") print("=" * 80) print(f"API base URL: {API_BASE_URL}") print(f"Model: {MODEL_NAME}") print(f"Env URL: {ENV_URL}") try: for episode in range(1, NUM_EPISODES + 1): result = await call_env_method(env, "reset") obs = getattr(result, "observation", None) done = bool(getattr(result, "done", False)) history: list[str] = [] success = False seed_id = "unknown" if obs is not None: metadata = getattr(obs, "metadata", {}) or {} seed_id = metadata.get("seed_id", "unknown") print(f"\nEpisode {episode}/{NUM_EPISODES} | seed={seed_id}") for step in range(1, MAX_STEPS + 1): if done or obs is None: break obs_dict = obs.model_dump() if hasattr(obs, "model_dump") else obs.dict() role = str(obs_dict.get("role_next", "proposer")) if role == "proposer": metrics["proposer_attempts"] += 1 else: metrics["solver_attempts"] += 1 try: code = await asyncio.to_thread(get_model_code, client, obs_dict, history) except Exception as exc: print(f"[SYSTEM ERROR] Model generation failed: {exc}", file=sys.stderr, flush=True) code = str(obs_dict.get("current_code", "")) action = DebugzeroAction(role=role, code=code) action_str = compact_action_string(role, code) result = await call_env_method(env, "step", action) obs = getattr(result, "observation", None) done = bool(getattr(result, "done", False)) reward = float(getattr(result, "reward", 0.0) or 0.0) error = extract_env_error(result) metrics["rewards"].append(reward) metrics["total_steps"] += 1 if obs is not None and getattr(obs, "syntax_error", False): if role == "proposer": metrics["proposer_syntax_errors"] += 1 else: metrics["solver_syntax_errors"] += 1 print( f" step={step} role={role} reward={reward:.2f} done={str(done).lower()} error={error or 'null'}", flush=True, ) history.append(f"Step {step}: {action_str} -> reward {reward:.2f}") if done and obs is not None: success = bool(getattr(obs, "tests_passed", False)) and not bool( getattr(obs, "syntax_error", False) ) if success: metrics["episode_successes"] += 1 if metrics["representative_success"] is None: metrics["representative_success"] = { "seed_id": getattr(obs, "metadata", {}).get("seed_id", "unknown"), "steps": step, "reward": reward, } elif metrics["representative_failure"] is None: metrics["representative_failure"] = { "seed_id": getattr(obs, "metadata", {}).get("seed_id", "unknown"), "steps": step, "execution_result": getattr(obs, "execution_result", ""), } break if not success and metrics["representative_failure"] is None: failure_seed = seed_id failure_output = getattr(obs, "execution_result", "") if obs is not None else "" metrics["representative_failure"] = { "seed_id": failure_seed, "steps": min(MAX_STEPS, len(history)), "execution_result": failure_output, } return metrics finally: await call_env_method(env, "close") async def main() -> None: metrics = await run_live_api_probe() if metrics is not None: print_live_summary(metrics) if __name__ == "__main__": asyncio.run(main())