| """Generate the SFT warm-start dataset for OpenSOC defender training. |
| |
| Run:: |
| |
| python -m train.make_sft_dataset --n 600 --out data/sft_train.jsonl |
| |
| Each output line is a JSON object with:: |
| |
| { "messages": [ {"role": "system", ...}, {"role": "user", ...}, |
| {"role": "assistant", ...} ], |
| "ground_truth": "<action>", |
| "stage": "<stage_id>", |
| "seed": <int> } |
| |
| Format-compliance is enforced by construction: targets are produced via |
| `prompt_format.render_defender_target` and stay bit-exact with what the |
| GRPO rollout parser expects. |
| |
| Distribution |
| ------------ |
| We balance across the 4 curriculum stages and across the 5 ground-truth |
| labels so the model sees enough rare-class examples (ESCALATE is only |
| ~10-15% of stage 1). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| from collections import Counter |
| from typing import Dict |
|
|
| _HERE = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, os.path.dirname(_HERE)) |
|
|
| from generator import generate_incident, make_alert |
| from schema import TriageAction |
| from tasks.registry import STAGE_REGISTRY |
| from train.prompt_format import ( |
| SYSTEM_PROMPT, |
| render_defender_prompt, |
| render_defender_target, |
| ) |
| from verifier import compute_ground_truth |
|
|
|
|
| RATIONALES: Dict[TriageAction, str] = { |
| TriageAction.DISMISS: "Benign noise; no compromise indicators present.", |
| TriageAction.MONITOR: "Weak signal; keep watching for follow-up activity.", |
| TriageAction.QUARANTINE_HOST: "Endpoint compromise indicator; isolate the host.", |
| TriageAction.BLOCK_IP: "External malicious network indicator; block the destination.", |
| TriageAction.ESCALATE: "Multi-stage compromise; page the on-call responder.", |
| } |
|
|
|
|
| def make_example(stage_id: str, seed: int) -> dict: |
| """Build a single SFT example for the given (stage, seed).""" |
| params = generate_incident(stage_id, seed) |
| alert = make_alert(params, alert_id=f"A-SFT-{stage_id[-1]}-{seed:05d}") |
| gt, sig = compute_ground_truth(params) |
| user_msg = render_defender_prompt(alert, params.events) |
| target = render_defender_target( |
| action=gt, |
| cited_log_id=sig.triggering_log_id or params.events[0].log_id, |
| rationale=RATIONALES[gt], |
| ) |
| return { |
| "messages": [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_msg}, |
| {"role": "assistant", "content": target}, |
| ], |
| "ground_truth": gt.value, |
| "stage": stage_id, |
| "seed": seed, |
| } |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--n", type=int, default=600, |
| help="Total number of examples (split across stages).") |
| parser.add_argument("--out", type=str, default="data/sft_train.jsonl") |
| parser.add_argument("--seed-base", type=int, default=10_000) |
| args = parser.parse_args() |
|
|
| stages = list(STAGE_REGISTRY.keys()) |
| per_stage = args.n // len(stages) |
|
|
| counts: Counter = Counter() |
| out_path = os.path.join(os.path.dirname(_HERE), args.out) |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
|
|
| written = 0 |
| with open(out_path, "w", encoding="utf-8") as f: |
| for stage_id in stages: |
| for i in range(per_stage): |
| ex = make_example(stage_id, seed=args.seed_base + i) |
| f.write(json.dumps(ex) + "\n") |
| counts[ex["ground_truth"]] += 1 |
| written += 1 |
|
|
| print(f"Wrote {written} examples to {out_path}") |
| print("Label distribution:") |
| for k, v in sorted(counts.items()): |
| print(f" {k:<18} {v:4d} ({100 * v / written:5.1f}%)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|