Spaces:
Sleeping
Sleeping
| """ | |
| schema.py — OpenSOC incident & action schema. | |
| This module is the single source of truth for what the attacker is allowed to | |
| emit and what the defender is allowed to respond with. The verifier and | |
| rubric both depend on the constraints here; they should never be relaxed | |
| without updating the corresponding tests in `tests/test_schema.py` and | |
| `tests/test_verifier.py`. | |
| Design principles | |
| ----------------- | |
| 1. The *attacker* controls structured parameters (event types, field values), | |
| never the ground-truth label directly. The label is derived deterministically | |
| from the params by `verifier.compute_ground_truth` so that the reward can | |
| never be hacked by attacker text. | |
| 2. Every event has a stable `log_id` of the form `L<turn>-<n>` so that the | |
| defender can cite a triggering event and earn a small bonus. This is | |
| regex-verifiable. | |
| 3. The defender's action set is a fixed enum of five SOC responses ranked by | |
| "cost" (dismiss = cheapest, escalate = most expensive). This lets the | |
| rubric grade over- vs under-reaction. | |
| """ | |
| from __future__ import annotations | |
| import enum | |
| import ipaddress | |
| import re | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import BaseModel, Field, field_validator, model_validator | |
| # --------------------------------------------------------------------------- | |
| # Defender action set (fixed enum, ordered by escalation cost) | |
| # --------------------------------------------------------------------------- | |
| class TriageAction(str, enum.Enum): | |
| DISMISS = "dismiss" | |
| MONITOR = "monitor" | |
| QUARANTINE_HOST = "quarantine_host" | |
| BLOCK_IP = "block_ip" | |
| ESCALATE = "escalate" | |
| # Cost ordering — used by rubric to compute over-/under-reaction penalties. | |
| # Higher number = more disruptive / expensive action. | |
| ACTION_COST: Dict[TriageAction, int] = { | |
| TriageAction.DISMISS: 0, | |
| TriageAction.MONITOR: 1, | |
| TriageAction.QUARANTINE_HOST: 2, | |
| TriageAction.BLOCK_IP: 2, | |
| TriageAction.ESCALATE: 3, | |
| } | |
| # Actions that are considered "containment" — used by rubric to detect | |
| # over-reaction on benign incidents. | |
| CONTAINMENT_ACTIONS = {TriageAction.QUARANTINE_HOST, TriageAction.BLOCK_IP} | |
| # --------------------------------------------------------------------------- | |
| # Event taxonomy | |
| # --------------------------------------------------------------------------- | |
| class EventType(str, enum.Enum): | |
| # Auth / identity | |
| AUTH_LOGIN_SUCCESS = "auth.login_success" | |
| AUTH_LOGIN_FAILURE = "auth.login_failure" | |
| AUTH_PASSWORD_RESET = "auth.password_reset" | |
| AUTH_MFA_FAILURE = "auth.mfa_failure" | |
| AUTH_PRIVILEGE_GRANT = "auth.privilege_grant" | |
| # Process | |
| PROC_START = "proc.start" | |
| PROC_PARENT_MISMATCH = "proc.parent_mismatch" | |
| PROC_LOLBIN = "proc.lolbin_use" | |
| # Network | |
| NET_OUTBOUND = "net.outbound_connection" | |
| NET_DNS_QUERY = "net.dns_query" | |
| NET_BEACON = "net.beacon" | |
| NET_PORT_SCAN_HIT = "net.port_scan_hit" | |
| # File / object | |
| FILE_WRITE = "file.write" | |
| FILE_DELETE = "file.delete" | |
| FILE_RENAME_DOUBLE_EXT = "file.rename_double_ext" | |
| # Email / phishing | |
| EMAIL_RECEIVED = "email.received" | |
| EMAIL_LINK_CLICKED = "email.link_clicked" | |
| EMAIL_ATTACHMENT_OPENED = "email.attachment_opened" | |
| # Cloud / SaaS | |
| CLOUD_API_CALL = "cloud.api_call" | |
| CLOUD_KEY_CREATED = "cloud.key_created" | |
| # Endpoint detection signals | |
| EDR_BEHAVIOR_MATCH = "edr.behavior_match" | |
| # Categories the attacker can target (informational for shaping; final | |
| # triage label is computed from parameter content, not from `category`). | |
| class IncidentCategory(str, enum.Enum): | |
| BENIGN_NOISE = "benign_noise" | |
| BRUTE_FORCE = "brute_force" | |
| PHISHING = "phishing" | |
| LATERAL_MOVEMENT = "lateral_movement" | |
| PRIVILEGE_ESCALATION = "privilege_escalation" | |
| DATA_EXFILTRATION = "data_exfiltration" | |
| MALWARE_EXECUTION = "malware_execution" | |
| C2_BEACON = "c2_beacon" | |
| INSIDER_DATA_ACCESS = "insider_data_access" | |
| # --------------------------------------------------------------------------- | |
| # Constants used for plausibility checks | |
| # --------------------------------------------------------------------------- | |
| SEVERITIES = ("info", "low", "medium", "high", "critical") | |
| LOG_ID_PATTERN = re.compile(r"^L\d+-\d+$") | |
| ISO_TS_PATTERN = re.compile( | |
| r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?Z$" | |
| ) | |
| # Internal RFC1918 ranges — a connection from internal->internal does NOT | |
| # look like exfil even if bytes are large. The verifier uses this. | |
| INTERNAL_NETS = [ | |
| ipaddress.ip_network("10.0.0.0/8"), | |
| ipaddress.ip_network("172.16.0.0/12"), | |
| ipaddress.ip_network("192.168.0.0/16"), | |
| ] | |
| # Living-off-the-land binaries the attacker can mark as `lolbin_use` events. | |
| KNOWN_LOLBINS = { | |
| "powershell.exe", "pwsh.exe", "cmd.exe", "wmic.exe", "rundll32.exe", | |
| "regsvr32.exe", "mshta.exe", "certutil.exe", "bitsadmin.exe", | |
| "schtasks.exe", "wscript.exe", "cscript.exe", | |
| } | |
| # Suspicious LOLBin parents — a lolbin spawned by Office or a browser is | |
| # strongly indicative of malicious code execution. | |
| SUSPICIOUS_LOLBIN_PARENTS = { | |
| "winword.exe", "excel.exe", "powerpnt.exe", "outlook.exe", | |
| "chrome.exe", "msedge.exe", "firefox.exe", | |
| } | |
| def is_internal_ip(ip_str: str) -> bool: | |
| """Return True if `ip_str` is in any RFC1918 range.""" | |
| try: | |
| ip = ipaddress.ip_address(ip_str) | |
| except ValueError: | |
| return False | |
| return any(ip in net for net in INTERNAL_NETS) | |
| # --------------------------------------------------------------------------- | |
| # Pydantic models | |
| # --------------------------------------------------------------------------- | |
| class Event(BaseModel): | |
| """A single structured log event.""" | |
| log_id: str = Field(..., description="Stable id of the form 'L<turn>-<n>'.") | |
| timestamp: str = Field(..., description="ISO-8601 UTC timestamp.") | |
| source: str = Field( | |
| "endpoint", | |
| description="Logical source bucket: endpoint | network | identity | email | cloud | edr", | |
| ) | |
| event_type: EventType | |
| fields: Dict[str, Any] = Field(default_factory=dict) | |
| def _check_log_id(cls, v: str) -> str: | |
| if not LOG_ID_PATTERN.match(v): | |
| raise ValueError(f"log_id must match L<turn>-<n>, got {v!r}") | |
| return v | |
| def _check_ts(cls, v: str) -> str: | |
| if not ISO_TS_PATTERN.match(v): | |
| raise ValueError(f"timestamp must be ISO-8601 UTC, got {v!r}") | |
| return v | |
| class Alert(BaseModel): | |
| """SIEM-style alert summary the defender sees first.""" | |
| alert_id: str | |
| category: IncidentCategory | |
| severity: str = Field("medium", description="One of: info, low, medium, high, critical.") | |
| summary: str | |
| host: str = "host-001" | |
| user: str = "user-001" | |
| def _check_severity(cls, v: str) -> str: | |
| if v not in SEVERITIES: | |
| raise ValueError(f"severity must be one of {SEVERITIES}, got {v!r}") | |
| return v | |
| class IncidentParams(BaseModel): | |
| """Parameters the attacker chooses; the env materializes these into an Incident. | |
| The triage label that ends up in the defender's reward is derived | |
| *deterministically* from the events here by `verifier.compute_ground_truth`. | |
| `target_label` is purely a shaping hint: if the attacker's events imply a | |
| different label than `target_label`, the schema validator rejects the | |
| incident (so the attacker cannot lie about its own intent). | |
| """ | |
| target_label: TriageAction | |
| category: IncidentCategory | |
| events: List[Event] | |
| narrative: str = Field("", description="Free-text scratchpad; ignored by the verifier.") | |
| def _events_nonempty(cls, v: List[Event]) -> List[Event]: | |
| if not v: | |
| raise ValueError("events must contain at least one Event") | |
| if len(v) > 32: | |
| raise ValueError("events list capped at 32 entries") | |
| return v | |
| def _events_have_unique_ids(self) -> "IncidentParams": | |
| ids = [e.log_id for e in self.events] | |
| if len(set(ids)) != len(ids): | |
| raise ValueError("event log_ids must be unique") | |
| return self | |
| def _timestamps_monotonic(self) -> "IncidentParams": | |
| ts = [e.timestamp for e in self.events] | |
| if ts != sorted(ts): | |
| raise ValueError("event timestamps must be non-decreasing") | |
| return self | |
| class Incident(BaseModel): | |
| """Materialized incident the env shows to the defender.""" | |
| alert: Alert | |
| log_window: List[Event] | |
| triggering_log_id: str = Field( | |
| ..., description="The log_id the verifier deemed most diagnostic." | |
| ) | |
| def _check_trigger_id(cls, v: str) -> str: | |
| if not LOG_ID_PATTERN.match(v): | |
| raise ValueError(f"triggering_log_id must match L<turn>-<n>, got {v!r}") | |
| return v | |
| class CraftIncident(BaseModel): | |
| """Attacker-facing action wrapper.""" | |
| target_label: TriageAction | |
| category: IncidentCategory | |
| events: List[Event] | |
| narrative: str = "" | |
| class SubmitTriage(BaseModel): | |
| """Defender-facing action wrapper.""" | |
| action: TriageAction | |
| cited_log_id: str = Field(..., description="Log id that drove the decision.") | |
| rationale: str = "" | |
| def _check_cited_log_id(cls, v: str) -> str: | |
| if not LOG_ID_PATTERN.match(v): | |
| raise ValueError(f"cited_log_id must match L<turn>-<n>, got {v!r}") | |
| return v | |
| class Action(BaseModel): | |
| """OpenEnv-style action union: exactly one field non-null per /step.""" | |
| craft_incident: Optional[CraftIncident] = None | |
| submit_triage: Optional[SubmitTriage] = None | |
| # --------------------------------------------------------------------------- | |
| # Convenience builders for tests / generators | |
| # --------------------------------------------------------------------------- | |
| def make_log_id(turn: int, n: int) -> str: | |
| """Return a canonically-formatted log id.""" | |
| return f"L{turn}-{n}" | |
| def make_event( | |
| turn: int, | |
| n: int, | |
| event_type: EventType, | |
| timestamp: str, | |
| *, | |
| source: str = "endpoint", | |
| **fields: Any, | |
| ) -> Event: | |
| """Compact helper used by `generator.py` and tests.""" | |
| return Event( | |
| log_id=make_log_id(turn, n), | |
| timestamp=timestamp, | |
| source=source, | |
| event_type=event_type, | |
| fields=dict(fields), | |
| ) | |