File size: 7,096 Bytes
4887b5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import random
import uuid
from datetime import datetime
from tasks.base import BaseTask, InternalState, StepOutput
from models import Action, StepResult, ServiceStatus, Alert, ActionType

class GeneratedTask(BaseTask):
    def __init__(self, incident_dict: dict):
        self.incident = incident_dict
        self.task_id = "generated"
        self.max_steps = 20
        # For compatibility with BaseTask which expects rng in __init__
        super().__init__(random.Random(incident_dict.get("seed", 42)))

    def initialize(self) -> InternalState:
        affected = self.incident["affected_service"]
        failure_mode = self.incident["failure_mode"]
        
        services_dict = {}
        SERVICES = ["payment-service", "order-service", "user-service", 
                    "inventory-service", "api-gateway", "notification-service", 
                    "data-pipeline", "ml-inference-service"]
        
        for svc in SERVICES:
            if svc == affected:
                services_dict[svc] = {
                    "name": svc,
                    "status": "degraded",
                    "cpu_percent": 85.0,
                    "memory_percent": 92.0,
                    "error_rate": 0.35,
                    "latency_p99_ms": 2800.0,
                    "replicas_running": 2,
                    "replicas_desired": 3,
                    "current_version": "v1.2.4",
                    "last_deployed": (datetime.utcnow()).isoformat(),
                    "minutes_degraded": 0,
                    "sla_breach": False
                }
            else:
                services_dict[svc] = {
                    "name": svc,
                    "status": "healthy",
                    "cpu_percent": 25.0,
                    "memory_percent": 40.0,
                    "error_rate": 0.01,
                    "latency_p99_ms": 120.0,
                    "replicas_running": 3,
                    "replicas_desired": 3,
                    "current_version": "v1.2.3",
                    "last_deployed": (datetime.utcnow()).isoformat(),
                    "minutes_degraded": 0,
                    "sla_breach": False
                }

        active_alerts = []
        # One CRITICAL alert
        active_alerts.append({
            "id": str(uuid.uuid4())[:8],
            "service": affected,
            "severity": "critical",
            "message": self.incident["description"],
            "timestamp": datetime.utcnow().isoformat(),
            "acknowledged": False
        })
        
        # Noise alerts
        for noise in self.incident["noise_alerts"]:
            active_alerts.append({
                "id": str(uuid.uuid4())[:8],
                "service": "notification-service",
                "severity": "warning",
                "message": noise,
                "timestamp": datetime.utcnow().isoformat(),
                "acknowledged": False
            })

        log_lines = {
            "oom": ["ERROR OutOfMemoryError: Java heap space", 
                    "WARN Memory usage at 98%, GC overhead limit exceeded"],
            "cascade": ["ERROR Connection pool exhausted: timeout after 30s",
                        "ERROR Failed to acquire connection from pool"],
            "corruption": ["WARN Price mismatch detected: expected 29.99 got 299.9",
                           "WARN Data validation failed for 847 records"],
            "security": ["WARN 1847 failed login attempts in 60s",
                         "WARN Rate limit exceeded from 185.220.101.x"],
            "database": ["WARN Slow query: seq_scan on orders (847ms)",
                         "WARN Query planner chose sequential scan, missing index"],
            "network_partition": ["ERROR Connection timeout to us-east-1",
                                  "ERROR Health check failed: unreachable"]
        }
        
        logs = {}
        for svc in SERVICES:
            if svc == affected:
                logs[svc] = log_lines.get(failure_mode, ["INFO Service running normally"])
            else:
                logs[svc] = ["INFO Service running normally", "INFO Health check passed"]

        state = InternalState(
            episode_id=str(uuid.uuid4()),
            task_id="generated",
            step=0,
            max_steps=self.max_steps,
            services=services_dict,
            alerts=active_alerts,
            logs=logs,
            action_history=[],
            total_reward=0.0,
            incident_resolved=False,
            ground_truth_root_cause=self.incident["ground_truth_root_cause"],
            ground_truth_fix=self.incident["ground_truth_fix"],
            incident_start_time=datetime.utcnow().isoformat(),
            rewards_given=set()
        )
        state._scenario = self.incident
        return state

    def step(self, state: InternalState, action: Action) -> StepOutput:
        reward = 0.0
        result_text, error_text = self._apply_action_to_logs(state, action)
        
        # ActionType can be enum or string
        at = action.action_type
        at_val = at.value if hasattr(at, "value") else str(at)
        
        if at_val == "read_logs":
            if action.service == self.incident["affected_service"]:
                if "read_logs" not in state.rewards_given:
                    reward += 0.10
                    state.rewards_given.add("read_logs")
        
        if at_val == "diagnose":
            diagnosis = action.diagnosis or action.root_cause or ""
            if state.ground_truth_root_cause.lower() in diagnosis.lower():
                if "diagnose" not in state.rewards_given:
                    reward += 0.30
                    state.rewards_given.add("diagnose")
        
        if at_val == self.incident["ground_truth_fix"]:
            if action.service == self.incident["affected_service"] and "fix" not in state.rewards_given:
                reward += 0.45
                state.rewards_given.add("fix")
                state.incident_resolved = True
                state.services[self.incident["affected_service"]]["status"] = "healthy"
                state.services[self.incident["affected_service"]]["cpu_percent"] = 25.0
                state.services[self.incident["affected_service"]]["memory_percent"] = 40.0
                state.services[self.incident["affected_service"]]["error_rate"] = 0.01
                state.services[self.incident["affected_service"]]["latency_p99_ms"] = 120.0
        
        state.step += 1
        state.total_reward = self._clamp(state.total_reward + reward)
        
        done = state.incident_resolved or state.step >= self.max_steps
        info = {}
        if state.incident_resolved: info["resolution"] = "incident_resolved"
        if state.step >= self.max_steps: info["reason"] = "max_steps_reached"
        
        state.action_history.append({
            "step": state.step,
            "action": action.model_dump(),
            "reward": round(reward, 4)
        })
        
        return StepOutput(next_state=state, reward=round(reward, 4), done=done, info=info)