opensoc-env / train /make_sft_dataset.py
shivam2k3's picture
OpenSOC v1
bb6a031
"""Generate the SFT warm-start dataset for OpenSOC defender training.
Run::
python -m train.make_sft_dataset --n 600 --out data/sft_train.jsonl
Each output line is a JSON object with::
{ "messages": [ {"role": "system", ...}, {"role": "user", ...},
{"role": "assistant", ...} ],
"ground_truth": "<action>",
"stage": "<stage_id>",
"seed": <int> }
Format-compliance is enforced by construction: targets are produced via
`prompt_format.render_defender_target` and stay bit-exact with what the
GRPO rollout parser expects.
Distribution
------------
We balance across the 4 curriculum stages and across the 5 ground-truth
labels so the model sees enough rare-class examples (ESCALATE is only
~10-15% of stage 1).
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from collections import Counter
from typing import Dict
_HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(_HERE))
from generator import generate_incident, make_alert # noqa: E402
from schema import TriageAction # noqa: E402
from tasks.registry import STAGE_REGISTRY # noqa: E402
from train.prompt_format import ( # noqa: E402
SYSTEM_PROMPT,
render_defender_prompt,
render_defender_target,
)
from verifier import compute_ground_truth # noqa: E402
RATIONALES: Dict[TriageAction, str] = {
TriageAction.DISMISS: "Benign noise; no compromise indicators present.",
TriageAction.MONITOR: "Weak signal; keep watching for follow-up activity.",
TriageAction.QUARANTINE_HOST: "Endpoint compromise indicator; isolate the host.",
TriageAction.BLOCK_IP: "External malicious network indicator; block the destination.",
TriageAction.ESCALATE: "Multi-stage compromise; page the on-call responder.",
}
def make_example(stage_id: str, seed: int) -> dict:
"""Build a single SFT example for the given (stage, seed)."""
params = generate_incident(stage_id, seed)
alert = make_alert(params, alert_id=f"A-SFT-{stage_id[-1]}-{seed:05d}")
gt, sig = compute_ground_truth(params)
user_msg = render_defender_prompt(alert, params.events)
target = render_defender_target(
action=gt,
cited_log_id=sig.triggering_log_id or params.events[0].log_id,
rationale=RATIONALES[gt],
)
return {
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_msg},
{"role": "assistant", "content": target},
],
"ground_truth": gt.value,
"stage": stage_id,
"seed": seed,
}
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int, default=600,
help="Total number of examples (split across stages).")
parser.add_argument("--out", type=str, default="data/sft_train.jsonl")
parser.add_argument("--seed-base", type=int, default=10_000)
args = parser.parse_args()
stages = list(STAGE_REGISTRY.keys())
per_stage = args.n // len(stages)
counts: Counter = Counter()
out_path = os.path.join(os.path.dirname(_HERE), args.out)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
written = 0
with open(out_path, "w", encoding="utf-8") as f:
for stage_id in stages:
for i in range(per_stage):
ex = make_example(stage_id, seed=args.seed_base + i)
f.write(json.dumps(ex) + "\n")
counts[ex["ground_truth"]] += 1
written += 1
print(f"Wrote {written} examples to {out_path}")
print("Label distribution:")
for k, v in sorted(counts.items()):
print(f" {k:<18} {v:4d} ({100 * v / written:5.1f}%)")
if __name__ == "__main__":
main()