maverickrzw's picture
des
2402804
from typing import Callable, Optional, Union
from torch import nn
from torch.nn import functional as F
from mmdet.registry import MODELS
from .transformer_blocks import (Conv2d, PositionEmbeddingSine,
TransformerEncoder, TransformerEncoderLayer,
get_norm)
# modified from https://github.com/microsoft/X-Decoder/blob/main/xdecoder/body/encoder/transformer_encoder_fpn.py # noqa
class TransformerEncoderOnly(nn.Module):
def __init__(self,
d_model=512,
nhead=8,
num_encoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation='relu',
normalize_before=False):
super().__init__()
encoder_layer = TransformerEncoderLayer(d_model, nhead,
dim_feedforward, dropout,
activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers,
encoder_norm)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, mask, pos_embed):
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
if mask is not None:
mask = mask.flatten(1)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
return memory.permute(1, 2, 0).view(bs, c, h, w)
class BasePixelDecoder(nn.Module):
def __init__(
self,
in_channels,
conv_dim: int,
mask_dim: int,
mask_on: bool,
norm: Optional[Union[str, Callable]] = None,
):
super().__init__()
lateral_convs = []
output_convs = []
use_bias = norm == ''
for idx, in_channel in enumerate(in_channels):
if idx == len(in_channels) - 1:
output_norm = get_norm(norm, conv_dim)
output_conv = Conv2d(
in_channel,
conv_dim,
kernel_size=3,
stride=1,
padding=1,
bias=use_bias,
norm=output_norm,
activation=F.relu,
)
self.add_module('layer_{}'.format(idx + 1), output_conv)
lateral_convs.append(None)
output_convs.append(output_conv)
else:
lateral_norm = get_norm(norm, conv_dim)
output_norm = get_norm(norm, conv_dim)
lateral_conv = Conv2d(
in_channel,
conv_dim,
kernel_size=1,
bias=use_bias,
norm=lateral_norm)
output_conv = Conv2d(
conv_dim,
conv_dim,
kernel_size=3,
stride=1,
padding=1,
bias=use_bias,
norm=output_norm,
activation=F.relu,
)
self.add_module('adapter_{}'.format(idx + 1), lateral_conv)
self.add_module('layer_{}'.format(idx + 1), output_conv)
lateral_convs.append(lateral_conv)
output_convs.append(output_conv)
# Place convs into top-down order (from low to high resolution)
# to make the top-down computation in forward clearer.
self.lateral_convs = lateral_convs[::-1]
self.output_convs = output_convs[::-1]
self.mask_on = mask_on
if self.mask_on:
self.mask_dim = mask_dim
self.mask_features = Conv2d(
conv_dim,
mask_dim,
kernel_size=3,
stride=1,
padding=1,
)
self.maskformer_num_feature_levels = 3
# To prevent conflicts with TransformerEncoderPixelDecoder in mask2former,
# we change the name to XTransformerEncoderPixelDecoder
@MODELS.register_module()
class XTransformerEncoderPixelDecoder(BasePixelDecoder):
def __init__(
self,
in_channels,
transformer_dropout: float = 0.0,
transformer_nheads: int = 8,
transformer_dim_feedforward: int = 2048,
transformer_enc_layers: int = 6,
transformer_pre_norm: bool = False,
conv_dim: int = 512,
mask_dim: int = 512,
norm: Optional[Union[str, Callable]] = 'GN',
):
super().__init__(
in_channels,
conv_dim=conv_dim,
mask_dim=mask_dim,
norm=norm,
mask_on=True)
self.in_features = ['res2', 'res3', 'res4', 'res5']
feature_channels = in_channels
in_channels = feature_channels[len(in_channels) - 1]
self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
self.transformer = TransformerEncoderOnly(
d_model=conv_dim,
dropout=transformer_dropout,
nhead=transformer_nheads,
dim_feedforward=transformer_dim_feedforward,
num_encoder_layers=transformer_enc_layers,
normalize_before=transformer_pre_norm,
)
self.pe_layer = PositionEmbeddingSine(conv_dim // 2, normalize=True)
# update layer
use_bias = norm == ''
output_norm = get_norm(norm, conv_dim)
output_conv = Conv2d(
conv_dim,
conv_dim,
kernel_size=3,
stride=1,
padding=1,
bias=use_bias,
norm=output_norm,
activation=F.relu,
)
delattr(self, 'layer_{}'.format(len(self.in_features)))
self.add_module('layer_{}'.format(len(self.in_features)), output_conv)
self.output_convs[0] = output_conv
def forward(self, features):
multi_scale_features = []
num_cur_levels = 0
# Reverse feature maps into top-down order
# (from low to high resolution)
for idx, f in enumerate(self.in_features[::-1]):
x = features[f]
lateral_conv = self.lateral_convs[idx]
output_conv = self.output_convs[idx]
if lateral_conv is None:
transformer = self.input_proj(x)
pos = self.pe_layer(x)
transformer = self.transformer(transformer, None, pos)
y = output_conv(transformer)
else:
cur_fpn = lateral_conv(x)
# Following FPN implementation, we use nearest upsampling here
y = cur_fpn + F.interpolate(
y, size=cur_fpn.shape[-2:], mode='nearest')
y = output_conv(y)
if num_cur_levels < self.maskformer_num_feature_levels:
multi_scale_features.append(y)
num_cur_levels += 1
mask_features = self.mask_features(y)
return mask_features, multi_scale_features