File size: 1,574 Bytes
b39a019 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import matplotlib.pyplot as plt
import torch
import Utilities as Utils
import classifierModel
import CrosswalkDataset as Dataset
model = Utils.load_model_for_eval('trainedClassifier_weights.pth', classifierModel.BasicClassificationModel)
dataset = Dataset.CrosswalkDataset("Crosswalk.v7-crosswalk-t3.tensorflow/test/_annotations.csv",
"Crosswalk.v7-crosswalk-t3.tensorflow/test")
with torch.no_grad():
loss = 0.0
batch_size = 3
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size)
loss_function = Utils.BasicClassificationLoss()
count, notCount = 0, 0
for images, gt_labels in dataloader:
predictions = model(images)
softmax_probabilities = Utils.softmax(predictions)
for i in range(len(images)):
plt.imshow(images[i].permute(1, 2, 0).numpy() / 255.0)
classif = False
if (gt_labels[i][1] > gt_labels[i][0] and softmax_probabilities[i][1] > softmax_probabilities[i][0]) or (gt_labels[i][1] <= gt_labels[i][0] and softmax_probabilities[i][1] <= softmax_probabilities[i][0]):
classif = True
count += 1
else:
notCount += 1
plt.title(str(softmax_probabilities[i]) + " " + str(gt_labels[i]) + str(classif))
plt.show()
print(softmax_probabilities[i])
batch_loss = loss_function(predictions, gt_labels)
loss += batch_loss
print("Loss is: ", loss / (len(dataloader) * batch_size))
print(count, notCount)
|