File size: 11,519 Bytes
06b4790
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77eea12
 
06b4790
77eea12
06b4790
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
849b14a
 
 
 
 
 
06b4790
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
from __future__ import annotations
import uuid
from typing import Dict, Any, List
from models import Action, ActionType
from tasks.base import BaseTask, InternalState, StepOutput, semantic_match

INCIDENT_TIME = "2026-03-30T10:14:47Z"

SCENARIOS = [
    {
        "failing_service": "payment-service",
        "root_cause": "memory_leak_payment_service",
        "fix": "restart payment-service",
        "alert_msg": "payment-service pod restarting (OOMKilled)",
        "language": "java",
        "diagnosis_keywords": ["memory", "oom", "heap", "leak", "outofmemory", "kill"],
    },
    {
        "failing_service": "order-service",
        "root_cause": "memory_leak_order_service",
        "fix": "restart order-service",
        "alert_msg": "order-service pod restarting (OOMKilled)",
        "language": "python",
        "diagnosis_keywords": ["memory", "oom", "heap", "leak", "segfault", "kill", "allocat"],
    },
    {
        "failing_service": "user-service",
        "root_cause": "memory_leak_user_service",
        "fix": "restart user-service",
        "alert_msg": "user-service pod restarting (OOMKilled)",
        "language": "node",
        "diagnosis_keywords": ["memory", "heap", "oom", "leak", "javascript", "kill"],
    },
]

ALL_SERVICES = ["payment-service", "order-service", "user-service", "api-gateway"]
VERSIONS = {
    "payment-service": "v4.2.1", "order-service": "v1.8.2",
    "user-service": "v3.0.5", "api-gateway": "v2.1.0",
}
DEPENDENCIES = [
    {"service": "api-gateway", "calls": ["payment-service", "order-service", "user-service"], "called_by": []},
    {"service": "payment-service", "calls": [], "called_by": ["api-gateway"]},
    {"service": "order-service", "calls": [], "called_by": ["api-gateway"]},
    {"service": "user-service", "calls": [], "called_by": ["api-gateway"]},
]

def _make_logs(scenario, heap1, heap2, restart_count):
    svc = scenario["failing_service"]
    lang = scenario["language"]
    if lang == "java":
        failing = [
            "[10:13:55] INFO  Request processed 200 38ms",
            f"[10:14:35] WARN  Heap usage at {heap1}% - approaching threshold",
            f"[10:14:41] WARN  Heap usage at {heap2}%",
            "[10:14:45] WARN  GC overhead limit exceeded - major GC running",
            "[10:14:47] ERROR java.lang.OutOfMemoryError: Java heap space",
            "[10:14:47] ERROR   at com.payments.ChargeProcessor.process(ChargeProcessor.java:142)",
            f"[10:14:48] FATAL Service entering crash loop - pod restart #{restart_count}",
        ]
    elif lang == "python":
        failing = [
            "[10:13:55] INFO  POST /orders 200 55ms",
            f"[10:14:35] WARN  RSS memory {heap1}% of pod limit",
            f"[10:14:41] WARN  RSS memory {heap2}% of pod limit - approaching OOM",
            "[10:14:46] ERROR Memory allocator: no more pages available",
            "[10:14:47] ERROR Fatal Python error: Segmentation fault (memory allocator exhausted)",
            f"[10:14:48] FATAL Pod killed by OOM killer - restart #{restart_count}",
        ]
    else:
        failing = [
            "[10:13:55] INFO  GET /users/profile 200 9ms",
            f"[10:14:35] WARN  Heap used: {heap1}% ({heap1 * 2}MB / 200MB)",
            f"[10:14:41] WARN  Heap used: {heap2}% - GC pressure increasing",
            "[10:14:47] ERROR FATAL ERROR: Reached heap limit - JavaScript heap out of memory",
            f"[10:14:48] FATAL Container OOMKilled - restart #{restart_count}",
        ]
    logs = {svc: failing}
    for name in ALL_SERVICES:
        if name == svc: continue
        if name == "api-gateway":
            logs[name] = [
                "[10:14:30] INFO  GET /api/v1/health 200 3ms",
                f"[10:14:48] WARN  Upstream {svc} returned 503",
                f"[10:14:49] WARN  Circuit breaker OPEN for {svc}",
            ]
        else:
            logs[name] = ["[10:14:30] INFO  Service healthy - 0 errors"]
    return logs


