XcodeAddy commited on
Commit
eefb7bf
·
1 Parent(s): b6d1ff0

Harden inference transport and output handling

Browse files
Files changed (3) hide show
  1. inference.py +113 -32
  2. requirements.txt +1 -0
  3. tests/test_inference.py +11 -1
inference.py CHANGED
@@ -2,6 +2,7 @@ import json
2
  import os
3
  import re
4
  import sys
 
5
  from pathlib import Path
6
  from typing import Any, Dict, List, Optional
7
 
@@ -93,7 +94,7 @@ class HttpEnvironmentTransport(EnvironmentTransport):
93
  json={"task_type": task_type, "ticket_id": ticket_id},
94
  timeout=30,
95
  )
96
- response.raise_for_status()
97
  return response.json()
98
 
99
  def step(self, session_id: str, action: Dict[str, Any]) -> Dict[str, Any]:
@@ -103,18 +104,42 @@ class HttpEnvironmentTransport(EnvironmentTransport):
103
  json=action,
104
  timeout=30,
105
  )
106
- response.raise_for_status()
107
  return response.json()
108
 
109
  def close(self) -> None:
110
  self.session.close()
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  class LocalEnvironmentTransport(EnvironmentTransport):
114
  def __init__(self):
115
- from fastapi.testclient import TestClient
 
 
 
 
 
 
116
 
117
- import app as app_module
 
 
 
 
 
118
 
119
  self.session = TestClient(app_module.app)
120
 
@@ -142,15 +167,25 @@ class LocalEnvironmentTransport(EnvironmentTransport):
142
  def build_transport() -> EnvironmentTransport:
143
  http_transport = HttpEnvironmentTransport(ENV_URL)
144
  if http_transport.probe():
 
145
  return http_transport
146
  http_transport.close()
 
 
 
 
147
  return LocalEnvironmentTransport()
148
 
149
 
150
  def create_model_client() -> Optional[OpenAI]:
151
  if not (API_BASE_URL and API_KEY and MODEL_NAME):
152
  return None
153
- return OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
 
 
 
 
 
154
 
155
 
156
  def build_user_prompt(observation: Dict[str, Any]) -> str:
@@ -184,12 +219,18 @@ def extract_json(raw: str) -> Dict[str, Any]:
184
 
185
  def normalize_action(raw_action: Dict[str, Any], observation: Dict[str, Any]) -> Dict[str, Any]:
186
  task_type = observation["task_type"]
 
 
 
 
 
 
187
  return {
188
  "incident_id": observation["incident_id"],
189
  "task_type": task_type,
190
- "severity": raw_action.get("severity") if task_type == "task1" else None,
191
- "root_cause": raw_action.get("root_cause") if task_type == "task2" else None,
192
- "action": raw_action.get("action") if task_type == "task3" else None,
193
  }
194
 
195
 
@@ -209,7 +250,12 @@ def predict_severity(alert_text: str, context: Dict[str, Any]) -> str:
209
  or _number(context.get("failure_rate"))
210
  or _number(context.get("affected_users_pct"))
211
  )
212
- revenue_impact = context.get("revenue_impact") is True or context.get("revenue_dependency") == "high"
 
 
 
 
 
213
 
214
  if (
215
  "CRITICAL" in alert_text
@@ -266,7 +312,7 @@ def predict_action(alert_text: str, context_text: str) -> str:
266
  def heuristic_action(observation: Dict[str, Any]) -> Dict[str, Any]:
267
  task_type = observation["task_type"]
268
  alert_text = observation["alert_text"].upper()
269
- context_text = json.dumps(observation["context"]).upper()
270
 
271
  if task_type == "task1":
272
  return normalize_action({"severity": predict_severity(alert_text, observation["context"])}, observation)
@@ -279,7 +325,7 @@ def get_action(model_client: Optional[OpenAI], observation: Dict[str, Any]) -> D
279
  if model_client is None:
280
  return heuristic_action(observation)
281
 
282
- for _ in range(2):
283
  try:
284
  completion = model_client.chat.completions.create(
285
  model=MODEL_NAME,
@@ -289,12 +335,21 @@ def get_action(model_client: Optional[OpenAI], observation: Dict[str, Any]) -> D
289
  ],
290
  temperature=TEMPERATURE,
291
  max_tokens=MAX_TOKENS,
 
292
  )
293
  content = (completion.choices[0].message.content or "").strip()
294
  return normalize_action(extract_json(content), observation)
295
- except Exception:
 
 
 
 
296
  continue
297
 
 
 
 
 
298
  return heuristic_action(observation)
299
 
300
 
@@ -382,34 +437,60 @@ def write_results(
382
  results: List[Dict[str, Any]],
383
  output_path: Path = OUTPUT_PATH,
384
  ) -> None:
385
- grouped: Dict[str, List[float]] = {}
386
- for result in results:
387
- grouped.setdefault(result["task_type"], []).append(result.get("score", 0.0))
388
-
389
- summary = {
390
- "benchmark": BENCHMARK,
391
- "model": MODEL_NAME,
392
- "episodes": len(results),
393
- "average_score": (sum(result.get("score", 0.0) for result in results) / len(results)) if results else 0.0,
394
- "by_task": {
395
- task_type: {
396
- "episodes": len(scores),
397
- "average_score": (sum(scores) / len(scores)) if scores else 0.0,
398
- }
399
- for task_type, scores in grouped.items()
400
- },
401
- "results": results,
402
- }
 
