Spaces:
Running
Running
| from contextlib import asynccontextmanager | |
| import uuid | |
| from collections import Counter | |
| from pathlib import Path | |
| import sys | |
| from threading import RLock | |
| from typing import Any | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from environment import IncidentEnv, TASK_SPECS | |
| from incidents import TICKETS | |
| from models import ( | |
| IncidentAction, | |
| IncidentObservation, | |
| IncidentReward, | |
| IncidentState, | |
| ResetRequest, | |
| StepResult, | |
| TaskType, | |
| ) | |
| # Session store: session_id -> IncidentEnv instance | |
| MAX_SESSIONS = 500 | |
| sessions: dict[str, IncidentEnv] = {} | |
| completed_states: dict[str, IncidentState] = {} | |
| session_lock = RLock() | |
| task_counts = Counter(ticket["task_type"] for ticket in TICKETS) | |
| def emit_lifecycle_event(event: str, **fields: Any) -> None: | |
| details = " ".join(f"{key}={value}" for key, value in fields.items()) | |
| print(f"[{event}] {details}", file=sys.stderr, flush=True) | |
| async def lifespan(_: FastAPI): | |
| emit_lifecycle_event("STARTUP", status="ready") | |
| try: | |
| yield | |
| finally: | |
| with session_lock: | |
| active_count = len(sessions) | |
| completed_count = len(completed_states) | |
| sessions.clear() | |
| completed_states.clear() | |
| emit_lifecycle_event( | |
| "SHUTDOWN", | |
| active_sessions=active_count, | |
| completed_sessions=completed_count, | |
| status="cleared", | |
| ) | |
| app = FastAPI(title="Incident Triage Environment", lifespan=lifespan) | |
| UI_DIR = Path(__file__).parent / "ui" | |
| ASSETS_DIR = UI_DIR / "assets" | |
| app.mount("/assets", StaticFiles(directory=ASSETS_DIR), name="assets") | |
| def log_event(event: str, **fields: Any) -> None: | |
| details = " ".join(f"{key}={value}" for key, value in fields.items()) | |
| print(f"[{event}] {details}", file=sys.stderr, flush=True) | |
| def evict_oldest(mapping: dict[str, Any], max_size: int) -> None: | |
| while len(mapping) >= max_size: | |
| oldest_key = next(iter(mapping), None) | |
| if oldest_key is None: | |
| return | |
| mapping.pop(oldest_key, None) | |
| def enrich_step_result(result: StepResult, session_id: str, state: IncidentState) -> StepResult: | |
| enriched_info = { | |
| **result.info, | |
| "session_id": session_id, | |
| "state": state.model_dump(), | |
| } | |
| return result.model_copy(update={"info": enriched_info}) | |
| def home_page(): | |
| return FileResponse(UI_DIR / "index.html") | |
| def status_page(): | |
| return FileResponse(UI_DIR / "status.html") | |
| def playground_page(): | |
| return FileResponse(UI_DIR / "playground.html") | |
| def api_page(): | |
| return FileResponse(UI_DIR / "api.html") | |
| def health(): | |
| return {"status": "healthy"} | |
| def metadata(): | |
| return { | |
| "name": "incident-triage-env", | |
| "description": "Production incident triage environment for severity, root-cause, and remediation decisions.", | |
| "tasks": { | |
| task_type.value: { | |
| "name": spec["name"], | |
| "difficulty": spec["difficulty"], | |
| "expected_field": spec["expected_field"], | |
| "allowed_values": spec["allowed_values"], | |
| "ticket_count": task_counts[task_type.value], | |
| } | |
| for task_type, spec in TASK_SPECS.items() | |
| }, | |
| "total_tickets": len(TICKETS), | |
| } | |
| def schema(): | |
| return { | |
| "action": IncidentAction.model_json_schema(), | |
| "observation": IncidentObservation.model_json_schema(), | |
| "reward": IncidentReward.model_json_schema(), | |
| "state": IncidentState.model_json_schema(), | |
| "step_result": StepResult.model_json_schema(), | |
| } | |
| def get_tasks(): | |
| return { | |
| "tasks": { | |
| task_type.value: { | |
| "name": spec["name"], | |
| "difficulty": spec["difficulty"], | |
| "expected_field": spec["expected_field"], | |
| "allowed_values": spec["allowed_values"], | |
| "ticket_count": task_counts[task_type.value], | |
| } | |
| for task_type, spec in TASK_SPECS.items() | |
| } | |
| } | |
| def get_tickets(): | |
| tickets = [] | |
| for ticket in TICKETS: | |
| task_type = TaskType(ticket["task_type"]) | |
| spec = TASK_SPECS[task_type] | |
| tickets.append( | |
| { | |
| "incident_id": ticket["incident_id"], | |
| "task_type": ticket["task_type"], | |
| "difficulty": spec["difficulty"], | |
| "task_name": spec["name"], | |
| "expected_field": spec["expected_field"], | |
| "alert_preview": ticket["alert_text"][:120], | |
| } | |
| ) | |
| return {"tickets": tickets, "count": len(tickets)} | |
| def reset(reset_request: ResetRequest | None = None): | |
| request = reset_request or ResetRequest() | |
| session_id = str(uuid.uuid4()) | |
| env = IncidentEnv() | |
| try: | |
| result = env.reset( | |
| task_type=request.task_type, | |
| ticket_id=request.ticket_id, | |
| seed=request.seed, | |
| ) | |
| except ValueError as e: | |
| log_event( | |
| "RESET_ERROR", | |
| task_type=request.task_type.value if request.task_type else "any", | |
| ticket_id=request.ticket_id or "random", | |
| error=str(e), | |
| ) | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| with session_lock: | |
| evict_oldest(sessions, MAX_SESSIONS) | |
| evict_oldest(completed_states, MAX_SESSIONS) | |
| sessions[session_id] = env | |
| result = enrich_step_result(result, session_id=session_id, state=env.state(session_id=session_id)) | |
| log_event( | |
| "RESET", | |
| session_id=session_id, | |
| incident_id=result.observation.incident_id, | |
| task_type=result.observation.task_type.value, | |
| expected_field=result.observation.expected_field, | |
| ) | |
| return result | |
| def step(action: IncidentAction, session_id: str): | |
| with session_lock: | |
| env = sessions.get(session_id) | |
| if not env: | |
| if session_id in completed_states: | |
| log_event("STEP_ERROR", session_id=session_id, error="episode_already_completed") | |
| raise HTTPException(status_code=400, detail="Episode already completed. Call reset() to start a new one.") | |
| log_event("STEP_ERROR", session_id=session_id, error="session_not_found") | |
| raise HTTPException(status_code=404, detail="Session not found. Call /reset first.") | |
| try: | |
| result = env.step(action) | |
| except (RuntimeError, ValueError) as e: | |
| log_event("STEP_ERROR", session_id=session_id, incident_id=action.incident_id, error=str(e)) | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| current_state = env.state(session_id=session_id) | |
| result = enrich_step_result(result, session_id=session_id, state=current_state) | |
| if result.done: | |
| completed_states[session_id] = current_state | |
| sessions.pop(session_id, None) | |
| log_event( | |
| "STEP", | |
| session_id=session_id, | |
| incident_id=action.incident_id, | |
| task_type=action.task_type.value, | |
| answer=action.selected_value() or "NONE", | |
| reward=result.reward.value, | |
| done=str(result.done).lower(), | |
| ) | |
| return result | |
| def state(session_id: str): | |
| with session_lock: | |
| env = sessions.get(session_id) | |
| if not env: | |
| completed_state = completed_states.get(session_id) | |
| if completed_state: | |
| log_event("STATE", session_id=session_id, incident_id=completed_state.incident_id, done=str(completed_state.done).lower()) | |
| return completed_state | |
| log_event("STATE_ERROR", session_id=session_id, error="no_active_session") | |
| raise HTTPException(status_code=404, detail="No active session.") | |
| try: | |
| current_state = env.state(session_id=session_id) | |
| log_event("STATE", session_id=session_id, incident_id=current_state.incident_id, done=str(current_state.done).lower()) | |
| return current_state | |
| except RuntimeError as e: | |
| log_event("STATE_ERROR", session_id=session_id, error=str(e)) | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| def get_grader_info(): | |
| return { | |
| "grading": "deterministic", | |
| "scoring": "task1: adjacent-severity partial credit; task2/task3: exact match plus conservative near-miss partial credit; all rewards remain strictly within (0, 1)", | |
| "tasks": { | |
| "task1": "exact=0.99, adjacent=0.5, far=0.01", | |
| "task2": "exact=0.99, related-domain=0.5, unknown=0.25, wrong=0.01", | |
| "task3": "exact=0.99, investigate fallback=0.4, related response=0.25, wrong=0.01", | |
| }, | |
| "notes": { | |
| "task2": [ | |
| "DATABASE and APPLICATION are treated as related because application faults often surface as database pressure and vice versa.", | |
| "NETWORK, INFRASTRUCTURE, and THIRD_PARTY share limited partial-credit bridges to reflect correlated outage signatures.", | |
| "APPLICATION and THIRD_PARTY are intentionally not treated as related because they imply different remediation ownership.", | |
| ] | |
| }, | |
| } | |
| def mcp(payload: dict[str, Any] | None = None): | |
| request = payload or {} | |
| method = request.get("method") | |
| rpc_id = request.get("id") | |
| if method == "ping": | |
| result: dict[str, Any] = {"status": "ok"} | |
| elif method == "tools/list": | |
| result = {"tools": []} | |
| else: | |
| result = { | |
| "status": "ok", | |
| "message": "Incident triage environment does not expose MCP tools.", | |
| } | |
| return {"jsonrpc": "2.0", "id": rpc_id, "result": result} | |