gajavegs commited on
Commit
151b3b5
·
verified ·
1 Parent(s): e47b149

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +23 -14
model_loader.py CHANGED
@@ -2,6 +2,9 @@ import torch
2
  import torch.nn as nn
3
  from torchvision import models, transforms
4
  from PIL import Image
 
 
 
5
 
6
  def build_alexnet(num_classes=2):
7
  model = models.alexnet(pretrained=False)
@@ -9,20 +12,26 @@ def build_alexnet(num_classes=2):
9
  model.classifier[6] = nn.Linear(in_features, num_classes)
10
  return model
11
 
12
- def load_alexnet_model(model_path, device=None):
13
- # Load weights on CPU first (safer with CUDA init)
14
- checkpoint = torch.load(model_path, map_location="cpu")
15
- model = build_alexnet(len(checkpoint["classes"]))
16
- model.load_state_dict(checkpoint["model_state"])
17
- if device is not None:
18
- model.to(device)
19
  model.eval()
20
- return model, checkpoint["classes"]
 
 
21
 
22
- def preprocess_image(image: Image.Image) -> torch.Tensor:
23
- transform = transforms.Compose([
24
- transforms.Resize((224, 224)),
25
  transforms.ToTensor(),
26
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
27
- ])
28
- return transform(image).unsqueeze(0)
 
 
 
 
 
 
 
 
2
  import torch.nn as nn
3
  from torchvision import models, transforms
4
  from PIL import Image
5
+ import os
6
+
7
+ DEFAULT_IMG_SIZE = int(os.getenv("IMG_SIZE", "32"))
8
 
9
  def build_alexnet(num_classes=2):
10
  model = models.alexnet(pretrained=False)
 
12
  model.classifier[6] = nn.Linear(in_features, num_classes)
13
  return model
14
 
15
+ def load_alexnet_model(model_path):
16
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
17
+ model = build_alexnet(len(checkpoint['classes']))
18
+ model.load_state_dict(checkpoint['model_state'])
 
 
 
19
  model.eval()
20
+ return model, checkpoint['classes']
21
+
22
+ DEFAULT_IMG_SIZE = int(os.getenv("IMG_SIZE", "224")) # changed default to match training (224x224)
23
 
24
+ def preprocess_image(image: Image.Image, normalize: bool = True) -> torch.Tensor:
25
+ transform_list = [
26
+ transforms.Resize((DEFAULT_IMG_SIZE, DEFAULT_IMG_SIZE), Image.Resampling.LANCZOS),
27
  transforms.ToTensor(),
28
+ ]
29
+
30
+ # default to ImageNet stats used in training; can override with IMG_MEAN / IMG_STD env vars
31
+ IMAGENET_MEAN = list(map(float, os.getenv("IMG_MEAN", "0.485,0.456,0.406").split(",")))
32
+ IMAGENET_STD = list(map(float, os.getenv("IMG_STD", "0.229,0.224,0.225").split(",")))
33
+
34
+ if normalize:
35
+ transform_list.append(transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD))
36
+ transform = transforms.Compose(transform_list)
37
+ return transform(image).unsqueeze(0)