Upload matcher.py with huggingface_hub
Browse files- matcher.py +113 -0
matcher.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
"""
|
| 3 |
+
Mostly copy-paste from DETR (https://github.com/facebookresearch/detr).
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
from scipy.optimize import linear_sum_assignment
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HungarianMatcher_Crowd(nn.Module):
|
| 11 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
| 12 |
+
|
| 13 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
| 14 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
| 15 |
+
while the others are un-matched (and thus treated as non-objects).
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
cost_class: float = 1,
|
| 21 |
+
cost_point: float = 1,
|
| 22 |
+
override_multiclass: bool = False,
|
| 23 |
+
pointmatch: bool = False
|
| 24 |
+
):
|
| 25 |
+
"""Creates the matcher
|
| 26 |
+
|
| 27 |
+
Params:
|
| 28 |
+
cost_class: This is the relative weight of the foreground object
|
| 29 |
+
cost_point: This is the relative weight of the L1 error of the points coordinates in the matching cost
|
| 30 |
+
"""
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.cost_class = cost_class
|
| 33 |
+
self.cost_point = cost_point
|
| 34 |
+
assert cost_class != 0 or cost_point != 0, "all costs cant be 0"
|
| 35 |
+
self.override_multiclass = override_multiclass
|
| 36 |
+
self.pointmatch = pointmatch
|
| 37 |
+
|
| 38 |
+
@torch.no_grad()
|
| 39 |
+
def forward(self, outputs, targets):
|
| 40 |
+
"""Performs the matching
|
| 41 |
+
|
| 42 |
+
Params:
|
| 43 |
+
outputs: This is a dict that contains at least these entries:
|
| 44 |
+
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
| 45 |
+
"points": Tensor of dim [batch_size, num_queries, 2] with the predicted point coordinates
|
| 46 |
+
|
| 47 |
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
| 48 |
+
"labels": Tensor of dim [num_target_points] (where num_target_points is the number of ground-truth
|
| 49 |
+
objects in the target) containing the class labels
|
| 50 |
+
"points": Tensor of dim [num_target_points, 2] containing the target point coordinates
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
| 54 |
+
- index_i is the indices of the selected predictions (in order)
|
| 55 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
| 56 |
+
For each batch element, it holds:
|
| 57 |
+
len(index_i) = len(index_j) = min(num_queries, num_target_points)
|
| 58 |
+
"""
|
| 59 |
+
bs, num_queries = outputs["pred_logits"].shape[:2]
|
| 60 |
+
|
| 61 |
+
# We flatten to compute the cost matrices in a batch
|
| 62 |
+
out_prob = (
|
| 63 |
+
outputs["pred_logits"].flatten(0, 1).softmax(-1)
|
| 64 |
+
) # [batch_size * num_queries, num_classes]
|
| 65 |
+
out_points = outputs["pred_points"].flatten(
|
| 66 |
+
0, 1
|
| 67 |
+
) # [batch_size * num_queries, 2]
|
| 68 |
+
# Also concat the target labels and points
|
| 69 |
+
# tgt_ids = torch.cat([v["labels"] for v in targets])
|
| 70 |
+
|
| 71 |
+
tgt_ids = torch.cat([v["labels"] for v in targets])
|
| 72 |
+
tgt_points = torch.cat([v["point"] for v in targets])
|
| 73 |
+
|
| 74 |
+
if self.override_multiclass:
|
| 75 |
+
tgt_ids = torch.ones(tgt_ids.size()[0], dtype=torch.int)
|
| 76 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
| 77 |
+
# but approximate it in 1 - proba[target class].
|
| 78 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
| 79 |
+
cost_class = -out_prob[:, tgt_ids]
|
| 80 |
+
# Compute the L2 cost between point
|
| 81 |
+
cost_point = torch.cdist(out_points, tgt_points, p=2)
|
| 82 |
+
|
| 83 |
+
# Compute the GIoU cost between point
|
| 84 |
+
# Final cost matrix
|
| 85 |
+
if self.pointmatch:
|
| 86 |
+
C = cost_point
|
| 87 |
+
else:
|
| 88 |
+
C = self.cost_point * cost_point + self.cost_class * cost_class
|
| 89 |
+
|
| 90 |
+
# Reshape to back to [batch_size, num_queries, num_target_points]
|
| 91 |
+
C = C.view(bs, num_queries, -1).cpu()
|
| 92 |
+
# compute the matching
|
| 93 |
+
sizes = [len(v["point"]) for v in targets]
|
| 94 |
+
indices = [
|
| 95 |
+
linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
return [
|
| 99 |
+
(
|
| 100 |
+
torch.as_tensor(i, dtype=torch.int64),
|
| 101 |
+
torch.as_tensor(j, dtype=torch.int64),
|
| 102 |
+
)
|
| 103 |
+
for i, j in indices
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def build_matcher_crowd(args, override_multiclass: bool = False):
|
| 108 |
+
return HungarianMatcher_Crowd(
|
| 109 |
+
cost_class=args.set_cost_class,
|
| 110 |
+
cost_point=args.set_cost_point,
|
| 111 |
+
override_multiclass=override_multiclass,
|
| 112 |
+
pointmatch=args.pointmatch,
|
| 113 |
+
)
|