Asim02's picture
Upload 3 files
b82820a verified
import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
import numpy as np
import csv # Add this line to import the csv module
from PIL import Image # Added for binarization
# Binarization function
def binarize(img, threshold=128):
img = img.convert('L')
return img.point(lambda p: 255 if p > threshold else 0).convert('L')
# Constants
# Function to get class names and counts
def get_class_info(directory):
classes = sorted(os.listdir(directory))
class_lengths = {class_name: len(os.listdir(os.path.join(directory, class_name))) for class_name in classes}
return classes, class_lengths
class ImageFolderWithPaths(datasets.ImageFolder):
def __getitem__(self, index):
# this is the original getitem method
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
# 1. Data Loading and Transformation
def load_data(TEST, batch_size=10): # 25
transform_test = transforms.Compose([
transforms.Lambda(lambda img: binarize(img, threshold=128)),
transforms.Resize((224, 224)),
transforms.Lambda(lambda img: img.convert('RGB')),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_dataset = ImageFolderWithPaths(root=TEST, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
classes = sorted(os.listdir(TEST))
return test_loader, classes, len(test_dataset)
# 2. Model Loading and Configuration
def load_model(model_weights_path, num_classes, device):
model = models.resnet18(weights=None) # Changed to resnet18
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(torch.load(model_weights_path))
model = model.to(device)
model.eval()
return model
# 3. Evaluation for Testing
def evaluate(model, test_iterator, criterion, device):
model.eval()
epoch_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (data, labels, _) in enumerate(test_iterator):
data, labels = data.to(device), labels.to(device)
outputs = model(data)
loss = criterion(outputs, labels)
epoch_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# Calculate overall accuracy
accuracy = 100. * correct / total
avg_loss = epoch_loss / len(test_iterator)
print(f'\nTesting completed. Average Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
return avg_loss, accuracy
# 4. Prediction and Label Gathering
def get_all_predictions(model, iterator, device):
model.eval()
all_preds = []
all_labels = []
all_filenames = [] # to store the filenames
with torch.no_grad():
for data, labels, paths in iterator:
data = data.to(device)
outputs = model(data)
_, predicted = torch.max(outputs, 1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.numpy())
# Retrieving file paths
all_filenames.extend(paths)
return all_preds, all_labels, all_filenames
# 5. Confusion Matrix Creation and Visualization
def plot_and_save_confusion_matrix(true_labels, predictions, classes, save_path):
# Ensure the directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True)
conf_mat = confusion_matrix(true_labels, predictions)
plt.figure(figsize=(10, 10))
sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig(save_path)
plt.close() # Close the plot to free memory
print(f"Confusion matrix saved to {save_path}")
# Main Execution
def main():
TEST_DIR = '/home/pola/01_printed_document/printed_13class/data/test'
MODEL_WEIGHTS_PATH = "/home/pola/01_printed_document/printed_13class/model_weights/650556_95909_35_weights.pt"
directory, filename = os.path.split(MODEL_WEIGHTS_PATH)
RESULTS_DIR = '/home/pola/01_printed_document/printed_13class/results/'
if not os.path.exists(RESULTS_DIR):
os.makedirs(RESULTS_DIR)
train_classes, train_class_lengths = get_class_info(TEST_DIR)
max_class_name_length = max(len(name) for name in train_classes)
print(f"{'Class Name'.ljust(max_class_name_length)} | {'Test Images'.ljust(16)} ")
print('-' * (max_class_name_length + 36))
for class_name in sorted(set(train_classes)):
train_count = train_class_lengths.get(class_name, 0)
print(f"{class_name.ljust(max_class_name_length)} | {str(train_count).ljust(16)} ")
print(f"\nTotal images in Training Dataset: {sum(train_class_lengths.values())}")
classes = sorted(os.listdir(TEST_DIR))
# Initialize
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()
# Load Data
test_loader, classes, num_test_samples = load_data(TEST_DIR)
# Load Model
model = load_model(MODEL_WEIGHTS_PATH, len(classes), device)
print("Loaded model with classes:", classes)
# Evaluate Model
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
# Plot and Save Confusion Matrix
save_path = './results'
os.makedirs(os.path.dirname(save_path), exist_ok=True)
all_preds, all_labels, all_filenames = get_all_predictions(model, test_loader, device)
# Dynamic filename construction
filename_details = f"{filename}_{num_test_samples}_smpls_{test_acc:.2f}_pct_acc_{len(classes)}_classes"
csv_filename = f"{RESULTS_DIR}/pred_{filename_details}.csv"
confusion_matrix_filename = f"{RESULTS_DIR}conf_{filename_details}_.png"
# Save all details to a CSV file
with open(csv_filename, 'w', newline='') as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(['Filename', 'True Label', 'Predicted Label'])
for filename, true_label, predicted_label in zip(all_filenames, all_labels, all_preds):
csvwriter.writerow([filename, classes[true_label], classes[predicted_label]])
print(f"All prediction details saved in CSV: {csv_filename}")
plot_and_save_confusion_matrix(all_labels, all_preds, classes, confusion_matrix_filename)
if __name__ == "__main__":
main()