File size: 1,286 Bytes
c5bce9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from collections import OrderedDict
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from model import get_model

_preprocess = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

def load_model(weights_path: str, device: torch.device):
    checkpoint = torch.load(weights_path, map_location=device, weights_only=False)
    if "model_state_dict" not in checkpoint:
        raise KeyError("model_state_dict not found in checkpoint")
    state_dict = checkpoint["model_state_dict"]
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_state_dict[k.replace("module.", "")] = v
    model = get_model()
    model.load_state_dict(new_state_dict)
    model.to(device)
    model.eval()
    return model

def preprocess(image):
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    return _preprocess(image).unsqueeze(0)

def predict(image, model, device):
    model.eval()
    with torch.no_grad():
        tensor = preprocess(image).to(device)
        output = model(tensor)
        pred = torch.argmax(output, dim=1).squeeze(0).cpu()
    return pred