403
 
404
  try:
405
  output_path.parent.mkdir(parents=True, exist_ok=True)
406
- output_path.write_text(json.dumps(summary, indent=2))
 
407
  except (PermissionError, OSError) as exc:
408
  print(
409
- f"[WARN] Could not write results file to {output_path}: {exc}. Scores were still emitted to stdout.",
410
  file=sys.stderr,
411
  flush=True,
412
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
 
415
  def main() -> None:
 
2
  import os
3
  import re
4
  import sys
5
+ import tempfile
6
  from pathlib import Path
7
  from typing import Any, Dict, List, Optional
8
 
 
94
  json={"task_type": task_type, "ticket_id": ticket_id},
95
  timeout=30,
96
  )
97
+ self._raise_for_status_with_body(response)
98
  return response.json()
99
 
100
  def step(self, session_id: str, action: Dict[str, Any]) -> Dict[str, Any]:
 
104
  json=action,
105
  timeout=30,
106
  )
107
+ self._raise_for_status_with_body(response)
108
  return response.json()
109
 
110
  def close(self) -> None:
111
  self.session.close()
112
 
113
+ @staticmethod
114
+ def _raise_for_status_with_body(response: requests.Response) -> None:
115
+ if response.ok:
116
+ return
117
+ try:
118
+ error_body = response.json()
119
+ except ValueError:
120
+ error_body = response.text[:500]
121
+ raise requests.HTTPError(
122
+ f"{response.status_code} {response.reason} — Body: {error_body}",
123
+ response=response,
124
+ )
125
+
126
 
127
  class LocalEnvironmentTransport(EnvironmentTransport):
128
  def __init__(self):
129
+ try:
130
+ from fastapi.testclient import TestClient
131
+ except ImportError as exc:
132
+ raise RuntimeError(
133
+ "LocalEnvironmentTransport requires FastAPI test-client dependencies "
134
+ "(including httpx). Install them with: pip install fastapi httpx"
135
+ ) from exc
136
 
137
+ try:
138
+ import app as app_module
139
+ except ImportError as exc:
140
+ raise RuntimeError(
141
+ "Could not import the local app module. Run inference.py from the project root."
142
+ ) from exc
143
 
144
  self.session = TestClient(app_module.app)
145
 
 
167
  def build_transport() -> EnvironmentTransport:
168
  http_transport = HttpEnvironmentTransport(ENV_URL)
169
  if http_transport.probe():
170
+ print(f"[TRANSPORT] Using HTTP transport at {ENV_URL}", flush=True)
171
  return http_transport
172
  http_transport.close()
173
+ print(
174
+ f"[TRANSPORT] HTTP server at {ENV_URL} is unavailable. Falling back to local in-process transport.",
175
+ flush=True,
176
+ )
177
  return LocalEnvironmentTransport()
178
 
179
 
180
  def create_model_client() -> Optional[OpenAI]:
181
  if not (API_BASE_URL and API_KEY and MODEL_NAME):
182
  return None
183
+ return OpenAI(
184
+ base_url=API_BASE_URL,
185
+ api_key=API_KEY,
186
+ timeout=20.0,
187
+ max_retries=0,
188
+ )
189
 
190
 
191
  def build_user_prompt(observation: Dict[str, Any]) -> str:
 
219
 
220
  def normalize_action(raw_action: Dict[str, Any], observation: Dict[str, Any]) -> Dict[str, Any]:
221
  task_type = observation["task_type"]
222
+
223
+ def upper_or_none(value: Any) -> Optional[str]:
224
+ if value is None:
225
+ return None
226
+ return str(value).upper().strip()
227
+
228
  return {
229
  "incident_id": observation["incident_id"],
230
  "task_type": task_type,
231
+ "severity": upper_or_none(raw_action.get("severity")) if task_type == "task1" else None,
232
+ "root_cause": upper_or_none(raw_action.get("root_cause")) if task_type == "task2" else None,
233
+ "action": upper_or_none(raw_action.get("action")) if task_type == "task3" else None,
234
  }
235
 
236
 
 
250
  or _number(context.get("failure_rate"))
251
  or _number(context.get("affected_users_pct"))
252
  )
253
+ revenue_impact = (
254
+ context.get("revenue_impact") is True
255
+ or context.get("revenue_dependency") == "high"
256
+ or "REVENUE IMPACT" in alert_text
257
+ or "REVENUE_IMPACT" in alert_text.replace(" ", "_")
258
+ )
259
 
