fengruilin commited on
Commit
5dc2dc5
·
verified ·
1 Parent(s): 9bed718

try to fix mismatch

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -7,14 +7,15 @@ from torchvision import transforms
7
  # 1. Define your model class
8
  # --------------------------
9
  # Make sure this matches the architecture you used to train your model
 
 
 
10
  class MyCarClassifier(torch.nn.Module):
11
- def __init__(self, num_classes=196): # Stanford Cars has 196 classes
12
  super(MyCarClassifier, self).__init__()
13
- # Example: replace with your actual model architecture
14
- from torchvision.models import resnet50
15
  self.model = resnet50(pretrained=False)
16
- self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
17
-
18
  def forward(self, x):
19
  return self.model(x)
20
 
 
7
  # 1. Define your model class
8
  # --------------------------
9
  # Make sure this matches the architecture you used to train your model
10
+
11
+ from torchvision.models import resnet50
12
+
13
  class MyCarClassifier(torch.nn.Module):
14
+ def __init__(self, num_classes=196):
15
  super(MyCarClassifier, self).__init__()
 
 
16
  self.model = resnet50(pretrained=False)
17
+ in_ch = self.model.fc.in_features
18
+ self.model.fc = torch.nn.Linear(in_ch, num_classes)
19
  def forward(self, x):
20
  return self.model(x)
21