fengruilin commited on
Commit
45a5c8a
·
verified ·
1 Parent(s): 86bf76b

trying still.....

Browse files
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -8,23 +8,36 @@ from torchvision import transforms
8
  # --------------------------
9
  # Make sure this matches the architecture you used to train your model
10
 
 
11
  from torchvision.models import resnet50
 
12
 
13
- class MyCarClassifier(torch.nn.Module):
14
  def __init__(self, num_classes=196):
15
  super(MyCarClassifier, self).__init__()
16
- self.model = resnet50(pretrained=False)
17
- in_ch = self.model.fc.in_features
18
- self.model.fc = torch.nn.Linear(in_ch, num_classes)
 
 
 
 
 
 
 
 
 
 
19
  def forward(self, x):
20
- return self.model(x)
 
 
21
 
22
  # --------------------------
23
  # 2. Load model weights
24
  # --------------------------
25
  model = MyCarClassifier()
26
- state_dict = torch.load("best_stanford_cars_transfer_model.pth", map_location="cpu")
27
- model.model.load_state_dict(state_dict, strict=True)
28
  model.eval() # important for inference
29
 
30
  # --------------------------
 
8
  # --------------------------
9
  # Make sure this matches the architecture you used to train your model
10
 
11
+
12
  from torchvision.models import resnet50
13
+ import torch.nn as nn
14
 
15
+ class MyCarClassifier(nn.Module):
16
  def __init__(self, num_classes=196):
17
  super(MyCarClassifier, self).__init__()
18
+ self.backbone = resnet50(weights=None)
19
+ in_ch = self.backbone.fc.in_features
20
+ self.backbone.fc = nn.Identity()
21
+ self.fc = nn.Sequential(
22
+ nn.Linear(in_ch, 512), # fc.1
23
+ nn.ReLU(),
24
+ nn.BatchNorm1d(512), # fc.3
25
+ nn.Linear(512, 256), # fc.5
26
+ nn.ReLU(),
27
+ nn.BatchNorm1d(256), # fc.7
28
+ nn.Linear(256, num_classes) # fc.9
29
+ )
30
+
31
  def forward(self, x):
32
+ x = self.backbone(x)
33
+ x = self.fc(x)
34
+ return x
35
 
36
  # --------------------------
37
  # 2. Load model weights
38
  # --------------------------
39
  model = MyCarClassifier()
40
+ model.load_state_dict(torch.load("best_stanford_cars_transfer_model.pth", map_location="cpu"), strict=True)
 
41
  model.eval() # important for inference
42
 
43
  # --------------------------