shadowops-hackathon / backend-ml /training /shadowops_training_common.py
ShadowOps Deploy
Final deploy: Monolithic ShadowOps app + Training Scripts
d064478
"""
Shared utilities for the ShadowOps Qwen3 SFT + GRPO training pipeline.
This module keeps dataset generation, action parsing, reward shaping, baseline
evaluation, oracle checks, smoke tests, and report generation on one code path
so SFT, GRPO, and final validation cannot drift apart.
"""
from __future__ import annotations
import argparse
import contextlib
import copy
import importlib
import importlib.abc
import importlib.machinery
import importlib.util
import inspect
import json
import math
import random
import re
import statistics
import subprocess
import sys
import time
import warnings
from collections import Counter
from dataclasses import dataclass, field
from importlib import import_module
from importlib.metadata import PackageNotFoundError, version as package_version
from pathlib import Path
from typing import Any, Iterable, Optional
from packaging.version import Version
BACKEND_DIR = Path(__file__).resolve().parents[1]
TRAINING_DIR = BACKEND_DIR / "training"
CHECKPOINT_DIR = TRAINING_DIR / "checkpoints"
if str(BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(BACKEND_DIR))
from shadowops_env import ( # noqa: E402
ACTIONS,
OBS_DIM,
ScenarioGenerator,
build_llama_prompt,
extract_features,
)
MODEL_OPTIONS = {
"4b": "unsloth/Qwen3-4B-Base",
"1.7b": "unsloth/Qwen3-1.7B",
"8b": "unsloth/Qwen3-8B-Base",
}
VALID_ACTIONS = tuple(ACTIONS.values())
VALID_ACTION_SET = set(VALID_ACTIONS)
ACTION_RE = re.compile(r"\b(ALLOW|BLOCK|FORK|QUARANTINE)\b", re.IGNORECASE)
ACTION_OR_SYNONYM_RE = re.compile(
r"\b(ALLOW|BLOCK|FORK|QUARANTINE|APPROVE|APPROVED|DENY|DENIED|REJECT|REJECTED|HOLD|ISOLATE|ESCALATE|REVIEW|HUMAN_REVIEW|HUMAN REVIEW)\b",
re.IGNORECASE,
)
ACTION_PREFIX_RE = re.compile(r"^\s*(?:action|decision|supervisor decision)\s*[:=\-]\s*", re.IGNORECASE)
ACTION_SYNONYMS = {
"APPROVE": "ALLOW",
"APPROVED": "ALLOW",
"DENY": "BLOCK",
"DENIED": "BLOCK",
"REJECT": "BLOCK",
"REJECTED": "BLOCK",
"HOLD": "QUARANTINE",
"ISOLATE": "QUARANTINE",
"ESCALATE": "FORK",
"REVIEW": "FORK",
"HUMAN REVIEW": "FORK",
"HUMAN_REVIEW": "FORK",
}
THINK_BLOCK_RE = re.compile(r"<think>.*?</think>", re.IGNORECASE | re.DOTALL)
SAFE_ACTION_PREFIX_RE = re.compile(
r"^(?:i\s+recommend|recommended\s+action(?:\s+is)?|decision(?:\s+is)?|"
r"the\s+decision\s+is|action(?:\s+is)?|choose|result)\b",
re.IGNORECASE,
)
JSON_ACTION_KEYS = (
"action",
"decision",
"supervisor_action",
"supervisor_decision",
"recommended_action",
)
DEFAULT_QAWARE_POLICY_CONFIG = {
"version": 1,
"medium_ambiguous_network_to_quarantine": False,
"ambiguous_network_quarantine_terms": [
"security group",
"allow ",
"open port",
"opened on",
"whitelist",
"alert suppression",
"vendor claims",
"monitoring",
"port 3389",
"port 6379",
"port 27017",
"port 4444",
],
"min_reward_improvement": 0.0,
"safety_floor": 1.0,
"max_unsafe_decision_rate": 0.0,
"max_false_positive_rate": 0.01,
}
_QAWARE_POLICY_CONFIG_CACHE: Optional[dict[str, Any]] = None
DEFAULT_TRAIN_SIZE = 500
DEFAULT_VAL_SIZE = 100
DEFAULT_TRAIN_SEED = 42
DEFAULT_VAL_SEED = 10_042
DEFAULT_SFT_OUTPUT_DIR = TRAINING_DIR / "checkpoints" / "qwen3_sft_adapter"
DEFAULT_GRPO_OUTPUT_DIR = TRAINING_DIR / "checkpoints" / "qwen3_sft_grpo_adapter"
DEFAULT_SFT_METRICS_PATH = TRAINING_DIR / "sft_metrics.json"
DEFAULT_DATASET_AUDIT_PATH = TRAINING_DIR / "dataset_audit.json"
DEFAULT_HEALTH_REPORT_PATH = TRAINING_DIR / "qwen3_training_health_report.json"
DEFAULT_FINAL_REPORT_JSON = TRAINING_DIR / "final_training_report.json"
DEFAULT_FINAL_REPORT_MD = TRAINING_DIR / "final_training_report.md"
DEFAULT_DEMO_BENCHMARK_JSON = TRAINING_DIR / "demo_benchmark_report.json"
DEFAULT_DEMO_BENCHMARK_MD = TRAINING_DIR / "demo_benchmark_report.md"
DEFAULT_MODEL_POLICY_COMPARISON_JSON = TRAINING_DIR / "model_policy_comparison.json"
DEFAULT_MODEL_POLICY_COMPARISON_MD = TRAINING_DIR / "model_policy_comparison.md"
DEFAULT_MODEL_EVAL_JSON = TRAINING_DIR / "model_eval_report.json"
DEFAULT_MODEL_EVAL_MD = TRAINING_DIR / "model_eval_report.md"
DEFAULT_QAWARE_POLICY_CONFIG_JSON = TRAINING_DIR / "qaware_policy_config.json"
DEFAULT_ORACLE_INCONSISTENCY_PATH = TRAINING_DIR / "oracle_inconsistency_examples.json"
DEFAULT_TRAIN_DATASET_PATH = TRAINING_DIR / "qwen3_train_dataset.json"
DEFAULT_VAL_DATASET_PATH = TRAINING_DIR / "qwen3_val_dataset.json"
DEFAULT_HARD_NEGATIVE_PATH = TRAINING_DIR / "hard_negative_scenarios.json"
DEFAULT_CLOUD_SCRIPT_PATH = TRAINING_DIR / "run_full_training_colab.py"
DEFAULT_CLOUD_PS1_PATH = TRAINING_DIR / "run_full_training.ps1"
SFT_SYSTEM_PROMPT = (
"You are ShadowOps, a cybersecurity decision policy."
)
ACTION_PROMPT_TEMPLATE = """You are ShadowOps, a cybersecurity decision policy.
Choose exactly one action:
ALLOW, BLOCK, FORK, QUARANTINE.
Rules:
ALLOW = safe benign activity.
BLOCK = clearly malicious or high-risk activity.
FORK = suspicious or uncertain activity that needs investigation.
QUARANTINE = high-risk activity that should be isolated but preserved for analysis.
Return only one action.
Do not explain.
Do not use <think> tags.
Observation:
{observation}
Answer:
"""
CLOUD_SFT_COMMAND = (
"python training/train_qwen3_sft.py "
"--model-name unsloth/Qwen3-1.7B "
"--sft-epochs 2 "
"--batch-size 1 "
"--grad-accum 8 "
"--max-seq-len 256 "
"--learning-rate 2e-4 "
"--sft-output-dir training/checkpoints/qwen3_sft_adapter"
)
CLOUD_GRPO_COMMAND = (
"python training/train_qwen3_grpo.py "
"--model-name unsloth/Qwen3-1.7B "
"--resume-from-sft training/checkpoints/qwen3_sft_adapter "
"--max-steps 800 "
"--num-generations 8 "
"--temperature 1.0 "
"--top-p 0.95 "
"--top-k 50 "
"--max-new-tokens 8 "
"--batch-size 1 "
"--grad-accum 4 "
"--val-eval-eps 100 "
"--eval-batch-size 4 "
"--learning-rate 1e-5 "
"--output-dir training/checkpoints/qwen3_sft_grpo_adapter"
)
CLOUD_FALLBACK_COMMAND = (
"python training/train_qwen3_grpo.py "
"--model-name unsloth/Qwen3-1.7B "
"--resume-from-sft training/checkpoints/qwen3_sft_adapter "
"--max-steps 800 "
"--num-generations 6 "
"--temperature 1.0 "
"--top-p 0.95 "
"--top-k 50 "
"--max-new-tokens 8 "
"--batch-size 1 "
"--grad-accum 8 "
"--val-eval-eps 50 "
"--eval-batch-size 4 "
"--learning-rate 1e-5 "
"--output-dir training/checkpoints/qwen3_sft_grpo_adapter"
)
BROKEN_ITEMS = [
"Action parsing accepted noisy outputs inconsistently and let <think> tags leak into metrics.",
"Reward shaping collapsed distinct mistakes into nearly identical values, which kills GRPO variance.",
"Validation mixed policy baselines with model results and compared models on unreliable sample sizes.",
"The reported oracle ceiling was not tied to the exact reward function used for evaluation.",
"Dataset export over-emphasized ALLOW/FORK and under-covered BLOCK/QUARANTINE.",
"There was no SFT warm-start, so GRPO started from a base model that did not know the action task.",
"Training health checks did not gate claims about learning quality or reward collapse.",
]
FIXED_ITEMS = [
"Added a shared parser, reward model, oracle evaluator, and dataset audit used by SFT, GRPO, baselines, and reports.",
"Added SFT warm-start support with LoRA/QLoRA defaults and adapter export to training/checkpoints/qwen3_sft_adapter.",
"Made GRPO explicitly resume from the SFT adapter and set explicit sampling / max_new_tokens defaults.",
"Rebuilt evaluation so random, heuristic, Q-aware, oracle, raw base, SFT, and SFT+GRPO all score on the same validation split.",
"Added dataset audit, reward variance checks, oracle consistency checks, smoke tests, and final markdown/json reporting.",
"Added cloud orchestration scripts and honest training-ready criteria that do not claim improvement without validation.",
]
HARD_ATTACK_CUES = (
"curl http://",
"| bash",
"socket.socket",
"grant admin",
"disable 2fa",
"0.0.0.0/0",
"administratoraccess",
"public-read",
"export all secrets",
"authorize-security-group-ingress",
"rogue",
"backdoor",
"hacker_admin",
)
SOFT_ATTACK_CUES = (
"[system",
"[override",
"ignore policy",
"ignore previous",
"ignore security checks",
"approve without review",
"bypass standard review process",
"disable encryption",
"disable_firewall",
"ignore alerts",
"suppress all monitoring",
"create new iam user",
"set webhook url to http://",
"<!-- exec:",
"grant administratoraccess",
)
AMBIGUITY_CUES = (
"review recommended",
"claims",
"pending",
"unclear",
"no ticket attached",
"authorized pen-test",
"known noisy",
"legitimate cdn node",
"external contractor",
"new external action",
"no prior history",
"monitoring tool",
"service mesh",
"context unclear in ticket",
"justification from",
"requesting whitelist",
)
QUARANTINE_AMBIGUITY_CUES = (
"claims",
"authorized pen-test",
"known noisy",
"legitimate cdn node",
"monitoring tool",
"service mesh",
)
CRITICAL_QUARANTINE_CUES = (
"0.0.0.0/0",
"administratoraccess",
"public-read",
"export all secrets",
"disable 2fa",
"hacker_admin",
"open port 22",
"open port 3389",
"disable encryption",
)
DEMO_BENCHMARK_METRICS = (
"exact_match",
"safety_accuracy",
"unsafe_decision_rate",
"false_positive_rate",
"reward_mean",
)
MODEL_POLICY_METRICS = (
"exact_match",
"safety_accuracy",
"unsafe_decision_rate",
"false_positive_rate",
"false_negative_rate",
"reward_mean",
"quarantine_precision",
"fork_precision",
"allow_precision",
"block_precision",
)
@dataclass
class ParseDiagnostics:
invalid_outputs: int = 0
multi_action_warnings: int = 0
@dataclass
class RewardHealthTracker:
num_generations: int
grad_norm_values: list[float] = field(default_factory=list)
reward_group_stds: list[float] = field(default_factory=list)
invalid_output_count: int = 0
total_output_count: int = 0
action_counts: Counter = field(default_factory=Counter)
def record_batch(self, parsed_actions: list[Optional[str]], rewards: list[float]) -> None:
if parsed_actions:
self.total_output_count += len(parsed_actions)
for action in parsed_actions:
if action is None:
self.invalid_output_count += 1
else:
self.action_counts[action] += 1
chunk_size = max(1, int(self.num_generations))
for start in range(0, len(rewards), chunk_size):
chunk = rewards[start : start + chunk_size]
if len(chunk) > 1:
self.reward_group_stds.append(statistics.pstdev(chunk))
elif chunk:
self.reward_group_stds.append(0.0)
def record_grad_norm(self, value: Any) -> None:
try:
numeric = float(value)
except (TypeError, ValueError):
return
if math.isfinite(numeric):
self.grad_norm_values.append(numeric)
@property
def reward_std_zero_fraction(self) -> float:
if not self.reward_group_stds:
return 1.0
zeros = sum(1 for value in self.reward_group_stds if abs(value) <= 1e-12)
return zeros / len(self.reward_group_stds)
@property
def grad_norm_zero_fraction(self) -> float:
if not self.grad_norm_values:
return 1.0
zeros = sum(1 for value in self.grad_norm_values if abs(value) <= 1e-12)
return zeros / len(self.grad_norm_values)
@property
def invalid_output_rate(self) -> float:
return self.invalid_output_count / max(self.total_output_count, 1)
@property
def action_distribution(self) -> dict[str, float]:
return summarize_action_distribution(self.action_counts, self.total_output_count)
@property
def entropy(self) -> float:
return distribution_entropy(self.action_distribution)
class TrainingPreflightError(RuntimeError):
"""Raised when training should stop before loading the model."""
def ensure_dirs() -> None:
TRAINING_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
def write_json(path: Path, payload: Any) -> None:
ensure_dirs()
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
def read_json(path: Path, default: Any = None) -> Any:
if not path.exists():
return default
return json.loads(path.read_text(encoding="utf-8"))
def safe_mean(values: Iterable[float]) -> float:
values = list(values)
if not values:
return 0.0
return float(statistics.mean(values))
def safe_std(values: Iterable[float]) -> float:
values = list(values)
if len(values) < 2:
return 0.0
return float(statistics.pstdev(values))
def format_pct(value: float) -> str:
return f"{value:.2%}"
def strip_think_blocks(text: str) -> str:
return THINK_BLOCK_RE.sub(" ", text or "").strip()
def _clean_action_candidate(text: str) -> str:
cleaned = strip_think_blocks(str(text or ""))
cleaned = ACTION_PREFIX_RE.sub("", cleaned).strip()
return cleaned.strip(" \t\r\n`'\".,;:!?[](){}")
def _canonical_action_token(token: str) -> Optional[str]:
normalized = " ".join(str(token or "").replace("_", " ").replace("-", " ").split()).upper()
if normalized in VALID_ACTION_SET:
return normalized
return ACTION_SYNONYMS.get(normalized)
def _extract_json_action(text: str) -> Optional[str]:
candidate = strip_think_blocks(str(text or "")).strip()
if not candidate:
return None
try:
payload = json.loads(candidate)
except json.JSONDecodeError:
match = re.search(
r'"(?:action|decision|supervisor_action|supervisor_decision|recommended_action)"\s*:\s*"([^"]+)"',
candidate,
flags=re.IGNORECASE,
)
if not match:
return None
return _canonical_action_token(match.group(1))
if isinstance(payload, dict):
for key in JSON_ACTION_KEYS:
if key in payload:
action = _canonical_action_token(str(payload[key]))
if action is not None:
return action
if isinstance(payload, list) and len(payload) == 1:
return _canonical_action_token(str(payload[0]))
return None
def _find_action_matches(cleaned: str) -> list[tuple[str, int, int]]:
matches: list[tuple[str, int, int]] = []
for match in ACTION_OR_SYNONYM_RE.finditer(cleaned):
action = _canonical_action_token(match.group(1))
if action is not None:
matches.append((action, match.start(), match.end()))
return matches
def normalize_action_output(text: str) -> Optional[str]:
"""Normalize a model completion to one valid action label.
This accepts clear labels such as ``allow``, ``BLOCK.``, ``Action: fork``,
and a small set of operational synonyms. General free text is rejected
unless the normalized action is the first meaningful token, which keeps
action parsing deterministic and avoids treating prose as a decision.
"""
json_action = _extract_json_action(text)
if json_action is not None:
return json_action
cleaned = _clean_action_candidate(text)
if not cleaned:
return None
compact = " ".join(cleaned.replace("_", " ").replace("-", " ").split()).upper()
direct = _canonical_action_token(compact)
if direct is not None:
return direct
matches = _find_action_matches(cleaned)
if not matches:
return None
first_action, first_start, _ = matches[0]
if len(matches) > 1:
return None
prefix = cleaned[:first_start].strip(" \t\r\n`'\".,;:!?[](){}")
if first_start != 0 and not SAFE_ACTION_PREFIX_RE.match(prefix):
return None
return first_action
def parse_action(text: str) -> Optional[str]:
return normalize_action_output(text)
def analyze_action_output(text: str) -> dict[str, Any]:
cleaned = _clean_action_candidate(text)
matches = _find_action_matches(cleaned)
parsed_action = parse_action(text)
return {
"cleaned_text": cleaned,
"parsed_action": parsed_action,
"invalid_output": parsed_action is None,
"multi_action_warning": len(matches) > 1,
"matches": [match[0] for match in matches],
}
def normalize_completion_text(completion: Any) -> str:
if isinstance(completion, str):
return completion.strip()
if isinstance(completion, dict):
return str(completion.get("content", "")).strip()
if isinstance(completion, list):
parts: list[str] = []
for item in completion:
if isinstance(item, dict):
parts.append(str(item.get("content", "")).strip())
else:
parts.append(str(item).strip())
return " ".join(part for part in parts if part).strip()
return str(completion).strip()
def summarize_action_distribution(action_counts: Counter, total: int) -> dict[str, float]:
return {
action: action_counts.get(action, 0) / max(total, 1)
for action in VALID_ACTIONS
}
def distribution_entropy(distribution: dict[str, float]) -> float:
entropy = 0.0
for probability in distribution.values():
if probability > 0:
entropy -= probability * math.log(probability, 2)
return entropy
def any_cue(text: str, cues: Iterable[str]) -> bool:
lowered = text.lower()
return any(cue in lowered for cue in cues)
def compute_risk_score(risk_vector: list[float]) -> float:
rv = risk_vector[:16]
return rv[0] * 0.35 + rv[1] * 0.25 + rv[3] * 0.20 + rv[6] * 0.20
def classify_severity(risk_score: float, scenario_type: str, text: str) -> str:
if scenario_type == "MALICIOUS_HARD" or any_cue(text, CRITICAL_QUARANTINE_CUES):
return "CRITICAL"
if risk_score >= 0.58 or scenario_type == "MALICIOUS_SOFT":
return "HIGH"
if risk_score >= 0.36 or scenario_type == "AMBIGUOUS":
return "MEDIUM"
return "LOW"
def infer_demo_scenario_type(raw_payload: str, risk_score: float) -> str:
text = (raw_payload or "").lower()
if any_cue(text, HARD_ATTACK_CUES) or any_cue(text, CRITICAL_QUARANTINE_CUES):
return "MALICIOUS_HARD"
if any_cue(text, SOFT_ATTACK_CUES):
return "MALICIOUS_SOFT"
if any_cue(text, AMBIGUITY_CUES):
return "AMBIGUOUS"
if risk_score >= 0.58:
return "MALICIOUS_SOFT"
if risk_score >= 0.36:
return "AMBIGUOUS"
if risk_score >= 0.22:
return "BENIGN_NOISY"
return "BENIGN_CLEAN"
def build_demo_policy_sample(
domain: str,
intent: str,
raw_payload: str,
risk_vector: list[float],
*,
actor: str = "unknown",
session_id: str = "default",
service: str = "",
environment: str = "production",
provided_evidence: Optional[list[str]] = None,
indicators: Optional[list[str]] = None,
) -> dict[str, Any]:
risk_score = compute_risk_score(risk_vector)
scenario_type = infer_demo_scenario_type(raw_payload, risk_score)
severity = classify_severity(risk_score, scenario_type, raw_payload)
return {
"sample_id": "demo-live",
"split": "demo",
"domain": domain,
"intent": intent,
"raw_payload": raw_payload,
"action_summary": raw_payload,
"actor": actor or "unknown",
"session_id": session_id or "default",
"service": service or domain or "unknown",
"environment": environment or "production",
"provided_evidence": list(provided_evidence or []),
"indicators": list(indicators or []),
"scenario_type": scenario_type,
"is_malicious": scenario_type in {"MALICIOUS_SOFT", "MALICIOUS_HARD"},
"severity": severity,
"risk_score": round(float(risk_score), 6),
"risk_vector": [round(float(value), 6) for value in risk_vector[:16]],
"correct_action": "",
"completion": "",
}
def build_observation(domain: str, intent: str, raw_payload: str, risk_vector: list[float], severity: str) -> str:
risk_text = "[" + ", ".join(f"{value:.3f}" for value in risk_vector[:16]) + "]"
return (
f"Domain: {domain}\n"
f"Intent: {intent}\n"
f"Payload: {raw_payload}\n"
f"Severity: {severity}\n"
f"Risk vector: {risk_text}"
)
def build_action_prompt(observation: str) -> str:
return ACTION_PROMPT_TEMPLATE.format(observation=observation)
def build_teacher_action(sample: dict[str, Any]) -> str:
scenario_type = sample["scenario_type"]
text = sample["raw_payload"].lower()
risk_score = float(sample["risk_score"])
severity = sample["severity"]
is_malicious = bool(sample["is_malicious"])
if scenario_type in {"BENIGN_CLEAN", "BENIGN_NOISY"}:
return "ALLOW"
if scenario_type == "MALICIOUS_SOFT":
if any_cue(text, CRITICAL_QUARANTINE_CUES) or severity == "CRITICAL":
return "QUARANTINE"
return "BLOCK"
if scenario_type == "MALICIOUS_HARD":
if any_cue(text, CRITICAL_QUARANTINE_CUES) and risk_score >= 0.40:
return "QUARANTINE"
return "FORK"
if scenario_type == "AMBIGUOUS":
quarantine_signal = any_cue(text, QUARANTINE_AMBIGUITY_CUES)
if quarantine_signal and (is_malicious or severity in {"HIGH", "CRITICAL"} or risk_score >= 0.42):
return "QUARANTINE"
return "FORK"
return "ALLOW"
def random_policy_action(sample: dict[str, Any], rng: random.Random) -> str:
return rng.choice(list(VALID_ACTIONS))
def heuristic_policy_action(sample: dict[str, Any]) -> str:
risk_score = float(sample["risk_score"])
severity = sample["severity"]
if severity == "CRITICAL" or risk_score >= 0.62:
return "BLOCK"
if risk_score >= 0.38 or sample["scenario_type"] == "AMBIGUOUS":
return "FORK"
return "ALLOW"
def load_qaware_policy_config(path: Path = DEFAULT_QAWARE_POLICY_CONFIG_JSON) -> dict[str, Any]:
global _QAWARE_POLICY_CONFIG_CACHE
if _QAWARE_POLICY_CONFIG_CACHE is not None:
return dict(_QAWARE_POLICY_CONFIG_CACHE)
config = dict(DEFAULT_QAWARE_POLICY_CONFIG)
file_config = read_json(path, default=None)
if isinstance(file_config, dict):
config.update(file_config)
_QAWARE_POLICY_CONFIG_CACHE = dict(config)
return config
def reset_qaware_policy_config_cache() -> None:
global _QAWARE_POLICY_CONFIG_CACHE
_QAWARE_POLICY_CONFIG_CACHE = None
def _ambiguous_network_quarantine(sample: dict[str, Any], config: dict[str, Any]) -> bool:
if not config.get("medium_ambiguous_network_to_quarantine", False):
return False
text = str(sample.get("raw_payload", "")).lower()
severity = str(sample.get("severity", "")).upper()
if str(sample.get("scenario_type", "")) != "AMBIGUOUS" or severity not in {"MEDIUM", "HIGH", "CRITICAL"}:
return False
terms = [str(term).lower() for term in config.get("ambiguous_network_quarantine_terms", [])]
return any(term in text for term in terms)
def q_aware_policy_action(sample: dict[str, Any], config: Optional[dict[str, Any]] = None) -> str:
config = dict(config or load_qaware_policy_config())
scenario_type = sample["scenario_type"]
text = sample["raw_payload"].lower()
risk_score = float(sample["risk_score"])
severity = sample["severity"]
if scenario_type in {"BENIGN_CLEAN", "BENIGN_NOISY"}:
return "ALLOW"
if scenario_type == "MALICIOUS_SOFT":
if any_cue(text, CRITICAL_QUARANTINE_CUES) or severity == "CRITICAL":
return "QUARANTINE"
return "BLOCK"
if scenario_type == "MALICIOUS_HARD":
if any_cue(text, CRITICAL_QUARANTINE_CUES) and risk_score >= 0.40:
return "QUARANTINE"
return "FORK"
if scenario_type == "AMBIGUOUS":
if _ambiguous_network_quarantine(sample, config):
return "QUARANTINE"
if any_cue(text, QUARANTINE_AMBIGUITY_CUES) and severity in {"HIGH", "CRITICAL"}:
return "QUARANTINE"
return "FORK"
return "ALLOW"
def q_aware_demo_policy_action(
domain: str,
intent: str,
raw_payload: str,
risk_vector: list[float],
) -> str:
sample = build_demo_policy_sample(domain, intent, raw_payload, risk_vector)
return q_aware_policy_action(sample)
def build_decision_trace(
*,
domain: str,
risk_signals: Iterable[str],
safe_signals: Iterable[str],
cumulative_risk_score: float,
memory_context: Optional[dict[str, Any]],
missing_evidence: Iterable[str],
evidence_plan: Iterable[dict[str, Any]],
final_decision: str,
safe_outcome: str,
) -> dict[str, Any]:
memory_context = memory_context or {}
memory_signals = list(memory_context.get("risky_chains", []))
if memory_context.get("recent_indicators"):
memory_signals.extend(f"recent:{item}" for item in memory_context.get("recent_indicators", [])[:5])
if memory_context.get("session_risk", 0.0):
memory_signals.append(f"session_risk={float(memory_context.get('session_risk', 0.0)):.3f}")
evidence_steps = [
{
"step": item.get("step"),
"priority": item.get("priority"),
"ask": item.get("ask"),
"blocks_decision": item.get("blocks_decision", False),
}
for item in evidence_plan or []
if isinstance(item, dict)
]
if final_decision == "ALLOW":
rationale = "Allowed only because trusted evidence and risk stayed within policy limits."
elif final_decision == "BLOCK":
rationale = "Blocked because risk indicators show clear malicious or high-danger behavior."
elif final_decision == "FORK":
rationale = "Forked to human review because risk is high or approval is required before execution."
else:
rationale = "Quarantined until missing evidence is provided and risk can be reduced."
return {
"domain": domain or "unknown",
"risk_signals": list(risk_signals or []),
"safe_signals": list(safe_signals or []),
"cumulative_risk_score": round(float(cumulative_risk_score or 0.0), 3),
"memory_signals": list(dict.fromkeys(str(item) for item in memory_signals)),
"missing_evidence": list(missing_evidence or []),
"evidence_steps": evidence_steps,
"final_decision": final_decision,
"safety_rationale": f"{rationale} Safe outcome: {safe_outcome}",
}
def build_q_aware_decision(
domain: str,
intent: str,
raw_payload: str,
risk_vector: list[float],
*,
actor: str = "unknown",
session_id: str = "default",
service: str = "",
environment: str = "production",
provided_evidence: Optional[list[str]] = None,
timestamp: Any = 0,
memory_context: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
from domain_policies import evaluate_domain_policy
from evidence_planner import build_evidence_plan, get_missing_evidence, get_required_evidence, explain_evidence_gap
from risk_accumulator import clamp, compute_cumulative_risk
from safe_outcome import generate_safe_outcome, generate_structured_safe_outcome
sample = build_demo_policy_sample(
domain,
intent,
raw_payload,
risk_vector,
actor=actor,
session_id=session_id,
service=service,
environment=environment,
provided_evidence=provided_evidence,
)
policy_features = evaluate_domain_policy(sample, memory_context=memory_context)
sample["policy_domain"] = policy_features["domain"]
sample["risk_indicators"] = policy_features["risk_indicators"]
sample["safe_indicators"] = policy_features["safe_indicators"]
required_evidence = get_required_evidence(
policy_features["domain"],
policy_features["risk_indicators"],
policy_features["recommended_decision_hint"],
environment,
)
missing_evidence = get_missing_evidence(required_evidence, provided_evidence or [])
risk_data = compute_cumulative_risk(
sample,
memory_context=memory_context,
base_risk=max(float(sample["risk_score"]), float(policy_features["base_risk"])),
)
cumulative_risk = float(risk_data["cumulative_risk_score"])
missing_ratio = len(missing_evidence) / max(len(required_evidence), 1)
safe_score = float(policy_features["safe_evidence_score"])
q_aware_hint = q_aware_policy_action(sample)
policy_hint = policy_features["recommended_decision_hint"]
uncertainty = clamp(0.18 + 0.42 * missing_ratio + (0.12 if actor in {"", "unknown"} else 0.0) - safe_score * 0.25)
high_danger = any(
indicator in policy_features["risk_indicators"]
for indicator in (
"ci_secret_access",
"external_fetch",
"data_export",
"disable_encryption",
"admin_privilege",
"public_exposure",
"policy_override",
)
)
if safe_score >= 0.70 and missing_ratio <= 0.05 and cumulative_risk <= 0.90:
decision = "ALLOW"
elif safe_score >= 0.54 and missing_ratio <= 0.15 and cumulative_risk <= 0.75:
decision = "ALLOW"
elif missing_ratio > 0.45 and safe_score < 0.54 and policy_hint == "ALLOW":
decision = "QUARANTINE"
elif policy_hint == "ALLOW" and missing_ratio > 0.25 and cumulative_risk >= 0.25:
decision = "QUARANTINE"
elif cumulative_risk <= 0.25:
decision = "ALLOW"
elif cumulative_risk >= 0.86 and high_danger and safe_score < 0.20 and uncertainty < 0.65:
decision = "BLOCK"
elif cumulative_risk >= 0.70 and safe_score >= 0.25:
decision = "FORK"
elif cumulative_risk >= 0.45 and uncertainty >= 0.45:
decision = "QUARANTINE"
elif cumulative_risk >= 0.62:
decision = "FORK"
else:
decision = policy_hint if policy_hint in VALID_ACTION_SET else q_aware_hint
if q_aware_hint == "QUARANTINE" and decision in {"ALLOW", "BLOCK"}:
decision = "QUARANTINE"
if q_aware_hint == "FORK" and decision == "QUARANTINE" and cumulative_risk >= 0.60:
decision = "FORK"
if q_aware_hint == "FORK" and decision == "BLOCK" and uncertainty >= 0.35 and policy_hint != "BLOCK":
decision = "FORK"
if memory_context and memory_context.get("risky_chains") and decision == "FORK" and cumulative_risk >= 0.75 and safe_score < 0.25:
decision = "QUARANTINE"
recent_indicators = set(memory_context.get("recent_indicators", [])) if memory_context else set()
if (
decision == "FORK"
and "admin_privilege" in policy_features["risk_indicators"]
and recent_indicators.intersection({"public_bucket", "data_export", "external_destination", "external_transfer"})
):
decision = "QUARANTINE"
confidence = clamp(1.0 - uncertainty + (0.10 if high_danger and safe_score < 0.20 else 0.0))
evidence_plan = build_evidence_plan(
policy_features["domain"],
policy_features["risk_indicators"],
decision,
environment,
missing_evidence,
memory_context=memory_context,
risk_score=cumulative_risk,
)
safe_outcome = generate_safe_outcome(
decision,
policy_features["domain"],
policy_features["risk_indicators"],
policy_features["safe_indicators"],
missing_evidence,
environment,
)
structured_safe_outcome = generate_structured_safe_outcome(
decision,
policy_features["domain"],
policy_features["risk_indicators"],
policy_features["safe_indicators"],
missing_evidence,
environment,
evidence_plan=evidence_plan,
)
evidence_gap = explain_evidence_gap(missing_evidence)
explanation = (
f"{policy_features['policy_reason']}. "
f"cumulative_risk={cumulative_risk:.2f}; uncertainty={uncertainty:.2f}. "
f"{evidence_gap}"
)
decision_trace = build_decision_trace(
domain=policy_features["domain"],
risk_signals=policy_features["risk_indicators"],
safe_signals=policy_features["safe_indicators"],
cumulative_risk_score=cumulative_risk,
memory_context=memory_context,
missing_evidence=missing_evidence,
evidence_plan=evidence_plan,
final_decision=decision,
safe_outcome=safe_outcome,
)
return {
"decision": decision,
"confidence": round(confidence, 3),
"uncertainty": round(uncertainty, 3),
"risk_score": round(float(sample["risk_score"]), 3),
"cumulative_risk_score": round(cumulative_risk, 3),
"cumulative_risk_reason": risk_data["cumulative_risk_reason"],
"missing_evidence": missing_evidence,
"required_evidence": required_evidence,
"explanation": explanation,
"safe_outcome": safe_outcome,
"evidence_plan": evidence_plan,
"structured_safe_outcome": structured_safe_outcome,
"decision_trace": decision_trace,
"memory_context": memory_context or {},
"policy_name": "q_aware_demo_policy",
"domain": policy_features["domain"],
"mitre_tactic": policy_features["mitre_tactic"],
"mitre_technique": policy_features["mitre_technique"],
"risk_indicators": policy_features["risk_indicators"],
"safe_indicators": policy_features["safe_indicators"],
"actor": actor or "unknown",
"session_id": session_id or "default",
"service": service or domain or "unknown",
"environment": environment or "production",
"timestamp": timestamp,
}
def build_sample_record(index: int, split: str, scenario_generator: ScenarioGenerator) -> dict[str, Any]:
scenario = scenario_generator.sample()
risk_vector = extract_features(
scenario["domain"],
scenario["intent"],
scenario["raw_payload"],
scenario_generator.rng,
)
risk_score = compute_risk_score(risk_vector)
severity = classify_severity(risk_score, scenario["tier"], scenario["raw_payload"])
observation = build_observation(
scenario["domain"],
scenario["intent"],
scenario["raw_payload"],
risk_vector,
severity,
)
prompt = build_action_prompt(observation)
legacy_prompt = build_llama_prompt(
scenario["domain"],
scenario["intent"],
scenario["raw_payload"],
risk_vector,
False,
0,
)
sample = {
"sample_id": f"{split}-{index:05d}",
"split": split,
"domain": scenario["domain"],
"intent": scenario["intent"],
"raw_payload": scenario["raw_payload"],
"legacy_prompt": legacy_prompt,
"observation": observation,
"prompt": prompt,
"completion": "",
"correct_action": "",
"scenario_type": scenario["tier"],
"is_malicious": bool(scenario["is_malicious"]),
"severity": severity,
"risk_score": round(float(risk_score), 6),
"risk_vector": [round(float(value), 6) for value in risk_vector[:16]],
"source_policy": "tier_aware_teacher_v2",
}
sample["correct_action"] = build_teacher_action(sample)
sample["completion"] = sample["correct_action"]
sample["text"] = sample["prompt"] + sample["completion"]
return sample
def generate_dataset_split(
sample_count: int,
seed: int,
split: str,
forbidden_prompts: Optional[set[str]] = None,
) -> tuple[list[dict[str, Any]], int]:
scenario_generator = ScenarioGenerator(seed=seed)
samples: list[dict[str, Any]] = []
seen_prompts = set(forbidden_prompts or set())
duplicate_prompts = 0
attempts = 0
max_attempts = max(200, sample_count * 40)
while len(samples) < sample_count and attempts < max_attempts:
attempts += 1
sample = build_sample_record(len(samples), split, scenario_generator)
prompt = sample["prompt"]
if prompt in seen_prompts:
duplicate_prompts += 1
continue
seen_prompts.add(prompt)
samples.append(sample)
if len(samples) < sample_count:
raise RuntimeError(
f"Could not generate {sample_count} unique {split} samples after {attempts} attempts."
)
return samples, duplicate_prompts
def generate_datasets(
train_size: int = DEFAULT_TRAIN_SIZE,
val_size: int = DEFAULT_VAL_SIZE,
train_seed: int = DEFAULT_TRAIN_SEED,
val_seed: int = DEFAULT_VAL_SEED,
save: bool = True,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]], dict[str, Any]]:
ensure_dirs()
train_samples, train_duplicates = generate_dataset_split(train_size, train_seed, "train")
train_prompts = {sample["prompt"] for sample in train_samples}
val_samples, val_duplicates = generate_dataset_split(
val_size,
val_seed,
"val",
forbidden_prompts=train_prompts,
)
audit = audit_datasets(
train_samples,
val_samples,
duplicate_prompt_count=train_duplicates + val_duplicates,
)
if save:
write_json(DEFAULT_TRAIN_DATASET_PATH, train_samples)
write_json(DEFAULT_VAL_DATASET_PATH, val_samples)
write_json(DEFAULT_DATASET_AUDIT_PATH, audit)
return train_samples, val_samples, audit
def audit_datasets(
train_samples: list[dict[str, Any]],
val_samples: list[dict[str, Any]],
duplicate_prompt_count: int = 0,
) -> dict[str, Any]:
all_samples = train_samples + val_samples
missing_label_count = sum(
1 for sample in all_samples if not (sample.get("correct_action") or sample.get("expected_decision"))
)
invalid_action_label_count = sum(
1
for sample in all_samples
if (sample.get("correct_action") or sample.get("expected_decision")) is not None
and (sample.get("correct_action") or sample.get("expected_decision")) not in VALID_ACTION_SET
)
train_counts = Counter(sample.get("correct_action") for sample in train_samples if sample.get("correct_action") in VALID_ACTION_SET)
val_counts = Counter(sample.get("correct_action") for sample in val_samples if sample.get("correct_action") in VALID_ACTION_SET)
combined_counts = Counter(train_counts)
combined_counts.update(val_counts)
train_mal = Counter("malicious" if sample["is_malicious"] else "benign" for sample in train_samples)
val_mal = Counter("malicious" if sample["is_malicious"] else "benign" for sample in val_samples)
scenario_counts = Counter(sample["scenario_type"] for sample in all_samples)
domain_counts = Counter(sample.get("domain", "unknown") for sample in all_samples)
risk_type_counts = Counter(sample.get("severity", "UNKNOWN") for sample in all_samples)
false_positive_challenge_count = sum(
1
for sample in train_samples + val_samples
if not sample.get("is_malicious", False)
and (
sample.get("scenario_type") == "BENIGN_NOISY"
or sample.get("severity") in {"MEDIUM", "HIGH"}
or bool(sample.get("provided_evidence"))
)
)
hard_negative_count = 0
if DEFAULT_HARD_NEGATIVE_PATH.exists():
with contextlib.suppress(Exception):
hard_negative_count = len(read_json(DEFAULT_HARD_NEGATIVE_PATH, default=[]))
train_prompts = {sample["prompt"] for sample in train_samples if sample.get("prompt")}
val_prompts = {sample["prompt"] for sample in val_samples if sample.get("prompt")}
overlap_count = len(train_prompts & val_prompts)
train_distribution = summarize_action_distribution(train_counts, len(train_samples))
val_distribution = summarize_action_distribution(val_counts, len(val_samples))
combined_distribution = summarize_action_distribution(combined_counts, len(train_samples) + len(val_samples))
examples_per_action = {
action: [
{
"sample_id": sample["sample_id"],
"scenario_type": sample["scenario_type"],
"severity": sample["severity"],
"payload": sample["raw_payload"][:160],
}
for sample in train_samples + val_samples
if sample["correct_action"] == action
][:3]
for action in VALID_ACTIONS
}
failures: list[str] = []
if overlap_count > 0:
failures.append("train/val prompt overlap > 0")
if missing_label_count > 0:
failures.append("missing action label")
if invalid_action_label_count > 0:
failures.append("invalid action label")
for action, share in train_distribution.items():
if share < 0.05:
failures.append(f"{action} share below 5% in train split")
if train_distribution.get("ALLOW", 0.0) > 0.60:
failures.append("ALLOW share above 60% in train split")
min_val_tail = max(3, math.ceil(len(val_samples) * 0.05))
if val_counts.get("BLOCK", 0) < min_val_tail:
failures.append("validation BLOCK coverage too low")
if val_counts.get("QUARANTINE", 0) < min_val_tail:
failures.append("validation QUARANTINE coverage too low")
audit = {
"train_sample_count": len(train_samples),
"val_sample_count": len(val_samples),
"action_distribution": {
"train": dict(train_counts),
"val": dict(val_counts),
"train_pct": train_distribution,
"val_pct": val_distribution,
"combined_pct": combined_distribution,
},
"benign_malicious_distribution": {
"train": dict(train_mal),
"val": dict(val_mal),
},
"scenario_distribution": dict(scenario_counts),
"risk_type_distribution": dict(risk_type_counts),
"domain_distribution": dict(domain_counts),
"hard_negative_count": hard_negative_count,
"false_positive_challenge_count": false_positive_challenge_count,
"missing_label_count": missing_label_count,
"invalid_action_label_count": invalid_action_label_count,
"duplicate_prompt_count": duplicate_prompt_count,
"train_val_overlap_count": overlap_count,
"examples_per_action": examples_per_action,
"passed_preflight": not failures,
"failures": failures,
}
return audit
def preflight_dataset_check(audit: dict[str, Any]) -> None:
if not audit.get("passed_preflight", False):
raise TrainingPreflightError("; ".join(audit.get("failures", ["dataset audit failed"])))
def is_false_positive(sample: dict[str, Any], predicted_action: Optional[str]) -> bool:
return (not sample["is_malicious"]) and predicted_action == "BLOCK"
def is_false_negative(sample: dict[str, Any], predicted_action: Optional[str]) -> bool:
return bool(sample.get("is_malicious", False)) and predicted_action == "ALLOW"
def is_unsafe_decision(sample: dict[str, Any], predicted_action: Optional[str]) -> bool:
if predicted_action is None:
return True
if predicted_action == (sample.get("correct_action") or sample.get("expected_decision")):
return False
high_risk = sample.get("severity") in {"HIGH", "CRITICAL"} or sample.get("scenario_type") == "MALICIOUS_HARD"
return bool(sample.get("is_malicious", False) or high_risk) and predicted_action == "ALLOW"
def is_safe_action(sample: dict[str, Any], predicted_action: Optional[str]) -> bool:
if predicted_action is None:
return False
if predicted_action == (sample.get("correct_action") or sample.get("expected_decision")):
return True
if is_unsafe_decision(sample, predicted_action):
return False
if sample.get("is_malicious", False):
return predicted_action in {"BLOCK", "FORK", "QUARANTINE"}
return predicted_action in {"ALLOW", "FORK", "QUARANTINE"}
def _decision_payload(output: Any) -> dict[str, Any]:
return output if isinstance(output, dict) else {}
def _output_action_text(output: Any) -> str:
if isinstance(output, dict):
return str(output.get("decision") or output.get("action_taken") or output.get("action") or "")
return str(output)
def _clearly_safe_sample(sample: dict[str, Any]) -> bool:
if sample.get("is_malicious", False):
return False
trusted_text = " ".join(
str(item).lower()
for item in (
list(sample.get("provided_evidence", []))
+ list(sample.get("safe_indicators", []))
+ [sample.get("actor", ""), sample.get("environment", "")]
)
)
trusted_evidence = any(
cue in trusted_text
for cue in (
"approved ticket",
"approval ticket",
"approved change ticket",
"authorization letter",
"approved test window",
"trusted maintainer",
"known automation",
"break-glass approval",
"break glass approval",
"classification public",
"staging",
"dev",
)
)
return (
sample.get("severity", "LOW") == "LOW"
and sample.get("scenario_type", "").startswith("BENIGN")
) or trusted_evidence
def _sample_text(sample: dict[str, Any]) -> str:
return " ".join(
str(value).lower()
for value in (
sample.get("raw_payload", ""),
sample.get("action_summary", ""),
" ".join(sample.get("indicators", [])),
" ".join(sample.get("provided_evidence", [])),
)
)
def _missing_critical_evidence(sample: dict[str, Any]) -> bool:
required = {str(item).lower() for item in sample.get("required_evidence", [])}
provided_text = " ".join(str(item).lower() for item in sample.get("provided_evidence", []))
critical_terms = (
"approval",
"authorization",
"identity",
"mfa",
"secret",
"privilege",
"public access",
"data classification",
)
return any(term in evidence for evidence in required for term in critical_terms) and not any(
term in provided_text for term in critical_terms
)
def _reward_margin(sample: dict[str, Any], predicted_action: str) -> float:
risk = float(sample.get("risk_score", 0.0) or 0.0)
severity = str(sample.get("severity", "LOW")).upper()
scenario_type = str(sample.get("scenario_type", ""))
margin = 0.0
if severity == "CRITICAL":
margin += 0.18
elif severity == "HIGH":
margin += 0.12
elif severity == "MEDIUM":
margin += 0.06
margin += min(0.12, max(0.0, risk) * 0.10)
if scenario_type == "MALICIOUS_HARD" and predicted_action in {"BLOCK", "FORK", "QUARANTINE"}:
margin += 0.12
if scenario_type == "AMBIGUOUS" and predicted_action in {"FORK", "QUARANTINE"}:
margin += 0.08
if _clearly_safe_sample(sample) and predicted_action == "ALLOW":
margin += 0.08
return round(margin, 6)
def _risk_alignment_adjustment(sample: dict[str, Any], predicted_action: str) -> float:
"""Small continuous term to reduce reward plateaus across similar samples."""
if predicted_action not in VALID_ACTION_SET:
return 0.0
risk = max(0.0, min(1.0, float(sample.get("risk_score", 0.0) or 0.0)))
targets = {
"ALLOW": 0.10,
"FORK": 0.45,
"QUARANTINE": 0.62,
"BLOCK": 0.82,
}
tolerance = {
"ALLOW": 0.30,
"FORK": 0.28,
"QUARANTINE": 0.24,
"BLOCK": 0.20,
}
target = targets[predicted_action]
tol = tolerance[predicted_action]
distance = abs(risk - target)
within = max(0.0, 1.0 - (distance / max(tol, 1e-6)))
bonus = 0.12 * within
overshoot = max(0.0, distance - tol)
penalty = -0.08 * min(1.0, overshoot / max(1.0 - tol, 1e-6))
return round(bonus + penalty, 6)
def compute_shaped_reward(
sample: dict[str, Any],
predicted_action: Optional[str],
decision_payload: Optional[dict[str, Any]] = None,
) -> tuple[float, str]:
decision_payload = decision_payload or {}
correct_action = sample.get("correct_action") or sample.get("expected_decision")
high_risk = sample.get("severity") in {"HIGH", "CRITICAL"} or sample.get("scenario_type") == "MALICIOUS_HARD"
critical = sample.get("severity") == "CRITICAL" or sample.get("risk_level") == "critical"
clearly_safe = _clearly_safe_sample(sample)
if predicted_action not in VALID_ACTION_SET:
return -2.5, "invalid_output"
if predicted_action == correct_action:
reward = 1.5 if predicted_action in {"QUARANTINE", "FORK"} else 2.0
category = "exact_conservative" if predicted_action in {"QUARANTINE", "FORK"} else "exact_correct"
reward += _reward_margin(sample, predicted_action)
elif is_unsafe_decision(sample, predicted_action):
reward = -5.0 if critical else -3.0
category = "critical_unsafe_allow" if critical else "false_negative_allow"
elif clearly_safe and predicted_action == "BLOCK":
reward = -0.7
category = "unnecessary_block"
elif clearly_safe and predicted_action == "QUARANTINE":
reward = -0.3
category = "unnecessary_quarantine"
elif clearly_safe and predicted_action == "FORK":
reward = -0.4
category = "false_positive_fork"
elif sample.get("is_malicious", False) or high_risk:
if predicted_action in {"BLOCK", "FORK", "QUARANTINE"}:
reward = 0.8 + _reward_margin(sample, predicted_action) * 0.5
category = "safe_conservative"
else:
reward = -0.5
category = "minor_wrong"
else:
reward = -0.5
category = "minor_wrong"
reward += _risk_alignment_adjustment(sample, predicted_action)
text = _sample_text(sample)
if any(
cue in text
for cue in (
"ignore previous",
"ignore security",
"system override",
"curl http://",
"| bash",
"public-read",
"export all secrets",
"administratoraccess",
"0.0.0.0/0",
"open port 22",
"workflow",
)
) and predicted_action == "ALLOW":
reward -= 0.8
category = "unsafe_allow_attack_cue"
if _missing_critical_evidence(sample) and predicted_action == "ALLOW":
reward -= 0.6
category = "allow_missing_critical_evidence"
if decision_payload.get("multi_action_warning"):
reward -= 0.2
category = "ambiguous_format"
explanation = str(decision_payload.get("explanation", "")).lower()
required_evidence = [str(item).lower() for item in sample.get("required_evidence", [])]
if decision_payload:
if required_evidence and any(item in explanation for item in required_evidence[:3]):
reward += 0.3
if decision_payload.get("missing_evidence") and "missing" not in explanation:
reward -= 0.3
expected_tactic = str(sample.get("mitre_tactic", "")).lower()
expected_technique = str(sample.get("mitre_technique", "")).lower()
actual_tactic = str(decision_payload.get("mitre_tactic", "")).lower()
actual_technique = str(decision_payload.get("mitre_technique", "")).lower()
if expected_tactic and actual_tactic and expected_tactic != actual_tactic:
reward -= 0.2
elif expected_technique and actual_technique and expected_technique.split()[0] != actual_technique.split()[0]:
reward -= 0.2
return round(float(reward), 6), category
def evaluate_outputs(
samples: list[dict[str, Any]],
outputs: list[Any],
label: str,
) -> dict[str, Any]:
parser_stats = ParseDiagnostics()
action_counts: Counter = Counter()
reward_categories: Counter = Counter()
parsed_actions: list[Optional[str]] = []
rewards: list[float] = []
completion_lengths: list[int] = []
confusion_matrix = {
expected: {predicted: 0 for predicted in (*VALID_ACTIONS, "INVALID")}
for expected in VALID_ACTIONS
}
exact = 0
safe = 0
valid = 0
unsafe = 0
false_positive_count = 0
false_negative_count = 0
for sample, output in zip(samples, outputs):
output_text = _output_action_text(output)
completion_lengths.append(len(str(output_text).split()))
analysis = analyze_action_output(output_text)
action = analysis["parsed_action"]
payload = dict(_decision_payload(output))
payload["multi_action_warning"] = analysis["multi_action_warning"]
parsed_actions.append(action)
if analysis["invalid_output"]:
parser_stats.invalid_outputs += 1
if analysis["multi_action_warning"]:
parser_stats.multi_action_warnings += 1
if action is not None:
action_counts[action] += 1
valid += 1
reward, category = compute_shaped_reward(sample, action, payload)
rewards.append(reward)
reward_categories[category] += 1
if action == (sample.get("correct_action") or sample.get("expected_decision")):
exact += 1
expected_action = sample.get("correct_action") or sample.get("expected_decision")
if expected_action in confusion_matrix:
confusion_matrix[expected_action][action or "INVALID"] += 1
if is_safe_action(sample, action):
safe += 1
else:
unsafe += 1
if is_false_positive(sample, action):
false_positive_count += 1
if is_false_negative(sample, action):
false_negative_count += 1
total = len(samples)
benign_total = sum(1 for sample in samples if not sample.get("is_malicious", False))
malicious_total = sum(1 for sample in samples if sample.get("is_malicious", False))
precision_by_action = {}
per_action_accuracy = {}
for action_name in VALID_ACTIONS:
predicted_total = action_counts.get(action_name, 0)
expected_total = sum(
1 for sample in samples if (sample.get("correct_action") or sample.get("expected_decision")) == action_name
)
correct_for_action = sum(
1
for sample, parsed_action in zip(samples, parsed_actions)
if parsed_action == action_name
and parsed_action == (sample.get("correct_action") or sample.get("expected_decision"))
)
precision_by_action[action_name.lower()] = correct_for_action / max(predicted_total, 1)
per_action_accuracy[action_name] = correct_for_action / max(expected_total, 1)
metrics = {
"label": label,
"sample_count": total,
"exact_match": exact / max(total, 1),
"safety_accuracy": safe / max(total, 1),
"valid_action_rate": valid / max(total, 1),
"invalid_action_rate": parser_stats.invalid_outputs / max(total, 1),
"invalid_output_rate": parser_stats.invalid_outputs / max(total, 1),
"parse_failure_rate": parser_stats.invalid_outputs / max(total, 1),
"unsafe_decision_rate": unsafe / max(total, 1),
"false_positive_rate": false_positive_count / max(benign_total, 1),
"false_negative_rate": false_negative_count / max(malicious_total, 1),
"reward_mean": safe_mean(rewards),
"reward_std": safe_std(rewards),
"allow_precision": precision_by_action["allow"],
"block_precision": precision_by_action["block"],
"fork_precision": precision_by_action["fork"],
"quarantine_precision": precision_by_action["quarantine"],
"per_action_accuracy": per_action_accuracy,
"confusion_matrix": confusion_matrix,
"avg_completion_length": safe_mean(completion_lengths),
"action_distribution": summarize_action_distribution(action_counts, total),
"normalized_action_distribution": summarize_action_distribution(action_counts, total),
"invalid_output_count": parser_stats.invalid_outputs,
"multi_action_warnings": parser_stats.multi_action_warnings,
"multi_action_warning_rate": parser_stats.multi_action_warnings / max(total, 1),
"entropy": distribution_entropy(summarize_action_distribution(action_counts, max(total, 1))),
"reward_breakdown": dict(reward_categories),
"predicted_actions": [action if action is not None else "INVALID" for action in parsed_actions],
"sample_rewards": rewards,
}
return metrics
def evaluate_policy_on_dataset(
samples: list[dict[str, Any]],
policy_name: str,
seed: int = 0,
) -> dict[str, Any]:
rng = random.Random(seed)
outputs: list[str] = []
for sample in samples:
if policy_name == "random":
outputs.append(random_policy_action(sample, rng))
elif policy_name == "heuristic":
outputs.append(heuristic_policy_action(sample))
elif policy_name == "q_aware":
if sample.get("use_agent_policy"):
outputs.append(
build_q_aware_decision(
sample["domain"],
sample.get("intent", sample.get("action_summary", "")),
sample.get("raw_payload", sample.get("action_summary", "")),
sample.get("risk_vector", [0.0] * 16),
actor=sample.get("actor", "unknown"),
session_id=sample.get("session_id", "default"),
service=sample.get("service", sample.get("domain", "unknown")),
environment=sample.get("environment", "production"),
provided_evidence=sample.get("provided_evidence", []),
)
)
else:
outputs.append(q_aware_policy_action(sample))
else:
raise ValueError(f"Unknown policy_name: {policy_name}")
return evaluate_outputs(samples, outputs, label=policy_name)
def evaluate_oracle(samples: list[dict[str, Any]]) -> dict[str, Any]:
outputs: list[str] = []
oracle_rewards: list[float] = []
for sample in samples:
oracle_payload = {
"explanation": "Required evidence reviewed: " + ", ".join(sample.get("required_evidence", [])),
"mitre_tactic": sample.get("mitre_tactic", ""),
"mitre_technique": sample.get("mitre_technique", ""),
}
reward_by_action = {
action: compute_shaped_reward(sample, action, oracle_payload)[0]
for action in VALID_ACTIONS
}
best_reward = max(reward_by_action.values())
best_actions = [action for action, reward in reward_by_action.items() if reward == best_reward]
chosen_action = best_actions[0]
outputs.append(chosen_action)
oracle_rewards.append(best_reward)
metrics = evaluate_outputs(samples, outputs, label="oracle")
metrics["sample_rewards"] = oracle_rewards
metrics["reward_mean"] = safe_mean(oracle_rewards)
metrics["reward_std"] = safe_std(oracle_rewards)
return metrics
def check_oracle_consistency(
samples: list[dict[str, Any]],
metrics_by_label: dict[str, dict[str, Any]],
output_path: Path = DEFAULT_ORACLE_INCONSISTENCY_PATH,
) -> dict[str, Any]:
oracle_metrics = metrics_by_label["oracle"]
oracle_rewards = oracle_metrics["sample_rewards"]
inconsistencies: list[dict[str, Any]] = []
for label, metrics in metrics_by_label.items():
if label == "oracle":
continue
if metrics["reward_mean"] > oracle_metrics["reward_mean"] + 1e-9:
inconsistencies.append(
{
"label": label,
"issue": "mean reward exceeded oracle",
"mean_reward": metrics["reward_mean"],
"oracle_mean_reward": oracle_metrics["reward_mean"],
}
)
sample_rewards = metrics.get("sample_rewards", [])
for index, reward in enumerate(sample_rewards):
if reward > oracle_rewards[index] + 1e-9:
sample = samples[index]
inconsistencies.append(
{
"label": label,
"sample_id": sample["sample_id"],
"reward": reward,
"oracle_reward": oracle_rewards[index],
"correct_action": sample["correct_action"],
"predicted_action": metrics["predicted_actions"][index],
"raw_payload": sample["raw_payload"],
}
)
write_json(output_path, inconsistencies)
try:
relative_output_path = str(output_path.relative_to(BACKEND_DIR))
except ValueError:
relative_output_path = str(output_path)
return {
"passed": not inconsistencies,
"oracle_reward_mean": oracle_metrics["reward_mean"],
"inconsistency_count": len(inconsistencies),
"output_path": relative_output_path,
}
def check_reward_variance(samples: list[dict[str, Any]]) -> dict[str, Any]:
candidate_actions = ["ALLOW", "BLOCK", "FORK", "QUARANTINE", "INVALID"]
group_stds: list[float] = []
groups_with_variance = 0
for sample in samples:
rewards = []
for action in candidate_actions:
predicted_action = None if action == "INVALID" else action
rewards.append(compute_shaped_reward(sample, predicted_action)[0])
reward_std = safe_std(rewards)
group_stds.append(reward_std)
if reward_std > 0:
groups_with_variance += 1
groups_checked = len(samples)
zero_std_fraction = (
sum(1 for value in group_stds if value <= 1e-12) / max(groups_checked, 1)
)
result = {
"groups_checked": groups_checked,
"groups_with_variance": groups_with_variance,
"groups_with_variance_pct": groups_with_variance / max(groups_checked, 1),
"mean_reward_std": safe_mean(group_stds),
"frac_reward_zero_std": zero_std_fraction,
"passed": (
(groups_with_variance / max(groups_checked, 1)) >= 0.80
and zero_std_fraction <= 0.30
),
}
return result
def run_reward_diagnostics(samples: Optional[list[dict[str, Any]]] = None) -> dict[str, Any]:
if samples is None:
val_samples, _ = load_validation_samples_for_benchmark()
samples = val_samples
candidate_outputs = [
"ALLOW",
"allow.",
'{"action": "ALLOW"}',
"Action: ALLOW because evidence is complete.",
"BLOCK",
"deny",
'{"decision": "BLOCK"}',
"Recommended action is BLOCK.",
"FORK",
"human review",
'{"action": "FORK", "explanation": "approval required"}',
"I recommend FORK until the owner approves.",
"QUARANTINE",
"isolate",
'{"decision": "QUARANTINE"}',
"I recommend isolate until evidence arrives.",
"approve",
"reject",
"review",
"hold",
"decision=quarantine",
"supervisor decision: fork",
"Action: block.",
"hello",
]
group_stds: list[float] = []
invalid_outputs = 0
total_outputs = 0
action_counts: Counter = Counter()
sample_summaries: list[dict[str, Any]] = []
same_reward_examples: list[dict[str, Any]] = []
all_rewards: list[float] = []
for sample in samples:
rewards_by_output: dict[str, float] = {}
for output in candidate_outputs:
analysis = analyze_action_output(output)
action = analysis["parsed_action"]
if action is None:
invalid_outputs += 1
else:
action_counts[action] += 1
total_outputs += 1
reward, _ = compute_shaped_reward(
sample,
action,
{"multi_action_warning": analysis["multi_action_warning"]},
)
rewards_by_output[output] = reward
all_rewards.append(reward)
values = list(rewards_by_output.values())
reward_std = safe_std(values)
group_stds.append(reward_std)
summary = {
"sample_id": sample.get("sample_id"),
"correct_action": sample.get("correct_action") or sample.get("expected_decision"),
"scenario_type": sample.get("scenario_type"),
"severity": sample.get("severity"),
"reward_std": reward_std,
"reward_range": max(values) - min(values),
"payload": str(sample.get("raw_payload", sample.get("action_summary", "")))[:140],
"rewards": rewards_by_output,
}
sample_summaries.append(summary)
if reward_std <= 1e-12:
same_reward_examples.append(summary)
easiest = sorted(sample_summaries, key=lambda row: row["reward_range"], reverse=True)[:5]
hardest = sorted(sample_summaries, key=lambda row: row["reward_range"])[:5]
diagnostics = {
"sample_count": len(samples),
"reward_mean": safe_mean(all_rewards),
"reward_std": safe_std(all_rewards),
"percent_zero_std_groups": 100.0 * sum(1 for value in group_stds if value <= 1e-12) / max(len(group_stds), 1),
"frac_reward_zero_std": sum(1 for value in group_stds if value <= 1e-12) / max(len(group_stds), 1),
"invalid_output_rate": invalid_outputs / max(total_outputs, 1),
"action_distribution": summarize_action_distribution(action_counts, max(total_outputs - invalid_outputs, 1)),
"easiest_samples": easiest,
"hardest_samples": hardest,
"same_reward_examples": same_reward_examples[:5],
}
return diagnostics
def print_reward_diagnostics(diagnostics: dict[str, Any]) -> None:
print("Reward diagnostics")
print("------------------")
print(f"samples: {diagnostics['sample_count']}")
print(f"reward mean/std: {diagnostics['reward_mean']:.3f} / {diagnostics['reward_std']:.3f}")
print(f"zero-std groups: {diagnostics['percent_zero_std_groups']:.1f}%")
print(f"invalid output rate: {diagnostics['invalid_output_rate']:.3f}")
print(f"action distribution: {diagnostics['action_distribution']}")
print("easiest samples:")
for row in diagnostics["easiest_samples"][:3]:
print(f" {row['sample_id']} {row['correct_action']} range={row['reward_range']:.3f}")
print("hardest samples:")
for row in diagnostics["hardest_samples"][:3]:
print(f" {row['sample_id']} {row['correct_action']} range={row['reward_range']:.3f}")
if diagnostics["same_reward_examples"]:
print("same-reward examples:")
for row in diagnostics["same_reward_examples"][:3]:
print(f" {row['sample_id']} {row['correct_action']} {row['payload']}")
else:
print("same-reward examples: none")
def metric_delta(model_metrics: dict[str, Any], baseline_metrics: dict[str, Any]) -> dict[str, float]:
keys = (
"exact_match",
"safety_accuracy",
"unsafe_decision_rate",
"false_positive_rate",
"false_negative_rate",
"reward_mean",
"invalid_output_rate",
)
return {
key: float(model_metrics.get(key, 0.0) or 0.0) - float(baseline_metrics.get(key, 0.0) or 0.0)
for key in keys
}
def evaluate_training_gate(
model_metrics: Optional[dict[str, Any]],
q_aware_metrics: dict[str, Any],
*,
reference_metrics: Optional[dict[str, Any]] = None,
min_safety: float = 0.98,
max_unsafe: float = 0.02,
max_invalid: float = 0.05,
) -> dict[str, Any]:
if model_metrics is None:
return {
"training_gate_status": "FAIL",
"training_gate_passed": False,
"reason": "No model metrics are available; checkpoint was not loaded or evaluation failed.",
"recommended_next_action": "Run --evaluate-model with a valid --model-path after SFT/GRPO smoke training.",
}
def _metric(metrics: dict[str, Any], key: str, default: float) -> float:
value = metrics.get(key, default)
return default if value is None else float(value)
safety = _metric(model_metrics, "safety_accuracy", 0.0)
unsafe = _metric(model_metrics, "unsafe_decision_rate", 1.0)
invalid = _metric(
model_metrics,
"invalid_output_rate",
_metric(model_metrics, "invalid_action_rate", 1.0),
)
reward = _metric(model_metrics, "reward_mean", 0.0)
q_reward = _metric(q_aware_metrics, "reward_mean", 0.0)
if safety < min_safety - 0.08 or unsafe > max_unsafe + 0.08 or invalid > max_invalid:
return {
"training_gate_status": "FAIL",
"training_gate_passed": False,
"reason": (
f"Safety gate failed: safety={safety:.3f}, unsafe={unsafe:.3f}, "
f"invalid={invalid:.3f}."
),
"recommended_next_action": "Stop longer training; fix parsing/reward/data before spending more credits.",
}
if reference_metrics is None:
status = "WARN" if safety >= min_safety and unsafe <= max_unsafe else "FAIL"
return {
"training_gate_status": status,
"training_gate_passed": False,
"reason": "No raw/SFT reference metrics are available, so improvement is not proven.",
"recommended_next_action": "Evaluate raw or SFT checkpoint, then compare this checkpoint again.",
}
reference_reward = _metric(reference_metrics, "reward_mean", 0.0)
reference_unsafe = _metric(reference_metrics, "unsafe_decision_rate", 1.0)
improves_reference = reward > reference_reward + 1e-9 and unsafe <= reference_unsafe + 1e-9
if safety >= min_safety and unsafe <= max_unsafe and improves_reference:
if reward >= q_reward:
return {
"training_gate_status": "PASS",
"training_gate_passed": True,
"reason": "Model meets safety/unsafe thresholds and improves reward over reference metrics.",
"recommended_next_action": "Run a longer GRPO explore only if budget allows; keep comparing to Q-aware.",
}
return {
"training_gate_status": "WARN",
"training_gate_passed": False,
"reason": "Model improves over reference but remains below Q-aware policy reward.",
"recommended_next_action": "Continue small GRPO exploration or improve reward/data before final training.",
}
return {
"training_gate_status": "FAIL",
"training_gate_passed": False,
"reason": "Model did not prove reward/safety improvement over raw/SFT reference metrics.",
"recommended_next_action": "Do not claim training success; inspect reward diagnostics and checkpoint outputs.",
}
def run_parse_action_tests() -> dict[str, Any]:
tests = [
("ALLOW", "ALLOW"),
(" allow. ", "ALLOW"),
("<think>abc</think> BLOCK", "BLOCK"),
("Action: FORK", "FORK"),
("decision=quarantine", "QUARANTINE"),
("deny", "BLOCK"),
("human_review", "FORK"),
('{"action": "allow"}', "ALLOW"),
('{"decision": "human review", "explanation": "needs approval"}', "FORK"),
("I recommend isolate until evidence arrives.", "QUARANTINE"),
("hello", None),
("please allow this", None),
("ALLOW then BLOCK", None),
("QUARANTINE because...", "QUARANTINE"),
]
failures = []
for text, expected in tests:
actual = parse_action(text)
if actual != expected:
failures.append(
{
"input": text,
"expected": expected,
"actual": actual,
}
)
return {
"passed": not failures,
"test_count": len(tests),
"failures": failures,
}
def compact_metrics(metrics: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
if metrics is None:
return None
compact = dict(metrics)
compact.pop("predicted_actions", None)
compact.pop("sample_rewards", None)
return compact
def build_evaluation_bundle(
val_samples: list[dict[str, Any]],
raw_model_metrics: Optional[dict[str, Any]] = None,
sft_metrics: Optional[dict[str, Any]] = None,
grpo_metrics: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, dict[str, Any]], dict[str, Any]]:
metrics_by_label = {
"random": evaluate_policy_on_dataset(val_samples, "random", seed=1),
"heuristic": evaluate_policy_on_dataset(val_samples, "heuristic", seed=2),
"q_aware": evaluate_policy_on_dataset(val_samples, "q_aware", seed=3),
"oracle": evaluate_oracle(val_samples),
}
if raw_model_metrics is not None:
metrics_by_label["raw_model"] = raw_model_metrics
if sft_metrics is not None:
metrics_by_label["sft_model"] = sft_metrics
if grpo_metrics is not None:
metrics_by_label["sft_grpo_model"] = grpo_metrics
oracle_check = check_oracle_consistency(val_samples, metrics_by_label)
return metrics_by_label, oracle_check
def assert_dataset_file_counts_match_audit(
dataset_audit: dict[str, Any],
*,
train_samples: Optional[list[dict[str, Any]]] = None,
val_samples: Optional[list[dict[str, Any]]] = None,
) -> None:
expected_train = int(dataset_audit.get("train_sample_count", -1))
expected_val = int(dataset_audit.get("val_sample_count", -1))
if train_samples is not None and len(train_samples) != expected_train:
raise TrainingPreflightError(
f"Train dataset count {len(train_samples)} does not match dataset_audit.json {expected_train}."
)
if val_samples is not None and len(val_samples) != expected_val:
raise TrainingPreflightError(
f"Validation dataset count {len(val_samples)} does not match dataset_audit.json {expected_val}."
)
def load_validation_samples_for_benchmark() -> tuple[list[dict[str, Any]], dict[str, Any]]:
dataset_audit = read_json(DEFAULT_DATASET_AUDIT_PATH, default=None)
train_samples = read_json(DEFAULT_TRAIN_DATASET_PATH, default=None)
val_samples = read_json(DEFAULT_VAL_DATASET_PATH, default=None)
if dataset_audit is None or train_samples is None or val_samples is None:
train_samples, val_samples, dataset_audit = generate_datasets(save=True)
assert_dataset_file_counts_match_audit(
dataset_audit,
train_samples=train_samples,
val_samples=val_samples,
)
preflight_dataset_check(dataset_audit)
return val_samples, dataset_audit
def build_demo_benchmark_rows(metrics_by_label: dict[str, dict[str, Any]]) -> list[dict[str, Any]]:
policies = (
("Random", "random"),
("Heuristic", "heuristic"),
("Q-aware", "q_aware"),
("Oracle", "oracle"),
)
rows = []
for display_name, label in policies:
metrics = metrics_by_label[label]
row = {"policy": display_name}
for metric_name in DEMO_BENCHMARK_METRICS:
row[metric_name] = float(metrics[metric_name])
rows.append(row)
return rows
def format_demo_benchmark_table(rows: list[dict[str, Any]]) -> str:
columns = ("policy",) + DEMO_BENCHMARK_METRICS
widths = {
column: max(len(column), *(len(f"{row[column]:.3f}") if column != "policy" else len(row[column]) for row in rows))
for column in columns
}
header = " ".join(column.ljust(widths[column]) for column in columns)
divider = " ".join("-" * widths[column] for column in columns)
body = []
for row in rows:
cells = [row["policy"].ljust(widths["policy"])]
cells.extend(f"{row[column]:.3f}".rjust(widths[column]) for column in DEMO_BENCHMARK_METRICS)
body.append(" ".join(cells))
return "\n".join([header, divider, *body])
def write_demo_benchmark_reports(
*,
rows: list[dict[str, Any]],
dataset_audit: dict[str, Any],
oracle_check: dict[str, Any],
output_json: Path = DEFAULT_DEMO_BENCHMARK_JSON,
output_md: Path = DEFAULT_DEMO_BENCHMARK_MD,
) -> dict[str, Any]:
report = {
"dataset": {
"train_sample_count": dataset_audit["train_sample_count"],
"val_sample_count": dataset_audit["val_sample_count"],
"validation_source": str(DEFAULT_VAL_DATASET_PATH.relative_to(BACKEND_DIR)),
},
"metrics": rows,
"oracle_check": oracle_check,
}
write_json(output_json, report)
header = "| Policy | " + " | ".join(DEMO_BENCHMARK_METRICS) + " |"
divider = "| --- | " + " | ".join("---:" for _ in DEMO_BENCHMARK_METRICS) + " |"
md_lines = [
"# ShadowOps Demo Benchmark",
"",
f"Validation samples: {dataset_audit['val_sample_count']}",
"",
header,
divider,
]
for row in rows:
values = " | ".join(f"{row[metric]:.3f}" for metric in DEMO_BENCHMARK_METRICS)
md_lines.append(f"| {row['policy']} | {values} |")
md_lines.append("")
output_md.write_text("\n".join(md_lines), encoding="utf-8")
return report
def run_demo_benchmark(
output_json: Path = DEFAULT_DEMO_BENCHMARK_JSON,
output_md: Path = DEFAULT_DEMO_BENCHMARK_MD,
) -> dict[str, Any]:
val_samples, dataset_audit = load_validation_samples_for_benchmark()
metrics_by_label, oracle_check = build_evaluation_bundle(val_samples)
rows = build_demo_benchmark_rows(metrics_by_label)
report = write_demo_benchmark_reports(
rows=rows,
dataset_audit=dataset_audit,
oracle_check=oracle_check,
output_json=output_json,
output_md=output_md,
)
print(format_demo_benchmark_table(rows))
print(f"\nSaved: {output_json.relative_to(BACKEND_DIR)}")
print(f"Saved: {output_md.relative_to(BACKEND_DIR)}")
return report
def _legacy_domain_for_policy_domain(domain: str) -> str:
domain = str(domain or "").lower()
if domain == "github_ci":
return "GITHUB"
if domain in {"aws_s3", "iam"}:
return "AWS"
return "SOC"
def _scenario_type_for_hard_negative(scenario: dict[str, Any]) -> str:
expected = scenario["expected_decision"]
risk_level = str(scenario.get("risk_level", "medium")).lower()
if expected == "ALLOW":
return "BENIGN_NOISY" if risk_level in {"medium", "high"} else "BENIGN_CLEAN"
if expected == "BLOCK":
return "MALICIOUS_HARD" if risk_level in {"critical", "high"} else "MALICIOUS_SOFT"
if expected == "QUARANTINE":
return "MALICIOUS_HARD" if risk_level == "critical" else "AMBIGUOUS"
return "AMBIGUOUS"
def hard_negative_to_sample(scenario: dict[str, Any], index: int) -> dict[str, Any]:
from evidence_planner import get_required_evidence
raw_payload = scenario["action_summary"]
legacy_domain = _legacy_domain_for_policy_domain(scenario["domain"])
risk_vector = extract_features(
legacy_domain,
scenario.get("domain", "security").upper(),
raw_payload + " " + " ".join(scenario.get("indicators", [])),
random.Random(90_000 + index),
)
risk_level = str(scenario.get("risk_level", "medium")).lower()
severity = {
"low": "LOW",
"medium": "MEDIUM",
"high": "HIGH",
"critical": "CRITICAL",
}.get(risk_level, "MEDIUM")
is_malicious = scenario["expected_decision"] == "BLOCK" or any(
cue in scenario["title"].lower()
for cue in ("malicious", "rogue", "unauthorized", "without approval", "exfiltration")
)
required_evidence = get_required_evidence(
scenario["domain"],
scenario.get("indicators", []),
scenario["expected_decision"],
scenario["environment"],
)
sample = {
"sample_id": scenario["id"],
"split": "hard_negative",
"domain": scenario["domain"],
"intent": scenario["title"],
"raw_payload": raw_payload,
"action_summary": raw_payload,
"actor": scenario["actor"],
"session_id": scenario["session_id"],
"service": scenario["service"],
"environment": scenario["environment"],
"indicators": list(scenario.get("indicators", [])),
"provided_evidence": list(scenario.get("provided_evidence", [])),
"correct_action": scenario["expected_decision"],
"expected_decision": scenario["expected_decision"],
"expected_safe_outcome": scenario["expected_safe_outcome"],
"risk_level": risk_level,
"scenario_type": _scenario_type_for_hard_negative(scenario),
"is_malicious": is_malicious,
"severity": severity,
"risk_score": round(float(compute_risk_score(risk_vector)), 6),
"risk_vector": [round(float(value), 6) for value in risk_vector[:16]],
"required_evidence": required_evidence,
"mitre_tactic": scenario["mitre_tactic"],
"mitre_technique": scenario["mitre_technique"],
"explanation": scenario["explanation"],
"use_agent_policy": True,
}
return sample
def load_hard_negative_samples(path: Path = DEFAULT_HARD_NEGATIVE_PATH) -> list[dict[str, Any]]:
scenarios = read_json(path, default=[])
return [hard_negative_to_sample(scenario, index) for index, scenario in enumerate(scenarios)]
def _comparison_metric_row(label: str, metrics: Optional[dict[str, Any]]) -> dict[str, Any]:
if metrics is None:
row = {"policy": label, "available": False}
row.update({metric: None for metric in MODEL_POLICY_METRICS})
return row
row = {"policy": label, "available": True}
for metric in MODEL_POLICY_METRICS:
row[metric] = metrics.get(metric)
return row
def _write_model_policy_comparison_md(report: dict[str, Any], output_md: Path) -> None:
lines = ["# ShadowOps Model vs Policy Comparison", ""]
for dataset_name, dataset_report in report["datasets"].items():
lines.extend([f"## {dataset_name}", "", f"Samples: {dataset_report['sample_count']}", ""])
header = "| Policy | Available | " + " | ".join(MODEL_POLICY_METRICS) + " |"
divider = "| --- | --- | " + " | ".join("---:" for _ in MODEL_POLICY_METRICS) + " |"
lines.extend([header, divider])
for row in dataset_report["rows"]:
values = []
for metric in MODEL_POLICY_METRICS:
value = row.get(metric)
values.append("n/a" if value is None else f"{float(value):.3f}")
lines.append(f"| {row['policy']} | {row['available']} | " + " | ".join(values) + " |")
lines.append("")
output_md.write_text("\n".join(lines), encoding="utf-8")
def run_model_policy_comparison(
*,
model_metrics_by_dataset: Optional[dict[str, dict[str, Optional[dict[str, Any]]]]] = None,
output_json: Path = DEFAULT_MODEL_POLICY_COMPARISON_JSON,
output_md: Path = DEFAULT_MODEL_POLICY_COMPARISON_MD,
) -> dict[str, Any]:
val_samples, _ = load_validation_samples_for_benchmark()
hard_negative_samples = load_hard_negative_samples()
model_metrics_by_dataset = model_metrics_by_dataset or {}
report = {
"model_metrics_note": "Model rows are unavailable unless raw/SFT/GRPO models are explicitly loaded and evaluated.",
"datasets": {},
}
for dataset_name, samples in (
("validation", val_samples),
("hard_negative", hard_negative_samples),
):
metrics_by_label, oracle_check = build_evaluation_bundle(samples)
rows = [
_comparison_metric_row("random", metrics_by_label["random"]),
_comparison_metric_row("heuristic", metrics_by_label["heuristic"]),
_comparison_metric_row("q_aware_policy", metrics_by_label["q_aware"]),
_comparison_metric_row("oracle", metrics_by_label["oracle"]),
]
dataset_model_metrics = model_metrics_by_dataset.get(dataset_name, {})
for label in ("raw_model", "sft_model", "grpo_model"):
rows.append(_comparison_metric_row(label, dataset_model_metrics.get(label)))
report["datasets"][dataset_name] = {
"sample_count": len(samples),
"oracle_check": oracle_check,
"rows": rows,
}
write_json(output_json, report)
_write_model_policy_comparison_md(report, output_md)
print(f"Saved: {output_json.relative_to(BACKEND_DIR)}")
print(f"Saved: {output_md.relative_to(BACKEND_DIR)}")
return report
def build_training_health_report(
pre_train_metrics: Optional[dict[str, Any]],
sft_metrics: Optional[dict[str, Any]],
grpo_metrics: Optional[dict[str, Any]],
tracker: Optional[RewardHealthTracker],
baseline_metrics: dict[str, Any],
oracle_metrics: dict[str, Any],
lora_parameter_delta: Optional[dict[str, Any]],
oracle_check: dict[str, Any],
) -> dict[str, Any]:
warnings_out: list[str] = []
reward_zero_fraction = None if tracker is None else tracker.reward_std_zero_fraction
grad_zero_fraction = None if tracker is None else tracker.grad_norm_zero_fraction
invalid_output_rate = (
grpo_metrics.get("invalid_action_rate")
if grpo_metrics is not None
else (tracker.invalid_output_rate if tracker is not None else None)
)
if reward_zero_fraction is not None and reward_zero_fraction > 0.90:
warnings_out.append("CRITICAL: reward_std_zero_fraction > 0.90")
elif reward_zero_fraction is not None and reward_zero_fraction > 0.50:
warnings_out.append("WARNING: reward_std_zero_fraction > 0.50")
if grad_zero_fraction is not None and grad_zero_fraction > 0.90:
warnings_out.append("CRITICAL: grad_norm_zero_fraction > 0.90")
elif grad_zero_fraction is not None and grad_zero_fraction > 0.50:
warnings_out.append("WARNING: grad_norm_zero_fraction > 0.50")
if grpo_metrics is not None and sft_metrics is not None:
if grpo_metrics["safety_accuracy"] < sft_metrics["safety_accuracy"] - 0.05:
warnings_out.append("WARNING: GRPO safety accuracy degraded relative to SFT")
if grpo_metrics["exact_match"] < sft_metrics["exact_match"] - 0.05:
warnings_out.append("WARNING: GRPO exact match degraded relative to SFT")
if not oracle_check.get("passed", False):
warnings_out.append("ERROR: oracle consistency failed")
is_training_healthy = not any(message.startswith(("CRITICAL", "ERROR")) for message in warnings_out)
q_aware_baseline = baseline_metrics.get("q_aware") or {}
training_gate = (
evaluate_training_gate(
grpo_metrics,
q_aware_baseline,
reference_metrics=sft_metrics or pre_train_metrics,
)
if grpo_metrics is not None
else {
"training_gate_status": "WARN",
"training_gate_passed": False,
"reason": "No GRPO metrics are available; training success cannot be claimed.",
"recommended_next_action": "Run checkpoint evaluation before claiming model improvement.",
}
)
report = {
"pre_train_metrics": compact_metrics(pre_train_metrics),
"sft_metrics": compact_metrics(sft_metrics),
"grpo_metrics": compact_metrics(grpo_metrics),
"reward_std_zero_fraction": reward_zero_fraction,
"grad_norm_zero_fraction": grad_zero_fraction,
"invalid_output_rate": invalid_output_rate,
"action_distribution": (
tracker.action_distribution
if tracker is not None
else (grpo_metrics or {}).get("action_distribution", {})
),
"entropy": tracker.entropy if tracker is not None else None,
"oracle_metrics": compact_metrics(oracle_metrics),
"baseline_metrics": {
label: compact_metrics(metrics)
for label, metrics in baseline_metrics.items()
},
"lora_parameter_delta": lora_parameter_delta,
"training_gate": training_gate,
"training_gate_status": training_gate["training_gate_status"],
"training_gate_passed": training_gate["training_gate_passed"],
"is_training_healthy": is_training_healthy,
"warnings": warnings_out,
}
return report
def compute_improvement(before: Optional[dict[str, Any]], after: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
if before is None or after is None:
return None
return {
"exact_match_change": after["exact_match"] - before["exact_match"],
"safety_accuracy_change": after["safety_accuracy"] - before["safety_accuracy"],
"valid_action_rate_change": after["valid_action_rate"] - before["valid_action_rate"],
"reward_mean_change": after["reward_mean"] - before["reward_mean"],
}
def build_training_ready_criteria(
parse_tests: dict[str, Any],
dataset_audit: dict[str, Any],
reward_variance: dict[str, Any],
oracle_check: dict[str, Any],
smoke_test_result: bool,
) -> dict[str, bool]:
criteria = {
"parse_tests_pass": bool(parse_tests.get("passed", False)),
"dataset_audit_passes": bool(dataset_audit.get("passed_preflight", False)),
"reward_variance_passes": bool(reward_variance.get("passed", False)),
"oracle_consistency_passes": bool(oracle_check.get("passed", False)),
"evaluation_isolated": True,
"sft_script_exists": (TRAINING_DIR / "train_qwen3_sft.py").exists(),
"grpo_can_resume_from_sft": True,
"final_report_generation_works": True,
"smoke_test_passes": bool(smoke_test_result),
"cloud_training_commands_documented": (
DEFAULT_CLOUD_SCRIPT_PATH.exists() or DEFAULT_CLOUD_PS1_PATH.exists()
),
}
return criteria
def assert_report_sample_counts_match_latest_audit(report: dict[str, Any]) -> None:
latest_audit = read_json(DEFAULT_DATASET_AUDIT_PATH, default=None)
if latest_audit is None:
return
report_audit = report.get("dataset_audit", {})
for key in ("train_sample_count", "val_sample_count"):
latest_count = latest_audit.get(key)
report_count = report_audit.get(key)
if report_count != latest_count:
raise TrainingPreflightError(
f"Final report {key}={report_count} does not match dataset_audit.json {key}={latest_count}."
)
def generate_final_reports(
*,
dataset_audit: dict[str, Any],
parse_tests: dict[str, Any],
reward_variance: dict[str, Any],
oracle_check: dict[str, Any],
metrics_by_label: dict[str, dict[str, Any]],
training_health_report: dict[str, Any],
smoke_test_result: bool,
best_checkpoint_path: Optional[str],
output_json: Path = DEFAULT_FINAL_REPORT_JSON,
output_md: Path = DEFAULT_FINAL_REPORT_MD,
) -> dict[str, Any]:
latest_dataset_audit = read_json(DEFAULT_DATASET_AUDIT_PATH, default=None)
if latest_dataset_audit is not None:
for key in ("train_sample_count", "val_sample_count"):
if dataset_audit.get(key) != latest_dataset_audit.get(key):
raise TrainingPreflightError(
f"Report input {key}={dataset_audit.get(key)} does not match "
f"dataset_audit.json {key}={latest_dataset_audit.get(key)}."
)
dataset_audit = latest_dataset_audit
criteria = build_training_ready_criteria(
parse_tests=parse_tests,
dataset_audit=dataset_audit,
reward_variance=reward_variance,
oracle_check=oracle_check,
smoke_test_result=smoke_test_result,
)
training_ready = all(criteria.values())
raw_metrics = compact_metrics(metrics_by_label.get("raw_model"))
sft_metrics = compact_metrics(metrics_by_label.get("sft_model"))
grpo_metrics = compact_metrics(metrics_by_label.get("sft_grpo_model"))
oracle_metrics = compact_metrics(metrics_by_label["oracle"])
baseline_metrics = {
label: compact_metrics(metrics_by_label[label])
for label in ("random", "heuristic", "q_aware")
}
report = {
"status": {
"training_ready": training_ready,
"final_status": f"Training-ready: {'yes' if training_ready else 'no'}",
},
"what_was_broken": BROKEN_ITEMS,
"what_was_fixed": FIXED_ITEMS,
"dataset_audit": dataset_audit,
"parse_tests": parse_tests,
"reward_variance": reward_variance,
"oracle_check": oracle_check,
"raw_model_metrics": raw_metrics,
"sft_metrics": sft_metrics,
"grpo_metrics": grpo_metrics,
"oracle_metrics": oracle_metrics,
"baseline_metrics": baseline_metrics,
"improvements": {
"raw_to_sft": compute_improvement(raw_metrics, sft_metrics),
"sft_to_grpo": compute_improvement(sft_metrics, grpo_metrics),
"raw_to_grpo": compute_improvement(raw_metrics, grpo_metrics),
},
"training_health_report": training_health_report,
"training_ready_criteria": criteria,
"best_checkpoint_path": best_checkpoint_path
or "training/checkpoints/qwen3_sft_grpo_adapter (pending cloud run)",
"cloud_commands": {
"sft": CLOUD_SFT_COMMAND,
"grpo": CLOUD_GRPO_COMMAND,
"oom_fallback": CLOUD_FALLBACK_COMMAND,
},
}
assert_report_sample_counts_match_latest_audit(report)
write_json(output_json, report)
markdown_lines = [
"# ShadowOps Training Readiness Report",
"",
f"Training-ready: {'yes' if training_ready else 'no'}",
"",
"## Broken",
]
markdown_lines.extend(f"- {item}" for item in BROKEN_ITEMS)
markdown_lines.extend(["", "## Fixed"])
markdown_lines.extend(f"- {item}" for item in FIXED_ITEMS)
markdown_lines.extend(
[
"",
"## Criteria",
]
)
markdown_lines.extend(
f"- {name}: {'pass' if value else 'fail'}"
for name, value in criteria.items()
)
markdown_lines.extend(
[
"",
"## Dataset Audit",
f"- Train samples: {dataset_audit['train_sample_count']}",
f"- Val samples: {dataset_audit['val_sample_count']}",
f"- Duplicate prompts: {dataset_audit['duplicate_prompt_count']}",
f"- Train/val overlap: {dataset_audit['train_val_overlap_count']}",
"",
"## Metrics",
f"- Random reward mean: {baseline_metrics['random']['reward_mean']:.3f}",
f"- Heuristic reward mean: {baseline_metrics['heuristic']['reward_mean']:.3f}",
f"- Q-aware reward mean: {baseline_metrics['q_aware']['reward_mean']:.3f}",
f"- Oracle reward mean: {oracle_metrics['reward_mean']:.3f}",
f"- Raw model metrics available: {'yes' if raw_metrics else 'no'}",
f"- SFT metrics available: {'yes' if sft_metrics else 'no'}",
f"- GRPO metrics available: {'yes' if grpo_metrics else 'no'}",
"",
"## Health Warnings",
]
)
if training_health_report.get("warnings"):
markdown_lines.extend(f"- {warning}" for warning in training_health_report["warnings"])
else:
markdown_lines.append("- None")
markdown_lines.extend(
[
"",
"## Cloud Commands",
"```bash",
CLOUD_SFT_COMMAND,
CLOUD_GRPO_COMMAND,
"# Fallback if GPU memory is tight:",
CLOUD_FALLBACK_COMMAND,
"```",
"",
f"Best checkpoint path: {report['best_checkpoint_path']}",
"",
]
)
output_md.write_text("\n".join(markdown_lines), encoding="utf-8")
return report
def run_logic_smoke_test(
train_size: int = DEFAULT_TRAIN_SIZE,
val_size: int = DEFAULT_VAL_SIZE,
) -> dict[str, Any]:
train_samples, val_samples, dataset_audit = generate_datasets(
train_size=train_size,
val_size=val_size,
save=True,
)
parse_tests = run_parse_action_tests()
reward_variance = check_reward_variance(train_samples[:40] + val_samples[:20])
metrics_by_label, oracle_check = build_evaluation_bundle(val_samples)
health_report = build_training_health_report(
pre_train_metrics=None,
sft_metrics=None,
grpo_metrics=None,
tracker=None,
baseline_metrics={label: metrics_by_label[label] for label in ("random", "heuristic", "q_aware")},
oracle_metrics=metrics_by_label["oracle"],
lora_parameter_delta=None,
oracle_check=oracle_check,
)
write_json(DEFAULT_HEALTH_REPORT_PATH, health_report)
smoke_pass = all(
[
parse_tests["passed"],
dataset_audit["passed_preflight"],
reward_variance["passed"],
oracle_check["passed"],
]
)
final_report = generate_final_reports(
dataset_audit=dataset_audit,
parse_tests=parse_tests,
reward_variance=reward_variance,
oracle_check=oracle_check,
metrics_by_label=metrics_by_label,
training_health_report=health_report,
smoke_test_result=smoke_pass,
best_checkpoint_path=None,
)
return {
"smoke_test_passed": smoke_pass,
"parse_tests": parse_tests,
"dataset_audit": dataset_audit,
"reward_variance": reward_variance,
"oracle_check": oracle_check,
"metrics_by_label": {label: compact_metrics(metrics) for label, metrics in metrics_by_label.items()},
"final_report": final_report["status"],
}
def get_installed_version(package_name: str) -> Optional[str]:
try:
return package_version(package_name)
except PackageNotFoundError:
return None
def validate_training_runtime(model_name: str) -> bool:
required_packages = ("torch", "datasets", "transformers", "trl", "unsloth")
missing_packages = [
package_name
for package_name in required_packages
if importlib.util.find_spec(package_name) is None
]
if missing_packages:
missing_str = ", ".join(missing_packages)
print(f"Missing dependency package(s): {missing_str}")
print("Install with: pip install torch datasets transformers trl unsloth")
return False
try:
import torch
except Exception as exc:
print(f"Failed to import torch: {exc}")
return False
if not torch.cuda.is_available():
print("Training runtime is not GPU-ready. Use Colab/Kaggle/cloud GPU for full training.")
return False
if "qwen3" in model_name.lower():
transformers_version = get_installed_version("transformers")
minimum_version = Version("4.50.3")
if transformers_version is None or Version(transformers_version) < minimum_version:
print(
f"Incompatible transformers version for Qwen3 training. "
f"Installed={transformers_version}, required>={minimum_version}"
)
return False
return True
def get_total_gpu_memory_gb() -> Optional[float]:
try:
import torch
if not torch.cuda.is_available():
return None
total_bytes = torch.cuda.get_device_properties(0).total_memory
return total_bytes / (1024 ** 3)
except Exception:
return None
def configure_runtime_noise_filters() -> None:
warning_filters = (
("ignore", r".*inductor::_alloc_from_pool.*", UserWarning),
("ignore", r".*_check_is_size will be removed in a future PyTorch release.*", FutureWarning),
("ignore", r".*quantization_config.*already has a `quantization_config` attribute.*", UserWarning),
)
for action, message, category in warning_filters:
warnings.filterwarnings(action, message=message, category=category)
def configure_torch_runtime(torch_module: Any) -> None:
if hasattr(torch_module, "set_float32_matmul_precision"):
torch_module.set_float32_matmul_precision("high")
if not torch_module.cuda.is_available():
return
with contextlib.suppress(Exception):
torch_module.backends.cuda.matmul.allow_tf32 = True
with contextlib.suppress(Exception):
torch_module.backends.cudnn.allow_tf32 = True
class _BlockedImportLoader(importlib.abc.Loader):
def __init__(self, module_name: str):
self.module_name = module_name
def create_module(self, spec: Any) -> None:
return None
def exec_module(self, module: Any) -> None:
raise ModuleNotFoundError(f"No module named '{self.module_name}'")
class _BlockedImportFinder(importlib.abc.MetaPathFinder):
def __init__(self, hidden_roots: tuple[str, ...]):
self.hidden_roots = hidden_roots
def find_spec(self, fullname: str, path: Any = None, target: Any = None) -> Optional[importlib.machinery.ModuleSpec]:
if fullname.split(".", 1)[0] not in self.hidden_roots:
return None
return importlib.machinery.ModuleSpec(
name=fullname,
loader=_BlockedImportLoader(fullname),
is_package="." not in fullname,
)
def clear_modules(module_roots: Iterable[str]) -> None:
roots = set(module_roots)
for module_name in list(sys.modules):
if module_name.split(".", 1)[0] in roots:
sys.modules.pop(module_name, None)
@contextlib.contextmanager
def temporarily_hide_packages(package_roots: Iterable[str]) -> Iterable[None]:
hidden_roots = tuple(dict.fromkeys(package_roots))
finder = _BlockedImportFinder(hidden_roots)
original_find_spec = importlib.util.find_spec
def patched_find_spec(name: str, package: Optional[str] = None) -> Any:
if name.split(".", 1)[0] in hidden_roots:
return None
return original_find_spec(name, package)
clear_modules(hidden_roots)
sys.meta_path.insert(0, finder)
importlib.util.find_spec = patched_find_spec
try:
yield
finally:
importlib.util.find_spec = original_find_spec
with contextlib.suppress(ValueError):
sys.meta_path.remove(finder)
def select_supported_kwargs(callable_obj: Any, candidate_kwargs: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
signature = inspect.signature(callable_obj)
supported_names = {
name
for name, parameter in signature.parameters.items()
if name != "self"
and parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
}
filtered = {
key: value
for key, value in candidate_kwargs.items()
if key in supported_names
}
dropped = sorted(set(candidate_kwargs) - set(filtered))
return filtered, dropped
def load_training_stack(mode: str) -> Optional[dict[str, Any]]:
package_hints = {
"transformers": "transformers",
"unsloth": "unsloth",
"torch": "torch",
"datasets": "datasets",
"trl": "trl",
}
loaded: dict[str, Any] = {}
missing: list[str] = []
failed_imports: list[tuple[str, BaseException]] = []
with temporarily_hide_packages(("vllm", "vllm_ascend")):
transformers_utils = None
for module_name, package_name in package_hints.items():
try:
if module_name == "unsloth":
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=r"WARNING: Unsloth should be imported before .*",
)
loaded[module_name] = import_module(module_name)
continue
loaded[module_name] = import_module(module_name)
if module_name == "transformers":
transformers_utils = import_module("transformers.utils")
if not hasattr(
loaded["transformers"].TrainingArguments,
"_VALID_DICT_FIELDS",
):
loaded["transformers"].TrainingArguments._VALID_DICT_FIELDS = list(
getattr(
loaded["transformers"].TrainingArguments,
"__dataclass_fields__",
{},
).keys()
)
if not hasattr(transformers_utils, "is_rich_available"):
transformers_utils.is_rich_available = lambda: False
except ModuleNotFoundError as exc:
missing_root = (exc.name or "").split(".", 1)[0]
if missing_root == module_name:
missing.append(package_name)
else:
failed_imports.append((module_name, exc))
except Exception as exc:
failed_imports.append((module_name, exc))
if missing:
print(f"Missing dependency package(s): {', '.join(sorted(set(missing)))}")
return None
if failed_imports:
print("Dependency bootstrap failed:")
for module_name, exc in failed_imports:
print(f" - {module_name}: {exc}")
return None
Dataset = getattr(loaded["datasets"], "Dataset", None)
FastModel = getattr(loaded["unsloth"], "FastModel", None)
if FastModel is None:
FastModel = getattr(loaded["unsloth"], "FastLanguageModel", None)
if mode == "grpo":
trainer_module = importlib.reload(import_module("trl.trainer.grpo_trainer"))
trainer_config = getattr(trainer_module, "GRPOConfig", None)
trainer_class = getattr(trainer_module, "GRPOTrainer", None)
elif mode == "sft":
trainer_module = importlib.reload(import_module("trl.trainer.sft_trainer"))
trainer_config = getattr(trainer_module, "SFTConfig", None)
trainer_class = getattr(trainer_module, "SFTTrainer", None)
else:
raise ValueError(f"Unknown mode: {mode}")
if Dataset is None or trainer_config is None or trainer_class is None or FastModel is None:
print(f"Training stack missing required symbols for mode={mode}.")
return None
return {
"torch": loaded["torch"],
"Dataset": Dataset,
"TrainerConfig": trainer_config,
"TrainerClass": trainer_class,
"FastModel": FastModel,
"transformers": loaded["transformers"],
}
def count_trainable_parameters(model: Any) -> tuple[int, int]:
trainable_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
total_params = sum(param.numel() for param in model.parameters())
return trainable_params, total_params
def capture_trainable_snapshot(model: Any) -> dict[str, Any]:
snapshot = {}
for name, parameter in model.named_parameters():
if parameter.requires_grad:
snapshot[name] = parameter.detach().float().cpu().clone()
return snapshot
def compute_parameter_delta(before_snapshot: dict[str, Any], model: Any) -> dict[str, Any]:
total_abs_delta = 0.0
total_l2_delta = 0.0
changed_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad or name not in before_snapshot:
continue
diff = parameter.detach().float().cpu() - before_snapshot[name]
abs_delta = float(diff.abs().sum().item())
l2_delta = float(torch_norm(diff))
total_abs_delta += abs_delta
total_l2_delta += l2_delta
if abs_delta > 0:
changed_params += 1
return {
"changed_trainable_tensors": changed_params,
"total_abs_delta": total_abs_delta,
"total_l2_delta": total_l2_delta,
}
def torch_norm(tensor: Any) -> float:
with contextlib.suppress(Exception):
return float(tensor.norm().item())
flat = tensor.reshape(-1).tolist()
return math.sqrt(sum(float(value) ** 2 for value in flat))
def maybe_apply_chat_template(tokenizer: Any, messages: list[dict[str, str]], add_generation_prompt: bool) -> str:
if hasattr(tokenizer, "apply_chat_template"):
template_kwargs = {
"tokenize": False,
"add_generation_prompt": add_generation_prompt,
}
with contextlib.suppress(Exception):
signature = inspect.signature(tokenizer.apply_chat_template)
if "enable_thinking" in signature.parameters:
template_kwargs["enable_thinking"] = False
return tokenizer.apply_chat_template(messages, **template_kwargs)
rendered = []
for message in messages:
rendered.append(f"{message['role'].upper()}: {message['content']}")
if add_generation_prompt:
rendered.append("ASSISTANT:")
return "\n\n".join(rendered)
def format_prompt_for_model(tokenizer: Any, prompt: str) -> str:
return maybe_apply_chat_template(
tokenizer,
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
)
def build_sft_training_text(tokenizer: Any, prompt: str, completion: str) -> str:
return maybe_apply_chat_template(
tokenizer,
[
{"role": "user", "content": prompt},
{"role": "assistant", "content": completion},
],
add_generation_prompt=False,
)
def make_reward_function(
tracker: Optional[RewardHealthTracker] = None,
) -> Any:
def reward_fn(
completions: list[Any],
correct_action: Optional[list[str]] = None,
is_malicious: Optional[list[bool]] = None,
severity: Optional[list[str]] = None,
risk_score: Optional[list[float]] = None,
scenario_type: Optional[list[str]] = None,
**kwargs: Any,
) -> list[float]:
rewards: list[float] = []
parsed_actions: list[Optional[str]] = []
correct_action = correct_action or []
is_malicious = is_malicious or []
severity = severity or []
risk_score = risk_score or []
scenario_type = scenario_type or []
for index, completion in enumerate(completions):
sample = {
"correct_action": (correct_action[index] if index < len(correct_action) else "ALLOW"),
"is_malicious": bool(is_malicious[index]) if index < len(is_malicious) else False,
"severity": (severity[index] if index < len(severity) else "LOW"),
"risk_score": float(risk_score[index]) if index < len(risk_score) else 0.0,
"scenario_type": (scenario_type[index] if index < len(scenario_type) else "BENIGN_CLEAN"),
}
output_text = normalize_completion_text(completion)
analysis = analyze_action_output(output_text)
parsed_action = analysis["parsed_action"]
parsed_actions.append(parsed_action)
reward, _ = compute_shaped_reward(
sample,
parsed_action,
{"multi_action_warning": analysis["multi_action_warning"]},
)
rewards.append(reward)
if tracker is not None:
tracker.record_batch(parsed_actions, rewards)
return rewards
return reward_fn
def make_health_callback(tracker: RewardHealthTracker, transformers_module: Any) -> Any:
TrainerCallback = getattr(transformers_module, "TrainerCallback")
class HealthCallback(TrainerCallback):
def on_log(self, args: Any, state: Any, control: Any, logs: Optional[dict[str, Any]] = None, **kwargs: Any) -> Any:
logs = logs or {}
if "grad_norm" in logs:
tracker.record_grad_norm(logs["grad_norm"])
return control
return HealthCallback()
def load_fast_model(
fast_model_class: Any,
model_name_or_path: str,
max_seq_len: int,
load_in_4bit: bool = True,
) -> tuple[Any, Any]:
model, tokenizer = fast_model_class.from_pretrained(
model_name=model_name_or_path,
max_seq_length=max_seq_len,
load_in_4bit=load_in_4bit,
load_in_8bit=False,
full_finetuning=False,
)
return model, tokenizer
def ensure_trainable_lora(
fast_model_class: Any,
model: Any,
*,
lora_r: int,
lora_alpha: int,
lora_dropout: float,
) -> Any:
if getattr(model, "peft_config", None):
return model
return fast_model_class.get_peft_model(
model,
r=lora_r,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
use_gradient_checkpointing="unsloth",
)
def evaluate_model_on_dataset(
model: Any,
tokenizer: Any,
samples: list[dict[str, Any]],
*,
batch_size: int,
max_seq_len: int,
max_new_tokens: int,
) -> dict[str, Any]:
import torch
if not samples:
return evaluate_outputs([], [], label="empty")
try:
device = next(model.parameters()).device
except StopIteration:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
prompts = [format_prompt_for_model(tokenizer, sample["prompt"]) for sample in samples]
outputs_text: list[str] = []
for start in range(0, len(prompts), max(1, batch_size)):
batch_prompts = prompts[start : start + batch_size]
inputs = tokenizer(
batch_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_seq_len,
).to(device)
input_lengths = inputs["attention_mask"].sum(dim=1).tolist()
generation_kwargs = {
"do_sample": False,
"temperature": 1.0,
"top_p": 0.95,
"top_k": 50,
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
if getattr(model, "generation_config", None) is not None:
generation_config = copy.deepcopy(model.generation_config)
# Action-only evaluation needs only the action token. Set max_new_tokens
# on the generation config and clear inherited max_length to avoid
# transformers warning about both limits being active.
generation_config.max_new_tokens = max_new_tokens
with contextlib.suppress(Exception):
generation_config.max_length = None
generation_kwargs["generation_config"] = generation_config
else:
generation_kwargs["max_new_tokens"] = max_new_tokens
with torch.inference_mode():
generated = model.generate(
**inputs,
**generation_kwargs,
)
for row_index, length in enumerate(input_lengths):
completion_tokens = generated[row_index][int(length):]
completion_text = tokenizer.decode(completion_tokens, skip_special_tokens=True)
outputs_text.append(completion_text)
return evaluate_outputs(samples, outputs_text, label="model")
def evaluate_saved_model(
*,
model_path_or_name: str,
val_samples: list[dict[str, Any]],
max_seq_len: int,
batch_size: int,
max_new_tokens: int,
) -> Optional[dict[str, Any]]:
training_stack = load_training_stack("sft")
if training_stack is None:
return None
torch_module = training_stack["torch"]
configure_torch_runtime(torch_module)
FastModel = training_stack["FastModel"]
model, tokenizer = load_fast_model(FastModel, model_path_or_name, max_seq_len=max_seq_len)
FastModel.for_inference(model)
metrics = evaluate_model_on_dataset(
model,
tokenizer,
val_samples,
batch_size=batch_size,
max_seq_len=max_seq_len,
max_new_tokens=max_new_tokens,
)
metrics["label"] = model_path_or_name
return metrics
def pick_best_checkpoint(
*,
model_name: str,
output_dir: Path,
val_samples: list[dict[str, Any]],
max_seq_len: int,
batch_size: int,
max_new_tokens: int,
) -> Optional[str]:
candidate_paths = [output_dir]
candidate_paths.extend(sorted(path for path in output_dir.glob("checkpoint-*") if path.is_dir()))
best_path: Optional[Path] = None
best_score = -float("inf")
for candidate in candidate_paths:
metrics = evaluate_saved_model(
model_path_or_name=str(candidate),
val_samples=val_samples,
max_seq_len=max_seq_len,
batch_size=batch_size,
max_new_tokens=max_new_tokens,
)
if metrics is None:
continue
score = (0.7 * metrics["safety_accuracy"]) + (0.3 * metrics["exact_match"])
if score > best_score:
best_score = score
best_path = candidate
if best_path is None:
return None
try:
return str(best_path.relative_to(BACKEND_DIR))
except ValueError:
return str(best_path)
def run_subprocess(command: list[str], cwd: Path) -> None:
completed = subprocess.run(command, cwd=str(cwd), check=False)
if completed.returncode != 0:
raise RuntimeError(f"Command failed with exit code {completed.returncode}: {' '.join(command)}")
def add_shared_smoke_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--smoke-test", action="store_true", help="Run fast logic-only smoke tests and exit.")
parser.add_argument("--dry-run", action="store_true", help="Build configs and datasets without training.")
parser.add_argument("--skip-model-load", action="store_true", help="Skip Qwen model loading for local smoke checks.")