Maikuuuu commited on
Commit
08f9621
·
verified ·
1 Parent(s): e949d8e

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +28 -0
inference.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torchvision.models as models
5
+
6
+ # Load model
7
+ model = models.resnet18(pretrained=False, num_classes=3)
8
+ model.load_state_dict(torch.load("pytorch_model.bin", map_location=torch.device("cpu")))
9
+ model.eval()
10
+
11
+ # Preprocessing function
12
+ transform = transforms.Compose([
13
+ transforms.Resize((224, 224)),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize([0.485, 0.456, 0.406],
16
+ [0.229, 0.224, 0.225])
17
+ ])
18
+
19
+ # Labels
20
+ labels = ["A", "B", "C", "D", "E", "F", "G"]
21
+
22
+ # Required function
23
+ def predict(image: Image.Image):
24
+ img_tensor = transform(image).unsqueeze(0)
25
+ with torch.no_grad():
26
+ outputs = model(img_tensor)
27
+ probs = torch.nn.functional.softmax(outputs[0], dim=0)
28
+ return {labels[i]: float(probs[i]) for i in range(len(labels))}