opensoc-env / env.py
shivam2k3's picture
OpenSOC v1
bb6a031
"""
env.py — `OpenSOCEnv`, the two-role gym-style environment.
Lifecycle
---------
An OpenSOC episode has *exactly two turns*:
Turn 1 (attacker): observation has role="attacker" with `attacker_brief`.
The agent submits `craft_incident` with structured
params. The env validates the params, runs the
plausibility checker, and computes ground truth.
Turn 2 (defender): observation has role="defender" with the materialized
`alert` and `log_window`. The agent submits
`submit_triage`. The env scores both sides and
terminates the episode.
In `defender_only` mode, the env auto-generates the incident with
`generator.generate_incident` and skips straight to turn 2 — useful for
SFT, eval, and smoke tests.
Mode selection happens via `OpenSOCEnv(mode=...)` or the `?mode=` query
param on `/reset`.
Anti-hack invariants
--------------------
1. The ground-truth label that drives defender reward is computed by
`verifier.compute_ground_truth(params)`, never read from `narrative`
or `target_label`.
2. The attacker's reward is gated on `verifier.check_plausibility(params)`.
3. Schema validation (pydantic) errors → schema_violation=True →
attacker reward floor of -0.5, *no* defender turn (env auto-dismisses).
"""
from __future__ import annotations
import time
import uuid
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field, ValidationError
from generator import generate_incident, make_alert
from rubric import score_attacker, score_defender
from schema import (
Action,
Alert,
CraftIncident,
Event,
IncidentParams,
SubmitTriage,
TriageAction,
)
from tasks.registry import STAGE_REGISTRY
from verifier import check_plausibility, compute_ground_truth
Role = Literal["attacker", "defender"]
Mode = Literal["self_play", "defender_only"]
# ---------------------------------------------------------------------------
# Public observation / state types
# ---------------------------------------------------------------------------
class AttackerBrief(BaseModel):
"""What the env tells the attacker to produce."""
target_label: TriageAction
difficulty: str
category_hint: str = "any"
class Observation(BaseModel):
"""Per-turn observation visible to the agent."""
role: Role
alert: Optional[Alert] = None
log_window: List[Event] = Field(default_factory=list)
attacker_brief: Optional[AttackerBrief] = None
step: int = 0
max_steps: int = 2
last_action_feedback: str = ""
done: bool = False
class EpisodeState(BaseModel):
"""Full internal state returned by /state."""
task_id: str
mode: Mode
step: int = 0
max_steps: int = 2
done: bool = False
role: Role
attacker_brief: Optional[AttackerBrief] = None
incident_alert: Optional[Alert] = None
incident_log_window: List[Event] = Field(default_factory=list)
triggering_log_id: Optional[str] = None
plausible: Optional[bool] = None
plausibility_reason: str = ""
schema_violation: bool = False
ground_truth: Optional[TriageAction] = None
defender_action: Optional[SubmitTriage] = None
defender_reward: Optional[float] = None
defender_breakdown: Dict[str, float] = Field(default_factory=dict)
attacker_reward: Optional[float] = None
attacker_breakdown: Dict[str, float] = Field(default_factory=dict)
cumulative_reward: float = 0.0
started_at: float = Field(default_factory=time.time)
# ---------------------------------------------------------------------------
# Environment
# ---------------------------------------------------------------------------
class OpenSOCEnv:
"""Two-role SOC triage environment with deterministic verifier rewards."""
MAX_STEPS = 2
def __init__(
self,
task_id: str = "stage1_basic",
mode: Mode = "self_play",
seed: int = 0,
):
if task_id not in STAGE_REGISTRY:
raise ValueError(
f"Unknown task '{task_id}'. Choose from: {list(STAGE_REGISTRY)}"
)
if mode not in ("self_play", "defender_only"):
raise ValueError(f"Unknown mode {mode!r}")
self.task_id = task_id
self.mode: Mode = mode
self.seed = seed
self._state: Optional[EpisodeState] = None
self._episode_idx = 0
# ------------------------------------------------------------------
# Gym-style API: reset / step / state / grade
# ------------------------------------------------------------------
def reset(self) -> Observation:
"""Start a fresh episode and return the first observation."""
self._episode_idx += 1
episode_seed = self.seed * 100_000 + self._episode_idx + STAGE_REGISTRY[self.task_id]["seed_offset"]
if self.mode == "defender_only":
params = generate_incident(self.task_id, seed=episode_seed)
return self._materialize_for_defender(params, started_role="defender")
# self_play: the next /step must be the attacker's craft_incident.
# We seed the brief with a target label that's representative of the
# stage's distribution, but the attacker is free to ignore it.
target_label = self._sample_target_label_for_brief(episode_seed)
brief = AttackerBrief(
target_label=target_label,
difficulty=STAGE_REGISTRY[self.task_id]["difficulty"],
category_hint="any",
)
self._state = EpisodeState(
task_id=self.task_id,
mode=self.mode,
role="attacker",
attacker_brief=brief,
max_steps=self.MAX_STEPS,
)
return Observation(
role="attacker",
attacker_brief=brief,
step=0,
max_steps=self.MAX_STEPS,
last_action_feedback=(
f"[stage={self.task_id}] Craft an incident whose ground truth "
f"is action={target_label.value}. Ignore the target_label hint "
f"if you can fool the defender harder with a different one."
),
)
def step(self, action: Action) -> tuple[Observation, float, bool, dict]:
"""Apply one agent action; return (obs, reward, done, info)."""
if self._state is None:
raise RuntimeError("Call reset() before step()")
if self._state.done:
raise RuntimeError("Episode is done. Call reset() to start a new one.")
s = self._state
s.step += 1
if s.role == "attacker":
return self._step_attacker(action)
return self._step_defender(action)
def state(self) -> Dict[str, Any]:
"""Return the full internal state."""
if self._state is None:
return {}
return self._state.model_dump(mode="json")
def grade(self) -> float:
"""Return a normalized [0, 1] score for the just-finished episode."""
s = self._state
if s is None or not s.done:
return 0.0
# Normalize defender reward to [0, 1] using the manifest range.
# Defender reward range is [-1.0, 1.1] (max correct + bonus).
if s.defender_reward is None:
return 0.0
lo, hi = -1.0, 1.1
clamped = max(lo, min(hi, s.defender_reward))
return float((clamped - lo) / (hi - lo))
# ------------------------------------------------------------------
# Attacker turn
# ------------------------------------------------------------------
def _step_attacker(self, action: Action) -> tuple[Observation, float, bool, dict]:
s = self._state
ci: Optional[CraftIncident] = action.craft_incident
if ci is None:
# Treated as a schema violation: -0.5 attacker reward, episode
# ends immediately because we have nothing to show the defender.
return self._abort_attacker_turn(
"Attacker turn requires craft_incident; got something else."
)
try:
params = IncidentParams(
target_label=ci.target_label,
category=ci.category,
events=ci.events,
narrative=ci.narrative,
)
except ValidationError as exc:
return self._abort_attacker_turn(f"Schema violation: {exc}")
plausible, reason, triggering_log_id = check_plausibility(params)
gt_label, _ = compute_ground_truth(params)
s.attacker_brief = s.attacker_brief
s.role = "defender"
s.plausible = plausible
s.plausibility_reason = reason
s.ground_truth = gt_label
s.triggering_log_id = triggering_log_id
alert = make_alert(params, alert_id=f"A-{uuid.uuid4().hex[:8]}")
s.incident_alert = alert
s.incident_log_window = list(params.events)
feedback = (
f"Attacker turn complete. plausible={plausible} ({reason}). "
"Defender will now triage."
)
obs = Observation(
role="defender",
alert=alert,
log_window=list(params.events),
step=s.step,
max_steps=self.MAX_STEPS,
last_action_feedback=feedback,
done=False,
)
info = {
"role_just_acted": "attacker",
"plausible": plausible,
"plausibility_reason": reason,
"ground_truth_hidden_from_defender": gt_label.value,
"triggering_log_id": triggering_log_id,
}
return obs, 0.0, False, info
def _abort_attacker_turn(self, reason: str) -> tuple[Observation, float, bool, dict]:
s = self._state
s.schema_violation = True
s.plausible = False
s.plausibility_reason = reason
attacker_reward, attacker_bd = score_attacker(
plausible=False, schema_violation=True,
defender_correct=False, novelty=0.0,
)
s.attacker_reward = attacker_reward
s.attacker_breakdown = attacker_bd
s.defender_reward = 0.0
s.cumulative_reward = attacker_reward
s.done = True
s.role = "defender"
return (
Observation(
role="defender",
step=s.step,
max_steps=self.MAX_STEPS,
last_action_feedback=f"Episode aborted: {reason}",
done=True,
),
attacker_reward,
True,
{
"role_just_acted": "attacker",
"schema_violation": True,
"attacker_reward": attacker_reward,
"attacker_breakdown": attacker_bd,
},
)
# ------------------------------------------------------------------
# Defender turn
# ------------------------------------------------------------------
def _step_defender(self, action: Action) -> tuple[Observation, float, bool, dict]:
s = self._state
st = action.submit_triage
if st is None:
# Treat as a missed-malicious-equivalent: penalize by acting as
# if the defender chose 'dismiss' with no citation.
st = SubmitTriage(action=TriageAction.DISMISS, cited_log_id="L1-0")
defender_reward, defender_bd = score_defender(
action=st.action,
ground_truth=s.ground_truth or TriageAction.DISMISS,
triggering_log_id=s.triggering_log_id or "L1-0",
cited_log_id=st.cited_log_id,
)
defender_correct = st.action is s.ground_truth
attacker_reward, attacker_bd = 0.0, {}
if s.mode == "self_play":
attacker_reward, attacker_bd = score_attacker(
plausible=bool(s.plausible),
schema_violation=False,
defender_correct=defender_correct,
novelty=0.0, # filled in by the trainer if it tracks batches
)
s.defender_action = st
s.defender_reward = defender_reward
s.defender_breakdown = defender_bd
s.attacker_reward = attacker_reward
s.attacker_breakdown = attacker_bd
s.cumulative_reward = defender_reward + attacker_reward
s.done = True
s.role = "defender"
feedback = (
f"Defender chose {st.action.value}; ground truth was "
f"{(s.ground_truth or TriageAction.DISMISS).value}. "
f"Reward={defender_reward:+.2f}."
)
obs = Observation(
role="defender",
alert=s.incident_alert,
log_window=list(s.incident_log_window),
step=s.step,
max_steps=self.MAX_STEPS,
last_action_feedback=feedback,
done=True,
)
info = {
"role_just_acted": "defender",
"ground_truth": (s.ground_truth or TriageAction.DISMISS).value,
"defender_correct": defender_correct,
"defender_breakdown": defender_bd,
"attacker_reward": attacker_reward,
"attacker_breakdown": attacker_bd,
"triggering_log_id": s.triggering_log_id,
}
return obs, defender_reward, True, info
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _materialize_for_defender(
self, params: IncidentParams, *, started_role: Role
) -> Observation:
"""Set up state for a defender_only episode (skip attacker turn)."""
plausible, reason, triggering_log_id = check_plausibility(params)
gt_label, _ = compute_ground_truth(params)
alert = make_alert(params, alert_id=f"A-{uuid.uuid4().hex[:8]}")
self._state = EpisodeState(
task_id=self.task_id,
mode=self.mode,
role="defender",
incident_alert=alert,
incident_log_window=list(params.events),
triggering_log_id=triggering_log_id,
plausible=plausible,
plausibility_reason=reason,
ground_truth=gt_label,
max_steps=self.MAX_STEPS,
)
return Observation(
role="defender",
alert=alert,
log_window=list(params.events),
step=0,
max_steps=self.MAX_STEPS,
last_action_feedback=(
f"[stage={self.task_id}, defender_only] Triage this alert."
),
)
def _sample_target_label_for_brief(self, seed: int) -> TriageAction:
"""Pick a brief target label from the stage's label distribution."""
# Reuse the generator's stage config so brief and defender-only
# generation are coherent.
from generator import STAGE_CONFIGS # local import avoids cycle
import random as _random
cfg = STAGE_CONFIGS[self.task_id]
rng = _random.Random(seed)
labels = list(cfg["label_distribution"].keys())
weights = [cfg["label_distribution"][lab] for lab in labels]
return rng.choices(labels, weights=weights, k=1)[0]
__all__ = [
"AttackerBrief",
"Action",
"Observation",
"EpisodeState",
"OpenSOCEnv",
]