| """ |
| schema.py — OpenSOC incident & action schema. |
| |
| This module is the single source of truth for what the attacker is allowed to |
| emit and what the defender is allowed to respond with. The verifier and |
| rubric both depend on the constraints here; they should never be relaxed |
| without updating the corresponding tests in `tests/test_schema.py` and |
| `tests/test_verifier.py`. |
| |
| Design principles |
| ----------------- |
| 1. The *attacker* controls structured parameters (event types, field values), |
| never the ground-truth label directly. The label is derived deterministically |
| from the params by `verifier.compute_ground_truth` so that the reward can |
| never be hacked by attacker text. |
| 2. Every event has a stable `log_id` of the form `L<turn>-<n>` so that the |
| defender can cite a triggering event and earn a small bonus. This is |
| regex-verifiable. |
| 3. The defender's action set is a fixed enum of five SOC responses ranked by |
| "cost" (dismiss = cheapest, escalate = most expensive). This lets the |
| rubric grade over- vs under-reaction. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import enum |
| import ipaddress |
| import re |
| from typing import Any, Dict, List, Optional |
|
|
| from pydantic import BaseModel, Field, field_validator, model_validator |
|
|
|
|
| |
| |
| |
|
|
| class TriageAction(str, enum.Enum): |
| DISMISS = "dismiss" |
| MONITOR = "monitor" |
| QUARANTINE_HOST = "quarantine_host" |
| BLOCK_IP = "block_ip" |
| ESCALATE = "escalate" |
|
|
|
|
| |
| |
| ACTION_COST: Dict[TriageAction, int] = { |
| TriageAction.DISMISS: 0, |
| TriageAction.MONITOR: 1, |
| TriageAction.QUARANTINE_HOST: 2, |
| TriageAction.BLOCK_IP: 2, |
| TriageAction.ESCALATE: 3, |
| } |
|
|
|
|
| |
| |
| CONTAINMENT_ACTIONS = {TriageAction.QUARANTINE_HOST, TriageAction.BLOCK_IP} |
|
|
|
|
| |
| |
| |
|
|
| class EventType(str, enum.Enum): |
| |
| AUTH_LOGIN_SUCCESS = "auth.login_success" |
| AUTH_LOGIN_FAILURE = "auth.login_failure" |
| AUTH_PASSWORD_RESET = "auth.password_reset" |
| AUTH_MFA_FAILURE = "auth.mfa_failure" |
| AUTH_PRIVILEGE_GRANT = "auth.privilege_grant" |
|
|
| |
| PROC_START = "proc.start" |
| PROC_PARENT_MISMATCH = "proc.parent_mismatch" |
| PROC_LOLBIN = "proc.lolbin_use" |
|
|
| |
| NET_OUTBOUND = "net.outbound_connection" |
| NET_DNS_QUERY = "net.dns_query" |
| NET_BEACON = "net.beacon" |
| NET_PORT_SCAN_HIT = "net.port_scan_hit" |
|
|
| |
| FILE_WRITE = "file.write" |
| FILE_DELETE = "file.delete" |
| FILE_RENAME_DOUBLE_EXT = "file.rename_double_ext" |
|
|
| |
| EMAIL_RECEIVED = "email.received" |
| EMAIL_LINK_CLICKED = "email.link_clicked" |
| EMAIL_ATTACHMENT_OPENED = "email.attachment_opened" |
|
|
| |
| CLOUD_API_CALL = "cloud.api_call" |
| CLOUD_KEY_CREATED = "cloud.key_created" |
|
|
| |
| EDR_BEHAVIOR_MATCH = "edr.behavior_match" |
|
|
|
|
| |
| |
| class IncidentCategory(str, enum.Enum): |
| BENIGN_NOISE = "benign_noise" |
| BRUTE_FORCE = "brute_force" |
| PHISHING = "phishing" |
| LATERAL_MOVEMENT = "lateral_movement" |
| PRIVILEGE_ESCALATION = "privilege_escalation" |
| DATA_EXFILTRATION = "data_exfiltration" |
| MALWARE_EXECUTION = "malware_execution" |
| C2_BEACON = "c2_beacon" |
| INSIDER_DATA_ACCESS = "insider_data_access" |
|
|
|
|
| |
| |
| |
|
|
| SEVERITIES = ("info", "low", "medium", "high", "critical") |
| LOG_ID_PATTERN = re.compile(r"^L\d+-\d+$") |
| ISO_TS_PATTERN = re.compile( |
| r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?Z$" |
| ) |
|
|
| |
| |
| INTERNAL_NETS = [ |
| ipaddress.ip_network("10.0.0.0/8"), |
| ipaddress.ip_network("172.16.0.0/12"), |
| ipaddress.ip_network("192.168.0.0/16"), |
| ] |
|
|
| |
| KNOWN_LOLBINS = { |
| "powershell.exe", "pwsh.exe", "cmd.exe", "wmic.exe", "rundll32.exe", |
| "regsvr32.exe", "mshta.exe", "certutil.exe", "bitsadmin.exe", |
| "schtasks.exe", "wscript.exe", "cscript.exe", |
| } |
|
|
| |
| |
| SUSPICIOUS_LOLBIN_PARENTS = { |
| "winword.exe", "excel.exe", "powerpnt.exe", "outlook.exe", |
| "chrome.exe", "msedge.exe", "firefox.exe", |
| } |
|
|
|
|
| def is_internal_ip(ip_str: str) -> bool: |
| """Return True if `ip_str` is in any RFC1918 range.""" |
| try: |
| ip = ipaddress.ip_address(ip_str) |
| except ValueError: |
| return False |
| return any(ip in net for net in INTERNAL_NETS) |
|
|
|
|
| |
| |
| |
|
|
| class Event(BaseModel): |
| """A single structured log event.""" |
|
|
| log_id: str = Field(..., description="Stable id of the form 'L<turn>-<n>'.") |
| timestamp: str = Field(..., description="ISO-8601 UTC timestamp.") |
| source: str = Field( |
| "endpoint", |
| description="Logical source bucket: endpoint | network | identity | email | cloud | edr", |
| ) |
| event_type: EventType |
| fields: Dict[str, Any] = Field(default_factory=dict) |
|
|
| @field_validator("log_id") |
| @classmethod |
| def _check_log_id(cls, v: str) -> str: |
| if not LOG_ID_PATTERN.match(v): |
| raise ValueError(f"log_id must match L<turn>-<n>, got {v!r}") |
| return v |
|
|
| @field_validator("timestamp") |
| @classmethod |
| def _check_ts(cls, v: str) -> str: |
| if not ISO_TS_PATTERN.match(v): |
| raise ValueError(f"timestamp must be ISO-8601 UTC, got {v!r}") |
| return v |
|
|
|
|
| class Alert(BaseModel): |
| """SIEM-style alert summary the defender sees first.""" |
|
|
| alert_id: str |
| category: IncidentCategory |
| severity: str = Field("medium", description="One of: info, low, medium, high, critical.") |
| summary: str |
| host: str = "host-001" |
| user: str = "user-001" |
|
|
| @field_validator("severity") |
| @classmethod |
| def _check_severity(cls, v: str) -> str: |
| if v not in SEVERITIES: |
| raise ValueError(f"severity must be one of {SEVERITIES}, got {v!r}") |
| return v |
|
|
|
|
| class IncidentParams(BaseModel): |
| """Parameters the attacker chooses; the env materializes these into an Incident. |
| |
| The triage label that ends up in the defender's reward is derived |
| *deterministically* from the events here by `verifier.compute_ground_truth`. |
| `target_label` is purely a shaping hint: if the attacker's events imply a |
| different label than `target_label`, the schema validator rejects the |
| incident (so the attacker cannot lie about its own intent). |
| """ |
|
|
| target_label: TriageAction |
| category: IncidentCategory |
| events: List[Event] |
| narrative: str = Field("", description="Free-text scratchpad; ignored by the verifier.") |
|
|
| @field_validator("events") |
| @classmethod |
| def _events_nonempty(cls, v: List[Event]) -> List[Event]: |
| if not v: |
| raise ValueError("events must contain at least one Event") |
| if len(v) > 32: |
| raise ValueError("events list capped at 32 entries") |
| return v |
|
|
| @model_validator(mode="after") |
| def _events_have_unique_ids(self) -> "IncidentParams": |
| ids = [e.log_id for e in self.events] |
| if len(set(ids)) != len(ids): |
| raise ValueError("event log_ids must be unique") |
| return self |
|
|
| @model_validator(mode="after") |
| def _timestamps_monotonic(self) -> "IncidentParams": |
| ts = [e.timestamp for e in self.events] |
| if ts != sorted(ts): |
| raise ValueError("event timestamps must be non-decreasing") |
| return self |
|
|
|
|
| class Incident(BaseModel): |
| """Materialized incident the env shows to the defender.""" |
|
|
| alert: Alert |
| log_window: List[Event] |
| triggering_log_id: str = Field( |
| ..., description="The log_id the verifier deemed most diagnostic." |
| ) |
|
|
| @field_validator("triggering_log_id") |
| @classmethod |
| def _check_trigger_id(cls, v: str) -> str: |
| if not LOG_ID_PATTERN.match(v): |
| raise ValueError(f"triggering_log_id must match L<turn>-<n>, got {v!r}") |
| return v |
|
|
|
|
| class CraftIncident(BaseModel): |
| """Attacker-facing action wrapper.""" |
|
|
| target_label: TriageAction |
| category: IncidentCategory |
| events: List[Event] |
| narrative: str = "" |
|
|
|
|
| class SubmitTriage(BaseModel): |
| """Defender-facing action wrapper.""" |
|
|
| action: TriageAction |
| cited_log_id: str = Field(..., description="Log id that drove the decision.") |
| rationale: str = "" |
|
|
| @field_validator("cited_log_id") |
| @classmethod |
| def _check_cited_log_id(cls, v: str) -> str: |
| if not LOG_ID_PATTERN.match(v): |
| raise ValueError(f"cited_log_id must match L<turn>-<n>, got {v!r}") |
| return v |
|
|
|
|
| class Action(BaseModel): |
| """OpenEnv-style action union: exactly one field non-null per /step.""" |
|
|
| craft_incident: Optional[CraftIncident] = None |
| submit_triage: Optional[SubmitTriage] = None |
|
|
|
|
| |
| |
| |
|
|
| def make_log_id(turn: int, n: int) -> str: |
| """Return a canonically-formatted log id.""" |
| return f"L{turn}-{n}" |
|
|
|
|
| def make_event( |
| turn: int, |
| n: int, |
| event_type: EventType, |
| timestamp: str, |
| *, |
| source: str = "endpoint", |
| **fields: Any, |
| ) -> Event: |
| """Compact helper used by `generator.py` and tests.""" |
| return Event( |
| log_id=make_log_id(turn, n), |
| timestamp=timestamp, |
| source=source, |
| event_type=event_type, |
| fields=dict(fields), |
| ) |
|
|