Spaces:
Runtime error
Runtime error
File size: 4,062 Bytes
cd5c208 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | """OpenAI-compatible action planner backed by the Hugging Face router."""
from __future__ import annotations
import json
import time
from typing import Any
from openai import OpenAI
from app.models.inference import AgentDecision, InferenceConfig
from app.utils.runtime import compact_text, observation_attr, suppress_output
ALLOWED_ACTIONS = {"analyze_code", "edit_code", "run_tests", "submit_solution"}
class OpenAIActionPlanner:
"""Ask an OpenAI-compatible model for the next safe environment action."""
def __init__(self, config: InferenceConfig) -> None:
self.config = config
self.client = OpenAI(base_url=config.api_base_url, api_key=config.hf_token) if config.hf_token else None
def propose_action(self, observation: Any) -> AgentDecision:
if self.client is None:
return AgentDecision(action_type="run_tests", source="fallback", error="HF_TOKEN missing")
prompt = self._build_prompt(observation)
for attempt in range(self.config.max_retries + 1):
try:
with suppress_output():
response = self.client.chat.completions.create(
model=self.config.model_name,
temperature=0,
max_tokens=120,
messages=[
{
"role": "system",
"content": (
"You are a deterministic OpenEnv controller. "
"Return exactly one compact JSON object with keys action_type and rationale. "
"Allowed action_type values: analyze_code, run_tests, submit_solution. "
"Never emit markdown."
),
},
{"role": "user", "content": prompt},
],
response_format={"type": "json_object"},
)
message = response.choices[0].message.content or ""
return self._parse_action(message)
except Exception as exc:
if attempt >= self.config.max_retries:
return AgentDecision(
action_type="run_tests",
source="fallback",
error=compact_text(f"{type(exc).__name__}: {exc}", default="LLM failure"),
)
time.sleep(0.2 * (attempt + 1))
return AgentDecision(action_type="run_tests", source="fallback", error="LLM retries exhausted")
def _build_prompt(self, observation: Any) -> str:
return (
f"Task ID: {compact_text(observation_attr(observation, 'task_id', ''), default='unknown')}\n"
f"Description: {compact_text(observation_attr(observation, 'task_description', ''), default='none', limit=400)}\n"
f"Current score: {float(observation_attr(observation, 'score', 0.01) or 0.01):.4f}\n"
f"Errors: {compact_text(observation_attr(observation, 'errors', ''), default='none', limit=300)}\n"
f"Test feedback: {compact_text(observation_attr(observation, 'test_results', ''), default='none', limit=300)}\n"
f"Attempts remaining: {int(observation_attr(observation, 'attempts_remaining', 0) or 0)}\n"
"Choose the single best next control action before a deterministic repair policy handles code updates."
)
def _parse_action(self, content: str) -> AgentDecision:
try:
payload = json.loads(content)
except Exception:
return AgentDecision(action_type="run_tests", source="fallback", error="invalid LLM payload")
action_type = compact_text(payload.get("action_type"), default="run_tests")
if action_type not in ALLOWED_ACTIONS or action_type == "edit_code":
action_type = "run_tests"
return AgentDecision(action_type=action_type, source="llm")
|