HF Space deploy commited on
Commit
cdad419
·
0 Parent(s):

Deploy snapshot (LFS for demo images per .gitattributes)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. .github/workflows/check-headers.yml +36 -0
  3. .github/workflows/codespell.yml +21 -0
  4. .github/workflows/release-pypi.yml +48 -0
  5. .gitignore +175 -0
  6. README.md +252 -0
  7. app.py +713 -0
  8. chumpy/__init__.py +16 -0
  9. chumpy/ch.py +66 -0
  10. configs/sa_finetune_hrnet_w32.yaml +220 -0
  11. configs_hydra/experiment/default.yaml +28 -0
  12. configs_hydra/experiment/default_val.yaml +34 -0
  13. configs_hydra/experiment/primaStage1.yaml +83 -0
  14. configs_hydra/experiment/primaStage2.yaml +113 -0
  15. configs_hydra/extras/default.yaml +8 -0
  16. configs_hydra/hydra/default.yaml +26 -0
  17. configs_hydra/launcher/local.yaml +13 -0
  18. configs_hydra/launcher/slurm.yaml +22 -0
  19. configs_hydra/paths/default.yaml +18 -0
  20. configs_hydra/train.yaml +46 -0
  21. configs_hydra/trainer/cpu.yaml +6 -0
  22. configs_hydra/trainer/ddp.yaml +14 -0
  23. configs_hydra/trainer/default.yaml +10 -0
  24. configs_hydra/trainer/default_amr.yaml +9 -0
  25. configs_hydra/trainer/gpu.yaml +6 -0
  26. configs_hydra/trainer/mps.yaml +6 -0
  27. demo.py +189 -0
  28. demo.sh +12 -0
  29. demo_data/000000015956_horse.png +3 -0
  30. demo_data/000000315905_zebra.jpg +3 -0
  31. demo_data/beagle.jpg +3 -0
  32. demo_data/n02101388_1188.png +3 -0
  33. demo_data/n02412080_12159.png +3 -0
  34. demo_data/shepherd_hati.jpg +3 -0
  35. demo_tta.py +399 -0
  36. demo_tta.sh +15 -0
  37. eval.py +103 -0
  38. images/teaser.png +3 -0
  39. packages.txt +4 -0
  40. prima/__init__.py +25 -0
  41. prima/configs/__init__.py +99 -0
  42. prima/models/__init__.py +54 -0
  43. prima/models/backbones/__init__.py +19 -0
  44. prima/models/backbones/vit.py +375 -0
  45. prima/models/bioclip_embedding.py +70 -0
  46. prima/models/components/__init__.py +0 -0
  47. prima/models/components/model_utils.py +160 -0
  48. prima/models/components/pose_transformer.py +366 -0
  49. prima/models/components/position_encoding.py +84 -0
  50. prima/models/components/t_cond_mlp.py +204 -0
