File size: 8,653 Bytes
3e0e9f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | # Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import copy
from abc import ABCMeta, abstractmethod
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner import BaseModule, auto_fp16, force_fp32
from ...ops import resize
from ..builder import build_loss
class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
in_channels (List): Input channels.
channels (int): Channels after modules, before conv_depth.
conv_cfg (dict|None): Config of conv layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
loss_decode (dict): Config of decode loss.
Default: dict(type='SigLoss').
sampler (dict|None): The config of depth map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
min_depth (int): Min depth in dataset setting.
Default: 1e-3.
max_depth (int): Max depth in dataset setting.
Default: None.
norm_cfg (dict|None): Config of norm layers.
Default: None.
classify (bool): Whether predict depth in a cls.-reg. manner.
Default: False.
n_bins (int): The number of bins used in cls. step.
Default: 256.
bins_strategy (str): The discrete strategy used in cls. step.
Default: 'UD'.
norm_strategy (str): The norm strategy on cls. probability
distribution. Default: 'linear'
scale_up (str): Whether predict depth in a scale-up manner.
Default: False.
"""
def __init__(
self,
in_channels,
channels=96,
conv_cfg=None,
act_cfg=dict(type="ReLU"),
loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10),
sampler=None,
align_corners=False,
min_depth=1e-3,
max_depth=None,
norm_cfg=None,
classify=False,
n_bins=256,
bins_strategy="UD",
norm_strategy="linear",
scale_up=False,
):
super(DepthBaseDecodeHead, self).__init__()
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.act_cfg = act_cfg
if isinstance(loss_decode, dict):
self.loss_decode = build_loss(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
self.align_corners = align_corners
self.min_depth = min_depth
self.max_depth = max_depth
self.norm_cfg = norm_cfg
self.classify = classify
self.n_bins = n_bins
self.scale_up = scale_up
if self.classify:
assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
self.bins_strategy = bins_strategy
self.norm_strategy = norm_strategy
self.softmax = nn.Softmax(dim=1)
self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
else:
self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
self.fp16_enabled = False
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def extra_repr(self):
"""Extra repr."""
s = f"align_corners={self.align_corners}"
return s
@auto_fp16()
@abstractmethod
def forward(self, inputs, img_metas):
"""Placeholder of forward function."""
pass
def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
depth_gt (Tensor): GT depth
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
depth_pred = self.forward(inputs, img_metas)
losses = self.losses(depth_pred, depth_gt)
log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
losses.update(**log_imgs)
return losses
def forward_test(self, inputs, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output depth map.
"""
return self.forward(inputs, img_metas)
def depth_pred(self, feat):
"""Prediction each pixel."""
if self.classify:
logit = self.conv_depth(feat)
if self.bins_strategy == "UD":
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
elif self.bins_strategy == "SID":
bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
# following Adabins, default linear
if self.norm_strategy == "linear":
logit = torch.relu(logit)
eps = 0.1
logit = logit + eps
logit = logit / logit.sum(dim=1, keepdim=True)
elif self.norm_strategy == "softmax":
logit = torch.softmax(logit, dim=1)
elif self.norm_strategy == "sigmoid":
logit = torch.sigmoid(logit)
logit = logit / logit.sum(dim=1, keepdim=True)
output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
else:
if self.scale_up:
output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
else:
output = self.relu(self.conv_depth(feat)) + self.min_depth
return output
@force_fp32(apply_to=("depth_pred",))
def losses(self, depth_pred, depth_gt):
"""Compute depth loss."""
loss = dict()
depth_pred = resize(
input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
)
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
else:
loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
return loss
def log_images(self, img_path, depth_pred, depth_gt, img_meta):
show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
show_img = show_img.numpy().astype(np.float32)
show_img = mmcv.imdenormalize(
show_img,
img_meta["img_norm_cfg"]["mean"],
img_meta["img_norm_cfg"]["std"],
img_meta["img_norm_cfg"]["to_rgb"],
)
show_img = np.clip(show_img, 0, 255)
show_img = show_img.astype(np.uint8)
show_img = show_img[:, :, ::-1]
show_img = show_img.transpose(0, 2, 1)
show_img = show_img.transpose(1, 0, 2)
depth_pred = depth_pred / torch.max(depth_pred)
depth_gt = depth_gt / torch.max(depth_gt)
depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
|