""" Core environment engine — implements reset/step/state for the SRE Incident Response env. """ import uuid from typing import Any, Dict, Optional, Set, Tuple from models import ( Action, ActionType, GraderResult, INVESTIGATION_ACTIONS, Observation, REMEDIATION_ACTIONS, RootCauseCategory, ServiceState, ServiceStatus, State, ) from env.scenario import IncidentScenario, RequiredFix from tasks import SCENARIOS from env.services import ( format_config_diff, format_deploy_history, format_dependencies, format_logs, format_metrics, format_runbook, format_traces, generate_alerts, ping_service, recompute_health, ) class Session: """Tracks the state of a single episode.""" def __init__(self, scenario: IncidentScenario, session_id: str): self.session_id = session_id self.scenario = scenario self.step_count = 0 self.done = False self.cumulative_reward = 0.0 # Mutable service state: {name: {status, version, replicas}} self.services: Dict[str, Dict[str, Any]] = {} for name, cfg in scenario.services.items(): self.services[name] = { "status": cfg.status, "version": cfg.version, "replicas": cfg.replicas, } # Track which root-cause services have been fixed self.fixed_services: Set[str] = set() # Build root-cause map: service_name -> fault_type self.root_cause_map: Dict[str, str] = {} for name, cfg in scenario.services.items(): if cfg.is_root_cause and cfg.fault_type: self.root_cause_map[name] = cfg.fault_type # Action history for grading self.actions: list[Action] = [] self.services_investigated: Set[str] = set() self.remediations_applied: list[Dict[str, Any]] = [] self.diagnosis: Optional[Action] = None class IncidentResponseEnv: """The SRE Incident Response OpenEnv environment.""" def __init__(self): self.sessions: Dict[str, Session] = {} def get_task_ids(self) -> list[str]: return list(SCENARIOS.keys()) def reset(self, task_id: str, seed: int = 0) -> Tuple[Observation, str]: """Start a new episode for the given task.""" if task_id not in SCENARIOS: raise ValueError(f"Unknown task_id: {task_id}. Available: {list(SCENARIOS.keys())}") scenario = SCENARIOS[task_id] session_id = str(uuid.uuid4())[:8] session = Session(scenario, session_id) self.sessions[session_id] = session # Build initial observation obs = self._build_observation(session, action_result=None) return obs, session_id def step(self, session_id: str, action: Action) -> Tuple[Observation, float, bool, Dict]: """Execute an action and return (observation, reward, done, info).""" session = self._get_session(session_id) if session.done: obs = self._build_observation(session, action_result="Episode already finished.") return obs, 0.0, True, {"error": "Episode already finished."} session.step_count += 1 session.actions.append(action) reward = 0.0 action_result = "" info: Dict[str, Any] = {} service_name = action.service scenario = session.scenario # Validate service name for actions that require it if action.action_type != ActionType.SUBMIT_DIAGNOSIS: if service_name and service_name not in scenario.services: action_result = f"Unknown service: '{service_name}'. Available: {list(scenario.services.keys())}" obs = self._build_observation(session, action_result=action_result) return obs, 0.0, False, {"error": action_result} if not service_name and action.action_type != ActionType.SUBMIT_DIAGNOSIS: action_result = "Action requires a 'service' parameter." obs = self._build_observation(session, action_result=action_result) return obs, 0.0, False, {"error": action_result} # ── Investigation actions ── if action.action_type in INVESTIGATION_ACTIONS: session.services_investigated.add(service_name) action_result = self._handle_investigation(session, action) # Small reward for investigating root cause services if service_name in scenario.root_cause_services: reward = 0.01 else: reward = 0.0 # ── Remediation actions ── elif action.action_type in REMEDIATION_ACTIONS: action_result, reward = self._handle_remediation(session, action) session.remediations_applied.append({ "action": action.action_type.value, "service": service_name, "target_version": action.target_version, "replicas": action.replicas, }) # ── Submit diagnosis ── elif action.action_type == ActionType.SUBMIT_DIAGNOSIS: session.diagnosis = action session.done = True grader_result = self._grade(session) reward = grader_result.score action_result = f"Diagnosis submitted. Score: {grader_result.score:.2f}" info["grader_result"] = grader_result.model_dump() session.cumulative_reward += reward # Check max steps if session.step_count >= scenario.max_steps and not session.done: session.done = True if session.diagnosis is None: # Auto-grade with whatever we have grader_result = self._grade(session) reward = grader_result.score info["grader_result"] = grader_result.model_dump() action_result += f"\n[MAX STEPS REACHED] Episode ended. Score: {grader_result.score:.2f}" obs = self._build_observation(session, action_result=action_result, reward=reward) obs.done = session.done if "grader_result" in info: obs.score = info["grader_result"]["score"] return obs, reward, session.done, info def state(self, session_id: str) -> State: """Return current episode state.""" session = self._get_session(session_id) return State( session_id=session.session_id, task_id=session.scenario.task_id, step_count=session.step_count, max_steps=session.scenario.max_steps, done=session.done, actions_taken=[a.action_type.value for a in session.actions], services_investigated=list(session.services_investigated), remediations_applied=[f"{r['action']}({r['service']})" for r in session.remediations_applied], cumulative_reward=round(session.cumulative_reward, 4), ) # ── Internal helpers ─────────────────────────────────────────────── def _get_session(self, session_id: str) -> Session: if session_id not in self.sessions: raise ValueError(f"Unknown session: {session_id}") return self.sessions[session_id] def _build_observation( self, session: Session, action_result: Optional[str], reward: float = 0.0, ) -> Observation: scenario = session.scenario svc_states = {} for name, data in session.services.items(): svc_states[name] = ServiceState( status=data["status"], version=data["version"], replicas=data["replicas"], ) alerts = generate_alerts( session.services, scenario.initial_alerts, session.fixed_services, ) return Observation( step_number=session.step_count, timestamp=f"2026-04-06T04:{session.step_count:02d}:00Z", services=svc_states, active_alerts=alerts, incident_summary=scenario.incident_summary if session.step_count == 0 else "", action_result=action_result, reward=round(reward, 4), done=session.done, ) def _handle_investigation(self, session: Session, action: Action) -> str: scenario = session.scenario svc = action.service if action.action_type == ActionType.READ_LOGS: logs = scenario.logs.get(svc, []) return format_logs(logs) elif action.action_type == ActionType.CHECK_METRICS: metrics = scenario.metrics.get(svc, []) return format_metrics(metrics) elif action.action_type == ActionType.PING_SERVICE: status = session.services[svc]["status"] return ping_service(status, svc) elif action.action_type == ActionType.CHECK_DEPENDENCIES: deps = scenario.dependencies.get(svc, []) dep_info = format_dependencies(deps) # Also show current health of dependencies dep_health = [] for d in deps: if d in session.services: dep_health.append(f" {d}: {session.services[d]['status'].value}") if dep_health: dep_info += "\n\nDependency health:\n" + "\n".join(dep_health) return dep_info elif action.action_type == ActionType.INSPECT_DEPLOY: deploys = scenario.deploy_history.get(svc, []) return format_deploy_history(deploys) elif action.action_type == ActionType.QUERY_TRACES: traces = scenario.traces.get(svc, []) return format_traces(traces) elif action.action_type == ActionType.CHECK_RUNBOOK: runbook = scenario.runbooks.get(svc, "") return format_runbook(runbook) elif action.action_type == ActionType.DIFF_CONFIG: configs = scenario.configs.get(svc, {}) return format_config_diff(configs) return f"No data available for {action.action_type.value} on {svc}." def _handle_remediation(self, session: Session, action: Action) -> Tuple[str, float]: scenario = session.scenario svc = action.service reward = 0.0 result = "" # Check if this remediation matches any required fix fix_matched = False for req_fix in scenario.required_fixes: if self._fix_matches(action, req_fix): fix_matched = True session.fixed_services.add(svc) reward = 0.05 break if action.action_type == ActionType.RESTART_SERVICE: if fix_matched: session.services[svc]["status"] = ServiceStatus.HEALTHY result = f"Service '{svc}' restarted successfully. Status: HEALTHY" else: # Restarting a non-root-cause service: no effect on the underlying issue current = session.services[svc]["status"] if current == ServiceStatus.DOWN and svc in session.root_cause_map: result = f"Service '{svc}' restarted but crashed again — underlying issue persists." elif current == ServiceStatus.HEALTHY: result = f"Service '{svc}' restarted. It was already healthy — no change." else: result = f"Service '{svc}' restarted. Status unchanged — issue is caused by an upstream dependency." reward = -0.05 elif action.action_type == ActionType.ROLLBACK_DEPLOY: if fix_matched: session.services[svc]["version"] = action.target_version or "" session.services[svc]["status"] = ServiceStatus.HEALTHY result = ( f"Service '{svc}' rolled back to {action.target_version}. " f"Pods restarting with previous version... Status: HEALTHY" ) else: current_version = session.services[svc]["version"] result = ( f"Rolled back '{svc}' to {action.target_version}, but this didn't resolve the issue. " f"Previous version was {current_version}." ) reward = -0.05 elif action.action_type == ActionType.SCALE_UP: replicas = action.replicas or 3 if fix_matched or (svc in scenario.root_cause_services): session.services[svc]["replicas"] = replicas session.fixed_services.add(svc) session.services[svc]["status"] = ServiceStatus.HEALTHY result = f"Service '{svc}' scaled to {replicas} replicas. Memory pressure alleviated. Status: HEALTHY" reward = 0.05 else: session.services[svc]["replicas"] = replicas result = f"Service '{svc}' scaled to {replicas} replicas. No effect on the underlying issue." reward = -0.05 elif action.action_type == ActionType.DRAIN_TRAFFIC: result = f"Traffic drained from '{svc}'. Service is no longer receiving requests." if svc not in scenario.root_cause_services: reward = -0.05 # Recompute health after remediation session.services = recompute_health( session.services, scenario.dependencies, session.fixed_services, session.root_cause_map, ) # Add post-remediation status summary still_broken = [ name for name, data in session.services.items() if data["status"] != ServiceStatus.HEALTHY ] if still_broken: result += f"\n\n[POST-REMEDIATION CHECK] Services still unhealthy: {', '.join(still_broken)}" else: result += "\n\n[POST-REMEDIATION CHECK] All services are now HEALTHY." return result, reward def _fix_matches(self, action: Action, req_fix: RequiredFix) -> bool: """Check if an action matches a required fix.""" if action.action_type.value != req_fix.action: return False if action.service != req_fix.service: return False if req_fix.target_version and action.target_version != req_fix.target_version: return False return True def _grade(self, session: Session) -> GraderResult: """Deterministic grading of the episode.""" from graders.grader import grade_episode return grade_episode(session)