opensoc-env / tests /test_prompt_format.py
shivam2k3's picture
OpenSOC v1
bb6a031
"""Tests for `train.prompt_format` and the SFT dataset round-trip.
Run with: pytest tests/test_prompt_format.py -v
"""
from __future__ import annotations
import json
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from generator import generate_incident, make_alert # noqa: E402
from schema import TriageAction # noqa: E402
from train.prompt_format import ( # noqa: E402
parse_defender_response,
render_defender_prompt,
render_defender_target,
)
from verifier import compute_ground_truth # noqa: E402
class TestPromptFormat:
def test_parse_round_trip(self):
rendered = render_defender_target(
action=TriageAction.QUARANTINE_HOST,
cited_log_id="L1-7",
rationale="encoded powershell from outlook is malware",
)
parsed = parse_defender_response(rendered)
assert parsed.action is TriageAction.QUARANTINE_HOST
assert parsed.cited_log_id == "L1-7"
assert parsed.format_ok
def test_parse_handles_extra_whitespace(self):
text = "Action: block_ip\nCitedLog: L1-2\nRationale: external beacon"
p = parse_defender_response(text)
assert p.action is TriageAction.BLOCK_IP
assert p.cited_log_id == "L1-2"
assert p.format_ok
def test_parse_rejects_unknown_action(self):
text = "Action: yolo\nCitedLog: L1-0\nRationale: nope"
p = parse_defender_response(text)
assert p.action is None
assert not p.format_ok
def test_parse_returns_format_ok_false_on_garbage(self):
text = "Sure! I think we should block the IP and call IT."
p = parse_defender_response(text)
assert not p.format_ok
def test_render_prompt_contains_all_log_ids(self):
params = generate_incident("stage2_multi", seed=99)
alert = make_alert(params, "A-TEST")
prompt = render_defender_prompt(alert, params.events)
for e in params.events:
assert e.log_id in prompt
assert alert.alert_id in prompt
assert alert.summary in prompt
class TestSftDataset:
DATASET = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"data", "sft_train.jsonl",
)
def test_dataset_exists_and_targets_are_well_formed(self):
assert os.path.exists(self.DATASET), "Run `python -m train.make_sft_dataset` first."
n = 0
with open(self.DATASET) as f:
for line in f:
ex = json.loads(line)
assert ex["messages"][0]["role"] == "system"
assert ex["messages"][1]["role"] == "user"
assert ex["messages"][2]["role"] == "assistant"
parsed = parse_defender_response(ex["messages"][2]["content"])
assert parsed.format_ok, ex["messages"][2]["content"]
assert parsed.action.value == ex["ground_truth"]
n += 1
assert n >= 100 # we asked for 600
def test_dataset_targets_match_verifier(self):
# Cross-check: re-run the verifier and confirm SFT targets agree.
with open(self.DATASET) as f:
for i, line in enumerate(f):
if i >= 50:
break # spot-check; full check is expensive
ex = json.loads(line)
params = generate_incident(ex["stage"], ex["seed"])
gt, _ = compute_ground_truth(params)
assert gt.value == ex["ground_truth"], (
f"verifier disagrees with SFT target at line {i}: "
f"{gt.value} != {ex['ground_truth']}"
)