fengruilin commited on
Commit
9bed718
·
verified ·
1 Parent(s): b089015

fix resnet mismatch

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -11,8 +11,8 @@ class MyCarClassifier(torch.nn.Module):
11
  def __init__(self, num_classes=196): # Stanford Cars has 196 classes
12
  super(MyCarClassifier, self).__init__()
13
  # Example: replace with your actual model architecture
14
- from torchvision.models import resnet18
15
- self.model = resnet18(pretrained=False)
16
  self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
17
 
18
  def forward(self, x):
@@ -22,7 +22,7 @@ class MyCarClassifier(torch.nn.Module):
22
  # 2. Load model weights
23
  # --------------------------
24
  model = MyCarClassifier()
25
- model.load_state_dict(torch.load("best_stanford_cars_transfer_model.pth", map_location="cpu"), strict=False)
26
  model.eval() # important for inference
27
 
28
  # --------------------------
 
11
  def __init__(self, num_classes=196): # Stanford Cars has 196 classes
12
  super(MyCarClassifier, self).__init__()
13
  # Example: replace with your actual model architecture
14
+ from torchvision.models import resnet50
15
+ self.model = resnet50(pretrained=False)
16
  self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
17
 
18
  def forward(self, x):
 
22
  # 2. Load model weights
23
  # --------------------------
24
  model = MyCarClassifier()
25
+ model.load_state_dict(torch.load("best_stanford_cars_transfer_model.pth", map_location="cpu"), strict=True)
26
  model.eval() # important for inference
27
 
28
  # --------------------------