try to fix mismatch
Browse files
app.py
CHANGED
|
@@ -7,14 +7,15 @@ from torchvision import transforms
|
|
| 7 |
# 1. Define your model class
|
| 8 |
# --------------------------
|
| 9 |
# Make sure this matches the architecture you used to train your model
|
|
|
|
|
|
|
|
|
|
| 10 |
class MyCarClassifier(torch.nn.Module):
|
| 11 |
-
def __init__(self, num_classes=196):
|
| 12 |
super(MyCarClassifier, self).__init__()
|
| 13 |
-
# Example: replace with your actual model architecture
|
| 14 |
-
from torchvision.models import resnet50
|
| 15 |
self.model = resnet50(pretrained=False)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
def forward(self, x):
|
| 19 |
return self.model(x)
|
| 20 |
|
|
|
|
| 7 |
# 1. Define your model class
|
| 8 |
# --------------------------
|
| 9 |
# Make sure this matches the architecture you used to train your model
|
| 10 |
+
|
| 11 |
+
from torchvision.models import resnet50
|
| 12 |
+
|
| 13 |
class MyCarClassifier(torch.nn.Module):
|
| 14 |
+
def __init__(self, num_classes=196):
|
| 15 |
super(MyCarClassifier, self).__init__()
|
|
|
|
|
|
|
| 16 |
self.model = resnet50(pretrained=False)
|
| 17 |
+
in_ch = self.model.fc.in_features
|
| 18 |
+
self.model.fc = torch.nn.Linear(in_ch, num_classes)
|
| 19 |
def forward(self, x):
|
| 20 |
return self.model(x)
|
| 21 |
|