chengan-jaime commited on
Commit
dfdb0da
·
1 Parent(s): ca0e2d5

code update

Browse files
Files changed (1) hide show
  1. model_loader.py +24 -0
model_loader.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+
5
+ def build_SurgFM(nclasses: int = 2, pretrained: bool = True, pretrained_weights = None):
6
+
7
+
8
+ #net of ConvNext
9
+ net = torchvision.models.convnext_large()
10
+ input_emdim = net.classifier[2].in_features
11
+ net.classifier[2] = nn.Identity()
12
+
13
+ if os.path.isfile(pretrained_weights):
14
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
15
+ state_dict = state_dict['teacher']
16
+
17
+ # remove `backbone.` prefix induced by multicrop wrapper
18
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith('backbone.')}
19
+ msg = net.load_state_dict(state_dict, strict=False)
20
+ print(msg, input_emdim)
21
+
22
+ net.cuda()
23
+
24
+ return net