File size: 16,637 Bytes
b14c6e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8061f1b
b14c6e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1341fa9
 
 
 
 
 
 
 
 
b14c6e3
 
 
 
 
 
 
 
 
 
 
1341fa9
b14c6e3
 
 
 
1341fa9
b14c6e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
"""
Utility Functions for the Adaptive Alert Triage Environment

Provides deterministic, seed-controlled helpers for:
  - Alert generation (individual and correlated chains)
  - Severity / noise / false-positive logic
  - System-load calculation
  - Alert-queue arrival modelling
  - Action-correctness evaluation (used by graders)

All randomness flows through numpy so that a single set_seed() call at
episode start guarantees full reproducibility.
"""

import random
from typing import List, Dict, Tuple, Optional

import numpy as np

from adaptive_alert_triage.models import Alert, AlertType


# ---------------------------------------------------------------------------
# Alert-type configuration
# ---------------------------------------------------------------------------

# Each entry defines the baseline true-severity and the false-positive rate
# for that alert class.  These values were chosen to reflect realistic SOC
# distributions (SECURITY is rare but almost never a false positive; APPLICATION
# is the noisiest signal).
ALERT_TYPE_CONFIG: Dict[str, Dict[str, float]] = {
    "CPU":         {"base_severity": 0.60, "false_positive_rate": 0.15},
    "MEMORY":      {"base_severity": 0.70, "false_positive_rate": 0.20},
    "DISK":        {"base_severity": 0.50, "false_positive_rate": 0.25},
    "NETWORK":     {"base_severity": 0.65, "false_positive_rate": 0.10},
    "APPLICATION": {"base_severity": 0.75, "false_positive_rate": 0.30},
    "SECURITY":    {"base_severity": 0.90, "false_positive_rate": 0.05},
}

# Cascade chains: each sub-list is a typical multi-alert failure sequence.
# The environment uses these when generating correlated alert groups.
CORRELATION_CHAINS: List[List[str]] = [
    ["CPU",      "MEMORY",  "APPLICATION"],
    ["NETWORK",  "APPLICATION", "APPLICATION"],
    ["DISK",     "MEMORY",  "APPLICATION"],
    ["SECURITY", "NETWORK", "APPLICATION"],
    ["MEMORY",   "CPU",     "APPLICATION"],
]

# Thresholds used across the environment and graders
CRITICAL_SEVERITY_THRESHOLD: float = 0.75   # true_severity >= this β†’ critical
CRITICAL_AGE_THRESHOLD: int = 5             # age >= this AND critical β†’ failure


# ---------------------------------------------------------------------------
# Seed management
# ---------------------------------------------------------------------------

def set_seed(seed: int) -> None:
    """
    Set random seeds for numpy and the stdlib random module.

    Must be called before any alert-generation functions to guarantee
    reproducible episodes.

    Args:
        seed: Non-negative integer seed value.
    """
    random.seed(seed)
    np.random.seed(seed)


# ---------------------------------------------------------------------------
# ID helpers
# ---------------------------------------------------------------------------

def generate_alert_id(step: int, alert_index: int) -> str:
    """
    Build a deterministic, human-readable alert identifier.

    Format: ``alert_<step:04d>_<index:02d>``

    Args:
        step:        Episode step at which the alert was generated.
        alert_index: Position of this alert within the batch generated
                     at that step.

    Returns:
        Unique alert ID string, e.g. ``"alert_0007_02"``.
    """
    return f"alert_{step:04d}_{alert_index:02d}"


# ---------------------------------------------------------------------------
# Alert-type sampling
# ---------------------------------------------------------------------------

