"""Shared prompt formatter for the defender role. The same formatter is used by: * SFT dataset generation (`make_sft_dataset.py`) * GRPO rollouts (`train_grpo.ipynb`) * Evaluation harness (`eval/eval.py`) Keeping the prompt template centralized prevents the classic train-eval mismatch where a model regresses because the inference prompt differs slightly from the training prompt. The defender response format is intentionally machine-parseable: Action: CitedLog: Rationale: The trainer rewards format compliance via `parse_defender_response`'s `format_ok` flag (used as an auxiliary GRPO reward in the notebook). """ from __future__ import annotations import json import re from dataclasses import dataclass from typing import Dict, Optional from schema import Alert, Event, TriageAction SYSTEM_PROMPT = ( "You are an autonomous SOC tier-1 triage analyst. " "You will receive a single SIEM alert and the small log window that triggered it. " "Your job is to choose exactly ONE triage action and cite the log_id that drove your decision.\n\n" "Allowed actions (highest cost last): dismiss, monitor, quarantine_host, block_ip, escalate.\n\n" "Always respond in this exact format:\n" "Action: \n" "CitedLog: \n" "Rationale: " ) def _event_to_line(e: Event) -> str: """Render a single event as a SIEM-style log line.""" fields = {k: v for k, v in (e.fields or {}).items() if v not in (None, "")} fields_str = " ".join( f"{k}={json.dumps(v) if isinstance(v, str) else v}" for k, v in fields.items() ) return ( f"[{e.timestamp}] {e.log_id} src={e.source} type={e.event_type.value} {fields_str}" ).rstrip() def render_defender_prompt(alert: Alert, events: list[Event]) -> str: """Render the user message for a defender turn.""" log_lines = "\n".join(_event_to_line(e) for e in events) return ( f"Alert {alert.alert_id} | severity={alert.severity} | category={alert.category.value}\n" f"Host: {alert.host} | User: {alert.user}\n" f"Summary: {alert.summary}\n\n" f"Log window:\n{log_lines}\n\n" f"Triage this alert." ) def render_defender_target(action: TriageAction, cited_log_id: str, rationale: str) -> str: """Render the gold response for SFT.""" return ( f"Action: {action.value}\n" f"CitedLog: {cited_log_id}\n" f"Rationale: {rationale}" ) @dataclass class ParsedDefenderResponse: action: Optional[TriageAction] cited_log_id: Optional[str] rationale: str format_ok: bool _ACTION_RE = re.compile(r"action\s*:\s*([a-z_]+)", re.IGNORECASE) _CITE_RE = re.compile(r"cited\s*log\s*:\s*([A-Za-z0-9\-_]+)", re.IGNORECASE) _RATIONALE_RE = re.compile(r"rationale\s*:\s*(.+)", re.IGNORECASE | re.DOTALL) def parse_defender_response(text: str) -> ParsedDefenderResponse: """Best-effort parse of a defender model output. Returns ``format_ok=True`` only if all three fields parse and the action is a recognized `TriageAction`. GRPO rollouts can use this as a small +0.05 format-compliance bonus. """ action_match = _ACTION_RE.search(text) cite_match = _CITE_RE.search(text) rat_match = _RATIONALE_RE.search(text) action: Optional[TriageAction] = None if action_match: try: action = TriageAction(action_match.group(1).lower()) except ValueError: action = None cited = cite_match.group(1) if cite_match else None rationale = rat_match.group(1).strip() if rat_match else "" format_ok = bool(action and cited and rationale) return ParsedDefenderResponse( action=action, cited_log_id=cited, rationale=rationale, format_ok=format_ok, ) def to_chat_messages(alert: Alert, events: list[Event]) -> list[Dict[str, str]]: return [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": render_defender_prompt(alert, events)}, ] __all__ = [ "SYSTEM_PROMPT", "render_defender_prompt", "render_defender_target", "parse_defender_response", "to_chat_messages", "ParsedDefenderResponse", ]