""" generator.py — Deterministic, seeded incident generator. Used by: * `OpenSOCEnv` in `defender_only` mode, when the env needs to materialize a self-contained incident for the defender (SFT warm-start, eval, smoke tests, curriculum starter prompts). * `train/sft_warmstart.py` to produce ~600 (incident, triage) pairs for bootstrapping defender format learning. * `train/make_holdout.py` to build the frozen 200-incident eval set. The generator emits `IncidentParams` instances; the env then materializes them into `Incident` objects with a SIEM-style `Alert` summary. The attacker is *not* required to use this generator — its only job is to give the env a deterministic starting distribution for stages 1-4. Seeding contract ---------------- ``generate_incident(stage_id, seed=N)`` is referentially transparent: calling it with the same arguments anywhere in the codebase returns the exact same incident. This is what makes the held-out eval set reproducible across machines. """ from __future__ import annotations import random from datetime import datetime, timedelta, timezone from typing import Callable, Dict, List, Tuple from schema import ( Alert, EventType, IncidentCategory, IncidentParams, TriageAction, make_event, ) # --------------------------------------------------------------------------- # Time helpers # --------------------------------------------------------------------------- def _ts_iter(start: datetime, n: int, step_s: int = 5) -> List[str]: """Return n monotonic ISO-8601 UTC timestamps starting at `start`.""" return [ (start + timedelta(seconds=step_s * i)) .replace(tzinfo=timezone.utc) .strftime("%Y-%m-%dT%H:%M:%SZ") for i in range(n) ] def _start_time(rng: random.Random) -> datetime: """Pick a recent UTC start time anchored on the current calendar day.""" base = datetime(2026, 4, 25, 0, 0, 0) minutes = rng.randint(0, 60 * 23) return base + timedelta(minutes=minutes) # --------------------------------------------------------------------------- # Single-event template builders, keyed by *true* ground-truth label # --------------------------------------------------------------------------- def _benign_login(rng: random.Random) -> Tuple[IncidentCategory, list]: ts = _ts_iter(_start_time(rng), 1) user = rng.choice(["alice", "bob", "carol", "dave", "erin"]) return IncidentCategory.BENIGN_NOISE, [ make_event( 1, 0, EventType.AUTH_LOGIN_SUCCESS, ts[0], source="identity", user=user, src_ip=f"10.0.0.{rng.randint(2, 200)}", ) ] def _benign_internal_beacon(rng: random.Random) -> Tuple[IncidentCategory, list]: ts = _ts_iter(_start_time(rng), 1) return IncidentCategory.BENIGN_NOISE, [ make_event( 1, 0, EventType.NET_BEACON, ts[0], source="network", dst_ip=f"10.0.{rng.randint(0, 255)}.{rng.randint(2, 250)}", interval_s=rng.choice([30, 60, 90]), ) ] def _monitor_brute_force(rng: random.Random) -> Tuple[IncidentCategory, list]: n_failures = rng.randint(5, 9) ts = _ts_iter(_start_time(rng), n_failures) user = rng.choice(["root", "admin", "service_acct", "alice"]) return IncidentCategory.BRUTE_FORCE, [ make_event( 1, i, EventType.AUTH_LOGIN_FAILURE, ts[i], source="identity", user=user, src_ip=f"203.0.113.{rng.randint(2, 250)}", ) for i in range(n_failures) ] def _monitor_port_scan(rng: random.Random) -> Tuple[IncidentCategory, list]: ts = _ts_iter(_start_time(rng), 1) return IncidentCategory.BENIGN_NOISE, [ make_event( 1, 0, EventType.NET_PORT_SCAN_HIT, ts[0], source="network", src_ip=f"203.0.113.{rng.randint(2, 250)}", port=rng.choice([22, 23, 445, 3389, 5985]), ), ] def _quarantine_lolbin(rng: random.Random) -> Tuple[IncidentCategory, list]: ts = _ts_iter(_start_time(rng), 2) parent = rng.choice(["outlook.exe", "winword.exe", "excel.exe", "chrome.exe"]) return IncidentCategory.MALWARE_EXECUTION, [ make_event( 1, 0, EventType.PROC_LOLBIN, ts[0], source="endpoint", process="powershell.exe", parent_process=parent, cmd="-enc " + "A" * rng.randint(40, 80), ), make_event( 1, 1, EventType.FILE_WRITE, ts[1], source="endpoint", path=f"C:\\Users\\Public\\{rng.choice(['payload', 'svc', 'tmp'])}.exe", ), ] def _quarantine_edr_high(rng: random.Random) -> Tuple[IncidentCategory, list]: ts = _ts_iter(_start_time(rng), 1) return IncidentCategory.MALWARE_EXECUTION, [ make_event( 1, 0, EventType.EDR_BEHAVIOR_MATCH, ts[0], source="edr", severity="high", rule_id=f"EDR-{rng.randint(1000, 9999)}", host=f"host-{rng.randint(1, 50):03d}", ) ] def _block_ip_external_beacon(rng: random.Random) -> Tuple[IncidentCategory, list]: ts = _ts_iter(_start_time(rng), 1) return IncidentCategory.C2_BEACON, [ make_event( 1, 0, EventType.NET_BEACON, ts[0], source="network", dst_ip=rng.choice(["203.0.113.5", "198.51.100.42", "185.220.101.7"]), interval_s=rng.choice([30, 60, 90]), ) ] def _block_ip_phish_url(rng: random.Random) -> Tuple[IncidentCategory, list]: ts = _ts_iter(_start_time(rng), 1) tld = rng.choice([".ru", ".cn", ".top", ".xyz", ".click"]) return IncidentCategory.PHISHING, [ make_event( 1, 0, EventType.EMAIL_LINK_CLICKED, ts[0], source="email", url=f"https://login-update{tld}/secure", user=rng.choice(["alice", "bob", "carol"]), ) ] def _escalate_combined(rng: random.Random) -> Tuple[IncidentCategory, list]: ts = _ts_iter(_start_time(rng), 3, step_s=15) return IncidentCategory.MALWARE_EXECUTION, [ make_event( 1, 0, EventType.PROC_LOLBIN, ts[0], source="endpoint", process="powershell.exe", parent_process="outlook.exe", cmd="-enc " + "B" * 60, ), make_event( 1, 1, EventType.NET_BEACON, ts[1], source="network", dst_ip="203.0.113.5", interval_s=30, ), make_event( 1, 2, EventType.AUTH_PRIVILEGE_GRANT, ts[2], source="identity", user=rng.choice(["alice", "bob"]), role="admin", ), ] # Mapping: ground-truth label → list of template builders TEMPLATES: Dict[TriageAction, List[Callable[[random.Random], Tuple[IncidentCategory, list]]]] = { TriageAction.DISMISS: [_benign_login, _benign_internal_beacon], TriageAction.MONITOR: [_monitor_brute_force, _monitor_port_scan], TriageAction.QUARANTINE_HOST: [_quarantine_lolbin, _quarantine_edr_high], TriageAction.BLOCK_IP: [_block_ip_external_beacon, _block_ip_phish_url], TriageAction.ESCALATE: [_escalate_combined], } # --------------------------------------------------------------------------- # Stage configs # --------------------------------------------------------------------------- # Each stage has: # - label_distribution: probability mass over ground-truth labels (must sum to 1) # - decoys: number of *additional* benign-looking events to splice in # - jitter: how much we perturb fields (0.0 = none, 1.0 = max) STAGE_CONFIGS: Dict[str, dict] = { "stage1_basic": { "label_distribution": { TriageAction.DISMISS: 0.30, TriageAction.MONITOR: 0.20, TriageAction.QUARANTINE_HOST: 0.20, TriageAction.BLOCK_IP: 0.20, TriageAction.ESCALATE: 0.10, }, "decoys": 0, "jitter": 0.0, }, "stage2_multi": { "label_distribution": { TriageAction.DISMISS: 0.20, TriageAction.MONITOR: 0.20, TriageAction.QUARANTINE_HOST: 0.25, TriageAction.BLOCK_IP: 0.20, TriageAction.ESCALATE: 0.15, }, "decoys": 1, "jitter": 0.2, }, "stage3_mixed": { "label_distribution": { TriageAction.DISMISS: 0.25, TriageAction.MONITOR: 0.25, TriageAction.QUARANTINE_HOST: 0.20, TriageAction.BLOCK_IP: 0.15, TriageAction.ESCALATE: 0.15, }, "decoys": 2, "jitter": 0.4, }, "stage4_adversarial": { "label_distribution": { TriageAction.DISMISS: 0.30, TriageAction.MONITOR: 0.25, TriageAction.QUARANTINE_HOST: 0.15, TriageAction.BLOCK_IP: 0.15, TriageAction.ESCALATE: 0.15, }, "decoys": 3, "jitter": 0.7, }, } def _sample_label(rng: random.Random, dist: Dict[TriageAction, float]) -> TriageAction: labels = list(dist.keys()) weights = [dist[lab] for lab in labels] return rng.choices(labels, weights=weights, k=1)[0] def _make_decoy_events(rng: random.Random, n_decoys: int, start_idx: int) -> list: """Generate `n_decoys` benign decoy events that don't change the label. Decoys are drawn from a pool that is provably benign by the verifier: a successful login, an internal DNS query, an internal outbound flow. """ ts = _ts_iter(_start_time(rng), n_decoys, step_s=2) decoys = [] for i in range(n_decoys): choice = rng.randint(0, 2) n = start_idx + i if choice == 0: decoys.append(make_event( 1, n, EventType.AUTH_LOGIN_SUCCESS, ts[i], source="identity", user=rng.choice(["alice", "bob", "carol", "dave"]), src_ip=f"10.0.0.{rng.randint(2, 250)}", )) elif choice == 1: decoys.append(make_event( 1, n, EventType.NET_DNS_QUERY, ts[i], source="network", domain=rng.choice(["github.com", "google.com", "internal.corp"]), )) else: decoys.append(make_event( 1, n, EventType.NET_OUTBOUND, ts[i], source="network", dst_ip=f"10.0.{rng.randint(0, 255)}.{rng.randint(2, 250)}", bytes_out=rng.randint(1_000, 100_000), )) return decoys def _renumber_and_resort(events: list) -> list: """Rewrite log_ids to L1-0..L1-N-1 and sort by timestamp.""" events = sorted(events, key=lambda e: e.timestamp) fixed = [] for i, e in enumerate(events): fixed.append( type(e)( log_id=f"L1-{i}", timestamp=e.timestamp, source=e.source, event_type=e.event_type, fields=e.fields, ) ) return fixed # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def generate_incident(stage_id: str, seed: int) -> IncidentParams: """Deterministically generate an `IncidentParams` for the given stage.""" if stage_id not in STAGE_CONFIGS: raise ValueError( f"Unknown stage_id {stage_id!r}; choose from {list(STAGE_CONFIGS)}" ) cfg = STAGE_CONFIGS[stage_id] rng = random.Random(seed) label = _sample_label(rng, cfg["label_distribution"]) template = rng.choice(TEMPLATES[label]) category, core_events = template(rng) decoy_events = _make_decoy_events(rng, cfg["decoys"], start_idx=len(core_events)) events = _renumber_and_resort(core_events + decoy_events) return IncidentParams( target_label=label, category=category, events=events, narrative="", ) def make_alert(params: IncidentParams, alert_id: str) -> Alert: """Synthesize a SIEM alert summary from an incident's events.""" sev_for_label = { TriageAction.DISMISS: "low", TriageAction.MONITOR: "medium", TriageAction.QUARANTINE_HOST: "high", TriageAction.BLOCK_IP: "high", TriageAction.ESCALATE: "critical", } severity = sev_for_label.get(params.target_label, "medium") first_event = params.events[0] user = str(first_event.fields.get("user", "user-001")) host = str(first_event.fields.get("host", "host-001")) summary = f"{params.category.value}: {len(params.events)} event(s); first={first_event.event_type.value}" return Alert( alert_id=alert_id, category=params.category, severity=severity, summary=summary, host=host, user=user, ) __all__ = ["generate_incident", "make_alert", "STAGE_CONFIGS", "TEMPLATES"]