NoelShin commited on
Commit
82d5d16
·
1 Parent(s): b8f8de8

Add application file

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
.idea/workspace.xml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="AutoImportSettings">
4
+ <option name="autoReloadType" value="SELECTIVE" />
5
+ </component>
6
+ <component name="ChangeListManager">
7
+ <list default="true" id="9fb9e207-fc4f-4ff3-9adc-3c4c1e67daa7" name="Changes" comment="">
8
+ <change beforePath="$PROJECT_DIR$/README.md" beforeDir="false" afterPath="$PROJECT_DIR$/README.md" afterDir="false" />
9
+ </list>
10
+ <option name="SHOW_DIALOG" value="false" />
11
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
12
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
13
+ <option name="LAST_RESOLUTION" value="IGNORE" />
14
+ </component>
15
+ <component name="Git.Settings">
16
+ <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
17
+ </component>
18
+ <component name="ProjectId" id="2FJJUIvRiY0OO5Dz2zvs0pNxkhb" />
19
+ <component name="ProjectLevelVcsManager" settingsEditedManually="true" />
20
+ <component name="ProjectViewState">
21
+ <option name="hideEmptyMiddlePackages" value="true" />
22
+ <option name="showLibraryContents" value="true" />
23
+ <option name="showMembers" value="true" />
24
+ </component>
25
+ <component name="PropertiesComponent">
26
+ <property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
27
+ <property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
28
+ <property name="WebServerToolWindowFactoryState" value="false" />
29
+ <property name="last_opened_file_path" value="$PROJECT_DIR$" />
30
+ <property name="node.js.detected.package.eslint" value="true" />
31
+ <property name="node.js.detected.package.tslint" value="true" />
32
+ <property name="node.js.selected.package.eslint" value="(autodetect)" />
33
+ <property name="node.js.selected.package.tslint" value="(autodetect)" />
34
+ <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
35
+ </component>
36
+ <component name="RecentsManager">
37
+ <key name="CopyFile.RECENT_KEYS">
38
+ <recent name="$PROJECT_DIR$" />
39
+ </key>
40
+ </component>
41
+ <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
42
+ <component name="TaskManager">
43
+ <task active="true" id="Default" summary="Default task">
44
+ <changelist id="9fb9e207-fc4f-4ff3-9adc-3c4c1e67daa7" name="Changes" comment="" />
45
+ <created>1664204268713</created>
46
+ <option name="number" value="Default" />
47
+ <option name="presentableId" value="Default" />
48
+ <updated>1664204268713</updated>
49
+ <workItem from="1664204270261" duration="37000" />
50
+ <workItem from="1664204316867" duration="4389000" />
51
+ </task>
52
+ <servers />
53
+ </component>
54
+ <component name="TypeScriptGeneratedFilesManager">
55
+ <option name="version" value="3" />
56
+ </component>
57
+ </project>
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Namedmask
3
- emoji: 🌍
4
- colorFrom: red
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.4
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
+ title: namedmask
3
+ emoji: 😷
4
+ colorFrom: gray
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 2.9.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser, Namespace
2
+ from typing import Dict, List, Tuple
3
+ import codecs
4
+ import yaml
5
+ import numpy as np
6
+ import cv2
7
+ from PIL import Image
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torchvision.transforms.functional import to_tensor, normalize, resize
11
+ import gradio as gr
12
+ from utils import get_network, colourise_mask
13
+ import os
14
+
15
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ # state_dict: dict = torch.hub.load_state_dict_from_url(
19
+ # "https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
20
+ # map_location=device # "cuda" if torch.cuda.is_available() else "cpu"
21
+ # )["model"]
22
+
23
+ parser = ArgumentParser("NamedMask demo")
24
+ parser.add_argument(
25
+ "--config",
26
+ type=str,
27
+ default="voc_val_n500_cp2_ex.yaml"
28
+ )
29
+
30
+ args: Namespace = parser.parse_args()
31
+ base_args = yaml.safe_load(open(f"{args.config}", 'r'))
32
+ base_args.pop("dataset_name")
33
+ args: dict = vars(args)
34
+ args.update(base_args)
35
+ args: Namespace = Namespace(**args)
36
+
37
+ model = get_network().to(device)
38
+ # model.load_state_dict(state_dict)
39
+ model.eval()
40
+
41
+ size: int = 384
42
+ max_size: int = 512
43
+ mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
44
+ std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
45
+
46
+
47
+ @torch.no_grad()
48
+ def main(image: Image):
49
+ pil_image: Image.Image = resize(image, size=size, max_size=max_size)
50
+ image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W
51
+
52
+ # logits: b (=1) x n_categories x H x W, torch.float32
53
+ logits: torch.Tensor = model(image[None].to(device))
54
+
55
+ # pred: H x W
56
+ pred: torch.Tensor = logits.squeeze(dim=0).argmax(dim=0).cpu().numpy()
57
+ coloured_pred: np.ndarray = colourise_mask(mask=pred.cpu().numpy())
58
+ super_imposed_img = cv2.addWeighted(coloured_pred, 0.5, np.array(pil_image), 0.5, 0)
59
+
60
+ # resize prediction to original resolution
61
+ # note: upsampling by 4 and cutting the padded region allows for a better result
62
+ # H, W = image.shape[-2:]
63
+ #
64
+ # # iterate over batch dimension
65
+ # pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255
66
+ #
67
+ # pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)
68
+ #
69
+ # attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
70
+ # super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
71
+ return super_imposed_img
72
+
73
+
74
+ demo = gr.Interface(
75
+ fn=main,
76
+ inputs=gr.inputs.Image(type="pil", source="upload", tool="editor"),
77
+ outputs=gr.outputs.Image(type="numpy", label="prediction"), # "image",
78
+ examples=[f"images/{fname}.jpg" for fname in [
79
+ "2007_002260",
80
+ "2008_002536",
81
+ "2008_003499",
82
+ "2008_007814",
83
+ "2009_004801",
84
+ "2010_001079",
85
+ "2010_005063"
86
+ ]],
87
+ examples_per_page=10,
88
+ description=codecs.open("description.html", 'r', "utf-8").read(),
89
+ title="NamedMask: Distilling Segmenters from Complementary Foundation Models",
90
+ allow_flagging="never",
91
+ analytics_enabled=False
92
+ )
93
+
94
+ demo.launch(
95
+ # share=True
96
+ )
description.html ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <title>Title</title>
6
+ </head>
7
+ <body>
8
+ This is a demo of <a href="https://arxiv.org/pdf/2209.11228.pdf">NamedMask: Distilling Segmenters from Complementary Foundation Models</a>.</br>
9
+ The goal of this work is to segment and name regions of images without access to pixel-level labels during training.
10
+ To tackle this task, we construct segmenters by distilling the complementary strengths of two foundation models.
11
+ The first, CLIP (Radford et al. 2021), exhibits the ability to assign names to image content but lacks an accessible representation of object structure.
12
+ The second, DINO (Caron et al. 2021), captures the spatial extent of objects but has no knowledge of object names.
13
+ Our method, termed NamedMask, begins by using CLIP to construct category-specific archives of images.
14
+ These images are pseudo-labelled with a category-agnostic salient object detector bootstrapped from DINO, then refined by category-specific segmenters using the CLIP archive labels.
15
+ Thanks to the high quality of the refined masks, we show that a standard segmentation architecture trained on these archives with appropriate data augmentation achieves impressive semantic segmentation abilities for both single-object and multi-object images.
16
+ As a result, our proposed NamedMask performs favourably against a range of prior work on five benchmarks including the VOC2012, COCO and large-scale ImageNet-S datasets.
17
+ Code is publicly available at <a href="https://github.com/NoelShin/namedmask">our repo</a>.
18
+ </body>
19
+ </html>
images/2007_002260.jpg ADDED
images/2008_002536.jpg ADDED
images/2008_003499.jpg ADDED
images/2008_007814.jpg ADDED
images/2009_004801.jpg ADDED
images/2010_001079.jpg ADDED
images/2010_005063.jpg ADDED
networks/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from networks.modeling import *
2
+ from networks._deeplab import convert_to_separable_conv, set_bn_momentum
networks/_deeplab.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from networks.deeplab.utils import _SimpleSegmentationModel
6
+
7
+ __all__ = ["DeepLabV3"]
8
+
9
+
10
+ class DeepLabV3(_SimpleSegmentationModel):
11
+ """
12
+ Implements DeepLabV3 model from
13
+ `"Rethinking Atrous Convolution for Semantic Image Segmentation"
14
+ <https://arxiv.org/abs/1706.05587>`_.
15
+ Arguments:
16
+ backbone (nn.Module): the network used to compute the features for the model.
17
+ The backbone should return an OrderedDict[Tensor], with the key being
18
+ "out" for the last feature map used, and "aux" if an auxiliary classifier
19
+ is used.
20
+ classifier (nn.Module): module that takes the "out" element returned from
21
+ the backbone and returns a dense prediction.
22
+ aux_classifier (nn.Module, optional): auxiliary classifier used during training
23
+ """
24
+ pass
25
+
26
+
27
+ class DeepLabHeadV3Plus(nn.Module):
28
+ def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
29
+ super(DeepLabHeadV3Plus, self).__init__()
30
+ self.project = nn.Sequential(
31
+ nn.Conv2d(low_level_channels, 48, 1, bias=False),
32
+ nn.BatchNorm2d(48),
33
+ nn.ReLU(inplace=True),
34
+ )
35
+
36
+ self.aspp = ASPP(in_channels, aspp_dilate)
37
+
38
+ self.classifier = nn.Sequential(
39
+ nn.Conv2d(304, 256, 3, padding=1, bias=False),
40
+ nn.BatchNorm2d(256),
41
+ nn.ReLU(inplace=True),
42
+ nn.Conv2d(256, num_classes, 1)
43
+ )
44
+ self._init_weight()
45
+
46
+ def forward(self, feature):
47
+ low_level_feature = self.project(feature['low_level'])
48
+ output_feature = self.aspp(feature['out'])
49
+ output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear',
50
+ align_corners=False)
51
+ return self.classifier(torch.cat([low_level_feature, output_feature], dim=1))
52
+
53
+ def _init_weight(self):
54
+ for m in self.modules():
55
+ if isinstance(m, nn.Conv2d):
56
+ nn.init.kaiming_normal_(m.weight)
57
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
58
+ nn.init.constant_(m.weight, 1)
59
+ nn.init.constant_(m.bias, 0)
60
+
61
+
62
+ class DeepLabHead(nn.Module):
63
+ def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]):
64
+ super(DeepLabHead, self).__init__()
65
+
66
+ self.classifier = nn.Sequential(
67
+ ASPP(in_channels, aspp_dilate),
68
+ nn.Conv2d(256, 256, 3, padding=1, bias=False),
69
+ nn.BatchNorm2d(256),
70
+ nn.ReLU(inplace=True),
71
+ nn.Conv2d(256, num_classes, 1)
72
+ )
73
+ self._init_weight()
74
+
75
+ def forward(self, feature):
76
+ return self.classifier(feature['out'])
77
+
78
+ def _init_weight(self):
79
+ for m in self.modules():
80
+ if isinstance(m, nn.Conv2d):
81
+ nn.init.kaiming_normal_(m.weight)
82
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
83
+ nn.init.constant_(m.weight, 1)
84
+ nn.init.constant_(m.bias, 0)
85
+
86
+
87
+ class AtrousSeparableConvolution(nn.Module):
88
+ """ Atrous Separable Convolution
89
+ """
90
+
91
+ def __init__(self, in_channels, out_channels, kernel_size,
92
+ stride=1, padding=0, dilation=1, bias=True):
93
+ super(AtrousSeparableConvolution, self).__init__()
94
+ self.body = nn.Sequential(
95
+ # Separable Conv
96
+ nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding,
97
+ dilation=dilation, bias=bias, groups=in_channels),
98
+ # PointWise Conv
99
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
100
+ )
101
+
102
+ self._init_weight()
103
+
104
+ def forward(self, x):
105
+ return self.body(x)
106
+
107
+ def _init_weight(self):
108
+ for m in self.modules():
109
+ if isinstance(m, nn.Conv2d):
110
+ nn.init.kaiming_normal_(m.weight)
111
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
112
+ nn.init.constant_(m.weight, 1)
113
+ nn.init.constant_(m.bias, 0)
114
+
115
+
116
+ class ASPPConv(nn.Sequential):
117
+ def __init__(self, in_channels, out_channels, dilation):
118
+ modules = [
119
+ nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
120
+ nn.BatchNorm2d(out_channels),
121
+ nn.ReLU(inplace=True)
122
+ ]
123
+ super(ASPPConv, self).__init__(*modules)
124
+
125
+
126
+ class ASPPPooling(nn.Sequential):
127
+ def __init__(self, in_channels, out_channels):
128
+ super(ASPPPooling, self).__init__(
129
+ nn.AdaptiveAvgPool2d(1),
130
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
131
+ nn.BatchNorm2d(out_channels),
132
+ nn.ReLU(inplace=True))
133
+
134
+ def forward(self, x):
135
+ size = x.shape[-2:]
136
+ x = super(ASPPPooling, self).forward(x)
137
+ return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
138
+
139
+
140
+ class ASPP(nn.Module):
141
+ def __init__(self, in_channels, atrous_rates):
142
+ super(ASPP, self).__init__()
143
+ out_channels = 256
144
+ modules = []
145
+ modules.append(nn.Sequential(
146
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
147
+ nn.BatchNorm2d(out_channels),
148
+ nn.ReLU(inplace=True)))
149
+
150
+ rate1, rate2, rate3 = tuple(atrous_rates)
151
+ modules.append(ASPPConv(in_channels, out_channels, rate1))
152
+ modules.append(ASPPConv(in_channels, out_channels, rate2))
153
+ modules.append(ASPPConv(in_channels, out_channels, rate3))
154
+ modules.append(ASPPPooling(in_channels, out_channels))
155
+
156
+ self.convs = nn.ModuleList(modules)
157
+
158
+ self.project = nn.Sequential(
159
+ nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
160
+ nn.BatchNorm2d(out_channels),
161
+ nn.ReLU(inplace=True),
162
+ nn.Dropout(0.1), )
163
+
164
+ def forward(self, x):
165
+ res = []
166
+ for conv in self.convs:
167
+ res.append(conv(x))
168
+ res = torch.cat(res, dim=1)
169
+ return self.project(res)
170
+
171
+
172
+ def convert_to_separable_conv(module):
173
+ new_module = module
174
+ if isinstance(module, nn.Conv2d) and module.kernel_size[0] > 1:
175
+ new_module = AtrousSeparableConvolution(module.in_channels,
176
+ module.out_channels,
177
+ module.kernel_size,
178
+ module.stride,
179
+ module.padding,
180
+ module.dilation,
181
+ module.bias)
182
+ for name, child in module.named_children():
183
+ new_module.add_module(name, convert_to_separable_conv(child))
184
+ return new_module
185
+
186
+
187
+ def set_bn_momentum(model, momentum=0.1):
188
+ for m in model.modules():
189
+ if isinstance(m, nn.BatchNorm2d):
190
+ m.momentum = momentum
networks/backbone/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from networks.deeplab.backbone import resnet
2
+ from networks.deeplab.backbone import mobilenetv2
3
+ from networks.deeplab.backbone import hrnetv2
networks/backbone/hrnetv2.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import os
5
+
6
+ __all__ = ['HRNet', 'hrnetv2_48', 'hrnetv2_32']
7
+
8
+ # Checkpoint path of pre-trained backbone (edit to your path). Download backbone pretrained model hrnetv2-32 @
9
+ # https://drive.google.com/file/d/1NxCK7Zgn5PmeS7W1jYLt5J9E0RRZ2oyF/view?usp=sharing .Personally, I added the backbone
10
+ # weights to the folder /checkpoints
11
+ try:
12
+ CKPT_PATH = './checkpoints/hrnetv2_32_model_best_epoch96.pth'
13
+ print(f"Backbone HRNet Pretrained weights at: {CKPT_PATH}, only usable for HRNetv2-32")
14
+ except:
15
+ print("No backbone checkpoint found for HRNetv2, please set pretrained=False when calling model")
16
+
17
+ # HRNetv2-48 not available yet, but you can train the whole model from scratch.
18
+
19
+ class Bottleneck(nn.Module):
20
+ expansion = 4
21
+
22
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
23
+ super(Bottleneck, self).__init__()
24
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
25
+ self.bn1 = nn.BatchNorm2d(planes)
26
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
27
+ self.bn2 = nn.BatchNorm2d(planes)
28
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
29
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
30
+ self.relu = nn.ReLU(inplace=True)
31
+ self.downsample = downsample
32
+
33
+ def forward(self, x):
34
+ identity = x
35
+
36
+ out = self.conv1(x)
37
+ out = self.bn1(out)
38
+ out = self.relu(out)
39
+ out = self.conv2(out)
40
+ out = self.bn2(out)
41
+ out = self.relu(out)
42
+ out = self.conv3(out)
43
+ out = self.bn3(out)
44
+
45
+ if self.downsample is not None:
46
+ identity = self.downsample(x)
47
+
48
+ out += identity
49
+ out = self.relu(out)
50
+
51
+ return out
52
+
53
+
54
+ class BasicBlock(nn.Module):
55
+ expansion = 1
56
+
57
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
58
+ super(BasicBlock, self).__init__()
59
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
60
+ self.bn1 = nn.BatchNorm2d(planes)
61
+ self.relu = nn.ReLU(inplace=True)
62
+ self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False)
63
+ self.bn2 = nn.BatchNorm2d(planes)
64
+ self.downsample = downsample
65
+
66
+ def forward(self, x):
67
+ identity = x
68
+
69
+ out = self.conv1(x)
70
+ out = self.bn1(out)
71
+ out = self.relu(out)
72
+ out = self.conv2(out)
73
+ out = self.bn2(out)
74
+
75
+ if self.downsample is not None:
76
+ identity = self.downsample(x)
77
+
78
+ out += identity
79
+ out = self.relu(out)
80
+
81
+ return out
82
+
83
+
84
+ class StageModule(nn.Module):
85
+ def __init__(self, stage, output_branches, c):
86
+ super(StageModule, self).__init__()
87
+
88
+ self.number_of_branches = stage # number of branches is equivalent to the stage configuration.
89
+ self.output_branches = output_branches
90
+
91
+ self.branches = nn.ModuleList()
92
+
93
+ # Note: Resolution + Number of channels maintains the same throughout respective branch.
94
+ for i in range(self.number_of_branches): # Stage scales with the number of branches. Ex: Stage 2 -> 2 branch
95
+ channels = c * (2 ** i) # Scale channels by 2x for branch with lower resolution,
96
+
97
+ # Paper does x4 basic block for each forward sequence in each branch (x4 basic block considered as a block)
98
+ branch = nn.Sequential(*[BasicBlock(channels, channels) for _ in range(4)])
99
+
100
+ self.branches.append(branch) # list containing all forward sequence of individual branches.
101
+
102
+ # For each branch requires repeated fusion with all other branches after passing through x4 basic blocks.
103
+ self.fuse_layers = nn.ModuleList()
104
+
105
+ for branch_output_number in range(self.output_branches):
106
+
107
+ self.fuse_layers.append(nn.ModuleList())
108
+
109
+ for branch_number in range(self.number_of_branches):
110
+ if branch_number == branch_output_number:
111
+ self.fuse_layers[-1].append(nn.Sequential()) # Used in place of "None" because it is callable
112
+ elif branch_number > branch_output_number:
113
+ self.fuse_layers[-1].append(nn.Sequential(
114
+ nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_output_number), kernel_size=1, stride=1,
115
+ bias=False),
116
+ nn.BatchNorm2d(c * (2 ** branch_output_number), eps=1e-05, momentum=0.1, affine=True,
117
+ track_running_stats=True),
118
+ nn.Upsample(scale_factor=(2.0 ** (branch_number - branch_output_number)), mode='nearest'),
119
+ ))
120
+ elif branch_number < branch_output_number:
121
+ downsampling_fusion = []
122
+ for _ in range(branch_output_number - branch_number - 1):
123
+ downsampling_fusion.append(nn.Sequential(
124
+ nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_number), kernel_size=3, stride=2,
125
+ padding=1,
126
+ bias=False),
127
+ nn.BatchNorm2d(c * (2 ** branch_number), eps=1e-05, momentum=0.1, affine=True,
128
+ track_running_stats=True),
129
+ nn.ReLU(inplace=True),
130
+ ))
131
+ downsampling_fusion.append(nn.Sequential(
132
+ nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_output_number), kernel_size=3,
133
+ stride=2, padding=1,
134
+ bias=False),
135
+ nn.BatchNorm2d(c * (2 ** branch_output_number), eps=1e-05, momentum=0.1, affine=True,
136
+ track_running_stats=True),
137
+ ))
138
+ self.fuse_layers[-1].append(nn.Sequential(*downsampling_fusion))
139
+
140
+ self.relu = nn.ReLU(inplace=True)
141
+
142
+ def forward(self, x):
143
+
144
+ # input to each stage is a list of inputs for each branch
145
+ x = [branch(branch_input) for branch, branch_input in zip(self.branches, x)]
146
+
147
+ x_fused = []
148
+ for branch_output_index in range(
149
+ self.output_branches): # Amount of output branches == total length of fusion layers
150
+ for input_index in range(self.number_of_branches): # The inputs of other branches to be fused.
151
+ if input_index == 0:
152
+ x_fused.append(self.fuse_layers[branch_output_index][input_index](x[input_index]))
153
+ else:
154
+ x_fused[branch_output_index] = x_fused[branch_output_index] + self.fuse_layers[branch_output_index][
155
+ input_index](x[input_index])
156
+
157
+ # After fusing all streams together, you will need to pass the fused layers
158
+ for i in range(self.output_branches):
159
+ x_fused[i] = self.relu(x_fused[i])
160
+
161
+ return x_fused # returning a list of fused outputs
162
+
163
+
164
+ class HRNet(nn.Module):
165
+ def __init__(self, c=48, num_blocks=[1, 4, 3], num_classes=1000):
166
+ super(HRNet, self).__init__()
167
+
168
+ # Stem:
169
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
170
+ self.bn1 = nn.BatchNorm2d(64, eps=1e-05, affine=True, track_running_stats=True)
171
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
172
+ self.bn2 = nn.BatchNorm2d(64, eps=1e-05, affine=True, track_running_stats=True)
173
+ self.relu = nn.ReLU(inplace=True)
174
+
175
+ # Stage 1:
176
+ downsample = nn.Sequential(
177
+ nn.Conv2d(64, 256, kernel_size=1, stride=1, bias=False),
178
+ nn.BatchNorm2d(256, eps=1e-05, affine=True, track_running_stats=True),
179
+ )
180
+ # Note that bottleneck module will expand the output channels according to the output channels*block.expansion
181
+ bn_expansion = Bottleneck.expansion # The channel expansion is set in the bottleneck class.
182
+ self.layer1 = nn.Sequential(
183
+ Bottleneck(64, 64, downsample=downsample), # Input is 64 for first module connection
184
+ Bottleneck(bn_expansion * 64, 64),
185
+ Bottleneck(bn_expansion * 64, 64),
186
+ Bottleneck(bn_expansion * 64, 64),
187
+ )
188
+
189
+ # Transition 1 - Creation of the first two branches (one full and one half resolution)
190
+ # Need to transition into high resolution stream and mid resolution stream
191
+ self.transition1 = nn.ModuleList([
192
+ nn.Sequential(
193
+ nn.Conv2d(256, c, kernel_size=3, stride=1, padding=1, bias=False),
194
+ nn.BatchNorm2d(c, eps=1e-05, affine=True, track_running_stats=True),
195
+ nn.ReLU(inplace=True),
196
+ ),
197
+ nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
198
+ nn.Conv2d(256, c * 2, kernel_size=3, stride=2, padding=1, bias=False),
199
+ nn.BatchNorm2d(c * 2, eps=1e-05, affine=True, track_running_stats=True),
200
+ nn.ReLU(inplace=True),
201
+ )),
202
+ ])
203
+
204
+ # Stage 2:
205
+ number_blocks_stage2 = num_blocks[0]
206
+ self.stage2 = nn.Sequential(
207
+ *[StageModule(stage=2, output_branches=2, c=c) for _ in range(number_blocks_stage2)])
208
+
209
+ # Transition 2 - Creation of the third branch (1/4 resolution)
210
+ self.transition2 = self._make_transition_layers(c, transition_number=2)
211
+
212
+ # Stage 3:
213
+ number_blocks_stage3 = num_blocks[1] # number blocks you want to create before fusion
214
+ self.stage3 = nn.Sequential(
215
+ *[StageModule(stage=3, output_branches=3, c=c) for _ in range(number_blocks_stage3)])
216
+
217
+ # Transition - Creation of the fourth branch (1/8 resolution)
218
+ self.transition3 = self._make_transition_layers(c, transition_number=3)
219
+
220
+ # Stage 4:
221
+ number_blocks_stage4 = num_blocks[2] # number blocks you want to create before fusion
222
+ self.stage4 = nn.Sequential(
223
+ *[StageModule(stage=4, output_branches=4, c=c) for _ in range(number_blocks_stage4)])
224
+
225
+ # Classifier (extra module if want to use for classification):
226
+ # pool, reduce dimensionality, flatten, connect to linear layer for classification:
227
+ out_channels = sum([c * 2 ** i for i in range(len(num_blocks)+1)]) # total output channels of HRNetV2
228
+ pool_feature_map = 8
229
+ self.bn_classifier = nn.Sequential(
230
+ nn.Conv2d(out_channels, out_channels // 4, kernel_size=1, bias=False),
231
+ nn.BatchNorm2d(out_channels // 4, eps=1e-05, affine=True, track_running_stats=True),
232
+ nn.ReLU(inplace=True),
233
+ nn.AdaptiveAvgPool2d(pool_feature_map),
234
+ nn.Flatten(),
235
+ nn.Linear(pool_feature_map * pool_feature_map * (out_channels // 4), num_classes),
236
+ )
237
+
238
+ @staticmethod
239
+ def _make_transition_layers(c, transition_number):
240
+ return nn.Sequential(
241
+ nn.Conv2d(c * (2 ** (transition_number - 1)), c * (2 ** transition_number), kernel_size=3, stride=2,
242
+ padding=1, bias=False),
243
+ nn.BatchNorm2d(c * (2 ** transition_number), eps=1e-05, affine=True,
244
+ track_running_stats=True),
245
+ nn.ReLU(inplace=True),
246
+ )
247
+
248
+ def forward(self, x):
249
+ # Stem:
250
+ x = self.conv1(x)
251
+ x = self.bn1(x)
252
+ x = self.relu(x)
253
+ x = self.conv2(x)
254
+ x = self.bn2(x)
255
+ x = self.relu(x)
256
+
257
+ # Stage 1
258
+ x = self.layer1(x)
259
+ x = [trans(x) for trans in self.transition1] # split to 2 branches, form a list.
260
+
261
+ # Stage 2
262
+ x = self.stage2(x)
263
+ x.append(self.transition2(x[-1]))
264
+
265
+ # Stage 3
266
+ x = self.stage3(x)
267
+ x.append(self.transition3(x[-1]))
268
+
269
+ # Stage 4
270
+ x = self.stage4(x)
271
+
272
+ # HRNetV2 Example: (follow paper, upsample via bilinear interpolation and to highest resolution size)
273
+ output_h, output_w = x[0].size(2), x[0].size(3) # Upsample to size of highest resolution stream
274
+ x1 = F.interpolate(x[1], size=(output_h, output_w), mode='bilinear', align_corners=False)
275
+ x2 = F.interpolate(x[2], size=(output_h, output_w), mode='bilinear', align_corners=False)
276
+ x3 = F.interpolate(x[3], size=(output_h, output_w), mode='bilinear', align_corners=False)
277
+
278
+ # Upsampling all the other resolution streams and then concatenate all (rather than adding/fusing like HRNetV1)
279
+ x = torch.cat([x[0], x1, x2, x3], dim=1)
280
+ x = self.bn_classifier(x)
281
+ return x
282
+
283
+
284
+ def _hrnet(arch, channels, num_blocks, pretrained, progress, **kwargs):
285
+ model = HRNet(channels, num_blocks, **kwargs)
286
+ if pretrained:
287
+ print("Loading pretrained backbone HRNetV2 model .....")
288
+ checkpoint = torch.load(CKPT_PATH)
289
+ model.load_state_dict(checkpoint['state_dict'])
290
+ return model
291
+
292
+
293
+ def hrnetv2_48(pretrained=False, progress=True, number_blocks=[1, 4, 3], **kwargs):
294
+ w_channels = 48
295
+ return _hrnet('hrnetv2_48', w_channels, number_blocks, pretrained, progress,
296
+ **kwargs)
297
+
298
+
299
+ def hrnetv2_32(pretrained=False, progress=True, number_blocks=[1, 4, 3], **kwargs):
300
+ w_channels = 32
301
+ return _hrnet('hrnetv2_32', w_channels, number_blocks, pretrained, progress,
302
+ **kwargs)
303
+
304
+
305
+ if __name__ == '__main__':
306
+
307
+ try:
308
+ CKPT_PATH = os.path.join(os.path.abspath("."), '../../checkpoints/hrnetv2_32_model_best_epoch96.pth')
309
+ print("--- Running file as MAIN ---")
310
+ print(f"Backbone HRNET Pretrained weights as __main__ at: {CKPT_PATH}")
311
+ except:
312
+ print("No backbone checkpoint found for HRNetv2, please set pretrained=False when calling model")
313
+
314
+ # Models
315
+ model = hrnetv2_32(pretrained=True)
316
+ #model = hrnetv2_48(pretrained=False)
317
+
318
+ if torch.cuda.is_available():
319
+ torch.backends.cudnn.deterministic = True
320
+ device = torch.device('cuda')
321
+ else:
322
+ device = torch.device('cpu')
323
+ model.to(device)
324
+ in_ = torch.ones(1, 3, 768, 768).to(device)
325
+ y = model(in_)
326
+ print(y.shape)
327
+
328
+ # Calculate total number of parameters:
329
+ # pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
330
+ # print(pytorch_total_params)
networks/backbone/mobilenetv2.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ try: # for torchvision<0.4
3
+ from torchvision.models.utils import load_state_dict_from_url
4
+ except: # for torchvision>=0.4
5
+ from torch.hub import load_state_dict_from_url
6
+ import torch.nn.functional as F
7
+
8
+ __all__ = ['MobileNetV2', 'mobilenet_v2']
9
+
10
+
11
+ model_urls = {
12
+ 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
13
+ }
14
+
15
+
16
+ def _make_divisible(v, divisor, min_value=None):
17
+ """
18
+ This function is taken from the original tf repo.
19
+ It ensures that all layers have a channel number that is divisible by 8
20
+ It can be seen here:
21
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
22
+ :param v:
23
+ :param divisor:
24
+ :param min_value:
25
+ :return:
26
+ """
27
+ if min_value is None:
28
+ min_value = divisor
29
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
30
+ # Make sure that round down does not go down by more than 10%.
31
+ if new_v < 0.9 * v:
32
+ new_v += divisor
33
+ return new_v
34
+
35
+
36
+ class ConvBNReLU(nn.Sequential):
37
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1):
38
+ #padding = (kernel_size - 1) // 2
39
+ super(ConvBNReLU, self).__init__(
40
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False),
41
+ nn.BatchNorm2d(out_planes),
42
+ nn.ReLU6(inplace=True)
43
+ )
44
+
45
+ def fixed_padding(kernel_size, dilation):
46
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
47
+ pad_total = kernel_size_effective - 1
48
+ pad_beg = pad_total // 2
49
+ pad_end = pad_total - pad_beg
50
+ return (pad_beg, pad_end, pad_beg, pad_end)
51
+
52
+ class InvertedResidual(nn.Module):
53
+ def __init__(self, inp, oup, stride, dilation, expand_ratio):
54
+ super(InvertedResidual, self).__init__()
55
+ self.stride = stride
56
+ assert stride in [1, 2]
57
+
58
+ hidden_dim = int(round(inp * expand_ratio))
59
+ self.use_res_connect = self.stride == 1 and inp == oup
60
+
61
+ layers = []
62
+ if expand_ratio != 1:
63
+ # pw
64
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
65
+
66
+ layers.extend([
67
+ # dw
68
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim),
69
+ # pw-linear
70
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
71
+ nn.BatchNorm2d(oup),
72
+ ])
73
+ self.conv = nn.Sequential(*layers)
74
+
75
+ self.input_padding = fixed_padding( 3, dilation )
76
+
77
+ def forward(self, x):
78
+ x_pad = F.pad(x, self.input_padding)
79
+ if self.use_res_connect:
80
+ return x + self.conv(x_pad)
81
+ else:
82
+ return self.conv(x_pad)
83
+
84
+ class MobileNetV2(nn.Module):
85
+ def __init__(self, num_classes=1000, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
86
+ """
87
+ MobileNet V2 main class
88
+ Args:
89
+ num_classes (int): Number of classes
90
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
91
+ inverted_residual_setting: Network structure
92
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
93
+ Set to 1 to turn off rounding
94
+ """
95
+ super(MobileNetV2, self).__init__()
96
+ block = InvertedResidual
97
+ input_channel = 32
98
+ last_channel = 1280
99
+ self.output_stride = output_stride
100
+ current_stride = 1
101
+ if inverted_residual_setting is None:
102
+ inverted_residual_setting = [
103
+ # t, c, n, s
104
+ [1, 16, 1, 1],
105
+ [6, 24, 2, 2],
106
+ [6, 32, 3, 2],
107
+ [6, 64, 4, 2],
108
+ [6, 96, 3, 1],
109
+ [6, 160, 3, 2],
110
+ [6, 320, 1, 1],
111
+ ]
112
+
113
+ # only check the first element, assuming user knows t,c,n,s are required
114
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
115
+ raise ValueError("inverted_residual_setting should be non-empty "
116
+ "or a 4-element list, got {}".format(inverted_residual_setting))
117
+
118
+ # building first layer
119
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
120
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
121
+ features = [ConvBNReLU(3, input_channel, stride=2)]
122
+ current_stride *= 2
123
+ dilation=1
124
+ previous_dilation = 1
125
+
126
+ # building inverted residual blocks
127
+ for t, c, n, s in inverted_residual_setting:
128
+ output_channel = _make_divisible(c * width_mult, round_nearest)
129
+ previous_dilation = dilation
130
+ if current_stride == output_stride:
131
+ stride = 1
132
+ dilation *= s
133
+ else:
134
+ stride = s
135
+ current_stride *= s
136
+ output_channel = int(c * width_mult)
137
+
138
+ for i in range(n):
139
+ if i==0:
140
+ features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t))
141
+ else:
142
+ features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t))
143
+ input_channel = output_channel
144
+ # building last several layers
145
+ features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
146
+ # make it nn.Sequential
147
+ self.features = nn.Sequential(*features)
148
+
149
+ # building classifier
150
+ self.classifier = nn.Sequential(
151
+ nn.Dropout(0.2),
152
+ nn.Linear(self.last_channel, num_classes),
153
+ )
154
+
155
+ # weight initialization
156
+ for m in self.modules():
157
+ if isinstance(m, nn.Conv2d):
158
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
159
+ if m.bias is not None:
160
+ nn.init.zeros_(m.bias)
161
+ elif isinstance(m, nn.BatchNorm2d):
162
+ nn.init.ones_(m.weight)
163
+ nn.init.zeros_(m.bias)
164
+ elif isinstance(m, nn.Linear):
165
+ nn.init.normal_(m.weight, 0, 0.01)
166
+ nn.init.zeros_(m.bias)
167
+
168
+ def forward(self, x):
169
+ x = self.features(x)
170
+ x = x.mean([2, 3])
171
+ x = self.classifier(x)
172
+ return x
173
+
174
+
175
+ def mobilenet_v2(pretrained=False, progress=True, **kwargs):
176
+ """
177
+ Constructs a MobileNetV2 architecture from
178
+ `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
179
+ Args:
180
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
181
+ progress (bool): If True, displays a progress bar of the download to stderr
182
+ """
183
+ model = MobileNetV2(**kwargs)
184
+ if pretrained:
185
+ state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
186
+ progress=progress)
187
+ model.load_state_dict(state_dict)
188
+ return model
networks/backbone/resnet.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ try: # for torchvision<0.4
4
+ from torchvision.models.utils import load_state_dict_from_url
5
+ except: # for torchvision>=0.4
6
+ from torch.hub import load_state_dict_from_url
7
+
8
+
9
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
11
+ 'wide_resnet50_2', 'wide_resnet101_2']
12
+
13
+
14
+ model_urls = {
15
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
16
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
17
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
18
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
19
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
20
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
21
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
22
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
23
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
24
+ }
25
+
26
+
27
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
28
+ """3x3 convolution with padding"""
29
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
31
+
32
+
33
+ def conv1x1(in_planes, out_planes, stride=1):
34
+ """1x1 convolution"""
35
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
36
+
37
+
38
+ class BasicBlock(nn.Module):
39
+ expansion = 1
40
+
41
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
42
+ base_width=64, dilation=1, norm_layer=None):
43
+ super(BasicBlock, self).__init__()
44
+ if norm_layer is None:
45
+ norm_layer = nn.BatchNorm2d
46
+ if groups != 1 or base_width != 64:
47
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
48
+ if dilation > 1:
49
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
50
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
51
+ self.conv1 = conv3x3(inplanes, planes, stride)
52
+ self.bn1 = norm_layer(planes)
53
+ self.relu = nn.ReLU(inplace=True)
54
+ self.conv2 = conv3x3(planes, planes)
55
+ self.bn2 = norm_layer(planes)
56
+ self.downsample = downsample
57
+ self.stride = stride
58
+
59
+ def forward(self, x):
60
+ identity = x
61
+
62
+ out = self.conv1(x)
63
+ out = self.bn1(out)
64
+ out = self.relu(out)
65
+
66
+ out = self.conv2(out)
67
+ out = self.bn2(out)
68
+
69
+ if self.downsample is not None:
70
+ identity = self.downsample(x)
71
+
72
+ out += identity
73
+ out = self.relu(out)
74
+
75
+ return out
76
+
77
+
78
+ class Bottleneck(nn.Module):
79
+ expansion = 4
80
+
81
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
82
+ base_width=64, dilation=1, norm_layer=None):
83
+ super(Bottleneck, self).__init__()
84
+ if norm_layer is None:
85
+ norm_layer = nn.BatchNorm2d
86
+ width = int(planes * (base_width / 64.)) * groups
87
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
88
+ self.conv1 = conv1x1(inplanes, width)
89
+ self.bn1 = norm_layer(width)
90
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
91
+ self.bn2 = norm_layer(width)
92
+ self.conv3 = conv1x1(width, planes * self.expansion)
93
+ self.bn3 = norm_layer(planes * self.expansion)
94
+ self.relu = nn.ReLU(inplace=True)
95
+ self.downsample = downsample
96
+ self.stride = stride
97
+
98
+ def forward(self, x):
99
+ identity = x
100
+
101
+ out = self.conv1(x)
102
+ out = self.bn1(out)
103
+ out = self.relu(out)
104
+
105
+ out = self.conv2(out)
106
+ out = self.bn2(out)
107
+ out = self.relu(out)
108
+
109
+ out = self.conv3(out)
110
+ out = self.bn3(out)
111
+
112
+ if self.downsample is not None:
113
+ identity = self.downsample(x)
114
+
115
+ out += identity
116
+ out = self.relu(out)
117
+
118
+ return out
119
+
120
+
121
+ class ResNet(nn.Module):
122
+
123
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
124
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
125
+ norm_layer=None):
126
+ super(ResNet, self).__init__()
127
+ if norm_layer is None:
128
+ norm_layer = nn.BatchNorm2d
129
+ self._norm_layer = norm_layer
130
+
131
+ self.inplanes = 64
132
+ self.dilation = 1
133
+ if replace_stride_with_dilation is None:
134
+ # each element in the tuple indicates if we should replace
135
+ # the 2x2 stride with a dilated convolution instead
136
+ replace_stride_with_dilation = [False, False, False]
137
+ if len(replace_stride_with_dilation) != 3:
138
+ raise ValueError("replace_stride_with_dilation should be None "
139
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
140
+ self.groups = groups
141
+ self.base_width = width_per_group
142
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
143
+ bias=False)
144
+ self.bn1 = norm_layer(self.inplanes)
145
+ self.relu = nn.ReLU(inplace=True)
146
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
147
+ self.layer1 = self._make_layer(block, 64, layers[0])
148
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
149
+ dilate=replace_stride_with_dilation[0])
150
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
151
+ dilate=replace_stride_with_dilation[1])
152
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
153
+ dilate=replace_stride_with_dilation[2])
154
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
155
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
156
+
157
+ for m in self.modules():
158
+ if isinstance(m, nn.Conv2d):
159
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
160
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
161
+ nn.init.constant_(m.weight, 1)
162
+ nn.init.constant_(m.bias, 0)
163
+
164
+ # Zero-initialize the last BN in each residual branch,
165
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
166
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
167
+ if zero_init_residual:
168
+ for m in self.modules():
169
+ if isinstance(m, Bottleneck):
170
+ nn.init.constant_(m.bn3.weight, 0)
171
+ elif isinstance(m, BasicBlock):
172
+ nn.init.constant_(m.bn2.weight, 0)
173
+
174
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
175
+ norm_layer = self._norm_layer
176
+ downsample = None
177
+ previous_dilation = self.dilation
178
+ if dilate:
179
+ self.dilation *= stride
180
+ stride = 1
181
+ if stride != 1 or self.inplanes != planes * block.expansion:
182
+ downsample = nn.Sequential(
183
+ conv1x1(self.inplanes, planes * block.expansion, stride),
184
+ norm_layer(planes * block.expansion),
185
+ )
186
+
187
+ layers = []
188
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
189
+ self.base_width, previous_dilation, norm_layer))
190
+ self.inplanes = planes * block.expansion
191
+ for _ in range(1, blocks):
192
+ layers.append(block(self.inplanes, planes, groups=self.groups,
193
+ base_width=self.base_width, dilation=self.dilation,
194
+ norm_layer=norm_layer))
195
+
196
+ return nn.Sequential(*layers)
197
+
198
+ def forward(self, x):
199
+ x = self.conv1(x)
200
+ x = self.bn1(x)
201
+ x = self.relu(x)
202
+ x = self.maxpool(x)
203
+
204
+ x = self.layer1(x)
205
+ x = self.layer2(x)
206
+ x = self.layer3(x)
207
+ x = self.layer4(x)
208
+
209
+ x = self.avgpool(x)
210
+ x = torch.flatten(x, 1)
211
+ x = self.fc(x)
212
+
213
+ return x
214
+
215
+
216
+ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
217
+ model = ResNet(block, layers, **kwargs)
218
+ if pretrained:
219
+ state_dict = load_state_dict_from_url(model_urls[arch],
220
+ progress=progress)
221
+ model.load_state_dict(state_dict)
222
+ return model
223
+
224
+
225
+ def resnet18(pretrained=False, progress=True, **kwargs):
226
+ r"""ResNet-18 model from
227
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
228
+ Args:
229
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
230
+ progress (bool): If True, displays a progress bar of the download to stderr
231
+ """
232
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
233
+ **kwargs)
234
+
235
+
236
+ def resnet34(pretrained=False, progress=True, **kwargs):
237
+ r"""ResNet-34 model from
238
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
239
+ Args:
240
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
241
+ progress (bool): If True, displays a progress bar of the download to stderr
242
+ """
243
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
244
+ **kwargs)
245
+
246
+
247
+ def resnet50(pretrained=False, progress=True, **kwargs):
248
+ r"""ResNet-50 model from
249
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
250
+ Args:
251
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
252
+ progress (bool): If True, displays a progress bar of the download to stderr
253
+ """
254
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
255
+ **kwargs)
256
+
257
+
258
+ def resnet101(pretrained=False, progress=True, **kwargs):
259
+ r"""ResNet-101 model from
260
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
261
+ Args:
262
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
263
+ progress (bool): If True, displays a progress bar of the download to stderr
264
+ """
265
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
266
+ **kwargs)
267
+
268
+
269
+ def resnet152(pretrained=False, progress=True, **kwargs):
270
+ r"""ResNet-152 model from
271
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
272
+ Args:
273
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
274
+ progress (bool): If True, displays a progress bar of the download to stderr
275
+ """
276
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
277
+ **kwargs)
278
+
279
+
280
+ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
281
+ r"""ResNeXt-50 32x4d model from
282
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
283
+ Args:
284
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
285
+ progress (bool): If True, displays a progress bar of the download to stderr
286
+ """
287
+ kwargs['groups'] = 32
288
+ kwargs['width_per_group'] = 4
289
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
290
+ pretrained, progress, **kwargs)
291
+
292
+
293
+ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
294
+ r"""ResNeXt-101 32x8d model from
295
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
296
+ Args:
297
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
298
+ progress (bool): If True, displays a progress bar of the download to stderr
299
+ """
300
+ kwargs['groups'] = 32
301
+ kwargs['width_per_group'] = 8
302
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
303
+ pretrained, progress, **kwargs)
304
+
305
+
306
+ def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
307
+ r"""Wide ResNet-50-2 model from
308
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
309
+ The model is the same as ResNet except for the bottleneck number of channels
310
+ which is twice larger in every block. The number of channels in outer 1x1
311
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
312
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
313
+ Args:
314
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
315
+ progress (bool): If True, displays a progress bar of the download to stderr
316
+ """
317
+ kwargs['width_per_group'] = 64 * 2
318
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
319
+ pretrained, progress, **kwargs)
320
+
321
+
322
+ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
323
+ r"""Wide ResNet-101-2 model from
324
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
325
+ The model is the same as ResNet except for the bottleneck number of channels
326
+ which is twice larger in every block. The number of channels in outer 1x1
327
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
328
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
329
+ Args:
330
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
331
+ progress (bool): If True, displays a progress bar of the download to stderr
332
+ """
333
+ kwargs['width_per_group'] = 64 * 2
334
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
335
+ pretrained, progress, **kwargs)
networks/modeling.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from networks.deeplab.utils import IntermediateLayerGetter
2
+ from networks.deeplab._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3
3
+ from networks.deeplab.backbone import resnet, mobilenetv2, hrnetv2
4
+
5
+
6
+ def _segm_hrnet(name, backbone_name, num_classes, pretrained_backbone):
7
+ backbone = hrnetv2.__dict__[backbone_name](pretrained_backbone)
8
+ # HRNetV2 config:
9
+ # the final output channels is dependent on highest resolution channel config (c).
10
+ # output of backbone will be the inplanes to assp:
11
+ hrnet_channels = int(backbone_name.split('_')[-1])
12
+ inplanes = sum([hrnet_channels * 2 ** i for i in range(4)])
13
+ low_level_planes = 256 # all hrnet version channel output from bottleneck is the same
14
+ aspp_dilate = [12, 24, 36] # If follow paper trend, can put [24, 48, 72].
15
+
16
+ if name == 'deeplabv3plus':
17
+ return_layers = {'stage4': 'out', 'layer1': 'low_level'}
18
+ classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
19
+ elif name == 'deeplabv3':
20
+ return_layers = {'stage4': 'out'}
21
+ classifier = DeepLabHead(inplanes, num_classes, aspp_dilate)
22
+
23
+ backbone = IntermediateLayerGetter(backbone, return_layers=return_layers, hrnet_flag=True)
24
+ model = DeepLabV3(backbone, classifier)
25
+ return model
26
+
27
+
28
+ def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
29
+ if output_stride == 8:
30
+ replace_stride_with_dilation = [False, True, True]
31
+ aspp_dilate = [12, 24, 36]
32
+ else:
33
+ replace_stride_with_dilation = [False, False, True]
34
+ aspp_dilate = [6, 12, 18]
35
+
36
+ backbone = resnet.__dict__[backbone_name](
37
+ pretrained=pretrained_backbone,
38
+ replace_stride_with_dilation=replace_stride_with_dilation)
39
+
40
+ inplanes = 2048
41
+ low_level_planes = 256
42
+
43
+ if name == 'deeplabv3plus':
44
+ return_layers = {'layer4': 'out', 'layer1': 'low_level'}
45
+ classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
46
+ elif name == 'deeplabv3':
47
+ return_layers = {'layer4': 'out'}
48
+ classifier = DeepLabHead(inplanes, num_classes, aspp_dilate)
49
+ backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
50
+
51
+ model = DeepLabV3(backbone, classifier)
52
+ return model
53
+
54
+
55
+ def _segm_mobilenet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
56
+ if output_stride == 8:
57
+ aspp_dilate = [12, 24, 36]
58
+ else:
59
+ aspp_dilate = [6, 12, 18]
60
+
61
+ backbone = mobilenetv2.mobilenet_v2(pretrained=pretrained_backbone, output_stride=output_stride)
62
+
63
+ # rename layers
64
+ backbone.low_level_features = backbone.features[0:4]
65
+ backbone.high_level_features = backbone.features[4:-1]
66
+ backbone.features = None
67
+ backbone.classifier = None
68
+
69
+ inplanes = 320
70
+ low_level_planes = 24
71
+
72
+ if name == 'deeplabv3plus':
73
+ return_layers = {'high_level_features': 'out', 'low_level_features': 'low_level'}
74
+ classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
75
+ elif name == 'deeplabv3':
76
+ return_layers = {'high_level_features': 'out'}
77
+ classifier = DeepLabHead(inplanes, num_classes, aspp_dilate)
78
+ backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
79
+
80
+ model = DeepLabV3(backbone, classifier)
81
+ return model
82
+
83
+
84
+ def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone):
85
+ if backbone == 'mobilenetv2':
86
+ model = _segm_mobilenet(arch_type, backbone, num_classes, output_stride=output_stride,
87
+ pretrained_backbone=pretrained_backbone)
88
+ elif backbone.startswith('resnet'):
89
+ model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride,
90
+ pretrained_backbone=pretrained_backbone)
91
+ elif backbone.startswith('hrnetv2'):
92
+ model = _segm_hrnet(arch_type, backbone, num_classes, pretrained_backbone=pretrained_backbone)
93
+ else:
94
+ raise NotImplementedError
95
+ return model
96
+
97
+
98
+ # Deeplab v3
99
+ def deeplabv3_hrnetv2_48(num_classes=21, output_stride=4, pretrained_backbone=False): # no pretrained backbone yet
100
+ return _load_model('deeplabv3', 'hrnetv2_48', output_stride, num_classes, pretrained_backbone=pretrained_backbone)
101
+
102
+
103
+ def deeplabv3_hrnetv2_32(num_classes=21, output_stride=4, pretrained_backbone=True):
104
+ return _load_model('deeplabv3', 'hrnetv2_32', output_stride, num_classes, pretrained_backbone=pretrained_backbone)
105
+
106
+
107
+ def deeplabv3_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):
108
+ """Constructs a DeepLabV3 model with a ResNet-50 backbone.
109
+ Args:
110
+ num_classes (int): number of classes.
111
+ output_stride (int): output stride for deeplab.
112
+ pretrained_backbone (bool): If True, use the pretrained backbone.
113
+ """
114
+ return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride,
115
+ pretrained_backbone=pretrained_backbone)
116
+
117
+
118
+ def deeplabv3_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):
119
+ """Constructs a DeepLabV3 model with a ResNet-101 backbone.
120
+ Args:
121
+ num_classes (int): number of classes.
122
+ output_stride (int): output stride for deeplab.
123
+ pretrained_backbone (bool): If True, use the pretrained backbone.
124
+ """
125
+ return _load_model('deeplabv3', 'resnet101', num_classes, output_stride=output_stride,
126
+ pretrained_backbone=pretrained_backbone)
127
+
128
+
129
+ def deeplabv3_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True, **kwargs):
130
+ """Constructs a DeepLabV3 model with a MobileNetv2 backbone.
131
+ Args:
132
+ num_classes (int): number of classes.
133
+ output_stride (int): output stride for deeplab.
134
+ pretrained_backbone (bool): If True, use the pretrained backbone.
135
+ """
136
+ return _load_model('deeplabv3', 'mobilenetv2', num_classes, output_stride=output_stride,
137
+ pretrained_backbone=pretrained_backbone)
138
+
139
+
140
+ # Deeplab v3+
141
+ def deeplabv3plus_hrnetv2_48(num_classes=21, output_stride=4, pretrained_backbone=False): # no pretrained backbone yet
142
+ return _load_model('deeplabv3plus', 'hrnetv2_48', num_classes, output_stride,
143
+ pretrained_backbone=pretrained_backbone)
144
+
145
+
146
+ def deeplabv3plus_hrnetv2_32(num_classes=21, output_stride=4, pretrained_backbone=True):
147
+ return _load_model('deeplabv3plus', 'hrnetv2_32', num_classes, output_stride,
148
+ pretrained_backbone=pretrained_backbone)
149
+
150
+
151
+ def deeplabv3plus_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):
152
+ """Constructs a DeepLabV3 model with a ResNet-50 backbone.
153
+ Args:
154
+ num_classes (int): number of classes.
155
+ output_stride (int): output stride for deeplab.
156
+ pretrained_backbone (bool): If True, use the pretrained backbone.
157
+ """
158
+ return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride,
159
+ pretrained_backbone=pretrained_backbone)
160
+
161
+
162
+ def deeplabv3plus_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):
163
+ """Constructs a DeepLabV3+ model with a ResNet-101 backbone.
164
+ Args:
165
+ num_classes (int): number of classes.
166
+ output_stride (int): output stride for deeplab.
167
+ pretrained_backbone (bool): If True, use the pretrained backbone.
168
+ """
169
+ return _load_model('deeplabv3plus', 'resnet101', num_classes, output_stride=output_stride,
170
+ pretrained_backbone=pretrained_backbone)
171
+
172
+
173
+ def deeplabv3plus_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True):
174
+ """Constructs a DeepLabV3+ model with a MobileNetv2 backbone.
175
+ Args:
176
+ num_classes (int): number of classes.
177
+ output_stride (int): output stride for deeplab.
178
+ pretrained_backbone (bool): If True, use the pretrained backbone.
179
+ """
180
+ return _load_model('deeplabv3plus', 'mobilenetv2', num_classes, output_stride=output_stride,
181
+ pretrained_backbone=pretrained_backbone)
networks/utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ from collections import OrderedDict
6
+
7
+
8
+ class _SimpleSegmentationModel(nn.Module):
9
+ def __init__(self, backbone, classifier):
10
+ super(_SimpleSegmentationModel, self).__init__()
11
+ self.backbone = backbone
12
+ self.classifier = classifier
13
+
14
+ def forward(self, x):
15
+ input_shape = x.shape[-2:]
16
+ features = self.backbone(x)
17
+ x = self.classifier(features)
18
+ x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
19
+ return x
20
+
21
+
22
+ class IntermediateLayerGetter(nn.ModuleDict):
23
+ """
24
+ Module wrapper that returns intermediate layers from a model
25
+ It has a strong assumption that the modules have been registered
26
+ into the model in the same order as they are used.
27
+ This means that one should **not** reuse the same nn.Module
28
+ twice in the forward if you want this to work.
29
+ Additionally, it is only able to query submodules that are directly
30
+ assigned to the model. So if `model` is passed, `model.feature1` can
31
+ be returned, but not `model.feature1.layer2`.
32
+ Arguments:
33
+ model (nn.Module): model on which we will extract the features
34
+ return_layers (Dict[name, new_name]): a dict containing the names
35
+ of the modules for which the activations will be returned as
36
+ the key of the dict, and the value of the dict is the name
37
+ of the returned activation (which the user can specify).
38
+ Examples::
39
+ >>> m = torchvision.models.resnet18(pretrained=True)
40
+ >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
41
+ >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
42
+ >>> {'layer1': 'feat1', 'layer3': 'feat2'})
43
+ >>> out = new_m(torch.rand(1, 3, 224, 224))
44
+ >>> print([(k, v.shape) for k, v in out.items()])
45
+ >>> [('feat1', torch.Size([1, 64, 56, 56])),
46
+ >>> ('feat2', torch.Size([1, 256, 14, 14]))]
47
+ """
48
+
49
+ def __init__(self, model, return_layers, hrnet_flag=False):
50
+ if not set(return_layers).issubset([name for name, _ in model.named_children()]):
51
+ raise ValueError("return_layers are not present in model")
52
+
53
+ self.hrnet_flag = hrnet_flag
54
+
55
+ orig_return_layers = return_layers
56
+ return_layers = {k: v for k, v in return_layers.items()}
57
+ layers = OrderedDict()
58
+ for name, module in model.named_children():
59
+ layers[name] = module
60
+ if name in return_layers:
61
+ del return_layers[name]
62
+ if not return_layers:
63
+ break
64
+
65
+ super(IntermediateLayerGetter, self).__init__(layers)
66
+ self.return_layers = orig_return_layers
67
+
68
+ def forward(self, x):
69
+ out = OrderedDict()
70
+ for name, module in self.named_children():
71
+ if self.hrnet_flag and name.startswith('transition'): # if using hrnet, you need to take care of transition
72
+ if name == 'transition1': # in transition1, you need to split the module to two streams first
73
+ x = [trans(x) for trans in module]
74
+ else: # all other transition is just an extra one stream split
75
+ x.append(module(x[-1]))
76
+ else: # other models (ex:resnet,mobilenet) are convolutions in series.
77
+ x = module(x)
78
+
79
+ if name in self.return_layers:
80
+ out_name = self.return_layers[name]
81
+ if name == 'stage4' and self.hrnet_flag: # In HRNetV2, we upsample and concat all outputs streams together
82
+ output_h, output_w = x[0].size(2), x[0].size(3) # Upsample to size of highest resolution stream
83
+ x1 = F.interpolate(x[1], size=(output_h, output_w), mode='bilinear', align_corners=False)
84
+ x2 = F.interpolate(x[2], size=(output_h, output_w), mode='bilinear', align_corners=False)
85
+ x3 = F.interpolate(x[3], size=(output_h, output_w), mode='bilinear', align_corners=False)
86
+ x = torch.cat([x[0], x1, x2, x3], dim=1)
87
+ out[out_name] = x
88
+ else:
89
+ out[out_name] = x
90
+ return out
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ opencv-contrib-python==4.5.5.62
2
+ torch==1.11.0
3
+ torchvision==0.12.0
4
+ timm==0.4.12
5
+ scipy==1.6.2
utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Union
2
+ import numpy as np
3
+ import torch
4
+ from networks import deeplabv3plus_resnet50
5
+ from networks import convert_to_separable_conv, set_bn_momentum
6
+
7
+
8
+ def get_network() -> torch.nn.Module:
9
+ network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
10
+ state_dict = torch.hub.load_state_dict_from_url(
11
+ "https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt"
12
+ )
13
+ network.backbone.load_state_dict(state_dict, strict=True)
14
+ convert_to_separable_conv(network.classifier)
15
+ set_bn_momentum(network.backbone, momentum=0.01)
16
+ return network
17
+
18
+
19
+ def colourise_mask(
20
+ mask: np.ndarray,
21
+ ):
22
+ assert len(mask.shape) == 2, ValueError(mask.shape)
23
+ h, w = mask.shape
24
+ grid = np.zeros((h, w, 3), dtype=np.uint8)
25
+
26
+ unique_labels = set(mask.flatten())
27
+
28
+ voc2012_palette = {
29
+ 0: [0, 0, 0],
30
+ 1: [128, 0, 0],
31
+ 2: [0, 128, 0],
32
+ 3: [128, 128, 0],
33
+ 4: [0, 0, 128],
34
+ 5: [128, 0, 128],
35
+ 6: [0, 128, 128],
36
+ 7: [128, 128, 128],
37
+ 8: [64, 0, 0],
38
+ 9: [192, 0, 0],
39
+ 10: [64, 128, 0],
40
+ 11: [192, 128, 0],
41
+ 12: [64, 0, 128],
42
+ 13: [192, 0, 128],
43
+ 14: [64, 128, 128],
44
+ 15: [192, 128, 128],
45
+ 16: [0, 64, 0],
46
+ 17: [128, 64, 0],
47
+ 18: [0, 192, 0],
48
+ 19: [128, 192, 0],
49
+ 20: [0, 64, 128],
50
+ 255: [255, 255, 255]
51
+ }
52
+
53
+ for l in unique_labels:
54
+ grid[mask == l] = np.array(voc2012_palette[l])
55
+ try:
56
+ grid[mask == l] = np.array(voc2012_palette[l])
57
+ except IndexError:
58
+ raise IndexError(f"No colour is found for a label id: {l}")
59
+ return grid
voc_val_n500_cp2_ex.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # base directories
2
+ category_to_p_images_fp: "/home/cs-shin1/datasets/ImageNet2012/voc2012_category_to_p_images_n500.json"
3
+ dir_ckpt: "/home/cs-shin1/namedmask/ckpt"
4
+ dir_train_dataset: "/home/cs-shin1/datasets/ImageNet2012"
5
+ dir_val_dataset: "/home/cs-shin1/datasets/VOCdevkit/VOC2012"
6
+
7
+ # augmentations
8
+ max_n_masks: 2
9
+ scale_range: [ 0.1, 1.0 ]
10
+
11
+ use_expert_pseudo_masks: true
12
+ category_agnostic: false
13
+
14
+ n_categories: 21
15
+ categories: [
16
+ "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "dining table",
17
+ "dog", "horse", "motorbike", "person", "potted plant", "sheep", "sofa", "train", "tv/monitor"
18
+ ]
19
+ n_images: 500
20
+
21
+ # dataset
22
+ dataset_name: "voc2012"
23
+ split: "val"
24
+ train_image_size: 384
25
+
26
+ # dataloader:
27
+ train_dataloader_kwargs:
28
+ batch_size: 16
29
+ num_workers: 16
30
+ pin_memory: true
31
+ shuffle: true
32
+
33
+ val_dataloader_kwargs:
34
+ batch_size: 1
35
+ num_workers: 4
36
+ pin_memory: true
37
+
38
+ # Segmenter configuration
39
+ # ["deeplabv3plus_resnet101", "deeplabv3plus_resnet50", "deeplabv3plus_mobilenet"]
40
+ segmenter_name: "deeplabv3plus_resnet50"
41
+
42
+ # optimiser
43
+ lr: 0.0005
44
+ momentum: 0.9
45
+ weight_decay: 0.0002
46
+ betas: [0.9, 0.999]
47
+ n_iters: 20000
48
+
49
+ iter_eval: 1000
50
+ iter_log: 100