| |
| import copy |
| from typing import Dict, List, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| from mmcv.cnn import Linear |
| from mmengine.model import bias_init_with_prob, constant_init |
| from torch import Tensor |
|
|
| from mmdet.registry import MODELS |
| from mmdet.structures import SampleList |
| from mmdet.utils import InstanceList, OptInstanceList |
| from ..layers import inverse_sigmoid |
| from .detr_head import DETRHead |
|
|
|
|
| @MODELS.register_module() |
| class DeformableDETRHead(DETRHead): |
| r"""Head of DeformDETR: Deformable DETR: Deformable Transformers for |
| End-to-End Object Detection. |
| |
| Code is modified from the `official github repo |
| <https://github.com/fundamentalvision/Deformable-DETR>`_. |
| |
| More details can be found in the `paper |
| <https://arxiv.org/abs/2010.04159>`_ . |
| |
| Args: |
| share_pred_layer (bool): Whether to share parameters for all the |
| prediction layers. Defaults to `False`. |
| num_pred_layer (int): The number of the prediction layers. |
| Defaults to 6. |
| as_two_stage (bool, optional): Whether to generate the proposal |
| from the outputs of encoder. Defaults to `False`. |
| """ |
|
|
| def __init__(self, |
| *args, |
| share_pred_layer: bool = False, |
| num_pred_layer: int = 6, |
| as_two_stage: bool = False, |
| **kwargs) -> None: |
| self.share_pred_layer = share_pred_layer |
| self.num_pred_layer = num_pred_layer |
| self.as_two_stage = as_two_stage |
|
|
| super().__init__(*args, **kwargs) |
|
|
| def _init_layers(self) -> None: |
| """Initialize classification branch and regression branch of head.""" |
| fc_cls = Linear(self.embed_dims, self.cls_out_channels) |
| reg_branch = [] |
| for _ in range(self.num_reg_fcs): |
| reg_branch.append(Linear(self.embed_dims, self.embed_dims)) |
| reg_branch.append(nn.ReLU()) |
| reg_branch.append(Linear(self.embed_dims, 4)) |
| reg_branch = nn.Sequential(*reg_branch) |
|
|
| if self.share_pred_layer: |
| self.cls_branches = nn.ModuleList( |
| [fc_cls for _ in range(self.num_pred_layer)]) |
| self.reg_branches = nn.ModuleList( |
| [reg_branch for _ in range(self.num_pred_layer)]) |
| else: |
| self.cls_branches = nn.ModuleList( |
| [copy.deepcopy(fc_cls) for _ in range(self.num_pred_layer)]) |
| self.reg_branches = nn.ModuleList([ |
| copy.deepcopy(reg_branch) for _ in range(self.num_pred_layer) |
| ]) |
|
|
| def init_weights(self) -> None: |
| """Initialize weights of the Deformable DETR head.""" |
| if self.loss_cls.use_sigmoid: |
| bias_init = bias_init_with_prob(0.01) |
| for m in self.cls_branches: |
| nn.init.constant_(m.bias, bias_init) |
| for m in self.reg_branches: |
| constant_init(m[-1], 0, bias=0) |
| nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0) |
| if self.as_two_stage: |
| for m in self.reg_branches: |
| nn.init.constant_(m[-1].bias.data[2:], 0.0) |
|
|
| def forward(self, hidden_states: Tensor, |
| references: List[Tensor]) -> Tuple[Tensor]: |
| """Forward function. |
| |
| Args: |
| hidden_states (Tensor): Hidden states output from each decoder |
| layer, has shape (num_decoder_layers, bs, num_queries, dim). |
| references (list[Tensor]): List of the reference from the decoder. |
| The first reference is the `init_reference` (initial) and the |
| other num_decoder_layers(6) references are `inter_references` |
| (intermediate). The `init_reference` has shape (bs, |
| num_queries, 4) when `as_two_stage` of the detector is `True`, |
| otherwise (bs, num_queries, 2). Each `inter_reference` has |
| shape (bs, num_queries, 4) when `with_box_refine` of the |
| detector is `True`, otherwise (bs, num_queries, 2). The |
| coordinates are arranged as (cx, cy) when the last dimension is |
| 2, and (cx, cy, w, h) when it is 4. |
| |
| Returns: |
| tuple[Tensor]: results of head containing the following tensor. |
| |
| - all_layers_outputs_classes (Tensor): Outputs from the |
| classification head, has shape (num_decoder_layers, bs, |
| num_queries, cls_out_channels). |
| - all_layers_outputs_coords (Tensor): Sigmoid outputs from the |
| regression head with normalized coordinate format (cx, cy, w, |
| h), has shape (num_decoder_layers, bs, num_queries, 4) with the |
| last dimension arranged as (cx, cy, w, h). |
| """ |
| all_layers_outputs_classes = [] |
| all_layers_outputs_coords = [] |
|
|
| for layer_id in range(hidden_states.shape[0]): |
| reference = inverse_sigmoid(references[layer_id]) |
| |
| hidden_state = hidden_states[layer_id] |
| outputs_class = self.cls_branches[layer_id](hidden_state) |
| tmp_reg_preds = self.reg_branches[layer_id](hidden_state) |
| if reference.shape[-1] == 4: |
| |
| |
| |
| tmp_reg_preds += reference |
| else: |
| |
| |
| |
| assert reference.shape[-1] == 2 |
| tmp_reg_preds[..., :2] += reference |
| outputs_coord = tmp_reg_preds.sigmoid() |
| all_layers_outputs_classes.append(outputs_class) |
| all_layers_outputs_coords.append(outputs_coord) |
|
|
| all_layers_outputs_classes = torch.stack(all_layers_outputs_classes) |
| all_layers_outputs_coords = torch.stack(all_layers_outputs_coords) |
|
|
| return all_layers_outputs_classes, all_layers_outputs_coords |
|
|
| def loss(self, hidden_states: Tensor, references: List[Tensor], |
| enc_outputs_class: Tensor, enc_outputs_coord: Tensor, |
| batch_data_samples: SampleList) -> dict: |
| """Perform forward propagation and loss calculation of the detection |
| head on the queries of the upstream network. |
| |
| Args: |
| hidden_states (Tensor): Hidden states output from each decoder |
| layer, has shape (num_decoder_layers, num_queries, bs, dim). |
| references (list[Tensor]): List of the reference from the decoder. |
| The first reference is the `init_reference` (initial) and the |
| other num_decoder_layers(6) references are `inter_references` |
| (intermediate). The `init_reference` has shape (bs, |
| num_queries, 4) when `as_two_stage` of the detector is `True`, |
| otherwise (bs, num_queries, 2). Each `inter_reference` has |
| shape (bs, num_queries, 4) when `with_box_refine` of the |
| detector is `True`, otherwise (bs, num_queries, 2). The |
| coordinates are arranged as (cx, cy) when the last dimension is |
| 2, and (cx, cy, w, h) when it is 4. |
| enc_outputs_class (Tensor): The score of each point on encode |
| feature map, has shape (bs, num_feat_points, cls_out_channels). |
| Only when `as_two_stage` is `True` it would be passed in, |
| otherwise it would be `None`. |
| enc_outputs_coord (Tensor): The proposal generate from the encode |
| feature map, has shape (bs, num_feat_points, 4) with the last |
| dimension arranged as (cx, cy, w, h). Only when `as_two_stage` |
| is `True` it would be passed in, otherwise it would be `None`. |
| batch_data_samples (list[:obj:`DetDataSample`]): The Data |
| Samples. It usually includes information such as |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
| |
| Returns: |
| dict: A dictionary of loss components. |
| """ |
| batch_gt_instances = [] |
| batch_img_metas = [] |
| for data_sample in batch_data_samples: |
| batch_img_metas.append(data_sample.metainfo) |
| batch_gt_instances.append(data_sample.gt_instances) |
|
|
| outs = self(hidden_states, references) |
| loss_inputs = outs + (enc_outputs_class, enc_outputs_coord, |
| batch_gt_instances, batch_img_metas) |
| losses = self.loss_by_feat(*loss_inputs) |
| return losses |
|
|
| def loss_by_feat( |
| self, |
| all_layers_cls_scores: Tensor, |
| all_layers_bbox_preds: Tensor, |
| enc_cls_scores: Tensor, |
| enc_bbox_preds: Tensor, |
| batch_gt_instances: InstanceList, |
| batch_img_metas: List[dict], |
| batch_gt_instances_ignore: OptInstanceList = None |
| ) -> Dict[str, Tensor]: |
| """Loss function. |
| |
| Args: |
| all_layers_cls_scores (Tensor): Classification scores of all |
| decoder layers, has shape (num_decoder_layers, bs, num_queries, |
| cls_out_channels). |
| all_layers_bbox_preds (Tensor): Regression outputs of all decoder |
| layers. Each is a 4D-tensor with normalized coordinate format |
| (cx, cy, w, h) and has shape (num_decoder_layers, bs, |
| num_queries, 4) with the last dimension arranged as |
| (cx, cy, w, h). |
| enc_cls_scores (Tensor): The score of each point on encode |
| feature map, has shape (bs, num_feat_points, cls_out_channels). |
| Only when `as_two_stage` is `True` it would be passes in, |
| otherwise, it would be `None`. |
| enc_bbox_preds (Tensor): The proposal generate from the encode |
| feature map, has shape (bs, num_feat_points, 4) with the last |
| dimension arranged as (cx, cy, w, h). Only when `as_two_stage` |
| is `True` it would be passed in, otherwise it would be `None`. |
| batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
| gt_instance. It usually includes ``bboxes`` and ``labels`` |
| attributes. |
| batch_img_metas (list[dict]): Meta information of each image, e.g., |
| image size, scaling factor, etc. |
| batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): |
| Batch of gt_instances_ignore. It includes ``bboxes`` attribute |
| data that is ignored during training and testing. |
| Defaults to None. |
| |
| Returns: |
| dict[str, Tensor]: A dictionary of loss components. |
| """ |
| loss_dict = super().loss_by_feat(all_layers_cls_scores, |
| all_layers_bbox_preds, |
| batch_gt_instances, batch_img_metas, |
| batch_gt_instances_ignore) |
|
|
| |
| if enc_cls_scores is not None: |
| proposal_gt_instances = copy.deepcopy(batch_gt_instances) |
| for i in range(len(proposal_gt_instances)): |
| proposal_gt_instances[i].labels = torch.zeros_like( |
| proposal_gt_instances[i].labels) |
| enc_loss_cls, enc_losses_bbox, enc_losses_iou = \ |
| self.loss_by_feat_single( |
| enc_cls_scores, enc_bbox_preds, |
| batch_gt_instances=proposal_gt_instances, |
| batch_img_metas=batch_img_metas) |
| loss_dict['enc_loss_cls'] = enc_loss_cls |
| loss_dict['enc_loss_bbox'] = enc_losses_bbox |
| loss_dict['enc_loss_iou'] = enc_losses_iou |
| return loss_dict |
|
|
| def predict(self, |
| hidden_states: Tensor, |
| references: List[Tensor], |
| batch_data_samples: SampleList, |
| rescale: bool = True) -> InstanceList: |
| """Perform forward propagation and loss calculation of the detection |
| head on the queries of the upstream network. |
| |
| Args: |
| hidden_states (Tensor): Hidden states output from each decoder |
| layer, has shape (num_decoder_layers, num_queries, bs, dim). |
| references (list[Tensor]): List of the reference from the decoder. |
| The first reference is the `init_reference` (initial) and the |
| other num_decoder_layers(6) references are `inter_references` |
| (intermediate). The `init_reference` has shape (bs, |
| num_queries, 4) when `as_two_stage` of the detector is `True`, |
| otherwise (bs, num_queries, 2). Each `inter_reference` has |
| shape (bs, num_queries, 4) when `with_box_refine` of the |
| detector is `True`, otherwise (bs, num_queries, 2). The |
| coordinates are arranged as (cx, cy) when the last dimension is |
| 2, and (cx, cy, w, h) when it is 4. |
| batch_data_samples (list[:obj:`DetDataSample`]): The Data |
| Samples. It usually includes information such as |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
| rescale (bool, optional): If `True`, return boxes in original |
| image space. Defaults to `True`. |
| |
| Returns: |
| list[obj:`InstanceData`]: Detection results of each image |
| after the post process. |
| """ |
| batch_img_metas = [ |
| data_samples.metainfo for data_samples in batch_data_samples |
| ] |
|
|
| outs = self(hidden_states, references) |
|
|
| predictions = self.predict_by_feat( |
| *outs, batch_img_metas=batch_img_metas, rescale=rescale) |
| return predictions |
|
|
| def predict_by_feat(self, |
| all_layers_cls_scores: Tensor, |
| all_layers_bbox_preds: Tensor, |
| batch_img_metas: List[Dict], |
| rescale: bool = False) -> InstanceList: |
| """Transform a batch of output features extracted from the head into |
| bbox results. |
| |
| Args: |
| all_layers_cls_scores (Tensor): Classification scores of all |
| decoder layers, has shape (num_decoder_layers, bs, num_queries, |
| cls_out_channels). |
| all_layers_bbox_preds (Tensor): Regression outputs of all decoder |
| layers. Each is a 4D-tensor with normalized coordinate format |
| (cx, cy, w, h) and shape (num_decoder_layers, bs, num_queries, |
| 4) with the last dimension arranged as (cx, cy, w, h). |
| batch_img_metas (list[dict]): Meta information of each image. |
| rescale (bool, optional): If `True`, return boxes in original |
| image space. Default `False`. |
| |
| Returns: |
| list[obj:`InstanceData`]: Detection results of each image |
| after the post process. |
| """ |
| cls_scores = all_layers_cls_scores[-1] |
| bbox_preds = all_layers_bbox_preds[-1] |
|
|
| result_list = [] |
| for img_id in range(len(batch_img_metas)): |
| cls_score = cls_scores[img_id] |
| bbox_pred = bbox_preds[img_id] |
| img_meta = batch_img_metas[img_id] |
| results = self._predict_by_feat_single(cls_score, bbox_pred, |
| img_meta, rescale) |
| result_list.append(results) |
| return result_list |
|
|