fengruilin commited on
Commit
4e3b057
·
verified ·
1 Parent(s): bbf5407

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -23,15 +23,16 @@ model = resnet50(weights=None)
23
 
24
  # Replace fc with the multi-layer head that matches checkpoint
25
  in_ch = model.fc.in_features
26
- model.fc = torch.nn.Sequential(
27
- torch.nn.Linear(in_ch, 1024), # fc.1
28
- torch.nn.ReLU(),
29
- torch.nn.BatchNorm1d(1024), # fc.3
30
- torch.nn.Linear(1024, 512), # fc.5
31
- torch.nn.ReLU(),
32
- torch.nn.BatchNorm1d(512), # fc.7
33
- torch.nn.Linear(512, 196) # fc.9
34
- )
 
35
 
36
  # Load the state dict
37
  model.load_state_dict(state_dict, strict=True)
 
23
 
24
  # Replace fc with the multi-layer head that matches checkpoint
25
  in_ch = model.fc.in_features
26
+ model.fc = torch.nn.Sequential()
27
+ model.fc.add_module('1', torch.nn.Linear(in_ch, 1024)) # fc.1
28
+ model.fc.add_module('2', torch.nn.ReLU())
29
+ model.fc.add_module('3', torch.nn.BatchNorm1d(1024)) # fc.3
30
+ model.fc.add_module('4', torch.nn.ReLU())
31
+ model.fc.add_module('5', torch.nn.Linear(1024, 512)) # fc.5
32
+ model.fc.add_module('6', torch.nn.ReLU())
33
+ model.fc.add_module('7', torch.nn.BatchNorm1d(512)) # fc.7
34
+ model.fc.add_module('8', torch.nn.ReLU())
35
+ model.fc.add_module('9', torch.nn.Linear(512, 196)) # fc.9
36
 
37
  # Load the state dict
38
  model.load_state_dict(state_dict, strict=True)