Spaces:
Sleeping
Sleeping
uploaded the weights
Browse files- detection/__init__.py +1 -0
- detection/__pycache__/__init__.cpython-311.pyc +0 -0
- detection/__pycache__/__init__.cpython-37.pyc +0 -0
- detection/__pycache__/_utils.cpython-311.pyc +0 -0
- detection/__pycache__/_utils.cpython-37.pyc +0 -0
- detection/__pycache__/anchor_utils.cpython-311.pyc +0 -0
- detection/__pycache__/anchor_utils.cpython-37.pyc +0 -0
- detection/__pycache__/backbone_utils.cpython-311.pyc +0 -0
- detection/__pycache__/backbone_utils.cpython-37.pyc +0 -0
- detection/__pycache__/faster_rcnn.cpython-311.pyc +0 -0
- detection/__pycache__/faster_rcnn.cpython-37.pyc +0 -0
- detection/__pycache__/generalized_rcnn.cpython-311.pyc +0 -0
- detection/__pycache__/generalized_rcnn.cpython-37.pyc +0 -0
- detection/__pycache__/image_list.cpython-311.pyc +0 -0
- detection/__pycache__/image_list.cpython-37.pyc +0 -0
- detection/__pycache__/roi_heads.cpython-311.pyc +0 -0
- detection/__pycache__/roi_heads.cpython-37.pyc +0 -0
- detection/__pycache__/rpn.cpython-311.pyc +0 -0
- detection/__pycache__/rpn.cpython-37.pyc +0 -0
- detection/__pycache__/transform.cpython-311.pyc +0 -0
- detection/__pycache__/transform.cpython-37.pyc +0 -0
- detection/_utils.py +540 -0
- detection/anchor_utils.py +268 -0
- detection/backbone_utils.py +121 -0
- detection/faster_rcnn.py +390 -0
- detection/generalized_rcnn.py +128 -0
- detection/image_list.py +25 -0
- detection/roi_heads.py +400 -0
- detection/rpn.py +385 -0
- detection/transform.py +318 -0
- infer.py +346 -0
- requirements.txt +105 -0
- st/tv_frcnn_r50fpn_faster_rcnn_st.pth +3 -0
- st/tv_frcnn_r50fpn_faster_rcnn_st_10.pth +3 -0
detection/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .faster_rcnn import *
|
detection/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (226 Bytes). View file
|
|
|
detection/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (169 Bytes). View file
|
|
|
detection/__pycache__/_utils.cpython-311.pyc
ADDED
|
Binary file (28.3 kB). View file
|
|
|
detection/__pycache__/_utils.cpython-37.pyc
ADDED
|
Binary file (18.2 kB). View file
|
|
|
detection/__pycache__/anchor_utils.cpython-311.pyc
ADDED
|
Binary file (18.5 kB). View file
|
|
|
detection/__pycache__/anchor_utils.cpython-37.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
detection/__pycache__/backbone_utils.cpython-311.pyc
ADDED
|
Binary file (6.98 kB). View file
|
|
|
detection/__pycache__/backbone_utils.cpython-37.pyc
ADDED
|
Binary file (4.19 kB). View file
|
|
|
detection/__pycache__/faster_rcnn.cpython-311.pyc
ADDED
|
Binary file (20.1 kB). View file
|
|
|
detection/__pycache__/faster_rcnn.cpython-37.pyc
ADDED
|
Binary file (14.4 kB). View file
|
|
|
detection/__pycache__/generalized_rcnn.cpython-311.pyc
ADDED
|
Binary file (6.28 kB). View file
|
|
|
detection/__pycache__/generalized_rcnn.cpython-37.pyc
ADDED
|
Binary file (3.74 kB). View file
|
|
|
detection/__pycache__/image_list.cpython-311.pyc
ADDED
|
Binary file (1.63 kB). View file
|
|
|
detection/__pycache__/image_list.cpython-37.pyc
ADDED
|
Binary file (1.2 kB). View file
|
|
|
detection/__pycache__/roi_heads.cpython-311.pyc
ADDED
|
Binary file (17.9 kB). View file
|
|
|
detection/__pycache__/roi_heads.cpython-37.pyc
ADDED
|
Binary file (9.31 kB). View file
|
|
|
detection/__pycache__/rpn.cpython-311.pyc
ADDED
|
Binary file (20.1 kB). View file
|
|
|
detection/__pycache__/rpn.cpython-37.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
detection/__pycache__/transform.cpython-311.pyc
ADDED
|
Binary file (19.5 kB). View file
|
|
|
detection/__pycache__/transform.cpython-37.pyc
ADDED
|
Binary file (9.92 kB). View file
|
|
|
detection/_utils.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from typing import Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn, Tensor
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BalancedPositiveNegativeSampler:
|
| 12 |
+
"""
|
| 13 |
+
This class samples batches, ensuring that they contain a fixed proportion of positives
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
|
| 17 |
+
"""
|
| 18 |
+
Args:
|
| 19 |
+
batch_size_per_image (int): number of elements to be selected per image
|
| 20 |
+
positive_fraction (float): percentage of positive elements per batch
|
| 21 |
+
"""
|
| 22 |
+
self.batch_size_per_image = batch_size_per_image
|
| 23 |
+
self.positive_fraction = positive_fraction
|
| 24 |
+
|
| 25 |
+
def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
|
| 26 |
+
"""
|
| 27 |
+
Args:
|
| 28 |
+
matched_idxs: list of tensors containing -1, 0 or positive values.
|
| 29 |
+
Each tensor corresponds to a specific image.
|
| 30 |
+
-1 values are ignored, 0 are considered as negatives and > 0 as
|
| 31 |
+
positives.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
pos_idx (list[tensor])
|
| 35 |
+
neg_idx (list[tensor])
|
| 36 |
+
|
| 37 |
+
Returns two lists of binary masks for each image.
|
| 38 |
+
The first list contains the positive elements that were selected,
|
| 39 |
+
and the second list the negative example.
|
| 40 |
+
"""
|
| 41 |
+
pos_idx = []
|
| 42 |
+
neg_idx = []
|
| 43 |
+
for matched_idxs_per_image in matched_idxs:
|
| 44 |
+
positive = torch.where(matched_idxs_per_image >= 1)[0]
|
| 45 |
+
negative = torch.where(matched_idxs_per_image == 0)[0]
|
| 46 |
+
|
| 47 |
+
num_pos = int(self.batch_size_per_image * self.positive_fraction)
|
| 48 |
+
# protect against not enough positive examples
|
| 49 |
+
num_pos = min(positive.numel(), num_pos)
|
| 50 |
+
num_neg = self.batch_size_per_image - num_pos
|
| 51 |
+
# protect against not enough negative examples
|
| 52 |
+
num_neg = min(negative.numel(), num_neg)
|
| 53 |
+
|
| 54 |
+
# randomly select positive and negative examples
|
| 55 |
+
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
|
| 56 |
+
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
|
| 57 |
+
|
| 58 |
+
pos_idx_per_image = positive[perm1]
|
| 59 |
+
neg_idx_per_image = negative[perm2]
|
| 60 |
+
|
| 61 |
+
# create binary mask from indices
|
| 62 |
+
pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
|
| 63 |
+
neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
|
| 64 |
+
|
| 65 |
+
pos_idx_per_image_mask[pos_idx_per_image] = 1
|
| 66 |
+
neg_idx_per_image_mask[neg_idx_per_image] = 1
|
| 67 |
+
|
| 68 |
+
pos_idx.append(pos_idx_per_image_mask)
|
| 69 |
+
neg_idx.append(neg_idx_per_image_mask)
|
| 70 |
+
|
| 71 |
+
return pos_idx, neg_idx
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@torch.jit._script_if_tracing
|
| 75 |
+
def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
|
| 76 |
+
"""
|
| 77 |
+
Encode a set of proposals with respect to some
|
| 78 |
+
reference boxes
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
reference_boxes (Tensor): reference boxes
|
| 82 |
+
proposals (Tensor): boxes to be encoded
|
| 83 |
+
weights (Tensor[4]): the weights for ``(x, y, w, h)``
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
# perform some unpacking to make it JIT-fusion friendly
|
| 87 |
+
wx = weights[0]
|
| 88 |
+
wy = weights[1]
|
| 89 |
+
ww = weights[2]
|
| 90 |
+
wh = weights[3]
|
| 91 |
+
|
| 92 |
+
proposals_x1 = proposals[:, 0].unsqueeze(1)
|
| 93 |
+
proposals_y1 = proposals[:, 1].unsqueeze(1)
|
| 94 |
+
proposals_x2 = proposals[:, 2].unsqueeze(1)
|
| 95 |
+
proposals_y2 = proposals[:, 3].unsqueeze(1)
|
| 96 |
+
|
| 97 |
+
reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
|
| 98 |
+
reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
|
| 99 |
+
reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
|
| 100 |
+
reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
|
| 101 |
+
|
| 102 |
+
# implementation starts here
|
| 103 |
+
ex_widths = proposals_x2 - proposals_x1
|
| 104 |
+
ex_heights = proposals_y2 - proposals_y1
|
| 105 |
+
ex_ctr_x = proposals_x1 + 0.5 * ex_widths
|
| 106 |
+
ex_ctr_y = proposals_y1 + 0.5 * ex_heights
|
| 107 |
+
|
| 108 |
+
gt_widths = reference_boxes_x2 - reference_boxes_x1
|
| 109 |
+
gt_heights = reference_boxes_y2 - reference_boxes_y1
|
| 110 |
+
gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
|
| 111 |
+
gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
|
| 112 |
+
|
| 113 |
+
targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
|
| 114 |
+
targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
|
| 115 |
+
targets_dw = ww * torch.log(gt_widths / ex_widths)
|
| 116 |
+
targets_dh = wh * torch.log(gt_heights / ex_heights)
|
| 117 |
+
|
| 118 |
+
targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
|
| 119 |
+
return targets
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class BoxCoder:
|
| 123 |
+
"""
|
| 124 |
+
This class encodes and decodes a set of bounding boxes into
|
| 125 |
+
the representation used for training the regressors.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
|
| 130 |
+
) -> None:
|
| 131 |
+
"""
|
| 132 |
+
Args:
|
| 133 |
+
weights (4-element tuple)
|
| 134 |
+
bbox_xform_clip (float)
|
| 135 |
+
"""
|
| 136 |
+
self.weights = weights
|
| 137 |
+
self.bbox_xform_clip = bbox_xform_clip
|
| 138 |
+
|
| 139 |
+
def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
|
| 140 |
+
boxes_per_image = [len(b) for b in reference_boxes]
|
| 141 |
+
reference_boxes = torch.cat(reference_boxes, dim=0)
|
| 142 |
+
proposals = torch.cat(proposals, dim=0)
|
| 143 |
+
targets = self.encode_single(reference_boxes, proposals)
|
| 144 |
+
return targets.split(boxes_per_image, 0)
|
| 145 |
+
|
| 146 |
+
def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
|
| 147 |
+
"""
|
| 148 |
+
Encode a set of proposals with respect to some
|
| 149 |
+
reference boxes
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
reference_boxes (Tensor): reference boxes
|
| 153 |
+
proposals (Tensor): boxes to be encoded
|
| 154 |
+
"""
|
| 155 |
+
dtype = reference_boxes.dtype
|
| 156 |
+
device = reference_boxes.device
|
| 157 |
+
weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
|
| 158 |
+
targets = encode_boxes(reference_boxes, proposals, weights)
|
| 159 |
+
|
| 160 |
+
return targets
|
| 161 |
+
|
| 162 |
+
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
|
| 163 |
+
torch._assert(
|
| 164 |
+
isinstance(boxes, (list, tuple)),
|
| 165 |
+
"This function expects boxes of type list or tuple.",
|
| 166 |
+
)
|
| 167 |
+
torch._assert(
|
| 168 |
+
isinstance(rel_codes, torch.Tensor),
|
| 169 |
+
"This function expects rel_codes of type torch.Tensor.",
|
| 170 |
+
)
|
| 171 |
+
boxes_per_image = [b.size(0) for b in boxes]
|
| 172 |
+
concat_boxes = torch.cat(boxes, dim=0)
|
| 173 |
+
box_sum = 0
|
| 174 |
+
for val in boxes_per_image:
|
| 175 |
+
box_sum += val
|
| 176 |
+
if box_sum > 0:
|
| 177 |
+
rel_codes = rel_codes.reshape(box_sum, -1)
|
| 178 |
+
pred_boxes = self.decode_single(rel_codes, concat_boxes)
|
| 179 |
+
if box_sum > 0:
|
| 180 |
+
pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
|
| 181 |
+
return pred_boxes
|
| 182 |
+
|
| 183 |
+
def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
|
| 184 |
+
"""
|
| 185 |
+
From a set of original boxes and encoded relative box offsets,
|
| 186 |
+
get the decoded boxes.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
rel_codes (Tensor): encoded boxes
|
| 190 |
+
boxes (Tensor): reference boxes.
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
boxes = boxes.to(rel_codes.dtype)
|
| 194 |
+
|
| 195 |
+
widths = boxes[:, 2] - boxes[:, 0]
|
| 196 |
+
heights = boxes[:, 3] - boxes[:, 1]
|
| 197 |
+
ctr_x = boxes[:, 0] + 0.5 * widths
|
| 198 |
+
ctr_y = boxes[:, 1] + 0.5 * heights
|
| 199 |
+
|
| 200 |
+
wx, wy, ww, wh = self.weights
|
| 201 |
+
dx = rel_codes[:, 0::4] / wx
|
| 202 |
+
dy = rel_codes[:, 1::4] / wy
|
| 203 |
+
dw = rel_codes[:, 2::4] / ww
|
| 204 |
+
dh = rel_codes[:, 3::4] / wh
|
| 205 |
+
|
| 206 |
+
# Prevent sending too large values into torch.exp()
|
| 207 |
+
dw = torch.clamp(dw, max=self.bbox_xform_clip)
|
| 208 |
+
dh = torch.clamp(dh, max=self.bbox_xform_clip)
|
| 209 |
+
|
| 210 |
+
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
|
| 211 |
+
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
|
| 212 |
+
pred_w = torch.exp(dw) * widths[:, None]
|
| 213 |
+
pred_h = torch.exp(dh) * heights[:, None]
|
| 214 |
+
|
| 215 |
+
# Distance from center to box's corner.
|
| 216 |
+
c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
|
| 217 |
+
c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
|
| 218 |
+
|
| 219 |
+
pred_boxes1 = pred_ctr_x - c_to_c_w
|
| 220 |
+
pred_boxes2 = pred_ctr_y - c_to_c_h
|
| 221 |
+
pred_boxes3 = pred_ctr_x + c_to_c_w
|
| 222 |
+
pred_boxes4 = pred_ctr_y + c_to_c_h
|
| 223 |
+
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
|
| 224 |
+
return pred_boxes
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class BoxLinearCoder:
|
| 228 |
+
"""
|
| 229 |
+
The linear box-to-box transform defined in FCOS. The transformation is parameterized
|
| 230 |
+
by the distance from the center of (square) src box to 4 edges of the target box.
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
def __init__(self, normalize_by_size: bool = True) -> None:
|
| 234 |
+
"""
|
| 235 |
+
Args:
|
| 236 |
+
normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
|
| 237 |
+
"""
|
| 238 |
+
self.normalize_by_size = normalize_by_size
|
| 239 |
+
|
| 240 |
+
def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
|
| 241 |
+
"""
|
| 242 |
+
Encode a set of proposals with respect to some reference boxes
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
reference_boxes (Tensor): reference boxes
|
| 246 |
+
proposals (Tensor): boxes to be encoded
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
Tensor: the encoded relative box offsets that can be used to
|
| 250 |
+
decode the boxes.
|
| 251 |
+
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
# get the center of reference_boxes
|
| 255 |
+
reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
|
| 256 |
+
reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
|
| 257 |
+
|
| 258 |
+
# get box regression transformation deltas
|
| 259 |
+
target_l = reference_boxes_ctr_x - proposals[..., 0]
|
| 260 |
+
target_t = reference_boxes_ctr_y - proposals[..., 1]
|
| 261 |
+
target_r = proposals[..., 2] - reference_boxes_ctr_x
|
| 262 |
+
target_b = proposals[..., 3] - reference_boxes_ctr_y
|
| 263 |
+
|
| 264 |
+
targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
|
| 265 |
+
|
| 266 |
+
if self.normalize_by_size:
|
| 267 |
+
reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
|
| 268 |
+
reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
|
| 269 |
+
reference_boxes_size = torch.stack(
|
| 270 |
+
(reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
|
| 271 |
+
)
|
| 272 |
+
targets = targets / reference_boxes_size
|
| 273 |
+
return targets
|
| 274 |
+
|
| 275 |
+
def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
|
| 276 |
+
|
| 277 |
+
"""
|
| 278 |
+
From a set of original boxes and encoded relative box offsets,
|
| 279 |
+
get the decoded boxes.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
rel_codes (Tensor): encoded boxes
|
| 283 |
+
boxes (Tensor): reference boxes.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
Tensor: the predicted boxes with the encoded relative box offsets.
|
| 287 |
+
|
| 288 |
+
.. note::
|
| 289 |
+
This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
|
| 290 |
+
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
boxes = boxes.to(dtype=rel_codes.dtype)
|
| 294 |
+
|
| 295 |
+
ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
|
| 296 |
+
ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
|
| 297 |
+
|
| 298 |
+
if self.normalize_by_size:
|
| 299 |
+
boxes_w = boxes[..., 2] - boxes[..., 0]
|
| 300 |
+
boxes_h = boxes[..., 3] - boxes[..., 1]
|
| 301 |
+
|
| 302 |
+
list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
|
| 303 |
+
rel_codes = rel_codes * list_box_size
|
| 304 |
+
|
| 305 |
+
pred_boxes1 = ctr_x - rel_codes[..., 0]
|
| 306 |
+
pred_boxes2 = ctr_y - rel_codes[..., 1]
|
| 307 |
+
pred_boxes3 = ctr_x + rel_codes[..., 2]
|
| 308 |
+
pred_boxes4 = ctr_y + rel_codes[..., 3]
|
| 309 |
+
|
| 310 |
+
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
|
| 311 |
+
return pred_boxes
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class Matcher:
|
| 315 |
+
"""
|
| 316 |
+
This class assigns to each predicted "element" (e.g., a box) a ground-truth
|
| 317 |
+
element. Each predicted element will have exactly zero or one matches; each
|
| 318 |
+
ground-truth element may be assigned to zero or more predicted elements.
|
| 319 |
+
|
| 320 |
+
Matching is based on the MxN match_quality_matrix, that characterizes how well
|
| 321 |
+
each (ground-truth, predicted)-pair match. For example, if the elements are
|
| 322 |
+
boxes, the matrix may contain box IoU overlap values.
|
| 323 |
+
|
| 324 |
+
The matcher returns a tensor of size N containing the index of the ground-truth
|
| 325 |
+
element m that matches to prediction n. If there is no match, a negative value
|
| 326 |
+
is returned.
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
BELOW_LOW_THRESHOLD = -1
|
| 330 |
+
BETWEEN_THRESHOLDS = -2
|
| 331 |
+
|
| 332 |
+
__annotations__ = {
|
| 333 |
+
"BELOW_LOW_THRESHOLD": int,
|
| 334 |
+
"BETWEEN_THRESHOLDS": int,
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
|
| 338 |
+
"""
|
| 339 |
+
Args:
|
| 340 |
+
high_threshold (float): quality values greater than or equal to
|
| 341 |
+
this value are candidate matches.
|
| 342 |
+
low_threshold (float): a lower quality threshold used to stratify
|
| 343 |
+
matches into three levels:
|
| 344 |
+
1) matches >= high_threshold
|
| 345 |
+
2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
|
| 346 |
+
3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
|
| 347 |
+
allow_low_quality_matches (bool): if True, produce additional matches
|
| 348 |
+
for predictions that have only low-quality match candidates. See
|
| 349 |
+
set_low_quality_matches_ for more details.
|
| 350 |
+
"""
|
| 351 |
+
self.BELOW_LOW_THRESHOLD = -1
|
| 352 |
+
self.BETWEEN_THRESHOLDS = -2
|
| 353 |
+
torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
|
| 354 |
+
self.high_threshold = high_threshold
|
| 355 |
+
self.low_threshold = low_threshold
|
| 356 |
+
self.allow_low_quality_matches = allow_low_quality_matches
|
| 357 |
+
|
| 358 |
+
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
|
| 359 |
+
"""
|
| 360 |
+
Args:
|
| 361 |
+
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
|
| 362 |
+
pairwise quality between M ground-truth elements and N predicted elements.
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
|
| 366 |
+
[0, M - 1] or a negative value indicating that prediction i could not
|
| 367 |
+
be matched.
|
| 368 |
+
"""
|
| 369 |
+
if match_quality_matrix.numel() == 0:
|
| 370 |
+
# empty targets or proposals not supported during training
|
| 371 |
+
if match_quality_matrix.shape[0] == 0:
|
| 372 |
+
raise ValueError("No ground-truth boxes available for one of the images during training")
|
| 373 |
+
else:
|
| 374 |
+
raise ValueError("No proposal boxes available for one of the images during training")
|
| 375 |
+
|
| 376 |
+
# match_quality_matrix is M (gt) x N (predicted)
|
| 377 |
+
# Max over gt elements (dim 0) to find best gt candidate for each prediction
|
| 378 |
+
matched_vals, matches = match_quality_matrix.max(dim=0)
|
| 379 |
+
if self.allow_low_quality_matches:
|
| 380 |
+
all_matches = matches.clone()
|
| 381 |
+
else:
|
| 382 |
+
all_matches = None # type: ignore[assignment]
|
| 383 |
+
|
| 384 |
+
# Assign candidate matches with low quality to negative (unassigned) values
|
| 385 |
+
below_low_threshold = matched_vals < self.low_threshold
|
| 386 |
+
between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
|
| 387 |
+
matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
|
| 388 |
+
matches[between_thresholds] = self.BETWEEN_THRESHOLDS
|
| 389 |
+
|
| 390 |
+
if self.allow_low_quality_matches:
|
| 391 |
+
if all_matches is None:
|
| 392 |
+
torch._assert(False, "all_matches should not be None")
|
| 393 |
+
else:
|
| 394 |
+
self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
|
| 395 |
+
|
| 396 |
+
return matches
|
| 397 |
+
|
| 398 |
+
def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
|
| 399 |
+
"""
|
| 400 |
+
Produce additional matches for predictions that have only low-quality matches.
|
| 401 |
+
Specifically, for each ground-truth find the set of predictions that have
|
| 402 |
+
maximum overlap with it (including ties); for each prediction in that set, if
|
| 403 |
+
it is unmatched, then match it to the ground-truth with which it has the highest
|
| 404 |
+
quality value.
|
| 405 |
+
"""
|
| 406 |
+
# For each gt, find the prediction with which it has the highest quality
|
| 407 |
+
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
|
| 408 |
+
# Find the highest quality match available, even if it is low, including ties
|
| 409 |
+
gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
|
| 410 |
+
# Example gt_pred_pairs_of_highest_quality:
|
| 411 |
+
# (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]),
|
| 412 |
+
# tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
|
| 413 |
+
# Each element in the first tensor is a gt index, and each element in second tensor is a prediction index
|
| 414 |
+
# Note how gt items 1, 2, 3, and 5 each have two ties
|
| 415 |
+
|
| 416 |
+
pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
|
| 417 |
+
matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class SSDMatcher(Matcher):
|
| 421 |
+
def __init__(self, threshold: float) -> None:
|
| 422 |
+
super().__init__(threshold, threshold, allow_low_quality_matches=False)
|
| 423 |
+
|
| 424 |
+
def __call__(self, match_quality_matrix: Tensor) -> Tensor:
|
| 425 |
+
matches = super().__call__(match_quality_matrix)
|
| 426 |
+
|
| 427 |
+
# For each gt, find the prediction with which it has the highest quality
|
| 428 |
+
_, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
|
| 429 |
+
matches[highest_quality_pred_foreach_gt] = torch.arange(
|
| 430 |
+
highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
return matches
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def overwrite_eps(model: nn.Module, eps: float) -> None:
|
| 437 |
+
"""
|
| 438 |
+
This method overwrites the default eps values of all the
|
| 439 |
+
FrozenBatchNorm2d layers of the model with the provided value.
|
| 440 |
+
This is necessary to address the BC-breaking change introduced
|
| 441 |
+
by the bug-fix at pytorch/vision#2933. The overwrite is applied
|
| 442 |
+
only when the pretrained weights are loaded to maintain compatibility
|
| 443 |
+
with previous versions.
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
model (nn.Module): The model on which we perform the overwrite.
|
| 447 |
+
eps (float): The new value of eps.
|
| 448 |
+
"""
|
| 449 |
+
for module in model.modules():
|
| 450 |
+
if isinstance(module, FrozenBatchNorm2d):
|
| 451 |
+
module.eps = eps
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
|
| 455 |
+
"""
|
| 456 |
+
This method retrieves the number of output channels of a specific model.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
model (nn.Module): The model for which we estimate the out_channels.
|
| 460 |
+
It should return a single Tensor or an OrderedDict[Tensor].
|
| 461 |
+
size (Tuple[int, int]): The size (wxh) of the input.
|
| 462 |
+
|
| 463 |
+
Returns:
|
| 464 |
+
out_channels (List[int]): A list of the output channels of the model.
|
| 465 |
+
"""
|
| 466 |
+
in_training = model.training
|
| 467 |
+
model.eval()
|
| 468 |
+
|
| 469 |
+
with torch.no_grad():
|
| 470 |
+
# Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
|
| 471 |
+
device = next(model.parameters()).device
|
| 472 |
+
tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
|
| 473 |
+
features = model(tmp_img)
|
| 474 |
+
if isinstance(features, torch.Tensor):
|
| 475 |
+
features = OrderedDict([("0", features)])
|
| 476 |
+
out_channels = [x.size(1) for x in features.values()]
|
| 477 |
+
|
| 478 |
+
if in_training:
|
| 479 |
+
model.train()
|
| 480 |
+
|
| 481 |
+
return out_channels
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
@torch.jit.unused
|
| 485 |
+
def _fake_cast_onnx(v: Tensor) -> int:
|
| 486 |
+
return v # type: ignore[return-value]
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
|
| 490 |
+
"""
|
| 491 |
+
ONNX spec requires the k-value to be less than or equal to the number of inputs along
|
| 492 |
+
provided dim. Certain models use the number of elements along a particular axis instead of K
|
| 493 |
+
if K exceeds the number of elements along that axis. Previously, python's min() function was
|
| 494 |
+
used to determine whether to use the provided k-value or the specified dim axis value.
|
| 495 |
+
|
| 496 |
+
However, in cases where the model is being exported in tracing mode, python min() is
|
| 497 |
+
static causing the model to be traced incorrectly and eventually fail at the topk node.
|
| 498 |
+
In order to avoid this situation, in tracing mode, torch.min() is used instead.
|
| 499 |
+
|
| 500 |
+
Args:
|
| 501 |
+
input (Tensor): The original input tensor.
|
| 502 |
+
orig_kval (int): The provided k-value.
|
| 503 |
+
axis(int): Axis along which we retrieve the input size.
|
| 504 |
+
|
| 505 |
+
Returns:
|
| 506 |
+
min_kval (int): Appropriately selected k-value.
|
| 507 |
+
"""
|
| 508 |
+
if not torch.jit.is_tracing():
|
| 509 |
+
return min(orig_kval, input.size(axis))
|
| 510 |
+
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
|
| 511 |
+
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
|
| 512 |
+
return _fake_cast_onnx(min_kval)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def _box_loss(
|
| 516 |
+
type: str,
|
| 517 |
+
box_coder: BoxCoder,
|
| 518 |
+
anchors_per_image: Tensor,
|
| 519 |
+
matched_gt_boxes_per_image: Tensor,
|
| 520 |
+
bbox_regression_per_image: Tensor,
|
| 521 |
+
cnf: Optional[Dict[str, float]] = None,
|
| 522 |
+
) -> Tensor:
|
| 523 |
+
torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
|
| 524 |
+
|
| 525 |
+
if type == "l1":
|
| 526 |
+
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
|
| 527 |
+
return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
|
| 528 |
+
elif type == "smooth_l1":
|
| 529 |
+
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
|
| 530 |
+
beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
|
| 531 |
+
return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
|
| 532 |
+
else:
|
| 533 |
+
bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
|
| 534 |
+
eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
|
| 535 |
+
if type == "ciou":
|
| 536 |
+
return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
|
| 537 |
+
if type == "diou":
|
| 538 |
+
return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
|
| 539 |
+
# otherwise giou
|
| 540 |
+
return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
|
detection/anchor_utils.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn, Tensor
|
| 6 |
+
|
| 7 |
+
from .image_list import ImageList
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AnchorGenerator(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Module that generates anchors for a set of feature maps and
|
| 13 |
+
image sizes.
|
| 14 |
+
|
| 15 |
+
The module support computing anchors at multiple sizes and aspect ratios
|
| 16 |
+
per feature map. This module assumes aspect ratio = height / width for
|
| 17 |
+
each anchor.
|
| 18 |
+
|
| 19 |
+
sizes and aspect_ratios should have the same number of elements, and it should
|
| 20 |
+
correspond to the number of feature maps.
|
| 21 |
+
|
| 22 |
+
sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
|
| 23 |
+
and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
|
| 24 |
+
per spatial location for feature map i.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
sizes (Tuple[Tuple[int]]):
|
| 28 |
+
aspect_ratios (Tuple[Tuple[float]]):
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
__annotations__ = {
|
| 32 |
+
"cell_anchors": List[torch.Tensor],
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
sizes=((128, 256, 512),),
|
| 38 |
+
aspect_ratios=((0.5, 1.0, 2.0),),
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
if not isinstance(sizes[0], (list, tuple)):
|
| 43 |
+
# TODO change this
|
| 44 |
+
sizes = tuple((s,) for s in sizes)
|
| 45 |
+
if not isinstance(aspect_ratios[0], (list, tuple)):
|
| 46 |
+
aspect_ratios = (aspect_ratios,) * len(sizes)
|
| 47 |
+
|
| 48 |
+
self.sizes = sizes
|
| 49 |
+
self.aspect_ratios = aspect_ratios
|
| 50 |
+
self.cell_anchors = [
|
| 51 |
+
self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios)
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
# TODO: https://github.com/pytorch/pytorch/issues/26792
|
| 55 |
+
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
|
| 56 |
+
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
|
| 57 |
+
# This method assumes aspect ratio = height / width for an anchor.
|
| 58 |
+
def generate_anchors(
|
| 59 |
+
self,
|
| 60 |
+
scales: List[int],
|
| 61 |
+
aspect_ratios: List[float],
|
| 62 |
+
dtype: torch.dtype = torch.float32,
|
| 63 |
+
device: torch.device = torch.device("cpu"),
|
| 64 |
+
) -> Tensor:
|
| 65 |
+
scales = torch.as_tensor(scales, dtype=dtype, device=device)
|
| 66 |
+
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
|
| 67 |
+
h_ratios = torch.sqrt(aspect_ratios)
|
| 68 |
+
w_ratios = 1 / h_ratios
|
| 69 |
+
|
| 70 |
+
ws = (w_ratios[:, None] * scales[None, :]).view(-1)
|
| 71 |
+
hs = (h_ratios[:, None] * scales[None, :]).view(-1)
|
| 72 |
+
|
| 73 |
+
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
|
| 74 |
+
return base_anchors.round()
|
| 75 |
+
|
| 76 |
+
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
|
| 77 |
+
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
|
| 78 |
+
|
| 79 |
+
def num_anchors_per_location(self) -> List[int]:
|
| 80 |
+
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
|
| 81 |
+
|
| 82 |
+
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
|
| 83 |
+
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
|
| 84 |
+
def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
|
| 85 |
+
anchors = []
|
| 86 |
+
cell_anchors = self.cell_anchors
|
| 87 |
+
torch._assert(cell_anchors is not None, "cell_anchors should not be None")
|
| 88 |
+
torch._assert(
|
| 89 |
+
len(grid_sizes) == len(strides) == len(cell_anchors),
|
| 90 |
+
"Anchors should be Tuple[Tuple[int]] because each feature "
|
| 91 |
+
"map could potentially have different sizes and aspect ratios. "
|
| 92 |
+
"There needs to be a match between the number of "
|
| 93 |
+
"feature maps passed and the number of sizes / aspect ratios specified.",
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
|
| 97 |
+
grid_height, grid_width = size
|
| 98 |
+
stride_height, stride_width = stride
|
| 99 |
+
device = base_anchors.device
|
| 100 |
+
|
| 101 |
+
# For output anchor, compute [x_center, y_center, x_center, y_center]
|
| 102 |
+
shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
|
| 103 |
+
shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
|
| 104 |
+
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
|
| 105 |
+
shift_x = shift_x.reshape(-1)
|
| 106 |
+
shift_y = shift_y.reshape(-1)
|
| 107 |
+
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
|
| 108 |
+
|
| 109 |
+
# For every (base anchor, output anchor) pair,
|
| 110 |
+
# offset each zero-centered base anchor by the center of the output anchor.
|
| 111 |
+
anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
|
| 112 |
+
|
| 113 |
+
return anchors
|
| 114 |
+
|
| 115 |
+
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
|
| 116 |
+
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
|
| 117 |
+
image_size = image_list.tensors.shape[-2:]
|
| 118 |
+
dtype, device = feature_maps[0].dtype, feature_maps[0].device
|
| 119 |
+
strides = [
|
| 120 |
+
[
|
| 121 |
+
torch.empty((), dtype=torch.int64, device=device).fill_(image_size[0] // g[0]),
|
| 122 |
+
torch.empty((), dtype=torch.int64, device=device).fill_(image_size[1] // g[1]),
|
| 123 |
+
]
|
| 124 |
+
for g in grid_sizes
|
| 125 |
+
]
|
| 126 |
+
self.set_cell_anchors(dtype, device)
|
| 127 |
+
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
|
| 128 |
+
anchors: List[List[torch.Tensor]] = []
|
| 129 |
+
for _ in range(len(image_list.image_sizes)):
|
| 130 |
+
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
|
| 131 |
+
anchors.append(anchors_in_image)
|
| 132 |
+
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
|
| 133 |
+
return anchors
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class DefaultBoxGenerator(nn.Module):
|
| 137 |
+
"""
|
| 138 |
+
This module generates the default boxes of SSD for a set of feature maps and image sizes.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
aspect_ratios (List[List[int]]): A list with all the aspect ratios used in each feature map.
|
| 142 |
+
min_ratio (float): The minimum scale :math:`\text{s}_{\text{min}}` of the default boxes used in the estimation
|
| 143 |
+
of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
|
| 144 |
+
max_ratio (float): The maximum scale :math:`\text{s}_{\text{max}}` of the default boxes used in the estimation
|
| 145 |
+
of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
|
| 146 |
+
scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
|
| 147 |
+
the ``min_ratio`` and ``max_ratio`` parameters.
|
| 148 |
+
steps (List[int]], optional): It's a hyper-parameter that affects the tiling of default boxes. If not provided
|
| 149 |
+
it will be estimated from the data.
|
| 150 |
+
clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
|
| 151 |
+
is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
aspect_ratios: List[List[int]],
|
| 157 |
+
min_ratio: float = 0.15,
|
| 158 |
+
max_ratio: float = 0.9,
|
| 159 |
+
scales: Optional[List[float]] = None,
|
| 160 |
+
steps: Optional[List[int]] = None,
|
| 161 |
+
clip: bool = True,
|
| 162 |
+
):
|
| 163 |
+
super().__init__()
|
| 164 |
+
if steps is not None and len(aspect_ratios) != len(steps):
|
| 165 |
+
raise ValueError("aspect_ratios and steps should have the same length")
|
| 166 |
+
self.aspect_ratios = aspect_ratios
|
| 167 |
+
self.steps = steps
|
| 168 |
+
self.clip = clip
|
| 169 |
+
num_outputs = len(aspect_ratios)
|
| 170 |
+
|
| 171 |
+
# Estimation of default boxes scales
|
| 172 |
+
if scales is None:
|
| 173 |
+
if num_outputs > 1:
|
| 174 |
+
range_ratio = max_ratio - min_ratio
|
| 175 |
+
self.scales = [min_ratio + range_ratio * k / (num_outputs - 1.0) for k in range(num_outputs)]
|
| 176 |
+
self.scales.append(1.0)
|
| 177 |
+
else:
|
| 178 |
+
self.scales = [min_ratio, max_ratio]
|
| 179 |
+
else:
|
| 180 |
+
self.scales = scales
|
| 181 |
+
|
| 182 |
+
self._wh_pairs = self._generate_wh_pairs(num_outputs)
|
| 183 |
+
|
| 184 |
+
def _generate_wh_pairs(
|
| 185 |
+
self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu")
|
| 186 |
+
) -> List[Tensor]:
|
| 187 |
+
_wh_pairs: List[Tensor] = []
|
| 188 |
+
for k in range(num_outputs):
|
| 189 |
+
# Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
|
| 190 |
+
s_k = self.scales[k]
|
| 191 |
+
s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1])
|
| 192 |
+
wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]]
|
| 193 |
+
|
| 194 |
+
# Adding 2 pairs for each aspect ratio of the feature map k
|
| 195 |
+
for ar in self.aspect_ratios[k]:
|
| 196 |
+
sq_ar = math.sqrt(ar)
|
| 197 |
+
w = self.scales[k] * sq_ar
|
| 198 |
+
h = self.scales[k] / sq_ar
|
| 199 |
+
wh_pairs.extend([[w, h], [h, w]])
|
| 200 |
+
|
| 201 |
+
_wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
|
| 202 |
+
return _wh_pairs
|
| 203 |
+
|
| 204 |
+
def num_anchors_per_location(self) -> List[int]:
|
| 205 |
+
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
|
| 206 |
+
return [2 + 2 * len(r) for r in self.aspect_ratios]
|
| 207 |
+
|
| 208 |
+
# Default Boxes calculation based on page 6 of SSD paper
|
| 209 |
+
def _grid_default_boxes(
|
| 210 |
+
self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32
|
| 211 |
+
) -> Tensor:
|
| 212 |
+
default_boxes = []
|
| 213 |
+
for k, f_k in enumerate(grid_sizes):
|
| 214 |
+
# Now add the default boxes for each width-height pair
|
| 215 |
+
if self.steps is not None:
|
| 216 |
+
x_f_k = image_size[1] / self.steps[k]
|
| 217 |
+
y_f_k = image_size[0] / self.steps[k]
|
| 218 |
+
else:
|
| 219 |
+
y_f_k, x_f_k = f_k
|
| 220 |
+
|
| 221 |
+
shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype)
|
| 222 |
+
shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype)
|
| 223 |
+
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
|
| 224 |
+
shift_x = shift_x.reshape(-1)
|
| 225 |
+
shift_y = shift_y.reshape(-1)
|
| 226 |
+
|
| 227 |
+
shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2)
|
| 228 |
+
# Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h)
|
| 229 |
+
_wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k]
|
| 230 |
+
wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1)
|
| 231 |
+
|
| 232 |
+
default_box = torch.cat((shifts, wh_pairs), dim=1)
|
| 233 |
+
|
| 234 |
+
default_boxes.append(default_box)
|
| 235 |
+
|
| 236 |
+
return torch.cat(default_boxes, dim=0)
|
| 237 |
+
|
| 238 |
+
def __repr__(self) -> str:
|
| 239 |
+
s = (
|
| 240 |
+
f"{self.__class__.__name__}("
|
| 241 |
+
f"aspect_ratios={self.aspect_ratios}"
|
| 242 |
+
f", clip={self.clip}"
|
| 243 |
+
f", scales={self.scales}"
|
| 244 |
+
f", steps={self.steps}"
|
| 245 |
+
")"
|
| 246 |
+
)
|
| 247 |
+
return s
|
| 248 |
+
|
| 249 |
+
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
|
| 250 |
+
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
|
| 251 |
+
image_size = image_list.tensors.shape[-2:]
|
| 252 |
+
dtype, device = feature_maps[0].dtype, feature_maps[0].device
|
| 253 |
+
default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype)
|
| 254 |
+
default_boxes = default_boxes.to(device)
|
| 255 |
+
|
| 256 |
+
dboxes = []
|
| 257 |
+
x_y_size = torch.tensor([image_size[1], image_size[0]], device=default_boxes.device)
|
| 258 |
+
for _ in image_list.image_sizes:
|
| 259 |
+
dboxes_in_image = default_boxes
|
| 260 |
+
dboxes_in_image = torch.cat(
|
| 261 |
+
[
|
| 262 |
+
(dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
|
| 263 |
+
(dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
|
| 264 |
+
],
|
| 265 |
+
-1,
|
| 266 |
+
)
|
| 267 |
+
dboxes.append(dboxes_in_image)
|
| 268 |
+
return dboxes
|
detection/backbone_utils.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
from torch import nn, Tensor
|
| 5 |
+
from torchvision.ops import misc as misc_nn_ops
|
| 6 |
+
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
|
| 7 |
+
|
| 8 |
+
from torchvision.models import resnet
|
| 9 |
+
from torchvision.models._api import _get_enum_from_fn, WeightsEnum
|
| 10 |
+
from torchvision.models._utils import handle_legacy_interface, IntermediateLayerGetter
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BackboneWithFPN(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
backbone: nn.Module,
|
| 17 |
+
return_layers: Dict[str, str],
|
| 18 |
+
in_channels_list: List[int],
|
| 19 |
+
out_channels: int,
|
| 20 |
+
extra_blocks: Optional[ExtraFPNBlock] = None,
|
| 21 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
if extra_blocks is None:
|
| 26 |
+
extra_blocks = LastLevelMaxPool()
|
| 27 |
+
|
| 28 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
| 29 |
+
self.fpn = FeaturePyramidNetwork(
|
| 30 |
+
in_channels_list=in_channels_list,
|
| 31 |
+
out_channels=out_channels,
|
| 32 |
+
extra_blocks=extra_blocks,
|
| 33 |
+
norm_layer=norm_layer,
|
| 34 |
+
)
|
| 35 |
+
self.out_channels = out_channels
|
| 36 |
+
|
| 37 |
+
def forward(self, x: Tensor) -> Dict[str, Tensor]:
|
| 38 |
+
x = self.body(x)
|
| 39 |
+
x = self.fpn(x)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@handle_legacy_interface(
|
| 44 |
+
weights=(
|
| 45 |
+
"pretrained",
|
| 46 |
+
lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
|
| 47 |
+
),
|
| 48 |
+
)
|
| 49 |
+
def resnet_fpn_backbone(
|
| 50 |
+
*,
|
| 51 |
+
backbone_name: str,
|
| 52 |
+
weights: Optional[WeightsEnum],
|
| 53 |
+
norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
|
| 54 |
+
trainable_layers: int = 3,
|
| 55 |
+
returned_layers: Optional[List[int]] = None,
|
| 56 |
+
extra_blocks: Optional[ExtraFPNBlock] = None,
|
| 57 |
+
) -> BackboneWithFPN:
|
| 58 |
+
|
| 59 |
+
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
|
| 60 |
+
return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _resnet_fpn_extractor(
|
| 64 |
+
backbone: resnet.ResNet,
|
| 65 |
+
trainable_layers: int,
|
| 66 |
+
returned_layers: Optional[List[int]] = None,
|
| 67 |
+
extra_blocks: Optional[ExtraFPNBlock] = None,
|
| 68 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 69 |
+
) -> BackboneWithFPN:
|
| 70 |
+
|
| 71 |
+
# select layers that won't be frozen
|
| 72 |
+
if trainable_layers < 0 or trainable_layers > 5:
|
| 73 |
+
raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
|
| 74 |
+
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
|
| 75 |
+
if trainable_layers == 5:
|
| 76 |
+
layers_to_train.append("bn1")
|
| 77 |
+
for name, parameter in backbone.named_parameters():
|
| 78 |
+
if all([not name.startswith(layer) for layer in layers_to_train]):
|
| 79 |
+
parameter.requires_grad_(False)
|
| 80 |
+
|
| 81 |
+
if extra_blocks is None:
|
| 82 |
+
extra_blocks = LastLevelMaxPool()
|
| 83 |
+
|
| 84 |
+
if returned_layers is None:
|
| 85 |
+
returned_layers = [1, 2, 3, 4]
|
| 86 |
+
if min(returned_layers) <= 0 or max(returned_layers) >= 5:
|
| 87 |
+
raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
|
| 88 |
+
return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
|
| 89 |
+
|
| 90 |
+
in_channels_stage2 = backbone.inplanes // 8
|
| 91 |
+
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
|
| 92 |
+
out_channels = 256
|
| 93 |
+
return BackboneWithFPN(
|
| 94 |
+
backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _validate_trainable_layers(
|
| 99 |
+
is_trained: bool,
|
| 100 |
+
trainable_backbone_layers: Optional[int],
|
| 101 |
+
max_value: int,
|
| 102 |
+
default_value: int,
|
| 103 |
+
) -> int:
|
| 104 |
+
# don't freeze any layers if pretrained model or backbone is not used
|
| 105 |
+
if not is_trained:
|
| 106 |
+
if trainable_backbone_layers is not None:
|
| 107 |
+
warnings.warn(
|
| 108 |
+
"Changing trainable_backbone_layers has no effect if "
|
| 109 |
+
"neither pretrained nor pretrained_backbone have been set to True, "
|
| 110 |
+
f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable"
|
| 111 |
+
)
|
| 112 |
+
trainable_backbone_layers = max_value
|
| 113 |
+
|
| 114 |
+
# by default freeze first blocks
|
| 115 |
+
if trainable_backbone_layers is None:
|
| 116 |
+
trainable_backbone_layers = default_value
|
| 117 |
+
if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
|
| 120 |
+
)
|
| 121 |
+
return trainable_backbone_layers
|
detection/faster_rcnn.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torchvision.ops import MultiScaleRoIAlign
|
| 7 |
+
|
| 8 |
+
from torchvision.ops import misc as misc_nn_ops
|
| 9 |
+
from torchvision.transforms._presets import ObjectDetection
|
| 10 |
+
from torchvision.models._api import register_model, Weights, WeightsEnum
|
| 11 |
+
from torchvision.models._meta import _COCO_CATEGORIES
|
| 12 |
+
from torchvision.models._utils import _ovewrite_value_param, handle_legacy_interface
|
| 13 |
+
from torchvision.models.resnet import resnet50, ResNet50_Weights
|
| 14 |
+
|
| 15 |
+
from ._utils import overwrite_eps
|
| 16 |
+
from .anchor_utils import AnchorGenerator
|
| 17 |
+
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
|
| 18 |
+
from .generalized_rcnn import GeneralizedRCNN
|
| 19 |
+
from .roi_heads import RoIHeads
|
| 20 |
+
from .rpn import RegionProposalNetwork, RPNHead
|
| 21 |
+
from .transform import GeneralizedRCNNTransform
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"FasterRCNN",
|
| 26 |
+
"FasterRCNN_ResNet50_FPN_Weights",
|
| 27 |
+
"fasterrcnn_resnet50_fpn",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _default_anchorgen():
|
| 32 |
+
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
|
| 33 |
+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
|
| 34 |
+
return AnchorGenerator(anchor_sizes, aspect_ratios)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class FasterRCNN(GeneralizedRCNN):
|
| 38 |
+
"""
|
| 39 |
+
Implements Faster R-CNN.
|
| 40 |
+
|
| 41 |
+
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
|
| 42 |
+
image, and should be in 0-1 range. Different images can have different sizes.
|
| 43 |
+
|
| 44 |
+
The behavior of the model changes depending on if it is in training or evaluation mode.
|
| 45 |
+
|
| 46 |
+
During training, the model expects both the input tensors and targets (list of dictionary),
|
| 47 |
+
containing:
|
| 48 |
+
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
|
| 49 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 50 |
+
- labels (Int64Tensor[N]): the class label for each ground-truth box
|
| 51 |
+
|
| 52 |
+
The model returns a Dict[Tensor] during training, containing the classification and regression
|
| 53 |
+
losses for both the RPN and the R-CNN.
|
| 54 |
+
|
| 55 |
+
During inference, the model requires only the input tensors, and returns the post-processed
|
| 56 |
+
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
|
| 57 |
+
follows:
|
| 58 |
+
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
|
| 59 |
+
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
|
| 60 |
+
- labels (Int64Tensor[N]): the predicted labels for each image
|
| 61 |
+
- scores (Tensor[N]): the scores or each prediction
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
backbone (nn.Module): the network used to compute the features for the model.
|
| 65 |
+
It should contain an out_channels attribute, which indicates the number of output
|
| 66 |
+
channels that each feature map has (and it should be the same for all feature maps).
|
| 67 |
+
The backbone should return a single Tensor or and OrderedDict[Tensor].
|
| 68 |
+
num_classes (int): number of output classes of the model (including the background).
|
| 69 |
+
If box_predictor is specified, num_classes should be None.
|
| 70 |
+
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
|
| 71 |
+
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
|
| 72 |
+
image_mean (Tuple[float, float, float]): mean values used for input normalization.
|
| 73 |
+
They are generally the mean values of the dataset on which the backbone has been trained
|
| 74 |
+
on
|
| 75 |
+
image_std (Tuple[float, float, float]): std values used for input normalization.
|
| 76 |
+
They are generally the std values of the dataset on which the backbone has been trained on
|
| 77 |
+
rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
|
| 78 |
+
maps.
|
| 79 |
+
rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
|
| 80 |
+
rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
|
| 81 |
+
rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
|
| 82 |
+
rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
|
| 83 |
+
rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
|
| 84 |
+
rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
|
| 85 |
+
rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
|
| 86 |
+
considered as positive during training of the RPN.
|
| 87 |
+
rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
|
| 88 |
+
considered as negative during training of the RPN.
|
| 89 |
+
rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
|
| 90 |
+
for computing the loss
|
| 91 |
+
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
|
| 92 |
+
of the RPN
|
| 93 |
+
rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
|
| 94 |
+
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
|
| 95 |
+
the locations indicated by the bounding boxes
|
| 96 |
+
box_head (nn.Module): module that takes the cropped feature maps as input
|
| 97 |
+
box_predictor (nn.Module): module that takes the output of box_head and returns the
|
| 98 |
+
classification logits and box regression deltas.
|
| 99 |
+
box_score_thresh (float): during inference, only return proposals with a classification score
|
| 100 |
+
greater than box_score_thresh
|
| 101 |
+
box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
|
| 102 |
+
box_detections_per_img (int): maximum number of detections per image, for all classes.
|
| 103 |
+
box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
|
| 104 |
+
considered as positive during training of the classification head
|
| 105 |
+
box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
|
| 106 |
+
considered as negative during training of the classification head
|
| 107 |
+
box_batch_size_per_image (int): number of proposals that are sampled during training of the
|
| 108 |
+
classification head
|
| 109 |
+
box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
|
| 110 |
+
of the classification head
|
| 111 |
+
bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
|
| 112 |
+
bounding boxes
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
backbone,
|
| 118 |
+
num_classes=None,
|
| 119 |
+
# transform parameters
|
| 120 |
+
min_size=800,
|
| 121 |
+
max_size=1333,
|
| 122 |
+
image_mean=None,
|
| 123 |
+
image_std=None,
|
| 124 |
+
# RPN parameters
|
| 125 |
+
rpn_anchor_generator=None,
|
| 126 |
+
rpn_head=None,
|
| 127 |
+
rpn_pre_nms_top_n_train=2000,
|
| 128 |
+
rpn_pre_nms_top_n_test=1000,
|
| 129 |
+
rpn_post_nms_top_n_train=2000,
|
| 130 |
+
rpn_post_nms_top_n_test=1000,
|
| 131 |
+
rpn_nms_thresh=0.7,
|
| 132 |
+
rpn_fg_iou_thresh=0.7,
|
| 133 |
+
rpn_bg_iou_thresh=0.3,
|
| 134 |
+
rpn_batch_size_per_image=256,
|
| 135 |
+
rpn_positive_fraction=0.5,
|
| 136 |
+
rpn_score_thresh=0.0,
|
| 137 |
+
# Box parameters
|
| 138 |
+
box_roi_pool=None,
|
| 139 |
+
box_head=None,
|
| 140 |
+
box_predictor=None,
|
| 141 |
+
box_score_thresh=0.05,
|
| 142 |
+
box_nms_thresh=0.5,
|
| 143 |
+
box_detections_per_img=100,
|
| 144 |
+
box_fg_iou_thresh=0.5,
|
| 145 |
+
box_bg_iou_thresh=0.5,
|
| 146 |
+
box_batch_size_per_image=512,
|
| 147 |
+
box_positive_fraction=0.25,
|
| 148 |
+
bbox_reg_weights=None,
|
| 149 |
+
**kwargs,
|
| 150 |
+
):
|
| 151 |
+
|
| 152 |
+
if not hasattr(backbone, "out_channels"):
|
| 153 |
+
raise ValueError(
|
| 154 |
+
"backbone should contain an attribute out_channels "
|
| 155 |
+
"specifying the number of output channels (assumed to be the "
|
| 156 |
+
"same for all the levels)"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
|
| 160 |
+
raise TypeError(
|
| 161 |
+
f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
|
| 162 |
+
)
|
| 163 |
+
if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
|
| 164 |
+
raise TypeError(
|
| 165 |
+
f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if num_classes is not None:
|
| 169 |
+
if box_predictor is not None:
|
| 170 |
+
raise ValueError("num_classes should be None when box_predictor is specified")
|
| 171 |
+
else:
|
| 172 |
+
if box_predictor is None:
|
| 173 |
+
raise ValueError("num_classes should not be None when box_predictor is not specified")
|
| 174 |
+
|
| 175 |
+
out_channels = backbone.out_channels
|
| 176 |
+
|
| 177 |
+
if rpn_anchor_generator is None:
|
| 178 |
+
rpn_anchor_generator = _default_anchorgen()
|
| 179 |
+
if rpn_head is None:
|
| 180 |
+
rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
|
| 181 |
+
|
| 182 |
+
rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
|
| 183 |
+
rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
|
| 184 |
+
|
| 185 |
+
rpn = RegionProposalNetwork(
|
| 186 |
+
rpn_anchor_generator,
|
| 187 |
+
rpn_head,
|
| 188 |
+
rpn_fg_iou_thresh,
|
| 189 |
+
rpn_bg_iou_thresh,
|
| 190 |
+
rpn_batch_size_per_image,
|
| 191 |
+
rpn_positive_fraction,
|
| 192 |
+
rpn_pre_nms_top_n,
|
| 193 |
+
rpn_post_nms_top_n,
|
| 194 |
+
rpn_nms_thresh,
|
| 195 |
+
score_thresh=rpn_score_thresh,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if box_roi_pool is None:
|
| 199 |
+
box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
|
| 200 |
+
|
| 201 |
+
if box_head is None:
|
| 202 |
+
resolution = box_roi_pool.output_size[0]
|
| 203 |
+
representation_size = 1024
|
| 204 |
+
box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
|
| 205 |
+
|
| 206 |
+
if box_predictor is None:
|
| 207 |
+
representation_size = 1024
|
| 208 |
+
box_predictor = FastRCNNPredictor(representation_size, num_classes)
|
| 209 |
+
|
| 210 |
+
roi_heads = RoIHeads(
|
| 211 |
+
# Box
|
| 212 |
+
box_roi_pool,
|
| 213 |
+
box_head,
|
| 214 |
+
box_predictor,
|
| 215 |
+
box_fg_iou_thresh,
|
| 216 |
+
box_bg_iou_thresh,
|
| 217 |
+
box_batch_size_per_image,
|
| 218 |
+
box_positive_fraction,
|
| 219 |
+
bbox_reg_weights,
|
| 220 |
+
box_score_thresh,
|
| 221 |
+
box_nms_thresh,
|
| 222 |
+
box_detections_per_img,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if image_mean is None:
|
| 226 |
+
image_mean = [0.485, 0.456, 0.406]
|
| 227 |
+
if image_std is None:
|
| 228 |
+
image_std = [0.229, 0.224, 0.225]
|
| 229 |
+
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
|
| 230 |
+
|
| 231 |
+
super().__init__(backbone, rpn, roi_heads, transform)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class TwoMLPHead(nn.Module):
|
| 235 |
+
"""
|
| 236 |
+
Standard heads for FPN-based models
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
in_channels (int): number of input channels
|
| 240 |
+
representation_size (int): size of the intermediate representation
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def __init__(self, in_channels, representation_size):
|
| 244 |
+
super().__init__()
|
| 245 |
+
|
| 246 |
+
self.fc6 = nn.Linear(in_channels, representation_size)
|
| 247 |
+
self.fc7 = nn.Linear(representation_size, representation_size)
|
| 248 |
+
|
| 249 |
+
def forward(self, x):
|
| 250 |
+
x = x.flatten(start_dim=1)
|
| 251 |
+
|
| 252 |
+
x = F.relu(self.fc6(x))
|
| 253 |
+
x = F.relu(self.fc7(x))
|
| 254 |
+
|
| 255 |
+
return x
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class FastRCNNConvFCHead(nn.Sequential):
|
| 259 |
+
def __init__(
|
| 260 |
+
self,
|
| 261 |
+
input_size: Tuple[int, int, int],
|
| 262 |
+
conv_layers: List[int],
|
| 263 |
+
fc_layers: List[int],
|
| 264 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 265 |
+
):
|
| 266 |
+
"""
|
| 267 |
+
Args:
|
| 268 |
+
input_size (Tuple[int, int, int]): the input size in CHW format.
|
| 269 |
+
conv_layers (list): feature dimensions of each Convolution layer
|
| 270 |
+
fc_layers (list): feature dimensions of each FCN layer
|
| 271 |
+
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
|
| 272 |
+
"""
|
| 273 |
+
in_channels, in_height, in_width = input_size
|
| 274 |
+
|
| 275 |
+
blocks = []
|
| 276 |
+
previous_channels = in_channels
|
| 277 |
+
for current_channels in conv_layers:
|
| 278 |
+
blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
|
| 279 |
+
previous_channels = current_channels
|
| 280 |
+
blocks.append(nn.Flatten())
|
| 281 |
+
previous_channels = previous_channels * in_height * in_width
|
| 282 |
+
for current_channels in fc_layers:
|
| 283 |
+
blocks.append(nn.Linear(previous_channels, current_channels))
|
| 284 |
+
blocks.append(nn.ReLU(inplace=True))
|
| 285 |
+
previous_channels = current_channels
|
| 286 |
+
|
| 287 |
+
super().__init__(*blocks)
|
| 288 |
+
for layer in self.modules():
|
| 289 |
+
if isinstance(layer, nn.Conv2d):
|
| 290 |
+
nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
|
| 291 |
+
if layer.bias is not None:
|
| 292 |
+
nn.init.zeros_(layer.bias)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class FastRCNNPredictor(nn.Module):
|
| 296 |
+
"""
|
| 297 |
+
Standard classification + bounding box regression layers + theta
|
| 298 |
+
for Fast R-CNN.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
in_channels (int): number of input channels
|
| 302 |
+
num_classes (int): number of output classes (including background)
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
def __init__(self, in_channels, num_classes, num_theta_bins=1):
|
| 306 |
+
super().__init__()
|
| 307 |
+
self.cls_score = nn.Linear(in_channels, num_classes)
|
| 308 |
+
self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
|
| 309 |
+
self.theta_pred = nn.Linear(in_channels, 1 + num_theta_bins)
|
| 310 |
+
|
| 311 |
+
def forward(self, x):
|
| 312 |
+
if x.dim() == 4:
|
| 313 |
+
torch._assert(
|
| 314 |
+
list(x.shape[2:]) == [1, 1],
|
| 315 |
+
f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
|
| 316 |
+
)
|
| 317 |
+
x = x.flatten(start_dim=1)
|
| 318 |
+
scores = self.cls_score(x)
|
| 319 |
+
bbox_deltas = self.bbox_pred(x)
|
| 320 |
+
theta_preds = self.theta_pred(x)
|
| 321 |
+
|
| 322 |
+
return scores, bbox_deltas, theta_preds
|
| 323 |
+
|
| 324 |
+
_COMMON_META = {
|
| 325 |
+
"categories": _COCO_CATEGORIES,
|
| 326 |
+
"min_size": (1, 1),
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
|
| 331 |
+
COCO_V1 = Weights(
|
| 332 |
+
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
|
| 333 |
+
transforms=ObjectDetection,
|
| 334 |
+
meta={
|
| 335 |
+
**_COMMON_META,
|
| 336 |
+
"num_params": 41755286,
|
| 337 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
|
| 338 |
+
"_metrics": {
|
| 339 |
+
"COCO-val2017": {
|
| 340 |
+
"box_map": 37.0,
|
| 341 |
+
}
|
| 342 |
+
},
|
| 343 |
+
"_ops": 134.38,
|
| 344 |
+
"_file_size": 159.743,
|
| 345 |
+
"_docs": """These weights were produced by following a similar training recipe as on the paper.""",
|
| 346 |
+
},
|
| 347 |
+
)
|
| 348 |
+
DEFAULT = COCO_V1
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# @register_model()
|
| 352 |
+
@handle_legacy_interface(
|
| 353 |
+
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
|
| 354 |
+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
|
| 355 |
+
)
|
| 356 |
+
def fasterrcnn_resnet50_fpn(
|
| 357 |
+
*,
|
| 358 |
+
weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
|
| 359 |
+
progress: bool = True,
|
| 360 |
+
num_classes: Optional[int] = None,
|
| 361 |
+
weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
|
| 362 |
+
trainable_backbone_layers: Optional[int] = None,
|
| 363 |
+
**kwargs: Any,
|
| 364 |
+
) -> FasterRCNN:
|
| 365 |
+
|
| 366 |
+
weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
|
| 367 |
+
weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
| 368 |
+
|
| 369 |
+
if weights is not None:
|
| 370 |
+
weights_backbone = None
|
| 371 |
+
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
|
| 372 |
+
elif num_classes is None:
|
| 373 |
+
num_classes = 91
|
| 374 |
+
|
| 375 |
+
is_trained = weights is not None or weights_backbone is not None
|
| 376 |
+
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
|
| 377 |
+
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
| 378 |
+
|
| 379 |
+
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
|
| 380 |
+
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
|
| 381 |
+
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
|
| 382 |
+
|
| 383 |
+
if weights is not None:
|
| 384 |
+
model.load_state_dict(weights.get_state_dict(progress=progress), strict=False)
|
| 385 |
+
torch.nn.init.kaiming_normal_(model.roi_heads.box_predictor.theta_pred.weight, mode="fan_out", nonlinearity="relu")
|
| 386 |
+
torch.nn.init.constant_(model.roi_heads.box_predictor.theta_pred.bias, 0)
|
| 387 |
+
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
|
| 388 |
+
overwrite_eps(model, 0.0)
|
| 389 |
+
|
| 390 |
+
return model
|
detection/generalized_rcnn.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Implements the Generalized R-CNN framework
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import warnings
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn, Tensor
|
| 11 |
+
|
| 12 |
+
from torchvision.utils import _log_api_usage_once
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GeneralizedRCNN(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
Main class for Generalized R-CNN.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
backbone (nn.Module):
|
| 21 |
+
rpn (nn.Module):
|
| 22 |
+
roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
|
| 23 |
+
detections / masks from it.
|
| 24 |
+
transform (nn.Module): performs the data transformation from the inputs to feed into
|
| 25 |
+
the model
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
_log_api_usage_once(self)
|
| 31 |
+
self.transform = transform
|
| 32 |
+
self.backbone = backbone
|
| 33 |
+
self.rpn = rpn
|
| 34 |
+
self.roi_heads = roi_heads
|
| 35 |
+
# used only on torchscript mode
|
| 36 |
+
self._has_warned = False
|
| 37 |
+
|
| 38 |
+
@torch.jit.unused
|
| 39 |
+
def eager_outputs(self, losses, detections):
|
| 40 |
+
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
|
| 41 |
+
if self.training:
|
| 42 |
+
return losses
|
| 43 |
+
|
| 44 |
+
return detections
|
| 45 |
+
|
| 46 |
+
def forward(self, images, targets=None):
|
| 47 |
+
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
images (list[Tensor]): images to be processed
|
| 51 |
+
targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
result (list[BoxList] or dict[Tensor]): the output from the model.
|
| 55 |
+
During training, it returns a dict[Tensor] which contains the losses.
|
| 56 |
+
During testing, it returns list[BoxList] contains additional fields
|
| 57 |
+
like `scores`, `labels` and `mask` (for Mask R-CNN models).
|
| 58 |
+
|
| 59 |
+
"""
|
| 60 |
+
if self.training:
|
| 61 |
+
if targets is None:
|
| 62 |
+
torch._assert(False, "targets should not be none when in training mode")
|
| 63 |
+
else:
|
| 64 |
+
for target in targets:
|
| 65 |
+
boxes = target["boxes"]
|
| 66 |
+
if isinstance(boxes, torch.Tensor):
|
| 67 |
+
torch._assert(
|
| 68 |
+
len(boxes.shape) == 2 and boxes.shape[-1] == 4,
|
| 69 |
+
f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
|
| 73 |
+
|
| 74 |
+
original_image_sizes: List[Tuple[int, int]] = []
|
| 75 |
+
for img in images:
|
| 76 |
+
val = img.shape[-2:]
|
| 77 |
+
torch._assert(
|
| 78 |
+
len(val) == 2,
|
| 79 |
+
f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
|
| 80 |
+
)
|
| 81 |
+
original_image_sizes.append((val[0], val[1]))
|
| 82 |
+
|
| 83 |
+
images, targets = self.transform(images, targets)
|
| 84 |
+
|
| 85 |
+
# Check for degenerate boxes
|
| 86 |
+
# TODO: Move this to a function
|
| 87 |
+
if targets is not None:
|
| 88 |
+
for target_idx, target in enumerate(targets):
|
| 89 |
+
boxes = target["boxes"]
|
| 90 |
+
degenerate_boxes = boxes[:, 2:4] <= boxes[:, :2]
|
| 91 |
+
if degenerate_boxes.any():
|
| 92 |
+
# print the first degenerate box
|
| 93 |
+
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
|
| 94 |
+
degen_bb: List[float] = boxes[bb_idx].tolist()
|
| 95 |
+
torch._assert(
|
| 96 |
+
False,
|
| 97 |
+
"All bounding boxes should have positive height and width."
|
| 98 |
+
f" Found invalid box {degen_bb} for target at index {target_idx}.",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
features = self.backbone(images.tensors)
|
| 102 |
+
if isinstance(features, torch.Tensor):
|
| 103 |
+
features = OrderedDict([("0", features)])
|
| 104 |
+
|
| 105 |
+
# modify targets to remove theta for rpn
|
| 106 |
+
# print(f"{len(targets)=}")
|
| 107 |
+
# print(f"{targets[0]=}")
|
| 108 |
+
# targets_rpn = []
|
| 109 |
+
# for target in targets:
|
| 110 |
+
# target_rpn = target.copy()
|
| 111 |
+
# target_rpn['boxes'] = target_rpn['boxes'][:, :-1]
|
| 112 |
+
# targets_rpn.append(target_rpn)
|
| 113 |
+
# print(f"{targets_rpn[0]=}")
|
| 114 |
+
proposals, proposal_losses = self.rpn(images, features, targets)
|
| 115 |
+
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
|
| 116 |
+
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
|
| 117 |
+
|
| 118 |
+
losses = {}
|
| 119 |
+
losses.update(detector_losses)
|
| 120 |
+
losses.update(proposal_losses)
|
| 121 |
+
|
| 122 |
+
if torch.jit.is_scripting():
|
| 123 |
+
if not self._has_warned:
|
| 124 |
+
warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
|
| 125 |
+
self._has_warned = True
|
| 126 |
+
return losses, detections
|
| 127 |
+
else:
|
| 128 |
+
return self.eager_outputs(losses, detections)
|
detection/image_list.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ImageList:
|
| 8 |
+
"""
|
| 9 |
+
Structure that holds a list of images (of possibly
|
| 10 |
+
varying sizes) as a single tensor.
|
| 11 |
+
This works by padding the images to the same size,
|
| 12 |
+
and storing in a field the original sizes of each image
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
tensors (tensor): Tensor containing images.
|
| 16 |
+
image_sizes (list[tuple[int, int]]): List of Tuples each containing size of images.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None:
|
| 20 |
+
self.tensors = tensors
|
| 21 |
+
self.image_sizes = image_sizes
|
| 22 |
+
|
| 23 |
+
def to(self, device: torch.device) -> "ImageList":
|
| 24 |
+
cast_tensor = self.tensors.to(device)
|
| 25 |
+
return ImageList(cast_tensor, self.image_sizes)
|
detection/roi_heads.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torchvision
|
| 6 |
+
from torch import nn, Tensor
|
| 7 |
+
from torchvision.ops import boxes as box_ops, roi_align
|
| 8 |
+
|
| 9 |
+
from . import _utils as det_utils
|
| 10 |
+
|
| 11 |
+
def compute_theta_loss(preds, targets):
|
| 12 |
+
# print(f"{preds.shape=} {targets.shape=}")
|
| 13 |
+
# print(f"{preds.device=}, {targets.device=}")
|
| 14 |
+
num_bins = preds.shape[1]
|
| 15 |
+
if num_bins == 1:
|
| 16 |
+
# regression
|
| 17 |
+
return F.mse_loss(preds[:,0], targets)
|
| 18 |
+
else:
|
| 19 |
+
# classification
|
| 20 |
+
bin_size = torch.pi / num_bins
|
| 21 |
+
targets_bins = (targets / bin_size).long()
|
| 22 |
+
targets_bins = torch.clamp(targets_bins, 0, num_bins - 1)
|
| 23 |
+
return F.cross_entropy(preds, targets_bins)
|
| 24 |
+
|
| 25 |
+
def fastrcnn_loss(class_logits, box_regression, theta_preds, labels, regression_targets, theta_targets):
|
| 26 |
+
# type: (Tensor, Tensor, Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor, Tensor]
|
| 27 |
+
"""
|
| 28 |
+
Computes the loss for Faster R-CNN.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
class_logits (Tensor)
|
| 32 |
+
box_regression (Tensor)
|
| 33 |
+
labels (list[BoxList])
|
| 34 |
+
regression_targets (Tensor)
|
| 35 |
+
theta_targets (Tensor)
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
classification_loss (Tensor)
|
| 39 |
+
box_loss (Tensor)
|
| 40 |
+
theta_loss (Tensor)
|
| 41 |
+
"""
|
| 42 |
+
# print(f"{class_logits.shape=} {box_regression.shape=} {theta_preds.shape=}")
|
| 43 |
+
# print(f"{labels[0].shape=} {regression_targets[0].shape=} {theta_targets[0].shape=}")
|
| 44 |
+
labels = torch.cat(labels, dim=0)
|
| 45 |
+
regression_targets = torch.cat(regression_targets, dim=0)
|
| 46 |
+
theta_targets = torch.cat(theta_targets, dim=0)
|
| 47 |
+
# print(f"{labels.shape=} {regression_targets.shape=} {theta_targets.shape=}")
|
| 48 |
+
|
| 49 |
+
classification_loss = F.cross_entropy(class_logits, labels)
|
| 50 |
+
|
| 51 |
+
# get indices that correspond to the regression targets for
|
| 52 |
+
# the corresponding ground truth labels, to be used with
|
| 53 |
+
# advanced indexing
|
| 54 |
+
sampled_pos_inds_subset = torch.where(labels > 0)[0]
|
| 55 |
+
labels_pos = labels[sampled_pos_inds_subset]
|
| 56 |
+
N, num_classes = class_logits.shape
|
| 57 |
+
box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
|
| 58 |
+
|
| 59 |
+
box_loss = F.smooth_l1_loss(
|
| 60 |
+
box_regression[sampled_pos_inds_subset, labels_pos],
|
| 61 |
+
regression_targets[sampled_pos_inds_subset],
|
| 62 |
+
beta=1 / 9,
|
| 63 |
+
reduction="sum",
|
| 64 |
+
)
|
| 65 |
+
box_loss = box_loss / labels.numel()
|
| 66 |
+
|
| 67 |
+
# theta loss
|
| 68 |
+
preds = theta_preds[sampled_pos_inds_subset.to(theta_preds.device)][:,1:]
|
| 69 |
+
targets = theta_targets[sampled_pos_inds_subset.to(theta_targets.device)].to(preds.device)
|
| 70 |
+
|
| 71 |
+
theta_loss = compute_theta_loss(preds, targets)
|
| 72 |
+
|
| 73 |
+
return classification_loss, box_loss, theta_loss
|
| 74 |
+
|
| 75 |
+
def fastrcnn_theta_loss(theta_preds, theta_targets):
|
| 76 |
+
# print(f"{len(theta_preds)=}")
|
| 77 |
+
# print(f"{len(theta_targets)=}")
|
| 78 |
+
# print(f"{theta_preds[0].shape=}")
|
| 79 |
+
# print(f"{theta_targets[0].shape=}")
|
| 80 |
+
return 0
|
| 81 |
+
|
| 82 |
+
def convert_xyxytheta_to_xywha(boxes):
|
| 83 |
+
xc = (boxes[:, 0] + boxes[:, 2]) / 2
|
| 84 |
+
yc = (boxes[:, 1] + boxes[:, 3]) / 2
|
| 85 |
+
w = boxes[:, 2] - boxes[:, 0]
|
| 86 |
+
h = boxes[:, 3] - boxes[:, 1]
|
| 87 |
+
theta = torch.deg2rad(boxes[:, 4])
|
| 88 |
+
return torch.stack([xc, yc, w, h, theta], dim=1)
|
| 89 |
+
|
| 90 |
+
class RoIHeads(nn.Module):
|
| 91 |
+
__annotations__ = {
|
| 92 |
+
"box_coder": det_utils.BoxCoder,
|
| 93 |
+
"proposal_matcher": det_utils.Matcher,
|
| 94 |
+
"fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
box_roi_pool,
|
| 100 |
+
box_head,
|
| 101 |
+
box_predictor,
|
| 102 |
+
# Faster R-CNN training
|
| 103 |
+
fg_iou_thresh,
|
| 104 |
+
bg_iou_thresh,
|
| 105 |
+
batch_size_per_image,
|
| 106 |
+
positive_fraction,
|
| 107 |
+
bbox_reg_weights,
|
| 108 |
+
# Faster R-CNN inference
|
| 109 |
+
score_thresh,
|
| 110 |
+
nms_thresh,
|
| 111 |
+
detections_per_img,
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
|
| 115 |
+
self.box_similarity = box_ops.box_iou
|
| 116 |
+
# assign ground-truth boxes for each proposal
|
| 117 |
+
self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
|
| 118 |
+
|
| 119 |
+
self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
|
| 120 |
+
|
| 121 |
+
if bbox_reg_weights is None:
|
| 122 |
+
bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
|
| 123 |
+
self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
|
| 124 |
+
|
| 125 |
+
self.box_roi_pool = box_roi_pool
|
| 126 |
+
self.box_head = box_head
|
| 127 |
+
self.box_predictor = box_predictor
|
| 128 |
+
|
| 129 |
+
self.score_thresh = score_thresh
|
| 130 |
+
self.nms_thresh = nms_thresh
|
| 131 |
+
self.detections_per_img = detections_per_img
|
| 132 |
+
|
| 133 |
+
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
|
| 134 |
+
# type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
|
| 135 |
+
matched_idxs = []
|
| 136 |
+
labels = []
|
| 137 |
+
for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
|
| 138 |
+
|
| 139 |
+
if gt_boxes_in_image.numel() == 0:
|
| 140 |
+
# Background image
|
| 141 |
+
device = proposals_in_image.device
|
| 142 |
+
clamped_matched_idxs_in_image = torch.zeros(
|
| 143 |
+
(proposals_in_image.shape[0],), dtype=torch.int64, device=device
|
| 144 |
+
)
|
| 145 |
+
labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
|
| 146 |
+
else:
|
| 147 |
+
# set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
|
| 148 |
+
# print(f"{gt_boxes_in_image.shape=}")
|
| 149 |
+
# print(f"{proposals_in_image.shape=}")
|
| 150 |
+
match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
|
| 151 |
+
# match_quality_matrix = box_iou_rotated(convert_xyxytheta_to_xywha(gt_boxes_in_image), convert_xyxytheta_to_xywha(proposals_in_image))
|
| 152 |
+
matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
|
| 153 |
+
|
| 154 |
+
clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
|
| 155 |
+
|
| 156 |
+
labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
|
| 157 |
+
labels_in_image = labels_in_image.to(dtype=torch.int64)
|
| 158 |
+
|
| 159 |
+
# Label background (below the low threshold)
|
| 160 |
+
bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
|
| 161 |
+
labels_in_image[bg_inds] = 0
|
| 162 |
+
|
| 163 |
+
# Label ignore proposals (between low and high thresholds)
|
| 164 |
+
ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
|
| 165 |
+
labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
|
| 166 |
+
|
| 167 |
+
matched_idxs.append(clamped_matched_idxs_in_image)
|
| 168 |
+
labels.append(labels_in_image)
|
| 169 |
+
return matched_idxs, labels
|
| 170 |
+
|
| 171 |
+
def subsample(self, labels):
|
| 172 |
+
# type: (List[Tensor]) -> List[Tensor]
|
| 173 |
+
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
|
| 174 |
+
sampled_inds = []
|
| 175 |
+
for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
|
| 176 |
+
img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
|
| 177 |
+
sampled_inds.append(img_sampled_inds)
|
| 178 |
+
return sampled_inds
|
| 179 |
+
|
| 180 |
+
def add_gt_proposals(self, proposals, gt_boxes):
|
| 181 |
+
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
|
| 182 |
+
# print(f"{len(proposals)=}")
|
| 183 |
+
# print(f"{len(gt_boxes)=}")
|
| 184 |
+
# print(f"{proposals[0].shape=}")
|
| 185 |
+
# print(f"{gt_boxes[0].shape=}")
|
| 186 |
+
# print(f"{proposals[0]=}")
|
| 187 |
+
# proposals_with_theta = []
|
| 188 |
+
# for proposal in proposals:
|
| 189 |
+
# proposal_with_theta = torch.cat((proposal, torch.zeros((proposal.shape[0], 1), dtype=proposal.dtype, device=proposal.device)), dim=1)
|
| 190 |
+
# proposals_with_theta.append(proposal_with_theta)
|
| 191 |
+
# gt_boxes_without_theta = [gt_box[:, :-1] for gt_box in gt_boxes]
|
| 192 |
+
# print(f"{len(proposal_with_theta)=}")
|
| 193 |
+
# print(f"{proposals_with_theta[0].shape=}")
|
| 194 |
+
# print(f"{proposals_with_theta[0]=}")
|
| 195 |
+
proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
|
| 196 |
+
|
| 197 |
+
return proposals
|
| 198 |
+
|
| 199 |
+
def check_targets(self, targets):
|
| 200 |
+
# type: (Optional[List[Dict[str, Tensor]]]) -> None
|
| 201 |
+
if targets is None:
|
| 202 |
+
raise ValueError("targets should not be None")
|
| 203 |
+
if not all(["boxes" in t for t in targets]):
|
| 204 |
+
raise ValueError("Every element of targets should have a boxes key")
|
| 205 |
+
if not all(["labels" in t for t in targets]):
|
| 206 |
+
raise ValueError("Every element of targets should have a labels key")
|
| 207 |
+
|
| 208 |
+
def select_training_samples(
|
| 209 |
+
self,
|
| 210 |
+
proposals, # type: List[Tensor]
|
| 211 |
+
targets, # type: Optional[List[Dict[str, Tensor]]]
|
| 212 |
+
):
|
| 213 |
+
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
|
| 214 |
+
self.check_targets(targets)
|
| 215 |
+
if targets is None:
|
| 216 |
+
raise ValueError("targets should not be None")
|
| 217 |
+
dtype = proposals[0].dtype
|
| 218 |
+
device = proposals[0].device
|
| 219 |
+
|
| 220 |
+
gt_boxes = [t["boxes"].to(dtype) for t in targets]
|
| 221 |
+
gt_labels = [t["labels"] for t in targets]
|
| 222 |
+
gt_thetas = [t["thetas"] for t in targets]
|
| 223 |
+
|
| 224 |
+
# append ground-truth bboxes to propos
|
| 225 |
+
proposals = self.add_gt_proposals(proposals, gt_boxes)
|
| 226 |
+
|
| 227 |
+
# get matching gt indices for each proposal
|
| 228 |
+
matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
|
| 229 |
+
# sample a fixed proportion of positive-negative proposals
|
| 230 |
+
sampled_inds = self.subsample(labels)
|
| 231 |
+
matched_gt_boxes = []
|
| 232 |
+
matched_gt_thetas = []
|
| 233 |
+
|
| 234 |
+
num_images = len(proposals)
|
| 235 |
+
for img_id in range(num_images):
|
| 236 |
+
img_sampled_inds = sampled_inds[img_id]
|
| 237 |
+
proposals[img_id] = proposals[img_id][img_sampled_inds]
|
| 238 |
+
labels[img_id] = labels[img_id][img_sampled_inds]
|
| 239 |
+
matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
|
| 240 |
+
|
| 241 |
+
gt_boxes_in_image = gt_boxes[img_id]
|
| 242 |
+
gt_thetas_in_image = gt_thetas[img_id]
|
| 243 |
+
|
| 244 |
+
if gt_boxes_in_image.numel() == 0:
|
| 245 |
+
gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
|
| 246 |
+
matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
|
| 247 |
+
matched_gt_thetas.append(gt_thetas_in_image[matched_idxs[img_id].to(gt_thetas_in_image.device)])
|
| 248 |
+
|
| 249 |
+
regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
|
| 250 |
+
theta_targets = matched_gt_thetas
|
| 251 |
+
return proposals, matched_idxs, labels, regression_targets, theta_targets
|
| 252 |
+
|
| 253 |
+
def postprocess_detections(
|
| 254 |
+
self,
|
| 255 |
+
class_logits, # type: Tensor
|
| 256 |
+
box_regression, # type: Tensor
|
| 257 |
+
theta_preds, # type: Tensor
|
| 258 |
+
proposals, # type: List[Tensor]
|
| 259 |
+
image_shapes, # type: List[Tuple[int, int]]
|
| 260 |
+
):
|
| 261 |
+
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
|
| 262 |
+
device = class_logits.device
|
| 263 |
+
num_classes = class_logits.shape[-1]
|
| 264 |
+
|
| 265 |
+
boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
|
| 266 |
+
pred_boxes = self.box_coder.decode(box_regression, proposals)
|
| 267 |
+
|
| 268 |
+
pred_scores = F.softmax(class_logits, -1)
|
| 269 |
+
|
| 270 |
+
pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
|
| 271 |
+
pred_scores_list = pred_scores.split(boxes_per_image, 0)
|
| 272 |
+
pred_theta_list = theta_preds.split(boxes_per_image, 0)
|
| 273 |
+
|
| 274 |
+
all_boxes = []
|
| 275 |
+
all_scores = []
|
| 276 |
+
all_labels = []
|
| 277 |
+
all_thetas = []
|
| 278 |
+
for boxes, scores, thetas, image_shape in zip(pred_boxes_list, pred_scores_list, pred_theta_list, image_shapes):
|
| 279 |
+
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
|
| 280 |
+
|
| 281 |
+
# create labels for each prediction
|
| 282 |
+
labels = torch.arange(num_classes, device=device)
|
| 283 |
+
labels = labels.view(1, -1).expand_as(scores)
|
| 284 |
+
|
| 285 |
+
# remove predictions with the background label
|
| 286 |
+
boxes = boxes[:, 1:]
|
| 287 |
+
scores = scores[:, 1:]
|
| 288 |
+
labels = labels[:, 1:]
|
| 289 |
+
thetas = thetas[:, 1:]
|
| 290 |
+
|
| 291 |
+
if thetas.shape[1] != 1:
|
| 292 |
+
nbins = thetas.shape[1]
|
| 293 |
+
angle_per_bin = torch.pi / nbins
|
| 294 |
+
max_val_idx = torch.argmax(thetas, dim=1)
|
| 295 |
+
max_val_theta = angle_per_bin * max_val_idx.float()
|
| 296 |
+
|
| 297 |
+
thetas = max_val_theta
|
| 298 |
+
|
| 299 |
+
# batch everything, by making every class prediction be a separate instance
|
| 300 |
+
boxes = boxes.reshape(-1, 4)
|
| 301 |
+
scores = scores.reshape(-1)
|
| 302 |
+
labels = labels.reshape(-1)
|
| 303 |
+
thetas = thetas.reshape(-1)
|
| 304 |
+
|
| 305 |
+
# remove low scoring boxes
|
| 306 |
+
inds = torch.where(scores > self.score_thresh)[0]
|
| 307 |
+
boxes, scores, labels, thetas = boxes[inds], scores[inds], labels[inds], thetas[inds]
|
| 308 |
+
|
| 309 |
+
# remove empty boxes
|
| 310 |
+
keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
|
| 311 |
+
boxes, scores, labels, thetas = boxes[keep], scores[keep], labels[keep], thetas[keep]
|
| 312 |
+
|
| 313 |
+
# non-maximum suppression, independently done per class
|
| 314 |
+
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
|
| 315 |
+
# keep only topk scoring predictions
|
| 316 |
+
keep = keep[: self.detections_per_img]
|
| 317 |
+
boxes, scores, labels, thetas = boxes[keep], scores[keep], labels[keep], thetas[keep]
|
| 318 |
+
|
| 319 |
+
all_boxes.append(boxes)
|
| 320 |
+
all_scores.append(scores)
|
| 321 |
+
all_labels.append(labels)
|
| 322 |
+
all_thetas.append(thetas)
|
| 323 |
+
|
| 324 |
+
return all_boxes, all_scores, all_labels, all_thetas
|
| 325 |
+
|
| 326 |
+
def forward(
|
| 327 |
+
self,
|
| 328 |
+
features, # type: Dict[str, Tensor]
|
| 329 |
+
proposals, # type: List[Tensor]
|
| 330 |
+
image_shapes, # type: List[Tuple[int, int]]
|
| 331 |
+
targets=None, # type: Optional[List[Dict[str, Tensor]]]
|
| 332 |
+
):
|
| 333 |
+
# type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
|
| 334 |
+
"""
|
| 335 |
+
Args:
|
| 336 |
+
features (List[Tensor])
|
| 337 |
+
proposals (List[Tensor[N, 4]])
|
| 338 |
+
image_shapes (List[Tuple[H, W]])
|
| 339 |
+
targets (List[Dict])
|
| 340 |
+
"""
|
| 341 |
+
if targets is not None:
|
| 342 |
+
for t in targets:
|
| 343 |
+
# TODO: https://github.com/pytorch/pytorch/issues/26731
|
| 344 |
+
floating_point_types = (torch.float, torch.double, torch.half)
|
| 345 |
+
if not t["boxes"].dtype in floating_point_types:
|
| 346 |
+
raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
|
| 347 |
+
if not t["labels"].dtype == torch.int64:
|
| 348 |
+
raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
|
| 349 |
+
|
| 350 |
+
# targets_without_theta = []
|
| 351 |
+
# if targets is not None:
|
| 352 |
+
# for target in targets:
|
| 353 |
+
# target_without_theta = {"boxes": target["boxes"][:, :-1], "labels": target["labels"]}
|
| 354 |
+
# targets_without_theta.append(target_without_theta)
|
| 355 |
+
if self.training:
|
| 356 |
+
proposals, matched_idxs, labels, regression_targets, theta_targets = self.select_training_samples(proposals, targets)
|
| 357 |
+
# print("---------")
|
| 358 |
+
# print(f"{theta_targets.shape=}")
|
| 359 |
+
else:
|
| 360 |
+
labels = None
|
| 361 |
+
regression_targets = None
|
| 362 |
+
theta_targets = None
|
| 363 |
+
matched_idxs = None
|
| 364 |
+
|
| 365 |
+
box_features = self.box_roi_pool(features, proposals, image_shapes)
|
| 366 |
+
box_features = self.box_head(box_features)
|
| 367 |
+
|
| 368 |
+
class_logits, box_regression, theta_preds = self.box_predictor(box_features)
|
| 369 |
+
# print(f"{class_logits.shape=}")
|
| 370 |
+
# print(f"{box_regression.shape=}")
|
| 371 |
+
# print(f"{theta_preds.shape=}")
|
| 372 |
+
|
| 373 |
+
result: List[Dict[str, torch.Tensor]] = []
|
| 374 |
+
losses = {}
|
| 375 |
+
if self.training:
|
| 376 |
+
if labels is None:
|
| 377 |
+
raise ValueError("labels cannot be None")
|
| 378 |
+
if regression_targets is None:
|
| 379 |
+
raise ValueError("regression_targets cannot be None")
|
| 380 |
+
loss_classifier, loss_box_reg, loss_theta = fastrcnn_loss(class_logits, box_regression, theta_preds, labels, regression_targets, theta_targets)
|
| 381 |
+
losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg, "loss_theta": loss_theta}
|
| 382 |
+
else:
|
| 383 |
+
boxes, scores, labels, thetas = self.postprocess_detections(class_logits, box_regression, theta_preds, proposals, image_shapes)
|
| 384 |
+
# print(f"{scores[0]=}")
|
| 385 |
+
# print(f"{labels[0]=}")
|
| 386 |
+
# print(f"{thetas[0]=}")
|
| 387 |
+
num_images = len(boxes)
|
| 388 |
+
for i in range(num_images):
|
| 389 |
+
result.append(
|
| 390 |
+
{
|
| 391 |
+
"boxes": boxes[i],
|
| 392 |
+
"labels": labels[i],
|
| 393 |
+
"scores": scores[i],
|
| 394 |
+
"thetas": thetas[i]
|
| 395 |
+
}
|
| 396 |
+
)
|
| 397 |
+
# print(f"{result}")
|
| 398 |
+
|
| 399 |
+
return result, losses
|
| 400 |
+
|
detection/rpn.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, Tensor
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torchvision.ops import boxes as box_ops, Conv2dNormActivation
|
| 7 |
+
|
| 8 |
+
from . import _utils as det_utils
|
| 9 |
+
|
| 10 |
+
# Import AnchorGenerator to keep compatibility.
|
| 11 |
+
from .anchor_utils import AnchorGenerator # noqa: 401
|
| 12 |
+
from .image_list import ImageList
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RPNHead(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
Adds a simple RPN Head with classification and regression heads
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
in_channels (int): number of channels of the input feature
|
| 21 |
+
num_anchors (int): number of anchors to be predicted
|
| 22 |
+
conv_depth (int, optional): number of convolutions
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
_version = 2
|
| 26 |
+
|
| 27 |
+
def __init__(self, in_channels: int, num_anchors: int, conv_depth=1) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
convs = []
|
| 30 |
+
for _ in range(conv_depth):
|
| 31 |
+
convs.append(Conv2dNormActivation(in_channels, in_channels, kernel_size=3, norm_layer=None))
|
| 32 |
+
self.conv = nn.Sequential(*convs)
|
| 33 |
+
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
|
| 34 |
+
self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
|
| 35 |
+
|
| 36 |
+
for layer in self.modules():
|
| 37 |
+
if isinstance(layer, nn.Conv2d):
|
| 38 |
+
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
|
| 39 |
+
if layer.bias is not None:
|
| 40 |
+
torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
|
| 41 |
+
|
| 42 |
+
def _load_from_state_dict(
|
| 43 |
+
self,
|
| 44 |
+
state_dict,
|
| 45 |
+
prefix,
|
| 46 |
+
local_metadata,
|
| 47 |
+
strict,
|
| 48 |
+
missing_keys,
|
| 49 |
+
unexpected_keys,
|
| 50 |
+
error_msgs,
|
| 51 |
+
):
|
| 52 |
+
version = local_metadata.get("version", None)
|
| 53 |
+
|
| 54 |
+
if version is None or version < 2:
|
| 55 |
+
for type in ["weight", "bias"]:
|
| 56 |
+
old_key = f"{prefix}conv.{type}"
|
| 57 |
+
new_key = f"{prefix}conv.0.0.{type}"
|
| 58 |
+
if old_key in state_dict:
|
| 59 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 60 |
+
|
| 61 |
+
super()._load_from_state_dict(
|
| 62 |
+
state_dict,
|
| 63 |
+
prefix,
|
| 64 |
+
local_metadata,
|
| 65 |
+
strict,
|
| 66 |
+
missing_keys,
|
| 67 |
+
unexpected_keys,
|
| 68 |
+
error_msgs,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
|
| 72 |
+
logits = []
|
| 73 |
+
bbox_reg = []
|
| 74 |
+
for feature in x:
|
| 75 |
+
t = self.conv(feature)
|
| 76 |
+
logits.append(self.cls_logits(t))
|
| 77 |
+
bbox_reg.append(self.bbox_pred(t))
|
| 78 |
+
return logits, bbox_reg
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor:
|
| 82 |
+
layer = layer.view(N, -1, C, H, W)
|
| 83 |
+
layer = layer.permute(0, 3, 4, 1, 2)
|
| 84 |
+
layer = layer.reshape(N, -1, C)
|
| 85 |
+
return layer
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]:
|
| 89 |
+
box_cls_flattened = []
|
| 90 |
+
box_regression_flattened = []
|
| 91 |
+
# for each feature level, permute the outputs to make them be in the
|
| 92 |
+
# same format as the labels. Note that the labels are computed for
|
| 93 |
+
# all feature levels concatenated, so we keep the same representation
|
| 94 |
+
# for the objectness and the box_regression
|
| 95 |
+
for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
|
| 96 |
+
N, AxC, H, W = box_cls_per_level.shape
|
| 97 |
+
Ax4 = box_regression_per_level.shape[1]
|
| 98 |
+
A = Ax4 // 4
|
| 99 |
+
C = AxC // A
|
| 100 |
+
box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
|
| 101 |
+
box_cls_flattened.append(box_cls_per_level)
|
| 102 |
+
|
| 103 |
+
box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
|
| 104 |
+
box_regression_flattened.append(box_regression_per_level)
|
| 105 |
+
# concatenate on the first dimension (representing the feature levels), to
|
| 106 |
+
# take into account the way the labels were generated (with all feature maps
|
| 107 |
+
# being concatenated as well)
|
| 108 |
+
box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
|
| 109 |
+
box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
|
| 110 |
+
return box_cls, box_regression
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class RegionProposalNetwork(torch.nn.Module):
|
| 114 |
+
"""
|
| 115 |
+
Implements Region Proposal Network (RPN).
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
|
| 119 |
+
maps.
|
| 120 |
+
head (nn.Module): module that computes the objectness and regression deltas
|
| 121 |
+
fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
|
| 122 |
+
considered as positive during training of the RPN.
|
| 123 |
+
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
|
| 124 |
+
considered as negative during training of the RPN.
|
| 125 |
+
batch_size_per_image (int): number of anchors that are sampled during training of the RPN
|
| 126 |
+
for computing the loss
|
| 127 |
+
positive_fraction (float): proportion of positive anchors in a mini-batch during training
|
| 128 |
+
of the RPN
|
| 129 |
+
pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should
|
| 130 |
+
contain two fields: training and testing, to allow for different values depending
|
| 131 |
+
on training or evaluation
|
| 132 |
+
post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should
|
| 133 |
+
contain two fields: training and testing, to allow for different values depending
|
| 134 |
+
on training or evaluation
|
| 135 |
+
nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
|
| 136 |
+
score_thresh (float): only return proposals with an objectness score greater than score_thresh
|
| 137 |
+
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
__annotations__ = {
|
| 141 |
+
"box_coder": det_utils.BoxCoder,
|
| 142 |
+
"proposal_matcher": det_utils.Matcher,
|
| 143 |
+
"fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
anchor_generator: AnchorGenerator,
|
| 149 |
+
head: nn.Module,
|
| 150 |
+
# Faster-RCNN Training
|
| 151 |
+
fg_iou_thresh: float,
|
| 152 |
+
bg_iou_thresh: float,
|
| 153 |
+
batch_size_per_image: int,
|
| 154 |
+
positive_fraction: float,
|
| 155 |
+
# Faster-RCNN Inference
|
| 156 |
+
pre_nms_top_n: Dict[str, int],
|
| 157 |
+
post_nms_top_n: Dict[str, int],
|
| 158 |
+
nms_thresh: float,
|
| 159 |
+
score_thresh: float = 0.0,
|
| 160 |
+
) -> None:
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.anchor_generator = anchor_generator
|
| 163 |
+
self.head = head
|
| 164 |
+
self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
|
| 165 |
+
|
| 166 |
+
# used during training
|
| 167 |
+
self.box_similarity = box_ops.box_iou
|
| 168 |
+
|
| 169 |
+
self.proposal_matcher = det_utils.Matcher(
|
| 170 |
+
fg_iou_thresh,
|
| 171 |
+
bg_iou_thresh,
|
| 172 |
+
allow_low_quality_matches=True,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
|
| 176 |
+
# used during testing
|
| 177 |
+
self._pre_nms_top_n = pre_nms_top_n
|
| 178 |
+
self._post_nms_top_n = post_nms_top_n
|
| 179 |
+
self.nms_thresh = nms_thresh
|
| 180 |
+
self.score_thresh = score_thresh
|
| 181 |
+
self.min_size = 1e-3
|
| 182 |
+
|
| 183 |
+
def pre_nms_top_n(self) -> int:
|
| 184 |
+
if self.training:
|
| 185 |
+
return self._pre_nms_top_n["training"]
|
| 186 |
+
return self._pre_nms_top_n["testing"]
|
| 187 |
+
|
| 188 |
+
def post_nms_top_n(self) -> int:
|
| 189 |
+
if self.training:
|
| 190 |
+
return self._post_nms_top_n["training"]
|
| 191 |
+
return self._post_nms_top_n["testing"]
|
| 192 |
+
|
| 193 |
+
def assign_targets_to_anchors(
|
| 194 |
+
self, anchors: List[Tensor], targets: List[Dict[str, Tensor]]
|
| 195 |
+
) -> Tuple[List[Tensor], List[Tensor]]:
|
| 196 |
+
|
| 197 |
+
labels = []
|
| 198 |
+
matched_gt_boxes = []
|
| 199 |
+
for anchors_per_image, targets_per_image in zip(anchors, targets):
|
| 200 |
+
gt_boxes = targets_per_image["boxes"]
|
| 201 |
+
|
| 202 |
+
if gt_boxes.numel() == 0:
|
| 203 |
+
# Background image (negative example)
|
| 204 |
+
device = anchors_per_image.device
|
| 205 |
+
matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
|
| 206 |
+
labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
|
| 207 |
+
else:
|
| 208 |
+
match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
|
| 209 |
+
matched_idxs = self.proposal_matcher(match_quality_matrix)
|
| 210 |
+
# get the targets corresponding GT for each proposal
|
| 211 |
+
# NB: need to clamp the indices because we can have a single
|
| 212 |
+
# GT in the image, and matched_idxs can be -2, which goes
|
| 213 |
+
# out of bounds
|
| 214 |
+
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
|
| 215 |
+
|
| 216 |
+
labels_per_image = matched_idxs >= 0
|
| 217 |
+
labels_per_image = labels_per_image.to(dtype=torch.float32)
|
| 218 |
+
|
| 219 |
+
# Background (negative examples)
|
| 220 |
+
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
|
| 221 |
+
labels_per_image[bg_indices] = 0.0
|
| 222 |
+
|
| 223 |
+
# discard indices that are between thresholds
|
| 224 |
+
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
|
| 225 |
+
labels_per_image[inds_to_discard] = -1.0
|
| 226 |
+
|
| 227 |
+
labels.append(labels_per_image)
|
| 228 |
+
matched_gt_boxes.append(matched_gt_boxes_per_image)
|
| 229 |
+
return labels, matched_gt_boxes
|
| 230 |
+
|
| 231 |
+
def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor:
|
| 232 |
+
r = []
|
| 233 |
+
offset = 0
|
| 234 |
+
for ob in objectness.split(num_anchors_per_level, 1):
|
| 235 |
+
num_anchors = ob.shape[1]
|
| 236 |
+
pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1)
|
| 237 |
+
_, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
|
| 238 |
+
r.append(top_n_idx + offset)
|
| 239 |
+
offset += num_anchors
|
| 240 |
+
return torch.cat(r, dim=1)
|
| 241 |
+
|
| 242 |
+
def filter_proposals(
|
| 243 |
+
self,
|
| 244 |
+
proposals: Tensor,
|
| 245 |
+
objectness: Tensor,
|
| 246 |
+
image_shapes: List[Tuple[int, int]],
|
| 247 |
+
num_anchors_per_level: List[int],
|
| 248 |
+
) -> Tuple[List[Tensor], List[Tensor]]:
|
| 249 |
+
|
| 250 |
+
num_images = proposals.shape[0]
|
| 251 |
+
device = proposals.device
|
| 252 |
+
# do not backprop through objectness
|
| 253 |
+
objectness = objectness.detach()
|
| 254 |
+
objectness = objectness.reshape(num_images, -1)
|
| 255 |
+
|
| 256 |
+
levels = [
|
| 257 |
+
torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)
|
| 258 |
+
]
|
| 259 |
+
levels = torch.cat(levels, 0)
|
| 260 |
+
levels = levels.reshape(1, -1).expand_as(objectness)
|
| 261 |
+
|
| 262 |
+
# select top_n boxes independently per level before applying nms
|
| 263 |
+
top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
|
| 264 |
+
|
| 265 |
+
image_range = torch.arange(num_images, device=device)
|
| 266 |
+
batch_idx = image_range[:, None]
|
| 267 |
+
|
| 268 |
+
objectness = objectness[batch_idx, top_n_idx]
|
| 269 |
+
levels = levels[batch_idx, top_n_idx]
|
| 270 |
+
proposals = proposals[batch_idx, top_n_idx]
|
| 271 |
+
|
| 272 |
+
objectness_prob = torch.sigmoid(objectness)
|
| 273 |
+
|
| 274 |
+
final_boxes = []
|
| 275 |
+
final_scores = []
|
| 276 |
+
for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
|
| 277 |
+
boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
|
| 278 |
+
|
| 279 |
+
# remove small boxes
|
| 280 |
+
keep = box_ops.remove_small_boxes(boxes, self.min_size)
|
| 281 |
+
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
|
| 282 |
+
|
| 283 |
+
# remove low scoring boxes
|
| 284 |
+
# use >= for Backwards compatibility
|
| 285 |
+
keep = torch.where(scores >= self.score_thresh)[0]
|
| 286 |
+
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
|
| 287 |
+
|
| 288 |
+
# non-maximum suppression, independently done per level
|
| 289 |
+
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
|
| 290 |
+
|
| 291 |
+
# keep only topk scoring predictions
|
| 292 |
+
keep = keep[: self.post_nms_top_n()]
|
| 293 |
+
boxes, scores = boxes[keep], scores[keep]
|
| 294 |
+
|
| 295 |
+
final_boxes.append(boxes)
|
| 296 |
+
final_scores.append(scores)
|
| 297 |
+
return final_boxes, final_scores
|
| 298 |
+
|
| 299 |
+
def compute_loss(
|
| 300 |
+
self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor]
|
| 301 |
+
) -> Tuple[Tensor, Tensor]:
|
| 302 |
+
"""
|
| 303 |
+
Args:
|
| 304 |
+
objectness (Tensor)
|
| 305 |
+
pred_bbox_deltas (Tensor)
|
| 306 |
+
labels (List[Tensor])
|
| 307 |
+
regression_targets (List[Tensor])
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
objectness_loss (Tensor)
|
| 311 |
+
box_loss (Tensor)
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
|
| 315 |
+
sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
|
| 316 |
+
sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
|
| 317 |
+
|
| 318 |
+
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
|
| 319 |
+
objectness = objectness.flatten()
|
| 320 |
+
|
| 321 |
+
labels = torch.cat(labels, dim=0)
|
| 322 |
+
regression_targets = torch.cat(regression_targets, dim=0)
|
| 323 |
+
|
| 324 |
+
box_loss = F.smooth_l1_loss(
|
| 325 |
+
pred_bbox_deltas[sampled_pos_inds],
|
| 326 |
+
regression_targets[sampled_pos_inds],
|
| 327 |
+
beta=1 / 9,
|
| 328 |
+
reduction="sum",
|
| 329 |
+
) / (sampled_inds.numel())
|
| 330 |
+
|
| 331 |
+
objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds])
|
| 332 |
+
|
| 333 |
+
return objectness_loss, box_loss
|
| 334 |
+
|
| 335 |
+
def forward(
|
| 336 |
+
self,
|
| 337 |
+
images: ImageList,
|
| 338 |
+
features: Dict[str, Tensor],
|
| 339 |
+
targets: Optional[List[Dict[str, Tensor]]] = None,
|
| 340 |
+
) -> Tuple[List[Tensor], Dict[str, Tensor]]:
|
| 341 |
+
|
| 342 |
+
"""
|
| 343 |
+
Args:
|
| 344 |
+
images (ImageList): images for which we want to compute the predictions
|
| 345 |
+
features (Dict[str, Tensor]): features computed from the images that are
|
| 346 |
+
used for computing the predictions. Each tensor in the list
|
| 347 |
+
correspond to different feature levels
|
| 348 |
+
targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional).
|
| 349 |
+
If provided, each element in the dict should contain a field `boxes`,
|
| 350 |
+
with the locations of the ground-truth boxes.
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
|
| 354 |
+
image.
|
| 355 |
+
losses (Dict[str, Tensor]): the losses for the model during training. During
|
| 356 |
+
testing, it is an empty dict.
|
| 357 |
+
"""
|
| 358 |
+
# RPN uses all feature maps that are available
|
| 359 |
+
features = list(features.values())
|
| 360 |
+
objectness, pred_bbox_deltas = self.head(features)
|
| 361 |
+
anchors = self.anchor_generator(images, features)
|
| 362 |
+
num_images = len(anchors)
|
| 363 |
+
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
|
| 364 |
+
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
|
| 365 |
+
objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
|
| 366 |
+
# apply pred_bbox_deltas to anchors to obtain the decoded proposals
|
| 367 |
+
# note that we detach the deltas because Faster R-CNN do not backprop through
|
| 368 |
+
# the proposals
|
| 369 |
+
proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
|
| 370 |
+
proposals = proposals.view(num_images, -1, 4)
|
| 371 |
+
boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
|
| 372 |
+
losses = {}
|
| 373 |
+
if self.training:
|
| 374 |
+
if targets is None:
|
| 375 |
+
raise ValueError("targets should not be None")
|
| 376 |
+
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
|
| 377 |
+
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
|
| 378 |
+
loss_objectness, loss_rpn_box_reg = self.compute_loss(
|
| 379 |
+
objectness, pred_bbox_deltas, labels, regression_targets
|
| 380 |
+
)
|
| 381 |
+
losses = {
|
| 382 |
+
"loss_objectness": loss_objectness,
|
| 383 |
+
"loss_rpn_box_reg": loss_rpn_box_reg,
|
| 384 |
+
}
|
| 385 |
+
return boxes, losses
|
detection/transform.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision
|
| 6 |
+
from torch import nn, Tensor
|
| 7 |
+
|
| 8 |
+
from .image_list import ImageList
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@torch.jit.unused
|
| 12 |
+
def _get_shape_onnx(image: Tensor) -> Tensor:
|
| 13 |
+
from torch.onnx import operators
|
| 14 |
+
|
| 15 |
+
return operators.shape_as_tensor(image)[-2:]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@torch.jit.unused
|
| 19 |
+
def _fake_cast_onnx(v: Tensor) -> float:
|
| 20 |
+
# ONNX requires a tensor but here we fake its type for JIT.
|
| 21 |
+
return v
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _resize_image_and_masks(
|
| 25 |
+
image: Tensor,
|
| 26 |
+
self_min_size: int,
|
| 27 |
+
self_max_size: int,
|
| 28 |
+
target: Optional[Dict[str, Tensor]] = None,
|
| 29 |
+
fixed_size: Optional[Tuple[int, int]] = None,
|
| 30 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
| 31 |
+
if torchvision._is_tracing():
|
| 32 |
+
im_shape = _get_shape_onnx(image)
|
| 33 |
+
elif torch.jit.is_scripting():
|
| 34 |
+
im_shape = torch.tensor(image.shape[-2:])
|
| 35 |
+
else:
|
| 36 |
+
im_shape = image.shape[-2:]
|
| 37 |
+
|
| 38 |
+
size: Optional[List[int]] = None
|
| 39 |
+
scale_factor: Optional[float] = None
|
| 40 |
+
recompute_scale_factor: Optional[bool] = None
|
| 41 |
+
if fixed_size is not None:
|
| 42 |
+
size = [fixed_size[1], fixed_size[0]]
|
| 43 |
+
else:
|
| 44 |
+
if torch.jit.is_scripting() or torchvision._is_tracing():
|
| 45 |
+
min_size = torch.min(im_shape).to(dtype=torch.float32)
|
| 46 |
+
max_size = torch.max(im_shape).to(dtype=torch.float32)
|
| 47 |
+
self_min_size_f = float(self_min_size)
|
| 48 |
+
self_max_size_f = float(self_max_size)
|
| 49 |
+
scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)
|
| 50 |
+
|
| 51 |
+
if torchvision._is_tracing():
|
| 52 |
+
scale_factor = _fake_cast_onnx(scale)
|
| 53 |
+
else:
|
| 54 |
+
scale_factor = scale.item()
|
| 55 |
+
|
| 56 |
+
else:
|
| 57 |
+
# Do it the normal way
|
| 58 |
+
min_size = min(im_shape)
|
| 59 |
+
max_size = max(im_shape)
|
| 60 |
+
scale_factor = min(self_min_size / min_size, self_max_size / max_size)
|
| 61 |
+
|
| 62 |
+
recompute_scale_factor = True
|
| 63 |
+
|
| 64 |
+
image = torch.nn.functional.interpolate(
|
| 65 |
+
image[None],
|
| 66 |
+
size=size,
|
| 67 |
+
scale_factor=scale_factor,
|
| 68 |
+
mode="bilinear",
|
| 69 |
+
recompute_scale_factor=recompute_scale_factor,
|
| 70 |
+
align_corners=False,
|
| 71 |
+
)[0]
|
| 72 |
+
|
| 73 |
+
if target is None:
|
| 74 |
+
return image, target
|
| 75 |
+
|
| 76 |
+
if "masks" in target:
|
| 77 |
+
mask = target["masks"]
|
| 78 |
+
mask = torch.nn.functional.interpolate(
|
| 79 |
+
mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
|
| 80 |
+
)[:, 0].byte()
|
| 81 |
+
target["masks"] = mask
|
| 82 |
+
return image, target
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class GeneralizedRCNNTransform(nn.Module):
|
| 86 |
+
"""
|
| 87 |
+
Performs input / target transformation before feeding the data to a GeneralizedRCNN
|
| 88 |
+
model.
|
| 89 |
+
|
| 90 |
+
The transformations it performs are:
|
| 91 |
+
- input normalization (mean subtraction and std division)
|
| 92 |
+
- input / target resizing to match min_size / max_size
|
| 93 |
+
|
| 94 |
+
It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
min_size: int,
|
| 100 |
+
max_size: int,
|
| 101 |
+
image_mean: List[float],
|
| 102 |
+
image_std: List[float],
|
| 103 |
+
size_divisible: int = 32,
|
| 104 |
+
fixed_size: Optional[Tuple[int, int]] = None,
|
| 105 |
+
**kwargs: Any,
|
| 106 |
+
):
|
| 107 |
+
super().__init__()
|
| 108 |
+
if not isinstance(min_size, (list, tuple)):
|
| 109 |
+
min_size = (min_size,)
|
| 110 |
+
self.min_size = min_size
|
| 111 |
+
self.max_size = max_size
|
| 112 |
+
self.image_mean = image_mean
|
| 113 |
+
self.image_std = image_std
|
| 114 |
+
self.size_divisible = size_divisible
|
| 115 |
+
self.fixed_size = fixed_size
|
| 116 |
+
self._skip_resize = kwargs.pop("_skip_resize", False)
|
| 117 |
+
|
| 118 |
+
def forward(
|
| 119 |
+
self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
|
| 120 |
+
) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
|
| 121 |
+
images = [img for img in images]
|
| 122 |
+
if targets is not None:
|
| 123 |
+
# make a copy of targets to avoid modifying it in-place
|
| 124 |
+
# once torchscript supports dict comprehension
|
| 125 |
+
# this can be simplified as follows
|
| 126 |
+
# targets = [{k: v for k,v in t.items()} for t in targets]
|
| 127 |
+
targets_copy: List[Dict[str, Tensor]] = []
|
| 128 |
+
for t in targets:
|
| 129 |
+
data: Dict[str, Tensor] = {}
|
| 130 |
+
for k, v in t.items():
|
| 131 |
+
data[k] = v
|
| 132 |
+
targets_copy.append(data)
|
| 133 |
+
targets = targets_copy
|
| 134 |
+
for i in range(len(images)):
|
| 135 |
+
image = images[i]
|
| 136 |
+
target_index = targets[i] if targets is not None else None
|
| 137 |
+
|
| 138 |
+
if image.dim() != 3:
|
| 139 |
+
raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
|
| 140 |
+
image = self.normalize(image)
|
| 141 |
+
image, target_index = self.resize(image, target_index)
|
| 142 |
+
images[i] = image
|
| 143 |
+
if targets is not None and target_index is not None:
|
| 144 |
+
targets[i] = target_index
|
| 145 |
+
|
| 146 |
+
image_sizes = [img.shape[-2:] for img in images]
|
| 147 |
+
images = self.batch_images(images, size_divisible=self.size_divisible)
|
| 148 |
+
image_sizes_list: List[Tuple[int, int]] = []
|
| 149 |
+
for image_size in image_sizes:
|
| 150 |
+
torch._assert(
|
| 151 |
+
len(image_size) == 2,
|
| 152 |
+
f"Input tensors expected to have in the last two elements H and W, instead got {image_size}",
|
| 153 |
+
)
|
| 154 |
+
image_sizes_list.append((image_size[0], image_size[1]))
|
| 155 |
+
|
| 156 |
+
image_list = ImageList(images, image_sizes_list)
|
| 157 |
+
return image_list, targets
|
| 158 |
+
|
| 159 |
+
def normalize(self, image: Tensor) -> Tensor:
|
| 160 |
+
if not image.is_floating_point():
|
| 161 |
+
raise TypeError(
|
| 162 |
+
f"Expected input images to be of floating type (in range [0, 1]), "
|
| 163 |
+
f"but found type {image.dtype} instead"
|
| 164 |
+
)
|
| 165 |
+
dtype, device = image.dtype, image.device
|
| 166 |
+
mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
|
| 167 |
+
std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
|
| 168 |
+
return (image - mean[:, None, None]) / std[:, None, None]
|
| 169 |
+
|
| 170 |
+
def torch_choice(self, k: List[int]) -> int:
|
| 171 |
+
"""
|
| 172 |
+
Implements `random.choice` via torch ops, so it can be compiled with
|
| 173 |
+
TorchScript and we use PyTorch's RNG (not native RNG)
|
| 174 |
+
"""
|
| 175 |
+
index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
|
| 176 |
+
return k[index]
|
| 177 |
+
|
| 178 |
+
def resize(
|
| 179 |
+
self,
|
| 180 |
+
image: Tensor,
|
| 181 |
+
target: Optional[Dict[str, Tensor]] = None,
|
| 182 |
+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
| 183 |
+
h, w = image.shape[-2:]
|
| 184 |
+
if self.training:
|
| 185 |
+
if self._skip_resize:
|
| 186 |
+
return image, target
|
| 187 |
+
size = self.torch_choice(self.min_size)
|
| 188 |
+
else:
|
| 189 |
+
size = self.min_size[-1]
|
| 190 |
+
image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
|
| 191 |
+
|
| 192 |
+
if target is None:
|
| 193 |
+
return image, target
|
| 194 |
+
|
| 195 |
+
bbox = target["boxes"]
|
| 196 |
+
bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
|
| 197 |
+
target["boxes"] = bbox
|
| 198 |
+
|
| 199 |
+
if "keypoints" in target:
|
| 200 |
+
keypoints = target["keypoints"]
|
| 201 |
+
keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
|
| 202 |
+
target["keypoints"] = keypoints
|
| 203 |
+
return image, target
|
| 204 |
+
|
| 205 |
+
# _onnx_batch_images() is an implementation of
|
| 206 |
+
# batch_images() that is supported by ONNX tracing.
|
| 207 |
+
@torch.jit.unused
|
| 208 |
+
def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
|
| 209 |
+
max_size = []
|
| 210 |
+
for i in range(images[0].dim()):
|
| 211 |
+
max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
|
| 212 |
+
max_size.append(max_size_i)
|
| 213 |
+
stride = size_divisible
|
| 214 |
+
max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
|
| 215 |
+
max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
|
| 216 |
+
max_size = tuple(max_size)
|
| 217 |
+
|
| 218 |
+
# work around for
|
| 219 |
+
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
| 220 |
+
# which is not yet supported in onnx
|
| 221 |
+
padded_imgs = []
|
| 222 |
+
for img in images:
|
| 223 |
+
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
| 224 |
+
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
| 225 |
+
padded_imgs.append(padded_img)
|
| 226 |
+
|
| 227 |
+
return torch.stack(padded_imgs)
|
| 228 |
+
|
| 229 |
+
def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
|
| 230 |
+
maxes = the_list[0]
|
| 231 |
+
for sublist in the_list[1:]:
|
| 232 |
+
for index, item in enumerate(sublist):
|
| 233 |
+
maxes[index] = max(maxes[index], item)
|
| 234 |
+
return maxes
|
| 235 |
+
|
| 236 |
+
def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
|
| 237 |
+
if torchvision._is_tracing():
|
| 238 |
+
# batch_images() does not export well to ONNX
|
| 239 |
+
# call _onnx_batch_images() instead
|
| 240 |
+
return self._onnx_batch_images(images, size_divisible)
|
| 241 |
+
|
| 242 |
+
max_size = self.max_by_axis([list(img.shape) for img in images])
|
| 243 |
+
stride = float(size_divisible)
|
| 244 |
+
max_size = list(max_size)
|
| 245 |
+
max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
|
| 246 |
+
max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
|
| 247 |
+
|
| 248 |
+
batch_shape = [len(images)] + max_size
|
| 249 |
+
batched_imgs = images[0].new_full(batch_shape, 0)
|
| 250 |
+
for i in range(batched_imgs.shape[0]):
|
| 251 |
+
img = images[i]
|
| 252 |
+
batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
| 253 |
+
|
| 254 |
+
return batched_imgs
|
| 255 |
+
|
| 256 |
+
def postprocess(
|
| 257 |
+
self,
|
| 258 |
+
result: List[Dict[str, Tensor]],
|
| 259 |
+
image_shapes: List[Tuple[int, int]],
|
| 260 |
+
original_image_sizes: List[Tuple[int, int]],
|
| 261 |
+
) -> List[Dict[str, Tensor]]:
|
| 262 |
+
if self.training:
|
| 263 |
+
return result
|
| 264 |
+
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
|
| 265 |
+
boxes = pred["boxes"]
|
| 266 |
+
boxes = resize_boxes(boxes, im_s, o_im_s)
|
| 267 |
+
result[i]["boxes"] = boxes
|
| 268 |
+
if "masks" in pred:
|
| 269 |
+
masks = pred["masks"]
|
| 270 |
+
masks = paste_masks_in_image(masks, boxes, o_im_s)
|
| 271 |
+
result[i]["masks"] = masks
|
| 272 |
+
if "keypoints" in pred:
|
| 273 |
+
keypoints = pred["keypoints"]
|
| 274 |
+
keypoints = resize_keypoints(keypoints, im_s, o_im_s)
|
| 275 |
+
result[i]["keypoints"] = keypoints
|
| 276 |
+
return result
|
| 277 |
+
|
| 278 |
+
def __repr__(self) -> str:
|
| 279 |
+
format_string = f"{self.__class__.__name__}("
|
| 280 |
+
_indent = "\n "
|
| 281 |
+
format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})"
|
| 282 |
+
format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')"
|
| 283 |
+
format_string += "\n)"
|
| 284 |
+
return format_string
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
|
| 288 |
+
ratios = [
|
| 289 |
+
torch.tensor(s, dtype=torch.float32, device=keypoints.device)
|
| 290 |
+
/ torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
|
| 291 |
+
for s, s_orig in zip(new_size, original_size)
|
| 292 |
+
]
|
| 293 |
+
ratio_h, ratio_w = ratios
|
| 294 |
+
resized_data = keypoints.clone()
|
| 295 |
+
if torch._C._get_tracing_state():
|
| 296 |
+
resized_data_0 = resized_data[:, :, 0] * ratio_w
|
| 297 |
+
resized_data_1 = resized_data[:, :, 1] * ratio_h
|
| 298 |
+
resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
|
| 299 |
+
else:
|
| 300 |
+
resized_data[..., 0] *= ratio_w
|
| 301 |
+
resized_data[..., 1] *= ratio_h
|
| 302 |
+
return resized_data
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
|
| 306 |
+
ratios = [
|
| 307 |
+
torch.tensor(s, dtype=torch.float32, device=boxes.device)
|
| 308 |
+
/ torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
|
| 309 |
+
for s, s_orig in zip(new_size, original_size)
|
| 310 |
+
]
|
| 311 |
+
ratio_height, ratio_width = ratios
|
| 312 |
+
xmin, ymin, xmax, ymax = boxes.unbind(1)
|
| 313 |
+
|
| 314 |
+
xmin = xmin * ratio_width
|
| 315 |
+
xmax = xmax * ratio_width
|
| 316 |
+
ymin = ymin * ratio_height
|
| 317 |
+
ymax = ymax * ratio_height
|
| 318 |
+
return torch.stack((xmin, ymin, xmax, ymax), dim=1)
|
infer.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import torchvision
|
| 5 |
+
import argparse
|
| 6 |
+
import random
|
| 7 |
+
import os
|
| 8 |
+
import yaml
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from dataset.st import SceneTextDataset
|
| 11 |
+
from torch.utils.data.dataloader import DataLoader
|
| 12 |
+
|
| 13 |
+
import detection
|
| 14 |
+
from detection.faster_rcnn import FastRCNNPredictor
|
| 15 |
+
from shapely.geometry import Polygon
|
| 16 |
+
from detection.anchor_utils import AnchorGenerator
|
| 17 |
+
|
| 18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_iou(det, gt):
|
| 22 |
+
det_x, det_y, det_w, det_h, det_theta = det
|
| 23 |
+
gt_x, gt_y, gt_w, gt_h, gt_theta = gt
|
| 24 |
+
|
| 25 |
+
def get_rotated_box(x, y, w, h, theta):
|
| 26 |
+
cos_t, sin_t = np.cos(theta), np.sin(theta)
|
| 27 |
+
dx, dy = w / 2, h / 2
|
| 28 |
+
corners = np.array([
|
| 29 |
+
[-dx, -dy], [dx, -dy], [dx, dy], [-dx, dy]
|
| 30 |
+
])
|
| 31 |
+
rotation_matrix = np.array([[cos_t, -sin_t], [sin_t, cos_t]])
|
| 32 |
+
rotated_corners = np.dot(corners, rotation_matrix.T) + np.array([x, y])
|
| 33 |
+
return Polygon(rotated_corners)
|
| 34 |
+
|
| 35 |
+
det_poly = get_rotated_box(det_x, det_y, det_w, det_h, det_theta)
|
| 36 |
+
gt_poly = get_rotated_box(gt_x, gt_y, gt_w, gt_h, gt_theta)
|
| 37 |
+
|
| 38 |
+
if not det_poly.intersects(gt_poly):
|
| 39 |
+
return 0.0
|
| 40 |
+
|
| 41 |
+
intersection_area = det_poly.intersection(gt_poly).area
|
| 42 |
+
union_area = det_poly.area + gt_poly.area - intersection_area + 1E-6
|
| 43 |
+
return intersection_area / union_area
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def compute_map(det_boxes, gt_boxes, iou_threshold=0.5, method="area", return_pr=False):
|
| 47 |
+
gt_labels = {cls_key for im_gt in gt_boxes for cls_key in im_gt.keys()}
|
| 48 |
+
gt_labels = sorted(gt_labels)
|
| 49 |
+
|
| 50 |
+
all_aps = {}
|
| 51 |
+
all_precisions = {}
|
| 52 |
+
all_recalls = {}
|
| 53 |
+
|
| 54 |
+
aps = []
|
| 55 |
+
|
| 56 |
+
for idx, label in enumerate(gt_labels):
|
| 57 |
+
# Get detection predictions of this class
|
| 58 |
+
cls_dets = [
|
| 59 |
+
[im_idx, im_dets_label]
|
| 60 |
+
for im_idx, im_dets in enumerate(det_boxes)
|
| 61 |
+
if label in im_dets
|
| 62 |
+
for im_dets_label in im_dets[label]
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
# Sort by confidence score (descending)
|
| 66 |
+
cls_dets = sorted(cls_dets, key=lambda k: -k[1][-1])
|
| 67 |
+
|
| 68 |
+
# Track matched GT boxes
|
| 69 |
+
gt_matched = [[False for _ in im_gts[label]] for im_gts in gt_boxes]
|
| 70 |
+
num_gts = sum([len(im_gts[label]) for im_gts in gt_boxes])
|
| 71 |
+
|
| 72 |
+
tp = np.zeros(len(cls_dets))
|
| 73 |
+
fp = np.zeros(len(cls_dets))
|
| 74 |
+
|
| 75 |
+
# Process each detection
|
| 76 |
+
for det_idx, (im_idx, det_pred) in enumerate(cls_dets):
|
| 77 |
+
im_gts = gt_boxes[im_idx][label]
|
| 78 |
+
max_iou_found = -1
|
| 79 |
+
max_iou_gt_idx = -1
|
| 80 |
+
|
| 81 |
+
# Find the best-matching GT box
|
| 82 |
+
for gt_box_idx, gt_box in enumerate(im_gts):
|
| 83 |
+
gt_box_iou = get_iou(det_pred[:-1], gt_box)
|
| 84 |
+
if gt_box_iou > max_iou_found:
|
| 85 |
+
max_iou_found = gt_box_iou
|
| 86 |
+
max_iou_gt_idx = gt_box_idx
|
| 87 |
+
|
| 88 |
+
# True Positive if IoU >= threshold & GT box is not already matched
|
| 89 |
+
if max_iou_found < iou_threshold or gt_matched[im_idx][max_iou_gt_idx]:
|
| 90 |
+
fp[det_idx] = 1
|
| 91 |
+
else:
|
| 92 |
+
tp[det_idx] = 1
|
| 93 |
+
gt_matched[im_idx][max_iou_gt_idx] = True
|
| 94 |
+
|
| 95 |
+
# Compute cumulative sums for TP and FP
|
| 96 |
+
tp = np.cumsum(tp)
|
| 97 |
+
fp = np.cumsum(fp)
|
| 98 |
+
|
| 99 |
+
eps = np.finfo(np.float32).eps
|
| 100 |
+
recalls = tp / np.maximum(num_gts, eps)
|
| 101 |
+
precisions = tp / np.maximum(tp + fp, eps)
|
| 102 |
+
|
| 103 |
+
# Compute AP
|
| 104 |
+
if method == "area":
|
| 105 |
+
recalls = np.concatenate(([0.0], recalls, [1.0]))
|
| 106 |
+
precisions = np.concatenate(([0.0], precisions, [0.0]))
|
| 107 |
+
|
| 108 |
+
for i in range(len(precisions) - 1, 0, -1):
|
| 109 |
+
precisions[i - 1] = np.maximum(precisions[i - 1], precisions[i])
|
| 110 |
+
|
| 111 |
+
i = np.where(recalls[1:] != recalls[:-1])[0]
|
| 112 |
+
ap = np.sum((recalls[i + 1] - recalls[i]) * precisions[i + 1])
|
| 113 |
+
|
| 114 |
+
elif method == "interp":
|
| 115 |
+
ap = (
|
| 116 |
+
sum(
|
| 117 |
+
[
|
| 118 |
+
max(precisions[recalls >= t]) if any(recalls >= t) else 0
|
| 119 |
+
for t in np.arange(0, 1.1, 0.1)
|
| 120 |
+
]
|
| 121 |
+
)
|
| 122 |
+
/ 11.0
|
| 123 |
+
)
|
| 124 |
+
else:
|
| 125 |
+
raise ValueError("Method must be 'area' or 'interp'")
|
| 126 |
+
|
| 127 |
+
if num_gts > 0:
|
| 128 |
+
aps.append(ap)
|
| 129 |
+
all_aps[label] = ap
|
| 130 |
+
all_precisions[label] = precisions.tolist()
|
| 131 |
+
all_recalls[label] = recalls.tolist()
|
| 132 |
+
else:
|
| 133 |
+
all_aps[label] = np.nan
|
| 134 |
+
all_precisions[label] = []
|
| 135 |
+
all_recalls[label] = []
|
| 136 |
+
|
| 137 |
+
mean_ap = sum(aps) / len(aps) if aps else 0.0
|
| 138 |
+
|
| 139 |
+
if return_pr:
|
| 140 |
+
return mean_ap, all_aps, all_precisions, all_recalls
|
| 141 |
+
else:
|
| 142 |
+
return mean_ap, all_aps
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def load_model_and_dataset(args):
|
| 146 |
+
# Read the config file #
|
| 147 |
+
with open(args.config_path, "r") as file:
|
| 148 |
+
try:
|
| 149 |
+
config = yaml.safe_load(file)
|
| 150 |
+
except yaml.YAMLError as exc:
|
| 151 |
+
print(exc)
|
| 152 |
+
print(config)
|
| 153 |
+
########################
|
| 154 |
+
|
| 155 |
+
dataset_config = config["dataset_params"]
|
| 156 |
+
model_config = config["model_params"]
|
| 157 |
+
train_config = config["train_params"]
|
| 158 |
+
|
| 159 |
+
seed = train_config["seed"]
|
| 160 |
+
torch.manual_seed(seed)
|
| 161 |
+
np.random.seed(seed)
|
| 162 |
+
random.seed(seed)
|
| 163 |
+
if device == "cuda":
|
| 164 |
+
torch.cuda.manual_seed_all(seed)
|
| 165 |
+
|
| 166 |
+
st = SceneTextDataset(args.split_type, root_dir=dataset_config["root_dir"])
|
| 167 |
+
test_dataset = DataLoader(st, batch_size=1, shuffle=False)
|
| 168 |
+
|
| 169 |
+
faster_rcnn_model = detection.fasterrcnn_resnet50_fpn(
|
| 170 |
+
pretrained=True,
|
| 171 |
+
min_size=600,
|
| 172 |
+
max_size=1000,
|
| 173 |
+
box_score_thresh=0.7,
|
| 174 |
+
)
|
| 175 |
+
faster_rcnn_model.roi_heads.box_predictor = FastRCNNPredictor(
|
| 176 |
+
faster_rcnn_model.roi_heads.box_predictor.cls_score.in_features,
|
| 177 |
+
num_classes=dataset_config["num_classes"],
|
| 178 |
+
num_theta_bins=args.num_theta_bins,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
faster_rcnn_model.eval()
|
| 182 |
+
faster_rcnn_model.to(device)
|
| 183 |
+
faster_rcnn_model.load_state_dict(
|
| 184 |
+
torch.load(
|
| 185 |
+
os.path.join(
|
| 186 |
+
train_config["task_name"],
|
| 187 |
+
"tv_frcnn_r50fpn_" + train_config["ckpt_name"],
|
| 188 |
+
),
|
| 189 |
+
map_location='cpu',
|
| 190 |
+
)
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
return faster_rcnn_model, st, test_dataset
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def evaluate_metrics(args):
|
| 197 |
+
faster_rcnn_model, voc, test_dataset = load_model_and_dataset(args)
|
| 198 |
+
|
| 199 |
+
gts = []
|
| 200 |
+
preds = []
|
| 201 |
+
|
| 202 |
+
for im, target, fname in tqdm(test_dataset):
|
| 203 |
+
im_name = fname
|
| 204 |
+
im = im.float().to(device)
|
| 205 |
+
target_boxes = target["bboxes"].float().to(device)[0]
|
| 206 |
+
target_labels = target["labels"].long().to(device)[0]
|
| 207 |
+
target_thetas = target["thetas"].float().to(device)[0]
|
| 208 |
+
frcnn_output = faster_rcnn_model(im, None)[0]
|
| 209 |
+
|
| 210 |
+
boxes = frcnn_output["boxes"]
|
| 211 |
+
labels = frcnn_output["labels"]
|
| 212 |
+
scores = frcnn_output["scores"]
|
| 213 |
+
thetas = frcnn_output["thetas"]
|
| 214 |
+
|
| 215 |
+
pred_boxes = {label_name: [] for label_name in voc.label2idx}
|
| 216 |
+
gt_boxes = {label_name: [] for label_name in voc.label2idx}
|
| 217 |
+
|
| 218 |
+
for idx, box in enumerate(boxes):
|
| 219 |
+
x1, y1, x2, y2 = box.detach().cpu().numpy()
|
| 220 |
+
label = labels[idx].detach().cpu().item()
|
| 221 |
+
score = scores[idx].detach().cpu().item()
|
| 222 |
+
theta = thetas[idx].detach().cpu().item()
|
| 223 |
+
label_name = voc.idx2label[label]
|
| 224 |
+
pred_boxes[label_name].append([x1, y1, x2, y2, theta, score])
|
| 225 |
+
|
| 226 |
+
for idx, box in enumerate(target_boxes):
|
| 227 |
+
x1, y1, x2, y2 = box.detach().cpu().numpy()
|
| 228 |
+
label = target_labels[idx].detach().cpu().item()
|
| 229 |
+
label_name = voc.idx2label[label]
|
| 230 |
+
theta = target_thetas[idx].detach().cpu().item()
|
| 231 |
+
gt_boxes[label_name].append([x1, y1, x2, y2, theta])
|
| 232 |
+
|
| 233 |
+
gts.append(gt_boxes)
|
| 234 |
+
preds.append(pred_boxes)
|
| 235 |
+
|
| 236 |
+
# Compute Mean Average Precision and Precision-Recall values
|
| 237 |
+
mean_ap, all_aps, precisions, recalls = compute_map(
|
| 238 |
+
preds, gts, method="interp", return_pr=True
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
mean_precision = 0
|
| 242 |
+
mean_recall = 0
|
| 243 |
+
num_classes = len(voc.idx2label)
|
| 244 |
+
|
| 245 |
+
for idx in range(num_classes):
|
| 246 |
+
class_name = voc.idx2label[idx]
|
| 247 |
+
ap = all_aps[class_name]
|
| 248 |
+
prec = precisions[class_name]
|
| 249 |
+
rec = recalls[class_name]
|
| 250 |
+
|
| 251 |
+
mean_precision += sum(prec) / len(prec) if len(prec) > 0 else 0
|
| 252 |
+
mean_recall += sum(rec) / len(rec) if len(rec) > 0 else 0
|
| 253 |
+
|
| 254 |
+
print(f"Class: {class_name}")
|
| 255 |
+
print(
|
| 256 |
+
f" AP: {ap:.4f}, Precision: {sum(prec) / len(prec) if len(prec) > 0 else 0:.4f}, Recall: {sum(rec) / len(rec) if len(rec) > 0 else 0:.4f}"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
mean_precision /= num_classes
|
| 260 |
+
mean_recall /= num_classes
|
| 261 |
+
|
| 262 |
+
print(f"Mean Average Precision (mAP): {mean_ap:.4f}")
|
| 263 |
+
print(f"Mean Precision: {mean_precision:.4f}")
|
| 264 |
+
print(f"Mean Recall: {mean_recall:.4f}")
|
| 265 |
+
return mean_ap, mean_precision, mean_recall
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def infer(args):
|
| 269 |
+
|
| 270 |
+
output_dir = "samples_tv_r50fpn"
|
| 271 |
+
if not os.path.exists(output_dir):
|
| 272 |
+
os.mkdir(output_dir)
|
| 273 |
+
faster_rcnn_model, voc, test_dataset = load_model_and_dataset(args)
|
| 274 |
+
|
| 275 |
+
for sample_count in tqdm(range(10)):
|
| 276 |
+
random_idx = random.randint(0, len(voc))
|
| 277 |
+
im, target, fname = voc[random_idx]
|
| 278 |
+
im = im.unsqueeze(0).float().to(device)
|
| 279 |
+
|
| 280 |
+
gt_im = cv2.imread(fname)
|
| 281 |
+
gt_im_copy = gt_im.copy()
|
| 282 |
+
|
| 283 |
+
# Saving images with ground truth boxes
|
| 284 |
+
for idx, box in enumerate(target["bboxes"]):
|
| 285 |
+
x1, y1, x2, y2 = box.detach().cpu().numpy()
|
| 286 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
| 287 |
+
theta = target["thetas"][idx].detach().cpu().numpy() * 180 / np.pi
|
| 288 |
+
|
| 289 |
+
cx, cy, w, h = (x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1
|
| 290 |
+
box = cv2.boxPoints(((cx, cy), (w, h), theta))
|
| 291 |
+
box = box.astype(np.int32)
|
| 292 |
+
cv2.drawContours(gt_im, [box], 0, (0, 255, 0), 2)
|
| 293 |
+
cv2.drawContours(gt_im_copy, [box], 0, (0, 255, 0), 2)
|
| 294 |
+
|
| 295 |
+
cv2.addWeighted(gt_im_copy, 0.7, gt_im, 0.3, 0, gt_im)
|
| 296 |
+
cv2.imwrite("{}/output_frcnn_gt_{}.png".format(output_dir, sample_count), gt_im)
|
| 297 |
+
|
| 298 |
+
# Getting predictions from trained model
|
| 299 |
+
frcnn_output = faster_rcnn_model(im, None)[0]
|
| 300 |
+
boxes = frcnn_output["boxes"]
|
| 301 |
+
labels = frcnn_output["labels"]
|
| 302 |
+
scores = frcnn_output["scores"]
|
| 303 |
+
thetas = frcnn_output["thetas"]
|
| 304 |
+
im = cv2.imread(fname)
|
| 305 |
+
im_copy = im.copy()
|
| 306 |
+
|
| 307 |
+
# Saving images with predicted boxes
|
| 308 |
+
for idx, box in enumerate(boxes):
|
| 309 |
+
x1, y1, x2, y2 = box.detach().cpu().numpy()
|
| 310 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
| 311 |
+
theta = thetas[idx].detach().cpu().numpy() * 180 / np.pi
|
| 312 |
+
cx, cy, w, h = (x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1
|
| 313 |
+
box = cv2.boxPoints(((cx, cy), (w, h), theta))
|
| 314 |
+
box = box.astype(np.int32)
|
| 315 |
+
cv2.drawContours(im, [box], 0, (0, 255, 0), 2)
|
| 316 |
+
cv2.drawContours(im_copy, [box], 0, (0, 255, 0), 2)
|
| 317 |
+
cv2.addWeighted(im_copy, 0.7, im, 0.3, 0, im)
|
| 318 |
+
cv2.imwrite("{}/output_frcnn_{}.jpg".format(output_dir, sample_count), im)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
if __name__ == "__main__":
|
| 322 |
+
|
| 323 |
+
# print(torch)
|
| 324 |
+
|
| 325 |
+
parser = argparse.ArgumentParser(
|
| 326 |
+
description="Arguments for inference using torchvision code faster rcnn"
|
| 327 |
+
)
|
| 328 |
+
parser.add_argument(
|
| 329 |
+
"--config", dest="config_path", default="config/st.yaml", type=str
|
| 330 |
+
)
|
| 331 |
+
parser.add_argument("--evaluate", dest="evaluate", default=False, type=bool)
|
| 332 |
+
parser.add_argument(
|
| 333 |
+
"--infer_samples", dest="infer_samples", default=True, type=bool
|
| 334 |
+
)
|
| 335 |
+
args = parser.parse_args()
|
| 336 |
+
args.split_type = "train"
|
| 337 |
+
args.num_theta_bins = 359
|
| 338 |
+
if args.infer_samples:
|
| 339 |
+
infer(args)
|
| 340 |
+
else:
|
| 341 |
+
print("Not Inferring for samples as `infer_samples` argument is False")
|
| 342 |
+
|
| 343 |
+
if args.evaluate:
|
| 344 |
+
evaluate_metrics(args)
|
| 345 |
+
else:
|
| 346 |
+
print("Not Evaluating as `evaluate` argument is False")
|
requirements.txt
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets==3.4.0
|
| 2 |
+
deepdiff==8.4.1
|
| 3 |
+
dill==0.3.8
|
| 4 |
+
distlib==0.3.9
|
| 5 |
+
docker-pycreds==0.4.0
|
| 6 |
+
evaluate==0.4.3
|
| 7 |
+
fastapi==0.115.12
|
| 8 |
+
fastrlock==0.8.3
|
| 9 |
+
ffmpy==0.5.0
|
| 10 |
+
filelock==3.17.0
|
| 11 |
+
flatbuffers==25.2.10
|
| 12 |
+
fonttools==4.56.0
|
| 13 |
+
frozenlist==1.5.0
|
| 14 |
+
fsspec==2024.12.0
|
| 15 |
+
gast==0.6.0
|
| 16 |
+
gensim==4.3.3
|
| 17 |
+
gitdb==4.0.12
|
| 18 |
+
gradio==5.25.2
|
| 19 |
+
gradio_client==1.8.0
|
| 20 |
+
groovy==0.1.2
|
| 21 |
+
grpcio==1.71.0
|
| 22 |
+
h11==0.14.0
|
| 23 |
+
h5py==3.13.0
|
| 24 |
+
httpcore==1.0.8
|
| 25 |
+
httpx==0.28.1
|
| 26 |
+
huggingface-hub==0.29.3
|
| 27 |
+
humanfriendly==10.0
|
| 28 |
+
idna==3.10
|
| 29 |
+
imageio==2.37.0
|
| 30 |
+
Jinja2==3.1.6
|
| 31 |
+
joblib==1.4.2
|
| 32 |
+
jsonlines==4.0.0
|
| 33 |
+
kiwisolver==1.4.8
|
| 34 |
+
libclang==18.1.1
|
| 35 |
+
lightning-utilities==0.14.3
|
| 36 |
+
lm_eval==0.4.8
|
| 37 |
+
lxml==5.3.1
|
| 38 |
+
Markdown==3.7
|
| 39 |
+
markdown-it-py==3.0.0
|
| 40 |
+
MarkupSafe==3.0.2
|
| 41 |
+
matplotlib==3.10.1
|
| 42 |
+
mbstrdecoder==1.1.4
|
| 43 |
+
mdurl==0.1.2
|
| 44 |
+
ml_dtypes==0.5.1
|
| 45 |
+
more-itertools==10.6.0
|
| 46 |
+
mpmath==1.3.0
|
| 47 |
+
multidict==6.1.0
|
| 48 |
+
multiprocess==0.70.16
|
| 49 |
+
namex==0.0.8
|
| 50 |
+
networkx==3.4.2
|
| 51 |
+
nltk==3.9.1
|
| 52 |
+
numexpr==2.10.2
|
| 53 |
+
numpy==1.26.4
|
| 54 |
+
opencv-python==4.11.0.86
|
| 55 |
+
opt_einsum==3.4.0
|
| 56 |
+
optree==0.14.1
|
| 57 |
+
orderly-set==5.3.0
|
| 58 |
+
orjson==3.10.16
|
| 59 |
+
pandas==2.2.3
|
| 60 |
+
pathvalidate==3.2.3
|
| 61 |
+
peft==0.14.0
|
| 62 |
+
pillow==11.1.0
|
| 63 |
+
pluggy==1.5.0
|
| 64 |
+
portalocker==3.1.1
|
| 65 |
+
propcache==0.3.0
|
| 66 |
+
protobuf==5.29.3
|
| 67 |
+
pyahocorasick==2.1.0
|
| 68 |
+
pyarrow==19.0.1
|
| 69 |
+
pybind11==2.13.6
|
| 70 |
+
pydantic==2.10.6
|
| 71 |
+
pydantic_core==2.27.2
|
| 72 |
+
pydub==0.25.1
|
| 73 |
+
pyparsing==3.2.1
|
| 74 |
+
pyproject-api==1.9.0
|
| 75 |
+
pytablewriter==1.2.1
|
| 76 |
+
python-multipart==0.0.20
|
| 77 |
+
pytz==2025.1
|
| 78 |
+
PyYAML==6.0.2
|
| 79 |
+
regex==2024.11.6
|
| 80 |
+
requests==2.32.3
|
| 81 |
+
rich==13.9.4
|
| 82 |
+
rootpath==0.1.1
|
| 83 |
+
rouge_score==0.1.2
|
| 84 |
+
ruff==0.11.5
|
| 85 |
+
sacrebleu==2.5.1
|
| 86 |
+
safehttpx==0.1.6
|
| 87 |
+
safetensors==0.5.3
|
| 88 |
+
scikit-learn==1.6.1
|
| 89 |
+
scipy==1.13.1
|
| 90 |
+
seaborn==0.13.2
|
| 91 |
+
semantic-version==2.10.0
|
| 92 |
+
sentencepiece==0.2.0
|
| 93 |
+
sentry-sdk==2.22.0
|
| 94 |
+
setproctitle==1.3.5
|
| 95 |
+
shapely==2.0.7
|
| 96 |
+
shellingham==1.5.4
|
| 97 |
+
smart-open==7.1.0
|
| 98 |
+
smmap==5.0.2
|
| 99 |
+
sniffio==1.3.1
|
| 100 |
+
sns==0.1
|
| 101 |
+
torch==2.6.0
|
| 102 |
+
torchaudio==2.6.0
|
| 103 |
+
torchmetrics==1.7.1
|
| 104 |
+
torchvision==0.21.0
|
| 105 |
+
tqdm==4.67.1
|
st/tv_frcnn_r50fpn_faster_rcnn_st.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df8a754ab48d6c49503dc74fc9840a205d8b11561c2c8f68f736ad1c39a810dc
|
| 3 |
+
size 167210987
|
st/tv_frcnn_r50fpn_faster_rcnn_st_10.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0c4bf291c6993de1c5d9d52da604551de7c6b5f6d499db05da7da29c4230bb4b
|
| 3 |
+
size 165780139
|