Hacktrix-121's picture
grader fixes
c18a9d1
"""
Adaptive Alert Triage Environment — OpenEnv-compliant OpenEnv Environment
Implements a partially observable RL environment that simulates a real-world
DevOps / SOC alert-triage workflow. An agent must process a continuous stream
of system alerts under time and resource constraints, deciding for each alert:
INVESTIGATE — allocate resources to diagnose (costly)
IGNORE — dismiss as noise (efficient for false positives)
ESCALATE — route to specialist team
DELAY — defer to the next time-step
The environment supports three difficulty tasks:
easy (30 steps, no resource constraint, 10 % correlation probability)
medium (40 steps, K=3 investigations/step, 20 % correlation probability)
hard (50 steps, K=3 investigations/step, 40 % correlation probability,
stricter failure threshold)
OpenEnv interface
-----------------
reset(seed?, options?) -> Observation
step(action) -> (Observation, Reward, done, info)
state() -> EpisodeState
Info dict keys (required by graders)
-------------------------------------
processed_alerts : list[dict] — ground-truth data for every action taken
this step (alert_id, true_severity,
is_false_positive, action_taken, etc.)
correlation_groups: list[list] — current correlated-chain groups (alert IDs)
failures_this_step: int — failures triggered this step
system_failure : bool — True if the episode is in a failure state
step : int — current step index
cumulative_reward : float — total reward so far
failures_count : int — total failures so far
action_correct : bool — whether the most recent action was optimal
"""
from __future__ import annotations
from collections import deque
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import openenv_shim as gym
from openenv_shim import spaces
from adaptive_alert_triage.models import (
Action,
Alert,
EpisodeState,
Observation,
Reward,
)
from adaptive_alert_triage import utils
# Import reward calculation with graceful fallback for development mode
import os as _os
import sys as _sys
try:
from rewards.reward import calculate_reward
except ImportError:
_project_root = _os.path.dirname(
_os.path.dirname(_os.path.dirname(_os.path.abspath(__file__)))
)
if _project_root not in _sys.path:
_sys.path.insert(0, _project_root)
from rewards.reward import calculate_reward # type: ignore[no-redef]
# ---------------------------------------------------------------------------
# Task configurations
# ---------------------------------------------------------------------------
_TASK_CONFIGS: Dict[str, Dict[str, Any]] = {
"easy": {
"max_steps": 10,
"failure_threshold": 2,
"max_investigations": None, # unconstrained
"correlation_probability": 0.10,
"description": "Basic alert prioritisation — no resource constraint.",
},
"medium": {
"max_steps": 15,
"failure_threshold": 3,
"max_investigations": 3, # K = 3 per step
"correlation_probability": 0.20,
"description": "Resource-constrained triage — K=3 investigations/step.",
},
"hard": {
"max_steps": 20,
"failure_threshold": 2, # stricter
"max_investigations": 3,
"correlation_probability": 0.40,
"description": (
"Cascading-failure prevention — correlated alerts, delayed failures, "
"hidden severity, strict failure threshold."
),
},
}
# ---------------------------------------------------------------------------
# Main environment class
# ---------------------------------------------------------------------------
class AdaptiveAlertTriageEnv(gym.Env):
"""
OpenEnv environment for adaptive alert triage.
Parameters
----------
task_id : str
Difficulty level: ``"easy"``, ``"medium"``, or ``"hard"``.
max_steps : int, optional
Override the task-default episode length.
seed : int, optional
Fixed random seed for full reproducibility.
"""
metadata = {"render_modes": ["human", "ansi"]}
# ------------------------------------------------------------------
# Construction
# ------------------------------------------------------------------
def __init__(
self,
task_id: str = "easy",
max_steps: Optional[int] = None,
seed: Optional[int] = None,
) -> None:
super().__init__()
if task_id not in _TASK_CONFIGS:
raise ValueError(
f"Unknown task_id '{task_id}'. "
f"Valid options: {sorted(_TASK_CONFIGS.keys())}"
)
self.task_id: str = task_id
self.config: Dict[str, Any] = dict(_TASK_CONFIGS[task_id])
self.max_steps: int = max_steps or self.config["max_steps"]
self.failure_threshold: int = self.config["failure_threshold"]
self.max_investigations_per_step: Optional[int] = self.config["max_investigations"]
# Episode state — initialised properly in reset()
self.current_step: int = 0
self.alerts: List[Alert] = []
self.failures_count: int = 0
self.cumulative_reward: float = 0.0
self.investigations_used: int = 0
# Hidden state
self.correlation_groups: List[List[str]] = []
# Action history (for state() and checkpointing)
self._action_history: List[Action] = []
# Real-alert ingestion queue (Datadog / Kafka webhook mode)
self.real_alerts_queue: deque = deque(maxlen=50)
# Per-step grading data — populated in step(), consumed by graders
self._processed_alerts_this_step: List[Dict[str, Any]] = []
self._failures_this_step: int = 0
# Seed
self._seed: Optional[int] = seed
if seed is not None:
utils.set_seed(seed)
# OpenEnv spaces (abstract; real actions are Action Pydantic objects)
self.action_space = spaces.Discrete(4) # 4 ActionType values
self.observation_space = spaces.Dict(
{
"system_load": spaces.Box(0.0, 1.0, shape=(1,), dtype=np.float32),
"queue_length": spaces.Box(0, 100, shape=(1,), dtype=np.int32),
"time_remaining": spaces.Box(
0, self.max_steps, shape=(1,), dtype=np.int32
),
}
)
# ------------------------------------------------------------------
# OpenEnv interface — reset
# ------------------------------------------------------------------
def reset(
self,
seed: Optional[int] = None,
options: Optional[Dict[str, Any]] = None,
) -> Observation:
"""
Reset the environment to a clean initial state.
Args:
seed: Override seed for this episode.
options: Reserved for future use (ignored).
Returns:
Initial Observation with no agent-visible hidden fields.
"""
if seed is not None:
self._seed = seed
if self._seed is not None:
utils.set_seed(self._seed)
# Reset all episode counters
self.current_step = 0
self.failures_count = 0
self.cumulative_reward = 0.0
self.investigations_used = 0
self.correlation_groups = []
self._action_history = []
self._processed_alerts_this_step = []
self._failures_this_step = 0
# Generate the initial alert batch
self.alerts = self._generate_initial_alerts()
return self._create_observation()
# ------------------------------------------------------------------
# OpenEnv interface — step
# ------------------------------------------------------------------
def step(
self, action: Action
) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
"""
Execute one environment step.
The agent submits one Action per call; the environment:
1. Validates the alert ID and resource budget.
2. Records ground-truth data for the graders.
3. Calculates the dense reward.
4. Applies the action (removes / keeps alert).
5. Ages remaining alerts.
6. Checks for delayed failures.
7. Generates new alerts (Poisson arrivals + possible correlation chain).
8. Increments step counter and resets per-step budget.
9. Returns (Observation, Reward, done, info).
The ``info`` dict always contains:
processed_alerts — list of ground-truth dicts, one per action
correlation_groups — current correlation chains
failures_this_step — failures triggered this step
system_failure — whether the system is in a failure state
step — current step index
cumulative_reward — total reward this episode
failures_count — total failures this episode
action_correct — whether the action matched the optimal policy
Args:
action: Agent's Action targeting one alert by ID.
Returns:
(next_observation, reward, done, info)
"""
# --- Reset per-step tracking ---
self._processed_alerts_this_step = []
self._failures_this_step = 0
# --- Validate alert ID ---
alert = self._get_alert_by_id(action.alert_id)
if alert is None:
reward = Reward(
value=-5.0,
components={"invalid_action": -5.0},
info={"error": f"Alert ID '{action.alert_id}' not found in queue"},
)
obs = self._create_observation()
return obs, reward, True, self._build_info(
action_correct=False,
extra={"error": "Invalid alert ID"},
)
# --- Resource-budget enforcement ---
if (
self.max_investigations_per_step is not None
and action.action_type == "INVESTIGATE"
):
if self.investigations_used >= self.max_investigations_per_step:
reward = Reward(
value=-3.0,
components={"resource_budget_exceeded": -3.0},
info={
"error": "Investigation budget exhausted for this step",
"budget": self.max_investigations_per_step,
"used": self.investigations_used,
},
)
obs = self._create_observation()
return obs, reward, False, self._build_info(
action_correct=False,
extra={"resource_constraint_violated": True},
)
self.investigations_used += 1
# --- Record ground-truth for graders BEFORE removing the alert ---
processed: Dict[str, Any] = {
"alert_id": alert.id,
"true_severity": alert.true_severity,
"visible_severity": alert.visible_severity,
"confidence": alert.confidence,
"alert_type": alert.alert_type,
"age": alert.age,
"is_correlated": alert.is_correlated,
"is_false_positive": bool(alert.metadata.get("false_positive", False)),
"action_taken": action.action_type,
"correlation_group_index": self._find_correlation_group(alert.id),
}
self._processed_alerts_this_step.append(processed)
# --- Track action history ---
self._action_history.append(action)
# --- Calculate dense reward ---
reward = calculate_reward(action, alert, self.config)
self.cumulative_reward += reward.value
# --- Apply action to alert queue ---
self._process_action(action, alert)
# --- Age all remaining unresolved alerts ---
self._age_alerts()
# --- Check for failures triggered by aged critical alerts ---
self._failures_this_step = self._check_for_failures()
self.failures_count += self._failures_this_step
# --- Generate new alerts (Poisson arrivals + possible chain) ---
if utils.should_generate_new_alerts(self.current_step, len(self.alerts)):
new_alerts = self._generate_new_alerts()
self.alerts.extend(new_alerts)
# --- Advance step and reset per-step investigation budget ---
self.current_step += 1
self.investigations_used = 0
# --- Termination check ---
done: bool = self._is_terminal()
# --- Build next observation (hidden fields masked) ---
obs = self._create_observation()
# --- Determine overall failure state ---
system_in_failure: bool = (
self.failures_count >= self.failure_threshold
or self._failures_this_step > 0
)
info = self._build_info(
action_correct=bool(reward.info.get("action_correct", False)),
extra={
"system_failure": system_in_failure,
"alert_handled": alert.id,
},
)
return obs, reward, done, info
# ------------------------------------------------------------------
# OpenEnv interface — state
# ------------------------------------------------------------------
def state(self) -> EpisodeState:
"""
Return the complete internal episode state (visible + hidden).
Used by evaluation scripts, replay tools, and the hard-task grader
for root-cause analysis. NOT intended to be passed to the agent.
Returns:
EpisodeState with full ground-truth information.
"""
hidden: Dict[str, Any] = {
"true_severities": {a.id: a.true_severity for a in self.alerts},
"correlation_groups": [list(g) for g in self.correlation_groups],
"false_positives": [
a.id for a in self.alerts
if a.metadata.get("false_positive", False)
],
# Pending failures: alerts that are critical AND close to the age threshold
"pending_failures": {
a.id: utils.CRITICAL_AGE_THRESHOLD - a.age
for a in self.alerts
if utils.is_critical_alert(a)
and a.age < utils.CRITICAL_AGE_THRESHOLD
},
}
return EpisodeState(
observation=self._create_observation(),
hidden_state=hidden,
cumulative_reward=self.cumulative_reward,
failures_count=self.failures_count,
actions_taken=[a.model_dump() for a in self._action_history],
seed=self._seed,
)
# ------------------------------------------------------------------
# Internal helpers — observation construction
# ------------------------------------------------------------------
def _create_observation(self) -> Observation:
"""
Build the agent-facing Observation by masking all hidden fields.
true_severity and is_correlated are zeroed-out; metadata is stripped.
The agent must infer hidden information from visible_severity,
confidence, alert_type, and age alone.
"""
system_load: float = utils.calculate_system_load(len(self.alerts))
visible_alerts: List[Alert] = []
for a in self.alerts:
visible_alerts.append(
Alert(
id=a.id,
visible_severity=a.visible_severity,
confidence=a.confidence,
alert_type=a.alert_type,
age=a.age,
# Hidden fields zeroed out
true_severity=0.0,
is_correlated=False,
metadata={},
)
)
resource_budget: Optional[int] = None
if self.max_investigations_per_step is not None:
resource_budget = self.max_investigations_per_step - self.investigations_used
return Observation(
alerts=visible_alerts,
system_load=system_load,
queue_length=len(self.alerts),
time_remaining=max(0, self.max_steps - self.current_step),
episode_step=self.current_step,
resource_budget=resource_budget,
)
# ------------------------------------------------------------------
# Internal helpers — alert generation
# ------------------------------------------------------------------
def _generate_initial_alerts(self) -> List[Alert]:
"""
Generate the starting alert batch for a fresh episode.
Real alerts from the ingestion queue are prioritised; any remaining
slots are filled with synthetic alerts.
"""
num_initial: int = int(np.random.randint(3, 7))
alerts: List[Alert] = []
# Drain real alerts first
while self.real_alerts_queue and len(alerts) < num_initial:
raw = self.real_alerts_queue.popleft()
alerts.append(self._ingest_real_alert(raw))
# Fill with synthetic
for i in range(len(alerts), num_initial):
alerts.append(
utils.generate_alert(step=0, alert_index=i)
)
return alerts
def _generate_new_alerts(self) -> List[Alert]:
"""
Generate alerts to append to the queue this step.
If real alerts are queued they are processed first (no synthetic
alerts generated that step). Otherwise, a Poisson-sampled batch of
independent alerts is generated, with a task-configured probability
that a correlated chain replaces the batch entirely.
"""
# Priority: real ingest queue
if self.real_alerts_queue:
raw = self.real_alerts_queue.popleft()
return [self._ingest_real_alert(raw)]
# Correlated chain vs independent batch
if np.random.random() < self.config["correlation_probability"]:
chain_alerts = utils.generate_correlated_alerts(
self.current_step, num_alerts=3
)
self.correlation_groups.append([a.id for a in chain_alerts])
return chain_alerts
num_new: int = utils.sample_num_new_alerts()
return [
utils.generate_alert(
step=self.current_step,
alert_index=i,
)
for i in range(num_new)
]
@staticmethod
def _ingest_real_alert(raw: Dict[str, Any]) -> Alert:
"""
Convert a raw real-alert dict into an Alert with synthesised ground truth.
Ground truth is estimated by adding Gaussian noise to visible_severity,
reflecting that real monitoring tools provide imperfect severity scores.
"""
true_severity: float = float(
np.clip(
float(raw["visible_severity"]) + np.random.normal(0.0, 0.10),
0.0,
1.0,
)
)
return Alert(
id=raw["id"],
visible_severity=float(raw["visible_severity"]),
confidence=float(raw["confidence"]),
alert_type=raw["type"],
age=0,
true_severity=true_severity,
is_correlated=False,
metadata={"source": "real_ingest", "raw": raw},
)
# ------------------------------------------------------------------
# Internal helpers — action processing
# ------------------------------------------------------------------
def _process_action(self, action: Action, alert: Alert) -> None:
"""
Apply the agent's action to the alert queue.
INVESTIGATE, ESCALATE, and IGNORE all resolve the alert (remove it
from the queue). DELAY keeps the alert in the queue; its age will
be incremented by _age_alerts().
"""
if action.action_type in ("INVESTIGATE", "ESCALATE", "IGNORE"):
self.alerts = [a for a in self.alerts if a.id != alert.id]
# DELAY: no-op — alert remains; age increment handled in _age_alerts()
def _age_alerts(self) -> None:
"""Increment the age of every unresolved alert by one step."""
for alert in self.alerts:
alert.age += 1
def _check_for_failures(self) -> int:
"""
Detect and remove alerts that have caused system failures.
A failure occurs when a critical alert (true_severity ≥ 0.75) has
been in the queue for CRITICAL_AGE_THRESHOLD or more steps without
being resolved. Each such alert contributes one failure event.
Returns:
Number of new failure events detected this step.
"""
failures: int = 0
failed_ids: List[str] = []
for alert in self.alerts:
if (
utils.is_critical_alert(alert)
and alert.age >= utils.CRITICAL_AGE_THRESHOLD
):
failures += 1
failed_ids.append(alert.id)
# Remove failed alerts (they've escalated out of the triage queue)
if failed_ids:
self.alerts = [a for a in self.alerts if a.id not in failed_ids]
return failures
# ------------------------------------------------------------------
# Internal helpers — utilities
# ------------------------------------------------------------------
def _get_alert_by_id(self, alert_id: str) -> Optional[Alert]:
"""Return the Alert with the given ID, or None if not found."""
for alert in self.alerts:
if alert.id == alert_id:
return alert
return None
def _find_correlation_group(self, alert_id: str) -> Optional[int]:
"""
Return the index of the correlation group that contains alert_id, or None.
Used to populate the ``correlation_group_index`` field in processed_alerts
so the hard-task grader can score root-cause identification.
"""
for idx, group in enumerate(self.correlation_groups):
if alert_id in group:
return idx
return None
def _is_terminal(self) -> bool:
"""Return True if the episode should end."""
return (
self.current_step >= self.max_steps
or self.failures_count >= self.failure_threshold
)
def _build_info(
self,
action_correct: bool,
extra: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Assemble the standard info dict returned from step().
Always includes the keys required by all three task graders.
Additional keys can be merged in via ``extra``.
"""
info: Dict[str, Any] = {
# Core grading keys (required)
"processed_alerts": list(self._processed_alerts_this_step),
"correlation_groups": [list(g) for g in self.correlation_groups],
"failures_this_step": self._failures_this_step,
"system_failure": self.failures_count >= self.failure_threshold
or self._failures_this_step > 0,
# Convenience telemetry
"step": self.current_step,
"cumulative_reward": self.cumulative_reward,
"failures_count": self.failures_count,
"action_correct": action_correct,
}
if extra:
info.update(extra)
return info
# ------------------------------------------------------------------
# OpenEnv render
# ------------------------------------------------------------------
def render(self, mode: str = "human") -> Optional[str]:
"""
Render a text summary of the current environment state.
Args:
mode: ``"human"`` (prints to stdout) or ``"ansi"`` (returns string).
Returns:
String if mode is ``"ansi"``, otherwise None.
"""
lines = [
f"\n=== Step {self.current_step}/{self.max_steps}"
f" [{self.task_id}] ===",
f" Failures : {self.failures_count}/{self.failure_threshold}",
f" Cum. reward : {self.cumulative_reward:+.1f}",
f" Active alerts: {len(self.alerts)}",
]
if self.max_investigations_per_step is not None:
lines.append(
f" Inv. budget : "
f"{self.max_investigations_per_step - self.investigations_used}"
f"/{self.max_investigations_per_step} remaining"
)
if self.alerts:
lines.append("\n Alerts (first 5):")
for a in self.alerts[:5]:
lines.append(
f" {a.id} sev={a.visible_severity:.2f}"
f" conf={a.confidence:.2f}"
f" type={a.alert_type:<12}"
f" age={a.age}"
)
if len(self.alerts) > 5:
lines.append(f" … and {len(self.alerts) - 5} more")
output = "\n".join(lines) + "\n"
if mode == "human":
print(output)
return None
return output
# ---------------------------------------------------------------------------
# Quick demo
# ---------------------------------------------------------------------------
def main() -> None:
"""Run a short demo episode with a simple heuristic policy."""
print("Adaptive Alert Triage Environment — Demo\n")
env = AdaptiveAlertTriageEnv(task_id="easy", seed=42)
obs: Observation = env.reset()
print(f"Initial observation: {len(obs.alerts)} alerts "
f"(system_load={obs.system_load:.2f})\n")
done = False
step_count = 0
while not done and step_count < 5:
env.render()
if not obs.alerts:
print("No alerts in queue — nothing to handle.")
break
# Heuristic: pick the alert with the highest visible_severity
best_alert = max(obs.alerts, key=lambda a: a.visible_severity)
action = Action(
alert_id=best_alert.id,
action_type=(
"INVESTIGATE" if best_alert.visible_severity >= 0.7 else "IGNORE"
),
)
obs, reward, done, info = env.step(action)
print(
f" Action: {action.action_type}{best_alert.id}"
f" Reward: {reward.value:+.1f}"
f" Correct: {info.get('action_correct', '?')}"
)
step_count += 1
print(f"\nDemo finished after {step_count} steps.")
print(f"Final cumulative reward : {env.cumulative_reward:+.1f}")
print(f"Total system failures : {env.failures_count}")
if __name__ == "__main__":
main()