"""Tests for `train.prompt_format` and the SFT dataset round-trip. Run with: pytest tests/test_prompt_format.py -v """ from __future__ import annotations import json import os import sys sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) from generator import generate_incident, make_alert # noqa: E402 from schema import TriageAction # noqa: E402 from train.prompt_format import ( # noqa: E402 parse_defender_response, render_defender_prompt, render_defender_target, ) from verifier import compute_ground_truth # noqa: E402 class TestPromptFormat: def test_parse_round_trip(self): rendered = render_defender_target( action=TriageAction.QUARANTINE_HOST, cited_log_id="L1-7", rationale="encoded powershell from outlook is malware", ) parsed = parse_defender_response(rendered) assert parsed.action is TriageAction.QUARANTINE_HOST assert parsed.cited_log_id == "L1-7" assert parsed.format_ok def test_parse_handles_extra_whitespace(self): text = "Action: block_ip\nCitedLog: L1-2\nRationale: external beacon" p = parse_defender_response(text) assert p.action is TriageAction.BLOCK_IP assert p.cited_log_id == "L1-2" assert p.format_ok def test_parse_rejects_unknown_action(self): text = "Action: yolo\nCitedLog: L1-0\nRationale: nope" p = parse_defender_response(text) assert p.action is None assert not p.format_ok def test_parse_returns_format_ok_false_on_garbage(self): text = "Sure! I think we should block the IP and call IT." p = parse_defender_response(text) assert not p.format_ok def test_render_prompt_contains_all_log_ids(self): params = generate_incident("stage2_multi", seed=99) alert = make_alert(params, "A-TEST") prompt = render_defender_prompt(alert, params.events) for e in params.events: assert e.log_id in prompt assert alert.alert_id in prompt assert alert.summary in prompt class TestSftDataset: DATASET = os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data", "sft_train.jsonl", ) def test_dataset_exists_and_targets_are_well_formed(self): assert os.path.exists(self.DATASET), "Run `python -m train.make_sft_dataset` first." n = 0 with open(self.DATASET) as f: for line in f: ex = json.loads(line) assert ex["messages"][0]["role"] == "system" assert ex["messages"][1]["role"] == "user" assert ex["messages"][2]["role"] == "assistant" parsed = parse_defender_response(ex["messages"][2]["content"]) assert parsed.format_ok, ex["messages"][2]["content"] assert parsed.action.value == ex["ground_truth"] n += 1 assert n >= 100 # we asked for 600 def test_dataset_targets_match_verifier(self): # Cross-check: re-run the verifier and confirm SFT targets agree. with open(self.DATASET) as f: for i, line in enumerate(f): if i >= 50: break # spot-check; full check is expensive ex = json.loads(line) params = generate_incident(ex["stage"], ex["seed"]) gt, _ = compute_ground_truth(params) assert gt.value == ex["ground_truth"], ( f"verifier disagrees with SFT target at line {i}: " f"{gt.value} != {ex['ground_truth']}" )