Spaces:
Sleeping
Sleeping
| """ | |
| Adaptive Alert Triage Environment — OpenEnv-compliant OpenEnv Environment | |
| Implements a partially observable RL environment that simulates a real-world | |
| DevOps / SOC alert-triage workflow. An agent must process a continuous stream | |
| of system alerts under time and resource constraints, deciding for each alert: | |
| INVESTIGATE — allocate resources to diagnose (costly) | |
| IGNORE — dismiss as noise (efficient for false positives) | |
| ESCALATE — route to specialist team | |
| DELAY — defer to the next time-step | |
| The environment supports three difficulty tasks: | |
| easy (30 steps, no resource constraint, 10 % correlation probability) | |
| medium (40 steps, K=3 investigations/step, 20 % correlation probability) | |
| hard (50 steps, K=3 investigations/step, 40 % correlation probability, | |
| stricter failure threshold) | |
| OpenEnv interface | |
| ----------------- | |
| reset(seed?, options?) -> Observation | |
| step(action) -> (Observation, Reward, done, info) | |
| state() -> EpisodeState | |
| Info dict keys (required by graders) | |
| ------------------------------------- | |
| processed_alerts : list[dict] — ground-truth data for every action taken | |
| this step (alert_id, true_severity, | |
| is_false_positive, action_taken, etc.) | |
| correlation_groups: list[list] — current correlated-chain groups (alert IDs) | |
| failures_this_step: int — failures triggered this step | |
| system_failure : bool — True if the episode is in a failure state | |
| step : int — current step index | |
| cumulative_reward : float — total reward so far | |
| failures_count : int — total failures so far | |
| action_correct : bool — whether the most recent action was optimal | |
| """ | |
| from __future__ import annotations | |
| from collections import deque | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import openenv_shim as gym | |
| from openenv_shim import spaces | |
| from adaptive_alert_triage.models import ( | |
| Action, | |
| Alert, | |
| EpisodeState, | |
| Observation, | |
| Reward, | |
| ) | |
| from adaptive_alert_triage import utils | |
| # Import reward calculation with graceful fallback for development mode | |
| import os as _os | |
| import sys as _sys | |
| try: | |
| from rewards.reward import calculate_reward | |
| except ImportError: | |
| _project_root = _os.path.dirname( | |
| _os.path.dirname(_os.path.dirname(_os.path.abspath(__file__))) | |
| ) | |
| if _project_root not in _sys.path: | |
| _sys.path.insert(0, _project_root) | |
| from rewards.reward import calculate_reward # type: ignore[no-redef] | |
| # --------------------------------------------------------------------------- | |
| # Task configurations | |
| # --------------------------------------------------------------------------- | |
| _TASK_CONFIGS: Dict[str, Dict[str, Any]] = { | |
| "easy": { | |
| "max_steps": 10, | |
| "failure_threshold": 2, | |
| "max_investigations": None, # unconstrained | |
| "correlation_probability": 0.10, | |
| "description": "Basic alert prioritisation — no resource constraint.", | |
| }, | |
| "medium": { | |
| "max_steps": 15, | |
| "failure_threshold": 3, | |
| "max_investigations": 3, # K = 3 per step | |
| "correlation_probability": 0.20, | |
| "description": "Resource-constrained triage — K=3 investigations/step.", | |
| }, | |
| "hard": { | |
| "max_steps": 20, | |
| "failure_threshold": 2, # stricter | |
| "max_investigations": 3, | |
| "correlation_probability": 0.40, | |
| "description": ( | |
| "Cascading-failure prevention — correlated alerts, delayed failures, " | |
| "hidden severity, strict failure threshold." | |
| ), | |
| }, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Main environment class | |
| # --------------------------------------------------------------------------- | |
| class AdaptiveAlertTriageEnv(gym.Env): | |
| """ | |
| OpenEnv environment for adaptive alert triage. | |
| Parameters | |
| ---------- | |
| task_id : str | |
| Difficulty level: ``"easy"``, ``"medium"``, or ``"hard"``. | |
| max_steps : int, optional | |
| Override the task-default episode length. | |
| seed : int, optional | |
| Fixed random seed for full reproducibility. | |
| """ | |
| metadata = {"render_modes": ["human", "ansi"]} | |
| # ------------------------------------------------------------------ | |
| # Construction | |
| # ------------------------------------------------------------------ | |
| def __init__( | |
| self, | |
| task_id: str = "easy", | |
| max_steps: Optional[int] = None, | |
| seed: Optional[int] = None, | |
| ) -> None: | |
| super().__init__() | |
| if task_id not in _TASK_CONFIGS: | |
| raise ValueError( | |
| f"Unknown task_id '{task_id}'. " | |
| f"Valid options: {sorted(_TASK_CONFIGS.keys())}" | |
| ) | |
| self.task_id: str = task_id | |
| self.config: Dict[str, Any] = dict(_TASK_CONFIGS[task_id]) | |
| self.max_steps: int = max_steps or self.config["max_steps"] | |
| self.failure_threshold: int = self.config["failure_threshold"] | |
| self.max_investigations_per_step: Optional[int] = self.config["max_investigations"] | |
| # Episode state — initialised properly in reset() | |
| self.current_step: int = 0 | |
| self.alerts: List[Alert] = [] | |
| self.failures_count: int = 0 | |
| self.cumulative_reward: float = 0.0 | |
| self.investigations_used: int = 0 | |
| # Hidden state | |
| self.correlation_groups: List[List[str]] = [] | |
| # Action history (for state() and checkpointing) | |
| self._action_history: List[Action] = [] | |
| # Real-alert ingestion queue (Datadog / Kafka webhook mode) | |
| self.real_alerts_queue: deque = deque(maxlen=50) | |
| # Per-step grading data — populated in step(), consumed by graders | |
| self._processed_alerts_this_step: List[Dict[str, Any]] = [] | |
| self._failures_this_step: int = 0 | |
| # Seed | |
| self._seed: Optional[int] = seed | |
| if seed is not None: | |
| utils.set_seed(seed) | |
| # OpenEnv spaces (abstract; real actions are Action Pydantic objects) | |
| self.action_space = spaces.Discrete(4) # 4 ActionType values | |
| self.observation_space = spaces.Dict( | |
| { | |
| "system_load": spaces.Box(0.0, 1.0, shape=(1,), dtype=np.float32), | |
| "queue_length": spaces.Box(0, 100, shape=(1,), dtype=np.int32), | |
| "time_remaining": spaces.Box( | |
| 0, self.max_steps, shape=(1,), dtype=np.int32 | |
| ), | |
| } | |
| ) | |
| # ------------------------------------------------------------------ | |
| # OpenEnv interface — reset | |
| # ------------------------------------------------------------------ | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| options: Optional[Dict[str, Any]] = None, | |
| ) -> Observation: | |
| """ | |
| Reset the environment to a clean initial state. | |
| Args: | |
| seed: Override seed for this episode. | |
| options: Reserved for future use (ignored). | |
| Returns: | |
| Initial Observation with no agent-visible hidden fields. | |
| """ | |
| if seed is not None: | |
| self._seed = seed | |
| if self._seed is not None: | |
| utils.set_seed(self._seed) | |
| # Reset all episode counters | |
| self.current_step = 0 | |
| self.failures_count = 0 | |
| self.cumulative_reward = 0.0 | |
| self.investigations_used = 0 | |
| self.correlation_groups = [] | |
| self._action_history = [] | |
| self._processed_alerts_this_step = [] | |
| self._failures_this_step = 0 | |
| # Generate the initial alert batch | |
| self.alerts = self._generate_initial_alerts() | |
| return self._create_observation() | |
| # ------------------------------------------------------------------ | |
| # OpenEnv interface — step | |
| # ------------------------------------------------------------------ | |
| def step( | |
| self, action: Action | |
| ) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: | |
| """ | |
| Execute one environment step. | |
| The agent submits one Action per call; the environment: | |
| 1. Validates the alert ID and resource budget. | |
| 2. Records ground-truth data for the graders. | |
| 3. Calculates the dense reward. | |
| 4. Applies the action (removes / keeps alert). | |
| 5. Ages remaining alerts. | |
| 6. Checks for delayed failures. | |
| 7. Generates new alerts (Poisson arrivals + possible correlation chain). | |
| 8. Increments step counter and resets per-step budget. | |
| 9. Returns (Observation, Reward, done, info). | |
| The ``info`` dict always contains: | |
| processed_alerts — list of ground-truth dicts, one per action | |
| correlation_groups — current correlation chains | |
| failures_this_step — failures triggered this step | |
| system_failure — whether the system is in a failure state | |
| step — current step index | |
| cumulative_reward — total reward this episode | |
| failures_count — total failures this episode | |
| action_correct — whether the action matched the optimal policy | |
| Args: | |
| action: Agent's Action targeting one alert by ID. | |
| Returns: | |
| (next_observation, reward, done, info) | |
| """ | |
| # --- Reset per-step tracking --- | |
| self._processed_alerts_this_step = [] | |
| self._failures_this_step = 0 | |
| # --- Validate alert ID --- | |
| alert = self._get_alert_by_id(action.alert_id) | |
| if alert is None: | |
| reward = Reward( | |
| value=-5.0, | |
| components={"invalid_action": -5.0}, | |
| info={"error": f"Alert ID '{action.alert_id}' not found in queue"}, | |
| ) | |
| obs = self._create_observation() | |
| return obs, reward, True, self._build_info( | |
| action_correct=False, | |
| extra={"error": "Invalid alert ID"}, | |
| ) | |
| # --- Resource-budget enforcement --- | |
| if ( | |
| self.max_investigations_per_step is not None | |
| and action.action_type == "INVESTIGATE" | |
| ): | |
| if self.investigations_used >= self.max_investigations_per_step: | |
| reward = Reward( | |
| value=-3.0, | |
| components={"resource_budget_exceeded": -3.0}, | |
| info={ | |
| "error": "Investigation budget exhausted for this step", | |
| "budget": self.max_investigations_per_step, | |
| "used": self.investigations_used, | |
| }, | |
| ) | |
| obs = self._create_observation() | |
| return obs, reward, False, self._build_info( | |
| action_correct=False, | |
| extra={"resource_constraint_violated": True}, | |
| ) | |
| self.investigations_used += 1 | |
| # --- Record ground-truth for graders BEFORE removing the alert --- | |
| processed: Dict[str, Any] = { | |
| "alert_id": alert.id, | |
| "true_severity": alert.true_severity, | |
| "visible_severity": alert.visible_severity, | |
| "confidence": alert.confidence, | |
| "alert_type": alert.alert_type, | |
| "age": alert.age, | |
| "is_correlated": alert.is_correlated, | |
| "is_false_positive": bool(alert.metadata.get("false_positive", False)), | |
| "action_taken": action.action_type, | |
| "correlation_group_index": self._find_correlation_group(alert.id), | |
| } | |
| self._processed_alerts_this_step.append(processed) | |
| # --- Track action history --- | |
| self._action_history.append(action) | |
| # --- Calculate dense reward --- | |
| reward = calculate_reward(action, alert, self.config) | |
| self.cumulative_reward += reward.value | |
| # --- Apply action to alert queue --- | |
| self._process_action(action, alert) | |
| # --- Age all remaining unresolved alerts --- | |
| self._age_alerts() | |
| # --- Check for failures triggered by aged critical alerts --- | |
| self._failures_this_step = self._check_for_failures() | |
| self.failures_count += self._failures_this_step | |
| # --- Generate new alerts (Poisson arrivals + possible chain) --- | |
| if utils.should_generate_new_alerts(self.current_step, len(self.alerts)): | |
| new_alerts = self._generate_new_alerts() | |
| self.alerts.extend(new_alerts) | |
| # --- Advance step and reset per-step investigation budget --- | |
| self.current_step += 1 | |
| self.investigations_used = 0 | |
| # --- Termination check --- | |
| done: bool = self._is_terminal() | |
| # --- Build next observation (hidden fields masked) --- | |
| obs = self._create_observation() | |
| # --- Determine overall failure state --- | |
| system_in_failure: bool = ( | |
| self.failures_count >= self.failure_threshold | |
| or self._failures_this_step > 0 | |
| ) | |
| info = self._build_info( | |
| action_correct=bool(reward.info.get("action_correct", False)), | |
| extra={ | |
| "system_failure": system_in_failure, | |
| "alert_handled": alert.id, | |
| }, | |
| ) | |
| return obs, reward, done, info | |
| # ------------------------------------------------------------------ | |
| # OpenEnv interface — state | |
| # ------------------------------------------------------------------ | |
| def state(self) -> EpisodeState: | |
| """ | |
| Return the complete internal episode state (visible + hidden). | |
| Used by evaluation scripts, replay tools, and the hard-task grader | |
| for root-cause analysis. NOT intended to be passed to the agent. | |
| Returns: | |
| EpisodeState with full ground-truth information. | |
| """ | |
| hidden: Dict[str, Any] = { | |
| "true_severities": {a.id: a.true_severity for a in self.alerts}, | |
| "correlation_groups": [list(g) for g in self.correlation_groups], | |
| "false_positives": [ | |
| a.id for a in self.alerts | |
| if a.metadata.get("false_positive", False) | |
| ], | |
| # Pending failures: alerts that are critical AND close to the age threshold | |
| "pending_failures": { | |
| a.id: utils.CRITICAL_AGE_THRESHOLD - a.age | |
| for a in self.alerts | |
| if utils.is_critical_alert(a) | |
| and a.age < utils.CRITICAL_AGE_THRESHOLD | |
| }, | |
| } | |
| return EpisodeState( | |
| observation=self._create_observation(), | |
| hidden_state=hidden, | |
| cumulative_reward=self.cumulative_reward, | |
| failures_count=self.failures_count, | |
| actions_taken=[a.model_dump() for a in self._action_history], | |
| seed=self._seed, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Internal helpers — observation construction | |
| # ------------------------------------------------------------------ | |
| def _create_observation(self) -> Observation: | |
| """ | |
| Build the agent-facing Observation by masking all hidden fields. | |
| true_severity and is_correlated are zeroed-out; metadata is stripped. | |
| The agent must infer hidden information from visible_severity, | |
| confidence, alert_type, and age alone. | |
| """ | |
| system_load: float = utils.calculate_system_load(len(self.alerts)) | |
| visible_alerts: List[Alert] = [] | |
| for a in self.alerts: | |
| visible_alerts.append( | |
| Alert( | |
| id=a.id, | |
| visible_severity=a.visible_severity, | |
| confidence=a.confidence, | |
| alert_type=a.alert_type, | |
| age=a.age, | |
| # Hidden fields zeroed out | |
| true_severity=0.0, | |
| is_correlated=False, | |
| metadata={}, | |
| ) | |
| ) | |
| resource_budget: Optional[int] = None | |
| if self.max_investigations_per_step is not None: | |
| resource_budget = self.max_investigations_per_step - self.investigations_used | |
| return Observation( | |
| alerts=visible_alerts, | |
| system_load=system_load, | |
| queue_length=len(self.alerts), | |
| time_remaining=max(0, self.max_steps - self.current_step), | |
| episode_step=self.current_step, | |
| resource_budget=resource_budget, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Internal helpers — alert generation | |
| # ------------------------------------------------------------------ | |
| def _generate_initial_alerts(self) -> List[Alert]: | |
| """ | |
| Generate the starting alert batch for a fresh episode. | |
| Real alerts from the ingestion queue are prioritised; any remaining | |
| slots are filled with synthetic alerts. | |
| """ | |
| num_initial: int = int(np.random.randint(3, 7)) | |
| alerts: List[Alert] = [] | |
| # Drain real alerts first | |
| while self.real_alerts_queue and len(alerts) < num_initial: | |
| raw = self.real_alerts_queue.popleft() | |
| alerts.append(self._ingest_real_alert(raw)) | |
| # Fill with synthetic | |
| for i in range(len(alerts), num_initial): | |
| alerts.append( | |
| utils.generate_alert(step=0, alert_index=i) | |
| ) | |
| return alerts | |
| def _generate_new_alerts(self) -> List[Alert]: | |
| """ | |
| Generate alerts to append to the queue this step. | |
| If real alerts are queued they are processed first (no synthetic | |
| alerts generated that step). Otherwise, a Poisson-sampled batch of | |
| independent alerts is generated, with a task-configured probability | |
| that a correlated chain replaces the batch entirely. | |
| """ | |
| # Priority: real ingest queue | |
| if self.real_alerts_queue: | |
| raw = self.real_alerts_queue.popleft() | |
| return [self._ingest_real_alert(raw)] | |
| # Correlated chain vs independent batch | |
| if np.random.random() < self.config["correlation_probability"]: | |
| chain_alerts = utils.generate_correlated_alerts( | |
| self.current_step, num_alerts=3 | |
| ) | |
| self.correlation_groups.append([a.id for a in chain_alerts]) | |
| return chain_alerts | |
| num_new: int = utils.sample_num_new_alerts() | |
| return [ | |
| utils.generate_alert( | |
| step=self.current_step, | |
| alert_index=i, | |
| ) | |
| for i in range(num_new) | |
| ] | |
| def _ingest_real_alert(raw: Dict[str, Any]) -> Alert: | |
| """ | |
| Convert a raw real-alert dict into an Alert with synthesised ground truth. | |
| Ground truth is estimated by adding Gaussian noise to visible_severity, | |
| reflecting that real monitoring tools provide imperfect severity scores. | |
| """ | |
| true_severity: float = float( | |
| np.clip( | |
| float(raw["visible_severity"]) + np.random.normal(0.0, 0.10), | |
| 0.0, | |
| 1.0, | |
| ) | |
| ) | |
| return Alert( | |
| id=raw["id"], | |
| visible_severity=float(raw["visible_severity"]), | |
| confidence=float(raw["confidence"]), | |
| alert_type=raw["type"], | |
| age=0, | |
| true_severity=true_severity, | |
| is_correlated=False, | |
| metadata={"source": "real_ingest", "raw": raw}, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Internal helpers — action processing | |
| # ------------------------------------------------------------------ | |
| def _process_action(self, action: Action, alert: Alert) -> None: | |
| """ | |
| Apply the agent's action to the alert queue. | |
| INVESTIGATE, ESCALATE, and IGNORE all resolve the alert (remove it | |
| from the queue). DELAY keeps the alert in the queue; its age will | |
| be incremented by _age_alerts(). | |
| """ | |
| if action.action_type in ("INVESTIGATE", "ESCALATE", "IGNORE"): | |
| self.alerts = [a for a in self.alerts if a.id != alert.id] | |
| # DELAY: no-op — alert remains; age increment handled in _age_alerts() | |
| def _age_alerts(self) -> None: | |
| """Increment the age of every unresolved alert by one step.""" | |
| for alert in self.alerts: | |
| alert.age += 1 | |
| def _check_for_failures(self) -> int: | |
| """ | |
| Detect and remove alerts that have caused system failures. | |
| A failure occurs when a critical alert (true_severity ≥ 0.75) has | |
| been in the queue for CRITICAL_AGE_THRESHOLD or more steps without | |
| being resolved. Each such alert contributes one failure event. | |
| Returns: | |
| Number of new failure events detected this step. | |
| """ | |
| failures: int = 0 | |
| failed_ids: List[str] = [] | |
| for alert in self.alerts: | |
| if ( | |
| utils.is_critical_alert(alert) | |
| and alert.age >= utils.CRITICAL_AGE_THRESHOLD | |
| ): | |
| failures += 1 | |
| failed_ids.append(alert.id) | |
| # Remove failed alerts (they've escalated out of the triage queue) | |
| if failed_ids: | |
| self.alerts = [a for a in self.alerts if a.id not in failed_ids] | |
| return failures | |
| # ------------------------------------------------------------------ | |
| # Internal helpers — utilities | |
| # ------------------------------------------------------------------ | |
| def _get_alert_by_id(self, alert_id: str) -> Optional[Alert]: | |
| """Return the Alert with the given ID, or None if not found.""" | |
| for alert in self.alerts: | |
| if alert.id == alert_id: | |
| return alert | |
| return None | |
| def _find_correlation_group(self, alert_id: str) -> Optional[int]: | |
| """ | |
| Return the index of the correlation group that contains alert_id, or None. | |
| Used to populate the ``correlation_group_index`` field in processed_alerts | |
| so the hard-task grader can score root-cause identification. | |
| """ | |
| for idx, group in enumerate(self.correlation_groups): | |
| if alert_id in group: | |
| return idx | |
| return None | |
| def _is_terminal(self) -> bool: | |
| """Return True if the episode should end.""" | |
| return ( | |
| self.current_step >= self.max_steps | |
| or self.failures_count >= self.failure_threshold | |
| ) | |
| def _build_info( | |
| self, | |
| action_correct: bool, | |
| extra: Optional[Dict[str, Any]] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Assemble the standard info dict returned from step(). | |
| Always includes the keys required by all three task graders. | |
| Additional keys can be merged in via ``extra``. | |
| """ | |
| info: Dict[str, Any] = { | |
| # Core grading keys (required) | |
| "processed_alerts": list(self._processed_alerts_this_step), | |
| "correlation_groups": [list(g) for g in self.correlation_groups], | |
| "failures_this_step": self._failures_this_step, | |
| "system_failure": self.failures_count >= self.failure_threshold | |
| or self._failures_this_step > 0, | |
| # Convenience telemetry | |
| "step": self.current_step, | |
| "cumulative_reward": self.cumulative_reward, | |
| "failures_count": self.failures_count, | |
| "action_correct": action_correct, | |
| } | |
| if extra: | |
| info.update(extra) | |
| return info | |
| # ------------------------------------------------------------------ | |
| # OpenEnv render | |
| # ------------------------------------------------------------------ | |
| def render(self, mode: str = "human") -> Optional[str]: | |
| """ | |
| Render a text summary of the current environment state. | |
| Args: | |
| mode: ``"human"`` (prints to stdout) or ``"ansi"`` (returns string). | |
| Returns: | |
| String if mode is ``"ansi"``, otherwise None. | |
| """ | |
| lines = [ | |
| f"\n=== Step {self.current_step}/{self.max_steps}" | |
| f" [{self.task_id}] ===", | |
| f" Failures : {self.failures_count}/{self.failure_threshold}", | |
| f" Cum. reward : {self.cumulative_reward:+.1f}", | |
| f" Active alerts: {len(self.alerts)}", | |
| ] | |
| if self.max_investigations_per_step is not None: | |
| lines.append( | |
| f" Inv. budget : " | |
| f"{self.max_investigations_per_step - self.investigations_used}" | |
| f"/{self.max_investigations_per_step} remaining" | |
| ) | |
| if self.alerts: | |
| lines.append("\n Alerts (first 5):") | |
| for a in self.alerts[:5]: | |
| lines.append( | |
| f" {a.id} sev={a.visible_severity:.2f}" | |
| f" conf={a.confidence:.2f}" | |
| f" type={a.alert_type:<12}" | |
| f" age={a.age}" | |
| ) | |
| if len(self.alerts) > 5: | |
| lines.append(f" … and {len(self.alerts) - 5} more") | |
| output = "\n".join(lines) + "\n" | |
| if mode == "human": | |
| print(output) | |
| return None | |
| return output | |
| # --------------------------------------------------------------------------- | |
| # Quick demo | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| """Run a short demo episode with a simple heuristic policy.""" | |
| print("Adaptive Alert Triage Environment — Demo\n") | |
| env = AdaptiveAlertTriageEnv(task_id="easy", seed=42) | |
| obs: Observation = env.reset() | |
| print(f"Initial observation: {len(obs.alerts)} alerts " | |
| f"(system_load={obs.system_load:.2f})\n") | |
| done = False | |
| step_count = 0 | |
| while not done and step_count < 5: | |
| env.render() | |
| if not obs.alerts: | |
| print("No alerts in queue — nothing to handle.") | |
| break | |
| # Heuristic: pick the alert with the highest visible_severity | |
| best_alert = max(obs.alerts, key=lambda a: a.visible_severity) | |
| action = Action( | |
| alert_id=best_alert.id, | |
| action_type=( | |
| "INVESTIGATE" if best_alert.visible_severity >= 0.7 else "IGNORE" | |
| ), | |
| ) | |
| obs, reward, done, info = env.step(action) | |
| print( | |
| f" Action: {action.action_type} → {best_alert.id}" | |
| f" Reward: {reward.value:+.1f}" | |
| f" Correct: {info.get('action_correct', '?')}" | |
| ) | |
| step_count += 1 | |
| print(f"\nDemo finished after {step_count} steps.") | |
| print(f"Final cumulative reward : {env.cumulative_reward:+.1f}") | |
| print(f"Total system failures : {env.failures_count}") | |
| if __name__ == "__main__": | |
| main() |