import os import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader, random_split from torchvision import models, transforms from pathlib import Path from PIL import Image import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report from tqdm import tqdm import pandas as pd import random from collections import defaultdict # MPS (Metal Performance Shaders) check - Apple GPU if torch.backends.mps.is_available(): DEVICE = torch.device("mps") print(f"Using Metal GPU: {DEVICE}") else: DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Metal GPU not found, using device: {DEVICE}") # Constants IMG_SIZE = 224 BATCH_SIZE = 64 # Batch size increased for GPU NUM_WORKERS = 6 # Number of threads increased MAX_SAMPLES_PER_CLASS = 30 # Maximum number of samples per class (for quick testing) # Transformation for test dataset test_transform = transforms.Compose([ transforms.Resize(IMG_SIZE + 32), transforms.CenterCrop(IMG_SIZE), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) class ArtDataset(Dataset): def __init__(self, samples, transform=None, class_to_idx=None): self.samples = samples self.transform = transform if class_to_idx is None: # Extract classes from samples classes = set([Path(str(s[0])).parent.name for s in samples]) self.classes = sorted(list(classes)) self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} else: self.class_to_idx = class_to_idx self.classes = sorted(class_to_idx.keys(), key=lambda x: class_to_idx[x]) def __len__(self): return len(self.samples) def __getitem__(self, idx): img_path, class_name = self.samples[idx] label = self.class_to_idx[class_name] img = Image.open(img_path).convert('RGB') if self.transform: img = self.transform(img) return img, label def create_test_set(data_dir, test_ratio=0.2, max_per_class=None): """Create test set by taking a certain percentage of samples from each class""" class_samples = defaultdict(list) # Collect all examples by their classes for class_dir in Path(data_dir).iterdir(): if class_dir.is_dir(): class_name = class_dir.name for img_path in class_dir.glob('*'): class_samples[class_name].append((img_path, class_name)) # Select a certain percentage and maximum number of examples from each class test_samples = [] for class_name, samples in class_samples.items(): random.shuffle(samples) n_test = max(1, int(len(samples) * test_ratio)) # Limit the maximum number of examples if max_per_class and n_test > max_per_class: n_test = max_per_class test_samples.extend(samples[:n_test]) print(f"Total of {len(test_samples)} test samples selected from {len(class_samples)} different art movements.") # Create class-index mapping classes = sorted(class_samples.keys()) class_to_idx = {cls: i for i, cls in enumerate(classes)} return test_samples, class_to_idx def load_model(model_path, num_classes): """Load model file""" print(f"Loading model: {model_path}") # Create ResNet34 model model = models.resnet34(weights=None) # Update the last fully-connected layer model.fc = nn.Linear(512, num_classes) # Special loading for Metal GPU availability check state_dict = torch.load(model_path, map_location=DEVICE) model.load_state_dict(state_dict) model = model.to(DEVICE) model.eval() return model def evaluate_model(model, test_loader, classes): """Evaluate model and return metrics""" all_preds = [] all_labels = [] with torch.no_grad(): for inputs, labels in tqdm(test_loader, desc="Evaluation"): inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) # Run directly on MPS device (without using autocast) outputs = model(inputs) _, preds = torch.max(outputs, 1) # Move results to CPU all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) # Calculate metrics accuracy = accuracy_score(all_labels, all_preds) f1 = f1_score(all_labels, all_preds, average='weighted') precision = precision_score(all_labels, all_preds, average='weighted') recall = recall_score(all_labels, all_preds, average='weighted') # Class-based accuracy class_accuracy = {} conf_matrix = confusion_matrix(all_labels, all_preds) for i, class_name in enumerate(classes): class_samples = np.sum(np.array(all_labels) == i) class_correct = conf_matrix[i, i] if class_samples > 0: class_accuracy[class_name] = class_correct / class_samples results = { 'accuracy': accuracy, 'f1_score': f1, 'precision': precision, 'recall': recall, 'class_accuracy': class_accuracy, 'confusion_matrix': conf_matrix, 'predictions': all_preds, 'ground_truth': all_labels } return results def plot_confusion_matrix(conf_matrix, classes, model_name, save_dir): """Plot confusion matrix graph""" plt.figure(figsize=(12, 10)) sns.heatmap(conf_matrix, annot=False, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes) plt.xlabel('Predicted Class') plt.ylabel('True Class') plt.title(f'Confusion Matrix - {model_name}') plt.tight_layout() # Save the graph save_path = Path(save_dir) / f"conf_matrix_{Path(model_name).stem}.png" plt.savefig(save_path, dpi=300) plt.close() def plot_class_accuracy(class_acc, model_name, save_dir): """Plot class-based accuracy graph""" plt.figure(figsize=(14, 8)) # Sort classes by accuracy value sorted_items = sorted(class_acc.items(), key=lambda x: x[1], reverse=True) classes = [item[0] for item in sorted_items] accuracies = [item[1] for item in sorted_items] bars = plt.bar(classes, accuracies) plt.xlabel('Art Movement') plt.ylabel('Accuracy') plt.title(f'Class-Based Accuracy - {model_name}') plt.xticks(rotation=90) plt.ylim(0, 1.0) # Add values on top of bars for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height, f'{height:.2f}', ha='center', va='bottom', rotation=0) plt.tight_layout() # Save the graph save_path = Path(save_dir) / f"class_accuracy_{Path(model_name).stem}.png" plt.savefig(save_path, dpi=300) plt.close() def plot_model_comparison(all_results, save_dir): """Plot model comparison graph""" model_names = list(all_results.keys()) metrics = ['accuracy', 'f1_score', 'precision', 'recall'] # Collect metrics metric_data = {metric: [all_results[model][metric] for model in model_names] for metric in metrics} # Compare metrics plt.figure(figsize=(12, 7)) x = np.arange(len(model_names)) width = 0.2 multiplier = 0 for metric, values in metric_data.items(): offset = width * multiplier bars = plt.bar(x + offset, values, width, label=metric) # Add values on top of bars for bar in bars: height = bar.get_height() plt.annotate(f'{height:.3f}', xy=(bar.get_x() + bar.get_width() / 2, height), xytext=(0, 3), # 3 points vertical offset textcoords="offset points", ha='center', va='bottom') multiplier += 1 plt.xlabel('Model') plt.ylabel('Score') plt.title('Model Performance Comparison') plt.xticks(x + width, model_names) plt.legend(loc='lower right') plt.ylim(0, 1.0) plt.tight_layout() # Save the graph save_path = Path(save_dir) / "model_comparison.png" plt.savefig(save_path, dpi=300) plt.close() def main(): # Data directory and results directory art_dataset_dir = 'Art Dataset' models_dir = 'models' results_dir = 'evaluation_results' # Create results directory os.makedirs(results_dir, exist_ok=True) # Create test data - limit maximum number of examples from each class test_samples, class_to_idx = create_test_set(art_dataset_dir, test_ratio=0.2, max_per_class=MAX_SAMPLES_PER_CLASS) test_dataset = ArtDataset(test_samples, transform=test_transform, class_to_idx=class_to_idx) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True) classes = test_dataset.classes num_classes = len(classes) print(f"Art classes: {len(classes)}") # Find model files (exclude files like .DS_Store) model_paths = [os.path.join(models_dir, f) for f in os.listdir(models_dir) if f.endswith('.pth') and not f.startswith('.')] # Dictionary to store results all_results = {} # Evaluate each model for model_path in model_paths: model_name = Path(model_path).name print(f"\nEvaluating {model_name}...") # Load model model = load_model(model_path, num_classes) # Evaluate model results = evaluate_model(model, test_loader, classes) all_results[model_name] = results print(f"Accuracy: {results['accuracy']:.4f}") print(f"F1 Score: {results['f1_score']:.4f}") print(f"Precision: {results['precision']:.4f}") print(f"Recall: {results['recall']:.4f}") # Plot confusion matrix graph plot_confusion_matrix(results['confusion_matrix'], classes, model_name, results_dir) # Plot class-based accuracy graph plot_class_accuracy(results['class_accuracy'], model_name, results_dir) # Save detailed class report report = classification_report(results['ground_truth'], results['predictions'], target_names=classes, output_dict=True) report_df = pd.DataFrame(report).transpose() report_df.to_csv(f"{results_dir}/classification_report_{Path(model_name).stem}.csv") # Compare models if len(all_results) > 1: plot_model_comparison(all_results, results_dir) # Save results to CSV file results_summary = [] for model_name, results in all_results.items(): row = { 'model': model_name, 'accuracy': results['accuracy'], 'f1_score': results['f1_score'], 'precision': results['precision'], 'recall': results['recall'] } results_summary.append(row) summary_df = pd.DataFrame(results_summary) summary_df.to_csv(f"{results_dir}/model_comparison_summary.csv", index=False) print(f"\nEvaluation completed. Results are in '{results_dir}' directory.") if __name__ == "__main__": # Set seed for reproducibility random.seed(42) np.random.seed(42) torch.manual_seed(42) main()