def sample_alert_type() -> AlertType:
    """
    Sample a random alert type using empirically motivated class weights.

    APPLICATION alerts are most common (25 %); SECURITY alerts are rarest
    (5 %) but carry the highest baseline severity.

    Returns:
        One of the six AlertType literals.
    """
    alert_types: List[str] = [
        "CPU", "MEMORY", "DISK", "NETWORK", "APPLICATION", "SECURITY",
    ]
    weights: List[float] = [0.20, 0.20, 0.15, 0.15, 0.25, 0.05]
    idx: int = int(np.random.choice(len(alert_types), p=weights))
    return alert_types[idx]  # type: ignore[return-value]


# ---------------------------------------------------------------------------
# Severity helpers
# ---------------------------------------------------------------------------

def calculate_true_severity(
    alert_type: AlertType,
    is_correlated: bool = False,
) -> float:
    """
    Sample ground-truth severity for a *non*-false-positive alert.

    Adds Gaussian noise (Οƒ=0.10) around the type's baseline severity.
    Correlated alerts receive a 1.3Γ— boost (capped at 1.0) to model the
    increased risk of cascading failures.

    Args:
        alert_type:    Category of the alert.
        is_correlated: Whether the alert belongs to a correlated chain.

    Returns:
        True severity in [0.0, 1.0].
    """
    base: float = ALERT_TYPE_CONFIG[alert_type]["base_severity"]
    noise: float = float(np.random.normal(0.0, 0.10))
    severity: float = float(np.clip(base + noise, 0.0, 1.0))
    if is_correlated:
        severity = float(min(severity * 1.3, 1.0))
    return severity


def add_observation_noise(true_severity: float, confidence: float) -> float:
    """
    Simulate partial-observability by adding confidence-weighted noise.

    Lower confidence β†’ higher observation noise, making it harder for the
    agent to distinguish true positives from false alarms.

    Args:
        true_severity: Ground-truth severity value.
        confidence:    Sensor/detector confidence level.

    Returns:
        Noisy visible severity in [0.0, 1.0].
    """
    noise_std: float = 0.15 * (1.0 - confidence)
    noise: float = float(np.random.normal(0.0, noise_std))
    return float(np.clip(true_severity + noise, 0.0, 1.0))


# ---------------------------------------------------------------------------
# False-positive determination
# ---------------------------------------------------------------------------

def is_false_positive(alert_type: AlertType) -> bool:
    """
    Stochastically decide whether an alert is a false positive.

    Uses the per-type false-positive rate from ALERT_TYPE_CONFIG.

    Args:
        alert_type: Category of the alert.

    Returns:
        True if the alert should be treated as a false positive.
    """
    fp_rate: float = ALERT_TYPE_CONFIG[alert_type]["false_positive_rate"]
    return bool(np.random.random() < fp_rate)


# ---------------------------------------------------------------------------
# Single-alert generation
# ---------------------------------------------------------------------------

