HF Space deploy commited on
Commit
2ba375b
·
0 Parent(s):

Deploy snapshot (LFS for demo images per .gitattributes)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. .github/workflows/check-headers.yml +36 -0
  3. .github/workflows/codespell.yml +21 -0
  4. .github/workflows/release-pypi.yml +48 -0
  5. .gitignore +175 -0
  6. README.md +220 -0
  7. app.py +561 -0
  8. chumpy/__init__.py +16 -0
  9. chumpy/ch.py +37 -0
  10. configs_hydra/experiment/default.yaml +28 -0
  11. configs_hydra/experiment/default_val.yaml +34 -0
  12. configs_hydra/experiment/primaStage1.yaml +83 -0
  13. configs_hydra/experiment/primaStage2.yaml +113 -0
  14. configs_hydra/extras/default.yaml +8 -0
  15. configs_hydra/hydra/default.yaml +26 -0
  16. configs_hydra/launcher/local.yaml +13 -0
  17. configs_hydra/launcher/slurm.yaml +22 -0
  18. configs_hydra/paths/default.yaml +18 -0
  19. configs_hydra/train.yaml +46 -0
  20. configs_hydra/trainer/cpu.yaml +6 -0
  21. configs_hydra/trainer/ddp.yaml +14 -0
  22. configs_hydra/trainer/default.yaml +10 -0
  23. configs_hydra/trainer/default_amr.yaml +9 -0
  24. configs_hydra/trainer/gpu.yaml +6 -0
  25. configs_hydra/trainer/mps.yaml +6 -0
  26. demo.py +144 -0
  27. demo_data/000000015956_horse.png +3 -0
  28. demo_data/000000315905_zebra.jpg +3 -0
  29. demo_data/beagle.jpg +3 -0
  30. demo_data/n02101388_1188.png +3 -0
  31. demo_data/n02412080_12159.png +3 -0
  32. demo_data/shepherd_hati.jpg +3 -0
  33. demo_tta.py +340 -0
  34. eval.py +102 -0
  35. images/teaser.png +3 -0
  36. packages.txt +4 -0
  37. prima/__init__.py +25 -0
  38. prima/configs/__init__.py +99 -0
  39. prima/models/__init__.py +54 -0
  40. prima/models/backbones/__init__.py +19 -0
  41. prima/models/backbones/vit.py +375 -0
  42. prima/models/bioclip_embedding.py +70 -0
  43. prima/models/components/__init__.py +0 -0
  44. prima/models/components/model_utils.py +160 -0
  45. prima/models/components/pose_transformer.py +366 -0
  46. prima/models/components/position_encoding.py +84 -0
  47. prima/models/components/t_cond_mlp.py +204 -0
  48. prima/models/components/transformer.py +400 -0
  49. prima/models/discriminator.py +129 -0
  50. prima/models/heads/__init__.py +1 -0
