gajavegs commited on
Commit
7233ced
·
verified ·
1 Parent(s): b2b7ddc

Create model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +28 -0
model_loader.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+
6
+ def build_alexnet(num_classes=2):
7
+ model = models.alexnet(pretrained=False)
8
+ in_features = model.classifier[6].in_features
9
+ model.classifier[6] = nn.Linear(in_features, num_classes)
10
+ return model
11
+
12
+ def load_alexnet_model(model_path, device=None):
13
+ # Load weights on CPU first (safer with CUDA init)
14
+ checkpoint = torch.load(model_path, map_location="cpu")
15
+ model = build_alexnet(len(checkpoint["classes"]))
16
+ model.load_state_dict(checkpoint["model_state"])
17
+ if device is not None:
18
+ model.to(device)
19
+ model.eval()
20
+ return model, checkpoint["classes"]
21
+
22
+ def preprocess_image(image: Image.Image) -> torch.Tensor:
23
+ transform = transforms.Compose([
24
+ transforms.Resize((224, 224)),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
27
+ ])
28
+ return transform(image).unsqueeze(0)