akore commited on
Commit
fa156ee
Β·
verified Β·
1 Parent(s): bbf01fb

feat: add coordinate_mode arg (model/image/root_relative) to forward()

Browse files
Files changed (1) hide show
  1. modeling_rtmw.py +83 -4
modeling_rtmw.py CHANGED
@@ -22,9 +22,15 @@ class PoseOutput(ModelOutput):
22
 
23
  Args:
24
  keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
25
- Predicted keypoint coordinates in format [x, y].
 
 
 
 
26
  scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`):
27
- Predicted keypoint confidence scores.
 
 
28
  loss (`torch.FloatTensor`, *optional*):
29
  Loss value if training.
30
  pred_x (`torch.FloatTensor`, *optional*):
@@ -35,6 +41,7 @@ class PoseOutput(ModelOutput):
35
 
36
  keypoints: torch.FloatTensor = None
37
  scores: torch.FloatTensor = None
 
38
  loss: Optional[torch.FloatTensor] = None
39
  pred_x: Optional[torch.FloatTensor] = None
40
  pred_y: Optional[torch.FloatTensor] = None
@@ -1338,6 +1345,8 @@ class RTMWModel(PreTrainedModel):
1338
  def forward(
1339
  self,
1340
  pixel_values=None,
 
 
1341
  labels=None,
1342
  output_hidden_states=None,
1343
  return_dict=None,
@@ -1347,8 +1356,28 @@ class RTMWModel(PreTrainedModel):
1347
 
1348
  Args:
1349
  pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
1350
- Pixel values. Pixel values can be obtained using
1351
- RTMWImageProcessor.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1352
  labels (`List[Dict]`, *optional*):
1353
  Labels for computing the pose estimation loss.
1354
  output_hidden_states (`bool`, *optional*):
@@ -1361,6 +1390,7 @@ class RTMWModel(PreTrainedModel):
1361
  If return_dict=True, `PoseOutput` is returned.
1362
  If return_dict=False, a tuple is returned with keypoints and scores.
1363
  """
 
1364
  return_dict = return_dict if return_dict is not None else True
1365
 
1366
  # Get inputs
@@ -1400,10 +1430,59 @@ class RTMWModel(PreTrainedModel):
1400
  0.0, 1.0,
1401
  )
1402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1403
  if return_dict:
1404
  return PoseOutput(
1405
  keypoints=keypoints,
1406
  scores=scores,
 
1407
  pred_x=pred_x,
1408
  pred_y=pred_y
1409
  )
 
22
 
23
  Args:
24
  keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
25
+ Predicted keypoint coordinates in format [x, y]. The coordinate system
26
+ depends on the `coordinate_mode` passed to `forward()`:
27
+ - ``"model"`` β€” raw SimCC space (model input resolution, e.g. 288Γ—384 px)
28
+ - ``"image"`` β€” original image space, scaled via the supplied `bbox`
29
+ - ``"root_relative"`` β€” root-normalised: origin at mid-hip, unit = half hip-to-hip dist
30
  scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`):
31
+ Predicted keypoint confidence scores in [0, 1].
32
+ coordinate_mode (`str`):
33
+ Which coordinate system `keypoints` is expressed in (mirrors the arg passed to forward).
34
  loss (`torch.FloatTensor`, *optional*):
35
  Loss value if training.
36
  pred_x (`torch.FloatTensor`, *optional*):
 
41
 
42
  keypoints: torch.FloatTensor = None
43
  scores: torch.FloatTensor = None
44
+ coordinate_mode: Optional[str] = None
45
  loss: Optional[torch.FloatTensor] = None
46
  pred_x: Optional[torch.FloatTensor] = None
47
  pred_y: Optional[torch.FloatTensor] = None
 
