""" env.py — `OpenSOCEnv`, the two-role gym-style environment. Lifecycle --------- An OpenSOC episode has *exactly two turns*: Turn 1 (attacker): observation has role="attacker" with `attacker_brief`. The agent submits `craft_incident` with structured params. The env validates the params, runs the plausibility checker, and computes ground truth. Turn 2 (defender): observation has role="defender" with the materialized `alert` and `log_window`. The agent submits `submit_triage`. The env scores both sides and terminates the episode. In `defender_only` mode, the env auto-generates the incident with `generator.generate_incident` and skips straight to turn 2 — useful for SFT, eval, and smoke tests. Mode selection happens via `OpenSOCEnv(mode=...)` or the `?mode=` query param on `/reset`. Anti-hack invariants -------------------- 1. The ground-truth label that drives defender reward is computed by `verifier.compute_ground_truth(params)`, never read from `narrative` or `target_label`. 2. The attacker's reward is gated on `verifier.check_plausibility(params)`. 3. Schema validation (pydantic) errors → schema_violation=True → attacker reward floor of -0.5, *no* defender turn (env auto-dismisses). """ from __future__ import annotations import time import uuid from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, Field, ValidationError from generator import generate_incident, make_alert from rubric import score_attacker, score_defender from schema import ( Action, Alert, CraftIncident, Event, IncidentParams, SubmitTriage, TriageAction, ) from tasks.registry import STAGE_REGISTRY from verifier import check_plausibility, compute_ground_truth Role = Literal["attacker", "defender"] Mode = Literal["self_play", "defender_only"] # --------------------------------------------------------------------------- # Public observation / state types # --------------------------------------------------------------------------- class AttackerBrief(BaseModel): """What the env tells the attacker to produce.""" target_label: TriageAction difficulty: str category_hint: str = "any" class Observation(BaseModel): """Per-turn observation visible to the agent.""" role: Role alert: Optional[Alert] = None log_window: List[Event] = Field(default_factory=list) attacker_brief: Optional[AttackerBrief] = None step: int = 0 max_steps: int = 2 last_action_feedback: str = "" done: bool = False class EpisodeState(BaseModel): """Full internal state returned by /state.""" task_id: str mode: Mode step: int = 0 max_steps: int = 2 done: bool = False role: Role attacker_brief: Optional[AttackerBrief] = None incident_alert: Optional[Alert] = None incident_log_window: List[Event] = Field(default_factory=list) triggering_log_id: Optional[str] = None plausible: Optional[bool] = None plausibility_reason: str = "" schema_violation: bool = False ground_truth: Optional[TriageAction] = None defender_action: Optional[SubmitTriage] = None defender_reward: Optional[float] = None defender_breakdown: Dict[str, float] = Field(default_factory=dict) attacker_reward: Optional[float] = None attacker_breakdown: Dict[str, float] = Field(default_factory=dict) cumulative_reward: float = 0.0 started_at: float = Field(default_factory=time.time) # --------------------------------------------------------------------------- # Environment # --------------------------------------------------------------------------- class OpenSOCEnv: """Two-role SOC triage environment with deterministic verifier rewards.""" MAX_STEPS = 2 def __init__( self, task_id: str = "stage1_basic", mode: Mode = "self_play", seed: int = 0, ): if task_id not in STAGE_REGISTRY: raise ValueError( f"Unknown task '{task_id}'. Choose from: {list(STAGE_REGISTRY)}" ) if mode not in ("self_play", "defender_only"): raise ValueError(f"Unknown mode {mode!r}") self.task_id = task_id self.mode: Mode = mode self.seed = seed self._state: Optional[EpisodeState] = None self._episode_idx = 0 # ------------------------------------------------------------------ # Gym-style API: reset / step / state / grade # ------------------------------------------------------------------ def reset(self) -> Observation: """Start a fresh episode and return the first observation.""" self._episode_idx += 1 episode_seed = self.seed * 100_000 + self._episode_idx + STAGE_REGISTRY[self.task_id]["seed_offset"] if self.mode == "defender_only": params = generate_incident(self.task_id, seed=episode_seed) return self._materialize_for_defender(params, started_role="defender") # self_play: the next /step must be the attacker's craft_incident. # We seed the brief with a target label that's representative of the # stage's distribution, but the attacker is free to ignore it. target_label = self._sample_target_label_for_brief(episode_seed) brief = AttackerBrief( target_label=target_label, difficulty=STAGE_REGISTRY[self.task_id]["difficulty"], category_hint="any", ) self._state = EpisodeState( task_id=self.task_id, mode=self.mode, role="attacker", attacker_brief=brief, max_steps=self.MAX_STEPS, ) return Observation( role="attacker", attacker_brief=brief, step=0, max_steps=self.MAX_STEPS, last_action_feedback=( f"[stage={self.task_id}] Craft an incident whose ground truth " f"is action={target_label.value}. Ignore the target_label hint " f"if you can fool the defender harder with a different one." ), ) def step(self, action: Action) -> tuple[Observation, float, bool, dict]: """Apply one agent action; return (obs, reward, done, info).""" if self._state is None: raise RuntimeError("Call reset() before step()") if self._state.done: raise RuntimeError("Episode is done. Call reset() to start a new one.") s = self._state s.step += 1 if s.role == "attacker": return self._step_attacker(action) return self._step_defender(action) def state(self) -> Dict[str, Any]: """Return the full internal state.""" if self._state is None: return {} return self._state.model_dump(mode="json") def grade(self) -> float: """Return a normalized [0, 1] score for the just-finished episode.""" s = self._state if s is None or not s.done: return 0.0 # Normalize defender reward to [0, 1] using the manifest range. # Defender reward range is [-1.0, 1.1] (max correct + bonus). if s.defender_reward is None: return 0.0 lo, hi = -1.0, 1.1 clamped = max(lo, min(hi, s.defender_reward)) return float((clamped - lo) / (hi - lo)) # ------------------------------------------------------------------ # Attacker turn # ------------------------------------------------------------------ def _step_attacker(self, action: Action) -> tuple[Observation, float, bool, dict]: s = self._state ci: Optional[CraftIncident] = action.craft_incident if ci is None: # Treated as a schema violation: -0.5 attacker reward, episode # ends immediately because we have nothing to show the defender. return self._abort_attacker_turn( "Attacker turn requires craft_incident; got something else." ) try: params = IncidentParams( target_label=ci.target_label, category=ci.category, events=ci.events, narrative=ci.narrative, ) except ValidationError as exc: return self._abort_attacker_turn(f"Schema violation: {exc}") plausible, reason, triggering_log_id = check_plausibility(params) gt_label, _ = compute_ground_truth(params) s.attacker_brief = s.attacker_brief s.role = "defender" s.plausible = plausible s.plausibility_reason = reason s.ground_truth = gt_label s.triggering_log_id = triggering_log_id alert = make_alert(params, alert_id=f"A-{uuid.uuid4().hex[:8]}") s.incident_alert = alert s.incident_log_window = list(params.events) feedback = ( f"Attacker turn complete. plausible={plausible} ({reason}). " "Defender will now triage." ) obs = Observation( role="defender", alert=alert, log_window=list(params.events), step=s.step, max_steps=self.MAX_STEPS, last_action_feedback=feedback, done=False, ) info = { "role_just_acted": "attacker", "plausible": plausible, "plausibility_reason": reason, "ground_truth_hidden_from_defender": gt_label.value, "triggering_log_id": triggering_log_id, } return obs, 0.0, False, info def _abort_attacker_turn(self, reason: str) -> tuple[Observation, float, bool, dict]: s = self._state s.schema_violation = True s.plausible = False s.plausibility_reason = reason attacker_reward, attacker_bd = score_attacker( plausible=False, schema_violation=True, defender_correct=False, novelty=0.0, ) s.attacker_reward = attacker_reward s.attacker_breakdown = attacker_bd s.defender_reward = 0.0 s.cumulative_reward = attacker_reward s.done = True s.role = "defender" return ( Observation( role="defender", step=s.step, max_steps=self.MAX_STEPS, last_action_feedback=f"Episode aborted: {reason}", done=True, ), attacker_reward, True, { "role_just_acted": "attacker", "schema_violation": True, "attacker_reward": attacker_reward, "attacker_breakdown": attacker_bd, }, ) # ------------------------------------------------------------------ # Defender turn # ------------------------------------------------------------------ def _step_defender(self, action: Action) -> tuple[Observation, float, bool, dict]: s = self._state st = action.submit_triage if st is None: # Treat as a missed-malicious-equivalent: penalize by acting as # if the defender chose 'dismiss' with no citation. st = SubmitTriage(action=TriageAction.DISMISS, cited_log_id="L1-0") defender_reward, defender_bd = score_defender( action=st.action, ground_truth=s.ground_truth or TriageAction.DISMISS, triggering_log_id=s.triggering_log_id or "L1-0", cited_log_id=st.cited_log_id, ) defender_correct = st.action is s.ground_truth attacker_reward, attacker_bd = 0.0, {} if s.mode == "self_play": attacker_reward, attacker_bd = score_attacker( plausible=bool(s.plausible), schema_violation=False, defender_correct=defender_correct, novelty=0.0, # filled in by the trainer if it tracks batches ) s.defender_action = st s.defender_reward = defender_reward s.defender_breakdown = defender_bd s.attacker_reward = attacker_reward s.attacker_breakdown = attacker_bd s.cumulative_reward = defender_reward + attacker_reward s.done = True s.role = "defender" feedback = ( f"Defender chose {st.action.value}; ground truth was " f"{(s.ground_truth or TriageAction.DISMISS).value}. " f"Reward={defender_reward:+.2f}." ) obs = Observation( role="defender", alert=s.incident_alert, log_window=list(s.incident_log_window), step=s.step, max_steps=self.MAX_STEPS, last_action_feedback=feedback, done=True, ) info = { "role_just_acted": "defender", "ground_truth": (s.ground_truth or TriageAction.DISMISS).value, "defender_correct": defender_correct, "defender_breakdown": defender_bd, "attacker_reward": attacker_reward, "attacker_breakdown": attacker_bd, "triggering_log_id": s.triggering_log_id, } return obs, defender_reward, True, info # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _materialize_for_defender( self, params: IncidentParams, *, started_role: Role ) -> Observation: """Set up state for a defender_only episode (skip attacker turn).""" plausible, reason, triggering_log_id = check_plausibility(params) gt_label, _ = compute_ground_truth(params) alert = make_alert(params, alert_id=f"A-{uuid.uuid4().hex[:8]}") self._state = EpisodeState( task_id=self.task_id, mode=self.mode, role="defender", incident_alert=alert, incident_log_window=list(params.events), triggering_log_id=triggering_log_id, plausible=plausible, plausibility_reason=reason, ground_truth=gt_label, max_steps=self.MAX_STEPS, ) return Observation( role="defender", alert=alert, log_window=list(params.events), step=0, max_steps=self.MAX_STEPS, last_action_feedback=( f"[stage={self.task_id}, defender_only] Triage this alert." ), ) def _sample_target_label_for_brief(self, seed: int) -> TriageAction: """Pick a brief target label from the stage's label distribution.""" # Reuse the generator's stage config so brief and defender-only # generation are coherent. from generator import STAGE_CONFIGS # local import avoids cycle import random as _random cfg = STAGE_CONFIGS[self.task_id] rng = _random.Random(seed) labels = list(cfg["label_distribution"].keys()) weights = [cfg["label_distribution"][lab] for lab in labels] return rng.choices(labels, weights=weights, k=1)[0] __all__ = [ "AttackerBrief", "Action", "Observation", "EpisodeState", "OpenSOCEnv", ]