code-debug-env / baseline /run_baseline.py
luciferai-devil's picture
Upload folder using huggingface_hub
49893ba verified
#!/usr/bin/env python3
"""
Baseline inference script.
Runs an LLM agent on all 3 tasks using OpenAI API.
Usage: python baseline/run_baseline.py [--output json]
Requires: OPENAI_API_KEY environment variable.
"""
import asyncio
import sys
import json
import os
from pathlib import Path
# Add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from code_debug_env.client import CodeDebugEnv
from code_debug_env.models import Action
try:
from openai import AsyncOpenAI
except ImportError:
print("Please install openai: pip install openai", file=sys.stderr)
sys.exit(1)
BASE_URL = os.getenv("OPENENV_URL", "http://127.0.0.1:8000")
API_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("OPENENV_MODEL", "gpt-4o-mini")
_client = None
def get_openai_client():
global _client
if _client is None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
return None
_client = AsyncOpenAI(
api_key=api_key,
base_url=API_BASE_URL
)
return _client
async def openai_agent(observation) -> Action:
"""Uses LLM to suggest a code fix."""
prompt = f"""You are an expert Python debugger. Your task is to fix the buggy code below.
Task Description: {observation.task_description}
Buggy Code:
```python
{observation.buggy_code}
```
Test Results so far:
{[[t.name, t.passed, t.error] for t in observation.test_results]}
Passed {observation.passed} out of {observation.total} tests.
Provide ONLY a valid JSON object matching this schema:
{{
"patch": "The FULL python function as a string, with the bugs fixed",
"task_id": "{observation.task_id}",
"think": "Your chain-of-thought reasoning before patching (important!)"
}}
"""
client = get_openai_client()
if not client:
return Action(
patch=observation.buggy_code,
task_id=observation.task_id,
think="Skipping LLM call: OPENAI_API_KEY not set."
)
try:
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"} if "gpt-4" in MODEL_NAME or "gpt-oss" in MODEL_NAME else None,
temperature=0.2,
)
content = response.choices[0].message.content
data = json.loads(content)
return Action(
patch=data["patch"],
task_id=observation.task_id,
think=data.get("think", "Applied fix based on test errors."),
)
except Exception as e:
print(f"LLM Error: {e}", file=sys.stderr)
# fallback to returning original code to avoid crashing the loop
return Action(
patch=observation.buggy_code,
task_id=observation.task_id,
think="Failed to generate patch.",
)
async def evaluate_task(env, task_id: str) -> dict:
result = await env.reset(task_id=task_id)
obs = result.observation
best_score = 0.0
for step in range(10):
action = await openai_agent(obs)
result = await env.step(action)
best_score = max(best_score, result.observation.score)
obs = result.observation
if obs.done:
break
return {"task_id": task_id, "best_score": round(best_score, 4), "steps": step + 1}
async def main(output_format: str = "table"):
if not os.getenv("OPENAI_API_KEY"):
print("Warning: OPENAI_API_KEY not set. LLM calls will fail.", file=sys.stderr)
results = []
async with CodeDebugEnv(base_url=BASE_URL) as env:
for task_id in ["task_easy", "task_medium", "task_hard"]:
res = await evaluate_task(env, task_id)
results.append(res)
if output_format == "json":
print(json.dumps({"baseline_results": results, "agent": "openai_api"}))
else:
print("\n=== Baseline Results ===", file=sys.stderr)
for r in results:
print(f" {r['task_id']:15s} score={r['best_score']:.3f} steps={r['steps']}", file=sys.stderr)
print(f"\n avg score: {sum(r['best_score'] for r in results) / len(results):.3f}", file=sys.stderr)
if __name__ == "__main__":
output = "json" if "json" in sys.argv else "table"
asyncio.run(main(output))