Spaces:
Sleeping
Sleeping
| """ | |
| Task 3: Cascading Failure Prevention (Hard) | |
| ============================================ | |
| Objective | |
| --------- | |
| Detect correlated alert chains and stop them *before* they cascade into | |
| system failures. This is the defining challenge of the hard task: each | |
| chain's trigger alert arrives first; if the agent fails to handle it | |
| correctly, the *next* alert in the chain is spawned by the environment in | |
| a future step β and so on until the chain terminates in a system failure. | |
| How cascading chains work in this environment | |
| --------------------------------------------- | |
| Unlike the easy and medium tasks where alerts are mostly independent, the | |
| hard task environment (correlation_probability = 0.40) frequently spawns | |
| *correlated chains*. The full chain is NOT delivered all at once. Instead: | |
| Step N β trigger alert arrives (chain position 0) | |
| Step N+k β if trigger was IGNORED/DELAYED, child alert arrives (position 1) | |
| Step N+m β if child was missed, grandchild arrives (position 2) β¦etc. | |
| This means the agent must: | |
| 1. Recognise a trigger alert before seeing any siblings (the siblings | |
| haven't spawned yet). | |
| 2. INVESTIGATE or ESCALATE the trigger, which *stops* the chain from | |
| propagating further. | |
| 3. If the trigger is missed, handle each subsequent child aggressively. | |
| The grader tracks chain-level outcomes (was the chain stopped at the | |
| trigger? at the first child? did it run all the way to a system failure?) | |
| and awards bonus/penalty accordingly. | |
| Grading formula (all inputs are deterministic given the same seed) | |
| ------------------------------------------------------------------- | |
| component_score = chain_outcome_score + isolation_bonus | |
| + stability_score β timing_penalty | |
| where: | |
| chain_outcome_score: | |
| Each chain contributes up to 1.0 depending on where it was stopped: | |
| stopped at trigger (position 0) β 1.00 | |
| stopped at child (position 1) β 0.70 | |
| stopped at position 2 β 0.40 | |
| ran to system failure β 0.00 | |
| Weighted by the chain's max true_severity so dangerous chains | |
| matter more. | |
| isolation_bonus: | |
| +0.10 per independent (non-correlated) alert correctly handled. | |
| Cap: 0.20 total (so non-correlated work can contribute at most 20 %) | |
| stability_score: | |
| Based on total system failures: | |
| 0 failures β 1.00 | |
| 1 β 0.80 | |
| 2 β 0.60 | |
| 3 β 0.30 | |
| 4+ β 0.00 | |
| timing_penalty: | |
| For each chain: 0.05 Γ (stop_position - 1).clamp(0, 3) | |
| Penalises the agent for letting the chain run multiple steps. | |
| Final normalisation: | |
| raw = component_score / max_possible_score β [0, 1] | |
| final = max(0.0, min(raw * stability_score, 1.0)) | |
| Success threshold: β₯ 0.50 (intentionally hard to challenge frontier models) | |
| Why this is genuinely hard | |
| -------------------------- | |
| - The agent sees only one alert at a time. It cannot know whether an | |
| alert is a chain trigger without correlating its type, timing, and | |
| visible_severity pattern. | |
| - The chain siblings are hidden β they haven't spawned yet when the | |
| trigger arrives, so `is_correlated` can only be known from ground truth. | |
| - Delayed action is penalised: DELAYing a trigger causes a child to spawn. | |
| - The stability multiplier creates a non-linear cliff: even a single | |
| missed chain that reaches failure drops the score significantly. | |
| Integration with env.py | |
| ----------------------- | |
| obs, reward, done, info = env.step(action) | |
| # 1. Update correlation state (called EVERY step β chains grow dynamically) | |
| grader.update_correlation_state(info.get("correlation_groups", [])) | |
| # 2. Process all actions this step | |
| for alert_data in info.get("processed_alerts", []): | |
| grader.process_step(alert_data, info) | |
| # 3. Record any system failures | |
| grader.record_failures(info.get("failures_this_step", 0)) | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional, Set, Tuple | |
| import numpy as np | |
| from adaptive_alert_triage.models import Action, Alert, Reward | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| _CRITICAL_THRESHOLD: float = 0.75 | |
| _MEDIUM_THRESHOLD: float = 0.50 | |
| _FALSE_POSITIVE_THRESHOLD: float = 0.30 | |
| # Chain-stop rewards by position index (0 = trigger, 1 = first child, β¦) | |
| _CHAIN_STOP_REWARDS: Dict[int, float] = { | |
| 0: 1.00, # stopped at trigger β best outcome | |
| 1: 0.70, # stopped at first child | |
| 2: 0.40, # stopped two steps in | |
| 3: 0.15, # barely caught it | |
| } | |
| _CHAIN_FAILURE_REWARD: float = 0.00 # ran all the way to failure | |
| _TIMING_PENALTY_PER_STEP: float = 0.05 # extra penalty per position beyond 0 | |
| _MAX_TIMING_PENALTY_STEPS: int = 3 | |
| # Isolation bonus for correctly handling non-correlated alerts | |
| _ISOLATION_BONUS_PER_ALERT: float = 0.10 | |
| _ISOLATION_BONUS_CAP: float = 0.20 | |
| # Stability score by failure count (step-function approximation) | |
| _STABILITY_BY_FAILURES: List[Tuple[int, float]] = [ | |
| (0, 1.00), | |
| (1, 0.80), | |
| (2, 0.60), | |
| (3, 0.30), | |
| ] | |
| _STABILITY_FLOOR: float = 0.00 | |
| SUCCESS_THRESHOLD: float = 0.50 | |
| # --------------------------------------------------------------------------- | |
| # Internal data class for one chain | |
| # --------------------------------------------------------------------------- | |
| class _ChainRecord: | |
| """Bookkeeping for a single correlated alert chain.""" | |
| __slots__ = ( | |
| "chain_id", "alert_ids", "max_severity", | |
| "stop_position", "completed", "hit_failure", | |
| ) | |
| def __init__(self, chain_id: int, alert_ids: List[str]) -> None: | |
| self.chain_id: int = chain_id | |
| self.alert_ids: List[str] = list(alert_ids) | |
| self.max_severity: float = 0.0 # updated as alerts arrive | |
| self.stop_position: Optional[int] = None # index where chain was halted | |
| self.completed: bool = False | |
| self.hit_failure: bool = False | |
| def position_of(self, alert_id: str) -> Optional[int]: | |
| try: | |
| return self.alert_ids.index(alert_id) | |
| except ValueError: | |
| return None | |
| def mark_stopped(self, position: int, severity: float) -> None: | |
| if not self.completed: | |
| self.stop_position = position | |
| self.completed = True | |
| self.max_severity = max(self.max_severity, severity) | |
| def mark_failure(self) -> None: | |
| self.hit_failure = True | |
| self.completed = True | |
| def outcome_score(self) -> float: | |
| """Score for this chain's outcome weighted by severity.""" | |
| if self.hit_failure: | |
| return _CHAIN_FAILURE_REWARD | |
| if self.stop_position is None: | |
| return _CHAIN_FAILURE_REWARD # chain never handled | |
| base = _CHAIN_STOP_REWARDS.get(self.stop_position, _CHAIN_FAILURE_REWARD) | |
| timing = _TIMING_PENALTY_PER_STEP * min(self.stop_position, | |
| _MAX_TIMING_PENALTY_STEPS) | |
| return max(0.0, base - timing) * max(self.max_severity, 0.5) | |
| def max_possible(self) -> float: | |
| """Maximum possible score for this chain.""" | |
| return _CHAIN_STOP_REWARDS[0] * max(self.max_severity, 0.5) | |
| # --------------------------------------------------------------------------- | |
| # Grader | |
| # --------------------------------------------------------------------------- | |
| class HardTaskGrader: | |
| """ | |
| Grader for Task 3: Cascading Failure Prevention. | |
| Lifecycle (one episode) | |
| ----------------------- | |
| 1. Instantiate once per episode. | |
| 2. After every env.step(action): | |
| a. grader.update_correlation_state(info["correlation_groups"]) | |
| b. for alert_data in info["processed_alerts"]: | |
| grader.process_step(alert_data, info) | |
| c. grader.record_failures(info["failures_this_step"]) | |
| 3. At episode end: grader.get_episode_score() β float β [0.0, 1.0]. | |
| 4. grader.get_metrics() for a full breakdown. | |
| 5. grader.reset() to reuse for a new episode. | |
| """ | |
| def __init__( | |
| self, | |
| correlation_chains: Optional[List[List[str]]] = None, | |
| ) -> None: | |
| # Map from chain_id β _ChainRecord (built up over the episode) | |
| self._chains: Dict[int, _ChainRecord] = {} | |
| # Map from alert_id β chain_id for O(1) lookup | |
| self._alert_to_chain: Dict[str, int] = {} | |
| # Independent-alert counters | |
| self._isolation_correct: int = 0 | |
| self._isolation_total: int = 0 | |
| # System failures | |
| self._system_failures: int = 0 | |
| # Overall accumulator (non-correlated alerts only) | |
| self._action_history: List[Dict[str, Any]] = [] | |
| self._total_actions: int = 0 | |
| # Seed with any chains already known at episode start | |
| if correlation_chains: | |
| self.update_correlation_state(correlation_chains) | |
| # ------------------------------------------------------------------ | |
| # State update (called every step) | |
| # ------------------------------------------------------------------ | |
| def update_correlation_state(self, chains: List[List[str]]) -> None: | |
| """ | |
| Sync the grader's chain knowledge with the environment's live state. | |
| MUST be called EVERY step with info["correlation_groups"] because | |
| new chains spawn mid-episode as the agent misses trigger alerts. | |
| Args: | |
| chains: Current list-of-lists of correlated alert IDs from | |
| info["correlation_groups"]. | |
| """ | |
| for chain_id, alert_ids in enumerate(chains): | |
| if chain_id not in self._chains: | |
| # New chain discovered this step | |
| record = _ChainRecord(chain_id, alert_ids) | |
| self._chains[chain_id] = record | |
| for aid in alert_ids: | |
| self._alert_to_chain[aid] = chain_id | |
| else: | |
| # Existing chain may have grown (new child spawned) | |
| existing = self._chains[chain_id] | |
| for aid in alert_ids: | |
| if aid not in existing.alert_ids: | |
| existing.alert_ids.append(aid) | |
| self._alert_to_chain[aid] = chain_id | |
| # ------------------------------------------------------------------ | |
| # Primary interface | |
| # ------------------------------------------------------------------ | |
| def process_step( | |
| self, | |
| alert_data: Dict[str, Any], | |
| info: Dict[str, Any], # noqa: ARG002 | |
| ) -> float: | |
| """ | |
| Evaluate one action using ground-truth data from env.step(). | |
| For correlated alerts: updates the chain record (stopped / missed). | |
| For independent alerts: accumulates the isolation bonus. | |
| Args: | |
| alert_data: One entry from info["processed_alerts"]. | |
| info: Full info dict (unused here, kept for API symmetry). | |
| Returns: | |
| Immediate contribution to the score (for logging β the true | |
| episode score is only finalised at get_episode_score()). | |
| """ | |
| self._total_actions += 1 | |
| alert_id: str = str(alert_data.get("alert_id", "")) | |
| true_sev: float = float(alert_data.get("true_severity", 0.0)) | |
| action_type: str = str(alert_data.get("action_taken", "")) | |
| is_corr: bool = bool(alert_data.get("is_correlated", False)) | |
| chain_idx: Optional[int] = alert_data.get("correlation_group_index") | |
| # Normalise: use our own chain map if env didn't set the index | |
| if chain_idx is None: | |
| chain_idx = self._alert_to_chain.get(alert_id) | |
| contribution = 0.0 | |
| if is_corr and chain_idx is not None and chain_idx in self._chains: | |
| contribution = self._handle_correlated( | |
| alert_id, true_sev, action_type, chain_idx | |
| ) | |
| else: | |
| # Independent alert | |
| contribution = self._handle_independent(true_sev, action_type) | |
| self._action_history.append({ | |
| "alert_id": alert_id, | |
| "action": action_type, | |
| "true_severity": true_sev, | |
| "is_correlated": is_corr, | |
| "chain_idx": chain_idx, | |
| "contribution": contribution, | |
| }) | |
| return contribution | |
| def record_failures(self, count: int) -> None: | |
| """ | |
| Record system failures detected this step. | |
| Call with info["failures_this_step"] after every env.step(). | |
| Args: | |
| count: Number of failures this step (0 is fine). | |
| """ | |
| self._system_failures += max(0, int(count)) | |
| # Mark all incomplete chains as hit_failure (conservative) | |
| for record in self._chains.values(): | |
| if not record.completed and count > 0: | |
| record.mark_failure() | |
| # ------------------------------------------------------------------ | |
| # Legacy API | |
| # ------------------------------------------------------------------ | |
| def grade_action( | |
| self, | |
| action: Action, | |
| alert: Alert, | |
| reward: Reward, | |
| current_alerts: Optional[List[Alert]] = None, | |
| ) -> float: | |
| """Grade a single action-alert pair (legacy / unit-test API).""" | |
| alert_data = { | |
| "alert_id": alert.id, | |
| "true_severity": alert.true_severity, | |
| "visible_severity": alert.visible_severity, | |
| "confidence": alert.confidence, | |
| "alert_type": alert.alert_type, | |
| "age": alert.age, | |
| "action_taken": action.action_type, | |
| "is_correlated": alert.is_correlated, | |
| "is_false_positive": alert.true_severity < _FALSE_POSITIVE_THRESHOLD, | |
| "correlation_group_index": self._alert_to_chain.get(alert.id), | |
| } | |
| return self.process_step(alert_data, {}) | |
| def record_system_failure(self, alert_id: Optional[str] = None) -> None: | |
| """Legacy single-failure recorder (kept for backward compat).""" | |
| self.record_failures(1) | |
| # ------------------------------------------------------------------ | |
| # Scoring | |
| # ------------------------------------------------------------------ | |
| def get_episode_score(self) -> float: | |
| """ | |
| Return final normalised score strictly in (0, 1) β never 0.0 or 1.0. | |
| """ | |
| chain_score = sum(c.outcome_score() for c in self._chains.values()) | |
| max_chain = sum(c.max_possible() for c in self._chains.values()) | |
| isolation = min( | |
| self._isolation_correct * _ISOLATION_BONUS_PER_ALERT, | |
| _ISOLATION_BONUS_CAP, | |
| ) | |
| denominator = max(max_chain, 1.0) | |
| raw = (chain_score + isolation) / denominator | |
| stability = self._stability_score(self._system_failures) | |
| # Raw * stability is naturally in [0, 1]. | |
| # Map [0, 1] linearly to [0.01, 0.99] without clipping | |
| mapped = (raw * stability * 0.98) + 0.01 | |
| return float(round(mapped, 4)) | |
| def passed(self) -> bool: | |
| """Return True if the agent meets the hard-task success threshold.""" | |
| return self.get_episode_score() >= SUCCESS_THRESHOLD | |
| def calculate_correlation_detection_rate(self) -> float: | |
| """ | |
| Fraction of chains that were successfully stopped (any position). | |
| Returns 1.0 when no chains exist (nothing to detect). | |
| """ | |
| if not self._chains: | |
| return 1.0 | |
| stopped = sum( | |
| 1 for c in self._chains.values() | |
| if c.completed and not c.hit_failure | |
| ) | |
| raw = stopped / len(self._chains) | |
| return raw | |
| def calculate_stability_score(self) -> float: | |
| """Return the stability multiplier for the current failure count.""" | |
| return self._stability_score(self._system_failures) | |
| # ------------------------------------------------------------------ | |
| # Metrics | |
| # ------------------------------------------------------------------ | |
| def get_metrics(self) -> Dict[str, Any]: | |
| """Return a full breakdown of episode performance.""" | |
| score = self.get_episode_score() | |
| corr_rate = self.calculate_correlation_detection_rate() | |
| stability = self.calculate_stability_score() | |
| chain_details = [] | |
| for cid, rec in sorted(self._chains.items()): | |
| chain_details.append({ | |
| "chain_id": cid, | |
| "length": len(rec.alert_ids), | |
| "max_severity": rec.max_severity, | |
| "stop_position": rec.stop_position, | |
| "hit_failure": rec.hit_failure, | |
| "outcome_score": rec.outcome_score(), | |
| }) | |
| breakdown: Dict[str, int] = { | |
| "INVESTIGATE": 0, "IGNORE": 0, "ESCALATE": 0, "DELAY": 0, | |
| } | |
| for h in self._action_history: | |
| breakdown[h["action"]] = breakdown.get(h["action"], 0) + 1 | |
| return { | |
| "overall_score": score, | |
| "passed": self.passed(), | |
| "success_threshold": SUCCESS_THRESHOLD, | |
| "chain_score": sum(c.outcome_score() for c in self._chains.values()), | |
| "max_chain_score": sum(c.max_possible() for c in self._chains.values()), | |
| "correlation_detection_rate": corr_rate, | |
| "total_chains": len(self._chains), | |
| "chains_stopped": sum(1 for c in self._chains.values() | |
| if c.completed and not c.hit_failure), | |
| "chains_at_trigger": sum(1 for c in self._chains.values() | |
| if c.stop_position == 0), | |
| "chains_hit_failure": sum(1 for c in self._chains.values() | |
| if c.hit_failure), | |
| "chain_details": chain_details, | |
| "isolation_correct": self._isolation_correct, | |
| "isolation_total": self._isolation_total, | |
| "system_failures": self._system_failures, | |
| "stability_score": stability, | |
| "total_actions": self._total_actions, | |
| "action_breakdown": breakdown, | |
| } | |
| # ------------------------------------------------------------------ | |
| # Housekeeping | |
| # ------------------------------------------------------------------ | |
| def reset(self) -> None: | |
| """Reset all state for a new episode.""" | |
| self._chains = {} | |
| self._alert_to_chain = {} | |
| self._isolation_correct = 0 | |
| self._isolation_total = 0 | |
| self._system_failures = 0 | |
| self._action_history = [] | |
| self._total_actions = 0 | |
| def __repr__(self) -> str: | |
| score = self.get_episode_score() | |
| corr_r = self.calculate_correlation_detection_rate() | |
| return ( | |
| f"HardTaskGrader(score={score:.3f}, " | |
| f"failures={self._system_failures}, " | |
| f"chains={len(self._chains)}, " | |
| f"detection_rate={corr_r:.3f}, " | |
| f"passed={self.passed()})" | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Private helpers | |
| # ------------------------------------------------------------------ | |
| def _handle_correlated( | |
| self, | |
| alert_id: str, | |
| true_sev: float, | |
| action_type: str, | |
| chain_idx: int, | |
| ) -> float: | |
| """ | |
| Update chain record based on action and return immediate contribution. | |
| If the action is INVESTIGATE or ESCALATE β chain is stopped here. | |
| If IGNORE or DELAY β chain continues (child will spawn next step). | |
| """ | |
| record = self._chains[chain_idx] | |
| position = record.position_of(alert_id) | |
| if position is None: | |
| position = len(record.alert_ids) # fallback: treat as end of chain | |
| record.max_severity = max(record.max_severity, true_sev) | |
| proactive = action_type in ("INVESTIGATE", "ESCALATE") | |
| if proactive: | |
| # Agent stopped the chain at this position | |
| record.mark_stopped(position, true_sev) | |
| base = _CHAIN_STOP_REWARDS.get(position, _CHAIN_FAILURE_REWARD) | |
| timing = _TIMING_PENALTY_PER_STEP * min(position, _MAX_TIMING_PENALTY_STEPS) | |
| contrib = max(0.0, base - timing) * max(true_sev, 0.5) | |
| else: | |
| # Agent missed this alert β chain propagates | |
| # Give a small negative signal so the agent learns | |
| if true_sev >= _CRITICAL_THRESHOLD: | |
| contrib = -0.30 * true_sev | |
| else: | |
| contrib = -0.10 | |
| return contrib | |
| def _handle_independent(self, true_sev: float, action_type: str) -> float: | |
| """ | |
| Score a non-correlated alert action and accumulate isolation bonus. | |
| Returns a small contribution that feeds into the isolation bonus pool. | |
| """ | |
| self._isolation_total += 1 | |
| is_correct = self._independent_correct(action_type, true_sev) | |
| if is_correct: | |
| self._isolation_correct += 1 | |
| return _ISOLATION_BONUS_PER_ALERT | |
| return 0.0 | |
| def _independent_correct(action_type: str, true_sev: float) -> bool: | |
| """Correctness rule for non-correlated alerts (same as easy task).""" | |
| if true_sev >= _CRITICAL_THRESHOLD: | |
| return action_type in ("INVESTIGATE", "ESCALATE") | |
| if true_sev < _FALSE_POSITIVE_THRESHOLD: | |
| return action_type == "IGNORE" | |
| # Medium | |
| return action_type in ("INVESTIGATE", "ESCALATE") | |
| def _stability_score(failures: int) -> float: | |
| """Step-function stability multiplier.""" | |
| for threshold, score in _STABILITY_BY_FAILURES: | |
| if failures <= threshold: | |
| return score | |
| return _STABILITY_FLOOR | |
| # --------------------------------------------------------------------------- | |
| # Evaluation helper | |
| # --------------------------------------------------------------------------- | |
| def run_episode_evaluation( | |
| agent, | |
| env, | |
| num_episodes: int = 10, | |
| seed_offset: int = 0, | |
| verbose: bool = False, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run multiple episodes and return aggregated grading results. | |
| Args: | |
| agent: Agent with .act(observation) -> Action method. | |
| env: AdaptiveAlertTriageEnv(task_id="hard") instance. | |
| num_episodes: Number of episodes to run. | |
| seed_offset: Added to episode index for the reset seed. | |
| verbose: Print per-episode summary when True. | |
| Returns: | |
| Dict with keys: mean_score, std_score, min_score, max_score, | |
| success_rate, mean_failures, mean_detection_rate, | |
| episode_scores, episode_metrics. | |
| """ | |
| episode_scores: List[float] = [] | |
| episode_metrics: List[Dict[str, Any]] = [] | |
| for ep in range(num_episodes): | |
| grader = HardTaskGrader() | |
| obs = env.reset(seed=seed_offset + ep) | |
| done = False | |
| while not done: | |
| if not obs.alerts: | |
| break | |
| action = agent.act(obs) | |
| obs, _reward, done, info = env.step(action) | |
| # 1. Update chain knowledge (MUST be before process_step) | |
| grader.update_correlation_state(info.get("correlation_groups", [])) | |
| # 2. Grade actions | |
| for alert_data in info.get("processed_alerts", []): | |
| grader.process_step(alert_data, info) | |
| # 3. Record failures | |
| grader.record_failures(info.get("failures_this_step", 0)) | |
| score = grader.get_episode_score() | |
| metrics = grader.get_metrics() | |
| episode_scores.append(score) | |
| episode_metrics.append(metrics) | |
| if verbose: | |
| print( | |
| f" ep {ep + 1:02d} score={score:.3f} " | |
| f"failures={metrics['system_failures']} " | |
| f"chains={metrics['chains_stopped']}/{metrics['total_chains']} " | |
| f"passed={metrics['passed']}" | |
| ) | |
| scores_arr = np.array(episode_scores) | |
| fail_arr = np.array([m["system_failures"] for m in episode_metrics]) | |
| det_arr = np.array([m["correlation_detection_rate"] for m in episode_metrics]) | |
| return { | |
| "mean_score": float(scores_arr.mean()), | |
| "std_score": float(scores_arr.std()), | |
| "min_score": float(scores_arr.min()), | |
| "max_score": float(scores_arr.max()), | |
| "success_rate": float((scores_arr >= SUCCESS_THRESHOLD).mean()), | |
| "mean_failures": float(fail_arr.mean()), | |
| "mean_detection_rate": float(det_arr.mean()), | |
| "episode_scores": episode_scores, | |
| "episode_metrics": episode_metrics, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Self-test | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| print("HardTaskGrader β self-test\n" + "=" * 60) | |
| from adaptive_alert_triage.models import Alert, Action, Reward | |
| def _alert(aid: str, true_sev: float, correlated: bool = False) -> Alert: | |
| return Alert( | |
| id=aid, visible_severity=true_sev * 0.95, confidence=0.85, | |
| alert_type="CPU", age=1, true_severity=true_sev, | |
| is_correlated=correlated, | |
| ) | |
| # ββ Scenario 1: Agent stops chain at trigger ββββββββββββββββββββββ | |
| print("\n[Scenario 1] Agent catches trigger alert β chain stops immediately") | |
| grader = HardTaskGrader() | |
| grader.update_correlation_state([["cpu_01", "mem_01", "app_01"]]) | |
| a = _alert("cpu_01", 0.80, correlated=True) | |
| contrib = grader.grade_action(Action(alert_id="cpu_01", action_type="INVESTIGATE"), a, Reward(value=0)) | |
| print(f" Trigger INVESTIGATE contrib={contrib:+.4f}") | |
| grader.record_failures(0) | |
| score = grader.get_episode_score() | |
| m = grader.get_metrics() | |
| print(f" Episode score : {score:.4f} (expected β₯ 0.5)") | |
| print(f" Chains stopped at trigger: {m['chains_at_trigger']}") | |
| assert score >= 0.50, f"Expected β₯0.50, got {score}" | |
| assert m["chains_at_trigger"] == 1 | |
| # ββ Scenario 2: Agent misses trigger β child spawns β agent catches child ββ | |
| print("\n[Scenario 2] Agent misses trigger, catches first child") | |
| grader2 = HardTaskGrader() | |
| grader2.update_correlation_state([["cpu_02", "mem_02", "app_02"]]) | |
| # Miss trigger | |
| a1 = _alert("cpu_02", 0.78, correlated=True) | |
| c1 = grader2.grade_action(Action(alert_id="cpu_02", action_type="IGNORE"), a1, Reward(value=0)) | |
| print(f" Trigger IGNORE contrib={c1:+.4f} (penalty expected)") | |
| # Child alert spawned, agent catches it | |
| a2 = _alert("mem_02", 0.85, correlated=True) | |
| c2 = grader2.grade_action(Action(alert_id="mem_02", action_type="INVESTIGATE"), a2, Reward(value=0)) | |
| print(f" Child INVESTIGATE contrib={c2:+.4f}") | |
| grader2.record_failures(0) | |
| score2 = grader2.get_episode_score() | |
| print(f" Episode score : {score2:.4f} (expected < scenario 1 score)") | |
| assert score2 < score, "Catching child should score less than catching trigger" | |
| # ββ Scenario 3: Full failure β agent misses entire chain βββββββββββ | |
| print("\n[Scenario 3] Agent misses all chain alerts β system failure") | |
| grader3 = HardTaskGrader() | |
| grader3.update_correlation_state([["cpu_03", "mem_03", "app_03"]]) | |
| for aid, sev in [("cpu_03", 0.80), ("mem_03", 0.88), ("app_03", 0.92)]: | |
| a = _alert(aid, sev, correlated=True) | |
| grader3.grade_action(Action(alert_id=aid, action_type="IGNORE"), a, Reward(value=0)) | |
| grader3.record_failures(1) # system failure registered | |
| score3 = grader3.get_episode_score() | |
| print(f" Episode score : {score3:.4f} (expected β 0)") | |
| assert score3 < 0.3, f"Missed entire chain + failure should score < 0.3, got {score3}" | |
| # ββ Determinism check βββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n[Determinism] Same inputs β same score") | |
| def _run_fixed() -> float: | |
| g = HardTaskGrader() | |
| g.update_correlation_state([["x1", "x2"]]) | |
| a = _alert("x1", 0.82, correlated=True) | |
| g.grade_action(Action(alert_id="x1", action_type="INVESTIGATE"), a, Reward(value=0)) | |
| g.record_failures(0) | |
| return g.get_episode_score() | |
| s_a, s_b = _run_fixed(), _run_fixed() | |
| assert s_a == s_b, f"Non-deterministic: {s_a} != {s_b}" | |
| print(f" Score both runs: {s_a:.6f} β") | |
| print("\n" + "=" * 60) | |
| print("All self-tests passed!") |