File size: 743 Bytes
c3f9ef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torchvision
from PIL import Image
from model import *

test_data = torchvision.datasets.CIFAR10("CIFAR10", False, download=False)
print(test_data.class_to_idx)

image_path = ""     # Your test image
image = Image.open(image_path)
print(image)
image = image.convert("RGB")

transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
                                            torchvision.transforms.ToTensor()])

image = transform(image)
print(image.shape)

model = torch.load("./Mini-Vision-V1.pth", weights_only=False)

image = torch.reshape(image, (1, 3, 32, 32))

model.eval()
with torch.no_grad():
    output = model(image)

print(output)
print(output.argmax(1))