| from __future__ import annotations |
| import random |
| import uuid |
| from abc import ABC, abstractmethod |
| from dataclasses import dataclass, field |
| from typing import List, Dict, Any, Optional, Set |
| from models import ( |
| Action, ActionType, Observation, State, StepResult, |
| ServiceStatus, Alert, ServiceDependency, EvidenceEntry, |
| ) |
|
|
|
|
| AVAILABLE_RUNBOOKS = [ |
| "high_cpu.md", |
| "memory_leak.md", |
| "db_connection.md", |
| "deployment_rollback.md", |
| "cascade_failure.md", |
| "data_corruption.md", |
| ] |
|
|
| TASK_DESCRIPTIONS = { |
| "easy": ( |
| "PRODUCTION INCIDENT — One service is crash-looping. " |
| "Read its logs and metrics to find the root cause, diagnose precisely, " |
| "then apply the correct single-service fix. " |
| "Avoid restarting healthy services — collateral damage is penalised." |
| ), |
| "medium": ( |
| "PRODUCTION INCIDENT — Multiple services are degraded. " |
| "Use the service dependency map to trace the failure to its origin. " |
| "A recent deployment is likely involved. One alert is a red herring. " |
| "Fix the root service only — downstream victims will self-heal." |
| ), |
| "hard": ( |
| "PRODUCTION INCIDENT — All services show green health. No error-rate alerts. " |
| "Look for anomalies in business-logic metrics and WARN-level logs. " |
| "Correlate signals across services to find silent data corruption. " |
| "Two actions are required for full credit: rollback AND alert_oncall." |
| ), |
| "bonus": ( |
| "PRODUCTION INCIDENT — Two independent failures are active simultaneously. " |
| "They are unrelated — fixing one will NOT fix the other. " |
| "Identify both root causes and remediate each independently. " |
| "Full credit requires resolving both." |
| ), |
| } |
|
|
|
|
| @dataclass |
| class InternalState: |
| episode_id: str |
| task_id: str |
| step: int |
| max_steps: int |
| services: Dict[str, dict] |
| alerts: list |
| logs: Dict[str, List[str]] |
| action_history: List[Dict[str, Any]] |
| total_reward: float |
| incident_resolved: bool |
| ground_truth_root_cause: str |
| ground_truth_fix: str |
| incident_start_time: str |
| last_action_result: Optional[str] = field(default=None) |
| last_action_error: Optional[str] = field(default=None) |
| rewards_given: Set[str] = field(default_factory=set) |
| healthy_services: List[str] = field(default_factory=list) |
| evidence_log: List[dict] = field(default_factory=list) |
| service_dependencies: List[dict] = field(default_factory=list) |
| _scenario: Any = field(default=None, repr=False) |
| _ml_version: Any = field(default=None, repr=False) |
|
|
| def to_state_snapshot(self) -> State: |
| obs = self._build_observation() |
| return State( |
| episode_id=self.episode_id, |
| task_id=self.task_id, |
| step=self.step, |
| current_observation=obs, |
| action_history=self.action_history, |
| total_reward=round(self.total_reward, 4), |
| incident_resolved=self.incident_resolved, |
| ground_truth_root_cause=self.ground_truth_root_cause, |
| ground_truth_fix=self.ground_truth_fix, |
| info={ |
| "rewards_unlocked": sorted(self.rewards_given), |
| "evidence_gathered": len(self.evidence_log), |
| }, |
| ) |
|
|
| def _build_sla_status(self) -> Dict[str, str]: |
| status = {} |
| for name, svc in self.services.items(): |
| if svc["status"] == "down": |
| mins = self.step * 2 |
| if mins >= 10: |
| status[name] = "breached" |
| elif mins >= 5: |
| status[name] = "warning" |
| else: |
| status[name] = "ok" |
| elif svc["status"] == "degraded": |
| mins = self.step * 2 |
| if mins >= 20: |
| status[name] = "breached" |
| elif mins >= 10: |
| status[name] = "warning" |
| else: |
| status[name] = "ok" |
| else: |
| status[name] = "ok" |
| return status |
|
|
| def _apply_sla_degradation(self) -> None: |
| """Services get progressively worse if not fixed — adds urgency.""" |
| if self.incident_resolved: |
| return |
| for name, svc in self.services.items(): |
| if svc["status"] == "down": |
| svc["minutes_degraded"] = svc.get("minutes_degraded", 0) + 2 |
| |
| svc["error_rate"] = min(svc["error_rate"] * 1.05, 50.0) |
| elif svc["status"] == "degraded": |
| svc["minutes_degraded"] = svc.get("minutes_degraded", 0) + 2 |
| |
| svc["latency_p99_ms"] = min(svc["latency_p99_ms"] * 1.03, 60000.0) |
| if svc["latency_p99_ms"] > 30000 and svc["error_rate"] < 1.0: |
| svc["error_rate"] = round(svc["error_rate"] + 0.5, 2) |
|
|
| def _build_observation( |
| self, |
| last_action_result: Optional[str] = None, |
| last_action_error: Optional[str] = None, |
| ) -> Observation: |
| if last_action_result is not None: self.last_action_result = last_action_result |
| if last_action_error is not None: self.last_action_error = last_action_error |
| services = [] |
| for name, s in self.services.items(): |
| services.append(ServiceStatus( |
| name=s["name"], |
| status=s["status"], |
| cpu_percent=s["cpu_percent"], |
| memory_percent=s["memory_percent"], |
| error_rate=round(s["error_rate"], 3), |
| latency_p99_ms=round(s["latency_p99_ms"], 0), |
| replicas_running=s["replicas_running"], |
| replicas_desired=s["replicas_desired"], |
| current_version=s["current_version"], |
| last_deployed=s["last_deployed"], |
| sla_breach=s.get("sla_breach", False), |
| minutes_degraded=s.get("minutes_degraded", 0), |
| )) |
|
|
| alerts = [Alert(**a) for a in self.alerts] |
| deps = [ServiceDependency(**d) for d in self.service_dependencies] |
| evidence = [EvidenceEntry(**e) for e in self.evidence_log] |
| sla = self._build_sla_status() |
|
|
| return Observation( |
| step=self.step, |
| max_steps=self.max_steps, |
| task_id=self.task_id, |
| task_description=TASK_DESCRIPTIONS.get(self.task_id, ""), |
| services=services, |
| active_alerts=alerts, |
| recent_logs={ |
| svc: lines[-2:] + ([f"[... {len(lines)-2} more lines — use read_logs to see full history]"] if len(lines) > 2 else []) |
| for svc, lines in self.logs.items() |
| }, |
| available_runbooks=AVAILABLE_RUNBOOKS, |
| service_dependencies=deps, |
| evidence_log=evidence, |
| sla_status=sla, |
| last_action_result=self.last_action_result, |
| last_action_error=self.last_action_error, |
| incident_start_time=self.incident_start_time, |
| elapsed_minutes=self.step * 2, |
| ) |
|
|
|
|
| @dataclass |
| class StepOutput: |
| next_state: InternalState |
| reward: float |
| done: bool |
| info: Dict[str, Any] |
|
|
|
|
| def semantic_match(candidate: str, keywords: List[str], threshold: int = 1) -> bool: |
| """ |
| Returns True if candidate contains at least `threshold` keywords. |
| Case-insensitive, handles hyphens/underscores. |
| """ |
| if not candidate: |
| return False |
| c = candidate.lower().replace("-", " ").replace("_", " ") |
| hits = sum(1 for kw in keywords if kw.lower().replace("-", " ") in c) |
| return hits >= threshold |
|
|
|
|
| class BaseTask(ABC): |
| def __init__(self, rng: random.Random): |
| self.rng = rng |
|
|
| @abstractmethod |
| def initialize(self) -> InternalState: |
| pass |
|
|
| @abstractmethod |
| def step(self, state: InternalState, action: Action) -> StepOutput: |
| pass |
|
|
| def _apply_action_to_logs( |
| self, state: InternalState, action: Action |
| ) -> tuple[Optional[str], Optional[str]]: |
| at = action.action_type.value |
|
|
| if at == "read_logs": |
| svc = action.service |
| if svc and svc in state.logs: |
| lines = state.logs[svc] |
| result = "\n".join(lines) |
| |
| state.evidence_log.append({ |
| "step": state.step, |
| "source": f"logs:{svc}", |
| "summary": f"Read {len(lines)} log lines from {svc}", |
| "raw": result, |
| }) |
| return result, None |
| return None, f"No logs found for service '{svc}'" |
|
|
| if at == "search_logs": |
| svc = action.service |
| query = (action.query or "").lower() |
| if not svc or svc not in state.logs: |
| return None, f"Unknown service '{svc}'" |
| if not query: |
| return None, "search_logs requires a query parameter" |
| lines = state.logs[svc] |
| matches = [l for l in lines if query in l.lower()] |
| if not matches: |
| result = f"No lines matching '{query}' in {svc} logs." |
| else: |
| result = f"Found {len(matches)} lines matching '{query}':\n" + "\n".join(matches) |
| state.evidence_log.append({ |
| "step": state.step, |
| "source": f"search:{svc}", |
| "summary": f"Searched {svc} for '{query}': {len(matches)} matches", |
| "raw": result, |
| }) |
| return result, None |
|
|
| if at == "read_metrics": |
| svc = action.service |
| if svc and svc in state.services: |
| s = state.services[svc] |
| result = ( |
| f"=== Metrics: {svc} ===\n" |
| f"Status: {s['status'].upper()}\n" |
| f"CPU: {s['cpu_percent']:.1f}%\n" |
| f"Memory: {s['memory_percent']:.1f}%\n" |
| f"Error rate: {s['error_rate']:.3f}/s\n" |
| f"P99 latency: {s['latency_p99_ms']:.0f}ms\n" |
| f"Replicas: {s['replicas_running']}/{s['replicas_desired']}\n" |
| f"Version: {s['current_version']}\n" |
| f"Last deploy: {s['last_deployed']}\n" |
| f"Degraded for: {s.get('minutes_degraded', 0)} minutes" |
| ) |
| state.evidence_log.append({ |
| "step": state.step, |
| "source": f"metrics:{svc}", |
| "summary": ( |
| f"{svc}: {s['status']}, cpu={s['cpu_percent']:.0f}%, " |
| f"mem={s['memory_percent']:.0f}%, err={s['error_rate']:.2f}/s, " |
| f"ver={s['current_version']}" |
| ), |
| "raw": result, |
| }) |
| return result, None |
| return None, f"Unknown service '{svc}'" |
|
|
| if at == "read_runbook": |
| rb = action.runbook |
| if rb in AVAILABLE_RUNBOOKS: |
| content = self._load_runbook(rb) |
| state.evidence_log.append({ |
| "step": state.step, |
| "source": f"runbook:{rb}", |
| "summary": f"Read runbook: {rb}", |
| "raw": content[:200], |
| }) |
| return content, None |
| return None, f"Runbook '{rb}' not found. Available: {AVAILABLE_RUNBOOKS}" |
|
|
| if at == "acknowledge": |
| alert_id = action.service |
| for a in state.alerts: |
| if a["id"] == alert_id: |
| a["acknowledged"] = True |
| return f"Alert {alert_id} acknowledged.", None |
| return None, f"Alert '{alert_id}' not found." |
|
|
| if at == "noop": |
| return "No action taken.", None |
|
|
| return None, None |
|
|
| def _load_runbook(self, name: str) -> str: |
| import os |
| path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data", "runbooks", name) |
| try: |
| with open(path) as f: |
| return f.read() |
| except FileNotFoundError: |
| return f"[Runbook '{name}' not found]" |
|
|
| def _clamp(self, value: float) -> float: |
| return max(0.0, min(1.0, value)) |
|
|
| def _penalty_blind_remediation( |
| self, state: InternalState, action: Action, fix_key: str |
| ) -> float: |
| """ |
| Small penalty if agent remediates without any prior diagnosis. |
| Encourages evidence-gathering before action. |
| """ |
| if fix_key in state.rewards_given: |
| return 0.0 |
| if "diagnose_correct" not in state.rewards_given and \ |
| "diagnose_partial" not in state.rewards_given: |
| return -0.05 |
| return 0.0 |
|
|