.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Hugging Face Hub stores these via Git LFS / Xet (plain PNG/JPG in git are rejected on push).
2
+ demo_data/*.png filter=lfs diff=lfs merge=lfs -text
3
+ demo_data/*.jpg filter=lfs diff=lfs merge=lfs -text
4
+ demo_data/*.jpeg filter=lfs diff=lfs merge=lfs -text
5
+ images/*.png filter=lfs diff=lfs merge=lfs -text
.github/workflows/check-headers.yml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Check File Headers
3
+
4
+ on:
5
+ push:
6
+ branches: [main]
7
+ pull_request:
8
+ branches: [main]
9
+
10
+ jobs:
11
+ check-headers:
12
+ name: Check Python file headers
13
+ runs-on: ubuntu-latest
14
+ permissions:
15
+ contents: read
16
+
17
+ steps:
18
+ - name: Checkout code
19
+ uses: actions/checkout@v3
20
+
21
+ - name: Set up Python
22
+ uses: actions/setup-python@v4
23
+ with:
24
+ python-version: "3.10"
25
+
26
+ - name: Check headers
27
+ run: |
28
+ python scripts/update_headers.py --check
29
+ continue-on-error: false
30
+
31
+ - name: Provide fix instructions
32
+ if: failure()
33
+ run: |
34
+ echo "::error::Some files are missing proper headers."
35
+ echo "To fix this, run: python scripts/update_headers.py"
36
+ echo "Then commit the changes."
.github/workflows/codespell.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Codespell
3
+
4
+ on:
5
+ push:
6
+ branches: [main]
7
+ pull_request:
8
+ branches: [main]
9
+
10
+ jobs:
11
+ codespell:
12
+ name: Check for spelling errors
13
+ runs-on: ubuntu-latest
14
+
15
+ steps:
16
+ - name: Checkout
17
+ uses: actions/checkout@v3
18
+ - name: Codespell
19
+ uses: codespell-project/actions-codespell@v1
20
+ with:
21
+ ignore_words_list: prima-animal, mpjpe, uvd, xyz, hm36, cpn, dbb
.github/workflows/release-pypi.yml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Update pypi release
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - 'v*.*.*'
7
+ pull_request:
8
+ branches:
9
+ - main
10
+ types:
11
+ - labeled
12
+ - opened
13
+ - edited
14
+ - synchronize
15
+ - reopened
16
+
17
+ jobs:
18
+ release:
19
+ runs-on: ubuntu-latest
20
+
21
+ steps:
22
+ - name: Cache dependencies
23
+ id: pip-cache
24
+ uses: actions/cache@v4
25
+ with:
26
+ path: ~/.cache/pip
27
+ key: ${{ runner.os }}-pip
28
+
29
+ - name: Install dependencies
30
+ run: |
31
+ pip install --upgrade pip
32
+ pip install wheel
33
+ # NOTE(stes) see https://github.com/pypa/twine/issues/1216#issuecomment-2629069669
34
+ pip install "packaging>=24.2"
35
+
36
+ - name: Checkout code
37
+ uses: actions/checkout@v3
38
+
39
+ - name: Build and publish to PyPI
40
+ if: ${{ github.event_name == 'push' }}
41
+ env:
42
+ TWINE_USERNAME: __token__
43
+ TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }}
44
+ run: |
45
+ pip install build twine
46
+ python3 -m build
47
+ ls dist/
48
+ python3 -m twine upload --verbose dist/*
.gitignore ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+ # Vscode
164
+ .vscode/
165
+
166
+ # Directory
167
+ .gradio/
168
+ demo_out/
169
+ demo_out*/
170
+ data/PRIMA*/
171
+ data/backbone.pth
172
+ logs/
173
+ *.pth
174
+ *.pkl
175
+ datasets/
README.md ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PRIMA Demo
3
+ emoji: 🦮
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ python_version: "3.10"
8
+ app_file: app.py
9
+ startup_duration_timeout: 60m
10
+ ---
11
+
12
+ # PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
13
+
14
+
15
+ This is the official implementation of the approach described in the preprint:
16
+
17
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation \
18
+ Xiaohang Yu, Ti Wang, Mackenzie Weygandt Mathis
19
+
20
+ ![PRIMA teaser](images/teaser.png)
21
+
22
+
23
+ ---
24
+
25
+
26
+ ## 🚀 TL;DR
27
+ PRIMA creates a 3D quadruped mesh from a single 2D image. It leverages BioCLIP-based biological priors for robust cross-species shape understanding, then applies test-time adaptation with 2D reprojection and auxiliary keypoint guidance to refine SMAL pose and shape predictions.
28
+
29
+ It further can be used to build Quadruped3D, a large-scale pseudo-3D dataset with diverse species and poses.
30
+
31
+ PRIMA achieves state-of-the-art results on Animal3D, CtrlAni3D, Quadruped2D, and Animal Kingdom datasets.
32
+
33
+ ## Installation
34
+
35
+ ### Install from PyPI
36
+
37
+ > Recommended: Python 3.10 and a CUDA-enabled PyTorch installation.
38
+
39
+ ```bash
40
+ conda create -n prima python=3.10 -y
41
+ conda activate prima
42
+
43
+ # Install PyTorch matching your CUDA (example: CUDA 11.8)
44
+ pip install --index-url https://download.pytorch.org/whl/cu118 \
45
+ "torch==2.2.1" "torchvision==0.17.1" "torchaudio==2.2.1"
46
+
47
+ # Install chumpy and PyTorch3D
48
+ python -m pip install --no-build-isolation \
49
+ "git+https://github.com/mattloper/chumpy.git"
50
+ python -m pip install --no-build-isolation \
51
+ "git+https://github.com/facebookresearch/pytorch3d.git"
52
+
53
+ # Install PRIMA from PyPI
54
+ pip install prima-animal
55
+ ```
56
+
57
+ `prima-animal` includes demo runtime dependencies used by `demo.py`, `demo_tta.py`, and `app.py` (including Detectron2 and DeepLabCut).
58
+
59
+ ---
60
+
61
+ ## Demo
62
+
63
+ ### Checkpoints and data
64
+
65
+ We provide an automated demo-download script for models hosted on Hugging Face.
66
+ Use the helper script to download and place all demo assets automatically in `data/`:
67
+
68
+ ```bash
69
+ python scripts/setup_demo_data.py --hf-repo-id MLAdaptiveIntelligence/PRIMA
70
+ ```
71
+
72
+ Approximate download volume from Hugging Face is ~24 GB total
73
+ (`s1ckpt.ckpt` ~10.2 GB + `s3ckpt.ckpt` ~10.2 GB + `amr_vitbb.pth` ~2.5 GB + SMAL files).
74
+ Expected time is roughly:
75
+ - 100 Mbps: ~35-45 minutes
76
+ - 300 Mbps: ~12-18 minutes
77
+ - 1 Gbps: ~4-8 minutes
78
+
79
+ To avoid re-downloading completed assets, rerun without `--force`. The script now
80
+ re-downloads only missing or invalid checkpoints.
81
+
82
+ Expected files in that Hugging Face repo root:
83
+ - `my_smpl_00781_4_all.pkl`
84
+ - `my_smpl_data_00781_4_all.pkl`
85
+ - `walking_toy_symmetric_pose_prior_with_cov_35parts.pkl`
86
+ - `amr_vitbb.pth`
87
+ - `config_s1_HYDRA.yaml`
88
+ - `config_s3_HYDRA.yaml`
89
+ - `s1ckpt.ckpt`
90
+ - `s3ckpt.ckpt`
91
+
92
+ ### Demo (without TTA)
93
+
94
+ Run animal detection + PRIMA 3D pose/shape inference:
95
+
96
+ ```bash
97
+ python demo.py \
98
+ --checkpoint data/PRIMAS1/checkpoints/s1ckpt.ckpt \
99
+ --img_folder demo_data/ \
100
+ --out_folder demo_out/
101
+ ```
102
+
103
+ Outputs are written to `demo_out/`.
104
+
105
+ ---
106
+
107
+ ### Demo (with TTA)
108
+
109
+ `demo_tta.py` pipeline: specify learning rate and number of iterations:
110
+
111
+ Example:
112
+
113
+ ```bash
114
+ python demo_tta.py \
115
+ --checkpoint data/PRIMAS1/checkpoints/s1ckpt.ckpt \
116
+ --img_folder demo_data/ \
117
+ --out_folder demo_out_tta/ \
118
+ --tta_lr 1e-6 \
119
+ --tta_num_iters 30
120
+ ```
121
+
122
+ Outputs are written to `demo_out_tta/` (before/after TTA renders, keypoints, and optional meshes).
123
+
124
+ ---
125
+
126
+ ### Gradio demo
127
+
128
+ We also provide a simple Gradio-based web demo for interactive testing in the
129
+ browser:
130
+
131
+ ```bash
132
+ python app.py \
133
+ --checkpoint data/PRIMAS1/checkpoints/s1ckpt.ckpt \
134
+ --out_folder demo_out_tta_gradio/
135
+ ```
136
+
137
+ This starts a local Gradio app (by default on http://127.0.0.1:7860), where
138
+ you can upload images and visualize PRIMA predictions and adaptation results.
139
+
140
+ #### Hugging Face Space (maintainers)
141
+
142
+ Demo images under `demo_data/` and `images/teaser.png` are tracked with **Git LFS**
143
+ (see `.gitattributes`) so they can be pushed to a Hugging Face Space under the Hub’s
144
+ LFS / **Xet** bridge. Install tooling once:
145
+
146
+ ```bash
147
+ brew install git-lfs git-xet
148
+ git xet install
149
+ git lfs install
150
+ ```
151
+
152
+ Then from a clean checkout with LFS files present, deploy the Space repo:
153
+
154
+ ```bash
155
+ ./scripts/deploy_hf_space.sh
156
+ ```
157
+
158
+ The script rsyncs the working tree (not `git archive`) so image files are materialized
159
+ before `git add` turns them into LFS blobs.
160
+
161
+ ---
162
+
163
+
164
+ ## Training and Evaluation
165
+
166
+ ### Dataset Setup
167
+
168
+ Download datasets from [Animal3D](https://xujiacong.github.io/Animal3D/), [CtrlAni3D](https://github.com/luoxue-star/AniMer?tab=readme-ov-file#training), Quadruped2D, and [Animal Kingdom](https://drive.google.com/file/d/1dk2a0qB0fbVZ4X6eAgP6VJVXj0rxVfsJ/view?usp=drive_link). For Quadruped2D, download the images from [SuperAnimal-Quadruped80K](https://zenodo.org/records/14016777) and our processed annotations from [here](https://drive.google.com/drive/folders/1eBNboxVwl_eGPoC93zxf-U3hmE6e2f-f?usp=sharing). Put all the datasets under `datasets/`.
169
+
170
+ ### Training
171
+
172
+ Two-stage training script:
173
+
174
+ ```bash
175
+ bash train.sh
176
+ ```
177
+
178
+ Training outputs are written to `logs/train/runs/<exp_name>/`.
179
+
180
+
181
+ ### Evaluation
182
+
183
+ ```bash
184
+ python eval.py \
185
+ --config data/PRIMAS1/.hydra/config.yaml \
186
+ --checkpoint data/PRIMAS1/checkpoints/s1ckpt.ckpt
187
+ ```
188
+
189
+ Common values for `--dataset` are controlled by:
190
+ - `configs_hydra/experiment/default_val.yaml`
191
+
192
+ ---
193
+
194
+
195
+ ## Acknowledgements
196
+
197
+ This release builds on several open-source projects, including:
198
+ - [Detectron2](https://github.com/facebookresearch/detectron2)
199
+ - [BioCLIP](https://github.com/Imageomics/BioCLIP)
200
+ - [AniMer](https://github.com/luoxue-star/AniMer)
201
+ - [DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)
202
+
203
+ ---
204
+
205
+ ## Citation
206
+
207
+ If you use this code in your research, please cite our PRIMA paper.
208
+
209
+ ```bibtex
210
+ @misc{yu_prima,
211
+ title={PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation},
212
+ author={Xiaohang Yu and Ti Wang and Mackenzie Weygandt Mathis},
213
+ }
214
+ ```
215
+
216
+ ---
217
+
218
+ ## Contact
219
+
220
+ For issues, please open a GitHub issue in this repository.
app.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """Gradio demo for PRIMA + SuperAnimal + TTA.
11
+
12
+ This script wraps the ``demo_tta.py`` pipeline into an interactive
13
+ Gradio interface. The overall logic follows:
14
+
15
+ 1. Given an input image, run Detectron2 to detect animals.
16
+ 2. For each detected animal, run PRIMA for 3D pose/shape estimation.
17
+ 3. Run DeepLabCut SuperAnimal to obtain 2D keypoints.
18
+ 4. Map SuperAnimal 39 keypoints to the 26 PRIMA keypoints.
19
+ 5. Run test-time adaptation (TTA) with user-specified lr and iters.
20
+ 6. Render and save before/after TTA results and keypoint visualizations.
21
+
22
+ """
23
+
24
+ import argparse
25
+ import os
26
+ import sys
27
+ import tempfile
28
+ import traceback
29
+ from types import SimpleNamespace
30
+ from typing import List, Tuple
31
+ from pathlib import Path
32
+
33
+ import cv2
34
+ import gradio as gr
35
+ import numpy as np
36
+ import torch
37
+ import torch.utils.data
38
+
39
+ # Repo-local minimal ``chumpy`` shim (see ``chumpy/__init__.py``) so SMAL pickles load
40
+ # without installing the full chumpy package in Space builds.
41
+ _REPO_ROOT = Path(__file__).resolve().parent
42
+ if str(_REPO_ROOT) not in sys.path:
43
+ sys.path.insert(0, str(_REPO_ROOT))
44
+
45
+
46
+ # Default checkpoint path following README instructions
47
+ DEFAULT_CHECKPOINT = "data/PRIMAS1/checkpoints/s1ckpt.ckpt"
48
+ DEFAULT_HF_ASSET_REPO = "MLAdaptiveIntelligence/PRIMA"
49
+
50
+ # Output folder for rendered images/meshes and keypoints
51
+ DEFAULT_OUT_FOLDER = "demo_out_tta_gradio"
52
+
53
+
54
+ def _is_truthy_env(var_name: str) -> bool:
55
+ return os.environ.get(var_name, "").strip().lower() in {"1", "true", "yes", "on"}
56
+
57
+
58
+ def _running_on_space() -> bool:
59
+ return bool(os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID"))
60
+
61
+
62
+ def _gradio_examples_for_interface() -> List[List]:
63
+ """Gradio prefetches example media at startup.
64
+
65
+ Demo images are tracked with Git LFS / Xet (see ``.gitattributes``) so they can live
66
+ in the Hugging Face Space repo. Use absolute paths only when files exist beside ``app.py``.
67
+ """
68
+ if _is_truthy_env("PRIMA_DISABLE_GRADIO_EXAMPLES"):
69
+ return []
70
+ rows: List[List] = []
71
+ template: List[Tuple[str, float, int, float, float, bool, bool]] = [
72
+ ("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, True),
73
+ ("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, True),
74
+ ("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, True),
75
+ ("demo_data/beagle.jpg", 1e-6, 0, 0.7, 0.1, False, True),
76
+ ("demo_data/shepherd_hati.jpg", 1e-6, 0, 0.7, 0.1, False, True),
77
+ ]
78
+ for rel, *rest in template:
79
+ p = _REPO_ROOT / rel
80
+ if p.is_file():
81
+ rows.append([str(p), *rest])
82
+ return rows
83
+
84
+
85
+ def _should_preload_assets() -> bool:
86
+ """Default to preload on Spaces; configurable via PRIMA_PRELOAD_ASSETS."""
87
+ preload_env = os.environ.get("PRIMA_PRELOAD_ASSETS")
88
+ if preload_env is not None:
89
+ return _is_truthy_env("PRIMA_PRELOAD_ASSETS")
90
+ return _running_on_space()
91
+
92
+
93
+ def _ensure_demo_assets(checkpoint_path: str) -> None:
94
+ """Download required demo assets when running in a clean environment."""
95
+ from scripts.setup_demo_data import (
96
+ maybe_download_smal,
97
+ maybe_download_backbone,
98
+ maybe_download_stage,
99
+ )
100
+
101
+ checkpoint = Path(checkpoint_path)
102
+ data_dir = checkpoint.parents[2]
103
+ hf_repo_id = os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO)
104
+
105
+ maybe_download_smal(data_dir, force=False, hf_repo_id=hf_repo_id)
106
+ maybe_download_backbone(data_dir, force=False, hf_repo_id=hf_repo_id)
107
+ maybe_download_stage(
108
+ "PRIMAS1",
109
+ "config_s1_HYDRA.yaml",
110
+ "s1ckpt.ckpt",
111
+ "s1ckpt.ckpt",
112
+ data_dir,
113
+ force=False,
114
+ hf_repo_id=hf_repo_id,
115
+ )
116
+
117
+
118
+ def _preload_assets_once(checkpoint_path: str) -> None:
119
+ checkpoint = Path(checkpoint_path)
120
+ cfg_path = checkpoint.parent.parent / ".hydra" / "config.yaml"
121
+ if checkpoint.exists() and cfg_path.exists():
122
+ print("[startup] Assets already present; skipping preload.")
123
+ return
124
+ print("[startup] Preloading demo assets from Hugging Face Hub...")
125
+ _ensure_demo_assets(checkpoint_path)
126
+ print("[startup] Asset preload complete.")
127
+
128
+
129
+ def _load_prima_model(checkpoint_path: str = DEFAULT_CHECKPOINT):
130
+ """Load PRIMA model and renderer once for the Gradio app."""
131
+ from prima.models import load_prima
132
+ from prima.utils.renderer import Renderer
133
+
134
+ checkpoint = Path(checkpoint_path)
135
+ cfg_path = checkpoint.parent.parent / ".hydra" / "config.yaml"
136
+ if not checkpoint.exists() or not cfg_path.exists():
137
+ _ensure_demo_assets(checkpoint_path)
138
+ if not checkpoint.exists():
139
+ raise FileNotFoundError(
140
+ f"Missing checkpoint: {checkpoint}. Download demo checkpoints/data as described in README."
141
+ )
142
+ if not cfg_path.exists():
143
+ raise FileNotFoundError(
144
+ f"Missing model config: {cfg_path}. Ensure the full checkpoint folder layout from README is present."
145
+ )
146
+
147
+ model, model_cfg = load_prima(checkpoint_path)
148
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
149
+ model = model.to(device)
150
+ model.eval()
151
+
152
+ renderer = Renderer(model_cfg, faces=model.smal.faces)
153
+ return model, model_cfg, renderer, device
154
+
155
+
156
+ def _build_detector():
157
+ """Build Detectron2 animal detector (same config as demo_tta/demo.py)."""
158
+ try:
159
+ import detectron2.config
160
+ import detectron2.engine
161
+ from detectron2 import model_zoo
162
+ except Exception as e:
163
+ print(f"[warn] Detectron2 unavailable ({type(e).__name__}: {e}); using full-image fallback bbox.")
164
+ return None
165
+
166
+ cfg = detectron2.config.get_cfg()
167
+ cfg.merge_from_file(
168
+ model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml")
169
+ )
170
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
171
+ cfg.MODEL.WEIGHTS = (
172
+ "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/"
173
+ "faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
174
+ )
175
+ cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
176
+ detector = detectron2.engine.DefaultPredictor(cfg)
177
+ return detector
178
+
179
+ # SuperAnimal defaults (same as in demo_tta parser)
180
+ SUPER_ANIMAL_ARGS = SimpleNamespace(
181
+ superanimal_name="superanimal_quadruped",
182
+ superanimal_model_name="hrnet_w32",
183
+ superanimal_detector_name="fasterrcnn_resnet50_fpn_v2",
184
+ superanimal_max_individuals=1,
185
+ )
186
+
187
+
188
+ def _collect_animal_results(
189
+ model,
190
+ model_cfg,
191
+ renderer,
192
+ device,
193
+ detector,
194
+ out_folder: str,
195
+ img_rgb: np.ndarray,
196
+ tta_lr: float,
197
+ tta_num_iters: int,
198
+ det_thresh: float,
199
+ kp_conf_thresh: float,
200
+ side_view: bool,
201
+ save_mesh: bool,
202
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], str | None, str | None]:
203
+ """Run detection + PRIMA + SuperAnimal + TTA on a single RGB image.
204
+
205
+ Returns:
206
+ before_imgs: list of HxWx3 RGB images (before TTA) for all animals
207
+ after_imgs: list of HxWx3 RGB images (after TTA) for all animals
208
+ kpt_imgs: list of HxWx3 RGB keypoint visualizations
209
+ first_before_mesh: path to first animal's before-TTA mesh (.obj) or None
210
+ first_after_mesh: path to first animal's after-TTA mesh (.obj) or None
211
+ """
212
+ from prima.utils import recursive_to
213
+ from prima.datasets.vitdet_dataset import ViTDetDataset
214
+ from demo_tta import (
215
+ ANIMAL_COCO_IDS,
216
+ denorm_patch_to_rgb,
217
+ map_superanimal_to_prima,
218
+ run_superanimal_on_patch,
219
+ save_keypoint_vis,
220
+ tta_optimize,
221
+ )
222
+
223
+ # Detect animals
224
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
225
+ if detector is None:
226
+ # Fallback for environments where Detectron2 is unavailable: process full image as one crop.
227
+ h, w = img_bgr.shape[:2]
228
+ boxes = np.array([[0.0, 0.0, float(max(1, w - 1)), float(max(1, h - 1))]], dtype=np.float32)
229
+ else:
230
+ det_out = detector(img_bgr)
231
+ det_instances = det_out["instances"]
232
+
233
+ valid_idx = [
234
+ i
235
+ for i, (c, s) in enumerate(zip(det_instances.pred_classes, det_instances.scores))
236
+ if (int(c) in ANIMAL_COCO_IDS) and (float(s) > float(det_thresh))
237
+ ]
238
+ if len(valid_idx) == 0:
239
+ return [], [], [], None, None
240
+
241
+ boxes = det_instances.pred_boxes.tensor[valid_idx].cpu().numpy()
242
+
243
+ dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
244
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
245
+
246
+ before_imgs: List[np.ndarray] = []
247
+ after_imgs: List[np.ndarray] = []
248
+ kpt_imgs: List[np.ndarray] = []
249
+ before_mesh_paths: List[str] = []
250
+ after_mesh_paths: List[str] = []
251
+
252
+ img_token = next(tempfile._get_candidate_names())
253
+
254
+ for batch in dataloader:
255
+ batch = recursive_to(batch, device)
256
+
257
+ with torch.no_grad():
258
+ out_before = model(batch)
259
+
260
+ animal_id = int(batch["animalid"][0])
261
+
262
+ # Save/render before TTA
263
+ img_fn = f"{img_token}"
264
+ from demo_tta import render_and_save # imported lazily to avoid circular issues
265
+
266
+ render_and_save(
267
+ renderer,
268
+ out_before,
269
+ batch,
270
+ img_fn,
271
+ animal_id,
272
+ out_folder,
273
+ suffix="before_tta",
274
+ side_view=side_view,
275
+ save_mesh=save_mesh,
276
+ )
277
+
278
+ before_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.png")
279
+ if os.path.exists(before_png_path):
280
+ before_bgr = cv2.imread(before_png_path)
281
+ if before_bgr is not None:
282
+ before_imgs.append(cv2.cvtColor(before_bgr, cv2.COLOR_BGR2RGB))
283
+
284
+ if save_mesh:
285
+ before_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.obj")
286
+ if os.path.exists(before_obj_path):
287
+ before_mesh_paths.append(before_obj_path)
288
+
289
+ if int(tta_num_iters) <= 0:
290
+ render_and_save(
291
+ renderer,
292
+ out_before,
293
+ batch,
294
+ img_fn,
295
+ animal_id,
296
+ out_folder,
297
+ suffix="after_tta",
298
+ side_view=side_view,
299
+ save_mesh=save_mesh,
300
+ )
301
+
302
+ after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
303
+ if os.path.exists(after_png_path):
304
+ after_bgr = cv2.imread(after_png_path)
305
+ if after_bgr is not None:
306
+ after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB))
307
+
308
+ if save_mesh:
309
+ after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj")
310
+ if os.path.exists(after_obj_path):
311
+ after_mesh_paths.append(after_obj_path)
312
+ continue
313
+
314
+ # Prepare patch for SuperAnimal
315
+ patch_rgb = denorm_patch_to_rgb(batch["img"][0])
316
+ with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir:
317
+ bodyparts_xyc = run_superanimal_on_patch(patch_rgb, SUPER_ANIMAL_ARGS, tmp_dir)
318
+
319
+ if bodyparts_xyc is None:
320
+ # No keypoints => skip TTA for this animal
321
+ continue
322
+
323
+ mapped_xyc = map_superanimal_to_prima(bodyparts_xyc)
324
+ mapped_xyc[mapped_xyc[:, 2] < float(kp_conf_thresh), 2] = 0.0
325
+
326
+ # Save keypoint visualization and npy
327
+ kpt_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png")
328
+ save_keypoint_vis(patch_rgb, mapped_xyc, kpt_png_path)
329
+ npy_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy")
330
+ np.save(npy_path, mapped_xyc)
331
+
332
+ if os.path.exists(kpt_png_path):
333
+ kpt_bgr = cv2.imread(kpt_png_path)
334
+ if kpt_bgr is not None:
335
+ kpt_imgs.append(cv2.cvtColor(kpt_bgr, cv2.COLOR_BGR2RGB))
336
+
337
+ # Normalize keypoints to [-0.5, 0.5] as in demo_tta
338
+ patch_h, patch_w = patch_rgb.shape[:2]
339
+ mapped_norm = mapped_xyc.copy()
340
+ mapped_norm[:, 0] = mapped_norm[:, 0] / float(patch_w) - 0.5
341
+ mapped_norm[:, 1] = mapped_norm[:, 1] / float(patch_h) - 0.5
342
+ gt_kpts_norm = torch.from_numpy(mapped_norm[None]).to(device=device, dtype=batch["img"].dtype)
343
+
344
+ # Run TTA
345
+ out_after = tta_optimize(
346
+ model,
347
+ batch,
348
+ gt_kpts_norm,
349
+ num_iters=int(tta_num_iters),
350
+ lr=float(tta_lr),
351
+ )
352
+
353
+ render_and_save(
354
+ renderer,
355
+ out_after,
356
+ batch,
357
+ img_fn,
358
+ animal_id,
359
+ out_folder,
360
+ suffix="after_tta",
361
+ side_view=side_view,
362
+ save_mesh=save_mesh,
363
+ )
364
+
365
+ after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
366
+ if os.path.exists(after_png_path):
367
+ after_bgr = cv2.imread(after_png_path)
368
+ if after_bgr is not None:
369
+ after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB))
370
+
371
+ if save_mesh:
372
+ after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj")
373
+ if os.path.exists(after_obj_path):
374
+ after_mesh_paths.append(after_obj_path)
375
+
376
+ first_before_mesh = before_mesh_paths[0] if before_mesh_paths else None
377
+ first_after_mesh = after_mesh_paths[0] if after_mesh_paths else None
378
+
379
+ return before_imgs, after_imgs, kpt_imgs, first_before_mesh, first_after_mesh
380
+
381
+
382
+ def build_demo(checkpoint_path: str = DEFAULT_CHECKPOINT, out_folder: str = DEFAULT_OUT_FOLDER) -> gr.Interface:
383
+ os.makedirs(out_folder, exist_ok=True)
384
+ runtime_cache = {
385
+ "model": None,
386
+ "model_cfg": None,
387
+ "renderer": None,
388
+ "device": None,
389
+ "detector": None,
390
+ }
391
+
392
+ def gradio_inference(
393
+ image: np.ndarray,
394
+ tta_lr: float,
395
+ tta_num_iters: int,
396
+ det_thresh: float,
397
+ kp_conf_thresh: float,
398
+ side_view: bool,
399
+ save_mesh: bool,
400
+ ):
401
+ """Wrapper for Gradio. ``image`` is an RGB numpy array.
402
+
403
+ Yields intermediate status so long first-run (Hub downloads + model load)
404
+ does not hit silent client/proxy timeouts.
405
+ """
406
+
407
+ if image is None:
408
+ yield None, None, None, "No image provided."
409
+ return
410
+
411
+ if image.dtype != np.uint8:
412
+ img_rgb = np.clip(image, 0, 255).astype(np.uint8)
413
+ else:
414
+ img_rgb = image
415
+
416
+ yield None, None, None, "Queued; preparing run…"
417
+
418
+ if runtime_cache["model"] is None:
419
+ yield (
420
+ None,
421
+ None,
422
+ None,
423
+ "First run: downloading demo assets from Hugging Face (large checkpoint) "
424
+ "and loading the model. This can take many minutes; status updates here "
425
+ "mean the session is still alive.",
426
+ )
427
+ try:
428
+ model, model_cfg, renderer, device = _load_prima_model(checkpoint_path)
429
+ detector = _build_detector()
430
+ except Exception:
431
+ yield None, None, None, f"Model initialization failed:\n{traceback.format_exc()}"
432
+ return
433
+ runtime_cache["model"] = model
434
+ runtime_cache["model_cfg"] = model_cfg
435
+ runtime_cache["renderer"] = renderer
436
+ runtime_cache["device"] = device
437
+ runtime_cache["detector"] = detector
438
+ yield None, None, None, "Model loaded. Running detection and inference…"
439
+
440
+ try:
441
+ before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = _collect_animal_results(
442
+ runtime_cache["model"],
443
+ runtime_cache["model_cfg"],
444
+ runtime_cache["renderer"],
445
+ runtime_cache["device"],
446
+ runtime_cache["detector"],
447
+ out_folder,
448
+ img_rgb,
449
+ tta_lr=tta_lr,
450
+ tta_num_iters=tta_num_iters,
451
+ det_thresh=det_thresh,
452
+ kp_conf_thresh=kp_conf_thresh,
453
+ side_view=side_view,
454
+ save_mesh=save_mesh,
455
+ )
456
+ except Exception:
457
+ yield None, None, None, f"Inference failed:\n{traceback.format_exc()}"
458
+ return
459
+
460
+ first_before = before_imgs[0] if before_imgs else None
461
+ first_after = after_imgs[0] if after_imgs else None
462
+ first_kpts = kpt_imgs[0] if kpt_imgs else None
463
+ if first_before is None and first_after is None:
464
+ yield (
465
+ None,
466
+ None,
467
+ None,
468
+ "No output generated. Try an image with a clearly visible quadruped.",
469
+ )
470
+ return
471
+ yield first_before, first_after, first_kpts, "OK"
472
+
473
+ _gradio_examples = _gradio_examples_for_interface()
474
+ _iface_kw = dict(
475
+ fn=gradio_inference,
476
+ analytics_enabled=False,
477
+ cache_examples=False,
478
+ inputs=[
479
+ gr.Image(
480
+ label="Input image",
481
+ type="numpy",
482
+ sources=["upload", "clipboard"],
483
+ ),
484
+ gr.Slider(
485
+ label="TTA learning rate",
486
+ minimum=1e-7,
487
+ maximum=1e-4,
488
+ value=1e-6,
489
+ step=1e-7,
490
+ ),
491
+ gr.Slider(
492
+ label="TTA iterations",
493
+ minimum=0,
494
+ maximum=100,
495
+ value=30,
496
+ step=1,
497
+ info="Set to 0 to disable TTA and reuse the initial PRIMA prediction.",
498
+ ),
499
+ gr.Slider(
500
+ label="Detection threshold",
501
+ minimum=0.3,
502
+ maximum=0.9,
503
+ value=0.7,
504
+ step=0.05,
505
+ ),
506
+ gr.Slider(
507
+ label="Keypoint confidence threshold",
508
+ minimum=0.0,
509
+ maximum=1.0,
510
+ value=0.1,
511
+ step=0.05,
512
+ ),
513
+ gr.Checkbox(label="Render side view", value=False),
514
+ gr.Checkbox(label="Save meshes (.obj)", value=True),
515
+ ],
516
+ outputs=[
517
+ gr.Image(label="Before TTA"),
518
+ gr.Image(label="After TTA"),
519
+ gr.Image(label="PRIMA 26 keypoints"),
520
+ gr.Textbox(label="Status / Traceback", lines=12),
521
+ ],
522
+ title="PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation",
523
+ description=(
524
+ "Upload an animal image. The demo runs Detectron2 for animal detection, "
525
+ "PRIMA for 3D pose/shape, DeepLabCut SuperAnimal for 2D keypoints, and "
526
+ "test-time adaptation (TTA) with configurable learning rate and iterations. "
527
+ "Set TTA iterations to 0 to disable adaptation.\n\n"
528
+ "Results (PNG/OBJ and 26-keypoint visualizations) are saved under "
529
+ f"'{out_folder}'."
530
+ ),
531
+ )
532
+ if _gradio_examples:
533
+ _iface_kw["examples"] = _gradio_examples
534
+ demo = gr.Interface(**_iface_kw)
535
+ demo.queue(max_size=8, default_concurrency_limit=1)
536
+ return demo
537
+
538
+
539
+ def parse_args() -> argparse.Namespace:
540
+ parser = argparse.ArgumentParser(description="Gradio demo for PRIMA + SuperAnimal + TTA")
541
+ parser.add_argument(
542
+ "--checkpoint",
543
+ type=str,
544
+ default=DEFAULT_CHECKPOINT,
545
+ help="Path to the pretrained PRIMA checkpoint",
546
+ )
547
+ parser.add_argument(
548
+ "--out_folder",
549
+ type=str,
550
+ default=DEFAULT_OUT_FOLDER,
551
+ help="Folder used to save rendered outputs and meshes",
552
+ )
553
+ return parser.parse_args()
554
+
555
+
556
+ if __name__ == "__main__":
557
+ args = parse_args()
558
+ if _should_preload_assets():
559
+ _preload_assets_once(args.checkpoint)
560
+ demo = build_demo(checkpoint_path=args.checkpoint, out_folder=args.out_folder)
561
+ demo.launch()
chumpy/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """Minimal ``chumpy`` compatibility for unpickling legacy SMAL model configs."""
11
+
12
+ from __future__ import annotations
13
+
14
+ from .ch import Ch, ChArray
15
+
16
+ __all__ = ["Ch", "ChArray"]
chumpy/ch.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """``chumpy.ch`` namespace expected by legacy SMAL pickles."""
11
+
12
+ from __future__ import annotations
13
+
14
+ import numpy as np
15
+
16
+
17
+ class Ch:
18
+ """Minimal stand-in for ``chumpy.ch.Ch`` (unpickling only)."""
19
+
20
+ def __init__(self, *args, **kwargs):
21
+ self._data = None
22
+ if args:
23
+ self._data = np.asarray(args[0])
24
+
25
+ def r(self):
26
+ if self._data is None:
27
+ return np.zeros((), dtype=np.float32)
28
+ return np.asarray(self._data)
29
+
30
+
31
+ class ChArray(np.ndarray):
32
+ """Minimal stand-in for ``chumpy.ch.ChArray``."""
33
+
34
+ pass
35
+
36
+
37
+ __all__ = ["Ch", "ChArray"]
configs_hydra/experiment/default.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ SMAL:
4
+ DATA_DIR: data/smal
5
+ MODEL_PATH: data/smal/my_smpl_00781_4_all.pkl
6
+ SHAPE_PRIOR_PATH: data/smal/my_smpl_data_00781_4_all.pkl
7
+ POSE_PRIOR_PATH: data/smal/walking_toy_symmetric_pose_prior_with_cov_35parts.pkl
8
+ NUM_JOINTS: 34
9
+
10
+ EXTRA:
11
+ FOCAL_LENGTH: 1000
12
+ NUM_LOG_IMAGES: 4
13
+ NUM_LOG_SAMPLES_PER_IMAGE: 4
14
+ PELVIS_IND: 0
15
+
16
+ DATASETS:
17
+ CONFIG:
18
+ SCALE_FACTOR: 0.3
19
+ ROT_FACTOR: 30
20
+ TRANS_FACTOR: 0.02
21
+ COLOR_SCALE: 0.2
22
+ ROT_AUG_RATE: 0.6
23
+ TRANS_AUG_RATE: 0.5
24
+ DO_FLIP: False
25
+ FLIP_AUG_RATE: 0.0
26
+ EXTREME_CROP_AUG_RATE: 0.0
27
+ EXTREME_CROP_AUG_LEVEL: 1
28
+
configs_hydra/experiment/default_val.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ DATASETS:
4
+ ANIMAL3D:
5
+ ROOT_IMAGE: ./datasets/animal3d/
6
+ JSON_FILE:
7
+ TEST: ./datasets/animal3d/test.json
8
+ CONTROL_ANIMAL3D:
9
+ ROOT_IMAGE: ./datasets/control_animal3dlatest/
10
+ JSON_FILE:
11
+ TEST: ./datasets/control_animal3dlatest/test.json
12
+ QUADRUPED2D:
13
+ ROOT_IMAGE: ./datasets/quadruped2d/
14
+ JSON_FILE:
15
+ TEST: ./datasets/quadruped2d/test.json
16
+ ANIMAL_KINGDOM:
17
+ ROOT_IMAGE: ./datasets/Animal_Kingdom_test/
18
+ JSON_FILE:
19
+ TEST: ./datasets/Animal_Kingdom_test/test.json
20
+ CONFIG:
21
+ SCALE_FACTOR: 0.0
22
+ ROT_FACTOR: 0
23
+ TRANS_FACTOR: 0.0
24
+ COLOR_SCALE: 0.0
25
+ ROT_AUG_RATE: 0.0
26
+ TRANS_AUG_RATE: 0.0
27
+ DO_FLIP: False
28
+ FLIP_AUG_RATE: 0.0
29
+ EXTREME_CROP_AUG_RATE: 0.0
30
+ EXTREME_CROP_AUG_LEVEL: 1
31
+
32
+ METRIC:
33
+ PCK_THRESHOLD: [0.10, 0.15]
34
+
configs_hydra/experiment/primaStage1.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - default.yaml
5
+
6
+ GENERAL:
7
+
8
+
9
+ TOTAL_STEPS: 63_000
10
+ LOG_STEPS: 63
11
+ VAL_STEPS: 63
12
+ VAL_EPOCHS: 1
13
+ CHECKPOINT_EPOCHS: 1
14
+ CHECKPOINT_SAVE_TOP_K: 2
15
+ NUM_WORKERS: 8
16
+ PREFETCH_FACTOR: 2
17
+
18
+ LOSS_WEIGHTS:
19
+ KEYPOINTS_3D: 0.05
20
+ KEYPOINTS_2D: 0.01
21
+ INTERMEDIATE_KP2D: 0.001
22
+ INTERMEDIATE_KP3D: 0.001
23
+ GLOBAL_ORIENT: 0.005
24
+ POSE: 0.001
25
+ BETAS: 0.0005
26
+ TRANSL: 0.0005
27
+ ADVERSARIAL: 0.0005
28
+ SUPCON: 0.0005
29
+
30
+
31
+ TRAIN:
32
+ LR: 3.75e-6
33
+ WEIGHT_DECAY: 1e-4
34
+ BATCH_SIZE: 48
35
+ LOSS_REDUCTION: mean
36
+ NUM_TRAIN_SAMPLES: 2
37
+ NUM_TEST_SAMPLES: 64
38
+ POSE_2D_NOISE_RATIO: 0.01
39
+ SMPL_PARAM_NOISE_RATIO: 0.005
40
+
41
+ MODEL:
42
+ IMAGE_SIZE: 256
43
+ IMAGE_MEAN: [0.485, 0.456, 0.406]
44
+ IMAGE_STD: [0.229, 0.224, 0.225]
45
+ BACKBONE:
46
+ TYPE: vith
47
+ PRETRAINED_WEIGHTS: ./data/amr_vitbb.pth
48
+ FREEZE: False
49
+
50
+ # Enable BioClip embedding
51
+ USE_BIOCLIP_EMBEDDING: True
52
+ BIOCLIP_EMBEDDING:
53
+ EMBED_DIM: 1280 # Match DINOv2 output dimension for token-wise concatenation
54
+ TYPE: bioclip1
55
+
56
+ # Enable 2D keypoint embedding for initialization; NewBioGuidedSMALPoseDecoder updates it dynamically
57
+ USE_KEYPOINT_EMBEDDING: False
58
+
59
+ SMAL_HEAD:
60
+ TYPE: new_bio_pose_transformer_decoder # Use the newer version with SAM3D-style hierarchical updates
61
+ IN_CHANNELS: 1280
62
+ IEF_ITERS: 3
63
+
64
+ # Pose Transformer Decoder configuration
65
+ DECODER_DIM: 1280
66
+ NUM_DECODER_LAYERS: 6
67
+ NUM_HEADS: 8
68
+ MLP_RATIO: 4.0
69
+
70
+ # Keypoint token configuration specific to NewBioGuidedSMALPoseDecoder
71
+ USE_KEYPOINT_2D_TOKENS: True # Enable 2D keypoint tokens with SAM3D-style dynamic updates
72
+ USE_KEYPOINT_3D_TOKENS: True # Enable 3D keypoint tokens with pelvis normalization
73
+ KEYPOINT_TOKEN_UPDATE: True # Enable hierarchical keypoint prediction and token updates
74
+ KP2D_INJECT_IMAGE_FEAT: True # Key setting: inject image features via grid_sample
75
+
76
+
77
+ DATASETS:
78
+ ANIMAL3D:
79
+ ROOT_IMAGE: ./datasets/animal3d/
80
+ JSON_FILE:
81
+ TRAIN: ./datasets/animal3d/train.json
82
+ TEST: ./datasets/animal3d/test.json
83
+ WEIGHT: 1.0
configs_hydra/experiment/primaStage2.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - default.yaml
5
+
6
+ GENERAL:
7
+
8
+
9
+ TOTAL_STEPS: 450_000
10
+ LOG_STEPS: 533
11
+ VAL_STEPS: 533
12
+ VAL_EPOCHS: 1
13
+ CHECKPOINT_EPOCHS: 1
14
+ CHECKPOINT_SAVE_TOP_K: 2
15
+ NUM_WORKERS: 2
16
+ PREFETCH_FACTOR: 2
17
+
18
+ LOSS_WEIGHTS:
19
+ KEYPOINTS_3D: 0.05
20
+ KEYPOINTS_2D: 0.01
21
+ INTERMEDIATE_KP2D: 0.001
22
+ INTERMEDIATE_KP3D: 0.001
23
+ GLOBAL_ORIENT: 0.005
24
+ POSE: 0.001
25
+ BETAS: 0.0005
26
+ TRANSL: 0.0005
27
+ ADVERSARIAL: 0.0
28
+ SUPCON: 0.0005
29
+
30
+
31
+ TRAIN:
32
+ LR: 3.75e-6
33
+ WEIGHT_DECAY: 1e-4
34
+ BATCH_SIZE: 48
35
+ LOSS_REDUCTION: mean
36
+ NUM_TRAIN_SAMPLES: 2
37
+ NUM_TEST_SAMPLES: 64
38
+ POSE_2D_NOISE_RATIO: 0.01
39
+ SMPL_PARAM_NOISE_RATIO: 0.005
40
+
41
+ MODEL:
42
+ IMAGE_SIZE: 256
43
+ IMAGE_MEAN: [0.485, 0.456, 0.406]
44
+ IMAGE_STD: [0.229, 0.224, 0.225]
45
+ BACKBONE:
46
+ TYPE: vith
47
+ PRETRAINED_WEIGHTS: ./data/amr_vitbb.pth
48
+ FREEZE: False
49
+
50
+ # Enable BioClip embedding
51
+ USE_BIOCLIP_EMBEDDING: True
52
+ BIOCLIP_EMBEDDING:
53
+ EMBED_DIM: 1280 # Match vit output dimension for token-wise concatenation
54
+ TYPE: bioclip1
55
+
56
+ # Enable 2D keypoint embedding
57
+ USE_KEYPOINT_EMBEDDING: False
58
+ KEYPOINT_EMBEDDING:
59
+ NUM_KEYPOINTS: 26 # Number of SMAL keypoints
60
+ KEYPOINT_DIM: 2 # 2D coordinates (x, y)
61
+ EMBED_DIM: 1280 # Match vit output dimension
62
+ HIDDEN_DIM: 512 # Hidden layer dimension in MLP
63
+ TYPE: 'token' # Use token-based embedding (recommended)
64
+
65
+ SMAL_HEAD:
66
+ TYPE: new_bio_pose_transformer_decoder # Use the newer version with SAM3D-style hierarchical updates
67
+ IN_CHANNELS: 1280
68
+ IEF_ITERS: 1
69
+
70
+ # Pose Transformer Decoder configuration
71
+ DECODER_DIM: 1280
72
+ NUM_DECODER_LAYERS: 6
73
+ NUM_HEADS: 8
74
+ MLP_RATIO: 4.0
75
+
76
+ # Keypoint token configuration specific to NewBioGuidedSMALPoseDecoder
77
+ USE_KEYPOINT_2D_TOKENS: True # Enable 2D keypoint tokens with SAM3D-style dynamic updates
78
+ USE_KEYPOINT_3D_TOKENS: True # Enable 3D keypoint tokens with pelvis normalization
79
+ KEYPOINT_TOKEN_UPDATE: True # Enable hierarchical keypoint prediction and token updates
80
+ KP2D_INJECT_IMAGE_FEAT: True # Key setting: inject image features via grid_sample
81
+
82
+ # Legacy transformer config (kept for compatibility)
83
+ TRANSFORMER_DECODER:
84
+ depth: 6
85
+ heads: 8
86
+ mlp_dim: 1024
87
+ dim_head: 64
88
+ dropout: 0.0
89
+ emb_dropout: 0.0
90
+ norm: layer
91
+ context_dim: 1280
92
+
93
+
94
+
95
+ DATASETS:
96
+ ANIMAL3D:
97
+ ROOT_IMAGE: ./datasets/animal3d/
98
+ JSON_FILE:
99
+ TRAIN: ./datasets/animal3d/train.json
100
+ TEST: ./datasets/animal3d/test.json
101
+ WEIGHT: 1.0
102
+ CONTROL_ANIMAL3D:
103
+ ROOT_IMAGE: ./datasets/control_animal3dlatest/
104
+ JSON_FILE:
105
+ TRAIN: ./datasets/control_animal3dlatest/train.json
106
+ TEST: ./datasets/control_animal3dlatest/test.json
107
+ WEIGHT: 0.5
108
+ QUADRUPED2D:
109
+ ROOT_IMAGE: ./datasets/quadruped2d/
110
+ JSON_FILE:
111
+ TRAIN: ./datasets/quadruped2d/train.json
112
+ TEST: ./datasets/quadruped2d/test.json
113
+ WEIGHT: 0.15
configs_hydra/extras/default.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # disable python warnings if they annoy you
2
+ ignore_warnings: False
3
+
4
+ # ask user for tags if none are provided in the config
5
+ enforce_tags: True
6
+
7
+ # pretty print config tree at the start of the run using Rich library
8
+ print_config: True
configs_hydra/hydra/default.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ # https://hydra.cc/docs/configure_hydra/intro/
3
+
4
+ # enable color logging
5
+ defaults:
6
+ - override /hydra/hydra_logging: colorlog
7
+ - override /hydra/job_logging: colorlog
8
+
9
+ # exp_name: ovrd_${hydra:job.override_dirname}
10
+ exp_name: ${now:%Y-%m-%d}_${now:%H-%M-%S}
11
+
12
+ hydra:
13
+ run:
14
+ dir: ${paths.log_dir}/${task_name}/runs/${exp_name}
15
+ sweep:
16
+ dir: ${paths.log_dir}/${task_name}/multiruns/${exp_name}
17
+ subdir: ${hydra.job.num}
18
+ job:
19
+ config:
20
+ override_dirname:
21
+ exclude_keys:
22
+ - trainer
23
+ - trainer.devices
24
+ - trainer.num_nodes
25
+ - callbacks
26
+ - debug
configs_hydra/launcher/local.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /hydra/launcher: submitit_local
5
+
6
+ hydra:
7
+ launcher:
8
+ timeout_min: 10_080 # 7 days
9
+ nodes: 1
10
+ tasks_per_node: ${trainer.devices}
11
+ cpus_per_task: 8
12
+ gpus_per_node: ${trainer.devices}
13
+ name: amr
configs_hydra/launcher/slurm.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /hydra/launcher: submitit_slurm
5
+
6
+ hydra:
7
+ launcher:
8
+ timeout_min: 10_080 # 7 days
9
+ max_num_timeout: 3
10
+ partition: g40
11
+ qos: idle
12
+ nodes: 1
13
+ tasks_per_node: ${trainer.devices}
14
+ gpus_per_task: null
15
+ cpus_per_task: 12
16
+ gpus_per_node: ${trainer.devices}
17
+ cpus_per_gpu: null
18
+ comment: prima
19
+ name: prima
20
+ setup:
21
+ - module load cuda openmpi libfabric-aws
22
+ - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
configs_hydra/paths/default.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path to root directory
2
+ # this requires PROJECT_ROOT environment variable to exist
3
+ # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py`
4
+ root_dir: ${oc.env:PROJECT_ROOT}
5
+
6
+ # path to data directory
7
+ data_dir: ${paths.root_dir}/data/
8
+
9
+ # path to logging directory
10
+ log_dir: logs/
11
+
12
+ # path to output directory, created dynamically by hydra
13
+ # path generation pattern is specified in `configs/hydra/default.yaml`
14
+ # use it to store all files generated during the run, like ckpts and metrics
15
+ output_dir: ${hydra:runtime.output_dir}
16
+
17
+ # path to working directory
18
+ work_dir: ${hydra:runtime.cwd}
configs_hydra/train.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # specify here default configuration
4
+ # order of defaults determines the order in which configs override each other
5
+ defaults:
6
+ - _self_
7
+ - trainer: ddp.yaml
8
+ - paths: default.yaml
9
+ - extras: default.yaml
10
+ - hydra: default.yaml
11
+
12
+ # experiment configs allow for version control of specific hyperparameters
13
+ # e.g. best hyperparameters for given model and datamodule
14
+ - experiment: null
15
+ - texture_exp: null
16
+
17
+ # optional local config for machine/user specific settings
18
+ # it's optional since it doesn't need to exist and is excluded from version control
19
+ - optional launcher: local.yaml
20
+ # - optional launcher: slurm.yaml
21
+
22
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
23
+ - debug: null
24
+
25
+ # task name, determines output directory path
26
+ task_name: "train"
27
+
28
+ # tags to help you identify your experiments
29
+ # you can overwrite this in experiment configs
30
+ # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
31
+ # appending lists from command line is currently not supported :(
32
+ # https://github.com/facebookresearch/hydra/issues/1547
33
+ tags: ["dev"]
34
+
35
+ # set False to skip model training
36
+ train: True
37
+
38
+ # evaluate on test set, using best model weights achieved during training
39
+ # lightning chooses best weights based on the metric specified in checkpoint callback
40
+ test: False
41
+
42
+ # simply provide checkpoint path to resume training
43
+ ckpt_path: True
44
+
45
+ # seed for random number generators in pytorch, numpy and python.random
46
+ seed: null
configs_hydra/trainer/cpu.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+ - default_amr.yaml
4
+
5
+ accelerator: cpu
6
+ devices: 1
configs_hydra/trainer/ddp.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+ - default_amr.yaml
4
+
5
+ # use "ddp_spawn" instead of "ddp",
6
+ # it's slower but normal "ddp" currently doesn't work ideally with hydra
7
+ # https://github.com/facebookresearch/hydra/issues/2070
8
+ # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn
9
+ strategy: ddp_spawn
10
+
11
+ accelerator: gpu
12
+ devices: 2
13
+ num_nodes: 1
14
+ sync_batchnorm: True
configs_hydra/trainer/default.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: pytorch_lightning.Trainer
2
+
3
+ default_root_dir: ${paths.output_dir}
4
+
5
+ accelerator: gpu
6
+ devices: 1
7
+
8
+ # set True to to ensure deterministic results
9
+ # makes training slower but gives more reproducibility than just setting seeds
10
+ deterministic: False
configs_hydra/trainer/default_amr.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ num_sanity_val_steps: 0
2
+ log_every_n_steps: ${GENERAL.LOG_STEPS}
3
+ val_check_interval: ${GENERAL.VAL_STEPS} # How often within one training epoch to check the validation set.
4
+ check_val_every_n_epoch: ${GENERAL.VAL_EPOCHS} # Check val every n train epochs.
5
+ precision: 16-mixed # 16-mixed, 32
6
+ max_steps: ${GENERAL.TOTAL_STEPS}
7
+ # move_metrics_to_cpu: True
8
+ limit_val_batches: 80 # How much of validation dataset to check.
9
+ # track_grad_norm: -1
configs_hydra/trainer/gpu.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+ - default_amr.yaml
4
+
5
+ accelerator: gpu
6
+ devices: 1
configs_hydra/trainer/mps.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+ - default_amr.yaml
4
+
5
+ accelerator: mps
6
+ devices: 1
demo.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from pathlib import Path
11
+ import detectron2.config
12
+ import detectron2.engine
13
+ import torch
14
+ import argparse
15
+ import os
16
+ import cv2
17
+ import numpy as np
18
+ from tqdm import tqdm
19
+ import torch.utils
20
+ import torch.utils.data
21
+ from prima.models import load_prima
22
+ from prima.utils import recursive_to
23
+ from prima.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
24
+ from prima.utils.renderer import Renderer, cam_crop_to_full
25
+ import detectron2
26
+ from detectron2 import model_zoo
27
+ import warnings
28
+ warnings.filterwarnings("ignore")
29
+
30
+ LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353)
31
+ GREEN = (0.65, 0.86, 0.74)
32
+
33
+
34
+
35
+ def main():
36
+ parser = argparse.ArgumentParser(description='prima demo code')
37
+ parser.add_argument('--checkpoint', type=str,
38
+ help='Path to pretrained model checkpoint')
39
+ parser.add_argument('--img_folder', type=str, default='demo_data/', help='Folder with input images')
40
+ parser.add_argument('--out_folder', type=str, default='demo_out', help='Output folder to save rendered results')
41
+ parser.add_argument('--side_view', dest='side_view', action='store_true', default=False,
42
+ help='If set, render side view also')
43
+ parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False,
44
+ help='If set, save meshes to disk also')
45
+ parser.add_argument('--batch_size', type=int, default=1, help='Batch size for inference/fitting')
46
+ parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg', '*.JPEG'],
47
+ help='List of file extensions to consider')
48
+
49
+ args = parser.parse_args()
50
+
51
+ model, model_cfg = load_prima(args.checkpoint)
52
+
53
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
54
+ model = model.to(device)
55
+ model.eval()
56
+
57
+ # Setup the renderer
58
+ renderer = Renderer(model_cfg, faces=model.smal.faces)
59
+
60
+ # Make output directory if it does not exist
61
+ os.makedirs(args.out_folder, exist_ok=True)
62
+
63
+ # Load detector
64
+ cfg = detectron2.config.get_cfg()
65
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"))
66
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
67
+ cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
68
+ detector = detectron2.engine.DefaultPredictor(cfg)
69
+
70
+ img_paths = sorted([img for end in args.file_type for img in Path(args.img_folder).glob(end)])
71
+ for img_path in img_paths:
72
+ img_bgr = cv2.imread(str(img_path))
73
+ if img_bgr is None:
74
+ print(f"[WARN] Cannot read image: {img_path}")
75
+ continue
76
+ # Detect animals in image
77
+ det_out = detector(img_bgr)
78
+
79
+ det_instances = det_out['instances']
80
+ valid_idx = [i for i, (c, s) in enumerate(zip(det_instances.pred_classes, det_instances.scores)) if ((c in [15, 16, 17, 18, 19, 21, 22]) & (s > 0.7))]
81
+ boxes = det_instances.pred_boxes.tensor[valid_idx].cpu().numpy()
82
+
83
+ # Run PRIMA on detected animals
84
+ dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
85
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
86
+ for batch in tqdm(dataloader):
87
+ batch = recursive_to(batch, device)
88
+ with torch.no_grad():
89
+ out = model(batch)
90
+
91
+ pred_cam = out['pred_cam']
92
+ box_center = batch["box_center"].float()
93
+ box_size = batch["box_size"].float()
94
+ img_size = batch["img_size"].float()
95
+ scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max()
96
+ pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size,
97
+ scaled_focal_length).detach().cpu().numpy()
98
+
99
+ # Render the result
100
+ batch_size = batch['img'].shape[0]
101
+ for n in range(batch_size):
102
+ # Get filename from path img_path
103
+ img_fn, _ = os.path.splitext(os.path.basename(img_path))
104
+ animal_id = int(batch['animalid'][n])
105
+ white_img = (torch.ones_like(batch['img'][n]).cpu() - DEFAULT_MEAN[:, None, None] / 255) / (
106
+ DEFAULT_STD[:, None, None] / 255)
107
+ input_patch = (batch['img'][n].cpu() * (DEFAULT_STD[:, None, None]) + (
108
+ DEFAULT_MEAN[:, None, None])) / 255.
109
+ input_patch = input_patch.permute(1, 2, 0).numpy()
110
+
111
+ regression_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(),
112
+ out['pred_cam_t'][n].detach().cpu().numpy(),
113
+ batch['img'][n],
114
+ mesh_base_color=GREEN,
115
+ scene_bg_color=(1, 1, 1),
116
+ )
117
+
118
+ final_img = np.concatenate([input_patch, regression_img], axis=1)
119
+
120
+ if args.side_view:
121
+ side_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(),
122
+ out['pred_cam_t'][n].detach().cpu().numpy(),
123
+ white_img,
124
+ mesh_base_color=GREEN,
125
+ scene_bg_color=(1, 1, 1),
126
+ side_view=True)
127
+ final_img = np.concatenate([final_img, side_img], axis=1)
128
+
129
+ cv2.imwrite(os.path.join(args.out_folder, f'{img_fn}_{animal_id}.png'),
130
+ cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR))
131
+
132
+ # Add all verts and cams to list
133
+ verts = out['pred_vertices'][n].detach().cpu().numpy()
134
+ cam_t = pred_cam_t_full[n]
135
+
136
+ # Save all meshes to disk
137
+ if args.save_mesh:
138
+ camera_translation = cam_t.copy()
139
+ tmesh = renderer.vertices_to_trimesh(verts, camera_translation, LIGHT_BLUE)
140
+ tmesh.export(os.path.join(args.out_folder, f'{img_fn}_{animal_id}.obj'))
141
+
142
+
143
+ if __name__ == '__main__':
144
+ main()
demo_data/000000015956_horse.png ADDED

