Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # This file is covered by the LICENSE file in the root of this project. | |
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class BasicBlock(nn.Module): | |
| def __init__(self, inplanes, planes, bn_d=0.1): | |
| super(BasicBlock, self).__init__() | |
| self.conv1 = nn.Conv2d(inplanes, planes[0], kernel_size=1, | |
| stride=1, padding=0, bias=False) | |
| self.bn1 = nn.BatchNorm2d(planes[0], momentum=bn_d) | |
| self.relu1 = nn.LeakyReLU(0.1) | |
| self.conv2 = nn.Conv2d(planes[0], planes[1], kernel_size=3, | |
| stride=1, padding=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(planes[1], momentum=bn_d) | |
| self.relu2 = nn.LeakyReLU(0.1) | |
| def forward(self, x): | |
| residual = x | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu1(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out = self.relu2(out) | |
| out += residual | |
| return out | |
| # ****************************************************************************** | |
| # number of layers per model | |
| model_blocks = { | |
| 21: [1, 1, 2, 2, 1], | |
| 53: [1, 2, 8, 8, 4], | |
| } | |
| class Backbone(nn.Module): | |
| """ | |
| Class for DarknetSeg. Subclasses PyTorch's own "nn" module | |
| """ | |
| def __init__(self, params): | |
| super(Backbone, self).__init__() | |
| self.use_range = params["input_depth"]["range"] | |
| self.use_xyz = params["input_depth"]["xyz"] | |
| self.use_remission = params["input_depth"]["remission"] | |
| self.drop_prob = params["dropout"] | |
| self.bn_d = params["bn_d"] | |
| self.OS = params["OS"] | |
| self.layers = params["extra"]["layers"] | |
| # input depth calc | |
| self.input_depth = 0 | |
| self.input_idxs = [] | |
| if self.use_range: | |
| self.input_depth += 1 | |
| self.input_idxs.append(0) | |
| if self.use_xyz: | |
| self.input_depth += 3 | |
| self.input_idxs.extend([1, 2, 3]) | |
| if self.use_remission: | |
| self.input_depth += 1 | |
| self.input_idxs.append(4) | |
| # stride play | |
| self.strides = [2, 2, 2, 2, 2] | |
| # check current stride | |
| current_os = 1 | |
| for s in self.strides: | |
| current_os *= s | |
| # make the new stride | |
| if self.OS > current_os: | |
| print("Can't do OS, ", self.OS, | |
| " because it is bigger than original ", current_os) | |
| else: | |
| # redo strides according to needed stride | |
| for i, stride in enumerate(reversed(self.strides), 0): | |
| if int(current_os) != self.OS: | |
| if stride == 2: | |
| current_os /= 2 | |
| self.strides[-1 - i] = 1 | |
| if int(current_os) == self.OS: | |
| break | |
| # check that darknet exists | |
| assert self.layers in model_blocks.keys() | |
| # generate layers depending on darknet type | |
| self.blocks = model_blocks[self.layers] | |
| # input layer | |
| self.conv1 = nn.Conv2d(self.input_depth, 32, kernel_size=3, | |
| stride=1, padding=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(32, momentum=self.bn_d) | |
| self.relu1 = nn.LeakyReLU(0.1) | |
| # encoder | |
| self.enc1 = self._make_enc_layer(BasicBlock, [32, 64], self.blocks[0], | |
| stride=self.strides[0], bn_d=self.bn_d) | |
| self.enc2 = self._make_enc_layer(BasicBlock, [64, 128], self.blocks[1], | |
| stride=self.strides[1], bn_d=self.bn_d) | |
| self.enc3 = self._make_enc_layer(BasicBlock, [128, 256], self.blocks[2], | |
| stride=self.strides[2], bn_d=self.bn_d) | |
| self.enc4 = self._make_enc_layer(BasicBlock, [256, 512], self.blocks[3], | |
| stride=self.strides[3], bn_d=self.bn_d) | |
| self.enc5 = self._make_enc_layer(BasicBlock, [512, 1024], self.blocks[4], | |
| stride=self.strides[4], bn_d=self.bn_d) | |
| # for a bit of fun | |
| self.dropout = nn.Dropout2d(self.drop_prob) | |
| # last channels | |
| self.last_channels = 1024 | |
| # make layer useful function | |
| def _make_enc_layer(self, block, planes, blocks, stride, bn_d=0.1): | |
| layers = [] | |
| # downsample | |
| layers.append(("conv", nn.Conv2d(planes[0], planes[1], | |
| kernel_size=3, | |
| stride=[1, stride], dilation=1, | |
| padding=1, bias=False))) | |
| layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d))) | |
| layers.append(("relu", nn.LeakyReLU(0.1))) | |
| # blocks | |
| inplanes = planes[1] | |
| for i in range(0, blocks): | |
| layers.append(("residual_{}".format(i), | |
| block(inplanes, planes, bn_d))) | |
| return nn.Sequential(OrderedDict(layers)) | |
| def run_layer(self, x, layer, skips, os): | |
| y = layer(x) | |
| if y.shape[2] < x.shape[2] or y.shape[3] < x.shape[3]: | |
| skips[os] = x.detach() | |
| os *= 2 | |
| x = y | |
| return x, skips, os | |
| def forward(self, x, return_logits=False, return_list=None): | |
| # filter input | |
| x = x[:, self.input_idxs] | |
| # run cnn | |
| # store for skip connections | |
| skips = {} | |
| out_dict = {} | |
| os = 1 | |
| # first layer | |
| x, skips, os = self.run_layer(x, self.conv1, skips, os) | |
| x, skips, os = self.run_layer(x, self.bn1, skips, os) | |
| x, skips, os = self.run_layer(x, self.relu1, skips, os) | |
| if return_list and 'enc_0' in return_list: | |
| out_dict['enc_0'] = x.detach().cpu() # 32, 64, 1024 | |
| # all encoder blocks with intermediate dropouts | |
| x, skips, os = self.run_layer(x, self.enc1, skips, os) | |
| if return_list and 'enc_1' in return_list: | |
| out_dict['enc_1'] = x.detach().cpu() # 64, 64, 512 | |
| x, skips, os = self.run_layer(x, self.dropout, skips, os) | |
| x, skips, os = self.run_layer(x, self.enc2, skips, os) | |
| if return_list and 'enc_2' in return_list: | |
| out_dict['enc_2'] = x.detach().cpu() # 128, 64, 256 | |
| x, skips, os = self.run_layer(x, self.dropout, skips, os) | |
| x, skips, os = self.run_layer(x, self.enc3, skips, os) | |
| if return_list and 'enc_3' in return_list: | |
| out_dict['enc_3'] = x.detach().cpu() # 256, 64, 128 | |
| x, skips, os = self.run_layer(x, self.dropout, skips, os) | |
| x, skips, os = self.run_layer(x, self.enc4, skips, os) | |
| if return_list and 'enc_4' in return_list: | |
| out_dict['enc_4'] = x.detach().cpu() # 512, 64, 64 | |
| x, skips, os = self.run_layer(x, self.dropout, skips, os) | |
| x, skips, os = self.run_layer(x, self.enc5, skips, os) | |
| if return_list and 'enc_5' in return_list: | |
| out_dict['enc_5'] = x.detach().cpu() # 1024, 64, 32 | |
| if return_logits: | |
| return x | |
| x, skips, os = self.run_layer(x, self.dropout, skips, os) | |
| if return_list is not None: | |
| return x, skips, out_dict | |
| return x, skips | |
| def get_last_depth(self): | |
| return self.last_channels | |
| def get_input_depth(self): | |
| return self.input_depth | |
| class Decoder(nn.Module): | |
| """ | |
| Class for DarknetSeg. Subclasses PyTorch's own "nn" module | |
| """ | |
| def __init__(self, params, OS=32, feature_depth=1024): | |
| super(Decoder, self).__init__() | |
| self.backbone_OS = OS | |
| self.backbone_feature_depth = feature_depth | |
| self.drop_prob = params["dropout"] | |
| self.bn_d = params["bn_d"] | |
| self.index = 0 | |
| # stride play | |
| self.strides = [2, 2, 2, 2, 2] | |
| # check current stride | |
| current_os = 1 | |
| for s in self.strides: | |
| current_os *= s | |
| # redo strides according to needed stride | |
| for i, stride in enumerate(self.strides): | |
| if int(current_os) != self.backbone_OS: | |
| if stride == 2: | |
| current_os /= 2 | |
| self.strides[i] = 1 | |
| if int(current_os) == self.backbone_OS: | |
| break | |
| # decoder | |
| self.dec5 = self._make_dec_layer(BasicBlock, | |
| [self.backbone_feature_depth, 512], | |
| bn_d=self.bn_d, | |
| stride=self.strides[0]) | |
| self.dec4 = self._make_dec_layer(BasicBlock, [512, 256], bn_d=self.bn_d, | |
| stride=self.strides[1]) | |
| self.dec3 = self._make_dec_layer(BasicBlock, [256, 128], bn_d=self.bn_d, | |
| stride=self.strides[2]) | |
| self.dec2 = self._make_dec_layer(BasicBlock, [128, 64], bn_d=self.bn_d, | |
| stride=self.strides[3]) | |
| self.dec1 = self._make_dec_layer(BasicBlock, [64, 32], bn_d=self.bn_d, | |
| stride=self.strides[4]) | |
| # layer list to execute with skips | |
| self.layers = [self.dec5, self.dec4, self.dec3, self.dec2, self.dec1] | |
| # for a bit of fun | |
| self.dropout = nn.Dropout2d(self.drop_prob) | |
| # last channels | |
| self.last_channels = 32 | |
| def _make_dec_layer(self, block, planes, bn_d=0.1, stride=2): | |
| layers = [] | |
| # downsample | |
| if stride == 2: | |
| layers.append(("upconv", nn.ConvTranspose2d(planes[0], planes[1], | |
| kernel_size=[1, 4], stride=[1, 2], | |
| padding=[0, 1]))) | |
| else: | |
| layers.append(("conv", nn.Conv2d(planes[0], planes[1], | |
| kernel_size=3, padding=1))) | |
| layers.append(("bn", nn.BatchNorm2d(planes[1], momentum=bn_d))) | |
| layers.append(("relu", nn.LeakyReLU(0.1))) | |
| # blocks | |
| layers.append(("residual", block(planes[1], planes, bn_d))) | |
| return nn.Sequential(OrderedDict(layers)) | |
| def run_layer(self, x, layer, skips, os): | |
| feats = layer(x) # up | |
| if feats.shape[-1] > x.shape[-1]: | |
| os //= 2 # match skip | |
| feats = feats + skips[os].detach() # add skip | |
| x = feats | |
| return x, skips, os | |
| def forward(self, x, skips, return_logits=False, return_list=None): | |
| os = self.backbone_OS | |
| out_dict = {} | |
| # run layers | |
| x, skips, os = self.run_layer(x, self.dec5, skips, os) | |
| if return_list and 'dec_4' in return_list: | |
| out_dict['dec_4'] = x.detach().cpu() # 512, 64, 64 | |
| x, skips, os = self.run_layer(x, self.dec4, skips, os) | |
| if return_list and 'dec_3' in return_list: | |
| out_dict['dec_3'] = x.detach().cpu() # 256, 64, 128 | |
| x, skips, os = self.run_layer(x, self.dec3, skips, os) | |
| if return_list and 'dec_2' in return_list: | |
| out_dict['dec_2'] = x.detach().cpu() # 128, 64, 256 | |
| x, skips, os = self.run_layer(x, self.dec2, skips, os) | |
| if return_list and 'dec_1' in return_list: | |
| out_dict['dec_1'] = x.detach().cpu() # 64, 64, 512 | |
| x, skips, os = self.run_layer(x, self.dec1, skips, os) | |
| if return_list and 'dec_0' in return_list: | |
| out_dict['dec_0'] = x.detach().cpu() # 32, 64, 1024 | |
| logits = torch.clone(x).detach() | |
| x = self.dropout(x) | |
| if return_logits: | |
| return x, logits | |
| if return_list is not None: | |
| return out_dict | |
| return x | |
| def get_last_depth(self): | |
| return self.last_channels | |
| class Model(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.backbone = Backbone(params=self.config["backbone"]) | |
| self.decoder = Decoder(params=self.config["decoder"], OS=self.config["backbone"]["OS"], | |
| feature_depth=self.backbone.get_last_depth()) | |
| def load_pretrained_weights(self, path): | |
| w_dict = torch.load(path + "/backbone", | |
| map_location=lambda storage, loc: storage) | |
| self.backbone.load_state_dict(w_dict, strict=True) | |
| w_dict = torch.load(path + "/segmentation_decoder", | |
| map_location=lambda storage, loc: storage) | |
| self.decoder.load_state_dict(w_dict, strict=True) | |
| def forward(self, x, return_logits=False, return_final_logits=False, return_list=None, agg_type='depth'): | |
| if return_logits: | |
| logits = self.backbone(x, return_logits) | |
| logits = F.adaptive_avg_pool2d(logits, (1, 1)).squeeze() | |
| logits = torch.clone(logits).detach().cpu().numpy() | |
| return logits | |
| elif return_list is not None: | |
| x, skips, enc_dict = self.backbone(x, return_list=return_list) | |
| dec_dict = self.decoder(x, skips, return_list=return_list) | |
| out_dict = {**enc_dict, **dec_dict} | |
| return out_dict | |
| elif return_final_logits: | |
| assert agg_type in ['all', 'sector', 'depth'] | |
| y, skips = self.backbone(x) | |
| y, logits = self.decoder(y, skips, True) | |
| B, C, H, W = logits.shape | |
| N = 16 | |
| # avg all | |
| if agg_type == 'all': | |
| logits = logits.mean([2, 3]) | |
| # avg in patch | |
| elif agg_type == 'sector': | |
| logits = logits.view(B, C, H, N, W // N).mean([2, 4]).reshape(B, -1) | |
| # avg in row | |
| elif agg_type == 'depth': | |
| logits = logits.view(B, C, N, H // N, W).mean([3, 4]).reshape(B, -1) | |
| logits = torch.clone(logits).detach().cpu().numpy() | |
| return logits | |
| else: | |
| y, skips = self.backbone(x) | |
| y = self.decoder(y, skips, False) | |
| return y | |