import numpy as np import torch.nn as nn from torch.utils.data import DataLoader import torch import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import ( accuracy_score, classification_report, confusion_matrix, roc_curve, precision_recall_curve, auc, average_precision_score, cohen_kappa_score, ) from sklearn.preprocessing import label_binarize class ClassificationEvaluator: """ A class to evaluate and visualize classification model performance. This class provides methods to compute various classification metrics and generate visualizations for model evaluation. """ def __init__(self, class_names: list): """ Initialize the evaluator with class names. Parameters: - class_names: list of class names """ self.class_names = class_names self.num_classes = len(class_names) def _ensure_numpy(self, data): """Convert tensor to numpy if needed.""" if torch.is_tensor(data): return data.cpu().numpy() return np.array(data) def evaluate_model(self, model: nn.Module, test_loader: DataLoader) -> dict: """ Evaluate a trained model on test dataset. Parameters: - model: PyTorch model to evaluate - test_loader: DataLoader containing test data Returns: - results: Dictionary containing evaluation metrics """ model.eval() device = next(model.parameters()).device all_labels = [] all_preds = [] all_scores = [] with torch.no_grad(): for inputs, labels in test_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) all_labels.extend(labels.cpu().numpy()) all_preds.extend(preds.cpu().numpy()) all_scores.append( torch.nn.functional.softmax(outputs, dim=1).cpu().numpy() ) all_scores = np.vstack(all_scores) # Compute metrics results = self.compute_metrics(all_labels, all_preds, all_scores) return results def compute_metrics( self, y_true, y_pred, y_scores, model_name: str | torch.Tensor = "" ) -> dict: """ Compute comprehensive classification metrics. Parameters: - y_true: true labels - y_pred: predicted labels - y_scores: predicted probability scores - model_name: name of the model (optional) Returns: - Dictionary containing all metrics """ # Ensure numpy arrays y_true = self._ensure_numpy(y_true) y_pred = self._ensure_numpy(y_pred) y_scores = self._ensure_numpy(y_scores) # Calculate accuracy accuracy = accuracy_score(y_true, y_pred) print(f"Overall Accuracy: {accuracy:.4f}") # Calculate and display Cohen's Kappa kappa = cohen_kappa_score(y_true, y_pred) print(f"Cohen's Kappa Score: {kappa:.4f}") # Generate classification report report = classification_report( y_true, y_pred, target_names=self.class_names, output_dict=True ) # Print formatted classification report print("\nClassification Report:") print(classification_report(y_true, y_pred, target_names=self.class_names)) # Calculate ROC curves and AUC for each class print("\nCalculating ROC curves...") roc_auc_dict = self.plot_roc_curves(y_true, y_scores) # Calculate PR curves and AUC for each class print("\nCalculating PR curves...") pr_auc_dict = self.plot_pr_curves(y_true, y_scores) # Plot confusion matrix print("\nGenerating confusion matrix...") self.plot_confusion_matrix(y_true, y_pred) # Plot per-class accuracy print("\nCalculating per-class accuracy...") self.plot_per_class_accuracy(y_true, y_pred) # Return metrics dictionary return { "accuracy": accuracy, "report": report, "roc_auc": roc_auc_dict, "pr_auc": pr_auc_dict, "kappa": kappa, } def plot_roc_curves(self, y_true, y_scores) -> dict: """ Plot ROC curves for multi-class classification. Parameters: - y_true: true labels - y_scores: predicted probability scores Returns: - Dictionary containing AUC values for each class """ y_true = self._ensure_numpy(y_true) y_scores = self._ensure_numpy(y_scores) # Binarize the labels for one-vs-rest ROC calculation y_true_bin = label_binarize(y_true, classes=range(self.num_classes)) # Compute ROC curve and ROC area for each class fpr = {} tpr = {} roc_auc = {} plt.figure(figsize=(12, 8)) for i in range(self.num_classes): fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_scores[:, i]) roc_auc[i] = auc(fpr[i], tpr[i]) plt.plot( fpr[i], tpr[i], lw=2, label=f"{self.class_names[i]} (area = {roc_auc[i]:.2f})", ) # Plot the diagonal (random classifier) plt.plot([0, 1], [0, 1], "k--", lw=2) # Calculate and plot micro-average ROC curve fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), y_scores.ravel()) roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) plt.plot( fpr["micro"], tpr["micro"], label=f'Micro-average (area = {roc_auc["micro"]:.2f})', lw=2, linestyle=":", color="deeppink", ) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") plt.title("ROC Curves") plt.legend(loc="lower right") plt.grid(True, alpha=0.3) plt.tight_layout() plt.show() return roc_auc def plot_pr_curves(self, y_true, y_scores) -> dict: """ Plot Precision-Recall curves for multi-class classification. Parameters: - y_true: true labels - y_scores: predicted probability scores Returns: - Dictionary containing average precision values for each class """ y_true = self._ensure_numpy(y_true) y_scores = self._ensure_numpy(y_scores) # Binarize the labels y_true_bin = label_binarize(y_true, classes=range(self.num_classes)) # Compute PR curve and average precision for each class precision = {} recall = {} avg_precision = {} plt.figure(figsize=(12, 8)) for i in range(self.num_classes): precision[i], recall[i], _ = precision_recall_curve( y_true_bin[:, i], y_scores[:, i] ) avg_precision[i] = average_precision_score(y_true_bin[:, i], y_scores[:, i]) plt.plot( recall[i], precision[i], lw=2, label=f"{self.class_names[i]} (AP = {avg_precision[i]:.2f})", ) # Calculate and plot micro-average PR curve precision["micro"], recall["micro"], _ = precision_recall_curve( y_true_bin.ravel(), y_scores.ravel() ) avg_precision["micro"] = average_precision_score( y_true_bin.ravel(), y_scores.ravel() ) plt.plot( recall["micro"], precision["micro"], label=f'Micro-average (AP = {avg_precision["micro"]:.2f})', lw=2, linestyle=":", color="deeppink", ) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel("Recall") plt.ylabel("Precision") plt.title("Precision-Recall Curves") plt.legend(loc="best") plt.grid(True, alpha=0.3) plt.tight_layout() plt.show() return avg_precision def plot_confusion_matrix(self, y_true, y_pred) -> None: """ Plot confusion matrix. Parameters: - y_true: true labels - y_pred: predicted labels """ y_true = self._ensure_numpy(y_true) y_pred = self._ensure_numpy(y_pred) # Get unique values in both arrays unique_values = np.unique(np.concatenate([y_true, y_pred])) print(f"Unique values in confusion matrix data: {unique_values}") # Create the confusion matrix with explicit labels cm = confusion_matrix(y_true, y_pred, labels=range(self.num_classes)) plt.figure(figsize=(10, 8)) sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", xticklabels=self.class_names, yticklabels=self.class_names, ) plt.title("Confusion Matrix") plt.xlabel("Predicted") plt.ylabel("True") plt.tight_layout() plt.show() def plot_per_class_accuracy(self, y_true, y_pred) -> np.ndarray: """ Plot per-class accuracy. Parameters: - y_true: true labels - y_pred: predicted labels """ y_true = self._ensure_numpy(y_true) y_pred = self._ensure_numpy(y_pred) # Create the confusion matrix with explicit labels cm = confusion_matrix(y_true, y_pred, labels=range(self.num_classes)) # Calculate per-class accuracy per_class_accuracy = np.zeros(self.num_classes) for i in range(self.num_classes): if i < cm.shape[0] and np.sum(cm[i, :]) > 0: per_class_accuracy[i] = cm[i, i] / np.sum(cm[i, :]) # Create the bar plot plt.figure(figsize=(14, 7)) plt.bar(range(self.num_classes), per_class_accuracy, color="skyblue") plt.xticks(range(self.num_classes), self.class_names, rotation=45, ha="right") plt.xlabel("Classes") plt.ylabel("Accuracy") plt.title("Per-Class Accuracy") plt.tight_layout() plt.show() return per_class_accuracy def plot_training_history( self, train_losses, val_losses, train_accs, val_accs ) -> None: """ Plot accuracy and loss curves from training history. Parameters: - train_losses: list of training losses - val_losses: list of validation losses - train_accs: list of training accuracies - val_accs: list of validation accuracies """ plt.figure(figsize=(12, 5)) # Accuracy curve plt.subplot(1, 2, 1) plt.plot(train_accs, label="Train Accuracy") plt.plot(val_accs, label="Validation Accuracy") plt.xlabel("Epochs") plt.ylabel("Accuracy") plt.title("Accuracy Curve") plt.legend() plt.grid(True) # Loss curve plt.subplot(1, 2, 2) plt.plot(train_losses, label="Train Loss") plt.plot(val_losses, label="Validation Loss") plt.xlabel("Epochs") plt.ylabel("Loss") plt.title("Loss Curve") plt.legend() plt.grid(True) plt.tight_layout() plt.show()