Spaces:
Runtime error
Runtime error
| import copy | |
| import warnings | |
| from typing import List, Optional, Tuple, Union, Dict | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import ConvModule | |
| from mmengine import ConfigDict | |
| from mmengine.structures import InstanceData | |
| from torch import Tensor | |
| from mmdet.models import BaseDetector, TwoStageDetector, StandardRoIHead, SinePositionalEncoding, FCNMaskHead, \ | |
| BaseRoIHead | |
| from mmdet.models.task_modules import SamplingResult | |
| from mmdet.models.utils import multi_apply, unpack_gt_instances, empty_instances | |
| from mmdet.structures import SampleList, DetDataSample | |
| from mmdet.structures.bbox import bbox2roi | |
| from mmdet.structures.mask import mask_target | |
| from mmdet.utils import InstanceList, reduce_mean, OptMultiConfig | |
| from mmpl.registry import MODELS, TASK_UTILS | |
| from mmengine.model import BaseModel, BaseModule | |
| from einops import rearrange, repeat | |
| from mmpl.utils import ConfigType, OptConfigType | |
| from mmdet.models.dense_heads import Mask2FormerHead | |
| from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead | |
| class SAMInstanceHead(Mask2FormerHead): | |
| def __init__( | |
| self, | |
| num_things_classes: int = 1, | |
| num_stuff_classes: int = 0, | |
| prompt_neck: ConfigType = ..., | |
| with_iou: bool = False, | |
| with_multiscale: bool = False, | |
| with_sincos: bool = False, | |
| with_res_imgfeat: bool = False, | |
| loss_cls: ConfigType = dict( | |
| type='CrossEntropyLoss', | |
| use_sigmoid=False, | |
| loss_weight=2.0, | |
| reduction='mean', | |
| class_weight=[1.0] * 133 + [0.1]), | |
| loss_mask: ConfigType = dict( | |
| type='CrossEntropyLoss', | |
| use_sigmoid=True, | |
| reduction='mean', | |
| loss_weight=5.0), | |
| loss_dice: ConfigType = dict( | |
| type='DiceLoss', | |
| use_sigmoid=True, | |
| activate=True, | |
| reduction='mean', | |
| naive_dice=True, | |
| eps=1.0, | |
| loss_weight=5.0), | |
| train_cfg: OptConfigType = None, | |
| test_cfg: OptConfigType = None, | |
| init_cfg: OptMultiConfig = None, | |
| norm_cfg=dict(type='BN', requires_grad=True), | |
| act_cfg=dict(type='ReLU', inplace=True), | |
| **kwargs | |
| ): | |
| super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) | |
| self.num_things_classes = num_things_classes | |
| self.num_stuff_classes = num_stuff_classes | |
| self.num_classes = self.num_things_classes + self.num_stuff_classes | |
| self.with_iou = with_iou | |
| self.with_multiscale = with_multiscale | |
| self.with_sincos = with_sincos | |
| self.with_res_imgfeat = with_res_imgfeat | |
| # self.num_transformer_feat_level = num_transformer_feat_level | |
| # self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads | |
| # self.num_transformer_decoder_layers = transformer_decoder.num_layers | |
| # assert pixel_decoder.encoder.layer_cfg. \ | |
| # self_attn_cfg.num_levels == num_transformer_feat_level | |
| # pixel_decoder_ = copy.deepcopy(pixel_decoder) | |
| # pixel_decoder_.update( | |
| # in_channels=in_channels, | |
| # feat_channels=feat_channels, | |
| # out_channels=out_channels) | |
| # self.pixel_decoder = MODELS.build(pixel_decoder_) | |
| # self.transformer_decoder = Mask2FormerTransformerDecoder( | |
| # **transformer_decoder) | |
| # self.decoder_embed_dims = self.transformer_decoder.embed_dims | |
| # | |
| # self.decoder_input_projs = ModuleList() | |
| # # from low resolution to high resolution | |
| # for _ in range(num_transformer_feat_level): | |
| # if (self.decoder_embed_dims != feat_channels | |
| # or enforce_decoder_input_project): | |
| # self.decoder_input_projs.append( | |
| # Conv2d( | |
| # feat_channels, self.decoder_embed_dims, kernel_size=1)) | |
| # else: | |
| # self.decoder_input_projs.append(nn.Identity()) | |
| # self.decoder_positional_encoding = SinePositionalEncoding( | |
| # **positional_encoding) | |
| # self.query_embed = nn.Embedding(self.num_queries, feat_channels) | |
| # self.query_feat = nn.Embedding(self.num_queries, feat_channels) | |
| # # from low resolution to high resolution | |
| # self.level_embed = nn.Embedding(self.num_transformer_feat_level, | |
| # feat_channels) | |
| # | |
| # self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) | |
| # self.mask_embed = nn.Sequential( | |
| # nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |
| # nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |
| # nn.Linear(feat_channels, out_channels)) | |
| self.prompt_neck = MODELS.build(prompt_neck) | |
| self.num_queries = self.prompt_neck.num_queries | |
| self.per_query_point = self.prompt_neck.per_query_point | |
| out_channels = self.prompt_neck.out_channels | |
| self.cls_embed = nn.Sequential( | |
| nn.Linear(out_channels, out_channels // 2), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(out_channels // 2, self.num_classes + 1) | |
| ) | |
| if self.with_sincos: | |
| self.point_emb = nn.Sequential( | |
| nn.Linear(out_channels, out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(out_channels, out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(out_channels, self.per_query_point * out_channels*2) | |
| ) | |
| else: | |
| self.point_emb = nn.Sequential( | |
| nn.Linear(out_channels, out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(out_channels, out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(out_channels, self.per_query_point * out_channels) | |
| ) | |
| if self.with_res_imgfeat: | |
| self.res_imgfeat = nn.Sequential( | |
| nn.UpsamplingBilinear2d(scale_factor=2), | |
| ConvModule( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg | |
| ) | |
| ) | |
| self.test_cfg = test_cfg | |
| self.train_cfg = train_cfg | |
| if train_cfg: | |
| self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) | |
| self.sampler = TASK_UTILS.build( | |
| self.train_cfg['sampler'], default_args=dict(context=self)) | |
| self.num_points = self.train_cfg.get('num_points', 12544) | |
| self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) | |
| self.importance_sample_ratio = self.train_cfg.get( | |
| 'importance_sample_ratio', 0.75) | |
| self.class_weight = loss_cls.class_weight | |
| self.loss_cls = MODELS.build(loss_cls) | |
| self.loss_mask = MODELS.build(loss_mask) | |
| self.loss_dice = MODELS.build(loss_dice) | |
| def forward(self, x: List[Tensor], | |
| batch_data_samples: SampleList, | |
| sam | |
| ) -> Tuple[List[Tensor]]: | |
| """Forward function. | |
| Args: | |
| x (list[Tensor]): Multi scale Features from the | |
| upstream network, each is a 4D-tensor. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
| Samples. It usually includes information such as | |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
| Returns: | |
| tuple[list[Tensor]]: A tuple contains two elements. | |
| - cls_pred_list (list[Tensor)]: Classification logits \ | |
| for each decoder layer. Each is a 3D-tensor with shape \ | |
| (batch_size, num_queries, cls_out_channels). \ | |
| Note `cls_out_channels` should includes background. | |
| - mask_pred_list (list[Tensor]): Mask logits for each \ | |
| decoder layer. Each with shape (batch_size, num_queries, \ | |
| h, w). | |
| """ | |
| batch_img_metas = [ | |
| data_sample.metainfo for data_sample in batch_data_samples | |
| ] | |
| batch_size = len(batch_img_metas) | |
| decoder_out, query_feat_list, res_img_feat = self.prompt_neck(x) | |
| if self.with_multiscale: | |
| cls_pred_list = [self.cls_embed(query_feat) for query_feat in query_feat_list] | |
| else: | |
| # shape (batch_size, num_queries, c) | |
| cls_pred_list = [self.cls_embed(decoder_out)] | |
| # shape (batch_size, num_queries, c) | |
| point_emb = self.point_emb(decoder_out) | |
| # shape (batch_size, num_queries, per_query_point, c) | |
| point_emb = point_emb.view(batch_size, self.num_queries, self.per_query_point, -1) | |
| img_seg_feat = x[0] | |
| point_emb = rearrange(point_emb, 'b n p c -> (b n) p c') | |
| if self.with_sincos: | |
| point_emb = torch.sin(point_emb[..., ::2]) + point_emb[..., 1::2] | |
| nomask_dense_embeddings = sam.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( | |
| point_emb.shape[0], -1, *img_seg_feat.shape[-2:] | |
| ) | |
| img_embeddings = torch.repeat_interleave(img_seg_feat, self.num_queries, dim=0) | |
| img_pe = sam.prompt_encoder.get_dense_pe() | |
| img_pe = repeat(img_pe, 'b c h w -> (b n) c h w', n=img_embeddings.shape[0]) | |
| if self.with_res_imgfeat: | |
| res_img_feat = self.res_imgfeat(res_img_feat) | |
| res_img_feat = torch.repeat_interleave(res_img_feat, self.num_queries, dim=0) | |
| else: | |
| res_img_feat = None | |
| low_res_masks, iou_predictions = sam.mask_decoder.forward_batch( | |
| image_embeddings=img_embeddings, | |
| image_pe=img_pe, | |
| sparse_prompt_embeddings=point_emb, | |
| dense_prompt_embeddings=nomask_dense_embeddings, | |
| multimask_output=False, | |
| res_img_feat=res_img_feat, | |
| ) | |
| mask_pred = rearrange(low_res_masks.squeeze(1), '(b n) h w -> b n h w', b=batch_size) | |
| # optional | |
| # if self.with_iou: | |
| # iou_predictions = iou_predictions.view(batch_size, self.num_queries, -1) | |
| # cls_pred = cls_pred * iou_predictions | |
| if self.with_multiscale: | |
| mask_pred_list = [mask_pred] * len(cls_pred_list) | |
| else: | |
| mask_pred_list = [mask_pred] | |
| return cls_pred_list, mask_pred_list | |
| def predict(self, x: Tuple[Tensor], | |
| batch_data_samples: SampleList, | |
| sam | |
| ) -> Tuple[Tensor]: | |
| """Test without augmentaton. | |
| Args: | |
| x (tuple[Tensor]): Multi-level features from the | |
| upstream network, each is a 4D-tensor. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
| Samples. It usually includes information such as | |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
| Returns: | |
| tuple[Tensor]: A tuple contains two tensors. | |
| - mask_cls_results (Tensor): Mask classification logits,\ | |
| shape (batch_size, num_queries, cls_out_channels). | |
| Note `cls_out_channels` should includes background. | |
| - mask_pred_results (Tensor): Mask logits, shape \ | |
| (batch_size, num_queries, h, w). | |
| """ | |
| batch_img_metas = [ | |
| data_sample.metainfo for data_sample in batch_data_samples | |
| ] | |
| all_cls_scores, all_mask_preds = self(x, batch_data_samples, sam) | |
| mask_cls_results = all_cls_scores[-1] | |
| mask_pred_results = all_mask_preds[-1] | |
| # upsample masks | |
| img_shape = batch_img_metas[0]['batch_input_shape'] | |
| mask_pred_results = F.interpolate( | |
| mask_pred_results, | |
| size=(img_shape[0], img_shape[1]), | |
| mode='bilinear', | |
| align_corners=False) | |
| return mask_cls_results, mask_pred_results | |
| def loss( | |
| self, | |
| x: Tuple[Tensor], | |
| batch_data_samples: SampleList, | |
| sam, | |
| ) -> Dict[str, Tensor]: | |
| """Perform forward propagation and loss calculation of the panoptic | |
| head on the features of the upstream network. | |
| Args: | |
| x (tuple[Tensor]): Multi-level features from the upstream | |
| network, each is a 4D-tensor. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
| Samples. It usually includes information such as | |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
| Returns: | |
| dict[str, Tensor]: a dictionary of loss components | |
| """ | |
| batch_img_metas = [] | |
| batch_gt_instances = [] | |
| batch_gt_semantic_segs = [] | |
| for data_sample in batch_data_samples: | |
| batch_img_metas.append(data_sample.metainfo) | |
| batch_gt_instances.append(data_sample.gt_instances) | |
| if 'gt_sem_seg' in data_sample: | |
| batch_gt_semantic_segs.append(data_sample.gt_sem_seg) | |
| else: | |
| batch_gt_semantic_segs.append(None) | |
| # forward | |
| all_cls_scores, all_mask_preds = self(x, batch_data_samples, sam) | |
| # preprocess ground truth | |
| batch_gt_instances = self.preprocess_gt(batch_gt_instances, | |
| batch_gt_semantic_segs) | |
| # loss | |
| losses = self.loss_by_feat(all_cls_scores, all_mask_preds, | |
| batch_gt_instances, batch_img_metas) | |
| return losses | |
| class SAMAnchorInstanceHead(TwoStageDetector): | |
| def __init__( | |
| self, | |
| sam_head=True, | |
| neck: OptConfigType = None, | |
| rpn_head: OptConfigType = None, | |
| roi_head: OptConfigType = None, | |
| train_cfg: OptConfigType = None, | |
| test_cfg: OptConfigType = None, | |
| **kwargs | |
| ): | |
| super(TwoStageDetector, self).__init__() | |
| self.neck = MODELS.build(neck) | |
| self.sam_head = sam_head | |
| if rpn_head is not None: | |
| rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None | |
| rpn_head_ = rpn_head.copy() | |
| rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn) | |
| rpn_head_num_classes = rpn_head_.get('num_classes', None) | |
| if rpn_head_num_classes is None: | |
| rpn_head_.update(num_classes=1) | |
| else: | |
| if rpn_head_num_classes != 1: | |
| warnings.warn( | |
| 'The `num_classes` should be 1 in RPN, but get ' | |
| f'{rpn_head_num_classes}, please set ' | |
| 'rpn_head.num_classes = 1 in your config file.') | |
| rpn_head_.update(num_classes=1) | |
| self.rpn_head = MODELS.build(rpn_head_) | |
| if roi_head is not None: | |
| # update train and test cfg here for now | |
| # TODO: refactor assigner & sampler | |
| rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None | |
| roi_head.update(train_cfg=rcnn_train_cfg) | |
| roi_head.update(test_cfg=test_cfg.rcnn) | |
| self.roi_head = MODELS.build(roi_head) | |
| self.train_cfg = train_cfg | |
| self.test_cfg = test_cfg | |
| def extract_feat(self, x): | |
| x = self.neck(x) | |
| return x | |
| def loss(self, | |
| batch_inputs, | |
| batch_data_samples: SampleList, | |
| sam | |
| ) -> dict: | |
| """Calculate losses from a batch of inputs and data samples. | |
| Args: | |
| batch_inputs (Tensor): Input images of shape (N, C, H, W). | |
| These should usually be mean centered and std scaled. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch | |
| data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
| Returns: | |
| dict: A dictionary of loss components | |
| """ | |
| x = self.extract_feat(batch_inputs) | |
| img_seg_feat = batch_inputs[0] | |
| losses = dict() | |
| # RPN forward and loss | |
| if self.with_rpn: | |
| proposal_cfg = self.train_cfg.get('rpn_proposal', | |
| self.test_cfg.rpn) | |
| rpn_data_samples = copy.deepcopy(batch_data_samples) | |
| # set cat_id of gt_labels to 0 in RPN | |
| for data_sample in rpn_data_samples: | |
| data_sample.gt_instances.labels = \ | |
| torch.zeros_like(data_sample.gt_instances.labels) | |
| rpn_losses, rpn_results_list = self.rpn_head.loss_and_predict( | |
| x, rpn_data_samples, proposal_cfg=proposal_cfg) | |
| # avoid get same name with roi_head loss | |
| keys = rpn_losses.keys() | |
| for key in list(keys): | |
| if 'loss' in key and 'rpn' not in key: | |
| rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) | |
| losses.update(rpn_losses) | |
| else: | |
| assert batch_data_samples[0].get('proposals', None) is not None | |
| # use pre-defined proposals in InstanceData for the second stage | |
| # to extract ROI features. | |
| rpn_results_list = [ | |
| data_sample.proposals for data_sample in batch_data_samples | |
| ] | |
| if self.sam_head: | |
| roi_losses = self.roi_head.loss(x, rpn_results_list, | |
| batch_data_samples, | |
| sam, img_seg_feat | |
| ) | |
| else: | |
| roi_losses = self.roi_head.loss(x, rpn_results_list, | |
| batch_data_samples | |
| ) | |
| losses.update(roi_losses) | |
| return losses | |
| def predict(self, | |
| batch_inputs: Tensor, | |
| batch_data_samples: SampleList, | |
| sam, | |
| rescale: bool = True | |
| ) -> SampleList: | |
| """Predict results from a batch of inputs and data samples with post- | |
| processing. | |
| Args: | |
| batch_inputs (Tensor): Inputs with shape (N, C, H, W). | |
| batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
| Samples. It usually includes information such as | |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
| rescale (bool): Whether to rescale the results. | |
| Defaults to True. | |
| Returns: | |
| list[:obj:`DetDataSample`]: Return the detection results of the | |
| input images. The returns value is DetDataSample, | |
| which usually contain 'pred_instances'. And the | |
| ``pred_instances`` usually contains following keys. | |
| - scores (Tensor): Classification scores, has a shape | |
| (num_instance, ) | |
| - labels (Tensor): Labels of bboxes, has a shape | |
| (num_instances, ). | |
| - bboxes (Tensor): Has a shape (num_instances, 4), | |
| the last dimension 4 arrange as (x1, y1, x2, y2). | |
| - masks (Tensor): Has a shape (num_instances, H, W). | |
| """ | |
| assert self.with_bbox, 'Bbox head must be implemented.' | |
| x = self.extract_feat(batch_inputs) | |
| img_seg_feat = batch_inputs[0] | |
| # If there are no pre-defined proposals, use RPN to get proposals | |
| if batch_data_samples[0].get('proposals', None) is None: | |
| rpn_results_list = self.rpn_head.predict( | |
| x, batch_data_samples, rescale=False) | |
| else: | |
| rpn_results_list = [ | |
| data_sample.proposals for data_sample in batch_data_samples | |
| ] | |
| if self.sam_head: | |
| results_list = self.roi_head.predict( | |
| x, rpn_results_list, batch_data_samples, sam, img_seg_feat, rescale=rescale) | |
| else: | |
| results_list = self.roi_head.predict( | |
| x, rpn_results_list, batch_data_samples, rescale=rescale) | |
| batch_data_samples = self.add_pred_to_datasample( | |
| batch_data_samples, results_list) | |
| return batch_data_samples | |
| class SAMAnchorPromptRoIHead(StandardRoIHead): | |
| def __init__( | |
| self, | |
| positional_encoding=dict(num_feats=128, normalize=True), | |
| *args, | |
| **kwargs | |
| ): | |
| super(StandardRoIHead, self).__init__(*args, **kwargs) | |
| self.generator_pe = SinePositionalEncoding(**positional_encoding) | |
| def _mask_forward(self, | |
| x: Tuple[Tensor], | |
| rois: Tensor = None, | |
| pos_inds: Optional[Tensor] = None, | |
| bbox_feats: Optional[Tensor] = None, | |
| sam=None, img_seg_feat=None | |
| ) -> dict: | |
| """Mask head forward function used in both training and testing. | |
| Args: | |
| x (tuple[Tensor]): Tuple of multi-level img features. | |
| rois (Tensor): RoIs with the shape (n, 5) where the first | |
| column indicates batch id of each RoI. | |
| pos_inds (Tensor, optional): Indices of positive samples. | |
| Defaults to None. | |
| bbox_feats (Tensor): Extract bbox RoI features. Defaults to None. | |
| Returns: | |
| dict[str, Tensor]: Usually returns a dictionary with keys: | |
| - `mask_preds` (Tensor): Mask prediction. | |
| - `mask_feats` (Tensor): Extract mask RoI features. | |
| """ | |
| assert ((rois is not None) ^ | |
| (pos_inds is not None and bbox_feats is not None)) | |
| if rois is not None: | |
| mask_feats = self.mask_roi_extractor( | |
| x[:self.mask_roi_extractor.num_inputs], rois) | |
| if self.with_shared_head: | |
| mask_feats = self.shared_head(mask_feats) | |
| else: | |
| assert bbox_feats is not None | |
| mask_feats = bbox_feats[pos_inds] | |
| mask_preds = self.mask_head(mask_feats, sam, img_seg_feat, img_flag_ids=rois[:, 0]) | |
| mask_results = dict(mask_preds=mask_preds[0], mask_iou=mask_preds[1], mask_feats=mask_feats) | |
| return mask_results | |
| def mask_loss(self, x: Tuple[Tensor], | |
| sampling_results: List[SamplingResult], bbox_feats: Tensor, | |
| batch_gt_instances: InstanceList, | |
| sam, img_seg_feat | |
| ) -> dict: | |
| """Perform forward propagation and loss calculation of the mask head on | |
| the features of the upstream network. | |
| Args: | |
| x (tuple[Tensor]): Tuple of multi-level img features. | |
| sampling_results (list["obj:`SamplingResult`]): Sampling results. | |
| bbox_feats (Tensor): Extract bbox RoI features. | |
| batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
| gt_instance. It usually includes ``bboxes``, ``labels``, and | |
| ``masks`` attributes. | |
| Returns: | |
| dict: Usually returns a dictionary with keys: | |
| - `mask_preds` (Tensor): Mask prediction. | |
| - `mask_feats` (Tensor): Extract mask RoI features. | |
| - `mask_targets` (Tensor): Mask target of each positive\ | |
| proposals in the image. | |
| - `loss_mask` (dict): A dictionary of mask loss components. | |
| """ | |
| if not self.share_roi_extractor: | |
| pos_rois = bbox2roi([res.pos_priors for res in sampling_results]) | |
| mask_results = self._mask_forward( | |
| x, pos_rois, sam=sam, img_seg_feat=img_seg_feat) | |
| else: | |
| pos_inds = [] | |
| device = bbox_feats.device | |
| for res in sampling_results: | |
| pos_inds.append( | |
| torch.ones( | |
| res.pos_priors.shape[0], | |
| device=device, | |
| dtype=torch.uint8)) | |
| pos_inds.append( | |
| torch.zeros( | |
| res.neg_priors.shape[0], | |
| device=device, | |
| dtype=torch.uint8)) | |
| pos_inds = torch.cat(pos_inds) | |
| mask_results = self._mask_forward( | |
| x, pos_inds=pos_inds, bbox_feats=bbox_feats) | |
| mask_loss_and_target = self.mask_head.loss_and_target( | |
| mask_preds=mask_results['mask_preds'], | |
| sampling_results=sampling_results, | |
| batch_gt_instances=batch_gt_instances, | |
| rcnn_train_cfg=self.train_cfg) | |
| mask_results.update(loss_mask=mask_loss_and_target['loss_mask']) | |
| return mask_results | |
| def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, | |
| batch_data_samples: List[DetDataSample], | |
| sam, img_seg_feat | |
| ) -> dict: | |
| """Perform forward propagation and loss calculation of the detection | |
| roi on the features of the upstream network. | |
| Args: | |
| x (tuple[Tensor]): List of multi-level img features. | |
| rpn_results_list (list[:obj:`InstanceData`]): List of region | |
| proposals. | |
| batch_data_samples (list[:obj:`DetDataSample`]): The batch | |
| data samples. It usually includes information such | |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. | |
| Returns: | |
| dict[str, Tensor]: A dictionary of loss components | |
| """ | |
| x = list(x) | |
| bs, _, h, w = x[-1].shape | |
| mask_pe = torch.zeros((bs, h, w), device=x[0].device, dtype=torch.bool) | |
| img_feats_pe = self.generator_pe(mask_pe) | |
| for i in range(len(x)): | |
| x[i] = x[i] + torch.nn.functional.interpolate(img_feats_pe, size=x[i].shape[-2:], mode='bilinear') | |
| assert len(rpn_results_list) == len(batch_data_samples) | |
| outputs = unpack_gt_instances(batch_data_samples) | |
| batch_gt_instances, batch_gt_instances_ignore, _ = outputs | |
| # assign gts and sample proposals | |
| num_imgs = len(batch_data_samples) | |
| sampling_results = [] | |
| for i in range(num_imgs): | |
| # rename rpn_results.bboxes to rpn_results.priors | |
| rpn_results = rpn_results_list[i] | |
| rpn_results.priors = rpn_results.pop('bboxes') | |
| assign_result = self.bbox_assigner.assign( | |
| rpn_results, batch_gt_instances[i], | |
| batch_gt_instances_ignore[i]) | |
| sampling_result = self.bbox_sampler.sample( | |
| assign_result, | |
| rpn_results, | |
| batch_gt_instances[i], | |
| feats=[lvl_feat[i][None] for lvl_feat in x]) | |
| sampling_results.append(sampling_result) | |
| losses = dict() | |
| # bbox head loss | |
| if self.with_bbox: | |
| bbox_results = self.bbox_loss(x, sampling_results) | |
| losses.update(bbox_results['loss_bbox']) | |
| # mask head forward and loss | |
| if self.with_mask: | |
| mask_results = self.mask_loss(x, sampling_results, | |
| bbox_results['bbox_feats'], | |
| batch_gt_instances, | |
| sam, img_seg_feat | |
| ) | |
| losses.update(mask_results['loss_mask']) | |
| return losses | |
| def predict_mask(self, | |
| x: Tuple[Tensor], | |
| batch_img_metas: List[dict], | |
| results_list: InstanceList, | |
| rescale: bool = False, | |
| sam=None, img_seg_feat=None | |
| ) -> InstanceList: | |
| """Perform forward propagation of the mask head and predict detection | |
| results on the features of the upstream network. | |
| Args: | |
| x (tuple[Tensor]): Feature maps of all scale level. | |
| batch_img_metas (list[dict]): List of image information. | |
| results_list (list[:obj:`InstanceData`]): Detection results of | |
| each image. | |
| rescale (bool): If True, return boxes in original image space. | |
| Defaults to False. | |
| Returns: | |
| list[:obj:`InstanceData`]: Detection results of each image | |
| after the post process. | |
| Each item usually contains following keys. | |
| - scores (Tensor): Classification scores, has a shape | |
| (num_instance, ) | |
| - labels (Tensor): Labels of bboxes, has a shape | |
| (num_instances, ). | |
| - bboxes (Tensor): Has a shape (num_instances, 4), | |
| the last dimension 4 arrange as (x1, y1, x2, y2). | |
| - masks (Tensor): Has a shape (num_instances, H, W). | |
| """ | |
| # don't need to consider aug_test. | |
| bboxes = [res.bboxes for res in results_list] | |
| mask_rois = bbox2roi(bboxes) | |
| if mask_rois.shape[0] == 0: | |
| results_list = empty_instances( | |
| batch_img_metas, | |
| mask_rois.device, | |
| task_type='mask', | |
| instance_results=results_list, | |
| mask_thr_binary=self.test_cfg.mask_thr_binary) | |
| return results_list | |
| mask_results = self._mask_forward(x, mask_rois, sam=sam, img_seg_feat=img_seg_feat) | |
| mask_preds = mask_results['mask_preds'] | |
| # split batch mask prediction back to each image | |
| num_mask_rois_per_img = [len(res) for res in results_list] | |
| mask_preds = mask_preds.split(num_mask_rois_per_img, 0) | |
| # TODO: Handle the case where rescale is false | |
| results_list = self.mask_head.predict_by_feat( | |
| mask_preds=mask_preds, | |
| results_list=results_list, | |
| batch_img_metas=batch_img_metas, | |
| rcnn_test_cfg=self.test_cfg, | |
| rescale=rescale) | |
| return results_list | |
| def predict(self, | |
| x: Tuple[Tensor], | |
| rpn_results_list: InstanceList, | |
| batch_data_samples: SampleList, | |
| sam, img_seg_feat, | |
| rescale: bool = False) -> InstanceList: | |
| """Perform forward propagation of the roi head and predict detection | |
| results on the features of the upstream network. | |
| Args: | |
| x (tuple[Tensor]): Features from upstream network. Each | |
| has shape (N, C, H, W). | |
| rpn_results_list (list[:obj:`InstanceData`]): list of region | |
| proposals. | |
| batch_data_samples (List[:obj:`DetDataSample`]): The Data | |
| Samples. It usually includes information such as | |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | |
| rescale (bool): Whether to rescale the results to | |
| the original image. Defaults to True. | |
| Returns: | |
| list[obj:`InstanceData`]: Detection results of each image. | |
| Each item usually contains following keys. | |
| - scores (Tensor): Classification scores, has a shape | |
| (num_instance, ) | |
| - labels (Tensor): Labels of bboxes, has a shape | |
| (num_instances, ). | |
| - bboxes (Tensor): Has a shape (num_instances, 4), | |
| the last dimension 4 arrange as (x1, y1, x2, y2). | |
| - masks (Tensor): Has a shape (num_instances, H, W). | |
| """ | |
| x = list(x) | |
| bs, _, h, w = x[-1].shape | |
| mask_pe = torch.zeros((bs, h, w), device=x[0].device, dtype=torch.bool) | |
| img_feats_pe = self.generator_pe(mask_pe) | |
| for i in range(len(x)): | |
| x[i] = x[i] + torch.nn.functional.interpolate(img_feats_pe, size=x[i].shape[-2:], mode='bilinear') | |
| assert self.with_bbox, 'Bbox head must be implemented.' | |
| batch_img_metas = [ | |
| data_samples.metainfo for data_samples in batch_data_samples | |
| ] | |
| # TODO: nms_op in mmcv need be enhanced, the bbox result may get | |
| # difference when not rescale in bbox_head | |
| # If it has the mask branch, the bbox branch does not need | |
| # to be scaled to the original image scale, because the mask | |
| # branch will scale both bbox and mask at the same time. | |
| bbox_rescale = rescale if not self.with_mask else False | |
| results_list = self.predict_bbox( | |
| x, | |
| batch_img_metas, | |
| rpn_results_list, | |
| rcnn_test_cfg=self.test_cfg, | |
| rescale=bbox_rescale) | |
| if self.with_mask: | |
| results_list = self.predict_mask( | |
| x, batch_img_metas, results_list, rescale=rescale, sam=sam, img_seg_feat=img_seg_feat) | |
| return results_list | |
| class SAMPromptMaskHead(FCNMaskHead): | |
| def __init__(self, | |
| per_query_point: int = 5, | |
| with_sincos: bool = True, | |
| class_agnostic: bool = False, | |
| loss_mask: ConfigType = dict( | |
| type='CrossEntropyLoss', use_mask=True, loss_weight=1.0), | |
| *args, | |
| **kwargs | |
| ) -> None: | |
| super(BaseModule, self).__init__() | |
| self.per_query_point = per_query_point | |
| self.with_sincos = with_sincos | |
| self.class_agnostic = class_agnostic | |
| self.loss_mask = MODELS.build(loss_mask) | |
| if with_sincos: | |
| sincos = 2 | |
| else: | |
| sincos = 1 | |
| self.point_emb = nn.Sequential( | |
| nn.Conv2d(256, 256, 3, stride=2, padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(inplace=True), | |
| nn.Flatten(), | |
| nn.Linear(7*7*256, 256), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(256, 256), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(256, 256*sincos*per_query_point) | |
| ) | |
| def forward(self, x, sam, img_seg_feat, img_flag_ids) -> Tensor: | |
| batch_size = x.shape[0] | |
| point_emb = self.point_emb(x) | |
| point_emb = point_emb.view(batch_size, self.per_query_point, -1) | |
| if self.with_sincos: | |
| point_emb = torch.sin(point_emb[..., ::2]) + point_emb[..., 1::2] | |
| nomask_dense_embeddings = sam.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( | |
| point_emb.shape[0], -1, *img_seg_feat.shape[-2:] | |
| ) | |
| img_flag_ids = torch.bincount(img_flag_ids.long()) | |
| padding = torch.zeros((len(img_seg_feat)-len(img_flag_ids),), device=img_flag_ids.device, dtype=img_flag_ids.dtype) | |
| img_flag_ids = torch.cat([img_flag_ids, padding]) | |
| img_embeddings = torch.repeat_interleave(img_seg_feat, img_flag_ids, dim=0) | |
| img_pe = sam.prompt_encoder.get_dense_pe() | |
| img_pe = repeat(img_pe, 'b c h w -> (b n) c h w', n=img_embeddings.shape[0]) | |
| res_img_feat = None | |
| low_res_masks, iou_predictions = sam.mask_decoder.forward_batch( | |
| image_embeddings=img_embeddings, | |
| image_pe=img_pe, | |
| sparse_prompt_embeddings=point_emb, | |
| dense_prompt_embeddings=nomask_dense_embeddings, | |
| multimask_output=False, | |
| res_img_feat=res_img_feat, | |
| ) | |
| mask_pred = low_res_masks.squeeze(1) | |
| iou_predictions = iou_predictions.squeeze(1) | |
| return mask_pred, iou_predictions | |
| def get_targets(self, sampling_results: List[SamplingResult], | |
| batch_gt_instances: InstanceList, | |
| rcnn_train_cfg: ConfigDict) -> Tensor: | |
| """Calculate the ground truth for all samples in a batch according to | |
| the sampling_results. | |
| Args: | |
| sampling_results (List[obj:SamplingResult]): Assign results of | |
| all images in a batch after sampling. | |
| batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
| gt_instance. It usually includes ``bboxes``, ``labels``, and | |
| ``masks`` attributes. | |
| rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. | |
| Returns: | |
| Tensor: Mask target of each positive proposals in the image. | |
| """ | |
| pos_proposals = [res.pos_priors for res in sampling_results] | |
| pos_assigned_gt_inds = [ | |
| res.pos_assigned_gt_inds for res in sampling_results | |
| ] | |
| gt_masks = [res.masks for res in batch_gt_instances] | |
| mask_targets_list = [] | |
| mask_size = (rcnn_train_cfg.mask_size,) * 2 | |
| device = pos_proposals[0].device | |
| for pos_gt_inds, gt_mask in zip(pos_assigned_gt_inds, gt_masks): | |
| if len(pos_gt_inds) == 0: | |
| mask_targets = torch.zeros((0,) + mask_size, device=device, dytpe=torch.float32) | |
| else: | |
| mask_targets = gt_mask[pos_gt_inds.cpu()].to_tensor(dtype=torch.float32, device=device) | |
| mask_targets_list.append(mask_targets) | |
| mask_targets = torch.cat(mask_targets_list) | |
| return mask_targets | |
| def loss_and_target(self, mask_preds: Tensor, | |
| sampling_results: List[SamplingResult], | |
| batch_gt_instances: InstanceList, | |
| rcnn_train_cfg: ConfigDict) -> dict: | |
| """Calculate the loss based on the features extracted by the mask head. | |
| Args: | |
| mask_preds (Tensor): Predicted foreground masks, has shape | |
| (num_pos, num_classes, h, w). | |
| sampling_results (List[obj:SamplingResult]): Assign results of | |
| all images in a batch after sampling. | |
| batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
| gt_instance. It usually includes ``bboxes``, ``labels``, and | |
| ``masks`` attributes. | |
| rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. | |
| Returns: | |
| dict: A dictionary of loss and targets components. | |
| """ | |
| mask_targets = self.get_targets( | |
| sampling_results=sampling_results, | |
| batch_gt_instances=batch_gt_instances, | |
| rcnn_train_cfg=rcnn_train_cfg) | |
| pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) | |
| mask_preds = torch.nn.functional.interpolate( | |
| mask_preds.unsqueeze(1), size=mask_targets.shape[-2:], mode='bilinear', align_corners=False) | |
| loss = dict() | |
| if mask_preds.size(0) == 0: | |
| loss_mask = mask_preds.sum() | |
| else: | |
| if self.class_agnostic: | |
| loss_mask = self.loss_mask(mask_preds, mask_targets, | |
| torch.zeros_like(pos_labels)) | |
| else: | |
| loss_mask = self.loss_mask(mask_preds, mask_targets, | |
| pos_labels) | |
| loss['loss_mask'] = loss_mask | |
| # TODO: which algorithm requires mask_targets? | |
| return dict(loss_mask=loss, mask_targets=mask_targets) | |
| def _predict_by_feat_single(self, | |
| mask_preds: Tensor, | |
| bboxes: Tensor, | |
| labels: Tensor, | |
| img_meta: dict, | |
| rcnn_test_cfg: ConfigDict, | |
| rescale: bool = False, | |
| activate_map: bool = False) -> Tensor: | |
| """Get segmentation masks from mask_preds and bboxes. | |
| Args: | |
| mask_preds (Tensor): Predicted foreground masks, has shape | |
| (n, num_classes, h, w). | |
| bboxes (Tensor): Predicted bboxes, has shape (n, 4) | |
| labels (Tensor): Labels of bboxes, has shape (n, ) | |
| img_meta (dict): image information. | |
| rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. | |
| Defaults to None. | |
| rescale (bool): If True, return boxes in original image space. | |
| Defaults to False. | |
| activate_map (book): Whether get results with augmentations test. | |
| If True, the `mask_preds` will not process with sigmoid. | |
| Defaults to False. | |
| Returns: | |
| Tensor: Encoded masks, has shape (n, img_w, img_h) | |
| Example: | |
| >>> from mmengine.config import Config | |
| >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA | |
| >>> N = 7 # N = number of extracted ROIs | |
| >>> C, H, W = 11, 32, 32 | |
| >>> # Create example instance of FCN Mask Head. | |
| >>> self = FCNMaskHead(num_classes=C, num_convs=0) | |
| >>> inputs = torch.rand(N, self.in_channels, H, W) | |
| >>> mask_preds = self.forward(inputs) | |
| >>> # Each input is associated with some bounding box | |
| >>> bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N) | |
| >>> labels = torch.randint(0, C, size=(N,)) | |
| >>> rcnn_test_cfg = Config({'mask_thr_binary': 0, }) | |
| >>> ori_shape = (H * 4, W * 4) | |
| >>> scale_factor = (1, 1) | |
| >>> rescale = False | |
| >>> img_meta = {'scale_factor': scale_factor, | |
| ... 'ori_shape': ori_shape} | |
| >>> # Encoded masks are a list for each category. | |
| >>> encoded_masks = self._get_seg_masks_single( | |
| ... mask_preds, bboxes, labels, | |
| ... img_meta, rcnn_test_cfg, rescale) | |
| >>> assert encoded_masks.size()[0] == N | |
| >>> assert encoded_masks.size()[1:] == ori_shape | |
| """ | |
| scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( | |
| (1, 2)) | |
| img_h, img_w = img_meta['ori_shape'][:2] | |
| device = bboxes.device | |
| if not activate_map: | |
| mask_preds = mask_preds.sigmoid() | |
| else: | |
| # In AugTest, has been activated before | |
| mask_preds = bboxes.new_tensor(mask_preds) | |
| if rescale: # in-placed rescale the bboxes | |
| bboxes /= scale_factor | |
| else: | |
| w_scale, h_scale = scale_factor[0, 0], scale_factor[0, 1] | |
| img_h = np.round(img_h * h_scale.item()).astype(np.int32) | |
| img_w = np.round(img_w * w_scale.item()).astype(np.int32) | |
| threshold = rcnn_test_cfg.mask_thr_binary | |
| im_mask = torch.nn.functional.interpolate( | |
| mask_preds.unsqueeze(1), size=(img_h, img_w), mode='bilinear', align_corners=False).squeeze(1) | |
| if threshold >= 0: | |
| im_mask = im_mask >= threshold | |
| else: | |
| # for visualization and debugging | |
| im_mask = (im_mask * 255).to(dtype=torch.uint8) | |
| return im_mask |