File size: 14,692 Bytes
5fe9036
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
"""
Core environment engine β€” implements reset/step/state for the SRE Incident Response env.
"""

import uuid
from typing import Any, Dict, Optional, Set, Tuple

from models import (
    Action,
    ActionType,
    GraderResult,
    INVESTIGATION_ACTIONS,
    Observation,
    REMEDIATION_ACTIONS,
    RootCauseCategory,
    ServiceState,
    ServiceStatus,
    State,
)
from env.scenario import IncidentScenario, RequiredFix
from tasks import SCENARIOS
from env.services import (
    format_config_diff,
    format_deploy_history,
    format_dependencies,
    format_logs,
    format_metrics,
    format_runbook,
    format_traces,
    generate_alerts,
    ping_service,
    recompute_health,
)


class Session:
    """Tracks the state of a single episode."""

    def __init__(self, scenario: IncidentScenario, session_id: str):
        self.session_id = session_id
        self.scenario = scenario
        self.step_count = 0
        self.done = False
        self.cumulative_reward = 0.0

        # Mutable service state: {name: {status, version, replicas}}
        self.services: Dict[str, Dict[str, Any]] = {}
        for name, cfg in scenario.services.items():
            self.services[name] = {
                "status": cfg.status,
                "version": cfg.version,
                "replicas": cfg.replicas,
            }

        # Track which root-cause services have been fixed
        self.fixed_services: Set[str] = set()

        # Build root-cause map: service_name -> fault_type
        self.root_cause_map: Dict[str, str] = {}
        for name, cfg in scenario.services.items():
            if cfg.is_root_cause and cfg.fault_type:
                self.root_cause_map[name] = cfg.fault_type

        # Action history for grading
        self.actions: list[Action] = []
        self.services_investigated: Set[str] = set()
        self.remediations_applied: list[Dict[str, Any]] = []
        self.diagnosis: Optional[Action] = None


