Mjolnir65 commited on
Commit
0e51835
·
verified ·
1 Parent(s): 2219500

uploaded the weights

Browse files
Files changed (34) hide show
  1. detection/__init__.py +1 -0
  2. detection/__pycache__/__init__.cpython-311.pyc +0 -0
  3. detection/__pycache__/__init__.cpython-37.pyc +0 -0
  4. detection/__pycache__/_utils.cpython-311.pyc +0 -0
  5. detection/__pycache__/_utils.cpython-37.pyc +0 -0
  6. detection/__pycache__/anchor_utils.cpython-311.pyc +0 -0
  7. detection/__pycache__/anchor_utils.cpython-37.pyc +0 -0
  8. detection/__pycache__/backbone_utils.cpython-311.pyc +0 -0
  9. detection/__pycache__/backbone_utils.cpython-37.pyc +0 -0
  10. detection/__pycache__/faster_rcnn.cpython-311.pyc +0 -0
  11. detection/__pycache__/faster_rcnn.cpython-37.pyc +0 -0
  12. detection/__pycache__/generalized_rcnn.cpython-311.pyc +0 -0
  13. detection/__pycache__/generalized_rcnn.cpython-37.pyc +0 -0
  14. detection/__pycache__/image_list.cpython-311.pyc +0 -0
  15. detection/__pycache__/image_list.cpython-37.pyc +0 -0
  16. detection/__pycache__/roi_heads.cpython-311.pyc +0 -0
  17. detection/__pycache__/roi_heads.cpython-37.pyc +0 -0
  18. detection/__pycache__/rpn.cpython-311.pyc +0 -0
  19. detection/__pycache__/rpn.cpython-37.pyc +0 -0
  20. detection/__pycache__/transform.cpython-311.pyc +0 -0
  21. detection/__pycache__/transform.cpython-37.pyc +0 -0
  22. detection/_utils.py +540 -0
  23. detection/anchor_utils.py +268 -0
  24. detection/backbone_utils.py +121 -0
  25. detection/faster_rcnn.py +390 -0
  26. detection/generalized_rcnn.py +128 -0
  27. detection/image_list.py +25 -0
  28. detection/roi_heads.py +400 -0
  29. detection/rpn.py +385 -0
  30. detection/transform.py +318 -0
  31. infer.py +346 -0
  32. requirements.txt +105 -0
  33. st/tv_frcnn_r50fpn_faster_rcnn_st.pth +3 -0
  34. st/tv_frcnn_r50fpn_faster_rcnn_st_10.pth +3 -0
detection/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .faster_rcnn import *
detection/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (226 Bytes). View file
 
detection/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (169 Bytes). View file
 
detection/__pycache__/_utils.cpython-311.pyc ADDED
Binary file (28.3 kB). View file
 
detection/__pycache__/_utils.cpython-37.pyc ADDED
Binary file (18.2 kB). View file
 
detection/__pycache__/anchor_utils.cpython-311.pyc ADDED
Binary file (18.5 kB). View file
 
detection/__pycache__/anchor_utils.cpython-37.pyc ADDED
Binary file (10.5 kB). View file
 
detection/__pycache__/backbone_utils.cpython-311.pyc ADDED
Binary file (6.98 kB). View file
 
detection/__pycache__/backbone_utils.cpython-37.pyc ADDED
Binary file (4.19 kB). View file
 
detection/__pycache__/faster_rcnn.cpython-311.pyc ADDED
Binary file (20.1 kB). View file
 
detection/__pycache__/faster_rcnn.cpython-37.pyc ADDED
Binary file (14.4 kB). View file
 
detection/__pycache__/generalized_rcnn.cpython-311.pyc ADDED
Binary file (6.28 kB). View file
 
detection/__pycache__/generalized_rcnn.cpython-37.pyc ADDED
Binary file (3.74 kB). View file
 
detection/__pycache__/image_list.cpython-311.pyc ADDED
Binary file (1.63 kB). View file
 
detection/__pycache__/image_list.cpython-37.pyc ADDED
Binary file (1.2 kB). View file
 
detection/__pycache__/roi_heads.cpython-311.pyc ADDED
Binary file (17.9 kB). View file
 
detection/__pycache__/roi_heads.cpython-37.pyc ADDED
Binary file (9.31 kB). View file
 
detection/__pycache__/rpn.cpython-311.pyc ADDED
Binary file (20.1 kB). View file
 
detection/__pycache__/rpn.cpython-37.pyc ADDED
Binary file (11.7 kB). View file
 
detection/__pycache__/transform.cpython-311.pyc ADDED
Binary file (19.5 kB). View file
 
detection/__pycache__/transform.cpython-37.pyc ADDED
Binary file (9.92 kB). View file
 
