File size: 10,307 Bytes
309bdb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import random
from typing import Optional, Dict, Any, List
from server.models.action import IncidentAction
from server.models.observation import IncidentObservation, ServiceMetrics, DependencyEdge, Alert, LogEntry
from server.simulation.scenarios.base_scenario import BaseScenario
from server.simulation.topology import ENTERPRISE_TOPOLOGY
from server.safety.guard import SafetyGuard, ValidationResult

class StateEngine:
    def __init__(self, scenario: BaseScenario, seed: int = 42):
        self.rng = random.Random(seed)
        self.scenario = scenario
        self.step_count = 0
        self.simulated_time = 0.0
        self.action_history: List[Dict] = []
        self.safety_guard = SafetyGuard()
        self.current_state: Dict[str, Any] = {}

    def tick(self, action: Optional[IncidentAction] = None) -> IncidentObservation:
        if self.step_count == 0:
            self.current_state = self.scenario.get_initial_state()
        
        state = self.current_state
        state["step"] = self.step_count
        state["elapsed_seconds"] = self.simulated_time

        if action:
            validation = self.safety_guard.validate(action, state)
            if validation.allowed:
                self._apply_action_effects(action, state)
                state["actions_taken"] = state.get("actions_taken", 0) + 1
                self.safety_guard.audit_log_entry(action, validation, state)
            else:
                # Log blocked action
                self.safety_guard.audit_log_entry(action, validation, state)
                state["safety_violations"] = state.get("safety_violations", 0) + 1

        self._propagate_failures(state)
        self._evolve_metrics(state)
        self._generate_alerts_and_logs(state)
        self._recalculate_blast_radius(state)

        self.step_count += 1
        self.simulated_time += 60.0  # 60 seconds per step

        return self._build_observation(state, action)

    def _get_current_state(self) -> Dict[str, Any]:
        # Return persisted state, or regenerate if not yet initialized
        return self.current_state if self.current_state else self.scenario.get_initial_state()

    def _apply_action_effects(self, action: IncidentAction, state: Dict[str, Any]) -> None:
        """Apply action effects to the state. Maps each action_type to state mutations."""
        if action.action_type == "restart_service":
            for sm in state["service_metrics"]:
                if sm.service == action.target_service:
                    sm.status = "recovering"
                    sm.error_rate = max(0, sm.error_rate - 0.5)
                    break
        
        elif action.action_type == "inspect_logs":
            # Diagnostic action - no state effect, but logged for grading
            pass
        
        elif action.action_type == "pull_metrics":
            # Diagnostic action - no state effect
            pass
        
        elif action.action_type == "trace_dependency_chain":
            # Diagnostic action - used for analysis
            pass
        
        elif action.action_type == "check_deploy_history":
            # Diagnostic action - mark that we've checked deployments
            state["deploy_history_checked"] = state.get("deploy_history_checked", {})
            state["deploy_history_checked"][action.target_service] = True
        
        elif action.action_type == "query_runbook":
            # Diagnostic action - no state effect
            pass
        
        elif action.action_type == "check_alert_history":
            # Diagnostic action - no state effect
            pass
        
        elif action.action_type == "analyze_error_budget":
            # Diagnostic action - no state effect
            pass
        
        elif action.action_type == "acknowledge_alert":
            # Mark alert as acknowledged
            for alert in state.get("active_alerts", []):
                if alert.service == action.target_service:
                    alert.is_acknowledged = True
        
        elif action.action_type == "scale_up_replicas":
            # Increase replica count and reduce per-replica load
            replica_delta = action.parameters.get("replica_count", 1)
            for sm in state["service_metrics"]:
                if sm.service == action.target_service:
                    sm.replica_count = max(1, sm.replica_count + replica_delta)
                    sm.cpu_percent = max(10, sm.cpu_percent * 0.7)
                    sm.latency_p99_ms *= 0.8
                    break
        
        elif action.action_type == "rollback_deployment":
            # Reduce error rate as we revert to stable version
            for sm in state["service_metrics"]:
                if sm.service == action.target_service:
                    sm.error_rate = max(0, sm.error_rate - 0.4)
                    sm.status = "recovering"
                    break
        
        elif action.action_type == "toggle_feature_flag":
            # Disable problematic feature, reducing errors
            feature = action.parameters.get("feature_name", "unknown")
            for sm in state["service_metrics"]:
                if sm.service == action.target_service:
                    sm.error_rate = max(0, sm.error_rate - 0.3)
                    break
        
        elif action.action_type == "flush_cache":
            # Cache flush causes temporary miss spike, then recovery
            for sm in state["service_metrics"]:
                if sm.service == action.target_service:
                    sm.status = "degraded"
                    sm.latency_p99_ms *= 2  # Temporary spike
                    state["cache_flushed_services"] = state.get("cache_flushed_services", set())
                    state["cache_flushed_services"].add(action.target_service)
                    break
        
        elif action.action_type == "increase_conn_pool":
            # Increase database connection pool
            pool_size = action.parameters.get("pool_size", 50)
            for sm in state["service_metrics"]:
                if sm.service == action.target_service:
                    sm.db_conn_pool_utilization = max(0, (sm.db_conn_pool_utilization or 0.5) - 0.2)
                    break
        
        elif action.action_type == "drain_traffic":
            # Remove service from load balancer (all traffic rerouted)
            for sm in state["service_metrics"]:
                if sm.service == action.target_service and sm.service != "haproxy-lb":
                    sm.status = "drained"
                    sm.rps = 0
                    break
        
        elif action.action_type == "restore_traffic":
            # Restore traffic to service
            for sm in state["service_metrics"]:
                if sm.service == action.target_service:
                    sm.status = "healthy" if sm.error_rate < 0.1 else "degraded"
                    sm.rps = 100  # Restored to baseline
                    break
        
        elif action.action_type == "trigger_failover":
            # Database failover - brief unavailability then recovery
            for sm in state["service_metrics"]:
                if sm.service == action.target_service:
                    sm.status = "recovering"
                    sm.error_rate = 0.2  # Temporary spike during failover
                    break
        
        elif action.action_type == "publish_status_page":
            # Communication action - no direct state effect
            pass
        
        elif action.action_type == "declare_incident_resolved":
            # Terminal action - marks episode for scoring
            state["incident_declared_resolved"] = True
        
        elif action.action_type == "request_human_escalation":
            # Terminal action - escalation penalty applied but episode continues
            state["escalation_requested"] = True

    def _propagate_failures(self, state: Dict[str, Any]) -> None:
        # Simplified cascade
        failing_services = [sm.service for sm in state["service_metrics"] if sm.status in ("critical", "down")]
        for sm in state["service_metrics"]:
            if any(dep in failing_services for dep in ENTERPRISE_TOPOLOGY[sm.service]["depends_on"]):
                sm.error_rate = min(1.0, sm.error_rate + 0.3)

    def _evolve_metrics(self, state: Dict[str, Any]) -> None:
        for sm in state["service_metrics"]:
            # Add noise
            sm.error_rate = max(0, min(1, sm.error_rate + self.rng.uniform(-0.01, 0.01)))
            sm.latency_p99_ms += self.rng.uniform(-5, 5)

    def _generate_alerts_and_logs(self, state: Dict[str, Any]) -> None:
        # Simplified
        pass

    def _recalculate_blast_radius(self, state: Dict[str, Any]) -> float:
        # Simplified
        failing = sum(1 for sm in state["service_metrics"] if sm.status != "healthy")
        state["blast_radius"] = min(1.0, failing / len(state["service_metrics"]))

    def _build_observation(self, state: Dict[str, Any], action: Optional[IncidentAction] = None) -> IncidentObservation:
        return IncidentObservation(
            episode_id=state["episode_id"],
            task_id="task1_oom_crash",  # Hardcoded for now
            step=state["step"],
            elapsed_seconds=state["elapsed_seconds"],
            max_steps=state["max_steps"],
            blast_radius=state["blast_radius"],
            error_budget_burn_rate=state["error_budget_burn_rate"],
            estimated_revenue_loss_usd=state["blast_radius"] * 100000,  # Simplified
            active_user_sessions_impacted=int(state["blast_radius"] * 10000),
            active_alerts=state["active_alerts"],
            new_alerts_this_step=[],
            service_metrics=state["service_metrics"],
            recent_logs=[],
            logs_this_step=[],
            dependency_graph=state["dependency_graph"],
            last_action_result="Action applied" if action else None,
            last_action_success=True if action else None,
            actions_taken=state.get("actions_taken", 0),
            available_actions=["inspect_logs", "restart_service"],  # Simplified
            safety_violations_this_episode=state.get("safety_violations", 0),
            rate_limit_remaining=30 - state.get("actions_taken", 0)
        )