"""Generate the SFT warm-start dataset for OpenSOC defender training. Run:: python -m train.make_sft_dataset --n 600 --out data/sft_train.jsonl Each output line is a JSON object with:: { "messages": [ {"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...} ], "ground_truth": "", "stage": "", "seed": } Format-compliance is enforced by construction: targets are produced via `prompt_format.render_defender_target` and stay bit-exact with what the GRPO rollout parser expects. Distribution ------------ We balance across the 4 curriculum stages and across the 5 ground-truth labels so the model sees enough rare-class examples (ESCALATE is only ~10-15% of stage 1). """ from __future__ import annotations import argparse import json import os import sys from collections import Counter from typing import Dict _HERE = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.dirname(_HERE)) from generator import generate_incident, make_alert # noqa: E402 from schema import TriageAction # noqa: E402 from tasks.registry import STAGE_REGISTRY # noqa: E402 from train.prompt_format import ( # noqa: E402 SYSTEM_PROMPT, render_defender_prompt, render_defender_target, ) from verifier import compute_ground_truth # noqa: E402 RATIONALES: Dict[TriageAction, str] = { TriageAction.DISMISS: "Benign noise; no compromise indicators present.", TriageAction.MONITOR: "Weak signal; keep watching for follow-up activity.", TriageAction.QUARANTINE_HOST: "Endpoint compromise indicator; isolate the host.", TriageAction.BLOCK_IP: "External malicious network indicator; block the destination.", TriageAction.ESCALATE: "Multi-stage compromise; page the on-call responder.", } def make_example(stage_id: str, seed: int) -> dict: """Build a single SFT example for the given (stage, seed).""" params = generate_incident(stage_id, seed) alert = make_alert(params, alert_id=f"A-SFT-{stage_id[-1]}-{seed:05d}") gt, sig = compute_ground_truth(params) user_msg = render_defender_prompt(alert, params.events) target = render_defender_target( action=gt, cited_log_id=sig.triggering_log_id or params.events[0].log_id, rationale=RATIONALES[gt], ) return { "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_msg}, {"role": "assistant", "content": target}, ], "ground_truth": gt.value, "stage": stage_id, "seed": seed, } def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--n", type=int, default=600, help="Total number of examples (split across stages).") parser.add_argument("--out", type=str, default="data/sft_train.jsonl") parser.add_argument("--seed-base", type=int, default=10_000) args = parser.parse_args() stages = list(STAGE_REGISTRY.keys()) per_stage = args.n // len(stages) counts: Counter = Counter() out_path = os.path.join(os.path.dirname(_HERE), args.out) os.makedirs(os.path.dirname(out_path), exist_ok=True) written = 0 with open(out_path, "w", encoding="utf-8") as f: for stage_id in stages: for i in range(per_stage): ex = make_example(stage_id, seed=args.seed_base + i) f.write(json.dumps(ex) + "\n") counts[ex["ground_truth"]] += 1 written += 1 print(f"Wrote {written} examples to {out_path}") print("Label distribution:") for k, v in sorted(counts.items()): print(f" {k:<18} {v:4d} ({100 * v / written:5.1f}%)") if __name__ == "__main__": main()