File size: 6,817 Bytes
b82820a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
import numpy as np
import csv  # Add this line to import the csv module
from PIL import Image  # Added for binarization

# Binarization function
def binarize(img, threshold=128):
    img = img.convert('L')
    return img.point(lambda p: 255 if p > threshold else 0).convert('L')

# Constants

# Function to get class names and counts
def get_class_info(directory):
    classes = sorted(os.listdir(directory))
    class_lengths = {class_name: len(os.listdir(os.path.join(directory, class_name))) for class_name in classes}
    return classes, class_lengths

class ImageFolderWithPaths(datasets.ImageFolder):
    def __getitem__(self, index):
        # this is the original getitem method
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        # the image file path
        path = self.imgs[index][0]
        # make a new tuple that includes original and the path
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

# 1. Data Loading and Transformation
def load_data(TEST, batch_size=10):  # 25
    transform_test = transforms.Compose([
        transforms.Lambda(lambda img: binarize(img, threshold=128)),
        transforms.Resize((224, 224)),
        transforms.Lambda(lambda img: img.convert('RGB')),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    test_dataset = ImageFolderWithPaths(root=TEST, transform=transform_test)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    classes = sorted(os.listdir(TEST))
    return test_loader, classes, len(test_dataset)

# 2. Model Loading and Configuration
def load_model(model_weights_path, num_classes, device):
    model = models.resnet18(weights=None)  # Changed to resnet18
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(model_weights_path))
    model = model.to(device)
    model.eval()
    return model

# 3. Evaluation for Testing
def evaluate(model, test_iterator, criterion, device):
    model.eval()
    epoch_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (data, labels, _) in enumerate(test_iterator):
            data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            loss = criterion(outputs, labels)
            epoch_loss += loss.item()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    # Calculate overall accuracy
    accuracy = 100. * correct / total
    avg_loss = epoch_loss / len(test_iterator)
    print(f'\nTesting completed. Average Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
    return avg_loss, accuracy

# 4. Prediction and Label Gathering
def get_all_predictions(model, iterator, device):
    model.eval()
    all_preds = []
    all_labels = []
    all_filenames = []  # to store the filenames

    with torch.no_grad():
        for data, labels, paths in iterator:
            data = data.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
            # Retrieving file paths
            all_filenames.extend(paths)
    return all_preds, all_labels, all_filenames

# 5. Confusion Matrix Creation and Visualization
def plot_and_save_confusion_matrix(true_labels, predictions, classes, save_path):
    # Ensure the directory exists
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    conf_mat = confusion_matrix(true_labels, predictions)
    plt.figure(figsize=(10, 10))
    sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.savefig(save_path)
    plt.close()  # Close the plot to free memory
    print(f"Confusion matrix saved to {save_path}")

# Main Execution
def main():
    TEST_DIR = '/home/pola/01_printed_document/printed_13class/data/test'
    MODEL_WEIGHTS_PATH = "/home/pola/01_printed_document/printed_13class/model_weights/650556_95909_35_weights.pt"
    directory, filename = os.path.split(MODEL_WEIGHTS_PATH)
    RESULTS_DIR = '/home/pola/01_printed_document/printed_13class/results/'
    if not os.path.exists(RESULTS_DIR):
        os.makedirs(RESULTS_DIR)
    train_classes, train_class_lengths = get_class_info(TEST_DIR)
    max_class_name_length = max(len(name) for name in train_classes)

    print(f"{'Class Name'.ljust(max_class_name_length)} | {'Test Images'.ljust(16)} ")
    print('-' * (max_class_name_length + 36))
    for class_name in sorted(set(train_classes)):
        train_count = train_class_lengths.get(class_name, 0)
        print(f"{class_name.ljust(max_class_name_length)} | {str(train_count).ljust(16)} ")
    print(f"\nTotal images in Training Dataset: {sum(train_class_lengths.values())}")
    classes = sorted(os.listdir(TEST_DIR))

    # Initialize
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    criterion = nn.CrossEntropyLoss()

    # Load Data
    test_loader, classes, num_test_samples = load_data(TEST_DIR)

    # Load Model
    model = load_model(MODEL_WEIGHTS_PATH, len(classes), device)
    print("Loaded model with classes:", classes)

    # Evaluate Model
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)

    # Plot and Save Confusion Matrix
    save_path = './results'
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    all_preds, all_labels, all_filenames = get_all_predictions(model, test_loader, device)

    # Dynamic filename construction
    filename_details = f"{filename}_{num_test_samples}_smpls_{test_acc:.2f}_pct_acc_{len(classes)}_classes"
    csv_filename = f"{RESULTS_DIR}/pred_{filename_details}.csv"
    confusion_matrix_filename = f"{RESULTS_DIR}conf_{filename_details}_.png"

    # Save all details to a CSV file
    with open(csv_filename, 'w', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow(['Filename', 'True Label', 'Predicted Label'])
        for filename, true_label, predicted_label in zip(all_filenames, all_labels, all_preds):
            csvwriter.writerow([filename, classes[true_label], classes[predicted_label]])

    print(f"All prediction details saved in CSV: {csv_filename}")
    plot_and_save_confusion_matrix(all_labels, all_preds, classes, confusion_matrix_filename)

if __name__ == "__main__":
    main()