| from model.cnn import CNN | |
| import torch | |
| from dataset.dataset import inverted_translation | |
| import matplotlib.pyplot as plt | |
| from torchvision.transforms import functional as F | |
| path = "model/epoch=599-step=187800.ckpt" | |
| model = CNN.load_from_checkpoint(path) | |
| model.eval() | |
| def predict(image, get_dictionary=False): | |
| image_tensor = image.view(1, 3, 32, 32) | |
| result = model(image_tensor) | |
| result = torch.softmax(result,dim=1) | |
| result = result[0] | |
| if get_dictionary: | |
| dict_results = {} | |
| for i in range(len(result)): | |
| dict_results[inverted_translation[i]] = float(result[i]) | |
| return dict_results | |
| else: | |
| best = int(torch.argmax(result)) | |
| return inverted_translation[best] | |