import os import torch import torch.nn as nn from torchvision import datasets, transforms, models from torch.utils.data import DataLoader import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix import seaborn as sns import numpy as np import csv # Add this line to import the csv module from PIL import Image # Added for binarization # Binarization function def binarize(img, threshold=128): img = img.convert('L') return img.point(lambda p: 255 if p > threshold else 0).convert('L') # Constants # Function to get class names and counts def get_class_info(directory): classes = sorted(os.listdir(directory)) class_lengths = {class_name: len(os.listdir(os.path.join(directory, class_name))) for class_name in classes} return classes, class_lengths class ImageFolderWithPaths(datasets.ImageFolder): def __getitem__(self, index): # this is the original getitem method original_tuple = super(ImageFolderWithPaths, self).__getitem__(index) # the image file path path = self.imgs[index][0] # make a new tuple that includes original and the path tuple_with_path = (original_tuple + (path,)) return tuple_with_path # 1. Data Loading and Transformation def load_data(TEST, batch_size=10): # 25 transform_test = transforms.Compose([ transforms.Lambda(lambda img: binarize(img, threshold=128)), transforms.Resize((224, 224)), transforms.Lambda(lambda img: img.convert('RGB')), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) test_dataset = ImageFolderWithPaths(root=TEST, transform=transform_test) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) classes = sorted(os.listdir(TEST)) return test_loader, classes, len(test_dataset) # 2. Model Loading and Configuration def load_model(model_weights_path, num_classes, device): model = models.resnet18(weights=None) # Changed to resnet18 model.fc = nn.Linear(model.fc.in_features, num_classes) model.load_state_dict(torch.load(model_weights_path)) model = model.to(device) model.eval() return model # 3. Evaluation for Testing def evaluate(model, test_iterator, criterion, device): model.eval() epoch_loss = 0 correct = 0 total = 0 with torch.no_grad(): for batch_idx, (data, labels, _) in enumerate(test_iterator): data, labels = data.to(device), labels.to(device) outputs = model(data) loss = criterion(outputs, labels) epoch_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() # Calculate overall accuracy accuracy = 100. * correct / total avg_loss = epoch_loss / len(test_iterator) print(f'\nTesting completed. Average Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\n') return avg_loss, accuracy # 4. Prediction and Label Gathering def get_all_predictions(model, iterator, device): model.eval() all_preds = [] all_labels = [] all_filenames = [] # to store the filenames with torch.no_grad(): for data, labels, paths in iterator: data = data.to(device) outputs = model(data) _, predicted = torch.max(outputs, 1) all_preds.extend(predicted.cpu().numpy()) all_labels.extend(labels.numpy()) # Retrieving file paths all_filenames.extend(paths) return all_preds, all_labels, all_filenames # 5. Confusion Matrix Creation and Visualization def plot_and_save_confusion_matrix(true_labels, predictions, classes, save_path): # Ensure the directory exists os.makedirs(os.path.dirname(save_path), exist_ok=True) conf_mat = confusion_matrix(true_labels, predictions) plt.figure(figsize=(10, 10)) sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes) plt.xlabel('Predicted Labels') plt.ylabel('True Labels') plt.title('Confusion Matrix') plt.savefig(save_path) plt.close() # Close the plot to free memory print(f"Confusion matrix saved to {save_path}") # Main Execution def main(): TEST_DIR = '/home/pola/01_printed_document/printed_13class/data/test' MODEL_WEIGHTS_PATH = "/home/pola/01_printed_document/printed_13class/model_weights/650556_95909_35_weights.pt" directory, filename = os.path.split(MODEL_WEIGHTS_PATH) RESULTS_DIR = '/home/pola/01_printed_document/printed_13class/results/' if not os.path.exists(RESULTS_DIR): os.makedirs(RESULTS_DIR) train_classes, train_class_lengths = get_class_info(TEST_DIR) max_class_name_length = max(len(name) for name in train_classes) print(f"{'Class Name'.ljust(max_class_name_length)} | {'Test Images'.ljust(16)} ") print('-' * (max_class_name_length + 36)) for class_name in sorted(set(train_classes)): train_count = train_class_lengths.get(class_name, 0) print(f"{class_name.ljust(max_class_name_length)} | {str(train_count).ljust(16)} ") print(f"\nTotal images in Training Dataset: {sum(train_class_lengths.values())}") classes = sorted(os.listdir(TEST_DIR)) # Initialize device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') criterion = nn.CrossEntropyLoss() # Load Data test_loader, classes, num_test_samples = load_data(TEST_DIR) # Load Model model = load_model(MODEL_WEIGHTS_PATH, len(classes), device) print("Loaded model with classes:", classes) # Evaluate Model test_loss, test_acc = evaluate(model, test_loader, criterion, device) # Plot and Save Confusion Matrix save_path = './results' os.makedirs(os.path.dirname(save_path), exist_ok=True) all_preds, all_labels, all_filenames = get_all_predictions(model, test_loader, device) # Dynamic filename construction filename_details = f"{filename}_{num_test_samples}_smpls_{test_acc:.2f}_pct_acc_{len(classes)}_classes" csv_filename = f"{RESULTS_DIR}/pred_{filename_details}.csv" confusion_matrix_filename = f"{RESULTS_DIR}conf_{filename_details}_.png" # Save all details to a CSV file with open(csv_filename, 'w', newline='') as csvfile: csvwriter = csv.writer(csvfile) csvwriter.writerow(['Filename', 'True Label', 'Predicted Label']) for filename, true_label, predicted_label in zip(all_filenames, all_labels, all_preds): csvwriter.writerow([filename, classes[true_label], classes[predicted_label]]) print(f"All prediction details saved in CSV: {csv_filename}") plot_and_save_confusion_matrix(all_labels, all_preds, classes, confusion_matrix_filename) if __name__ == "__main__": main()