import random from typing import Optional, Dict, Any, List from server.models.action import IncidentAction from server.models.observation import IncidentObservation, ServiceMetrics, DependencyEdge, Alert, LogEntry from server.simulation.scenarios.base_scenario import BaseScenario from server.simulation.topology import ENTERPRISE_TOPOLOGY from server.safety.guard import SafetyGuard, ValidationResult class StateEngine: def __init__(self, scenario: BaseScenario, seed: int = 42): self.rng = random.Random(seed) self.scenario = scenario self.step_count = 0 self.simulated_time = 0.0 self.action_history: List[Dict] = [] self.safety_guard = SafetyGuard() self.current_state: Dict[str, Any] = {} def tick(self, action: Optional[IncidentAction] = None) -> IncidentObservation: if self.step_count == 0: self.current_state = self.scenario.get_initial_state() state = self.current_state state["step"] = self.step_count state["elapsed_seconds"] = self.simulated_time if action: validation = self.safety_guard.validate(action, state) if validation.allowed: self._apply_action_effects(action, state) state["actions_taken"] = state.get("actions_taken", 0) + 1 self.safety_guard.audit_log_entry(action, validation, state) else: # Log blocked action self.safety_guard.audit_log_entry(action, validation, state) state["safety_violations"] = state.get("safety_violations", 0) + 1 self._propagate_failures(state) self._evolve_metrics(state) self._generate_alerts_and_logs(state) self._recalculate_blast_radius(state) self.step_count += 1 self.simulated_time += 60.0 # 60 seconds per step return self._build_observation(state, action) def _get_current_state(self) -> Dict[str, Any]: # Return persisted state, or regenerate if not yet initialized return self.current_state if self.current_state else self.scenario.get_initial_state() def _apply_action_effects(self, action: IncidentAction, state: Dict[str, Any]) -> None: """Apply action effects to the state. Maps each action_type to state mutations.""" if action.action_type == "restart_service": for sm in state["service_metrics"]: if sm.service == action.target_service: sm.status = "recovering" sm.error_rate = max(0, sm.error_rate - 0.5) break elif action.action_type == "inspect_logs": # Diagnostic action - no state effect, but logged for grading pass elif action.action_type == "pull_metrics": # Diagnostic action - no state effect pass elif action.action_type == "trace_dependency_chain": # Diagnostic action - used for analysis pass elif action.action_type == "check_deploy_history": # Diagnostic action - mark that we've checked deployments state["deploy_history_checked"] = state.get("deploy_history_checked", {}) state["deploy_history_checked"][action.target_service] = True elif action.action_type == "query_runbook": # Diagnostic action - no state effect pass elif action.action_type == "check_alert_history": # Diagnostic action - no state effect pass elif action.action_type == "analyze_error_budget": # Diagnostic action - no state effect pass elif action.action_type == "acknowledge_alert": # Mark alert as acknowledged for alert in state.get("active_alerts", []): if alert.service == action.target_service: alert.is_acknowledged = True elif action.action_type == "scale_up_replicas": # Increase replica count and reduce per-replica load replica_delta = action.parameters.get("replica_count", 1) for sm in state["service_metrics"]: if sm.service == action.target_service: sm.replica_count = max(1, sm.replica_count + replica_delta) sm.cpu_percent = max(10, sm.cpu_percent * 0.7) sm.latency_p99_ms *= 0.8 break elif action.action_type == "rollback_deployment": # Reduce error rate as we revert to stable version for sm in state["service_metrics"]: if sm.service == action.target_service: sm.error_rate = max(0, sm.error_rate - 0.4) sm.status = "recovering" break elif action.action_type == "toggle_feature_flag": # Disable problematic feature, reducing errors feature = action.parameters.get("feature_name", "unknown") for sm in state["service_metrics"]: if sm.service == action.target_service: sm.error_rate = max(0, sm.error_rate - 0.3) break elif action.action_type == "flush_cache": # Cache flush causes temporary miss spike, then recovery for sm in state["service_metrics"]: if sm.service == action.target_service: sm.status = "degraded" sm.latency_p99_ms *= 2 # Temporary spike state["cache_flushed_services"] = state.get("cache_flushed_services", set()) state["cache_flushed_services"].add(action.target_service) break elif action.action_type == "increase_conn_pool": # Increase database connection pool pool_size = action.parameters.get("pool_size", 50) for sm in state["service_metrics"]: if sm.service == action.target_service: sm.db_conn_pool_utilization = max(0, (sm.db_conn_pool_utilization or 0.5) - 0.2) break elif action.action_type == "drain_traffic": # Remove service from load balancer (all traffic rerouted) for sm in state["service_metrics"]: if sm.service == action.target_service and sm.service != "haproxy-lb": sm.status = "drained" sm.rps = 0 break elif action.action_type == "restore_traffic": # Restore traffic to service for sm in state["service_metrics"]: if sm.service == action.target_service: sm.status = "healthy" if sm.error_rate < 0.1 else "degraded" sm.rps = 100 # Restored to baseline break elif action.action_type == "trigger_failover": # Database failover - brief unavailability then recovery for sm in state["service_metrics"]: if sm.service == action.target_service: sm.status = "recovering" sm.error_rate = 0.2 # Temporary spike during failover break elif action.action_type == "publish_status_page": # Communication action - no direct state effect pass elif action.action_type == "declare_incident_resolved": # Terminal action - marks episode for scoring state["incident_declared_resolved"] = True elif action.action_type == "request_human_escalation": # Terminal action - escalation penalty applied but episode continues state["escalation_requested"] = True def _propagate_failures(self, state: Dict[str, Any]) -> None: # Simplified cascade failing_services = [sm.service for sm in state["service_metrics"] if sm.status in ("critical", "down")] for sm in state["service_metrics"]: if any(dep in failing_services for dep in ENTERPRISE_TOPOLOGY[sm.service]["depends_on"]): sm.error_rate = min(1.0, sm.error_rate + 0.3) def _evolve_metrics(self, state: Dict[str, Any]) -> None: for sm in state["service_metrics"]: # Add noise sm.error_rate = max(0, min(1, sm.error_rate + self.rng.uniform(-0.01, 0.01))) sm.latency_p99_ms += self.rng.uniform(-5, 5) def _generate_alerts_and_logs(self, state: Dict[str, Any]) -> None: # Simplified pass def _recalculate_blast_radius(self, state: Dict[str, Any]) -> float: # Simplified failing = sum(1 for sm in state["service_metrics"] if sm.status != "healthy") state["blast_radius"] = min(1.0, failing / len(state["service_metrics"])) def _build_observation(self, state: Dict[str, Any], action: Optional[IncidentAction] = None) -> IncidentObservation: return IncidentObservation( episode_id=state["episode_id"], task_id="task1_oom_crash", # Hardcoded for now step=state["step"], elapsed_seconds=state["elapsed_seconds"], max_steps=state["max_steps"], blast_radius=state["blast_radius"], error_budget_burn_rate=state["error_budget_burn_rate"], estimated_revenue_loss_usd=state["blast_radius"] * 100000, # Simplified active_user_sessions_impacted=int(state["blast_radius"] * 10000), active_alerts=state["active_alerts"], new_alerts_this_step=[], service_metrics=state["service_metrics"], recent_logs=[], logs_this_step=[], dependency_graph=state["dependency_graph"], last_action_result="Action applied" if action else None, last_action_success=True if action else None, actions_taken=state.get("actions_taken", 0), available_actions=["inspect_logs", "restart_service"], # Simplified safety_violations_this_episode=state.get("safety_violations", 0), rate_limit_remaining=30 - state.get("actions_taken", 0) )