KurtHHHHHH commited on
Commit
3891ac3
·
verified ·
1 Parent(s): 5d48d9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -24,16 +24,17 @@ model = resnet50(weights=None)
24
 
25
  # Replace fc with the multi-layer head that matches checkpoint
26
  in_ch = model.fc.in_features
27
- model.fc = torch.nn.Sequential()
28
- model.fc.add_module('1', torch.nn.Linear(in_ch, 1024)) # fc.1
29
- model.fc.add_module('2', torch.nn.ReLU())
30
- model.fc.add_module('3', torch.nn.BatchNorm1d(1024)) # fc.3
31
- model.fc.add_module('4', torch.nn.ReLU())
32
- model.fc.add_module('5', torch.nn.Linear(1024, 512)) # fc.5
33
- model.fc.add_module('6', torch.nn.ReLU())
34
- model.fc.add_module('7', torch.nn.BatchNorm1d(512)) # fc.7
35
- model.fc.add_module('8', torch.nn.ReLU())
36
- model.fc.add_module('9', torch.nn.Linear(512, 196)) # fc.9
 
37
 
38
  # Load the state dict
39
  model.load_state_dict(state_dict, strict=True)
 
24
 
25
  # Replace fc with the multi-layer head that matches checkpoint
26
  in_ch = model.fc.in_features
27
+ model.fc = torch.nn.Linear(in_ch, num_classes)
28
+ # model.fc = torch.nn.Sequential()
29
+ # model.fc.add_module('1', torch.nn.Linear(in_ch, 1024)) # fc.1
30
+ # model.fc.add_module('2', torch.nn.ReLU())
31
+ # model.fc.add_module('3', torch.nn.BatchNorm1d(1024)) # fc.3
32
+ # model.fc.add_module('4', torch.nn.ReLU())
33
+ # model.fc.add_module('5', torch.nn.Linear(1024, 512)) # fc.5
34
+ # model.fc.add_module('6', torch.nn.ReLU())
35
+ # model.fc.add_module('7', torch.nn.BatchNorm1d(512)) # fc.7
36
+ # model.fc.add_module('8', torch.nn.ReLU())
37
+ # model.fc.add_module('9', torch.nn.Linear(512, 196)) # fc.9
38
 
39
  # Load the state dict
40
  model.load_state_dict(state_dict, strict=True)