| |
|
|
| import math |
| from typing import Dict, List, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint as checkpoint |
|
|
| from .model_misc import MLP |
|
|
|
|
| class LinearPresenceHead(nn.Sequential): |
| def __init__(self, d_model): |
| |
| super().__init__(nn.Identity(), nn.Identity(), nn.Linear(d_model, 1)) |
|
|
| def forward(self, hs, prompt, prompt_mask): |
| return super().forward(hs) |
|
|
|
|
| class MaskPredictor(nn.Module): |
| def __init__(self, hidden_dim, mask_dim): |
| super().__init__() |
| self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) |
|
|
| def forward(self, obj_queries, pixel_embed): |
| if len(obj_queries.shape) == 3: |
| if pixel_embed.ndim == 3: |
| |
| mask_preds = torch.einsum( |
| "bqc,chw->bqhw", self.mask_embed(obj_queries), pixel_embed |
| ) |
| else: |
| mask_preds = torch.einsum( |
| "bqc,bchw->bqhw", self.mask_embed(obj_queries), pixel_embed |
| ) |
| else: |
| |
| if pixel_embed.ndim == 3: |
| |
| mask_preds = torch.einsum( |
| "lbqc,chw->lbqhw", self.mask_embed(obj_queries), pixel_embed |
| ) |
| else: |
| mask_preds = torch.einsum( |
| "lbqc,bchw->lbqhw", self.mask_embed(obj_queries), pixel_embed |
| ) |
|
|
| return mask_preds |
|
|
|
|
| class SegmentationHead(nn.Module): |
| def __init__( |
| self, |
| hidden_dim, |
| upsampling_stages, |
| use_encoder_inputs=False, |
| aux_masks=False, |
| no_dec=False, |
| pixel_decoder=None, |
| act_ckpt=False, |
| shared_conv=False, |
| compile_mode_pixel_decoder=None, |
| ): |
| super().__init__() |
| self.use_encoder_inputs = use_encoder_inputs |
| self.aux_masks = aux_masks |
| if pixel_decoder is not None: |
| self.pixel_decoder = pixel_decoder |
| else: |
| self.pixel_decoder = PixelDecoder( |
| hidden_dim, |
| upsampling_stages, |
| shared_conv=shared_conv, |
| compile_mode=compile_mode_pixel_decoder, |
| ) |
| self.no_dec = no_dec |
| if no_dec: |
| self.mask_predictor = nn.Conv2d( |
| hidden_dim, 1, kernel_size=3, stride=1, padding=1 |
| ) |
| else: |
| self.mask_predictor = MaskPredictor(hidden_dim, mask_dim=hidden_dim) |
|
|
| self.act_ckpt = act_ckpt |
|
|
| |
| self.instance_keys = ["pred_masks"] |
|
|
| @property |
| def device(self): |
| self._device = getattr(self, "_device", None) or next(self.parameters()).device |
| return self._device |
|
|
| def to(self, *args, **kwargs): |
| |
| self._device = None |
| return super().to(*args, **kwargs) |
|
|
| def _embed_pixels( |
| self, |
| backbone_feats: List[torch.Tensor], |
| image_ids, |
| encoder_hidden_states, |
| ) -> torch.Tensor: |
| feature_device = backbone_feats[0].device |
| model_device = self.device |
| image_ids_ = image_ids.to(feature_device) |
| if self.use_encoder_inputs: |
| if backbone_feats[0].shape[0] > 1: |
| |
| backbone_visual_feats = [] |
| for feat in backbone_feats: |
| |
| backbone_visual_feats.append(feat[image_ids_, ...].to(model_device)) |
| else: |
| |
| backbone_visual_feats = [bb_feat.clone() for bb_feat in backbone_feats] |
| |
| encoder_hidden_states = encoder_hidden_states.permute(1, 2, 0) |
| spatial_dim = math.prod(backbone_feats[-1].shape[-2:]) |
| encoder_visual_embed = encoder_hidden_states[..., :spatial_dim].reshape( |
| -1, *backbone_feats[-1].shape[1:] |
| ) |
|
|
| backbone_visual_feats[-1] = encoder_visual_embed |
| if self.act_ckpt: |
| pixel_embed = checkpoint.checkpoint( |
| self.pixel_decoder, backbone_visual_feats, use_reentrant=False |
| ) |
| else: |
| pixel_embed = self.pixel_decoder(backbone_visual_feats) |
| else: |
| backbone_feats = [x.to(model_device) for x in backbone_feats] |
| pixel_embed = self.pixel_decoder(backbone_feats) |
| if pixel_embed.shape[0] == 1: |
| |
| pixel_embed = pixel_embed.squeeze(0) |
| else: |
| pixel_embed = pixel_embed[image_ids, ...] |
| return pixel_embed |
|
|
| def forward( |
| self, |
| backbone_feats: List[torch.Tensor], |
| obj_queries: torch.Tensor, |
| image_ids, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> Dict[str, torch.Tensor]: |
| if self.use_encoder_inputs: |
| assert encoder_hidden_states is not None |
|
|
| pixel_embed = self._embed_pixels( |
| backbone_feats=backbone_feats, |
| image_ids=image_ids, |
| encoder_hidden_states=encoder_hidden_states, |
| ) |
|
|
| if self.no_dec: |
| mask_pred = self.mask_predictor(pixel_embed) |
| elif self.aux_masks: |
| mask_pred = self.mask_predictor(obj_queries, pixel_embed) |
| else: |
| mask_pred = self.mask_predictor(obj_queries[-1], pixel_embed) |
|
|
| return {"pred_masks": mask_pred} |
|
|
|
|
| class PixelDecoder(nn.Module): |
| def __init__( |
| self, |
| hidden_dim, |
| num_upsampling_stages, |
| interpolation_mode="nearest", |
| shared_conv=False, |
| compile_mode=None, |
| ): |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| self.num_upsampling_stages = num_upsampling_stages |
| self.interpolation_mode = interpolation_mode |
| conv_layers = [] |
| norms = [] |
| num_convs = 1 if shared_conv else num_upsampling_stages |
| for _ in range(num_convs): |
| conv_layers.append(nn.Conv2d(self.hidden_dim, self.hidden_dim, 3, 1, 1)) |
| norms.append(nn.GroupNorm(8, self.hidden_dim)) |
|
|
| self.conv_layers = nn.ModuleList(conv_layers) |
| self.norms = nn.ModuleList(norms) |
| self.shared_conv = shared_conv |
| self.out_dim = self.conv_layers[-1].out_channels |
| if compile_mode is not None: |
| self.forward = torch.compile( |
| self.forward, mode=compile_mode, dynamic=True, fullgraph=True |
| ) |
| |
| torch._dynamo.config.optimize_ddp = False |
|
|
| def forward(self, backbone_feats: List[torch.Tensor]): |
| |
|
|
| prev_fpn = backbone_feats[-1] |
| fpn_feats = backbone_feats[:-1] |
| for layer_idx, bb_feat in enumerate(fpn_feats[::-1]): |
| curr_fpn = bb_feat |
| prev_fpn = curr_fpn + F.interpolate( |
| prev_fpn, size=curr_fpn.shape[-2:], mode=self.interpolation_mode |
| ) |
| if self.shared_conv: |
| |
| layer_idx = 0 |
| prev_fpn = self.conv_layers[layer_idx](prev_fpn) |
| prev_fpn = F.relu(self.norms[layer_idx](prev_fpn)) |
|
|
| return prev_fpn |
|
|
|
|
| class UniversalSegmentationHead(SegmentationHead): |
| """This module handles semantic+instance segmentation""" |
|
|
| def __init__( |
| self, |
| hidden_dim, |
| upsampling_stages, |
| pixel_decoder, |
| aux_masks=False, |
| no_dec=False, |
| act_ckpt=False, |
| presence_head: bool = False, |
| dot_product_scorer=None, |
| cross_attend_prompt=None, |
| ): |
| super().__init__( |
| hidden_dim=hidden_dim, |
| upsampling_stages=upsampling_stages, |
| use_encoder_inputs=True, |
| aux_masks=aux_masks, |
| no_dec=no_dec, |
| pixel_decoder=pixel_decoder, |
| act_ckpt=act_ckpt, |
| ) |
| self.d_model = hidden_dim |
|
|
| if dot_product_scorer is not None: |
| assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake" |
|
|
| self.presence_head = None |
| if presence_head: |
| self.presence_head = ( |
| dot_product_scorer |
| if dot_product_scorer is not None |
| else LinearPresenceHead(self.d_model) |
| ) |
|
|
| self.cross_attend_prompt = cross_attend_prompt |
| if self.cross_attend_prompt is not None: |
| self.cross_attn_norm = nn.LayerNorm(self.d_model) |
|
|
| self.semantic_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, 1, kernel_size=1) |
| self.instance_seg_head = nn.Conv2d( |
| self.pixel_decoder.out_dim, self.d_model, kernel_size=1 |
| ) |
|
|
| def forward( |
| self, |
| backbone_feats: List[torch.Tensor], |
| obj_queries: torch.Tensor, |
| image_ids, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| prompt: Optional[torch.Tensor] = None, |
| prompt_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> Dict[str, Optional[torch.Tensor]]: |
| assert encoder_hidden_states is not None |
| bs = encoder_hidden_states.shape[1] |
|
|
| if self.cross_attend_prompt is not None: |
| tgt2 = self.cross_attn_norm(encoder_hidden_states) |
| tgt2 = self.cross_attend_prompt( |
| query=tgt2, |
| key=prompt, |
| value=prompt, |
| key_padding_mask=prompt_mask, |
| )[0] |
| encoder_hidden_states = tgt2 + encoder_hidden_states |
|
|
| presence_logit = None |
| if self.presence_head is not None: |
| pooled_enc = encoder_hidden_states.mean(0) |
| presence_logit = ( |
| self.presence_head( |
| pooled_enc.view(1, bs, 1, self.d_model), |
| prompt=prompt, |
| prompt_mask=prompt_mask, |
| ) |
| .squeeze(0) |
| .squeeze(1) |
| ) |
|
|
| pixel_embed = self._embed_pixels( |
| backbone_feats=backbone_feats, |
| image_ids=image_ids, |
| encoder_hidden_states=encoder_hidden_states, |
| ) |
|
|
| instance_embeds = self.instance_seg_head(pixel_embed) |
|
|
| if self.no_dec: |
| mask_pred = self.mask_predictor(instance_embeds) |
| elif self.aux_masks: |
| mask_pred = self.mask_predictor(obj_queries, instance_embeds) |
| else: |
| mask_pred = self.mask_predictor(obj_queries[-1], instance_embeds) |
|
|
| return { |
| "pred_masks": mask_pred, |
| "semantic_seg": self.semantic_seg_head(pixel_embed), |
| "presence_logit": presence_logit, |
| } |
|
|