|
|
"""
|
|
|
Utility functions for the CIFAR-10 CNN project
|
|
|
"""
|
|
|
import os
|
|
|
import torch
|
|
|
import matplotlib
|
|
|
matplotlib.use('Agg')
|
|
|
import matplotlib.pyplot as plt
|
|
|
import numpy as np
|
|
|
from sklearn.metrics import confusion_matrix, classification_report
|
|
|
|
|
|
import config
|
|
|
|
|
|
|
|
|
def save_checkpoint(model, optimizer, epoch, accuracy, filepath):
|
|
|
"""
|
|
|
Save model checkpoint
|
|
|
|
|
|
Args:
|
|
|
model: PyTorch model
|
|
|
optimizer: Optimizer
|
|
|
epoch: Current epoch
|
|
|
accuracy: Current accuracy
|
|
|
filepath: Path to save checkpoint
|
|
|
"""
|
|
|
checkpoint = {
|
|
|
'epoch': epoch,
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
|
'accuracy': accuracy
|
|
|
}
|
|
|
torch.save(checkpoint, filepath)
|
|
|
|
|
|
|
|
|
def load_checkpoint(model, optimizer, filepath):
|
|
|
"""
|
|
|
Load model checkpoint
|
|
|
|
|
|
Args:
|
|
|
model: PyTorch model
|
|
|
optimizer: Optimizer
|
|
|
filepath: Path to checkpoint file
|
|
|
|
|
|
Returns:
|
|
|
tuple: (epoch, accuracy)
|
|
|
"""
|
|
|
checkpoint = torch.load(filepath, map_location=config.DEVICE)
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
if optimizer:
|
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
epoch = checkpoint['epoch']
|
|
|
accuracy = checkpoint['accuracy']
|
|
|
return epoch, accuracy
|
|
|
|
|
|
|
|
|
def plot_training_history(history, save_dir):
|
|
|
"""
|
|
|
Plot training history
|
|
|
|
|
|
Args:
|
|
|
history: Dictionary containing training history
|
|
|
save_dir: Directory to save plots
|
|
|
"""
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
|
|
|
|
|
|
|
|
ax1.plot(history['train_loss'], label='Train Loss', linewidth=2)
|
|
|
ax1.plot(history['val_loss'], label='Validation Loss', linewidth=2)
|
|
|
ax1.set_xlabel('Epoch', fontsize=12)
|
|
|
ax1.set_ylabel('Loss', fontsize=12)
|
|
|
ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
|
|
|
ax1.legend(fontsize=10)
|
|
|
ax1.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
ax2.plot(history['train_acc'], label='Train Accuracy', linewidth=2)
|
|
|
ax2.plot(history['val_acc'], label='Validation Accuracy', linewidth=2)
|
|
|
ax2.set_xlabel('Epoch', fontsize=12)
|
|
|
ax2.set_ylabel('Accuracy (%)', fontsize=12)
|
|
|
ax2.set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
|
|
|
ax2.legend(fontsize=10)
|
|
|
ax2.grid(True, alpha=0.3)
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(os.path.join(save_dir, 'training_history.png'), dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
def plot_confusion_matrix(y_true, y_pred, save_path):
|
|
|
"""
|
|
|
Plot confusion matrix
|
|
|
|
|
|
Args:
|
|
|
y_true: True labels
|
|
|
y_pred: Predicted labels
|
|
|
save_path: Path to save the plot
|
|
|
"""
|
|
|
import seaborn as sns
|
|
|
cm = confusion_matrix(y_true, y_pred)
|
|
|
|
|
|
plt.figure(figsize=(12, 10))
|
|
|
sns.heatmap(
|
|
|
cm, annot=True, fmt='d', cmap='Blues',
|
|
|
xticklabels=config.CLASS_NAMES,
|
|
|
yticklabels=config.CLASS_NAMES,
|
|
|
cbar_kws={'label': 'Count'}
|
|
|
)
|
|
|
plt.xlabel('Predicted Label', fontsize=12)
|
|
|
plt.ylabel('True Label', fontsize=12)
|
|
|
plt.title('Confusion Matrix', fontsize=14, fontweight='bold')
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
def print_classification_report(y_true, y_pred):
|
|
|
"""
|
|
|
Print classification report
|
|
|
|
|
|
Args:
|
|
|
y_true: True labels
|
|
|
y_pred: Predicted labels
|
|
|
"""
|
|
|
report = classification_report(
|
|
|
y_true, y_pred,
|
|
|
target_names=config.CLASS_NAMES,
|
|
|
digits=4
|
|
|
)
|
|
|
print("\nClassification Report:")
|
|
|
print("=" * 80)
|
|
|
print(report)
|
|
|
print("=" * 80)
|
|
|
|
|
|
|
|
|
def visualize_predictions(model, test_loader, device, num_images=16):
|
|
|
"""
|
|
|
Visualize model predictions
|
|
|
|
|
|
Args:
|
|
|
model: PyTorch model
|
|
|
test_loader: Test data loader
|
|
|
device: Device to run on
|
|
|
num_images: Number of images to visualize
|
|
|
"""
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
images, labels = next(iter(test_loader))
|
|
|
images, labels = images[:num_images], labels[:num_images]
|
|
|
images_device = images.to(device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = model(images_device)
|
|
|
_, predicted = outputs.max(1)
|
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
|
|
|
axes = axes.ravel()
|
|
|
|
|
|
for idx in range(num_images):
|
|
|
|
|
|
img = images[idx].cpu().numpy().transpose(1, 2, 0)
|
|
|
mean = np.array([0.4914, 0.4822, 0.4465])
|
|
|
std = np.array([0.2470, 0.2435, 0.2616])
|
|
|
img = img * std + mean
|
|
|
img = np.clip(img, 0, 1)
|
|
|
|
|
|
|
|
|
axes[idx].imshow(img)
|
|
|
axes[idx].axis('off')
|
|
|
|
|
|
true_label = config.CLASS_NAMES[labels[idx]]
|
|
|
pred_label = config.CLASS_NAMES[predicted[idx].cpu()]
|
|
|
|
|
|
color = 'green' if labels[idx] == predicted[idx].cpu() else 'red'
|
|
|
axes[idx].set_title(
|
|
|
f'True: {true_label}\nPred: {pred_label}',
|
|
|
color=color, fontsize=10
|
|
|
)
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(os.path.join(config.PLOTS_DIR, 'predictions.png'), dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
print(f"Predictions visualization saved to {config.PLOTS_DIR}/predictions.png")
|
|
|
|