Spaces:
Sleeping
Sleeping
| import random | |
| import time | |
| import copy | |
| from typing import List, Dict, Any, Optional | |
| from .models import ( | |
| Observation, DashboardMetrics, StreamSample, EventSnippet, | |
| SQLModel, RootCauseEnum, Action, ReadDashboardAction, | |
| SampleStreamAction, InspectLineageAction, SimulateConfigChangeAction | |
| ) | |
| class CausalStreamEngine: | |
| def __init__(self, seed: int = 42): | |
| self.seed = seed | |
| self.random = random.Random(seed) | |
| self.current_tick = 0 | |
| self.max_ticks = 100 | |
| self.aggregation_window = 300 # 5 minutes | |
| self.events_buffer: List[EventSnippet] = [] | |
| self.sql_models: Dict[str, SQLModel] = { | |
| "aggregator": SQLModel( | |
| model_id="aggregator", | |
| sql="SELECT SUM(value) FROM events WHERE arrival_time - event_time < 300", | |
| dependencies=["raw_stream"] | |
| ) | |
| } | |
| self.active_incident: Optional[RootCauseEnum] = None | |
| self._initialize_buffer() | |
| def _initialize_buffer(self): | |
| """Pre-fills the buffer with stochastic base events.""" | |
| for i in range(100): | |
| self._generate_event() | |
| def _generate_event(self): | |
| """Generates a single event with stochastic latency and incident evidence.""" | |
| base_latency = 0.5 | |
| jitter = self.random.uniform(0, 1.0) | |
| status = "success" | |
| evidence_tokens = [] | |
| is_phantom = self.random.random() < 0.05 | |
| # Incident logic | |
| if self.active_incident == RootCauseEnum.LATENCY_SPIKE and not is_phantom: | |
| base_latency += 5.0 | |
| if self.random.random() < 0.2: | |
| evidence_tokens.append("STRIPE_WEBHOOK_DELAY") | |
| evidence_tokens.append("P99_LATENCY_3000MS") | |
| elif self.active_incident == RootCauseEnum.JOIN_FAILURE and not is_phantom: | |
| if self.random.random() < 0.3: | |
| status = "error" | |
| evidence_tokens.append("NULL_KEY_ERR") | |
| evidence_tokens.append("JOIN_MISMATCH_404") | |
| elif self.active_incident == RootCauseEnum.OUT_OF_ORDER and not is_phantom: | |
| jitter += 350.0 # Force it past the 300s window | |
| if self.random.random() < 0.2: | |
| evidence_tokens.append("ARRIVAL_GT_EVENT_TIME") | |
| evidence_tokens.append("WINDOW_TIMEOUT") | |
| # Use a true deterministic clock | |
| base_epoch = 1700000000.0 | |
| event_time = base_epoch + self.current_tick - (100 - len(self.events_buffer)) | |
| arrival_time = event_time + base_latency + jitter | |
| actual_latency_ms = (arrival_time - event_time) * 1000.0 | |
| sla_ms = 1000.0 | |
| event = EventSnippet( | |
| event_id=f"evt_{self.random.getrandbits(32)}", | |
| event_time=event_time, | |
| arrival_time=arrival_time, | |
| provider="Stripe-Sim" if not is_phantom else "corrupted_buffer", | |
| status=status, | |
| sla_p99_latency_ms=sla_ms, | |
| actual_p99_latency_ms=actual_latency_ms, | |
| sla_breach=actual_latency_ms > sla_ms, | |
| evidence_tokens=evidence_tokens | |
| ) | |
| self.events_buffer.append(event) | |
| # Duplicate occasionally | |
| if self.random.random() < 0.02: | |
| dup_event = copy.deepcopy(event) | |
| dup_event.arrival_time += self.random.uniform(0, 0.5) | |
| self.events_buffer.append(dup_event) | |
| if len(self.events_buffer) > 500: | |
| self.events_buffer = self.events_buffer[-500:] | |
| def tick(self, count: int = 1): | |
| """Increments the world clock and updates the stream.""" | |
| for _ in range(count): | |
| self.current_tick += 1 | |
| self._generate_event() | |
| def get_observation(self) -> Observation: | |
| # Calculate mock metrics based on buffer and active incident | |
| revenue = 1000.0 | |
| if self.active_incident: | |
| revenue *= 0.9 # 10% drop | |
| metrics = DashboardMetrics( | |
| revenue=revenue, | |
| error_rate=0.02, | |
| avg_latency=1.2, | |
| active_users=5000 | |
| ) | |
| return Observation( | |
| current_tick=self.current_tick, | |
| dashboard=metrics, | |
| alert_feed=["Critical: 10% Revenue Drop Detected" if self.active_incident else "System Nominal"] | |
| ) | |
| def step(self, action: Action) -> Observation: | |
| """Executes an action and returns the new observation.""" | |
| obs = self.get_observation() | |
| if action.type == "read_dashboard": | |
| pass | |
| elif action.type == "sample_stream": | |
| self.tick(1) | |
| obs = self.get_observation() | |
| obs.last_sample = StreamSample( | |
| events=self.events_buffer[-action.sample_size:], | |
| tick=self.current_tick | |
| ) | |
| elif action.type == "inspect_lineage": | |
| self.tick(1) | |
| obs = self.get_observation() | |
| if action.model_id in self.sql_models: | |
| obs.inspected_lineage = self.sql_models[action.model_id] | |
| elif action.type == "simulate_config_change": | |
| self.tick(2) # Simulations are expensive | |
| obs = self.get_observation() | |
| obs.alert_feed.append(f"Simulation: Changed {action.config_param} to {action.value}. Revenue would be {obs.dashboard.revenue * 1.05:.2f}") | |
| elif action.type == "query_system_logs": | |
| self.tick(1) | |
| obs = self.get_observation() | |
| if self.active_incident == RootCauseEnum.EXPECTED_MAINTENANCE: | |
| obs.alert_feed.append(f"System Logs [{action.log_name}]: MAINT_WINDOW_0800_1000 matched. SYSTEM_EVENTS_METADATA trace confirmed.") | |
| else: | |
| obs.alert_feed.append(f"System Logs [{action.log_name}]: Normal. No maintenance logs found.") | |
| elif action.type == "query_provider_contract": | |
| self.tick(1) | |
| obs = self.get_observation() | |
| if self.active_incident == RootCauseEnum.LATENCY_SPIKE: | |
| obs.alert_feed.append(f"Contract Check [{action.provider_id}]: SLA_BREACH_STRIPE true, STRIPE_WEBHOOK_DELAY verified, P99_LATENCY_3000MS exceeded.") | |
| else: | |
| obs.alert_feed.append(f"Contract Check [{action.provider_id}]: SLA met. P99 is nominal.") | |
| return obs | |
| def set_incident(self, incident: RootCauseEnum): | |
| self.active_incident = incident | |