Spaces:
Sleeping
Sleeping
| """ | |
| Model evaluation utilities for emotion recognition. | |
| """ | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| from sklearn.metrics import ( | |
| classification_report, confusion_matrix, | |
| accuracy_score, precision_recall_fscore_support, | |
| roc_curve, auc | |
| ) | |
| from tensorflow.keras.models import Model | |
| import sys | |
| sys.path.append(str(Path(__file__).parent.parent.parent)) | |
| from src.config import EMOTION_CLASSES, NUM_CLASSES, MODELS_DIR | |
| def evaluate_model( | |
| model: Model, | |
| test_generator, | |
| class_names: List[str] = EMOTION_CLASSES | |
| ) -> Dict: | |
| """ | |
| Evaluate a trained model on test data. | |
| Args: | |
| model: Trained Keras model | |
| test_generator: Test data generator | |
| class_names: List of class names | |
| Returns: | |
| Dictionary with evaluation metrics | |
| """ | |
| # Reset generator to start | |
| test_generator.reset() | |
| # Get predictions | |
| predictions = model.predict(test_generator, verbose=1) | |
| y_pred = np.argmax(predictions, axis=1) | |
| y_true = test_generator.classes | |
| # Calculate metrics | |
| accuracy = accuracy_score(y_true, y_pred) | |
| precision, recall, f1, support = precision_recall_fscore_support( | |
| y_true, y_pred, average=None | |
| ) | |
| # Per-class metrics | |
| per_class_metrics = {} | |
| for i, class_name in enumerate(class_names): | |
| per_class_metrics[class_name] = { | |
| "precision": float(precision[i]), | |
| "recall": float(recall[i]), | |
| "f1_score": float(f1[i]), | |
| "support": int(support[i]) | |
| } | |
| # Overall metrics | |
| precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( | |
| y_true, y_pred, average='macro' | |
| ) | |
| precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support( | |
| y_true, y_pred, average='weighted' | |
| ) | |
| results = { | |
| "accuracy": float(accuracy), | |
| "macro_precision": float(precision_macro), | |
| "macro_recall": float(recall_macro), | |
| "macro_f1": float(f1_macro), | |
| "weighted_precision": float(precision_weighted), | |
| "weighted_recall": float(recall_weighted), | |
| "weighted_f1": float(f1_weighted), | |
| "per_class": per_class_metrics, | |
| "predictions": y_pred.tolist(), | |
| "true_labels": y_true.tolist(), | |
| "probabilities": predictions.tolist() | |
| } | |
| return results | |
| def generate_classification_report( | |
| y_true: np.ndarray, | |
| y_pred: np.ndarray, | |
| class_names: List[str] = EMOTION_CLASSES, | |
| output_dict: bool = True | |
| ) -> Dict: | |
| """ | |
| Generate a classification report. | |
| Args: | |
| y_true: True labels | |
| y_pred: Predicted labels | |
| class_names: List of class names | |
| output_dict: Whether to return as dictionary | |
| Returns: | |
| Classification report | |
| """ | |
| report = classification_report( | |
| y_true, y_pred, | |
| target_names=class_names, | |
| output_dict=output_dict | |
| ) | |
| if not output_dict: | |
| print(report) | |
| return report | |
| def compute_confusion_matrix( | |
| y_true: np.ndarray, | |
| y_pred: np.ndarray, | |
| normalize: bool = True | |
| ) -> np.ndarray: | |
| """ | |
| Compute confusion matrix. | |
| Args: | |
| y_true: True labels | |
| y_pred: Predicted labels | |
| normalize: Whether to normalize the matrix | |
| Returns: | |
| Confusion matrix | |
| """ | |
| cm = confusion_matrix(y_true, y_pred) | |
| if normalize: | |
| cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | |
| return cm | |
| def plot_confusion_matrix( | |
| y_true: np.ndarray, | |
| y_pred: np.ndarray, | |
| class_names: List[str] = EMOTION_CLASSES, | |
| normalize: bool = True, | |
| figsize: Tuple[int, int] = (12, 10), | |
| cmap: str = 'Blues', | |
| save_path: Optional[Path] = None, | |
| title: str = "Confusion Matrix" | |
| ) -> plt.Figure: | |
| """ | |
| Plot confusion matrix as a heatmap. | |
| Args: | |
| y_true: True labels | |
| y_pred: Predicted labels | |
| class_names: List of class names | |
| normalize: Whether to normalize | |
| figsize: Figure size | |
| cmap: Colormap | |
| save_path: Optional path to save the figure | |
| title: Plot title | |
| Returns: | |
| Matplotlib figure | |
| """ | |
| cm = compute_confusion_matrix(y_true, y_pred, normalize=normalize) | |
| fig, ax = plt.subplots(figsize=figsize) | |
| sns.heatmap( | |
| cm, annot=True, fmt='.2f' if normalize else 'd', | |
| cmap=cmap, ax=ax, | |
| xticklabels=class_names, | |
| yticklabels=class_names, | |
| square=True, | |
| cbar_kws={'shrink': 0.8} | |
| ) | |
| ax.set_xlabel('Predicted Label', fontsize=12) | |
| ax.set_ylabel('True Label', fontsize=12) | |
| ax.set_title(title, fontsize=14) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f"Confusion matrix saved to: {save_path}") | |
| return fig | |
| def plot_training_history( | |
| history: Dict, | |
| metrics: List[str] = ['accuracy', 'loss'], | |
| figsize: Tuple[int, int] = (14, 5), | |
| save_path: Optional[Path] = None | |
| ) -> plt.Figure: | |
| """ | |
| Plot training history curves. | |
| Args: | |
| history: Training history dictionary | |
| metrics: Metrics to plot | |
| figsize: Figure size | |
| save_path: Optional path to save the figure | |
| Returns: | |
| Matplotlib figure | |
| """ | |
| num_metrics = len(metrics) | |
| fig, axes = plt.subplots(1, num_metrics, figsize=figsize) | |
| if num_metrics == 1: | |
| axes = [axes] | |
| for ax, metric in zip(axes, metrics): | |
| if metric in history: | |
| epochs = range(1, len(history[metric]) + 1) | |
| ax.plot(epochs, history[metric], 'b-', label=f'Training {metric.capitalize()}') | |
| val_metric = f'val_{metric}' | |
| if val_metric in history: | |
| ax.plot(epochs, history[val_metric], 'r-', label=f'Validation {metric.capitalize()}') | |
| ax.set_xlabel('Epoch') | |
| ax.set_ylabel(metric.capitalize()) | |
| ax.set_title(f'{metric.capitalize()} over Epochs') | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f"Training history plot saved to: {save_path}") | |
| return fig | |
| def plot_per_class_metrics( | |
| results: Dict, | |
| figsize: Tuple[int, int] = (14, 6), | |
| save_path: Optional[Path] = None | |
| ) -> plt.Figure: | |
| """ | |
| Plot per-class precision, recall, and F1 scores. | |
| Args: | |
| results: Evaluation results dictionary | |
| figsize: Figure size | |
| save_path: Optional path to save | |
| Returns: | |
| Matplotlib figure | |
| """ | |
| per_class = results['per_class'] | |
| classes = list(per_class.keys()) | |
| precision = [per_class[c]['precision'] for c in classes] | |
| recall = [per_class[c]['recall'] for c in classes] | |
| f1 = [per_class[c]['f1_score'] for c in classes] | |
| x = np.arange(len(classes)) | |
| width = 0.25 | |
| fig, ax = plt.subplots(figsize=figsize) | |
| bars1 = ax.bar(x - width, precision, width, label='Precision', color='#3498db') | |
| bars2 = ax.bar(x, recall, width, label='Recall', color='#2ecc71') | |
| bars3 = ax.bar(x + width, f1, width, label='F1-Score', color='#e74c3c') | |
| ax.set_xlabel('Emotion Class') | |
| ax.set_ylabel('Score') | |
| ax.set_title('Per-Class Performance Metrics') | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(classes, rotation=45, ha='right') | |
| ax.legend() | |
| ax.set_ylim(0, 1.0) | |
| ax.grid(True, alpha=0.3, axis='y') | |
| # Add value labels | |
| for bars in [bars1, bars2, bars3]: | |
| for bar in bars: | |
| height = bar.get_height() | |
| ax.annotate(f'{height:.2f}', | |
| xy=(bar.get_x() + bar.get_width() / 2, height), | |
| xytext=(0, 3), | |
| textcoords="offset points", | |
| ha='center', va='bottom', fontsize=8) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f"Per-class metrics plot saved to: {save_path}") | |
| return fig | |
| def compute_roc_curves( | |
| y_true: np.ndarray, | |
| y_proba: np.ndarray, | |
| class_names: List[str] = EMOTION_CLASSES | |
| ) -> Dict: | |
| """ | |
| Compute ROC curves for each class. | |
| Args: | |
| y_true: True labels (one-hot encoded) | |
| y_proba: Prediction probabilities | |
| class_names: List of class names | |
| Returns: | |
| Dictionary with ROC curve data | |
| """ | |
| # Convert to one-hot if needed | |
| if len(y_true.shape) == 1: | |
| y_true_onehot = np.zeros((len(y_true), len(class_names))) | |
| y_true_onehot[np.arange(len(y_true)), y_true] = 1 | |
| y_true = y_true_onehot | |
| roc_data = {} | |
| for i, class_name in enumerate(class_names): | |
| fpr, tpr, thresholds = roc_curve(y_true[:, i], y_proba[:, i]) | |
| roc_auc = auc(fpr, tpr) | |
| roc_data[class_name] = { | |
| 'fpr': fpr.tolist(), | |
| 'tpr': tpr.tolist(), | |
| 'auc': float(roc_auc) | |
| } | |
| return roc_data | |
| def plot_roc_curves( | |
| roc_data: Dict, | |
| figsize: Tuple[int, int] = (10, 8), | |
| save_path: Optional[Path] = None | |
| ) -> plt.Figure: | |
| """ | |
| Plot ROC curves for all classes. | |
| Args: | |
| roc_data: ROC curve data from compute_roc_curves | |
| figsize: Figure size | |
| save_path: Optional save path | |
| Returns: | |
| Matplotlib figure | |
| """ | |
| fig, ax = plt.subplots(figsize=figsize) | |
| colors = plt.cm.Set2(np.linspace(0, 1, len(roc_data))) | |
| for (class_name, data), color in zip(roc_data.items(), colors): | |
| ax.plot( | |
| data['fpr'], data['tpr'], | |
| color=color, lw=2, | |
| label=f"{class_name} (AUC = {data['auc']:.2f})" | |
| ) | |
| ax.plot([0, 1], [0, 1], 'k--', lw=2, label='Random') | |
| ax.set_xlim([0.0, 1.0]) | |
| ax.set_ylim([0.0, 1.05]) | |
| ax.set_xlabel('False Positive Rate') | |
| ax.set_ylabel('True Positive Rate') | |
| ax.set_title('ROC Curves by Emotion Class') | |
| ax.legend(loc='lower right') | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f"ROC curves saved to: {save_path}") | |
| return fig | |
| def compare_models( | |
| model_results: Dict[str, Dict], | |
| save_path: Optional[Path] = None | |
| ) -> plt.Figure: | |
| """ | |
| Compare multiple models. | |
| Args: | |
| model_results: Dictionary of model_name -> evaluation results | |
| save_path: Optional save path | |
| Returns: | |
| Matplotlib figure | |
| """ | |
| models = list(model_results.keys()) | |
| metrics = ['accuracy', 'macro_precision', 'macro_recall', 'macro_f1'] | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| x = np.arange(len(models)) | |
| width = 0.2 | |
| for i, metric in enumerate(metrics): | |
| values = [model_results[m].get(metric, 0) for m in models] | |
| offset = (i - len(metrics)/2 + 0.5) * width | |
| bars = ax.bar(x + offset, values, width, label=metric.replace('_', ' ').title()) | |
| ax.set_xlabel('Model') | |
| ax.set_ylabel('Score') | |
| ax.set_title('Model Comparison') | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(models) | |
| ax.legend() | |
| ax.set_ylim(0, 1.0) | |
| ax.grid(True, alpha=0.3, axis='y') | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') | |
| print(f"Model comparison saved to: {save_path}") | |
| return fig | |
| if __name__ == "__main__": | |
| # Example usage | |
| print("Evaluation module loaded successfully.") | |
| print(f"Emotion classes: {EMOTION_CLASSES}") | |