Spaces:
Running
on
Zero
Running
on
Zero
| # Building Hungarian Matcher | |
| # Borrow code from AnchorDETR | |
| # We replace bounding box matching with point location matching | |
| import numpy as np | |
| import torch | |
| from scipy.optimize import linear_sum_assignment | |
| from torch import nn | |
| from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou, box_iou | |
| class HungarianMatcher(nn.Module): | |
| def __init__(self, cost_class: float = 1.0, cost_points: float = 1.0): | |
| """Create the matcher | |
| Params: | |
| cost_class: Class weight | |
| cost_dists: distance weight | |
| """ | |
| super().__init__() | |
| self.cost_class = cost_class | |
| self.cost_points = cost_points | |
| def forward(self, outputs, targets): | |
| """Matching pipeline | |
| Args: | |
| outputs (dict): contains at least two params: | |
| pred_logits: [batch_size, num_queries, num_classes]: classification logits | |
| pred_points: [batch_size, num_queries, 2]: predicted points | |
| targets (list of targets, where len(targets) = batch_size), each target is a dict containing | |
| labels: tensor of dim [num_target_boxes] containing the class label | |
| points: tensor of dim [num_target_boxes,2]: target points coordinate | |
| Returns: | |
| A list of size batch_size, containing the tuple of (index_i, index_j) where: | |
| - index_i: index of selected predictions (in order) | |
| - index_j: index of corresponding selected targets | |
| """ | |
| with torch.no_grad(): | |
| bs, num_queries = outputs["pred_logits"].shape[:2] | |
| # Flatten to compute cost matrix of the batch | |
| out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() | |
| out_points = outputs["pred_points"].flatten(0, 1) # [batch_size * num_queries, 2] | |
| # Also concat target labels and points | |
| tgt_ids = torch.cat([v["labels"] for v in targets]) | |
| tgt_points = torch.cat([v["points"] for v in targets]) # [batch_size*num_targets,2] | |
| # Compute the classification loss | |
| alpha = 0.25 | |
| gamma = 2.0 | |
| neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) | |
| pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) | |
| cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] # [num_queries, num_targets] | |
| # L1 loss | |
| cost_points = torch.cdist(out_points, tgt_points, p=1) | |
| # Add cost | |
| C = self.cost_class * cost_class + self.cost_points * cost_points | |
| C = C.view(bs, num_queries, -1).cpu() | |
| sizes = [len(v["points"]) for v in targets] | |
| indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] | |
| return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] | |
| class PointsDistance(nn.Module): | |
| def __init__(self, dist_type): | |
| """ | |
| Accept two distance type: EMD and Chamfer | |
| """ | |
| super().__init__() | |
| self.dist_type = dist_type | |
| def _get_src_permutation_idx(self, indices): | |
| batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |
| src_idx = torch.cat([src for (src, _) in indices]) | |
| return batch_idx, src_idx | |
| def em_distance(self, outputs, targets): | |
| with torch.no_grad(): | |
| bs, num_queries = outputs["pred_points"].shape[:2] | |
| out_points = outputs["pred_points"].flatten(0, 1) # [batch_size * numqueries,2] | |
| tgt_points = torch.cat([v["points"] for v in targets]) # [batch_size * num_targets,2] | |
| C = torch.norm( | |
| out_points[:, None, :] - tgt_points[None, :, :], p=2, dim=-1 | |
| ) # [batch_size*num_queries,batch_size*num_targets] | |
| C = C.view(bs, num_queries, -1).cpu() | |
| sizes = [len(v["points"]) for v in targets] | |
| indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] | |
| indices = [ | |
| (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices | |
| ] | |
| idx = self._get_src_permutation_idx(indices) | |
| src_points = outputs["pred_points"][idx] | |
| tgt_points = torch.cat([t["points"][i] for t, (_, i) in zip(targets, indices)]) | |
| dists = torch.norm(src_points - tgt_points, p=2, dim=-1) | |
| return torch.mean(dists), indices | |
| def chamfer_distance(self, outputs, targets): | |
| with torch.no_grad(): | |
| bs, num_queries = outputs["pred_points"].shape[:2] | |
| out_points = outputs["pred_points"].flatten(0, 1) # [batch_size * num_queries,2] | |
| tgt_points = torch.cat([v["points"] for v in targets]) # [batch_size * num_targets,2] | |
| C = torch.norm( | |
| out_points[:, None, :] - tgt_points[None, :, :], p=2, dim=-1 | |
| ) # [batch_size * num_queries, batch_size * num_targets] | |
| C = C.view(bs, num_queries, -1) # [batch_size, num queries, num_targets] | |
| indices_src = torch.argmin(C, dim=1) | |
| indices_tgt = torch.argmin(C, dim=2) | |
| src_points = outputs["pred_points"] | |
| tgt_points = torch.stack([v["points"] for v in targets]) | |
| matched_src = tgt_points[torch.arange(indices_tgt.shape[0]), torch.reshape(indices_tgt, [-1])] | |
| matched_tgt = src_points[torch.arange(indices_src.shape[0]), torch.reshape(indices_src, [-1])] | |
| src_points = src_points.flatten(0, 1) | |
| tgt_points = tgt_points.flatten(0, 1) | |
| chamfer_dist = torch.mean(torch.norm(src_points - matched_src, p=2, dim=-1)) + torch.mean( | |
| torch.norm(matched_tgt - tgt_points, p=2, dim=-1) | |
| ) | |
| return chamfer_dist, indices_src | |
| def forward(self, outputs, targets): | |
| if self.dist_type == "emd": | |
| return self.em_distance(outputs, targets) | |
| elif self.dist_type == "chamfer": | |
| return self.chamfer_distance(outputs, targets) | |
| else: | |
| raise NotImplementedError("not support other distance") | |
| class ChamferDistanceMatching(nn.Module): | |
| def __init__(self, point_cost, giou_cost): | |
| super().__init__() | |
| self.point_cost = point_cost | |
| self.giou_cost = giou_cost | |
| def forward(self, outputs, targets): | |
| """ | |
| Expected parameters in the form | |
| dictionary, expected in the form: | |
| pred_boxes: [l,t,r,b]: the bounding position corresponds to anchor position | |
| points: [x,y]: coordinates of each anchor points | |
| targets: list of dictionary | |
| boxes: [cx,cy,w,h]: target bounding boxes | |
| """ | |
| with torch.no_grad(): | |
| bs, num_queries = outputs["pred_boxes"].shape[:2] | |
| out_boxes = outputs["pred_boxes"].flatten(0, 1) # [batch_size*num_queries,4] | |
| tgt_boxes = torch.cat([v["boxes"] for v in targets]) # [batch_size * num_targets,4] | |
| cost_points = torch.cdist( | |
| out_boxes[..., :2], tgt_boxes[..., :2] | |
| ) # [batch_size*num_queries,batch_size*num_targets] | |
| cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_boxes), box_cxcywh_to_xyxy(tgt_boxes)) | |
| C = self.point_cost * cost_points + self.giou_cost * cost_giou | |
| C = C.view(bs, num_queries, -1).cpu() | |
| indices_src = torch.argmin(C, dim=1) | |
| indices_tgt = torch.argmin(C, dim=2) | |
| return indices_src, indices_tgt | |
| def match_points_to_boxes(ref_points, param): | |
| """ | |
| Args: | |
| ref_points: [2, num_points] | |
| param: [num_boxes, 4] | |
| Returns: | |
| points_in_boxes: [num_points_in_gt, 2] | |
| points_outside_boxes: [num_points_outside_gt, 2] | |
| """ | |
| ref_points = ref_points.type(torch.float32) | |
| param = param.type(torch.float32) | |
| points_in_boxes = torch.logical_and( | |
| torch.logical_and( | |
| ref_points[1] >= param[:, 0].unsqueeze(1), ref_points[1] <= param[:, 2].unsqueeze(1) | |
| ), | |
| torch.logical_and( | |
| ref_points[0] >= param[:, 1].unsqueeze(1), ref_points[0] <= param[:, 3].unsqueeze(1) | |
| ), | |
| ) | |
| mask_points_in = points_in_boxes.sum(dim=0) > 0 | |
| mask_points_out = torch.logical_not(mask_points_in) | |
| # points_in_boxes = ref_points[:, mask_points_in] | |
| # points_outside_boxes = ref_points[:, mask_points_out] | |
| return mask_points_in, mask_points_out | |
| class PointLossHungarianMatcher(nn.Module): | |
| def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): | |
| """Creates the matcher | |
| Params: | |
| cost_class: This is the relative weight of the classification error in the matching cost | |
| cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost | |
| cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost | |
| """ | |
| super().__init__() | |
| self.cost_class = cost_class | |
| self.cost_bbox = cost_bbox | |
| self.cost_giou = cost_giou | |
| assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" | |
| def forward(self, outputs, targets, ref_points=None): | |
| """ Performs the matching | |
| Params: | |
| outputs: This is a dict that contains at least these entries: | |
| "box_v": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits | |
| "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates | |
| targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: | |
| "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth | |
| objects in the target) containing the class labels | |
| "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates | |
| Returns: | |
| A list of size batch_size, containing tuples of (index_i, index_j) where: | |
| - index_i is the indices of the selected predictions (in order) | |
| - index_j is the indices of the corresponding selected targets (in order) | |
| For each batch element, it holds: | |
| len(index_i) = len(index_j) = min(num_queries, num_target_boxes) | |
| """ | |
| with torch.no_grad(): | |
| bs, num_queries = outputs["box_v"].shape[:2] | |
| # We flatten to compute the cost matrices in a batch | |
| out_prob = outputs["box_v"].flatten(0, 1).sigmoid() | |
| out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] | |
| # Also concat the target labels and boxes | |
| tgt_ids = torch.cat([v["labels"] for v in targets]) | |
| tgt_bbox = torch.cat([v["boxes"] for v in targets]) | |
| # Compute the L1 cost between boxes | |
| cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) | |
| # Compute the giou cost betwen boxes | |
| iou, unions = box_iou(out_bbox, tgt_bbox) | |
| cost_giou = - generalized_box_iou(out_bbox, tgt_bbox) | |
| # Final cost matrix | |
| C = self.cost_bbox * cost_bbox + self.cost_giou * cost_giou | |
| C = C.view(bs, num_queries, -1).cpu() | |
| sizes = [len(v["boxes"]) for v in targets] | |
| indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] | |
| non_mathced_gt_bbox_idx = \ | |
| np.nonzero(np.logical_not(np.in1d(np.array([i for i in range(tgt_bbox.shape[0])]), indices[0][1])))[0] | |
| non_mathced_gt_bbox_idx = np.concatenate( | |
| (non_mathced_gt_bbox_idx, torch.where(iou.max(dim=0)[0] == 0)[0].cpu().numpy())) | |
| non_mathced_gt_bbox_idx = [torch.tensor(non_mathced_gt_bbox_idx, dtype=torch.int64).unique()] | |
| remove_mask = np.logical_not(np.in1d(indices[0][1], non_mathced_gt_bbox_idx[ | |
| 0].cpu())) | |
| ind0 = indices[0][0][remove_mask] | |
| ind1 = indices[0][1][remove_mask] | |
| non_mathced_pred_bbox_idx = \ | |
| np.nonzero(np.logical_not(np.in1d(np.array([i for i in range(out_bbox.shape[0])]), indices[0][0])))[0] | |
| match_indexes = [(torch.as_tensor(ind0, dtype=torch.int64), torch.as_tensor(ind1, dtype=torch.int64))] | |
| return match_indexes, non_mathced_gt_bbox_idx, non_mathced_pred_bbox_idx | |
| # from matplotlib import pyplot as plt | |
| # import matplotlib.colors as mcolors | |
| # # colors = mcolors.CSS4_COLORS#['r', 'g','b','y','c','gray','brown','lightblue'] | |
| # # colors = sorted( | |
| # # colors, key=lambda c: tuple(mcolors.rgb_to_hsv(mcolors.to_rgb(c)))) | |
| # | |
| # colors = [ | |
| # 'violet', 'khaki', 'aquamarine', 'darkslategray', 'orchid', 'cornflowerblue', | |
| # 'darkgreen', 'peru', 'darkorange', 'mediumseagreen', 'darkviolet', 'dodgerblue', | |
| # 'rosybrown', 'mediumorchid', 'cadetblue', 'darkgoldenrod', 'slateblue', 'springgreen', 'firebrick', | |
| # 'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', | |
| # 'navy', 'coral', 'lime', 'tomato', 'indigo', 'sienna', 'magenta', 'silver', 'gold', 'teal' | |
| # ] | |
| # | |
| # | |
| # plt.clf() | |
| # # for i in range(out_bbox.shape[0]): | |
| # # box = out_bbox[i].cpu() | |
| # # plt.plot([box[0], box[0], box[2], box[2], box[0]], | |
| # # [box[1], box[3], box[3], box[1], box[1]], color='black') | |
| # | |
| # for i in range(indices[0][0].shape[0]): | |
| # box = out_bbox[indices[0][0][i]].cpu() | |
| # plt.plot([box[0], box[0], box[2], box[2], box[0]], | |
| # [box[1], box[3], box[3], box[1], box[1]], color=colors[i]) | |
| # | |
| # box = tgt_bbox[indices[0][1][i]].cpu() | |
| # if indices[0][1][i] == 1: | |
| # plt.plot([box[0], box[0], box[2], box[2], box[0]], | |
| # [box[1], box[3], box[3], box[1], box[1]], color=colors[i], linewidth=3) | |
| # plt.plot([box[0], box[0], box[2], box[2], box[0]], | |
| # [box[1], box[3], box[3], box[1], box[1]], color=colors[i]) | |
| # plt.savefig("Matcbed_bboxes_9") | |
| # # | |
| # print(sorted(indices[0][1])) | |
| def build_matcher(args): | |
| return PointLossHungarianMatcher(args.cost_class, args.cost_bbox, args.cost_giou) | |
| def build_chamfer_matcher(args): | |
| return ChamferDistanceMatching(args.chamfer_point_cost, args.chamfer_giou_cost) | |
| class PointHungarianMatcher(nn.Module): | |
| """This class computes an assignment between the targets and the predictions of the network | |
| For efficiency reasons, the targets don't include the no_object. Because of this, in general, | |
| there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, | |
| while the others are un-matched (and thus treated as non-objects). | |
| """ | |
| def __init__( | |
| self, cost_point: float = 1, | |
| ): | |
| """Creates the matcher | |
| Params: | |
| cost_class: This is the relative weight of the classification error in the matching cost | |
| cost_point: This is the relative weight of the L1 error of the point in the matching cost | |
| """ | |
| super().__init__() | |
| self.cost_point = cost_point | |
| assert cost_point != 0, "all costs cant be 0" | |
| def forward(self, outputs, targets): | |
| """ Performs the matching | |
| Params: | |
| outputs: This is a dict that contains at least these entries: | |
| "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits | |
| "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates | |
| targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: | |
| "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth | |
| objects in the target) containing the class labels | |
| "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates | |
| Returns: | |
| A list of size batch_size, containing tuples of (index_i, index_j) where: | |
| - index_i is the indices of the selected predictions (in order) | |
| - index_j is the indices of the corresponding selected targets (in order) | |
| For each batch element, it holds: | |
| len(index_i) = len(index_j) = min(num_queries, num_target_boxes) | |
| """ | |
| with torch.no_grad(): | |
| bs, num_queries = outputs["pred_logits"].shape[:2] | |
| # We flatten to compute the cost matrices in a batch | |
| out_point = outputs["pred_points"].flatten(0, 1) # [batch_size * num_queries, 4] | |
| # Also concat the target point | |
| tgt_point = torch.cat([v["points"] for v in targets]) | |
| # Compute the L1 cost between points | |
| cost_point = torch.cdist(out_point, tgt_point, p=1) | |
| # Final cost matrix | |
| C = self.cost_point * cost_point | |
| C = C.view(bs, num_queries, -1).cpu() | |
| sizes = [len(v["boxes"]) for v in targets] | |
| indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] | |
| return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] | |
| def build_centerness_matcher(args): | |
| return PointHungarianMatcher(cost_point=args.set_cost_points) | |