File size: 754 Bytes
8aeeee9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from model.cnn import CNN
import torch
from dataset.dataset import inverted_translation
import matplotlib.pyplot as plt
from torchvision.transforms import functional as F
path = "model/epoch=599-step=187800.ckpt"
model = CNN.load_from_checkpoint(path)
model.eval()
def predict(image, get_dictionary=False):
image_tensor = image.view(1, 3, 32, 32)
result = model(image_tensor)
result = torch.softmax(result,dim=1)
result = result[0]
if get_dictionary:
dict_results = {}
for i in range(len(result)):
dict_results[inverted_translation[i]] = float(result[i])
return dict_results
else:
best = int(torch.argmax(result))
return inverted_translation[best]
|