Siteshcodes commited on
Commit
ff292ff
Β·
1 Parent(s): 20f2ce3

fix: use pydantic .dict() instead of dataclasses asdict, remove raise exc

Browse files
Files changed (1) hide show
  1. inference.py +17 -5
inference.py CHANGED
@@ -15,7 +15,6 @@ import time
15
  import textwrap
16
  import requests
17
  from typing import List, Optional
18
- from dataclasses import asdict
19
 
20
  from openai import OpenAI
21
  from model import TriageAction, TriageObservation, BugReport
@@ -56,7 +55,10 @@ print(f"[CONFIG] API_KEY={'set' if API_KEY else 'MISSING'}", flush=True)
56
  # ── inlined client (avoids openenv-core import conflict) ──────────────────
57
 
58
  def _parse_observation(data: dict) -> TriageObservation:
59
- bug = BugReport(**data["bug_report"])
 
 
 
60
  return TriageObservation(
61
  bug_report=bug,
62
  task_id=data.get("task_id", "easy"),
@@ -94,9 +96,14 @@ class BugTriageClient:
94
 
95
  def step(self, action: TriageAction) -> StepResult:
96
  print("[ENV] Sending step action...", flush=True)
 
 
 
 
 
97
  response = self.session.post(
98
  f"{self.base_url}/step",
99
- json={"action": asdict(action)},
100
  timeout=30,
101
  )
102
  response.raise_for_status()
@@ -214,13 +221,18 @@ def call_model(client: OpenAI, bug_text: str) -> TriageAction:
214
  raw = (completion.choices[0].message.content or "").strip()
215
  print(f"[LLM] Raw response: {raw[:200]}", flush=True)
216
 
 
217
  if raw.startswith("```"):
218
  parts = raw.split("```")
219
  raw = parts[1] if len(parts) > 1 else raw
220
  if raw.startswith("json"):
221
  raw = raw[4:].strip()
222
 
223
- data = json.loads(raw)
 
 
 
 
224
 
225
  action = TriageAction(
226
  priority=data.get("priority", "P2"),
@@ -278,10 +290,10 @@ def main() -> None:
278
  success = score >= SUCCESS_SCORE_THRESHOLD
279
 
280
  except Exception as exc:
 
281
  print(f"[ERROR] Exception during run: {type(exc).__name__}: {exc}", flush=True)
282
  score = sum(rewards) / len(TASK_IDS) if rewards else 0.0
283
  success = False
284
- raise exc
285
 
286
  finally:
287
  log_end(success, len(rewards), score, rewards)
 
15
  import textwrap
16
  import requests
17
  from typing import List, Optional
 
18
 
19
  from openai import OpenAI
20
  from model import TriageAction, TriageObservation, BugReport
 
55
  # ── inlined client (avoids openenv-core import conflict) ──────────────────
56
 
57
  def _parse_observation(data: dict) -> TriageObservation:
58
+ try:
59
+ bug = BugReport.model_validate(data["bug_report"])
60
+ except Exception:
61
+ bug = BugReport(**data["bug_report"])
62
  return TriageObservation(
63
  bug_report=bug,
64
  task_id=data.get("task_id", "easy"),
 
96
 
97
  def step(self, action: TriageAction) -> StepResult:
98
  print("[ENV] Sending step action...", flush=True)
99
+ # FIX: TriageAction is a Pydantic model β€” use .dict() not asdict()
100
+ try:
101
+ action_dict = action.model_dump() # Pydantic v2
102
+ except AttributeError:
103
+ action_dict = action.dict() # Pydantic v1 fallback
104
  response = self.session.post(
105
  f"{self.base_url}/step",
106
+ json={"action": action_dict},
107
  timeout=30,
108
  )
109
  response.raise_for_status()
 
221
  raw = (completion.choices[0].message.content or "").strip()
222
  print(f"[LLM] Raw response: {raw[:200]}", flush=True)
223
 
224
+ # Strip markdown code fences if present
225
  if raw.startswith("```"):
226
  parts = raw.split("```")
227
  raw = parts[1] if len(parts) > 1 else raw
228
  if raw.startswith("json"):
229
  raw = raw[4:].strip()
230
 
231
+ try:
232
+ data = json.loads(raw)
233
+ except json.JSONDecodeError as e:
234
+ print(f"[LLM] JSON parse failed: {e}. Using safe defaults.", flush=True)
235
+ data = {}
236
 
237
  action = TriageAction(
238
  priority=data.get("priority", "P2"),
 
290
  success = score >= SUCCESS_SCORE_THRESHOLD
291
 
292
  except Exception as exc:
293
+ # FIX: do NOT re-raise β€” validator treats non-zero exit as failure
294
  print(f"[ERROR] Exception during run: {type(exc).__name__}: {exc}", flush=True)
295
  score = sum(rewards) / len(TASK_IDS) if rewards else 0.0
296
  success = False
 
297
 
298
  finally:
299
  log_end(success, len(rewards), score, rewards)