CSSE416_Classifier_Demo / model_loader.py
gajavegs's picture
Update model_loader.py
7c50ef1 verified
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
def build_alexnet(num_classes=2):
model = models.alexnet(pretrained=False)
in_features = model.classifier[6].in_features
model.classifier[6] = nn.Linear(in_features, num_classes)
return model
def load_alexnet_model(model_path, device=None):
# Load weights on CPU first (safer with CUDA init)
checkpoint = torch.load(model_path, map_location="cpu")
model = build_alexnet(len(checkpoint["classes"]))
model.load_state_dict(checkpoint["model_state"])
if device is not None:
model.to(device)
model.eval()
return model, checkpoint["classes"]
def preprocess_image(image: Image.Image) -> torch.Tensor:
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.4914,0.4822,0.4465], [0.2470,0.2435,0.2616]), # CIFAR MEAN and STD
])
return transform(image).unsqueeze(0)