opensoc-env / generator.py
shivam2k3's picture
OpenSOC v1
bb6a031
"""
generator.py — Deterministic, seeded incident generator.
Used by:
* `OpenSOCEnv` in `defender_only` mode, when the env needs to materialize a
self-contained incident for the defender (SFT warm-start, eval, smoke
tests, curriculum starter prompts).
* `train/sft_warmstart.py` to produce ~600 (incident, triage) pairs for
bootstrapping defender format learning.
* `train/make_holdout.py` to build the frozen 200-incident eval set.
The generator emits `IncidentParams` instances; the env then materializes
them into `Incident` objects with a SIEM-style `Alert` summary. The
attacker is *not* required to use this generator — its only job is to give
the env a deterministic starting distribution for stages 1-4.
Seeding contract
----------------
``generate_incident(stage_id, seed=N)`` is referentially transparent:
calling it with the same arguments anywhere in the codebase returns the
exact same incident. This is what makes the held-out eval set
reproducible across machines.
"""
from __future__ import annotations
import random
from datetime import datetime, timedelta, timezone
from typing import Callable, Dict, List, Tuple
from schema import (
Alert,
EventType,
IncidentCategory,
IncidentParams,
TriageAction,
make_event,
)
# ---------------------------------------------------------------------------
# Time helpers
# ---------------------------------------------------------------------------
def _ts_iter(start: datetime, n: int, step_s: int = 5) -> List[str]:
"""Return n monotonic ISO-8601 UTC timestamps starting at `start`."""
return [
(start + timedelta(seconds=step_s * i))
.replace(tzinfo=timezone.utc)
.strftime("%Y-%m-%dT%H:%M:%SZ")
for i in range(n)
]
def _start_time(rng: random.Random) -> datetime:
"""Pick a recent UTC start time anchored on the current calendar day."""
base = datetime(2026, 4, 25, 0, 0, 0)
minutes = rng.randint(0, 60 * 23)
return base + timedelta(minutes=minutes)
# ---------------------------------------------------------------------------
# Single-event template builders, keyed by *true* ground-truth label
# ---------------------------------------------------------------------------
def _benign_login(rng: random.Random) -> Tuple[IncidentCategory, list]:
ts = _ts_iter(_start_time(rng), 1)
user = rng.choice(["alice", "bob", "carol", "dave", "erin"])
return IncidentCategory.BENIGN_NOISE, [
make_event(
1, 0, EventType.AUTH_LOGIN_SUCCESS, ts[0],
source="identity", user=user, src_ip=f"10.0.0.{rng.randint(2, 200)}",
)
]
def _benign_internal_beacon(rng: random.Random) -> Tuple[IncidentCategory, list]:
ts = _ts_iter(_start_time(rng), 1)
return IncidentCategory.BENIGN_NOISE, [
make_event(
1, 0, EventType.NET_BEACON, ts[0],
source="network", dst_ip=f"10.0.{rng.randint(0, 255)}.{rng.randint(2, 250)}",
interval_s=rng.choice([30, 60, 90]),
)
]
def _monitor_brute_force(rng: random.Random) -> Tuple[IncidentCategory, list]:
n_failures = rng.randint(5, 9)
ts = _ts_iter(_start_time(rng), n_failures)
user = rng.choice(["root", "admin", "service_acct", "alice"])
return IncidentCategory.BRUTE_FORCE, [
make_event(
1, i, EventType.AUTH_LOGIN_FAILURE, ts[i],
source="identity", user=user,
src_ip=f"203.0.113.{rng.randint(2, 250)}",
)
for i in range(n_failures)
]
def _monitor_port_scan(rng: random.Random) -> Tuple[IncidentCategory, list]:
ts = _ts_iter(_start_time(rng), 1)
return IncidentCategory.BENIGN_NOISE, [
make_event(
1, 0, EventType.NET_PORT_SCAN_HIT, ts[0],
source="network",
src_ip=f"203.0.113.{rng.randint(2, 250)}",
port=rng.choice([22, 23, 445, 3389, 5985]),
),
]
def _quarantine_lolbin(rng: random.Random) -> Tuple[IncidentCategory, list]:
ts = _ts_iter(_start_time(rng), 2)
parent = rng.choice(["outlook.exe", "winword.exe", "excel.exe", "chrome.exe"])
return IncidentCategory.MALWARE_EXECUTION, [
make_event(
1, 0, EventType.PROC_LOLBIN, ts[0],
source="endpoint",
process="powershell.exe",
parent_process=parent,
cmd="-enc " + "A" * rng.randint(40, 80),
),
make_event(
1, 1, EventType.FILE_WRITE, ts[1],
source="endpoint",
path=f"C:\\Users\\Public\\{rng.choice(['payload', 'svc', 'tmp'])}.exe",
),
]
def _quarantine_edr_high(rng: random.Random) -> Tuple[IncidentCategory, list]:
ts = _ts_iter(_start_time(rng), 1)
return IncidentCategory.MALWARE_EXECUTION, [
make_event(
1, 0, EventType.EDR_BEHAVIOR_MATCH, ts[0],
source="edr", severity="high", rule_id=f"EDR-{rng.randint(1000, 9999)}",
host=f"host-{rng.randint(1, 50):03d}",
)
]
def _block_ip_external_beacon(rng: random.Random) -> Tuple[IncidentCategory, list]:
ts = _ts_iter(_start_time(rng), 1)
return IncidentCategory.C2_BEACON, [
make_event(
1, 0, EventType.NET_BEACON, ts[0],
source="network",
dst_ip=rng.choice(["203.0.113.5", "198.51.100.42", "185.220.101.7"]),
interval_s=rng.choice([30, 60, 90]),
)
]
def _block_ip_phish_url(rng: random.Random) -> Tuple[IncidentCategory, list]:
ts = _ts_iter(_start_time(rng), 1)
tld = rng.choice([".ru", ".cn", ".top", ".xyz", ".click"])
return IncidentCategory.PHISHING, [
make_event(
1, 0, EventType.EMAIL_LINK_CLICKED, ts[0],
source="email",
url=f"https://login-update{tld}/secure",
user=rng.choice(["alice", "bob", "carol"]),
)
]
def _escalate_combined(rng: random.Random) -> Tuple[IncidentCategory, list]:
ts = _ts_iter(_start_time(rng), 3, step_s=15)
return IncidentCategory.MALWARE_EXECUTION, [
make_event(
1, 0, EventType.PROC_LOLBIN, ts[0],
source="endpoint",
process="powershell.exe", parent_process="outlook.exe",
cmd="-enc " + "B" * 60,
),
make_event(
1, 1, EventType.NET_BEACON, ts[1],
source="network", dst_ip="203.0.113.5", interval_s=30,
),
make_event(
1, 2, EventType.AUTH_PRIVILEGE_GRANT, ts[2],
source="identity", user=rng.choice(["alice", "bob"]), role="admin",
),
]
# Mapping: ground-truth label → list of template builders
TEMPLATES: Dict[TriageAction, List[Callable[[random.Random], Tuple[IncidentCategory, list]]]] = {
TriageAction.DISMISS: [_benign_login, _benign_internal_beacon],
TriageAction.MONITOR: [_monitor_brute_force, _monitor_port_scan],
TriageAction.QUARANTINE_HOST: [_quarantine_lolbin, _quarantine_edr_high],
TriageAction.BLOCK_IP: [_block_ip_external_beacon, _block_ip_phish_url],
TriageAction.ESCALATE: [_escalate_combined],
}
# ---------------------------------------------------------------------------
# Stage configs
# ---------------------------------------------------------------------------
# Each stage has:
# - label_distribution: probability mass over ground-truth labels (must sum to 1)
# - decoys: number of *additional* benign-looking events to splice in
# - jitter: how much we perturb fields (0.0 = none, 1.0 = max)
STAGE_CONFIGS: Dict[str, dict] = {
"stage1_basic": {
"label_distribution": {
TriageAction.DISMISS: 0.30,
TriageAction.MONITOR: 0.20,
TriageAction.QUARANTINE_HOST: 0.20,
TriageAction.BLOCK_IP: 0.20,
TriageAction.ESCALATE: 0.10,
},
"decoys": 0,
"jitter": 0.0,
},
"stage2_multi": {
"label_distribution": {
TriageAction.DISMISS: 0.20,
TriageAction.MONITOR: 0.20,
TriageAction.QUARANTINE_HOST: 0.25,
TriageAction.BLOCK_IP: 0.20,
TriageAction.ESCALATE: 0.15,
},
"decoys": 1,
"jitter": 0.2,
},
"stage3_mixed": {
"label_distribution": {
TriageAction.DISMISS: 0.25,
TriageAction.MONITOR: 0.25,
TriageAction.QUARANTINE_HOST: 0.20,
TriageAction.BLOCK_IP: 0.15,
TriageAction.ESCALATE: 0.15,
},
"decoys": 2,
"jitter": 0.4,
},
"stage4_adversarial": {
"label_distribution": {
TriageAction.DISMISS: 0.30,
TriageAction.MONITOR: 0.25,
TriageAction.QUARANTINE_HOST: 0.15,
TriageAction.BLOCK_IP: 0.15,
TriageAction.ESCALATE: 0.15,
},
"decoys": 3,
"jitter": 0.7,
},
}
def _sample_label(rng: random.Random, dist: Dict[TriageAction, float]) -> TriageAction:
labels = list(dist.keys())
weights = [dist[lab] for lab in labels]
return rng.choices(labels, weights=weights, k=1)[0]
def _make_decoy_events(rng: random.Random, n_decoys: int, start_idx: int) -> list:
"""Generate `n_decoys` benign decoy events that don't change the label.
Decoys are drawn from a pool that is provably benign by the verifier:
a successful login, an internal DNS query, an internal outbound flow.
"""
ts = _ts_iter(_start_time(rng), n_decoys, step_s=2)
decoys = []
for i in range(n_decoys):
choice = rng.randint(0, 2)
n = start_idx + i
if choice == 0:
decoys.append(make_event(
1, n, EventType.AUTH_LOGIN_SUCCESS, ts[i],
source="identity",
user=rng.choice(["alice", "bob", "carol", "dave"]),
src_ip=f"10.0.0.{rng.randint(2, 250)}",
))
elif choice == 1:
decoys.append(make_event(
1, n, EventType.NET_DNS_QUERY, ts[i],
source="network",
domain=rng.choice(["github.com", "google.com", "internal.corp"]),
))
else:
decoys.append(make_event(
1, n, EventType.NET_OUTBOUND, ts[i],
source="network",
dst_ip=f"10.0.{rng.randint(0, 255)}.{rng.randint(2, 250)}",
bytes_out=rng.randint(1_000, 100_000),
))
return decoys
def _renumber_and_resort(events: list) -> list:
"""Rewrite log_ids to L1-0..L1-N-1 and sort by timestamp."""
events = sorted(events, key=lambda e: e.timestamp)
fixed = []
for i, e in enumerate(events):
fixed.append(
type(e)(
log_id=f"L1-{i}",
timestamp=e.timestamp,
source=e.source,
event_type=e.event_type,
fields=e.fields,
)
)
return fixed
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def generate_incident(stage_id: str, seed: int) -> IncidentParams:
"""Deterministically generate an `IncidentParams` for the given stage."""
if stage_id not in STAGE_CONFIGS:
raise ValueError(
f"Unknown stage_id {stage_id!r}; choose from {list(STAGE_CONFIGS)}"
)
cfg = STAGE_CONFIGS[stage_id]
rng = random.Random(seed)
label = _sample_label(rng, cfg["label_distribution"])
template = rng.choice(TEMPLATES[label])
category, core_events = template(rng)
decoy_events = _make_decoy_events(rng, cfg["decoys"], start_idx=len(core_events))
events = _renumber_and_resort(core_events + decoy_events)
return IncidentParams(
target_label=label,
category=category,
events=events,
narrative="",
)
def make_alert(params: IncidentParams, alert_id: str) -> Alert:
"""Synthesize a SIEM alert summary from an incident's events."""
sev_for_label = {
TriageAction.DISMISS: "low",
TriageAction.MONITOR: "medium",
TriageAction.QUARANTINE_HOST: "high",
TriageAction.BLOCK_IP: "high",
TriageAction.ESCALATE: "critical",
}
severity = sev_for_label.get(params.target_label, "medium")
first_event = params.events[0]
user = str(first_event.fields.get("user", "user-001"))
host = str(first_event.fields.get("host", "host-001"))
summary = f"{params.category.value}: {len(params.events)} event(s); first={first_event.event_type.value}"
return Alert(
alert_id=alert_id,
category=params.category,
severity=severity,
summary=summary,
host=host,
user=user,
)
__all__ = ["generate_incident", "make_alert", "STAGE_CONFIGS", "TEMPLATES"]