File size: 1,574 Bytes
b39a019
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import matplotlib.pyplot as plt
import torch
import Utilities as Utils
import classifierModel
import CrosswalkDataset as Dataset


model = Utils.load_model_for_eval('trainedClassifier_weights.pth', classifierModel.BasicClassificationModel)
dataset = Dataset.CrosswalkDataset("Crosswalk.v7-crosswalk-t3.tensorflow/test/_annotations.csv",
                                   "Crosswalk.v7-crosswalk-t3.tensorflow/test")


with torch.no_grad():
    loss = 0.0
    batch_size = 3
    dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size)
    loss_function = Utils.BasicClassificationLoss()
    count, notCount = 0, 0
    for images, gt_labels in dataloader:
        predictions = model(images)
        softmax_probabilities = Utils.softmax(predictions)
        for i in range(len(images)):
            plt.imshow(images[i].permute(1, 2, 0).numpy() / 255.0)
            classif = False
            if (gt_labels[i][1] > gt_labels[i][0] and softmax_probabilities[i][1] > softmax_probabilities[i][0]) or (gt_labels[i][1] <= gt_labels[i][0] and softmax_probabilities[i][1] <= softmax_probabilities[i][0]):
                classif = True
                count += 1
            else:
                notCount += 1
            plt.title(str(softmax_probabilities[i]) + " " + str(gt_labels[i]) + str(classif))
            plt.show()
            print(softmax_probabilities[i])

        batch_loss = loss_function(predictions, gt_labels)
        loss += batch_loss

    print("Loss is: ", loss / (len(dataloader) * batch_size))
    print(count, notCount)