cwangrun's picture
Upload 156 files
5ebfade verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from ...ops import resize
from ..builder import HEADS
from .decode_head import DepthBaseDecodeHead
@HEADS.register_module()
class BNHead(DepthBaseDecodeHead):
"""Just a batchnorm."""
def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
super().__init__(**kwargs)
self.input_transform = input_transform
self.in_index = in_index
self.upsample = upsample
# self.bn = nn.SyncBatchNorm(self.in_channels)
if self.classify:
self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
else:
self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if "concat" in self.input_transform:
inputs = [inputs[i] for i in self.in_index]
if "resize" in self.input_transform:
inputs = [
resize(
input=x,
size=[s * self.upsample for s in inputs[0].shape[2:]],
mode="bilinear",
align_corners=self.align_corners,
)
for x in inputs
]
inputs = torch.cat(inputs, dim=1)
elif self.input_transform == "multiple_select":
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
def _forward_feature(self, inputs, img_metas=None, **kwargs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
# accept lists (for cls token)
inputs = list(inputs)
for i, x in enumerate(inputs):
if len(x) == 2:
x, cls_token = x[0], x[1]
if len(x.shape) == 2:
x = x[:, :, None, None]
cls_token = cls_token[:, :, None, None].expand_as(x)
inputs[i] = torch.cat((x, cls_token), 1)
else:
x = x[0]
if len(x.shape) == 2:
x = x[:, :, None, None]
inputs[i] = x
x = self._transform_inputs(inputs)
# feats = self.bn(x)
return x
def forward(self, inputs, img_metas=None, **kwargs):
"""Forward function."""
output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
output = self.depth_pred(output)
return output