Spaces:
Sleeping
Sleeping
| """ | |
| Utility Functions for the Adaptive Alert Triage Environment | |
| Provides deterministic, seed-controlled helpers for: | |
| - Alert generation (individual and correlated chains) | |
| - Severity / noise / false-positive logic | |
| - System-load calculation | |
| - Alert-queue arrival modelling | |
| - Action-correctness evaluation (used by graders) | |
| All randomness flows through numpy so that a single set_seed() call at | |
| episode start guarantees full reproducibility. | |
| """ | |
| import random | |
| from typing import List, Dict, Tuple, Optional | |
| import numpy as np | |
| from adaptive_alert_triage.models import Alert, AlertType | |
| # --------------------------------------------------------------------------- | |
| # Alert-type configuration | |
| # --------------------------------------------------------------------------- | |
| # Each entry defines the baseline true-severity and the false-positive rate | |
| # for that alert class. These values were chosen to reflect realistic SOC | |
| # distributions (SECURITY is rare but almost never a false positive; APPLICATION | |
| # is the noisiest signal). | |
| ALERT_TYPE_CONFIG: Dict[str, Dict[str, float]] = { | |
| "CPU": {"base_severity": 0.60, "false_positive_rate": 0.15}, | |
| "MEMORY": {"base_severity": 0.70, "false_positive_rate": 0.20}, | |
| "DISK": {"base_severity": 0.50, "false_positive_rate": 0.25}, | |
| "NETWORK": {"base_severity": 0.65, "false_positive_rate": 0.10}, | |
| "APPLICATION": {"base_severity": 0.75, "false_positive_rate": 0.30}, | |
| "SECURITY": {"base_severity": 0.90, "false_positive_rate": 0.05}, | |
| } | |
| # Cascade chains: each sub-list is a typical multi-alert failure sequence. | |
| # The environment uses these when generating correlated alert groups. | |
| CORRELATION_CHAINS: List[List[str]] = [ | |
| ["CPU", "MEMORY", "APPLICATION"], | |
| ["NETWORK", "APPLICATION", "APPLICATION"], | |
| ["DISK", "MEMORY", "APPLICATION"], | |
| ["SECURITY", "NETWORK", "APPLICATION"], | |
| ["MEMORY", "CPU", "APPLICATION"], | |
| ] | |
| # Thresholds used across the environment and graders | |
| CRITICAL_SEVERITY_THRESHOLD: float = 0.75 # true_severity >= this → critical | |
| CRITICAL_AGE_THRESHOLD: int = 5 # age >= this AND critical → failure | |
| # --------------------------------------------------------------------------- | |
| # Seed management | |
| # --------------------------------------------------------------------------- | |
| def set_seed(seed: int) -> None: | |
| """ | |
| Set random seeds for numpy and the stdlib random module. | |
| Must be called before any alert-generation functions to guarantee | |
| reproducible episodes. | |
| Args: | |
| seed: Non-negative integer seed value. | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| # --------------------------------------------------------------------------- | |
| # ID helpers | |
| # --------------------------------------------------------------------------- | |
| def generate_alert_id(step: int, alert_index: int) -> str: | |
| """ | |
| Build a deterministic, human-readable alert identifier. | |
| Format: ``alert_<step:04d>_<index:02d>`` | |
| Args: | |
| step: Episode step at which the alert was generated. | |
| alert_index: Position of this alert within the batch generated | |
| at that step. | |
| Returns: | |
| Unique alert ID string, e.g. ``"alert_0007_02"``. | |
| """ | |
| return f"alert_{step:04d}_{alert_index:02d}" | |
| # --------------------------------------------------------------------------- | |
| # Alert-type sampling | |
| # --------------------------------------------------------------------------- | |
| def sample_alert_type() -> AlertType: | |
| """ | |
| Sample a random alert type using empirically motivated class weights. | |
| APPLICATION alerts are most common (25 %); SECURITY alerts are rarest | |
| (5 %) but carry the highest baseline severity. | |
| Returns: | |
| One of the six AlertType literals. | |
| """ | |
| alert_types: List[str] = [ | |
| "CPU", "MEMORY", "DISK", "NETWORK", "APPLICATION", "SECURITY", | |
| ] | |
| weights: List[float] = [0.20, 0.20, 0.15, 0.15, 0.25, 0.05] | |
| idx: int = int(np.random.choice(len(alert_types), p=weights)) | |
| return alert_types[idx] # type: ignore[return-value] | |
| # --------------------------------------------------------------------------- | |
| # Severity helpers | |
| # --------------------------------------------------------------------------- | |
| def calculate_true_severity( | |
| alert_type: AlertType, | |
| is_correlated: bool = False, | |
| ) -> float: | |
| """ | |
| Sample ground-truth severity for a *non*-false-positive alert. | |
| Adds Gaussian noise (σ=0.10) around the type's baseline severity. | |
| Correlated alerts receive a 1.3× boost (capped at 1.0) to model the | |
| increased risk of cascading failures. | |
| Args: | |
| alert_type: Category of the alert. | |
| is_correlated: Whether the alert belongs to a correlated chain. | |
| Returns: | |
| True severity in [0.0, 1.0]. | |
| """ | |
| base: float = ALERT_TYPE_CONFIG[alert_type]["base_severity"] | |
| noise: float = float(np.random.normal(0.0, 0.10)) | |
| severity: float = float(np.clip(base + noise, 0.0, 1.0)) | |
| if is_correlated: | |
| severity = float(min(severity * 1.3, 1.0)) | |
| return severity | |
| def add_observation_noise(true_severity: float, confidence: float) -> float: | |
| """ | |
| Simulate partial-observability by adding confidence-weighted noise. | |
| Lower confidence → higher observation noise, making it harder for the | |
| agent to distinguish true positives from false alarms. | |
| Args: | |
| true_severity: Ground-truth severity value. | |
| confidence: Sensor/detector confidence level. | |
| Returns: | |
| Noisy visible severity in [0.0, 1.0]. | |
| """ | |
| noise_std: float = 0.15 * (1.0 - confidence) | |
| noise: float = float(np.random.normal(0.0, noise_std)) | |
| return float(np.clip(true_severity + noise, 0.0, 1.0)) | |
| # --------------------------------------------------------------------------- | |
| # False-positive determination | |
| # --------------------------------------------------------------------------- | |
| def is_false_positive(alert_type: AlertType) -> bool: | |
| """ | |
| Stochastically decide whether an alert is a false positive. | |
| Uses the per-type false-positive rate from ALERT_TYPE_CONFIG. | |
| Args: | |
| alert_type: Category of the alert. | |
| Returns: | |
| True if the alert should be treated as a false positive. | |
| """ | |
| fp_rate: float = ALERT_TYPE_CONFIG[alert_type]["false_positive_rate"] | |
| return bool(np.random.random() < fp_rate) | |
| # --------------------------------------------------------------------------- | |
| # Single-alert generation | |
| # --------------------------------------------------------------------------- | |
| def generate_alert( | |
| step: int, | |
| alert_index: int, | |
| is_correlated: bool = False, | |
| force_critical: bool = False, | |
| ) -> Alert: | |
| """ | |
| Generate a single synthetic alert with both visible and hidden attributes. | |
| Workflow: | |
| 1. Sample alert type. | |
| 2. Determine if false positive (unless force_critical=True). | |
| 3. Set true_severity: low for FPs, high for forced-critical, otherwise | |
| sampled via calculate_true_severity(). | |
| 4. Sample confidence (type-dependent baseline + noise). | |
| 5. Generate noisy visible_severity via add_observation_noise(). | |
| Args: | |
| step: Current episode step (used for ID generation). | |
| alert_index: Index within this step's batch. | |
| is_correlated: Mark the alert as part of a correlated failure chain. | |
| force_critical: Override FP logic and set severity in [0.8, 1.0]. | |
| Returns: | |
| Fully populated Alert object. | |
| """ | |
| alert_id: str = generate_alert_id(step, alert_index) | |
| alert_type: AlertType = sample_alert_type() | |
| # False-positive logic | |
| is_fp: bool = is_false_positive(alert_type) and not force_critical | |
| # True severity | |
| if is_fp: | |
| true_severity = float(np.random.uniform(0.0, 0.30)) | |
| elif force_critical: | |
| true_severity = float(np.random.uniform(0.80, 1.0)) | |
| else: | |
| true_severity = calculate_true_severity(alert_type, is_correlated) | |
| # Confidence — inversely related to FP rate, with Gaussian jitter | |
| base_confidence: float = 1.0 - ALERT_TYPE_CONFIG[alert_type]["false_positive_rate"] | |
| confidence: float = float( | |
| np.clip(base_confidence + np.random.normal(0.0, 0.10), 0.0, 1.0) | |
| ) | |
| # Observable severity (noisy) | |
| visible_severity: float = add_observation_noise(true_severity, confidence) | |
| # --- Extreme Outlier Logic (stochastic noise for score variance) --- | |
| # Adds a 2% chance of a "rogue" alert that contradicts its indicators, | |
| # ensuring that even perfect agents have some score variance < 1.0. | |
| if np.random.random() < 0.02: | |
| if true_severity >= 0.8: | |
| visible_severity = float(np.random.uniform(0.0, 0.2)) # "Hidden Critical" | |
| elif true_severity <= 0.2: | |
| visible_severity = float(np.random.uniform(0.8, 1.0)) # "Phantom Critical" | |
| return Alert( | |
| id=alert_id, | |
| visible_severity=visible_severity, | |
| confidence=confidence, | |
| alert_type=alert_type, | |
| age=0, | |
| true_severity=true_severity, | |
| is_correlated=is_correlated, | |
| metadata={ | |
| "false_positive": is_fp, | |
| "generated_at_step": step, | |
| "is_outlier": True, # mark for audit | |
| }, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Correlated-alert chain generation | |
| # --------------------------------------------------------------------------- | |
| def generate_correlated_alerts(step: int, num_alerts: int = 3) -> List[Alert]: | |
| """ | |
| Generate a sequence of alerts that share a hidden root cause. | |
| Simulates cascading failures (e.g. high CPU → memory pressure → | |
| application crash). Severity escalates along the chain so that later | |
| members are more dangerous than the trigger. | |
| The IDs of all alerts in the chain should be tracked in | |
| ``AdaptiveAlertTriageEnv.correlation_groups`` so the hard-task grader | |
| can reward root-cause identification. | |
| Args: | |
| step: Current episode step (used for ID generation). | |
| num_alerts: Number of alerts to produce (1 – len(chain), capped | |
| at 3 by default to match a typical failure chain). | |
| Returns: | |
| List of correlated Alert objects in causal order. | |
| """ | |
| chain: List[str] = random.choice(CORRELATION_CHAINS)[:num_alerts] | |
| alerts: List[Alert] = [] | |
| for i, alert_type in enumerate(chain): | |
| alert_id = generate_alert_id(step, i) | |
| # Severity increases along the chain | |
| base_sev: float = 0.60 + i * 0.15 | |
| true_severity: float = float( | |
| np.clip(base_sev + np.random.normal(0.0, 0.05), 0.0, 1.0) | |
| ) | |
| confidence: float = float( | |
| np.clip(0.80 + np.random.normal(0.0, 0.10), 0.0, 1.0) | |
| ) | |
| visible_severity: float = add_observation_noise(true_severity, confidence) | |
| alert = Alert( | |
| id=alert_id, | |
| visible_severity=visible_severity, | |
| confidence=confidence, | |
| alert_type=alert_type, # type: ignore[arg-type] | |
| age=0, | |
| true_severity=true_severity, | |
| is_correlated=True, | |
| metadata={ | |
| "false_positive": False, | |
| "correlation_chain": chain, | |
| "chain_position": i, | |
| "generated_at_step": step, | |
| }, | |
| ) | |
| alerts.append(alert) | |
| return alerts | |
| # --------------------------------------------------------------------------- | |
| # System-load calculation | |
| # --------------------------------------------------------------------------- | |
| def calculate_system_load(num_active_alerts: int, base_load: float = 0.30) -> float: | |
| """ | |
| Estimate current system resource utilisation. | |
| Each unresolved alert contributes 0.05 to load, plus a small Gaussian | |
| jitter to model background variability. | |
| Args: | |
| num_active_alerts: Number of alerts currently in the queue. | |
| base_load: Steady-state load with no active alerts. | |
| Returns: | |
| System load in [0.0, 1.0]. | |
| """ | |
| alert_contribution: float = num_active_alerts * 0.05 | |
| jitter: float = float(np.random.normal(0.0, 0.02)) | |
| return float(np.clip(base_load + alert_contribution + jitter, 0.0, 1.0)) | |
| # --------------------------------------------------------------------------- | |
| # Alert-arrival modelling | |
| # --------------------------------------------------------------------------- | |
| def should_generate_new_alerts(step: int, current_queue: int) -> bool: | |
| """ | |
| Decide whether the environment should produce new alerts this step. | |
| Uses a Poisson-inspired arrival model with back-pressure: a growing queue | |
| reduces arrival probability, preventing runaway queue growth and forcing | |
| the agent to drain alerts before new ones overwhelm the system. | |
| Args: | |
| step: Current episode step (unused but available for | |
| future step-dependent patterns). | |
| current_queue: Number of alerts already in the queue. | |
| Returns: | |
| True if new alerts should be generated. | |
| """ | |
| base_prob: float = 0.70 | |
| # Back-pressure: every queued alert reduces arrival probability by 0.05, | |
| # capped at a maximum reduction of 0.40. | |
| queue_penalty: float = min(current_queue * 0.05, 0.40) | |
| arrival_prob: float = base_prob - queue_penalty | |
| return bool(np.random.random() < arrival_prob) | |
| def sample_num_new_alerts() -> int: | |
| """ | |
| Sample the number of alerts to generate this step (Poisson, λ=2). | |
| Capped at 5 to prevent single-step queue explosions. | |
| Returns: | |
| Integer in [0, 5]. | |
| """ | |
| return int(min(int(np.random.poisson(2)), 5)) | |
| # --------------------------------------------------------------------------- | |
| # Alert criticality | |
| # --------------------------------------------------------------------------- | |
| def is_critical_alert(alert: Alert, threshold: float = CRITICAL_SEVERITY_THRESHOLD) -> bool: | |
| """ | |
| Determine whether an alert is critical based on its *true* severity. | |
| Note: the agent cannot observe true_severity directly; this function is | |
| used internally by the reward calculator and failure checker. | |
| Args: | |
| alert: The alert to evaluate. | |
| threshold: Minimum true_severity for criticality (default 0.75). | |
| Returns: | |
| True if the alert's true severity meets or exceeds the threshold. | |
| """ | |
| return alert.true_severity >= threshold | |
| # --------------------------------------------------------------------------- | |
| # Action-correctness evaluation (used by task graders) | |
| # --------------------------------------------------------------------------- | |
| def calculate_action_correctness( | |
| action_type: str, | |
| alert: Alert, | |
| resource_constrained: bool = False, | |
| ) -> Tuple[bool, str]: | |
| """ | |
| Evaluate whether an action matches the ground-truth optimal policy. | |
| Decision logic: | |
| - Critical alert → INVESTIGATE or ESCALATE is correct. | |
| - False positive → IGNORE is correct; anything else wastes resources. | |
| - Medium severity → INVESTIGATE is correct; DELAY is acceptable when | |
| resource-constrained. | |
| This is intentionally strict for critical alerts (the agent should never | |
| ignore or indefinitely delay them) and lenient for medium-severity alerts | |
| (a delayed medium alert is acceptable if the budget is exhausted). | |
| Args: | |
| action_type: The action taken ("INVESTIGATE", "IGNORE", etc.). | |
| alert: Alert being evaluated (with true hidden fields). | |
| resource_constrained: Whether the task enforces a per-step action budget. | |
| Returns: | |
| Tuple of (is_correct: bool, reason: str). | |
| """ | |
| is_critical: bool = is_critical_alert(alert) | |
| is_fp: bool = bool(alert.metadata.get("false_positive", False)) | |
| if is_critical: | |
| if action_type in ("INVESTIGATE", "ESCALATE"): | |
| return True, "Correctly handled critical alert" | |
| return False, "Missed critical alert — should INVESTIGATE or ESCALATE" | |
| if is_fp: | |
| if action_type == "IGNORE": | |
| return True, "Correctly ignored false positive" | |
| return False, "Wasted resources on false positive" | |
| # Medium-severity alert | |
| if action_type == "INVESTIGATE": | |
| return True, "Investigated medium-severity alert" | |
| if action_type == "DELAY" and resource_constrained: | |
| return True, "Delayed medium alert under resource constraints (acceptable)" | |
| if action_type == "ESCALATE": | |
| return True, "Escalated medium alert (acceptable)" | |
| return True, "Acceptable action for medium-severity alert" |