| |
| import copy |
| from typing import List, Optional, Tuple |
|
|
| import torch |
| from mmengine.structures import InstanceData |
| from torch import Tensor |
|
|
| from mmdet.models.utils import (filter_gt_instances, rename_loss_dict, |
| reweight_loss_dict) |
| from mmdet.registry import MODELS |
| from mmdet.structures import SampleList |
| from mmdet.structures.bbox import bbox2roi, bbox_project |
| from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig |
| from ..utils.misc import unpack_gt_instances |
| from .semi_base import SemiBaseDetector |
|
|
|
|
| @MODELS.register_module() |
| class SoftTeacher(SemiBaseDetector): |
| r"""Implementation of `End-to-End Semi-Supervised Object Detection |
| with Soft Teacher <https://arxiv.org/abs/2106.09018>`_ |
| |
| Args: |
| detector (:obj:`ConfigDict` or dict): The detector config. |
| semi_train_cfg (:obj:`ConfigDict` or dict, optional): |
| The semi-supervised training config. |
| semi_test_cfg (:obj:`ConfigDict` or dict, optional): |
| The semi-supervised testing config. |
| data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of |
| :class:`DetDataPreprocessor` to process the input data. |
| Defaults to None. |
| init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or |
| list[dict], optional): Initialization config dict. |
| Defaults to None. |
| """ |
|
|
| def __init__(self, |
| detector: ConfigType, |
| semi_train_cfg: OptConfigType = None, |
| semi_test_cfg: OptConfigType = None, |
| data_preprocessor: OptConfigType = None, |
| init_cfg: OptMultiConfig = None) -> None: |
| super().__init__( |
| detector=detector, |
| semi_train_cfg=semi_train_cfg, |
| semi_test_cfg=semi_test_cfg, |
| data_preprocessor=data_preprocessor, |
| init_cfg=init_cfg) |
|
|
| def loss_by_pseudo_instances(self, |
| batch_inputs: Tensor, |
| batch_data_samples: SampleList, |
| batch_info: Optional[dict] = None) -> dict: |
| """Calculate losses from a batch of inputs and pseudo data samples. |
| |
| Args: |
| batch_inputs (Tensor): Input images of shape (N, C, H, W). |
| These should usually be mean centered and std scaled. |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch |
| data samples. It usually includes information such |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, |
| which are `pseudo_instance` or `pseudo_panoptic_seg` |
| or `pseudo_sem_seg` in fact. |
| batch_info (dict): Batch information of teacher model |
| forward propagation process. Defaults to None. |
| |
| Returns: |
| dict: A dictionary of loss components |
| """ |
|
|
| x = self.student.extract_feat(batch_inputs) |
|
|
| losses = {} |
| rpn_losses, rpn_results_list = self.rpn_loss_by_pseudo_instances( |
| x, batch_data_samples) |
| losses.update(**rpn_losses) |
| losses.update(**self.rcnn_cls_loss_by_pseudo_instances( |
| x, rpn_results_list, batch_data_samples, batch_info)) |
| losses.update(**self.rcnn_reg_loss_by_pseudo_instances( |
| x, rpn_results_list, batch_data_samples)) |
| unsup_weight = self.semi_train_cfg.get('unsup_weight', 1.) |
| return rename_loss_dict('unsup_', |
| reweight_loss_dict(losses, unsup_weight)) |
|
|
| @torch.no_grad() |
| def get_pseudo_instances( |
| self, batch_inputs: Tensor, batch_data_samples: SampleList |
| ) -> Tuple[SampleList, Optional[dict]]: |
| """Get pseudo instances from teacher model.""" |
| assert self.teacher.with_bbox, 'Bbox head must be implemented.' |
| x = self.teacher.extract_feat(batch_inputs) |
|
|
| |
| if batch_data_samples[0].get('proposals', None) is None: |
| rpn_results_list = self.teacher.rpn_head.predict( |
| x, batch_data_samples, rescale=False) |
| else: |
| rpn_results_list = [ |
| data_sample.proposals for data_sample in batch_data_samples |
| ] |
|
|
| results_list = self.teacher.roi_head.predict( |
| x, rpn_results_list, batch_data_samples, rescale=False) |
|
|
| for data_samples, results in zip(batch_data_samples, results_list): |
| data_samples.gt_instances = results |
|
|
| batch_data_samples = filter_gt_instances( |
| batch_data_samples, |
| score_thr=self.semi_train_cfg.pseudo_label_initial_score_thr) |
|
|
| reg_uncs_list = self.compute_uncertainty_with_aug( |
| x, batch_data_samples) |
|
|
| for data_samples, reg_uncs in zip(batch_data_samples, reg_uncs_list): |
| data_samples.gt_instances['reg_uncs'] = reg_uncs |
| data_samples.gt_instances.bboxes = bbox_project( |
| data_samples.gt_instances.bboxes, |
| torch.from_numpy(data_samples.homography_matrix).inverse().to( |
| self.data_preprocessor.device), data_samples.ori_shape) |
|
|
| batch_info = { |
| 'feat': x, |
| 'img_shape': [], |
| 'homography_matrix': [], |
| 'metainfo': [] |
| } |
| for data_samples in batch_data_samples: |
| batch_info['img_shape'].append(data_samples.img_shape) |
| batch_info['homography_matrix'].append( |
| torch.from_numpy(data_samples.homography_matrix).to( |
| self.data_preprocessor.device)) |
| batch_info['metainfo'].append(data_samples.metainfo) |
| return batch_data_samples, batch_info |
|
|
| def rpn_loss_by_pseudo_instances(self, x: Tuple[Tensor], |
| batch_data_samples: SampleList) -> dict: |
| """Calculate rpn loss from a batch of inputs and pseudo data samples. |
| |
| Args: |
| x (tuple[Tensor]): Features from FPN. |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch |
| data samples. It usually includes information such |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, |
| which are `pseudo_instance` or `pseudo_panoptic_seg` |
| or `pseudo_sem_seg` in fact. |
| Returns: |
| dict: A dictionary of rpn loss components |
| """ |
|
|
| rpn_data_samples = copy.deepcopy(batch_data_samples) |
| rpn_data_samples = filter_gt_instances( |
| rpn_data_samples, score_thr=self.semi_train_cfg.rpn_pseudo_thr) |
| proposal_cfg = self.student.train_cfg.get('rpn_proposal', |
| self.student.test_cfg.rpn) |
| |
| for data_sample in rpn_data_samples: |
| data_sample.gt_instances.labels = \ |
| torch.zeros_like(data_sample.gt_instances.labels) |
|
|
| rpn_losses, rpn_results_list = self.student.rpn_head.loss_and_predict( |
| x, rpn_data_samples, proposal_cfg=proposal_cfg) |
| for key in rpn_losses.keys(): |
| if 'loss' in key and 'rpn' not in key: |
| rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) |
| return rpn_losses, rpn_results_list |
|
|
| def rcnn_cls_loss_by_pseudo_instances(self, x: Tuple[Tensor], |
| unsup_rpn_results_list: InstanceList, |
| batch_data_samples: SampleList, |
| batch_info: dict) -> dict: |
| """Calculate classification loss from a batch of inputs and pseudo data |
| samples. |
| |
| Args: |
| x (tuple[Tensor]): List of multi-level img features. |
| unsup_rpn_results_list (list[:obj:`InstanceData`]): |
| List of region proposals. |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch |
| data samples. It usually includes information such |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, |
| which are `pseudo_instance` or `pseudo_panoptic_seg` |
| or `pseudo_sem_seg` in fact. |
| batch_info (dict): Batch information of teacher model |
| forward propagation process. |
| |
| Returns: |
| dict[str, Tensor]: A dictionary of rcnn |
| classification loss components |
| """ |
| rpn_results_list = copy.deepcopy(unsup_rpn_results_list) |
| cls_data_samples = copy.deepcopy(batch_data_samples) |
| cls_data_samples = filter_gt_instances( |
| cls_data_samples, score_thr=self.semi_train_cfg.cls_pseudo_thr) |
|
|
| outputs = unpack_gt_instances(cls_data_samples) |
| batch_gt_instances, batch_gt_instances_ignore, _ = outputs |
|
|
| |
| num_imgs = len(cls_data_samples) |
| sampling_results = [] |
| for i in range(num_imgs): |
| |
| rpn_results = rpn_results_list[i] |
| rpn_results.priors = rpn_results.pop('bboxes') |
| assign_result = self.student.roi_head.bbox_assigner.assign( |
| rpn_results, batch_gt_instances[i], |
| batch_gt_instances_ignore[i]) |
| sampling_result = self.student.roi_head.bbox_sampler.sample( |
| assign_result, |
| rpn_results, |
| batch_gt_instances[i], |
| feats=[lvl_feat[i][None] for lvl_feat in x]) |
| sampling_results.append(sampling_result) |
|
|
| selected_bboxes = [res.priors for res in sampling_results] |
| rois = bbox2roi(selected_bboxes) |
| bbox_results = self.student.roi_head._bbox_forward(x, rois) |
| |
| |
| cls_reg_targets = self.student.roi_head.bbox_head.get_targets( |
| sampling_results, self.student.train_cfg.rcnn) |
|
|
| selected_results_list = [] |
| for bboxes, data_samples, teacher_matrix, teacher_img_shape in zip( |
| selected_bboxes, batch_data_samples, |
| batch_info['homography_matrix'], batch_info['img_shape']): |
| student_matrix = torch.tensor( |
| data_samples.homography_matrix, device=teacher_matrix.device) |
| homography_matrix = teacher_matrix @ student_matrix.inverse() |
| projected_bboxes = bbox_project(bboxes, homography_matrix, |
| teacher_img_shape) |
| selected_results_list.append(InstanceData(bboxes=projected_bboxes)) |
|
|
| with torch.no_grad(): |
| results_list = self.teacher.roi_head.predict_bbox( |
| batch_info['feat'], |
| batch_info['metainfo'], |
| selected_results_list, |
| rcnn_test_cfg=None, |
| rescale=False) |
| bg_score = torch.cat( |
| [results.scores[:, -1] for results in results_list]) |
| |
| neg_inds = cls_reg_targets[ |
| 0] == self.student.roi_head.bbox_head.num_classes |
| |
| cls_reg_targets[1][neg_inds] = bg_score[neg_inds].detach() |
|
|
| losses = self.student.roi_head.bbox_head.loss( |
| bbox_results['cls_score'], bbox_results['bbox_pred'], rois, |
| *cls_reg_targets) |
| |
| losses['loss_cls'] = losses['loss_cls'] * len( |
| cls_reg_targets[1]) / max(sum(cls_reg_targets[1]), 1.0) |
| return losses |
|
|
| def rcnn_reg_loss_by_pseudo_instances( |
| self, x: Tuple[Tensor], unsup_rpn_results_list: InstanceList, |
| batch_data_samples: SampleList) -> dict: |
| """Calculate rcnn regression loss from a batch of inputs and pseudo |
| data samples. |
| |
| Args: |
| x (tuple[Tensor]): List of multi-level img features. |
| unsup_rpn_results_list (list[:obj:`InstanceData`]): |
| List of region proposals. |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch |
| data samples. It usually includes information such |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, |
| which are `pseudo_instance` or `pseudo_panoptic_seg` |
| or `pseudo_sem_seg` in fact. |
| |
| Returns: |
| dict[str, Tensor]: A dictionary of rcnn |
| regression loss components |
| """ |
| rpn_results_list = copy.deepcopy(unsup_rpn_results_list) |
| reg_data_samples = copy.deepcopy(batch_data_samples) |
| for data_samples in reg_data_samples: |
| if data_samples.gt_instances.bboxes.shape[0] > 0: |
| data_samples.gt_instances = data_samples.gt_instances[ |
| data_samples.gt_instances.reg_uncs < |
| self.semi_train_cfg.reg_pseudo_thr] |
| roi_losses = self.student.roi_head.loss(x, rpn_results_list, |
| reg_data_samples) |
| return {'loss_bbox': roi_losses['loss_bbox']} |
|
|
| def compute_uncertainty_with_aug( |
| self, x: Tuple[Tensor], |
| batch_data_samples: SampleList) -> List[Tensor]: |
| """Compute uncertainty with augmented bboxes. |
| |
| Args: |
| x (tuple[Tensor]): List of multi-level img features. |
| batch_data_samples (List[:obj:`DetDataSample`]): The batch |
| data samples. It usually includes information such |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`, |
| which are `pseudo_instance` or `pseudo_panoptic_seg` |
| or `pseudo_sem_seg` in fact. |
| |
| Returns: |
| list[Tensor]: A list of uncertainty for pseudo bboxes. |
| """ |
| auged_results_list = self.aug_box(batch_data_samples, |
| self.semi_train_cfg.jitter_times, |
| self.semi_train_cfg.jitter_scale) |
| |
| auged_results_list = [ |
| InstanceData(bboxes=auged.reshape(-1, auged.shape[-1])) |
| for auged in auged_results_list |
| ] |
|
|
| self.teacher.roi_head.test_cfg = None |
| results_list = self.teacher.roi_head.predict( |
| x, auged_results_list, batch_data_samples, rescale=False) |
| self.teacher.roi_head.test_cfg = self.teacher.test_cfg.rcnn |
|
|
| reg_channel = max( |
| [results.bboxes.shape[-1] for results in results_list]) // 4 |
| bboxes = [ |
| results.bboxes.reshape(self.semi_train_cfg.jitter_times, -1, |
| results.bboxes.shape[-1]) |
| if results.bboxes.numel() > 0 else results.bboxes.new_zeros( |
| self.semi_train_cfg.jitter_times, 0, 4 * reg_channel).float() |
| for results in results_list |
| ] |
|
|
| box_unc = [bbox.std(dim=0) for bbox in bboxes] |
| bboxes = [bbox.mean(dim=0) for bbox in bboxes] |
| labels = [ |
| data_samples.gt_instances.labels |
| for data_samples in batch_data_samples |
| ] |
| if reg_channel != 1: |
| bboxes = [ |
| bbox.reshape(bbox.shape[0], reg_channel, |
| 4)[torch.arange(bbox.shape[0]), label] |
| for bbox, label in zip(bboxes, labels) |
| ] |
| box_unc = [ |
| unc.reshape(unc.shape[0], reg_channel, |
| 4)[torch.arange(unc.shape[0]), label] |
| for unc, label in zip(box_unc, labels) |
| ] |
|
|
| box_shape = [(bbox[:, 2:4] - bbox[:, :2]).clamp(min=1.0) |
| for bbox in bboxes] |
| box_unc = [ |
| torch.mean( |
| unc / wh[:, None, :].expand(-1, 2, 2).reshape(-1, 4), dim=-1) |
| if wh.numel() > 0 else unc for unc, wh in zip(box_unc, box_shape) |
| ] |
| return box_unc |
|
|
| @staticmethod |
| def aug_box(batch_data_samples, times, frac): |
| """Augment bboxes with jitter.""" |
|
|
| def _aug_single(box): |
| box_scale = box[:, 2:4] - box[:, :2] |
| box_scale = ( |
| box_scale.clamp(min=1)[:, None, :].expand(-1, 2, |
| 2).reshape(-1, 4)) |
| aug_scale = box_scale * frac |
|
|
| offset = ( |
| torch.randn(times, box.shape[0], 4, device=box.device) * |
| aug_scale[None, ...]) |
| new_box = box.clone()[None, ...].expand(times, box.shape[0], |
| -1) + offset |
| return new_box |
|
|
| return [ |
| _aug_single(data_samples.gt_instances.bboxes) |
| for data_samples in batch_data_samples |
| ] |
|
|