.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Hugging Face Hub stores these via Git LFS / Xet (plain PNG/JPG in git are rejected on push).
2
+ demo_data/*.png filter=lfs diff=lfs merge=lfs -text
3
+ demo_data/*.jpg filter=lfs diff=lfs merge=lfs -text
4
+ demo_data/*.jpeg filter=lfs diff=lfs merge=lfs -text
5
+ images/*.png filter=lfs diff=lfs merge=lfs -text
.github/workflows/check-headers.yml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Check File Headers
3
+
4
+ on:
5
+ push:
6
+ branches: [main]
7
+ pull_request:
8
+ branches: [main]
9
+
10
+ jobs:
11
+ check-headers:
12
+ name: Check Python file headers
13
+ runs-on: ubuntu-latest
14
+ permissions:
15
+ contents: read
16
+
17
+ steps:
18
+ - name: Checkout code
19
+ uses: actions/checkout@v3
20
+
21
+ - name: Set up Python
22
+ uses: actions/setup-python@v4
23
+ with:
24
+ python-version: "3.10"
25
+
26
+ - name: Check headers
27
+ run: |
28
+ python scripts/update_headers.py --check
29
+ continue-on-error: false
30
+
31
+ - name: Provide fix instructions
32
+ if: failure()
33
+ run: |
34
+ echo "::error::Some files are missing proper headers."
35
+ echo "To fix this, run: python scripts/update_headers.py"
36
+ echo "Then commit the changes."
.github/workflows/codespell.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Codespell
3
+
4
+ on:
5
+ push:
6
+ branches: [main]
7
+ pull_request:
8
+ branches: [main]
9
+
10
+ jobs:
11
+ codespell:
12
+ name: Check for spelling errors
13
+ runs-on: ubuntu-latest
14
+
15
+ steps:
16
+ - name: Checkout
17
+ uses: actions/checkout@v3
18
+ - name: Codespell
19
+ uses: codespell-project/actions-codespell@v1
20
+ with:
21
+ ignore_words_list: prima-animal, mpjpe, uvd, xyz, hm36, cpn, dbb
.github/workflows/release-pypi.yml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Update pypi release
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - 'v*.*.*'
7
+ pull_request:
8
+ branches:
9
+ - main
10
+ types:
11
+ - labeled
12
+ - opened
13
+ - edited
14
+ - synchronize
15
+ - reopened
16
+
17
+ jobs:
18
+ release:
19
+ runs-on: ubuntu-latest
20
+
21
+ steps:
22
+ - name: Cache dependencies
23
+ id: pip-cache
24
+ uses: actions/cache@v4
25
+ with:
26
+ path: ~/.cache/pip
27
+ key: ${{ runner.os }}-pip
28
+
29
+ - name: Install dependencies
30
+ run: |
31
+ pip install --upgrade pip
32
+ pip install wheel
33
+ # NOTE(stes) see https://github.com/pypa/twine/issues/1216#issuecomment-2629069669
34
+ pip install "packaging>=24.2"
35
+
36
+ - name: Checkout code
37
+ uses: actions/checkout@v3
38
+
39
+ - name: Build and publish to PyPI
40
+ if: ${{ github.event_name == 'push' }}
41
+ env:
42
+ TWINE_USERNAME: __token__
43
+ TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }}
44
+ run: |
45
+ pip install build twine
46
+ python3 -m build
47
+ ls dist/
48
+ python3 -m twine upload --verbose dist/*
.gitignore ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+ # Vscode
164
+ .vscode/
165
+
166
+ # Directory
167
+ .gradio/
168
+ demo_out/
169
+ demo_out*/
170
+ data/PRIMA*/
171
+ data/backbone.pth
172
+ logs/
173
+ *.pth
174
+ *.pkl
175
+ datasets/
README.md ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
2
+
3
+
4
+ This is the official implementation of the approach described in the preprint:
5
+
6
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation \
7
+ Xiaohang Yu, Ti Wang, Mackenzie Weygandt Mathis
8
+
9
+ ![PRIMA teaser](images/teaser.png)
10
+
11
+
12
+ ---
13
+
14
+
15
+ ## 🚀 TL;DR
16
+ PRIMA creates a 3D quadruped mesh from a single 2D image. It leverages BioCLIP-based biological priors for robust cross-species shape understanding, then applies test-time adaptation with 2D reprojection and auxiliary keypoint guidance to refine SMAL pose and shape predictions.
17
+
18
+ It further can be used to build Quadruped3D, a large-scale pseudo-3D dataset with diverse species and poses.
19
+
20
+ PRIMA achieves state-of-the-art results on Animal3D, CtrlAni3D, Quadruped2D, and Animal Kingdom datasets.
21
+
22
+ ## Installation
23
+
24
+ ### Install from PyPI
25
+
26
+ > Recommended: Python 3.10 and a CUDA-enabled PyTorch installation.
27
+
28
+ ```bash
29
+ conda create -n prima python=3.10 -y
30
+ conda activate prima
31
+
32
+ # Install PyTorch matching your CUDA (example: CUDA 11.8)
33
+ pip install --index-url https://download.pytorch.org/whl/cu118 \
34
+ "torch==2.2.1" "torchvision==0.17.1" "torchaudio==2.2.1"
35
+
36
+ # Install chumpy and PyTorch3D
37
+ python -m pip install --no-build-isolation \
38
+ "git+https://github.com/mattloper/chumpy.git"
39
+ python -m pip install --no-build-isolation \
40
+ "git+https://github.com/facebookresearch/pytorch3d.git"
41
+
42
+ # Install PRIMA from PyPI
43
+ pip install prima-animal
44
+ ```
45
+
46
+ `prima-animal` includes demo runtime dependencies used by `demo.py`, `demo_tta.py`, and `app.py` (including Detectron2 and DeepLabCut).
47
+
48
+ ### Clean install from this repository
49
+
50
+ Use these when developing from a **git clone** (not the PyPI wheel). The shell scripts are **non-interactive** (pip uses `--no-input`; `GIT_TERMINAL_PROMPT=0` for git). Put Hugging Face credentials in your environment or git credential helper before pushing the Space.
51
+
52
+ **Local (fresh venv, LFS assets, Hub demo weights, smoke test)** — requires **Python 3.10+**
53
+ (Gradio 5.1+ / Space-provided Gradio 6.x and `app.py` type hints). On macOS without `python3.10` on your `PATH`, install
54
+ `brew install python@3.10` and set `PRIMA_PYTHON=/opt/homebrew/bin/python3.10`.
55
+
56
+ ```bash
57
+ chmod +x scripts/clean_install_local.sh scripts/clean_redeploy_hf_space.sh scripts/deploy_hf_space.sh
58
+ PRIMA_PYTHON=/opt/homebrew/bin/python3.10 ./scripts/clean_install_local.sh
59
+ ```
60
+
61
+ Options:
62
+
63
+ - `PRIMA_VENV=.venv ./scripts/clean_install_local.sh --skip-data` — skip the large `setup_demo_data` download if `data/` is already populated.
64
+ - `./scripts/clean_install_local.sh --wipe-data --force-data` — delete downloaded `data/` assets and redownload.
65
+ - `./scripts/clean_install_local.sh --no-editable` — only `requirements.txt` (no `pip install -e .`); use if editable install fails and you will install the training stack via conda as in the PyPI section above. You still need **Python 3.10+** for Gradio 5.1+. The smoke test sets `PYTHONPATH` to the repo root so `import prima` works without an editable install.
66
+ - **`requirements.txt` pins `deeplabcut==3.0.0rc14`** (SuperAnimal PyTorch API). On macOS, `clean_install_local.sh` installs a PyTables wheel first, then DLC 3.x. Full check: `./scripts/test_local_full.sh`.
67
+
68
+ After `requirements.txt`, the script runs **`pip install --no-deps -e .`** so the `prima` package is registered without re-resolving `pyproject.toml` (which would pull **Detectron2** from git again). Install Detectron2 separately if needed: `pip install 'git+https://github.com/facebookresearch/detectron2.git'`.
69
+
70
+ **Hugging Face Space (full redeploy from your working tree):**
71
+
72
+ Requires [Git LFS / Xet](https://huggingface.co/docs/hub/xet/using-xet-storage#git) tooling (`brew install git-lfs git-xet`, `git xet install`, `git lfs install`). Then:
73
+
74
+ ```bash
75
+ ./scripts/clean_redeploy_hf_space.sh
76
+ ```
77
+
78
+ This is equivalent to `./scripts/deploy_hf_space.sh` and force-pushes a fresh snapshot to the Space.
79
+
80
+ ---
81
+
82
+ ## Demo
83
+
84
+ ### Checkpoints and data
85
+
86
+ The demo scripts auto-download their default Stage 1 PRIMA assets from Hugging
87
+ Face when the checkpoint or matching Hydra config is missing. If you want to
88
+ pre-download all necessary checkpoints and data ahead of time, run:
89
+
90
+ ```bash
91
+ python scripts/setup_demo_data.py --hf-repo-id MLAdaptiveIntelligence/PRIMA
92
+ ```
93
+
94
+ Approximate default prefetch volume from Hugging Face is ~5.5 GB total
95
+ (`s1ckpt_inference.ckpt` ~3 GB + `amr_vitbb.pth` ~2.5 GB + SMAL files).
96
+ Expected time is roughly:
97
+ - 100 Mbps: ~7-10 minutes
98
+ - 300 Mbps: ~2-4 minutes
99
+ - 1 Gbps: ~1 minute
100
+
101
+ Existing files are reused by default; pass `--force` only if you need to redownload them. If you also need the Stage 3 pretrained model, add `--include-stage3`.
102
+
103
+ Expected files in that Hugging Face repo root:
104
+ - `my_smpl_00781_4_all.pkl`
105
+ - `my_smpl_data_00781_4_all.pkl`
106
+ - `walking_toy_symmetric_pose_prior_with_cov_35parts.pkl`
107
+ - `amr_vitbb.pth`
108
+ - `config_s1_HYDRA.yaml`
109
+ - `s1ckpt_inference.ckpt`
110
+
111
+ Optional Stage 3 prefetch expects:
112
+ - `config_s3_HYDRA.yaml`
113
+ - `s3ckpt_inference.ckpt`
114
+
115
+ ### Demo (without TTA)
116
+
117
+ Run animal detection + PRIMA 3D pose/shape inference:
118
+
119
+ ```bash
120
+ bash demo.sh
121
+ ```
122
+
123
+ Outputs are written to `demo_out/`. Edit `demo.sh` if you want to use a custom
124
+ checkpoint path.
125
+
126
+ ---
127
+
128
+ ### Demo (with TTA)
129
+
130
+ Run PRIMA inference with test-time adaptation:
131
+
132
+ ```bash
133
+ bash demo_tta.sh
134
+ ```
135
+
136
+ Outputs are written to `demo_out_tta/` (before/after TTA renders, keypoints, and
137
+ optional meshes). Edit `demo_tta.sh` if you want to change the checkpoint, TTA
138
+ learning rate, or number of iterations.
139
+
140
+ ---
141
+
142
+ ### Gradio demo
143
+
144
+ We also provide a simple Gradio-based web demo for interactive testing in the
145
+ browser:
146
+
147
+ ```bash
148
+ python app.py \
149
+ --checkpoint data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt \
150
+ --out_folder demo_out_tta_gradio/
151
+ ```
152
+
153
+ This starts a local Gradio app (by default on http://127.0.0.1:7860), where
154
+ you can upload images and visualize PRIMA predictions and adaptation results.
155
+ The `s1ckpt_inference.ckpt` checkpoint is downloaded automatically if missing.
156
+
157
+ `app.py` picks a **demo profile** automatically:
158
+
159
+ | | **Local** (`python app.py`) | **Hugging Face Space** |
160
+ |--|--|--|
161
+ | PRIMA device | GPU if available, else CPU | CPU only |
162
+ | Detectron2 | X-101-FPN | R50-FPN (lighter) |
163
+ | Default TTA iterations | 30 | 0 (PRIMA-only by default) |
164
+ | Save `.obj` meshes | on | off |
165
+ | Preload checkpoint at startup | off | on |
166
+
167
+ Override for testing: `PRIMA_DEMO_MODE=local` or `PRIMA_DEMO_MODE=space`.
168
+
169
+ #### Hugging Face Space (maintainers)
170
+
171
+ Demo images under `demo_data/` and `images/teaser.png` are tracked with **Git LFS**
172
+ (see `.gitattributes`) so they can be pushed to a Hugging Face Space under the Hub’s
173
+ LFS / **Xet** bridge. Install tooling once:
174
+
175
+ ```bash
176
+ brew install git-lfs git-xet
177
+ git xet install
178
+ git lfs install
179
+ ```
180
+
181
+ Then from a clean checkout with LFS files present, redeploy the Space (same as `clean_redeploy_hf_space.sh`):
182
+
183
+ ```bash
184
+ ./scripts/deploy_hf_space.sh
185
+ # or
186
+ ./scripts/clean_redeploy_hf_space.sh
187
+ ```
188
+
189
+ The script rsyncs the working tree (not `git archive`) so image files are materialized
190
+ before `git add` turns them into LFS blobs.
191
+
192
+ ---
193
+
194
+
195
+ ## Training and Evaluation
196
+
197
+ ### Dataset Setup
198
+
199
+ Download datasets from [Animal3D](https://xujiacong.github.io/Animal3D/), [CtrlAni3D](https://github.com/luoxue-star/AniMer?tab=readme-ov-file#training), Quadruped2D, and [Animal Kingdom](https://drive.google.com/file/d/1dk2a0qB0fbVZ4X6eAgP6VJVXj0rxVfsJ/view?usp=drive_link). For Quadruped2D, download the images from [SuperAnimal-Quadruped80K](https://zenodo.org/records/14016777) and our processed annotations from [here](https://drive.google.com/drive/folders/1eBNboxVwl_eGPoC93zxf-U3hmE6e2f-f?usp=sharing). Put all the datasets under `datasets/`.
200
+
201
+ ### Training
202
+
203
+ Two-stage training script:
204
+
205
+ ```bash
206
+ bash train.sh
207
+ ```
208
+
209
+ Training outputs are written to `logs/train/runs/<exp_name>/`.
210
+
211
+
212
+ ### Evaluation
213
+
214
+ ```bash
215
+ python eval.py \
216
+ --config data/PRIMAS1/.hydra/config.yaml \
217
+ --checkpoint data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt
218
+ ```
219
+
220
+ Common values for `--dataset` are controlled by:
221
+ - `configs_hydra/experiment/default_val.yaml`
222
+
223
+ ---
224
+
225
+
226
+ ## Acknowledgements
227
+
228
+ This release builds on several open-source projects, including:
229
+ - [Detectron2](https://github.com/facebookresearch/detectron2)
230
+ - [BioCLIP](https://github.com/Imageomics/BioCLIP)
231
+ - [AniMer](https://github.com/luoxue-star/AniMer)
232
+ - [DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)
233
+ - [SAM3DB](https://github.com/facebookresearch/sam-3d-body)
234
+
235
+ ---
236
+
237
+ ## Citation
238
+
239
+ If you use this code in your research, please cite our PRIMA paper.
240
+
241
+ ```bibtex
242
+ @misc{yu_prima,
243
+ title={PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation},
244
+ author={Xiaohang Yu and Ti Wang and Mackenzie Weygandt Mathis},
245
+ }
246
+ ```
247
+
248
+ ---
249
+
250
+ ## Contact
251
+
252
+ For issues, please open a GitHub issue in this repository.
app.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """Gradio demo for PRIMA + SuperAnimal + TTA.
11
+
12
+ This script wraps the ``demo_tta.py`` pipeline into an interactive
13
+ Gradio interface. The overall logic follows:
14
+
15
+ 1. Given an input image, run Detectron2 to detect animals.
16
+ 2. For each detected animal, run PRIMA for 3D pose/shape estimation.
17
+ 3. Run the fine-tuned DeepLabCut SuperAnimal model to obtain PRIMA 26-keypoint
18
+ 2D predictions.
19
+ 4. Run test-time adaptation (TTA) with user-specified lr and iters.
20
+ 5. Render and save before/after TTA results and keypoint visualizations.
21
+
22
+ """
23
+
24
+ import argparse
25
+ import os
26
+ import sys
27
+ import tempfile
28
+ import traceback
29
+ from dataclasses import dataclass
30
+ from functools import lru_cache
31
+ from types import SimpleNamespace
32
+ from typing import List, Optional, Tuple
33
+ from pathlib import Path
34
+
35
+ import cv2
36
+ import gradio as gr
37
+ import numpy as np
38
+ import torch
39
+ import torch.utils.data
40
+
41
+ # Space demo on macOS: limit BLAS threads (PyRender + PyTorch on main thread only).
42
+ if sys.platform == "darwin" and os.environ.get("SPACE_ID"):
43
+ os.environ.setdefault("OMP_NUM_THREADS", "1")
44
+ torch.set_num_threads(1)
45
+
46
+ # Repo-local minimal ``chumpy`` shim (see ``chumpy/__init__.py``) so SMAL pickles load
47
+ # without installing the full chumpy package in Space builds.
48
+ _REPO_ROOT = Path(__file__).resolve().parent
49
+ if str(_REPO_ROOT) not in sys.path:
50
+ sys.path.insert(0, str(_REPO_ROOT))
51
+
52
+ from prima.utils.weights import (
53
+ DEFAULT_HF_REPO_ID,
54
+ resolve_prima_checkpoint_path,
55
+ )
56
+ from prima.utils.detection import select_animal_boxes
57
+
58
+
59
+ # Default checkpoint path following README instructions
60
+ DEFAULT_CHECKPOINT = str(_REPO_ROOT / "data" / "PRIMAS1" / "checkpoints" / "s1ckpt_inference.ckpt")
61
+ DEFAULT_HF_ASSET_REPO = DEFAULT_HF_REPO_ID
62
+
63
+ # Output folder for rendered images/meshes and keypoints
64
+ DEFAULT_OUT_FOLDER = "demo_out_tta_gradio"
65
+
66
+ _D2_R50_CFG = "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
67
+ _D2_R50_URL = (
68
+ "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/"
69
+ "faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl"
70
+ )
71
+ _D2_X101_CFG = "COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"
72
+ _D2_X101_URL = (
73
+ "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/"
74
+ "faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
75
+ )
76
+
77
+ # Gradio example row: (image_rel, tta_lr, tta_iters, det_thresh, kp_thresh, side_view, save_mesh)
78
+ ExampleRow = Tuple[str, float, int, float, float, bool, bool]
79
+
80
+
81
+ @dataclass(frozen=True)
82
+ class DemoProfile:
83
+ """Runtime settings for either the full local app or the lightweight HF Space demo."""
84
+
85
+ mode: str
86
+ prima_device: str # "auto" (CUDA if available) or "cpu"
87
+ detectron_config_yaml: str
88
+ detectron_weights_url: str
89
+ detectron_device: str # "auto" or "cpu"
90
+ default_tta_iters: int
91
+ max_tta_iters: int
92
+ default_save_mesh: bool
93
+ default_side_view: bool
94
+ preload_assets: bool
95
+ example_rows: Tuple[ExampleRow, ...]
96
+ description: str
97
+ interface_title: str
98
+
99
+ def resolve_prima_device(self) -> torch.device:
100
+ if self.prima_device == "cpu":
101
+ return torch.device("cpu")
102
+ return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
103
+
104
+ def resolve_detectron_device(self) -> str:
105
+ if self.detectron_device == "cpu":
106
+ return "cpu"
107
+ return "cuda" if torch.cuda.is_available() else "cpu"
108
+
109
+
110
+ LOCAL_DEMO_PROFILE = DemoProfile(
111
+ mode="local",
112
+ prima_device="auto",
113
+ detectron_config_yaml=_D2_X101_CFG,
114
+ detectron_weights_url=_D2_X101_URL,
115
+ detectron_device="auto",
116
+ default_tta_iters=30,
117
+ max_tta_iters=100,
118
+ default_save_mesh=True,
119
+ default_side_view=False,
120
+ preload_assets=False,
121
+ example_rows=(
122
+ ("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, True),
123
+ ("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, True),
124
+ ("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, True),
125
+ ("demo_data/beagle.jpg", 1e-6, 0, 0.7, 0.1, False, True),
126
+ ("demo_data/shepherd_hati.jpg", 1e-6, 0, 0.7, 0.1, False, True),
127
+ ),
128
+ description=(
129
+ "**Local demo** — full pipeline on your machine (GPU when available).\n\n"
130
+ "Detectron2 **X-101-FPN**, PRIMA mesh recovery, optional **DeepLabCut SuperAnimal + TTA**. "
131
+ "Set TTA iterations to **0** to skip adaptation. Outputs are saved under "
132
+ f"`{DEFAULT_OUT_FOLDER}`."
133
+ ),
134
+ interface_title=(
135
+ "PRIMA local demo (GPU/CPU) — detection, mesh recovery, optional TTA"
136
+ ),
137
+ )
138
+
139
+ SPACE_DEMO_PROFILE = DemoProfile(
140
+ mode="space",
141
+ prima_device="cpu",
142
+ detectron_config_yaml=_D2_R50_CFG,
143
+ detectron_weights_url=_D2_R50_URL,
144
+ detectron_device="cpu",
145
+ default_tta_iters=0,
146
+ max_tta_iters=30,
147
+ default_save_mesh=False,
148
+ default_side_view=False,
149
+ preload_assets=True,
150
+ example_rows=(
151
+ ("demo_data/beagle.jpg", 1e-6, 0, 0.7, 0.1, False, False),
152
+ ("demo_data/000000015956_horse.png", 1e-6, 0, 0.7, 0.1, False, False),
153
+ ("demo_data/000000315905_zebra.jpg", 1e-6, 0, 0.7, 0.1, False, False),
154
+ ),
155
+ description=(
156
+ "**Hugging Face Space (cpu-basic)** — lightweight demo: **CPU-only**, Detectron2 **R50-FPN**, "
157
+ "PRIMA inference. TTA is optional (0 by default; increases runtime). Mesh `.obj` export is off "
158
+ "by default to save time and disk."
159
+ ),
160
+ interface_title="PRIMA on Hugging Face — lightweight CPU demo",
161
+ )
162
+
163
+
164
+ def _is_truthy_env(var_name: str) -> bool:
165
+ return os.environ.get(var_name, "").strip().lower() in {"1", "true", "yes", "on"}
166
+
167
+
168
+ def _running_on_space() -> bool:
169
+ return bool(os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID"))
170
+
171
+
172
+ @lru_cache(maxsize=1)
173
+ def get_demo_profile() -> DemoProfile:
174
+ """Select local vs Space profile. Override with ``PRIMA_DEMO_MODE=local|space``."""
175
+ override = os.environ.get("PRIMA_DEMO_MODE", "").strip().lower()
176
+ if override == "local":
177
+ return LOCAL_DEMO_PROFILE
178
+ if override == "space":
179
+ return SPACE_DEMO_PROFILE
180
+ return SPACE_DEMO_PROFILE if _running_on_space() else LOCAL_DEMO_PROFILE
181
+
182
+
183
+ def _gradio_examples_for_interface(profile: DemoProfile) -> List[List]:
184
+ """Gradio prefetches example media at startup (paths must exist beside ``app.py``)."""
185
+ if _is_truthy_env("PRIMA_DISABLE_GRADIO_EXAMPLES"):
186
+ return []
187
+ rows: List[List] = []
188
+ for rel, *rest in profile.example_rows:
189
+ p = _REPO_ROOT / rel
190
+ if p.is_file():
191
+ rows.append([str(p), *rest])
192
+ return rows
193
+
194
+
195
+ def _should_preload_assets(profile: DemoProfile) -> bool:
196
+ preload_env = os.environ.get("PRIMA_PRELOAD_ASSETS")
197
+ if preload_env is not None:
198
+ return _is_truthy_env("PRIMA_PRELOAD_ASSETS")
199
+ return profile.preload_assets
200
+
201
+ def _deeplabcut_available() -> bool:
202
+ try:
203
+ from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images # noqa: F401
204
+
205
+ return True
206
+ except Exception:
207
+ return False
208
+
209
+
210
+ def _preload_assets_once(checkpoint_path: str) -> None:
211
+ print("[startup] Ensuring demo assets from Hugging Face Hub...")
212
+ resolve_prima_checkpoint_path(
213
+ checkpoint_path,
214
+ data_dir=_REPO_ROOT / "data",
215
+ auto_download=True,
216
+ hf_repo_id=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO),
217
+ )
218
+ print("[startup] Asset preload complete.")
219
+
220
+
221
+ def _load_prima_model(checkpoint_path: str = DEFAULT_CHECKPOINT):
222
+ """Load PRIMA model and renderer once for the Gradio app."""
223
+ from prima.models import load_prima
224
+ from prima.utils.renderer import Renderer, cam_crop_to_full
225
+
226
+ checkpoint_path = resolve_prima_checkpoint_path(
227
+ checkpoint_path,
228
+ data_dir=_REPO_ROOT / "data",
229
+ auto_download=True,
230
+ hf_repo_id=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO),
231
+ )
232
+ checkpoint = Path(checkpoint_path)
233
+ cfg_path = checkpoint.parent.parent / ".hydra" / "config.yaml"
234
+ if not checkpoint.exists():
235
+ raise FileNotFoundError(
236
+ f"Missing checkpoint: {checkpoint}. Download demo checkpoints/data as described in README."
237
+ )
238
+ if not cfg_path.exists():
239
+ raise FileNotFoundError(
240
+ f"Missing model config: {cfg_path}. Ensure the full checkpoint folder layout from README is present."
241
+ )
242
+
243
+ profile = get_demo_profile()
244
+ model, model_cfg = load_prima(checkpoint_path)
245
+ device = profile.resolve_prima_device()
246
+ model = model.to(device)
247
+ model.eval()
248
+
249
+ renderer = Renderer(model_cfg, faces=model.smal.faces)
250
+ return model, model_cfg, renderer, cam_crop_to_full, device
251
+
252
+
253
+ def _build_detector(profile: Optional[DemoProfile] = None):
254
+ """Build Detectron2 animal detector (profile selects X-101+GPU locally vs R50+CPU on Space)."""
255
+ try:
256
+ import detectron2.config
257
+ import detectron2.engine
258
+ from detectron2 import model_zoo
259
+ except Exception as e:
260
+ print(f"[warn] Detectron2 unavailable ({type(e).__name__}: {e}); using full-image fallback bbox.")
261
+ return None
262
+
263
+ if profile is None:
264
+ profile = get_demo_profile()
265
+ config_yaml = profile.detectron_config_yaml
266
+ weights = profile.detectron_weights_url
267
+ device_str = profile.resolve_detectron_device()
268
+ print(f"[detectron2] mode={profile.mode} config={config_yaml} device={device_str}")
269
+
270
+ cfg = detectron2.config.get_cfg()
271
+ cfg.merge_from_file(model_zoo.get_config_file(config_yaml))
272
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
273
+ cfg.MODEL.WEIGHTS = weights
274
+ cfg.MODEL.DEVICE = device_str
275
+ detector = detectron2.engine.DefaultPredictor(cfg)
276
+ return detector
277
+
278
+
279
+ def _load_model_and_detector_for_demo(checkpoint_path: str, profile: DemoProfile):
280
+ """Load PRIMA and Detectron2 once for the Gradio session (main thread only)."""
281
+ model, model_cfg, renderer, cam_crop_to_full_fn, device = _load_prima_model(checkpoint_path)
282
+ detector = _build_detector(profile)
283
+ return model, model_cfg, renderer, cam_crop_to_full_fn, device, detector
284
+
285
+
286
+ def _detect_animal_boxes(
287
+ detector,
288
+ img_bgr: np.ndarray,
289
+ det_thresh: float,
290
+ ) -> Optional[np.ndarray]:
291
+ """Return Nx4 XYXY boxes or None if no animal detections."""
292
+ if detector is None:
293
+ h, w = img_bgr.shape[:2]
294
+ return np.array([[0.0, 0.0, float(max(1, w - 1)), float(max(1, h - 1))]], dtype=np.float32)
295
+
296
+ det_out = detector(img_bgr)
297
+ det_instances = det_out["instances"]
298
+ boxes, suppressed = select_animal_boxes(det_instances, score_threshold=float(det_thresh))
299
+ if suppressed > 0:
300
+ print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s)")
301
+ if len(boxes) == 0:
302
+ return None
303
+ return boxes
304
+
305
+
306
+ # SuperAnimal defaults (same as in demo_tta parser)
307
+ SUPER_ANIMAL_ARGS = SimpleNamespace(
308
+ superanimal_name="superanimal_quadruped",
309
+ superanimal_model_name="hrnet_w32",
310
+ superanimal_detector_name="fasterrcnn_resnet50_fpn_v2",
311
+ superanimal_max_individuals=1,
312
+ saved_2d_model_path="",
313
+ pytorch_config_2d_path=str(_REPO_ROOT / "configs" / "sa_finetune_hrnet_w32.yaml"),
314
+ )
315
+
316
+
317
+ def _collect_animal_results(
318
+ model,
319
+ model_cfg,
320
+ renderer,
321
+ cam_crop_to_full_fn,
322
+ device,
323
+ detector,
324
+ out_folder: str,
325
+ img_rgb: np.ndarray,
326
+ tta_lr: float,
327
+ tta_num_iters: int,
328
+ det_thresh: float,
329
+ kp_conf_thresh: float,
330
+ side_view: bool,
331
+ save_mesh: bool,
332
+ boxes: Optional[np.ndarray] = None,
333
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], str | None, str | None]:
334
+ """Run detection + PRIMA + SuperAnimal + TTA on a single RGB image.
335
+
336
+ Returns:
337
+ before_imgs: list of HxWx3 RGB images (before TTA) for all animals
338
+ after_imgs: list of HxWx3 RGB images (after TTA) for all animals
339
+ kpt_imgs: list of HxWx3 RGB keypoint visualizations
340
+ first_before_mesh: path to first animal's before-TTA mesh (.obj) or None
341
+ first_after_mesh: path to first animal's after-TTA mesh (.obj) or None
342
+ """
343
+ from prima.utils import recursive_to
344
+ from prima.datasets.vitdet_dataset import ViTDetDataset
345
+ from demo_tta import (
346
+ denorm_patch_to_rgb,
347
+ resolve_sa_weights_path,
348
+ run_superanimal_on_patch,
349
+ save_keypoint_vis,
350
+ tta_optimize,
351
+ )
352
+
353
+ if int(tta_num_iters) > 0 and not SUPER_ANIMAL_ARGS.saved_2d_model_path:
354
+ SUPER_ANIMAL_ARGS.saved_2d_model_path = resolve_sa_weights_path("")
355
+
356
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
357
+ if boxes is None:
358
+ boxes = _detect_animal_boxes(detector, img_bgr, det_thresh)
359
+ if boxes is None:
360
+ return [], [], [], None, None
361
+
362
+ dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
363
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
364
+
365
+ before_imgs: List[np.ndarray] = []
366
+ after_imgs: List[np.ndarray] = []
367
+ kpt_imgs: List[np.ndarray] = []
368
+ before_mesh_paths: List[str] = []
369
+ after_mesh_paths: List[str] = []
370
+
371
+ img_token = next(tempfile._get_candidate_names())
372
+
373
+ for batch in dataloader:
374
+ batch = recursive_to(batch, device)
375
+
376
+ with torch.no_grad():
377
+ out_before = model(batch)
378
+
379
+ animal_id = int(batch["animalid"][0])
380
+
381
+ # Save/render before TTA
382
+ img_fn = f"{img_token}"
383
+ from demo_tta import render_and_save # imported lazily to avoid circular issues
384
+
385
+ render_and_save(
386
+ renderer,
387
+ cam_crop_to_full_fn,
388
+ out_before,
389
+ batch,
390
+ img_fn,
391
+ animal_id,
392
+ out_folder,
393
+ suffix="before_tta",
394
+ side_view=side_view,
395
+ save_mesh=save_mesh,
396
+ )
397
+
398
+ before_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.png")
399
+ if os.path.exists(before_png_path):
400
+ before_bgr = cv2.imread(before_png_path)
401
+ if before_bgr is not None:
402
+ before_imgs.append(cv2.cvtColor(before_bgr, cv2.COLOR_BGR2RGB))
403
+
404
+ if save_mesh:
405
+ before_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.obj")
406
+ if os.path.exists(before_obj_path):
407
+ before_mesh_paths.append(before_obj_path)
408
+
409
+ if int(tta_num_iters) <= 0:
410
+ render_and_save(
411
+ renderer,
412
+ cam_crop_to_full_fn,
413
+ out_before,
414
+ batch,
415
+ img_fn,
416
+ animal_id,
417
+ out_folder,
418
+ suffix="after_tta",
419
+ side_view=side_view,
420
+ save_mesh=save_mesh,
421
+ )
422
+
423
+ after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
424
+ if os.path.exists(after_png_path):
425
+ after_bgr = cv2.imread(after_png_path)
426
+ if after_bgr is not None:
427
+ after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB))
428
+
429
+ if save_mesh:
430
+ after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj")
431
+ if os.path.exists(after_obj_path):
432
+ after_mesh_paths.append(after_obj_path)
433
+ continue
434
+
435
+ # Prepare patch for SuperAnimal
436
+ patch_rgb = denorm_patch_to_rgb(batch["img"][0])
437
+ with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir:
438
+ bodyparts_xyc = run_superanimal_on_patch(patch_rgb, SUPER_ANIMAL_ARGS, tmp_dir)
439
+
440
+ if bodyparts_xyc is None:
441
+ # No keypoints => skip TTA for this animal
442
+ continue
443
+
444
+ kpts_xyc = bodyparts_xyc
445
+ kpts_xyc[kpts_xyc[:, 2] < float(kp_conf_thresh), 2] = 0.0
446
+
447
+ # Save keypoint visualization and npy
448
+ kpt_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png")
449
+ save_keypoint_vis(patch_rgb, kpts_xyc, kpt_png_path)
450
+ npy_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy")
451
+ np.save(npy_path, kpts_xyc)
452
+
453
+ if os.path.exists(kpt_png_path):
454
+ kpt_bgr = cv2.imread(kpt_png_path)
455
+ if kpt_bgr is not None:
456
+ kpt_imgs.append(cv2.cvtColor(kpt_bgr, cv2.COLOR_BGR2RGB))
457
+
458
+ # Normalize keypoints to [-0.5, 0.5] as in demo_tta
459
+ patch_h, patch_w = patch_rgb.shape[:2]
460
+ kpts_norm = kpts_xyc.copy()
461
+ kpts_norm[:, 0] = kpts_norm[:, 0] / float(patch_w) - 0.5
462
+ kpts_norm[:, 1] = kpts_norm[:, 1] / float(patch_h) - 0.5
463
+ gt_kpts_norm = torch.from_numpy(kpts_norm[None]).to(device=device, dtype=batch["img"].dtype)
464
+
465
+ # Run TTA
466
+ out_after = tta_optimize(
467
+ model,
468
+ batch,
469
+ gt_kpts_norm,
470
+ num_iters=int(tta_num_iters),
471
+ lr=float(tta_lr),
472
+ )
473
+
474
+ render_and_save(
475
+ renderer,
476
+ cam_crop_to_full_fn,
477
+ out_after,
478
+ batch,
479
+ img_fn,
480
+ animal_id,
481
+ out_folder,
482
+ suffix="after_tta",
483
+ side_view=side_view,
484
+ save_mesh=save_mesh,
485
+ )
486
+
487
+ after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
488
+ if os.path.exists(after_png_path):
489
+ after_bgr = cv2.imread(after_png_path)
490
+ if after_bgr is not None:
491
+ after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB))
492
+
493
+ if save_mesh:
494
+ after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj")
495
+ if os.path.exists(after_obj_path):
496
+ after_mesh_paths.append(after_obj_path)
497
+
498
+ first_before_mesh = before_mesh_paths[0] if before_mesh_paths else None
499
+ first_after_mesh = after_mesh_paths[0] if after_mesh_paths else None
500
+
501
+ return before_imgs, after_imgs, kpt_imgs, first_before_mesh, first_after_mesh
502
+
503
+
504
+ def build_demo(checkpoint_path: str = DEFAULT_CHECKPOINT, out_folder: str = DEFAULT_OUT_FOLDER) -> gr.Interface:
505
+ profile = get_demo_profile()
506
+ print(
507
+ f"[demo] profile={profile.mode} prima={profile.resolve_prima_device()} "
508
+ f"detectron={profile.detectron_config_yaml} d2_device={profile.resolve_detectron_device()}"
509
+ )
510
+ os.makedirs(out_folder, exist_ok=True)
511
+ runtime_cache = {
512
+ "model": None,
513
+ "model_cfg": None,
514
+ "renderer": None,
515
+ "cam_crop_to_full_fn": None,
516
+ "device": None,
517
+ "detector": None,
518
+ }
519
+
520
+ def gradio_inference(
521
+ image: np.ndarray,
522
+ tta_lr: float,
523
+ tta_num_iters: int,
524
+ det_thresh: float,
525
+ kp_conf_thresh: float,
526
+ side_view: bool,
527
+ save_mesh: bool,
528
+ ):
529
+ """Wrapper for Gradio. ``image`` is an RGB numpy array.
530
+
531
+ Yields intermediate status so long first-run (Hub downloads + model load)
532
+ and long inference do not hit silent client/proxy WebSocket timeouts.
533
+ """
534
+
535
+ if image is None:
536
+ yield None, None, None, "No image provided."
537
+ return
538
+
539
+ if int(tta_num_iters) > 0 and not _deeplabcut_available():
540
+ yield (
541
+ None,
542
+ None,
543
+ None,
544
+ "DeepLabCut is not installed. Set **TTA iterations** to **0** for PRIMA-only inference, "
545
+ "or install `deeplabcut` (see README / requirements.txt).",
546
+ )
547
+ return
548
+
549
+ if image.dtype != np.uint8:
550
+ img_rgb = np.clip(image, 0, 255).astype(np.uint8)
551
+ else:
552
+ img_rgb = image
553
+
554
+ yield None, None, None, "Queued; preparing run…"
555
+
556
+ if runtime_cache["model"] is None:
557
+ yield (
558
+ None,
559
+ None,
560
+ None,
561
+ "First run: downloading demo assets from Hugging Face (large checkpoint) "
562
+ "and loading the model. This can take many minutes.",
563
+ )
564
+ try:
565
+ model, model_cfg, renderer, cam_crop_to_full_fn, device, detector = _load_model_and_detector_for_demo(
566
+ checkpoint_path, profile
567
+ )
568
+ except Exception:
569
+ yield None, None, None, f"Model initialization failed:\n{traceback.format_exc()}"
570
+ return
571
+ runtime_cache["model"] = model
572
+ runtime_cache["model_cfg"] = model_cfg
573
+ runtime_cache["renderer"] = renderer
574
+ runtime_cache["cam_crop_to_full_fn"] = cam_crop_to_full_fn
575
+ runtime_cache["device"] = device
576
+ runtime_cache["detector"] = detector
577
+ yield None, None, None, "Model loaded."
578
+
579
+ try:
580
+ yield None, None, None, "Running animal detection…"
581
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
582
+ boxes = _detect_animal_boxes(runtime_cache["detector"], img_bgr, det_thresh)
583
+ if boxes is None:
584
+ yield (
585
+ None,
586
+ None,
587
+ None,
588
+ "No animal detected. Try lowering the detection threshold or another image.",
589
+ )
590
+ return
591
+ yield (
592
+ None,
593
+ None,
594
+ None,
595
+ f"Detected {len(boxes)} animal region(s). Running PRIMA (+ SuperAnimal/TTA if enabled)…",
596
+ )
597
+ before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = _collect_animal_results(
598
+ runtime_cache["model"],
599
+ runtime_cache["model_cfg"],
600
+ runtime_cache["renderer"],
601
+ runtime_cache["cam_crop_to_full_fn"],
602
+ runtime_cache["device"],
603
+ runtime_cache["detector"],
604
+ out_folder,
605
+ img_rgb,
606
+ tta_lr=tta_lr,
607
+ tta_num_iters=tta_num_iters,
608
+ det_thresh=det_thresh,
609
+ kp_conf_thresh=kp_conf_thresh,
610
+ side_view=side_view,
611
+ save_mesh=save_mesh,
612
+ boxes=boxes,
613
+ )
614
+ except Exception:
615
+ yield None, None, None, f"Inference failed:\n{traceback.format_exc()}"
616
+ return
617
+
618
+ first_before = before_imgs[0] if before_imgs else None
619
+ first_after = after_imgs[0] if after_imgs else None
620
+ first_kpts = kpt_imgs[0] if kpt_imgs else None
621
+ if first_before is None and first_after is None:
622
+ yield (
623
+ None,
624
+ None,
625
+ None,
626
+ "No output generated. Try an image with a clearly visible quadruped.",
627
+ )
628
+ return
629
+ yield first_before, first_after, first_kpts, "OK"
630
+
631
+ _gradio_examples = _gradio_examples_for_interface(profile)
632
+ _iface_kw = dict(
633
+ fn=gradio_inference,
634
+ analytics_enabled=False,
635
+ cache_examples=False,
636
+ inputs=[
637
+ gr.Image(
638
+ label="Input image",
639
+ type="numpy",
640
+ sources=["upload", "clipboard"],
641
+ ),
642
+ gr.Slider(
643
+ label="TTA learning rate",
644
+ minimum=1e-7,
645
+ maximum=1e-4,
646
+ value=1e-6,
647
+ step=1e-7,
648
+ ),
649
+ gr.Slider(
650
+ label="TTA iterations",
651
+ minimum=0,
652
+ maximum=profile.max_tta_iters,
653
+ value=profile.default_tta_iters,
654
+ step=1,
655
+ info="Set to 0 to disable TTA and reuse the initial PRIMA prediction.",
656
+ ),
657
+ gr.Slider(
658
+ label="Detection threshold",
659
+ minimum=0.3,
660
+ maximum=0.9,
661
+ value=0.7,
662
+ step=0.05,
663
+ ),
664
+ gr.Slider(
665
+ label="Keypoint confidence threshold",
666
+ minimum=0.0,
667
+ maximum=1.0,
668
+ value=0.1,
669
+ step=0.05,
670
+ ),
671
+ gr.Checkbox(label="Render side view", value=profile.default_side_view),
672
+ gr.Checkbox(label="Save meshes (.obj)", value=profile.default_save_mesh),
673
+ ],
674
+ outputs=[
675
+ gr.Image(label="Before TTA"),
676
+ gr.Image(label="After TTA"),
677
+ gr.Image(label="PRIMA 26 keypoints"),
678
+ gr.Textbox(label="Status / Traceback", lines=12),
679
+ ],
680
+ title=profile.interface_title,
681
+ description=profile.description,
682
+ )
683
+ if _gradio_examples:
684
+ _iface_kw["examples"] = _gradio_examples
685
+ demo = gr.Interface(**_iface_kw)
686
+ demo.queue(max_size=8, default_concurrency_limit=1)
687
+ return demo
688
+
689
+
690
+ def parse_args() -> argparse.Namespace:
691
+ parser = argparse.ArgumentParser(description="Gradio demo for PRIMA + SuperAnimal + TTA")
692
+ parser.add_argument(
693
+ "--checkpoint",
694
+ type=str,
695
+ default=DEFAULT_CHECKPOINT,
696
+ help="Path to the pretrained PRIMA checkpoint",
697
+ )
698
+ parser.add_argument(
699
+ "--out_folder",
700
+ type=str,
701
+ default=DEFAULT_OUT_FOLDER,
702
+ help="Folder used to save rendered outputs and meshes",
703
+ )
704
+ return parser.parse_args()
705
+
706
+
707
+ if __name__ == "__main__":
708
+ args = parse_args()
709
+ profile = get_demo_profile()
710
+ if _should_preload_assets(profile):
711
+ _preload_assets_once(args.checkpoint)
712
+ demo = build_demo(checkpoint_path=args.checkpoint, out_folder=args.out_folder)
713
+ demo.launch(inbrowser=False)
chumpy/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ """
3
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
4
+
5
+ Official implementation of the paper:
6
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
7
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
8
+ Licensed under a modified MIT license
9
+ """
10
+
11
+
12
+ """Minimal ``chumpy`` compatibility for unpickling legacy SMAL model configs."""
13
+
14
+ from .ch import Ch, ChArray, materialize
15
+
16
+ __all__ = ["Ch", "ChArray", "materialize"]
chumpy/ch.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ """
3
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
4
+
5
+ Official implementation of the paper:
6
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
7
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
8
+ Licensed under a modified MIT license
9
+ """
10
+
11
+
12
+ """``chumpy.ch`` namespace expected by legacy SMAL pickles."""
13
+
14
+ import numpy as np
15
+
16
+
17
+ class Ch:
18
+ """Minimal stand-in for ``chumpy.ch.Ch`` (unpickling only)."""
19
+
20
+ def __init__(self, *args, **kwargs):
21
+ self._data = None
22
+ if args:
23
+ self._data = np.asarray(args[0])
24
+
25
+ def _resolve(self) -> np.ndarray:
26
+ # Real chumpy Ch instances store the underlying ndarray on attribute ``x``;
27
+ # legacy pickles unpickle by restoring ``__dict__`` without calling ``__init__``,
28
+ # so try common attribute names before falling back to ``_data``.
29
+ for attr in ("x", "_x", "_data"):
30
+ val = self.__dict__.get(attr)
31
+ if val is not None:
32
+ return np.asarray(val)
33
+ if self._data is not None:
34
+ return np.asarray(self._data)
35
+ return np.zeros((), dtype=np.float32)
36
+
37
+ @property
38
+ def r(self) -> np.ndarray:
39
+ return self._resolve()
40
+
41
+ def __array__(self, dtype=None):
42
+ arr = self.r()
43
+ if dtype is not None:
44
+ arr = arr.astype(dtype, copy=False)
45
+ return arr
46
+
47
+
48
+ class ChArray(np.ndarray):
49
+ """Minimal stand-in for ``chumpy.ch.ChArray``."""
50
+
51
+
52
+ def materialize(value, dtype=np.float32) -> np.ndarray:
53
+ """Recursively unwrap ``Ch`` / object arrays from legacy SMAL pickles."""
54
+ if isinstance(value, Ch):
55
+ return np.asarray(value.r(), dtype=dtype)
56
+ if isinstance(value, np.ndarray):
57
+ if value.dtype == object:
58
+ flat = [materialize(x, dtype=dtype) for x in value.ravel()]
59
+ return np.stack(flat).reshape(value.shape)
60
+ return np.asarray(value, dtype=dtype)
61
+ if isinstance(value, (list, tuple)):
62
+ return np.asarray([materialize(x, dtype=dtype) for x in value], dtype=dtype)
63
+ return np.asarray(value, dtype=dtype)
64
+
65
+
66
+ __all__ = ["Ch", "ChArray", "materialize"]
configs/sa_finetune_hrnet_w32.yaml ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepLabCut pytorch_config for the PRIMA TTA 2D pose model:
2
+ # SuperAnimal-Quadruped HRNet-w32 backbone fine-tuned on Animal3D, with
3
+ # the heatmap head re-trained for the 26-joint Animal3D / PRIMA layout.
4
+ #
5
+ # Used by demo_tta.py via DLC's `superanimal_analyze_images(...,
6
+ # customized_model_config=<this yaml>, customized_pose_checkpoint=<your
7
+ # fine-tuned .pt>)`. Only the pose model is fine-tuned; the bounding-box
8
+ # detector (Faster R-CNN) is the stock SuperAnimal-Quadruped one
9
+ # resolved by DLC at runtime.
10
+ data:
11
+ bbox_margin: 20
12
+ colormode: RGB
13
+ inference:
14
+ normalize_images: true
15
+ top_down_crop:
16
+ width: 256
17
+ height: 256
18
+ auto_padding:
19
+ pad_width_divisor: 32
20
+ pad_height_divisor: 32
21
+ train:
22
+ affine:
23
+ p: 0.5
24
+ rotation: 30
25
+ scaling:
26
+ - 1.0
27
+ - 1.0
28
+ translation: 0
29
+ gaussian_noise: 12.75
30
+ motion_blur: true
31
+ normalize_images: true
32
+ top_down_crop:
33
+ width: 256
34
+ height: 256
35
+ auto_padding:
36
+ pad_width_divisor: 32
37
+ pad_height_divisor: 32
38
+ detector:
39
+ data:
40
+ colormode: RGB
41
+ inference:
42
+ normalize_images: true
43
+ train:
44
+ affine:
45
+ p: 0.5
46
+ rotation: 30
47
+ scaling:
48
+ - 1.0
49
+ - 1.0
50
+ translation: 40
51
+ collate:
52
+ type: ResizeFromDataSizeCollate
53
+ min_scale: 0.4
54
+ max_scale: 1.0
55
+ min_short_side: 128
56
+ max_short_side: 1152
57
+ multiple_of: 32
58
+ to_square: false
59
+ hflip: true
60
+ normalize_images: true
61
+ device: auto
62
+ model:
63
+ type: FasterRCNN
64
+ freeze_bn_stats: true
65
+ freeze_bn_weights: false
66
+ variant: fasterrcnn_resnet50_fpn_v2
67
+ runner:
68
+ type: DetectorTrainingRunner
69
+ key_metric: test.mAP@50:95
70
+ key_metric_asc: true
71
+ eval_interval: 10
72
+ optimizer:
73
+ type: AdamW
74
+ params:
75
+ lr: 0.0001
76
+ scheduler:
77
+ type: LRListScheduler
78
+ params:
79
+ milestones:
80
+ - 160
81
+ lr_list:
82
+ - - 1e-05
83
+ snapshots:
84
+ max_snapshots: 5
85
+ save_epochs: 25
86
+ save_optimizer_state: false
87
+ train_settings:
88
+ batch_size: 1
89
+ dataloader_workers: 0
90
+ dataloader_pin_memory: false
91
+ display_iters: 500
92
+ epochs: 250
93
+ device: auto
94
+ inference:
95
+ multithreading:
96
+ enabled: true
97
+ queue_length: 4
98
+ timeout: 30.0
99
+ compile:
100
+ enabled: false
101
+ backend: inductor
102
+ autocast:
103
+ enabled: false
104
+ metadata:
105
+ project_path: ""
106
+ pose_config_path: ""
107
+ bodyparts:
108
+ - left_eye
109
+ - right_eye
110
+ - chin
111
+ - left_front_paw
112
+ - right_front_paw
113
+ - left_back_paw
114
+ - right_back_paw
115
+ - tail_base
116
+ - left_front_thigh
117
+ - right_front_thigh
118
+ - left_back_thigh
119
+ - right_back_thigh
120
+ - left_shoulder
121
+ - right_shoulder
122
+ - left_front_knee
123
+ - right_front_knee
124
+ - left_back_knee
125
+ - right_back_knee
126
+ - neck_base
127
+ - tail_mid
128
+ - left_ear_base
129
+ - right_ear_base
130
+ - left_mouth_corner
131
+ - right_mouth_corner
132
+ - nose
133
+ - tail_tip_first
134
+ unique_bodyparts: []
135
+ individuals:
136
+ - individual000
137
+ with_identity: false
138
+ method: td
139
+ model:
140
+ backbone:
141
+ type: HRNet
142
+ model_name: hrnet_w32
143
+ freeze_bn_stats: true
144
+ freeze_bn_weights: false
145
+ interpolate_branches: false
146
+ increased_channel_count: false
147
+ backbone_output_channels: 32
148
+ heads:
149
+ bodypart:
150
+ type: HeatmapHead
151
+ weight_init: normal
152
+ predictor:
153
+ type: HeatmapPredictor
154
+ apply_sigmoid: false
155
+ clip_scores: true
156
+ location_refinement: true
157
+ locref_std: 7.2801
158
+ target_generator:
159
+ type: HeatmapGaussianGenerator
160
+ num_heatmaps: 26
161
+ pos_dist_thresh: 17
162
+ heatmap_mode: KEYPOINT
163
+ gradient_masking: true
164
+ background_weight: 0.0
165
+ generate_locref: true
166
+ locref_std: 7.2801
167
+ criterion:
168
+ heatmap:
169
+ type: WeightedMSECriterion
170
+ weight: 1.0
171
+ locref:
172
+ type: WeightedHuberCriterion
173
+ weight: 0.05
174
+ heatmap_config:
175
+ channels:
176
+ - 32
177
+ kernel_size: []
178
+ strides: []
179
+ final_conv:
180
+ out_channels: 26
181
+ kernel_size: 1
182
+ locref_config:
183
+ channels:
184
+ - 32
185
+ kernel_size: []
186
+ strides: []
187
+ final_conv:
188
+ out_channels: 52
189
+ kernel_size: 1
190
+ net_type: hrnet_w32
191
+ runner:
192
+ type: PoseTrainingRunner
193
+ gpus:
194
+ key_metric: test.mAP
195
+ key_metric_asc: true
196
+ eval_interval: 10
197
+ optimizer:
198
+ type: AdamW
199
+ params:
200
+ lr: 0.0001
201
+ scheduler:
202
+ type: LRListScheduler
203
+ params:
204
+ lr_list:
205
+ - - 1e-05
206
+ - - 1e-06
207
+ milestones:
208
+ - 160
209
+ - 190
210
+ snapshots:
211
+ max_snapshots: 5
212
+ save_epochs: 10
213
+ save_optimizer_state: false
214
+ train_settings:
215
+ batch_size: 64
216
+ dataloader_workers: 8
217
+ dataloader_pin_memory: false
218
+ display_iters: 500
219
+ epochs: 200
220
+ seed: 42
configs_hydra/experiment/default.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ SMAL:
4
+ DATA_DIR: data/smal
5
+ MODEL_PATH: data/smal/my_smpl_00781_4_all.pkl
6
+ SHAPE_PRIOR_PATH: data/smal/my_smpl_data_00781_4_all.pkl
7
+ POSE_PRIOR_PATH: data/smal/walking_toy_symmetric_pose_prior_with_cov_35parts.pkl
8
+ NUM_JOINTS: 34
9
+
10
+ EXTRA:
11
+ FOCAL_LENGTH: 1000
12
+ NUM_LOG_IMAGES: 4
13
+ NUM_LOG_SAMPLES_PER_IMAGE: 4
14
+ PELVIS_IND: 0
15
+
16
+ DATASETS:
17
+ CONFIG:
18
+ SCALE_FACTOR: 0.3
19
+ ROT_FACTOR: 30
20
+ TRANS_FACTOR: 0.02
21
+ COLOR_SCALE: 0.2
22
+ ROT_AUG_RATE: 0.6
23
+ TRANS_AUG_RATE: 0.5
24
+ DO_FLIP: False
25
+ FLIP_AUG_RATE: 0.0
26
+ EXTREME_CROP_AUG_RATE: 0.0
27
+ EXTREME_CROP_AUG_LEVEL: 1
28
+
configs_hydra/experiment/default_val.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ DATASETS:
4
+ ANIMAL3D:
5
+ ROOT_IMAGE: ./datasets/animal3d/
6
+ JSON_FILE:
7
+ TEST: ./datasets/animal3d/test.json
8
+ CONTROL_ANIMAL3D:
9
+ ROOT_IMAGE: ./datasets/control_animal3dlatest/
10
+ JSON_FILE:
11
+ TEST: ./datasets/control_animal3dlatest/test.json
12
+ QUADRUPED2D:
13
+ ROOT_IMAGE: ./datasets/quadruped2d/
14
+ JSON_FILE:
15
+ TEST: ./datasets/quadruped2d/test.json
16
+ ANIMAL_KINGDOM:
17
+ ROOT_IMAGE: ./datasets/Animal_Kingdom_test/
18
+ JSON_FILE:
19
+ TEST: ./datasets/Animal_Kingdom_test/test.json
20
+ CONFIG:
21
+ SCALE_FACTOR: 0.0
22
+ ROT_FACTOR: 0
23
+ TRANS_FACTOR: 0.0
24
+ COLOR_SCALE: 0.0
25
+ ROT_AUG_RATE: 0.0
26
+ TRANS_AUG_RATE: 0.0
27
+ DO_FLIP: False
28
+ FLIP_AUG_RATE: 0.0
29
+ EXTREME_CROP_AUG_RATE: 0.0
30
+ EXTREME_CROP_AUG_LEVEL: 1
31
+
32
+ METRIC:
33
+ PCK_THRESHOLD: [0.10, 0.15]
34
+
configs_hydra/experiment/primaStage1.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - default.yaml
5
+
6
+ GENERAL:
7
+
8
+
9
+ TOTAL_STEPS: 63_000
10
+ LOG_STEPS: 63
11
+ VAL_STEPS: 63
12
+ VAL_EPOCHS: 1
13
+ CHECKPOINT_EPOCHS: 1
14
+ CHECKPOINT_SAVE_TOP_K: 2
15
+ NUM_WORKERS: 8
16
+ PREFETCH_FACTOR: 2
17
+
18
+ LOSS_WEIGHTS:
19
+ KEYPOINTS_3D: 0.05
20
+ KEYPOINTS_2D: 0.01
21
+ INTERMEDIATE_KP2D: 0.001
22
+ INTERMEDIATE_KP3D: 0.001
23
+ GLOBAL_ORIENT: 0.005
24
+ POSE: 0.001
25
+ BETAS: 0.0005
26
+ TRANSL: 0.0005
27
+ ADVERSARIAL: 0.0005
28
+ SUPCON: 0.0005
29
+
30
+
31
+ TRAIN:
32
+ LR: 3.75e-6
33
+ WEIGHT_DECAY: 1e-4
34
+ BATCH_SIZE: 48
35
+ LOSS_REDUCTION: mean
36
+ NUM_TRAIN_SAMPLES: 2
37
+ NUM_TEST_SAMPLES: 64
38
+ POSE_2D_NOISE_RATIO: 0.01
39
+ SMPL_PARAM_NOISE_RATIO: 0.005
40
+
41
+ MODEL:
42
+ IMAGE_SIZE: 256
43
+ IMAGE_MEAN: [0.485, 0.456, 0.406]
44
+ IMAGE_STD: [0.229, 0.224, 0.225]
45
+ BACKBONE:
46
+ TYPE: vith
47
+ PRETRAINED_WEIGHTS: ./data/amr_vitbb.pth
48
+ FREEZE: False
49
+
50
+ # Enable BioClip embedding
51
+ USE_BIOCLIP_EMBEDDING: True
52
+ BIOCLIP_EMBEDDING:
53
+ EMBED_DIM: 1280 # Match DINOv2 output dimension for token-wise concatenation
54
+ TYPE: bioclip1
55
+
56
+ # Enable 2D keypoint embedding for initialization; NewBioGuidedSMALPoseDecoder updates it dynamically
57
+ USE_KEYPOINT_EMBEDDING: False
58
+
59
+ SMAL_HEAD:
60
+ TYPE: new_bio_pose_transformer_decoder # Use the newer version with SAM3D-style hierarchical updates
61
+ IN_CHANNELS: 1280
62
+ IEF_ITERS: 3
63
+
64
+ # Pose Transformer Decoder configuration
65
+ DECODER_DIM: 1280
66
+ NUM_DECODER_LAYERS: 6
67
+ NUM_HEADS: 8
68
+ MLP_RATIO: 4.0
69
+
70
+ # Keypoint token configuration specific to NewBioGuidedSMALPoseDecoder
71
+ USE_KEYPOINT_2D_TOKENS: True # Enable 2D keypoint tokens with SAM3D-style dynamic updates
72
+ USE_KEYPOINT_3D_TOKENS: True # Enable 3D keypoint tokens with pelvis normalization
73
+ KEYPOINT_TOKEN_UPDATE: True # Enable hierarchical keypoint prediction and token updates
74
+ KP2D_INJECT_IMAGE_FEAT: True # Key setting: inject image features via grid_sample
75
+
76
+
77
+ DATASETS:
78
+ ANIMAL3D:
79
+ ROOT_IMAGE: ./datasets/animal3d/
80
+ JSON_FILE:
81
+ TRAIN: ./datasets/animal3d/train.json
82
+ TEST: ./datasets/animal3d/test.json
83
+ WEIGHT: 1.0
configs_hydra/experiment/primaStage2.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - default.yaml
5
+
6
+ GENERAL:
7
+
8
+
9
+ TOTAL_STEPS: 450_000
10
+ LOG_STEPS: 533
11
+ VAL_STEPS: 533
12
+ VAL_EPOCHS: 1
13
+ CHECKPOINT_EPOCHS: 1
14
+ CHECKPOINT_SAVE_TOP_K: 2
15
+ NUM_WORKERS: 2
16
+ PREFETCH_FACTOR: 2
17
+
18
+ LOSS_WEIGHTS:
19
+ KEYPOINTS_3D: 0.05
20
+ KEYPOINTS_2D: 0.01
21
+ INTERMEDIATE_KP2D: 0.001
22
+ INTERMEDIATE_KP3D: 0.001
23
+ GLOBAL_ORIENT: 0.005
24
+ POSE: 0.001
25
+ BETAS: 0.0005
26
+ TRANSL: 0.0005
27
+ ADVERSARIAL: 0.0
28
+ SUPCON: 0.0005
29
+
30
+
31
+ TRAIN:
32
+ LR: 3.75e-6
33
+ WEIGHT_DECAY: 1e-4
34
+ BATCH_SIZE: 48
35
+ LOSS_REDUCTION: mean
36
+ NUM_TRAIN_SAMPLES: 2
37
+ NUM_TEST_SAMPLES: 64
38
+ POSE_2D_NOISE_RATIO: 0.01
39
+ SMPL_PARAM_NOISE_RATIO: 0.005
40
+
41
+ MODEL:
42
+ IMAGE_SIZE: 256
43
+ IMAGE_MEAN: [0.485, 0.456, 0.406]
44
+ IMAGE_STD: [0.229, 0.224, 0.225]
45
+ BACKBONE:
46
+ TYPE: vith
47
+ PRETRAINED_WEIGHTS: ./data/amr_vitbb.pth
48
+ FREEZE: False
49
+
50
+ # Enable BioClip embedding
51
+ USE_BIOCLIP_EMBEDDING: True
52
+ BIOCLIP_EMBEDDING:
53
+ EMBED_DIM: 1280 # Match vit output dimension for token-wise concatenation
54
+ TYPE: bioclip1
55
+
56
+ # Enable 2D keypoint embedding
57
+ USE_KEYPOINT_EMBEDDING: False
58
+ KEYPOINT_EMBEDDING:
59
+ NUM_KEYPOINTS: 26 # Number of SMAL keypoints
60
+ KEYPOINT_DIM: 2 # 2D coordinates (x, y)
61
+ EMBED_DIM: 1280 # Match vit output dimension
62
+ HIDDEN_DIM: 512 # Hidden layer dimension in MLP
63
+ TYPE: 'token' # Use token-based embedding (recommended)
64
+
65
+ SMAL_HEAD:
66
+ TYPE: new_bio_pose_transformer_decoder # Use the newer version with SAM3D-style hierarchical updates
67
+ IN_CHANNELS: 1280
68
+ IEF_ITERS: 1
69
+
70
+ # Pose Transformer Decoder configuration
71
+ DECODER_DIM: 1280
72
+ NUM_DECODER_LAYERS: 6
73
+ NUM_HEADS: 8
74
+ MLP_RATIO: 4.0
75
+
76
+ # Keypoint token configuration specific to NewBioGuidedSMALPoseDecoder
77
+ USE_KEYPOINT_2D_TOKENS: True # Enable 2D keypoint tokens with SAM3D-style dynamic updates
78
+ USE_KEYPOINT_3D_TOKENS: True # Enable 3D keypoint tokens with pelvis normalization
79
+ KEYPOINT_TOKEN_UPDATE: True # Enable hierarchical keypoint prediction and token updates
80
+ KP2D_INJECT_IMAGE_FEAT: True # Key setting: inject image features via grid_sample
81
+
82
+ # Legacy transformer config (kept for compatibility)
83
+ TRANSFORMER_DECODER:
84
+ depth: 6
85
+ heads: 8
86
+ mlp_dim: 1024
87
+ dim_head: 64
88
+ dropout: 0.0
89
+ emb_dropout: 0.0
90
+ norm: layer
91
+ context_dim: 1280
92
+
93
+
94
+
95
+ DATASETS:
96
+ ANIMAL3D:
97
+ ROOT_IMAGE: ./datasets/animal3d/
98
+ JSON_FILE:
99
+ TRAIN: ./datasets/animal3d/train.json
100
+ TEST: ./datasets/animal3d/test.json
101
+ WEIGHT: 1.0
102
+ CONTROL_ANIMAL3D:
103
+ ROOT_IMAGE: ./datasets/control_animal3dlatest/
104
+ JSON_FILE:
105
+ TRAIN: ./datasets/control_animal3dlatest/train.json
106
+ TEST: ./datasets/control_animal3dlatest/test.json
107
+ WEIGHT: 0.5
108
+ QUADRUPED2D:
109
+ ROOT_IMAGE: ./datasets/quadruped2d/
110
+ JSON_FILE:
111
+ TRAIN: ./datasets/quadruped2d/train.json
112
+ TEST: ./datasets/quadruped2d/test.json
113
+ WEIGHT: 0.15
configs_hydra/extras/default.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # disable python warnings if they annoy you
2
+ ignore_warnings: False
3
+
4
+ # ask user for tags if none are provided in the config
5
+ enforce_tags: True
6
+
7
+ # pretty print config tree at the start of the run using Rich library
8
+ print_config: True
configs_hydra/hydra/default.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ # https://hydra.cc/docs/configure_hydra/intro/
3
+
4
+ # enable color logging
5
+ defaults:
6
+ - override /hydra/hydra_logging: colorlog
7
+ - override /hydra/job_logging: colorlog
8
+
9
+ # exp_name: ovrd_${hydra:job.override_dirname}
10
+ exp_name: ${now:%Y-%m-%d}_${now:%H-%M-%S}
11
+
12
+ hydra:
13
+ run:
14
+ dir: ${paths.log_dir}/${task_name}/runs/${exp_name}
15
+ sweep:
16
+ dir: ${paths.log_dir}/${task_name}/multiruns/${exp_name}
17
+ subdir: ${hydra.job.num}
18
+ job:
19
+ config:
20
+ override_dirname:
21
+ exclude_keys:
22
+ - trainer
23
+ - trainer.devices
24
+ - trainer.num_nodes
25
+ - callbacks
26
+ - debug
configs_hydra/launcher/local.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /hydra/launcher: submitit_local
5
+
6
+ hydra:
7
+ launcher:
8
+ timeout_min: 10_080 # 7 days
9
+ nodes: 1
10
+ tasks_per_node: ${trainer.devices}
11
+ cpus_per_task: 8
12
+ gpus_per_node: ${trainer.devices}
13
+ name: amr
configs_hydra/launcher/slurm.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /hydra/launcher: submitit_slurm
5
+
6
+ hydra:
7
+ launcher:
8
+ timeout_min: 10_080 # 7 days
9
+ max_num_timeout: 3
10
+ partition: g40
11
+ qos: idle
12
+ nodes: 1
13
+ tasks_per_node: ${trainer.devices}
14
+ gpus_per_task: null
15
+ cpus_per_task: 12
16
+ gpus_per_node: ${trainer.devices}
17
+ cpus_per_gpu: null
18
+ comment: prima
19
+ name: prima
20
+ setup:
21
+ - module load cuda openmpi libfabric-aws
22
+ - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
configs_hydra/paths/default.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path to root directory
2
+ # this requires PROJECT_ROOT environment variable to exist
3
+ # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py`
4
+ root_dir: ${oc.env:PROJECT_ROOT}
5
+
6
+ # path to data directory
7
+ data_dir: ${paths.root_dir}/data/
8
+
9
+ # path to logging directory
10
+ log_dir: logs/
11
+
12
+ # path to output directory, created dynamically by hydra
13
+ # path generation pattern is specified in `configs/hydra/default.yaml`
14
+ # use it to store all files generated during the run, like ckpts and metrics
15
+ output_dir: ${hydra:runtime.output_dir}
16
+
17
+ # path to working directory
18
+ work_dir: ${hydra:runtime.cwd}
configs_hydra/train.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # specify here default configuration
4
+ # order of defaults determines the order in which configs override each other
5
+ defaults:
6
+ - _self_
7
+ - trainer: ddp.yaml
8
+ - paths: default.yaml
9
+ - extras: default.yaml
10
+ - hydra: default.yaml
11
+
12
+ # experiment configs allow for version control of specific hyperparameters
13
+ # e.g. best hyperparameters for given model and datamodule
14
+ - experiment: null
15
+ - texture_exp: null
16
+
17
+ # optional local config for machine/user specific settings
18
+ # it's optional since it doesn't need to exist and is excluded from version control
19
+ - optional launcher: local.yaml
20
+ # - optional launcher: slurm.yaml
21
+
22
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
23
+ - debug: null
24
+
25
+ # task name, determines output directory path
26
+ task_name: "train"
27
+
28
+ # tags to help you identify your experiments
29
+ # you can overwrite this in experiment configs
30
+ # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
31
+ # appending lists from command line is currently not supported :(
32
+ # https://github.com/facebookresearch/hydra/issues/1547
33
+ tags: ["dev"]
34
+
35
+ # set False to skip model training
36
+ train: True
37
+
38
+ # evaluate on test set, using best model weights achieved during training
39
+ # lightning chooses best weights based on the metric specified in checkpoint callback
40
+ test: False
41
+
42
+ # simply provide checkpoint path to resume training
43
+ ckpt_path: True
44
+
45
+ # seed for random number generators in pytorch, numpy and python.random
46
+ seed: null
configs_hydra/trainer/cpu.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+ - default_amr.yaml
4
+
5
+ accelerator: cpu
6
+ devices: 1
configs_hydra/trainer/ddp.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+ - default_amr.yaml
4
+
5
+ # use "ddp_spawn" instead of "ddp",
6
+ # it's slower but normal "ddp" currently doesn't work ideally with hydra
7
+ # https://github.com/facebookresearch/hydra/issues/2070
8
+ # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn
9
+ strategy: ddp_spawn
10
+
11
+ accelerator: gpu
12
+ devices: 2
13
+ num_nodes: 1
14
+ sync_batchnorm: True
configs_hydra/trainer/default.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: pytorch_lightning.Trainer
2
+
3
+ default_root_dir: ${paths.output_dir}
4
+
5
+ accelerator: gpu
6
+ devices: 1
7
+
8
+ # set True to to ensure deterministic results
9
+ # makes training slower but gives more reproducibility than just setting seeds
10
+ deterministic: False
configs_hydra/trainer/default_amr.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ num_sanity_val_steps: 0
2
+ log_every_n_steps: ${GENERAL.LOG_STEPS}
3
+ val_check_interval: ${GENERAL.VAL_STEPS} # How often within one training epoch to check the validation set.
4
+ check_val_every_n_epoch: ${GENERAL.VAL_EPOCHS} # Check val every n train epochs.
5
+ precision: 16-mixed # 16-mixed, 32
6
+ max_steps: ${GENERAL.TOTAL_STEPS}
7
+ # move_metrics_to_cpu: True
8
+ limit_val_batches: 80 # How much of validation dataset to check.
9
+ # track_grad_norm: -1
configs_hydra/trainer/gpu.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+ - default_amr.yaml
4
+
5
+ accelerator: gpu
6
+ devices: 1
configs_hydra/trainer/mps.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+ - default_amr.yaml
4
+
5
+ accelerator: mps
6
+ devices: 1
demo.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from pathlib import Path
11
+ import detectron2.config
12
+ import detectron2.engine
13
+ import torch
14
+ import argparse
15
+ import os
16
+ import cv2
17
+ import numpy as np
18
+ from tqdm import tqdm
19
+ import torch.utils
20
+ import torch.utils.data
21
+ from prima.models import load_prima
22
+ from prima.utils import recursive_to
23
+ from prima.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
24
+ from prima.utils.detection import select_animal_boxes
25
+ from prima.utils.weights import DEFAULT_HF_REPO_ID, resolve_prima_checkpoint_path
26
+ import detectron2
27
+ from detectron2 import model_zoo
28
+ import warnings
29
+ warnings.filterwarnings("ignore")
30
+
31
+ LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353)
32
+ GREEN = (0.65, 0.86, 0.74)
33
+ REPO_ROOT = Path(__file__).resolve().parent
34
+
35
+
36
+ def load_renderer_components():
37
+ try:
38
+ from prima.utils.renderer import Renderer, cam_crop_to_full
39
+ except Exception as exc:
40
+ raise RuntimeError(
41
+ "Cannot initialize the PRIMA renderer. Rendering requires a working "
42
+ "pyrender/OpenGL backend such as EGL or OSMesa. Install the missing "
43
+ "OpenGL runtime for this environment, or run in an environment where "
44
+ "PYOPENGL_PLATFORM=egl/osmesa works."
45
+ ) from exc
46
+ return Renderer, cam_crop_to_full
47
+
48
+
49
+ def main():
50
+ parser = argparse.ArgumentParser(description='prima demo code')
51
+ parser.add_argument('--checkpoint', type=str, default='',
52
+ help='Path to pretrained model checkpoint. Empty -> auto-download the default Stage 1 checkpoint.')
53
+ parser.add_argument('--hf-repo-id', '--hf_repo_id', dest='hf_repo_id',
54
+ type=str, default=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_REPO_ID),
55
+ help='Hugging Face repo ID containing PRIMA demo assets')
56
+ parser.add_argument('--no-auto-download', '--no_auto_download', dest='no_auto_download', action='store_true',
57
+ help='Disable automatic download of missing PRIMA demo assets')
58
+ parser.add_argument('--img_folder', type=str, default='demo_data/', help='Folder with input images')
59
+ parser.add_argument('--out_folder', type=str, default='demo_out', help='Output folder to save rendered results')
60
+ parser.add_argument('--side_view', dest='side_view', action='store_true', default=False,
61
+ help='If set, render side view also')
62
+ parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False,
63
+ help='If set, save meshes to disk also')
64
+ parser.add_argument('--batch_size', type=int, default=1, help='Batch size for inference/fitting')
65
+ parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg', '*.JPEG'],
66
+ help='List of file extensions to consider')
67
+
68
+ args = parser.parse_args()
69
+
70
+ checkpoint_path = resolve_prima_checkpoint_path(
71
+ args.checkpoint,
72
+ data_dir=REPO_ROOT / "data",
73
+ auto_download=not args.no_auto_download,
74
+ hf_repo_id=args.hf_repo_id,
75
+ )
76
+
77
+ model, model_cfg = load_prima(checkpoint_path)
78
+
79
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
80
+ model = model.to(device)
81
+ model.eval()
82
+
83
+ # Setup the renderer
84
+ Renderer, cam_crop_to_full = load_renderer_components()
85
+ renderer = Renderer(model_cfg, faces=model.smal.faces)
86
+
87
+ # Make output directory if it does not exist
88
+ os.makedirs(args.out_folder, exist_ok=True)
89
+
90
+ # Load detector
91
+ cfg = detectron2.config.get_cfg()
92
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"))
93
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
94
+ cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
95
+ cfg.MODEL.DEVICE = device.type
96
+ detector = detectron2.engine.DefaultPredictor(cfg)
97
+
98
+ img_paths = sorted([img for end in args.file_type for img in Path(args.img_folder).glob(end)])
99
+ num_readable_images = 0
100
+ num_rendered_results = 0
101
+ num_suppressed_detections = 0
102
+ for img_path in img_paths:
103
+ img_bgr = cv2.imread(str(img_path))
104
+ if img_bgr is None:
105
+ print(f"[WARN] Cannot read image: {img_path}")
106
+ continue
107
+ num_readable_images += 1
108
+ # Detect animals in image
109
+ det_out = detector(img_bgr)
110
+
111
+ det_instances = det_out['instances']
112
+ boxes, suppressed = select_animal_boxes(det_instances, score_threshold=0.7)
113
+ num_suppressed_detections += suppressed
114
+ if suppressed > 0:
115
+ print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s) in {img_path}")
116
+ if len(boxes) == 0:
117
+ print(f"[INFO] No animal detected in {img_path}")
118
+ continue
119
+
120
+ # Run PRIMA on detected animals
121
+ dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
122
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
123
+ for batch in tqdm(dataloader):
124
+ batch = recursive_to(batch, device)
125
+ with torch.no_grad():
126
+ out = model(batch)
127
+
128
+ pred_cam = out['pred_cam']
129
+ box_center = batch["box_center"].float()
130
+ box_size = batch["box_size"].float()
131
+ img_size = batch["img_size"].float()
132
+ scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max()
133
+ pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size,
134
+ scaled_focal_length).detach().cpu().numpy()
135
+
136
+ # Render the result
137
+ batch_size = batch['img'].shape[0]
138
+ for n in range(batch_size):
139
+ # Get filename from path img_path
140
+ img_fn, _ = os.path.splitext(os.path.basename(img_path))
141
+ animal_id = int(batch['animalid'][n])
142
+ white_img = (torch.ones_like(batch['img'][n]).cpu() - DEFAULT_MEAN[:, None, None] / 255) / (
143
+ DEFAULT_STD[:, None, None] / 255)
144
+ input_patch = (batch['img'][n].cpu() * (DEFAULT_STD[:, None, None]) + (
145
+ DEFAULT_MEAN[:, None, None])) / 255.
146
+ input_patch = input_patch.permute(1, 2, 0).numpy()
147
+
148
+ regression_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(),
149
+ out['pred_cam_t'][n].detach().cpu().numpy(),
150
+ batch['img'][n],
151
+ mesh_base_color=GREEN,
152
+ scene_bg_color=(1, 1, 1),
153
+ )
154
+
155
+ final_img = np.concatenate([input_patch, regression_img], axis=1)
156
+
157
+ if args.side_view:
158
+ side_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(),
159
+ out['pred_cam_t'][n].detach().cpu().numpy(),
160
+ white_img,
161
+ mesh_base_color=GREEN,
162
+ scene_bg_color=(1, 1, 1),
163
+ side_view=True)
164
+ final_img = np.concatenate([final_img, side_img], axis=1)
165
+
166
+ cv2.imwrite(os.path.join(args.out_folder, f'{img_fn}_{animal_id}.png'),
167
+ cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR))
168
+ num_rendered_results += 1
169
+
170
+ # Add all verts and cams to list
171
+ verts = out['pred_vertices'][n].detach().cpu().numpy()
172
+ cam_t = pred_cam_t_full[n]
173
+
174
+ # Save all meshes to disk
175
+ if args.save_mesh:
176
+ camera_translation = cam_t.copy()
177
+ tmesh = renderer.vertices_to_trimesh(verts, camera_translation, LIGHT_BLUE)
178
+ tmesh.export(os.path.join(args.out_folder, f'{img_fn}_{animal_id}.obj'))
179
+
180
+ print(
181
+ f"[done] Demo complete. Processed {num_readable_images}/{len(img_paths)} image(s), "
182
+ f"saved {num_rendered_results} rendered result(s) to {args.out_folder}."
183
+ )
184
+ if num_suppressed_detections > 0:
185
+ print(f"[done] Suppressed {num_suppressed_detections} duplicate animal detection(s).")
186
+
187
+
188
+ if __name__ == '__main__':
189
+ main()
demo.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Default PRIMA Stage 1 inference checkpoint:
2
+ # data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt
3
+ #
4
+ # If this local file is missing, it will be downloaded from the PRIMA Hugging Face repo.
5
+ # To use another local checkpoint instead, update this path.
6
+ # For example: checkpoint='data/PRIMAS3/checkpoints/s3ckpt.ckpt'
7
+ checkpoint='data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt'
8
+
9
+ python demo.py \
10
+ --checkpoint "${checkpoint}" \
11
+ --img_folder demo_data/ \
12
+ --out_folder demo_out/
demo_data/000000015956_horse.png ADDED

