Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the Apache License, Version 2.0 | |
| # found in the LICENSE file in the root directory of this source tree. | |
| import copy | |
| from functools import partial | |
| import math | |
| import warnings | |
| import torch | |
| import torch.nn as nn | |
| from .ops import resize | |
| # XXX: (Untested) replacement for mmcv.imdenormalize() | |
| def _imdenormalize(img, mean, std, to_bgr=True): | |
| import numpy as np | |
| mean = mean.reshape(1, -1).astype(np.float64) | |
| std = std.reshape(1, -1).astype(np.float64) | |
| img = (img * std) + mean | |
| if to_bgr: | |
| img = img[::-1] | |
| return img | |
| class DepthBaseDecodeHead(nn.Module): | |
| """Base class for BaseDecodeHead. | |
| Args: | |
| in_channels (List): Input channels. | |
| channels (int): Channels after modules, before conv_depth. | |
| conv_layer (nn.Module): Conv layers. Default: None. | |
| act_layer (nn.Module): Activation layers. Default: nn.ReLU. | |
| loss_decode (dict): Config of decode loss. | |
| Default: (). | |
| sampler (dict|None): The config of depth map sampler. | |
| Default: None. | |
| align_corners (bool): align_corners argument of F.interpolate. | |
| Default: False. | |
| min_depth (int): Min depth in dataset setting. | |
| Default: 1e-3. | |
| max_depth (int): Max depth in dataset setting. | |
| Default: None. | |
| norm_layer (dict|None): Norm layers. | |
| Default: None. | |
| classify (bool): Whether predict depth in a cls.-reg. manner. | |
| Default: False. | |
| n_bins (int): The number of bins used in cls. step. | |
| Default: 256. | |
| bins_strategy (str): The discrete strategy used in cls. step. | |
| Default: 'UD'. | |
| norm_strategy (str): The norm strategy on cls. probability | |
| distribution. Default: 'linear' | |
| scale_up (str): Whether predict depth in a scale-up manner. | |
| Default: False. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| conv_layer=None, | |
| act_layer=nn.ReLU, | |
| channels=96, | |
| loss_decode=(), | |
| sampler=None, | |
| align_corners=False, | |
| min_depth=1e-3, | |
| max_depth=None, | |
| norm_layer=None, | |
| classify=False, | |
| n_bins=256, | |
| bins_strategy="UD", | |
| norm_strategy="linear", | |
| scale_up=False, | |
| ): | |
| super(DepthBaseDecodeHead, self).__init__() | |
| self.in_channels = in_channels | |
| self.channels = channels | |
| self.conf_layer = conv_layer | |
| self.act_layer = act_layer | |
| self.loss_decode = loss_decode | |
| self.align_corners = align_corners | |
| self.min_depth = min_depth | |
| self.max_depth = max_depth | |
| self.norm_layer = norm_layer | |
| self.classify = classify | |
| self.n_bins = n_bins | |
| self.scale_up = scale_up | |
| if self.classify: | |
| assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" | |
| assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" | |
| self.bins_strategy = bins_strategy | |
| self.norm_strategy = norm_strategy | |
| self.softmax = nn.Softmax(dim=1) | |
| self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) | |
| else: | |
| self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) | |
| self.relu = nn.ReLU() | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, inputs, img_metas): | |
| """Placeholder of forward function.""" | |
| pass | |
| def forward_train(self, img, inputs, img_metas, depth_gt): | |
| """Forward function for training. | |
| Args: | |
| inputs (list[Tensor]): List of multi-level img features. | |
| img_metas (list[dict]): List of image info dict where each dict | |
| has: 'img_shape', 'scale_factor', 'flip', and may also contain | |
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |
| For details on the values of these keys see | |
| `depth/datasets/pipelines/formatting.py:Collect`. | |
| depth_gt (Tensor): GT depth | |
| Returns: | |
| dict[str, Tensor]: a dictionary of loss components | |
| """ | |
| depth_pred = self.forward(inputs, img_metas) | |
| losses = self.losses(depth_pred, depth_gt) | |
| log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) | |
| losses.update(**log_imgs) | |
| return losses | |
| def forward_test(self, inputs, img_metas): | |
| """Forward function for testing. | |
| Args: | |
| inputs (list[Tensor]): List of multi-level img features. | |
| img_metas (list[dict]): List of image info dict where each dict | |
| has: 'img_shape', 'scale_factor', 'flip', and may also contain | |
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |
| For details on the values of these keys see | |
| `depth/datasets/pipelines/formatting.py:Collect`. | |
| Returns: | |
| Tensor: Output depth map. | |
| """ | |
| return self.forward(inputs, img_metas) | |
| def depth_pred(self, feat): | |
| """Prediction each pixel.""" | |
| if self.classify: | |
| logit = self.conv_depth(feat) | |
| if self.bins_strategy == "UD": | |
| bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) | |
| elif self.bins_strategy == "SID": | |
| bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) | |
| # following Adabins, default linear | |
| if self.norm_strategy == "linear": | |
| logit = torch.relu(logit) | |
| eps = 0.1 | |
| logit = logit + eps | |
| logit = logit / logit.sum(dim=1, keepdim=True) | |
| elif self.norm_strategy == "softmax": | |
| logit = torch.softmax(logit, dim=1) | |
| elif self.norm_strategy == "sigmoid": | |
| logit = torch.sigmoid(logit) | |
| logit = logit / logit.sum(dim=1, keepdim=True) | |
| output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) | |
| else: | |
| if self.scale_up: | |
| output = self.sigmoid(self.conv_depth(feat)) * self.max_depth | |
| else: | |
| output = self.relu(self.conv_depth(feat)) + self.min_depth | |
| return output | |
| def losses(self, depth_pred, depth_gt): | |
| """Compute depth loss.""" | |
| loss = dict() | |
| depth_pred = resize( | |
| input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False | |
| ) | |
| if not isinstance(self.loss_decode, nn.ModuleList): | |
| losses_decode = [self.loss_decode] | |
| else: | |
| losses_decode = self.loss_decode | |
| for loss_decode in losses_decode: | |
| if loss_decode.loss_name not in loss: | |
| loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) | |
| else: | |
| loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) | |
| return loss | |
| def log_images(self, img_path, depth_pred, depth_gt, img_meta): | |
| import numpy as np | |
| show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) | |
| show_img = show_img.numpy().astype(np.float32) | |
| show_img = _imdenormalize( | |
| show_img, | |
| img_meta["img_norm_cfg"]["mean"], | |
| img_meta["img_norm_cfg"]["std"], | |
| img_meta["img_norm_cfg"]["to_rgb"], | |
| ) | |
| show_img = np.clip(show_img, 0, 255) | |
| show_img = show_img.astype(np.uint8) | |
| show_img = show_img[:, :, ::-1] | |
| show_img = show_img.transpose(0, 2, 1) | |
| show_img = show_img.transpose(1, 0, 2) | |
| depth_pred = depth_pred / torch.max(depth_pred) | |
| depth_gt = depth_gt / torch.max(depth_gt) | |
| depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) | |
| depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) | |
| return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} | |
| class BNHead(DepthBaseDecodeHead): | |
| """Just a batchnorm.""" | |
| def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): | |
| super().__init__(**kwargs) | |
| self.input_transform = input_transform | |
| self.in_index = in_index | |
| self.upsample = upsample | |
| # self.bn = nn.SyncBatchNorm(self.in_channels) | |
| if self.classify: | |
| self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) | |
| else: | |
| self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) | |
| def _transform_inputs(self, inputs): | |
| """Transform inputs for decoder. | |
| Args: | |
| inputs (list[Tensor]): List of multi-level img features. | |
| Returns: | |
| Tensor: The transformed inputs | |
| """ | |
| if "concat" in self.input_transform: | |
| inputs = [inputs[i] for i in self.in_index] | |
| if "resize" in self.input_transform: | |
| inputs = [ | |
| resize( | |
| input=x, | |
| size=[s * self.upsample for s in inputs[0].shape[2:]], | |
| mode="bilinear", | |
| align_corners=self.align_corners, | |
| ) | |
| for x in inputs | |
| ] | |
| inputs = torch.cat(inputs, dim=1) | |
| elif self.input_transform == "multiple_select": | |
| inputs = [inputs[i] for i in self.in_index] | |
| else: | |
| inputs = inputs[self.in_index] | |
| return inputs | |
| def _forward_feature(self, inputs, img_metas=None, **kwargs): | |
| """Forward function for feature maps before classifying each pixel with | |
| ``self.cls_seg`` fc. | |
| Args: | |
| inputs (list[Tensor]): List of multi-level img features. | |
| Returns: | |
| feats (Tensor): A tensor of shape (batch_size, self.channels, | |
| H, W) which is feature map for last layer of decoder head. | |
| """ | |
| # accept lists (for cls token) | |
| inputs = list(inputs) | |
| for i, x in enumerate(inputs): | |
| if len(x) == 2: | |
| x, cls_token = x[0], x[1] | |
| if len(x.shape) == 2: | |
| x = x[:, :, None, None] | |
| cls_token = cls_token[:, :, None, None].expand_as(x) | |
| inputs[i] = torch.cat((x, cls_token), 1) | |
| else: | |
| x = x[0] | |
| if len(x.shape) == 2: | |
| x = x[:, :, None, None] | |
| inputs[i] = x | |
| x = self._transform_inputs(inputs) | |
| # feats = self.bn(x) | |
| return x | |
| def forward(self, inputs, img_metas=None, **kwargs): | |
| """Forward function.""" | |
| output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) | |
| output = self.depth_pred(output) | |
| return output | |
| class ConvModule(nn.Module): | |
| """A conv block that bundles conv/norm/activation layers. | |
| This block simplifies the usage of convolution layers, which are commonly | |
| used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). | |
| It is based upon three build methods: `build_conv_layer()`, | |
| `build_norm_layer()` and `build_activation_layer()`. | |
| Besides, we add some additional features in this module. | |
| 1. Automatically set `bias` of the conv layer. | |
| 2. Spectral norm is supported. | |
| 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only | |
| supports zero and circular padding, and we add "reflect" padding mode. | |
| Args: | |
| in_channels (int): Number of channels in the input feature map. | |
| Same as that in ``nn._ConvNd``. | |
| out_channels (int): Number of channels produced by the convolution. | |
| Same as that in ``nn._ConvNd``. | |
| kernel_size (int | tuple[int]): Size of the convolving kernel. | |
| Same as that in ``nn._ConvNd``. | |
| stride (int | tuple[int]): Stride of the convolution. | |
| Same as that in ``nn._ConvNd``. | |
| padding (int | tuple[int]): Zero-padding added to both sides of | |
| the input. Same as that in ``nn._ConvNd``. | |
| dilation (int | tuple[int]): Spacing between kernel elements. | |
| Same as that in ``nn._ConvNd``. | |
| groups (int): Number of blocked connections from input channels to | |
| output channels. Same as that in ``nn._ConvNd``. | |
| bias (bool | str): If specified as `auto`, it will be decided by the | |
| norm_layer. Bias will be set as True if `norm_layer` is None, otherwise | |
| False. Default: "auto". | |
| conv_layer (nn.Module): Convolution layer. Default: None, | |
| which means using conv2d. | |
| norm_layer (nn.Module): Normalization layer. Default: None. | |
| act_layer (nn.Module): Activation layer. Default: nn.ReLU. | |
| inplace (bool): Whether to use inplace mode for activation. | |
| Default: True. | |
| with_spectral_norm (bool): Whether use spectral norm in conv module. | |
| Default: False. | |
| padding_mode (str): If the `padding_mode` has not been supported by | |
| current `Conv2d` in PyTorch, we will use our own padding layer | |
| instead. Currently, we support ['zeros', 'circular'] with official | |
| implementation and ['reflect'] with our own implementation. | |
| Default: 'zeros'. | |
| order (tuple[str]): The order of conv/norm/activation layers. It is a | |
| sequence of "conv", "norm" and "act". Common examples are | |
| ("conv", "norm", "act") and ("act", "conv", "norm"). | |
| Default: ('conv', 'norm', 'act'). | |
| """ | |
| _abbr_ = "conv_block" | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| bias="auto", | |
| conv_layer=nn.Conv2d, | |
| norm_layer=None, | |
| act_layer=nn.ReLU, | |
| inplace=True, | |
| with_spectral_norm=False, | |
| padding_mode="zeros", | |
| order=("conv", "norm", "act"), | |
| ): | |
| super(ConvModule, self).__init__() | |
| official_padding_mode = ["zeros", "circular"] | |
| self.conv_layer = conv_layer | |
| self.norm_layer = norm_layer | |
| self.act_layer = act_layer | |
| self.inplace = inplace | |
| self.with_spectral_norm = with_spectral_norm | |
| self.with_explicit_padding = padding_mode not in official_padding_mode | |
| self.order = order | |
| assert isinstance(self.order, tuple) and len(self.order) == 3 | |
| assert set(order) == set(["conv", "norm", "act"]) | |
| self.with_norm = norm_layer is not None | |
| self.with_activation = act_layer is not None | |
| # if the conv layer is before a norm layer, bias is unnecessary. | |
| if bias == "auto": | |
| bias = not self.with_norm | |
| self.with_bias = bias | |
| if self.with_explicit_padding: | |
| if padding_mode == "zeros": | |
| padding_layer = nn.ZeroPad2d | |
| else: | |
| raise AssertionError(f"Unsupported padding mode: {padding_mode}") | |
| self.pad = padding_layer(padding) | |
| # reset padding to 0 for conv module | |
| conv_padding = 0 if self.with_explicit_padding else padding | |
| # build convolution layer | |
| self.conv = self.conv_layer( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=stride, | |
| padding=conv_padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias, | |
| ) | |
| # export the attributes of self.conv to a higher level for convenience | |
| self.in_channels = self.conv.in_channels | |
| self.out_channels = self.conv.out_channels | |
| self.kernel_size = self.conv.kernel_size | |
| self.stride = self.conv.stride | |
| self.padding = padding | |
| self.dilation = self.conv.dilation | |
| self.transposed = self.conv.transposed | |
| self.output_padding = self.conv.output_padding | |
| self.groups = self.conv.groups | |
| if self.with_spectral_norm: | |
| self.conv = nn.utils.spectral_norm(self.conv) | |
| # build normalization layers | |
| if self.with_norm: | |
| # norm layer is after conv layer | |
| if order.index("norm") > order.index("conv"): | |
| norm_channels = out_channels | |
| else: | |
| norm_channels = in_channels | |
| norm = partial(norm_layer, num_features=norm_channels) | |
| self.add_module("norm", norm) | |
| if self.with_bias: | |
| from torch.nnModules.batchnorm import _BatchNorm | |
| from torch.nnModules.instancenorm import _InstanceNorm | |
| if isinstance(norm, (_BatchNorm, _InstanceNorm)): | |
| warnings.warn("Unnecessary conv bias before batch/instance norm") | |
| else: | |
| self.norm_name = None | |
| # build activation layer | |
| if self.with_activation: | |
| # nn.Tanh has no 'inplace' argument | |
| # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU) | |
| if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)): | |
| act_layer = partial(act_layer, inplace=inplace) | |
| self.activate = act_layer() | |
| # Use msra init by default | |
| self.init_weights() | |
| def norm(self): | |
| if self.norm_name: | |
| return getattr(self, self.norm_name) | |
| else: | |
| return None | |
| def init_weights(self): | |
| # 1. It is mainly for customized conv layers with their own | |
| # initialization manners by calling their own ``init_weights()``, | |
| # and we do not want ConvModule to override the initialization. | |
| # 2. For customized conv layers without their own initialization | |
| # manners (that is, they don't have their own ``init_weights()``) | |
| # and PyTorch's conv layers, they will be initialized by | |
| # this method with default ``kaiming_init``. | |
| # Note: For PyTorch's conv layers, they will be overwritten by our | |
| # initialization implementation using default ``kaiming_init``. | |
| if not hasattr(self.conv, "init_weights"): | |
| if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU): | |
| nonlinearity = "leaky_relu" | |
| a = 0.01 # XXX: default negative_slope | |
| else: | |
| nonlinearity = "relu" | |
| a = 0 | |
| if hasattr(self.conv, "weight") and self.conv.weight is not None: | |
| nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity) | |
| if hasattr(self.conv, "bias") and self.conv.bias is not None: | |
| nn.init.constant_(self.conv.bias, 0) | |
| if self.with_norm: | |
| if hasattr(self.norm, "weight") and self.norm.weight is not None: | |
| nn.init.constant_(self.norm.weight, 1) | |
| if hasattr(self.norm, "bias") and self.norm.bias is not None: | |
| nn.init.constant_(self.norm.bias, 0) | |
| def forward(self, x, activate=True, norm=True): | |
| for layer in self.order: | |
| if layer == "conv": | |
| if self.with_explicit_padding: | |
| x = self.pad(x) | |
| x = self.conv(x) | |
| elif layer == "norm" and norm and self.with_norm: | |
| x = self.norm(x) | |
| elif layer == "act" and activate and self.with_activation: | |
| x = self.activate(x) | |
| return x | |
| class Interpolate(nn.Module): | |
| def __init__(self, scale_factor, mode, align_corners=False): | |
| super(Interpolate, self).__init__() | |
| self.interp = nn.functional.interpolate | |
| self.scale_factor = scale_factor | |
| self.mode = mode | |
| self.align_corners = align_corners | |
| def forward(self, x): | |
| x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) | |
| return x | |
| class HeadDepth(nn.Module): | |
| def __init__(self, features): | |
| super(HeadDepth, self).__init__() | |
| self.head = nn.Sequential( | |
| nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), | |
| Interpolate(scale_factor=2, mode="bilinear", align_corners=True), | |
| nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(), | |
| nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), | |
| ) | |
| def forward(self, x): | |
| x = self.head(x) | |
| return x | |
| class ReassembleBlocks(nn.Module): | |
| """ViTPostProcessBlock, process cls_token in ViT backbone output and | |
| rearrange the feature vector to feature map. | |
| Args: | |
| in_channels (int): ViT feature channels. Default: 768. | |
| out_channels (List): output channels of each stage. | |
| Default: [96, 192, 384, 768]. | |
| readout_type (str): Type of readout operation. Default: 'ignore'. | |
| patch_size (int): The patch size. Default: 16. | |
| """ | |
| def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16): | |
| super(ReassembleBlocks, self).__init__() | |
| assert readout_type in ["ignore", "add", "project"] | |
| self.readout_type = readout_type | |
| self.patch_size = patch_size | |
| self.projects = nn.ModuleList( | |
| [ | |
| ConvModule( | |
| in_channels=in_channels, | |
| out_channels=out_channel, | |
| kernel_size=1, | |
| act_layer=None, | |
| ) | |
| for out_channel in out_channels | |
| ] | |
| ) | |
| self.resize_layers = nn.ModuleList( | |
| [ | |
| nn.ConvTranspose2d( | |
| in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 | |
| ), | |
| nn.ConvTranspose2d( | |
| in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 | |
| ), | |
| nn.Identity(), | |
| nn.Conv2d( | |
| in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 | |
| ), | |
| ] | |
| ) | |
| if self.readout_type == "project": | |
| self.readout_projects = nn.ModuleList() | |
| for _ in range(len(self.projects)): | |
| self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU())) | |
| def forward(self, inputs): | |
| assert isinstance(inputs, list) | |
| out = [] | |
| for i, x in enumerate(inputs): | |
| assert len(x) == 2 | |
| x, cls_token = x[0], x[1] | |
| feature_shape = x.shape | |
| if self.readout_type == "project": | |
| x = x.flatten(2).permute((0, 2, 1)) | |
| readout = cls_token.unsqueeze(1).expand_as(x) | |
| x = self.readout_projects[i](torch.cat((x, readout), -1)) | |
| x = x.permute(0, 2, 1).reshape(feature_shape) | |
| elif self.readout_type == "add": | |
| x = x.flatten(2) + cls_token.unsqueeze(-1) | |
| x = x.reshape(feature_shape) | |
| else: | |
| pass | |
| x = self.projects[i](x) | |
| x = self.resize_layers[i](x) | |
| out.append(x) | |
| return out | |
| class PreActResidualConvUnit(nn.Module): | |
| """ResidualConvUnit, pre-activate residual unit. | |
| Args: | |
| in_channels (int): number of channels in the input feature map. | |
| act_layer (nn.Module): activation layer. | |
| norm_layer (nn.Module): norm layer. | |
| stride (int): stride of the first block. Default: 1 | |
| dilation (int): dilation rate for convs layers. Default: 1. | |
| """ | |
| def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1): | |
| super(PreActResidualConvUnit, self).__init__() | |
| self.conv1 = ConvModule( | |
| in_channels, | |
| in_channels, | |
| 3, | |
| stride=stride, | |
| padding=dilation, | |
| dilation=dilation, | |
| norm_layer=norm_layer, | |
| act_layer=act_layer, | |
| bias=False, | |
| order=("act", "conv", "norm"), | |
| ) | |
| self.conv2 = ConvModule( | |
| in_channels, | |
| in_channels, | |
| 3, | |
| padding=1, | |
| norm_layer=norm_layer, | |
| act_layer=act_layer, | |
| bias=False, | |
| order=("act", "conv", "norm"), | |
| ) | |
| def forward(self, inputs): | |
| inputs_ = inputs.clone() | |
| x = self.conv1(inputs) | |
| x = self.conv2(x) | |
| return x + inputs_ | |
| class FeatureFusionBlock(nn.Module): | |
| """FeatureFusionBlock, merge feature map from different stages. | |
| Args: | |
| in_channels (int): Input channels. | |
| act_layer (nn.Module): activation layer for ResidualConvUnit. | |
| norm_layer (nn.Module): normalization layer. | |
| expand (bool): Whether expand the channels in post process block. | |
| Default: False. | |
| align_corners (bool): align_corner setting for bilinear upsample. | |
| Default: True. | |
| """ | |
| def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True): | |
| super(FeatureFusionBlock, self).__init__() | |
| self.in_channels = in_channels | |
| self.expand = expand | |
| self.align_corners = align_corners | |
| self.out_channels = in_channels | |
| if self.expand: | |
| self.out_channels = in_channels // 2 | |
| self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True) | |
| self.res_conv_unit1 = PreActResidualConvUnit( | |
| in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer | |
| ) | |
| self.res_conv_unit2 = PreActResidualConvUnit( | |
| in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer | |
| ) | |
| def forward(self, *inputs): | |
| x = inputs[0] | |
| if len(inputs) == 2: | |
| if x.shape != inputs[1].shape: | |
| res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) | |
| else: | |
| res = inputs[1] | |
| x = x + self.res_conv_unit1(res) | |
| x = self.res_conv_unit2(x) | |
| x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) | |
| x = self.project(x) | |
| return x | |
| class DPTHead(DepthBaseDecodeHead): | |
| """Vision Transformers for Dense Prediction. | |
| This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_. | |
| Args: | |
| embed_dims (int): The embed dimension of the ViT backbone. | |
| Default: 768. | |
| post_process_channels (List): Out channels of post process conv | |
| layers. Default: [96, 192, 384, 768]. | |
| readout_type (str): Type of readout operation. Default: 'ignore'. | |
| patch_size (int): The patch size. Default: 16. | |
| expand_channels (bool): Whether expand the channels in post process | |
| block. Default: False. | |
| """ | |
| def __init__( | |
| self, | |
| embed_dims=768, | |
| post_process_channels=[96, 192, 384, 768], | |
| readout_type="ignore", | |
| patch_size=16, | |
| expand_channels=False, | |
| **kwargs, | |
| ): | |
| super(DPTHead, self).__init__(**kwargs) | |
| self.in_channels = self.in_channels | |
| self.expand_channels = expand_channels | |
| self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size) | |
| self.post_process_channels = [ | |
| channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels) | |
| ] | |
| self.convs = nn.ModuleList() | |
| for channel in self.post_process_channels: | |
| self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False)) | |
| self.fusion_blocks = nn.ModuleList() | |
| for _ in range(len(self.convs)): | |
| self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer)) | |
| self.fusion_blocks[0].res_conv_unit1 = None | |
| self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer) | |
| self.num_fusion_blocks = len(self.fusion_blocks) | |
| self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) | |
| self.num_post_process_channels = len(self.post_process_channels) | |
| assert self.num_fusion_blocks == self.num_reassemble_blocks | |
| assert self.num_reassemble_blocks == self.num_post_process_channels | |
| self.conv_depth = HeadDepth(self.channels) | |
| def forward(self, inputs, img_metas): | |
| assert len(inputs) == self.num_reassemble_blocks | |
| x = [inp for inp in inputs] | |
| x = self.reassemble_blocks(x) | |
| x = [self.convs[i](feature) for i, feature in enumerate(x)] | |
| out = self.fusion_blocks[0](x[-1]) | |
| for i in range(1, len(self.fusion_blocks)): | |
| out = self.fusion_blocks[i](out, x[-(i + 1)]) | |
| out = self.project(out) | |
| out = self.depth_pred(out) | |
| return out | |