cwangrun's picture
Upload 156 files
5ebfade verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from mmseg.models.builder import HEADS
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.ops import resize
@HEADS.register_module()
class BNHead(BaseDecodeHead):
"""Just a batchnorm."""
def __init__(self, resize_factors=None, **kwargs):
super().__init__(**kwargs)
assert self.in_channels == self.channels
self.bn = nn.SyncBatchNorm(self.in_channels)
self.resize_factors = resize_factors
def _forward_feature(self, inputs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
# print("inputs", [i.shape for i in inputs])
x = self._transform_inputs(inputs)
# print("x", x.shape)
feats = self.bn(x)
# print("feats", feats.shape)
return feats
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == "resize_concat":
# accept lists (for cls token)
input_list = []
for x in inputs:
if isinstance(x, list):
input_list.extend(x)
else:
input_list.append(x)
inputs = input_list
# an image descriptor can be a local descriptor with resolution 1x1
for i, x in enumerate(inputs):
if len(x.shape) == 2:
inputs[i] = x[:, :, None, None]
# select indices
inputs = [inputs[i] for i in self.in_index]
# Resizing shenanigans
# print("before", *(x.shape for x in inputs))
if self.resize_factors is not None:
assert len(self.resize_factors) == len(inputs), (len(self.resize_factors), len(inputs))
inputs = [
resize(input=x, scale_factor=f, mode="bilinear" if f >= 1 else "area")
for x, f in zip(inputs, self.resize_factors)
]
# print("after", *(x.shape for x in inputs))
upsampled_inputs = [
resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners)
for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == "multiple_select":
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
def forward(self, inputs):
"""Forward function."""
output = self._forward_feature(inputs)
output = self.cls_seg(output)
return output