fix resnet mismatch
Browse files
app.py
CHANGED
|
@@ -11,8 +11,8 @@ class MyCarClassifier(torch.nn.Module):
|
|
| 11 |
def __init__(self, num_classes=196): # Stanford Cars has 196 classes
|
| 12 |
super(MyCarClassifier, self).__init__()
|
| 13 |
# Example: replace with your actual model architecture
|
| 14 |
-
from torchvision.models import
|
| 15 |
-
self.model =
|
| 16 |
self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
|
| 17 |
|
| 18 |
def forward(self, x):
|
|
@@ -22,7 +22,7 @@ class MyCarClassifier(torch.nn.Module):
|
|
| 22 |
# 2. Load model weights
|
| 23 |
# --------------------------
|
| 24 |
model = MyCarClassifier()
|
| 25 |
-
model.load_state_dict(torch.load("best_stanford_cars_transfer_model.pth", map_location="cpu"), strict=
|
| 26 |
model.eval() # important for inference
|
| 27 |
|
| 28 |
# --------------------------
|
|
|
|
| 11 |
def __init__(self, num_classes=196): # Stanford Cars has 196 classes
|
| 12 |
super(MyCarClassifier, self).__init__()
|
| 13 |
# Example: replace with your actual model architecture
|
| 14 |
+
from torchvision.models import resnet50
|
| 15 |
+
self.model = resnet50(pretrained=False)
|
| 16 |
self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
|
| 17 |
|
| 18 |
def forward(self, x):
|
|
|
|
| 22 |
# 2. Load model weights
|
| 23 |
# --------------------------
|
| 24 |
model = MyCarClassifier()
|
| 25 |
+
model.load_state_dict(torch.load("best_stanford_cars_transfer_model.pth", map_location="cpu"), strict=True)
|
| 26 |
model.eval() # important for inference
|
| 27 |
|
| 28 |
# --------------------------
|