| """ |
| env.py — `OpenSOCEnv`, the two-role gym-style environment. |
| |
| Lifecycle |
| --------- |
| An OpenSOC episode has *exactly two turns*: |
| |
| Turn 1 (attacker): observation has role="attacker" with `attacker_brief`. |
| The agent submits `craft_incident` with structured |
| params. The env validates the params, runs the |
| plausibility checker, and computes ground truth. |
| |
| Turn 2 (defender): observation has role="defender" with the materialized |
| `alert` and `log_window`. The agent submits |
| `submit_triage`. The env scores both sides and |
| terminates the episode. |
| |
| In `defender_only` mode, the env auto-generates the incident with |
| `generator.generate_incident` and skips straight to turn 2 — useful for |
| SFT, eval, and smoke tests. |
| |
| Mode selection happens via `OpenSOCEnv(mode=...)` or the `?mode=` query |
| param on `/reset`. |
| |
| Anti-hack invariants |
| -------------------- |
| 1. The ground-truth label that drives defender reward is computed by |
| `verifier.compute_ground_truth(params)`, never read from `narrative` |
| or `target_label`. |
| 2. The attacker's reward is gated on `verifier.check_plausibility(params)`. |
| 3. Schema validation (pydantic) errors → schema_violation=True → |
| attacker reward floor of -0.5, *no* defender turn (env auto-dismisses). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import time |
| import uuid |
| from typing import Any, Dict, List, Literal, Optional |
|
|
| from pydantic import BaseModel, Field, ValidationError |
|
|
| from generator import generate_incident, make_alert |
| from rubric import score_attacker, score_defender |
| from schema import ( |
| Action, |
| Alert, |
| CraftIncident, |
| Event, |
| IncidentParams, |
| SubmitTriage, |
| TriageAction, |
| ) |
| from tasks.registry import STAGE_REGISTRY |
| from verifier import check_plausibility, compute_ground_truth |
|
|
|
|
| Role = Literal["attacker", "defender"] |
| Mode = Literal["self_play", "defender_only"] |
|
|
|
|
| |
| |
| |
|
|
| class AttackerBrief(BaseModel): |
| """What the env tells the attacker to produce.""" |
| target_label: TriageAction |
| difficulty: str |
| category_hint: str = "any" |
|
|
|
|
| class Observation(BaseModel): |
| """Per-turn observation visible to the agent.""" |
| role: Role |
| alert: Optional[Alert] = None |
| log_window: List[Event] = Field(default_factory=list) |
| attacker_brief: Optional[AttackerBrief] = None |
| step: int = 0 |
| max_steps: int = 2 |
| last_action_feedback: str = "" |
| done: bool = False |
|
|
|
|
| class EpisodeState(BaseModel): |
| """Full internal state returned by /state.""" |
| task_id: str |
| mode: Mode |
| step: int = 0 |
| max_steps: int = 2 |
| done: bool = False |
| role: Role |
| attacker_brief: Optional[AttackerBrief] = None |
| incident_alert: Optional[Alert] = None |
| incident_log_window: List[Event] = Field(default_factory=list) |
| triggering_log_id: Optional[str] = None |
| plausible: Optional[bool] = None |
| plausibility_reason: str = "" |
| schema_violation: bool = False |
| ground_truth: Optional[TriageAction] = None |
| defender_action: Optional[SubmitTriage] = None |
| defender_reward: Optional[float] = None |
| defender_breakdown: Dict[str, float] = Field(default_factory=dict) |
| attacker_reward: Optional[float] = None |
| attacker_breakdown: Dict[str, float] = Field(default_factory=dict) |
| cumulative_reward: float = 0.0 |
| started_at: float = Field(default_factory=time.time) |
|
|
|
|
| |
| |
| |
|
|
| class OpenSOCEnv: |
| """Two-role SOC triage environment with deterministic verifier rewards.""" |
|
|
| MAX_STEPS = 2 |
|
|
| def __init__( |
| self, |
| task_id: str = "stage1_basic", |
| mode: Mode = "self_play", |
| seed: int = 0, |
| ): |
| if task_id not in STAGE_REGISTRY: |
| raise ValueError( |
| f"Unknown task '{task_id}'. Choose from: {list(STAGE_REGISTRY)}" |
| ) |
| if mode not in ("self_play", "defender_only"): |
| raise ValueError(f"Unknown mode {mode!r}") |
| self.task_id = task_id |
| self.mode: Mode = mode |
| self.seed = seed |
| self._state: Optional[EpisodeState] = None |
| self._episode_idx = 0 |
|
|
| |
| |
| |
|
|
| def reset(self) -> Observation: |
| """Start a fresh episode and return the first observation.""" |
| self._episode_idx += 1 |
| episode_seed = self.seed * 100_000 + self._episode_idx + STAGE_REGISTRY[self.task_id]["seed_offset"] |
|
|
| if self.mode == "defender_only": |
| params = generate_incident(self.task_id, seed=episode_seed) |
| return self._materialize_for_defender(params, started_role="defender") |
|
|
| |
| |
| |
| target_label = self._sample_target_label_for_brief(episode_seed) |
| brief = AttackerBrief( |
| target_label=target_label, |
| difficulty=STAGE_REGISTRY[self.task_id]["difficulty"], |
| category_hint="any", |
| ) |
| self._state = EpisodeState( |
| task_id=self.task_id, |
| mode=self.mode, |
| role="attacker", |
| attacker_brief=brief, |
| max_steps=self.MAX_STEPS, |
| ) |
| return Observation( |
| role="attacker", |
| attacker_brief=brief, |
| step=0, |
| max_steps=self.MAX_STEPS, |
| last_action_feedback=( |
| f"[stage={self.task_id}] Craft an incident whose ground truth " |
| f"is action={target_label.value}. Ignore the target_label hint " |
| f"if you can fool the defender harder with a different one." |
| ), |
| ) |
|
|
| def step(self, action: Action) -> tuple[Observation, float, bool, dict]: |
| """Apply one agent action; return (obs, reward, done, info).""" |
| if self._state is None: |
| raise RuntimeError("Call reset() before step()") |
| if self._state.done: |
| raise RuntimeError("Episode is done. Call reset() to start a new one.") |
|
|
| s = self._state |
| s.step += 1 |
|
|
| if s.role == "attacker": |
| return self._step_attacker(action) |
| return self._step_defender(action) |
|
|
| def state(self) -> Dict[str, Any]: |
| """Return the full internal state.""" |
| if self._state is None: |
| return {} |
| return self._state.model_dump(mode="json") |
|
|
| def grade(self) -> float: |
| """Return a normalized [0, 1] score for the just-finished episode.""" |
| s = self._state |
| if s is None or not s.done: |
| return 0.0 |
| |
| |
| if s.defender_reward is None: |
| return 0.0 |
| lo, hi = -1.0, 1.1 |
| clamped = max(lo, min(hi, s.defender_reward)) |
| return float((clamped - lo) / (hi - lo)) |
|
|
| |
| |
| |
|
|
| def _step_attacker(self, action: Action) -> tuple[Observation, float, bool, dict]: |
| s = self._state |
| ci: Optional[CraftIncident] = action.craft_incident |
| if ci is None: |
| |
| |
| return self._abort_attacker_turn( |
| "Attacker turn requires craft_incident; got something else." |
| ) |
|
|
| try: |
| params = IncidentParams( |
| target_label=ci.target_label, |
| category=ci.category, |
| events=ci.events, |
| narrative=ci.narrative, |
| ) |
| except ValidationError as exc: |
| return self._abort_attacker_turn(f"Schema violation: {exc}") |
|
|
| plausible, reason, triggering_log_id = check_plausibility(params) |
| gt_label, _ = compute_ground_truth(params) |
|
|
| s.attacker_brief = s.attacker_brief |
| s.role = "defender" |
| s.plausible = plausible |
| s.plausibility_reason = reason |
| s.ground_truth = gt_label |
| s.triggering_log_id = triggering_log_id |
|
|
| alert = make_alert(params, alert_id=f"A-{uuid.uuid4().hex[:8]}") |
| s.incident_alert = alert |
| s.incident_log_window = list(params.events) |
|
|
| feedback = ( |
| f"Attacker turn complete. plausible={plausible} ({reason}). " |
| "Defender will now triage." |
| ) |
|
|
| obs = Observation( |
| role="defender", |
| alert=alert, |
| log_window=list(params.events), |
| step=s.step, |
| max_steps=self.MAX_STEPS, |
| last_action_feedback=feedback, |
| done=False, |
| ) |
| info = { |
| "role_just_acted": "attacker", |
| "plausible": plausible, |
| "plausibility_reason": reason, |
| "ground_truth_hidden_from_defender": gt_label.value, |
| "triggering_log_id": triggering_log_id, |
| } |
| return obs, 0.0, False, info |
|
|
| def _abort_attacker_turn(self, reason: str) -> tuple[Observation, float, bool, dict]: |
| s = self._state |
| s.schema_violation = True |
| s.plausible = False |
| s.plausibility_reason = reason |
| attacker_reward, attacker_bd = score_attacker( |
| plausible=False, schema_violation=True, |
| defender_correct=False, novelty=0.0, |
| ) |
| s.attacker_reward = attacker_reward |
| s.attacker_breakdown = attacker_bd |
| s.defender_reward = 0.0 |
| s.cumulative_reward = attacker_reward |
| s.done = True |
| s.role = "defender" |
| return ( |
| Observation( |
| role="defender", |
| step=s.step, |
| max_steps=self.MAX_STEPS, |
| last_action_feedback=f"Episode aborted: {reason}", |
| done=True, |
| ), |
| attacker_reward, |
| True, |
| { |
| "role_just_acted": "attacker", |
| "schema_violation": True, |
| "attacker_reward": attacker_reward, |
| "attacker_breakdown": attacker_bd, |
| }, |
| ) |
|
|
| |
| |
| |
|
|
| def _step_defender(self, action: Action) -> tuple[Observation, float, bool, dict]: |
| s = self._state |
| st = action.submit_triage |
| if st is None: |
| |
| |
| st = SubmitTriage(action=TriageAction.DISMISS, cited_log_id="L1-0") |
|
|
| defender_reward, defender_bd = score_defender( |
| action=st.action, |
| ground_truth=s.ground_truth or TriageAction.DISMISS, |
| triggering_log_id=s.triggering_log_id or "L1-0", |
| cited_log_id=st.cited_log_id, |
| ) |
| defender_correct = st.action is s.ground_truth |
|
|
| attacker_reward, attacker_bd = 0.0, {} |
| if s.mode == "self_play": |
| attacker_reward, attacker_bd = score_attacker( |
| plausible=bool(s.plausible), |
| schema_violation=False, |
| defender_correct=defender_correct, |
| novelty=0.0, |
| ) |
|
|
| s.defender_action = st |
| s.defender_reward = defender_reward |
| s.defender_breakdown = defender_bd |
| s.attacker_reward = attacker_reward |
| s.attacker_breakdown = attacker_bd |
| s.cumulative_reward = defender_reward + attacker_reward |
| s.done = True |
| s.role = "defender" |
|
|
| feedback = ( |
| f"Defender chose {st.action.value}; ground truth was " |
| f"{(s.ground_truth or TriageAction.DISMISS).value}. " |
| f"Reward={defender_reward:+.2f}." |
| ) |
| obs = Observation( |
| role="defender", |
| alert=s.incident_alert, |
| log_window=list(s.incident_log_window), |
| step=s.step, |
| max_steps=self.MAX_STEPS, |
| last_action_feedback=feedback, |
| done=True, |
| ) |
| info = { |
| "role_just_acted": "defender", |
| "ground_truth": (s.ground_truth or TriageAction.DISMISS).value, |
| "defender_correct": defender_correct, |
| "defender_breakdown": defender_bd, |
| "attacker_reward": attacker_reward, |
| "attacker_breakdown": attacker_bd, |
| "triggering_log_id": s.triggering_log_id, |
| } |
| return obs, defender_reward, True, info |
|
|
| |
| |
| |
|
|
| def _materialize_for_defender( |
| self, params: IncidentParams, *, started_role: Role |
| ) -> Observation: |
| """Set up state for a defender_only episode (skip attacker turn).""" |
| plausible, reason, triggering_log_id = check_plausibility(params) |
| gt_label, _ = compute_ground_truth(params) |
| alert = make_alert(params, alert_id=f"A-{uuid.uuid4().hex[:8]}") |
|
|
| self._state = EpisodeState( |
| task_id=self.task_id, |
| mode=self.mode, |
| role="defender", |
| incident_alert=alert, |
| incident_log_window=list(params.events), |
| triggering_log_id=triggering_log_id, |
| plausible=plausible, |
| plausibility_reason=reason, |
| ground_truth=gt_label, |
| max_steps=self.MAX_STEPS, |
| ) |
|
|
| return Observation( |
| role="defender", |
| alert=alert, |
| log_window=list(params.events), |
| step=0, |
| max_steps=self.MAX_STEPS, |
| last_action_feedback=( |
| f"[stage={self.task_id}, defender_only] Triage this alert." |
| ), |
| ) |
|
|
| def _sample_target_label_for_brief(self, seed: int) -> TriageAction: |
| """Pick a brief target label from the stage's label distribution.""" |
| |
| |
| from generator import STAGE_CONFIGS |
| import random as _random |
| cfg = STAGE_CONFIGS[self.task_id] |
| rng = _random.Random(seed) |
| labels = list(cfg["label_distribution"].keys()) |
| weights = [cfg["label_distribution"][lab] for lab in labels] |
| return rng.choices(labels, weights=weights, k=1)[0] |
|
|
|
|
| __all__ = [ |
| "AttackerBrief", |
| "Action", |
| "Observation", |
| "EpisodeState", |
| "OpenSOCEnv", |
| ] |
|
|