Spaces:
Running
Running
Fix evaluation reliability and lifecycle issues
Browse files- app.py +49 -9
- environment.py +49 -3
- graders.py +9 -0
- inference.py +21 -8
- openenv.yaml +2 -0
- tests/test_env.py +76 -0
- tests/test_inference.py +41 -0
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import uuid
|
| 2 |
from collections import Counter
|
| 3 |
from pathlib import Path
|
|
@@ -21,10 +22,6 @@ from models import (
|
|
| 21 |
TaskType,
|
| 22 |
)
|
| 23 |
|
| 24 |
-
app = FastAPI(title="Incident Triage Environment")
|
| 25 |
-
UI_DIR = Path(__file__).parent / "ui"
|
| 26 |
-
ASSETS_DIR = UI_DIR / "assets"
|
| 27 |
-
|
| 28 |
# Session store: session_id -> IncidentEnv instance
|
| 29 |
MAX_SESSIONS = 500
|
| 30 |
sessions: dict[str, IncidentEnv] = {}
|
|
@@ -32,6 +29,35 @@ completed_states: dict[str, IncidentState] = {}
|
|
| 32 |
session_lock = RLock()
|
| 33 |
task_counts = Counter(ticket["task_type"] for ticket in TICKETS)
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
app.mount("/assets", StaticFiles(directory=ASSETS_DIR), name="assets")
|
| 36 |
|
| 37 |
|
|
@@ -48,6 +74,15 @@ def evict_oldest(mapping: dict[str, Any], max_size: int) -> None:
|
|
| 48 |
mapping.pop(oldest_key, None)
|
| 49 |
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
@app.get("/", include_in_schema=False)
|
| 52 |
def home_page():
|
| 53 |
return FileResponse(UI_DIR / "index.html")
|
|
@@ -161,8 +196,7 @@ def reset(reset_request: ResetRequest | None = None):
|
|
| 161 |
evict_oldest(sessions, MAX_SESSIONS)
|
| 162 |
evict_oldest(completed_states, MAX_SESSIONS)
|
| 163 |
sessions[session_id] = env
|
| 164 |
-
result
|
| 165 |
-
result.info["state"] = env.state(session_id=session_id).model_dump()
|
| 166 |
log_event(
|
| 167 |
"RESET",
|
| 168 |
session_id=session_id,
|
|
@@ -188,9 +222,8 @@ def step(action: IncidentAction, session_id: str):
|
|
| 188 |
except (RuntimeError, ValueError) as e:
|
| 189 |
log_event("STEP_ERROR", session_id=session_id, incident_id=action.incident_id, error=str(e))
|
| 190 |
raise HTTPException(status_code=400, detail=str(e))
|
| 191 |
-
result.info["session_id"] = session_id
|
| 192 |
current_state = env.state(session_id=session_id)
|
| 193 |
-
result
|
| 194 |
if result.done:
|
| 195 |
completed_states[session_id] = current_state
|
| 196 |
sessions.pop(session_id, None)
|
|
@@ -235,7 +268,14 @@ def get_grader_info():
|
|
| 235 |
"task1": "exact=1.0, adjacent=0.5, far=0.0",
|
| 236 |
"task2": "exact=1.0, related-domain=0.5, unknown=0.25, wrong=0.0",
|
| 237 |
"task3": "exact=1.0, investigate fallback=0.4, related response=0.25, wrong=0.0",
|
| 238 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
}
|
| 240 |
|
| 241 |
|
|
|
|
| 1 |
+
from contextlib import asynccontextmanager
|
| 2 |
import uuid
|
| 3 |
from collections import Counter
|
| 4 |
from pathlib import Path
|
|
|
|
| 22 |
TaskType,
|
| 23 |
)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Session store: session_id -> IncidentEnv instance
|
| 26 |
MAX_SESSIONS = 500
|
| 27 |
sessions: dict[str, IncidentEnv] = {}
|
|
|
|
| 29 |
session_lock = RLock()
|
| 30 |
task_counts = Counter(ticket["task_type"] for ticket in TICKETS)
|
| 31 |
|
| 32 |
+
|
| 33 |
+
def emit_lifecycle_event(event: str, **fields: Any) -> None:
|
| 34 |
+
details = " ".join(f"{key}={value}" for key, value in fields.items())
|
| 35 |
+
print(f"[{event}] {details}", file=sys.stderr, flush=True)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@asynccontextmanager
|
| 39 |
+
async def lifespan(_: FastAPI):
|
| 40 |
+
emit_lifecycle_event("STARTUP", status="ready")
|
| 41 |
+
try:
|
| 42 |
+
yield
|
| 43 |
+
finally:
|
| 44 |
+
with session_lock:
|
| 45 |
+
active_count = len(sessions)
|
| 46 |
+
completed_count = len(completed_states)
|
| 47 |
+
sessions.clear()
|
| 48 |
+
completed_states.clear()
|
| 49 |
+
emit_lifecycle_event(
|
| 50 |
+
"SHUTDOWN",
|
| 51 |
+
active_sessions=active_count,
|
| 52 |
+
completed_sessions=completed_count,
|
| 53 |
+
status="cleared",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
app = FastAPI(title="Incident Triage Environment", lifespan=lifespan)
|
| 58 |
+
UI_DIR = Path(__file__).parent / "ui"
|
| 59 |
+
ASSETS_DIR = UI_DIR / "assets"
|
| 60 |
+
|
| 61 |
app.mount("/assets", StaticFiles(directory=ASSETS_DIR), name="assets")
|
| 62 |
|
| 63 |
|
|
|
|
| 74 |
mapping.pop(oldest_key, None)
|
| 75 |
|
| 76 |
|
| 77 |
+
def enrich_step_result(result: StepResult, session_id: str, state: IncidentState) -> StepResult:
|
| 78 |
+
enriched_info = {
|
| 79 |
+
**result.info,
|
| 80 |
+
"session_id": session_id,
|
| 81 |
+
"state": state.model_dump(),
|
| 82 |
+
}
|
| 83 |
+
return result.model_copy(update={"info": enriched_info})
|
| 84 |
+
|
| 85 |
+
|
| 86 |
@app.get("/", include_in_schema=False)
|
| 87 |
def home_page():
|
| 88 |
return FileResponse(UI_DIR / "index.html")
|
|
|
|
| 196 |
evict_oldest(sessions, MAX_SESSIONS)
|
| 197 |
evict_oldest(completed_states, MAX_SESSIONS)
|
| 198 |
sessions[session_id] = env
|
| 199 |
+
result = enrich_step_result(result, session_id=session_id, state=env.state(session_id=session_id))
|
|
|
|
| 200 |
log_event(
|
| 201 |
"RESET",
|
| 202 |
session_id=session_id,
|
|
|
|
| 222 |
except (RuntimeError, ValueError) as e:
|
| 223 |
log_event("STEP_ERROR", session_id=session_id, incident_id=action.incident_id, error=str(e))
|
| 224 |
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
| 225 |
current_state = env.state(session_id=session_id)
|
| 226 |
+
result = enrich_step_result(result, session_id=session_id, state=current_state)
|
| 227 |
if result.done:
|
| 228 |
completed_states[session_id] = current_state
|
| 229 |
sessions.pop(session_id, None)
|
|
|
|
| 268 |
"task1": "exact=1.0, adjacent=0.5, far=0.0",
|
| 269 |
"task2": "exact=1.0, related-domain=0.5, unknown=0.25, wrong=0.0",
|
| 270 |
"task3": "exact=1.0, investigate fallback=0.4, related response=0.25, wrong=0.0",
|
| 271 |
+
},
|
| 272 |
+
"notes": {
|
| 273 |
+
"task2": [
|
| 274 |
+
"DATABASE and APPLICATION are treated as related because application faults often surface as database pressure and vice versa.",
|
| 275 |
+
"NETWORK, INFRASTRUCTURE, and THIRD_PARTY share limited partial-credit bridges to reflect correlated outage signatures.",
|
| 276 |
+
"APPLICATION and THIRD_PARTY are intentionally not treated as related because they imply different remediation ownership.",
|
| 277 |
+
]
|
| 278 |
+
},
|
| 279 |
}
|
| 280 |
|
| 281 |
|
environment.py
CHANGED
|
@@ -38,6 +38,35 @@ TASK_SPECS = {
|
|
| 38 |
"description": "Choose the best immediate operational response for stabilizing the incident.",
|
| 39 |
},
|
| 40 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
TICKETS_BY_ID = {ticket["incident_id"]: ticket for ticket in TICKETS}
|
| 42 |
|
| 43 |
|
|
@@ -100,13 +129,12 @@ class IncidentEnv:
|
|
| 100 |
self._validate_action(action)
|
| 101 |
|
| 102 |
task_type = self.current_ticket["task_type"]
|
| 103 |
-
ground_truth = self.
|
| 104 |
grader_fn = GRADERS[task_type]
|
| 105 |
reward_value, reward_reason = grader_fn(action, ground_truth)
|
| 106 |
|
| 107 |
agent_answer = action.selected_value() or "NONE"
|
| 108 |
selected_field = action.selected_field() or "NONE"
|
| 109 |
-
ground_truth_value = list(ground_truth.values())[0]
|
| 110 |
|
| 111 |
self.step_count += 1
|
| 112 |
self.last_reward = reward_value
|
|
@@ -171,7 +199,8 @@ class IncidentEnv:
|
|
| 171 |
if not pool:
|
| 172 |
raise ValueError(f"No tickets found for task_type: {task_type}")
|
| 173 |
|
| 174 |
-
|
|
|
|
| 175 |
return chooser.choice(pool)
|
| 176 |
|
| 177 |
def _task_spec(self) -> dict:
|
|
@@ -208,3 +237,20 @@ class IncidentEnv:
|
|
| 208 |
f"Task '{self.current_ticket['task_type']}' expects field '{expected_field}', "
|
| 209 |
f"but got '{next(iter(populated))}'."
|
| 210 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"description": "Choose the best immediate operational response for stabilizing the incident.",
|
| 39 |
},
|
| 40 |
}
|
| 41 |
+
DEFAULT_RESET_SEED = 42
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def validate_ticket_dataset(tickets: list[dict]) -> None:
|
| 45 |
+
for ticket in tickets:
|
| 46 |
+
incident_id = ticket.get("incident_id", "<unknown>")
|
| 47 |
+
task_type_raw = ticket.get("task_type")
|
| 48 |
+
try:
|
| 49 |
+
task_type = TaskType(task_type_raw)
|
| 50 |
+
except ValueError as exc:
|
| 51 |
+
raise RuntimeError(
|
| 52 |
+
f"Ticket '{incident_id}' has unsupported task_type '{task_type_raw}'."
|
| 53 |
+
) from exc
|
| 54 |
+
|
| 55 |
+
expected_field = TASK_SPECS[task_type]["expected_field"]
|
| 56 |
+
ground_truth = ticket.get("ground_truth")
|
| 57 |
+
if not isinstance(ground_truth, dict) or not ground_truth:
|
| 58 |
+
raise RuntimeError(f"Ticket '{incident_id}' has empty ground_truth.")
|
| 59 |
+
if set(ground_truth.keys()) != {expected_field}:
|
| 60 |
+
raise RuntimeError(
|
| 61 |
+
f"Ticket '{incident_id}' must define only '{expected_field}' in ground_truth."
|
| 62 |
+
)
|
| 63 |
+
if ground_truth.get(expected_field) in (None, ""):
|
| 64 |
+
raise RuntimeError(
|
| 65 |
+
f"Ticket '{incident_id}' has missing value for ground_truth['{expected_field}']."
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
validate_ticket_dataset(TICKETS)
|
| 70 |
TICKETS_BY_ID = {ticket["incident_id"]: ticket for ticket in TICKETS}
|
| 71 |
|
| 72 |
|
|
|
|
| 129 |
self._validate_action(action)
|
| 130 |
|
| 131 |
task_type = self.current_ticket["task_type"]
|
| 132 |
+
ground_truth, ground_truth_value = self._validated_ground_truth()
|
| 133 |
grader_fn = GRADERS[task_type]
|
| 134 |
reward_value, reward_reason = grader_fn(action, ground_truth)
|
| 135 |
|
| 136 |
agent_answer = action.selected_value() or "NONE"
|
| 137 |
selected_field = action.selected_field() or "NONE"
|
|
|
|
| 138 |
|
| 139 |
self.step_count += 1
|
| 140 |
self.last_reward = reward_value
|
|
|
|
| 199 |
if not pool:
|
| 200 |
raise ValueError(f"No tickets found for task_type: {task_type}")
|
| 201 |
|
| 202 |
+
effective_seed = seed if seed is not None else DEFAULT_RESET_SEED
|
| 203 |
+
chooser = random.Random(effective_seed)
|
| 204 |
return chooser.choice(pool)
|
| 205 |
|
| 206 |
def _task_spec(self) -> dict:
|
|
|
|
| 237 |
f"Task '{self.current_ticket['task_type']}' expects field '{expected_field}', "
|
| 238 |
f"but got '{next(iter(populated))}'."
|
| 239 |
)
|
| 240 |
+
|
| 241 |
+
def _validated_ground_truth(self) -> tuple[dict, str]:
|
| 242 |
+
if self.current_ticket is None:
|
| 243 |
+
raise RuntimeError("No active episode. Call reset() first.")
|
| 244 |
+
|
| 245 |
+
incident_id = self.current_ticket["incident_id"]
|
| 246 |
+
expected_field = self._task_spec()["expected_field"]
|
| 247 |
+
ground_truth = self.current_ticket.get("ground_truth")
|
| 248 |
+
if not isinstance(ground_truth, dict) or not ground_truth:
|
| 249 |
+
raise RuntimeError(
|
| 250 |
+
f"Ticket '{incident_id}' has empty ground_truth. This is a dataset integrity error."
|
| 251 |
+
)
|
| 252 |
+
if expected_field not in ground_truth or ground_truth[expected_field] in (None, ""):
|
| 253 |
+
raise RuntimeError(
|
| 254 |
+
f"Ticket '{incident_id}' is missing ground_truth['{expected_field}']."
|
| 255 |
+
)
|
| 256 |
+
return ground_truth, str(ground_truth[expected_field])
|
graders.py
CHANGED
|
@@ -1,6 +1,15 @@
|
|
| 1 |
from models import IncidentAction
|
| 2 |
|
| 3 |
_SEV_ORDER = {"SEV1": 0, "SEV2": 1, "SEV3": 2}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
_TASK2_RELATED_GROUPS = [
|
| 5 |
{"DATABASE", "APPLICATION"},
|
| 6 |
{"NETWORK", "INFRASTRUCTURE"},
|
|
|
|
| 1 |
from models import IncidentAction
|
| 2 |
|
| 3 |
_SEV_ORDER = {"SEV1": 0, "SEV2": 1, "SEV3": 2}
|
| 4 |
+
# Related-domain partial credit is intentionally conservative.
|
| 5 |
+
# DATABASE <-> APPLICATION captures incidents where app bugs manifest as
|
| 6 |
+
# database saturation and vice versa.
|
| 7 |
+
# NETWORK <-> INFRASTRUCTURE captures physical or platform-layer correlation.
|
| 8 |
+
# NETWORK <-> THIRD_PARTY captures dependency outages that resemble network loss.
|
| 9 |
+
# INFRASTRUCTURE <-> THIRD_PARTY captures external services failing through shared
|
| 10 |
+
# platform primitives.
|
| 11 |
+
# APPLICATION <-> THIRD_PARTY is intentionally not included because we treat
|
| 12 |
+
# product-code failures and vendor degradation as materially different diagnoses.
|
| 13 |
_TASK2_RELATED_GROUPS = [
|
| 14 |
{"DATABASE", "APPLICATION"},
|
| 15 |
{"NETWORK", "INFRASTRUCTURE"},
|
inference.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import re
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Any, Dict, List, Optional
|
| 6 |
|
|
@@ -24,7 +25,7 @@ ENV_URL = os.environ.get("ENV_URL") or "http://localhost:7860"
|
|
| 24 |
BENCHMARK = "incident-triage-env"
|
| 25 |
MAX_TOKENS = 300
|
| 26 |
TEMPERATURE = 0.0
|
| 27 |
-
OUTPUT_PATH = Path("outputs/baseline_scores.json")
|
| 28 |
|
| 29 |
SYSTEM_PROMPT = """You are an expert SRE triaging production incidents.
|
| 30 |
You will receive an incident alert, structured context, and the expected output field.
|
|
@@ -377,7 +378,10 @@ def run_episode(
|
|
| 377 |
return episode_result
|
| 378 |
|
| 379 |
|
| 380 |
-
def write_results(
|
|
|
|
|
|
|
|
|
|
| 381 |
grouped: Dict[str, List[float]] = {}
|
| 382 |
for result in results:
|
| 383 |
grouped.setdefault(result["task_type"], []).append(result.get("score", 0.0))
|
|
@@ -397,16 +401,25 @@ def write_results(results: List[Dict[str, Any]]) -> None:
|
|
| 397 |
"results": results,
|
| 398 |
}
|
| 399 |
|
| 400 |
-
|
| 401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
|
| 403 |
|
| 404 |
def main() -> None:
|
| 405 |
transport = build_transport()
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
|
|
|
|
|
|
| 410 |
|
| 411 |
|
| 412 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
+
import sys
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Any, Dict, List, Optional
|
| 7 |
|
|
|
|
| 25 |
BENCHMARK = "incident-triage-env"
|
| 26 |
MAX_TOKENS = 300
|
| 27 |
TEMPERATURE = 0.0
|
| 28 |
+
OUTPUT_PATH = Path(os.environ.get("OUTPUT_PATH") or "/tmp/outputs/baseline_scores.json")
|
| 29 |
|
| 30 |
SYSTEM_PROMPT = """You are an expert SRE triaging production incidents.
|
| 31 |
You will receive an incident alert, structured context, and the expected output field.
|
|
|
|
| 378 |
return episode_result
|
| 379 |
|
| 380 |
|
| 381 |
+
def write_results(
|
| 382 |
+
results: List[Dict[str, Any]],
|
| 383 |
+
output_path: Path = OUTPUT_PATH,
|
| 384 |
+
) -> None:
|
| 385 |
grouped: Dict[str, List[float]] = {}
|
| 386 |
for result in results:
|
| 387 |
grouped.setdefault(result["task_type"], []).append(result.get("score", 0.0))
|
|
|
|
| 401 |
"results": results,
|
| 402 |
}
|
| 403 |
|
| 404 |
+
try:
|
| 405 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 406 |
+
output_path.write_text(json.dumps(summary, indent=2))
|
| 407 |
+
except (PermissionError, OSError) as exc:
|
| 408 |
+
print(
|
| 409 |
+
f"[WARN] Could not write results file to {output_path}: {exc}. Scores were still emitted to stdout.",
|
| 410 |
+
file=sys.stderr,
|
| 411 |
+
flush=True,
|
| 412 |
+
)
|
| 413 |
|
| 414 |
|
| 415 |
def main() -> None:
|
| 416 |
transport = build_transport()
|
| 417 |
+
try:
|
| 418 |
+
model_client = create_model_client()
|
| 419 |
+
results = [run_episode(transport, model_client, ticket) for ticket in TICKETS]
|
| 420 |
+
write_results(results)
|
| 421 |
+
finally:
|
| 422 |
+
transport.close()
|
| 423 |
|
| 424 |
|
| 425 |
if __name__ == "__main__":
|
openenv.yaml
CHANGED
|
@@ -101,3 +101,5 @@ reproducibility:
|
|
| 101 |
max_steps_per_episode: 1
|
| 102 |
dataset_order: fixed TICKETS list order in incidents.py
|
| 103 |
baseline_selection: deterministic ticket_id-driven evaluation across all tickets
|
|
|
|
|
|
|
|
|
| 101 |
max_steps_per_episode: 1
|
| 102 |
dataset_order: fixed TICKETS list order in incidents.py
|
| 103 |
baseline_selection: deterministic ticket_id-driven evaluation across all tickets
|
| 104 |
+
default_reset_seed: 42
|
| 105 |
+
reset_without_ticket_id: deterministic fixed-seed selection within the requested task pool
|
tests/test_env.py
CHANGED
|
@@ -3,6 +3,8 @@ import unittest
|
|
| 3 |
from fastapi.testclient import TestClient
|
| 4 |
|
| 5 |
from app import app, completed_states, sessions
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class IncidentEnvApiTests(unittest.TestCase):
|
|
@@ -33,6 +35,12 @@ class IncidentEnvApiTests(unittest.TestCase):
|
|
| 33 |
self.assertEqual(mcp_body["jsonrpc"], "2.0")
|
| 34 |
self.assertEqual(mcp_body["id"], 1)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def test_tickets_endpoint_returns_safe_ticket_inventory(self) -> None:
|
| 37 |
response = self.client.get("/tickets")
|
| 38 |
self.assertEqual(response.status_code, 200)
|
|
@@ -78,6 +86,17 @@ class IncidentEnvApiTests(unittest.TestCase):
|
|
| 78 |
self.assertIn("session_id", body["info"])
|
| 79 |
self.assertEqual(body["info"]["state"]["status"], "awaiting_action")
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
def test_step_completes_episode_and_state_endpoint_reflects_completion(self) -> None:
|
| 82 |
reset_response = self.client.post(
|
| 83 |
"/reset",
|
|
@@ -140,6 +159,63 @@ class IncidentEnvApiTests(unittest.TestCase):
|
|
| 140 |
self.assertEqual(step_response.status_code, 400)
|
| 141 |
self.assertIn("does not match", step_response.json()["detail"])
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
if __name__ == "__main__":
|
| 145 |
unittest.main()
|
|
|
|
| 3 |
from fastapi.testclient import TestClient
|
| 4 |
|
| 5 |
from app import app, completed_states, sessions
|
| 6 |
+
from environment import IncidentEnv, validate_ticket_dataset
|
| 7 |
+
from models import IncidentAction, IncidentState, TaskType
|
| 8 |
|
| 9 |
|
| 10 |
class IncidentEnvApiTests(unittest.TestCase):
|
|
|
|
| 35 |
self.assertEqual(mcp_body["jsonrpc"], "2.0")
|
| 36 |
self.assertEqual(mcp_body["id"], 1)
|
| 37 |
|
| 38 |
+
grader_response = self.client.get("/grader")
|
| 39 |
+
self.assertEqual(grader_response.status_code, 200)
|
| 40 |
+
grader_body = grader_response.json()
|
| 41 |
+
self.assertIn("notes", grader_body)
|
| 42 |
+
self.assertIn("task2", grader_body["notes"])
|
| 43 |
+
|
| 44 |
def test_tickets_endpoint_returns_safe_ticket_inventory(self) -> None:
|
| 45 |
response = self.client.get("/tickets")
|
| 46 |
self.assertEqual(response.status_code, 200)
|
|
|
|
| 86 |
self.assertIn("session_id", body["info"])
|
| 87 |
self.assertEqual(body["info"]["state"]["status"], "awaiting_action")
|
| 88 |
|
| 89 |
+
def test_reset_without_seed_is_deterministic_for_same_task(self) -> None:
|
| 90 |
+
first_response = self.client.post("/reset", json={"task_type": "task2"})
|
| 91 |
+
second_response = self.client.post("/reset", json={"task_type": "task2"})
|
| 92 |
+
|
| 93 |
+
self.assertEqual(first_response.status_code, 200)
|
| 94 |
+
self.assertEqual(second_response.status_code, 200)
|
| 95 |
+
self.assertEqual(
|
| 96 |
+
first_response.json()["observation"]["incident_id"],
|
| 97 |
+
second_response.json()["observation"]["incident_id"],
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
def test_step_completes_episode_and_state_endpoint_reflects_completion(self) -> None:
|
| 101 |
reset_response = self.client.post(
|
| 102 |
"/reset",
|
|
|
|
| 159 |
self.assertEqual(step_response.status_code, 400)
|
| 160 |
self.assertIn("does not match", step_response.json()["detail"])
|
| 161 |
|
| 162 |
+
def test_dataset_validation_rejects_empty_ground_truth(self) -> None:
|
| 163 |
+
with self.assertRaisesRegex(RuntimeError, "empty ground_truth"):
|
| 164 |
+
validate_ticket_dataset(
|
| 165 |
+
[
|
| 166 |
+
{
|
| 167 |
+
"incident_id": "INC-BAD",
|
| 168 |
+
"task_type": "task1",
|
| 169 |
+
"alert_text": "Broken test ticket",
|
| 170 |
+
"context": {},
|
| 171 |
+
"ground_truth": {},
|
| 172 |
+
}
|
| 173 |
+
]
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def test_step_raises_clear_dataset_error_for_invalid_ground_truth(self) -> None:
|
| 177 |
+
env = IncidentEnv()
|
| 178 |
+
env.current_ticket = {
|
| 179 |
+
"incident_id": "INC-BAD",
|
| 180 |
+
"task_type": "task1",
|
| 181 |
+
"alert_text": "Broken test ticket",
|
| 182 |
+
"context": {},
|
| 183 |
+
"ground_truth": {},
|
| 184 |
+
}
|
| 185 |
+
env.episode_id = "episode-bad"
|
| 186 |
+
|
| 187 |
+
with self.assertRaisesRegex(RuntimeError, "dataset integrity error"):
|
| 188 |
+
env.step(
|
| 189 |
+
IncidentAction(
|
| 190 |
+
incident_id="INC-BAD",
|
| 191 |
+
task_type="task1",
|
| 192 |
+
severity="SEV1",
|
| 193 |
+
)
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def test_lifespan_shutdown_clears_session_stores(self) -> None:
|
| 197 |
+
sessions["active-session"] = IncidentEnv()
|
| 198 |
+
completed_states["done-session"] = IncidentState(
|
| 199 |
+
episode_id="episode-1",
|
| 200 |
+
session_id="done-session",
|
| 201 |
+
step_count=1,
|
| 202 |
+
max_steps=1,
|
| 203 |
+
total_reward=1.0,
|
| 204 |
+
done=True,
|
| 205 |
+
incident_id="INC-001",
|
| 206 |
+
task_type=TaskType.TASK1,
|
| 207 |
+
difficulty="easy",
|
| 208 |
+
status="completed",
|
| 209 |
+
last_reward=1.0,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
with TestClient(app) as client:
|
| 213 |
+
response = client.get("/health")
|
| 214 |
+
self.assertEqual(response.status_code, 200)
|
| 215 |
+
|
| 216 |
+
self.assertEqual(sessions, {})
|
| 217 |
+
self.assertEqual(completed_states, {})
|
| 218 |
+
|
| 219 |
|
| 220 |
if __name__ == "__main__":
|
| 221 |
unittest.main()
|
tests/test_inference.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import tempfile
|
| 3 |
+
import unittest
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from inference import write_results
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class InferenceOutputTests(unittest.TestCase):
|
| 10 |
+
def test_write_results_writes_summary_to_configured_path(self) -> None:
|
| 11 |
+
results = [
|
| 12 |
+
{"incident_id": "INC-001", "task_type": "task1", "score": 1.0, "success": True},
|
| 13 |
+
{"incident_id": "INC-002", "task_type": "task2", "score": 0.5, "success": False},
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 17 |
+
output_path = Path(temp_dir) / "nested" / "baseline_scores.json"
|
| 18 |
+
write_results(results, output_path=output_path)
|
| 19 |
+
|
| 20 |
+
self.assertTrue(output_path.exists())
|
| 21 |
+
payload = json.loads(output_path.read_text())
|
| 22 |
+
self.assertEqual(payload["episodes"], 2)
|
| 23 |
+
self.assertAlmostEqual(payload["average_score"], 0.75)
|
| 24 |
+
self.assertEqual(payload["by_task"]["task1"]["average_score"], 1.0)
|
| 25 |
+
self.assertEqual(payload["by_task"]["task2"]["average_score"], 0.5)
|
| 26 |
+
|
| 27 |
+
def test_write_results_tolerates_unwritable_path(self) -> None:
|
| 28 |
+
results = [
|
| 29 |
+
{"incident_id": "INC-001", "task_type": "task1", "score": 1.0, "success": True},
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 33 |
+
blocked_parent = Path(temp_dir) / "blocked"
|
| 34 |
+
blocked_parent.write_text("not-a-directory")
|
| 35 |
+
blocked_path = blocked_parent / "baseline_scores.json"
|
| 36 |
+
|
| 37 |
+
write_results(results, output_path=blocked_path)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
unittest.main()
|