Spaces:
Running
Running
| import random | |
| import uuid | |
| from incidents import TICKETS | |
| from graders import GRADERS | |
| from models import ( | |
| IncidentAction, | |
| IncidentObservation, | |
| IncidentReward, | |
| IncidentState, | |
| RecommendedAction, | |
| RootCauseCategory, | |
| SeverityLevel, | |
| StepResult, | |
| TaskType, | |
| ) | |
| TASK_SPECS = { | |
| TaskType.TASK1: { | |
| "name": "Severity Classification", | |
| "difficulty": "easy", | |
| "expected_field": "severity", | |
| "allowed_values": [item.value for item in SeverityLevel], | |
| "description": "Classify the severity of the incident using blast radius, user impact, and business risk.", | |
| }, | |
| TaskType.TASK2: { | |
| "name": "Root Cause Classification", | |
| "difficulty": "medium", | |
| "expected_field": "root_cause", | |
| "allowed_values": [item.value for item in RootCauseCategory], | |
| "description": "Identify the most likely failure domain from the incident evidence.", | |
| }, | |
| TaskType.TASK3: { | |
| "name": "Recommended Action", | |
| "difficulty": "hard", | |
| "expected_field": "action", | |
| "allowed_values": [item.value for item in RecommendedAction], | |
| "description": "Choose the best immediate operational response for stabilizing the incident.", | |
| }, | |
| } | |
| DEFAULT_RESET_SEED = 42 | |
| INITIAL_REWARD = 0.01 | |
| def validate_ticket_dataset(tickets: list[dict]) -> None: | |
| for ticket in tickets: | |
| incident_id = ticket.get("incident_id", "<unknown>") | |
| task_type_raw = ticket.get("task_type") | |
| try: | |
| task_type = TaskType(task_type_raw) | |
| except ValueError as exc: | |
| raise RuntimeError( | |
| f"Ticket '{incident_id}' has unsupported task_type '{task_type_raw}'." | |
| ) from exc | |
| expected_field = TASK_SPECS[task_type]["expected_field"] | |
| ground_truth = ticket.get("ground_truth") | |
| if not isinstance(ground_truth, dict) or not ground_truth: | |
| raise RuntimeError(f"Ticket '{incident_id}' has empty ground_truth.") | |
| if set(ground_truth.keys()) != {expected_field}: | |
| raise RuntimeError( | |
| f"Ticket '{incident_id}' must define only '{expected_field}' in ground_truth." | |
| ) | |
| if ground_truth.get(expected_field) in (None, ""): | |
| raise RuntimeError( | |
| f"Ticket '{incident_id}' has missing value for ground_truth['{expected_field}']." | |
| ) | |
| validate_ticket_dataset(TICKETS) | |
| TICKETS_BY_ID = {ticket["incident_id"]: ticket for ticket in TICKETS} | |
| class IncidentEnv: | |
| def __init__(self): | |
| self.current_ticket = None | |
| self.episode_id = "" | |
| self.step_count = 0 | |
| self.max_steps = 1 | |
| self.total_reward = INITIAL_REWARD | |
| self.done = False | |
| self.last_reward = INITIAL_REWARD | |
| self.last_action_summary = None | |
| def reset( | |
| self, | |
| task_type: TaskType | str | None = None, | |
| ticket_id: str | None = None, | |
| seed: int | None = None, | |
| ) -> StepResult: | |
| normalized_task = TaskType(task_type) if task_type else None | |
| self.current_ticket = self._select_ticket(normalized_task, ticket_id, seed) | |
| self.episode_id = str(uuid.uuid4()) | |
| self.step_count = 0 | |
| self.total_reward = INITIAL_REWARD | |
| self.done = False | |
| self.last_reward = INITIAL_REWARD | |
| self.last_action_summary = None | |
| return StepResult( | |
| observation=self._build_observation(), | |
| reward=IncidentReward(value=INITIAL_REWARD, reason="Episode initialized and awaiting first action."), | |
| done=False, | |
| info={ | |
| "episode_id": self.episode_id, | |
| "task_name": self._task_spec()["name"], | |
| "difficulty": self._task_spec()["difficulty"], | |
| "max_steps": self.max_steps, | |
| }, | |
| ) | |
| def step(self, action: IncidentAction) -> StepResult: | |
| if self.current_ticket is None: | |
| raise RuntimeError("Call reset() before step()") | |
| if self.done: | |
| raise RuntimeError("Episode already completed. Call reset() to start a new one.") | |
| if action.incident_id != self.current_ticket["incident_id"]: | |
| raise ValueError( | |
| f"Action incident_id '{action.incident_id}' does not match " | |
| f"current ticket '{self.current_ticket['incident_id']}'" | |
| ) | |
| if action.task_type != TaskType(self.current_ticket["task_type"]): | |
| raise ValueError( | |
| f"Action task_type '{action.task_type.value}' does not match " | |
| f"current ticket task_type '{self.current_ticket['task_type']}'" | |
| ) | |
| self._validate_action(action) | |
| task_type = self.current_ticket["task_type"] | |
| ground_truth, ground_truth_value = self._validated_ground_truth() | |
| grader_fn = GRADERS[task_type] | |
| reward_value, reward_reason = grader_fn(action, ground_truth) | |
| agent_answer = action.selected_value() or "NONE" | |
| selected_field = action.selected_field() or "NONE" | |
| self.step_count += 1 | |
| self.last_reward = reward_value | |
| self.total_reward = reward_value | |
| self.done = self.step_count >= self.max_steps | |
| self.last_action_summary = f"Submitted {selected_field}={agent_answer}" | |
| return StepResult( | |
| observation=self._build_observation(), | |
| reward=IncidentReward(value=reward_value, reason=reward_reason), | |
| done=self.done, | |
| info={ | |
| "episode_id": self.episode_id, | |
| "task_name": self._task_spec()["name"], | |
| "difficulty": self._task_spec()["difficulty"], | |
| "correct": agent_answer == ground_truth_value, | |
| "ground_truth": ground_truth_value, | |
| "agent_answer": agent_answer, | |
| "selected_field": selected_field, | |
| "max_steps": self.max_steps, | |
| }, | |
| ) | |
| def state(self, session_id: str | None = None) -> IncidentState: | |
| if self.current_ticket is None: | |
| raise RuntimeError("No active episode. Call reset() first.") | |
| return IncidentState( | |
| episode_id=self.episode_id, | |
| session_id=session_id, | |
| step_count=self.step_count, | |
| max_steps=self.max_steps, | |
| total_reward=self.total_reward, | |
| done=self.done, | |
| incident_id=self.current_ticket["incident_id"], | |
| task_type=TaskType(self.current_ticket["task_type"]), | |
| difficulty=self._task_spec()["difficulty"], | |
| status="completed" if self.done else "awaiting_action", | |
| last_reward=self.last_reward, | |
| ) | |
| def _select_ticket( | |
| self, | |
| task_type: TaskType | None = None, | |
| ticket_id: str | None = None, | |
| seed: int | None = None, | |
| ) -> dict: | |
| if ticket_id: | |
| ticket = TICKETS_BY_ID.get(ticket_id) | |
| if ticket is None: | |
| raise ValueError(f"No ticket found for ticket_id: {ticket_id}") | |
| if task_type and ticket["task_type"] != task_type.value: | |
| raise ValueError( | |
| f"Ticket '{ticket_id}' belongs to task_type '{ticket['task_type']}', " | |
| f"not '{task_type.value}'" | |
| ) | |
| return ticket | |
| pool = TICKETS | |
| if task_type: | |
| pool = [ticket for ticket in TICKETS if ticket["task_type"] == task_type.value] | |
| if not pool: | |
| raise ValueError(f"No tickets found for task_type: {task_type}") | |
| effective_seed = seed if seed is not None else DEFAULT_RESET_SEED | |
| chooser = random.Random(effective_seed) | |
| return chooser.choice(pool) | |
| def _task_spec(self) -> dict: | |
| if self.current_ticket is None: | |
| raise RuntimeError("No active episode. Call reset() first.") | |
| return TASK_SPECS[TaskType(self.current_ticket["task_type"])] | |
| def _build_observation(self) -> IncidentObservation: | |
| spec = self._task_spec() | |
| return IncidentObservation( | |
| incident_id=self.current_ticket["incident_id"], | |
| task_type=TaskType(self.current_ticket["task_type"]), | |
| difficulty=spec["difficulty"], | |
| task_description=spec["description"], | |
| alert_text=self.current_ticket["alert_text"], | |
| context=self.current_ticket["context"], | |
| expected_field=spec["expected_field"], | |
| allowed_values=spec["allowed_values"], | |
| step_count=self.step_count, | |
| max_steps=self.max_steps, | |
| last_action_summary=self.last_action_summary, | |
| last_reward=self.last_reward, | |
| episode_status="completed" if self.done else "awaiting_action", | |
| ) | |
| def _validate_action(self, action: IncidentAction) -> None: | |
| populated = action.populated_fields() | |
| if len(populated) != 1: | |
| raise ValueError("Action must populate exactly one of severity, root_cause, or action.") | |
| expected_field = self._task_spec()["expected_field"] | |
| if expected_field not in populated: | |
| raise ValueError( | |
| f"Task '{self.current_ticket['task_type']}' expects field '{expected_field}', " | |
| f"but got '{next(iter(populated))}'." | |
| ) | |
| def _validated_ground_truth(self) -> tuple[dict, str]: | |
| if self.current_ticket is None: | |
| raise RuntimeError("No active episode. Call reset() first.") | |
| incident_id = self.current_ticket["incident_id"] | |
| expected_field = self._task_spec()["expected_field"] | |
| ground_truth = self.current_ticket.get("ground_truth") | |
| if not isinstance(ground_truth, dict) or not ground_truth: | |
| raise RuntimeError( | |
| f"Ticket '{incident_id}' has empty ground_truth. This is a dataset integrity error." | |
| ) | |
| if expected_field not in ground_truth or ground_truth[expected_field] in (None, ""): | |
| raise RuntimeError( | |
| f"Ticket '{incident_id}' is missing ground_truth['{expected_field}']." | |
| ) | |
| return ground_truth, str(ground_truth[expected_field]) | |