| import torch | |
| from collections import OrderedDict | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from model import get_model | |
| _preprocess = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| def load_model(weights_path: str, device: torch.device): | |
| checkpoint = torch.load(weights_path, map_location=device, weights_only=False) | |
| if "model_state_dict" not in checkpoint: | |
| raise KeyError("model_state_dict not found in checkpoint") | |
| state_dict = checkpoint["model_state_dict"] | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| new_state_dict[k.replace("module.", "")] = v | |
| model = get_model() | |
| model.load_state_dict(new_state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def preprocess(image): | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| return _preprocess(image).unsqueeze(0) | |
| def predict(image, model, device): | |
| model.eval() | |
| with torch.no_grad(): | |
| tensor = preprocess(image).to(device) | |
| output = model(tensor) | |
| pred = torch.argmax(output, dim=1).squeeze(0).cpu() | |
| return pred |