File size: 754 Bytes
8aeeee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from model.cnn import CNN
import torch
from dataset.dataset import inverted_translation
import matplotlib.pyplot as plt 
from torchvision.transforms import functional as F

path = "model/epoch=599-step=187800.ckpt"
model = CNN.load_from_checkpoint(path)
model.eval()

def predict(image, get_dictionary=False):
    
    image_tensor = image.view(1, 3, 32, 32)
    result = model(image_tensor)
    result = torch.softmax(result,dim=1)
    result = result[0]

    if get_dictionary:
        dict_results = {}
        
        for i in range(len(result)):
            dict_results[inverted_translation[i]] = float(result[i])

        return dict_results

    else:
        best = int(torch.argmax(result))        
        return inverted_translation[best]