Spaces:
Sleeping
Sleeping
File size: 4,345 Bytes
cacd58c 1021aa7 cacd58c 49893ba cacd58c 49893ba cacd58c 1021aa7 cacd58c 1021aa7 cacd58c 2d5d492 cacd58c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | #!/usr/bin/env python3
"""
Baseline inference script.
Runs an LLM agent on all 3 tasks using OpenAI API.
Usage: python baseline/run_baseline.py [--output json]
Requires: OPENAI_API_KEY environment variable.
"""
import asyncio
import sys
import json
import os
from pathlib import Path
# Add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from code_debug_env.client import CodeDebugEnv
from code_debug_env.models import Action
try:
from openai import AsyncOpenAI
except ImportError:
print("Please install openai: pip install openai", file=sys.stderr)
sys.exit(1)
BASE_URL = os.getenv("OPENENV_URL", "http://127.0.0.1:8000")
API_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("OPENENV_MODEL", "gpt-4o-mini")
_client = None
def get_openai_client():
global _client
if _client is None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
return None
_client = AsyncOpenAI(
api_key=api_key,
base_url=API_BASE_URL
)
return _client
async def openai_agent(observation) -> Action:
"""Uses LLM to suggest a code fix."""
prompt = f"""You are an expert Python debugger. Your task is to fix the buggy code below.
Task Description: {observation.task_description}
Buggy Code:
```python
{observation.buggy_code}
```
Test Results so far:
{[[t.name, t.passed, t.error] for t in observation.test_results]}
Passed {observation.passed} out of {observation.total} tests.
Provide ONLY a valid JSON object matching this schema:
{{
"patch": "The FULL python function as a string, with the bugs fixed",
"task_id": "{observation.task_id}",
"think": "Your chain-of-thought reasoning before patching (important!)"
}}
"""
client = get_openai_client()
if not client:
return Action(
patch=observation.buggy_code,
task_id=observation.task_id,
think="Skipping LLM call: OPENAI_API_KEY not set."
)
try:
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"} if "gpt-4" in MODEL_NAME or "gpt-oss" in MODEL_NAME else None,
temperature=0.2,
)
content = response.choices[0].message.content
data = json.loads(content)
return Action(
patch=data["patch"],
task_id=observation.task_id,
think=data.get("think", "Applied fix based on test errors."),
)
except Exception as e:
print(f"LLM Error: {e}", file=sys.stderr)
# fallback to returning original code to avoid crashing the loop
return Action(
patch=observation.buggy_code,
task_id=observation.task_id,
think="Failed to generate patch.",
)
async def evaluate_task(env, task_id: str) -> dict:
result = await env.reset(task_id=task_id)
obs = result.observation
best_score = 0.0
for step in range(10):
action = await openai_agent(obs)
result = await env.step(action)
best_score = max(best_score, result.observation.score)
obs = result.observation
if obs.done:
break
return {"task_id": task_id, "best_score": round(best_score, 4), "steps": step + 1}
async def main(output_format: str = "table"):
if not os.getenv("OPENAI_API_KEY"):
print("Warning: OPENAI_API_KEY not set. LLM calls will fail.", file=sys.stderr)
results = []
async with CodeDebugEnv(base_url=BASE_URL) as env:
for task_id in ["task_easy", "task_medium", "task_hard"]:
res = await evaluate_task(env, task_id)
results.append(res)
if output_format == "json":
print(json.dumps({"baseline_results": results, "agent": "openai_api"}))
else:
print("\n=== Baseline Results ===", file=sys.stderr)
for r in results:
print(f" {r['task_id']:15s} score={r['best_score']:.3f} steps={r['steps']}", file=sys.stderr)
print(f"\n avg score: {sum(r['best_score'] for r in results) / len(results):.3f}", file=sys.stderr)
if __name__ == "__main__":
output = "json" if "json" in sys.argv else "table"
asyncio.run(main(output))
|