""" Evaluation Module for CIFAR-10 Image Classification. Provides functions to: - Evaluate model accuracy on test data - Generate classification reports (precision, recall, F1) - Plot and save confusion matrices - Plot training history curves """ import os import numpy as np import matplotlib matplotlib.use('Agg') # Non-interactive backend for saving plots import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import ( classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score ) from .data_loader import CLASS_NAMES def evaluate_model(model, x_test, y_test): """ Evaluate a trained model on the test set. Args: model: Trained Keras model. x_test: Test images (normalized). y_test: Test labels (one-hot encoded). Returns: dict: Dictionary with loss, accuracy, precision, recall, f1, y_true (int labels), y_pred (int predictions), and the full classification_report string. """ # Get test loss and accuracy loss, accuracy = model.evaluate(x_test, y_test, verbose=0) # Get predictions y_pred_proba = model.predict(x_test, verbose=0) y_pred = np.argmax(y_pred_proba, axis=1) y_true = np.argmax(y_test, axis=1) # Compute metrics precision = precision_score(y_true, y_pred, average='weighted') recall = recall_score(y_true, y_pred, average='weighted') f1 = f1_score(y_true, y_pred, average='weighted') report = classification_report(y_true, y_pred, target_names=CLASS_NAMES) return { 'loss': loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1, 'y_true': y_true, 'y_pred': y_pred, 'classification_report': report } def plot_confusion_matrix(y_true, y_pred, class_names=CLASS_NAMES, save_path=None, title='Confusion Matrix'): """ Plot and optionally save a confusion matrix heatmap. Args: y_true: True integer labels. y_pred: Predicted integer labels. class_names: List of class name strings. save_path: Optional file path to save the plot. title: Plot title. """ cm = confusion_matrix(y_true, y_pred) fig, ax = plt.subplots(figsize=(10, 8)) sns.heatmap( cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, ax=ax ) ax.set_xlabel('Predicted Label', fontsize=12) ax.set_ylabel('True Label', fontsize=12) ax.set_title(title, fontsize=14) plt.tight_layout() if save_path: os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Confusion matrix saved to: {save_path}") plt.close(fig) def plot_training_history(history, save_path=None, title_prefix=''): """ Plot training and validation accuracy/loss curves. Args: history: Keras History object from model.fit(). save_path: Optional file path to save the plot. title_prefix: Optional prefix for plot titles (e.g., 'Custom CNN'). """ fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Accuracy plot axes[0].plot(history.history['accuracy'], label='Train Accuracy', linewidth=2) axes[0].plot(history.history['val_accuracy'], label='Validation Accuracy', linewidth=2) axes[0].set_title(f'{title_prefix} Accuracy', fontsize=14) axes[0].set_xlabel('Epoch') axes[0].set_ylabel('Accuracy') axes[0].legend() axes[0].grid(True, alpha=0.3) # Loss plot axes[1].plot(history.history['loss'], label='Train Loss', linewidth=2) axes[1].plot(history.history['val_loss'], label='Validation Loss', linewidth=2) axes[1].set_title(f'{title_prefix} Loss', fontsize=14) axes[1].set_xlabel('Epoch') axes[1].set_ylabel('Loss') axes[1].legend() axes[1].grid(True, alpha=0.3) plt.tight_layout() if save_path: os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Training history saved to: {save_path}") plt.close(fig) def print_evaluation_summary(metrics, model_name='Model'): """ Print a formatted evaluation summary. Args: metrics: Dictionary returned by evaluate_model(). model_name: Name of the model for display. """ print(f"\n{'='*60}") print(f" {model_name} — Evaluation Results") print(f"{'='*60}") print(f" Test Loss: {metrics['loss']:.4f}") print(f" Test Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.2f}%)") print(f" Precision: {metrics['precision']:.4f}") print(f" Recall: {metrics['recall']:.4f}") print(f" F1-Score: {metrics['f1_score']:.4f}") print(f"{'='*60}") print(f"\nClassification Report:\n") print(metrics['classification_report'])