"""Pure-python evaluation metrics for OpenSOC. Exposes: * `confusion_matrix(predictions, truths)` — 5x5 dict-of-dicts * `per_class_f1(cm)` — macro F1 plus per-class precision/recall/F1 * `dismiss_on_malicious_rate(predictions, truths)` — the cardinal SOC metric. This is what we publish in the headline plot. * `over_react_rate(predictions, truths)` — how often the model quarantines or blocks on a benign incident. We deliberately don't pull in scikit-learn — keeping eval dependency-free makes it easy to run inside the OpenEnv container and from a Hugging Face Space build log. """ from __future__ import annotations from typing import Dict, Iterable, List, Tuple from schema import CONTAINMENT_ACTIONS, TriageAction ALL_ACTIONS: List[str] = [a.value for a in TriageAction] def confusion_matrix(predictions: Iterable[str], truths: Iterable[str]) -> Dict[str, Dict[str, int]]: cm: Dict[str, Dict[str, int]] = {gt: {p: 0 for p in ALL_ACTIONS} for gt in ALL_ACTIONS} for p, gt in zip(predictions, truths): if gt not in cm: cm[gt] = {a: 0 for a in ALL_ACTIONS} if p not in cm[gt]: cm[gt][p] = 0 cm[gt][p] = cm[gt].get(p, 0) + 1 return cm def per_class_f1(cm: Dict[str, Dict[str, int]]) -> Tuple[float, Dict[str, Dict[str, float]]]: per_class: Dict[str, Dict[str, float]] = {} f1_sum = 0.0 n_classes = 0 for cls in ALL_ACTIONS: tp = cm.get(cls, {}).get(cls, 0) fp = sum(cm.get(other, {}).get(cls, 0) for other in ALL_ACTIONS if other != cls) fn = sum(cm.get(cls, {}).get(other, 0) for other in ALL_ACTIONS if other != cls) precision = tp / (tp + fp) if (tp + fp) else 0.0 recall = tp / (tp + fn) if (tp + fn) else 0.0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0 per_class[cls] = {"precision": precision, "recall": recall, "f1": f1, "support": tp + fn} f1_sum += f1 n_classes += 1 macro_f1 = f1_sum / n_classes if n_classes else 0.0 return macro_f1, per_class def dismiss_on_malicious_rate(predictions: Iterable[str], truths: Iterable[str]) -> float: """Fraction of malicious incidents the model wrongly dismissed.""" malicious_total = 0 dismissed = 0 for p, gt in zip(predictions, truths): if gt != TriageAction.DISMISS.value: malicious_total += 1 if p == TriageAction.DISMISS.value: dismissed += 1 return dismissed / malicious_total if malicious_total else 0.0 def over_react_rate(predictions: Iterable[str], truths: Iterable[str]) -> float: """Fraction of benign-or-low incidents the model over-reacted on.""" benign_total = 0 over_reacted = 0 containment = {a.value for a in CONTAINMENT_ACTIONS} for p, gt in zip(predictions, truths): if gt in (TriageAction.DISMISS.value, TriageAction.MONITOR.value): benign_total += 1 if p in containment: over_reacted += 1 return over_reacted / benign_total if benign_total else 0.0 def accuracy(predictions: Iterable[str], truths: Iterable[str]) -> float: correct = 0 n = 0 for p, gt in zip(predictions, truths): n += 1 if p == gt: correct += 1 return correct / n if n else 0.0 __all__ = [ "ALL_ACTIONS", "confusion_matrix", "per_class_f1", "dismiss_on_malicious_rate", "over_react_rate", "accuracy", ]