Spaces:
Sleeping
Sleeping
| # Reference: https://github.com/google-research/deeplab2/blob/main/model/pixel_decoder/kmax.py | |
| # Modified by Qihang Yu | |
| from turtle import forward | |
| from typing import Dict, List | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from timm.models.layers import DropPath | |
| from timm.models.layers import trunc_normal_tf_ as trunc_normal_ | |
| from detectron2.config import configurable | |
| from detectron2.layers import ShapeSpec | |
| from detectron2.modeling import SEM_SEG_HEADS_REGISTRY | |
| from torch.cuda.amp import autocast | |
| from ..backbone.convnext import LayerNorm | |
| import math | |
| def get_activation(name): | |
| if name is None or name.lower() == 'none': | |
| return nn.Identity() | |
| if name == 'relu': | |
| return nn.ReLU() | |
| elif name == 'gelu': | |
| return nn.GELU() | |
| class SyncBNCPU(nn.SyncBatchNorm): | |
| def forward(self, input): | |
| self._check_input_dim(input) | |
| self._check_non_zero_input_channels(input) | |
| if self.momentum is None: | |
| exponential_average_factor = 0.0 | |
| else: | |
| exponential_average_factor = self.momentum | |
| bn_training = False | |
| running_mean = self.running_mean | |
| running_var = self.running_var | |
| # fallback to framework BN when synchronization is not necessary | |
| return F.batch_norm( | |
| input, | |
| running_mean, | |
| running_var, | |
| self.weight, | |
| self.bias, | |
| bn_training, | |
| exponential_average_factor, | |
| self.eps, | |
| ) | |
| def get_norm(name, channels): | |
| if name is None or name.lower() == 'none': | |
| return nn.Identity() | |
| if name.lower() == 'syncbn': | |
| return SyncBNCPU(channels, eps=1e-3, momentum=0.01) | |
| class ConvBN(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, norm=None, act=None, | |
| conv_type='2d', conv_init='he_normal', norm_init=1.0): | |
| super().__init__() | |
| if conv_type == '2d': | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) | |
| elif conv_type == '1d': | |
| self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) | |
| self.norm = get_norm(norm, out_channels) | |
| self.act = get_activation(act) | |
| if conv_init == 'normal': | |
| nn.init.normal_(self.conv.weight, std=.02) | |
| elif conv_init == 'trunc_normal': | |
| trunc_normal_(self.conv.weight, std=.02) | |
| elif conv_init == 'he_normal': | |
| # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/HeNormal | |
| trunc_normal_(self.conv.weight, std=math.sqrt(2.0 / in_channels)) | |
| elif conv_init == 'xavier_uniform': | |
| nn.init.xavier_uniform_(self.conv.weight) | |
| if bias: | |
| nn.init.zeros_(self.conv.bias) | |
| if norm is not None: | |
| nn.init.constant_(self.norm.weight, norm_init) | |
| def forward(self, x): | |
| return self.act(self.norm(self.conv(x))) | |
| MAX_SPAN = 255 | |
| def _compute_relative_distance_matrix(query_length, key_length): | |
| if (key_length - query_length) % 2: | |
| raise ValueError('Key_length should be query_length + 2 * memory_flange.') | |
| key_index = torch.arange(key_length) | |
| query_index = torch.arange(query_length) + (key_length - query_length) // 2 | |
| distance_matrix = key_index[None, :] - query_index[:, None] | |
| # Shift the distance_matrix so that it is >= 0. Each entry of the | |
| # distance_matrix distance will index a relative positional embedding. | |
| distance_matrix = distance_matrix + MAX_SPAN - 1 | |
| return distance_matrix | |
| class RelativePositionalEncoding(nn.Module): | |
| def __init__(self, query_length, key_length, depth): | |
| super().__init__() | |
| self._embeddings = nn.Embedding(MAX_SPAN * 2 - 1, depth) | |
| trunc_normal_(self._embeddings.weight, std=1.0) | |
| self._relative_distance_matrix = _compute_relative_distance_matrix(query_length, key_length) | |
| self.query_length = query_length | |
| self.key_length = key_length | |
| self.depth = depth | |
| def forward(self): | |
| return self._embeddings.weight[self._relative_distance_matrix.reshape(-1)].reshape(self.query_length, self.key_length, self.depth) | |
| # https://github.com/google-research/deeplab2/blob/main/model/layers/axial_layers.py#L36 | |
| class AxialAttention(nn.Module): | |
| def __init__(self, in_planes, query_shape=56, total_key_depth=512, total_value_depth=1024, num_heads=8): | |
| assert (total_key_depth % num_heads == 0) and (total_value_depth % num_heads == 0) | |
| super().__init__() | |
| self._in_planes = in_planes | |
| self._query_shape = query_shape | |
| self._total_key_depth = total_key_depth | |
| self._total_value_depth = total_value_depth | |
| self._num_heads = num_heads | |
| self._key_depth_per_head = total_key_depth // num_heads | |
| self.qkv_transform = ConvBN(in_planes, self._total_key_depth * 2 + self._total_value_depth, kernel_size=1, stride=1, | |
| padding=0, bias=False, norm=None, act=None, conv_type='1d') | |
| trunc_normal_(self.qkv_transform.conv.weight, std=in_planes ** -0.5) | |
| self._query_rpe = RelativePositionalEncoding(query_shape, query_shape, self._key_depth_per_head) | |
| self._key_rpe = RelativePositionalEncoding(query_shape, query_shape, self._key_depth_per_head) | |
| self._value_rpe = RelativePositionalEncoding(query_shape, query_shape, total_value_depth // num_heads) | |
| self._batch_norm_qkv = get_norm('syncbn', self._total_key_depth * 2 + self._total_value_depth) | |
| self._batch_norm_similarity = get_norm('syncbn', num_heads * 3) | |
| self._batch_norm_retrieved_output = get_norm('syncbn', self._total_value_depth * 2) | |
| def forward(self, x): | |
| N, C, L = x.shape | |
| qkv = self._batch_norm_qkv(self.qkv_transform(x)) | |
| q, k, v = torch.split(qkv, [self._total_key_depth, self._total_key_depth, self._total_value_depth], dim=1) | |
| q = q.reshape(N, self._num_heads, self._total_key_depth // self._num_heads, L) | |
| k = k.reshape(N, self._num_heads, self._total_key_depth // self._num_heads, L) | |
| v = v.reshape(N, self._num_heads, self._total_value_depth // self._num_heads, L) | |
| similarity_logits = [] | |
| content_similarity = torch.einsum('bhdl,bhdm->bhlm', q, k) | |
| query_rpe = self._query_rpe() | |
| query_rpe_similarity = torch.einsum('bhdl,lmd->bhlm', q, query_rpe) | |
| key_rpe = self._key_rpe() | |
| key_rpe_similarity = torch.einsum('bhdm,lmd->bhlm', k, key_rpe) | |
| similarity_logits = torch.cat([content_similarity, query_rpe_similarity, key_rpe_similarity], dim=1) | |
| similarity_logits = self._batch_norm_similarity(similarity_logits).reshape(N, 3, self._num_heads, L, L).sum(dim=1) | |
| with autocast(enabled=False): | |
| weights = F.softmax(similarity_logits.float(), dim=-1) | |
| retrieved_content = torch.einsum('bhlm,bhdm->bhdl', weights, v) | |
| value_rpe = self._value_rpe() | |
| retrieved_rpe = torch.einsum('bhlm,lmd->bhdl', weights, value_rpe) | |
| retrieved_output = torch.cat([retrieved_content, retrieved_rpe], dim=1).reshape(N, 2*self._total_value_depth, L) | |
| retrieved_output = self._batch_norm_retrieved_output(retrieved_output).reshape(N, 2, self._total_value_depth, L).sum(1) | |
| return retrieved_output | |
| # https://github.com/google-research/deeplab2/blob/main/model/layers/axial_layers.py#L316 | |
| class AxialAttention2D(nn.Module): | |
| def __init__(self, in_planes, query_shape=[56, 56], filters=512, key_expansion=1, value_expansion=2, num_heads=8): | |
| super().__init__() | |
| total_key_depth = int(round(filters * key_expansion)) | |
| total_value_depth = int(round(filters * value_expansion)) | |
| self._total_key_depth = total_key_depth | |
| self._total_value_depth = total_value_depth | |
| self._height_axis = AxialAttention( | |
| in_planes=in_planes, | |
| query_shape=query_shape[0], | |
| total_key_depth=total_key_depth, | |
| total_value_depth=total_value_depth, | |
| num_heads=num_heads) | |
| self._width_axis = AxialAttention( | |
| in_planes=total_value_depth, | |
| query_shape=query_shape[1], | |
| total_key_depth=total_key_depth, | |
| total_value_depth=total_value_depth, | |
| num_heads=num_heads) | |
| def forward(self, x): | |
| # N C H W -> N W C H | |
| N, C, H, W = x.shape | |
| x = x.permute(0, 3, 1, 2).contiguous() | |
| x = x.reshape(N*W, C, H) | |
| x = self._height_axis(x) | |
| # N W C H -> N H C W | |
| x = x.reshape(N, W, self._total_value_depth, H).permute(0, 3, 2, 1).contiguous() | |
| x = x.reshape(N*H, self._total_value_depth, W) | |
| x = self._width_axis(x) | |
| x = x.reshape(N, H, self._total_value_depth, W).permute(0, 2, 1, 3).contiguous() | |
| x = x.reshape(N, self._total_value_depth, H, W) | |
| return x | |
| # https://github.com/google-research/deeplab2/blob/main/model/layers/axial_blocks.py#L36 | |
| class SingleBlock(nn.Module): | |
| def __init__(self, inplanes, filter_list, block_type, query_shape=[56, 56], key_expansion=1, value_expansion=2, num_heads=8, drop_path_prob=0.0): | |
| super(SingleBlock, self).__init__() | |
| self._block_type = block_type.lower() | |
| self._filter_list = filter_list | |
| self._conv1_bn_act = ConvBN(inplanes, self._filter_list[0], kernel_size=1, bias=False, norm='syncbn', act='gelu') | |
| if self._block_type == 'axial': | |
| self._attention = AxialAttention2D(in_planes=self._filter_list[0], query_shape=query_shape, filters=self._filter_list[1], | |
| key_expansion=key_expansion, value_expansion=value_expansion, num_heads=num_heads) | |
| output_channel = filter_list[1] * value_expansion | |
| elif self._block_type == 'bottleneck': | |
| self._conv2_bn_act = ConvBN(self._filter_list[0], self._filter_list[1], kernel_size=3, padding=1, bias=False, norm='syncbn', act='gelu') | |
| output_channel = filter_list[1] | |
| self._conv3_bn = ConvBN(output_channel, self._filter_list[2], kernel_size=1, bias=False, norm='syncbn', act=None, norm_init=0.0) | |
| self._shortcut = None | |
| if inplanes != self._filter_list[-1]: | |
| self._shortcut = ConvBN(inplanes, self._filter_list[-1], kernel_size=1, bias=False, norm='syncbn', act=None) | |
| self.drop_path = DropPath(drop_path_prob) if drop_path_prob > 0. else nn.Identity() | |
| def forward(self, x): | |
| x = F.gelu(x) | |
| shortcut = x | |
| if self._shortcut is not None: | |
| shortcut = self._shortcut(shortcut) | |
| x = self._conv1_bn_act(x) | |
| if self._block_type == 'axial': | |
| x = self._attention(x) | |
| x = F.gelu(x) | |
| elif self._block_type == 'bottleneck': | |
| x = self._conv2_bn_act(x) | |
| x = self._conv3_bn(x) | |
| x = self.drop_path(x) + shortcut | |
| return x | |
| # https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L42 | |
| class BlockGroup(nn.Module): | |
| def __init__(self, inplanes, base_filter, num_blocks, block_type, **kwargs): | |
| super().__init__() | |
| self._num_blocks = num_blocks | |
| block_type = block_type.lower() | |
| if block_type == 'axial': | |
| # https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L247 | |
| filter_list = [base_filter * 2, base_filter, base_filter * 4] | |
| elif block_type == 'bottleneck': | |
| # https://github.com/google-research/deeplab2/blob/main/model/layers/axial_block_groups.py#L250 | |
| filter_list = [base_filter, base_filter, base_filter * 4] | |
| self._blocks = nn.ModuleList() | |
| for i in range(num_blocks): | |
| self._blocks.append(SingleBlock(inplanes=inplanes, filter_list=filter_list, block_type=block_type, **kwargs)) | |
| inplanes = filter_list[-1] | |
| def forward(self, x): | |
| for i in range(self._num_blocks): | |
| x = self._blocks[i](x) | |
| return x | |
| # https://github.com/google-research/deeplab2/blob/7a01a7165e97b3325ad7ea9b6bcc02d67fecd07a/model/layers/resized_fuse.py#L31 | |
| class ResizedFuse(nn.Module): | |
| def __init__(self, low_in_channels, high_in_channels, out_channels): | |
| super().__init__() | |
| self.low_in_channels = low_in_channels | |
| self.high_in_channels = high_in_channels | |
| self.out_channels = out_channels | |
| if low_in_channels != out_channels: | |
| self._conv_bn_low = ConvBN(low_in_channels, out_channels, kernel_size=1, bias=False, norm='syncbn', act=None) | |
| if high_in_channels != out_channels: | |
| self._conv_bn_high = ConvBN(high_in_channels, out_channels, kernel_size=1, bias=False, norm='syncbn', act=None) | |
| def forward(self, lowres_x, highres_x): | |
| align_corners = (lowres_x.shape[-1] % 2 == 1) | |
| if self.low_in_channels != self.out_channels: | |
| lowres_x = F.gelu(lowres_x) | |
| lowres_x = self._conv_bn_low(lowres_x) | |
| lowres_x = F.interpolate(lowres_x, size=highres_x.shape[2:], mode='bilinear', align_corners=align_corners) | |
| else: | |
| lowres_x = F.interpolate(lowres_x, size=highres_x.shape[2:], mode='bilinear', align_corners=align_corners) | |
| if self.high_in_channels != self.out_channels: | |
| highres_x = F.gelu(highres_x) | |
| highres_x = self._conv_bn_high(highres_x) | |
| return lowres_x + highres_x | |
| class kMaXPixelDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| input_shape: Dict[str, ShapeSpec], | |
| *, | |
| dec_layers: List[int], | |
| dec_channels: List[int], | |
| layer_types: List[str], | |
| drop_path_prob: float, | |
| spatial_shape: List[int], | |
| ): | |
| """ | |
| NOTE: this interface is experimental. | |
| Args: | |
| """ | |
| super().__init__() | |
| self.num_stages = len(input_shape) | |
| assert self.num_stages == len(dec_layers) and self.num_stages == len(dec_channels) and self.num_stages == len(layer_types) | |
| # For now, we hard code all hyper-parameters. | |
| block_types = ['axial', 'axial', 'bottleneck', 'bottleneck'] | |
| input_shape = sorted(input_shape.items(), key=lambda x: -x[1].stride) | |
| self.in_features = [k for k, v in input_shape] # starting from "res5" to "res2" | |
| in_channels = [v.channels for k, v in input_shape] | |
| add_one = (spatial_shape[0] % 2, spatial_shape[1] % 2) | |
| query_shape = [ | |
| (spatial_shape[0]//32+add_one[0], spatial_shape[1]//32+add_one[1]), | |
| (spatial_shape[0]//16+add_one[0], spatial_shape[1]//16+add_one[1]), | |
| (spatial_shape[0]//8+add_one[0], spatial_shape[1]//8+add_one[1]), | |
| (spatial_shape[0]//4+add_one[0], spatial_shape[1]//4+add_one[1])] | |
| self._in_norms = nn.ModuleList() | |
| self._stages = nn.ModuleList() | |
| self._resized_fuses = nn.ModuleList() | |
| for i in range(self.num_stages): | |
| self._in_norms.append(LayerNorm(in_channels[i], data_format="channels_first")) | |
| inplanes = in_channels[i] if i == 0 else dec_channels[i] | |
| self._stages.append(BlockGroup(inplanes=inplanes, | |
| base_filter=dec_channels[i], num_blocks=dec_layers[i], block_type=block_types[i], | |
| query_shape=query_shape[i], key_expansion=1, value_expansion=2, num_heads=8, drop_path_prob=0.0)) | |
| if i > 0: | |
| self._resized_fuses.append(ResizedFuse( | |
| low_in_channels=dec_channels[i-1] * 4, | |
| high_in_channels=in_channels[i], | |
| out_channels=dec_channels[i])) | |
| def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): | |
| ret = {} | |
| ret["input_shape"] = { | |
| k: v for k, v in input_shape.items() if k in cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.IN_FEATURES | |
| } | |
| ret["dec_layers"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_LAYERS | |
| ret["dec_channels"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DEC_CHANNELS | |
| ret["layer_types"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.LAYER_TYPES | |
| ret["drop_path_prob"] = cfg.MODEL.KMAX_DEEPLAB.PIXEL_DEC.DROP_PATH_PROB | |
| ret["spatial_shape"] = cfg.INPUT.IMAGE_SIZE # We expect the height == width | |
| return ret | |
| def forward_features(self, features): | |
| out = [] | |
| multi_scale_features = [] | |
| x = self._in_norms[0](features[self.in_features[0]]) | |
| for idx in range(self.num_stages - 1): | |
| x = self._stages[idx](x) | |
| out.append(x) | |
| x = self._resized_fuses[idx]( | |
| lowres_x=x, | |
| highres_x=self._in_norms[idx+1](features[self.in_features[idx+1]])) | |
| x = self._stages[-1](x) | |
| out.append(x) | |
| multi_scale_features = out[:3] # OS32, 16, 8, they are used for kmax_transformer_decoder. | |
| panoptic_features = out[-1] # OS4, it is used for final mask prediction. | |
| # OS 32, 8, 4 | |
| semantic_features = [features[self.in_features[0]], features[self.in_features[2]], features[self.in_features[3]]] | |
| return panoptic_features, semantic_features, multi_scale_features | |