BuildPlay / sketch /sketch.py
Kim Adams
updates
6960ade
raw
history blame
1.51 kB
import torch
from PIL import Image
import numpy as np
from torch import nn
from pathlib import Path
PATH="sketch/class_names.txt"
LABELS = Path(PATH).read_text().splitlines()
model = nn.Sequential(
nn.Conv2d(1, 32, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1152, 256),
nn.ReLU(),
nn.Linear(256, len(LABELS)),
)
state_dict = torch.load('sketch/pytorch_model.bin', map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model.eval()
def Predict(img):
# Convert to grayscale if the image is RGB
if img is not None and img.any():
if img.shape[-1] == 3:
img = np.mean(img, axis=-1)
# Convert the NumPy array to a PIL image
img_pil = Image.fromarray(img.astype('uint8'))
# Resize the image
img_resized = img_pil.resize((28, 28), Image.ANTIALIAS)
# Convert the PIL image back to a NumPy array
img_np = np.array(img_resized)
# Convert to tensor and normalize
x = torch.tensor(img_np, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
# Model prediction
with torch.no_grad():
out = model(x)
probabilities = torch.nn.functional.softmax(out[0], dim=0)
values, indices = torch.topk(probabilities, 5)
confidences = {LABELS[i]: v.item() for i, v in zip(indices, values)}
return confidences