class EasyTask(BaseTask):
    def initialize(self) -> InternalState:
        scenario = SCENARIOS[self.rng.randint(0, len(SCENARIOS) - 1)]
        failing = scenario["failing_service"]
        heap1 = self.rng.randint(74, 83)
        heap2 = heap1 + self.rng.randint(5, 10)
        restart_count = self.rng.randint(2, 6)

        services: Dict[str, dict] = {}
        for name in ALL_SERVICES:
            if name == failing:
                services[name] = {
                    "name": name, "status": "down",
                    "cpu_percent": round(self.rng.uniform(5, 20), 1),
                    "memory_percent": round(self.rng.uniform(93, 99), 1),
                    "error_rate": round(self.rng.uniform(8.0, 15.0), 2),
                    "latency_p99_ms": round(self.rng.uniform(5000, 9000), 0),
                    "replicas_running": 0, "replicas_desired": 3,
                    "current_version": VERSIONS[name],
                    "last_deployed": "2026-03-28T14:00:00Z",
                    "minutes_degraded": 0, "sla_breach": False,
                }
            elif name == "api-gateway":
                services[name] = {
                    "name": name, "status": "degraded",
                    "cpu_percent": round(self.rng.uniform(35, 55), 1),
                    "memory_percent": round(self.rng.uniform(40, 55), 1),
                    "error_rate": round(self.rng.uniform(2.0, 5.0), 2),
                    "latency_p99_ms": round(self.rng.uniform(800, 1500), 0),
                    "replicas_running": 2, "replicas_desired": 2,
                    "current_version": VERSIONS[name],
                    "last_deployed": "2026-03-25T09:00:00Z",
                    "minutes_degraded": 0, "sla_breach": False,
                }
            else:
                services[name] = {
                    "name": name, "status": "healthy",
                    "cpu_percent": round(self.rng.uniform(20, 40), 1),
                    "memory_percent": round(self.rng.uniform(30, 48), 1),
                    "error_rate": 0.0,
                    "latency_p99_ms": round(self.rng.uniform(8, 30), 0),
                    "replicas_running": 2, "replicas_desired": 2,
                    "current_version": VERSIONS[name],
                    "last_deployed": "2026-03-20T11:00:00Z",
                    "minutes_degraded": 0, "sla_breach": False,
                }

        alerts = [
            {
                "id": "A001", "severity": "critical", "service": failing,
                "message": f"{scenario['alert_msg']} - {restart_count} times in 5 minutes",
                "timestamp": "2026-03-30T10:14:48Z", "acknowledged": False,
            },
            {
                "id": "A002", "severity": "warning", "service": "api-gateway",
                "message": f"Upstream {failing} returning 503 - circuit breaker open",
                "timestamp": "2026-03-30T10:14:52Z", "acknowledged": False,
            },
        ]

        state = InternalState(
            episode_id=str(uuid.uuid4()), task_id="easy", step=0, max_steps=15,
            services=services, alerts=alerts,
            logs=_make_logs(scenario, heap1, heap2, restart_count),
            action_history=[], total_reward=0.0, incident_resolved=False,
            ground_truth_root_cause=scenario["root_cause"],
            ground_truth_fix=scenario["fix"],
            incident_start_time=INCIDENT_TIME,
            healthy_services=[s for s in ALL_SERVICES if s != failing],
            service_dependencies=DEPENDENCIES,
        )
        state._scenario = scenario
        return state

    def step(self, state: InternalState, action: Action) -> StepOutput:
        state.step += 1
        state._apply_sla_degradation()
        at = action.action_type
        svc = action.service or ""
        scenario = state._scenario
        failing = scenario["failing_service"]
        keywords = scenario["diagnosis_keywords"]
        reward = 0.0
        done = False
        info: Dict[str, Any] = {}

        result_text, error_text = self._apply_action_to_logs(state, action)

        if at in (ActionType.READ_LOGS, ActionType.SEARCH_LOGS) and svc == failing:
            if "logs_investigated" not in state.rewards_given:
                reward += 0.15
                state.rewards_given.add("logs_investigated")

        if at == ActionType.READ_METRICS and svc == failing:
            if "read_metrics" not in state.rewards_given:
                reward += 0.10
                state.rewards_given.add("read_metrics")

        if at == ActionType.READ_RUNBOOK:
            if "runbook" not in state.rewards_given:
                reward += 0.05
                state.rewards_given.add("runbook")

        if at == ActionType.DIAGNOSE:
            rc = action.root_cause or ""
            correct_type = semantic_match(rc, keywords, threshold=1)
            correct_svc = semantic_match(rc, [failing, failing.split("-")[0]])
            result_text = f"Diagnosis recorded: {rc}"
            if correct_type and correct_svc:
                if "diagnose_correct" not in state.rewards_given:
                    # Give full reward, remove partial if already given
                    bonus = 0.30 if "diagnose_partial" not in state.rewards_given else 0.15
                    reward += bonus
                    state.rewards_given.add("diagnose_correct")
            elif correct_type:
                if "diagnose_partial" not in state.rewards_given and "diagnose_correct" not in state.rewards_given:
                    reward += 0.15
                    state.rewards_given.add("diagnose_partial")

        if at == ActionType.RESTART_SERVICE:
            blind_penalty = self._penalty_blind_remediation(state, action, "restarted")
            reward += blind_penalty
            if svc == failing:
                reward += 0.40
                state.services[svc]["status"] = "healthy"
                state.services[svc]["memory_percent"] = round(self.rng.uniform(38, 48), 1)
                state.services[svc]["error_rate"] = 0.0
                state.services[svc]["latency_p99_ms"] = round(self.rng.uniform(20, 60), 0)
                state.services[svc]["replicas_running"] = state.services[svc]["replicas_desired"]
                state.alerts = [a for a in state.alerts if a["id"] != "A001"]
                state.incident_resolved = True
                result_text = f"{svc} restarted. Memory cleared. All pods healthy."
                done = True
                info["resolution"] = "incident_resolved"
            elif svc in state.healthy_services:
                reward -= 0.10
                error_text = f"Collateral damage: {svc} was healthy. Unnecessary restart."

        if at == ActionType.NOOP and state.step > 3:
            reward -= 0.04


        if at in (ActionType.BLOCK_IP_RANGE, ActionType.CREATE_INDEX, ActionType.FAILOVER) or str(at) in ("block_ip_range", "create_index", "failover"):
            reward -= 0.10
            error_text = f"Action {at} is not applicable to this incident."


        state.total_reward = self._clamp(state.total_reward + reward)
        if state.step >= state.max_steps and not done:
            done = True
            info["reason"] = "max_steps_reached"

        obs = state._build_observation(last_action_result=result_text, last_action_error=error_text)
        state.action_history.append({"step": state.step, "action": action.model_dump(), "reward": round(reward, 4)})
        return StepOutput(next_state=state, reward=round(reward, 4), done=done, info=info)