File size: 3,785 Bytes
bb6a031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Generate the SFT warm-start dataset for OpenSOC defender training.

Run::

    python -m train.make_sft_dataset --n 600 --out data/sft_train.jsonl

Each output line is a JSON object with::

    { "messages": [ {"role": "system", ...}, {"role": "user", ...},
                    {"role": "assistant", ...} ],
      "ground_truth": "<action>",
      "stage": "<stage_id>",
      "seed": <int> }

Format-compliance is enforced by construction: targets are produced via
`prompt_format.render_defender_target` and stay bit-exact with what the
GRPO rollout parser expects.

Distribution
------------
We balance across the 4 curriculum stages and across the 5 ground-truth
labels so the model sees enough rare-class examples (ESCALATE is only
~10-15% of stage 1).
"""

from __future__ import annotations

import argparse
import json
import os
import sys
from collections import Counter
from typing import Dict

_HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(_HERE))

from generator import generate_incident, make_alert  # noqa: E402
from schema import TriageAction  # noqa: E402
from tasks.registry import STAGE_REGISTRY  # noqa: E402
from train.prompt_format import (  # noqa: E402
    SYSTEM_PROMPT,
    render_defender_prompt,
    render_defender_target,
)
from verifier import compute_ground_truth  # noqa: E402


RATIONALES: Dict[TriageAction, str] = {
    TriageAction.DISMISS: "Benign noise; no compromise indicators present.",
    TriageAction.MONITOR: "Weak signal; keep watching for follow-up activity.",
    TriageAction.QUARANTINE_HOST: "Endpoint compromise indicator; isolate the host.",
    TriageAction.BLOCK_IP: "External malicious network indicator; block the destination.",
    TriageAction.ESCALATE: "Multi-stage compromise; page the on-call responder.",
}


def make_example(stage_id: str, seed: int) -> dict:
    """Build a single SFT example for the given (stage, seed)."""
    params = generate_incident(stage_id, seed)
    alert = make_alert(params, alert_id=f"A-SFT-{stage_id[-1]}-{seed:05d}")
    gt, sig = compute_ground_truth(params)
    user_msg = render_defender_prompt(alert, params.events)
    target = render_defender_target(
        action=gt,
        cited_log_id=sig.triggering_log_id or params.events[0].log_id,
        rationale=RATIONALES[gt],
    )
    return {
        "messages": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_msg},
            {"role": "assistant", "content": target},
        ],
        "ground_truth": gt.value,
        "stage": stage_id,
        "seed": seed,
    }


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--n", type=int, default=600,
                        help="Total number of examples (split across stages).")
    parser.add_argument("--out", type=str, default="data/sft_train.jsonl")
    parser.add_argument("--seed-base", type=int, default=10_000)
    args = parser.parse_args()

    stages = list(STAGE_REGISTRY.keys())
    per_stage = args.n // len(stages)

    counts: Counter = Counter()
    out_path = os.path.join(os.path.dirname(_HERE), args.out)
    os.makedirs(os.path.dirname(out_path), exist_ok=True)

    written = 0
    with open(out_path, "w", encoding="utf-8") as f:
        for stage_id in stages:
            for i in range(per_stage):
                ex = make_example(stage_id, seed=args.seed_base + i)
                f.write(json.dumps(ex) + "\n")
                counts[ex["ground_truth"]] += 1
                written += 1

    print(f"Wrote {written} examples to {out_path}")
    print("Label distribution:")
    for k, v in sorted(counts.items()):
        print(f"  {k:<18} {v:4d} ({100 * v / written:5.1f}%)")


if __name__ == "__main__":
    main()