feat: add coordinate_mode arg (model/image/root_relative) to forward()
Browse files- modeling_rtmw.py +83 -4
modeling_rtmw.py
CHANGED
|
@@ -22,9 +22,15 @@ class PoseOutput(ModelOutput):
|
|
| 22 |
|
| 23 |
Args:
|
| 24 |
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
|
| 25 |
-
Predicted keypoint coordinates in format [x, y].
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`):
|
| 27 |
-
Predicted keypoint confidence scores.
|
|
|
|
|
|
|
| 28 |
loss (`torch.FloatTensor`, *optional*):
|
| 29 |
Loss value if training.
|
| 30 |
pred_x (`torch.FloatTensor`, *optional*):
|
|
@@ -35,6 +41,7 @@ class PoseOutput(ModelOutput):
|
|
| 35 |
|
| 36 |
keypoints: torch.FloatTensor = None
|
| 37 |
scores: torch.FloatTensor = None
|
|
|
|
| 38 |
loss: Optional[torch.FloatTensor] = None
|
| 39 |
pred_x: Optional[torch.FloatTensor] = None
|
| 40 |
pred_y: Optional[torch.FloatTensor] = None
|
|
@@ -1338,6 +1345,8 @@ class RTMWModel(PreTrainedModel):
|
|
| 1338 |
def forward(
|
| 1339 |
self,
|
| 1340 |
pixel_values=None,
|
|
|
|
|
|
|
| 1341 |
labels=None,
|
| 1342 |
output_hidden_states=None,
|
| 1343 |
return_dict=None,
|
|
@@ -1347,8 +1356,28 @@ class RTMWModel(PreTrainedModel):
|
|
| 1347 |
|
| 1348 |
Args:
|
| 1349 |
pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
|
| 1350 |
-
Pixel values
|
| 1351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1352 |
labels (`List[Dict]`, *optional*):
|
| 1353 |
Labels for computing the pose estimation loss.
|
| 1354 |
output_hidden_states (`bool`, *optional*):
|
|
@@ -1361,6 +1390,7 @@ class RTMWModel(PreTrainedModel):
|
|
| 1361 |
If return_dict=True, `PoseOutput` is returned.
|
| 1362 |
If return_dict=False, a tuple is returned with keypoints and scores.
|
| 1363 |
"""
|
|
|
|
| 1364 |
return_dict = return_dict if return_dict is not None else True
|
| 1365 |
|
| 1366 |
# Get inputs
|
|
@@ -1400,10 +1430,59 @@ class RTMWModel(PreTrainedModel):
|
|
| 1400 |
0.0, 1.0,
|
| 1401 |
)
|
| 1402 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1403 |
if return_dict:
|
| 1404 |
return PoseOutput(
|
| 1405 |
keypoints=keypoints,
|
| 1406 |
scores=scores,
|
|
|
|
| 1407 |
pred_x=pred_x,
|
| 1408 |
pred_y=pred_y
|
| 1409 |
)
|
|
|
|
| 22 |
|
| 23 |
Args:
|
| 24 |
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
|
| 25 |
+
Predicted keypoint coordinates in format [x, y]. The coordinate system
|
| 26 |
+
depends on the `coordinate_mode` passed to `forward()`:
|
| 27 |
+
- ``"model"`` β raw SimCC space (model input resolution, e.g. 288Γ384 px)
|
| 28 |
+
- ``"image"`` β original image space, scaled via the supplied `bbox`
|
| 29 |
+
- ``"root_relative"`` β root-normalised: origin at mid-hip, unit = half hip-to-hip dist
|
| 30 |
scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`):
|
| 31 |
+
Predicted keypoint confidence scores in [0, 1].
|
| 32 |
+
coordinate_mode (`str`):
|
| 33 |
+
Which coordinate system `keypoints` is expressed in (mirrors the arg passed to forward).
|
| 34 |
loss (`torch.FloatTensor`, *optional*):
|
| 35 |
Loss value if training.
|
| 36 |
pred_x (`torch.FloatTensor`, *optional*):
|
|
|
|
| 41 |
|
| 42 |
keypoints: torch.FloatTensor = None
|
| 43 |
scores: torch.FloatTensor = None
|
| 44 |
+
coordinate_mode: Optional[str] = None
|
| 45 |
loss: Optional[torch.FloatTensor] = None
|
| 46 |
pred_x: Optional[torch.FloatTensor] = None
|
| 47 |
pred_y: Optional[torch.FloatTensor] = None
|
|
|
|
| 1345 |
def forward(
|
| 1346 |
self,
|
| 1347 |
pixel_values=None,
|
| 1348 |
+
bbox=None,
|
| 1349 |
+
coordinate_mode: str = "image",
|
| 1350 |
labels=None,
|
| 1351 |
output_hidden_states=None,
|
| 1352 |
return_dict=None,
|
|
|
|
| 1356 |
|
| 1357 |
Args:
|
| 1358 |
pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
|
| 1359 |
+
Pixel values cropped and resized to the model's input resolution
|
| 1360 |
+
(e.g. 288Γ384). Use `RTMWImageProcessor` or prepare manually with
|
| 1361 |
+
ImageNet normalisation.
|
| 1362 |
+
bbox (`torch.FloatTensor` of shape `(batch_size, 4)` or `(4,)`, *optional*):
|
| 1363 |
+
Person bounding boxes in the **original** image, as
|
| 1364 |
+
``[x1, y1, x2, y2]`` pixel coordinates. Required when
|
| 1365 |
+
``coordinate_mode="image"``; ignored otherwise.
|
| 1366 |
+
coordinate_mode (`str`, *optional*, defaults to ``"image"``):
|
| 1367 |
+
How to express the returned keypoint coordinates:
|
| 1368 |
+
|
| 1369 |
+
- ``"model"`` β raw SimCC space (same resolution as the
|
| 1370 |
+
model input, e.g. 288Γ384 px). No extra arguments needed.
|
| 1371 |
+
- ``"image"`` β rescaled back to the original image pixel
|
| 1372 |
+
space using the supplied ``bbox``. If ``bbox`` is ``None`` the
|
| 1373 |
+
output falls back to ``"model"`` space with a warning.
|
| 1374 |
+
- ``"root_relative"`` β root-normalised coordinates. The root is
|
| 1375 |
+
the midpoint of the left-hip (kp 11) and right-hip (kp 12)
|
| 1376 |
+
joints. All keypoints are translated so the root is at the
|
| 1377 |
+
origin, then divided by half the inter-hip distance so that
|
| 1378 |
+
each hip lands at unit distance from the origin. Applied
|
| 1379 |
+
*after* any ``"image"`` projection when both are combined
|
| 1380 |
+
(not combinable via this single arg β choose one).
|
| 1381 |
labels (`List[Dict]`, *optional*):
|
| 1382 |
Labels for computing the pose estimation loss.
|
| 1383 |
output_hidden_states (`bool`, *optional*):
|
|
|
|
| 1390 |
If return_dict=True, `PoseOutput` is returned.
|
| 1391 |
If return_dict=False, a tuple is returned with keypoints and scores.
|
| 1392 |
"""
|
| 1393 |
+
import warnings
|
| 1394 |
return_dict = return_dict if return_dict is not None else True
|
| 1395 |
|
| 1396 |
# Get inputs
|
|
|
|
| 1430 |
0.0, 1.0,
|
| 1431 |
)
|
| 1432 |
|
| 1433 |
+
# ββ Coordinate transform ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1434 |
+
# Keypoints are currently in model-input space:
|
| 1435 |
+
# x in [0, model_w), y in [0, model_h)
|
| 1436 |
+
# e.g. model_w=288, model_h=384 for rtmw-l-384x288.
|
| 1437 |
+
if coordinate_mode == "image":
|
| 1438 |
+
if bbox is None:
|
| 1439 |
+
warnings.warn(
|
| 1440 |
+
"coordinate_mode='image' requires bbox=[x1,y1,x2,y2] per image. "
|
| 1441 |
+
"Falling back to model-space coordinates.",
|
| 1442 |
+
UserWarning, stacklevel=2,
|
| 1443 |
+
)
|
| 1444 |
+
coordinate_mode = "model"
|
| 1445 |
+
else:
|
| 1446 |
+
# bbox: (B, 4) or (4,) β normalise to (B, 1, 2) broadcast shape
|
| 1447 |
+
bbox_t = torch.as_tensor(bbox, dtype=keypoints.dtype, device=keypoints.device)
|
| 1448 |
+
if bbox_t.dim() == 1:
|
| 1449 |
+
bbox_t = bbox_t.unsqueeze(0).expand(keypoints.shape[0], -1)
|
| 1450 |
+
model_h = pixel_values.shape[2] # H dim of model input
|
| 1451 |
+
model_w = pixel_values.shape[3] # W dim of model input
|
| 1452 |
+
x1 = bbox_t[:, 0:1] # (B, 1)
|
| 1453 |
+
y1 = bbox_t[:, 1:2]
|
| 1454 |
+
x2 = bbox_t[:, 2:3]
|
| 1455 |
+
y2 = bbox_t[:, 3:4]
|
| 1456 |
+
scale_x = (x2 - x1) / model_w # (B, 1)
|
| 1457 |
+
scale_y = (y2 - y1) / model_h # (B, 1)
|
| 1458 |
+
# (B, K, 2) β broadcast over K
|
| 1459 |
+
keypoints = keypoints.clone()
|
| 1460 |
+
keypoints[:, :, 0] = keypoints[:, :, 0] * scale_x + x1
|
| 1461 |
+
keypoints[:, :, 1] = keypoints[:, :, 1] * scale_y + y1
|
| 1462 |
+
|
| 1463 |
+
elif coordinate_mode == "root_relative":
|
| 1464 |
+
# Root = midpoint of left_hip (11) and right_hip (12).
|
| 1465 |
+
# Scale = half the inter-hip distance so each hip is at unit
|
| 1466 |
+
# distance from the root. Clamp to β₯1 px to guard against
|
| 1467 |
+
# degenerate detections where the hips are co-located.
|
| 1468 |
+
left_hip = keypoints[:, 11, :] # (B, 2)
|
| 1469 |
+
right_hip = keypoints[:, 12, :] # (B, 2)
|
| 1470 |
+
root = 0.5 * (left_hip + right_hip) # (B, 2)
|
| 1471 |
+
scale = (0.5 * torch.norm(right_hip - left_hip, dim=-1, keepdim=True) # (B, 1)
|
| 1472 |
+
.clamp(min=1.0))
|
| 1473 |
+
keypoints = (keypoints - root.unsqueeze(1)) / scale.unsqueeze(1)
|
| 1474 |
+
|
| 1475 |
+
elif coordinate_mode != "model":
|
| 1476 |
+
raise ValueError(
|
| 1477 |
+
f"coordinate_mode must be 'model', 'image', or 'root_relative', got {coordinate_mode!r}"
|
| 1478 |
+
)
|
| 1479 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1480 |
+
|
| 1481 |
if return_dict:
|
| 1482 |
return PoseOutput(
|
| 1483 |
keypoints=keypoints,
|
| 1484 |
scores=scores,
|
| 1485 |
+
coordinate_mode=coordinate_mode,
|
| 1486 |
pred_x=pred_x,
|
| 1487 |
pred_y=pred_y
|
| 1488 |
)
|