HF Space deploy commited on
Commit
9d665dd
·
0 Parent(s):

Deploy snapshot (LFS for demo images per .gitattributes)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. README.md +251 -0
  3. app.py +660 -0
  4. chumpy/__init__.py +16 -0
  5. chumpy/ch.py +52 -0
  6. configs/sa_finetune_hrnet_w32.yaml +220 -0
  7. demo_data/000000015956_horse.png +3 -0
  8. demo_data/000000315905_zebra.jpg +3 -0
  9. demo_data/beagle.jpg +3 -0
  10. demo_data/n02101388_1188.png +3 -0
  11. demo_data/n02412080_12159.png +3 -0
  12. demo_data/shepherd_hati.jpg +3 -0
  13. demo_tta.py +399 -0
  14. images/teaser.png +3 -0
  15. packages.txt +7 -0
  16. prima/__init__.py +25 -0
  17. prima/configs/__init__.py +99 -0
  18. prima/datasets/__init__.py +79 -0
  19. prima/datasets/datasets.py +278 -0
  20. prima/datasets/dlc2coco.py +362 -0
  21. prima/datasets/split_acinoset.py +153 -0
  22. prima/datasets/utils.py +1106 -0
  23. prima/datasets/vitdet_dataset.py +100 -0
  24. prima/models/__init__.py +54 -0
  25. prima/models/backbones/__init__.py +19 -0
  26. prima/models/backbones/vit.py +375 -0
  27. prima/models/bioclip_embedding.py +70 -0
  28. prima/models/components/__init__.py +0 -0
  29. prima/models/components/model_utils.py +160 -0
  30. prima/models/components/pose_transformer.py +366 -0
  31. prima/models/components/position_encoding.py +84 -0
  32. prima/models/components/t_cond_mlp.py +204 -0
  33. prima/models/components/transformer.py +400 -0
  34. prima/models/discriminator.py +129 -0
  35. prima/models/heads/__init__.py +1 -0
  36. prima/models/heads/classifier_head.py +30 -0
  37. prima/models/heads/smal_head.py +647 -0
  38. prima/models/losses.py +580 -0
  39. prima/models/prima.py +615 -0
  40. prima/models/smal_wrapper.py +134 -0
  41. prima/utils/__init__.py +45 -0
  42. prima/utils/detection.py +118 -0
  43. prima/utils/evaluate_metric.py +206 -0
  44. prima/utils/geometry.py +115 -0
  45. prima/utils/mesh_renderer.py +330 -0
  46. prima/utils/misc.py +211 -0
  47. prima/utils/pylogger.py +26 -0
  48. prima/utils/renderer.py +433 -0
  49. prima/utils/rich_utils.py +114 -0
  50. prima/utils/weights.py +337 -0
.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Hugging Face Hub stores these via Git LFS / Xet (plain PNG/JPG in git are rejected on push).
2
+ demo_data/*.png filter=lfs diff=lfs merge=lfs -text
3
+ demo_data/*.jpg filter=lfs diff=lfs merge=lfs -text
4
+ demo_data/*.jpeg filter=lfs diff=lfs merge=lfs -text
5
+ images/*.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PRIMA Demo
3
+ emoji: 🦮
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ python_version: "3.10"
8
+ app_file: app.py
9
+ startup_duration_timeout: 60m
10
+ ---
11
+
12
+ # PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
13
+
14
+
15
+ This is the official implementation of the approach described in the preprint:
16
+
17
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation \
18
+ Xiaohang Yu, Ti Wang, Mackenzie Weygandt Mathis
19
+
20
+ ![PRIMA teaser](images/teaser.png)
21
+
22
+
23
+ ---
24
+
25
+
26
+ ## 🚀 TL;DR
27
+ PRIMA creates a 3D quadruped mesh from a single 2D image. It leverages BioCLIP-based biological priors for robust cross-species shape understanding, then applies test-time adaptation with 2D reprojection and auxiliary keypoint guidance to refine SMAL pose and shape predictions.
28
+
29
+ It further can be used to build Quadruped3D, a large-scale pseudo-3D dataset with diverse species and poses.
30
+
31
+ PRIMA achieves state-of-the-art results on Animal3D, CtrlAni3D, Quadruped2D, and Animal Kingdom datasets.
32
+
33
+ ## Installation
34
+
35
+ ### Install from PyPI
36
+
37
+ > Recommended: Python 3.10 and a CUDA-enabled PyTorch installation.
38
+
39
+ ```bash
40
+ conda create -n prima python=3.10 -y
41
+ conda activate prima
42
+
43
+ # Install PyTorch matching your CUDA (example: CUDA 11.8)
44
+ pip install --index-url https://download.pytorch.org/whl/cu118 \
45
+ "torch==2.2.1" "torchvision==0.17.1" "torchaudio==2.2.1"
46
+
47
+ # Install chumpy and PyTorch3D
48
+ python -m pip install --no-build-isolation \
49
+ "git+https://github.com/mattloper/chumpy.git"
50
+ python -m pip install --no-build-isolation \
51
+ "git+https://github.com/facebookresearch/pytorch3d.git"
52
+
53
+ # Install PRIMA from PyPI
54
+ pip install prima-animal
55
+ ```
56
+
57
+ `prima-animal` includes demo runtime dependencies used by `demo.py`, `demo_tta.py`, and `app.py` (including Detectron2 and DeepLabCut).
58
+
59
+ ### Clean install from this repository
60
+
61
+ Use these when developing from a **git clone** (not the PyPI wheel). The shell scripts are **non-interactive** (pip uses `--no-input`; `GIT_TERMINAL_PROMPT=0` for git). Put Hugging Face credentials in your environment or git credential helper before pushing the Space.
62
+
63
+ **Local (fresh venv, LFS assets, Hub demo weights, smoke test)** — requires **Python 3.10+**
64
+ (Gradio 5.1+ / Space-provided Gradio 6.x and `app.py` type hints). On macOS without `python3.10` on your `PATH`, install
65
+ `brew install python@3.10` and set `PRIMA_PYTHON=/opt/homebrew/bin/python3.10`.
66
+
67
+ ```bash
68
+ chmod +x scripts/clean_install_local.sh scripts/clean_redeploy_hf_space.sh scripts/deploy_hf_space.sh
69
+ PRIMA_PYTHON=/opt/homebrew/bin/python3.10 ./scripts/clean_install_local.sh
70
+ ```
71
+
72
+ Options:
73
+
74
+ - `PRIMA_VENV=.venv ./scripts/clean_install_local.sh --skip-data` — skip the large `setup_demo_data` download if `data/` is already populated.
75
+ - `./scripts/clean_install_local.sh --wipe-data --force-data` — delete downloaded `data/` assets and redownload.
76
+ - `./scripts/clean_install_local.sh --no-editable` — only `requirements.txt` (no `pip install -e .`); use if editable install fails and you will install the training stack via conda as in the PyPI section above. You still need **Python 3.10+** for Gradio 5.1+. The smoke test sets `PYTHONPATH` to the repo root so `import prima` works without an editable install.
77
+ - **macOS:** the script omits the `deeplabcut` line from `pip install` because DeepLabCut’s pinned PyTables version often does not build on Apple Silicon. Use conda/mamba for DeepLabCut if you need SuperAnimal + TTA (`tta_num_iters` > 0). **Linux** (including Hugging Face Space builds) uses the full `requirements.txt` including `deeplabcut`.
78
+
79
+ After `requirements.txt`, the script runs **`pip install --no-deps -e .`** so the `prima` package is registered without re-resolving `pyproject.toml` (which would pull **Detectron2** and **DeepLabCut** again and often fail on macOS). Full `pip install -e .` is still recommended from a **conda** environment per the PyPI section if you need every training extra matched exactly.
80
+
81
+ **Hugging Face Space (full redeploy from your working tree):**
82
+
83
+ Requires [Git LFS / Xet](https://huggingface.co/docs/hub/xet/using-xet-storage#git) tooling (`brew install git-lfs git-xet`, `git xet install`, `git lfs install`). Then:
84
+
85
+ ```bash
86
+ ./scripts/clean_redeploy_hf_space.sh
87
+ ```
88
+
89
+ This is equivalent to `./scripts/deploy_hf_space.sh` and force-pushes a fresh snapshot to the Space.
90
+
91
+ ---
92
+
93
+ ## Demo
94
+
95
+ ### Checkpoints and data
96
+
97
+ The demo scripts auto-download their default Stage 1 PRIMA assets from Hugging
98
+ Face when the checkpoint or matching Hydra config is missing. If you want to
99
+ pre-download all necessary checkpoints and data ahead of time, run:
100
+
101
+ ```bash
102
+ python scripts/setup_demo_data.py --hf-repo-id MLAdaptiveIntelligence/PRIMA
103
+ ```
104
+
105
+ Approximate default prefetch volume from Hugging Face is ~5.5 GB total
106
+ (`s1ckpt_inference.ckpt` ~3 GB + `amr_vitbb.pth` ~2.5 GB + SMAL files).
107
+ Expected time is roughly:
108
+ - 100 Mbps: ~7-10 minutes
109
+ - 300 Mbps: ~2-4 minutes
110
+ - 1 Gbps: ~1 minute
111
+
112
+ Existing files are reused by default; pass `--force` only if you need to redownload them. If you also need the Stage 3 pretrained model, add `--include-stage3`.
113
+
114
+ Expected files in that Hugging Face repo root:
115
+ - `my_smpl_00781_4_all.pkl`
116
+ - `my_smpl_data_00781_4_all.pkl`
117
+ - `walking_toy_symmetric_pose_prior_with_cov_35parts.pkl`
118
+ - `amr_vitbb.pth`
119
+ - `config_s1_HYDRA.yaml`
120
+ - `s1ckpt_inference.ckpt`
121
+
122
+ Optional Stage 3 prefetch expects:
123
+ - `config_s3_HYDRA.yaml`
124
+ - `s3ckpt_inference.ckpt`
125
+
126
+ ### Demo (without TTA)
127
+
128
+ Run animal detection + PRIMA 3D pose/shape inference:
129
+
130
+ ```bash
131
+ bash demo.sh
132
+ ```
133
+
134
+ Outputs are written to `demo_out/`. Edit `demo.sh` if you want to use a custom
135
+ checkpoint path.
136
+
137
+ ---
138
+
139
+ ### Demo (with TTA)
140
+
141
+ Run PRIMA inference with test-time adaptation:
142
+
143
+ ```bash
144
+ bash demo_tta.sh
145
+ ```
146
+
147
+ Outputs are written to `demo_out_tta/` (before/after TTA renders, keypoints, and
148
+ optional meshes). Edit `demo_tta.sh` if you want to change the checkpoint, TTA
149
+ learning rate, or number of iterations.
150
+
151
+ ---
152
+
153
+ ### Gradio demo
154
+
155
+ We also provide a simple Gradio-based web demo for interactive testing in the
156
+ browser:
157
+
158
+ ```bash
159
+ python app.py \
160
+ --checkpoint data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt \
161
+ --out_folder demo_out_tta_gradio/
162
+ ```
163
+
164
+ This starts a local Gradio app (by default on http://127.0.0.1:7860), where
165
+ you can upload images and visualize PRIMA predictions and adaptation results.
166
+ The `s1ckpt_inference.ckpt` checkpoint is downloaded automatically if missing.
167
+
168
+ #### Hugging Face Space (maintainers)
169
+
170
+ Demo images under `demo_data/` and `images/teaser.png` are tracked with **Git LFS**
171
+ (see `.gitattributes`) so they can be pushed to a Hugging Face Space under the Hub’s
172
+ LFS / **Xet** bridge. Install tooling once:
173
+
174
+ ```bash
175
+ brew install git-lfs git-xet
176
+ git xet install
177
+ git lfs install
178
+ ```
179
+
180
+ Then from a clean checkout with LFS files present, redeploy the Space (same as `clean_redeploy_hf_space.sh`):
181
+
182
+ ```bash
183
+ ./scripts/deploy_hf_space.sh
184
+ # or
185
+ ./scripts/clean_redeploy_hf_space.sh
186
+ ```
187
+
188
+ The script rsyncs the working tree (not `git archive`) so image files are materialized
189
+ before `git add` turns them into LFS blobs.
190
+
191
+ ---
192
+
193
+
194
+ ## Training and Evaluation
195
+
196
+ ### Dataset Setup
197
+
198
+ Download datasets from [Animal3D](https://xujiacong.github.io/Animal3D/), [CtrlAni3D](https://github.com/luoxue-star/AniMer?tab=readme-ov-file#training), Quadruped2D, and [Animal Kingdom](https://drive.google.com/file/d/1dk2a0qB0fbVZ4X6eAgP6VJVXj0rxVfsJ/view?usp=drive_link). For Quadruped2D, download the images from [SuperAnimal-Quadruped80K](https://zenodo.org/records/14016777) and our processed annotations from [here](https://drive.google.com/drive/folders/1eBNboxVwl_eGPoC93zxf-U3hmE6e2f-f?usp=sharing). Put all the datasets under `datasets/`.
199
+
200
+ ### Training
201
+
202
+ Two-stage training script:
203
+
204
+ ```bash
205
+ bash train.sh
206
+ ```
207
+
208
+ Training outputs are written to `logs/train/runs/<exp_name>/`.
209
+
210
+
211
+ ### Evaluation
212
+
213
+ ```bash
214
+ python eval.py \
215
+ --config data/PRIMAS1/.hydra/config.yaml \
216
+ --checkpoint data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt
217
+ ```
218
+
219
+ Common values for `--dataset` are controlled by:
220
+ - `configs_hydra/experiment/default_val.yaml`
221
+
222
+ ---
223
+
224
+
225
+ ## Acknowledgements
226
+
227
+ This release builds on several open-source projects, including:
228
+ - [Detectron2](https://github.com/facebookresearch/detectron2)
229
+ - [BioCLIP](https://github.com/Imageomics/BioCLIP)
230
+ - [AniMer](https://github.com/luoxue-star/AniMer)
231
+ - [DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)
232
+ - [SAM3DB](https://github.com/facebookresearch/sam-3d-body)
233
+
234
+ ---
235
+
236
+ ## Citation
237
+
238
+ If you use this code in your research, please cite our PRIMA paper.
239
+
240
+ ```bibtex
241
+ @misc{yu_prima,
242
+ title={PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation},
243
+ author={Xiaohang Yu and Ti Wang and Mackenzie Weygandt Mathis},
244
+ }
245
+ ```
246
+
247
+ ---
248
+
249
+ ## Contact
250
+
251
+ For issues, please open a GitHub issue in this repository.
app.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """Gradio demo for PRIMA + SuperAnimal + TTA.
11
+
12
+ This script wraps the ``demo_tta.py`` pipeline into an interactive
13
+ Gradio interface. The overall logic follows:
14
+
15
+ 1. Given an input image, run Detectron2 to detect animals.
16
+ 2. For each detected animal, run PRIMA for 3D pose/shape estimation.
17
+ 3. Run the fine-tuned DeepLabCut SuperAnimal model to obtain PRIMA 26-keypoint
18
+ 2D predictions.
19
+ 4. Run test-time adaptation (TTA) with user-specified lr and iters.
20
+ 5. Render and save before/after TTA results and keypoint visualizations.
21
+
22
+ """
23
+
24
+ import argparse
25
+ import concurrent.futures
26
+ import os
27
+ import queue
28
+ import sys
29
+ import tempfile
30
+ import time
31
+ import traceback
32
+ from types import SimpleNamespace
33
+ from typing import Callable, List, Tuple
34
+ from pathlib import Path
35
+
36
+ import cv2
37
+ import gradio as gr
38
+ import numpy as np
39
+ import torch
40
+ import torch.utils.data
41
+
42
+ # Repo-local minimal ``chumpy`` shim (see ``chumpy/__init__.py``) so SMAL pickles load
43
+ # without installing the full chumpy package in Space builds.
44
+ _REPO_ROOT = Path(__file__).resolve().parent
45
+ if str(_REPO_ROOT) not in sys.path:
46
+ sys.path.insert(0, str(_REPO_ROOT))
47
+
48
+ from prima.utils.weights import (
49
+ DEFAULT_HF_REPO_ID,
50
+ resolve_prima_checkpoint_path,
51
+ )
52
+ from prima.utils.detection import select_animal_boxes
53
+
54
+
55
+ # Default checkpoint path following README instructions
56
+ DEFAULT_CHECKPOINT = str(_REPO_ROOT / "data" / "PRIMAS1" / "checkpoints" / "s1ckpt_inference.ckpt")
57
+ DEFAULT_HF_ASSET_REPO = DEFAULT_HF_REPO_ID
58
+
59
+ # Output folder for rendered images/meshes and keypoints
60
+ DEFAULT_OUT_FOLDER = "demo_out_tta_gradio"
61
+
62
+
63
+ def _is_truthy_env(var_name: str) -> bool:
64
+ return os.environ.get(var_name, "").strip().lower() in {"1", "true", "yes", "on"}
65
+
66
+
67
+ def _running_on_space() -> bool:
68
+ return bool(os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID"))
69
+
70
+
71
+ def _gradio_examples_for_interface() -> List[List]:
72
+ """Gradio prefetches example media at startup.
73
+
74
+ Demo images are tracked with Git LFS / Xet (see ``.gitattributes``) so they can live
75
+ in the Hugging Face Space repo. Use absolute paths only when files exist beside ``app.py``.
76
+ """
77
+ if _is_truthy_env("PRIMA_DISABLE_GRADIO_EXAMPLES"):
78
+ return []
79
+ rows: List[List] = []
80
+ template: List[Tuple[str, float, int, float, float, bool, bool]] = [
81
+ ("demo_data/000000015956_horse.png", 1e-6, 0, 0.7, 0.1, False, True),
82
+ ("demo_data/n02412080_12159.png", 1e-6, 0, 0.7, 0.1, False, True),
83
+ ("demo_data/000000315905_zebra.jpg", 1e-6, 0, 0.7, 0.1, False, True),
84
+ ("demo_data/beagle.jpg", 1e-6, 0, 0.7, 0.1, False, True),
85
+ ("demo_data/shepherd_hati.jpg", 1e-6, 0, 0.7, 0.1, False, True),
86
+ ]
87
+ for rel, *rest in template:
88
+ p = _REPO_ROOT / rel
89
+ if p.is_file():
90
+ rows.append([str(p), *rest])
91
+ return rows
92
+
93
+
94
+ def _should_preload_assets() -> bool:
95
+ """Default to preload on Spaces; configurable via PRIMA_PRELOAD_ASSETS."""
96
+ preload_env = os.environ.get("PRIMA_PRELOAD_ASSETS")
97
+ if preload_env is not None:
98
+ return _is_truthy_env("PRIMA_PRELOAD_ASSETS")
99
+ return _running_on_space()
100
+
101
+ def _gradio_heartbeat_interval_sec() -> float:
102
+ """How often to yield status while waiting on long CPU/GPU work (keeps WebSockets alive).
103
+
104
+ Set ``PRIMA_GRADIO_HEARTBEAT_SEC`` to ``0`` to run long work on the Gradio thread (old behavior).
105
+ """
106
+ raw = os.environ.get("PRIMA_GRADIO_HEARTBEAT_SEC", "25").strip()
107
+ try:
108
+ v = float(raw)
109
+ except ValueError:
110
+ return 25.0
111
+ return max(0.0, v)
112
+
113
+
114
+ def _preload_assets_once(checkpoint_path: str) -> None:
115
+ print("[startup] Ensuring demo assets from Hugging Face Hub...")
116
+ resolve_prima_checkpoint_path(
117
+ checkpoint_path,
118
+ data_dir=_REPO_ROOT / "data",
119
+ auto_download=True,
120
+ hf_repo_id=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO),
121
+ )
122
+ print("[startup] Asset preload complete.")
123
+
124
+
125
+ def _load_prima_model(checkpoint_path: str = DEFAULT_CHECKPOINT):
126
+ """Load PRIMA model and renderer once for the Gradio app."""
127
+ from prima.models import load_prima
128
+ from prima.utils.renderer import Renderer, cam_crop_to_full
129
+
130
+ checkpoint_path = resolve_prima_checkpoint_path(
131
+ checkpoint_path,
132
+ data_dir=_REPO_ROOT / "data",
133
+ auto_download=True,
134
+ hf_repo_id=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO),
135
+ )
136
+ checkpoint = Path(checkpoint_path)
137
+ cfg_path = checkpoint.parent.parent / ".hydra" / "config.yaml"
138
+ if not checkpoint.exists():
139
+ raise FileNotFoundError(
140
+ f"Missing checkpoint: {checkpoint}. Download demo checkpoints/data as described in README."
141
+ )
142
+ if not cfg_path.exists():
143
+ raise FileNotFoundError(
144
+ f"Missing model config: {cfg_path}. Ensure the full checkpoint folder layout from README is present."
145
+ )
146
+
147
+ model, model_cfg = load_prima(checkpoint_path)
148
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
149
+ model = model.to(device)
150
+ model.eval()
151
+
152
+ renderer = Renderer(model_cfg, faces=model.smal.faces)
153
+ return model, model_cfg, renderer, cam_crop_to_full, device
154
+
155
+
156
+ def _build_detector():
157
+ """Build Detectron2 animal detector (same config as demo_tta/demo.py)."""
158
+ try:
159
+ import detectron2.config
160
+ import detectron2.engine
161
+ from detectron2 import model_zoo
162
+ except Exception as e:
163
+ print(f"[warn] Detectron2 unavailable ({type(e).__name__}: {e}); using full-image fallback bbox.")
164
+ return None
165
+
166
+ cfg = detectron2.config.get_cfg()
167
+ cfg.merge_from_file(
168
+ model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml")
169
+ )
170
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
171
+ cfg.MODEL.WEIGHTS = (
172
+ "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/"
173
+ "faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
174
+ )
175
+ cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
176
+ detector = detectron2.engine.DefaultPredictor(cfg)
177
+ return detector
178
+
179
+
180
+ def _load_model_and_detector_for_demo(checkpoint_path: str):
181
+ """Run on a worker thread when using heartbeat polling (single entry point for executor)."""
182
+ model, model_cfg, renderer, cam_crop_to_full_fn, device = _load_prima_model(checkpoint_path)
183
+ detector = _build_detector()
184
+ return model, model_cfg, renderer, cam_crop_to_full_fn, device, detector
185
+
186
+
187
+ # SuperAnimal defaults (same as in demo_tta parser)
188
+ SUPER_ANIMAL_ARGS = SimpleNamespace(
189
+ superanimal_name="superanimal_quadruped",
190
+ superanimal_model_name="hrnet_w32",
191
+ superanimal_detector_name="fasterrcnn_resnet50_fpn_v2",
192
+ superanimal_max_individuals=1,
193
+ saved_2d_model_path="",
194
+ pytorch_config_2d_path=str(_REPO_ROOT / "configs" / "sa_finetune_hrnet_w32.yaml"),
195
+ )
196
+
197
+
198
+ def _collect_animal_results(
199
+ model,
200
+ model_cfg,
201
+ renderer,
202
+ cam_crop_to_full_fn,
203
+ device,
204
+ detector,
205
+ out_folder: str,
206
+ img_rgb: np.ndarray,
207
+ tta_lr: float,
208
+ tta_num_iters: int,
209
+ det_thresh: float,
210
+ kp_conf_thresh: float,
211
+ side_view: bool,
212
+ save_mesh: bool,
213
+ progress_callback: Callable[[str], None] | None = None,
214
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], str | None, str | None]:
215
+ """Run detection + PRIMA + SuperAnimal + TTA on a single RGB image.
216
+
217
+ Returns:
218
+ before_imgs: list of HxWx3 RGB images (before TTA) for all animals
219
+ after_imgs: list of HxWx3 RGB images (after TTA) for all animals
220
+ kpt_imgs: list of HxWx3 RGB keypoint visualizations
221
+ first_before_mesh: path to first animal's before-TTA mesh (.obj) or None
222
+ first_after_mesh: path to first animal's after-TTA mesh (.obj) or None
223
+ """
224
+ from prima.utils import recursive_to
225
+ from prima.datasets.vitdet_dataset import ViTDetDataset
226
+ from demo_tta import (
227
+ denorm_patch_to_rgb,
228
+ resolve_sa_weights_path,
229
+ run_superanimal_on_patch,
230
+ save_keypoint_vis,
231
+ tta_optimize,
232
+ )
233
+
234
+ def report(message: str) -> None:
235
+ if progress_callback is not None:
236
+ progress_callback(message)
237
+
238
+ if int(tta_num_iters) > 0 and not SUPER_ANIMAL_ARGS.saved_2d_model_path:
239
+ report("Resolving SuperAnimal weights...")
240
+ SUPER_ANIMAL_ARGS.saved_2d_model_path = resolve_sa_weights_path("")
241
+
242
+ # Detect animals
243
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
244
+ if detector is None:
245
+ # Fallback for environments where Detectron2 is unavailable: process full image as one crop.
246
+ report("Detectron2 unavailable; using full-image crop...")
247
+ h, w = img_bgr.shape[:2]
248
+ boxes = np.array([[0.0, 0.0, float(max(1, w - 1)), float(max(1, h - 1))]], dtype=np.float32)
249
+ else:
250
+ report("Detecting animals with Detectron2...")
251
+ det_out = detector(img_bgr)
252
+ det_instances = det_out["instances"]
253
+
254
+ boxes, suppressed = select_animal_boxes(det_instances, score_threshold=float(det_thresh))
255
+ if suppressed > 0:
256
+ print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s)")
257
+ if len(boxes) == 0:
258
+ return [], [], [], None, None
259
+
260
+ report(f"Detected {len(boxes)} animal(s). Preparing crops...")
261
+ dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
262
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
263
+
264
+ before_imgs: List[np.ndarray] = []
265
+ after_imgs: List[np.ndarray] = []
266
+ kpt_imgs: List[np.ndarray] = []
267
+ before_mesh_paths: List[str] = []
268
+ after_mesh_paths: List[str] = []
269
+
270
+ img_token = next(tempfile._get_candidate_names())
271
+
272
+ total_batches = len(dataloader)
273
+ for batch_idx, batch in enumerate(dataloader, start=1):
274
+ batch = recursive_to(batch, device)
275
+
276
+ report(f"Animal {batch_idx}/{total_batches}: running PRIMA...")
277
+ with torch.no_grad():
278
+ out_before = model(batch)
279
+
280
+ animal_id = int(batch["animalid"][0])
281
+
282
+ # Save/render before TTA
283
+ img_fn = f"{img_token}"
284
+ from demo_tta import render_and_save # imported lazily to avoid circular issues
285
+
286
+ report(f"Animal {batch_idx}/{total_batches}: rendering before TTA...")
287
+ render_and_save(
288
+ renderer,
289
+ cam_crop_to_full_fn,
290
+ out_before,
291
+ batch,
292
+ img_fn,
293
+ animal_id,
294
+ out_folder,
295
+ suffix="before_tta",
296
+ side_view=side_view,
297
+ save_mesh=save_mesh,
298
+ )
299
+
300
+ before_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.png")
301
+ if os.path.exists(before_png_path):
302
+ before_bgr = cv2.imread(before_png_path)
303
+ if before_bgr is not None:
304
+ before_imgs.append(cv2.cvtColor(before_bgr, cv2.COLOR_BGR2RGB))
305
+
306
+ if save_mesh:
307
+ before_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.obj")
308
+ if os.path.exists(before_obj_path):
309
+ before_mesh_paths.append(before_obj_path)
310
+
311
+ if int(tta_num_iters) <= 0:
312
+ report(f"Animal {batch_idx}/{total_batches}: rendering final output...")
313
+ render_and_save(
314
+ renderer,
315
+ cam_crop_to_full_fn,
316
+ out_before,
317
+ batch,
318
+ img_fn,
319
+ animal_id,
320
+ out_folder,
321
+ suffix="after_tta",
322
+ side_view=side_view,
323
+ save_mesh=save_mesh,
324
+ )
325
+
326
+ after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
327
+ if os.path.exists(after_png_path):
328
+ after_bgr = cv2.imread(after_png_path)
329
+ if after_bgr is not None:
330
+ after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB))
331
+
332
+ if save_mesh:
333
+ after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj")
334
+ if os.path.exists(after_obj_path):
335
+ after_mesh_paths.append(after_obj_path)
336
+ continue
337
+
338
+ # Prepare patch for SuperAnimal
339
+ report(f"Animal {batch_idx}/{total_batches}: running SuperAnimal keypoints...")
340
+ patch_rgb = denorm_patch_to_rgb(batch["img"][0])
341
+ with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir:
342
+ bodyparts_xyc = run_superanimal_on_patch(patch_rgb, SUPER_ANIMAL_ARGS, tmp_dir)
343
+
344
+ if bodyparts_xyc is None:
345
+ # No keypoints => skip TTA for this animal
346
+ continue
347
+
348
+ kpts_xyc = bodyparts_xyc
349
+ kpts_xyc[kpts_xyc[:, 2] < float(kp_conf_thresh), 2] = 0.0
350
+
351
+ # Save keypoint visualization and npy
352
+ kpt_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png")
353
+ save_keypoint_vis(patch_rgb, kpts_xyc, kpt_png_path)
354
+ npy_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy")
355
+ np.save(npy_path, kpts_xyc)
356
+
357
+ if os.path.exists(kpt_png_path):
358
+ kpt_bgr = cv2.imread(kpt_png_path)
359
+ if kpt_bgr is not None:
360
+ kpt_imgs.append(cv2.cvtColor(kpt_bgr, cv2.COLOR_BGR2RGB))
361
+
362
+ # Normalize keypoints to [-0.5, 0.5] as in demo_tta
363
+ patch_h, patch_w = patch_rgb.shape[:2]
364
+ kpts_norm = kpts_xyc.copy()
365
+ kpts_norm[:, 0] = kpts_norm[:, 0] / float(patch_w) - 0.5
366
+ kpts_norm[:, 1] = kpts_norm[:, 1] / float(patch_h) - 0.5
367
+ gt_kpts_norm = torch.from_numpy(kpts_norm[None]).to(device=device, dtype=batch["img"].dtype)
368
+
369
+ # Run TTA
370
+ report(f"Animal {batch_idx}/{total_batches}: running TTA ({int(tta_num_iters)} iterations)...")
371
+ out_after = tta_optimize(
372
+ model,
373
+ batch,
374
+ gt_kpts_norm,
375
+ num_iters=int(tta_num_iters),
376
+ lr=float(tta_lr),
377
+ )
378
+
379
+ report(f"Animal {batch_idx}/{total_batches}: rendering after TTA...")
380
+ render_and_save(
381
+ renderer,
382
+ cam_crop_to_full_fn,
383
+ out_after,
384
+ batch,
385
+ img_fn,
386
+ animal_id,
387
+ out_folder,
388
+ suffix="after_tta",
389
+ side_view=side_view,
390
+ save_mesh=save_mesh,
391
+ )
392
+
393
+ after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
394
+ if os.path.exists(after_png_path):
395
+ after_bgr = cv2.imread(after_png_path)
396
+ if after_bgr is not None:
397
+ after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB))
398
+
399
+ if save_mesh:
400
+ after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj")
401
+ if os.path.exists(after_obj_path):
402
+ after_mesh_paths.append(after_obj_path)
403
+
404
+ first_before_mesh = before_mesh_paths[0] if before_mesh_paths else None
405
+ first_after_mesh = after_mesh_paths[0] if after_mesh_paths else None
406
+
407
+ report("Collecting outputs...")
408
+ return before_imgs, after_imgs, kpt_imgs, first_before_mesh, first_after_mesh
409
+
410
+
411
+ def build_demo(checkpoint_path: str = DEFAULT_CHECKPOINT, out_folder: str = DEFAULT_OUT_FOLDER) -> gr.Interface:
412
+ os.makedirs(out_folder, exist_ok=True)
413
+ runtime_cache = {
414
+ "model": None,
415
+ "model_cfg": None,
416
+ "renderer": None,
417
+ "cam_crop_to_full_fn": None,
418
+ "device": None,
419
+ "detector": None,
420
+ }
421
+
422
+ def gradio_inference(
423
+ image: np.ndarray,
424
+ tta_lr: float,
425
+ tta_num_iters: int,
426
+ det_thresh: float,
427
+ kp_conf_thresh: float,
428
+ side_view: bool,
429
+ save_mesh: bool,
430
+ ):
431
+ """Wrapper for Gradio. ``image`` is an RGB numpy array.
432
+
433
+ Yields intermediate status so long first-run (Hub downloads + model load)
434
+ and long inference do not hit silent client/proxy WebSocket timeouts.
435
+ """
436
+
437
+ if image is None:
438
+ yield None, None, None, "No image provided."
439
+ return
440
+
441
+ if image.dtype != np.uint8:
442
+ img_rgb = np.clip(image, 0, 255).astype(np.uint8)
443
+ else:
444
+ img_rgb = image
445
+
446
+ yield None, None, None, "Queued; preparing run…"
447
+
448
+ hb = _gradio_heartbeat_interval_sec()
449
+
450
+ if runtime_cache["model"] is None:
451
+ yield (
452
+ None,
453
+ None,
454
+ None,
455
+ "First run: downloading demo assets from Hugging Face (large checkpoint) "
456
+ "and loading the model. This can take many minutes; status updates here "
457
+ "mean the session is still alive.",
458
+ )
459
+ try:
460
+ if hb <= 0:
461
+ model, model_cfg, renderer, cam_crop_to_full_fn, device, detector = _load_model_and_detector_for_demo(
462
+ checkpoint_path
463
+ )
464
+ else:
465
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
466
+ fut = pool.submit(_load_model_and_detector_for_demo, checkpoint_path)
467
+ t0 = time.monotonic()
468
+ while True:
469
+ try:
470
+ model, model_cfg, renderer, cam_crop_to_full_fn, device, detector = fut.result(timeout=hb)
471
+ break
472
+ except concurrent.futures.TimeoutError:
473
+ elapsed = int(time.monotonic() - t0)
474
+ yield None, None, None, (
475
+ f"First run: still loading model and assets ({elapsed}s). "
476
+ f"Updates every ~{int(hb)}s keep the browser connection open on Spaces."
477
+ )
478
+ except Exception:
479
+ yield None, None, None, f"Model initialization failed:\n{traceback.format_exc()}"
480
+ return
481
+ runtime_cache["model"] = model
482
+ runtime_cache["model_cfg"] = model_cfg
483
+ runtime_cache["renderer"] = renderer
484
+ runtime_cache["cam_crop_to_full_fn"] = cam_crop_to_full_fn
485
+ runtime_cache["device"] = device
486
+ runtime_cache["detector"] = detector
487
+ yield None, None, None, "Model loaded. Running detection and inference…"
488
+
489
+ try:
490
+ if hb <= 0:
491
+ before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = _collect_animal_results(
492
+ runtime_cache["model"],
493
+ runtime_cache["model_cfg"],
494
+ runtime_cache["renderer"],
495
+ runtime_cache["cam_crop_to_full_fn"],
496
+ runtime_cache["device"],
497
+ runtime_cache["detector"],
498
+ out_folder,
499
+ img_rgb,
500
+ tta_lr=tta_lr,
501
+ tta_num_iters=tta_num_iters,
502
+ det_thresh=det_thresh,
503
+ kp_conf_thresh=kp_conf_thresh,
504
+ side_view=side_view,
505
+ save_mesh=save_mesh,
506
+ )
507
+ else:
508
+ stage_updates: queue.Queue[str] = queue.Queue()
509
+
510
+ def report_stage(message: str) -> None:
511
+ stage_updates.put(message)
512
+
513
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
514
+ fut = pool.submit(
515
+ _collect_animal_results,
516
+ runtime_cache["model"],
517
+ runtime_cache["model_cfg"],
518
+ runtime_cache["renderer"],
519
+ runtime_cache["cam_crop_to_full_fn"],
520
+ runtime_cache["device"],
521
+ runtime_cache["detector"],
522
+ out_folder,
523
+ img_rgb,
524
+ tta_lr,
525
+ tta_num_iters,
526
+ det_thresh,
527
+ kp_conf_thresh,
528
+ side_view,
529
+ save_mesh,
530
+ report_stage,
531
+ )
532
+ t0 = time.monotonic()
533
+ latest_stage = "Starting inference..."
534
+ while True:
535
+ while True:
536
+ try:
537
+ latest_stage = stage_updates.get_nowait()
538
+ except queue.Empty:
539
+ break
540
+ else:
541
+ elapsed = int(time.monotonic() - t0)
542
+ yield None, None, None, f"{latest_stage}\nElapsed: {elapsed}s"
543
+ try:
544
+ before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = fut.result(
545
+ timeout=1.0
546
+ )
547
+ break
548
+ except concurrent.futures.TimeoutError:
549
+ elapsed = int(time.monotonic() - t0)
550
+ yield None, None, None, (
551
+ f"{latest_stage}\n"
552
+ f"Elapsed: {elapsed}s\n"
553
+ "CPU inference can take several minutes."
554
+ )
555
+ except Exception:
556
+ yield None, None, None, f"Inference failed:\n{traceback.format_exc()}"
557
+ return
558
+
559
+ first_before = before_imgs[0] if before_imgs else None
560
+ first_after = after_imgs[0] if after_imgs else None
561
+ first_kpts = kpt_imgs[0] if kpt_imgs else None
562
+ if first_before is None and first_after is None:
563
+ yield (
564
+ None,
565
+ None,
566
+ None,
567
+ "No output generated. Try an image with a clearly visible quadruped.",
568
+ )
569
+ return
570
+ yield first_before, first_after, first_kpts, "OK"
571
+
572
+ _gradio_examples = _gradio_examples_for_interface()
573
+ _iface_kw = dict(
574
+ fn=gradio_inference,
575
+ analytics_enabled=False,
576
+ cache_examples=False,
577
+ inputs=[
578
+ gr.Image(
579
+ label="Input image",
580
+ type="numpy",
581
+ sources=["upload", "clipboard"],
582
+ ),
583
+ gr.Slider(
584
+ label="TTA learning rate",
585
+ minimum=1e-7,
586
+ maximum=1e-4,
587
+ value=1e-6,
588
+ step=1e-7,
589
+ ),
590
+ gr.Slider(
591
+ label="TTA iterations",
592
+ minimum=0,
593
+ maximum=100,
594
+ value=0,
595
+ step=1,
596
+ info="Set to 0 to disable TTA and reuse the initial PRIMA prediction.",
597
+ ),
598
+ gr.Slider(
599
+ label="Detection threshold",
600
+ minimum=0.3,
601
+ maximum=0.9,
602
+ value=0.7,
603
+ step=0.05,
604
+ ),
605
+ gr.Slider(
606
+ label="Keypoint confidence threshold",
607
+ minimum=0.0,
608
+ maximum=1.0,
609
+ value=0.1,
610
+ step=0.05,
611
+ ),
612
+ gr.Checkbox(label="Render side view", value=False),
613
+ gr.Checkbox(label="Save meshes (.obj)", value=True),
614
+ ],
615
+ outputs=[
616
+ gr.Image(label="Before TTA"),
617
+ gr.Image(label="After TTA"),
618
+ gr.Image(label="PRIMA 26 keypoints"),
619
+ gr.Textbox(label="Status / Traceback", lines=12),
620
+ ],
621
+ title="PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation",
622
+ description=(
623
+ "Upload an animal image. The demo runs Detectron2 for animal detection, "
624
+ "PRIMA for 3D pose/shape, DeepLabCut SuperAnimal for 2D keypoints, and "
625
+ "test-time adaptation (TTA) with configurable learning rate and iterations. "
626
+ "Set TTA iterations to 0 to disable adaptation.\n\n"
627
+ "Results (PNG/OBJ and 26-keypoint visualizations) are saved under "
628
+ f"'{out_folder}'."
629
+ ),
630
+ )
631
+ if _gradio_examples:
632
+ _iface_kw["examples"] = _gradio_examples
633
+ demo = gr.Interface(**_iface_kw)
634
+ demo.queue(max_size=8, default_concurrency_limit=1)
635
+ return demo
636
+
637
+
638
+ def parse_args() -> argparse.Namespace:
639
+ parser = argparse.ArgumentParser(description="Gradio demo for PRIMA + SuperAnimal + TTA")
640
+ parser.add_argument(
641
+ "--checkpoint",
642
+ type=str,
643
+ default=DEFAULT_CHECKPOINT,
644
+ help="Path to the pretrained PRIMA checkpoint",
645
+ )
646
+ parser.add_argument(
647
+ "--out_folder",
648
+ type=str,
649
+ default=DEFAULT_OUT_FOLDER,
650
+ help="Folder used to save rendered outputs and meshes",
651
+ )
652
+ return parser.parse_args()
653
+
654
+
655
+ if __name__ == "__main__":
656
+ args = parse_args()
657
+ if _should_preload_assets():
658
+ _preload_assets_once(args.checkpoint)
659
+ demo = build_demo(checkpoint_path=args.checkpoint, out_folder=args.out_folder)
660
+ demo.launch(inbrowser=False, ssr_mode=False)
chumpy/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ # Minimal ``chumpy`` compatibility for unpickling legacy SMAL model configs.
13
+
14
+ from .ch import Ch, ChArray
15
+
16
+ __all__ = ["Ch", "ChArray"]
chumpy/ch.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ # ``chumpy.ch`` namespace expected by legacy SMAL pickles.
13
+
14
+ import numpy as np
15
+
16
+
17
+ class Ch:
18
+ """Minimal stand-in for ``chumpy.ch.Ch`` (unpickling only)."""
19
+
20
+ def __init__(self, *args, **kwargs):
21
+ self._data = None
22
+ if args:
23
+ self._data = np.asarray(args[0])
24
+
25
+ def _resolve(self) -> np.ndarray:
26
+ # Real chumpy Ch instances store the underlying ndarray on attribute ``x``;
27
+ # legacy pickles unpickle by restoring ``__dict__`` without calling ``__init__``,
28
+ # so try common attribute names before falling back to ``_data``.
29
+ for attr in ("x", "_x", "_data"):
30
+ val = self.__dict__.get(attr)
31
+ if val is not None:
32
+ return np.asarray(val)
33
+ return np.zeros((), dtype=np.float32)
34
+
35
+ def __array__(self, dtype=None):
36
+ arr = self._resolve()
37
+ if dtype is not None:
38
+ arr = arr.astype(dtype, copy=False)
39
+ return arr
40
+
41
+ @property
42
+ def r(self) -> np.ndarray:
43
+ return self._resolve()
44
+
45
+
46
+ class ChArray(np.ndarray):
47
+ """Minimal stand-in for ``chumpy.ch.ChArray``."""
48
+
49
+ pass
50
+
51
+
52
+ __all__ = ["Ch", "ChArray"]
configs/sa_finetune_hrnet_w32.yaml ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepLabCut pytorch_config for the PRIMA TTA 2D pose model:
2
+ # SuperAnimal-Quadruped HRNet-w32 backbone fine-tuned on Animal3D, with
3
+ # the heatmap head re-trained for the 26-joint Animal3D / PRIMA layout.
4
+ #
5
+ # Used by demo_tta.py via DLC's `superanimal_analyze_images(...,
6
+ # customized_model_config=<this yaml>, customized_pose_checkpoint=<your
7
+ # fine-tuned .pt>)`. Only the pose model is fine-tuned; the bounding-box
8
+ # detector (Faster R-CNN) is the stock SuperAnimal-Quadruped one
9
+ # resolved by DLC at runtime.
10
+ data:
11
+ bbox_margin: 20
12
+ colormode: RGB
13
+ inference:
14
+ normalize_images: true
15
+ top_down_crop:
16
+ width: 256
17
+ height: 256
18
+ auto_padding:
19
+ pad_width_divisor: 32
20
+ pad_height_divisor: 32
21
+ train:
22
+ affine:
23
+ p: 0.5
24
+ rotation: 30
25
+ scaling:
26
+ - 1.0
27
+ - 1.0
28
+ translation: 0
29
+ gaussian_noise: 12.75
30
+ motion_blur: true
31
+ normalize_images: true
32
+ top_down_crop:
33
+ width: 256
34
+ height: 256
35
+ auto_padding:
36
+ pad_width_divisor: 32
37
+ pad_height_divisor: 32
38
+ detector:
39
+ data:
40
+ colormode: RGB
41
+ inference:
42
+ normalize_images: true
43
+ train:
44
+ affine:
45
+ p: 0.5
46
+ rotation: 30
47
+ scaling:
48
+ - 1.0
49
+ - 1.0
50
+ translation: 40
51
+ collate:
52
+ type: ResizeFromDataSizeCollate
53
+ min_scale: 0.4
54
+ max_scale: 1.0
55
+ min_short_side: 128
56
+ max_short_side: 1152
57
+ multiple_of: 32
58
+ to_square: false
59
+ hflip: true
60
+ normalize_images: true
61
+ device: auto
62
+ model:
63
+ type: FasterRCNN
64
+ freeze_bn_stats: true
65
+ freeze_bn_weights: false
66
+ variant: fasterrcnn_resnet50_fpn_v2
67
+ runner:
68
+ type: DetectorTrainingRunner
69
+ key_metric: test.mAP@50:95
70
+ key_metric_asc: true
71
+ eval_interval: 10
72
+ optimizer:
73
+ type: AdamW
74
+ params:
75
+ lr: 0.0001
76
+ scheduler:
77
+ type: LRListScheduler
78
+ params:
79
+ milestones:
80
+ - 160
81
+ lr_list:
82
+ - - 1e-05
83
+ snapshots:
84
+ max_snapshots: 5
85
+ save_epochs: 25
86
+ save_optimizer_state: false
87
+ train_settings:
88
+ batch_size: 1
89
+ dataloader_workers: 0
90
+ dataloader_pin_memory: false
91
+ display_iters: 500
92
+ epochs: 250
93
+ device: auto
94
+ inference:
95
+ multithreading:
96
+ enabled: true
97
+ queue_length: 4
98
+ timeout: 30.0
99
+ compile:
100
+ enabled: false
101
+ backend: inductor
102
+ autocast:
103
+ enabled: false
104
+ metadata:
105
+ project_path: ""
106
+ pose_config_path: ""
107
+ bodyparts:
108
+ - left_eye
109
+ - right_eye
110
+ - chin
111
+ - left_front_paw
112
+ - right_front_paw
113
+ - left_back_paw
114
+ - right_back_paw
115
+ - tail_base
116
+ - left_front_thigh
117
+ - right_front_thigh
118
+ - left_back_thigh
119
+ - right_back_thigh
120
+ - left_shoulder
121
+ - right_shoulder
122
+ - left_front_knee
123
+ - right_front_knee
124
+ - left_back_knee
125
+ - right_back_knee
126
+ - neck_base
127
+ - tail_mid
128
+ - left_ear_base
129
+ - right_ear_base
130
+ - left_mouth_corner
131
+ - right_mouth_corner
132
+ - nose
133
+ - tail_tip_first
134
+ unique_bodyparts: []
135
+ individuals:
136
+ - individual000
137
+ with_identity: false
138
+ method: td
139
+ model:
140
+ backbone:
141
+ type: HRNet
142
+ model_name: hrnet_w32
143
+ freeze_bn_stats: true
144
+ freeze_bn_weights: false
145
+ interpolate_branches: false
146
+ increased_channel_count: false
147
+ backbone_output_channels: 32
148
+ heads:
149
+ bodypart:
150
+ type: HeatmapHead
151
+ weight_init: normal
152
+ predictor:
153
+ type: HeatmapPredictor
154
+ apply_sigmoid: false
155
+ clip_scores: true
156
+ location_refinement: true
157
+ locref_std: 7.2801
158
+ target_generator:
159
+ type: HeatmapGaussianGenerator
160
+ num_heatmaps: 26
161
+ pos_dist_thresh: 17
162
+ heatmap_mode: KEYPOINT
163
+ gradient_masking: true
164
+ background_weight: 0.0
165
+ generate_locref: true
166
+ locref_std: 7.2801
167
+ criterion:
168
+ heatmap:
169
+ type: WeightedMSECriterion
170
+ weight: 1.0
171
+ locref:
172
+ type: WeightedHuberCriterion
173
+ weight: 0.05
174
+ heatmap_config:
175
+ channels:
176
+ - 32
177
+ kernel_size: []
178
+ strides: []
179
+ final_conv:
180
+ out_channels: 26
181
+ kernel_size: 1
182
+ locref_config:
183
+ channels:
184
+ - 32
185
+ kernel_size: []
186
+ strides: []
187
+ final_conv:
188
+ out_channels: 52
189
+ kernel_size: 1
190
+ net_type: hrnet_w32
191
+ runner:
192
+ type: PoseTrainingRunner
193
+ gpus:
194
+ key_metric: test.mAP
195
+ key_metric_asc: true
196
+ eval_interval: 10
197
+ optimizer:
198
+ type: AdamW
199
+ params:
200
+ lr: 0.0001
201
+ scheduler:
202
+ type: LRListScheduler
203
+ params:
204
+ lr_list:
205
+ - - 1e-05
206
+ - - 1e-06
207
+ milestones:
208
+ - 160
209
+ - 190
210
+ snapshots:
211
+ max_snapshots: 5
212
+ save_epochs: 10
213
+ save_optimizer_state: false
214
+ train_settings:
215
+ batch_size: 64
216
+ dataloader_workers: 8
217
+ dataloader_pin_memory: false
218
+ display_iters: 500
219
+ epochs: 200
220
+ seed: 42
demo_data/000000015956_horse.png ADDED

Git LFS Details

  • SHA256: 2a2398ba7df40a47c636afefa28be17b55f4b7bc2c378e053aeea507580ad2cb
  • Pointer size: 131 Bytes
  • Size of remote file: 620 kB
demo_data/000000315905_zebra.jpg ADDED

Git LFS Details

  • SHA256: e0a17e1f1650820b020a9025144015c1e27f0f1ab435859f0bde3a0047d8f689
  • Pointer size: 131 Bytes
  • Size of remote file: 257 kB
demo_data/beagle.jpg ADDED

Git LFS Details

  • SHA256: ac29e6ea6086831dd9806a8cd3fd608e264ac1af567f6fcfc8797c5bd3d5d560
  • Pointer size: 131 Bytes
  • Size of remote file: 350 kB
demo_data/n02101388_1188.png ADDED

Git LFS Details

  • SHA256: e45ff508fb8c6437cce22fcb59b4f1b6fe37ddfab1d4cf68d97629f9caa939f4
  • Pointer size: 131 Bytes
  • Size of remote file: 319 kB
demo_data/n02412080_12159.png ADDED

Git LFS Details

  • SHA256: 03273c57e8b25b258d3eb96af7b4f77b43b5c40be90da83c21875f3322b487f1
  • Pointer size: 131 Bytes
  • Size of remote file: 347 kB
demo_data/shepherd_hati.jpg ADDED

Git LFS Details

  • SHA256: 65c5878203bc3165dda9011ebfce77cc7d930daed0a215396d8036509d1963c1
  • Pointer size: 131 Bytes
  • Size of remote file: 210 kB
demo_tta.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """
11
+ demo_tta.py: PRIMA inference with fine-tuned DeepLabCut SuperAnimal TTA
12
+
13
+ Pipeline:
14
+ 1. Run Detectron2 to detect animals in the input image.
15
+ 2. Run PRIMA on each detected animal to obtain 3D pose/shape estimation.
16
+ 3. Run a fine-tuned DeepLabCut SuperAnimal pose model (Animal3D 26-joint
17
+ layout) to obtain 2D keypoints already in PRIMA topology. The fine-tuned
18
+ snapshot is wired into DLC's
19
+ ``superanimal_analyze_images`` via the ``customized_pose_checkpoint``
20
+ and ``customized_model_config`` kwargs.
21
+ 4. Run test-time adaptation (TTA) with user-specified lr and num_iters
22
+ to further optimize the 3D pose and shape estimation.
23
+ 5. Render and save before/after TTA results (PNG + OBJ) and the
24
+ 26-keypoint visualization (PNG).
25
+ """
26
+
27
+
28
+ from pathlib import Path
29
+ import argparse
30
+ import copy
31
+ import os
32
+ import tempfile
33
+ import warnings
34
+
35
+ import cv2
36
+ import numpy as np
37
+ import torch
38
+ import torch.nn.functional as F
39
+ import torch.utils.data
40
+ from tqdm import tqdm
41
+
42
+ from prima.models import load_prima
43
+ from prima.utils import recursive_to
44
+ from prima.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
45
+ from prima.utils.detection import ANIMAL_COCO_IDS, select_animal_boxes
46
+ from prima.utils.weights import DEFAULT_HF_REPO_ID, resolve_prima_checkpoint_path
47
+
48
+ warnings.filterwarnings("ignore")
49
+
50
+ LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353)
51
+ GREEN = (0.65, 0.86, 0.74)
52
+
53
+ REPO_ROOT = Path(__file__).resolve().parent
54
+
55
+
56
+ def load_renderer_components():
57
+ try:
58
+ from prima.utils.renderer import Renderer, cam_crop_to_full
59
+ except Exception as exc:
60
+ raise RuntimeError(
61
+ "Cannot initialize the PRIMA renderer. Rendering requires a working "
62
+ "pyrender/OpenGL backend such as EGL or OSMesa. Install the missing "
63
+ "OpenGL runtime for this environment, or run in an environment where "
64
+ "PYOPENGL_PLATFORM=egl/osmesa works."
65
+ ) from exc
66
+ return Renderer, cam_crop_to_full
67
+
68
+
69
+ def denorm_patch_to_rgb(img_tensor: torch.Tensor) -> np.ndarray:
70
+ patch = (img_tensor.detach().cpu() * (DEFAULT_STD[:, None, None]) + DEFAULT_MEAN[:, None, None]) / 255.0
71
+ patch = patch.permute(1, 2, 0).numpy()
72
+ return np.clip(patch, 0.0, 1.0)
73
+
74
+
75
+ def save_keypoint_vis(patch_rgb: np.ndarray, kpts_xyc: np.ndarray, save_path: str) -> None:
76
+ vis = cv2.cvtColor((patch_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2BGR).copy()
77
+ num_kpts = len(kpts_xyc)
78
+
79
+ for i, (x, y, c) in enumerate(kpts_xyc):
80
+ if c <= 0:
81
+ continue
82
+
83
+ # Use distinct color for each keypoint (OpenCV uses BGR)
84
+ hue = int(179 * i / max(1, num_kpts - 1))
85
+ color_bgr = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2BGR)[0, 0]
86
+ color_bgr = (int(color_bgr[0]), int(color_bgr[1]), int(color_bgr[2]))
87
+
88
+ cx, cy = int(round(float(x))), int(round(float(y)))
89
+ cv2.circle(vis, (cx, cy), 3, color_bgr, -1)
90
+ cv2.putText(vis, str(i), (cx + 3, cy - 3), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1, cv2.LINE_AA)
91
+
92
+ cv2.imwrite(save_path, vis)
93
+
94
+
95
+ def resolve_sa_weights_path(local_path: str) -> str:
96
+ """Return a local path to the fine-tuned SuperAnimal .pt snapshot.
97
+
98
+ If ``local_path`` is empty, downloads ``sa_finetune_hrnet_w32.pt`` from the
99
+ ``MLAdaptiveIntelligence/FMPose3D`` Hugging Face repo (cached under
100
+ ``~/.cache/huggingface``).
101
+ """
102
+ if local_path:
103
+ return local_path
104
+ try:
105
+ from huggingface_hub import hf_hub_download
106
+ except ImportError:
107
+ raise ImportError(
108
+ "huggingface_hub is required to auto-download the fine-tuned "
109
+ "SuperAnimal weights. Install with `pip install huggingface_hub`, "
110
+ "or pass --saved_2d_model_path with a local .pt file."
111
+ ) from None
112
+ repo_id = "MLAdaptiveIntelligence/FMPose3D"
113
+ filename = "sa_finetune_hrnet_w32.pt"
114
+ try:
115
+ cached_path = hf_hub_download(repo_id=repo_id, filename=filename, local_files_only=True)
116
+ except Exception:
117
+ print(f"No --saved_2d_model_path provided; downloading '{filename}' from {repo_id}...")
118
+ return hf_hub_download(repo_id=repo_id, filename=filename)
119
+
120
+ print(f"Using cached SuperAnimal weights: {cached_path}")
121
+ return cached_path
122
+
123
+
124
+ def run_superanimal_on_patch(patch_rgb: np.ndarray, args, tmp_dir: str):
125
+ """Predict 26-joint 2D keypoints on a single PRIMA patch using a
126
+ fine-tuned DeepLabCut SuperAnimal snapshot.
127
+
128
+ Returns an ``(26, 3)`` array of ``(x, y, confidence)`` in patch
129
+ pixel coordinates, or ``None`` if no individual was detected.
130
+ """
131
+ try:
132
+ from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images
133
+ except Exception as e:
134
+ raise RuntimeError(
135
+ "Cannot import DeepLabCut SuperAnimal API. Please install deeplabcut with pose_estimation_pytorch support."
136
+ ) from e
137
+
138
+ patch_path = os.path.join(tmp_dir, "patch.png")
139
+ cv2.imwrite(patch_path, cv2.cvtColor((patch_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
140
+
141
+ dlc_device = "cuda" if torch.cuda.is_available() else "cpu"
142
+ preds = superanimal_analyze_images(
143
+ superanimal_name=args.superanimal_name,
144
+ model_name=args.superanimal_model_name,
145
+ detector_name=args.superanimal_detector_name,
146
+ images=patch_path,
147
+ max_individuals=args.superanimal_max_individuals,
148
+ out_folder=tmp_dir,
149
+ device=dlc_device,
150
+ customized_model_config=args.pytorch_config_2d_path,
151
+ customized_pose_checkpoint=args.saved_2d_model_path,
152
+ )
153
+
154
+ payload = preds.get(patch_path, None)
155
+ if payload is None:
156
+ return None
157
+ bodyparts = payload.get("bodyparts", None)
158
+ if bodyparts is None or len(bodyparts) == 0:
159
+ return None
160
+
161
+ best_idx = int(np.argmax(bodyparts[..., 2].mean(axis=1)))
162
+ return bodyparts[best_idx].astype(np.float32)
163
+
164
+
165
+ def render_and_save(renderer, cam_crop_to_full_fn, out, batch, img_fn, animal_id, out_folder, suffix, side_view, save_mesh):
166
+ pred_cam = out['pred_cam']
167
+ box_center = batch['box_center'].float()
168
+ box_size = batch['box_size'].float()
169
+ img_size = batch['img_size'].float()
170
+ scaled_focal_length = batch['focal_length'][0, 0] / batch['img'].shape[-1] * img_size.max()
171
+ pred_cam_t_full = cam_crop_to_full_fn(pred_cam, box_center, box_size, img_size, scaled_focal_length)
172
+
173
+ white_img = (torch.ones_like(batch['img'][0]).cpu() - DEFAULT_MEAN[:, None, None] / 255) / (
174
+ DEFAULT_STD[:, None, None] / 255
175
+ )
176
+ input_patch = denorm_patch_to_rgb(batch['img'][0])
177
+
178
+ regression_img = renderer(
179
+ out['pred_vertices'][0].detach().cpu().numpy(),
180
+ out['pred_cam_t'][0].detach().cpu().numpy(),
181
+ batch['img'][0],
182
+ mesh_base_color=GREEN,
183
+ scene_bg_color=(1, 1, 1),
184
+ )
185
+
186
+ final_img = np.concatenate([input_patch, regression_img], axis=1)
187
+ if side_view:
188
+ side_img = renderer(
189
+ out['pred_vertices'][0].detach().cpu().numpy(),
190
+ out['pred_cam_t'][0].detach().cpu().numpy(),
191
+ white_img,
192
+ mesh_base_color=GREEN,
193
+ scene_bg_color=(1, 1, 1),
194
+ side_view=True,
195
+ )
196
+ final_img = np.concatenate([final_img, side_img], axis=1)
197
+
198
+ cv2.imwrite(
199
+ os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.png'),
200
+ cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR),
201
+ )
202
+
203
+ if save_mesh:
204
+ verts = out['pred_vertices'][0].detach().cpu().numpy()
205
+ cam_t = pred_cam_t_full[0].detach().cpu().numpy()
206
+ tmesh = renderer.vertices_to_trimesh(verts, cam_t.copy(), LIGHT_BLUE)
207
+ tmesh.export(os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.obj'))
208
+
209
+
210
+ def tta_optimize(model, batch, gt_kpts_norm, num_iters, lr):
211
+ model.eval()
212
+
213
+ if hasattr(model, 'backbone'):
214
+ for p in model.backbone.parameters():
215
+ p.requires_grad = False
216
+
217
+ orig_smal_head_state = copy.deepcopy(model.smal_head.state_dict())
218
+ model.smal_head.freeze_except_regression_heads()
219
+ tta_params = model.smal_head.get_tta_parameters(mode='all')
220
+ optimizer = torch.optim.Adam(tta_params, lr=lr)
221
+
222
+ valid_mask = (gt_kpts_norm[..., 2] > 0).float().unsqueeze(-1)
223
+ gt_xy = gt_kpts_norm[..., :2]
224
+
225
+ for _ in range(num_iters):
226
+ optimizer.zero_grad()
227
+ out = model(batch)
228
+ pred_xy = out['pred_keypoints_2d']
229
+ loss = F.mse_loss(pred_xy * valid_mask, gt_xy * valid_mask, reduction='sum') / (valid_mask.sum() + 1e-6)
230
+ loss.backward()
231
+ optimizer.step()
232
+
233
+ with torch.no_grad():
234
+ out_after = model(batch)
235
+
236
+ model.smal_head.load_state_dict(orig_smal_head_state)
237
+ model.smal_head.unfreeze_all()
238
+
239
+ return out_after
240
+
241
+
242
+ def main():
243
+ parser = argparse.ArgumentParser(description='PRIMA + SuperAnimal + TTA demo')
244
+ parser.add_argument('--checkpoint', type=str, default='',
245
+ help='Path to pretrained PRIMA checkpoint. Empty -> auto-download the default Stage 1 checkpoint.')
246
+ parser.add_argument('--hf-repo-id', '--hf_repo_id', dest='hf_repo_id',
247
+ type=str, default=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_REPO_ID),
248
+ help='Hugging Face repo ID containing PRIMA demo assets')
249
+ parser.add_argument('--no-auto-download', '--no_auto_download', dest='no_auto_download', action='store_true',
250
+ help='Disable automatic download of missing PRIMA demo assets')
251
+ parser.add_argument('--img_path', type=str, default=None, help='Single image path')
252
+ parser.add_argument('--img_folder', type=str, default='demo_data/', help='Folder with input images')
253
+ parser.add_argument('--out_folder', type=str, default='demo_out_tta', help='Output folder')
254
+ parser.add_argument('--side_view', dest='side_view', action='store_true', default=False, help='Render side view')
255
+ parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, help='Save meshes')
256
+ parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg', '*.JPEG'], help='Image globs')
257
+ parser.add_argument('--det_thresh', type=float, default=0.7, help='Detectron2 score threshold for animals')
258
+
259
+ parser.add_argument('--tta_lr', type=float, default=1e-6, help='TTA learning rate')
260
+ parser.add_argument('--tta_num_iters', type=int, default=30, help='TTA iterations')
261
+ parser.add_argument('--kp_conf_thresh', type=float, default=0.1, help='Keypoint confidence threshold')
262
+
263
+ parser.add_argument('--superanimal_name', type=str, default='superanimal_quadruped')
264
+ parser.add_argument('--superanimal_model_name', type=str, default='hrnet_w32')
265
+ parser.add_argument('--superanimal_detector_name', type=str, default='fasterrcnn_resnet50_fpn_v2')
266
+ parser.add_argument('--superanimal_max_individuals', type=int, default=1)
267
+ parser.add_argument('--saved_2d_model_path', type=str, default='',
268
+ help='Path to the fine-tuned SuperAnimal 26-joint .pt snapshot. '
269
+ 'Empty -> auto-download sa_finetune_hrnet_w32.pt from '
270
+ 'MLAdaptiveIntelligence/FMPose3D on Hugging Face Hub.')
271
+ parser.add_argument('--pytorch_config_2d_path', type=str,
272
+ default=str(Path(__file__).resolve().parent / 'configs' / 'sa_finetune_hrnet_w32.yaml'),
273
+ help='Path to the DLC pytorch config yaml for the fine-tuned snapshot. '
274
+ 'Defaults to the bundled configs/sa_finetune_hrnet_w32.yaml.')
275
+
276
+ args = parser.parse_args()
277
+ checkpoint_path = resolve_prima_checkpoint_path(
278
+ args.checkpoint,
279
+ data_dir=REPO_ROOT / "data",
280
+ auto_download=not args.no_auto_download,
281
+ hf_repo_id=args.hf_repo_id,
282
+ )
283
+ args.saved_2d_model_path = resolve_sa_weights_path(args.saved_2d_model_path)
284
+
285
+ model, model_cfg = load_prima(checkpoint_path)
286
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
287
+ model = model.to(device)
288
+ model.eval()
289
+
290
+ Renderer, cam_crop_to_full_fn = load_renderer_components()
291
+ renderer = Renderer(model_cfg, faces=model.smal.faces)
292
+ os.makedirs(args.out_folder, exist_ok=True)
293
+
294
+ import detectron2.config
295
+ import detectron2.engine
296
+ from detectron2 import model_zoo
297
+
298
+ cfg = detectron2.config.get_cfg()
299
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"))
300
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
301
+ cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
302
+ cfg.MODEL.DEVICE = device.type
303
+ detector = detectron2.engine.DefaultPredictor(cfg)
304
+
305
+ if args.img_path is not None:
306
+ img_paths = [Path(args.img_path)]
307
+ else:
308
+ img_paths = sorted([img for end in args.file_type for img in Path(args.img_folder).glob(end)])
309
+
310
+ for img_path in img_paths:
311
+ img_bgr = cv2.imread(str(img_path))
312
+ if img_bgr is None:
313
+ print(f"[WARN] Cannot read image: {img_path}")
314
+ continue
315
+ det_out = detector(img_bgr)
316
+ det_instances = det_out['instances']
317
+ boxes, suppressed = select_animal_boxes(
318
+ det_instances,
319
+ animal_class_ids=ANIMAL_COCO_IDS,
320
+ score_threshold=args.det_thresh,
321
+ )
322
+ if suppressed > 0:
323
+ print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s) in {img_path}")
324
+
325
+ if len(boxes) == 0:
326
+ print(f"[INFO] No animal detected in {img_path}")
327
+ continue
328
+
329
+ dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
330
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
331
+
332
+ for batch in tqdm(dataloader, desc=f"{img_path.name}"):
333
+ batch = recursive_to(batch, device)
334
+ with torch.no_grad():
335
+ out_before = model(batch)
336
+
337
+ img_fn = img_path.stem
338
+ animal_id = int(batch['animalid'][0])
339
+
340
+ render_and_save(
341
+ renderer,
342
+ cam_crop_to_full_fn,
343
+ out_before,
344
+ batch,
345
+ img_fn,
346
+ animal_id,
347
+ args.out_folder,
348
+ suffix='before_tta',
349
+ side_view=args.side_view,
350
+ save_mesh=args.save_mesh,
351
+ )
352
+
353
+ patch_rgb = denorm_patch_to_rgb(batch['img'][0])
354
+ with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir:
355
+ kpts_xyc = run_superanimal_on_patch(patch_rgb, args, tmp_dir)
356
+
357
+ if kpts_xyc is None:
358
+ print(f"[WARN] No SuperAnimal keypoints for {img_fn}_{animal_id}, skip TTA")
359
+ continue
360
+
361
+ kpts_xyc[kpts_xyc[:, 2] < args.kp_conf_thresh, 2] = 0.0
362
+
363
+ save_keypoint_vis(
364
+ patch_rgb,
365
+ kpts_xyc,
366
+ os.path.join(args.out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png"),
367
+ )
368
+ np.save(os.path.join(args.out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy"), kpts_xyc)
369
+
370
+ patch_h, patch_w = patch_rgb.shape[:2]
371
+ kpts_norm = kpts_xyc.copy()
372
+ kpts_norm[:, 0] = kpts_norm[:, 0] / float(patch_w) - 0.5
373
+ kpts_norm[:, 1] = kpts_norm[:, 1] / float(patch_h) - 0.5
374
+ gt_kpts_norm = torch.from_numpy(kpts_norm[None]).to(device=device, dtype=batch['img'].dtype)
375
+
376
+ out_after = tta_optimize(
377
+ model,
378
+ batch,
379
+ gt_kpts_norm,
380
+ num_iters=args.tta_num_iters,
381
+ lr=args.tta_lr,
382
+ )
383
+
384
+ render_and_save(
385
+ renderer,
386
+ cam_crop_to_full_fn,
387
+ out_after,
388
+ batch,
389
+ img_fn,
390
+ animal_id,
391
+ args.out_folder,
392
+ suffix='after_tta',
393
+ side_view=args.side_view,
394
+ save_mesh=args.save_mesh,
395
+ )
396
+
397
+
398
+ if __name__ == '__main__':
399
+ main()
images/teaser.png ADDED

Git LFS Details

  • SHA256: a617ca4fd37de03e2db4ccf397ce9841ed32c3fe18c766c4832d41af574ad746
  • Pointer size: 132 Bytes
  • Size of remote file: 4.29 MB
packages.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ libosmesa6
2
+ libgl1
3
+ libgl1-mesa-dri
4
+ libegl-mesa0
5
+ libegl1
6
+ libglx-mesa0
7
+ libgles2
prima/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """Top-level package for PRIMA.
11
+
12
+ This package contains models, datasets and utilities for
13
+ 3D animal pose and shape estimation.
14
+ """
15
+
16
+ from importlib.metadata import PackageNotFoundError, version
17
+
18
+
19
+ try: # pragma: no cover - best effort during development
20
+ __version__ = version("prima-animal")
21
+ except PackageNotFoundError: # pragma: no cover
22
+ __version__ = "0.0.0"
23
+
24
+
25
+ __all__ = ["__version__"]
prima/configs/__init__.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from typing import Dict
11
+ from yacs.config import CfgNode as CN
12
+
13
+ def to_lower(x: Dict) -> Dict:
14
+ """
15
+ Convert all dictionary keys to lowercase
16
+ Args:
17
+ x (dict): Input dictionary
18
+ Returns:
19
+ dict: Output dictionary with all keys converted to lowercase
20
+ """
21
+ return {k.lower(): v for k, v in x.items()}
22
+
23
+
24
+ _C = CN(new_allowed=True)
25
+
26
+ _C.GENERAL = CN(new_allowed=True)
27
+ _C.GENERAL.RESUME = True
28
+ _C.GENERAL.TIME_TO_RUN = 3300
29
+ _C.GENERAL.VAL_STEPS = 100
30
+ _C.GENERAL.LOG_STEPS = 100
31
+ _C.GENERAL.CHECKPOINT_STEPS = 20000
32
+ _C.GENERAL.CHECKPOINT_DIR = "checkpoints"
33
+ _C.GENERAL.SUMMARY_DIR = "tensorboard"
34
+ _C.GENERAL.NUM_GPUS = 1
35
+ _C.GENERAL.NUM_WORKERS = 4
36
+ _C.GENERAL.MIXED_PRECISION = True
37
+ _C.GENERAL.ALLOW_CUDA = True
38
+ _C.GENERAL.PIN_MEMORY = False
39
+ _C.GENERAL.DISTRIBUTED = False
40
+ _C.GENERAL.LOCAL_RANK = 0
41
+ _C.GENERAL.USE_SYNCBN = False
42
+ _C.GENERAL.WORLD_SIZE = 1
43
+ _C.GENERAL.PREFETCH_FACTOR = 2
44
+
45
+ _C.TRAIN = CN(new_allowed=True)
46
+ _C.TRAIN.NUM_EPOCHS = 100
47
+ _C.TRAIN.SHUFFLE = True
48
+ _C.TRAIN.WARMUP = False
49
+ _C.TRAIN.NORMALIZE_PER_IMAGE = False
50
+ _C.TRAIN.CLIP_GRAD = False
51
+ _C.TRAIN.CLIP_GRAD_VALUE = 1.0
52
+ _C.LOSS_WEIGHTS = CN(new_allowed=True)
53
+
54
+ _C.DATASETS = CN(new_allowed=True)
55
+
56
+ _C.MODEL = CN(new_allowed=True)
57
+ _C.MODEL.IMAGE_SIZE = 224
58
+
59
+ _C.EXTRA = CN(new_allowed=True)
60
+ _C.EXTRA.FOCAL_LENGTH = 5000
61
+
62
+ _C.DATASETS.CONFIG = CN(new_allowed=True)
63
+ _C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
64
+ _C.DATASETS.CONFIG.ROT_FACTOR = 30
65
+ _C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
66
+ _C.DATASETS.CONFIG.COLOR_SCALE = 0.2
67
+ _C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
68
+ _C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
69
+ _C.DATASETS.CONFIG.DO_FLIP = False
70
+ _C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
71
+ _C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
72
+
73
+
74
+ def default_config() -> CN:
75
+ """
76
+ Get a yacs CfgNode object with the default config values.
77
+ """
78
+ # Return a clone so that the defaults will not be altered
79
+ # This is for the "local variable" use pattern
80
+ return _C.clone()
81
+
82
+
83
+ def get_config(config_file: str, merge: bool = True) -> CN:
84
+ """
85
+ Read a config file and optionally merge it with the default config file.
86
+ Args:
87
+ config_file (str): Path to config file.
88
+ merge (bool): Whether to merge with the default config or not.
89
+ Returns:
90
+ CfgNode: Config as a yacs CfgNode object.
91
+ """
92
+ if merge:
93
+ cfg = default_config()
94
+ else:
95
+ cfg = CN(new_allowed=True)
96
+ cfg.merge_from_file(config_file)
97
+
98
+ cfg.freeze()
99
+ return cfg
prima/datasets/__init__.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from typing import Dict, Optional
11
+ from torch.utils.data import WeightedRandomSampler
12
+ import torch
13
+ import pytorch_lightning as pl
14
+ from yacs.config import CfgNode
15
+ from .datasets import OptionAnimalDataset, TrainDataset
16
+ from prima.utils.pylogger import get_pylogger
17
+
18
+ log = get_pylogger(__name__)
19
+
20
+
21
+ class DataModule(pl.LightningDataModule):
22
+
23
+ def __init__(self, cfg: CfgNode) -> None:
24
+ """
25
+ Initialize LightningDataModule for AMR training
26
+ Args:
27
+ cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
28
+ """
29
+ super().__init__()
30
+ self.cfg = cfg
31
+ self.train_dataset = None
32
+ self.val_dataset = None
33
+ self.test_dataset = None
34
+ self.mocap_dataset = None
35
+ self.weight_sampler = None
36
+
37
+ def setup(self, stage: Optional[str] = None) -> None:
38
+ """
39
+ Load datasets necessary for training
40
+ Args:
41
+ stage:
42
+ """
43
+ if self.train_dataset is None:
44
+ self.train_dataset = OptionAnimalDataset(self.cfg)
45
+ self.weight_sampler = WeightedRandomSampler(weights=self.train_dataset.weights,
46
+ num_samples=len(self.train_dataset))
47
+ if self.val_dataset is None:
48
+ self.val_dataset = TrainDataset(self.cfg, is_train=False,
49
+ root_image=self.cfg.DATASETS.ANIMAL3D.ROOT_IMAGE,
50
+ json_file=self.cfg.DATASETS.ANIMAL3D.JSON_FILE.TEST)
51
+
52
+ def train_dataloader(self) -> Dict:
53
+ """
54
+ Setup training data loader.
55
+ Returns:
56
+ Dict: Dictionary containing image and mocap data dataloaders
57
+ """
58
+ shuffle = False if self.weight_sampler is not None else True
59
+ train_dataloader = torch.utils.data.DataLoader(self.train_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True,
60
+ num_workers=self.cfg.GENERAL.NUM_WORKERS,
61
+ prefetch_factor=self.cfg.GENERAL.PREFETCH_FACTOR,
62
+ pin_memory=True,
63
+ shuffle=shuffle,
64
+ sampler=self.weight_sampler,
65
+ )
66
+ return {'img': train_dataloader}
67
+
68
+ def val_dataloader(self) -> torch.utils.data.DataLoader:
69
+ """
70
+ Setup val data loader.
71
+ Returns:
72
+ torch.utils.data.DataLoader: Validation dataloader
73
+ """
74
+ val_dataloader = torch.utils.data.DataLoader(self.val_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True,
75
+ num_workers=self.cfg.GENERAL.NUM_WORKERS, pin_memory=True)
76
+ return val_dataloader
77
+
78
+
79
+
prima/datasets/datasets.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import copy
11
+ import os
12
+ import numpy as np
13
+ import torch
14
+ from yacs.config import CfgNode
15
+ import cv2
16
+ import pyrootutils
17
+ from torch.utils.data import ConcatDataset
18
+ from typing import List
19
+ root = pyrootutils.setup_root(
20
+ search_from=__file__,
21
+ indicator=[".git", "pyproject.toml"],
22
+ pythonpath=True,
23
+ dotenv=True,
24
+ )
25
+
26
+ import json
27
+ import hydra
28
+ from omegaconf import DictConfig, OmegaConf
29
+ from PIL import Image
30
+ from torch.utils.data import Dataset, DataLoader
31
+ from typing import Optional, Tuple
32
+ from .utils import get_example, expand_to_aspect_ratio
33
+
34
+
35
+ class TrainDataset(Dataset):
36
+ def __init__(self, cfg: CfgNode, is_train: bool, root_image: str, json_file: str):
37
+ super().__init__()
38
+ self.root_image = root_image
39
+ self.focal_length = cfg.SMAL.get("FOCAL_LENGTH", 1000)
40
+
41
+ json_file = json_file
42
+ with open(json_file, 'r') as f:
43
+ self.data = json.load(f)
44
+
45
+ self.is_train = is_train
46
+ self.IMG_SIZE = cfg.MODEL.IMAGE_SIZE
47
+ self.MEAN = 255. * np.array(cfg.MODEL.IMAGE_MEAN)
48
+ self.STD = 255. * np.array(cfg.MODEL.IMAGE_STD)
49
+ self.use_skimage_antialias = cfg.DATASETS.get('USE_SKIMAGE_ANTIALIAS', False)
50
+ self.border_mode = {
51
+ 'constant': cv2.BORDER_CONSTANT,
52
+ 'replicate': cv2.BORDER_REPLICATE,
53
+ }[cfg.DATASETS.get('BORDER_MODE', 'constant')]
54
+
55
+ self.augm_config = cfg.DATASETS.CONFIG
56
+
57
+ def __len__(self):
58
+ return len(self.data['data'])
59
+
60
+ def __getitem__(self, item):
61
+ data = self.data['data'][item]
62
+ key = data['img_path']
63
+ image = np.array(Image.open(os.path.join(self.root_image, key)).convert("RGB"))
64
+ mask = np.array(Image.open(os.path.join(self.root_image, data['mask_path'])).convert('L'))
65
+ category_idx = data['supercategory']
66
+ keypoint_2d = np.array(data['keypoint_2d'], dtype=np.float32)
67
+ if 'keypoint_3d' in data:
68
+ keypoint_3d = np.concatenate(
69
+ (data['keypoint_3d'], np.ones((len(data['keypoint_3d']), 1))), axis=-1).astype(np.float32)
70
+ else:
71
+ keypoint_3d = np.zeros((len(keypoint_2d), 4), dtype=np.float32)
72
+ bbox = data['bbox'] # [x, y, w, h]
73
+ center = np.array([(bbox[0] * 2 + bbox[2]) // 2, (bbox[1] * 2 + bbox[3]) // 2])
74
+ pose = np.array(data['pose'], dtype=np.float32) if 'pose' in data else np.zeros(105, dtype=np.float32) # [105, ]
75
+ betas = np.array(data['shape'] + data['shape_extra'], dtype=np.float32) if 'shape' in data else np.zeros(41, dtype=np.float32) # [41, ]
76
+ translation = np.array(data['trans'], dtype=np.float32) if 'trans' in data else np.zeros(3, dtype=np.float32) # [3, ]
77
+ # Fixed: Check if all elements are zero, not if all elements are truthy
78
+ has_pose = np.array(1., dtype=np.float32) if not (pose == 0).all() else np.array(0., dtype=np.float32)
79
+ has_betas = np.array(1., dtype=np.float32) if not (betas == 0).all() else np.array(0., dtype=np.float32)
80
+ has_translation = np.array(1., dtype=np.float32) if not (translation == 0).all() else np.array(0., dtype=np.float32)
81
+ ori_keypoint_2d = keypoint_2d.copy()
82
+ center_x, center_y = center[0], center[1]
83
+ bbox_size = max([bbox[2], bbox[3]])
84
+
85
+ smal_params = {'global_orient': pose[:3],
86
+ 'pose': pose[3:],
87
+ 'betas': betas,
88
+ 'transl': translation,
89
+ }
90
+ has_smal_params = {'global_orient': has_pose,
91
+ 'pose': has_pose,
92
+ 'betas': has_betas,
93
+ 'transl': has_translation,
94
+ }
95
+ smal_params_is_axis_angle = {'global_orient': True,
96
+ 'pose': True,
97
+ 'betas': False,
98
+ 'transl': False,
99
+ }
100
+
101
+ augm_config = copy.deepcopy(self.augm_config)
102
+ img_rgba = np.concatenate([image, mask[:, :, None]], axis=2)
103
+ img_patch_rgba, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, trans, img_border_mask = get_example(
104
+ img_rgba,
105
+ center_x, center_y,
106
+ bbox_size, bbox_size,
107
+ keypoint_2d, keypoint_3d,
108
+ smal_params, has_smal_params,
109
+ self.IMG_SIZE, self.IMG_SIZE,
110
+ self.MEAN, self.STD, self.is_train, augm_config,
111
+ is_bgr=False, return_trans=True,
112
+ use_skimage_antialias=self.use_skimage_antialias,
113
+ border_mode=self.border_mode
114
+ )
115
+ img_patch = (img_patch_rgba[:3, :, :])
116
+ mask_patch = (img_patch_rgba[3, :, :] / 255.0).clip(0, 1)
117
+ if (mask_patch < 0.5).all():
118
+ mask_patch = np.ones_like(mask_patch)
119
+
120
+ item = {'img': img_patch,
121
+ 'mask': mask_patch,
122
+ 'keypoints_2d': keypoints_2d,
123
+ 'keypoints_3d': keypoints_3d,
124
+ 'orig_keypoints_2d': ori_keypoint_2d,
125
+ 'box_center': np.array(center.copy(), dtype=np.float32),
126
+ 'box_size': float(bbox_size),
127
+ 'img_size': np.array(1.0 * img_size[::-1].copy(), dtype=np.float32),
128
+ 'smal_params': smal_params,
129
+ 'has_smal_params': has_smal_params,
130
+ 'smal_params_is_axis_angle': smal_params_is_axis_angle,
131
+ '_trans': trans,
132
+ 'focal_length': np.array([self.focal_length, self.focal_length], dtype=np.float32),
133
+ 'category': np.array(category_idx, dtype=np.int32),
134
+ 'supercategory': np.array(category_idx, dtype=np.int32),
135
+ "img_border_mask": img_border_mask.astype(np.float32),
136
+ "has_mask": np.array(1, dtype=np.float32)}
137
+ return item
138
+
139
+
140
+ class EvaluationDataset(Dataset):
141
+ def __init__(self, root_image: str, json_file: str, augm_config,
142
+ focal_length: int=1000, image_size: int=256,
143
+ mean: List[float]=[0.485, 0.456, 0.406], std: List[float]=[0.229, 0.224, 0.225]):
144
+ super().__init__()
145
+ self.root_image = root_image
146
+ self.focal_length = focal_length
147
+
148
+ with open(json_file, 'r') as f:
149
+ self.data = json.load(f)
150
+
151
+ self.is_train = False
152
+ self.IMG_SIZE = image_size
153
+ self.MEAN = 255. * np.array(mean)
154
+ self.STD = 255. * np.array(std)
155
+ self.use_skimage_antialias = False
156
+ self.border_mode = cv2.BORDER_CONSTANT
157
+ self.augm_config = augm_config
158
+
159
+ def __len__(self):
160
+ return len(self.data['data'])
161
+
162
+ def __getitem__(self, item):
163
+ data = self.data['data'][item]
164
+ key = data['img_path']
165
+ image = np.array(Image.open(os.path.join(self.root_image, key)).convert("RGB"))
166
+ mask = np.array(Image.open(os.path.join(self.root_image, data['mask_path'])).convert('L'))
167
+ category_idx = data['supercategory']
168
+ keypoint_2d = np.array(data['keypoint_2d'], dtype=np.float32)
169
+ # add check keypoint_3d, make it suitable for 2D dataset, and same with train dataset
170
+ if 'keypoint_3d' in data:
171
+ keypoint_3d = np.concatenate(
172
+ (data['keypoint_3d'], np.ones((len(data['keypoint_3d']), 1))), axis=-1).astype(np.float32)
173
+ else:
174
+ keypoint_3d = np.zeros((len(keypoint_2d), 4), dtype=np.float32)
175
+ bbox = data['bbox'] # [x, y, w, h]
176
+ center = np.array([(bbox[0] * 2 + bbox[2]) // 2, (bbox[1] * 2 + bbox[3]) // 2])
177
+ pose = np.array(data['pose'], dtype=np.float32) if 'pose' in data else np.zeros(105, dtype=np.float32) # [105, ]
178
+ betas = np.array(data['shape'] + data['shape_extra'], dtype=np.float32) if 'shape' in data else np.zeros(41, dtype=np.float32) # [41, ]
179
+ translation = np.array(data['trans'], dtype=np.float32) if 'trans' in data else np.zeros(3, dtype=np.float32) # [3, ]
180
+ # Fixed: Check if all elements are zero, not if all elements are truthy
181
+ has_pose = np.array(1., dtype=np.float32) if not (pose == 0).all() else np.array(0., dtype=np.float32)
182
+ has_betas = np.array(1., dtype=np.float32) if not (betas == 0).all() else np.array(0., dtype=np.float32)
183
+ has_translation = np.array(1., dtype=np.float32) if not (translation == 0).all() else np.array(0., dtype=np.float32)
184
+ ori_keypoint_2d = keypoint_2d.copy()
185
+ center_x, center_y = center[0], center[1]
186
+
187
+ scale = np.array([bbox[2], bbox[3]], dtype=np.float32) / 200.
188
+ bbox_size = expand_to_aspect_ratio(scale*200, None).max()
189
+ bbox_expand_factor = bbox_size / ((scale*200).max())
190
+
191
+ smal_params = {'global_orient': pose[:3],
192
+ 'pose': pose[3:],
193
+ 'betas': betas,
194
+ 'transl': translation,
195
+ 'bone': np.zeros(24, dtype=np.float32) if 'bone' not in data else np.array(data['bone'])
196
+ }
197
+ has_smal_params = {'global_orient': has_pose,
198
+ 'pose': has_pose,
199
+ 'betas': has_betas,
200
+ 'transl': has_translation,
201
+ 'bone': np.array(1., dtype=np.float32) if 'bone' in data else np.array(0., dtype=np.float32),
202
+ }
203
+ smal_params_is_axis_angle = {'global_orient': True,
204
+ 'pose': True,
205
+ 'betas': False,
206
+ 'transl': False,
207
+ 'bone': False
208
+ }
209
+
210
+ augm_config = copy.deepcopy(self.augm_config)
211
+ img_rgba = np.concatenate([image, mask[:, :, None]], axis=2)
212
+ img_patch_rgba, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, trans, img_border_mask = get_example(
213
+ img_rgba,
214
+ center_x, center_y,
215
+ bbox_size, bbox_size,
216
+ keypoint_2d, keypoint_3d,
217
+ smal_params, has_smal_params,
218
+ self.IMG_SIZE, self.IMG_SIZE,
219
+ self.MEAN, self.STD, self.is_train, augm_config,
220
+ is_bgr=False, return_trans=True,
221
+ use_skimage_antialias=self.use_skimage_antialias,
222
+ border_mode=self.border_mode
223
+ )
224
+ img_patch = (img_patch_rgba[:3, :, :])
225
+ mask_patch = (img_patch_rgba[3, :, :] / 255.0).clip(0, 1)
226
+ if (mask_patch < 0.5).all():
227
+ mask_patch = np.ones_like(mask_patch)
228
+
229
+ item = {'img': img_patch,
230
+ 'mask': mask_patch,
231
+ 'keypoints_2d': keypoints_2d,
232
+ 'keypoints_3d': keypoints_3d,
233
+ 'orig_keypoints_2d': ori_keypoint_2d,
234
+ 'box_center': np.array(center.copy(), dtype=np.float32),
235
+ 'box_size': float(bbox_size),
236
+ 'img_size': np.array(1.0 * img_size[::-1].copy(), dtype=np.float32),
237
+ 'smal_params': smal_params,
238
+ 'has_smal_params': has_smal_params,
239
+ 'smal_params_is_axis_angle': smal_params_is_axis_angle,
240
+ '_trans': trans,
241
+ 'focal_length': np.array([self.focal_length, self.focal_length], dtype=np.float32),
242
+ 'category': np.array(category_idx, dtype=np.int32),
243
+ 'bbox_expand_factor': bbox_expand_factor,
244
+ 'supercategory': np.array(category_idx, dtype=np.int32),
245
+ "img_border_mask": img_border_mask.astype(np.float32),
246
+ 'has_mask': np.array(1., dtype=np.float32),
247
+ 'imgname': key,
248
+ 'bbox': np.array(bbox, dtype=np.float32)}
249
+ return item
250
+
251
+
252
+ class OptionAnimalDataset(Dataset):
253
+ def __init__(self, cfg: CfgNode):
254
+ datasets = []
255
+ weights = []
256
+
257
+ dataset_configs = cfg.DATASETS
258
+ for dataset_name in dataset_configs:
259
+ if dataset_name != "CONFIG":
260
+ datasets.append(TrainDataset(cfg,
261
+ is_train=True,
262
+ root_image=dataset_configs[dataset_name].ROOT_IMAGE,
263
+ json_file=dataset_configs[dataset_name].JSON_FILE.TRAIN))
264
+ weights.extend([dataset_configs[dataset_name].WEIGHT] * len(datasets[-1]))
265
+
266
+ # Concatenate all enabled datasets
267
+ if datasets:
268
+ self.dataset = ConcatDataset(datasets)
269
+ self.weights = torch.tensor(weights, dtype=torch.float32)
270
+ else:
271
+ raise ValueError("No datasets enabled in the configuration.")
272
+
273
+ def __len__(self):
274
+ return len(self.dataset)
275
+
276
+ def __getitem__(self, idx):
277
+ return self.dataset[idx]
278
+
prima/datasets/dlc2coco.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ '''
11
+ this scripts if to convert DeepLabCut labeled data format (20 keypoints) to COCO format (26 keypoints ), also image should be extracted from the raw video to save as frames.
12
+
13
+ Usage:
14
+ python dlc2coco.py --dataset_dir /path/to/dataset --output_dir /path/to/output
15
+
16
+ for camera x
17
+ dlc keypoint data: <dataset_dir>/<behavior>/fte_pw/camx_fte.csv,
18
+ where video frame index from the video and keypoint coordinates are stored
19
+ raw video: <dataset_dir>/<behavior>/camx.mp4
20
+
21
+ for coco format, please refer to:
22
+ ./datasets/quadruped2d/test.json
23
+
24
+ also, the relationship of multiview should be saved.
25
+
26
+
27
+ keypoint mapping from acinoset to animal3d :
28
+ keypoint_mapping = {"acinoset":[2, 1, -1, 13, 10, 19, 16, 5, -1, -1, -1, -1, 11, 8, 12, 9, 18, 15, 3, 7, -1,-1,-1,-1, 0, 6]}
29
+
30
+
31
+ '''
32
+
33
+ import argparse
34
+ import os
35
+ import json
36
+ import cv2
37
+ import numpy as np
38
+ import pandas as pd
39
+ from pathlib import Path
40
+ from tqdm import tqdm
41
+
42
+ # DLC keypoints (20 keypoints from acinoset):
43
+ # 0: nose, 1: r_eye, 2: l_eye, 3: neck_base, 4: spine, 5: tail_base, 6: tail1, 7: tail2,
44
+ # 8: r_shoulder, 9: r_front_knee, 10: r_front_ankle, 11: l_shoulder, 12: l_front_knee, 13: l_front_ankle,
45
+ # 14: r_hip, 15: r_back_knee, 16: r_back_ankle, 17: l_hip, 18: l_back_knee, 19: l_back_ankle
46
+
47
+ # Animal3D keypoints (26 keypoints):
48
+ # Based on the mapping: [2, 1, -1, 13, 10, 19, 16, 5, -1, -1, -1, -1, 11, 8, 12, 9, 18, 15, 3, 7, -1,-1,-1,-1, 0, 6]
49
+ # This means: animal3d_idx 0 maps to acinoset_idx 2 (l_eye), animal3d_idx 1 maps to acinoset_idx 1 (r_eye), etc.
50
+
51
+ # Keypoint mapping from acinoset (DLC) to animal3d (COCO format)
52
+ KEYPOINT_MAPPING = [2, 1, -1, 13, 10, 19, 16, 5, -1, -1, -1, -1, 11, 8, 12, 9, 18, 15, 3, 7, -1, -1, -1, -1, 0, 6]
53
+
54
+ def read_dlc_csv(csv_path):
55
+ """
56
+ Read DeepLabCut CSV file and extract keypoint data
57
+ Returns: DataFrame with frame index and keypoint coordinates
58
+ """
59
+ # Read the CSV file, skip the first 2 rows (header rows)
60
+ df = pd.read_csv(csv_path, skiprows=2)
61
+
62
+ # Replace NaN with 0
63
+ df = df.fillna(0)
64
+
65
+ # The first column is frame index
66
+ frame_indices = df.iloc[:, 0].values
67
+
68
+ # Extract keypoint coordinates (x, y, likelihood)
69
+ # DLC format: each keypoint has 3 columns (x, y, likelihood)
70
+ num_keypoints = 20
71
+ keypoints_data = []
72
+
73
+ for idx, frame_idx in enumerate(frame_indices):
74
+ keypoints = []
75
+ for kp_idx in range(num_keypoints):
76
+ col_start = 1 + kp_idx * 3
77
+ x = float(df.iloc[idx, col_start])
78
+ y = float(df.iloc[idx, col_start + 1])
79
+ likelihood = float(df.iloc[idx, col_start + 2])
80
+
81
+ # If likelihood is 0 (from NaN), but x and y are not 0, assume it's a valid point
82
+ if likelihood == 0 and (x != 0 or y != 0):
83
+ likelihood = 1.0 # Default to high confidence
84
+
85
+ keypoints.append([x, y, likelihood])
86
+
87
+ keypoints_data.append({
88
+ 'frame_idx': int(frame_idx),
89
+ 'keypoints': keypoints
90
+ })
91
+
92
+ return keypoints_data
93
+
94
+ def map_keypoints_to_animal3d(acinoset_keypoints):
95
+ """
96
+ Map 20 DLC keypoints to 26 Animal3D keypoints using the provided mapping
97
+ acinoset_keypoints: list of [x, y, likelihood] for 20 keypoints
98
+ Returns: list of [x, y, visibility] for 26 keypoints
99
+ """
100
+ animal3d_keypoints = []
101
+
102
+ for animal3d_idx, acinoset_idx in enumerate(KEYPOINT_MAPPING):
103
+ if acinoset_idx == -1:
104
+ # Missing keypoint, set to [0, 0, 0]
105
+ animal3d_keypoints.append([0.0, 0.0, 0.0])
106
+ else:
107
+ x, y, likelihood = acinoset_keypoints[acinoset_idx]
108
+ # Replace NaN with 0
109
+ if np.isnan(x):
110
+ x = 0.0
111
+ if np.isnan(y):
112
+ y = 0.0
113
+ if np.isnan(likelihood):
114
+ likelihood = 0.0
115
+
116
+ # Convert likelihood to visibility flag (2 = visible, 1 = occluded, 0 = not labeled)
117
+ # If the keypoint has valid coordinates, mark as visible
118
+ if x != 0.0 or y != 0.0:
119
+ visibility = 2.0
120
+ else:
121
+ visibility = 0.0
122
+
123
+ animal3d_keypoints.append([float(x), float(y), visibility])
124
+
125
+ return animal3d_keypoints
126
+
127
+ def extract_frames_from_video(video_path, output_dir, frame_indices, behavior, camera_id):
128
+ """
129
+ Extract specific frames from video and save as images
130
+ Returns: dict mapping frame_idx to image path
131
+ """
132
+ output_dir = Path(output_dir)
133
+ output_dir.mkdir(parents=True, exist_ok=True)
134
+
135
+ cap = cv2.VideoCapture(str(video_path))
136
+ if not cap.isOpened():
137
+ print(f"Error: Cannot open video {video_path}")
138
+ return {}
139
+
140
+ frame_paths = {}
141
+
142
+ # Sort frame indices for efficient extraction
143
+ sorted_frames = sorted(set(frame_indices)) # Remove duplicates
144
+
145
+ pbar = tqdm(total=len(sorted_frames), desc=f"Extracting frames from {video_path.name}")
146
+
147
+ for target_frame in sorted_frames:
148
+ # Use CAP_PROP_POS_FRAMES to seek to the exact frame
149
+ cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
150
+ ret, frame = cap.read()
151
+
152
+ if ret and frame is not None:
153
+ # Save frame as image with behavior name in filename
154
+ img_filename = f"{behavior}_cam{camera_id}_frame_{target_frame:06d}.jpg"
155
+ img_path = output_dir / img_filename
156
+ cv2.imwrite(str(img_path), frame)
157
+ frame_paths[target_frame] = str(img_path.relative_to(output_dir.parent.parent))
158
+ else:
159
+ print(f"Warning: Failed to read frame {target_frame} from {video_path.name}")
160
+
161
+ pbar.update(1)
162
+
163
+ pbar.close()
164
+ cap.release()
165
+
166
+ return frame_paths
167
+
168
+ def compute_bbox_from_keypoints(keypoints):
169
+ """
170
+ Compute bounding box from keypoints
171
+ keypoints: list of [x, y, visibility]
172
+ Returns: [x, y, width, height]
173
+ """
174
+ valid_points = [(kp[0], kp[1]) for kp in keypoints if kp[2] > 0]
175
+
176
+ if not valid_points:
177
+ return [0, 0, 0, 0]
178
+
179
+ xs, ys = zip(*valid_points)
180
+ x_min, x_max = min(xs), max(xs)
181
+ y_min, y_max = min(ys), max(ys)
182
+
183
+ # Add some padding
184
+ padding = 20
185
+ x_min = max(0, x_min - padding)
186
+ y_min = max(0, y_min - padding)
187
+ width = (x_max - x_min) + 2 * padding
188
+ height = (y_max - y_min) + 2 * padding
189
+
190
+ return [float(x_min), float(y_min), float(width), float(height)]
191
+
192
+ def process_camera(camera_id, base_dir, output_dir, behavior):
193
+ """
194
+ Process one camera: read CSV, extract frames, convert to COCO format
195
+ behavior: name of the behavior (e.g., 'run', 'flick')
196
+ """
197
+ base_dir = Path(base_dir)
198
+ output_dir = Path(output_dir)
199
+
200
+ # Paths
201
+ csv_path = base_dir / "fte_pw" / f"cam{camera_id}_fte.csv"
202
+ video_path = base_dir / f"cam{camera_id}.mp4"
203
+
204
+ print(f"\nProcessing Camera {camera_id} - Behavior: {behavior}...")
205
+ print(f"CSV: {csv_path}")
206
+ print(f"Video: {video_path}")
207
+
208
+ # Read keypoint data from CSV
209
+ keypoints_data = read_dlc_csv(csv_path)
210
+ print(f"Found {len(keypoints_data)} frames with keypoints")
211
+
212
+ # Extract frames from video
213
+ frame_indices = [kp_data['frame_idx'] for kp_data in keypoints_data]
214
+ images_dir = output_dir / "images" / behavior / f"cam{camera_id}"
215
+ frame_paths = extract_frames_from_video(video_path, images_dir, frame_indices, behavior, camera_id)
216
+
217
+ # Convert to COCO format
218
+ coco_data = []
219
+ for kp_data in tqdm(keypoints_data, desc=f"Converting cam{camera_id} to COCO format"):
220
+ frame_idx = kp_data['frame_idx']
221
+
222
+ if frame_idx not in frame_paths:
223
+ continue
224
+
225
+ # Map keypoints from acinoset (20) to animal3d (26)
226
+ acinoset_kps = kp_data['keypoints']
227
+ animal3d_kps = map_keypoints_to_animal3d(acinoset_kps)
228
+
229
+ # Compute bounding box
230
+ bbox = compute_bbox_from_keypoints(animal3d_kps)
231
+
232
+ # Create COCO entry
233
+ img_path = frame_paths[frame_idx]
234
+ coco_entry = {
235
+ "img_path": img_path,
236
+ "mask_path": img_path, # Same as img_path
237
+ "bbox": bbox,
238
+ "keypoint_2d": animal3d_kps,
239
+ "camera_id": camera_id,
240
+ "frame_idx": frame_idx,
241
+ "behavior": behavior
242
+ }
243
+
244
+ coco_data.append(coco_entry)
245
+
246
+ return coco_data
247
+
248
+ def parse_args():
249
+ parser = argparse.ArgumentParser(
250
+ description="Convert DeepLabCut labeled data to COCO format"
251
+ )
252
+ parser.add_argument(
253
+ "--dataset_dir", type=str, default=".",
254
+ help="Root directory containing behavior subdirectories (run, flick, etc.)"
255
+ )
256
+ parser.add_argument(
257
+ "--output_dir", type=str, default=None,
258
+ help="Output directory for COCO format data (default: {dataset_dir}/coco_format)"
259
+ )
260
+ parser.add_argument(
261
+ "--behaviors", type=str, nargs="+", default=["run", "flick"],
262
+ help="Behavior names to process (default: run flick)"
263
+ )
264
+ parser.add_argument(
265
+ "--cameras", type=int, nargs="+", default=[1, 2, 3, 4, 5, 6],
266
+ help="Camera IDs to process (default: 1 2 3 4 5 6)"
267
+ )
268
+ return parser.parse_args()
269
+
270
+
271
+ def main():
272
+ args = parse_args()
273
+
274
+ dataset_dir = Path(args.dataset_dir)
275
+ output_dir = Path(args.output_dir) if args.output_dir else dataset_dir / "coco_format"
276
+ output_dir.mkdir(parents=True, exist_ok=True)
277
+
278
+ behaviors = args.behaviors
279
+ camera_ids = args.cameras
280
+
281
+ all_data = []
282
+ behavior_data = {}
283
+ camera_data = {}
284
+
285
+ for behavior in behaviors:
286
+ behavior_dir = dataset_dir / behavior
287
+ behavior_data[behavior] = []
288
+
289
+ print(f"\n{'='*60}")
290
+ print(f"Processing Behavior: {behavior.upper()}")
291
+ print(f"{'='*60}")
292
+
293
+ for cam_id in camera_ids:
294
+ coco_data = process_camera(cam_id, behavior_dir, output_dir, behavior)
295
+ all_data.extend(coco_data)
296
+ behavior_data[behavior].extend(coco_data)
297
+
298
+ # Store per-camera-behavior data
299
+ key = f"{behavior}_cam{cam_id}"
300
+ camera_data[key] = coco_data
301
+
302
+ # Save combined data (all behaviors and cameras)
303
+ output_json = output_dir / "all_data.json"
304
+ with open(output_json, 'w') as f:
305
+ json.dump({"data": all_data}, f, indent=4)
306
+
307
+ print(f"\n{'='*60}")
308
+ print(f"SUMMARY")
309
+ print(f"{'='*60}")
310
+ print(f"Saved combined data to {output_json}")
311
+ print(f"Total entries: {len(all_data)}")
312
+
313
+ # Save per-behavior data
314
+ for behavior in behaviors:
315
+ behavior_json = output_dir / f"{behavior}.json"
316
+ with open(behavior_json, 'w') as f:
317
+ json.dump({"data": behavior_data[behavior]}, f, indent=4)
318
+ print(f"\nSaved {behavior} data to {behavior_json} ({len(behavior_data[behavior])} entries)")
319
+
320
+ # Save per-camera-behavior data
321
+ for behavior in behaviors:
322
+ for cam_id in camera_ids:
323
+ key = f"{behavior}_cam{cam_id}"
324
+ cam_json = output_dir / f"{behavior}_cam{cam_id}.json"
325
+ with open(cam_json, 'w') as f:
326
+ json.dump({"data": camera_data[key]}, f, indent=4)
327
+ print(f" - {behavior}_cam{cam_id}: {len(camera_data[key])} entries")
328
+
329
+ # Save multiview relationships
330
+ # Group by behavior and frame index to establish multiview correspondence
331
+ multiview_data = {}
332
+ for entry in all_data:
333
+ behavior = entry['behavior']
334
+ frame_idx = entry['frame_idx']
335
+ cam_id = entry['camera_id']
336
+
337
+ if behavior not in multiview_data:
338
+ multiview_data[behavior] = {}
339
+
340
+ if frame_idx not in multiview_data[behavior]:
341
+ multiview_data[behavior][frame_idx] = {}
342
+
343
+ multiview_data[behavior][frame_idx][f"cam{cam_id}"] = {
344
+ "img_path": entry['img_path'],
345
+ "keypoint_2d": entry['keypoint_2d'],
346
+ "bbox": entry['bbox']
347
+ }
348
+
349
+ multiview_json = output_dir / "multiview_mapping.json"
350
+ with open(multiview_json, 'w') as f:
351
+ json.dump(multiview_data, f, indent=4)
352
+
353
+ print(f"\nSaved multiview mapping to {multiview_json}")
354
+ for behavior in behaviors:
355
+ print(f" - {behavior}: {len(multiview_data.get(behavior, {}))} synchronized frames")
356
+
357
+ print(f"\n{'='*60}")
358
+ print("Conversion complete!")
359
+ print(f"{'='*60}")
360
+
361
+ if __name__ == "__main__":
362
+ main()
prima/datasets/split_acinoset.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """
11
+ Split acinoset multiview_mapping.json into train and test sets (7:3 ratio).
12
+
13
+ Usage:
14
+ python split_acinoset.py \
15
+ --input_json /path/to/multiview_mapping.json \
16
+ --output_dir /path/to/output \
17
+ --train_ratio 0.7 \
18
+ --seed 42
19
+ """
20
+
21
+ import argparse
22
+ import json
23
+ import random
24
+ from pathlib import Path
25
+ from collections import defaultdict
26
+
27
+ # ------------------------------------------------------------------
28
+ # EDIT THIS to point to your dataset root (see examples above).
29
+ # All paths below are relative to this directory.
30
+ # ------------------------------------------------------------------
31
+ BASE_DIR = Path("datasets")
32
+
33
+
34
+ def split_multiview_data(input_json, output_dir, train_ratio=0.7, seed=42):
35
+ """
36
+ Split multiview mapping data into train and test sets.
37
+
38
+
39
+ Args:
40
+ input_json: Path to multiview_mapping.json
41
+ output_dir: Directory to save train.json and test.json
42
+ train_ratio: Ratio of training data (default 0.7 for 70%%)
43
+ train_ratio: Ratio of training data (default 0.7 for 70%%)
44
+ seed: Random seed for reproducibility
45
+ """
46
+ # Set random seed
47
+ random.seed(seed)
48
+
49
+ # Load data
50
+ print(f"Loading data from {input_json}...")
51
+ with open(input_json, 'r') as f:
52
+ data = json.load(f)
53
+
54
+ # Initialize train and test splits
55
+ train_data = defaultdict(dict)
56
+ test_data = defaultdict(dict)
57
+
58
+ # Process each behavior
59
+ for behavior, frames in data.items():
60
+ print(f"\nProcessing behavior: {behavior}")
61
+
62
+ # Get all frame indices
63
+ frame_indices = list(frames.keys())
64
+ total_frames = len(frame_indices)
65
+
66
+ # Shuffle frame indices
67
+ random.shuffle(frame_indices)
68
+
69
+ # Calculate split point
70
+ train_size = int(total_frames * train_ratio)
71
+
72
+ # Split frames
73
+ train_frames = frame_indices[:train_size]
74
+ test_frames = frame_indices[train_size:]
75
+
76
+ print(f" Total frames: {total_frames}")
77
+ print(f" Train frames: {len(train_frames)}")
78
+ print(f" Test frames: {len(test_frames)}")
79
+
80
+ # Assign to train and test
81
+ for frame_idx in train_frames:
82
+ train_data[behavior][frame_idx] = frames[frame_idx]
83
+
84
+ for frame_idx in test_frames:
85
+ test_data[behavior][frame_idx] = frames[frame_idx]
86
+
87
+ # Save train and test splits
88
+ output_dir = Path(output_dir)
89
+ output_dir.mkdir(parents=True, exist_ok=True)
90
+
91
+ train_json = output_dir / "train.json"
92
+ test_json = output_dir / "test.json"
93
+
94
+ print(f"\nSaving train data to {train_json}...")
95
+ with open(train_json, 'w') as f:
96
+ json.dump(dict(train_data), f, indent=4)
97
+
98
+ print(f"Saving test data to {test_json}...")
99
+ with open(test_json, 'w') as f:
100
+ json.dump(dict(test_data), f, indent=4)
101
+
102
+ # Print summary
103
+ print("\n" + "="*50)
104
+ print("Summary:")
105
+ print("="*50)
106
+
107
+ total_train_frames = sum(len(frames) for frames in train_data.values())
108
+ total_test_frames = sum(len(frames) for frames in test_data.values())
109
+ total_frames = total_train_frames + total_test_frames
110
+
111
+ print(f"Total frames: {total_frames}")
112
+ print(f"Train frames: {total_train_frames} ({total_train_frames/total_frames*100:.1f}%%)")
113
+ print(f"Test frames: {total_test_frames} ({total_test_frames/total_frames*100:.1f}%%)")
114
+ print("\nPer behavior:")
115
+ for behavior in train_data.keys():
116
+ train_count = len(train_data[behavior])
117
+ test_count = len(test_data[behavior])
118
+ total_count = train_count + test_count
119
+ print(f" {behavior}: train={train_count}, test={test_count}, total={total_count}")
120
+
121
+ print("\nDone!")
122
+
123
+
124
+ if __name__ == "__main__":
125
+ parser = argparse.ArgumentParser(
126
+ description="Split multiview_mapping.json into train/test sets (default 7:3)."
127
+ )
128
+ parser.add_argument(
129
+ "--input_json", type=str,
130
+ default="datasets/acinoset/multiview_mapping.json",
131
+ help="Path to multiview_mapping.json (default: datasets/acinoset/multiview_mapping.json)."
132
+ )
133
+ parser.add_argument(
134
+ "--output_dir", type=str,
135
+ default="datasets/acinoset",
136
+ help="Directory to save train.json and test.json (default: datasets/acinoset)."
137
+ )
138
+ parser.add_argument(
139
+ "--train_ratio", type=float, default=0.7,
140
+ help="Fraction of data for training (default: 0.7)."
141
+ )
142
+ parser.add_argument(
143
+ "--seed", type=int, default=42,
144
+ help="Random seed for reproducibility (default: 42)."
145
+ )
146
+ args = parser.parse_args()
147
+
148
+ split_multiview_data(
149
+ input_json=args.input_json,
150
+ output_dir=args.output_dir,
151
+ train_ratio=args.train_ratio,
152
+ seed=args.seed,
153
+ )
prima/datasets/utils.py ADDED
@@ -0,0 +1,1106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """
11
+ Parts of the code are taken or adapted from
12
+ https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py
13
+ """
14
+ import torch
15
+ import numpy as np
16
+ from skimage.transform import rotate, resize
17
+ from skimage.filters import gaussian
18
+ import random
19
+ import cv2
20
+ from typing import List, Dict, Tuple
21
+ from yacs.config import CfgNode
22
+ from typing import Union
23
+
24
+
25
+ def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None):
26
+ """Increase the size of the bounding box to match the target shape."""
27
+ if target_aspect_ratio is None:
28
+ return input_shape
29
+
30
+ try:
31
+ w, h = input_shape
32
+ except (ValueError, TypeError):
33
+ return input_shape
34
+
35
+ w_t, h_t = target_aspect_ratio
36
+ if h / w < h_t / w_t:
37
+ h_new = max(w * h_t / w_t, h)
38
+ w_new = w
39
+ else:
40
+ h_new = h
41
+ w_new = max(h * w_t / h_t, w)
42
+ if h_new < h or w_new < w:
43
+ raise ValueError(f"Expanded size ({w_new}, {h_new}) smaller than original ({w}, {h})")
44
+ return np.array([w_new, h_new])
45
+
46
+
47
+ def do_augmentation(aug_config: CfgNode) -> Tuple:
48
+ """
49
+ Compute random augmentation parameters.
50
+ Args:
51
+ aug_config (CfgNode): Config containing augmentation parameters.
52
+ Returns:
53
+ scale (float): Box rescaling factor.
54
+ rot (float): Random image rotation.
55
+ do_flip (bool): Whether to flip image or not.
56
+ do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT).
57
+ color_scale (List): Color rescaling factor
58
+ tx (float): Random translation along the x axis.
59
+ ty (float): Random translation along the y axis.
60
+ """
61
+
62
+ tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
63
+ ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
64
+ scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0
65
+ rot = np.clip(np.random.randn(), -2.0,
66
+ 2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0
67
+ do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE
68
+ do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE
69
+ extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0)
70
+ # extreme_crop_lvl = 0
71
+ c_up = 1.0 + aug_config.COLOR_SCALE
72
+ c_low = 1.0 - aug_config.COLOR_SCALE
73
+ color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]
74
+ return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty
75
+
76
+
77
+ def rotate_2d(pt_2d: np.array, rot_rad: float) -> np.array:
78
+ """
79
+ Rotate a 2D point on the x-y plane.
80
+ Args:
81
+ pt_2d (np.array): Input 2D point with shape (2,).
82
+ rot_rad (float): Rotation angle
83
+ Returns:
84
+ np.array: Rotated 2D point.
85
+ """
86
+ x = pt_2d[0]
87
+ y = pt_2d[1]
88
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
89
+ xx = x * cs - y * sn
90
+ yy = x * sn + y * cs
91
+ return np.array([xx, yy], dtype=np.float32)
92
+
93
+
94
+ def gen_trans_from_patch_cv(c_x: float, c_y: float,
95
+ src_width: float, src_height: float,
96
+ dst_width: float, dst_height: float,
97
+ scale: float, rot: float) -> np.array:
98
+ """
99
+ Create transformation matrix for the bounding box crop.
100
+ Args:
101
+ c_x (float): Bounding box center x coordinate in the original image.
102
+ c_y (float): Bounding box center y coordinate in the original image.
103
+ src_width (float): Bounding box width.
104
+ src_height (float): Bounding box height.
105
+ dst_width (float): Output box width.
106
+ dst_height (float): Output box height.
107
+ scale (float): Rescaling factor for the bounding box (augmentation).
108
+ rot (float): Random rotation applied to the box.
109
+ Returns:
110
+ trans (np.array): Target geometric transformation.
111
+ """
112
+ # augment size with scale
113
+ src_w = src_width * scale
114
+ src_h = src_height * scale
115
+ src_center = np.zeros(2)
116
+ src_center[0] = c_x
117
+ src_center[1] = c_y
118
+ # augment rotation
119
+ rot_rad = np.pi * rot / 180
120
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
121
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
122
+
123
+ dst_w = dst_width
124
+ dst_h = dst_height
125
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
126
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
127
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
128
+
129
+ src = np.zeros((3, 2), dtype=np.float32)
130
+ src[0, :] = src_center
131
+ src[1, :] = src_center + src_downdir
132
+ src[2, :] = src_center + src_rightdir
133
+
134
+ dst = np.zeros((3, 2), dtype=np.float32)
135
+ dst[0, :] = dst_center
136
+ dst[1, :] = dst_center + dst_downdir
137
+ dst[2, :] = dst_center + dst_rightdir
138
+
139
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
140
+
141
+ return trans
142
+
143
+
144
+ def trans_point2d(pt_2d: np.array, trans: np.array):
145
+ """
146
+ Transform a 2D point using translation matrix trans.
147
+ Args:
148
+ pt_2d (np.array): Input 2D point with shape (2,).
149
+ trans (np.array): Transformation matrix.
150
+ Returns:
151
+ np.array: Transformed 2D point.
152
+ """
153
+ src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T
154
+ dst_pt = np.dot(trans, src_pt)
155
+ return dst_pt[0:2]
156
+
157
+
158
+ def get_transform(center, scale, res, rot=0):
159
+ """Generate transformation matrix."""
160
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
161
+ h = 200 * scale
162
+ t = np.zeros((3, 3))
163
+ t[0, 0] = float(res[1]) / h
164
+ t[1, 1] = float(res[0]) / h
165
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
166
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
167
+ t[2, 2] = 1
168
+ if not rot == 0:
169
+ rot = -rot # To match direction of rotation from cropping
170
+ rot_mat = np.zeros((3, 3))
171
+ rot_rad = rot * np.pi / 180
172
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
173
+ rot_mat[0, :2] = [cs, -sn]
174
+ rot_mat[1, :2] = [sn, cs]
175
+ rot_mat[2, 2] = 1
176
+ # Need to rotate around center
177
+ t_mat = np.eye(3)
178
+ t_mat[0, 2] = -res[1] / 2
179
+ t_mat[1, 2] = -res[0] / 2
180
+ t_inv = t_mat.copy()
181
+ t_inv[:2, 2] *= -1
182
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
183
+ return t
184
+
185
+
186
+ def transform(pt, center, scale, res, invert=0, rot=0, as_int=True):
187
+ """Transform pixel location to different reference."""
188
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
189
+ t = get_transform(center, scale, res, rot=rot)
190
+ if invert:
191
+ t = np.linalg.inv(t)
192
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
193
+ new_pt = np.dot(t, new_pt)
194
+ if as_int:
195
+ new_pt = new_pt.astype(int)
196
+ return new_pt[:2] + 1
197
+
198
+
199
+ def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0):
200
+ c_x = (ul[0] + br[0]) / 2
201
+ c_y = (ul[1] + br[1]) / 2
202
+ bb_width = patch_width = br[0] - ul[0]
203
+ bb_height = patch_height = br[1] - ul[1]
204
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0)
205
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
206
+ flags=cv2.INTER_LINEAR,
207
+ borderMode=border_mode,
208
+ borderValue=border_value
209
+ )
210
+
211
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
212
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
213
+ img_patch[:, :, 3] = cv2.warpAffine(img[:, :, 3], trans, (int(patch_width), int(patch_height)),
214
+ flags=cv2.INTER_LINEAR,
215
+ borderMode=cv2.BORDER_CONSTANT,
216
+ )
217
+
218
+ return img_patch
219
+
220
+
221
+ def generate_image_patch_skimage(img: np.array, c_x: float, c_y: float,
222
+ bb_width: float, bb_height: float,
223
+ patch_width: float, patch_height: float,
224
+ do_flip: bool, scale: float, rot: float,
225
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
226
+ """
227
+ Crop image according to the supplied bounding box.
228
+ Args:
229
+ img (np.array): Input image of shape (H, W, 3)
230
+ c_x (float): Bounding box center x coordinate in the original image.
231
+ c_y (float): Bounding box center y coordinate in the original image.
232
+ bb_width (float): Bounding box width.
233
+ bb_height (float): Bounding box height.
234
+ patch_width (float): Output box width.
235
+ patch_height (float): Output box height.
236
+ do_flip (bool): Whether to flip image or not.
237
+ scale (float): Rescaling factor for the bounding box (augmentation).
238
+ rot (float): Random rotation applied to the box.
239
+ Returns:
240
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
241
+ trans (np.array): Transformation matrix.
242
+ """
243
+
244
+ img_height, img_width, img_channels = img.shape
245
+ if do_flip:
246
+ img = img[:, ::-1, :]
247
+ c_x = img_width - c_x - 1
248
+
249
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
250
+
251
+ # img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR)
252
+
253
+ # skimage
254
+ center = np.zeros(2)
255
+ center[0] = c_x
256
+ center[1] = c_y
257
+ res = np.zeros(2)
258
+ res[0] = patch_width
259
+ res[1] = patch_height
260
+ # assumes bb_width = bb_height
261
+ # assumes patch_width = patch_height
262
+ assert bb_width == bb_height, f'{bb_width=} != {bb_height=}'
263
+ assert patch_width == patch_height, f'{patch_width=} != {patch_height=}'
264
+ scale1 = scale * bb_width / 200.
265
+
266
+ # Upper left point
267
+ ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1
268
+ # Bottom right point
269
+ br = np.array(transform([res[0] + 1,
270
+ res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1
271
+
272
+ # Padding so that when rotated proper amount of context is included
273
+ try:
274
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1
275
+ except Exception as e:
276
+ raise RuntimeError(f"Failed to compute pad: ul={ul}, br={br}") from e
277
+ if not rot == 0:
278
+ ul -= pad
279
+ br += pad
280
+
281
+ if False:
282
+ # Old way of cropping image
283
+ ul_int = ul.astype(int)
284
+ br_int = br.astype(int)
285
+ new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]]
286
+ if len(img.shape) > 2:
287
+ new_shape += [img.shape[2]]
288
+ new_img = np.zeros(new_shape)
289
+
290
+ # Range to fill new array
291
+ new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0]
292
+ new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1]
293
+ # Range to sample from original image
294
+ old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0])
295
+ old_y = max(0, ul_int[1]), min(len(img), br_int[1])
296
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
297
+ old_x[0]:old_x[1]]
298
+
299
+ # New way of cropping image
300
+ new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32)
301
+
302
+ # print(f'{new_img.shape=}')
303
+ # print(f'{new_img1.shape=}')
304
+ # print(f'{np.allclose(new_img, new_img1)=}')
305
+ # print(f'{img.dtype=}')
306
+
307
+ if not rot == 0:
308
+ # Remove padding
309
+
310
+ new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
311
+ new_img = new_img[pad:-pad, pad:-pad]
312
+
313
+ if new_img.shape[0] < 1 or new_img.shape[1] < 1:
314
+ raise ValueError(
315
+ f"Image patch too small: {new_img.shape}, original: {img.shape}, "
316
+ f"ul={ul}, br={br}, pad={pad}, rot={rot}"
317
+ )
318
+
319
+ # resize image
320
+ new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
321
+
322
+ new_img = np.clip(new_img, 0, 255).astype(np.uint8)
323
+
324
+ return new_img, trans
325
+
326
+
327
+ def generate_image_patch_cv2(img: np.array, c_x: float, c_y: float,
328
+ bb_width: float, bb_height: float,
329
+ patch_width: float, patch_height: float,
330
+ do_flip: bool, scale: float, rot: float,
331
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
332
+ """
333
+ Crop the input image and return the crop and the corresponding transformation matrix.
334
+ Args:
335
+ img (np.array): Input image of shape (H, W, 3)
336
+ c_x (float): Bounding box center x coordinate in the original image.
337
+ c_y (float): Bounding box center y coordinate in the original image.
338
+ bb_width (float): Bounding box width.
339
+ bb_height (float): Bounding box height.
340
+ patch_width (float): Output box width.
341
+ patch_height (float): Output box height.
342
+ do_flip (bool): Whether to flip image or not.
343
+ scale (float): Rescaling factor for the bounding box (augmentation).
344
+ rot (float): Random rotation applied to the box.
345
+ Returns:
346
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
347
+ trans (np.array): Transformation matrix.
348
+ """
349
+
350
+ img_height, img_width, img_channels = img.shape
351
+ if do_flip:
352
+ img = img[:, ::-1, :]
353
+ c_x = img_width - c_x - 1
354
+
355
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
356
+
357
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
358
+ flags=cv2.INTER_LINEAR,
359
+ borderMode=border_mode,
360
+ borderValue=border_value,
361
+ )
362
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
363
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
364
+ img_patch[:, :, 3] = cv2.warpAffine(img[:, :, 3], trans, (int(patch_width), int(patch_height)),
365
+ flags=cv2.INTER_LINEAR,
366
+ borderMode=cv2.BORDER_CONSTANT,
367
+ )
368
+
369
+ is_border = np.all(img_patch[:, :, :-1] == border_value, axis=2) if img_patch.shape[2] == 4 else np.all(img_patch == 0, axis=2)
370
+ img_border_mask = ~is_border
371
+ return img_patch, trans, img_border_mask
372
+
373
+
374
+ def convert_cvimg_to_tensor(cvimg: np.array):
375
+ """
376
+ Convert image from HWC to CHW format.
377
+ Args:
378
+ cvimg (np.array): Image of shape (H, W, 3) as loaded by OpenCV.
379
+ Returns:
380
+ np.array: Output image of shape (3, H, W).
381
+ """
382
+ # from h,w,c(OpenCV) to c,h,w
383
+ img = cvimg.copy()
384
+ img = np.transpose(img, (2, 0, 1))
385
+ # from int to float
386
+ img = img.astype(np.float32)
387
+ return img
388
+
389
+
390
+ def fliplr_params(smal_params: Dict, has_smal_params: Dict) -> Tuple[Dict, Dict]:
391
+ """
392
+ Flip SMAL parameters when flipping the image.
393
+ Args:
394
+ smal_params (Dict): SMAL parameter annotations.
395
+ has_smal_params (Dict): Whether SMAL annotations are valid.
396
+ Returns:
397
+ Dict, Dict: Flipped SMAL parameters and valid flags.
398
+ """
399
+ global_orient = smal_params['global_orient'].copy()
400
+ pose = smal_params['pose'].copy()
401
+ betas = smal_params['betas'].copy()
402
+ transl = smal_params['transl'].copy()
403
+ has_global_orient = has_smal_params['global_orient'].copy()
404
+ has_pose = has_smal_params['pose'].copy()
405
+ has_betas = has_smal_params['betas'].copy()
406
+ has_transl = has_smal_params['transl'].copy()
407
+
408
+ global_orient[1::3] *= -1
409
+ global_orient[2::3] *= -1
410
+ pose[1::3] *= -1
411
+ pose[2::3] *= -1
412
+ transl[1::3] *= -1
413
+ transl[2::3] *= -1
414
+
415
+ smal_params = {'global_orient': global_orient.astype(np.float32),
416
+ 'pose': pose.astype(np.float32),
417
+ 'betas': betas.astype(np.float32),
418
+ 'transl': transl.astype(np.float32)
419
+ }
420
+
421
+ has_smal_params = {'global_orient': has_global_orient,
422
+ 'pose': has_pose,
423
+ 'betas': has_betas,
424
+ 'transl': has_transl
425
+ }
426
+
427
+ return smal_params, has_smal_params
428
+
429
+
430
+ def fliplr_keypoints(joints: np.array, width: float, flip_permutation: List[int]) -> np.array:
431
+ """
432
+ Flip 2D or 3D keypoints.
433
+ Args:
434
+ joints (np.array): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence.
435
+ flip_permutation (List): Permutation to apply after flipping.
436
+ Returns:
437
+ np.array: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively.
438
+ """
439
+ joints = joints.copy()
440
+ # Flip horizontal
441
+ joints[:, 0] = width - joints[:, 0] - 1
442
+ joints = joints[flip_permutation, :]
443
+
444
+ return joints
445
+
446
+
447
+ def keypoint_3d_processing(keypoints_3d: np.array, rot: float, flip: bool) -> np.array:
448
+ """
449
+ Process 3D keypoints (rotation/flipping).
450
+ Args:
451
+ keypoints_3d (np.array): Input array of shape (N, 4) containing the 3D keypoints and confidence.
452
+ rot (float): Random rotation applied to the keypoints.
453
+ Returns:
454
+ np.array: Transformed 3D keypoints with shape (N, 4).
455
+ """
456
+ # in-plane rotation
457
+ rot_mat = np.eye(3, dtype=np.float32)
458
+ if not rot == 0:
459
+ rot_rad = -rot * np.pi / 180
460
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
461
+ rot_mat[0, :2] = [cs, -sn]
462
+ rot_mat[1, :2] = [sn, cs]
463
+ keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1])
464
+ # flip the x coordinates
465
+ if flip:
466
+ keypoints_3d = fliplr_keypoints(keypoints_3d, list(range(len(keypoints_3d))))
467
+ keypoints_3d = keypoints_3d.astype('float32')
468
+ return keypoints_3d
469
+
470
+
471
+ def rot_aa(aa: np.array, rot: float) -> np.array:
472
+ """
473
+ Rotate axis angle parameters.
474
+ Args:
475
+ aa (np.array): Axis-angle vector of shape (3,).
476
+ rot (np.array): Rotation angle in degrees.
477
+ Returns:
478
+ np.array: Rotated axis-angle vector.
479
+ """
480
+ # pose parameters
481
+ R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
482
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
483
+ [0, 0, 1]])
484
+ # find the rotation of the hand in camera frame
485
+ per_rdg, _ = cv2.Rodrigues(aa)
486
+ # apply the global rotation to the global orientation
487
+ resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
488
+ aa = (resrot.T)[0]
489
+ return aa.astype(np.float32)
490
+
491
+
492
+ def smal_param_processing(smal_params: Dict, has_smal_params: Dict, rot: float, do_flip: bool) -> Tuple[Dict, Dict]:
493
+ """
494
+ Apply random augmentations to the SMAL parameters.
495
+ Args:
496
+ smal_params (Dict): SMAL parameter annotations.
497
+ has_smal_params (Dict): Whether SMAL annotations are valid.
498
+ rot (float): Random rotation applied to the keypoints.
499
+ do_flip (bool): Whether to flip keypoints or not.
500
+ Returns:
501
+ Dict, Dict: Transformed SMAL parameters and valid flags.
502
+ """
503
+ if do_flip:
504
+ smal_params, has_smal_params = fliplr_params(smal_params, has_smal_params)
505
+ smal_params['global_orient'] = rot_aa(smal_params['global_orient'], rot)
506
+ # camera location is not change, so the translation is not change too.
507
+ # smal_params['transl'] = np.dot(np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
508
+ # [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
509
+ # [0, 0, 1]], dtype=np.float32), smal_params['transl'])
510
+ return smal_params, has_smal_params
511
+
512
+
513
+ def get_example(img_path: Union[str,np.ndarray], center_x: float, center_y: float,
514
+ width: float, height: float,
515
+ keypoints_2d: np.array, keypoints_3d: np.array,
516
+ smal_params: Dict, has_smal_params: Dict,
517
+ patch_width: int, patch_height: int,
518
+ mean: np.array, std: np.array,
519
+ do_augment: bool, augm_config: CfgNode,
520
+ is_bgr: bool = True,
521
+ use_skimage_antialias: bool = False,
522
+ border_mode: int = cv2.BORDER_CONSTANT,
523
+ return_trans: bool = False,) -> Tuple:
524
+ """
525
+ Get an example from the dataset and (possibly) apply random augmentations.
526
+ Args:
527
+ img_path (str): Image filename
528
+ center_x (float): Bounding box center x coordinate in the original image.
529
+ center_y (float): Bounding box center y coordinate in the original image.
530
+ width (float): Bounding box width.
531
+ height (float): Bounding box height.
532
+ keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
533
+ keypoints_3d (np.array): Array with shape (N,4) containing the 3D keypoints.
534
+ smal_params (Dict): SMAL parameter annotations.
535
+ has_smal_params (Dict): Whether SMAL annotations are valid.
536
+ patch_width (float): Output box width.
537
+ patch_height (float): Output box height.
538
+ mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
539
+ std (np.array): Array of shape (3,) containing the std for normalizing the input image.
540
+ do_augment (bool): Whether to apply data augmentation or not.
541
+ aug_config (CfgNode): Config containing augmentation parameters.
542
+ Returns:
543
+ return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size
544
+ img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
545
+ keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
546
+ keypoints_3d (np.array): Array with shape (N,4) containing the transformed 3D keypoints.
547
+ smal_params (Dict): Transformed SMAL parameters.
548
+ has_smal_params (Dict): Valid flag for transformed SMAL parameters.
549
+ img_size (np.array): Image size of the original image.
550
+ """
551
+ if isinstance(img_path, str):
552
+ # 1. load image
553
+ cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
554
+ if not isinstance(cvimg, np.ndarray):
555
+ raise IOError("Fail to read %s" % img_path)
556
+ elif isinstance(img_path, np.ndarray):
557
+ cvimg = img_path
558
+ else:
559
+ raise TypeError('img_path must be either a string or a numpy array')
560
+ img_height, img_width, img_channels = cvimg.shape
561
+
562
+ img_size = np.array([img_height, img_width], dtype=np.int32)
563
+
564
+ # 2. get augmentation params
565
+ if do_augment:
566
+ # box rescale factor, rotation angle, flip or not flip, crop or not crop, ..., color scale, translation x, ...
567
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
568
+ else:
569
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0,
570
+ 1.0,
571
+ 1.0], 0., 0.
572
+ if width < 1 or height < 1:
573
+ # Skip invalid samples with width/height < 1
574
+ print(f"Warning: Invalid bbox size - width: {width}, height: {height}. Using default size.")
575
+ width = max(width, 1.0)
576
+ height = max(height, 1.0)
577
+
578
+ if do_extreme_crop:
579
+ if extreme_crop_lvl == 0:
580
+ center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d)
581
+ elif extreme_crop_lvl == 1:
582
+ center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height,
583
+ keypoints_2d)
584
+
585
+ THRESH = 4
586
+ if width1 < THRESH or height1 < THRESH:
587
+ pass
588
+ else:
589
+ center_x, center_y, width, height = center_x1, center_y1, width1, height1
590
+
591
+ center_x += width * tx
592
+ center_y += height * ty
593
+
594
+ # Process 3D keypoints
595
+ keypoints_3d = keypoint_3d_processing(keypoints_3d, rot, do_flip)
596
+
597
+ # 3. generate image patch
598
+ if use_skimage_antialias:
599
+ # Blur image to avoid aliasing artifacts
600
+ downsampling_factor = (patch_width / (width * scale))
601
+ if downsampling_factor > 1.1:
602
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor - 1) / 2, channel_axis=2, preserve_range=True,
603
+ truncate=3.0)
604
+ # augmentation image, translation matrix
605
+ img_patch_cv, trans, img_border_mask = generate_image_patch_cv2(cvimg,
606
+ center_x, center_y,
607
+ width, height,
608
+ patch_width, patch_height,
609
+ do_flip, scale, rot,
610
+ border_mode=border_mode)
611
+
612
+ image = img_patch_cv.copy()
613
+ if is_bgr:
614
+ image = image[:, :, ::-1]
615
+ img_patch_cv = image.copy()
616
+ img_patch = convert_cvimg_to_tensor(image) # [h, w, 4] -> [4, h, w]
617
+
618
+ smal_params, has_smal_params = smal_param_processing(smal_params, has_smal_params, rot, do_flip)
619
+
620
+ # apply normalization
621
+ for n_c in range(min(img_channels, 3)):
622
+ img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
623
+ if mean is not None and std is not None:
624
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
625
+
626
+ if do_flip:
627
+ keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, list(range(len(keypoints_2d))))
628
+
629
+ for n_jt in range(len(keypoints_2d)):
630
+ keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
631
+ keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
632
+
633
+ if not return_trans:
634
+ return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, img_border_mask
635
+ else:
636
+ return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, trans, img_border_mask
637
+
638
+
639
+ def get_cub17_example(cvimg: np.array,
640
+ keypoints_2d: np.array,
641
+ center_x: float, center_y: float,
642
+ width: float, height: float,
643
+ patch_width: int, patch_height: int,
644
+ mean: np.array, std: np.array,
645
+ do_augment: bool, augm_config: CfgNode,
646
+ return_trans=True) -> Tuple:
647
+ """
648
+ Get an example from the dataset and (possibly) apply random augmentations.
649
+ Args:
650
+ cvimg (np.ndarray): Image
651
+ keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
652
+ center_x (float): Bounding box center x coordinate in the original image.
653
+ center_y (float): Bounding box center y coordinate in the original image.
654
+ width (float): Bounding box width.
655
+ height (float): Bounding box height.
656
+ patch_width (int): Output box width.
657
+ patch_height (int): Output box height.
658
+ mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
659
+ std (np.array): Array of shape (3,) containing the std for normalizing the input image.
660
+ do_augment (bool): Whether to apply data augmentation or not.
661
+ aug_config (CfgNode): Config containing augmentation parameters.
662
+ Returns:
663
+ return img_patch, keypoints_2d
664
+ img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
665
+ keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
666
+ """
667
+ img_height, img_width, img_channels = cvimg.shape
668
+
669
+ img_size = np.array([img_height, img_width], dtype=np.int32)
670
+
671
+ # 2. get augmentation params
672
+ if do_augment:
673
+ # box rescale factor, rotation angle, flip or not flip, crop or not crop, ..., color scale, translation x, ...
674
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
675
+ else:
676
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0,
677
+ 1.0,
678
+ 1.0], 0., 0.
679
+ # bounding box height and width
680
+ center_x += width * tx
681
+ center_y += height * ty
682
+ # augmentation image, translation matrix
683
+ img_patch_cv, trans, img_border_mask = generate_image_patch_cv2(cvimg,
684
+ center_x, center_y,
685
+ width, height,
686
+ patch_width, patch_height,
687
+ do_flip, scale, rot,
688
+ border_mode=cv2.BORDER_CONSTANT)
689
+
690
+ image = img_patch_cv.copy()
691
+ img_patch = convert_cvimg_to_tensor(image) # [h, w, 4] -> [4, h, w]
692
+
693
+ # apply normalization
694
+ for n_c in range(min(img_channels, 3)):
695
+ img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
696
+ if mean is not None and std is not None:
697
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
698
+
699
+ if do_flip:
700
+ keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, list(range(len(keypoints_2d))))
701
+
702
+ for n_jt in range(len(keypoints_2d)):
703
+ keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
704
+ keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
705
+
706
+ if return_trans:
707
+ return img_patch, keypoints_2d, img_size, trans, img_border_mask
708
+ else:
709
+ return img_patch, keypoints_2d, img_size, img_border_mask
710
+
711
+
712
+ def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
713
+ """
714
+ Extreme cropping: Crop the box up to the hip locations.
715
+ Args:
716
+ center_x (float): x coordinate of the bounding box center.
717
+ center_y (float): y coordinate of the bounding box center.
718
+ width (float): Bounding box width.
719
+ height (float): Bounding box height.
720
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
721
+ Returns:
722
+ center_x (float): x coordinate of the new bounding box center.
723
+ center_y (float): y coordinate of the new bounding box center.
724
+ width (float): New bounding box width.
725
+ height (float): New bounding box height.
726
+ """
727
+ keypoints_2d = keypoints_2d.copy()
728
+ lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25 + 0, 25 + 1, 25 + 4, 25 + 5]
729
+ keypoints_2d[lower_body_keypoints, :] = 0
730
+ if keypoints_2d[:, -1].sum() > 1:
731
+ center, scale = get_bbox(keypoints_2d)
732
+ center_x = center[0]
733
+ center_y = center[1]
734
+ width = 1.1 * scale[0]
735
+ height = 1.1 * scale[1]
736
+ return center_x, center_y, width, height
737
+
738
+
739
+ def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
740
+ """
741
+ Extreme cropping: Crop the box up to the shoulder locations.
742
+ Args:
743
+ center_x (float): x coordinate of the bounding box center.
744
+ center_y (float): y coordinate of the bounding box center.
745
+ width (float): Bounding box width.
746
+ height (float): Bounding box height.
747
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
748
+ Returns:
749
+ center_x (float): x coordinate of the new bounding box center.
750
+ center_y (float): y coordinate of the new bounding box center.
751
+ width (float): New bounding box width.
752
+ height (float): New bounding box height.
753
+ """
754
+ keypoints_2d = keypoints_2d.copy()
755
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in
756
+ [0, 1, 2, 3, 4, 5, 6, 7,
757
+ 10, 11, 14, 15, 16]]
758
+ keypoints_2d[lower_body_keypoints, :] = 0
759
+ center, scale = get_bbox(keypoints_2d)
760
+ if keypoints_2d[:, -1].sum() > 1:
761
+ center, scale = get_bbox(keypoints_2d)
762
+ center_x = center[0]
763
+ center_y = center[1]
764
+ width = 1.2 * scale[0]
765
+ height = 1.2 * scale[1]
766
+ return center_x, center_y, width, height
767
+
768
+
769
+ def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
770
+ """
771
+ Extreme cropping: Crop the box and keep on only the head.
772
+ Args:
773
+ center_x (float): x coordinate of the bounding box center.
774
+ center_y (float): y coordinate of the bounding box center.
775
+ width (float): Bounding box width.
776
+ height (float): Bounding box height.
777
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
778
+ Returns:
779
+ center_x (float): x coordinate of the new bounding box center.
780
+ center_y (float): y coordinate of the new bounding box center.
781
+ width (float): New bounding box width.
782
+ height (float): New bounding box height.
783
+ """
784
+ keypoints_2d = keypoints_2d.copy()
785
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in
786
+ [0, 1, 2, 3, 4, 5, 6, 7, 8,
787
+ 9, 10, 11, 14, 15, 16]]
788
+ keypoints_2d[lower_body_keypoints, :] = 0
789
+ if keypoints_2d[:, -1].sum() > 1:
790
+ center, scale = get_bbox(keypoints_2d)
791
+ center_x = center[0]
792
+ center_y = center[1]
793
+ width = 1.3 * scale[0]
794
+ height = 1.3 * scale[1]
795
+ return center_x, center_y, width, height
796
+
797
+
798
+ def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
799
+ """
800
+ Extreme cropping: Crop the box and keep on only the torso.
801
+ Args:
802
+ center_x (float): x coordinate of the bounding box center.
803
+ center_y (float): y coordinate of the bounding box center.
804
+ width (float): Bounding box width.
805
+ height (float): Bounding box height.
806
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
807
+ Returns:
808
+ center_x (float): x coordinate of the new bounding box center.
809
+ center_y (float): y coordinate of the new bounding box center.
810
+ width (float): New bounding box width.
811
+ height (float): New bounding box height.
812
+ """
813
+ keypoints_2d = keypoints_2d.copy()
814
+ nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in
815
+ [0, 1, 4, 5, 6,
816
+ 7, 10, 11, 13,
817
+ 17, 18]]
818
+ keypoints_2d[nontorso_body_keypoints, :] = 0
819
+ if keypoints_2d[:, -1].sum() > 1:
820
+ center, scale = get_bbox(keypoints_2d)
821
+ center_x = center[0]
822
+ center_y = center[1]
823
+ width = 1.1 * scale[0]
824
+ height = 1.1 * scale[1]
825
+ return center_x, center_y, width, height
826
+
827
+
828
+ def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
829
+ """
830
+ Extreme cropping: Crop the box and keep on only the right arm.
831
+ Args:
832
+ center_x (float): x coordinate of the bounding box center.
833
+ center_y (float): y coordinate of the bounding box center.
834
+ width (float): Bounding box width.
835
+ height (float): Bounding box height.
836
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
837
+ Returns:
838
+ center_x (float): x coordinate of the new bounding box center.
839
+ center_y (float): y coordinate of the new bounding box center.
840
+ width (float): New bounding box width.
841
+ height (float): New bounding box height.
842
+ """
843
+ keypoints_2d = keypoints_2d.copy()
844
+ nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [
845
+ 25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
846
+ keypoints_2d[nonrightarm_body_keypoints, :] = 0
847
+ if keypoints_2d[:, -1].sum() > 1:
848
+ center, scale = get_bbox(keypoints_2d)
849
+ center_x = center[0]
850
+ center_y = center[1]
851
+ width = 1.1 * scale[0]
852
+ height = 1.1 * scale[1]
853
+ return center_x, center_y, width, height
854
+
855
+
856
+ def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
857
+ """
858
+ Extreme cropping: Crop the box and keep on only the left arm.
859
+ Args:
860
+ center_x (float): x coordinate of the bounding box center.
861
+ center_y (float): y coordinate of the bounding box center.
862
+ width (float): Bounding box width.
863
+ height (float): Bounding box height.
864
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
865
+ Returns:
866
+ center_x (float): x coordinate of the new bounding box center.
867
+ center_y (float): y coordinate of the new bounding box center.
868
+ width (float): New bounding box width.
869
+ height (float): New bounding box height.
870
+ """
871
+ keypoints_2d = keypoints_2d.copy()
872
+ nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [
873
+ 25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]]
874
+ keypoints_2d[nonleftarm_body_keypoints, :] = 0
875
+ if keypoints_2d[:, -1].sum() > 1:
876
+ center, scale = get_bbox(keypoints_2d)
877
+ center_x = center[0]
878
+ center_y = center[1]
879
+ width = 1.1 * scale[0]
880
+ height = 1.1 * scale[1]
881
+ return center_x, center_y, width, height
882
+
883
+
884
+ def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
885
+ """
886
+ Extreme cropping: Crop the box and keep on only the legs.
887
+ Args:
888
+ center_x (float): x coordinate of the bounding box center.
889
+ center_y (float): y coordinate of the bounding box center.
890
+ width (float): Bounding box width.
891
+ height (float): Bounding box height.
892
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
893
+ Returns:
894
+ center_x (float): x coordinate of the new bounding box center.
895
+ center_y (float): y coordinate of the new bounding box center.
896
+ width (float): New bounding box width.
897
+ height (float): New bounding box height.
898
+ """
899
+ keypoints_2d = keypoints_2d.copy()
900
+ nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in
901
+ [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]]
902
+ keypoints_2d[nonlegs_body_keypoints, :] = 0
903
+ if keypoints_2d[:, -1].sum() > 1:
904
+ center, scale = get_bbox(keypoints_2d)
905
+ center_x = center[0]
906
+ center_y = center[1]
907
+ width = 1.1 * scale[0]
908
+ height = 1.1 * scale[1]
909
+ return center_x, center_y, width, height
910
+
911
+
912
+ def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
913
+ """
914
+ Extreme cropping: Crop the box and keep on only the right leg.
915
+ Args:
916
+ center_x (float): x coordinate of the bounding box center.
917
+ center_y (float): y coordinate of the bounding box center.
918
+ width (float): Bounding box width.
919
+ height (float): Bounding box height.
920
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
921
+ Returns:
922
+ center_x (float): x coordinate of the new bounding box center.
923
+ center_y (float): y coordinate of the new bounding box center.
924
+ width (float): New bounding box width.
925
+ height (float): New bounding box height.
926
+ """
927
+ keypoints_2d = keypoints_2d.copy()
928
+ nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in
929
+ [3, 4, 5, 6, 7,
930
+ 8, 9, 10, 11,
931
+ 12, 13, 14, 15,
932
+ 16, 17, 18]]
933
+ keypoints_2d[nonrightleg_body_keypoints, :] = 0
934
+ if keypoints_2d[:, -1].sum() > 1:
935
+ center, scale = get_bbox(keypoints_2d)
936
+ center_x = center[0]
937
+ center_y = center[1]
938
+ width = 1.1 * scale[0]
939
+ height = 1.1 * scale[1]
940
+ return center_x, center_y, width, height
941
+
942
+
943
+ def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
944
+ """
945
+ Extreme cropping: Crop the box and keep on only the left leg.
946
+ Args:
947
+ center_x (float): x coordinate of the bounding box center.
948
+ center_y (float): y coordinate of the bounding box center.
949
+ width (float): Bounding box width.
950
+ height (float): Bounding box height.
951
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
952
+ Returns:
953
+ center_x (float): x coordinate of the new bounding box center.
954
+ center_y (float): y coordinate of the new bounding box center.
955
+ width (float): New bounding box width.
956
+ height (float): New bounding box height.
957
+ """
958
+ keypoints_2d = keypoints_2d.copy()
959
+ nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in
960
+ [0, 1, 2, 6, 7, 8,
961
+ 9, 10, 11, 12,
962
+ 13, 14, 15, 16,
963
+ 17, 18]]
964
+ keypoints_2d[nonleftleg_body_keypoints, :] = 0
965
+ if keypoints_2d[:, -1].sum() > 1:
966
+ center, scale = get_bbox(keypoints_2d)
967
+ center_x = center[0]
968
+ center_y = center[1]
969
+ width = 1.1 * scale[0]
970
+ height = 1.1 * scale[1]
971
+ return center_x, center_y, width, height
972
+
973
+
974
+ def full_body(keypoints_2d: np.array) -> bool:
975
+ """
976
+ Check if all main body joints are visible.
977
+ Args:
978
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
979
+ Returns:
980
+ bool: True if all main body joints are visible.
981
+ """
982
+
983
+ body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14]
984
+ body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]]
985
+ return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(
986
+ body_keypoints)
987
+
988
+
989
+ def upper_body(keypoints_2d: np.array):
990
+ """
991
+ Check if all upper body joints are visible.
992
+ Args:
993
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
994
+ Returns:
995
+ bool: True if all main body joints are visible.
996
+ """
997
+ lower_body_keypoints_openpose = [10, 11, 13, 14]
998
+ lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]]
999
+ upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18]
1000
+ upper_body_keypoints = [25 + 8, 25 + 9, 25 + 12, 25 + 13, 25 + 17, 25 + 18]
1001
+ return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0) \
1002
+ and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2)
1003
+
1004
+
1005
+ def get_bbox(keypoints_2d: np.array, rescale: float = 1.2) -> Tuple:
1006
+ """
1007
+ Get center and scale for bounding box from openpose detections.
1008
+ Args:
1009
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
1010
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
1011
+ Returns:
1012
+ center (np.array): Array of shape (2,) containing the new bounding box center.
1013
+ scale (float): New bounding box scale.
1014
+ """
1015
+ valid = keypoints_2d[:, -1] > 0
1016
+ valid_keypoints = keypoints_2d[valid][:, :-1]
1017
+ center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0))
1018
+ bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0))
1019
+ # adjust bounding box tightness
1020
+ scale = bbox_size
1021
+ scale *= rescale
1022
+ return center, scale
1023
+
1024
+
1025
+ def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
1026
+ """
1027
+ Perform extreme cropping
1028
+ Args:
1029
+ center_x (float): x coordinate of bounding box center.
1030
+ center_y (float): y coordinate of bounding box center.
1031
+ width (float): bounding box width.
1032
+ height (float): bounding box height.
1033
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
1034
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
1035
+ Returns:
1036
+ center_x (float): x coordinate of bounding box center.
1037
+ center_y (float): y coordinate of bounding box center.
1038
+ width (float): bounding box width.
1039
+ height (float): bounding box height.
1040
+ """
1041
+ p = torch.rand(1).item()
1042
+ if full_body(keypoints_2d):
1043
+ if p < 0.7:
1044
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
1045
+ elif p < 0.9:
1046
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
1047
+ else:
1048
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
1049
+ elif upper_body(keypoints_2d):
1050
+ if p < 0.9:
1051
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
1052
+ else:
1053
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
1054
+
1055
+ return center_x, center_y, max(width, height), max(width, height)
1056
+
1057
+
1058
+ def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float,
1059
+ keypoints_2d: np.array) -> Tuple:
1060
+ """
1061
+ Perform aggressive extreme cropping
1062
+ Args:
1063
+ center_x (float): x coordinate of bounding box center.
1064
+ center_y (float): y coordinate of bounding box center.
1065
+ width (float): bounding box width.
1066
+ height (float): bounding box height.
1067
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
1068
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
1069
+ Returns:
1070
+ center_x (float): x coordinate of bounding box center.
1071
+ center_y (float): y coordinate of bounding box center.
1072
+ width (float): bounding box width.
1073
+ height (float): bounding box height.
1074
+ """
1075
+ p = torch.rand(1).item()
1076
+ if full_body(keypoints_2d):
1077
+ if p < 0.2:
1078
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
1079
+ elif p < 0.3:
1080
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
1081
+ elif p < 0.4:
1082
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
1083
+ elif p < 0.5:
1084
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
1085
+ elif p < 0.6:
1086
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
1087
+ elif p < 0.7:
1088
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
1089
+ elif p < 0.8:
1090
+ center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d)
1091
+ elif p < 0.9:
1092
+ center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d)
1093
+ else:
1094
+ center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d)
1095
+ elif upper_body(keypoints_2d):
1096
+ if p < 0.2:
1097
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
1098
+ elif p < 0.4:
1099
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
1100
+ elif p < 0.6:
1101
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
1102
+ elif p < 0.8:
1103
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
1104
+ else:
1105
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
1106
+ return center_x, center_y, max(width, height), max(width, height)
prima/datasets/vitdet_dataset.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from typing import Dict
11
+
12
+ import cv2
13
+ import numpy as np
14
+ from skimage.filters import gaussian
15
+ from yacs.config import CfgNode
16
+ import torch
17
+
18
+ from .utils import (convert_cvimg_to_tensor,
19
+ expand_to_aspect_ratio,
20
+ generate_image_patch_cv2)
21
+
22
+ DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
23
+ DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
24
+
25
+
26
+ class ViTDetDataset(torch.utils.data.Dataset):
27
+
28
+ def __init__(self,
29
+ cfg: CfgNode,
30
+ img_cv2: np.array,
31
+ boxes: np.array,
32
+ rescale_factor=1,
33
+ train: bool = False,
34
+ **kwargs):
35
+ super().__init__()
36
+ self.cfg = cfg
37
+ self.img_cv2 = img_cv2
38
+ self.boxes = boxes
39
+
40
+ assert train is False, "ViTDetDataset is only for inference"
41
+ self.train = train
42
+ self.img_size = cfg.MODEL.IMAGE_SIZE
43
+ self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN)
44
+ self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD)
45
+
46
+ # Preprocess annotations
47
+ boxes = boxes.astype(np.float32)
48
+ self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0
49
+ self.scale = rescale_factor * (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0
50
+ self.animalid = np.arange(len(boxes), dtype=np.int32)
51
+
52
+ def __len__(self) -> int:
53
+ return len(self.animalid)
54
+
55
+ def __getitem__(self, idx: int) -> Dict[str, np.array]:
56
+
57
+ center = self.center[idx].copy()
58
+ center_x = center[0]
59
+ center_y = center[1]
60
+
61
+ scale = self.scale[idx]
62
+ BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None)
63
+ bbox_size = expand_to_aspect_ratio(scale * 200, target_aspect_ratio=BBOX_SHAPE).max()
64
+
65
+ patch_width = patch_height = self.img_size
66
+
67
+ flip = False
68
+
69
+ # 3. generate image patch
70
+ # if use_skimage_antialias:
71
+ cvimg = self.img_cv2.copy()
72
+ if True:
73
+ # Blur image to avoid aliasing artifacts
74
+ downsampling_factor = ((bbox_size * 1.0) / patch_width)
75
+ downsampling_factor = downsampling_factor / 2.0
76
+ if downsampling_factor > 1.1:
77
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor - 1) / 2, channel_axis=2, preserve_range=True)
78
+
79
+ img_patch_cv, trans, _ = generate_image_patch_cv2(cvimg,
80
+ center_x, center_y,
81
+ bbox_size, bbox_size,
82
+ patch_width, patch_height,
83
+ flip, 1.0, 0.0,
84
+ border_mode=cv2.BORDER_CONSTANT)
85
+ img_patch_cv = img_patch_cv[:, :, ::-1]
86
+ img_patch = convert_cvimg_to_tensor(img_patch_cv)
87
+
88
+ # apply normalization
89
+ for n_c in range(min(self.img_cv2.shape[2], 3)):
90
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c]
91
+
92
+ item = {
93
+ 'img': img_patch,
94
+ 'animalid': int(self.animalid[idx]),
95
+ 'box_center': self.center[idx].copy(),
96
+ 'box_size': bbox_size,
97
+ 'img_size': 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]]),
98
+ 'focal_length': np.array([self.cfg.EXTRA.FOCAL_LENGTH, self.cfg.EXTRA.FOCAL_LENGTH]),
99
+ }
100
+ return item
prima/models/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from .prima import PRIMA
11
+
12
+
13
+ def load_prima(checkpoint_path):
14
+ from pathlib import Path
15
+ from ..configs import get_config
16
+ model_cfg = str(Path(checkpoint_path).parent.parent / '.hydra/config.yaml')
17
+ model_cfg = get_config(model_cfg)
18
+
19
+ # Override some config values, to crop bbox correctly
20
+ if (model_cfg.MODEL.BACKBONE.TYPE == 'vit') and ('BBOX_SHAPE' not in model_cfg.MODEL):
21
+ model_cfg.defrost()
22
+ assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
23
+ model_cfg.MODEL.BBOX_SHAPE = [192, 256]
24
+ model_cfg.freeze()
25
+ if (model_cfg.MODEL.BACKBONE.TYPE == 'dinov3') and ('BBOX_SHAPE' not in model_cfg.MODEL):
26
+ model_cfg.defrost()
27
+ assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for dino backbone"
28
+ model_cfg.MODEL.BBOX_SHAPE = [256, 256]
29
+ model_cfg.freeze()
30
+
31
+ if (model_cfg.MODEL.BACKBONE.TYPE == 'dinov2') and ('BBOX_SHAPE' not in model_cfg.MODEL):
32
+ model_cfg.defrost()
33
+ assert model_cfg.MODEL.IMAGE_SIZE == 252, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 252 for dino backbone"
34
+ model_cfg.MODEL.BBOX_SHAPE = [252, 252]
35
+ model_cfg.freeze()
36
+
37
+
38
+
39
+ # Update config to be compatible with demo
40
+ if ('PRETRAINED_WEIGHTS' in model_cfg.MODEL.BACKBONE):
41
+ model_cfg.defrost()
42
+ model_cfg.MODEL.BACKBONE.pop('PRETRAINED_WEIGHTS')
43
+ model_cfg.freeze()
44
+
45
+ # Offscreen training renderer is not needed for demo/inference startup and
46
+ # can fail on some local OpenGL backends.
47
+ model = PRIMA.load_from_checkpoint(
48
+ checkpoint_path,
49
+ strict=False,
50
+ cfg=model_cfg,
51
+ map_location='cpu',
52
+ init_renderer=False,
53
+ )
54
+ return model, model_cfg
prima/models/backbones/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from .vit import vith
11
+
12
+
13
+
14
+
15
+ def create_backbone(cfg):
16
+ if cfg.MODEL.BACKBONE.TYPE in ['vith','concat','aa']: # vit bb will be used in these three cases - animal feature extractor
17
+ return vith(cfg)
18
+ else:
19
+ raise NotImplementedError('Backbone type is not implemented')
prima/models/backbones/vit.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ # Copyright (c) OpenMMLab. All rights reserved.
11
+ import math
12
+
13
+ import torch
14
+ from functools import partial
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torch.utils.checkpoint as checkpoint
18
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
19
+
20
+
21
+ def vith(cfg):
22
+ return ViT(
23
+ img_size=(256, 192),
24
+ patch_size=16,
25
+ embed_dim=1280,
26
+ depth=32,
27
+ num_heads=16,
28
+ ratio=1,
29
+ use_checkpoint=False,
30
+ # use_checkpoint=True,
31
+ mlp_ratio=4,
32
+ qkv_bias=True,
33
+ drop_path_rate=0.55,
34
+ use_cls=True, # cls for animal family classification
35
+ )
36
+
37
+
38
+ def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
39
+ """
40
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
41
+ dimension for the original embeddings.
42
+ Args:
43
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
44
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
45
+ hw (Tuple): size of input image tokens.
46
+
47
+ Returns:
48
+ Absolute positional embeddings after processing with shape (1, H, W, C)
49
+ """
50
+ cls_token = None
51
+ B, L, C = abs_pos.shape
52
+ if has_cls_token:
53
+ cls_token = abs_pos[:, 0:1]
54
+ abs_pos = abs_pos[:, 1:]
55
+
56
+ if ori_h != h or ori_w != w:
57
+ new_abs_pos = F.interpolate(
58
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
59
+ size=(h, w),
60
+ mode="bicubic",
61
+ align_corners=False,
62
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
63
+
64
+ else:
65
+ new_abs_pos = abs_pos
66
+
67
+ if cls_token is not None:
68
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
69
+ return new_abs_pos
70
+
71
+
72
+ class DropPath(nn.Module):
73
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
74
+ """
75
+
76
+ def __init__(self, drop_prob=None):
77
+ super(DropPath, self).__init__()
78
+ self.drop_prob = drop_prob
79
+
80
+ def forward(self, x):
81
+ return drop_path(x, self.drop_prob, self.training)
82
+
83
+ def extra_repr(self):
84
+ return 'p={}'.format(self.drop_prob)
85
+
86
+
87
+ class Mlp(nn.Module):
88
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
89
+ super().__init__()
90
+ out_features = out_features or in_features
91
+ hidden_features = hidden_features or in_features
92
+ self.fc1 = nn.Linear(in_features, hidden_features)
93
+ self.act = act_layer()
94
+ self.fc2 = nn.Linear(hidden_features, out_features)
95
+ self.drop = nn.Dropout(drop)
96
+
97
+ def forward(self, x):
98
+ x = self.fc1(x)
99
+ x = self.act(x)
100
+ x = self.fc2(x)
101
+ x = self.drop(x)
102
+ return x
103
+
104
+
105
+ class Attention(nn.Module):
106
+ def __init__(
107
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
108
+ proj_drop=0., attn_head_dim=None):
109
+ super().__init__()
110
+ self.num_heads = num_heads
111
+ head_dim = dim // num_heads
112
+ self.dim = dim
113
+
114
+ if attn_head_dim is not None:
115
+ head_dim = attn_head_dim
116
+ all_head_dim = head_dim * self.num_heads
117
+
118
+ self.scale = qk_scale or head_dim ** -0.5
119
+
120
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
121
+
122
+ self.attn_drop = nn.Dropout(attn_drop)
123
+ self.proj = nn.Linear(all_head_dim, dim)
124
+ self.proj_drop = nn.Dropout(proj_drop)
125
+
126
+ def forward(self, x):
127
+ B, N, C = x.shape
128
+ qkv = self.qkv(x)
129
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
130
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
131
+
132
+ q = q * self.scale
133
+ attn = (q @ k.transpose(-2, -1))
134
+ attn = attn.softmax(dim=-1)
135
+ attn = self.attn_drop(attn)
136
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
137
+
138
+ x = self.proj(x)
139
+ x = self.proj_drop(x)
140
+
141
+ return x
142
+
143
+
144
+ class Block(nn.Module):
145
+
146
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
147
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
148
+ norm_layer=nn.LayerNorm, attn_head_dim=None,
149
+ ):
150
+ super().__init__()
151
+
152
+ self.norm1 = norm_layer(dim)
153
+ self.attn = Attention(
154
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
155
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
156
+ )
157
+
158
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
159
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
160
+ self.norm2 = norm_layer(dim)
161
+ mlp_hidden_dim = int(dim * mlp_ratio)
162
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
163
+
164
+ def forward(self, x):
165
+ x = x + self.drop_path(self.attn(self.norm1(x)))
166
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
167
+ return x
168
+
169
+
170
+ class PatchEmbed(nn.Module):
171
+ """ Image to Patch Embedding
172
+ """
173
+
174
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
175
+ super().__init__()
176
+ img_size = to_2tuple(img_size)
177
+ patch_size = to_2tuple(patch_size)
178
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
179
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
180
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
181
+ self.img_size = img_size
182
+ self.patch_size = patch_size
183
+ self.num_patches = num_patches
184
+
185
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio),
186
+ padding=4 + 2 * (ratio // 2 - 1))
187
+
188
+ def forward(self, x, **kwargs):
189
+ B, C, H, W = x.shape
190
+ x = self.proj(x)
191
+ Hp, Wp = x.shape[2], x.shape[3]
192
+
193
+ x = x.flatten(2).transpose(1, 2)
194
+ return x, (Hp, Wp)
195
+
196
+
197
+ class HybridEmbed(nn.Module):
198
+ """ CNN Feature Map Embedding
199
+ Extract feature map from CNN, flatten, project to embedding dim.
200
+ """
201
+
202
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
203
+ super().__init__()
204
+ assert isinstance(backbone, nn.Module)
205
+ img_size = to_2tuple(img_size)
206
+ self.img_size = img_size
207
+ self.backbone = backbone
208
+ if feature_size is None:
209
+ with torch.no_grad():
210
+ training = backbone.training
211
+ if training:
212
+ backbone.eval()
213
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
214
+ feature_size = o.shape[-2:]
215
+ feature_dim = o.shape[1]
216
+ backbone.train(training)
217
+ else:
218
+ feature_size = to_2tuple(feature_size)
219
+ feature_dim = self.backbone.feature_info.channels()[-1]
220
+ self.num_patches = feature_size[0] * feature_size[1]
221
+ self.proj = nn.Linear(feature_dim, embed_dim)
222
+
223
+ def forward(self, x):
224
+ x = self.backbone(x)[-1]
225
+ x = x.flatten(2).transpose(1, 2)
226
+ x = self.proj(x)
227
+ return x
228
+
229
+
230
+ class ViT(nn.Module):
231
+
232
+ def __init__(self,
233
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
234
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
235
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
236
+ frozen_stages=-1, ratio=1, last_norm=True, use_cls=False,
237
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
238
+ ):
239
+ # Protect mutable default arguments
240
+ super(ViT, self).__init__()
241
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
242
+ self.num_classes = num_classes
243
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
244
+ self.frozen_stages = frozen_stages
245
+ self.use_checkpoint = use_checkpoint
246
+ self.patch_padding = patch_padding
247
+ self.freeze_attn = freeze_attn
248
+ self.freeze_ffn = freeze_ffn
249
+ self.depth = depth
250
+
251
+ if hybrid_backbone is not None:
252
+ self.patch_embed = HybridEmbed(
253
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
254
+ else:
255
+ self.patch_embed = PatchEmbed(
256
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
257
+ num_patches = self.patch_embed.num_patches
258
+
259
+ # since the pretraining model has class token
260
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
261
+
262
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
263
+
264
+ self.blocks = nn.ModuleList([
265
+ Block(
266
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
267
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
268
+ )
269
+ for i in range(depth)])
270
+
271
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
272
+
273
+ if self.pos_embed is not None:
274
+ trunc_normal_(self.pos_embed, std=.02)
275
+
276
+ self.use_cls = use_cls
277
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
278
+ nn.init.normal_(self.cls_token, std=1e-6)
279
+
280
+ self._freeze_stages()
281
+
282
+ def _freeze_stages(self):
283
+ """Freeze parameters."""
284
+ if self.frozen_stages >= 0:
285
+ self.patch_embed.eval()
286
+ for param in self.patch_embed.parameters():
287
+ param.requires_grad = False
288
+
289
+ for i in range(1, self.frozen_stages + 1):
290
+ m = self.blocks[i]
291
+ m.eval()
292
+ for param in m.parameters():
293
+ param.requires_grad = False
294
+
295
+ if self.freeze_attn:
296
+ for i in range(0, self.depth):
297
+ m = self.blocks[i]
298
+ m.attn.eval()
299
+ m.norm1.eval()
300
+ for param in m.attn.parameters():
301
+ param.requires_grad = False
302
+ for param in m.norm1.parameters():
303
+ param.requires_grad = False
304
+
305
+ if self.freeze_ffn:
306
+ self.pos_embed.requires_grad = False
307
+ self.patch_embed.eval()
308
+ for param in self.patch_embed.parameters():
309
+ param.requires_grad = False
310
+ for i in range(0, self.depth):
311
+ m = self.blocks[i]
312
+ m.mlp.eval()
313
+ m.norm2.eval()
314
+ for param in m.mlp.parameters():
315
+ param.requires_grad = False
316
+ for param in m.norm2.parameters():
317
+ param.requires_grad = False
318
+
319
+ def init_weights(self):
320
+ """Initialize the weights in backbone.
321
+ Args:
322
+ pretrained (str, optional): Path to pre-trained weights.
323
+ Defaults to None.
324
+ """
325
+
326
+ def _init_weights(m):
327
+ if isinstance(m, nn.Linear):
328
+ trunc_normal_(m.weight, std=.02)
329
+ if isinstance(m, nn.Linear) and m.bias is not None:
330
+ nn.init.constant_(m.bias, 0)
331
+ elif isinstance(m, nn.LayerNorm):
332
+ nn.init.constant_(m.bias, 0)
333
+ nn.init.constant_(m.weight, 1.0)
334
+
335
+ self.apply(_init_weights)
336
+
337
+ def get_num_layers(self):
338
+ return len(self.blocks)
339
+
340
+ @torch.jit.ignore
341
+ def no_weight_decay(self):
342
+ return {'pos_embed', 'cls_token'}
343
+
344
+ def forward_features(self, x):
345
+ B, C, H, W = x.shape
346
+ x, (Hp, Wp) = self.patch_embed(x)
347
+
348
+ if self.pos_embed is not None:
349
+ # fit for multiple GPU training
350
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
351
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
352
+
353
+ x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) if self.use_cls else x
354
+ for blk in self.blocks:
355
+ if self.use_checkpoint:
356
+ x = checkpoint.checkpoint(blk, x)
357
+ else:
358
+ x = blk(x)
359
+
360
+ x = self.last_norm(x)
361
+
362
+ cls = x[:, 0] if self.use_cls else None
363
+ x = x[:, 1:] if self.use_cls else x
364
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
365
+
366
+ return xp, cls # shape [B, D, Hp, Wp], [B, D]
367
+
368
+ def forward(self, x):
369
+ x, cls = self.forward_features(x)
370
+ return x, cls
371
+
372
+ def train(self, mode=True):
373
+ """Convert the model into training mode."""
374
+ super().train(mode)
375
+ self._freeze_stages()
prima/models/bioclip_embedding.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """
11
+ bioclip Embedding Module
12
+ Converts image batch to embeddings that can be concatenated with image features
13
+ """
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ class BioClipEmbedding(nn.Module):
19
+ """
20
+ Embeds images into a feature space using BioClip model that can be combined with image features.
21
+
22
+ Args:
23
+ embed_dim: Output embedding dimension, should match the dimension of image features for concatenation
24
+ """
25
+
26
+ def __init__(self, cfg, embed_dim: int = 1024):
27
+ super().__init__()
28
+
29
+ self.embed_dim = embed_dim
30
+
31
+ import open_clip
32
+
33
+ if cfg.MODEL.BIOCLIP_EMBEDDING.TYPE == 'bioclip2':
34
+ print("[BioClipEmbedding] Using BioClip2 model from Hugging Face Hub")
35
+ self.species_model, _,_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2')
36
+ else:
37
+ self.species_model, _,_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
38
+ # tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip')
39
+
40
+
41
+ self.species_model.eval()
42
+
43
+ # Get the output dimension from the model
44
+ bioclip_output_dim = self.species_model.visual.output_dim
45
+
46
+ # Project to target dimension
47
+ self.projection = nn.Sequential(
48
+ nn.Linear(bioclip_output_dim, embed_dim),
49
+ nn.LayerNorm(embed_dim),
50
+ )
51
+
52
+ def forward(self, images: torch.Tensor) -> torch.Tensor:
53
+ """
54
+ Args:
55
+ images: Tensor of shape (B, C, H, W) representing a batch of images
56
+ Returns:
57
+ Tensor of shape (B, embed_dim) representing the embedded features
58
+ """
59
+ # BioClip expects 224x224 input, resize if needed
60
+ if images.shape[-2:] != (224, 224):
61
+ images_resized = F.interpolate(images, size=(224, 224), mode='bilinear', align_corners=False)
62
+ else:
63
+ images_resized = images
64
+
65
+ with torch.no_grad():
66
+ image_features = self.species_model.encode_image(images_resized)
67
+
68
+ projected_features = self.projection(image_features)
69
+
70
+ return projected_features
prima/models/components/__init__.py ADDED
File without changes
prima/models/components/model_utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
11
+ # All rights reserved.
12
+
13
+ # This source code is licensed under the license found in the
14
+ # LICENSE file in the root directory of this source tree.
15
+
16
+
17
+ import copy
18
+ from typing import Tuple
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+
26
+ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
27
+ """
28
+ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
29
+ that are temporally closest to the current frame at `frame_idx`. Here, we take
30
+ - a) the closest conditioning frame before `frame_idx` (if any);
31
+ - b) the closest conditioning frame after `frame_idx` (if any);
32
+ - c) any other temporally closest conditioning frames until reaching a total
33
+ of `max_cond_frame_num` conditioning frames.
34
+
35
+ Outputs:
36
+ - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
37
+ - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
38
+ """
39
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
40
+ selected_outputs = cond_frame_outputs
41
+ unselected_outputs = {}
42
+ else:
43
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
44
+ selected_outputs = {}
45
+
46
+ # the closest conditioning frame before `frame_idx` (if any)
47
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
48
+ if idx_before is not None:
49
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
50
+
51
+ # the closest conditioning frame after `frame_idx` (if any)
52
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
53
+ if idx_after is not None:
54
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
55
+
56
+ # add other temporally closest conditioning frames until reaching a total
57
+ # of `max_cond_frame_num` conditioning frames.
58
+ num_remain = max_cond_frame_num - len(selected_outputs)
59
+ inds_remain = sorted(
60
+ (t for t in cond_frame_outputs if t not in selected_outputs),
61
+ key=lambda x: abs(x - frame_idx),
62
+ )[:num_remain]
63
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
64
+ unselected_outputs = {
65
+ t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
66
+ }
67
+
68
+ return selected_outputs, unselected_outputs
69
+
70
+
71
+ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
72
+ """
73
+ Get 1D sine positional embedding as in the original Transformer paper.
74
+ """
75
+ pe_dim = dim // 2
76
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
77
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
78
+
79
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
80
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
81
+ return pos_embed
82
+
83
+
84
+ def get_activation_fn(activation):
85
+ """Return an activation function given a string"""
86
+ if activation == "relu":
87
+ return F.relu
88
+ if activation == "gelu":
89
+ return F.gelu
90
+ if activation == "glu":
91
+ return F.glu
92
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
93
+
94
+
95
+ def get_clones(module, N):
96
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
97
+
98
+
99
+ class DropPath(nn.Module):
100
+ # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
101
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
102
+ super(DropPath, self).__init__()
103
+ self.drop_prob = drop_prob
104
+ self.scale_by_keep = scale_by_keep
105
+
106
+ def forward(self, x):
107
+ if self.drop_prob == 0.0 or not self.training:
108
+ return x
109
+ keep_prob = 1 - self.drop_prob
110
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
111
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
112
+ if keep_prob > 0.0 and self.scale_by_keep:
113
+ random_tensor.div_(keep_prob)
114
+ return x * random_tensor
115
+
116
+
117
+ # Lightly adapted from
118
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
119
+ class MLP(nn.Module):
120
+ def __init__(
121
+ self,
122
+ input_dim: int,
123
+ hidden_dim: int,
124
+ output_dim: int,
125
+ num_layers: int,
126
+ activation: nn.Module = nn.ReLU,
127
+ sigmoid_output: bool = False,
128
+ ) -> None:
129
+ super().__init__()
130
+ self.num_layers = num_layers
131
+ h = [hidden_dim] * (num_layers - 1)
132
+ self.layers = nn.ModuleList(
133
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
134
+ )
135
+ self.sigmoid_output = sigmoid_output
136
+ self.act = activation()
137
+
138
+ def forward(self, x):
139
+ for i, layer in enumerate(self.layers):
140
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
141
+ if self.sigmoid_output:
142
+ x = F.sigmoid(x)
143
+ return x
144
+
145
+
146
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
147
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
148
+ class LayerNorm2d(nn.Module):
149
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
150
+ super().__init__()
151
+ self.weight = nn.Parameter(torch.ones(num_channels))
152
+ self.bias = nn.Parameter(torch.zeros(num_channels))
153
+ self.eps = eps
154
+
155
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
156
+ u = x.mean(1, keepdim=True)
157
+ s = (x - u).pow(2).mean(1, keepdim=True)
158
+ x = (x - u) / torch.sqrt(s + self.eps)
159
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
160
+ return x
prima/models/components/pose_transformer.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from inspect import isfunction
11
+ from typing import Callable, Optional
12
+
13
+ import torch
14
+ from einops import rearrange
15
+ from einops.layers.torch import Rearrange
16
+ from torch import nn
17
+
18
+ from .t_cond_mlp import (
19
+ AdaptiveLayerNorm1D,
20
+ FrequencyEmbedder,
21
+ normalization_layer,
22
+ )
23
+
24
+
25
+ def exists(val):
26
+ return val is not None
27
+
28
+
29
+ def default(val, d):
30
+ if exists(val):
31
+ return val
32
+ return d() if isfunction(d) else d
33
+
34
+
35
+ class PreNorm(nn.Module):
36
+ def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
37
+ super().__init__()
38
+ self.norm = normalization_layer(norm, dim, norm_cond_dim)
39
+ self.fn = fn
40
+
41
+ def forward(self, x: torch.Tensor, *args, **kwargs):
42
+ if isinstance(self.norm, AdaptiveLayerNorm1D):
43
+ return self.fn(self.norm(x, *args), **kwargs)
44
+ else:
45
+ return self.fn(self.norm(x), **kwargs)
46
+
47
+
48
+ class FeedForward(nn.Module):
49
+ def __init__(self, dim, hidden_dim, dropout=0.0):
50
+ super().__init__()
51
+ self.net = nn.Sequential(
52
+ nn.Linear(dim, hidden_dim),
53
+ nn.GELU(),
54
+ nn.Dropout(dropout),
55
+ nn.Linear(hidden_dim, dim),
56
+ nn.Dropout(dropout),
57
+ )
58
+
59
+ def forward(self, x):
60
+ return self.net(x)
61
+
62
+
63
+ class Attention(nn.Module):
64
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
65
+ super().__init__()
66
+ inner_dim = dim_head * heads
67
+ project_out = not (heads == 1 and dim_head == dim)
68
+
69
+ self.heads = heads
70
+ self.scale = dim_head**-0.5
71
+
72
+ self.attend = nn.Softmax(dim=-1)
73
+ self.dropout = nn.Dropout(dropout)
74
+
75
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
76
+
77
+ self.to_out = (
78
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
79
+ if project_out
80
+ else nn.Identity()
81
+ )
82
+
83
+ def forward(self, x):
84
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
85
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
86
+
87
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
88
+
89
+ attn = self.attend(dots)
90
+ attn = self.dropout(attn)
91
+
92
+ out = torch.matmul(attn, v)
93
+ out = rearrange(out, "b h n d -> b n (h d)")
94
+ return self.to_out(out)
95
+
96
+
97
+ class CrossAttention(nn.Module):
98
+ def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
99
+ super().__init__()
100
+ inner_dim = dim_head * heads
101
+ project_out = not (heads == 1 and dim_head == dim)
102
+
103
+ self.heads = heads
104
+ self.scale = dim_head**-0.5
105
+
106
+ self.attend = nn.Softmax(dim=-1)
107
+ self.dropout = nn.Dropout(dropout)
108
+
109
+ context_dim = default(context_dim, dim)
110
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
111
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
112
+
113
+ self.to_out = (
114
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
115
+ if project_out
116
+ else nn.Identity()
117
+ )
118
+
119
+ def forward(self, x, context=None):
120
+ context = default(context, x)
121
+ k, v = self.to_kv(context).chunk(2, dim=-1)
122
+ q = self.to_q(x)
123
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
124
+
125
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
126
+
127
+ attn = self.attend(dots)
128
+ attn = self.dropout(attn)
129
+
130
+ out = torch.matmul(attn, v)
131
+ out = rearrange(out, "b h n d -> b n (h d)")
132
+ return self.to_out(out)
133
+
134
+
135
+ class Transformer(nn.Module):
136
+ def __init__(
137
+ self,
138
+ dim: int,
139
+ depth: int,
140
+ heads: int,
141
+ dim_head: int,
142
+ mlp_dim: int,
143
+ dropout: float = 0.0,
144
+ norm: str = "layer",
145
+ norm_cond_dim: int = -1,
146
+ ):
147
+ super().__init__()
148
+ self.layers = nn.ModuleList([])
149
+ for _ in range(depth):
150
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
151
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
152
+ self.layers.append(
153
+ nn.ModuleList(
154
+ [
155
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
156
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
157
+ ]
158
+ )
159
+ )
160
+
161
+ def forward(self, x: torch.Tensor, *args):
162
+ for attn, ff in self.layers:
163
+ x = attn(x, *args) + x
164
+ x = ff(x, *args) + x
165
+ return x
166
+
167
+
168
+ class TransformerCrossAttn(nn.Module):
169
+ def __init__(
170
+ self,
171
+ dim: int,
172
+ depth: int,
173
+ heads: int,
174
+ dim_head: int,
175
+ mlp_dim: int,
176
+ dropout: float = 0.0,
177
+ norm: str = "layer",
178
+ norm_cond_dim: int = -1,
179
+ context_dim: Optional[int] = None,
180
+ ):
181
+ super().__init__()
182
+ self.layers = nn.ModuleList([])
183
+ for _ in range(depth):
184
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
185
+ ca = CrossAttention(
186
+ dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
187
+ )
188
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
189
+ self.layers.append(
190
+ nn.ModuleList(
191
+ [
192
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
193
+ PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
194
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
195
+ ]
196
+ )
197
+ )
198
+
199
+ def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
200
+ if context_list is None:
201
+ context_list = [context] * len(self.layers)
202
+ if len(context_list) != len(self.layers):
203
+ raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
204
+
205
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
206
+ x = self_attn(x, *args) + x
207
+ x = cross_attn(x, *args, context=context_list[i]) + x
208
+ x = ff(x, *args) + x
209
+ return x
210
+
211
+
212
+ class DropTokenDropout(nn.Module):
213
+ def __init__(self, p: float = 0.1):
214
+ super().__init__()
215
+ if p < 0 or p > 1:
216
+ raise ValueError(
217
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
218
+ )
219
+ self.p = p
220
+
221
+ def forward(self, x: torch.Tensor):
222
+ # x: (batch_size, seq_len, dim)
223
+ if self.training and self.p > 0:
224
+ zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
225
+
226
+ if zero_mask.any():
227
+ x = x[:, ~zero_mask, :]
228
+ return x
229
+
230
+
231
+ class ZeroTokenDropout(nn.Module):
232
+ def __init__(self, p: float = 0.1):
233
+ super().__init__()
234
+ if p < 0 or p > 1:
235
+ raise ValueError(
236
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
237
+ )
238
+ self.p = p
239
+
240
+ def forward(self, x: torch.Tensor):
241
+ # x: (batch_size, seq_len, dim)
242
+ if self.training and self.p > 0:
243
+ zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
244
+ # Zero-out the masked tokens
245
+ x[zero_mask, :] = 0
246
+ return x
247
+
248
+
249
+ class TransformerEncoder(nn.Module):
250
+ def __init__(
251
+ self,
252
+ num_tokens: int,
253
+ token_dim: int,
254
+ dim: int,
255
+ depth: int,
256
+ heads: int,
257
+ mlp_dim: int,
258
+ dim_head: int = 64,
259
+ dropout: float = 0.0,
260
+ emb_dropout: float = 0.0,
261
+ emb_dropout_type: str = "drop",
262
+ emb_dropout_loc: str = "token",
263
+ norm: str = "layer",
264
+ norm_cond_dim: int = -1,
265
+ token_pe_numfreq: int = -1,
266
+ ):
267
+ super().__init__()
268
+ if token_pe_numfreq > 0:
269
+ token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
270
+ self.to_token_embedding = nn.Sequential(
271
+ Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
272
+ FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
273
+ Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
274
+ nn.Linear(token_dim_new, dim),
275
+ )
276
+ else:
277
+ self.to_token_embedding = nn.Linear(token_dim, dim)
278
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
279
+ if emb_dropout_type == "drop":
280
+ self.dropout = DropTokenDropout(emb_dropout)
281
+ elif emb_dropout_type == "zero":
282
+ self.dropout = ZeroTokenDropout(emb_dropout)
283
+ else:
284
+ raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
285
+ self.emb_dropout_loc = emb_dropout_loc
286
+
287
+ self.transformer = Transformer(
288
+ dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
289
+ )
290
+
291
+ def forward(self, inp: torch.Tensor, *args, **kwargs):
292
+ x = inp
293
+
294
+ if self.emb_dropout_loc == "input":
295
+ x = self.dropout(x)
296
+ x = self.to_token_embedding(x)
297
+
298
+ if self.emb_dropout_loc == "token":
299
+ x = self.dropout(x)
300
+ b, n, _ = x.shape
301
+ x += self.pos_embedding[:, :n]
302
+
303
+ if self.emb_dropout_loc == "token_afterpos":
304
+ x = self.dropout(x)
305
+ x = self.transformer(x, *args)
306
+ return x
307
+
308
+
309
+ class TransformerDecoder(nn.Module):
310
+ def __init__(
311
+ self,
312
+ num_tokens: int,
313
+ token_dim: int,
314
+ dim: int,
315
+ depth: int,
316
+ heads: int,
317
+ mlp_dim: int,
318
+ dim_head: int = 64,
319
+ dropout: float = 0.0,
320
+ emb_dropout: float = 0.0,
321
+ emb_dropout_type: str = 'drop',
322
+ norm: str = "layer",
323
+ norm_cond_dim: int = -1,
324
+ context_dim: Optional[int] = None,
325
+ skip_token_embedding: bool = False,
326
+ ):
327
+ super().__init__()
328
+ if not skip_token_embedding:
329
+ self.to_token_embedding = nn.Linear(token_dim, dim)
330
+ else:
331
+ self.to_token_embedding = nn.Identity()
332
+ if token_dim != dim:
333
+ raise ValueError(
334
+ f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
335
+ )
336
+
337
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
338
+ if emb_dropout_type == "drop":
339
+ self.dropout = DropTokenDropout(emb_dropout)
340
+ elif emb_dropout_type == "zero":
341
+ self.dropout = ZeroTokenDropout(emb_dropout)
342
+ elif emb_dropout_type == "normal":
343
+ self.dropout = nn.Dropout(emb_dropout)
344
+
345
+ self.transformer = TransformerCrossAttn(
346
+ dim,
347
+ depth,
348
+ heads,
349
+ dim_head,
350
+ mlp_dim,
351
+ dropout,
352
+ norm=norm,
353
+ norm_cond_dim=norm_cond_dim,
354
+ context_dim=context_dim,
355
+ )
356
+
357
+ def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
358
+ x = self.to_token_embedding(inp)
359
+ b, n, _ = x.shape
360
+
361
+ x = self.dropout(x)
362
+ x += self.pos_embedding[:, :n]
363
+
364
+ x = self.transformer(x, *args, context=context, context_list=context_list)
365
+ return x
366
+
prima/models/components/position_encoding.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
11
+ # All rights reserved.
12
+
13
+ # This source code is licensed under the license found in the
14
+ # LICENSE file in the root directory of this source tree.
15
+
16
+ import math
17
+ from typing import Any, Optional, Tuple
18
+
19
+ import numpy as np
20
+
21
+ import torch
22
+ from torch import nn
23
+
24
+ # Rotary Positional Encoding, adapted from:
25
+ # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
26
+ # 2. https://github.com/naver-ai/rope-vit
27
+ # 3. https://github.com/lucidrains/rotary-embedding-torch
28
+
29
+
30
+ def init_t_xy(end_x: int, end_y: int):
31
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
32
+ t_x = (t % end_x).float()
33
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
34
+ return t_x, t_y
35
+
36
+
37
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
38
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
39
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
40
+
41
+ t_x, t_y = init_t_xy(end_x, end_y)
42
+ freqs_x = torch.outer(t_x, freqs_x)
43
+ freqs_y = torch.outer(t_y, freqs_y)
44
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
45
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
46
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
47
+
48
+
49
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
50
+ ndim = x.ndim
51
+ assert 0 <= 1 < ndim
52
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
53
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
54
+ return freqs_cis.view(*shape)
55
+
56
+
57
+ def apply_rotary_enc(
58
+ xq: torch.Tensor,
59
+ xk: torch.Tensor,
60
+ freqs_cis: torch.Tensor,
61
+ repeat_freqs_k: bool = False,
62
+ ):
63
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
64
+ xk_ = (
65
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
66
+ if xk.shape[-2] != 0
67
+ else None
68
+ )
69
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
70
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
71
+ if xk_ is None:
72
+ # no keys to rotate, due to dropout
73
+ return xq_out.type_as(xq).to(xq.device), xk
74
+ # repeat freqs along seq_len dim to match k seq_len
75
+ if repeat_freqs_k:
76
+ r = xk_.shape[-2] // xq_.shape[-2]
77
+ if freqs_cis.is_cuda:
78
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
79
+ else:
80
+ # torch.repeat on complex numbers may not be supported on non-CUDA devices
81
+ # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
82
+ freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
83
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
84
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
prima/models/components/t_cond_mlp.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import copy
11
+ from typing import List, Optional
12
+
13
+ import torch
14
+
15
+
16
+ class AdaptiveLayerNorm1D(torch.nn.Module):
17
+ def __init__(self, data_dim: int, norm_cond_dim: int):
18
+ super().__init__()
19
+ if data_dim <= 0:
20
+ raise ValueError(f"data_dim must be positive, but got {data_dim}")
21
+ if norm_cond_dim <= 0:
22
+ raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
23
+ self.norm = torch.nn.LayerNorm(data_dim)
24
+ self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
25
+ torch.nn.init.zeros_(self.linear.weight)
26
+ torch.nn.init.zeros_(self.linear.bias)
27
+
28
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
29
+ # x: (batch, ..., data_dim)
30
+ # t: (batch, norm_cond_dim)
31
+ # return: (batch, data_dim)
32
+ x = self.norm(x)
33
+ alpha, beta = self.linear(t).chunk(2, dim=-1)
34
+
35
+ # Add singleton dimensions to alpha and beta
36
+ if x.dim() > 2:
37
+ alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
38
+ beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
39
+
40
+ return x * (1 + alpha) + beta
41
+
42
+
43
+ class SequentialCond(torch.nn.Sequential):
44
+ def forward(self, input, *args, **kwargs):
45
+ for module in self:
46
+ if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
47
+ input = module(input, *args, **kwargs)
48
+ else:
49
+ input = module(input)
50
+ return input
51
+
52
+
53
+ def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
54
+ if norm == "batch":
55
+ return torch.nn.BatchNorm1d(dim)
56
+ elif norm == "layer":
57
+ return torch.nn.LayerNorm(dim)
58
+ elif norm == "ada":
59
+ assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
60
+ return AdaptiveLayerNorm1D(dim, norm_cond_dim)
61
+ elif norm is None:
62
+ return torch.nn.Identity()
63
+ else:
64
+ raise ValueError(f"Unknown norm: {norm}")
65
+
66
+
67
+ def linear_norm_activ_dropout(
68
+ input_dim: int,
69
+ output_dim: int,
70
+ activation: torch.nn.Module = torch.nn.ReLU(),
71
+ bias: bool = True,
72
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
73
+ dropout: float = 0.0,
74
+ norm_cond_dim: int = -1,
75
+ ) -> SequentialCond:
76
+ layers = []
77
+ layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
78
+ if norm is not None:
79
+ layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
80
+ layers.append(copy.deepcopy(activation))
81
+ if dropout > 0.0:
82
+ layers.append(torch.nn.Dropout(dropout))
83
+ return SequentialCond(*layers)
84
+
85
+
86
+ def create_simple_mlp(
87
+ input_dim: int,
88
+ hidden_dims: List[int],
89
+ output_dim: int,
90
+ activation: torch.nn.Module = torch.nn.ReLU(),
91
+ bias: bool = True,
92
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
93
+ dropout: float = 0.0,
94
+ norm_cond_dim: int = -1,
95
+ ) -> SequentialCond:
96
+ layers = []
97
+ prev_dim = input_dim
98
+ for hidden_dim in hidden_dims:
99
+ layers.extend(
100
+ linear_norm_activ_dropout(
101
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
102
+ )
103
+ )
104
+ prev_dim = hidden_dim
105
+ layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
106
+ return SequentialCond(*layers)
107
+
108
+
109
+ class ResidualMLPBlock(torch.nn.Module):
110
+ def __init__(
111
+ self,
112
+ input_dim: int,
113
+ hidden_dim: int,
114
+ num_hidden_layers: int,
115
+ output_dim: int,
116
+ activation: torch.nn.Module = torch.nn.ReLU(),
117
+ bias: bool = True,
118
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
119
+ dropout: float = 0.0,
120
+ norm_cond_dim: int = -1,
121
+ ):
122
+ super().__init__()
123
+ if not (input_dim == output_dim == hidden_dim):
124
+ raise NotImplementedError(
125
+ f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
126
+ )
127
+
128
+ layers = []
129
+ prev_dim = input_dim
130
+ for i in range(num_hidden_layers):
131
+ layers.append(
132
+ linear_norm_activ_dropout(
133
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
134
+ )
135
+ )
136
+ prev_dim = hidden_dim
137
+ self.model = SequentialCond(*layers)
138
+ self.skip = torch.nn.Identity()
139
+
140
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
141
+ return x + self.model(x, *args, **kwargs)
142
+
143
+
144
+ class ResidualMLP(torch.nn.Module):
145
+ def __init__(
146
+ self,
147
+ input_dim: int,
148
+ hidden_dim: int,
149
+ num_hidden_layers: int,
150
+ output_dim: int,
151
+ activation: torch.nn.Module = torch.nn.ReLU(),
152
+ bias: bool = True,
153
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
154
+ dropout: float = 0.0,
155
+ num_blocks: int = 1,
156
+ norm_cond_dim: int = -1,
157
+ ):
158
+ super().__init__()
159
+ self.input_dim = input_dim
160
+ self.model = SequentialCond(
161
+ linear_norm_activ_dropout(
162
+ input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
163
+ ),
164
+ *[
165
+ ResidualMLPBlock(
166
+ hidden_dim,
167
+ hidden_dim,
168
+ num_hidden_layers,
169
+ hidden_dim,
170
+ activation,
171
+ bias,
172
+ norm,
173
+ dropout,
174
+ norm_cond_dim,
175
+ )
176
+ for _ in range(num_blocks)
177
+ ],
178
+ torch.nn.Linear(hidden_dim, output_dim, bias=bias),
179
+ )
180
+
181
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
182
+ return self.model(x, *args, **kwargs)
183
+
184
+
185
+ class FrequencyEmbedder(torch.nn.Module):
186
+ def __init__(self, num_frequencies, max_freq_log2):
187
+ super().__init__()
188
+ frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
189
+ self.register_buffer("frequencies", frequencies)
190
+
191
+ def forward(self, x):
192
+ # x should be of size (N,) or (N, D)
193
+ N = x.size(0)
194
+ if x.dim() == 1: # (N,)
195
+ x = x.unsqueeze(1) # (N, D) where D=1
196
+ x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
197
+ scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
198
+ s = torch.sin(scaled)
199
+ c = torch.cos(scaled)
200
+ embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
201
+ N, -1
202
+ ) # (N, D * 2 * num_frequencies + D)
203
+ return embedded
204
+
prima/models/components/transformer.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
11
+ # All rights reserved.
12
+
13
+ # This source code is licensed under the license found in the
14
+ # LICENSE file in the root directory of this source tree.
15
+
16
+ import contextlib
17
+ import math
18
+ import warnings
19
+ from functools import partial
20
+ from typing import Tuple, Type
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import nn, Tensor
25
+
26
+ from .position_encoding import apply_rotary_enc, compute_axial_cis
27
+ from .model_utils import MLP
28
+
29
+ warnings.simplefilter(action="ignore", category=FutureWarning)
30
+
31
+
32
+ def get_sdpa_settings():
33
+ if torch.cuda.is_available():
34
+ old_gpu = torch.cuda.get_device_properties(0).major < 7
35
+ # only use Flash Attention on Ampere (8.0) or newer GPUs
36
+ use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
37
+ if not use_flash_attn:
38
+ warnings.warn(
39
+ "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
40
+ category=UserWarning,
41
+ stacklevel=2,
42
+ )
43
+ # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
44
+ # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
45
+ pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
46
+ if pytorch_version < (2, 2):
47
+ warnings.warn(
48
+ f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
49
+ "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
50
+ category=UserWarning,
51
+ stacklevel=2,
52
+ )
53
+ math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
54
+ else:
55
+ old_gpu = True
56
+ use_flash_attn = False
57
+ math_kernel_on = True
58
+
59
+ return old_gpu, use_flash_attn, math_kernel_on
60
+
61
+
62
+ # Check whether Flash Attention is available (and use it by default)
63
+ OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
64
+ # A fallback setting to allow all available kernels if Flash Attention fails
65
+ ALLOW_ALL_KERNELS = False
66
+
67
+
68
+ def sdp_kernel_context(dropout_p):
69
+ """
70
+ Get the context for the attention scaled dot-product kernel. We use Flash Attention
71
+ by default, but fall back to all available kernels if Flash Attention fails.
72
+ """
73
+ if ALLOW_ALL_KERNELS:
74
+ return contextlib.nullcontext()
75
+
76
+ return torch.backends.cuda.sdp_kernel(
77
+ enable_flash=USE_FLASH_ATTN,
78
+ # if Flash attention kernel is off, then math kernel needs to be enabled
79
+ enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
80
+ enable_mem_efficient=OLD_GPU,
81
+ )
82
+
83
+
84
+ class TwoWayTransformer(nn.Module):
85
+ def __init__(
86
+ self,
87
+ depth: int,
88
+ embedding_dim: int,
89
+ num_heads: int,
90
+ mlp_dim: int,
91
+ activation: Type[nn.Module] = nn.ReLU,
92
+ attention_downsample_rate: int = 2,
93
+ ) -> None:
94
+ """
95
+ A transformer decoder that attends to an input image using
96
+ queries whose positional embedding is supplied.
97
+
98
+ Args:
99
+ depth (int): number of layers in the transformer
100
+ embedding_dim (int): the channel dimension for the input embeddings
101
+ num_heads (int): the number of heads for multihead attention. Must
102
+ divide embedding_dim
103
+ mlp_dim (int): the channel dimension internal to the MLP block
104
+ activation (nn.Module): the activation to use in the MLP block
105
+ """
106
+ super().__init__()
107
+ self.depth = depth
108
+ self.embedding_dim = embedding_dim
109
+ self.num_heads = num_heads
110
+ self.mlp_dim = mlp_dim
111
+ self.layers = nn.ModuleList()
112
+
113
+ for i in range(depth):
114
+ self.layers.append(
115
+ TwoWayAttentionBlock(
116
+ embedding_dim=embedding_dim,
117
+ num_heads=num_heads,
118
+ mlp_dim=mlp_dim,
119
+ activation=activation,
120
+ attention_downsample_rate=attention_downsample_rate,
121
+ skip_first_layer_pe=(i == 0),
122
+ )
123
+ )
124
+
125
+ self.final_attn_token_to_image = Attention(
126
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
127
+ )
128
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
129
+
130
+ def forward(
131
+ self,
132
+ image_embedding: Tensor,
133
+ image_pe: Tensor,
134
+ point_embedding: Tensor,
135
+ ) -> Tuple[Tensor, Tensor]:
136
+ """
137
+ Args:
138
+ image_embedding (torch.Tensor): image to attend to. Should be shape
139
+ B x embedding_dim x h x w for any h and w.
140
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
141
+ have the same shape as image_embedding.
142
+ point_embedding (torch.Tensor): the embedding to add to the query points.
143
+ Must have shape B x N_points x embedding_dim for any N_points.
144
+
145
+ Returns:
146
+ torch.Tensor: the processed point_embedding
147
+ torch.Tensor: the processed image_embedding
148
+ """
149
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
150
+ bs, c, h, w = image_embedding.shape
151
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
152
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
153
+
154
+ # Prepare queries
155
+ queries = point_embedding
156
+ keys = image_embedding
157
+
158
+ # Apply transformer blocks and final layernorm
159
+ for layer in self.layers:
160
+ queries, keys = layer(
161
+ queries=queries,
162
+ keys=keys,
163
+ query_pe=point_embedding,
164
+ key_pe=image_pe,
165
+ )
166
+
167
+ # Apply the final attention layer from the points to the image
168
+ q = queries + point_embedding
169
+ k = keys + image_pe
170
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
171
+ queries = queries + attn_out
172
+ queries = self.norm_final_attn(queries)
173
+
174
+ return queries, keys
175
+
176
+
177
+ class TwoWayAttentionBlock(nn.Module):
178
+ def __init__(
179
+ self,
180
+ embedding_dim: int,
181
+ num_heads: int,
182
+ mlp_dim: int = 2048,
183
+ activation: Type[nn.Module] = nn.ReLU,
184
+ attention_downsample_rate: int = 2,
185
+ skip_first_layer_pe: bool = False,
186
+ ) -> None:
187
+ """
188
+ A transformer block with four layers: (1) self-attention of sparse
189
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
190
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
191
+ inputs.
192
+
193
+ Arguments:
194
+ embedding_dim (int): the channel dimension of the embeddings
195
+ num_heads (int): the number of heads in the attention layers
196
+ mlp_dim (int): the hidden dimension of the mlp block
197
+ activation (nn.Module): the activation of the mlp block
198
+ skip_first_layer_pe (bool): skip the PE on the first layer
199
+ """
200
+ super().__init__()
201
+ self.self_attn = Attention(embedding_dim, num_heads)
202
+ self.norm1 = nn.LayerNorm(embedding_dim)
203
+
204
+ self.cross_attn_token_to_image = Attention(
205
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
206
+ )
207
+ self.norm2 = nn.LayerNorm(embedding_dim)
208
+
209
+ self.mlp = MLP(
210
+ embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
211
+ )
212
+ self.norm3 = nn.LayerNorm(embedding_dim)
213
+
214
+ self.norm4 = nn.LayerNorm(embedding_dim)
215
+ self.cross_attn_image_to_token = Attention(
216
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
217
+ )
218
+
219
+ self.skip_first_layer_pe = skip_first_layer_pe
220
+
221
+ def forward(
222
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
223
+ ) -> Tuple[Tensor, Tensor]:
224
+ # Self attention block
225
+ if self.skip_first_layer_pe:
226
+ queries = self.self_attn(q=queries, k=queries, v=queries)
227
+ else:
228
+ q = queries + query_pe
229
+ attn_out = self.self_attn(q=q, k=q, v=queries)
230
+ queries = queries + attn_out
231
+ queries = self.norm1(queries)
232
+
233
+ # Cross attention block, tokens attending to image embedding
234
+ q = queries + query_pe
235
+ k = keys + key_pe
236
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
237
+ queries = queries + attn_out
238
+ queries = self.norm2(queries)
239
+
240
+ # MLP block
241
+ mlp_out = self.mlp(queries)
242
+ queries = queries + mlp_out
243
+ queries = self.norm3(queries)
244
+
245
+ # Cross attention block, image embedding attending to tokens
246
+ q = queries + query_pe
247
+ k = keys + key_pe
248
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
249
+ keys = keys + attn_out
250
+ keys = self.norm4(keys)
251
+
252
+ return queries, keys
253
+
254
+
255
+ class Attention(nn.Module):
256
+ """
257
+ An attention layer that allows for downscaling the size of the embedding
258
+ after projection to queries, keys, and values.
259
+ """
260
+
261
+ def __init__(
262
+ self,
263
+ embedding_dim: int,
264
+ num_heads: int,
265
+ downsample_rate: int = 1,
266
+ dropout: float = 0.0,
267
+ kv_in_dim: int = None,
268
+ ) -> None:
269
+ super().__init__()
270
+ self.embedding_dim = embedding_dim
271
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
272
+ self.internal_dim = embedding_dim // downsample_rate
273
+ self.num_heads = num_heads
274
+ assert (
275
+ self.internal_dim % num_heads == 0
276
+ ), "num_heads must divide embedding_dim."
277
+
278
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
279
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
280
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
281
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
282
+
283
+ self.dropout_p = dropout
284
+
285
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
286
+ b, n, c = x.shape
287
+ x = x.reshape(b, n, num_heads, c // num_heads)
288
+ return x.transpose(1, 2).contiguous() # B x N_heads x N_tokens x C_per_head
289
+
290
+ def _recombine_heads(self, x: Tensor) -> Tensor:
291
+ b, n_heads, n_tokens, c_per_head = x.shape
292
+ x = x.transpose(1, 2).contiguous()
293
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
294
+
295
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
296
+ # Input projections
297
+ q = self.q_proj(q)
298
+ k = self.k_proj(k)
299
+ v = self.v_proj(v)
300
+
301
+ # Separate into heads
302
+ q = self._separate_heads(q, self.num_heads)
303
+ k = self._separate_heads(k, self.num_heads)
304
+ v = self._separate_heads(v, self.num_heads)
305
+
306
+ dropout_p = self.dropout_p if self.training else 0.0
307
+ # Attention
308
+ try:
309
+ with sdp_kernel_context(dropout_p):
310
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
311
+ except Exception as e:
312
+ # Fall back to all kernels if the Flash attention kernel fails
313
+ warnings.warn(
314
+ f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
315
+ f"kernels for scaled_dot_product_attention (which may have a slower speed).",
316
+ category=UserWarning,
317
+ stacklevel=2,
318
+ )
319
+ global ALLOW_ALL_KERNELS
320
+ ALLOW_ALL_KERNELS = True
321
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
322
+
323
+ out = self._recombine_heads(out)
324
+ out = self.out_proj(out)
325
+
326
+ return out
327
+
328
+
329
+ class RoPEAttention(Attention):
330
+ """Attention with rotary position encoding."""
331
+
332
+ def __init__(
333
+ self,
334
+ *args,
335
+ rope_theta=10000.0,
336
+ # whether to repeat q rope to match k length
337
+ # this is needed for cross-attention to memories
338
+ rope_k_repeat=False,
339
+ feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
340
+ **kwargs,
341
+ ):
342
+ super().__init__(*args, **kwargs)
343
+
344
+ self.compute_cis = partial(
345
+ compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
346
+ )
347
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
348
+ self.freqs_cis = freqs_cis
349
+ self.rope_k_repeat = rope_k_repeat
350
+
351
+ def forward(
352
+ self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int=0,
353
+ ) -> Tensor:
354
+ # Input projections
355
+ q = self.q_proj(q)
356
+ k = self.k_proj(k)
357
+ v = self.v_proj(v)
358
+
359
+ # Separate into heads
360
+ q = self._separate_heads(q, self.num_heads)
361
+ k = self._separate_heads(k, self.num_heads)
362
+ v = self._separate_heads(v, self.num_heads)
363
+
364
+ # Apply rotary position encoding
365
+ w = h = math.sqrt(q.shape[-2])
366
+ self.freqs_cis = self.freqs_cis.to(q.device)
367
+ if self.freqs_cis.shape[0] != q.shape[-2]:
368
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
369
+ if q.shape[-2] != k.shape[-2]:
370
+ assert self.rope_k_repeat
371
+
372
+ num_k_rope = k.size(-2) - num_k_exclude_rope
373
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
374
+ q,
375
+ k[:, :, :num_k_rope],
376
+ freqs_cis=self.freqs_cis,
377
+ repeat_freqs_k=self.rope_k_repeat,
378
+ )
379
+
380
+ dropout_p = self.dropout_p if self.training else 0.0
381
+ # Attention
382
+ try:
383
+ with sdp_kernel_context(dropout_p):
384
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
385
+ except Exception as e:
386
+ # Fall back to all kernels if the Flash attention kernel fails
387
+ warnings.warn(
388
+ f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
389
+ f"kernels for scaled_dot_product_attention (which may have a slower speed).",
390
+ category=UserWarning,
391
+ stacklevel=2,
392
+ )
393
+ global ALLOW_ALL_KERNELS
394
+ ALLOW_ALL_KERNELS = True
395
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
396
+
397
+ out = self._recombine_heads(out)
398
+ out = self.out_proj(out)
399
+
400
+ return out
prima/models/discriminator.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+
14
+ class Discriminator(nn.Module):
15
+
16
+ def __init__(self):
17
+ """
18
+ Pose + Shape discriminator proposed in HMR
19
+ """
20
+ super(Discriminator, self).__init__()
21
+
22
+ self.num_joints = 34
23
+ # poses_alone
24
+ self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1)
25
+ nn.init.xavier_uniform_(self.D_conv1.weight)
26
+ nn.init.zeros_(self.D_conv1.bias)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1)
29
+ nn.init.xavier_uniform_(self.D_conv2.weight)
30
+ nn.init.zeros_(self.D_conv2.bias)
31
+ pose_out = []
32
+ for i in range(self.num_joints):
33
+ pose_out_temp = nn.Linear(32, 1)
34
+ nn.init.xavier_uniform_(pose_out_temp.weight)
35
+ nn.init.zeros_(pose_out_temp.bias)
36
+ pose_out.append(pose_out_temp)
37
+ self.pose_out = nn.ModuleList(pose_out)
38
+
39
+ # betas
40
+ self.betas_fc1 = nn.Linear(41, 10) # SMAL betas is 41
41
+ nn.init.xavier_uniform_(self.betas_fc1.weight)
42
+ nn.init.zeros_(self.betas_fc1.bias)
43
+ self.betas_fc2 = nn.Linear(10, 5)
44
+ nn.init.xavier_uniform_(self.betas_fc2.weight)
45
+ nn.init.zeros_(self.betas_fc2.bias)
46
+ self.betas_out = nn.Linear(5, 1)
47
+ nn.init.xavier_uniform_(self.betas_out.weight)
48
+ nn.init.zeros_(self.betas_out.bias)
49
+
50
+ # bones
51
+ self.bone_fc1 = nn.Linear(24, 10) # SMAL betas is 41
52
+ nn.init.xavier_uniform_(self.bone_fc1.weight)
53
+ nn.init.zeros_(self.bone_fc1.bias)
54
+ self.bone_fc2 = nn.Linear(10, 5)
55
+ nn.init.xavier_uniform_(self.bone_fc2.weight)
56
+ nn.init.zeros_(self.bone_fc2.bias)
57
+ self.bone_out = nn.Linear(5, 1)
58
+ nn.init.xavier_uniform_(self.bone_out.weight)
59
+ nn.init.zeros_(self.bone_out.bias)
60
+
61
+ # poses_joint
62
+ self.D_alljoints_fc1 = nn.Linear(32 * self.num_joints, 1024)
63
+ nn.init.xavier_uniform_(self.D_alljoints_fc1.weight)
64
+ nn.init.zeros_(self.D_alljoints_fc1.bias)
65
+ self.D_alljoints_fc2 = nn.Linear(1024, 1024)
66
+ nn.init.xavier_uniform_(self.D_alljoints_fc2.weight)
67
+ nn.init.zeros_(self.D_alljoints_fc2.bias)
68
+ self.D_alljoints_out = nn.Linear(1024, 1)
69
+ nn.init.xavier_uniform_(self.D_alljoints_out.weight)
70
+ nn.init.zeros_(self.D_alljoints_out.bias)
71
+
72
+ def forward(self, poses: torch.Tensor, betas: torch.Tensor, bone=None) -> torch.Tensor:
73
+ """
74
+ Forward pass of the discriminator.
75
+ Args:
76
+ poses (torch.Tensor): Tensor of shape (B, 23, 3, 3) containing a batch of poses (excluding the global orientation).
77
+ betas (torch.Tensor): Tensor of shape (B, 41) containing a batch of SMAL beta coefficients.
78
+ Returns:
79
+ torch.Tensor: Discriminator output with shape (B, 25)
80
+ """
81
+ # bn = poses.shape[0]
82
+ # poses B x 207
83
+ # poses = poses.reshape(bn, -1)
84
+ # poses B x num_joints x 1 x 9
85
+ poses = poses.reshape(-1, self.num_joints, 1, 9)
86
+ bn = poses.shape[0]
87
+ # poses B x 9 x num_joints x 1
88
+ poses = poses.permute(0, 3, 1, 2).contiguous()
89
+
90
+ # poses_alone
91
+ poses = self.D_conv1(poses)
92
+ poses = self.relu(poses)
93
+ poses = self.D_conv2(poses)
94
+ poses = self.relu(poses)
95
+
96
+ poses_out = []
97
+ for i in range(self.num_joints):
98
+ poses_out_ = self.pose_out[i](poses[:, :, i, 0])
99
+ poses_out.append(poses_out_)
100
+ poses_out = torch.cat(poses_out, dim=1)
101
+
102
+ # betas
103
+ betas = self.betas_fc1(betas)
104
+ betas = self.relu(betas)
105
+ betas = self.betas_fc2(betas)
106
+ betas = self.relu(betas)
107
+ betas_out = self.betas_out(betas)
108
+
109
+ # bone
110
+ if bone is not None:
111
+ bone = self.bone_fc1(bone)
112
+ bone = self.relu(bone)
113
+ bone = self.bone_fc2(bone)
114
+ bone = self.relu(bone)
115
+ bone_out = self.bone_out(bone)
116
+
117
+ # poses_joint
118
+ poses = poses.reshape(bn, -1)
119
+ poses_all = self.D_alljoints_fc1(poses)
120
+ poses_all = self.relu(poses_all)
121
+ poses_all = self.D_alljoints_fc2(poses_all)
122
+ poses_all = self.relu(poses_all)
123
+ poses_all_out = self.D_alljoints_out(poses_all)
124
+
125
+ if bone is not None:
126
+ disc_out = torch.cat((poses_out, betas_out, poses_all_out, bone_out), 1)
127
+ else:
128
+ disc_out = torch.cat((poses_out, betas_out, poses_all_out), 1)
129
+ return disc_out
prima/models/heads/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .smal_head import build_smal_head
prima/models/heads/classifier_head.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from torch import nn
11
+
12
+
13
+ class ClassTokenHead(nn.Module):
14
+ def __init__(self, embed_dim=1280, hidden_dim=4096, output_dim=256, num_layers=3, last_bn=True):
15
+ super().__init__()
16
+ mlp = []
17
+ for l in range(num_layers):
18
+ dim1 = embed_dim if l == 0 else hidden_dim
19
+ dim2 = output_dim if l == num_layers - 1 else hidden_dim
20
+ mlp.append(nn.Linear(dim1, dim2, bias=False))
21
+ if l < num_layers - 1:
22
+ mlp.append(nn.BatchNorm1d(dim2))
23
+ mlp.append(nn.ReLU(inplace=True))
24
+ elif last_bn:
25
+ mlp.append(nn.BatchNorm1d(dim2, affine=False))
26
+ self.head = nn.Sequential(*mlp)
27
+
28
+ def forward(self, x):
29
+ cls_feats = self.head(x)
30
+ return cls_feats
prima/models/heads/smal_head.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ import einops
15
+ import pickle as pkl
16
+ from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat
17
+ from ..components.pose_transformer import TransformerDecoder
18
+
19
+
20
+ def build_smal_head(cfg):
21
+ smal_head_type = cfg.MODEL.SMAL_HEAD.get('TYPE', 'amr')
22
+
23
+ if smal_head_type == 'new_bio_pose_transformer_decoder':
24
+ return NewBioGuidedSMALPoseDecoder(cfg)
25
+ else:
26
+ raise ValueError('Unknown SMAL head type: {}'.format(smal_head_type))
27
+
28
+
29
+
30
+
31
+
32
+
33
+
34
+ class NewBioGuidedSMALPoseDecoder(nn.Module):
35
+ """
36
+ Bio-Guided SMAL Decoder with Pose Token Aggregation
37
+
38
+ Final version:
39
+ - Query tokens = [param token] + [2D keypoint tokens (optional)] + [3D keypoint tokens (optional)]
40
+ - SAM3D-body-style layer-wise keypoint token updates:
41
+ * 2D: predict (x,y) in [-0.5,0.5] from kp2d tokens -> token_augment position encoding
42
+ + grid_sample image features at predicted locations -> add into kp2d token embeddings
43
+ + invalid_mask (out-of-bounds and optional vis mask) zeroes updates
44
+ * 3D: predict (x,y,z) from kp3d tokens -> pelvis-normalize -> token_augment position encoding
45
+ - token_augment is injected by feeding (token_embeddings + token_augment) into each decoder layer.
46
+ - Only param token (index 0) is used to regress pose/betas/cam deltas.
47
+ - Outputs:
48
+ pred_smal_params: dict with global_orient/pose/betas and optional keypoints_2d/3d
49
+ pred_cam: [B,3]
50
+ extra_outputs: includes bio-guided shape_feat/init_betas and pred_smal_params_list
51
+ """
52
+
53
+ def __init__(self, cfg):
54
+ super().__init__()
55
+ self.cfg = cfg
56
+
57
+ # ========== Basic config ==========
58
+ self.joint_rep_type = cfg.MODEL.SMAL_HEAD.get("JOINT_REP", "6d")
59
+ self.joint_rep_dim = {"6d": 6, "aa": 3}[self.joint_rep_type]
60
+ self.npose = self.joint_rep_dim * (cfg.SMAL.NUM_JOINTS + 1)
61
+
62
+ # ========== Dimensions ==========
63
+ self.decoder_dim = cfg.MODEL.SMAL_HEAD.get("DECODER_DIM", 1024)
64
+ context_dim = cfg.MODEL.SMAL_HEAD.get("IN_CHANNELS", 1024)
65
+ num_layers = cfg.MODEL.SMAL_HEAD.get("NUM_DECODER_LAYERS", 4)
66
+ num_heads = cfg.MODEL.SMAL_HEAD.get("NUM_HEADS", 8)
67
+ mlp_ratio = cfg.MODEL.SMAL_HEAD.get("MLP_RATIO", 4.0)
68
+
69
+ # keypoint config
70
+ self.use_keypoint_2d_tokens = cfg.MODEL.SMAL_HEAD.get("USE_KEYPOINT_2D_TOKENS", False)
71
+ self.use_keypoint_3d_tokens = cfg.MODEL.SMAL_HEAD.get("USE_KEYPOINT_3D_TOKENS", False)
72
+ self.num_keypoints = cfg.SMAL.get("NUM_KEYPOINTS", 26)
73
+ self.keypoint_token_update = cfg.MODEL.SMAL_HEAD.get("KEYPOINT_TOKEN_UPDATE", False)
74
+
75
+ # 2D update: whether to inject sampled image feature into kp2d tokens
76
+ self.kp2d_inject_image_feat = cfg.MODEL.SMAL_HEAD.get("KP2D_INJECT_IMAGE_FEAT", True)
77
+
78
+ # IEF iters
79
+ self.ief_iters = cfg.MODEL.SMAL_HEAD.get("IEF_ITERS", 3)
80
+
81
+ # pelvis indices
82
+ self.pelvis_idx = cfg.SMAL.get("PELVIS_IDX", [0, 1])
83
+
84
+ # ========== Test-time optimization config ==========
85
+ self._tta_mode = False # Track if in test-time adaptation mode
86
+
87
+ # ========== [Coarse] Bio prior ==========
88
+ self.bio_to_betas_init = nn.Sequential(
89
+ nn.Linear(context_dim, 256),
90
+ nn.LeakyReLU(0.2, inplace=True),
91
+ nn.Linear(256, 41),
92
+ )
93
+ self.shape_projector = nn.Sequential(
94
+ nn.Linear(41, 128),
95
+ nn.ReLU(inplace=True),
96
+ nn.Linear(128, 128),
97
+ )
98
+
99
+ # ========== Init pose/cam ==========
100
+ self.init_pose = nn.Parameter(torch.zeros(1, self.npose))
101
+ self.init_cam = nn.Parameter(torch.tensor([[0.9, 0, 0]], dtype=torch.float32))
102
+
103
+ # params -> param token
104
+ param_dim = self.npose + 41 + 3
105
+ self.param_to_token = nn.Sequential(
106
+ nn.Linear(param_dim, self.decoder_dim),
107
+ nn.LayerNorm(self.decoder_dim),
108
+ nn.ReLU(),
109
+ )
110
+
111
+ # ========== Keypoint token embeddings ==========
112
+ if self.use_keypoint_2d_tokens:
113
+ self.keypoint_2d_embeddings = nn.Embedding(self.num_keypoints, self.decoder_dim)
114
+ nn.init.normal_(self.keypoint_2d_embeddings.weight, std=0.02)
115
+
116
+ # (x,y) -> token augment
117
+ self.keypoint_2d_pos_encoder = nn.Sequential(
118
+ nn.Linear(2, 256),
119
+ nn.ReLU(),
120
+ nn.Linear(256, self.decoder_dim),
121
+ )
122
+ # sampled image feat -> token dim (add into token embeddings)
123
+ self.keypoint_2d_feat_linear = nn.Linear(self.decoder_dim, self.decoder_dim)
124
+
125
+ if self.use_keypoint_3d_tokens:
126
+ self.keypoint_3d_embeddings = nn.Embedding(self.num_keypoints, self.decoder_dim)
127
+ nn.init.normal_(self.keypoint_3d_embeddings.weight, std=0.02)
128
+
129
+ # (x,y,z) -> token augment
130
+ self.keypoint_3d_pos_encoder = nn.Sequential(
131
+ nn.Linear(3, 256),
132
+ nn.ReLU(),
133
+ nn.Linear(256, self.decoder_dim),
134
+ )
135
+
136
+ # ========== Per-token intermediate heads (predict from kp tokens themselves) ==========
137
+ if self.keypoint_token_update:
138
+ if self.use_keypoint_2d_tokens:
139
+ self.kp2d_from_tokens = nn.Sequential(
140
+ nn.Linear(self.decoder_dim, self.decoder_dim),
141
+ nn.ReLU(),
142
+ nn.Linear(self.decoder_dim, 2),
143
+ )
144
+ if self.use_keypoint_3d_tokens:
145
+ self.kp3d_from_tokens = nn.Sequential(
146
+ nn.Linear(self.decoder_dim, self.decoder_dim),
147
+ nn.ReLU(),
148
+ nn.Linear(self.decoder_dim, 3),
149
+ )
150
+
151
+ # ========== Image feature projection + pos encoding ==========
152
+ self.image_proj = nn.Identity() if context_dim == self.decoder_dim else nn.Linear(context_dim, self.decoder_dim)
153
+ self.image_pos_encoding = PositionalEncoding2D(self.decoder_dim)
154
+
155
+ # ========== Transformer decoder layers ==========
156
+ self.layers = nn.ModuleList(
157
+ [
158
+ PoseTransformerDecoderLayer(
159
+ d_model=self.decoder_dim,
160
+ nhead=num_heads,
161
+ dim_feedforward=int(self.decoder_dim * mlp_ratio),
162
+ dropout=0.1,
163
+ )
164
+ for _ in range(num_layers)
165
+ ]
166
+ )
167
+ self.norm = nn.LayerNorm(self.decoder_dim)
168
+
169
+ # ========== Regression heads (param token only) ==========
170
+ self.decpose = nn.Sequential(
171
+ nn.Linear(self.decoder_dim, self.decoder_dim),
172
+ nn.ReLU(),
173
+ nn.Linear(self.decoder_dim, self.npose),
174
+ )
175
+ self.decshape = nn.Sequential(
176
+ nn.Linear(self.decoder_dim, self.decoder_dim),
177
+ nn.ReLU(),
178
+ nn.Linear(self.decoder_dim, 41),
179
+ )
180
+ self.deccam = nn.Sequential(
181
+ nn.Linear(self.decoder_dim, self.decoder_dim // 2),
182
+ nn.ReLU(),
183
+ nn.Linear(self.decoder_dim // 2, 3),
184
+ )
185
+
186
+ # --------------------------
187
+ # helpers: query token build
188
+ # --------------------------
189
+ def _build_query_tokens(self, pred_pose, pred_betas, pred_cam):
190
+ B = pred_pose.shape[0]
191
+ tokens = []
192
+
193
+ params = torch.cat([pred_pose, pred_betas, pred_cam], dim=1) # [B, param_dim]
194
+ param_token = self.param_to_token(params).unsqueeze(1) # [B,1,D]
195
+ tokens.append(param_token)
196
+
197
+ kp2d_start = None
198
+ kp3d_start = None
199
+
200
+ if self.use_keypoint_2d_tokens:
201
+ kp2d_start = sum(t.shape[1] for t in tokens)
202
+ kp2d_tokens = self.keypoint_2d_embeddings.weight.unsqueeze(0).expand(B, -1, -1).contiguous()
203
+ tokens.append(kp2d_tokens)
204
+
205
+ if self.use_keypoint_3d_tokens:
206
+ kp3d_start = sum(t.shape[1] for t in tokens)
207
+ kp3d_tokens = self.keypoint_3d_embeddings.weight.unsqueeze(0).expand(B, -1, -1).contiguous()
208
+ tokens.append(kp3d_tokens)
209
+
210
+ token_embeddings = torch.cat(tokens, dim=1) # [B,Nq,D]
211
+ token_augment = torch.zeros_like(token_embeddings)
212
+
213
+ return token_embeddings, token_augment, kp2d_start, kp3d_start
214
+
215
+ # --------------------------
216
+ # helpers: updates
217
+ # --------------------------
218
+ def _kp2d_update(self, token_embeddings, token_augment, image_features, kp2d_start, H, W, vis_mask=None):
219
+ """
220
+ SAM3D-body-style 2D keypoint token update.
221
+
222
+ image_features: [B, HW, D] projected + pos-encoded, with HW=H*W (expected 12*16)
223
+ vis_mask: optional [B,N] bool (True=valid)
224
+ """
225
+ if not (self.keypoint_token_update and self.use_keypoint_2d_tokens):
226
+ return token_embeddings, token_augment, None
227
+
228
+ B = token_embeddings.shape[0]
229
+ N = self.num_keypoints
230
+
231
+ kp_tokens = token_embeddings[:, kp2d_start : kp2d_start + N, :] # [B,N,D]
232
+
233
+ # predict coords in [-0.5,0.5]
234
+ pred_xy = self.kp2d_from_tokens(kp_tokens)
235
+ pred_xy = torch.tanh(pred_xy) * 0.5
236
+
237
+ # invalid mask (out of bounds + optional vis)
238
+ pred_xy_01 = pred_xy + 0.5
239
+ invalid = (
240
+ (pred_xy_01[..., 0] < 0.0)
241
+ | (pred_xy_01[..., 0] > 1.0)
242
+ | (pred_xy_01[..., 1] < 0.0)
243
+ | (pred_xy_01[..., 1] > 1.0)
244
+ )
245
+ if vis_mask is not None:
246
+ invalid = invalid | (~vis_mask)
247
+ valid = (~invalid).unsqueeze(-1).float() # [B,N,1]
248
+
249
+ # update token_augment slice
250
+ token_augment = token_augment.clone()
251
+ token_augment[:, kp2d_start : kp2d_start + N, :] = self.keypoint_2d_pos_encoder(pred_xy) * valid
252
+
253
+ # inject sampled image feature into kp2d tokens (optional)
254
+ if self.kp2d_inject_image_feat:
255
+ img = image_features.view(B, H, W, self.decoder_dim).permute(0, 3, 1, 2).contiguous() # [B,D,H,W]
256
+ grid = (pred_xy * 2.0).unsqueeze(2) # [B,N,1,2] in [-1,1]
257
+
258
+ sampled = (
259
+ F.grid_sample(img, grid, mode="bilinear", padding_mode="zeros", align_corners=False)
260
+ .squeeze(3)
261
+ .permute(0, 2, 1)
262
+ .contiguous()
263
+ ) # [B,N,D]
264
+
265
+ sampled = sampled * valid
266
+ token_embeddings = token_embeddings.clone()
267
+ token_embeddings[:, kp2d_start : kp2d_start + N, :] += self.keypoint_2d_feat_linear(sampled)
268
+
269
+ return token_embeddings, token_augment, pred_xy
270
+
271
+ def _kp3d_update(self, token_embeddings, token_augment, kp3d_start):
272
+ if not (self.keypoint_token_update and self.use_keypoint_3d_tokens):
273
+ return token_embeddings, token_augment, None
274
+
275
+ N = self.num_keypoints
276
+ kp_tokens = token_embeddings[:, kp3d_start : kp3d_start + N, :] # [B,N,D]
277
+
278
+ pred_xyz = self.kp3d_from_tokens(kp_tokens) # [B,N,3]
279
+
280
+ # pelvis normalize
281
+ pelvis_center = pred_xyz[:, self.pelvis_idx, :].mean(dim=1, keepdim=True) # [B,1,3]
282
+ pred_xyz_norm = pred_xyz - pelvis_center
283
+
284
+ token_augment = token_augment.clone()
285
+ token_augment[:, kp3d_start : kp3d_start + N, :] = self.keypoint_3d_pos_encoder(pred_xyz_norm)
286
+
287
+ return token_embeddings, token_augment, pred_xyz
288
+
289
+ # --------------------------
290
+ # forward
291
+ # --------------------------
292
+ def forward(self, x, keypoint_coords_2d=None, keypoint_coords_3d=None, **kwargs):
293
+ """
294
+ Inputs:
295
+ x: [B, Hp*Wp+1, C] image tokens from backbone concatenated with bio token
296
+ (BioCLIP token is the last token in the sequence)
297
+
298
+ Note:
299
+ keypoint_coords_2d can optionally provide a vis/conf mask: [B,N,3] (x,y,vis)
300
+ We do NOT inject GT coords into tokens by default; they are used only as optional masking.
301
+ """
302
+ B = x.shape[0]
303
+
304
+ # ---- Data preprocessing ----
305
+ # Handle 4D input tensors of shape (B, C, H, W).
306
+ if len(x.shape) == 4:
307
+ x = einops.rearrange(x, 'b c h w -> b (h w) c')
308
+
309
+ bio_token = x[:, -1, :] # [B, C] - the BioCLIP token is the final token
310
+ image_features = x[:, :-1, :] # [B, H*W, C] - remaining image features
311
+
312
+ # ---- Coarse bio shape ----
313
+ init_betas = self.bio_to_betas_init(bio_token) # [B,41]
314
+ shape_feat = F.normalize(self.shape_projector(init_betas), dim=1)
315
+
316
+ # ---- Image feature projection ----
317
+ # Project only image features, excluding the bio token.
318
+ image_features = self.image_proj(image_features) # [B,HW,D]
319
+
320
+ # Your backbone: vit crop 256x192 with patch16 => Hp=12, Wp=16
321
+ H, W = 12, 16
322
+ assert image_features.shape[1] == H * W, f"Expected HW={H*W}, got {image_features.shape[1]}"
323
+
324
+ img_pos = self.image_pos_encoding(H, W).to(image_features.device) # [HW,D]
325
+ image_features = image_features + img_pos.unsqueeze(0)
326
+
327
+ # ---- init params ----
328
+ pred_pose = self.init_pose.expand(B, -1)
329
+ pred_betas = init_betas
330
+ pred_cam = self.init_cam.expand(B, -1)
331
+
332
+ pred_pose_list, pred_betas_list, pred_cam_list = [], [], []
333
+ pred_keypoints_2d_list, pred_keypoints_3d_list = [], []
334
+
335
+ # Optional visibility mask from provided 2D keypoints
336
+ vis_mask = None
337
+ if keypoint_coords_2d is not None and keypoint_coords_2d.shape[-1] == 3:
338
+ vis_mask = keypoint_coords_2d[..., 2] > 0 # [B,N]
339
+
340
+ # ---- IEF loop ----
341
+ for _ in range(self.ief_iters):
342
+ token_embeddings, token_augment, kp2d_start, kp3d_start = self._build_query_tokens(
343
+ pred_pose, pred_betas, pred_cam
344
+ )
345
+
346
+ # ---- Transformer layers ----
347
+ for layer_idx, layer in enumerate(self.layers):
348
+ # inject dynamic augment
349
+ tokens_in = token_embeddings + token_augment
350
+ token_embeddings = layer(tokens_in, image_features)
351
+
352
+ # layer-wise token update (skip last layer)
353
+ if self.keypoint_token_update and (layer_idx < len(self.layers) - 1):
354
+ if self.use_keypoint_2d_tokens:
355
+ token_embeddings, token_augment, pred_xy = self._kp2d_update(
356
+ token_embeddings, token_augment, image_features, kp2d_start, H, W, vis_mask=vis_mask
357
+ )
358
+ if pred_xy is not None:
359
+ pred_keypoints_2d_list.append(pred_xy)
360
+
361
+ if self.use_keypoint_3d_tokens:
362
+ token_embeddings, token_augment, pred_xyz = self._kp3d_update(
363
+ token_embeddings, token_augment, kp3d_start
364
+ )
365
+ if pred_xyz is not None:
366
+ pred_keypoints_3d_list.append(pred_xyz)
367
+
368
+ # ---- Regress deltas from param token ----
369
+ token_embeddings = self.norm(token_embeddings)
370
+ param_token_out = token_embeddings[:, 0, :]
371
+
372
+ delta_pose = self.decpose(param_token_out)
373
+ delta_betas = self.decshape(param_token_out)
374
+ delta_cam = self.deccam(param_token_out)
375
+
376
+ pred_pose = pred_pose + delta_pose
377
+ pred_betas = pred_betas + delta_betas
378
+ pred_cam = pred_cam + delta_cam
379
+
380
+ pred_pose_list.append(pred_pose)
381
+ pred_betas_list.append(pred_betas)
382
+ pred_cam_list.append(pred_cam)
383
+
384
+ # ---- Convert joint representation ----
385
+ joint_conversion_fn = {
386
+ "6d": rot6d_to_rotmat,
387
+ "aa": lambda y: aa_to_rotmat(y.view(-1, 3).contiguous()),
388
+ }[self.joint_rep_type]
389
+
390
+ pred_smal_params_list = {
391
+ "pose": torch.cat(
392
+ [joint_conversion_fn(p).view(B, -1, 3, 3)[:, 1:, :, :] for p in pred_pose_list],
393
+ dim=0,
394
+ ),
395
+ "betas": torch.cat(pred_betas_list, dim=0),
396
+ "cam": torch.cat(pred_cam_list, dim=0),
397
+ "keypoints_2d": torch.cat(pred_keypoints_2d_list, dim=0) if len(pred_keypoints_2d_list) else None,
398
+ "keypoints_3d": torch.cat(pred_keypoints_3d_list, dim=0) if len(pred_keypoints_3d_list) else None,
399
+ }
400
+
401
+ pred_pose_mat = joint_conversion_fn(pred_pose).view(B, self.cfg.SMAL.NUM_JOINTS + 1, 3, 3)
402
+ pred_smal_params = {
403
+ "global_orient": pred_pose_mat[:, [0]],
404
+ "pose": pred_pose_mat[:, 1:],
405
+ "betas": pred_betas,
406
+ }
407
+
408
+ # expose final predicted keypoints for losses
409
+ if self.keypoint_token_update:
410
+ if self.use_keypoint_2d_tokens and len(pred_keypoints_2d_list):
411
+ pred_smal_params["keypoints_2d"] = pred_keypoints_2d_list[-1]
412
+ if self.use_keypoint_3d_tokens and len(pred_keypoints_3d_list):
413
+ pred_smal_params["keypoints_3d"] = pred_keypoints_3d_list[-1]
414
+
415
+ extra_outputs = {
416
+ "shape_feat": shape_feat,
417
+ "init_betas": init_betas,
418
+ "pred_smal_params_list": pred_smal_params_list,
419
+ }
420
+ return pred_smal_params, pred_cam, extra_outputs
421
+
422
+ # --------------------------
423
+ # Test-time optimization helpers
424
+ # --------------------------
425
+ def freeze_all_except_keypoint_tokens(self):
426
+ """
427
+ Freeze all parameters except keypoint token embeddings and their prediction heads.
428
+ Use this before test-time optimization.
429
+ """
430
+ # Freeze everything first
431
+ for param in self.parameters():
432
+ param.requires_grad = False
433
+
434
+ # Unfreeze only keypoint-related parameters
435
+ if self.use_keypoint_2d_tokens:
436
+ for param in self.keypoint_2d_embeddings.parameters():
437
+ param.requires_grad = True
438
+ for param in self.keypoint_2d_pos_encoder.parameters():
439
+ param.requires_grad = True
440
+ for param in self.keypoint_2d_feat_linear.parameters():
441
+ param.requires_grad = True
442
+ if self.keypoint_token_update:
443
+ for param in self.kp2d_from_tokens.parameters():
444
+ param.requires_grad = True
445
+
446
+ if self.use_keypoint_3d_tokens:
447
+ for param in self.keypoint_3d_embeddings.parameters():
448
+ param.requires_grad = True
449
+ for param in self.keypoint_3d_pos_encoder.parameters():
450
+ param.requires_grad = True
451
+ if self.keypoint_token_update:
452
+ for param in self.kp3d_from_tokens.parameters():
453
+ param.requires_grad = True
454
+
455
+ self._tta_mode = True
456
+ print("[TTA] Frozen all parameters except keypoint tokens")
457
+
458
+ def freeze_backbone_only(self):
459
+ """
460
+ Freeze only backbone, keep SMAL head trainable.
461
+ Use for full SMAL parameter + keypoint optimization.
462
+ """
463
+ # Unfreeze all SMAL head parameters
464
+ for param in self.parameters():
465
+ param.requires_grad = True
466
+
467
+ self._tta_mode = True
468
+ print("[TTA] SMAL head fully trainable (backbone frozen separately)")
469
+
470
+ def freeze_except_regression_heads(self):
471
+ """
472
+ Freeze everything except the final regression heads (pose/shape/cam) and keypoint embeddings.
473
+ Keep transformer frozen to preserve pretrained representations.
474
+ """
475
+ # Freeze everything first
476
+ for param in self.parameters():
477
+ param.requires_grad = False
478
+
479
+ # Unfreeze only the final regression heads (small MLPs)
480
+ for param in self.decpose.parameters():
481
+ param.requires_grad = True
482
+ for param in self.decshape.parameters():
483
+ param.requires_grad = True
484
+ for param in self.deccam.parameters():
485
+ param.requires_grad = True
486
+
487
+ # Unfreeze ONLY keypoint embeddings (learned tokens, NOT position encoders)
488
+ if self.use_keypoint_2d_tokens:
489
+ self.keypoint_2d_embeddings.weight.requires_grad = True
490
+
491
+ if self.use_keypoint_3d_tokens:
492
+ self.keypoint_3d_embeddings.weight.requires_grad = True
493
+
494
+ # DO NOT unfreeze transformer - keep pretrained representations
495
+ # DO NOT unfreeze param_to_token - keep initial token mapping stable
496
+
497
+ self._tta_mode = True
498
+ print("[TTA] Frozen all except regression heads and keypoint embeddings")
499
+
500
+ def unfreeze_all(self):
501
+ """Restore all parameters to trainable state."""
502
+ for param in self.parameters():
503
+ param.requires_grad = True
504
+ self._tta_mode = False
505
+ print("[TTA] Unfrozen all parameters")
506
+
507
+ def get_tta_parameters(self, mode='keypoints_only'):
508
+ """
509
+ Get list of parameters that should be optimized during test-time adaptation.
510
+ MUST match what's unfrozen by freeze methods!
511
+
512
+ Args:
513
+ mode: 'keypoints_only', 'regression_heads', or 'all'
514
+ """
515
+ params = []
516
+
517
+ # Keypoint embeddings only (NOT position encoders or feature linears)
518
+ if mode in ['keypoints_only', 'regression_heads', 'all']:
519
+ if self.use_keypoint_2d_tokens:
520
+ params.append(self.keypoint_2d_embeddings.weight)
521
+
522
+ if self.use_keypoint_3d_tokens:
523
+ params.append(self.keypoint_3d_embeddings.weight)
524
+
525
+ # Regression heads only (NO transformer or param_to_token)
526
+ if mode in ['regression_heads', 'all']:
527
+ params.extend(list(self.decpose.parameters()))
528
+ params.extend(list(self.decshape.parameters()))
529
+ params.extend(list(self.deccam.parameters()))
530
+
531
+ return params
532
+
533
+
534
+
535
+
536
+ class PoseTransformerDecoderLayer(nn.Module):
537
+ """
538
+ Single-layer transformer decoder for pose-token aggregation.
539
+ Includes self-attention over tokens, cross-attention from tokens to
540
+ image features, and a feed-forward network.
541
+ """
542
+
543
+ def __init__(self, d_model=1024, nhead=8, dim_feedforward=4096, dropout=0.1):
544
+ super().__init__()
545
+
546
+ # Self-attention over tokens
547
+ self.self_attn = nn.MultiheadAttention(
548
+ d_model, nhead, dropout=dropout, batch_first=True
549
+ )
550
+ self.norm1 = nn.LayerNorm(d_model)
551
+ self.dropout1 = nn.Dropout(dropout)
552
+
553
+ # Cross-attention from image features into tokens
554
+ self.cross_attn = nn.MultiheadAttention(
555
+ d_model, nhead, dropout=dropout, batch_first=True
556
+ )
557
+ self.norm2 = nn.LayerNorm(d_model)
558
+ self.dropout2 = nn.Dropout(dropout)
559
+
560
+ # Feed-Forward Network
561
+ self.ffn = nn.Sequential(
562
+ nn.Linear(d_model, dim_feedforward),
563
+ nn.GELU(),
564
+ nn.Dropout(dropout),
565
+ nn.Linear(dim_feedforward, d_model),
566
+ nn.Dropout(dropout),
567
+ )
568
+ self.norm3 = nn.LayerNorm(d_model)
569
+
570
+ def forward(self, tokens, image_features):
571
+ """
572
+ Args:
573
+ tokens: [B, N_tokens, C] containing pose and keypoint tokens
574
+ image_features: [B, N_pixels, C] image features
575
+
576
+ Returns:
577
+ tokens: [B, N_tokens, C] updated tokens
578
+ """
579
+
580
+ # Self-attention lets tokens exchange information.
581
+ attn_output, _ = self.self_attn(tokens, tokens, tokens)
582
+ tokens = tokens + self.dropout1(attn_output)
583
+ tokens = self.norm1(tokens)
584
+
585
+ # Cross-attention injects visual information from image features.
586
+ attn_output, _ = self.cross_attn(
587
+ query=tokens,
588
+ key=image_features,
589
+ value=image_features,
590
+ )
591
+ tokens = tokens + self.dropout2(attn_output)
592
+ tokens = self.norm2(tokens)
593
+
594
+ # Feed-Forward Network
595
+ ffn_output = self.ffn(tokens)
596
+ tokens = tokens + ffn_output
597
+ tokens = self.norm3(tokens)
598
+
599
+ return tokens
600
+
601
+
602
+ class PositionalEncoding2D(nn.Module):
603
+ """
604
+ 2D sinusoidal positional encoding for image features.
605
+ """
606
+
607
+ def __init__(self, embed_dim=1024, temperature=10000):
608
+ super().__init__()
609
+ self.embed_dim = embed_dim
610
+ self.temperature = temperature
611
+
612
+ def forward(self, H, W):
613
+ """
614
+ Args:
615
+ H, W: height and width of the feature map
616
+
617
+ Returns:
618
+ pos_encoding: [H*W, embed_dim]
619
+ """
620
+ # Build grid coordinates.
621
+ y_embed = torch.arange(H, dtype=torch.float32).unsqueeze(1).repeat(1, W)
622
+ x_embed = torch.arange(W, dtype=torch.float32).unsqueeze(0).repeat(H, 1)
623
+
624
+ # Normalize to [0, 1].
625
+ y_embed = y_embed / H
626
+ x_embed = x_embed / W
627
+
628
+ # Build frequencies.
629
+ dim_t = torch.arange(self.embed_dim // 2, dtype=torch.float32)
630
+ dim_t = self.temperature ** (2 * dim_t / self.embed_dim)
631
+
632
+ # Apply sine/cosine encoding.
633
+ pos_x = x_embed[: , : , None] / dim_t
634
+ pos_y = y_embed[:, :, None] / dim_t
635
+
636
+ pos_x = torch.stack(
637
+ [pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()], dim=3
638
+ ).flatten(2)
639
+ pos_y = torch. stack(
640
+ [pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()], dim=3
641
+ ).flatten(2)
642
+
643
+ pos = torch.cat([pos_y, pos_x], dim=2).flatten(0, 1) # [H*W, embed_dim]
644
+
645
+ return pos
646
+
647
+
prima/models/losses.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import numpy as np
13
+ import pickle
14
+ import torch.nn.functional as F
15
+ from ..utils.geometry import aa_to_rotmat
16
+ from typing import Dict
17
+
18
+
19
+ def matrix_to_axis_angle(rot_mats: torch.Tensor) -> torch.Tensor:
20
+ """Convert rotation matrices (..., 3, 3) to axis-angle vectors (..., 3).
21
+
22
+ This local implementation avoids a hard runtime dependency on PyTorch3D.
23
+ """
24
+ if rot_mats.shape[-2:] != (3, 3):
25
+ raise ValueError(f"Expected (..., 3, 3) rotation matrices, got {rot_mats.shape}")
26
+
27
+ trace = rot_mats[..., 0, 0] + rot_mats[..., 1, 1] + rot_mats[..., 2, 2]
28
+ cos_theta = (trace - 1.0) * 0.5
29
+ cos_theta = torch.clamp(cos_theta, -1.0, 1.0)
30
+ theta = torch.acos(cos_theta)
31
+
32
+ vee = torch.stack(
33
+ [
34
+ rot_mats[..., 2, 1] - rot_mats[..., 1, 2],
35
+ rot_mats[..., 0, 2] - rot_mats[..., 2, 0],
36
+ rot_mats[..., 1, 0] - rot_mats[..., 0, 1],
37
+ ],
38
+ dim=-1,
39
+ )
40
+ sin_theta = torch.sin(theta)
41
+ eps = 1e-6
42
+ scale = theta / torch.clamp(2.0 * sin_theta, min=eps)
43
+ aa = vee * scale.unsqueeze(-1)
44
+
45
+ # For very small angles, first-order approximation: aa ~= 0.5 * vee
46
+ small = theta < 1e-4
47
+ if small.any():
48
+ aa = torch.where(small.unsqueeze(-1), 0.5 * vee, aa)
49
+ return aa
50
+
51
+
52
+
53
+ class DepthLoss(nn.Module):
54
+ """
55
+ Depth loss between predicted SMAL vertices and GT SMAL vertices.
56
+ Compares vertex Z (camera space) after applying camera translation.
57
+ Only computes loss for samples that have valid GT SMAL parameters.
58
+ """
59
+ def __init__(self, loss_type: str = 'l1'):
60
+ super().__init__()
61
+ self.loss_type = loss_type
62
+ self.l1 = nn.L1Loss(reduction='none') # Changed to 'none' for per-sample masking
63
+ self.l2 = nn.MSELoss(reduction='none') # Changed to 'none' for per-sample masking
64
+
65
+ def forward(self,
66
+ pred_vertices: torch.Tensor, # (B, V, 3)
67
+ pred_cam_t: torch.Tensor, # (B, 3)
68
+ gt_smal_params: Dict[str, torch.Tensor],
69
+ smal_model, # SMAL instance callable
70
+ is_axis_angle: Dict[str, torch.Tensor],
71
+ gt_cam_t: torch.Tensor = None, # (B, 3) or None -> fallback to pred_cam_t
72
+ has_smal_params: Dict[str, torch.Tensor] = None # Added masking support
73
+ ) -> torch.Tensor:
74
+ batch_size = pred_vertices.shape[0]
75
+ device = pred_vertices.device
76
+
77
+ # Determine which samples have valid GT SMAL params
78
+ # A sample is valid only if it has GT for pose, betas, and global_orient
79
+ if has_smal_params is not None:
80
+ valid_mask = (has_smal_params['pose'] *
81
+ has_smal_params['betas'] *
82
+ has_smal_params['global_orient']).bool()
83
+
84
+ # If no samples have valid GT, return zero loss
85
+ if valid_mask.sum() == 0:
86
+ return torch.tensor(0., device=device, dtype=pred_vertices.dtype)
87
+ else:
88
+ # If not provided, assume all samples are valid
89
+ valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
90
+
91
+ # prepare GT params for SMAL
92
+ gt_params_for_smal = {}
93
+ for k in ['global_orient', 'pose', 'betas']:
94
+ val = gt_smal_params[k].to(device=device)
95
+ if k == 'betas':
96
+ gt_params_for_smal[k] = val.view(batch_size, -1)
97
+ else:
98
+ gt_val = val.view(batch_size, -1)
99
+ if is_axis_angle[k].all():
100
+ gt_val = aa_to_rotmat(gt_val.reshape(-1, 3)).view(batch_size, -1, 3, 3)
101
+ else:
102
+ gt_val = gt_val.view(batch_size, -1, 3, 3)
103
+ gt_params_for_smal[k] = gt_val
104
+
105
+ # generate GT vertices (no grad)
106
+ with torch.no_grad():
107
+ gt_out = smal_model(**gt_params_for_smal, pose2rot=False)
108
+ gt_vertices = gt_out.vertices.view(batch_size, -1, 3)
109
+
110
+ if gt_cam_t is None:
111
+ gt_cam_t = pred_cam_t
112
+
113
+ # depth = z in camera coordinates
114
+ pred_depth = (pred_vertices + pred_cam_t.unsqueeze(1))[..., 2] # (B, V)
115
+ gt_depth = (gt_vertices + gt_cam_t.unsqueeze(1))[..., 2] # (B, V)
116
+
117
+ # Compute loss per sample
118
+ if self.loss_type == 'l1':
119
+ loss_per_sample = self.l1(pred_depth, gt_depth).mean(dim=1) # (B,)
120
+ else:
121
+ loss_per_sample = self.l2(pred_depth, gt_depth).mean(dim=1) # (B,)
122
+
123
+ # Apply mask: only compute loss for samples with valid GT
124
+ masked_loss = loss_per_sample * valid_mask.float()
125
+
126
+ # Return mean over valid samples
127
+ num_valid = valid_mask.sum().float()
128
+ if num_valid > 0:
129
+ return masked_loss.sum() / num_valid
130
+ else:
131
+ return torch.tensor(0., device=device, dtype=pred_vertices.dtype)
132
+
133
+
134
+ class MaskLoss(nn.Module):
135
+ """
136
+ Mask loss between rendered predicted mesh mask and rendered GT mesh mask.
137
+ This loss relies on a MeshRenderer-like object that provides `render_mask(vertices, camera_translation, focal_length)`
138
+ returning a single-channel numpy mask (H, W) with values 0/1.
139
+ """
140
+ def __init__(self, mesh_renderer=None):
141
+ super().__init__()
142
+ self.mesh_renderer = mesh_renderer
143
+ self.l1 = nn.L1Loss(reduction='mean')
144
+
145
+ def forward(self,
146
+ pred_vertices: torch.Tensor, # (B, V, 3)
147
+ pred_cam_t: torch.Tensor, # (B, 3)
148
+ gt_smal_params: Dict[str, torch.Tensor],
149
+ smal_model, # SMAL instance callable
150
+ is_axis_angle: Dict[str, torch.Tensor],
151
+ gt_cam_t: torch.Tensor = None, # optional (B,3)
152
+ focal_length: float = 1000.0
153
+ ) -> torch.Tensor:
154
+ batch_size = pred_vertices.shape[0]
155
+ device = pred_vertices.device
156
+
157
+ # if no renderer available, return zero loss
158
+ if self.mesh_renderer is None:
159
+ return torch.tensor(0., device=device, dtype=pred_vertices.dtype)
160
+
161
+ # prepare GT params for SMAL
162
+ gt_params_for_smal = {}
163
+ for k in ['global_orient', 'pose', 'betas']:
164
+ val = gt_smal_params[k].to(device=device)
165
+ if k == 'betas':
166
+ gt_params_for_smal[k] = val.view(batch_size, -1)
167
+ else:
168
+ gt_val = val.view(batch_size, -1)
169
+ if is_axis_angle[k].all():
170
+ gt_val = aa_to_rotmat(gt_val.reshape(-1, 3)).view(batch_size, -1, 3, 3)
171
+ else:
172
+ gt_val = gt_val.view(batch_size, -1, 3, 3)
173
+ gt_params_for_smal[k] = gt_val
174
+
175
+ # generate GT vertices (no grad)
176
+ with torch.no_grad():
177
+ gt_out = smal_model(**gt_params_for_smal, pose2rot=False)
178
+ gt_vertices = gt_out.vertices
179
+
180
+ if gt_cam_t is None:
181
+ gt_cam_t = pred_cam_t
182
+
183
+ # convert to numpy for renderer
184
+ pred_vertices_np = pred_vertices.detach().cpu().numpy()
185
+ gt_vertices_np = gt_vertices.detach().cpu().numpy()
186
+ cam_np = pred_cam_t.detach().cpu().numpy() if pred_cam_t is not None else np.zeros((batch_size, 3), dtype=np.float32)
187
+
188
+ per_item_losses = []
189
+ for i in range(batch_size):
190
+ try:
191
+ pred_mask = self.mesh_renderer.render_mask(pred_vertices_np[i], cam_np[i], focal_length)
192
+ gt_mask_r = self.mesh_renderer.render_mask(gt_vertices_np[i], cam_np[i], focal_length)
193
+ pm = torch.from_numpy(pred_mask).to(device=device, dtype=pred_vertices.dtype)
194
+ gm = torch.from_numpy(gt_mask_r).to(device=device, dtype=pred_vertices.dtype)
195
+ per_item_losses.append(self.l1(pm, gm))
196
+ except Exception:
197
+ # ignore render failure for this sample
198
+ continue
199
+
200
+ if len(per_item_losses) == 0:
201
+ return torch.tensor(0., device=device, dtype=pred_vertices.dtype)
202
+
203
+ return torch.stack(per_item_losses).mean()
204
+
205
+ class Keypoint2DLoss(nn.Module):
206
+
207
+ def __init__(self, loss_type: str = 'l1'):
208
+ """
209
+ 2D keypoint loss module.
210
+ Args:
211
+ loss_type (str): Choose between l1 and l2 losses.
212
+ """
213
+ super(Keypoint2DLoss, self).__init__()
214
+ if loss_type == 'l1':
215
+ self.loss_fn = nn.L1Loss(reduction='none')
216
+ elif loss_type == 'l2':
217
+ self.loss_fn = nn.MSELoss(reduction='none')
218
+ else:
219
+ raise NotImplementedError('Unsupported loss function')
220
+
221
+ def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor:
222
+ """
223
+ Compute 2D reprojection loss on the keypoints.
224
+ Args:
225
+ pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
226
+ gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence.
227
+ Returns:
228
+ torch.Tensor: 2D keypoint loss.
229
+ """
230
+ conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
231
+ batch_size = conf.shape[0]
232
+ loss = (conf * self.loss_fn(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).sum(dim=(1, 2))
233
+ return loss.sum()
234
+
235
+
236
+ class Keypoint3DLoss(nn.Module):
237
+
238
+ def __init__(self, loss_type: str = 'l1'):
239
+ """
240
+ 3D keypoint loss module.
241
+ Args:
242
+ loss_type (str): Choose between l1 and l2 losses.
243
+ """
244
+ super(Keypoint3DLoss, self).__init__()
245
+ if loss_type == 'l1':
246
+ self.loss_fn = nn.L1Loss(reduction='none')
247
+ elif loss_type == 'l2':
248
+ self.loss_fn = nn.MSELoss(reduction='none')
249
+ else:
250
+ raise NotImplementedError('Unsupported loss function')
251
+
252
+ def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor, pelvis_id: int = 0):
253
+ """
254
+ Compute 3D keypoint loss.
255
+ Args:
256
+ pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
257
+ gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence.
258
+ Returns:
259
+ torch.Tensor: 3D keypoint loss.
260
+ """
261
+ batch_size = pred_keypoints_3d.shape[0]
262
+ gt_keypoints_3d = gt_keypoints_3d.clone()
263
+ pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1)
264
+ gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, pelvis_id, :-1].unsqueeze(dim=1)
265
+ conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
266
+ gt_keypoints_3d = gt_keypoints_3d[:, :, :-1]
267
+ loss = (conf * self.loss_fn(pred_keypoints_3d, gt_keypoints_3d)).sum(dim=(1, 2))
268
+ return loss.sum()
269
+
270
+
271
+ class ParameterLoss(nn.Module):
272
+
273
+ def __init__(self):
274
+ """
275
+ SMAL parameter loss module.
276
+ """
277
+ super(ParameterLoss, self).__init__()
278
+ self.loss_fn = nn.MSELoss(reduction='none')
279
+
280
+ def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor):
281
+ """
282
+ Compute SMAL parameter loss.
283
+ Args:
284
+ pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas)
285
+ gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth MANO parameters.
286
+ Returns:
287
+ torch.Tensor: L2 parameter loss loss.
288
+ """
289
+ mask = torch.ones_like(pred_param, device=pred_param.device, dtype=pred_param.dtype)
290
+ batch_size = pred_param.shape[0]
291
+ num_dims = len(pred_param.shape)
292
+ mask_dimension = [batch_size] + [1] * (num_dims - 1)
293
+ has_param = has_param.type(pred_param.type()).view(*mask_dimension)
294
+ loss_param = (has_param * self.loss_fn(pred_param*mask, gt_param*mask))
295
+ return loss_param.sum()
296
+
297
+
298
+ class PosePriorLoss(nn.Module):
299
+ def __init__(self, path_prior):
300
+ super(PosePriorLoss, self).__init__()
301
+ with open(path_prior, "rb") as f:
302
+ data_prior = pickle.load(f, encoding="latin1")
303
+
304
+ self.register_buffer("mean_pose", torch.from_numpy(data_prior["mean_pose"]).float())
305
+ self.register_buffer("precs", torch.from_numpy(np.array(data_prior["pic"])).float())
306
+
307
+ use_index = np.ones(105, dtype=bool)
308
+ use_index[:3] = False # global rotation set False
309
+ self.register_buffer("use_index", torch.from_numpy(use_index).float())
310
+
311
+ def forward(self, x, has_gt):
312
+ """
313
+ Args:
314
+ x: (batch_size, 35, 3, 3)
315
+ has_gt: has pose?
316
+ Returns:
317
+ pose prior loss
318
+ """
319
+ if has_gt.sum() == len(has_gt):
320
+ return torch.tensor(0.0, dtype=x.dtype, device=x.device)
321
+ has_gt = has_gt.type(torch.bool)
322
+ x = x[~has_gt]
323
+ x = matrix_to_axis_angle(x.reshape(-1, 3, 3))
324
+ delta = x.reshape(-1, 35*3) - self.mean_pose
325
+ loss = torch.tensordot(delta, self.precs, dims=([1], [0])) * self.use_index
326
+ return (loss ** 2).mean()
327
+
328
+
329
+ class ShapePriorLoss(nn.Module):
330
+ def __init__(self, path_prior):
331
+ super(ShapePriorLoss, self).__init__()
332
+ with open(path_prior, "rb") as f:
333
+ data_prior = pickle.load(f, encoding="latin1")
334
+
335
+ model_covs = np.array(data_prior["cluster_cov"]) # shape: (5, 41, 41)
336
+ inverse_covs = np.stack(
337
+ [np.linalg.inv(model_cov + 1e-5 * np.eye(model_cov.shape[0])) for model_cov in model_covs],
338
+ axis=0)
339
+ prec = np.stack([np.linalg.cholesky(inverse_cov) for inverse_cov in inverse_covs], axis=0)
340
+
341
+ self.register_buffer("betas_prec", torch.FloatTensor(prec))
342
+ self.register_buffer("mean_betas", torch.FloatTensor(data_prior["cluster_means"]))
343
+
344
+ def forward(self, x, category, has_gt):
345
+ """
346
+ Args:
347
+ x: predicted betas (batch_size, 41)
348
+ category: animal category (batch_size,)
349
+ has_gt: has shape?
350
+ Returns:
351
+ shape prior loss
352
+ """
353
+ if has_gt.sum() == len(has_gt):
354
+ return torch.tensor(0.0, dtype=x.dtype, device=x.device)
355
+ has_gt = has_gt.type(torch.bool)
356
+ x, category = x[~has_gt], category[~has_gt]
357
+ delta = (x - self.mean_betas[category.long()]) # [batch_size, 41]
358
+ loss = []
359
+ for x0, c0 in zip(delta, category):
360
+ loss.append(torch.tensordot(x0, self.betas_prec[c0], dims=([0], [0])))
361
+ loss = torch.stack(loss, dim=0)
362
+ return (loss ** 2).mean()
363
+
364
+
365
+
366
+ class PrototypeSupConLoss(nn.Module):
367
+ def __init__(self, prototypes_init, feat_dim=128, temperature=0.1):
368
+ """
369
+ prototypes_init: precomputed (5, 512) BioCLIP family prototypes
370
+ feat_dim: dimension of the projected shape feature (128)
371
+ """
372
+ super().__init__()
373
+ self.temperature = temperature
374
+
375
+ # The prototypes should live in the projected feature space.
376
+ # A practical setup is to pass the BioCLIP centers through the projector
377
+ # once at the beginning of training to initialize these prototypes.
378
+ self.register_buffer("prototypes", torch.randn(5, feat_dim))
379
+
380
+ def forward(self, features, labels):
381
+ """
382
+ features: (B, 128) normalized shared features or shape features
383
+ labels: (B,) family indices for the 5-way classification setting
384
+ """
385
+ # 1. Ensure features are normalized.
386
+ features = F.normalize(features, p=2, dim=1)
387
+ # 2. Ensure prototypes are normalized as well.
388
+ prototypes = F.normalize(self.prototypes, p=2, dim=1)
389
+
390
+ # 3. Compute sample-to-prototype similarities with temperature scaling.
391
+ logits = torch.matmul(features, prototypes.T) / self.temperature
392
+
393
+ # 4. Cross-entropy pulls samples toward their family prototype and
394
+ # pushes them away from the other family prototypes.
395
+ loss = F.cross_entropy(logits, labels)
396
+
397
+ return loss
398
+
399
+ @torch.no_grad()
400
+ def update_prototypes(self, features, labels, momentum=0.999):
401
+ """
402
+ Optional: update prototypes with momentum during training so they
403
+ adapt gradually to the 3D task.
404
+ """
405
+ for i in range(5):
406
+ mask = (labels == i)
407
+ if mask.any():
408
+ new_mean = features[mask].mean(dim=0)
409
+ self.prototypes[i] = momentum * self.prototypes[i] + (1 - momentum) * new_mean
410
+
411
+
412
+ class SupConLoss(nn.Module):
413
+ def __init__(self, temperature=0.1, contrast_mode='all',
414
+ base_temperature=0.07):
415
+ super(SupConLoss, self).__init__()
416
+ self.temperature = temperature
417
+ self.contrast_mode = contrast_mode
418
+ self.base_temperature = base_temperature
419
+
420
+ def forward(self, features, labels=None, mask=None):
421
+ """
422
+ Args:
423
+ features: hidden vector of shape [bsz, ...].
424
+ labels: ground truth of shape [bsz].
425
+ mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
426
+ has the same class as sample i. Can be asymmetric.
427
+ Returns:
428
+ A loss scalar.
429
+ """
430
+ features = torch.stack((features, features), dim=1)
431
+ device = features.device
432
+
433
+ if len(features.shape) < 3:
434
+ raise ValueError('`features` needs to be [bsz, n_views, ...],'
435
+ 'at least 3 dimensions are required')
436
+ if len(features.shape) > 3:
437
+ features = features.view(features.shape[0], features.shape[1], -1)
438
+
439
+ batch_size = features.shape[0]
440
+ if labels is not None and mask is not None:
441
+ raise ValueError('Cannot define both `labels` and `mask`')
442
+ elif labels is None and mask is None:
443
+ mask = torch.eye(batch_size, dtype=torch.float32).to(device)
444
+ elif labels is not None:
445
+ labels = labels.contiguous().view(-1, 1)
446
+ if labels.shape[0] != batch_size:
447
+ raise ValueError('Num of labels does not match num of features')
448
+ mask = torch.eq(labels, labels.T).float().to(device)
449
+ else:
450
+ mask = mask.float().to(device)
451
+
452
+ contrast_count = features.shape[1]
453
+ contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
454
+ if self.contrast_mode == 'one':
455
+ anchor_feature = features[:, 0]
456
+ anchor_count = 1
457
+ elif self.contrast_mode == 'all':
458
+ anchor_feature = contrast_feature
459
+ anchor_count = contrast_count
460
+ else:
461
+ raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
462
+
463
+ # compute logits
464
+ anchor_dot_contrast = torch.div(
465
+ torch.matmul(anchor_feature, contrast_feature.T),
466
+ self.temperature)
467
+ # for numerical stability
468
+ logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
469
+ logits = anchor_dot_contrast - logits_max.detach()
470
+
471
+ # tile mask
472
+ mask = mask.repeat(anchor_count, contrast_count)
473
+ # mask-out self-contrast cases
474
+ logits_mask = torch.scatter(
475
+ torch.ones_like(mask),
476
+ 1,
477
+ torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
478
+ 0
479
+ )
480
+ mask = mask * logits_mask
481
+
482
+ # compute log_prob
483
+ exp_logits = torch.exp(logits) * logits_mask
484
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
485
+
486
+ # compute mean of log-likelihood over positive
487
+ mask_pos_pairs = mask.sum(1)
488
+ mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
489
+ mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs
490
+
491
+ # loss
492
+ loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
493
+ loss = loss.view(anchor_count, batch_size).mean()
494
+
495
+ return loss
496
+
497
+ # Auxiliary intermediate-supervision loss module.
498
+
499
+ class InterLoss(nn.Module):
500
+ def __init__(self, cfg):
501
+ super().__init__()
502
+ self.cfg = cfg
503
+ self.use_intermediate_supervision = cfg.LOSS.get('USE_INTERMEDIATE_SUPERVISION', True)
504
+ self.intermediate_weight = cfg.LOSS.get('INTERMEDIATE_WEIGHT', 0.5)
505
+
506
+ # 2D keypoint loss
507
+ self.keypoint_2d_loss = nn.MSELoss(reduction='none')
508
+
509
+ # 3D keypoint loss
510
+ self.keypoint_3d_loss = nn.MSELoss(reduction='none')
511
+
512
+ def forward(self, predictions, gt_data):
513
+ """
514
+ Args:
515
+ predictions: model outputs (pred_smal_params, pred_cam, pred_smal_params_list)
516
+ gt_data: dict containing ground-truth data
517
+ - 'keypoints_2d': [B, N, 3] (x, y, visibility)
518
+ - 'keypoints_3d': [B, N, 3] (x, y, z) or [B, N, 4] (x, y, z, confidence)
519
+ """
520
+ pred_smal_params, pred_cam, pred_smal_params_list = predictions
521
+
522
+ losses = {}
523
+ total_loss = 0.0
524
+
525
+ # ========== Supervision for final predictions ==========
526
+ # Final keypoint supervision can be added here after running the
527
+ # predicted parameters through the SMAL model.
528
+
529
+ # ========== Supervision for intermediate predictions ==========
530
+ if self.use_intermediate_supervision and pred_smal_params_list is not None:
531
+
532
+ # 2D keypoint supervision
533
+ if 'keypoints_2d' in pred_smal_params_list and pred_smal_params_list['keypoints_2d'] is not None:
534
+ pred_kps_2d_all = pred_smal_params_list['keypoints_2d']
535
+ # [B*num_iters, N, 2]
536
+
537
+ gt_kps_2d = gt_data['keypoints_2d'][: , :, :2] # [B, N, 2]
538
+ gt_vis_2d = gt_data['keypoints_2d'][:, :, 2] # [B, N]
539
+
540
+ # Repeat the ground truth for each iteration.
541
+ num_iters = pred_kps_2d_all.shape[0] // gt_kps_2d.shape[0]
542
+ gt_kps_2d_repeated = gt_kps_2d.repeat(num_iters, 1, 1) # [B*num_iters, N, 2]
543
+ gt_vis_2d_repeated = gt_vis_2d.repeat(num_iters, 1) # [B*num_iters, N]
544
+
545
+ # Compute the loss only over visible keypoints.
546
+ loss_2d = self.keypoint_2d_loss(pred_kps_2d_all, gt_kps_2d_repeated)
547
+ loss_2d = loss_2d.mean(dim=-1) # [B*num_iters, N]
548
+ loss_2d = (loss_2d * gt_vis_2d_repeated).sum() / (gt_vis_2d_repeated.sum() + 1e-6)
549
+
550
+ losses['intermediate_keypoints_2d'] = loss_2d * self.intermediate_weight
551
+ total_loss += losses['intermediate_keypoints_2d']
552
+
553
+ # 3D keypoint supervision
554
+ if 'keypoints_3d' in pred_smal_params_list and pred_smal_params_list['keypoints_3d'] is not None:
555
+ pred_kps_3d_all = pred_smal_params_list['keypoints_3d']
556
+ # [B*num_iters, N, 3]
557
+
558
+ gt_kps_3d = gt_data['keypoints_3d'][: , :, :3] # [B, N, 3]
559
+ if gt_data['keypoints_3d'].shape[-1] == 4:
560
+ gt_conf_3d = gt_data['keypoints_3d'][:, :, 3] # [B, N]
561
+ else:
562
+ gt_conf_3d = torch.ones_like(gt_kps_3d[:, :, 0]) # All keypoints are valid.
563
+
564
+ # Repeat the ground truth for each iteration.
565
+ num_iters = pred_kps_3d_all.shape[0] // gt_kps_3d.shape[0]
566
+ gt_kps_3d_repeated = gt_kps_3d.repeat(num_iters, 1, 1)
567
+ gt_conf_3d_repeated = gt_conf_3d.repeat(num_iters, 1)
568
+
569
+ # Compute the loss.
570
+ loss_3d = self.keypoint_3d_loss(pred_kps_3d_all, gt_kps_3d_repeated)
571
+ loss_3d = loss_3d.mean(dim=-1) # [B*num_iters, N]
572
+ loss_3d = (loss_3d * gt_conf_3d_repeated).sum() / (gt_conf_3d_repeated.sum() + 1e-6)
573
+
574
+ losses['intermediate_keypoints_3d'] = loss_3d * self.intermediate_weight
575
+ total_loss += losses['intermediate_keypoints_3d']
576
+
577
+ # ... other losses (pose, shape, etc.) ...
578
+
579
+ losses['total'] = total_loss
580
+ return losses
prima/models/prima.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import torch
11
+ import pickle
12
+ import pytorch_lightning as pl
13
+ from typing import Any, Dict
14
+ from yacs.config import CfgNode
15
+
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import numpy as np
19
+
20
+ from torchvision.utils import make_grid
21
+ from ..utils.geometry import perspective_projection, aa_to_rotmat
22
+ from ..utils.pylogger import get_pylogger
23
+ from .backbones import create_backbone
24
+ from .heads import build_smal_head
25
+ from prima.models.smal_wrapper import SMAL
26
+ from .discriminator import Discriminator
27
+
28
+ from .bioclip_embedding import BioClipEmbedding
29
+ import sys
30
+ from transformers import AutoModel, AutoFeatureExtractor
31
+ import einops
32
+
33
+ import open_clip
34
+
35
+
36
+ from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss, ShapePriorLoss, PosePriorLoss, SupConLoss
37
+ log = get_pylogger(__name__)
38
+
39
+
40
+ class PRIMA(pl.LightningModule):
41
+
42
+ def __init__(self, cfg: CfgNode, init_renderer: bool = True):
43
+ """
44
+ Setup PRIMA model
45
+ Args:
46
+ cfg (CfgNode): Config file as a yacs CfgNode
47
+ """
48
+ super().__init__()
49
+
50
+ # Save hyperparameters
51
+ self.save_hyperparameters(logger=False, ignore=['init_renderer'])
52
+
53
+ self.cfg = cfg
54
+ # Create backbone feature extractor
55
+
56
+ if cfg.MODEL.BACKBONE.TYPE =='vith':
57
+ self.backbone = create_backbone(cfg) # create vit backbone anyway, for inference, no config loading, just load ckpt weights
58
+
59
+ if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None): # pretrained exists and not none, then true
60
+
61
+ log.info(f'Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}')
62
+ state_dict = torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu', weights_only=True)['state_dict']
63
+ state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items()}
64
+
65
+ missing_keys, unexpected_keys = self.backbone.load_state_dict(state_dict, strict=False)
66
+
67
+
68
+ # freeze backbones
69
+ if cfg.MODEL.BACKBONE.get('FREEZE', False) and cfg.MODEL.BACKBONE.TYPE == 'vith':
70
+ log.info(f'Freezing first 2/3 blocks of vit backbone')
71
+ # Freeze patch embedding
72
+ if hasattr(self.backbone, 'patch_embed'):
73
+ for p in self.backbone.patch_embed.parameters():
74
+ p.requires_grad = False
75
+
76
+ # Freeze first 2/3 of transformer blocks
77
+ if hasattr(self.backbone, 'blocks'):
78
+ total_blocks = len(self.backbone.blocks)
79
+ freeze_blocks = int(total_blocks * 2 / 3)
80
+ log.info(f'Freezing {freeze_blocks} out of {total_blocks} blocks')
81
+ for i in range(freeze_blocks):
82
+ for p in self.backbone.blocks[i].parameters():
83
+ p.requires_grad = False
84
+
85
+ # Create SMAL head (predicts SMAL params + perspective camera)
86
+ self.smal_head = build_smal_head(cfg)
87
+
88
+ # Instantiate SMAL model
89
+ smal_model_path = cfg.SMAL.MODEL_PATH
90
+ with open(smal_model_path, 'rb') as f:
91
+ smal_cfg = pickle.load(f, encoding="latin1")
92
+ self.smal = SMAL(**smal_cfg)
93
+
94
+ # create bioclip model for species classification token extraction
95
+ use_bioclip_embedding = cfg.MODEL.get('USE_BIOCLIP_EMBEDDING', False)
96
+ if use_bioclip_embedding:
97
+ bioclip_config = cfg.MODEL.get('BIOCLIP_EMBEDDING', {})
98
+ embed_dim = bioclip_config.get('EMBED_DIM', 1280)
99
+ self.bioclip_embedding = BioClipEmbedding(cfg, embed_dim=embed_dim)
100
+ # Freeze BioClip model by default
101
+ for param in self.bioclip_embedding.species_model.parameters():
102
+ param.requires_grad = False
103
+ else:
104
+ self.bioclip_embedding = None
105
+
106
+ # Create discriminator
107
+ self.discriminator = Discriminator()
108
+
109
+
110
+
111
+
112
+ # Define loss functions
113
+ self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1')
114
+ self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1')
115
+
116
+ if self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP2D', 0) > 0:
117
+ self.intermediate_kp2d_loss = Keypoint2DLoss(loss_type='l1')
118
+ if self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP3D', 0) > 0:
119
+ self.intermediate_kp3d_loss = Keypoint3DLoss(loss_type='l1')
120
+ self.smal_parameter_loss = ParameterLoss()
121
+ self.shape_prior_loss = ShapePriorLoss(path_prior=cfg.SMAL.SHAPE_PRIOR_PATH)
122
+ self.pose_prior_loss = PosePriorLoss(path_prior=cfg.SMAL.POSE_PRIOR_PATH)
123
+ self.supcon_loss = SupConLoss()
124
+
125
+
126
+ self.register_buffer('initialized', torch.tensor(False))
127
+
128
+ # init depth renderer for supervised training
129
+ # Setup renderer for visualization
130
+ if init_renderer:
131
+ from ..utils import MeshRenderer
132
+
133
+ self.mesh_renderer = MeshRenderer(self.cfg, faces=self.smal.faces.numpy())
134
+ else:
135
+ self.mesh_renderer = None
136
+
137
+ # Disable automatic optimization since we use adversarial training
138
+ self.automatic_optimization = False
139
+
140
+ def get_parameters(self):
141
+ all_params = list(self.smal_head.parameters())
142
+ if self.cfg.MODEL.BACKBONE.TYPE in ['vith', 'dinov2', 'dinov3']:
143
+ all_params += list(self.backbone.parameters())
144
+
145
+
146
+ if hasattr(self, 'keypoint_projection') and self.keypoint_projection is not None:
147
+ all_params += list(self.keypoint_projection.parameters())
148
+ if hasattr(self, 'bioclip_embedding') and self.bioclip_embedding is not None:
149
+ # Only add projection parameters as the model itself is frozen
150
+ all_params += list(self.bioclip_embedding.projection.parameters())
151
+ return all_params
152
+
153
+ def configure_optimizers(self):
154
+ """
155
+ Setup model and discriminator Optimizers
156
+ Returns:
157
+ Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers
158
+ """
159
+ # Use separate learning rates only for vith backbone
160
+ if self.cfg.MODEL.BACKBONE.TYPE == 'vith':
161
+ # Separate backbone parameters and other parameters
162
+ backbone_params = []
163
+ other_params = []
164
+
165
+ # Collect backbone parameters
166
+ if hasattr(self, 'backbone'):
167
+ backbone_params = list(filter(lambda p: p.requires_grad, self.backbone.parameters()))
168
+
169
+ # Collect other parameters
170
+ other_params += list(self.smal_head.parameters())
171
+
172
+
173
+ if hasattr(self, 'keypoint_projection') and self.keypoint_projection is not None:
174
+ other_params += list(self.keypoint_projection.parameters())
175
+ if hasattr(self, 'bioclip_embedding') and self.bioclip_embedding is not None:
176
+ other_params += list(self.bioclip_embedding.projection.parameters())
177
+
178
+
179
+ # Filter only trainable parameters
180
+ other_params = list(filter(lambda p: p.requires_grad, other_params))
181
+
182
+ # Create parameter groups with different learning rates
183
+ param_groups = [
184
+ {'params': backbone_params, 'lr': self.cfg.TRAIN.LR / 10.0}, # Backbone: 1/10 lr
185
+ {'params': other_params, 'lr': self.cfg.TRAIN.LR} # Other modules: normal lr
186
+ ]
187
+
188
+ log.info(f'Using separate LR for vith backbone')
189
+ log.info(f'Backbone parameters: {len(backbone_params)}, lr={self.cfg.TRAIN.LR / 10.0}')
190
+ log.info(f'Other parameters: {len(other_params)}, lr={self.cfg.TRAIN.LR}')
191
+ else:
192
+ # Use same learning rate for all parameters
193
+ all_params = list(filter(lambda p: p.requires_grad, self.get_parameters()))
194
+ param_groups = [{'params': all_params, 'lr': self.cfg.TRAIN.LR}]
195
+ log.info(f'Using same LR for all parameters: {len(all_params)}, lr={self.cfg.TRAIN.LR}')
196
+
197
+ optimizer = torch.optim.AdamW(params=param_groups,
198
+ weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
199
+ if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0:
200
+ optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(),
201
+ lr=self.cfg.TRAIN.LR,
202
+ weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
203
+ else:
204
+ return optimizer,
205
+
206
+ return optimizer, optimizer_disc
207
+
208
+ def forward_step(self, batch: Dict, train: bool = False) -> Dict:
209
+ """
210
+ Run a forward step of the network
211
+ Args:
212
+ batch (Dict): Dictionary containing batch data
213
+ train (bool): Flag indicating whether it is training or validation mode
214
+ Returns:
215
+ Dict: Dictionary containing the regression output
216
+ """
217
+
218
+ # Use RGB image as input
219
+ x = batch['img'] # [B, 3, H, W]
220
+ batch_size = x.shape[0]
221
+
222
+ # Compute conditioning features using the backbone
223
+ if self.cfg.MODEL.BACKBONE.TYPE =='vith': # vit backbone return [1, 1280, 12, 16]
224
+ conditioning_feats, cls = self.backbone(x[:, :, :, 32:-32]) # reshape the input into [256, 192]
225
+ # return shape shape [B, D, Hp, Wp], [B, D]
226
+ if conditioning_feats.ndim == 4:
227
+ # Flatten spatial dimensions into sequence dimension: [B, D, Hp, Wp] -> [B, Hp*Wp, D]
228
+ B, D, Hp, Wp = conditioning_feats.shape
229
+ conditioning_feats = conditioning_feats.permute(0, 2, 3, 1).reshape(B, Hp * Wp, D) # [B, Hp*Wp, D]
230
+
231
+
232
+ # add bioclip embedding if enabled
233
+ if self.bioclip_embedding is not None:
234
+ species_feature = self.bioclip_embedding(batch['img']) # [B, embed_dim]
235
+
236
+ # concatenate species feature to conditioning_feats along token dimension
237
+ if len(conditioning_feats.shape) == 3:
238
+ # Token-wise concatenation: add species_feature as a single token
239
+ # (B, embed_dim) -> (B, 1, embed_dim)
240
+ species_token = species_feature.unsqueeze(1) # (B, 1, embed_dim)
241
+ # Concatenate along token dimension: (B, num_tokens, C) + (B, 1, embed_dim) -> (B, num_tokens + 1, C or embed_dim)
242
+ # Note: This requires C == embed_dim for consistent feature dimensions
243
+ conditioning_feats = torch.cat([conditioning_feats, species_token], dim=1) # (B, num_tokens + 1, C)
244
+ else:
245
+ # If conditioning_feats is 2D (B, C), concat directly along feature dimension
246
+ conditioning_feats = torch.cat([conditioning_feats, species_feature], dim=-1)
247
+
248
+ # Predict SMAL parameters and camera
249
+ pred_smal_params, pred_cam, extra_outputs = self.smal_head(conditioning_feats)
250
+
251
+
252
+ # Store useful regression outputs to the output dict
253
+ output = {}
254
+
255
+ if 'shape_feat' in extra_outputs:
256
+ output['shape_feat'] = extra_outputs['shape_feat']
257
+
258
+ if 'init_betas' in extra_outputs:
259
+ output['init_betas'] = extra_outputs['init_betas'].reshape(batch_size, -1)
260
+
261
+
262
+ output['pred_cam'] = pred_cam # [B, 3]
263
+ output['pred_smal_params'] = {k: v.clone() for k, v in pred_smal_params.items()}
264
+
265
+
266
+
267
+ # Compute camera translation
268
+ focal_length = batch['focal_length']
269
+
270
+ pred_cam_t = torch.stack([
271
+ pred_cam[:, 1],
272
+ pred_cam[:, 2],
273
+ 2 * focal_length[:, 0] / (self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] + 1e-9)
274
+ ], dim=-1) # [B, 3]
275
+
276
+ output['pred_cam_t'] = pred_cam_t # [B, 3]
277
+ output['focal_length'] = focal_length # [B, 2]
278
+
279
+ # Compute model vertices, joints and the projected joints
280
+ pred_smal_params['global_orient'] = pred_smal_params['global_orient'].reshape(batch_size, -1, 3, 3)
281
+ pred_smal_params['pose'] = pred_smal_params['pose'].reshape(batch_size, -1, 3, 3)
282
+ pred_smal_params['betas'] = pred_smal_params['betas'].reshape(batch_size, -1)
283
+ smal_output = self.smal(**pred_smal_params, pose2rot=False)
284
+
285
+ pred_keypoints_3d = smal_output.joints
286
+ pred_vertices = smal_output.vertices
287
+ output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
288
+ output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
289
+
290
+ # project 3D keypoints to 2D
291
+ pred_keypoints_2d = perspective_projection(
292
+ pred_keypoints_3d,
293
+ translation=pred_cam_t,
294
+ focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE
295
+ ) # [B, num_joints, 2]
296
+ output['pred_keypoints_2d'] = pred_keypoints_2d
297
+
298
+ # get intermediate keypoint predictions if available
299
+
300
+ if 'keypoints_3d' in pred_smal_params and pred_smal_params['keypoints_3d'] is not None:
301
+ inter_keypoints_3d = pred_smal_params['keypoints_3d']
302
+ output['inter_keypoints_3d'] = inter_keypoints_3d.reshape(batch_size, -1, 3)
303
+ # output['use_intermediate_kp3d_loss'] = True
304
+
305
+ if 'keypoints_2d' in pred_smal_params and pred_smal_params['keypoints_2d'] is not None:
306
+ inter_keypoints_2d = pred_smal_params['keypoints_2d']
307
+ output['inter_keypoints_2d'] = inter_keypoints_2d.reshape(batch_size, -1, 2)
308
+ # output['use_intermediate_kp2d_loss'] = True
309
+
310
+ return output
311
+
312
+ def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor:
313
+ """
314
+ Compute losses given the input batch and the regression output
315
+ Args:
316
+ batch (Dict): Dictionary containing batch data
317
+ output (Dict): Dictionary containing the regression output
318
+ train (bool): Flag indicating whether it is training or validation mode
319
+ Returns:
320
+ torch.Tensor : Total loss for current batch
321
+ """
322
+
323
+ pred_smal_params = output['pred_smal_params']
324
+ pred_keypoints_2d = output['pred_keypoints_2d']
325
+ pred_keypoints_3d = output['pred_keypoints_3d']
326
+
327
+ if 'inter_keypoints_2d' in output:
328
+ inter_keypoints_2d = output['inter_keypoints_2d']
329
+ if 'inter_keypoints_3d' in output:
330
+ inter_keypoints_3d = output['inter_keypoints_3d']
331
+
332
+ batch_size = pred_smal_params['pose'].shape[0]
333
+ device = pred_smal_params['pose'].device
334
+ dtype = pred_smal_params['pose'].dtype
335
+
336
+ # Get annotations
337
+ gt_keypoints_2d = batch['keypoints_2d']
338
+ gt_keypoints_3d = batch['keypoints_3d']
339
+ gt_smal_params = batch['smal_params']
340
+ gt_mask = batch['mask']
341
+ has_smal_params = batch['has_smal_params']
342
+ is_axis_angle = batch['smal_params_is_axis_angle']
343
+ has_mask = batch['has_mask']
344
+
345
+ # Compute 2D keypoint loss
346
+ loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
347
+
348
+ # Compute 3D keypoint loss
349
+ loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
350
+
351
+ # Compute intermediate 2D keypoint loss if available
352
+ loss_intermediate_kp2d = torch.tensor(0., device=device, dtype=dtype)
353
+ if 'inter_keypoints_2d' in output:
354
+ loss_intermediate_kp2d = self.intermediate_kp2d_loss(inter_keypoints_2d, gt_keypoints_2d)
355
+ # loss_keypoints_2d = loss_keypoints_2d + loss_intermediate_kp2d
356
+
357
+ # Compute intermediate 3D keypoint loss if available
358
+ loss_intermediate_kp3d = torch.tensor(0., device=device, dtype=dtype)
359
+ if 'inter_keypoints_3d' in output:
360
+ loss_intermediate_kp3d = self.intermediate_kp3d_loss(inter_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
361
+ # loss_keypoints_3d = loss_keypoints_3d + loss_intermediate_kp3d
362
+
363
+ # add intermediate keypoint losses if available
364
+
365
+ # Compute loss on SMAL parameters
366
+ loss_smal_params = {}
367
+ for k, pred in pred_smal_params.items():
368
+ # Skip keypoint predictions - they're handled separately
369
+ if k in ['keypoints_2d', 'keypoints_3d']:
370
+ continue
371
+
372
+ gt = gt_smal_params[k].view(batch_size, -1)
373
+ if is_axis_angle[k].all():
374
+ gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3)
375
+ has_gt = has_smal_params[k]
376
+
377
+ # Only compute parameter loss if ANY sample has GT
378
+ param_loss = self.smal_parameter_loss(pred.reshape(batch_size, -1),
379
+ gt.reshape(batch_size, -1),
380
+ has_gt)
381
+
382
+ if k == "betas":
383
+ # Only add shape prior loss if NOT all samples have GT (prior is regularization for samples without GT)
384
+ # But the shape_prior_loss already handles this check internally
385
+ loss_smal_params[k] = param_loss + self.shape_prior_loss(pred, batch["category"], has_gt)
386
+ if 'init_betas' in output:
387
+ init_betas = output['init_betas']
388
+ loss_smal_params[k] = loss_smal_params[k] + self.shape_prior_loss(init_betas, batch["category"], has_gt) / 2.
389
+
390
+ else:
391
+ # Only add pose prior loss if NOT all samples have GT
392
+ # The pose_prior_loss already handles this check internally
393
+ loss_smal_params[k] = param_loss + \
394
+ self.pose_prior_loss(torch.cat((pred_smal_params["global_orient"],
395
+ pred_smal_params["pose"]),
396
+ dim=1), has_gt) / 2.
397
+ if 'shape_feat' in output:
398
+ loss_supcon = self.supcon_loss(output['shape_feat'], labels=batch['category'])
399
+ else:
400
+ loss_supcon = torch.tensor(0., device=device, dtype=dtype)
401
+ loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d + \
402
+ self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d + \
403
+ sum([loss_smal_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_smal_params]) + \
404
+ self.cfg.LOSS_WEIGHTS['SUPCON'] * loss_supcon
405
+
406
+ if 'inter_keypoints_2d' in output:
407
+ loss = loss + self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP2D', 0) * loss_intermediate_kp2d
408
+ if 'inter_keypoints_3d' in output:
409
+ loss = loss + self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP3D', 0) * loss_intermediate_kp3d
410
+
411
+
412
+ losses = dict(loss=loss.detach(),
413
+ loss_keypoints_2d=loss_keypoints_2d.detach(),
414
+ loss_keypoints_3d=loss_keypoints_3d.detach(),
415
+ loss_supcon=loss_supcon.detach(),
416
+ )
417
+
418
+ for k, v in loss_smal_params.items():
419
+ losses['loss_' + k] = v.detach()
420
+
421
+ # attach intermediate keypoint losses if computed
422
+ if 'inter_keypoints_2d' in output:
423
+ losses['loss_inter_keypoints_2d'] = loss_intermediate_kp2d.detach()
424
+ if 'inter_keypoints_3d' in output:
425
+ losses['loss_inter_keypoints_3d'] = loss_intermediate_kp3d.detach()
426
+
427
+
428
+
429
+ output['losses'] = losses
430
+
431
+ return loss
432
+
433
+ def forward(self, batch: Dict) -> Dict:
434
+ """
435
+ Run a forward step of the network in val mode
436
+ Args:
437
+ batch (Dict): Dictionary containing batch data
438
+ Returns:
439
+ Dict: Dictionary containing the regression output
440
+ """
441
+ return self.forward_step(batch, train=False)
442
+
443
+ def training_step_discriminator(self, batch: Dict,
444
+ pose: torch.Tensor,
445
+ betas: torch.Tensor,
446
+ optimizer: torch.optim.Optimizer) -> torch.Tensor:
447
+ """
448
+ Run a discriminator training step
449
+ Args:
450
+ batch (Dict): Dictionary containing mocap batch data
451
+ pose (torch.Tensor): Regressed pose from current step
452
+ betas (torch.Tensor): Regressed betas from current step
453
+ optimizer (torch.optim.Optimizer): Discriminator optimizer
454
+ Returns:
455
+ torch.Tensor: Discriminator loss
456
+ """
457
+ batch_size = pose.shape[0]
458
+ gt_pose = batch['pose']
459
+ gt_betas = batch['betas']
460
+ gt_rotmat = aa_to_rotmat(gt_pose.view(-1, 3)).view(batch_size, -1, 3, 3)
461
+ disc_fake_out = self.discriminator(pose.detach(), betas.detach())
462
+ loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size
463
+ disc_real_out = self.discriminator(gt_rotmat.detach(), gt_betas.detach())
464
+ loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size
465
+ loss_disc = loss_fake + loss_real
466
+ loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc
467
+ optimizer.zero_grad()
468
+ self.manual_backward(loss)
469
+ optimizer.step()
470
+ return loss_disc.detach()
471
+
472
+ # Tensoroboard logging should run from first rank only
473
+ @pl.utilities.rank_zero.rank_zero_only
474
+ def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True,
475
+ write_to_summary_writer: bool = True) -> None:
476
+ """
477
+ Log results to Tensorboard
478
+ Args:
479
+ batch (Dict): Dictionary containing batch data
480
+ output (Dict): Dictionary containing the regression output
481
+ step_count (int): Global training step count
482
+ train (bool): Flag indicating whether it is training or validation mode
483
+ """
484
+
485
+ mode = 'train' if train else 'val'
486
+
487
+ images = batch['img']
488
+ gt_keypoints_2d = batch['keypoints_2d']
489
+ batch_size = images.shape[0]
490
+
491
+ # mul std then add mean
492
+ images = (images) * (torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1))
493
+ images = (images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1))
494
+
495
+ pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3)
496
+ losses = output['losses']
497
+ pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3)
498
+ pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2)
499
+
500
+ if write_to_summary_writer:
501
+ summary_writer = self.logger.experiment
502
+ for loss_name, val in losses.items():
503
+ summary_writer.add_scalar(mode + '/' + loss_name, val.detach().item(), step_count)
504
+ # if train is False:
505
+ # for metric_name, val in output['metric'].items():
506
+ # summary_writer.add_scalar(mode + '/' + metric_name, val, step_count)
507
+ num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES)
508
+
509
+ predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(),
510
+ pred_cam_t[:num_images].cpu().numpy(),
511
+ images[:num_images].cpu().numpy(),
512
+ self.cfg.SMAL.get("FOCAL_LENGTH", 1000),
513
+ pred_keypoints_2d[:num_images].cpu().numpy(),
514
+ gt_keypoints_2d[:num_images].cpu().numpy(),
515
+ pred_masks=output.get('pred_masks', None)[:num_images] if output.get('pred_masks', None) is not None else None,
516
+ gt_masks=output.get('gt_masks', None)[:num_images] if output.get('gt_masks', None) is not None else None,
517
+ )
518
+ predictions = make_grid(predictions, nrow=5, padding=2)
519
+ if write_to_summary_writer:
520
+ summary_writer.add_image('%s/predictions' % mode, predictions, step_count)
521
+
522
+ return predictions
523
+
524
+ def training_step(self, batch: Dict) -> Dict:
525
+ """
526
+ Run a full training step
527
+ Args:
528
+ batch (Dict): Dictionary containing {'img', 'mask', 'keypoints_2d', 'keypoints_3d', 'orig_keypoints_2d',
529
+ 'box_center', 'box_size', 'img_size', 'smal_params',
530
+ 'smal_params_is_axis_angle', '_trans', 'imgname', 'focal_length'}
531
+ Returns:
532
+ Dict: Dictionary containing regression output.
533
+ """
534
+ batch = batch['img']
535
+ optimizer = self.optimizers(use_pl_optimizer=True)
536
+ if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0:
537
+ optimizer, optimizer_disc = optimizer
538
+
539
+ batch_size = batch['img'].shape[0]
540
+ output = self.forward_step(batch, train=True)
541
+ pred_smal_params = output['pred_smal_params']
542
+ loss = self.compute_loss(batch, output, train=True)
543
+ if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0:
544
+ disc_out = self.discriminator(pred_smal_params['pose'].reshape(batch_size, -1),
545
+ pred_smal_params['betas'].reshape(batch_size, -1))
546
+ loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size
547
+ loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv
548
+
549
+ # Error if Nan
550
+ if torch.isnan(loss):
551
+ raise ValueError('Loss is NaN')
552
+
553
+ optimizer.zero_grad()
554
+ self.manual_backward(loss)
555
+ # Clip gradient
556
+ if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0:
557
+ gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL,
558
+ error_if_nonfinite=True)
559
+ self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
560
+
561
+ # For compatibility
562
+ # if self.cfg.LOSS_WEIGHTS.ADVERSARIAL == 0:
563
+ # optimizer.param_groups[0]['capturable'] = True
564
+
565
+ optimizer.step()
566
+ if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0:
567
+ loss_disc = self.training_step_discriminator(batch['smal_params'],
568
+ pred_smal_params['pose'].reshape(batch_size, -1),
569
+ pred_smal_params['betas'].reshape(batch_size, -1),
570
+ optimizer_disc)
571
+ output['losses']['loss_gen'] = loss_adv
572
+ output['losses']['loss_disc'] = loss_disc
573
+
574
+ if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0:
575
+ self.tensorboard_logging(batch, output, self.global_step, train=True)
576
+
577
+ # Log training loss to the logger so checkpoint callback can monitor it.
578
+ self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True,
579
+ logger=True, batch_size=batch_size, sync_dist=True)
580
+
581
+ return output
582
+
583
+ def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict:
584
+ """
585
+ Run a validation step and log to Tensorboard
586
+ Args:
587
+ batch (Dict): Dictionary containing batch data
588
+ batch_idx (int): Unused.
589
+ Returns:
590
+ Dict: Dictionary containing regression output.
591
+ """
592
+ # The validation dataloader yields the inner batch dict directly (not wrapped as {'img': loader}).
593
+ # Run forward, compute loss and log aggregated validation metrics so ModelCheckpoint can monitor them.
594
+ output = self.forward_step(batch, train=False)
595
+ # compute_loss will populate output['losses'] and return the scalar loss
596
+ loss = self.compute_loss(batch, output, train=False)
597
+
598
+ # Ensure losses dict is available
599
+ losses = output.get('losses', {})
600
+
601
+ # Log all validation losses to logger with on_epoch=True so checkpoint monitors epoch-level metric
602
+ for loss_name, val in losses.items():
603
+ # use prog_bar only for the main loss
604
+ prog = True if loss_name == 'loss' else False
605
+ # Log as 'val/<loss_name>' e.g. 'val/loss'
606
+ self.log(f'val/{loss_name}', val, on_step=False, on_epoch=True, prog_bar=prog, logger=True,
607
+ sync_dist=True)
608
+
609
+ # Periodically write images/other visuals to tensorboard
610
+ # Log visualizations on the first batch of each validation epoch
611
+ if batch_idx == 0:
612
+ # Use global_step for step count when logging validation visuals
613
+ self.tensorboard_logging(batch, output, self.global_step, train=False)
614
+
615
+ return output
prima/models/smal_wrapper.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import json
11
+ from torch import nn
12
+ import torch
13
+ import numpy as np
14
+ import pickle
15
+ import cv2
16
+ from typing import Optional, Tuple, NewType
17
+ from dataclasses import dataclass
18
+ import smplx
19
+ from smplx.lbs import vertices2joints, lbs
20
+ from smplx.utils import MANOOutput, to_tensor, ModelOutput
21
+ from smplx.vertex_ids import vertex_ids
22
+
23
+ Tensor = NewType('Tensor', torch.Tensor)
24
+ keypoint_vertices_idx = [[1068, 1080, 1029, 1226], [2660, 3030, 2675, 3038], [910], [360, 1203, 1235, 1230],
25
+ [3188, 3156, 2327, 3183], [1976, 1974, 1980, 856], [3854, 2820, 3852, 3858], [452, 1811],
26
+ [416, 235, 182], [2156, 2382, 2203], [829], [2793], [60, 114, 186, 59],
27
+ [2091, 2037, 2036, 2160], [384, 799, 1169, 431], [2351, 2763, 2397, 3127],
28
+ [221, 104], [2754, 2192], [191, 1158, 3116, 2165],
29
+ [28, 1109, 1110, 1111, 1835, 1836, 3067, 3068, 3069],
30
+ [498, 499, 500, 501, 502, 503], [2463, 2464, 2465, 2466, 2467, 2468],
31
+ [764, 915, 916, 917, 934, 935, 956], [2878, 2879, 2880, 2897, 2898, 2919, 3751],
32
+ [1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762],
33
+ [0, 464, 465, 726, 1824, 2429, 2430, 2690]]
34
+
35
+ name2id35 = {'RFoot': 14, 'RFootBack': 24, 'spine1': 4, 'Head': 16, 'LLegBack3': 19, 'RLegBack1': 21, 'pelvis0': 1,
36
+ 'RLegBack3': 23, 'LLegBack2': 18, 'spine0': 3, 'spine3': 6, 'spine2': 5, 'Mouth': 32, 'Neck': 15,
37
+ 'LFootBack': 20, 'LLegBack1': 17, 'RLeg3': 13, 'RLeg2': 12, 'LLeg1': 7, 'LLeg3': 9, 'RLeg1': 11,
38
+ 'LLeg2': 8, 'spine': 2, 'LFoot': 10, 'Tail7': 31, 'Tail6': 30, 'Tail5': 29, 'Tail4': 28, 'Tail3': 27,
39
+ 'Tail2': 26, 'Tail1': 25, 'RLegBack2': 22, 'root': 0, 'LEar': 33, 'REar': 34, 'EndNose': 35, 'Chin': 36,
40
+ 'RightEarTip': 37, 'LeftEarTip': 38, 'LeftEye': 39, 'RightEye': 40}
41
+
42
+ @dataclass
43
+ class SMALOutput(ModelOutput):
44
+ betas: Optional[Tensor] = None
45
+ pose: Optional[Tensor] = None
46
+
47
+
48
+ class SMALLayer(nn.Module):
49
+ def __init__(self, num_betas=41, **kwargs):
50
+ super().__init__()
51
+ self.num_betas = num_betas
52
+ self.register_buffer("shapedirs", torch.from_numpy(np.array(kwargs['shapedirs'], dtype=np.float32))[:, :, :num_betas]) # [3889, 3, 41]
53
+ self.register_buffer("v_template", torch.from_numpy(np.array(kwargs['v_template']).astype(np.float32))) # [3889, 3]
54
+ self.register_buffer("posedirs", torch.from_numpy(np.array(kwargs['posedirs'], dtype=np.float32)).reshape(-1,
55
+ 34*9).T) # [34*9, 11667]
56
+ self.register_buffer("J_regressor", torch.from_numpy(kwargs['J_regressor'].toarray().astype(np.float32))) # [33, 3389]
57
+ self.register_buffer("lbs_weights", torch.from_numpy(np.array(kwargs['weights'], dtype=np.float32))) # [3889, 33]
58
+ self.register_buffer("faces", torch.from_numpy(np.array(kwargs['f'], dtype=np.int32))) # [7774, 3]
59
+
60
+ kintree_table = kwargs['kintree_table']
61
+ self.register_buffer("parents", torch.from_numpy(kintree_table[0].astype(np.int32)))
62
+
63
+ def forward(
64
+ self,
65
+ betas: Optional[Tensor] = None,
66
+ global_orient: Optional[Tensor] = None,
67
+ pose: Optional[Tensor] = None,
68
+ transl: Optional[Tensor] = None,
69
+ return_verts: bool = True,
70
+ return_full_pose: bool = False,
71
+ **kwargs):
72
+ """
73
+ Args:
74
+ betas: [batch_size, 10]
75
+ global_orient: [batch_size, 1, 3, 3]
76
+ pose: [batch_size, num_joints, 3, 3]
77
+ transl: [batch_size, num_joints, 3]
78
+ return_verts:
79
+ return_full_pose:
80
+ **kwargs:
81
+ Returns:
82
+ """
83
+ device, dtype = betas.device, betas.dtype
84
+ if global_orient is None:
85
+ batch_size = 1
86
+ global_orient = torch.eye(3, device=device, dtype=dtype).view(
87
+ 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
88
+ else:
89
+ batch_size = global_orient.shape[0]
90
+ if pose is None:
91
+ pose = torch.eye(3, device=device, dtype=dtype).view(
92
+ 1, 1, 3, 3).expand(batch_size, 34, -1, -1).contiguous()
93
+ if betas is None:
94
+ betas = torch.zeros(
95
+ [batch_size, self.num_betas], dtype=dtype, device=device)
96
+ if transl is None:
97
+ transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
98
+
99
+ full_pose = torch.cat([global_orient, pose], dim=1)
100
+ vertices, joints = lbs(betas, full_pose, self.v_template,
101
+ self.shapedirs, self.posedirs,
102
+ self.J_regressor, self.parents,
103
+ self.lbs_weights, pose2rot=False)
104
+
105
+ if transl is not None:
106
+ joints = joints + transl.unsqueeze(dim=1)
107
+ vertices = vertices + transl.unsqueeze(dim=1)
108
+
109
+ output = SMALOutput(
110
+ vertices=vertices if return_verts else None,
111
+ joints=joints if return_verts else None,
112
+ betas=betas,
113
+ global_orient=global_orient,
114
+ pose=pose,
115
+ transl=transl,
116
+ full_pose=full_pose if return_full_pose else None,
117
+ )
118
+ return output
119
+
120
+
121
+ class SMAL(SMALLayer):
122
+ def __init__(self, **kwargs):
123
+ super(SMAL, self).__init__(**kwargs)
124
+
125
+ def forward(self, *args, **kwargs):
126
+ smal_output = super(SMAL, self).forward(**kwargs)
127
+
128
+ keypoint = []
129
+ for kp_v in keypoint_vertices_idx:
130
+ keypoint.append(smal_output.vertices[:, kp_v, :].mean(dim=1))
131
+ smal_output.joints = torch.stack(keypoint, dim=1)
132
+ return smal_output
133
+
134
+
prima/utils/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from typing import Any
11
+
12
+
13
+ def recursive_to(x: Any, target: Any):
14
+ """
15
+ Recursively transfer a batch of data to the target device
16
+ Args:
17
+ x (Any): Batch of data.
18
+ target (torch.device): Target device.
19
+ Returns:
20
+ Batch of data where all tensors are transferred to the target device.
21
+ """
22
+ import torch
23
+
24
+ def move(value: Any):
25
+ if isinstance(value, dict):
26
+ return {k: move(v) for k, v in value.items()}
27
+ if isinstance(value, torch.Tensor):
28
+ return value.to(target)
29
+ if isinstance(value, list):
30
+ return [move(i) for i in value]
31
+ return value
32
+
33
+ return move(x)
34
+
35
+
36
+ def __getattr__(name: str):
37
+ if name == "MeshRenderer":
38
+ from .mesh_renderer import MeshRenderer
39
+
40
+ globals()[name] = MeshRenderer
41
+ return MeshRenderer
42
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
43
+
44
+
45
+ __all__ = ["MeshRenderer", "recursive_to"]
prima/utils/detection.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ # Utilities for filtering animal detections before PRIMA demo inference.
13
+ #
14
+ # Detectron2 may return both a full-animal box and a local/partial box for the
15
+ # same animal. These helpers keep the demo pipeline from rendering the same
16
+ # animal multiple times.
17
+
18
+ from typing import Iterable
19
+
20
+ import numpy as np
21
+
22
+ ANIMAL_COCO_IDS = (15, 16, 17, 18, 19, 21, 22)
23
+
24
+
25
+ def _box_areas(boxes: np.ndarray) -> np.ndarray:
26
+ widths = np.maximum(0.0, boxes[:, 2] - boxes[:, 0])
27
+ heights = np.maximum(0.0, boxes[:, 3] - boxes[:, 1])
28
+ return widths * heights
29
+
30
+
31
+ def _intersection_areas(box: np.ndarray, boxes: np.ndarray) -> np.ndarray:
32
+ x1 = np.maximum(box[0], boxes[:, 0])
33
+ y1 = np.maximum(box[1], boxes[:, 1])
34
+ x2 = np.minimum(box[2], boxes[:, 2])
35
+ y2 = np.minimum(box[3], boxes[:, 3])
36
+ return np.maximum(0.0, x2 - x1) * np.maximum(0.0, y2 - y1)
37
+
38
+
39
+ def _suppress_duplicate_boxes(
40
+ boxes: np.ndarray,
41
+ scores: np.ndarray,
42
+ *,
43
+ iou_threshold: float,
44
+ containment_threshold: float,
45
+ ) -> np.ndarray:
46
+ if len(boxes) <= 1:
47
+ return np.arange(len(boxes), dtype=np.int64)
48
+
49
+ boxes = boxes.astype(np.float32, copy=False)
50
+ scores = scores.astype(np.float32, copy=False)
51
+ areas = _box_areas(boxes)
52
+
53
+ contained = np.zeros(len(boxes), dtype=bool)
54
+ for idx, area in enumerate(areas):
55
+ if area <= 0:
56
+ contained[idx] = True
57
+ continue
58
+ larger = np.where(areas > area)[0]
59
+ if len(larger) == 0:
60
+ continue
61
+ covered = _intersection_areas(boxes[idx], boxes[larger]) / area
62
+ if np.any(covered >= containment_threshold):
63
+ contained[idx] = True
64
+
65
+ candidates = np.where(~contained)[0]
66
+ if len(candidates) <= 1:
67
+ return candidates
68
+
69
+ order = candidates[np.argsort(scores[candidates])[::-1]]
70
+ keep = []
71
+ while len(order) > 0:
72
+ current = order[0]
73
+ keep.append(current)
74
+ rest = order[1:]
75
+ if len(rest) == 0:
76
+ break
77
+
78
+ inter = _intersection_areas(boxes[current], boxes[rest])
79
+ union = areas[current] + areas[rest] - inter
80
+ iou = np.divide(inter, union, out=np.zeros_like(inter), where=union > 0)
81
+ order = rest[iou <= iou_threshold]
82
+
83
+ return np.array(sorted(keep), dtype=np.int64)
84
+
85
+
86
+ def select_animal_boxes(
87
+ det_instances,
88
+ *,
89
+ animal_class_ids: Iterable[int] = ANIMAL_COCO_IDS,
90
+ score_threshold: float = 0.7,
91
+ iou_threshold: float = 0.5,
92
+ containment_threshold: float = 0.9,
93
+ ) -> tuple[np.ndarray, int]:
94
+ """Return filtered animal boxes and the number of duplicate boxes removed."""
95
+ class_ids = set(int(class_id) for class_id in animal_class_ids)
96
+ classes = det_instances.pred_classes.detach().cpu().numpy()
97
+ scores = det_instances.scores.detach().cpu().numpy()
98
+
99
+ valid_idx = np.array(
100
+ [
101
+ i
102
+ for i, (class_id, score) in enumerate(zip(classes, scores))
103
+ if int(class_id) in class_ids and float(score) > float(score_threshold)
104
+ ],
105
+ dtype=np.int64,
106
+ )
107
+ if len(valid_idx) == 0:
108
+ return np.zeros((0, 4), dtype=np.float32), 0
109
+
110
+ boxes = det_instances.pred_boxes.tensor[valid_idx].detach().cpu().numpy()
111
+ scores = scores[valid_idx]
112
+ keep = _suppress_duplicate_boxes(
113
+ boxes,
114
+ scores,
115
+ iou_threshold=iou_threshold,
116
+ containment_threshold=containment_threshold,
117
+ )
118
+ return boxes[keep], int(len(boxes) - len(keep))
prima/utils/evaluate_metric.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import torch
11
+ import numpy as np
12
+ import open3d as o3d
13
+ from typing import Dict, List, Union
14
+ from pytorch3d.transforms import axis_angle_to_matrix
15
+
16
+
17
+ def compute_scale_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor:
18
+ """
19
+ Computes a scale transform (s) in a batched way that takes
20
+ a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3).
21
+ Args:
22
+ S1 (torch.Tensor): First set of points of shape (B, N, 3).
23
+ S2 (torch.Tensor): Second set of points of shape (B, N, 3).
24
+ Returns:
25
+ (torch.Tensor): The first set of points after applying the scale transformation.
26
+ """
27
+
28
+ # 1. Remove mean.
29
+ mu1 = S1.mean(dim=1, keepdim=True)
30
+ mu2 = S2.mean(dim=1, keepdim=True)
31
+ X1 = S1 - mu1
32
+ X2 = S2 - mu2
33
+
34
+ # 2. Compute variance of X1 used for scale.
35
+ var1 = (X1 ** 2).sum(dim=(1, 2), keepdim=True)
36
+
37
+ # 3. Compute scale.
38
+ scale = (X2 * X1).sum(dim=(1, 2), keepdim=True) / var1
39
+
40
+ # 4. Apply scale transform.
41
+ S1_hat = scale * X1 + mu2
42
+
43
+ return S1_hat
44
+
45
+
46
+ def compute_similarity_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Computes a similarity transform (sR, t) in a batched way that takes
49
+ a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3),
50
+ where R is a 3x3 rotation matrix, t 3x1 translation, s scale.
51
+ i.e. solves the orthogonal Procrutes problem.
52
+ Args:
53
+ S1 (torch.Tensor): First set of points of shape (B, N, 3).
54
+ S2 (torch.Tensor): Second set of points of shape (B, N, 3).
55
+ Returns:
56
+ (torch.Tensor): The first set of points after applying the similarity transformation.
57
+ """
58
+
59
+ batch_size = S1.shape[0]
60
+ S1 = S1.permute(0, 2, 1)
61
+ S2 = S2.permute(0, 2, 1)
62
+ # 1. Remove mean.
63
+ mu1 = S1.mean(dim=2, keepdim=True)
64
+ mu2 = S2.mean(dim=2, keepdim=True)
65
+ X1 = S1 - mu1
66
+ X2 = S2 - mu2
67
+
68
+ # 2. Compute variance of X1 used for scale.
69
+ var1 = (X1 ** 2).sum(dim=(1, 2))
70
+
71
+ # 3. The outer product of X1 and X2.
72
+ K = torch.matmul(X1.float(), X2.permute(0, 2, 1))
73
+
74
+ # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K.
75
+ U, s, V = torch.svd(K.float())
76
+ Vh = V.permute(0, 2, 1)
77
+
78
+ # Construct Z that fixes the orientation of R to get det(R)=1.
79
+ Z = torch.eye(U.shape[1], device=U.device).unsqueeze(0).repeat(batch_size, 1, 1).float()
80
+ Z[:, -1, -1] *= torch.sign(torch.linalg.det(torch.matmul(U.float(), Vh.float()).float()))
81
+
82
+ # Construct R.
83
+ R = torch.matmul(torch.matmul(V, Z), U.permute(0, 2, 1))
84
+
85
+ # 5. Recover scale.
86
+ trace = torch.matmul(R, K).diagonal(offset=0, dim1=-1, dim2=-2).sum(dim=-1)
87
+ scale = (trace / var1).unsqueeze(dim=-1).unsqueeze(dim=-1)
88
+
89
+ # 6. Recover translation.
90
+ t = mu2 - scale * torch.matmul(R.float(), mu1.float())
91
+
92
+ # 7. Error:
93
+ S1_hat = scale * torch.matmul(R.float(), S1.float()).float() + t
94
+
95
+ return S1_hat.permute(0, 2, 1)
96
+
97
+
98
+ def pointcloud(points: np.ndarray):
99
+ pcd = o3d.geometry.PointCloud()
100
+ points = o3d.utility.Vector3dVector(points)
101
+ pcd.points = points
102
+ return pcd
103
+
104
+
105
+ class Evaluator:
106
+ def __init__(self, smal_model, image_size: int=256, pelvis_ind: int = 7):
107
+ self.pelvis_ind = pelvis_ind
108
+ self.smal_model = smal_model
109
+ self.image_size = image_size
110
+
111
+ def compute_pck(self, output: Dict, batch: Dict, pck_threshold: Union[List, None]):
112
+ if pck_threshold is None or len(pck_threshold) == 0:
113
+ return torch.tensor([], dtype=torch.float32)
114
+
115
+ pred_keypoints_2d = output['pred_keypoints_2d'].detach().cpu()
116
+ gt_keypoints_2d = batch['keypoints_2d'].detach().cpu()
117
+
118
+ pred_keypoints_2d = (pred_keypoints_2d + 0.5) * self.image_size
119
+ conf = gt_keypoints_2d[:, :, -1]
120
+ gt_keypoints_2d = (gt_keypoints_2d[:, :, :-1] + 0.5) * self.image_size
121
+
122
+ if 'mask' in batch and batch['mask'] is not None:
123
+ seg_area = torch.sum(batch['mask'].detach().cpu().reshape(batch['mask'].shape[0], -1), dim=-1).unsqueeze(-1)
124
+ else:
125
+ seg_area = torch.tensor([self.image_size * self.image_size] * len(pred_keypoints_2d), dtype=torch.float32).unsqueeze(-1)
126
+
127
+ total_visible = torch.sum(conf, dim=-1).clamp_min(1e-6) # (B,)
128
+ dist = torch.norm(pred_keypoints_2d - gt_keypoints_2d, dim=-1) # (B, K)
129
+ norm_dist = dist / torch.sqrt(seg_area) # (B, K)
130
+
131
+ thresholds = torch.tensor(pck_threshold, dtype=torch.float32).view(-1, 1, 1) # (T, 1, 1)
132
+ hits = (norm_dist.unsqueeze(0) < thresholds).float() # (T, B, K)
133
+ pcks = (hits * conf.unsqueeze(0)).sum(dim=-1) / total_visible.unsqueeze(0) # (T, B)
134
+ return pcks.mean(dim=1) # (T,)
135
+
136
+ def compute_pa_mpjpe(self, pred_joints, gt_joints):
137
+ S1_hat = compute_similarity_transform(pred_joints, gt_joints)
138
+ pa_mpjpe = torch.sqrt(((S1_hat - gt_joints) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * 1000
139
+ return pa_mpjpe.mean()
140
+
141
+ def compute_pa_mpvpe(self, gt_vertices: torch.Tensor, pred_vertices: torch.Tensor):
142
+ batch_size = pred_vertices.shape[0]
143
+ S1_hat = compute_similarity_transform(pred_vertices, gt_vertices)
144
+ pa_mpvpe = torch.sqrt(((S1_hat - gt_vertices) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * 1000
145
+ return pa_mpvpe.mean()
146
+
147
+ def eval_3d(self, output: Dict, batch: Dict):
148
+ """
149
+ Evaluate current batch
150
+ Args:
151
+ output: model output
152
+ batch: model input
153
+ Returns: evaluate metric
154
+ """
155
+ if batch['has_smal_params']["betas"].sum() == 0:
156
+ return 0., 0.
157
+
158
+ pred_keypoints_3d = output["pred_keypoints_3d"].detach()
159
+ pred_keypoints_3d = pred_keypoints_3d[:, None, :, :]
160
+ batch_size = pred_keypoints_3d.shape[0]
161
+ num_samples = pred_keypoints_3d.shape[1]
162
+ gt_keypoints_3d = batch['keypoints_3d'][:, :, :-1].unsqueeze(1).repeat(1, num_samples, 1, 1)
163
+ gt_vertices = self.smal_forward(batch)
164
+
165
+ # Align predictions and ground truth such that the pelvis location is at the origin
166
+ pred_keypoints_3d -= pred_keypoints_3d[:, :, [self.pelvis_ind]]
167
+ gt_keypoints_3d -= gt_keypoints_3d[:, :, [self.pelvis_ind]]
168
+
169
+ pa_mpjpe = self.compute_pa_mpjpe(pred_keypoints_3d.reshape(batch_size * num_samples, -1, 3),
170
+ gt_keypoints_3d.reshape(batch_size * num_samples, -1, 3))
171
+ pa_mpvpe = self.compute_pa_mpvpe(gt_vertices, output['pred_vertices'])
172
+ return pa_mpjpe, pa_mpvpe
173
+
174
+ def eval_2d(self, output: Dict, batch: Dict, pck_threshold: List[float]=[0.10, 0.15]):
175
+ pck = self.compute_pck(output, batch, pck_threshold=pck_threshold)
176
+ auc = self.compute_auc(batch, output)
177
+ return pck.tolist(), auc
178
+
179
+ def compute_auc(self, batch: Dict, output: Dict, threshold_min: float=0.0, threshold_max: float=1.0, steps: int=100):
180
+ thresholds = np.linspace(threshold_min, threshold_max, steps)
181
+ pck_curve = self.compute_pck(output, batch, thresholds.tolist()).numpy() # (steps,)
182
+ norm_factor = threshold_max - threshold_min
183
+ auc = float(np.trapz(pck_curve, thresholds) / norm_factor)
184
+ return auc
185
+
186
+ def smal_forward(self, batch: Dict):
187
+ batch_size = batch['img'].shape[0]
188
+ smal_params = batch['smal_params']
189
+ smal_params['global_orient'] = axis_angle_to_matrix(smal_params['global_orient'].reshape(batch_size, -1)).unsqueeze(1)
190
+ smal_params['pose'] = axis_angle_to_matrix(smal_params['pose'].reshape(batch_size, -1, 3))
191
+ # The SMAL model only registers buffers (e.g. shapedirs) and has no trainable parameters,
192
+ # so self.smal_model.parameters() can be empty and calling next on it would raise StopIteration.
193
+ # Here we first try to get the device from parameters; if there are no parameters, fall back to buffers;
194
+ # if there are no buffers either, fall back to the device of the input batch.
195
+ try:
196
+ device = next(self.smal_model.parameters()).device
197
+ except StopIteration:
198
+ try:
199
+ device = next(self.smal_model.buffers()).device
200
+ except StopIteration:
201
+ device = batch['img'].device
202
+ smal_params = {k: v.to(device) for k, v in smal_params.items()}
203
+ with torch.no_grad():
204
+ smal_output = self.smal_model(**smal_params)
205
+ vertices = smal_output.vertices
206
+ return vertices
prima/utils/geometry.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from typing import Optional
11
+ import torch
12
+ from torch.nn import functional as F
13
+
14
+
15
+ def aa_to_rotmat(theta: torch.Tensor):
16
+ """
17
+ Convert axis-angle representation to rotation matrix.
18
+ Works by first converting it to a quaternion.
19
+ Args:
20
+ theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
21
+ Returns:
22
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
23
+ """
24
+ norm = torch.norm(theta + 1e-8, p=2, dim=1)
25
+ angle = torch.unsqueeze(norm, -1)
26
+ normalized = torch.div(theta, angle)
27
+ angle = angle * 0.5
28
+ v_cos = torch.cos(angle)
29
+ v_sin = torch.sin(angle)
30
+ quat = torch.cat([v_cos, v_sin * normalized], dim=1)
31
+ return quat_to_rotmat(quat)
32
+
33
+
34
+ def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
35
+ """
36
+ Convert quaternion representation to rotation matrix.
37
+ Args:
38
+ quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
39
+ Returns:
40
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
41
+ """
42
+ norm_quat = quat
43
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
44
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
45
+
46
+ B = quat.size(0)
47
+
48
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
49
+ wx, wy, wz = w * x, w * y, w * z
50
+ xy, xz, yz = x * y, x * z, y * z
51
+
52
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz,
53
+ 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx,
54
+ 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
55
+ return rotMat
56
+
57
+
58
+ def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
59
+ """
60
+ Convert 6D rotation representation to 3x3 rotation matrix.
61
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
62
+ Args:
63
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
64
+ Returns:
65
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
66
+ """
67
+ x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous()
68
+ a1 = x[:, :, 0]
69
+ a2 = x[:, :, 1]
70
+ b1 = F.normalize(a1)
71
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
72
+ b3 = torch.cross(b1, b2, dim=1)
73
+ return torch.stack((b1, b2, b3), dim=-1)
74
+
75
+
76
+ def perspective_projection(points: torch.Tensor,
77
+ translation: torch.Tensor,
78
+ focal_length: torch.Tensor,
79
+ camera_center: Optional[torch.Tensor] = None,
80
+ rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
81
+ """
82
+ Computes the perspective projection of a set of 3D points.
83
+ Args:
84
+ points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
85
+ translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
86
+ focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
87
+ camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
88
+ rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
89
+ Returns:
90
+ torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
91
+ """
92
+ batch_size = points.shape[0]
93
+ if rotation is None:
94
+ rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
95
+ if camera_center is None:
96
+ camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
97
+ # Populate intrinsic camera matrix K.
98
+ K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
99
+ K[:, 0, 0] = focal_length[:, 0]
100
+ K[:, 1, 1] = focal_length[:, 1]
101
+ K[:, 2, 2] = 1.
102
+ K[:, :-1, -1] = camera_center
103
+
104
+ # Transform points
105
+ points = torch.einsum('bij,bkj->bki', rotation, points)
106
+ points = points + translation.unsqueeze(1)
107
+
108
+ # Apply perspective distortion
109
+ projected_points = points / points[:, :, -1].unsqueeze(-1)
110
+
111
+ # Apply camera intrinsics
112
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
113
+
114
+
115
+ return projected_points[:, :, :-1]
prima/utils/mesh_renderer.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import os
11
+ from ctypes.util import find_library
12
+
13
+ if 'PYOPENGL_PLATFORM' not in os.environ and os.uname().sysname != 'Darwin':
14
+ # Prefer EGL; PyOpenGL's OSMesa bindings can lack symbols required by pyrender.
15
+ os.environ['PYOPENGL_PLATFORM'] = 'egl' if find_library('EGL') else 'osmesa'
16
+ if os.environ['PYOPENGL_PLATFORM'] == 'egl':
17
+ os.environ.setdefault('EGL_PLATFORM', 'surfaceless')
18
+ import torch
19
+ from torchvision.utils import make_grid
20
+ import numpy as np
21
+ import pyrender
22
+ import trimesh
23
+ import cv2
24
+ import math
25
+ import torch.nn.functional as F
26
+ from typing import List, Tuple
27
+
28
+
29
+ def create_raymond_lights():
30
+ import pyrender
31
+ thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
32
+ phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
33
+
34
+ nodes = []
35
+
36
+ for phi, theta in zip(phis, thetas):
37
+ xp = np.sin(theta) * np.cos(phi)
38
+ yp = np.sin(theta) * np.sin(phi)
39
+ zp = np.cos(theta)
40
+
41
+ z = np.array([xp, yp, zp])
42
+ z = z / np.linalg.norm(z)
43
+ x = np.array([-z[1], z[0], 0.0])
44
+ if np.linalg.norm(x) == 0:
45
+ x = np.array([1.0, 0.0, 0.0])
46
+ x = x / np.linalg.norm(x)
47
+ y = np.cross(z, x)
48
+
49
+ matrix = np.eye(4)
50
+ matrix[:3, :3] = np.c_[x, y, z]
51
+ nodes.append(pyrender.Node(
52
+ light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0),
53
+ matrix=matrix
54
+ ))
55
+
56
+ return nodes
57
+
58
+
59
+ def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]:
60
+ """
61
+ Compute rectangle enclosing keypoints above the threshold.
62
+ Args:
63
+ keypoints (np.array): Keypoint array of shape (N, 3).
64
+ threshold (float): Confidence visualization threshold.
65
+ Returns:
66
+ Tuple[float, float, float]: Rectangle width, height and area.
67
+ """
68
+ valid_ind = keypoints[:, -1] > threshold
69
+ if valid_ind.sum() > 0:
70
+ valid_keypoints = keypoints[valid_ind][:, :-1]
71
+ max_x = valid_keypoints[:, 0].max()
72
+ max_y = valid_keypoints[:, 1].max()
73
+ min_x = valid_keypoints[:, 0].min()
74
+ min_y = valid_keypoints[:, 1].min()
75
+ width = max_x - min_x
76
+ height = max_y - min_y
77
+ area = width * height
78
+ return width, height, area
79
+ else:
80
+ return 0, 0, 0
81
+
82
+
83
+ def render_keypoint(img: np.array, keypoint: np.array, threshold=0.1,
84
+ use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0) -> np.array:
85
+ if use_confidence and map_fn is not None:
86
+ thicknessCircleRatioRight = 1. / 50 * map_fn(keypoint[:, -1])
87
+ else:
88
+ thicknessCircleRatioRight = 1. / 50 * np.ones(keypoint.shape[0])
89
+
90
+ thicknessLineRatioWRTCircle = 0.75
91
+ if keypoint.shape[0] == 26:
92
+ pairs = [0, 24, 1, 24, 2, 24, 3, 14, 4, 15, 5, 16, 6, 17, 7, 18, 8, 12, 9, 13, 10, 7, 11, 7,
93
+ 12, 18, 13, 18, 14, 8, 15, 9, 16, 10, 17, 11, 18, 24, 19, 25, 20, 0, 21, 1, 22, 24,
94
+ 23, 24, 25, 7]
95
+ elif keypoint.shape[0] == 18:
96
+ pairs = [9, 8, 8, 2, 2, 3, 3, 4, 2, 0, 2, 1, 4, 5,
97
+ 5, 14, 14, 15, 4, 6, 6, 7, 7, 11, 11, 10,
98
+ 7, 13, 13, 12, 5, 16, 5, 17]
99
+ else:
100
+ raise ValueError("Keypoint shape not supported")
101
+ pairs = np.array(pairs).reshape(-1, 2) if pairs is not None else None
102
+ colors = [255., 0., 85.,
103
+ 255., 0., 0.,
104
+ 255., 85., 0.,
105
+ 255., 170., 0.,
106
+ 255., 255., 0.,
107
+ 170., 255., 0.,
108
+ 85., 255., 0.,
109
+ 0., 255., 0.,
110
+ 255., 0., 0.,
111
+ 0., 255., 85.,
112
+ 0., 255., 170.,
113
+ 0., 255., 255.,
114
+ 0., 170., 255.,
115
+ 0., 85., 255.,
116
+ 0., 0., 255.,
117
+ 255., 0., 170.,
118
+ 170., 0., 255.,
119
+ 255., 0., 255.,
120
+ 85., 0., 255.,
121
+ 0., 0., 255.,
122
+ 0., 0., 255.,
123
+ 0., 0., 255.,
124
+ 0., 255., 255.,
125
+ 0., 255., 255.,
126
+ 0., 255., 255.,
127
+ 255., 225., 255.]
128
+ colors = np.array(colors).reshape(-1, 3)
129
+ poseScales = [1]
130
+
131
+ img_orig = img.copy()
132
+ width, height = img.shape[1], img.shape[2]
133
+ area = width * height
134
+
135
+ lineType = 8
136
+ shift = 0
137
+ numberColors = len(colors)
138
+ thresholdRectangle = 0.1
139
+
140
+ animal_width, animal_height, animal_area = get_keypoints_rectangle(keypoint, thresholdRectangle)
141
+ if animal_area > 0:
142
+ ratioAreas = min(1, max(animal_width / width, animal_height / height))
143
+ thicknessRatio = np.maximum(np.round(math.sqrt(area) * thicknessCircleRatioRight * ratioAreas), 2)
144
+ thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio))
145
+ thicknessLine = np.maximum(1, np.round(thicknessRatio * thicknessLineRatioWRTCircle))
146
+ radius = thicknessRatio / 2
147
+ else:
148
+ return img
149
+
150
+ img = np.ascontiguousarray(img.copy())
151
+ if pairs is not None:
152
+ for i, pair in enumerate(pairs):
153
+ index1, index2 = pair
154
+ if keypoint[index1, -1] > threshold and keypoint[index2, -1] > threshold:
155
+ thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * poseScales[0]))
156
+ colorIndex = index2
157
+ color = colors[colorIndex % numberColors]
158
+ keypoint1 = keypoint[index1, :-1].astype(np.int32)
159
+ keypoint2 = keypoint[index2, :-1].astype(np.int32)
160
+ cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()),
161
+ thicknessLineScaled, lineType, shift)
162
+ for part in range(len(keypoint)):
163
+ faceIndex = part
164
+ if keypoint[faceIndex, -1] > threshold:
165
+ radiusScaled = int(round(radius[faceIndex] * poseScales[0]))
166
+ thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * poseScales[0]))
167
+ colorIndex = part
168
+ color = colors[colorIndex % numberColors]
169
+ center = keypoint[faceIndex, :-1].astype(np.int32)
170
+ cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled,
171
+ lineType, shift)
172
+
173
+ return img
174
+
175
+
176
+ class MeshRenderer:
177
+
178
+ def __init__(self, cfg, faces=None):
179
+ self.cfg = cfg
180
+ self.img_res = cfg.MODEL.IMAGE_SIZE
181
+ self.renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res,
182
+ viewport_height=self.img_res,
183
+ point_size=1.0)
184
+
185
+ self.camera_center = [self.img_res // 2, self.img_res // 2]
186
+ self.faces = faces
187
+
188
+ def visualize(self, vertices, camera_translation, images, focal_length, nrow=3, padding=2):
189
+ images_np = np.transpose(images, (0, 2, 3, 1))
190
+ rend_imgs = []
191
+ for i in range(vertices.shape[0]):
192
+ rend_img = torch.from_numpy(np.transpose(
193
+ self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=False),
194
+ (2, 0, 1))).float()
195
+ rend_img_side = torch.from_numpy(np.transpose(
196
+ self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=True),
197
+ (2, 0, 1))).float()
198
+ rend_imgs.append(torch.from_numpy(images[i]))
199
+ rend_imgs.append(rend_img)
200
+ rend_imgs.append(rend_img_side)
201
+ rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding)
202
+ return rend_imgs
203
+
204
+ def visualize_tensorboard(self, vertices, camera_translation, images, focal_length, pred_keypoints, gt_keypoints,
205
+ pred_masks=None, gt_masks=None):
206
+ images_np = np.transpose(images, (0, 2, 3, 1))
207
+ rend_imgs = []
208
+ pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1)
209
+ pred_keypoints = self.img_res * (pred_keypoints + 0.5)
210
+ gt_keypoints[:, :, :-1] = self.img_res * (gt_keypoints[:, :, :-1] + 0.5)
211
+ # keypoint_matches = [(1, 12), (2, 8), (3, 7), (4, 6), (5, 9),
212
+ # (6, 10), (7, 11), (8, 14), (9, 2), (10, 1), (11, 0), (12, 3), (13, 4), (14, 5)]
213
+ # rend_img_pytorch3d = self.render_by_pytorch3d(vertices, camera_translation,
214
+ # images_np, focal_length=self.focal_length)
215
+ for i in range(vertices.shape[0]):
216
+ rend_img = torch.from_numpy(np.transpose(
217
+ self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=False),
218
+ (2, 0, 1))).float()
219
+ rend_img_side = torch.from_numpy(np.transpose(
220
+ self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=True),
221
+ (2, 0, 1))).float()
222
+ keypoints = pred_keypoints[i]
223
+ pred_keypoints_img = render_keypoint(255 * images_np[i].copy(), keypoints) / 255
224
+ keypoints = gt_keypoints[i]
225
+ gt_keypoints_img = render_keypoint(255 * images_np[i].copy(), keypoints) / 255
226
+ rend_imgs.append(torch.from_numpy(images[i]))
227
+ rend_imgs.append(rend_img)
228
+ rend_imgs.append(rend_img_side)
229
+ if pred_masks is not None:
230
+ rend_imgs.append(torch.from_numpy(pred_masks[i]))
231
+ if gt_masks is not None:
232
+ rend_imgs.append(torch.from_numpy(gt_masks[i]))
233
+ rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2, 0, 1))
234
+ rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2, 0, 1))
235
+ return rend_imgs
236
+
237
+ def __call__(self, vertices, camera_translation, image, focal_length, text=None, resize=None, side_view=False,
238
+ baseColorFactor=(1.0, 1.0, 0.9, 1.0), rot_angle=90):
239
+ renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1],
240
+ viewport_height=image.shape[0],
241
+ point_size=1.0)
242
+ material = pyrender.MetallicRoughnessMaterial(
243
+ metallicFactor=0.0,
244
+ alphaMode='OPAQUE',
245
+ baseColorFactor=baseColorFactor)
246
+
247
+ camera_translation_local = camera_translation.copy()
248
+ camera_translation_local[0] *= -1.
249
+
250
+ mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
251
+ if side_view:
252
+ rot = trimesh.transformations.rotation_matrix(
253
+ np.radians(rot_angle), [0, 1, 0])
254
+ mesh.apply_transform(rot)
255
+ rot = trimesh.transformations.rotation_matrix(
256
+ np.radians(180), [1, 0, 0])
257
+ mesh.apply_transform(rot)
258
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
259
+
260
+ scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
261
+ ambient_light=(0.3, 0.3, 0.3))
262
+ scene.add(mesh, 'mesh')
263
+
264
+ camera_pose = np.eye(4)
265
+ camera_pose[:3, 3] = camera_translation_local
266
+ camera_center = [image.shape[1] / 2., image.shape[0] / 2.]
267
+ camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
268
+ cx=camera_center[0], cy=camera_center[1],
269
+ zfar=1000)
270
+ scene.add(camera, pose=camera_pose)
271
+
272
+ light_nodes = create_raymond_lights()
273
+ for node in light_nodes:
274
+ scene.add_node(node)
275
+
276
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
277
+ color = color.astype(np.float32) / 255.0
278
+ valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
279
+ if not side_view:
280
+ output_img = (color[:, :, :3] * valid_mask +
281
+ (1 - valid_mask) * image)
282
+ else:
283
+ output_img = color[:, :, :3]
284
+ if resize is not None:
285
+ output_img = cv2.resize(output_img, resize)
286
+
287
+ output_img = output_img.astype(np.float32)
288
+ renderer.delete()
289
+ return output_img
290
+
291
+ def render_mask(self, vertices, camera_translation, focal_length, side_view=False, rot_angle=90):
292
+ """
293
+ Render only the visibility mask (alpha>0) of the mesh given vertices and camera translation.
294
+ Returns a single-channel float32 numpy array with values 0.0 or 1.0 with shape (H, W).
295
+ """
296
+ renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res,
297
+ viewport_height=self.img_res,
298
+ point_size=1.0)
299
+
300
+ mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
301
+ if side_view:
302
+ rot = trimesh.transformations.rotation_matrix(
303
+ np.radians(rot_angle), [0, 1, 0])
304
+ mesh.apply_transform(rot)
305
+ rot = trimesh.transformations.rotation_matrix(
306
+ np.radians(180), [1, 0, 0])
307
+ mesh.apply_transform(rot)
308
+ mesh = pyrender.Mesh.from_trimesh(mesh)
309
+
310
+ scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
311
+ ambient_light=(0.3, 0.3, 0.3))
312
+ scene.add(mesh, 'mesh')
313
+
314
+ camera_pose = np.eye(4)
315
+ camera_pose[:3, 3] = camera_translation
316
+ camera_center = [self.img_res / 2., self.img_res / 2.]
317
+ camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
318
+ cx=camera_center[0], cy=camera_center[1],
319
+ zfar=1000)
320
+ scene.add(camera, pose=camera_pose)
321
+
322
+ light_nodes = create_raymond_lights()
323
+ for node in light_nodes:
324
+ scene.add_node(node)
325
+
326
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
327
+ # alpha channel indicates visibility
328
+ mask = (color[:, :, -1] > 0).astype(np.float32)
329
+ renderer.delete()
330
+ return mask
prima/utils/misc.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import time
11
+ import warnings
12
+ from importlib.util import find_spec
13
+ from pathlib import Path
14
+ from typing import Callable, List
15
+
16
+ import hydra
17
+ from omegaconf import DictConfig, OmegaConf
18
+ from pytorch_lightning import Callback
19
+ from pytorch_lightning.loggers import Logger
20
+ from pytorch_lightning.utilities import rank_zero_only
21
+
22
+ from . import pylogger, rich_utils
23
+
24
+ log = pylogger.get_pylogger(__name__)
25
+
26
+
27
+ def task_wrapper(task_func: Callable) -> Callable:
28
+ """Optional decorator that wraps the task function in extra utilities.
29
+
30
+ Makes multirun more resistant to failure.
31
+
32
+ Utilities:
33
+ - Calling the `utils.extras()` before the task is started
34
+ - Calling the `utils.close_loggers()` after the task is finished
35
+ - Logging the exception if occurs
36
+ - Logging the task total execution time
37
+ - Logging the output dir
38
+ """
39
+
40
+ def wrap(cfg: DictConfig):
41
+ start_time = time.time()
42
+ try:
43
+ # apply extra utilities
44
+ extras(cfg)
45
+
46
+ # execute the task
47
+ ret = task_func(cfg=cfg)
48
+ except Exception as ex:
49
+ log.exception("") # save exception to `.log` file
50
+ raise ex
51
+ finally:
52
+ path = Path(cfg.paths.output_dir, "exec_time.log")
53
+ content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)"
54
+ save_file(path, content) # save task execution time (even if exception occurs)
55
+ close_loggers() # close loggers (even if exception occurs so multirun won't fail)
56
+
57
+ log.info(f"Output dir: {cfg.paths.output_dir}")
58
+
59
+ return ret
60
+
61
+ return wrap
62
+
63
+
64
+ def extras(cfg: DictConfig) -> None:
65
+ """Applies optional utilities before the task is started.
66
+
67
+ Utilities:
68
+ - Ignoring python warnings
69
+ - Setting tags from command line
70
+ - Rich config printing
71
+ """
72
+
73
+ # return if no `extras` config
74
+ if not cfg.get("extras"):
75
+ log.warning("Extras config not found! <cfg.extras=null>")
76
+ return
77
+
78
+ # disable python warnings
79
+ if cfg.extras.get("ignore_warnings"):
80
+ log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
81
+ warnings.filterwarnings("ignore")
82
+
83
+ # prompt user to input tags from command line if none are provided in the config
84
+ if cfg.extras.get("enforce_tags"):
85
+ log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
86
+ rich_utils.enforce_tags(cfg, save_to_file=True)
87
+
88
+ # pretty print config tree using Rich library
89
+ if cfg.extras.get("print_config"):
90
+ log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
91
+ rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
92
+
93
+
94
+ @rank_zero_only
95
+ def save_file(path: str, content: str) -> None:
96
+ """Save file in rank zero mode (only on one process in multi-GPU setup)."""
97
+ with open(path, "w+") as file:
98
+ file.write(content)
99
+
100
+
101
+ def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
102
+ """Instantiates callbacks from config."""
103
+ callbacks: List[Callback] = []
104
+
105
+ if not callbacks_cfg:
106
+ log.warning("Callbacks config is empty.")
107
+ return callbacks
108
+
109
+ if not isinstance(callbacks_cfg, DictConfig):
110
+ raise TypeError("Callbacks config must be a DictConfig!")
111
+
112
+ for _, cb_conf in callbacks_cfg.items():
113
+ if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
114
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
115
+ callbacks.append(hydra.utils.instantiate(cb_conf))
116
+
117
+ return callbacks
118
+
119
+
120
+ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
121
+ """Instantiates loggers from config."""
122
+ logger: List[Logger] = []
123
+
124
+ if not logger_cfg:
125
+ log.warning("Logger config is empty.")
126
+ return logger
127
+
128
+ if not isinstance(logger_cfg, DictConfig):
129
+ raise TypeError("Logger config must be a DictConfig!")
130
+
131
+ for _, lg_conf in logger_cfg.items():
132
+ if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
133
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
134
+ logger.append(hydra.utils.instantiate(lg_conf))
135
+
136
+ return logger
137
+
138
+
139
+ @rank_zero_only
140
+ def log_hyperparameters(object_dict: dict) -> None:
141
+ """Controls which config parts are saved by lightning loggers.
142
+
143
+ Additionally saves:
144
+ - Number of model parameters
145
+ """
146
+
147
+ hparams = {}
148
+
149
+ cfg = object_dict["cfg"]
150
+ model = object_dict["model"]
151
+ trainer = object_dict["trainer"]
152
+
153
+ if not trainer.logger:
154
+ log.warning("Logger not found! Skipping hyperparameter logging...")
155
+ return
156
+
157
+ # save number of model parameters
158
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
159
+ hparams["model/params/trainable"] = sum(
160
+ p.numel() for p in model.parameters() if p.requires_grad
161
+ )
162
+ hparams["model/params/non_trainable"] = sum(
163
+ p.numel() for p in model.parameters() if not p.requires_grad
164
+ )
165
+
166
+ for k in cfg.keys():
167
+ hparams[k] = cfg.get(k)
168
+
169
+ # Resolve all interpolations
170
+ def _resolve(_cfg):
171
+ if isinstance(_cfg, DictConfig):
172
+ _cfg = OmegaConf.to_container(_cfg, resolve=True)
173
+ return _cfg
174
+
175
+ hparams = {k: _resolve(v) for k, v in hparams.items()}
176
+
177
+ # send hparams to all loggers
178
+ trainer.logger.log_hyperparams(hparams)
179
+
180
+
181
+ def get_metric_value(metric_dict: dict, metric_name: str) -> float:
182
+ """Safely retrieves value of the metric logged in LightningModule."""
183
+
184
+ if not metric_name:
185
+ log.info("Metric name is None! Skipping metric value retrieval...")
186
+ return None
187
+
188
+ if metric_name not in metric_dict:
189
+ raise Exception(
190
+ f"Metric value not found! <metric_name={metric_name}>\n"
191
+ "Make sure metric name logged in LightningModule is correct!\n"
192
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
193
+ )
194
+
195
+ metric_value = metric_dict[metric_name].item()
196
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
197
+
198
+ return metric_value
199
+
200
+
201
+ def close_loggers() -> None:
202
+ """Makes sure all loggers closed properly (prevents logging failure during multirun)."""
203
+
204
+ log.info("Closing loggers...")
205
+
206
+ if find_spec("wandb"): # if wandb is installed
207
+ import wandb
208
+
209
+ if wandb.run:
210
+ log.info("Closing wandb!")
211
+ wandb.finish()
prima/utils/pylogger.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import logging
11
+
12
+ from pytorch_lightning.utilities import rank_zero_only
13
+
14
+
15
+ def get_pylogger(name=__name__) -> logging.Logger:
16
+ """Initializes multi-GPU-friendly python command line logger."""
17
+
18
+ logger = logging.getLogger(name)
19
+
20
+ # this ensures all logging levels get marked with the rank zero decorator
21
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
22
+ logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
23
+ for level in logging_levels:
24
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
25
+
26
+ return logger
prima/utils/renderer.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ """
3
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
4
+
5
+ Official implementation of the paper:
6
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
7
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
8
+ Licensed under a modified MIT license
9
+ """
10
+
11
+
12
+ import os
13
+ from ctypes.util import find_library
14
+
15
+ if 'PYOPENGL_PLATFORM' not in os.environ and os.uname().sysname != 'Darwin':
16
+ # Prefer EGL; PyOpenGL's OSMesa bindings can lack symbols required by pyrender.
17
+ os.environ['PYOPENGL_PLATFORM'] = 'egl' if find_library('EGL') else 'osmesa'
18
+ if os.environ['PYOPENGL_PLATFORM'] == 'egl':
19
+ os.environ.setdefault('EGL_PLATFORM', 'surfaceless')
20
+ import torch
21
+ import numpy as np
22
+ import pyrender
23
+ import trimesh
24
+ import cv2
25
+ from yacs.config import CfgNode
26
+ from typing import List, Optional
27
+
28
+
29
+ def cam_crop_to_full(cam_bbox, box_center, box_size, img_size, focal_length=5000.):
30
+ # Convert cam_bbox to full image
31
+ img_w, img_h = img_size[:, 0], img_size[:, 1]
32
+ cx, cy, b = box_center[:, 0], box_center[:, 1], box_size
33
+ w_2, h_2 = img_w / 2., img_h / 2.
34
+ bs = b * cam_bbox[:, 0] + 1e-9
35
+ tz = 2 * focal_length / bs
36
+ tx = (2 * (cx - w_2) / bs) + cam_bbox[:, 1]
37
+ ty = (2 * (cy - h_2) / bs) + cam_bbox[:, 2]
38
+ full_cam = torch.stack([tx, ty, tz], dim=-1)
39
+ return full_cam
40
+
41
+
42
+ def get_light_poses(n_lights=5, elevation=np.pi / 3, dist=12):
43
+ # get lights in a circle around origin at elevation
44
+ thetas = elevation * np.ones(n_lights)
45
+ phis = 2 * np.pi * np.arange(n_lights) / n_lights
46
+ poses = []
47
+ trans = make_translation(torch.tensor([0, 0, dist]))
48
+ for phi, theta in zip(phis, thetas):
49
+ rot = make_rotation(rx=-theta, ry=phi, order="xyz")
50
+ poses.append((rot @ trans).numpy())
51
+ return poses
52
+
53
+
54
+ def make_translation(t):
55
+ return make_4x4_pose(torch.eye(3), t)
56
+
57
+
58
+ def make_rotation(rx=0, ry=0, rz=0, order="xyz"):
59
+ Rx = rotx(rx)
60
+ Ry = roty(ry)
61
+ Rz = rotz(rz)
62
+ if order == "xyz":
63
+ R = Rz @ Ry @ Rx
64
+ elif order == "xzy":
65
+ R = Ry @ Rz @ Rx
66
+ elif order == "yxz":
67
+ R = Rz @ Rx @ Ry
68
+ elif order == "yzx":
69
+ R = Rx @ Rz @ Ry
70
+ elif order == "zyx":
71
+ R = Rx @ Ry @ Rz
72
+ elif order == "zxy":
73
+ R = Ry @ Rx @ Rz
74
+ return make_4x4_pose(R, torch.zeros(3))
75
+
76
+
77
+ def make_4x4_pose(R, t):
78
+ """
79
+ :param R (*, 3, 3)
80
+ :param t (*, 3)
81
+ return (*, 4, 4)
82
+ """
83
+ dims = R.shape[:-2]
84
+ pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1)
85
+ bottom = (
86
+ torch.tensor([0, 0, 0, 1], device=R.device)
87
+ .reshape(*(1,) * len(dims), 1, 4)
88
+ .expand(*dims, 1, 4)
89
+ )
90
+ return torch.cat([pose_3x4, bottom], dim=-2)
91
+
92
+
93
+ def rotx(theta):
94
+ return torch.tensor(
95
+ [
96
+ [1, 0, 0],
97
+ [0, np.cos(theta), -np.sin(theta)],
98
+ [0, np.sin(theta), np.cos(theta)],
99
+ ],
100
+ dtype=torch.float32,
101
+ )
102
+
103
+
104
+ def roty(theta):
105
+ return torch.tensor(
106
+ [
107
+ [np.cos(theta), 0, np.sin(theta)],
108
+ [0, 1, 0],
109
+ [-np.sin(theta), 0, np.cos(theta)],
110
+ ],
111
+ dtype=torch.float32,
112
+ )
113
+
114
+
115
+ def rotz(theta):
116
+ return torch.tensor(
117
+ [
118
+ [np.cos(theta), -np.sin(theta), 0],
119
+ [np.sin(theta), np.cos(theta), 0],
120
+ [0, 0, 1],
121
+ ],
122
+ dtype=torch.float32,
123
+ )
124
+
125
+
126
+ def create_raymond_lights() -> List[pyrender.Node]:
127
+ """
128
+ Return raymond light nodes for the scene.
129
+ """
130
+ thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
131
+ phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
132
+
133
+ nodes = []
134
+
135
+ for phi, theta in zip(phis, thetas):
136
+ xp = np.sin(theta) * np.cos(phi)
137
+ yp = np.sin(theta) * np.sin(phi)
138
+ zp = np.cos(theta)
139
+
140
+ z = np.array([xp, yp, zp])
141
+ z = z / np.linalg.norm(z)
142
+ x = np.array([-z[1], z[0], 0.0])
143
+ if np.linalg.norm(x) == 0:
144
+ x = np.array([1.0, 0.0, 0.0])
145
+ x = x / np.linalg.norm(x)
146
+ y = np.cross(z, x)
147
+
148
+ matrix = np.eye(4)
149
+ matrix[:3, :3] = np.c_[x, y, z]
150
+ nodes.append(pyrender.Node(
151
+ light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0),
152
+ matrix=matrix
153
+ ))
154
+
155
+ return nodes
156
+
157
+
158
+ class Renderer:
159
+
160
+ def __init__(self, cfg: CfgNode, faces: np.array):
161
+ """
162
+ Wrapper around the pyrender renderer to render MANO meshes.
163
+ Args:
164
+ cfg (CfgNode): Model config file.
165
+ faces (np.array): Array of shape (F, 3) containing the mesh faces.
166
+ """
167
+ self.cfg = cfg
168
+ self.focal_length = 1000. if faces.shape[0] == 7774 else 2167.
169
+ self.img_res = cfg.MODEL.IMAGE_SIZE
170
+
171
+ self.camera_center = [self.img_res // 2, self.img_res // 2]
172
+ self.faces = faces.cpu().numpy()
173
+
174
+ def __call__(self,
175
+ vertices: np.array,
176
+ camera_translation: np.array,
177
+ image: torch.Tensor,
178
+ full_frame: bool = False,
179
+ imgname: Optional[str] = None,
180
+ side_view=False, rot_angle=90,
181
+ mesh_base_color=(1.0, 1.0, 0.9),
182
+ scene_bg_color=(0, 0, 0),
183
+ return_rgba=False,
184
+ depth = False,
185
+ focal_length: Optional[float] = None,
186
+ ) -> np.array:
187
+ """
188
+ Render meshes on input image
189
+ Args:
190
+ vertices (np.array): Array of shape (V, 3) containing the mesh vertices.
191
+ camera_translation (np.array): Array of shape (3,) with the camera translation.
192
+ image (torch.Tensor): Tensor of shape (3, H, W) containing the image crop with normalized pixel values.
193
+ full_frame (bool): If True, then render on the full image.
194
+ imgname (Optional[str]): Contains the original image filenamee. Used only if full_frame == True.
195
+ focal_length (Optional[float]): Custom focal length. If None, uses self.focal_length.
196
+ """
197
+
198
+ if full_frame:
199
+
200
+ image = cv2.imread(imgname)
201
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
202
+ else:
203
+ image = (image.clone()) * (torch.tensor(self.cfg.MODEL.IMAGE_STD, device=image.device).reshape(3, 1, 1))
204
+ image = image + torch.tensor(self.cfg.MODEL.IMAGE_MEAN, device=image.device).reshape(3, 1, 1)
205
+ image = image.permute(1, 2, 0).cpu().numpy()
206
+
207
+ # Use custom focal length if provided, otherwise use default
208
+ focal_length_to_use = focal_length if focal_length is not None else self.focal_length
209
+
210
+ renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1],
211
+ viewport_height=image.shape[0],
212
+ point_size=1.0)
213
+ material = pyrender.MetallicRoughnessMaterial(
214
+ metallicFactor=0.0,
215
+ alphaMode='OPAQUE',
216
+ baseColorFactor=(*mesh_base_color, 1.0))
217
+
218
+ camera_translation_local = camera_translation.copy()
219
+ camera_translation_local[0] *= -1.
220
+
221
+ mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
222
+ if side_view:
223
+ rot = trimesh.transformations.rotation_matrix(
224
+ np.radians(rot_angle), [0, 1, 0])
225
+ mesh.apply_transform(rot)
226
+ rot = trimesh.transformations.rotation_matrix(
227
+ np.radians(180), [1, 0, 0])
228
+ mesh.apply_transform(rot)
229
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
230
+
231
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
232
+ ambient_light=(0.3, 0.3, 0.3))
233
+ scene.add(mesh, 'mesh')
234
+
235
+ camera_pose = np.eye(4)
236
+ camera_pose[:3, 3] = camera_translation_local
237
+ camera_center = [image.shape[1] / 2., image.shape[0] / 2.]
238
+ camera = pyrender.IntrinsicsCamera(fx=focal_length_to_use, fy=focal_length_to_use,
239
+ cx=camera_center[0], cy=camera_center[1], zfar=1e12)
240
+ scene.add(camera, pose=camera_pose)
241
+
242
+ light_nodes = create_raymond_lights()
243
+ for node in light_nodes:
244
+ scene.add_node(node)
245
+
246
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
247
+ color = color.astype(np.float32) / 255.0
248
+ renderer.delete()
249
+
250
+ if depth:
251
+ return rend_depth
252
+
253
+ if return_rgba:
254
+ return color
255
+
256
+ valid_mask = (rend_depth > 0).astype(np.float32)[:, :, np.newaxis]
257
+ if not side_view:
258
+ output_img = (color[:, :, :3] * valid_mask + (1 - valid_mask) * image)
259
+ else:
260
+ output_img = color[:, :, :3]
261
+
262
+ output_img = output_img.astype(np.float32)
263
+ return output_img
264
+
265
+ def vertices_to_trimesh(self, vertices, camera_translation, mesh_base_color=(1.0, 1.0, 0.9),
266
+ rot_axis=[1, 0, 0], rot_angle=0):
267
+ # material = pyrender.MetallicRoughnessMaterial(
268
+ # metallicFactor=0.0,
269
+ # alphaMode='OPAQUE',
270
+ # baseColorFactor=(*mesh_base_color, 1.0))
271
+ vertex_colors = np.array([(*mesh_base_color, 1.0)] * vertices.shape[0])
272
+ mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces.copy(), vertex_colors=vertex_colors)
273
+ # mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
274
+
275
+ rot = trimesh.transformations.rotation_matrix(
276
+ np.radians(rot_angle), rot_axis)
277
+ mesh.apply_transform(rot)
278
+
279
+ rot = trimesh.transformations.rotation_matrix(
280
+ np.radians(180), [1, 0, 0])
281
+ mesh.apply_transform(rot)
282
+ return mesh
283
+
284
+ def render_rgba(
285
+ self,
286
+ vertices: np.array,
287
+ cam_t=None,
288
+ rot=None,
289
+ rot_axis=[1, 0, 0],
290
+ rot_angle=0,
291
+ camera_z=3,
292
+ # camera_translation: np.array,
293
+ mesh_base_color=(1.0, 1.0, 0.9),
294
+ scene_bg_color=(0, 0, 0),
295
+ render_res=[256, 256],
296
+ focal_length=None,
297
+ ):
298
+
299
+ renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0],
300
+ viewport_height=render_res[1],
301
+ point_size=1.0)
302
+ # material = pyrender.MetallicRoughnessMaterial(
303
+ # metallicFactor=0.0,
304
+ # alphaMode='OPAQUE',
305
+ # baseColorFactor=(*mesh_base_color, 1.0))
306
+
307
+ focal_length = focal_length if focal_length is not None else self.focal_length
308
+
309
+ if cam_t is not None:
310
+ camera_translation = cam_t.copy()
311
+ camera_translation[0] *= -1.
312
+ else:
313
+ camera_translation = np.array([0, 0, camera_z * focal_length / render_res[1]])
314
+
315
+ mesh = self.vertices_to_trimesh(vertices, np.array([0, 0, 0]), mesh_base_color, rot_axis, rot_angle,
316
+ )
317
+ mesh = pyrender.Mesh.from_trimesh(mesh)
318
+ # mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
319
+
320
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
321
+ ambient_light=(0.3, 0.3, 0.3))
322
+ scene.add(mesh, 'mesh')
323
+
324
+ camera_pose = np.eye(4)
325
+ camera_pose[:3, 3] = camera_translation
326
+ camera_center = [render_res[0] / 2., render_res[1] / 2.]
327
+ camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
328
+ cx=camera_center[0], cy=camera_center[1], zfar=1e12)
329
+
330
+ # Create camera node and add it to pyRender scene
331
+ camera_node = pyrender.Node(camera=camera, matrix=camera_pose)
332
+ scene.add_node(camera_node)
333
+ self.add_point_lighting(scene, camera_node)
334
+ self.add_lighting(scene, camera_node)
335
+
336
+ light_nodes = create_raymond_lights()
337
+ for node in light_nodes:
338
+ scene.add_node(node)
339
+
340
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
341
+ color = color.astype(np.float32) / 255.0
342
+ renderer.delete()
343
+
344
+ return color
345
+
346
+ def render_rgba_multiple(
347
+ self,
348
+ vertices: List[np.array],
349
+ cam_t: List[np.array],
350
+ rot_axis=[1, 0, 0],
351
+ rot_angle=0,
352
+ mesh_base_color=(1.0, 1.0, 0.9),
353
+ scene_bg_color=(0, 0, 0),
354
+ render_res=[256, 256],
355
+ focal_length=None,
356
+ ):
357
+
358
+ renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0],
359
+ viewport_height=render_res[1],
360
+ point_size=1.0)
361
+ # material = pyrender.MetallicRoughnessMaterial(
362
+ # metallicFactor=0.0,
363
+ # alphaMode='OPAQUE',
364
+ # baseColorFactor=(*mesh_base_color, 1.0))
365
+
366
+ mesh_list = [pyrender.Mesh.from_trimesh(
367
+ self.vertices_to_trimesh(vvv, ttt.copy(), mesh_base_color, rot_axis, rot_angle)) for
368
+ vvv, ttt in zip(vertices, cam_t)]
369
+
370
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
371
+ ambient_light=(0.3, 0.3, 0.3))
372
+ for i, mesh in enumerate(mesh_list):
373
+ scene.add(mesh, f'mesh_{i}')
374
+
375
+ camera_pose = np.eye(4)
376
+ # camera_pose[:3, 3] = camera_translation
377
+ camera_center = [render_res[0] / 2., render_res[1] / 2.]
378
+ focal_length = focal_length if focal_length is not None else self.focal_length
379
+ camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
380
+ cx=camera_center[0], cy=camera_center[1], zfar=1e12)
381
+
382
+ # Create camera node and add it to pyRender scene
383
+ camera_node = pyrender.Node(camera=camera, matrix=camera_pose)
384
+ scene.add_node(camera_node)
385
+ self.add_point_lighting(scene, camera_node)
386
+ self.add_lighting(scene, camera_node)
387
+
388
+ light_nodes = create_raymond_lights()
389
+ for node in light_nodes:
390
+ scene.add_node(node)
391
+
392
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
393
+ color = color.astype(np.float32) / 255.0
394
+ renderer.delete()
395
+
396
+ return color
397
+
398
+ def add_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0):
399
+
400
+ light_poses = get_light_poses()
401
+ light_poses.append(np.eye(4))
402
+ cam_pose = scene.get_pose(cam_node)
403
+ for i, pose in enumerate(light_poses):
404
+ matrix = cam_pose @ pose
405
+ node = pyrender.Node(
406
+ name=f"light-{i:02d}",
407
+ light=pyrender.DirectionalLight(color=color, intensity=intensity),
408
+ matrix=matrix,
409
+ )
410
+ if scene.has_node(node):
411
+ continue
412
+ scene.add_node(node)
413
+
414
+ def add_point_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0):
415
+
416
+ light_poses = get_light_poses(dist=0.5)
417
+ light_poses.append(np.eye(4))
418
+ cam_pose = scene.get_pose(cam_node)
419
+ for i, pose in enumerate(light_poses):
420
+ matrix = cam_pose @ pose
421
+ # node = pyrender.Node(
422
+ # name=f"light-{i:02d}",
423
+ # light=pyrender.DirectionalLight(color=color, intensity=intensity),
424
+ # matrix=matrix,
425
+ # )
426
+ node = pyrender.Node(
427
+ name=f"plight-{i:02d}",
428
+ light=pyrender.PointLight(color=color, intensity=intensity),
429
+ matrix=matrix,
430
+ )
431
+ if scene.has_node(node):
432
+ continue
433
+ scene.add_node(node)
prima/utils/rich_utils.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from pathlib import Path
11
+ from typing import Sequence
12
+
13
+ import rich
14
+ import rich.syntax
15
+ import rich.tree
16
+ from hydra.core.hydra_config import HydraConfig
17
+ from omegaconf import DictConfig, OmegaConf, open_dict
18
+ from pytorch_lightning.utilities import rank_zero_only
19
+ from rich.prompt import Prompt
20
+
21
+ from . import pylogger
22
+
23
+ log = pylogger.get_pylogger(__name__)
24
+
25
+
26
+ @rank_zero_only
27
+ def print_config_tree(
28
+ cfg: DictConfig,
29
+ print_order: Sequence[str] = (
30
+ "datamodule",
31
+ "model",
32
+ "callbacks",
33
+ "logger",
34
+ "trainer",
35
+ "paths",
36
+ "extras",
37
+ ),
38
+ resolve: bool = False,
39
+ save_to_file: bool = False,
40
+ ) -> None:
41
+ """Prints content of DictConfig using Rich library and its tree structure.
42
+
43
+ Args:
44
+ cfg (DictConfig): Configuration composed by Hydra.
45
+ print_order (Sequence[str], optional): Determines in what order config components are printed.
46
+ resolve (bool, optional): Whether to resolve reference fields of DictConfig.
47
+ save_to_file (bool, optional): Whether to export config to the hydra output folder.
48
+ """
49
+
50
+ style = "dim"
51
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
52
+
53
+ queue = []
54
+
55
+ # add fields from `print_order` to queue
56
+ for field in print_order:
57
+ queue.append(field) if field in cfg else log.warning(
58
+ f"Field '{field}' not found in config. Skipping '{field}' config printing..."
59
+ )
60
+
61
+ # add all the other fields to queue (not specified in `print_order`)
62
+ for field in cfg:
63
+ if field not in queue:
64
+ queue.append(field)
65
+
66
+ # generate config tree from queue
67
+ for field in queue:
68
+ branch = tree.add(field, style=style, guide_style=style)
69
+
70
+ config_group = cfg[field]
71
+ if isinstance(config_group, DictConfig):
72
+ branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
73
+ else:
74
+ branch_content = str(config_group)
75
+
76
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
77
+
78
+ # print config tree
79
+ rich.print(tree)
80
+
81
+ # save config tree to file
82
+ if save_to_file:
83
+ with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
84
+ rich.print(tree, file=file)
85
+
86
+
87
+ @rank_zero_only
88
+ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
89
+ """Prompts user to input tags from command line if no tags are provided in config."""
90
+
91
+ if not cfg.get("tags"):
92
+ if "id" in HydraConfig().cfg.hydra.job:
93
+ raise ValueError("Specify tags before launching a multirun!")
94
+
95
+ log.warning("No tags provided in config. Prompting user to input tags...")
96
+ tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
97
+ tags = [t.strip() for t in tags.split(",") if t != ""]
98
+
99
+ with open_dict(cfg):
100
+ cfg.tags = tags
101
+
102
+ log.info(f"Tags: {cfg.tags}")
103
+
104
+ if save_to_file:
105
+ with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
106
+ rich.print(cfg.tags, file=file)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ from hydra import compose, initialize
111
+
112
+ with initialize(version_base="1.2", config_path="../../configs_hydra"):
113
+ cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[])
114
+ print_config_tree(cfg, resolve=False, save_to_file=False)
prima/utils/weights.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import os
13
+ import shutil
14
+ from pathlib import Path
15
+ from typing import Iterable, Optional, Sequence, Union
16
+
17
+ HF_REPO_ID = "MLAdaptiveIntelligence/PRIMA"
18
+ DEFAULT_HF_REPO_ID = HF_REPO_ID
19
+
20
+ DEFAULT_STAGE1_CHECKPOINT = Path("data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt")
21
+ DEFAULT_STAGE3_CHECKPOINT = Path("data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt")
22
+
23
+ SMAL_ASSET_PATHS = [
24
+ "my_smpl_00781_4_all.pkl",
25
+ "my_smpl_data_00781_4_all.pkl",
26
+ "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl",
27
+ ]
28
+ BACKBONE_ASSET_PATH = "amr_vitbb.pth"
29
+ STAGE1_CONFIG_ASSET_PATH = "config_s1_HYDRA.yaml"
30
+ STAGE1_CHECKPOINT_ASSET_PATH = "s1ckpt_inference.ckpt"
31
+ STAGE3_CONFIG_ASSET_PATH = "config_s3_HYDRA.yaml"
32
+ STAGE3_CHECKPOINT_ASSET_PATH = "s3ckpt_inference.ckpt"
33
+
34
+ STAGE_ASSETS = {
35
+ "PRIMAS1": (STAGE1_CONFIG_ASSET_PATH, STAGE1_CHECKPOINT_ASSET_PATH, "s1ckpt_inference.ckpt"),
36
+ "PRIMAS3": (STAGE3_CONFIG_ASSET_PATH, STAGE3_CHECKPOINT_ASSET_PATH, "s3ckpt_inference.ckpt"),
37
+ }
38
+
39
+ STAGE_CHECKPOINTS = {
40
+ "PRIMAS1": Path("PRIMAS1/checkpoints/s1ckpt_inference.ckpt"),
41
+ "PRIMAS3": Path("PRIMAS3/checkpoints/s3ckpt_inference.ckpt"),
42
+ }
43
+
44
+ PathLike = Union[str, Path]
45
+
46
+
47
+ def _resolve_hf_repo_id(hf_repo_id: Optional[str]) -> str:
48
+ return hf_repo_id or os.environ.get("PRIMA_HF_REPO_ID", HF_REPO_ID)
49
+
50
+
51
+ def _default_checkpoint_path(data_dir: PathLike = "data") -> Path:
52
+ return Path(data_dir) / STAGE_CHECKPOINTS["PRIMAS1"]
53
+
54
+
55
+ def _config_path_for_checkpoint(checkpoint_path: PathLike) -> Path:
56
+ checkpoint_path = Path(checkpoint_path)
57
+ return checkpoint_path.parent.parent / ".hydra" / "config.yaml"
58
+
59
+
60
+ def _stage_for_checkpoint(checkpoint_path: PathLike) -> Optional[str]:
61
+ checkpoint_path = Path(checkpoint_path)
62
+ if len(checkpoint_path.parents) < 2:
63
+ return None
64
+ stage_name = checkpoint_path.parent.parent.name
65
+ stage_assets = STAGE_ASSETS.get(stage_name)
66
+ if stage_assets is None:
67
+ return None
68
+ _, _, checkpoint_name = stage_assets
69
+ if checkpoint_path.name != checkpoint_name:
70
+ return None
71
+ return stage_name
72
+
73
+
74
+ def _download_file(
75
+ hf_repo_id: str,
76
+ remote_filename: str,
77
+ destination: Path,
78
+ force_download: bool = False,
79
+ ) -> None:
80
+ try:
81
+ from huggingface_hub import hf_hub_download
82
+ except ImportError:
83
+ raise ImportError(
84
+ "huggingface_hub is required to download PRIMA demo assets. "
85
+ "Install it with: pip install huggingface_hub\n"
86
+ "Or download the assets manually and pass a local checkpoint path."
87
+ ) from None
88
+
89
+ destination.parent.mkdir(parents=True, exist_ok=True)
90
+ downloaded = hf_hub_download(
91
+ repo_id=hf_repo_id,
92
+ filename=remote_filename,
93
+ local_dir=str(destination.parent),
94
+ local_dir_use_symlinks=False,
95
+ force_download=force_download,
96
+ )
97
+ downloaded_path = Path(downloaded).resolve()
98
+ target = destination.resolve()
99
+ if downloaded_path != target:
100
+ if target.exists():
101
+ target.unlink()
102
+ shutil.move(str(downloaded_path), str(target))
103
+
104
+
105
+ def _validate_torch_checkpoint(path: Path) -> None:
106
+ import inspect
107
+ import pickle
108
+ import zipfile
109
+
110
+ import torch
111
+
112
+ if zipfile.is_zipfile(path):
113
+ with zipfile.ZipFile(path) as checkpoint_zip:
114
+ corrupt_member = checkpoint_zip.testzip()
115
+ if corrupt_member is not None:
116
+ raise RuntimeError(
117
+ f"Checkpoint file is invalid or incomplete: {path}\n"
118
+ f"Corrupt archive member: {corrupt_member}\n"
119
+ "Please redownload the checkpoint and try again."
120
+ )
121
+
122
+ supports_weights_only = "weights_only" in inspect.signature(torch.load).parameters
123
+ load_kwargs = {"map_location": "cpu"}
124
+ if supports_weights_only:
125
+ load_kwargs["weights_only"] = True
126
+
127
+ try:
128
+ torch.load(path, **load_kwargs)
129
+ except pickle.UnpicklingError as exc:
130
+ message = str(exc)
131
+ if (
132
+ supports_weights_only
133
+ and "Weights only load failed" in message
134
+ and ("Unsupported global" in message or "Unsupported class" in message)
135
+ ):
136
+ return
137
+ raise RuntimeError(
138
+ f"Checkpoint file is invalid or incomplete: {path}\n"
139
+ "Downloaded checkpoint is not loadable. "
140
+ "Please verify the uploaded Hugging Face file and try again."
141
+ ) from exc
142
+ except Exception as exc:
143
+ raise RuntimeError(
144
+ f"Checkpoint file is invalid or incomplete: {path}\n"
145
+ "Downloaded checkpoint is not loadable. "
146
+ "Please verify the uploaded Hugging Face file and try again."
147
+ ) from exc
148
+
149
+
150
+ def _ensure_backbone(data_dir: Path, force: bool, hf_repo_id: str) -> None:
151
+ target = data_dir / "amr_vitbb.pth"
152
+ if target.exists() and not force:
153
+ print(f"[skip] {target} already exists")
154
+ return
155
+
156
+ print("[download] pretrained backbone")
157
+ _download_file(hf_repo_id, BACKBONE_ASSET_PATH, target, force_download=force)
158
+ print(f"[ok] {target}")
159
+
160
+
161
+ def _ensure_smal_assets(data_dir: Path, force: bool, hf_repo_id: str) -> None:
162
+ required = [Path(p).name for p in SMAL_ASSET_PATHS]
163
+ smal_dir = data_dir / "smal"
164
+ if smal_dir.exists() and all((smal_dir / n).exists() for n in required) and not force:
165
+ print("[skip] SMAL files already exist")
166
+ return
167
+
168
+ print("[download] SMAL assets")
169
+ for asset_path in SMAL_ASSET_PATHS:
170
+ target = smal_dir / Path(asset_path).name
171
+ _download_file(hf_repo_id, asset_path, target, force_download=force)
172
+ print(f"[ok] {smal_dir}")
173
+
174
+
175
+ def _ensure_stage_assets(
176
+ stage_name: str,
177
+ data_dir: Path,
178
+ force: bool,
179
+ hf_repo_id: str,
180
+ validate_existing: bool = True,
181
+ ) -> None:
182
+ if stage_name not in STAGE_ASSETS:
183
+ known = ", ".join(sorted(STAGE_ASSETS))
184
+ raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}")
185
+
186
+ config_asset_path, checkpoint_asset_path, checkpoint_name = STAGE_ASSETS[stage_name]
187
+ stage_dir = data_dir / stage_name
188
+ config_target = stage_dir / ".hydra" / "config.yaml"
189
+ checkpoint_target = stage_dir / "checkpoints" / checkpoint_name
190
+ redownload_checkpoint = False
191
+
192
+ if config_target.exists() and checkpoint_target.exists() and not force:
193
+ if validate_existing:
194
+ try:
195
+ _validate_torch_checkpoint(checkpoint_target)
196
+ except RuntimeError:
197
+ print(f"[warn] {stage_name} checkpoint is incomplete, redownloading checkpoint only.")
198
+ redownload_checkpoint = True
199
+ else:
200
+ print(f"[skip] {stage_name} assets already exist")
201
+ return
202
+ else:
203
+ print(f"[skip] {stage_name} assets already exist")
204
+ return
205
+
206
+ print(f"[download] {stage_name} assets")
207
+ config_target.parent.mkdir(parents=True, exist_ok=True)
208
+ checkpoint_target.parent.mkdir(parents=True, exist_ok=True)
209
+ if force or not config_target.exists():
210
+ _download_file(hf_repo_id, config_asset_path, config_target, force_download=force)
211
+ if redownload_checkpoint and checkpoint_target.exists():
212
+ checkpoint_target.unlink()
213
+ if force or redownload_checkpoint or not checkpoint_target.exists():
214
+ _download_file(
215
+ hf_repo_id,
216
+ checkpoint_asset_path,
217
+ checkpoint_target,
218
+ force_download=force or redownload_checkpoint,
219
+ )
220
+ _validate_torch_checkpoint(checkpoint_target)
221
+ print(f"[ok] {stage_dir}")
222
+
223
+
224
+ def _normalize_stages(stages: Union[str, Iterable[str]]) -> Sequence[str]:
225
+ if isinstance(stages, str):
226
+ return (stages,)
227
+ return tuple(stages)
228
+
229
+
230
+ def _verify_assets(data_dir: Path, stages: Sequence[str]) -> None:
231
+ required_paths = [
232
+ data_dir / "smal" / "my_smpl_00781_4_all.pkl",
233
+ data_dir / "smal" / "my_smpl_data_00781_4_all.pkl",
234
+ data_dir / "smal" / "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl",
235
+ data_dir / "amr_vitbb.pth",
236
+ ]
237
+ for stage_name in stages:
238
+ if stage_name not in STAGE_ASSETS:
239
+ known = ", ".join(sorted(STAGE_ASSETS))
240
+ raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}")
241
+ _, _, checkpoint_name = STAGE_ASSETS[stage_name]
242
+ stage_dir = data_dir / stage_name
243
+ required_paths.extend(
244
+ [
245
+ stage_dir / ".hydra" / "config.yaml",
246
+ stage_dir / "checkpoints" / checkpoint_name,
247
+ ]
248
+ )
249
+
250
+ missing = [p for p in required_paths if not p.exists()]
251
+ if missing:
252
+ raise FileNotFoundError("Missing required files:\n" + "\n".join(str(p) for p in missing))
253
+
254
+ for stage_name in stages:
255
+ _, _, checkpoint_name = STAGE_ASSETS[stage_name]
256
+ _validate_torch_checkpoint(data_dir / stage_name / "checkpoints" / checkpoint_name)
257
+
258
+
259
+ def _ensure_assets_for_checkpoint(
260
+ checkpoint_path: PathLike,
261
+ force: bool = False,
262
+ hf_repo_id: Optional[str] = None,
263
+ ) -> None:
264
+ checkpoint_path = Path(checkpoint_path)
265
+ config_path = _config_path_for_checkpoint(checkpoint_path)
266
+ stage_name = _stage_for_checkpoint(checkpoint_path)
267
+ if stage_name is None:
268
+ if checkpoint_path.exists() and config_path.exists() and not force:
269
+ print(f"[skip] Using local PRIMA checkpoint {checkpoint_path}")
270
+ return
271
+ raise FileNotFoundError(
272
+ "Missing checkpoint or config for a custom path:\n"
273
+ f" checkpoint: {checkpoint_path}\n"
274
+ f" config: {config_path}\n"
275
+ "Auto-download supports the standard PRIMA demo layouts only:\n"
276
+ " data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt\n"
277
+ " data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt\n"
278
+ "Pass one of those paths, or download/copy your custom checkpoint manually."
279
+ )
280
+
281
+ data_dir = checkpoint_path.parent.parent.parent
282
+ repo_id = _resolve_hf_repo_id(hf_repo_id)
283
+ print(f"[download] Ensuring PRIMA demo assets under {data_dir}")
284
+ _ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id)
285
+ _ensure_backbone(data_dir, force=force, hf_repo_id=repo_id)
286
+ _ensure_stage_assets(
287
+ stage_name,
288
+ data_dir,
289
+ force=force,
290
+ hf_repo_id=repo_id,
291
+ validate_existing=False,
292
+ )
293
+
294
+
295
+ def ensure_demo_assets(
296
+ data_dir: PathLike = "data",
297
+ *,
298
+ stages: Union[str, Iterable[str]] = ("PRIMAS1",),
299
+ force: bool = False,
300
+ hf_repo_id: Optional[str] = None,
301
+ ) -> None:
302
+ """Ensure PRIMA demo assets exist in the expected ``data/`` layout."""
303
+ data_dir = Path(data_dir).resolve()
304
+ data_dir.mkdir(parents=True, exist_ok=True)
305
+ repo_id = _resolve_hf_repo_id(hf_repo_id)
306
+ selected_stages = _normalize_stages(stages)
307
+
308
+ _ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id)
309
+ _ensure_backbone(data_dir, force=force, hf_repo_id=repo_id)
310
+ for stage_name in selected_stages:
311
+ _ensure_stage_assets(stage_name, data_dir, force=force, hf_repo_id=repo_id)
312
+ _verify_assets(data_dir, selected_stages)
313
+
314
+
315
+ def resolve_prima_checkpoint_path(
316
+ checkpoint_path: PathLike = "",
317
+ *,
318
+ data_dir: PathLike = "data",
319
+ auto_download: bool = True,
320
+ hf_repo_id: Optional[str] = None,
321
+ force: bool = False,
322
+ ) -> str:
323
+ """Return a PRIMA checkpoint path, downloading default demo assets if needed."""
324
+ resolved = Path(checkpoint_path) if checkpoint_path else _default_checkpoint_path(data_dir)
325
+ if auto_download:
326
+ _ensure_assets_for_checkpoint(resolved, force=force, hf_repo_id=hf_repo_id)
327
+ return str(resolved)
328
+
329
+
330
+ __all__ = [
331
+ "DEFAULT_HF_REPO_ID",
332
+ "DEFAULT_STAGE1_CHECKPOINT",
333
+ "DEFAULT_STAGE3_CHECKPOINT",
334
+ "HF_REPO_ID",
335
+ "ensure_demo_assets",
336
+ "resolve_prima_checkpoint_path",
337
+ ]