Spaces:
Sleeping
Sleeping
File size: 1,003 Bytes
7233ced 151b3b5 7233ced 7c50ef1 7233ced 7c50ef1 151b3b5 7c50ef1 7233ced 7c50ef1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
def build_alexnet(num_classes=2):
model = models.alexnet(pretrained=False)
in_features = model.classifier[6].in_features
model.classifier[6] = nn.Linear(in_features, num_classes)
return model
def load_alexnet_model(model_path, device=None):
# Load weights on CPU first (safer with CUDA init)
checkpoint = torch.load(model_path, map_location="cpu")
model = build_alexnet(len(checkpoint["classes"]))
model.load_state_dict(checkpoint["model_state"])
if device is not None:
model.to(device)
model.eval()
return model, checkpoint["classes"]
def preprocess_image(image: Image.Image) -> torch.Tensor:
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.4914,0.4822,0.4465], [0.2470,0.2435,0.2616]), # CIFAR MEAN and STD
])
return transform(image).unsqueeze(0)
|