Spaces:
Sleeping
Sleeping
| """ | |
| Core environment logic for OnCallEnv. | |
| Manages state transitions, command parsing, and reward computation. | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from datetime import datetime | |
| from typing import Any | |
| from models import ( | |
| Action, Alert, EnvState, LogEntry, Observation, | |
| Reward, ServiceMetrics, StepResponse, | |
| ) | |
| from scenarios import ALL_SCENARIOS, Scenario, ServiceDef | |
| class OnCallEnvironment: | |
| """Simulates a production incident response scenario.""" | |
| def __init__(self): | |
| self._scenario: Scenario | None = None | |
| self._step: int = 0 | |
| self._done: bool = False | |
| self._actions_taken: list[str] = [] | |
| self._investigation_log: list[str] = [] # services investigated | |
| self._root_cause_identified: bool = False | |
| self._remediation_applied: bool = False | |
| self._remediation_correct: bool = False | |
| self._resolved: bool = False | |
| self._wrong_actions: int = 0 # destructive actions on wrong services | |
| self._remediation_step: int | None = None # step when correct remediation was applied | |
| self._multi_fixes_applied: list[dict] = [] # for multi-root-cause scenarios | |
| # ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, task_id: str | None = None) -> Observation: | |
| """Reset environment to initial state for the given task.""" | |
| if task_id is None: | |
| task_id = "easy_memory_leak" | |
| if task_id not in ALL_SCENARIOS: | |
| raise ValueError(f"Unknown task: {task_id}. Available: {list(ALL_SCENARIOS)}") | |
| self._scenario = ALL_SCENARIOS[task_id] | |
| self._step = 0 | |
| self._done = False | |
| self._actions_taken = [] | |
| self._investigation_log = [] | |
| self._root_cause_identified = False | |
| self._remediation_applied = False | |
| self._remediation_correct = False | |
| self._resolved = False | |
| self._wrong_actions = 0 | |
| self._remediation_step = None | |
| self._multi_fixes_applied = [] | |
| return self._build_observation( | |
| last_action=None, | |
| last_action_result="Environment reset. You are the on-call engineer. Read the alerts and begin investigating.", | |
| last_action_error=False, | |
| ) | |
| def step(self, action: Action) -> StepResponse: | |
| """Execute one action and return the new observation + reward.""" | |
| if self._scenario is None: | |
| raise RuntimeError("Call reset() before step()") | |
| if self._done: | |
| return StepResponse( | |
| observation=self._build_observation( | |
| last_action=action.command, | |
| last_action_result="Episode already finished.", | |
| last_action_error=True, | |
| ), | |
| reward=self._compute_reward(), | |
| done=True, | |
| info={"reason": "already_done"}, | |
| ) | |
| self._step += 1 | |
| self._actions_taken.append(action.command) | |
| result, error = self._execute_command(action.command) | |
| # Check termination | |
| # After correct remediation, give agent 2 more steps to call mark_resolved | |
| if self._remediation_correct and self._remediation_step is None: | |
| self._remediation_step = self._step | |
| if self._resolved: | |
| self._done = True | |
| elif self._remediation_step and self._step >= self._remediation_step + 2: | |
| # Agent had 2 steps after remediation but didn't mark_resolved | |
| self._done = True | |
| elif self._step >= self._scenario.max_steps: | |
| self._done = True | |
| reward = self._compute_reward() | |
| obs = self._build_observation( | |
| last_action=action.command, | |
| last_action_result=result, | |
| last_action_error=error, | |
| ) | |
| info: dict[str, Any] = {} | |
| if self._done: | |
| info["reason"] = "resolved" if self._resolved else "max_steps" | |
| info["final_score"] = reward.total | |
| return StepResponse( | |
| observation=obs, | |
| reward=reward, | |
| done=self._done, | |
| info=info, | |
| ) | |
| def state(self) -> EnvState: | |
| """Return full internal state for debugging / grading.""" | |
| return EnvState( | |
| task_id=self._scenario.task_id if self._scenario else "", | |
| step=self._step, | |
| done=self._done, | |
| actions_taken=list(self._actions_taken), | |
| investigation_log=list(self._investigation_log), | |
| root_cause_identified=self._root_cause_identified, | |
| remediation_applied=self._remediation_applied, | |
| score=self._compute_reward().total, | |
| reward_breakdown=self._compute_reward().breakdown, | |
| ) | |
| def list_tasks(self) -> list[dict]: | |
| """Return metadata for all available tasks.""" | |
| return [ | |
| { | |
| "task_id": s.task_id, | |
| "task_name": s.task_name, | |
| "difficulty": s.difficulty, | |
| "description": s.description, | |
| "max_steps": s.max_steps, | |
| } | |
| for s in ALL_SCENARIOS.values() | |
| ] | |
| # ββ Command execution βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _execute_command(self, raw: str) -> tuple[str, bool]: | |
| """Parse and execute a command. Returns (result_text, is_error).""" | |
| raw = raw.strip() | |
| parts = raw.split(None, 2) | |
| if not parts: | |
| return "Empty command. Use one of: check_metrics, check_logs, check_config, check_dependencies, check_deploy_history, restart_service, rollback_deploy, scale_service, update_config, mark_resolved", True | |
| cmd = parts[0].lower() | |
| args = parts[1:] if len(parts) > 1 else [] | |
| dispatch = { | |
| "check_metrics": self._cmd_check_metrics, | |
| "check_logs": self._cmd_check_logs, | |
| "check_config": self._cmd_check_config, | |
| "check_dependencies": self._cmd_check_dependencies, | |
| "check_deploy_history": self._cmd_check_deploy_history, | |
| "restart_service": self._cmd_restart_service, | |
| "rollback_deploy": self._cmd_rollback_deploy, | |
| "scale_service": self._cmd_scale_service, | |
| "update_config": self._cmd_update_config, | |
| "mark_resolved": self._cmd_mark_resolved, | |
| } | |
| handler = dispatch.get(cmd) | |
| if handler is None: | |
| return ( | |
| f"Unknown command: '{cmd}'. Available commands: {', '.join(dispatch.keys())}", | |
| True, | |
| ) | |
| return handler(args) | |
| def _get_service(self, args: list[str]) -> tuple[ServiceDef | None, str]: | |
| if not args: | |
| return None, "Missing service name." | |
| svc_name = args[0] | |
| svc = self._scenario.services.get(svc_name) | |
| if svc is None: | |
| available = ", ".join(self._scenario.services.keys()) | |
| return None, f"Unknown service: '{svc_name}'. Available: {available}" | |
| return svc, "" | |
| def _record_investigation(self, service_name: str): | |
| if service_name not in self._investigation_log: | |
| self._investigation_log.append(service_name) | |
| def _check_multi_remediation(self, cmd: str, service: str, | |
| key: str | None = None) -> bool: | |
| """For multi-root-cause scenarios, track and check if this fix is one of the required ones.""" | |
| sc = self._scenario | |
| if sc.correct_remediation != "multi" or not sc.correct_remediations: | |
| return False | |
| for req in sc.correct_remediations: | |
| if req["cmd"] == cmd and req["service"] == service: | |
| if "key" in req and key != req["key"]: | |
| continue | |
| fix = {"cmd": cmd, "service": service} | |
| if fix not in self._multi_fixes_applied: | |
| self._multi_fixes_applied.append(fix) | |
| # Check if ALL required fixes are now applied | |
| if len(self._multi_fixes_applied) >= len(sc.correct_remediations): | |
| self._remediation_correct = True | |
| return True | |
| return False | |
| def _cmd_check_metrics(self, args: list[str]) -> tuple[str, bool]: | |
| svc, err = self._get_service(args) | |
| if svc is None: | |
| return err, True | |
| self._record_investigation(svc.name) | |
| # Dynamic metrics: degrade over time if unresolved, recover after fix | |
| cpu = svc.cpu | |
| memory = svc.memory | |
| latency_p50 = svc.latency_p50 | |
| latency_p99 = svc.latency_p99 | |
| error_rate = svc.error_rate | |
| rps = svc.rps | |
| extra = dict(svc.extra_metrics) | |
| if self._remediation_correct: | |
| # Post-fix: show healthy metrics | |
| if not svc.healthy or svc.name == self._scenario.root_cause_service: | |
| cpu = min(cpu, 30.0) | |
| memory = min(memory, 45.0) | |
| latency_p50 = min(latency_p50, 20.0) | |
| latency_p99 = min(latency_p99, 80.0) | |
| error_rate = 0.1 | |
| elif not svc.healthy and self._step > 1: | |
| # Pre-fix degradation: metrics worsen each step for unhealthy services | |
| degrade = min(self._step * 0.03, 0.15) # up to 15% worse | |
| cpu = min(99.0, cpu * (1 + degrade)) | |
| memory = min(99.0, memory * (1 + degrade * 0.5)) | |
| latency_p99 = latency_p99 * (1 + degrade) | |
| error_rate = min(99.0, error_rate * (1 + degrade)) | |
| # Update extra metrics that are counts | |
| for k in extra: | |
| if isinstance(extra[k], (int, float)) and any( | |
| w in k for w in ["fail", "error", "pending", "miss", "stale", "lag"] | |
| ): | |
| extra[k] = round(extra[k] * (1 + degrade), 1) if isinstance(extra[k], float) else int(extra[k] * (1 + degrade)) | |
| lines = [ | |
| f"=== Metrics for {svc.name} (step {self._step}) ===", | |
| f" CPU usage: {cpu:.1f}%", | |
| f" Memory usage: {memory:.1f}%", | |
| f" Latency p50: {latency_p50:.0f}ms", | |
| f" Latency p99: {latency_p99:.0f}ms", | |
| f" Error rate: {error_rate:.1f}%", | |
| f" Requests/sec: {rps:.0f}", | |
| ] | |
| for k, v in extra.items(): | |
| lines.append(f" {k}: {v}") | |
| return "\n".join(lines), False | |
| def _cmd_check_logs(self, args: list[str]) -> tuple[str, bool]: | |
| svc, err = self._get_service(args) | |
| if svc is None: | |
| return err, True | |
| self._record_investigation(svc.name) | |
| if not svc.logs: | |
| return f"No recent logs for {svc.name}.", False | |
| header = f"=== Recent logs for {svc.name} ===" | |
| return header + "\n" + "\n".join(svc.logs), False | |
| def _cmd_check_config(self, args: list[str]) -> tuple[str, bool]: | |
| svc, err = self._get_service(args) | |
| if svc is None: | |
| return err, True | |
| self._record_investigation(svc.name) | |
| lines = [f"=== Configuration for {svc.name} ==="] | |
| for k, v in svc.config.items(): | |
| lines.append(f" {k}: {v}") | |
| return "\n".join(lines), False | |
| def _cmd_check_dependencies(self, args: list[str]) -> tuple[str, bool]: | |
| svc, err = self._get_service(args) | |
| if svc is None: | |
| return err, True | |
| self._record_investigation(svc.name) | |
| deps = svc.dependencies or ["(none)"] | |
| return f"=== Dependencies for {svc.name} ===\n " + "\n ".join(deps), False | |
| def _cmd_check_deploy_history(self, args: list[str]) -> tuple[str, bool]: | |
| svc, err = self._get_service(args) | |
| if svc is None: | |
| return err, True | |
| self._record_investigation(svc.name) | |
| if not svc.deploy_history: | |
| return f"No deployment history for {svc.name}.", False | |
| lines = [f"=== Deploy history for {svc.name} ==="] | |
| for d in svc.deploy_history: | |
| line = f" [{d['date']}] {d['version']} β {d.get('status', 'unknown')}" | |
| if d.get("notes"): | |
| line += f"\n Notes: {d['notes']}" | |
| lines.append(line) | |
| return "\n".join(lines), False | |
| def _cmd_restart_service(self, args: list[str]) -> tuple[str, bool]: | |
| svc, err = self._get_service(args) | |
| if svc is None: | |
| return err, True | |
| self._remediation_applied = True | |
| # Check if this is the correct remediation | |
| sc = self._scenario | |
| if self._check_multi_remediation("restart_service", svc.name): | |
| fixes_done = len(self._multi_fixes_applied) | |
| fixes_total = len(sc.correct_remediations) | |
| if fixes_done < fixes_total: | |
| return ( | |
| f"Service {svc.name} restarted successfully. Service recovering. " | |
| f"[{fixes_done}/{fixes_total} issues fixed] β other issues remain. Continue investigating." | |
| ), False | |
| return ( | |
| f"Service {svc.name} restarted successfully. All issues resolved! " | |
| f"Now call mark_resolved with a root cause description to close the incident." | |
| ), False | |
| if (sc.correct_remediation == "restart_service" | |
| and svc.name == sc.root_cause_service): | |
| self._remediation_correct = True | |
| return ( | |
| f"Service {svc.name} restarted successfully. Service is now healthy. " | |
| f"Memory usage dropping to normal levels. Error rate returning to baseline. " | |
| f"Now call mark_resolved with a root cause description to close the incident." | |
| ), False | |
| if svc.name in sc.penalty_services: | |
| self._wrong_actions += 1 | |
| return f"Service {svc.name} restarted, but this does not address the underlying issue. Service was already healthy.", False | |
| return f"Service {svc.name} restarted. Monitoring for changes...", False | |
| def _cmd_rollback_deploy(self, args: list[str]) -> tuple[str, bool]: | |
| svc, err = self._get_service(args) | |
| if svc is None: | |
| return err, True | |
| self._remediation_applied = True | |
| sc = self._scenario | |
| if self._check_multi_remediation("rollback_deploy", svc.name): | |
| prev = [d for d in svc.deploy_history if d.get("status") == "previous"] | |
| prev_ver = prev[0]["version"] if prev else "previous" | |
| fixes_done = len(self._multi_fixes_applied) | |
| fixes_total = len(sc.correct_remediations) | |
| if fixes_done < fixes_total: | |
| return ( | |
| f"Rolled back {svc.name} to {prev_ver}. Service recovering. " | |
| f"[{fixes_done}/{fixes_total} issues fixed] β other issues remain. Continue investigating." | |
| ), False | |
| return ( | |
| f"Rolled back {svc.name} to {prev_ver}. All issues resolved! " | |
| f"Now call mark_resolved with a root cause description to close the incident." | |
| ), False | |
| if (sc.correct_remediation == "rollback_deploy" | |
| and svc.name == sc.root_cause_service): | |
| self._remediation_correct = True | |
| prev = [d for d in svc.deploy_history if d.get("status") == "previous"] | |
| prev_ver = prev[0]["version"] if prev else "previous" | |
| return ( | |
| f"Rolled back {svc.name} to {prev_ver}. " | |
| f"Service restarting with previous version... Service healthy. Metrics normalizing. " | |
| f"Now call mark_resolved with a root cause description to close the incident." | |
| ), False | |
| if svc.name in sc.penalty_services: | |
| self._wrong_actions += 1 | |
| return f"Rolled back {svc.name} to previous version. No improvement observed.", False | |
| def _cmd_scale_service(self, args: list[str]) -> tuple[str, bool]: | |
| if len(args) < 2: | |
| return "Usage: scale_service <service> <replicas>", True | |
| svc, err = self._get_service(args[:1]) | |
| if svc is None: | |
| return err, True | |
| try: | |
| replicas = int(args[1]) | |
| except ValueError: | |
| return f"Invalid replica count: {args[1]}", True | |
| if svc.name in self._scenario.penalty_services: | |
| self._wrong_actions += 1 | |
| return f"Scaled {svc.name} to {replicas} replicas. This may help with load but does not address root cause.", False | |
| def _cmd_update_config(self, args: list[str]) -> tuple[str, bool]: | |
| # Expected: update_config <service> <key> <value> | |
| if len(args) < 2: | |
| return "Usage: update_config <service> <key> <value>", True | |
| svc_name = args[0] | |
| svc = self._scenario.services.get(svc_name) | |
| if svc is None: | |
| available = ", ".join(self._scenario.services.keys()) | |
| return f"Unknown service: '{svc_name}'. Available: {available}", True | |
| # Parse key and value from remaining args | |
| # args[1:] could be ["db_pool_size", "50"] or ["db_pool_size 50"] | |
| sub_parts = " ".join(args[1:]).split(None, 1) | |
| if len(sub_parts) < 2: | |
| return "Usage: update_config <service> <key> <value>", True | |
| key, value = sub_parts[0], sub_parts[1] | |
| self._remediation_applied = True | |
| sc = self._scenario | |
| if self._check_multi_remediation("update_config", svc_name, key=key): | |
| fixes_done = len(self._multi_fixes_applied) | |
| fixes_total = len(sc.correct_remediations) | |
| if fixes_done < fixes_total: | |
| return ( | |
| f"Configuration updated: {svc_name}.{key} = {value}. Service recovering. " | |
| f"[{fixes_done}/{fixes_total} issues fixed] β other issues remain. Continue investigating." | |
| ), False | |
| return ( | |
| f"Configuration updated: {svc_name}.{key} = {value}. All issues resolved! " | |
| f"Now call mark_resolved with a root cause description to close the incident." | |
| ), False | |
| if (sc.correct_remediation == "update_config" | |
| and svc_name == sc.root_cause_service | |
| and key == sc.correct_remediation_args.get("key")): | |
| self._remediation_correct = True | |
| return ( | |
| f"Configuration updated: {svc_name}.{key} = {value}. " | |
| f"Service reloading config... Metrics stabilizing. Issue resolved. " | |
| f"Now call mark_resolved with a root cause description to close the incident." | |
| ), False | |
| if svc_name in sc.penalty_services: | |
| self._wrong_actions += 1 | |
| return f"Configuration updated: {svc_name}.{key} = {value}. Monitoring for effect...", False | |
| def _cmd_mark_resolved(self, args: list[str]) -> tuple[str, bool]: | |
| if not args: | |
| return "Usage: mark_resolved <root_cause_description>", True | |
| description = " ".join(args).lower() | |
| sc = self._scenario | |
| # Check if root cause description matches keywords | |
| matched = sum(1 for kw in sc.root_cause_keywords if kw in description) | |
| has_service = sc.root_cause_service.lower() in description | |
| if matched >= 1 and has_service: | |
| self._root_cause_identified = True | |
| if self._remediation_correct: | |
| self._resolved = True | |
| return "Incident marked as resolved. Root cause correctly identified. Well done!", False | |
| return ( | |
| "Root cause description acknowledged, but remediation has not been applied yet. " | |
| "Please fix the issue before marking as resolved." | |
| ), False | |
| elif matched >= 1 or has_service: | |
| self._root_cause_identified = True # partial | |
| return "Partial root cause identified. Consider being more specific about the service and issue.", False | |
| else: | |
| return "Root cause description does not match the actual issue. Continue investigating.", False | |
| # ββ Observation builder βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_observation(self, last_action: str | None, | |
| last_action_result: str | None, | |
| last_action_error: bool) -> Observation: | |
| sc = self._scenario | |
| # Build alerts with escalation if incident is unresolved | |
| alerts_data = list(sc.alerts) | |
| if not self._remediation_correct and self._step >= sc.max_steps * 0.5: | |
| # Escalate: add urgency alerts after 50% of steps | |
| alerts_data = alerts_data + [ | |
| {"alert_id": f"ALT-ESC-{self._step}", "severity": "critical", | |
| "service": sc.root_cause_service, | |
| "message": f"ESCALATION: Incident unresolved after {self._step} steps. Impact increasing.", | |
| "timestamp": f"2025-04-01T{10 + self._step}:{(self._step * 3) % 60:02d}:00Z"}, | |
| ] | |
| elif self._remediation_correct: | |
| # Post-fix: show resolution in alerts | |
| alerts_data = alerts_data + [ | |
| {"alert_id": "ALT-RESOLVED", "severity": "info", | |
| "service": sc.root_cause_service, | |
| "message": "Remediation applied. Metrics stabilizing. Call mark_resolved to close incident.", | |
| "timestamp": f"2025-04-01T{10 + self._step}:{(self._step * 3) % 60:02d}:00Z"}, | |
| ] | |
| # More realistic time progression | |
| hour = 10 + (self._step * 2) // 60 | |
| minute = (self._step * 2) % 60 | |
| return Observation( | |
| task_id=sc.task_id, | |
| goal=sc.goal, | |
| step=self._step, | |
| max_steps=sc.max_steps, | |
| current_time=f"2025-04-01T{hour:02d}:{minute:02d}:00Z", | |
| alerts=[Alert(**a) for a in alerts_data], | |
| last_action=last_action, | |
| last_action_result=last_action_result, | |
| last_action_error=last_action_error, | |
| available_commands=[ | |
| "check_metrics <service>", | |
| "check_logs <service>", | |
| "check_config <service>", | |
| "check_dependencies <service>", | |
| "check_deploy_history <service>", | |
| "restart_service <service>", | |
| "rollback_deploy <service>", | |
| "scale_service <service> <replicas>", | |
| "update_config <service> <key> <value>", | |
| "mark_resolved <root_cause_description>", | |
| ], | |
| services=list(sc.services.keys()), | |
| ) | |
| # ββ Reward computation ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _compute_reward(self) -> Reward: | |
| """ | |
| Reward breakdown (0.0 β 1.0): | |
| investigation: 0.30 β checking relevant services | |
| root_cause: 0.25 β correctly identifying root cause | |
| remediation: 0.30 β applying correct fix | |
| efficiency: 0.15 β fewer steps = better | |
| penalties: -0.05 each for wrong destructive actions | |
| """ | |
| sc = self._scenario | |
| breakdown: dict[str, float] = {} | |
| # 1) Investigation quality (0.30) | |
| # Per-step incremental: each relevant service gives proportional credit | |
| relevant_checked = [s for s in self._investigation_log | |
| if s in sc.investigation_targets] | |
| if sc.investigation_targets: | |
| investigation_ratio = len(relevant_checked) / len(sc.investigation_targets) | |
| else: | |
| investigation_ratio = 0.0 | |
| # Bonus for checking root cause service | |
| if sc.root_cause_service in self._investigation_log: | |
| investigation_ratio = min(1.0, investigation_ratio + 0.2) | |
| # Bonus for deep investigation (config or deploy history of root cause) | |
| deep_cmds = [a for a in self._actions_taken | |
| if (a.startswith("check_config " + sc.root_cause_service) | |
| or a.startswith("check_deploy_history " + sc.root_cause_service))] | |
| if deep_cmds: | |
| investigation_ratio = min(1.0, investigation_ratio + 0.1) | |
| breakdown["investigation"] = round(0.30 * min(investigation_ratio, 1.0), 3) | |
| # 2) Root cause identification (0.25) | |
| if self._root_cause_identified: | |
| breakdown["root_cause"] = 0.25 | |
| else: | |
| # Partial credit if they at least investigated the right service | |
| if sc.root_cause_service in self._investigation_log: | |
| breakdown["root_cause"] = 0.08 | |
| else: | |
| breakdown["root_cause"] = 0.0 | |
| # 3) Remediation (0.30) | |
| if self._remediation_correct: | |
| breakdown["remediation"] = 0.30 | |
| elif self._remediation_applied: | |
| breakdown["remediation"] = 0.05 # tried but wrong | |
| else: | |
| breakdown["remediation"] = 0.0 | |
| # 4) Efficiency (0.15) | |
| if self._done and self._resolved: | |
| steps_used = self._step | |
| max_s = sc.max_steps | |
| # Full efficiency bonus if resolved in <= 40% of max steps | |
| if steps_used <= max_s * 0.4: | |
| breakdown["efficiency"] = 0.15 | |
| elif steps_used <= max_s * 0.7: | |
| breakdown["efficiency"] = 0.10 | |
| else: | |
| breakdown["efficiency"] = 0.05 | |
| elif self._resolved: | |
| breakdown["efficiency"] = 0.05 | |
| else: | |
| breakdown["efficiency"] = 0.0 | |
| # 5) Penalties | |
| penalty = self._wrong_actions * 0.05 | |
| breakdown["penalty"] = round(-min(penalty, 0.15), 3) | |
| total = sum(breakdown.values()) | |
| total = round(max(0.01, min(0.99, total)), 3) | |
| return Reward(total=total, breakdown=breakdown) | |