Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Any, Iterable | |
| from openai import OpenAI | |
| from inference.prompts import REQUIRED_ACTIONS, SYSTEM_PROMPT, build_user_prompt, heuristic_action, sanitize_action_text | |
| class ModelWrapper: | |
| client: OpenAI | None | |
| model_name: str | |
| temperature: float | |
| max_tokens: int | |
| offline: bool | |
| def generate_action( | |
| self, | |
| step: int, | |
| config_text: str, | |
| error_message: str, | |
| history: list[str], | |
| available_actions: Iterable[str] | None = None, | |
| ) -> str: | |
| fallback = heuristic_action(config_text, error_message, available_actions, history) | |
| if self.offline or self.client is None: | |
| return fallback | |
| user_prompt = build_user_prompt( | |
| step=step, | |
| config_text=config_text, | |
| error_message=error_message, | |
| history=history, | |
| available_actions=available_actions, | |
| ) | |
| try: | |
| completion = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=max(float(self.temperature), 0.0), | |
| max_tokens=max(16, int(self.max_tokens)), | |
| stream=False, | |
| ) | |
| generated = str(completion.choices[0].message.content or "") | |
| return sanitize_action_text(generated, fallback=fallback) | |
| except Exception: | |
| return fallback | |
| def generate_candidates( | |
| self, | |
| step: int, | |
| config_text: str, | |
| error_message: str, | |
| history: list[str], | |
| count: int, | |
| available_actions: Iterable[str] | None = None, | |
| ) -> list[str]: | |
| candidates = [heuristic_action(config_text, error_message, available_actions, history)] | |
| for idx in range(max(1, count)): | |
| action = self.generate_action( | |
| step=step, | |
| config_text=config_text, | |
| error_message=error_message, | |
| history=history + [f"candidate={idx}"], | |
| available_actions=available_actions, | |
| ) | |
| if action: | |
| candidates.append(action) | |
| deduped: list[str] = [] | |
| seen: set[str] = set() | |
| for candidate in candidates: | |
| normalized = candidate.strip() | |
| if not normalized or normalized in seen: | |
| continue | |
| seen.add(normalized) | |
| deduped.append(normalized) | |
| return deduped | |
| def score_action_candidate(observation_text: str, action_text: str, hacking_predicate: Any) -> float: | |
| obs = (observation_text or "").lower() | |
| action = (action_text or "").lower() | |
| score = 0.0 | |
| selected_tool = action.split(":", 1)[0].strip() | |
| if selected_tool in REQUIRED_ACTIONS: | |
| score += 0.5 | |
| else: | |
| score -= 0.5 | |
| if "npm tset" in obs and "edit_config" in action and "npm test" in action: | |
| score += 2.0 | |
| if "yaml" in obs and "edit_config" in action: | |
| score += 1.2 | |
| if "missing" in obs and "analyze_error" in action: | |
| score += 0.8 | |
| if "validate_fix" in action or "submit_solution" in action: | |
| score += 0.5 | |
| if hacking_predicate(action_text): | |
| score -= 3.0 | |
| score += min(len(action_text) / 240.0, 0.2) | |
| return score | |