File size: 4,062 Bytes
cd5c208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""OpenAI-compatible action planner backed by the Hugging Face router."""

from __future__ import annotations

import json
import time
from typing import Any

from openai import OpenAI

from app.models.inference import AgentDecision, InferenceConfig
from app.utils.runtime import compact_text, observation_attr, suppress_output


ALLOWED_ACTIONS = {"analyze_code", "edit_code", "run_tests", "submit_solution"}


class OpenAIActionPlanner:
    """Ask an OpenAI-compatible model for the next safe environment action."""

    def __init__(self, config: InferenceConfig) -> None:
        self.config = config
        self.client = OpenAI(base_url=config.api_base_url, api_key=config.hf_token) if config.hf_token else None

    def propose_action(self, observation: Any) -> AgentDecision:
        if self.client is None:
            return AgentDecision(action_type="run_tests", source="fallback", error="HF_TOKEN missing")

        prompt = self._build_prompt(observation)
        for attempt in range(self.config.max_retries + 1):
            try:
                with suppress_output():
                    response = self.client.chat.completions.create(
                        model=self.config.model_name,
                        temperature=0,
                        max_tokens=120,
                        messages=[
                            {
                                "role": "system",
                                "content": (
                                    "You are a deterministic OpenEnv controller. "
                                    "Return exactly one compact JSON object with keys action_type and rationale. "
                                    "Allowed action_type values: analyze_code, run_tests, submit_solution. "
                                    "Never emit markdown."
                                ),
                            },
                            {"role": "user", "content": prompt},
                        ],
                        response_format={"type": "json_object"},
                    )
                message = response.choices[0].message.content or ""
                return self._parse_action(message)
            except Exception as exc:
                if attempt >= self.config.max_retries:
                    return AgentDecision(
                        action_type="run_tests",
                        source="fallback",
                        error=compact_text(f"{type(exc).__name__}: {exc}", default="LLM failure"),
                    )
                time.sleep(0.2 * (attempt + 1))

        return AgentDecision(action_type="run_tests", source="fallback", error="LLM retries exhausted")

    def _build_prompt(self, observation: Any) -> str:
        return (
            f"Task ID: {compact_text(observation_attr(observation, 'task_id', ''), default='unknown')}\n"
            f"Description: {compact_text(observation_attr(observation, 'task_description', ''), default='none', limit=400)}\n"
            f"Current score: {float(observation_attr(observation, 'score', 0.01) or 0.01):.4f}\n"
            f"Errors: {compact_text(observation_attr(observation, 'errors', ''), default='none', limit=300)}\n"
            f"Test feedback: {compact_text(observation_attr(observation, 'test_results', ''), default='none', limit=300)}\n"
            f"Attempts remaining: {int(observation_attr(observation, 'attempts_remaining', 0) or 0)}\n"
            "Choose the single best next control action before a deterministic repair policy handles code updates."
        )

    def _parse_action(self, content: str) -> AgentDecision:
        try:
            payload = json.loads(content)
        except Exception:
            return AgentDecision(action_type="run_tests", source="fallback", error="invalid LLM payload")

        action_type = compact_text(payload.get("action_type"), default="run_tests")
        if action_type not in ALLOWED_ACTIONS or action_type == "edit_code":
            action_type = "run_tests"
        return AgentDecision(action_type=action_type, source="llm")