Spaces:
Sleeping
Sleeping
Update core/networks.py
Browse files- core/networks.py +33 -33
core/networks.py
CHANGED
|
@@ -34,45 +34,45 @@ def group_norm(features):
|
|
| 34 |
#######################################################################
|
| 35 |
|
| 36 |
class Backbone(nn.Module, ABC_Model):
|
| 37 |
-
|
| 38 |
-
|
| 39 |
|
| 40 |
-
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
| 56 |
else:
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
self.stage3 = nn.Sequential(self.model.layer2)
|
| 74 |
-
self.stage4 = nn.Sequential(self.model.layer3)
|
| 75 |
-
self.stage5 = nn.Sequential(self.model.layer4)
|
| 76 |
|
| 77 |
class Classifier(Backbone):
|
| 78 |
def __init__(self, model_name, state_path, num_classes=20, mode='fix'):
|
|
|
|
| 34 |
#######################################################################
|
| 35 |
|
| 36 |
class Backbone(nn.Module, ABC_Model):
|
| 37 |
+
def __init__(self, model_name, num_classes=20, mode='fix', segmentation=False):
|
| 38 |
+
super().__init__()
|
| 39 |
|
| 40 |
+
self.mode = mode
|
| 41 |
|
| 42 |
+
if self.mode == 'fix':
|
| 43 |
+
self.norm_fn = FixedBatchNorm
|
| 44 |
+
else:
|
| 45 |
+
self.norm_fn = nn.BatchNorm2d
|
| 46 |
|
| 47 |
+
if 'resnet' in model_name:
|
| 48 |
+
self.model = resnet.ResNet(resnet.Bottleneck, resnet.layers_dic[model_name], strides=(2, 2, 2, 1),
|
| 49 |
+
batch_norm_fn=self.norm_fn)
|
| 50 |
|
| 51 |
+
state_dict = model_zoo.load_url(resnet.urls_dic[model_name])
|
| 52 |
+
state_dict.pop('fc.weight')
|
| 53 |
+
state_dict.pop('fc.bias')
|
| 54 |
|
| 55 |
+
self.model.load_state_dict(state_dict)
|
| 56 |
+
else:
|
| 57 |
+
if segmentation:
|
| 58 |
+
dilation, dilated = 4, True
|
| 59 |
else:
|
| 60 |
+
dilation, dilated = 2, False
|
| 61 |
+
|
| 62 |
+
self.model = eval("resnest." + model_name)(pretrained=True, dilated=dilated, dilation=dilation,
|
| 63 |
+
norm_layer=self.norm_fn)
|
| 64 |
+
|
| 65 |
+
del self.model.avgpool
|
| 66 |
+
del self.model.fc
|
| 67 |
+
|
| 68 |
+
self.stage1 = nn.Sequential(self.model.conv1,
|
| 69 |
+
self.model.bn1,
|
| 70 |
+
self.model.relu,
|
| 71 |
+
self.model.maxpool)
|
| 72 |
+
self.stage2 = nn.Sequential(self.model.layer1)
|
| 73 |
+
self.stage3 = nn.Sequential(self.model.layer2)
|
| 74 |
+
self.stage4 = nn.Sequential(self.model.layer3)
|
| 75 |
+
self.stage5 = nn.Sequential(self.model.layer4)
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
class Classifier(Backbone):
|
| 78 |
def __init__(self, model_name, state_path, num_classes=20, mode='fix'):
|