from __future__ import annotations from dataclasses import dataclass from typing import Any, Iterable from openai import OpenAI from inference.prompts import REQUIRED_ACTIONS, SYSTEM_PROMPT, build_user_prompt, heuristic_action, sanitize_action_text @dataclass class ModelWrapper: client: OpenAI | None model_name: str temperature: float max_tokens: int offline: bool def generate_action( self, step: int, config_text: str, error_message: str, history: list[str], available_actions: Iterable[str] | None = None, ) -> str: fallback = heuristic_action(config_text, error_message, available_actions, history) if self.offline or self.client is None: return fallback user_prompt = build_user_prompt( step=step, config_text=config_text, error_message=error_message, history=history, available_actions=available_actions, ) try: completion = self.client.chat.completions.create( model=self.model_name, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], temperature=max(float(self.temperature), 0.0), max_tokens=max(16, int(self.max_tokens)), stream=False, ) generated = str(completion.choices[0].message.content or "") return sanitize_action_text(generated, fallback=fallback) except Exception: return fallback def generate_candidates( self, step: int, config_text: str, error_message: str, history: list[str], count: int, available_actions: Iterable[str] | None = None, ) -> list[str]: candidates = [heuristic_action(config_text, error_message, available_actions, history)] for idx in range(max(1, count)): action = self.generate_action( step=step, config_text=config_text, error_message=error_message, history=history + [f"candidate={idx}"], available_actions=available_actions, ) if action: candidates.append(action) deduped: list[str] = [] seen: set[str] = set() for candidate in candidates: normalized = candidate.strip() if not normalized or normalized in seen: continue seen.add(normalized) deduped.append(normalized) return deduped def score_action_candidate(observation_text: str, action_text: str, hacking_predicate: Any) -> float: obs = (observation_text or "").lower() action = (action_text or "").lower() score = 0.0 selected_tool = action.split(":", 1)[0].strip() if selected_tool in REQUIRED_ACTIONS: score += 0.5 else: score -= 0.5 if "npm tset" in obs and "edit_config" in action and "npm test" in action: score += 2.0 if "yaml" in obs and "edit_config" in action: score += 1.2 if "missing" in obs and "analyze_error" in action: score += 0.8 if "validate_fix" in action or "submit_solution" in action: score += 0.5 if hacking_predicate(action_text): score -= 3.0 score += min(len(action_text) / 240.0, 0.2) return score