opensoc-env / eval /make_holdout.py
shivam2k3's picture
OpenSOC v1
bb6a031
"""Build the frozen 200-incident hold-out evaluation set.
Run::
python -m eval.make_holdout --out data/holdout.jsonl
This file is committed to the repo so reviewers can verify reported
numbers byte-for-byte without rerunning the generator. The seeds used
here are *outside* the SFT and GRPO seed bands declared in
`tasks/registry.py` (seed_offset 1k-4k for training, 90k-94k here) so
there is zero overlap between train and eval.
Each record::
{ "alert": {...}, "events": [...], "ground_truth": "<action>",
"triggering_log_id": "<id>", "stage": "<stage>", "seed": <int> }
`eval/eval.py` consumes this format directly.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from collections import Counter
_HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(_HERE))
from generator import generate_incident, make_alert # noqa: E402
from verifier import compute_ground_truth # noqa: E402
# Seed bands — kept distinct from training seed bands.
HOLDOUT_SEED_BAND = {
"stage1_basic": 90_000,
"stage2_multi": 91_000,
"stage3_mixed": 92_000,
"stage4_adversarial": 93_000,
}
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--n-per-stage", type=int, default=50,
help="Number of incidents per stage (default 50 → 200 total).")
parser.add_argument("--out", default="data/holdout.jsonl")
args = parser.parse_args()
out_path = os.path.join(os.path.dirname(_HERE), args.out)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
counts: Counter = Counter()
written = 0
with open(out_path, "w", encoding="utf-8") as f:
for stage_id, base in HOLDOUT_SEED_BAND.items():
for i in range(args.n_per_stage):
seed = base + i
params = generate_incident(stage_id, seed)
alert = make_alert(params, alert_id=f"A-EVAL-{stage_id[-1]}-{seed}")
gt, sig = compute_ground_truth(params)
rec = {
"alert": alert.model_dump(mode="json"),
"events": [e.model_dump(mode="json") for e in params.events],
"ground_truth": gt.value,
"triggering_log_id": sig.triggering_log_id or params.events[0].log_id,
"stage": stage_id,
"seed": seed,
}
f.write(json.dumps(rec) + "\n")
counts[gt.value] += 1
written += 1
print(f"Wrote {written} hold-out incidents to {out_path}")
print("Label distribution:")
for k, v in sorted(counts.items()):
print(f" {k:<18} {v:4d} ({100 * v / written:5.1f}%)")
if __name__ == "__main__":
main()