|
|
"""
|
|
|
Evaluation script for CIFAR-10 CNN
|
|
|
"""
|
|
|
import os
|
|
|
import torch
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
import config
|
|
|
from model import get_model
|
|
|
from data_loader import get_data_loaders
|
|
|
from utils import (
|
|
|
load_checkpoint, plot_confusion_matrix,
|
|
|
print_classification_report, visualize_predictions
|
|
|
)
|
|
|
|
|
|
|
|
|
def evaluate():
|
|
|
"""
|
|
|
Evaluate the trained model
|
|
|
"""
|
|
|
|
|
|
os.makedirs(config.PLOTS_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
print("Loading CIFAR-10 dataset...")
|
|
|
_, test_loader = get_data_loaders()
|
|
|
print(f"Test samples: {len(test_loader.dataset)}")
|
|
|
|
|
|
|
|
|
print(f"\nLoading model from {config.BEST_MODEL_PATH}")
|
|
|
model = get_model(num_classes=config.NUM_CLASSES, device=config.DEVICE)
|
|
|
|
|
|
|
|
|
if not os.path.exists(config.BEST_MODEL_PATH):
|
|
|
print(f"Error: Model checkpoint not found at {config.BEST_MODEL_PATH}")
|
|
|
print("Please train the model first using train.py")
|
|
|
return
|
|
|
|
|
|
epoch, accuracy = load_checkpoint(model, None, config.BEST_MODEL_PATH)
|
|
|
print(f"Loaded model from epoch {epoch + 1} with accuracy: {accuracy:.2f}%")
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
correct = 0
|
|
|
total = 0
|
|
|
all_predictions = []
|
|
|
all_labels = []
|
|
|
|
|
|
print("\nEvaluating model...")
|
|
|
with torch.no_grad():
|
|
|
pbar = tqdm(test_loader, desc='Evaluating')
|
|
|
for inputs, labels in pbar:
|
|
|
inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
|
|
|
|
|
|
|
|
|
outputs = model(inputs)
|
|
|
_, predicted = outputs.max(1)
|
|
|
|
|
|
|
|
|
total += labels.size(0)
|
|
|
correct += predicted.eq(labels).sum().item()
|
|
|
|
|
|
|
|
|
all_predictions.extend(predicted.cpu().numpy())
|
|
|
all_labels.extend(labels.cpu().numpy())
|
|
|
|
|
|
|
|
|
pbar.set_postfix({'acc': f'{100. * correct / total:.2f}%'})
|
|
|
|
|
|
|
|
|
final_accuracy = 100. * correct / total
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 80)
|
|
|
print(f"Final Test Accuracy: {final_accuracy:.2f}%")
|
|
|
print(f"Correct predictions: {correct}/{total}")
|
|
|
print("=" * 80)
|
|
|
|
|
|
|
|
|
print_classification_report(all_labels, all_predictions)
|
|
|
|
|
|
|
|
|
print("\nGenerating confusion matrix...")
|
|
|
cm_path = os.path.join(config.PLOTS_DIR, 'confusion_matrix.png')
|
|
|
plot_confusion_matrix(all_labels, all_predictions, cm_path)
|
|
|
print(f"Confusion matrix saved to {cm_path}")
|
|
|
|
|
|
|
|
|
print("\nGenerating prediction visualizations...")
|
|
|
visualize_predictions(model, test_loader, config.DEVICE, num_images=16)
|
|
|
|
|
|
print("\nEvaluation completed!")
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
evaluate()
|
|
|
|