Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,806 Bytes
a249588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from mmengine.structures import InstanceData
from torch import Tensor
from mmpose.registry import TASK_UTILS
from mmpose.utils.typing import ConfigType
INF = 100000.0
EPS = 1.0e-7
@TASK_UTILS.register_module()
class SimOTAAssigner:
"""Computes matching between predictions and ground truth.
Args:
center_radius (float): Radius of center area to determine
if a prior is in the center of a gt. Defaults to 2.5.
candidate_topk (int): Top-k ious candidates to calculate dynamic-k.
Defaults to 10.
iou_weight (float): Weight of bbox iou cost. Defaults to 3.0.
cls_weight (float): Weight of classification cost. Defaults to 1.0.
oks_weight (float): Weight of keypoint OKS cost. Defaults to 3.0.
vis_weight (float): Weight of keypoint visibility cost. Defaults to 0.0
dynamic_k_indicator (str): Cost type for calculating dynamic-k,
either 'iou' or 'oks'. Defaults to 'iou'.
use_keypoints_for_center (bool): Whether to use keypoints to determine
if a prior is in the center of a gt. Defaults to False.
iou_calculator (dict): Config of IoU calculation method.
Defaults to dict(type='BBoxOverlaps2D').
oks_calculator (dict): Config of OKS calculation method.
Defaults to dict(type='PoseOKS').
"""
def __init__(self,
center_radius: float = 2.5,
candidate_topk: int = 10,
iou_weight: float = 3.0,
cls_weight: float = 1.0,
oks_weight: float = 3.0,
vis_weight: float = 0.0,
dynamic_k_indicator: str = 'iou',
use_keypoints_for_center: bool = False,
iou_calculator: ConfigType = dict(type='BBoxOverlaps2D'),
oks_calculator: ConfigType = dict(type='PoseOKS')):
self.center_radius = center_radius
self.candidate_topk = candidate_topk
self.iou_weight = iou_weight
self.cls_weight = cls_weight
self.oks_weight = oks_weight
self.vis_weight = vis_weight
assert dynamic_k_indicator in ('iou', 'oks'), f'the argument ' \
f'`dynamic_k_indicator` should be either \'iou\' or \'oks\', ' \
f'but got {dynamic_k_indicator}'
self.dynamic_k_indicator = dynamic_k_indicator
self.use_keypoints_for_center = use_keypoints_for_center
self.iou_calculator = TASK_UTILS.build(iou_calculator)
self.oks_calculator = TASK_UTILS.build(oks_calculator)
def assign(self, pred_instances: InstanceData, gt_instances: InstanceData,
**kwargs) -> dict:
"""Assign gt to priors using SimOTA.
Args:
pred_instances (:obj:`InstanceData`): Instances of model
predictions. It includes ``priors``, and the priors can
be anchors or points, or the bboxes predicted by the
previous stage, has shape (n, 4). The bboxes predicted by
the current model or stage will be named ``bboxes``,
``labels``, and ``scores``, the same as the ``InstanceData``
in other places.
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It usually includes ``bboxes``, with shape (k, 4),
and ``labels``, with shape (k, ).
Returns:
dict: Assignment result containing assigned gt indices,
max iou overlaps, assigned labels, etc.
"""
gt_bboxes = gt_instances.bboxes
gt_labels = gt_instances.labels
gt_keypoints = gt_instances.keypoints
gt_keypoints_visible = gt_instances.keypoints_visible
gt_areas = gt_instances.areas
num_gt = gt_bboxes.size(0)
decoded_bboxes = pred_instances.bboxes
pred_scores = pred_instances.scores
priors = pred_instances.priors
keypoints = pred_instances.keypoints
keypoints_visible = pred_instances.keypoints_visible
num_bboxes = decoded_bboxes.size(0)
# assign 0 by default
assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ),
0,
dtype=torch.long)
if num_gt == 0 or num_bboxes == 0:
# No ground truth or boxes, return empty assignment
max_overlaps = decoded_bboxes.new_zeros((num_bboxes, ))
assigned_labels = decoded_bboxes.new_full((num_bboxes, ),
-1,
dtype=torch.long)
return dict(
num_gts=num_gt,
gt_inds=assigned_gt_inds,
max_overlaps=max_overlaps,
labels=assigned_labels)
valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info(
priors, gt_bboxes, gt_keypoints, gt_keypoints_visible)
valid_decoded_bbox = decoded_bboxes[valid_mask]
valid_pred_scores = pred_scores[valid_mask]
valid_pred_kpts = keypoints[valid_mask]
valid_pred_kpts_vis = keypoints_visible[valid_mask]
num_valid = valid_decoded_bbox.size(0)
if num_valid == 0:
# No valid bboxes, return empty assignment
max_overlaps = decoded_bboxes.new_zeros((num_bboxes, ))
assigned_labels = decoded_bboxes.new_full((num_bboxes, ),
-1,
dtype=torch.long)
return dict(
num_gts=num_gt,
gt_inds=assigned_gt_inds,
max_overlaps=max_overlaps,
labels=assigned_labels)
cost_matrix = (~is_in_boxes_and_center) * INF
# calculate iou
pairwise_ious = self.iou_calculator(valid_decoded_bbox, gt_bboxes)
if self.iou_weight > 0:
iou_cost = -torch.log(pairwise_ious + EPS)
cost_matrix = cost_matrix + iou_cost * self.iou_weight
# calculate oks
if self.oks_weight > 0 or self.dynamic_k_indicator == 'oks':
pairwise_oks = self.oks_calculator(
valid_pred_kpts.unsqueeze(1), # [num_valid, 1, k, 2]
target=gt_keypoints.unsqueeze(0), # [1, num_gt, k, 2]
target_weights=gt_keypoints_visible.unsqueeze(
0), # [1, num_gt, k]
areas=gt_areas.unsqueeze(0), # [1, num_gt]
) # -> [num_valid, num_gt]
oks_cost = -torch.log(pairwise_oks + EPS)
cost_matrix = cost_matrix + oks_cost * self.oks_weight
# calculate cls
if self.cls_weight > 0:
gt_onehot_label = (
F.one_hot(gt_labels.to(torch.int64),
pred_scores.shape[-1]).float().unsqueeze(0).repeat(
num_valid, 1, 1))
valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(
1, num_gt, 1)
# disable AMP autocast to avoid overflow
with torch.cuda.amp.autocast(enabled=False):
cls_cost = (
F.binary_cross_entropy(
valid_pred_scores.to(dtype=torch.float32),
gt_onehot_label,
reduction='none',
).sum(-1).to(dtype=valid_pred_scores.dtype))
cost_matrix = cost_matrix + cls_cost * self.cls_weight
# calculate vis
if self.vis_weight > 0:
valid_pred_kpts_vis = valid_pred_kpts_vis.unsqueeze(1).repeat(
1, num_gt, 1) # [num_valid, 1, k]
gt_kpt_vis = gt_keypoints_visible.unsqueeze(
0).float() # [1, num_gt, k]
with torch.cuda.amp.autocast(enabled=False):
vis_cost = (
F.binary_cross_entropy(
valid_pred_kpts_vis.to(dtype=torch.float32),
gt_kpt_vis.repeat(num_valid, 1, 1),
reduction='none',
).sum(-1).to(dtype=valid_pred_kpts_vis.dtype))
cost_matrix = cost_matrix + vis_cost * self.vis_weight
if self.dynamic_k_indicator == 'iou':
matched_pred_ious, matched_gt_inds = \
self.dynamic_k_matching(
cost_matrix, pairwise_ious, num_gt, valid_mask)
elif self.dynamic_k_indicator == 'oks':
matched_pred_ious, matched_gt_inds = \
self.dynamic_k_matching(
cost_matrix, pairwise_oks, num_gt, valid_mask)
# convert to AssignResult format
assigned_gt_inds[valid_mask] = matched_gt_inds + 1
assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
max_overlaps = assigned_gt_inds.new_full((num_bboxes, ),
-INF,
dtype=torch.float32)
max_overlaps[valid_mask] = matched_pred_ious.to(max_overlaps)
return dict(
num_gts=num_gt,
gt_inds=assigned_gt_inds,
max_overlaps=max_overlaps,
labels=assigned_labels)
def get_in_gt_and_in_center_info(
self,
priors: Tensor,
gt_bboxes: Tensor,
gt_keypoints: Optional[Tensor] = None,
gt_keypoints_visible: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Get the information of which prior is in gt bboxes and gt center
priors."""
num_gt = gt_bboxes.size(0)
repeated_x = priors[:, 0].unsqueeze(1).repeat(1, num_gt)
repeated_y = priors[:, 1].unsqueeze(1).repeat(1, num_gt)
repeated_stride_x = priors[:, 2].unsqueeze(1).repeat(1, num_gt)
repeated_stride_y = priors[:, 3].unsqueeze(1).repeat(1, num_gt)
# is prior centers in gt bboxes, shape: [n_prior, n_gt]
l_ = repeated_x - gt_bboxes[:, 0]
t_ = repeated_y - gt_bboxes[:, 1]
r_ = gt_bboxes[:, 2] - repeated_x
b_ = gt_bboxes[:, 3] - repeated_y
deltas = torch.stack([l_, t_, r_, b_], dim=1)
is_in_gts = deltas.min(dim=1).values > 0
is_in_gts_all = is_in_gts.sum(dim=1) > 0
# is prior centers in gt centers
gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
if self.use_keypoints_for_center and gt_keypoints_visible is not None:
gt_kpts_cts = (gt_keypoints * gt_keypoints_visible.unsqueeze(-1)
).sum(dim=-2) / gt_keypoints_visible.sum(
dim=-1, keepdims=True).clip(min=0)
gt_kpts_cts = gt_kpts_cts.to(gt_bboxes)
valid_mask = gt_keypoints_visible.sum(-1) > 0
gt_cxs[valid_mask] = gt_kpts_cts[valid_mask][..., 0]
gt_cys[valid_mask] = gt_kpts_cts[valid_mask][..., 1]
ct_box_l = gt_cxs - self.center_radius * repeated_stride_x
ct_box_t = gt_cys - self.center_radius * repeated_stride_y
ct_box_r = gt_cxs + self.center_radius * repeated_stride_x
ct_box_b = gt_cys + self.center_radius * repeated_stride_y
cl_ = repeated_x - ct_box_l
ct_ = repeated_y - ct_box_t
cr_ = ct_box_r - repeated_x
cb_ = ct_box_b - repeated_y
ct_deltas = torch.stack([cl_, ct_, cr_, cb_], dim=1)
is_in_cts = ct_deltas.min(dim=1).values > 0
is_in_cts_all = is_in_cts.sum(dim=1) > 0
# in boxes or in centers, shape: [num_priors]
is_in_gts_or_centers = is_in_gts_all | is_in_cts_all
# both in boxes and centers, shape: [num_fg, num_gt]
is_in_boxes_and_centers = (
is_in_gts[is_in_gts_or_centers, :]
& is_in_cts[is_in_gts_or_centers, :])
return is_in_gts_or_centers, is_in_boxes_and_centers
def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor,
num_gt: int,
valid_mask: Tensor) -> Tuple[Tensor, Tensor]:
"""Use IoU and matching cost to calculate the dynamic top-k positive
targets."""
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
# select candidate topk ious for dynamic-k calculation
candidate_topk = min(self.candidate_topk, pairwise_ious.size(0))
topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
# calculate dynamic k for each gt
dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(
cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False)
matching_matrix[:, gt_idx][pos_idx] = 1
del topk_ious, dynamic_ks, pos_idx
prior_match_gt_mask = matching_matrix.sum(1) > 1
if prior_match_gt_mask.sum() > 0:
cost_min, cost_argmin = torch.min(
cost[prior_match_gt_mask, :], dim=1)
matching_matrix[prior_match_gt_mask, :] *= 0
matching_matrix[prior_match_gt_mask, cost_argmin] = 1
# get foreground mask inside box and center prior
fg_mask_inboxes = matching_matrix.sum(1) > 0
valid_mask[valid_mask.clone()] = fg_mask_inboxes
matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
matched_pred_ious = (matching_matrix *
pairwise_ious).sum(1)[fg_mask_inboxes]
return matched_pred_ious, matched_gt_inds
|