Spaces:
Sleeping
Sleeping
Commit Β·
ff292ff
1
Parent(s): 20f2ce3
fix: use pydantic .dict() instead of dataclasses asdict, remove raise exc
Browse files- inference.py +17 -5
inference.py
CHANGED
|
@@ -15,7 +15,6 @@ import time
|
|
| 15 |
import textwrap
|
| 16 |
import requests
|
| 17 |
from typing import List, Optional
|
| 18 |
-
from dataclasses import asdict
|
| 19 |
|
| 20 |
from openai import OpenAI
|
| 21 |
from model import TriageAction, TriageObservation, BugReport
|
|
@@ -56,7 +55,10 @@ print(f"[CONFIG] API_KEY={'set' if API_KEY else 'MISSING'}", flush=True)
|
|
| 56 |
# ββ inlined client (avoids openenv-core import conflict) ββββββββββββββββββ
|
| 57 |
|
| 58 |
def _parse_observation(data: dict) -> TriageObservation:
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
| 60 |
return TriageObservation(
|
| 61 |
bug_report=bug,
|
| 62 |
task_id=data.get("task_id", "easy"),
|
|
@@ -94,9 +96,14 @@ class BugTriageClient:
|
|
| 94 |
|
| 95 |
def step(self, action: TriageAction) -> StepResult:
|
| 96 |
print("[ENV] Sending step action...", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
response = self.session.post(
|
| 98 |
f"{self.base_url}/step",
|
| 99 |
-
json={"action":
|
| 100 |
timeout=30,
|
| 101 |
)
|
| 102 |
response.raise_for_status()
|
|
@@ -214,13 +221,18 @@ def call_model(client: OpenAI, bug_text: str) -> TriageAction:
|
|
| 214 |
raw = (completion.choices[0].message.content or "").strip()
|
| 215 |
print(f"[LLM] Raw response: {raw[:200]}", flush=True)
|
| 216 |
|
|
|
|
| 217 |
if raw.startswith("```"):
|
| 218 |
parts = raw.split("```")
|
| 219 |
raw = parts[1] if len(parts) > 1 else raw
|
| 220 |
if raw.startswith("json"):
|
| 221 |
raw = raw[4:].strip()
|
| 222 |
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
action = TriageAction(
|
| 226 |
priority=data.get("priority", "P2"),
|
|
@@ -278,10 +290,10 @@ def main() -> None:
|
|
| 278 |
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 279 |
|
| 280 |
except Exception as exc:
|
|
|
|
| 281 |
print(f"[ERROR] Exception during run: {type(exc).__name__}: {exc}", flush=True)
|
| 282 |
score = sum(rewards) / len(TASK_IDS) if rewards else 0.0
|
| 283 |
success = False
|
| 284 |
-
raise exc
|
| 285 |
|
| 286 |
finally:
|
| 287 |
log_end(success, len(rewards), score, rewards)
|
|
|
|
| 15 |
import textwrap
|
| 16 |
import requests
|
| 17 |
from typing import List, Optional
|
|
|
|
| 18 |
|
| 19 |
from openai import OpenAI
|
| 20 |
from model import TriageAction, TriageObservation, BugReport
|
|
|
|
| 55 |
# ββ inlined client (avoids openenv-core import conflict) ββββββββββββββββββ
|
| 56 |
|
| 57 |
def _parse_observation(data: dict) -> TriageObservation:
|
| 58 |
+
try:
|
| 59 |
+
bug = BugReport.model_validate(data["bug_report"])
|
| 60 |
+
except Exception:
|
| 61 |
+
bug = BugReport(**data["bug_report"])
|
| 62 |
return TriageObservation(
|
| 63 |
bug_report=bug,
|
| 64 |
task_id=data.get("task_id", "easy"),
|
|
|
|
| 96 |
|
| 97 |
def step(self, action: TriageAction) -> StepResult:
|
| 98 |
print("[ENV] Sending step action...", flush=True)
|
| 99 |
+
# FIX: TriageAction is a Pydantic model β use .dict() not asdict()
|
| 100 |
+
try:
|
| 101 |
+
action_dict = action.model_dump() # Pydantic v2
|
| 102 |
+
except AttributeError:
|
| 103 |
+
action_dict = action.dict() # Pydantic v1 fallback
|
| 104 |
response = self.session.post(
|
| 105 |
f"{self.base_url}/step",
|
| 106 |
+
json={"action": action_dict},
|
| 107 |
timeout=30,
|
| 108 |
)
|
| 109 |
response.raise_for_status()
|
|
|
|
| 221 |
raw = (completion.choices[0].message.content or "").strip()
|
| 222 |
print(f"[LLM] Raw response: {raw[:200]}", flush=True)
|
| 223 |
|
| 224 |
+
# Strip markdown code fences if present
|
| 225 |
if raw.startswith("```"):
|
| 226 |
parts = raw.split("```")
|
| 227 |
raw = parts[1] if len(parts) > 1 else raw
|
| 228 |
if raw.startswith("json"):
|
| 229 |
raw = raw[4:].strip()
|
| 230 |
|
| 231 |
+
try:
|
| 232 |
+
data = json.loads(raw)
|
| 233 |
+
except json.JSONDecodeError as e:
|
| 234 |
+
print(f"[LLM] JSON parse failed: {e}. Using safe defaults.", flush=True)
|
| 235 |
+
data = {}
|
| 236 |
|
| 237 |
action = TriageAction(
|
| 238 |
priority=data.get("priority", "P2"),
|
|
|
|
| 290 |
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 291 |
|
| 292 |
except Exception as exc:
|
| 293 |
+
# FIX: do NOT re-raise β validator treats non-zero exit as failure
|
| 294 |
print(f"[ERROR] Exception during run: {type(exc).__name__}: {exc}", flush=True)
|
| 295 |
score = sum(rewards) / len(TASK_IDS) if rewards else 0.0
|
| 296 |
success = False
|
|
|
|
| 297 |
|
| 298 |
finally:
|
| 299 |
log_end(success, len(rewards), score, rewards)
|