Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Baseline inference script. | |
| Runs an LLM agent on all 3 tasks using OpenAI API. | |
| Usage: python baseline/run_baseline.py [--output json] | |
| Requires: OPENAI_API_KEY environment variable. | |
| """ | |
| import asyncio | |
| import sys | |
| import json | |
| import os | |
| from pathlib import Path | |
| # Add parent to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent)) | |
| from code_debug_env.client import CodeDebugEnv | |
| from code_debug_env.models import Action | |
| try: | |
| from openai import AsyncOpenAI | |
| except ImportError: | |
| print("Please install openai: pip install openai", file=sys.stderr) | |
| sys.exit(1) | |
| BASE_URL = os.getenv("OPENENV_URL", "http://127.0.0.1:8000") | |
| API_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") | |
| MODEL_NAME = os.getenv("OPENENV_MODEL", "gpt-4o-mini") | |
| _client = None | |
| def get_openai_client(): | |
| global _client | |
| if _client is None: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| return None | |
| _client = AsyncOpenAI( | |
| api_key=api_key, | |
| base_url=API_BASE_URL | |
| ) | |
| return _client | |
| async def openai_agent(observation) -> Action: | |
| """Uses LLM to suggest a code fix.""" | |
| prompt = f"""You are an expert Python debugger. Your task is to fix the buggy code below. | |
| Task Description: {observation.task_description} | |
| Buggy Code: | |
| ```python | |
| {observation.buggy_code} | |
| ``` | |
| Test Results so far: | |
| {[[t.name, t.passed, t.error] for t in observation.test_results]} | |
| Passed {observation.passed} out of {observation.total} tests. | |
| Provide ONLY a valid JSON object matching this schema: | |
| {{ | |
| "patch": "The FULL python function as a string, with the bugs fixed", | |
| "task_id": "{observation.task_id}", | |
| "think": "Your chain-of-thought reasoning before patching (important!)" | |
| }} | |
| """ | |
| client = get_openai_client() | |
| if not client: | |
| return Action( | |
| patch=observation.buggy_code, | |
| task_id=observation.task_id, | |
| think="Skipping LLM call: OPENAI_API_KEY not set." | |
| ) | |
| try: | |
| response = await client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[{"role": "user", "content": prompt}], | |
| response_format={"type": "json_object"} if "gpt-4" in MODEL_NAME or "gpt-oss" in MODEL_NAME else None, | |
| temperature=0.2, | |
| ) | |
| content = response.choices[0].message.content | |
| data = json.loads(content) | |
| return Action( | |
| patch=data["patch"], | |
| task_id=observation.task_id, | |
| think=data.get("think", "Applied fix based on test errors."), | |
| ) | |
| except Exception as e: | |
| print(f"LLM Error: {e}", file=sys.stderr) | |
| # fallback to returning original code to avoid crashing the loop | |
| return Action( | |
| patch=observation.buggy_code, | |
| task_id=observation.task_id, | |
| think="Failed to generate patch.", | |
| ) | |
| async def evaluate_task(env, task_id: str) -> dict: | |
| result = await env.reset(task_id=task_id) | |
| obs = result.observation | |
| best_score = 0.0 | |
| for step in range(10): | |
| action = await openai_agent(obs) | |
| result = await env.step(action) | |
| best_score = max(best_score, result.observation.score) | |
| obs = result.observation | |
| if obs.done: | |
| break | |
| return {"task_id": task_id, "best_score": round(best_score, 4), "steps": step + 1} | |
| async def main(output_format: str = "table"): | |
| if not os.getenv("OPENAI_API_KEY"): | |
| print("Warning: OPENAI_API_KEY not set. LLM calls will fail.", file=sys.stderr) | |
| results = [] | |
| async with CodeDebugEnv(base_url=BASE_URL) as env: | |
| for task_id in ["task_easy", "task_medium", "task_hard"]: | |
| res = await evaluate_task(env, task_id) | |
| results.append(res) | |
| if output_format == "json": | |
| print(json.dumps({"baseline_results": results, "agent": "openai_api"})) | |
| else: | |
| print("\n=== Baseline Results ===", file=sys.stderr) | |
| for r in results: | |
| print(f" {r['task_id']:15s} score={r['best_score']:.3f} steps={r['steps']}", file=sys.stderr) | |
| print(f"\n avg score: {sum(r['best_score'] for r in results) / len(results):.3f}", file=sys.stderr) | |
| if __name__ == "__main__": | |
| output = "json" if "json" in sys.argv else "table" | |
| asyncio.run(main(output)) | |