Update app.py
Browse files
app.py
CHANGED
|
@@ -24,16 +24,17 @@ model = resnet50(weights=None)
|
|
| 24 |
|
| 25 |
# Replace fc with the multi-layer head that matches checkpoint
|
| 26 |
in_ch = model.fc.in_features
|
| 27 |
-
model.fc = torch.nn.
|
| 28 |
-
model.fc
|
| 29 |
-
model.fc.add_module('
|
| 30 |
-
model.fc.add_module('
|
| 31 |
-
model.fc.add_module('
|
| 32 |
-
model.fc.add_module('
|
| 33 |
-
model.fc.add_module('
|
| 34 |
-
model.fc.add_module('
|
| 35 |
-
model.fc.add_module('
|
| 36 |
-
model.fc.add_module('
|
|
|
|
| 37 |
|
| 38 |
# Load the state dict
|
| 39 |
model.load_state_dict(state_dict, strict=True)
|
|
|
|
| 24 |
|
| 25 |
# Replace fc with the multi-layer head that matches checkpoint
|
| 26 |
in_ch = model.fc.in_features
|
| 27 |
+
model.fc = torch.nn.Linear(in_ch, num_classes)
|
| 28 |
+
# model.fc = torch.nn.Sequential()
|
| 29 |
+
# model.fc.add_module('1', torch.nn.Linear(in_ch, 1024)) # fc.1
|
| 30 |
+
# model.fc.add_module('2', torch.nn.ReLU())
|
| 31 |
+
# model.fc.add_module('3', torch.nn.BatchNorm1d(1024)) # fc.3
|
| 32 |
+
# model.fc.add_module('4', torch.nn.ReLU())
|
| 33 |
+
# model.fc.add_module('5', torch.nn.Linear(1024, 512)) # fc.5
|
| 34 |
+
# model.fc.add_module('6', torch.nn.ReLU())
|
| 35 |
+
# model.fc.add_module('7', torch.nn.BatchNorm1d(512)) # fc.7
|
| 36 |
+
# model.fc.add_module('8', torch.nn.ReLU())
|
| 37 |
+
# model.fc.add_module('9', torch.nn.Linear(512, 196)) # fc.9
|
| 38 |
|
| 39 |
# Load the state dict
|
| 40 |
model.load_state_dict(state_dict, strict=True)
|