Spaces:
Sleeping
Sleeping
File size: 4,995 Bytes
3e16037 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | """
Evaluation Module for CIFAR-10 Image Classification.
Provides functions to:
- Evaluate model accuracy on test data
- Generate classification reports (precision, recall, F1)
- Plot and save confusion matrices
- Plot training history curves
"""
import os
import numpy as np
import matplotlib
matplotlib.use('Agg') # Non-interactive backend for saving plots
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
classification_report, confusion_matrix,
accuracy_score, precision_score, recall_score, f1_score
)
from .data_loader import CLASS_NAMES
def evaluate_model(model, x_test, y_test):
"""
Evaluate a trained model on the test set.
Args:
model: Trained Keras model.
x_test: Test images (normalized).
y_test: Test labels (one-hot encoded).
Returns:
dict: Dictionary with loss, accuracy, precision, recall, f1,
y_true (int labels), y_pred (int predictions),
and the full classification_report string.
"""
# Get test loss and accuracy
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
# Get predictions
y_pred_proba = model.predict(x_test, verbose=0)
y_pred = np.argmax(y_pred_proba, axis=1)
y_true = np.argmax(y_test, axis=1)
# Compute metrics
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')
report = classification_report(y_true, y_pred, target_names=CLASS_NAMES)
return {
'loss': loss,
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1_score': f1,
'y_true': y_true,
'y_pred': y_pred,
'classification_report': report
}
def plot_confusion_matrix(y_true, y_pred, class_names=CLASS_NAMES,
save_path=None, title='Confusion Matrix'):
"""
Plot and optionally save a confusion matrix heatmap.
Args:
y_true: True integer labels.
y_pred: Predicted integer labels.
class_names: List of class name strings.
save_path: Optional file path to save the plot.
title: Plot title.
"""
cm = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names,
ax=ax
)
ax.set_xlabel('Predicted Label', fontsize=12)
ax.set_ylabel('True Label', fontsize=12)
ax.set_title(title, fontsize=14)
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Confusion matrix saved to: {save_path}")
plt.close(fig)
def plot_training_history(history, save_path=None, title_prefix=''):
"""
Plot training and validation accuracy/loss curves.
Args:
history: Keras History object from model.fit().
save_path: Optional file path to save the plot.
title_prefix: Optional prefix for plot titles (e.g., 'Custom CNN').
"""
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Accuracy plot
axes[0].plot(history.history['accuracy'], label='Train Accuracy', linewidth=2)
axes[0].plot(history.history['val_accuracy'], label='Validation Accuracy', linewidth=2)
axes[0].set_title(f'{title_prefix} Accuracy', fontsize=14)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Loss plot
axes[1].plot(history.history['loss'], label='Train Loss', linewidth=2)
axes[1].plot(history.history['val_loss'], label='Validation Loss', linewidth=2)
axes[1].set_title(f'{title_prefix} Loss', fontsize=14)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Training history saved to: {save_path}")
plt.close(fig)
def print_evaluation_summary(metrics, model_name='Model'):
"""
Print a formatted evaluation summary.
Args:
metrics: Dictionary returned by evaluate_model().
model_name: Name of the model for display.
"""
print(f"\n{'='*60}")
print(f" {model_name} — Evaluation Results")
print(f"{'='*60}")
print(f" Test Loss: {metrics['loss']:.4f}")
print(f" Test Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.2f}%)")
print(f" Precision: {metrics['precision']:.4f}")
print(f" Recall: {metrics['recall']:.4f}")
print(f" F1-Score: {metrics['f1_score']:.4f}")
print(f"{'='*60}")
print(f"\nClassification Report:\n")
print(metrics['classification_report'])
|