diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..c9aeeafbf06c1e9208bc88aa7befba9918e8765e --- /dev/null +++ b/.gitattributes @@ -0,0 +1,5 @@ +# Hugging Face Hub stores these via Git LFS / Xet (plain PNG/JPG in git are rejected on push). +demo_data/*.png filter=lfs diff=lfs merge=lfs -text +demo_data/*.jpg filter=lfs diff=lfs merge=lfs -text +demo_data/*.jpeg filter=lfs diff=lfs merge=lfs -text +images/*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/check-headers.yml b/.github/workflows/check-headers.yml new file mode 100644 index 0000000000000000000000000000000000000000..b44c56c2ebfde8f94a341574509bffd8eb03200f --- /dev/null +++ b/.github/workflows/check-headers.yml @@ -0,0 +1,36 @@ +--- + name: Check File Headers + + on: + push: + branches: [main] + pull_request: + branches: [main] + + jobs: + check-headers: + name: Check Python file headers + runs-on: ubuntu-latest + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Check headers + run: | + python scripts/update_headers.py --check + continue-on-error: false + + - name: Provide fix instructions + if: failure() + run: | + echo "::error::Some files are missing proper headers." + echo "To fix this, run: python scripts/update_headers.py" + echo "Then commit the changes." \ No newline at end of file diff --git a/.github/workflows/codespell.yml b/.github/workflows/codespell.yml new file mode 100644 index 0000000000000000000000000000000000000000..d7536076209626bdde3546f8e63f227b65318840 --- /dev/null +++ b/.github/workflows/codespell.yml @@ -0,0 +1,21 @@ +--- + name: Codespell + + on: + push: + branches: [main] + pull_request: + branches: [main] + + jobs: + codespell: + name: Check for spelling errors + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Codespell + uses: codespell-project/actions-codespell@v1 + with: + ignore_words_list: prima-animal, mpjpe, uvd, xyz, hm36, cpn, dbb \ No newline at end of file diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml new file mode 100644 index 0000000000000000000000000000000000000000..e45b97fedf4fbdd335f907541871bdd8a7125a82 --- /dev/null +++ b/.github/workflows/release-pypi.yml @@ -0,0 +1,48 @@ +name: Update pypi release + +on: + push: + tags: + - 'v*.*.*' + pull_request: + branches: + - main + types: + - labeled + - opened + - edited + - synchronize + - reopened + +jobs: + release: + runs-on: ubuntu-latest + + steps: + - name: Cache dependencies + id: pip-cache + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install wheel + # NOTE(stes) see https://github.com/pypa/twine/issues/1216#issuecomment-2629069669 + pip install "packaging>=24.2" + + - name: Checkout code + uses: actions/checkout@v3 + + - name: Build and publish to PyPI + if: ${{ github.event_name == 'push' }} + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} + run: | + pip install build twine + python3 -m build + ls dist/ + python3 -m twine upload --verbose dist/* \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f8b5caa1f1c0cc3d6a32d5f30d6669c5b8ffc504 --- /dev/null +++ b/.gitignore @@ -0,0 +1,175 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ +# Vscode +.vscode/ + +# Directory +.gradio/ +demo_out/ +demo_out*/ +data/PRIMA*/ +data/backbone.pth +logs/ +*.pth +*.pkl +datasets/ diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..22ad91f0370e9c35da1d8397eca3ae4b72bc31f8 --- /dev/null +++ b/README.md @@ -0,0 +1,252 @@ +# PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + + +This is the official implementation of the approach described in the preprint: + +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation \ +Xiaohang Yu, Ti Wang, Mackenzie Weygandt Mathis + +![PRIMA teaser](images/teaser.png) + + +--- + + +## 🚀 TL;DR +PRIMA creates a 3D quadruped mesh from a single 2D image. It leverages BioCLIP-based biological priors for robust cross-species shape understanding, then applies test-time adaptation with 2D reprojection and auxiliary keypoint guidance to refine SMAL pose and shape predictions. + +It further can be used to build Quadruped3D, a large-scale pseudo-3D dataset with diverse species and poses. + +PRIMA achieves state-of-the-art results on Animal3D, CtrlAni3D, Quadruped2D, and Animal Kingdom datasets. + +## Installation + +### Install from PyPI + +> Recommended: Python 3.10 and a CUDA-enabled PyTorch installation. + +```bash +conda create -n prima python=3.10 -y +conda activate prima + +# Install PyTorch matching your CUDA (example: CUDA 11.8) +pip install --index-url https://download.pytorch.org/whl/cu118 \ + "torch==2.2.1" "torchvision==0.17.1" "torchaudio==2.2.1" + +# Install chumpy and PyTorch3D +python -m pip install --no-build-isolation \ + "git+https://github.com/mattloper/chumpy.git" +python -m pip install --no-build-isolation \ + "git+https://github.com/facebookresearch/pytorch3d.git" + +# Install PRIMA from PyPI +pip install prima-animal +``` + +`prima-animal` includes demo runtime dependencies used by `demo.py`, `demo_tta.py`, and `app.py` (including Detectron2 and DeepLabCut). + +### Clean install from this repository + +Use these when developing from a **git clone** (not the PyPI wheel). The shell scripts are **non-interactive** (pip uses `--no-input`; `GIT_TERMINAL_PROMPT=0` for git). Put Hugging Face credentials in your environment or git credential helper before pushing the Space. + +**Local (fresh venv, LFS assets, Hub demo weights, smoke test)** — requires **Python 3.10+** +(Gradio 5.1+ / Space-provided Gradio 6.x and `app.py` type hints). On macOS without `python3.10` on your `PATH`, install +`brew install python@3.10` and set `PRIMA_PYTHON=/opt/homebrew/bin/python3.10`. + +```bash +chmod +x scripts/clean_install_local.sh scripts/clean_redeploy_hf_space.sh scripts/deploy_hf_space.sh +PRIMA_PYTHON=/opt/homebrew/bin/python3.10 ./scripts/clean_install_local.sh +``` + +Options: + +- `PRIMA_VENV=.venv ./scripts/clean_install_local.sh --skip-data` — skip the large `setup_demo_data` download if `data/` is already populated. +- `./scripts/clean_install_local.sh --wipe-data --force-data` — delete downloaded `data/` assets and redownload. +- `./scripts/clean_install_local.sh --no-editable` — only `requirements.txt` (no `pip install -e .`); use if editable install fails and you will install the training stack via conda as in the PyPI section above. You still need **Python 3.10+** for Gradio 5.1+. The smoke test sets `PYTHONPATH` to the repo root so `import prima` works without an editable install. +- **`requirements.txt` pins `deeplabcut==3.0.0rc14`** (SuperAnimal PyTorch API). On macOS, `clean_install_local.sh` installs a PyTables wheel first, then DLC 3.x. Full check: `./scripts/test_local_full.sh`. + +After `requirements.txt`, the script runs **`pip install --no-deps -e .`** so the `prima` package is registered without re-resolving `pyproject.toml` (which would pull **Detectron2** from git again). Install Detectron2 separately if needed: `pip install 'git+https://github.com/facebookresearch/detectron2.git'`. + +**Hugging Face Space (full redeploy from your working tree):** + +Requires [Git LFS / Xet](https://huggingface.co/docs/hub/xet/using-xet-storage#git) tooling (`brew install git-lfs git-xet`, `git xet install`, `git lfs install`). Then: + +```bash +./scripts/clean_redeploy_hf_space.sh +``` + +This is equivalent to `./scripts/deploy_hf_space.sh` and force-pushes a fresh snapshot to the Space. + +--- + +## Demo + +### Checkpoints and data + +The demo scripts auto-download their default Stage 1 PRIMA assets from Hugging +Face when the checkpoint or matching Hydra config is missing. If you want to +pre-download all necessary checkpoints and data ahead of time, run: + +```bash +python scripts/setup_demo_data.py --hf-repo-id MLAdaptiveIntelligence/PRIMA +``` + +Approximate default prefetch volume from Hugging Face is ~5.5 GB total +(`s1ckpt_inference.ckpt` ~3 GB + `amr_vitbb.pth` ~2.5 GB + SMAL files). +Expected time is roughly: +- 100 Mbps: ~7-10 minutes +- 300 Mbps: ~2-4 minutes +- 1 Gbps: ~1 minute + +Existing files are reused by default; pass `--force` only if you need to redownload them. If you also need the Stage 3 pretrained model, add `--include-stage3`. + +Expected files in that Hugging Face repo root: +- `my_smpl_00781_4_all.pkl` +- `my_smpl_data_00781_4_all.pkl` +- `walking_toy_symmetric_pose_prior_with_cov_35parts.pkl` +- `amr_vitbb.pth` +- `config_s1_HYDRA.yaml` +- `s1ckpt_inference.ckpt` + +Optional Stage 3 prefetch expects: +- `config_s3_HYDRA.yaml` +- `s3ckpt_inference.ckpt` + +### Demo (without TTA) + +Run animal detection + PRIMA 3D pose/shape inference: + +```bash +bash demo.sh +``` + +Outputs are written to `demo_out/`. Edit `demo.sh` if you want to use a custom +checkpoint path. + +--- + +### Demo (with TTA) + +Run PRIMA inference with test-time adaptation: + +```bash +bash demo_tta.sh +``` + +Outputs are written to `demo_out_tta/` (before/after TTA renders, keypoints, and +optional meshes). Edit `demo_tta.sh` if you want to change the checkpoint, TTA +learning rate, or number of iterations. + +--- + +### Gradio demo + +We also provide a simple Gradio-based web demo for interactive testing in the +browser: + +```bash +python app.py \ + --checkpoint data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt \ + --out_folder demo_out_tta_gradio/ +``` + +This starts a local Gradio app (by default on http://127.0.0.1:7860), where +you can upload images and visualize PRIMA predictions and adaptation results. +The `s1ckpt_inference.ckpt` checkpoint is downloaded automatically if missing. + +`app.py` picks a **demo profile** automatically: + +| | **Local** (`python app.py`) | **Hugging Face Space** | +|--|--|--| +| PRIMA device | GPU if available, else CPU | CPU only | +| Detectron2 | X-101-FPN | R50-FPN (lighter) | +| Default TTA iterations | 30 | 0 (PRIMA-only by default) | +| Save `.obj` meshes | on | off | +| Preload checkpoint at startup | off | on | + +Override for testing: `PRIMA_DEMO_MODE=local` or `PRIMA_DEMO_MODE=space`. + +#### Hugging Face Space (maintainers) + +Demo images under `demo_data/` and `images/teaser.png` are tracked with **Git LFS** +(see `.gitattributes`) so they can be pushed to a Hugging Face Space under the Hub’s +LFS / **Xet** bridge. Install tooling once: + +```bash +brew install git-lfs git-xet +git xet install +git lfs install +``` + +Then from a clean checkout with LFS files present, redeploy the Space (same as `clean_redeploy_hf_space.sh`): + +```bash +./scripts/deploy_hf_space.sh +# or +./scripts/clean_redeploy_hf_space.sh +``` + +The script rsyncs the working tree (not `git archive`) so image files are materialized +before `git add` turns them into LFS blobs. + +--- + + +## Training and Evaluation + +### Dataset Setup + +Download datasets from [Animal3D](https://xujiacong.github.io/Animal3D/), [CtrlAni3D](https://github.com/luoxue-star/AniMer?tab=readme-ov-file#training), Quadruped2D, and [Animal Kingdom](https://drive.google.com/file/d/1dk2a0qB0fbVZ4X6eAgP6VJVXj0rxVfsJ/view?usp=drive_link). For Quadruped2D, download the images from [SuperAnimal-Quadruped80K](https://zenodo.org/records/14016777) and our processed annotations from [here](https://drive.google.com/drive/folders/1eBNboxVwl_eGPoC93zxf-U3hmE6e2f-f?usp=sharing). Put all the datasets under `datasets/`. + +### Training + +Two-stage training script: + +```bash +bash train.sh +``` + +Training outputs are written to `logs/train/runs//`. + + +### Evaluation + +```bash +python eval.py \ + --config data/PRIMAS1/.hydra/config.yaml \ + --checkpoint data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt +``` + +Common values for `--dataset` are controlled by: +- `configs_hydra/experiment/default_val.yaml` + +--- + + +## Acknowledgements + +This release builds on several open-source projects, including: +- [Detectron2](https://github.com/facebookresearch/detectron2) +- [BioCLIP](https://github.com/Imageomics/BioCLIP) +- [AniMer](https://github.com/luoxue-star/AniMer) +- [DeepLabCut](https://github.com/DeepLabCut/DeepLabCut) +- [SAM3DB](https://github.com/facebookresearch/sam-3d-body) + +--- + +## Citation + +If you use this code in your research, please cite our PRIMA paper. + +```bibtex +@misc{yu_prima, + title={PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation}, + author={Xiaohang Yu and Ti Wang and Mackenzie Weygandt Mathis}, +} +``` + +--- + +## Contact + +For issues, please open a GitHub issue in this repository. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..a5408267a3618dac2cb6f8eb9dcb69266d3104f0 --- /dev/null +++ b/app.py @@ -0,0 +1,713 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +"""Gradio demo for PRIMA + SuperAnimal + TTA. + +This script wraps the ``demo_tta.py`` pipeline into an interactive +Gradio interface. The overall logic follows: + +1. Given an input image, run Detectron2 to detect animals. +2. For each detected animal, run PRIMA for 3D pose/shape estimation. +3. Run the fine-tuned DeepLabCut SuperAnimal model to obtain PRIMA 26-keypoint + 2D predictions. +4. Run test-time adaptation (TTA) with user-specified lr and iters. +5. Render and save before/after TTA results and keypoint visualizations. + +""" + +import argparse +import os +import sys +import tempfile +import traceback +from dataclasses import dataclass +from functools import lru_cache +from types import SimpleNamespace +from typing import List, Optional, Tuple +from pathlib import Path + +import cv2 +import gradio as gr +import numpy as np +import torch +import torch.utils.data + +# Space demo on macOS: limit BLAS threads (PyRender + PyTorch on main thread only). +if sys.platform == "darwin" and os.environ.get("SPACE_ID"): + os.environ.setdefault("OMP_NUM_THREADS", "1") + torch.set_num_threads(1) + +# Repo-local minimal ``chumpy`` shim (see ``chumpy/__init__.py``) so SMAL pickles load +# without installing the full chumpy package in Space builds. +_REPO_ROOT = Path(__file__).resolve().parent +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from prima.utils.weights import ( + DEFAULT_HF_REPO_ID, + resolve_prima_checkpoint_path, +) +from prima.utils.detection import select_animal_boxes + + +# Default checkpoint path following README instructions +DEFAULT_CHECKPOINT = str(_REPO_ROOT / "data" / "PRIMAS1" / "checkpoints" / "s1ckpt_inference.ckpt") +DEFAULT_HF_ASSET_REPO = DEFAULT_HF_REPO_ID + +# Output folder for rendered images/meshes and keypoints +DEFAULT_OUT_FOLDER = "demo_out_tta_gradio" + +_D2_R50_CFG = "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml" +_D2_R50_URL = ( + "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/" + "faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl" +) +_D2_X101_CFG = "COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml" +_D2_X101_URL = ( + "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/" + "faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl" +) + +# Gradio example row: (image_rel, tta_lr, tta_iters, det_thresh, kp_thresh, side_view, save_mesh) +ExampleRow = Tuple[str, float, int, float, float, bool, bool] + + +@dataclass(frozen=True) +class DemoProfile: + """Runtime settings for either the full local app or the lightweight HF Space demo.""" + + mode: str + prima_device: str # "auto" (CUDA if available) or "cpu" + detectron_config_yaml: str + detectron_weights_url: str + detectron_device: str # "auto" or "cpu" + default_tta_iters: int + max_tta_iters: int + default_save_mesh: bool + default_side_view: bool + preload_assets: bool + example_rows: Tuple[ExampleRow, ...] + description: str + interface_title: str + + def resolve_prima_device(self) -> torch.device: + if self.prima_device == "cpu": + return torch.device("cpu") + return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + def resolve_detectron_device(self) -> str: + if self.detectron_device == "cpu": + return "cpu" + return "cuda" if torch.cuda.is_available() else "cpu" + + +LOCAL_DEMO_PROFILE = DemoProfile( + mode="local", + prima_device="auto", + detectron_config_yaml=_D2_X101_CFG, + detectron_weights_url=_D2_X101_URL, + detectron_device="auto", + default_tta_iters=30, + max_tta_iters=100, + default_save_mesh=True, + default_side_view=False, + preload_assets=False, + example_rows=( + ("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, True), + ("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, True), + ("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, True), + ("demo_data/beagle.jpg", 1e-6, 0, 0.7, 0.1, False, True), + ("demo_data/shepherd_hati.jpg", 1e-6, 0, 0.7, 0.1, False, True), + ), + description=( + "**Local demo** — full pipeline on your machine (GPU when available).\n\n" + "Detectron2 **X-101-FPN**, PRIMA mesh recovery, optional **DeepLabCut SuperAnimal + TTA**. " + "Set TTA iterations to **0** to skip adaptation. Outputs are saved under " + f"`{DEFAULT_OUT_FOLDER}`." + ), + interface_title=( + "PRIMA local demo (GPU/CPU) — detection, mesh recovery, optional TTA" + ), +) + +SPACE_DEMO_PROFILE = DemoProfile( + mode="space", + prima_device="cpu", + detectron_config_yaml=_D2_R50_CFG, + detectron_weights_url=_D2_R50_URL, + detectron_device="cpu", + default_tta_iters=0, + max_tta_iters=30, + default_save_mesh=False, + default_side_view=False, + preload_assets=True, + example_rows=( + ("demo_data/beagle.jpg", 1e-6, 0, 0.7, 0.1, False, False), + ("demo_data/000000015956_horse.png", 1e-6, 0, 0.7, 0.1, False, False), + ("demo_data/000000315905_zebra.jpg", 1e-6, 0, 0.7, 0.1, False, False), + ), + description=( + "**Hugging Face Space (cpu-basic)** — lightweight demo: **CPU-only**, Detectron2 **R50-FPN**, " + "PRIMA inference. TTA is optional (0 by default; increases runtime). Mesh `.obj` export is off " + "by default to save time and disk." + ), + interface_title="PRIMA on Hugging Face — lightweight CPU demo", +) + + +def _is_truthy_env(var_name: str) -> bool: + return os.environ.get(var_name, "").strip().lower() in {"1", "true", "yes", "on"} + + +def _running_on_space() -> bool: + return bool(os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID")) + + +@lru_cache(maxsize=1) +def get_demo_profile() -> DemoProfile: + """Select local vs Space profile. Override with ``PRIMA_DEMO_MODE=local|space``.""" + override = os.environ.get("PRIMA_DEMO_MODE", "").strip().lower() + if override == "local": + return LOCAL_DEMO_PROFILE + if override == "space": + return SPACE_DEMO_PROFILE + return SPACE_DEMO_PROFILE if _running_on_space() else LOCAL_DEMO_PROFILE + + +def _gradio_examples_for_interface(profile: DemoProfile) -> List[List]: + """Gradio prefetches example media at startup (paths must exist beside ``app.py``).""" + if _is_truthy_env("PRIMA_DISABLE_GRADIO_EXAMPLES"): + return [] + rows: List[List] = [] + for rel, *rest in profile.example_rows: + p = _REPO_ROOT / rel + if p.is_file(): + rows.append([str(p), *rest]) + return rows + + +def _should_preload_assets(profile: DemoProfile) -> bool: + preload_env = os.environ.get("PRIMA_PRELOAD_ASSETS") + if preload_env is not None: + return _is_truthy_env("PRIMA_PRELOAD_ASSETS") + return profile.preload_assets + +def _deeplabcut_available() -> bool: + try: + from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images # noqa: F401 + + return True + except Exception: + return False + + +def _preload_assets_once(checkpoint_path: str) -> None: + print("[startup] Ensuring demo assets from Hugging Face Hub...") + resolve_prima_checkpoint_path( + checkpoint_path, + data_dir=_REPO_ROOT / "data", + auto_download=True, + hf_repo_id=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO), + ) + print("[startup] Asset preload complete.") + + +def _load_prima_model(checkpoint_path: str = DEFAULT_CHECKPOINT): + """Load PRIMA model and renderer once for the Gradio app.""" + from prima.models import load_prima + from prima.utils.renderer import Renderer, cam_crop_to_full + + checkpoint_path = resolve_prima_checkpoint_path( + checkpoint_path, + data_dir=_REPO_ROOT / "data", + auto_download=True, + hf_repo_id=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO), + ) + checkpoint = Path(checkpoint_path) + cfg_path = checkpoint.parent.parent / ".hydra" / "config.yaml" + if not checkpoint.exists(): + raise FileNotFoundError( + f"Missing checkpoint: {checkpoint}. Download demo checkpoints/data as described in README." + ) + if not cfg_path.exists(): + raise FileNotFoundError( + f"Missing model config: {cfg_path}. Ensure the full checkpoint folder layout from README is present." + ) + + profile = get_demo_profile() + model, model_cfg = load_prima(checkpoint_path) + device = profile.resolve_prima_device() + model = model.to(device) + model.eval() + + renderer = Renderer(model_cfg, faces=model.smal.faces) + return model, model_cfg, renderer, cam_crop_to_full, device + + +def _build_detector(profile: Optional[DemoProfile] = None): + """Build Detectron2 animal detector (profile selects X-101+GPU locally vs R50+CPU on Space).""" + try: + import detectron2.config + import detectron2.engine + from detectron2 import model_zoo + except Exception as e: + print(f"[warn] Detectron2 unavailable ({type(e).__name__}: {e}); using full-image fallback bbox.") + return None + + if profile is None: + profile = get_demo_profile() + config_yaml = profile.detectron_config_yaml + weights = profile.detectron_weights_url + device_str = profile.resolve_detectron_device() + print(f"[detectron2] mode={profile.mode} config={config_yaml} device={device_str}") + + cfg = detectron2.config.get_cfg() + cfg.merge_from_file(model_zoo.get_config_file(config_yaml)) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 + cfg.MODEL.WEIGHTS = weights + cfg.MODEL.DEVICE = device_str + detector = detectron2.engine.DefaultPredictor(cfg) + return detector + + +def _load_model_and_detector_for_demo(checkpoint_path: str, profile: DemoProfile): + """Load PRIMA and Detectron2 once for the Gradio session (main thread only).""" + model, model_cfg, renderer, cam_crop_to_full_fn, device = _load_prima_model(checkpoint_path) + detector = _build_detector(profile) + return model, model_cfg, renderer, cam_crop_to_full_fn, device, detector + + +def _detect_animal_boxes( + detector, + img_bgr: np.ndarray, + det_thresh: float, +) -> Optional[np.ndarray]: + """Return Nx4 XYXY boxes or None if no animal detections.""" + if detector is None: + h, w = img_bgr.shape[:2] + return np.array([[0.0, 0.0, float(max(1, w - 1)), float(max(1, h - 1))]], dtype=np.float32) + + det_out = detector(img_bgr) + det_instances = det_out["instances"] + boxes, suppressed = select_animal_boxes(det_instances, score_threshold=float(det_thresh)) + if suppressed > 0: + print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s)") + if len(boxes) == 0: + return None + return boxes + + +# SuperAnimal defaults (same as in demo_tta parser) +SUPER_ANIMAL_ARGS = SimpleNamespace( + superanimal_name="superanimal_quadruped", + superanimal_model_name="hrnet_w32", + superanimal_detector_name="fasterrcnn_resnet50_fpn_v2", + superanimal_max_individuals=1, + saved_2d_model_path="", + pytorch_config_2d_path=str(_REPO_ROOT / "configs" / "sa_finetune_hrnet_w32.yaml"), +) + + +def _collect_animal_results( + model, + model_cfg, + renderer, + cam_crop_to_full_fn, + device, + detector, + out_folder: str, + img_rgb: np.ndarray, + tta_lr: float, + tta_num_iters: int, + det_thresh: float, + kp_conf_thresh: float, + side_view: bool, + save_mesh: bool, + boxes: Optional[np.ndarray] = None, +) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], str | None, str | None]: + """Run detection + PRIMA + SuperAnimal + TTA on a single RGB image. + + Returns: + before_imgs: list of HxWx3 RGB images (before TTA) for all animals + after_imgs: list of HxWx3 RGB images (after TTA) for all animals + kpt_imgs: list of HxWx3 RGB keypoint visualizations + first_before_mesh: path to first animal's before-TTA mesh (.obj) or None + first_after_mesh: path to first animal's after-TTA mesh (.obj) or None + """ + from prima.utils import recursive_to + from prima.datasets.vitdet_dataset import ViTDetDataset + from demo_tta import ( + denorm_patch_to_rgb, + resolve_sa_weights_path, + run_superanimal_on_patch, + save_keypoint_vis, + tta_optimize, + ) + + if int(tta_num_iters) > 0 and not SUPER_ANIMAL_ARGS.saved_2d_model_path: + SUPER_ANIMAL_ARGS.saved_2d_model_path = resolve_sa_weights_path("") + + img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + if boxes is None: + boxes = _detect_animal_boxes(detector, img_bgr, det_thresh) + if boxes is None: + return [], [], [], None, None + + dataset = ViTDetDataset(model_cfg, img_bgr, boxes) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + + before_imgs: List[np.ndarray] = [] + after_imgs: List[np.ndarray] = [] + kpt_imgs: List[np.ndarray] = [] + before_mesh_paths: List[str] = [] + after_mesh_paths: List[str] = [] + + img_token = next(tempfile._get_candidate_names()) + + for batch in dataloader: + batch = recursive_to(batch, device) + + with torch.no_grad(): + out_before = model(batch) + + animal_id = int(batch["animalid"][0]) + + # Save/render before TTA + img_fn = f"{img_token}" + from demo_tta import render_and_save # imported lazily to avoid circular issues + + render_and_save( + renderer, + cam_crop_to_full_fn, + out_before, + batch, + img_fn, + animal_id, + out_folder, + suffix="before_tta", + side_view=side_view, + save_mesh=save_mesh, + ) + + before_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.png") + if os.path.exists(before_png_path): + before_bgr = cv2.imread(before_png_path) + if before_bgr is not None: + before_imgs.append(cv2.cvtColor(before_bgr, cv2.COLOR_BGR2RGB)) + + if save_mesh: + before_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.obj") + if os.path.exists(before_obj_path): + before_mesh_paths.append(before_obj_path) + + if int(tta_num_iters) <= 0: + render_and_save( + renderer, + cam_crop_to_full_fn, + out_before, + batch, + img_fn, + animal_id, + out_folder, + suffix="after_tta", + side_view=side_view, + save_mesh=save_mesh, + ) + + after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png") + if os.path.exists(after_png_path): + after_bgr = cv2.imread(after_png_path) + if after_bgr is not None: + after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB)) + + if save_mesh: + after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj") + if os.path.exists(after_obj_path): + after_mesh_paths.append(after_obj_path) + continue + + # Prepare patch for SuperAnimal + patch_rgb = denorm_patch_to_rgb(batch["img"][0]) + with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir: + bodyparts_xyc = run_superanimal_on_patch(patch_rgb, SUPER_ANIMAL_ARGS, tmp_dir) + + if bodyparts_xyc is None: + # No keypoints => skip TTA for this animal + continue + + kpts_xyc = bodyparts_xyc + kpts_xyc[kpts_xyc[:, 2] < float(kp_conf_thresh), 2] = 0.0 + + # Save keypoint visualization and npy + kpt_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png") + save_keypoint_vis(patch_rgb, kpts_xyc, kpt_png_path) + npy_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy") + np.save(npy_path, kpts_xyc) + + if os.path.exists(kpt_png_path): + kpt_bgr = cv2.imread(kpt_png_path) + if kpt_bgr is not None: + kpt_imgs.append(cv2.cvtColor(kpt_bgr, cv2.COLOR_BGR2RGB)) + + # Normalize keypoints to [-0.5, 0.5] as in demo_tta + patch_h, patch_w = patch_rgb.shape[:2] + kpts_norm = kpts_xyc.copy() + kpts_norm[:, 0] = kpts_norm[:, 0] / float(patch_w) - 0.5 + kpts_norm[:, 1] = kpts_norm[:, 1] / float(patch_h) - 0.5 + gt_kpts_norm = torch.from_numpy(kpts_norm[None]).to(device=device, dtype=batch["img"].dtype) + + # Run TTA + out_after = tta_optimize( + model, + batch, + gt_kpts_norm, + num_iters=int(tta_num_iters), + lr=float(tta_lr), + ) + + render_and_save( + renderer, + cam_crop_to_full_fn, + out_after, + batch, + img_fn, + animal_id, + out_folder, + suffix="after_tta", + side_view=side_view, + save_mesh=save_mesh, + ) + + after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png") + if os.path.exists(after_png_path): + after_bgr = cv2.imread(after_png_path) + if after_bgr is not None: + after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB)) + + if save_mesh: + after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj") + if os.path.exists(after_obj_path): + after_mesh_paths.append(after_obj_path) + + first_before_mesh = before_mesh_paths[0] if before_mesh_paths else None + first_after_mesh = after_mesh_paths[0] if after_mesh_paths else None + + return before_imgs, after_imgs, kpt_imgs, first_before_mesh, first_after_mesh + + +def build_demo(checkpoint_path: str = DEFAULT_CHECKPOINT, out_folder: str = DEFAULT_OUT_FOLDER) -> gr.Interface: + profile = get_demo_profile() + print( + f"[demo] profile={profile.mode} prima={profile.resolve_prima_device()} " + f"detectron={profile.detectron_config_yaml} d2_device={profile.resolve_detectron_device()}" + ) + os.makedirs(out_folder, exist_ok=True) + runtime_cache = { + "model": None, + "model_cfg": None, + "renderer": None, + "cam_crop_to_full_fn": None, + "device": None, + "detector": None, + } + + def gradio_inference( + image: np.ndarray, + tta_lr: float, + tta_num_iters: int, + det_thresh: float, + kp_conf_thresh: float, + side_view: bool, + save_mesh: bool, + ): + """Wrapper for Gradio. ``image`` is an RGB numpy array. + + Yields intermediate status so long first-run (Hub downloads + model load) + and long inference do not hit silent client/proxy WebSocket timeouts. + """ + + if image is None: + yield None, None, None, "No image provided." + return + + if int(tta_num_iters) > 0 and not _deeplabcut_available(): + yield ( + None, + None, + None, + "DeepLabCut is not installed. Set **TTA iterations** to **0** for PRIMA-only inference, " + "or install `deeplabcut` (see README / requirements.txt).", + ) + return + + if image.dtype != np.uint8: + img_rgb = np.clip(image, 0, 255).astype(np.uint8) + else: + img_rgb = image + + yield None, None, None, "Queued; preparing run…" + + if runtime_cache["model"] is None: + yield ( + None, + None, + None, + "First run: downloading demo assets from Hugging Face (large checkpoint) " + "and loading the model. This can take many minutes.", + ) + try: + model, model_cfg, renderer, cam_crop_to_full_fn, device, detector = _load_model_and_detector_for_demo( + checkpoint_path, profile + ) + except Exception: + yield None, None, None, f"Model initialization failed:\n{traceback.format_exc()}" + return + runtime_cache["model"] = model + runtime_cache["model_cfg"] = model_cfg + runtime_cache["renderer"] = renderer + runtime_cache["cam_crop_to_full_fn"] = cam_crop_to_full_fn + runtime_cache["device"] = device + runtime_cache["detector"] = detector + yield None, None, None, "Model loaded." + + try: + yield None, None, None, "Running animal detection…" + img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + boxes = _detect_animal_boxes(runtime_cache["detector"], img_bgr, det_thresh) + if boxes is None: + yield ( + None, + None, + None, + "No animal detected. Try lowering the detection threshold or another image.", + ) + return + yield ( + None, + None, + None, + f"Detected {len(boxes)} animal region(s). Running PRIMA (+ SuperAnimal/TTA if enabled)…", + ) + before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = _collect_animal_results( + runtime_cache["model"], + runtime_cache["model_cfg"], + runtime_cache["renderer"], + runtime_cache["cam_crop_to_full_fn"], + runtime_cache["device"], + runtime_cache["detector"], + out_folder, + img_rgb, + tta_lr=tta_lr, + tta_num_iters=tta_num_iters, + det_thresh=det_thresh, + kp_conf_thresh=kp_conf_thresh, + side_view=side_view, + save_mesh=save_mesh, + boxes=boxes, + ) + except Exception: + yield None, None, None, f"Inference failed:\n{traceback.format_exc()}" + return + + first_before = before_imgs[0] if before_imgs else None + first_after = after_imgs[0] if after_imgs else None + first_kpts = kpt_imgs[0] if kpt_imgs else None + if first_before is None and first_after is None: + yield ( + None, + None, + None, + "No output generated. Try an image with a clearly visible quadruped.", + ) + return + yield first_before, first_after, first_kpts, "OK" + + _gradio_examples = _gradio_examples_for_interface(profile) + _iface_kw = dict( + fn=gradio_inference, + analytics_enabled=False, + cache_examples=False, + inputs=[ + gr.Image( + label="Input image", + type="numpy", + sources=["upload", "clipboard"], + ), + gr.Slider( + label="TTA learning rate", + minimum=1e-7, + maximum=1e-4, + value=1e-6, + step=1e-7, + ), + gr.Slider( + label="TTA iterations", + minimum=0, + maximum=profile.max_tta_iters, + value=profile.default_tta_iters, + step=1, + info="Set to 0 to disable TTA and reuse the initial PRIMA prediction.", + ), + gr.Slider( + label="Detection threshold", + minimum=0.3, + maximum=0.9, + value=0.7, + step=0.05, + ), + gr.Slider( + label="Keypoint confidence threshold", + minimum=0.0, + maximum=1.0, + value=0.1, + step=0.05, + ), + gr.Checkbox(label="Render side view", value=profile.default_side_view), + gr.Checkbox(label="Save meshes (.obj)", value=profile.default_save_mesh), + ], + outputs=[ + gr.Image(label="Before TTA"), + gr.Image(label="After TTA"), + gr.Image(label="PRIMA 26 keypoints"), + gr.Textbox(label="Status / Traceback", lines=12), + ], + title=profile.interface_title, + description=profile.description, + ) + if _gradio_examples: + _iface_kw["examples"] = _gradio_examples + demo = gr.Interface(**_iface_kw) + demo.queue(max_size=8, default_concurrency_limit=1) + return demo + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Gradio demo for PRIMA + SuperAnimal + TTA") + parser.add_argument( + "--checkpoint", + type=str, + default=DEFAULT_CHECKPOINT, + help="Path to the pretrained PRIMA checkpoint", + ) + parser.add_argument( + "--out_folder", + type=str, + default=DEFAULT_OUT_FOLDER, + help="Folder used to save rendered outputs and meshes", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + profile = get_demo_profile() + if _should_preload_assets(profile): + _preload_assets_once(args.checkpoint) + demo = build_demo(checkpoint_path=args.checkpoint, out_folder=args.out_folder) + demo.launch(inbrowser=False) diff --git a/chumpy/__init__.py b/chumpy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5b170d272607008ce7df4f3321694db1bc6f0a --- /dev/null +++ b/chumpy/__init__.py @@ -0,0 +1,16 @@ +from __future__ import annotations +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + + +"""Minimal ``chumpy`` compatibility for unpickling legacy SMAL model configs.""" + +from .ch import Ch, ChArray, materialize + +__all__ = ["Ch", "ChArray", "materialize"] diff --git a/chumpy/ch.py b/chumpy/ch.py new file mode 100644 index 0000000000000000000000000000000000000000..9bdb31ddfccb79a8b2b12bc58de45a7aef6da718 --- /dev/null +++ b/chumpy/ch.py @@ -0,0 +1,66 @@ +from __future__ import annotations +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + + +"""``chumpy.ch`` namespace expected by legacy SMAL pickles.""" + +import numpy as np + + +class Ch: + """Minimal stand-in for ``chumpy.ch.Ch`` (unpickling only).""" + + def __init__(self, *args, **kwargs): + self._data = None + if args: + self._data = np.asarray(args[0]) + + def _resolve(self) -> np.ndarray: + # Real chumpy Ch instances store the underlying ndarray on attribute ``x``; + # legacy pickles unpickle by restoring ``__dict__`` without calling ``__init__``, + # so try common attribute names before falling back to ``_data``. + for attr in ("x", "_x", "_data"): + val = self.__dict__.get(attr) + if val is not None: + return np.asarray(val) + if self._data is not None: + return np.asarray(self._data) + return np.zeros((), dtype=np.float32) + + @property + def r(self) -> np.ndarray: + return self._resolve() + + def __array__(self, dtype=None): + arr = self.r() + if dtype is not None: + arr = arr.astype(dtype, copy=False) + return arr + + +class ChArray(np.ndarray): + """Minimal stand-in for ``chumpy.ch.ChArray``.""" + + +def materialize(value, dtype=np.float32) -> np.ndarray: + """Recursively unwrap ``Ch`` / object arrays from legacy SMAL pickles.""" + if isinstance(value, Ch): + return np.asarray(value.r(), dtype=dtype) + if isinstance(value, np.ndarray): + if value.dtype == object: + flat = [materialize(x, dtype=dtype) for x in value.ravel()] + return np.stack(flat).reshape(value.shape) + return np.asarray(value, dtype=dtype) + if isinstance(value, (list, tuple)): + return np.asarray([materialize(x, dtype=dtype) for x in value], dtype=dtype) + return np.asarray(value, dtype=dtype) + + +__all__ = ["Ch", "ChArray", "materialize"] diff --git a/configs/sa_finetune_hrnet_w32.yaml b/configs/sa_finetune_hrnet_w32.yaml new file mode 100644 index 0000000000000000000000000000000000000000..83dd64e2b05982bb1e4782a627528bea2c81b85a --- /dev/null +++ b/configs/sa_finetune_hrnet_w32.yaml @@ -0,0 +1,220 @@ +# DeepLabCut pytorch_config for the PRIMA TTA 2D pose model: +# SuperAnimal-Quadruped HRNet-w32 backbone fine-tuned on Animal3D, with +# the heatmap head re-trained for the 26-joint Animal3D / PRIMA layout. +# +# Used by demo_tta.py via DLC's `superanimal_analyze_images(..., +# customized_model_config=, customized_pose_checkpoint=)`. Only the pose model is fine-tuned; the bounding-box +# detector (Faster R-CNN) is the stock SuperAnimal-Quadruped one +# resolved by DLC at runtime. +data: + bbox_margin: 20 + colormode: RGB + inference: + normalize_images: true + top_down_crop: + width: 256 + height: 256 + auto_padding: + pad_width_divisor: 32 + pad_height_divisor: 32 + train: + affine: + p: 0.5 + rotation: 30 + scaling: + - 1.0 + - 1.0 + translation: 0 + gaussian_noise: 12.75 + motion_blur: true + normalize_images: true + top_down_crop: + width: 256 + height: 256 + auto_padding: + pad_width_divisor: 32 + pad_height_divisor: 32 +detector: + data: + colormode: RGB + inference: + normalize_images: true + train: + affine: + p: 0.5 + rotation: 30 + scaling: + - 1.0 + - 1.0 + translation: 40 + collate: + type: ResizeFromDataSizeCollate + min_scale: 0.4 + max_scale: 1.0 + min_short_side: 128 + max_short_side: 1152 + multiple_of: 32 + to_square: false + hflip: true + normalize_images: true + device: auto + model: + type: FasterRCNN + freeze_bn_stats: true + freeze_bn_weights: false + variant: fasterrcnn_resnet50_fpn_v2 + runner: + type: DetectorTrainingRunner + key_metric: test.mAP@50:95 + key_metric_asc: true + eval_interval: 10 + optimizer: + type: AdamW + params: + lr: 0.0001 + scheduler: + type: LRListScheduler + params: + milestones: + - 160 + lr_list: + - - 1e-05 + snapshots: + max_snapshots: 5 + save_epochs: 25 + save_optimizer_state: false + train_settings: + batch_size: 1 + dataloader_workers: 0 + dataloader_pin_memory: false + display_iters: 500 + epochs: 250 +device: auto +inference: + multithreading: + enabled: true + queue_length: 4 + timeout: 30.0 + compile: + enabled: false + backend: inductor + autocast: + enabled: false +metadata: + project_path: "" + pose_config_path: "" + bodyparts: + - left_eye + - right_eye + - chin + - left_front_paw + - right_front_paw + - left_back_paw + - right_back_paw + - tail_base + - left_front_thigh + - right_front_thigh + - left_back_thigh + - right_back_thigh + - left_shoulder + - right_shoulder + - left_front_knee + - right_front_knee + - left_back_knee + - right_back_knee + - neck_base + - tail_mid + - left_ear_base + - right_ear_base + - left_mouth_corner + - right_mouth_corner + - nose + - tail_tip_first + unique_bodyparts: [] + individuals: + - individual000 + with_identity: false +method: td +model: + backbone: + type: HRNet + model_name: hrnet_w32 + freeze_bn_stats: true + freeze_bn_weights: false + interpolate_branches: false + increased_channel_count: false + backbone_output_channels: 32 + heads: + bodypart: + type: HeatmapHead + weight_init: normal + predictor: + type: HeatmapPredictor + apply_sigmoid: false + clip_scores: true + location_refinement: true + locref_std: 7.2801 + target_generator: + type: HeatmapGaussianGenerator + num_heatmaps: 26 + pos_dist_thresh: 17 + heatmap_mode: KEYPOINT + gradient_masking: true + background_weight: 0.0 + generate_locref: true + locref_std: 7.2801 + criterion: + heatmap: + type: WeightedMSECriterion + weight: 1.0 + locref: + type: WeightedHuberCriterion + weight: 0.05 + heatmap_config: + channels: + - 32 + kernel_size: [] + strides: [] + final_conv: + out_channels: 26 + kernel_size: 1 + locref_config: + channels: + - 32 + kernel_size: [] + strides: [] + final_conv: + out_channels: 52 + kernel_size: 1 +net_type: hrnet_w32 +runner: + type: PoseTrainingRunner + gpus: + key_metric: test.mAP + key_metric_asc: true + eval_interval: 10 + optimizer: + type: AdamW + params: + lr: 0.0001 + scheduler: + type: LRListScheduler + params: + lr_list: + - - 1e-05 + - - 1e-06 + milestones: + - 160 + - 190 + snapshots: + max_snapshots: 5 + save_epochs: 10 + save_optimizer_state: false +train_settings: + batch_size: 64 + dataloader_workers: 8 + dataloader_pin_memory: false + display_iters: 500 + epochs: 200 + seed: 42 diff --git a/configs_hydra/experiment/default.yaml b/configs_hydra/experiment/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d5fc29b962bfd8b61f934582801f0f4dff29cc7 --- /dev/null +++ b/configs_hydra/experiment/default.yaml @@ -0,0 +1,28 @@ +# @package _global_ + +SMAL: + DATA_DIR: data/smal + MODEL_PATH: data/smal/my_smpl_00781_4_all.pkl + SHAPE_PRIOR_PATH: data/smal/my_smpl_data_00781_4_all.pkl + POSE_PRIOR_PATH: data/smal/walking_toy_symmetric_pose_prior_with_cov_35parts.pkl + NUM_JOINTS: 34 + +EXTRA: + FOCAL_LENGTH: 1000 + NUM_LOG_IMAGES: 4 + NUM_LOG_SAMPLES_PER_IMAGE: 4 + PELVIS_IND: 0 + +DATASETS: + CONFIG: + SCALE_FACTOR: 0.3 + ROT_FACTOR: 30 + TRANS_FACTOR: 0.02 + COLOR_SCALE: 0.2 + ROT_AUG_RATE: 0.6 + TRANS_AUG_RATE: 0.5 + DO_FLIP: False + FLIP_AUG_RATE: 0.0 + EXTREME_CROP_AUG_RATE: 0.0 + EXTREME_CROP_AUG_LEVEL: 1 + diff --git a/configs_hydra/experiment/default_val.yaml b/configs_hydra/experiment/default_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6656b3ac036f2e5873c87162650d5a46fa3f073f --- /dev/null +++ b/configs_hydra/experiment/default_val.yaml @@ -0,0 +1,34 @@ +# @package _global_ + +DATASETS: + ANIMAL3D: + ROOT_IMAGE: ./datasets/animal3d/ + JSON_FILE: + TEST: ./datasets/animal3d/test.json + CONTROL_ANIMAL3D: + ROOT_IMAGE: ./datasets/control_animal3dlatest/ + JSON_FILE: + TEST: ./datasets/control_animal3dlatest/test.json + QUADRUPED2D: + ROOT_IMAGE: ./datasets/quadruped2d/ + JSON_FILE: + TEST: ./datasets/quadruped2d/test.json + ANIMAL_KINGDOM: + ROOT_IMAGE: ./datasets/Animal_Kingdom_test/ + JSON_FILE: + TEST: ./datasets/Animal_Kingdom_test/test.json + CONFIG: + SCALE_FACTOR: 0.0 + ROT_FACTOR: 0 + TRANS_FACTOR: 0.0 + COLOR_SCALE: 0.0 + ROT_AUG_RATE: 0.0 + TRANS_AUG_RATE: 0.0 + DO_FLIP: False + FLIP_AUG_RATE: 0.0 + EXTREME_CROP_AUG_RATE: 0.0 + EXTREME_CROP_AUG_LEVEL: 1 + +METRIC: + PCK_THRESHOLD: [0.10, 0.15] + diff --git a/configs_hydra/experiment/primaStage1.yaml b/configs_hydra/experiment/primaStage1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7caccd6e22d1572d32b5d37a284b41c75a7bb74b --- /dev/null +++ b/configs_hydra/experiment/primaStage1.yaml @@ -0,0 +1,83 @@ +# @package _global_ + +defaults: + - default.yaml + +GENERAL: + + + TOTAL_STEPS: 63_000 + LOG_STEPS: 63 + VAL_STEPS: 63 + VAL_EPOCHS: 1 + CHECKPOINT_EPOCHS: 1 + CHECKPOINT_SAVE_TOP_K: 2 + NUM_WORKERS: 8 + PREFETCH_FACTOR: 2 + +LOSS_WEIGHTS: + KEYPOINTS_3D: 0.05 + KEYPOINTS_2D: 0.01 + INTERMEDIATE_KP2D: 0.001 + INTERMEDIATE_KP3D: 0.001 + GLOBAL_ORIENT: 0.005 + POSE: 0.001 + BETAS: 0.0005 + TRANSL: 0.0005 + ADVERSARIAL: 0.0005 + SUPCON: 0.0005 + + +TRAIN: + LR: 3.75e-6 + WEIGHT_DECAY: 1e-4 + BATCH_SIZE: 48 + LOSS_REDUCTION: mean + NUM_TRAIN_SAMPLES: 2 + NUM_TEST_SAMPLES: 64 + POSE_2D_NOISE_RATIO: 0.01 + SMPL_PARAM_NOISE_RATIO: 0.005 + +MODEL: + IMAGE_SIZE: 256 + IMAGE_MEAN: [0.485, 0.456, 0.406] + IMAGE_STD: [0.229, 0.224, 0.225] + BACKBONE: + TYPE: vith + PRETRAINED_WEIGHTS: ./data/amr_vitbb.pth + FREEZE: False + + # Enable BioClip embedding + USE_BIOCLIP_EMBEDDING: True + BIOCLIP_EMBEDDING: + EMBED_DIM: 1280 # Match DINOv2 output dimension for token-wise concatenation + TYPE: bioclip1 + + # Enable 2D keypoint embedding for initialization; NewBioGuidedSMALPoseDecoder updates it dynamically + USE_KEYPOINT_EMBEDDING: False + + SMAL_HEAD: + TYPE: new_bio_pose_transformer_decoder # Use the newer version with SAM3D-style hierarchical updates + IN_CHANNELS: 1280 + IEF_ITERS: 3 + + # Pose Transformer Decoder configuration + DECODER_DIM: 1280 + NUM_DECODER_LAYERS: 6 + NUM_HEADS: 8 + MLP_RATIO: 4.0 + + # Keypoint token configuration specific to NewBioGuidedSMALPoseDecoder + USE_KEYPOINT_2D_TOKENS: True # Enable 2D keypoint tokens with SAM3D-style dynamic updates + USE_KEYPOINT_3D_TOKENS: True # Enable 3D keypoint tokens with pelvis normalization + KEYPOINT_TOKEN_UPDATE: True # Enable hierarchical keypoint prediction and token updates + KP2D_INJECT_IMAGE_FEAT: True # Key setting: inject image features via grid_sample + + +DATASETS: + ANIMAL3D: + ROOT_IMAGE: ./datasets/animal3d/ + JSON_FILE: + TRAIN: ./datasets/animal3d/train.json + TEST: ./datasets/animal3d/test.json + WEIGHT: 1.0 diff --git a/configs_hydra/experiment/primaStage2.yaml b/configs_hydra/experiment/primaStage2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..feee9d7f0541028598765d08f5e1052344f737d9 --- /dev/null +++ b/configs_hydra/experiment/primaStage2.yaml @@ -0,0 +1,113 @@ +# @package _global_ + +defaults: + - default.yaml + +GENERAL: + + + TOTAL_STEPS: 450_000 + LOG_STEPS: 533 + VAL_STEPS: 533 + VAL_EPOCHS: 1 + CHECKPOINT_EPOCHS: 1 + CHECKPOINT_SAVE_TOP_K: 2 + NUM_WORKERS: 2 + PREFETCH_FACTOR: 2 + +LOSS_WEIGHTS: + KEYPOINTS_3D: 0.05 + KEYPOINTS_2D: 0.01 + INTERMEDIATE_KP2D: 0.001 + INTERMEDIATE_KP3D: 0.001 + GLOBAL_ORIENT: 0.005 + POSE: 0.001 + BETAS: 0.0005 + TRANSL: 0.0005 + ADVERSARIAL: 0.0 + SUPCON: 0.0005 + + +TRAIN: + LR: 3.75e-6 + WEIGHT_DECAY: 1e-4 + BATCH_SIZE: 48 + LOSS_REDUCTION: mean + NUM_TRAIN_SAMPLES: 2 + NUM_TEST_SAMPLES: 64 + POSE_2D_NOISE_RATIO: 0.01 + SMPL_PARAM_NOISE_RATIO: 0.005 + +MODEL: + IMAGE_SIZE: 256 + IMAGE_MEAN: [0.485, 0.456, 0.406] + IMAGE_STD: [0.229, 0.224, 0.225] + BACKBONE: + TYPE: vith + PRETRAINED_WEIGHTS: ./data/amr_vitbb.pth + FREEZE: False + + # Enable BioClip embedding + USE_BIOCLIP_EMBEDDING: True + BIOCLIP_EMBEDDING: + EMBED_DIM: 1280 # Match vit output dimension for token-wise concatenation + TYPE: bioclip1 + + # Enable 2D keypoint embedding + USE_KEYPOINT_EMBEDDING: False + KEYPOINT_EMBEDDING: + NUM_KEYPOINTS: 26 # Number of SMAL keypoints + KEYPOINT_DIM: 2 # 2D coordinates (x, y) + EMBED_DIM: 1280 # Match vit output dimension + HIDDEN_DIM: 512 # Hidden layer dimension in MLP + TYPE: 'token' # Use token-based embedding (recommended) + + SMAL_HEAD: + TYPE: new_bio_pose_transformer_decoder # Use the newer version with SAM3D-style hierarchical updates + IN_CHANNELS: 1280 + IEF_ITERS: 1 + + # Pose Transformer Decoder configuration + DECODER_DIM: 1280 + NUM_DECODER_LAYERS: 6 + NUM_HEADS: 8 + MLP_RATIO: 4.0 + + # Keypoint token configuration specific to NewBioGuidedSMALPoseDecoder + USE_KEYPOINT_2D_TOKENS: True # Enable 2D keypoint tokens with SAM3D-style dynamic updates + USE_KEYPOINT_3D_TOKENS: True # Enable 3D keypoint tokens with pelvis normalization + KEYPOINT_TOKEN_UPDATE: True # Enable hierarchical keypoint prediction and token updates + KP2D_INJECT_IMAGE_FEAT: True # Key setting: inject image features via grid_sample + + # Legacy transformer config (kept for compatibility) + TRANSFORMER_DECODER: + depth: 6 + heads: 8 + mlp_dim: 1024 + dim_head: 64 + dropout: 0.0 + emb_dropout: 0.0 + norm: layer + context_dim: 1280 + + + +DATASETS: + ANIMAL3D: + ROOT_IMAGE: ./datasets/animal3d/ + JSON_FILE: + TRAIN: ./datasets/animal3d/train.json + TEST: ./datasets/animal3d/test.json + WEIGHT: 1.0 + CONTROL_ANIMAL3D: + ROOT_IMAGE: ./datasets/control_animal3dlatest/ + JSON_FILE: + TRAIN: ./datasets/control_animal3dlatest/train.json + TEST: ./datasets/control_animal3dlatest/test.json + WEIGHT: 0.5 + QUADRUPED2D: + ROOT_IMAGE: ./datasets/quadruped2d/ + JSON_FILE: + TRAIN: ./datasets/quadruped2d/train.json + TEST: ./datasets/quadruped2d/test.json + WEIGHT: 0.15 diff --git a/configs_hydra/extras/default.yaml b/configs_hydra/extras/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9c6b622283a647fbc513166fc14f016cc3ed8a0 --- /dev/null +++ b/configs_hydra/extras/default.yaml @@ -0,0 +1,8 @@ +# disable python warnings if they annoy you +ignore_warnings: False + +# ask user for tags if none are provided in the config +enforce_tags: True + +# pretty print config tree at the start of the run using Rich library +print_config: True diff --git a/configs_hydra/hydra/default.yaml b/configs_hydra/hydra/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c30c188f4e68b205ec0f1e5679345626fe187164 --- /dev/null +++ b/configs_hydra/hydra/default.yaml @@ -0,0 +1,26 @@ +# @package _global_ +# https://hydra.cc/docs/configure_hydra/intro/ + +# enable color logging +defaults: + - override /hydra/hydra_logging: colorlog + - override /hydra/job_logging: colorlog + +# exp_name: ovrd_${hydra:job.override_dirname} +exp_name: ${now:%Y-%m-%d}_${now:%H-%M-%S} + +hydra: + run: + dir: ${paths.log_dir}/${task_name}/runs/${exp_name} + sweep: + dir: ${paths.log_dir}/${task_name}/multiruns/${exp_name} + subdir: ${hydra.job.num} + job: + config: + override_dirname: + exclude_keys: + - trainer + - trainer.devices + - trainer.num_nodes + - callbacks + - debug diff --git a/configs_hydra/launcher/local.yaml b/configs_hydra/launcher/local.yaml new file mode 100644 index 0000000000000000000000000000000000000000..06fa483d5c25b119f3543565a49f0ce52ff8dea7 --- /dev/null +++ b/configs_hydra/launcher/local.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +defaults: + - override /hydra/launcher: submitit_local + +hydra: + launcher: + timeout_min: 10_080 # 7 days + nodes: 1 + tasks_per_node: ${trainer.devices} + cpus_per_task: 8 + gpus_per_node: ${trainer.devices} + name: amr diff --git a/configs_hydra/launcher/slurm.yaml b/configs_hydra/launcher/slurm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f5bd1d2de7bfa68e7d16400f3c48b02fb9fdadd2 --- /dev/null +++ b/configs_hydra/launcher/slurm.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +defaults: + - override /hydra/launcher: submitit_slurm + +hydra: + launcher: + timeout_min: 10_080 # 7 days + max_num_timeout: 3 + partition: g40 + qos: idle + nodes: 1 + tasks_per_node: ${trainer.devices} + gpus_per_task: null + cpus_per_task: 12 + gpus_per_node: ${trainer.devices} + cpus_per_gpu: null + comment: prima + name: prima + setup: + - module load cuda openmpi libfabric-aws + - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 diff --git a/configs_hydra/paths/default.yaml b/configs_hydra/paths/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2afd22a65d1b34d881943cb48ee4ce3ff37d165 --- /dev/null +++ b/configs_hydra/paths/default.yaml @@ -0,0 +1,18 @@ +# path to root directory +# this requires PROJECT_ROOT environment variable to exist +# PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py` +root_dir: ${oc.env:PROJECT_ROOT} + +# path to data directory +data_dir: ${paths.root_dir}/data/ + +# path to logging directory +log_dir: logs/ + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like ckpts and metrics +output_dir: ${hydra:runtime.output_dir} + +# path to working directory +work_dir: ${hydra:runtime.cwd} diff --git a/configs_hydra/train.yaml b/configs_hydra/train.yaml new file mode 100644 index 0000000000000000000000000000000000000000..74d295905aa006ee8fd05b9df29c330021e10476 --- /dev/null +++ b/configs_hydra/train.yaml @@ -0,0 +1,46 @@ +# @package _global_ + +# specify here default configuration +# order of defaults determines the order in which configs override each other +defaults: + - _self_ + - trainer: ddp.yaml + - paths: default.yaml + - extras: default.yaml + - hydra: default.yaml + + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: null + - texture_exp: null + + # optional local config for machine/user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional launcher: local.yaml + # - optional launcher: slurm.yaml + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + +# task name, determines output directory path +task_name: "train" + +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` +# appending lists from command line is currently not supported :( +# https://github.com/facebookresearch/hydra/issues/1547 +tags: ["dev"] + +# set False to skip model training +train: True + +# evaluate on test set, using best model weights achieved during training +# lightning chooses best weights based on the metric specified in checkpoint callback +test: False + +# simply provide checkpoint path to resume training +ckpt_path: True + +# seed for random number generators in pytorch, numpy and python.random +seed: null diff --git a/configs_hydra/trainer/cpu.yaml b/configs_hydra/trainer/cpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c4a08a11630deecaf5996776c70d9a330eca454e --- /dev/null +++ b/configs_hydra/trainer/cpu.yaml @@ -0,0 +1,6 @@ +defaults: + - default.yaml + - default_amr.yaml + +accelerator: cpu +devices: 1 diff --git a/configs_hydra/trainer/ddp.yaml b/configs_hydra/trainer/ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29074dee3c0d3371d7ad275153db31d9d3c5c665 --- /dev/null +++ b/configs_hydra/trainer/ddp.yaml @@ -0,0 +1,14 @@ +defaults: + - default.yaml + - default_amr.yaml + +# use "ddp_spawn" instead of "ddp", +# it's slower but normal "ddp" currently doesn't work ideally with hydra +# https://github.com/facebookresearch/hydra/issues/2070 +# https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn +strategy: ddp_spawn + +accelerator: gpu +devices: 2 +num_nodes: 1 +sync_batchnorm: True diff --git a/configs_hydra/trainer/default.yaml b/configs_hydra/trainer/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..be2fce6edc77762308243a31cce24959ee0fd23d --- /dev/null +++ b/configs_hydra/trainer/default.yaml @@ -0,0 +1,10 @@ +_target_: pytorch_lightning.Trainer + +default_root_dir: ${paths.output_dir} + +accelerator: gpu +devices: 1 + +# set True to to ensure deterministic results +# makes training slower but gives more reproducibility than just setting seeds +deterministic: False diff --git a/configs_hydra/trainer/default_amr.yaml b/configs_hydra/trainer/default_amr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..937164b7248c81b32e1e4b0a4eba718b4ff6230f --- /dev/null +++ b/configs_hydra/trainer/default_amr.yaml @@ -0,0 +1,9 @@ +num_sanity_val_steps: 0 +log_every_n_steps: ${GENERAL.LOG_STEPS} +val_check_interval: ${GENERAL.VAL_STEPS} # How often within one training epoch to check the validation set. +check_val_every_n_epoch: ${GENERAL.VAL_EPOCHS} # Check val every n train epochs. +precision: 16-mixed # 16-mixed, 32 +max_steps: ${GENERAL.TOTAL_STEPS} +# move_metrics_to_cpu: True +limit_val_batches: 80 # How much of validation dataset to check. +# track_grad_norm: -1 diff --git a/configs_hydra/trainer/gpu.yaml b/configs_hydra/trainer/gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..14eee7741a37a073c235ef279aa55c92229436a1 --- /dev/null +++ b/configs_hydra/trainer/gpu.yaml @@ -0,0 +1,6 @@ +defaults: + - default.yaml + - default_amr.yaml + +accelerator: gpu +devices: 1 diff --git a/configs_hydra/trainer/mps.yaml b/configs_hydra/trainer/mps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6e4997f8c0c50e90ba4f7e101a9102ff337f6944 --- /dev/null +++ b/configs_hydra/trainer/mps.yaml @@ -0,0 +1,6 @@ +defaults: + - default.yaml + - default_amr.yaml + +accelerator: mps +devices: 1 diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ac81553cfa5c409d9fb571d7a6fd5a2b981c3f --- /dev/null +++ b/demo.py @@ -0,0 +1,189 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from pathlib import Path +import detectron2.config +import detectron2.engine +import torch +import argparse +import os +import cv2 +import numpy as np +from tqdm import tqdm +import torch.utils +import torch.utils.data +from prima.models import load_prima +from prima.utils import recursive_to +from prima.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD +from prima.utils.detection import select_animal_boxes +from prima.utils.weights import DEFAULT_HF_REPO_ID, resolve_prima_checkpoint_path +import detectron2 +from detectron2 import model_zoo +import warnings +warnings.filterwarnings("ignore") + +LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353) +GREEN = (0.65, 0.86, 0.74) +REPO_ROOT = Path(__file__).resolve().parent + + +def load_renderer_components(): + try: + from prima.utils.renderer import Renderer, cam_crop_to_full + except Exception as exc: + raise RuntimeError( + "Cannot initialize the PRIMA renderer. Rendering requires a working " + "pyrender/OpenGL backend such as EGL or OSMesa. Install the missing " + "OpenGL runtime for this environment, or run in an environment where " + "PYOPENGL_PLATFORM=egl/osmesa works." + ) from exc + return Renderer, cam_crop_to_full + + +def main(): + parser = argparse.ArgumentParser(description='prima demo code') + parser.add_argument('--checkpoint', type=str, default='', + help='Path to pretrained model checkpoint. Empty -> auto-download the default Stage 1 checkpoint.') + parser.add_argument('--hf-repo-id', '--hf_repo_id', dest='hf_repo_id', + type=str, default=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_REPO_ID), + help='Hugging Face repo ID containing PRIMA demo assets') + parser.add_argument('--no-auto-download', '--no_auto_download', dest='no_auto_download', action='store_true', + help='Disable automatic download of missing PRIMA demo assets') + parser.add_argument('--img_folder', type=str, default='demo_data/', help='Folder with input images') + parser.add_argument('--out_folder', type=str, default='demo_out', help='Output folder to save rendered results') + parser.add_argument('--side_view', dest='side_view', action='store_true', default=False, + help='If set, render side view also') + parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, + help='If set, save meshes to disk also') + parser.add_argument('--batch_size', type=int, default=1, help='Batch size for inference/fitting') + parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg', '*.JPEG'], + help='List of file extensions to consider') + + args = parser.parse_args() + + checkpoint_path = resolve_prima_checkpoint_path( + args.checkpoint, + data_dir=REPO_ROOT / "data", + auto_download=not args.no_auto_download, + hf_repo_id=args.hf_repo_id, + ) + + model, model_cfg = load_prima(checkpoint_path) + + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + model = model.to(device) + model.eval() + + # Setup the renderer + Renderer, cam_crop_to_full = load_renderer_components() + renderer = Renderer(model_cfg, faces=model.smal.faces) + + # Make output directory if it does not exist + os.makedirs(args.out_folder, exist_ok=True) + + # Load detector + cfg = detectron2.config.get_cfg() + cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml")) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 + cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl" + cfg.MODEL.DEVICE = device.type + detector = detectron2.engine.DefaultPredictor(cfg) + + img_paths = sorted([img for end in args.file_type for img in Path(args.img_folder).glob(end)]) + num_readable_images = 0 + num_rendered_results = 0 + num_suppressed_detections = 0 + for img_path in img_paths: + img_bgr = cv2.imread(str(img_path)) + if img_bgr is None: + print(f"[WARN] Cannot read image: {img_path}") + continue + num_readable_images += 1 + # Detect animals in image + det_out = detector(img_bgr) + + det_instances = det_out['instances'] + boxes, suppressed = select_animal_boxes(det_instances, score_threshold=0.7) + num_suppressed_detections += suppressed + if suppressed > 0: + print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s) in {img_path}") + if len(boxes) == 0: + print(f"[INFO] No animal detected in {img_path}") + continue + + # Run PRIMA on detected animals + dataset = ViTDetDataset(model_cfg, img_bgr, boxes) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) + for batch in tqdm(dataloader): + batch = recursive_to(batch, device) + with torch.no_grad(): + out = model(batch) + + pred_cam = out['pred_cam'] + box_center = batch["box_center"].float() + box_size = batch["box_size"].float() + img_size = batch["img_size"].float() + scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max() + pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size, + scaled_focal_length).detach().cpu().numpy() + + # Render the result + batch_size = batch['img'].shape[0] + for n in range(batch_size): + # Get filename from path img_path + img_fn, _ = os.path.splitext(os.path.basename(img_path)) + animal_id = int(batch['animalid'][n]) + white_img = (torch.ones_like(batch['img'][n]).cpu() - DEFAULT_MEAN[:, None, None] / 255) / ( + DEFAULT_STD[:, None, None] / 255) + input_patch = (batch['img'][n].cpu() * (DEFAULT_STD[:, None, None]) + ( + DEFAULT_MEAN[:, None, None])) / 255. + input_patch = input_patch.permute(1, 2, 0).numpy() + + regression_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(), + out['pred_cam_t'][n].detach().cpu().numpy(), + batch['img'][n], + mesh_base_color=GREEN, + scene_bg_color=(1, 1, 1), + ) + + final_img = np.concatenate([input_patch, regression_img], axis=1) + + if args.side_view: + side_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(), + out['pred_cam_t'][n].detach().cpu().numpy(), + white_img, + mesh_base_color=GREEN, + scene_bg_color=(1, 1, 1), + side_view=True) + final_img = np.concatenate([final_img, side_img], axis=1) + + cv2.imwrite(os.path.join(args.out_folder, f'{img_fn}_{animal_id}.png'), + cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR)) + num_rendered_results += 1 + + # Add all verts and cams to list + verts = out['pred_vertices'][n].detach().cpu().numpy() + cam_t = pred_cam_t_full[n] + + # Save all meshes to disk + if args.save_mesh: + camera_translation = cam_t.copy() + tmesh = renderer.vertices_to_trimesh(verts, camera_translation, LIGHT_BLUE) + tmesh.export(os.path.join(args.out_folder, f'{img_fn}_{animal_id}.obj')) + + print( + f"[done] Demo complete. Processed {num_readable_images}/{len(img_paths)} image(s), " + f"saved {num_rendered_results} rendered result(s) to {args.out_folder}." + ) + if num_suppressed_detections > 0: + print(f"[done] Suppressed {num_suppressed_detections} duplicate animal detection(s).") + + +if __name__ == '__main__': + main() diff --git a/demo.sh b/demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..e45cc37a28469c70e6b5769fbb7f95e1626fa1db --- /dev/null +++ b/demo.sh @@ -0,0 +1,12 @@ +# Default PRIMA Stage 1 inference checkpoint: +# data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt +# +# If this local file is missing, it will be downloaded from the PRIMA Hugging Face repo. +# To use another local checkpoint instead, update this path. +# For example: checkpoint='data/PRIMAS3/checkpoints/s3ckpt.ckpt' +checkpoint='data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt' + +python demo.py \ + --checkpoint "${checkpoint}" \ + --img_folder demo_data/ \ + --out_folder demo_out/ diff --git a/demo_data/000000015956_horse.png b/demo_data/000000015956_horse.png new file mode 100644 index 0000000000000000000000000000000000000000..15fab77e07b6b732852bcb48fad389bcefc94e98 --- /dev/null +++ b/demo_data/000000015956_horse.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a2398ba7df40a47c636afefa28be17b55f4b7bc2c378e053aeea507580ad2cb +size 620466 diff --git a/demo_data/000000315905_zebra.jpg b/demo_data/000000315905_zebra.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c476bdf86fb551db280f8bba613b83d673f4768c --- /dev/null +++ b/demo_data/000000315905_zebra.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0a17e1f1650820b020a9025144015c1e27f0f1ab435859f0bde3a0047d8f689 +size 257420 diff --git a/demo_data/beagle.jpg b/demo_data/beagle.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e3ee9223e2597f56e8ebf1e70d0d8bac82e90dbb --- /dev/null +++ b/demo_data/beagle.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac29e6ea6086831dd9806a8cd3fd608e264ac1af567f6fcfc8797c5bd3d5d560 +size 349657 diff --git a/demo_data/n02101388_1188.png b/demo_data/n02101388_1188.png new file mode 100644 index 0000000000000000000000000000000000000000..873d5861f7bff57812afab192a747b02647a3084 --- /dev/null +++ b/demo_data/n02101388_1188.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e45ff508fb8c6437cce22fcb59b4f1b6fe37ddfab1d4cf68d97629f9caa939f4 +size 318688 diff --git a/demo_data/n02412080_12159.png b/demo_data/n02412080_12159.png new file mode 100644 index 0000000000000000000000000000000000000000..a6517fa50d8e1b094a07468ea64bb8bb21cdba07 --- /dev/null +++ b/demo_data/n02412080_12159.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03273c57e8b25b258d3eb96af7b4f77b43b5c40be90da83c21875f3322b487f1 +size 347450 diff --git a/demo_data/shepherd_hati.jpg b/demo_data/shepherd_hati.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4fd79ba7a132f502833df56a884d89a58c8a086a --- /dev/null +++ b/demo_data/shepherd_hati.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65c5878203bc3165dda9011ebfce77cc7d930daed0a215396d8036509d1963c1 +size 209710 diff --git a/demo_tta.py b/demo_tta.py new file mode 100644 index 0000000000000000000000000000000000000000..c8257c2712af180319e1f8959eeda8e6d1c583d2 --- /dev/null +++ b/demo_tta.py @@ -0,0 +1,399 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +""" +demo_tta.py: PRIMA inference with fine-tuned DeepLabCut SuperAnimal TTA + +Pipeline: +1. Run Detectron2 to detect animals in the input image. +2. Run PRIMA on each detected animal to obtain 3D pose/shape estimation. +3. Run a fine-tuned DeepLabCut SuperAnimal pose model (Animal3D 26-joint + layout) to obtain 2D keypoints already in PRIMA topology. The fine-tuned + snapshot is wired into DLC's + ``superanimal_analyze_images`` via the ``customized_pose_checkpoint`` + and ``customized_model_config`` kwargs. +4. Run test-time adaptation (TTA) with user-specified lr and num_iters + to further optimize the 3D pose and shape estimation. +5. Render and save before/after TTA results (PNG + OBJ) and the + 26-keypoint visualization (PNG). +""" + + +from pathlib import Path +import argparse +import copy +import os +import tempfile +import warnings + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.data +from tqdm import tqdm + +from prima.models import load_prima +from prima.utils import recursive_to +from prima.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD +from prima.utils.detection import ANIMAL_COCO_IDS, select_animal_boxes +from prima.utils.weights import DEFAULT_HF_REPO_ID, resolve_prima_checkpoint_path + +warnings.filterwarnings("ignore") + +LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353) +GREEN = (0.65, 0.86, 0.74) + +REPO_ROOT = Path(__file__).resolve().parent + + +def load_renderer_components(): + try: + from prima.utils.renderer import Renderer, cam_crop_to_full + except Exception as exc: + raise RuntimeError( + "Cannot initialize the PRIMA renderer. Rendering requires a working " + "pyrender/OpenGL backend such as EGL or OSMesa. Install the missing " + "OpenGL runtime for this environment, or run in an environment where " + "PYOPENGL_PLATFORM=egl/osmesa works." + ) from exc + return Renderer, cam_crop_to_full + + +def denorm_patch_to_rgb(img_tensor: torch.Tensor) -> np.ndarray: + patch = (img_tensor.detach().cpu() * (DEFAULT_STD[:, None, None]) + DEFAULT_MEAN[:, None, None]) / 255.0 + patch = patch.permute(1, 2, 0).numpy() + return np.clip(patch, 0.0, 1.0) + + +def save_keypoint_vis(patch_rgb: np.ndarray, kpts_xyc: np.ndarray, save_path: str) -> None: + vis = cv2.cvtColor((patch_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2BGR).copy() + num_kpts = len(kpts_xyc) + + for i, (x, y, c) in enumerate(kpts_xyc): + if c <= 0: + continue + + # Use distinct color for each keypoint (OpenCV uses BGR) + hue = int(179 * i / max(1, num_kpts - 1)) + color_bgr = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2BGR)[0, 0] + color_bgr = (int(color_bgr[0]), int(color_bgr[1]), int(color_bgr[2])) + + cx, cy = int(round(float(x))), int(round(float(y))) + cv2.circle(vis, (cx, cy), 3, color_bgr, -1) + cv2.putText(vis, str(i), (cx + 3, cy - 3), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1, cv2.LINE_AA) + + cv2.imwrite(save_path, vis) + + +def resolve_sa_weights_path(local_path: str) -> str: + """Return a local path to the fine-tuned SuperAnimal .pt snapshot. + + If ``local_path`` is empty, downloads ``sa_finetune_hrnet_w32.pt`` from the + ``MLAdaptiveIntelligence/FMPose3D`` Hugging Face repo (cached under + ``~/.cache/huggingface``). + """ + if local_path: + return local_path + try: + from huggingface_hub import hf_hub_download + except ImportError: + raise ImportError( + "huggingface_hub is required to auto-download the fine-tuned " + "SuperAnimal weights. Install with `pip install huggingface_hub`, " + "or pass --saved_2d_model_path with a local .pt file." + ) from None + repo_id = "MLAdaptiveIntelligence/FMPose3D" + filename = "sa_finetune_hrnet_w32.pt" + try: + cached_path = hf_hub_download(repo_id=repo_id, filename=filename, local_files_only=True) + except Exception: + print(f"No --saved_2d_model_path provided; downloading '{filename}' from {repo_id}...") + return hf_hub_download(repo_id=repo_id, filename=filename) + + print(f"Using cached SuperAnimal weights: {cached_path}") + return cached_path + + +def run_superanimal_on_patch(patch_rgb: np.ndarray, args, tmp_dir: str): + """Predict 26-joint 2D keypoints on a single PRIMA patch using a + fine-tuned DeepLabCut SuperAnimal snapshot. + + Returns an ``(26, 3)`` array of ``(x, y, confidence)`` in patch + pixel coordinates, or ``None`` if no individual was detected. + """ + try: + from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images + except Exception as e: + raise RuntimeError( + "Cannot import DeepLabCut SuperAnimal API. Please install deeplabcut with pose_estimation_pytorch support." + ) from e + + patch_path = os.path.join(tmp_dir, "patch.png") + cv2.imwrite(patch_path, cv2.cvtColor((patch_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + + dlc_device = "cuda" if torch.cuda.is_available() else "cpu" + preds = superanimal_analyze_images( + superanimal_name=args.superanimal_name, + model_name=args.superanimal_model_name, + detector_name=args.superanimal_detector_name, + images=patch_path, + max_individuals=args.superanimal_max_individuals, + out_folder=tmp_dir, + device=dlc_device, + customized_model_config=args.pytorch_config_2d_path, + customized_pose_checkpoint=args.saved_2d_model_path, + ) + + payload = preds.get(patch_path, None) + if payload is None: + return None + bodyparts = payload.get("bodyparts", None) + if bodyparts is None or len(bodyparts) == 0: + return None + + best_idx = int(np.argmax(bodyparts[..., 2].mean(axis=1))) + return bodyparts[best_idx].astype(np.float32) + + +def render_and_save(renderer, cam_crop_to_full_fn, out, batch, img_fn, animal_id, out_folder, suffix, side_view, save_mesh): + pred_cam = out['pred_cam'] + box_center = batch['box_center'].float() + box_size = batch['box_size'].float() + img_size = batch['img_size'].float() + scaled_focal_length = batch['focal_length'][0, 0] / batch['img'].shape[-1] * img_size.max() + pred_cam_t_full = cam_crop_to_full_fn(pred_cam, box_center, box_size, img_size, scaled_focal_length) + + white_img = (torch.ones_like(batch['img'][0]).cpu() - DEFAULT_MEAN[:, None, None] / 255) / ( + DEFAULT_STD[:, None, None] / 255 + ) + input_patch = denorm_patch_to_rgb(batch['img'][0]) + + regression_img = renderer( + out['pred_vertices'][0].detach().cpu().numpy(), + out['pred_cam_t'][0].detach().cpu().numpy(), + batch['img'][0], + mesh_base_color=GREEN, + scene_bg_color=(1, 1, 1), + ) + + final_img = np.concatenate([input_patch, regression_img], axis=1) + if side_view: + side_img = renderer( + out['pred_vertices'][0].detach().cpu().numpy(), + out['pred_cam_t'][0].detach().cpu().numpy(), + white_img, + mesh_base_color=GREEN, + scene_bg_color=(1, 1, 1), + side_view=True, + ) + final_img = np.concatenate([final_img, side_img], axis=1) + + cv2.imwrite( + os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.png'), + cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR), + ) + + if save_mesh: + verts = out['pred_vertices'][0].detach().cpu().numpy() + cam_t = pred_cam_t_full[0].detach().cpu().numpy() + tmesh = renderer.vertices_to_trimesh(verts, cam_t.copy(), LIGHT_BLUE) + tmesh.export(os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.obj')) + + +def tta_optimize(model, batch, gt_kpts_norm, num_iters, lr): + model.eval() + + if hasattr(model, 'backbone'): + for p in model.backbone.parameters(): + p.requires_grad = False + + orig_smal_head_state = copy.deepcopy(model.smal_head.state_dict()) + model.smal_head.freeze_except_regression_heads() + tta_params = model.smal_head.get_tta_parameters(mode='all') + optimizer = torch.optim.Adam(tta_params, lr=lr) + + valid_mask = (gt_kpts_norm[..., 2] > 0).float().unsqueeze(-1) + gt_xy = gt_kpts_norm[..., :2] + + for _ in range(num_iters): + optimizer.zero_grad() + out = model(batch) + pred_xy = out['pred_keypoints_2d'] + loss = F.mse_loss(pred_xy * valid_mask, gt_xy * valid_mask, reduction='sum') / (valid_mask.sum() + 1e-6) + loss.backward() + optimizer.step() + + with torch.no_grad(): + out_after = model(batch) + + model.smal_head.load_state_dict(orig_smal_head_state) + model.smal_head.unfreeze_all() + + return out_after + + +def main(): + parser = argparse.ArgumentParser(description='PRIMA + SuperAnimal + TTA demo') + parser.add_argument('--checkpoint', type=str, default='', + help='Path to pretrained PRIMA checkpoint. Empty -> auto-download the default Stage 1 checkpoint.') + parser.add_argument('--hf-repo-id', '--hf_repo_id', dest='hf_repo_id', + type=str, default=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_REPO_ID), + help='Hugging Face repo ID containing PRIMA demo assets') + parser.add_argument('--no-auto-download', '--no_auto_download', dest='no_auto_download', action='store_true', + help='Disable automatic download of missing PRIMA demo assets') + parser.add_argument('--img_path', type=str, default=None, help='Single image path') + parser.add_argument('--img_folder', type=str, default='demo_data/', help='Folder with input images') + parser.add_argument('--out_folder', type=str, default='demo_out_tta', help='Output folder') + parser.add_argument('--side_view', dest='side_view', action='store_true', default=False, help='Render side view') + parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, help='Save meshes') + parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg', '*.JPEG'], help='Image globs') + parser.add_argument('--det_thresh', type=float, default=0.7, help='Detectron2 score threshold for animals') + + parser.add_argument('--tta_lr', type=float, default=1e-6, help='TTA learning rate') + parser.add_argument('--tta_num_iters', type=int, default=30, help='TTA iterations') + parser.add_argument('--kp_conf_thresh', type=float, default=0.1, help='Keypoint confidence threshold') + + parser.add_argument('--superanimal_name', type=str, default='superanimal_quadruped') + parser.add_argument('--superanimal_model_name', type=str, default='hrnet_w32') + parser.add_argument('--superanimal_detector_name', type=str, default='fasterrcnn_resnet50_fpn_v2') + parser.add_argument('--superanimal_max_individuals', type=int, default=1) + parser.add_argument('--saved_2d_model_path', type=str, default='', + help='Path to the fine-tuned SuperAnimal 26-joint .pt snapshot. ' + 'Empty -> auto-download sa_finetune_hrnet_w32.pt from ' + 'MLAdaptiveIntelligence/FMPose3D on Hugging Face Hub.') + parser.add_argument('--pytorch_config_2d_path', type=str, + default=str(Path(__file__).resolve().parent / 'configs' / 'sa_finetune_hrnet_w32.yaml'), + help='Path to the DLC pytorch config yaml for the fine-tuned snapshot. ' + 'Defaults to the bundled configs/sa_finetune_hrnet_w32.yaml.') + + args = parser.parse_args() + checkpoint_path = resolve_prima_checkpoint_path( + args.checkpoint, + data_dir=REPO_ROOT / "data", + auto_download=not args.no_auto_download, + hf_repo_id=args.hf_repo_id, + ) + args.saved_2d_model_path = resolve_sa_weights_path(args.saved_2d_model_path) + + model, model_cfg = load_prima(checkpoint_path) + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + model = model.to(device) + model.eval() + + Renderer, cam_crop_to_full_fn = load_renderer_components() + renderer = Renderer(model_cfg, faces=model.smal.faces) + os.makedirs(args.out_folder, exist_ok=True) + + import detectron2.config + import detectron2.engine + from detectron2 import model_zoo + + cfg = detectron2.config.get_cfg() + cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml")) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 + cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl" + cfg.MODEL.DEVICE = device.type + detector = detectron2.engine.DefaultPredictor(cfg) + + if args.img_path is not None: + img_paths = [Path(args.img_path)] + else: + img_paths = sorted([img for end in args.file_type for img in Path(args.img_folder).glob(end)]) + + for img_path in img_paths: + img_bgr = cv2.imread(str(img_path)) + if img_bgr is None: + print(f"[WARN] Cannot read image: {img_path}") + continue + det_out = detector(img_bgr) + det_instances = det_out['instances'] + boxes, suppressed = select_animal_boxes( + det_instances, + animal_class_ids=ANIMAL_COCO_IDS, + score_threshold=args.det_thresh, + ) + if suppressed > 0: + print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s) in {img_path}") + + if len(boxes) == 0: + print(f"[INFO] No animal detected in {img_path}") + continue + + dataset = ViTDetDataset(model_cfg, img_bgr, boxes) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + + for batch in tqdm(dataloader, desc=f"{img_path.name}"): + batch = recursive_to(batch, device) + with torch.no_grad(): + out_before = model(batch) + + img_fn = img_path.stem + animal_id = int(batch['animalid'][0]) + + render_and_save( + renderer, + cam_crop_to_full_fn, + out_before, + batch, + img_fn, + animal_id, + args.out_folder, + suffix='before_tta', + side_view=args.side_view, + save_mesh=args.save_mesh, + ) + + patch_rgb = denorm_patch_to_rgb(batch['img'][0]) + with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir: + kpts_xyc = run_superanimal_on_patch(patch_rgb, args, tmp_dir) + + if kpts_xyc is None: + print(f"[WARN] No SuperAnimal keypoints for {img_fn}_{animal_id}, skip TTA") + continue + + kpts_xyc[kpts_xyc[:, 2] < args.kp_conf_thresh, 2] = 0.0 + + save_keypoint_vis( + patch_rgb, + kpts_xyc, + os.path.join(args.out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png"), + ) + np.save(os.path.join(args.out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy"), kpts_xyc) + + patch_h, patch_w = patch_rgb.shape[:2] + kpts_norm = kpts_xyc.copy() + kpts_norm[:, 0] = kpts_norm[:, 0] / float(patch_w) - 0.5 + kpts_norm[:, 1] = kpts_norm[:, 1] / float(patch_h) - 0.5 + gt_kpts_norm = torch.from_numpy(kpts_norm[None]).to(device=device, dtype=batch['img'].dtype) + + out_after = tta_optimize( + model, + batch, + gt_kpts_norm, + num_iters=args.tta_num_iters, + lr=args.tta_lr, + ) + + render_and_save( + renderer, + cam_crop_to_full_fn, + out_after, + batch, + img_fn, + animal_id, + args.out_folder, + suffix='after_tta', + side_view=args.side_view, + save_mesh=args.save_mesh, + ) + + +if __name__ == '__main__': + main() diff --git a/demo_tta.sh b/demo_tta.sh new file mode 100644 index 0000000000000000000000000000000000000000..8f0f111c184d2ef4c84a73e7a45fa435d701f6e9 --- /dev/null +++ b/demo_tta.sh @@ -0,0 +1,15 @@ + +# Empty checkpoint uses the default PRIMA Stage 1 inference checkpoint: +# data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt +# +# This standard path is auto-downloaded from the PRIMA Hugging Face repo if missing. +# To use another local checkpoint instead, update this path. +# For example: checkpoint='data/PRIMAS3/checkpoints/s3ckpt.ckpt' +checkpoint='data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt' + +python3 demo_tta.py \ + --checkpoint "${checkpoint}" \ + --img_folder demo_data/ \ + --out_folder demo_out_tta/ \ + --tta_lr 1e-6 \ + --tta_num_iters 30 diff --git a/eval.py b/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..659b7eb92771d9f3e6e4108455b752a4a95efc6c --- /dev/null +++ b/eval.py @@ -0,0 +1,103 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import numpy as np +from tqdm import tqdm +import torch +from prima.utils import recursive_to +from prima.utils.evaluate_metric import Evaluator +from prima.datasets.datasets import EvaluationDataset +import argparse +from torch.utils.data import DataLoader +from prima.models.prima import PRIMA +from prima.configs import get_config +torch.multiprocessing.set_sharing_strategy('file_system') + + +def main(args): + cfg = get_config(args.config) + default_cfg = get_config(args.default_eval_config) + model = PRIMA.load_from_checkpoint(args.checkpoint, cfg=cfg, strict=False) + model.eval() + model = model.to(args.device) + + smal_evaluator = Evaluator(smal_model=model.smal, image_size=cfg.MODEL.IMAGE_SIZE) + cfg_eval_dataset = dict(default_cfg.DATASETS) + aug_cfg = cfg_eval_dataset.pop("CONFIG", None) # augmentation config is not used in evaluation + + if args.dataset.upper() == "ALL": + for key in cfg_eval_dataset.keys(): + print(f"-------- Evaluate {key} dataset ------------") + eval_one_dataset(cfg_eval_dataset[key], default_cfg, cfg, model, + evaluator=smal_evaluator, + aug_cfg=aug_cfg, + key=key, + device=args.device) + print(f"-------{key} Dataset evaluate finish ------") + else: + print(f"-------- Evaluate {args.dataset} dataset ------------") + eval_one_dataset(cfg_eval_dataset[args.dataset], default_cfg, cfg, model, + evaluator=smal_evaluator, + aug_cfg=aug_cfg, + key=args.dataset, + device=args.device) + print(f"-------{args.dataset} Dataset evaluate finish ------") + + +def eval_one_dataset(dataset_cfg, default_cfg, cfg, model, evaluator, aug_cfg, key, device='cuda'): + dataset = EvaluationDataset(root_image=dataset_cfg['ROOT_IMAGE'], + json_file=dataset_cfg['JSON_FILE']['TEST'], + augm_config=aug_cfg, focal_length=cfg.SMAL.get("FOCAL_LENGTH", 1000), + image_size=cfg.MODEL.IMAGE_SIZE, + ) + dataloader = DataLoader(dataset, batch_size=1, num_workers=cfg.GENERAL.NUM_WORKERS) + + bar = tqdm(dataloader) + pa_mpjpe_list, pck_list, auc_list, pa_mpvpe_list = [], [], [], [] + for i, batch in enumerate(bar): + batch = recursive_to(batch, device) + with torch.no_grad(): + output = model(batch) + + if key in ["ANIMAL3D", "CONTROL_ANIMAL3D"]: + pa_mpjpe, pa_mpvpe = evaluator.eval_3d(output, batch) + else: + pa_mpjpe, pa_mpvpe = 0., 0. + pck, auc = evaluator.eval_2d(output, batch, pck_threshold=default_cfg.METRIC.PCK_THRESHOLD) + + pa_mpjpe_list.append(pa_mpjpe) + pa_mpvpe_list.append(pa_mpvpe) + auc_list.append(auc) + pck_list.append(pck) + + bar.set_postfix(PA_MPJPE=pa_mpjpe, + PA_MPVPE=pa_mpvpe, + AUC=auc, + pck=pck,) + + print("---------------- 3D metric -----------------") + print(f"Avg PA-MPJPE: {np.mean(pa_mpjpe_list)}") + print(f"Avg PA-MPVPE: {np.mean(pa_mpvpe_list)}") + + print("--------------- 2D metric ------------------") + print(f"AUC: {np.mean(auc_list)}") + pck_list = np.array(pck_list) + for _, th in enumerate(default_cfg.METRIC.PCK_THRESHOLD): + print(f"PCK@{th}: {np.mean(pck_list[:, _])}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, help="Path to config file", required=True) + parser.add_argument("--checkpoint", type=str, help="Path to checkpoint file", required=True) + parser.add_argument("--default_eval_config", type=str, default="./configs_hydra/experiment/default_val.yaml") + parser.add_argument("--dataset", type=str, default="ALL") + parser.add_argument("--device", type=str, default="cuda", help="Device to use for evaluation") + args = parser.parse_args() + main(args) diff --git a/images/teaser.png b/images/teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..0565e56973c430588bb28395e5f55162e2affef6 --- /dev/null +++ b/images/teaser.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a617ca4fd37de03e2db4ccf397ce9841ed32c3fe18c766c4832d41af574ad746 +size 4287693 diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..65d0df2936b816c0327c8a5911b1886f77909e65 --- /dev/null +++ b/packages.txt @@ -0,0 +1,4 @@ +libosmesa6 +libgl1 +libegl1 +libgles2 diff --git a/prima/__init__.py b/prima/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a653e908faa00c04476586f4dbeedd876f468831 --- /dev/null +++ b/prima/__init__.py @@ -0,0 +1,25 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +"""Top-level package for PRIMA. + +This package contains models, datasets and utilities for +3D animal pose and shape estimation. +""" + +from importlib.metadata import PackageNotFoundError, version + + +try: # pragma: no cover - best effort during development + __version__ = version("prima-animal") +except PackageNotFoundError: # pragma: no cover + __version__ = "0.0.0" + + +__all__ = ["__version__"] diff --git a/prima/configs/__init__.py b/prima/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c465c6b3c894534c973c1e8ba8344716c4554e0 --- /dev/null +++ b/prima/configs/__init__.py @@ -0,0 +1,99 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from typing import Dict +from yacs.config import CfgNode as CN + +def to_lower(x: Dict) -> Dict: + """ + Convert all dictionary keys to lowercase + Args: + x (dict): Input dictionary + Returns: + dict: Output dictionary with all keys converted to lowercase + """ + return {k.lower(): v for k, v in x.items()} + + +_C = CN(new_allowed=True) + +_C.GENERAL = CN(new_allowed=True) +_C.GENERAL.RESUME = True +_C.GENERAL.TIME_TO_RUN = 3300 +_C.GENERAL.VAL_STEPS = 100 +_C.GENERAL.LOG_STEPS = 100 +_C.GENERAL.CHECKPOINT_STEPS = 20000 +_C.GENERAL.CHECKPOINT_DIR = "checkpoints" +_C.GENERAL.SUMMARY_DIR = "tensorboard" +_C.GENERAL.NUM_GPUS = 1 +_C.GENERAL.NUM_WORKERS = 4 +_C.GENERAL.MIXED_PRECISION = True +_C.GENERAL.ALLOW_CUDA = True +_C.GENERAL.PIN_MEMORY = False +_C.GENERAL.DISTRIBUTED = False +_C.GENERAL.LOCAL_RANK = 0 +_C.GENERAL.USE_SYNCBN = False +_C.GENERAL.WORLD_SIZE = 1 +_C.GENERAL.PREFETCH_FACTOR = 2 + +_C.TRAIN = CN(new_allowed=True) +_C.TRAIN.NUM_EPOCHS = 100 +_C.TRAIN.SHUFFLE = True +_C.TRAIN.WARMUP = False +_C.TRAIN.NORMALIZE_PER_IMAGE = False +_C.TRAIN.CLIP_GRAD = False +_C.TRAIN.CLIP_GRAD_VALUE = 1.0 +_C.LOSS_WEIGHTS = CN(new_allowed=True) + +_C.DATASETS = CN(new_allowed=True) + +_C.MODEL = CN(new_allowed=True) +_C.MODEL.IMAGE_SIZE = 224 + +_C.EXTRA = CN(new_allowed=True) +_C.EXTRA.FOCAL_LENGTH = 5000 + +_C.DATASETS.CONFIG = CN(new_allowed=True) +_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3 +_C.DATASETS.CONFIG.ROT_FACTOR = 30 +_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02 +_C.DATASETS.CONFIG.COLOR_SCALE = 0.2 +_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6 +_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5 +_C.DATASETS.CONFIG.DO_FLIP = False +_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5 +_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10 + + +def default_config() -> CN: + """ + Get a yacs CfgNode object with the default config values. + """ + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() + + +def get_config(config_file: str, merge: bool = True) -> CN: + """ + Read a config file and optionally merge it with the default config file. + Args: + config_file (str): Path to config file. + merge (bool): Whether to merge with the default config or not. + Returns: + CfgNode: Config as a yacs CfgNode object. + """ + if merge: + cfg = default_config() + else: + cfg = CN(new_allowed=True) + cfg.merge_from_file(config_file) + + cfg.freeze() + return cfg diff --git a/prima/models/__init__.py b/prima/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d41a09abd2b42228d2a13373344415dc6d781ba8 --- /dev/null +++ b/prima/models/__init__.py @@ -0,0 +1,54 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from .prima import PRIMA + + +def load_prima(checkpoint_path): + from pathlib import Path + from ..configs import get_config + model_cfg = str(Path(checkpoint_path).parent.parent / '.hydra/config.yaml') + model_cfg = get_config(model_cfg) + + # Override some config values, to crop bbox correctly + if (model_cfg.MODEL.BACKBONE.TYPE == 'vit') and ('BBOX_SHAPE' not in model_cfg.MODEL): + model_cfg.defrost() + assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone" + model_cfg.MODEL.BBOX_SHAPE = [192, 256] + model_cfg.freeze() + if (model_cfg.MODEL.BACKBONE.TYPE == 'dinov3') and ('BBOX_SHAPE' not in model_cfg.MODEL): + model_cfg.defrost() + assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for dino backbone" + model_cfg.MODEL.BBOX_SHAPE = [256, 256] + model_cfg.freeze() + + if (model_cfg.MODEL.BACKBONE.TYPE == 'dinov2') and ('BBOX_SHAPE' not in model_cfg.MODEL): + model_cfg.defrost() + assert model_cfg.MODEL.IMAGE_SIZE == 252, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 252 for dino backbone" + model_cfg.MODEL.BBOX_SHAPE = [252, 252] + model_cfg.freeze() + + + + # Update config to be compatible with demo + if ('PRETRAINED_WEIGHTS' in model_cfg.MODEL.BACKBONE): + model_cfg.defrost() + model_cfg.MODEL.BACKBONE.pop('PRETRAINED_WEIGHTS') + model_cfg.freeze() + + # Offscreen training renderer is not needed for demo/inference startup and + # can fail on some local OpenGL backends. + model = PRIMA.load_from_checkpoint( + checkpoint_path, + strict=False, + cfg=model_cfg, + map_location='cpu', + init_renderer=False, + ) + return model, model_cfg diff --git a/prima/models/backbones/__init__.py b/prima/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fedb1b14855557664259adf3c4a331f185410363 --- /dev/null +++ b/prima/models/backbones/__init__.py @@ -0,0 +1,19 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from .vit import vith + + + + +def create_backbone(cfg): + if cfg.MODEL.BACKBONE.TYPE in ['vith','concat','aa']: # vit bb will be used in these three cases - animal feature extractor + return vith(cfg) + else: + raise NotImplementedError('Backbone type is not implemented') diff --git a/prima/models/backbones/vit.py b/prima/models/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b625913a166263119532e3b96de745259157c --- /dev/null +++ b/prima/models/backbones/vit.py @@ -0,0 +1,375 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from functools import partial +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.layers import drop_path, to_2tuple, trunc_normal_ + + +def vith(cfg): + return ViT( + img_size=(256, 192), + patch_size=16, + embed_dim=1280, + depth=32, + num_heads=16, + ratio=1, + use_checkpoint=False, + # use_checkpoint=True, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.55, + use_cls=True, # cls for animal family classification + ) + + +def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + cls_token = None + B, L, C = abs_pos.shape + if has_cls_token: + cls_token = abs_pos[:, 0:1] + abs_pos = abs_pos[:, 1:] + + if ori_h != h or ori_w != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).reshape(B, -1, C) + + else: + new_abs_pos = abs_pos + + if cls_token is not None: + new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1) + return new_abs_pos + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.dim = dim + + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, + norm_layer=nn.LayerNorm, attn_head_dim=None, + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim + ) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2) + self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio)) + self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1])) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), + padding=4 + 2 * (ratio // 2 - 1)) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, (Hp, Wp) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class ViT(nn.Module): + + def __init__(self, + img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False, + frozen_stages=-1, ratio=1, last_norm=True, use_cls=False, + patch_padding='pad', freeze_attn=False, freeze_ffn=False, + ): + # Protect mutable default arguments + super(ViT, self).__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.frozen_stages = frozen_stages + self.use_checkpoint = use_checkpoint + self.patch_padding = patch_padding + self.freeze_attn = freeze_attn + self.freeze_ffn = freeze_ffn + self.depth = depth + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio) + num_patches = self.patch_embed.num_patches + + # since the pretraining model has class token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + ) + for i in range(depth)]) + + self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + + self.use_cls = use_cls + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + nn.init.normal_(self.cls_token, std=1e-6) + + self._freeze_stages() + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if self.freeze_attn: + for i in range(0, self.depth): + m = self.blocks[i] + m.attn.eval() + m.norm1.eval() + for param in m.attn.parameters(): + param.requires_grad = False + for param in m.norm1.parameters(): + param.requires_grad = False + + if self.freeze_ffn: + self.pos_embed.requires_grad = False + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + for i in range(0, self.depth): + m = self.blocks[i] + m.mlp.eval() + m.norm2.eval() + for param in m.mlp.parameters(): + param.requires_grad = False + for param in m.norm2.parameters(): + param.requires_grad = False + + def init_weights(self): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x): + B, C, H, W = x.shape + x, (Hp, Wp) = self.patch_embed(x) + + if self.pos_embed is not None: + # fit for multiple GPU training + # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference + x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] + + x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) if self.use_cls else x + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + x = self.last_norm(x) + + cls = x[:, 0] if self.use_cls else None + x = x[:, 1:] if self.use_cls else x + xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() + + return xp, cls # shape [B, D, Hp, Wp], [B, D] + + def forward(self, x): + x, cls = self.forward_features(x) + return x, cls + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() diff --git a/prima/models/bioclip_embedding.py b/prima/models/bioclip_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1ebd5c1ada10748c48622c4eb3c5226e2d3136 --- /dev/null +++ b/prima/models/bioclip_embedding.py @@ -0,0 +1,70 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +""" +bioclip Embedding Module +Converts image batch to embeddings that can be concatenated with image features +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +class BioClipEmbedding(nn.Module): + """ + Embeds images into a feature space using BioClip model that can be combined with image features. + + Args: + embed_dim: Output embedding dimension, should match the dimension of image features for concatenation + """ + + def __init__(self, cfg, embed_dim: int = 1024): + super().__init__() + + self.embed_dim = embed_dim + + import open_clip + + if cfg.MODEL.BIOCLIP_EMBEDDING.TYPE == 'bioclip2': + print("[BioClipEmbedding] Using BioClip2 model from Hugging Face Hub") + self.species_model, _,_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2') + else: + self.species_model, _,_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip') + # tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip') + + + self.species_model.eval() + + # Get the output dimension from the model + bioclip_output_dim = self.species_model.visual.output_dim + + # Project to target dimension + self.projection = nn.Sequential( + nn.Linear(bioclip_output_dim, embed_dim), + nn.LayerNorm(embed_dim), + ) + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Args: + images: Tensor of shape (B, C, H, W) representing a batch of images + Returns: + Tensor of shape (B, embed_dim) representing the embedded features + """ + # BioClip expects 224x224 input, resize if needed + if images.shape[-2:] != (224, 224): + images_resized = F.interpolate(images, size=(224, 224), mode='bilinear', align_corners=False) + else: + images_resized = images + + with torch.no_grad(): + image_features = self.species_model.encode_image(images_resized) + + projected_features = self.projection(image_features) + + return projected_features \ No newline at end of file diff --git a/prima/models/components/__init__.py b/prima/models/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/prima/models/components/model_utils.py b/prima/models/components/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..16d8a5ab4140da0335edf9720fa7788180defc13 --- /dev/null +++ b/prima/models/components/model_utils.py @@ -0,0 +1,160 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DropPath(nn.Module): + # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + def __init__(self, drop_prob=0.0, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/prima/models/components/pose_transformer.py b/prima/models/components/pose_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..698e30bd866cc5ea7a85064d67f85b068584f022 --- /dev/null +++ b/prima/models/components/pose_transformer.py @@ -0,0 +1,366 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from inspect import isfunction +from typing import Callable, Optional + +import torch +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import nn + +from .t_cond_mlp import ( + AdaptiveLayerNorm1D, + FrequencyEmbedder, + normalization_layer, +) + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +class PreNorm(nn.Module): + def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1): + super().__init__() + self.norm = normalization_layer(norm, dim, norm_cond_dim) + self.fn = fn + + def forward(self, x: torch.Tensor, *args, **kwargs): + if isinstance(self.norm, AdaptiveLayerNorm1D): + return self.fn(self.norm(x, *args), **kwargs) + else: + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class CrossAttention(nn.Module): + def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + context_dim = default(context_dim, dim) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) + self.to_q = nn.Linear(dim, inner_dim, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x, context=None): + context = default(context, x) + k, v = self.to_kv(context).chunk(2, dim=-1) + q = self.to_q(x) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v]) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + norm: str = "layer", + norm_cond_dim: int = -1, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + ff = FeedForward(dim, mlp_dim, dropout=dropout) + self.layers.append( + nn.ModuleList( + [ + PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), + ] + ) + ) + + def forward(self, x: torch.Tensor, *args): + for attn, ff in self.layers: + x = attn(x, *args) + x + x = ff(x, *args) + x + return x + + +class TransformerCrossAttn(nn.Module): + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + norm: str = "layer", + norm_cond_dim: int = -1, + context_dim: Optional[int] = None, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + ca = CrossAttention( + dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout + ) + ff = FeedForward(dim, mlp_dim, dropout=dropout) + self.layers.append( + nn.ModuleList( + [ + PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), + ] + ) + ) + + def forward(self, x: torch.Tensor, *args, context=None, context_list=None): + if context_list is None: + context_list = [context] * len(self.layers) + if len(context_list) != len(self.layers): + raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})") + + for i, (self_attn, cross_attn, ff) in enumerate(self.layers): + x = self_attn(x, *args) + x + x = cross_attn(x, *args, context=context_list[i]) + x + x = ff(x, *args) + x + return x + + +class DropTokenDropout(nn.Module): + def __init__(self, p: float = 0.1): + super().__init__() + if p < 0 or p > 1: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + self.p = p + + def forward(self, x: torch.Tensor): + # x: (batch_size, seq_len, dim) + if self.training and self.p > 0: + zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool() + + if zero_mask.any(): + x = x[:, ~zero_mask, :] + return x + + +class ZeroTokenDropout(nn.Module): + def __init__(self, p: float = 0.1): + super().__init__() + if p < 0 or p > 1: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + self.p = p + + def forward(self, x: torch.Tensor): + # x: (batch_size, seq_len, dim) + if self.training and self.p > 0: + zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool() + # Zero-out the masked tokens + x[zero_mask, :] = 0 + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, + num_tokens: int, + token_dim: int, + dim: int, + depth: int, + heads: int, + mlp_dim: int, + dim_head: int = 64, + dropout: float = 0.0, + emb_dropout: float = 0.0, + emb_dropout_type: str = "drop", + emb_dropout_loc: str = "token", + norm: str = "layer", + norm_cond_dim: int = -1, + token_pe_numfreq: int = -1, + ): + super().__init__() + if token_pe_numfreq > 0: + token_dim_new = token_dim * (2 * token_pe_numfreq + 1) + self.to_token_embedding = nn.Sequential( + Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim), + FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1), + Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new), + nn.Linear(token_dim_new, dim), + ) + else: + self.to_token_embedding = nn.Linear(token_dim, dim) + self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) + if emb_dropout_type == "drop": + self.dropout = DropTokenDropout(emb_dropout) + elif emb_dropout_type == "zero": + self.dropout = ZeroTokenDropout(emb_dropout) + else: + raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}") + self.emb_dropout_loc = emb_dropout_loc + + self.transformer = Transformer( + dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim + ) + + def forward(self, inp: torch.Tensor, *args, **kwargs): + x = inp + + if self.emb_dropout_loc == "input": + x = self.dropout(x) + x = self.to_token_embedding(x) + + if self.emb_dropout_loc == "token": + x = self.dropout(x) + b, n, _ = x.shape + x += self.pos_embedding[:, :n] + + if self.emb_dropout_loc == "token_afterpos": + x = self.dropout(x) + x = self.transformer(x, *args) + return x + + +class TransformerDecoder(nn.Module): + def __init__( + self, + num_tokens: int, + token_dim: int, + dim: int, + depth: int, + heads: int, + mlp_dim: int, + dim_head: int = 64, + dropout: float = 0.0, + emb_dropout: float = 0.0, + emb_dropout_type: str = 'drop', + norm: str = "layer", + norm_cond_dim: int = -1, + context_dim: Optional[int] = None, + skip_token_embedding: bool = False, + ): + super().__init__() + if not skip_token_embedding: + self.to_token_embedding = nn.Linear(token_dim, dim) + else: + self.to_token_embedding = nn.Identity() + if token_dim != dim: + raise ValueError( + f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True" + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) + if emb_dropout_type == "drop": + self.dropout = DropTokenDropout(emb_dropout) + elif emb_dropout_type == "zero": + self.dropout = ZeroTokenDropout(emb_dropout) + elif emb_dropout_type == "normal": + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = TransformerCrossAttn( + dim, + depth, + heads, + dim_head, + mlp_dim, + dropout, + norm=norm, + norm_cond_dim=norm_cond_dim, + context_dim=context_dim, + ) + + def forward(self, inp: torch.Tensor, *args, context=None, context_list=None): + x = self.to_token_embedding(inp) + b, n, _ = x.shape + + x = self.dropout(x) + x += self.pos_embedding[:, :n] + + x = self.transformer(x, *args, context=context, context_list=context_list) + return x + diff --git a/prima/models/components/position_encoding.py b/prima/models/components/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..8456b77882ac9ad7d8c86dc84209ab63a06d4f46 --- /dev/null +++ b/prima/models/components/position_encoding.py @@ -0,0 +1,84 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple + +import numpy as np + +import torch +from torch import nn + +# Rotary Positional Encoding, adapted from: +# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +# 2. https://github.com/naver-ai/rope-vit +# 3. https://github.com/lucidrains/rotary-embedding-torch + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + if freqs_cis.is_cuda: + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + else: + # torch.repeat on complex numbers may not be supported on non-CUDA devices + # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten + freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/prima/models/components/t_cond_mlp.py b/prima/models/components/t_cond_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..df2e931aea60e82dbe5cb2b08cad111e401f1e41 --- /dev/null +++ b/prima/models/components/t_cond_mlp.py @@ -0,0 +1,204 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import copy +from typing import List, Optional + +import torch + + +class AdaptiveLayerNorm1D(torch.nn.Module): + def __init__(self, data_dim: int, norm_cond_dim: int): + super().__init__() + if data_dim <= 0: + raise ValueError(f"data_dim must be positive, but got {data_dim}") + if norm_cond_dim <= 0: + raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}") + self.norm = torch.nn.LayerNorm(data_dim) + self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim) + torch.nn.init.zeros_(self.linear.weight) + torch.nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + # x: (batch, ..., data_dim) + # t: (batch, norm_cond_dim) + # return: (batch, data_dim) + x = self.norm(x) + alpha, beta = self.linear(t).chunk(2, dim=-1) + + # Add singleton dimensions to alpha and beta + if x.dim() > 2: + alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1]) + beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1]) + + return x * (1 + alpha) + beta + + +class SequentialCond(torch.nn.Sequential): + def forward(self, input, *args, **kwargs): + for module in self: + if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)): + input = module(input, *args, **kwargs) + else: + input = module(input) + return input + + +def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1): + if norm == "batch": + return torch.nn.BatchNorm1d(dim) + elif norm == "layer": + return torch.nn.LayerNorm(dim) + elif norm == "ada": + assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}" + return AdaptiveLayerNorm1D(dim, norm_cond_dim) + elif norm is None: + return torch.nn.Identity() + else: + raise ValueError(f"Unknown norm: {norm}") + + +def linear_norm_activ_dropout( + input_dim: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, +) -> SequentialCond: + layers = [] + layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias)) + if norm is not None: + layers.append(normalization_layer(norm, output_dim, norm_cond_dim)) + layers.append(copy.deepcopy(activation)) + if dropout > 0.0: + layers.append(torch.nn.Dropout(dropout)) + return SequentialCond(*layers) + + +def create_simple_mlp( + input_dim: int, + hidden_dims: List[int], + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, +) -> SequentialCond: + layers = [] + prev_dim = input_dim + for hidden_dim in hidden_dims: + layers.extend( + linear_norm_activ_dropout( + prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ) + ) + prev_dim = hidden_dim + layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias)) + return SequentialCond(*layers) + + +class ResidualMLPBlock(torch.nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + num_hidden_layers: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, + ): + super().__init__() + if not (input_dim == output_dim == hidden_dim): + raise NotImplementedError( + f"input_dim {input_dim} != output_dim {output_dim} is not implemented" + ) + + layers = [] + prev_dim = input_dim + for i in range(num_hidden_layers): + layers.append( + linear_norm_activ_dropout( + prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ) + ) + prev_dim = hidden_dim + self.model = SequentialCond(*layers) + self.skip = torch.nn.Identity() + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return x + self.model(x, *args, **kwargs) + + +class ResidualMLP(torch.nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + num_hidden_layers: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + num_blocks: int = 1, + norm_cond_dim: int = -1, + ): + super().__init__() + self.input_dim = input_dim + self.model = SequentialCond( + linear_norm_activ_dropout( + input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ), + *[ + ResidualMLPBlock( + hidden_dim, + hidden_dim, + num_hidden_layers, + hidden_dim, + activation, + bias, + norm, + dropout, + norm_cond_dim, + ) + for _ in range(num_blocks) + ], + torch.nn.Linear(hidden_dim, output_dim, bias=bias), + ) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return self.model(x, *args, **kwargs) + + +class FrequencyEmbedder(torch.nn.Module): + def __init__(self, num_frequencies, max_freq_log2): + super().__init__() + frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies) + self.register_buffer("frequencies", frequencies) + + def forward(self, x): + # x should be of size (N,) or (N, D) + N = x.size(0) + if x.dim() == 1: # (N,) + x = x.unsqueeze(1) # (N, D) where D=1 + x_unsqueezed = x.unsqueeze(-1) # (N, D, 1) + scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies) + s = torch.sin(scaled) + c = torch.cos(scaled) + embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view( + N, -1 + ) # (N, D * 2 * num_frequencies + D) + return embedded + diff --git a/prima/models/components/transformer.py b/prima/models/components/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5e42ddf114db7fa1ec4e95aef8f3ebb6fbbbd72f --- /dev/null +++ b/prima/models/components/transformer.py @@ -0,0 +1,400 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import math +import warnings +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from .position_encoding import apply_rotary_enc, compute_axial_cis +from .model_utils import MLP + +warnings.simplefilter(action="ignore", category=FutureWarning) + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +# Check whether Flash Attention is available (and use it by default) +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() +# A fallback setting to allow all available kernels if Flash Attention fails +ALLOW_ALL_KERNELS = False + + +def sdp_kernel_context(dropout_p): + """ + Get the context for the attention scaled dot-product kernel. We use Flash Attention + by default, but fall back to all available kernels if Flash Attention fails. + """ + if ALLOW_ALL_KERNELS: + return contextlib.nullcontext() + + return torch.backends.cuda.sdp_kernel( + enable_flash=USE_FLASH_ATTN, + # if Flash attention kernel is off, then math kernel needs to be enabled + enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + enable_mem_efficient=OLD_GPU, + ) + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation + ) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2).contiguous() # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2).contiguous() + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int=0, + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/prima/models/discriminator.py b/prima/models/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..1465965754417f203337d437f665299fc5df8e34 --- /dev/null +++ b/prima/models/discriminator.py @@ -0,0 +1,129 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import torch +import torch.nn as nn + + +class Discriminator(nn.Module): + + def __init__(self): + """ + Pose + Shape discriminator proposed in HMR + """ + super(Discriminator, self).__init__() + + self.num_joints = 34 + # poses_alone + self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1) + nn.init.xavier_uniform_(self.D_conv1.weight) + nn.init.zeros_(self.D_conv1.bias) + self.relu = nn.ReLU(inplace=True) + self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1) + nn.init.xavier_uniform_(self.D_conv2.weight) + nn.init.zeros_(self.D_conv2.bias) + pose_out = [] + for i in range(self.num_joints): + pose_out_temp = nn.Linear(32, 1) + nn.init.xavier_uniform_(pose_out_temp.weight) + nn.init.zeros_(pose_out_temp.bias) + pose_out.append(pose_out_temp) + self.pose_out = nn.ModuleList(pose_out) + + # betas + self.betas_fc1 = nn.Linear(41, 10) # SMAL betas is 41 + nn.init.xavier_uniform_(self.betas_fc1.weight) + nn.init.zeros_(self.betas_fc1.bias) + self.betas_fc2 = nn.Linear(10, 5) + nn.init.xavier_uniform_(self.betas_fc2.weight) + nn.init.zeros_(self.betas_fc2.bias) + self.betas_out = nn.Linear(5, 1) + nn.init.xavier_uniform_(self.betas_out.weight) + nn.init.zeros_(self.betas_out.bias) + + # bones + self.bone_fc1 = nn.Linear(24, 10) # SMAL betas is 41 + nn.init.xavier_uniform_(self.bone_fc1.weight) + nn.init.zeros_(self.bone_fc1.bias) + self.bone_fc2 = nn.Linear(10, 5) + nn.init.xavier_uniform_(self.bone_fc2.weight) + nn.init.zeros_(self.bone_fc2.bias) + self.bone_out = nn.Linear(5, 1) + nn.init.xavier_uniform_(self.bone_out.weight) + nn.init.zeros_(self.bone_out.bias) + + # poses_joint + self.D_alljoints_fc1 = nn.Linear(32 * self.num_joints, 1024) + nn.init.xavier_uniform_(self.D_alljoints_fc1.weight) + nn.init.zeros_(self.D_alljoints_fc1.bias) + self.D_alljoints_fc2 = nn.Linear(1024, 1024) + nn.init.xavier_uniform_(self.D_alljoints_fc2.weight) + nn.init.zeros_(self.D_alljoints_fc2.bias) + self.D_alljoints_out = nn.Linear(1024, 1) + nn.init.xavier_uniform_(self.D_alljoints_out.weight) + nn.init.zeros_(self.D_alljoints_out.bias) + + def forward(self, poses: torch.Tensor, betas: torch.Tensor, bone=None) -> torch.Tensor: + """ + Forward pass of the discriminator. + Args: + poses (torch.Tensor): Tensor of shape (B, 23, 3, 3) containing a batch of poses (excluding the global orientation). + betas (torch.Tensor): Tensor of shape (B, 41) containing a batch of SMAL beta coefficients. + Returns: + torch.Tensor: Discriminator output with shape (B, 25) + """ + # bn = poses.shape[0] + # poses B x 207 + # poses = poses.reshape(bn, -1) + # poses B x num_joints x 1 x 9 + poses = poses.reshape(-1, self.num_joints, 1, 9) + bn = poses.shape[0] + # poses B x 9 x num_joints x 1 + poses = poses.permute(0, 3, 1, 2).contiguous() + + # poses_alone + poses = self.D_conv1(poses) + poses = self.relu(poses) + poses = self.D_conv2(poses) + poses = self.relu(poses) + + poses_out = [] + for i in range(self.num_joints): + poses_out_ = self.pose_out[i](poses[:, :, i, 0]) + poses_out.append(poses_out_) + poses_out = torch.cat(poses_out, dim=1) + + # betas + betas = self.betas_fc1(betas) + betas = self.relu(betas) + betas = self.betas_fc2(betas) + betas = self.relu(betas) + betas_out = self.betas_out(betas) + + # bone + if bone is not None: + bone = self.bone_fc1(bone) + bone = self.relu(bone) + bone = self.bone_fc2(bone) + bone = self.relu(bone) + bone_out = self.bone_out(bone) + + # poses_joint + poses = poses.reshape(bn, -1) + poses_all = self.D_alljoints_fc1(poses) + poses_all = self.relu(poses_all) + poses_all = self.D_alljoints_fc2(poses_all) + poses_all = self.relu(poses_all) + poses_all_out = self.D_alljoints_out(poses_all) + + if bone is not None: + disc_out = torch.cat((poses_out, betas_out, poses_all_out, bone_out), 1) + else: + disc_out = torch.cat((poses_out, betas_out, poses_all_out), 1) + return disc_out diff --git a/prima/models/heads/__init__.py b/prima/models/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61bc900f3c0a0506ca8b45f547c054167726dbc5 --- /dev/null +++ b/prima/models/heads/__init__.py @@ -0,0 +1,10 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from .smal_head import build_smal_head diff --git a/prima/models/heads/classifier_head.py b/prima/models/heads/classifier_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ff034cbe5a25dc90f5fb4926768ca0e41d93b6e1 --- /dev/null +++ b/prima/models/heads/classifier_head.py @@ -0,0 +1,30 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from torch import nn + + +class ClassTokenHead(nn.Module): + def __init__(self, embed_dim=1280, hidden_dim=4096, output_dim=256, num_layers=3, last_bn=True): + super().__init__() + mlp = [] + for l in range(num_layers): + dim1 = embed_dim if l == 0 else hidden_dim + dim2 = output_dim if l == num_layers - 1 else hidden_dim + mlp.append(nn.Linear(dim1, dim2, bias=False)) + if l < num_layers - 1: + mlp.append(nn.BatchNorm1d(dim2)) + mlp.append(nn.ReLU(inplace=True)) + elif last_bn: + mlp.append(nn.BatchNorm1d(dim2, affine=False)) + self.head = nn.Sequential(*mlp) + + def forward(self, x): + cls_feats = self.head(x) + return cls_feats \ No newline at end of file diff --git a/prima/models/heads/smal_head.py b/prima/models/heads/smal_head.py new file mode 100644 index 0000000000000000000000000000000000000000..53e20b0d9b4d020a79bddc6cfaa9e59f1dc16188 --- /dev/null +++ b/prima/models/heads/smal_head.py @@ -0,0 +1,647 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import einops +import pickle as pkl +from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat +from ..components.pose_transformer import TransformerDecoder + + +def build_smal_head(cfg): + smal_head_type = cfg.MODEL.SMAL_HEAD.get('TYPE', 'amr') + + if smal_head_type == 'new_bio_pose_transformer_decoder': + return NewBioGuidedSMALPoseDecoder(cfg) + else: + raise ValueError('Unknown SMAL head type: {}'.format(smal_head_type)) + + + + + + + +class NewBioGuidedSMALPoseDecoder(nn.Module): + """ + Bio-Guided SMAL Decoder with Pose Token Aggregation + + Final version: + - Query tokens = [param token] + [2D keypoint tokens (optional)] + [3D keypoint tokens (optional)] + - SAM3D-body-style layer-wise keypoint token updates: + * 2D: predict (x,y) in [-0.5,0.5] from kp2d tokens -> token_augment position encoding + + grid_sample image features at predicted locations -> add into kp2d token embeddings + + invalid_mask (out-of-bounds and optional vis mask) zeroes updates + * 3D: predict (x,y,z) from kp3d tokens -> pelvis-normalize -> token_augment position encoding + - token_augment is injected by feeding (token_embeddings + token_augment) into each decoder layer. + - Only param token (index 0) is used to regress pose/betas/cam deltas. + - Outputs: + pred_smal_params: dict with global_orient/pose/betas and optional keypoints_2d/3d + pred_cam: [B,3] + extra_outputs: includes bio-guided shape_feat/init_betas and pred_smal_params_list + """ + + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + # ========== Basic config ========== + self.joint_rep_type = cfg.MODEL.SMAL_HEAD.get("JOINT_REP", "6d") + self.joint_rep_dim = {"6d": 6, "aa": 3}[self.joint_rep_type] + self.npose = self.joint_rep_dim * (cfg.SMAL.NUM_JOINTS + 1) + + # ========== Dimensions ========== + self.decoder_dim = cfg.MODEL.SMAL_HEAD.get("DECODER_DIM", 1024) + context_dim = cfg.MODEL.SMAL_HEAD.get("IN_CHANNELS", 1024) + num_layers = cfg.MODEL.SMAL_HEAD.get("NUM_DECODER_LAYERS", 4) + num_heads = cfg.MODEL.SMAL_HEAD.get("NUM_HEADS", 8) + mlp_ratio = cfg.MODEL.SMAL_HEAD.get("MLP_RATIO", 4.0) + + # keypoint config + self.use_keypoint_2d_tokens = cfg.MODEL.SMAL_HEAD.get("USE_KEYPOINT_2D_TOKENS", False) + self.use_keypoint_3d_tokens = cfg.MODEL.SMAL_HEAD.get("USE_KEYPOINT_3D_TOKENS", False) + self.num_keypoints = cfg.SMAL.get("NUM_KEYPOINTS", 26) + self.keypoint_token_update = cfg.MODEL.SMAL_HEAD.get("KEYPOINT_TOKEN_UPDATE", False) + + # 2D update: whether to inject sampled image feature into kp2d tokens + self.kp2d_inject_image_feat = cfg.MODEL.SMAL_HEAD.get("KP2D_INJECT_IMAGE_FEAT", True) + + # IEF iters + self.ief_iters = cfg.MODEL.SMAL_HEAD.get("IEF_ITERS", 3) + + # pelvis indices + self.pelvis_idx = cfg.SMAL.get("PELVIS_IDX", [0, 1]) + + # ========== Test-time optimization config ========== + self._tta_mode = False # Track if in test-time adaptation mode + + # ========== [Coarse] Bio prior ========== + self.bio_to_betas_init = nn.Sequential( + nn.Linear(context_dim, 256), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(256, 41), + ) + self.shape_projector = nn.Sequential( + nn.Linear(41, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 128), + ) + + # ========== Init pose/cam ========== + self.init_pose = nn.Parameter(torch.zeros(1, self.npose)) + self.init_cam = nn.Parameter(torch.tensor([[0.9, 0, 0]], dtype=torch.float32)) + + # params -> param token + param_dim = self.npose + 41 + 3 + self.param_to_token = nn.Sequential( + nn.Linear(param_dim, self.decoder_dim), + nn.LayerNorm(self.decoder_dim), + nn.ReLU(), + ) + + # ========== Keypoint token embeddings ========== + if self.use_keypoint_2d_tokens: + self.keypoint_2d_embeddings = nn.Embedding(self.num_keypoints, self.decoder_dim) + nn.init.normal_(self.keypoint_2d_embeddings.weight, std=0.02) + + # (x,y) -> token augment + self.keypoint_2d_pos_encoder = nn.Sequential( + nn.Linear(2, 256), + nn.ReLU(), + nn.Linear(256, self.decoder_dim), + ) + # sampled image feat -> token dim (add into token embeddings) + self.keypoint_2d_feat_linear = nn.Linear(self.decoder_dim, self.decoder_dim) + + if self.use_keypoint_3d_tokens: + self.keypoint_3d_embeddings = nn.Embedding(self.num_keypoints, self.decoder_dim) + nn.init.normal_(self.keypoint_3d_embeddings.weight, std=0.02) + + # (x,y,z) -> token augment + self.keypoint_3d_pos_encoder = nn.Sequential( + nn.Linear(3, 256), + nn.ReLU(), + nn.Linear(256, self.decoder_dim), + ) + + # ========== Per-token intermediate heads (predict from kp tokens themselves) ========== + if self.keypoint_token_update: + if self.use_keypoint_2d_tokens: + self.kp2d_from_tokens = nn.Sequential( + nn.Linear(self.decoder_dim, self.decoder_dim), + nn.ReLU(), + nn.Linear(self.decoder_dim, 2), + ) + if self.use_keypoint_3d_tokens: + self.kp3d_from_tokens = nn.Sequential( + nn.Linear(self.decoder_dim, self.decoder_dim), + nn.ReLU(), + nn.Linear(self.decoder_dim, 3), + ) + + # ========== Image feature projection + pos encoding ========== + self.image_proj = nn.Identity() if context_dim == self.decoder_dim else nn.Linear(context_dim, self.decoder_dim) + self.image_pos_encoding = PositionalEncoding2D(self.decoder_dim) + + # ========== Transformer decoder layers ========== + self.layers = nn.ModuleList( + [ + PoseTransformerDecoderLayer( + d_model=self.decoder_dim, + nhead=num_heads, + dim_feedforward=int(self.decoder_dim * mlp_ratio), + dropout=0.1, + ) + for _ in range(num_layers) + ] + ) + self.norm = nn.LayerNorm(self.decoder_dim) + + # ========== Regression heads (param token only) ========== + self.decpose = nn.Sequential( + nn.Linear(self.decoder_dim, self.decoder_dim), + nn.ReLU(), + nn.Linear(self.decoder_dim, self.npose), + ) + self.decshape = nn.Sequential( + nn.Linear(self.decoder_dim, self.decoder_dim), + nn.ReLU(), + nn.Linear(self.decoder_dim, 41), + ) + self.deccam = nn.Sequential( + nn.Linear(self.decoder_dim, self.decoder_dim // 2), + nn.ReLU(), + nn.Linear(self.decoder_dim // 2, 3), + ) + + # -------------------------- + # helpers: query token build + # -------------------------- + def _build_query_tokens(self, pred_pose, pred_betas, pred_cam): + B = pred_pose.shape[0] + tokens = [] + + params = torch.cat([pred_pose, pred_betas, pred_cam], dim=1) # [B, param_dim] + param_token = self.param_to_token(params).unsqueeze(1) # [B,1,D] + tokens.append(param_token) + + kp2d_start = None + kp3d_start = None + + if self.use_keypoint_2d_tokens: + kp2d_start = sum(t.shape[1] for t in tokens) + kp2d_tokens = self.keypoint_2d_embeddings.weight.unsqueeze(0).expand(B, -1, -1).contiguous() + tokens.append(kp2d_tokens) + + if self.use_keypoint_3d_tokens: + kp3d_start = sum(t.shape[1] for t in tokens) + kp3d_tokens = self.keypoint_3d_embeddings.weight.unsqueeze(0).expand(B, -1, -1).contiguous() + tokens.append(kp3d_tokens) + + token_embeddings = torch.cat(tokens, dim=1) # [B,Nq,D] + token_augment = torch.zeros_like(token_embeddings) + + return token_embeddings, token_augment, kp2d_start, kp3d_start + + # -------------------------- + # helpers: updates + # -------------------------- + def _kp2d_update(self, token_embeddings, token_augment, image_features, kp2d_start, H, W, vis_mask=None): + """ + SAM3D-body-style 2D keypoint token update. + + image_features: [B, HW, D] projected + pos-encoded, with HW=H*W (expected 12*16) + vis_mask: optional [B,N] bool (True=valid) + """ + if not (self.keypoint_token_update and self.use_keypoint_2d_tokens): + return token_embeddings, token_augment, None + + B = token_embeddings.shape[0] + N = self.num_keypoints + + kp_tokens = token_embeddings[:, kp2d_start : kp2d_start + N, :] # [B,N,D] + + # predict coords in [-0.5,0.5] + pred_xy = self.kp2d_from_tokens(kp_tokens) + pred_xy = torch.tanh(pred_xy) * 0.5 + + # invalid mask (out of bounds + optional vis) + pred_xy_01 = pred_xy + 0.5 + invalid = ( + (pred_xy_01[..., 0] < 0.0) + | (pred_xy_01[..., 0] > 1.0) + | (pred_xy_01[..., 1] < 0.0) + | (pred_xy_01[..., 1] > 1.0) + ) + if vis_mask is not None: + invalid = invalid | (~vis_mask) + valid = (~invalid).unsqueeze(-1).float() # [B,N,1] + + # update token_augment slice + token_augment = token_augment.clone() + token_augment[:, kp2d_start : kp2d_start + N, :] = self.keypoint_2d_pos_encoder(pred_xy) * valid + + # inject sampled image feature into kp2d tokens (optional) + if self.kp2d_inject_image_feat: + img = image_features.view(B, H, W, self.decoder_dim).permute(0, 3, 1, 2).contiguous() # [B,D,H,W] + grid = (pred_xy * 2.0).unsqueeze(2) # [B,N,1,2] in [-1,1] + + sampled = ( + F.grid_sample(img, grid, mode="bilinear", padding_mode="zeros", align_corners=False) + .squeeze(3) + .permute(0, 2, 1) + .contiguous() + ) # [B,N,D] + + sampled = sampled * valid + token_embeddings = token_embeddings.clone() + token_embeddings[:, kp2d_start : kp2d_start + N, :] += self.keypoint_2d_feat_linear(sampled) + + return token_embeddings, token_augment, pred_xy + + def _kp3d_update(self, token_embeddings, token_augment, kp3d_start): + if not (self.keypoint_token_update and self.use_keypoint_3d_tokens): + return token_embeddings, token_augment, None + + N = self.num_keypoints + kp_tokens = token_embeddings[:, kp3d_start : kp3d_start + N, :] # [B,N,D] + + pred_xyz = self.kp3d_from_tokens(kp_tokens) # [B,N,3] + + # pelvis normalize + pelvis_center = pred_xyz[:, self.pelvis_idx, :].mean(dim=1, keepdim=True) # [B,1,3] + pred_xyz_norm = pred_xyz - pelvis_center + + token_augment = token_augment.clone() + token_augment[:, kp3d_start : kp3d_start + N, :] = self.keypoint_3d_pos_encoder(pred_xyz_norm) + + return token_embeddings, token_augment, pred_xyz + + # -------------------------- + # forward + # -------------------------- + def forward(self, x, keypoint_coords_2d=None, keypoint_coords_3d=None, **kwargs): + """ + Inputs: + x: [B, Hp*Wp+1, C] image tokens from backbone concatenated with bio token + (BioCLIP token is the last token in the sequence) + + Note: + keypoint_coords_2d can optionally provide a vis/conf mask: [B,N,3] (x,y,vis) + We do NOT inject GT coords into tokens by default; they are used only as optional masking. + """ + B = x.shape[0] + + # ---- Data preprocessing ---- + # Handle 4D input tensors of shape (B, C, H, W). + if len(x.shape) == 4: + x = einops.rearrange(x, 'b c h w -> b (h w) c') + + bio_token = x[:, -1, :] # [B, C] - the BioCLIP token is the final token + image_features = x[:, :-1, :] # [B, H*W, C] - remaining image features + + # ---- Coarse bio shape ---- + init_betas = self.bio_to_betas_init(bio_token) # [B,41] + shape_feat = F.normalize(self.shape_projector(init_betas), dim=1) + + # ---- Image feature projection ---- + # Project only image features, excluding the bio token. + image_features = self.image_proj(image_features) # [B,HW,D] + + # Your backbone: vit crop 256x192 with patch16 => Hp=12, Wp=16 + H, W = 12, 16 + assert image_features.shape[1] == H * W, f"Expected HW={H*W}, got {image_features.shape[1]}" + + img_pos = self.image_pos_encoding(H, W).to(image_features.device) # [HW,D] + image_features = image_features + img_pos.unsqueeze(0) + + # ---- init params ---- + pred_pose = self.init_pose.expand(B, -1) + pred_betas = init_betas + pred_cam = self.init_cam.expand(B, -1) + + pred_pose_list, pred_betas_list, pred_cam_list = [], [], [] + pred_keypoints_2d_list, pred_keypoints_3d_list = [], [] + + # Optional visibility mask from provided 2D keypoints + vis_mask = None + if keypoint_coords_2d is not None and keypoint_coords_2d.shape[-1] == 3: + vis_mask = keypoint_coords_2d[..., 2] > 0 # [B,N] + + # ---- IEF loop ---- + for _ in range(self.ief_iters): + token_embeddings, token_augment, kp2d_start, kp3d_start = self._build_query_tokens( + pred_pose, pred_betas, pred_cam + ) + + # ---- Transformer layers ---- + for layer_idx, layer in enumerate(self.layers): + # inject dynamic augment + tokens_in = token_embeddings + token_augment + token_embeddings = layer(tokens_in, image_features) + + # layer-wise token update (skip last layer) + if self.keypoint_token_update and (layer_idx < len(self.layers) - 1): + if self.use_keypoint_2d_tokens: + token_embeddings, token_augment, pred_xy = self._kp2d_update( + token_embeddings, token_augment, image_features, kp2d_start, H, W, vis_mask=vis_mask + ) + if pred_xy is not None: + pred_keypoints_2d_list.append(pred_xy) + + if self.use_keypoint_3d_tokens: + token_embeddings, token_augment, pred_xyz = self._kp3d_update( + token_embeddings, token_augment, kp3d_start + ) + if pred_xyz is not None: + pred_keypoints_3d_list.append(pred_xyz) + + # ---- Regress deltas from param token ---- + token_embeddings = self.norm(token_embeddings) + param_token_out = token_embeddings[:, 0, :] + + delta_pose = self.decpose(param_token_out) + delta_betas = self.decshape(param_token_out) + delta_cam = self.deccam(param_token_out) + + pred_pose = pred_pose + delta_pose + pred_betas = pred_betas + delta_betas + pred_cam = pred_cam + delta_cam + + pred_pose_list.append(pred_pose) + pred_betas_list.append(pred_betas) + pred_cam_list.append(pred_cam) + + # ---- Convert joint representation ---- + joint_conversion_fn = { + "6d": rot6d_to_rotmat, + "aa": lambda y: aa_to_rotmat(y.view(-1, 3).contiguous()), + }[self.joint_rep_type] + + pred_smal_params_list = { + "pose": torch.cat( + [joint_conversion_fn(p).view(B, -1, 3, 3)[:, 1:, :, :] for p in pred_pose_list], + dim=0, + ), + "betas": torch.cat(pred_betas_list, dim=0), + "cam": torch.cat(pred_cam_list, dim=0), + "keypoints_2d": torch.cat(pred_keypoints_2d_list, dim=0) if len(pred_keypoints_2d_list) else None, + "keypoints_3d": torch.cat(pred_keypoints_3d_list, dim=0) if len(pred_keypoints_3d_list) else None, + } + + pred_pose_mat = joint_conversion_fn(pred_pose).view(B, self.cfg.SMAL.NUM_JOINTS + 1, 3, 3) + pred_smal_params = { + "global_orient": pred_pose_mat[:, [0]], + "pose": pred_pose_mat[:, 1:], + "betas": pred_betas, + } + + # expose final predicted keypoints for losses + if self.keypoint_token_update: + if self.use_keypoint_2d_tokens and len(pred_keypoints_2d_list): + pred_smal_params["keypoints_2d"] = pred_keypoints_2d_list[-1] + if self.use_keypoint_3d_tokens and len(pred_keypoints_3d_list): + pred_smal_params["keypoints_3d"] = pred_keypoints_3d_list[-1] + + extra_outputs = { + "shape_feat": shape_feat, + "init_betas": init_betas, + "pred_smal_params_list": pred_smal_params_list, + } + return pred_smal_params, pred_cam, extra_outputs + + # -------------------------- + # Test-time optimization helpers + # -------------------------- + def freeze_all_except_keypoint_tokens(self): + """ + Freeze all parameters except keypoint token embeddings and their prediction heads. + Use this before test-time optimization. + """ + # Freeze everything first + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze only keypoint-related parameters + if self.use_keypoint_2d_tokens: + for param in self.keypoint_2d_embeddings.parameters(): + param.requires_grad = True + for param in self.keypoint_2d_pos_encoder.parameters(): + param.requires_grad = True + for param in self.keypoint_2d_feat_linear.parameters(): + param.requires_grad = True + if self.keypoint_token_update: + for param in self.kp2d_from_tokens.parameters(): + param.requires_grad = True + + if self.use_keypoint_3d_tokens: + for param in self.keypoint_3d_embeddings.parameters(): + param.requires_grad = True + for param in self.keypoint_3d_pos_encoder.parameters(): + param.requires_grad = True + if self.keypoint_token_update: + for param in self.kp3d_from_tokens.parameters(): + param.requires_grad = True + + self._tta_mode = True + print("[TTA] Frozen all parameters except keypoint tokens") + + def freeze_backbone_only(self): + """ + Freeze only backbone, keep SMAL head trainable. + Use for full SMAL parameter + keypoint optimization. + """ + # Unfreeze all SMAL head parameters + for param in self.parameters(): + param.requires_grad = True + + self._tta_mode = True + print("[TTA] SMAL head fully trainable (backbone frozen separately)") + + def freeze_except_regression_heads(self): + """ + Freeze everything except the final regression heads (pose/shape/cam) and keypoint embeddings. + Keep transformer frozen to preserve pretrained representations. + """ + # Freeze everything first + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze only the final regression heads (small MLPs) + for param in self.decpose.parameters(): + param.requires_grad = True + for param in self.decshape.parameters(): + param.requires_grad = True + for param in self.deccam.parameters(): + param.requires_grad = True + + # Unfreeze ONLY keypoint embeddings (learned tokens, NOT position encoders) + if self.use_keypoint_2d_tokens: + self.keypoint_2d_embeddings.weight.requires_grad = True + + if self.use_keypoint_3d_tokens: + self.keypoint_3d_embeddings.weight.requires_grad = True + + # DO NOT unfreeze transformer - keep pretrained representations + # DO NOT unfreeze param_to_token - keep initial token mapping stable + + self._tta_mode = True + print("[TTA] Frozen all except regression heads and keypoint embeddings") + + def unfreeze_all(self): + """Restore all parameters to trainable state.""" + for param in self.parameters(): + param.requires_grad = True + self._tta_mode = False + print("[TTA] Unfrozen all parameters") + + def get_tta_parameters(self, mode='keypoints_only'): + """ + Get list of parameters that should be optimized during test-time adaptation. + MUST match what's unfrozen by freeze methods! + + Args: + mode: 'keypoints_only', 'regression_heads', or 'all' + """ + params = [] + + # Keypoint embeddings only (NOT position encoders or feature linears) + if mode in ['keypoints_only', 'regression_heads', 'all']: + if self.use_keypoint_2d_tokens: + params.append(self.keypoint_2d_embeddings.weight) + + if self.use_keypoint_3d_tokens: + params.append(self.keypoint_3d_embeddings.weight) + + # Regression heads only (NO transformer or param_to_token) + if mode in ['regression_heads', 'all']: + params.extend(list(self.decpose.parameters())) + params.extend(list(self.decshape.parameters())) + params.extend(list(self.deccam.parameters())) + + return params + + + + +class PoseTransformerDecoderLayer(nn.Module): + """ + Single-layer transformer decoder for pose-token aggregation. + Includes self-attention over tokens, cross-attention from tokens to + image features, and a feed-forward network. + """ + + def __init__(self, d_model=1024, nhead=8, dim_feedforward=4096, dropout=0.1): + super().__init__() + + # Self-attention over tokens + self.self_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=True + ) + self.norm1 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + + # Cross-attention from image features into tokens + self.cross_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=True + ) + self.norm2 = nn.LayerNorm(d_model) + self.dropout2 = nn.Dropout(dropout) + + # Feed-Forward Network + self.ffn = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + nn.Dropout(dropout), + ) + self.norm3 = nn.LayerNorm(d_model) + + def forward(self, tokens, image_features): + """ + Args: + tokens: [B, N_tokens, C] containing pose and keypoint tokens + image_features: [B, N_pixels, C] image features + + Returns: + tokens: [B, N_tokens, C] updated tokens + """ + + # Self-attention lets tokens exchange information. + attn_output, _ = self.self_attn(tokens, tokens, tokens) + tokens = tokens + self.dropout1(attn_output) + tokens = self.norm1(tokens) + + # Cross-attention injects visual information from image features. + attn_output, _ = self.cross_attn( + query=tokens, + key=image_features, + value=image_features, + ) + tokens = tokens + self.dropout2(attn_output) + tokens = self.norm2(tokens) + + # Feed-Forward Network + ffn_output = self.ffn(tokens) + tokens = tokens + ffn_output + tokens = self.norm3(tokens) + + return tokens + + +class PositionalEncoding2D(nn.Module): + """ + 2D sinusoidal positional encoding for image features. + """ + + def __init__(self, embed_dim=1024, temperature=10000): + super().__init__() + self.embed_dim = embed_dim + self.temperature = temperature + + def forward(self, H, W): + """ + Args: + H, W: height and width of the feature map + + Returns: + pos_encoding: [H*W, embed_dim] + """ + # Build grid coordinates. + y_embed = torch.arange(H, dtype=torch.float32).unsqueeze(1).repeat(1, W) + x_embed = torch.arange(W, dtype=torch.float32).unsqueeze(0).repeat(H, 1) + + # Normalize to [0, 1]. + y_embed = y_embed / H + x_embed = x_embed / W + + # Build frequencies. + dim_t = torch.arange(self.embed_dim // 2, dtype=torch.float32) + dim_t = self.temperature ** (2 * dim_t / self.embed_dim) + + # Apply sine/cosine encoding. + pos_x = x_embed[: , : , None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + + pos_x = torch.stack( + [pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()], dim=3 + ).flatten(2) + pos_y = torch. stack( + [pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()], dim=3 + ).flatten(2) + + pos = torch.cat([pos_y, pos_x], dim=2).flatten(0, 1) # [H*W, embed_dim] + + return pos + + diff --git a/prima/models/losses.py b/prima/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..2982d17689d36d91ab8d351e354a04c1699ddf75 --- /dev/null +++ b/prima/models/losses.py @@ -0,0 +1,580 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import torch +import torch.nn as nn +import numpy as np +import pickle +import torch.nn.functional as F +from ..utils.geometry import aa_to_rotmat +from typing import Dict + + +def matrix_to_axis_angle(rot_mats: torch.Tensor) -> torch.Tensor: + """Convert rotation matrices (..., 3, 3) to axis-angle vectors (..., 3). + + This local implementation avoids a hard runtime dependency on PyTorch3D. + """ + if rot_mats.shape[-2:] != (3, 3): + raise ValueError(f"Expected (..., 3, 3) rotation matrices, got {rot_mats.shape}") + + trace = rot_mats[..., 0, 0] + rot_mats[..., 1, 1] + rot_mats[..., 2, 2] + cos_theta = (trace - 1.0) * 0.5 + cos_theta = torch.clamp(cos_theta, -1.0, 1.0) + theta = torch.acos(cos_theta) + + vee = torch.stack( + [ + rot_mats[..., 2, 1] - rot_mats[..., 1, 2], + rot_mats[..., 0, 2] - rot_mats[..., 2, 0], + rot_mats[..., 1, 0] - rot_mats[..., 0, 1], + ], + dim=-1, + ) + sin_theta = torch.sin(theta) + eps = 1e-6 + scale = theta / torch.clamp(2.0 * sin_theta, min=eps) + aa = vee * scale.unsqueeze(-1) + + # For very small angles, first-order approximation: aa ~= 0.5 * vee + small = theta < 1e-4 + if small.any(): + aa = torch.where(small.unsqueeze(-1), 0.5 * vee, aa) + return aa + + + +class DepthLoss(nn.Module): + """ + Depth loss between predicted SMAL vertices and GT SMAL vertices. + Compares vertex Z (camera space) after applying camera translation. + Only computes loss for samples that have valid GT SMAL parameters. + """ + def __init__(self, loss_type: str = 'l1'): + super().__init__() + self.loss_type = loss_type + self.l1 = nn.L1Loss(reduction='none') # Changed to 'none' for per-sample masking + self.l2 = nn.MSELoss(reduction='none') # Changed to 'none' for per-sample masking + + def forward(self, + pred_vertices: torch.Tensor, # (B, V, 3) + pred_cam_t: torch.Tensor, # (B, 3) + gt_smal_params: Dict[str, torch.Tensor], + smal_model, # SMAL instance callable + is_axis_angle: Dict[str, torch.Tensor], + gt_cam_t: torch.Tensor = None, # (B, 3) or None -> fallback to pred_cam_t + has_smal_params: Dict[str, torch.Tensor] = None # Added masking support + ) -> torch.Tensor: + batch_size = pred_vertices.shape[0] + device = pred_vertices.device + + # Determine which samples have valid GT SMAL params + # A sample is valid only if it has GT for pose, betas, and global_orient + if has_smal_params is not None: + valid_mask = (has_smal_params['pose'] * + has_smal_params['betas'] * + has_smal_params['global_orient']).bool() + + # If no samples have valid GT, return zero loss + if valid_mask.sum() == 0: + return torch.tensor(0., device=device, dtype=pred_vertices.dtype) + else: + # If not provided, assume all samples are valid + valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device) + + # prepare GT params for SMAL + gt_params_for_smal = {} + for k in ['global_orient', 'pose', 'betas']: + val = gt_smal_params[k].to(device=device) + if k == 'betas': + gt_params_for_smal[k] = val.view(batch_size, -1) + else: + gt_val = val.view(batch_size, -1) + if is_axis_angle[k].all(): + gt_val = aa_to_rotmat(gt_val.reshape(-1, 3)).view(batch_size, -1, 3, 3) + else: + gt_val = gt_val.view(batch_size, -1, 3, 3) + gt_params_for_smal[k] = gt_val + + # generate GT vertices (no grad) + with torch.no_grad(): + gt_out = smal_model(**gt_params_for_smal, pose2rot=False) + gt_vertices = gt_out.vertices.view(batch_size, -1, 3) + + if gt_cam_t is None: + gt_cam_t = pred_cam_t + + # depth = z in camera coordinates + pred_depth = (pred_vertices + pred_cam_t.unsqueeze(1))[..., 2] # (B, V) + gt_depth = (gt_vertices + gt_cam_t.unsqueeze(1))[..., 2] # (B, V) + + # Compute loss per sample + if self.loss_type == 'l1': + loss_per_sample = self.l1(pred_depth, gt_depth).mean(dim=1) # (B,) + else: + loss_per_sample = self.l2(pred_depth, gt_depth).mean(dim=1) # (B,) + + # Apply mask: only compute loss for samples with valid GT + masked_loss = loss_per_sample * valid_mask.float() + + # Return mean over valid samples + num_valid = valid_mask.sum().float() + if num_valid > 0: + return masked_loss.sum() / num_valid + else: + return torch.tensor(0., device=device, dtype=pred_vertices.dtype) + + +class MaskLoss(nn.Module): + """ + Mask loss between rendered predicted mesh mask and rendered GT mesh mask. + This loss relies on a MeshRenderer-like object that provides `render_mask(vertices, camera_translation, focal_length)` + returning a single-channel numpy mask (H, W) with values 0/1. + """ + def __init__(self, mesh_renderer=None): + super().__init__() + self.mesh_renderer = mesh_renderer + self.l1 = nn.L1Loss(reduction='mean') + + def forward(self, + pred_vertices: torch.Tensor, # (B, V, 3) + pred_cam_t: torch.Tensor, # (B, 3) + gt_smal_params: Dict[str, torch.Tensor], + smal_model, # SMAL instance callable + is_axis_angle: Dict[str, torch.Tensor], + gt_cam_t: torch.Tensor = None, # optional (B,3) + focal_length: float = 1000.0 + ) -> torch.Tensor: + batch_size = pred_vertices.shape[0] + device = pred_vertices.device + + # if no renderer available, return zero loss + if self.mesh_renderer is None: + return torch.tensor(0., device=device, dtype=pred_vertices.dtype) + + # prepare GT params for SMAL + gt_params_for_smal = {} + for k in ['global_orient', 'pose', 'betas']: + val = gt_smal_params[k].to(device=device) + if k == 'betas': + gt_params_for_smal[k] = val.view(batch_size, -1) + else: + gt_val = val.view(batch_size, -1) + if is_axis_angle[k].all(): + gt_val = aa_to_rotmat(gt_val.reshape(-1, 3)).view(batch_size, -1, 3, 3) + else: + gt_val = gt_val.view(batch_size, -1, 3, 3) + gt_params_for_smal[k] = gt_val + + # generate GT vertices (no grad) + with torch.no_grad(): + gt_out = smal_model(**gt_params_for_smal, pose2rot=False) + gt_vertices = gt_out.vertices + + if gt_cam_t is None: + gt_cam_t = pred_cam_t + + # convert to numpy for renderer + pred_vertices_np = pred_vertices.detach().cpu().numpy() + gt_vertices_np = gt_vertices.detach().cpu().numpy() + cam_np = pred_cam_t.detach().cpu().numpy() if pred_cam_t is not None else np.zeros((batch_size, 3), dtype=np.float32) + + per_item_losses = [] + for i in range(batch_size): + try: + pred_mask = self.mesh_renderer.render_mask(pred_vertices_np[i], cam_np[i], focal_length) + gt_mask_r = self.mesh_renderer.render_mask(gt_vertices_np[i], cam_np[i], focal_length) + pm = torch.from_numpy(pred_mask).to(device=device, dtype=pred_vertices.dtype) + gm = torch.from_numpy(gt_mask_r).to(device=device, dtype=pred_vertices.dtype) + per_item_losses.append(self.l1(pm, gm)) + except Exception: + # ignore render failure for this sample + continue + + if len(per_item_losses) == 0: + return torch.tensor(0., device=device, dtype=pred_vertices.dtype) + + return torch.stack(per_item_losses).mean() + +class Keypoint2DLoss(nn.Module): + + def __init__(self, loss_type: str = 'l1'): + """ + 2D keypoint loss module. + Args: + loss_type (str): Choose between l1 and l2 losses. + """ + super(Keypoint2DLoss, self).__init__() + if loss_type == 'l1': + self.loss_fn = nn.L1Loss(reduction='none') + elif loss_type == 'l2': + self.loss_fn = nn.MSELoss(reduction='none') + else: + raise NotImplementedError('Unsupported loss function') + + def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor: + """ + Compute 2D reprojection loss on the keypoints. + Args: + pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints) + gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence. + Returns: + torch.Tensor: 2D keypoint loss. + """ + conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() + batch_size = conf.shape[0] + loss = (conf * self.loss_fn(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).sum(dim=(1, 2)) + return loss.sum() + + +class Keypoint3DLoss(nn.Module): + + def __init__(self, loss_type: str = 'l1'): + """ + 3D keypoint loss module. + Args: + loss_type (str): Choose between l1 and l2 losses. + """ + super(Keypoint3DLoss, self).__init__() + if loss_type == 'l1': + self.loss_fn = nn.L1Loss(reduction='none') + elif loss_type == 'l2': + self.loss_fn = nn.MSELoss(reduction='none') + else: + raise NotImplementedError('Unsupported loss function') + + def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor, pelvis_id: int = 0): + """ + Compute 3D keypoint loss. + Args: + pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints) + gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence. + Returns: + torch.Tensor: 3D keypoint loss. + """ + batch_size = pred_keypoints_3d.shape[0] + gt_keypoints_3d = gt_keypoints_3d.clone() + pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1) + gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, pelvis_id, :-1].unsqueeze(dim=1) + conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() + gt_keypoints_3d = gt_keypoints_3d[:, :, :-1] + loss = (conf * self.loss_fn(pred_keypoints_3d, gt_keypoints_3d)).sum(dim=(1, 2)) + return loss.sum() + + +class ParameterLoss(nn.Module): + + def __init__(self): + """ + SMAL parameter loss module. + """ + super(ParameterLoss, self).__init__() + self.loss_fn = nn.MSELoss(reduction='none') + + def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor): + """ + Compute SMAL parameter loss. + Args: + pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas) + gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth MANO parameters. + Returns: + torch.Tensor: L2 parameter loss loss. + """ + mask = torch.ones_like(pred_param, device=pred_param.device, dtype=pred_param.dtype) + batch_size = pred_param.shape[0] + num_dims = len(pred_param.shape) + mask_dimension = [batch_size] + [1] * (num_dims - 1) + has_param = has_param.type(pred_param.type()).view(*mask_dimension) + loss_param = (has_param * self.loss_fn(pred_param*mask, gt_param*mask)) + return loss_param.sum() + + +class PosePriorLoss(nn.Module): + def __init__(self, path_prior): + super(PosePriorLoss, self).__init__() + with open(path_prior, "rb") as f: + data_prior = pickle.load(f, encoding="latin1") + + self.register_buffer("mean_pose", torch.from_numpy(data_prior["mean_pose"]).float()) + self.register_buffer("precs", torch.from_numpy(np.array(data_prior["pic"])).float()) + + use_index = np.ones(105, dtype=bool) + use_index[:3] = False # global rotation set False + self.register_buffer("use_index", torch.from_numpy(use_index).float()) + + def forward(self, x, has_gt): + """ + Args: + x: (batch_size, 35, 3, 3) + has_gt: has pose? + Returns: + pose prior loss + """ + if has_gt.sum() == len(has_gt): + return torch.tensor(0.0, dtype=x.dtype, device=x.device) + has_gt = has_gt.type(torch.bool) + x = x[~has_gt] + x = matrix_to_axis_angle(x.reshape(-1, 3, 3)) + delta = x.reshape(-1, 35*3) - self.mean_pose + loss = torch.tensordot(delta, self.precs, dims=([1], [0])) * self.use_index + return (loss ** 2).mean() + + +class ShapePriorLoss(nn.Module): + def __init__(self, path_prior): + super(ShapePriorLoss, self).__init__() + with open(path_prior, "rb") as f: + data_prior = pickle.load(f, encoding="latin1") + + model_covs = np.array(data_prior["cluster_cov"]) # shape: (5, 41, 41) + inverse_covs = np.stack( + [np.linalg.inv(model_cov + 1e-5 * np.eye(model_cov.shape[0])) for model_cov in model_covs], + axis=0) + prec = np.stack([np.linalg.cholesky(inverse_cov) for inverse_cov in inverse_covs], axis=0) + + self.register_buffer("betas_prec", torch.FloatTensor(prec)) + self.register_buffer("mean_betas", torch.FloatTensor(data_prior["cluster_means"])) + + def forward(self, x, category, has_gt): + """ + Args: + x: predicted betas (batch_size, 41) + category: animal category (batch_size,) + has_gt: has shape? + Returns: + shape prior loss + """ + if has_gt.sum() == len(has_gt): + return torch.tensor(0.0, dtype=x.dtype, device=x.device) + has_gt = has_gt.type(torch.bool) + x, category = x[~has_gt], category[~has_gt] + delta = (x - self.mean_betas[category.long()]) # [batch_size, 41] + loss = [] + for x0, c0 in zip(delta, category): + loss.append(torch.tensordot(x0, self.betas_prec[c0], dims=([0], [0]))) + loss = torch.stack(loss, dim=0) + return (loss ** 2).mean() + + + +class PrototypeSupConLoss(nn.Module): + def __init__(self, prototypes_init, feat_dim=128, temperature=0.1): + """ + prototypes_init: precomputed (5, 512) BioCLIP family prototypes + feat_dim: dimension of the projected shape feature (128) + """ + super().__init__() + self.temperature = temperature + + # The prototypes should live in the projected feature space. + # A practical setup is to pass the BioCLIP centers through the projector + # once at the beginning of training to initialize these prototypes. + self.register_buffer("prototypes", torch.randn(5, feat_dim)) + + def forward(self, features, labels): + """ + features: (B, 128) normalized shared features or shape features + labels: (B,) family indices for the 5-way classification setting + """ + # 1. Ensure features are normalized. + features = F.normalize(features, p=2, dim=1) + # 2. Ensure prototypes are normalized as well. + prototypes = F.normalize(self.prototypes, p=2, dim=1) + + # 3. Compute sample-to-prototype similarities with temperature scaling. + logits = torch.matmul(features, prototypes.T) / self.temperature + + # 4. Cross-entropy pulls samples toward their family prototype and + # pushes them away from the other family prototypes. + loss = F.cross_entropy(logits, labels) + + return loss + + @torch.no_grad() + def update_prototypes(self, features, labels, momentum=0.999): + """ + Optional: update prototypes with momentum during training so they + adapt gradually to the 3D task. + """ + for i in range(5): + mask = (labels == i) + if mask.any(): + new_mean = features[mask].mean(dim=0) + self.prototypes[i] = momentum * self.prototypes[i] + (1 - momentum) * new_mean + + +class SupConLoss(nn.Module): + def __init__(self, temperature=0.1, contrast_mode='all', + base_temperature=0.07): + super(SupConLoss, self).__init__() + self.temperature = temperature + self.contrast_mode = contrast_mode + self.base_temperature = base_temperature + + def forward(self, features, labels=None, mask=None): + """ + Args: + features: hidden vector of shape [bsz, ...]. + labels: ground truth of shape [bsz]. + mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j + has the same class as sample i. Can be asymmetric. + Returns: + A loss scalar. + """ + features = torch.stack((features, features), dim=1) + device = features.device + + if len(features.shape) < 3: + raise ValueError('`features` needs to be [bsz, n_views, ...],' + 'at least 3 dimensions are required') + if len(features.shape) > 3: + features = features.view(features.shape[0], features.shape[1], -1) + + batch_size = features.shape[0] + if labels is not None and mask is not None: + raise ValueError('Cannot define both `labels` and `mask`') + elif labels is None and mask is None: + mask = torch.eye(batch_size, dtype=torch.float32).to(device) + elif labels is not None: + labels = labels.contiguous().view(-1, 1) + if labels.shape[0] != batch_size: + raise ValueError('Num of labels does not match num of features') + mask = torch.eq(labels, labels.T).float().to(device) + else: + mask = mask.float().to(device) + + contrast_count = features.shape[1] + contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) + if self.contrast_mode == 'one': + anchor_feature = features[:, 0] + anchor_count = 1 + elif self.contrast_mode == 'all': + anchor_feature = contrast_feature + anchor_count = contrast_count + else: + raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) + + # compute logits + anchor_dot_contrast = torch.div( + torch.matmul(anchor_feature, contrast_feature.T), + self.temperature) + # for numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # tile mask + mask = mask.repeat(anchor_count, contrast_count) + # mask-out self-contrast cases + logits_mask = torch.scatter( + torch.ones_like(mask), + 1, + torch.arange(batch_size * anchor_count).view(-1, 1).to(device), + 0 + ) + mask = mask * logits_mask + + # compute log_prob + exp_logits = torch.exp(logits) * logits_mask + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + + # compute mean of log-likelihood over positive + mask_pos_pairs = mask.sum(1) + mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) + mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs + + # loss + loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos + loss = loss.view(anchor_count, batch_size).mean() + + return loss + +# Auxiliary intermediate-supervision loss module. + +class InterLoss(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.use_intermediate_supervision = cfg.LOSS.get('USE_INTERMEDIATE_SUPERVISION', True) + self.intermediate_weight = cfg.LOSS.get('INTERMEDIATE_WEIGHT', 0.5) + + # 2D keypoint loss + self.keypoint_2d_loss = nn.MSELoss(reduction='none') + + # 3D keypoint loss + self.keypoint_3d_loss = nn.MSELoss(reduction='none') + + def forward(self, predictions, gt_data): + """ + Args: + predictions: model outputs (pred_smal_params, pred_cam, pred_smal_params_list) + gt_data: dict containing ground-truth data + - 'keypoints_2d': [B, N, 3] (x, y, visibility) + - 'keypoints_3d': [B, N, 3] (x, y, z) or [B, N, 4] (x, y, z, confidence) + """ + pred_smal_params, pred_cam, pred_smal_params_list = predictions + + losses = {} + total_loss = 0.0 + + # ========== Supervision for final predictions ========== + # Final keypoint supervision can be added here after running the + # predicted parameters through the SMAL model. + + # ========== Supervision for intermediate predictions ========== + if self.use_intermediate_supervision and pred_smal_params_list is not None: + + # 2D keypoint supervision + if 'keypoints_2d' in pred_smal_params_list and pred_smal_params_list['keypoints_2d'] is not None: + pred_kps_2d_all = pred_smal_params_list['keypoints_2d'] + # [B*num_iters, N, 2] + + gt_kps_2d = gt_data['keypoints_2d'][: , :, :2] # [B, N, 2] + gt_vis_2d = gt_data['keypoints_2d'][:, :, 2] # [B, N] + + # Repeat the ground truth for each iteration. + num_iters = pred_kps_2d_all.shape[0] // gt_kps_2d.shape[0] + gt_kps_2d_repeated = gt_kps_2d.repeat(num_iters, 1, 1) # [B*num_iters, N, 2] + gt_vis_2d_repeated = gt_vis_2d.repeat(num_iters, 1) # [B*num_iters, N] + + # Compute the loss only over visible keypoints. + loss_2d = self.keypoint_2d_loss(pred_kps_2d_all, gt_kps_2d_repeated) + loss_2d = loss_2d.mean(dim=-1) # [B*num_iters, N] + loss_2d = (loss_2d * gt_vis_2d_repeated).sum() / (gt_vis_2d_repeated.sum() + 1e-6) + + losses['intermediate_keypoints_2d'] = loss_2d * self.intermediate_weight + total_loss += losses['intermediate_keypoints_2d'] + + # 3D keypoint supervision + if 'keypoints_3d' in pred_smal_params_list and pred_smal_params_list['keypoints_3d'] is not None: + pred_kps_3d_all = pred_smal_params_list['keypoints_3d'] + # [B*num_iters, N, 3] + + gt_kps_3d = gt_data['keypoints_3d'][: , :, :3] # [B, N, 3] + if gt_data['keypoints_3d'].shape[-1] == 4: + gt_conf_3d = gt_data['keypoints_3d'][:, :, 3] # [B, N] + else: + gt_conf_3d = torch.ones_like(gt_kps_3d[:, :, 0]) # All keypoints are valid. + + # Repeat the ground truth for each iteration. + num_iters = pred_kps_3d_all.shape[0] // gt_kps_3d.shape[0] + gt_kps_3d_repeated = gt_kps_3d.repeat(num_iters, 1, 1) + gt_conf_3d_repeated = gt_conf_3d.repeat(num_iters, 1) + + # Compute the loss. + loss_3d = self.keypoint_3d_loss(pred_kps_3d_all, gt_kps_3d_repeated) + loss_3d = loss_3d.mean(dim=-1) # [B*num_iters, N] + loss_3d = (loss_3d * gt_conf_3d_repeated).sum() / (gt_conf_3d_repeated.sum() + 1e-6) + + losses['intermediate_keypoints_3d'] = loss_3d * self.intermediate_weight + total_loss += losses['intermediate_keypoints_3d'] + + # ... other losses (pose, shape, etc.) ... + + losses['total'] = total_loss + return losses diff --git a/prima/models/prima.py b/prima/models/prima.py new file mode 100755 index 0000000000000000000000000000000000000000..0cb861dd5b7eb2693c03afe97d3fcd53647f7b1a --- /dev/null +++ b/prima/models/prima.py @@ -0,0 +1,615 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import torch +import pickle +import pytorch_lightning as pl +from typing import Any, Dict +from yacs.config import CfgNode + +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from torchvision.utils import make_grid +from ..utils.geometry import perspective_projection, aa_to_rotmat +from ..utils.pylogger import get_pylogger +from .backbones import create_backbone +from .heads import build_smal_head +from prima.models.smal_wrapper import SMAL +from .discriminator import Discriminator + +from .bioclip_embedding import BioClipEmbedding +import sys +from transformers import AutoModel, AutoFeatureExtractor +import einops + +import open_clip + + +from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss, ShapePriorLoss, PosePriorLoss, SupConLoss +log = get_pylogger(__name__) + + +class PRIMA(pl.LightningModule): + + def __init__(self, cfg: CfgNode, init_renderer: bool = True): + """ + Setup PRIMA model + Args: + cfg (CfgNode): Config file as a yacs CfgNode + """ + super().__init__() + + # Save hyperparameters + self.save_hyperparameters(logger=False, ignore=['init_renderer']) + + self.cfg = cfg + # Create backbone feature extractor + + if cfg.MODEL.BACKBONE.TYPE =='vith': + self.backbone = create_backbone(cfg) # create vit backbone anyway, for inference, no config loading, just load ckpt weights + + if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None): # pretrained exists and not none, then true + + log.info(f'Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}') + state_dict = torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu', weights_only=True)['state_dict'] + state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items()} + + missing_keys, unexpected_keys = self.backbone.load_state_dict(state_dict, strict=False) + + + # freeze backbones + if cfg.MODEL.BACKBONE.get('FREEZE', False) and cfg.MODEL.BACKBONE.TYPE == 'vith': + log.info(f'Freezing first 2/3 blocks of vit backbone') + # Freeze patch embedding + if hasattr(self.backbone, 'patch_embed'): + for p in self.backbone.patch_embed.parameters(): + p.requires_grad = False + + # Freeze first 2/3 of transformer blocks + if hasattr(self.backbone, 'blocks'): + total_blocks = len(self.backbone.blocks) + freeze_blocks = int(total_blocks * 2 / 3) + log.info(f'Freezing {freeze_blocks} out of {total_blocks} blocks') + for i in range(freeze_blocks): + for p in self.backbone.blocks[i].parameters(): + p.requires_grad = False + + # Create SMAL head (predicts SMAL params + perspective camera) + self.smal_head = build_smal_head(cfg) + + # Instantiate SMAL model + smal_model_path = cfg.SMAL.MODEL_PATH + with open(smal_model_path, 'rb') as f: + smal_cfg = pickle.load(f, encoding="latin1") + self.smal = SMAL(**smal_cfg) + + # create bioclip model for species classification token extraction + use_bioclip_embedding = cfg.MODEL.get('USE_BIOCLIP_EMBEDDING', False) + if use_bioclip_embedding: + bioclip_config = cfg.MODEL.get('BIOCLIP_EMBEDDING', {}) + embed_dim = bioclip_config.get('EMBED_DIM', 1280) + self.bioclip_embedding = BioClipEmbedding(cfg, embed_dim=embed_dim) + # Freeze BioClip model by default + for param in self.bioclip_embedding.species_model.parameters(): + param.requires_grad = False + else: + self.bioclip_embedding = None + + # Create discriminator + self.discriminator = Discriminator() + + + + + # Define loss functions + self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1') + self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1') + + if self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP2D', 0) > 0: + self.intermediate_kp2d_loss = Keypoint2DLoss(loss_type='l1') + if self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP3D', 0) > 0: + self.intermediate_kp3d_loss = Keypoint3DLoss(loss_type='l1') + self.smal_parameter_loss = ParameterLoss() + self.shape_prior_loss = ShapePriorLoss(path_prior=cfg.SMAL.SHAPE_PRIOR_PATH) + self.pose_prior_loss = PosePriorLoss(path_prior=cfg.SMAL.POSE_PRIOR_PATH) + self.supcon_loss = SupConLoss() + + + self.register_buffer('initialized', torch.tensor(False)) + + # init depth renderer for supervised training + # Setup renderer for visualization + if init_renderer: + from ..utils import MeshRenderer + + self.mesh_renderer = MeshRenderer(self.cfg, faces=self.smal.faces.numpy()) + else: + self.mesh_renderer = None + + # Disable automatic optimization since we use adversarial training + self.automatic_optimization = False + + def get_parameters(self): + all_params = list(self.smal_head.parameters()) + if self.cfg.MODEL.BACKBONE.TYPE in ['vith', 'dinov2', 'dinov3']: + all_params += list(self.backbone.parameters()) + + + if hasattr(self, 'keypoint_projection') and self.keypoint_projection is not None: + all_params += list(self.keypoint_projection.parameters()) + if hasattr(self, 'bioclip_embedding') and self.bioclip_embedding is not None: + # Only add projection parameters as the model itself is frozen + all_params += list(self.bioclip_embedding.projection.parameters()) + return all_params + + def configure_optimizers(self): + """ + Setup model and discriminator Optimizers + Returns: + Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers + """ + # Use separate learning rates only for vith backbone + if self.cfg.MODEL.BACKBONE.TYPE == 'vith': + # Separate backbone parameters and other parameters + backbone_params = [] + other_params = [] + + # Collect backbone parameters + if hasattr(self, 'backbone'): + backbone_params = list(filter(lambda p: p.requires_grad, self.backbone.parameters())) + + # Collect other parameters + other_params += list(self.smal_head.parameters()) + + + if hasattr(self, 'keypoint_projection') and self.keypoint_projection is not None: + other_params += list(self.keypoint_projection.parameters()) + if hasattr(self, 'bioclip_embedding') and self.bioclip_embedding is not None: + other_params += list(self.bioclip_embedding.projection.parameters()) + + + # Filter only trainable parameters + other_params = list(filter(lambda p: p.requires_grad, other_params)) + + # Create parameter groups with different learning rates + param_groups = [ + {'params': backbone_params, 'lr': self.cfg.TRAIN.LR / 10.0}, # Backbone: 1/10 lr + {'params': other_params, 'lr': self.cfg.TRAIN.LR} # Other modules: normal lr + ] + + log.info(f'Using separate LR for vith backbone') + log.info(f'Backbone parameters: {len(backbone_params)}, lr={self.cfg.TRAIN.LR / 10.0}') + log.info(f'Other parameters: {len(other_params)}, lr={self.cfg.TRAIN.LR}') + else: + # Use same learning rate for all parameters + all_params = list(filter(lambda p: p.requires_grad, self.get_parameters())) + param_groups = [{'params': all_params, 'lr': self.cfg.TRAIN.LR}] + log.info(f'Using same LR for all parameters: {len(all_params)}, lr={self.cfg.TRAIN.LR}') + + optimizer = torch.optim.AdamW(params=param_groups, + weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) + if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0: + optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(), + lr=self.cfg.TRAIN.LR, + weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) + else: + return optimizer, + + return optimizer, optimizer_disc + + def forward_step(self, batch: Dict, train: bool = False) -> Dict: + """ + Run a forward step of the network + Args: + batch (Dict): Dictionary containing batch data + train (bool): Flag indicating whether it is training or validation mode + Returns: + Dict: Dictionary containing the regression output + """ + + # Use RGB image as input + x = batch['img'] # [B, 3, H, W] + batch_size = x.shape[0] + + # Compute conditioning features using the backbone + if self.cfg.MODEL.BACKBONE.TYPE =='vith': # vit backbone return [1, 1280, 12, 16] + conditioning_feats, cls = self.backbone(x[:, :, :, 32:-32]) # reshape the input into [256, 192] + # return shape shape [B, D, Hp, Wp], [B, D] + if conditioning_feats.ndim == 4: + # Flatten spatial dimensions into sequence dimension: [B, D, Hp, Wp] -> [B, Hp*Wp, D] + B, D, Hp, Wp = conditioning_feats.shape + conditioning_feats = conditioning_feats.permute(0, 2, 3, 1).reshape(B, Hp * Wp, D) # [B, Hp*Wp, D] + + + # add bioclip embedding if enabled + if self.bioclip_embedding is not None: + species_feature = self.bioclip_embedding(batch['img']) # [B, embed_dim] + + # concatenate species feature to conditioning_feats along token dimension + if len(conditioning_feats.shape) == 3: + # Token-wise concatenation: add species_feature as a single token + # (B, embed_dim) -> (B, 1, embed_dim) + species_token = species_feature.unsqueeze(1) # (B, 1, embed_dim) + # Concatenate along token dimension: (B, num_tokens, C) + (B, 1, embed_dim) -> (B, num_tokens + 1, C or embed_dim) + # Note: This requires C == embed_dim for consistent feature dimensions + conditioning_feats = torch.cat([conditioning_feats, species_token], dim=1) # (B, num_tokens + 1, C) + else: + # If conditioning_feats is 2D (B, C), concat directly along feature dimension + conditioning_feats = torch.cat([conditioning_feats, species_feature], dim=-1) + + # Predict SMAL parameters and camera + pred_smal_params, pred_cam, extra_outputs = self.smal_head(conditioning_feats) + + + # Store useful regression outputs to the output dict + output = {} + + if 'shape_feat' in extra_outputs: + output['shape_feat'] = extra_outputs['shape_feat'] + + if 'init_betas' in extra_outputs: + output['init_betas'] = extra_outputs['init_betas'].reshape(batch_size, -1) + + + output['pred_cam'] = pred_cam # [B, 3] + output['pred_smal_params'] = {k: v.clone() for k, v in pred_smal_params.items()} + + + + # Compute camera translation + focal_length = batch['focal_length'] + + pred_cam_t = torch.stack([ + pred_cam[:, 1], + pred_cam[:, 2], + 2 * focal_length[:, 0] / (self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] + 1e-9) + ], dim=-1) # [B, 3] + + output['pred_cam_t'] = pred_cam_t # [B, 3] + output['focal_length'] = focal_length # [B, 2] + + # Compute model vertices, joints and the projected joints + pred_smal_params['global_orient'] = pred_smal_params['global_orient'].reshape(batch_size, -1, 3, 3) + pred_smal_params['pose'] = pred_smal_params['pose'].reshape(batch_size, -1, 3, 3) + pred_smal_params['betas'] = pred_smal_params['betas'].reshape(batch_size, -1) + smal_output = self.smal(**pred_smal_params, pose2rot=False) + + pred_keypoints_3d = smal_output.joints + pred_vertices = smal_output.vertices + output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3) + output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3) + + # project 3D keypoints to 2D + pred_keypoints_2d = perspective_projection( + pred_keypoints_3d, + translation=pred_cam_t, + focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE + ) # [B, num_joints, 2] + output['pred_keypoints_2d'] = pred_keypoints_2d + + # get intermediate keypoint predictions if available + + if 'keypoints_3d' in pred_smal_params and pred_smal_params['keypoints_3d'] is not None: + inter_keypoints_3d = pred_smal_params['keypoints_3d'] + output['inter_keypoints_3d'] = inter_keypoints_3d.reshape(batch_size, -1, 3) + # output['use_intermediate_kp3d_loss'] = True + + if 'keypoints_2d' in pred_smal_params and pred_smal_params['keypoints_2d'] is not None: + inter_keypoints_2d = pred_smal_params['keypoints_2d'] + output['inter_keypoints_2d'] = inter_keypoints_2d.reshape(batch_size, -1, 2) + # output['use_intermediate_kp2d_loss'] = True + + return output + + def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor: + """ + Compute losses given the input batch and the regression output + Args: + batch (Dict): Dictionary containing batch data + output (Dict): Dictionary containing the regression output + train (bool): Flag indicating whether it is training or validation mode + Returns: + torch.Tensor : Total loss for current batch + """ + + pred_smal_params = output['pred_smal_params'] + pred_keypoints_2d = output['pred_keypoints_2d'] + pred_keypoints_3d = output['pred_keypoints_3d'] + + if 'inter_keypoints_2d' in output: + inter_keypoints_2d = output['inter_keypoints_2d'] + if 'inter_keypoints_3d' in output: + inter_keypoints_3d = output['inter_keypoints_3d'] + + batch_size = pred_smal_params['pose'].shape[0] + device = pred_smal_params['pose'].device + dtype = pred_smal_params['pose'].dtype + + # Get annotations + gt_keypoints_2d = batch['keypoints_2d'] + gt_keypoints_3d = batch['keypoints_3d'] + gt_smal_params = batch['smal_params'] + gt_mask = batch['mask'] + has_smal_params = batch['has_smal_params'] + is_axis_angle = batch['smal_params_is_axis_angle'] + has_mask = batch['has_mask'] + + # Compute 2D keypoint loss + loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d) + + # Compute 3D keypoint loss + loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0) + + # Compute intermediate 2D keypoint loss if available + loss_intermediate_kp2d = torch.tensor(0., device=device, dtype=dtype) + if 'inter_keypoints_2d' in output: + loss_intermediate_kp2d = self.intermediate_kp2d_loss(inter_keypoints_2d, gt_keypoints_2d) + # loss_keypoints_2d = loss_keypoints_2d + loss_intermediate_kp2d + + # Compute intermediate 3D keypoint loss if available + loss_intermediate_kp3d = torch.tensor(0., device=device, dtype=dtype) + if 'inter_keypoints_3d' in output: + loss_intermediate_kp3d = self.intermediate_kp3d_loss(inter_keypoints_3d, gt_keypoints_3d, pelvis_id=0) + # loss_keypoints_3d = loss_keypoints_3d + loss_intermediate_kp3d + + # add intermediate keypoint losses if available + + # Compute loss on SMAL parameters + loss_smal_params = {} + for k, pred in pred_smal_params.items(): + # Skip keypoint predictions - they're handled separately + if k in ['keypoints_2d', 'keypoints_3d']: + continue + + gt = gt_smal_params[k].view(batch_size, -1) + if is_axis_angle[k].all(): + gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3) + has_gt = has_smal_params[k] + + # Only compute parameter loss if ANY sample has GT + param_loss = self.smal_parameter_loss(pred.reshape(batch_size, -1), + gt.reshape(batch_size, -1), + has_gt) + + if k == "betas": + # Only add shape prior loss if NOT all samples have GT (prior is regularization for samples without GT) + # But the shape_prior_loss already handles this check internally + loss_smal_params[k] = param_loss + self.shape_prior_loss(pred, batch["category"], has_gt) + if 'init_betas' in output: + init_betas = output['init_betas'] + loss_smal_params[k] = loss_smal_params[k] + self.shape_prior_loss(init_betas, batch["category"], has_gt) / 2. + + else: + # Only add pose prior loss if NOT all samples have GT + # The pose_prior_loss already handles this check internally + loss_smal_params[k] = param_loss + \ + self.pose_prior_loss(torch.cat((pred_smal_params["global_orient"], + pred_smal_params["pose"]), + dim=1), has_gt) / 2. + if 'shape_feat' in output: + loss_supcon = self.supcon_loss(output['shape_feat'], labels=batch['category']) + else: + loss_supcon = torch.tensor(0., device=device, dtype=dtype) + loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d + \ + self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d + \ + sum([loss_smal_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_smal_params]) + \ + self.cfg.LOSS_WEIGHTS['SUPCON'] * loss_supcon + + if 'inter_keypoints_2d' in output: + loss = loss + self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP2D', 0) * loss_intermediate_kp2d + if 'inter_keypoints_3d' in output: + loss = loss + self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP3D', 0) * loss_intermediate_kp3d + + + losses = dict(loss=loss.detach(), + loss_keypoints_2d=loss_keypoints_2d.detach(), + loss_keypoints_3d=loss_keypoints_3d.detach(), + loss_supcon=loss_supcon.detach(), + ) + + for k, v in loss_smal_params.items(): + losses['loss_' + k] = v.detach() + + # attach intermediate keypoint losses if computed + if 'inter_keypoints_2d' in output: + losses['loss_inter_keypoints_2d'] = loss_intermediate_kp2d.detach() + if 'inter_keypoints_3d' in output: + losses['loss_inter_keypoints_3d'] = loss_intermediate_kp3d.detach() + + + + output['losses'] = losses + + return loss + + def forward(self, batch: Dict) -> Dict: + """ + Run a forward step of the network in val mode + Args: + batch (Dict): Dictionary containing batch data + Returns: + Dict: Dictionary containing the regression output + """ + return self.forward_step(batch, train=False) + + def training_step_discriminator(self, batch: Dict, + pose: torch.Tensor, + betas: torch.Tensor, + optimizer: torch.optim.Optimizer) -> torch.Tensor: + """ + Run a discriminator training step + Args: + batch (Dict): Dictionary containing mocap batch data + pose (torch.Tensor): Regressed pose from current step + betas (torch.Tensor): Regressed betas from current step + optimizer (torch.optim.Optimizer): Discriminator optimizer + Returns: + torch.Tensor: Discriminator loss + """ + batch_size = pose.shape[0] + gt_pose = batch['pose'] + gt_betas = batch['betas'] + gt_rotmat = aa_to_rotmat(gt_pose.view(-1, 3)).view(batch_size, -1, 3, 3) + disc_fake_out = self.discriminator(pose.detach(), betas.detach()) + loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size + disc_real_out = self.discriminator(gt_rotmat.detach(), gt_betas.detach()) + loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size + loss_disc = loss_fake + loss_real + loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() + return loss_disc.detach() + + # Tensoroboard logging should run from first rank only + @pl.utilities.rank_zero.rank_zero_only + def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, + write_to_summary_writer: bool = True) -> None: + """ + Log results to Tensorboard + Args: + batch (Dict): Dictionary containing batch data + output (Dict): Dictionary containing the regression output + step_count (int): Global training step count + train (bool): Flag indicating whether it is training or validation mode + """ + + mode = 'train' if train else 'val' + + images = batch['img'] + gt_keypoints_2d = batch['keypoints_2d'] + batch_size = images.shape[0] + + # mul std then add mean + images = (images) * (torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)) + images = (images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)) + + pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3) + losses = output['losses'] + pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3) + pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2) + + if write_to_summary_writer: + summary_writer = self.logger.experiment + for loss_name, val in losses.items(): + summary_writer.add_scalar(mode + '/' + loss_name, val.detach().item(), step_count) + # if train is False: + # for metric_name, val in output['metric'].items(): + # summary_writer.add_scalar(mode + '/' + metric_name, val, step_count) + num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES) + + predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(), + pred_cam_t[:num_images].cpu().numpy(), + images[:num_images].cpu().numpy(), + self.cfg.SMAL.get("FOCAL_LENGTH", 1000), + pred_keypoints_2d[:num_images].cpu().numpy(), + gt_keypoints_2d[:num_images].cpu().numpy(), + pred_masks=output.get('pred_masks', None)[:num_images] if output.get('pred_masks', None) is not None else None, + gt_masks=output.get('gt_masks', None)[:num_images] if output.get('gt_masks', None) is not None else None, + ) + predictions = make_grid(predictions, nrow=5, padding=2) + if write_to_summary_writer: + summary_writer.add_image('%s/predictions' % mode, predictions, step_count) + + return predictions + + def training_step(self, batch: Dict) -> Dict: + """ + Run a full training step + Args: + batch (Dict): Dictionary containing {'img', 'mask', 'keypoints_2d', 'keypoints_3d', 'orig_keypoints_2d', + 'box_center', 'box_size', 'img_size', 'smal_params', + 'smal_params_is_axis_angle', '_trans', 'imgname', 'focal_length'} + Returns: + Dict: Dictionary containing regression output. + """ + batch = batch['img'] + optimizer = self.optimizers(use_pl_optimizer=True) + if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0: + optimizer, optimizer_disc = optimizer + + batch_size = batch['img'].shape[0] + output = self.forward_step(batch, train=True) + pred_smal_params = output['pred_smal_params'] + loss = self.compute_loss(batch, output, train=True) + if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0: + disc_out = self.discriminator(pred_smal_params['pose'].reshape(batch_size, -1), + pred_smal_params['betas'].reshape(batch_size, -1)) + loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size + loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv + + # Error if Nan + if torch.isnan(loss): + raise ValueError('Loss is NaN') + + optimizer.zero_grad() + self.manual_backward(loss) + # Clip gradient + if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0: + gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, + error_if_nonfinite=True) + self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + # For compatibility + # if self.cfg.LOSS_WEIGHTS.ADVERSARIAL == 0: + # optimizer.param_groups[0]['capturable'] = True + + optimizer.step() + if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0: + loss_disc = self.training_step_discriminator(batch['smal_params'], + pred_smal_params['pose'].reshape(batch_size, -1), + pred_smal_params['betas'].reshape(batch_size, -1), + optimizer_disc) + output['losses']['loss_gen'] = loss_adv + output['losses']['loss_disc'] = loss_disc + + if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0: + self.tensorboard_logging(batch, output, self.global_step, train=True) + + # Log training loss to the logger so checkpoint callback can monitor it. + self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, + logger=True, batch_size=batch_size, sync_dist=True) + + return output + + def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict: + """ + Run a validation step and log to Tensorboard + Args: + batch (Dict): Dictionary containing batch data + batch_idx (int): Unused. + Returns: + Dict: Dictionary containing regression output. + """ + # The validation dataloader yields the inner batch dict directly (not wrapped as {'img': loader}). + # Run forward, compute loss and log aggregated validation metrics so ModelCheckpoint can monitor them. + output = self.forward_step(batch, train=False) + # compute_loss will populate output['losses'] and return the scalar loss + loss = self.compute_loss(batch, output, train=False) + + # Ensure losses dict is available + losses = output.get('losses', {}) + + # Log all validation losses to logger with on_epoch=True so checkpoint monitors epoch-level metric + for loss_name, val in losses.items(): + # use prog_bar only for the main loss + prog = True if loss_name == 'loss' else False + # Log as 'val/' e.g. 'val/loss' + self.log(f'val/{loss_name}', val, on_step=False, on_epoch=True, prog_bar=prog, logger=True, + sync_dist=True) + + # Periodically write images/other visuals to tensorboard + # Log visualizations on the first batch of each validation epoch + if batch_idx == 0: + # Use global_step for step count when logging validation visuals + self.tensorboard_logging(batch, output, self.global_step, train=False) + + return output diff --git a/prima/models/smal_wrapper.py b/prima/models/smal_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d0873cf3adcd5758e1c135c8a4115bebd125c --- /dev/null +++ b/prima/models/smal_wrapper.py @@ -0,0 +1,148 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import json +from torch import nn +import torch +import numpy as np +import pickle +import cv2 +from typing import Optional, Tuple, NewType +from dataclasses import dataclass +import smplx +from smplx.lbs import vertices2joints, lbs +from smplx.utils import MANOOutput, to_tensor, ModelOutput +from smplx.vertex_ids import vertex_ids + +Tensor = NewType('Tensor', torch.Tensor) +keypoint_vertices_idx = [[1068, 1080, 1029, 1226], [2660, 3030, 2675, 3038], [910], [360, 1203, 1235, 1230], + [3188, 3156, 2327, 3183], [1976, 1974, 1980, 856], [3854, 2820, 3852, 3858], [452, 1811], + [416, 235, 182], [2156, 2382, 2203], [829], [2793], [60, 114, 186, 59], + [2091, 2037, 2036, 2160], [384, 799, 1169, 431], [2351, 2763, 2397, 3127], + [221, 104], [2754, 2192], [191, 1158, 3116, 2165], + [28, 1109, 1110, 1111, 1835, 1836, 3067, 3068, 3069], + [498, 499, 500, 501, 502, 503], [2463, 2464, 2465, 2466, 2467, 2468], + [764, 915, 916, 917, 934, 935, 956], [2878, 2879, 2880, 2897, 2898, 2919, 3751], + [1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762], + [0, 464, 465, 726, 1824, 2429, 2430, 2690]] + +name2id35 = {'RFoot': 14, 'RFootBack': 24, 'spine1': 4, 'Head': 16, 'LLegBack3': 19, 'RLegBack1': 21, 'pelvis0': 1, + 'RLegBack3': 23, 'LLegBack2': 18, 'spine0': 3, 'spine3': 6, 'spine2': 5, 'Mouth': 32, 'Neck': 15, + 'LFootBack': 20, 'LLegBack1': 17, 'RLeg3': 13, 'RLeg2': 12, 'LLeg1': 7, 'LLeg3': 9, 'RLeg1': 11, + 'LLeg2': 8, 'spine': 2, 'LFoot': 10, 'Tail7': 31, 'Tail6': 30, 'Tail5': 29, 'Tail4': 28, 'Tail3': 27, + 'Tail2': 26, 'Tail1': 25, 'RLegBack2': 22, 'root': 0, 'LEar': 33, 'REar': 34, 'EndNose': 35, 'Chin': 36, + 'RightEarTip': 37, 'LeftEarTip': 38, 'LeftEye': 39, 'RightEye': 40} + +@dataclass +class SMALOutput(ModelOutput): + betas: Optional[Tensor] = None + pose: Optional[Tensor] = None + + +class SMALLayer(nn.Module): + def __init__(self, num_betas=41, **kwargs): + super().__init__() + self.num_betas = num_betas + from chumpy.ch import materialize + + self.register_buffer( + "shapedirs", + torch.from_numpy(materialize(kwargs["shapedirs"]))[:, :, :num_betas], + ) # [3889, 3, 41] + self.register_buffer( + "v_template", torch.from_numpy(materialize(kwargs["v_template"])) + ) # [3889, 3] + self.register_buffer( + "posedirs", + torch.from_numpy(materialize(kwargs["posedirs"])).reshape(-1, 34 * 9).T, + ) # [34*9, 11667] + self.register_buffer( + "J_regressor", + torch.from_numpy(kwargs["J_regressor"].toarray().astype(np.float32)), + ) # [33, 3389] + self.register_buffer( + "lbs_weights", torch.from_numpy(materialize(kwargs["weights"])) + ) # [3889, 33] + self.register_buffer("faces", torch.from_numpy(materialize(kwargs["f"], dtype=np.int32))) # [7774, 3] + + kintree_table = kwargs['kintree_table'] + self.register_buffer("parents", torch.from_numpy(kintree_table[0].astype(np.int32))) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + **kwargs): + """ + Args: + betas: [batch_size, 10] + global_orient: [batch_size, 1, 3, 3] + pose: [batch_size, num_joints, 3, 3] + transl: [batch_size, num_joints, 3] + return_verts: + return_full_pose: + **kwargs: + Returns: + """ + device, dtype = betas.device, betas.dtype + if global_orient is None: + batch_size = 1 + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + else: + batch_size = global_orient.shape[0] + if pose is None: + pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 34, -1, -1).contiguous() + if betas is None: + betas = torch.zeros( + [batch_size, self.num_betas], dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + full_pose = torch.cat([global_orient, pose], dim=1) + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=False) + + if transl is not None: + joints = joints + transl.unsqueeze(dim=1) + vertices = vertices + transl.unsqueeze(dim=1) + + output = SMALOutput( + vertices=vertices if return_verts else None, + joints=joints if return_verts else None, + betas=betas, + global_orient=global_orient, + pose=pose, + transl=transl, + full_pose=full_pose if return_full_pose else None, + ) + return output + + +class SMAL(SMALLayer): + def __init__(self, **kwargs): + super(SMAL, self).__init__(**kwargs) + + def forward(self, *args, **kwargs): + smal_output = super(SMAL, self).forward(**kwargs) + + keypoint = [] + for kp_v in keypoint_vertices_idx: + keypoint.append(smal_output.vertices[:, kp_v, :].mean(dim=1)) + smal_output.joints = torch.stack(keypoint, dim=1) + return smal_output + + diff --git a/prima/utils/__init__.py b/prima/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..9b323a57a00b59045862cfeded8f431418235b24 --- /dev/null +++ b/prima/utils/__init__.py @@ -0,0 +1,45 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from typing import Any + + +def recursive_to(x: Any, target: Any): + """ + Recursively transfer a batch of data to the target device + Args: + x (Any): Batch of data. + target (torch.device): Target device. + Returns: + Batch of data where all tensors are transferred to the target device. + """ + import torch + + def move(value: Any): + if isinstance(value, dict): + return {k: move(v) for k, v in value.items()} + if isinstance(value, torch.Tensor): + return value.to(target) + if isinstance(value, list): + return [move(i) for i in value] + return value + + return move(x) + + +def __getattr__(name: str): + if name == "MeshRenderer": + from .mesh_renderer import MeshRenderer + + globals()[name] = MeshRenderer + return MeshRenderer + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = ["MeshRenderer", "recursive_to"] diff --git a/prima/utils/detection.py b/prima/utils/detection.py new file mode 100644 index 0000000000000000000000000000000000000000..72dcae2a21ac74e3fd3592a1c0c0bda324376a5f --- /dev/null +++ b/prima/utils/detection.py @@ -0,0 +1,118 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from __future__ import annotations + +# Utilities for filtering animal detections before PRIMA demo inference. +# +# Detectron2 may return both a full-animal box and a local/partial box for the +# same animal. These helpers keep the demo pipeline from rendering the same +# animal multiple times. + +from typing import Iterable + +import numpy as np + +ANIMAL_COCO_IDS = (15, 16, 17, 18, 19, 21, 22) + + +def _box_areas(boxes: np.ndarray) -> np.ndarray: + widths = np.maximum(0.0, boxes[:, 2] - boxes[:, 0]) + heights = np.maximum(0.0, boxes[:, 3] - boxes[:, 1]) + return widths * heights + + +def _intersection_areas(box: np.ndarray, boxes: np.ndarray) -> np.ndarray: + x1 = np.maximum(box[0], boxes[:, 0]) + y1 = np.maximum(box[1], boxes[:, 1]) + x2 = np.minimum(box[2], boxes[:, 2]) + y2 = np.minimum(box[3], boxes[:, 3]) + return np.maximum(0.0, x2 - x1) * np.maximum(0.0, y2 - y1) + + +def _suppress_duplicate_boxes( + boxes: np.ndarray, + scores: np.ndarray, + *, + iou_threshold: float, + containment_threshold: float, +) -> np.ndarray: + if len(boxes) <= 1: + return np.arange(len(boxes), dtype=np.int64) + + boxes = boxes.astype(np.float32, copy=False) + scores = scores.astype(np.float32, copy=False) + areas = _box_areas(boxes) + + contained = np.zeros(len(boxes), dtype=bool) + for idx, area in enumerate(areas): + if area <= 0: + contained[idx] = True + continue + larger = np.where(areas > area)[0] + if len(larger) == 0: + continue + covered = _intersection_areas(boxes[idx], boxes[larger]) / area + if np.any(covered >= containment_threshold): + contained[idx] = True + + candidates = np.where(~contained)[0] + if len(candidates) <= 1: + return candidates + + order = candidates[np.argsort(scores[candidates])[::-1]] + keep = [] + while len(order) > 0: + current = order[0] + keep.append(current) + rest = order[1:] + if len(rest) == 0: + break + + inter = _intersection_areas(boxes[current], boxes[rest]) + union = areas[current] + areas[rest] - inter + iou = np.divide(inter, union, out=np.zeros_like(inter), where=union > 0) + order = rest[iou <= iou_threshold] + + return np.array(sorted(keep), dtype=np.int64) + + +def select_animal_boxes( + det_instances, + *, + animal_class_ids: Iterable[int] = ANIMAL_COCO_IDS, + score_threshold: float = 0.7, + iou_threshold: float = 0.5, + containment_threshold: float = 0.9, +) -> tuple[np.ndarray, int]: + """Return filtered animal boxes and the number of duplicate boxes removed.""" + class_ids = set(int(class_id) for class_id in animal_class_ids) + classes = det_instances.pred_classes.detach().cpu().numpy() + scores = det_instances.scores.detach().cpu().numpy() + + valid_idx = np.array( + [ + i + for i, (class_id, score) in enumerate(zip(classes, scores)) + if int(class_id) in class_ids and float(score) > float(score_threshold) + ], + dtype=np.int64, + ) + if len(valid_idx) == 0: + return np.zeros((0, 4), dtype=np.float32), 0 + + boxes = det_instances.pred_boxes.tensor[valid_idx].detach().cpu().numpy() + scores = scores[valid_idx] + keep = _suppress_duplicate_boxes( + boxes, + scores, + iou_threshold=iou_threshold, + containment_threshold=containment_threshold, + ) + return boxes[keep], int(len(boxes) - len(keep)) diff --git a/prima/utils/evaluate_metric.py b/prima/utils/evaluate_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..39720bd53098b21a5cfbe06bf18317c9f217066b --- /dev/null +++ b/prima/utils/evaluate_metric.py @@ -0,0 +1,206 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import torch +import numpy as np +import open3d as o3d +from typing import Dict, List, Union +from pytorch3d.transforms import axis_angle_to_matrix + + +def compute_scale_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor: + """ + Computes a scale transform (s) in a batched way that takes + a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3). + Args: + S1 (torch.Tensor): First set of points of shape (B, N, 3). + S2 (torch.Tensor): Second set of points of shape (B, N, 3). + Returns: + (torch.Tensor): The first set of points after applying the scale transformation. + """ + + # 1. Remove mean. + mu1 = S1.mean(dim=1, keepdim=True) + mu2 = S2.mean(dim=1, keepdim=True) + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = (X1 ** 2).sum(dim=(1, 2), keepdim=True) + + # 3. Compute scale. + scale = (X2 * X1).sum(dim=(1, 2), keepdim=True) / var1 + + # 4. Apply scale transform. + S1_hat = scale * X1 + mu2 + + return S1_hat + + +def compute_similarity_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor: + """ + Computes a similarity transform (sR, t) in a batched way that takes + a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3), + where R is a 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + Args: + S1 (torch.Tensor): First set of points of shape (B, N, 3). + S2 (torch.Tensor): Second set of points of shape (B, N, 3). + Returns: + (torch.Tensor): The first set of points after applying the similarity transformation. + """ + + batch_size = S1.shape[0] + S1 = S1.permute(0, 2, 1) + S2 = S2.permute(0, 2, 1) + # 1. Remove mean. + mu1 = S1.mean(dim=2, keepdim=True) + mu2 = S2.mean(dim=2, keepdim=True) + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = (X1 ** 2).sum(dim=(1, 2)) + + # 3. The outer product of X1 and X2. + K = torch.matmul(X1.float(), X2.permute(0, 2, 1)) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K. + U, s, V = torch.svd(K.float()) + Vh = V.permute(0, 2, 1) + + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = torch.eye(U.shape[1], device=U.device).unsqueeze(0).repeat(batch_size, 1, 1).float() + Z[:, -1, -1] *= torch.sign(torch.linalg.det(torch.matmul(U.float(), Vh.float()).float())) + + # Construct R. + R = torch.matmul(torch.matmul(V, Z), U.permute(0, 2, 1)) + + # 5. Recover scale. + trace = torch.matmul(R, K).diagonal(offset=0, dim1=-1, dim2=-2).sum(dim=-1) + scale = (trace / var1).unsqueeze(dim=-1).unsqueeze(dim=-1) + + # 6. Recover translation. + t = mu2 - scale * torch.matmul(R.float(), mu1.float()) + + # 7. Error: + S1_hat = scale * torch.matmul(R.float(), S1.float()).float() + t + + return S1_hat.permute(0, 2, 1) + + +def pointcloud(points: np.ndarray): + pcd = o3d.geometry.PointCloud() + points = o3d.utility.Vector3dVector(points) + pcd.points = points + return pcd + + +class Evaluator: + def __init__(self, smal_model, image_size: int=256, pelvis_ind: int = 7): + self.pelvis_ind = pelvis_ind + self.smal_model = smal_model + self.image_size = image_size + + def compute_pck(self, output: Dict, batch: Dict, pck_threshold: Union[List, None]): + if pck_threshold is None or len(pck_threshold) == 0: + return torch.tensor([], dtype=torch.float32) + + pred_keypoints_2d = output['pred_keypoints_2d'].detach().cpu() + gt_keypoints_2d = batch['keypoints_2d'].detach().cpu() + + pred_keypoints_2d = (pred_keypoints_2d + 0.5) * self.image_size + conf = gt_keypoints_2d[:, :, -1] + gt_keypoints_2d = (gt_keypoints_2d[:, :, :-1] + 0.5) * self.image_size + + if 'mask' in batch and batch['mask'] is not None: + seg_area = torch.sum(batch['mask'].detach().cpu().reshape(batch['mask'].shape[0], -1), dim=-1).unsqueeze(-1) + else: + seg_area = torch.tensor([self.image_size * self.image_size] * len(pred_keypoints_2d), dtype=torch.float32).unsqueeze(-1) + + total_visible = torch.sum(conf, dim=-1).clamp_min(1e-6) # (B,) + dist = torch.norm(pred_keypoints_2d - gt_keypoints_2d, dim=-1) # (B, K) + norm_dist = dist / torch.sqrt(seg_area) # (B, K) + + thresholds = torch.tensor(pck_threshold, dtype=torch.float32).view(-1, 1, 1) # (T, 1, 1) + hits = (norm_dist.unsqueeze(0) < thresholds).float() # (T, B, K) + pcks = (hits * conf.unsqueeze(0)).sum(dim=-1) / total_visible.unsqueeze(0) # (T, B) + return pcks.mean(dim=1) # (T,) + + def compute_pa_mpjpe(self, pred_joints, gt_joints): + S1_hat = compute_similarity_transform(pred_joints, gt_joints) + pa_mpjpe = torch.sqrt(((S1_hat - gt_joints) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * 1000 + return pa_mpjpe.mean() + + def compute_pa_mpvpe(self, gt_vertices: torch.Tensor, pred_vertices: torch.Tensor): + batch_size = pred_vertices.shape[0] + S1_hat = compute_similarity_transform(pred_vertices, gt_vertices) + pa_mpvpe = torch.sqrt(((S1_hat - gt_vertices) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * 1000 + return pa_mpvpe.mean() + + def eval_3d(self, output: Dict, batch: Dict): + """ + Evaluate current batch + Args: + output: model output + batch: model input + Returns: evaluate metric + """ + if batch['has_smal_params']["betas"].sum() == 0: + return 0., 0. + + pred_keypoints_3d = output["pred_keypoints_3d"].detach() + pred_keypoints_3d = pred_keypoints_3d[:, None, :, :] + batch_size = pred_keypoints_3d.shape[0] + num_samples = pred_keypoints_3d.shape[1] + gt_keypoints_3d = batch['keypoints_3d'][:, :, :-1].unsqueeze(1).repeat(1, num_samples, 1, 1) + gt_vertices = self.smal_forward(batch) + + # Align predictions and ground truth such that the pelvis location is at the origin + pred_keypoints_3d -= pred_keypoints_3d[:, :, [self.pelvis_ind]] + gt_keypoints_3d -= gt_keypoints_3d[:, :, [self.pelvis_ind]] + + pa_mpjpe = self.compute_pa_mpjpe(pred_keypoints_3d.reshape(batch_size * num_samples, -1, 3), + gt_keypoints_3d.reshape(batch_size * num_samples, -1, 3)) + pa_mpvpe = self.compute_pa_mpvpe(gt_vertices, output['pred_vertices']) + return pa_mpjpe, pa_mpvpe + + def eval_2d(self, output: Dict, batch: Dict, pck_threshold: List[float]=[0.10, 0.15]): + pck = self.compute_pck(output, batch, pck_threshold=pck_threshold) + auc = self.compute_auc(batch, output) + return pck.tolist(), auc + + def compute_auc(self, batch: Dict, output: Dict, threshold_min: float=0.0, threshold_max: float=1.0, steps: int=100): + thresholds = np.linspace(threshold_min, threshold_max, steps) + pck_curve = self.compute_pck(output, batch, thresholds.tolist()).numpy() # (steps,) + norm_factor = threshold_max - threshold_min + auc = float(np.trapz(pck_curve, thresholds) / norm_factor) + return auc + + def smal_forward(self, batch: Dict): + batch_size = batch['img'].shape[0] + smal_params = batch['smal_params'] + smal_params['global_orient'] = axis_angle_to_matrix(smal_params['global_orient'].reshape(batch_size, -1)).unsqueeze(1) + smal_params['pose'] = axis_angle_to_matrix(smal_params['pose'].reshape(batch_size, -1, 3)) + # The SMAL model only registers buffers (e.g. shapedirs) and has no trainable parameters, + # so self.smal_model.parameters() can be empty and calling next on it would raise StopIteration. + # Here we first try to get the device from parameters; if there are no parameters, fall back to buffers; + # if there are no buffers either, fall back to the device of the input batch. + try: + device = next(self.smal_model.parameters()).device + except StopIteration: + try: + device = next(self.smal_model.buffers()).device + except StopIteration: + device = batch['img'].device + smal_params = {k: v.to(device) for k, v in smal_params.items()} + with torch.no_grad(): + smal_output = self.smal_model(**smal_params) + vertices = smal_output.vertices + return vertices diff --git a/prima/utils/geometry.py b/prima/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc16e8faf270d761a35dbb7d04381a6cf8ab0c0 --- /dev/null +++ b/prima/utils/geometry.py @@ -0,0 +1,115 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from typing import Optional +import torch +from torch.nn import functional as F + + +def aa_to_rotmat(theta: torch.Tensor): + """ + Convert axis-angle representation to rotation matrix. + Works by first converting it to a quaternion. + Args: + theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations. + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm = torch.norm(theta + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim=1) + return quat_to_rotmat(quat) + + +def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor: + """ + Convert quaternion representation to rotation matrix. + Args: + quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z). + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, + 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx, + 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + + +def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor: + """ + Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Args: + x (torch.Tensor): (B,6) Batch of 6-D rotation representations. + Returns: + torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3). + """ + x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous() + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2, dim=1) + return torch.stack((b1, b2, b3), dim=-1) + + +def perspective_projection(points: torch.Tensor, + translation: torch.Tensor, + focal_length: torch.Tensor, + camera_center: Optional[torch.Tensor] = None, + rotation: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Computes the perspective projection of a set of 3D points. + Args: + points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points. + translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation. + focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels. + camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels. + rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation. + Returns: + torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points. + """ + batch_size = points.shape[0] + if rotation is None: + rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1) + if camera_center is None: + camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype) + # Populate intrinsic camera matrix K. + K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype) + K[:, 0, 0] = focal_length[:, 0] + K[:, 1, 1] = focal_length[:, 1] + K[:, 2, 2] = 1. + K[:, :-1, -1] = camera_center + + # Transform points + points = torch.einsum('bij,bkj->bki', rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:, :, -1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum('bij,bkj->bki', K, projected_points) + + + return projected_points[:, :, :-1] diff --git a/prima/utils/mesh_renderer.py b/prima/utils/mesh_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..522bc1a8641e05592d7dd7d4792e137e77c590ba --- /dev/null +++ b/prima/utils/mesh_renderer.py @@ -0,0 +1,328 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import os +from ctypes.util import find_library + +if 'PYOPENGL_PLATFORM' not in os.environ and os.uname().sysname != 'Darwin': + # Prefer OSMesa; fall back to EGL where available. + os.environ['PYOPENGL_PLATFORM'] = 'osmesa' if find_library('OSMesa') else 'egl' +import torch +from torchvision.utils import make_grid +import numpy as np +import pyrender +import trimesh +import cv2 +import math +import torch.nn.functional as F +from typing import List, Tuple + + +def create_raymond_lights(): + import pyrender + thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]) + phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0]) + + nodes = [] + + for phi, theta in zip(phis, thetas): + xp = np.sin(theta) * np.cos(phi) + yp = np.sin(theta) * np.sin(phi) + zp = np.cos(theta) + + z = np.array([xp, yp, zp]) + z = z / np.linalg.norm(z) + x = np.array([-z[1], z[0], 0.0]) + if np.linalg.norm(x) == 0: + x = np.array([1.0, 0.0, 0.0]) + x = x / np.linalg.norm(x) + y = np.cross(z, x) + + matrix = np.eye(4) + matrix[:3, :3] = np.c_[x, y, z] + nodes.append(pyrender.Node( + light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0), + matrix=matrix + )) + + return nodes + + +def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]: + """ + Compute rectangle enclosing keypoints above the threshold. + Args: + keypoints (np.array): Keypoint array of shape (N, 3). + threshold (float): Confidence visualization threshold. + Returns: + Tuple[float, float, float]: Rectangle width, height and area. + """ + valid_ind = keypoints[:, -1] > threshold + if valid_ind.sum() > 0: + valid_keypoints = keypoints[valid_ind][:, :-1] + max_x = valid_keypoints[:, 0].max() + max_y = valid_keypoints[:, 1].max() + min_x = valid_keypoints[:, 0].min() + min_y = valid_keypoints[:, 1].min() + width = max_x - min_x + height = max_y - min_y + area = width * height + return width, height, area + else: + return 0, 0, 0 + + +def render_keypoint(img: np.array, keypoint: np.array, threshold=0.1, + use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0) -> np.array: + if use_confidence and map_fn is not None: + thicknessCircleRatioRight = 1. / 50 * map_fn(keypoint[:, -1]) + else: + thicknessCircleRatioRight = 1. / 50 * np.ones(keypoint.shape[0]) + + thicknessLineRatioWRTCircle = 0.75 + if keypoint.shape[0] == 26: + pairs = [0, 24, 1, 24, 2, 24, 3, 14, 4, 15, 5, 16, 6, 17, 7, 18, 8, 12, 9, 13, 10, 7, 11, 7, + 12, 18, 13, 18, 14, 8, 15, 9, 16, 10, 17, 11, 18, 24, 19, 25, 20, 0, 21, 1, 22, 24, + 23, 24, 25, 7] + elif keypoint.shape[0] == 18: + pairs = [9, 8, 8, 2, 2, 3, 3, 4, 2, 0, 2, 1, 4, 5, + 5, 14, 14, 15, 4, 6, 6, 7, 7, 11, 11, 10, + 7, 13, 13, 12, 5, 16, 5, 17] + else: + raise ValueError("Keypoint shape not supported") + pairs = np.array(pairs).reshape(-1, 2) if pairs is not None else None + colors = [255., 0., 85., + 255., 0., 0., + 255., 85., 0., + 255., 170., 0., + 255., 255., 0., + 170., 255., 0., + 85., 255., 0., + 0., 255., 0., + 255., 0., 0., + 0., 255., 85., + 0., 255., 170., + 0., 255., 255., + 0., 170., 255., + 0., 85., 255., + 0., 0., 255., + 255., 0., 170., + 170., 0., 255., + 255., 0., 255., + 85., 0., 255., + 0., 0., 255., + 0., 0., 255., + 0., 0., 255., + 0., 255., 255., + 0., 255., 255., + 0., 255., 255., + 255., 225., 255.] + colors = np.array(colors).reshape(-1, 3) + poseScales = [1] + + img_orig = img.copy() + width, height = img.shape[1], img.shape[2] + area = width * height + + lineType = 8 + shift = 0 + numberColors = len(colors) + thresholdRectangle = 0.1 + + animal_width, animal_height, animal_area = get_keypoints_rectangle(keypoint, thresholdRectangle) + if animal_area > 0: + ratioAreas = min(1, max(animal_width / width, animal_height / height)) + thicknessRatio = np.maximum(np.round(math.sqrt(area) * thicknessCircleRatioRight * ratioAreas), 2) + thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio)) + thicknessLine = np.maximum(1, np.round(thicknessRatio * thicknessLineRatioWRTCircle)) + radius = thicknessRatio / 2 + else: + return img + + img = np.ascontiguousarray(img.copy()) + if pairs is not None: + for i, pair in enumerate(pairs): + index1, index2 = pair + if keypoint[index1, -1] > threshold and keypoint[index2, -1] > threshold: + thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * poseScales[0])) + colorIndex = index2 + color = colors[colorIndex % numberColors] + keypoint1 = keypoint[index1, :-1].astype(np.int32) + keypoint2 = keypoint[index2, :-1].astype(np.int32) + cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()), + thicknessLineScaled, lineType, shift) + for part in range(len(keypoint)): + faceIndex = part + if keypoint[faceIndex, -1] > threshold: + radiusScaled = int(round(radius[faceIndex] * poseScales[0])) + thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * poseScales[0])) + colorIndex = part + color = colors[colorIndex % numberColors] + center = keypoint[faceIndex, :-1].astype(np.int32) + cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled, + lineType, shift) + + return img + + +class MeshRenderer: + + def __init__(self, cfg, faces=None): + self.cfg = cfg + self.img_res = cfg.MODEL.IMAGE_SIZE + self.renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res, + viewport_height=self.img_res, + point_size=1.0) + + self.camera_center = [self.img_res // 2, self.img_res // 2] + self.faces = faces + + def visualize(self, vertices, camera_translation, images, focal_length, nrow=3, padding=2): + images_np = np.transpose(images, (0, 2, 3, 1)) + rend_imgs = [] + for i in range(vertices.shape[0]): + rend_img = torch.from_numpy(np.transpose( + self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=False), + (2, 0, 1))).float() + rend_img_side = torch.from_numpy(np.transpose( + self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=True), + (2, 0, 1))).float() + rend_imgs.append(torch.from_numpy(images[i])) + rend_imgs.append(rend_img) + rend_imgs.append(rend_img_side) + rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding) + return rend_imgs + + def visualize_tensorboard(self, vertices, camera_translation, images, focal_length, pred_keypoints, gt_keypoints, + pred_masks=None, gt_masks=None): + images_np = np.transpose(images, (0, 2, 3, 1)) + rend_imgs = [] + pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1) + pred_keypoints = self.img_res * (pred_keypoints + 0.5) + gt_keypoints[:, :, :-1] = self.img_res * (gt_keypoints[:, :, :-1] + 0.5) + # keypoint_matches = [(1, 12), (2, 8), (3, 7), (4, 6), (5, 9), + # (6, 10), (7, 11), (8, 14), (9, 2), (10, 1), (11, 0), (12, 3), (13, 4), (14, 5)] + # rend_img_pytorch3d = self.render_by_pytorch3d(vertices, camera_translation, + # images_np, focal_length=self.focal_length) + for i in range(vertices.shape[0]): + rend_img = torch.from_numpy(np.transpose( + self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=False), + (2, 0, 1))).float() + rend_img_side = torch.from_numpy(np.transpose( + self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=True), + (2, 0, 1))).float() + keypoints = pred_keypoints[i] + pred_keypoints_img = render_keypoint(255 * images_np[i].copy(), keypoints) / 255 + keypoints = gt_keypoints[i] + gt_keypoints_img = render_keypoint(255 * images_np[i].copy(), keypoints) / 255 + rend_imgs.append(torch.from_numpy(images[i])) + rend_imgs.append(rend_img) + rend_imgs.append(rend_img_side) + if pred_masks is not None: + rend_imgs.append(torch.from_numpy(pred_masks[i])) + if gt_masks is not None: + rend_imgs.append(torch.from_numpy(gt_masks[i])) + rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2, 0, 1)) + rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2, 0, 1)) + return rend_imgs + + def __call__(self, vertices, camera_translation, image, focal_length, text=None, resize=None, side_view=False, + baseColorFactor=(1.0, 1.0, 0.9, 1.0), rot_angle=90): + renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1], + viewport_height=image.shape[0], + point_size=1.0) + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.0, + alphaMode='OPAQUE', + baseColorFactor=baseColorFactor) + + camera_translation_local = camera_translation.copy() + camera_translation_local[0] *= -1. + + mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy()) + if side_view: + rot = trimesh.transformations.rotation_matrix( + np.radians(rot_angle), [0, 1, 0]) + mesh.apply_transform(rot) + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation_local + camera_center = [image.shape[1] / 2., image.shape[0] / 2.] + camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length, + cx=camera_center[0], cy=camera_center[1], + zfar=1000) + scene.add(camera, pose=camera_pose) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis] + if not side_view: + output_img = (color[:, :, :3] * valid_mask + + (1 - valid_mask) * image) + else: + output_img = color[:, :, :3] + if resize is not None: + output_img = cv2.resize(output_img, resize) + + output_img = output_img.astype(np.float32) + renderer.delete() + return output_img + + def render_mask(self, vertices, camera_translation, focal_length, side_view=False, rot_angle=90): + """ + Render only the visibility mask (alpha>0) of the mesh given vertices and camera translation. + Returns a single-channel float32 numpy array with values 0.0 or 1.0 with shape (H, W). + """ + renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res, + viewport_height=self.img_res, + point_size=1.0) + + mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy()) + if side_view: + rot = trimesh.transformations.rotation_matrix( + np.radians(rot_angle), [0, 1, 0]) + mesh.apply_transform(rot) + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + mesh = pyrender.Mesh.from_trimesh(mesh) + + scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation + camera_center = [self.img_res / 2., self.img_res / 2.] + camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length, + cx=camera_center[0], cy=camera_center[1], + zfar=1000) + scene.add(camera, pose=camera_pose) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + # alpha channel indicates visibility + mask = (color[:, :, -1] > 0).astype(np.float32) + renderer.delete() + return mask diff --git a/prima/utils/misc.py b/prima/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7c4a15159ac5395ac7d82d43e28bfc6de308fbd0 --- /dev/null +++ b/prima/utils/misc.py @@ -0,0 +1,211 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import time +import warnings +from importlib.util import find_spec +from pathlib import Path +from typing import Callable, List + +import hydra +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Callback +from pytorch_lightning.loggers import Logger +from pytorch_lightning.utilities import rank_zero_only + +from . import pylogger, rich_utils + +log = pylogger.get_pylogger(__name__) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that wraps the task function in extra utilities. + + Makes multirun more resistant to failure. + + Utilities: + - Calling the `utils.extras()` before the task is started + - Calling the `utils.close_loggers()` after the task is finished + - Logging the exception if occurs + - Logging the task total execution time + - Logging the output dir + """ + + def wrap(cfg: DictConfig): + start_time = time.time() + try: + # apply extra utilities + extras(cfg) + + # execute the task + ret = task_func(cfg=cfg) + except Exception as ex: + log.exception("") # save exception to `.log` file + raise ex + finally: + path = Path(cfg.paths.output_dir, "exec_time.log") + content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" + save_file(path, content) # save task execution time (even if exception occurs) + close_loggers() # close loggers (even if exception occurs so multirun won't fail) + + log.info(f"Output dir: {cfg.paths.output_dir}") + + return ret + + return wrap + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +@rank_zero_only +def save_file(path: str, content: str) -> None: + """Save file in rank zero mode (only on one process in multi-GPU setup).""" + with open(path, "w+") as file: + file.write(content) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("Callbacks config is empty.") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("Logger config is empty.") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + for k in cfg.keys(): + hparams[k] = cfg.get(k) + + # Resolve all interpolations + def _resolve(_cfg): + if isinstance(_cfg, DictConfig): + _cfg = OmegaConf.to_container(_cfg, resolve=True) + return _cfg + + hparams = {k: _resolve(v) for k, v in hparams.items()} + + # send hparams to all loggers + trainer.logger.log_hyperparams(hparams) + + +def get_metric_value(metric_dict: dict, metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule.""" + + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def close_loggers() -> None: + """Makes sure all loggers closed properly (prevents logging failure during multirun).""" + + log.info("Closing loggers...") + + if find_spec("wandb"): # if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() diff --git a/prima/utils/pylogger.py b/prima/utils/pylogger.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5e2f94074a368f12a78036cc469128188f2af8 --- /dev/null +++ b/prima/utils/pylogger.py @@ -0,0 +1,26 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import logging + +from pytorch_lightning.utilities import rank_zero_only + + +def get_pylogger(name=__name__) -> logging.Logger: + """Initializes multi-GPU-friendly python command line logger.""" + + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger diff --git a/prima/utils/renderer.py b/prima/utils/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..b05a0bf4a4225ad77baa3c6fc86633a423a9e25f --- /dev/null +++ b/prima/utils/renderer.py @@ -0,0 +1,442 @@ +from __future__ import annotations +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + + +import os +from ctypes.util import find_library + +if 'PYOPENGL_PLATFORM' not in os.environ and os.uname().sysname != 'Darwin': + # Prefer OSMesa; fall back to EGL where available. + os.environ['PYOPENGL_PLATFORM'] = 'osmesa' if find_library('OSMesa') else 'egl' +import torch +import numpy as np +import pyrender +import trimesh +import cv2 +from yacs.config import CfgNode +from typing import List, Optional + + +def cam_crop_to_full(cam_bbox, box_center, box_size, img_size, focal_length=5000.): + # Convert cam_bbox to full image + img_w, img_h = img_size[:, 0], img_size[:, 1] + cx, cy, b = box_center[:, 0], box_center[:, 1], box_size + w_2, h_2 = img_w / 2., img_h / 2. + bs = b * cam_bbox[:, 0] + 1e-9 + tz = 2 * focal_length / bs + tx = (2 * (cx - w_2) / bs) + cam_bbox[:, 1] + ty = (2 * (cy - h_2) / bs) + cam_bbox[:, 2] + full_cam = torch.stack([tx, ty, tz], dim=-1) + return full_cam + + +def get_light_poses(n_lights=5, elevation=np.pi / 3, dist=12): + # get lights in a circle around origin at elevation + thetas = elevation * np.ones(n_lights) + phis = 2 * np.pi * np.arange(n_lights) / n_lights + poses = [] + trans = make_translation(torch.tensor([0, 0, dist])) + for phi, theta in zip(phis, thetas): + rot = make_rotation(rx=-theta, ry=phi, order="xyz") + poses.append((rot @ trans).numpy()) + return poses + + +def make_translation(t): + return make_4x4_pose(torch.eye(3), t) + + +def make_rotation(rx=0, ry=0, rz=0, order="xyz"): + Rx = rotx(rx) + Ry = roty(ry) + Rz = rotz(rz) + if order == "xyz": + R = Rz @ Ry @ Rx + elif order == "xzy": + R = Ry @ Rz @ Rx + elif order == "yxz": + R = Rz @ Rx @ Ry + elif order == "yzx": + R = Rx @ Rz @ Ry + elif order == "zyx": + R = Rx @ Ry @ Rz + elif order == "zxy": + R = Ry @ Rx @ Rz + return make_4x4_pose(R, torch.zeros(3)) + + +def make_4x4_pose(R, t): + """ + :param R (*, 3, 3) + :param t (*, 3) + return (*, 4, 4) + """ + dims = R.shape[:-2] + pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1) + bottom = ( + torch.tensor([0, 0, 0, 1], device=R.device) + .reshape(*(1,) * len(dims), 1, 4) + .expand(*dims, 1, 4) + ) + return torch.cat([pose_3x4, bottom], dim=-2) + + +def rotx(theta): + return torch.tensor( + [ + [1, 0, 0], + [0, np.cos(theta), -np.sin(theta)], + [0, np.sin(theta), np.cos(theta)], + ], + dtype=torch.float32, + ) + + +def roty(theta): + return torch.tensor( + [ + [np.cos(theta), 0, np.sin(theta)], + [0, 1, 0], + [-np.sin(theta), 0, np.cos(theta)], + ], + dtype=torch.float32, + ) + + +def rotz(theta): + return torch.tensor( + [ + [np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1], + ], + dtype=torch.float32, + ) + + +def create_raymond_lights() -> List[pyrender.Node]: + """ + Return raymond light nodes for the scene. + """ + thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]) + phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0]) + + nodes = [] + + for phi, theta in zip(phis, thetas): + xp = np.sin(theta) * np.cos(phi) + yp = np.sin(theta) * np.sin(phi) + zp = np.cos(theta) + + z = np.array([xp, yp, zp]) + z = z / np.linalg.norm(z) + x = np.array([-z[1], z[0], 0.0]) + if np.linalg.norm(x) == 0: + x = np.array([1.0, 0.0, 0.0]) + x = x / np.linalg.norm(x) + y = np.cross(z, x) + + matrix = np.eye(4) + matrix[:3, :3] = np.c_[x, y, z] + nodes.append(pyrender.Node( + light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0), + matrix=matrix + )) + + return nodes + + +class Renderer: + + def __init__(self, cfg: CfgNode, faces: np.array): + """ + Wrapper around the pyrender renderer to render MANO meshes. + Args: + cfg (CfgNode): Model config file. + faces (np.array): Array of shape (F, 3) containing the mesh faces. + """ + self.cfg = cfg + self.focal_length = 1000. if faces.shape[0] == 7774 else 2167. + self.img_res = cfg.MODEL.IMAGE_SIZE + + self.camera_center = [self.img_res // 2, self.img_res // 2] + self.faces = faces.cpu().numpy() + + def __call__(self, + vertices: np.array, + camera_translation: np.array, + image: torch.Tensor, + full_frame: bool = False, + imgname: Optional[str] = None, + side_view=False, rot_angle=90, + mesh_base_color=(1.0, 1.0, 0.9), + scene_bg_color=(0, 0, 0), + return_rgba=False, + depth = False, + focal_length: Optional[float] = None, + ) -> np.array: + """ + Render meshes on input image + Args: + vertices (np.array): Array of shape (V, 3) containing the mesh vertices. + camera_translation (np.array): Array of shape (3,) with the camera translation. + image (torch.Tensor): Tensor of shape (3, H, W) containing the image crop with normalized pixel values. + full_frame (bool): If True, then render on the full image. + imgname (Optional[str]): Contains the original image filenamee. Used only if full_frame == True. + focal_length (Optional[float]): Custom focal length. If None, uses self.focal_length. + """ + + if full_frame: + + image = cv2.imread(imgname) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255. + else: + image = (image.clone()) * (torch.tensor(self.cfg.MODEL.IMAGE_STD, device=image.device).reshape(3, 1, 1)) + image = image + torch.tensor(self.cfg.MODEL.IMAGE_MEAN, device=image.device).reshape(3, 1, 1) + image = image.permute(1, 2, 0).cpu().numpy() + + # Use custom focal length if provided, otherwise use default + focal_length_to_use = focal_length if focal_length is not None else self.focal_length + + try: + renderer = pyrender.OffscreenRenderer( + viewport_width=image.shape[1], + viewport_height=image.shape[0], + point_size=1.0, + ) + except (IndexError, OSError) as exc: + raise RuntimeError( + "PyRender could not open an OpenGL context (common on headless macOS or remote SSH). " + "Run the demo from a normal desktop session, or on Linux/Spaces use OSMesa (see packages.txt). " + f"Original error: {exc}" + ) from exc + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.0, + alphaMode='OPAQUE', + baseColorFactor=(*mesh_base_color, 1.0)) + + camera_translation_local = camera_translation.copy() + camera_translation_local[0] *= -1. + + mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy()) + if side_view: + rot = trimesh.transformations.rotation_matrix( + np.radians(rot_angle), [0, 1, 0]) + mesh.apply_transform(rot) + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation_local + camera_center = [image.shape[1] / 2., image.shape[0] / 2.] + camera = pyrender.IntrinsicsCamera(fx=focal_length_to_use, fy=focal_length_to_use, + cx=camera_center[0], cy=camera_center[1], zfar=1e12) + scene.add(camera, pose=camera_pose) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + renderer.delete() + + if depth: + return rend_depth + + if return_rgba: + return color + + valid_mask = (rend_depth > 0).astype(np.float32)[:, :, np.newaxis] + if not side_view: + output_img = (color[:, :, :3] * valid_mask + (1 - valid_mask) * image) + else: + output_img = color[:, :, :3] + + output_img = output_img.astype(np.float32) + return output_img + + def vertices_to_trimesh(self, vertices, camera_translation, mesh_base_color=(1.0, 1.0, 0.9), + rot_axis=[1, 0, 0], rot_angle=0): + # material = pyrender.MetallicRoughnessMaterial( + # metallicFactor=0.0, + # alphaMode='OPAQUE', + # baseColorFactor=(*mesh_base_color, 1.0)) + vertex_colors = np.array([(*mesh_base_color, 1.0)] * vertices.shape[0]) + mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces.copy(), vertex_colors=vertex_colors) + # mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy()) + + rot = trimesh.transformations.rotation_matrix( + np.radians(rot_angle), rot_axis) + mesh.apply_transform(rot) + + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + return mesh + + def render_rgba( + self, + vertices: np.array, + cam_t=None, + rot=None, + rot_axis=[1, 0, 0], + rot_angle=0, + camera_z=3, + # camera_translation: np.array, + mesh_base_color=(1.0, 1.0, 0.9), + scene_bg_color=(0, 0, 0), + render_res=[256, 256], + focal_length=None, + ): + + renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0], + viewport_height=render_res[1], + point_size=1.0) + # material = pyrender.MetallicRoughnessMaterial( + # metallicFactor=0.0, + # alphaMode='OPAQUE', + # baseColorFactor=(*mesh_base_color, 1.0)) + + focal_length = focal_length if focal_length is not None else self.focal_length + + if cam_t is not None: + camera_translation = cam_t.copy() + camera_translation[0] *= -1. + else: + camera_translation = np.array([0, 0, camera_z * focal_length / render_res[1]]) + + mesh = self.vertices_to_trimesh(vertices, np.array([0, 0, 0]), mesh_base_color, rot_axis, rot_angle, + ) + mesh = pyrender.Mesh.from_trimesh(mesh) + # mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation + camera_center = [render_res[0] / 2., render_res[1] / 2.] + camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length, + cx=camera_center[0], cy=camera_center[1], zfar=1e12) + + # Create camera node and add it to pyRender scene + camera_node = pyrender.Node(camera=camera, matrix=camera_pose) + scene.add_node(camera_node) + self.add_point_lighting(scene, camera_node) + self.add_lighting(scene, camera_node) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + renderer.delete() + + return color + + def render_rgba_multiple( + self, + vertices: List[np.array], + cam_t: List[np.array], + rot_axis=[1, 0, 0], + rot_angle=0, + mesh_base_color=(1.0, 1.0, 0.9), + scene_bg_color=(0, 0, 0), + render_res=[256, 256], + focal_length=None, + ): + + renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0], + viewport_height=render_res[1], + point_size=1.0) + # material = pyrender.MetallicRoughnessMaterial( + # metallicFactor=0.0, + # alphaMode='OPAQUE', + # baseColorFactor=(*mesh_base_color, 1.0)) + + mesh_list = [pyrender.Mesh.from_trimesh( + self.vertices_to_trimesh(vvv, ttt.copy(), mesh_base_color, rot_axis, rot_angle)) for + vvv, ttt in zip(vertices, cam_t)] + + scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + for i, mesh in enumerate(mesh_list): + scene.add(mesh, f'mesh_{i}') + + camera_pose = np.eye(4) + # camera_pose[:3, 3] = camera_translation + camera_center = [render_res[0] / 2., render_res[1] / 2.] + focal_length = focal_length if focal_length is not None else self.focal_length + camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length, + cx=camera_center[0], cy=camera_center[1], zfar=1e12) + + # Create camera node and add it to pyRender scene + camera_node = pyrender.Node(camera=camera, matrix=camera_pose) + scene.add_node(camera_node) + self.add_point_lighting(scene, camera_node) + self.add_lighting(scene, camera_node) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + renderer.delete() + + return color + + def add_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0): + + light_poses = get_light_poses() + light_poses.append(np.eye(4)) + cam_pose = scene.get_pose(cam_node) + for i, pose in enumerate(light_poses): + matrix = cam_pose @ pose + node = pyrender.Node( + name=f"light-{i:02d}", + light=pyrender.DirectionalLight(color=color, intensity=intensity), + matrix=matrix, + ) + if scene.has_node(node): + continue + scene.add_node(node) + + def add_point_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0): + + light_poses = get_light_poses(dist=0.5) + light_poses.append(np.eye(4)) + cam_pose = scene.get_pose(cam_node) + for i, pose in enumerate(light_poses): + matrix = cam_pose @ pose + # node = pyrender.Node( + # name=f"light-{i:02d}", + # light=pyrender.DirectionalLight(color=color, intensity=intensity), + # matrix=matrix, + # ) + node = pyrender.Node( + name=f"plight-{i:02d}", + light=pyrender.PointLight(color=color, intensity=intensity), + matrix=matrix, + ) + if scene.has_node(node): + continue + scene.add_node(node) + + diff --git a/prima/utils/rich_utils.py b/prima/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..20f8c36a549238eefa0a2096d33e18fbd8775a7a --- /dev/null +++ b/prima/utils/rich_utils.py @@ -0,0 +1,114 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch_lightning.utilities import rank_zero_only +from rich.prompt import Prompt + +from . import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "datamodule", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + save_to_file (bool, optional): Whether to export config to the hydra output folder. + """ + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + queue.append(field) if field in cfg else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config.""" + + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) + + +if __name__ == "__main__": + from hydra import compose, initialize + + with initialize(version_base="1.2", config_path="../../configs_hydra"): + cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) + print_config_tree(cfg, resolve=False, save_to_file=False) diff --git a/prima/utils/weights.py b/prima/utils/weights.py new file mode 100644 index 0000000000000000000000000000000000000000..c02e593a5713a4261e23120fc01fdcbadb88ddbc --- /dev/null +++ b/prima/utils/weights.py @@ -0,0 +1,337 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from __future__ import annotations + +import os +import shutil +from pathlib import Path +from typing import Iterable, Optional, Sequence, Union + +HF_REPO_ID = "MLAdaptiveIntelligence/PRIMA" +DEFAULT_HF_REPO_ID = HF_REPO_ID + +DEFAULT_STAGE1_CHECKPOINT = Path("data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt") +DEFAULT_STAGE3_CHECKPOINT = Path("data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt") + +SMAL_ASSET_PATHS = [ + "my_smpl_00781_4_all.pkl", + "my_smpl_data_00781_4_all.pkl", + "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl", +] +BACKBONE_ASSET_PATH = "amr_vitbb.pth" +STAGE1_CONFIG_ASSET_PATH = "config_s1_HYDRA.yaml" +STAGE1_CHECKPOINT_ASSET_PATH = "s1ckpt_inference.ckpt" +STAGE3_CONFIG_ASSET_PATH = "config_s3_HYDRA.yaml" +STAGE3_CHECKPOINT_ASSET_PATH = "s3ckpt_inference.ckpt" + +STAGE_ASSETS = { + "PRIMAS1": (STAGE1_CONFIG_ASSET_PATH, STAGE1_CHECKPOINT_ASSET_PATH, "s1ckpt_inference.ckpt"), + "PRIMAS3": (STAGE3_CONFIG_ASSET_PATH, STAGE3_CHECKPOINT_ASSET_PATH, "s3ckpt_inference.ckpt"), +} + +STAGE_CHECKPOINTS = { + "PRIMAS1": Path("PRIMAS1/checkpoints/s1ckpt_inference.ckpt"), + "PRIMAS3": Path("PRIMAS3/checkpoints/s3ckpt_inference.ckpt"), +} + +PathLike = Union[str, Path] + + +def _resolve_hf_repo_id(hf_repo_id: Optional[str]) -> str: + return hf_repo_id or os.environ.get("PRIMA_HF_REPO_ID", HF_REPO_ID) + + +def _default_checkpoint_path(data_dir: PathLike = "data") -> Path: + return Path(data_dir) / STAGE_CHECKPOINTS["PRIMAS1"] + + +def _config_path_for_checkpoint(checkpoint_path: PathLike) -> Path: + checkpoint_path = Path(checkpoint_path) + return checkpoint_path.parent.parent / ".hydra" / "config.yaml" + + +def _stage_for_checkpoint(checkpoint_path: PathLike) -> Optional[str]: + checkpoint_path = Path(checkpoint_path) + if len(checkpoint_path.parents) < 2: + return None + stage_name = checkpoint_path.parent.parent.name + stage_assets = STAGE_ASSETS.get(stage_name) + if stage_assets is None: + return None + _, _, checkpoint_name = stage_assets + if checkpoint_path.name != checkpoint_name: + return None + return stage_name + + +def _download_file( + hf_repo_id: str, + remote_filename: str, + destination: Path, + force_download: bool = False, +) -> None: + try: + from huggingface_hub import hf_hub_download + except ImportError: + raise ImportError( + "huggingface_hub is required to download PRIMA demo assets. " + "Install it with: pip install huggingface_hub\n" + "Or download the assets manually and pass a local checkpoint path." + ) from None + + destination.parent.mkdir(parents=True, exist_ok=True) + downloaded = hf_hub_download( + repo_id=hf_repo_id, + filename=remote_filename, + local_dir=str(destination.parent), + local_dir_use_symlinks=False, + force_download=force_download, + ) + downloaded_path = Path(downloaded).resolve() + target = destination.resolve() + if downloaded_path != target: + if target.exists(): + target.unlink() + shutil.move(str(downloaded_path), str(target)) + + +def _validate_torch_checkpoint(path: Path) -> None: + import inspect + import pickle + import zipfile + + import torch + + if zipfile.is_zipfile(path): + with zipfile.ZipFile(path) as checkpoint_zip: + corrupt_member = checkpoint_zip.testzip() + if corrupt_member is not None: + raise RuntimeError( + f"Checkpoint file is invalid or incomplete: {path}\n" + f"Corrupt archive member: {corrupt_member}\n" + "Please redownload the checkpoint and try again." + ) + + supports_weights_only = "weights_only" in inspect.signature(torch.load).parameters + load_kwargs = {"map_location": "cpu"} + if supports_weights_only: + load_kwargs["weights_only"] = True + + try: + torch.load(path, **load_kwargs) + except pickle.UnpicklingError as exc: + message = str(exc) + if ( + supports_weights_only + and "Weights only load failed" in message + and ("Unsupported global" in message or "Unsupported class" in message) + ): + return + raise RuntimeError( + f"Checkpoint file is invalid or incomplete: {path}\n" + "Downloaded checkpoint is not loadable. " + "Please verify the uploaded Hugging Face file and try again." + ) from exc + except Exception as exc: + raise RuntimeError( + f"Checkpoint file is invalid or incomplete: {path}\n" + "Downloaded checkpoint is not loadable. " + "Please verify the uploaded Hugging Face file and try again." + ) from exc + + +def _ensure_backbone(data_dir: Path, force: bool, hf_repo_id: str) -> None: + target = data_dir / "amr_vitbb.pth" + if target.exists() and not force: + print(f"[skip] {target} already exists") + return + + print("[download] pretrained backbone") + _download_file(hf_repo_id, BACKBONE_ASSET_PATH, target, force_download=force) + print(f"[ok] {target}") + + +def _ensure_smal_assets(data_dir: Path, force: bool, hf_repo_id: str) -> None: + required = [Path(p).name for p in SMAL_ASSET_PATHS] + smal_dir = data_dir / "smal" + if smal_dir.exists() and all((smal_dir / n).exists() for n in required) and not force: + print("[skip] SMAL files already exist") + return + + print("[download] SMAL assets") + for asset_path in SMAL_ASSET_PATHS: + target = smal_dir / Path(asset_path).name + _download_file(hf_repo_id, asset_path, target, force_download=force) + print(f"[ok] {smal_dir}") + + +def _ensure_stage_assets( + stage_name: str, + data_dir: Path, + force: bool, + hf_repo_id: str, + validate_existing: bool = True, +) -> None: + if stage_name not in STAGE_ASSETS: + known = ", ".join(sorted(STAGE_ASSETS)) + raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}") + + config_asset_path, checkpoint_asset_path, checkpoint_name = STAGE_ASSETS[stage_name] + stage_dir = data_dir / stage_name + config_target = stage_dir / ".hydra" / "config.yaml" + checkpoint_target = stage_dir / "checkpoints" / checkpoint_name + redownload_checkpoint = False + + if config_target.exists() and checkpoint_target.exists() and not force: + if validate_existing: + try: + _validate_torch_checkpoint(checkpoint_target) + except RuntimeError: + print(f"[warn] {stage_name} checkpoint is incomplete, redownloading checkpoint only.") + redownload_checkpoint = True + else: + print(f"[skip] {stage_name} assets already exist") + return + else: + print(f"[skip] {stage_name} assets already exist") + return + + print(f"[download] {stage_name} assets") + config_target.parent.mkdir(parents=True, exist_ok=True) + checkpoint_target.parent.mkdir(parents=True, exist_ok=True) + if force or not config_target.exists(): + _download_file(hf_repo_id, config_asset_path, config_target, force_download=force) + if redownload_checkpoint and checkpoint_target.exists(): + checkpoint_target.unlink() + if force or redownload_checkpoint or not checkpoint_target.exists(): + _download_file( + hf_repo_id, + checkpoint_asset_path, + checkpoint_target, + force_download=force or redownload_checkpoint, + ) + _validate_torch_checkpoint(checkpoint_target) + print(f"[ok] {stage_dir}") + + +def _normalize_stages(stages: Union[str, Iterable[str]]) -> Sequence[str]: + if isinstance(stages, str): + return (stages,) + return tuple(stages) + + +def _verify_assets(data_dir: Path, stages: Sequence[str]) -> None: + required_paths = [ + data_dir / "smal" / "my_smpl_00781_4_all.pkl", + data_dir / "smal" / "my_smpl_data_00781_4_all.pkl", + data_dir / "smal" / "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl", + data_dir / "amr_vitbb.pth", + ] + for stage_name in stages: + if stage_name not in STAGE_ASSETS: + known = ", ".join(sorted(STAGE_ASSETS)) + raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}") + _, _, checkpoint_name = STAGE_ASSETS[stage_name] + stage_dir = data_dir / stage_name + required_paths.extend( + [ + stage_dir / ".hydra" / "config.yaml", + stage_dir / "checkpoints" / checkpoint_name, + ] + ) + + missing = [p for p in required_paths if not p.exists()] + if missing: + raise FileNotFoundError("Missing required files:\n" + "\n".join(str(p) for p in missing)) + + for stage_name in stages: + _, _, checkpoint_name = STAGE_ASSETS[stage_name] + _validate_torch_checkpoint(data_dir / stage_name / "checkpoints" / checkpoint_name) + + +def _ensure_assets_for_checkpoint( + checkpoint_path: PathLike, + force: bool = False, + hf_repo_id: Optional[str] = None, +) -> None: + checkpoint_path = Path(checkpoint_path) + config_path = _config_path_for_checkpoint(checkpoint_path) + stage_name = _stage_for_checkpoint(checkpoint_path) + if stage_name is None: + if checkpoint_path.exists() and config_path.exists() and not force: + print(f"[skip] Using local PRIMA checkpoint {checkpoint_path}") + return + raise FileNotFoundError( + "Missing checkpoint or config for a custom path:\n" + f" checkpoint: {checkpoint_path}\n" + f" config: {config_path}\n" + "Auto-download supports the standard PRIMA demo layouts only:\n" + " data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt\n" + " data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt\n" + "Pass one of those paths, or download/copy your custom checkpoint manually." + ) + + data_dir = checkpoint_path.parent.parent.parent + repo_id = _resolve_hf_repo_id(hf_repo_id) + print(f"[download] Ensuring PRIMA demo assets under {data_dir}") + _ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id) + _ensure_backbone(data_dir, force=force, hf_repo_id=repo_id) + _ensure_stage_assets( + stage_name, + data_dir, + force=force, + hf_repo_id=repo_id, + validate_existing=False, + ) + + +def ensure_demo_assets( + data_dir: PathLike = "data", + *, + stages: Union[str, Iterable[str]] = ("PRIMAS1",), + force: bool = False, + hf_repo_id: Optional[str] = None, +) -> None: + """Ensure PRIMA demo assets exist in the expected ``data/`` layout.""" + data_dir = Path(data_dir).resolve() + data_dir.mkdir(parents=True, exist_ok=True) + repo_id = _resolve_hf_repo_id(hf_repo_id) + selected_stages = _normalize_stages(stages) + + _ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id) + _ensure_backbone(data_dir, force=force, hf_repo_id=repo_id) + for stage_name in selected_stages: + _ensure_stage_assets(stage_name, data_dir, force=force, hf_repo_id=repo_id) + _verify_assets(data_dir, selected_stages) + + +def resolve_prima_checkpoint_path( + checkpoint_path: PathLike = "", + *, + data_dir: PathLike = "data", + auto_download: bool = True, + hf_repo_id: Optional[str] = None, + force: bool = False, +) -> str: + """Return a PRIMA checkpoint path, downloading default demo assets if needed.""" + resolved = Path(checkpoint_path) if checkpoint_path else _default_checkpoint_path(data_dir) + if auto_download: + _ensure_assets_for_checkpoint(resolved, force=force, hf_repo_id=hf_repo_id) + return str(resolved) + + +__all__ = [ + "DEFAULT_HF_REPO_ID", + "DEFAULT_STAGE1_CHECKPOINT", + "DEFAULT_STAGE3_CHECKPOINT", + "HF_REPO_ID", + "ensure_demo_assets", + "resolve_prima_checkpoint_path", +] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..17576dd989d77ac392e33cb1a1cb92014755e078 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,91 @@ +[build-system] +requires = ["setuptools>=61", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "prima-animal" +version = "0.1.7" +description = "PRIMA: 3D animal pose and shape estimation" +readme = "README.md" +requires-python = ">=3.10" + +authors = [ + { name = "Xiaohang Yu", email = "xiaohang.yu@epfl.ch" }, + { name = "Ti Wang", email = "ti.wang@epfl.ch" }, + { name = "Mackenzie Weygandt Mathis", email = "mackenzie.mathis@epfl.ch" }, + +] + +keywords = ["3d", "animal", "pose", "shape", "vision", "pytorch"] + +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] + +dependencies = [ + # Core + "numpy==1.26.1", + "pandas==2.3.2", + + # Vision & geometry + "opencv-python==4.11.0.86", + "pyrender==0.1.45", + "open3d==0.19.0", + "trimesh==4.8.2", + "scikit-image==0.25.2", + "mmcv==1.3.9", + + # Model components + "smplx==0.1.28", + "yacs==0.1.8", + "timm==1.0.24", + "einops==0.8.1", + "xtcocotools==1.14.3", + "open_clip_torch==3.2.0", + "transformers==4.57.0", + + # Config & utilities + "omegaconf==2.3.0", + "hydra-core==1.3.2", + "hydra-submitit-launcher==1.2.0", + "hydra-colorlog==1.2.0", + "pyrootutils==1.0.4", + "rich==14.1.0", + + # IO / misc + "gdown==5.2.0", + # Match HF Space (Gradio 6.x) and local demo; do not pin 5.1 — Space injects gradio==6.x. + "gradio>=5.1,<7", + "pydantic>=2.10,<3", + + # Training framework + "pytorch-lightning==2.5.5", + + # Build/runtime helpers needed for older mmcv installation flows + "setuptools<81", + "packaging<25", + "Cython<3", + "wheel", + + # Demo runtime dependencies (included in main PyPI install) + "detectron2 @ git+https://github.com/facebookresearch/detectron2.git", + "deeplabcut==3.0.0rc14", +] + +[project.optional-dependencies] +all = [] + +[project.urls] +Homepage = "https://github.com/AdaptiveMotorControlLab/PRIMA" +Source = "https://github.com/AdaptiveMotorControlLab/PRIMA" +Issues = "https://github.com/AdaptiveMotorControlLab/PRIMA/issues" + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +where = ["."] +include = ["prima*", "chumpy*"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b6fc5432ced5c1d4d42514615a8f5c86e1a8ac9b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,38 @@ +huggingface_hub<1 +torch==2.2.1 +torchvision==0.17.1 +# Do not pin gradio here: Hugging Face Spaces runs +# pip install -r requirements.txt gradio[oauth,mcp]==6.x ... +# and a pinned older gradio causes ResolutionImpossible. Local installs get Gradio from +# scripts/clean_install_local.sh after this file, or from pyproject.toml for editable installs. +# 3.x rc needed for ``pose_estimation_pytorch`` / SuperAnimal (2.3.x is TensorFlow-only). +deeplabcut==3.0.0rc14 +# PyTables: use a wheel on macOS (DLC 2.x pinned 3.8.0 from source); Linux pip resolves normally. +tables>=3.9.2,<3.11 +# Animal detector (needs torch installed first; local clean_install uses --no-build-isolation). +detectron2 @ git+https://github.com/facebookresearch/detectron2.git +pytorch-lightning==2.5.5 +yacs==0.1.8 +pyrender==0.1.45 +trimesh==4.8.2 +opencv-python==4.11.0.86 +timm==1.0.24 +einops==0.8.1 +smplx==0.1.28 +xtcocotools==1.14.3 +open_clip_torch==3.2.0 +transformers==4.56.2 +omegaconf==2.3.0 +hydra-core==1.3.2 +hydra-submitit-launcher==1.2.0 +hydra-colorlog==1.2.0 +pyrootutils==1.0.4 +rich==14.1.0 +scikit-image==0.25.2 +pandas==2.3.2 +numpy==1.26.1 +gdown==5.2.0 +setuptools<81 +packaging<25 +Cython<3 +wheel diff --git a/scripts/clean_install_local.sh b/scripts/clean_install_local.sh new file mode 100755 index 0000000000000000000000000000000000000000..7730825366cd72382260ee80c27c466a305ff193 --- /dev/null +++ b/scripts/clean_install_local.sh @@ -0,0 +1,180 @@ +#!/usr/bin/env bash +# Fresh local environment: venv, pip deps, LFS assets, demo checkpoints, smoke test. +# +# Requires Python 3.10+ (matches README, Space, and type hints in app.py). +# +# Usage: +# ./scripts/clean_install_local.sh +# PRIMA_PYTHON=/opt/homebrew/bin/python3.10 ./scripts/clean_install_local.sh +# PRIMA_VENV=.venv ./scripts/clean_install_local.sh --skip-data +# ./scripts/clean_install_local.sh --wipe-data --force-data +set -euo pipefail + +# Non-interactive: no pip/git credential prompts on stdin. +export GIT_TERMINAL_PROMPT=0 +export PIP_DISABLE_PIP_VERSION_CHECK=1 +export HF_HUB_DISABLE_SYMLINKS_WARNING=1 + +ROOT="$(git rev-parse --show-toplevel)" +cd "$ROOT" + +VENV="${PRIMA_VENV:-.venv}" +SKIP_DATA=0 +FORCE_DATA=0 +WIPE_DATA=0 +EDITABLE=1 + +while [[ $# -gt 0 ]]; do + case "$1" in + --venv) + VENV="$2" + shift 2 + ;; + --skip-data) + SKIP_DATA=1 + shift + ;; + --force-data) + FORCE_DATA=1 + shift + ;; + --wipe-data) + WIPE_DATA=1 + shift + ;; + --no-editable) + EDITABLE=0 + shift + ;; + -h|--help) + echo "Usage: $0 [--venv DIR] [--skip-data] [--force-data] [--wipe-data] [--no-editable]" + echo "Env: PRIMA_PYTHON=python3.10 PRIMA_VENV=.venv" + exit 0 + ;; + *) + echo "Unknown option: $1" >&2 + exit 1 + ;; + esac +done + +resolve_python() { + if [[ -n "${PRIMA_PYTHON:-}" ]]; then + if [[ -x "${PRIMA_PYTHON}" ]] || command -v "${PRIMA_PYTHON}" >/dev/null 2>&1; then + echo "${PRIMA_PYTHON}" + return 0 + fi + echo "[clean-install] ERROR: PRIMA_PYTHON=${PRIMA_PYTHON} is not executable." >&2 + return 1 + fi + local c p + for c in python3.12 python3.11 python3.10; do + if command -v "$c" >/dev/null 2>&1; then + if "$c" -c 'import sys; raise SystemExit(0 if sys.version_info >= (3, 10) else 1)'; then + command -v "$c" + return 0 + fi + fi + done + for p in /opt/homebrew/bin/python3.10 /usr/local/bin/python3.10; do + if [[ -x "$p" ]]; then + echo "$p" + return 0 + fi + done + return 1 +} + +echo "[clean-install] Repository: ${ROOT}" + +if ! PY="$(resolve_python)"; then + echo "[clean-install] ERROR: Need Python 3.10 or newer (Gradio 5 + app type hints)." >&2 + echo " macOS: brew install python@3.10" >&2 + echo " Then: PRIMA_PYTHON=/opt/homebrew/bin/python3.10 $0 ..." >&2 + exit 1 +fi +echo "[clean-install] Using Python: $("$PY" -c 'import sys; print(sys.executable, sys.version.split()[0])')" + +if command -v git-lfs >/dev/null 2>&1; then + echo "[clean-install] git lfs pull (demo images / teaser) ..." + git lfs install + git lfs pull +else + echo "[clean-install] WARN: git-lfs not found; demo images may be LFS pointer stubs. Install: brew install git-lfs && git lfs install" >&2 +fi + +if [[ -d "$VENV" ]]; then + echo "[clean-install] Removing existing venv: ${VENV}" + rm -rf "$VENV" +fi + +echo "[clean-install] Creating venv: ${VENV}" +"$PY" -m venv "$VENV" +# shellcheck disable=SC1090 +source "${VENV}/bin/activate" + +python -m pip install --no-input -U pip wheel +# Match requirements.txt / pyproject pins before pulling the rest +python -m pip install --no-input "setuptools<81" "packaging<25" "Cython<3" + +echo "[clean-install] pip install -r requirements.txt (this can take a long time) ..." +REQ_TMP="$(mktemp)" +grep -vE '^[[:space:]]*(deeplabcut|detectron2)' "${ROOT}/requirements.txt" > "${REQ_TMP}" +python -m pip install --no-input -r "${REQ_TMP}" +rm -f "${REQ_TMP}" + +if [[ "$(uname -s)" == "Darwin" ]]; then + echo "[clean-install] macOS: PyTables wheel then DeepLabCut 3.x (SuperAnimal pytorch API) ..." + python -m pip install --no-input "tables>=3.9.2,<3.11" + python -m pip install --no-input "deeplabcut==3.0.0rc14" || { + echo "[clean-install] ERROR: deeplabcut install failed. Try: brew install hdf5 && retry." >&2 + exit 1 + } +else + python -m pip install --no-input "deeplabcut==3.0.0rc14" +fi + +echo "[clean-install] Detectron2 (needs torch in venv; --no-build-isolation) ..." +python -m pip install --no-input --no-build-isolation \ + "detectron2 @ git+https://github.com/facebookresearch/detectron2.git" + +# Spaces install Gradio separately; local venv needs it for app.py. +echo "[clean-install] Installing Gradio for local demo (HF Space provides its own) ..." +python -m pip install --no-input "gradio>=5.1,<7" + +if [[ "$EDITABLE" -eq 1 ]]; then + echo "[clean-install] pip install --no-deps -e . (register package; runtime deps from requirements.txt) ..." + python -m pip install --no-input --no-deps -e "${ROOT}" +fi + +if [[ "$WIPE_DATA" -eq 1 ]]; then + echo "[clean-install] Wiping downloaded demo data under data/ ..." + rm -rf "${ROOT}/data/PRIMAS1" "${ROOT}/data/PRIMAS3" "${ROOT}/data/smal" "${ROOT}/data/amr_vitbb.pth" 2>/dev/null || true +fi + +if [[ "$SKIP_DATA" -eq 0 ]]; then + FORCE_ARGS=() + if [[ "$FORCE_DATA" -eq 1 ]]; then + FORCE_ARGS=(--force) + fi + echo "[clean-install] Downloading demo assets (large) ..." + python "${ROOT}/scripts/setup_demo_data.py" "${FORCE_ARGS[@]}" +else + echo "[clean-install] Skipping setup_demo_data (--skip-data)." +fi + +export PYTHONPATH="${ROOT}${PYTHONPATH:+:${PYTHONPATH}}" + +echo "[clean-install] Smoke test: import app + build_demo + DeepLabCut API ..." +python -c " +import app +app.get_demo_profile.cache_clear() +p = app.get_demo_profile() +print('[clean-install] demo profile:', p.mode) +app.build_demo() +print('[clean-install] DeepLabCut SuperAnimal (may take ~30s on first import) ...') +from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images # noqa: F401 +print('[clean-install] Gradio demo build + DeepLabCut 3.x: OK') +" + +echo "[clean-install] Done. Activate with: source ${VENV}/bin/activate" diff --git a/scripts/clean_redeploy_hf_space.sh b/scripts/clean_redeploy_hf_space.sh new file mode 100755 index 0000000000000000000000000000000000000000..1192a4f158f4fb13f8d5dfdabfe5bcd1d98bff12 --- /dev/null +++ b/scripts/clean_redeploy_hf_space.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash +# Clean redeploy of the Hugging Face Space from the current working tree. +# Same as scripts/deploy_hf_space.sh; use after a local clean install or any code change. +set -euo pipefail +ROOT="$(git rev-parse --show-toplevel)" +exec "${ROOT}/scripts/deploy_hf_space.sh" diff --git a/scripts/deploy_hf_space.sh b/scripts/deploy_hf_space.sh new file mode 100755 index 0000000000000000000000000000000000000000..aeed09a80ddb3eeaa2310cccfb607bf02c3bed38 --- /dev/null +++ b/scripts/deploy_hf_space.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash +# Deploy working tree to Hugging Face Space MLAdaptiveIntelligence/PRIMA-demo. +# +# Demo PNG/JPG are tracked with Git LFS (Hugging Face Hub Xet bridge); see .gitattributes. +# We rsync the working tree (not ``git archive``) so LFS-tracked files are real bytes here, +# then ``git add`` stores them as LFS objects on push. +# +# Prerequisites: brew install git-lfs git-xet && git xet install && git lfs install +set -euo pipefail + +export GIT_TERMINAL_PROMPT=0 + +ROOT="$(git rev-parse --show-toplevel)" +cd "$ROOT" +SPACE_URL="${HF_SPACE_GIT_URL:-https://huggingface.co/spaces/MLAdaptiveIntelligence/PRIMA-demo.git}" + +if ! command -v git-lfs >/dev/null 2>&1; then + echo "[deploy] ERROR: git-lfs is required. Install: brew install git-lfs && git lfs install" >&2 + exit 1 +fi + +TMP="$(mktemp -d)" +cleanup() { rm -rf "$TMP"; } +trap cleanup EXIT + +echo "[deploy] Rsync working tree from ${ROOT} ..." +rsync -a \ + --exclude=".git/" \ + --exclude="__pycache__/" \ + --exclude="*.pyc" \ + --exclude=".DS_Store" \ + --exclude="data/" \ + --exclude=".venv/" \ + --exclude="venv/" \ + --exclude=".pytest_cache/" \ + --exclude=".tmp_gradio_info.json" \ + --exclude="demo_out_tta_gradio/" \ + --exclude=".gradio/" \ + "${ROOT}/" "${TMP}/" + +cd "$TMP" + +echo "[deploy] Git init + LFS commit ..." +git init -q +git lfs install +git add -A +git -c user.email="space-deploy@users.noreply.github.com" -c user.name="HF Space deploy" commit -q -m "Deploy snapshot (LFS for demo images per .gitattributes)" + +git remote add hf "$SPACE_URL" +echo "[deploy] Force-pushing to Hugging Face Space ..." +GIT_TERMINAL_PROMPT=0 git -c credential.helper=osxkeychain push hf HEAD:main --force +echo "[deploy] Done." diff --git a/scripts/local_infer.py b/scripts/local_infer.py new file mode 100755 index 0000000000000000000000000000000000000000..1de893bf70d738420209170e8ecb75ea121773e0 --- /dev/null +++ b/scripts/local_infer.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +import cv2 + + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Local PRIMA inference (no Gradio).") + p.add_argument( + "--image", + type=str, + default=str(ROOT / "demo_data" / "beagle.jpg"), + help="Path to an input image.", + ) + p.add_argument( + "--out", + type=str, + default=str(ROOT / "demo_out_local_cli"), + help="Output folder for PNG renders / artifacts.", + ) + p.add_argument("--tta_lr", type=float, default=1e-6) + p.add_argument("--tta_iters", type=int, default=0, help="0 disables TTA.") + p.add_argument("--det_thresh", type=float, default=0.7) + p.add_argument("--kp_conf_thresh", type=float, default=0.1) + p.add_argument("--side_view", action="store_true") + p.add_argument("--save_mesh", action="store_true") + return p.parse_args() + + +def main() -> int: + # Ensure local defaults (GPU if available) but no Space-only preload behavior. + os.environ.setdefault("PRIMA_DEMO_MODE", "local") + os.environ.setdefault("PRIMA_PRELOAD_ASSETS", "0") + + import numpy as np # noqa: E402 + + import app # noqa: E402 + + args = parse_args() + out_dir = Path(args.out) + out_dir.mkdir(parents=True, exist_ok=True) + + img_path = Path(args.image) + if not img_path.is_file(): + raise FileNotFoundError(f"Missing image: {img_path}") + + img_bgr = cv2.imread(str(img_path)) + if img_bgr is None: + raise RuntimeError(f"Failed to read image: {img_path}") + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.uint8) + + print("[local_infer] Loading PRIMA model ...") + model, model_cfg, renderer, device = app._load_prima_model() + print(f"[local_infer] device={device}") + + print("[local_infer] Building detector (Detectron2 if installed, else fallback) ...") + detector = app._build_detector() + print(f"[local_infer] detector={'detectron2' if detector is not None else 'fallback'}") + + print("[local_infer] Running inference ...") + before, after, kpts, mesh_before, mesh_after = app._collect_animal_results( + model, + model_cfg, + renderer, + device, + detector, + str(out_dir), + img_rgb, + tta_lr=float(args.tta_lr), + tta_num_iters=int(args.tta_iters), + det_thresh=float(args.det_thresh), + kp_conf_thresh=float(args.kp_conf_thresh), + side_view=bool(args.side_view), + save_mesh=bool(args.save_mesh), + ) + + print(f"[local_infer] renders: before={len(before)} after={len(after)} kpts={len(kpts)}") + if mesh_before or mesh_after: + print(f"[local_infer] meshes: before={mesh_before} after={mesh_after}") + + pngs = sorted(out_dir.glob("*.png")) + for p in pngs: + print(f"[local_infer] output: {p}") + + if not pngs: + raise RuntimeError("No PNG outputs produced.") + + print("[local_infer] OK") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) + diff --git a/scripts/run_local_demo_once.py b/scripts/run_local_demo_once.py new file mode 100644 index 0000000000000000000000000000000000000000..828edaadb79e47f9f3f95dd36ffd287aee5e4c22 --- /dev/null +++ b/scripts/run_local_demo_once.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license + +One-shot local smoke: load PRIMA, run beagle demo (TTA off), print paths to outputs. +""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT)) +os.environ.setdefault("PRIMA_PRELOAD_ASSETS", "0") + +import cv2 # noqa: E402 + +import app # noqa: E402 + + +def main() -> int: + out_dir = ROOT / "demo_out_tta_gradio_local_proof" + out_dir.mkdir(parents=True, exist_ok=True) + img_path = ROOT / "demo_data" / "beagle.jpg" + if not img_path.is_file(): + print(f"ERROR: missing {img_path}") + return 1 + + print("[1/4] Loading PRIMA checkpoint …") + model, cfg, renderer, device = app._load_prima_model() + print(f" device={device}") + + print("[2/4] Building detector (Detectron2 if installed, else full-image bbox) …") + det = app._build_detector() + print(f" detector={'detectron2' if det is not None else 'full-image fallback'}") + + img = cv2.cvtColor(cv2.imread(str(img_path)), cv2.COLOR_BGR2RGB) + print(f"[3/4] Running inference on {img_path.name} (TTA iterations=0) …") + before, after, kpts, _, _ = app._collect_animal_results( + model, + cfg, + renderer, + device, + det, + str(out_dir), + img, + 1e-6, + 0, + 0.7, + 0.1, + False, + False, + ) + print(f" renders: before={len(before)} after={len(after)} kpts={len(kpts)}") + + pngs = sorted(out_dir.glob("*.png")) + print("[4/4] Output files:") + for p in pngs: + print(f" {p}") + + if not pngs: + print("FAIL: no PNG outputs (often pyrender/display on headless macOS).") + return 1 + + print("OK: local demo produced outputs.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/setup_demo_data.py b/scripts/setup_demo_data.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca3659668355bea3f0511d3a944816493286f58 --- /dev/null +++ b/scripts/setup_demo_data.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" +# Download and arrange PRIMA demo assets into the expected data/ layout. +# Usage: +# python scripts/setup_demo_data.py +# python scripts/setup_demo_data.py --include-stage3 +# python scripts/setup_demo_data.py --force + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +_REPO_ROOT = Path(__file__).resolve().parent.parent +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from prima.utils.weights import ( + DEFAULT_HF_REPO_ID, + ensure_demo_assets, +) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Download PRIMA demo checkpoints and data") + parser.add_argument("--data-dir", type=Path, default=Path("data"), help="Target data directory") + parser.add_argument("--force", action="store_true", help="Redownload and overwrite existing files") + parser.add_argument( + "--include-stage3", + action="store_true", + help="Also prefetch the Stage 3 checkpoint and config", + ) + parser.add_argument( + "--hf-repo-id", + type=str, + default=DEFAULT_HF_REPO_ID, + help="Hugging Face repo ID containing demo assets (e.g., org/repo)", + ) + args = parser.parse_args() + stages = ("PRIMAS1", "PRIMAS3") if args.include_stage3 else ("PRIMAS1",) + ensure_demo_assets( + args.data_dir, + stages=stages, + force=args.force, + hf_repo_id=args.hf_repo_id, + ) + + print("\n[done] Demo assets ready.") + print("Run demo:") + print(" python demo.py --img_folder demo_data/ --out_folder demo_out/") + print("Run demo with TTA:") + print(" python demo_tta.py --img_folder demo_data/ --out_folder demo_out_tta/ --tta_lr 1e-6 --tta_num_iters 30") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/test_local_full.sh b/scripts/test_local_full.sh new file mode 100755 index 0000000000000000000000000000000000000000..9bb90fbb24f0ab422b5a8aff972b59f2e34aaa23 --- /dev/null +++ b/scripts/test_local_full.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +# Full local smoke: local CLI inference (no Gradio). +set -euo pipefail + +export GIT_TERMINAL_PROMPT=0 +ROOT="$(git rev-parse --show-toplevel)" +cd "$ROOT" +VENV="${PRIMA_VENV:-.venv}" + +if [[ ! -x "${VENV}/bin/python" ]]; then + echo "ERROR: missing ${VENV}. Run: ./scripts/clean_install_local.sh --skip-data" >&2 + exit 1 +fi + +# shellcheck disable=SC1090 +source "${VENV}/bin/activate" +export PYTHONPATH="${ROOT}${PYTHONPATH:+:${PYTHONPATH}}" +export PRIMA_PRELOAD_ASSETS=0 +export PRIMA_DEMO_MODE=local + +echo "=== [1/3] Demo profile (local) ===" +python -c "import app; app.get_demo_profile.cache_clear(); p=app.get_demo_profile(); print('profile:', p.mode)" + +echo "=== [2/3] DeepLabCut SuperAnimal API ===" +python -c " +from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images # noqa: F401 +print('DeepLabCut SuperAnimal: OK') +" + +echo "=== [3/3] PRIMA local CLI inference (beagle, TTA off) ===" +python "${ROOT}/scripts/local_infer.py" --tta_iters 0 2>&1 + +echo "=== All local checks passed ===" diff --git a/scripts/update_headers.py b/scripts/update_headers.py new file mode 100644 index 0000000000000000000000000000000000000000..96b8b16de5a62fa8a39950b7f42ac7684133d9c5 --- /dev/null +++ b/scripts/update_headers.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import os +import sys +from pathlib import Path + +# Define the standard header for the project +STANDARD_HEADER = '''""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +"""''' + +# Old headers that should be replaced +OLD_HEADERS = [ + '''""" + +"""''' +] + + +def should_skip_file(file_path): + """ + Determine if a file should be skipped for header addition. + + Args: + file_path: Path to check + + Returns: + True if the file should be skipped, False otherwise + """ + skip_dirs = {'.git', '__pycache__', '.pytest_cache', 'venv', '.venv', 'env', '.tox', 'build', 'dist', '.eggs', 'site-packages'} + + # Skip if in excluded directory + for part in file_path.parts: + if part in skip_dirs: + return True + + return False + + +def has_header(content): + """ + Check if content already has the standard header or a valid variant. + + Args: + content: File content to check + + Returns: + True if the file has the standard header or acceptable variant, False otherwise + """ + # Check for exact match + if STANDARD_HEADER.strip() in content: + return True + + required_elements = [ + 'PRIMA: Boosting Animal Mesh Recovery with Biological Priors', + 'Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis', + 'Licensed under a modified MIT license', + ] + + return all(elem in content for elem in required_elements) + + +def needs_header_update(content): + """ + Check if content has an old header that needs updating. + + Args: + content: File content to check + + Returns: + Old header if found, None otherwise + """ + for old_header in OLD_HEADERS: + if old_header.strip() in content: + return old_header + return None + + +def add_or_update_header(file_path, check_only=False): + """ + Add or update the header in a single file. + + Args: + file_path: Path to the file to update + check_only: If True, only check without modifying + + Returns: + Tuple of (status, message) where status is 'ok', 'updated', 'added', or 'error' + """ + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + # Check if file already has the correct header + if has_header(content): + return ('ok', 'Already has correct header') + + # Check if file has an old header that needs replacing + old_header = needs_header_update(content) + if old_header: + if not check_only: + new_content = content.replace(old_header, STANDARD_HEADER) + with open(file_path, 'w', encoding='utf-8') as f: + f.write(new_content) + return ('updated', 'Replaced old header with standard header') + + # File has no header, add one + # Skip adding header to files that start with shebang or are very short + lines = content.split('\n') + if content.strip() and len(content.strip()) > 10: + if not check_only: + # Handle special cases for header placement + new_lines = [] + insert_index = 0 + + # If file starts with shebang, keep it at the top + if lines[0].startswith('#!'): + new_lines.append(lines[0]) + insert_index = 1 + + # Check for 'from __future__' imports which must be very early + # Find the first non-comment, non-shebang, non-empty line + future_import_index = None + for i in range(insert_index, min(len(lines), 10)): + line = lines[i].strip() + if line.startswith('from __future__'): + future_import_index = i + break + elif line and not line.startswith('#'): + # Found a non-comment line that isn't a future import + break + + if future_import_index is not None: + # ``from __future__`` must stay first; PRIMA header docstring follows it. + new_lines.extend(lines[insert_index:future_import_index + 1]) + if STANDARD_HEADER.strip() not in content: + new_lines.append(STANDARD_HEADER) + new_lines.append('') + new_lines.extend(lines[future_import_index + 1:]) + else: + # Otherwise, add header at the beginning (after shebang if present) + new_lines.append(STANDARD_HEADER) + new_lines.append('') + new_lines.extend(lines[insert_index:]) + + new_content = '\n'.join(new_lines) + + with open(file_path, 'w', encoding='utf-8') as f: + f.write(new_content) + return ('added', 'Added standard header') + + return ('ok', 'Skipped (file too short or empty)') + + except Exception as e: + return ('error', f"Error processing file: {e}") + + +def find_and_process_headers(root_dir, check_only=False): + """ + Find and process all Python files. + + Args: + root_dir: Root directory to search from + check_only: If True, only check without modifying files + + Returns: + Dictionary with statistics about processed files + """ + root_path = Path(root_dir) + stats = { + 'ok': [], + 'updated': [], + 'added': [], + 'error': [] + } + + # Find all Python files + for py_file in root_path.rglob('*.py'): + # Skip files that should not be processed + if should_skip_file(py_file): + continue + + status, message = add_or_update_header(py_file, check_only) + stats[status].append((py_file, message)) + + if status in ['updated', 'added']: + rel_path = py_file.relative_to(root_path) + print(f"{'[CHECK]' if check_only else '✓'} {rel_path}: {message}") + elif status == 'error': + rel_path = py_file.relative_to(root_path) + print(f"✗ {rel_path}: {message}") + + return stats + + +def main(): + """Main function to run the header update script.""" + check_only = '--check' in sys.argv + + if len(sys.argv) > 1 and not sys.argv[1].startswith('--'): + root_dir = Path(sys.argv[1]) + else: + root_dir = Path(os.getcwd()) + + mode = "Checking" if check_only else "Processing" + print(f"{mode} files for headers in: {root_dir}") + print("-" * 60) + + stats = find_and_process_headers(root_dir, check_only) + + print("-" * 60) + + # Print summary + total_changes = len(stats['updated']) + len(stats['added']) + + if check_only: + if total_changes > 0: + print(f"\n⚠ Found {total_changes} file(s) needing header updates:") + for file_path, msg in stats['updated']: + print(f" - {file_path.relative_to(root_dir)}: {msg}") + for file_path, msg in stats['added']: + print(f" - {file_path.relative_to(root_dir)}: {msg}") + return 1 + else: + print("\n✓ All Python files have correct headers!") + return 0 + else: + if total_changes > 0: + print(f"\n✓ Successfully processed {total_changes} file(s):") + if stats['updated']: + print(f" - Updated: {len(stats['updated'])} file(s)") + if stats['added']: + print(f" - Added headers: {len(stats['added'])} file(s)") + else: + print("\n✓ No files needed header updates.") + + if stats['error']: + print(f"\n✗ Errors: {len(stats['error'])} file(s)") + for file_path, msg in stats['error']: + print(f" - {file_path.relative_to(root_dir)}: {msg}") + return 1 + + return 0 + + +if __name__ == '__main__': + sys.exit(main()) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d9bcdf70159c3a02894e7fd6d6f700fa9cd374d8 --- /dev/null +++ b/train.py @@ -0,0 +1,152 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from typing import Optional +import pyrootutils + + +root = pyrootutils.setup_root( + search_from=__file__, + indicator=[".git", "pyproject.toml"], + pythonpath=True, + dotenv=True, +) + +import os +import sys + +import hydra +import pytorch_lightning as pl +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.plugins.environments import SLURMEnvironment +from pytorch_lightning.callbacks import TQDMProgressBar +from tqdm import tqdm +from prima.datasets import DataModule +from prima.models.prima import PRIMA +from prima.utils.pylogger import get_pylogger +from prima.utils.misc import log_hyperparameters +import signal + +signal.signal(signal.SIGUSR1, signal.SIG_DFL) + + +class MyTQDMProgressBar(TQDMProgressBar): + + def __init__(self): + super(MyTQDMProgressBar, self).__init__() + + def init_train_tqdm(self): + bar = super().init_train_tqdm() + bar.ncols = 150 + bar.dynamic_ncols=False + return bar + + def init_validation_tqdm(self): + bar = tqdm( + desc=self.validation_description, + position=0, + disable=self.is_disabled, + leave=True, + # dynamic_ncols=True, + file=sys.stdout, + dynamic_ncols= False, + ncols = 150, + ) + return bar + + +@hydra.main(version_base="1.2", config_path= "./configs_hydra", config_name="train.yaml") + +def main(cfg: DictConfig) -> Optional[float]: + datamodule = DataModule(cfg) + model = PRIMA(cfg) + + # Setup Tensorboard logger + logger = TensorBoardLogger(os.path.join(cfg.paths.output_dir, 'tensorboard'), name='', version='', + default_hp_metric=False) + loggers = [logger] + + # Setup checkpoint saving + checkpoint_callback = pl.callbacks.ModelCheckpoint( + dirpath=os.path.join(cfg.paths.output_dir, 'checkpoints'), + # every_n_train_steps=cfg.GENERAL.CHECKPOINT_STEPS, + every_n_epochs=cfg.GENERAL.CHECKPOINT_EPOCHS, + save_last=True, + # Monitor a metric so `save_top_k` keeps the best checkpoint instead of the last one. + # We monitor the validation loss logged as 'val/loss' (lower is better). + monitor='val/loss', + mode='min', + save_top_k=cfg.GENERAL.CHECKPOINT_SAVE_TOP_K, + filename="best-{epoch:03d}-{val_loss:.4f}", # Clearly label the best checkpoint + ) + + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step') + callbacks = [ + checkpoint_callback, + lr_monitor, + # rich_callback + MyTQDMProgressBar() + ] + + log = get_pylogger(__name__) + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=loggers, + plugins=(SLURMEnvironment(requeue_signal=signal.SIGUSR2) if (cfg.get('launcher', None) is not None) else None), + sync_batchnorm=True, + ) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + log_hyperparameters(object_dict) + + # Train the model + # Determine checkpoint path + ckpt_path = None + last_v1_ckpt = os.path.join(cfg.paths.output_dir, 'checkpoints', 'last-v1.ckpt') + last_ckpt = os.path.join(cfg.paths.output_dir, 'checkpoints', 'last.ckpt') + + if os.path.exists(last_v1_ckpt): + ckpt_path = last_v1_ckpt + log.info(f"Resuming from checkpoint: {ckpt_path}") + elif os.path.exists(last_ckpt): + ckpt_path = last_ckpt + log.info(f"Resuming from checkpoint: {ckpt_path}") + else: + log.info("No checkpoint found, starting from scratch") + + trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info("Fitting done") + + +if __name__ == "__main__": + import torch + import gc + + gc.collect() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + for i in range(torch.cuda.device_count()): + print(f"GPU {i}: {torch.cuda.memory_allocated(i)/1024**2:.2f} MiB allocated, " + f"{torch.cuda.memory_reserved(i)/1024**2:.2f} MiB reserved") + main() diff --git a/train.sh b/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..3c52fb4329dc6047efb4c3bc6e0b21f7d99d0e2a --- /dev/null +++ b/train.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -e + +exp_name_stage1=primaStage1 +exp_name_stage2=primaStage2 +experiment1=primaStage1 +experiment2=primaStage2 +HYDRA_FULL_ERROR=1 python train.py exp_name=$exp_name_stage1 experiment=$experiment1 trainer=gpu launcher=local +cp -r ./logs/train/runs/$exp_name_stage1 ./logs/train/runs/$exp_name_stage2 +HYDRA_FULL_ERROR=1 python train.py exp_name=$exp_name_stage2 experiment=$experiment2 trainer=gpu launcher=local