"""Evaluate baseline and trained defender on the frozen hold-out set. Two models are compared by default: * **Baseline**: vanilla Qwen2.5-3B-Instruct, no SFT, no GRPO. * **Trained**: Qwen2.5-3B-Instruct + SFT warm-start + GRPO curriculum. Both are scored on `data/holdout.jsonl` using the verifier's ground-truth labels. Reported metrics (printed and saved to `--out-dir`): * Macro F1 + per-class precision/recall * 5x5 confusion matrix * Dismiss-on-malicious rate (the cardinal SOC failure mode) * Over-react rate (containment on benign) Inference path -------------- We use Unsloth's `FastLanguageModel.from_pretrained(... load_in_4bit=True)` with `model.fast_generate` to keep eval under 10 minutes on a T4. When GPU deps aren't available (e.g. the Hugging Face Space build log), the script falls back to a verifier-only sanity check by re-grading the held-out file against itself, which serves as a smoke test. """ from __future__ import annotations import argparse import json import os import sys from typing import List, Tuple _HERE = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.dirname(_HERE)) from eval.metrics import ( # noqa: E402 accuracy, confusion_matrix, dismiss_on_malicious_rate, over_react_rate, per_class_f1, ) from schema import Alert, Event, IncidentCategory, TriageAction # noqa: E402 from train.prompt_format import ( # noqa: E402 SYSTEM_PROMPT, parse_defender_response, render_defender_prompt, ) def _load_holdout(path: str): items = [] with open(path, "r", encoding="utf-8") as f: for line in f: items.append(json.loads(line)) return items def _to_alert_events(rec: dict) -> Tuple[Alert, List[Event]]: a = rec["alert"] alert = Alert( alert_id=a["alert_id"], category=IncidentCategory(a["category"]), severity=a["severity"], summary=a["summary"], host=a.get("host", ""), user=a.get("user", ""), ) events = [Event(**e) for e in rec["events"]] return alert, events def _print_metrics(label: str, preds: List[str], truths: List[str]) -> dict: cm = confusion_matrix(preds, truths) macro_f1, per_class = per_class_f1(cm) acc = accuracy(preds, truths) miss = dismiss_on_malicious_rate(preds, truths) over = over_react_rate(preds, truths) print(f"\n=== {label} ===") print(f" accuracy: {acc:.3f}") print(f" macro F1: {macro_f1:.3f}") print(f" dismiss-on-malicious: {miss:.3f}") print(f" over-react on benign: {over:.3f}") print(" per-class:") for cls, m in per_class.items(): print(f" {cls:<18} P={m['precision']:.2f} R={m['recall']:.2f} F1={m['f1']:.2f} (n={int(m['support'])})") return { "label": label, "accuracy": acc, "macro_f1": macro_f1, "dismiss_on_malicious": miss, "over_react_rate": over, "per_class": per_class, "confusion_matrix": cm, } # --------------------------------------------------------------------------- # Inference adapters # --------------------------------------------------------------------------- class _VerifierOracle: """A 'model' that always returns the verifier's correct answer. Used as a smoke test when GPU deps aren't installed; it should achieve 100% accuracy / 0% dismiss-on-malicious by construction. """ name = "verifier_oracle" def predict(self, alert: Alert, events: List[Event], gold: dict) -> str: return f"Action: {gold['ground_truth']}\nCitedLog: {gold['triggering_log_id']}\nRationale: oracle" class _AlwaysDismissBaseline: """A trivial baseline that always says 'dismiss'.""" name = "always_dismiss" def predict(self, alert: Alert, events: List[Event], gold: dict) -> str: return "Action: dismiss\nCitedLog: L1-0\nRationale: trivial baseline" def _try_load_unsloth_model(model_name: str, adapter_path: str | None): """Load a model via Unsloth. Returns None if GPU deps aren't installed.""" try: from unsloth import FastLanguageModel except ImportError: return None model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=2048, dtype=None, load_in_4bit=True, ) if adapter_path and os.path.exists(adapter_path): model.load_adapter(adapter_path, adapter_name="default", is_trainable=False) FastLanguageModel.for_inference(model) return model, tokenizer def _generate(model_pair, alert, events) -> str: model, tokenizer = model_pair messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": render_defender_prompt(alert, events)}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) out = model.generate(**inputs, max_new_tokens=128, do_sample=False, temperature=0.0) text = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) return text # --------------------------------------------------------------------------- # Main eval # --------------------------------------------------------------------------- def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--baseline", default="unsloth/Qwen2.5-3B-Instruct") parser.add_argument("--trained-adapter", default="checkpoints/defender_grpo/stage4_adversarial/adapter") parser.add_argument("--holdout", default="data/holdout.jsonl") parser.add_argument("--out-dir", default="eval/results") parser.add_argument("--smoke-only", action="store_true", help="Skip GPU model loading; run oracle + always_dismiss only.") args = parser.parse_args() holdout_path = os.path.join(os.path.dirname(_HERE), args.holdout) out_dir = os.path.join(os.path.dirname(_HERE), args.out_dir) os.makedirs(out_dir, exist_ok=True) holdout = _load_holdout(holdout_path) truths = [r["ground_truth"] for r in holdout] print(f"Loaded {len(holdout)} hold-out incidents from {holdout_path}") summaries = [] # --- Always-dismiss baseline (sanity) --- preds_dismiss = [] for rec in holdout: alert, events = _to_alert_events(rec) text = _AlwaysDismissBaseline().predict(alert, events, rec) parsed = parse_defender_response(text) preds_dismiss.append(parsed.action.value if parsed.action else "dismiss") summaries.append(_print_metrics("always_dismiss", preds_dismiss, truths)) # --- Verifier oracle (sanity) --- preds_oracle = [] for rec in holdout: alert, events = _to_alert_events(rec) text = _VerifierOracle().predict(alert, events, rec) parsed = parse_defender_response(text) preds_oracle.append(parsed.action.value if parsed.action else "dismiss") summaries.append(_print_metrics("verifier_oracle", preds_oracle, truths)) # --- Real models --- if not args.smoke_only: baseline_pair = _try_load_unsloth_model(args.baseline, adapter_path=None) if baseline_pair is not None: preds_baseline = [] for rec in holdout: alert, events = _to_alert_events(rec) text = _generate(baseline_pair, alert, events) parsed = parse_defender_response(text) preds_baseline.append(parsed.action.value if parsed.action else "dismiss") summaries.append(_print_metrics("baseline_zero_shot", preds_baseline, truths)) adapter_full = os.path.join(os.path.dirname(_HERE), args.trained_adapter) if os.path.exists(adapter_full): trained_pair = _try_load_unsloth_model(args.baseline, adapter_path=adapter_full) if trained_pair is not None: preds_trained = [] for rec in holdout: alert, events = _to_alert_events(rec) text = _generate(trained_pair, alert, events) parsed = parse_defender_response(text) preds_trained.append(parsed.action.value if parsed.action else "dismiss") summaries.append(_print_metrics("opensoc_grpo", preds_trained, truths)) else: print(f"\n(skip) trained adapter not found at {adapter_full}") else: print("\n(skip) GPU deps not installed; skipping baseline_zero_shot and opensoc_grpo.") out_json = os.path.join(out_dir, "summary.json") with open(out_json, "w") as f: json.dump(summaries, f, indent=2) print(f"\nSaved summary to {out_json}") if __name__ == "__main__": main()