Bachstelze
readd keras models
73f28de
"""Evaluation helpers (confusion matrix, metrics tables, plots)."""
from __future__ import annotations
from typing import Iterable
import numpy as np
import tensorflow as tf
from sklearn.metrics import (
accuracy_score,
confusion_matrix,
precision_score,
recall_score,
roc_auc_score,
)
METRIC_KEYS = ["tp", "fp", "tn", "fn", "accuracy", "precision", "recall", "auc"]
def predict_proba(model: tf.keras.Model, X: np.ndarray) -> np.ndarray:
return model.predict(X, verbose=0).reshape(-1)
def metrics_from_predictions(
y_true: np.ndarray, y_proba: np.ndarray, threshold: float = 0.5
) -> dict[str, float]:
y_pred = (y_proba >= threshold).astype(int)
cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
tn, fp, fn, tp = cm.ravel()
auc = float("nan")
if len(np.unique(y_true)) > 1:
auc = roc_auc_score(y_true, y_proba)
return {
"tp": int(tp), "fp": int(fp), "tn": int(tn), "fn": int(fn),
"accuracy": accuracy_score(y_true, y_pred),
"precision": precision_score(y_true, y_pred, zero_division=0),
"recall": recall_score(y_true, y_pred, zero_division=0),
"auc": auc,
}
def confusion(y_true: np.ndarray, y_proba: np.ndarray, threshold: float = 0.5) -> np.ndarray:
return confusion_matrix(y_true, (y_proba >= threshold).astype(int), labels=[0, 1])
def plot_confusion(cm: np.ndarray, title: str, ax=None):
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots(figsize=(3, 3))
ax.imshow(cm, cmap="Blues")
ax.set_xticks([0, 1]); ax.set_yticks([0, 1])
ax.set_xticklabels(["bad", "good"]); ax.set_yticklabels(["bad", "good"])
ax.set_xlabel("predicted"); ax.set_ylabel("true")
ax.set_title(title)
for i in range(2):
for j in range(2):
ax.text(j, i, int(cm[i, j]), ha="center", va="center",
color="white" if cm[i, j] > cm.max() / 2 else "black")
return ax
def metrics_table(rows: Iterable[dict], index: Iterable[str]):
import pandas as pd
df = pd.DataFrame(list(rows), index=list(index))
cols = [c for c in METRIC_KEYS if c in df.columns]
return df[cols].round(4)