# Building Hungarian Matcher # Borrow code from AnchorDETR # We replace bounding box matching with point location matching import numpy as np import torch from scipy.optimize import linear_sum_assignment from torch import nn from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou, box_iou class HungarianMatcher(nn.Module): def __init__(self, cost_class: float = 1.0, cost_points: float = 1.0): """Create the matcher Params: cost_class: Class weight cost_dists: distance weight """ super().__init__() self.cost_class = cost_class self.cost_points = cost_points def forward(self, outputs, targets): """Matching pipeline Args: outputs (dict): contains at least two params: pred_logits: [batch_size, num_queries, num_classes]: classification logits pred_points: [batch_size, num_queries, 2]: predicted points targets (list of targets, where len(targets) = batch_size), each target is a dict containing labels: tensor of dim [num_target_boxes] containing the class label points: tensor of dim [num_target_boxes,2]: target points coordinate Returns: A list of size batch_size, containing the tuple of (index_i, index_j) where: - index_i: index of selected predictions (in order) - index_j: index of corresponding selected targets """ with torch.no_grad(): bs, num_queries = outputs["pred_logits"].shape[:2] # Flatten to compute cost matrix of the batch out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() out_points = outputs["pred_points"].flatten(0, 1) # [batch_size * num_queries, 2] # Also concat target labels and points tgt_ids = torch.cat([v["labels"] for v in targets]) tgt_points = torch.cat([v["points"] for v in targets]) # [batch_size*num_targets,2] # Compute the classification loss alpha = 0.25 gamma = 2.0 neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] # [num_queries, num_targets] # L1 loss cost_points = torch.cdist(out_points, tgt_points, p=1) # Add cost C = self.cost_class * cost_class + self.cost_points * cost_points C = C.view(bs, num_queries, -1).cpu() sizes = [len(v["points"]) for v in targets] indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] class PointsDistance(nn.Module): def __init__(self, dist_type): """ Accept two distance type: EMD and Chamfer """ super().__init__() self.dist_type = dist_type def _get_src_permutation_idx(self, indices): batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def em_distance(self, outputs, targets): with torch.no_grad(): bs, num_queries = outputs["pred_points"].shape[:2] out_points = outputs["pred_points"].flatten(0, 1) # [batch_size * numqueries,2] tgt_points = torch.cat([v["points"] for v in targets]) # [batch_size * num_targets,2] C = torch.norm( out_points[:, None, :] - tgt_points[None, :, :], p=2, dim=-1 ) # [batch_size*num_queries,batch_size*num_targets] C = C.view(bs, num_queries, -1).cpu() sizes = [len(v["points"]) for v in targets] indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] indices = [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ] idx = self._get_src_permutation_idx(indices) src_points = outputs["pred_points"][idx] tgt_points = torch.cat([t["points"][i] for t, (_, i) in zip(targets, indices)]) dists = torch.norm(src_points - tgt_points, p=2, dim=-1) return torch.mean(dists), indices def chamfer_distance(self, outputs, targets): with torch.no_grad(): bs, num_queries = outputs["pred_points"].shape[:2] out_points = outputs["pred_points"].flatten(0, 1) # [batch_size * num_queries,2] tgt_points = torch.cat([v["points"] for v in targets]) # [batch_size * num_targets,2] C = torch.norm( out_points[:, None, :] - tgt_points[None, :, :], p=2, dim=-1 ) # [batch_size * num_queries, batch_size * num_targets] C = C.view(bs, num_queries, -1) # [batch_size, num queries, num_targets] indices_src = torch.argmin(C, dim=1) indices_tgt = torch.argmin(C, dim=2) src_points = outputs["pred_points"] tgt_points = torch.stack([v["points"] for v in targets]) matched_src = tgt_points[torch.arange(indices_tgt.shape[0]), torch.reshape(indices_tgt, [-1])] matched_tgt = src_points[torch.arange(indices_src.shape[0]), torch.reshape(indices_src, [-1])] src_points = src_points.flatten(0, 1) tgt_points = tgt_points.flatten(0, 1) chamfer_dist = torch.mean(torch.norm(src_points - matched_src, p=2, dim=-1)) + torch.mean( torch.norm(matched_tgt - tgt_points, p=2, dim=-1) ) return chamfer_dist, indices_src def forward(self, outputs, targets): if self.dist_type == "emd": return self.em_distance(outputs, targets) elif self.dist_type == "chamfer": return self.chamfer_distance(outputs, targets) else: raise NotImplementedError("not support other distance") class ChamferDistanceMatching(nn.Module): def __init__(self, point_cost, giou_cost): super().__init__() self.point_cost = point_cost self.giou_cost = giou_cost def forward(self, outputs, targets): """ Expected parameters in the form dictionary, expected in the form: pred_boxes: [l,t,r,b]: the bounding position corresponds to anchor position points: [x,y]: coordinates of each anchor points targets: list of dictionary boxes: [cx,cy,w,h]: target bounding boxes """ with torch.no_grad(): bs, num_queries = outputs["pred_boxes"].shape[:2] out_boxes = outputs["pred_boxes"].flatten(0, 1) # [batch_size*num_queries,4] tgt_boxes = torch.cat([v["boxes"] for v in targets]) # [batch_size * num_targets,4] cost_points = torch.cdist( out_boxes[..., :2], tgt_boxes[..., :2] ) # [batch_size*num_queries,batch_size*num_targets] cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_boxes), box_cxcywh_to_xyxy(tgt_boxes)) C = self.point_cost * cost_points + self.giou_cost * cost_giou C = C.view(bs, num_queries, -1).cpu() indices_src = torch.argmin(C, dim=1) indices_tgt = torch.argmin(C, dim=2) return indices_src, indices_tgt def match_points_to_boxes(ref_points, param): """ Args: ref_points: [2, num_points] param: [num_boxes, 4] Returns: points_in_boxes: [num_points_in_gt, 2] points_outside_boxes: [num_points_outside_gt, 2] """ ref_points = ref_points.type(torch.float32) param = param.type(torch.float32) points_in_boxes = torch.logical_and( torch.logical_and( ref_points[1] >= param[:, 0].unsqueeze(1), ref_points[1] <= param[:, 2].unsqueeze(1) ), torch.logical_and( ref_points[0] >= param[:, 1].unsqueeze(1), ref_points[0] <= param[:, 3].unsqueeze(1) ), ) mask_points_in = points_in_boxes.sum(dim=0) > 0 mask_points_out = torch.logical_not(mask_points_in) # points_in_boxes = ref_points[:, mask_points_in] # points_outside_boxes = ref_points[:, mask_points_out] return mask_points_in, mask_points_out class PointLossHungarianMatcher(nn.Module): def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): """Creates the matcher Params: cost_class: This is the relative weight of the classification error in the matching cost cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost """ super().__init__() self.cost_class = cost_class self.cost_bbox = cost_bbox self.cost_giou = cost_giou assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" def forward(self, outputs, targets, ref_points=None): """ Performs the matching Params: outputs: This is a dict that contains at least these entries: "box_v": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates Returns: A list of size batch_size, containing tuples of (index_i, index_j) where: - index_i is the indices of the selected predictions (in order) - index_j is the indices of the corresponding selected targets (in order) For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) """ with torch.no_grad(): bs, num_queries = outputs["box_v"].shape[:2] # We flatten to compute the cost matrices in a batch out_prob = outputs["box_v"].flatten(0, 1).sigmoid() out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] # Also concat the target labels and boxes tgt_ids = torch.cat([v["labels"] for v in targets]) tgt_bbox = torch.cat([v["boxes"] for v in targets]) # Compute the L1 cost between boxes cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # Compute the giou cost betwen boxes iou, unions = box_iou(out_bbox, tgt_bbox) cost_giou = - generalized_box_iou(out_bbox, tgt_bbox) # Final cost matrix C = self.cost_bbox * cost_bbox + self.cost_giou * cost_giou C = C.view(bs, num_queries, -1).cpu() sizes = [len(v["boxes"]) for v in targets] indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] non_mathced_gt_bbox_idx = \ np.nonzero(np.logical_not(np.in1d(np.array([i for i in range(tgt_bbox.shape[0])]), indices[0][1])))[0] non_mathced_gt_bbox_idx = np.concatenate( (non_mathced_gt_bbox_idx, torch.where(iou.max(dim=0)[0] == 0)[0].cpu().numpy())) non_mathced_gt_bbox_idx = [torch.tensor(non_mathced_gt_bbox_idx, dtype=torch.int64).unique()] remove_mask = np.logical_not(np.in1d(indices[0][1], non_mathced_gt_bbox_idx[ 0].cpu())) ind0 = indices[0][0][remove_mask] ind1 = indices[0][1][remove_mask] non_mathced_pred_bbox_idx = \ np.nonzero(np.logical_not(np.in1d(np.array([i for i in range(out_bbox.shape[0])]), indices[0][0])))[0] match_indexes = [(torch.as_tensor(ind0, dtype=torch.int64), torch.as_tensor(ind1, dtype=torch.int64))] return match_indexes, non_mathced_gt_bbox_idx, non_mathced_pred_bbox_idx # from matplotlib import pyplot as plt # import matplotlib.colors as mcolors # # colors = mcolors.CSS4_COLORS#['r', 'g','b','y','c','gray','brown','lightblue'] # # colors = sorted( # # colors, key=lambda c: tuple(mcolors.rgb_to_hsv(mcolors.to_rgb(c)))) # # colors = [ # 'violet', 'khaki', 'aquamarine', 'darkslategray', 'orchid', 'cornflowerblue', # 'darkgreen', 'peru', 'darkorange', 'mediumseagreen', 'darkviolet', 'dodgerblue', # 'rosybrown', 'mediumorchid', 'cadetblue', 'darkgoldenrod', 'slateblue', 'springgreen', 'firebrick', # 'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', # 'navy', 'coral', 'lime', 'tomato', 'indigo', 'sienna', 'magenta', 'silver', 'gold', 'teal' # ] # # # plt.clf() # # for i in range(out_bbox.shape[0]): # # box = out_bbox[i].cpu() # # plt.plot([box[0], box[0], box[2], box[2], box[0]], # # [box[1], box[3], box[3], box[1], box[1]], color='black') # # for i in range(indices[0][0].shape[0]): # box = out_bbox[indices[0][0][i]].cpu() # plt.plot([box[0], box[0], box[2], box[2], box[0]], # [box[1], box[3], box[3], box[1], box[1]], color=colors[i]) # # box = tgt_bbox[indices[0][1][i]].cpu() # if indices[0][1][i] == 1: # plt.plot([box[0], box[0], box[2], box[2], box[0]], # [box[1], box[3], box[3], box[1], box[1]], color=colors[i], linewidth=3) # plt.plot([box[0], box[0], box[2], box[2], box[0]], # [box[1], box[3], box[3], box[1], box[1]], color=colors[i]) # plt.savefig("Matcbed_bboxes_9") # # # print(sorted(indices[0][1])) def build_matcher(args): return PointLossHungarianMatcher(args.cost_class, args.cost_bbox, args.cost_giou) def build_chamfer_matcher(args): return ChamferDistanceMatching(args.chamfer_point_cost, args.chamfer_giou_cost) class PointHungarianMatcher(nn.Module): """This class computes an assignment between the targets and the predictions of the network For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are un-matched (and thus treated as non-objects). """ def __init__( self, cost_point: float = 1, ): """Creates the matcher Params: cost_class: This is the relative weight of the classification error in the matching cost cost_point: This is the relative weight of the L1 error of the point in the matching cost """ super().__init__() self.cost_point = cost_point assert cost_point != 0, "all costs cant be 0" def forward(self, outputs, targets): """ Performs the matching Params: outputs: This is a dict that contains at least these entries: "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates Returns: A list of size batch_size, containing tuples of (index_i, index_j) where: - index_i is the indices of the selected predictions (in order) - index_j is the indices of the corresponding selected targets (in order) For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) """ with torch.no_grad(): bs, num_queries = outputs["pred_logits"].shape[:2] # We flatten to compute the cost matrices in a batch out_point = outputs["pred_points"].flatten(0, 1) # [batch_size * num_queries, 4] # Also concat the target point tgt_point = torch.cat([v["points"] for v in targets]) # Compute the L1 cost between points cost_point = torch.cdist(out_point, tgt_point, p=1) # Final cost matrix C = self.cost_point * cost_point C = C.view(bs, num_queries, -1).cpu() sizes = [len(v["boxes"]) for v in targets] indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] def build_centerness_matcher(args): return PointHungarianMatcher(cost_point=args.set_cost_points)