sposhiy commited on
Commit
02094a4
·
verified ·
1 Parent(s): f92b56e

Upload matcher.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. matcher.py +113 -0
matcher.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Mostly copy-paste from DETR (https://github.com/facebookresearch/detr).
4
+ """
5
+ import torch
6
+ from scipy.optimize import linear_sum_assignment
7
+ from torch import nn
8
+
9
+
10
+ class HungarianMatcher_Crowd(nn.Module):
11
+ """This class computes an assignment between the targets and the predictions of the network
12
+
13
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
14
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
15
+ while the others are un-matched (and thus treated as non-objects).
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ cost_class: float = 1,
21
+ cost_point: float = 1,
22
+ override_multiclass: bool = False,
23
+ pointmatch: bool = False
24
+ ):
25
+ """Creates the matcher
26
+
27
+ Params:
28
+ cost_class: This is the relative weight of the foreground object
29
+ cost_point: This is the relative weight of the L1 error of the points coordinates in the matching cost
30
+ """
31
+ super().__init__()
32
+ self.cost_class = cost_class
33
+ self.cost_point = cost_point
34
+ assert cost_class != 0 or cost_point != 0, "all costs cant be 0"
35
+ self.override_multiclass = override_multiclass
36
+ self.pointmatch = pointmatch
37
+
38
+ @torch.no_grad()
39
+ def forward(self, outputs, targets):
40
+ """Performs the matching
41
+
42
+ Params:
43
+ outputs: This is a dict that contains at least these entries:
44
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
45
+ "points": Tensor of dim [batch_size, num_queries, 2] with the predicted point coordinates
46
+
47
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
48
+ "labels": Tensor of dim [num_target_points] (where num_target_points is the number of ground-truth
49
+ objects in the target) containing the class labels
50
+ "points": Tensor of dim [num_target_points, 2] containing the target point coordinates
51
+
52
+ Returns:
53
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
54
+ - index_i is the indices of the selected predictions (in order)
55
+ - index_j is the indices of the corresponding selected targets (in order)
56
+ For each batch element, it holds:
57
+ len(index_i) = len(index_j) = min(num_queries, num_target_points)
58
+ """
59
+ bs, num_queries = outputs["pred_logits"].shape[:2]
60
+
61
+ # We flatten to compute the cost matrices in a batch
62
+ out_prob = (
63
+ outputs["pred_logits"].flatten(0, 1).softmax(-1)
64
+ ) # [batch_size * num_queries, num_classes]
65
+ out_points = outputs["pred_points"].flatten(
66
+ 0, 1
67
+ ) # [batch_size * num_queries, 2]
68
+ # Also concat the target labels and points
69
+ # tgt_ids = torch.cat([v["labels"] for v in targets])
70
+
71
+ tgt_ids = torch.cat([v["labels"] for v in targets])
72
+ tgt_points = torch.cat([v["point"] for v in targets])
73
+
74
+ if self.override_multiclass:
75
+ tgt_ids = torch.ones(tgt_ids.size()[0], dtype=torch.int)
76
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
77
+ # but approximate it in 1 - proba[target class].
78
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
79
+ cost_class = -out_prob[:, tgt_ids]
80
+ # Compute the L2 cost between point
81
+ cost_point = torch.cdist(out_points, tgt_points, p=2)
82
+
83
+ # Compute the GIoU cost between point
84
+ # Final cost matrix
85
+ if self.pointmatch:
86
+ C = cost_point
87
+ else:
88
+ C = self.cost_point * cost_point + self.cost_class * cost_class
89
+
90
+ # Reshape to back to [batch_size, num_queries, num_target_points]
91
+ C = C.view(bs, num_queries, -1).cpu()
92
+ # compute the matching
93
+ sizes = [len(v["point"]) for v in targets]
94
+ indices = [
95
+ linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))
96
+ ]
97
+
98
+ return [
99
+ (
100
+ torch.as_tensor(i, dtype=torch.int64),
101
+ torch.as_tensor(j, dtype=torch.int64),
102
+ )
103
+ for i, j in indices
104
+ ]
105
+
106
+
107
+ def build_matcher_crowd(args, override_multiclass: bool = False):
108
+ return HungarianMatcher_Crowd(
109
+ cost_class=args.set_cost_class,
110
+ cost_point=args.set_cost_point,
111
+ override_multiclass=override_multiclass,
112
+ pointmatch=args.pointmatch,
113
+ )