|
|
|
|
|
|
|
|
|
|
|
|
| import numpy as np
|
| from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
| import torch
|
| from torch import nn
|
| from torch.nn import functional as F
|
| from torch.nn.init import normal_
|
| from torch.amp import autocast
|
|
|
| from dinov3.eval.segmentation.models.utils.batch_norm import get_norm
|
| from dinov3.eval.segmentation.models.utils.position_encoding import PositionEmbeddingSine
|
| from dinov3.eval.segmentation.models.utils.transformer import _get_clones, _get_activation_fn
|
| from dinov3.eval.segmentation.models.utils.ms_deform_attn import MSDeformAttn
|
|
|
|
|
| def c2_xavier_fill(module: nn.Module) -> None:
|
| """
|
| Initialize `module.weight` using the "XavierFill" implemented in Caffe2.
|
| Also initializes `module.bias` to 0.
|
|
|
| Args:
|
| module (torch.nn.Module): module to initialize.
|
| """
|
|
|
|
|
|
|
| nn.init.kaiming_uniform_(module.weight, a=1)
|
| if module.bias is not None:
|
|
|
|
|
| nn.init.constant_(module.bias, 0)
|
|
|
|
|
| class Conv2d(torch.nn.Conv2d):
|
| """
|
| A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
|
| """
|
|
|
| def __init__(self, *args, **kwargs):
|
| """
|
| Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
|
|
|
| Args:
|
| norm (nn.Module, optional): a normalization layer
|
| activation (callable(Tensor) -> Tensor): a callable activation function
|
|
|
| It assumes that norm layer is used before activation.
|
| """
|
| norm = kwargs.pop("norm", None)
|
| activation = kwargs.pop("activation", None)
|
| super().__init__(*args, **kwargs)
|
|
|
| self.norm = norm
|
| self.activation = activation
|
|
|
| def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
| if self.norm is not None:
|
| x = self.norm(x)
|
| if self.activation is not None:
|
| x = self.activation(x)
|
| return x
|
|
|
|
|
|
|
| class MSDeformAttnTransformerEncoderOnly(nn.Module):
|
| def __init__(
|
| self,
|
| d_model=256,
|
| nhead=8,
|
| num_encoder_layers=6,
|
| dim_feedforward=1024,
|
| dropout=0.1,
|
| activation="relu",
|
| num_feature_levels=4,
|
| enc_n_points=4,
|
| ):
|
| super().__init__()
|
|
|
| self.d_model = d_model
|
| self.nhead = nhead
|
|
|
| encoder_layer = MSDeformAttnTransformerEncoderLayer(
|
| d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points
|
| )
|
| self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers)
|
|
|
| self.level_encoding = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
|
|
|
| self._reset_parameters()
|
|
|
| def _reset_parameters(self):
|
| for p in self.parameters():
|
| if p.dim() > 1:
|
| nn.init.xavier_uniform_(p)
|
| for m in self.modules():
|
| if isinstance(m, MSDeformAttn):
|
| m._reset_parameters()
|
| normal_(self.level_encoding)
|
|
|
| def get_valid_ratio(self, mask):
|
| _, H, W = mask.shape
|
| valid_H = torch.sum(~mask[:, :, 0], 1)
|
| valid_W = torch.sum(~mask[:, 0, :], 1)
|
| valid_ratio_h = valid_H.float() / H
|
| valid_ratio_w = valid_W.float() / W
|
| valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
| return valid_ratio
|
|
|
| def forward(self, srcs, pos_embeds):
|
| masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs]
|
|
|
| src_flatten = []
|
| mask_flatten = []
|
| lvl_pos_embed_flatten = []
|
| spatial_shapes = []
|
| for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
|
| bs, c, h, w = src.shape
|
| spatial_shape = (h, w)
|
| spatial_shapes.append(spatial_shape)
|
| src = src.flatten(2).transpose(1, 2)
|
| mask = mask.flatten(1)
|
| pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
| lvl_pos_embed = pos_embed + self.level_encoding[lvl].view(1, 1, -1)
|
| lvl_pos_embed_flatten.append(lvl_pos_embed)
|
| src_flatten.append(src)
|
| mask_flatten.append(mask)
|
| src_flatten = torch.cat(src_flatten, 1)
|
| mask_flatten = torch.cat(mask_flatten, 1)
|
| lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
| spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
|
| level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
| valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
|
|
|
|
| memory = self.encoder(
|
| src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten
|
| )
|
|
|
| return memory, spatial_shapes, level_start_index
|
|
|
|
|
| class MSDeformAttnTransformerEncoderLayer(nn.Module):
|
| def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4):
|
| super().__init__()
|
|
|
|
|
| self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
| self.dropout1 = nn.Dropout(dropout)
|
| self.norm1 = nn.LayerNorm(d_model)
|
|
|
|
|
| self.linear1 = nn.Linear(d_model, d_ffn)
|
| self.activation = _get_activation_fn(activation)
|
| self.dropout2 = nn.Dropout(dropout)
|
| self.linear2 = nn.Linear(d_ffn, d_model)
|
| self.dropout3 = nn.Dropout(dropout)
|
| self.norm2 = nn.LayerNorm(d_model)
|
|
|
| @staticmethod
|
| def with_pos_embed(tensor, pos):
|
| return tensor if pos is None else tensor + pos
|
|
|
| def forward_ffn(self, src):
|
| src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
|
| src = src + self.dropout3(src2)
|
| src = self.norm2(src)
|
| return src
|
|
|
| def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
|
|
|
| src2 = self.self_attn(
|
| self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask
|
| )
|
| src = src + self.dropout1(src2)
|
| src = self.norm1(src)
|
|
|
|
|
| src = self.forward_ffn(src)
|
|
|
| return src
|
|
|
|
|
| class MSDeformAttnTransformerEncoder(nn.Module):
|
| def __init__(self, encoder_layer, num_layers):
|
| super().__init__()
|
| self.layers = _get_clones(encoder_layer, num_layers)
|
| self.num_layers = num_layers
|
|
|
| @staticmethod
|
| def get_reference_points(spatial_shapes, valid_ratios, device):
|
| reference_points_list = []
|
| for lvl, (H_, W_) in enumerate(spatial_shapes):
|
| ref_y, ref_x = torch.meshgrid(
|
| torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
| torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
|
| )
|
| ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
|
| ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
|
| ref = torch.stack((ref_x, ref_y), -1)
|
| reference_points_list.append(ref)
|
| reference_points = torch.cat(reference_points_list, 1)
|
| reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
| return reference_points
|
|
|
| def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
|
| output = src
|
| reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
|
| for _, layer in enumerate(self.layers):
|
| output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
|
|
|
| return output
|
|
|
|
|
|
|
| class MSDeformAttnPixelDecoder(nn.Module):
|
|
|
| def __init__(
|
| self,
|
| input_shape: Dict[str, Tuple[int]],
|
| *,
|
| transformer_dropout: float,
|
| transformer_nheads: int,
|
| transformer_dim_feedforward: int,
|
| transformer_enc_layers: int,
|
| conv_dim: int,
|
| mask_dim: int,
|
| norm: Optional[Union[str, Callable]] = None,
|
|
|
| transformer_in_features: List[str],
|
| common_stride: int,
|
| ):
|
| """
|
| NOTE: this interface is experimental.
|
| Args:
|
| input_shape: shapes (channels and stride) of the input features
|
| transformer_dropout: dropout probability in transformer
|
| transformer_nheads: number of heads in transformer
|
| transformer_dim_feedforward: dimension of feedforward network
|
| transformer_enc_layers: number of transformer encoder layers
|
| conv_dims: number of output channels for the intermediate conv layers.
|
| mask_dim: number of output channels for the final conv layer.
|
| norm (str or callable): normalization for all conv layers
|
| """
|
| super().__init__()
|
| transformer_input_shape = {k: v for k, v in input_shape.items() if k in transformer_in_features}
|
|
|
|
|
| input_shape = sorted(input_shape.items(), key=lambda x: x[1][-1])
|
| self.in_features = [k for k, v in input_shape]
|
| self.feature_strides = [v[-1] for k, v in input_shape]
|
| self.feature_channels = [v[0] for k, v in input_shape]
|
|
|
|
|
| transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1][-1])
|
| self.transformer_in_features = [k for k, v in transformer_input_shape]
|
| transformer_in_channels = [v[0] for k, v in transformer_input_shape]
|
| self.transformer_feature_strides = [v[-1] for k, v in transformer_input_shape]
|
|
|
| self.transformer_num_feature_levels = 3
|
| if self.transformer_num_feature_levels > 1:
|
| input_proj_list = []
|
|
|
| for in_channels in transformer_in_channels[::-1][:-1]:
|
| input_proj_list.append(
|
| nn.Sequential(
|
| nn.Conv2d(in_channels, conv_dim, kernel_size=1),
|
| nn.GroupNorm(32, conv_dim),
|
| )
|
| )
|
| self.input_convs = nn.ModuleList(input_proj_list)
|
| else:
|
| self.input_convs = nn.ModuleList(
|
| [
|
| nn.Sequential(
|
| nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1),
|
| nn.GroupNorm(32, conv_dim),
|
| )
|
| ]
|
| )
|
|
|
| for proj in self.input_convs:
|
| nn.init.xavier_uniform_(proj[0].weight, gain=1)
|
| nn.init.constant_(proj[0].bias, 0)
|
|
|
| self.encoder = MSDeformAttnTransformerEncoderOnly(
|
| d_model=conv_dim,
|
| dropout=transformer_dropout,
|
| nhead=transformer_nheads,
|
| dim_feedforward=transformer_dim_feedforward,
|
| num_encoder_layers=transformer_enc_layers,
|
| num_feature_levels=self.transformer_num_feature_levels,
|
| )
|
| N_steps = conv_dim // 2
|
| self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
|
|
|
| self.mask_dim = mask_dim
|
|
|
| self.mask_feature = Conv2d(
|
| conv_dim,
|
| mask_dim,
|
| kernel_size=1,
|
| stride=1,
|
| padding=0,
|
| )
|
| c2_xavier_fill(self.mask_feature)
|
|
|
| self.maskformer_num_feature_levels = 3
|
| self.common_stride = common_stride
|
|
|
|
|
| stride = min(self.transformer_feature_strides)
|
| self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride))
|
|
|
| lateral_convs = []
|
| output_convs = []
|
|
|
| use_bias = norm == ""
|
| for idx, in_channels in enumerate(self.feature_channels[:1]):
|
| lateral_norm = get_norm(norm, conv_dim)
|
| output_norm = get_norm(norm, conv_dim)
|
|
|
| lateral_conv = Conv2d(in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm)
|
| output_conv = Conv2d(
|
| conv_dim,
|
| conv_dim,
|
| kernel_size=3,
|
| stride=1,
|
| padding=1,
|
| bias=use_bias,
|
| norm=output_norm,
|
| activation=F.relu,
|
| )
|
| c2_xavier_fill(lateral_conv)
|
| c2_xavier_fill(output_conv)
|
|
|
|
|
|
|
| lateral_convs.append(lateral_conv)
|
| output_convs.append(output_conv)
|
|
|
|
|
| self.lateral_convs = nn.ModuleList(lateral_convs[::-1])
|
| self.output_convs = nn.ModuleList(output_convs[::-1])
|
|
|
| @autocast(device_type="cuda", enabled=False)
|
| def forward_features(self, features):
|
| srcs = []
|
| pos = []
|
|
|
| for idx, f in enumerate(self.transformer_in_features[::-1][:-1]):
|
| x = features[f].float()
|
| srcs.append(self.input_convs[idx](x))
|
| pos.append(self.pe_layer(x))
|
|
|
| y, spatial_shapes, level_start_index = self.encoder(srcs, pos)
|
| bs = y.shape[0]
|
|
|
| split_size_or_sections = [None] * self.transformer_num_feature_levels
|
| for i in range(self.transformer_num_feature_levels):
|
| if i < self.transformer_num_feature_levels - 1:
|
| split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]
|
| else:
|
| split_size_or_sections[i] = y.shape[1] - level_start_index[i]
|
| y = torch.split(y, split_size_or_sections, dim=1)
|
|
|
| out = []
|
| multi_scale_features = []
|
| num_cur_levels = 0
|
| for i, z in enumerate(y):
|
| out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))
|
|
|
|
|
|
|
| for idx, f in enumerate(self.in_features[0]):
|
| x = features[f].float()
|
| lateral_conv = self.lateral_convs[idx]
|
| output_conv = self.output_convs[idx]
|
| cur_fpn = lateral_conv(x)
|
|
|
| y = cur_fpn + F.interpolate(out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False)
|
| y = output_conv(y)
|
| out.append(y)
|
|
|
| for o in out:
|
| if num_cur_levels < self.maskformer_num_feature_levels:
|
| multi_scale_features.append(o)
|
| num_cur_levels += 1
|
|
|
| return self.mask_feature(out[-1]), out[0], multi_scale_features
|
|
|