Spaces:
Sleeping
Sleeping
| """Reference defender policies + a minimal in-process episode runner. | |
| Two policies are shipped: | |
| * :class:`RandomPolicy` - uniform random over legal (action, target) pairs. | |
| * :class:`HeuristicPolicy` - rule-based defender driven by alert severity, | |
| confirmed-compromise lists, and a small set of thresholds. Encodes the | |
| behavior we expect a competent SOC analyst to exhibit; serves as the | |
| primary baseline a trained LLM defender must beat. | |
| The episode runner :func:`run_episode` drives a single :class:`CybersecEnvironment` | |
| instance with a chosen policy and returns a structured :class:`EpisodeResult`. | |
| This is intentionally separate from the OpenEnv HTTP/WebSocket server so it | |
| can be used inside training notebooks without a network hop. | |
| """ | |
| from __future__ import annotations | |
| import random | |
| from dataclasses import dataclass, field | |
| from typing import Callable, Dict, List, Optional, Protocol | |
| from .server.cybersec_environment import CybersecEnvironment | |
| from .models import ( | |
| ActionType, | |
| AlertEvent, | |
| AlertSignal, | |
| AttackerPersonality, | |
| CybersecAction, | |
| CybersecObservation, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Policy protocol | |
| # --------------------------------------------------------------------------- | |
| class DefenderPolicy(Protocol): | |
| """Stateless or stateful per-episode policy. ``reset`` is optional.""" | |
| name: str | |
| def act(self, observation: CybersecObservation) -> CybersecAction: ... | |
| def reset(self) -> None: ... # no-op default; concrete classes may override | |
| # --------------------------------------------------------------------------- | |
| # Random policy | |
| # --------------------------------------------------------------------------- | |
| class RandomPolicy: | |
| """Uniform random policy constrained to currently-valid (action, target).""" | |
| name = "random" | |
| def __init__(self, seed: int = 0): | |
| self._rng = random.Random(seed) | |
| def reset(self) -> None: # pragma: no cover - intentionally trivial | |
| pass | |
| def act(self, observation: CybersecObservation) -> CybersecAction: | |
| rng = self._rng | |
| atype = rng.choice(observation.available_actions) | |
| if atype is ActionType.MONITOR: | |
| return CybersecAction(action_type=atype) | |
| assets = observation.valid_targets.get("assets", []) | |
| identities = observation.valid_targets.get("identities", []) | |
| if atype is ActionType.INVESTIGATE: | |
| pool = assets + identities | |
| elif atype in (ActionType.ISOLATE_ASSET, ActionType.BLOCK_EGRESS, ActionType.PATCH_ASSET): | |
| pool = assets | |
| else: # REVOKE_IDENTITY | |
| pool = identities | |
| if not pool: | |
| return CybersecAction(action_type=ActionType.MONITOR) | |
| return CybersecAction(action_type=atype, target=rng.choice(pool)) | |
| # --------------------------------------------------------------------------- | |
| # Heuristic policy | |
| # --------------------------------------------------------------------------- | |
| class HeuristicConfig: | |
| # Lower thresholds react sooner to rising alert severity (helps aggressive | |
| # attacker tempo where waiting for 0.55+ lets random luck outperform). | |
| investigate_severity_threshold: float = 0.26 | |
| isolate_severity_threshold: float = 0.48 | |
| revoke_severity_threshold: float = 0.52 | |
| confirmed_isolate_min_confidence: float = 0.65 | |
| monitor_when_quiet: bool = True | |
| class HeuristicPolicy: | |
| """Decision tree over alerts + forensics. | |
| Priority order each tick: | |
| 1. Containment of any *confirmed-compromised* asset/identity that is | |
| not yet isolated/revoked. | |
| 2. Escalation to ISOLATE/REVOKE on the highest-severity alert if it | |
| crosses the isolation threshold. | |
| 3. INVESTIGATE the most suspicious-looking unconfirmed target. | |
| 4. PATCH a never-touched, on-attack-path asset (proactive hardening). | |
| 5. MONITOR. | |
| """ | |
| name = "heuristic" | |
| def __init__(self, config: Optional[HeuristicConfig] = None): | |
| self.cfg = config or HeuristicConfig() | |
| self._investigated: set[str] = set() | |
| def reset(self) -> None: | |
| self._investigated = set() | |
| # ------------------------------------------------------------------ helpers | |
| def _alert_target(alert: AlertEvent) -> Optional[str]: | |
| return alert.asset or alert.identity | |
| def _strongest_alert(observation: CybersecObservation) -> Optional[AlertEvent]: | |
| best: Optional[AlertEvent] = None | |
| for a in observation.alerts: | |
| if a.signal is AlertSignal.BACKGROUND_NOISE: | |
| continue | |
| if best is None or a.severity > best.severity: | |
| best = a | |
| return best | |
| def _is_asset(self, target: str, obs: CybersecObservation) -> bool: | |
| return target in obs.valid_targets.get("assets", []) | |
| def _is_identity(self, target: str, obs: CybersecObservation) -> bool: | |
| return target in obs.valid_targets.get("identities", []) | |
| # ------------------------------------------------------------------ act | |
| def act(self, observation: CybersecObservation) -> CybersecAction: | |
| cfg = self.cfg | |
| # 1. Containment of confirmed-compromised targets we haven't sealed. | |
| for target in observation.confirmed_compromised: | |
| if self._is_asset(target, observation): | |
| if target not in observation.isolated_assets: | |
| return CybersecAction(action_type=ActionType.ISOLATE_ASSET, target=target) | |
| elif self._is_identity(target, observation): | |
| if target not in observation.revoked_identities: | |
| return CybersecAction(action_type=ActionType.REVOKE_IDENTITY, target=target) | |
| strongest = self._strongest_alert(observation) | |
| # 2. Escalate on a high-severity alert with a clear target. | |
| if strongest is not None: | |
| target = self._alert_target(strongest) | |
| if target is not None: | |
| if ( | |
| strongest.severity >= cfg.isolate_severity_threshold | |
| and self._is_asset(target, observation) | |
| and target not in observation.isolated_assets | |
| ): | |
| if strongest.signal is AlertSignal.EGRESS_ANOMALY and target not in observation.blocked_egress_assets: | |
| return CybersecAction(action_type=ActionType.BLOCK_EGRESS, target=target) | |
| return CybersecAction(action_type=ActionType.ISOLATE_ASSET, target=target) | |
| if ( | |
| strongest.severity >= cfg.revoke_severity_threshold | |
| and self._is_identity(target, observation) | |
| and target not in observation.revoked_identities | |
| ): | |
| return CybersecAction(action_type=ActionType.REVOKE_IDENTITY, target=target) | |
| # 3. Investigate a moderately-suspicious untouched target. | |
| if strongest is not None and strongest.severity >= cfg.investigate_severity_threshold: | |
| target = self._alert_target(strongest) | |
| if target is not None and target not in self._investigated: | |
| self._investigated.add(target) | |
| return CybersecAction(action_type=ActionType.INVESTIGATE, target=target) | |
| # 4. Proactively patch the first never-patched asset that already had | |
| # *any* prior alert (cheap proxy for "the policy thought about it"). | |
| seen_assets = {self._alert_target(a) for a in observation.alerts} | |
| seen_assets.discard(None) | |
| for asset in observation.valid_targets.get("assets", []): | |
| if asset in seen_assets and asset not in observation.patched_assets: | |
| return CybersecAction(action_type=ActionType.PATCH_ASSET, target=asset) | |
| # 5. Monitor. | |
| return CybersecAction(action_type=ActionType.MONITOR) | |
| # --------------------------------------------------------------------------- | |
| # Episode runner + result | |
| # --------------------------------------------------------------------------- | |
| class EpisodeResult: | |
| policy_name: str | |
| scenario_id: str | |
| attacker_personality: str | |
| seed: int | |
| steps: int | |
| cumulative_reward: float | |
| succeeded_stage_count: int | |
| total_stage_count: int | |
| exfil_completed: bool | |
| terminal_reason: Optional[str] | |
| confirmed_compromised: List[str] | |
| invalid_action_count: int | |
| false_positive_count: int | |
| # ``invalid_action_count`` counts steps the *server* rejected as | |
| # invalid (target missing from valid_targets, etc). LLM policies that | |
| # silently fall back to ``MONITOR`` on parse failure trivially score | |
| # 0 on that metric while still being garbage. ``monitor_fallback_count`` | |
| # is set by the policy itself via ``policy.last_act_was_fallback`` | |
| # before each action is returned, and bubbles up through ``run_episode``. | |
| monitor_fallback_count: int = 0 | |
| reward_curve: List[float] = field(default_factory=list) | |
| def run_episode( | |
| env: CybersecEnvironment, | |
| policy: DefenderPolicy, | |
| seed: int, | |
| scenario_id: Optional[str] = None, | |
| attacker_personality: Optional[AttackerPersonality] = None, | |
| ) -> EpisodeResult: | |
| """Drive ``env`` with ``policy`` for a full episode, return diagnostics.""" | |
| if hasattr(policy, "reset"): | |
| policy.reset() | |
| reset_kwargs: Dict[str, object] = {} | |
| if scenario_id is not None: | |
| reset_kwargs["scenario_id"] = scenario_id | |
| if attacker_personality is not None: | |
| reset_kwargs["attacker_personality"] = attacker_personality | |
| obs = env.reset(seed=seed, **reset_kwargs) | |
| invalid = 0 | |
| fp = 0 | |
| fallback = 0 | |
| reward_curve: List[float] = [] | |
| while not obs.done: | |
| action = policy.act(obs) | |
| # Policies that emit text and parse it (LLM policies in particular) | |
| # set this attribute *inside* `act()` so we can distinguish a | |
| # genuine MONITOR decision from a parse-failure fallback. | |
| if getattr(policy, "last_act_was_fallback", False): | |
| fallback += 1 | |
| obs = env.step(action) | |
| rb = obs.info.get("reward_breakdown", {}) if isinstance(obs.info, dict) else {} | |
| if rb.get("invalid_action_penalty", 0.0): | |
| invalid += 1 | |
| if rb.get("false_positive_penalty", 0.0): | |
| fp += 1 | |
| reward_curve.append(float(obs.reward or 0.0)) | |
| terminal = obs.info.get("terminal", {}) if isinstance(obs.info, dict) else {} | |
| return EpisodeResult( | |
| policy_name=policy.name, | |
| scenario_id=obs.scenario_id, | |
| attacker_personality=obs.attacker_personality.value, | |
| seed=seed, | |
| steps=obs.tick, | |
| cumulative_reward=float(terminal.get("cumulative_reward", sum(reward_curve))), | |
| succeeded_stage_count=int(terminal.get("stages_succeeded", 0)), | |
| total_stage_count=int(terminal.get("stages_total", 0)), | |
| exfil_completed=bool(terminal.get("exfil_completed", False)), | |
| terminal_reason=terminal.get("terminal_reason"), | |
| confirmed_compromised=list(obs.confirmed_compromised), | |
| invalid_action_count=invalid, | |
| false_positive_count=fp, | |
| monitor_fallback_count=fallback, | |
| reward_curve=reward_curve, | |
| ) | |
| def aggregate_results(results: List[EpisodeResult]) -> Dict[str, float]: | |
| """Mean over a batch of episode results, suitable for paper-style tables.""" | |
| if not results: | |
| return {} | |
| n = len(results) | |
| total_steps = sum(r.steps for r in results) or 1 | |
| return { | |
| "n_episodes": n, | |
| "mean_return": round(sum(r.cumulative_reward for r in results) / n, 3), | |
| "mean_steps": round(sum(r.steps for r in results) / n, 2), | |
| "exfil_rate": round(sum(1 for r in results if r.exfil_completed) / n, 3), | |
| "mean_stages_succeeded": round( | |
| sum(r.succeeded_stage_count for r in results) / n, 3 | |
| ), | |
| "mean_invalid_actions": round(sum(r.invalid_action_count for r in results) / n, 3), | |
| "mean_false_positives": round(sum(r.false_positive_count for r in results) / n, 3), | |
| "monitor_fallback_rate": round( | |
| sum(r.monitor_fallback_count for r in results) / total_steps, 3 | |
| ), | |
| "containment_rate": round( | |
| sum( | |
| 1 | |
| for r in results | |
| if r.total_stage_count | |
| and (r.total_stage_count - r.succeeded_stage_count) >= max(1, r.total_stage_count // 2) | |
| ) | |
| / n, | |
| 3, | |
| ), | |
| } | |
| __all__ = [ | |
| "DefenderPolicy", | |
| "RandomPolicy", | |
| "HeuristicConfig", | |
| "HeuristicPolicy", | |
| "EpisodeResult", | |
| "run_episode", | |
| "aggregate_results", | |
| ] | |