import random import uuid from datetime import datetime from tasks.base import BaseTask, InternalState, StepOutput from models import Action, StepResult, ServiceStatus, Alert, ActionType class GeneratedTask(BaseTask): def __init__(self, incident_dict: dict): self.incident = incident_dict self.task_id = "generated" self.max_steps = 20 # For compatibility with BaseTask which expects rng in __init__ super().__init__(random.Random(incident_dict.get("seed", 42))) def initialize(self) -> InternalState: affected = self.incident["affected_service"] failure_mode = self.incident["failure_mode"] services_dict = {} SERVICES = ["payment-service", "order-service", "user-service", "inventory-service", "api-gateway", "notification-service", "data-pipeline", "ml-inference-service"] for svc in SERVICES: if svc == affected: services_dict[svc] = { "name": svc, "status": "degraded", "cpu_percent": 85.0, "memory_percent": 92.0, "error_rate": 0.35, "latency_p99_ms": 2800.0, "replicas_running": 2, "replicas_desired": 3, "current_version": "v1.2.4", "last_deployed": (datetime.utcnow()).isoformat(), "minutes_degraded": 0, "sla_breach": False } else: services_dict[svc] = { "name": svc, "status": "healthy", "cpu_percent": 25.0, "memory_percent": 40.0, "error_rate": 0.01, "latency_p99_ms": 120.0, "replicas_running": 3, "replicas_desired": 3, "current_version": "v1.2.3", "last_deployed": (datetime.utcnow()).isoformat(), "minutes_degraded": 0, "sla_breach": False } active_alerts = [] # One CRITICAL alert active_alerts.append({ "id": str(uuid.uuid4())[:8], "service": affected, "severity": "critical", "message": self.incident["description"], "timestamp": datetime.utcnow().isoformat(), "acknowledged": False }) # Noise alerts for noise in self.incident["noise_alerts"]: active_alerts.append({ "id": str(uuid.uuid4())[:8], "service": "notification-service", "severity": "warning", "message": noise, "timestamp": datetime.utcnow().isoformat(), "acknowledged": False }) log_lines = { "oom": ["ERROR OutOfMemoryError: Java heap space", "WARN Memory usage at 98%, GC overhead limit exceeded"], "cascade": ["ERROR Connection pool exhausted: timeout after 30s", "ERROR Failed to acquire connection from pool"], "corruption": ["WARN Price mismatch detected: expected 29.99 got 299.9", "WARN Data validation failed for 847 records"], "security": ["WARN 1847 failed login attempts in 60s", "WARN Rate limit exceeded from 185.220.101.x"], "database": ["WARN Slow query: seq_scan on orders (847ms)", "WARN Query planner chose sequential scan, missing index"], "network_partition": ["ERROR Connection timeout to us-east-1", "ERROR Health check failed: unreachable"] } logs = {} for svc in SERVICES: if svc == affected: logs[svc] = log_lines.get(failure_mode, ["INFO Service running normally"]) else: logs[svc] = ["INFO Service running normally", "INFO Health check passed"] state = InternalState( episode_id=str(uuid.uuid4()), task_id="generated", step=0, max_steps=self.max_steps, services=services_dict, alerts=active_alerts, logs=logs, action_history=[], total_reward=0.0, incident_resolved=False, ground_truth_root_cause=self.incident["ground_truth_root_cause"], ground_truth_fix=self.incident["ground_truth_fix"], incident_start_time=datetime.utcnow().isoformat(), rewards_given=set() ) state._scenario = self.incident return state def step(self, state: InternalState, action: Action) -> StepOutput: reward = 0.0 result_text, error_text = self._apply_action_to_logs(state, action) # ActionType can be enum or string at = action.action_type at_val = at.value if hasattr(at, "value") else str(at) if at_val == "read_logs": if action.service == self.incident["affected_service"]: if "read_logs" not in state.rewards_given: reward += 0.10 state.rewards_given.add("read_logs") if at_val == "diagnose": diagnosis = action.diagnosis or action.root_cause or "" if state.ground_truth_root_cause.lower() in diagnosis.lower(): if "diagnose" not in state.rewards_given: reward += 0.30 state.rewards_given.add("diagnose") if at_val == self.incident["ground_truth_fix"]: if action.service == self.incident["affected_service"] and "fix" not in state.rewards_given: reward += 0.45 state.rewards_given.add("fix") state.incident_resolved = True state.services[self.incident["affected_service"]]["status"] = "healthy" state.services[self.incident["affected_service"]]["cpu_percent"] = 25.0 state.services[self.incident["affected_service"]]["memory_percent"] = 40.0 state.services[self.incident["affected_service"]]["error_rate"] = 0.01 state.services[self.incident["affected_service"]]["latency_p99_ms"] = 120.0 state.step += 1 state.total_reward = self._clamp(state.total_reward + reward) done = state.incident_resolved or state.step >= self.max_steps info = {} if state.incident_resolved: info["resolution"] = "incident_resolved" if state.step >= self.max_steps: info["reason"] = "max_steps_reached" state.action_history.append({ "step": state.step, "action": action.model_dump(), "reward": round(reward, 4) }) return StepOutput(next_state=state, reward=round(reward, 4), done=done, info=info)