def generate_alert(
    step: int,
    alert_index: int,
    is_correlated: bool = False,
    force_critical: bool = False,
) -> Alert:
    """
    Generate a single synthetic alert with both visible and hidden attributes.

    Workflow:
      1. Sample alert type.
      2. Determine if false positive (unless force_critical=True).
      3. Set true_severity: low for FPs, high for forced-critical, otherwise
         sampled via calculate_true_severity().
      4. Sample confidence (type-dependent baseline + noise).
      5. Generate noisy visible_severity via add_observation_noise().

    Args:
        step:          Current episode step (used for ID generation).
        alert_index:   Index within this step's batch.
        is_correlated: Mark the alert as part of a correlated failure chain.
        force_critical: Override FP logic and set severity in [0.8, 1.0].

    Returns:
        Fully populated Alert object.
    """
    alert_id: str = generate_alert_id(step, alert_index)
    alert_type: AlertType = sample_alert_type()

    # False-positive logic
    is_fp: bool = is_false_positive(alert_type) and not force_critical

    # True severity
    if is_fp:
        true_severity = float(np.random.uniform(0.0, 0.30))
    elif force_critical:
        true_severity = float(np.random.uniform(0.80, 1.0))
    else:
        true_severity = calculate_true_severity(alert_type, is_correlated)

    # Confidence β€” inversely related to FP rate, with Gaussian jitter
    base_confidence: float = 1.0 - ALERT_TYPE_CONFIG[alert_type]["false_positive_rate"]
    confidence: float = float(
        np.clip(base_confidence + np.random.normal(0.0, 0.10), 0.0, 1.0)
    )

    # Observable severity (noisy)
    visible_severity: float = add_observation_noise(true_severity, confidence)

    # --- Extreme Outlier Logic (stochastic noise for score variance) ---
    # Adds a 2% chance of a "rogue" alert that contradicts its indicators,
    # ensuring that even perfect agents have some score variance < 1.0.
    if np.random.random() < 0.02:
        if true_severity >= 0.8:
            visible_severity = float(np.random.uniform(0.0, 0.2))  # "Hidden Critical"
        elif true_severity <= 0.2:
            visible_severity = float(np.random.uniform(0.8, 1.0))  # "Phantom Critical"

    return Alert(
        id=alert_id,
        visible_severity=visible_severity,
        confidence=confidence,
        alert_type=alert_type,
        age=0,
        true_severity=true_severity,
        is_correlated=is_correlated,
        metadata={
            "false_positive": is_fp,
            "generated_at_step": step,
            "is_outlier": True,  # mark for audit
        },
    )



# ---------------------------------------------------------------------------
# Correlated-alert chain generation
# ---------------------------------------------------------------------------

def generate_correlated_alerts(step: int, num_alerts: int = 3) -> List[Alert]:
    """
    Generate a sequence of alerts that share a hidden root cause.

    Simulates cascading failures (e.g. high CPU β†’ memory pressure β†’
    application crash).  Severity escalates along the chain so that later
    members are more dangerous than the trigger.

    The IDs of all alerts in the chain should be tracked in
    ``AdaptiveAlertTriageEnv.correlation_groups`` so the hard-task grader
    can reward root-cause identification.

    Args:
        step:       Current episode step (used for ID generation).
        num_alerts: Number of alerts to produce (1 – len(chain), capped
                    at 3 by default to match a typical failure chain).

    Returns:
        List of correlated Alert objects in causal order.
    """
    chain: List[str] = random.choice(CORRELATION_CHAINS)[:num_alerts]
    alerts: List[Alert] = []

    for i, alert_type in enumerate(chain):
        alert_id = generate_alert_id(step, i)

        # Severity increases along the chain
        base_sev: float = 0.60 + i * 0.15
        true_severity: float = float(
            np.clip(base_sev + np.random.normal(0.0, 0.05), 0.0, 1.0)
        )
        confidence: float = float(
            np.clip(0.80 + np.random.normal(0.0, 0.10), 0.0, 1.0)
        )
        visible_severity: float = add_observation_noise(true_severity, confidence)

        alert = Alert(
            id=alert_id,
            visible_severity=visible_severity,
            confidence=confidence,
            alert_type=alert_type,  # type: ignore[arg-type]
            age=0,
            true_severity=true_severity,
            is_correlated=True,
            metadata={
                "false_positive": False,
                "correlation_chain": chain,
                "chain_position": i,
                "generated_at_step": step,
            },
        )
        alerts.append(alert)

    return alerts


# ---------------------------------------------------------------------------
# System-load calculation
# ---------------------------------------------------------------------------

def calculate_system_load(num_active_alerts: int, base_load: float = 0.30) -> float:
    """
    Estimate current system resource utilisation.

    Each unresolved alert contributes 0.05 to load, plus a small Gaussian
    jitter to model background variability.

    Args:
        num_active_alerts: Number of alerts currently in the queue.
        base_load:         Steady-state load with no active alerts.

    Returns:
        System load in [0.0, 1.0].
    """
    alert_contribution: float = num_active_alerts * 0.05
    jitter: float = float(np.random.normal(0.0, 0.02))
    return float(np.clip(base_load + alert_contribution + jitter, 0.0, 1.0))