Git LFS Details

  • SHA256: 2a2398ba7df40a47c636afefa28be17b55f4b7bc2c378e053aeea507580ad2cb
  • Pointer size: 131 Bytes
  • Size of remote file: 620 kB
demo_data/000000315905_zebra.jpg ADDED

Git LFS Details

  • SHA256: e0a17e1f1650820b020a9025144015c1e27f0f1ab435859f0bde3a0047d8f689
  • Pointer size: 131 Bytes
  • Size of remote file: 257 kB
demo_data/beagle.jpg ADDED

Git LFS Details

  • SHA256: ac29e6ea6086831dd9806a8cd3fd608e264ac1af567f6fcfc8797c5bd3d5d560
  • Pointer size: 131 Bytes
  • Size of remote file: 350 kB
demo_data/n02101388_1188.png ADDED

Git LFS Details

  • SHA256: e45ff508fb8c6437cce22fcb59b4f1b6fe37ddfab1d4cf68d97629f9caa939f4
  • Pointer size: 131 Bytes
  • Size of remote file: 319 kB
demo_data/n02412080_12159.png ADDED

Git LFS Details

  • SHA256: 03273c57e8b25b258d3eb96af7b4f77b43b5c40be90da83c21875f3322b487f1
  • Pointer size: 131 Bytes
  • Size of remote file: 347 kB
