| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from ultralytics.utils.loss import FocalLoss, VarifocalLoss |
| from ultralytics.utils.metrics import bbox_iou |
| from .ops import HungarianMatcher |
|
|
|
|
| class DETRLoss(nn.Module): |
| """ |
| DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the |
| DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary |
| losses. |
| |
| Attributes: |
| nc (int): The number of classes. |
| loss_gain (dict): Coefficients for different loss components. |
| aux_loss (bool): Whether to compute auxiliary losses. |
| use_fl (bool): Use FocalLoss or not. |
| use_vfl (bool): Use VarifocalLoss or not. |
| use_uni_match (bool): Whether to use a fixed layer to assign labels for the auxiliary branch. |
| uni_match_ind (int): The fixed indices of a layer to use if `use_uni_match` is True. |
| matcher (HungarianMatcher): Object to compute matching cost and indices. |
| fl (FocalLoss or None): Focal Loss object if `use_fl` is True, otherwise None. |
| vfl (VarifocalLoss or None): Varifocal Loss object if `use_vfl` is True, otherwise None. |
| device (torch.device): Device on which tensors are stored. |
| """ |
|
|
| def __init__( |
| self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0 |
| ): |
| """ |
| DETR loss function. |
| |
| Args: |
| nc (int): The number of classes. |
| loss_gain (dict): The coefficient of loss. |
| aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used. |
| use_vfl (bool): Use VarifocalLoss or not. |
| use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch. |
| uni_match_ind (int): The fixed indices of a layer. |
| """ |
| super().__init__() |
|
|
| if loss_gain is None: |
| loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1} |
| self.nc = nc |
| self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2}) |
| self.loss_gain = loss_gain |
| self.aux_loss = aux_loss |
| self.fl = FocalLoss() if use_fl else None |
| self.vfl = VarifocalLoss() if use_vfl else None |
|
|
| self.use_uni_match = use_uni_match |
| self.uni_match_ind = uni_match_ind |
| self.device = None |
|
|
| def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""): |
| """Computes the classification loss based on predictions, target values, and ground truth scores.""" |
| |
| name_class = f"loss_class{postfix}" |
| bs, nq = pred_scores.shape[:2] |
| |
| one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device) |
| one_hot.scatter_(2, targets.unsqueeze(-1), 1) |
| one_hot = one_hot[..., :-1] |
| gt_scores = gt_scores.view(bs, nq, 1) * one_hot |
|
|
| if self.fl: |
| if num_gts and self.vfl: |
| loss_cls = self.vfl(pred_scores, gt_scores, one_hot) |
| else: |
| loss_cls = self.fl(pred_scores, one_hot.float()) |
| loss_cls /= max(num_gts, 1) / nq |
| else: |
| loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() |
|
|
| return {name_class: loss_cls.squeeze() * self.loss_gain["class"]} |
|
|
| def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""): |
| """Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding |
| boxes. |
| """ |
| |
| name_bbox = f"loss_bbox{postfix}" |
| name_giou = f"loss_giou{postfix}" |
|
|
| loss = {} |
| if len(gt_bboxes) == 0: |
| loss[name_bbox] = torch.tensor(0.0, device=self.device) |
| loss[name_giou] = torch.tensor(0.0, device=self.device) |
| return loss |
|
|
| loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes) |
| loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True) |
| loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes) |
| loss[name_giou] = self.loss_gain["giou"] * loss[name_giou] |
| return {k: v.squeeze() for k, v in loss.items()} |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def _get_loss_aux( |
| self, |
| pred_bboxes, |
| pred_scores, |
| gt_bboxes, |
| gt_cls, |
| gt_groups, |
| match_indices=None, |
| postfix="", |
| masks=None, |
| gt_mask=None, |
| ): |
| """Get auxiliary losses.""" |
| |
| loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device) |
| if match_indices is None and self.use_uni_match: |
| match_indices = self.matcher( |
| pred_bboxes[self.uni_match_ind], |
| pred_scores[self.uni_match_ind], |
| gt_bboxes, |
| gt_cls, |
| gt_groups, |
| masks=masks[self.uni_match_ind] if masks is not None else None, |
| gt_mask=gt_mask, |
| ) |
| for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)): |
| aux_masks = masks[i] if masks is not None else None |
| loss_ = self._get_loss( |
| aux_bboxes, |
| aux_scores, |
| gt_bboxes, |
| gt_cls, |
| gt_groups, |
| masks=aux_masks, |
| gt_mask=gt_mask, |
| postfix=postfix, |
| match_indices=match_indices, |
| ) |
| loss[0] += loss_[f"loss_class{postfix}"] |
| loss[1] += loss_[f"loss_bbox{postfix}"] |
| loss[2] += loss_[f"loss_giou{postfix}"] |
| |
| |
| |
| |
|
|
| loss = { |
| f"loss_class_aux{postfix}": loss[0], |
| f"loss_bbox_aux{postfix}": loss[1], |
| f"loss_giou_aux{postfix}": loss[2], |
| } |
| |
| |
| |
| return loss |
|
|
| @staticmethod |
| def _get_index(match_indices): |
| """Returns batch indices, source indices, and destination indices from provided match indices.""" |
| batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)]) |
| src_idx = torch.cat([src for (src, _) in match_indices]) |
| dst_idx = torch.cat([dst for (_, dst) in match_indices]) |
| return (batch_idx, src_idx), dst_idx |
|
|
| def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices): |
| """Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices.""" |
| pred_assigned = torch.cat( |
| [ |
| t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device) |
| for t, (i, _) in zip(pred_bboxes, match_indices) |
| ] |
| ) |
| gt_assigned = torch.cat( |
| [ |
| t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device) |
| for t, (_, j) in zip(gt_bboxes, match_indices) |
| ] |
| ) |
| return pred_assigned, gt_assigned |
|
|
| def _get_loss( |
| self, |
| pred_bboxes, |
| pred_scores, |
| gt_bboxes, |
| gt_cls, |
| gt_groups, |
| masks=None, |
| gt_mask=None, |
| postfix="", |
| match_indices=None, |
| ): |
| """Get losses.""" |
| if match_indices is None: |
| match_indices = self.matcher( |
| pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask |
| ) |
|
|
| idx, gt_idx = self._get_index(match_indices) |
| pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx] |
|
|
| bs, nq = pred_scores.shape[:2] |
| targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype) |
| targets[idx] = gt_cls[gt_idx] |
|
|
| gt_scores = torch.zeros([bs, nq], device=pred_scores.device) |
| if len(gt_bboxes): |
| gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1) |
|
|
| loss = {} |
| loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix)) |
| loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix)) |
| |
| |
| return loss |
|
|
| def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs): |
| """ |
| Args: |
| pred_bboxes (torch.Tensor): [l, b, query, 4] |
| pred_scores (torch.Tensor): [l, b, query, num_classes] |
| batch (dict): A dict includes: |
| gt_cls (torch.Tensor) with shape [num_gts, ], |
| gt_bboxes (torch.Tensor): [num_gts, 4], |
| gt_groups (List(int)): a list of batch size length includes the number of gts of each image. |
| postfix (str): postfix of loss name. |
| """ |
| self.device = pred_bboxes.device |
| match_indices = kwargs.get("match_indices", None) |
| gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"] |
|
|
| total_loss = self._get_loss( |
| pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices |
| ) |
|
|
| if self.aux_loss: |
| total_loss.update( |
| self._get_loss_aux( |
| pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix |
| ) |
| ) |
|
|
| return total_loss |
|
|
|
|
| class RTDETRDetectionLoss(DETRLoss): |
| """ |
| Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss. |
| |
| This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as |
| an additional denoising training loss when provided with denoising metadata. |
| """ |
|
|
| def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None): |
| """ |
| Forward pass to compute the detection loss. |
| |
| Args: |
| preds (tuple): Predicted bounding boxes and scores. |
| batch (dict): Batch data containing ground truth information. |
| dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. Default is None. |
| dn_scores (torch.Tensor, optional): Denoising scores. Default is None. |
| dn_meta (dict, optional): Metadata for denoising. Default is None. |
| |
| Returns: |
| (dict): Dictionary containing the total loss and, if applicable, the denoising loss. |
| """ |
| pred_bboxes, pred_scores = preds |
| total_loss = super().forward(pred_bboxes, pred_scores, batch) |
|
|
| |
| if dn_meta is not None: |
| dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"] |
| assert len(batch["gt_groups"]) == len(dn_pos_idx) |
|
|
| |
| match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"]) |
|
|
| |
| dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices) |
| total_loss.update(dn_loss) |
| else: |
| |
| total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()}) |
|
|
| return total_loss |
|
|
| @staticmethod |
| def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups): |
| """ |
| Get the match indices for denoising. |
| |
| Args: |
| dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising. |
| dn_num_group (int): Number of denoising groups. |
| gt_groups (List[int]): List of integers representing the number of ground truths for each image. |
| |
| Returns: |
| (List[tuple]): List of tuples containing matched indices for denoising. |
| """ |
| dn_match_indices = [] |
| idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) |
| for i, num_gt in enumerate(gt_groups): |
| if num_gt > 0: |
| gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i] |
| gt_idx = gt_idx.repeat(dn_num_group) |
| assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, " |
| f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively." |
| dn_match_indices.append((dn_pos_idx[i], gt_idx)) |
| else: |
| dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long))) |
| return dn_match_indices |
|
|