| """ |
| Evaluation for MambaShield dual-head (safety binary + category multi-label). |
| |
| Returns: |
| safety_acc : binary accuracy (safe/unsafe) |
| safety_f1 : binary F1 |
| category_macro_f1 : macro-averaged F1 across 11 categories |
| category_micro_f1 : micro-averaged F1 |
| per_class_f1 : dict {label_name: f1} |
| per_class_conf : dict {label_name: mean_confidence} |
| |
| Usage (standalone): |
| python evaluate.py --ckpt checkpoints/best_model.pt --plot |
| """ |
|
|
| import argparse |
| from typing import Dict |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from sklearn.metrics import ( |
| accuracy_score, classification_report, |
| f1_score, confusion_matrix, |
| ) |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| from config import ID2LABEL, MambaShieldConfig |
|
|
|
|
| def evaluate( |
| model, |
| loader: DataLoader, |
| device: torch.device, |
| cfg: MambaShieldConfig, |
| print_report: bool = False, |
| ) -> Dict: |
| model.eval() |
|
|
| all_safety_pred, all_safety_true = [], [] |
| all_cat_pred, all_cat_true = [], [] |
| all_cat_probs = [] |
|
|
| with torch.no_grad(): |
| for batch in tqdm(loader, desc="Evaluating", leave=False): |
| ids = batch["input_ids"].to(device) |
| mask = batch["attention_mask"].to(device) |
| s_true = batch["safety_label"] |
| c_true = batch["category_vec"] |
|
|
| safety_logit, cat_logits = model(ids, mask) |
|
|
| s_pred = (safety_logit.squeeze(-1) > 0).float().cpu() |
| c_prob = torch.sigmoid(cat_logits).float().cpu() |
| c_pred = (c_prob > cfg.category_threshold).float() |
|
|
| all_safety_pred.append(s_pred) |
| all_safety_true.append(s_true) |
| all_cat_pred.append(c_pred) |
| all_cat_true.append(c_true) |
| all_cat_probs.append(c_prob) |
|
|
| s_pred = torch.cat(all_safety_pred).numpy() |
| s_true = torch.cat(all_safety_true).numpy() |
| c_pred = torch.cat(all_cat_pred).numpy() |
| c_true = torch.cat(all_cat_true).numpy() |
| c_prob = torch.cat(all_cat_probs).numpy() |
|
|
| |
| safety_acc = accuracy_score(s_true, s_pred) |
| safety_f1 = f1_score(s_true, s_pred, average="binary", zero_division=0) |
|
|
| |
| cat_macro_f1 = f1_score(c_true, c_pred, average="macro", zero_division=0) |
| cat_micro_f1 = f1_score(c_true, c_pred, average="micro", zero_division=0) |
| per_label_f1 = f1_score(c_true, c_pred, average=None, zero_division=0) |
| mean_conf = c_prob.mean(axis=0) |
|
|
| per_class_f1 = {ID2LABEL[i]: round(float(per_label_f1[i]), 4) for i in range(len(ID2LABEL))} |
| per_class_conf = {ID2LABEL[i]: round(float(mean_conf[i]), 4) for i in range(len(ID2LABEL))} |
|
|
| if print_report: |
| label_names = [ID2LABEL[i] for i in range(len(ID2LABEL))] |
| print("\nββ Safety Head ββ") |
| print(f" accuracy={safety_acc:.4f} f1={safety_f1:.4f}") |
| print("\nββ Category Head (multi-label, threshold=%.2f) ββ" % cfg.category_threshold) |
| print(classification_report(c_true, c_pred, target_names=label_names, zero_division=0)) |
| print("\nββ Mean Confidence Per Category ββ") |
| for name, conf in sorted(per_class_conf.items(), key=lambda x: -x[1]): |
| bar = "β" * int(conf * 30) |
| print(f" {name:<45} {conf:.4f} {bar}") |
|
|
| return { |
| "safety_acc": safety_acc, |
| "safety_f1": safety_f1, |
| "category_macro_f1": cat_macro_f1, |
| "category_micro_f1": cat_micro_f1, |
| "per_class_f1": per_class_f1, |
| "per_class_conf": per_class_conf, |
| |
| "safety_pred": s_pred, "safety_true": s_true, |
| "cat_pred": c_pred, "cat_true": c_true, |
| } |
|
|
|
|
| def plot_confusion_matrix(preds, labels, save_path="confusion_matrix.png"): |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
|
|
| label_names = [ID2LABEL[i] for i in range(len(ID2LABEL))] |
| |
| if preds.ndim == 2: |
| preds = preds.argmax(axis=1) |
| labels = labels.argmax(axis=1) |
|
|
| cm = confusion_matrix(labels, preds) |
| fig, ax = plt.subplots(figsize=(12, 10)) |
| sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", |
| xticklabels=label_names, yticklabels=label_names, ax=ax) |
| ax.set_xlabel("Predicted"); ax.set_ylabel("True") |
| ax.set_title("MambaShield β Category Confusion Matrix") |
| plt.xticks(rotation=45, ha="right"); plt.tight_layout() |
| plt.savefig(save_path, dpi=150) |
| print(f"Confusion matrix saved β {save_path}") |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| import os, sys |
| sys.path.insert(0, os.path.dirname(__file__)) |
| from config import MambaShieldConfig |
| from dataset import build_dataloaders |
| from model import MambaShield |
|
|
| p = argparse.ArgumentParser() |
| p.add_argument("--ckpt", required=True) |
| p.add_argument("--plot", action="store_true") |
| args = p.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| ckpt = torch.load(args.ckpt, map_location=device) |
| cfg: MambaShieldConfig = ckpt["cfg"] |
|
|
| model = MambaShield(ckpt["vocab_size"], cfg).to(device) |
| model.load_state_dict(ckpt["model_state"]) |
|
|
| _, _, test_loader, _ = build_dataloaders(cfg) |
| metrics = evaluate(model, test_loader, device, cfg, print_report=True) |
| print(f"\nSafety accuracy : {metrics['safety_acc']:.4f}") |
| print(f"Category macro F1: {metrics['category_macro_f1']:.4f}") |
|
|
| if args.plot: |
| plot_confusion_matrix(metrics["cat_pred"], metrics["cat_true"]) |
|
|