opensoc-env / train /prompt_format.py
shivam2k3's picture
OpenSOC v1
bb6a031
"""Shared prompt formatter for the defender role.
The same formatter is used by:
* SFT dataset generation (`make_sft_dataset.py`)
* GRPO rollouts (`train_grpo.ipynb`)
* Evaluation harness (`eval/eval.py`)
Keeping the prompt template centralized prevents the classic train-eval
mismatch where a model regresses because the inference prompt differs
slightly from the training prompt.
The defender response format is intentionally machine-parseable:
Action: <one of dismiss|monitor|quarantine_host|block_ip|escalate>
CitedLog: <log_id>
Rationale: <one short sentence>
The trainer rewards format compliance via `parse_defender_response`'s
`format_ok` flag (used as an auxiliary GRPO reward in the notebook).
"""
from __future__ import annotations
import json
import re
from dataclasses import dataclass
from typing import Dict, Optional
from schema import Alert, Event, TriageAction
SYSTEM_PROMPT = (
"You are an autonomous SOC tier-1 triage analyst. "
"You will receive a single SIEM alert and the small log window that triggered it. "
"Your job is to choose exactly ONE triage action and cite the log_id that drove your decision.\n\n"
"Allowed actions (highest cost last): dismiss, monitor, quarantine_host, block_ip, escalate.\n\n"
"Always respond in this exact format:\n"
"Action: <action>\n"
"CitedLog: <log_id>\n"
"Rationale: <one short sentence>"
)
def _event_to_line(e: Event) -> str:
"""Render a single event as a SIEM-style log line."""
fields = {k: v for k, v in (e.fields or {}).items() if v not in (None, "")}
fields_str = " ".join(
f"{k}={json.dumps(v) if isinstance(v, str) else v}" for k, v in fields.items()
)
return (
f"[{e.timestamp}] {e.log_id} src={e.source} type={e.event_type.value} {fields_str}"
).rstrip()
def render_defender_prompt(alert: Alert, events: list[Event]) -> str:
"""Render the user message for a defender turn."""
log_lines = "\n".join(_event_to_line(e) for e in events)
return (
f"Alert {alert.alert_id} | severity={alert.severity} | category={alert.category.value}\n"
f"Host: {alert.host} | User: {alert.user}\n"
f"Summary: {alert.summary}\n\n"
f"Log window:\n{log_lines}\n\n"
f"Triage this alert."
)
def render_defender_target(action: TriageAction, cited_log_id: str, rationale: str) -> str:
"""Render the gold response for SFT."""
return (
f"Action: {action.value}\n"
f"CitedLog: {cited_log_id}\n"
f"Rationale: {rationale}"
)
@dataclass
class ParsedDefenderResponse:
action: Optional[TriageAction]
cited_log_id: Optional[str]
rationale: str
format_ok: bool
_ACTION_RE = re.compile(r"action\s*:\s*([a-z_]+)", re.IGNORECASE)
_CITE_RE = re.compile(r"cited\s*log\s*:\s*([A-Za-z0-9\-_]+)", re.IGNORECASE)
_RATIONALE_RE = re.compile(r"rationale\s*:\s*(.+)", re.IGNORECASE | re.DOTALL)
def parse_defender_response(text: str) -> ParsedDefenderResponse:
"""Best-effort parse of a defender model output.
Returns ``format_ok=True`` only if all three fields parse and the
action is a recognized `TriageAction`. GRPO rollouts can use this as
a small +0.05 format-compliance bonus.
"""
action_match = _ACTION_RE.search(text)
cite_match = _CITE_RE.search(text)
rat_match = _RATIONALE_RE.search(text)
action: Optional[TriageAction] = None
if action_match:
try:
action = TriageAction(action_match.group(1).lower())
except ValueError:
action = None
cited = cite_match.group(1) if cite_match else None
rationale = rat_match.group(1).strip() if rat_match else ""
format_ok = bool(action and cited and rationale)
return ParsedDefenderResponse(
action=action,
cited_log_id=cited,
rationale=rationale,
format_ok=format_ok,
)
def to_chat_messages(alert: Alert, events: list[Event]) -> list[Dict[str, str]]:
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": render_defender_prompt(alert, events)},
]
__all__ = [
"SYSTEM_PROMPT",
"render_defender_prompt",
"render_defender_target",
"parse_defender_response",
"to_chat_messages",
"ParsedDefenderResponse",
]