akore commited on
Commit
1294ede
·
verified ·
1 Parent(s): 7a9f27a

Add rtmw-m-256x192 RTMW/RTMDet HF port

Browse files
Files changed (6) hide show
  1. README.md +137 -0
  2. config.json +49 -0
  3. configuration_rtmw.py +145 -0
  4. model.safetensors +3 -0
  5. modeling_rtmw.py +1406 -0
  6. preprocessor_config.json +39 -0
README.md ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - pose-estimation
5
+ - wholebody-pose
6
+ - rtmpose
7
+ - rtmw
8
+ - keypoint-detection
9
+ - computer-vision
10
+ datasets:
11
+ - coco-wholebody
12
+ pipeline_tag: keypoint-detection
13
+ ---
14
+
15
+ # rtmw-m-256x192
16
+
17
+ This is a Hugging Face-compatible port of **rtmw-m-256x192** from [OpenMMLab MMPose](https://github.com/open-mmlab/mmpose).
18
+
19
+ RTMW (**R**eal-**T**ime **M**ulti-person **W**holebody pose estimation) extends RTMPose to predict 133 wholebody keypoints covering the body, face, hands, and feet simultaneously.
20
+
21
+ The model is trained on **Cocktail14** — a mixture of 14 public datasets — and evaluated on COCO-WholeBody v1.0 val.
22
+
23
+ ## Model description
24
+
25
+ - **Architecture**: CSPNeXt backbone + CSPNeXtPAFPN neck + RTMWHead (SimCC with GAU)
26
+ - **Keypoints**: 133 (17 body + 6 feet + 68 face + 21 left hand + 21 right hand)
27
+ - **Codec**: SimCC with Gaussian label smoothing
28
+ - **Uses custom code** — load with `trust_remote_code=True`
29
+
30
+ ## Performance on COCO-WholeBody v1.0 val
31
+
32
+ Detector: human AP = 56.4 on COCO val2017.
33
+
34
+ | Model | Input | Body AP | Body AR | Foot AP | Foot AR | Face AP | Face AR | Hand AP | Hand AR | Whole AP | Whole AR |
35
+ |:------|:------:|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:|:--------:|:--------:|
36
+ | rtmw-m-256x192 | 256×192 | 0.676 | 0.747 | 0.671 | 0.794 | 0.783 | 0.854 | 0.491 | 0.604 | 0.582 | 0.673 |
37
+
38
+ ## Usage
39
+
40
+ ```python
41
+ from transformers import AutoImageProcessor
42
+ from PIL import Image
43
+ import torch
44
+
45
+ # Load model (requires trust_remote_code=True for custom architecture)
46
+ from huggingface_hub import hf_hub_download
47
+
48
+ # Or directly:
49
+ import sys, json
50
+ from pathlib import Path
51
+
52
+ # Using the custom modules:
53
+ from rtmw_modules.configuration_rtmw import RTMWConfig
54
+ from rtmw_modules.modeling_rtmw import RTMWModel
55
+
56
+ config = RTMWConfig.from_pretrained("akore/rtmw-m-256x192", trust_remote_code=True)
57
+ model = RTMWModel.from_pretrained("akore/rtmw-m-256x192", trust_remote_code=True)
58
+ model.eval()
59
+
60
+ processor = AutoImageProcessor.from_pretrained("akore/rtmw-m-256x192")
61
+ image = Image.open("your_image.jpg").convert("RGB")
62
+ inputs = processor(images=image, return_tensors="pt")
63
+
64
+ with torch.no_grad():
65
+ outputs = model(**inputs)
66
+
67
+ # outputs.keypoints: (1, 133, 2) — [x, y] in image coordinates
68
+ # outputs.scores: (1, 133) — confidence in [0, 1]
69
+ print(outputs.keypoints.shape, outputs.scores.shape)
70
+ ```
71
+
72
+ ## Cocktail14 training datasets
73
+
74
+ | Dataset | Link |
75
+ |---------|------|
76
+ | AI Challenger | [mmpose docs](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#aic) |
77
+ | CrowdPose | [mmpose docs](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#crowdpose) |
78
+ | MPII | [mmpose docs](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#mpii) |
79
+ | sub-JHMDB | [mmpose docs](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#sub-jhmdb-dataset) |
80
+ | Halpe | [mmpose docs](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_wholebody_keypoint.html#halpe) |
81
+ | PoseTrack18 | [mmpose docs](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#posetrack18) |
82
+ | COCO-WholeBody | [GitHub](https://github.com/jin-s13/COCO-WholeBody/) |
83
+ | UBody | [GitHub](https://github.com/IDEA-Research/OSX) |
84
+ | Human-Art | [mmpose docs](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#human-art-dataset) |
85
+ | WFLW | [project page](https://wywu.github.io/projects/LAB/WFLW.html) |
86
+ | 300W | [project page](https://ibug.doc.ic.ac.uk/resources/300-W/) |
87
+ | COFW | [project page](http://www.vision.caltech.edu/xpburgos/ICCV13/) |
88
+ | LaPa | [GitHub](https://github.com/JDAI-CV/lapa-dataset) |
89
+ | InterHand | [project page](https://mks0601.github.io/InterHand2.6M/) |
90
+
91
+ ## Score normalization
92
+
93
+ Raw SimCC confidence scores vary across model variants (0–1 for 256×192 models, 0–10 for 384×288 models). This port applies fixed min–max normalization so all model variants output scores in **[0, 1]**. The `score_min` and `score_max` hyperparameters used are stored in the config and were determined empirically from real-world inference.
94
+
95
+ ## Citation
96
+
97
+ ```bibtex
98
+ @article{jiang2024rtmw,
99
+ title={RTMW: Real-Time Multi-Person 2D and 3D Whole-body Pose Estimation},
100
+ author={Jiang, Tao and Xie, Xinchen and Li, Yining},
101
+ journal={arXiv preprint arXiv:2407.08634},
102
+ year={2024}
103
+ }
104
+
105
+ @misc{https://doi.org/10.48550/arxiv.2303.07399,
106
+ doi = {10.48550/ARXIV.2303.07399},
107
+ url = {https://arxiv.org/abs/2303.07399},
108
+ author = {Jiang, Tao and Lu, Peng and Zhang, Li and Ma, Ningsheng and Han, Rui and Lyu, Chengqi and Li, Yining and Chen, Kai},
109
+ title = {RTMPose: Real-Time Multi-Person Pose Estimation based on MMPose},
110
+ publisher = {arXiv},
111
+ year = {2023},
112
+ copyright = {Creative Commons Attribution 4.0 International}
113
+ }
114
+
115
+ @misc{mmpose2020,
116
+ title={OpenMMLab Pose Estimation Toolbox and Benchmark},
117
+ author={MMPose Contributors},
118
+ howpublished = {\url{https://github.com/open-mmlab/mmpose}},
119
+ year={2020}
120
+ }
121
+
122
+ @misc{lyu2022rtmdet,
123
+ title={RTMDet: An Empirical Study of Designing Real-Time Object Detectors},
124
+ author={Chengqi Lyu and Wenwei Zhang and Haian Huang and Yue Zhou and Yudong Wang and Yanyi Liu and Shilong Zhang and Kai Chen},
125
+ year={2022},
126
+ eprint={2212.07784},
127
+ archivePrefix={arXiv},
128
+ primaryClass={cs.CV}
129
+ }
130
+
131
+ @inproceedings{jin2020whole,
132
+ title={Whole-Body Human Pose Estimation in the Wild},
133
+ author={Jin, Sheng and Xu, Lumin and Xu, Jin and Wang, Can and Liu, Wentao and Qian, Chen and Ouyang, Wanli and Luo, Ping},
134
+ booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
135
+ year={2020}
136
+ }
137
+ ```
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backbone_arch": "P5",
3
+ "backbone_channel_attention": true,
4
+ "backbone_deepen_factor": 0.67,
5
+ "backbone_expand_ratio": 0.5,
6
+ "backbone_widen_factor": 0.75,
7
+ "decoder_normalize": false,
8
+ "decoder_sigma": [
9
+ 4.9,
10
+ 5.66
11
+ ],
12
+ "decoder_use_dark": false,
13
+ "gau_act_fn": "SiLU",
14
+ "gau_drop_path": 0.0,
15
+ "gau_dropout_rate": 0.0,
16
+ "gau_expansion_factor": 2,
17
+ "gau_hidden_dims": 256,
18
+ "gau_pos_enc": false,
19
+ "gau_s": 128,
20
+ "gau_use_rel_bias": false,
21
+ "head_final_layer_kernel_size": 7,
22
+ "head_in_channels": 768,
23
+ "head_in_featuremap_size": [
24
+ 6,
25
+ 8
26
+ ],
27
+ "input_size": [
28
+ 192,
29
+ 256
30
+ ],
31
+ "model_type": "rtmw",
32
+ "neck_expand_ratio": 0.5,
33
+ "neck_in_channels": [
34
+ 192,
35
+ 384,
36
+ 768
37
+ ],
38
+ "neck_num_csp_blocks": 2,
39
+ "neck_out_channels": null,
40
+ "num_keypoints": 133,
41
+ "score_max": 1.0,
42
+ "score_min": 0.0,
43
+ "simcc_split_ratio": 2.0,
44
+ "transformers_version": "5.2.0",
45
+ "auto_map": {
46
+ "AutoConfig": "configuration_rtmw.RTMWConfig",
47
+ "AutoModelForImageProcessing": "modeling_rtmw.RTMWModel"
48
+ }
49
+ }
configuration_rtmw.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union, Tuple
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class RTMWConfig(PretrainedConfig):
11
+ """
12
+ Configuration class for RTMW models from OpenMMLab.
13
+
14
+ Args:
15
+ backbone_arch (`str`, *optional*, defaults to `"P5"`):
16
+ Architecture of the backbone. Can be either "P5" or "P6".
17
+ backbone_expand_ratio (`float`, *optional*, defaults to `0.5`):
18
+ Expand ratio of the backbone channels.
19
+ backbone_deepen_factor (`float`, *optional*, defaults to `0.67`):
20
+ Factor to deepen the backbone stages.
21
+ backbone_widen_factor (`float`, *optional*, defaults to `0.75`):
22
+ Factor to widen the backbone channels.
23
+ backbone_channel_attention (`bool`, *optional*, defaults to `True`):
24
+ Whether to use channel attention in the backbone.
25
+ neck_in_channels (`List[int]`, *optional*, defaults to `[192, 384, 768]`):
26
+ Input channels for the neck.
27
+ neck_out_channels (`int`, *optional*, defaults to `192`):
28
+ Output channels for the neck.
29
+ neck_num_csp_blocks (`int`, *optional*, defaults to `2`):
30
+ Number of CSP blocks in the neck.
31
+ neck_expand_ratio (`float`, *optional*, defaults to `0.5`):
32
+ Expand ratio for the neck channels.
33
+ num_keypoints (`int`, *optional*, defaults to `133`):
34
+ Number of keypoints to predict.
35
+ input_size (`Tuple[int, int]`, *optional*, defaults to `(192, 256)`):
36
+ Default input image size [width, height].
37
+ simcc_split_ratio (`float`, *optional*, defaults to `2.0`):
38
+ Split ratio of pixels for SimCC.
39
+ decoder_sigma (`Tuple[float, float]`, *optional*, defaults to `(4.9, 5.66)`):
40
+ Sigma values for the Gaussian distribution in SimCC decoder.
41
+ decoder_normalize (`bool`, *optional*, defaults to `False`):
42
+ Whether to normalize the decoder outputs.
43
+ decoder_use_dark (`bool`, *optional*, defaults to `False`):
44
+ Whether to use DARK post-processing in the decoder.
45
+ gau_hidden_dims (`int`, *optional*, defaults to `256`):
46
+ Hidden dimensions for the Gated Attention Unit.
47
+ gau_expansion_factor (`int`, *optional*, defaults to `2`):
48
+ Expansion factor for the Gated Attention Unit.
49
+ gau_dropout_rate (`float`, *optional*, defaults to `0.0`):
50
+ Dropout rate for the Gated Attention Unit.
51
+ head_in_channels (`int`, *optional*, defaults to `768`):
52
+ Input channels for the detection head.
53
+ head_in_featuremap_size (`Tuple[int, int]`, *optional*, defaults to `(6, 8)`):
54
+ Input feature map size for the head.
55
+ head_final_layer_kernel_size (`int`, *optional*, defaults to `7`):
56
+ Kernel size for the final layer in the head.
57
+ score_min (`float`, *optional*, defaults to `0.0`):
58
+ Minimum raw score used for fixed min-max normalization of keypoint
59
+ confidence scores to the [0, 1] range. Empirically determined from
60
+ the model's score distribution.
61
+ score_max (`float`, *optional*, defaults to `1.0`):
62
+ Maximum raw score used for fixed min-max normalization of keypoint
63
+ confidence scores to the [0, 1] range. Empirically determined from
64
+ the model's score distribution (p99.9 of observed scores).
65
+ **kwargs:
66
+ Additional parameters passed to the parent class.
67
+ """
68
+
69
+ model_type = "rtmw"
70
+
71
+ def __init__(
72
+ self,
73
+ backbone_arch: str = "P5",
74
+ backbone_expand_ratio: float = 0.5,
75
+ backbone_deepen_factor: float = 0.67,
76
+ backbone_widen_factor: float = 0.75,
77
+ backbone_channel_attention: bool = True,
78
+ neck_in_channels: List[int] = [192, 384, 768],
79
+ neck_out_channels: int = None,
80
+ neck_num_csp_blocks: int = 2,
81
+ neck_expand_ratio: float = 0.5,
82
+ num_keypoints: int = 133,
83
+ input_size: Tuple[int, int] = (192, 256),
84
+ simcc_split_ratio: float = 2.0,
85
+ decoder_sigma: Tuple[float, float] = (4.9, 5.66),
86
+ decoder_normalize: bool = False,
87
+ decoder_use_dark: bool = False,
88
+ gau_hidden_dims: int = 256,
89
+ gau_s: int = 128,
90
+ gau_expansion_factor: int = 2,
91
+ gau_dropout_rate: float = 0.0,
92
+ gau_drop_path: float = 0.0,
93
+ gau_act_fn: str = "SiLU",
94
+ gau_use_rel_bias: bool = False,
95
+ gau_pos_enc: bool = False,
96
+ head_in_channels: int = 768,
97
+ head_in_featuremap_size: Tuple[int, int] = (6, 8),
98
+ head_final_layer_kernel_size: int = 7,
99
+ score_min: float = 0.0,
100
+ score_max: float = 1.0,
101
+ **kwargs
102
+ ):
103
+ super().__init__(**kwargs)
104
+
105
+ # Backbone config
106
+ self.backbone_arch = backbone_arch
107
+ self.backbone_expand_ratio = backbone_expand_ratio
108
+ self.backbone_deepen_factor = backbone_deepen_factor
109
+ self.backbone_widen_factor = backbone_widen_factor
110
+ self.backbone_channel_attention = backbone_channel_attention
111
+
112
+ # Neck config
113
+ self.neck_in_channels = neck_in_channels
114
+ self.neck_out_channels = neck_out_channels
115
+ self.neck_num_csp_blocks = neck_num_csp_blocks
116
+ self.neck_expand_ratio = neck_expand_ratio
117
+
118
+ # Pose estimation specific config
119
+ self.num_keypoints = num_keypoints
120
+ self.input_size = input_size
121
+ self.simcc_split_ratio = simcc_split_ratio
122
+
123
+ # Decoder config
124
+ self.decoder_sigma = decoder_sigma
125
+ self.decoder_normalize = decoder_normalize
126
+ self.decoder_use_dark = decoder_use_dark
127
+
128
+ # GAU config (for RTMWHead)
129
+ self.gau_hidden_dims = gau_hidden_dims
130
+ self.gau_s = gau_s
131
+ self.gau_expansion_factor = gau_expansion_factor
132
+ self.gau_dropout_rate = gau_dropout_rate
133
+ self.gau_drop_path = gau_drop_path
134
+ self.gau_act_fn = gau_act_fn
135
+ self.gau_use_rel_bias = gau_use_rel_bias
136
+ self.gau_pos_enc = gau_pos_enc
137
+
138
+ # Head config
139
+ self.head_in_channels = head_in_channels
140
+ self.head_in_featuremap_size = head_in_featuremap_size
141
+ self.head_final_layer_kernel_size = head_final_layer_kernel_size
142
+
143
+ # Score normalization config
144
+ self.score_min = score_min
145
+ self.score_max = score_max
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e83a737e23882af050ed44728f8d184b575867d979cefd658d32c9ae3a565775
3
+ size 129280692
modeling_rtmw.py ADDED
@@ -0,0 +1,1406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union, Dict, Sequence
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from transformers.modeling_outputs import ModelOutput
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ from transformers.utils import logging
11
+
12
+ from .configuration_rtmw import RTMWConfig
13
+
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class PoseOutput(ModelOutput):
20
+ """
21
+ Output type for pose estimation models.
22
+
23
+ Args:
24
+ keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
25
+ Predicted keypoint coordinates in format [x, y].
26
+ scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`):
27
+ Predicted keypoint confidence scores.
28
+ loss (`torch.FloatTensor`, *optional*):
29
+ Loss value if training.
30
+ pred_x (`torch.FloatTensor`, *optional*):
31
+ X-axis heatmap predictions from the SimCC representation.
32
+ pred_y (`torch.FloatTensor`, *optional*):
33
+ Y-axis heatmap predictions from the SimCC representation.
34
+ """
35
+
36
+ keypoints: torch.FloatTensor = None
37
+ scores: torch.FloatTensor = None
38
+ loss: Optional[torch.FloatTensor] = None
39
+ pred_x: Optional[torch.FloatTensor] = None
40
+ pred_y: Optional[torch.FloatTensor] = None
41
+
42
+
43
+ # Common layers and building blocks from RTMDet with adjustments for RTMW
44
+ class ConvModule(nn.Module):
45
+ """A conv block that bundles conv/norm/activation layers."""
46
+ def __init__(
47
+ self,
48
+ in_channels: int,
49
+ out_channels: int,
50
+ kernel_size: Union[int, Tuple[int, int]],
51
+ stride: Union[int, Tuple[int, int]] = 1,
52
+ padding: Union[int, Tuple[int, int]] = 0,
53
+ dilation: Union[int, Tuple[int, int]] = 1,
54
+ groups: int = 1,
55
+ bias: bool = True,
56
+ norm_cfg: Optional[Dict] = dict(type='BN'),
57
+ act_cfg: Optional[Dict] = dict(type='SiLU'),
58
+ inplace: bool = True,
59
+ ):
60
+ super().__init__()
61
+ self.with_norm = norm_cfg is not None
62
+ self.with_activation = act_cfg is not None
63
+
64
+ # Build convolution layer
65
+ self.conv = nn.Conv2d(
66
+ in_channels,
67
+ out_channels,
68
+ kernel_size=kernel_size,
69
+ stride=stride,
70
+ padding=padding,
71
+ dilation=dilation,
72
+ groups=groups,
73
+ bias=bias and not self.with_norm)
74
+
75
+ # Build normalization layer
76
+ if self.with_norm:
77
+ norm_channels = out_channels
78
+ # Use PyTorch default values to match MMPose's actual BN parameters during inference
79
+ # momentum doesn't affect inference, but eps is critical!
80
+ self.bn = nn.BatchNorm2d(norm_channels, momentum=0.1, eps=1e-05)
81
+
82
+ # Build activation layer
83
+ if self.with_activation:
84
+ if act_cfg['type'] == 'ReLU':
85
+ self.activate = nn.ReLU(inplace=inplace)
86
+ elif act_cfg['type'] == 'LeakyReLU':
87
+ self.activate = nn.LeakyReLU(negative_slope=0.1, inplace=inplace)
88
+ elif act_cfg['type'] == 'SiLU' or act_cfg['type'] == 'Swish':
89
+ self.activate = nn.SiLU(inplace=inplace)
90
+ else:
91
+ raise NotImplementedError(f"Activation {act_cfg['type']} not implemented")
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ x = self.conv(x)
95
+ if self.with_norm:
96
+ x = self.bn(x)
97
+ if self.with_activation:
98
+ x = self.activate(x)
99
+ return x
100
+
101
+
102
+ class DepthwiseSeparableConvModule(nn.Module):
103
+ """Depthwise separable convolution module."""
104
+ def __init__(
105
+ self,
106
+ in_channels: int,
107
+ out_channels: int,
108
+ kernel_size: Union[int, Tuple[int, int]],
109
+ stride: Union[int, Tuple[int, int]] = 1,
110
+ padding: Union[int, Tuple[int, int]] = 0,
111
+ dilation: Union[int, Tuple[int, int]] = 1,
112
+ norm_cfg: Optional[Dict] = dict(type='BN'),
113
+ act_cfg: Dict = dict(type='SiLU'),
114
+ **kwargs
115
+ ):
116
+ super().__init__()
117
+
118
+ # Depthwise convolution
119
+ self.depthwise_conv = ConvModule(
120
+ in_channels,
121
+ in_channels,
122
+ kernel_size,
123
+ stride=stride,
124
+ padding=padding,
125
+ dilation=dilation,
126
+ groups=in_channels,
127
+ norm_cfg=norm_cfg,
128
+ act_cfg=act_cfg,
129
+ **kwargs)
130
+
131
+ # Pointwise convolution
132
+ self.pointwise_conv = ConvModule(
133
+ in_channels,
134
+ out_channels,
135
+ 1,
136
+ norm_cfg=norm_cfg,
137
+ act_cfg=act_cfg,
138
+ **kwargs)
139
+
140
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
141
+ x = self.depthwise_conv(x)
142
+ x = self.pointwise_conv(x)
143
+ return x
144
+
145
+
146
+ class ChannelAttention(nn.Module):
147
+ """Channel attention Module."""
148
+ def __init__(self, channels: int) -> None:
149
+ super().__init__()
150
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
151
+ self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
152
+ self.act = nn.Hardsigmoid(inplace=True)
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ with torch.amp.autocast(enabled=False, device_type=x.device.type):
156
+ out = self.global_avgpool(x)
157
+ out = self.fc(out)
158
+ out = self.act(out)
159
+ return x * out
160
+
161
+
162
+ class CSPNeXtBlock(nn.Module):
163
+ """The basic bottleneck block used in CSPNeXt."""
164
+ def __init__(
165
+ self,
166
+ in_channels: int,
167
+ out_channels: int,
168
+ expansion: float = 0.5,
169
+ add_identity: bool = True,
170
+ use_depthwise: bool = False,
171
+ kernel_size: int = 5,
172
+ act_cfg: Dict = dict(type='SiLU'),
173
+ ) -> None:
174
+ super().__init__()
175
+ hidden_channels = int(out_channels * expansion)
176
+ conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
177
+
178
+ self.conv1 = conv(
179
+ in_channels,
180
+ hidden_channels,
181
+ 3,
182
+ stride=1,
183
+ padding=1,
184
+ act_cfg=act_cfg)
185
+
186
+ self.conv2 = DepthwiseSeparableConvModule(
187
+ hidden_channels,
188
+ out_channels,
189
+ kernel_size,
190
+ stride=1,
191
+ padding=kernel_size // 2,
192
+ act_cfg=act_cfg)
193
+
194
+ self.add_identity = add_identity and in_channels == out_channels
195
+
196
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
197
+ identity = x
198
+ out = self.conv1(x)
199
+ out = self.conv2(out)
200
+ if self.add_identity:
201
+ return out + identity
202
+ else:
203
+ return out
204
+
205
+
206
+ class CSPLayer(nn.Module):
207
+ """Cross Stage Partial Layer."""
208
+ def __init__(
209
+ self,
210
+ in_channels: int,
211
+ out_channels: int,
212
+ expand_ratio: float = 0.5,
213
+ num_blocks: int = 1,
214
+ add_identity: bool = True,
215
+ use_depthwise: bool = False,
216
+ use_cspnext_block: bool = False,
217
+ channel_attention: bool = False,
218
+ act_cfg: Dict = dict(type='SiLU'),
219
+ ) -> None:
220
+ super().__init__()
221
+ block = CSPNeXtBlock if use_cspnext_block else None # Default to CSPNeXtBlock
222
+ mid_channels = int(out_channels * expand_ratio)
223
+ self.channel_attention = channel_attention
224
+
225
+ self.main_conv = ConvModule(
226
+ in_channels,
227
+ mid_channels,
228
+ 1,
229
+ act_cfg=act_cfg)
230
+
231
+ self.short_conv = ConvModule(
232
+ in_channels,
233
+ mid_channels,
234
+ 1,
235
+ act_cfg=act_cfg)
236
+
237
+ self.final_conv = ConvModule(
238
+ 2 * mid_channels,
239
+ out_channels,
240
+ 1,
241
+ act_cfg=act_cfg)
242
+
243
+ self.blocks = nn.Sequential(*[
244
+ block(
245
+ mid_channels,
246
+ mid_channels,
247
+ 1.0,
248
+ add_identity,
249
+ use_depthwise) for _ in range(num_blocks)
250
+ ])
251
+
252
+ if channel_attention:
253
+ self.attention = ChannelAttention(2 * mid_channels)
254
+
255
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
256
+ x_short = self.short_conv(x)
257
+ x_main = self.main_conv(x)
258
+ x_main = self.blocks(x_main)
259
+ x_final = torch.cat((x_main, x_short), dim=1)
260
+
261
+ if self.channel_attention:
262
+ x_final = self.attention(x_final)
263
+
264
+ return self.final_conv(x_final)
265
+
266
+
267
+ class SPPBottleneck(nn.Module):
268
+ """Spatial pyramid pooling layer."""
269
+ def __init__(
270
+ self,
271
+ in_channels: int,
272
+ out_channels: int,
273
+ kernel_sizes: Tuple[int, ...] = (5, 9, 13),
274
+ act_cfg: Dict = dict(type='SiLU'),
275
+ ):
276
+ super().__init__()
277
+ mid_channels = in_channels // 2
278
+ self.conv1 = ConvModule(
279
+ in_channels,
280
+ mid_channels,
281
+ 1,
282
+ stride=1,
283
+ act_cfg=act_cfg)
284
+
285
+ self.poolings = nn.ModuleList([
286
+ nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
287
+ for ks in kernel_sizes
288
+ ])
289
+
290
+ conv2_channels = mid_channels * (len(kernel_sizes) + 1)
291
+ self.conv2 = ConvModule(
292
+ conv2_channels,
293
+ out_channels,
294
+ 1,
295
+ act_cfg=act_cfg)
296
+
297
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
298
+ x = self.conv1(x)
299
+ with torch.amp.autocast(enabled=False, device_type=x.device.type):
300
+ x = torch.cat(
301
+ [x] + [pooling(x) for pooling in self.poolings], dim=1)
302
+ x = self.conv2(x)
303
+ return x
304
+
305
+
306
+ class CSPNeXt(nn.Module):
307
+ """CSPNeXt backbone used in RTMW."""
308
+
309
+ # From left to right:
310
+ # in_channels, out_channels, num_blocks, add_identity, use_spp
311
+ arch_settings = {
312
+ 'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
313
+ [256, 512, 6, True, False], [512, 1024, 3, False, True]],
314
+ 'P6': [[64, 128, 3, True, False], [128, 256, 6, True, False],
315
+ [256, 512, 6, True, False], [512, 768, 3, True, False],
316
+ [768, 1024, 3, False, True]]
317
+ }
318
+
319
+ def __init__(
320
+ self,
321
+ arch: str = 'P5',
322
+ deepen_factor: float = 1.0,
323
+ widen_factor: float = 1.0,
324
+ out_indices: Sequence[int] = (2, 3, 4),
325
+ frozen_stages: int = -1,
326
+ use_depthwise: bool = False,
327
+ expand_ratio: float = 0.5,
328
+ channel_attention: bool = True,
329
+ act_cfg: Dict = dict(type='SiLU'),
330
+ ) -> None:
331
+ super().__init__()
332
+ arch_setting = self.arch_settings[arch]
333
+
334
+ self.out_indices = out_indices
335
+ self.frozen_stages = frozen_stages
336
+ self.use_depthwise = use_depthwise
337
+
338
+ conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
339
+
340
+ self.stem = nn.Sequential(
341
+ ConvModule(
342
+ 3,
343
+ int(arch_setting[0][0] * widen_factor // 2),
344
+ 3,
345
+ padding=1,
346
+ stride=2,
347
+ act_cfg=act_cfg),
348
+ ConvModule(
349
+ int(arch_setting[0][0] * widen_factor // 2),
350
+ int(arch_setting[0][0] * widen_factor // 2),
351
+ 3,
352
+ padding=1,
353
+ stride=1,
354
+ act_cfg=act_cfg),
355
+ ConvModule(
356
+ int(arch_setting[0][0] * widen_factor // 2),
357
+ int(arch_setting[0][0] * widen_factor),
358
+ 3,
359
+ padding=1,
360
+ stride=1,
361
+ act_cfg=act_cfg))
362
+
363
+ self.layers = ['stem']
364
+
365
+ for i, (in_channels, out_channels, num_blocks, add_identity,
366
+ use_spp) in enumerate(arch_setting):
367
+ in_channels = int(in_channels * widen_factor)
368
+ out_channels = int(out_channels * widen_factor)
369
+ num_blocks = max(round(num_blocks * deepen_factor), 1)
370
+ stage = []
371
+
372
+ conv_layer = conv(
373
+ in_channels,
374
+ out_channels,
375
+ 3,
376
+ stride=2,
377
+ padding=1,
378
+ act_cfg=act_cfg)
379
+
380
+ stage.append(conv_layer)
381
+
382
+ if use_spp:
383
+ spp = SPPBottleneck(
384
+ out_channels,
385
+ out_channels,
386
+ act_cfg=act_cfg)
387
+ stage.append(spp)
388
+
389
+ csp_layer = CSPLayer(
390
+ out_channels,
391
+ out_channels,
392
+ num_blocks=num_blocks,
393
+ add_identity=add_identity,
394
+ use_depthwise=use_depthwise,
395
+ use_cspnext_block=True,
396
+ expand_ratio=expand_ratio,
397
+ channel_attention=channel_attention,
398
+ act_cfg=act_cfg)
399
+
400
+ stage.append(csp_layer)
401
+
402
+ self.add_module(f'stage{i + 1}', nn.Sequential(*stage))
403
+ self.layers.append(f'stage{i + 1}')
404
+
405
+ def freeze_stages(self) -> None:
406
+ """Freeze stages parameters."""
407
+ if self.frozen_stages >= 0:
408
+ for i in range(self.frozen_stages + 1):
409
+ m = getattr(self, self.layers[i])
410
+ m.eval()
411
+ for param in m.parameters():
412
+ param.requires_grad = False
413
+
414
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
415
+ outs = []
416
+ for i, layer_name in enumerate(self.layers):
417
+ layer = getattr(self, layer_name)
418
+ x = layer(x)
419
+ if i in self.out_indices:
420
+ outs.append(x)
421
+ return tuple(outs)
422
+
423
+
424
+ class CSPNeXtPAFPN(nn.Module):
425
+ """Path Aggregation Network with CSPNeXt blocks."""
426
+ def __init__(
427
+ self,
428
+ in_channels: Sequence[int],
429
+ out_channels: int,
430
+ out_indices: Tuple[int, ...] = (1, 2),
431
+ num_csp_blocks: int = 3,
432
+ use_depthwise: bool = False,
433
+ expand_ratio: float = 0.5,
434
+ act_cfg: Dict = dict(type='SiLU'),
435
+ ) -> None:
436
+ super().__init__()
437
+ self.in_channels = in_channels
438
+ self.out_channels = out_channels
439
+ self.out_indices = out_indices
440
+
441
+ conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
442
+
443
+ # Build top-down blocks
444
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
445
+ self.reduce_layers = nn.ModuleList()
446
+ self.top_down_blocks = nn.ModuleList()
447
+
448
+ for idx in range(len(in_channels) - 1, 0, -1):
449
+ self.reduce_layers.append(
450
+ ConvModule(
451
+ in_channels[idx],
452
+ in_channels[idx - 1],
453
+ 1,
454
+ act_cfg=act_cfg))
455
+
456
+ self.top_down_blocks.append(
457
+ CSPLayer(
458
+ in_channels[idx - 1] * 2,
459
+ in_channels[idx - 1],
460
+ num_blocks=num_csp_blocks,
461
+ add_identity=False,
462
+ use_depthwise=use_depthwise,
463
+ use_cspnext_block=True,
464
+ expand_ratio=expand_ratio,
465
+ act_cfg=act_cfg))
466
+
467
+ # Build bottom-up blocks
468
+ self.downsamples = nn.ModuleList()
469
+ self.bottom_up_blocks = nn.ModuleList()
470
+
471
+ for idx in range(len(in_channels) - 1):
472
+ self.downsamples.append(
473
+ conv(
474
+ in_channels[idx],
475
+ in_channels[idx],
476
+ 3,
477
+ stride=2,
478
+ padding=1,
479
+ act_cfg=act_cfg))
480
+
481
+ self.bottom_up_blocks.append(
482
+ CSPLayer(
483
+ in_channels[idx] * 2,
484
+ in_channels[idx + 1],
485
+ num_blocks=num_csp_blocks,
486
+ add_identity=False,
487
+ use_depthwise=use_depthwise,
488
+ use_cspnext_block=True,
489
+ expand_ratio=expand_ratio,
490
+ act_cfg=act_cfg))
491
+
492
+ if self.out_channels is not None:
493
+ self.out_convs = nn.ModuleList()
494
+ for i in range(len(in_channels)):
495
+ self.out_convs.append(
496
+ conv(
497
+ in_channels[i],
498
+ out_channels,
499
+ 3,
500
+ padding=1,
501
+ act_cfg=act_cfg))
502
+
503
+ def forward(self, inputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
504
+ assert len(inputs) == len(self.in_channels)
505
+
506
+ # Top-down path
507
+ inner_outs = [inputs[-1]]
508
+ for idx in range(len(self.in_channels) - 1, 0, -1):
509
+ feat_high = inner_outs[0]
510
+ feat_low = inputs[idx - 1]
511
+ feat_high = self.reduce_layers[len(self.in_channels) - 1 - idx](
512
+ feat_high)
513
+ inner_outs[0] = feat_high
514
+
515
+ upsample_feat = self.upsample(feat_high)
516
+
517
+ inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
518
+ torch.cat([upsample_feat, feat_low], 1))
519
+ inner_outs.insert(0, inner_out)
520
+
521
+ # Bottom-up path
522
+ outs = [inner_outs[0]]
523
+ for idx in range(len(self.in_channels) - 1):
524
+ feat_low = outs[-1]
525
+ feat_high = inner_outs[idx + 1]
526
+ downsample_feat = self.downsamples[idx](feat_low)
527
+ out = self.bottom_up_blocks[idx](
528
+ torch.cat([downsample_feat, feat_high], 1))
529
+ outs.append(out)
530
+
531
+ if self.out_channels is not None:
532
+ # Apply output convolutions
533
+ for idx in range(len(outs)):
534
+ outs[idx] = self.out_convs[idx](outs[idx])
535
+
536
+ return tuple([outs[i] for i in self.out_indices])
537
+
538
+
539
+ class ScaleNorm(nn.Module):
540
+ """Scale normalization layer with scaling factor."""
541
+ def __init__(self, dim: int, eps: float = 1e-5):
542
+ super().__init__()
543
+ self.scale = dim ** -0.5
544
+ self.eps = eps
545
+ self.g = nn.Parameter(torch.ones(1))
546
+
547
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
548
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
549
+ return x / (norm + self.eps) * self.g
550
+
551
+
552
+ class Scale(nn.Module):
553
+ """Scale vector by element multiplications."""
554
+ def __init__(self, dim, init_value=1., trainable=True):
555
+ super().__init__()
556
+ self.scale = nn.Parameter(
557
+ init_value * torch.ones(dim), requires_grad=trainable)
558
+
559
+ def forward(self, x):
560
+ return x * self.scale
561
+
562
+
563
+ def drop_path(x: torch.Tensor,
564
+ drop_prob: float = 0.,
565
+ training: bool = False) -> torch.Tensor:
566
+ """Drop paths (Stochastic Depth) per sample."""
567
+ if drop_prob == 0. or not training:
568
+ return x
569
+ keep_prob = 1 - drop_prob
570
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
571
+ random_tensor = keep_prob + torch.rand(
572
+ shape, dtype=x.dtype, device=x.device)
573
+ output = x.div(keep_prob) * random_tensor.floor()
574
+ return output
575
+
576
+
577
+ class DropPath(nn.Module):
578
+ """Drop paths (Stochastic Depth) per sample."""
579
+ def __init__(self, drop_prob: float = 0.1):
580
+ super().__init__()
581
+ self.drop_prob = drop_prob
582
+
583
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
584
+ return drop_path(x, self.drop_prob, self.training)
585
+
586
+
587
+ def rope(x, dim):
588
+ """Applies Rotary Position Embedding to input tensor."""
589
+ shape = x.shape
590
+ if isinstance(dim, int):
591
+ dim = [dim]
592
+
593
+ spatial_shape = [shape[i] for i in dim]
594
+ total_len = 1
595
+ for i in spatial_shape:
596
+ total_len *= i
597
+
598
+ position = torch.reshape(
599
+ torch.arange(total_len, dtype=torch.int, device=x.device),
600
+ spatial_shape)
601
+
602
+ for i in range(dim[-1] + 1, len(shape) - 1, 1):
603
+ position = torch.unsqueeze(position, dim=-1)
604
+
605
+ half_size = shape[-1] // 2
606
+ freq_seq = -torch.arange(
607
+ half_size, dtype=torch.int, device=x.device) / float(half_size)
608
+ inv_freq = 10000**-freq_seq
609
+ sinusoid = position[..., None] * inv_freq[None, None, :]
610
+ sin = torch.sin(sinusoid)
611
+ cos = torch.cos(sinusoid)
612
+
613
+ x1, x2 = torch.chunk(x, 2, dim=-1)
614
+ return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
615
+
616
+
617
+ # def gaussian_blur1d(simcc: np.ndarray, kernel: int = 11) -> np.ndarray:
618
+ # """Modulate simcc distribution with Gaussian.
619
+
620
+ # Note:
621
+ # - num_keypoints: K
622
+ # - simcc length: Wx
623
+
624
+ # Args:
625
+ # simcc (np.ndarray[K, Wx]): model predicted simcc.
626
+ # kernel (int): Gaussian kernel size (K) for modulation, which should
627
+ # match the simcc gaussian sigma when training.
628
+ # K=17 for sigma=3 and k=11 for sigma=2.
629
+
630
+ # Returns:
631
+ # np.ndarray ([K, Wx]): Modulated simcc distribution.
632
+ # """
633
+ # assert kernel % 2 == 1
634
+
635
+ # border = (kernel - 1) // 2
636
+ # N, K, Wx = simcc.shape
637
+
638
+ # for n, k in product(range(N), range(K)):
639
+ # origin_max = np.max(simcc[n, k])
640
+ # dr = np.zeros((1, Wx + 2 * border), dtype=np.float32)
641
+ # dr[0, border:-border] = simcc[n, k].copy()
642
+ # dr = cv2.GaussianBlur(dr, (kernel, 1), 0)
643
+ # simcc[n, k] = dr[0, border:-border].copy()
644
+ # simcc[n, k] *= origin_max / np.max(simcc[n, k])
645
+ # return simcc
646
+
647
+ def gaussian_blur1d(simcc: torch.Tensor, kernel: int = 11) -> torch.Tensor:
648
+ """Modulate simcc distribution with Gaussian using PyTorch.
649
+
650
+ Args:
651
+ simcc (torch.Tensor[N, K, Wx]): model predicted simcc.
652
+ kernel (int): Gaussian kernel size (K) for modulation, which should
653
+ match the simcc gaussian sigma when training.
654
+ K=17 for sigma=3 and k=11 for sigma=2.
655
+
656
+ Returns:
657
+ torch.Tensor ([N, K, Wx]): Modulated simcc distribution.
658
+ """
659
+ assert kernel % 2 == 1
660
+
661
+ border = (kernel - 1) // 2
662
+ N, K, Wx = simcc.shape
663
+
664
+ # Create Gaussian kernel
665
+ sigma = kernel / 6.0 # Approximate conversion from kernel size to sigma
666
+ x = torch.arange(-border, border + 1, dtype=torch.float, device=simcc.device)
667
+ kernel_1d = torch.exp(-0.5 * (x / sigma).pow(2))
668
+ kernel_1d = kernel_1d / kernel_1d.sum()
669
+
670
+ # Reshape kernel for conv1d: (out_channels, in_channels/groups, kernel_length)
671
+ kernel_1d = kernel_1d.view(1, 1, kernel).expand(1, 1, kernel)
672
+
673
+ result = torch.zeros_like(simcc)
674
+
675
+
676
+ def get_simcc_maximum(simcc_x: torch.Tensor,
677
+ simcc_y: torch.Tensor,
678
+ apply_softmax: bool = False
679
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
680
+ """Get maximum response location and value from simcc representations.
681
+
682
+ Args:
683
+ simcc_x (torch.Tensor): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
684
+ simcc_y (torch.Tensor): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
685
+ apply_softmax (bool): whether to apply softmax on the heatmap.
686
+ Defaults to False.
687
+
688
+ Returns:
689
+ tuple:
690
+ - locs (torch.Tensor): locations of maximum heatmap responses in shape
691
+ (K, 2) or (N, K, 2)
692
+ - vals (torch.Tensor): values of maximum heatmap responses in shape
693
+ (K,) or (N, K)
694
+ """
695
+
696
+ assert simcc_x.dim() == 2 or simcc_x.dim() == 3, f'Invalid shape {simcc_x.shape}'
697
+ assert simcc_y.dim() == 2 or simcc_y.dim() == 3, f'Invalid shape {simcc_y.shape}'
698
+ assert simcc_x.dim() == simcc_y.dim(), f'{simcc_x.shape} != {simcc_y.shape}'
699
+
700
+ if simcc_x.dim() == 3:
701
+ N, K, Wx = simcc_x.shape
702
+ simcc_x_reshape = simcc_x.reshape(N * K, -1)
703
+ simcc_y_reshape = simcc_y.reshape(N * K, -1)
704
+ else:
705
+ N = None
706
+ simcc_x_reshape = simcc_x
707
+ simcc_y_reshape = simcc_y
708
+
709
+ if apply_softmax:
710
+ simcc_x_reshape = simcc_x_reshape - torch.max(simcc_x_reshape, dim=1, keepdim=True)[0]
711
+ simcc_y_reshape = simcc_y_reshape - torch.max(simcc_y_reshape, dim=1, keepdim=True)[0]
712
+ ex, ey = torch.exp(simcc_x_reshape), torch.exp(simcc_y_reshape)
713
+ simcc_x_reshape = ex / torch.sum(ex, dim=1, keepdim=True)
714
+ simcc_y_reshape = ey / torch.sum(ey, dim=1, keepdim=True)
715
+
716
+ # Get argmax locations
717
+ x_locs = torch.argmax(simcc_x_reshape, dim=1)
718
+ y_locs = torch.argmax(simcc_y_reshape, dim=1)
719
+
720
+ # Create combined location tensor
721
+ locs = torch.stack((x_locs, y_locs), dim=-1).float()
722
+
723
+ # Get maximum values for each axis
724
+ max_val_x = torch.amax(simcc_x_reshape, dim=1)
725
+ max_val_y = torch.amax(simcc_y_reshape, dim=1)
726
+
727
+ # Take the MINIMUM value between x and y responses (this is the correct behavior from MMPose)
728
+ vals = torch.minimum(max_val_x, max_val_y)
729
+
730
+ # Set invalid locations (where confidence is zero) to -1
731
+ locs[vals <= 0.] = -1
732
+
733
+ if N is not None:
734
+ locs = locs.reshape(N, K, 2)
735
+ vals = vals.reshape(N, K)
736
+
737
+ return locs, vals
738
+
739
+
740
+ def refine_simcc_dark(keypoints: torch.Tensor, simcc: torch.Tensor,
741
+ blur_kernel_size: int) -> torch.Tensor:
742
+ """PyTorch version of SimCC refinement using distribution aware decoding for UDP.
743
+
744
+ Args:
745
+ keypoints (torch.Tensor): The keypoint coordinates in shape (N, K, D)
746
+ simcc (torch.Tensor): The heatmaps in shape (N, K, Wx)
747
+ blur_kernel_size (int): The Gaussian blur kernel size of the heatmap
748
+ modulation
749
+
750
+ Returns:
751
+ torch.Tensor: Refined keypoint coordinates in shape (N, K, D)
752
+ """
753
+ N = simcc.shape[0]
754
+
755
+ # Modulate simcc
756
+ simcc = gaussian_blur1d(simcc, blur_kernel_size)
757
+ simcc = torch.clamp(simcc, min=1e-3, max=50.)
758
+ simcc = torch.log(simcc)
759
+
760
+ # Pad the simcc tensor
761
+ simcc = F.pad(simcc, (2, 2), mode='replicate')
762
+
763
+ # Create refined keypoints tensor
764
+ keypoints_refined = keypoints.clone()
765
+
766
+ for n in range(N):
767
+ # Convert keypoints to indices
768
+ px = (keypoints[n] + 2.5).long().view(-1, 1) # K, 1
769
+
770
+ # Ensure indices are within bounds
771
+ px = torch.clamp(px, min=0, max=simcc.shape[2]-1)
772
+
773
+ # Sample values for dx calculation
774
+ # Use gather for more efficient tensor indexing
775
+ # Create index tensors for gather
776
+ batch_idx = torch.zeros_like(px).long() + n
777
+ channel_idx = torch.arange(px.shape[0], device=px.device).view(-1, 1)
778
+
779
+ # Gather values for dx and dxx calculation
780
+ dx0 = simcc[n, torch.arange(px.shape[0], device=px.device), px.squeeze(-1)]
781
+ dx1 = simcc[n, torch.arange(px.shape[0], device=px.device), (px + 1).squeeze(-1)]
782
+ dx_1 = simcc[n, torch.arange(px.shape[0], device=px.device), (px - 1).squeeze(-1)]
783
+ dx2 = simcc[n, torch.arange(px.shape[0], device=px.device), (px + 2).squeeze(-1)]
784
+ dx_2 = simcc[n, torch.arange(px.shape[0], device=px.device), (px - 2).squeeze(-1)]
785
+
786
+ # Calculate dx and dxx
787
+ dx = 0.5 * (dx1 - dx_1)
788
+ dxx = 1e-9 + 0.25 * (dx2 - 2 * dx0 + dx_2)
789
+
790
+ # Calculate offset
791
+ offset = dx / dxx
792
+
793
+ # Apply offset to refine keypoints
794
+ keypoints_refined[n] -= offset
795
+
796
+ return keypoints_refined
797
+
798
+
799
+ class SimCCCodec:
800
+ """Generate keypoint representation via SimCC approach - All PyTorch implementation.
801
+
802
+ This class implements the SimCC (Simple Coordinate Classification) approach for human pose estimation
803
+ without relying on NumPy, ensuring full PyTorch tensor compatibility.
804
+
805
+ Args:
806
+ input_size (tuple): Input image size in [w, h]
807
+ smoothing_type (str): The SimCC label smoothing strategy. Options are
808
+ 'gaussian' and 'standard'. Defaults to 'gaussian'
809
+ sigma (float | int | tuple): The sigma value in the Gaussian SimCC label.
810
+ Defaults to 6.0
811
+ simcc_split_ratio (float): The ratio of the label size to the input size.
812
+ For example, if the input width is w, the x label size will be
813
+ w*simcc_split_ratio. Defaults to 2.0
814
+ normalize (bool): Whether to normalize the heatmaps. Defaults to False.
815
+ use_dark (bool): Whether to use the DARK post processing. Defaults to False.
816
+ """
817
+
818
+ def __init__(
819
+ self,
820
+ input_size,
821
+ smoothing_type='gaussian',
822
+ sigma=6.0,
823
+ simcc_split_ratio=2.0,
824
+ normalize=False,
825
+ use_dark=False
826
+ ):
827
+ self.input_size = input_size
828
+ self.smoothing_type = smoothing_type
829
+ self.simcc_split_ratio = simcc_split_ratio
830
+ self.normalize = normalize
831
+ self.use_dark = use_dark
832
+
833
+ if isinstance(sigma, (float, int)):
834
+ sigma = [sigma, sigma]
835
+ self.sigma = torch.tensor(sigma)
836
+
837
+ def encode(self, keypoints, keypoints_visible=None):
838
+ """Encoding keypoints into SimCC labels. Note that the original
839
+ keypoint coordinates should be in the input image space.
840
+
841
+ This is primarily used for training but included for completeness.
842
+ """
843
+ raise NotImplementedError(
844
+ "SimCCCodecPyTorch.encode() is not implemented, only supports inference.")
845
+
846
+ def decode(self, simcc_x: torch.Tensor,
847
+ simcc_y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
848
+ """Decode keypoint coordinates from SimCC representations. The decoded
849
+ coordinates are in the input image space.
850
+
851
+ Args:
852
+ simcc_x (torch.Tensor): SimCC label for x-axis
853
+ simcc_y (torch.Tensor): SimCC label for y-axis
854
+
855
+ Returns:
856
+ tuple:
857
+ - keypoints (torch.Tensor): Decoded coordinates in shape (N, K, D)
858
+ - scores (torch.Tensor): The keypoint scores in shape (N, K).
859
+ It usually represents the confidence of the keypoint prediction
860
+ """
861
+ device = simcc_x.device
862
+
863
+ # Ensure correct dimensions for processing
864
+ if simcc_x.dim() == 2:
865
+ simcc_x = simcc_x.unsqueeze(0) # Add batch dimension
866
+ if simcc_y.dim() == 2:
867
+ simcc_y = simcc_y.unsqueeze(0) # Add batch dimension
868
+
869
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
870
+
871
+ # Apply DARK post-processing if requested
872
+ if self.use_dark:
873
+ # Calculate blur kernel sizes based on sigma values
874
+ sigma_tensor = self.sigma.to(device)
875
+ x_blur = int((sigma_tensor[0] * 20 - 7) // 3)
876
+ y_blur = int((sigma_tensor[1] * 20 - 7) // 3)
877
+
878
+ # Ensure odd kernel sizes
879
+ x_blur -= int((x_blur % 2) == 0)
880
+ y_blur -= int((y_blur % 2) == 0)
881
+
882
+ # Apply DARK refinement separately to x and y coordinates
883
+ for i in range(keypoints.shape[0]):
884
+ keypoints_x = keypoints[i, :, 0:1]
885
+ keypoints_y = keypoints[i, :, 1:2]
886
+
887
+ keypoints[i, :, 0] = refine_simcc_dark(
888
+ keypoints_x, simcc_x[i:i+1], x_blur)[:, 0]
889
+ keypoints[i, :, 1] = refine_simcc_dark(
890
+ keypoints_y, simcc_y[i:i+1], y_blur)[:, 0]
891
+
892
+ # Convert from SimCC coordinate space back to image coordinate space
893
+ keypoints /= self.simcc_split_ratio
894
+
895
+ return keypoints, scores
896
+
897
+
898
+ class RTMCCBlock(nn.Module):
899
+ """Gated Attention Unit (GAU) in RTMBlock."""
900
+
901
+ def __init__(
902
+ self,
903
+ num_token,
904
+ in_token_dims,
905
+ out_token_dims,
906
+ expansion_factor=2,
907
+ s=128,
908
+ eps=1e-5,
909
+ dropout_rate=0.,
910
+ drop_path=0.,
911
+ attn_type='self-attn',
912
+ act_fn='SiLU',
913
+ bias=False,
914
+ use_rel_bias=True,
915
+ pos_enc=False
916
+ ):
917
+ super(RTMCCBlock, self).__init__()
918
+ self.s = s
919
+ self.num_token = num_token
920
+ self.use_rel_bias = use_rel_bias
921
+ self.attn_type = attn_type
922
+ self.pos_enc = pos_enc
923
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
924
+ self.e = int(in_token_dims * expansion_factor)
925
+
926
+ if use_rel_bias:
927
+ if attn_type == 'self-attn':
928
+ self.w = nn.Parameter(
929
+ torch.rand([2 * num_token - 1], dtype=torch.float))
930
+ else:
931
+ self.a = nn.Parameter(torch.rand([1, s], dtype=torch.float))
932
+ self.b = nn.Parameter(torch.rand([1, s], dtype=torch.float))
933
+
934
+ self.o = nn.Linear(self.e, out_token_dims, bias=bias)
935
+
936
+ if attn_type == 'self-attn':
937
+ self.uv = nn.Linear(in_token_dims, 2 * self.e + self.s, bias=bias)
938
+ self.gamma = nn.Parameter(torch.rand((2, self.s)))
939
+ self.beta = nn.Parameter(torch.rand((2, self.s)))
940
+ else:
941
+ self.uv = nn.Linear(in_token_dims, self.e + self.s, bias=bias)
942
+ self.k_fc = nn.Linear(in_token_dims, self.s, bias=bias)
943
+ self.v_fc = nn.Linear(in_token_dims, self.e, bias=bias)
944
+ nn.init.xavier_uniform_(self.k_fc.weight)
945
+ nn.init.xavier_uniform_(self.v_fc.weight)
946
+
947
+ self.ln = ScaleNorm(in_token_dims, eps=eps)
948
+ nn.init.xavier_uniform_(self.uv.weight)
949
+
950
+ if act_fn == 'SiLU' or act_fn == nn.SiLU:
951
+ self.act_fn = nn.SiLU(True)
952
+ elif act_fn == 'ReLU' or act_fn == nn.ReLU:
953
+ self.act_fn = nn.ReLU(True)
954
+ else:
955
+ raise NotImplementedError
956
+
957
+ if in_token_dims == out_token_dims:
958
+ self.shortcut = True
959
+ self.res_scale = Scale(in_token_dims)
960
+ else:
961
+ self.shortcut = False
962
+
963
+ self.sqrt_s = torch.sqrt(torch.tensor(s, dtype=torch.float))
964
+ self.dropout_rate = dropout_rate
965
+ if dropout_rate > 0.:
966
+ self.dropout = nn.Dropout(dropout_rate)
967
+
968
+ def rel_pos_bias(self, seq_len, k_len=None):
969
+ """Add relative position bias."""
970
+ if self.attn_type == 'self-attn':
971
+ t = F.pad(self.w[:2 * seq_len - 1], [0, seq_len]).repeat(seq_len)
972
+ t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2)
973
+ r = (2 * seq_len - 1) // 2
974
+ t = t[..., r:-r]
975
+ else:
976
+ a = rope(self.a.repeat(seq_len, 1), dim=0)
977
+ b = rope(self.b.repeat(k_len, 1), dim=0)
978
+ t = torch.bmm(a, b.permute(0, 2, 1))
979
+ return t
980
+
981
+ def _forward(self, inputs):
982
+ """GAU Forward function."""
983
+ if self.attn_type == 'self-attn':
984
+ x = inputs
985
+ else:
986
+ x, k, v = inputs
987
+
988
+ x = self.ln(x)
989
+ uv = self.uv(x)
990
+ uv = self.act_fn(uv)
991
+
992
+ if self.attn_type == 'self-attn':
993
+ # Split into u, v, base
994
+ u, v, base = torch.split(uv, [self.e, self.e, self.s], dim=2)
995
+ # Apply gamma and beta parameters
996
+ base = base.unsqueeze(2) * self.gamma[None, None, :] + self.beta[None, None, :]
997
+ if self.pos_enc:
998
+ base = rope(base, dim=1)
999
+ # Split base into q, k
1000
+ q, k = torch.unbind(base, dim=2)
1001
+ else:
1002
+ # Split into u, q
1003
+ u, q = torch.split(uv, [self.e, self.s], dim=2)
1004
+ k = self.k_fc(k) # -> [B, K, s]
1005
+ v = self.v_fc(v) # -> [B, K, e]
1006
+ if self.pos_enc:
1007
+ q = rope(q, 1)
1008
+ k = rope(k, 1)
1009
+
1010
+ # Calculate attention
1011
+ qk = torch.bmm(q, k.permute(0, 2, 1))
1012
+
1013
+ if self.use_rel_bias:
1014
+ if self.attn_type == 'self-attn':
1015
+ bias = self.rel_pos_bias(q.size(1))
1016
+ else:
1017
+ bias = self.rel_pos_bias(q.size(1), k.size(1))
1018
+ qk += bias[:, :q.size(1), :k.size(1)]
1019
+
1020
+ # Apply kernel (square of ReLU)
1021
+ kernel = torch.square(F.relu(qk / self.sqrt_s))
1022
+
1023
+ if self.dropout_rate > 0.:
1024
+ kernel = self.dropout(kernel)
1025
+
1026
+ # Apply attention
1027
+ if self.attn_type == 'self-attn':
1028
+ x = u * torch.bmm(kernel, v)
1029
+ else:
1030
+ x = u * torch.bmm(kernel, v)
1031
+
1032
+ x = self.o(x)
1033
+ return x
1034
+
1035
+ def forward(self, x):
1036
+ """Forward function."""
1037
+ if self.shortcut:
1038
+ if self.attn_type == 'cross-attn':
1039
+ res_shortcut = x[0]
1040
+ else:
1041
+ res_shortcut = x
1042
+ main_branch = self.drop_path(self._forward(x))
1043
+ return self.res_scale(res_shortcut) + main_branch
1044
+ else:
1045
+ return self.drop_path(self._forward(x))
1046
+
1047
+
1048
+ class RTMWHead(nn.Module):
1049
+ """Top-down head introduced in RTMPose-Wholebody (2023).
1050
+ Updated to use PyTorch-only implementations without NumPy or OpenCV.
1051
+ """
1052
+
1053
+ def __init__(
1054
+ self,
1055
+ in_channels: int,
1056
+ out_channels: int,
1057
+ input_size: Tuple[int, int],
1058
+ in_featuremap_size: Tuple[int, int],
1059
+ simcc_split_ratio: float = 2.0,
1060
+ final_layer_kernel_size: int = 7,
1061
+ gau_cfg: Optional[Dict] = None,
1062
+ decoder: Optional[Dict] = None,
1063
+ ):
1064
+ super().__init__()
1065
+ self.in_channels = in_channels
1066
+ self.out_channels = out_channels
1067
+ self.input_size = input_size
1068
+ self.in_featuremap_size = in_featuremap_size
1069
+ self.simcc_split_ratio = simcc_split_ratio
1070
+
1071
+ # Default GAU config if not provided
1072
+ if gau_cfg is None:
1073
+ gau_cfg = dict(
1074
+ hidden_dims=256,
1075
+ s=128,
1076
+ expansion_factor=2,
1077
+ dropout_rate=0.,
1078
+ drop_path=0.,
1079
+ act_fn='ReLU',
1080
+ use_rel_bias=False,
1081
+ pos_enc=False)
1082
+
1083
+ # Define SimCC layers
1084
+ flatten_dims = self.in_featuremap_size[0] * self.in_featuremap_size[1]
1085
+
1086
+ ps = 2 # pixel shuffle factor
1087
+ self.ps = nn.PixelShuffle(ps)
1088
+
1089
+ self.conv_dec = ConvModule(
1090
+ in_channels // ps**2,
1091
+ in_channels // 4,
1092
+ kernel_size=final_layer_kernel_size,
1093
+ stride=1,
1094
+ padding=final_layer_kernel_size // 2,
1095
+ norm_cfg=dict(type='BN'),
1096
+ act_cfg=dict(type='ReLU'))
1097
+
1098
+ self.final_layer = ConvModule(
1099
+ in_channels,
1100
+ out_channels,
1101
+ kernel_size=final_layer_kernel_size,
1102
+ stride=1,
1103
+ padding=final_layer_kernel_size // 2,
1104
+ norm_cfg=dict(type='BN'),
1105
+ act_cfg=dict(type='ReLU'))
1106
+
1107
+ self.final_layer2 = ConvModule(
1108
+ in_channels // ps + in_channels // 4,
1109
+ out_channels,
1110
+ kernel_size=final_layer_kernel_size,
1111
+ stride=1,
1112
+ padding=final_layer_kernel_size // 2,
1113
+ norm_cfg=dict(type='BN'),
1114
+ act_cfg=dict(type='ReLU'))
1115
+
1116
+ self.mlp = nn.Sequential(
1117
+ ScaleNorm(flatten_dims),
1118
+ nn.Linear(flatten_dims, gau_cfg['hidden_dims'] // 2, bias=False))
1119
+
1120
+ self.mlp2 = nn.Sequential(
1121
+ ScaleNorm(flatten_dims * ps**2),
1122
+ nn.Linear(
1123
+ flatten_dims * ps**2, gau_cfg['hidden_dims'] // 2, bias=False))
1124
+
1125
+ W = int(self.input_size[0] * self.simcc_split_ratio)
1126
+ H = int(self.input_size[1] * self.simcc_split_ratio)
1127
+
1128
+ self.gau = RTMCCBlock(
1129
+ self.out_channels,
1130
+ gau_cfg['hidden_dims'],
1131
+ gau_cfg['hidden_dims'],
1132
+ s=gau_cfg['s'],
1133
+ expansion_factor=gau_cfg['expansion_factor'],
1134
+ dropout_rate=gau_cfg['dropout_rate'],
1135
+ drop_path=gau_cfg['drop_path'],
1136
+ attn_type='self-attn',
1137
+ act_fn=gau_cfg['act_fn'],
1138
+ use_rel_bias=gau_cfg['use_rel_bias'],
1139
+ pos_enc=gau_cfg['pos_enc'])
1140
+
1141
+ self.cls_x = nn.Linear(gau_cfg['hidden_dims'], W, bias=False)
1142
+ self.cls_y = nn.Linear(gau_cfg['hidden_dims'], H, bias=False)
1143
+
1144
+ # Create SimCC codec for decoding - using PyTorch version
1145
+ if decoder is not None:
1146
+ self.decoder = SimCCCodec(
1147
+ input_size=decoder.get('input_size', self.input_size),
1148
+ smoothing_type=decoder.get('smoothing_type', 'gaussian'),
1149
+ sigma=decoder.get('sigma', (4.9, 5.66)),
1150
+ simcc_split_ratio=self.simcc_split_ratio,
1151
+ normalize=decoder.get('normalize', False),
1152
+ use_dark=decoder.get('use_dark', False)
1153
+ )
1154
+ else:
1155
+ self.decoder = SimCCCodec(
1156
+ input_size=self.input_size,
1157
+ sigma=(4.9, 5.66),
1158
+ simcc_split_ratio=self.simcc_split_ratio,
1159
+ normalize=False,
1160
+ use_dark=False
1161
+ )
1162
+
1163
+ def forward(self, feats: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
1164
+ """Forward the network to get SimCC representations.
1165
+
1166
+ Args:
1167
+ feats (Tuple[Tensor]): Multi scale feature maps.
1168
+
1169
+ Returns:
1170
+ pred_x (Tensor): 1d representation of x.
1171
+ pred_y (Tensor): 1d representation of y.
1172
+ """
1173
+ enc_b, enc_t = feats
1174
+
1175
+ feats_t = self.final_layer(enc_t)
1176
+ feats_t = torch.flatten(feats_t, 2)
1177
+ feats_t = self.mlp(feats_t)
1178
+
1179
+ dec_t = self.ps(enc_t)
1180
+ dec_t = self.conv_dec(dec_t)
1181
+ enc_b = torch.cat([dec_t, enc_b], dim=1)
1182
+
1183
+ feats_b = self.final_layer2(enc_b)
1184
+ feats_b = torch.flatten(feats_b, 2)
1185
+ feats_b = self.mlp2(feats_b)
1186
+
1187
+ feats = torch.cat([feats_t, feats_b], dim=2)
1188
+
1189
+ feats = self.gau(feats)
1190
+
1191
+ pred_x = self.cls_x(feats)
1192
+ pred_y = self.cls_y(feats)
1193
+
1194
+ return pred_x, pred_y
1195
+
1196
+ def predict(self, feats: Tuple[torch.Tensor, torch.Tensor], flip_test=False, flip_indices=None):
1197
+ """Predict keypoints from features.
1198
+
1199
+ Args:
1200
+ feats (Tuple[torch.Tensor]): Features from the backbone + neck
1201
+ flip_test (bool): Whether to use flip test augmentation
1202
+ flip_indices (List[int]): Indices for flipping keypoints
1203
+
1204
+ Returns:
1205
+ List[Dict]: Predicted keypoints and scores
1206
+ """
1207
+ batch_pred_x, batch_pred_y = None, None
1208
+ device = feats[0].device
1209
+
1210
+ if flip_test:
1211
+ assert flip_indices is not None, "flip_indices must be provided for flip test"
1212
+
1213
+ # Original forward pass
1214
+ _batch_pred_x, _batch_pred_y = self.forward(feats)
1215
+
1216
+ # Create flipped input and get predictions
1217
+ feats_flipped = [torch.flip(feat, dims=[-1]) for feat in feats]
1218
+ _batch_pred_x_flip, _batch_pred_y_flip = self.forward(feats_flipped)
1219
+
1220
+ # Flip predictions back - critical part
1221
+ _batch_pred_x_flip = torch.flip(_batch_pred_x_flip, dims=[2]) # Flip along the width dimension
1222
+
1223
+ # Handle keypoint swapping (like left-right joints)
1224
+ batch_size = _batch_pred_x.shape[0]
1225
+ for i in range(batch_size):
1226
+ for src_idx, dst_idx in enumerate(flip_indices):
1227
+ if src_idx != dst_idx:
1228
+ _batch_pred_x_flip[i, dst_idx] = _batch_pred_x_flip[i, src_idx].clone()
1229
+ _batch_pred_y_flip[i, dst_idx] = _batch_pred_y_flip[i, src_idx].clone()
1230
+
1231
+ # Average the predictions
1232
+ batch_pred_x = (_batch_pred_x + _batch_pred_x_flip) * 0.5
1233
+ batch_pred_y = (_batch_pred_y + _batch_pred_y_flip) * 0.5
1234
+ else:
1235
+ # Standard forward pass
1236
+ batch_pred_x, batch_pred_y = self.forward(feats)
1237
+
1238
+ # Decode keypoints using PyTorch-based decoder
1239
+ keypoints, scores = self.decoder.decode(batch_pred_x, batch_pred_y)
1240
+
1241
+ # Convert to list of instances
1242
+ batch_size = keypoints.shape[0]
1243
+ instances = []
1244
+
1245
+ for i in range(batch_size):
1246
+ instances.append({
1247
+ 'keypoints': keypoints[i],
1248
+ 'keypoint_scores': scores[i]
1249
+ })
1250
+
1251
+ return instances
1252
+
1253
+
1254
+ class RTMWModel(PreTrainedModel):
1255
+ """
1256
+ RTMW model for human pose estimation.
1257
+
1258
+ This model consists of a backbone, neck, and pose head for keypoint detection.
1259
+ All implementations use PyTorch only with no NumPy or OpenCV dependencies.
1260
+ """
1261
+
1262
+ def __init__(self, config: RTMWConfig):
1263
+ super().__init__(config)
1264
+ self.config = config
1265
+
1266
+ # Build backbone
1267
+ self.backbone = CSPNeXt(
1268
+ arch=config.backbone_arch,
1269
+ deepen_factor=config.backbone_deepen_factor,
1270
+ widen_factor=config.backbone_widen_factor,
1271
+ expand_ratio=config.backbone_expand_ratio,
1272
+ channel_attention=config.backbone_channel_attention,
1273
+ use_depthwise=False,
1274
+ )
1275
+
1276
+ # Build neck
1277
+ self.neck = CSPNeXtPAFPN(
1278
+ in_channels=config.neck_in_channels,
1279
+ out_channels=config.neck_out_channels,
1280
+ num_csp_blocks=config.neck_num_csp_blocks,
1281
+ expand_ratio=config.neck_expand_ratio,
1282
+ use_depthwise=False,
1283
+ )
1284
+
1285
+ # Build head
1286
+ # Create GAU config from the configuration
1287
+ gau_cfg = {
1288
+ 'hidden_dims': config.gau_hidden_dims,
1289
+ 's': config.gau_s,
1290
+ 'expansion_factor': config.gau_expansion_factor,
1291
+ 'dropout_rate': config.gau_dropout_rate,
1292
+ 'drop_path': config.gau_drop_path,
1293
+ 'act_fn': config.gau_act_fn,
1294
+ 'use_rel_bias': config.gau_use_rel_bias,
1295
+ 'pos_enc': config.gau_pos_enc,
1296
+ }
1297
+
1298
+ self.head = RTMWHead(
1299
+ in_channels=config.head_in_channels,
1300
+ out_channels=config.num_keypoints,
1301
+ input_size=config.input_size,
1302
+ in_featuremap_size=config.head_in_featuremap_size,
1303
+ simcc_split_ratio=config.simcc_split_ratio,
1304
+ final_layer_kernel_size=config.head_final_layer_kernel_size,
1305
+ gau_cfg=gau_cfg,
1306
+ decoder = dict(
1307
+ input_size=config.input_size,
1308
+ sigma=config.decoder_sigma,
1309
+ simcc_split_ratio=config.simcc_split_ratio,
1310
+ normalize=config.decoder_normalize,
1311
+ use_dark=config.decoder_use_dark)
1312
+ )
1313
+
1314
+ # Initialize weights
1315
+ self.init_weights()
1316
+
1317
+ def init_weights(self):
1318
+ """Initialize the weights of the model."""
1319
+ # Initialize convolution layers with normal distribution
1320
+ for m in self.modules():
1321
+ if isinstance(m, nn.Conv2d):
1322
+ nn.init.normal_(m.weight, mean=0, std=0.01)
1323
+ if m.bias is not None:
1324
+ nn.init.constant_(m.bias, 0)
1325
+ if isinstance(m, nn.BatchNorm2d):
1326
+ nn.init.constant_(m.weight, 1)
1327
+ nn.init.constant_(m.bias, 0)
1328
+ if isinstance(m, nn.Linear):
1329
+ nn.init.normal_(m.weight, mean=0, std=0.01)
1330
+ if m.bias is not None:
1331
+ nn.init.constant_(m.bias, 0)
1332
+
1333
+ def forward(
1334
+ self,
1335
+ pixel_values=None,
1336
+ labels=None,
1337
+ output_hidden_states=None,
1338
+ return_dict=None,
1339
+ ):
1340
+ """
1341
+ Forward pass of the model.
1342
+
1343
+ Args:
1344
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
1345
+ Pixel values. Pixel values can be obtained using
1346
+ RTMWImageProcessor.
1347
+ labels (`List[Dict]`, *optional*):
1348
+ Labels for computing the pose estimation loss.
1349
+ output_hidden_states (`bool`, *optional*):
1350
+ Whether or not to return the hidden states of all layers.
1351
+ return_dict (`bool`, *optional*):
1352
+ Whether or not to return a ModelOutput instead of a plain tuple.
1353
+
1354
+ Returns:
1355
+ `PoseOutput` or `tuple`:
1356
+ If return_dict=True, `PoseOutput` is returned.
1357
+ If return_dict=False, a tuple is returned with keypoints and scores.
1358
+ """
1359
+ return_dict = return_dict if return_dict is not None else True
1360
+
1361
+ # Get inputs
1362
+ if pixel_values is None:
1363
+ raise ValueError("You have to specify pixel_values")
1364
+
1365
+ # Extract features from backbone
1366
+ backbone_features = self.backbone(pixel_values)
1367
+
1368
+ # Process features through neck
1369
+ neck_features = self.neck(backbone_features)
1370
+
1371
+ # Get SimCC representations from pose head
1372
+ pred_x, pred_y = self.head.forward(neck_features)
1373
+
1374
+ # Decode keypoints
1375
+ instances = self.head.predict(neck_features, None)
1376
+
1377
+ # Extract keypoints and scores from instances
1378
+ batch_size = len(instances)
1379
+ keypoints = torch.zeros((batch_size, self.head.out_channels, 2), device=pixel_values.device)
1380
+ scores = torch.zeros((batch_size, self.head.out_channels), device=pixel_values.device)
1381
+
1382
+ for i, instance in enumerate(instances):
1383
+ keypoints[i] = instance['keypoints']
1384
+ scores[i] = instance['keypoint_scores']
1385
+
1386
+ # Apply fixed min-max normalization to map scores to [0, 1].
1387
+ # Only valid scores (> 0) are normalized; invalid keypoints keep
1388
+ # their raw (≤ 0) values so downstream code can still filter them.
1389
+ score_min = getattr(self.config, 'score_min', None)
1390
+ score_max = getattr(self.config, 'score_max', None)
1391
+ if score_min is not None and score_max is not None and score_max > score_min:
1392
+ valid_mask = scores > 0
1393
+ scores[valid_mask] = torch.clamp(
1394
+ (scores[valid_mask] - score_min) / (score_max - score_min),
1395
+ 0.0, 1.0,
1396
+ )
1397
+
1398
+ if return_dict:
1399
+ return PoseOutput(
1400
+ keypoints=keypoints,
1401
+ scores=scores,
1402
+ pred_x=pred_x,
1403
+ pred_y=pred_y
1404
+ )
1405
+ else:
1406
+ return (keypoints, scores)
preprocessor_config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_valid_processor_keys": [
3
+ "images",
4
+ "do_resize",
5
+ "size",
6
+ "keep_aspect_ratio",
7
+ "ensure_multiple_of",
8
+ "resample",
9
+ "do_rescale",
10
+ "rescale_factor",
11
+ "do_normalize",
12
+ "image_mean",
13
+ "image_std",
14
+ "do_pad",
15
+ "size_divisor",
16
+ "return_tensors",
17
+ "data_format",
18
+ "input_data_format"
19
+ ],
20
+ "do_normalize": true,
21
+ "do_rescale": false,
22
+ "do_resize": true,
23
+ "image_mean": [
24
+ 123.675,
25
+ 116.28,
26
+ 103.53
27
+ ],
28
+ "image_processor_type": "DPTImageProcessor",
29
+ "image_std": [
30
+ 58.395,
31
+ 57.12,
32
+ 57.375
33
+ ],
34
+ "size": {
35
+ "height": 256,
36
+ "width": 192
37
+ }
38
+ }
39
+