Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| def build_alexnet(num_classes=2): | |
| model = models.alexnet(pretrained=False) | |
| in_features = model.classifier[6].in_features | |
| model.classifier[6] = nn.Linear(in_features, num_classes) | |
| return model | |
| def load_alexnet_model(model_path, device=None): | |
| # Load weights on CPU first (safer with CUDA init) | |
| checkpoint = torch.load(model_path, map_location="cpu") | |
| model = build_alexnet(len(checkpoint["classes"])) | |
| model.load_state_dict(checkpoint["model_state"]) | |
| if device is not None: | |
| model.to(device) | |
| model.eval() | |
| return model, checkpoint["classes"] | |
| def preprocess_image(image: Image.Image) -> torch.Tensor: | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.4914,0.4822,0.4465], [0.2470,0.2435,0.2616]), # CIFAR MEAN and STD | |
| ]) | |
| return transform(image).unsqueeze(0) | |