Git LFS Details

  • SHA256: 2a2398ba7df40a47c636afefa28be17b55f4b7bc2c378e053aeea507580ad2cb
  • Pointer size: 131 Bytes
  • Size of remote file: 620 kB
demo_data/000000315905_zebra.jpg ADDED

Git LFS Details

  • SHA256: e0a17e1f1650820b020a9025144015c1e27f0f1ab435859f0bde3a0047d8f689
  • Pointer size: 131 Bytes
  • Size of remote file: 257 kB
demo_data/beagle.jpg ADDED

Git LFS Details

  • SHA256: ac29e6ea6086831dd9806a8cd3fd608e264ac1af567f6fcfc8797c5bd3d5d560
  • Pointer size: 131 Bytes
  • Size of remote file: 350 kB
demo_data/n02101388_1188.png ADDED

Git LFS Details

  • SHA256: e45ff508fb8c6437cce22fcb59b4f1b6fe37ddfab1d4cf68d97629f9caa939f4
  • Pointer size: 131 Bytes
  • Size of remote file: 319 kB
demo_data/n02412080_12159.png ADDED

Git LFS Details

  • SHA256: 03273c57e8b25b258d3eb96af7b4f77b43b5c40be90da83c21875f3322b487f1
  • Pointer size: 131 Bytes
  • Size of remote file: 347 kB
