| from __future__ import annotations |
| import uuid |
| from typing import Dict, Any, List |
| from models import Action, ActionType |
| from tasks.base import BaseTask, InternalState, StepOutput, semantic_match |
|
|
| INCIDENT_TIME = "2026-03-30T10:14:47Z" |
|
|
| SCENARIOS = [ |
| { |
| "failing_service": "payment-service", |
| "root_cause": "memory_leak_payment_service", |
| "fix": "restart payment-service", |
| "alert_msg": "payment-service pod restarting (OOMKilled)", |
| "language": "java", |
| "diagnosis_keywords": ["memory", "oom", "heap", "leak", "outofmemory", "kill"], |
| }, |
| { |
| "failing_service": "order-service", |
| "root_cause": "memory_leak_order_service", |
| "fix": "restart order-service", |
| "alert_msg": "order-service pod restarting (OOMKilled)", |
| "language": "python", |
| "diagnosis_keywords": ["memory", "oom", "heap", "leak", "segfault", "kill", "allocat"], |
| }, |
| { |
| "failing_service": "user-service", |
| "root_cause": "memory_leak_user_service", |
| "fix": "restart user-service", |
| "alert_msg": "user-service pod restarting (OOMKilled)", |
| "language": "node", |
| "diagnosis_keywords": ["memory", "heap", "oom", "leak", "javascript", "kill"], |
| }, |
| ] |
|
|
| ALL_SERVICES = ["payment-service", "order-service", "user-service", "api-gateway"] |
| VERSIONS = { |
| "payment-service": "v4.2.1", "order-service": "v1.8.2", |
| "user-service": "v3.0.5", "api-gateway": "v2.1.0", |
| } |
| DEPENDENCIES = [ |
| {"service": "api-gateway", "calls": ["payment-service", "order-service", "user-service"], "called_by": []}, |
| {"service": "payment-service", "calls": [], "called_by": ["api-gateway"]}, |
| {"service": "order-service", "calls": [], "called_by": ["api-gateway"]}, |
| {"service": "user-service", "calls": [], "called_by": ["api-gateway"]}, |
| ] |
|
|
| def _make_logs(scenario, heap1, heap2, restart_count): |
| svc = scenario["failing_service"] |
| lang = scenario["language"] |
| if lang == "java": |
| failing = [ |
| "[10:13:55] INFO Request processed 200 38ms", |
| f"[10:14:35] WARN Heap usage at {heap1}% - approaching threshold", |
| f"[10:14:41] WARN Heap usage at {heap2}%", |
| "[10:14:45] WARN GC overhead limit exceeded - major GC running", |
| "[10:14:47] ERROR java.lang.OutOfMemoryError: Java heap space", |
| "[10:14:47] ERROR at com.payments.ChargeProcessor.process(ChargeProcessor.java:142)", |
| f"[10:14:48] FATAL Service entering crash loop - pod restart #{restart_count}", |
| ] |
| elif lang == "python": |
| failing = [ |
| "[10:13:55] INFO POST /orders 200 55ms", |
| f"[10:14:35] WARN RSS memory {heap1}% of pod limit", |
| f"[10:14:41] WARN RSS memory {heap2}% of pod limit - approaching OOM", |
| "[10:14:46] ERROR Memory allocator: no more pages available", |
| "[10:14:47] ERROR Fatal Python error: Segmentation fault (memory allocator exhausted)", |
| f"[10:14:48] FATAL Pod killed by OOM killer - restart #{restart_count}", |
| ] |
| else: |
| failing = [ |
| "[10:13:55] INFO GET /users/profile 200 9ms", |
| f"[10:14:35] WARN Heap used: {heap1}% ({heap1 * 2}MB / 200MB)", |
| f"[10:14:41] WARN Heap used: {heap2}% - GC pressure increasing", |
| "[10:14:47] ERROR FATAL ERROR: Reached heap limit - JavaScript heap out of memory", |
| f"[10:14:48] FATAL Container OOMKilled - restart #{restart_count}", |
| ] |
| logs = {svc: failing} |
| for name in ALL_SERVICES: |
| if name == svc: continue |
| if name == "api-gateway": |
| logs[name] = [ |
| "[10:14:30] INFO GET /api/v1/health 200 3ms", |
| f"[10:14:48] WARN Upstream {svc} returned 503", |
| f"[10:14:49] WARN Circuit breaker OPEN for {svc}", |
| ] |
| else: |
| logs[name] = ["[10:14:30] INFO Service healthy - 0 errors"] |
| return logs |
|
|
|
|
| class EasyTask(BaseTask): |
| def initialize(self) -> InternalState: |
| scenario = SCENARIOS[self.rng.randint(0, len(SCENARIOS) - 1)] |
| failing = scenario["failing_service"] |
| heap1 = self.rng.randint(74, 83) |
| heap2 = heap1 + self.rng.randint(5, 10) |
| restart_count = self.rng.randint(2, 6) |
|
|
| services: Dict[str, dict] = {} |
| for name in ALL_SERVICES: |
| if name == failing: |
| services[name] = { |
| "name": name, "status": "down", |
| "cpu_percent": round(self.rng.uniform(5, 20), 1), |
| "memory_percent": round(self.rng.uniform(93, 99), 1), |
| "error_rate": round(self.rng.uniform(8.0, 15.0), 2), |
| "latency_p99_ms": round(self.rng.uniform(5000, 9000), 0), |
| "replicas_running": 0, "replicas_desired": 3, |
| "current_version": VERSIONS[name], |
| "last_deployed": "2026-03-28T14:00:00Z", |
| "minutes_degraded": 0, "sla_breach": False, |
| } |
| elif name == "api-gateway": |
| services[name] = { |
| "name": name, "status": "degraded", |
| "cpu_percent": round(self.rng.uniform(35, 55), 1), |
| "memory_percent": round(self.rng.uniform(40, 55), 1), |
| "error_rate": round(self.rng.uniform(2.0, 5.0), 2), |
| "latency_p99_ms": round(self.rng.uniform(800, 1500), 0), |
| "replicas_running": 2, "replicas_desired": 2, |
| "current_version": VERSIONS[name], |
| "last_deployed": "2026-03-25T09:00:00Z", |
| "minutes_degraded": 0, "sla_breach": False, |
| } |
| else: |
| services[name] = { |
| "name": name, "status": "healthy", |
| "cpu_percent": round(self.rng.uniform(20, 40), 1), |
| "memory_percent": round(self.rng.uniform(30, 48), 1), |
| "error_rate": 0.0, |
| "latency_p99_ms": round(self.rng.uniform(8, 30), 0), |
| "replicas_running": 2, "replicas_desired": 2, |
| "current_version": VERSIONS[name], |
| "last_deployed": "2026-03-20T11:00:00Z", |
| "minutes_degraded": 0, "sla_breach": False, |
| } |
|
|
| alerts = [ |
| { |
| "id": "A001", "severity": "critical", "service": failing, |
| "message": f"{scenario['alert_msg']} - {restart_count} times in 5 minutes", |
| "timestamp": "2026-03-30T10:14:48Z", "acknowledged": False, |
| }, |
| { |
| "id": "A002", "severity": "warning", "service": "api-gateway", |
| "message": f"Upstream {failing} returning 503 - circuit breaker open", |
| "timestamp": "2026-03-30T10:14:52Z", "acknowledged": False, |
| }, |
| ] |
|
|
| state = InternalState( |
| episode_id=str(uuid.uuid4()), task_id="easy", step=0, max_steps=15, |
| services=services, alerts=alerts, |
| logs=_make_logs(scenario, heap1, heap2, restart_count), |
| action_history=[], total_reward=0.0, incident_resolved=False, |
| ground_truth_root_cause=scenario["root_cause"], |
| ground_truth_fix=scenario["fix"], |
| incident_start_time=INCIDENT_TIME, |
| healthy_services=[s for s in ALL_SERVICES if s != failing], |
| service_dependencies=DEPENDENCIES, |
| ) |
| state._scenario = scenario |
| return state |
|
|
| def step(self, state: InternalState, action: Action) -> StepOutput: |
| state.step += 1 |
| state._apply_sla_degradation() |
| at = action.action_type |
| svc = action.service or "" |
| scenario = state._scenario |
| failing = scenario["failing_service"] |
| keywords = scenario["diagnosis_keywords"] |
| reward = 0.0 |
| done = False |
| info: Dict[str, Any] = {} |
|
|
| result_text, error_text = self._apply_action_to_logs(state, action) |
|
|
| if at in (ActionType.READ_LOGS, ActionType.SEARCH_LOGS) and svc == failing: |
| if "logs_investigated" not in state.rewards_given: |
| reward += 0.15 |
| state.rewards_given.add("logs_investigated") |
|
|
| if at == ActionType.READ_METRICS and svc == failing: |
| if "read_metrics" not in state.rewards_given: |
| reward += 0.10 |
| state.rewards_given.add("read_metrics") |
|
|
| if at == ActionType.READ_RUNBOOK: |
| if "runbook" not in state.rewards_given: |
| reward += 0.05 |
| state.rewards_given.add("runbook") |
|
|
| if at == ActionType.DIAGNOSE: |
| rc = action.root_cause or "" |
| correct_type = semantic_match(rc, keywords, threshold=1) |
| correct_svc = semantic_match(rc, [failing, failing.split("-")[0]]) |
| result_text = f"Diagnosis recorded: {rc}" |
| if correct_type and correct_svc: |
| if "diagnose_correct" not in state.rewards_given: |
| |
| bonus = 0.30 if "diagnose_partial" not in state.rewards_given else 0.15 |
| reward += bonus |
| state.rewards_given.add("diagnose_correct") |
| elif correct_type: |
| if "diagnose_partial" not in state.rewards_given and "diagnose_correct" not in state.rewards_given: |
| reward += 0.15 |
| state.rewards_given.add("diagnose_partial") |
|
|
| if at == ActionType.RESTART_SERVICE: |
| blind_penalty = self._penalty_blind_remediation(state, action, "restarted") |
| reward += blind_penalty |
| if svc == failing: |
| reward += 0.40 |
| state.services[svc]["status"] = "healthy" |
| state.services[svc]["memory_percent"] = round(self.rng.uniform(38, 48), 1) |
| state.services[svc]["error_rate"] = 0.0 |
| state.services[svc]["latency_p99_ms"] = round(self.rng.uniform(20, 60), 0) |
| state.services[svc]["replicas_running"] = state.services[svc]["replicas_desired"] |
| state.alerts = [a for a in state.alerts if a["id"] != "A001"] |
| state.incident_resolved = True |
| result_text = f"{svc} restarted. Memory cleared. All pods healthy." |
| done = True |
| info["resolution"] = "incident_resolved" |
| elif svc in state.healthy_services: |
| reward -= 0.10 |
| error_text = f"Collateral damage: {svc} was healthy. Unnecessary restart." |
|
|
| if at == ActionType.NOOP and state.step > 3: |
| reward -= 0.04 |
|
|
|
|
| if at in (ActionType.BLOCK_IP_RANGE, ActionType.CREATE_INDEX, ActionType.FAILOVER) or str(at) in ("block_ip_range", "create_index", "failover"): |
| reward -= 0.10 |
| error_text = f"Action {at} is not applicable to this incident." |
|
|
|
|
| state.total_reward = self._clamp(state.total_reward + reward) |
| if state.step >= state.max_steps and not done: |
| done = True |
| info["reason"] = "max_steps_reached" |
|
|
| obs = state._build_observation(last_action_result=result_text, last_action_error=error_text) |
| state.action_history.append({"step": state.step, "action": action.model_dump(), "reward": round(reward, 4)}) |
| return StepOutput(next_state=state, reward=round(reward, 4), done=done, info=info) |
|
|