260
  if (
261
  "CRITICAL" in alert_text
 
312
  def heuristic_action(observation: Dict[str, Any]) -> Dict[str, Any]:
313
  task_type = observation["task_type"]
314
  alert_text = observation["alert_text"].upper()
315
+ context_text = json.dumps(observation["context"]).upper().replace("_", " ")
316
 
317
  if task_type == "task1":
318
  return normalize_action({"severity": predict_severity(alert_text, observation["context"])}, observation)
 
325
  if model_client is None:
326
  return heuristic_action(observation)
327
 
328
+ for attempt in range(2):
329
  try:
330
  completion = model_client.chat.completions.create(
331
  model=MODEL_NAME,
 
335
  ],
336
  temperature=TEMPERATURE,
337
  max_tokens=MAX_TOKENS,
338
+ timeout=15.0,
339
  )
340
  content = (completion.choices[0].message.content or "").strip()
341
  return normalize_action(extract_json(content), observation)
342
+ except Exception as exc:
343
+ print(
344
+ f"[WARN] LLM error on attempt {attempt + 1} for {observation['incident_id']}: {exc}",
345
+ flush=True,
346
+ )
347
  continue
348
 
349
+ print(
350
+ f"[FALLBACK] Using heuristic for {observation['incident_id']} after LLM failures.",
351
+ flush=True,
352
+ )
353
  return heuristic_action(observation)
354
 
355
 
 
437
  results: List[Dict[str, Any]],
438
  output_path: Path = OUTPUT_PATH,
439
  ) -> None:
440
+ try:
441
+ summary = {
442
+ "benchmark": BENCHMARK,
443
+ "model": MODEL_NAME,
444
+ "episodes": len(results),
445
+ "average_score": (sum(result.get("score", 0.0) for result in results) / len(results)) if results else 0.0,
446
+ "by_task": _group_by_task(results),
447
+ "results": results,
448
+ }
449
+ serialized = json.dumps(summary, indent=2)
450
+ except (TypeError, ValueError) as exc:
451
+ print(
452
+ f"[ERROR] Results serialization failed: {exc}. Raw episode results follow.",
453
+ file=sys.stderr,
454
+ flush=True,
455
+ )
456
+ for result in results:
457
+ print(f"[RESULT] {json.dumps(result, default=str)}", flush=True)
458
+ return
459
 
460
  try:
461
  output_path.parent.mkdir(parents=True, exist_ok=True)
462
+ output_path.write_text(serialized)
463
+ print(f"[RESULTS] Written to {output_path}", flush=True)
464
  except (PermissionError, OSError) as exc:
465
  print(
466
+ f"[WARN] Could not write results file to {output_path}: {exc}",
467
  file=sys.stderr,
468
  flush=True,
469
  )
470
+ fallback_path = Path(tempfile.gettempdir()) / "incident-triage-env-baseline-scores.json"
471
+ try:
472
+ fallback_path.write_text(serialized)
473
+ print(f"[RESULTS] Fallback written to {fallback_path}", flush=True)
474
+ except OSError as fallback_exc:
475
+ print(
476
+ f"[WARN] Fallback results write failed: {fallback_exc}. Emitting JSON summary to stdout.",
477
+ file=sys.stderr,
478
+ flush=True,
479
+ )
480
+ print(f"[RESULTS_JSON] {serialized}", flush=True)
481
+
482
+
483
+ def _group_by_task(results: List[Dict[str, Any]]) -> Dict[str, Dict[str, float]]:
484
+ grouped: Dict[str, List[float]] = {}
485
+ for result in results:
486
+ grouped.setdefault(result["task_type"], []).append(result.get("score", 0.0))
487
+ return {
488
+ task_type: {
489
+ "episodes": len(scores),
490
+ "average_score": (sum(scores) / len(scores)) if scores else 0.0,
491
+ }
492
+ for task_type, scores in grouped.items()
493
+ }
494
 
495
 
496
  def main() -> None:
requirements.txt CHANGED
@@ -6,3 +6,4 @@ requests
6
  python-dotenv
7
  setuptools
8
  wheel
 
 
6
  python-dotenv
7
  setuptools
8
  wheel
9
+ httpx>=0.27.0
tests/test_inference.py CHANGED
@@ -3,10 +3,20 @@ import tempfile
3
  import unittest
4
  from pathlib import Path
5
 
6
- from inference import write_results
7
 
8
 
9
  class InferenceOutputTests(unittest.TestCase):
 
 
 
 
 
 
 
 
 
 
10
  def test_write_results_writes_summary_to_configured_path(self) -> None:
11
  results = [
12
  {"incident_id": "INC-001", "task_type": "task1", "score": 1.0, "success": True},
 
3
  import unittest
4
  from pathlib import Path
5
 
6
+ from inference import normalize_action, write_results
7
 
8
 
9
  class InferenceOutputTests(unittest.TestCase):
10
+ def test_normalize_action_uppercases_model_outputs(self) -> None:
11
+ normalized = normalize_action(
12
+ {"action": "failover"},
13
+ {"incident_id": "INC-014", "task_type": "task3"},
14
+ )
15
+
16
+ self.assertEqual(normalized["action"], "FAILOVER")
17
+ self.assertIsNone(normalized["severity"])
18
+ self.assertIsNone(normalized["root_cause"])
19
+
20
  def test_write_results_writes_summary_to_configured_path(self) -> None:
21
  results = [
22
  {"incident_id": "INC-001", "task_type": "task1", "score": 1.0, "success": True},