File size: 1,003 Bytes
7233ced
 
 
 
151b3b5
7233ced
 
 
 
 
 
7c50ef1
 
 
 
 
 
 
7233ced
7c50ef1
151b3b5
7c50ef1
 
 
7233ced
7c50ef1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image

def build_alexnet(num_classes=2):
    model = models.alexnet(pretrained=False)
    in_features = model.classifier[6].in_features
    model.classifier[6] = nn.Linear(in_features, num_classes)
    return model

def load_alexnet_model(model_path, device=None):
    # Load weights on CPU first (safer with CUDA init)
    checkpoint = torch.load(model_path, map_location="cpu")
    model = build_alexnet(len(checkpoint["classes"]))
    model.load_state_dict(checkpoint["model_state"])
    if device is not None:
        model.to(device)
    model.eval()
    return model, checkpoint["classes"]

def preprocess_image(image: Image.Image) -> torch.Tensor:
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.4914,0.4822,0.4465], [0.2470,0.2435,0.2616]), # CIFAR MEAN and STD
    ])
    return transform(image).unsqueeze(0)