File size: 6,922 Bytes
e942d15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
"""
Evaluation functions for model testing and visualization.
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
def evaluate_model(model, test_loader, device, class_names=['Not Smoking', 'Smoking']):
"""
Evaluate model on test set and return predictions and labels.
Args:
model: PyTorch model
test_loader: DataLoader for test data
device: Device to evaluate on (cuda/cpu)
class_names: List of class names for reporting
Returns:
tuple: (all_predictions, all_labels, test_accuracy)
"""
model.eval()
all_preds = []
all_labels = []
print("🧪 Evaluating on Test Set...")
print(f" Test batches: {len(test_loader)}\n")
with torch.no_grad():
for images, labels in tqdm(test_loader, desc="Testing"):
images = images.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.numpy())
# Calculate accuracy
test_acc = 100. * sum(p == l for p, l in zip(all_preds, all_labels)) / len(all_labels)
return all_preds, all_labels, test_acc
def print_classification_report(predictions, labels, class_names=['Not Smoking', 'Smoking']):
"""
Print detailed classification metrics.
Args:
predictions: List of predicted labels
labels: List of true labels
class_names: List of class names
"""
print(f"\n{'='*60}")
print(f"📊 TEST SET RESULTS")
print(f"{'='*60}")
# Overall accuracy
test_acc = 100. * sum(p == l for p, l in zip(predictions, labels)) / len(labels)
print(f"\n Overall Accuracy: {test_acc:.2f}%\n")
# Detailed report
print("\nDetailed Classification Report:")
print(classification_report(labels, predictions, target_names=class_names, digits=4))
print(f"{'='*60}")
def plot_confusion_matrix(predictions, labels, class_names=['Not Smoking', 'Smoking'],
save_path=None):
"""
Plot confusion matrix.
Args:
predictions: List of predicted labels
labels: List of true labels
class_names: List of class names
save_path: Optional path to save the figure
Returns:
matplotlib figure
"""
cm = confusion_matrix(labels, predictions)
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names,
cbar_kws={'label': 'Count'}, ax=ax)
ax.set_title('Confusion Matrix - Test Set', fontsize=14, fontweight='bold')
ax.set_ylabel('True Label')
ax.set_xlabel('Predicted Label')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Confusion matrix saved to {save_path}")
return fig
def plot_training_history(history, save_path=None):
"""
Plot training and validation loss/accuracy curves.
Args:
history: Dictionary with keys 'train_loss', 'val_loss', 'train_acc', 'val_acc'
save_path: Optional path to save the figure
Returns:
matplotlib figure
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# Loss curves
ax1.plot(history['train_loss'], label='Train Loss', marker='o', linewidth=2)
ax1.plot(history['val_loss'], label='Val Loss', marker='s', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
# Accuracy curves
ax2.plot(history['train_acc'], label='Train Accuracy', marker='o', linewidth=2)
ax2.plot(history['val_acc'], label='Val Accuracy', marker='s', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Training history saved to {save_path}")
return fig
def get_predictions_with_confidence(model, dataloader, device):
"""
Get predictions along with confidence scores.
Args:
model: PyTorch model
dataloader: DataLoader for data
device: Device to run inference on
Returns:
tuple: (predictions, confidences, labels)
"""
model.eval()
all_preds = []
all_confidences = []
all_labels = []
with torch.no_grad():
for images, labels in dataloader:
images = images.to(device)
outputs = model(images)
# Get softmax probabilities
probs = torch.softmax(outputs, dim=1)
confidences, predicted = probs.max(1)
all_preds.extend(predicted.cpu().numpy())
all_confidences.extend(confidences.cpu().numpy())
all_labels.extend(labels.numpy())
return np.array(all_preds), np.array(all_confidences), np.array(all_labels)
def analyze_errors(model, dataloader, device, dataset, num_samples=10):
"""
Analyze misclassified samples.
Args:
model: PyTorch model
dataloader: DataLoader for data
device: Device to run inference on
dataset: Original dataset to access images
num_samples: Number of error samples to display
Returns:
List of dictionaries with error information
"""
predictions, confidences, labels = get_predictions_with_confidence(model, dataloader, device)
# Find misclassified samples
errors = []
for idx, (pred, conf, label) in enumerate(zip(predictions, confidences, labels)):
if pred != label:
errors.append({
'index': idx,
'true_label': label,
'predicted_label': pred,
'confidence': conf,
'image_path': dataset.image_paths[idx]
})
print(f"\n🔍 Error Analysis:")
print(f" Total errors: {len(errors)}")
print(f" Error rate: {100 * len(errors) / len(labels):.2f}%")
# Sort by confidence (highest confidence errors are most interesting)
errors.sort(key=lambda x: x['confidence'], reverse=True)
return errors[:num_samples] |