1345
  def forward(
1346
  self,
1347
  pixel_values=None,
1348
+ bbox=None,
1349
+ coordinate_mode: str = "image",
1350
  labels=None,
1351
  output_hidden_states=None,
1352
  return_dict=None,
 
1356
 
1357
  Args:
1358
  pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
1359
+ Pixel values cropped and resized to the model's input resolution
1360
+ (e.g. 288Γ—384). Use `RTMWImageProcessor` or prepare manually with
1361
+ ImageNet normalisation.
1362
+ bbox (`torch.FloatTensor` of shape `(batch_size, 4)` or `(4,)`, *optional*):
1363
+ Person bounding boxes in the **original** image, as
1364
+ ``[x1, y1, x2, y2]`` pixel coordinates. Required when
1365
+ ``coordinate_mode="image"``; ignored otherwise.
1366
+ coordinate_mode (`str`, *optional*, defaults to ``"image"``):
1367
+ How to express the returned keypoint coordinates:
1368
+
1369
+ - ``"model"`` β€” raw SimCC space (same resolution as the
1370
+ model input, e.g. 288Γ—384 px). No extra arguments needed.
1371
+ - ``"image"`` β€” rescaled back to the original image pixel
1372
+ space using the supplied ``bbox``. If ``bbox`` is ``None`` the
1373
+ output falls back to ``"model"`` space with a warning.
1374
+ - ``"root_relative"`` β€” root-normalised coordinates. The root is
1375
+ the midpoint of the left-hip (kp 11) and right-hip (kp 12)
1376
+ joints. All keypoints are translated so the root is at the
1377
+ origin, then divided by half the inter-hip distance so that
1378
+ each hip lands at unit distance from the origin. Applied
1379
+ *after* any ``"image"`` projection when both are combined
1380
+ (not combinable via this single arg β€” choose one).
1381
  labels (`List[Dict]`, *optional*):
1382
  Labels for computing the pose estimation loss.
1383
  output_hidden_states (`bool`, *optional*):
 
1390
  If return_dict=True, `PoseOutput` is returned.
1391
  If return_dict=False, a tuple is returned with keypoints and scores.
1392
  """
1393
+ import warnings
1394
  return_dict = return_dict if return_dict is not None else True
1395
 
1396
  # Get inputs
 
1430
  0.0, 1.0,
1431
  )
1432
 
1433
+ # ── Coordinate transform ──────────────────────────────────────────────
1434
+ # Keypoints are currently in model-input space:
1435
+ # x in [0, model_w), y in [0, model_h)
1436
+ # e.g. model_w=288, model_h=384 for rtmw-l-384x288.
1437
+ if coordinate_mode == "image":
1438
+ if bbox is None:
1439
+ warnings.warn(
1440
+ "coordinate_mode='image' requires bbox=[x1,y1,x2,y2] per image. "
1441
+ "Falling back to model-space coordinates.",
1442
+ UserWarning, stacklevel=2,
1443
+ )
1444
+ coordinate_mode = "model"
1445
+ else:
1446
+ # bbox: (B, 4) or (4,) β†’ normalise to (B, 1, 2) broadcast shape
1447
+ bbox_t = torch.as_tensor(bbox, dtype=keypoints.dtype, device=keypoints.device)
1448
+ if bbox_t.dim() == 1:
1449
+ bbox_t = bbox_t.unsqueeze(0).expand(keypoints.shape[0], -1)
1450
+ model_h = pixel_values.shape[2] # H dim of model input
1451
+ model_w = pixel_values.shape[3] # W dim of model input
1452
+ x1 = bbox_t[:, 0:1] # (B, 1)
1453
+ y1 = bbox_t[:, 1:2]
1454
+ x2 = bbox_t[:, 2:3]
1455
+ y2 = bbox_t[:, 3:4]
1456
+ scale_x = (x2 - x1) / model_w # (B, 1)
1457
+ scale_y = (y2 - y1) / model_h # (B, 1)
1458
+ # (B, K, 2) β€” broadcast over K
1459
+ keypoints = keypoints.clone()
1460
+ keypoints[:, :, 0] = keypoints[:, :, 0] * scale_x + x1
1461
+ keypoints[:, :, 1] = keypoints[:, :, 1] * scale_y + y1
1462
+
1463
+ elif coordinate_mode == "root_relative":
1464
+ # Root = midpoint of left_hip (11) and right_hip (12).
1465
+ # Scale = half the inter-hip distance so each hip is at unit
1466
+ # distance from the root. Clamp to β‰₯1 px to guard against
1467
+ # degenerate detections where the hips are co-located.
1468
+ left_hip = keypoints[:, 11, :] # (B, 2)
1469
+ right_hip = keypoints[:, 12, :] # (B, 2)
1470
+ root = 0.5 * (left_hip + right_hip) # (B, 2)
1471
+ scale = (0.5 * torch.norm(right_hip - left_hip, dim=-1, keepdim=True) # (B, 1)
1472
+ .clamp(min=1.0))
1473
+ keypoints = (keypoints - root.unsqueeze(1)) / scale.unsqueeze(1)
1474
+
1475
+ elif coordinate_mode != "model":
1476
+ raise ValueError(
1477
+ f"coordinate_mode must be 'model', 'image', or 'root_relative', got {coordinate_mode!r}"
1478
+ )
1479
+ # ─────────────────────────────────────────────────────────────────────
1480
+
1481
  if return_dict:
1482
  return PoseOutput(
1483
  keypoints=keypoints,
1484
  scores=scores,
1485
+ coordinate_mode=coordinate_mode,
1486
  pred_x=pred_x,
1487
  pred_y=pred_y
1488
  )