# ---------------------------------------------------------------------------
# Alert-arrival modelling
# ---------------------------------------------------------------------------

def should_generate_new_alerts(step: int, current_queue: int) -> bool:
    """
    Decide whether the environment should produce new alerts this step.

    Uses a Poisson-inspired arrival model with back-pressure: a growing queue
    reduces arrival probability, preventing runaway queue growth and forcing
    the agent to drain alerts before new ones overwhelm the system.

    Args:
        step:          Current episode step (unused but available for
                       future step-dependent patterns).
        current_queue: Number of alerts already in the queue.

    Returns:
        True if new alerts should be generated.
    """
    base_prob: float = 0.70
    # Back-pressure: every queued alert reduces arrival probability by 0.05,
    # capped at a maximum reduction of 0.40.
    queue_penalty: float = min(current_queue * 0.05, 0.40)
    arrival_prob: float = base_prob - queue_penalty
    return bool(np.random.random() < arrival_prob)


def sample_num_new_alerts() -> int:
    """
    Sample the number of alerts to generate this step (Poisson, Ξ»=2).

    Capped at 5 to prevent single-step queue explosions.

    Returns:
        Integer in [0, 5].
    """
    return int(min(int(np.random.poisson(2)), 5))


# ---------------------------------------------------------------------------
# Alert criticality
# ---------------------------------------------------------------------------

def is_critical_alert(alert: Alert, threshold: float = CRITICAL_SEVERITY_THRESHOLD) -> bool:
    """
    Determine whether an alert is critical based on its *true* severity.

    Note: the agent cannot observe true_severity directly; this function is
    used internally by the reward calculator and failure checker.

    Args:
        alert:     The alert to evaluate.
        threshold: Minimum true_severity for criticality (default 0.75).

    Returns:
        True if the alert's true severity meets or exceeds the threshold.
    """
    return alert.true_severity >= threshold


# ---------------------------------------------------------------------------
# Action-correctness evaluation  (used by task graders)
# ---------------------------------------------------------------------------

def calculate_action_correctness(
    action_type: str,
    alert: Alert,
    resource_constrained: bool = False,
) -> Tuple[bool, str]:
    """
    Evaluate whether an action matches the ground-truth optimal policy.

    Decision logic:
      - Critical alert  β†’ INVESTIGATE or ESCALATE is correct.
      - False positive  β†’ IGNORE is correct; anything else wastes resources.
      - Medium severity β†’ INVESTIGATE is correct; DELAY is acceptable when
                          resource-constrained.

    This is intentionally strict for critical alerts (the agent should never
    ignore or indefinitely delay them) and lenient for medium-severity alerts
    (a delayed medium alert is acceptable if the budget is exhausted).

    Args:
        action_type:          The action taken ("INVESTIGATE", "IGNORE", etc.).
        alert:                Alert being evaluated (with true hidden fields).
        resource_constrained: Whether the task enforces a per-step action budget.

    Returns:
        Tuple of (is_correct: bool, reason: str).
    """
    is_critical: bool = is_critical_alert(alert)
    is_fp: bool = bool(alert.metadata.get("false_positive", False))

    if is_critical:
        if action_type in ("INVESTIGATE", "ESCALATE"):
            return True, "Correctly handled critical alert"
        return False, "Missed critical alert β€” should INVESTIGATE or ESCALATE"

    if is_fp:
        if action_type == "IGNORE":
            return True, "Correctly ignored false positive"
        return False, "Wasted resources on false positive"

    # Medium-severity alert
    if action_type == "INVESTIGATE":
        return True, "Investigated medium-severity alert"
    if action_type == "DELAY" and resource_constrained:
        return True, "Delayed medium alert under resource constraints (acceptable)"
    if action_type == "ESCALATE":
        return True, "Escalated medium alert (acceptable)"
    return True, "Acceptable action for medium-severity alert"