Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py | |
| import logging | |
| import fvcore.nn.weight_init as weight_init | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| from torch import nn, Tensor | |
| from torch.nn import functional as F | |
| from detectron2.config import configurable | |
| from detectron2.layers import Conv2d | |
| from detectron2.utils.registry import Registry | |
| from .position_encoding import PositionEmbeddingSine | |
| TRANSFORMER_DECODER_REGISTRY = Registry("TRANSFORMER_MODULE") | |
| TRANSFORMER_DECODER_REGISTRY.__doc__ = """ | |
| Registry for transformer module in MaskFormer. | |
| """ | |
| def build_transformer_decoder(cfg, in_channels, mask_classification=True): | |
| """ | |
| Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`. | |
| """ | |
| name = cfg.MODEL.MASK_FORMER.TRANSFORMER_DECODER_NAME | |
| return TRANSFORMER_DECODER_REGISTRY.get(name)(cfg, in_channels, mask_classification) | |
| def get_classification_logits(x, text_classifier, logit_scale, num_templates=None): | |
| # x in shape of [B, *, C] | |
| # text_classifier in shape of [num_classes, C] | |
| # logit_scale is a learnable scalar https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/model.py#L201 | |
| # return: [B, *, num_classes] | |
| x = F.normalize(x, dim=-1) | |
| logit_scale = torch.clamp(logit_scale.exp(), max=100) | |
| pred_logits = logit_scale * x @ text_classifier.T # B, *, N + 1 | |
| # max ensembel as in OpenSeg/ODISE | |
| final_pred_logits = [] | |
| cur_idx = 0 | |
| for num_t in num_templates: | |
| final_pred_logits.append(pred_logits[:, :, cur_idx: cur_idx + num_t].max(-1).values) | |
| cur_idx += num_t | |
| final_pred_logits.append(pred_logits[:, :, -1]) # the last classifier is for void | |
| final_pred_logits = torch.stack(final_pred_logits, dim=-1) | |
| return final_pred_logits | |
| # Ref: https://github.com/NVlabs/ODISE/blob/e97b06c424c575fec9fc5368dd4b3e050d91abc4/odise/modeling/meta_arch/odise.py#L923 | |
| class MaskPooling(nn.Module): | |
| def __init__( | |
| self, | |
| ): | |
| super().__init__() | |
| def forward(self, x, mask): | |
| """ | |
| Args: | |
| x: [B, C, H, W] | |
| mask: [B, Q, H, W] | |
| """ | |
| if not x.shape[-2:] == mask.shape[-2:]: | |
| # reshape mask to x | |
| mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False) | |
| with torch.no_grad(): | |
| mask = mask.detach() | |
| mask = (mask > 0).to(mask.dtype) | |
| denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8 | |
| mask_pooled_x = torch.einsum( | |
| "bchw,bqhw->bqc", | |
| x, | |
| mask / denorm, | |
| ) | |
| return mask_pooled_x | |
| class SelfAttentionLayer(nn.Module): | |
| def __init__(self, d_model, nhead, dropout=0.0, | |
| activation="relu", normalize_before=False): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| self.normalize_before = normalize_before | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |
| return tensor if pos is None else tensor + pos | |
| def forward_post(self, tgt, | |
| tgt_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| query_pos: Optional[Tensor] = None): | |
| q = k = self.with_pos_embed(tgt, query_pos) | |
| tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, | |
| key_padding_mask=tgt_key_padding_mask)[0] | |
| tgt = tgt + self.dropout(tgt2) | |
| tgt = self.norm(tgt) | |
| return tgt | |
| def forward_pre(self, tgt, | |
| tgt_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| query_pos: Optional[Tensor] = None): | |
| tgt2 = self.norm(tgt) | |
| q = k = self.with_pos_embed(tgt2, query_pos) | |
| tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, | |
| key_padding_mask=tgt_key_padding_mask)[0] | |
| tgt = tgt + self.dropout(tgt2) | |
| return tgt | |
| def forward(self, tgt, | |
| tgt_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| query_pos: Optional[Tensor] = None): | |
| if self.normalize_before: | |
| return self.forward_pre(tgt, tgt_mask, | |
| tgt_key_padding_mask, query_pos) | |
| return self.forward_post(tgt, tgt_mask, | |
| tgt_key_padding_mask, query_pos) | |
| class CrossAttentionLayer(nn.Module): | |
| def __init__(self, d_model, nhead, dropout=0.0, | |
| activation="relu", normalize_before=False): | |
| super().__init__() | |
| self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| self.normalize_before = normalize_before | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |
| return tensor if pos is None else tensor + pos | |
| def forward_post(self, tgt, memory, | |
| memory_mask: Optional[Tensor] = None, | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| pos: Optional[Tensor] = None, | |
| query_pos: Optional[Tensor] = None): | |
| tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), | |
| key=self.with_pos_embed(memory, pos), | |
| value=memory, attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask)[0] | |
| tgt = tgt + self.dropout(tgt2) | |
| tgt = self.norm(tgt) | |
| return tgt | |
| def forward_pre(self, tgt, memory, | |
| memory_mask: Optional[Tensor] = None, | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| pos: Optional[Tensor] = None, | |
| query_pos: Optional[Tensor] = None): | |
| tgt2 = self.norm(tgt) | |
| tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), | |
| key=self.with_pos_embed(memory, pos), | |
| value=memory, attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask)[0] | |
| tgt = tgt + self.dropout(tgt2) | |
| return tgt | |
| def forward(self, tgt, memory, | |
| memory_mask: Optional[Tensor] = None, | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| pos: Optional[Tensor] = None, | |
| query_pos: Optional[Tensor] = None): | |
| if self.normalize_before: | |
| return self.forward_pre(tgt, memory, memory_mask, | |
| memory_key_padding_mask, pos, query_pos) | |
| return self.forward_post(tgt, memory, memory_mask, | |
| memory_key_padding_mask, pos, query_pos) | |
| class kMaXCrossAttentionLayer(nn.Module): | |
| def __init__(self, d_model, nhead, dropout=0.0, | |
| activation="relu", normalize_before=False): | |
| super().__init__() | |
| self.nhead = nhead | |
| self.v_proj = nn.Linear(d_model, d_model) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| self.normalize_before = normalize_before | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |
| return tensor if pos is None else tensor + pos | |
| def forward_post(self, tgt, memory, | |
| memory_mask: Optional[Tensor] = None, | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| pos: Optional[Tensor] = None, | |
| query_pos: Optional[Tensor] = None): | |
| # memory mask is in shape [B*num_heads, Q, HW] | |
| tgt2 = self.v_proj(memory) # HW x C | |
| clustering_result = memory_mask.view(tgt.shape[0], self.nhead, memory_mask.shape[-1])[:, 0] # [B, Q, HW] | |
| tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), | |
| key=self.with_pos_embed(memory, pos), | |
| value=memory, attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask)[0] | |
| tgt = tgt + self.dropout(tgt2) | |
| tgt = self.norm(tgt) | |
| return tgt | |
| def forward_pre(self, tgt, memory, | |
| memory_mask: Optional[Tensor] = None, | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| pos: Optional[Tensor] = None, | |
| query_pos: Optional[Tensor] = None): | |
| tgt2 = self.norm(tgt) | |
| tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), | |
| key=self.with_pos_embed(memory, pos), | |
| value=memory, attn_mask=memory_mask, | |
| key_padding_mask=memory_key_padding_mask)[0] | |
| tgt = tgt + self.dropout(tgt2) | |
| return tgt | |
| def forward(self, tgt, memory, | |
| memory_mask: Optional[Tensor] = None, | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| pos: Optional[Tensor] = None, | |
| query_pos: Optional[Tensor] = None): | |
| if self.normalize_before: | |
| return self.forward_pre(tgt, memory, memory_mask, | |
| memory_key_padding_mask, pos, query_pos) | |
| return self.forward_post(tgt, memory, memory_mask, | |
| memory_key_padding_mask, pos, query_pos) | |
| class FFNLayer(nn.Module): | |
| def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, | |
| activation="relu", normalize_before=False): | |
| super().__init__() | |
| # Implementation of Feedforward model | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.norm = nn.LayerNorm(d_model) | |
| self.activation = _get_activation_fn(activation) | |
| self.normalize_before = normalize_before | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |
| return tensor if pos is None else tensor + pos | |
| def forward_post(self, tgt): | |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |
| tgt = tgt + self.dropout(tgt2) | |
| tgt = self.norm(tgt) | |
| return tgt | |
| def forward_pre(self, tgt): | |
| tgt2 = self.norm(tgt) | |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) | |
| tgt = tgt + self.dropout(tgt2) | |
| return tgt | |
| def forward(self, tgt): | |
| if self.normalize_before: | |
| return self.forward_pre(tgt) | |
| return self.forward_post(tgt) | |
| def _get_activation_fn(activation): | |
| """Return an activation function given a string""" | |
| if activation == "relu": | |
| return F.relu | |
| if activation == "gelu": | |
| return F.gelu | |
| if activation == "glu": | |
| return F.glu | |
| raise RuntimeError(F"activation should be relu/gelu, not {activation}.") | |
| class MLP(nn.Module): | |
| """ Very simple multi-layer perceptron (also called FFN)""" | |
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| h = [hidden_dim] * (num_layers - 1) | |
| self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) | |
| def forward(self, x): | |
| for i, layer in enumerate(self.layers): | |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
| return x | |
| class MultiScaleMaskedTransformerDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| mask_classification=True, | |
| *, | |
| num_classes: int, | |
| hidden_dim: int, | |
| num_queries: int, | |
| nheads: int, | |
| dim_feedforward: int, | |
| dec_layers: int, | |
| pre_norm: bool, | |
| mask_dim: int, | |
| enforce_input_project: bool, | |
| clip_embedding_dim: int | |
| ): | |
| """ | |
| NOTE: this interface is experimental. | |
| Args: | |
| in_channels: channels of the input features | |
| mask_classification: whether to add mask classifier or not | |
| num_classes: number of classes | |
| hidden_dim: Transformer feature dimension | |
| num_queries: number of queries | |
| nheads: number of heads | |
| dim_feedforward: feature dimension in feedforward network | |
| enc_layers: number of Transformer encoder layers | |
| dec_layers: number of Transformer decoder layers | |
| pre_norm: whether to use pre-LayerNorm or not | |
| mask_dim: mask feature dimension | |
| enforce_input_project: add input project 1x1 conv even if input | |
| channels and hidden dim is identical | |
| """ | |
| super().__init__() | |
| assert mask_classification, "Only support mask classification model" | |
| self.mask_classification = mask_classification | |
| # positional encoding | |
| N_steps = hidden_dim // 2 | |
| self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) | |
| # define Transformer decoder here | |
| self.num_heads = nheads | |
| self.num_layers = dec_layers | |
| self.transformer_self_attention_layers = nn.ModuleList() | |
| self.transformer_cross_attention_layers = nn.ModuleList() | |
| self.transformer_ffn_layers = nn.ModuleList() | |
| for _ in range(self.num_layers): | |
| self.transformer_self_attention_layers.append( | |
| SelfAttentionLayer( | |
| d_model=hidden_dim, | |
| nhead=nheads, | |
| dropout=0.0, | |
| normalize_before=pre_norm, | |
| ) | |
| ) | |
| self.transformer_cross_attention_layers.append( | |
| CrossAttentionLayer( | |
| d_model=hidden_dim, | |
| nhead=nheads, | |
| dropout=0.0, | |
| normalize_before=pre_norm, | |
| ) | |
| ) | |
| self.transformer_ffn_layers.append( | |
| FFNLayer( | |
| d_model=hidden_dim, | |
| dim_feedforward=dim_feedforward, | |
| dropout=0.0, | |
| normalize_before=pre_norm, | |
| ) | |
| ) | |
| self.decoder_norm = nn.LayerNorm(hidden_dim) | |
| self.num_queries = num_queries | |
| # learnable query features | |
| self.query_feat = nn.Embedding(num_queries, hidden_dim) | |
| # learnable query p.e. | |
| self.query_embed = nn.Embedding(num_queries, hidden_dim) | |
| # level embedding (we always use 3 scales) | |
| self.num_feature_levels = 3 | |
| self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) | |
| self.input_proj = nn.ModuleList() | |
| for _ in range(self.num_feature_levels): | |
| if in_channels != hidden_dim or enforce_input_project: | |
| self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) | |
| weight_init.c2_xavier_fill(self.input_proj[-1]) | |
| else: | |
| self.input_proj.append(nn.Sequential()) | |
| # output FFNs | |
| # if self.mask_classification: | |
| # self.class_embed = nn.Linear(hidden_dim, num_classes + 1) | |
| self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) | |
| # FC-CLIP | |
| self.mask_pooling = MaskPooling() | |
| self._mask_pooling_proj = nn.Sequential( | |
| nn.LayerNorm(hidden_dim), | |
| nn.Linear(hidden_dim, hidden_dim)) | |
| self.class_embed = MLP(hidden_dim, hidden_dim, clip_embedding_dim, 3) | |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
| def from_config(cls, cfg, in_channels, mask_classification): | |
| ret = {} | |
| ret["in_channels"] = in_channels | |
| ret["mask_classification"] = mask_classification | |
| ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES | |
| ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM | |
| ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES | |
| # Transformer parameters: | |
| ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS | |
| ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD | |
| # NOTE: because we add learnable query features which requires supervision, | |
| # we add minus 1 to decoder layers to be consistent with our loss | |
| # implementation: that is, number of auxiliary losses is always | |
| # equal to number of decoder layers. With learnable query features, the number of | |
| # auxiliary losses equals number of decoders plus 1. | |
| assert cfg.MODEL.MASK_FORMER.DEC_LAYERS >= 1 | |
| ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1 | |
| ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM | |
| ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ | |
| ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM | |
| ret["clip_embedding_dim"] = cfg.MODEL.FC_CLIP.EMBED_DIM | |
| return ret | |
| def forward(self, x, mask_features, mask = None, text_classifier=None, num_templates=None): | |
| # x is a list of multi-scale feature | |
| assert len(x) == self.num_feature_levels | |
| src = [] | |
| pos = [] | |
| size_list = [] | |
| # disable mask, it does not affect performance | |
| del mask | |
| for i in range(self.num_feature_levels): | |
| size_list.append(x[i].shape[-2:]) | |
| pos.append(self.pe_layer(x[i], None).flatten(2)) | |
| src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) | |
| # flatten NxCxHxW to HWxNxC | |
| pos[-1] = pos[-1].permute(2, 0, 1) | |
| src[-1] = src[-1].permute(2, 0, 1) | |
| _, bs, _ = src[0].shape | |
| # QxNxC | |
| query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) | |
| output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) | |
| predictions_class = [] | |
| predictions_mask = [] | |
| # prediction heads on learnable query features | |
| outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], | |
| text_classifier=text_classifier, num_templates=num_templates) | |
| predictions_class.append(outputs_class) | |
| predictions_mask.append(outputs_mask) | |
| for i in range(self.num_layers): | |
| level_index = i % self.num_feature_levels | |
| attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False | |
| # attention: cross-attention first | |
| output = self.transformer_cross_attention_layers[i]( | |
| output, src[level_index], | |
| memory_mask=attn_mask, | |
| memory_key_padding_mask=None, # here we do not apply masking on padded region | |
| pos=pos[level_index], query_pos=query_embed | |
| ) | |
| output = self.transformer_self_attention_layers[i]( | |
| output, tgt_mask=None, | |
| tgt_key_padding_mask=None, | |
| query_pos=query_embed | |
| ) | |
| # FFN | |
| output = self.transformer_ffn_layers[i]( | |
| output | |
| ) | |
| outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], | |
| text_classifier=text_classifier, num_templates=num_templates) | |
| predictions_class.append(outputs_class) | |
| predictions_mask.append(outputs_mask) | |
| assert len(predictions_class) == self.num_layers + 1 | |
| out = { | |
| 'pred_logits': predictions_class[-1], | |
| 'pred_masks': predictions_mask[-1], | |
| 'aux_outputs': self._set_aux_loss( | |
| predictions_class if self.mask_classification else None, predictions_mask | |
| ) | |
| } | |
| return out | |
| def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, text_classifier, num_templates): | |
| decoder_output = self.decoder_norm(output) | |
| decoder_output = decoder_output.transpose(0, 1) | |
| mask_embed = self.mask_embed(decoder_output) | |
| outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) | |
| # fcclip head | |
| maskpool_embeddings = self.mask_pooling(x=mask_features, mask=outputs_mask) # [B, Q, C] | |
| maskpool_embeddings = self._mask_pooling_proj(maskpool_embeddings) | |
| class_embed = self.class_embed(maskpool_embeddings + decoder_output) | |
| outputs_class = get_classification_logits(class_embed, text_classifier, self.logit_scale, num_templates) | |
| # NOTE: prediction is of higher-resolution | |
| # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW] | |
| attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) | |
| # must use bool type | |
| # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. | |
| attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() | |
| attn_mask = attn_mask.detach() | |
| return outputs_class, outputs_mask, attn_mask | |
| def _set_aux_loss(self, outputs_class, outputs_seg_masks): | |
| # this is a workaround to make torchscript happy, as torchscript | |
| # doesn't support dictionary with non-homogeneous values, such | |
| # as a dict having both a Tensor and a list. | |
| if self.mask_classification: | |
| return [ | |
| {"pred_logits": a, "pred_masks": b} | |
| for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) | |
| ] | |
| else: | |
| return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] | |