import torch import copy import torch import torch.nn.functional as F from torch import nn from torch.nn import BCELoss from utils import box_ops class ObjectNormalizedL2Loss(nn.Module): def __init__(self): super(ObjectNormalizedL2Loss, self).__init__() def forward(self, output, dmap, num_objects): return ((output - dmap) ** 2).sum() / num_objects class Detection_criterion(nn.Module): def __init__( self, sizes, iou_loss_type, center_sample, fpn_strides, pos_radius, aux=False ): super().__init__() self.sizes = sizes self.box_loss = IOULoss(iou_loss_type) self.aux = aux self.center_sample = center_sample self.strides = fpn_strides self.radius = pos_radius def prepare_target(self, points, targets): ex_size_of_interest = [] for i, point_per_level in enumerate(points): size_of_interest_per_level = point_per_level.new_tensor(self.sizes[i]) ex_size_of_interest.append( size_of_interest_per_level[None].expand(len(point_per_level), -1) ) ex_size_of_interest = torch.cat(ex_size_of_interest, 0) n_point_per_level = [len(point_per_level) for point_per_level in points] point_all = torch.cat(points, dim=0) label, box_target = self.compute_target_for_location( point_all, targets, ex_size_of_interest, n_point_per_level ) for i in range(len(label)): label[i] = torch.split(label[i], n_point_per_level, 0) box_target[i] = torch.split(box_target[i], n_point_per_level, 0) label_level_first = [] box_target_level_first = [] for level in range(len(points)): label_level_first.append( torch.cat([label_per_img[level] for label_per_img in label], 0).to(points[0].device) ) box_target_level_first.append( torch.cat( [box_target_per_img[level] for box_target_per_img in box_target], 0 ) ) return label_level_first, box_target_level_first def get_sample_region(self, gt, strides, n_point_per_level, xs, ys, radius=1): n_gt = gt.shape[0] n_loc = len(xs) gt = gt[None].expand(n_loc, n_gt, 4) center_x = (gt[..., 0] + gt[..., 2]) / 2 center_y = (gt[..., 1] + gt[..., 3]) / 2 # y_stride = torch.min((gt[..., 3] - gt[..., 1]) / 2)/2 # x_stride = torch.min((gt[..., 2] - gt[..., 0]) / 2)/2 if center_x[..., 0].sum() == 0: return xs.new_zeros(xs.shape, dtype=torch.uint8) begin = 0 center_gt = gt.new_zeros(gt.shape) for level, n_p in enumerate(n_point_per_level): end = begin + n_p stride = strides[level] * radius x_min = center_x[begin:end] - stride y_min = center_y[begin:end] - stride x_max = center_x[begin:end] + stride y_max = center_y[begin:end] + stride center_gt[begin:end, :, 0] = torch.where( x_min > gt[begin:end, :, 0], x_min, gt[begin:end, :, 0] ) center_gt[begin:end, :, 1] = torch.where( y_min > gt[begin:end, :, 1], y_min, gt[begin:end, :, 1] ) center_gt[begin:end, :, 2] = torch.where( x_max > gt[begin:end, :, 2], gt[begin:end, :, 2], x_max ) center_gt[begin:end, :, 3] = torch.where( y_max > gt[begin:end, :, 3], gt[begin:end, :, 3], y_max ) begin = end left = xs[:, None] - center_gt[..., 0] right = center_gt[..., 2] - xs[:, None] top = ys[:, None] - center_gt[..., 1] bottom = center_gt[..., 3] - ys[:, None] center_bbox = torch.stack((left, top, right, bottom), -1) is_in_boxes = center_bbox.min(-1)[0] > 0 return is_in_boxes def compute_target_for_location( self, locations, targets, sizes_of_interest, n_point_per_level ): labels = [] box_targets = [] xs, ys = locations[:, 0], locations[:, 1] for i in range(len(targets)): targets_per_img = targets[i] targets_per_img=targets_per_img.clip(remove_empty=True) assert targets_per_img.mode == 'xyxy' targets_per_img = targets_per_img[:50] bboxes = targets_per_img.box labels_per_img = torch.tensor([1]*len(bboxes)).to(locations.device) area = targets_per_img.area() l = xs[:, None] - bboxes[:, 0][None] t = ys[:, None] - bboxes[:, 1][None] r = bboxes[:, 2][None] - xs[:, None] b = bboxes[:, 3][None] - ys[:, None] box_targets_per_img = torch.stack([l, t, r, b], 2) if self.center_sample: is_in_boxes = self.get_sample_region( bboxes, self.strides, n_point_per_level, xs, ys, radius=self.radius ) else: is_in_boxes = box_targets_per_img.min(2)[0] > 0 max_box_targets_per_img = box_targets_per_img.max(2)[0] is_cared_in_level = ( max_box_targets_per_img >= sizes_of_interest[:, [0]] ) & (max_box_targets_per_img <= sizes_of_interest[:, [1]]) locations_to_gt_area = area[None].repeat(len(locations), 1) locations_to_gt_area[is_in_boxes == 0] = INF locations_to_gt_area[is_cared_in_level == 0] = INF locations_to_min_area, locations_to_gt_id = locations_to_gt_area.min(1) box_targets_per_img = box_targets_per_img[ range(len(locations)), locations_to_gt_id ] labels_per_img = labels_per_img.to(locations_to_gt_id.device)[locations_to_gt_id] labels_per_img[locations_to_min_area == INF] = 0 labels.append(labels_per_img) box_targets.append(box_targets_per_img) return labels, box_targets def compute_centerness_targets(self, box_targets): left_right = box_targets[:, [0, 2]] top_bottom = box_targets[:, [1, 3]] centerness = (left_right.min(-1)[0] / left_right.max(-1)[0]) * ( top_bottom.min(-1)[0] / top_bottom.max(-1)[0] ) return torch.sqrt(centerness) def forward(self, locations, box_pred, targets): batch = box_pred[0].shape[0] labels, box_targets = self.prepare_target(locations, targets) box_flat = [] labels_flat = [] box_targets_flat = [] for i in range(len(labels)): box_flat.append(box_pred.permute(0, 2, 3, 1).reshape(-1, 4)) labels_flat.append(labels[i].reshape(-1)) box_targets_flat.append(box_targets[i].reshape(-1, 4)) box_flat = torch.cat(box_flat, 0) labels_flat = torch.cat(labels_flat, 0) box_targets_flat = torch.cat(box_targets_flat, 0) pos_id = torch.nonzero(labels_flat > 0).squeeze(1) box_flat = box_flat[pos_id] box_targets_flat = box_targets_flat[pos_id] if pos_id.numel() > 0: center_targets = self.compute_centerness_targets(box_targets_flat) box_loss = self.box_loss(box_flat, box_targets_flat, center_targets) else: box_loss = box_flat.sum() return box_loss INF = 100000000 class IOULoss(nn.Module): def __init__(self, loc_loss_type): super().__init__() self.loc_loss_type = loc_loss_type def forward(self, out, target, weight=None): pred_left, pred_top, pred_right, pred_bottom = out.unbind(1) target_left, target_top, target_right, target_bottom = target.unbind(1) target_area = (target_left + target_right) * (target_top + target_bottom) pred_area = (pred_left + pred_right) * (pred_top + pred_bottom) w_intersect = torch.min(pred_left, target_left) + torch.min( pred_right, target_right ) h_intersect = torch.min(pred_bottom, target_bottom) + torch.min( pred_top, target_top ) area_intersect = w_intersect * h_intersect area_union = target_area + pred_area - area_intersect ious = (area_intersect + 1) / (area_union + 1) if self.loc_loss_type == 'iou': loss = -torch.log(ious) elif self.loc_loss_type == 'giou': g_w_intersect = torch.max(pred_left, target_left) + torch.max( pred_right, target_right ) g_h_intersect = torch.max(pred_bottom, target_bottom) + torch.max( pred_top, target_top ) g_intersect = g_w_intersect * g_h_intersect + 1e-7 gious = ious - (g_intersect - area_union) / g_intersect loss = 1 - gious if weight is not None and weight.sum() > 0: return (loss * weight).sum() / weight.sum() else: return loss.mean() class SetCriterion(nn.Module): """ This class computes the loss for DETR. The process happens in two steps: 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """ def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25): """ Create the criterion. Parameters: num_classes: number of object categories, omitting the special no-object category matcher: module able to compute a matching between targets and proposals weight_dict: dict containing as key the names of the losses and as values their relative weight. losses: list of all the losses to be applied. See get_loss for list of available losses. focal_alpha: alpha in Focal Loss """ super().__init__() self.num_classes = num_classes self.matcher = matcher self.weight_dict = weight_dict self.losses = losses self.focal_alpha = focal_alpha self.cross_entropy = BCELoss() def loss_boxes(self, outputs, targets, indices, num_boxes, centerness, centerness_gt,mask): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. """ assert 'pred_boxes' in outputs idx = self._get_src_permutation_idx(indices) src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') losses = {} losses['loss_bbox'] = loss_bbox.sum() / num_boxes loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( (src_boxes), (target_boxes))) losses['loss_giou'] = loss_giou.sum() / num_boxes return losses def ce_loss(self, outputs, targets, indices, num_boxes, centerness, centerness_gt, mask): l2 = ((centerness[mask > 0] - centerness_gt[mask > 0]) ** 2) losses = {} losses['loss_ce'] = l2.sum() / num_boxes return losses def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx def get_loss(self, loss, outputs, targets, indices, num_boxes, centerness, centerness_gt, mask, **kwargs): loss_map = { 'bboxes': self.loss_boxes, 'ce': self.ce_loss } assert loss in loss_map, f'do you really want to compute {loss} loss?' return loss_map[loss](outputs, targets, indices, num_boxes,centerness, centerness_gt, mask, **kwargs) # def generate_centerness_gt(self, indices, FN_idx, FP_idx, outputs, targets, centerness, ref_points): # # TP_bboxes = outputs['pred_boxes'][0][indices[0][0]] * centerness.shape[1] # # FP_bboxes = outputs['pred_boxes'][0][FP_idx] * centerness.shape[1] # FN_bboxes = targets[0]['boxes'][FN_idx] * centerness.shape[1] # centerness_gt = torch.zeros_like(centerness) # mask = torch.ones_like(centerness) # # FP -> Non-matched PRED bboxes get 0 in the reference point, so 1 in mask # FP_locs = ref_points.permute(1, 0)[FP_idx] # mask[0][FP_locs[:, 0], FP_locs[:, 1]] = 1 # bounding_boxes = (targets[0]['boxes'] * centerness.shape[1]).type(torch.int64) # for box in bounding_boxes: # x_min, y_min, x_max, y_max = box # mask[:, y_min:y_max, x_min:x_max] = 0 # # FN -> Non-matched GT bboxes get 1 in center of bbox # if len(FN_bboxes) > 0: # FN_y_loc = torch.clamp(((FN_bboxes[:, 3] + FN_bboxes[:, 1]) / 2).int(), min=0, max=centerness.shape[1]-1) # FN_x_loc = torch.clamp(((FN_bboxes[:, 2] + FN_bboxes[:, 0]) / 2).int(), min=0, max=centerness.shape[1]-1) # centerness_gt[0][FN_y_loc, FN_x_loc] = 1 # mask[0][FN_y_loc, FN_x_loc] = 1 # # TP -> Matched PRED bboxes get 1 in the reference point # TP_locs = ref_points.permute(1, 0)[indices[0][0]] # centerness_gt[0][TP_locs[:, 0], TP_locs[:, 1]] = 1 # mask[0][TP_locs[:, 0], TP_locs[:, 1]] = 1 # return centerness_gt, mask def generate_centerness_gt(self, indices, FN_idx, FP_idx, outputs, targets, centerness, ref_points): FN_bboxes = targets[0]['boxes'][FN_idx] * centerness.shape[1] centerness_gt = torch.zeros_like(centerness) mask = torch.zeros_like(centerness) # FP -> Non-matched PRED bboxes get 0 in the reference point, so 1 in mask FP_locs = ref_points.permute(1, 0)[FP_idx] mask[0][FP_locs[:, 0], FP_locs[:, 1]] = 1 # bounding_boxes = (targets[0]['boxes'] * centerness.shape[1]).type(torch.int64) # for box in bounding_boxes: # x_min, y_min, x_max, y_max = box # mask[:, y_min:y_max, x_min:x_max] = 0 # FN -> Non-matched GT bboxes get 1 in center of bbox if len(FN_bboxes) > 0: FN_y_loc = torch.clamp(((FN_bboxes[:, 3] + FN_bboxes[:, 1]) / 2).int(), min=0, max=centerness.shape[1]-1) FN_x_loc = torch.clamp(((FN_bboxes[:, 2] + FN_bboxes[:, 0]) / 2).int(), min=0, max=centerness.shape[1]-1) centerness_gt[0][FN_y_loc, FN_x_loc] = 1 mask[0][FN_y_loc, FN_x_loc] = 1 # TP -> Matched PRED bboxes get 1 in the reference point TP_locs = ref_points.permute(1, 0)[indices[0][0]] centerness_gt[0][TP_locs[:, 0], TP_locs[:, 1]] = 1 mask[0][TP_locs[:, 0], TP_locs[:, 1]] = 1 if centerness_gt.sum() < targets[0]['boxes'].shape[0]: centerness_gt = torch.zeros_like(centerness) FN_bboxes = targets[0]['boxes'] * centerness.shape[1] FN_y_loc = torch.clamp(((FN_bboxes[:, 3] + FN_bboxes[:, 1]) / 2).int(), min=0, max=centerness.shape[1]-1) FN_x_loc = torch.clamp(((FN_bboxes[:, 2] + FN_bboxes[:, 0]) / 2).int(), min=0, max=centerness.shape[1]-1) centerness_gt[0][FN_y_loc, FN_x_loc] = 1 mask = torch.ones_like(centerness) return centerness_gt, mask # def generate_centerness_gt(self, indices, FN_idx, FP_idx, outputs, targets, centerness, ref_points): # # TP_bboxes = outputs['pred_boxes'][0][indices[0][0]] * centerness.shape[1] # # FP_bboxes = outputs['pred_boxes'][0][FP_idx] * centerness.shape[1] # FN_bboxes = targets[0]['boxes'][FN_idx] * centerness.shape[1] # centerness_gt = torch.zeros_like(centerness) # mask = torch.zeros_like(centerness) # # FN -> Non-matched GT bboxes get 1 in center of bbox # if len(FN_bboxes) > 0: # FN_y_loc = ((FN_bboxes[:, 3] + FN_bboxes[:, 1]) / 2 ).int() # FN_x_loc = ((FN_bboxes[:, 2] + FN_bboxes[:, 0]) / 2 ).int() # centerness_gt[0][FN_y_loc, FN_x_loc] = 1 # # mask[0][FN_y_loc, FN_x_loc] = 1 # # FP -> Non-matched PRED bboxes get 0 in the reference point, so 1 in mask # FP_locs = ref_points.permute(1, 0)[FP_idx] # # mask[0][FP_locs[:, 0], FP_locs[:, 1]] = 1 # # TP -> Matched PRED bboxes get 1 in the reference point # print(indices[0][0]) # TP_locs = ref_points.permute(1, 0)[indices[0][0]] # centerness_gt[0][TP_locs[:, 0], TP_locs[:, 1]] = 1 # # mask[0][TP_locs[:, 0], TP_locs[:, 1]] = 1 # return centerness_gt, mask # # from matplotlib import pyplot as plt # # plt.clf() # # plt.imshow(centerness_gt.cpu()[0], cmap='jet') # # plt.imshow(mask.cpu()[0], cmap='jet', alpha=0.3) # # for i in range(TP_bboxes.shape[0]): # # box = TP_bboxes[i].cpu() # # plt.plot([box[0], box[0], box[2], box[2], box[0]], # # [box[1], box[3], box[3], box[1], box[1]], color='g') # # # # for i in range(FP_bboxes.shape[0]): # # box = FP_bboxes[i].cpu() # # plt.plot([box[0], box[0], box[2], box[2], box[0]], # # [box[1], box[3], box[3], box[1], box[1]], color='orange') # # # # for i in range(FN_bboxes.shape[0]): # # box = FN_bboxes[i].cpu() # # plt.plot([box[0], box[0], box[2], box[2], box[0]], # # [box[1], box[3], box[3], box[1], box[1]], color='red') # # plt.savefig("T") def forward(self, outputs, targets, centerness, ref_points): """ This performs the loss computation. Parameters: outputs: dict of tensors, see the output specification of the model for the format targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ # Retrieve the matching between the outputs of the last layer and the targets indices, FN_idx, FP_idx = self.matcher(outputs, targets) # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t["labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) num_boxes = torch.clamp(num_boxes, min=1).item() centerness_gt, mask = self.generate_centerness_gt(indices, FN_idx, FP_idx, outputs, targets, centerness, ref_points) # Compute all the requested losses losses = {} for loss in self.losses: kwargs = {} losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, centerness, centerness_gt, mask, **kwargs)) # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. if 'aux_outputs' in outputs: for i, aux_outputs in enumerate(outputs['aux_outputs']): indices = self.matcher(aux_outputs, targets) for loss in self.losses: if loss == 'masks': # Intermediate masks losses are too costly to compute, we ignore them. continue kwargs = {} if loss == 'labels': # Logging is enabled only for the last layer kwargs['log'] = False l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) l_dict = {k + f'_{i}': v for k, v in l_dict.items()} losses.update(l_dict) if 'enc_outputs' in outputs: enc_outputs = outputs['enc_outputs'] bin_targets = copy.deepcopy(targets) for bt in bin_targets: bt['labels'] = torch.zeros_like(bt['labels']) indices = self.matcher(enc_outputs, bin_targets) for loss in self.losses: if loss == 'masks': # Intermediate masks losses are too costly to compute, we ignore them. continue kwargs = {} if loss == 'labels': # Logging is enabled only for the last layer kwargs['log'] = False l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs) l_dict = {k + f'_enc': v for k, v in l_dict.items()} losses.update(l_dict) return losses