import torch from collections import OrderedDict import numpy as np from PIL import Image import torchvision.transforms as transforms from model import get_model _preprocess = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) def load_model(weights_path: str, device: torch.device): checkpoint = torch.load(weights_path, map_location=device, weights_only=False) if "model_state_dict" not in checkpoint: raise KeyError("model_state_dict not found in checkpoint") state_dict = checkpoint["model_state_dict"] new_state_dict = OrderedDict() for k, v in state_dict.items(): new_state_dict[k.replace("module.", "")] = v model = get_model() model.load_state_dict(new_state_dict) model.to(device) model.eval() return model def preprocess(image): if isinstance(image, np.ndarray): image = Image.fromarray(image) return _preprocess(image).unsqueeze(0) def predict(image, model, device): model.eval() with torch.no_grad(): tensor = preprocess(image).to(device) output = model(tensor) pred = torch.argmax(output, dim=1).squeeze(0).cpu() return pred