fengruilin commited on
Commit
bbf5407
·
verified ·
1 Parent(s): 92c7f7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -25
app.py CHANGED
@@ -9,35 +9,32 @@ from torchvision import transforms
9
  # Make sure this matches the architecture you used to train your model
10
 
11
 
12
- from torchvision.models import resnet50
13
- import torch.nn as nn
14
-
15
- class MyCarClassifier(nn.Module):
16
- def __init__(self, num_classes=196):
17
- super(MyCarClassifier, self).__init__()
18
- self.backbone = resnet50(weights=None)
19
- in_ch = self.backbone.fc.in_features
20
- self.backbone.fc = nn.Identity()
21
- self.fc = nn.Sequential(
22
- nn.Linear(in_ch, 1024), # fc.1
23
- nn.ReLU(),
24
- nn.BatchNorm1d(1024), # fc.3
25
- nn.Linear(1024, 512), # fc.5
26
- nn.ReLU(),
27
- nn.BatchNorm1d(512), # fc.7
28
- nn.Linear(512, num_classes) # fc.9
29
- )
30
-
31
- def forward(self, x):
32
- x = self.backbone(x)
33
- x = self.fc(x)
34
- return x
35
 
36
  # --------------------------
37
  # 2. Load model weights
38
  # --------------------------
39
- model = MyCarClassifier()
40
- model.load_state_dict(torch.load("best_stanford_cars_transfer_model.pth", map_location="cpu"), strict=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  model.eval() # important for inference
42
 
43
  # --------------------------
 
9
  # Make sure this matches the architecture you used to train your model
10
 
11
 
12
+ # Model definition is now handled directly in the loading section
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # --------------------------
15
  # 2. Load model weights
16
  # --------------------------
17
+ # Load the checkpoint directly as it was saved (a plain ResNet50 with custom fc head)
18
+ state_dict = torch.load("best_stanford_cars_transfer_model.pth", map_location="cpu")
19
+
20
+ # Create a ResNet50 and modify its fc to match the checkpoint
21
+ from torchvision.models import resnet50
22
+ model = resnet50(weights=None)
23
+
24
+ # Replace fc with the multi-layer head that matches checkpoint
25
+ in_ch = model.fc.in_features
26
+ model.fc = torch.nn.Sequential(
27
+ torch.nn.Linear(in_ch, 1024), # fc.1
28
+ torch.nn.ReLU(),
29
+ torch.nn.BatchNorm1d(1024), # fc.3
30
+ torch.nn.Linear(1024, 512), # fc.5
31
+ torch.nn.ReLU(),
32
+ torch.nn.BatchNorm1d(512), # fc.7
33
+ torch.nn.Linear(512, 196) # fc.9
34
+ )
35
+
36
+ # Load the state dict
37
+ model.load_state_dict(state_dict, strict=True)
38
  model.eval() # important for inference
39
 
40
  # --------------------------