""" Model evaluation utilities for emotion recognition. """ import numpy as np import matplotlib.pyplot as plt import seaborn as sns from pathlib import Path from typing import Dict, List, Optional, Tuple from sklearn.metrics import ( classification_report, confusion_matrix, accuracy_score, precision_recall_fscore_support, roc_curve, auc ) from tensorflow.keras.models import Model import sys sys.path.append(str(Path(__file__).parent.parent.parent)) from src.config import EMOTION_CLASSES, NUM_CLASSES, MODELS_DIR def evaluate_model( model: Model, test_generator, class_names: List[str] = EMOTION_CLASSES ) -> Dict: """ Evaluate a trained model on test data. Args: model: Trained Keras model test_generator: Test data generator class_names: List of class names Returns: Dictionary with evaluation metrics """ # Reset generator to start test_generator.reset() # Get predictions predictions = model.predict(test_generator, verbose=1) y_pred = np.argmax(predictions, axis=1) y_true = test_generator.classes # Calculate metrics accuracy = accuracy_score(y_true, y_pred) precision, recall, f1, support = precision_recall_fscore_support( y_true, y_pred, average=None ) # Per-class metrics per_class_metrics = {} for i, class_name in enumerate(class_names): per_class_metrics[class_name] = { "precision": float(precision[i]), "recall": float(recall[i]), "f1_score": float(f1[i]), "support": int(support[i]) } # Overall metrics precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( y_true, y_pred, average='macro' ) precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support( y_true, y_pred, average='weighted' ) results = { "accuracy": float(accuracy), "macro_precision": float(precision_macro), "macro_recall": float(recall_macro), "macro_f1": float(f1_macro), "weighted_precision": float(precision_weighted), "weighted_recall": float(recall_weighted), "weighted_f1": float(f1_weighted), "per_class": per_class_metrics, "predictions": y_pred.tolist(), "true_labels": y_true.tolist(), "probabilities": predictions.tolist() } return results def generate_classification_report( y_true: np.ndarray, y_pred: np.ndarray, class_names: List[str] = EMOTION_CLASSES, output_dict: bool = True ) -> Dict: """ Generate a classification report. Args: y_true: True labels y_pred: Predicted labels class_names: List of class names output_dict: Whether to return as dictionary Returns: Classification report """ report = classification_report( y_true, y_pred, target_names=class_names, output_dict=output_dict ) if not output_dict: print(report) return report def compute_confusion_matrix( y_true: np.ndarray, y_pred: np.ndarray, normalize: bool = True ) -> np.ndarray: """ Compute confusion matrix. Args: y_true: True labels y_pred: Predicted labels normalize: Whether to normalize the matrix Returns: Confusion matrix """ cm = confusion_matrix(y_true, y_pred) if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] return cm def plot_confusion_matrix( y_true: np.ndarray, y_pred: np.ndarray, class_names: List[str] = EMOTION_CLASSES, normalize: bool = True, figsize: Tuple[int, int] = (12, 10), cmap: str = 'Blues', save_path: Optional[Path] = None, title: str = "Confusion Matrix" ) -> plt.Figure: """ Plot confusion matrix as a heatmap. Args: y_true: True labels y_pred: Predicted labels class_names: List of class names normalize: Whether to normalize figsize: Figure size cmap: Colormap save_path: Optional path to save the figure title: Plot title Returns: Matplotlib figure """ cm = compute_confusion_matrix(y_true, y_pred, normalize=normalize) fig, ax = plt.subplots(figsize=figsize) sns.heatmap( cm, annot=True, fmt='.2f' if normalize else 'd', cmap=cmap, ax=ax, xticklabels=class_names, yticklabels=class_names, square=True, cbar_kws={'shrink': 0.8} ) ax.set_xlabel('Predicted Label', fontsize=12) ax.set_ylabel('True Label', fontsize=12) ax.set_title(title, fontsize=14) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Confusion matrix saved to: {save_path}") return fig def plot_training_history( history: Dict, metrics: List[str] = ['accuracy', 'loss'], figsize: Tuple[int, int] = (14, 5), save_path: Optional[Path] = None ) -> plt.Figure: """ Plot training history curves. Args: history: Training history dictionary metrics: Metrics to plot figsize: Figure size save_path: Optional path to save the figure Returns: Matplotlib figure """ num_metrics = len(metrics) fig, axes = plt.subplots(1, num_metrics, figsize=figsize) if num_metrics == 1: axes = [axes] for ax, metric in zip(axes, metrics): if metric in history: epochs = range(1, len(history[metric]) + 1) ax.plot(epochs, history[metric], 'b-', label=f'Training {metric.capitalize()}') val_metric = f'val_{metric}' if val_metric in history: ax.plot(epochs, history[val_metric], 'r-', label=f'Validation {metric.capitalize()}') ax.set_xlabel('Epoch') ax.set_ylabel(metric.capitalize()) ax.set_title(f'{metric.capitalize()} over Epochs') ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Training history plot saved to: {save_path}") return fig def plot_per_class_metrics( results: Dict, figsize: Tuple[int, int] = (14, 6), save_path: Optional[Path] = None ) -> plt.Figure: """ Plot per-class precision, recall, and F1 scores. Args: results: Evaluation results dictionary figsize: Figure size save_path: Optional path to save Returns: Matplotlib figure """ per_class = results['per_class'] classes = list(per_class.keys()) precision = [per_class[c]['precision'] for c in classes] recall = [per_class[c]['recall'] for c in classes] f1 = [per_class[c]['f1_score'] for c in classes] x = np.arange(len(classes)) width = 0.25 fig, ax = plt.subplots(figsize=figsize) bars1 = ax.bar(x - width, precision, width, label='Precision', color='#3498db') bars2 = ax.bar(x, recall, width, label='Recall', color='#2ecc71') bars3 = ax.bar(x + width, f1, width, label='F1-Score', color='#e74c3c') ax.set_xlabel('Emotion Class') ax.set_ylabel('Score') ax.set_title('Per-Class Performance Metrics') ax.set_xticks(x) ax.set_xticklabels(classes, rotation=45, ha='right') ax.legend() ax.set_ylim(0, 1.0) ax.grid(True, alpha=0.3, axis='y') # Add value labels for bars in [bars1, bars2, bars3]: for bar in bars: height = bar.get_height() ax.annotate(f'{height:.2f}', xy=(bar.get_x() + bar.get_width() / 2, height), xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Per-class metrics plot saved to: {save_path}") return fig def compute_roc_curves( y_true: np.ndarray, y_proba: np.ndarray, class_names: List[str] = EMOTION_CLASSES ) -> Dict: """ Compute ROC curves for each class. Args: y_true: True labels (one-hot encoded) y_proba: Prediction probabilities class_names: List of class names Returns: Dictionary with ROC curve data """ # Convert to one-hot if needed if len(y_true.shape) == 1: y_true_onehot = np.zeros((len(y_true), len(class_names))) y_true_onehot[np.arange(len(y_true)), y_true] = 1 y_true = y_true_onehot roc_data = {} for i, class_name in enumerate(class_names): fpr, tpr, thresholds = roc_curve(y_true[:, i], y_proba[:, i]) roc_auc = auc(fpr, tpr) roc_data[class_name] = { 'fpr': fpr.tolist(), 'tpr': tpr.tolist(), 'auc': float(roc_auc) } return roc_data def plot_roc_curves( roc_data: Dict, figsize: Tuple[int, int] = (10, 8), save_path: Optional[Path] = None ) -> plt.Figure: """ Plot ROC curves for all classes. Args: roc_data: ROC curve data from compute_roc_curves figsize: Figure size save_path: Optional save path Returns: Matplotlib figure """ fig, ax = plt.subplots(figsize=figsize) colors = plt.cm.Set2(np.linspace(0, 1, len(roc_data))) for (class_name, data), color in zip(roc_data.items(), colors): ax.plot( data['fpr'], data['tpr'], color=color, lw=2, label=f"{class_name} (AUC = {data['auc']:.2f})" ) ax.plot([0, 1], [0, 1], 'k--', lw=2, label='Random') ax.set_xlim([0.0, 1.0]) ax.set_ylim([0.0, 1.05]) ax.set_xlabel('False Positive Rate') ax.set_ylabel('True Positive Rate') ax.set_title('ROC Curves by Emotion Class') ax.legend(loc='lower right') ax.grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"ROC curves saved to: {save_path}") return fig def compare_models( model_results: Dict[str, Dict], save_path: Optional[Path] = None ) -> plt.Figure: """ Compare multiple models. Args: model_results: Dictionary of model_name -> evaluation results save_path: Optional save path Returns: Matplotlib figure """ models = list(model_results.keys()) metrics = ['accuracy', 'macro_precision', 'macro_recall', 'macro_f1'] fig, ax = plt.subplots(figsize=(12, 6)) x = np.arange(len(models)) width = 0.2 for i, metric in enumerate(metrics): values = [model_results[m].get(metric, 0) for m in models] offset = (i - len(metrics)/2 + 0.5) * width bars = ax.bar(x + offset, values, width, label=metric.replace('_', ' ').title()) ax.set_xlabel('Model') ax.set_ylabel('Score') ax.set_title('Model Comparison') ax.set_xticks(x) ax.set_xticklabels(models) ax.legend() ax.set_ylim(0, 1.0) ax.grid(True, alpha=0.3, axis='y') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Model comparison saved to: {save_path}") return fig if __name__ == "__main__": # Example usage print("Evaluation module loaded successfully.") print(f"Emotion classes: {EMOTION_CLASSES}")