detection/_utils.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ from torch import nn, Tensor
7
+ from torch.nn import functional as F
8
+ from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
9
+
10
+
11
+ class BalancedPositiveNegativeSampler:
12
+ """
13
+ This class samples batches, ensuring that they contain a fixed proportion of positives
14
+ """
15
+
16
+ def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
17
+ """
18
+ Args:
19
+ batch_size_per_image (int): number of elements to be selected per image
20
+ positive_fraction (float): percentage of positive elements per batch
21
+ """
22
+ self.batch_size_per_image = batch_size_per_image
23
+ self.positive_fraction = positive_fraction
24
+
25
+ def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
26
+ """
27
+ Args:
28
+ matched_idxs: list of tensors containing -1, 0 or positive values.
29
+ Each tensor corresponds to a specific image.
30
+ -1 values are ignored, 0 are considered as negatives and > 0 as
31
+ positives.
32
+
33
+ Returns:
34
+ pos_idx (list[tensor])
35
+ neg_idx (list[tensor])
36
+
37
+ Returns two lists of binary masks for each image.
38
+ The first list contains the positive elements that were selected,
39
+ and the second list the negative example.
40
+ """
41
+ pos_idx = []
42
+ neg_idx = []
43
+ for matched_idxs_per_image in matched_idxs:
44
+ positive = torch.where(matched_idxs_per_image >= 1)[0]
45
+ negative = torch.where(matched_idxs_per_image == 0)[0]
46
+
47
+ num_pos = int(self.batch_size_per_image * self.positive_fraction)
48
+ # protect against not enough positive examples
49
+ num_pos = min(positive.numel(), num_pos)
50
+ num_neg = self.batch_size_per_image - num_pos
51
+ # protect against not enough negative examples
52
+ num_neg = min(negative.numel(), num_neg)
53
+
54
+ # randomly select positive and negative examples
55
+ perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
56
+ perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
57
+
58
+ pos_idx_per_image = positive[perm1]
59
+ neg_idx_per_image = negative[perm2]
60
+
61
+ # create binary mask from indices
62
+ pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
63
+ neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
64
+
65
+ pos_idx_per_image_mask[pos_idx_per_image] = 1
66
+ neg_idx_per_image_mask[neg_idx_per_image] = 1
67
+
68
+ pos_idx.append(pos_idx_per_image_mask)
69
+ neg_idx.append(neg_idx_per_image_mask)
70
+
71
+ return pos_idx, neg_idx
72
+
73
+
74
+ @torch.jit._script_if_tracing
75
+ def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
76
+ """
77
+ Encode a set of proposals with respect to some
78
+ reference boxes
79
+
80
+ Args:
81
+ reference_boxes (Tensor): reference boxes
82
+ proposals (Tensor): boxes to be encoded
83
+ weights (Tensor[4]): the weights for ``(x, y, w, h)``
84
+ """
85
+
86
+ # perform some unpacking to make it JIT-fusion friendly
87
+ wx = weights[0]
88
+ wy = weights[1]
89
+ ww = weights[2]
90
+ wh = weights[3]
91
+
92
+ proposals_x1 = proposals[:, 0].unsqueeze(1)
93
+ proposals_y1 = proposals[:, 1].unsqueeze(1)
94
+ proposals_x2 = proposals[:, 2].unsqueeze(1)
95
+ proposals_y2 = proposals[:, 3].unsqueeze(1)
96
+
97
+ reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
98
+ reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
99
+ reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
100
+ reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
101
+
102
+ # implementation starts here
103
+ ex_widths = proposals_x2 - proposals_x1
104
+ ex_heights = proposals_y2 - proposals_y1
105
+ ex_ctr_x = proposals_x1 + 0.5 * ex_widths
106
+ ex_ctr_y = proposals_y1 + 0.5 * ex_heights
107
+
108
+ gt_widths = reference_boxes_x2 - reference_boxes_x1
109
+ gt_heights = reference_boxes_y2 - reference_boxes_y1
110
+ gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
111
+ gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
112
+
113
+ targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
114
+ targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
115
+ targets_dw = ww * torch.log(gt_widths / ex_widths)
116
+ targets_dh = wh * torch.log(gt_heights / ex_heights)
117
+
118
+ targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
119
+ return targets
120
+
121
+
122
+ class BoxCoder:
123
+ """
124
+ This class encodes and decodes a set of bounding boxes into
125
+ the representation used for training the regressors.
126
+ """
127
+
128
+ def __init__(
129
+ self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
130
+ ) -> None:
131
+ """
132
+ Args:
133
+ weights (4-element tuple)
134
+ bbox_xform_clip (float)
135
+ """
136
+ self.weights = weights
137
+ self.bbox_xform_clip = bbox_xform_clip
138
+
139
+ def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
140
+ boxes_per_image = [len(b) for b in reference_boxes]
141
+ reference_boxes = torch.cat(reference_boxes, dim=0)
142
+ proposals = torch.cat(proposals, dim=0)
143
+ targets = self.encode_single(reference_boxes, proposals)
144
+ return targets.split(boxes_per_image, 0)
145
+
146
+ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
147
+ """
148
+ Encode a set of proposals with respect to some
149
+ reference boxes
150
+
151
+ Args:
152
+ reference_boxes (Tensor): reference boxes
153
+ proposals (Tensor): boxes to be encoded
154
+ """
155
+ dtype = reference_boxes.dtype
156
+ device = reference_boxes.device
157
+ weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
158
+ targets = encode_boxes(reference_boxes, proposals, weights)
159
+
160
+ return targets
161
+
162
+ def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
163
+ torch._assert(
164
+ isinstance(boxes, (list, tuple)),
165
+ "This function expects boxes of type list or tuple.",
166
+ )
167
+ torch._assert(
168
+ isinstance(rel_codes, torch.Tensor),
169
+ "This function expects rel_codes of type torch.Tensor.",
170
+ )
171
+ boxes_per_image = [b.size(0) for b in boxes]
172
+ concat_boxes = torch.cat(boxes, dim=0)
173
+ box_sum = 0
174
+ for val in boxes_per_image:
175
+ box_sum += val
176
+ if box_sum > 0:
177
+ rel_codes = rel_codes.reshape(box_sum, -1)
178
+ pred_boxes = self.decode_single(rel_codes, concat_boxes)
179
+ if box_sum > 0:
180
+ pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
181
+ return pred_boxes
182
+
183
+ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
184
+ """
185
+ From a set of original boxes and encoded relative box offsets,
186
+ get the decoded boxes.
187
+
188
+ Args:
189
+ rel_codes (Tensor): encoded boxes
190
+ boxes (Tensor): reference boxes.
191
+ """
192
+
193
+ boxes = boxes.to(rel_codes.dtype)
194
+
195
+ widths = boxes[:, 2] - boxes[:, 0]
196
+ heights = boxes[:, 3] - boxes[:, 1]
197
+ ctr_x = boxes[:, 0] + 0.5 * widths
198
+ ctr_y = boxes[:, 1] + 0.5 * heights
199
+
200
+ wx, wy, ww, wh = self.weights
201
+ dx = rel_codes[:, 0::4] / wx
202
+ dy = rel_codes[:, 1::4] / wy
203
+ dw = rel_codes[:, 2::4] / ww
204
+ dh = rel_codes[:, 3::4] / wh
205
+
206
+ # Prevent sending too large values into torch.exp()
207
+ dw = torch.clamp(dw, max=self.bbox_xform_clip)
208
+ dh = torch.clamp(dh, max=self.bbox_xform_clip)
209
+
210
+ pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
211
+ pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
212
+ pred_w = torch.exp(dw) * widths[:, None]
213
+ pred_h = torch.exp(dh) * heights[:, None]
214
+
215
+ # Distance from center to box's corner.
216
+ c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
217
+ c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
218
+
219
+ pred_boxes1 = pred_ctr_x - c_to_c_w
220
+ pred_boxes2 = pred_ctr_y - c_to_c_h
221
+ pred_boxes3 = pred_ctr_x + c_to_c_w
222
+ pred_boxes4 = pred_ctr_y + c_to_c_h
223
+ pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
224
+ return pred_boxes
225
+
226
+
227
+ class BoxLinearCoder:
228
+ """
229
+ The linear box-to-box transform defined in FCOS. The transformation is parameterized
230
+ by the distance from the center of (square) src box to 4 edges of the target box.
231
+ """
232
+
233
+ def __init__(self, normalize_by_size: bool = True) -> None:
234
+ """
235
+ Args:
236
+ normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
237
+ """
238
+ self.normalize_by_size = normalize_by_size
239
+
240
+ def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
241
+ """
242
+ Encode a set of proposals with respect to some reference boxes
243
+
244
+ Args:
245
+ reference_boxes (Tensor): reference boxes
246
+ proposals (Tensor): boxes to be encoded
247
+
248
+ Returns:
249
+ Tensor: the encoded relative box offsets that can be used to
250
+ decode the boxes.
251
+
252
+ """
253
+
254
+ # get the center of reference_boxes
255
+ reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
256
+ reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
257
+
258
+ # get box regression transformation deltas
259
+ target_l = reference_boxes_ctr_x - proposals[..., 0]
260
+ target_t = reference_boxes_ctr_y - proposals[..., 1]
261
+ target_r = proposals[..., 2] - reference_boxes_ctr_x
262
+ target_b = proposals[..., 3] - reference_boxes_ctr_y
263
+
264
+ targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
265
+
266
+ if self.normalize_by_size:
267
+ reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
268
+ reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
269
+ reference_boxes_size = torch.stack(
270
+ (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
271
+ )
272
+ targets = targets / reference_boxes_size
273
+ return targets
274
+
275
+ def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
276
+
277
+ """
278
+ From a set of original boxes and encoded relative box offsets,
279
+ get the decoded boxes.
280
+
281
+ Args:
282
+ rel_codes (Tensor): encoded boxes
283
+ boxes (Tensor): reference boxes.
284
+
285
+ Returns:
286
+ Tensor: the predicted boxes with the encoded relative box offsets.
287
+
288
+ .. note::
289
+ This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
290
+
291
+ """
292
+
293
+ boxes = boxes.to(dtype=rel_codes.dtype)
294
+
295
+ ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
296
+ ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
297
+
298
+ if self.normalize_by_size:
299
+ boxes_w = boxes[..., 2] - boxes[..., 0]
300
+ boxes_h = boxes[..., 3] - boxes[..., 1]
301
+
302
+ list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
303
+ rel_codes = rel_codes * list_box_size
304
+
305
+ pred_boxes1 = ctr_x - rel_codes[..., 0]
306
+ pred_boxes2 = ctr_y - rel_codes[..., 1]
307
+ pred_boxes3 = ctr_x + rel_codes[..., 2]
308
+ pred_boxes4 = ctr_y + rel_codes[..., 3]
309
+
310
+ pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
311
+ return pred_boxes
312
+
313
+
314
+ class Matcher:
315
+ """
316
+ This class assigns to each predicted "element" (e.g., a box) a ground-truth
317
+ element. Each predicted element will have exactly zero or one matches; each
318
+ ground-truth element may be assigned to zero or more predicted elements.
319
+
320
+ Matching is based on the MxN match_quality_matrix, that characterizes how well
321
+ each (ground-truth, predicted)-pair match. For example, if the elements are
322
+ boxes, the matrix may contain box IoU overlap values.
323
+
324
+ The matcher returns a tensor of size N containing the index of the ground-truth
325
+ element m that matches to prediction n. If there is no match, a negative value
326
+ is returned.
327
+ """
328
+
329
+ BELOW_LOW_THRESHOLD = -1
330
+ BETWEEN_THRESHOLDS = -2
331
+
332
+ __annotations__ = {
333
+ "BELOW_LOW_THRESHOLD": int,
334
+ "BETWEEN_THRESHOLDS": int,
335
+ }
336
+
337
+ def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
338
+ """
339
+ Args:
340
+ high_threshold (float): quality values greater than or equal to
341
+ this value are candidate matches.
342
+ low_threshold (float): a lower quality threshold used to stratify
343
+ matches into three levels:
344
+ 1) matches >= high_threshold
345
+ 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
346
+ 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
347
+ allow_low_quality_matches (bool): if True, produce additional matches
348
+ for predictions that have only low-quality match candidates. See
349
+ set_low_quality_matches_ for more details.
350
+ """
351
+ self.BELOW_LOW_THRESHOLD = -1
352
+ self.BETWEEN_THRESHOLDS = -2
353
+ torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
354
+ self.high_threshold = high_threshold
355
+ self.low_threshold = low_threshold
356
+ self.allow_low_quality_matches = allow_low_quality_matches
357
+
358
+ def __call__(self, match_quality_matrix: Tensor) -> Tensor:
359
+ """
360
+ Args:
361
+ match_quality_matrix (Tensor[float]): an MxN tensor, containing the
362
+ pairwise quality between M ground-truth elements and N predicted elements.
363
+
364
+ Returns:
365
+ matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
366
+ [0, M - 1] or a negative value indicating that prediction i could not
367
+ be matched.
368
+ """
369
+ if match_quality_matrix.numel() == 0:
370
+ # empty targets or proposals not supported during training
371
+ if match_quality_matrix.shape[0] == 0:
372
+ raise ValueError("No ground-truth boxes available for one of the images during training")
373
+ else:
374
+ raise ValueError("No proposal boxes available for one of the images during training")
375
+
376
+ # match_quality_matrix is M (gt) x N (predicted)
377
+ # Max over gt elements (dim 0) to find best gt candidate for each prediction
378
+ matched_vals, matches = match_quality_matrix.max(dim=0)
379
+ if self.allow_low_quality_matches:
380
+ all_matches = matches.clone()
381
+ else:
382
+ all_matches = None # type: ignore[assignment]
383
+
384
+ # Assign candidate matches with low quality to negative (unassigned) values
385
+ below_low_threshold = matched_vals < self.low_threshold
386
+ between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
387
+ matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
388
+ matches[between_thresholds] = self.BETWEEN_THRESHOLDS
389
+
390
+ if self.allow_low_quality_matches:
391
+ if all_matches is None:
392
+ torch._assert(False, "all_matches should not be None")
393
+ else:
394
+ self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
395
+
396
+ return matches
397
+
398
+ def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
399
+ """
400
+ Produce additional matches for predictions that have only low-quality matches.
401
+ Specifically, for each ground-truth find the set of predictions that have
402
+ maximum overlap with it (including ties); for each prediction in that set, if
403
+ it is unmatched, then match it to the ground-truth with which it has the highest
404
+ quality value.
405
+ """
406
+ # For each gt, find the prediction with which it has the highest quality
407
+ highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
408
+ # Find the highest quality match available, even if it is low, including ties
409
+ gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
410
+ # Example gt_pred_pairs_of_highest_quality:
411
+ # (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]),
412
+ # tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
413
+ # Each element in the first tensor is a gt index, and each element in second tensor is a prediction index
414
+ # Note how gt items 1, 2, 3, and 5 each have two ties
415
+
416
+ pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
417
+ matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
418
+
419
+
420
+ class SSDMatcher(Matcher):
421
+ def __init__(self, threshold: float) -> None:
422
+ super().__init__(threshold, threshold, allow_low_quality_matches=False)
423
+
424
+ def __call__(self, match_quality_matrix: Tensor) -> Tensor:
425
+ matches = super().__call__(match_quality_matrix)
426
+
427
+ # For each gt, find the prediction with which it has the highest quality
428
+ _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
429
+ matches[highest_quality_pred_foreach_gt] = torch.arange(
430
+ highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
431
+ )
432
+
433
+ return matches
434
+
435
+
436
+ def overwrite_eps(model: nn.Module, eps: float) -> None:
437
+ """
438
+ This method overwrites the default eps values of all the
439
+ FrozenBatchNorm2d layers of the model with the provided value.
440
+ This is necessary to address the BC-breaking change introduced
441
+ by the bug-fix at pytorch/vision#2933. The overwrite is applied
442
+ only when the pretrained weights are loaded to maintain compatibility
443
+ with previous versions.
444
+
445
+ Args:
446
+ model (nn.Module): The model on which we perform the overwrite.
447
+ eps (float): The new value of eps.
448
+ """
449
+ for module in model.modules():
450
+ if isinstance(module, FrozenBatchNorm2d):
451
+ module.eps = eps
452
+
453
+
454
+ def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
455
+ """
456
+ This method retrieves the number of output channels of a specific model.
457
+
458
+ Args:
459
+ model (nn.Module): The model for which we estimate the out_channels.
460
+ It should return a single Tensor or an OrderedDict[Tensor].
461
+ size (Tuple[int, int]): The size (wxh) of the input.
462
+
463
+ Returns:
464
+ out_channels (List[int]): A list of the output channels of the model.
465
+ """
466
+ in_training = model.training
467
+ model.eval()
468
+
469
+ with torch.no_grad():
470
+ # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
471
+ device = next(model.parameters()).device
472
+ tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
473
+ features = model(tmp_img)
474
+ if isinstance(features, torch.Tensor):
475
+ features = OrderedDict([("0", features)])
476
+ out_channels = [x.size(1) for x in features.values()]
477
+
478
+ if in_training:
479
+ model.train()
480
+
481
+ return out_channels
482
+
483
+
484
+ @torch.jit.unused
485
+ def _fake_cast_onnx(v: Tensor) -> int:
486
+ return v # type: ignore[return-value]
487
+
488
+
489
+ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
490
+ """
491
+ ONNX spec requires the k-value to be less than or equal to the number of inputs along
492
+ provided dim. Certain models use the number of elements along a particular axis instead of K
493
+ if K exceeds the number of elements along that axis. Previously, python's min() function was
494
+ used to determine whether to use the provided k-value or the specified dim axis value.
495
+
496
+ However, in cases where the model is being exported in tracing mode, python min() is
497
+ static causing the model to be traced incorrectly and eventually fail at the topk node.
498
+ In order to avoid this situation, in tracing mode, torch.min() is used instead.
499
+
500
+ Args:
501
+ input (Tensor): The original input tensor.
502
+ orig_kval (int): The provided k-value.
503
+ axis(int): Axis along which we retrieve the input size.
504
+
505
+ Returns:
506
+ min_kval (int): Appropriately selected k-value.
507
+ """
508
+ if not torch.jit.is_tracing():
509
+ return min(orig_kval, input.size(axis))
510
+ axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
511
+ min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
512
+ return _fake_cast_onnx(min_kval)
513
+
514
+
515
+ def _box_loss(
516
+ type: str,
517
+ box_coder: BoxCoder,
518
+ anchors_per_image: Tensor,
519
+ matched_gt_boxes_per_image: Tensor,
520
+ bbox_regression_per_image: Tensor,
521
+ cnf: Optional[Dict[str, float]] = None,
522
+ ) -> Tensor:
523
+ torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
524
+
525
+ if type == "l1":
526
+ target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
527
+ return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
528
+ elif type == "smooth_l1":
529
+ target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
530
+ beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
531
+ return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
532
+ else:
533
+ bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
534
+ eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
535
+ if type == "ciou":
536
+ return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
537
+ if type == "diou":
538
+ return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
539
+ # otherwise giou
540
+ return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
detection/anchor_utils.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+
7
+ from .image_list import ImageList
8
+
9
+
10
+ class AnchorGenerator(nn.Module):
11
+ """
12
+ Module that generates anchors for a set of feature maps and
13
+ image sizes.
14
+
15
+ The module support computing anchors at multiple sizes and aspect ratios
16
+ per feature map. This module assumes aspect ratio = height / width for
17
+ each anchor.
18
+
19
+ sizes and aspect_ratios should have the same number of elements, and it should
20
+ correspond to the number of feature maps.
21
+
22
+ sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
23
+ and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
24
+ per spatial location for feature map i.
25
+
26
+ Args:
27
+ sizes (Tuple[Tuple[int]]):
28
+ aspect_ratios (Tuple[Tuple[float]]):
29
+ """
30
+
31
+ __annotations__ = {
32
+ "cell_anchors": List[torch.Tensor],
33
+ }
34
+
35
+ def __init__(
36
+ self,
37
+ sizes=((128, 256, 512),),
38
+ aspect_ratios=((0.5, 1.0, 2.0),),
39
+ ):
40
+ super().__init__()
41
+
42
+ if not isinstance(sizes[0], (list, tuple)):
43
+ # TODO change this
44
+ sizes = tuple((s,) for s in sizes)
45
+ if not isinstance(aspect_ratios[0], (list, tuple)):
46
+ aspect_ratios = (aspect_ratios,) * len(sizes)
47
+
48
+ self.sizes = sizes
49
+ self.aspect_ratios = aspect_ratios
50
+ self.cell_anchors = [
51
+ self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios)
52
+ ]
53
+
54
+ # TODO: https://github.com/pytorch/pytorch/issues/26792
55
+ # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
56
+ # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
57
+ # This method assumes aspect ratio = height / width for an anchor.
58
+ def generate_anchors(
59
+ self,
60
+ scales: List[int],
61
+ aspect_ratios: List[float],
62
+ dtype: torch.dtype = torch.float32,
63
+ device: torch.device = torch.device("cpu"),
64
+ ) -> Tensor:
65
+ scales = torch.as_tensor(scales, dtype=dtype, device=device)
66
+ aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
67
+ h_ratios = torch.sqrt(aspect_ratios)
68
+ w_ratios = 1 / h_ratios
69
+
70
+ ws = (w_ratios[:, None] * scales[None, :]).view(-1)
71
+ hs = (h_ratios[:, None] * scales[None, :]).view(-1)
72
+
73
+ base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
74
+ return base_anchors.round()
75
+
76
+ def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
77
+ self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
78
+
79
+ def num_anchors_per_location(self) -> List[int]:
80
+ return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
81
+
82
+ # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
83
+ # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
84
+ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
85
+ anchors = []
86
+ cell_anchors = self.cell_anchors
87
+ torch._assert(cell_anchors is not None, "cell_anchors should not be None")
88
+ torch._assert(
89
+ len(grid_sizes) == len(strides) == len(cell_anchors),
90
+ "Anchors should be Tuple[Tuple[int]] because each feature "
91
+ "map could potentially have different sizes and aspect ratios. "
92
+ "There needs to be a match between the number of "
93
+ "feature maps passed and the number of sizes / aspect ratios specified.",
94
+ )
95
+
96
+ for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
97
+ grid_height, grid_width = size
98
+ stride_height, stride_width = stride
99
+ device = base_anchors.device
100
+
101
+ # For output anchor, compute [x_center, y_center, x_center, y_center]
102
+ shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
103
+ shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
104
+ shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
105
+ shift_x = shift_x.reshape(-1)
106
+ shift_y = shift_y.reshape(-1)
107
+ shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
108
+
109
+ # For every (base anchor, output anchor) pair,
110
+ # offset each zero-centered base anchor by the center of the output anchor.
111
+ anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
112
+
113
+ return anchors
114
+
115
+ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
116
+ grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
117
+ image_size = image_list.tensors.shape[-2:]
118
+ dtype, device = feature_maps[0].dtype, feature_maps[0].device
119
+ strides = [
120
+ [
121
+ torch.empty((), dtype=torch.int64, device=device).fill_(image_size[0] // g[0]),
122
+ torch.empty((), dtype=torch.int64, device=device).fill_(image_size[1] // g[1]),
123
+ ]
124
+ for g in grid_sizes
125
+ ]
126
+ self.set_cell_anchors(dtype, device)
127
+ anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
128
+ anchors: List[List[torch.Tensor]] = []
129
+ for _ in range(len(image_list.image_sizes)):
130
+ anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
131
+ anchors.append(anchors_in_image)
132
+ anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
133
+ return anchors
134
+
135
+
136
+ class DefaultBoxGenerator(nn.Module):
137
+ """
138
+ This module generates the default boxes of SSD for a set of feature maps and image sizes.
139
+
140
+ Args:
141
+ aspect_ratios (List[List[int]]): A list with all the aspect ratios used in each feature map.
142
+ min_ratio (float): The minimum scale :math:`\text{s}_{\text{min}}` of the default boxes used in the estimation
143
+ of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
144
+ max_ratio (float): The maximum scale :math:`\text{s}_{\text{max}}` of the default boxes used in the estimation
145
+ of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
146
+ scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
147
+ the ``min_ratio`` and ``max_ratio`` parameters.
148
+ steps (List[int]], optional): It's a hyper-parameter that affects the tiling of default boxes. If not provided
149
+ it will be estimated from the data.
150
+ clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
151
+ is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ aspect_ratios: List[List[int]],
157
+ min_ratio: float = 0.15,
158
+ max_ratio: float = 0.9,
159
+ scales: Optional[List[float]] = None,
160
+ steps: Optional[List[int]] = None,
161
+ clip: bool = True,
162
+ ):
163
+ super().__init__()
164
+ if steps is not None and len(aspect_ratios) != len(steps):
165
+ raise ValueError("aspect_ratios and steps should have the same length")
166
+ self.aspect_ratios = aspect_ratios
167
+ self.steps = steps
168
+ self.clip = clip
169
+ num_outputs = len(aspect_ratios)
170
+
171
+ # Estimation of default boxes scales
172
+ if scales is None:
173
+ if num_outputs > 1:
174
+ range_ratio = max_ratio - min_ratio
175
+ self.scales = [min_ratio + range_ratio * k / (num_outputs - 1.0) for k in range(num_outputs)]
176
+ self.scales.append(1.0)
177
+ else:
178
+ self.scales = [min_ratio, max_ratio]
179
+ else:
180
+ self.scales = scales
181
+
182
+ self._wh_pairs = self._generate_wh_pairs(num_outputs)
183
+
184
+ def _generate_wh_pairs(
185
+ self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu")
186
+ ) -> List[Tensor]:
187
+ _wh_pairs: List[Tensor] = []
188
+ for k in range(num_outputs):
189
+ # Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
190
+ s_k = self.scales[k]
191
+ s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1])
192
+ wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]]
193
+
194
+ # Adding 2 pairs for each aspect ratio of the feature map k
195
+ for ar in self.aspect_ratios[k]:
196
+ sq_ar = math.sqrt(ar)
197
+ w = self.scales[k] * sq_ar
198
+ h = self.scales[k] / sq_ar
199
+ wh_pairs.extend([[w, h], [h, w]])
200
+
201
+ _wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
202
+ return _wh_pairs
203
+
204
+ def num_anchors_per_location(self) -> List[int]:
205
+ # Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
206
+ return [2 + 2 * len(r) for r in self.aspect_ratios]
207
+
208
+ # Default Boxes calculation based on page 6 of SSD paper
209
+ def _grid_default_boxes(
210
+ self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32
211
+ ) -> Tensor:
212
+ default_boxes = []
213
+ for k, f_k in enumerate(grid_sizes):
214
+ # Now add the default boxes for each width-height pair
215
+ if self.steps is not None:
216
+ x_f_k = image_size[1] / self.steps[k]
217
+ y_f_k = image_size[0] / self.steps[k]
218
+ else:
219
+ y_f_k, x_f_k = f_k
220
+
221
+ shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype)
222
+ shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype)
223
+ shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
224
+ shift_x = shift_x.reshape(-1)
225
+ shift_y = shift_y.reshape(-1)
226
+
227
+ shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2)
228
+ # Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h)
229
+ _wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k]
230
+ wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1)
231
+
232
+ default_box = torch.cat((shifts, wh_pairs), dim=1)
233
+
234
+ default_boxes.append(default_box)
235
+
236
+ return torch.cat(default_boxes, dim=0)
237
+
238
+ def __repr__(self) -> str:
239
+ s = (
240
+ f"{self.__class__.__name__}("
241
+ f"aspect_ratios={self.aspect_ratios}"
242
+ f", clip={self.clip}"
243
+ f", scales={self.scales}"
244
+ f", steps={self.steps}"
245
+ ")"
246
+ )
247
+ return s
248
+
249
+ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
250
+ grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
251
+ image_size = image_list.tensors.shape[-2:]
252
+ dtype, device = feature_maps[0].dtype, feature_maps[0].device
253
+ default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype)
254
+ default_boxes = default_boxes.to(device)
255
+
256
+ dboxes = []
257
+ x_y_size = torch.tensor([image_size[1], image_size[0]], device=default_boxes.device)
258
+ for _ in image_list.image_sizes:
259
+ dboxes_in_image = default_boxes
260
+ dboxes_in_image = torch.cat(
261
+ [
262
+ (dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
263
+ (dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
264
+ ],
265
+ -1,
266
+ )
267
+ dboxes.append(dboxes_in_image)
268
+ return dboxes
detection/backbone_utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Callable, Dict, List, Optional, Union
3
+
4
+ from torch import nn, Tensor
5
+ from torchvision.ops import misc as misc_nn_ops
6
+ from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
7
+
8
+ from torchvision.models import resnet
9
+ from torchvision.models._api import _get_enum_from_fn, WeightsEnum
10
+ from torchvision.models._utils import handle_legacy_interface, IntermediateLayerGetter
11
+
12
+
13
+ class BackboneWithFPN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ backbone: nn.Module,
17
+ return_layers: Dict[str, str],
18
+ in_channels_list: List[int],
19
+ out_channels: int,
20
+ extra_blocks: Optional[ExtraFPNBlock] = None,
21
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ if extra_blocks is None:
26
+ extra_blocks = LastLevelMaxPool()
27
+
28
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
29
+ self.fpn = FeaturePyramidNetwork(
30
+ in_channels_list=in_channels_list,
31
+ out_channels=out_channels,
32
+ extra_blocks=extra_blocks,
33
+ norm_layer=norm_layer,
34
+ )
35
+ self.out_channels = out_channels
36
+
37
+ def forward(self, x: Tensor) -> Dict[str, Tensor]:
38
+ x = self.body(x)
39
+ x = self.fpn(x)
40
+ return x
41
+
42
+
43
+ @handle_legacy_interface(
44
+ weights=(
45
+ "pretrained",
46
+ lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
47
+ ),
48
+ )
49
+ def resnet_fpn_backbone(
50
+ *,
51
+ backbone_name: str,
52
+ weights: Optional[WeightsEnum],
53
+ norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
54
+ trainable_layers: int = 3,
55
+ returned_layers: Optional[List[int]] = None,
56
+ extra_blocks: Optional[ExtraFPNBlock] = None,
57
+ ) -> BackboneWithFPN:
58
+
59
+ backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
60
+ return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
61
+
62
+
63
+ def _resnet_fpn_extractor(
64
+ backbone: resnet.ResNet,
65
+ trainable_layers: int,
66
+ returned_layers: Optional[List[int]] = None,
67
+ extra_blocks: Optional[ExtraFPNBlock] = None,
68
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
69
+ ) -> BackboneWithFPN:
70
+
71
+ # select layers that won't be frozen
72
+ if trainable_layers < 0 or trainable_layers > 5:
73
+ raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
74
+ layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
75
+ if trainable_layers == 5:
76
+ layers_to_train.append("bn1")
77
+ for name, parameter in backbone.named_parameters():
78
+ if all([not name.startswith(layer) for layer in layers_to_train]):
79
+ parameter.requires_grad_(False)
80
+
81
+ if extra_blocks is None:
82
+ extra_blocks = LastLevelMaxPool()
83
+
84
+ if returned_layers is None:
85
+ returned_layers = [1, 2, 3, 4]
86
+ if min(returned_layers) <= 0 or max(returned_layers) >= 5:
87
+ raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
88
+ return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
89
+
90
+ in_channels_stage2 = backbone.inplanes // 8
91
+ in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
92
+ out_channels = 256
93
+ return BackboneWithFPN(
94
+ backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
95
+ )
96
+
97
+
98
+ def _validate_trainable_layers(
99
+ is_trained: bool,
100
+ trainable_backbone_layers: Optional[int],
101
+ max_value: int,
102
+ default_value: int,
103
+ ) -> int:
104
+ # don't freeze any layers if pretrained model or backbone is not used
105
+ if not is_trained:
106
+ if trainable_backbone_layers is not None:
107
+ warnings.warn(
108
+ "Changing trainable_backbone_layers has no effect if "
109
+ "neither pretrained nor pretrained_backbone have been set to True, "
110
+ f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable"
111
+ )
112
+ trainable_backbone_layers = max_value
113
+
114
+ # by default freeze first blocks
115
+ if trainable_backbone_layers is None:
116
+ trainable_backbone_layers = default_value
117
+ if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
118
+ raise ValueError(
119
+ f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
120
+ )
121
+ return trainable_backbone_layers
detection/faster_rcnn.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from torchvision.ops import MultiScaleRoIAlign
7
+
8
+ from torchvision.ops import misc as misc_nn_ops
9
+ from torchvision.transforms._presets import ObjectDetection
10
+ from torchvision.models._api import register_model, Weights, WeightsEnum
11
+ from torchvision.models._meta import _COCO_CATEGORIES
12
+ from torchvision.models._utils import _ovewrite_value_param, handle_legacy_interface
13
+ from torchvision.models.resnet import resnet50, ResNet50_Weights
14
+
15
+ from ._utils import overwrite_eps
16
+ from .anchor_utils import AnchorGenerator
17
+ from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
18
+ from .generalized_rcnn import GeneralizedRCNN
19
+ from .roi_heads import RoIHeads
20
+ from .rpn import RegionProposalNetwork, RPNHead
21
+ from .transform import GeneralizedRCNNTransform
22
+
23
+
24
+ __all__ = [
25
+ "FasterRCNN",
26
+ "FasterRCNN_ResNet50_FPN_Weights",
27
+ "fasterrcnn_resnet50_fpn",
28
+ ]
29
+
30
+
31
+ def _default_anchorgen():
32
+ anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
33
+ aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
34
+ return AnchorGenerator(anchor_sizes, aspect_ratios)
35
+
36
+
37
+ class FasterRCNN(GeneralizedRCNN):
38
+ """
39
+ Implements Faster R-CNN.
40
+
41
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
42
+ image, and should be in 0-1 range. Different images can have different sizes.
43
+
44
+ The behavior of the model changes depending on if it is in training or evaluation mode.
45
+
46
+ During training, the model expects both the input tensors and targets (list of dictionary),
47
+ containing:
48
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
49
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
50
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
51
+
52
+ The model returns a Dict[Tensor] during training, containing the classification and regression
53
+ losses for both the RPN and the R-CNN.
54
+
55
+ During inference, the model requires only the input tensors, and returns the post-processed
56
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
57
+ follows:
58
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
59
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
60
+ - labels (Int64Tensor[N]): the predicted labels for each image
61
+ - scores (Tensor[N]): the scores or each prediction
62
+
63
+ Args:
64
+ backbone (nn.Module): the network used to compute the features for the model.
65
+ It should contain an out_channels attribute, which indicates the number of output
66
+ channels that each feature map has (and it should be the same for all feature maps).
67
+ The backbone should return a single Tensor or and OrderedDict[Tensor].
68
+ num_classes (int): number of output classes of the model (including the background).
69
+ If box_predictor is specified, num_classes should be None.
70
+ min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
71
+ max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
72
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
73
+ They are generally the mean values of the dataset on which the backbone has been trained
74
+ on
75
+ image_std (Tuple[float, float, float]): std values used for input normalization.
76
+ They are generally the std values of the dataset on which the backbone has been trained on
77
+ rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
78
+ maps.
79
+ rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
80
+ rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
81
+ rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
82
+ rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
83
+ rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
84
+ rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
85
+ rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
86
+ considered as positive during training of the RPN.
87
+ rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
88
+ considered as negative during training of the RPN.
89
+ rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
90
+ for computing the loss
91
+ rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
92
+ of the RPN
93
+ rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
94
+ box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
95
+ the locations indicated by the bounding boxes
96
+ box_head (nn.Module): module that takes the cropped feature maps as input
97
+ box_predictor (nn.Module): module that takes the output of box_head and returns the
98
+ classification logits and box regression deltas.
99
+ box_score_thresh (float): during inference, only return proposals with a classification score
100
+ greater than box_score_thresh
101
+ box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
102
+ box_detections_per_img (int): maximum number of detections per image, for all classes.
103
+ box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
104
+ considered as positive during training of the classification head
105
+ box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
106
+ considered as negative during training of the classification head
107
+ box_batch_size_per_image (int): number of proposals that are sampled during training of the
108
+ classification head
109
+ box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
110
+ of the classification head
111
+ bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
112
+ bounding boxes
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ backbone,
118
+ num_classes=None,
119
+ # transform parameters
120
+ min_size=800,
121
+ max_size=1333,
122
+ image_mean=None,
123
+ image_std=None,
124
+ # RPN parameters
125
+ rpn_anchor_generator=None,
126
+ rpn_head=None,
127
+ rpn_pre_nms_top_n_train=2000,
128
+ rpn_pre_nms_top_n_test=1000,
129
+ rpn_post_nms_top_n_train=2000,
130
+ rpn_post_nms_top_n_test=1000,
131
+ rpn_nms_thresh=0.7,
132
+ rpn_fg_iou_thresh=0.7,
133
+ rpn_bg_iou_thresh=0.3,
134
+ rpn_batch_size_per_image=256,
135
+ rpn_positive_fraction=0.5,
136
+ rpn_score_thresh=0.0,
137
+ # Box parameters
138
+ box_roi_pool=None,
139
+ box_head=None,
140
+ box_predictor=None,
141
+ box_score_thresh=0.05,
142
+ box_nms_thresh=0.5,
143
+ box_detections_per_img=100,
144
+ box_fg_iou_thresh=0.5,
145
+ box_bg_iou_thresh=0.5,
146
+ box_batch_size_per_image=512,
147
+ box_positive_fraction=0.25,
148
+ bbox_reg_weights=None,
149
+ **kwargs,
150
+ ):
151
+
152
+ if not hasattr(backbone, "out_channels"):
153
+ raise ValueError(
154
+ "backbone should contain an attribute out_channels "
155
+ "specifying the number of output channels (assumed to be the "
156
+ "same for all the levels)"
157
+ )
158
+
159
+ if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
160
+ raise TypeError(
161
+ f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
162
+ )
163
+ if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
164
+ raise TypeError(
165
+ f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
166
+ )
167
+
168
+ if num_classes is not None:
169
+ if box_predictor is not None:
170
+ raise ValueError("num_classes should be None when box_predictor is specified")
171
+ else:
172
+ if box_predictor is None:
173
+ raise ValueError("num_classes should not be None when box_predictor is not specified")
174
+
175
+ out_channels = backbone.out_channels
176
+
177
+ if rpn_anchor_generator is None:
178
+ rpn_anchor_generator = _default_anchorgen()
179
+ if rpn_head is None:
180
+ rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
181
+
182
+ rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
183
+ rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
184
+
185
+ rpn = RegionProposalNetwork(
186
+ rpn_anchor_generator,
187
+ rpn_head,
188
+ rpn_fg_iou_thresh,
189
+ rpn_bg_iou_thresh,
190
+ rpn_batch_size_per_image,
191
+ rpn_positive_fraction,
192
+ rpn_pre_nms_top_n,
193
+ rpn_post_nms_top_n,
194
+ rpn_nms_thresh,
195
+ score_thresh=rpn_score_thresh,
196
+ )
197
+
198
+ if box_roi_pool is None:
199
+ box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
200
+
201
+ if box_head is None:
202
+ resolution = box_roi_pool.output_size[0]
203
+ representation_size = 1024
204
+ box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
205
+
206
+ if box_predictor is None:
207
+ representation_size = 1024
208
+ box_predictor = FastRCNNPredictor(representation_size, num_classes)
209
+
210
+ roi_heads = RoIHeads(
211
+ # Box
212
+ box_roi_pool,
213
+ box_head,
214
+ box_predictor,
215
+ box_fg_iou_thresh,
216
+ box_bg_iou_thresh,
217
+ box_batch_size_per_image,
218
+ box_positive_fraction,
219
+ bbox_reg_weights,
220
+ box_score_thresh,
221
+ box_nms_thresh,
222
+ box_detections_per_img,
223
+ )
224
+
225
+ if image_mean is None:
226
+ image_mean = [0.485, 0.456, 0.406]
227
+ if image_std is None:
228
+ image_std = [0.229, 0.224, 0.225]
229
+ transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
230
+
231
+ super().__init__(backbone, rpn, roi_heads, transform)
232
+
233
+
234
+ class TwoMLPHead(nn.Module):
235
+ """
236
+ Standard heads for FPN-based models
237
+
238
+ Args:
239
+ in_channels (int): number of input channels
240
+ representation_size (int): size of the intermediate representation
241
+ """
242
+
243
+ def __init__(self, in_channels, representation_size):
244
+ super().__init__()
245
+
246
+ self.fc6 = nn.Linear(in_channels, representation_size)
247
+ self.fc7 = nn.Linear(representation_size, representation_size)
248
+
249
+ def forward(self, x):
250
+ x = x.flatten(start_dim=1)
251
+
252
+ x = F.relu(self.fc6(x))
253
+ x = F.relu(self.fc7(x))
254
+
255
+ return x
256
+
257
+
258
+ class FastRCNNConvFCHead(nn.Sequential):
259
+ def __init__(
260
+ self,
261
+ input_size: Tuple[int, int, int],
262
+ conv_layers: List[int],
263
+ fc_layers: List[int],
264
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
265
+ ):
266
+ """
267
+ Args:
268
+ input_size (Tuple[int, int, int]): the input size in CHW format.
269
+ conv_layers (list): feature dimensions of each Convolution layer
270
+ fc_layers (list): feature dimensions of each FCN layer
271
+ norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
272
+ """
273
+ in_channels, in_height, in_width = input_size
274
+
275
+ blocks = []
276
+ previous_channels = in_channels
277
+ for current_channels in conv_layers:
278
+ blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
279
+ previous_channels = current_channels
280
+ blocks.append(nn.Flatten())
281
+ previous_channels = previous_channels * in_height * in_width
282
+ for current_channels in fc_layers:
283
+ blocks.append(nn.Linear(previous_channels, current_channels))
284
+ blocks.append(nn.ReLU(inplace=True))
285
+ previous_channels = current_channels
286
+
287
+ super().__init__(*blocks)
288
+ for layer in self.modules():
289
+ if isinstance(layer, nn.Conv2d):
290
+ nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
291
+ if layer.bias is not None:
292
+ nn.init.zeros_(layer.bias)
293
+
294
+
295
+ class FastRCNNPredictor(nn.Module):
296
+ """
297
+ Standard classification + bounding box regression layers + theta
298
+ for Fast R-CNN.
299
+
300
+ Args:
301
+ in_channels (int): number of input channels
302
+ num_classes (int): number of output classes (including background)
303
+ """
304
+
305
+ def __init__(self, in_channels, num_classes, num_theta_bins=1):
306
+ super().__init__()
307
+ self.cls_score = nn.Linear(in_channels, num_classes)
308
+ self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
309
+ self.theta_pred = nn.Linear(in_channels, 1 + num_theta_bins)
310
+
311
+ def forward(self, x):
312
+ if x.dim() == 4:
313
+ torch._assert(
314
+ list(x.shape[2:]) == [1, 1],
315
+ f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
316
+ )
317
+ x = x.flatten(start_dim=1)
318
+ scores = self.cls_score(x)
319
+ bbox_deltas = self.bbox_pred(x)
320
+ theta_preds = self.theta_pred(x)
321
+
322
+ return scores, bbox_deltas, theta_preds
323
+
324
+ _COMMON_META = {
325
+ "categories": _COCO_CATEGORIES,
326
+ "min_size": (1, 1),
327
+ }
328
+
329
+
330
+ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
331
+ COCO_V1 = Weights(
332
+ url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
333
+ transforms=ObjectDetection,
334
+ meta={
335
+ **_COMMON_META,
336
+ "num_params": 41755286,
337
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
338
+ "_metrics": {
339
+ "COCO-val2017": {
340
+ "box_map": 37.0,
341
+ }
342
+ },
343
+ "_ops": 134.38,
344
+ "_file_size": 159.743,
345
+ "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
346
+ },
347
+ )
348
+ DEFAULT = COCO_V1
349
+
350
+
351
+ # @register_model()
352
+ @handle_legacy_interface(
353
+ weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
354
+ weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
355
+ )
356
+ def fasterrcnn_resnet50_fpn(
357
+ *,
358
+ weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
359
+ progress: bool = True,
360
+ num_classes: Optional[int] = None,
361
+ weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
362
+ trainable_backbone_layers: Optional[int] = None,
363
+ **kwargs: Any,
364
+ ) -> FasterRCNN:
365
+
366
+ weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
367
+ weights_backbone = ResNet50_Weights.verify(weights_backbone)
368
+
369
+ if weights is not None:
370
+ weights_backbone = None
371
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
372
+ elif num_classes is None:
373
+ num_classes = 91
374
+
375
+ is_trained = weights is not None or weights_backbone is not None
376
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
377
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
378
+
379
+ backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
380
+ backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
381
+ model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
382
+
383
+ if weights is not None:
384
+ model.load_state_dict(weights.get_state_dict(progress=progress), strict=False)
385
+ torch.nn.init.kaiming_normal_(model.roi_heads.box_predictor.theta_pred.weight, mode="fan_out", nonlinearity="relu")
386
+ torch.nn.init.constant_(model.roi_heads.box_predictor.theta_pred.bias, 0)
387
+ if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
388
+ overwrite_eps(model, 0.0)
389
+
390
+ return model
detection/generalized_rcnn.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implements the Generalized R-CNN framework
3
+ """
4
+
5
+ import warnings
6
+ from collections import OrderedDict
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ from torch import nn, Tensor
11
+
12
+ from torchvision.utils import _log_api_usage_once
13
+
14
+
15
+ class GeneralizedRCNN(nn.Module):
16
+ """
17
+ Main class for Generalized R-CNN.
18
+
19
+ Args:
20
+ backbone (nn.Module):
21
+ rpn (nn.Module):
22
+ roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
23
+ detections / masks from it.
24
+ transform (nn.Module): performs the data transformation from the inputs to feed into
25
+ the model
26
+ """
27
+
28
+ def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
29
+ super().__init__()
30
+ _log_api_usage_once(self)
31
+ self.transform = transform
32
+ self.backbone = backbone
33
+ self.rpn = rpn
34
+ self.roi_heads = roi_heads
35
+ # used only on torchscript mode
36
+ self._has_warned = False
37
+
38
+ @torch.jit.unused
39
+ def eager_outputs(self, losses, detections):
40
+ # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
41
+ if self.training:
42
+ return losses
43
+
44
+ return detections
45
+
46
+ def forward(self, images, targets=None):
47
+ # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
48
+ """
49
+ Args:
50
+ images (list[Tensor]): images to be processed
51
+ targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
52
+
53
+ Returns:
54
+ result (list[BoxList] or dict[Tensor]): the output from the model.
55
+ During training, it returns a dict[Tensor] which contains the losses.
56
+ During testing, it returns list[BoxList] contains additional fields
57
+ like `scores`, `labels` and `mask` (for Mask R-CNN models).
58
+
59
+ """
60
+ if self.training:
61
+ if targets is None:
62
+ torch._assert(False, "targets should not be none when in training mode")
63
+ else:
64
+ for target in targets:
65
+ boxes = target["boxes"]
66
+ if isinstance(boxes, torch.Tensor):
67
+ torch._assert(
68
+ len(boxes.shape) == 2 and boxes.shape[-1] == 4,
69
+ f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
70
+ )
71
+ else:
72
+ torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
73
+
74
+ original_image_sizes: List[Tuple[int, int]] = []
75
+ for img in images:
76
+ val = img.shape[-2:]
77
+ torch._assert(
78
+ len(val) == 2,
79
+ f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
80
+ )
81
+ original_image_sizes.append((val[0], val[1]))
82
+
83
+ images, targets = self.transform(images, targets)
84
+
85
+ # Check for degenerate boxes
86
+ # TODO: Move this to a function
87
+ if targets is not None:
88
+ for target_idx, target in enumerate(targets):
89
+ boxes = target["boxes"]
90
+ degenerate_boxes = boxes[:, 2:4] <= boxes[:, :2]
91
+ if degenerate_boxes.any():
92
+ # print the first degenerate box
93
+ bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
94
+ degen_bb: List[float] = boxes[bb_idx].tolist()
95
+ torch._assert(
96
+ False,
97
+ "All bounding boxes should have positive height and width."
98
+ f" Found invalid box {degen_bb} for target at index {target_idx}.",
99
+ )
100
+
101
+ features = self.backbone(images.tensors)
102
+ if isinstance(features, torch.Tensor):
103
+ features = OrderedDict([("0", features)])
104
+
105
+ # modify targets to remove theta for rpn
106
+ # print(f"{len(targets)=}")
107
+ # print(f"{targets[0]=}")
108
+ # targets_rpn = []
109
+ # for target in targets:
110
+ # target_rpn = target.copy()
111
+ # target_rpn['boxes'] = target_rpn['boxes'][:, :-1]
112
+ # targets_rpn.append(target_rpn)
113
+ # print(f"{targets_rpn[0]=}")
114
+ proposals, proposal_losses = self.rpn(images, features, targets)
115
+ detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
116
+ detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
117
+
118
+ losses = {}
119
+ losses.update(detector_losses)
120
+ losses.update(proposal_losses)
121
+
122
+ if torch.jit.is_scripting():
123
+ if not self._has_warned:
124
+ warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
125
+ self._has_warned = True
126
+ return losses, detections
127
+ else:
128
+ return self.eager_outputs(losses, detections)
detection/image_list.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ class ImageList:
8
+ """
9
+ Structure that holds a list of images (of possibly
10
+ varying sizes) as a single tensor.
11
+ This works by padding the images to the same size,
12
+ and storing in a field the original sizes of each image
13
+
14
+ Args:
15
+ tensors (tensor): Tensor containing images.
16
+ image_sizes (list[tuple[int, int]]): List of Tuples each containing size of images.
17
+ """
18
+
19
+ def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None:
20
+ self.tensors = tensors
21
+ self.image_sizes = image_sizes
22
+
23
+ def to(self, device: torch.device) -> "ImageList":
24
+ cast_tensor = self.tensors.to(device)
25
+ return ImageList(cast_tensor, self.image_sizes)
detection/roi_heads.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchvision
6
+ from torch import nn, Tensor
7
+ from torchvision.ops import boxes as box_ops, roi_align
8
+
9
+ from . import _utils as det_utils
10
+
11
+ def compute_theta_loss(preds, targets):
12
+ # print(f"{preds.shape=} {targets.shape=}")
13
+ # print(f"{preds.device=}, {targets.device=}")
14
+ num_bins = preds.shape[1]
15
+ if num_bins == 1:
16
+ # regression
17
+ return F.mse_loss(preds[:,0], targets)
18
+ else:
19
+ # classification
20
+ bin_size = torch.pi / num_bins
21
+ targets_bins = (targets / bin_size).long()
22
+ targets_bins = torch.clamp(targets_bins, 0, num_bins - 1)
23
+ return F.cross_entropy(preds, targets_bins)
24
+
25
+ def fastrcnn_loss(class_logits, box_regression, theta_preds, labels, regression_targets, theta_targets):
26
+ # type: (Tensor, Tensor, Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor, Tensor]
27
+ """
28
+ Computes the loss for Faster R-CNN.
29
+
30
+ Args:
31
+ class_logits (Tensor)
32
+ box_regression (Tensor)
33
+ labels (list[BoxList])
34
+ regression_targets (Tensor)
35
+ theta_targets (Tensor)
36
+
37
+ Returns:
38
+ classification_loss (Tensor)
39
+ box_loss (Tensor)
40
+ theta_loss (Tensor)
41
+ """
42
+ # print(f"{class_logits.shape=} {box_regression.shape=} {theta_preds.shape=}")
43
+ # print(f"{labels[0].shape=} {regression_targets[0].shape=} {theta_targets[0].shape=}")
44
+ labels = torch.cat(labels, dim=0)
45
+ regression_targets = torch.cat(regression_targets, dim=0)
46
+ theta_targets = torch.cat(theta_targets, dim=0)
47
+ # print(f"{labels.shape=} {regression_targets.shape=} {theta_targets.shape=}")
48
+
49
+ classification_loss = F.cross_entropy(class_logits, labels)
50
+
51
+ # get indices that correspond to the regression targets for
52
+ # the corresponding ground truth labels, to be used with
53
+ # advanced indexing
54
+ sampled_pos_inds_subset = torch.where(labels > 0)[0]
55
+ labels_pos = labels[sampled_pos_inds_subset]
56
+ N, num_classes = class_logits.shape
57
+ box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
58
+
59
+ box_loss = F.smooth_l1_loss(
60
+ box_regression[sampled_pos_inds_subset, labels_pos],
61
+ regression_targets[sampled_pos_inds_subset],
62
+ beta=1 / 9,
63
+ reduction="sum",
64
+ )
65
+ box_loss = box_loss / labels.numel()
66
+
67
+ # theta loss
68
+ preds = theta_preds[sampled_pos_inds_subset.to(theta_preds.device)][:,1:]
69
+ targets = theta_targets[sampled_pos_inds_subset.to(theta_targets.device)].to(preds.device)
70
+
71
+ theta_loss = compute_theta_loss(preds, targets)
72
+
73
+ return classification_loss, box_loss, theta_loss
74
+
75
+ def fastrcnn_theta_loss(theta_preds, theta_targets):
76
+ # print(f"{len(theta_preds)=}")
77
+ # print(f"{len(theta_targets)=}")
78
+ # print(f"{theta_preds[0].shape=}")
79
+ # print(f"{theta_targets[0].shape=}")
80
+ return 0
81
+
82
+ def convert_xyxytheta_to_xywha(boxes):
83
+ xc = (boxes[:, 0] + boxes[:, 2]) / 2
84
+ yc = (boxes[:, 1] + boxes[:, 3]) / 2
85
+ w = boxes[:, 2] - boxes[:, 0]
86
+ h = boxes[:, 3] - boxes[:, 1]
87
+ theta = torch.deg2rad(boxes[:, 4])
88
+ return torch.stack([xc, yc, w, h, theta], dim=1)
89
+
90
+ class RoIHeads(nn.Module):
91
+ __annotations__ = {
92
+ "box_coder": det_utils.BoxCoder,
93
+ "proposal_matcher": det_utils.Matcher,
94
+ "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
95
+ }
96
+
97
+ def __init__(
98
+ self,
99
+ box_roi_pool,
100
+ box_head,
101
+ box_predictor,
102
+ # Faster R-CNN training
103
+ fg_iou_thresh,
104
+ bg_iou_thresh,
105
+ batch_size_per_image,
106
+ positive_fraction,
107
+ bbox_reg_weights,
108
+ # Faster R-CNN inference
109
+ score_thresh,
110
+ nms_thresh,
111
+ detections_per_img,
112
+ ):
113
+ super().__init__()
114
+
115
+ self.box_similarity = box_ops.box_iou
116
+ # assign ground-truth boxes for each proposal
117
+ self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
118
+
119
+ self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
120
+
121
+ if bbox_reg_weights is None:
122
+ bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
123
+ self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
124
+
125
+ self.box_roi_pool = box_roi_pool
126
+ self.box_head = box_head
127
+ self.box_predictor = box_predictor
128
+
129
+ self.score_thresh = score_thresh
130
+ self.nms_thresh = nms_thresh
131
+ self.detections_per_img = detections_per_img
132
+
133
+ def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
134
+ # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
135
+ matched_idxs = []
136
+ labels = []
137
+ for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
138
+
139
+ if gt_boxes_in_image.numel() == 0:
140
+ # Background image
141
+ device = proposals_in_image.device
142
+ clamped_matched_idxs_in_image = torch.zeros(
143
+ (proposals_in_image.shape[0],), dtype=torch.int64, device=device
144
+ )
145
+ labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
146
+ else:
147
+ # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
148
+ # print(f"{gt_boxes_in_image.shape=}")
149
+ # print(f"{proposals_in_image.shape=}")
150
+ match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
151
+ # match_quality_matrix = box_iou_rotated(convert_xyxytheta_to_xywha(gt_boxes_in_image), convert_xyxytheta_to_xywha(proposals_in_image))
152
+ matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
153
+
154
+ clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
155
+
156
+ labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
157
+ labels_in_image = labels_in_image.to(dtype=torch.int64)
158
+
159
+ # Label background (below the low threshold)
160
+ bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
161
+ labels_in_image[bg_inds] = 0
162
+
163
+ # Label ignore proposals (between low and high thresholds)
164
+ ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
165
+ labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
166
+
167
+ matched_idxs.append(clamped_matched_idxs_in_image)
168
+ labels.append(labels_in_image)
169
+ return matched_idxs, labels
170
+
171
+ def subsample(self, labels):
172
+ # type: (List[Tensor]) -> List[Tensor]
173
+ sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
174
+ sampled_inds = []
175
+ for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
176
+ img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
177
+ sampled_inds.append(img_sampled_inds)
178
+ return sampled_inds
179
+
180
+ def add_gt_proposals(self, proposals, gt_boxes):
181
+ # type: (List[Tensor], List[Tensor]) -> List[Tensor]
182
+ # print(f"{len(proposals)=}")
183
+ # print(f"{len(gt_boxes)=}")
184
+ # print(f"{proposals[0].shape=}")
185
+ # print(f"{gt_boxes[0].shape=}")
186
+ # print(f"{proposals[0]=}")
187
+ # proposals_with_theta = []
188
+ # for proposal in proposals:
189
+ # proposal_with_theta = torch.cat((proposal, torch.zeros((proposal.shape[0], 1), dtype=proposal.dtype, device=proposal.device)), dim=1)
190
+ # proposals_with_theta.append(proposal_with_theta)
191
+ # gt_boxes_without_theta = [gt_box[:, :-1] for gt_box in gt_boxes]
192
+ # print(f"{len(proposal_with_theta)=}")
193
+ # print(f"{proposals_with_theta[0].shape=}")
194
+ # print(f"{proposals_with_theta[0]=}")
195
+ proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
196
+
197
+ return proposals
198
+
199
+ def check_targets(self, targets):
200
+ # type: (Optional[List[Dict[str, Tensor]]]) -> None
201
+ if targets is None:
202
+ raise ValueError("targets should not be None")
203
+ if not all(["boxes" in t for t in targets]):
204
+ raise ValueError("Every element of targets should have a boxes key")
205
+ if not all(["labels" in t for t in targets]):
206
+ raise ValueError("Every element of targets should have a labels key")
207
+
208
+ def select_training_samples(
209
+ self,
210
+ proposals, # type: List[Tensor]
211
+ targets, # type: Optional[List[Dict[str, Tensor]]]
212
+ ):
213
+ # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
214
+ self.check_targets(targets)
215
+ if targets is None:
216
+ raise ValueError("targets should not be None")
217
+ dtype = proposals[0].dtype
218
+ device = proposals[0].device
219
+
220
+ gt_boxes = [t["boxes"].to(dtype) for t in targets]
221
+ gt_labels = [t["labels"] for t in targets]
222
+ gt_thetas = [t["thetas"] for t in targets]
223
+
224
+ # append ground-truth bboxes to propos
225
+ proposals = self.add_gt_proposals(proposals, gt_boxes)
226
+
227
+ # get matching gt indices for each proposal
228
+ matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
229
+ # sample a fixed proportion of positive-negative proposals
230
+ sampled_inds = self.subsample(labels)
231
+ matched_gt_boxes = []
232
+ matched_gt_thetas = []
233
+
234
+ num_images = len(proposals)
235
+ for img_id in range(num_images):
236
+ img_sampled_inds = sampled_inds[img_id]
237
+ proposals[img_id] = proposals[img_id][img_sampled_inds]
238
+ labels[img_id] = labels[img_id][img_sampled_inds]
239
+ matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
240
+
241
+ gt_boxes_in_image = gt_boxes[img_id]
242
+ gt_thetas_in_image = gt_thetas[img_id]
243
+
244
+ if gt_boxes_in_image.numel() == 0:
245
+ gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
246
+ matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
247
+ matched_gt_thetas.append(gt_thetas_in_image[matched_idxs[img_id].to(gt_thetas_in_image.device)])
248
+
249
+ regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
250
+ theta_targets = matched_gt_thetas
251
+ return proposals, matched_idxs, labels, regression_targets, theta_targets
252
+
253
+ def postprocess_detections(
254
+ self,
255
+ class_logits, # type: Tensor
256
+ box_regression, # type: Tensor
257
+ theta_preds, # type: Tensor
258
+ proposals, # type: List[Tensor]
259
+ image_shapes, # type: List[Tuple[int, int]]
260
+ ):
261
+ # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
262
+ device = class_logits.device
263
+ num_classes = class_logits.shape[-1]
264
+
265
+ boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
266
+ pred_boxes = self.box_coder.decode(box_regression, proposals)
267
+
268
+ pred_scores = F.softmax(class_logits, -1)
269
+
270
+ pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
271
+ pred_scores_list = pred_scores.split(boxes_per_image, 0)
272
+ pred_theta_list = theta_preds.split(boxes_per_image, 0)
273
+
274
+ all_boxes = []
275
+ all_scores = []
276
+ all_labels = []
277
+ all_thetas = []
278
+ for boxes, scores, thetas, image_shape in zip(pred_boxes_list, pred_scores_list, pred_theta_list, image_shapes):
279
+ boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
280
+
281
+ # create labels for each prediction
282
+ labels = torch.arange(num_classes, device=device)
283
+ labels = labels.view(1, -1).expand_as(scores)
284
+
285
+ # remove predictions with the background label
286
+ boxes = boxes[:, 1:]
287
+ scores = scores[:, 1:]
288
+ labels = labels[:, 1:]
289
+ thetas = thetas[:, 1:]
290
+
291
+ if thetas.shape[1] != 1:
292
+ nbins = thetas.shape[1]
293
+ angle_per_bin = torch.pi / nbins
294
+ max_val_idx = torch.argmax(thetas, dim=1)
295
+ max_val_theta = angle_per_bin * max_val_idx.float()
296
+
297
+ thetas = max_val_theta
298
+
299
+ # batch everything, by making every class prediction be a separate instance
300
+ boxes = boxes.reshape(-1, 4)
301
+ scores = scores.reshape(-1)
302
+ labels = labels.reshape(-1)
303
+ thetas = thetas.reshape(-1)
304
+
305
+ # remove low scoring boxes
306
+ inds = torch.where(scores > self.score_thresh)[0]
307
+ boxes, scores, labels, thetas = boxes[inds], scores[inds], labels[inds], thetas[inds]
308
+
309
+ # remove empty boxes
310
+ keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
311
+ boxes, scores, labels, thetas = boxes[keep], scores[keep], labels[keep], thetas[keep]
312
+
313
+ # non-maximum suppression, independently done per class
314
+ keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
315
+ # keep only topk scoring predictions
316
+ keep = keep[: self.detections_per_img]
317
+ boxes, scores, labels, thetas = boxes[keep], scores[keep], labels[keep], thetas[keep]
318
+
319
+ all_boxes.append(boxes)
320
+ all_scores.append(scores)
321
+ all_labels.append(labels)
322
+ all_thetas.append(thetas)
323
+
324
+ return all_boxes, all_scores, all_labels, all_thetas
325
+
326
+ def forward(
327
+ self,
328
+ features, # type: Dict[str, Tensor]
329
+ proposals, # type: List[Tensor]
330
+ image_shapes, # type: List[Tuple[int, int]]
331
+ targets=None, # type: Optional[List[Dict[str, Tensor]]]
332
+ ):
333
+ # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
334
+ """
335
+ Args:
336
+ features (List[Tensor])
337
+ proposals (List[Tensor[N, 4]])
338
+ image_shapes (List[Tuple[H, W]])
339
+ targets (List[Dict])
340
+ """
341
+ if targets is not None:
342
+ for t in targets:
343
+ # TODO: https://github.com/pytorch/pytorch/issues/26731
344
+ floating_point_types = (torch.float, torch.double, torch.half)
345
+ if not t["boxes"].dtype in floating_point_types:
346
+ raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
347
+ if not t["labels"].dtype == torch.int64:
348
+ raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
349
+
350
+ # targets_without_theta = []
351
+ # if targets is not None:
352
+ # for target in targets:
353
+ # target_without_theta = {"boxes": target["boxes"][:, :-1], "labels": target["labels"]}
354
+ # targets_without_theta.append(target_without_theta)
355
+ if self.training:
356
+ proposals, matched_idxs, labels, regression_targets, theta_targets = self.select_training_samples(proposals, targets)
357
+ # print("---------")
358
+ # print(f"{theta_targets.shape=}")
359
+ else:
360
+ labels = None
361
+ regression_targets = None
362
+ theta_targets = None
363
+ matched_idxs = None
364
+
365
+ box_features = self.box_roi_pool(features, proposals, image_shapes)
366
+ box_features = self.box_head(box_features)
367
+
368
+ class_logits, box_regression, theta_preds = self.box_predictor(box_features)
369
+ # print(f"{class_logits.shape=}")
370
+ # print(f"{box_regression.shape=}")
371
+ # print(f"{theta_preds.shape=}")
372
+
373
+ result: List[Dict[str, torch.Tensor]] = []
374
+ losses = {}
375
+ if self.training:
376
+ if labels is None:
377
+ raise ValueError("labels cannot be None")
378
+ if regression_targets is None:
379
+ raise ValueError("regression_targets cannot be None")
380
+ loss_classifier, loss_box_reg, loss_theta = fastrcnn_loss(class_logits, box_regression, theta_preds, labels, regression_targets, theta_targets)
381
+ losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg, "loss_theta": loss_theta}
382
+ else:
383
+ boxes, scores, labels, thetas = self.postprocess_detections(class_logits, box_regression, theta_preds, proposals, image_shapes)
384
+ # print(f"{scores[0]=}")
385
+ # print(f"{labels[0]=}")
386
+ # print(f"{thetas[0]=}")
387
+ num_images = len(boxes)
388
+ for i in range(num_images):
389
+ result.append(
390
+ {
391
+ "boxes": boxes[i],
392
+ "labels": labels[i],
393
+ "scores": scores[i],
394
+ "thetas": thetas[i]
395
+ }
396
+ )
397
+ # print(f"{result}")
398
+
399
+ return result, losses
400
+
detection/rpn.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn, Tensor
5
+ from torch.nn import functional as F
6
+ from torchvision.ops import boxes as box_ops, Conv2dNormActivation
7
+
8
+ from . import _utils as det_utils
9
+
10
+ # Import AnchorGenerator to keep compatibility.
11
+ from .anchor_utils import AnchorGenerator # noqa: 401
12
+ from .image_list import ImageList
13
+
14
+
15
+ class RPNHead(nn.Module):
16
+ """
17
+ Adds a simple RPN Head with classification and regression heads
18
+
19
+ Args:
20
+ in_channels (int): number of channels of the input feature
21
+ num_anchors (int): number of anchors to be predicted
22
+ conv_depth (int, optional): number of convolutions
23
+ """
24
+
25
+ _version = 2
26
+
27
+ def __init__(self, in_channels: int, num_anchors: int, conv_depth=1) -> None:
28
+ super().__init__()
29
+ convs = []
30
+ for _ in range(conv_depth):
31
+ convs.append(Conv2dNormActivation(in_channels, in_channels, kernel_size=3, norm_layer=None))
32
+ self.conv = nn.Sequential(*convs)
33
+ self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
34
+ self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
35
+
36
+ for layer in self.modules():
37
+ if isinstance(layer, nn.Conv2d):
38
+ torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
39
+ if layer.bias is not None:
40
+ torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
41
+
42
+ def _load_from_state_dict(
43
+ self,
44
+ state_dict,
45
+ prefix,
46
+ local_metadata,
47
+ strict,
48
+ missing_keys,
49
+ unexpected_keys,
50
+ error_msgs,
51
+ ):
52
+ version = local_metadata.get("version", None)
53
+
54
+ if version is None or version < 2:
55
+ for type in ["weight", "bias"]:
56
+ old_key = f"{prefix}conv.{type}"
57
+ new_key = f"{prefix}conv.0.0.{type}"
58
+ if old_key in state_dict:
59
+ state_dict[new_key] = state_dict.pop(old_key)
60
+
61
+ super()._load_from_state_dict(
62
+ state_dict,
63
+ prefix,
64
+ local_metadata,
65
+ strict,
66
+ missing_keys,
67
+ unexpected_keys,
68
+ error_msgs,
69
+ )
70
+
71
+ def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
72
+ logits = []
73
+ bbox_reg = []
74
+ for feature in x:
75
+ t = self.conv(feature)
76
+ logits.append(self.cls_logits(t))
77
+ bbox_reg.append(self.bbox_pred(t))
78
+ return logits, bbox_reg
79
+
80
+
81
+ def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor:
82
+ layer = layer.view(N, -1, C, H, W)
83
+ layer = layer.permute(0, 3, 4, 1, 2)
84
+ layer = layer.reshape(N, -1, C)
85
+ return layer
86
+
87
+
88
+ def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]:
89
+ box_cls_flattened = []
90
+ box_regression_flattened = []
91
+ # for each feature level, permute the outputs to make them be in the
92
+ # same format as the labels. Note that the labels are computed for
93
+ # all feature levels concatenated, so we keep the same representation
94
+ # for the objectness and the box_regression
95
+ for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
96
+ N, AxC, H, W = box_cls_per_level.shape
97
+ Ax4 = box_regression_per_level.shape[1]
98
+ A = Ax4 // 4
99
+ C = AxC // A
100
+ box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
101
+ box_cls_flattened.append(box_cls_per_level)
102
+
103
+ box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
104
+ box_regression_flattened.append(box_regression_per_level)
105
+ # concatenate on the first dimension (representing the feature levels), to
106
+ # take into account the way the labels were generated (with all feature maps
107
+ # being concatenated as well)
108
+ box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
109
+ box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
110
+ return box_cls, box_regression
111
+
112
+
113
+ class RegionProposalNetwork(torch.nn.Module):
114
+ """
115
+ Implements Region Proposal Network (RPN).
116
+
117
+ Args:
118
+ anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
119
+ maps.
120
+ head (nn.Module): module that computes the objectness and regression deltas
121
+ fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
122
+ considered as positive during training of the RPN.
123
+ bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
124
+ considered as negative during training of the RPN.
125
+ batch_size_per_image (int): number of anchors that are sampled during training of the RPN
126
+ for computing the loss
127
+ positive_fraction (float): proportion of positive anchors in a mini-batch during training
128
+ of the RPN
129
+ pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should
130
+ contain two fields: training and testing, to allow for different values depending
131
+ on training or evaluation
132
+ post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should
133
+ contain two fields: training and testing, to allow for different values depending
134
+ on training or evaluation
135
+ nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
136
+ score_thresh (float): only return proposals with an objectness score greater than score_thresh
137
+
138
+ """
139
+
140
+ __annotations__ = {
141
+ "box_coder": det_utils.BoxCoder,
142
+ "proposal_matcher": det_utils.Matcher,
143
+ "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
144
+ }
145
+
146
+ def __init__(
147
+ self,
148
+ anchor_generator: AnchorGenerator,
149
+ head: nn.Module,
150
+ # Faster-RCNN Training
151
+ fg_iou_thresh: float,
152
+ bg_iou_thresh: float,
153
+ batch_size_per_image: int,
154
+ positive_fraction: float,
155
+ # Faster-RCNN Inference
156
+ pre_nms_top_n: Dict[str, int],
157
+ post_nms_top_n: Dict[str, int],
158
+ nms_thresh: float,
159
+ score_thresh: float = 0.0,
160
+ ) -> None:
161
+ super().__init__()
162
+ self.anchor_generator = anchor_generator
163
+ self.head = head
164
+ self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
165
+
166
+ # used during training
167
+ self.box_similarity = box_ops.box_iou
168
+
169
+ self.proposal_matcher = det_utils.Matcher(
170
+ fg_iou_thresh,
171
+ bg_iou_thresh,
172
+ allow_low_quality_matches=True,
173
+ )
174
+
175
+ self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
176
+ # used during testing
177
+ self._pre_nms_top_n = pre_nms_top_n
178
+ self._post_nms_top_n = post_nms_top_n
179
+ self.nms_thresh = nms_thresh
180
+ self.score_thresh = score_thresh
181
+ self.min_size = 1e-3
182
+
183
+ def pre_nms_top_n(self) -> int:
184
+ if self.training:
185
+ return self._pre_nms_top_n["training"]
186
+ return self._pre_nms_top_n["testing"]
187
+
188
+ def post_nms_top_n(self) -> int:
189
+ if self.training:
190
+ return self._post_nms_top_n["training"]
191
+ return self._post_nms_top_n["testing"]
192
+
193
+ def assign_targets_to_anchors(
194
+ self, anchors: List[Tensor], targets: List[Dict[str, Tensor]]
195
+ ) -> Tuple[List[Tensor], List[Tensor]]:
196
+
197
+ labels = []
198
+ matched_gt_boxes = []
199
+ for anchors_per_image, targets_per_image in zip(anchors, targets):
200
+ gt_boxes = targets_per_image["boxes"]
201
+
202
+ if gt_boxes.numel() == 0:
203
+ # Background image (negative example)
204
+ device = anchors_per_image.device
205
+ matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
206
+ labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
207
+ else:
208
+ match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
209
+ matched_idxs = self.proposal_matcher(match_quality_matrix)
210
+ # get the targets corresponding GT for each proposal
211
+ # NB: need to clamp the indices because we can have a single
212
+ # GT in the image, and matched_idxs can be -2, which goes
213
+ # out of bounds
214
+ matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
215
+
216
+ labels_per_image = matched_idxs >= 0
217
+ labels_per_image = labels_per_image.to(dtype=torch.float32)
218
+
219
+ # Background (negative examples)
220
+ bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
221
+ labels_per_image[bg_indices] = 0.0
222
+
223
+ # discard indices that are between thresholds
224
+ inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
225
+ labels_per_image[inds_to_discard] = -1.0
226
+
227
+ labels.append(labels_per_image)
228
+ matched_gt_boxes.append(matched_gt_boxes_per_image)
229
+ return labels, matched_gt_boxes
230
+
231
+ def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor:
232
+ r = []
233
+ offset = 0
234
+ for ob in objectness.split(num_anchors_per_level, 1):
235
+ num_anchors = ob.shape[1]
236
+ pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1)
237
+ _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
238
+ r.append(top_n_idx + offset)
239
+ offset += num_anchors
240
+ return torch.cat(r, dim=1)
241
+
242
+ def filter_proposals(
243
+ self,
244
+ proposals: Tensor,
245
+ objectness: Tensor,
246
+ image_shapes: List[Tuple[int, int]],
247
+ num_anchors_per_level: List[int],
248
+ ) -> Tuple[List[Tensor], List[Tensor]]:
249
+
250
+ num_images = proposals.shape[0]
251
+ device = proposals.device
252
+ # do not backprop through objectness
253
+ objectness = objectness.detach()
254
+ objectness = objectness.reshape(num_images, -1)
255
+
256
+ levels = [
257
+ torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)
258
+ ]
259
+ levels = torch.cat(levels, 0)
260
+ levels = levels.reshape(1, -1).expand_as(objectness)
261
+
262
+ # select top_n boxes independently per level before applying nms
263
+ top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
264
+
265
+ image_range = torch.arange(num_images, device=device)
266
+ batch_idx = image_range[:, None]
267
+
268
+ objectness = objectness[batch_idx, top_n_idx]
269
+ levels = levels[batch_idx, top_n_idx]
270
+ proposals = proposals[batch_idx, top_n_idx]
271
+
272
+ objectness_prob = torch.sigmoid(objectness)
273
+
274
+ final_boxes = []
275
+ final_scores = []
276
+ for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
277
+ boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
278
+
279
+ # remove small boxes
280
+ keep = box_ops.remove_small_boxes(boxes, self.min_size)
281
+ boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
282
+
283
+ # remove low scoring boxes
284
+ # use >= for Backwards compatibility
285
+ keep = torch.where(scores >= self.score_thresh)[0]
286
+ boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
287
+
288
+ # non-maximum suppression, independently done per level
289
+ keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
290
+
291
+ # keep only topk scoring predictions
292
+ keep = keep[: self.post_nms_top_n()]
293
+ boxes, scores = boxes[keep], scores[keep]
294
+
295
+ final_boxes.append(boxes)
296
+ final_scores.append(scores)
297
+ return final_boxes, final_scores
298
+
299
+ def compute_loss(
300
+ self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor]
301
+ ) -> Tuple[Tensor, Tensor]:
302
+ """
303
+ Args:
304
+ objectness (Tensor)
305
+ pred_bbox_deltas (Tensor)
306
+ labels (List[Tensor])
307
+ regression_targets (List[Tensor])
308
+
309
+ Returns:
310
+ objectness_loss (Tensor)
311
+ box_loss (Tensor)
312
+ """
313
+
314
+ sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
315
+ sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
316
+ sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
317
+
318
+ sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
319
+ objectness = objectness.flatten()
320
+
321
+ labels = torch.cat(labels, dim=0)
322
+ regression_targets = torch.cat(regression_targets, dim=0)
323
+
324
+ box_loss = F.smooth_l1_loss(
325
+ pred_bbox_deltas[sampled_pos_inds],
326
+ regression_targets[sampled_pos_inds],
327
+ beta=1 / 9,
328
+ reduction="sum",
329
+ ) / (sampled_inds.numel())
330
+
331
+ objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds])
332
+
333
+ return objectness_loss, box_loss
334
+
335
+ def forward(
336
+ self,
337
+ images: ImageList,
338
+ features: Dict[str, Tensor],
339
+ targets: Optional[List[Dict[str, Tensor]]] = None,
340
+ ) -> Tuple[List[Tensor], Dict[str, Tensor]]:
341
+
342
+ """
343
+ Args:
344
+ images (ImageList): images for which we want to compute the predictions
345
+ features (Dict[str, Tensor]): features computed from the images that are
346
+ used for computing the predictions. Each tensor in the list
347
+ correspond to different feature levels
348
+ targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional).
349
+ If provided, each element in the dict should contain a field `boxes`,
350
+ with the locations of the ground-truth boxes.
351
+
352
+ Returns:
353
+ boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
354
+ image.
355
+ losses (Dict[str, Tensor]): the losses for the model during training. During
356
+ testing, it is an empty dict.
357
+ """
358
+ # RPN uses all feature maps that are available
359
+ features = list(features.values())
360
+ objectness, pred_bbox_deltas = self.head(features)
361
+ anchors = self.anchor_generator(images, features)
362
+ num_images = len(anchors)
363
+ num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
364
+ num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
365
+ objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
366
+ # apply pred_bbox_deltas to anchors to obtain the decoded proposals
367
+ # note that we detach the deltas because Faster R-CNN do not backprop through
368
+ # the proposals
369
+ proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
370
+ proposals = proposals.view(num_images, -1, 4)
371
+ boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
372
+ losses = {}
373
+ if self.training:
374
+ if targets is None:
375
+ raise ValueError("targets should not be None")
376
+ labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
377
+ regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
378
+ loss_objectness, loss_rpn_box_reg = self.compute_loss(
379
+ objectness, pred_bbox_deltas, labels, regression_targets
380
+ )
381
+ losses = {
382
+ "loss_objectness": loss_objectness,
383
+ "loss_rpn_box_reg": loss_rpn_box_reg,
384
+ }
385
+ return boxes, losses
detection/transform.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ import torchvision
6
+ from torch import nn, Tensor
7
+
8
+ from .image_list import ImageList
9
+
10
+
11
+ @torch.jit.unused
12
+ def _get_shape_onnx(image: Tensor) -> Tensor:
13
+ from torch.onnx import operators
14
+
15
+ return operators.shape_as_tensor(image)[-2:]
16
+
17
+
18
+ @torch.jit.unused
19
+ def _fake_cast_onnx(v: Tensor) -> float:
20
+ # ONNX requires a tensor but here we fake its type for JIT.
21
+ return v
22
+
23
+
24
+ def _resize_image_and_masks(
25
+ image: Tensor,
26
+ self_min_size: int,
27
+ self_max_size: int,
28
+ target: Optional[Dict[str, Tensor]] = None,
29
+ fixed_size: Optional[Tuple[int, int]] = None,
30
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
31
+ if torchvision._is_tracing():
32
+ im_shape = _get_shape_onnx(image)
33
+ elif torch.jit.is_scripting():
34
+ im_shape = torch.tensor(image.shape[-2:])
35
+ else:
36
+ im_shape = image.shape[-2:]
37
+
38
+ size: Optional[List[int]] = None
39
+ scale_factor: Optional[float] = None
40
+ recompute_scale_factor: Optional[bool] = None
41
+ if fixed_size is not None:
42
+ size = [fixed_size[1], fixed_size[0]]
43
+ else:
44
+ if torch.jit.is_scripting() or torchvision._is_tracing():
45
+ min_size = torch.min(im_shape).to(dtype=torch.float32)
46
+ max_size = torch.max(im_shape).to(dtype=torch.float32)
47
+ self_min_size_f = float(self_min_size)
48
+ self_max_size_f = float(self_max_size)
49
+ scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)
50
+
51
+ if torchvision._is_tracing():
52
+ scale_factor = _fake_cast_onnx(scale)
53
+ else:
54
+ scale_factor = scale.item()
55
+
56
+ else:
57
+ # Do it the normal way
58
+ min_size = min(im_shape)
59
+ max_size = max(im_shape)
60
+ scale_factor = min(self_min_size / min_size, self_max_size / max_size)
61
+
62
+ recompute_scale_factor = True
63
+
64
+ image = torch.nn.functional.interpolate(
65
+ image[None],
66
+ size=size,
67
+ scale_factor=scale_factor,
68
+ mode="bilinear",
69
+ recompute_scale_factor=recompute_scale_factor,
70
+ align_corners=False,
71
+ )[0]
72
+
73
+ if target is None:
74
+ return image, target
75
+
76
+ if "masks" in target:
77
+ mask = target["masks"]
78
+ mask = torch.nn.functional.interpolate(
79
+ mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
80
+ )[:, 0].byte()
81
+ target["masks"] = mask
82
+ return image, target
83
+
84
+
85
+ class GeneralizedRCNNTransform(nn.Module):
86
+ """
87
+ Performs input / target transformation before feeding the data to a GeneralizedRCNN
88
+ model.
89
+
90
+ The transformations it performs are:
91
+ - input normalization (mean subtraction and std division)
92
+ - input / target resizing to match min_size / max_size
93
+
94
+ It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ min_size: int,
100
+ max_size: int,
101
+ image_mean: List[float],
102
+ image_std: List[float],
103
+ size_divisible: int = 32,
104
+ fixed_size: Optional[Tuple[int, int]] = None,
105
+ **kwargs: Any,
106
+ ):
107
+ super().__init__()
108
+ if not isinstance(min_size, (list, tuple)):
109
+ min_size = (min_size,)
110
+ self.min_size = min_size
111
+ self.max_size = max_size
112
+ self.image_mean = image_mean
113
+ self.image_std = image_std
114
+ self.size_divisible = size_divisible
115
+ self.fixed_size = fixed_size
116
+ self._skip_resize = kwargs.pop("_skip_resize", False)
117
+
118
+ def forward(
119
+ self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
120
+ ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
121
+ images = [img for img in images]
122
+ if targets is not None:
123
+ # make a copy of targets to avoid modifying it in-place
124
+ # once torchscript supports dict comprehension
125
+ # this can be simplified as follows
126
+ # targets = [{k: v for k,v in t.items()} for t in targets]
127
+ targets_copy: List[Dict[str, Tensor]] = []
128
+ for t in targets:
129
+ data: Dict[str, Tensor] = {}
130
+ for k, v in t.items():
131
+ data[k] = v
132
+ targets_copy.append(data)
133
+ targets = targets_copy
134
+ for i in range(len(images)):
135
+ image = images[i]
136
+ target_index = targets[i] if targets is not None else None
137
+
138
+ if image.dim() != 3:
139
+ raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
140
+ image = self.normalize(image)
141
+ image, target_index = self.resize(image, target_index)
142
+ images[i] = image
143
+ if targets is not None and target_index is not None:
144
+ targets[i] = target_index
145
+
146
+ image_sizes = [img.shape[-2:] for img in images]
147
+ images = self.batch_images(images, size_divisible=self.size_divisible)
148
+ image_sizes_list: List[Tuple[int, int]] = []
149
+ for image_size in image_sizes:
150
+ torch._assert(
151
+ len(image_size) == 2,
152
+ f"Input tensors expected to have in the last two elements H and W, instead got {image_size}",
153
+ )
154
+ image_sizes_list.append((image_size[0], image_size[1]))
155
+
156
+ image_list = ImageList(images, image_sizes_list)
157
+ return image_list, targets
158
+
159
+ def normalize(self, image: Tensor) -> Tensor:
160
+ if not image.is_floating_point():
161
+ raise TypeError(
162
+ f"Expected input images to be of floating type (in range [0, 1]), "
163
+ f"but found type {image.dtype} instead"
164
+ )
165
+ dtype, device = image.dtype, image.device
166
+ mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
167
+ std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
168
+ return (image - mean[:, None, None]) / std[:, None, None]
169
+
170
+ def torch_choice(self, k: List[int]) -> int:
171
+ """
172
+ Implements `random.choice` via torch ops, so it can be compiled with
173
+ TorchScript and we use PyTorch's RNG (not native RNG)
174
+ """
175
+ index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
176
+ return k[index]
177
+
178
+ def resize(
179
+ self,
180
+ image: Tensor,
181
+ target: Optional[Dict[str, Tensor]] = None,
182
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
183
+ h, w = image.shape[-2:]
184
+ if self.training:
185
+ if self._skip_resize:
186
+ return image, target
187
+ size = self.torch_choice(self.min_size)
188
+ else:
189
+ size = self.min_size[-1]
190
+ image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
191
+
192
+ if target is None:
193
+ return image, target
194
+
195
+ bbox = target["boxes"]
196
+ bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
197
+ target["boxes"] = bbox
198
+
199
+ if "keypoints" in target:
200
+ keypoints = target["keypoints"]
201
+ keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
202
+ target["keypoints"] = keypoints
203
+ return image, target
204
+
205
+ # _onnx_batch_images() is an implementation of
206
+ # batch_images() that is supported by ONNX tracing.
207
+ @torch.jit.unused
208
+ def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
209
+ max_size = []
210
+ for i in range(images[0].dim()):
211
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
212
+ max_size.append(max_size_i)
213
+ stride = size_divisible
214
+ max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
215
+ max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
216
+ max_size = tuple(max_size)
217
+
218
+ # work around for
219
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
220
+ # which is not yet supported in onnx
221
+ padded_imgs = []
222
+ for img in images:
223
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
224
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
225
+ padded_imgs.append(padded_img)
226
+
227
+ return torch.stack(padded_imgs)
228
+
229
+ def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
230
+ maxes = the_list[0]
231
+ for sublist in the_list[1:]:
232
+ for index, item in enumerate(sublist):
233
+ maxes[index] = max(maxes[index], item)
234
+ return maxes
235
+
236
+ def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
237
+ if torchvision._is_tracing():
238
+ # batch_images() does not export well to ONNX
239
+ # call _onnx_batch_images() instead
240
+ return self._onnx_batch_images(images, size_divisible)
241
+
242
+ max_size = self.max_by_axis([list(img.shape) for img in images])
243
+ stride = float(size_divisible)
244
+ max_size = list(max_size)
245
+ max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
246
+ max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
247
+
248
+ batch_shape = [len(images)] + max_size
249
+ batched_imgs = images[0].new_full(batch_shape, 0)
250
+ for i in range(batched_imgs.shape[0]):
251
+ img = images[i]
252
+ batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
253
+
254
+ return batched_imgs
255
+
256
+ def postprocess(
257
+ self,
258
+ result: List[Dict[str, Tensor]],
259
+ image_shapes: List[Tuple[int, int]],
260
+ original_image_sizes: List[Tuple[int, int]],
261
+ ) -> List[Dict[str, Tensor]]:
262
+ if self.training:
263
+ return result
264
+ for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
265
+ boxes = pred["boxes"]
266
+ boxes = resize_boxes(boxes, im_s, o_im_s)
267
+ result[i]["boxes"] = boxes
268
+ if "masks" in pred:
269
+ masks = pred["masks"]
270
+ masks = paste_masks_in_image(masks, boxes, o_im_s)
271
+ result[i]["masks"] = masks
272
+ if "keypoints" in pred:
273
+ keypoints = pred["keypoints"]
274
+ keypoints = resize_keypoints(keypoints, im_s, o_im_s)
275
+ result[i]["keypoints"] = keypoints
276
+ return result
277
+
278
+ def __repr__(self) -> str:
279
+ format_string = f"{self.__class__.__name__}("
280
+ _indent = "\n "
281
+ format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})"
282
+ format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')"
283
+ format_string += "\n)"
284
+ return format_string
285
+
286
+
287
+ def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
288
+ ratios = [
289
+ torch.tensor(s, dtype=torch.float32, device=keypoints.device)
290
+ / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
291
+ for s, s_orig in zip(new_size, original_size)
292
+ ]
293
+ ratio_h, ratio_w = ratios
294
+ resized_data = keypoints.clone()
295
+ if torch._C._get_tracing_state():
296
+ resized_data_0 = resized_data[:, :, 0] * ratio_w
297
+ resized_data_1 = resized_data[:, :, 1] * ratio_h
298
+ resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
299
+ else:
300
+ resized_data[..., 0] *= ratio_w
301
+ resized_data[..., 1] *= ratio_h
302
+ return resized_data
303
+
304
+
305
+ def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
306
+ ratios = [
307
+ torch.tensor(s, dtype=torch.float32, device=boxes.device)
308
+ / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
309
+ for s, s_orig in zip(new_size, original_size)
310
+ ]
311
+ ratio_height, ratio_width = ratios
312
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
313
+
314
+ xmin = xmin * ratio_width
315
+ xmax = xmax * ratio_width
316
+ ymin = ymin * ratio_height
317
+ ymax = ymax * ratio_height
318
+ return torch.stack((xmin, ymin, xmax, ymax), dim=1)
infer.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ import torchvision
5
+ import argparse
6
+ import random
7
+ import os
8
+ import yaml
9
+ from tqdm import tqdm
10
+ from dataset.st import SceneTextDataset
11
+ from torch.utils.data.dataloader import DataLoader
12
+
13
+ import detection
14
+ from detection.faster_rcnn import FastRCNNPredictor
15
+ from shapely.geometry import Polygon
16
+ from detection.anchor_utils import AnchorGenerator
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+
21
+ def get_iou(det, gt):
22
+ det_x, det_y, det_w, det_h, det_theta = det
23
+ gt_x, gt_y, gt_w, gt_h, gt_theta = gt
24
+
25
+ def get_rotated_box(x, y, w, h, theta):
26
+ cos_t, sin_t = np.cos(theta), np.sin(theta)
27
+ dx, dy = w / 2, h / 2
28
+ corners = np.array([
29
+ [-dx, -dy], [dx, -dy], [dx, dy], [-dx, dy]
30
+ ])
31
+ rotation_matrix = np.array([[cos_t, -sin_t], [sin_t, cos_t]])
32
+ rotated_corners = np.dot(corners, rotation_matrix.T) + np.array([x, y])
33
+ return Polygon(rotated_corners)
34
+
35
+ det_poly = get_rotated_box(det_x, det_y, det_w, det_h, det_theta)
36
+ gt_poly = get_rotated_box(gt_x, gt_y, gt_w, gt_h, gt_theta)
37
+
38
+ if not det_poly.intersects(gt_poly):
39
+ return 0.0
40
+
41
+ intersection_area = det_poly.intersection(gt_poly).area
42
+ union_area = det_poly.area + gt_poly.area - intersection_area + 1E-6
43
+ return intersection_area / union_area
44
+
45
+
46
+ def compute_map(det_boxes, gt_boxes, iou_threshold=0.5, method="area", return_pr=False):
47
+ gt_labels = {cls_key for im_gt in gt_boxes for cls_key in im_gt.keys()}
48
+ gt_labels = sorted(gt_labels)
49
+
50
+ all_aps = {}
51
+ all_precisions = {}
52
+ all_recalls = {}
53
+
54
+ aps = []
55
+
56
+ for idx, label in enumerate(gt_labels):
57
+ # Get detection predictions of this class
58
+ cls_dets = [
59
+ [im_idx, im_dets_label]
60
+ for im_idx, im_dets in enumerate(det_boxes)
61
+ if label in im_dets
62
+ for im_dets_label in im_dets[label]
63
+ ]
64
+
65
+ # Sort by confidence score (descending)
66
+ cls_dets = sorted(cls_dets, key=lambda k: -k[1][-1])
67
+
68
+ # Track matched GT boxes
69
+ gt_matched = [[False for _ in im_gts[label]] for im_gts in gt_boxes]
70
+ num_gts = sum([len(im_gts[label]) for im_gts in gt_boxes])
71
+
72
+ tp = np.zeros(len(cls_dets))
73
+ fp = np.zeros(len(cls_dets))
74
+
75
+ # Process each detection
76
+ for det_idx, (im_idx, det_pred) in enumerate(cls_dets):
77
+ im_gts = gt_boxes[im_idx][label]
78
+ max_iou_found = -1
79
+ max_iou_gt_idx = -1
80
+
81
+ # Find the best-matching GT box
82
+ for gt_box_idx, gt_box in enumerate(im_gts):
83
+ gt_box_iou = get_iou(det_pred[:-1], gt_box)
84
+ if gt_box_iou > max_iou_found:
85
+ max_iou_found = gt_box_iou
86
+ max_iou_gt_idx = gt_box_idx
87
+
88
+ # True Positive if IoU >= threshold & GT box is not already matched
89
+ if max_iou_found < iou_threshold or gt_matched[im_idx][max_iou_gt_idx]:
90
+ fp[det_idx] = 1
91
+ else:
92
+ tp[det_idx] = 1
93
+ gt_matched[im_idx][max_iou_gt_idx] = True
94
+
95
+ # Compute cumulative sums for TP and FP
96
+ tp = np.cumsum(tp)
97
+ fp = np.cumsum(fp)
98
+
99
+ eps = np.finfo(np.float32).eps
100
+ recalls = tp / np.maximum(num_gts, eps)
101
+ precisions = tp / np.maximum(tp + fp, eps)
102
+
103
+ # Compute AP
104
+ if method == "area":
105
+ recalls = np.concatenate(([0.0], recalls, [1.0]))
106
+ precisions = np.concatenate(([0.0], precisions, [0.0]))
107
+
108
+ for i in range(len(precisions) - 1, 0, -1):
109
+ precisions[i - 1] = np.maximum(precisions[i - 1], precisions[i])
110
+
111
+ i = np.where(recalls[1:] != recalls[:-1])[0]
112
+ ap = np.sum((recalls[i + 1] - recalls[i]) * precisions[i + 1])
113
+
114
+ elif method == "interp":
115
+ ap = (
116
+ sum(
117
+ [
118
+ max(precisions[recalls >= t]) if any(recalls >= t) else 0
119
+ for t in np.arange(0, 1.1, 0.1)
120
+ ]
121
+ )
122
+ / 11.0
123
+ )
124
+ else:
125
+ raise ValueError("Method must be 'area' or 'interp'")
126
+
127
+ if num_gts > 0:
128
+ aps.append(ap)
129
+ all_aps[label] = ap
130
+ all_precisions[label] = precisions.tolist()
131
+ all_recalls[label] = recalls.tolist()
132
+ else:
133
+ all_aps[label] = np.nan
134
+ all_precisions[label] = []
135
+ all_recalls[label] = []
136
+
137
+ mean_ap = sum(aps) / len(aps) if aps else 0.0
138
+
139
+ if return_pr:
140
+ return mean_ap, all_aps, all_precisions, all_recalls
141
+ else:
142
+ return mean_ap, all_aps
143
+
144
+
145
+ def load_model_and_dataset(args):
146
+ # Read the config file #
147
+ with open(args.config_path, "r") as file:
148
+ try:
149
+ config = yaml.safe_load(file)
150
+ except yaml.YAMLError as exc:
151
+ print(exc)
152
+ print(config)
153
+ ########################
154
+
155
+ dataset_config = config["dataset_params"]
156
+ model_config = config["model_params"]
157
+ train_config = config["train_params"]
158
+
159
+ seed = train_config["seed"]
160
+ torch.manual_seed(seed)
161
+ np.random.seed(seed)
162
+ random.seed(seed)
163
+ if device == "cuda":
164
+ torch.cuda.manual_seed_all(seed)
165
+
166
+ st = SceneTextDataset(args.split_type, root_dir=dataset_config["root_dir"])
167
+ test_dataset = DataLoader(st, batch_size=1, shuffle=False)
168
+
169
+ faster_rcnn_model = detection.fasterrcnn_resnet50_fpn(
170
+ pretrained=True,
171
+ min_size=600,
172
+ max_size=1000,
173
+ box_score_thresh=0.7,
174
+ )
175
+ faster_rcnn_model.roi_heads.box_predictor = FastRCNNPredictor(
176
+ faster_rcnn_model.roi_heads.box_predictor.cls_score.in_features,
177
+ num_classes=dataset_config["num_classes"],
178
+ num_theta_bins=args.num_theta_bins,
179
+ )
180
+
181
+ faster_rcnn_model.eval()
182
+ faster_rcnn_model.to(device)
183
+ faster_rcnn_model.load_state_dict(
184
+ torch.load(
185
+ os.path.join(
186
+ train_config["task_name"],
187
+ "tv_frcnn_r50fpn_" + train_config["ckpt_name"],
188
+ ),
189
+ map_location='cpu',
190
+ )
191
+ )
192
+
193
+ return faster_rcnn_model, st, test_dataset
194
+
195
+
196
+ def evaluate_metrics(args):
197
+ faster_rcnn_model, voc, test_dataset = load_model_and_dataset(args)
198
+
199
+ gts = []
200
+ preds = []
201
+
202
+ for im, target, fname in tqdm(test_dataset):
203
+ im_name = fname
204
+ im = im.float().to(device)
205
+ target_boxes = target["bboxes"].float().to(device)[0]
206
+ target_labels = target["labels"].long().to(device)[0]
207
+ target_thetas = target["thetas"].float().to(device)[0]
208
+ frcnn_output = faster_rcnn_model(im, None)[0]
209
+
210
+ boxes = frcnn_output["boxes"]
211
+ labels = frcnn_output["labels"]
212
+ scores = frcnn_output["scores"]
213
+ thetas = frcnn_output["thetas"]
214
+
215
+ pred_boxes = {label_name: [] for label_name in voc.label2idx}
216
+ gt_boxes = {label_name: [] for label_name in voc.label2idx}
217
+
218
+ for idx, box in enumerate(boxes):
219
+ x1, y1, x2, y2 = box.detach().cpu().numpy()
220
+ label = labels[idx].detach().cpu().item()
221
+ score = scores[idx].detach().cpu().item()
222
+ theta = thetas[idx].detach().cpu().item()
223
+ label_name = voc.idx2label[label]
224
+ pred_boxes[label_name].append([x1, y1, x2, y2, theta, score])
225
+
226
+ for idx, box in enumerate(target_boxes):
227
+ x1, y1, x2, y2 = box.detach().cpu().numpy()
228
+ label = target_labels[idx].detach().cpu().item()
229
+ label_name = voc.idx2label[label]
230
+ theta = target_thetas[idx].detach().cpu().item()
231
+ gt_boxes[label_name].append([x1, y1, x2, y2, theta])
232
+
233
+ gts.append(gt_boxes)
234
+ preds.append(pred_boxes)
235
+
236
+ # Compute Mean Average Precision and Precision-Recall values
237
+ mean_ap, all_aps, precisions, recalls = compute_map(
238
+ preds, gts, method="interp", return_pr=True
239
+ )
240
+
241
+ mean_precision = 0
242
+ mean_recall = 0
243
+ num_classes = len(voc.idx2label)
244
+
245
+ for idx in range(num_classes):
246
+ class_name = voc.idx2label[idx]
247
+ ap = all_aps[class_name]
248
+ prec = precisions[class_name]
249
+ rec = recalls[class_name]
250
+
251
+ mean_precision += sum(prec) / len(prec) if len(prec) > 0 else 0
252
+ mean_recall += sum(rec) / len(rec) if len(rec) > 0 else 0
253
+
254
+ print(f"Class: {class_name}")
255
+ print(
256
+ f" AP: {ap:.4f}, Precision: {sum(prec) / len(prec) if len(prec) > 0 else 0:.4f}, Recall: {sum(rec) / len(rec) if len(rec) > 0 else 0:.4f}"
257
+ )
258
+
259
+ mean_precision /= num_classes
260
+ mean_recall /= num_classes
261
+
262
+ print(f"Mean Average Precision (mAP): {mean_ap:.4f}")
263
+ print(f"Mean Precision: {mean_precision:.4f}")
264
+ print(f"Mean Recall: {mean_recall:.4f}")
265
+ return mean_ap, mean_precision, mean_recall
266
+
267
+
268
+ def infer(args):
269
+
270
+ output_dir = "samples_tv_r50fpn"
271
+ if not os.path.exists(output_dir):
272
+ os.mkdir(output_dir)
273
+ faster_rcnn_model, voc, test_dataset = load_model_and_dataset(args)
274
+
275
+ for sample_count in tqdm(range(10)):
276
+ random_idx = random.randint(0, len(voc))
277
+ im, target, fname = voc[random_idx]
278
+ im = im.unsqueeze(0).float().to(device)
279
+
280
+ gt_im = cv2.imread(fname)
281
+ gt_im_copy = gt_im.copy()
282
+
283
+ # Saving images with ground truth boxes
284
+ for idx, box in enumerate(target["bboxes"]):
285
+ x1, y1, x2, y2 = box.detach().cpu().numpy()
286
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
287
+ theta = target["thetas"][idx].detach().cpu().numpy() * 180 / np.pi
288
+
289
+ cx, cy, w, h = (x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1
290
+ box = cv2.boxPoints(((cx, cy), (w, h), theta))
291
+ box = box.astype(np.int32)
292
+ cv2.drawContours(gt_im, [box], 0, (0, 255, 0), 2)
293
+ cv2.drawContours(gt_im_copy, [box], 0, (0, 255, 0), 2)
294
+
295
+ cv2.addWeighted(gt_im_copy, 0.7, gt_im, 0.3, 0, gt_im)
296
+ cv2.imwrite("{}/output_frcnn_gt_{}.png".format(output_dir, sample_count), gt_im)
297
+
298
+ # Getting predictions from trained model
299
+ frcnn_output = faster_rcnn_model(im, None)[0]
300
+ boxes = frcnn_output["boxes"]
301
+ labels = frcnn_output["labels"]
302
+ scores = frcnn_output["scores"]
303
+ thetas = frcnn_output["thetas"]
304
+ im = cv2.imread(fname)
305
+ im_copy = im.copy()
306
+
307
+ # Saving images with predicted boxes
308
+ for idx, box in enumerate(boxes):
309
+ x1, y1, x2, y2 = box.detach().cpu().numpy()
310
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
311
+ theta = thetas[idx].detach().cpu().numpy() * 180 / np.pi
312
+ cx, cy, w, h = (x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1
313
+ box = cv2.boxPoints(((cx, cy), (w, h), theta))
314
+ box = box.astype(np.int32)
315
+ cv2.drawContours(im, [box], 0, (0, 255, 0), 2)
316
+ cv2.drawContours(im_copy, [box], 0, (0, 255, 0), 2)
317
+ cv2.addWeighted(im_copy, 0.7, im, 0.3, 0, im)
318
+ cv2.imwrite("{}/output_frcnn_{}.jpg".format(output_dir, sample_count), im)
319
+
320
+
321
+ if __name__ == "__main__":
322
+
323
+ # print(torch)
324
+
325
+ parser = argparse.ArgumentParser(
326
+ description="Arguments for inference using torchvision code faster rcnn"
327
+ )
328
+ parser.add_argument(
329
+ "--config", dest="config_path", default="config/st.yaml", type=str
330
+ )
331
+ parser.add_argument("--evaluate", dest="evaluate", default=False, type=bool)
332
+ parser.add_argument(
333
+ "--infer_samples", dest="infer_samples", default=True, type=bool
334
+ )
335
+ args = parser.parse_args()
336
+ args.split_type = "train"
337
+ args.num_theta_bins = 359
338
+ if args.infer_samples:
339
+ infer(args)
340
+ else:
341
+ print("Not Inferring for samples as `infer_samples` argument is False")
342
+
343
+ if args.evaluate:
344
+ evaluate_metrics(args)
345
+ else:
346
+ print("Not Evaluating as `evaluate` argument is False")
requirements.txt ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets==3.4.0
2
+ deepdiff==8.4.1
3
+ dill==0.3.8
4
+ distlib==0.3.9
5
+ docker-pycreds==0.4.0
6
+ evaluate==0.4.3
7
+ fastapi==0.115.12
8
+ fastrlock==0.8.3
9
+ ffmpy==0.5.0
10
+ filelock==3.17.0
11
+ flatbuffers==25.2.10
12
+ fonttools==4.56.0
13
+ frozenlist==1.5.0
14
+ fsspec==2024.12.0
15
+ gast==0.6.0
16
+ gensim==4.3.3
17
+ gitdb==4.0.12
18
+ gradio==5.25.2
19
+ gradio_client==1.8.0
20
+ groovy==0.1.2
21
+ grpcio==1.71.0
22
+ h11==0.14.0
23
+ h5py==3.13.0
24
+ httpcore==1.0.8
25
+ httpx==0.28.1
26
+ huggingface-hub==0.29.3
27
+ humanfriendly==10.0
28
+ idna==3.10
29
+ imageio==2.37.0
30
+ Jinja2==3.1.6
31
+ joblib==1.4.2
32
+ jsonlines==4.0.0
33
+ kiwisolver==1.4.8
34
+ libclang==18.1.1
35
+ lightning-utilities==0.14.3
36
+ lm_eval==0.4.8
37
+ lxml==5.3.1
38
+ Markdown==3.7
39
+ markdown-it-py==3.0.0
40
+ MarkupSafe==3.0.2
41
+ matplotlib==3.10.1
42
+ mbstrdecoder==1.1.4
43
+ mdurl==0.1.2
44
+ ml_dtypes==0.5.1
45
+ more-itertools==10.6.0
46
+ mpmath==1.3.0
47
+ multidict==6.1.0
48
+ multiprocess==0.70.16
49
+ namex==0.0.8
50
+ networkx==3.4.2
51
+ nltk==3.9.1
52
+ numexpr==2.10.2
53
+ numpy==1.26.4
54
+ opencv-python==4.11.0.86
55
+ opt_einsum==3.4.0
56
+ optree==0.14.1
57
+ orderly-set==5.3.0
58
+ orjson==3.10.16
59
+ pandas==2.2.3
60
+ pathvalidate==3.2.3
61
+ peft==0.14.0
62
+ pillow==11.1.0
63
+ pluggy==1.5.0
64
+ portalocker==3.1.1
65
+ propcache==0.3.0
66
+ protobuf==5.29.3
67
+ pyahocorasick==2.1.0
68
+ pyarrow==19.0.1
69
+ pybind11==2.13.6
70
+ pydantic==2.10.6
71
+ pydantic_core==2.27.2
72
+ pydub==0.25.1
73
+ pyparsing==3.2.1
74
+ pyproject-api==1.9.0
75
+ pytablewriter==1.2.1
76
+ python-multipart==0.0.20
77
+ pytz==2025.1
78
+ PyYAML==6.0.2
79
+ regex==2024.11.6
80
+ requests==2.32.3
81
+ rich==13.9.4
82
+ rootpath==0.1.1
83
+ rouge_score==0.1.2
84
+ ruff==0.11.5
85
+ sacrebleu==2.5.1
86
+ safehttpx==0.1.6
87
+ safetensors==0.5.3
88
+ scikit-learn==1.6.1
89
+ scipy==1.13.1
90
+ seaborn==0.13.2
91
+ semantic-version==2.10.0
92
+ sentencepiece==0.2.0
93
+ sentry-sdk==2.22.0
94
+ setproctitle==1.3.5
95
+ shapely==2.0.7
96
+ shellingham==1.5.4
97
+ smart-open==7.1.0
98
+ smmap==5.0.2
99
+ sniffio==1.3.1
100
+ sns==0.1
101
+ torch==2.6.0
102
+ torchaudio==2.6.0
103
+ torchmetrics==1.7.1
104
+ torchvision==0.21.0
105
+ tqdm==4.67.1
st/tv_frcnn_r50fpn_faster_rcnn_st.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df8a754ab48d6c49503dc74fc9840a205d8b11561c2c8f68f736ad1c39a810dc
3
+ size 167210987
st/tv_frcnn_r50fpn_faster_rcnn_st_10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c4bf291c6993de1c5d9d52da604551de7c6b5f6d499db05da7da29c4230bb4b
3
+ size 165780139