Spaces:
Sleeping
Sleeping
| import random | |
| from typing import Optional, Dict, Any, List | |
| from server.models.action import IncidentAction | |
| from server.models.observation import IncidentObservation, ServiceMetrics, DependencyEdge, Alert, LogEntry | |
| from server.simulation.scenarios.base_scenario import BaseScenario | |
| from server.simulation.topology import ENTERPRISE_TOPOLOGY | |
| from server.safety.guard import SafetyGuard, ValidationResult | |
| class StateEngine: | |
| def __init__(self, scenario: BaseScenario, seed: int = 42): | |
| self.rng = random.Random(seed) | |
| self.scenario = scenario | |
| self.step_count = 0 | |
| self.simulated_time = 0.0 | |
| self.action_history: List[Dict] = [] | |
| self.safety_guard = SafetyGuard() | |
| self.current_state: Dict[str, Any] = {} | |
| def tick(self, action: Optional[IncidentAction] = None) -> IncidentObservation: | |
| if self.step_count == 0: | |
| self.current_state = self.scenario.get_initial_state() | |
| state = self.current_state | |
| state["step"] = self.step_count | |
| state["elapsed_seconds"] = self.simulated_time | |
| if action: | |
| validation = self.safety_guard.validate(action, state) | |
| if validation.allowed: | |
| self._apply_action_effects(action, state) | |
| state["actions_taken"] = state.get("actions_taken", 0) + 1 | |
| self.safety_guard.audit_log_entry(action, validation, state) | |
| else: | |
| # Log blocked action | |
| self.safety_guard.audit_log_entry(action, validation, state) | |
| state["safety_violations"] = state.get("safety_violations", 0) + 1 | |
| self._propagate_failures(state) | |
| self._evolve_metrics(state) | |
| self._generate_alerts_and_logs(state) | |
| self._recalculate_blast_radius(state) | |
| self.step_count += 1 | |
| self.simulated_time += 60.0 # 60 seconds per step | |
| return self._build_observation(state, action) | |
| def _get_current_state(self) -> Dict[str, Any]: | |
| # Return persisted state, or regenerate if not yet initialized | |
| return self.current_state if self.current_state else self.scenario.get_initial_state() | |
| def _apply_action_effects(self, action: IncidentAction, state: Dict[str, Any]) -> None: | |
| """Apply action effects to the state. Maps each action_type to state mutations.""" | |
| if action.action_type == "restart_service": | |
| for sm in state["service_metrics"]: | |
| if sm.service == action.target_service: | |
| sm.status = "recovering" | |
| sm.error_rate = max(0, sm.error_rate - 0.5) | |
| break | |
| elif action.action_type == "inspect_logs": | |
| # Diagnostic action - no state effect, but logged for grading | |
| pass | |
| elif action.action_type == "pull_metrics": | |
| # Diagnostic action - no state effect | |
| pass | |
| elif action.action_type == "trace_dependency_chain": | |
| # Diagnostic action - used for analysis | |
| pass | |
| elif action.action_type == "check_deploy_history": | |
| # Diagnostic action - mark that we've checked deployments | |
| state["deploy_history_checked"] = state.get("deploy_history_checked", {}) | |
| state["deploy_history_checked"][action.target_service] = True | |
| elif action.action_type == "query_runbook": | |
| # Diagnostic action - no state effect | |
| pass | |
| elif action.action_type == "check_alert_history": | |
| # Diagnostic action - no state effect | |
| pass | |
| elif action.action_type == "analyze_error_budget": | |
| # Diagnostic action - no state effect | |
| pass | |
| elif action.action_type == "acknowledge_alert": | |
| # Mark alert as acknowledged | |
| for alert in state.get("active_alerts", []): | |
| if alert.service == action.target_service: | |
| alert.is_acknowledged = True | |
| elif action.action_type == "scale_up_replicas": | |
| # Increase replica count and reduce per-replica load | |
| replica_delta = action.parameters.get("replica_count", 1) | |
| for sm in state["service_metrics"]: | |
| if sm.service == action.target_service: | |
| sm.replica_count = max(1, sm.replica_count + replica_delta) | |
| sm.cpu_percent = max(10, sm.cpu_percent * 0.7) | |
| sm.latency_p99_ms *= 0.8 | |
| break | |
| elif action.action_type == "rollback_deployment": | |
| # Reduce error rate as we revert to stable version | |
| for sm in state["service_metrics"]: | |
| if sm.service == action.target_service: | |
| sm.error_rate = max(0, sm.error_rate - 0.4) | |
| sm.status = "recovering" | |
| break | |
| elif action.action_type == "toggle_feature_flag": | |
| # Disable problematic feature, reducing errors | |
| feature = action.parameters.get("feature_name", "unknown") | |
| for sm in state["service_metrics"]: | |
| if sm.service == action.target_service: | |
| sm.error_rate = max(0, sm.error_rate - 0.3) | |
| break | |
| elif action.action_type == "flush_cache": | |
| # Cache flush causes temporary miss spike, then recovery | |
| for sm in state["service_metrics"]: | |
| if sm.service == action.target_service: | |
| sm.status = "degraded" | |
| sm.latency_p99_ms *= 2 # Temporary spike | |
| state["cache_flushed_services"] = state.get("cache_flushed_services", set()) | |
| state["cache_flushed_services"].add(action.target_service) | |
| break | |
| elif action.action_type == "increase_conn_pool": | |
| # Increase database connection pool | |
| pool_size = action.parameters.get("pool_size", 50) | |
| for sm in state["service_metrics"]: | |
| if sm.service == action.target_service: | |
| sm.db_conn_pool_utilization = max(0, (sm.db_conn_pool_utilization or 0.5) - 0.2) | |
| break | |
| elif action.action_type == "drain_traffic": | |
| # Remove service from load balancer (all traffic rerouted) | |
| for sm in state["service_metrics"]: | |
| if sm.service == action.target_service and sm.service != "haproxy-lb": | |
| sm.status = "drained" | |
| sm.rps = 0 | |
| break | |
| elif action.action_type == "restore_traffic": | |
| # Restore traffic to service | |
| for sm in state["service_metrics"]: | |
| if sm.service == action.target_service: | |
| sm.status = "healthy" if sm.error_rate < 0.1 else "degraded" | |
| sm.rps = 100 # Restored to baseline | |
| break | |
| elif action.action_type == "trigger_failover": | |
| # Database failover - brief unavailability then recovery | |
| for sm in state["service_metrics"]: | |
| if sm.service == action.target_service: | |
| sm.status = "recovering" | |
| sm.error_rate = 0.2 # Temporary spike during failover | |
| break | |
| elif action.action_type == "publish_status_page": | |
| # Communication action - no direct state effect | |
| pass | |
| elif action.action_type == "declare_incident_resolved": | |
| # Terminal action - marks episode for scoring | |
| state["incident_declared_resolved"] = True | |
| elif action.action_type == "request_human_escalation": | |
| # Terminal action - escalation penalty applied but episode continues | |
| state["escalation_requested"] = True | |
| def _propagate_failures(self, state: Dict[str, Any]) -> None: | |
| # Simplified cascade | |
| failing_services = [sm.service for sm in state["service_metrics"] if sm.status in ("critical", "down")] | |
| for sm in state["service_metrics"]: | |
| if any(dep in failing_services for dep in ENTERPRISE_TOPOLOGY[sm.service]["depends_on"]): | |
| sm.error_rate = min(1.0, sm.error_rate + 0.3) | |
| def _evolve_metrics(self, state: Dict[str, Any]) -> None: | |
| for sm in state["service_metrics"]: | |
| # Add noise | |
| sm.error_rate = max(0, min(1, sm.error_rate + self.rng.uniform(-0.01, 0.01))) | |
| sm.latency_p99_ms += self.rng.uniform(-5, 5) | |
| def _generate_alerts_and_logs(self, state: Dict[str, Any]) -> None: | |
| # Simplified | |
| pass | |
| def _recalculate_blast_radius(self, state: Dict[str, Any]) -> float: | |
| # Simplified | |
| failing = sum(1 for sm in state["service_metrics"] if sm.status != "healthy") | |
| state["blast_radius"] = min(1.0, failing / len(state["service_metrics"])) | |
| def _build_observation(self, state: Dict[str, Any], action: Optional[IncidentAction] = None) -> IncidentObservation: | |
| return IncidentObservation( | |
| episode_id=state["episode_id"], | |
| task_id="task1_oom_crash", # Hardcoded for now | |
| step=state["step"], | |
| elapsed_seconds=state["elapsed_seconds"], | |
| max_steps=state["max_steps"], | |
| blast_radius=state["blast_radius"], | |
| error_budget_burn_rate=state["error_budget_burn_rate"], | |
| estimated_revenue_loss_usd=state["blast_radius"] * 100000, # Simplified | |
| active_user_sessions_impacted=int(state["blast_radius"] * 10000), | |
| active_alerts=state["active_alerts"], | |
| new_alerts_this_step=[], | |
| service_metrics=state["service_metrics"], | |
| recent_logs=[], | |
| logs_this_step=[], | |
| dependency_graph=state["dependency_graph"], | |
| last_action_result="Action applied" if action else None, | |
| last_action_success=True if action else None, | |
| actions_taken=state.get("actions_taken", 0), | |
| available_actions=["inspect_logs", "restart_service"], # Simplified | |
| safety_violations_this_episode=state.get("safety_violations", 0), | |
| rate_limit_remaining=30 - state.get("actions_taken", 0) | |
| ) |