File size: 6,466 Bytes
690ea3c
 
 
 
 
 
 
2f7b539
690ea3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6257920
690ea3c
 
6257920
 
 
690ea3c
6257920
226b5f7
6257920
 
 
 
 
 
 
 
 
 
690ea3c
6257920
 
 
 
 
 
 
7e33b86
 
690ea3c
 
226b5f7
 
 
690ea3c
 
 
 
226b5f7
6257920
226b5f7
 
6257920
 
690ea3c
 
226b5f7
 
 
 
 
 
 
690ea3c
226b5f7
690ea3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e33b86
 
690ea3c
7e33b86
 
3013549
 
 
3135026
3013549
7e33b86
 
3013549
 
 
3135026
3013549
7e33b86
690ea3c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import random
import time
import copy
from typing import List, Dict, Any, Optional
from .models import (
    Observation, DashboardMetrics, StreamSample, EventSnippet, 
    SQLModel, RootCauseEnum, Action, ReadDashboardAction,
    SampleStreamAction, InspectLineageAction, SimulateConfigChangeAction
)

class CausalStreamEngine:
    def __init__(self, seed: int = 42):
        self.seed = seed
        self.random = random.Random(seed)
        self.current_tick = 0
        self.max_ticks = 100
        self.aggregation_window = 300  # 5 minutes
        self.events_buffer: List[EventSnippet] = []
        self.sql_models: Dict[str, SQLModel] = {
            "aggregator": SQLModel(
                model_id="aggregator",
                sql="SELECT SUM(value) FROM events WHERE arrival_time - event_time < 300",
                dependencies=["raw_stream"]
            )
        }
        self.active_incident: Optional[RootCauseEnum] = None
        self._initialize_buffer()

    def _initialize_buffer(self):
        """Pre-fills the buffer with stochastic base events."""
        for i in range(100):
            self._generate_event()

    def _generate_event(self):
        """Generates a single event with stochastic latency and incident evidence."""
        base_latency = 0.5
        jitter = self.random.uniform(0, 1.0)
        status = "success"
        evidence_tokens = []
        is_phantom = self.random.random() < 0.05
        
        # Incident logic
        if self.active_incident == RootCauseEnum.LATENCY_SPIKE and not is_phantom:
            base_latency += 5.0
            if self.random.random() < 0.2:
                evidence_tokens.append("STRIPE_WEBHOOK_DELAY")
                evidence_tokens.append("P99_LATENCY_3000MS")
        
        elif self.active_incident == RootCauseEnum.JOIN_FAILURE and not is_phantom:
            if self.random.random() < 0.3:
                status = "error"
                evidence_tokens.append("NULL_KEY_ERR")
                evidence_tokens.append("JOIN_MISMATCH_404")
        
        elif self.active_incident == RootCauseEnum.OUT_OF_ORDER and not is_phantom:
            jitter += 350.0  # Force it past the 300s window
            if self.random.random() < 0.2:
                evidence_tokens.append("ARRIVAL_GT_EVENT_TIME")
                evidence_tokens.append("WINDOW_TIMEOUT")

        # Use a true deterministic clock
        base_epoch = 1700000000.0
        event_time = base_epoch + self.current_tick - (100 - len(self.events_buffer))
        arrival_time = event_time + base_latency + jitter
        
        actual_latency_ms = (arrival_time - event_time) * 1000.0
        sla_ms = 1000.0
        
        event = EventSnippet(
            event_id=f"evt_{self.random.getrandbits(32)}",
            event_time=event_time,
            arrival_time=arrival_time,
            provider="Stripe-Sim" if not is_phantom else "corrupted_buffer",
            status=status,
            sla_p99_latency_ms=sla_ms,
            actual_p99_latency_ms=actual_latency_ms,
            sla_breach=actual_latency_ms > sla_ms,
            evidence_tokens=evidence_tokens
        )
        self.events_buffer.append(event)
        
        # Duplicate occasionally
        if self.random.random() < 0.02:
            dup_event = copy.deepcopy(event)
            dup_event.arrival_time += self.random.uniform(0, 0.5)
            self.events_buffer.append(dup_event)
            
        if len(self.events_buffer) > 500:
            self.events_buffer = self.events_buffer[-500:]

    def tick(self, count: int = 1):
        """Increments the world clock and updates the stream."""
        for _ in range(count):
            self.current_tick += 1
            self._generate_event()

    def get_observation(self) -> Observation:
        # Calculate mock metrics based on buffer and active incident
        revenue = 1000.0
        if self.active_incident:
            revenue *= 0.9  # 10% drop
            
        metrics = DashboardMetrics(
            revenue=revenue,
            error_rate=0.02,
            avg_latency=1.2,
            active_users=5000
        )
        
        return Observation(
            current_tick=self.current_tick,
            dashboard=metrics,
            alert_feed=["Critical: 10% Revenue Drop Detected" if self.active_incident else "System Nominal"]
        )

    def step(self, action: Action) -> Observation:
        """Executes an action and returns the new observation."""
        obs = self.get_observation()
        
        if action.type == "read_dashboard":
            pass
        elif action.type == "sample_stream":
            self.tick(1)
            obs = self.get_observation()
            obs.last_sample = StreamSample(
                events=self.events_buffer[-action.sample_size:],
                tick=self.current_tick
            )
        elif action.type == "inspect_lineage":
            self.tick(1)
            obs = self.get_observation()
            if action.model_id in self.sql_models:
                obs.inspected_lineage = self.sql_models[action.model_id]
        elif action.type == "simulate_config_change":
            self.tick(2) # Simulations are expensive
            obs = self.get_observation()
            obs.alert_feed.append(f"Simulation: Changed {action.config_param} to {action.value}. Revenue would be {obs.dashboard.revenue * 1.05:.2f}")
        elif action.type == "query_system_logs":
            self.tick(1)
            obs = self.get_observation()
            if self.active_incident == RootCauseEnum.EXPECTED_MAINTENANCE:
                obs.alert_feed.append(f"System Logs [{action.log_name}]: MAINT_WINDOW_0800_1000 matched. SYSTEM_EVENTS_METADATA trace confirmed.")
            else:
                obs.alert_feed.append(f"System Logs [{action.log_name}]: Normal. No maintenance logs found.")
        elif action.type == "query_provider_contract":
            self.tick(1)
            obs = self.get_observation()
            if self.active_incident == RootCauseEnum.LATENCY_SPIKE:
                obs.alert_feed.append(f"Contract Check [{action.provider_id}]: SLA_BREACH_STRIPE true, STRIPE_WEBHOOK_DELAY verified, P99_LATENCY_3000MS exceeded.")
            else:
                obs.alert_feed.append(f"Contract Check [{action.provider_id}]: SLA met. P99 is nominal.")
            
        return obs

    def set_incident(self, incident: RootCauseEnum):
        self.active_incident = incident