HF Space deploy commited on
Commit ·
9d665dd
0
Parent(s):
Deploy snapshot (LFS for demo images per .gitattributes)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +5 -0
- README.md +251 -0
- app.py +660 -0
- chumpy/__init__.py +16 -0
- chumpy/ch.py +52 -0
- configs/sa_finetune_hrnet_w32.yaml +220 -0
- demo_data/000000015956_horse.png +3 -0
- demo_data/000000315905_zebra.jpg +3 -0
- demo_data/beagle.jpg +3 -0
- demo_data/n02101388_1188.png +3 -0
- demo_data/n02412080_12159.png +3 -0
- demo_data/shepherd_hati.jpg +3 -0
- demo_tta.py +399 -0
- images/teaser.png +3 -0
- packages.txt +7 -0
- prima/__init__.py +25 -0
- prima/configs/__init__.py +99 -0
- prima/datasets/__init__.py +79 -0
- prima/datasets/datasets.py +278 -0
- prima/datasets/dlc2coco.py +362 -0
- prima/datasets/split_acinoset.py +153 -0
- prima/datasets/utils.py +1106 -0
- prima/datasets/vitdet_dataset.py +100 -0
- prima/models/__init__.py +54 -0
- prima/models/backbones/__init__.py +19 -0
- prima/models/backbones/vit.py +375 -0
- prima/models/bioclip_embedding.py +70 -0
- prima/models/components/__init__.py +0 -0
- prima/models/components/model_utils.py +160 -0
- prima/models/components/pose_transformer.py +366 -0
- prima/models/components/position_encoding.py +84 -0
- prima/models/components/t_cond_mlp.py +204 -0
- prima/models/components/transformer.py +400 -0
- prima/models/discriminator.py +129 -0
- prima/models/heads/__init__.py +1 -0
- prima/models/heads/classifier_head.py +30 -0
- prima/models/heads/smal_head.py +647 -0
- prima/models/losses.py +580 -0
- prima/models/prima.py +615 -0
- prima/models/smal_wrapper.py +134 -0
- prima/utils/__init__.py +45 -0
- prima/utils/detection.py +118 -0
- prima/utils/evaluate_metric.py +206 -0
- prima/utils/geometry.py +115 -0
- prima/utils/mesh_renderer.py +330 -0
- prima/utils/misc.py +211 -0
- prima/utils/pylogger.py +26 -0
- prima/utils/renderer.py +433 -0
- prima/utils/rich_utils.py +114 -0
- prima/utils/weights.py +337 -0
.gitattributes
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Hub stores these via Git LFS / Xet (plain PNG/JPG in git are rejected on push).
|
| 2 |
+
demo_data/*.png filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
demo_data/*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
demo_data/*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
images/*.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: PRIMA Demo
|
| 3 |
+
emoji: 🦮
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
python_version: "3.10"
|
| 8 |
+
app_file: app.py
|
| 9 |
+
startup_duration_timeout: 60m
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
This is the official implementation of the approach described in the preprint:
|
| 16 |
+
|
| 17 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation \
|
| 18 |
+
Xiaohang Yu, Ti Wang, Mackenzie Weygandt Mathis
|
| 19 |
+
|
| 20 |
+

|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
## 🚀 TL;DR
|
| 27 |
+
PRIMA creates a 3D quadruped mesh from a single 2D image. It leverages BioCLIP-based biological priors for robust cross-species shape understanding, then applies test-time adaptation with 2D reprojection and auxiliary keypoint guidance to refine SMAL pose and shape predictions.
|
| 28 |
+
|
| 29 |
+
It further can be used to build Quadruped3D, a large-scale pseudo-3D dataset with diverse species and poses.
|
| 30 |
+
|
| 31 |
+
PRIMA achieves state-of-the-art results on Animal3D, CtrlAni3D, Quadruped2D, and Animal Kingdom datasets.
|
| 32 |
+
|
| 33 |
+
## Installation
|
| 34 |
+
|
| 35 |
+
### Install from PyPI
|
| 36 |
+
|
| 37 |
+
> Recommended: Python 3.10 and a CUDA-enabled PyTorch installation.
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
conda create -n prima python=3.10 -y
|
| 41 |
+
conda activate prima
|
| 42 |
+
|
| 43 |
+
# Install PyTorch matching your CUDA (example: CUDA 11.8)
|
| 44 |
+
pip install --index-url https://download.pytorch.org/whl/cu118 \
|
| 45 |
+
"torch==2.2.1" "torchvision==0.17.1" "torchaudio==2.2.1"
|
| 46 |
+
|
| 47 |
+
# Install chumpy and PyTorch3D
|
| 48 |
+
python -m pip install --no-build-isolation \
|
| 49 |
+
"git+https://github.com/mattloper/chumpy.git"
|
| 50 |
+
python -m pip install --no-build-isolation \
|
| 51 |
+
"git+https://github.com/facebookresearch/pytorch3d.git"
|
| 52 |
+
|
| 53 |
+
# Install PRIMA from PyPI
|
| 54 |
+
pip install prima-animal
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
`prima-animal` includes demo runtime dependencies used by `demo.py`, `demo_tta.py`, and `app.py` (including Detectron2 and DeepLabCut).
|
| 58 |
+
|
| 59 |
+
### Clean install from this repository
|
| 60 |
+
|
| 61 |
+
Use these when developing from a **git clone** (not the PyPI wheel). The shell scripts are **non-interactive** (pip uses `--no-input`; `GIT_TERMINAL_PROMPT=0` for git). Put Hugging Face credentials in your environment or git credential helper before pushing the Space.
|
| 62 |
+
|
| 63 |
+
**Local (fresh venv, LFS assets, Hub demo weights, smoke test)** — requires **Python 3.10+**
|
| 64 |
+
(Gradio 5.1+ / Space-provided Gradio 6.x and `app.py` type hints). On macOS without `python3.10` on your `PATH`, install
|
| 65 |
+
`brew install python@3.10` and set `PRIMA_PYTHON=/opt/homebrew/bin/python3.10`.
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
chmod +x scripts/clean_install_local.sh scripts/clean_redeploy_hf_space.sh scripts/deploy_hf_space.sh
|
| 69 |
+
PRIMA_PYTHON=/opt/homebrew/bin/python3.10 ./scripts/clean_install_local.sh
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
Options:
|
| 73 |
+
|
| 74 |
+
- `PRIMA_VENV=.venv ./scripts/clean_install_local.sh --skip-data` — skip the large `setup_demo_data` download if `data/` is already populated.
|
| 75 |
+
- `./scripts/clean_install_local.sh --wipe-data --force-data` — delete downloaded `data/` assets and redownload.
|
| 76 |
+
- `./scripts/clean_install_local.sh --no-editable` — only `requirements.txt` (no `pip install -e .`); use if editable install fails and you will install the training stack via conda as in the PyPI section above. You still need **Python 3.10+** for Gradio 5.1+. The smoke test sets `PYTHONPATH` to the repo root so `import prima` works without an editable install.
|
| 77 |
+
- **macOS:** the script omits the `deeplabcut` line from `pip install` because DeepLabCut’s pinned PyTables version often does not build on Apple Silicon. Use conda/mamba for DeepLabCut if you need SuperAnimal + TTA (`tta_num_iters` > 0). **Linux** (including Hugging Face Space builds) uses the full `requirements.txt` including `deeplabcut`.
|
| 78 |
+
|
| 79 |
+
After `requirements.txt`, the script runs **`pip install --no-deps -e .`** so the `prima` package is registered without re-resolving `pyproject.toml` (which would pull **Detectron2** and **DeepLabCut** again and often fail on macOS). Full `pip install -e .` is still recommended from a **conda** environment per the PyPI section if you need every training extra matched exactly.
|
| 80 |
+
|
| 81 |
+
**Hugging Face Space (full redeploy from your working tree):**
|
| 82 |
+
|
| 83 |
+
Requires [Git LFS / Xet](https://huggingface.co/docs/hub/xet/using-xet-storage#git) tooling (`brew install git-lfs git-xet`, `git xet install`, `git lfs install`). Then:
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
./scripts/clean_redeploy_hf_space.sh
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
This is equivalent to `./scripts/deploy_hf_space.sh` and force-pushes a fresh snapshot to the Space.
|
| 90 |
+
|
| 91 |
+
---
|
| 92 |
+
|
| 93 |
+
## Demo
|
| 94 |
+
|
| 95 |
+
### Checkpoints and data
|
| 96 |
+
|
| 97 |
+
The demo scripts auto-download their default Stage 1 PRIMA assets from Hugging
|
| 98 |
+
Face when the checkpoint or matching Hydra config is missing. If you want to
|
| 99 |
+
pre-download all necessary checkpoints and data ahead of time, run:
|
| 100 |
+
|
| 101 |
+
```bash
|
| 102 |
+
python scripts/setup_demo_data.py --hf-repo-id MLAdaptiveIntelligence/PRIMA
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
Approximate default prefetch volume from Hugging Face is ~5.5 GB total
|
| 106 |
+
(`s1ckpt_inference.ckpt` ~3 GB + `amr_vitbb.pth` ~2.5 GB + SMAL files).
|
| 107 |
+
Expected time is roughly:
|
| 108 |
+
- 100 Mbps: ~7-10 minutes
|
| 109 |
+
- 300 Mbps: ~2-4 minutes
|
| 110 |
+
- 1 Gbps: ~1 minute
|
| 111 |
+
|
| 112 |
+
Existing files are reused by default; pass `--force` only if you need to redownload them. If you also need the Stage 3 pretrained model, add `--include-stage3`.
|
| 113 |
+
|
| 114 |
+
Expected files in that Hugging Face repo root:
|
| 115 |
+
- `my_smpl_00781_4_all.pkl`
|
| 116 |
+
- `my_smpl_data_00781_4_all.pkl`
|
| 117 |
+
- `walking_toy_symmetric_pose_prior_with_cov_35parts.pkl`
|
| 118 |
+
- `amr_vitbb.pth`
|
| 119 |
+
- `config_s1_HYDRA.yaml`
|
| 120 |
+
- `s1ckpt_inference.ckpt`
|
| 121 |
+
|
| 122 |
+
Optional Stage 3 prefetch expects:
|
| 123 |
+
- `config_s3_HYDRA.yaml`
|
| 124 |
+
- `s3ckpt_inference.ckpt`
|
| 125 |
+
|
| 126 |
+
### Demo (without TTA)
|
| 127 |
+
|
| 128 |
+
Run animal detection + PRIMA 3D pose/shape inference:
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
bash demo.sh
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
Outputs are written to `demo_out/`. Edit `demo.sh` if you want to use a custom
|
| 135 |
+
checkpoint path.
|
| 136 |
+
|
| 137 |
+
---
|
| 138 |
+
|
| 139 |
+
### Demo (with TTA)
|
| 140 |
+
|
| 141 |
+
Run PRIMA inference with test-time adaptation:
|
| 142 |
+
|
| 143 |
+
```bash
|
| 144 |
+
bash demo_tta.sh
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
Outputs are written to `demo_out_tta/` (before/after TTA renders, keypoints, and
|
| 148 |
+
optional meshes). Edit `demo_tta.sh` if you want to change the checkpoint, TTA
|
| 149 |
+
learning rate, or number of iterations.
|
| 150 |
+
|
| 151 |
+
---
|
| 152 |
+
|
| 153 |
+
### Gradio demo
|
| 154 |
+
|
| 155 |
+
We also provide a simple Gradio-based web demo for interactive testing in the
|
| 156 |
+
browser:
|
| 157 |
+
|
| 158 |
+
```bash
|
| 159 |
+
python app.py \
|
| 160 |
+
--checkpoint data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt \
|
| 161 |
+
--out_folder demo_out_tta_gradio/
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
This starts a local Gradio app (by default on http://127.0.0.1:7860), where
|
| 165 |
+
you can upload images and visualize PRIMA predictions and adaptation results.
|
| 166 |
+
The `s1ckpt_inference.ckpt` checkpoint is downloaded automatically if missing.
|
| 167 |
+
|
| 168 |
+
#### Hugging Face Space (maintainers)
|
| 169 |
+
|
| 170 |
+
Demo images under `demo_data/` and `images/teaser.png` are tracked with **Git LFS**
|
| 171 |
+
(see `.gitattributes`) so they can be pushed to a Hugging Face Space under the Hub’s
|
| 172 |
+
LFS / **Xet** bridge. Install tooling once:
|
| 173 |
+
|
| 174 |
+
```bash
|
| 175 |
+
brew install git-lfs git-xet
|
| 176 |
+
git xet install
|
| 177 |
+
git lfs install
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
Then from a clean checkout with LFS files present, redeploy the Space (same as `clean_redeploy_hf_space.sh`):
|
| 181 |
+
|
| 182 |
+
```bash
|
| 183 |
+
./scripts/deploy_hf_space.sh
|
| 184 |
+
# or
|
| 185 |
+
./scripts/clean_redeploy_hf_space.sh
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
The script rsyncs the working tree (not `git archive`) so image files are materialized
|
| 189 |
+
before `git add` turns them into LFS blobs.
|
| 190 |
+
|
| 191 |
+
---
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
## Training and Evaluation
|
| 195 |
+
|
| 196 |
+
### Dataset Setup
|
| 197 |
+
|
| 198 |
+
Download datasets from [Animal3D](https://xujiacong.github.io/Animal3D/), [CtrlAni3D](https://github.com/luoxue-star/AniMer?tab=readme-ov-file#training), Quadruped2D, and [Animal Kingdom](https://drive.google.com/file/d/1dk2a0qB0fbVZ4X6eAgP6VJVXj0rxVfsJ/view?usp=drive_link). For Quadruped2D, download the images from [SuperAnimal-Quadruped80K](https://zenodo.org/records/14016777) and our processed annotations from [here](https://drive.google.com/drive/folders/1eBNboxVwl_eGPoC93zxf-U3hmE6e2f-f?usp=sharing). Put all the datasets under `datasets/`.
|
| 199 |
+
|
| 200 |
+
### Training
|
| 201 |
+
|
| 202 |
+
Two-stage training script:
|
| 203 |
+
|
| 204 |
+
```bash
|
| 205 |
+
bash train.sh
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
Training outputs are written to `logs/train/runs/<exp_name>/`.
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
### Evaluation
|
| 212 |
+
|
| 213 |
+
```bash
|
| 214 |
+
python eval.py \
|
| 215 |
+
--config data/PRIMAS1/.hydra/config.yaml \
|
| 216 |
+
--checkpoint data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
Common values for `--dataset` are controlled by:
|
| 220 |
+
- `configs_hydra/experiment/default_val.yaml`
|
| 221 |
+
|
| 222 |
+
---
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
## Acknowledgements
|
| 226 |
+
|
| 227 |
+
This release builds on several open-source projects, including:
|
| 228 |
+
- [Detectron2](https://github.com/facebookresearch/detectron2)
|
| 229 |
+
- [BioCLIP](https://github.com/Imageomics/BioCLIP)
|
| 230 |
+
- [AniMer](https://github.com/luoxue-star/AniMer)
|
| 231 |
+
- [DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)
|
| 232 |
+
- [SAM3DB](https://github.com/facebookresearch/sam-3d-body)
|
| 233 |
+
|
| 234 |
+
---
|
| 235 |
+
|
| 236 |
+
## Citation
|
| 237 |
+
|
| 238 |
+
If you use this code in your research, please cite our PRIMA paper.
|
| 239 |
+
|
| 240 |
+
```bibtex
|
| 241 |
+
@misc{yu_prima,
|
| 242 |
+
title={PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation},
|
| 243 |
+
author={Xiaohang Yu and Ti Wang and Mackenzie Weygandt Mathis},
|
| 244 |
+
}
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
---
|
| 248 |
+
|
| 249 |
+
## Contact
|
| 250 |
+
|
| 251 |
+
For issues, please open a GitHub issue in this repository.
|
app.py
ADDED
|
@@ -0,0 +1,660 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
"""Gradio demo for PRIMA + SuperAnimal + TTA.
|
| 11 |
+
|
| 12 |
+
This script wraps the ``demo_tta.py`` pipeline into an interactive
|
| 13 |
+
Gradio interface. The overall logic follows:
|
| 14 |
+
|
| 15 |
+
1. Given an input image, run Detectron2 to detect animals.
|
| 16 |
+
2. For each detected animal, run PRIMA for 3D pose/shape estimation.
|
| 17 |
+
3. Run the fine-tuned DeepLabCut SuperAnimal model to obtain PRIMA 26-keypoint
|
| 18 |
+
2D predictions.
|
| 19 |
+
4. Run test-time adaptation (TTA) with user-specified lr and iters.
|
| 20 |
+
5. Render and save before/after TTA results and keypoint visualizations.
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import concurrent.futures
|
| 26 |
+
import os
|
| 27 |
+
import queue
|
| 28 |
+
import sys
|
| 29 |
+
import tempfile
|
| 30 |
+
import time
|
| 31 |
+
import traceback
|
| 32 |
+
from types import SimpleNamespace
|
| 33 |
+
from typing import Callable, List, Tuple
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
import cv2
|
| 37 |
+
import gradio as gr
|
| 38 |
+
import numpy as np
|
| 39 |
+
import torch
|
| 40 |
+
import torch.utils.data
|
| 41 |
+
|
| 42 |
+
# Repo-local minimal ``chumpy`` shim (see ``chumpy/__init__.py``) so SMAL pickles load
|
| 43 |
+
# without installing the full chumpy package in Space builds.
|
| 44 |
+
_REPO_ROOT = Path(__file__).resolve().parent
|
| 45 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 46 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 47 |
+
|
| 48 |
+
from prima.utils.weights import (
|
| 49 |
+
DEFAULT_HF_REPO_ID,
|
| 50 |
+
resolve_prima_checkpoint_path,
|
| 51 |
+
)
|
| 52 |
+
from prima.utils.detection import select_animal_boxes
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Default checkpoint path following README instructions
|
| 56 |
+
DEFAULT_CHECKPOINT = str(_REPO_ROOT / "data" / "PRIMAS1" / "checkpoints" / "s1ckpt_inference.ckpt")
|
| 57 |
+
DEFAULT_HF_ASSET_REPO = DEFAULT_HF_REPO_ID
|
| 58 |
+
|
| 59 |
+
# Output folder for rendered images/meshes and keypoints
|
| 60 |
+
DEFAULT_OUT_FOLDER = "demo_out_tta_gradio"
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _is_truthy_env(var_name: str) -> bool:
|
| 64 |
+
return os.environ.get(var_name, "").strip().lower() in {"1", "true", "yes", "on"}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _running_on_space() -> bool:
|
| 68 |
+
return bool(os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID"))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _gradio_examples_for_interface() -> List[List]:
|
| 72 |
+
"""Gradio prefetches example media at startup.
|
| 73 |
+
|
| 74 |
+
Demo images are tracked with Git LFS / Xet (see ``.gitattributes``) so they can live
|
| 75 |
+
in the Hugging Face Space repo. Use absolute paths only when files exist beside ``app.py``.
|
| 76 |
+
"""
|
| 77 |
+
if _is_truthy_env("PRIMA_DISABLE_GRADIO_EXAMPLES"):
|
| 78 |
+
return []
|
| 79 |
+
rows: List[List] = []
|
| 80 |
+
template: List[Tuple[str, float, int, float, float, bool, bool]] = [
|
| 81 |
+
("demo_data/000000015956_horse.png", 1e-6, 0, 0.7, 0.1, False, True),
|
| 82 |
+
("demo_data/n02412080_12159.png", 1e-6, 0, 0.7, 0.1, False, True),
|
| 83 |
+
("demo_data/000000315905_zebra.jpg", 1e-6, 0, 0.7, 0.1, False, True),
|
| 84 |
+
("demo_data/beagle.jpg", 1e-6, 0, 0.7, 0.1, False, True),
|
| 85 |
+
("demo_data/shepherd_hati.jpg", 1e-6, 0, 0.7, 0.1, False, True),
|
| 86 |
+
]
|
| 87 |
+
for rel, *rest in template:
|
| 88 |
+
p = _REPO_ROOT / rel
|
| 89 |
+
if p.is_file():
|
| 90 |
+
rows.append([str(p), *rest])
|
| 91 |
+
return rows
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _should_preload_assets() -> bool:
|
| 95 |
+
"""Default to preload on Spaces; configurable via PRIMA_PRELOAD_ASSETS."""
|
| 96 |
+
preload_env = os.environ.get("PRIMA_PRELOAD_ASSETS")
|
| 97 |
+
if preload_env is not None:
|
| 98 |
+
return _is_truthy_env("PRIMA_PRELOAD_ASSETS")
|
| 99 |
+
return _running_on_space()
|
| 100 |
+
|
| 101 |
+
def _gradio_heartbeat_interval_sec() -> float:
|
| 102 |
+
"""How often to yield status while waiting on long CPU/GPU work (keeps WebSockets alive).
|
| 103 |
+
|
| 104 |
+
Set ``PRIMA_GRADIO_HEARTBEAT_SEC`` to ``0`` to run long work on the Gradio thread (old behavior).
|
| 105 |
+
"""
|
| 106 |
+
raw = os.environ.get("PRIMA_GRADIO_HEARTBEAT_SEC", "25").strip()
|
| 107 |
+
try:
|
| 108 |
+
v = float(raw)
|
| 109 |
+
except ValueError:
|
| 110 |
+
return 25.0
|
| 111 |
+
return max(0.0, v)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _preload_assets_once(checkpoint_path: str) -> None:
|
| 115 |
+
print("[startup] Ensuring demo assets from Hugging Face Hub...")
|
| 116 |
+
resolve_prima_checkpoint_path(
|
| 117 |
+
checkpoint_path,
|
| 118 |
+
data_dir=_REPO_ROOT / "data",
|
| 119 |
+
auto_download=True,
|
| 120 |
+
hf_repo_id=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO),
|
| 121 |
+
)
|
| 122 |
+
print("[startup] Asset preload complete.")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _load_prima_model(checkpoint_path: str = DEFAULT_CHECKPOINT):
|
| 126 |
+
"""Load PRIMA model and renderer once for the Gradio app."""
|
| 127 |
+
from prima.models import load_prima
|
| 128 |
+
from prima.utils.renderer import Renderer, cam_crop_to_full
|
| 129 |
+
|
| 130 |
+
checkpoint_path = resolve_prima_checkpoint_path(
|
| 131 |
+
checkpoint_path,
|
| 132 |
+
data_dir=_REPO_ROOT / "data",
|
| 133 |
+
auto_download=True,
|
| 134 |
+
hf_repo_id=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO),
|
| 135 |
+
)
|
| 136 |
+
checkpoint = Path(checkpoint_path)
|
| 137 |
+
cfg_path = checkpoint.parent.parent / ".hydra" / "config.yaml"
|
| 138 |
+
if not checkpoint.exists():
|
| 139 |
+
raise FileNotFoundError(
|
| 140 |
+
f"Missing checkpoint: {checkpoint}. Download demo checkpoints/data as described in README."
|
| 141 |
+
)
|
| 142 |
+
if not cfg_path.exists():
|
| 143 |
+
raise FileNotFoundError(
|
| 144 |
+
f"Missing model config: {cfg_path}. Ensure the full checkpoint folder layout from README is present."
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
model, model_cfg = load_prima(checkpoint_path)
|
| 148 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 149 |
+
model = model.to(device)
|
| 150 |
+
model.eval()
|
| 151 |
+
|
| 152 |
+
renderer = Renderer(model_cfg, faces=model.smal.faces)
|
| 153 |
+
return model, model_cfg, renderer, cam_crop_to_full, device
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _build_detector():
|
| 157 |
+
"""Build Detectron2 animal detector (same config as demo_tta/demo.py)."""
|
| 158 |
+
try:
|
| 159 |
+
import detectron2.config
|
| 160 |
+
import detectron2.engine
|
| 161 |
+
from detectron2 import model_zoo
|
| 162 |
+
except Exception as e:
|
| 163 |
+
print(f"[warn] Detectron2 unavailable ({type(e).__name__}: {e}); using full-image fallback bbox.")
|
| 164 |
+
return None
|
| 165 |
+
|
| 166 |
+
cfg = detectron2.config.get_cfg()
|
| 167 |
+
cfg.merge_from_file(
|
| 168 |
+
model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml")
|
| 169 |
+
)
|
| 170 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
|
| 171 |
+
cfg.MODEL.WEIGHTS = (
|
| 172 |
+
"https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/"
|
| 173 |
+
"faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
|
| 174 |
+
)
|
| 175 |
+
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 176 |
+
detector = detectron2.engine.DefaultPredictor(cfg)
|
| 177 |
+
return detector
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _load_model_and_detector_for_demo(checkpoint_path: str):
|
| 181 |
+
"""Run on a worker thread when using heartbeat polling (single entry point for executor)."""
|
| 182 |
+
model, model_cfg, renderer, cam_crop_to_full_fn, device = _load_prima_model(checkpoint_path)
|
| 183 |
+
detector = _build_detector()
|
| 184 |
+
return model, model_cfg, renderer, cam_crop_to_full_fn, device, detector
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# SuperAnimal defaults (same as in demo_tta parser)
|
| 188 |
+
SUPER_ANIMAL_ARGS = SimpleNamespace(
|
| 189 |
+
superanimal_name="superanimal_quadruped",
|
| 190 |
+
superanimal_model_name="hrnet_w32",
|
| 191 |
+
superanimal_detector_name="fasterrcnn_resnet50_fpn_v2",
|
| 192 |
+
superanimal_max_individuals=1,
|
| 193 |
+
saved_2d_model_path="",
|
| 194 |
+
pytorch_config_2d_path=str(_REPO_ROOT / "configs" / "sa_finetune_hrnet_w32.yaml"),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _collect_animal_results(
|
| 199 |
+
model,
|
| 200 |
+
model_cfg,
|
| 201 |
+
renderer,
|
| 202 |
+
cam_crop_to_full_fn,
|
| 203 |
+
device,
|
| 204 |
+
detector,
|
| 205 |
+
out_folder: str,
|
| 206 |
+
img_rgb: np.ndarray,
|
| 207 |
+
tta_lr: float,
|
| 208 |
+
tta_num_iters: int,
|
| 209 |
+
det_thresh: float,
|
| 210 |
+
kp_conf_thresh: float,
|
| 211 |
+
side_view: bool,
|
| 212 |
+
save_mesh: bool,
|
| 213 |
+
progress_callback: Callable[[str], None] | None = None,
|
| 214 |
+
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], str | None, str | None]:
|
| 215 |
+
"""Run detection + PRIMA + SuperAnimal + TTA on a single RGB image.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
before_imgs: list of HxWx3 RGB images (before TTA) for all animals
|
| 219 |
+
after_imgs: list of HxWx3 RGB images (after TTA) for all animals
|
| 220 |
+
kpt_imgs: list of HxWx3 RGB keypoint visualizations
|
| 221 |
+
first_before_mesh: path to first animal's before-TTA mesh (.obj) or None
|
| 222 |
+
first_after_mesh: path to first animal's after-TTA mesh (.obj) or None
|
| 223 |
+
"""
|
| 224 |
+
from prima.utils import recursive_to
|
| 225 |
+
from prima.datasets.vitdet_dataset import ViTDetDataset
|
| 226 |
+
from demo_tta import (
|
| 227 |
+
denorm_patch_to_rgb,
|
| 228 |
+
resolve_sa_weights_path,
|
| 229 |
+
run_superanimal_on_patch,
|
| 230 |
+
save_keypoint_vis,
|
| 231 |
+
tta_optimize,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
def report(message: str) -> None:
|
| 235 |
+
if progress_callback is not None:
|
| 236 |
+
progress_callback(message)
|
| 237 |
+
|
| 238 |
+
if int(tta_num_iters) > 0 and not SUPER_ANIMAL_ARGS.saved_2d_model_path:
|
| 239 |
+
report("Resolving SuperAnimal weights...")
|
| 240 |
+
SUPER_ANIMAL_ARGS.saved_2d_model_path = resolve_sa_weights_path("")
|
| 241 |
+
|
| 242 |
+
# Detect animals
|
| 243 |
+
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
| 244 |
+
if detector is None:
|
| 245 |
+
# Fallback for environments where Detectron2 is unavailable: process full image as one crop.
|
| 246 |
+
report("Detectron2 unavailable; using full-image crop...")
|
| 247 |
+
h, w = img_bgr.shape[:2]
|
| 248 |
+
boxes = np.array([[0.0, 0.0, float(max(1, w - 1)), float(max(1, h - 1))]], dtype=np.float32)
|
| 249 |
+
else:
|
| 250 |
+
report("Detecting animals with Detectron2...")
|
| 251 |
+
det_out = detector(img_bgr)
|
| 252 |
+
det_instances = det_out["instances"]
|
| 253 |
+
|
| 254 |
+
boxes, suppressed = select_animal_boxes(det_instances, score_threshold=float(det_thresh))
|
| 255 |
+
if suppressed > 0:
|
| 256 |
+
print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s)")
|
| 257 |
+
if len(boxes) == 0:
|
| 258 |
+
return [], [], [], None, None
|
| 259 |
+
|
| 260 |
+
report(f"Detected {len(boxes)} animal(s). Preparing crops...")
|
| 261 |
+
dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
|
| 262 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
|
| 263 |
+
|
| 264 |
+
before_imgs: List[np.ndarray] = []
|
| 265 |
+
after_imgs: List[np.ndarray] = []
|
| 266 |
+
kpt_imgs: List[np.ndarray] = []
|
| 267 |
+
before_mesh_paths: List[str] = []
|
| 268 |
+
after_mesh_paths: List[str] = []
|
| 269 |
+
|
| 270 |
+
img_token = next(tempfile._get_candidate_names())
|
| 271 |
+
|
| 272 |
+
total_batches = len(dataloader)
|
| 273 |
+
for batch_idx, batch in enumerate(dataloader, start=1):
|
| 274 |
+
batch = recursive_to(batch, device)
|
| 275 |
+
|
| 276 |
+
report(f"Animal {batch_idx}/{total_batches}: running PRIMA...")
|
| 277 |
+
with torch.no_grad():
|
| 278 |
+
out_before = model(batch)
|
| 279 |
+
|
| 280 |
+
animal_id = int(batch["animalid"][0])
|
| 281 |
+
|
| 282 |
+
# Save/render before TTA
|
| 283 |
+
img_fn = f"{img_token}"
|
| 284 |
+
from demo_tta import render_and_save # imported lazily to avoid circular issues
|
| 285 |
+
|
| 286 |
+
report(f"Animal {batch_idx}/{total_batches}: rendering before TTA...")
|
| 287 |
+
render_and_save(
|
| 288 |
+
renderer,
|
| 289 |
+
cam_crop_to_full_fn,
|
| 290 |
+
out_before,
|
| 291 |
+
batch,
|
| 292 |
+
img_fn,
|
| 293 |
+
animal_id,
|
| 294 |
+
out_folder,
|
| 295 |
+
suffix="before_tta",
|
| 296 |
+
side_view=side_view,
|
| 297 |
+
save_mesh=save_mesh,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
before_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.png")
|
| 301 |
+
if os.path.exists(before_png_path):
|
| 302 |
+
before_bgr = cv2.imread(before_png_path)
|
| 303 |
+
if before_bgr is not None:
|
| 304 |
+
before_imgs.append(cv2.cvtColor(before_bgr, cv2.COLOR_BGR2RGB))
|
| 305 |
+
|
| 306 |
+
if save_mesh:
|
| 307 |
+
before_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.obj")
|
| 308 |
+
if os.path.exists(before_obj_path):
|
| 309 |
+
before_mesh_paths.append(before_obj_path)
|
| 310 |
+
|
| 311 |
+
if int(tta_num_iters) <= 0:
|
| 312 |
+
report(f"Animal {batch_idx}/{total_batches}: rendering final output...")
|
| 313 |
+
render_and_save(
|
| 314 |
+
renderer,
|
| 315 |
+
cam_crop_to_full_fn,
|
| 316 |
+
out_before,
|
| 317 |
+
batch,
|
| 318 |
+
img_fn,
|
| 319 |
+
animal_id,
|
| 320 |
+
out_folder,
|
| 321 |
+
suffix="after_tta",
|
| 322 |
+
side_view=side_view,
|
| 323 |
+
save_mesh=save_mesh,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
|
| 327 |
+
if os.path.exists(after_png_path):
|
| 328 |
+
after_bgr = cv2.imread(after_png_path)
|
| 329 |
+
if after_bgr is not None:
|
| 330 |
+
after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB))
|
| 331 |
+
|
| 332 |
+
if save_mesh:
|
| 333 |
+
after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj")
|
| 334 |
+
if os.path.exists(after_obj_path):
|
| 335 |
+
after_mesh_paths.append(after_obj_path)
|
| 336 |
+
continue
|
| 337 |
+
|
| 338 |
+
# Prepare patch for SuperAnimal
|
| 339 |
+
report(f"Animal {batch_idx}/{total_batches}: running SuperAnimal keypoints...")
|
| 340 |
+
patch_rgb = denorm_patch_to_rgb(batch["img"][0])
|
| 341 |
+
with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir:
|
| 342 |
+
bodyparts_xyc = run_superanimal_on_patch(patch_rgb, SUPER_ANIMAL_ARGS, tmp_dir)
|
| 343 |
+
|
| 344 |
+
if bodyparts_xyc is None:
|
| 345 |
+
# No keypoints => skip TTA for this animal
|
| 346 |
+
continue
|
| 347 |
+
|
| 348 |
+
kpts_xyc = bodyparts_xyc
|
| 349 |
+
kpts_xyc[kpts_xyc[:, 2] < float(kp_conf_thresh), 2] = 0.0
|
| 350 |
+
|
| 351 |
+
# Save keypoint visualization and npy
|
| 352 |
+
kpt_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png")
|
| 353 |
+
save_keypoint_vis(patch_rgb, kpts_xyc, kpt_png_path)
|
| 354 |
+
npy_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy")
|
| 355 |
+
np.save(npy_path, kpts_xyc)
|
| 356 |
+
|
| 357 |
+
if os.path.exists(kpt_png_path):
|
| 358 |
+
kpt_bgr = cv2.imread(kpt_png_path)
|
| 359 |
+
if kpt_bgr is not None:
|
| 360 |
+
kpt_imgs.append(cv2.cvtColor(kpt_bgr, cv2.COLOR_BGR2RGB))
|
| 361 |
+
|
| 362 |
+
# Normalize keypoints to [-0.5, 0.5] as in demo_tta
|
| 363 |
+
patch_h, patch_w = patch_rgb.shape[:2]
|
| 364 |
+
kpts_norm = kpts_xyc.copy()
|
| 365 |
+
kpts_norm[:, 0] = kpts_norm[:, 0] / float(patch_w) - 0.5
|
| 366 |
+
kpts_norm[:, 1] = kpts_norm[:, 1] / float(patch_h) - 0.5
|
| 367 |
+
gt_kpts_norm = torch.from_numpy(kpts_norm[None]).to(device=device, dtype=batch["img"].dtype)
|
| 368 |
+
|
| 369 |
+
# Run TTA
|
| 370 |
+
report(f"Animal {batch_idx}/{total_batches}: running TTA ({int(tta_num_iters)} iterations)...")
|
| 371 |
+
out_after = tta_optimize(
|
| 372 |
+
model,
|
| 373 |
+
batch,
|
| 374 |
+
gt_kpts_norm,
|
| 375 |
+
num_iters=int(tta_num_iters),
|
| 376 |
+
lr=float(tta_lr),
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
report(f"Animal {batch_idx}/{total_batches}: rendering after TTA...")
|
| 380 |
+
render_and_save(
|
| 381 |
+
renderer,
|
| 382 |
+
cam_crop_to_full_fn,
|
| 383 |
+
out_after,
|
| 384 |
+
batch,
|
| 385 |
+
img_fn,
|
| 386 |
+
animal_id,
|
| 387 |
+
out_folder,
|
| 388 |
+
suffix="after_tta",
|
| 389 |
+
side_view=side_view,
|
| 390 |
+
save_mesh=save_mesh,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
|
| 394 |
+
if os.path.exists(after_png_path):
|
| 395 |
+
after_bgr = cv2.imread(after_png_path)
|
| 396 |
+
if after_bgr is not None:
|
| 397 |
+
after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB))
|
| 398 |
+
|
| 399 |
+
if save_mesh:
|
| 400 |
+
after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj")
|
| 401 |
+
if os.path.exists(after_obj_path):
|
| 402 |
+
after_mesh_paths.append(after_obj_path)
|
| 403 |
+
|
| 404 |
+
first_before_mesh = before_mesh_paths[0] if before_mesh_paths else None
|
| 405 |
+
first_after_mesh = after_mesh_paths[0] if after_mesh_paths else None
|
| 406 |
+
|
| 407 |
+
report("Collecting outputs...")
|
| 408 |
+
return before_imgs, after_imgs, kpt_imgs, first_before_mesh, first_after_mesh
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def build_demo(checkpoint_path: str = DEFAULT_CHECKPOINT, out_folder: str = DEFAULT_OUT_FOLDER) -> gr.Interface:
|
| 412 |
+
os.makedirs(out_folder, exist_ok=True)
|
| 413 |
+
runtime_cache = {
|
| 414 |
+
"model": None,
|
| 415 |
+
"model_cfg": None,
|
| 416 |
+
"renderer": None,
|
| 417 |
+
"cam_crop_to_full_fn": None,
|
| 418 |
+
"device": None,
|
| 419 |
+
"detector": None,
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
def gradio_inference(
|
| 423 |
+
image: np.ndarray,
|
| 424 |
+
tta_lr: float,
|
| 425 |
+
tta_num_iters: int,
|
| 426 |
+
det_thresh: float,
|
| 427 |
+
kp_conf_thresh: float,
|
| 428 |
+
side_view: bool,
|
| 429 |
+
save_mesh: bool,
|
| 430 |
+
):
|
| 431 |
+
"""Wrapper for Gradio. ``image`` is an RGB numpy array.
|
| 432 |
+
|
| 433 |
+
Yields intermediate status so long first-run (Hub downloads + model load)
|
| 434 |
+
and long inference do not hit silent client/proxy WebSocket timeouts.
|
| 435 |
+
"""
|
| 436 |
+
|
| 437 |
+
if image is None:
|
| 438 |
+
yield None, None, None, "No image provided."
|
| 439 |
+
return
|
| 440 |
+
|
| 441 |
+
if image.dtype != np.uint8:
|
| 442 |
+
img_rgb = np.clip(image, 0, 255).astype(np.uint8)
|
| 443 |
+
else:
|
| 444 |
+
img_rgb = image
|
| 445 |
+
|
| 446 |
+
yield None, None, None, "Queued; preparing run…"
|
| 447 |
+
|
| 448 |
+
hb = _gradio_heartbeat_interval_sec()
|
| 449 |
+
|
| 450 |
+
if runtime_cache["model"] is None:
|
| 451 |
+
yield (
|
| 452 |
+
None,
|
| 453 |
+
None,
|
| 454 |
+
None,
|
| 455 |
+
"First run: downloading demo assets from Hugging Face (large checkpoint) "
|
| 456 |
+
"and loading the model. This can take many minutes; status updates here "
|
| 457 |
+
"mean the session is still alive.",
|
| 458 |
+
)
|
| 459 |
+
try:
|
| 460 |
+
if hb <= 0:
|
| 461 |
+
model, model_cfg, renderer, cam_crop_to_full_fn, device, detector = _load_model_and_detector_for_demo(
|
| 462 |
+
checkpoint_path
|
| 463 |
+
)
|
| 464 |
+
else:
|
| 465 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
| 466 |
+
fut = pool.submit(_load_model_and_detector_for_demo, checkpoint_path)
|
| 467 |
+
t0 = time.monotonic()
|
| 468 |
+
while True:
|
| 469 |
+
try:
|
| 470 |
+
model, model_cfg, renderer, cam_crop_to_full_fn, device, detector = fut.result(timeout=hb)
|
| 471 |
+
break
|
| 472 |
+
except concurrent.futures.TimeoutError:
|
| 473 |
+
elapsed = int(time.monotonic() - t0)
|
| 474 |
+
yield None, None, None, (
|
| 475 |
+
f"First run: still loading model and assets ({elapsed}s). "
|
| 476 |
+
f"Updates every ~{int(hb)}s keep the browser connection open on Spaces."
|
| 477 |
+
)
|
| 478 |
+
except Exception:
|
| 479 |
+
yield None, None, None, f"Model initialization failed:\n{traceback.format_exc()}"
|
| 480 |
+
return
|
| 481 |
+
runtime_cache["model"] = model
|
| 482 |
+
runtime_cache["model_cfg"] = model_cfg
|
| 483 |
+
runtime_cache["renderer"] = renderer
|
| 484 |
+
runtime_cache["cam_crop_to_full_fn"] = cam_crop_to_full_fn
|
| 485 |
+
runtime_cache["device"] = device
|
| 486 |
+
runtime_cache["detector"] = detector
|
| 487 |
+
yield None, None, None, "Model loaded. Running detection and inference…"
|
| 488 |
+
|
| 489 |
+
try:
|
| 490 |
+
if hb <= 0:
|
| 491 |
+
before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = _collect_animal_results(
|
| 492 |
+
runtime_cache["model"],
|
| 493 |
+
runtime_cache["model_cfg"],
|
| 494 |
+
runtime_cache["renderer"],
|
| 495 |
+
runtime_cache["cam_crop_to_full_fn"],
|
| 496 |
+
runtime_cache["device"],
|
| 497 |
+
runtime_cache["detector"],
|
| 498 |
+
out_folder,
|
| 499 |
+
img_rgb,
|
| 500 |
+
tta_lr=tta_lr,
|
| 501 |
+
tta_num_iters=tta_num_iters,
|
| 502 |
+
det_thresh=det_thresh,
|
| 503 |
+
kp_conf_thresh=kp_conf_thresh,
|
| 504 |
+
side_view=side_view,
|
| 505 |
+
save_mesh=save_mesh,
|
| 506 |
+
)
|
| 507 |
+
else:
|
| 508 |
+
stage_updates: queue.Queue[str] = queue.Queue()
|
| 509 |
+
|
| 510 |
+
def report_stage(message: str) -> None:
|
| 511 |
+
stage_updates.put(message)
|
| 512 |
+
|
| 513 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
| 514 |
+
fut = pool.submit(
|
| 515 |
+
_collect_animal_results,
|
| 516 |
+
runtime_cache["model"],
|
| 517 |
+
runtime_cache["model_cfg"],
|
| 518 |
+
runtime_cache["renderer"],
|
| 519 |
+
runtime_cache["cam_crop_to_full_fn"],
|
| 520 |
+
runtime_cache["device"],
|
| 521 |
+
runtime_cache["detector"],
|
| 522 |
+
out_folder,
|
| 523 |
+
img_rgb,
|
| 524 |
+
tta_lr,
|
| 525 |
+
tta_num_iters,
|
| 526 |
+
det_thresh,
|
| 527 |
+
kp_conf_thresh,
|
| 528 |
+
side_view,
|
| 529 |
+
save_mesh,
|
| 530 |
+
report_stage,
|
| 531 |
+
)
|
| 532 |
+
t0 = time.monotonic()
|
| 533 |
+
latest_stage = "Starting inference..."
|
| 534 |
+
while True:
|
| 535 |
+
while True:
|
| 536 |
+
try:
|
| 537 |
+
latest_stage = stage_updates.get_nowait()
|
| 538 |
+
except queue.Empty:
|
| 539 |
+
break
|
| 540 |
+
else:
|
| 541 |
+
elapsed = int(time.monotonic() - t0)
|
| 542 |
+
yield None, None, None, f"{latest_stage}\nElapsed: {elapsed}s"
|
| 543 |
+
try:
|
| 544 |
+
before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = fut.result(
|
| 545 |
+
timeout=1.0
|
| 546 |
+
)
|
| 547 |
+
break
|
| 548 |
+
except concurrent.futures.TimeoutError:
|
| 549 |
+
elapsed = int(time.monotonic() - t0)
|
| 550 |
+
yield None, None, None, (
|
| 551 |
+
f"{latest_stage}\n"
|
| 552 |
+
f"Elapsed: {elapsed}s\n"
|
| 553 |
+
"CPU inference can take several minutes."
|
| 554 |
+
)
|
| 555 |
+
except Exception:
|
| 556 |
+
yield None, None, None, f"Inference failed:\n{traceback.format_exc()}"
|
| 557 |
+
return
|
| 558 |
+
|
| 559 |
+
first_before = before_imgs[0] if before_imgs else None
|
| 560 |
+
first_after = after_imgs[0] if after_imgs else None
|
| 561 |
+
first_kpts = kpt_imgs[0] if kpt_imgs else None
|
| 562 |
+
if first_before is None and first_after is None:
|
| 563 |
+
yield (
|
| 564 |
+
None,
|
| 565 |
+
None,
|
| 566 |
+
None,
|
| 567 |
+
"No output generated. Try an image with a clearly visible quadruped.",
|
| 568 |
+
)
|
| 569 |
+
return
|
| 570 |
+
yield first_before, first_after, first_kpts, "OK"
|
| 571 |
+
|
| 572 |
+
_gradio_examples = _gradio_examples_for_interface()
|
| 573 |
+
_iface_kw = dict(
|
| 574 |
+
fn=gradio_inference,
|
| 575 |
+
analytics_enabled=False,
|
| 576 |
+
cache_examples=False,
|
| 577 |
+
inputs=[
|
| 578 |
+
gr.Image(
|
| 579 |
+
label="Input image",
|
| 580 |
+
type="numpy",
|
| 581 |
+
sources=["upload", "clipboard"],
|
| 582 |
+
),
|
| 583 |
+
gr.Slider(
|
| 584 |
+
label="TTA learning rate",
|
| 585 |
+
minimum=1e-7,
|
| 586 |
+
maximum=1e-4,
|
| 587 |
+
value=1e-6,
|
| 588 |
+
step=1e-7,
|
| 589 |
+
),
|
| 590 |
+
gr.Slider(
|
| 591 |
+
label="TTA iterations",
|
| 592 |
+
minimum=0,
|
| 593 |
+
maximum=100,
|
| 594 |
+
value=0,
|
| 595 |
+
step=1,
|
| 596 |
+
info="Set to 0 to disable TTA and reuse the initial PRIMA prediction.",
|
| 597 |
+
),
|
| 598 |
+
gr.Slider(
|
| 599 |
+
label="Detection threshold",
|
| 600 |
+
minimum=0.3,
|
| 601 |
+
maximum=0.9,
|
| 602 |
+
value=0.7,
|
| 603 |
+
step=0.05,
|
| 604 |
+
),
|
| 605 |
+
gr.Slider(
|
| 606 |
+
label="Keypoint confidence threshold",
|
| 607 |
+
minimum=0.0,
|
| 608 |
+
maximum=1.0,
|
| 609 |
+
value=0.1,
|
| 610 |
+
step=0.05,
|
| 611 |
+
),
|
| 612 |
+
gr.Checkbox(label="Render side view", value=False),
|
| 613 |
+
gr.Checkbox(label="Save meshes (.obj)", value=True),
|
| 614 |
+
],
|
| 615 |
+
outputs=[
|
| 616 |
+
gr.Image(label="Before TTA"),
|
| 617 |
+
gr.Image(label="After TTA"),
|
| 618 |
+
gr.Image(label="PRIMA 26 keypoints"),
|
| 619 |
+
gr.Textbox(label="Status / Traceback", lines=12),
|
| 620 |
+
],
|
| 621 |
+
title="PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation",
|
| 622 |
+
description=(
|
| 623 |
+
"Upload an animal image. The demo runs Detectron2 for animal detection, "
|
| 624 |
+
"PRIMA for 3D pose/shape, DeepLabCut SuperAnimal for 2D keypoints, and "
|
| 625 |
+
"test-time adaptation (TTA) with configurable learning rate and iterations. "
|
| 626 |
+
"Set TTA iterations to 0 to disable adaptation.\n\n"
|
| 627 |
+
"Results (PNG/OBJ and 26-keypoint visualizations) are saved under "
|
| 628 |
+
f"'{out_folder}'."
|
| 629 |
+
),
|
| 630 |
+
)
|
| 631 |
+
if _gradio_examples:
|
| 632 |
+
_iface_kw["examples"] = _gradio_examples
|
| 633 |
+
demo = gr.Interface(**_iface_kw)
|
| 634 |
+
demo.queue(max_size=8, default_concurrency_limit=1)
|
| 635 |
+
return demo
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def parse_args() -> argparse.Namespace:
|
| 639 |
+
parser = argparse.ArgumentParser(description="Gradio demo for PRIMA + SuperAnimal + TTA")
|
| 640 |
+
parser.add_argument(
|
| 641 |
+
"--checkpoint",
|
| 642 |
+
type=str,
|
| 643 |
+
default=DEFAULT_CHECKPOINT,
|
| 644 |
+
help="Path to the pretrained PRIMA checkpoint",
|
| 645 |
+
)
|
| 646 |
+
parser.add_argument(
|
| 647 |
+
"--out_folder",
|
| 648 |
+
type=str,
|
| 649 |
+
default=DEFAULT_OUT_FOLDER,
|
| 650 |
+
help="Folder used to save rendered outputs and meshes",
|
| 651 |
+
)
|
| 652 |
+
return parser.parse_args()
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
if __name__ == "__main__":
|
| 656 |
+
args = parse_args()
|
| 657 |
+
if _should_preload_assets():
|
| 658 |
+
_preload_assets_once(args.checkpoint)
|
| 659 |
+
demo = build_demo(checkpoint_path=args.checkpoint, out_folder=args.out_folder)
|
| 660 |
+
demo.launch(inbrowser=False, ssr_mode=False)
|
chumpy/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
# Minimal ``chumpy`` compatibility for unpickling legacy SMAL model configs.
|
| 13 |
+
|
| 14 |
+
from .ch import Ch, ChArray
|
| 15 |
+
|
| 16 |
+
__all__ = ["Ch", "ChArray"]
|
chumpy/ch.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
# ``chumpy.ch`` namespace expected by legacy SMAL pickles.
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Ch:
|
| 18 |
+
"""Minimal stand-in for ``chumpy.ch.Ch`` (unpickling only)."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, *args, **kwargs):
|
| 21 |
+
self._data = None
|
| 22 |
+
if args:
|
| 23 |
+
self._data = np.asarray(args[0])
|
| 24 |
+
|
| 25 |
+
def _resolve(self) -> np.ndarray:
|
| 26 |
+
# Real chumpy Ch instances store the underlying ndarray on attribute ``x``;
|
| 27 |
+
# legacy pickles unpickle by restoring ``__dict__`` without calling ``__init__``,
|
| 28 |
+
# so try common attribute names before falling back to ``_data``.
|
| 29 |
+
for attr in ("x", "_x", "_data"):
|
| 30 |
+
val = self.__dict__.get(attr)
|
| 31 |
+
if val is not None:
|
| 32 |
+
return np.asarray(val)
|
| 33 |
+
return np.zeros((), dtype=np.float32)
|
| 34 |
+
|
| 35 |
+
def __array__(self, dtype=None):
|
| 36 |
+
arr = self._resolve()
|
| 37 |
+
if dtype is not None:
|
| 38 |
+
arr = arr.astype(dtype, copy=False)
|
| 39 |
+
return arr
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def r(self) -> np.ndarray:
|
| 43 |
+
return self._resolve()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ChArray(np.ndarray):
|
| 47 |
+
"""Minimal stand-in for ``chumpy.ch.ChArray``."""
|
| 48 |
+
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
__all__ = ["Ch", "ChArray"]
|
configs/sa_finetune_hrnet_w32.yaml
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeepLabCut pytorch_config for the PRIMA TTA 2D pose model:
|
| 2 |
+
# SuperAnimal-Quadruped HRNet-w32 backbone fine-tuned on Animal3D, with
|
| 3 |
+
# the heatmap head re-trained for the 26-joint Animal3D / PRIMA layout.
|
| 4 |
+
#
|
| 5 |
+
# Used by demo_tta.py via DLC's `superanimal_analyze_images(...,
|
| 6 |
+
# customized_model_config=<this yaml>, customized_pose_checkpoint=<your
|
| 7 |
+
# fine-tuned .pt>)`. Only the pose model is fine-tuned; the bounding-box
|
| 8 |
+
# detector (Faster R-CNN) is the stock SuperAnimal-Quadruped one
|
| 9 |
+
# resolved by DLC at runtime.
|
| 10 |
+
data:
|
| 11 |
+
bbox_margin: 20
|
| 12 |
+
colormode: RGB
|
| 13 |
+
inference:
|
| 14 |
+
normalize_images: true
|
| 15 |
+
top_down_crop:
|
| 16 |
+
width: 256
|
| 17 |
+
height: 256
|
| 18 |
+
auto_padding:
|
| 19 |
+
pad_width_divisor: 32
|
| 20 |
+
pad_height_divisor: 32
|
| 21 |
+
train:
|
| 22 |
+
affine:
|
| 23 |
+
p: 0.5
|
| 24 |
+
rotation: 30
|
| 25 |
+
scaling:
|
| 26 |
+
- 1.0
|
| 27 |
+
- 1.0
|
| 28 |
+
translation: 0
|
| 29 |
+
gaussian_noise: 12.75
|
| 30 |
+
motion_blur: true
|
| 31 |
+
normalize_images: true
|
| 32 |
+
top_down_crop:
|
| 33 |
+
width: 256
|
| 34 |
+
height: 256
|
| 35 |
+
auto_padding:
|
| 36 |
+
pad_width_divisor: 32
|
| 37 |
+
pad_height_divisor: 32
|
| 38 |
+
detector:
|
| 39 |
+
data:
|
| 40 |
+
colormode: RGB
|
| 41 |
+
inference:
|
| 42 |
+
normalize_images: true
|
| 43 |
+
train:
|
| 44 |
+
affine:
|
| 45 |
+
p: 0.5
|
| 46 |
+
rotation: 30
|
| 47 |
+
scaling:
|
| 48 |
+
- 1.0
|
| 49 |
+
- 1.0
|
| 50 |
+
translation: 40
|
| 51 |
+
collate:
|
| 52 |
+
type: ResizeFromDataSizeCollate
|
| 53 |
+
min_scale: 0.4
|
| 54 |
+
max_scale: 1.0
|
| 55 |
+
min_short_side: 128
|
| 56 |
+
max_short_side: 1152
|
| 57 |
+
multiple_of: 32
|
| 58 |
+
to_square: false
|
| 59 |
+
hflip: true
|
| 60 |
+
normalize_images: true
|
| 61 |
+
device: auto
|
| 62 |
+
model:
|
| 63 |
+
type: FasterRCNN
|
| 64 |
+
freeze_bn_stats: true
|
| 65 |
+
freeze_bn_weights: false
|
| 66 |
+
variant: fasterrcnn_resnet50_fpn_v2
|
| 67 |
+
runner:
|
| 68 |
+
type: DetectorTrainingRunner
|
| 69 |
+
key_metric: test.mAP@50:95
|
| 70 |
+
key_metric_asc: true
|
| 71 |
+
eval_interval: 10
|
| 72 |
+
optimizer:
|
| 73 |
+
type: AdamW
|
| 74 |
+
params:
|
| 75 |
+
lr: 0.0001
|
| 76 |
+
scheduler:
|
| 77 |
+
type: LRListScheduler
|
| 78 |
+
params:
|
| 79 |
+
milestones:
|
| 80 |
+
- 160
|
| 81 |
+
lr_list:
|
| 82 |
+
- - 1e-05
|
| 83 |
+
snapshots:
|
| 84 |
+
max_snapshots: 5
|
| 85 |
+
save_epochs: 25
|
| 86 |
+
save_optimizer_state: false
|
| 87 |
+
train_settings:
|
| 88 |
+
batch_size: 1
|
| 89 |
+
dataloader_workers: 0
|
| 90 |
+
dataloader_pin_memory: false
|
| 91 |
+
display_iters: 500
|
| 92 |
+
epochs: 250
|
| 93 |
+
device: auto
|
| 94 |
+
inference:
|
| 95 |
+
multithreading:
|
| 96 |
+
enabled: true
|
| 97 |
+
queue_length: 4
|
| 98 |
+
timeout: 30.0
|
| 99 |
+
compile:
|
| 100 |
+
enabled: false
|
| 101 |
+
backend: inductor
|
| 102 |
+
autocast:
|
| 103 |
+
enabled: false
|
| 104 |
+
metadata:
|
| 105 |
+
project_path: ""
|
| 106 |
+
pose_config_path: ""
|
| 107 |
+
bodyparts:
|
| 108 |
+
- left_eye
|
| 109 |
+
- right_eye
|
| 110 |
+
- chin
|
| 111 |
+
- left_front_paw
|
| 112 |
+
- right_front_paw
|
| 113 |
+
- left_back_paw
|
| 114 |
+
- right_back_paw
|
| 115 |
+
- tail_base
|
| 116 |
+
- left_front_thigh
|
| 117 |
+
- right_front_thigh
|
| 118 |
+
- left_back_thigh
|
| 119 |
+
- right_back_thigh
|
| 120 |
+
- left_shoulder
|
| 121 |
+
- right_shoulder
|
| 122 |
+
- left_front_knee
|
| 123 |
+
- right_front_knee
|
| 124 |
+
- left_back_knee
|
| 125 |
+
- right_back_knee
|
| 126 |
+
- neck_base
|
| 127 |
+
- tail_mid
|
| 128 |
+
- left_ear_base
|
| 129 |
+
- right_ear_base
|
| 130 |
+
- left_mouth_corner
|
| 131 |
+
- right_mouth_corner
|
| 132 |
+
- nose
|
| 133 |
+
- tail_tip_first
|
| 134 |
+
unique_bodyparts: []
|
| 135 |
+
individuals:
|
| 136 |
+
- individual000
|
| 137 |
+
with_identity: false
|
| 138 |
+
method: td
|
| 139 |
+
model:
|
| 140 |
+
backbone:
|
| 141 |
+
type: HRNet
|
| 142 |
+
model_name: hrnet_w32
|
| 143 |
+
freeze_bn_stats: true
|
| 144 |
+
freeze_bn_weights: false
|
| 145 |
+
interpolate_branches: false
|
| 146 |
+
increased_channel_count: false
|
| 147 |
+
backbone_output_channels: 32
|
| 148 |
+
heads:
|
| 149 |
+
bodypart:
|
| 150 |
+
type: HeatmapHead
|
| 151 |
+
weight_init: normal
|
| 152 |
+
predictor:
|
| 153 |
+
type: HeatmapPredictor
|
| 154 |
+
apply_sigmoid: false
|
| 155 |
+
clip_scores: true
|
| 156 |
+
location_refinement: true
|
| 157 |
+
locref_std: 7.2801
|
| 158 |
+
target_generator:
|
| 159 |
+
type: HeatmapGaussianGenerator
|
| 160 |
+
num_heatmaps: 26
|
| 161 |
+
pos_dist_thresh: 17
|
| 162 |
+
heatmap_mode: KEYPOINT
|
| 163 |
+
gradient_masking: true
|
| 164 |
+
background_weight: 0.0
|
| 165 |
+
generate_locref: true
|
| 166 |
+
locref_std: 7.2801
|
| 167 |
+
criterion:
|
| 168 |
+
heatmap:
|
| 169 |
+
type: WeightedMSECriterion
|
| 170 |
+
weight: 1.0
|
| 171 |
+
locref:
|
| 172 |
+
type: WeightedHuberCriterion
|
| 173 |
+
weight: 0.05
|
| 174 |
+
heatmap_config:
|
| 175 |
+
channels:
|
| 176 |
+
- 32
|
| 177 |
+
kernel_size: []
|
| 178 |
+
strides: []
|
| 179 |
+
final_conv:
|
| 180 |
+
out_channels: 26
|
| 181 |
+
kernel_size: 1
|
| 182 |
+
locref_config:
|
| 183 |
+
channels:
|
| 184 |
+
- 32
|
| 185 |
+
kernel_size: []
|
| 186 |
+
strides: []
|
| 187 |
+
final_conv:
|
| 188 |
+
out_channels: 52
|
| 189 |
+
kernel_size: 1
|
| 190 |
+
net_type: hrnet_w32
|
| 191 |
+
runner:
|
| 192 |
+
type: PoseTrainingRunner
|
| 193 |
+
gpus:
|
| 194 |
+
key_metric: test.mAP
|
| 195 |
+
key_metric_asc: true
|
| 196 |
+
eval_interval: 10
|
| 197 |
+
optimizer:
|
| 198 |
+
type: AdamW
|
| 199 |
+
params:
|
| 200 |
+
lr: 0.0001
|
| 201 |
+
scheduler:
|
| 202 |
+
type: LRListScheduler
|
| 203 |
+
params:
|
| 204 |
+
lr_list:
|
| 205 |
+
- - 1e-05
|
| 206 |
+
- - 1e-06
|
| 207 |
+
milestones:
|
| 208 |
+
- 160
|
| 209 |
+
- 190
|
| 210 |
+
snapshots:
|
| 211 |
+
max_snapshots: 5
|
| 212 |
+
save_epochs: 10
|
| 213 |
+
save_optimizer_state: false
|
| 214 |
+
train_settings:
|
| 215 |
+
batch_size: 64
|
| 216 |
+
dataloader_workers: 8
|
| 217 |
+
dataloader_pin_memory: false
|
| 218 |
+
display_iters: 500
|
| 219 |
+
epochs: 200
|
| 220 |
+
seed: 42
|
demo_data/000000015956_horse.png
ADDED
|
Git LFS Details
|
demo_data/000000315905_zebra.jpg
ADDED
|
Git LFS Details
|
demo_data/beagle.jpg
ADDED
|
Git LFS Details
|
demo_data/n02101388_1188.png
ADDED
|
Git LFS Details
|
demo_data/n02412080_12159.png
ADDED
|
Git LFS Details
|
demo_data/shepherd_hati.jpg
ADDED
|
Git LFS Details
|
demo_tta.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
demo_tta.py: PRIMA inference with fine-tuned DeepLabCut SuperAnimal TTA
|
| 12 |
+
|
| 13 |
+
Pipeline:
|
| 14 |
+
1. Run Detectron2 to detect animals in the input image.
|
| 15 |
+
2. Run PRIMA on each detected animal to obtain 3D pose/shape estimation.
|
| 16 |
+
3. Run a fine-tuned DeepLabCut SuperAnimal pose model (Animal3D 26-joint
|
| 17 |
+
layout) to obtain 2D keypoints already in PRIMA topology. The fine-tuned
|
| 18 |
+
snapshot is wired into DLC's
|
| 19 |
+
``superanimal_analyze_images`` via the ``customized_pose_checkpoint``
|
| 20 |
+
and ``customized_model_config`` kwargs.
|
| 21 |
+
4. Run test-time adaptation (TTA) with user-specified lr and num_iters
|
| 22 |
+
to further optimize the 3D pose and shape estimation.
|
| 23 |
+
5. Render and save before/after TTA results (PNG + OBJ) and the
|
| 24 |
+
26-keypoint visualization (PNG).
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
import argparse
|
| 30 |
+
import copy
|
| 31 |
+
import os
|
| 32 |
+
import tempfile
|
| 33 |
+
import warnings
|
| 34 |
+
|
| 35 |
+
import cv2
|
| 36 |
+
import numpy as np
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn.functional as F
|
| 39 |
+
import torch.utils.data
|
| 40 |
+
from tqdm import tqdm
|
| 41 |
+
|
| 42 |
+
from prima.models import load_prima
|
| 43 |
+
from prima.utils import recursive_to
|
| 44 |
+
from prima.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
|
| 45 |
+
from prima.utils.detection import ANIMAL_COCO_IDS, select_animal_boxes
|
| 46 |
+
from prima.utils.weights import DEFAULT_HF_REPO_ID, resolve_prima_checkpoint_path
|
| 47 |
+
|
| 48 |
+
warnings.filterwarnings("ignore")
|
| 49 |
+
|
| 50 |
+
LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353)
|
| 51 |
+
GREEN = (0.65, 0.86, 0.74)
|
| 52 |
+
|
| 53 |
+
REPO_ROOT = Path(__file__).resolve().parent
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_renderer_components():
|
| 57 |
+
try:
|
| 58 |
+
from prima.utils.renderer import Renderer, cam_crop_to_full
|
| 59 |
+
except Exception as exc:
|
| 60 |
+
raise RuntimeError(
|
| 61 |
+
"Cannot initialize the PRIMA renderer. Rendering requires a working "
|
| 62 |
+
"pyrender/OpenGL backend such as EGL or OSMesa. Install the missing "
|
| 63 |
+
"OpenGL runtime for this environment, or run in an environment where "
|
| 64 |
+
"PYOPENGL_PLATFORM=egl/osmesa works."
|
| 65 |
+
) from exc
|
| 66 |
+
return Renderer, cam_crop_to_full
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def denorm_patch_to_rgb(img_tensor: torch.Tensor) -> np.ndarray:
|
| 70 |
+
patch = (img_tensor.detach().cpu() * (DEFAULT_STD[:, None, None]) + DEFAULT_MEAN[:, None, None]) / 255.0
|
| 71 |
+
patch = patch.permute(1, 2, 0).numpy()
|
| 72 |
+
return np.clip(patch, 0.0, 1.0)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def save_keypoint_vis(patch_rgb: np.ndarray, kpts_xyc: np.ndarray, save_path: str) -> None:
|
| 76 |
+
vis = cv2.cvtColor((patch_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2BGR).copy()
|
| 77 |
+
num_kpts = len(kpts_xyc)
|
| 78 |
+
|
| 79 |
+
for i, (x, y, c) in enumerate(kpts_xyc):
|
| 80 |
+
if c <= 0:
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
# Use distinct color for each keypoint (OpenCV uses BGR)
|
| 84 |
+
hue = int(179 * i / max(1, num_kpts - 1))
|
| 85 |
+
color_bgr = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2BGR)[0, 0]
|
| 86 |
+
color_bgr = (int(color_bgr[0]), int(color_bgr[1]), int(color_bgr[2]))
|
| 87 |
+
|
| 88 |
+
cx, cy = int(round(float(x))), int(round(float(y)))
|
| 89 |
+
cv2.circle(vis, (cx, cy), 3, color_bgr, -1)
|
| 90 |
+
cv2.putText(vis, str(i), (cx + 3, cy - 3), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1, cv2.LINE_AA)
|
| 91 |
+
|
| 92 |
+
cv2.imwrite(save_path, vis)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def resolve_sa_weights_path(local_path: str) -> str:
|
| 96 |
+
"""Return a local path to the fine-tuned SuperAnimal .pt snapshot.
|
| 97 |
+
|
| 98 |
+
If ``local_path`` is empty, downloads ``sa_finetune_hrnet_w32.pt`` from the
|
| 99 |
+
``MLAdaptiveIntelligence/FMPose3D`` Hugging Face repo (cached under
|
| 100 |
+
``~/.cache/huggingface``).
|
| 101 |
+
"""
|
| 102 |
+
if local_path:
|
| 103 |
+
return local_path
|
| 104 |
+
try:
|
| 105 |
+
from huggingface_hub import hf_hub_download
|
| 106 |
+
except ImportError:
|
| 107 |
+
raise ImportError(
|
| 108 |
+
"huggingface_hub is required to auto-download the fine-tuned "
|
| 109 |
+
"SuperAnimal weights. Install with `pip install huggingface_hub`, "
|
| 110 |
+
"or pass --saved_2d_model_path with a local .pt file."
|
| 111 |
+
) from None
|
| 112 |
+
repo_id = "MLAdaptiveIntelligence/FMPose3D"
|
| 113 |
+
filename = "sa_finetune_hrnet_w32.pt"
|
| 114 |
+
try:
|
| 115 |
+
cached_path = hf_hub_download(repo_id=repo_id, filename=filename, local_files_only=True)
|
| 116 |
+
except Exception:
|
| 117 |
+
print(f"No --saved_2d_model_path provided; downloading '{filename}' from {repo_id}...")
|
| 118 |
+
return hf_hub_download(repo_id=repo_id, filename=filename)
|
| 119 |
+
|
| 120 |
+
print(f"Using cached SuperAnimal weights: {cached_path}")
|
| 121 |
+
return cached_path
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def run_superanimal_on_patch(patch_rgb: np.ndarray, args, tmp_dir: str):
|
| 125 |
+
"""Predict 26-joint 2D keypoints on a single PRIMA patch using a
|
| 126 |
+
fine-tuned DeepLabCut SuperAnimal snapshot.
|
| 127 |
+
|
| 128 |
+
Returns an ``(26, 3)`` array of ``(x, y, confidence)`` in patch
|
| 129 |
+
pixel coordinates, or ``None`` if no individual was detected.
|
| 130 |
+
"""
|
| 131 |
+
try:
|
| 132 |
+
from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images
|
| 133 |
+
except Exception as e:
|
| 134 |
+
raise RuntimeError(
|
| 135 |
+
"Cannot import DeepLabCut SuperAnimal API. Please install deeplabcut with pose_estimation_pytorch support."
|
| 136 |
+
) from e
|
| 137 |
+
|
| 138 |
+
patch_path = os.path.join(tmp_dir, "patch.png")
|
| 139 |
+
cv2.imwrite(patch_path, cv2.cvtColor((patch_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
|
| 140 |
+
|
| 141 |
+
dlc_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 142 |
+
preds = superanimal_analyze_images(
|
| 143 |
+
superanimal_name=args.superanimal_name,
|
| 144 |
+
model_name=args.superanimal_model_name,
|
| 145 |
+
detector_name=args.superanimal_detector_name,
|
| 146 |
+
images=patch_path,
|
| 147 |
+
max_individuals=args.superanimal_max_individuals,
|
| 148 |
+
out_folder=tmp_dir,
|
| 149 |
+
device=dlc_device,
|
| 150 |
+
customized_model_config=args.pytorch_config_2d_path,
|
| 151 |
+
customized_pose_checkpoint=args.saved_2d_model_path,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
payload = preds.get(patch_path, None)
|
| 155 |
+
if payload is None:
|
| 156 |
+
return None
|
| 157 |
+
bodyparts = payload.get("bodyparts", None)
|
| 158 |
+
if bodyparts is None or len(bodyparts) == 0:
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
best_idx = int(np.argmax(bodyparts[..., 2].mean(axis=1)))
|
| 162 |
+
return bodyparts[best_idx].astype(np.float32)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def render_and_save(renderer, cam_crop_to_full_fn, out, batch, img_fn, animal_id, out_folder, suffix, side_view, save_mesh):
|
| 166 |
+
pred_cam = out['pred_cam']
|
| 167 |
+
box_center = batch['box_center'].float()
|
| 168 |
+
box_size = batch['box_size'].float()
|
| 169 |
+
img_size = batch['img_size'].float()
|
| 170 |
+
scaled_focal_length = batch['focal_length'][0, 0] / batch['img'].shape[-1] * img_size.max()
|
| 171 |
+
pred_cam_t_full = cam_crop_to_full_fn(pred_cam, box_center, box_size, img_size, scaled_focal_length)
|
| 172 |
+
|
| 173 |
+
white_img = (torch.ones_like(batch['img'][0]).cpu() - DEFAULT_MEAN[:, None, None] / 255) / (
|
| 174 |
+
DEFAULT_STD[:, None, None] / 255
|
| 175 |
+
)
|
| 176 |
+
input_patch = denorm_patch_to_rgb(batch['img'][0])
|
| 177 |
+
|
| 178 |
+
regression_img = renderer(
|
| 179 |
+
out['pred_vertices'][0].detach().cpu().numpy(),
|
| 180 |
+
out['pred_cam_t'][0].detach().cpu().numpy(),
|
| 181 |
+
batch['img'][0],
|
| 182 |
+
mesh_base_color=GREEN,
|
| 183 |
+
scene_bg_color=(1, 1, 1),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
final_img = np.concatenate([input_patch, regression_img], axis=1)
|
| 187 |
+
if side_view:
|
| 188 |
+
side_img = renderer(
|
| 189 |
+
out['pred_vertices'][0].detach().cpu().numpy(),
|
| 190 |
+
out['pred_cam_t'][0].detach().cpu().numpy(),
|
| 191 |
+
white_img,
|
| 192 |
+
mesh_base_color=GREEN,
|
| 193 |
+
scene_bg_color=(1, 1, 1),
|
| 194 |
+
side_view=True,
|
| 195 |
+
)
|
| 196 |
+
final_img = np.concatenate([final_img, side_img], axis=1)
|
| 197 |
+
|
| 198 |
+
cv2.imwrite(
|
| 199 |
+
os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.png'),
|
| 200 |
+
cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR),
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
if save_mesh:
|
| 204 |
+
verts = out['pred_vertices'][0].detach().cpu().numpy()
|
| 205 |
+
cam_t = pred_cam_t_full[0].detach().cpu().numpy()
|
| 206 |
+
tmesh = renderer.vertices_to_trimesh(verts, cam_t.copy(), LIGHT_BLUE)
|
| 207 |
+
tmesh.export(os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.obj'))
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def tta_optimize(model, batch, gt_kpts_norm, num_iters, lr):
|
| 211 |
+
model.eval()
|
| 212 |
+
|
| 213 |
+
if hasattr(model, 'backbone'):
|
| 214 |
+
for p in model.backbone.parameters():
|
| 215 |
+
p.requires_grad = False
|
| 216 |
+
|
| 217 |
+
orig_smal_head_state = copy.deepcopy(model.smal_head.state_dict())
|
| 218 |
+
model.smal_head.freeze_except_regression_heads()
|
| 219 |
+
tta_params = model.smal_head.get_tta_parameters(mode='all')
|
| 220 |
+
optimizer = torch.optim.Adam(tta_params, lr=lr)
|
| 221 |
+
|
| 222 |
+
valid_mask = (gt_kpts_norm[..., 2] > 0).float().unsqueeze(-1)
|
| 223 |
+
gt_xy = gt_kpts_norm[..., :2]
|
| 224 |
+
|
| 225 |
+
for _ in range(num_iters):
|
| 226 |
+
optimizer.zero_grad()
|
| 227 |
+
out = model(batch)
|
| 228 |
+
pred_xy = out['pred_keypoints_2d']
|
| 229 |
+
loss = F.mse_loss(pred_xy * valid_mask, gt_xy * valid_mask, reduction='sum') / (valid_mask.sum() + 1e-6)
|
| 230 |
+
loss.backward()
|
| 231 |
+
optimizer.step()
|
| 232 |
+
|
| 233 |
+
with torch.no_grad():
|
| 234 |
+
out_after = model(batch)
|
| 235 |
+
|
| 236 |
+
model.smal_head.load_state_dict(orig_smal_head_state)
|
| 237 |
+
model.smal_head.unfreeze_all()
|
| 238 |
+
|
| 239 |
+
return out_after
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def main():
|
| 243 |
+
parser = argparse.ArgumentParser(description='PRIMA + SuperAnimal + TTA demo')
|
| 244 |
+
parser.add_argument('--checkpoint', type=str, default='',
|
| 245 |
+
help='Path to pretrained PRIMA checkpoint. Empty -> auto-download the default Stage 1 checkpoint.')
|
| 246 |
+
parser.add_argument('--hf-repo-id', '--hf_repo_id', dest='hf_repo_id',
|
| 247 |
+
type=str, default=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_REPO_ID),
|
| 248 |
+
help='Hugging Face repo ID containing PRIMA demo assets')
|
| 249 |
+
parser.add_argument('--no-auto-download', '--no_auto_download', dest='no_auto_download', action='store_true',
|
| 250 |
+
help='Disable automatic download of missing PRIMA demo assets')
|
| 251 |
+
parser.add_argument('--img_path', type=str, default=None, help='Single image path')
|
| 252 |
+
parser.add_argument('--img_folder', type=str, default='demo_data/', help='Folder with input images')
|
| 253 |
+
parser.add_argument('--out_folder', type=str, default='demo_out_tta', help='Output folder')
|
| 254 |
+
parser.add_argument('--side_view', dest='side_view', action='store_true', default=False, help='Render side view')
|
| 255 |
+
parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, help='Save meshes')
|
| 256 |
+
parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg', '*.JPEG'], help='Image globs')
|
| 257 |
+
parser.add_argument('--det_thresh', type=float, default=0.7, help='Detectron2 score threshold for animals')
|
| 258 |
+
|
| 259 |
+
parser.add_argument('--tta_lr', type=float, default=1e-6, help='TTA learning rate')
|
| 260 |
+
parser.add_argument('--tta_num_iters', type=int, default=30, help='TTA iterations')
|
| 261 |
+
parser.add_argument('--kp_conf_thresh', type=float, default=0.1, help='Keypoint confidence threshold')
|
| 262 |
+
|
| 263 |
+
parser.add_argument('--superanimal_name', type=str, default='superanimal_quadruped')
|
| 264 |
+
parser.add_argument('--superanimal_model_name', type=str, default='hrnet_w32')
|
| 265 |
+
parser.add_argument('--superanimal_detector_name', type=str, default='fasterrcnn_resnet50_fpn_v2')
|
| 266 |
+
parser.add_argument('--superanimal_max_individuals', type=int, default=1)
|
| 267 |
+
parser.add_argument('--saved_2d_model_path', type=str, default='',
|
| 268 |
+
help='Path to the fine-tuned SuperAnimal 26-joint .pt snapshot. '
|
| 269 |
+
'Empty -> auto-download sa_finetune_hrnet_w32.pt from '
|
| 270 |
+
'MLAdaptiveIntelligence/FMPose3D on Hugging Face Hub.')
|
| 271 |
+
parser.add_argument('--pytorch_config_2d_path', type=str,
|
| 272 |
+
default=str(Path(__file__).resolve().parent / 'configs' / 'sa_finetune_hrnet_w32.yaml'),
|
| 273 |
+
help='Path to the DLC pytorch config yaml for the fine-tuned snapshot. '
|
| 274 |
+
'Defaults to the bundled configs/sa_finetune_hrnet_w32.yaml.')
|
| 275 |
+
|
| 276 |
+
args = parser.parse_args()
|
| 277 |
+
checkpoint_path = resolve_prima_checkpoint_path(
|
| 278 |
+
args.checkpoint,
|
| 279 |
+
data_dir=REPO_ROOT / "data",
|
| 280 |
+
auto_download=not args.no_auto_download,
|
| 281 |
+
hf_repo_id=args.hf_repo_id,
|
| 282 |
+
)
|
| 283 |
+
args.saved_2d_model_path = resolve_sa_weights_path(args.saved_2d_model_path)
|
| 284 |
+
|
| 285 |
+
model, model_cfg = load_prima(checkpoint_path)
|
| 286 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 287 |
+
model = model.to(device)
|
| 288 |
+
model.eval()
|
| 289 |
+
|
| 290 |
+
Renderer, cam_crop_to_full_fn = load_renderer_components()
|
| 291 |
+
renderer = Renderer(model_cfg, faces=model.smal.faces)
|
| 292 |
+
os.makedirs(args.out_folder, exist_ok=True)
|
| 293 |
+
|
| 294 |
+
import detectron2.config
|
| 295 |
+
import detectron2.engine
|
| 296 |
+
from detectron2 import model_zoo
|
| 297 |
+
|
| 298 |
+
cfg = detectron2.config.get_cfg()
|
| 299 |
+
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml"))
|
| 300 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
|
| 301 |
+
cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
|
| 302 |
+
cfg.MODEL.DEVICE = device.type
|
| 303 |
+
detector = detectron2.engine.DefaultPredictor(cfg)
|
| 304 |
+
|
| 305 |
+
if args.img_path is not None:
|
| 306 |
+
img_paths = [Path(args.img_path)]
|
| 307 |
+
else:
|
| 308 |
+
img_paths = sorted([img for end in args.file_type for img in Path(args.img_folder).glob(end)])
|
| 309 |
+
|
| 310 |
+
for img_path in img_paths:
|
| 311 |
+
img_bgr = cv2.imread(str(img_path))
|
| 312 |
+
if img_bgr is None:
|
| 313 |
+
print(f"[WARN] Cannot read image: {img_path}")
|
| 314 |
+
continue
|
| 315 |
+
det_out = detector(img_bgr)
|
| 316 |
+
det_instances = det_out['instances']
|
| 317 |
+
boxes, suppressed = select_animal_boxes(
|
| 318 |
+
det_instances,
|
| 319 |
+
animal_class_ids=ANIMAL_COCO_IDS,
|
| 320 |
+
score_threshold=args.det_thresh,
|
| 321 |
+
)
|
| 322 |
+
if suppressed > 0:
|
| 323 |
+
print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s) in {img_path}")
|
| 324 |
+
|
| 325 |
+
if len(boxes) == 0:
|
| 326 |
+
print(f"[INFO] No animal detected in {img_path}")
|
| 327 |
+
continue
|
| 328 |
+
|
| 329 |
+
dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
|
| 330 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
|
| 331 |
+
|
| 332 |
+
for batch in tqdm(dataloader, desc=f"{img_path.name}"):
|
| 333 |
+
batch = recursive_to(batch, device)
|
| 334 |
+
with torch.no_grad():
|
| 335 |
+
out_before = model(batch)
|
| 336 |
+
|
| 337 |
+
img_fn = img_path.stem
|
| 338 |
+
animal_id = int(batch['animalid'][0])
|
| 339 |
+
|
| 340 |
+
render_and_save(
|
| 341 |
+
renderer,
|
| 342 |
+
cam_crop_to_full_fn,
|
| 343 |
+
out_before,
|
| 344 |
+
batch,
|
| 345 |
+
img_fn,
|
| 346 |
+
animal_id,
|
| 347 |
+
args.out_folder,
|
| 348 |
+
suffix='before_tta',
|
| 349 |
+
side_view=args.side_view,
|
| 350 |
+
save_mesh=args.save_mesh,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
patch_rgb = denorm_patch_to_rgb(batch['img'][0])
|
| 354 |
+
with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir:
|
| 355 |
+
kpts_xyc = run_superanimal_on_patch(patch_rgb, args, tmp_dir)
|
| 356 |
+
|
| 357 |
+
if kpts_xyc is None:
|
| 358 |
+
print(f"[WARN] No SuperAnimal keypoints for {img_fn}_{animal_id}, skip TTA")
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
kpts_xyc[kpts_xyc[:, 2] < args.kp_conf_thresh, 2] = 0.0
|
| 362 |
+
|
| 363 |
+
save_keypoint_vis(
|
| 364 |
+
patch_rgb,
|
| 365 |
+
kpts_xyc,
|
| 366 |
+
os.path.join(args.out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png"),
|
| 367 |
+
)
|
| 368 |
+
np.save(os.path.join(args.out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy"), kpts_xyc)
|
| 369 |
+
|
| 370 |
+
patch_h, patch_w = patch_rgb.shape[:2]
|
| 371 |
+
kpts_norm = kpts_xyc.copy()
|
| 372 |
+
kpts_norm[:, 0] = kpts_norm[:, 0] / float(patch_w) - 0.5
|
| 373 |
+
kpts_norm[:, 1] = kpts_norm[:, 1] / float(patch_h) - 0.5
|
| 374 |
+
gt_kpts_norm = torch.from_numpy(kpts_norm[None]).to(device=device, dtype=batch['img'].dtype)
|
| 375 |
+
|
| 376 |
+
out_after = tta_optimize(
|
| 377 |
+
model,
|
| 378 |
+
batch,
|
| 379 |
+
gt_kpts_norm,
|
| 380 |
+
num_iters=args.tta_num_iters,
|
| 381 |
+
lr=args.tta_lr,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
render_and_save(
|
| 385 |
+
renderer,
|
| 386 |
+
cam_crop_to_full_fn,
|
| 387 |
+
out_after,
|
| 388 |
+
batch,
|
| 389 |
+
img_fn,
|
| 390 |
+
animal_id,
|
| 391 |
+
args.out_folder,
|
| 392 |
+
suffix='after_tta',
|
| 393 |
+
side_view=args.side_view,
|
| 394 |
+
save_mesh=args.save_mesh,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
if __name__ == '__main__':
|
| 399 |
+
main()
|
images/teaser.png
ADDED
|
Git LFS Details
|
packages.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
libosmesa6
|
| 2 |
+
libgl1
|
| 3 |
+
libgl1-mesa-dri
|
| 4 |
+
libegl-mesa0
|
| 5 |
+
libegl1
|
| 6 |
+
libglx-mesa0
|
| 7 |
+
libgles2
|
prima/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
"""Top-level package for PRIMA.
|
| 11 |
+
|
| 12 |
+
This package contains models, datasets and utilities for
|
| 13 |
+
3D animal pose and shape estimation.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from importlib.metadata import PackageNotFoundError, version
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
try: # pragma: no cover - best effort during development
|
| 20 |
+
__version__ = version("prima-animal")
|
| 21 |
+
except PackageNotFoundError: # pragma: no cover
|
| 22 |
+
__version__ = "0.0.0"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
__all__ = ["__version__"]
|
prima/configs/__init__.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Dict
|
| 11 |
+
from yacs.config import CfgNode as CN
|
| 12 |
+
|
| 13 |
+
def to_lower(x: Dict) -> Dict:
|
| 14 |
+
"""
|
| 15 |
+
Convert all dictionary keys to lowercase
|
| 16 |
+
Args:
|
| 17 |
+
x (dict): Input dictionary
|
| 18 |
+
Returns:
|
| 19 |
+
dict: Output dictionary with all keys converted to lowercase
|
| 20 |
+
"""
|
| 21 |
+
return {k.lower(): v for k, v in x.items()}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
_C = CN(new_allowed=True)
|
| 25 |
+
|
| 26 |
+
_C.GENERAL = CN(new_allowed=True)
|
| 27 |
+
_C.GENERAL.RESUME = True
|
| 28 |
+
_C.GENERAL.TIME_TO_RUN = 3300
|
| 29 |
+
_C.GENERAL.VAL_STEPS = 100
|
| 30 |
+
_C.GENERAL.LOG_STEPS = 100
|
| 31 |
+
_C.GENERAL.CHECKPOINT_STEPS = 20000
|
| 32 |
+
_C.GENERAL.CHECKPOINT_DIR = "checkpoints"
|
| 33 |
+
_C.GENERAL.SUMMARY_DIR = "tensorboard"
|
| 34 |
+
_C.GENERAL.NUM_GPUS = 1
|
| 35 |
+
_C.GENERAL.NUM_WORKERS = 4
|
| 36 |
+
_C.GENERAL.MIXED_PRECISION = True
|
| 37 |
+
_C.GENERAL.ALLOW_CUDA = True
|
| 38 |
+
_C.GENERAL.PIN_MEMORY = False
|
| 39 |
+
_C.GENERAL.DISTRIBUTED = False
|
| 40 |
+
_C.GENERAL.LOCAL_RANK = 0
|
| 41 |
+
_C.GENERAL.USE_SYNCBN = False
|
| 42 |
+
_C.GENERAL.WORLD_SIZE = 1
|
| 43 |
+
_C.GENERAL.PREFETCH_FACTOR = 2
|
| 44 |
+
|
| 45 |
+
_C.TRAIN = CN(new_allowed=True)
|
| 46 |
+
_C.TRAIN.NUM_EPOCHS = 100
|
| 47 |
+
_C.TRAIN.SHUFFLE = True
|
| 48 |
+
_C.TRAIN.WARMUP = False
|
| 49 |
+
_C.TRAIN.NORMALIZE_PER_IMAGE = False
|
| 50 |
+
_C.TRAIN.CLIP_GRAD = False
|
| 51 |
+
_C.TRAIN.CLIP_GRAD_VALUE = 1.0
|
| 52 |
+
_C.LOSS_WEIGHTS = CN(new_allowed=True)
|
| 53 |
+
|
| 54 |
+
_C.DATASETS = CN(new_allowed=True)
|
| 55 |
+
|
| 56 |
+
_C.MODEL = CN(new_allowed=True)
|
| 57 |
+
_C.MODEL.IMAGE_SIZE = 224
|
| 58 |
+
|
| 59 |
+
_C.EXTRA = CN(new_allowed=True)
|
| 60 |
+
_C.EXTRA.FOCAL_LENGTH = 5000
|
| 61 |
+
|
| 62 |
+
_C.DATASETS.CONFIG = CN(new_allowed=True)
|
| 63 |
+
_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
|
| 64 |
+
_C.DATASETS.CONFIG.ROT_FACTOR = 30
|
| 65 |
+
_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
|
| 66 |
+
_C.DATASETS.CONFIG.COLOR_SCALE = 0.2
|
| 67 |
+
_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
|
| 68 |
+
_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
|
| 69 |
+
_C.DATASETS.CONFIG.DO_FLIP = False
|
| 70 |
+
_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
|
| 71 |
+
_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def default_config() -> CN:
|
| 75 |
+
"""
|
| 76 |
+
Get a yacs CfgNode object with the default config values.
|
| 77 |
+
"""
|
| 78 |
+
# Return a clone so that the defaults will not be altered
|
| 79 |
+
# This is for the "local variable" use pattern
|
| 80 |
+
return _C.clone()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_config(config_file: str, merge: bool = True) -> CN:
|
| 84 |
+
"""
|
| 85 |
+
Read a config file and optionally merge it with the default config file.
|
| 86 |
+
Args:
|
| 87 |
+
config_file (str): Path to config file.
|
| 88 |
+
merge (bool): Whether to merge with the default config or not.
|
| 89 |
+
Returns:
|
| 90 |
+
CfgNode: Config as a yacs CfgNode object.
|
| 91 |
+
"""
|
| 92 |
+
if merge:
|
| 93 |
+
cfg = default_config()
|
| 94 |
+
else:
|
| 95 |
+
cfg = CN(new_allowed=True)
|
| 96 |
+
cfg.merge_from_file(config_file)
|
| 97 |
+
|
| 98 |
+
cfg.freeze()
|
| 99 |
+
return cfg
|
prima/datasets/__init__.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Dict, Optional
|
| 11 |
+
from torch.utils.data import WeightedRandomSampler
|
| 12 |
+
import torch
|
| 13 |
+
import pytorch_lightning as pl
|
| 14 |
+
from yacs.config import CfgNode
|
| 15 |
+
from .datasets import OptionAnimalDataset, TrainDataset
|
| 16 |
+
from prima.utils.pylogger import get_pylogger
|
| 17 |
+
|
| 18 |
+
log = get_pylogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DataModule(pl.LightningDataModule):
|
| 22 |
+
|
| 23 |
+
def __init__(self, cfg: CfgNode) -> None:
|
| 24 |
+
"""
|
| 25 |
+
Initialize LightningDataModule for AMR training
|
| 26 |
+
Args:
|
| 27 |
+
cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info.
|
| 28 |
+
"""
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.cfg = cfg
|
| 31 |
+
self.train_dataset = None
|
| 32 |
+
self.val_dataset = None
|
| 33 |
+
self.test_dataset = None
|
| 34 |
+
self.mocap_dataset = None
|
| 35 |
+
self.weight_sampler = None
|
| 36 |
+
|
| 37 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
| 38 |
+
"""
|
| 39 |
+
Load datasets necessary for training
|
| 40 |
+
Args:
|
| 41 |
+
stage:
|
| 42 |
+
"""
|
| 43 |
+
if self.train_dataset is None:
|
| 44 |
+
self.train_dataset = OptionAnimalDataset(self.cfg)
|
| 45 |
+
self.weight_sampler = WeightedRandomSampler(weights=self.train_dataset.weights,
|
| 46 |
+
num_samples=len(self.train_dataset))
|
| 47 |
+
if self.val_dataset is None:
|
| 48 |
+
self.val_dataset = TrainDataset(self.cfg, is_train=False,
|
| 49 |
+
root_image=self.cfg.DATASETS.ANIMAL3D.ROOT_IMAGE,
|
| 50 |
+
json_file=self.cfg.DATASETS.ANIMAL3D.JSON_FILE.TEST)
|
| 51 |
+
|
| 52 |
+
def train_dataloader(self) -> Dict:
|
| 53 |
+
"""
|
| 54 |
+
Setup training data loader.
|
| 55 |
+
Returns:
|
| 56 |
+
Dict: Dictionary containing image and mocap data dataloaders
|
| 57 |
+
"""
|
| 58 |
+
shuffle = False if self.weight_sampler is not None else True
|
| 59 |
+
train_dataloader = torch.utils.data.DataLoader(self.train_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True,
|
| 60 |
+
num_workers=self.cfg.GENERAL.NUM_WORKERS,
|
| 61 |
+
prefetch_factor=self.cfg.GENERAL.PREFETCH_FACTOR,
|
| 62 |
+
pin_memory=True,
|
| 63 |
+
shuffle=shuffle,
|
| 64 |
+
sampler=self.weight_sampler,
|
| 65 |
+
)
|
| 66 |
+
return {'img': train_dataloader}
|
| 67 |
+
|
| 68 |
+
def val_dataloader(self) -> torch.utils.data.DataLoader:
|
| 69 |
+
"""
|
| 70 |
+
Setup val data loader.
|
| 71 |
+
Returns:
|
| 72 |
+
torch.utils.data.DataLoader: Validation dataloader
|
| 73 |
+
"""
|
| 74 |
+
val_dataloader = torch.utils.data.DataLoader(self.val_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True,
|
| 75 |
+
num_workers=self.cfg.GENERAL.NUM_WORKERS, pin_memory=True)
|
| 76 |
+
return val_dataloader
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
prima/datasets/datasets.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import copy
|
| 11 |
+
import os
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from yacs.config import CfgNode
|
| 15 |
+
import cv2
|
| 16 |
+
import pyrootutils
|
| 17 |
+
from torch.utils.data import ConcatDataset
|
| 18 |
+
from typing import List
|
| 19 |
+
root = pyrootutils.setup_root(
|
| 20 |
+
search_from=__file__,
|
| 21 |
+
indicator=[".git", "pyproject.toml"],
|
| 22 |
+
pythonpath=True,
|
| 23 |
+
dotenv=True,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
import json
|
| 27 |
+
import hydra
|
| 28 |
+
from omegaconf import DictConfig, OmegaConf
|
| 29 |
+
from PIL import Image
|
| 30 |
+
from torch.utils.data import Dataset, DataLoader
|
| 31 |
+
from typing import Optional, Tuple
|
| 32 |
+
from .utils import get_example, expand_to_aspect_ratio
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TrainDataset(Dataset):
|
| 36 |
+
def __init__(self, cfg: CfgNode, is_train: bool, root_image: str, json_file: str):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.root_image = root_image
|
| 39 |
+
self.focal_length = cfg.SMAL.get("FOCAL_LENGTH", 1000)
|
| 40 |
+
|
| 41 |
+
json_file = json_file
|
| 42 |
+
with open(json_file, 'r') as f:
|
| 43 |
+
self.data = json.load(f)
|
| 44 |
+
|
| 45 |
+
self.is_train = is_train
|
| 46 |
+
self.IMG_SIZE = cfg.MODEL.IMAGE_SIZE
|
| 47 |
+
self.MEAN = 255. * np.array(cfg.MODEL.IMAGE_MEAN)
|
| 48 |
+
self.STD = 255. * np.array(cfg.MODEL.IMAGE_STD)
|
| 49 |
+
self.use_skimage_antialias = cfg.DATASETS.get('USE_SKIMAGE_ANTIALIAS', False)
|
| 50 |
+
self.border_mode = {
|
| 51 |
+
'constant': cv2.BORDER_CONSTANT,
|
| 52 |
+
'replicate': cv2.BORDER_REPLICATE,
|
| 53 |
+
}[cfg.DATASETS.get('BORDER_MODE', 'constant')]
|
| 54 |
+
|
| 55 |
+
self.augm_config = cfg.DATASETS.CONFIG
|
| 56 |
+
|
| 57 |
+
def __len__(self):
|
| 58 |
+
return len(self.data['data'])
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, item):
|
| 61 |
+
data = self.data['data'][item]
|
| 62 |
+
key = data['img_path']
|
| 63 |
+
image = np.array(Image.open(os.path.join(self.root_image, key)).convert("RGB"))
|
| 64 |
+
mask = np.array(Image.open(os.path.join(self.root_image, data['mask_path'])).convert('L'))
|
| 65 |
+
category_idx = data['supercategory']
|
| 66 |
+
keypoint_2d = np.array(data['keypoint_2d'], dtype=np.float32)
|
| 67 |
+
if 'keypoint_3d' in data:
|
| 68 |
+
keypoint_3d = np.concatenate(
|
| 69 |
+
(data['keypoint_3d'], np.ones((len(data['keypoint_3d']), 1))), axis=-1).astype(np.float32)
|
| 70 |
+
else:
|
| 71 |
+
keypoint_3d = np.zeros((len(keypoint_2d), 4), dtype=np.float32)
|
| 72 |
+
bbox = data['bbox'] # [x, y, w, h]
|
| 73 |
+
center = np.array([(bbox[0] * 2 + bbox[2]) // 2, (bbox[1] * 2 + bbox[3]) // 2])
|
| 74 |
+
pose = np.array(data['pose'], dtype=np.float32) if 'pose' in data else np.zeros(105, dtype=np.float32) # [105, ]
|
| 75 |
+
betas = np.array(data['shape'] + data['shape_extra'], dtype=np.float32) if 'shape' in data else np.zeros(41, dtype=np.float32) # [41, ]
|
| 76 |
+
translation = np.array(data['trans'], dtype=np.float32) if 'trans' in data else np.zeros(3, dtype=np.float32) # [3, ]
|
| 77 |
+
# Fixed: Check if all elements are zero, not if all elements are truthy
|
| 78 |
+
has_pose = np.array(1., dtype=np.float32) if not (pose == 0).all() else np.array(0., dtype=np.float32)
|
| 79 |
+
has_betas = np.array(1., dtype=np.float32) if not (betas == 0).all() else np.array(0., dtype=np.float32)
|
| 80 |
+
has_translation = np.array(1., dtype=np.float32) if not (translation == 0).all() else np.array(0., dtype=np.float32)
|
| 81 |
+
ori_keypoint_2d = keypoint_2d.copy()
|
| 82 |
+
center_x, center_y = center[0], center[1]
|
| 83 |
+
bbox_size = max([bbox[2], bbox[3]])
|
| 84 |
+
|
| 85 |
+
smal_params = {'global_orient': pose[:3],
|
| 86 |
+
'pose': pose[3:],
|
| 87 |
+
'betas': betas,
|
| 88 |
+
'transl': translation,
|
| 89 |
+
}
|
| 90 |
+
has_smal_params = {'global_orient': has_pose,
|
| 91 |
+
'pose': has_pose,
|
| 92 |
+
'betas': has_betas,
|
| 93 |
+
'transl': has_translation,
|
| 94 |
+
}
|
| 95 |
+
smal_params_is_axis_angle = {'global_orient': True,
|
| 96 |
+
'pose': True,
|
| 97 |
+
'betas': False,
|
| 98 |
+
'transl': False,
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
augm_config = copy.deepcopy(self.augm_config)
|
| 102 |
+
img_rgba = np.concatenate([image, mask[:, :, None]], axis=2)
|
| 103 |
+
img_patch_rgba, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, trans, img_border_mask = get_example(
|
| 104 |
+
img_rgba,
|
| 105 |
+
center_x, center_y,
|
| 106 |
+
bbox_size, bbox_size,
|
| 107 |
+
keypoint_2d, keypoint_3d,
|
| 108 |
+
smal_params, has_smal_params,
|
| 109 |
+
self.IMG_SIZE, self.IMG_SIZE,
|
| 110 |
+
self.MEAN, self.STD, self.is_train, augm_config,
|
| 111 |
+
is_bgr=False, return_trans=True,
|
| 112 |
+
use_skimage_antialias=self.use_skimage_antialias,
|
| 113 |
+
border_mode=self.border_mode
|
| 114 |
+
)
|
| 115 |
+
img_patch = (img_patch_rgba[:3, :, :])
|
| 116 |
+
mask_patch = (img_patch_rgba[3, :, :] / 255.0).clip(0, 1)
|
| 117 |
+
if (mask_patch < 0.5).all():
|
| 118 |
+
mask_patch = np.ones_like(mask_patch)
|
| 119 |
+
|
| 120 |
+
item = {'img': img_patch,
|
| 121 |
+
'mask': mask_patch,
|
| 122 |
+
'keypoints_2d': keypoints_2d,
|
| 123 |
+
'keypoints_3d': keypoints_3d,
|
| 124 |
+
'orig_keypoints_2d': ori_keypoint_2d,
|
| 125 |
+
'box_center': np.array(center.copy(), dtype=np.float32),
|
| 126 |
+
'box_size': float(bbox_size),
|
| 127 |
+
'img_size': np.array(1.0 * img_size[::-1].copy(), dtype=np.float32),
|
| 128 |
+
'smal_params': smal_params,
|
| 129 |
+
'has_smal_params': has_smal_params,
|
| 130 |
+
'smal_params_is_axis_angle': smal_params_is_axis_angle,
|
| 131 |
+
'_trans': trans,
|
| 132 |
+
'focal_length': np.array([self.focal_length, self.focal_length], dtype=np.float32),
|
| 133 |
+
'category': np.array(category_idx, dtype=np.int32),
|
| 134 |
+
'supercategory': np.array(category_idx, dtype=np.int32),
|
| 135 |
+
"img_border_mask": img_border_mask.astype(np.float32),
|
| 136 |
+
"has_mask": np.array(1, dtype=np.float32)}
|
| 137 |
+
return item
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class EvaluationDataset(Dataset):
|
| 141 |
+
def __init__(self, root_image: str, json_file: str, augm_config,
|
| 142 |
+
focal_length: int=1000, image_size: int=256,
|
| 143 |
+
mean: List[float]=[0.485, 0.456, 0.406], std: List[float]=[0.229, 0.224, 0.225]):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.root_image = root_image
|
| 146 |
+
self.focal_length = focal_length
|
| 147 |
+
|
| 148 |
+
with open(json_file, 'r') as f:
|
| 149 |
+
self.data = json.load(f)
|
| 150 |
+
|
| 151 |
+
self.is_train = False
|
| 152 |
+
self.IMG_SIZE = image_size
|
| 153 |
+
self.MEAN = 255. * np.array(mean)
|
| 154 |
+
self.STD = 255. * np.array(std)
|
| 155 |
+
self.use_skimage_antialias = False
|
| 156 |
+
self.border_mode = cv2.BORDER_CONSTANT
|
| 157 |
+
self.augm_config = augm_config
|
| 158 |
+
|
| 159 |
+
def __len__(self):
|
| 160 |
+
return len(self.data['data'])
|
| 161 |
+
|
| 162 |
+
def __getitem__(self, item):
|
| 163 |
+
data = self.data['data'][item]
|
| 164 |
+
key = data['img_path']
|
| 165 |
+
image = np.array(Image.open(os.path.join(self.root_image, key)).convert("RGB"))
|
| 166 |
+
mask = np.array(Image.open(os.path.join(self.root_image, data['mask_path'])).convert('L'))
|
| 167 |
+
category_idx = data['supercategory']
|
| 168 |
+
keypoint_2d = np.array(data['keypoint_2d'], dtype=np.float32)
|
| 169 |
+
# add check keypoint_3d, make it suitable for 2D dataset, and same with train dataset
|
| 170 |
+
if 'keypoint_3d' in data:
|
| 171 |
+
keypoint_3d = np.concatenate(
|
| 172 |
+
(data['keypoint_3d'], np.ones((len(data['keypoint_3d']), 1))), axis=-1).astype(np.float32)
|
| 173 |
+
else:
|
| 174 |
+
keypoint_3d = np.zeros((len(keypoint_2d), 4), dtype=np.float32)
|
| 175 |
+
bbox = data['bbox'] # [x, y, w, h]
|
| 176 |
+
center = np.array([(bbox[0] * 2 + bbox[2]) // 2, (bbox[1] * 2 + bbox[3]) // 2])
|
| 177 |
+
pose = np.array(data['pose'], dtype=np.float32) if 'pose' in data else np.zeros(105, dtype=np.float32) # [105, ]
|
| 178 |
+
betas = np.array(data['shape'] + data['shape_extra'], dtype=np.float32) if 'shape' in data else np.zeros(41, dtype=np.float32) # [41, ]
|
| 179 |
+
translation = np.array(data['trans'], dtype=np.float32) if 'trans' in data else np.zeros(3, dtype=np.float32) # [3, ]
|
| 180 |
+
# Fixed: Check if all elements are zero, not if all elements are truthy
|
| 181 |
+
has_pose = np.array(1., dtype=np.float32) if not (pose == 0).all() else np.array(0., dtype=np.float32)
|
| 182 |
+
has_betas = np.array(1., dtype=np.float32) if not (betas == 0).all() else np.array(0., dtype=np.float32)
|
| 183 |
+
has_translation = np.array(1., dtype=np.float32) if not (translation == 0).all() else np.array(0., dtype=np.float32)
|
| 184 |
+
ori_keypoint_2d = keypoint_2d.copy()
|
| 185 |
+
center_x, center_y = center[0], center[1]
|
| 186 |
+
|
| 187 |
+
scale = np.array([bbox[2], bbox[3]], dtype=np.float32) / 200.
|
| 188 |
+
bbox_size = expand_to_aspect_ratio(scale*200, None).max()
|
| 189 |
+
bbox_expand_factor = bbox_size / ((scale*200).max())
|
| 190 |
+
|
| 191 |
+
smal_params = {'global_orient': pose[:3],
|
| 192 |
+
'pose': pose[3:],
|
| 193 |
+
'betas': betas,
|
| 194 |
+
'transl': translation,
|
| 195 |
+
'bone': np.zeros(24, dtype=np.float32) if 'bone' not in data else np.array(data['bone'])
|
| 196 |
+
}
|
| 197 |
+
has_smal_params = {'global_orient': has_pose,
|
| 198 |
+
'pose': has_pose,
|
| 199 |
+
'betas': has_betas,
|
| 200 |
+
'transl': has_translation,
|
| 201 |
+
'bone': np.array(1., dtype=np.float32) if 'bone' in data else np.array(0., dtype=np.float32),
|
| 202 |
+
}
|
| 203 |
+
smal_params_is_axis_angle = {'global_orient': True,
|
| 204 |
+
'pose': True,
|
| 205 |
+
'betas': False,
|
| 206 |
+
'transl': False,
|
| 207 |
+
'bone': False
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
augm_config = copy.deepcopy(self.augm_config)
|
| 211 |
+
img_rgba = np.concatenate([image, mask[:, :, None]], axis=2)
|
| 212 |
+
img_patch_rgba, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, trans, img_border_mask = get_example(
|
| 213 |
+
img_rgba,
|
| 214 |
+
center_x, center_y,
|
| 215 |
+
bbox_size, bbox_size,
|
| 216 |
+
keypoint_2d, keypoint_3d,
|
| 217 |
+
smal_params, has_smal_params,
|
| 218 |
+
self.IMG_SIZE, self.IMG_SIZE,
|
| 219 |
+
self.MEAN, self.STD, self.is_train, augm_config,
|
| 220 |
+
is_bgr=False, return_trans=True,
|
| 221 |
+
use_skimage_antialias=self.use_skimage_antialias,
|
| 222 |
+
border_mode=self.border_mode
|
| 223 |
+
)
|
| 224 |
+
img_patch = (img_patch_rgba[:3, :, :])
|
| 225 |
+
mask_patch = (img_patch_rgba[3, :, :] / 255.0).clip(0, 1)
|
| 226 |
+
if (mask_patch < 0.5).all():
|
| 227 |
+
mask_patch = np.ones_like(mask_patch)
|
| 228 |
+
|
| 229 |
+
item = {'img': img_patch,
|
| 230 |
+
'mask': mask_patch,
|
| 231 |
+
'keypoints_2d': keypoints_2d,
|
| 232 |
+
'keypoints_3d': keypoints_3d,
|
| 233 |
+
'orig_keypoints_2d': ori_keypoint_2d,
|
| 234 |
+
'box_center': np.array(center.copy(), dtype=np.float32),
|
| 235 |
+
'box_size': float(bbox_size),
|
| 236 |
+
'img_size': np.array(1.0 * img_size[::-1].copy(), dtype=np.float32),
|
| 237 |
+
'smal_params': smal_params,
|
| 238 |
+
'has_smal_params': has_smal_params,
|
| 239 |
+
'smal_params_is_axis_angle': smal_params_is_axis_angle,
|
| 240 |
+
'_trans': trans,
|
| 241 |
+
'focal_length': np.array([self.focal_length, self.focal_length], dtype=np.float32),
|
| 242 |
+
'category': np.array(category_idx, dtype=np.int32),
|
| 243 |
+
'bbox_expand_factor': bbox_expand_factor,
|
| 244 |
+
'supercategory': np.array(category_idx, dtype=np.int32),
|
| 245 |
+
"img_border_mask": img_border_mask.astype(np.float32),
|
| 246 |
+
'has_mask': np.array(1., dtype=np.float32),
|
| 247 |
+
'imgname': key,
|
| 248 |
+
'bbox': np.array(bbox, dtype=np.float32)}
|
| 249 |
+
return item
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class OptionAnimalDataset(Dataset):
|
| 253 |
+
def __init__(self, cfg: CfgNode):
|
| 254 |
+
datasets = []
|
| 255 |
+
weights = []
|
| 256 |
+
|
| 257 |
+
dataset_configs = cfg.DATASETS
|
| 258 |
+
for dataset_name in dataset_configs:
|
| 259 |
+
if dataset_name != "CONFIG":
|
| 260 |
+
datasets.append(TrainDataset(cfg,
|
| 261 |
+
is_train=True,
|
| 262 |
+
root_image=dataset_configs[dataset_name].ROOT_IMAGE,
|
| 263 |
+
json_file=dataset_configs[dataset_name].JSON_FILE.TRAIN))
|
| 264 |
+
weights.extend([dataset_configs[dataset_name].WEIGHT] * len(datasets[-1]))
|
| 265 |
+
|
| 266 |
+
# Concatenate all enabled datasets
|
| 267 |
+
if datasets:
|
| 268 |
+
self.dataset = ConcatDataset(datasets)
|
| 269 |
+
self.weights = torch.tensor(weights, dtype=torch.float32)
|
| 270 |
+
else:
|
| 271 |
+
raise ValueError("No datasets enabled in the configuration.")
|
| 272 |
+
|
| 273 |
+
def __len__(self):
|
| 274 |
+
return len(self.dataset)
|
| 275 |
+
|
| 276 |
+
def __getitem__(self, idx):
|
| 277 |
+
return self.dataset[idx]
|
| 278 |
+
|
prima/datasets/dlc2coco.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
'''
|
| 11 |
+
this scripts if to convert DeepLabCut labeled data format (20 keypoints) to COCO format (26 keypoints ), also image should be extracted from the raw video to save as frames.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python dlc2coco.py --dataset_dir /path/to/dataset --output_dir /path/to/output
|
| 15 |
+
|
| 16 |
+
for camera x
|
| 17 |
+
dlc keypoint data: <dataset_dir>/<behavior>/fte_pw/camx_fte.csv,
|
| 18 |
+
where video frame index from the video and keypoint coordinates are stored
|
| 19 |
+
raw video: <dataset_dir>/<behavior>/camx.mp4
|
| 20 |
+
|
| 21 |
+
for coco format, please refer to:
|
| 22 |
+
./datasets/quadruped2d/test.json
|
| 23 |
+
|
| 24 |
+
also, the relationship of multiview should be saved.
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
keypoint mapping from acinoset to animal3d :
|
| 28 |
+
keypoint_mapping = {"acinoset":[2, 1, -1, 13, 10, 19, 16, 5, -1, -1, -1, -1, 11, 8, 12, 9, 18, 15, 3, 7, -1,-1,-1,-1, 0, 6]}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
'''
|
| 32 |
+
|
| 33 |
+
import argparse
|
| 34 |
+
import os
|
| 35 |
+
import json
|
| 36 |
+
import cv2
|
| 37 |
+
import numpy as np
|
| 38 |
+
import pandas as pd
|
| 39 |
+
from pathlib import Path
|
| 40 |
+
from tqdm import tqdm
|
| 41 |
+
|
| 42 |
+
# DLC keypoints (20 keypoints from acinoset):
|
| 43 |
+
# 0: nose, 1: r_eye, 2: l_eye, 3: neck_base, 4: spine, 5: tail_base, 6: tail1, 7: tail2,
|
| 44 |
+
# 8: r_shoulder, 9: r_front_knee, 10: r_front_ankle, 11: l_shoulder, 12: l_front_knee, 13: l_front_ankle,
|
| 45 |
+
# 14: r_hip, 15: r_back_knee, 16: r_back_ankle, 17: l_hip, 18: l_back_knee, 19: l_back_ankle
|
| 46 |
+
|
| 47 |
+
# Animal3D keypoints (26 keypoints):
|
| 48 |
+
# Based on the mapping: [2, 1, -1, 13, 10, 19, 16, 5, -1, -1, -1, -1, 11, 8, 12, 9, 18, 15, 3, 7, -1,-1,-1,-1, 0, 6]
|
| 49 |
+
# This means: animal3d_idx 0 maps to acinoset_idx 2 (l_eye), animal3d_idx 1 maps to acinoset_idx 1 (r_eye), etc.
|
| 50 |
+
|
| 51 |
+
# Keypoint mapping from acinoset (DLC) to animal3d (COCO format)
|
| 52 |
+
KEYPOINT_MAPPING = [2, 1, -1, 13, 10, 19, 16, 5, -1, -1, -1, -1, 11, 8, 12, 9, 18, 15, 3, 7, -1, -1, -1, -1, 0, 6]
|
| 53 |
+
|
| 54 |
+
def read_dlc_csv(csv_path):
|
| 55 |
+
"""
|
| 56 |
+
Read DeepLabCut CSV file and extract keypoint data
|
| 57 |
+
Returns: DataFrame with frame index and keypoint coordinates
|
| 58 |
+
"""
|
| 59 |
+
# Read the CSV file, skip the first 2 rows (header rows)
|
| 60 |
+
df = pd.read_csv(csv_path, skiprows=2)
|
| 61 |
+
|
| 62 |
+
# Replace NaN with 0
|
| 63 |
+
df = df.fillna(0)
|
| 64 |
+
|
| 65 |
+
# The first column is frame index
|
| 66 |
+
frame_indices = df.iloc[:, 0].values
|
| 67 |
+
|
| 68 |
+
# Extract keypoint coordinates (x, y, likelihood)
|
| 69 |
+
# DLC format: each keypoint has 3 columns (x, y, likelihood)
|
| 70 |
+
num_keypoints = 20
|
| 71 |
+
keypoints_data = []
|
| 72 |
+
|
| 73 |
+
for idx, frame_idx in enumerate(frame_indices):
|
| 74 |
+
keypoints = []
|
| 75 |
+
for kp_idx in range(num_keypoints):
|
| 76 |
+
col_start = 1 + kp_idx * 3
|
| 77 |
+
x = float(df.iloc[idx, col_start])
|
| 78 |
+
y = float(df.iloc[idx, col_start + 1])
|
| 79 |
+
likelihood = float(df.iloc[idx, col_start + 2])
|
| 80 |
+
|
| 81 |
+
# If likelihood is 0 (from NaN), but x and y are not 0, assume it's a valid point
|
| 82 |
+
if likelihood == 0 and (x != 0 or y != 0):
|
| 83 |
+
likelihood = 1.0 # Default to high confidence
|
| 84 |
+
|
| 85 |
+
keypoints.append([x, y, likelihood])
|
| 86 |
+
|
| 87 |
+
keypoints_data.append({
|
| 88 |
+
'frame_idx': int(frame_idx),
|
| 89 |
+
'keypoints': keypoints
|
| 90 |
+
})
|
| 91 |
+
|
| 92 |
+
return keypoints_data
|
| 93 |
+
|
| 94 |
+
def map_keypoints_to_animal3d(acinoset_keypoints):
|
| 95 |
+
"""
|
| 96 |
+
Map 20 DLC keypoints to 26 Animal3D keypoints using the provided mapping
|
| 97 |
+
acinoset_keypoints: list of [x, y, likelihood] for 20 keypoints
|
| 98 |
+
Returns: list of [x, y, visibility] for 26 keypoints
|
| 99 |
+
"""
|
| 100 |
+
animal3d_keypoints = []
|
| 101 |
+
|
| 102 |
+
for animal3d_idx, acinoset_idx in enumerate(KEYPOINT_MAPPING):
|
| 103 |
+
if acinoset_idx == -1:
|
| 104 |
+
# Missing keypoint, set to [0, 0, 0]
|
| 105 |
+
animal3d_keypoints.append([0.0, 0.0, 0.0])
|
| 106 |
+
else:
|
| 107 |
+
x, y, likelihood = acinoset_keypoints[acinoset_idx]
|
| 108 |
+
# Replace NaN with 0
|
| 109 |
+
if np.isnan(x):
|
| 110 |
+
x = 0.0
|
| 111 |
+
if np.isnan(y):
|
| 112 |
+
y = 0.0
|
| 113 |
+
if np.isnan(likelihood):
|
| 114 |
+
likelihood = 0.0
|
| 115 |
+
|
| 116 |
+
# Convert likelihood to visibility flag (2 = visible, 1 = occluded, 0 = not labeled)
|
| 117 |
+
# If the keypoint has valid coordinates, mark as visible
|
| 118 |
+
if x != 0.0 or y != 0.0:
|
| 119 |
+
visibility = 2.0
|
| 120 |
+
else:
|
| 121 |
+
visibility = 0.0
|
| 122 |
+
|
| 123 |
+
animal3d_keypoints.append([float(x), float(y), visibility])
|
| 124 |
+
|
| 125 |
+
return animal3d_keypoints
|
| 126 |
+
|
| 127 |
+
def extract_frames_from_video(video_path, output_dir, frame_indices, behavior, camera_id):
|
| 128 |
+
"""
|
| 129 |
+
Extract specific frames from video and save as images
|
| 130 |
+
Returns: dict mapping frame_idx to image path
|
| 131 |
+
"""
|
| 132 |
+
output_dir = Path(output_dir)
|
| 133 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 134 |
+
|
| 135 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 136 |
+
if not cap.isOpened():
|
| 137 |
+
print(f"Error: Cannot open video {video_path}")
|
| 138 |
+
return {}
|
| 139 |
+
|
| 140 |
+
frame_paths = {}
|
| 141 |
+
|
| 142 |
+
# Sort frame indices for efficient extraction
|
| 143 |
+
sorted_frames = sorted(set(frame_indices)) # Remove duplicates
|
| 144 |
+
|
| 145 |
+
pbar = tqdm(total=len(sorted_frames), desc=f"Extracting frames from {video_path.name}")
|
| 146 |
+
|
| 147 |
+
for target_frame in sorted_frames:
|
| 148 |
+
# Use CAP_PROP_POS_FRAMES to seek to the exact frame
|
| 149 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
| 150 |
+
ret, frame = cap.read()
|
| 151 |
+
|
| 152 |
+
if ret and frame is not None:
|
| 153 |
+
# Save frame as image with behavior name in filename
|
| 154 |
+
img_filename = f"{behavior}_cam{camera_id}_frame_{target_frame:06d}.jpg"
|
| 155 |
+
img_path = output_dir / img_filename
|
| 156 |
+
cv2.imwrite(str(img_path), frame)
|
| 157 |
+
frame_paths[target_frame] = str(img_path.relative_to(output_dir.parent.parent))
|
| 158 |
+
else:
|
| 159 |
+
print(f"Warning: Failed to read frame {target_frame} from {video_path.name}")
|
| 160 |
+
|
| 161 |
+
pbar.update(1)
|
| 162 |
+
|
| 163 |
+
pbar.close()
|
| 164 |
+
cap.release()
|
| 165 |
+
|
| 166 |
+
return frame_paths
|
| 167 |
+
|
| 168 |
+
def compute_bbox_from_keypoints(keypoints):
|
| 169 |
+
"""
|
| 170 |
+
Compute bounding box from keypoints
|
| 171 |
+
keypoints: list of [x, y, visibility]
|
| 172 |
+
Returns: [x, y, width, height]
|
| 173 |
+
"""
|
| 174 |
+
valid_points = [(kp[0], kp[1]) for kp in keypoints if kp[2] > 0]
|
| 175 |
+
|
| 176 |
+
if not valid_points:
|
| 177 |
+
return [0, 0, 0, 0]
|
| 178 |
+
|
| 179 |
+
xs, ys = zip(*valid_points)
|
| 180 |
+
x_min, x_max = min(xs), max(xs)
|
| 181 |
+
y_min, y_max = min(ys), max(ys)
|
| 182 |
+
|
| 183 |
+
# Add some padding
|
| 184 |
+
padding = 20
|
| 185 |
+
x_min = max(0, x_min - padding)
|
| 186 |
+
y_min = max(0, y_min - padding)
|
| 187 |
+
width = (x_max - x_min) + 2 * padding
|
| 188 |
+
height = (y_max - y_min) + 2 * padding
|
| 189 |
+
|
| 190 |
+
return [float(x_min), float(y_min), float(width), float(height)]
|
| 191 |
+
|
| 192 |
+
def process_camera(camera_id, base_dir, output_dir, behavior):
|
| 193 |
+
"""
|
| 194 |
+
Process one camera: read CSV, extract frames, convert to COCO format
|
| 195 |
+
behavior: name of the behavior (e.g., 'run', 'flick')
|
| 196 |
+
"""
|
| 197 |
+
base_dir = Path(base_dir)
|
| 198 |
+
output_dir = Path(output_dir)
|
| 199 |
+
|
| 200 |
+
# Paths
|
| 201 |
+
csv_path = base_dir / "fte_pw" / f"cam{camera_id}_fte.csv"
|
| 202 |
+
video_path = base_dir / f"cam{camera_id}.mp4"
|
| 203 |
+
|
| 204 |
+
print(f"\nProcessing Camera {camera_id} - Behavior: {behavior}...")
|
| 205 |
+
print(f"CSV: {csv_path}")
|
| 206 |
+
print(f"Video: {video_path}")
|
| 207 |
+
|
| 208 |
+
# Read keypoint data from CSV
|
| 209 |
+
keypoints_data = read_dlc_csv(csv_path)
|
| 210 |
+
print(f"Found {len(keypoints_data)} frames with keypoints")
|
| 211 |
+
|
| 212 |
+
# Extract frames from video
|
| 213 |
+
frame_indices = [kp_data['frame_idx'] for kp_data in keypoints_data]
|
| 214 |
+
images_dir = output_dir / "images" / behavior / f"cam{camera_id}"
|
| 215 |
+
frame_paths = extract_frames_from_video(video_path, images_dir, frame_indices, behavior, camera_id)
|
| 216 |
+
|
| 217 |
+
# Convert to COCO format
|
| 218 |
+
coco_data = []
|
| 219 |
+
for kp_data in tqdm(keypoints_data, desc=f"Converting cam{camera_id} to COCO format"):
|
| 220 |
+
frame_idx = kp_data['frame_idx']
|
| 221 |
+
|
| 222 |
+
if frame_idx not in frame_paths:
|
| 223 |
+
continue
|
| 224 |
+
|
| 225 |
+
# Map keypoints from acinoset (20) to animal3d (26)
|
| 226 |
+
acinoset_kps = kp_data['keypoints']
|
| 227 |
+
animal3d_kps = map_keypoints_to_animal3d(acinoset_kps)
|
| 228 |
+
|
| 229 |
+
# Compute bounding box
|
| 230 |
+
bbox = compute_bbox_from_keypoints(animal3d_kps)
|
| 231 |
+
|
| 232 |
+
# Create COCO entry
|
| 233 |
+
img_path = frame_paths[frame_idx]
|
| 234 |
+
coco_entry = {
|
| 235 |
+
"img_path": img_path,
|
| 236 |
+
"mask_path": img_path, # Same as img_path
|
| 237 |
+
"bbox": bbox,
|
| 238 |
+
"keypoint_2d": animal3d_kps,
|
| 239 |
+
"camera_id": camera_id,
|
| 240 |
+
"frame_idx": frame_idx,
|
| 241 |
+
"behavior": behavior
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
coco_data.append(coco_entry)
|
| 245 |
+
|
| 246 |
+
return coco_data
|
| 247 |
+
|
| 248 |
+
def parse_args():
|
| 249 |
+
parser = argparse.ArgumentParser(
|
| 250 |
+
description="Convert DeepLabCut labeled data to COCO format"
|
| 251 |
+
)
|
| 252 |
+
parser.add_argument(
|
| 253 |
+
"--dataset_dir", type=str, default=".",
|
| 254 |
+
help="Root directory containing behavior subdirectories (run, flick, etc.)"
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--output_dir", type=str, default=None,
|
| 258 |
+
help="Output directory for COCO format data (default: {dataset_dir}/coco_format)"
|
| 259 |
+
)
|
| 260 |
+
parser.add_argument(
|
| 261 |
+
"--behaviors", type=str, nargs="+", default=["run", "flick"],
|
| 262 |
+
help="Behavior names to process (default: run flick)"
|
| 263 |
+
)
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
"--cameras", type=int, nargs="+", default=[1, 2, 3, 4, 5, 6],
|
| 266 |
+
help="Camera IDs to process (default: 1 2 3 4 5 6)"
|
| 267 |
+
)
|
| 268 |
+
return parser.parse_args()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def main():
|
| 272 |
+
args = parse_args()
|
| 273 |
+
|
| 274 |
+
dataset_dir = Path(args.dataset_dir)
|
| 275 |
+
output_dir = Path(args.output_dir) if args.output_dir else dataset_dir / "coco_format"
|
| 276 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 277 |
+
|
| 278 |
+
behaviors = args.behaviors
|
| 279 |
+
camera_ids = args.cameras
|
| 280 |
+
|
| 281 |
+
all_data = []
|
| 282 |
+
behavior_data = {}
|
| 283 |
+
camera_data = {}
|
| 284 |
+
|
| 285 |
+
for behavior in behaviors:
|
| 286 |
+
behavior_dir = dataset_dir / behavior
|
| 287 |
+
behavior_data[behavior] = []
|
| 288 |
+
|
| 289 |
+
print(f"\n{'='*60}")
|
| 290 |
+
print(f"Processing Behavior: {behavior.upper()}")
|
| 291 |
+
print(f"{'='*60}")
|
| 292 |
+
|
| 293 |
+
for cam_id in camera_ids:
|
| 294 |
+
coco_data = process_camera(cam_id, behavior_dir, output_dir, behavior)
|
| 295 |
+
all_data.extend(coco_data)
|
| 296 |
+
behavior_data[behavior].extend(coco_data)
|
| 297 |
+
|
| 298 |
+
# Store per-camera-behavior data
|
| 299 |
+
key = f"{behavior}_cam{cam_id}"
|
| 300 |
+
camera_data[key] = coco_data
|
| 301 |
+
|
| 302 |
+
# Save combined data (all behaviors and cameras)
|
| 303 |
+
output_json = output_dir / "all_data.json"
|
| 304 |
+
with open(output_json, 'w') as f:
|
| 305 |
+
json.dump({"data": all_data}, f, indent=4)
|
| 306 |
+
|
| 307 |
+
print(f"\n{'='*60}")
|
| 308 |
+
print(f"SUMMARY")
|
| 309 |
+
print(f"{'='*60}")
|
| 310 |
+
print(f"Saved combined data to {output_json}")
|
| 311 |
+
print(f"Total entries: {len(all_data)}")
|
| 312 |
+
|
| 313 |
+
# Save per-behavior data
|
| 314 |
+
for behavior in behaviors:
|
| 315 |
+
behavior_json = output_dir / f"{behavior}.json"
|
| 316 |
+
with open(behavior_json, 'w') as f:
|
| 317 |
+
json.dump({"data": behavior_data[behavior]}, f, indent=4)
|
| 318 |
+
print(f"\nSaved {behavior} data to {behavior_json} ({len(behavior_data[behavior])} entries)")
|
| 319 |
+
|
| 320 |
+
# Save per-camera-behavior data
|
| 321 |
+
for behavior in behaviors:
|
| 322 |
+
for cam_id in camera_ids:
|
| 323 |
+
key = f"{behavior}_cam{cam_id}"
|
| 324 |
+
cam_json = output_dir / f"{behavior}_cam{cam_id}.json"
|
| 325 |
+
with open(cam_json, 'w') as f:
|
| 326 |
+
json.dump({"data": camera_data[key]}, f, indent=4)
|
| 327 |
+
print(f" - {behavior}_cam{cam_id}: {len(camera_data[key])} entries")
|
| 328 |
+
|
| 329 |
+
# Save multiview relationships
|
| 330 |
+
# Group by behavior and frame index to establish multiview correspondence
|
| 331 |
+
multiview_data = {}
|
| 332 |
+
for entry in all_data:
|
| 333 |
+
behavior = entry['behavior']
|
| 334 |
+
frame_idx = entry['frame_idx']
|
| 335 |
+
cam_id = entry['camera_id']
|
| 336 |
+
|
| 337 |
+
if behavior not in multiview_data:
|
| 338 |
+
multiview_data[behavior] = {}
|
| 339 |
+
|
| 340 |
+
if frame_idx not in multiview_data[behavior]:
|
| 341 |
+
multiview_data[behavior][frame_idx] = {}
|
| 342 |
+
|
| 343 |
+
multiview_data[behavior][frame_idx][f"cam{cam_id}"] = {
|
| 344 |
+
"img_path": entry['img_path'],
|
| 345 |
+
"keypoint_2d": entry['keypoint_2d'],
|
| 346 |
+
"bbox": entry['bbox']
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
multiview_json = output_dir / "multiview_mapping.json"
|
| 350 |
+
with open(multiview_json, 'w') as f:
|
| 351 |
+
json.dump(multiview_data, f, indent=4)
|
| 352 |
+
|
| 353 |
+
print(f"\nSaved multiview mapping to {multiview_json}")
|
| 354 |
+
for behavior in behaviors:
|
| 355 |
+
print(f" - {behavior}: {len(multiview_data.get(behavior, {}))} synchronized frames")
|
| 356 |
+
|
| 357 |
+
print(f"\n{'='*60}")
|
| 358 |
+
print("Conversion complete!")
|
| 359 |
+
print(f"{'='*60}")
|
| 360 |
+
|
| 361 |
+
if __name__ == "__main__":
|
| 362 |
+
main()
|
prima/datasets/split_acinoset.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Split acinoset multiview_mapping.json into train and test sets (7:3 ratio).
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python split_acinoset.py \
|
| 15 |
+
--input_json /path/to/multiview_mapping.json \
|
| 16 |
+
--output_dir /path/to/output \
|
| 17 |
+
--train_ratio 0.7 \
|
| 18 |
+
--seed 42
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
import random
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from collections import defaultdict
|
| 26 |
+
|
| 27 |
+
# ------------------------------------------------------------------
|
| 28 |
+
# EDIT THIS to point to your dataset root (see examples above).
|
| 29 |
+
# All paths below are relative to this directory.
|
| 30 |
+
# ------------------------------------------------------------------
|
| 31 |
+
BASE_DIR = Path("datasets")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def split_multiview_data(input_json, output_dir, train_ratio=0.7, seed=42):
|
| 35 |
+
"""
|
| 36 |
+
Split multiview mapping data into train and test sets.
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
input_json: Path to multiview_mapping.json
|
| 41 |
+
output_dir: Directory to save train.json and test.json
|
| 42 |
+
train_ratio: Ratio of training data (default 0.7 for 70%%)
|
| 43 |
+
train_ratio: Ratio of training data (default 0.7 for 70%%)
|
| 44 |
+
seed: Random seed for reproducibility
|
| 45 |
+
"""
|
| 46 |
+
# Set random seed
|
| 47 |
+
random.seed(seed)
|
| 48 |
+
|
| 49 |
+
# Load data
|
| 50 |
+
print(f"Loading data from {input_json}...")
|
| 51 |
+
with open(input_json, 'r') as f:
|
| 52 |
+
data = json.load(f)
|
| 53 |
+
|
| 54 |
+
# Initialize train and test splits
|
| 55 |
+
train_data = defaultdict(dict)
|
| 56 |
+
test_data = defaultdict(dict)
|
| 57 |
+
|
| 58 |
+
# Process each behavior
|
| 59 |
+
for behavior, frames in data.items():
|
| 60 |
+
print(f"\nProcessing behavior: {behavior}")
|
| 61 |
+
|
| 62 |
+
# Get all frame indices
|
| 63 |
+
frame_indices = list(frames.keys())
|
| 64 |
+
total_frames = len(frame_indices)
|
| 65 |
+
|
| 66 |
+
# Shuffle frame indices
|
| 67 |
+
random.shuffle(frame_indices)
|
| 68 |
+
|
| 69 |
+
# Calculate split point
|
| 70 |
+
train_size = int(total_frames * train_ratio)
|
| 71 |
+
|
| 72 |
+
# Split frames
|
| 73 |
+
train_frames = frame_indices[:train_size]
|
| 74 |
+
test_frames = frame_indices[train_size:]
|
| 75 |
+
|
| 76 |
+
print(f" Total frames: {total_frames}")
|
| 77 |
+
print(f" Train frames: {len(train_frames)}")
|
| 78 |
+
print(f" Test frames: {len(test_frames)}")
|
| 79 |
+
|
| 80 |
+
# Assign to train and test
|
| 81 |
+
for frame_idx in train_frames:
|
| 82 |
+
train_data[behavior][frame_idx] = frames[frame_idx]
|
| 83 |
+
|
| 84 |
+
for frame_idx in test_frames:
|
| 85 |
+
test_data[behavior][frame_idx] = frames[frame_idx]
|
| 86 |
+
|
| 87 |
+
# Save train and test splits
|
| 88 |
+
output_dir = Path(output_dir)
|
| 89 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 90 |
+
|
| 91 |
+
train_json = output_dir / "train.json"
|
| 92 |
+
test_json = output_dir / "test.json"
|
| 93 |
+
|
| 94 |
+
print(f"\nSaving train data to {train_json}...")
|
| 95 |
+
with open(train_json, 'w') as f:
|
| 96 |
+
json.dump(dict(train_data), f, indent=4)
|
| 97 |
+
|
| 98 |
+
print(f"Saving test data to {test_json}...")
|
| 99 |
+
with open(test_json, 'w') as f:
|
| 100 |
+
json.dump(dict(test_data), f, indent=4)
|
| 101 |
+
|
| 102 |
+
# Print summary
|
| 103 |
+
print("\n" + "="*50)
|
| 104 |
+
print("Summary:")
|
| 105 |
+
print("="*50)
|
| 106 |
+
|
| 107 |
+
total_train_frames = sum(len(frames) for frames in train_data.values())
|
| 108 |
+
total_test_frames = sum(len(frames) for frames in test_data.values())
|
| 109 |
+
total_frames = total_train_frames + total_test_frames
|
| 110 |
+
|
| 111 |
+
print(f"Total frames: {total_frames}")
|
| 112 |
+
print(f"Train frames: {total_train_frames} ({total_train_frames/total_frames*100:.1f}%%)")
|
| 113 |
+
print(f"Test frames: {total_test_frames} ({total_test_frames/total_frames*100:.1f}%%)")
|
| 114 |
+
print("\nPer behavior:")
|
| 115 |
+
for behavior in train_data.keys():
|
| 116 |
+
train_count = len(train_data[behavior])
|
| 117 |
+
test_count = len(test_data[behavior])
|
| 118 |
+
total_count = train_count + test_count
|
| 119 |
+
print(f" {behavior}: train={train_count}, test={test_count}, total={total_count}")
|
| 120 |
+
|
| 121 |
+
print("\nDone!")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
parser = argparse.ArgumentParser(
|
| 126 |
+
description="Split multiview_mapping.json into train/test sets (default 7:3)."
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--input_json", type=str,
|
| 130 |
+
default="datasets/acinoset/multiview_mapping.json",
|
| 131 |
+
help="Path to multiview_mapping.json (default: datasets/acinoset/multiview_mapping.json)."
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--output_dir", type=str,
|
| 135 |
+
default="datasets/acinoset",
|
| 136 |
+
help="Directory to save train.json and test.json (default: datasets/acinoset)."
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--train_ratio", type=float, default=0.7,
|
| 140 |
+
help="Fraction of data for training (default: 0.7)."
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--seed", type=int, default=42,
|
| 144 |
+
help="Random seed for reproducibility (default: 42)."
|
| 145 |
+
)
|
| 146 |
+
args = parser.parse_args()
|
| 147 |
+
|
| 148 |
+
split_multiview_data(
|
| 149 |
+
input_json=args.input_json,
|
| 150 |
+
output_dir=args.output_dir,
|
| 151 |
+
train_ratio=args.train_ratio,
|
| 152 |
+
seed=args.seed,
|
| 153 |
+
)
|
prima/datasets/utils.py
ADDED
|
@@ -0,0 +1,1106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Parts of the code are taken or adapted from
|
| 12 |
+
https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py
|
| 13 |
+
"""
|
| 14 |
+
import torch
|
| 15 |
+
import numpy as np
|
| 16 |
+
from skimage.transform import rotate, resize
|
| 17 |
+
from skimage.filters import gaussian
|
| 18 |
+
import random
|
| 19 |
+
import cv2
|
| 20 |
+
from typing import List, Dict, Tuple
|
| 21 |
+
from yacs.config import CfgNode
|
| 22 |
+
from typing import Union
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None):
|
| 26 |
+
"""Increase the size of the bounding box to match the target shape."""
|
| 27 |
+
if target_aspect_ratio is None:
|
| 28 |
+
return input_shape
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
w, h = input_shape
|
| 32 |
+
except (ValueError, TypeError):
|
| 33 |
+
return input_shape
|
| 34 |
+
|
| 35 |
+
w_t, h_t = target_aspect_ratio
|
| 36 |
+
if h / w < h_t / w_t:
|
| 37 |
+
h_new = max(w * h_t / w_t, h)
|
| 38 |
+
w_new = w
|
| 39 |
+
else:
|
| 40 |
+
h_new = h
|
| 41 |
+
w_new = max(h * w_t / h_t, w)
|
| 42 |
+
if h_new < h or w_new < w:
|
| 43 |
+
raise ValueError(f"Expanded size ({w_new}, {h_new}) smaller than original ({w}, {h})")
|
| 44 |
+
return np.array([w_new, h_new])
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def do_augmentation(aug_config: CfgNode) -> Tuple:
|
| 48 |
+
"""
|
| 49 |
+
Compute random augmentation parameters.
|
| 50 |
+
Args:
|
| 51 |
+
aug_config (CfgNode): Config containing augmentation parameters.
|
| 52 |
+
Returns:
|
| 53 |
+
scale (float): Box rescaling factor.
|
| 54 |
+
rot (float): Random image rotation.
|
| 55 |
+
do_flip (bool): Whether to flip image or not.
|
| 56 |
+
do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT).
|
| 57 |
+
color_scale (List): Color rescaling factor
|
| 58 |
+
tx (float): Random translation along the x axis.
|
| 59 |
+
ty (float): Random translation along the y axis.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
|
| 63 |
+
ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
|
| 64 |
+
scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0
|
| 65 |
+
rot = np.clip(np.random.randn(), -2.0,
|
| 66 |
+
2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0
|
| 67 |
+
do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE
|
| 68 |
+
do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE
|
| 69 |
+
extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0)
|
| 70 |
+
# extreme_crop_lvl = 0
|
| 71 |
+
c_up = 1.0 + aug_config.COLOR_SCALE
|
| 72 |
+
c_low = 1.0 - aug_config.COLOR_SCALE
|
| 73 |
+
color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]
|
| 74 |
+
return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def rotate_2d(pt_2d: np.array, rot_rad: float) -> np.array:
|
| 78 |
+
"""
|
| 79 |
+
Rotate a 2D point on the x-y plane.
|
| 80 |
+
Args:
|
| 81 |
+
pt_2d (np.array): Input 2D point with shape (2,).
|
| 82 |
+
rot_rad (float): Rotation angle
|
| 83 |
+
Returns:
|
| 84 |
+
np.array: Rotated 2D point.
|
| 85 |
+
"""
|
| 86 |
+
x = pt_2d[0]
|
| 87 |
+
y = pt_2d[1]
|
| 88 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
| 89 |
+
xx = x * cs - y * sn
|
| 90 |
+
yy = x * sn + y * cs
|
| 91 |
+
return np.array([xx, yy], dtype=np.float32)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def gen_trans_from_patch_cv(c_x: float, c_y: float,
|
| 95 |
+
src_width: float, src_height: float,
|
| 96 |
+
dst_width: float, dst_height: float,
|
| 97 |
+
scale: float, rot: float) -> np.array:
|
| 98 |
+
"""
|
| 99 |
+
Create transformation matrix for the bounding box crop.
|
| 100 |
+
Args:
|
| 101 |
+
c_x (float): Bounding box center x coordinate in the original image.
|
| 102 |
+
c_y (float): Bounding box center y coordinate in the original image.
|
| 103 |
+
src_width (float): Bounding box width.
|
| 104 |
+
src_height (float): Bounding box height.
|
| 105 |
+
dst_width (float): Output box width.
|
| 106 |
+
dst_height (float): Output box height.
|
| 107 |
+
scale (float): Rescaling factor for the bounding box (augmentation).
|
| 108 |
+
rot (float): Random rotation applied to the box.
|
| 109 |
+
Returns:
|
| 110 |
+
trans (np.array): Target geometric transformation.
|
| 111 |
+
"""
|
| 112 |
+
# augment size with scale
|
| 113 |
+
src_w = src_width * scale
|
| 114 |
+
src_h = src_height * scale
|
| 115 |
+
src_center = np.zeros(2)
|
| 116 |
+
src_center[0] = c_x
|
| 117 |
+
src_center[1] = c_y
|
| 118 |
+
# augment rotation
|
| 119 |
+
rot_rad = np.pi * rot / 180
|
| 120 |
+
src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
|
| 121 |
+
src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
|
| 122 |
+
|
| 123 |
+
dst_w = dst_width
|
| 124 |
+
dst_h = dst_height
|
| 125 |
+
dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
|
| 126 |
+
dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
|
| 127 |
+
dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
|
| 128 |
+
|
| 129 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
| 130 |
+
src[0, :] = src_center
|
| 131 |
+
src[1, :] = src_center + src_downdir
|
| 132 |
+
src[2, :] = src_center + src_rightdir
|
| 133 |
+
|
| 134 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
| 135 |
+
dst[0, :] = dst_center
|
| 136 |
+
dst[1, :] = dst_center + dst_downdir
|
| 137 |
+
dst[2, :] = dst_center + dst_rightdir
|
| 138 |
+
|
| 139 |
+
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
| 140 |
+
|
| 141 |
+
return trans
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def trans_point2d(pt_2d: np.array, trans: np.array):
|
| 145 |
+
"""
|
| 146 |
+
Transform a 2D point using translation matrix trans.
|
| 147 |
+
Args:
|
| 148 |
+
pt_2d (np.array): Input 2D point with shape (2,).
|
| 149 |
+
trans (np.array): Transformation matrix.
|
| 150 |
+
Returns:
|
| 151 |
+
np.array: Transformed 2D point.
|
| 152 |
+
"""
|
| 153 |
+
src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T
|
| 154 |
+
dst_pt = np.dot(trans, src_pt)
|
| 155 |
+
return dst_pt[0:2]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_transform(center, scale, res, rot=0):
|
| 159 |
+
"""Generate transformation matrix."""
|
| 160 |
+
"""Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
|
| 161 |
+
h = 200 * scale
|
| 162 |
+
t = np.zeros((3, 3))
|
| 163 |
+
t[0, 0] = float(res[1]) / h
|
| 164 |
+
t[1, 1] = float(res[0]) / h
|
| 165 |
+
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
|
| 166 |
+
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
|
| 167 |
+
t[2, 2] = 1
|
| 168 |
+
if not rot == 0:
|
| 169 |
+
rot = -rot # To match direction of rotation from cropping
|
| 170 |
+
rot_mat = np.zeros((3, 3))
|
| 171 |
+
rot_rad = rot * np.pi / 180
|
| 172 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
| 173 |
+
rot_mat[0, :2] = [cs, -sn]
|
| 174 |
+
rot_mat[1, :2] = [sn, cs]
|
| 175 |
+
rot_mat[2, 2] = 1
|
| 176 |
+
# Need to rotate around center
|
| 177 |
+
t_mat = np.eye(3)
|
| 178 |
+
t_mat[0, 2] = -res[1] / 2
|
| 179 |
+
t_mat[1, 2] = -res[0] / 2
|
| 180 |
+
t_inv = t_mat.copy()
|
| 181 |
+
t_inv[:2, 2] *= -1
|
| 182 |
+
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
|
| 183 |
+
return t
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def transform(pt, center, scale, res, invert=0, rot=0, as_int=True):
|
| 187 |
+
"""Transform pixel location to different reference."""
|
| 188 |
+
"""Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
|
| 189 |
+
t = get_transform(center, scale, res, rot=rot)
|
| 190 |
+
if invert:
|
| 191 |
+
t = np.linalg.inv(t)
|
| 192 |
+
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
|
| 193 |
+
new_pt = np.dot(t, new_pt)
|
| 194 |
+
if as_int:
|
| 195 |
+
new_pt = new_pt.astype(int)
|
| 196 |
+
return new_pt[:2] + 1
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0):
|
| 200 |
+
c_x = (ul[0] + br[0]) / 2
|
| 201 |
+
c_y = (ul[1] + br[1]) / 2
|
| 202 |
+
bb_width = patch_width = br[0] - ul[0]
|
| 203 |
+
bb_height = patch_height = br[1] - ul[1]
|
| 204 |
+
trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0)
|
| 205 |
+
img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
|
| 206 |
+
flags=cv2.INTER_LINEAR,
|
| 207 |
+
borderMode=border_mode,
|
| 208 |
+
borderValue=border_value
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Force borderValue=cv2.BORDER_CONSTANT for alpha channel
|
| 212 |
+
if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
|
| 213 |
+
img_patch[:, :, 3] = cv2.warpAffine(img[:, :, 3], trans, (int(patch_width), int(patch_height)),
|
| 214 |
+
flags=cv2.INTER_LINEAR,
|
| 215 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
return img_patch
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def generate_image_patch_skimage(img: np.array, c_x: float, c_y: float,
|
| 222 |
+
bb_width: float, bb_height: float,
|
| 223 |
+
patch_width: float, patch_height: float,
|
| 224 |
+
do_flip: bool, scale: float, rot: float,
|
| 225 |
+
border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
|
| 226 |
+
"""
|
| 227 |
+
Crop image according to the supplied bounding box.
|
| 228 |
+
Args:
|
| 229 |
+
img (np.array): Input image of shape (H, W, 3)
|
| 230 |
+
c_x (float): Bounding box center x coordinate in the original image.
|
| 231 |
+
c_y (float): Bounding box center y coordinate in the original image.
|
| 232 |
+
bb_width (float): Bounding box width.
|
| 233 |
+
bb_height (float): Bounding box height.
|
| 234 |
+
patch_width (float): Output box width.
|
| 235 |
+
patch_height (float): Output box height.
|
| 236 |
+
do_flip (bool): Whether to flip image or not.
|
| 237 |
+
scale (float): Rescaling factor for the bounding box (augmentation).
|
| 238 |
+
rot (float): Random rotation applied to the box.
|
| 239 |
+
Returns:
|
| 240 |
+
img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
|
| 241 |
+
trans (np.array): Transformation matrix.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
img_height, img_width, img_channels = img.shape
|
| 245 |
+
if do_flip:
|
| 246 |
+
img = img[:, ::-1, :]
|
| 247 |
+
c_x = img_width - c_x - 1
|
| 248 |
+
|
| 249 |
+
trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
|
| 250 |
+
|
| 251 |
+
# img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR)
|
| 252 |
+
|
| 253 |
+
# skimage
|
| 254 |
+
center = np.zeros(2)
|
| 255 |
+
center[0] = c_x
|
| 256 |
+
center[1] = c_y
|
| 257 |
+
res = np.zeros(2)
|
| 258 |
+
res[0] = patch_width
|
| 259 |
+
res[1] = patch_height
|
| 260 |
+
# assumes bb_width = bb_height
|
| 261 |
+
# assumes patch_width = patch_height
|
| 262 |
+
assert bb_width == bb_height, f'{bb_width=} != {bb_height=}'
|
| 263 |
+
assert patch_width == patch_height, f'{patch_width=} != {patch_height=}'
|
| 264 |
+
scale1 = scale * bb_width / 200.
|
| 265 |
+
|
| 266 |
+
# Upper left point
|
| 267 |
+
ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1
|
| 268 |
+
# Bottom right point
|
| 269 |
+
br = np.array(transform([res[0] + 1,
|
| 270 |
+
res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1
|
| 271 |
+
|
| 272 |
+
# Padding so that when rotated proper amount of context is included
|
| 273 |
+
try:
|
| 274 |
+
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1
|
| 275 |
+
except Exception as e:
|
| 276 |
+
raise RuntimeError(f"Failed to compute pad: ul={ul}, br={br}") from e
|
| 277 |
+
if not rot == 0:
|
| 278 |
+
ul -= pad
|
| 279 |
+
br += pad
|
| 280 |
+
|
| 281 |
+
if False:
|
| 282 |
+
# Old way of cropping image
|
| 283 |
+
ul_int = ul.astype(int)
|
| 284 |
+
br_int = br.astype(int)
|
| 285 |
+
new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]]
|
| 286 |
+
if len(img.shape) > 2:
|
| 287 |
+
new_shape += [img.shape[2]]
|
| 288 |
+
new_img = np.zeros(new_shape)
|
| 289 |
+
|
| 290 |
+
# Range to fill new array
|
| 291 |
+
new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0]
|
| 292 |
+
new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1]
|
| 293 |
+
# Range to sample from original image
|
| 294 |
+
old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0])
|
| 295 |
+
old_y = max(0, ul_int[1]), min(len(img), br_int[1])
|
| 296 |
+
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
|
| 297 |
+
old_x[0]:old_x[1]]
|
| 298 |
+
|
| 299 |
+
# New way of cropping image
|
| 300 |
+
new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32)
|
| 301 |
+
|
| 302 |
+
# print(f'{new_img.shape=}')
|
| 303 |
+
# print(f'{new_img1.shape=}')
|
| 304 |
+
# print(f'{np.allclose(new_img, new_img1)=}')
|
| 305 |
+
# print(f'{img.dtype=}')
|
| 306 |
+
|
| 307 |
+
if not rot == 0:
|
| 308 |
+
# Remove padding
|
| 309 |
+
|
| 310 |
+
new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
|
| 311 |
+
new_img = new_img[pad:-pad, pad:-pad]
|
| 312 |
+
|
| 313 |
+
if new_img.shape[0] < 1 or new_img.shape[1] < 1:
|
| 314 |
+
raise ValueError(
|
| 315 |
+
f"Image patch too small: {new_img.shape}, original: {img.shape}, "
|
| 316 |
+
f"ul={ul}, br={br}, pad={pad}, rot={rot}"
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# resize image
|
| 320 |
+
new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
|
| 321 |
+
|
| 322 |
+
new_img = np.clip(new_img, 0, 255).astype(np.uint8)
|
| 323 |
+
|
| 324 |
+
return new_img, trans
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def generate_image_patch_cv2(img: np.array, c_x: float, c_y: float,
|
| 328 |
+
bb_width: float, bb_height: float,
|
| 329 |
+
patch_width: float, patch_height: float,
|
| 330 |
+
do_flip: bool, scale: float, rot: float,
|
| 331 |
+
border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
|
| 332 |
+
"""
|
| 333 |
+
Crop the input image and return the crop and the corresponding transformation matrix.
|
| 334 |
+
Args:
|
| 335 |
+
img (np.array): Input image of shape (H, W, 3)
|
| 336 |
+
c_x (float): Bounding box center x coordinate in the original image.
|
| 337 |
+
c_y (float): Bounding box center y coordinate in the original image.
|
| 338 |
+
bb_width (float): Bounding box width.
|
| 339 |
+
bb_height (float): Bounding box height.
|
| 340 |
+
patch_width (float): Output box width.
|
| 341 |
+
patch_height (float): Output box height.
|
| 342 |
+
do_flip (bool): Whether to flip image or not.
|
| 343 |
+
scale (float): Rescaling factor for the bounding box (augmentation).
|
| 344 |
+
rot (float): Random rotation applied to the box.
|
| 345 |
+
Returns:
|
| 346 |
+
img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
|
| 347 |
+
trans (np.array): Transformation matrix.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
img_height, img_width, img_channels = img.shape
|
| 351 |
+
if do_flip:
|
| 352 |
+
img = img[:, ::-1, :]
|
| 353 |
+
c_x = img_width - c_x - 1
|
| 354 |
+
|
| 355 |
+
trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
|
| 356 |
+
|
| 357 |
+
img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
|
| 358 |
+
flags=cv2.INTER_LINEAR,
|
| 359 |
+
borderMode=border_mode,
|
| 360 |
+
borderValue=border_value,
|
| 361 |
+
)
|
| 362 |
+
# Force borderValue=cv2.BORDER_CONSTANT for alpha channel
|
| 363 |
+
if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
|
| 364 |
+
img_patch[:, :, 3] = cv2.warpAffine(img[:, :, 3], trans, (int(patch_width), int(patch_height)),
|
| 365 |
+
flags=cv2.INTER_LINEAR,
|
| 366 |
+
borderMode=cv2.BORDER_CONSTANT,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
is_border = np.all(img_patch[:, :, :-1] == border_value, axis=2) if img_patch.shape[2] == 4 else np.all(img_patch == 0, axis=2)
|
| 370 |
+
img_border_mask = ~is_border
|
| 371 |
+
return img_patch, trans, img_border_mask
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def convert_cvimg_to_tensor(cvimg: np.array):
|
| 375 |
+
"""
|
| 376 |
+
Convert image from HWC to CHW format.
|
| 377 |
+
Args:
|
| 378 |
+
cvimg (np.array): Image of shape (H, W, 3) as loaded by OpenCV.
|
| 379 |
+
Returns:
|
| 380 |
+
np.array: Output image of shape (3, H, W).
|
| 381 |
+
"""
|
| 382 |
+
# from h,w,c(OpenCV) to c,h,w
|
| 383 |
+
img = cvimg.copy()
|
| 384 |
+
img = np.transpose(img, (2, 0, 1))
|
| 385 |
+
# from int to float
|
| 386 |
+
img = img.astype(np.float32)
|
| 387 |
+
return img
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def fliplr_params(smal_params: Dict, has_smal_params: Dict) -> Tuple[Dict, Dict]:
|
| 391 |
+
"""
|
| 392 |
+
Flip SMAL parameters when flipping the image.
|
| 393 |
+
Args:
|
| 394 |
+
smal_params (Dict): SMAL parameter annotations.
|
| 395 |
+
has_smal_params (Dict): Whether SMAL annotations are valid.
|
| 396 |
+
Returns:
|
| 397 |
+
Dict, Dict: Flipped SMAL parameters and valid flags.
|
| 398 |
+
"""
|
| 399 |
+
global_orient = smal_params['global_orient'].copy()
|
| 400 |
+
pose = smal_params['pose'].copy()
|
| 401 |
+
betas = smal_params['betas'].copy()
|
| 402 |
+
transl = smal_params['transl'].copy()
|
| 403 |
+
has_global_orient = has_smal_params['global_orient'].copy()
|
| 404 |
+
has_pose = has_smal_params['pose'].copy()
|
| 405 |
+
has_betas = has_smal_params['betas'].copy()
|
| 406 |
+
has_transl = has_smal_params['transl'].copy()
|
| 407 |
+
|
| 408 |
+
global_orient[1::3] *= -1
|
| 409 |
+
global_orient[2::3] *= -1
|
| 410 |
+
pose[1::3] *= -1
|
| 411 |
+
pose[2::3] *= -1
|
| 412 |
+
transl[1::3] *= -1
|
| 413 |
+
transl[2::3] *= -1
|
| 414 |
+
|
| 415 |
+
smal_params = {'global_orient': global_orient.astype(np.float32),
|
| 416 |
+
'pose': pose.astype(np.float32),
|
| 417 |
+
'betas': betas.astype(np.float32),
|
| 418 |
+
'transl': transl.astype(np.float32)
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
has_smal_params = {'global_orient': has_global_orient,
|
| 422 |
+
'pose': has_pose,
|
| 423 |
+
'betas': has_betas,
|
| 424 |
+
'transl': has_transl
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
return smal_params, has_smal_params
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def fliplr_keypoints(joints: np.array, width: float, flip_permutation: List[int]) -> np.array:
|
| 431 |
+
"""
|
| 432 |
+
Flip 2D or 3D keypoints.
|
| 433 |
+
Args:
|
| 434 |
+
joints (np.array): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence.
|
| 435 |
+
flip_permutation (List): Permutation to apply after flipping.
|
| 436 |
+
Returns:
|
| 437 |
+
np.array: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively.
|
| 438 |
+
"""
|
| 439 |
+
joints = joints.copy()
|
| 440 |
+
# Flip horizontal
|
| 441 |
+
joints[:, 0] = width - joints[:, 0] - 1
|
| 442 |
+
joints = joints[flip_permutation, :]
|
| 443 |
+
|
| 444 |
+
return joints
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def keypoint_3d_processing(keypoints_3d: np.array, rot: float, flip: bool) -> np.array:
|
| 448 |
+
"""
|
| 449 |
+
Process 3D keypoints (rotation/flipping).
|
| 450 |
+
Args:
|
| 451 |
+
keypoints_3d (np.array): Input array of shape (N, 4) containing the 3D keypoints and confidence.
|
| 452 |
+
rot (float): Random rotation applied to the keypoints.
|
| 453 |
+
Returns:
|
| 454 |
+
np.array: Transformed 3D keypoints with shape (N, 4).
|
| 455 |
+
"""
|
| 456 |
+
# in-plane rotation
|
| 457 |
+
rot_mat = np.eye(3, dtype=np.float32)
|
| 458 |
+
if not rot == 0:
|
| 459 |
+
rot_rad = -rot * np.pi / 180
|
| 460 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
| 461 |
+
rot_mat[0, :2] = [cs, -sn]
|
| 462 |
+
rot_mat[1, :2] = [sn, cs]
|
| 463 |
+
keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1])
|
| 464 |
+
# flip the x coordinates
|
| 465 |
+
if flip:
|
| 466 |
+
keypoints_3d = fliplr_keypoints(keypoints_3d, list(range(len(keypoints_3d))))
|
| 467 |
+
keypoints_3d = keypoints_3d.astype('float32')
|
| 468 |
+
return keypoints_3d
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def rot_aa(aa: np.array, rot: float) -> np.array:
|
| 472 |
+
"""
|
| 473 |
+
Rotate axis angle parameters.
|
| 474 |
+
Args:
|
| 475 |
+
aa (np.array): Axis-angle vector of shape (3,).
|
| 476 |
+
rot (np.array): Rotation angle in degrees.
|
| 477 |
+
Returns:
|
| 478 |
+
np.array: Rotated axis-angle vector.
|
| 479 |
+
"""
|
| 480 |
+
# pose parameters
|
| 481 |
+
R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
| 482 |
+
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
| 483 |
+
[0, 0, 1]])
|
| 484 |
+
# find the rotation of the hand in camera frame
|
| 485 |
+
per_rdg, _ = cv2.Rodrigues(aa)
|
| 486 |
+
# apply the global rotation to the global orientation
|
| 487 |
+
resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
|
| 488 |
+
aa = (resrot.T)[0]
|
| 489 |
+
return aa.astype(np.float32)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def smal_param_processing(smal_params: Dict, has_smal_params: Dict, rot: float, do_flip: bool) -> Tuple[Dict, Dict]:
|
| 493 |
+
"""
|
| 494 |
+
Apply random augmentations to the SMAL parameters.
|
| 495 |
+
Args:
|
| 496 |
+
smal_params (Dict): SMAL parameter annotations.
|
| 497 |
+
has_smal_params (Dict): Whether SMAL annotations are valid.
|
| 498 |
+
rot (float): Random rotation applied to the keypoints.
|
| 499 |
+
do_flip (bool): Whether to flip keypoints or not.
|
| 500 |
+
Returns:
|
| 501 |
+
Dict, Dict: Transformed SMAL parameters and valid flags.
|
| 502 |
+
"""
|
| 503 |
+
if do_flip:
|
| 504 |
+
smal_params, has_smal_params = fliplr_params(smal_params, has_smal_params)
|
| 505 |
+
smal_params['global_orient'] = rot_aa(smal_params['global_orient'], rot)
|
| 506 |
+
# camera location is not change, so the translation is not change too.
|
| 507 |
+
# smal_params['transl'] = np.dot(np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
| 508 |
+
# [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
| 509 |
+
# [0, 0, 1]], dtype=np.float32), smal_params['transl'])
|
| 510 |
+
return smal_params, has_smal_params
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def get_example(img_path: Union[str,np.ndarray], center_x: float, center_y: float,
|
| 514 |
+
width: float, height: float,
|
| 515 |
+
keypoints_2d: np.array, keypoints_3d: np.array,
|
| 516 |
+
smal_params: Dict, has_smal_params: Dict,
|
| 517 |
+
patch_width: int, patch_height: int,
|
| 518 |
+
mean: np.array, std: np.array,
|
| 519 |
+
do_augment: bool, augm_config: CfgNode,
|
| 520 |
+
is_bgr: bool = True,
|
| 521 |
+
use_skimage_antialias: bool = False,
|
| 522 |
+
border_mode: int = cv2.BORDER_CONSTANT,
|
| 523 |
+
return_trans: bool = False,) -> Tuple:
|
| 524 |
+
"""
|
| 525 |
+
Get an example from the dataset and (possibly) apply random augmentations.
|
| 526 |
+
Args:
|
| 527 |
+
img_path (str): Image filename
|
| 528 |
+
center_x (float): Bounding box center x coordinate in the original image.
|
| 529 |
+
center_y (float): Bounding box center y coordinate in the original image.
|
| 530 |
+
width (float): Bounding box width.
|
| 531 |
+
height (float): Bounding box height.
|
| 532 |
+
keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
|
| 533 |
+
keypoints_3d (np.array): Array with shape (N,4) containing the 3D keypoints.
|
| 534 |
+
smal_params (Dict): SMAL parameter annotations.
|
| 535 |
+
has_smal_params (Dict): Whether SMAL annotations are valid.
|
| 536 |
+
patch_width (float): Output box width.
|
| 537 |
+
patch_height (float): Output box height.
|
| 538 |
+
mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
|
| 539 |
+
std (np.array): Array of shape (3,) containing the std for normalizing the input image.
|
| 540 |
+
do_augment (bool): Whether to apply data augmentation or not.
|
| 541 |
+
aug_config (CfgNode): Config containing augmentation parameters.
|
| 542 |
+
Returns:
|
| 543 |
+
return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size
|
| 544 |
+
img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
|
| 545 |
+
keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
|
| 546 |
+
keypoints_3d (np.array): Array with shape (N,4) containing the transformed 3D keypoints.
|
| 547 |
+
smal_params (Dict): Transformed SMAL parameters.
|
| 548 |
+
has_smal_params (Dict): Valid flag for transformed SMAL parameters.
|
| 549 |
+
img_size (np.array): Image size of the original image.
|
| 550 |
+
"""
|
| 551 |
+
if isinstance(img_path, str):
|
| 552 |
+
# 1. load image
|
| 553 |
+
cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
|
| 554 |
+
if not isinstance(cvimg, np.ndarray):
|
| 555 |
+
raise IOError("Fail to read %s" % img_path)
|
| 556 |
+
elif isinstance(img_path, np.ndarray):
|
| 557 |
+
cvimg = img_path
|
| 558 |
+
else:
|
| 559 |
+
raise TypeError('img_path must be either a string or a numpy array')
|
| 560 |
+
img_height, img_width, img_channels = cvimg.shape
|
| 561 |
+
|
| 562 |
+
img_size = np.array([img_height, img_width], dtype=np.int32)
|
| 563 |
+
|
| 564 |
+
# 2. get augmentation params
|
| 565 |
+
if do_augment:
|
| 566 |
+
# box rescale factor, rotation angle, flip or not flip, crop or not crop, ..., color scale, translation x, ...
|
| 567 |
+
scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
|
| 568 |
+
else:
|
| 569 |
+
scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0,
|
| 570 |
+
1.0,
|
| 571 |
+
1.0], 0., 0.
|
| 572 |
+
if width < 1 or height < 1:
|
| 573 |
+
# Skip invalid samples with width/height < 1
|
| 574 |
+
print(f"Warning: Invalid bbox size - width: {width}, height: {height}. Using default size.")
|
| 575 |
+
width = max(width, 1.0)
|
| 576 |
+
height = max(height, 1.0)
|
| 577 |
+
|
| 578 |
+
if do_extreme_crop:
|
| 579 |
+
if extreme_crop_lvl == 0:
|
| 580 |
+
center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d)
|
| 581 |
+
elif extreme_crop_lvl == 1:
|
| 582 |
+
center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height,
|
| 583 |
+
keypoints_2d)
|
| 584 |
+
|
| 585 |
+
THRESH = 4
|
| 586 |
+
if width1 < THRESH or height1 < THRESH:
|
| 587 |
+
pass
|
| 588 |
+
else:
|
| 589 |
+
center_x, center_y, width, height = center_x1, center_y1, width1, height1
|
| 590 |
+
|
| 591 |
+
center_x += width * tx
|
| 592 |
+
center_y += height * ty
|
| 593 |
+
|
| 594 |
+
# Process 3D keypoints
|
| 595 |
+
keypoints_3d = keypoint_3d_processing(keypoints_3d, rot, do_flip)
|
| 596 |
+
|
| 597 |
+
# 3. generate image patch
|
| 598 |
+
if use_skimage_antialias:
|
| 599 |
+
# Blur image to avoid aliasing artifacts
|
| 600 |
+
downsampling_factor = (patch_width / (width * scale))
|
| 601 |
+
if downsampling_factor > 1.1:
|
| 602 |
+
cvimg = gaussian(cvimg, sigma=(downsampling_factor - 1) / 2, channel_axis=2, preserve_range=True,
|
| 603 |
+
truncate=3.0)
|
| 604 |
+
# augmentation image, translation matrix
|
| 605 |
+
img_patch_cv, trans, img_border_mask = generate_image_patch_cv2(cvimg,
|
| 606 |
+
center_x, center_y,
|
| 607 |
+
width, height,
|
| 608 |
+
patch_width, patch_height,
|
| 609 |
+
do_flip, scale, rot,
|
| 610 |
+
border_mode=border_mode)
|
| 611 |
+
|
| 612 |
+
image = img_patch_cv.copy()
|
| 613 |
+
if is_bgr:
|
| 614 |
+
image = image[:, :, ::-1]
|
| 615 |
+
img_patch_cv = image.copy()
|
| 616 |
+
img_patch = convert_cvimg_to_tensor(image) # [h, w, 4] -> [4, h, w]
|
| 617 |
+
|
| 618 |
+
smal_params, has_smal_params = smal_param_processing(smal_params, has_smal_params, rot, do_flip)
|
| 619 |
+
|
| 620 |
+
# apply normalization
|
| 621 |
+
for n_c in range(min(img_channels, 3)):
|
| 622 |
+
img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
|
| 623 |
+
if mean is not None and std is not None:
|
| 624 |
+
img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
|
| 625 |
+
|
| 626 |
+
if do_flip:
|
| 627 |
+
keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, list(range(len(keypoints_2d))))
|
| 628 |
+
|
| 629 |
+
for n_jt in range(len(keypoints_2d)):
|
| 630 |
+
keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
|
| 631 |
+
keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
|
| 632 |
+
|
| 633 |
+
if not return_trans:
|
| 634 |
+
return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, img_border_mask
|
| 635 |
+
else:
|
| 636 |
+
return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, trans, img_border_mask
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def get_cub17_example(cvimg: np.array,
|
| 640 |
+
keypoints_2d: np.array,
|
| 641 |
+
center_x: float, center_y: float,
|
| 642 |
+
width: float, height: float,
|
| 643 |
+
patch_width: int, patch_height: int,
|
| 644 |
+
mean: np.array, std: np.array,
|
| 645 |
+
do_augment: bool, augm_config: CfgNode,
|
| 646 |
+
return_trans=True) -> Tuple:
|
| 647 |
+
"""
|
| 648 |
+
Get an example from the dataset and (possibly) apply random augmentations.
|
| 649 |
+
Args:
|
| 650 |
+
cvimg (np.ndarray): Image
|
| 651 |
+
keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
|
| 652 |
+
center_x (float): Bounding box center x coordinate in the original image.
|
| 653 |
+
center_y (float): Bounding box center y coordinate in the original image.
|
| 654 |
+
width (float): Bounding box width.
|
| 655 |
+
height (float): Bounding box height.
|
| 656 |
+
patch_width (int): Output box width.
|
| 657 |
+
patch_height (int): Output box height.
|
| 658 |
+
mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
|
| 659 |
+
std (np.array): Array of shape (3,) containing the std for normalizing the input image.
|
| 660 |
+
do_augment (bool): Whether to apply data augmentation or not.
|
| 661 |
+
aug_config (CfgNode): Config containing augmentation parameters.
|
| 662 |
+
Returns:
|
| 663 |
+
return img_patch, keypoints_2d
|
| 664 |
+
img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
|
| 665 |
+
keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
|
| 666 |
+
"""
|
| 667 |
+
img_height, img_width, img_channels = cvimg.shape
|
| 668 |
+
|
| 669 |
+
img_size = np.array([img_height, img_width], dtype=np.int32)
|
| 670 |
+
|
| 671 |
+
# 2. get augmentation params
|
| 672 |
+
if do_augment:
|
| 673 |
+
# box rescale factor, rotation angle, flip or not flip, crop or not crop, ..., color scale, translation x, ...
|
| 674 |
+
scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
|
| 675 |
+
else:
|
| 676 |
+
scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0,
|
| 677 |
+
1.0,
|
| 678 |
+
1.0], 0., 0.
|
| 679 |
+
# bounding box height and width
|
| 680 |
+
center_x += width * tx
|
| 681 |
+
center_y += height * ty
|
| 682 |
+
# augmentation image, translation matrix
|
| 683 |
+
img_patch_cv, trans, img_border_mask = generate_image_patch_cv2(cvimg,
|
| 684 |
+
center_x, center_y,
|
| 685 |
+
width, height,
|
| 686 |
+
patch_width, patch_height,
|
| 687 |
+
do_flip, scale, rot,
|
| 688 |
+
border_mode=cv2.BORDER_CONSTANT)
|
| 689 |
+
|
| 690 |
+
image = img_patch_cv.copy()
|
| 691 |
+
img_patch = convert_cvimg_to_tensor(image) # [h, w, 4] -> [4, h, w]
|
| 692 |
+
|
| 693 |
+
# apply normalization
|
| 694 |
+
for n_c in range(min(img_channels, 3)):
|
| 695 |
+
img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
|
| 696 |
+
if mean is not None and std is not None:
|
| 697 |
+
img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
|
| 698 |
+
|
| 699 |
+
if do_flip:
|
| 700 |
+
keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, list(range(len(keypoints_2d))))
|
| 701 |
+
|
| 702 |
+
for n_jt in range(len(keypoints_2d)):
|
| 703 |
+
keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
|
| 704 |
+
keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
|
| 705 |
+
|
| 706 |
+
if return_trans:
|
| 707 |
+
return img_patch, keypoints_2d, img_size, trans, img_border_mask
|
| 708 |
+
else:
|
| 709 |
+
return img_patch, keypoints_2d, img_size, img_border_mask
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
|
| 713 |
+
"""
|
| 714 |
+
Extreme cropping: Crop the box up to the hip locations.
|
| 715 |
+
Args:
|
| 716 |
+
center_x (float): x coordinate of the bounding box center.
|
| 717 |
+
center_y (float): y coordinate of the bounding box center.
|
| 718 |
+
width (float): Bounding box width.
|
| 719 |
+
height (float): Bounding box height.
|
| 720 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 721 |
+
Returns:
|
| 722 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 723 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 724 |
+
width (float): New bounding box width.
|
| 725 |
+
height (float): New bounding box height.
|
| 726 |
+
"""
|
| 727 |
+
keypoints_2d = keypoints_2d.copy()
|
| 728 |
+
lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25 + 0, 25 + 1, 25 + 4, 25 + 5]
|
| 729 |
+
keypoints_2d[lower_body_keypoints, :] = 0
|
| 730 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 731 |
+
center, scale = get_bbox(keypoints_2d)
|
| 732 |
+
center_x = center[0]
|
| 733 |
+
center_y = center[1]
|
| 734 |
+
width = 1.1 * scale[0]
|
| 735 |
+
height = 1.1 * scale[1]
|
| 736 |
+
return center_x, center_y, width, height
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 740 |
+
"""
|
| 741 |
+
Extreme cropping: Crop the box up to the shoulder locations.
|
| 742 |
+
Args:
|
| 743 |
+
center_x (float): x coordinate of the bounding box center.
|
| 744 |
+
center_y (float): y coordinate of the bounding box center.
|
| 745 |
+
width (float): Bounding box width.
|
| 746 |
+
height (float): Bounding box height.
|
| 747 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 748 |
+
Returns:
|
| 749 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 750 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 751 |
+
width (float): New bounding box width.
|
| 752 |
+
height (float): New bounding box height.
|
| 753 |
+
"""
|
| 754 |
+
keypoints_2d = keypoints_2d.copy()
|
| 755 |
+
lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in
|
| 756 |
+
[0, 1, 2, 3, 4, 5, 6, 7,
|
| 757 |
+
10, 11, 14, 15, 16]]
|
| 758 |
+
keypoints_2d[lower_body_keypoints, :] = 0
|
| 759 |
+
center, scale = get_bbox(keypoints_2d)
|
| 760 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 761 |
+
center, scale = get_bbox(keypoints_2d)
|
| 762 |
+
center_x = center[0]
|
| 763 |
+
center_y = center[1]
|
| 764 |
+
width = 1.2 * scale[0]
|
| 765 |
+
height = 1.2 * scale[1]
|
| 766 |
+
return center_x, center_y, width, height
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 770 |
+
"""
|
| 771 |
+
Extreme cropping: Crop the box and keep on only the head.
|
| 772 |
+
Args:
|
| 773 |
+
center_x (float): x coordinate of the bounding box center.
|
| 774 |
+
center_y (float): y coordinate of the bounding box center.
|
| 775 |
+
width (float): Bounding box width.
|
| 776 |
+
height (float): Bounding box height.
|
| 777 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 778 |
+
Returns:
|
| 779 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 780 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 781 |
+
width (float): New bounding box width.
|
| 782 |
+
height (float): New bounding box height.
|
| 783 |
+
"""
|
| 784 |
+
keypoints_2d = keypoints_2d.copy()
|
| 785 |
+
lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in
|
| 786 |
+
[0, 1, 2, 3, 4, 5, 6, 7, 8,
|
| 787 |
+
9, 10, 11, 14, 15, 16]]
|
| 788 |
+
keypoints_2d[lower_body_keypoints, :] = 0
|
| 789 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 790 |
+
center, scale = get_bbox(keypoints_2d)
|
| 791 |
+
center_x = center[0]
|
| 792 |
+
center_y = center[1]
|
| 793 |
+
width = 1.3 * scale[0]
|
| 794 |
+
height = 1.3 * scale[1]
|
| 795 |
+
return center_x, center_y, width, height
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 799 |
+
"""
|
| 800 |
+
Extreme cropping: Crop the box and keep on only the torso.
|
| 801 |
+
Args:
|
| 802 |
+
center_x (float): x coordinate of the bounding box center.
|
| 803 |
+
center_y (float): y coordinate of the bounding box center.
|
| 804 |
+
width (float): Bounding box width.
|
| 805 |
+
height (float): Bounding box height.
|
| 806 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 807 |
+
Returns:
|
| 808 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 809 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 810 |
+
width (float): New bounding box width.
|
| 811 |
+
height (float): New bounding box height.
|
| 812 |
+
"""
|
| 813 |
+
keypoints_2d = keypoints_2d.copy()
|
| 814 |
+
nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in
|
| 815 |
+
[0, 1, 4, 5, 6,
|
| 816 |
+
7, 10, 11, 13,
|
| 817 |
+
17, 18]]
|
| 818 |
+
keypoints_2d[nontorso_body_keypoints, :] = 0
|
| 819 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 820 |
+
center, scale = get_bbox(keypoints_2d)
|
| 821 |
+
center_x = center[0]
|
| 822 |
+
center_y = center[1]
|
| 823 |
+
width = 1.1 * scale[0]
|
| 824 |
+
height = 1.1 * scale[1]
|
| 825 |
+
return center_x, center_y, width, height
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 829 |
+
"""
|
| 830 |
+
Extreme cropping: Crop the box and keep on only the right arm.
|
| 831 |
+
Args:
|
| 832 |
+
center_x (float): x coordinate of the bounding box center.
|
| 833 |
+
center_y (float): y coordinate of the bounding box center.
|
| 834 |
+
width (float): Bounding box width.
|
| 835 |
+
height (float): Bounding box height.
|
| 836 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 837 |
+
Returns:
|
| 838 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 839 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 840 |
+
width (float): New bounding box width.
|
| 841 |
+
height (float): New bounding box height.
|
| 842 |
+
"""
|
| 843 |
+
keypoints_2d = keypoints_2d.copy()
|
| 844 |
+
nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [
|
| 845 |
+
25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
|
| 846 |
+
keypoints_2d[nonrightarm_body_keypoints, :] = 0
|
| 847 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 848 |
+
center, scale = get_bbox(keypoints_2d)
|
| 849 |
+
center_x = center[0]
|
| 850 |
+
center_y = center[1]
|
| 851 |
+
width = 1.1 * scale[0]
|
| 852 |
+
height = 1.1 * scale[1]
|
| 853 |
+
return center_x, center_y, width, height
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 857 |
+
"""
|
| 858 |
+
Extreme cropping: Crop the box and keep on only the left arm.
|
| 859 |
+
Args:
|
| 860 |
+
center_x (float): x coordinate of the bounding box center.
|
| 861 |
+
center_y (float): y coordinate of the bounding box center.
|
| 862 |
+
width (float): Bounding box width.
|
| 863 |
+
height (float): Bounding box height.
|
| 864 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 865 |
+
Returns:
|
| 866 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 867 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 868 |
+
width (float): New bounding box width.
|
| 869 |
+
height (float): New bounding box height.
|
| 870 |
+
"""
|
| 871 |
+
keypoints_2d = keypoints_2d.copy()
|
| 872 |
+
nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [
|
| 873 |
+
25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]]
|
| 874 |
+
keypoints_2d[nonleftarm_body_keypoints, :] = 0
|
| 875 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 876 |
+
center, scale = get_bbox(keypoints_2d)
|
| 877 |
+
center_x = center[0]
|
| 878 |
+
center_y = center[1]
|
| 879 |
+
width = 1.1 * scale[0]
|
| 880 |
+
height = 1.1 * scale[1]
|
| 881 |
+
return center_x, center_y, width, height
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 885 |
+
"""
|
| 886 |
+
Extreme cropping: Crop the box and keep on only the legs.
|
| 887 |
+
Args:
|
| 888 |
+
center_x (float): x coordinate of the bounding box center.
|
| 889 |
+
center_y (float): y coordinate of the bounding box center.
|
| 890 |
+
width (float): Bounding box width.
|
| 891 |
+
height (float): Bounding box height.
|
| 892 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 893 |
+
Returns:
|
| 894 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 895 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 896 |
+
width (float): New bounding box width.
|
| 897 |
+
height (float): New bounding box height.
|
| 898 |
+
"""
|
| 899 |
+
keypoints_2d = keypoints_2d.copy()
|
| 900 |
+
nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in
|
| 901 |
+
[6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]]
|
| 902 |
+
keypoints_2d[nonlegs_body_keypoints, :] = 0
|
| 903 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 904 |
+
center, scale = get_bbox(keypoints_2d)
|
| 905 |
+
center_x = center[0]
|
| 906 |
+
center_y = center[1]
|
| 907 |
+
width = 1.1 * scale[0]
|
| 908 |
+
height = 1.1 * scale[1]
|
| 909 |
+
return center_x, center_y, width, height
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 913 |
+
"""
|
| 914 |
+
Extreme cropping: Crop the box and keep on only the right leg.
|
| 915 |
+
Args:
|
| 916 |
+
center_x (float): x coordinate of the bounding box center.
|
| 917 |
+
center_y (float): y coordinate of the bounding box center.
|
| 918 |
+
width (float): Bounding box width.
|
| 919 |
+
height (float): Bounding box height.
|
| 920 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 921 |
+
Returns:
|
| 922 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 923 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 924 |
+
width (float): New bounding box width.
|
| 925 |
+
height (float): New bounding box height.
|
| 926 |
+
"""
|
| 927 |
+
keypoints_2d = keypoints_2d.copy()
|
| 928 |
+
nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in
|
| 929 |
+
[3, 4, 5, 6, 7,
|
| 930 |
+
8, 9, 10, 11,
|
| 931 |
+
12, 13, 14, 15,
|
| 932 |
+
16, 17, 18]]
|
| 933 |
+
keypoints_2d[nonrightleg_body_keypoints, :] = 0
|
| 934 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 935 |
+
center, scale = get_bbox(keypoints_2d)
|
| 936 |
+
center_x = center[0]
|
| 937 |
+
center_y = center[1]
|
| 938 |
+
width = 1.1 * scale[0]
|
| 939 |
+
height = 1.1 * scale[1]
|
| 940 |
+
return center_x, center_y, width, height
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
|
| 944 |
+
"""
|
| 945 |
+
Extreme cropping: Crop the box and keep on only the left leg.
|
| 946 |
+
Args:
|
| 947 |
+
center_x (float): x coordinate of the bounding box center.
|
| 948 |
+
center_y (float): y coordinate of the bounding box center.
|
| 949 |
+
width (float): Bounding box width.
|
| 950 |
+
height (float): Bounding box height.
|
| 951 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 952 |
+
Returns:
|
| 953 |
+
center_x (float): x coordinate of the new bounding box center.
|
| 954 |
+
center_y (float): y coordinate of the new bounding box center.
|
| 955 |
+
width (float): New bounding box width.
|
| 956 |
+
height (float): New bounding box height.
|
| 957 |
+
"""
|
| 958 |
+
keypoints_2d = keypoints_2d.copy()
|
| 959 |
+
nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in
|
| 960 |
+
[0, 1, 2, 6, 7, 8,
|
| 961 |
+
9, 10, 11, 12,
|
| 962 |
+
13, 14, 15, 16,
|
| 963 |
+
17, 18]]
|
| 964 |
+
keypoints_2d[nonleftleg_body_keypoints, :] = 0
|
| 965 |
+
if keypoints_2d[:, -1].sum() > 1:
|
| 966 |
+
center, scale = get_bbox(keypoints_2d)
|
| 967 |
+
center_x = center[0]
|
| 968 |
+
center_y = center[1]
|
| 969 |
+
width = 1.1 * scale[0]
|
| 970 |
+
height = 1.1 * scale[1]
|
| 971 |
+
return center_x, center_y, width, height
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
def full_body(keypoints_2d: np.array) -> bool:
|
| 975 |
+
"""
|
| 976 |
+
Check if all main body joints are visible.
|
| 977 |
+
Args:
|
| 978 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 979 |
+
Returns:
|
| 980 |
+
bool: True if all main body joints are visible.
|
| 981 |
+
"""
|
| 982 |
+
|
| 983 |
+
body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14]
|
| 984 |
+
body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]]
|
| 985 |
+
return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(
|
| 986 |
+
body_keypoints)
|
| 987 |
+
|
| 988 |
+
|
| 989 |
+
def upper_body(keypoints_2d: np.array):
|
| 990 |
+
"""
|
| 991 |
+
Check if all upper body joints are visible.
|
| 992 |
+
Args:
|
| 993 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 994 |
+
Returns:
|
| 995 |
+
bool: True if all main body joints are visible.
|
| 996 |
+
"""
|
| 997 |
+
lower_body_keypoints_openpose = [10, 11, 13, 14]
|
| 998 |
+
lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]]
|
| 999 |
+
upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18]
|
| 1000 |
+
upper_body_keypoints = [25 + 8, 25 + 9, 25 + 12, 25 + 13, 25 + 17, 25 + 18]
|
| 1001 |
+
return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0) \
|
| 1002 |
+
and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2)
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
def get_bbox(keypoints_2d: np.array, rescale: float = 1.2) -> Tuple:
|
| 1006 |
+
"""
|
| 1007 |
+
Get center and scale for bounding box from openpose detections.
|
| 1008 |
+
Args:
|
| 1009 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 1010 |
+
rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
|
| 1011 |
+
Returns:
|
| 1012 |
+
center (np.array): Array of shape (2,) containing the new bounding box center.
|
| 1013 |
+
scale (float): New bounding box scale.
|
| 1014 |
+
"""
|
| 1015 |
+
valid = keypoints_2d[:, -1] > 0
|
| 1016 |
+
valid_keypoints = keypoints_2d[valid][:, :-1]
|
| 1017 |
+
center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0))
|
| 1018 |
+
bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0))
|
| 1019 |
+
# adjust bounding box tightness
|
| 1020 |
+
scale = bbox_size
|
| 1021 |
+
scale *= rescale
|
| 1022 |
+
return center, scale
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
|
| 1026 |
+
"""
|
| 1027 |
+
Perform extreme cropping
|
| 1028 |
+
Args:
|
| 1029 |
+
center_x (float): x coordinate of bounding box center.
|
| 1030 |
+
center_y (float): y coordinate of bounding box center.
|
| 1031 |
+
width (float): bounding box width.
|
| 1032 |
+
height (float): bounding box height.
|
| 1033 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 1034 |
+
rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
|
| 1035 |
+
Returns:
|
| 1036 |
+
center_x (float): x coordinate of bounding box center.
|
| 1037 |
+
center_y (float): y coordinate of bounding box center.
|
| 1038 |
+
width (float): bounding box width.
|
| 1039 |
+
height (float): bounding box height.
|
| 1040 |
+
"""
|
| 1041 |
+
p = torch.rand(1).item()
|
| 1042 |
+
if full_body(keypoints_2d):
|
| 1043 |
+
if p < 0.7:
|
| 1044 |
+
center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
|
| 1045 |
+
elif p < 0.9:
|
| 1046 |
+
center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
|
| 1047 |
+
else:
|
| 1048 |
+
center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
|
| 1049 |
+
elif upper_body(keypoints_2d):
|
| 1050 |
+
if p < 0.9:
|
| 1051 |
+
center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
|
| 1052 |
+
else:
|
| 1053 |
+
center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
|
| 1054 |
+
|
| 1055 |
+
return center_x, center_y, max(width, height), max(width, height)
|
| 1056 |
+
|
| 1057 |
+
|
| 1058 |
+
def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float,
|
| 1059 |
+
keypoints_2d: np.array) -> Tuple:
|
| 1060 |
+
"""
|
| 1061 |
+
Perform aggressive extreme cropping
|
| 1062 |
+
Args:
|
| 1063 |
+
center_x (float): x coordinate of bounding box center.
|
| 1064 |
+
center_y (float): y coordinate of bounding box center.
|
| 1065 |
+
width (float): bounding box width.
|
| 1066 |
+
height (float): bounding box height.
|
| 1067 |
+
keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
|
| 1068 |
+
rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
|
| 1069 |
+
Returns:
|
| 1070 |
+
center_x (float): x coordinate of bounding box center.
|
| 1071 |
+
center_y (float): y coordinate of bounding box center.
|
| 1072 |
+
width (float): bounding box width.
|
| 1073 |
+
height (float): bounding box height.
|
| 1074 |
+
"""
|
| 1075 |
+
p = torch.rand(1).item()
|
| 1076 |
+
if full_body(keypoints_2d):
|
| 1077 |
+
if p < 0.2:
|
| 1078 |
+
center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
|
| 1079 |
+
elif p < 0.3:
|
| 1080 |
+
center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
|
| 1081 |
+
elif p < 0.4:
|
| 1082 |
+
center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
|
| 1083 |
+
elif p < 0.5:
|
| 1084 |
+
center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
|
| 1085 |
+
elif p < 0.6:
|
| 1086 |
+
center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
|
| 1087 |
+
elif p < 0.7:
|
| 1088 |
+
center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
|
| 1089 |
+
elif p < 0.8:
|
| 1090 |
+
center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d)
|
| 1091 |
+
elif p < 0.9:
|
| 1092 |
+
center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d)
|
| 1093 |
+
else:
|
| 1094 |
+
center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d)
|
| 1095 |
+
elif upper_body(keypoints_2d):
|
| 1096 |
+
if p < 0.2:
|
| 1097 |
+
center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
|
| 1098 |
+
elif p < 0.4:
|
| 1099 |
+
center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
|
| 1100 |
+
elif p < 0.6:
|
| 1101 |
+
center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
|
| 1102 |
+
elif p < 0.8:
|
| 1103 |
+
center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
|
| 1104 |
+
else:
|
| 1105 |
+
center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
|
| 1106 |
+
return center_x, center_y, max(width, height), max(width, height)
|
prima/datasets/vitdet_dataset.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Dict
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
from skimage.filters import gaussian
|
| 15 |
+
from yacs.config import CfgNode
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from .utils import (convert_cvimg_to_tensor,
|
| 19 |
+
expand_to_aspect_ratio,
|
| 20 |
+
generate_image_patch_cv2)
|
| 21 |
+
|
| 22 |
+
DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
|
| 23 |
+
DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ViTDetDataset(torch.utils.data.Dataset):
|
| 27 |
+
|
| 28 |
+
def __init__(self,
|
| 29 |
+
cfg: CfgNode,
|
| 30 |
+
img_cv2: np.array,
|
| 31 |
+
boxes: np.array,
|
| 32 |
+
rescale_factor=1,
|
| 33 |
+
train: bool = False,
|
| 34 |
+
**kwargs):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.cfg = cfg
|
| 37 |
+
self.img_cv2 = img_cv2
|
| 38 |
+
self.boxes = boxes
|
| 39 |
+
|
| 40 |
+
assert train is False, "ViTDetDataset is only for inference"
|
| 41 |
+
self.train = train
|
| 42 |
+
self.img_size = cfg.MODEL.IMAGE_SIZE
|
| 43 |
+
self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN)
|
| 44 |
+
self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD)
|
| 45 |
+
|
| 46 |
+
# Preprocess annotations
|
| 47 |
+
boxes = boxes.astype(np.float32)
|
| 48 |
+
self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0
|
| 49 |
+
self.scale = rescale_factor * (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0
|
| 50 |
+
self.animalid = np.arange(len(boxes), dtype=np.int32)
|
| 51 |
+
|
| 52 |
+
def __len__(self) -> int:
|
| 53 |
+
return len(self.animalid)
|
| 54 |
+
|
| 55 |
+
def __getitem__(self, idx: int) -> Dict[str, np.array]:
|
| 56 |
+
|
| 57 |
+
center = self.center[idx].copy()
|
| 58 |
+
center_x = center[0]
|
| 59 |
+
center_y = center[1]
|
| 60 |
+
|
| 61 |
+
scale = self.scale[idx]
|
| 62 |
+
BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None)
|
| 63 |
+
bbox_size = expand_to_aspect_ratio(scale * 200, target_aspect_ratio=BBOX_SHAPE).max()
|
| 64 |
+
|
| 65 |
+
patch_width = patch_height = self.img_size
|
| 66 |
+
|
| 67 |
+
flip = False
|
| 68 |
+
|
| 69 |
+
# 3. generate image patch
|
| 70 |
+
# if use_skimage_antialias:
|
| 71 |
+
cvimg = self.img_cv2.copy()
|
| 72 |
+
if True:
|
| 73 |
+
# Blur image to avoid aliasing artifacts
|
| 74 |
+
downsampling_factor = ((bbox_size * 1.0) / patch_width)
|
| 75 |
+
downsampling_factor = downsampling_factor / 2.0
|
| 76 |
+
if downsampling_factor > 1.1:
|
| 77 |
+
cvimg = gaussian(cvimg, sigma=(downsampling_factor - 1) / 2, channel_axis=2, preserve_range=True)
|
| 78 |
+
|
| 79 |
+
img_patch_cv, trans, _ = generate_image_patch_cv2(cvimg,
|
| 80 |
+
center_x, center_y,
|
| 81 |
+
bbox_size, bbox_size,
|
| 82 |
+
patch_width, patch_height,
|
| 83 |
+
flip, 1.0, 0.0,
|
| 84 |
+
border_mode=cv2.BORDER_CONSTANT)
|
| 85 |
+
img_patch_cv = img_patch_cv[:, :, ::-1]
|
| 86 |
+
img_patch = convert_cvimg_to_tensor(img_patch_cv)
|
| 87 |
+
|
| 88 |
+
# apply normalization
|
| 89 |
+
for n_c in range(min(self.img_cv2.shape[2], 3)):
|
| 90 |
+
img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c]
|
| 91 |
+
|
| 92 |
+
item = {
|
| 93 |
+
'img': img_patch,
|
| 94 |
+
'animalid': int(self.animalid[idx]),
|
| 95 |
+
'box_center': self.center[idx].copy(),
|
| 96 |
+
'box_size': bbox_size,
|
| 97 |
+
'img_size': 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]]),
|
| 98 |
+
'focal_length': np.array([self.cfg.EXTRA.FOCAL_LENGTH, self.cfg.EXTRA.FOCAL_LENGTH]),
|
| 99 |
+
}
|
| 100 |
+
return item
|
prima/models/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from .prima import PRIMA
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_prima(checkpoint_path):
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from ..configs import get_config
|
| 16 |
+
model_cfg = str(Path(checkpoint_path).parent.parent / '.hydra/config.yaml')
|
| 17 |
+
model_cfg = get_config(model_cfg)
|
| 18 |
+
|
| 19 |
+
# Override some config values, to crop bbox correctly
|
| 20 |
+
if (model_cfg.MODEL.BACKBONE.TYPE == 'vit') and ('BBOX_SHAPE' not in model_cfg.MODEL):
|
| 21 |
+
model_cfg.defrost()
|
| 22 |
+
assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
|
| 23 |
+
model_cfg.MODEL.BBOX_SHAPE = [192, 256]
|
| 24 |
+
model_cfg.freeze()
|
| 25 |
+
if (model_cfg.MODEL.BACKBONE.TYPE == 'dinov3') and ('BBOX_SHAPE' not in model_cfg.MODEL):
|
| 26 |
+
model_cfg.defrost()
|
| 27 |
+
assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for dino backbone"
|
| 28 |
+
model_cfg.MODEL.BBOX_SHAPE = [256, 256]
|
| 29 |
+
model_cfg.freeze()
|
| 30 |
+
|
| 31 |
+
if (model_cfg.MODEL.BACKBONE.TYPE == 'dinov2') and ('BBOX_SHAPE' not in model_cfg.MODEL):
|
| 32 |
+
model_cfg.defrost()
|
| 33 |
+
assert model_cfg.MODEL.IMAGE_SIZE == 252, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 252 for dino backbone"
|
| 34 |
+
model_cfg.MODEL.BBOX_SHAPE = [252, 252]
|
| 35 |
+
model_cfg.freeze()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Update config to be compatible with demo
|
| 40 |
+
if ('PRETRAINED_WEIGHTS' in model_cfg.MODEL.BACKBONE):
|
| 41 |
+
model_cfg.defrost()
|
| 42 |
+
model_cfg.MODEL.BACKBONE.pop('PRETRAINED_WEIGHTS')
|
| 43 |
+
model_cfg.freeze()
|
| 44 |
+
|
| 45 |
+
# Offscreen training renderer is not needed for demo/inference startup and
|
| 46 |
+
# can fail on some local OpenGL backends.
|
| 47 |
+
model = PRIMA.load_from_checkpoint(
|
| 48 |
+
checkpoint_path,
|
| 49 |
+
strict=False,
|
| 50 |
+
cfg=model_cfg,
|
| 51 |
+
map_location='cpu',
|
| 52 |
+
init_renderer=False,
|
| 53 |
+
)
|
| 54 |
+
return model, model_cfg
|
prima/models/backbones/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from .vit import vith
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def create_backbone(cfg):
|
| 16 |
+
if cfg.MODEL.BACKBONE.TYPE in ['vith','concat','aa']: # vit bb will be used in these three cases - animal feature extractor
|
| 17 |
+
return vith(cfg)
|
| 18 |
+
else:
|
| 19 |
+
raise NotImplementedError('Backbone type is not implemented')
|
prima/models/backbones/vit.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from functools import partial
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import torch.utils.checkpoint as checkpoint
|
| 18 |
+
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def vith(cfg):
|
| 22 |
+
return ViT(
|
| 23 |
+
img_size=(256, 192),
|
| 24 |
+
patch_size=16,
|
| 25 |
+
embed_dim=1280,
|
| 26 |
+
depth=32,
|
| 27 |
+
num_heads=16,
|
| 28 |
+
ratio=1,
|
| 29 |
+
use_checkpoint=False,
|
| 30 |
+
# use_checkpoint=True,
|
| 31 |
+
mlp_ratio=4,
|
| 32 |
+
qkv_bias=True,
|
| 33 |
+
drop_path_rate=0.55,
|
| 34 |
+
use_cls=True, # cls for animal family classification
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
|
| 39 |
+
"""
|
| 40 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
| 41 |
+
dimension for the original embeddings.
|
| 42 |
+
Args:
|
| 43 |
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
| 44 |
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
| 45 |
+
hw (Tuple): size of input image tokens.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Absolute positional embeddings after processing with shape (1, H, W, C)
|
| 49 |
+
"""
|
| 50 |
+
cls_token = None
|
| 51 |
+
B, L, C = abs_pos.shape
|
| 52 |
+
if has_cls_token:
|
| 53 |
+
cls_token = abs_pos[:, 0:1]
|
| 54 |
+
abs_pos = abs_pos[:, 1:]
|
| 55 |
+
|
| 56 |
+
if ori_h != h or ori_w != w:
|
| 57 |
+
new_abs_pos = F.interpolate(
|
| 58 |
+
abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
|
| 59 |
+
size=(h, w),
|
| 60 |
+
mode="bicubic",
|
| 61 |
+
align_corners=False,
|
| 62 |
+
).permute(0, 2, 3, 1).reshape(B, -1, C)
|
| 63 |
+
|
| 64 |
+
else:
|
| 65 |
+
new_abs_pos = abs_pos
|
| 66 |
+
|
| 67 |
+
if cls_token is not None:
|
| 68 |
+
new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
|
| 69 |
+
return new_abs_pos
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class DropPath(nn.Module):
|
| 73 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(self, drop_prob=None):
|
| 77 |
+
super(DropPath, self).__init__()
|
| 78 |
+
self.drop_prob = drop_prob
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 82 |
+
|
| 83 |
+
def extra_repr(self):
|
| 84 |
+
return 'p={}'.format(self.drop_prob)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class Mlp(nn.Module):
|
| 88 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 89 |
+
super().__init__()
|
| 90 |
+
out_features = out_features or in_features
|
| 91 |
+
hidden_features = hidden_features or in_features
|
| 92 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 93 |
+
self.act = act_layer()
|
| 94 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 95 |
+
self.drop = nn.Dropout(drop)
|
| 96 |
+
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
x = self.fc1(x)
|
| 99 |
+
x = self.act(x)
|
| 100 |
+
x = self.fc2(x)
|
| 101 |
+
x = self.drop(x)
|
| 102 |
+
return x
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Attention(nn.Module):
|
| 106 |
+
def __init__(
|
| 107 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
| 108 |
+
proj_drop=0., attn_head_dim=None):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.num_heads = num_heads
|
| 111 |
+
head_dim = dim // num_heads
|
| 112 |
+
self.dim = dim
|
| 113 |
+
|
| 114 |
+
if attn_head_dim is not None:
|
| 115 |
+
head_dim = attn_head_dim
|
| 116 |
+
all_head_dim = head_dim * self.num_heads
|
| 117 |
+
|
| 118 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 119 |
+
|
| 120 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
|
| 121 |
+
|
| 122 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 123 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
| 124 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
B, N, C = x.shape
|
| 128 |
+
qkv = self.qkv(x)
|
| 129 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 130 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 131 |
+
|
| 132 |
+
q = q * self.scale
|
| 133 |
+
attn = (q @ k.transpose(-2, -1))
|
| 134 |
+
attn = attn.softmax(dim=-1)
|
| 135 |
+
attn = self.attn_drop(attn)
|
| 136 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
| 137 |
+
|
| 138 |
+
x = self.proj(x)
|
| 139 |
+
x = self.proj_drop(x)
|
| 140 |
+
|
| 141 |
+
return x
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class Block(nn.Module):
|
| 145 |
+
|
| 146 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
|
| 147 |
+
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
|
| 148 |
+
norm_layer=nn.LayerNorm, attn_head_dim=None,
|
| 149 |
+
):
|
| 150 |
+
super().__init__()
|
| 151 |
+
|
| 152 |
+
self.norm1 = norm_layer(dim)
|
| 153 |
+
self.attn = Attention(
|
| 154 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 155 |
+
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 159 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 160 |
+
self.norm2 = norm_layer(dim)
|
| 161 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 162 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 166 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 167 |
+
return x
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class PatchEmbed(nn.Module):
|
| 171 |
+
""" Image to Patch Embedding
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
|
| 175 |
+
super().__init__()
|
| 176 |
+
img_size = to_2tuple(img_size)
|
| 177 |
+
patch_size = to_2tuple(patch_size)
|
| 178 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
|
| 179 |
+
self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
|
| 180 |
+
self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
|
| 181 |
+
self.img_size = img_size
|
| 182 |
+
self.patch_size = patch_size
|
| 183 |
+
self.num_patches = num_patches
|
| 184 |
+
|
| 185 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio),
|
| 186 |
+
padding=4 + 2 * (ratio // 2 - 1))
|
| 187 |
+
|
| 188 |
+
def forward(self, x, **kwargs):
|
| 189 |
+
B, C, H, W = x.shape
|
| 190 |
+
x = self.proj(x)
|
| 191 |
+
Hp, Wp = x.shape[2], x.shape[3]
|
| 192 |
+
|
| 193 |
+
x = x.flatten(2).transpose(1, 2)
|
| 194 |
+
return x, (Hp, Wp)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class HybridEmbed(nn.Module):
|
| 198 |
+
""" CNN Feature Map Embedding
|
| 199 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
| 203 |
+
super().__init__()
|
| 204 |
+
assert isinstance(backbone, nn.Module)
|
| 205 |
+
img_size = to_2tuple(img_size)
|
| 206 |
+
self.img_size = img_size
|
| 207 |
+
self.backbone = backbone
|
| 208 |
+
if feature_size is None:
|
| 209 |
+
with torch.no_grad():
|
| 210 |
+
training = backbone.training
|
| 211 |
+
if training:
|
| 212 |
+
backbone.eval()
|
| 213 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
| 214 |
+
feature_size = o.shape[-2:]
|
| 215 |
+
feature_dim = o.shape[1]
|
| 216 |
+
backbone.train(training)
|
| 217 |
+
else:
|
| 218 |
+
feature_size = to_2tuple(feature_size)
|
| 219 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
| 220 |
+
self.num_patches = feature_size[0] * feature_size[1]
|
| 221 |
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
| 222 |
+
|
| 223 |
+
def forward(self, x):
|
| 224 |
+
x = self.backbone(x)[-1]
|
| 225 |
+
x = x.flatten(2).transpose(1, 2)
|
| 226 |
+
x = self.proj(x)
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class ViT(nn.Module):
|
| 231 |
+
|
| 232 |
+
def __init__(self,
|
| 233 |
+
img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
|
| 234 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
| 235 |
+
drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
|
| 236 |
+
frozen_stages=-1, ratio=1, last_norm=True, use_cls=False,
|
| 237 |
+
patch_padding='pad', freeze_attn=False, freeze_ffn=False,
|
| 238 |
+
):
|
| 239 |
+
# Protect mutable default arguments
|
| 240 |
+
super(ViT, self).__init__()
|
| 241 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 242 |
+
self.num_classes = num_classes
|
| 243 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 244 |
+
self.frozen_stages = frozen_stages
|
| 245 |
+
self.use_checkpoint = use_checkpoint
|
| 246 |
+
self.patch_padding = patch_padding
|
| 247 |
+
self.freeze_attn = freeze_attn
|
| 248 |
+
self.freeze_ffn = freeze_ffn
|
| 249 |
+
self.depth = depth
|
| 250 |
+
|
| 251 |
+
if hybrid_backbone is not None:
|
| 252 |
+
self.patch_embed = HybridEmbed(
|
| 253 |
+
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 254 |
+
else:
|
| 255 |
+
self.patch_embed = PatchEmbed(
|
| 256 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
|
| 257 |
+
num_patches = self.patch_embed.num_patches
|
| 258 |
+
|
| 259 |
+
# since the pretraining model has class token
|
| 260 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 261 |
+
|
| 262 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 263 |
+
|
| 264 |
+
self.blocks = nn.ModuleList([
|
| 265 |
+
Block(
|
| 266 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 267 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
| 268 |
+
)
|
| 269 |
+
for i in range(depth)])
|
| 270 |
+
|
| 271 |
+
self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
|
| 272 |
+
|
| 273 |
+
if self.pos_embed is not None:
|
| 274 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 275 |
+
|
| 276 |
+
self.use_cls = use_cls
|
| 277 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 278 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 279 |
+
|
| 280 |
+
self._freeze_stages()
|
| 281 |
+
|
| 282 |
+
def _freeze_stages(self):
|
| 283 |
+
"""Freeze parameters."""
|
| 284 |
+
if self.frozen_stages >= 0:
|
| 285 |
+
self.patch_embed.eval()
|
| 286 |
+
for param in self.patch_embed.parameters():
|
| 287 |
+
param.requires_grad = False
|
| 288 |
+
|
| 289 |
+
for i in range(1, self.frozen_stages + 1):
|
| 290 |
+
m = self.blocks[i]
|
| 291 |
+
m.eval()
|
| 292 |
+
for param in m.parameters():
|
| 293 |
+
param.requires_grad = False
|
| 294 |
+
|
| 295 |
+
if self.freeze_attn:
|
| 296 |
+
for i in range(0, self.depth):
|
| 297 |
+
m = self.blocks[i]
|
| 298 |
+
m.attn.eval()
|
| 299 |
+
m.norm1.eval()
|
| 300 |
+
for param in m.attn.parameters():
|
| 301 |
+
param.requires_grad = False
|
| 302 |
+
for param in m.norm1.parameters():
|
| 303 |
+
param.requires_grad = False
|
| 304 |
+
|
| 305 |
+
if self.freeze_ffn:
|
| 306 |
+
self.pos_embed.requires_grad = False
|
| 307 |
+
self.patch_embed.eval()
|
| 308 |
+
for param in self.patch_embed.parameters():
|
| 309 |
+
param.requires_grad = False
|
| 310 |
+
for i in range(0, self.depth):
|
| 311 |
+
m = self.blocks[i]
|
| 312 |
+
m.mlp.eval()
|
| 313 |
+
m.norm2.eval()
|
| 314 |
+
for param in m.mlp.parameters():
|
| 315 |
+
param.requires_grad = False
|
| 316 |
+
for param in m.norm2.parameters():
|
| 317 |
+
param.requires_grad = False
|
| 318 |
+
|
| 319 |
+
def init_weights(self):
|
| 320 |
+
"""Initialize the weights in backbone.
|
| 321 |
+
Args:
|
| 322 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 323 |
+
Defaults to None.
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
def _init_weights(m):
|
| 327 |
+
if isinstance(m, nn.Linear):
|
| 328 |
+
trunc_normal_(m.weight, std=.02)
|
| 329 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 330 |
+
nn.init.constant_(m.bias, 0)
|
| 331 |
+
elif isinstance(m, nn.LayerNorm):
|
| 332 |
+
nn.init.constant_(m.bias, 0)
|
| 333 |
+
nn.init.constant_(m.weight, 1.0)
|
| 334 |
+
|
| 335 |
+
self.apply(_init_weights)
|
| 336 |
+
|
| 337 |
+
def get_num_layers(self):
|
| 338 |
+
return len(self.blocks)
|
| 339 |
+
|
| 340 |
+
@torch.jit.ignore
|
| 341 |
+
def no_weight_decay(self):
|
| 342 |
+
return {'pos_embed', 'cls_token'}
|
| 343 |
+
|
| 344 |
+
def forward_features(self, x):
|
| 345 |
+
B, C, H, W = x.shape
|
| 346 |
+
x, (Hp, Wp) = self.patch_embed(x)
|
| 347 |
+
|
| 348 |
+
if self.pos_embed is not None:
|
| 349 |
+
# fit for multiple GPU training
|
| 350 |
+
# since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
|
| 351 |
+
x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
|
| 352 |
+
|
| 353 |
+
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) if self.use_cls else x
|
| 354 |
+
for blk in self.blocks:
|
| 355 |
+
if self.use_checkpoint:
|
| 356 |
+
x = checkpoint.checkpoint(blk, x)
|
| 357 |
+
else:
|
| 358 |
+
x = blk(x)
|
| 359 |
+
|
| 360 |
+
x = self.last_norm(x)
|
| 361 |
+
|
| 362 |
+
cls = x[:, 0] if self.use_cls else None
|
| 363 |
+
x = x[:, 1:] if self.use_cls else x
|
| 364 |
+
xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
|
| 365 |
+
|
| 366 |
+
return xp, cls # shape [B, D, Hp, Wp], [B, D]
|
| 367 |
+
|
| 368 |
+
def forward(self, x):
|
| 369 |
+
x, cls = self.forward_features(x)
|
| 370 |
+
return x, cls
|
| 371 |
+
|
| 372 |
+
def train(self, mode=True):
|
| 373 |
+
"""Convert the model into training mode."""
|
| 374 |
+
super().train(mode)
|
| 375 |
+
self._freeze_stages()
|
prima/models/bioclip_embedding.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
bioclip Embedding Module
|
| 12 |
+
Converts image batch to embeddings that can be concatenated with image features
|
| 13 |
+
"""
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
class BioClipEmbedding(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
Embeds images into a feature space using BioClip model that can be combined with image features.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
embed_dim: Output embedding dimension, should match the dimension of image features for concatenation
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, cfg, embed_dim: int = 1024):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
self.embed_dim = embed_dim
|
| 30 |
+
|
| 31 |
+
import open_clip
|
| 32 |
+
|
| 33 |
+
if cfg.MODEL.BIOCLIP_EMBEDDING.TYPE == 'bioclip2':
|
| 34 |
+
print("[BioClipEmbedding] Using BioClip2 model from Hugging Face Hub")
|
| 35 |
+
self.species_model, _,_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2')
|
| 36 |
+
else:
|
| 37 |
+
self.species_model, _,_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip')
|
| 38 |
+
# tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip')
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
self.species_model.eval()
|
| 42 |
+
|
| 43 |
+
# Get the output dimension from the model
|
| 44 |
+
bioclip_output_dim = self.species_model.visual.output_dim
|
| 45 |
+
|
| 46 |
+
# Project to target dimension
|
| 47 |
+
self.projection = nn.Sequential(
|
| 48 |
+
nn.Linear(bioclip_output_dim, embed_dim),
|
| 49 |
+
nn.LayerNorm(embed_dim),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
"""
|
| 54 |
+
Args:
|
| 55 |
+
images: Tensor of shape (B, C, H, W) representing a batch of images
|
| 56 |
+
Returns:
|
| 57 |
+
Tensor of shape (B, embed_dim) representing the embedded features
|
| 58 |
+
"""
|
| 59 |
+
# BioClip expects 224x224 input, resize if needed
|
| 60 |
+
if images.shape[-2:] != (224, 224):
|
| 61 |
+
images_resized = F.interpolate(images, size=(224, 224), mode='bilinear', align_corners=False)
|
| 62 |
+
else:
|
| 63 |
+
images_resized = images
|
| 64 |
+
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
image_features = self.species_model.encode_image(images_resized)
|
| 67 |
+
|
| 68 |
+
projected_features = self.projection(image_features)
|
| 69 |
+
|
| 70 |
+
return projected_features
|
prima/models/components/__init__.py
ADDED
|
File without changes
|
prima/models/components/model_utils.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 11 |
+
# All rights reserved.
|
| 12 |
+
|
| 13 |
+
# This source code is licensed under the license found in the
|
| 14 |
+
# LICENSE file in the root directory of this source tree.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
|
| 27 |
+
"""
|
| 28 |
+
Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
|
| 29 |
+
that are temporally closest to the current frame at `frame_idx`. Here, we take
|
| 30 |
+
- a) the closest conditioning frame before `frame_idx` (if any);
|
| 31 |
+
- b) the closest conditioning frame after `frame_idx` (if any);
|
| 32 |
+
- c) any other temporally closest conditioning frames until reaching a total
|
| 33 |
+
of `max_cond_frame_num` conditioning frames.
|
| 34 |
+
|
| 35 |
+
Outputs:
|
| 36 |
+
- selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
|
| 37 |
+
- unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
|
| 38 |
+
"""
|
| 39 |
+
if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
|
| 40 |
+
selected_outputs = cond_frame_outputs
|
| 41 |
+
unselected_outputs = {}
|
| 42 |
+
else:
|
| 43 |
+
assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
|
| 44 |
+
selected_outputs = {}
|
| 45 |
+
|
| 46 |
+
# the closest conditioning frame before `frame_idx` (if any)
|
| 47 |
+
idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
|
| 48 |
+
if idx_before is not None:
|
| 49 |
+
selected_outputs[idx_before] = cond_frame_outputs[idx_before]
|
| 50 |
+
|
| 51 |
+
# the closest conditioning frame after `frame_idx` (if any)
|
| 52 |
+
idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
|
| 53 |
+
if idx_after is not None:
|
| 54 |
+
selected_outputs[idx_after] = cond_frame_outputs[idx_after]
|
| 55 |
+
|
| 56 |
+
# add other temporally closest conditioning frames until reaching a total
|
| 57 |
+
# of `max_cond_frame_num` conditioning frames.
|
| 58 |
+
num_remain = max_cond_frame_num - len(selected_outputs)
|
| 59 |
+
inds_remain = sorted(
|
| 60 |
+
(t for t in cond_frame_outputs if t not in selected_outputs),
|
| 61 |
+
key=lambda x: abs(x - frame_idx),
|
| 62 |
+
)[:num_remain]
|
| 63 |
+
selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
|
| 64 |
+
unselected_outputs = {
|
| 65 |
+
t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
return selected_outputs, unselected_outputs
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_1d_sine_pe(pos_inds, dim, temperature=10000):
|
| 72 |
+
"""
|
| 73 |
+
Get 1D sine positional embedding as in the original Transformer paper.
|
| 74 |
+
"""
|
| 75 |
+
pe_dim = dim // 2
|
| 76 |
+
dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
|
| 77 |
+
dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
|
| 78 |
+
|
| 79 |
+
pos_embed = pos_inds.unsqueeze(-1) / dim_t
|
| 80 |
+
pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
|
| 81 |
+
return pos_embed
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_activation_fn(activation):
|
| 85 |
+
"""Return an activation function given a string"""
|
| 86 |
+
if activation == "relu":
|
| 87 |
+
return F.relu
|
| 88 |
+
if activation == "gelu":
|
| 89 |
+
return F.gelu
|
| 90 |
+
if activation == "glu":
|
| 91 |
+
return F.glu
|
| 92 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_clones(module, N):
|
| 96 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class DropPath(nn.Module):
|
| 100 |
+
# adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
| 101 |
+
def __init__(self, drop_prob=0.0, scale_by_keep=True):
|
| 102 |
+
super(DropPath, self).__init__()
|
| 103 |
+
self.drop_prob = drop_prob
|
| 104 |
+
self.scale_by_keep = scale_by_keep
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
if self.drop_prob == 0.0 or not self.training:
|
| 108 |
+
return x
|
| 109 |
+
keep_prob = 1 - self.drop_prob
|
| 110 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 111 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 112 |
+
if keep_prob > 0.0 and self.scale_by_keep:
|
| 113 |
+
random_tensor.div_(keep_prob)
|
| 114 |
+
return x * random_tensor
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Lightly adapted from
|
| 118 |
+
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
| 119 |
+
class MLP(nn.Module):
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
input_dim: int,
|
| 123 |
+
hidden_dim: int,
|
| 124 |
+
output_dim: int,
|
| 125 |
+
num_layers: int,
|
| 126 |
+
activation: nn.Module = nn.ReLU,
|
| 127 |
+
sigmoid_output: bool = False,
|
| 128 |
+
) -> None:
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.num_layers = num_layers
|
| 131 |
+
h = [hidden_dim] * (num_layers - 1)
|
| 132 |
+
self.layers = nn.ModuleList(
|
| 133 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
| 134 |
+
)
|
| 135 |
+
self.sigmoid_output = sigmoid_output
|
| 136 |
+
self.act = activation()
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
for i, layer in enumerate(self.layers):
|
| 140 |
+
x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
|
| 141 |
+
if self.sigmoid_output:
|
| 142 |
+
x = F.sigmoid(x)
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
| 147 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
| 148 |
+
class LayerNorm2d(nn.Module):
|
| 149 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
| 152 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
| 153 |
+
self.eps = eps
|
| 154 |
+
|
| 155 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
u = x.mean(1, keepdim=True)
|
| 157 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 158 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 159 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 160 |
+
return x
|
prima/models/components/pose_transformer.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from inspect import isfunction
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
from einops.layers.torch import Rearrange
|
| 16 |
+
from torch import nn
|
| 17 |
+
|
| 18 |
+
from .t_cond_mlp import (
|
| 19 |
+
AdaptiveLayerNorm1D,
|
| 20 |
+
FrequencyEmbedder,
|
| 21 |
+
normalization_layer,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def exists(val):
|
| 26 |
+
return val is not None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def default(val, d):
|
| 30 |
+
if exists(val):
|
| 31 |
+
return val
|
| 32 |
+
return d() if isfunction(d) else d
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class PreNorm(nn.Module):
|
| 36 |
+
def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.norm = normalization_layer(norm, dim, norm_cond_dim)
|
| 39 |
+
self.fn = fn
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor, *args, **kwargs):
|
| 42 |
+
if isinstance(self.norm, AdaptiveLayerNorm1D):
|
| 43 |
+
return self.fn(self.norm(x, *args), **kwargs)
|
| 44 |
+
else:
|
| 45 |
+
return self.fn(self.norm(x), **kwargs)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class FeedForward(nn.Module):
|
| 49 |
+
def __init__(self, dim, hidden_dim, dropout=0.0):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.net = nn.Sequential(
|
| 52 |
+
nn.Linear(dim, hidden_dim),
|
| 53 |
+
nn.GELU(),
|
| 54 |
+
nn.Dropout(dropout),
|
| 55 |
+
nn.Linear(hidden_dim, dim),
|
| 56 |
+
nn.Dropout(dropout),
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
return self.net(x)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Attention(nn.Module):
|
| 64 |
+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
|
| 65 |
+
super().__init__()
|
| 66 |
+
inner_dim = dim_head * heads
|
| 67 |
+
project_out = not (heads == 1 and dim_head == dim)
|
| 68 |
+
|
| 69 |
+
self.heads = heads
|
| 70 |
+
self.scale = dim_head**-0.5
|
| 71 |
+
|
| 72 |
+
self.attend = nn.Softmax(dim=-1)
|
| 73 |
+
self.dropout = nn.Dropout(dropout)
|
| 74 |
+
|
| 75 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
| 76 |
+
|
| 77 |
+
self.to_out = (
|
| 78 |
+
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
| 79 |
+
if project_out
|
| 80 |
+
else nn.Identity()
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
| 85 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
|
| 86 |
+
|
| 87 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| 88 |
+
|
| 89 |
+
attn = self.attend(dots)
|
| 90 |
+
attn = self.dropout(attn)
|
| 91 |
+
|
| 92 |
+
out = torch.matmul(attn, v)
|
| 93 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 94 |
+
return self.to_out(out)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class CrossAttention(nn.Module):
|
| 98 |
+
def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
| 99 |
+
super().__init__()
|
| 100 |
+
inner_dim = dim_head * heads
|
| 101 |
+
project_out = not (heads == 1 and dim_head == dim)
|
| 102 |
+
|
| 103 |
+
self.heads = heads
|
| 104 |
+
self.scale = dim_head**-0.5
|
| 105 |
+
|
| 106 |
+
self.attend = nn.Softmax(dim=-1)
|
| 107 |
+
self.dropout = nn.Dropout(dropout)
|
| 108 |
+
|
| 109 |
+
context_dim = default(context_dim, dim)
|
| 110 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
|
| 111 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 112 |
+
|
| 113 |
+
self.to_out = (
|
| 114 |
+
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
| 115 |
+
if project_out
|
| 116 |
+
else nn.Identity()
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def forward(self, x, context=None):
|
| 120 |
+
context = default(context, x)
|
| 121 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
| 122 |
+
q = self.to_q(x)
|
| 123 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
|
| 124 |
+
|
| 125 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| 126 |
+
|
| 127 |
+
attn = self.attend(dots)
|
| 128 |
+
attn = self.dropout(attn)
|
| 129 |
+
|
| 130 |
+
out = torch.matmul(attn, v)
|
| 131 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 132 |
+
return self.to_out(out)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Transformer(nn.Module):
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
dim: int,
|
| 139 |
+
depth: int,
|
| 140 |
+
heads: int,
|
| 141 |
+
dim_head: int,
|
| 142 |
+
mlp_dim: int,
|
| 143 |
+
dropout: float = 0.0,
|
| 144 |
+
norm: str = "layer",
|
| 145 |
+
norm_cond_dim: int = -1,
|
| 146 |
+
):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.layers = nn.ModuleList([])
|
| 149 |
+
for _ in range(depth):
|
| 150 |
+
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
| 151 |
+
ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
| 152 |
+
self.layers.append(
|
| 153 |
+
nn.ModuleList(
|
| 154 |
+
[
|
| 155 |
+
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
|
| 156 |
+
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
|
| 157 |
+
]
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def forward(self, x: torch.Tensor, *args):
|
| 162 |
+
for attn, ff in self.layers:
|
| 163 |
+
x = attn(x, *args) + x
|
| 164 |
+
x = ff(x, *args) + x
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class TransformerCrossAttn(nn.Module):
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
dim: int,
|
| 172 |
+
depth: int,
|
| 173 |
+
heads: int,
|
| 174 |
+
dim_head: int,
|
| 175 |
+
mlp_dim: int,
|
| 176 |
+
dropout: float = 0.0,
|
| 177 |
+
norm: str = "layer",
|
| 178 |
+
norm_cond_dim: int = -1,
|
| 179 |
+
context_dim: Optional[int] = None,
|
| 180 |
+
):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.layers = nn.ModuleList([])
|
| 183 |
+
for _ in range(depth):
|
| 184 |
+
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
| 185 |
+
ca = CrossAttention(
|
| 186 |
+
dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
|
| 187 |
+
)
|
| 188 |
+
ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
| 189 |
+
self.layers.append(
|
| 190 |
+
nn.ModuleList(
|
| 191 |
+
[
|
| 192 |
+
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
|
| 193 |
+
PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
|
| 194 |
+
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
|
| 195 |
+
]
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
|
| 200 |
+
if context_list is None:
|
| 201 |
+
context_list = [context] * len(self.layers)
|
| 202 |
+
if len(context_list) != len(self.layers):
|
| 203 |
+
raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
|
| 204 |
+
|
| 205 |
+
for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
|
| 206 |
+
x = self_attn(x, *args) + x
|
| 207 |
+
x = cross_attn(x, *args, context=context_list[i]) + x
|
| 208 |
+
x = ff(x, *args) + x
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class DropTokenDropout(nn.Module):
|
| 213 |
+
def __init__(self, p: float = 0.1):
|
| 214 |
+
super().__init__()
|
| 215 |
+
if p < 0 or p > 1:
|
| 216 |
+
raise ValueError(
|
| 217 |
+
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
|
| 218 |
+
)
|
| 219 |
+
self.p = p
|
| 220 |
+
|
| 221 |
+
def forward(self, x: torch.Tensor):
|
| 222 |
+
# x: (batch_size, seq_len, dim)
|
| 223 |
+
if self.training and self.p > 0:
|
| 224 |
+
zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
|
| 225 |
+
|
| 226 |
+
if zero_mask.any():
|
| 227 |
+
x = x[:, ~zero_mask, :]
|
| 228 |
+
return x
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class ZeroTokenDropout(nn.Module):
|
| 232 |
+
def __init__(self, p: float = 0.1):
|
| 233 |
+
super().__init__()
|
| 234 |
+
if p < 0 or p > 1:
|
| 235 |
+
raise ValueError(
|
| 236 |
+
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
|
| 237 |
+
)
|
| 238 |
+
self.p = p
|
| 239 |
+
|
| 240 |
+
def forward(self, x: torch.Tensor):
|
| 241 |
+
# x: (batch_size, seq_len, dim)
|
| 242 |
+
if self.training and self.p > 0:
|
| 243 |
+
zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
|
| 244 |
+
# Zero-out the masked tokens
|
| 245 |
+
x[zero_mask, :] = 0
|
| 246 |
+
return x
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class TransformerEncoder(nn.Module):
|
| 250 |
+
def __init__(
|
| 251 |
+
self,
|
| 252 |
+
num_tokens: int,
|
| 253 |
+
token_dim: int,
|
| 254 |
+
dim: int,
|
| 255 |
+
depth: int,
|
| 256 |
+
heads: int,
|
| 257 |
+
mlp_dim: int,
|
| 258 |
+
dim_head: int = 64,
|
| 259 |
+
dropout: float = 0.0,
|
| 260 |
+
emb_dropout: float = 0.0,
|
| 261 |
+
emb_dropout_type: str = "drop",
|
| 262 |
+
emb_dropout_loc: str = "token",
|
| 263 |
+
norm: str = "layer",
|
| 264 |
+
norm_cond_dim: int = -1,
|
| 265 |
+
token_pe_numfreq: int = -1,
|
| 266 |
+
):
|
| 267 |
+
super().__init__()
|
| 268 |
+
if token_pe_numfreq > 0:
|
| 269 |
+
token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
|
| 270 |
+
self.to_token_embedding = nn.Sequential(
|
| 271 |
+
Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
|
| 272 |
+
FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
|
| 273 |
+
Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
|
| 274 |
+
nn.Linear(token_dim_new, dim),
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
self.to_token_embedding = nn.Linear(token_dim, dim)
|
| 278 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
|
| 279 |
+
if emb_dropout_type == "drop":
|
| 280 |
+
self.dropout = DropTokenDropout(emb_dropout)
|
| 281 |
+
elif emb_dropout_type == "zero":
|
| 282 |
+
self.dropout = ZeroTokenDropout(emb_dropout)
|
| 283 |
+
else:
|
| 284 |
+
raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
|
| 285 |
+
self.emb_dropout_loc = emb_dropout_loc
|
| 286 |
+
|
| 287 |
+
self.transformer = Transformer(
|
| 288 |
+
dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def forward(self, inp: torch.Tensor, *args, **kwargs):
|
| 292 |
+
x = inp
|
| 293 |
+
|
| 294 |
+
if self.emb_dropout_loc == "input":
|
| 295 |
+
x = self.dropout(x)
|
| 296 |
+
x = self.to_token_embedding(x)
|
| 297 |
+
|
| 298 |
+
if self.emb_dropout_loc == "token":
|
| 299 |
+
x = self.dropout(x)
|
| 300 |
+
b, n, _ = x.shape
|
| 301 |
+
x += self.pos_embedding[:, :n]
|
| 302 |
+
|
| 303 |
+
if self.emb_dropout_loc == "token_afterpos":
|
| 304 |
+
x = self.dropout(x)
|
| 305 |
+
x = self.transformer(x, *args)
|
| 306 |
+
return x
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class TransformerDecoder(nn.Module):
|
| 310 |
+
def __init__(
|
| 311 |
+
self,
|
| 312 |
+
num_tokens: int,
|
| 313 |
+
token_dim: int,
|
| 314 |
+
dim: int,
|
| 315 |
+
depth: int,
|
| 316 |
+
heads: int,
|
| 317 |
+
mlp_dim: int,
|
| 318 |
+
dim_head: int = 64,
|
| 319 |
+
dropout: float = 0.0,
|
| 320 |
+
emb_dropout: float = 0.0,
|
| 321 |
+
emb_dropout_type: str = 'drop',
|
| 322 |
+
norm: str = "layer",
|
| 323 |
+
norm_cond_dim: int = -1,
|
| 324 |
+
context_dim: Optional[int] = None,
|
| 325 |
+
skip_token_embedding: bool = False,
|
| 326 |
+
):
|
| 327 |
+
super().__init__()
|
| 328 |
+
if not skip_token_embedding:
|
| 329 |
+
self.to_token_embedding = nn.Linear(token_dim, dim)
|
| 330 |
+
else:
|
| 331 |
+
self.to_token_embedding = nn.Identity()
|
| 332 |
+
if token_dim != dim:
|
| 333 |
+
raise ValueError(
|
| 334 |
+
f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
|
| 338 |
+
if emb_dropout_type == "drop":
|
| 339 |
+
self.dropout = DropTokenDropout(emb_dropout)
|
| 340 |
+
elif emb_dropout_type == "zero":
|
| 341 |
+
self.dropout = ZeroTokenDropout(emb_dropout)
|
| 342 |
+
elif emb_dropout_type == "normal":
|
| 343 |
+
self.dropout = nn.Dropout(emb_dropout)
|
| 344 |
+
|
| 345 |
+
self.transformer = TransformerCrossAttn(
|
| 346 |
+
dim,
|
| 347 |
+
depth,
|
| 348 |
+
heads,
|
| 349 |
+
dim_head,
|
| 350 |
+
mlp_dim,
|
| 351 |
+
dropout,
|
| 352 |
+
norm=norm,
|
| 353 |
+
norm_cond_dim=norm_cond_dim,
|
| 354 |
+
context_dim=context_dim,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
|
| 358 |
+
x = self.to_token_embedding(inp)
|
| 359 |
+
b, n, _ = x.shape
|
| 360 |
+
|
| 361 |
+
x = self.dropout(x)
|
| 362 |
+
x += self.pos_embedding[:, :n]
|
| 363 |
+
|
| 364 |
+
x = self.transformer(x, *args, context=context, context_list=context_list)
|
| 365 |
+
return x
|
| 366 |
+
|
prima/models/components/position_encoding.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 11 |
+
# All rights reserved.
|
| 12 |
+
|
| 13 |
+
# This source code is licensed under the license found in the
|
| 14 |
+
# LICENSE file in the root directory of this source tree.
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
from typing import Any, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch import nn
|
| 23 |
+
|
| 24 |
+
# Rotary Positional Encoding, adapted from:
|
| 25 |
+
# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
| 26 |
+
# 2. https://github.com/naver-ai/rope-vit
|
| 27 |
+
# 3. https://github.com/lucidrains/rotary-embedding-torch
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def init_t_xy(end_x: int, end_y: int):
|
| 31 |
+
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
| 32 |
+
t_x = (t % end_x).float()
|
| 33 |
+
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
| 34 |
+
return t_x, t_y
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
| 38 |
+
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
| 39 |
+
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
| 40 |
+
|
| 41 |
+
t_x, t_y = init_t_xy(end_x, end_y)
|
| 42 |
+
freqs_x = torch.outer(t_x, freqs_x)
|
| 43 |
+
freqs_y = torch.outer(t_y, freqs_y)
|
| 44 |
+
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
|
| 45 |
+
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
|
| 46 |
+
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| 50 |
+
ndim = x.ndim
|
| 51 |
+
assert 0 <= 1 < ndim
|
| 52 |
+
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
| 53 |
+
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
|
| 54 |
+
return freqs_cis.view(*shape)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def apply_rotary_enc(
|
| 58 |
+
xq: torch.Tensor,
|
| 59 |
+
xk: torch.Tensor,
|
| 60 |
+
freqs_cis: torch.Tensor,
|
| 61 |
+
repeat_freqs_k: bool = False,
|
| 62 |
+
):
|
| 63 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 64 |
+
xk_ = (
|
| 65 |
+
torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 66 |
+
if xk.shape[-2] != 0
|
| 67 |
+
else None
|
| 68 |
+
)
|
| 69 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| 70 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 71 |
+
if xk_ is None:
|
| 72 |
+
# no keys to rotate, due to dropout
|
| 73 |
+
return xq_out.type_as(xq).to(xq.device), xk
|
| 74 |
+
# repeat freqs along seq_len dim to match k seq_len
|
| 75 |
+
if repeat_freqs_k:
|
| 76 |
+
r = xk_.shape[-2] // xq_.shape[-2]
|
| 77 |
+
if freqs_cis.is_cuda:
|
| 78 |
+
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
| 79 |
+
else:
|
| 80 |
+
# torch.repeat on complex numbers may not be supported on non-CUDA devices
|
| 81 |
+
# (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten
|
| 82 |
+
freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3)
|
| 83 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 84 |
+
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
prima/models/components/t_cond_mlp.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import copy
|
| 11 |
+
from typing import List, Optional
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AdaptiveLayerNorm1D(torch.nn.Module):
|
| 17 |
+
def __init__(self, data_dim: int, norm_cond_dim: int):
|
| 18 |
+
super().__init__()
|
| 19 |
+
if data_dim <= 0:
|
| 20 |
+
raise ValueError(f"data_dim must be positive, but got {data_dim}")
|
| 21 |
+
if norm_cond_dim <= 0:
|
| 22 |
+
raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
|
| 23 |
+
self.norm = torch.nn.LayerNorm(data_dim)
|
| 24 |
+
self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
|
| 25 |
+
torch.nn.init.zeros_(self.linear.weight)
|
| 26 |
+
torch.nn.init.zeros_(self.linear.bias)
|
| 27 |
+
|
| 28 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
# x: (batch, ..., data_dim)
|
| 30 |
+
# t: (batch, norm_cond_dim)
|
| 31 |
+
# return: (batch, data_dim)
|
| 32 |
+
x = self.norm(x)
|
| 33 |
+
alpha, beta = self.linear(t).chunk(2, dim=-1)
|
| 34 |
+
|
| 35 |
+
# Add singleton dimensions to alpha and beta
|
| 36 |
+
if x.dim() > 2:
|
| 37 |
+
alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
|
| 38 |
+
beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
|
| 39 |
+
|
| 40 |
+
return x * (1 + alpha) + beta
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class SequentialCond(torch.nn.Sequential):
|
| 44 |
+
def forward(self, input, *args, **kwargs):
|
| 45 |
+
for module in self:
|
| 46 |
+
if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
|
| 47 |
+
input = module(input, *args, **kwargs)
|
| 48 |
+
else:
|
| 49 |
+
input = module(input)
|
| 50 |
+
return input
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
|
| 54 |
+
if norm == "batch":
|
| 55 |
+
return torch.nn.BatchNorm1d(dim)
|
| 56 |
+
elif norm == "layer":
|
| 57 |
+
return torch.nn.LayerNorm(dim)
|
| 58 |
+
elif norm == "ada":
|
| 59 |
+
assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
|
| 60 |
+
return AdaptiveLayerNorm1D(dim, norm_cond_dim)
|
| 61 |
+
elif norm is None:
|
| 62 |
+
return torch.nn.Identity()
|
| 63 |
+
else:
|
| 64 |
+
raise ValueError(f"Unknown norm: {norm}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def linear_norm_activ_dropout(
|
| 68 |
+
input_dim: int,
|
| 69 |
+
output_dim: int,
|
| 70 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 71 |
+
bias: bool = True,
|
| 72 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
| 73 |
+
dropout: float = 0.0,
|
| 74 |
+
norm_cond_dim: int = -1,
|
| 75 |
+
) -> SequentialCond:
|
| 76 |
+
layers = []
|
| 77 |
+
layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
|
| 78 |
+
if norm is not None:
|
| 79 |
+
layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
|
| 80 |
+
layers.append(copy.deepcopy(activation))
|
| 81 |
+
if dropout > 0.0:
|
| 82 |
+
layers.append(torch.nn.Dropout(dropout))
|
| 83 |
+
return SequentialCond(*layers)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def create_simple_mlp(
|
| 87 |
+
input_dim: int,
|
| 88 |
+
hidden_dims: List[int],
|
| 89 |
+
output_dim: int,
|
| 90 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 91 |
+
bias: bool = True,
|
| 92 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
| 93 |
+
dropout: float = 0.0,
|
| 94 |
+
norm_cond_dim: int = -1,
|
| 95 |
+
) -> SequentialCond:
|
| 96 |
+
layers = []
|
| 97 |
+
prev_dim = input_dim
|
| 98 |
+
for hidden_dim in hidden_dims:
|
| 99 |
+
layers.extend(
|
| 100 |
+
linear_norm_activ_dropout(
|
| 101 |
+
prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
prev_dim = hidden_dim
|
| 105 |
+
layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
|
| 106 |
+
return SequentialCond(*layers)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class ResidualMLPBlock(torch.nn.Module):
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
input_dim: int,
|
| 113 |
+
hidden_dim: int,
|
| 114 |
+
num_hidden_layers: int,
|
| 115 |
+
output_dim: int,
|
| 116 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 117 |
+
bias: bool = True,
|
| 118 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
| 119 |
+
dropout: float = 0.0,
|
| 120 |
+
norm_cond_dim: int = -1,
|
| 121 |
+
):
|
| 122 |
+
super().__init__()
|
| 123 |
+
if not (input_dim == output_dim == hidden_dim):
|
| 124 |
+
raise NotImplementedError(
|
| 125 |
+
f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
layers = []
|
| 129 |
+
prev_dim = input_dim
|
| 130 |
+
for i in range(num_hidden_layers):
|
| 131 |
+
layers.append(
|
| 132 |
+
linear_norm_activ_dropout(
|
| 133 |
+
prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
|
| 134 |
+
)
|
| 135 |
+
)
|
| 136 |
+
prev_dim = hidden_dim
|
| 137 |
+
self.model = SequentialCond(*layers)
|
| 138 |
+
self.skip = torch.nn.Identity()
|
| 139 |
+
|
| 140 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 141 |
+
return x + self.model(x, *args, **kwargs)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class ResidualMLP(torch.nn.Module):
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
input_dim: int,
|
| 148 |
+
hidden_dim: int,
|
| 149 |
+
num_hidden_layers: int,
|
| 150 |
+
output_dim: int,
|
| 151 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 152 |
+
bias: bool = True,
|
| 153 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
| 154 |
+
dropout: float = 0.0,
|
| 155 |
+
num_blocks: int = 1,
|
| 156 |
+
norm_cond_dim: int = -1,
|
| 157 |
+
):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.input_dim = input_dim
|
| 160 |
+
self.model = SequentialCond(
|
| 161 |
+
linear_norm_activ_dropout(
|
| 162 |
+
input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
|
| 163 |
+
),
|
| 164 |
+
*[
|
| 165 |
+
ResidualMLPBlock(
|
| 166 |
+
hidden_dim,
|
| 167 |
+
hidden_dim,
|
| 168 |
+
num_hidden_layers,
|
| 169 |
+
hidden_dim,
|
| 170 |
+
activation,
|
| 171 |
+
bias,
|
| 172 |
+
norm,
|
| 173 |
+
dropout,
|
| 174 |
+
norm_cond_dim,
|
| 175 |
+
)
|
| 176 |
+
for _ in range(num_blocks)
|
| 177 |
+
],
|
| 178 |
+
torch.nn.Linear(hidden_dim, output_dim, bias=bias),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 182 |
+
return self.model(x, *args, **kwargs)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class FrequencyEmbedder(torch.nn.Module):
|
| 186 |
+
def __init__(self, num_frequencies, max_freq_log2):
|
| 187 |
+
super().__init__()
|
| 188 |
+
frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
|
| 189 |
+
self.register_buffer("frequencies", frequencies)
|
| 190 |
+
|
| 191 |
+
def forward(self, x):
|
| 192 |
+
# x should be of size (N,) or (N, D)
|
| 193 |
+
N = x.size(0)
|
| 194 |
+
if x.dim() == 1: # (N,)
|
| 195 |
+
x = x.unsqueeze(1) # (N, D) where D=1
|
| 196 |
+
x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
|
| 197 |
+
scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
|
| 198 |
+
s = torch.sin(scaled)
|
| 199 |
+
c = torch.cos(scaled)
|
| 200 |
+
embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
|
| 201 |
+
N, -1
|
| 202 |
+
) # (N, D * 2 * num_frequencies + D)
|
| 203 |
+
return embedded
|
| 204 |
+
|
prima/models/components/transformer.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 11 |
+
# All rights reserved.
|
| 12 |
+
|
| 13 |
+
# This source code is licensed under the license found in the
|
| 14 |
+
# LICENSE file in the root directory of this source tree.
|
| 15 |
+
|
| 16 |
+
import contextlib
|
| 17 |
+
import math
|
| 18 |
+
import warnings
|
| 19 |
+
from functools import partial
|
| 20 |
+
from typing import Tuple, Type
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from torch import nn, Tensor
|
| 25 |
+
|
| 26 |
+
from .position_encoding import apply_rotary_enc, compute_axial_cis
|
| 27 |
+
from .model_utils import MLP
|
| 28 |
+
|
| 29 |
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_sdpa_settings():
|
| 33 |
+
if torch.cuda.is_available():
|
| 34 |
+
old_gpu = torch.cuda.get_device_properties(0).major < 7
|
| 35 |
+
# only use Flash Attention on Ampere (8.0) or newer GPUs
|
| 36 |
+
use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
|
| 37 |
+
if not use_flash_attn:
|
| 38 |
+
warnings.warn(
|
| 39 |
+
"Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
|
| 40 |
+
category=UserWarning,
|
| 41 |
+
stacklevel=2,
|
| 42 |
+
)
|
| 43 |
+
# keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
|
| 44 |
+
# available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
|
| 45 |
+
pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
|
| 46 |
+
if pytorch_version < (2, 2):
|
| 47 |
+
warnings.warn(
|
| 48 |
+
f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
|
| 49 |
+
"Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
|
| 50 |
+
category=UserWarning,
|
| 51 |
+
stacklevel=2,
|
| 52 |
+
)
|
| 53 |
+
math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
|
| 54 |
+
else:
|
| 55 |
+
old_gpu = True
|
| 56 |
+
use_flash_attn = False
|
| 57 |
+
math_kernel_on = True
|
| 58 |
+
|
| 59 |
+
return old_gpu, use_flash_attn, math_kernel_on
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Check whether Flash Attention is available (and use it by default)
|
| 63 |
+
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
| 64 |
+
# A fallback setting to allow all available kernels if Flash Attention fails
|
| 65 |
+
ALLOW_ALL_KERNELS = False
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def sdp_kernel_context(dropout_p):
|
| 69 |
+
"""
|
| 70 |
+
Get the context for the attention scaled dot-product kernel. We use Flash Attention
|
| 71 |
+
by default, but fall back to all available kernels if Flash Attention fails.
|
| 72 |
+
"""
|
| 73 |
+
if ALLOW_ALL_KERNELS:
|
| 74 |
+
return contextlib.nullcontext()
|
| 75 |
+
|
| 76 |
+
return torch.backends.cuda.sdp_kernel(
|
| 77 |
+
enable_flash=USE_FLASH_ATTN,
|
| 78 |
+
# if Flash attention kernel is off, then math kernel needs to be enabled
|
| 79 |
+
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
| 80 |
+
enable_mem_efficient=OLD_GPU,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class TwoWayTransformer(nn.Module):
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
depth: int,
|
| 88 |
+
embedding_dim: int,
|
| 89 |
+
num_heads: int,
|
| 90 |
+
mlp_dim: int,
|
| 91 |
+
activation: Type[nn.Module] = nn.ReLU,
|
| 92 |
+
attention_downsample_rate: int = 2,
|
| 93 |
+
) -> None:
|
| 94 |
+
"""
|
| 95 |
+
A transformer decoder that attends to an input image using
|
| 96 |
+
queries whose positional embedding is supplied.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
depth (int): number of layers in the transformer
|
| 100 |
+
embedding_dim (int): the channel dimension for the input embeddings
|
| 101 |
+
num_heads (int): the number of heads for multihead attention. Must
|
| 102 |
+
divide embedding_dim
|
| 103 |
+
mlp_dim (int): the channel dimension internal to the MLP block
|
| 104 |
+
activation (nn.Module): the activation to use in the MLP block
|
| 105 |
+
"""
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.depth = depth
|
| 108 |
+
self.embedding_dim = embedding_dim
|
| 109 |
+
self.num_heads = num_heads
|
| 110 |
+
self.mlp_dim = mlp_dim
|
| 111 |
+
self.layers = nn.ModuleList()
|
| 112 |
+
|
| 113 |
+
for i in range(depth):
|
| 114 |
+
self.layers.append(
|
| 115 |
+
TwoWayAttentionBlock(
|
| 116 |
+
embedding_dim=embedding_dim,
|
| 117 |
+
num_heads=num_heads,
|
| 118 |
+
mlp_dim=mlp_dim,
|
| 119 |
+
activation=activation,
|
| 120 |
+
attention_downsample_rate=attention_downsample_rate,
|
| 121 |
+
skip_first_layer_pe=(i == 0),
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self.final_attn_token_to_image = Attention(
|
| 126 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 127 |
+
)
|
| 128 |
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
| 129 |
+
|
| 130 |
+
def forward(
|
| 131 |
+
self,
|
| 132 |
+
image_embedding: Tensor,
|
| 133 |
+
image_pe: Tensor,
|
| 134 |
+
point_embedding: Tensor,
|
| 135 |
+
) -> Tuple[Tensor, Tensor]:
|
| 136 |
+
"""
|
| 137 |
+
Args:
|
| 138 |
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
| 139 |
+
B x embedding_dim x h x w for any h and w.
|
| 140 |
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
| 141 |
+
have the same shape as image_embedding.
|
| 142 |
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
| 143 |
+
Must have shape B x N_points x embedding_dim for any N_points.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
torch.Tensor: the processed point_embedding
|
| 147 |
+
torch.Tensor: the processed image_embedding
|
| 148 |
+
"""
|
| 149 |
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
| 150 |
+
bs, c, h, w = image_embedding.shape
|
| 151 |
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
| 152 |
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
| 153 |
+
|
| 154 |
+
# Prepare queries
|
| 155 |
+
queries = point_embedding
|
| 156 |
+
keys = image_embedding
|
| 157 |
+
|
| 158 |
+
# Apply transformer blocks and final layernorm
|
| 159 |
+
for layer in self.layers:
|
| 160 |
+
queries, keys = layer(
|
| 161 |
+
queries=queries,
|
| 162 |
+
keys=keys,
|
| 163 |
+
query_pe=point_embedding,
|
| 164 |
+
key_pe=image_pe,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Apply the final attention layer from the points to the image
|
| 168 |
+
q = queries + point_embedding
|
| 169 |
+
k = keys + image_pe
|
| 170 |
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
| 171 |
+
queries = queries + attn_out
|
| 172 |
+
queries = self.norm_final_attn(queries)
|
| 173 |
+
|
| 174 |
+
return queries, keys
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class TwoWayAttentionBlock(nn.Module):
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
embedding_dim: int,
|
| 181 |
+
num_heads: int,
|
| 182 |
+
mlp_dim: int = 2048,
|
| 183 |
+
activation: Type[nn.Module] = nn.ReLU,
|
| 184 |
+
attention_downsample_rate: int = 2,
|
| 185 |
+
skip_first_layer_pe: bool = False,
|
| 186 |
+
) -> None:
|
| 187 |
+
"""
|
| 188 |
+
A transformer block with four layers: (1) self-attention of sparse
|
| 189 |
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
| 190 |
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
| 191 |
+
inputs.
|
| 192 |
+
|
| 193 |
+
Arguments:
|
| 194 |
+
embedding_dim (int): the channel dimension of the embeddings
|
| 195 |
+
num_heads (int): the number of heads in the attention layers
|
| 196 |
+
mlp_dim (int): the hidden dimension of the mlp block
|
| 197 |
+
activation (nn.Module): the activation of the mlp block
|
| 198 |
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
| 199 |
+
"""
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.self_attn = Attention(embedding_dim, num_heads)
|
| 202 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
| 203 |
+
|
| 204 |
+
self.cross_attn_token_to_image = Attention(
|
| 205 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 206 |
+
)
|
| 207 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
| 208 |
+
|
| 209 |
+
self.mlp = MLP(
|
| 210 |
+
embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
|
| 211 |
+
)
|
| 212 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
| 213 |
+
|
| 214 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
| 215 |
+
self.cross_attn_image_to_token = Attention(
|
| 216 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
| 220 |
+
|
| 221 |
+
def forward(
|
| 222 |
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
| 223 |
+
) -> Tuple[Tensor, Tensor]:
|
| 224 |
+
# Self attention block
|
| 225 |
+
if self.skip_first_layer_pe:
|
| 226 |
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
| 227 |
+
else:
|
| 228 |
+
q = queries + query_pe
|
| 229 |
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
| 230 |
+
queries = queries + attn_out
|
| 231 |
+
queries = self.norm1(queries)
|
| 232 |
+
|
| 233 |
+
# Cross attention block, tokens attending to image embedding
|
| 234 |
+
q = queries + query_pe
|
| 235 |
+
k = keys + key_pe
|
| 236 |
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
| 237 |
+
queries = queries + attn_out
|
| 238 |
+
queries = self.norm2(queries)
|
| 239 |
+
|
| 240 |
+
# MLP block
|
| 241 |
+
mlp_out = self.mlp(queries)
|
| 242 |
+
queries = queries + mlp_out
|
| 243 |
+
queries = self.norm3(queries)
|
| 244 |
+
|
| 245 |
+
# Cross attention block, image embedding attending to tokens
|
| 246 |
+
q = queries + query_pe
|
| 247 |
+
k = keys + key_pe
|
| 248 |
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
| 249 |
+
keys = keys + attn_out
|
| 250 |
+
keys = self.norm4(keys)
|
| 251 |
+
|
| 252 |
+
return queries, keys
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class Attention(nn.Module):
|
| 256 |
+
"""
|
| 257 |
+
An attention layer that allows for downscaling the size of the embedding
|
| 258 |
+
after projection to queries, keys, and values.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
def __init__(
|
| 262 |
+
self,
|
| 263 |
+
embedding_dim: int,
|
| 264 |
+
num_heads: int,
|
| 265 |
+
downsample_rate: int = 1,
|
| 266 |
+
dropout: float = 0.0,
|
| 267 |
+
kv_in_dim: int = None,
|
| 268 |
+
) -> None:
|
| 269 |
+
super().__init__()
|
| 270 |
+
self.embedding_dim = embedding_dim
|
| 271 |
+
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
| 272 |
+
self.internal_dim = embedding_dim // downsample_rate
|
| 273 |
+
self.num_heads = num_heads
|
| 274 |
+
assert (
|
| 275 |
+
self.internal_dim % num_heads == 0
|
| 276 |
+
), "num_heads must divide embedding_dim."
|
| 277 |
+
|
| 278 |
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
| 279 |
+
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
| 280 |
+
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
| 281 |
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
| 282 |
+
|
| 283 |
+
self.dropout_p = dropout
|
| 284 |
+
|
| 285 |
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
| 286 |
+
b, n, c = x.shape
|
| 287 |
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
| 288 |
+
return x.transpose(1, 2).contiguous() # B x N_heads x N_tokens x C_per_head
|
| 289 |
+
|
| 290 |
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
| 291 |
+
b, n_heads, n_tokens, c_per_head = x.shape
|
| 292 |
+
x = x.transpose(1, 2).contiguous()
|
| 293 |
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
| 294 |
+
|
| 295 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 296 |
+
# Input projections
|
| 297 |
+
q = self.q_proj(q)
|
| 298 |
+
k = self.k_proj(k)
|
| 299 |
+
v = self.v_proj(v)
|
| 300 |
+
|
| 301 |
+
# Separate into heads
|
| 302 |
+
q = self._separate_heads(q, self.num_heads)
|
| 303 |
+
k = self._separate_heads(k, self.num_heads)
|
| 304 |
+
v = self._separate_heads(v, self.num_heads)
|
| 305 |
+
|
| 306 |
+
dropout_p = self.dropout_p if self.training else 0.0
|
| 307 |
+
# Attention
|
| 308 |
+
try:
|
| 309 |
+
with sdp_kernel_context(dropout_p):
|
| 310 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 311 |
+
except Exception as e:
|
| 312 |
+
# Fall back to all kernels if the Flash attention kernel fails
|
| 313 |
+
warnings.warn(
|
| 314 |
+
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
| 315 |
+
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
| 316 |
+
category=UserWarning,
|
| 317 |
+
stacklevel=2,
|
| 318 |
+
)
|
| 319 |
+
global ALLOW_ALL_KERNELS
|
| 320 |
+
ALLOW_ALL_KERNELS = True
|
| 321 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 322 |
+
|
| 323 |
+
out = self._recombine_heads(out)
|
| 324 |
+
out = self.out_proj(out)
|
| 325 |
+
|
| 326 |
+
return out
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class RoPEAttention(Attention):
|
| 330 |
+
"""Attention with rotary position encoding."""
|
| 331 |
+
|
| 332 |
+
def __init__(
|
| 333 |
+
self,
|
| 334 |
+
*args,
|
| 335 |
+
rope_theta=10000.0,
|
| 336 |
+
# whether to repeat q rope to match k length
|
| 337 |
+
# this is needed for cross-attention to memories
|
| 338 |
+
rope_k_repeat=False,
|
| 339 |
+
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
|
| 340 |
+
**kwargs,
|
| 341 |
+
):
|
| 342 |
+
super().__init__(*args, **kwargs)
|
| 343 |
+
|
| 344 |
+
self.compute_cis = partial(
|
| 345 |
+
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
| 346 |
+
)
|
| 347 |
+
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
| 348 |
+
self.freqs_cis = freqs_cis
|
| 349 |
+
self.rope_k_repeat = rope_k_repeat
|
| 350 |
+
|
| 351 |
+
def forward(
|
| 352 |
+
self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int=0,
|
| 353 |
+
) -> Tensor:
|
| 354 |
+
# Input projections
|
| 355 |
+
q = self.q_proj(q)
|
| 356 |
+
k = self.k_proj(k)
|
| 357 |
+
v = self.v_proj(v)
|
| 358 |
+
|
| 359 |
+
# Separate into heads
|
| 360 |
+
q = self._separate_heads(q, self.num_heads)
|
| 361 |
+
k = self._separate_heads(k, self.num_heads)
|
| 362 |
+
v = self._separate_heads(v, self.num_heads)
|
| 363 |
+
|
| 364 |
+
# Apply rotary position encoding
|
| 365 |
+
w = h = math.sqrt(q.shape[-2])
|
| 366 |
+
self.freqs_cis = self.freqs_cis.to(q.device)
|
| 367 |
+
if self.freqs_cis.shape[0] != q.shape[-2]:
|
| 368 |
+
self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
|
| 369 |
+
if q.shape[-2] != k.shape[-2]:
|
| 370 |
+
assert self.rope_k_repeat
|
| 371 |
+
|
| 372 |
+
num_k_rope = k.size(-2) - num_k_exclude_rope
|
| 373 |
+
q, k[:, :, :num_k_rope] = apply_rotary_enc(
|
| 374 |
+
q,
|
| 375 |
+
k[:, :, :num_k_rope],
|
| 376 |
+
freqs_cis=self.freqs_cis,
|
| 377 |
+
repeat_freqs_k=self.rope_k_repeat,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
dropout_p = self.dropout_p if self.training else 0.0
|
| 381 |
+
# Attention
|
| 382 |
+
try:
|
| 383 |
+
with sdp_kernel_context(dropout_p):
|
| 384 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 385 |
+
except Exception as e:
|
| 386 |
+
# Fall back to all kernels if the Flash attention kernel fails
|
| 387 |
+
warnings.warn(
|
| 388 |
+
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
| 389 |
+
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
| 390 |
+
category=UserWarning,
|
| 391 |
+
stacklevel=2,
|
| 392 |
+
)
|
| 393 |
+
global ALLOW_ALL_KERNELS
|
| 394 |
+
ALLOW_ALL_KERNELS = True
|
| 395 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
| 396 |
+
|
| 397 |
+
out = self._recombine_heads(out)
|
| 398 |
+
out = self.out_proj(out)
|
| 399 |
+
|
| 400 |
+
return out
|
prima/models/discriminator.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Discriminator(nn.Module):
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
"""
|
| 18 |
+
Pose + Shape discriminator proposed in HMR
|
| 19 |
+
"""
|
| 20 |
+
super(Discriminator, self).__init__()
|
| 21 |
+
|
| 22 |
+
self.num_joints = 34
|
| 23 |
+
# poses_alone
|
| 24 |
+
self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1)
|
| 25 |
+
nn.init.xavier_uniform_(self.D_conv1.weight)
|
| 26 |
+
nn.init.zeros_(self.D_conv1.bias)
|
| 27 |
+
self.relu = nn.ReLU(inplace=True)
|
| 28 |
+
self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1)
|
| 29 |
+
nn.init.xavier_uniform_(self.D_conv2.weight)
|
| 30 |
+
nn.init.zeros_(self.D_conv2.bias)
|
| 31 |
+
pose_out = []
|
| 32 |
+
for i in range(self.num_joints):
|
| 33 |
+
pose_out_temp = nn.Linear(32, 1)
|
| 34 |
+
nn.init.xavier_uniform_(pose_out_temp.weight)
|
| 35 |
+
nn.init.zeros_(pose_out_temp.bias)
|
| 36 |
+
pose_out.append(pose_out_temp)
|
| 37 |
+
self.pose_out = nn.ModuleList(pose_out)
|
| 38 |
+
|
| 39 |
+
# betas
|
| 40 |
+
self.betas_fc1 = nn.Linear(41, 10) # SMAL betas is 41
|
| 41 |
+
nn.init.xavier_uniform_(self.betas_fc1.weight)
|
| 42 |
+
nn.init.zeros_(self.betas_fc1.bias)
|
| 43 |
+
self.betas_fc2 = nn.Linear(10, 5)
|
| 44 |
+
nn.init.xavier_uniform_(self.betas_fc2.weight)
|
| 45 |
+
nn.init.zeros_(self.betas_fc2.bias)
|
| 46 |
+
self.betas_out = nn.Linear(5, 1)
|
| 47 |
+
nn.init.xavier_uniform_(self.betas_out.weight)
|
| 48 |
+
nn.init.zeros_(self.betas_out.bias)
|
| 49 |
+
|
| 50 |
+
# bones
|
| 51 |
+
self.bone_fc1 = nn.Linear(24, 10) # SMAL betas is 41
|
| 52 |
+
nn.init.xavier_uniform_(self.bone_fc1.weight)
|
| 53 |
+
nn.init.zeros_(self.bone_fc1.bias)
|
| 54 |
+
self.bone_fc2 = nn.Linear(10, 5)
|
| 55 |
+
nn.init.xavier_uniform_(self.bone_fc2.weight)
|
| 56 |
+
nn.init.zeros_(self.bone_fc2.bias)
|
| 57 |
+
self.bone_out = nn.Linear(5, 1)
|
| 58 |
+
nn.init.xavier_uniform_(self.bone_out.weight)
|
| 59 |
+
nn.init.zeros_(self.bone_out.bias)
|
| 60 |
+
|
| 61 |
+
# poses_joint
|
| 62 |
+
self.D_alljoints_fc1 = nn.Linear(32 * self.num_joints, 1024)
|
| 63 |
+
nn.init.xavier_uniform_(self.D_alljoints_fc1.weight)
|
| 64 |
+
nn.init.zeros_(self.D_alljoints_fc1.bias)
|
| 65 |
+
self.D_alljoints_fc2 = nn.Linear(1024, 1024)
|
| 66 |
+
nn.init.xavier_uniform_(self.D_alljoints_fc2.weight)
|
| 67 |
+
nn.init.zeros_(self.D_alljoints_fc2.bias)
|
| 68 |
+
self.D_alljoints_out = nn.Linear(1024, 1)
|
| 69 |
+
nn.init.xavier_uniform_(self.D_alljoints_out.weight)
|
| 70 |
+
nn.init.zeros_(self.D_alljoints_out.bias)
|
| 71 |
+
|
| 72 |
+
def forward(self, poses: torch.Tensor, betas: torch.Tensor, bone=None) -> torch.Tensor:
|
| 73 |
+
"""
|
| 74 |
+
Forward pass of the discriminator.
|
| 75 |
+
Args:
|
| 76 |
+
poses (torch.Tensor): Tensor of shape (B, 23, 3, 3) containing a batch of poses (excluding the global orientation).
|
| 77 |
+
betas (torch.Tensor): Tensor of shape (B, 41) containing a batch of SMAL beta coefficients.
|
| 78 |
+
Returns:
|
| 79 |
+
torch.Tensor: Discriminator output with shape (B, 25)
|
| 80 |
+
"""
|
| 81 |
+
# bn = poses.shape[0]
|
| 82 |
+
# poses B x 207
|
| 83 |
+
# poses = poses.reshape(bn, -1)
|
| 84 |
+
# poses B x num_joints x 1 x 9
|
| 85 |
+
poses = poses.reshape(-1, self.num_joints, 1, 9)
|
| 86 |
+
bn = poses.shape[0]
|
| 87 |
+
# poses B x 9 x num_joints x 1
|
| 88 |
+
poses = poses.permute(0, 3, 1, 2).contiguous()
|
| 89 |
+
|
| 90 |
+
# poses_alone
|
| 91 |
+
poses = self.D_conv1(poses)
|
| 92 |
+
poses = self.relu(poses)
|
| 93 |
+
poses = self.D_conv2(poses)
|
| 94 |
+
poses = self.relu(poses)
|
| 95 |
+
|
| 96 |
+
poses_out = []
|
| 97 |
+
for i in range(self.num_joints):
|
| 98 |
+
poses_out_ = self.pose_out[i](poses[:, :, i, 0])
|
| 99 |
+
poses_out.append(poses_out_)
|
| 100 |
+
poses_out = torch.cat(poses_out, dim=1)
|
| 101 |
+
|
| 102 |
+
# betas
|
| 103 |
+
betas = self.betas_fc1(betas)
|
| 104 |
+
betas = self.relu(betas)
|
| 105 |
+
betas = self.betas_fc2(betas)
|
| 106 |
+
betas = self.relu(betas)
|
| 107 |
+
betas_out = self.betas_out(betas)
|
| 108 |
+
|
| 109 |
+
# bone
|
| 110 |
+
if bone is not None:
|
| 111 |
+
bone = self.bone_fc1(bone)
|
| 112 |
+
bone = self.relu(bone)
|
| 113 |
+
bone = self.bone_fc2(bone)
|
| 114 |
+
bone = self.relu(bone)
|
| 115 |
+
bone_out = self.bone_out(bone)
|
| 116 |
+
|
| 117 |
+
# poses_joint
|
| 118 |
+
poses = poses.reshape(bn, -1)
|
| 119 |
+
poses_all = self.D_alljoints_fc1(poses)
|
| 120 |
+
poses_all = self.relu(poses_all)
|
| 121 |
+
poses_all = self.D_alljoints_fc2(poses_all)
|
| 122 |
+
poses_all = self.relu(poses_all)
|
| 123 |
+
poses_all_out = self.D_alljoints_out(poses_all)
|
| 124 |
+
|
| 125 |
+
if bone is not None:
|
| 126 |
+
disc_out = torch.cat((poses_out, betas_out, poses_all_out, bone_out), 1)
|
| 127 |
+
else:
|
| 128 |
+
disc_out = torch.cat((poses_out, betas_out, poses_all_out), 1)
|
| 129 |
+
return disc_out
|
prima/models/heads/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .smal_head import build_smal_head
|
prima/models/heads/classifier_head.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ClassTokenHead(nn.Module):
|
| 14 |
+
def __init__(self, embed_dim=1280, hidden_dim=4096, output_dim=256, num_layers=3, last_bn=True):
|
| 15 |
+
super().__init__()
|
| 16 |
+
mlp = []
|
| 17 |
+
for l in range(num_layers):
|
| 18 |
+
dim1 = embed_dim if l == 0 else hidden_dim
|
| 19 |
+
dim2 = output_dim if l == num_layers - 1 else hidden_dim
|
| 20 |
+
mlp.append(nn.Linear(dim1, dim2, bias=False))
|
| 21 |
+
if l < num_layers - 1:
|
| 22 |
+
mlp.append(nn.BatchNorm1d(dim2))
|
| 23 |
+
mlp.append(nn.ReLU(inplace=True))
|
| 24 |
+
elif last_bn:
|
| 25 |
+
mlp.append(nn.BatchNorm1d(dim2, affine=False))
|
| 26 |
+
self.head = nn.Sequential(*mlp)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
cls_feats = self.head(x)
|
| 30 |
+
return cls_feats
|
prima/models/heads/smal_head.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import numpy as np
|
| 14 |
+
import einops
|
| 15 |
+
import pickle as pkl
|
| 16 |
+
from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat
|
| 17 |
+
from ..components.pose_transformer import TransformerDecoder
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def build_smal_head(cfg):
|
| 21 |
+
smal_head_type = cfg.MODEL.SMAL_HEAD.get('TYPE', 'amr')
|
| 22 |
+
|
| 23 |
+
if smal_head_type == 'new_bio_pose_transformer_decoder':
|
| 24 |
+
return NewBioGuidedSMALPoseDecoder(cfg)
|
| 25 |
+
else:
|
| 26 |
+
raise ValueError('Unknown SMAL head type: {}'.format(smal_head_type))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class NewBioGuidedSMALPoseDecoder(nn.Module):
|
| 35 |
+
"""
|
| 36 |
+
Bio-Guided SMAL Decoder with Pose Token Aggregation
|
| 37 |
+
|
| 38 |
+
Final version:
|
| 39 |
+
- Query tokens = [param token] + [2D keypoint tokens (optional)] + [3D keypoint tokens (optional)]
|
| 40 |
+
- SAM3D-body-style layer-wise keypoint token updates:
|
| 41 |
+
* 2D: predict (x,y) in [-0.5,0.5] from kp2d tokens -> token_augment position encoding
|
| 42 |
+
+ grid_sample image features at predicted locations -> add into kp2d token embeddings
|
| 43 |
+
+ invalid_mask (out-of-bounds and optional vis mask) zeroes updates
|
| 44 |
+
* 3D: predict (x,y,z) from kp3d tokens -> pelvis-normalize -> token_augment position encoding
|
| 45 |
+
- token_augment is injected by feeding (token_embeddings + token_augment) into each decoder layer.
|
| 46 |
+
- Only param token (index 0) is used to regress pose/betas/cam deltas.
|
| 47 |
+
- Outputs:
|
| 48 |
+
pred_smal_params: dict with global_orient/pose/betas and optional keypoints_2d/3d
|
| 49 |
+
pred_cam: [B,3]
|
| 50 |
+
extra_outputs: includes bio-guided shape_feat/init_betas and pred_smal_params_list
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, cfg):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.cfg = cfg
|
| 56 |
+
|
| 57 |
+
# ========== Basic config ==========
|
| 58 |
+
self.joint_rep_type = cfg.MODEL.SMAL_HEAD.get("JOINT_REP", "6d")
|
| 59 |
+
self.joint_rep_dim = {"6d": 6, "aa": 3}[self.joint_rep_type]
|
| 60 |
+
self.npose = self.joint_rep_dim * (cfg.SMAL.NUM_JOINTS + 1)
|
| 61 |
+
|
| 62 |
+
# ========== Dimensions ==========
|
| 63 |
+
self.decoder_dim = cfg.MODEL.SMAL_HEAD.get("DECODER_DIM", 1024)
|
| 64 |
+
context_dim = cfg.MODEL.SMAL_HEAD.get("IN_CHANNELS", 1024)
|
| 65 |
+
num_layers = cfg.MODEL.SMAL_HEAD.get("NUM_DECODER_LAYERS", 4)
|
| 66 |
+
num_heads = cfg.MODEL.SMAL_HEAD.get("NUM_HEADS", 8)
|
| 67 |
+
mlp_ratio = cfg.MODEL.SMAL_HEAD.get("MLP_RATIO", 4.0)
|
| 68 |
+
|
| 69 |
+
# keypoint config
|
| 70 |
+
self.use_keypoint_2d_tokens = cfg.MODEL.SMAL_HEAD.get("USE_KEYPOINT_2D_TOKENS", False)
|
| 71 |
+
self.use_keypoint_3d_tokens = cfg.MODEL.SMAL_HEAD.get("USE_KEYPOINT_3D_TOKENS", False)
|
| 72 |
+
self.num_keypoints = cfg.SMAL.get("NUM_KEYPOINTS", 26)
|
| 73 |
+
self.keypoint_token_update = cfg.MODEL.SMAL_HEAD.get("KEYPOINT_TOKEN_UPDATE", False)
|
| 74 |
+
|
| 75 |
+
# 2D update: whether to inject sampled image feature into kp2d tokens
|
| 76 |
+
self.kp2d_inject_image_feat = cfg.MODEL.SMAL_HEAD.get("KP2D_INJECT_IMAGE_FEAT", True)
|
| 77 |
+
|
| 78 |
+
# IEF iters
|
| 79 |
+
self.ief_iters = cfg.MODEL.SMAL_HEAD.get("IEF_ITERS", 3)
|
| 80 |
+
|
| 81 |
+
# pelvis indices
|
| 82 |
+
self.pelvis_idx = cfg.SMAL.get("PELVIS_IDX", [0, 1])
|
| 83 |
+
|
| 84 |
+
# ========== Test-time optimization config ==========
|
| 85 |
+
self._tta_mode = False # Track if in test-time adaptation mode
|
| 86 |
+
|
| 87 |
+
# ========== [Coarse] Bio prior ==========
|
| 88 |
+
self.bio_to_betas_init = nn.Sequential(
|
| 89 |
+
nn.Linear(context_dim, 256),
|
| 90 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 91 |
+
nn.Linear(256, 41),
|
| 92 |
+
)
|
| 93 |
+
self.shape_projector = nn.Sequential(
|
| 94 |
+
nn.Linear(41, 128),
|
| 95 |
+
nn.ReLU(inplace=True),
|
| 96 |
+
nn.Linear(128, 128),
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# ========== Init pose/cam ==========
|
| 100 |
+
self.init_pose = nn.Parameter(torch.zeros(1, self.npose))
|
| 101 |
+
self.init_cam = nn.Parameter(torch.tensor([[0.9, 0, 0]], dtype=torch.float32))
|
| 102 |
+
|
| 103 |
+
# params -> param token
|
| 104 |
+
param_dim = self.npose + 41 + 3
|
| 105 |
+
self.param_to_token = nn.Sequential(
|
| 106 |
+
nn.Linear(param_dim, self.decoder_dim),
|
| 107 |
+
nn.LayerNorm(self.decoder_dim),
|
| 108 |
+
nn.ReLU(),
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# ========== Keypoint token embeddings ==========
|
| 112 |
+
if self.use_keypoint_2d_tokens:
|
| 113 |
+
self.keypoint_2d_embeddings = nn.Embedding(self.num_keypoints, self.decoder_dim)
|
| 114 |
+
nn.init.normal_(self.keypoint_2d_embeddings.weight, std=0.02)
|
| 115 |
+
|
| 116 |
+
# (x,y) -> token augment
|
| 117 |
+
self.keypoint_2d_pos_encoder = nn.Sequential(
|
| 118 |
+
nn.Linear(2, 256),
|
| 119 |
+
nn.ReLU(),
|
| 120 |
+
nn.Linear(256, self.decoder_dim),
|
| 121 |
+
)
|
| 122 |
+
# sampled image feat -> token dim (add into token embeddings)
|
| 123 |
+
self.keypoint_2d_feat_linear = nn.Linear(self.decoder_dim, self.decoder_dim)
|
| 124 |
+
|
| 125 |
+
if self.use_keypoint_3d_tokens:
|
| 126 |
+
self.keypoint_3d_embeddings = nn.Embedding(self.num_keypoints, self.decoder_dim)
|
| 127 |
+
nn.init.normal_(self.keypoint_3d_embeddings.weight, std=0.02)
|
| 128 |
+
|
| 129 |
+
# (x,y,z) -> token augment
|
| 130 |
+
self.keypoint_3d_pos_encoder = nn.Sequential(
|
| 131 |
+
nn.Linear(3, 256),
|
| 132 |
+
nn.ReLU(),
|
| 133 |
+
nn.Linear(256, self.decoder_dim),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# ========== Per-token intermediate heads (predict from kp tokens themselves) ==========
|
| 137 |
+
if self.keypoint_token_update:
|
| 138 |
+
if self.use_keypoint_2d_tokens:
|
| 139 |
+
self.kp2d_from_tokens = nn.Sequential(
|
| 140 |
+
nn.Linear(self.decoder_dim, self.decoder_dim),
|
| 141 |
+
nn.ReLU(),
|
| 142 |
+
nn.Linear(self.decoder_dim, 2),
|
| 143 |
+
)
|
| 144 |
+
if self.use_keypoint_3d_tokens:
|
| 145 |
+
self.kp3d_from_tokens = nn.Sequential(
|
| 146 |
+
nn.Linear(self.decoder_dim, self.decoder_dim),
|
| 147 |
+
nn.ReLU(),
|
| 148 |
+
nn.Linear(self.decoder_dim, 3),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# ========== Image feature projection + pos encoding ==========
|
| 152 |
+
self.image_proj = nn.Identity() if context_dim == self.decoder_dim else nn.Linear(context_dim, self.decoder_dim)
|
| 153 |
+
self.image_pos_encoding = PositionalEncoding2D(self.decoder_dim)
|
| 154 |
+
|
| 155 |
+
# ========== Transformer decoder layers ==========
|
| 156 |
+
self.layers = nn.ModuleList(
|
| 157 |
+
[
|
| 158 |
+
PoseTransformerDecoderLayer(
|
| 159 |
+
d_model=self.decoder_dim,
|
| 160 |
+
nhead=num_heads,
|
| 161 |
+
dim_feedforward=int(self.decoder_dim * mlp_ratio),
|
| 162 |
+
dropout=0.1,
|
| 163 |
+
)
|
| 164 |
+
for _ in range(num_layers)
|
| 165 |
+
]
|
| 166 |
+
)
|
| 167 |
+
self.norm = nn.LayerNorm(self.decoder_dim)
|
| 168 |
+
|
| 169 |
+
# ========== Regression heads (param token only) ==========
|
| 170 |
+
self.decpose = nn.Sequential(
|
| 171 |
+
nn.Linear(self.decoder_dim, self.decoder_dim),
|
| 172 |
+
nn.ReLU(),
|
| 173 |
+
nn.Linear(self.decoder_dim, self.npose),
|
| 174 |
+
)
|
| 175 |
+
self.decshape = nn.Sequential(
|
| 176 |
+
nn.Linear(self.decoder_dim, self.decoder_dim),
|
| 177 |
+
nn.ReLU(),
|
| 178 |
+
nn.Linear(self.decoder_dim, 41),
|
| 179 |
+
)
|
| 180 |
+
self.deccam = nn.Sequential(
|
| 181 |
+
nn.Linear(self.decoder_dim, self.decoder_dim // 2),
|
| 182 |
+
nn.ReLU(),
|
| 183 |
+
nn.Linear(self.decoder_dim // 2, 3),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# --------------------------
|
| 187 |
+
# helpers: query token build
|
| 188 |
+
# --------------------------
|
| 189 |
+
def _build_query_tokens(self, pred_pose, pred_betas, pred_cam):
|
| 190 |
+
B = pred_pose.shape[0]
|
| 191 |
+
tokens = []
|
| 192 |
+
|
| 193 |
+
params = torch.cat([pred_pose, pred_betas, pred_cam], dim=1) # [B, param_dim]
|
| 194 |
+
param_token = self.param_to_token(params).unsqueeze(1) # [B,1,D]
|
| 195 |
+
tokens.append(param_token)
|
| 196 |
+
|
| 197 |
+
kp2d_start = None
|
| 198 |
+
kp3d_start = None
|
| 199 |
+
|
| 200 |
+
if self.use_keypoint_2d_tokens:
|
| 201 |
+
kp2d_start = sum(t.shape[1] for t in tokens)
|
| 202 |
+
kp2d_tokens = self.keypoint_2d_embeddings.weight.unsqueeze(0).expand(B, -1, -1).contiguous()
|
| 203 |
+
tokens.append(kp2d_tokens)
|
| 204 |
+
|
| 205 |
+
if self.use_keypoint_3d_tokens:
|
| 206 |
+
kp3d_start = sum(t.shape[1] for t in tokens)
|
| 207 |
+
kp3d_tokens = self.keypoint_3d_embeddings.weight.unsqueeze(0).expand(B, -1, -1).contiguous()
|
| 208 |
+
tokens.append(kp3d_tokens)
|
| 209 |
+
|
| 210 |
+
token_embeddings = torch.cat(tokens, dim=1) # [B,Nq,D]
|
| 211 |
+
token_augment = torch.zeros_like(token_embeddings)
|
| 212 |
+
|
| 213 |
+
return token_embeddings, token_augment, kp2d_start, kp3d_start
|
| 214 |
+
|
| 215 |
+
# --------------------------
|
| 216 |
+
# helpers: updates
|
| 217 |
+
# --------------------------
|
| 218 |
+
def _kp2d_update(self, token_embeddings, token_augment, image_features, kp2d_start, H, W, vis_mask=None):
|
| 219 |
+
"""
|
| 220 |
+
SAM3D-body-style 2D keypoint token update.
|
| 221 |
+
|
| 222 |
+
image_features: [B, HW, D] projected + pos-encoded, with HW=H*W (expected 12*16)
|
| 223 |
+
vis_mask: optional [B,N] bool (True=valid)
|
| 224 |
+
"""
|
| 225 |
+
if not (self.keypoint_token_update and self.use_keypoint_2d_tokens):
|
| 226 |
+
return token_embeddings, token_augment, None
|
| 227 |
+
|
| 228 |
+
B = token_embeddings.shape[0]
|
| 229 |
+
N = self.num_keypoints
|
| 230 |
+
|
| 231 |
+
kp_tokens = token_embeddings[:, kp2d_start : kp2d_start + N, :] # [B,N,D]
|
| 232 |
+
|
| 233 |
+
# predict coords in [-0.5,0.5]
|
| 234 |
+
pred_xy = self.kp2d_from_tokens(kp_tokens)
|
| 235 |
+
pred_xy = torch.tanh(pred_xy) * 0.5
|
| 236 |
+
|
| 237 |
+
# invalid mask (out of bounds + optional vis)
|
| 238 |
+
pred_xy_01 = pred_xy + 0.5
|
| 239 |
+
invalid = (
|
| 240 |
+
(pred_xy_01[..., 0] < 0.0)
|
| 241 |
+
| (pred_xy_01[..., 0] > 1.0)
|
| 242 |
+
| (pred_xy_01[..., 1] < 0.0)
|
| 243 |
+
| (pred_xy_01[..., 1] > 1.0)
|
| 244 |
+
)
|
| 245 |
+
if vis_mask is not None:
|
| 246 |
+
invalid = invalid | (~vis_mask)
|
| 247 |
+
valid = (~invalid).unsqueeze(-1).float() # [B,N,1]
|
| 248 |
+
|
| 249 |
+
# update token_augment slice
|
| 250 |
+
token_augment = token_augment.clone()
|
| 251 |
+
token_augment[:, kp2d_start : kp2d_start + N, :] = self.keypoint_2d_pos_encoder(pred_xy) * valid
|
| 252 |
+
|
| 253 |
+
# inject sampled image feature into kp2d tokens (optional)
|
| 254 |
+
if self.kp2d_inject_image_feat:
|
| 255 |
+
img = image_features.view(B, H, W, self.decoder_dim).permute(0, 3, 1, 2).contiguous() # [B,D,H,W]
|
| 256 |
+
grid = (pred_xy * 2.0).unsqueeze(2) # [B,N,1,2] in [-1,1]
|
| 257 |
+
|
| 258 |
+
sampled = (
|
| 259 |
+
F.grid_sample(img, grid, mode="bilinear", padding_mode="zeros", align_corners=False)
|
| 260 |
+
.squeeze(3)
|
| 261 |
+
.permute(0, 2, 1)
|
| 262 |
+
.contiguous()
|
| 263 |
+
) # [B,N,D]
|
| 264 |
+
|
| 265 |
+
sampled = sampled * valid
|
| 266 |
+
token_embeddings = token_embeddings.clone()
|
| 267 |
+
token_embeddings[:, kp2d_start : kp2d_start + N, :] += self.keypoint_2d_feat_linear(sampled)
|
| 268 |
+
|
| 269 |
+
return token_embeddings, token_augment, pred_xy
|
| 270 |
+
|
| 271 |
+
def _kp3d_update(self, token_embeddings, token_augment, kp3d_start):
|
| 272 |
+
if not (self.keypoint_token_update and self.use_keypoint_3d_tokens):
|
| 273 |
+
return token_embeddings, token_augment, None
|
| 274 |
+
|
| 275 |
+
N = self.num_keypoints
|
| 276 |
+
kp_tokens = token_embeddings[:, kp3d_start : kp3d_start + N, :] # [B,N,D]
|
| 277 |
+
|
| 278 |
+
pred_xyz = self.kp3d_from_tokens(kp_tokens) # [B,N,3]
|
| 279 |
+
|
| 280 |
+
# pelvis normalize
|
| 281 |
+
pelvis_center = pred_xyz[:, self.pelvis_idx, :].mean(dim=1, keepdim=True) # [B,1,3]
|
| 282 |
+
pred_xyz_norm = pred_xyz - pelvis_center
|
| 283 |
+
|
| 284 |
+
token_augment = token_augment.clone()
|
| 285 |
+
token_augment[:, kp3d_start : kp3d_start + N, :] = self.keypoint_3d_pos_encoder(pred_xyz_norm)
|
| 286 |
+
|
| 287 |
+
return token_embeddings, token_augment, pred_xyz
|
| 288 |
+
|
| 289 |
+
# --------------------------
|
| 290 |
+
# forward
|
| 291 |
+
# --------------------------
|
| 292 |
+
def forward(self, x, keypoint_coords_2d=None, keypoint_coords_3d=None, **kwargs):
|
| 293 |
+
"""
|
| 294 |
+
Inputs:
|
| 295 |
+
x: [B, Hp*Wp+1, C] image tokens from backbone concatenated with bio token
|
| 296 |
+
(BioCLIP token is the last token in the sequence)
|
| 297 |
+
|
| 298 |
+
Note:
|
| 299 |
+
keypoint_coords_2d can optionally provide a vis/conf mask: [B,N,3] (x,y,vis)
|
| 300 |
+
We do NOT inject GT coords into tokens by default; they are used only as optional masking.
|
| 301 |
+
"""
|
| 302 |
+
B = x.shape[0]
|
| 303 |
+
|
| 304 |
+
# ---- Data preprocessing ----
|
| 305 |
+
# Handle 4D input tensors of shape (B, C, H, W).
|
| 306 |
+
if len(x.shape) == 4:
|
| 307 |
+
x = einops.rearrange(x, 'b c h w -> b (h w) c')
|
| 308 |
+
|
| 309 |
+
bio_token = x[:, -1, :] # [B, C] - the BioCLIP token is the final token
|
| 310 |
+
image_features = x[:, :-1, :] # [B, H*W, C] - remaining image features
|
| 311 |
+
|
| 312 |
+
# ---- Coarse bio shape ----
|
| 313 |
+
init_betas = self.bio_to_betas_init(bio_token) # [B,41]
|
| 314 |
+
shape_feat = F.normalize(self.shape_projector(init_betas), dim=1)
|
| 315 |
+
|
| 316 |
+
# ---- Image feature projection ----
|
| 317 |
+
# Project only image features, excluding the bio token.
|
| 318 |
+
image_features = self.image_proj(image_features) # [B,HW,D]
|
| 319 |
+
|
| 320 |
+
# Your backbone: vit crop 256x192 with patch16 => Hp=12, Wp=16
|
| 321 |
+
H, W = 12, 16
|
| 322 |
+
assert image_features.shape[1] == H * W, f"Expected HW={H*W}, got {image_features.shape[1]}"
|
| 323 |
+
|
| 324 |
+
img_pos = self.image_pos_encoding(H, W).to(image_features.device) # [HW,D]
|
| 325 |
+
image_features = image_features + img_pos.unsqueeze(0)
|
| 326 |
+
|
| 327 |
+
# ---- init params ----
|
| 328 |
+
pred_pose = self.init_pose.expand(B, -1)
|
| 329 |
+
pred_betas = init_betas
|
| 330 |
+
pred_cam = self.init_cam.expand(B, -1)
|
| 331 |
+
|
| 332 |
+
pred_pose_list, pred_betas_list, pred_cam_list = [], [], []
|
| 333 |
+
pred_keypoints_2d_list, pred_keypoints_3d_list = [], []
|
| 334 |
+
|
| 335 |
+
# Optional visibility mask from provided 2D keypoints
|
| 336 |
+
vis_mask = None
|
| 337 |
+
if keypoint_coords_2d is not None and keypoint_coords_2d.shape[-1] == 3:
|
| 338 |
+
vis_mask = keypoint_coords_2d[..., 2] > 0 # [B,N]
|
| 339 |
+
|
| 340 |
+
# ---- IEF loop ----
|
| 341 |
+
for _ in range(self.ief_iters):
|
| 342 |
+
token_embeddings, token_augment, kp2d_start, kp3d_start = self._build_query_tokens(
|
| 343 |
+
pred_pose, pred_betas, pred_cam
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# ---- Transformer layers ----
|
| 347 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 348 |
+
# inject dynamic augment
|
| 349 |
+
tokens_in = token_embeddings + token_augment
|
| 350 |
+
token_embeddings = layer(tokens_in, image_features)
|
| 351 |
+
|
| 352 |
+
# layer-wise token update (skip last layer)
|
| 353 |
+
if self.keypoint_token_update and (layer_idx < len(self.layers) - 1):
|
| 354 |
+
if self.use_keypoint_2d_tokens:
|
| 355 |
+
token_embeddings, token_augment, pred_xy = self._kp2d_update(
|
| 356 |
+
token_embeddings, token_augment, image_features, kp2d_start, H, W, vis_mask=vis_mask
|
| 357 |
+
)
|
| 358 |
+
if pred_xy is not None:
|
| 359 |
+
pred_keypoints_2d_list.append(pred_xy)
|
| 360 |
+
|
| 361 |
+
if self.use_keypoint_3d_tokens:
|
| 362 |
+
token_embeddings, token_augment, pred_xyz = self._kp3d_update(
|
| 363 |
+
token_embeddings, token_augment, kp3d_start
|
| 364 |
+
)
|
| 365 |
+
if pred_xyz is not None:
|
| 366 |
+
pred_keypoints_3d_list.append(pred_xyz)
|
| 367 |
+
|
| 368 |
+
# ---- Regress deltas from param token ----
|
| 369 |
+
token_embeddings = self.norm(token_embeddings)
|
| 370 |
+
param_token_out = token_embeddings[:, 0, :]
|
| 371 |
+
|
| 372 |
+
delta_pose = self.decpose(param_token_out)
|
| 373 |
+
delta_betas = self.decshape(param_token_out)
|
| 374 |
+
delta_cam = self.deccam(param_token_out)
|
| 375 |
+
|
| 376 |
+
pred_pose = pred_pose + delta_pose
|
| 377 |
+
pred_betas = pred_betas + delta_betas
|
| 378 |
+
pred_cam = pred_cam + delta_cam
|
| 379 |
+
|
| 380 |
+
pred_pose_list.append(pred_pose)
|
| 381 |
+
pred_betas_list.append(pred_betas)
|
| 382 |
+
pred_cam_list.append(pred_cam)
|
| 383 |
+
|
| 384 |
+
# ---- Convert joint representation ----
|
| 385 |
+
joint_conversion_fn = {
|
| 386 |
+
"6d": rot6d_to_rotmat,
|
| 387 |
+
"aa": lambda y: aa_to_rotmat(y.view(-1, 3).contiguous()),
|
| 388 |
+
}[self.joint_rep_type]
|
| 389 |
+
|
| 390 |
+
pred_smal_params_list = {
|
| 391 |
+
"pose": torch.cat(
|
| 392 |
+
[joint_conversion_fn(p).view(B, -1, 3, 3)[:, 1:, :, :] for p in pred_pose_list],
|
| 393 |
+
dim=0,
|
| 394 |
+
),
|
| 395 |
+
"betas": torch.cat(pred_betas_list, dim=0),
|
| 396 |
+
"cam": torch.cat(pred_cam_list, dim=0),
|
| 397 |
+
"keypoints_2d": torch.cat(pred_keypoints_2d_list, dim=0) if len(pred_keypoints_2d_list) else None,
|
| 398 |
+
"keypoints_3d": torch.cat(pred_keypoints_3d_list, dim=0) if len(pred_keypoints_3d_list) else None,
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
pred_pose_mat = joint_conversion_fn(pred_pose).view(B, self.cfg.SMAL.NUM_JOINTS + 1, 3, 3)
|
| 402 |
+
pred_smal_params = {
|
| 403 |
+
"global_orient": pred_pose_mat[:, [0]],
|
| 404 |
+
"pose": pred_pose_mat[:, 1:],
|
| 405 |
+
"betas": pred_betas,
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
# expose final predicted keypoints for losses
|
| 409 |
+
if self.keypoint_token_update:
|
| 410 |
+
if self.use_keypoint_2d_tokens and len(pred_keypoints_2d_list):
|
| 411 |
+
pred_smal_params["keypoints_2d"] = pred_keypoints_2d_list[-1]
|
| 412 |
+
if self.use_keypoint_3d_tokens and len(pred_keypoints_3d_list):
|
| 413 |
+
pred_smal_params["keypoints_3d"] = pred_keypoints_3d_list[-1]
|
| 414 |
+
|
| 415 |
+
extra_outputs = {
|
| 416 |
+
"shape_feat": shape_feat,
|
| 417 |
+
"init_betas": init_betas,
|
| 418 |
+
"pred_smal_params_list": pred_smal_params_list,
|
| 419 |
+
}
|
| 420 |
+
return pred_smal_params, pred_cam, extra_outputs
|
| 421 |
+
|
| 422 |
+
# --------------------------
|
| 423 |
+
# Test-time optimization helpers
|
| 424 |
+
# --------------------------
|
| 425 |
+
def freeze_all_except_keypoint_tokens(self):
|
| 426 |
+
"""
|
| 427 |
+
Freeze all parameters except keypoint token embeddings and their prediction heads.
|
| 428 |
+
Use this before test-time optimization.
|
| 429 |
+
"""
|
| 430 |
+
# Freeze everything first
|
| 431 |
+
for param in self.parameters():
|
| 432 |
+
param.requires_grad = False
|
| 433 |
+
|
| 434 |
+
# Unfreeze only keypoint-related parameters
|
| 435 |
+
if self.use_keypoint_2d_tokens:
|
| 436 |
+
for param in self.keypoint_2d_embeddings.parameters():
|
| 437 |
+
param.requires_grad = True
|
| 438 |
+
for param in self.keypoint_2d_pos_encoder.parameters():
|
| 439 |
+
param.requires_grad = True
|
| 440 |
+
for param in self.keypoint_2d_feat_linear.parameters():
|
| 441 |
+
param.requires_grad = True
|
| 442 |
+
if self.keypoint_token_update:
|
| 443 |
+
for param in self.kp2d_from_tokens.parameters():
|
| 444 |
+
param.requires_grad = True
|
| 445 |
+
|
| 446 |
+
if self.use_keypoint_3d_tokens:
|
| 447 |
+
for param in self.keypoint_3d_embeddings.parameters():
|
| 448 |
+
param.requires_grad = True
|
| 449 |
+
for param in self.keypoint_3d_pos_encoder.parameters():
|
| 450 |
+
param.requires_grad = True
|
| 451 |
+
if self.keypoint_token_update:
|
| 452 |
+
for param in self.kp3d_from_tokens.parameters():
|
| 453 |
+
param.requires_grad = True
|
| 454 |
+
|
| 455 |
+
self._tta_mode = True
|
| 456 |
+
print("[TTA] Frozen all parameters except keypoint tokens")
|
| 457 |
+
|
| 458 |
+
def freeze_backbone_only(self):
|
| 459 |
+
"""
|
| 460 |
+
Freeze only backbone, keep SMAL head trainable.
|
| 461 |
+
Use for full SMAL parameter + keypoint optimization.
|
| 462 |
+
"""
|
| 463 |
+
# Unfreeze all SMAL head parameters
|
| 464 |
+
for param in self.parameters():
|
| 465 |
+
param.requires_grad = True
|
| 466 |
+
|
| 467 |
+
self._tta_mode = True
|
| 468 |
+
print("[TTA] SMAL head fully trainable (backbone frozen separately)")
|
| 469 |
+
|
| 470 |
+
def freeze_except_regression_heads(self):
|
| 471 |
+
"""
|
| 472 |
+
Freeze everything except the final regression heads (pose/shape/cam) and keypoint embeddings.
|
| 473 |
+
Keep transformer frozen to preserve pretrained representations.
|
| 474 |
+
"""
|
| 475 |
+
# Freeze everything first
|
| 476 |
+
for param in self.parameters():
|
| 477 |
+
param.requires_grad = False
|
| 478 |
+
|
| 479 |
+
# Unfreeze only the final regression heads (small MLPs)
|
| 480 |
+
for param in self.decpose.parameters():
|
| 481 |
+
param.requires_grad = True
|
| 482 |
+
for param in self.decshape.parameters():
|
| 483 |
+
param.requires_grad = True
|
| 484 |
+
for param in self.deccam.parameters():
|
| 485 |
+
param.requires_grad = True
|
| 486 |
+
|
| 487 |
+
# Unfreeze ONLY keypoint embeddings (learned tokens, NOT position encoders)
|
| 488 |
+
if self.use_keypoint_2d_tokens:
|
| 489 |
+
self.keypoint_2d_embeddings.weight.requires_grad = True
|
| 490 |
+
|
| 491 |
+
if self.use_keypoint_3d_tokens:
|
| 492 |
+
self.keypoint_3d_embeddings.weight.requires_grad = True
|
| 493 |
+
|
| 494 |
+
# DO NOT unfreeze transformer - keep pretrained representations
|
| 495 |
+
# DO NOT unfreeze param_to_token - keep initial token mapping stable
|
| 496 |
+
|
| 497 |
+
self._tta_mode = True
|
| 498 |
+
print("[TTA] Frozen all except regression heads and keypoint embeddings")
|
| 499 |
+
|
| 500 |
+
def unfreeze_all(self):
|
| 501 |
+
"""Restore all parameters to trainable state."""
|
| 502 |
+
for param in self.parameters():
|
| 503 |
+
param.requires_grad = True
|
| 504 |
+
self._tta_mode = False
|
| 505 |
+
print("[TTA] Unfrozen all parameters")
|
| 506 |
+
|
| 507 |
+
def get_tta_parameters(self, mode='keypoints_only'):
|
| 508 |
+
"""
|
| 509 |
+
Get list of parameters that should be optimized during test-time adaptation.
|
| 510 |
+
MUST match what's unfrozen by freeze methods!
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
mode: 'keypoints_only', 'regression_heads', or 'all'
|
| 514 |
+
"""
|
| 515 |
+
params = []
|
| 516 |
+
|
| 517 |
+
# Keypoint embeddings only (NOT position encoders or feature linears)
|
| 518 |
+
if mode in ['keypoints_only', 'regression_heads', 'all']:
|
| 519 |
+
if self.use_keypoint_2d_tokens:
|
| 520 |
+
params.append(self.keypoint_2d_embeddings.weight)
|
| 521 |
+
|
| 522 |
+
if self.use_keypoint_3d_tokens:
|
| 523 |
+
params.append(self.keypoint_3d_embeddings.weight)
|
| 524 |
+
|
| 525 |
+
# Regression heads only (NO transformer or param_to_token)
|
| 526 |
+
if mode in ['regression_heads', 'all']:
|
| 527 |
+
params.extend(list(self.decpose.parameters()))
|
| 528 |
+
params.extend(list(self.decshape.parameters()))
|
| 529 |
+
params.extend(list(self.deccam.parameters()))
|
| 530 |
+
|
| 531 |
+
return params
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class PoseTransformerDecoderLayer(nn.Module):
|
| 537 |
+
"""
|
| 538 |
+
Single-layer transformer decoder for pose-token aggregation.
|
| 539 |
+
Includes self-attention over tokens, cross-attention from tokens to
|
| 540 |
+
image features, and a feed-forward network.
|
| 541 |
+
"""
|
| 542 |
+
|
| 543 |
+
def __init__(self, d_model=1024, nhead=8, dim_feedforward=4096, dropout=0.1):
|
| 544 |
+
super().__init__()
|
| 545 |
+
|
| 546 |
+
# Self-attention over tokens
|
| 547 |
+
self.self_attn = nn.MultiheadAttention(
|
| 548 |
+
d_model, nhead, dropout=dropout, batch_first=True
|
| 549 |
+
)
|
| 550 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 551 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 552 |
+
|
| 553 |
+
# Cross-attention from image features into tokens
|
| 554 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 555 |
+
d_model, nhead, dropout=dropout, batch_first=True
|
| 556 |
+
)
|
| 557 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 558 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 559 |
+
|
| 560 |
+
# Feed-Forward Network
|
| 561 |
+
self.ffn = nn.Sequential(
|
| 562 |
+
nn.Linear(d_model, dim_feedforward),
|
| 563 |
+
nn.GELU(),
|
| 564 |
+
nn.Dropout(dropout),
|
| 565 |
+
nn.Linear(dim_feedforward, d_model),
|
| 566 |
+
nn.Dropout(dropout),
|
| 567 |
+
)
|
| 568 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 569 |
+
|
| 570 |
+
def forward(self, tokens, image_features):
|
| 571 |
+
"""
|
| 572 |
+
Args:
|
| 573 |
+
tokens: [B, N_tokens, C] containing pose and keypoint tokens
|
| 574 |
+
image_features: [B, N_pixels, C] image features
|
| 575 |
+
|
| 576 |
+
Returns:
|
| 577 |
+
tokens: [B, N_tokens, C] updated tokens
|
| 578 |
+
"""
|
| 579 |
+
|
| 580 |
+
# Self-attention lets tokens exchange information.
|
| 581 |
+
attn_output, _ = self.self_attn(tokens, tokens, tokens)
|
| 582 |
+
tokens = tokens + self.dropout1(attn_output)
|
| 583 |
+
tokens = self.norm1(tokens)
|
| 584 |
+
|
| 585 |
+
# Cross-attention injects visual information from image features.
|
| 586 |
+
attn_output, _ = self.cross_attn(
|
| 587 |
+
query=tokens,
|
| 588 |
+
key=image_features,
|
| 589 |
+
value=image_features,
|
| 590 |
+
)
|
| 591 |
+
tokens = tokens + self.dropout2(attn_output)
|
| 592 |
+
tokens = self.norm2(tokens)
|
| 593 |
+
|
| 594 |
+
# Feed-Forward Network
|
| 595 |
+
ffn_output = self.ffn(tokens)
|
| 596 |
+
tokens = tokens + ffn_output
|
| 597 |
+
tokens = self.norm3(tokens)
|
| 598 |
+
|
| 599 |
+
return tokens
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
class PositionalEncoding2D(nn.Module):
|
| 603 |
+
"""
|
| 604 |
+
2D sinusoidal positional encoding for image features.
|
| 605 |
+
"""
|
| 606 |
+
|
| 607 |
+
def __init__(self, embed_dim=1024, temperature=10000):
|
| 608 |
+
super().__init__()
|
| 609 |
+
self.embed_dim = embed_dim
|
| 610 |
+
self.temperature = temperature
|
| 611 |
+
|
| 612 |
+
def forward(self, H, W):
|
| 613 |
+
"""
|
| 614 |
+
Args:
|
| 615 |
+
H, W: height and width of the feature map
|
| 616 |
+
|
| 617 |
+
Returns:
|
| 618 |
+
pos_encoding: [H*W, embed_dim]
|
| 619 |
+
"""
|
| 620 |
+
# Build grid coordinates.
|
| 621 |
+
y_embed = torch.arange(H, dtype=torch.float32).unsqueeze(1).repeat(1, W)
|
| 622 |
+
x_embed = torch.arange(W, dtype=torch.float32).unsqueeze(0).repeat(H, 1)
|
| 623 |
+
|
| 624 |
+
# Normalize to [0, 1].
|
| 625 |
+
y_embed = y_embed / H
|
| 626 |
+
x_embed = x_embed / W
|
| 627 |
+
|
| 628 |
+
# Build frequencies.
|
| 629 |
+
dim_t = torch.arange(self.embed_dim // 2, dtype=torch.float32)
|
| 630 |
+
dim_t = self.temperature ** (2 * dim_t / self.embed_dim)
|
| 631 |
+
|
| 632 |
+
# Apply sine/cosine encoding.
|
| 633 |
+
pos_x = x_embed[: , : , None] / dim_t
|
| 634 |
+
pos_y = y_embed[:, :, None] / dim_t
|
| 635 |
+
|
| 636 |
+
pos_x = torch.stack(
|
| 637 |
+
[pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()], dim=3
|
| 638 |
+
).flatten(2)
|
| 639 |
+
pos_y = torch. stack(
|
| 640 |
+
[pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()], dim=3
|
| 641 |
+
).flatten(2)
|
| 642 |
+
|
| 643 |
+
pos = torch.cat([pos_y, pos_x], dim=2).flatten(0, 1) # [H*W, embed_dim]
|
| 644 |
+
|
| 645 |
+
return pos
|
| 646 |
+
|
| 647 |
+
|
prima/models/losses.py
ADDED
|
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pickle
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from ..utils.geometry import aa_to_rotmat
|
| 16 |
+
from typing import Dict
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def matrix_to_axis_angle(rot_mats: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
"""Convert rotation matrices (..., 3, 3) to axis-angle vectors (..., 3).
|
| 21 |
+
|
| 22 |
+
This local implementation avoids a hard runtime dependency on PyTorch3D.
|
| 23 |
+
"""
|
| 24 |
+
if rot_mats.shape[-2:] != (3, 3):
|
| 25 |
+
raise ValueError(f"Expected (..., 3, 3) rotation matrices, got {rot_mats.shape}")
|
| 26 |
+
|
| 27 |
+
trace = rot_mats[..., 0, 0] + rot_mats[..., 1, 1] + rot_mats[..., 2, 2]
|
| 28 |
+
cos_theta = (trace - 1.0) * 0.5
|
| 29 |
+
cos_theta = torch.clamp(cos_theta, -1.0, 1.0)
|
| 30 |
+
theta = torch.acos(cos_theta)
|
| 31 |
+
|
| 32 |
+
vee = torch.stack(
|
| 33 |
+
[
|
| 34 |
+
rot_mats[..., 2, 1] - rot_mats[..., 1, 2],
|
| 35 |
+
rot_mats[..., 0, 2] - rot_mats[..., 2, 0],
|
| 36 |
+
rot_mats[..., 1, 0] - rot_mats[..., 0, 1],
|
| 37 |
+
],
|
| 38 |
+
dim=-1,
|
| 39 |
+
)
|
| 40 |
+
sin_theta = torch.sin(theta)
|
| 41 |
+
eps = 1e-6
|
| 42 |
+
scale = theta / torch.clamp(2.0 * sin_theta, min=eps)
|
| 43 |
+
aa = vee * scale.unsqueeze(-1)
|
| 44 |
+
|
| 45 |
+
# For very small angles, first-order approximation: aa ~= 0.5 * vee
|
| 46 |
+
small = theta < 1e-4
|
| 47 |
+
if small.any():
|
| 48 |
+
aa = torch.where(small.unsqueeze(-1), 0.5 * vee, aa)
|
| 49 |
+
return aa
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class DepthLoss(nn.Module):
|
| 54 |
+
"""
|
| 55 |
+
Depth loss between predicted SMAL vertices and GT SMAL vertices.
|
| 56 |
+
Compares vertex Z (camera space) after applying camera translation.
|
| 57 |
+
Only computes loss for samples that have valid GT SMAL parameters.
|
| 58 |
+
"""
|
| 59 |
+
def __init__(self, loss_type: str = 'l1'):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.loss_type = loss_type
|
| 62 |
+
self.l1 = nn.L1Loss(reduction='none') # Changed to 'none' for per-sample masking
|
| 63 |
+
self.l2 = nn.MSELoss(reduction='none') # Changed to 'none' for per-sample masking
|
| 64 |
+
|
| 65 |
+
def forward(self,
|
| 66 |
+
pred_vertices: torch.Tensor, # (B, V, 3)
|
| 67 |
+
pred_cam_t: torch.Tensor, # (B, 3)
|
| 68 |
+
gt_smal_params: Dict[str, torch.Tensor],
|
| 69 |
+
smal_model, # SMAL instance callable
|
| 70 |
+
is_axis_angle: Dict[str, torch.Tensor],
|
| 71 |
+
gt_cam_t: torch.Tensor = None, # (B, 3) or None -> fallback to pred_cam_t
|
| 72 |
+
has_smal_params: Dict[str, torch.Tensor] = None # Added masking support
|
| 73 |
+
) -> torch.Tensor:
|
| 74 |
+
batch_size = pred_vertices.shape[0]
|
| 75 |
+
device = pred_vertices.device
|
| 76 |
+
|
| 77 |
+
# Determine which samples have valid GT SMAL params
|
| 78 |
+
# A sample is valid only if it has GT for pose, betas, and global_orient
|
| 79 |
+
if has_smal_params is not None:
|
| 80 |
+
valid_mask = (has_smal_params['pose'] *
|
| 81 |
+
has_smal_params['betas'] *
|
| 82 |
+
has_smal_params['global_orient']).bool()
|
| 83 |
+
|
| 84 |
+
# If no samples have valid GT, return zero loss
|
| 85 |
+
if valid_mask.sum() == 0:
|
| 86 |
+
return torch.tensor(0., device=device, dtype=pred_vertices.dtype)
|
| 87 |
+
else:
|
| 88 |
+
# If not provided, assume all samples are valid
|
| 89 |
+
valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
|
| 90 |
+
|
| 91 |
+
# prepare GT params for SMAL
|
| 92 |
+
gt_params_for_smal = {}
|
| 93 |
+
for k in ['global_orient', 'pose', 'betas']:
|
| 94 |
+
val = gt_smal_params[k].to(device=device)
|
| 95 |
+
if k == 'betas':
|
| 96 |
+
gt_params_for_smal[k] = val.view(batch_size, -1)
|
| 97 |
+
else:
|
| 98 |
+
gt_val = val.view(batch_size, -1)
|
| 99 |
+
if is_axis_angle[k].all():
|
| 100 |
+
gt_val = aa_to_rotmat(gt_val.reshape(-1, 3)).view(batch_size, -1, 3, 3)
|
| 101 |
+
else:
|
| 102 |
+
gt_val = gt_val.view(batch_size, -1, 3, 3)
|
| 103 |
+
gt_params_for_smal[k] = gt_val
|
| 104 |
+
|
| 105 |
+
# generate GT vertices (no grad)
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
gt_out = smal_model(**gt_params_for_smal, pose2rot=False)
|
| 108 |
+
gt_vertices = gt_out.vertices.view(batch_size, -1, 3)
|
| 109 |
+
|
| 110 |
+
if gt_cam_t is None:
|
| 111 |
+
gt_cam_t = pred_cam_t
|
| 112 |
+
|
| 113 |
+
# depth = z in camera coordinates
|
| 114 |
+
pred_depth = (pred_vertices + pred_cam_t.unsqueeze(1))[..., 2] # (B, V)
|
| 115 |
+
gt_depth = (gt_vertices + gt_cam_t.unsqueeze(1))[..., 2] # (B, V)
|
| 116 |
+
|
| 117 |
+
# Compute loss per sample
|
| 118 |
+
if self.loss_type == 'l1':
|
| 119 |
+
loss_per_sample = self.l1(pred_depth, gt_depth).mean(dim=1) # (B,)
|
| 120 |
+
else:
|
| 121 |
+
loss_per_sample = self.l2(pred_depth, gt_depth).mean(dim=1) # (B,)
|
| 122 |
+
|
| 123 |
+
# Apply mask: only compute loss for samples with valid GT
|
| 124 |
+
masked_loss = loss_per_sample * valid_mask.float()
|
| 125 |
+
|
| 126 |
+
# Return mean over valid samples
|
| 127 |
+
num_valid = valid_mask.sum().float()
|
| 128 |
+
if num_valid > 0:
|
| 129 |
+
return masked_loss.sum() / num_valid
|
| 130 |
+
else:
|
| 131 |
+
return torch.tensor(0., device=device, dtype=pred_vertices.dtype)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class MaskLoss(nn.Module):
|
| 135 |
+
"""
|
| 136 |
+
Mask loss between rendered predicted mesh mask and rendered GT mesh mask.
|
| 137 |
+
This loss relies on a MeshRenderer-like object that provides `render_mask(vertices, camera_translation, focal_length)`
|
| 138 |
+
returning a single-channel numpy mask (H, W) with values 0/1.
|
| 139 |
+
"""
|
| 140 |
+
def __init__(self, mesh_renderer=None):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.mesh_renderer = mesh_renderer
|
| 143 |
+
self.l1 = nn.L1Loss(reduction='mean')
|
| 144 |
+
|
| 145 |
+
def forward(self,
|
| 146 |
+
pred_vertices: torch.Tensor, # (B, V, 3)
|
| 147 |
+
pred_cam_t: torch.Tensor, # (B, 3)
|
| 148 |
+
gt_smal_params: Dict[str, torch.Tensor],
|
| 149 |
+
smal_model, # SMAL instance callable
|
| 150 |
+
is_axis_angle: Dict[str, torch.Tensor],
|
| 151 |
+
gt_cam_t: torch.Tensor = None, # optional (B,3)
|
| 152 |
+
focal_length: float = 1000.0
|
| 153 |
+
) -> torch.Tensor:
|
| 154 |
+
batch_size = pred_vertices.shape[0]
|
| 155 |
+
device = pred_vertices.device
|
| 156 |
+
|
| 157 |
+
# if no renderer available, return zero loss
|
| 158 |
+
if self.mesh_renderer is None:
|
| 159 |
+
return torch.tensor(0., device=device, dtype=pred_vertices.dtype)
|
| 160 |
+
|
| 161 |
+
# prepare GT params for SMAL
|
| 162 |
+
gt_params_for_smal = {}
|
| 163 |
+
for k in ['global_orient', 'pose', 'betas']:
|
| 164 |
+
val = gt_smal_params[k].to(device=device)
|
| 165 |
+
if k == 'betas':
|
| 166 |
+
gt_params_for_smal[k] = val.view(batch_size, -1)
|
| 167 |
+
else:
|
| 168 |
+
gt_val = val.view(batch_size, -1)
|
| 169 |
+
if is_axis_angle[k].all():
|
| 170 |
+
gt_val = aa_to_rotmat(gt_val.reshape(-1, 3)).view(batch_size, -1, 3, 3)
|
| 171 |
+
else:
|
| 172 |
+
gt_val = gt_val.view(batch_size, -1, 3, 3)
|
| 173 |
+
gt_params_for_smal[k] = gt_val
|
| 174 |
+
|
| 175 |
+
# generate GT vertices (no grad)
|
| 176 |
+
with torch.no_grad():
|
| 177 |
+
gt_out = smal_model(**gt_params_for_smal, pose2rot=False)
|
| 178 |
+
gt_vertices = gt_out.vertices
|
| 179 |
+
|
| 180 |
+
if gt_cam_t is None:
|
| 181 |
+
gt_cam_t = pred_cam_t
|
| 182 |
+
|
| 183 |
+
# convert to numpy for renderer
|
| 184 |
+
pred_vertices_np = pred_vertices.detach().cpu().numpy()
|
| 185 |
+
gt_vertices_np = gt_vertices.detach().cpu().numpy()
|
| 186 |
+
cam_np = pred_cam_t.detach().cpu().numpy() if pred_cam_t is not None else np.zeros((batch_size, 3), dtype=np.float32)
|
| 187 |
+
|
| 188 |
+
per_item_losses = []
|
| 189 |
+
for i in range(batch_size):
|
| 190 |
+
try:
|
| 191 |
+
pred_mask = self.mesh_renderer.render_mask(pred_vertices_np[i], cam_np[i], focal_length)
|
| 192 |
+
gt_mask_r = self.mesh_renderer.render_mask(gt_vertices_np[i], cam_np[i], focal_length)
|
| 193 |
+
pm = torch.from_numpy(pred_mask).to(device=device, dtype=pred_vertices.dtype)
|
| 194 |
+
gm = torch.from_numpy(gt_mask_r).to(device=device, dtype=pred_vertices.dtype)
|
| 195 |
+
per_item_losses.append(self.l1(pm, gm))
|
| 196 |
+
except Exception:
|
| 197 |
+
# ignore render failure for this sample
|
| 198 |
+
continue
|
| 199 |
+
|
| 200 |
+
if len(per_item_losses) == 0:
|
| 201 |
+
return torch.tensor(0., device=device, dtype=pred_vertices.dtype)
|
| 202 |
+
|
| 203 |
+
return torch.stack(per_item_losses).mean()
|
| 204 |
+
|
| 205 |
+
class Keypoint2DLoss(nn.Module):
|
| 206 |
+
|
| 207 |
+
def __init__(self, loss_type: str = 'l1'):
|
| 208 |
+
"""
|
| 209 |
+
2D keypoint loss module.
|
| 210 |
+
Args:
|
| 211 |
+
loss_type (str): Choose between l1 and l2 losses.
|
| 212 |
+
"""
|
| 213 |
+
super(Keypoint2DLoss, self).__init__()
|
| 214 |
+
if loss_type == 'l1':
|
| 215 |
+
self.loss_fn = nn.L1Loss(reduction='none')
|
| 216 |
+
elif loss_type == 'l2':
|
| 217 |
+
self.loss_fn = nn.MSELoss(reduction='none')
|
| 218 |
+
else:
|
| 219 |
+
raise NotImplementedError('Unsupported loss function')
|
| 220 |
+
|
| 221 |
+
def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor:
|
| 222 |
+
"""
|
| 223 |
+
Compute 2D reprojection loss on the keypoints.
|
| 224 |
+
Args:
|
| 225 |
+
pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
|
| 226 |
+
gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence.
|
| 227 |
+
Returns:
|
| 228 |
+
torch.Tensor: 2D keypoint loss.
|
| 229 |
+
"""
|
| 230 |
+
conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
|
| 231 |
+
batch_size = conf.shape[0]
|
| 232 |
+
loss = (conf * self.loss_fn(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).sum(dim=(1, 2))
|
| 233 |
+
return loss.sum()
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class Keypoint3DLoss(nn.Module):
|
| 237 |
+
|
| 238 |
+
def __init__(self, loss_type: str = 'l1'):
|
| 239 |
+
"""
|
| 240 |
+
3D keypoint loss module.
|
| 241 |
+
Args:
|
| 242 |
+
loss_type (str): Choose between l1 and l2 losses.
|
| 243 |
+
"""
|
| 244 |
+
super(Keypoint3DLoss, self).__init__()
|
| 245 |
+
if loss_type == 'l1':
|
| 246 |
+
self.loss_fn = nn.L1Loss(reduction='none')
|
| 247 |
+
elif loss_type == 'l2':
|
| 248 |
+
self.loss_fn = nn.MSELoss(reduction='none')
|
| 249 |
+
else:
|
| 250 |
+
raise NotImplementedError('Unsupported loss function')
|
| 251 |
+
|
| 252 |
+
def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor, pelvis_id: int = 0):
|
| 253 |
+
"""
|
| 254 |
+
Compute 3D keypoint loss.
|
| 255 |
+
Args:
|
| 256 |
+
pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
|
| 257 |
+
gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence.
|
| 258 |
+
Returns:
|
| 259 |
+
torch.Tensor: 3D keypoint loss.
|
| 260 |
+
"""
|
| 261 |
+
batch_size = pred_keypoints_3d.shape[0]
|
| 262 |
+
gt_keypoints_3d = gt_keypoints_3d.clone()
|
| 263 |
+
pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1)
|
| 264 |
+
gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, pelvis_id, :-1].unsqueeze(dim=1)
|
| 265 |
+
conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
|
| 266 |
+
gt_keypoints_3d = gt_keypoints_3d[:, :, :-1]
|
| 267 |
+
loss = (conf * self.loss_fn(pred_keypoints_3d, gt_keypoints_3d)).sum(dim=(1, 2))
|
| 268 |
+
return loss.sum()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class ParameterLoss(nn.Module):
|
| 272 |
+
|
| 273 |
+
def __init__(self):
|
| 274 |
+
"""
|
| 275 |
+
SMAL parameter loss module.
|
| 276 |
+
"""
|
| 277 |
+
super(ParameterLoss, self).__init__()
|
| 278 |
+
self.loss_fn = nn.MSELoss(reduction='none')
|
| 279 |
+
|
| 280 |
+
def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor):
|
| 281 |
+
"""
|
| 282 |
+
Compute SMAL parameter loss.
|
| 283 |
+
Args:
|
| 284 |
+
pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas)
|
| 285 |
+
gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth MANO parameters.
|
| 286 |
+
Returns:
|
| 287 |
+
torch.Tensor: L2 parameter loss loss.
|
| 288 |
+
"""
|
| 289 |
+
mask = torch.ones_like(pred_param, device=pred_param.device, dtype=pred_param.dtype)
|
| 290 |
+
batch_size = pred_param.shape[0]
|
| 291 |
+
num_dims = len(pred_param.shape)
|
| 292 |
+
mask_dimension = [batch_size] + [1] * (num_dims - 1)
|
| 293 |
+
has_param = has_param.type(pred_param.type()).view(*mask_dimension)
|
| 294 |
+
loss_param = (has_param * self.loss_fn(pred_param*mask, gt_param*mask))
|
| 295 |
+
return loss_param.sum()
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class PosePriorLoss(nn.Module):
|
| 299 |
+
def __init__(self, path_prior):
|
| 300 |
+
super(PosePriorLoss, self).__init__()
|
| 301 |
+
with open(path_prior, "rb") as f:
|
| 302 |
+
data_prior = pickle.load(f, encoding="latin1")
|
| 303 |
+
|
| 304 |
+
self.register_buffer("mean_pose", torch.from_numpy(data_prior["mean_pose"]).float())
|
| 305 |
+
self.register_buffer("precs", torch.from_numpy(np.array(data_prior["pic"])).float())
|
| 306 |
+
|
| 307 |
+
use_index = np.ones(105, dtype=bool)
|
| 308 |
+
use_index[:3] = False # global rotation set False
|
| 309 |
+
self.register_buffer("use_index", torch.from_numpy(use_index).float())
|
| 310 |
+
|
| 311 |
+
def forward(self, x, has_gt):
|
| 312 |
+
"""
|
| 313 |
+
Args:
|
| 314 |
+
x: (batch_size, 35, 3, 3)
|
| 315 |
+
has_gt: has pose?
|
| 316 |
+
Returns:
|
| 317 |
+
pose prior loss
|
| 318 |
+
"""
|
| 319 |
+
if has_gt.sum() == len(has_gt):
|
| 320 |
+
return torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 321 |
+
has_gt = has_gt.type(torch.bool)
|
| 322 |
+
x = x[~has_gt]
|
| 323 |
+
x = matrix_to_axis_angle(x.reshape(-1, 3, 3))
|
| 324 |
+
delta = x.reshape(-1, 35*3) - self.mean_pose
|
| 325 |
+
loss = torch.tensordot(delta, self.precs, dims=([1], [0])) * self.use_index
|
| 326 |
+
return (loss ** 2).mean()
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class ShapePriorLoss(nn.Module):
|
| 330 |
+
def __init__(self, path_prior):
|
| 331 |
+
super(ShapePriorLoss, self).__init__()
|
| 332 |
+
with open(path_prior, "rb") as f:
|
| 333 |
+
data_prior = pickle.load(f, encoding="latin1")
|
| 334 |
+
|
| 335 |
+
model_covs = np.array(data_prior["cluster_cov"]) # shape: (5, 41, 41)
|
| 336 |
+
inverse_covs = np.stack(
|
| 337 |
+
[np.linalg.inv(model_cov + 1e-5 * np.eye(model_cov.shape[0])) for model_cov in model_covs],
|
| 338 |
+
axis=0)
|
| 339 |
+
prec = np.stack([np.linalg.cholesky(inverse_cov) for inverse_cov in inverse_covs], axis=0)
|
| 340 |
+
|
| 341 |
+
self.register_buffer("betas_prec", torch.FloatTensor(prec))
|
| 342 |
+
self.register_buffer("mean_betas", torch.FloatTensor(data_prior["cluster_means"]))
|
| 343 |
+
|
| 344 |
+
def forward(self, x, category, has_gt):
|
| 345 |
+
"""
|
| 346 |
+
Args:
|
| 347 |
+
x: predicted betas (batch_size, 41)
|
| 348 |
+
category: animal category (batch_size,)
|
| 349 |
+
has_gt: has shape?
|
| 350 |
+
Returns:
|
| 351 |
+
shape prior loss
|
| 352 |
+
"""
|
| 353 |
+
if has_gt.sum() == len(has_gt):
|
| 354 |
+
return torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
| 355 |
+
has_gt = has_gt.type(torch.bool)
|
| 356 |
+
x, category = x[~has_gt], category[~has_gt]
|
| 357 |
+
delta = (x - self.mean_betas[category.long()]) # [batch_size, 41]
|
| 358 |
+
loss = []
|
| 359 |
+
for x0, c0 in zip(delta, category):
|
| 360 |
+
loss.append(torch.tensordot(x0, self.betas_prec[c0], dims=([0], [0])))
|
| 361 |
+
loss = torch.stack(loss, dim=0)
|
| 362 |
+
return (loss ** 2).mean()
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class PrototypeSupConLoss(nn.Module):
|
| 367 |
+
def __init__(self, prototypes_init, feat_dim=128, temperature=0.1):
|
| 368 |
+
"""
|
| 369 |
+
prototypes_init: precomputed (5, 512) BioCLIP family prototypes
|
| 370 |
+
feat_dim: dimension of the projected shape feature (128)
|
| 371 |
+
"""
|
| 372 |
+
super().__init__()
|
| 373 |
+
self.temperature = temperature
|
| 374 |
+
|
| 375 |
+
# The prototypes should live in the projected feature space.
|
| 376 |
+
# A practical setup is to pass the BioCLIP centers through the projector
|
| 377 |
+
# once at the beginning of training to initialize these prototypes.
|
| 378 |
+
self.register_buffer("prototypes", torch.randn(5, feat_dim))
|
| 379 |
+
|
| 380 |
+
def forward(self, features, labels):
|
| 381 |
+
"""
|
| 382 |
+
features: (B, 128) normalized shared features or shape features
|
| 383 |
+
labels: (B,) family indices for the 5-way classification setting
|
| 384 |
+
"""
|
| 385 |
+
# 1. Ensure features are normalized.
|
| 386 |
+
features = F.normalize(features, p=2, dim=1)
|
| 387 |
+
# 2. Ensure prototypes are normalized as well.
|
| 388 |
+
prototypes = F.normalize(self.prototypes, p=2, dim=1)
|
| 389 |
+
|
| 390 |
+
# 3. Compute sample-to-prototype similarities with temperature scaling.
|
| 391 |
+
logits = torch.matmul(features, prototypes.T) / self.temperature
|
| 392 |
+
|
| 393 |
+
# 4. Cross-entropy pulls samples toward their family prototype and
|
| 394 |
+
# pushes them away from the other family prototypes.
|
| 395 |
+
loss = F.cross_entropy(logits, labels)
|
| 396 |
+
|
| 397 |
+
return loss
|
| 398 |
+
|
| 399 |
+
@torch.no_grad()
|
| 400 |
+
def update_prototypes(self, features, labels, momentum=0.999):
|
| 401 |
+
"""
|
| 402 |
+
Optional: update prototypes with momentum during training so they
|
| 403 |
+
adapt gradually to the 3D task.
|
| 404 |
+
"""
|
| 405 |
+
for i in range(5):
|
| 406 |
+
mask = (labels == i)
|
| 407 |
+
if mask.any():
|
| 408 |
+
new_mean = features[mask].mean(dim=0)
|
| 409 |
+
self.prototypes[i] = momentum * self.prototypes[i] + (1 - momentum) * new_mean
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class SupConLoss(nn.Module):
|
| 413 |
+
def __init__(self, temperature=0.1, contrast_mode='all',
|
| 414 |
+
base_temperature=0.07):
|
| 415 |
+
super(SupConLoss, self).__init__()
|
| 416 |
+
self.temperature = temperature
|
| 417 |
+
self.contrast_mode = contrast_mode
|
| 418 |
+
self.base_temperature = base_temperature
|
| 419 |
+
|
| 420 |
+
def forward(self, features, labels=None, mask=None):
|
| 421 |
+
"""
|
| 422 |
+
Args:
|
| 423 |
+
features: hidden vector of shape [bsz, ...].
|
| 424 |
+
labels: ground truth of shape [bsz].
|
| 425 |
+
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
|
| 426 |
+
has the same class as sample i. Can be asymmetric.
|
| 427 |
+
Returns:
|
| 428 |
+
A loss scalar.
|
| 429 |
+
"""
|
| 430 |
+
features = torch.stack((features, features), dim=1)
|
| 431 |
+
device = features.device
|
| 432 |
+
|
| 433 |
+
if len(features.shape) < 3:
|
| 434 |
+
raise ValueError('`features` needs to be [bsz, n_views, ...],'
|
| 435 |
+
'at least 3 dimensions are required')
|
| 436 |
+
if len(features.shape) > 3:
|
| 437 |
+
features = features.view(features.shape[0], features.shape[1], -1)
|
| 438 |
+
|
| 439 |
+
batch_size = features.shape[0]
|
| 440 |
+
if labels is not None and mask is not None:
|
| 441 |
+
raise ValueError('Cannot define both `labels` and `mask`')
|
| 442 |
+
elif labels is None and mask is None:
|
| 443 |
+
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
|
| 444 |
+
elif labels is not None:
|
| 445 |
+
labels = labels.contiguous().view(-1, 1)
|
| 446 |
+
if labels.shape[0] != batch_size:
|
| 447 |
+
raise ValueError('Num of labels does not match num of features')
|
| 448 |
+
mask = torch.eq(labels, labels.T).float().to(device)
|
| 449 |
+
else:
|
| 450 |
+
mask = mask.float().to(device)
|
| 451 |
+
|
| 452 |
+
contrast_count = features.shape[1]
|
| 453 |
+
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
|
| 454 |
+
if self.contrast_mode == 'one':
|
| 455 |
+
anchor_feature = features[:, 0]
|
| 456 |
+
anchor_count = 1
|
| 457 |
+
elif self.contrast_mode == 'all':
|
| 458 |
+
anchor_feature = contrast_feature
|
| 459 |
+
anchor_count = contrast_count
|
| 460 |
+
else:
|
| 461 |
+
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
|
| 462 |
+
|
| 463 |
+
# compute logits
|
| 464 |
+
anchor_dot_contrast = torch.div(
|
| 465 |
+
torch.matmul(anchor_feature, contrast_feature.T),
|
| 466 |
+
self.temperature)
|
| 467 |
+
# for numerical stability
|
| 468 |
+
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
|
| 469 |
+
logits = anchor_dot_contrast - logits_max.detach()
|
| 470 |
+
|
| 471 |
+
# tile mask
|
| 472 |
+
mask = mask.repeat(anchor_count, contrast_count)
|
| 473 |
+
# mask-out self-contrast cases
|
| 474 |
+
logits_mask = torch.scatter(
|
| 475 |
+
torch.ones_like(mask),
|
| 476 |
+
1,
|
| 477 |
+
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
|
| 478 |
+
0
|
| 479 |
+
)
|
| 480 |
+
mask = mask * logits_mask
|
| 481 |
+
|
| 482 |
+
# compute log_prob
|
| 483 |
+
exp_logits = torch.exp(logits) * logits_mask
|
| 484 |
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
|
| 485 |
+
|
| 486 |
+
# compute mean of log-likelihood over positive
|
| 487 |
+
mask_pos_pairs = mask.sum(1)
|
| 488 |
+
mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
|
| 489 |
+
mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs
|
| 490 |
+
|
| 491 |
+
# loss
|
| 492 |
+
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
|
| 493 |
+
loss = loss.view(anchor_count, batch_size).mean()
|
| 494 |
+
|
| 495 |
+
return loss
|
| 496 |
+
|
| 497 |
+
# Auxiliary intermediate-supervision loss module.
|
| 498 |
+
|
| 499 |
+
class InterLoss(nn.Module):
|
| 500 |
+
def __init__(self, cfg):
|
| 501 |
+
super().__init__()
|
| 502 |
+
self.cfg = cfg
|
| 503 |
+
self.use_intermediate_supervision = cfg.LOSS.get('USE_INTERMEDIATE_SUPERVISION', True)
|
| 504 |
+
self.intermediate_weight = cfg.LOSS.get('INTERMEDIATE_WEIGHT', 0.5)
|
| 505 |
+
|
| 506 |
+
# 2D keypoint loss
|
| 507 |
+
self.keypoint_2d_loss = nn.MSELoss(reduction='none')
|
| 508 |
+
|
| 509 |
+
# 3D keypoint loss
|
| 510 |
+
self.keypoint_3d_loss = nn.MSELoss(reduction='none')
|
| 511 |
+
|
| 512 |
+
def forward(self, predictions, gt_data):
|
| 513 |
+
"""
|
| 514 |
+
Args:
|
| 515 |
+
predictions: model outputs (pred_smal_params, pred_cam, pred_smal_params_list)
|
| 516 |
+
gt_data: dict containing ground-truth data
|
| 517 |
+
- 'keypoints_2d': [B, N, 3] (x, y, visibility)
|
| 518 |
+
- 'keypoints_3d': [B, N, 3] (x, y, z) or [B, N, 4] (x, y, z, confidence)
|
| 519 |
+
"""
|
| 520 |
+
pred_smal_params, pred_cam, pred_smal_params_list = predictions
|
| 521 |
+
|
| 522 |
+
losses = {}
|
| 523 |
+
total_loss = 0.0
|
| 524 |
+
|
| 525 |
+
# ========== Supervision for final predictions ==========
|
| 526 |
+
# Final keypoint supervision can be added here after running the
|
| 527 |
+
# predicted parameters through the SMAL model.
|
| 528 |
+
|
| 529 |
+
# ========== Supervision for intermediate predictions ==========
|
| 530 |
+
if self.use_intermediate_supervision and pred_smal_params_list is not None:
|
| 531 |
+
|
| 532 |
+
# 2D keypoint supervision
|
| 533 |
+
if 'keypoints_2d' in pred_smal_params_list and pred_smal_params_list['keypoints_2d'] is not None:
|
| 534 |
+
pred_kps_2d_all = pred_smal_params_list['keypoints_2d']
|
| 535 |
+
# [B*num_iters, N, 2]
|
| 536 |
+
|
| 537 |
+
gt_kps_2d = gt_data['keypoints_2d'][: , :, :2] # [B, N, 2]
|
| 538 |
+
gt_vis_2d = gt_data['keypoints_2d'][:, :, 2] # [B, N]
|
| 539 |
+
|
| 540 |
+
# Repeat the ground truth for each iteration.
|
| 541 |
+
num_iters = pred_kps_2d_all.shape[0] // gt_kps_2d.shape[0]
|
| 542 |
+
gt_kps_2d_repeated = gt_kps_2d.repeat(num_iters, 1, 1) # [B*num_iters, N, 2]
|
| 543 |
+
gt_vis_2d_repeated = gt_vis_2d.repeat(num_iters, 1) # [B*num_iters, N]
|
| 544 |
+
|
| 545 |
+
# Compute the loss only over visible keypoints.
|
| 546 |
+
loss_2d = self.keypoint_2d_loss(pred_kps_2d_all, gt_kps_2d_repeated)
|
| 547 |
+
loss_2d = loss_2d.mean(dim=-1) # [B*num_iters, N]
|
| 548 |
+
loss_2d = (loss_2d * gt_vis_2d_repeated).sum() / (gt_vis_2d_repeated.sum() + 1e-6)
|
| 549 |
+
|
| 550 |
+
losses['intermediate_keypoints_2d'] = loss_2d * self.intermediate_weight
|
| 551 |
+
total_loss += losses['intermediate_keypoints_2d']
|
| 552 |
+
|
| 553 |
+
# 3D keypoint supervision
|
| 554 |
+
if 'keypoints_3d' in pred_smal_params_list and pred_smal_params_list['keypoints_3d'] is not None:
|
| 555 |
+
pred_kps_3d_all = pred_smal_params_list['keypoints_3d']
|
| 556 |
+
# [B*num_iters, N, 3]
|
| 557 |
+
|
| 558 |
+
gt_kps_3d = gt_data['keypoints_3d'][: , :, :3] # [B, N, 3]
|
| 559 |
+
if gt_data['keypoints_3d'].shape[-1] == 4:
|
| 560 |
+
gt_conf_3d = gt_data['keypoints_3d'][:, :, 3] # [B, N]
|
| 561 |
+
else:
|
| 562 |
+
gt_conf_3d = torch.ones_like(gt_kps_3d[:, :, 0]) # All keypoints are valid.
|
| 563 |
+
|
| 564 |
+
# Repeat the ground truth for each iteration.
|
| 565 |
+
num_iters = pred_kps_3d_all.shape[0] // gt_kps_3d.shape[0]
|
| 566 |
+
gt_kps_3d_repeated = gt_kps_3d.repeat(num_iters, 1, 1)
|
| 567 |
+
gt_conf_3d_repeated = gt_conf_3d.repeat(num_iters, 1)
|
| 568 |
+
|
| 569 |
+
# Compute the loss.
|
| 570 |
+
loss_3d = self.keypoint_3d_loss(pred_kps_3d_all, gt_kps_3d_repeated)
|
| 571 |
+
loss_3d = loss_3d.mean(dim=-1) # [B*num_iters, N]
|
| 572 |
+
loss_3d = (loss_3d * gt_conf_3d_repeated).sum() / (gt_conf_3d_repeated.sum() + 1e-6)
|
| 573 |
+
|
| 574 |
+
losses['intermediate_keypoints_3d'] = loss_3d * self.intermediate_weight
|
| 575 |
+
total_loss += losses['intermediate_keypoints_3d']
|
| 576 |
+
|
| 577 |
+
# ... other losses (pose, shape, etc.) ...
|
| 578 |
+
|
| 579 |
+
losses['total'] = total_loss
|
| 580 |
+
return losses
|
prima/models/prima.py
ADDED
|
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import pickle
|
| 12 |
+
import pytorch_lightning as pl
|
| 13 |
+
from typing import Any, Dict
|
| 14 |
+
from yacs.config import CfgNode
|
| 15 |
+
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from torchvision.utils import make_grid
|
| 21 |
+
from ..utils.geometry import perspective_projection, aa_to_rotmat
|
| 22 |
+
from ..utils.pylogger import get_pylogger
|
| 23 |
+
from .backbones import create_backbone
|
| 24 |
+
from .heads import build_smal_head
|
| 25 |
+
from prima.models.smal_wrapper import SMAL
|
| 26 |
+
from .discriminator import Discriminator
|
| 27 |
+
|
| 28 |
+
from .bioclip_embedding import BioClipEmbedding
|
| 29 |
+
import sys
|
| 30 |
+
from transformers import AutoModel, AutoFeatureExtractor
|
| 31 |
+
import einops
|
| 32 |
+
|
| 33 |
+
import open_clip
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss, ShapePriorLoss, PosePriorLoss, SupConLoss
|
| 37 |
+
log = get_pylogger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class PRIMA(pl.LightningModule):
|
| 41 |
+
|
| 42 |
+
def __init__(self, cfg: CfgNode, init_renderer: bool = True):
|
| 43 |
+
"""
|
| 44 |
+
Setup PRIMA model
|
| 45 |
+
Args:
|
| 46 |
+
cfg (CfgNode): Config file as a yacs CfgNode
|
| 47 |
+
"""
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
# Save hyperparameters
|
| 51 |
+
self.save_hyperparameters(logger=False, ignore=['init_renderer'])
|
| 52 |
+
|
| 53 |
+
self.cfg = cfg
|
| 54 |
+
# Create backbone feature extractor
|
| 55 |
+
|
| 56 |
+
if cfg.MODEL.BACKBONE.TYPE =='vith':
|
| 57 |
+
self.backbone = create_backbone(cfg) # create vit backbone anyway, for inference, no config loading, just load ckpt weights
|
| 58 |
+
|
| 59 |
+
if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None): # pretrained exists and not none, then true
|
| 60 |
+
|
| 61 |
+
log.info(f'Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}')
|
| 62 |
+
state_dict = torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu', weights_only=True)['state_dict']
|
| 63 |
+
state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items()}
|
| 64 |
+
|
| 65 |
+
missing_keys, unexpected_keys = self.backbone.load_state_dict(state_dict, strict=False)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# freeze backbones
|
| 69 |
+
if cfg.MODEL.BACKBONE.get('FREEZE', False) and cfg.MODEL.BACKBONE.TYPE == 'vith':
|
| 70 |
+
log.info(f'Freezing first 2/3 blocks of vit backbone')
|
| 71 |
+
# Freeze patch embedding
|
| 72 |
+
if hasattr(self.backbone, 'patch_embed'):
|
| 73 |
+
for p in self.backbone.patch_embed.parameters():
|
| 74 |
+
p.requires_grad = False
|
| 75 |
+
|
| 76 |
+
# Freeze first 2/3 of transformer blocks
|
| 77 |
+
if hasattr(self.backbone, 'blocks'):
|
| 78 |
+
total_blocks = len(self.backbone.blocks)
|
| 79 |
+
freeze_blocks = int(total_blocks * 2 / 3)
|
| 80 |
+
log.info(f'Freezing {freeze_blocks} out of {total_blocks} blocks')
|
| 81 |
+
for i in range(freeze_blocks):
|
| 82 |
+
for p in self.backbone.blocks[i].parameters():
|
| 83 |
+
p.requires_grad = False
|
| 84 |
+
|
| 85 |
+
# Create SMAL head (predicts SMAL params + perspective camera)
|
| 86 |
+
self.smal_head = build_smal_head(cfg)
|
| 87 |
+
|
| 88 |
+
# Instantiate SMAL model
|
| 89 |
+
smal_model_path = cfg.SMAL.MODEL_PATH
|
| 90 |
+
with open(smal_model_path, 'rb') as f:
|
| 91 |
+
smal_cfg = pickle.load(f, encoding="latin1")
|
| 92 |
+
self.smal = SMAL(**smal_cfg)
|
| 93 |
+
|
| 94 |
+
# create bioclip model for species classification token extraction
|
| 95 |
+
use_bioclip_embedding = cfg.MODEL.get('USE_BIOCLIP_EMBEDDING', False)
|
| 96 |
+
if use_bioclip_embedding:
|
| 97 |
+
bioclip_config = cfg.MODEL.get('BIOCLIP_EMBEDDING', {})
|
| 98 |
+
embed_dim = bioclip_config.get('EMBED_DIM', 1280)
|
| 99 |
+
self.bioclip_embedding = BioClipEmbedding(cfg, embed_dim=embed_dim)
|
| 100 |
+
# Freeze BioClip model by default
|
| 101 |
+
for param in self.bioclip_embedding.species_model.parameters():
|
| 102 |
+
param.requires_grad = False
|
| 103 |
+
else:
|
| 104 |
+
self.bioclip_embedding = None
|
| 105 |
+
|
| 106 |
+
# Create discriminator
|
| 107 |
+
self.discriminator = Discriminator()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Define loss functions
|
| 113 |
+
self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1')
|
| 114 |
+
self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1')
|
| 115 |
+
|
| 116 |
+
if self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP2D', 0) > 0:
|
| 117 |
+
self.intermediate_kp2d_loss = Keypoint2DLoss(loss_type='l1')
|
| 118 |
+
if self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP3D', 0) > 0:
|
| 119 |
+
self.intermediate_kp3d_loss = Keypoint3DLoss(loss_type='l1')
|
| 120 |
+
self.smal_parameter_loss = ParameterLoss()
|
| 121 |
+
self.shape_prior_loss = ShapePriorLoss(path_prior=cfg.SMAL.SHAPE_PRIOR_PATH)
|
| 122 |
+
self.pose_prior_loss = PosePriorLoss(path_prior=cfg.SMAL.POSE_PRIOR_PATH)
|
| 123 |
+
self.supcon_loss = SupConLoss()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
self.register_buffer('initialized', torch.tensor(False))
|
| 127 |
+
|
| 128 |
+
# init depth renderer for supervised training
|
| 129 |
+
# Setup renderer for visualization
|
| 130 |
+
if init_renderer:
|
| 131 |
+
from ..utils import MeshRenderer
|
| 132 |
+
|
| 133 |
+
self.mesh_renderer = MeshRenderer(self.cfg, faces=self.smal.faces.numpy())
|
| 134 |
+
else:
|
| 135 |
+
self.mesh_renderer = None
|
| 136 |
+
|
| 137 |
+
# Disable automatic optimization since we use adversarial training
|
| 138 |
+
self.automatic_optimization = False
|
| 139 |
+
|
| 140 |
+
def get_parameters(self):
|
| 141 |
+
all_params = list(self.smal_head.parameters())
|
| 142 |
+
if self.cfg.MODEL.BACKBONE.TYPE in ['vith', 'dinov2', 'dinov3']:
|
| 143 |
+
all_params += list(self.backbone.parameters())
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
if hasattr(self, 'keypoint_projection') and self.keypoint_projection is not None:
|
| 147 |
+
all_params += list(self.keypoint_projection.parameters())
|
| 148 |
+
if hasattr(self, 'bioclip_embedding') and self.bioclip_embedding is not None:
|
| 149 |
+
# Only add projection parameters as the model itself is frozen
|
| 150 |
+
all_params += list(self.bioclip_embedding.projection.parameters())
|
| 151 |
+
return all_params
|
| 152 |
+
|
| 153 |
+
def configure_optimizers(self):
|
| 154 |
+
"""
|
| 155 |
+
Setup model and discriminator Optimizers
|
| 156 |
+
Returns:
|
| 157 |
+
Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers
|
| 158 |
+
"""
|
| 159 |
+
# Use separate learning rates only for vith backbone
|
| 160 |
+
if self.cfg.MODEL.BACKBONE.TYPE == 'vith':
|
| 161 |
+
# Separate backbone parameters and other parameters
|
| 162 |
+
backbone_params = []
|
| 163 |
+
other_params = []
|
| 164 |
+
|
| 165 |
+
# Collect backbone parameters
|
| 166 |
+
if hasattr(self, 'backbone'):
|
| 167 |
+
backbone_params = list(filter(lambda p: p.requires_grad, self.backbone.parameters()))
|
| 168 |
+
|
| 169 |
+
# Collect other parameters
|
| 170 |
+
other_params += list(self.smal_head.parameters())
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if hasattr(self, 'keypoint_projection') and self.keypoint_projection is not None:
|
| 174 |
+
other_params += list(self.keypoint_projection.parameters())
|
| 175 |
+
if hasattr(self, 'bioclip_embedding') and self.bioclip_embedding is not None:
|
| 176 |
+
other_params += list(self.bioclip_embedding.projection.parameters())
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# Filter only trainable parameters
|
| 180 |
+
other_params = list(filter(lambda p: p.requires_grad, other_params))
|
| 181 |
+
|
| 182 |
+
# Create parameter groups with different learning rates
|
| 183 |
+
param_groups = [
|
| 184 |
+
{'params': backbone_params, 'lr': self.cfg.TRAIN.LR / 10.0}, # Backbone: 1/10 lr
|
| 185 |
+
{'params': other_params, 'lr': self.cfg.TRAIN.LR} # Other modules: normal lr
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
log.info(f'Using separate LR for vith backbone')
|
| 189 |
+
log.info(f'Backbone parameters: {len(backbone_params)}, lr={self.cfg.TRAIN.LR / 10.0}')
|
| 190 |
+
log.info(f'Other parameters: {len(other_params)}, lr={self.cfg.TRAIN.LR}')
|
| 191 |
+
else:
|
| 192 |
+
# Use same learning rate for all parameters
|
| 193 |
+
all_params = list(filter(lambda p: p.requires_grad, self.get_parameters()))
|
| 194 |
+
param_groups = [{'params': all_params, 'lr': self.cfg.TRAIN.LR}]
|
| 195 |
+
log.info(f'Using same LR for all parameters: {len(all_params)}, lr={self.cfg.TRAIN.LR}')
|
| 196 |
+
|
| 197 |
+
optimizer = torch.optim.AdamW(params=param_groups,
|
| 198 |
+
weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
|
| 199 |
+
if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0:
|
| 200 |
+
optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(),
|
| 201 |
+
lr=self.cfg.TRAIN.LR,
|
| 202 |
+
weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
|
| 203 |
+
else:
|
| 204 |
+
return optimizer,
|
| 205 |
+
|
| 206 |
+
return optimizer, optimizer_disc
|
| 207 |
+
|
| 208 |
+
def forward_step(self, batch: Dict, train: bool = False) -> Dict:
|
| 209 |
+
"""
|
| 210 |
+
Run a forward step of the network
|
| 211 |
+
Args:
|
| 212 |
+
batch (Dict): Dictionary containing batch data
|
| 213 |
+
train (bool): Flag indicating whether it is training or validation mode
|
| 214 |
+
Returns:
|
| 215 |
+
Dict: Dictionary containing the regression output
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
# Use RGB image as input
|
| 219 |
+
x = batch['img'] # [B, 3, H, W]
|
| 220 |
+
batch_size = x.shape[0]
|
| 221 |
+
|
| 222 |
+
# Compute conditioning features using the backbone
|
| 223 |
+
if self.cfg.MODEL.BACKBONE.TYPE =='vith': # vit backbone return [1, 1280, 12, 16]
|
| 224 |
+
conditioning_feats, cls = self.backbone(x[:, :, :, 32:-32]) # reshape the input into [256, 192]
|
| 225 |
+
# return shape shape [B, D, Hp, Wp], [B, D]
|
| 226 |
+
if conditioning_feats.ndim == 4:
|
| 227 |
+
# Flatten spatial dimensions into sequence dimension: [B, D, Hp, Wp] -> [B, Hp*Wp, D]
|
| 228 |
+
B, D, Hp, Wp = conditioning_feats.shape
|
| 229 |
+
conditioning_feats = conditioning_feats.permute(0, 2, 3, 1).reshape(B, Hp * Wp, D) # [B, Hp*Wp, D]
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# add bioclip embedding if enabled
|
| 233 |
+
if self.bioclip_embedding is not None:
|
| 234 |
+
species_feature = self.bioclip_embedding(batch['img']) # [B, embed_dim]
|
| 235 |
+
|
| 236 |
+
# concatenate species feature to conditioning_feats along token dimension
|
| 237 |
+
if len(conditioning_feats.shape) == 3:
|
| 238 |
+
# Token-wise concatenation: add species_feature as a single token
|
| 239 |
+
# (B, embed_dim) -> (B, 1, embed_dim)
|
| 240 |
+
species_token = species_feature.unsqueeze(1) # (B, 1, embed_dim)
|
| 241 |
+
# Concatenate along token dimension: (B, num_tokens, C) + (B, 1, embed_dim) -> (B, num_tokens + 1, C or embed_dim)
|
| 242 |
+
# Note: This requires C == embed_dim for consistent feature dimensions
|
| 243 |
+
conditioning_feats = torch.cat([conditioning_feats, species_token], dim=1) # (B, num_tokens + 1, C)
|
| 244 |
+
else:
|
| 245 |
+
# If conditioning_feats is 2D (B, C), concat directly along feature dimension
|
| 246 |
+
conditioning_feats = torch.cat([conditioning_feats, species_feature], dim=-1)
|
| 247 |
+
|
| 248 |
+
# Predict SMAL parameters and camera
|
| 249 |
+
pred_smal_params, pred_cam, extra_outputs = self.smal_head(conditioning_feats)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# Store useful regression outputs to the output dict
|
| 253 |
+
output = {}
|
| 254 |
+
|
| 255 |
+
if 'shape_feat' in extra_outputs:
|
| 256 |
+
output['shape_feat'] = extra_outputs['shape_feat']
|
| 257 |
+
|
| 258 |
+
if 'init_betas' in extra_outputs:
|
| 259 |
+
output['init_betas'] = extra_outputs['init_betas'].reshape(batch_size, -1)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
output['pred_cam'] = pred_cam # [B, 3]
|
| 263 |
+
output['pred_smal_params'] = {k: v.clone() for k, v in pred_smal_params.items()}
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# Compute camera translation
|
| 268 |
+
focal_length = batch['focal_length']
|
| 269 |
+
|
| 270 |
+
pred_cam_t = torch.stack([
|
| 271 |
+
pred_cam[:, 1],
|
| 272 |
+
pred_cam[:, 2],
|
| 273 |
+
2 * focal_length[:, 0] / (self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] + 1e-9)
|
| 274 |
+
], dim=-1) # [B, 3]
|
| 275 |
+
|
| 276 |
+
output['pred_cam_t'] = pred_cam_t # [B, 3]
|
| 277 |
+
output['focal_length'] = focal_length # [B, 2]
|
| 278 |
+
|
| 279 |
+
# Compute model vertices, joints and the projected joints
|
| 280 |
+
pred_smal_params['global_orient'] = pred_smal_params['global_orient'].reshape(batch_size, -1, 3, 3)
|
| 281 |
+
pred_smal_params['pose'] = pred_smal_params['pose'].reshape(batch_size, -1, 3, 3)
|
| 282 |
+
pred_smal_params['betas'] = pred_smal_params['betas'].reshape(batch_size, -1)
|
| 283 |
+
smal_output = self.smal(**pred_smal_params, pose2rot=False)
|
| 284 |
+
|
| 285 |
+
pred_keypoints_3d = smal_output.joints
|
| 286 |
+
pred_vertices = smal_output.vertices
|
| 287 |
+
output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
|
| 288 |
+
output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
|
| 289 |
+
|
| 290 |
+
# project 3D keypoints to 2D
|
| 291 |
+
pred_keypoints_2d = perspective_projection(
|
| 292 |
+
pred_keypoints_3d,
|
| 293 |
+
translation=pred_cam_t,
|
| 294 |
+
focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE
|
| 295 |
+
) # [B, num_joints, 2]
|
| 296 |
+
output['pred_keypoints_2d'] = pred_keypoints_2d
|
| 297 |
+
|
| 298 |
+
# get intermediate keypoint predictions if available
|
| 299 |
+
|
| 300 |
+
if 'keypoints_3d' in pred_smal_params and pred_smal_params['keypoints_3d'] is not None:
|
| 301 |
+
inter_keypoints_3d = pred_smal_params['keypoints_3d']
|
| 302 |
+
output['inter_keypoints_3d'] = inter_keypoints_3d.reshape(batch_size, -1, 3)
|
| 303 |
+
# output['use_intermediate_kp3d_loss'] = True
|
| 304 |
+
|
| 305 |
+
if 'keypoints_2d' in pred_smal_params and pred_smal_params['keypoints_2d'] is not None:
|
| 306 |
+
inter_keypoints_2d = pred_smal_params['keypoints_2d']
|
| 307 |
+
output['inter_keypoints_2d'] = inter_keypoints_2d.reshape(batch_size, -1, 2)
|
| 308 |
+
# output['use_intermediate_kp2d_loss'] = True
|
| 309 |
+
|
| 310 |
+
return output
|
| 311 |
+
|
| 312 |
+
def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor:
|
| 313 |
+
"""
|
| 314 |
+
Compute losses given the input batch and the regression output
|
| 315 |
+
Args:
|
| 316 |
+
batch (Dict): Dictionary containing batch data
|
| 317 |
+
output (Dict): Dictionary containing the regression output
|
| 318 |
+
train (bool): Flag indicating whether it is training or validation mode
|
| 319 |
+
Returns:
|
| 320 |
+
torch.Tensor : Total loss for current batch
|
| 321 |
+
"""
|
| 322 |
+
|
| 323 |
+
pred_smal_params = output['pred_smal_params']
|
| 324 |
+
pred_keypoints_2d = output['pred_keypoints_2d']
|
| 325 |
+
pred_keypoints_3d = output['pred_keypoints_3d']
|
| 326 |
+
|
| 327 |
+
if 'inter_keypoints_2d' in output:
|
| 328 |
+
inter_keypoints_2d = output['inter_keypoints_2d']
|
| 329 |
+
if 'inter_keypoints_3d' in output:
|
| 330 |
+
inter_keypoints_3d = output['inter_keypoints_3d']
|
| 331 |
+
|
| 332 |
+
batch_size = pred_smal_params['pose'].shape[0]
|
| 333 |
+
device = pred_smal_params['pose'].device
|
| 334 |
+
dtype = pred_smal_params['pose'].dtype
|
| 335 |
+
|
| 336 |
+
# Get annotations
|
| 337 |
+
gt_keypoints_2d = batch['keypoints_2d']
|
| 338 |
+
gt_keypoints_3d = batch['keypoints_3d']
|
| 339 |
+
gt_smal_params = batch['smal_params']
|
| 340 |
+
gt_mask = batch['mask']
|
| 341 |
+
has_smal_params = batch['has_smal_params']
|
| 342 |
+
is_axis_angle = batch['smal_params_is_axis_angle']
|
| 343 |
+
has_mask = batch['has_mask']
|
| 344 |
+
|
| 345 |
+
# Compute 2D keypoint loss
|
| 346 |
+
loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
|
| 347 |
+
|
| 348 |
+
# Compute 3D keypoint loss
|
| 349 |
+
loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
|
| 350 |
+
|
| 351 |
+
# Compute intermediate 2D keypoint loss if available
|
| 352 |
+
loss_intermediate_kp2d = torch.tensor(0., device=device, dtype=dtype)
|
| 353 |
+
if 'inter_keypoints_2d' in output:
|
| 354 |
+
loss_intermediate_kp2d = self.intermediate_kp2d_loss(inter_keypoints_2d, gt_keypoints_2d)
|
| 355 |
+
# loss_keypoints_2d = loss_keypoints_2d + loss_intermediate_kp2d
|
| 356 |
+
|
| 357 |
+
# Compute intermediate 3D keypoint loss if available
|
| 358 |
+
loss_intermediate_kp3d = torch.tensor(0., device=device, dtype=dtype)
|
| 359 |
+
if 'inter_keypoints_3d' in output:
|
| 360 |
+
loss_intermediate_kp3d = self.intermediate_kp3d_loss(inter_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
|
| 361 |
+
# loss_keypoints_3d = loss_keypoints_3d + loss_intermediate_kp3d
|
| 362 |
+
|
| 363 |
+
# add intermediate keypoint losses if available
|
| 364 |
+
|
| 365 |
+
# Compute loss on SMAL parameters
|
| 366 |
+
loss_smal_params = {}
|
| 367 |
+
for k, pred in pred_smal_params.items():
|
| 368 |
+
# Skip keypoint predictions - they're handled separately
|
| 369 |
+
if k in ['keypoints_2d', 'keypoints_3d']:
|
| 370 |
+
continue
|
| 371 |
+
|
| 372 |
+
gt = gt_smal_params[k].view(batch_size, -1)
|
| 373 |
+
if is_axis_angle[k].all():
|
| 374 |
+
gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3)
|
| 375 |
+
has_gt = has_smal_params[k]
|
| 376 |
+
|
| 377 |
+
# Only compute parameter loss if ANY sample has GT
|
| 378 |
+
param_loss = self.smal_parameter_loss(pred.reshape(batch_size, -1),
|
| 379 |
+
gt.reshape(batch_size, -1),
|
| 380 |
+
has_gt)
|
| 381 |
+
|
| 382 |
+
if k == "betas":
|
| 383 |
+
# Only add shape prior loss if NOT all samples have GT (prior is regularization for samples without GT)
|
| 384 |
+
# But the shape_prior_loss already handles this check internally
|
| 385 |
+
loss_smal_params[k] = param_loss + self.shape_prior_loss(pred, batch["category"], has_gt)
|
| 386 |
+
if 'init_betas' in output:
|
| 387 |
+
init_betas = output['init_betas']
|
| 388 |
+
loss_smal_params[k] = loss_smal_params[k] + self.shape_prior_loss(init_betas, batch["category"], has_gt) / 2.
|
| 389 |
+
|
| 390 |
+
else:
|
| 391 |
+
# Only add pose prior loss if NOT all samples have GT
|
| 392 |
+
# The pose_prior_loss already handles this check internally
|
| 393 |
+
loss_smal_params[k] = param_loss + \
|
| 394 |
+
self.pose_prior_loss(torch.cat((pred_smal_params["global_orient"],
|
| 395 |
+
pred_smal_params["pose"]),
|
| 396 |
+
dim=1), has_gt) / 2.
|
| 397 |
+
if 'shape_feat' in output:
|
| 398 |
+
loss_supcon = self.supcon_loss(output['shape_feat'], labels=batch['category'])
|
| 399 |
+
else:
|
| 400 |
+
loss_supcon = torch.tensor(0., device=device, dtype=dtype)
|
| 401 |
+
loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d + \
|
| 402 |
+
self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d + \
|
| 403 |
+
sum([loss_smal_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_smal_params]) + \
|
| 404 |
+
self.cfg.LOSS_WEIGHTS['SUPCON'] * loss_supcon
|
| 405 |
+
|
| 406 |
+
if 'inter_keypoints_2d' in output:
|
| 407 |
+
loss = loss + self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP2D', 0) * loss_intermediate_kp2d
|
| 408 |
+
if 'inter_keypoints_3d' in output:
|
| 409 |
+
loss = loss + self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP3D', 0) * loss_intermediate_kp3d
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
losses = dict(loss=loss.detach(),
|
| 413 |
+
loss_keypoints_2d=loss_keypoints_2d.detach(),
|
| 414 |
+
loss_keypoints_3d=loss_keypoints_3d.detach(),
|
| 415 |
+
loss_supcon=loss_supcon.detach(),
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
for k, v in loss_smal_params.items():
|
| 419 |
+
losses['loss_' + k] = v.detach()
|
| 420 |
+
|
| 421 |
+
# attach intermediate keypoint losses if computed
|
| 422 |
+
if 'inter_keypoints_2d' in output:
|
| 423 |
+
losses['loss_inter_keypoints_2d'] = loss_intermediate_kp2d.detach()
|
| 424 |
+
if 'inter_keypoints_3d' in output:
|
| 425 |
+
losses['loss_inter_keypoints_3d'] = loss_intermediate_kp3d.detach()
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
output['losses'] = losses
|
| 430 |
+
|
| 431 |
+
return loss
|
| 432 |
+
|
| 433 |
+
def forward(self, batch: Dict) -> Dict:
|
| 434 |
+
"""
|
| 435 |
+
Run a forward step of the network in val mode
|
| 436 |
+
Args:
|
| 437 |
+
batch (Dict): Dictionary containing batch data
|
| 438 |
+
Returns:
|
| 439 |
+
Dict: Dictionary containing the regression output
|
| 440 |
+
"""
|
| 441 |
+
return self.forward_step(batch, train=False)
|
| 442 |
+
|
| 443 |
+
def training_step_discriminator(self, batch: Dict,
|
| 444 |
+
pose: torch.Tensor,
|
| 445 |
+
betas: torch.Tensor,
|
| 446 |
+
optimizer: torch.optim.Optimizer) -> torch.Tensor:
|
| 447 |
+
"""
|
| 448 |
+
Run a discriminator training step
|
| 449 |
+
Args:
|
| 450 |
+
batch (Dict): Dictionary containing mocap batch data
|
| 451 |
+
pose (torch.Tensor): Regressed pose from current step
|
| 452 |
+
betas (torch.Tensor): Regressed betas from current step
|
| 453 |
+
optimizer (torch.optim.Optimizer): Discriminator optimizer
|
| 454 |
+
Returns:
|
| 455 |
+
torch.Tensor: Discriminator loss
|
| 456 |
+
"""
|
| 457 |
+
batch_size = pose.shape[0]
|
| 458 |
+
gt_pose = batch['pose']
|
| 459 |
+
gt_betas = batch['betas']
|
| 460 |
+
gt_rotmat = aa_to_rotmat(gt_pose.view(-1, 3)).view(batch_size, -1, 3, 3)
|
| 461 |
+
disc_fake_out = self.discriminator(pose.detach(), betas.detach())
|
| 462 |
+
loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size
|
| 463 |
+
disc_real_out = self.discriminator(gt_rotmat.detach(), gt_betas.detach())
|
| 464 |
+
loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size
|
| 465 |
+
loss_disc = loss_fake + loss_real
|
| 466 |
+
loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc
|
| 467 |
+
optimizer.zero_grad()
|
| 468 |
+
self.manual_backward(loss)
|
| 469 |
+
optimizer.step()
|
| 470 |
+
return loss_disc.detach()
|
| 471 |
+
|
| 472 |
+
# Tensoroboard logging should run from first rank only
|
| 473 |
+
@pl.utilities.rank_zero.rank_zero_only
|
| 474 |
+
def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True,
|
| 475 |
+
write_to_summary_writer: bool = True) -> None:
|
| 476 |
+
"""
|
| 477 |
+
Log results to Tensorboard
|
| 478 |
+
Args:
|
| 479 |
+
batch (Dict): Dictionary containing batch data
|
| 480 |
+
output (Dict): Dictionary containing the regression output
|
| 481 |
+
step_count (int): Global training step count
|
| 482 |
+
train (bool): Flag indicating whether it is training or validation mode
|
| 483 |
+
"""
|
| 484 |
+
|
| 485 |
+
mode = 'train' if train else 'val'
|
| 486 |
+
|
| 487 |
+
images = batch['img']
|
| 488 |
+
gt_keypoints_2d = batch['keypoints_2d']
|
| 489 |
+
batch_size = images.shape[0]
|
| 490 |
+
|
| 491 |
+
# mul std then add mean
|
| 492 |
+
images = (images) * (torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1))
|
| 493 |
+
images = (images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1))
|
| 494 |
+
|
| 495 |
+
pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3)
|
| 496 |
+
losses = output['losses']
|
| 497 |
+
pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3)
|
| 498 |
+
pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2)
|
| 499 |
+
|
| 500 |
+
if write_to_summary_writer:
|
| 501 |
+
summary_writer = self.logger.experiment
|
| 502 |
+
for loss_name, val in losses.items():
|
| 503 |
+
summary_writer.add_scalar(mode + '/' + loss_name, val.detach().item(), step_count)
|
| 504 |
+
# if train is False:
|
| 505 |
+
# for metric_name, val in output['metric'].items():
|
| 506 |
+
# summary_writer.add_scalar(mode + '/' + metric_name, val, step_count)
|
| 507 |
+
num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES)
|
| 508 |
+
|
| 509 |
+
predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(),
|
| 510 |
+
pred_cam_t[:num_images].cpu().numpy(),
|
| 511 |
+
images[:num_images].cpu().numpy(),
|
| 512 |
+
self.cfg.SMAL.get("FOCAL_LENGTH", 1000),
|
| 513 |
+
pred_keypoints_2d[:num_images].cpu().numpy(),
|
| 514 |
+
gt_keypoints_2d[:num_images].cpu().numpy(),
|
| 515 |
+
pred_masks=output.get('pred_masks', None)[:num_images] if output.get('pred_masks', None) is not None else None,
|
| 516 |
+
gt_masks=output.get('gt_masks', None)[:num_images] if output.get('gt_masks', None) is not None else None,
|
| 517 |
+
)
|
| 518 |
+
predictions = make_grid(predictions, nrow=5, padding=2)
|
| 519 |
+
if write_to_summary_writer:
|
| 520 |
+
summary_writer.add_image('%s/predictions' % mode, predictions, step_count)
|
| 521 |
+
|
| 522 |
+
return predictions
|
| 523 |
+
|
| 524 |
+
def training_step(self, batch: Dict) -> Dict:
|
| 525 |
+
"""
|
| 526 |
+
Run a full training step
|
| 527 |
+
Args:
|
| 528 |
+
batch (Dict): Dictionary containing {'img', 'mask', 'keypoints_2d', 'keypoints_3d', 'orig_keypoints_2d',
|
| 529 |
+
'box_center', 'box_size', 'img_size', 'smal_params',
|
| 530 |
+
'smal_params_is_axis_angle', '_trans', 'imgname', 'focal_length'}
|
| 531 |
+
Returns:
|
| 532 |
+
Dict: Dictionary containing regression output.
|
| 533 |
+
"""
|
| 534 |
+
batch = batch['img']
|
| 535 |
+
optimizer = self.optimizers(use_pl_optimizer=True)
|
| 536 |
+
if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0:
|
| 537 |
+
optimizer, optimizer_disc = optimizer
|
| 538 |
+
|
| 539 |
+
batch_size = batch['img'].shape[0]
|
| 540 |
+
output = self.forward_step(batch, train=True)
|
| 541 |
+
pred_smal_params = output['pred_smal_params']
|
| 542 |
+
loss = self.compute_loss(batch, output, train=True)
|
| 543 |
+
if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0:
|
| 544 |
+
disc_out = self.discriminator(pred_smal_params['pose'].reshape(batch_size, -1),
|
| 545 |
+
pred_smal_params['betas'].reshape(batch_size, -1))
|
| 546 |
+
loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size
|
| 547 |
+
loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv
|
| 548 |
+
|
| 549 |
+
# Error if Nan
|
| 550 |
+
if torch.isnan(loss):
|
| 551 |
+
raise ValueError('Loss is NaN')
|
| 552 |
+
|
| 553 |
+
optimizer.zero_grad()
|
| 554 |
+
self.manual_backward(loss)
|
| 555 |
+
# Clip gradient
|
| 556 |
+
if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0:
|
| 557 |
+
gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL,
|
| 558 |
+
error_if_nonfinite=True)
|
| 559 |
+
self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
|
| 560 |
+
|
| 561 |
+
# For compatibility
|
| 562 |
+
# if self.cfg.LOSS_WEIGHTS.ADVERSARIAL == 0:
|
| 563 |
+
# optimizer.param_groups[0]['capturable'] = True
|
| 564 |
+
|
| 565 |
+
optimizer.step()
|
| 566 |
+
if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0:
|
| 567 |
+
loss_disc = self.training_step_discriminator(batch['smal_params'],
|
| 568 |
+
pred_smal_params['pose'].reshape(batch_size, -1),
|
| 569 |
+
pred_smal_params['betas'].reshape(batch_size, -1),
|
| 570 |
+
optimizer_disc)
|
| 571 |
+
output['losses']['loss_gen'] = loss_adv
|
| 572 |
+
output['losses']['loss_disc'] = loss_disc
|
| 573 |
+
|
| 574 |
+
if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0:
|
| 575 |
+
self.tensorboard_logging(batch, output, self.global_step, train=True)
|
| 576 |
+
|
| 577 |
+
# Log training loss to the logger so checkpoint callback can monitor it.
|
| 578 |
+
self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True,
|
| 579 |
+
logger=True, batch_size=batch_size, sync_dist=True)
|
| 580 |
+
|
| 581 |
+
return output
|
| 582 |
+
|
| 583 |
+
def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict:
|
| 584 |
+
"""
|
| 585 |
+
Run a validation step and log to Tensorboard
|
| 586 |
+
Args:
|
| 587 |
+
batch (Dict): Dictionary containing batch data
|
| 588 |
+
batch_idx (int): Unused.
|
| 589 |
+
Returns:
|
| 590 |
+
Dict: Dictionary containing regression output.
|
| 591 |
+
"""
|
| 592 |
+
# The validation dataloader yields the inner batch dict directly (not wrapped as {'img': loader}).
|
| 593 |
+
# Run forward, compute loss and log aggregated validation metrics so ModelCheckpoint can monitor them.
|
| 594 |
+
output = self.forward_step(batch, train=False)
|
| 595 |
+
# compute_loss will populate output['losses'] and return the scalar loss
|
| 596 |
+
loss = self.compute_loss(batch, output, train=False)
|
| 597 |
+
|
| 598 |
+
# Ensure losses dict is available
|
| 599 |
+
losses = output.get('losses', {})
|
| 600 |
+
|
| 601 |
+
# Log all validation losses to logger with on_epoch=True so checkpoint monitors epoch-level metric
|
| 602 |
+
for loss_name, val in losses.items():
|
| 603 |
+
# use prog_bar only for the main loss
|
| 604 |
+
prog = True if loss_name == 'loss' else False
|
| 605 |
+
# Log as 'val/<loss_name>' e.g. 'val/loss'
|
| 606 |
+
self.log(f'val/{loss_name}', val, on_step=False, on_epoch=True, prog_bar=prog, logger=True,
|
| 607 |
+
sync_dist=True)
|
| 608 |
+
|
| 609 |
+
# Periodically write images/other visuals to tensorboard
|
| 610 |
+
# Log visualizations on the first batch of each validation epoch
|
| 611 |
+
if batch_idx == 0:
|
| 612 |
+
# Use global_step for step count when logging validation visuals
|
| 613 |
+
self.tensorboard_logging(batch, output, self.global_step, train=False)
|
| 614 |
+
|
| 615 |
+
return output
|
prima/models/smal_wrapper.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
from torch import nn
|
| 12 |
+
import torch
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pickle
|
| 15 |
+
import cv2
|
| 16 |
+
from typing import Optional, Tuple, NewType
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
import smplx
|
| 19 |
+
from smplx.lbs import vertices2joints, lbs
|
| 20 |
+
from smplx.utils import MANOOutput, to_tensor, ModelOutput
|
| 21 |
+
from smplx.vertex_ids import vertex_ids
|
| 22 |
+
|
| 23 |
+
Tensor = NewType('Tensor', torch.Tensor)
|
| 24 |
+
keypoint_vertices_idx = [[1068, 1080, 1029, 1226], [2660, 3030, 2675, 3038], [910], [360, 1203, 1235, 1230],
|
| 25 |
+
[3188, 3156, 2327, 3183], [1976, 1974, 1980, 856], [3854, 2820, 3852, 3858], [452, 1811],
|
| 26 |
+
[416, 235, 182], [2156, 2382, 2203], [829], [2793], [60, 114, 186, 59],
|
| 27 |
+
[2091, 2037, 2036, 2160], [384, 799, 1169, 431], [2351, 2763, 2397, 3127],
|
| 28 |
+
[221, 104], [2754, 2192], [191, 1158, 3116, 2165],
|
| 29 |
+
[28, 1109, 1110, 1111, 1835, 1836, 3067, 3068, 3069],
|
| 30 |
+
[498, 499, 500, 501, 502, 503], [2463, 2464, 2465, 2466, 2467, 2468],
|
| 31 |
+
[764, 915, 916, 917, 934, 935, 956], [2878, 2879, 2880, 2897, 2898, 2919, 3751],
|
| 32 |
+
[1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762],
|
| 33 |
+
[0, 464, 465, 726, 1824, 2429, 2430, 2690]]
|
| 34 |
+
|
| 35 |
+
name2id35 = {'RFoot': 14, 'RFootBack': 24, 'spine1': 4, 'Head': 16, 'LLegBack3': 19, 'RLegBack1': 21, 'pelvis0': 1,
|
| 36 |
+
'RLegBack3': 23, 'LLegBack2': 18, 'spine0': 3, 'spine3': 6, 'spine2': 5, 'Mouth': 32, 'Neck': 15,
|
| 37 |
+
'LFootBack': 20, 'LLegBack1': 17, 'RLeg3': 13, 'RLeg2': 12, 'LLeg1': 7, 'LLeg3': 9, 'RLeg1': 11,
|
| 38 |
+
'LLeg2': 8, 'spine': 2, 'LFoot': 10, 'Tail7': 31, 'Tail6': 30, 'Tail5': 29, 'Tail4': 28, 'Tail3': 27,
|
| 39 |
+
'Tail2': 26, 'Tail1': 25, 'RLegBack2': 22, 'root': 0, 'LEar': 33, 'REar': 34, 'EndNose': 35, 'Chin': 36,
|
| 40 |
+
'RightEarTip': 37, 'LeftEarTip': 38, 'LeftEye': 39, 'RightEye': 40}
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class SMALOutput(ModelOutput):
|
| 44 |
+
betas: Optional[Tensor] = None
|
| 45 |
+
pose: Optional[Tensor] = None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class SMALLayer(nn.Module):
|
| 49 |
+
def __init__(self, num_betas=41, **kwargs):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.num_betas = num_betas
|
| 52 |
+
self.register_buffer("shapedirs", torch.from_numpy(np.array(kwargs['shapedirs'], dtype=np.float32))[:, :, :num_betas]) # [3889, 3, 41]
|
| 53 |
+
self.register_buffer("v_template", torch.from_numpy(np.array(kwargs['v_template']).astype(np.float32))) # [3889, 3]
|
| 54 |
+
self.register_buffer("posedirs", torch.from_numpy(np.array(kwargs['posedirs'], dtype=np.float32)).reshape(-1,
|
| 55 |
+
34*9).T) # [34*9, 11667]
|
| 56 |
+
self.register_buffer("J_regressor", torch.from_numpy(kwargs['J_regressor'].toarray().astype(np.float32))) # [33, 3389]
|
| 57 |
+
self.register_buffer("lbs_weights", torch.from_numpy(np.array(kwargs['weights'], dtype=np.float32))) # [3889, 33]
|
| 58 |
+
self.register_buffer("faces", torch.from_numpy(np.array(kwargs['f'], dtype=np.int32))) # [7774, 3]
|
| 59 |
+
|
| 60 |
+
kintree_table = kwargs['kintree_table']
|
| 61 |
+
self.register_buffer("parents", torch.from_numpy(kintree_table[0].astype(np.int32)))
|
| 62 |
+
|
| 63 |
+
def forward(
|
| 64 |
+
self,
|
| 65 |
+
betas: Optional[Tensor] = None,
|
| 66 |
+
global_orient: Optional[Tensor] = None,
|
| 67 |
+
pose: Optional[Tensor] = None,
|
| 68 |
+
transl: Optional[Tensor] = None,
|
| 69 |
+
return_verts: bool = True,
|
| 70 |
+
return_full_pose: bool = False,
|
| 71 |
+
**kwargs):
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
betas: [batch_size, 10]
|
| 75 |
+
global_orient: [batch_size, 1, 3, 3]
|
| 76 |
+
pose: [batch_size, num_joints, 3, 3]
|
| 77 |
+
transl: [batch_size, num_joints, 3]
|
| 78 |
+
return_verts:
|
| 79 |
+
return_full_pose:
|
| 80 |
+
**kwargs:
|
| 81 |
+
Returns:
|
| 82 |
+
"""
|
| 83 |
+
device, dtype = betas.device, betas.dtype
|
| 84 |
+
if global_orient is None:
|
| 85 |
+
batch_size = 1
|
| 86 |
+
global_orient = torch.eye(3, device=device, dtype=dtype).view(
|
| 87 |
+
1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
|
| 88 |
+
else:
|
| 89 |
+
batch_size = global_orient.shape[0]
|
| 90 |
+
if pose is None:
|
| 91 |
+
pose = torch.eye(3, device=device, dtype=dtype).view(
|
| 92 |
+
1, 1, 3, 3).expand(batch_size, 34, -1, -1).contiguous()
|
| 93 |
+
if betas is None:
|
| 94 |
+
betas = torch.zeros(
|
| 95 |
+
[batch_size, self.num_betas], dtype=dtype, device=device)
|
| 96 |
+
if transl is None:
|
| 97 |
+
transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
|
| 98 |
+
|
| 99 |
+
full_pose = torch.cat([global_orient, pose], dim=1)
|
| 100 |
+
vertices, joints = lbs(betas, full_pose, self.v_template,
|
| 101 |
+
self.shapedirs, self.posedirs,
|
| 102 |
+
self.J_regressor, self.parents,
|
| 103 |
+
self.lbs_weights, pose2rot=False)
|
| 104 |
+
|
| 105 |
+
if transl is not None:
|
| 106 |
+
joints = joints + transl.unsqueeze(dim=1)
|
| 107 |
+
vertices = vertices + transl.unsqueeze(dim=1)
|
| 108 |
+
|
| 109 |
+
output = SMALOutput(
|
| 110 |
+
vertices=vertices if return_verts else None,
|
| 111 |
+
joints=joints if return_verts else None,
|
| 112 |
+
betas=betas,
|
| 113 |
+
global_orient=global_orient,
|
| 114 |
+
pose=pose,
|
| 115 |
+
transl=transl,
|
| 116 |
+
full_pose=full_pose if return_full_pose else None,
|
| 117 |
+
)
|
| 118 |
+
return output
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class SMAL(SMALLayer):
|
| 122 |
+
def __init__(self, **kwargs):
|
| 123 |
+
super(SMAL, self).__init__(**kwargs)
|
| 124 |
+
|
| 125 |
+
def forward(self, *args, **kwargs):
|
| 126 |
+
smal_output = super(SMAL, self).forward(**kwargs)
|
| 127 |
+
|
| 128 |
+
keypoint = []
|
| 129 |
+
for kp_v in keypoint_vertices_idx:
|
| 130 |
+
keypoint.append(smal_output.vertices[:, kp_v, :].mean(dim=1))
|
| 131 |
+
smal_output.joints = torch.stack(keypoint, dim=1)
|
| 132 |
+
return smal_output
|
| 133 |
+
|
| 134 |
+
|
prima/utils/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def recursive_to(x: Any, target: Any):
|
| 14 |
+
"""
|
| 15 |
+
Recursively transfer a batch of data to the target device
|
| 16 |
+
Args:
|
| 17 |
+
x (Any): Batch of data.
|
| 18 |
+
target (torch.device): Target device.
|
| 19 |
+
Returns:
|
| 20 |
+
Batch of data where all tensors are transferred to the target device.
|
| 21 |
+
"""
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
def move(value: Any):
|
| 25 |
+
if isinstance(value, dict):
|
| 26 |
+
return {k: move(v) for k, v in value.items()}
|
| 27 |
+
if isinstance(value, torch.Tensor):
|
| 28 |
+
return value.to(target)
|
| 29 |
+
if isinstance(value, list):
|
| 30 |
+
return [move(i) for i in value]
|
| 31 |
+
return value
|
| 32 |
+
|
| 33 |
+
return move(x)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def __getattr__(name: str):
|
| 37 |
+
if name == "MeshRenderer":
|
| 38 |
+
from .mesh_renderer import MeshRenderer
|
| 39 |
+
|
| 40 |
+
globals()[name] = MeshRenderer
|
| 41 |
+
return MeshRenderer
|
| 42 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
__all__ = ["MeshRenderer", "recursive_to"]
|
prima/utils/detection.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
# Utilities for filtering animal detections before PRIMA demo inference.
|
| 13 |
+
#
|
| 14 |
+
# Detectron2 may return both a full-animal box and a local/partial box for the
|
| 15 |
+
# same animal. These helpers keep the demo pipeline from rendering the same
|
| 16 |
+
# animal multiple times.
|
| 17 |
+
|
| 18 |
+
from typing import Iterable
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
ANIMAL_COCO_IDS = (15, 16, 17, 18, 19, 21, 22)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _box_areas(boxes: np.ndarray) -> np.ndarray:
|
| 26 |
+
widths = np.maximum(0.0, boxes[:, 2] - boxes[:, 0])
|
| 27 |
+
heights = np.maximum(0.0, boxes[:, 3] - boxes[:, 1])
|
| 28 |
+
return widths * heights
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _intersection_areas(box: np.ndarray, boxes: np.ndarray) -> np.ndarray:
|
| 32 |
+
x1 = np.maximum(box[0], boxes[:, 0])
|
| 33 |
+
y1 = np.maximum(box[1], boxes[:, 1])
|
| 34 |
+
x2 = np.minimum(box[2], boxes[:, 2])
|
| 35 |
+
y2 = np.minimum(box[3], boxes[:, 3])
|
| 36 |
+
return np.maximum(0.0, x2 - x1) * np.maximum(0.0, y2 - y1)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _suppress_duplicate_boxes(
|
| 40 |
+
boxes: np.ndarray,
|
| 41 |
+
scores: np.ndarray,
|
| 42 |
+
*,
|
| 43 |
+
iou_threshold: float,
|
| 44 |
+
containment_threshold: float,
|
| 45 |
+
) -> np.ndarray:
|
| 46 |
+
if len(boxes) <= 1:
|
| 47 |
+
return np.arange(len(boxes), dtype=np.int64)
|
| 48 |
+
|
| 49 |
+
boxes = boxes.astype(np.float32, copy=False)
|
| 50 |
+
scores = scores.astype(np.float32, copy=False)
|
| 51 |
+
areas = _box_areas(boxes)
|
| 52 |
+
|
| 53 |
+
contained = np.zeros(len(boxes), dtype=bool)
|
| 54 |
+
for idx, area in enumerate(areas):
|
| 55 |
+
if area <= 0:
|
| 56 |
+
contained[idx] = True
|
| 57 |
+
continue
|
| 58 |
+
larger = np.where(areas > area)[0]
|
| 59 |
+
if len(larger) == 0:
|
| 60 |
+
continue
|
| 61 |
+
covered = _intersection_areas(boxes[idx], boxes[larger]) / area
|
| 62 |
+
if np.any(covered >= containment_threshold):
|
| 63 |
+
contained[idx] = True
|
| 64 |
+
|
| 65 |
+
candidates = np.where(~contained)[0]
|
| 66 |
+
if len(candidates) <= 1:
|
| 67 |
+
return candidates
|
| 68 |
+
|
| 69 |
+
order = candidates[np.argsort(scores[candidates])[::-1]]
|
| 70 |
+
keep = []
|
| 71 |
+
while len(order) > 0:
|
| 72 |
+
current = order[0]
|
| 73 |
+
keep.append(current)
|
| 74 |
+
rest = order[1:]
|
| 75 |
+
if len(rest) == 0:
|
| 76 |
+
break
|
| 77 |
+
|
| 78 |
+
inter = _intersection_areas(boxes[current], boxes[rest])
|
| 79 |
+
union = areas[current] + areas[rest] - inter
|
| 80 |
+
iou = np.divide(inter, union, out=np.zeros_like(inter), where=union > 0)
|
| 81 |
+
order = rest[iou <= iou_threshold]
|
| 82 |
+
|
| 83 |
+
return np.array(sorted(keep), dtype=np.int64)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def select_animal_boxes(
|
| 87 |
+
det_instances,
|
| 88 |
+
*,
|
| 89 |
+
animal_class_ids: Iterable[int] = ANIMAL_COCO_IDS,
|
| 90 |
+
score_threshold: float = 0.7,
|
| 91 |
+
iou_threshold: float = 0.5,
|
| 92 |
+
containment_threshold: float = 0.9,
|
| 93 |
+
) -> tuple[np.ndarray, int]:
|
| 94 |
+
"""Return filtered animal boxes and the number of duplicate boxes removed."""
|
| 95 |
+
class_ids = set(int(class_id) for class_id in animal_class_ids)
|
| 96 |
+
classes = det_instances.pred_classes.detach().cpu().numpy()
|
| 97 |
+
scores = det_instances.scores.detach().cpu().numpy()
|
| 98 |
+
|
| 99 |
+
valid_idx = np.array(
|
| 100 |
+
[
|
| 101 |
+
i
|
| 102 |
+
for i, (class_id, score) in enumerate(zip(classes, scores))
|
| 103 |
+
if int(class_id) in class_ids and float(score) > float(score_threshold)
|
| 104 |
+
],
|
| 105 |
+
dtype=np.int64,
|
| 106 |
+
)
|
| 107 |
+
if len(valid_idx) == 0:
|
| 108 |
+
return np.zeros((0, 4), dtype=np.float32), 0
|
| 109 |
+
|
| 110 |
+
boxes = det_instances.pred_boxes.tensor[valid_idx].detach().cpu().numpy()
|
| 111 |
+
scores = scores[valid_idx]
|
| 112 |
+
keep = _suppress_duplicate_boxes(
|
| 113 |
+
boxes,
|
| 114 |
+
scores,
|
| 115 |
+
iou_threshold=iou_threshold,
|
| 116 |
+
containment_threshold=containment_threshold,
|
| 117 |
+
)
|
| 118 |
+
return boxes[keep], int(len(boxes) - len(keep))
|
prima/utils/evaluate_metric.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
import open3d as o3d
|
| 13 |
+
from typing import Dict, List, Union
|
| 14 |
+
from pytorch3d.transforms import axis_angle_to_matrix
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def compute_scale_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
"""
|
| 19 |
+
Computes a scale transform (s) in a batched way that takes
|
| 20 |
+
a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3).
|
| 21 |
+
Args:
|
| 22 |
+
S1 (torch.Tensor): First set of points of shape (B, N, 3).
|
| 23 |
+
S2 (torch.Tensor): Second set of points of shape (B, N, 3).
|
| 24 |
+
Returns:
|
| 25 |
+
(torch.Tensor): The first set of points after applying the scale transformation.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
# 1. Remove mean.
|
| 29 |
+
mu1 = S1.mean(dim=1, keepdim=True)
|
| 30 |
+
mu2 = S2.mean(dim=1, keepdim=True)
|
| 31 |
+
X1 = S1 - mu1
|
| 32 |
+
X2 = S2 - mu2
|
| 33 |
+
|
| 34 |
+
# 2. Compute variance of X1 used for scale.
|
| 35 |
+
var1 = (X1 ** 2).sum(dim=(1, 2), keepdim=True)
|
| 36 |
+
|
| 37 |
+
# 3. Compute scale.
|
| 38 |
+
scale = (X2 * X1).sum(dim=(1, 2), keepdim=True) / var1
|
| 39 |
+
|
| 40 |
+
# 4. Apply scale transform.
|
| 41 |
+
S1_hat = scale * X1 + mu2
|
| 42 |
+
|
| 43 |
+
return S1_hat
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def compute_similarity_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
"""
|
| 48 |
+
Computes a similarity transform (sR, t) in a batched way that takes
|
| 49 |
+
a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3),
|
| 50 |
+
where R is a 3x3 rotation matrix, t 3x1 translation, s scale.
|
| 51 |
+
i.e. solves the orthogonal Procrutes problem.
|
| 52 |
+
Args:
|
| 53 |
+
S1 (torch.Tensor): First set of points of shape (B, N, 3).
|
| 54 |
+
S2 (torch.Tensor): Second set of points of shape (B, N, 3).
|
| 55 |
+
Returns:
|
| 56 |
+
(torch.Tensor): The first set of points after applying the similarity transformation.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
batch_size = S1.shape[0]
|
| 60 |
+
S1 = S1.permute(0, 2, 1)
|
| 61 |
+
S2 = S2.permute(0, 2, 1)
|
| 62 |
+
# 1. Remove mean.
|
| 63 |
+
mu1 = S1.mean(dim=2, keepdim=True)
|
| 64 |
+
mu2 = S2.mean(dim=2, keepdim=True)
|
| 65 |
+
X1 = S1 - mu1
|
| 66 |
+
X2 = S2 - mu2
|
| 67 |
+
|
| 68 |
+
# 2. Compute variance of X1 used for scale.
|
| 69 |
+
var1 = (X1 ** 2).sum(dim=(1, 2))
|
| 70 |
+
|
| 71 |
+
# 3. The outer product of X1 and X2.
|
| 72 |
+
K = torch.matmul(X1.float(), X2.permute(0, 2, 1))
|
| 73 |
+
|
| 74 |
+
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K.
|
| 75 |
+
U, s, V = torch.svd(K.float())
|
| 76 |
+
Vh = V.permute(0, 2, 1)
|
| 77 |
+
|
| 78 |
+
# Construct Z that fixes the orientation of R to get det(R)=1.
|
| 79 |
+
Z = torch.eye(U.shape[1], device=U.device).unsqueeze(0).repeat(batch_size, 1, 1).float()
|
| 80 |
+
Z[:, -1, -1] *= torch.sign(torch.linalg.det(torch.matmul(U.float(), Vh.float()).float()))
|
| 81 |
+
|
| 82 |
+
# Construct R.
|
| 83 |
+
R = torch.matmul(torch.matmul(V, Z), U.permute(0, 2, 1))
|
| 84 |
+
|
| 85 |
+
# 5. Recover scale.
|
| 86 |
+
trace = torch.matmul(R, K).diagonal(offset=0, dim1=-1, dim2=-2).sum(dim=-1)
|
| 87 |
+
scale = (trace / var1).unsqueeze(dim=-1).unsqueeze(dim=-1)
|
| 88 |
+
|
| 89 |
+
# 6. Recover translation.
|
| 90 |
+
t = mu2 - scale * torch.matmul(R.float(), mu1.float())
|
| 91 |
+
|
| 92 |
+
# 7. Error:
|
| 93 |
+
S1_hat = scale * torch.matmul(R.float(), S1.float()).float() + t
|
| 94 |
+
|
| 95 |
+
return S1_hat.permute(0, 2, 1)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def pointcloud(points: np.ndarray):
|
| 99 |
+
pcd = o3d.geometry.PointCloud()
|
| 100 |
+
points = o3d.utility.Vector3dVector(points)
|
| 101 |
+
pcd.points = points
|
| 102 |
+
return pcd
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Evaluator:
|
| 106 |
+
def __init__(self, smal_model, image_size: int=256, pelvis_ind: int = 7):
|
| 107 |
+
self.pelvis_ind = pelvis_ind
|
| 108 |
+
self.smal_model = smal_model
|
| 109 |
+
self.image_size = image_size
|
| 110 |
+
|
| 111 |
+
def compute_pck(self, output: Dict, batch: Dict, pck_threshold: Union[List, None]):
|
| 112 |
+
if pck_threshold is None or len(pck_threshold) == 0:
|
| 113 |
+
return torch.tensor([], dtype=torch.float32)
|
| 114 |
+
|
| 115 |
+
pred_keypoints_2d = output['pred_keypoints_2d'].detach().cpu()
|
| 116 |
+
gt_keypoints_2d = batch['keypoints_2d'].detach().cpu()
|
| 117 |
+
|
| 118 |
+
pred_keypoints_2d = (pred_keypoints_2d + 0.5) * self.image_size
|
| 119 |
+
conf = gt_keypoints_2d[:, :, -1]
|
| 120 |
+
gt_keypoints_2d = (gt_keypoints_2d[:, :, :-1] + 0.5) * self.image_size
|
| 121 |
+
|
| 122 |
+
if 'mask' in batch and batch['mask'] is not None:
|
| 123 |
+
seg_area = torch.sum(batch['mask'].detach().cpu().reshape(batch['mask'].shape[0], -1), dim=-1).unsqueeze(-1)
|
| 124 |
+
else:
|
| 125 |
+
seg_area = torch.tensor([self.image_size * self.image_size] * len(pred_keypoints_2d), dtype=torch.float32).unsqueeze(-1)
|
| 126 |
+
|
| 127 |
+
total_visible = torch.sum(conf, dim=-1).clamp_min(1e-6) # (B,)
|
| 128 |
+
dist = torch.norm(pred_keypoints_2d - gt_keypoints_2d, dim=-1) # (B, K)
|
| 129 |
+
norm_dist = dist / torch.sqrt(seg_area) # (B, K)
|
| 130 |
+
|
| 131 |
+
thresholds = torch.tensor(pck_threshold, dtype=torch.float32).view(-1, 1, 1) # (T, 1, 1)
|
| 132 |
+
hits = (norm_dist.unsqueeze(0) < thresholds).float() # (T, B, K)
|
| 133 |
+
pcks = (hits * conf.unsqueeze(0)).sum(dim=-1) / total_visible.unsqueeze(0) # (T, B)
|
| 134 |
+
return pcks.mean(dim=1) # (T,)
|
| 135 |
+
|
| 136 |
+
def compute_pa_mpjpe(self, pred_joints, gt_joints):
|
| 137 |
+
S1_hat = compute_similarity_transform(pred_joints, gt_joints)
|
| 138 |
+
pa_mpjpe = torch.sqrt(((S1_hat - gt_joints) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * 1000
|
| 139 |
+
return pa_mpjpe.mean()
|
| 140 |
+
|
| 141 |
+
def compute_pa_mpvpe(self, gt_vertices: torch.Tensor, pred_vertices: torch.Tensor):
|
| 142 |
+
batch_size = pred_vertices.shape[0]
|
| 143 |
+
S1_hat = compute_similarity_transform(pred_vertices, gt_vertices)
|
| 144 |
+
pa_mpvpe = torch.sqrt(((S1_hat - gt_vertices) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * 1000
|
| 145 |
+
return pa_mpvpe.mean()
|
| 146 |
+
|
| 147 |
+
def eval_3d(self, output: Dict, batch: Dict):
|
| 148 |
+
"""
|
| 149 |
+
Evaluate current batch
|
| 150 |
+
Args:
|
| 151 |
+
output: model output
|
| 152 |
+
batch: model input
|
| 153 |
+
Returns: evaluate metric
|
| 154 |
+
"""
|
| 155 |
+
if batch['has_smal_params']["betas"].sum() == 0:
|
| 156 |
+
return 0., 0.
|
| 157 |
+
|
| 158 |
+
pred_keypoints_3d = output["pred_keypoints_3d"].detach()
|
| 159 |
+
pred_keypoints_3d = pred_keypoints_3d[:, None, :, :]
|
| 160 |
+
batch_size = pred_keypoints_3d.shape[0]
|
| 161 |
+
num_samples = pred_keypoints_3d.shape[1]
|
| 162 |
+
gt_keypoints_3d = batch['keypoints_3d'][:, :, :-1].unsqueeze(1).repeat(1, num_samples, 1, 1)
|
| 163 |
+
gt_vertices = self.smal_forward(batch)
|
| 164 |
+
|
| 165 |
+
# Align predictions and ground truth such that the pelvis location is at the origin
|
| 166 |
+
pred_keypoints_3d -= pred_keypoints_3d[:, :, [self.pelvis_ind]]
|
| 167 |
+
gt_keypoints_3d -= gt_keypoints_3d[:, :, [self.pelvis_ind]]
|
| 168 |
+
|
| 169 |
+
pa_mpjpe = self.compute_pa_mpjpe(pred_keypoints_3d.reshape(batch_size * num_samples, -1, 3),
|
| 170 |
+
gt_keypoints_3d.reshape(batch_size * num_samples, -1, 3))
|
| 171 |
+
pa_mpvpe = self.compute_pa_mpvpe(gt_vertices, output['pred_vertices'])
|
| 172 |
+
return pa_mpjpe, pa_mpvpe
|
| 173 |
+
|
| 174 |
+
def eval_2d(self, output: Dict, batch: Dict, pck_threshold: List[float]=[0.10, 0.15]):
|
| 175 |
+
pck = self.compute_pck(output, batch, pck_threshold=pck_threshold)
|
| 176 |
+
auc = self.compute_auc(batch, output)
|
| 177 |
+
return pck.tolist(), auc
|
| 178 |
+
|
| 179 |
+
def compute_auc(self, batch: Dict, output: Dict, threshold_min: float=0.0, threshold_max: float=1.0, steps: int=100):
|
| 180 |
+
thresholds = np.linspace(threshold_min, threshold_max, steps)
|
| 181 |
+
pck_curve = self.compute_pck(output, batch, thresholds.tolist()).numpy() # (steps,)
|
| 182 |
+
norm_factor = threshold_max - threshold_min
|
| 183 |
+
auc = float(np.trapz(pck_curve, thresholds) / norm_factor)
|
| 184 |
+
return auc
|
| 185 |
+
|
| 186 |
+
def smal_forward(self, batch: Dict):
|
| 187 |
+
batch_size = batch['img'].shape[0]
|
| 188 |
+
smal_params = batch['smal_params']
|
| 189 |
+
smal_params['global_orient'] = axis_angle_to_matrix(smal_params['global_orient'].reshape(batch_size, -1)).unsqueeze(1)
|
| 190 |
+
smal_params['pose'] = axis_angle_to_matrix(smal_params['pose'].reshape(batch_size, -1, 3))
|
| 191 |
+
# The SMAL model only registers buffers (e.g. shapedirs) and has no trainable parameters,
|
| 192 |
+
# so self.smal_model.parameters() can be empty and calling next on it would raise StopIteration.
|
| 193 |
+
# Here we first try to get the device from parameters; if there are no parameters, fall back to buffers;
|
| 194 |
+
# if there are no buffers either, fall back to the device of the input batch.
|
| 195 |
+
try:
|
| 196 |
+
device = next(self.smal_model.parameters()).device
|
| 197 |
+
except StopIteration:
|
| 198 |
+
try:
|
| 199 |
+
device = next(self.smal_model.buffers()).device
|
| 200 |
+
except StopIteration:
|
| 201 |
+
device = batch['img'].device
|
| 202 |
+
smal_params = {k: v.to(device) for k, v in smal_params.items()}
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
smal_output = self.smal_model(**smal_params)
|
| 205 |
+
vertices = smal_output.vertices
|
| 206 |
+
return vertices
|
prima/utils/geometry.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Optional
|
| 11 |
+
import torch
|
| 12 |
+
from torch.nn import functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def aa_to_rotmat(theta: torch.Tensor):
|
| 16 |
+
"""
|
| 17 |
+
Convert axis-angle representation to rotation matrix.
|
| 18 |
+
Works by first converting it to a quaternion.
|
| 19 |
+
Args:
|
| 20 |
+
theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
|
| 21 |
+
Returns:
|
| 22 |
+
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
|
| 23 |
+
"""
|
| 24 |
+
norm = torch.norm(theta + 1e-8, p=2, dim=1)
|
| 25 |
+
angle = torch.unsqueeze(norm, -1)
|
| 26 |
+
normalized = torch.div(theta, angle)
|
| 27 |
+
angle = angle * 0.5
|
| 28 |
+
v_cos = torch.cos(angle)
|
| 29 |
+
v_sin = torch.sin(angle)
|
| 30 |
+
quat = torch.cat([v_cos, v_sin * normalized], dim=1)
|
| 31 |
+
return quat_to_rotmat(quat)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
"""
|
| 36 |
+
Convert quaternion representation to rotation matrix.
|
| 37 |
+
Args:
|
| 38 |
+
quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
|
| 39 |
+
Returns:
|
| 40 |
+
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
|
| 41 |
+
"""
|
| 42 |
+
norm_quat = quat
|
| 43 |
+
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
|
| 44 |
+
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
|
| 45 |
+
|
| 46 |
+
B = quat.size(0)
|
| 47 |
+
|
| 48 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
| 49 |
+
wx, wy, wz = w * x, w * y, w * z
|
| 50 |
+
xy, xz, yz = x * y, x * z, y * z
|
| 51 |
+
|
| 52 |
+
rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz,
|
| 53 |
+
2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx,
|
| 54 |
+
2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
|
| 55 |
+
return rotMat
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
Convert 6D rotation representation to 3x3 rotation matrix.
|
| 61 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
| 62 |
+
Args:
|
| 63 |
+
x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
|
| 64 |
+
Returns:
|
| 65 |
+
torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
|
| 66 |
+
"""
|
| 67 |
+
x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous()
|
| 68 |
+
a1 = x[:, :, 0]
|
| 69 |
+
a2 = x[:, :, 1]
|
| 70 |
+
b1 = F.normalize(a1)
|
| 71 |
+
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
|
| 72 |
+
b3 = torch.cross(b1, b2, dim=1)
|
| 73 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def perspective_projection(points: torch.Tensor,
|
| 77 |
+
translation: torch.Tensor,
|
| 78 |
+
focal_length: torch.Tensor,
|
| 79 |
+
camera_center: Optional[torch.Tensor] = None,
|
| 80 |
+
rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 81 |
+
"""
|
| 82 |
+
Computes the perspective projection of a set of 3D points.
|
| 83 |
+
Args:
|
| 84 |
+
points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
|
| 85 |
+
translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
|
| 86 |
+
focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
|
| 87 |
+
camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
|
| 88 |
+
rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
|
| 89 |
+
Returns:
|
| 90 |
+
torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
|
| 91 |
+
"""
|
| 92 |
+
batch_size = points.shape[0]
|
| 93 |
+
if rotation is None:
|
| 94 |
+
rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
|
| 95 |
+
if camera_center is None:
|
| 96 |
+
camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
|
| 97 |
+
# Populate intrinsic camera matrix K.
|
| 98 |
+
K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
|
| 99 |
+
K[:, 0, 0] = focal_length[:, 0]
|
| 100 |
+
K[:, 1, 1] = focal_length[:, 1]
|
| 101 |
+
K[:, 2, 2] = 1.
|
| 102 |
+
K[:, :-1, -1] = camera_center
|
| 103 |
+
|
| 104 |
+
# Transform points
|
| 105 |
+
points = torch.einsum('bij,bkj->bki', rotation, points)
|
| 106 |
+
points = points + translation.unsqueeze(1)
|
| 107 |
+
|
| 108 |
+
# Apply perspective distortion
|
| 109 |
+
projected_points = points / points[:, :, -1].unsqueeze(-1)
|
| 110 |
+
|
| 111 |
+
# Apply camera intrinsics
|
| 112 |
+
projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
return projected_points[:, :, :-1]
|
prima/utils/mesh_renderer.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
from ctypes.util import find_library
|
| 12 |
+
|
| 13 |
+
if 'PYOPENGL_PLATFORM' not in os.environ and os.uname().sysname != 'Darwin':
|
| 14 |
+
# Prefer EGL; PyOpenGL's OSMesa bindings can lack symbols required by pyrender.
|
| 15 |
+
os.environ['PYOPENGL_PLATFORM'] = 'egl' if find_library('EGL') else 'osmesa'
|
| 16 |
+
if os.environ['PYOPENGL_PLATFORM'] == 'egl':
|
| 17 |
+
os.environ.setdefault('EGL_PLATFORM', 'surfaceless')
|
| 18 |
+
import torch
|
| 19 |
+
from torchvision.utils import make_grid
|
| 20 |
+
import numpy as np
|
| 21 |
+
import pyrender
|
| 22 |
+
import trimesh
|
| 23 |
+
import cv2
|
| 24 |
+
import math
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
from typing import List, Tuple
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_raymond_lights():
|
| 30 |
+
import pyrender
|
| 31 |
+
thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
|
| 32 |
+
phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
|
| 33 |
+
|
| 34 |
+
nodes = []
|
| 35 |
+
|
| 36 |
+
for phi, theta in zip(phis, thetas):
|
| 37 |
+
xp = np.sin(theta) * np.cos(phi)
|
| 38 |
+
yp = np.sin(theta) * np.sin(phi)
|
| 39 |
+
zp = np.cos(theta)
|
| 40 |
+
|
| 41 |
+
z = np.array([xp, yp, zp])
|
| 42 |
+
z = z / np.linalg.norm(z)
|
| 43 |
+
x = np.array([-z[1], z[0], 0.0])
|
| 44 |
+
if np.linalg.norm(x) == 0:
|
| 45 |
+
x = np.array([1.0, 0.0, 0.0])
|
| 46 |
+
x = x / np.linalg.norm(x)
|
| 47 |
+
y = np.cross(z, x)
|
| 48 |
+
|
| 49 |
+
matrix = np.eye(4)
|
| 50 |
+
matrix[:3, :3] = np.c_[x, y, z]
|
| 51 |
+
nodes.append(pyrender.Node(
|
| 52 |
+
light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0),
|
| 53 |
+
matrix=matrix
|
| 54 |
+
))
|
| 55 |
+
|
| 56 |
+
return nodes
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]:
|
| 60 |
+
"""
|
| 61 |
+
Compute rectangle enclosing keypoints above the threshold.
|
| 62 |
+
Args:
|
| 63 |
+
keypoints (np.array): Keypoint array of shape (N, 3).
|
| 64 |
+
threshold (float): Confidence visualization threshold.
|
| 65 |
+
Returns:
|
| 66 |
+
Tuple[float, float, float]: Rectangle width, height and area.
|
| 67 |
+
"""
|
| 68 |
+
valid_ind = keypoints[:, -1] > threshold
|
| 69 |
+
if valid_ind.sum() > 0:
|
| 70 |
+
valid_keypoints = keypoints[valid_ind][:, :-1]
|
| 71 |
+
max_x = valid_keypoints[:, 0].max()
|
| 72 |
+
max_y = valid_keypoints[:, 1].max()
|
| 73 |
+
min_x = valid_keypoints[:, 0].min()
|
| 74 |
+
min_y = valid_keypoints[:, 1].min()
|
| 75 |
+
width = max_x - min_x
|
| 76 |
+
height = max_y - min_y
|
| 77 |
+
area = width * height
|
| 78 |
+
return width, height, area
|
| 79 |
+
else:
|
| 80 |
+
return 0, 0, 0
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def render_keypoint(img: np.array, keypoint: np.array, threshold=0.1,
|
| 84 |
+
use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0) -> np.array:
|
| 85 |
+
if use_confidence and map_fn is not None:
|
| 86 |
+
thicknessCircleRatioRight = 1. / 50 * map_fn(keypoint[:, -1])
|
| 87 |
+
else:
|
| 88 |
+
thicknessCircleRatioRight = 1. / 50 * np.ones(keypoint.shape[0])
|
| 89 |
+
|
| 90 |
+
thicknessLineRatioWRTCircle = 0.75
|
| 91 |
+
if keypoint.shape[0] == 26:
|
| 92 |
+
pairs = [0, 24, 1, 24, 2, 24, 3, 14, 4, 15, 5, 16, 6, 17, 7, 18, 8, 12, 9, 13, 10, 7, 11, 7,
|
| 93 |
+
12, 18, 13, 18, 14, 8, 15, 9, 16, 10, 17, 11, 18, 24, 19, 25, 20, 0, 21, 1, 22, 24,
|
| 94 |
+
23, 24, 25, 7]
|
| 95 |
+
elif keypoint.shape[0] == 18:
|
| 96 |
+
pairs = [9, 8, 8, 2, 2, 3, 3, 4, 2, 0, 2, 1, 4, 5,
|
| 97 |
+
5, 14, 14, 15, 4, 6, 6, 7, 7, 11, 11, 10,
|
| 98 |
+
7, 13, 13, 12, 5, 16, 5, 17]
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError("Keypoint shape not supported")
|
| 101 |
+
pairs = np.array(pairs).reshape(-1, 2) if pairs is not None else None
|
| 102 |
+
colors = [255., 0., 85.,
|
| 103 |
+
255., 0., 0.,
|
| 104 |
+
255., 85., 0.,
|
| 105 |
+
255., 170., 0.,
|
| 106 |
+
255., 255., 0.,
|
| 107 |
+
170., 255., 0.,
|
| 108 |
+
85., 255., 0.,
|
| 109 |
+
0., 255., 0.,
|
| 110 |
+
255., 0., 0.,
|
| 111 |
+
0., 255., 85.,
|
| 112 |
+
0., 255., 170.,
|
| 113 |
+
0., 255., 255.,
|
| 114 |
+
0., 170., 255.,
|
| 115 |
+
0., 85., 255.,
|
| 116 |
+
0., 0., 255.,
|
| 117 |
+
255., 0., 170.,
|
| 118 |
+
170., 0., 255.,
|
| 119 |
+
255., 0., 255.,
|
| 120 |
+
85., 0., 255.,
|
| 121 |
+
0., 0., 255.,
|
| 122 |
+
0., 0., 255.,
|
| 123 |
+
0., 0., 255.,
|
| 124 |
+
0., 255., 255.,
|
| 125 |
+
0., 255., 255.,
|
| 126 |
+
0., 255., 255.,
|
| 127 |
+
255., 225., 255.]
|
| 128 |
+
colors = np.array(colors).reshape(-1, 3)
|
| 129 |
+
poseScales = [1]
|
| 130 |
+
|
| 131 |
+
img_orig = img.copy()
|
| 132 |
+
width, height = img.shape[1], img.shape[2]
|
| 133 |
+
area = width * height
|
| 134 |
+
|
| 135 |
+
lineType = 8
|
| 136 |
+
shift = 0
|
| 137 |
+
numberColors = len(colors)
|
| 138 |
+
thresholdRectangle = 0.1
|
| 139 |
+
|
| 140 |
+
animal_width, animal_height, animal_area = get_keypoints_rectangle(keypoint, thresholdRectangle)
|
| 141 |
+
if animal_area > 0:
|
| 142 |
+
ratioAreas = min(1, max(animal_width / width, animal_height / height))
|
| 143 |
+
thicknessRatio = np.maximum(np.round(math.sqrt(area) * thicknessCircleRatioRight * ratioAreas), 2)
|
| 144 |
+
thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio))
|
| 145 |
+
thicknessLine = np.maximum(1, np.round(thicknessRatio * thicknessLineRatioWRTCircle))
|
| 146 |
+
radius = thicknessRatio / 2
|
| 147 |
+
else:
|
| 148 |
+
return img
|
| 149 |
+
|
| 150 |
+
img = np.ascontiguousarray(img.copy())
|
| 151 |
+
if pairs is not None:
|
| 152 |
+
for i, pair in enumerate(pairs):
|
| 153 |
+
index1, index2 = pair
|
| 154 |
+
if keypoint[index1, -1] > threshold and keypoint[index2, -1] > threshold:
|
| 155 |
+
thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * poseScales[0]))
|
| 156 |
+
colorIndex = index2
|
| 157 |
+
color = colors[colorIndex % numberColors]
|
| 158 |
+
keypoint1 = keypoint[index1, :-1].astype(np.int32)
|
| 159 |
+
keypoint2 = keypoint[index2, :-1].astype(np.int32)
|
| 160 |
+
cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()),
|
| 161 |
+
thicknessLineScaled, lineType, shift)
|
| 162 |
+
for part in range(len(keypoint)):
|
| 163 |
+
faceIndex = part
|
| 164 |
+
if keypoint[faceIndex, -1] > threshold:
|
| 165 |
+
radiusScaled = int(round(radius[faceIndex] * poseScales[0]))
|
| 166 |
+
thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * poseScales[0]))
|
| 167 |
+
colorIndex = part
|
| 168 |
+
color = colors[colorIndex % numberColors]
|
| 169 |
+
center = keypoint[faceIndex, :-1].astype(np.int32)
|
| 170 |
+
cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled,
|
| 171 |
+
lineType, shift)
|
| 172 |
+
|
| 173 |
+
return img
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class MeshRenderer:
|
| 177 |
+
|
| 178 |
+
def __init__(self, cfg, faces=None):
|
| 179 |
+
self.cfg = cfg
|
| 180 |
+
self.img_res = cfg.MODEL.IMAGE_SIZE
|
| 181 |
+
self.renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res,
|
| 182 |
+
viewport_height=self.img_res,
|
| 183 |
+
point_size=1.0)
|
| 184 |
+
|
| 185 |
+
self.camera_center = [self.img_res // 2, self.img_res // 2]
|
| 186 |
+
self.faces = faces
|
| 187 |
+
|
| 188 |
+
def visualize(self, vertices, camera_translation, images, focal_length, nrow=3, padding=2):
|
| 189 |
+
images_np = np.transpose(images, (0, 2, 3, 1))
|
| 190 |
+
rend_imgs = []
|
| 191 |
+
for i in range(vertices.shape[0]):
|
| 192 |
+
rend_img = torch.from_numpy(np.transpose(
|
| 193 |
+
self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=False),
|
| 194 |
+
(2, 0, 1))).float()
|
| 195 |
+
rend_img_side = torch.from_numpy(np.transpose(
|
| 196 |
+
self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=True),
|
| 197 |
+
(2, 0, 1))).float()
|
| 198 |
+
rend_imgs.append(torch.from_numpy(images[i]))
|
| 199 |
+
rend_imgs.append(rend_img)
|
| 200 |
+
rend_imgs.append(rend_img_side)
|
| 201 |
+
rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding)
|
| 202 |
+
return rend_imgs
|
| 203 |
+
|
| 204 |
+
def visualize_tensorboard(self, vertices, camera_translation, images, focal_length, pred_keypoints, gt_keypoints,
|
| 205 |
+
pred_masks=None, gt_masks=None):
|
| 206 |
+
images_np = np.transpose(images, (0, 2, 3, 1))
|
| 207 |
+
rend_imgs = []
|
| 208 |
+
pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1)
|
| 209 |
+
pred_keypoints = self.img_res * (pred_keypoints + 0.5)
|
| 210 |
+
gt_keypoints[:, :, :-1] = self.img_res * (gt_keypoints[:, :, :-1] + 0.5)
|
| 211 |
+
# keypoint_matches = [(1, 12), (2, 8), (3, 7), (4, 6), (5, 9),
|
| 212 |
+
# (6, 10), (7, 11), (8, 14), (9, 2), (10, 1), (11, 0), (12, 3), (13, 4), (14, 5)]
|
| 213 |
+
# rend_img_pytorch3d = self.render_by_pytorch3d(vertices, camera_translation,
|
| 214 |
+
# images_np, focal_length=self.focal_length)
|
| 215 |
+
for i in range(vertices.shape[0]):
|
| 216 |
+
rend_img = torch.from_numpy(np.transpose(
|
| 217 |
+
self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=False),
|
| 218 |
+
(2, 0, 1))).float()
|
| 219 |
+
rend_img_side = torch.from_numpy(np.transpose(
|
| 220 |
+
self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=True),
|
| 221 |
+
(2, 0, 1))).float()
|
| 222 |
+
keypoints = pred_keypoints[i]
|
| 223 |
+
pred_keypoints_img = render_keypoint(255 * images_np[i].copy(), keypoints) / 255
|
| 224 |
+
keypoints = gt_keypoints[i]
|
| 225 |
+
gt_keypoints_img = render_keypoint(255 * images_np[i].copy(), keypoints) / 255
|
| 226 |
+
rend_imgs.append(torch.from_numpy(images[i]))
|
| 227 |
+
rend_imgs.append(rend_img)
|
| 228 |
+
rend_imgs.append(rend_img_side)
|
| 229 |
+
if pred_masks is not None:
|
| 230 |
+
rend_imgs.append(torch.from_numpy(pred_masks[i]))
|
| 231 |
+
if gt_masks is not None:
|
| 232 |
+
rend_imgs.append(torch.from_numpy(gt_masks[i]))
|
| 233 |
+
rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2, 0, 1))
|
| 234 |
+
rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2, 0, 1))
|
| 235 |
+
return rend_imgs
|
| 236 |
+
|
| 237 |
+
def __call__(self, vertices, camera_translation, image, focal_length, text=None, resize=None, side_view=False,
|
| 238 |
+
baseColorFactor=(1.0, 1.0, 0.9, 1.0), rot_angle=90):
|
| 239 |
+
renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1],
|
| 240 |
+
viewport_height=image.shape[0],
|
| 241 |
+
point_size=1.0)
|
| 242 |
+
material = pyrender.MetallicRoughnessMaterial(
|
| 243 |
+
metallicFactor=0.0,
|
| 244 |
+
alphaMode='OPAQUE',
|
| 245 |
+
baseColorFactor=baseColorFactor)
|
| 246 |
+
|
| 247 |
+
camera_translation_local = camera_translation.copy()
|
| 248 |
+
camera_translation_local[0] *= -1.
|
| 249 |
+
|
| 250 |
+
mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
|
| 251 |
+
if side_view:
|
| 252 |
+
rot = trimesh.transformations.rotation_matrix(
|
| 253 |
+
np.radians(rot_angle), [0, 1, 0])
|
| 254 |
+
mesh.apply_transform(rot)
|
| 255 |
+
rot = trimesh.transformations.rotation_matrix(
|
| 256 |
+
np.radians(180), [1, 0, 0])
|
| 257 |
+
mesh.apply_transform(rot)
|
| 258 |
+
mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
|
| 259 |
+
|
| 260 |
+
scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
|
| 261 |
+
ambient_light=(0.3, 0.3, 0.3))
|
| 262 |
+
scene.add(mesh, 'mesh')
|
| 263 |
+
|
| 264 |
+
camera_pose = np.eye(4)
|
| 265 |
+
camera_pose[:3, 3] = camera_translation_local
|
| 266 |
+
camera_center = [image.shape[1] / 2., image.shape[0] / 2.]
|
| 267 |
+
camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
|
| 268 |
+
cx=camera_center[0], cy=camera_center[1],
|
| 269 |
+
zfar=1000)
|
| 270 |
+
scene.add(camera, pose=camera_pose)
|
| 271 |
+
|
| 272 |
+
light_nodes = create_raymond_lights()
|
| 273 |
+
for node in light_nodes:
|
| 274 |
+
scene.add_node(node)
|
| 275 |
+
|
| 276 |
+
color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
|
| 277 |
+
color = color.astype(np.float32) / 255.0
|
| 278 |
+
valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
|
| 279 |
+
if not side_view:
|
| 280 |
+
output_img = (color[:, :, :3] * valid_mask +
|
| 281 |
+
(1 - valid_mask) * image)
|
| 282 |
+
else:
|
| 283 |
+
output_img = color[:, :, :3]
|
| 284 |
+
if resize is not None:
|
| 285 |
+
output_img = cv2.resize(output_img, resize)
|
| 286 |
+
|
| 287 |
+
output_img = output_img.astype(np.float32)
|
| 288 |
+
renderer.delete()
|
| 289 |
+
return output_img
|
| 290 |
+
|
| 291 |
+
def render_mask(self, vertices, camera_translation, focal_length, side_view=False, rot_angle=90):
|
| 292 |
+
"""
|
| 293 |
+
Render only the visibility mask (alpha>0) of the mesh given vertices and camera translation.
|
| 294 |
+
Returns a single-channel float32 numpy array with values 0.0 or 1.0 with shape (H, W).
|
| 295 |
+
"""
|
| 296 |
+
renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res,
|
| 297 |
+
viewport_height=self.img_res,
|
| 298 |
+
point_size=1.0)
|
| 299 |
+
|
| 300 |
+
mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
|
| 301 |
+
if side_view:
|
| 302 |
+
rot = trimesh.transformations.rotation_matrix(
|
| 303 |
+
np.radians(rot_angle), [0, 1, 0])
|
| 304 |
+
mesh.apply_transform(rot)
|
| 305 |
+
rot = trimesh.transformations.rotation_matrix(
|
| 306 |
+
np.radians(180), [1, 0, 0])
|
| 307 |
+
mesh.apply_transform(rot)
|
| 308 |
+
mesh = pyrender.Mesh.from_trimesh(mesh)
|
| 309 |
+
|
| 310 |
+
scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
|
| 311 |
+
ambient_light=(0.3, 0.3, 0.3))
|
| 312 |
+
scene.add(mesh, 'mesh')
|
| 313 |
+
|
| 314 |
+
camera_pose = np.eye(4)
|
| 315 |
+
camera_pose[:3, 3] = camera_translation
|
| 316 |
+
camera_center = [self.img_res / 2., self.img_res / 2.]
|
| 317 |
+
camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
|
| 318 |
+
cx=camera_center[0], cy=camera_center[1],
|
| 319 |
+
zfar=1000)
|
| 320 |
+
scene.add(camera, pose=camera_pose)
|
| 321 |
+
|
| 322 |
+
light_nodes = create_raymond_lights()
|
| 323 |
+
for node in light_nodes:
|
| 324 |
+
scene.add_node(node)
|
| 325 |
+
|
| 326 |
+
color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
|
| 327 |
+
# alpha channel indicates visibility
|
| 328 |
+
mask = (color[:, :, -1] > 0).astype(np.float32)
|
| 329 |
+
renderer.delete()
|
| 330 |
+
return mask
|
prima/utils/misc.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import time
|
| 11 |
+
import warnings
|
| 12 |
+
from importlib.util import find_spec
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Callable, List
|
| 15 |
+
|
| 16 |
+
import hydra
|
| 17 |
+
from omegaconf import DictConfig, OmegaConf
|
| 18 |
+
from pytorch_lightning import Callback
|
| 19 |
+
from pytorch_lightning.loggers import Logger
|
| 20 |
+
from pytorch_lightning.utilities import rank_zero_only
|
| 21 |
+
|
| 22 |
+
from . import pylogger, rich_utils
|
| 23 |
+
|
| 24 |
+
log = pylogger.get_pylogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def task_wrapper(task_func: Callable) -> Callable:
|
| 28 |
+
"""Optional decorator that wraps the task function in extra utilities.
|
| 29 |
+
|
| 30 |
+
Makes multirun more resistant to failure.
|
| 31 |
+
|
| 32 |
+
Utilities:
|
| 33 |
+
- Calling the `utils.extras()` before the task is started
|
| 34 |
+
- Calling the `utils.close_loggers()` after the task is finished
|
| 35 |
+
- Logging the exception if occurs
|
| 36 |
+
- Logging the task total execution time
|
| 37 |
+
- Logging the output dir
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def wrap(cfg: DictConfig):
|
| 41 |
+
start_time = time.time()
|
| 42 |
+
try:
|
| 43 |
+
# apply extra utilities
|
| 44 |
+
extras(cfg)
|
| 45 |
+
|
| 46 |
+
# execute the task
|
| 47 |
+
ret = task_func(cfg=cfg)
|
| 48 |
+
except Exception as ex:
|
| 49 |
+
log.exception("") # save exception to `.log` file
|
| 50 |
+
raise ex
|
| 51 |
+
finally:
|
| 52 |
+
path = Path(cfg.paths.output_dir, "exec_time.log")
|
| 53 |
+
content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)"
|
| 54 |
+
save_file(path, content) # save task execution time (even if exception occurs)
|
| 55 |
+
close_loggers() # close loggers (even if exception occurs so multirun won't fail)
|
| 56 |
+
|
| 57 |
+
log.info(f"Output dir: {cfg.paths.output_dir}")
|
| 58 |
+
|
| 59 |
+
return ret
|
| 60 |
+
|
| 61 |
+
return wrap
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def extras(cfg: DictConfig) -> None:
|
| 65 |
+
"""Applies optional utilities before the task is started.
|
| 66 |
+
|
| 67 |
+
Utilities:
|
| 68 |
+
- Ignoring python warnings
|
| 69 |
+
- Setting tags from command line
|
| 70 |
+
- Rich config printing
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
# return if no `extras` config
|
| 74 |
+
if not cfg.get("extras"):
|
| 75 |
+
log.warning("Extras config not found! <cfg.extras=null>")
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
# disable python warnings
|
| 79 |
+
if cfg.extras.get("ignore_warnings"):
|
| 80 |
+
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
|
| 81 |
+
warnings.filterwarnings("ignore")
|
| 82 |
+
|
| 83 |
+
# prompt user to input tags from command line if none are provided in the config
|
| 84 |
+
if cfg.extras.get("enforce_tags"):
|
| 85 |
+
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
|
| 86 |
+
rich_utils.enforce_tags(cfg, save_to_file=True)
|
| 87 |
+
|
| 88 |
+
# pretty print config tree using Rich library
|
| 89 |
+
if cfg.extras.get("print_config"):
|
| 90 |
+
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
|
| 91 |
+
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@rank_zero_only
|
| 95 |
+
def save_file(path: str, content: str) -> None:
|
| 96 |
+
"""Save file in rank zero mode (only on one process in multi-GPU setup)."""
|
| 97 |
+
with open(path, "w+") as file:
|
| 98 |
+
file.write(content)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
| 102 |
+
"""Instantiates callbacks from config."""
|
| 103 |
+
callbacks: List[Callback] = []
|
| 104 |
+
|
| 105 |
+
if not callbacks_cfg:
|
| 106 |
+
log.warning("Callbacks config is empty.")
|
| 107 |
+
return callbacks
|
| 108 |
+
|
| 109 |
+
if not isinstance(callbacks_cfg, DictConfig):
|
| 110 |
+
raise TypeError("Callbacks config must be a DictConfig!")
|
| 111 |
+
|
| 112 |
+
for _, cb_conf in callbacks_cfg.items():
|
| 113 |
+
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
| 114 |
+
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
| 115 |
+
callbacks.append(hydra.utils.instantiate(cb_conf))
|
| 116 |
+
|
| 117 |
+
return callbacks
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
| 121 |
+
"""Instantiates loggers from config."""
|
| 122 |
+
logger: List[Logger] = []
|
| 123 |
+
|
| 124 |
+
if not logger_cfg:
|
| 125 |
+
log.warning("Logger config is empty.")
|
| 126 |
+
return logger
|
| 127 |
+
|
| 128 |
+
if not isinstance(logger_cfg, DictConfig):
|
| 129 |
+
raise TypeError("Logger config must be a DictConfig!")
|
| 130 |
+
|
| 131 |
+
for _, lg_conf in logger_cfg.items():
|
| 132 |
+
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
| 133 |
+
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
| 134 |
+
logger.append(hydra.utils.instantiate(lg_conf))
|
| 135 |
+
|
| 136 |
+
return logger
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@rank_zero_only
|
| 140 |
+
def log_hyperparameters(object_dict: dict) -> None:
|
| 141 |
+
"""Controls which config parts are saved by lightning loggers.
|
| 142 |
+
|
| 143 |
+
Additionally saves:
|
| 144 |
+
- Number of model parameters
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
hparams = {}
|
| 148 |
+
|
| 149 |
+
cfg = object_dict["cfg"]
|
| 150 |
+
model = object_dict["model"]
|
| 151 |
+
trainer = object_dict["trainer"]
|
| 152 |
+
|
| 153 |
+
if not trainer.logger:
|
| 154 |
+
log.warning("Logger not found! Skipping hyperparameter logging...")
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
# save number of model parameters
|
| 158 |
+
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
| 159 |
+
hparams["model/params/trainable"] = sum(
|
| 160 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
| 161 |
+
)
|
| 162 |
+
hparams["model/params/non_trainable"] = sum(
|
| 163 |
+
p.numel() for p in model.parameters() if not p.requires_grad
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
for k in cfg.keys():
|
| 167 |
+
hparams[k] = cfg.get(k)
|
| 168 |
+
|
| 169 |
+
# Resolve all interpolations
|
| 170 |
+
def _resolve(_cfg):
|
| 171 |
+
if isinstance(_cfg, DictConfig):
|
| 172 |
+
_cfg = OmegaConf.to_container(_cfg, resolve=True)
|
| 173 |
+
return _cfg
|
| 174 |
+
|
| 175 |
+
hparams = {k: _resolve(v) for k, v in hparams.items()}
|
| 176 |
+
|
| 177 |
+
# send hparams to all loggers
|
| 178 |
+
trainer.logger.log_hyperparams(hparams)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def get_metric_value(metric_dict: dict, metric_name: str) -> float:
|
| 182 |
+
"""Safely retrieves value of the metric logged in LightningModule."""
|
| 183 |
+
|
| 184 |
+
if not metric_name:
|
| 185 |
+
log.info("Metric name is None! Skipping metric value retrieval...")
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
if metric_name not in metric_dict:
|
| 189 |
+
raise Exception(
|
| 190 |
+
f"Metric value not found! <metric_name={metric_name}>\n"
|
| 191 |
+
"Make sure metric name logged in LightningModule is correct!\n"
|
| 192 |
+
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
metric_value = metric_dict[metric_name].item()
|
| 196 |
+
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
| 197 |
+
|
| 198 |
+
return metric_value
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def close_loggers() -> None:
|
| 202 |
+
"""Makes sure all loggers closed properly (prevents logging failure during multirun)."""
|
| 203 |
+
|
| 204 |
+
log.info("Closing loggers...")
|
| 205 |
+
|
| 206 |
+
if find_spec("wandb"): # if wandb is installed
|
| 207 |
+
import wandb
|
| 208 |
+
|
| 209 |
+
if wandb.run:
|
| 210 |
+
log.info("Closing wandb!")
|
| 211 |
+
wandb.finish()
|
prima/utils/pylogger.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
from pytorch_lightning.utilities import rank_zero_only
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_pylogger(name=__name__) -> logging.Logger:
|
| 16 |
+
"""Initializes multi-GPU-friendly python command line logger."""
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(name)
|
| 19 |
+
|
| 20 |
+
# this ensures all logging levels get marked with the rank zero decorator
|
| 21 |
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
| 22 |
+
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
|
| 23 |
+
for level in logging_levels:
|
| 24 |
+
setattr(logger, level, rank_zero_only(getattr(logger, level)))
|
| 25 |
+
|
| 26 |
+
return logger
|
prima/utils/renderer.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
"""
|
| 3 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 4 |
+
|
| 5 |
+
Official implementation of the paper:
|
| 6 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 7 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 8 |
+
Licensed under a modified MIT license
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
from ctypes.util import find_library
|
| 14 |
+
|
| 15 |
+
if 'PYOPENGL_PLATFORM' not in os.environ and os.uname().sysname != 'Darwin':
|
| 16 |
+
# Prefer EGL; PyOpenGL's OSMesa bindings can lack symbols required by pyrender.
|
| 17 |
+
os.environ['PYOPENGL_PLATFORM'] = 'egl' if find_library('EGL') else 'osmesa'
|
| 18 |
+
if os.environ['PYOPENGL_PLATFORM'] == 'egl':
|
| 19 |
+
os.environ.setdefault('EGL_PLATFORM', 'surfaceless')
|
| 20 |
+
import torch
|
| 21 |
+
import numpy as np
|
| 22 |
+
import pyrender
|
| 23 |
+
import trimesh
|
| 24 |
+
import cv2
|
| 25 |
+
from yacs.config import CfgNode
|
| 26 |
+
from typing import List, Optional
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def cam_crop_to_full(cam_bbox, box_center, box_size, img_size, focal_length=5000.):
|
| 30 |
+
# Convert cam_bbox to full image
|
| 31 |
+
img_w, img_h = img_size[:, 0], img_size[:, 1]
|
| 32 |
+
cx, cy, b = box_center[:, 0], box_center[:, 1], box_size
|
| 33 |
+
w_2, h_2 = img_w / 2., img_h / 2.
|
| 34 |
+
bs = b * cam_bbox[:, 0] + 1e-9
|
| 35 |
+
tz = 2 * focal_length / bs
|
| 36 |
+
tx = (2 * (cx - w_2) / bs) + cam_bbox[:, 1]
|
| 37 |
+
ty = (2 * (cy - h_2) / bs) + cam_bbox[:, 2]
|
| 38 |
+
full_cam = torch.stack([tx, ty, tz], dim=-1)
|
| 39 |
+
return full_cam
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_light_poses(n_lights=5, elevation=np.pi / 3, dist=12):
|
| 43 |
+
# get lights in a circle around origin at elevation
|
| 44 |
+
thetas = elevation * np.ones(n_lights)
|
| 45 |
+
phis = 2 * np.pi * np.arange(n_lights) / n_lights
|
| 46 |
+
poses = []
|
| 47 |
+
trans = make_translation(torch.tensor([0, 0, dist]))
|
| 48 |
+
for phi, theta in zip(phis, thetas):
|
| 49 |
+
rot = make_rotation(rx=-theta, ry=phi, order="xyz")
|
| 50 |
+
poses.append((rot @ trans).numpy())
|
| 51 |
+
return poses
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def make_translation(t):
|
| 55 |
+
return make_4x4_pose(torch.eye(3), t)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def make_rotation(rx=0, ry=0, rz=0, order="xyz"):
|
| 59 |
+
Rx = rotx(rx)
|
| 60 |
+
Ry = roty(ry)
|
| 61 |
+
Rz = rotz(rz)
|
| 62 |
+
if order == "xyz":
|
| 63 |
+
R = Rz @ Ry @ Rx
|
| 64 |
+
elif order == "xzy":
|
| 65 |
+
R = Ry @ Rz @ Rx
|
| 66 |
+
elif order == "yxz":
|
| 67 |
+
R = Rz @ Rx @ Ry
|
| 68 |
+
elif order == "yzx":
|
| 69 |
+
R = Rx @ Rz @ Ry
|
| 70 |
+
elif order == "zyx":
|
| 71 |
+
R = Rx @ Ry @ Rz
|
| 72 |
+
elif order == "zxy":
|
| 73 |
+
R = Ry @ Rx @ Rz
|
| 74 |
+
return make_4x4_pose(R, torch.zeros(3))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def make_4x4_pose(R, t):
|
| 78 |
+
"""
|
| 79 |
+
:param R (*, 3, 3)
|
| 80 |
+
:param t (*, 3)
|
| 81 |
+
return (*, 4, 4)
|
| 82 |
+
"""
|
| 83 |
+
dims = R.shape[:-2]
|
| 84 |
+
pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1)
|
| 85 |
+
bottom = (
|
| 86 |
+
torch.tensor([0, 0, 0, 1], device=R.device)
|
| 87 |
+
.reshape(*(1,) * len(dims), 1, 4)
|
| 88 |
+
.expand(*dims, 1, 4)
|
| 89 |
+
)
|
| 90 |
+
return torch.cat([pose_3x4, bottom], dim=-2)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def rotx(theta):
|
| 94 |
+
return torch.tensor(
|
| 95 |
+
[
|
| 96 |
+
[1, 0, 0],
|
| 97 |
+
[0, np.cos(theta), -np.sin(theta)],
|
| 98 |
+
[0, np.sin(theta), np.cos(theta)],
|
| 99 |
+
],
|
| 100 |
+
dtype=torch.float32,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def roty(theta):
|
| 105 |
+
return torch.tensor(
|
| 106 |
+
[
|
| 107 |
+
[np.cos(theta), 0, np.sin(theta)],
|
| 108 |
+
[0, 1, 0],
|
| 109 |
+
[-np.sin(theta), 0, np.cos(theta)],
|
| 110 |
+
],
|
| 111 |
+
dtype=torch.float32,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def rotz(theta):
|
| 116 |
+
return torch.tensor(
|
| 117 |
+
[
|
| 118 |
+
[np.cos(theta), -np.sin(theta), 0],
|
| 119 |
+
[np.sin(theta), np.cos(theta), 0],
|
| 120 |
+
[0, 0, 1],
|
| 121 |
+
],
|
| 122 |
+
dtype=torch.float32,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def create_raymond_lights() -> List[pyrender.Node]:
|
| 127 |
+
"""
|
| 128 |
+
Return raymond light nodes for the scene.
|
| 129 |
+
"""
|
| 130 |
+
thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
|
| 131 |
+
phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
|
| 132 |
+
|
| 133 |
+
nodes = []
|
| 134 |
+
|
| 135 |
+
for phi, theta in zip(phis, thetas):
|
| 136 |
+
xp = np.sin(theta) * np.cos(phi)
|
| 137 |
+
yp = np.sin(theta) * np.sin(phi)
|
| 138 |
+
zp = np.cos(theta)
|
| 139 |
+
|
| 140 |
+
z = np.array([xp, yp, zp])
|
| 141 |
+
z = z / np.linalg.norm(z)
|
| 142 |
+
x = np.array([-z[1], z[0], 0.0])
|
| 143 |
+
if np.linalg.norm(x) == 0:
|
| 144 |
+
x = np.array([1.0, 0.0, 0.0])
|
| 145 |
+
x = x / np.linalg.norm(x)
|
| 146 |
+
y = np.cross(z, x)
|
| 147 |
+
|
| 148 |
+
matrix = np.eye(4)
|
| 149 |
+
matrix[:3, :3] = np.c_[x, y, z]
|
| 150 |
+
nodes.append(pyrender.Node(
|
| 151 |
+
light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0),
|
| 152 |
+
matrix=matrix
|
| 153 |
+
))
|
| 154 |
+
|
| 155 |
+
return nodes
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class Renderer:
|
| 159 |
+
|
| 160 |
+
def __init__(self, cfg: CfgNode, faces: np.array):
|
| 161 |
+
"""
|
| 162 |
+
Wrapper around the pyrender renderer to render MANO meshes.
|
| 163 |
+
Args:
|
| 164 |
+
cfg (CfgNode): Model config file.
|
| 165 |
+
faces (np.array): Array of shape (F, 3) containing the mesh faces.
|
| 166 |
+
"""
|
| 167 |
+
self.cfg = cfg
|
| 168 |
+
self.focal_length = 1000. if faces.shape[0] == 7774 else 2167.
|
| 169 |
+
self.img_res = cfg.MODEL.IMAGE_SIZE
|
| 170 |
+
|
| 171 |
+
self.camera_center = [self.img_res // 2, self.img_res // 2]
|
| 172 |
+
self.faces = faces.cpu().numpy()
|
| 173 |
+
|
| 174 |
+
def __call__(self,
|
| 175 |
+
vertices: np.array,
|
| 176 |
+
camera_translation: np.array,
|
| 177 |
+
image: torch.Tensor,
|
| 178 |
+
full_frame: bool = False,
|
| 179 |
+
imgname: Optional[str] = None,
|
| 180 |
+
side_view=False, rot_angle=90,
|
| 181 |
+
mesh_base_color=(1.0, 1.0, 0.9),
|
| 182 |
+
scene_bg_color=(0, 0, 0),
|
| 183 |
+
return_rgba=False,
|
| 184 |
+
depth = False,
|
| 185 |
+
focal_length: Optional[float] = None,
|
| 186 |
+
) -> np.array:
|
| 187 |
+
"""
|
| 188 |
+
Render meshes on input image
|
| 189 |
+
Args:
|
| 190 |
+
vertices (np.array): Array of shape (V, 3) containing the mesh vertices.
|
| 191 |
+
camera_translation (np.array): Array of shape (3,) with the camera translation.
|
| 192 |
+
image (torch.Tensor): Tensor of shape (3, H, W) containing the image crop with normalized pixel values.
|
| 193 |
+
full_frame (bool): If True, then render on the full image.
|
| 194 |
+
imgname (Optional[str]): Contains the original image filenamee. Used only if full_frame == True.
|
| 195 |
+
focal_length (Optional[float]): Custom focal length. If None, uses self.focal_length.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
if full_frame:
|
| 199 |
+
|
| 200 |
+
image = cv2.imread(imgname)
|
| 201 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
|
| 202 |
+
else:
|
| 203 |
+
image = (image.clone()) * (torch.tensor(self.cfg.MODEL.IMAGE_STD, device=image.device).reshape(3, 1, 1))
|
| 204 |
+
image = image + torch.tensor(self.cfg.MODEL.IMAGE_MEAN, device=image.device).reshape(3, 1, 1)
|
| 205 |
+
image = image.permute(1, 2, 0).cpu().numpy()
|
| 206 |
+
|
| 207 |
+
# Use custom focal length if provided, otherwise use default
|
| 208 |
+
focal_length_to_use = focal_length if focal_length is not None else self.focal_length
|
| 209 |
+
|
| 210 |
+
renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1],
|
| 211 |
+
viewport_height=image.shape[0],
|
| 212 |
+
point_size=1.0)
|
| 213 |
+
material = pyrender.MetallicRoughnessMaterial(
|
| 214 |
+
metallicFactor=0.0,
|
| 215 |
+
alphaMode='OPAQUE',
|
| 216 |
+
baseColorFactor=(*mesh_base_color, 1.0))
|
| 217 |
+
|
| 218 |
+
camera_translation_local = camera_translation.copy()
|
| 219 |
+
camera_translation_local[0] *= -1.
|
| 220 |
+
|
| 221 |
+
mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
|
| 222 |
+
if side_view:
|
| 223 |
+
rot = trimesh.transformations.rotation_matrix(
|
| 224 |
+
np.radians(rot_angle), [0, 1, 0])
|
| 225 |
+
mesh.apply_transform(rot)
|
| 226 |
+
rot = trimesh.transformations.rotation_matrix(
|
| 227 |
+
np.radians(180), [1, 0, 0])
|
| 228 |
+
mesh.apply_transform(rot)
|
| 229 |
+
mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
|
| 230 |
+
|
| 231 |
+
scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
|
| 232 |
+
ambient_light=(0.3, 0.3, 0.3))
|
| 233 |
+
scene.add(mesh, 'mesh')
|
| 234 |
+
|
| 235 |
+
camera_pose = np.eye(4)
|
| 236 |
+
camera_pose[:3, 3] = camera_translation_local
|
| 237 |
+
camera_center = [image.shape[1] / 2., image.shape[0] / 2.]
|
| 238 |
+
camera = pyrender.IntrinsicsCamera(fx=focal_length_to_use, fy=focal_length_to_use,
|
| 239 |
+
cx=camera_center[0], cy=camera_center[1], zfar=1e12)
|
| 240 |
+
scene.add(camera, pose=camera_pose)
|
| 241 |
+
|
| 242 |
+
light_nodes = create_raymond_lights()
|
| 243 |
+
for node in light_nodes:
|
| 244 |
+
scene.add_node(node)
|
| 245 |
+
|
| 246 |
+
color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
|
| 247 |
+
color = color.astype(np.float32) / 255.0
|
| 248 |
+
renderer.delete()
|
| 249 |
+
|
| 250 |
+
if depth:
|
| 251 |
+
return rend_depth
|
| 252 |
+
|
| 253 |
+
if return_rgba:
|
| 254 |
+
return color
|
| 255 |
+
|
| 256 |
+
valid_mask = (rend_depth > 0).astype(np.float32)[:, :, np.newaxis]
|
| 257 |
+
if not side_view:
|
| 258 |
+
output_img = (color[:, :, :3] * valid_mask + (1 - valid_mask) * image)
|
| 259 |
+
else:
|
| 260 |
+
output_img = color[:, :, :3]
|
| 261 |
+
|
| 262 |
+
output_img = output_img.astype(np.float32)
|
| 263 |
+
return output_img
|
| 264 |
+
|
| 265 |
+
def vertices_to_trimesh(self, vertices, camera_translation, mesh_base_color=(1.0, 1.0, 0.9),
|
| 266 |
+
rot_axis=[1, 0, 0], rot_angle=0):
|
| 267 |
+
# material = pyrender.MetallicRoughnessMaterial(
|
| 268 |
+
# metallicFactor=0.0,
|
| 269 |
+
# alphaMode='OPAQUE',
|
| 270 |
+
# baseColorFactor=(*mesh_base_color, 1.0))
|
| 271 |
+
vertex_colors = np.array([(*mesh_base_color, 1.0)] * vertices.shape[0])
|
| 272 |
+
mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces.copy(), vertex_colors=vertex_colors)
|
| 273 |
+
# mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
|
| 274 |
+
|
| 275 |
+
rot = trimesh.transformations.rotation_matrix(
|
| 276 |
+
np.radians(rot_angle), rot_axis)
|
| 277 |
+
mesh.apply_transform(rot)
|
| 278 |
+
|
| 279 |
+
rot = trimesh.transformations.rotation_matrix(
|
| 280 |
+
np.radians(180), [1, 0, 0])
|
| 281 |
+
mesh.apply_transform(rot)
|
| 282 |
+
return mesh
|
| 283 |
+
|
| 284 |
+
def render_rgba(
|
| 285 |
+
self,
|
| 286 |
+
vertices: np.array,
|
| 287 |
+
cam_t=None,
|
| 288 |
+
rot=None,
|
| 289 |
+
rot_axis=[1, 0, 0],
|
| 290 |
+
rot_angle=0,
|
| 291 |
+
camera_z=3,
|
| 292 |
+
# camera_translation: np.array,
|
| 293 |
+
mesh_base_color=(1.0, 1.0, 0.9),
|
| 294 |
+
scene_bg_color=(0, 0, 0),
|
| 295 |
+
render_res=[256, 256],
|
| 296 |
+
focal_length=None,
|
| 297 |
+
):
|
| 298 |
+
|
| 299 |
+
renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0],
|
| 300 |
+
viewport_height=render_res[1],
|
| 301 |
+
point_size=1.0)
|
| 302 |
+
# material = pyrender.MetallicRoughnessMaterial(
|
| 303 |
+
# metallicFactor=0.0,
|
| 304 |
+
# alphaMode='OPAQUE',
|
| 305 |
+
# baseColorFactor=(*mesh_base_color, 1.0))
|
| 306 |
+
|
| 307 |
+
focal_length = focal_length if focal_length is not None else self.focal_length
|
| 308 |
+
|
| 309 |
+
if cam_t is not None:
|
| 310 |
+
camera_translation = cam_t.copy()
|
| 311 |
+
camera_translation[0] *= -1.
|
| 312 |
+
else:
|
| 313 |
+
camera_translation = np.array([0, 0, camera_z * focal_length / render_res[1]])
|
| 314 |
+
|
| 315 |
+
mesh = self.vertices_to_trimesh(vertices, np.array([0, 0, 0]), mesh_base_color, rot_axis, rot_angle,
|
| 316 |
+
)
|
| 317 |
+
mesh = pyrender.Mesh.from_trimesh(mesh)
|
| 318 |
+
# mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
|
| 319 |
+
|
| 320 |
+
scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
|
| 321 |
+
ambient_light=(0.3, 0.3, 0.3))
|
| 322 |
+
scene.add(mesh, 'mesh')
|
| 323 |
+
|
| 324 |
+
camera_pose = np.eye(4)
|
| 325 |
+
camera_pose[:3, 3] = camera_translation
|
| 326 |
+
camera_center = [render_res[0] / 2., render_res[1] / 2.]
|
| 327 |
+
camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
|
| 328 |
+
cx=camera_center[0], cy=camera_center[1], zfar=1e12)
|
| 329 |
+
|
| 330 |
+
# Create camera node and add it to pyRender scene
|
| 331 |
+
camera_node = pyrender.Node(camera=camera, matrix=camera_pose)
|
| 332 |
+
scene.add_node(camera_node)
|
| 333 |
+
self.add_point_lighting(scene, camera_node)
|
| 334 |
+
self.add_lighting(scene, camera_node)
|
| 335 |
+
|
| 336 |
+
light_nodes = create_raymond_lights()
|
| 337 |
+
for node in light_nodes:
|
| 338 |
+
scene.add_node(node)
|
| 339 |
+
|
| 340 |
+
color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
|
| 341 |
+
color = color.astype(np.float32) / 255.0
|
| 342 |
+
renderer.delete()
|
| 343 |
+
|
| 344 |
+
return color
|
| 345 |
+
|
| 346 |
+
def render_rgba_multiple(
|
| 347 |
+
self,
|
| 348 |
+
vertices: List[np.array],
|
| 349 |
+
cam_t: List[np.array],
|
| 350 |
+
rot_axis=[1, 0, 0],
|
| 351 |
+
rot_angle=0,
|
| 352 |
+
mesh_base_color=(1.0, 1.0, 0.9),
|
| 353 |
+
scene_bg_color=(0, 0, 0),
|
| 354 |
+
render_res=[256, 256],
|
| 355 |
+
focal_length=None,
|
| 356 |
+
):
|
| 357 |
+
|
| 358 |
+
renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0],
|
| 359 |
+
viewport_height=render_res[1],
|
| 360 |
+
point_size=1.0)
|
| 361 |
+
# material = pyrender.MetallicRoughnessMaterial(
|
| 362 |
+
# metallicFactor=0.0,
|
| 363 |
+
# alphaMode='OPAQUE',
|
| 364 |
+
# baseColorFactor=(*mesh_base_color, 1.0))
|
| 365 |
+
|
| 366 |
+
mesh_list = [pyrender.Mesh.from_trimesh(
|
| 367 |
+
self.vertices_to_trimesh(vvv, ttt.copy(), mesh_base_color, rot_axis, rot_angle)) for
|
| 368 |
+
vvv, ttt in zip(vertices, cam_t)]
|
| 369 |
+
|
| 370 |
+
scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
|
| 371 |
+
ambient_light=(0.3, 0.3, 0.3))
|
| 372 |
+
for i, mesh in enumerate(mesh_list):
|
| 373 |
+
scene.add(mesh, f'mesh_{i}')
|
| 374 |
+
|
| 375 |
+
camera_pose = np.eye(4)
|
| 376 |
+
# camera_pose[:3, 3] = camera_translation
|
| 377 |
+
camera_center = [render_res[0] / 2., render_res[1] / 2.]
|
| 378 |
+
focal_length = focal_length if focal_length is not None else self.focal_length
|
| 379 |
+
camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
|
| 380 |
+
cx=camera_center[0], cy=camera_center[1], zfar=1e12)
|
| 381 |
+
|
| 382 |
+
# Create camera node and add it to pyRender scene
|
| 383 |
+
camera_node = pyrender.Node(camera=camera, matrix=camera_pose)
|
| 384 |
+
scene.add_node(camera_node)
|
| 385 |
+
self.add_point_lighting(scene, camera_node)
|
| 386 |
+
self.add_lighting(scene, camera_node)
|
| 387 |
+
|
| 388 |
+
light_nodes = create_raymond_lights()
|
| 389 |
+
for node in light_nodes:
|
| 390 |
+
scene.add_node(node)
|
| 391 |
+
|
| 392 |
+
color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
|
| 393 |
+
color = color.astype(np.float32) / 255.0
|
| 394 |
+
renderer.delete()
|
| 395 |
+
|
| 396 |
+
return color
|
| 397 |
+
|
| 398 |
+
def add_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0):
|
| 399 |
+
|
| 400 |
+
light_poses = get_light_poses()
|
| 401 |
+
light_poses.append(np.eye(4))
|
| 402 |
+
cam_pose = scene.get_pose(cam_node)
|
| 403 |
+
for i, pose in enumerate(light_poses):
|
| 404 |
+
matrix = cam_pose @ pose
|
| 405 |
+
node = pyrender.Node(
|
| 406 |
+
name=f"light-{i:02d}",
|
| 407 |
+
light=pyrender.DirectionalLight(color=color, intensity=intensity),
|
| 408 |
+
matrix=matrix,
|
| 409 |
+
)
|
| 410 |
+
if scene.has_node(node):
|
| 411 |
+
continue
|
| 412 |
+
scene.add_node(node)
|
| 413 |
+
|
| 414 |
+
def add_point_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0):
|
| 415 |
+
|
| 416 |
+
light_poses = get_light_poses(dist=0.5)
|
| 417 |
+
light_poses.append(np.eye(4))
|
| 418 |
+
cam_pose = scene.get_pose(cam_node)
|
| 419 |
+
for i, pose in enumerate(light_poses):
|
| 420 |
+
matrix = cam_pose @ pose
|
| 421 |
+
# node = pyrender.Node(
|
| 422 |
+
# name=f"light-{i:02d}",
|
| 423 |
+
# light=pyrender.DirectionalLight(color=color, intensity=intensity),
|
| 424 |
+
# matrix=matrix,
|
| 425 |
+
# )
|
| 426 |
+
node = pyrender.Node(
|
| 427 |
+
name=f"plight-{i:02d}",
|
| 428 |
+
light=pyrender.PointLight(color=color, intensity=intensity),
|
| 429 |
+
matrix=matrix,
|
| 430 |
+
)
|
| 431 |
+
if scene.has_node(node):
|
| 432 |
+
continue
|
| 433 |
+
scene.add_node(node)
|
prima/utils/rich_utils.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Sequence
|
| 12 |
+
|
| 13 |
+
import rich
|
| 14 |
+
import rich.syntax
|
| 15 |
+
import rich.tree
|
| 16 |
+
from hydra.core.hydra_config import HydraConfig
|
| 17 |
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
| 18 |
+
from pytorch_lightning.utilities import rank_zero_only
|
| 19 |
+
from rich.prompt import Prompt
|
| 20 |
+
|
| 21 |
+
from . import pylogger
|
| 22 |
+
|
| 23 |
+
log = pylogger.get_pylogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@rank_zero_only
|
| 27 |
+
def print_config_tree(
|
| 28 |
+
cfg: DictConfig,
|
| 29 |
+
print_order: Sequence[str] = (
|
| 30 |
+
"datamodule",
|
| 31 |
+
"model",
|
| 32 |
+
"callbacks",
|
| 33 |
+
"logger",
|
| 34 |
+
"trainer",
|
| 35 |
+
"paths",
|
| 36 |
+
"extras",
|
| 37 |
+
),
|
| 38 |
+
resolve: bool = False,
|
| 39 |
+
save_to_file: bool = False,
|
| 40 |
+
) -> None:
|
| 41 |
+
"""Prints content of DictConfig using Rich library and its tree structure.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
cfg (DictConfig): Configuration composed by Hydra.
|
| 45 |
+
print_order (Sequence[str], optional): Determines in what order config components are printed.
|
| 46 |
+
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
|
| 47 |
+
save_to_file (bool, optional): Whether to export config to the hydra output folder.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
style = "dim"
|
| 51 |
+
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
| 52 |
+
|
| 53 |
+
queue = []
|
| 54 |
+
|
| 55 |
+
# add fields from `print_order` to queue
|
| 56 |
+
for field in print_order:
|
| 57 |
+
queue.append(field) if field in cfg else log.warning(
|
| 58 |
+
f"Field '{field}' not found in config. Skipping '{field}' config printing..."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# add all the other fields to queue (not specified in `print_order`)
|
| 62 |
+
for field in cfg:
|
| 63 |
+
if field not in queue:
|
| 64 |
+
queue.append(field)
|
| 65 |
+
|
| 66 |
+
# generate config tree from queue
|
| 67 |
+
for field in queue:
|
| 68 |
+
branch = tree.add(field, style=style, guide_style=style)
|
| 69 |
+
|
| 70 |
+
config_group = cfg[field]
|
| 71 |
+
if isinstance(config_group, DictConfig):
|
| 72 |
+
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
| 73 |
+
else:
|
| 74 |
+
branch_content = str(config_group)
|
| 75 |
+
|
| 76 |
+
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
| 77 |
+
|
| 78 |
+
# print config tree
|
| 79 |
+
rich.print(tree)
|
| 80 |
+
|
| 81 |
+
# save config tree to file
|
| 82 |
+
if save_to_file:
|
| 83 |
+
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
| 84 |
+
rich.print(tree, file=file)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@rank_zero_only
|
| 88 |
+
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
| 89 |
+
"""Prompts user to input tags from command line if no tags are provided in config."""
|
| 90 |
+
|
| 91 |
+
if not cfg.get("tags"):
|
| 92 |
+
if "id" in HydraConfig().cfg.hydra.job:
|
| 93 |
+
raise ValueError("Specify tags before launching a multirun!")
|
| 94 |
+
|
| 95 |
+
log.warning("No tags provided in config. Prompting user to input tags...")
|
| 96 |
+
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
|
| 97 |
+
tags = [t.strip() for t in tags.split(",") if t != ""]
|
| 98 |
+
|
| 99 |
+
with open_dict(cfg):
|
| 100 |
+
cfg.tags = tags
|
| 101 |
+
|
| 102 |
+
log.info(f"Tags: {cfg.tags}")
|
| 103 |
+
|
| 104 |
+
if save_to_file:
|
| 105 |
+
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
| 106 |
+
rich.print(cfg.tags, file=file)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
from hydra import compose, initialize
|
| 111 |
+
|
| 112 |
+
with initialize(version_base="1.2", config_path="../../configs_hydra"):
|
| 113 |
+
cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[])
|
| 114 |
+
print_config_tree(cfg, resolve=False, save_to_file=False)
|
prima/utils/weights.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
|
| 3 |
+
|
| 4 |
+
Official implementation of the paper:
|
| 5 |
+
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
|
| 6 |
+
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
|
| 7 |
+
Licensed under a modified MIT license
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import shutil
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Iterable, Optional, Sequence, Union
|
| 16 |
+
|
| 17 |
+
HF_REPO_ID = "MLAdaptiveIntelligence/PRIMA"
|
| 18 |
+
DEFAULT_HF_REPO_ID = HF_REPO_ID
|
| 19 |
+
|
| 20 |
+
DEFAULT_STAGE1_CHECKPOINT = Path("data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt")
|
| 21 |
+
DEFAULT_STAGE3_CHECKPOINT = Path("data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt")
|
| 22 |
+
|
| 23 |
+
SMAL_ASSET_PATHS = [
|
| 24 |
+
"my_smpl_00781_4_all.pkl",
|
| 25 |
+
"my_smpl_data_00781_4_all.pkl",
|
| 26 |
+
"walking_toy_symmetric_pose_prior_with_cov_35parts.pkl",
|
| 27 |
+
]
|
| 28 |
+
BACKBONE_ASSET_PATH = "amr_vitbb.pth"
|
| 29 |
+
STAGE1_CONFIG_ASSET_PATH = "config_s1_HYDRA.yaml"
|
| 30 |
+
STAGE1_CHECKPOINT_ASSET_PATH = "s1ckpt_inference.ckpt"
|
| 31 |
+
STAGE3_CONFIG_ASSET_PATH = "config_s3_HYDRA.yaml"
|
| 32 |
+
STAGE3_CHECKPOINT_ASSET_PATH = "s3ckpt_inference.ckpt"
|
| 33 |
+
|
| 34 |
+
STAGE_ASSETS = {
|
| 35 |
+
"PRIMAS1": (STAGE1_CONFIG_ASSET_PATH, STAGE1_CHECKPOINT_ASSET_PATH, "s1ckpt_inference.ckpt"),
|
| 36 |
+
"PRIMAS3": (STAGE3_CONFIG_ASSET_PATH, STAGE3_CHECKPOINT_ASSET_PATH, "s3ckpt_inference.ckpt"),
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
STAGE_CHECKPOINTS = {
|
| 40 |
+
"PRIMAS1": Path("PRIMAS1/checkpoints/s1ckpt_inference.ckpt"),
|
| 41 |
+
"PRIMAS3": Path("PRIMAS3/checkpoints/s3ckpt_inference.ckpt"),
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
PathLike = Union[str, Path]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _resolve_hf_repo_id(hf_repo_id: Optional[str]) -> str:
|
| 48 |
+
return hf_repo_id or os.environ.get("PRIMA_HF_REPO_ID", HF_REPO_ID)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _default_checkpoint_path(data_dir: PathLike = "data") -> Path:
|
| 52 |
+
return Path(data_dir) / STAGE_CHECKPOINTS["PRIMAS1"]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _config_path_for_checkpoint(checkpoint_path: PathLike) -> Path:
|
| 56 |
+
checkpoint_path = Path(checkpoint_path)
|
| 57 |
+
return checkpoint_path.parent.parent / ".hydra" / "config.yaml"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _stage_for_checkpoint(checkpoint_path: PathLike) -> Optional[str]:
|
| 61 |
+
checkpoint_path = Path(checkpoint_path)
|
| 62 |
+
if len(checkpoint_path.parents) < 2:
|
| 63 |
+
return None
|
| 64 |
+
stage_name = checkpoint_path.parent.parent.name
|
| 65 |
+
stage_assets = STAGE_ASSETS.get(stage_name)
|
| 66 |
+
if stage_assets is None:
|
| 67 |
+
return None
|
| 68 |
+
_, _, checkpoint_name = stage_assets
|
| 69 |
+
if checkpoint_path.name != checkpoint_name:
|
| 70 |
+
return None
|
| 71 |
+
return stage_name
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _download_file(
|
| 75 |
+
hf_repo_id: str,
|
| 76 |
+
remote_filename: str,
|
| 77 |
+
destination: Path,
|
| 78 |
+
force_download: bool = False,
|
| 79 |
+
) -> None:
|
| 80 |
+
try:
|
| 81 |
+
from huggingface_hub import hf_hub_download
|
| 82 |
+
except ImportError:
|
| 83 |
+
raise ImportError(
|
| 84 |
+
"huggingface_hub is required to download PRIMA demo assets. "
|
| 85 |
+
"Install it with: pip install huggingface_hub\n"
|
| 86 |
+
"Or download the assets manually and pass a local checkpoint path."
|
| 87 |
+
) from None
|
| 88 |
+
|
| 89 |
+
destination.parent.mkdir(parents=True, exist_ok=True)
|
| 90 |
+
downloaded = hf_hub_download(
|
| 91 |
+
repo_id=hf_repo_id,
|
| 92 |
+
filename=remote_filename,
|
| 93 |
+
local_dir=str(destination.parent),
|
| 94 |
+
local_dir_use_symlinks=False,
|
| 95 |
+
force_download=force_download,
|
| 96 |
+
)
|
| 97 |
+
downloaded_path = Path(downloaded).resolve()
|
| 98 |
+
target = destination.resolve()
|
| 99 |
+
if downloaded_path != target:
|
| 100 |
+
if target.exists():
|
| 101 |
+
target.unlink()
|
| 102 |
+
shutil.move(str(downloaded_path), str(target))
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _validate_torch_checkpoint(path: Path) -> None:
|
| 106 |
+
import inspect
|
| 107 |
+
import pickle
|
| 108 |
+
import zipfile
|
| 109 |
+
|
| 110 |
+
import torch
|
| 111 |
+
|
| 112 |
+
if zipfile.is_zipfile(path):
|
| 113 |
+
with zipfile.ZipFile(path) as checkpoint_zip:
|
| 114 |
+
corrupt_member = checkpoint_zip.testzip()
|
| 115 |
+
if corrupt_member is not None:
|
| 116 |
+
raise RuntimeError(
|
| 117 |
+
f"Checkpoint file is invalid or incomplete: {path}\n"
|
| 118 |
+
f"Corrupt archive member: {corrupt_member}\n"
|
| 119 |
+
"Please redownload the checkpoint and try again."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
supports_weights_only = "weights_only" in inspect.signature(torch.load).parameters
|
| 123 |
+
load_kwargs = {"map_location": "cpu"}
|
| 124 |
+
if supports_weights_only:
|
| 125 |
+
load_kwargs["weights_only"] = True
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
torch.load(path, **load_kwargs)
|
| 129 |
+
except pickle.UnpicklingError as exc:
|
| 130 |
+
message = str(exc)
|
| 131 |
+
if (
|
| 132 |
+
supports_weights_only
|
| 133 |
+
and "Weights only load failed" in message
|
| 134 |
+
and ("Unsupported global" in message or "Unsupported class" in message)
|
| 135 |
+
):
|
| 136 |
+
return
|
| 137 |
+
raise RuntimeError(
|
| 138 |
+
f"Checkpoint file is invalid or incomplete: {path}\n"
|
| 139 |
+
"Downloaded checkpoint is not loadable. "
|
| 140 |
+
"Please verify the uploaded Hugging Face file and try again."
|
| 141 |
+
) from exc
|
| 142 |
+
except Exception as exc:
|
| 143 |
+
raise RuntimeError(
|
| 144 |
+
f"Checkpoint file is invalid or incomplete: {path}\n"
|
| 145 |
+
"Downloaded checkpoint is not loadable. "
|
| 146 |
+
"Please verify the uploaded Hugging Face file and try again."
|
| 147 |
+
) from exc
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _ensure_backbone(data_dir: Path, force: bool, hf_repo_id: str) -> None:
|
| 151 |
+
target = data_dir / "amr_vitbb.pth"
|
| 152 |
+
if target.exists() and not force:
|
| 153 |
+
print(f"[skip] {target} already exists")
|
| 154 |
+
return
|
| 155 |
+
|
| 156 |
+
print("[download] pretrained backbone")
|
| 157 |
+
_download_file(hf_repo_id, BACKBONE_ASSET_PATH, target, force_download=force)
|
| 158 |
+
print(f"[ok] {target}")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _ensure_smal_assets(data_dir: Path, force: bool, hf_repo_id: str) -> None:
|
| 162 |
+
required = [Path(p).name for p in SMAL_ASSET_PATHS]
|
| 163 |
+
smal_dir = data_dir / "smal"
|
| 164 |
+
if smal_dir.exists() and all((smal_dir / n).exists() for n in required) and not force:
|
| 165 |
+
print("[skip] SMAL files already exist")
|
| 166 |
+
return
|
| 167 |
+
|
| 168 |
+
print("[download] SMAL assets")
|
| 169 |
+
for asset_path in SMAL_ASSET_PATHS:
|
| 170 |
+
target = smal_dir / Path(asset_path).name
|
| 171 |
+
_download_file(hf_repo_id, asset_path, target, force_download=force)
|
| 172 |
+
print(f"[ok] {smal_dir}")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _ensure_stage_assets(
|
| 176 |
+
stage_name: str,
|
| 177 |
+
data_dir: Path,
|
| 178 |
+
force: bool,
|
| 179 |
+
hf_repo_id: str,
|
| 180 |
+
validate_existing: bool = True,
|
| 181 |
+
) -> None:
|
| 182 |
+
if stage_name not in STAGE_ASSETS:
|
| 183 |
+
known = ", ".join(sorted(STAGE_ASSETS))
|
| 184 |
+
raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}")
|
| 185 |
+
|
| 186 |
+
config_asset_path, checkpoint_asset_path, checkpoint_name = STAGE_ASSETS[stage_name]
|
| 187 |
+
stage_dir = data_dir / stage_name
|
| 188 |
+
config_target = stage_dir / ".hydra" / "config.yaml"
|
| 189 |
+
checkpoint_target = stage_dir / "checkpoints" / checkpoint_name
|
| 190 |
+
redownload_checkpoint = False
|
| 191 |
+
|
| 192 |
+
if config_target.exists() and checkpoint_target.exists() and not force:
|
| 193 |
+
if validate_existing:
|
| 194 |
+
try:
|
| 195 |
+
_validate_torch_checkpoint(checkpoint_target)
|
| 196 |
+
except RuntimeError:
|
| 197 |
+
print(f"[warn] {stage_name} checkpoint is incomplete, redownloading checkpoint only.")
|
| 198 |
+
redownload_checkpoint = True
|
| 199 |
+
else:
|
| 200 |
+
print(f"[skip] {stage_name} assets already exist")
|
| 201 |
+
return
|
| 202 |
+
else:
|
| 203 |
+
print(f"[skip] {stage_name} assets already exist")
|
| 204 |
+
return
|
| 205 |
+
|
| 206 |
+
print(f"[download] {stage_name} assets")
|
| 207 |
+
config_target.parent.mkdir(parents=True, exist_ok=True)
|
| 208 |
+
checkpoint_target.parent.mkdir(parents=True, exist_ok=True)
|
| 209 |
+
if force or not config_target.exists():
|
| 210 |
+
_download_file(hf_repo_id, config_asset_path, config_target, force_download=force)
|
| 211 |
+
if redownload_checkpoint and checkpoint_target.exists():
|
| 212 |
+
checkpoint_target.unlink()
|
| 213 |
+
if force or redownload_checkpoint or not checkpoint_target.exists():
|
| 214 |
+
_download_file(
|
| 215 |
+
hf_repo_id,
|
| 216 |
+
checkpoint_asset_path,
|
| 217 |
+
checkpoint_target,
|
| 218 |
+
force_download=force or redownload_checkpoint,
|
| 219 |
+
)
|
| 220 |
+
_validate_torch_checkpoint(checkpoint_target)
|
| 221 |
+
print(f"[ok] {stage_dir}")
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _normalize_stages(stages: Union[str, Iterable[str]]) -> Sequence[str]:
|
| 225 |
+
if isinstance(stages, str):
|
| 226 |
+
return (stages,)
|
| 227 |
+
return tuple(stages)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _verify_assets(data_dir: Path, stages: Sequence[str]) -> None:
|
| 231 |
+
required_paths = [
|
| 232 |
+
data_dir / "smal" / "my_smpl_00781_4_all.pkl",
|
| 233 |
+
data_dir / "smal" / "my_smpl_data_00781_4_all.pkl",
|
| 234 |
+
data_dir / "smal" / "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl",
|
| 235 |
+
data_dir / "amr_vitbb.pth",
|
| 236 |
+
]
|
| 237 |
+
for stage_name in stages:
|
| 238 |
+
if stage_name not in STAGE_ASSETS:
|
| 239 |
+
known = ", ".join(sorted(STAGE_ASSETS))
|
| 240 |
+
raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}")
|
| 241 |
+
_, _, checkpoint_name = STAGE_ASSETS[stage_name]
|
| 242 |
+
stage_dir = data_dir / stage_name
|
| 243 |
+
required_paths.extend(
|
| 244 |
+
[
|
| 245 |
+
stage_dir / ".hydra" / "config.yaml",
|
| 246 |
+
stage_dir / "checkpoints" / checkpoint_name,
|
| 247 |
+
]
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
missing = [p for p in required_paths if not p.exists()]
|
| 251 |
+
if missing:
|
| 252 |
+
raise FileNotFoundError("Missing required files:\n" + "\n".join(str(p) for p in missing))
|
| 253 |
+
|
| 254 |
+
for stage_name in stages:
|
| 255 |
+
_, _, checkpoint_name = STAGE_ASSETS[stage_name]
|
| 256 |
+
_validate_torch_checkpoint(data_dir / stage_name / "checkpoints" / checkpoint_name)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _ensure_assets_for_checkpoint(
|
| 260 |
+
checkpoint_path: PathLike,
|
| 261 |
+
force: bool = False,
|
| 262 |
+
hf_repo_id: Optional[str] = None,
|
| 263 |
+
) -> None:
|
| 264 |
+
checkpoint_path = Path(checkpoint_path)
|
| 265 |
+
config_path = _config_path_for_checkpoint(checkpoint_path)
|
| 266 |
+
stage_name = _stage_for_checkpoint(checkpoint_path)
|
| 267 |
+
if stage_name is None:
|
| 268 |
+
if checkpoint_path.exists() and config_path.exists() and not force:
|
| 269 |
+
print(f"[skip] Using local PRIMA checkpoint {checkpoint_path}")
|
| 270 |
+
return
|
| 271 |
+
raise FileNotFoundError(
|
| 272 |
+
"Missing checkpoint or config for a custom path:\n"
|
| 273 |
+
f" checkpoint: {checkpoint_path}\n"
|
| 274 |
+
f" config: {config_path}\n"
|
| 275 |
+
"Auto-download supports the standard PRIMA demo layouts only:\n"
|
| 276 |
+
" data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt\n"
|
| 277 |
+
" data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt\n"
|
| 278 |
+
"Pass one of those paths, or download/copy your custom checkpoint manually."
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
data_dir = checkpoint_path.parent.parent.parent
|
| 282 |
+
repo_id = _resolve_hf_repo_id(hf_repo_id)
|
| 283 |
+
print(f"[download] Ensuring PRIMA demo assets under {data_dir}")
|
| 284 |
+
_ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id)
|
| 285 |
+
_ensure_backbone(data_dir, force=force, hf_repo_id=repo_id)
|
| 286 |
+
_ensure_stage_assets(
|
| 287 |
+
stage_name,
|
| 288 |
+
data_dir,
|
| 289 |
+
force=force,
|
| 290 |
+
hf_repo_id=repo_id,
|
| 291 |
+
validate_existing=False,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def ensure_demo_assets(
|
| 296 |
+
data_dir: PathLike = "data",
|
| 297 |
+
*,
|
| 298 |
+
stages: Union[str, Iterable[str]] = ("PRIMAS1",),
|
| 299 |
+
force: bool = False,
|
| 300 |
+
hf_repo_id: Optional[str] = None,
|
| 301 |
+
) -> None:
|
| 302 |
+
"""Ensure PRIMA demo assets exist in the expected ``data/`` layout."""
|
| 303 |
+
data_dir = Path(data_dir).resolve()
|
| 304 |
+
data_dir.mkdir(parents=True, exist_ok=True)
|
| 305 |
+
repo_id = _resolve_hf_repo_id(hf_repo_id)
|
| 306 |
+
selected_stages = _normalize_stages(stages)
|
| 307 |
+
|
| 308 |
+
_ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id)
|
| 309 |
+
_ensure_backbone(data_dir, force=force, hf_repo_id=repo_id)
|
| 310 |
+
for stage_name in selected_stages:
|
| 311 |
+
_ensure_stage_assets(stage_name, data_dir, force=force, hf_repo_id=repo_id)
|
| 312 |
+
_verify_assets(data_dir, selected_stages)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def resolve_prima_checkpoint_path(
|
| 316 |
+
checkpoint_path: PathLike = "",
|
| 317 |
+
*,
|
| 318 |
+
data_dir: PathLike = "data",
|
| 319 |
+
auto_download: bool = True,
|
| 320 |
+
hf_repo_id: Optional[str] = None,
|
| 321 |
+
force: bool = False,
|
| 322 |
+
) -> str:
|
| 323 |
+
"""Return a PRIMA checkpoint path, downloading default demo assets if needed."""
|
| 324 |
+
resolved = Path(checkpoint_path) if checkpoint_path else _default_checkpoint_path(data_dir)
|
| 325 |
+
if auto_download:
|
| 326 |
+
_ensure_assets_for_checkpoint(resolved, force=force, hf_repo_id=hf_repo_id)
|
| 327 |
+
return str(resolved)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
__all__ = [
|
| 331 |
+
"DEFAULT_HF_REPO_ID",
|
| 332 |
+
"DEFAULT_STAGE1_CHECKPOINT",
|
| 333 |
+
"DEFAULT_STAGE3_CHECKPOINT",
|
| 334 |
+
"HF_REPO_ID",
|
| 335 |
+
"ensure_demo_assets",
|
| 336 |
+
"resolve_prima_checkpoint_path",
|
| 337 |
+
]
|