| """Tests for `train.prompt_format` and the SFT dataset round-trip. |
| |
| Run with: pytest tests/test_prompt_format.py -v |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import sys |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) |
|
|
| from generator import generate_incident, make_alert |
| from schema import TriageAction |
| from train.prompt_format import ( |
| parse_defender_response, |
| render_defender_prompt, |
| render_defender_target, |
| ) |
| from verifier import compute_ground_truth |
|
|
|
|
| class TestPromptFormat: |
| def test_parse_round_trip(self): |
| rendered = render_defender_target( |
| action=TriageAction.QUARANTINE_HOST, |
| cited_log_id="L1-7", |
| rationale="encoded powershell from outlook is malware", |
| ) |
| parsed = parse_defender_response(rendered) |
| assert parsed.action is TriageAction.QUARANTINE_HOST |
| assert parsed.cited_log_id == "L1-7" |
| assert parsed.format_ok |
|
|
| def test_parse_handles_extra_whitespace(self): |
| text = "Action: block_ip\nCitedLog: L1-2\nRationale: external beacon" |
| p = parse_defender_response(text) |
| assert p.action is TriageAction.BLOCK_IP |
| assert p.cited_log_id == "L1-2" |
| assert p.format_ok |
|
|
| def test_parse_rejects_unknown_action(self): |
| text = "Action: yolo\nCitedLog: L1-0\nRationale: nope" |
| p = parse_defender_response(text) |
| assert p.action is None |
| assert not p.format_ok |
|
|
| def test_parse_returns_format_ok_false_on_garbage(self): |
| text = "Sure! I think we should block the IP and call IT." |
| p = parse_defender_response(text) |
| assert not p.format_ok |
|
|
| def test_render_prompt_contains_all_log_ids(self): |
| params = generate_incident("stage2_multi", seed=99) |
| alert = make_alert(params, "A-TEST") |
| prompt = render_defender_prompt(alert, params.events) |
| for e in params.events: |
| assert e.log_id in prompt |
| assert alert.alert_id in prompt |
| assert alert.summary in prompt |
|
|
|
|
| class TestSftDataset: |
| DATASET = os.path.join( |
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))), |
| "data", "sft_train.jsonl", |
| ) |
|
|
| def test_dataset_exists_and_targets_are_well_formed(self): |
| assert os.path.exists(self.DATASET), "Run `python -m train.make_sft_dataset` first." |
| n = 0 |
| with open(self.DATASET) as f: |
| for line in f: |
| ex = json.loads(line) |
| assert ex["messages"][0]["role"] == "system" |
| assert ex["messages"][1]["role"] == "user" |
| assert ex["messages"][2]["role"] == "assistant" |
| parsed = parse_defender_response(ex["messages"][2]["content"]) |
| assert parsed.format_ok, ex["messages"][2]["content"] |
| assert parsed.action.value == ex["ground_truth"] |
| n += 1 |
| assert n >= 100 |
|
|
| def test_dataset_targets_match_verifier(self): |
| |
| with open(self.DATASET) as f: |
| for i, line in enumerate(f): |
| if i >= 50: |
| break |
| ex = json.loads(line) |
| params = generate_incident(ex["stage"], ex["seed"]) |
| gt, _ = compute_ground_truth(params) |
| assert gt.value == ex["ground_truth"], ( |
| f"verifier disagrees with SFT target at line {i}: " |
| f"{gt.value} != {ex['ground_truth']}" |
| ) |
|
|