g-loubna
Space: download weights from model repo
c5bce9d
import torch
from collections import OrderedDict
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from model import get_model
_preprocess = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def load_model(weights_path: str, device: torch.device):
checkpoint = torch.load(weights_path, map_location=device, weights_only=False)
if "model_state_dict" not in checkpoint:
raise KeyError("model_state_dict not found in checkpoint")
state_dict = checkpoint["model_state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k.replace("module.", "")] = v
model = get_model()
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()
return model
def preprocess(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
return _preprocess(image).unsqueeze(0)
def predict(image, model, device):
model.eval()
with torch.no_grad():
tensor = preprocess(image).to(device)
output = model(tensor)
pred = torch.argmax(output, dim=1).squeeze(0).cpu()
return pred