| |
|
| | """
|
| | ResNet50 Evaluation & Confusion Matrix Script
|
| | Dataset: Animals-10
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | from torchvision import datasets, transforms, models
|
| | from torch.utils.data import DataLoader, random_split
|
| | from sklearn.metrics import confusion_matrix, classification_report
|
| | import seaborn as sns
|
| | import matplotlib.pyplot as plt
|
| | import os
|
| | import kagglehub
|
| |
|
| |
|
| | MODEL_PATH = "best_resnet50_animals.pt"
|
| | BATCH_SIZE = 64
|
| | NUM_WORKERS = 2
|
| |
|
| |
|
| |
|
| | def get_data_path():
|
| | """Locates the dataset locally or downloads it via KaggleHub."""
|
| | current_dir = os.getcwd()
|
| | local_path = os.path.join(current_dir, "animals10", "raw-img")
|
| |
|
| | if os.path.exists(local_path):
|
| | print(f"Dataset found locally at: {local_path}")
|
| | return local_path
|
| |
|
| | print("Dataset not found locally. Downloading via KaggleHub...")
|
| | path = kagglehub.dataset_download("alessiocorrado99/animals10")
|
| | return os.path.join(path, "raw-img")
|
| |
|
| | def evaluate_model():
|
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| | if device.type == "cuda":
|
| | print(f"Device: CUDA ({torch.cuda.get_device_name(0)})")
|
| | else:
|
| | print("Device: CPU")
|
| |
|
| |
|
| | data_path = get_data_path()
|
| |
|
| |
|
| | test_transform = transforms.Compose([
|
| | transforms.Resize((256, 256)),
|
| | transforms.CenterCrop(224),
|
| | transforms.ToTensor(),
|
| | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
| | ])
|
| |
|
| | print("Loading dataset...")
|
| | dataset = datasets.ImageFolder(data_path, transform=test_transform)
|
| | classes = dataset.classes
|
| | print(f"Total samples: {len(dataset)} | Classes: {len(classes)}")
|
| |
|
| |
|
| |
|
| | total_len = len(dataset)
|
| | train_len = int(0.8 * total_len)
|
| | val_len = int(0.1 * total_len)
|
| | test_len = total_len - train_len - val_len
|
| |
|
| | generator = torch.Generator().manual_seed(42)
|
| | _, _, test_set = random_split(dataset, [train_len, val_len, test_len], generator=generator)
|
| |
|
| | test_loader = DataLoader(
|
| | test_set,
|
| | batch_size=BATCH_SIZE,
|
| | shuffle=False,
|
| | num_workers=NUM_WORKERS,
|
| | pin_memory=True if device.type == "cuda" else False
|
| | )
|
| |
|
| |
|
| | print(f"Loading model weights from: {MODEL_PATH}")
|
| | if not os.path.exists(MODEL_PATH):
|
| | print(f"Error: Model file '{MODEL_PATH}' not found in the directory.")
|
| | return
|
| |
|
| | model = models.resnet50(weights=None)
|
| | model.fc = nn.Linear(model.fc.in_features, 10)
|
| |
|
| |
|
| | model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
|
| | model = model.to(device)
|
| | model.eval()
|
| | print("Model loaded successfully.")
|
| |
|
| |
|
| | all_preds = []
|
| | all_labels = []
|
| |
|
| | print(f"Starting inference on {len(test_set)} test samples...")
|
| |
|
| | with torch.no_grad():
|
| | for batch_idx, (inputs, labels) in enumerate(test_loader):
|
| | inputs = inputs.to(device)
|
| |
|
| | outputs = model(inputs)
|
| | _, preds = torch.max(outputs, 1)
|
| |
|
| | all_preds.extend(preds.cpu().numpy())
|
| | all_labels.extend(labels.numpy())
|
| |
|
| |
|
| | if (batch_idx + 1) % 10 == 0:
|
| | print(f"Processed batch: {batch_idx + 1}/{len(test_loader)}")
|
| |
|
| |
|
| | correct_preds = sum([1 for i in range(len(all_preds)) if all_preds[i] == all_labels[i]])
|
| | accuracy = 100 * correct_preds / len(all_preds)
|
| | print(f"\nTest Accuracy: {accuracy:.2f}%")
|
| |
|
| |
|
| | print("\nClassification Report:")
|
| | print(classification_report(all_labels, all_preds, target_names=classes, digits=3))
|
| |
|
| |
|
| | cm = confusion_matrix(all_labels, all_preds)
|
| | plt.figure(figsize=(12, 10))
|
| | sns.heatmap(
|
| | cm,
|
| | annot=True,
|
| | fmt='d',
|
| | xticklabels=classes,
|
| | yticklabels=classes,
|
| | cmap='Blues',
|
| | cbar_kws={'label': 'Count'}
|
| | )
|
| | plt.xlabel('Predicted Class', fontsize=12, fontweight='bold')
|
| | plt.ylabel('True Class', fontsize=12, fontweight='bold')
|
| | plt.title(f'Confusion Matrix - Accuracy: {accuracy:.2f}%', fontsize=14, fontweight='bold')
|
| | plt.xticks(rotation=45, ha='right')
|
| | plt.yticks(rotation=0)
|
| | plt.tight_layout()
|
| |
|
| | output_file = 'model_performance_matrix.png'
|
| | plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
| | print(f"\nConfusion Matrix saved as: {output_file}")
|
| |
|
| |
|
| | if device.type == "cuda":
|
| | torch.cuda.empty_cache()
|
| |
|
| | if __name__ == '__main__':
|
| | evaluate_model() |