| | |
| | |
| | |
| | |
| |
|
| | import copy |
| | from abc import ABCMeta, abstractmethod |
| |
|
| | import mmcv |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from mmcv.runner import BaseModule, auto_fp16, force_fp32 |
| |
|
| | from ...ops import resize |
| | from ..builder import build_loss |
| |
|
| |
|
| | class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta): |
| | """Base class for BaseDecodeHead. |
| | |
| | Args: |
| | in_channels (List): Input channels. |
| | channels (int): Channels after modules, before conv_depth. |
| | conv_cfg (dict|None): Config of conv layers. Default: None. |
| | act_cfg (dict): Config of activation layers. |
| | Default: dict(type='ReLU') |
| | loss_decode (dict): Config of decode loss. |
| | Default: dict(type='SigLoss'). |
| | sampler (dict|None): The config of depth map sampler. |
| | Default: None. |
| | align_corners (bool): align_corners argument of F.interpolate. |
| | Default: False. |
| | min_depth (int): Min depth in dataset setting. |
| | Default: 1e-3. |
| | max_depth (int): Max depth in dataset setting. |
| | Default: None. |
| | norm_cfg (dict|None): Config of norm layers. |
| | Default: None. |
| | classify (bool): Whether predict depth in a cls.-reg. manner. |
| | Default: False. |
| | n_bins (int): The number of bins used in cls. step. |
| | Default: 256. |
| | bins_strategy (str): The discrete strategy used in cls. step. |
| | Default: 'UD'. |
| | norm_strategy (str): The norm strategy on cls. probability |
| | distribution. Default: 'linear' |
| | scale_up (str): Whether predict depth in a scale-up manner. |
| | Default: False. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_channels, |
| | channels=96, |
| | conv_cfg=None, |
| | act_cfg=dict(type="ReLU"), |
| | loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10), |
| | sampler=None, |
| | align_corners=False, |
| | min_depth=1e-3, |
| | max_depth=None, |
| | norm_cfg=None, |
| | classify=False, |
| | n_bins=256, |
| | bins_strategy="UD", |
| | norm_strategy="linear", |
| | scale_up=False, |
| | ): |
| | super(DepthBaseDecodeHead, self).__init__() |
| |
|
| | self.in_channels = in_channels |
| | self.channels = channels |
| | self.conv_cfg = conv_cfg |
| | self.act_cfg = act_cfg |
| | if isinstance(loss_decode, dict): |
| | self.loss_decode = build_loss(loss_decode) |
| | elif isinstance(loss_decode, (list, tuple)): |
| | self.loss_decode = nn.ModuleList() |
| | for loss in loss_decode: |
| | self.loss_decode.append(build_loss(loss)) |
| | self.align_corners = align_corners |
| | self.min_depth = min_depth |
| | self.max_depth = max_depth |
| | self.norm_cfg = norm_cfg |
| | self.classify = classify |
| | self.n_bins = n_bins |
| | self.scale_up = scale_up |
| |
|
| | if self.classify: |
| | assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" |
| | assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" |
| |
|
| | self.bins_strategy = bins_strategy |
| | self.norm_strategy = norm_strategy |
| | self.softmax = nn.Softmax(dim=1) |
| | self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) |
| | else: |
| | self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) |
| |
|
| | self.fp16_enabled = False |
| | self.relu = nn.ReLU() |
| | self.sigmoid = nn.Sigmoid() |
| |
|
| | def extra_repr(self): |
| | """Extra repr.""" |
| | s = f"align_corners={self.align_corners}" |
| | return s |
| |
|
| | @auto_fp16() |
| | @abstractmethod |
| | def forward(self, inputs, img_metas): |
| | """Placeholder of forward function.""" |
| | pass |
| |
|
| | def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg): |
| | """Forward function for training. |
| | Args: |
| | inputs (list[Tensor]): List of multi-level img features. |
| | img_metas (list[dict]): List of image info dict where each dict |
| | has: 'img_shape', 'scale_factor', 'flip', and may also contain |
| | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. |
| | For details on the values of these keys see |
| | `depth/datasets/pipelines/formatting.py:Collect`. |
| | depth_gt (Tensor): GT depth |
| | train_cfg (dict): The training config. |
| | |
| | Returns: |
| | dict[str, Tensor]: a dictionary of loss components |
| | """ |
| | depth_pred = self.forward(inputs, img_metas) |
| | losses = self.losses(depth_pred, depth_gt) |
| |
|
| | log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) |
| | losses.update(**log_imgs) |
| |
|
| | return losses |
| |
|
| | def forward_test(self, inputs, img_metas, test_cfg): |
| | """Forward function for testing. |
| | Args: |
| | inputs (list[Tensor]): List of multi-level img features. |
| | img_metas (list[dict]): List of image info dict where each dict |
| | has: 'img_shape', 'scale_factor', 'flip', and may also contain |
| | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. |
| | For details on the values of these keys see |
| | `depth/datasets/pipelines/formatting.py:Collect`. |
| | test_cfg (dict): The testing config. |
| | |
| | Returns: |
| | Tensor: Output depth map. |
| | """ |
| | return self.forward(inputs, img_metas) |
| |
|
| | def depth_pred(self, feat): |
| | """Prediction each pixel.""" |
| | if self.classify: |
| | logit = self.conv_depth(feat) |
| |
|
| | if self.bins_strategy == "UD": |
| | bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) |
| | elif self.bins_strategy == "SID": |
| | bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) |
| |
|
| | |
| | if self.norm_strategy == "linear": |
| | logit = torch.relu(logit) |
| | eps = 0.1 |
| | logit = logit + eps |
| | logit = logit / logit.sum(dim=1, keepdim=True) |
| | elif self.norm_strategy == "softmax": |
| | logit = torch.softmax(logit, dim=1) |
| | elif self.norm_strategy == "sigmoid": |
| | logit = torch.sigmoid(logit) |
| | logit = logit / logit.sum(dim=1, keepdim=True) |
| |
|
| | output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) |
| |
|
| | else: |
| | if self.scale_up: |
| | output = self.sigmoid(self.conv_depth(feat)) * self.max_depth |
| | else: |
| | output = self.relu(self.conv_depth(feat)) + self.min_depth |
| | return output |
| |
|
| | @force_fp32(apply_to=("depth_pred",)) |
| | def losses(self, depth_pred, depth_gt): |
| | """Compute depth loss.""" |
| | loss = dict() |
| | depth_pred = resize( |
| | input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False |
| | ) |
| | if not isinstance(self.loss_decode, nn.ModuleList): |
| | losses_decode = [self.loss_decode] |
| | else: |
| | losses_decode = self.loss_decode |
| | for loss_decode in losses_decode: |
| | if loss_decode.loss_name not in loss: |
| | loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) |
| | else: |
| | loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) |
| | return loss |
| |
|
| | def log_images(self, img_path, depth_pred, depth_gt, img_meta): |
| | show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) |
| | show_img = show_img.numpy().astype(np.float32) |
| | show_img = mmcv.imdenormalize( |
| | show_img, |
| | img_meta["img_norm_cfg"]["mean"], |
| | img_meta["img_norm_cfg"]["std"], |
| | img_meta["img_norm_cfg"]["to_rgb"], |
| | ) |
| | show_img = np.clip(show_img, 0, 255) |
| | show_img = show_img.astype(np.uint8) |
| | show_img = show_img[:, :, ::-1] |
| | show_img = show_img.transpose(0, 2, 1) |
| | show_img = show_img.transpose(1, 0, 2) |
| |
|
| | depth_pred = depth_pred / torch.max(depth_pred) |
| | depth_gt = depth_gt / torch.max(depth_gt) |
| |
|
| | depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) |
| | depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) |
| |
|
| | return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} |
| |
|