import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image def build_alexnet(num_classes=2): model = models.alexnet(pretrained=False) in_features = model.classifier[6].in_features model.classifier[6] = nn.Linear(in_features, num_classes) return model def load_alexnet_model(model_path, device=None): # Load weights on CPU first (safer with CUDA init) checkpoint = torch.load(model_path, map_location="cpu") model = build_alexnet(len(checkpoint["classes"])) model.load_state_dict(checkpoint["model_state"]) if device is not None: model.to(device) model.eval() return model, checkpoint["classes"] def preprocess_image(image: Image.Image) -> torch.Tensor: transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.4914,0.4822,0.4465], [0.2470,0.2435,0.2616]), # CIFAR MEAN and STD ]) return transform(image).unsqueeze(0)