demo_data/shepherd_hati.jpg ADDED

Git LFS Details

  • SHA256: 65c5878203bc3165dda9011ebfce77cc7d930daed0a215396d8036509d1963c1
  • Pointer size: 131 Bytes
  • Size of remote file: 210 kB
demo_tta.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """
11
+ demo_tta.py: PRIMA inference with fine-tuned DeepLabCut SuperAnimal TTA
12
+
13
+ Pipeline:
14
+ 1. Run Detectron2 to detect animals in the input image.
15
+ 2. Run PRIMA on each detected animal to obtain 3D pose/shape estimation.
16
+ 3. Run a fine-tuned DeepLabCut SuperAnimal pose model (Animal3D 26-joint
17
+ layout) to obtain 2D keypoints already in PRIMA topology. The fine-tuned
18
+ snapshot is wired into DLC's
19
+ ``superanimal_analyze_images`` via the ``customized_pose_checkpoint``
20
+ and ``customized_model_config`` kwargs.
21
+ 4. Run test-time adaptation (TTA) with user-specified lr and num_iters
22
+ to further optimize the 3D pose and shape estimation.
23
+ 5. Render and save before/after TTA results (PNG + OBJ) and the
24
+ 26-keypoint visualization (PNG).
25
+ """
26
+
27
+
28
+ from pathlib import Path
29
+ import argparse
30
+ import copy
31
+ import os
32
+ import tempfile
33
+ import warnings
34
+
35
+ import cv2
36
+ import numpy as np
37
+ import torch
38
+ import torch.nn.functional as F
39
+ import torch.utils.data
40
+ from tqdm import tqdm
41
+
42
+ from prima.models import load_prima
43
+ from prima.utils import recursive_to
44
+ from prima.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
45
+ from prima.utils.detection import ANIMAL_COCO_IDS, select_animal_boxes
46
+ from prima.utils.weights import DEFAULT_HF_REPO_ID, resolve_prima_checkpoint_path
47
+
48
+ warnings.filterwarnings("ignore")
49
+
50
+ LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353)
51
+ GREEN = (0.65, 0.86, 0.74)
52
+
53
+ REPO_ROOT = Path(__file__).resolve().parent
54
+
55
+
56
+ def load_renderer_components():
57
+ try:
58
+ from prima.utils.renderer import Renderer, cam_crop_to_full
59
+ except Exception as exc:
60
+ raise RuntimeError(
61
+ "Cannot initialize the PRIMA renderer. Rendering requires a working "
62
+ "pyrender/OpenGL backend such as EGL or OSMesa. Install the missing "
63
+ "OpenGL runtime for this environment, or run in an environment where "
64
+ "PYOPENGL_PLATFORM=egl/osmesa works."
65
+ ) from exc
66
+ return Renderer, cam_crop_to_full
67
+
68
+
69
+ def denorm_patch_to_rgb(img_tensor: torch.Tensor) -> np.ndarray:
70
+ patch = (img_tensor.detach().cpu() * (DEFAULT_STD[:, None, None]) + DEFAULT_MEAN[:, None, None]) / 255.0
71
+ patch = patch.permute(1, 2, 0).numpy()
72
+ return np.clip(patch, 0.0, 1.0)
73
+
74
+
75
+ def save_keypoint_vis(patch_rgb: np.ndarray, kpts_xyc: np.ndarray, save_path: str) -> None:
76
+ vis = cv2.cvtColor((patch_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2BGR).copy()
77
+ num_kpts = len(kpts_xyc)
78
+
79
+ for i, (x, y, c) in enumerate(kpts_xyc):
80
+ if c <= 0:
81
+ continue
82
+
83
+ # Use distinct color for each keypoint (OpenCV uses BGR)
84
+ hue = int(179 * i / max(1, num_kpts - 1))
85
+ color_bgr = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2BGR)[0, 0]
86
+ color_bgr = (int(color_bgr[0]), int(color_bgr[1]), int(color_bgr[2]))
87
+
88
+ cx, cy = int(round(float(x))), int(round(float(y)))
89
+ cv2.circle(vis, (cx, cy), 3, color_bgr, -1)
90
+ cv2.putText(vis, str(i), (cx + 3, cy - 3), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1, cv2.LINE_AA)
91
+
92
+ cv2.imwrite(save_path, vis)
93
+
94
+
95
+ def resolve_sa_weights_path(local_path: str) -> str:
96
+ """Return a local path to the fine-tuned SuperAnimal .pt snapshot.
97
+
98
+ If ``local_path`` is empty, downloads ``sa_finetune_hrnet_w32.pt`` from the
99
+ ``MLAdaptiveIntelligence/FMPose3D`` Hugging Face repo (cached under
100
+ ``~/.cache/huggingface``).
101
+ """
102
+ if local_path:
103
+ return local_path
104
+ try:
105
+ from huggingface_hub import hf_hub_download
106
+ except ImportError:
107
+ raise ImportError(
108
+ "huggingface_hub is required to auto-download the fine-tuned "
109
+ "SuperAnimal weights. Install with `pip install huggingface_hub`, "
110
+ "or pass --saved_2d_model_path with a local .pt file."
111
+ ) from None
112
+ repo_id = "MLAdaptiveIntelligence/FMPose3D"
113
+ filename = "sa_finetune_hrnet_w32.pt"
114
+ try:
115
+ cached_path = hf_hub_download(repo_id=repo_id, filename=filename, local_files_only=True)
116
+ except Exception:
117
+ print(f"No --saved_2d_model_path provided; downloading '{filename}' from {repo_id}...")
118
+ return hf_hub_download(repo_id=repo_id, filename=filename)
119
+
120
+ print(f"Using cached SuperAnimal weights: {cached_path}")
121
+ return cached_path
122
+
123
+
124
+ def run_superanimal_on_patch(patch_rgb: np.ndarray, args, tmp_dir: str):
125
+ """Predict 26-joint 2D keypoints on a single PRIMA patch using a
126
+ fine-tuned DeepLabCut SuperAnimal snapshot.
127
+
128
+ Returns an ``(26, 3)`` array of ``(x, y, confidence)`` in patch
129
+ pixel coordinates, or ``None`` if no individual was detected.
130
+ """
131
+ try:
132
+ from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images
133
+ except Exception as e:
134
+ raise RuntimeError(
135
+ "Cannot import DeepLabCut SuperAnimal API. Please install deeplabcut with pose_estimation_pytorch support."
136
+ ) from e
137
+
138
+ patch_path = os.path.join(tmp_dir, "patch.png")
139
+ cv2.imwrite(patch_path, cv2.cvtColor((patch_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
140
+
141
+ dlc_device = "cuda" if torch.cuda.is_available() else "cpu"
142
+ preds = superanimal_analyze_images(
143
+ superanimal_name=args.superanimal_name,
144
+ model_name=args.superanimal_model_name,
145
+ detector_name=args.superanimal_detector_name,
146
+ images=patch_path,
147
+ max_individuals=args.superanimal_max_individuals,
148
+ out_folder=tmp_dir,
149
+ device=dlc_device,
150
+ customized_model_config=args.pytorch_config_2d_path,
151
+ customized_pose_checkpoint=args.saved_2d_model_path,
152
+ )
153
+
154
+ payload = preds.get(patch_path, None)
155
+ if payload is None:
156
+ return None
157
+ bodyparts = payload.get("bodyparts", None)
158
+ if bodyparts is None or len(bodyparts) == 0:
159
+ return None
160
+
161
+ best_idx = int(np.argmax(bodyparts[..., 2].mean(axis=1)))
162
+ return bodyparts[best_idx].astype(np.float32)
163
+
164
+
165
+ def render_and_save(renderer, cam_crop_to_full_fn, out, batch, img_fn, animal_id, out_folder, suffix, side_view, save_mesh):
166
+ pred_cam = out['pred_cam']
167
+ box_center = batch['box_center'].float()
168
+ box_size = batch['box_size'].float()
169
+ img_size = batch['img_size'].float()
170
+ scaled_focal_length = batch['focal_length'][0, 0] / batch['img'].shape[-1] * img_size.max()
171
+ pred_cam_t_full = cam_crop_to_full_fn(pred_cam, box_center, box_size, img_size, scaled_focal_length)
172
+
173
+ white_img = (torch.ones_like(batch['img'][0]).cpu() - DEFAULT_MEAN[:, None, None] / 255) / (
174
+ DEFAULT_STD[:, None, None] / 255
175
+ )
176
+ input_patch = denorm_patch_to_rgb(batch['img'][0])
177
+
178
+ regression_img = renderer(
179
+ out['pred_vertices'][0].detach().cpu().numpy(),
180
+ out['pred_cam_t'][0].detach().cpu().numpy(),
181
+ batch['img'][0],
182
+ mesh_base_color=GREEN,
183
+ scene_bg_color=(1, 1, 1),
184
+ )
185
+
186
+ final_img = np.concatenate([input_patch, regression_img], axis=1)
187
+ if side_view:
188
+ side_img = renderer(
189
+ out['pred_vertices'][0].detach().cpu().numpy(),
190
+ out['pred_cam_t'][0].detach().cpu().numpy(),
191
+ white_img,
192
+ mesh_base_color=GREEN,
193
+ scene_bg_color=(1, 1, 1),
194
+ side_view=True,
195
+ )
196
+ final_img = np.concatenate([final_img, side_img], axis=1)
197
+
198
+ cv2.imwrite(
199
+ os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.png'),
200
+ cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR),
201
+ )
202
+
203
+ if save_mesh:
204
+ verts = out['pred_vertices'][0].detach().cpu().numpy()
205
+ cam_t = pred_cam_t_full[0].detach().cpu().numpy()
206
+ tmesh = renderer.vertices_to_trimesh(verts, cam_t.copy(), LIGHT_BLUE)
207
+ tmesh.export(os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.obj'))
208
+
209
+
210
+ def tta_optimize(model, batch, gt_kpts_norm, num_iters, lr):
211
+ model.eval()
212
+
213
+ if hasattr(model, 'backbone'):
214
+ for p in model.backbone.parameters():
215
+ p.requires_grad = False
216
+
217
+ orig_smal_head_state = copy.deepcopy(model.smal_head.state_dict())
218
+ model.smal_head.freeze_except_regression_heads()
219
+ tta_params = model.smal_head.get_tta_parameters(mode='all')
220
+ optimizer = torch.optim.Adam(tta_params, lr=lr)
221
+
222
+ valid_mask = (gt_kpts_norm[..., 2] > 0).float().unsqueeze(-1)
223
+ gt_xy = gt_kpts_norm[..., :2]
224
+
225
+ for _ in range(num_iters):
226
+ optimizer.zero_grad()
227
+ out = model(batch)
228
+ pred_xy = out['pred_keypoints_2d']
229
+ loss = F.mse_loss(pred_xy * valid_mask, gt_xy * valid_mask, reduction='sum') / (valid_mask.sum() + 1e-6)
230
+ loss.backward()
231
+ optimizer.step()
232
+
233
+ with torch.no_grad():
234
+ out_after = model(batch)
235
+
236
+ model.smal_head.load_state_dict(orig_smal_head_state)
237
+ model.smal_head.unfreeze_all()
238
+
239
+ return out_after
240
+
241
+
242
+ def main():
243
+ parser = argparse.ArgumentParser(description='PRIMA + SuperAnimal + TTA demo')
244
+ parser.add_argument('--checkpoint', type=str, default='',
245
+ help='Path to pretrained PRIMA checkpoint. Empty -> auto-download the default Stage 1 checkpoint.')
246
+ parser.add_argument('--hf-repo-id', '--hf_repo_id', dest='hf_repo_id',
247
+ type=str, default=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_REPO_ID),
248
+ help='Hugging Face repo ID containing PRIMA demo assets')
249
+ parser.add_argument('--no-auto-download', '--no_auto_download', dest='no_auto_download', action='store_true',
250
+ help='Disable automatic download of missing PRIMA demo assets')
251
+ parser.add_argument('--img_path', type=str, default=None, help='Single image path')
252
+ parser.add_argument('--img_folder', type=str, default='demo_data/', help='Folder with input images')
253
+ parser.add_argument('--out_folder', type=str, default='demo_out_tta', help='Output folder')
254
+ parser.add_argument('--side_view', dest='side_view', action='store_true', default=False, help='Render side view')
255
+ parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, help='Save meshes')
256
+ parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg', '*.JPEG'], help='Image globs')
257
+ parser.add_argument('--det_thresh', type=float, default=0.7, help='Detectron2 score threshold for animals')
258
+
259
+ parser.add_argument('--tta_lr', type=float, default=1e-6, help='TTA learning rate')
260
+ parser.add_argument('--tta_num_iters', type=int, default=30, help='TTA iterations')
261
+ parser.add_argument('--kp_conf_thresh', type=float, default=0.1, help='Keypoint confidence threshold')
262
+
263
+ parser.add_argument('--superanimal_name', type=str, default='superanimal_quadruped')
264
+ parser.add_argument('--superanimal_model_name', type=str, default='hrnet_w32')
265
+ parser.add_argument('--superanimal_detector_name', type=str, default='fasterrcnn_resnet50_fpn_v2')
266
+ parser.add_argument('--superanimal_max_individuals', type=int, default=1)
267
+ parser.add_argument('--saved_2d_model_path', type=str, default='',
268
+ help='Path to the fine-tuned SuperAnimal 26-joint .pt snapshot. '
269
+ 'Empty -> auto-download sa_finetune_hrnet_w32.pt from '
270
+ 'MLAdaptiveIntelligence/FMPose3D on Hugging Face Hub.')
271
+ parser.add_argument('--pytorch_config_2d_path', type=str,
272
+ default=str(Path(__file__).resolve().parent / 'configs' / 'sa_finetune_hrnet_w32.yaml'),
273
+ help='Path to the DLC pytorch config yaml for the fine-tuned snapshot. '
274
+ 'Defaults to the bundled configs/sa_finetune_hrnet_w32.yaml.')
275
+
276
+ args = parser.parse_args()
277
+ checkpoint_path = resolve_prima_checkpoint_path(
278
+ args.checkpoint,
279
+ data_dir=REPO_ROOT / "data",
280
+ auto_download=not args.no_auto_download,
281
+ hf_repo_id=args.hf_repo_id,
282
+ )
283
+ args.saved_2d_model_path = resolve_sa_weights_path(args.saved_2d_model_path)
284
+
285
+ model, model_cfg = load_prima(checkpoint_path)
286
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
287
+ model = model.to(device)
288
+ model.eval()
289
+
290
+ Renderer, cam_crop_to_full_fn = load_renderer_components()
291
+ renderer = Renderer(model_cfg, faces=model.smal.faces)
292
+ os.makedirs(args.out_folder, exist_ok=True)
293
+
294
+ import detectron2.config
295
+ import detectron2.engine
296
+ from detectron2 import model_zoo
297
+
298
+ cfg = detectron2.config.get_cfg()
299
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"))
300
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
301
+ cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
302
+ cfg.MODEL.DEVICE = device.type
303
+ detector = detectron2.engine.DefaultPredictor(cfg)
304
+
305
+ if args.img_path is not None:
306
+ img_paths = [Path(args.img_path)]
307
+ else:
308
+ img_paths = sorted([img for end in args.file_type for img in Path(args.img_folder).glob(end)])
309
+
310
+ for img_path in img_paths:
311
+ img_bgr = cv2.imread(str(img_path))
312
+ if img_bgr is None:
313
+ print(f"[WARN] Cannot read image: {img_path}")
314
+ continue
315
+ det_out = detector(img_bgr)
316
+ det_instances = det_out['instances']
317
+ boxes, suppressed = select_animal_boxes(
318
+ det_instances,
319
+ animal_class_ids=ANIMAL_COCO_IDS,
320
+ score_threshold=args.det_thresh,
321
+ )
322
+ if suppressed > 0:
323
+ print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s) in {img_path}")
324
+
325
+ if len(boxes) == 0:
326
+ print(f"[INFO] No animal detected in {img_path}")
327
+ continue
328
+
329
+ dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
330
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
331
+
332
+ for batch in tqdm(dataloader, desc=f"{img_path.name}"):
333
+ batch = recursive_to(batch, device)
334
+ with torch.no_grad():
335
+ out_before = model(batch)
336
+
337
+ img_fn = img_path.stem
338
+ animal_id = int(batch['animalid'][0])
339
+
340
+ render_and_save(
341
+ renderer,
342
+ cam_crop_to_full_fn,
343
+ out_before,
344
+ batch,
345
+ img_fn,
346
+ animal_id,
347
+ args.out_folder,
348
+ suffix='before_tta',
349
+ side_view=args.side_view,
350
+ save_mesh=args.save_mesh,
351
+ )
352
+
353
+ patch_rgb = denorm_patch_to_rgb(batch['img'][0])
354
+ with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir:
355
+ kpts_xyc = run_superanimal_on_patch(patch_rgb, args, tmp_dir)
356
+
357
+ if kpts_xyc is None:
358
+ print(f"[WARN] No SuperAnimal keypoints for {img_fn}_{animal_id}, skip TTA")
359
+ continue
360
+
361
+ kpts_xyc[kpts_xyc[:, 2] < args.kp_conf_thresh, 2] = 0.0
362
+
363
+ save_keypoint_vis(
364
+ patch_rgb,
365
+ kpts_xyc,
366
+ os.path.join(args.out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png"),
367
+ )
368
+ np.save(os.path.join(args.out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy"), kpts_xyc)
369
+
370
+ patch_h, patch_w = patch_rgb.shape[:2]
371
+ kpts_norm = kpts_xyc.copy()
372
+ kpts_norm[:, 0] = kpts_norm[:, 0] / float(patch_w) - 0.5
373
+ kpts_norm[:, 1] = kpts_norm[:, 1] / float(patch_h) - 0.5
374
+ gt_kpts_norm = torch.from_numpy(kpts_norm[None]).to(device=device, dtype=batch['img'].dtype)
375
+
376
+ out_after = tta_optimize(
377
+ model,
378
+ batch,
379
+ gt_kpts_norm,
380
+ num_iters=args.tta_num_iters,
381
+ lr=args.tta_lr,
382
+ )
383
+
384
+ render_and_save(
385
+ renderer,
386
+ cam_crop_to_full_fn,
387
+ out_after,
388
+ batch,
389
+ img_fn,
390
+ animal_id,
391
+ args.out_folder,
392
+ suffix='after_tta',
393
+ side_view=args.side_view,
394
+ save_mesh=args.save_mesh,
395
+ )
396
+
397
+
398
+ if __name__ == '__main__':
399
+ main()
demo_tta.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Empty checkpoint uses the default PRIMA Stage 1 inference checkpoint:
3
+ # data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt
4
+ #
5
+ # This standard path is auto-downloaded from the PRIMA Hugging Face repo if missing.
6
+ # To use another local checkpoint instead, update this path.
7
+ # For example: checkpoint='data/PRIMAS3/checkpoints/s3ckpt.ckpt'
8
+ checkpoint='data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt'
9
+
10
+ python3 demo_tta.py \
11
+ --checkpoint "${checkpoint}" \
12
+ --img_folder demo_data/ \
13
+ --out_folder demo_out_tta/ \
14
+ --tta_lr 1e-6 \
15
+ --tta_num_iters 30
eval.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+ import torch
13
+ from prima.utils import recursive_to
14
+ from prima.utils.evaluate_metric import Evaluator
15
+ from prima.datasets.datasets import EvaluationDataset
16
+ import argparse
17
+ from torch.utils.data import DataLoader
18
+ from prima.models.prima import PRIMA
19
+ from prima.configs import get_config
20
+ torch.multiprocessing.set_sharing_strategy('file_system')
21
+
22
+
23
+ def main(args):
24
+ cfg = get_config(args.config)
25
+ default_cfg = get_config(args.default_eval_config)
26
+ model = PRIMA.load_from_checkpoint(args.checkpoint, cfg=cfg, strict=False)
27
+ model.eval()
28
+ model = model.to(args.device)
29
+
30
+ smal_evaluator = Evaluator(smal_model=model.smal, image_size=cfg.MODEL.IMAGE_SIZE)
31
+ cfg_eval_dataset = dict(default_cfg.DATASETS)
32
+ aug_cfg = cfg_eval_dataset.pop("CONFIG", None) # augmentation config is not used in evaluation
33
+
34
+ if args.dataset.upper() == "ALL":
35
+ for key in cfg_eval_dataset.keys():
36
+ print(f"-------- Evaluate {key} dataset ------------")
37
+ eval_one_dataset(cfg_eval_dataset[key], default_cfg, cfg, model,
38
+ evaluator=smal_evaluator,
39
+ aug_cfg=aug_cfg,
40
+ key=key,
41
+ device=args.device)
42
+ print(f"-------{key} Dataset evaluate finish ------")
43
+ else:
44
+ print(f"-------- Evaluate {args.dataset} dataset ------------")
45
+ eval_one_dataset(cfg_eval_dataset[args.dataset], default_cfg, cfg, model,
46
+ evaluator=smal_evaluator,
47
+ aug_cfg=aug_cfg,
48
+ key=args.dataset,
49
+ device=args.device)
50
+ print(f"-------{args.dataset} Dataset evaluate finish ------")
51
+
52
+
53
+ def eval_one_dataset(dataset_cfg, default_cfg, cfg, model, evaluator, aug_cfg, key, device='cuda'):
54
+ dataset = EvaluationDataset(root_image=dataset_cfg['ROOT_IMAGE'],
55
+ json_file=dataset_cfg['JSON_FILE']['TEST'],
56
+ augm_config=aug_cfg, focal_length=cfg.SMAL.get("FOCAL_LENGTH", 1000),
57
+ image_size=cfg.MODEL.IMAGE_SIZE,
58
+ )
59
+ dataloader = DataLoader(dataset, batch_size=1, num_workers=cfg.GENERAL.NUM_WORKERS)
60
+
61
+ bar = tqdm(dataloader)
62
+ pa_mpjpe_list, pck_list, auc_list, pa_mpvpe_list = [], [], [], []
63
+ for i, batch in enumerate(bar):
64
+ batch = recursive_to(batch, device)
65
+ with torch.no_grad():
66
+ output = model(batch)
67
+
68
+ if key in ["ANIMAL3D", "CONTROL_ANIMAL3D"]:
69
+ pa_mpjpe, pa_mpvpe = evaluator.eval_3d(output, batch)
70
+ else:
71
+ pa_mpjpe, pa_mpvpe = 0., 0.
72
+ pck, auc = evaluator.eval_2d(output, batch, pck_threshold=default_cfg.METRIC.PCK_THRESHOLD)
73
+
74
+ pa_mpjpe_list.append(pa_mpjpe)
75
+ pa_mpvpe_list.append(pa_mpvpe)
76
+ auc_list.append(auc)
77
+ pck_list.append(pck)
78
+
79
+ bar.set_postfix(PA_MPJPE=pa_mpjpe,
80
+ PA_MPVPE=pa_mpvpe,
81
+ AUC=auc,
82
+ pck=pck,)
83
+
84
+ print("---------------- 3D metric -----------------")
85
+ print(f"Avg PA-MPJPE: {np.mean(pa_mpjpe_list)}")
86
+ print(f"Avg PA-MPVPE: {np.mean(pa_mpvpe_list)}")
87
+
88
+ print("--------------- 2D metric ------------------")
89
+ print(f"AUC: {np.mean(auc_list)}")
90
+ pck_list = np.array(pck_list)
91
+ for _, th in enumerate(default_cfg.METRIC.PCK_THRESHOLD):
92
+ print(f"PCK@{th}: {np.mean(pck_list[:, _])}")
93
+
94
+
95
+ if __name__ == "__main__":
96
+ parser = argparse.ArgumentParser()
97
+ parser.add_argument("--config", type=str, help="Path to config file", required=True)
98
+ parser.add_argument("--checkpoint", type=str, help="Path to checkpoint file", required=True)
99
+ parser.add_argument("--default_eval_config", type=str, default="./configs_hydra/experiment/default_val.yaml")
100
+ parser.add_argument("--dataset", type=str, default="ALL")
101
+ parser.add_argument("--device", type=str, default="cuda", help="Device to use for evaluation")
102
+ args = parser.parse_args()
103
+ main(args)
images/teaser.png ADDED

Git LFS Details

  • SHA256: a617ca4fd37de03e2db4ccf397ce9841ed32c3fe18c766c4832d41af574ad746
  • Pointer size: 132 Bytes
  • Size of remote file: 4.29 MB
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ libosmesa6
2
+ libgl1
3
+ libegl1
4
+ libgles2
prima/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """Top-level package for PRIMA.
11
+
12
+ This package contains models, datasets and utilities for
13
+ 3D animal pose and shape estimation.
14
+ """
15
+
16
+ from importlib.metadata import PackageNotFoundError, version
17
+
18
+
19
+ try: # pragma: no cover - best effort during development
20
+ __version__ = version("prima-animal")
21
+ except PackageNotFoundError: # pragma: no cover
22
+ __version__ = "0.0.0"
23
+
24
+
25
+ __all__ = ["__version__"]
prima/configs/__init__.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from typing import Dict
11
+ from yacs.config import CfgNode as CN
12
+
13
+ def to_lower(x: Dict) -> Dict:
14
+ """
15
+ Convert all dictionary keys to lowercase
16
+ Args:
17
+ x (dict): Input dictionary
18
+ Returns:
19
+ dict: Output dictionary with all keys converted to lowercase
20
+ """
21
+ return {k.lower(): v for k, v in x.items()}
22
+
23
+
24
+ _C = CN(new_allowed=True)
25
+
26
+ _C.GENERAL = CN(new_allowed=True)
27
+ _C.GENERAL.RESUME = True
28
+ _C.GENERAL.TIME_TO_RUN = 3300
29
+ _C.GENERAL.VAL_STEPS = 100
30
+ _C.GENERAL.LOG_STEPS = 100
31
+ _C.GENERAL.CHECKPOINT_STEPS = 20000
32
+ _C.GENERAL.CHECKPOINT_DIR = "checkpoints"
33
+ _C.GENERAL.SUMMARY_DIR = "tensorboard"
34
+ _C.GENERAL.NUM_GPUS = 1
35
+ _C.GENERAL.NUM_WORKERS = 4
36
+ _C.GENERAL.MIXED_PRECISION = True
37
+ _C.GENERAL.ALLOW_CUDA = True
38
+ _C.GENERAL.PIN_MEMORY = False
39
+ _C.GENERAL.DISTRIBUTED = False
40
+ _C.GENERAL.LOCAL_RANK = 0
41
+ _C.GENERAL.USE_SYNCBN = False
42
+ _C.GENERAL.WORLD_SIZE = 1
43
+ _C.GENERAL.PREFETCH_FACTOR = 2
44
+
45
+ _C.TRAIN = CN(new_allowed=True)
46
+ _C.TRAIN.NUM_EPOCHS = 100
47
+ _C.TRAIN.SHUFFLE = True
48
+ _C.TRAIN.WARMUP = False
49
+ _C.TRAIN.NORMALIZE_PER_IMAGE = False
50
+ _C.TRAIN.CLIP_GRAD = False
51
+ _C.TRAIN.CLIP_GRAD_VALUE = 1.0
52
+ _C.LOSS_WEIGHTS = CN(new_allowed=True)
53
+
54
+ _C.DATASETS = CN(new_allowed=True)
55
+
56
+ _C.MODEL = CN(new_allowed=True)
57
+ _C.MODEL.IMAGE_SIZE = 224
58
+
59
+ _C.EXTRA = CN(new_allowed=True)
60
+ _C.EXTRA.FOCAL_LENGTH = 5000
61
+
62
+ _C.DATASETS.CONFIG = CN(new_allowed=True)
63
+ _C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
64
+ _C.DATASETS.CONFIG.ROT_FACTOR = 30
65
+ _C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
66
+ _C.DATASETS.CONFIG.COLOR_SCALE = 0.2
67
+ _C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
68
+ _C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
69
+ _C.DATASETS.CONFIG.DO_FLIP = False
70
+ _C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
71
+ _C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
72
+
73
+
74
+ def default_config() -> CN:
75
+ """
76
+ Get a yacs CfgNode object with the default config values.
77
+ """
78
+ # Return a clone so that the defaults will not be altered
79
+ # This is for the "local variable" use pattern
80
+ return _C.clone()
81
+
82
+
83
+ def get_config(config_file: str, merge: bool = True) -> CN:
84
+ """
85
+ Read a config file and optionally merge it with the default config file.
86
+ Args:
87
+ config_file (str): Path to config file.
88
+ merge (bool): Whether to merge with the default config or not.
89
+ Returns:
90
+ CfgNode: Config as a yacs CfgNode object.
91
+ """
92
+ if merge:
93
+ cfg = default_config()
94
+ else:
95
+ cfg = CN(new_allowed=True)
96
+ cfg.merge_from_file(config_file)
97
+
98
+ cfg.freeze()
99
+ return cfg
prima/models/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from .prima import PRIMA
11
+
12
+
13
+ def load_prima(checkpoint_path):
14
+ from pathlib import Path
15
+ from ..configs import get_config
16
+ model_cfg = str(Path(checkpoint_path).parent.parent / '.hydra/config.yaml')
17
+ model_cfg = get_config(model_cfg)
18
+
19
+ # Override some config values, to crop bbox correctly
20
+ if (model_cfg.MODEL.BACKBONE.TYPE == 'vit') and ('BBOX_SHAPE' not in model_cfg.MODEL):
21
+ model_cfg.defrost()
22
+ assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
23
+ model_cfg.MODEL.BBOX_SHAPE = [192, 256]
24
+ model_cfg.freeze()
25
+ if (model_cfg.MODEL.BACKBONE.TYPE == 'dinov3') and ('BBOX_SHAPE' not in model_cfg.MODEL):
26
+ model_cfg.defrost()
27
+ assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for dino backbone"
28
+ model_cfg.MODEL.BBOX_SHAPE = [256, 256]
29
+ model_cfg.freeze()
30
+
31
+ if (model_cfg.MODEL.BACKBONE.TYPE == 'dinov2') and ('BBOX_SHAPE' not in model_cfg.MODEL):
32
+ model_cfg.defrost()
33
+ assert model_cfg.MODEL.IMAGE_SIZE == 252, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 252 for dino backbone"
34
+ model_cfg.MODEL.BBOX_SHAPE = [252, 252]
35
+ model_cfg.freeze()
36
+
37
+
38
+
39
+ # Update config to be compatible with demo
40
+ if ('PRETRAINED_WEIGHTS' in model_cfg.MODEL.BACKBONE):
41
+ model_cfg.defrost()
42
+ model_cfg.MODEL.BACKBONE.pop('PRETRAINED_WEIGHTS')
43
+ model_cfg.freeze()
44
+
45
+ # Offscreen training renderer is not needed for demo/inference startup and
46
+ # can fail on some local OpenGL backends.
47
+ model = PRIMA.load_from_checkpoint(
48
+ checkpoint_path,
49
+ strict=False,
50
+ cfg=model_cfg,
51
+ map_location='cpu',
52
+ init_renderer=False,
53
+ )
54
+ return model, model_cfg
prima/models/backbones/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from .vit import vith
11
+
12
+
13
+
14
+
15
+ def create_backbone(cfg):
16
+ if cfg.MODEL.BACKBONE.TYPE in ['vith','concat','aa']: # vit bb will be used in these three cases - animal feature extractor
17
+ return vith(cfg)
18
+ else:
19
+ raise NotImplementedError('Backbone type is not implemented')
prima/models/backbones/vit.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ # Copyright (c) OpenMMLab. All rights reserved.
11
+ import math
12
+
13
+ import torch
14
+ from functools import partial
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torch.utils.checkpoint as checkpoint
18
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
19
+
20
+
21
+ def vith(cfg):
22
+ return ViT(
23
+ img_size=(256, 192),
24
+ patch_size=16,
25
+ embed_dim=1280,
26
+ depth=32,
27
+ num_heads=16,
28
+ ratio=1,
29
+ use_checkpoint=False,
30
+ # use_checkpoint=True,
31
+ mlp_ratio=4,
32
+ qkv_bias=True,
33
+ drop_path_rate=0.55,
34
+ use_cls=True, # cls for animal family classification
35
+ )
36
+
37
+
38
+ def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
39
+ """
40
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
41
+ dimension for the original embeddings.
42
+ Args:
43
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
44
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
45
+ hw (Tuple): size of input image tokens.
46
+
47
+ Returns:
48
+ Absolute positional embeddings after processing with shape (1, H, W, C)
49
+ """
50
+ cls_token = None
51
+ B, L, C = abs_pos.shape
52
+ if has_cls_token:
53
+ cls_token = abs_pos[:, 0:1]
54
+ abs_pos = abs_pos[:, 1:]
55
+
56
+ if ori_h != h or ori_w != w:
57
+ new_abs_pos = F.interpolate(
58
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
59
+ size=(h, w),
60
+ mode="bicubic",
61
+ align_corners=False,
62
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
63
+
64
+ else:
65
+ new_abs_pos = abs_pos
66
+
67
+ if cls_token is not None:
68
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
69
+ return new_abs_pos
70
+
71
+
72
+ class DropPath(nn.Module):
73
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
74
+ """
75
+
76
+ def __init__(self, drop_prob=None):
77
+ super(DropPath, self).__init__()
78
+ self.drop_prob = drop_prob
79
+
80
+ def forward(self, x):
81
+ return drop_path(x, self.drop_prob, self.training)
82
+
83
+ def extra_repr(self):
84
+ return 'p={}'.format(self.drop_prob)
85
+
86
+
87
+ class Mlp(nn.Module):
88
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
89
+ super().__init__()
90
+ out_features = out_features or in_features
91
+ hidden_features = hidden_features or in_features
92
+ self.fc1 = nn.Linear(in_features, hidden_features)
93
+ self.act = act_layer()
94
+ self.fc2 = nn.Linear(hidden_features, out_features)
95
+ self.drop = nn.Dropout(drop)
96
+
97
+ def forward(self, x):
98
+ x = self.fc1(x)
99
+ x = self.act(x)
100
+ x = self.fc2(x)
101
+ x = self.drop(x)
102
+ return x
103
+
104
+
105
+ class Attention(nn.Module):
106
+ def __init__(
107
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
108
+ proj_drop=0., attn_head_dim=None):
109
+ super().__init__()
110
+ self.num_heads = num_heads
111
+ head_dim = dim // num_heads
112
+ self.dim = dim
113
+
114
+ if attn_head_dim is not None:
115
+ head_dim = attn_head_dim
116
+ all_head_dim = head_dim * self.num_heads
117
+
118
+ self.scale = qk_scale or head_dim ** -0.5
119
+
120
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
121
+
122
+ self.attn_drop = nn.Dropout(attn_drop)
123
+ self.proj = nn.Linear(all_head_dim, dim)
124
+ self.proj_drop = nn.Dropout(proj_drop)
125
+
126
+ def forward(self, x):
127
+ B, N, C = x.shape
128
+ qkv = self.qkv(x)
129
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
130
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
131
+
132
+ q = q * self.scale
133
+ attn = (q @ k.transpose(-2, -1))
134
+ attn = attn.softmax(dim=-1)
135
+ attn = self.attn_drop(attn)
136
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
137
+
138
+ x = self.proj(x)
139
+ x = self.proj_drop(x)
140
+
141
+ return x
142
+
143
+
144
+ class Block(nn.Module):
145
+
146
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
147
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
148
+ norm_layer=nn.LayerNorm, attn_head_dim=None,
149
+ ):
150
+ super().__init__()
151
+
152
+ self.norm1 = norm_layer(dim)
153
+ self.attn = Attention(
154
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
155
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
156
+ )
157
+
158
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
159
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
160
+ self.norm2 = norm_layer(dim)
161
+ mlp_hidden_dim = int(dim * mlp_ratio)
162
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
163
+
164
+ def forward(self, x):
165
+ x = x + self.drop_path(self.attn(self.norm1(x)))
166
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
167
+ return x
168
+
169
+
170
+ class PatchEmbed(nn.Module):
171
+ """ Image to Patch Embedding
172
+ """
173
+
174
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
175
+ super().__init__()
176
+ img_size = to_2tuple(img_size)
177
+ patch_size = to_2tuple(patch_size)
178
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
179
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
180
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
181
+ self.img_size = img_size
182
+ self.patch_size = patch_size
183
+ self.num_patches = num_patches
184
+
185
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio),
186
+ padding=4 + 2 * (ratio // 2 - 1))
187
+
188
+ def forward(self, x, **kwargs):
189
+ B, C, H, W = x.shape
190
+ x = self.proj(x)
191
+ Hp, Wp = x.shape[2], x.shape[3]
192
+
193
+ x = x.flatten(2).transpose(1, 2)
194
+ return x, (Hp, Wp)
195
+
196
+
197
+ class HybridEmbed(nn.Module):
198
+ """ CNN Feature Map Embedding
199
+ Extract feature map from CNN, flatten, project to embedding dim.
200
+ """
201
+
202
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
203
+ super().__init__()
204
+ assert isinstance(backbone, nn.Module)
205
+ img_size = to_2tuple(img_size)
206
+ self.img_size = img_size
207
+ self.backbone = backbone
208
+ if feature_size is None:
209
+ with torch.no_grad():
210
+ training = backbone.training
211
+ if training:
212
+ backbone.eval()
213
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
214
+ feature_size = o.shape[-2:]
215
+ feature_dim = o.shape[1]
216
+ backbone.train(training)
217
+ else:
218
+ feature_size = to_2tuple(feature_size)
219
+ feature_dim = self.backbone.feature_info.channels()[-1]
220
+ self.num_patches = feature_size[0] * feature_size[1]
221
+ self.proj = nn.Linear(feature_dim, embed_dim)
222
+
223
+ def forward(self, x):
224
+ x = self.backbone(x)[-1]
225
+ x = x.flatten(2).transpose(1, 2)
226
+ x = self.proj(x)
227
+ return x
228
+
229
+
230
+ class ViT(nn.Module):
231
+
232
+ def __init__(self,
233
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
234
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
235
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
236
+ frozen_stages=-1, ratio=1, last_norm=True, use_cls=False,
237
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
238
+ ):
239
+ # Protect mutable default arguments
240
+ super(ViT, self).__init__()
241
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
242
+ self.num_classes = num_classes
243
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
244
+ self.frozen_stages = frozen_stages
245
+ self.use_checkpoint = use_checkpoint
246
+ self.patch_padding = patch_padding
247
+ self.freeze_attn = freeze_attn
248
+ self.freeze_ffn = freeze_ffn
249
+ self.depth = depth
250
+
251
+ if hybrid_backbone is not None:
252
+ self.patch_embed = HybridEmbed(
253
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
254
+ else:
255
+ self.patch_embed = PatchEmbed(
256
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
257
+ num_patches = self.patch_embed.num_patches
258
+
259
+ # since the pretraining model has class token
260
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
261
+
262
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
263
+
264
+ self.blocks = nn.ModuleList([
265
+ Block(
266
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
267
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
268
+ )
269
+ for i in range(depth)])
270
+
271
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
272
+
273
+ if self.pos_embed is not None:
274
+ trunc_normal_(self.pos_embed, std=.02)
275
+
276
+ self.use_cls = use_cls
277
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
278
+ nn.init.normal_(self.cls_token, std=1e-6)
279
+
280
+ self._freeze_stages()
281
+
282
+ def _freeze_stages(self):
283
+ """Freeze parameters."""
284
+ if self.frozen_stages >= 0:
285
+ self.patch_embed.eval()
286
+ for param in self.patch_embed.parameters():
287
+ param.requires_grad = False
288
+
289
+ for i in range(1, self.frozen_stages + 1):
290
+ m = self.blocks[i]
291
+ m.eval()
292
+ for param in m.parameters():
293
+ param.requires_grad = False
294
+
295
+ if self.freeze_attn:
296
+ for i in range(0, self.depth):
297
+ m = self.blocks[i]
298
+ m.attn.eval()
299
+ m.norm1.eval()
300
+ for param in m.attn.parameters():
301
+ param.requires_grad = False
302
+ for param in m.norm1.parameters():
303
+ param.requires_grad = False
304
+
305
+ if self.freeze_ffn:
306
+ self.pos_embed.requires_grad = False
307
+ self.patch_embed.eval()
308
+ for param in self.patch_embed.parameters():
309
+ param.requires_grad = False
310
+ for i in range(0, self.depth):
311
+ m = self.blocks[i]
312
+ m.mlp.eval()
313
+ m.norm2.eval()
314
+ for param in m.mlp.parameters():
315
+ param.requires_grad = False
316
+ for param in m.norm2.parameters():
317
+ param.requires_grad = False
318
+
319
+ def init_weights(self):
320
+ """Initialize the weights in backbone.
321
+ Args:
322
+ pretrained (str, optional): Path to pre-trained weights.
323
+ Defaults to None.
324
+ """
325
+
326
+ def _init_weights(m):
327
+ if isinstance(m, nn.Linear):
328
+ trunc_normal_(m.weight, std=.02)
329
+ if isinstance(m, nn.Linear) and m.bias is not None:
330
+ nn.init.constant_(m.bias, 0)
331
+ elif isinstance(m, nn.LayerNorm):
332
+ nn.init.constant_(m.bias, 0)
333
+ nn.init.constant_(m.weight, 1.0)
334
+
335
+ self.apply(_init_weights)
336
+
337
+ def get_num_layers(self):
338
+ return len(self.blocks)
339
+
340
+ @torch.jit.ignore
341
+ def no_weight_decay(self):
342
+ return {'pos_embed', 'cls_token'}
343
+
344
+ def forward_features(self, x):
345
+ B, C, H, W = x.shape
346
+ x, (Hp, Wp) = self.patch_embed(x)
347
+
348
+ if self.pos_embed is not None:
349
+ # fit for multiple GPU training
350
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
351
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
352
+
353
+ x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) if self.use_cls else x
354
+ for blk in self.blocks:
355
+ if self.use_checkpoint:
356
+ x = checkpoint.checkpoint(blk, x)
357
+ else:
358
+ x = blk(x)
359
+
360
+ x = self.last_norm(x)
361
+
362
+ cls = x[:, 0] if self.use_cls else None
363
+ x = x[:, 1:] if self.use_cls else x
364
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
365
+
366
+ return xp, cls # shape [B, D, Hp, Wp], [B, D]
367
+
368
+ def forward(self, x):
369
+ x, cls = self.forward_features(x)
370
+ return x, cls
371
+
372
+ def train(self, mode=True):
373
+ """Convert the model into training mode."""
374
+ super().train(mode)
375
+ self._freeze_stages()
prima/models/bioclip_embedding.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ """
11
+ bioclip Embedding Module
12
+ Converts image batch to embeddings that can be concatenated with image features
13
+ """
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ class BioClipEmbedding(nn.Module):
19
+ """
20
+ Embeds images into a feature space using BioClip model that can be combined with image features.
21
+
22
+ Args:
23
+ embed_dim: Output embedding dimension, should match the dimension of image features for concatenation
24
+ """
25
+
26
+ def __init__(self, cfg, embed_dim: int = 1024):
27
+ super().__init__()
28
+
29
+ self.embed_dim = embed_dim
30
+
31
+ import open_clip
32
+
33
+ if cfg.MODEL.BIOCLIP_EMBEDDING.TYPE == 'bioclip2':
34
+ print("[BioClipEmbedding] Using BioClip2 model from Hugging Face Hub")
35
+ self.species_model, _,_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2')
36
+ else:
37
+ self.species_model, _,_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
38
+ # tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip')
39
+
40
+
41
+ self.species_model.eval()
42
+
43
+ # Get the output dimension from the model
44
+ bioclip_output_dim = self.species_model.visual.output_dim
45
+
46
+ # Project to target dimension
47
+ self.projection = nn.Sequential(
48
+ nn.Linear(bioclip_output_dim, embed_dim),
49
+ nn.LayerNorm(embed_dim),
50
+ )
51
+
52
+ def forward(self, images: torch.Tensor) -> torch.Tensor:
53
+ """
54
+ Args:
55
+ images: Tensor of shape (B, C, H, W) representing a batch of images
56
+ Returns:
57
+ Tensor of shape (B, embed_dim) representing the embedded features
58
+ """
59
+ # BioClip expects 224x224 input, resize if needed
60
+ if images.shape[-2:] != (224, 224):
61
+ images_resized = F.interpolate(images, size=(224, 224), mode='bilinear', align_corners=False)
62
+ else:
63
+ images_resized = images
64
+
65
+ with torch.no_grad():
66
+ image_features = self.species_model.encode_image(images_resized)
67
+
68
+ projected_features = self.projection(image_features)
69
+
70
+ return projected_features
prima/models/components/__init__.py ADDED
File without changes
prima/models/components/model_utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
11
+ # All rights reserved.
12
+
13
+ # This source code is licensed under the license found in the
14
+ # LICENSE file in the root directory of this source tree.
15
+
16
+
17
+ import copy
18
+ from typing import Tuple
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+
26
+ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
27
+ """
28
+ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
29
+ that are temporally closest to the current frame at `frame_idx`. Here, we take
30
+ - a) the closest conditioning frame before `frame_idx` (if any);
31
+ - b) the closest conditioning frame after `frame_idx` (if any);
32
+ - c) any other temporally closest conditioning frames until reaching a total
33
+ of `max_cond_frame_num` conditioning frames.
34
+
35
+ Outputs:
36
+ - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
37
+ - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
38
+ """
39
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
40
+ selected_outputs = cond_frame_outputs
41
+ unselected_outputs = {}
42
+ else:
43
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
44
+ selected_outputs = {}
45
+
46
+ # the closest conditioning frame before `frame_idx` (if any)
47
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
48
+ if idx_before is not None:
49
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
50
+
51
+ # the closest conditioning frame after `frame_idx` (if any)
52
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
53
+ if idx_after is not None:
54
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
55
+
56
+ # add other temporally closest conditioning frames until reaching a total
57
+ # of `max_cond_frame_num` conditioning frames.
58
+ num_remain = max_cond_frame_num - len(selected_outputs)
59
+ inds_remain = sorted(
60
+ (t for t in cond_frame_outputs if t not in selected_outputs),
61
+ key=lambda x: abs(x - frame_idx),
62
+ )[:num_remain]
63
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
64
+ unselected_outputs = {
65
+ t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
66
+ }
67
+
68
+ return selected_outputs, unselected_outputs
69
+
70
+
71
+ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
72
+ """
73
+ Get 1D sine positional embedding as in the original Transformer paper.
74
+ """
75
+ pe_dim = dim // 2
76
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
77
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
78
+
79
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
80
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
81
+ return pos_embed
82
+
83
+
84
+ def get_activation_fn(activation):
85
+ """Return an activation function given a string"""
86
+ if activation == "relu":
87
+ return F.relu
88
+ if activation == "gelu":
89
+ return F.gelu
90
+ if activation == "glu":
91
+ return F.glu
92
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
93
+
94
+
95
+ def get_clones(module, N):
96
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
97
+
98
+
99
+ class DropPath(nn.Module):
100
+ # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
101
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
102
+ super(DropPath, self).__init__()
103
+ self.drop_prob = drop_prob
104
+ self.scale_by_keep = scale_by_keep
105
+
106
+ def forward(self, x):
107
+ if self.drop_prob == 0.0 or not self.training:
108
+ return x
109
+ keep_prob = 1 - self.drop_prob
110
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
111
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
112
+ if keep_prob > 0.0 and self.scale_by_keep:
113
+ random_tensor.div_(keep_prob)
114
+ return x * random_tensor
115
+
116
+
117
+ # Lightly adapted from
118
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
119
+ class MLP(nn.Module):
120
+ def __init__(
121
+ self,
122
+ input_dim: int,
123
+ hidden_dim: int,
124
+ output_dim: int,
125
+ num_layers: int,
126
+ activation: nn.Module = nn.ReLU,
127
+ sigmoid_output: bool = False,
128
+ ) -> None:
129
+ super().__init__()
130
+ self.num_layers = num_layers
131
+ h = [hidden_dim] * (num_layers - 1)
132
+ self.layers = nn.ModuleList(
133
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
134
+ )
135
+ self.sigmoid_output = sigmoid_output
136
+ self.act = activation()
137
+
138
+ def forward(self, x):
139
+ for i, layer in enumerate(self.layers):
140
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
141
+ if self.sigmoid_output:
142
+ x = F.sigmoid(x)
143
+ return x
144
+
145
+
146
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
147
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
148
+ class LayerNorm2d(nn.Module):
149
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
150
+ super().__init__()
151
+ self.weight = nn.Parameter(torch.ones(num_channels))
152
+ self.bias = nn.Parameter(torch.zeros(num_channels))
153
+ self.eps = eps
154
+
155
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
156
+ u = x.mean(1, keepdim=True)
157
+ s = (x - u).pow(2).mean(1, keepdim=True)
158
+ x = (x - u) / torch.sqrt(s + self.eps)
159
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
160
+ return x
prima/models/components/pose_transformer.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ from inspect import isfunction
11
+ from typing import Callable, Optional
12
+
13
+ import torch
14
+ from einops import rearrange
15
+ from einops.layers.torch import Rearrange
16
+ from torch import nn
17
+
18
+ from .t_cond_mlp import (
19
+ AdaptiveLayerNorm1D,
20
+ FrequencyEmbedder,
21
+ normalization_layer,
22
+ )
23
+
24
+
25
+ def exists(val):
26
+ return val is not None
27
+
28
+
29
+ def default(val, d):
30
+ if exists(val):
31
+ return val
32
+ return d() if isfunction(d) else d
33
+
34
+
35
+ class PreNorm(nn.Module):
36
+ def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
37
+ super().__init__()
38
+ self.norm = normalization_layer(norm, dim, norm_cond_dim)
39
+ self.fn = fn
40
+
41
+ def forward(self, x: torch.Tensor, *args, **kwargs):
42
+ if isinstance(self.norm, AdaptiveLayerNorm1D):
43
+ return self.fn(self.norm(x, *args), **kwargs)
44
+ else:
45
+ return self.fn(self.norm(x), **kwargs)
46
+
47
+
48
+ class FeedForward(nn.Module):
49
+ def __init__(self, dim, hidden_dim, dropout=0.0):
50
+ super().__init__()
51
+ self.net = nn.Sequential(
52
+ nn.Linear(dim, hidden_dim),
53
+ nn.GELU(),
54
+ nn.Dropout(dropout),
55
+ nn.Linear(hidden_dim, dim),
56
+ nn.Dropout(dropout),
57
+ )
58
+
59
+ def forward(self, x):
60
+ return self.net(x)
61
+
62
+
63
+ class Attention(nn.Module):
64
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
65
+ super().__init__()
66
+ inner_dim = dim_head * heads
67
+ project_out = not (heads == 1 and dim_head == dim)
68
+
69
+ self.heads = heads
70
+ self.scale = dim_head**-0.5
71
+
72
+ self.attend = nn.Softmax(dim=-1)
73
+ self.dropout = nn.Dropout(dropout)
74
+
75
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
76
+
77
+ self.to_out = (
78
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
79
+ if project_out
80
+ else nn.Identity()
81
+ )
82
+
83
+ def forward(self, x):
84
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
85
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
86
+
87
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
88
+
89
+ attn = self.attend(dots)
90
+ attn = self.dropout(attn)
91
+
92
+ out = torch.matmul(attn, v)
93
+ out = rearrange(out, "b h n d -> b n (h d)")
94
+ return self.to_out(out)
95
+
96
+
97
+ class CrossAttention(nn.Module):
98
+ def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
99
+ super().__init__()
100
+ inner_dim = dim_head * heads
101
+ project_out = not (heads == 1 and dim_head == dim)
102
+
103
+ self.heads = heads
104
+ self.scale = dim_head**-0.5
105
+
106
+ self.attend = nn.Softmax(dim=-1)
107
+ self.dropout = nn.Dropout(dropout)
108
+
109
+ context_dim = default(context_dim, dim)
110
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
111
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
112
+
113
+ self.to_out = (
114
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
115
+ if project_out
116
+ else nn.Identity()
117
+ )
118
+
119
+ def forward(self, x, context=None):
120
+ context = default(context, x)
121
+ k, v = self.to_kv(context).chunk(2, dim=-1)
122
+ q = self.to_q(x)
123
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
124
+
125
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
126
+
127
+ attn = self.attend(dots)
128
+ attn = self.dropout(attn)
129
+
130
+ out = torch.matmul(attn, v)
131
+ out = rearrange(out, "b h n d -> b n (h d)")
132
+ return self.to_out(out)
133
+
134
+
135
+ class Transformer(nn.Module):
136
+ def __init__(
137
+ self,
138
+ dim: int,
139
+ depth: int,
140
+ heads: int,
141
+ dim_head: int,
142
+ mlp_dim: int,
143
+ dropout: float = 0.0,
144
+ norm: str = "layer",
145
+ norm_cond_dim: int = -1,
146
+ ):
147
+ super().__init__()
148
+ self.layers = nn.ModuleList([])
149
+ for _ in range(depth):
150
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
151
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
152
+ self.layers.append(
153
+ nn.ModuleList(
154
+ [
155
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
156
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
157
+ ]
158
+ )
159
+ )
160
+
161
+ def forward(self, x: torch.Tensor, *args):
162
+ for attn, ff in self.layers:
163
+ x = attn(x, *args) + x
164
+ x = ff(x, *args) + x
165
+ return x
166
+
167
+
168
+ class TransformerCrossAttn(nn.Module):
169
+ def __init__(
170
+ self,
171
+ dim: int,
172
+ depth: int,
173
+ heads: int,
174
+ dim_head: int,
175
+ mlp_dim: int,
176
+ dropout: float = 0.0,
177
+ norm: str = "layer",
178
+ norm_cond_dim: int = -1,
179
+ context_dim: Optional[int] = None,
180
+ ):
181
+ super().__init__()
182
+ self.layers = nn.ModuleList([])
183
+ for _ in range(depth):
184
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
185
+ ca = CrossAttention(
186
+ dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
187
+ )
188
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
189
+ self.layers.append(
190
+ nn.ModuleList(
191
+ [
192
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
193
+ PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
194
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
195
+ ]
196
+ )
197
+ )
198
+
199
+ def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
200
+ if context_list is None:
201
+ context_list = [context] * len(self.layers)
202
+ if len(context_list) != len(self.layers):
203
+ raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
204
+
205
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
206
+ x = self_attn(x, *args) + x
207
+ x = cross_attn(x, *args, context=context_list[i]) + x
208
+ x = ff(x, *args) + x
209
+ return x
210
+
211
+
212
+ class DropTokenDropout(nn.Module):
213
+ def __init__(self, p: float = 0.1):
214
+ super().__init__()
215
+ if p < 0 or p > 1:
216
+ raise ValueError(
217
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
218
+ )
219
+ self.p = p
220
+
221
+ def forward(self, x: torch.Tensor):
222
+ # x: (batch_size, seq_len, dim)
223
+ if self.training and self.p > 0:
224
+ zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
225
+
226
+ if zero_mask.any():
227
+ x = x[:, ~zero_mask, :]
228
+ return x
229
+
230
+
231
+ class ZeroTokenDropout(nn.Module):
232
+ def __init__(self, p: float = 0.1):
233
+ super().__init__()
234
+ if p < 0 or p > 1:
235
+ raise ValueError(
236
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
237
+ )
238
+ self.p = p
239
+
240
+ def forward(self, x: torch.Tensor):
241
+ # x: (batch_size, seq_len, dim)
242
+ if self.training and self.p > 0:
243
+ zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
244
+ # Zero-out the masked tokens
245
+ x[zero_mask, :] = 0
246
+ return x
247
+
248
+
249
+ class TransformerEncoder(nn.Module):
250
+ def __init__(
251
+ self,
252
+ num_tokens: int,
253
+ token_dim: int,
254
+ dim: int,
255
+ depth: int,
256
+ heads: int,
257
+ mlp_dim: int,
258
+ dim_head: int = 64,
259
+ dropout: float = 0.0,
260
+ emb_dropout: float = 0.0,
261
+ emb_dropout_type: str = "drop",
262
+ emb_dropout_loc: str = "token",
263
+ norm: str = "layer",
264
+ norm_cond_dim: int = -1,
265
+ token_pe_numfreq: int = -1,
266
+ ):
267
+ super().__init__()
268
+ if token_pe_numfreq > 0:
269
+ token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
270
+ self.to_token_embedding = nn.Sequential(
271
+ Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
272
+ FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
273
+ Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
274
+ nn.Linear(token_dim_new, dim),
275
+ )
276
+ else:
277
+ self.to_token_embedding = nn.Linear(token_dim, dim)
278
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
279
+ if emb_dropout_type == "drop":
280
+ self.dropout = DropTokenDropout(emb_dropout)
281
+ elif emb_dropout_type == "zero":
282
+ self.dropout = ZeroTokenDropout(emb_dropout)
283
+ else:
284
+ raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
285
+ self.emb_dropout_loc = emb_dropout_loc
286
+
287
+ self.transformer = Transformer(
288
+ dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
289
+ )
290
+
291
+ def forward(self, inp: torch.Tensor, *args, **kwargs):
292
+ x = inp
293
+
294
+ if self.emb_dropout_loc == "input":
295
+ x = self.dropout(x)
296
+ x = self.to_token_embedding(x)
297
+
298
+ if self.emb_dropout_loc == "token":
299
+ x = self.dropout(x)
300
+ b, n, _ = x.shape
301
+ x += self.pos_embedding[:, :n]
302
+
303
+ if self.emb_dropout_loc == "token_afterpos":
304
+ x = self.dropout(x)
305
+ x = self.transformer(x, *args)
306
+ return x
307
+
308
+
309
+ class TransformerDecoder(nn.Module):
310
+ def __init__(
311
+ self,
312
+ num_tokens: int,
313
+ token_dim: int,
314
+ dim: int,
315
+ depth: int,
316
+ heads: int,
317
+ mlp_dim: int,
318
+ dim_head: int = 64,
319
+ dropout: float = 0.0,
320
+ emb_dropout: float = 0.0,
321
+ emb_dropout_type: str = 'drop',
322
+ norm: str = "layer",
323
+ norm_cond_dim: int = -1,
324
+ context_dim: Optional[int] = None,
325
+ skip_token_embedding: bool = False,
326
+ ):
327
+ super().__init__()
328
+ if not skip_token_embedding:
329
+ self.to_token_embedding = nn.Linear(token_dim, dim)
330
+ else:
331
+ self.to_token_embedding = nn.Identity()
332
+ if token_dim != dim:
333
+ raise ValueError(
334
+ f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
335
+ )
336
+
337
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
338
+ if emb_dropout_type == "drop":
339
+ self.dropout = DropTokenDropout(emb_dropout)
340
+ elif emb_dropout_type == "zero":
341
+ self.dropout = ZeroTokenDropout(emb_dropout)
342
+ elif emb_dropout_type == "normal":
343
+ self.dropout = nn.Dropout(emb_dropout)
344
+
345
+ self.transformer = TransformerCrossAttn(
346
+ dim,
347
+ depth,
348
+ heads,
349
+ dim_head,
350
+ mlp_dim,
351
+ dropout,
352
+ norm=norm,
353
+ norm_cond_dim=norm_cond_dim,
354
+ context_dim=context_dim,
355
+ )
356
+
357
+ def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
358
+ x = self.to_token_embedding(inp)
359
+ b, n, _ = x.shape
360
+
361
+ x = self.dropout(x)
362
+ x += self.pos_embedding[:, :n]
363
+
364
+ x = self.transformer(x, *args, context=context, context_list=context_list)
365
+ return x
366
+
prima/models/components/position_encoding.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
11
+ # All rights reserved.
12
+
13
+ # This source code is licensed under the license found in the
14
+ # LICENSE file in the root directory of this source tree.
15
+
16
+ import math
17
+ from typing import Any, Optional, Tuple
18
+
19
+ import numpy as np
20
+
21
+ import torch
22
+ from torch import nn
23
+
24
+ # Rotary Positional Encoding, adapted from:
25
+ # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
26
+ # 2. https://github.com/naver-ai/rope-vit
27
+ # 3. https://github.com/lucidrains/rotary-embedding-torch
28
+
29
+
30
+ def init_t_xy(end_x: int, end_y: int):
31
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
32
+ t_x = (t % end_x).float()
33
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
34
+ return t_x, t_y
35
+
36
+
37
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
38
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
39
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
40
+
41
+ t_x, t_y = init_t_xy(end_x, end_y)
42
+ freqs_x = torch.outer(t_x, freqs_x)
43
+ freqs_y = torch.outer(t_y, freqs_y)
44
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
45
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
46
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
47
+
48
+
49
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
50
+ ndim = x.ndim
51
+ assert 0 <= 1 < ndim
52
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
53
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
54
+ return freqs_cis.view(*shape)
55
+
56
+
57
+ def apply_rotary_enc(
58
+ xq: torch.Tensor,
59
+ xk: torch.Tensor,
60
+ freqs_cis: torch.Tensor,
61
+ repeat_freqs_k: bool = False,
62
+ ):
63
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
64
+ xk_ = (
65
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
66
+ if xk.shape[-2] != 0
67
+ else None
68
+ )
69
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
70
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
71
+ if xk_ is None:
72
+ # no keys to rotate, due to dropout
73
+ return xq_out.type_as(xq).to(xq.device), xk
74
+ # repeat freqs along seq_len dim to match k seq_len
75
+ if repeat_freqs_k:
76
+ r = xk_.shape[-2] // xq_.shape[-2]
77
+ if freqs_cis.is_cuda:
78
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
79
+ else:
80
+ # torch.repeat on complex numbers may not be supported on non-CUDA devices
81
+ # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
82
+ freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
83
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
84
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
prima/models/components/t_cond_mlp.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
3
+
4
+ Official implementation of the paper:
5
+ "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
6
+ by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
7
+ Licensed under a modified MIT license
8
+ """
9
+
10
+ import copy
11
+ from typing import List, Optional
12
+
13
+ import torch
14
+
15
+
16
+ class AdaptiveLayerNorm1D(torch.nn.Module):
17
+ def __init__(self, data_dim: int, norm_cond_dim: int):
18
+ super().__init__()
19
+ if data_dim <= 0:
20
+ raise ValueError(f"data_dim must be positive, but got {data_dim}")
21
+ if norm_cond_dim <= 0:
22
+ raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
23
+ self.norm = torch.nn.LayerNorm(data_dim)
24
+ self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
25
+ torch.nn.init.zeros_(self.linear.weight)
26
+ torch.nn.init.zeros_(self.linear.bias)
27
+
28
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
29
+ # x: (batch, ..., data_dim)
30
+ # t: (batch, norm_cond_dim)
31
+ # return: (batch, data_dim)
32
+ x = self.norm(x)
33
+ alpha, beta = self.linear(t).chunk(2, dim=-1)
34
+
35
+ # Add singleton dimensions to alpha and beta
36
+ if x.dim() > 2:
37
+ alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
38
+ beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
39
+
40
+ return x * (1 + alpha) + beta
41
+
42
+
43
+ class SequentialCond(torch.nn.Sequential):
44
+ def forward(self, input, *args, **kwargs):
45
+ for module in self:
46
+ if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
47
+ input = module(input, *args, **kwargs)
48
+ else:
49
+ input = module(input)
50
+ return input
51
+
52
+
53
+ def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
54
+ if norm == "batch":
55
+ return torch.nn.BatchNorm1d(dim)
56
+ elif norm == "layer":
57
+ return torch.nn.LayerNorm(dim)
58
+ elif norm == "ada":
59
+ assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
60
+ return AdaptiveLayerNorm1D(dim, norm_cond_dim)
61
+ elif norm is None:
62
+ return torch.nn.Identity()
63
+ else:
64
+ raise ValueError(f"Unknown norm: {norm}")
65
+
66
+
67
+ def linear_norm_activ_dropout(
68
+ input_dim: int,
69
+ output_dim: int,
70
+ activation: torch.nn.Module = torch.nn.ReLU(),
71
+ bias: bool = True,
72
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
73
+ dropout: float = 0.0,
74
+ norm_cond_dim: int = -1,
75
+ ) -> SequentialCond:
76
+ layers = []
77
+ layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
78
+ if norm is not None:
79
+ layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
80
+ layers.append(copy.deepcopy(activation))
81
+ if dropout > 0.0:
82
+ layers.append(torch.nn.Dropout(dropout))
83
+ return SequentialCond(*layers)
84
+
85
+
86
+ def create_simple_mlp(
87
+ input_dim: int,
88
+ hidden_dims: List[int],
89
+ output_dim: int,
90
+ activation: torch.nn.Module = torch.nn.ReLU(),
91
+ bias: bool = True,
92
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
93
+ dropout: float = 0.0,
94
+ norm_cond_dim: int = -1,
95
+ ) -> SequentialCond:
96
+ layers = []
97
+ prev_dim = input_dim
98
+ for hidden_dim in hidden_dims:
99
+ layers.extend(
100
+ linear_norm_activ_dropout(
101
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
102
+ )
103
+ )
104
+ prev_dim = hidden_dim
105
+ layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
106
+ return SequentialCond(*layers)
107
+
108
+
109
+ class ResidualMLPBlock(torch.nn.Module):
110
+ def __init__(
111
+ self,
112
+ input_dim: int,
113
+ hidden_dim: int,
114
+ num_hidden_layers: int,
115
+ output_dim: int,
116
+ activation: torch.nn.Module = torch.nn.ReLU(),
117
+ bias: bool = True,
118
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
119
+ dropout: float = 0.0,
120
+ norm_cond_dim: int = -1,
121
+ ):
122
+ super().__init__()
123
+ if not (input_dim == output_dim == hidden_dim):
124
+ raise NotImplementedError(
125
+ f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
126
+ )
127
+
128
+ layers = []
129
+ prev_dim = input_dim
130
+ for i in range(num_hidden_layers):
131
+ layers.append(
132
+ linear_norm_activ_dropout(
133
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
134
+ )
135
+ )
136
+ prev_dim = hidden_dim
137
+ self.model = SequentialCond(*layers)
138
+ self.skip = torch.nn.Identity()
139
+
140
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
141
+ return x + self.model(x, *args, **kwargs)
142
+
143
+
144
+ class ResidualMLP(torch.nn.Module):
145
+ def __init__(
146
+ self,
147
+ input_dim: int,
148
+ hidden_dim: int,
149
+ num_hidden_layers: int,
150
+ output_dim: int,
151
+ activation: torch.nn.Module = torch.nn.ReLU(),
152
+ bias: bool = True,
153
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
154
+ dropout: float = 0.0,
155
+ num_blocks: int = 1,
156
+ norm_cond_dim: int = -1,
157
+ ):
158
+ super().__init__()
159
+ self.input_dim = input_dim
160
+ self.model = SequentialCond(
161
+ linear_norm_activ_dropout(
162
+ input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
163
+ ),
164
+ *[
165
+ ResidualMLPBlock(
166
+ hidden_dim,
167
+ hidden_dim,
168
+ num_hidden_layers,
169
+ hidden_dim,
170
+ activation,
171
+ bias,
172
+ norm,
173
+ dropout,
174
+ norm_cond_dim,
175
+ )
176
+ for _ in range(num_blocks)
177
+ ],
178
+ torch.nn.Linear(hidden_dim, output_dim, bias=bias),
179
+ )
180
+
181
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
182
+ return self.model(x, *args, **kwargs)
183
+
184
+
185
+ class FrequencyEmbedder(torch.nn.Module):
186
+ def __init__(self, num_frequencies, max_freq_log2):
187
+ super().__init__()
188
+ frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
189
+ self.register_buffer("frequencies", frequencies)
190
+
191
+ def forward(self, x):
192
+ # x should be of size (N,) or (N, D)
193
+ N = x.size(0)
194
+ if x.dim() == 1: # (N,)
195
+ x = x.unsqueeze(1) # (N, D) where D=1
196
+ x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
197
+ scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
198
+ s = torch.sin(scaled)
199
+ c = torch.cos(scaled)
200
+ embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
201
+ N, -1
202
+ ) # (N, D * 2 * num_frequencies + D)
203
+ return embedded
204
+