Spaces:
Running
Running
| """Evaluate baseline and trained defender on the frozen hold-out set. | |
| Two models are compared by default: | |
| * **Baseline**: vanilla Qwen2.5-3B-Instruct, no SFT, no GRPO. | |
| * **Trained**: Qwen2.5-3B-Instruct + SFT warm-start + GRPO curriculum. | |
| Both are scored on `data/holdout.jsonl` using the verifier's ground-truth | |
| labels. Reported metrics (printed and saved to `--out-dir`): | |
| * Macro F1 + per-class precision/recall | |
| * 5x5 confusion matrix | |
| * Dismiss-on-malicious rate (the cardinal SOC failure mode) | |
| * Over-react rate (containment on benign) | |
| Inference path | |
| -------------- | |
| We use Unsloth's `FastLanguageModel.from_pretrained(... load_in_4bit=True)` | |
| with `model.fast_generate` to keep eval under 10 minutes on a T4. When | |
| GPU deps aren't available (e.g. the Hugging Face Space build log), the | |
| script falls back to a verifier-only sanity check by re-grading the | |
| held-out file against itself, which serves as a smoke test. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from typing import List, Tuple | |
| _HERE = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.insert(0, os.path.dirname(_HERE)) | |
| from eval.metrics import ( # noqa: E402 | |
| accuracy, | |
| confusion_matrix, | |
| dismiss_on_malicious_rate, | |
| over_react_rate, | |
| per_class_f1, | |
| ) | |
| from schema import Alert, Event, IncidentCategory, TriageAction # noqa: E402 | |
| from train.prompt_format import ( # noqa: E402 | |
| SYSTEM_PROMPT, | |
| parse_defender_response, | |
| render_defender_prompt, | |
| ) | |
| def _load_holdout(path: str): | |
| items = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| items.append(json.loads(line)) | |
| return items | |
| def _to_alert_events(rec: dict) -> Tuple[Alert, List[Event]]: | |
| a = rec["alert"] | |
| alert = Alert( | |
| alert_id=a["alert_id"], | |
| category=IncidentCategory(a["category"]), | |
| severity=a["severity"], | |
| summary=a["summary"], | |
| host=a.get("host", ""), | |
| user=a.get("user", ""), | |
| ) | |
| events = [Event(**e) for e in rec["events"]] | |
| return alert, events | |
| def _print_metrics(label: str, preds: List[str], truths: List[str]) -> dict: | |
| cm = confusion_matrix(preds, truths) | |
| macro_f1, per_class = per_class_f1(cm) | |
| acc = accuracy(preds, truths) | |
| miss = dismiss_on_malicious_rate(preds, truths) | |
| over = over_react_rate(preds, truths) | |
| print(f"\n=== {label} ===") | |
| print(f" accuracy: {acc:.3f}") | |
| print(f" macro F1: {macro_f1:.3f}") | |
| print(f" dismiss-on-malicious: {miss:.3f}") | |
| print(f" over-react on benign: {over:.3f}") | |
| print(" per-class:") | |
| for cls, m in per_class.items(): | |
| print(f" {cls:<18} P={m['precision']:.2f} R={m['recall']:.2f} F1={m['f1']:.2f} (n={int(m['support'])})") | |
| return { | |
| "label": label, | |
| "accuracy": acc, | |
| "macro_f1": macro_f1, | |
| "dismiss_on_malicious": miss, | |
| "over_react_rate": over, | |
| "per_class": per_class, | |
| "confusion_matrix": cm, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Inference adapters | |
| # --------------------------------------------------------------------------- | |
| class _VerifierOracle: | |
| """A 'model' that always returns the verifier's correct answer. | |
| Used as a smoke test when GPU deps aren't installed; it should achieve | |
| 100% accuracy / 0% dismiss-on-malicious by construction. | |
| """ | |
| name = "verifier_oracle" | |
| def predict(self, alert: Alert, events: List[Event], gold: dict) -> str: | |
| return f"Action: {gold['ground_truth']}\nCitedLog: {gold['triggering_log_id']}\nRationale: oracle" | |
| class _AlwaysDismissBaseline: | |
| """A trivial baseline that always says 'dismiss'.""" | |
| name = "always_dismiss" | |
| def predict(self, alert: Alert, events: List[Event], gold: dict) -> str: | |
| return "Action: dismiss\nCitedLog: L1-0\nRationale: trivial baseline" | |
| def _try_load_unsloth_model(model_name: str, adapter_path: str | None): | |
| """Load a model via Unsloth. Returns None if GPU deps aren't installed.""" | |
| try: | |
| from unsloth import FastLanguageModel | |
| except ImportError: | |
| return None | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=2048, | |
| dtype=None, | |
| load_in_4bit=True, | |
| ) | |
| if adapter_path and os.path.exists(adapter_path): | |
| model.load_adapter(adapter_path, adapter_name="default", is_trainable=False) | |
| FastLanguageModel.for_inference(model) | |
| return model, tokenizer | |
| def _generate(model_pair, alert, events) -> str: | |
| model, tokenizer = model_pair | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": render_defender_prompt(alert, events)}, | |
| ] | |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| out = model.generate(**inputs, max_new_tokens=128, do_sample=False, temperature=0.0) | |
| text = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| return text | |
| # --------------------------------------------------------------------------- | |
| # Main eval | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--baseline", default="unsloth/Qwen2.5-3B-Instruct") | |
| parser.add_argument("--trained-adapter", default="checkpoints/defender_grpo/stage4_adversarial/adapter") | |
| parser.add_argument("--holdout", default="data/holdout.jsonl") | |
| parser.add_argument("--out-dir", default="eval/results") | |
| parser.add_argument("--smoke-only", action="store_true", | |
| help="Skip GPU model loading; run oracle + always_dismiss only.") | |
| args = parser.parse_args() | |
| holdout_path = os.path.join(os.path.dirname(_HERE), args.holdout) | |
| out_dir = os.path.join(os.path.dirname(_HERE), args.out_dir) | |
| os.makedirs(out_dir, exist_ok=True) | |
| holdout = _load_holdout(holdout_path) | |
| truths = [r["ground_truth"] for r in holdout] | |
| print(f"Loaded {len(holdout)} hold-out incidents from {holdout_path}") | |
| summaries = [] | |
| # --- Always-dismiss baseline (sanity) --- | |
| preds_dismiss = [] | |
| for rec in holdout: | |
| alert, events = _to_alert_events(rec) | |
| text = _AlwaysDismissBaseline().predict(alert, events, rec) | |
| parsed = parse_defender_response(text) | |
| preds_dismiss.append(parsed.action.value if parsed.action else "dismiss") | |
| summaries.append(_print_metrics("always_dismiss", preds_dismiss, truths)) | |
| # --- Verifier oracle (sanity) --- | |
| preds_oracle = [] | |
| for rec in holdout: | |
| alert, events = _to_alert_events(rec) | |
| text = _VerifierOracle().predict(alert, events, rec) | |
| parsed = parse_defender_response(text) | |
| preds_oracle.append(parsed.action.value if parsed.action else "dismiss") | |
| summaries.append(_print_metrics("verifier_oracle", preds_oracle, truths)) | |
| # --- Real models --- | |
| if not args.smoke_only: | |
| baseline_pair = _try_load_unsloth_model(args.baseline, adapter_path=None) | |
| if baseline_pair is not None: | |
| preds_baseline = [] | |
| for rec in holdout: | |
| alert, events = _to_alert_events(rec) | |
| text = _generate(baseline_pair, alert, events) | |
| parsed = parse_defender_response(text) | |
| preds_baseline.append(parsed.action.value if parsed.action else "dismiss") | |
| summaries.append(_print_metrics("baseline_zero_shot", preds_baseline, truths)) | |
| adapter_full = os.path.join(os.path.dirname(_HERE), args.trained_adapter) | |
| if os.path.exists(adapter_full): | |
| trained_pair = _try_load_unsloth_model(args.baseline, adapter_path=adapter_full) | |
| if trained_pair is not None: | |
| preds_trained = [] | |
| for rec in holdout: | |
| alert, events = _to_alert_events(rec) | |
| text = _generate(trained_pair, alert, events) | |
| parsed = parse_defender_response(text) | |
| preds_trained.append(parsed.action.value if parsed.action else "dismiss") | |
| summaries.append(_print_metrics("opensoc_grpo", preds_trained, truths)) | |
| else: | |
| print(f"\n(skip) trained adapter not found at {adapter_full}") | |
| else: | |
| print("\n(skip) GPU deps not installed; skipping baseline_zero_shot and opensoc_grpo.") | |
| out_json = os.path.join(out_dir, "summary.json") | |
| with open(out_json, "w") as f: | |
| json.dump(summaries, f, indent=2) | |
| print(f"\nSaved summary to {out_json}") | |
| if __name__ == "__main__": | |
| main() | |