class IncidentResponseEnv:
    """The SRE Incident Response OpenEnv environment."""

    def __init__(self):
        self.sessions: Dict[str, Session] = {}

    def get_task_ids(self) -> list[str]:
        return list(SCENARIOS.keys())

    def reset(self, task_id: str, seed: int = 0) -> Tuple[Observation, str]:
        """Start a new episode for the given task."""
        if task_id not in SCENARIOS:
            raise ValueError(f"Unknown task_id: {task_id}. Available: {list(SCENARIOS.keys())}")

        scenario = SCENARIOS[task_id]
        session_id = str(uuid.uuid4())[:8]
        session = Session(scenario, session_id)
        self.sessions[session_id] = session

        # Build initial observation
        obs = self._build_observation(session, action_result=None)
        return obs, session_id

    def step(self, session_id: str, action: Action) -> Tuple[Observation, float, bool, Dict]:
        """Execute an action and return (observation, reward, done, info)."""
        session = self._get_session(session_id)
        if session.done:
            obs = self._build_observation(session, action_result="Episode already finished.")
            return obs, 0.0, True, {"error": "Episode already finished."}

        session.step_count += 1
        session.actions.append(action)

        reward = 0.0
        action_result = ""
        info: Dict[str, Any] = {}

        service_name = action.service
        scenario = session.scenario

        # Validate service name for actions that require it
        if action.action_type != ActionType.SUBMIT_DIAGNOSIS:
            if service_name and service_name not in scenario.services:
                action_result = f"Unknown service: '{service_name}'. Available: {list(scenario.services.keys())}"
                obs = self._build_observation(session, action_result=action_result)
                return obs, 0.0, False, {"error": action_result}
            if not service_name and action.action_type != ActionType.SUBMIT_DIAGNOSIS:
                action_result = "Action requires a 'service' parameter."
                obs = self._build_observation(session, action_result=action_result)
                return obs, 0.0, False, {"error": action_result}

        # ── Investigation actions ──
        if action.action_type in INVESTIGATION_ACTIONS:
            session.services_investigated.add(service_name)
            action_result = self._handle_investigation(session, action)

            # Small reward for investigating root cause services
            if service_name in scenario.root_cause_services:
                reward = 0.01
            else:
                reward = 0.0

        # ── Remediation actions ──
        elif action.action_type in REMEDIATION_ACTIONS:
            action_result, reward = self._handle_remediation(session, action)
            session.remediations_applied.append({
                "action": action.action_type.value,
                "service": service_name,
                "target_version": action.target_version,
                "replicas": action.replicas,
            })

        # ── Submit diagnosis ──
        elif action.action_type == ActionType.SUBMIT_DIAGNOSIS:
            session.diagnosis = action
            session.done = True
            grader_result = self._grade(session)
            reward = grader_result.score
            action_result = f"Diagnosis submitted. Score: {grader_result.score:.2f}"
            info["grader_result"] = grader_result.model_dump()

        session.cumulative_reward += reward

        # Check max steps
        if session.step_count >= scenario.max_steps and not session.done:
            session.done = True
            if session.diagnosis is None:
                # Auto-grade with whatever we have
                grader_result = self._grade(session)
                reward = grader_result.score
                info["grader_result"] = grader_result.model_dump()
                action_result += f"\n[MAX STEPS REACHED] Episode ended. Score: {grader_result.score:.2f}"

        obs = self._build_observation(session, action_result=action_result, reward=reward)
        obs.done = session.done
        if "grader_result" in info:
            obs.score = info["grader_result"]["score"]

        return obs, reward, session.done, info

    def state(self, session_id: str) -> State:
        """Return current episode state."""
        session = self._get_session(session_id)
        return State(
            session_id=session.session_id,
            task_id=session.scenario.task_id,
            step_count=session.step_count,
            max_steps=session.scenario.max_steps,
            done=session.done,
            actions_taken=[a.action_type.value for a in session.actions],
            services_investigated=list(session.services_investigated),
            remediations_applied=[f"{r['action']}({r['service']})" for r in session.remediations_applied],
            cumulative_reward=round(session.cumulative_reward, 4),
        )

    # ── Internal helpers ───────────────────────────────────────────────

    def _get_session(self, session_id: str) -> Session:
        if session_id not in self.sessions:
            raise ValueError(f"Unknown session: {session_id}")
        return self.sessions[session_id]

    def _build_observation(
        self, session: Session, action_result: Optional[str], reward: float = 0.0,
    ) -> Observation:
        scenario = session.scenario
        svc_states = {}
        for name, data in session.services.items():
            svc_states[name] = ServiceState(
                status=data["status"],
                version=data["version"],
                replicas=data["replicas"],
            )

        alerts = generate_alerts(
            session.services, scenario.initial_alerts, session.fixed_services,
        )

        return Observation(
            step_number=session.step_count,
            timestamp=f"2026-04-06T04:{session.step_count:02d}:00Z",
            services=svc_states,
            active_alerts=alerts,
            incident_summary=scenario.incident_summary if session.step_count == 0 else "",
            action_result=action_result,
            reward=round(reward, 4),
            done=session.done,
        )

    def _handle_investigation(self, session: Session, action: Action) -> str:
        scenario = session.scenario
        svc = action.service

        if action.action_type == ActionType.READ_LOGS:
            logs = scenario.logs.get(svc, [])
            return format_logs(logs)

        elif action.action_type == ActionType.CHECK_METRICS:
            metrics = scenario.metrics.get(svc, [])
            return format_metrics(metrics)

        elif action.action_type == ActionType.PING_SERVICE:
            status = session.services[svc]["status"]
            return ping_service(status, svc)

        elif action.action_type == ActionType.CHECK_DEPENDENCIES:
            deps = scenario.dependencies.get(svc, [])
            dep_info = format_dependencies(deps)
            # Also show current health of dependencies
            dep_health = []
            for d in deps:
                if d in session.services:
                    dep_health.append(f"  {d}: {session.services[d]['status'].value}")
            if dep_health:
                dep_info += "\n\nDependency health:\n" + "\n".join(dep_health)
            return dep_info

        elif action.action_type == ActionType.INSPECT_DEPLOY:
            deploys = scenario.deploy_history.get(svc, [])
            return format_deploy_history(deploys)

        elif action.action_type == ActionType.QUERY_TRACES:
            traces = scenario.traces.get(svc, [])
            return format_traces(traces)

        elif action.action_type == ActionType.CHECK_RUNBOOK:
            runbook = scenario.runbooks.get(svc, "")
            return format_runbook(runbook)

        elif action.action_type == ActionType.DIFF_CONFIG:
            configs = scenario.configs.get(svc, {})
            return format_config_diff(configs)

        return f"No data available for {action.action_type.value} on {svc}."

    def _handle_remediation(self, session: Session, action: Action) -> Tuple[str, float]:
        scenario = session.scenario
        svc = action.service
        reward = 0.0
        result = ""

        # Check if this remediation matches any required fix
        fix_matched = False
        for req_fix in scenario.required_fixes:
            if self._fix_matches(action, req_fix):
                fix_matched = True
                session.fixed_services.add(svc)
                reward = 0.05
                break

        if action.action_type == ActionType.RESTART_SERVICE:
            if fix_matched:
                session.services[svc]["status"] = ServiceStatus.HEALTHY
                result = f"Service '{svc}' restarted successfully. Status: HEALTHY"
            else:
                # Restarting a non-root-cause service: no effect on the underlying issue
                current = session.services[svc]["status"]
                if current == ServiceStatus.DOWN and svc in session.root_cause_map:
                    result = f"Service '{svc}' restarted but crashed again β€” underlying issue persists."
                elif current == ServiceStatus.HEALTHY:
                    result = f"Service '{svc}' restarted. It was already healthy β€” no change."
                else:
                    result = f"Service '{svc}' restarted. Status unchanged β€” issue is caused by an upstream dependency."
                reward = -0.05

        elif action.action_type == ActionType.ROLLBACK_DEPLOY:
            if fix_matched:
                session.services[svc]["version"] = action.target_version or ""
                session.services[svc]["status"] = ServiceStatus.HEALTHY
                result = (
                    f"Service '{svc}' rolled back to {action.target_version}. "
                    f"Pods restarting with previous version... Status: HEALTHY"
                )
            else:
                current_version = session.services[svc]["version"]
                result = (
                    f"Rolled back '{svc}' to {action.target_version}, but this didn't resolve the issue. "
                    f"Previous version was {current_version}."
                )
                reward = -0.05

        elif action.action_type == ActionType.SCALE_UP:
            replicas = action.replicas or 3
            if fix_matched or (svc in scenario.root_cause_services):
                session.services[svc]["replicas"] = replicas
                session.fixed_services.add(svc)
                session.services[svc]["status"] = ServiceStatus.HEALTHY
                result = f"Service '{svc}' scaled to {replicas} replicas. Memory pressure alleviated. Status: HEALTHY"
                reward = 0.05
            else:
                session.services[svc]["replicas"] = replicas
                result = f"Service '{svc}' scaled to {replicas} replicas. No effect on the underlying issue."
                reward = -0.05

        elif action.action_type == ActionType.DRAIN_TRAFFIC:
            result = f"Traffic drained from '{svc}'. Service is no longer receiving requests."
            if svc not in scenario.root_cause_services:
                reward = -0.05

        # Recompute health after remediation
        session.services = recompute_health(
            session.services,
            scenario.dependencies,
            session.fixed_services,
            session.root_cause_map,
        )

        # Add post-remediation status summary
        still_broken = [
            name for name, data in session.services.items()
            if data["status"] != ServiceStatus.HEALTHY
        ]
        if still_broken:
            result += f"\n\n[POST-REMEDIATION CHECK] Services still unhealthy: {', '.join(still_broken)}"
        else:
            result += "\n\n[POST-REMEDIATION CHECK] All services are now HEALTHY."

        return result, reward

    def _fix_matches(self, action: Action, req_fix: RequiredFix) -> bool:
        """Check if an action matches a required fix."""
        if action.action_type.value != req_fix.action:
            return False
        if action.service != req_fix.service:
            return False
        if req_fix.target_version and action.target_version != req_fix.target_version:
            return False
        return True

    def _grade(self, session: Session) -> GraderResult:
        """Deterministic grading of the episode."""
        from graders.grader import grade_episode
        return grade_episode(session)