""" Evaluation for MambaShield dual-head (safety binary + category multi-label). Returns: safety_acc : binary accuracy (safe/unsafe) safety_f1 : binary F1 category_macro_f1 : macro-averaged F1 across 11 categories category_micro_f1 : micro-averaged F1 per_class_f1 : dict {label_name: f1} per_class_conf : dict {label_name: mean_confidence} Usage (standalone): python evaluate.py --ckpt checkpoints/best_model.pt --plot """ import argparse from typing import Dict import numpy as np import torch import torch.nn.functional as F from sklearn.metrics import ( accuracy_score, classification_report, f1_score, confusion_matrix, ) from torch.utils.data import DataLoader from tqdm import tqdm from config import ID2LABEL, MambaShieldConfig def evaluate( model, loader: DataLoader, device: torch.device, cfg: MambaShieldConfig, print_report: bool = False, ) -> Dict: model.eval() all_safety_pred, all_safety_true = [], [] all_cat_pred, all_cat_true = [], [] # binary vectors (multi-label) all_cat_probs = [] # raw sigmoid scores (B, 11) with torch.no_grad(): for batch in tqdm(loader, desc="Evaluating", leave=False): ids = batch["input_ids"].to(device) mask = batch["attention_mask"].to(device) s_true = batch["safety_label"] # (B,) float c_true = batch["category_vec"] # (B,11) float safety_logit, cat_logits = model(ids, mask) s_pred = (safety_logit.squeeze(-1) > 0).float().cpu() c_prob = torch.sigmoid(cat_logits).float().cpu() c_pred = (c_prob > cfg.category_threshold).float() all_safety_pred.append(s_pred) all_safety_true.append(s_true) all_cat_pred.append(c_pred) all_cat_true.append(c_true) all_cat_probs.append(c_prob) s_pred = torch.cat(all_safety_pred).numpy() s_true = torch.cat(all_safety_true).numpy() c_pred = torch.cat(all_cat_pred).numpy() c_true = torch.cat(all_cat_true).numpy() c_prob = torch.cat(all_cat_probs).numpy() # ── Safety metrics ──────────────────────────────────────────────────── safety_acc = accuracy_score(s_true, s_pred) safety_f1 = f1_score(s_true, s_pred, average="binary", zero_division=0) # ── Category multi-label metrics ────────────────────────────────────── cat_macro_f1 = f1_score(c_true, c_pred, average="macro", zero_division=0) cat_micro_f1 = f1_score(c_true, c_pred, average="micro", zero_division=0) per_label_f1 = f1_score(c_true, c_pred, average=None, zero_division=0) mean_conf = c_prob.mean(axis=0) # mean confidence per class per_class_f1 = {ID2LABEL[i]: round(float(per_label_f1[i]), 4) for i in range(len(ID2LABEL))} per_class_conf = {ID2LABEL[i]: round(float(mean_conf[i]), 4) for i in range(len(ID2LABEL))} if print_report: label_names = [ID2LABEL[i] for i in range(len(ID2LABEL))] print("\n── Safety Head ──") print(f" accuracy={safety_acc:.4f} f1={safety_f1:.4f}") print("\n── Category Head (multi-label, threshold=%.2f) ──" % cfg.category_threshold) print(classification_report(c_true, c_pred, target_names=label_names, zero_division=0)) print("\n── Mean Confidence Per Category ──") for name, conf in sorted(per_class_conf.items(), key=lambda x: -x[1]): bar = "█" * int(conf * 30) print(f" {name:<45} {conf:.4f} {bar}") return { "safety_acc": safety_acc, "safety_f1": safety_f1, "category_macro_f1": cat_macro_f1, "category_micro_f1": cat_micro_f1, "per_class_f1": per_class_f1, "per_class_conf": per_class_conf, # raw arrays for confusion matrix / further analysis "safety_pred": s_pred, "safety_true": s_true, "cat_pred": c_pred, "cat_true": c_true, } def plot_confusion_matrix(preds, labels, save_path="confusion_matrix.png"): import matplotlib.pyplot as plt import seaborn as sns label_names = [ID2LABEL[i] for i in range(len(ID2LABEL))] # For multi-label: take argmax for single-class approximation if preds.ndim == 2: preds = preds.argmax(axis=1) labels = labels.argmax(axis=1) cm = confusion_matrix(labels, preds) fig, ax = plt.subplots(figsize=(12, 10)) sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=label_names, yticklabels=label_names, ax=ax) ax.set_xlabel("Predicted"); ax.set_ylabel("True") ax.set_title("MambaShield — Category Confusion Matrix") plt.xticks(rotation=45, ha="right"); plt.tight_layout() plt.savefig(save_path, dpi=150) print(f"Confusion matrix saved → {save_path}") # ── Standalone ──────────────────────────────────────────────────────────────── if __name__ == "__main__": import os, sys sys.path.insert(0, os.path.dirname(__file__)) from config import MambaShieldConfig from dataset import build_dataloaders from model import MambaShield p = argparse.ArgumentParser() p.add_argument("--ckpt", required=True) p.add_argument("--plot", action="store_true") args = p.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt = torch.load(args.ckpt, map_location=device) cfg: MambaShieldConfig = ckpt["cfg"] model = MambaShield(ckpt["vocab_size"], cfg).to(device) model.load_state_dict(ckpt["model_state"]) _, _, test_loader, _ = build_dataloaders(cfg) metrics = evaluate(model, test_loader, device, cfg, print_report=True) print(f"\nSafety accuracy : {metrics['safety_acc']:.4f}") print(f"Category macro F1: {metrics['category_macro_f1']:.4f}") if args.plot: plot_confusion_matrix(metrics["cat_pred"], metrics["cat_true"])