skyruh's picture
fix: total hardening of scoring range and evidence discovery logic with stateful episode rewards
3135026
import random
import time
import copy
from typing import List, Dict, Any, Optional
from .models import (
Observation, DashboardMetrics, StreamSample, EventSnippet,
SQLModel, RootCauseEnum, Action, ReadDashboardAction,
SampleStreamAction, InspectLineageAction, SimulateConfigChangeAction
)
class CausalStreamEngine:
def __init__(self, seed: int = 42):
self.seed = seed
self.random = random.Random(seed)
self.current_tick = 0
self.max_ticks = 100
self.aggregation_window = 300 # 5 minutes
self.events_buffer: List[EventSnippet] = []
self.sql_models: Dict[str, SQLModel] = {
"aggregator": SQLModel(
model_id="aggregator",
sql="SELECT SUM(value) FROM events WHERE arrival_time - event_time < 300",
dependencies=["raw_stream"]
)
}
self.active_incident: Optional[RootCauseEnum] = None
self._initialize_buffer()
def _initialize_buffer(self):
"""Pre-fills the buffer with stochastic base events."""
for i in range(100):
self._generate_event()
def _generate_event(self):
"""Generates a single event with stochastic latency and incident evidence."""
base_latency = 0.5
jitter = self.random.uniform(0, 1.0)
status = "success"
evidence_tokens = []
is_phantom = self.random.random() < 0.05
# Incident logic
if self.active_incident == RootCauseEnum.LATENCY_SPIKE and not is_phantom:
base_latency += 5.0
if self.random.random() < 0.2:
evidence_tokens.append("STRIPE_WEBHOOK_DELAY")
evidence_tokens.append("P99_LATENCY_3000MS")
elif self.active_incident == RootCauseEnum.JOIN_FAILURE and not is_phantom:
if self.random.random() < 0.3:
status = "error"
evidence_tokens.append("NULL_KEY_ERR")
evidence_tokens.append("JOIN_MISMATCH_404")
elif self.active_incident == RootCauseEnum.OUT_OF_ORDER and not is_phantom:
jitter += 350.0 # Force it past the 300s window
if self.random.random() < 0.2:
evidence_tokens.append("ARRIVAL_GT_EVENT_TIME")
evidence_tokens.append("WINDOW_TIMEOUT")
# Use a true deterministic clock
base_epoch = 1700000000.0
event_time = base_epoch + self.current_tick - (100 - len(self.events_buffer))
arrival_time = event_time + base_latency + jitter
actual_latency_ms = (arrival_time - event_time) * 1000.0
sla_ms = 1000.0
event = EventSnippet(
event_id=f"evt_{self.random.getrandbits(32)}",
event_time=event_time,
arrival_time=arrival_time,
provider="Stripe-Sim" if not is_phantom else "corrupted_buffer",
status=status,
sla_p99_latency_ms=sla_ms,
actual_p99_latency_ms=actual_latency_ms,
sla_breach=actual_latency_ms > sla_ms,
evidence_tokens=evidence_tokens
)
self.events_buffer.append(event)
# Duplicate occasionally
if self.random.random() < 0.02:
dup_event = copy.deepcopy(event)
dup_event.arrival_time += self.random.uniform(0, 0.5)
self.events_buffer.append(dup_event)
if len(self.events_buffer) > 500:
self.events_buffer = self.events_buffer[-500:]
def tick(self, count: int = 1):
"""Increments the world clock and updates the stream."""
for _ in range(count):
self.current_tick += 1
self._generate_event()
def get_observation(self) -> Observation:
# Calculate mock metrics based on buffer and active incident
revenue = 1000.0
if self.active_incident:
revenue *= 0.9 # 10% drop
metrics = DashboardMetrics(
revenue=revenue,
error_rate=0.02,
avg_latency=1.2,
active_users=5000
)
return Observation(
current_tick=self.current_tick,
dashboard=metrics,
alert_feed=["Critical: 10% Revenue Drop Detected" if self.active_incident else "System Nominal"]
)
def step(self, action: Action) -> Observation:
"""Executes an action and returns the new observation."""
obs = self.get_observation()
if action.type == "read_dashboard":
pass
elif action.type == "sample_stream":
self.tick(1)
obs = self.get_observation()
obs.last_sample = StreamSample(
events=self.events_buffer[-action.sample_size:],
tick=self.current_tick
)
elif action.type == "inspect_lineage":
self.tick(1)
obs = self.get_observation()
if action.model_id in self.sql_models:
obs.inspected_lineage = self.sql_models[action.model_id]
elif action.type == "simulate_config_change":
self.tick(2) # Simulations are expensive
obs = self.get_observation()
obs.alert_feed.append(f"Simulation: Changed {action.config_param} to {action.value}. Revenue would be {obs.dashboard.revenue * 1.05:.2f}")
elif action.type == "query_system_logs":
self.tick(1)
obs = self.get_observation()
if self.active_incident == RootCauseEnum.EXPECTED_MAINTENANCE:
obs.alert_feed.append(f"System Logs [{action.log_name}]: MAINT_WINDOW_0800_1000 matched. SYSTEM_EVENTS_METADATA trace confirmed.")
else:
obs.alert_feed.append(f"System Logs [{action.log_name}]: Normal. No maintenance logs found.")
elif action.type == "query_provider_contract":
self.tick(1)
obs = self.get_observation()
if self.active_incident == RootCauseEnum.LATENCY_SPIKE:
obs.alert_feed.append(f"Contract Check [{action.provider_id}]: SLA_BREACH_STRIPE true, STRIPE_WEBHOOK_DELAY verified, P99_LATENCY_3000MS exceeded.")
else:
obs.alert_feed.append(f"Contract Check [{action.provider_id}]: SLA met. P99 is nominal.")
return obs
def set_incident(self, incident: RootCauseEnum):
self.active_incident = incident