ChordDetectorv2 / inference.py
Maikuuuu's picture
Create inference.py
08f9621 verified
from PIL import Image
import torch
import torchvision.transforms as transforms
import torchvision.models as models
# Load model
model = models.resnet18(pretrained=False, num_classes=3)
model.load_state_dict(torch.load("pytorch_model.bin", map_location=torch.device("cpu")))
model.eval()
# Preprocessing function
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# Labels
labels = ["A", "B", "C", "D", "E", "F", "G"]
# Required function
def predict(image: Image.Image):
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
probs = torch.nn.functional.softmax(outputs[0], dim=0)
return {labels[i]: float(probs[i]) for i in range(len(labels))}