demo_data/shepherd_hati.jpg ADDED

Git LFS Details

  • SHA256: 65c5878203bc3165dda9011ebfce77cc7d930daed0a215396d8036509d1963c1
  • Pointer size: 131 Bytes
  • Size of remote file: 210 kB
demo_tta.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """
11
+ demo_tta.py: PRIMA inference with DeepLabCut SuperAnimal TTA
12
+
13
+ Pipeline:
14
+ 1. Run Detectron2 to detect animals in the input image.
15
+ 2. Run PRIMA on each detected animal to obtain 3D pose/shape estimation.
16
+ 3. Run DeepLabCut SuperAnimal to obtain 2D keypoint estimation.
17
+ 4. Map the 39 SuperAnimal keypoints to the 26 PRIMA keypoints.
18
+ 5. Run test-time adaptation (TTA) with user-specified lr and num_iters
19
+ to further optimize the 3D pose and shape estimation.
20
+ 6. Render and save before/after TTA results (PNG + OBJ) and the
21
+ 26-keypoint visualization (PNG).
22
+
23
+ Reference code:
24
+ - Test-time adaptation: prima/../eval_with_tta.py
25
+ - DeepLabCut: https://github.com/AdaptiveMotorControlLab/FMPose3D/blob/main/animals/demo/vis_animals.py
26
+ - Keypoint mapping (SuperAnimal 39 → PRIMA 26):
27
+ keypoint_mapping = {"quadruped80k":[10, 5, -1, 26, 29, 30, 35, 22, 24, 27, 31, 32, -1, -1,
28
+ 25, 28, 33, 34, 15, 23, 11, 6, 4, 3, 0, -1]}
29
+ """
30
+
31
+
32
+ from pathlib import Path
33
+ import argparse
34
+ import copy
35
+ import os
36
+ import tempfile
37
+ import warnings
38
+
39
+ import cv2
40
+ import numpy as np
41
+ import torch
42
+ import torch.nn.functional as F
43
+ import torch.utils.data
44
+ from tqdm import tqdm
45
+
46
+ from prima.models import load_prima
47
+ from prima.utils import recursive_to
48
+ from prima.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
49
+ from prima.utils.renderer import Renderer, cam_crop_to_full
50
+
51
+ warnings.filterwarnings("ignore")
52
+
53
+ LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353)
54
+ GREEN = (0.65, 0.86, 0.74)
55
+
56
+ ANIMAL_COCO_IDS = [15, 16, 17, 18, 19, 21, 22]
57
+ keypoint_mapping = {
58
+ "quadruped80k": [10, 5, -1, 26, 29, 30, 35, 22, 24, 27, 31, 32, -1, -1, 25, 28, 33, 34, 15, 23, 11, 6, 4, 3, 0, -1]
59
+ }
60
+
61
+
62
+ def denorm_patch_to_rgb(img_tensor: torch.Tensor) -> np.ndarray:
63
+ patch = (img_tensor.detach().cpu() * (DEFAULT_STD[:, None, None]) + DEFAULT_MEAN[:, None, None]) / 255.0
64
+ patch = patch.permute(1, 2, 0).numpy()
65
+ return np.clip(patch, 0.0, 1.0)
66
+
67
+
68
+ def map_superanimal_to_prima(bodyparts_xyc: np.ndarray) -> np.ndarray:
69
+ mapping = keypoint_mapping["quadruped80k"]
70
+ num_src = bodyparts_xyc.shape[0]
71
+ mapped = np.zeros((len(mapping), 3), dtype=np.float32)
72
+
73
+ for tgt_i, src_i in enumerate(mapping):
74
+ if src_i >= 0 and src_i < num_src:
75
+ mapped[tgt_i] = bodyparts_xyc[src_i]
76
+
77
+ return mapped
78
+
79
+
80
+ def save_keypoint_vis(patch_rgb: np.ndarray, kpts_xyc: np.ndarray, save_path: str) -> None:
81
+ vis = cv2.cvtColor((patch_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2BGR).copy()
82
+ num_kpts = len(kpts_xyc)
83
+
84
+ for i, (x, y, c) in enumerate(kpts_xyc):
85
+ if c <= 0:
86
+ continue
87
+
88
+ # Use distinct color for each keypoint (OpenCV uses BGR)
89
+ hue = int(179 * i / max(1, num_kpts - 1))
90
+ color_bgr = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2BGR)[0, 0]
91
+ color_bgr = (int(color_bgr[0]), int(color_bgr[1]), int(color_bgr[2]))
92
+
93
+ cx, cy = int(round(float(x))), int(round(float(y)))
94
+ cv2.circle(vis, (cx, cy), 3, color_bgr, -1)
95
+ cv2.putText(vis, str(i), (cx + 3, cy - 3), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1, cv2.LINE_AA)
96
+
97
+ cv2.imwrite(save_path, vis)
98
+
99
+
100
+ def run_superanimal_on_patch(patch_rgb: np.ndarray, args, tmp_dir: str):
101
+ try:
102
+ from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images
103
+ except Exception as e:
104
+ raise RuntimeError(
105
+ "Cannot import DeepLabCut SuperAnimal API. Please install deeplabcut with pose_estimation_pytorch support."
106
+ ) from e
107
+
108
+ patch_path = os.path.join(tmp_dir, "patch.png")
109
+ cv2.imwrite(patch_path, cv2.cvtColor((patch_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
110
+
111
+ preds = superanimal_analyze_images(
112
+ args.superanimal_name,
113
+ args.superanimal_model_name,
114
+ args.superanimal_detector_name,
115
+ patch_path,
116
+ args.superanimal_max_individuals,
117
+ out_folder=tmp_dir,
118
+ )
119
+
120
+ payload = preds.get(patch_path, None)
121
+ if payload is None:
122
+ return None
123
+ bodyparts = payload.get("bodyparts", None)
124
+ if bodyparts is None or len(bodyparts) == 0:
125
+ return None
126
+
127
+ best_idx = int(np.argmax(bodyparts[..., 2].mean(axis=1)))
128
+ return bodyparts[best_idx]
129
+
130
+
131
+ def render_and_save(renderer, out, batch, img_fn, animal_id, out_folder, suffix, side_view, save_mesh):
132
+ pred_cam = out['pred_cam']
133
+ box_center = batch['box_center'].float()
134
+ box_size = batch['box_size'].float()
135
+ img_size = batch['img_size'].float()
136
+ scaled_focal_length = batch['focal_length'][0, 0] / batch['img'].shape[-1] * img_size.max()
137
+ pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size, scaled_focal_length)
138
+
139
+ white_img = (torch.ones_like(batch['img'][0]).cpu() - DEFAULT_MEAN[:, None, None] / 255) / (
140
+ DEFAULT_STD[:, None, None] / 255
141
+ )
142
+ input_patch = denorm_patch_to_rgb(batch['img'][0])
143
+
144
+ regression_img = renderer(
145
+ out['pred_vertices'][0].detach().cpu().numpy(),
146
+ out['pred_cam_t'][0].detach().cpu().numpy(),
147
+ batch['img'][0],
148
+ mesh_base_color=GREEN,
149
+ scene_bg_color=(1, 1, 1),
150
+ )
151
+
152
+ final_img = np.concatenate([input_patch, regression_img], axis=1)
153
+ if side_view:
154
+ side_img = renderer(
155
+ out['pred_vertices'][0].detach().cpu().numpy(),
156
+ out['pred_cam_t'][0].detach().cpu().numpy(),
157
+ white_img,
158
+ mesh_base_color=GREEN,
159
+ scene_bg_color=(1, 1, 1),
160
+ side_view=True,
161
+ )
162
+ final_img = np.concatenate([final_img, side_img], axis=1)
163
+
164
+ cv2.imwrite(
165
+ os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.png'),
166
+ cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR),
167
+ )
168
+
169
+ if save_mesh:
170
+ verts = out['pred_vertices'][0].detach().cpu().numpy()
171
+ cam_t = pred_cam_t_full[0].detach().cpu().numpy()
172
+ tmesh = renderer.vertices_to_trimesh(verts, cam_t.copy(), LIGHT_BLUE)
173
+ tmesh.export(os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.obj'))
174
+
175
+
176
+ def tta_optimize(model, batch, gt_kpts_norm, num_iters, lr):
177
+ model.eval()
178
+
179
+ if hasattr(model, 'backbone'):
180
+ for p in model.backbone.parameters():
181
+ p.requires_grad = False
182
+
183
+ orig_smal_head_state = copy.deepcopy(model.smal_head.state_dict())
184
+ model.smal_head.freeze_except_regression_heads()
185
+ tta_params = model.smal_head.get_tta_parameters(mode='all')
186
+ optimizer = torch.optim.Adam(tta_params, lr=lr)
187
+
188
+ valid_mask = (gt_kpts_norm[..., 2] > 0).float().unsqueeze(-1)
189
+ gt_xy = gt_kpts_norm[..., :2]
190
+
191
+ for _ in range(num_iters):
192
+ optimizer.zero_grad()
193
+ out = model(batch)
194
+ pred_xy = out['pred_keypoints_2d']
195
+ loss = F.mse_loss(pred_xy * valid_mask, gt_xy * valid_mask, reduction='sum') / (valid_mask.sum() + 1e-6)
196
+ loss.backward()
197
+ optimizer.step()
198
+
199
+ with torch.no_grad():
200
+ out_after = model(batch)
201
+
202
+ model.smal_head.load_state_dict(orig_smal_head_state)
203
+ model.smal_head.unfreeze_all()
204
+
205
+ return out_after
206
+
207
+
208
+ def main():
209
+ parser = argparse.ArgumentParser(description='PRIMA + SuperAnimal + TTA demo')
210
+ parser.add_argument('--checkpoint', type=str, required=True, help='Path to pretrained PRIMA checkpoint')
211
+ parser.add_argument('--img_path', type=str, default=None, help='Single image path')
212
+ parser.add_argument('--img_folder', type=str, default='demo_data/', help='Folder with input images')
213
+ parser.add_argument('--out_folder', type=str, default='demo_out_tta', help='Output folder')
214
+ parser.add_argument('--side_view', dest='side_view', action='store_true', default=False, help='Render side view')
215
+ parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, help='Save meshes')
216
+ parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg', '*.JPEG'], help='Image globs')
217
+ parser.add_argument('--det_thresh', type=float, default=0.7, help='Detectron2 score threshold for animals')
218
+
219
+ parser.add_argument('--tta_lr', type=float, default=1e-6, help='TTA learning rate')
220
+ parser.add_argument('--tta_num_iters', type=int, default=30, help='TTA iterations')
221
+ parser.add_argument('--kp_conf_thresh', type=float, default=0.1, help='Keypoint confidence threshold')
222
+
223
+ parser.add_argument('--superanimal_name', type=str, default='superanimal_quadruped')
224
+ parser.add_argument('--superanimal_model_name', type=str, default='hrnet_w32')
225
+ parser.add_argument('--superanimal_detector_name', type=str, default='fasterrcnn_resnet50_fpn_v2')
226
+ parser.add_argument('--superanimal_max_individuals', type=int, default=1)
227
+
228
+ args = parser.parse_args()
229
+
230
+ model, model_cfg = load_prima(args.checkpoint)
231
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
232
+ model = model.to(device)
233
+ model.eval()
234
+
235
+ renderer = Renderer(model_cfg, faces=model.smal.faces)
236
+ os.makedirs(args.out_folder, exist_ok=True)
237
+
238
+ import detectron2.config
239
+ import detectron2.engine
240
+ from detectron2 import model_zoo
241
+
242
+ cfg = detectron2.config.get_cfg()
243
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"))
244
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
245
+ cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
246
+ detector = detectron2.engine.DefaultPredictor(cfg)
247
+
248
+ if args.img_path is not None:
249
+ img_paths = [Path(args.img_path)]
250
+ else:
251
+ img_paths = sorted([img for end in args.file_type for img in Path(args.img_folder).glob(end)])
252
+
253
+ for img_path in img_paths:
254
+ img_bgr = cv2.imread(str(img_path))
255
+ if img_bgr is None:
256
+ print(f"[WARN] Cannot read image: {img_path}")
257
+ continue
258
+ det_out = detector(img_bgr)
259
+ det_instances = det_out['instances']
260
+ valid_idx = [
261
+ i for i, (c, s) in enumerate(zip(det_instances.pred_classes, det_instances.scores))
262
+ if (int(c) in ANIMAL_COCO_IDS) and (float(s) > args.det_thresh)
263
+ ]
264
+ boxes = det_instances.pred_boxes.tensor[valid_idx].cpu().numpy()
265
+
266
+ if len(boxes) == 0:
267
+ print(f"[INFO] No animal detected in {img_path}")
268
+ continue
269
+
270
+ dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
271
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
272
+
273
+ for batch in tqdm(dataloader, desc=f"{img_path.name}"):
274
+ batch = recursive_to(batch, device)
275
+ with torch.no_grad():
276
+ out_before = model(batch)
277
+
278
+ img_fn = img_path.stem
279
+ animal_id = int(batch['animalid'][0])
280
+
281
+ render_and_save(
282
+ renderer,
283
+ out_before,
284
+ batch,
285
+ img_fn,
286
+ animal_id,
287
+ args.out_folder,
288
+ suffix='before_tta',
289
+ side_view=args.side_view,
290
+ save_mesh=args.save_mesh,
291
+ )
292
+
293
+ patch_rgb = denorm_patch_to_rgb(batch['img'][0])
294
+ with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir:
295
+ bodyparts_xyc = run_superanimal_on_patch(patch_rgb, args, tmp_dir)
296
+
297
+ if bodyparts_xyc is None:
298
+ print(f"[WARN] No SuperAnimal keypoints for {img_fn}_{animal_id}, skip TTA")
299
+ continue
300
+
301
+ mapped_xyc = map_superanimal_to_prima(bodyparts_xyc)
302
+ mapped_xyc[mapped_xyc[:, 2] < args.kp_conf_thresh, 2] = 0.0
303
+
304
+ save_keypoint_vis(
305
+ patch_rgb,
306
+ mapped_xyc,
307
+ os.path.join(args.out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png"),
308
+ )
309
+ np.save(os.path.join(args.out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy"), mapped_xyc)
310
+
311
+ patch_h, patch_w = patch_rgb.shape[:2]
312
+ mapped_norm = mapped_xyc.copy()
313
+ mapped_norm[:, 0] = mapped_norm[:, 0] / float(patch_w) - 0.5
314
+ mapped_norm[:, 1] = mapped_norm[:, 1] / float(patch_h) - 0.5
315
+ gt_kpts_norm = torch.from_numpy(mapped_norm[None]).to(device=device, dtype=batch['img'].dtype)
316
+
317
+ out_after = tta_optimize(
318
+ model,
319
+ batch,
320
+ gt_kpts_norm,
321
+ num_iters=args.tta_num_iters,
322
+ lr=args.tta_lr,
323
+ )
324
+
325
+ render_and_save(
326
+ renderer,
327
+ out_after,
328
+ batch,
329
+ img_fn,
330
+ animal_id,
331
+ args.out_folder,
332
+ suffix='after_tta',
333
+ side_view=args.side_view,
334
+ save_mesh=args.save_mesh,
335
+ )
336
+
337
+
338
+ if __name__ == '__main__':
339
+ main()
340
+
eval.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+ import torch
13
+ from prima.utils import recursive_to
14
+ from prima.utils.evaluate_metric import Evaluator
15
+ from prima.datasets.datasets import EvaluationDataset
16
+ import argparse
17
+ from torch.utils.data import DataLoader
18
+ from prima.models.prima import PRIMA
19
+ from prima.configs import get_config
20
+ torch.multiprocessing.set_sharing_strategy('file_system')
21
+
22
+
23
+ def main(args):
24
+ cfg = get_config(args.config)
25
+ default_cfg = get_config(args.default_eval_config)
26
+ model = PRIMA.load_from_checkpoint(args.checkpoint, cfg=cfg, strict=False)
27
+ model.eval()
28
+
29
+ smal_evaluator = Evaluator(smal_model=model.smal, image_size=cfg.MODEL.IMAGE_SIZE)
30
+ cfg_eval_dataset = dict(default_cfg.DATASETS)
31
+ aug_cfg = cfg_eval_dataset.pop("CONFIG", None) # augmentation config is not used in evaluation
32
+
33
+ if args.dataset.upper() == "ALL":
34
+ for key in cfg_eval_dataset.keys():
35
+ print(f"-------- Evaluate {key} dataset ------------")
36
+ eval_one_dataset(cfg_eval_dataset[key], default_cfg, cfg, model,
37
+ evaluator=smal_evaluator,
38
+ aug_cfg=aug_cfg,
39
+ key=key,
40
+ device=args.device)
41
+ print(f"-------{key} Dataset evaluate finish ------")
42
+ else:
43
+ print(f"-------- Evaluate {args.dataset} dataset ------------")
44
+ eval_one_dataset(cfg_eval_dataset[args.dataset], default_cfg, cfg, model,
45
+ evaluator=smal_evaluator,
46
+ aug_cfg=aug_cfg,
47
+ key=args.dataset,
48
+ device=args.device)
49
+ print(f"-------{args.dataset} Dataset evaluate finish ------")
50
+
51
+
52
+ def eval_one_dataset(dataset_cfg, default_cfg, cfg, model, evaluator, aug_cfg, key, device='cuda'):
53
+ dataset = EvaluationDataset(root_image=dataset_cfg['ROOT_IMAGE'],
54
+ json_file=dataset_cfg['JSON_FILE']['TEST'],
55
+ augm_config=aug_cfg, focal_length=cfg.SMAL.get("FOCAL_LENGTH", 1000),
56
+ image_size=cfg.MODEL.IMAGE_SIZE,
57
+ )
58
+ dataloader = DataLoader(dataset, batch_size=1, num_workers=cfg.GENERAL.NUM_WORKERS)
59
+
60
+ bar = tqdm(dataloader)
61
+ pa_mpjpe_list, pck_list, auc_list, pa_mpvpe_list = [], [], [], []
62
+ for i, batch in enumerate(bar):
63
+ batch = recursive_to(batch, device)
64
+ with torch.no_grad():
65
+ output = model(batch)
66
+
67
+ if key in ["ANIMAL3D", "CONTROL_ANIMAL3D"]:
68
+ pa_mpjpe, pa_mpvpe = evaluator.eval_3d(output, batch)
69
+ else:
70
+ pa_mpjpe, pa_mpvpe = 0., 0.
71
+ pck, auc = evaluator.eval_2d(output, batch, pck_threshold=default_cfg.METRIC.PCK_THRESHOLD)
72
+
73
+ pa_mpjpe_list.append(pa_mpjpe)
74
+ pa_mpvpe_list.append(pa_mpvpe)
75
+ auc_list.append(auc)
76
+ pck_list.append(pck)
77
+
78
+ bar.set_postfix(PA_MPJPE=pa_mpjpe,
79
+ PA_MPVPE=pa_mpvpe,
80
+ AUC=auc,
81
+ pck=pck,)
82
+
83
+ print("---------------- 3D metric -----------------")
84
+ print(f"Avg PA-MPJPE: {np.mean(pa_mpjpe_list)}")
85
+ print(f"Avg PA-MPVPE: {np.mean(pa_mpvpe_list)}")
86
+
87
+ print("--------------- 2D metric ------------------")
88
+ print(f"AUC: {np.mean(auc_list)}")
89
+ pck_list = np.array(pck_list)
90
+ for _, th in enumerate(default_cfg.METRIC.PCK_THRESHOLD):
91
+ print(f"PCK@{th}: {np.mean(pck_list[:, _])}")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ parser = argparse.ArgumentParser()
96
+ parser.add_argument("--config", type=str, help="Path to config file", required=True)
97
+ parser.add_argument("--checkpoint", type=str, help="Path to checkpoint file", required=True)
98
+ parser.add_argument("--default_eval_config", type=str, default="./configs_hydra/experiment/default_val.yaml")
99
+ parser.add_argument("--dataset", type=str, default="ALL")
100
+ parser.add_argument("--device", type=str, default="cuda", help="Device to use for evaluation")
101
+ args = parser.parse_args()
102
+ main(args)
images/teaser.png ADDED

Git LFS Details

  • SHA256: a617ca4fd37de03e2db4ccf397ce9841ed32c3fe18c766c4832d41af574ad746
  • Pointer size: 132 Bytes
  • Size of remote file: 4.29 MB
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ libosmesa6
2
+ libgl1
3
+ libegl1
4
+ libgles2
prima/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """Top-level package for PRIMA.
11
+
12
+ This package contains models, datasets and utilities for
13
+ 3D animal pose and shape estimation.
14
+ """
15
+
16
+ from importlib.metadata import PackageNotFoundError, version
17
+
18
+
19
+ try: # pragma: no cover - best effort during development
20
+ __version__ = version("prima-animal")
21
+ except PackageNotFoundError: # pragma: no cover
22
+ __version__ = "0.0.0"
23
+
24
+
25
+ __all__ = ["__version__"]
prima/configs/__init__.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from typing import Dict
11
+ from yacs.config import CfgNode as CN
12
+
13
+ def to_lower(x: Dict) -> Dict:
14
+ """
15
+ Convert all dictionary keys to lowercase
16
+ Args:
17
+ x (dict): Input dictionary
18
+ Returns:
19
+ dict: Output dictionary with all keys converted to lowercase
20
+ """
21
+ return {k.lower(): v for k, v in x.items()}
22
+
23
+
24
+ _C = CN(new_allowed=True)
25
+
26
+ _C.GENERAL = CN(new_allowed=True)
27
+ _C.GENERAL.RESUME = True
28
+ _C.GENERAL.TIME_TO_RUN = 3300
29
+ _C.GENERAL.VAL_STEPS = 100
30
+ _C.GENERAL.LOG_STEPS = 100
31
+ _C.GENERAL.CHECKPOINT_STEPS = 20000
32
+ _C.GENERAL.CHECKPOINT_DIR = "checkpoints"
33
+ _C.GENERAL.SUMMARY_DIR = "tensorboard"
34
+ _C.GENERAL.NUM_GPUS = 1
35
+ _C.GENERAL.NUM_WORKERS = 4
36
+ _C.GENERAL.MIXED_PRECISION = True
37
+ _C.GENERAL.ALLOW_CUDA = True
38
+ _C.GENERAL.PIN_MEMORY = False
39
+ _C.GENERAL.DISTRIBUTED = False
40
+ _C.GENERAL.LOCAL_RANK = 0
41
+ _C.GENERAL.USE_SYNCBN = False
42
+ _C.GENERAL.WORLD_SIZE = 1
43
+ _C.GENERAL.PREFETCH_FACTOR = 2
44
+
45
+ _C.TRAIN = CN(new_allowed=True)
46
+ _C.TRAIN.NUM_EPOCHS = 100
47
+ _C.TRAIN.SHUFFLE = True
48
+ _C.TRAIN.WARMUP = False
49
+ _C.TRAIN.NORMALIZE_PER_IMAGE = False
50
+ _C.TRAIN.CLIP_GRAD = False
51
+ _C.TRAIN.CLIP_GRAD_VALUE = 1.0
52
+ _C.LOSS_WEIGHTS = CN(new_allowed=True)
53
+
54
+ _C.DATASETS = CN(new_allowed=True)
55
+
56
+ _C.MODEL = CN(new_allowed=True)
57
+ _C.MODEL.IMAGE_SIZE = 224
58
+
59
+ _C.EXTRA = CN(new_allowed=True)
60
+ _C.EXTRA.FOCAL_LENGTH = 5000
61
+
62
+ _C.DATASETS.CONFIG = CN(new_allowed=True)
63
+ _C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
64
+ _C.DATASETS.CONFIG.ROT_FACTOR = 30
65
+ _C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
66
+ _C.DATASETS.CONFIG.COLOR_SCALE = 0.2
67
+ _C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
68
+ _C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
69
+ _C.DATASETS.CONFIG.DO_FLIP = False
70
+ _C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
71
+ _C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
72
+
73
+
74
+ def default_config() -> CN:
75
+ """
76
+ Get a yacs CfgNode object with the default config values.
77
+ """
78
+ # Return a clone so that the defaults will not be altered
79
+ # This is for the "local variable" use pattern
80
+ return _C.clone()
81
+
82
+
83
+ def get_config(config_file: str, merge: bool = True) -> CN:
84
+ """
85
+ Read a config file and optionally merge it with the default config file.
86
+ Args:
87
+ config_file (str): Path to config file.
88
+ merge (bool): Whether to merge with the default config or not.
89
+ Returns:
90
+ CfgNode: Config as a yacs CfgNode object.
91
+ """
92
+ if merge:
93
+ cfg = default_config()
94
+ else:
95
+ cfg = CN(new_allowed=True)
96
+ cfg.merge_from_file(config_file)
97
+
98
+ cfg.freeze()
99
+ return cfg
prima/models/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from .prima import PRIMA
11
+
12
+
13
+ def load_prima(checkpoint_path):
14
+ from pathlib import Path
15
+ from ..configs import get_config
16
+ model_cfg = str(Path(checkpoint_path).parent.parent / '.hydra/config.yaml')
17
+ model_cfg = get_config(model_cfg)
18
+
19
+ # Override some config values, to crop bbox correctly
20
+ if (model_cfg.MODEL.BACKBONE.TYPE == 'vit') and ('BBOX_SHAPE' not in model_cfg.MODEL):
21
+ model_cfg.defrost()
22
+ assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
23
+ model_cfg.MODEL.BBOX_SHAPE = [192, 256]
24
+ model_cfg.freeze()
25
+ if (model_cfg.MODEL.BACKBONE.TYPE == 'dinov3') and ('BBOX_SHAPE' not in model_cfg.MODEL):
26
+ model_cfg.defrost()
27
+ assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for dino backbone"
28
+ model_cfg.MODEL.BBOX_SHAPE = [256, 256]
29
+ model_cfg.freeze()
30
+
31
+ if (model_cfg.MODEL.BACKBONE.TYPE == 'dinov2') and ('BBOX_SHAPE' not in model_cfg.MODEL):
32
+ model_cfg.defrost()
33
+ assert model_cfg.MODEL.IMAGE_SIZE == 252, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 252 for dino backbone"
34
+ model_cfg.MODEL.BBOX_SHAPE = [252, 252]
35
+ model_cfg.freeze()
36
+
37
+
38
+
39
+ # Update config to be compatible with demo
40
+ if ('PRETRAINED_WEIGHTS' in model_cfg.MODEL.BACKBONE):
41
+ model_cfg.defrost()
42
+ model_cfg.MODEL.BACKBONE.pop('PRETRAINED_WEIGHTS')
43
+ model_cfg.freeze()
44
+
45
+ # Offscreen training renderer is not needed for demo/inference startup and
46
+ # can fail on some local OpenGL backends.
47
+ model = PRIMA.load_from_checkpoint(
48
+ checkpoint_path,
49
+ strict=False,
50
+ cfg=model_cfg,
51
+ map_location='cpu',
52
+ init_renderer=False,
53
+ )
54
+ return model, model_cfg
prima/models/backbones/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from .vit import vith
11
+
12
+
13
+
14
+
15
+ def create_backbone(cfg):
16
+ if cfg.MODEL.BACKBONE.TYPE in ['vith','concat','aa']: # vit bb will be used in these three cases - animal feature extractor
17
+ return vith(cfg)
18
+ else:
19
+ raise NotImplementedError('Backbone type is not implemented')
prima/models/backbones/vit.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ # Copyright (c) OpenMMLab. All rights reserved.
11
+ import math
12
+
13
+ import torch
14
+ from functools import partial
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torch.utils.checkpoint as checkpoint
18
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
19
+
20
+
21
+ def vith(cfg):
22
+ return ViT(
23
+ img_size=(256, 192),
24
+ patch_size=16,
25
+ embed_dim=1280,
26
+ depth=32,
27
+ num_heads=16,
28
+ ratio=1,
29
+ use_checkpoint=False,
30
+ # use_checkpoint=True,
31
+ mlp_ratio=4,
32
+ qkv_bias=True,
33
+ drop_path_rate=0.55,
34
+ use_cls=True, # cls for animal family classification
35
+ )
36
+
37
+
38
+ def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
39
+ """
40
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
41
+ dimension for the original embeddings.
42
+ Args:
43
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
44
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
45
+ hw (Tuple): size of input image tokens.
46
+
47
+ Returns:
48
+ Absolute positional embeddings after processing with shape (1, H, W, C)
49
+ """
50
+ cls_token = None
51
+ B, L, C = abs_pos.shape
52
+ if has_cls_token:
53
+ cls_token = abs_pos[:, 0:1]
54
+ abs_pos = abs_pos[:, 1:]
55
+
56
+ if ori_h != h or ori_w != w:
57
+ new_abs_pos = F.interpolate(
58
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
59
+ size=(h, w),
60
+ mode="bicubic",
61
+ align_corners=False,
62
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
63
+
64
+ else:
65
+ new_abs_pos = abs_pos
66
+
67
+ if cls_token is not None:
68
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
69
+ return new_abs_pos
70
+
71
+
72
+ class DropPath(nn.Module):
73
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
74
+ """
75
+
76
+ def __init__(self, drop_prob=None):
77
+ super(DropPath, self).__init__()
78
+ self.drop_prob = drop_prob
79
+
80
+ def forward(self, x):
81
+ return drop_path(x, self.drop_prob, self.training)
82
+
83
+ def extra_repr(self):
84
+ return 'p={}'.format(self.drop_prob)
85
+
86
+
87
+ class Mlp(nn.Module):
88
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
89
+ super().__init__()
90
+ out_features = out_features or in_features
91
+ hidden_features = hidden_features or in_features
92
+ self.fc1 = nn.Linear(in_features, hidden_features)
93
+ self.act = act_layer()
94
+ self.fc2 = nn.Linear(hidden_features, out_features)
95
+ self.drop = nn.Dropout(drop)
96
+
97
+ def forward(self, x):
98
+ x = self.fc1(x)
99
+ x = self.act(x)
100
+ x = self.fc2(x)
101
+ x = self.drop(x)
102
+ return x
103
+
104
+
105
+ class Attention(nn.Module):
106
+ def __init__(
107
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
108
+ proj_drop=0., attn_head_dim=None):
109
+ super().__init__()
110
+ self.num_heads = num_heads
111
+ head_dim = dim // num_heads
112
+ self.dim = dim
113
+
114
+ if attn_head_dim is not None:
115
+ head_dim = attn_head_dim
116
+ all_head_dim = head_dim * self.num_heads
117
+
118
+ self.scale = qk_scale or head_dim ** -0.5
119
+
120
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
121
+
122
+ self.attn_drop = nn.Dropout(attn_drop)
123
+ self.proj = nn.Linear(all_head_dim, dim)
124
+ self.proj_drop = nn.Dropout(proj_drop)
125
+
126
+ def forward(self, x):
127
+ B, N, C = x.shape
128
+ qkv = self.qkv(x)
129
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
130
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
131
+
132
+ q = q * self.scale
133
+ attn = (q @ k.transpose(-2, -1))
134
+ attn = attn.softmax(dim=-1)
135
+ attn = self.attn_drop(attn)
136
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
137
+
138
+ x = self.proj(x)
139
+ x = self.proj_drop(x)
140
+
141
+ return x
142
+
143
+
144
+ class Block(nn.Module):
145
+
146
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
147
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
148
+ norm_layer=nn.LayerNorm, attn_head_dim=None,
149
+ ):
150
+ super().__init__()
151
+
152
+ self.norm1 = norm_layer(dim)
153
+ self.attn = Attention(
154
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
155
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
156
+ )
157
+
158
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
159
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
160
+ self.norm2 = norm_layer(dim)
161
+ mlp_hidden_dim = int(dim * mlp_ratio)
162
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
163
+
164
+ def forward(self, x):
165
+ x = x + self.drop_path(self.attn(self.norm1(x)))
166
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
167
+ return x
168
+
169
+
170
+ class PatchEmbed(nn.Module):
171
+ """ Image to Patch Embedding
172
+ """
173
+
174
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
175
+ super().__init__()
176
+ img_size = to_2tuple(img_size)
177
+ patch_size = to_2tuple(patch_size)
178
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
179
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
180
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
181
+ self.img_size = img_size
182
+ self.patch_size = patch_size
183
+ self.num_patches = num_patches
184
+
185
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio),
186
+ padding=4 + 2 * (ratio // 2 - 1))
187
+
188
+ def forward(self, x, **kwargs):
189
+ B, C, H, W = x.shape
190
+ x = self.proj(x)
191
+ Hp, Wp = x.shape[2], x.shape[3]
192
+
193
+ x = x.flatten(2).transpose(1, 2)
194
+ return x, (Hp, Wp)
195
+
196
+
197
+ class HybridEmbed(nn.Module):
198
+ """ CNN Feature Map Embedding
199
+ Extract feature map from CNN, flatten, project to embedding dim.
200
+ """
201
+
202
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
203
+ super().__init__()
204
+ assert isinstance(backbone, nn.Module)
205
+ img_size = to_2tuple(img_size)
206
+ self.img_size = img_size
207
+ self.backbone = backbone
208
+ if feature_size is None:
209
+ with torch.no_grad():
210
+ training = backbone.training
211
+ if training:
212
+ backbone.eval()
213
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
214
+ feature_size = o.shape[-2:]
215
+ feature_dim = o.shape[1]
216
+ backbone.train(training)
217
+ else:
218
+ feature_size = to_2tuple(feature_size)
219
+ feature_dim = self.backbone.feature_info.channels()[-1]
220
+ self.num_patches = feature_size[0] * feature_size[1]
221
+ self.proj = nn.Linear(feature_dim, embed_dim)
222
+
223
+ def forward(self, x):
224
+ x = self.backbone(x)[-1]
225
+ x = x.flatten(2).transpose(1, 2)
226
+ x = self.proj(x)
227
+ return x
228
+
229
+
230
+ class ViT(nn.Module):
231
+
232
+ def __init__(self,
233
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
234
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
235
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
236
+ frozen_stages=-1, ratio=1, last_norm=True, use_cls=False,
237
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
238
+ ):
239
+ # Protect mutable default arguments
240
+ super(ViT, self).__init__()
241
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
242
+ self.num_classes = num_classes
243
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
244
+ self.frozen_stages = frozen_stages
245
+ self.use_checkpoint = use_checkpoint
246
+ self.patch_padding = patch_padding
247
+ self.freeze_attn = freeze_attn
248
+ self.freeze_ffn = freeze_ffn
249
+ self.depth = depth
250
+
251
+ if hybrid_backbone is not None:
252
+ self.patch_embed = HybridEmbed(
253
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
254
+ else:
255
+ self.patch_embed = PatchEmbed(
256
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
257
+ num_patches = self.patch_embed.num_patches
258
+
259
+ # since the pretraining model has class token
260
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
261
+
262
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
263
+
264
+ self.blocks = nn.ModuleList([
265
+ Block(
266
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
267
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
268
+ )
269
+ for i in range(depth)])
270
+
271
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
272
+
273
+ if self.pos_embed is not None:
274
+ trunc_normal_(self.pos_embed, std=.02)
275
+
276
+ self.use_cls = use_cls
277
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
278
+ nn.init.normal_(self.cls_token, std=1e-6)
279
+
280
+ self._freeze_stages()
281
+
282
+ def _freeze_stages(self):
283
+ """Freeze parameters."""
284
+ if self.frozen_stages >= 0:
285
+ self.patch_embed.eval()
286
+ for param in self.patch_embed.parameters():
287
+ param.requires_grad = False
288
+
289
+ for i in range(1, self.frozen_stages + 1):
290
+ m = self.blocks[i]
291
+ m.eval()
292
+ for param in m.parameters():
293
+ param.requires_grad = False
294
+
295
+ if self.freeze_attn:
296
+ for i in range(0, self.depth):
297
+ m = self.blocks[i]
298
+ m.attn.eval()
299
+ m.norm1.eval()
300
+ for param in m.attn.parameters():
301
+ param.requires_grad = False
302
+ for param in m.norm1.parameters():
303
+ param.requires_grad = False
304
+
305
+ if self.freeze_ffn:
306
+ self.pos_embed.requires_grad = False
307
+ self.patch_embed.eval()
308
+ for param in self.patch_embed.parameters():
309
+ param.requires_grad = False
310
+ for i in range(0, self.depth):
311
+ m = self.blocks[i]
312
+ m.mlp.eval()
313
+ m.norm2.eval()
314
+ for param in m.mlp.parameters():
315
+ param.requires_grad = False
316
+ for param in m.norm2.parameters():
317
+ param.requires_grad = False
318
+
319
+ def init_weights(self):
320
+ """Initialize the weights in backbone.
321
+ Args:
322
+ pretrained (str, optional): Path to pre-trained weights.
323
+ Defaults to None.
324
+ """
325
+
326
+ def _init_weights(m):
327
+ if isinstance(m, nn.Linear):
328
+ trunc_normal_(m.weight, std=.02)
329
+ if isinstance(m, nn.Linear) and m.bias is not None:
330
+ nn.init.constant_(m.bias, 0)
331
+ elif isinstance(m, nn.LayerNorm):
332
+ nn.init.constant_(m.bias, 0)
333
+ nn.init.constant_(m.weight, 1.0)
334
+
335
+ self.apply(_init_weights)
336
+
337
+ def get_num_layers(self):
338
+ return len(self.blocks)
339
+
340
+ @torch.jit.ignore
341
+ def no_weight_decay(self):
342
+ return {'pos_embed', 'cls_token'}
343
+
344
+ def forward_features(self, x):
345
+ B, C, H, W = x.shape
346
+ x, (Hp, Wp) = self.patch_embed(x)
347
+
348
+ if self.pos_embed is not None:
349
+ # fit for multiple GPU training
350
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
351
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
352
+
353
+ x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) if self.use_cls else x
354
+ for blk in self.blocks:
355
+ if self.use_checkpoint:
356
+ x = checkpoint.checkpoint(blk, x)
357
+ else:
358
+ x = blk(x)
359
+
360
+ x = self.last_norm(x)
361
+
362
+ cls = x[:, 0] if self.use_cls else None
363
+ x = x[:, 1:] if self.use_cls else x
364
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
365
+
366
+ return xp, cls # shape [B, D, Hp, Wp], [B, D]
367
+
368
+ def forward(self, x):
369
+ x, cls = self.forward_features(x)
370
+ return x, cls
371
+
372
+ def train(self, mode=True):
373
+ """Convert the model into training mode."""
374
+ super().train(mode)
375
+ self._freeze_stages()
prima/models/bioclip_embedding.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """
11
+ bioclip Embedding Module
12
+ Converts image batch to embeddings that can be concatenated with image features
13
+ """
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ class BioClipEmbedding(nn.Module):
19
+ """
20
+ Embeds images into a feature space using BioClip model that can be combined with image features.
21
+
22
+ Args:
23
+ embed_dim: Output embedding dimension, should match the dimension of image features for concatenation
24
+ """
25
+
26
+ def __init__(self, cfg, embed_dim: int = 1024):
27
+ super().__init__()
28
+
29
+ self.embed_dim = embed_dim
30
+
31
+ import open_clip
32
+
33
+ if cfg.MODEL.BIOCLIP_EMBEDDING.TYPE == 'bioclip2':
34
+ print("[BioClipEmbedding] Using BioClip2 model from Hugging Face Hub")
35
+ self.species_model, _,_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2')
36
+ else:
37
+ self.species_model, _,_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
38
+ # tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip')
39
+
40
+
41
+ self.species_model.eval()
42
+
43
+ # Get the output dimension from the model
44
+ bioclip_output_dim = self.species_model.visual.output_dim
45
+
46
+ # Project to target dimension
47
+ self.projection = nn.Sequential(
48
+ nn.Linear(bioclip_output_dim, embed_dim),
49
+ nn.LayerNorm(embed_dim),
50
+ )
51
+
52
+ def forward(self, images: torch.Tensor) -> torch.Tensor:
53
+ """
54
+ Args:
55
+ images: Tensor of shape (B, C, H, W) representing a batch of images
56
+ Returns:
57
+ Tensor of shape (B, embed_dim) representing the embedded features
58
+ """
59
+ # BioClip expects 224x224 input, resize if needed
60
+ if images.shape[-2:] != (224, 224):
61
+ images_resized = F.interpolate(images, size=(224, 224), mode='bilinear', align_corners=False)
62
+ else:
63
+ images_resized = images
64
+
65
+ with torch.no_grad():
66
+ image_features = self.species_model.encode_image(images_resized)
67
+
68
+ projected_features = self.projection(image_features)
69
+
70
+ return projected_features
prima/models/components/__init__.py ADDED
File without changes
prima/models/components/model_utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
11
+ # All rights reserved.
12
+
13
+ # This source code is licensed under the license found in the
14
+ # LICENSE file in the root directory of this source tree.
15
+
16
+
17
+ import copy
18
+ from typing import Tuple
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+
26
+ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
27
+ """
28
+ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
29
+ that are temporally closest to the current frame at `frame_idx`. Here, we take
30
+ - a) the closest conditioning frame before `frame_idx` (if any);
31
+ - b) the closest conditioning frame after `frame_idx` (if any);
32
+ - c) any other temporally closest conditioning frames until reaching a total
33
+ of `max_cond_frame_num` conditioning frames.
34
+
35
+ Outputs:
36
+ - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
37
+ - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
38
+ """
39
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
40
+ selected_outputs = cond_frame_outputs
41
+ unselected_outputs = {}
42
+ else:
43
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
44
+ selected_outputs = {}
45
+
46
+ # the closest conditioning frame before `frame_idx` (if any)
47
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
48
+ if idx_before is not None:
49
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
50
+
51
+ # the closest conditioning frame after `frame_idx` (if any)
52
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
53
+ if idx_after is not None:
54
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
55
+
56
+ # add other temporally closest conditioning frames until reaching a total
57
+ # of `max_cond_frame_num` conditioning frames.
58
+ num_remain = max_cond_frame_num - len(selected_outputs)
59
+ inds_remain = sorted(
60
+ (t for t in cond_frame_outputs if t not in selected_outputs),
61
+ key=lambda x: abs(x - frame_idx),
62
+ )[:num_remain]
63
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
64
+ unselected_outputs = {
65
+ t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
66
+ }
67
+
68
+ return selected_outputs, unselected_outputs
69
+
70
+
71
+ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
72
+ """
73
+ Get 1D sine positional embedding as in the original Transformer paper.
74
+ """
75
+ pe_dim = dim // 2
76
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
77
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
78
+
79
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
80
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
81
+ return pos_embed
82
+
83
+
84
+ def get_activation_fn(activation):
85
+ """Return an activation function given a string"""
86
+ if activation == "relu":
87
+ return F.relu
88
+ if activation == "gelu":
89
+ return F.gelu
90
+ if activation == "glu":
91
+ return F.glu
92
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
93
+
94
+
95
+ def get_clones(module, N):
96
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
97
+
98
+
99
+ class DropPath(nn.Module):
100
+ # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
101
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
102
+ super(DropPath, self).__init__()
103
+ self.drop_prob = drop_prob
104
+ self.scale_by_keep = scale_by_keep
105
+
106
+ def forward(self, x):
107
+ if self.drop_prob == 0.0 or not self.training:
108
+ return x
109
+ keep_prob = 1 - self.drop_prob
110
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
111
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
112
+ if keep_prob > 0.0 and self.scale_by_keep:
113
+ random_tensor.div_(keep_prob)
114
+ return x * random_tensor
115
+
116
+
117
+ # Lightly adapted from
118
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
119
+ class MLP(nn.Module):
120
+ def __init__(
121
+ self,
122
+ input_dim: int,
123
+ hidden_dim: int,
124
+ output_dim: int,
125
+ num_layers: int,
126
+ activation: nn.Module = nn.ReLU,
127
+ sigmoid_output: bool = False,
128
+ ) -> None:
129
+ super().__init__()
130
+ self.num_layers = num_layers
131
+ h = [hidden_dim] * (num_layers - 1)
132
+ self.layers = nn.ModuleList(
133
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
134
+ )
135
+ self.sigmoid_output = sigmoid_output
136
+ self.act = activation()
137
+
138
+ def forward(self, x):
139
+ for i, layer in enumerate(self.layers):
140
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
141
+ if self.sigmoid_output:
142
+ x = F.sigmoid(x)
143
+ return x
144
+
145
+
146
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
147
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
148
+ class LayerNorm2d(nn.Module):
149
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
150
+ super().__init__()
151
+ self.weight = nn.Parameter(torch.ones(num_channels))
152
+ self.bias = nn.Parameter(torch.zeros(num_channels))
153
+ self.eps = eps
154
+
155
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
156
+ u = x.mean(1, keepdim=True)
157
+ s = (x - u).pow(2).mean(1, keepdim=True)
158
+ x = (x - u) / torch.sqrt(s + self.eps)
159
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
160
+ return x
prima/models/components/pose_transformer.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from inspect import isfunction
11
+ from typing import Callable, Optional
12
+
13
+ import torch
14
+ from einops import rearrange
15
+ from einops.layers.torch import Rearrange
16
+ from torch import nn
17
+
18
+ from .t_cond_mlp import (
19
+ AdaptiveLayerNorm1D,
20
+ FrequencyEmbedder,
21
+ normalization_layer,
22
+ )
23
+
24
+
25
+ def exists(val):
26
+ return val is not None
27
+
28
+
29
+ def default(val, d):
30
+ if exists(val):
31
+ return val
32
+ return d() if isfunction(d) else d
33
+
34
+
35
+ class PreNorm(nn.Module):
36
+ def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
37
+ super().__init__()
38
+ self.norm = normalization_layer(norm, dim, norm_cond_dim)
39
+ self.fn = fn
40
+
41
+ def forward(self, x: torch.Tensor, *args, **kwargs):
42
+ if isinstance(self.norm, AdaptiveLayerNorm1D):
43
+ return self.fn(self.norm(x, *args), **kwargs)
44
+ else:
45
+ return self.fn(self.norm(x), **kwargs)
46
+
47
+
48
+ class FeedForward(nn.Module):
49
+ def __init__(self, dim, hidden_dim, dropout=0.0):
50
+ super().__init__()
51
+ self.net = nn.Sequential(
52
+ nn.Linear(dim, hidden_dim),
53
+ nn.GELU(),
54
+ nn.Dropout(dropout),
55
+ nn.Linear(hidden_dim, dim),
56
+ nn.Dropout(dropout),
57
+ )
58
+
59
+ def forward(self, x):
60
+ return self.net(x)
61
+
62
+
63
+ class Attention(nn.Module):
64
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
65
+ super().__init__()
66
+ inner_dim = dim_head * heads
67
+ project_out = not (heads == 1 and dim_head == dim)
68
+
69
+ self.heads = heads
70
+ self.scale = dim_head**-0.5
71
+
72
+ self.attend = nn.Softmax(dim=-1)
73
+ self.dropout = nn.Dropout(dropout)
74
+
75
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
76
+
77
+ self.to_out = (
78
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
79
+ if project_out
80
+ else nn.Identity()
81
+ )
82
+
83
+ def forward(self, x):
84
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
85
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
86
+
87
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
88
+
89
+ attn = self.attend(dots)
90
+ attn = self.dropout(attn)
91
+
92
+ out = torch.matmul(attn, v)
93
+ out = rearrange(out, "b h n d -> b n (h d)")
94
+ return self.to_out(out)
95
+
96
+
97
+ class CrossAttention(nn.Module):
98
+ def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
99
+ super().__init__()
100
+ inner_dim = dim_head * heads
101
+ project_out = not (heads == 1 and dim_head == dim)
102
+
103
+ self.heads = heads
104
+ self.scale = dim_head**-0.5
105
+
106
+ self.attend = nn.Softmax(dim=-1)
107
+ self.dropout = nn.Dropout(dropout)
108
+
109
+ context_dim = default(context_dim, dim)
110
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
111
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
112
+
113
+ self.to_out = (
114
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
115
+ if project_out
116
+ else nn.Identity()
117
+ )
118
+
119
+ def forward(self, x, context=None):
120
+ context = default(context, x)
121
+ k, v = self.to_kv(context).chunk(2, dim=-1)
122
+ q = self.to_q(x)
123
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
124
+
125
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
126
+
127
+ attn = self.attend(dots)
128
+ attn = self.dropout(attn)
129
+
130
+ out = torch.matmul(attn, v)
131
+ out = rearrange(out, "b h n d -> b n (h d)")
132
+ return self.to_out(out)
133
+
134
+
135
+ class Transformer(nn.Module):
136
+ def __init__(
137
+ self,
138
+ dim: int,
139
+ depth: int,
140
+ heads: int,
141
+ dim_head: int,
142
+ mlp_dim: int,
143
+ dropout: float = 0.0,
144
+ norm: str = "layer",
145
+ norm_cond_dim: int = -1,
146
+ ):
147
+ super().__init__()
148
+ self.layers = nn.ModuleList([])
149
+ for _ in range(depth):
150
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
151
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
152
+ self.layers.append(
153
+ nn.ModuleList(
154
+ [
155
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
156
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
157
+ ]
158
+ )
159
+ )
160
+
161
+ def forward(self, x: torch.Tensor, *args):
162
+ for attn, ff in self.layers:
163
+ x = attn(x, *args) + x
164
+ x = ff(x, *args) + x
165
+ return x
166
+
167
+
168
+ class TransformerCrossAttn(nn.Module):
169
+ def __init__(
170
+ self,
171
+ dim: int,
172
+ depth: int,
173
+ heads: int,
174
+ dim_head: int,
175
+ mlp_dim: int,
176
+ dropout: float = 0.0,
177
+ norm: str = "layer",
178
+ norm_cond_dim: int = -1,
179
+ context_dim: Optional[int] = None,
180
+ ):
181
+ super().__init__()
182
+ self.layers = nn.ModuleList([])
183
+ for _ in range(depth):
184
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
185
+ ca = CrossAttention(
186
+ dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
187
+ )
188
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
189
+ self.layers.append(
190
+ nn.ModuleList(
191
+ [
192
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
193
+ PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
194
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
195
+ ]
196
+ )
197
+ )
198
+
199
+ def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
200
+ if context_list is None:
201
+ context_list = [context] * len(self.layers)
202
+ if len(context_list) != len(self.layers):
203
+ raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
204
+
205
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
206
+ x = self_attn(x, *args) + x
207
+ x = cross_attn(x, *args, context=context_list[i]) + x
208
+ x = ff(x, *args) + x
209
+ return x
210
+
211
+
212
+ class DropTokenDropout(nn.Module):
213
+ def __init__(self, p: float = 0.1):
214
+ super().__init__()
215
+ if p < 0 or p > 1:
216
+ raise ValueError(
217
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
218
+ )
219
+ self.p = p
220
+
221
+ def forward(self, x: torch.Tensor):
222
+ # x: (batch_size, seq_len, dim)
223
+ if self.training and self.p > 0:
224
+ zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
225
+
226
+ if zero_mask.any():
227
+ x = x[:, ~zero_mask, :]
228
+ return x
229
+
230
+
231
+ class ZeroTokenDropout(nn.Module):
232
+ def __init__(self, p: float = 0.1):
233
+ super().__init__()
234
+ if p < 0 or p > 1:
235
+ raise ValueError(
236
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
237
+ )
238
+ self.p = p
239
+
240
+ def forward(self, x: torch.Tensor):
241
+ # x: (batch_size, seq_len, dim)
242
+ if self.training and self.p > 0:
243
+ zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
244
+ # Zero-out the masked tokens
245
+ x[zero_mask, :] = 0
246
+ return x
247
+
248
+
249
+ class TransformerEncoder(nn.Module):
250
+ def __init__(
251
+ self,
252
+ num_tokens: int,
253
+ token_dim: int,
254
+ dim: int,
255
+ depth: int,
256
+ heads: int,
257
+ mlp_dim: int,
258
+ dim_head: int = 64,
259
+ dropout: float = 0.0,
260
+ emb_dropout: float = 0.0,
261
+ emb_dropout_type: str = "drop",
262
+ emb_dropout_loc: str = "token",
263
+ norm: str = "layer",
264
+ norm_cond_dim: int = -1,
265
+ token_pe_numfreq: int = -1,
266
+ ):
267
+ super().__init__()
268
+ if token_pe_numfreq > 0:
269
+ token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
270
+ self.to_token_embedding = nn.Sequential(
271
+ Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
272
+ FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
273
+ Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
274
+ nn.Linear(token_dim_new, dim),
275
+ )
276
+ else:
277
+ self.to_token_embedding = nn.Linear(token_dim, dim)
278
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
279
+ if emb_dropout_type == "drop":
280
+ self.dropout = DropTokenDropout(emb_dropout)
281
+ elif emb_dropout_type == "zero":
282
+ self.dropout = ZeroTokenDropout(emb_dropout)
283
+ else:
284
+ raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
285
+ self.emb_dropout_loc = emb_dropout_loc
286
+
287
+ self.transformer = Transformer(
288
+ dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
289
+ )
290
+
291
+ def forward(self, inp: torch.Tensor, *args, **kwargs):
292
+ x = inp
293
+
294
+ if self.emb_dropout_loc == "input":
295
+ x = self.dropout(x)
296
+ x = self.to_token_embedding(x)
297
+
298
+ if self.emb_dropout_loc == "token":
299
+ x = self.dropout(x)
300
+ b, n, _ = x.shape
301
+ x += self.pos_embedding[:, :n]
302
+
303
+ if self.emb_dropout_loc == "token_afterpos":
304
+ x = self.dropout(x)
305
+ x = self.transformer(x, *args)
306
+ return x
307
+
308
+
309
+ class TransformerDecoder(nn.Module):
310
+ def __init__(
311
+ self,
312
+ num_tokens: int,
313
+ token_dim: int,
314
+ dim: int,
315
+ depth: int,
316
+ heads: int,
317
+ mlp_dim: int,
318
+ dim_head: int = 64,
319
+ dropout: float = 0.0,
320
+ emb_dropout: float = 0.0,
321
+ emb_dropout_type: str = 'drop',
322
+ norm: str = "layer",
323
+ norm_cond_dim: int = -1,
324
+ context_dim: Optional[int] = None,
325
+ skip_token_embedding: bool = False,
326
+ ):
327
+ super().__init__()
328
+ if not skip_token_embedding:
329
+ self.to_token_embedding = nn.Linear(token_dim, dim)
330
+ else:
331
+ self.to_token_embedding = nn.Identity()
332
+ if token_dim != dim:
333
+ raise ValueError(
334
+ f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
335
+ )
336
+
337
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
338
+ if emb_dropout_type == "drop":
339
+ self.dropout = DropTokenDropout(emb_dropout)
340
+ elif emb_dropout_type == "zero":
341
+ self.dropout = ZeroTokenDropout(emb_dropout)
342
+ elif emb_dropout_type == "normal":
343
+ self.dropout = nn.Dropout(emb_dropout)
344
+
345
+ self.transformer = TransformerCrossAttn(
346
+ dim,
347
+ depth,
348
+ heads,
349
+ dim_head,
350
+ mlp_dim,
351
+ dropout,
352
+ norm=norm,
353
+ norm_cond_dim=norm_cond_dim,
354
+ context_dim=context_dim,
355
+ )
356
+
357
+ def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
358
+ x = self.to_token_embedding(inp)
359
+ b, n, _ = x.shape
360
+
361
+ x = self.dropout(x)
362
+ x += self.pos_embedding[:, :n]
363
+
364
+ x = self.transformer(x, *args, context=context, context_list=context_list)
365
+ return x
366
+
prima/models/components/position_encoding.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
11
+ # All rights reserved.
12
+
13
+ # This source code is licensed under the license found in the
14
+ # LICENSE file in the root directory of this source tree.
15
+
16
+ import math
17
+ from typing import Any, Optional, Tuple
18
+
19
+ import numpy as np
20
+
21
+ import torch
22
+ from torch import nn
23
+
24
+ # Rotary Positional Encoding, adapted from:
25
+ # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
26
+ # 2. https://github.com/naver-ai/rope-vit
27
+ # 3. https://github.com/lucidrains/rotary-embedding-torch
28
+
29
+
30
+ def init_t_xy(end_x: int, end_y: int):
31
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
32
+ t_x = (t % end_x).float()
33
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
34
+ return t_x, t_y
35
+
36
+
37
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
38
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
39
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
40
+
41
+ t_x, t_y = init_t_xy(end_x, end_y)
42
+ freqs_x = torch.outer(t_x, freqs_x)
43
+ freqs_y = torch.outer(t_y, freqs_y)
44
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
45
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
46
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
47
+
48
+
49
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
50
+ ndim = x.ndim
51
+ assert 0 <= 1 < ndim
52
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
53
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
54
+ return freqs_cis.view(*shape)
55
+
56
+
57
+ def apply_rotary_enc(
58
+ xq: torch.Tensor,
59
+ xk: torch.Tensor,
60
+ freqs_cis: torch.Tensor,
61
+ repeat_freqs_k: bool = False,
62
+ ):
63
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
64
+ xk_ = (
65
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
66
+ if xk.shape[-2] != 0
67
+ else None
68
+ )
69
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
70
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
71
+ if xk_ is None:
72
+ # no keys to rotate, due to dropout
73
+ return xq_out.type_as(xq).to(xq.device), xk
74
+ # repeat freqs along seq_len dim to match k seq_len
75
+ if repeat_freqs_k:
76
+ r = xk_.shape[-2] // xq_.shape[-2]
77
+ if freqs_cis.is_cuda:
78
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
79
+ else:
80
+ # torch.repeat on complex numbers may not be supported on non-CUDA devices
81
+ # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
82
+ freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
83
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
84
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
prima/models/components/t_cond_mlp.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import copy
11
+ from typing import List, Optional
12
+
13
+ import torch
14
+
15
+
16
+ class AdaptiveLayerNorm1D(torch.nn.Module):
17
+ def __init__(self, data_dim: int, norm_cond_dim: int):
18
+ super().__init__()
19
+ if data_dim <= 0:
20
+ raise ValueError(f"data_dim must be positive, but got {data_dim}")
21
+ if norm_cond_dim <= 0:
22
+ raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
23
+ self.norm = torch.nn.LayerNorm(data_dim)
24
+ self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
25
+ torch.nn.init.zeros_(self.linear.weight)
26
+ torch.nn.init.zeros_(self.linear.bias)
27
+
28
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
29
+ # x: (batch, ..., data_dim)
30
+ # t: (batch, norm_cond_dim)
31
+ # return: (batch, data_dim)
32
+ x = self.norm(x)
33
+ alpha, beta = self.linear(t).chunk(2, dim=-1)
34
+
35
+ # Add singleton dimensions to alpha and beta
36
+ if x.dim() > 2:
37
+ alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
38
+ beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
39
+
40
+ return x * (1 + alpha) + beta
41
+
42
+
43
+ class SequentialCond(torch.nn.Sequential):
44
+ def forward(self, input, *args, **kwargs):
45
+ for module in self:
46
+ if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
47
+ input = module(input, *args, **kwargs)
48
+ else:
49
+ input = module(input)
50
+ return input
51
+
52
+
53
+ def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
54
+ if norm == "batch":
55
+ return torch.nn.BatchNorm1d(dim)
56
+ elif norm == "layer":
57
+ return torch.nn.LayerNorm(dim)
58
+ elif norm == "ada":
59
+ assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
60
+ return AdaptiveLayerNorm1D(dim, norm_cond_dim)
61
+ elif norm is None:
62
+ return torch.nn.Identity()
63
+ else:
64
+ raise ValueError(f"Unknown norm: {norm}")
65
+
66
+
67
+ def linear_norm_activ_dropout(
68
+ input_dim: int,
69
+ output_dim: int,
70
+ activation: torch.nn.Module = torch.nn.ReLU(),
71
+ bias: bool = True,
72
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
73
+ dropout: float = 0.0,
74
+ norm_cond_dim: int = -1,
75
+ ) -> SequentialCond:
76
+ layers = []
77
+ layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
78
+ if norm is not None:
79
+ layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
80
+ layers.append(copy.deepcopy(activation))
81
+ if dropout > 0.0:
82
+ layers.append(torch.nn.Dropout(dropout))
83
+ return SequentialCond(*layers)
84
+
85
+
86
+ def create_simple_mlp(
87
+ input_dim: int,
88
+ hidden_dims: List[int],
89
+ output_dim: int,
90
+ activation: torch.nn.Module = torch.nn.ReLU(),
91
+ bias: bool = True,
92
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
93
+ dropout: float = 0.0,
94
+ norm_cond_dim: int = -1,
95
+ ) -> SequentialCond:
96
+ layers = []
97
+ prev_dim = input_dim
98
+ for hidden_dim in hidden_dims:
99
+ layers.extend(
100
+ linear_norm_activ_dropout(
101
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
102
+ )
103
+ )
104
+ prev_dim = hidden_dim
105
+ layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
106
+ return SequentialCond(*layers)
107
+
108
+
109
+ class ResidualMLPBlock(torch.nn.Module):
110
+ def __init__(
111
+ self,
112
+ input_dim: int,
113
+ hidden_dim: int,
114
+ num_hidden_layers: int,
115
+ output_dim: int,
116
+ activation: torch.nn.Module = torch.nn.ReLU(),
117
+ bias: bool = True,
118
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
119
+ dropout: float = 0.0,
120
+ norm_cond_dim: int = -1,
121
+ ):
122
+ super().__init__()
123
+ if not (input_dim == output_dim == hidden_dim):
124
+ raise NotImplementedError(
125
+ f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
126
+ )
127
+
128
+ layers = []
129
+ prev_dim = input_dim
130
+ for i in range(num_hidden_layers):
131
+ layers.append(
132
+ linear_norm_activ_dropout(
133
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
134
+ )
135
+ )
136
+ prev_dim = hidden_dim
137
+ self.model = SequentialCond(*layers)
138
+ self.skip = torch.nn.Identity()
139
+
140
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
141
+ return x + self.model(x, *args, **kwargs)
142
+
143
+
144
+ class ResidualMLP(torch.nn.Module):
145
+ def __init__(
146
+ self,
147
+ input_dim: int,
148
+ hidden_dim: int,
149
+ num_hidden_layers: int,
150
+ output_dim: int,
151
+ activation: torch.nn.Module = torch.nn.ReLU(),
152
+ bias: bool = True,
153
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
154
+ dropout: float = 0.0,
155
+ num_blocks: int = 1,
156
+ norm_cond_dim: int = -1,
157
+ ):
158
+ super().__init__()
159
+ self.input_dim = input_dim
160
+ self.model = SequentialCond(
161
+ linear_norm_activ_dropout(
162
+ input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
163
+ ),
164
+ *[
165
+ ResidualMLPBlock(
166
+ hidden_dim,
167
+ hidden_dim,
168
+ num_hidden_layers,
169
+ hidden_dim,
170
+ activation,
171
+ bias,
172
+ norm,
173
+ dropout,
174
+ norm_cond_dim,
175
+ )
176
+ for _ in range(num_blocks)
177
+ ],
178
+ torch.nn.Linear(hidden_dim, output_dim, bias=bias),
179
+ )
180
+
181
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
182
+ return self.model(x, *args, **kwargs)
183
+
184
+
185
+ class FrequencyEmbedder(torch.nn.Module):
186
+ def __init__(self, num_frequencies, max_freq_log2):
187
+ super().__init__()
188
+ frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
189
+ self.register_buffer("frequencies", frequencies)
190
+
191
+ def forward(self, x):
192
+ # x should be of size (N,) or (N, D)
193
+ N = x.size(0)
194
+ if x.dim() == 1: # (N,)
195
+ x = x.unsqueeze(1) # (N, D) where D=1
196
+ x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
197
+ scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
198
+ s = torch.sin(scaled)
199
+ c = torch.cos(scaled)
200
+ embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
201
+ N, -1
202
+ ) # (N, D * 2 * num_frequencies + D)
203
+ return embedded
204
+
prima/models/components/transformer.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
11
+ # All rights reserved.
12
+
13
+ # This source code is licensed under the license found in the
14
+ # LICENSE file in the root directory of this source tree.
15
+
16
+ import contextlib
17
+ import math
18
+ import warnings
19
+ from functools import partial
20
+ from typing import Tuple, Type
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import nn, Tensor
25
+
26
+ from .position_encoding import apply_rotary_enc, compute_axial_cis
27
+ from .model_utils import MLP
28
+
29
+ warnings.simplefilter(action="ignore", category=FutureWarning)
30
+
31
+
32
+ def get_sdpa_settings():
33
+ if torch.cuda.is_available():
34
+ old_gpu = torch.cuda.get_device_properties(0).major < 7
35
+ # only use Flash Attention on Ampere (8.0) or newer GPUs
36
+ use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
37
+ if not use_flash_attn:
38
+ warnings.warn(
39
+ "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
40
+ category=UserWarning,
41
+ stacklevel=2,
42
+ )
43
+ # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
44
+ # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
45
+ pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
46
+ if pytorch_version < (2, 2):
47
+ warnings.warn(
48
+ f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
49
+ "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
50
+ category=UserWarning,
51
+ stacklevel=2,
52
+ )
53
+ math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
54
+ else:
55
+ old_gpu = True
56
+ use_flash_attn = False
57
+ math_kernel_on = True
58
+
59
+ return old_gpu, use_flash_attn, math_kernel_on
60
+
61
+
62
+ # Check whether Flash Attention is available (and use it by default)
63
+ OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
64
+ # A fallback setting to allow all available kernels if Flash Attention fails
65
+ ALLOW_ALL_KERNELS = False
66
+
67
+
68
+ def sdp_kernel_context(dropout_p):
69
+ """
70
+ Get the context for the attention scaled dot-product kernel. We use Flash Attention
71
+ by default, but fall back to all available kernels if Flash Attention fails.
72
+ """
73
+ if ALLOW_ALL_KERNELS:
74
+ return contextlib.nullcontext()
75
+
76
+ return torch.backends.cuda.sdp_kernel(
77
+ enable_flash=USE_FLASH_ATTN,
78
+ # if Flash attention kernel is off, then math kernel needs to be enabled
79
+ enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
80
+ enable_mem_efficient=OLD_GPU,
81
+ )
82
+
83
+
84
+ class TwoWayTransformer(nn.Module):
85
+ def __init__(
86
+ self,
87
+ depth: int,
88
+ embedding_dim: int,
89
+ num_heads: int,
90
+ mlp_dim: int,
91
+ activation: Type[nn.Module] = nn.ReLU,
92
+ attention_downsample_rate: int = 2,
93
+ ) -> None:
94
+ """
95
+ A transformer decoder that attends to an input image using
96
+ queries whose positional embedding is supplied.
97
+
98
+ Args:
99
+ depth (int): number of layers in the transformer
100
+ embedding_dim (int): the channel dimension for the input embeddings
101
+ num_heads (int): the number of heads for multihead attention. Must
102
+ divide embedding_dim
103
+ mlp_dim (int): the channel dimension internal to the MLP block
104
+ activation (nn.Module): the activation to use in the MLP block
105
+ """
106
+ super().__init__()
107
+ self.depth = depth
108
+ self.embedding_dim = embedding_dim
109
+ self.num_heads = num_heads
110
+ self.mlp_dim = mlp_dim
111
+ self.layers = nn.ModuleList()
112
+
113
+ for i in range(depth):
114
+ self.layers.append(
115
+ TwoWayAttentionBlock(
116
+ embedding_dim=embedding_dim,
117
+ num_heads=num_heads,
118
+ mlp_dim=mlp_dim,
119
+ activation=activation,
120
+ attention_downsample_rate=attention_downsample_rate,
121
+ skip_first_layer_pe=(i == 0),
122
+ )
123
+ )
124
+
125
+ self.final_attn_token_to_image = Attention(
126
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
127
+ )
128
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
129
+
130
+ def forward(
131
+ self,
132
+ image_embedding: Tensor,
133
+ image_pe: Tensor,
134
+ point_embedding: Tensor,
135
+ ) -> Tuple[Tensor, Tensor]:
136
+ """
137
+ Args:
138
+ image_embedding (torch.Tensor): image to attend to. Should be shape
139
+ B x embedding_dim x h x w for any h and w.
140
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
141
+ have the same shape as image_embedding.
142
+ point_embedding (torch.Tensor): the embedding to add to the query points.
143
+ Must have shape B x N_points x embedding_dim for any N_points.
144
+
145
+ Returns:
146
+ torch.Tensor: the processed point_embedding
147
+ torch.Tensor: the processed image_embedding
148
+ """
149
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
150
+ bs, c, h, w = image_embedding.shape
151
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
152
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
153
+
154
+ # Prepare queries
155
+ queries = point_embedding
156
+ keys = image_embedding
157
+
158
+ # Apply transformer blocks and final layernorm
159
+ for layer in self.layers:
160
+ queries, keys = layer(
161
+ queries=queries,
162
+ keys=keys,
163
+ query_pe=point_embedding,
164
+ key_pe=image_pe,
165
+ )
166
+
167
+ # Apply the final attention layer from the points to the image
168
+ q = queries + point_embedding
169
+ k = keys + image_pe
170
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
171
+ queries = queries + attn_out
172
+ queries = self.norm_final_attn(queries)
173
+
174
+ return queries, keys
175
+
176
+
177
+ class TwoWayAttentionBlock(nn.Module):
178
+ def __init__(
179
+ self,
180
+ embedding_dim: int,
181
+ num_heads: int,
182
+ mlp_dim: int = 2048,
183
+ activation: Type[nn.Module] = nn.ReLU,
184
+ attention_downsample_rate: int = 2,
185
+ skip_first_layer_pe: bool = False,
186
+ ) -> None:
187
+ """
188
+ A transformer block with four layers: (1) self-attention of sparse
189
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
190
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
191
+ inputs.
192
+
193
+ Arguments:
194
+ embedding_dim (int): the channel dimension of the embeddings
195
+ num_heads (int): the number of heads in the attention layers
196
+ mlp_dim (int): the hidden dimension of the mlp block
197
+ activation (nn.Module): the activation of the mlp block
198
+ skip_first_layer_pe (bool): skip the PE on the first layer
199
+ """
200
+ super().__init__()
201
+ self.self_attn = Attention(embedding_dim, num_heads)
202
+ self.norm1 = nn.LayerNorm(embedding_dim)
203
+
204
+ self.cross_attn_token_to_image = Attention(
205
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
206
+ )
207
+ self.norm2 = nn.LayerNorm(embedding_dim)
208
+
209
+ self.mlp = MLP(
210
+ embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
211
+ )
212
+ self.norm3 = nn.LayerNorm(embedding_dim)
213
+
214
+ self.norm4 = nn.LayerNorm(embedding_dim)
215
+ self.cross_attn_image_to_token = Attention(
216
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
217
+ )
218
+
219
+ self.skip_first_layer_pe = skip_first_layer_pe
220
+
221
+ def forward(
222
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
223
+ ) -> Tuple[Tensor, Tensor]:
224
+ # Self attention block
225
+ if self.skip_first_layer_pe:
226
+ queries = self.self_attn(q=queries, k=queries, v=queries)
227
+ else:
228
+ q = queries + query_pe
229
+ attn_out = self.self_attn(q=q, k=q, v=queries)
230
+ queries = queries + attn_out
231
+ queries = self.norm1(queries)
232
+
233
+ # Cross attention block, tokens attending to image embedding
234
+ q = queries + query_pe
235
+ k = keys + key_pe
236
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
237
+ queries = queries + attn_out
238
+ queries = self.norm2(queries)
239
+
240
+ # MLP block
241
+ mlp_out = self.mlp(queries)
242
+ queries = queries + mlp_out
243
+ queries = self.norm3(queries)
244
+
245
+ # Cross attention block, image embedding attending to tokens
246
+ q = queries + query_pe
247
+ k = keys + key_pe
248
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
249
+ keys = keys + attn_out
250
+ keys = self.norm4(keys)
251
+
252
+ return queries, keys
253
+
254
+
255
+ class Attention(nn.Module):
256
+ """
257
+ An attention layer that allows for downscaling the size of the embedding
258
+ after projection to queries, keys, and values.
259
+ """
260
+
261
+ def __init__(
262
+ self,
263
+ embedding_dim: int,
264
+ num_heads: int,
265
+ downsample_rate: int = 1,
266
+ dropout: float = 0.0,
267
+ kv_in_dim: int = None,
268
+ ) -> None:
269
+ super().__init__()
270
+ self.embedding_dim = embedding_dim
271
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
272
+ self.internal_dim = embedding_dim // downsample_rate
273
+ self.num_heads = num_heads
274
+ assert (
275
+ self.internal_dim % num_heads == 0
276
+ ), "num_heads must divide embedding_dim."
277
+
278
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
279
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
280
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
281
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
282
+
283
+ self.dropout_p = dropout
284
+
285
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
286
+ b, n, c = x.shape
287
+ x = x.reshape(b, n, num_heads, c // num_heads)
288
+ return x.transpose(1, 2).contiguous() # B x N_heads x N_tokens x C_per_head
289
+
290
+ def _recombine_heads(self, x: Tensor) -> Tensor:
291
+ b, n_heads, n_tokens, c_per_head = x.shape
292
+ x = x.transpose(1, 2).contiguous()
293
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
294
+
295
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
296
+ # Input projections
297
+ q = self.q_proj(q)
298
+ k = self.k_proj(k)
299
+ v = self.v_proj(v)
300
+
301
+ # Separate into heads
302
+ q = self._separate_heads(q, self.num_heads)
303
+ k = self._separate_heads(k, self.num_heads)
304
+ v = self._separate_heads(v, self.num_heads)
305
+
306
+ dropout_p = self.dropout_p if self.training else 0.0
307
+ # Attention
308
+ try:
309
+ with sdp_kernel_context(dropout_p):
310
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
311
+ except Exception as e:
312
+ # Fall back to all kernels if the Flash attention kernel fails
313
+ warnings.warn(
314
+ f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
315
+ f"kernels for scaled_dot_product_attention (which may have a slower speed).",
316
+ category=UserWarning,
317
+ stacklevel=2,
318
+ )
319
+ global ALLOW_ALL_KERNELS
320
+ ALLOW_ALL_KERNELS = True
321
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
322
+
323
+ out = self._recombine_heads(out)
324
+ out = self.out_proj(out)
325
+
326
+ return out
327
+
328
+
329
+ class RoPEAttention(Attention):
330
+ """Attention with rotary position encoding."""
331
+
332
+ def __init__(
333
+ self,
334
+ *args,
335
+ rope_theta=10000.0,
336
+ # whether to repeat q rope to match k length
337
+ # this is needed for cross-attention to memories
338
+ rope_k_repeat=False,
339
+ feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
340
+ **kwargs,
341
+ ):
342
+ super().__init__(*args, **kwargs)
343
+
344
+ self.compute_cis = partial(
345
+ compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
346
+ )
347
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
348
+ self.freqs_cis = freqs_cis
349
+ self.rope_k_repeat = rope_k_repeat
350
+
351
+ def forward(
352
+ self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int=0,
353
+ ) -> Tensor:
354
+ # Input projections
355
+ q = self.q_proj(q)
356
+ k = self.k_proj(k)
357
+ v = self.v_proj(v)
358
+
359
+ # Separate into heads
360
+ q = self._separate_heads(q, self.num_heads)
361
+ k = self._separate_heads(k, self.num_heads)
362
+ v = self._separate_heads(v, self.num_heads)
363
+
364
+ # Apply rotary position encoding
365
+ w = h = math.sqrt(q.shape[-2])
366
+ self.freqs_cis = self.freqs_cis.to(q.device)
367
+ if self.freqs_cis.shape[0] != q.shape[-2]:
368
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
369
+ if q.shape[-2] != k.shape[-2]:
370
+ assert self.rope_k_repeat
371
+
372
+ num_k_rope = k.size(-2) - num_k_exclude_rope
373
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
374
+ q,
375
+ k[:, :, :num_k_rope],
376
+ freqs_cis=self.freqs_cis,
377
+ repeat_freqs_k=self.rope_k_repeat,
378
+ )
379
+
380
+ dropout_p = self.dropout_p if self.training else 0.0
381
+ # Attention
382
+ try:
383
+ with sdp_kernel_context(dropout_p):
384
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
385
+ except Exception as e:
386
+ # Fall back to all kernels if the Flash attention kernel fails
387
+ warnings.warn(
388
+ f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
389
+ f"kernels for scaled_dot_product_attention (which may have a slower speed).",
390
+ category=UserWarning,
391
+ stacklevel=2,
392
+ )
393
+ global ALLOW_ALL_KERNELS
394
+ ALLOW_ALL_KERNELS = True
395
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
396
+
397
+ out = self._recombine_heads(out)
398
+ out = self.out_proj(out)
399
+
400
+ return out
prima/models/discriminator.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+
14
+ class Discriminator(nn.Module):
15
+
16
+ def __init__(self):
17
+ """
18
+ Pose + Shape discriminator proposed in HMR
19
+ """
20
+ super(Discriminator, self).__init__()
21
+
22
+ self.num_joints = 34
23
+ # poses_alone
24
+ self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1)
25
+ nn.init.xavier_uniform_(self.D_conv1.weight)
26
+ nn.init.zeros_(self.D_conv1.bias)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1)
29
+ nn.init.xavier_uniform_(self.D_conv2.weight)
30
+ nn.init.zeros_(self.D_conv2.bias)
31
+ pose_out = []
32
+ for i in range(self.num_joints):
33
+ pose_out_temp = nn.Linear(32, 1)
34
+ nn.init.xavier_uniform_(pose_out_temp.weight)
35
+ nn.init.zeros_(pose_out_temp.bias)
36
+ pose_out.append(pose_out_temp)
37
+ self.pose_out = nn.ModuleList(pose_out)
38
+
39
+ # betas
40
+ self.betas_fc1 = nn.Linear(41, 10) # SMAL betas is 41
41
+ nn.init.xavier_uniform_(self.betas_fc1.weight)
42
+ nn.init.zeros_(self.betas_fc1.bias)
43
+ self.betas_fc2 = nn.Linear(10, 5)
44
+ nn.init.xavier_uniform_(self.betas_fc2.weight)
45
+ nn.init.zeros_(self.betas_fc2.bias)
46
+ self.betas_out = nn.Linear(5, 1)
47
+ nn.init.xavier_uniform_(self.betas_out.weight)
48
+ nn.init.zeros_(self.betas_out.bias)
49
+
50
+ # bones
51
+ self.bone_fc1 = nn.Linear(24, 10) # SMAL betas is 41
52
+ nn.init.xavier_uniform_(self.bone_fc1.weight)
53
+ nn.init.zeros_(self.bone_fc1.bias)
54
+ self.bone_fc2 = nn.Linear(10, 5)
55
+ nn.init.xavier_uniform_(self.bone_fc2.weight)
56
+ nn.init.zeros_(self.bone_fc2.bias)
57
+ self.bone_out = nn.Linear(5, 1)
58
+ nn.init.xavier_uniform_(self.bone_out.weight)
59
+ nn.init.zeros_(self.bone_out.bias)
60
+
61
+ # poses_joint
62
+ self.D_alljoints_fc1 = nn.Linear(32 * self.num_joints, 1024)
63
+ nn.init.xavier_uniform_(self.D_alljoints_fc1.weight)
64
+ nn.init.zeros_(self.D_alljoints_fc1.bias)
65
+ self.D_alljoints_fc2 = nn.Linear(1024, 1024)
66
+ nn.init.xavier_uniform_(self.D_alljoints_fc2.weight)
67
+ nn.init.zeros_(self.D_alljoints_fc2.bias)
68
+ self.D_alljoints_out = nn.Linear(1024, 1)
69
+ nn.init.xavier_uniform_(self.D_alljoints_out.weight)
70
+ nn.init.zeros_(self.D_alljoints_out.bias)
71
+
72
+ def forward(self, poses: torch.Tensor, betas: torch.Tensor, bone=None) -> torch.Tensor:
73
+ """
74
+ Forward pass of the discriminator.
75
+ Args:
76
+ poses (torch.Tensor): Tensor of shape (B, 23, 3, 3) containing a batch of poses (excluding the global orientation).
77
+ betas (torch.Tensor): Tensor of shape (B, 41) containing a batch of SMAL beta coefficients.
78
+ Returns:
79
+ torch.Tensor: Discriminator output with shape (B, 25)
80
+ """
81
+ # bn = poses.shape[0]
82
+ # poses B x 207
83
+ # poses = poses.reshape(bn, -1)
84
+ # poses B x num_joints x 1 x 9
85
+ poses = poses.reshape(-1, self.num_joints, 1, 9)
86
+ bn = poses.shape[0]
87
+ # poses B x 9 x num_joints x 1
88
+ poses = poses.permute(0, 3, 1, 2).contiguous()
89
+
90
+ # poses_alone
91
+ poses = self.D_conv1(poses)
92
+ poses = self.relu(poses)
93
+ poses = self.D_conv2(poses)
94
+ poses = self.relu(poses)
95
+
96
+ poses_out = []
97
+ for i in range(self.num_joints):
98
+ poses_out_ = self.pose_out[i](poses[:, :, i, 0])
99
+ poses_out.append(poses_out_)
100
+ poses_out = torch.cat(poses_out, dim=1)
101
+
102
+ # betas
103
+ betas = self.betas_fc1(betas)
104
+ betas = self.relu(betas)
105
+ betas = self.betas_fc2(betas)
106
+ betas = self.relu(betas)
107
+ betas_out = self.betas_out(betas)
108
+
109
+ # bone
110
+ if bone is not None:
111
+ bone = self.bone_fc1(bone)
112
+ bone = self.relu(bone)
113
+ bone = self.bone_fc2(bone)
114
+ bone = self.relu(bone)
115
+ bone_out = self.bone_out(bone)
116
+
117
+ # poses_joint
118
+ poses = poses.reshape(bn, -1)
119
+ poses_all = self.D_alljoints_fc1(poses)
120
+ poses_all = self.relu(poses_all)
121
+ poses_all = self.D_alljoints_fc2(poses_all)
122
+ poses_all = self.relu(poses_all)
123
+ poses_all_out = self.D_alljoints_out(poses_all)
124
+
125
+ if bone is not None:
126
+ disc_out = torch.cat((poses_out, betas_out, poses_all_out, bone_out), 1)
127
+ else:
128
+ disc_out = torch.cat((poses_out, betas_out, poses_all_out), 1)
129
+ return disc_out
prima/models/heads/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .smal_head import build_smal_head