Spaces:
Sleeping
Sleeping
| """ | |
| env.py — `OpenSOCEnv`, the two-role gym-style environment. | |
| Lifecycle | |
| --------- | |
| An OpenSOC episode has *exactly two turns*: | |
| Turn 1 (attacker): observation has role="attacker" with `attacker_brief`. | |
| The agent submits `craft_incident` with structured | |
| params. The env validates the params, runs the | |
| plausibility checker, and computes ground truth. | |
| Turn 2 (defender): observation has role="defender" with the materialized | |
| `alert` and `log_window`. The agent submits | |
| `submit_triage`. The env scores both sides and | |
| terminates the episode. | |
| In `defender_only` mode, the env auto-generates the incident with | |
| `generator.generate_incident` and skips straight to turn 2 — useful for | |
| SFT, eval, and smoke tests. | |
| Mode selection happens via `OpenSOCEnv(mode=...)` or the `?mode=` query | |
| param on `/reset`. | |
| Anti-hack invariants | |
| -------------------- | |
| 1. The ground-truth label that drives defender reward is computed by | |
| `verifier.compute_ground_truth(params)`, never read from `narrative` | |
| or `target_label`. | |
| 2. The attacker's reward is gated on `verifier.check_plausibility(params)`. | |
| 3. Schema validation (pydantic) errors → schema_violation=True → | |
| attacker reward floor of -0.5, *no* defender turn (env auto-dismisses). | |
| """ | |
| from __future__ import annotations | |
| import time | |
| import uuid | |
| from typing import Any, Dict, List, Literal, Optional | |
| from pydantic import BaseModel, Field, ValidationError | |
| from generator import generate_incident, make_alert | |
| from rubric import score_attacker, score_defender | |
| from schema import ( | |
| Action, | |
| Alert, | |
| CraftIncident, | |
| Event, | |
| IncidentParams, | |
| SubmitTriage, | |
| TriageAction, | |
| ) | |
| from tasks.registry import STAGE_REGISTRY | |
| from verifier import check_plausibility, compute_ground_truth | |
| Role = Literal["attacker", "defender"] | |
| Mode = Literal["self_play", "defender_only"] | |
| # --------------------------------------------------------------------------- | |
| # Public observation / state types | |
| # --------------------------------------------------------------------------- | |
| class AttackerBrief(BaseModel): | |
| """What the env tells the attacker to produce.""" | |
| target_label: TriageAction | |
| difficulty: str | |
| category_hint: str = "any" | |
| class Observation(BaseModel): | |
| """Per-turn observation visible to the agent.""" | |
| role: Role | |
| alert: Optional[Alert] = None | |
| log_window: List[Event] = Field(default_factory=list) | |
| attacker_brief: Optional[AttackerBrief] = None | |
| step: int = 0 | |
| max_steps: int = 2 | |
| last_action_feedback: str = "" | |
| done: bool = False | |
| class EpisodeState(BaseModel): | |
| """Full internal state returned by /state.""" | |
| task_id: str | |
| mode: Mode | |
| step: int = 0 | |
| max_steps: int = 2 | |
| done: bool = False | |
| role: Role | |
| attacker_brief: Optional[AttackerBrief] = None | |
| incident_alert: Optional[Alert] = None | |
| incident_log_window: List[Event] = Field(default_factory=list) | |
| triggering_log_id: Optional[str] = None | |
| plausible: Optional[bool] = None | |
| plausibility_reason: str = "" | |
| schema_violation: bool = False | |
| ground_truth: Optional[TriageAction] = None | |
| defender_action: Optional[SubmitTriage] = None | |
| defender_reward: Optional[float] = None | |
| defender_breakdown: Dict[str, float] = Field(default_factory=dict) | |
| attacker_reward: Optional[float] = None | |
| attacker_breakdown: Dict[str, float] = Field(default_factory=dict) | |
| cumulative_reward: float = 0.0 | |
| started_at: float = Field(default_factory=time.time) | |
| # --------------------------------------------------------------------------- | |
| # Environment | |
| # --------------------------------------------------------------------------- | |
| class OpenSOCEnv: | |
| """Two-role SOC triage environment with deterministic verifier rewards.""" | |
| MAX_STEPS = 2 | |
| def __init__( | |
| self, | |
| task_id: str = "stage1_basic", | |
| mode: Mode = "self_play", | |
| seed: int = 0, | |
| ): | |
| if task_id not in STAGE_REGISTRY: | |
| raise ValueError( | |
| f"Unknown task '{task_id}'. Choose from: {list(STAGE_REGISTRY)}" | |
| ) | |
| if mode not in ("self_play", "defender_only"): | |
| raise ValueError(f"Unknown mode {mode!r}") | |
| self.task_id = task_id | |
| self.mode: Mode = mode | |
| self.seed = seed | |
| self._state: Optional[EpisodeState] = None | |
| self._episode_idx = 0 | |
| # ------------------------------------------------------------------ | |
| # Gym-style API: reset / step / state / grade | |
| # ------------------------------------------------------------------ | |
| def reset(self) -> Observation: | |
| """Start a fresh episode and return the first observation.""" | |
| self._episode_idx += 1 | |
| episode_seed = self.seed * 100_000 + self._episode_idx + STAGE_REGISTRY[self.task_id]["seed_offset"] | |
| if self.mode == "defender_only": | |
| params = generate_incident(self.task_id, seed=episode_seed) | |
| return self._materialize_for_defender(params, started_role="defender") | |
| # self_play: the next /step must be the attacker's craft_incident. | |
| # We seed the brief with a target label that's representative of the | |
| # stage's distribution, but the attacker is free to ignore it. | |
| target_label = self._sample_target_label_for_brief(episode_seed) | |
| brief = AttackerBrief( | |
| target_label=target_label, | |
| difficulty=STAGE_REGISTRY[self.task_id]["difficulty"], | |
| category_hint="any", | |
| ) | |
| self._state = EpisodeState( | |
| task_id=self.task_id, | |
| mode=self.mode, | |
| role="attacker", | |
| attacker_brief=brief, | |
| max_steps=self.MAX_STEPS, | |
| ) | |
| return Observation( | |
| role="attacker", | |
| attacker_brief=brief, | |
| step=0, | |
| max_steps=self.MAX_STEPS, | |
| last_action_feedback=( | |
| f"[stage={self.task_id}] Craft an incident whose ground truth " | |
| f"is action={target_label.value}. Ignore the target_label hint " | |
| f"if you can fool the defender harder with a different one." | |
| ), | |
| ) | |
| def step(self, action: Action) -> tuple[Observation, float, bool, dict]: | |
| """Apply one agent action; return (obs, reward, done, info).""" | |
| if self._state is None: | |
| raise RuntimeError("Call reset() before step()") | |
| if self._state.done: | |
| raise RuntimeError("Episode is done. Call reset() to start a new one.") | |
| s = self._state | |
| s.step += 1 | |
| if s.role == "attacker": | |
| return self._step_attacker(action) | |
| return self._step_defender(action) | |
| def state(self) -> Dict[str, Any]: | |
| """Return the full internal state.""" | |
| if self._state is None: | |
| return {} | |
| return self._state.model_dump(mode="json") | |
| def grade(self) -> float: | |
| """Return a normalized [0, 1] score for the just-finished episode.""" | |
| s = self._state | |
| if s is None or not s.done: | |
| return 0.0 | |
| # Normalize defender reward to [0, 1] using the manifest range. | |
| # Defender reward range is [-1.0, 1.1] (max correct + bonus). | |
| if s.defender_reward is None: | |
| return 0.0 | |
| lo, hi = -1.0, 1.1 | |
| clamped = max(lo, min(hi, s.defender_reward)) | |
| return float((clamped - lo) / (hi - lo)) | |
| # ------------------------------------------------------------------ | |
| # Attacker turn | |
| # ------------------------------------------------------------------ | |
| def _step_attacker(self, action: Action) -> tuple[Observation, float, bool, dict]: | |
| s = self._state | |
| ci: Optional[CraftIncident] = action.craft_incident | |
| if ci is None: | |
| # Treated as a schema violation: -0.5 attacker reward, episode | |
| # ends immediately because we have nothing to show the defender. | |
| return self._abort_attacker_turn( | |
| "Attacker turn requires craft_incident; got something else." | |
| ) | |
| try: | |
| params = IncidentParams( | |
| target_label=ci.target_label, | |
| category=ci.category, | |
| events=ci.events, | |
| narrative=ci.narrative, | |
| ) | |
| except ValidationError as exc: | |
| return self._abort_attacker_turn(f"Schema violation: {exc}") | |
| plausible, reason, triggering_log_id = check_plausibility(params) | |
| gt_label, _ = compute_ground_truth(params) | |
| s.attacker_brief = s.attacker_brief | |
| s.role = "defender" | |
| s.plausible = plausible | |
| s.plausibility_reason = reason | |
| s.ground_truth = gt_label | |
| s.triggering_log_id = triggering_log_id | |
| alert = make_alert(params, alert_id=f"A-{uuid.uuid4().hex[:8]}") | |
| s.incident_alert = alert | |
| s.incident_log_window = list(params.events) | |
| feedback = ( | |
| f"Attacker turn complete. plausible={plausible} ({reason}). " | |
| "Defender will now triage." | |
| ) | |
| obs = Observation( | |
| role="defender", | |
| alert=alert, | |
| log_window=list(params.events), | |
| step=s.step, | |
| max_steps=self.MAX_STEPS, | |
| last_action_feedback=feedback, | |
| done=False, | |
| ) | |
| info = { | |
| "role_just_acted": "attacker", | |
| "plausible": plausible, | |
| "plausibility_reason": reason, | |
| "ground_truth_hidden_from_defender": gt_label.value, | |
| "triggering_log_id": triggering_log_id, | |
| } | |
| return obs, 0.0, False, info | |
| def _abort_attacker_turn(self, reason: str) -> tuple[Observation, float, bool, dict]: | |
| s = self._state | |
| s.schema_violation = True | |
| s.plausible = False | |
| s.plausibility_reason = reason | |
| attacker_reward, attacker_bd = score_attacker( | |
| plausible=False, schema_violation=True, | |
| defender_correct=False, novelty=0.0, | |
| ) | |
| s.attacker_reward = attacker_reward | |
| s.attacker_breakdown = attacker_bd | |
| s.defender_reward = 0.0 | |
| s.cumulative_reward = attacker_reward | |
| s.done = True | |
| s.role = "defender" | |
| return ( | |
| Observation( | |
| role="defender", | |
| step=s.step, | |
| max_steps=self.MAX_STEPS, | |
| last_action_feedback=f"Episode aborted: {reason}", | |
| done=True, | |
| ), | |
| attacker_reward, | |
| True, | |
| { | |
| "role_just_acted": "attacker", | |
| "schema_violation": True, | |
| "attacker_reward": attacker_reward, | |
| "attacker_breakdown": attacker_bd, | |
| }, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Defender turn | |
| # ------------------------------------------------------------------ | |
| def _step_defender(self, action: Action) -> tuple[Observation, float, bool, dict]: | |
| s = self._state | |
| st = action.submit_triage | |
| if st is None: | |
| # Treat as a missed-malicious-equivalent: penalize by acting as | |
| # if the defender chose 'dismiss' with no citation. | |
| st = SubmitTriage(action=TriageAction.DISMISS, cited_log_id="L1-0") | |
| defender_reward, defender_bd = score_defender( | |
| action=st.action, | |
| ground_truth=s.ground_truth or TriageAction.DISMISS, | |
| triggering_log_id=s.triggering_log_id or "L1-0", | |
| cited_log_id=st.cited_log_id, | |
| ) | |
| defender_correct = st.action is s.ground_truth | |
| attacker_reward, attacker_bd = 0.0, {} | |
| if s.mode == "self_play": | |
| attacker_reward, attacker_bd = score_attacker( | |
| plausible=bool(s.plausible), | |
| schema_violation=False, | |
| defender_correct=defender_correct, | |
| novelty=0.0, # filled in by the trainer if it tracks batches | |
| ) | |
| s.defender_action = st | |
| s.defender_reward = defender_reward | |
| s.defender_breakdown = defender_bd | |
| s.attacker_reward = attacker_reward | |
| s.attacker_breakdown = attacker_bd | |
| s.cumulative_reward = defender_reward + attacker_reward | |
| s.done = True | |
| s.role = "defender" | |
| feedback = ( | |
| f"Defender chose {st.action.value}; ground truth was " | |
| f"{(s.ground_truth or TriageAction.DISMISS).value}. " | |
| f"Reward={defender_reward:+.2f}." | |
| ) | |
| obs = Observation( | |
| role="defender", | |
| alert=s.incident_alert, | |
| log_window=list(s.incident_log_window), | |
| step=s.step, | |
| max_steps=self.MAX_STEPS, | |
| last_action_feedback=feedback, | |
| done=True, | |
| ) | |
| info = { | |
| "role_just_acted": "defender", | |
| "ground_truth": (s.ground_truth or TriageAction.DISMISS).value, | |
| "defender_correct": defender_correct, | |
| "defender_breakdown": defender_bd, | |
| "attacker_reward": attacker_reward, | |
| "attacker_breakdown": attacker_bd, | |
| "triggering_log_id": s.triggering_log_id, | |
| } | |
| return obs, defender_reward, True, info | |
| # ------------------------------------------------------------------ | |
| # Helpers | |
| # ------------------------------------------------------------------ | |
| def _materialize_for_defender( | |
| self, params: IncidentParams, *, started_role: Role | |
| ) -> Observation: | |
| """Set up state for a defender_only episode (skip attacker turn).""" | |
| plausible, reason, triggering_log_id = check_plausibility(params) | |
| gt_label, _ = compute_ground_truth(params) | |
| alert = make_alert(params, alert_id=f"A-{uuid.uuid4().hex[:8]}") | |
| self._state = EpisodeState( | |
| task_id=self.task_id, | |
| mode=self.mode, | |
| role="defender", | |
| incident_alert=alert, | |
| incident_log_window=list(params.events), | |
| triggering_log_id=triggering_log_id, | |
| plausible=plausible, | |
| plausibility_reason=reason, | |
| ground_truth=gt_label, | |
| max_steps=self.MAX_STEPS, | |
| ) | |
| return Observation( | |
| role="defender", | |
| alert=alert, | |
| log_window=list(params.events), | |
| step=0, | |
| max_steps=self.MAX_STEPS, | |
| last_action_feedback=( | |
| f"[stage={self.task_id}, defender_only] Triage this alert." | |
| ), | |
| ) | |
| def _sample_target_label_for_brief(self, seed: int) -> TriageAction: | |
| """Pick a brief target label from the stage's label distribution.""" | |
| # Reuse the generator's stage config so brief and defender-only | |
| # generation are coherent. | |
| from generator import STAGE_CONFIGS # local import avoids cycle | |
| import random as _random | |
| cfg = STAGE_CONFIGS[self.task_id] | |
| rng = _random.Random(seed) | |
| labels = list(cfg["label_distribution"].keys()) | |
| weights = [cfg["label_distribution"][lab] for lab in labels] | |
| return rng.choices(labels, weights=weights, k=1)[0] | |
| __all__ = [ | |
| "AttackerBrief", | |
| "Action", | |
| "Observation", | |
| "EpisodeState", | |
| "OpenSOCEnv", | |
| ] | |