mamba-shield / evaluate.py
satyamsaf3ai's picture
Fix BFloat16 numpy conversion error in evaluate
4cdbe43
"""
Evaluation for MambaShield dual-head (safety binary + category multi-label).
Returns:
safety_acc : binary accuracy (safe/unsafe)
safety_f1 : binary F1
category_macro_f1 : macro-averaged F1 across 11 categories
category_micro_f1 : micro-averaged F1
per_class_f1 : dict {label_name: f1}
per_class_conf : dict {label_name: mean_confidence}
Usage (standalone):
python evaluate.py --ckpt checkpoints/best_model.pt --plot
"""
import argparse
from typing import Dict
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import (
accuracy_score, classification_report,
f1_score, confusion_matrix,
)
from torch.utils.data import DataLoader
from tqdm import tqdm
from config import ID2LABEL, MambaShieldConfig
def evaluate(
model,
loader: DataLoader,
device: torch.device,
cfg: MambaShieldConfig,
print_report: bool = False,
) -> Dict:
model.eval()
all_safety_pred, all_safety_true = [], []
all_cat_pred, all_cat_true = [], [] # binary vectors (multi-label)
all_cat_probs = [] # raw sigmoid scores (B, 11)
with torch.no_grad():
for batch in tqdm(loader, desc="Evaluating", leave=False):
ids = batch["input_ids"].to(device)
mask = batch["attention_mask"].to(device)
s_true = batch["safety_label"] # (B,) float
c_true = batch["category_vec"] # (B,11) float
safety_logit, cat_logits = model(ids, mask)
s_pred = (safety_logit.squeeze(-1) > 0).float().cpu()
c_prob = torch.sigmoid(cat_logits).float().cpu()
c_pred = (c_prob > cfg.category_threshold).float()
all_safety_pred.append(s_pred)
all_safety_true.append(s_true)
all_cat_pred.append(c_pred)
all_cat_true.append(c_true)
all_cat_probs.append(c_prob)
s_pred = torch.cat(all_safety_pred).numpy()
s_true = torch.cat(all_safety_true).numpy()
c_pred = torch.cat(all_cat_pred).numpy()
c_true = torch.cat(all_cat_true).numpy()
c_prob = torch.cat(all_cat_probs).numpy()
# ── Safety metrics ────────────────────────────────────────────────────
safety_acc = accuracy_score(s_true, s_pred)
safety_f1 = f1_score(s_true, s_pred, average="binary", zero_division=0)
# ── Category multi-label metrics ──────────────────────────────────────
cat_macro_f1 = f1_score(c_true, c_pred, average="macro", zero_division=0)
cat_micro_f1 = f1_score(c_true, c_pred, average="micro", zero_division=0)
per_label_f1 = f1_score(c_true, c_pred, average=None, zero_division=0)
mean_conf = c_prob.mean(axis=0) # mean confidence per class
per_class_f1 = {ID2LABEL[i]: round(float(per_label_f1[i]), 4) for i in range(len(ID2LABEL))}
per_class_conf = {ID2LABEL[i]: round(float(mean_conf[i]), 4) for i in range(len(ID2LABEL))}
if print_report:
label_names = [ID2LABEL[i] for i in range(len(ID2LABEL))]
print("\n── Safety Head ──")
print(f" accuracy={safety_acc:.4f} f1={safety_f1:.4f}")
print("\n── Category Head (multi-label, threshold=%.2f) ──" % cfg.category_threshold)
print(classification_report(c_true, c_pred, target_names=label_names, zero_division=0))
print("\n── Mean Confidence Per Category ──")
for name, conf in sorted(per_class_conf.items(), key=lambda x: -x[1]):
bar = "β–ˆ" * int(conf * 30)
print(f" {name:<45} {conf:.4f} {bar}")
return {
"safety_acc": safety_acc,
"safety_f1": safety_f1,
"category_macro_f1": cat_macro_f1,
"category_micro_f1": cat_micro_f1,
"per_class_f1": per_class_f1,
"per_class_conf": per_class_conf,
# raw arrays for confusion matrix / further analysis
"safety_pred": s_pred, "safety_true": s_true,
"cat_pred": c_pred, "cat_true": c_true,
}
def plot_confusion_matrix(preds, labels, save_path="confusion_matrix.png"):
import matplotlib.pyplot as plt
import seaborn as sns
label_names = [ID2LABEL[i] for i in range(len(ID2LABEL))]
# For multi-label: take argmax for single-class approximation
if preds.ndim == 2:
preds = preds.argmax(axis=1)
labels = labels.argmax(axis=1)
cm = confusion_matrix(labels, preds)
fig, ax = plt.subplots(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
xticklabels=label_names, yticklabels=label_names, ax=ax)
ax.set_xlabel("Predicted"); ax.set_ylabel("True")
ax.set_title("MambaShield β€” Category Confusion Matrix")
plt.xticks(rotation=45, ha="right"); plt.tight_layout()
plt.savefig(save_path, dpi=150)
print(f"Confusion matrix saved β†’ {save_path}")
# ── Standalone ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import os, sys
sys.path.insert(0, os.path.dirname(__file__))
from config import MambaShieldConfig
from dataset import build_dataloaders
from model import MambaShield
p = argparse.ArgumentParser()
p.add_argument("--ckpt", required=True)
p.add_argument("--plot", action="store_true")
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = torch.load(args.ckpt, map_location=device)
cfg: MambaShieldConfig = ckpt["cfg"]
model = MambaShield(ckpt["vocab_size"], cfg).to(device)
model.load_state_dict(ckpt["model_state"])
_, _, test_loader, _ = build_dataloaders(cfg)
metrics = evaluate(model, test_loader, device, cfg, print_report=True)
print(f"\nSafety accuracy : {metrics['safety_acc']:.4f}")
print(f"Category macro F1: {metrics['category_macro_f1']:.4f}")
if args.plot:
plot_confusion_matrix(metrics["cat_pred"], metrics["cat_true"])