diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..c9aeeafbf06c1e9208bc88aa7befba9918e8765e --- /dev/null +++ b/.gitattributes @@ -0,0 +1,5 @@ +# Hugging Face Hub stores these via Git LFS / Xet (plain PNG/JPG in git are rejected on push). +demo_data/*.png filter=lfs diff=lfs merge=lfs -text +demo_data/*.jpg filter=lfs diff=lfs merge=lfs -text +demo_data/*.jpeg filter=lfs diff=lfs merge=lfs -text +images/*.png filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c1e616422f53bc9d7b1b2b86b7adc53bece27fe0 --- /dev/null +++ b/README.md @@ -0,0 +1,251 @@ +--- +title: PRIMA Demo +emoji: 🦮 +colorFrom: blue +colorTo: green +sdk: gradio +python_version: "3.10" +app_file: app.py +startup_duration_timeout: 60m +--- + +# PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + + +This is the official implementation of the approach described in the preprint: + +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation \ +Xiaohang Yu, Ti Wang, Mackenzie Weygandt Mathis + +![PRIMA teaser](images/teaser.png) + + +--- + + +## 🚀 TL;DR +PRIMA creates a 3D quadruped mesh from a single 2D image. It leverages BioCLIP-based biological priors for robust cross-species shape understanding, then applies test-time adaptation with 2D reprojection and auxiliary keypoint guidance to refine SMAL pose and shape predictions. + +It further can be used to build Quadruped3D, a large-scale pseudo-3D dataset with diverse species and poses. + +PRIMA achieves state-of-the-art results on Animal3D, CtrlAni3D, Quadruped2D, and Animal Kingdom datasets. + +## Installation + +### Install from PyPI + +> Recommended: Python 3.10 and a CUDA-enabled PyTorch installation. + +```bash +conda create -n prima python=3.10 -y +conda activate prima + +# Install PyTorch matching your CUDA (example: CUDA 11.8) +pip install --index-url https://download.pytorch.org/whl/cu118 \ + "torch==2.2.1" "torchvision==0.17.1" "torchaudio==2.2.1" + +# Install chumpy and PyTorch3D +python -m pip install --no-build-isolation \ + "git+https://github.com/mattloper/chumpy.git" +python -m pip install --no-build-isolation \ + "git+https://github.com/facebookresearch/pytorch3d.git" + +# Install PRIMA from PyPI +pip install prima-animal +``` + +`prima-animal` includes demo runtime dependencies used by `demo.py`, `demo_tta.py`, and `app.py` (including Detectron2 and DeepLabCut). + +### Clean install from this repository + +Use these when developing from a **git clone** (not the PyPI wheel). The shell scripts are **non-interactive** (pip uses `--no-input`; `GIT_TERMINAL_PROMPT=0` for git). Put Hugging Face credentials in your environment or git credential helper before pushing the Space. + +**Local (fresh venv, LFS assets, Hub demo weights, smoke test)** — requires **Python 3.10+** +(Gradio 5.1+ / Space-provided Gradio 6.x and `app.py` type hints). On macOS without `python3.10` on your `PATH`, install +`brew install python@3.10` and set `PRIMA_PYTHON=/opt/homebrew/bin/python3.10`. + +```bash +chmod +x scripts/clean_install_local.sh scripts/clean_redeploy_hf_space.sh scripts/deploy_hf_space.sh +PRIMA_PYTHON=/opt/homebrew/bin/python3.10 ./scripts/clean_install_local.sh +``` + +Options: + +- `PRIMA_VENV=.venv ./scripts/clean_install_local.sh --skip-data` — skip the large `setup_demo_data` download if `data/` is already populated. +- `./scripts/clean_install_local.sh --wipe-data --force-data` — delete downloaded `data/` assets and redownload. +- `./scripts/clean_install_local.sh --no-editable` — only `requirements.txt` (no `pip install -e .`); use if editable install fails and you will install the training stack via conda as in the PyPI section above. You still need **Python 3.10+** for Gradio 5.1+. The smoke test sets `PYTHONPATH` to the repo root so `import prima` works without an editable install. +- **macOS:** the script omits the `deeplabcut` line from `pip install` because DeepLabCut’s pinned PyTables version often does not build on Apple Silicon. Use conda/mamba for DeepLabCut if you need SuperAnimal + TTA (`tta_num_iters` > 0). **Linux** (including Hugging Face Space builds) uses the full `requirements.txt` including `deeplabcut`. + +After `requirements.txt`, the script runs **`pip install --no-deps -e .`** so the `prima` package is registered without re-resolving `pyproject.toml` (which would pull **Detectron2** and **DeepLabCut** again and often fail on macOS). Full `pip install -e .` is still recommended from a **conda** environment per the PyPI section if you need every training extra matched exactly. + +**Hugging Face Space (full redeploy from your working tree):** + +Requires [Git LFS / Xet](https://huggingface.co/docs/hub/xet/using-xet-storage#git) tooling (`brew install git-lfs git-xet`, `git xet install`, `git lfs install`). Then: + +```bash +./scripts/clean_redeploy_hf_space.sh +``` + +This is equivalent to `./scripts/deploy_hf_space.sh` and force-pushes a fresh snapshot to the Space. + +--- + +## Demo + +### Checkpoints and data + +The demo scripts auto-download their default Stage 1 PRIMA assets from Hugging +Face when the checkpoint or matching Hydra config is missing. If you want to +pre-download all necessary checkpoints and data ahead of time, run: + +```bash +python scripts/setup_demo_data.py --hf-repo-id MLAdaptiveIntelligence/PRIMA +``` + +Approximate default prefetch volume from Hugging Face is ~5.5 GB total +(`s1ckpt_inference.ckpt` ~3 GB + `amr_vitbb.pth` ~2.5 GB + SMAL files). +Expected time is roughly: +- 100 Mbps: ~7-10 minutes +- 300 Mbps: ~2-4 minutes +- 1 Gbps: ~1 minute + +Existing files are reused by default; pass `--force` only if you need to redownload them. If you also need the Stage 3 pretrained model, add `--include-stage3`. + +Expected files in that Hugging Face repo root: +- `my_smpl_00781_4_all.pkl` +- `my_smpl_data_00781_4_all.pkl` +- `walking_toy_symmetric_pose_prior_with_cov_35parts.pkl` +- `amr_vitbb.pth` +- `config_s1_HYDRA.yaml` +- `s1ckpt_inference.ckpt` + +Optional Stage 3 prefetch expects: +- `config_s3_HYDRA.yaml` +- `s3ckpt_inference.ckpt` + +### Demo (without TTA) + +Run animal detection + PRIMA 3D pose/shape inference: + +```bash +bash demo.sh +``` + +Outputs are written to `demo_out/`. Edit `demo.sh` if you want to use a custom +checkpoint path. + +--- + +### Demo (with TTA) + +Run PRIMA inference with test-time adaptation: + +```bash +bash demo_tta.sh +``` + +Outputs are written to `demo_out_tta/` (before/after TTA renders, keypoints, and +optional meshes). Edit `demo_tta.sh` if you want to change the checkpoint, TTA +learning rate, or number of iterations. + +--- + +### Gradio demo + +We also provide a simple Gradio-based web demo for interactive testing in the +browser: + +```bash +python app.py \ + --checkpoint data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt \ + --out_folder demo_out_tta_gradio/ +``` + +This starts a local Gradio app (by default on http://127.0.0.1:7860), where +you can upload images and visualize PRIMA predictions and adaptation results. +The `s1ckpt_inference.ckpt` checkpoint is downloaded automatically if missing. + +#### Hugging Face Space (maintainers) + +Demo images under `demo_data/` and `images/teaser.png` are tracked with **Git LFS** +(see `.gitattributes`) so they can be pushed to a Hugging Face Space under the Hub’s +LFS / **Xet** bridge. Install tooling once: + +```bash +brew install git-lfs git-xet +git xet install +git lfs install +``` + +Then from a clean checkout with LFS files present, redeploy the Space (same as `clean_redeploy_hf_space.sh`): + +```bash +./scripts/deploy_hf_space.sh +# or +./scripts/clean_redeploy_hf_space.sh +``` + +The script rsyncs the working tree (not `git archive`) so image files are materialized +before `git add` turns them into LFS blobs. + +--- + + +## Training and Evaluation + +### Dataset Setup + +Download datasets from [Animal3D](https://xujiacong.github.io/Animal3D/), [CtrlAni3D](https://github.com/luoxue-star/AniMer?tab=readme-ov-file#training), Quadruped2D, and [Animal Kingdom](https://drive.google.com/file/d/1dk2a0qB0fbVZ4X6eAgP6VJVXj0rxVfsJ/view?usp=drive_link). For Quadruped2D, download the images from [SuperAnimal-Quadruped80K](https://zenodo.org/records/14016777) and our processed annotations from [here](https://drive.google.com/drive/folders/1eBNboxVwl_eGPoC93zxf-U3hmE6e2f-f?usp=sharing). Put all the datasets under `datasets/`. + +### Training + +Two-stage training script: + +```bash +bash train.sh +``` + +Training outputs are written to `logs/train/runs//`. + + +### Evaluation + +```bash +python eval.py \ + --config data/PRIMAS1/.hydra/config.yaml \ + --checkpoint data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt +``` + +Common values for `--dataset` are controlled by: +- `configs_hydra/experiment/default_val.yaml` + +--- + + +## Acknowledgements + +This release builds on several open-source projects, including: +- [Detectron2](https://github.com/facebookresearch/detectron2) +- [BioCLIP](https://github.com/Imageomics/BioCLIP) +- [AniMer](https://github.com/luoxue-star/AniMer) +- [DeepLabCut](https://github.com/DeepLabCut/DeepLabCut) +- [SAM3DB](https://github.com/facebookresearch/sam-3d-body) + +--- + +## Citation + +If you use this code in your research, please cite our PRIMA paper. + +```bibtex +@misc{yu_prima, + title={PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation}, + author={Xiaohang Yu and Ti Wang and Mackenzie Weygandt Mathis}, +} +``` + +--- + +## Contact + +For issues, please open a GitHub issue in this repository. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d41b252ea27408a81f596b4bc568662135521378 --- /dev/null +++ b/app.py @@ -0,0 +1,660 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +"""Gradio demo for PRIMA + SuperAnimal + TTA. + +This script wraps the ``demo_tta.py`` pipeline into an interactive +Gradio interface. The overall logic follows: + +1. Given an input image, run Detectron2 to detect animals. +2. For each detected animal, run PRIMA for 3D pose/shape estimation. +3. Run the fine-tuned DeepLabCut SuperAnimal model to obtain PRIMA 26-keypoint + 2D predictions. +4. Run test-time adaptation (TTA) with user-specified lr and iters. +5. Render and save before/after TTA results and keypoint visualizations. + +""" + +import argparse +import concurrent.futures +import os +import queue +import sys +import tempfile +import time +import traceback +from types import SimpleNamespace +from typing import Callable, List, Tuple +from pathlib import Path + +import cv2 +import gradio as gr +import numpy as np +import torch +import torch.utils.data + +# Repo-local minimal ``chumpy`` shim (see ``chumpy/__init__.py``) so SMAL pickles load +# without installing the full chumpy package in Space builds. +_REPO_ROOT = Path(__file__).resolve().parent +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from prima.utils.weights import ( + DEFAULT_HF_REPO_ID, + resolve_prima_checkpoint_path, +) +from prima.utils.detection import select_animal_boxes + + +# Default checkpoint path following README instructions +DEFAULT_CHECKPOINT = str(_REPO_ROOT / "data" / "PRIMAS1" / "checkpoints" / "s1ckpt_inference.ckpt") +DEFAULT_HF_ASSET_REPO = DEFAULT_HF_REPO_ID + +# Output folder for rendered images/meshes and keypoints +DEFAULT_OUT_FOLDER = "demo_out_tta_gradio" + + +def _is_truthy_env(var_name: str) -> bool: + return os.environ.get(var_name, "").strip().lower() in {"1", "true", "yes", "on"} + + +def _running_on_space() -> bool: + return bool(os.environ.get("SPACE_ID") or os.environ.get("HF_SPACE_ID")) + + +def _gradio_examples_for_interface() -> List[List]: + """Gradio prefetches example media at startup. + + Demo images are tracked with Git LFS / Xet (see ``.gitattributes``) so they can live + in the Hugging Face Space repo. Use absolute paths only when files exist beside ``app.py``. + """ + if _is_truthy_env("PRIMA_DISABLE_GRADIO_EXAMPLES"): + return [] + rows: List[List] = [] + template: List[Tuple[str, float, int, float, float, bool, bool]] = [ + ("demo_data/000000015956_horse.png", 1e-6, 0, 0.7, 0.1, False, True), + ("demo_data/n02412080_12159.png", 1e-6, 0, 0.7, 0.1, False, True), + ("demo_data/000000315905_zebra.jpg", 1e-6, 0, 0.7, 0.1, False, True), + ("demo_data/beagle.jpg", 1e-6, 0, 0.7, 0.1, False, True), + ("demo_data/shepherd_hati.jpg", 1e-6, 0, 0.7, 0.1, False, True), + ] + for rel, *rest in template: + p = _REPO_ROOT / rel + if p.is_file(): + rows.append([str(p), *rest]) + return rows + + +def _should_preload_assets() -> bool: + """Default to preload on Spaces; configurable via PRIMA_PRELOAD_ASSETS.""" + preload_env = os.environ.get("PRIMA_PRELOAD_ASSETS") + if preload_env is not None: + return _is_truthy_env("PRIMA_PRELOAD_ASSETS") + return _running_on_space() + +def _gradio_heartbeat_interval_sec() -> float: + """How often to yield status while waiting on long CPU/GPU work (keeps WebSockets alive). + + Set ``PRIMA_GRADIO_HEARTBEAT_SEC`` to ``0`` to run long work on the Gradio thread (old behavior). + """ + raw = os.environ.get("PRIMA_GRADIO_HEARTBEAT_SEC", "25").strip() + try: + v = float(raw) + except ValueError: + return 25.0 + return max(0.0, v) + + +def _preload_assets_once(checkpoint_path: str) -> None: + print("[startup] Ensuring demo assets from Hugging Face Hub...") + resolve_prima_checkpoint_path( + checkpoint_path, + data_dir=_REPO_ROOT / "data", + auto_download=True, + hf_repo_id=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO), + ) + print("[startup] Asset preload complete.") + + +def _load_prima_model(checkpoint_path: str = DEFAULT_CHECKPOINT): + """Load PRIMA model and renderer once for the Gradio app.""" + from prima.models import load_prima + from prima.utils.renderer import Renderer, cam_crop_to_full + + checkpoint_path = resolve_prima_checkpoint_path( + checkpoint_path, + data_dir=_REPO_ROOT / "data", + auto_download=True, + hf_repo_id=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO), + ) + checkpoint = Path(checkpoint_path) + cfg_path = checkpoint.parent.parent / ".hydra" / "config.yaml" + if not checkpoint.exists(): + raise FileNotFoundError( + f"Missing checkpoint: {checkpoint}. Download demo checkpoints/data as described in README." + ) + if not cfg_path.exists(): + raise FileNotFoundError( + f"Missing model config: {cfg_path}. Ensure the full checkpoint folder layout from README is present." + ) + + model, model_cfg = load_prima(checkpoint_path) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + model.eval() + + renderer = Renderer(model_cfg, faces=model.smal.faces) + return model, model_cfg, renderer, cam_crop_to_full, device + + +def _build_detector(): + """Build Detectron2 animal detector (same config as demo_tta/demo.py).""" + try: + import detectron2.config + import detectron2.engine + from detectron2 import model_zoo + except Exception as e: + print(f"[warn] Detectron2 unavailable ({type(e).__name__}: {e}); using full-image fallback bbox.") + return None + + cfg = detectron2.config.get_cfg() + cfg.merge_from_file( + model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml") + ) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 + cfg.MODEL.WEIGHTS = ( + "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/" + "faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl" + ) + cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + detector = detectron2.engine.DefaultPredictor(cfg) + return detector + + +def _load_model_and_detector_for_demo(checkpoint_path: str): + """Run on a worker thread when using heartbeat polling (single entry point for executor).""" + model, model_cfg, renderer, cam_crop_to_full_fn, device = _load_prima_model(checkpoint_path) + detector = _build_detector() + return model, model_cfg, renderer, cam_crop_to_full_fn, device, detector + + +# SuperAnimal defaults (same as in demo_tta parser) +SUPER_ANIMAL_ARGS = SimpleNamespace( + superanimal_name="superanimal_quadruped", + superanimal_model_name="hrnet_w32", + superanimal_detector_name="fasterrcnn_resnet50_fpn_v2", + superanimal_max_individuals=1, + saved_2d_model_path="", + pytorch_config_2d_path=str(_REPO_ROOT / "configs" / "sa_finetune_hrnet_w32.yaml"), +) + + +def _collect_animal_results( + model, + model_cfg, + renderer, + cam_crop_to_full_fn, + device, + detector, + out_folder: str, + img_rgb: np.ndarray, + tta_lr: float, + tta_num_iters: int, + det_thresh: float, + kp_conf_thresh: float, + side_view: bool, + save_mesh: bool, + progress_callback: Callable[[str], None] | None = None, +) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], str | None, str | None]: + """Run detection + PRIMA + SuperAnimal + TTA on a single RGB image. + + Returns: + before_imgs: list of HxWx3 RGB images (before TTA) for all animals + after_imgs: list of HxWx3 RGB images (after TTA) for all animals + kpt_imgs: list of HxWx3 RGB keypoint visualizations + first_before_mesh: path to first animal's before-TTA mesh (.obj) or None + first_after_mesh: path to first animal's after-TTA mesh (.obj) or None + """ + from prima.utils import recursive_to + from prima.datasets.vitdet_dataset import ViTDetDataset + from demo_tta import ( + denorm_patch_to_rgb, + resolve_sa_weights_path, + run_superanimal_on_patch, + save_keypoint_vis, + tta_optimize, + ) + + def report(message: str) -> None: + if progress_callback is not None: + progress_callback(message) + + if int(tta_num_iters) > 0 and not SUPER_ANIMAL_ARGS.saved_2d_model_path: + report("Resolving SuperAnimal weights...") + SUPER_ANIMAL_ARGS.saved_2d_model_path = resolve_sa_weights_path("") + + # Detect animals + img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + if detector is None: + # Fallback for environments where Detectron2 is unavailable: process full image as one crop. + report("Detectron2 unavailable; using full-image crop...") + h, w = img_bgr.shape[:2] + boxes = np.array([[0.0, 0.0, float(max(1, w - 1)), float(max(1, h - 1))]], dtype=np.float32) + else: + report("Detecting animals with Detectron2...") + det_out = detector(img_bgr) + det_instances = det_out["instances"] + + boxes, suppressed = select_animal_boxes(det_instances, score_threshold=float(det_thresh)) + if suppressed > 0: + print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s)") + if len(boxes) == 0: + return [], [], [], None, None + + report(f"Detected {len(boxes)} animal(s). Preparing crops...") + dataset = ViTDetDataset(model_cfg, img_bgr, boxes) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + + before_imgs: List[np.ndarray] = [] + after_imgs: List[np.ndarray] = [] + kpt_imgs: List[np.ndarray] = [] + before_mesh_paths: List[str] = [] + after_mesh_paths: List[str] = [] + + img_token = next(tempfile._get_candidate_names()) + + total_batches = len(dataloader) + for batch_idx, batch in enumerate(dataloader, start=1): + batch = recursive_to(batch, device) + + report(f"Animal {batch_idx}/{total_batches}: running PRIMA...") + with torch.no_grad(): + out_before = model(batch) + + animal_id = int(batch["animalid"][0]) + + # Save/render before TTA + img_fn = f"{img_token}" + from demo_tta import render_and_save # imported lazily to avoid circular issues + + report(f"Animal {batch_idx}/{total_batches}: rendering before TTA...") + render_and_save( + renderer, + cam_crop_to_full_fn, + out_before, + batch, + img_fn, + animal_id, + out_folder, + suffix="before_tta", + side_view=side_view, + save_mesh=save_mesh, + ) + + before_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.png") + if os.path.exists(before_png_path): + before_bgr = cv2.imread(before_png_path) + if before_bgr is not None: + before_imgs.append(cv2.cvtColor(before_bgr, cv2.COLOR_BGR2RGB)) + + if save_mesh: + before_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.obj") + if os.path.exists(before_obj_path): + before_mesh_paths.append(before_obj_path) + + if int(tta_num_iters) <= 0: + report(f"Animal {batch_idx}/{total_batches}: rendering final output...") + render_and_save( + renderer, + cam_crop_to_full_fn, + out_before, + batch, + img_fn, + animal_id, + out_folder, + suffix="after_tta", + side_view=side_view, + save_mesh=save_mesh, + ) + + after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png") + if os.path.exists(after_png_path): + after_bgr = cv2.imread(after_png_path) + if after_bgr is not None: + after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB)) + + if save_mesh: + after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj") + if os.path.exists(after_obj_path): + after_mesh_paths.append(after_obj_path) + continue + + # Prepare patch for SuperAnimal + report(f"Animal {batch_idx}/{total_batches}: running SuperAnimal keypoints...") + patch_rgb = denorm_patch_to_rgb(batch["img"][0]) + with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir: + bodyparts_xyc = run_superanimal_on_patch(patch_rgb, SUPER_ANIMAL_ARGS, tmp_dir) + + if bodyparts_xyc is None: + # No keypoints => skip TTA for this animal + continue + + kpts_xyc = bodyparts_xyc + kpts_xyc[kpts_xyc[:, 2] < float(kp_conf_thresh), 2] = 0.0 + + # Save keypoint visualization and npy + kpt_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png") + save_keypoint_vis(patch_rgb, kpts_xyc, kpt_png_path) + npy_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy") + np.save(npy_path, kpts_xyc) + + if os.path.exists(kpt_png_path): + kpt_bgr = cv2.imread(kpt_png_path) + if kpt_bgr is not None: + kpt_imgs.append(cv2.cvtColor(kpt_bgr, cv2.COLOR_BGR2RGB)) + + # Normalize keypoints to [-0.5, 0.5] as in demo_tta + patch_h, patch_w = patch_rgb.shape[:2] + kpts_norm = kpts_xyc.copy() + kpts_norm[:, 0] = kpts_norm[:, 0] / float(patch_w) - 0.5 + kpts_norm[:, 1] = kpts_norm[:, 1] / float(patch_h) - 0.5 + gt_kpts_norm = torch.from_numpy(kpts_norm[None]).to(device=device, dtype=batch["img"].dtype) + + # Run TTA + report(f"Animal {batch_idx}/{total_batches}: running TTA ({int(tta_num_iters)} iterations)...") + out_after = tta_optimize( + model, + batch, + gt_kpts_norm, + num_iters=int(tta_num_iters), + lr=float(tta_lr), + ) + + report(f"Animal {batch_idx}/{total_batches}: rendering after TTA...") + render_and_save( + renderer, + cam_crop_to_full_fn, + out_after, + batch, + img_fn, + animal_id, + out_folder, + suffix="after_tta", + side_view=side_view, + save_mesh=save_mesh, + ) + + after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png") + if os.path.exists(after_png_path): + after_bgr = cv2.imread(after_png_path) + if after_bgr is not None: + after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB)) + + if save_mesh: + after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj") + if os.path.exists(after_obj_path): + after_mesh_paths.append(after_obj_path) + + first_before_mesh = before_mesh_paths[0] if before_mesh_paths else None + first_after_mesh = after_mesh_paths[0] if after_mesh_paths else None + + report("Collecting outputs...") + return before_imgs, after_imgs, kpt_imgs, first_before_mesh, first_after_mesh + + +def build_demo(checkpoint_path: str = DEFAULT_CHECKPOINT, out_folder: str = DEFAULT_OUT_FOLDER) -> gr.Interface: + os.makedirs(out_folder, exist_ok=True) + runtime_cache = { + "model": None, + "model_cfg": None, + "renderer": None, + "cam_crop_to_full_fn": None, + "device": None, + "detector": None, + } + + def gradio_inference( + image: np.ndarray, + tta_lr: float, + tta_num_iters: int, + det_thresh: float, + kp_conf_thresh: float, + side_view: bool, + save_mesh: bool, + ): + """Wrapper for Gradio. ``image`` is an RGB numpy array. + + Yields intermediate status so long first-run (Hub downloads + model load) + and long inference do not hit silent client/proxy WebSocket timeouts. + """ + + if image is None: + yield None, None, None, "No image provided." + return + + if image.dtype != np.uint8: + img_rgb = np.clip(image, 0, 255).astype(np.uint8) + else: + img_rgb = image + + yield None, None, None, "Queued; preparing run…" + + hb = _gradio_heartbeat_interval_sec() + + if runtime_cache["model"] is None: + yield ( + None, + None, + None, + "First run: downloading demo assets from Hugging Face (large checkpoint) " + "and loading the model. This can take many minutes; status updates here " + "mean the session is still alive.", + ) + try: + if hb <= 0: + model, model_cfg, renderer, cam_crop_to_full_fn, device, detector = _load_model_and_detector_for_demo( + checkpoint_path + ) + else: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + fut = pool.submit(_load_model_and_detector_for_demo, checkpoint_path) + t0 = time.monotonic() + while True: + try: + model, model_cfg, renderer, cam_crop_to_full_fn, device, detector = fut.result(timeout=hb) + break + except concurrent.futures.TimeoutError: + elapsed = int(time.monotonic() - t0) + yield None, None, None, ( + f"First run: still loading model and assets ({elapsed}s). " + f"Updates every ~{int(hb)}s keep the browser connection open on Spaces." + ) + except Exception: + yield None, None, None, f"Model initialization failed:\n{traceback.format_exc()}" + return + runtime_cache["model"] = model + runtime_cache["model_cfg"] = model_cfg + runtime_cache["renderer"] = renderer + runtime_cache["cam_crop_to_full_fn"] = cam_crop_to_full_fn + runtime_cache["device"] = device + runtime_cache["detector"] = detector + yield None, None, None, "Model loaded. Running detection and inference…" + + try: + if hb <= 0: + before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = _collect_animal_results( + runtime_cache["model"], + runtime_cache["model_cfg"], + runtime_cache["renderer"], + runtime_cache["cam_crop_to_full_fn"], + runtime_cache["device"], + runtime_cache["detector"], + out_folder, + img_rgb, + tta_lr=tta_lr, + tta_num_iters=tta_num_iters, + det_thresh=det_thresh, + kp_conf_thresh=kp_conf_thresh, + side_view=side_view, + save_mesh=save_mesh, + ) + else: + stage_updates: queue.Queue[str] = queue.Queue() + + def report_stage(message: str) -> None: + stage_updates.put(message) + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + fut = pool.submit( + _collect_animal_results, + runtime_cache["model"], + runtime_cache["model_cfg"], + runtime_cache["renderer"], + runtime_cache["cam_crop_to_full_fn"], + runtime_cache["device"], + runtime_cache["detector"], + out_folder, + img_rgb, + tta_lr, + tta_num_iters, + det_thresh, + kp_conf_thresh, + side_view, + save_mesh, + report_stage, + ) + t0 = time.monotonic() + latest_stage = "Starting inference..." + while True: + while True: + try: + latest_stage = stage_updates.get_nowait() + except queue.Empty: + break + else: + elapsed = int(time.monotonic() - t0) + yield None, None, None, f"{latest_stage}\nElapsed: {elapsed}s" + try: + before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = fut.result( + timeout=1.0 + ) + break + except concurrent.futures.TimeoutError: + elapsed = int(time.monotonic() - t0) + yield None, None, None, ( + f"{latest_stage}\n" + f"Elapsed: {elapsed}s\n" + "CPU inference can take several minutes." + ) + except Exception: + yield None, None, None, f"Inference failed:\n{traceback.format_exc()}" + return + + first_before = before_imgs[0] if before_imgs else None + first_after = after_imgs[0] if after_imgs else None + first_kpts = kpt_imgs[0] if kpt_imgs else None + if first_before is None and first_after is None: + yield ( + None, + None, + None, + "No output generated. Try an image with a clearly visible quadruped.", + ) + return + yield first_before, first_after, first_kpts, "OK" + + _gradio_examples = _gradio_examples_for_interface() + _iface_kw = dict( + fn=gradio_inference, + analytics_enabled=False, + cache_examples=False, + inputs=[ + gr.Image( + label="Input image", + type="numpy", + sources=["upload", "clipboard"], + ), + gr.Slider( + label="TTA learning rate", + minimum=1e-7, + maximum=1e-4, + value=1e-6, + step=1e-7, + ), + gr.Slider( + label="TTA iterations", + minimum=0, + maximum=100, + value=0, + step=1, + info="Set to 0 to disable TTA and reuse the initial PRIMA prediction.", + ), + gr.Slider( + label="Detection threshold", + minimum=0.3, + maximum=0.9, + value=0.7, + step=0.05, + ), + gr.Slider( + label="Keypoint confidence threshold", + minimum=0.0, + maximum=1.0, + value=0.1, + step=0.05, + ), + gr.Checkbox(label="Render side view", value=False), + gr.Checkbox(label="Save meshes (.obj)", value=True), + ], + outputs=[ + gr.Image(label="Before TTA"), + gr.Image(label="After TTA"), + gr.Image(label="PRIMA 26 keypoints"), + gr.Textbox(label="Status / Traceback", lines=12), + ], + title="PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation", + description=( + "Upload an animal image. The demo runs Detectron2 for animal detection, " + "PRIMA for 3D pose/shape, DeepLabCut SuperAnimal for 2D keypoints, and " + "test-time adaptation (TTA) with configurable learning rate and iterations. " + "Set TTA iterations to 0 to disable adaptation.\n\n" + "Results (PNG/OBJ and 26-keypoint visualizations) are saved under " + f"'{out_folder}'." + ), + ) + if _gradio_examples: + _iface_kw["examples"] = _gradio_examples + demo = gr.Interface(**_iface_kw) + demo.queue(max_size=8, default_concurrency_limit=1) + return demo + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Gradio demo for PRIMA + SuperAnimal + TTA") + parser.add_argument( + "--checkpoint", + type=str, + default=DEFAULT_CHECKPOINT, + help="Path to the pretrained PRIMA checkpoint", + ) + parser.add_argument( + "--out_folder", + type=str, + default=DEFAULT_OUT_FOLDER, + help="Folder used to save rendered outputs and meshes", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + if _should_preload_assets(): + _preload_assets_once(args.checkpoint) + demo = build_demo(checkpoint_path=args.checkpoint, out_folder=args.out_folder) + demo.launch(inbrowser=False, ssr_mode=False) diff --git a/chumpy/__init__.py b/chumpy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..404174f2de27186743ef118ef37df9f2cef0064a --- /dev/null +++ b/chumpy/__init__.py @@ -0,0 +1,16 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from __future__ import annotations + +# Minimal ``chumpy`` compatibility for unpickling legacy SMAL model configs. + +from .ch import Ch, ChArray + +__all__ = ["Ch", "ChArray"] diff --git a/chumpy/ch.py b/chumpy/ch.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b5db98bf73d67e46dd12bd104fccdd1d7472bb --- /dev/null +++ b/chumpy/ch.py @@ -0,0 +1,52 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from __future__ import annotations + +# ``chumpy.ch`` namespace expected by legacy SMAL pickles. + +import numpy as np + + +class Ch: + """Minimal stand-in for ``chumpy.ch.Ch`` (unpickling only).""" + + def __init__(self, *args, **kwargs): + self._data = None + if args: + self._data = np.asarray(args[0]) + + def _resolve(self) -> np.ndarray: + # Real chumpy Ch instances store the underlying ndarray on attribute ``x``; + # legacy pickles unpickle by restoring ``__dict__`` without calling ``__init__``, + # so try common attribute names before falling back to ``_data``. + for attr in ("x", "_x", "_data"): + val = self.__dict__.get(attr) + if val is not None: + return np.asarray(val) + return np.zeros((), dtype=np.float32) + + def __array__(self, dtype=None): + arr = self._resolve() + if dtype is not None: + arr = arr.astype(dtype, copy=False) + return arr + + @property + def r(self) -> np.ndarray: + return self._resolve() + + +class ChArray(np.ndarray): + """Minimal stand-in for ``chumpy.ch.ChArray``.""" + + pass + + +__all__ = ["Ch", "ChArray"] diff --git a/configs/sa_finetune_hrnet_w32.yaml b/configs/sa_finetune_hrnet_w32.yaml new file mode 100644 index 0000000000000000000000000000000000000000..83dd64e2b05982bb1e4782a627528bea2c81b85a --- /dev/null +++ b/configs/sa_finetune_hrnet_w32.yaml @@ -0,0 +1,220 @@ +# DeepLabCut pytorch_config for the PRIMA TTA 2D pose model: +# SuperAnimal-Quadruped HRNet-w32 backbone fine-tuned on Animal3D, with +# the heatmap head re-trained for the 26-joint Animal3D / PRIMA layout. +# +# Used by demo_tta.py via DLC's `superanimal_analyze_images(..., +# customized_model_config=, customized_pose_checkpoint=)`. Only the pose model is fine-tuned; the bounding-box +# detector (Faster R-CNN) is the stock SuperAnimal-Quadruped one +# resolved by DLC at runtime. +data: + bbox_margin: 20 + colormode: RGB + inference: + normalize_images: true + top_down_crop: + width: 256 + height: 256 + auto_padding: + pad_width_divisor: 32 + pad_height_divisor: 32 + train: + affine: + p: 0.5 + rotation: 30 + scaling: + - 1.0 + - 1.0 + translation: 0 + gaussian_noise: 12.75 + motion_blur: true + normalize_images: true + top_down_crop: + width: 256 + height: 256 + auto_padding: + pad_width_divisor: 32 + pad_height_divisor: 32 +detector: + data: + colormode: RGB + inference: + normalize_images: true + train: + affine: + p: 0.5 + rotation: 30 + scaling: + - 1.0 + - 1.0 + translation: 40 + collate: + type: ResizeFromDataSizeCollate + min_scale: 0.4 + max_scale: 1.0 + min_short_side: 128 + max_short_side: 1152 + multiple_of: 32 + to_square: false + hflip: true + normalize_images: true + device: auto + model: + type: FasterRCNN + freeze_bn_stats: true + freeze_bn_weights: false + variant: fasterrcnn_resnet50_fpn_v2 + runner: + type: DetectorTrainingRunner + key_metric: test.mAP@50:95 + key_metric_asc: true + eval_interval: 10 + optimizer: + type: AdamW + params: + lr: 0.0001 + scheduler: + type: LRListScheduler + params: + milestones: + - 160 + lr_list: + - - 1e-05 + snapshots: + max_snapshots: 5 + save_epochs: 25 + save_optimizer_state: false + train_settings: + batch_size: 1 + dataloader_workers: 0 + dataloader_pin_memory: false + display_iters: 500 + epochs: 250 +device: auto +inference: + multithreading: + enabled: true + queue_length: 4 + timeout: 30.0 + compile: + enabled: false + backend: inductor + autocast: + enabled: false +metadata: + project_path: "" + pose_config_path: "" + bodyparts: + - left_eye + - right_eye + - chin + - left_front_paw + - right_front_paw + - left_back_paw + - right_back_paw + - tail_base + - left_front_thigh + - right_front_thigh + - left_back_thigh + - right_back_thigh + - left_shoulder + - right_shoulder + - left_front_knee + - right_front_knee + - left_back_knee + - right_back_knee + - neck_base + - tail_mid + - left_ear_base + - right_ear_base + - left_mouth_corner + - right_mouth_corner + - nose + - tail_tip_first + unique_bodyparts: [] + individuals: + - individual000 + with_identity: false +method: td +model: + backbone: + type: HRNet + model_name: hrnet_w32 + freeze_bn_stats: true + freeze_bn_weights: false + interpolate_branches: false + increased_channel_count: false + backbone_output_channels: 32 + heads: + bodypart: + type: HeatmapHead + weight_init: normal + predictor: + type: HeatmapPredictor + apply_sigmoid: false + clip_scores: true + location_refinement: true + locref_std: 7.2801 + target_generator: + type: HeatmapGaussianGenerator + num_heatmaps: 26 + pos_dist_thresh: 17 + heatmap_mode: KEYPOINT + gradient_masking: true + background_weight: 0.0 + generate_locref: true + locref_std: 7.2801 + criterion: + heatmap: + type: WeightedMSECriterion + weight: 1.0 + locref: + type: WeightedHuberCriterion + weight: 0.05 + heatmap_config: + channels: + - 32 + kernel_size: [] + strides: [] + final_conv: + out_channels: 26 + kernel_size: 1 + locref_config: + channels: + - 32 + kernel_size: [] + strides: [] + final_conv: + out_channels: 52 + kernel_size: 1 +net_type: hrnet_w32 +runner: + type: PoseTrainingRunner + gpus: + key_metric: test.mAP + key_metric_asc: true + eval_interval: 10 + optimizer: + type: AdamW + params: + lr: 0.0001 + scheduler: + type: LRListScheduler + params: + lr_list: + - - 1e-05 + - - 1e-06 + milestones: + - 160 + - 190 + snapshots: + max_snapshots: 5 + save_epochs: 10 + save_optimizer_state: false +train_settings: + batch_size: 64 + dataloader_workers: 8 + dataloader_pin_memory: false + display_iters: 500 + epochs: 200 + seed: 42 diff --git a/demo_data/000000015956_horse.png b/demo_data/000000015956_horse.png new file mode 100644 index 0000000000000000000000000000000000000000..15fab77e07b6b732852bcb48fad389bcefc94e98 --- /dev/null +++ b/demo_data/000000015956_horse.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a2398ba7df40a47c636afefa28be17b55f4b7bc2c378e053aeea507580ad2cb +size 620466 diff --git a/demo_data/000000315905_zebra.jpg b/demo_data/000000315905_zebra.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c476bdf86fb551db280f8bba613b83d673f4768c --- /dev/null +++ b/demo_data/000000315905_zebra.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0a17e1f1650820b020a9025144015c1e27f0f1ab435859f0bde3a0047d8f689 +size 257420 diff --git a/demo_data/beagle.jpg b/demo_data/beagle.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e3ee9223e2597f56e8ebf1e70d0d8bac82e90dbb --- /dev/null +++ b/demo_data/beagle.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac29e6ea6086831dd9806a8cd3fd608e264ac1af567f6fcfc8797c5bd3d5d560 +size 349657 diff --git a/demo_data/n02101388_1188.png b/demo_data/n02101388_1188.png new file mode 100644 index 0000000000000000000000000000000000000000..873d5861f7bff57812afab192a747b02647a3084 --- /dev/null +++ b/demo_data/n02101388_1188.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e45ff508fb8c6437cce22fcb59b4f1b6fe37ddfab1d4cf68d97629f9caa939f4 +size 318688 diff --git a/demo_data/n02412080_12159.png b/demo_data/n02412080_12159.png new file mode 100644 index 0000000000000000000000000000000000000000..a6517fa50d8e1b094a07468ea64bb8bb21cdba07 --- /dev/null +++ b/demo_data/n02412080_12159.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03273c57e8b25b258d3eb96af7b4f77b43b5c40be90da83c21875f3322b487f1 +size 347450 diff --git a/demo_data/shepherd_hati.jpg b/demo_data/shepherd_hati.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4fd79ba7a132f502833df56a884d89a58c8a086a --- /dev/null +++ b/demo_data/shepherd_hati.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65c5878203bc3165dda9011ebfce77cc7d930daed0a215396d8036509d1963c1 +size 209710 diff --git a/demo_tta.py b/demo_tta.py new file mode 100644 index 0000000000000000000000000000000000000000..c8257c2712af180319e1f8959eeda8e6d1c583d2 --- /dev/null +++ b/demo_tta.py @@ -0,0 +1,399 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +""" +demo_tta.py: PRIMA inference with fine-tuned DeepLabCut SuperAnimal TTA + +Pipeline: +1. Run Detectron2 to detect animals in the input image. +2. Run PRIMA on each detected animal to obtain 3D pose/shape estimation. +3. Run a fine-tuned DeepLabCut SuperAnimal pose model (Animal3D 26-joint + layout) to obtain 2D keypoints already in PRIMA topology. The fine-tuned + snapshot is wired into DLC's + ``superanimal_analyze_images`` via the ``customized_pose_checkpoint`` + and ``customized_model_config`` kwargs. +4. Run test-time adaptation (TTA) with user-specified lr and num_iters + to further optimize the 3D pose and shape estimation. +5. Render and save before/after TTA results (PNG + OBJ) and the + 26-keypoint visualization (PNG). +""" + + +from pathlib import Path +import argparse +import copy +import os +import tempfile +import warnings + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.data +from tqdm import tqdm + +from prima.models import load_prima +from prima.utils import recursive_to +from prima.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD +from prima.utils.detection import ANIMAL_COCO_IDS, select_animal_boxes +from prima.utils.weights import DEFAULT_HF_REPO_ID, resolve_prima_checkpoint_path + +warnings.filterwarnings("ignore") + +LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353) +GREEN = (0.65, 0.86, 0.74) + +REPO_ROOT = Path(__file__).resolve().parent + + +def load_renderer_components(): + try: + from prima.utils.renderer import Renderer, cam_crop_to_full + except Exception as exc: + raise RuntimeError( + "Cannot initialize the PRIMA renderer. Rendering requires a working " + "pyrender/OpenGL backend such as EGL or OSMesa. Install the missing " + "OpenGL runtime for this environment, or run in an environment where " + "PYOPENGL_PLATFORM=egl/osmesa works." + ) from exc + return Renderer, cam_crop_to_full + + +def denorm_patch_to_rgb(img_tensor: torch.Tensor) -> np.ndarray: + patch = (img_tensor.detach().cpu() * (DEFAULT_STD[:, None, None]) + DEFAULT_MEAN[:, None, None]) / 255.0 + patch = patch.permute(1, 2, 0).numpy() + return np.clip(patch, 0.0, 1.0) + + +def save_keypoint_vis(patch_rgb: np.ndarray, kpts_xyc: np.ndarray, save_path: str) -> None: + vis = cv2.cvtColor((patch_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2BGR).copy() + num_kpts = len(kpts_xyc) + + for i, (x, y, c) in enumerate(kpts_xyc): + if c <= 0: + continue + + # Use distinct color for each keypoint (OpenCV uses BGR) + hue = int(179 * i / max(1, num_kpts - 1)) + color_bgr = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2BGR)[0, 0] + color_bgr = (int(color_bgr[0]), int(color_bgr[1]), int(color_bgr[2])) + + cx, cy = int(round(float(x))), int(round(float(y))) + cv2.circle(vis, (cx, cy), 3, color_bgr, -1) + cv2.putText(vis, str(i), (cx + 3, cy - 3), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1, cv2.LINE_AA) + + cv2.imwrite(save_path, vis) + + +def resolve_sa_weights_path(local_path: str) -> str: + """Return a local path to the fine-tuned SuperAnimal .pt snapshot. + + If ``local_path`` is empty, downloads ``sa_finetune_hrnet_w32.pt`` from the + ``MLAdaptiveIntelligence/FMPose3D`` Hugging Face repo (cached under + ``~/.cache/huggingface``). + """ + if local_path: + return local_path + try: + from huggingface_hub import hf_hub_download + except ImportError: + raise ImportError( + "huggingface_hub is required to auto-download the fine-tuned " + "SuperAnimal weights. Install with `pip install huggingface_hub`, " + "or pass --saved_2d_model_path with a local .pt file." + ) from None + repo_id = "MLAdaptiveIntelligence/FMPose3D" + filename = "sa_finetune_hrnet_w32.pt" + try: + cached_path = hf_hub_download(repo_id=repo_id, filename=filename, local_files_only=True) + except Exception: + print(f"No --saved_2d_model_path provided; downloading '{filename}' from {repo_id}...") + return hf_hub_download(repo_id=repo_id, filename=filename) + + print(f"Using cached SuperAnimal weights: {cached_path}") + return cached_path + + +def run_superanimal_on_patch(patch_rgb: np.ndarray, args, tmp_dir: str): + """Predict 26-joint 2D keypoints on a single PRIMA patch using a + fine-tuned DeepLabCut SuperAnimal snapshot. + + Returns an ``(26, 3)`` array of ``(x, y, confidence)`` in patch + pixel coordinates, or ``None`` if no individual was detected. + """ + try: + from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images + except Exception as e: + raise RuntimeError( + "Cannot import DeepLabCut SuperAnimal API. Please install deeplabcut with pose_estimation_pytorch support." + ) from e + + patch_path = os.path.join(tmp_dir, "patch.png") + cv2.imwrite(patch_path, cv2.cvtColor((patch_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + + dlc_device = "cuda" if torch.cuda.is_available() else "cpu" + preds = superanimal_analyze_images( + superanimal_name=args.superanimal_name, + model_name=args.superanimal_model_name, + detector_name=args.superanimal_detector_name, + images=patch_path, + max_individuals=args.superanimal_max_individuals, + out_folder=tmp_dir, + device=dlc_device, + customized_model_config=args.pytorch_config_2d_path, + customized_pose_checkpoint=args.saved_2d_model_path, + ) + + payload = preds.get(patch_path, None) + if payload is None: + return None + bodyparts = payload.get("bodyparts", None) + if bodyparts is None or len(bodyparts) == 0: + return None + + best_idx = int(np.argmax(bodyparts[..., 2].mean(axis=1))) + return bodyparts[best_idx].astype(np.float32) + + +def render_and_save(renderer, cam_crop_to_full_fn, out, batch, img_fn, animal_id, out_folder, suffix, side_view, save_mesh): + pred_cam = out['pred_cam'] + box_center = batch['box_center'].float() + box_size = batch['box_size'].float() + img_size = batch['img_size'].float() + scaled_focal_length = batch['focal_length'][0, 0] / batch['img'].shape[-1] * img_size.max() + pred_cam_t_full = cam_crop_to_full_fn(pred_cam, box_center, box_size, img_size, scaled_focal_length) + + white_img = (torch.ones_like(batch['img'][0]).cpu() - DEFAULT_MEAN[:, None, None] / 255) / ( + DEFAULT_STD[:, None, None] / 255 + ) + input_patch = denorm_patch_to_rgb(batch['img'][0]) + + regression_img = renderer( + out['pred_vertices'][0].detach().cpu().numpy(), + out['pred_cam_t'][0].detach().cpu().numpy(), + batch['img'][0], + mesh_base_color=GREEN, + scene_bg_color=(1, 1, 1), + ) + + final_img = np.concatenate([input_patch, regression_img], axis=1) + if side_view: + side_img = renderer( + out['pred_vertices'][0].detach().cpu().numpy(), + out['pred_cam_t'][0].detach().cpu().numpy(), + white_img, + mesh_base_color=GREEN, + scene_bg_color=(1, 1, 1), + side_view=True, + ) + final_img = np.concatenate([final_img, side_img], axis=1) + + cv2.imwrite( + os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.png'), + cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR), + ) + + if save_mesh: + verts = out['pred_vertices'][0].detach().cpu().numpy() + cam_t = pred_cam_t_full[0].detach().cpu().numpy() + tmesh = renderer.vertices_to_trimesh(verts, cam_t.copy(), LIGHT_BLUE) + tmesh.export(os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.obj')) + + +def tta_optimize(model, batch, gt_kpts_norm, num_iters, lr): + model.eval() + + if hasattr(model, 'backbone'): + for p in model.backbone.parameters(): + p.requires_grad = False + + orig_smal_head_state = copy.deepcopy(model.smal_head.state_dict()) + model.smal_head.freeze_except_regression_heads() + tta_params = model.smal_head.get_tta_parameters(mode='all') + optimizer = torch.optim.Adam(tta_params, lr=lr) + + valid_mask = (gt_kpts_norm[..., 2] > 0).float().unsqueeze(-1) + gt_xy = gt_kpts_norm[..., :2] + + for _ in range(num_iters): + optimizer.zero_grad() + out = model(batch) + pred_xy = out['pred_keypoints_2d'] + loss = F.mse_loss(pred_xy * valid_mask, gt_xy * valid_mask, reduction='sum') / (valid_mask.sum() + 1e-6) + loss.backward() + optimizer.step() + + with torch.no_grad(): + out_after = model(batch) + + model.smal_head.load_state_dict(orig_smal_head_state) + model.smal_head.unfreeze_all() + + return out_after + + +def main(): + parser = argparse.ArgumentParser(description='PRIMA + SuperAnimal + TTA demo') + parser.add_argument('--checkpoint', type=str, default='', + help='Path to pretrained PRIMA checkpoint. Empty -> auto-download the default Stage 1 checkpoint.') + parser.add_argument('--hf-repo-id', '--hf_repo_id', dest='hf_repo_id', + type=str, default=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_REPO_ID), + help='Hugging Face repo ID containing PRIMA demo assets') + parser.add_argument('--no-auto-download', '--no_auto_download', dest='no_auto_download', action='store_true', + help='Disable automatic download of missing PRIMA demo assets') + parser.add_argument('--img_path', type=str, default=None, help='Single image path') + parser.add_argument('--img_folder', type=str, default='demo_data/', help='Folder with input images') + parser.add_argument('--out_folder', type=str, default='demo_out_tta', help='Output folder') + parser.add_argument('--side_view', dest='side_view', action='store_true', default=False, help='Render side view') + parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, help='Save meshes') + parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg', '*.JPEG'], help='Image globs') + parser.add_argument('--det_thresh', type=float, default=0.7, help='Detectron2 score threshold for animals') + + parser.add_argument('--tta_lr', type=float, default=1e-6, help='TTA learning rate') + parser.add_argument('--tta_num_iters', type=int, default=30, help='TTA iterations') + parser.add_argument('--kp_conf_thresh', type=float, default=0.1, help='Keypoint confidence threshold') + + parser.add_argument('--superanimal_name', type=str, default='superanimal_quadruped') + parser.add_argument('--superanimal_model_name', type=str, default='hrnet_w32') + parser.add_argument('--superanimal_detector_name', type=str, default='fasterrcnn_resnet50_fpn_v2') + parser.add_argument('--superanimal_max_individuals', type=int, default=1) + parser.add_argument('--saved_2d_model_path', type=str, default='', + help='Path to the fine-tuned SuperAnimal 26-joint .pt snapshot. ' + 'Empty -> auto-download sa_finetune_hrnet_w32.pt from ' + 'MLAdaptiveIntelligence/FMPose3D on Hugging Face Hub.') + parser.add_argument('--pytorch_config_2d_path', type=str, + default=str(Path(__file__).resolve().parent / 'configs' / 'sa_finetune_hrnet_w32.yaml'), + help='Path to the DLC pytorch config yaml for the fine-tuned snapshot. ' + 'Defaults to the bundled configs/sa_finetune_hrnet_w32.yaml.') + + args = parser.parse_args() + checkpoint_path = resolve_prima_checkpoint_path( + args.checkpoint, + data_dir=REPO_ROOT / "data", + auto_download=not args.no_auto_download, + hf_repo_id=args.hf_repo_id, + ) + args.saved_2d_model_path = resolve_sa_weights_path(args.saved_2d_model_path) + + model, model_cfg = load_prima(checkpoint_path) + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + model = model.to(device) + model.eval() + + Renderer, cam_crop_to_full_fn = load_renderer_components() + renderer = Renderer(model_cfg, faces=model.smal.faces) + os.makedirs(args.out_folder, exist_ok=True) + + import detectron2.config + import detectron2.engine + from detectron2 import model_zoo + + cfg = detectron2.config.get_cfg() + cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml")) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 + cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl" + cfg.MODEL.DEVICE = device.type + detector = detectron2.engine.DefaultPredictor(cfg) + + if args.img_path is not None: + img_paths = [Path(args.img_path)] + else: + img_paths = sorted([img for end in args.file_type for img in Path(args.img_folder).glob(end)]) + + for img_path in img_paths: + img_bgr = cv2.imread(str(img_path)) + if img_bgr is None: + print(f"[WARN] Cannot read image: {img_path}") + continue + det_out = detector(img_bgr) + det_instances = det_out['instances'] + boxes, suppressed = select_animal_boxes( + det_instances, + animal_class_ids=ANIMAL_COCO_IDS, + score_threshold=args.det_thresh, + ) + if suppressed > 0: + print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s) in {img_path}") + + if len(boxes) == 0: + print(f"[INFO] No animal detected in {img_path}") + continue + + dataset = ViTDetDataset(model_cfg, img_bgr, boxes) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + + for batch in tqdm(dataloader, desc=f"{img_path.name}"): + batch = recursive_to(batch, device) + with torch.no_grad(): + out_before = model(batch) + + img_fn = img_path.stem + animal_id = int(batch['animalid'][0]) + + render_and_save( + renderer, + cam_crop_to_full_fn, + out_before, + batch, + img_fn, + animal_id, + args.out_folder, + suffix='before_tta', + side_view=args.side_view, + save_mesh=args.save_mesh, + ) + + patch_rgb = denorm_patch_to_rgb(batch['img'][0]) + with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir: + kpts_xyc = run_superanimal_on_patch(patch_rgb, args, tmp_dir) + + if kpts_xyc is None: + print(f"[WARN] No SuperAnimal keypoints for {img_fn}_{animal_id}, skip TTA") + continue + + kpts_xyc[kpts_xyc[:, 2] < args.kp_conf_thresh, 2] = 0.0 + + save_keypoint_vis( + patch_rgb, + kpts_xyc, + os.path.join(args.out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png"), + ) + np.save(os.path.join(args.out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy"), kpts_xyc) + + patch_h, patch_w = patch_rgb.shape[:2] + kpts_norm = kpts_xyc.copy() + kpts_norm[:, 0] = kpts_norm[:, 0] / float(patch_w) - 0.5 + kpts_norm[:, 1] = kpts_norm[:, 1] / float(patch_h) - 0.5 + gt_kpts_norm = torch.from_numpy(kpts_norm[None]).to(device=device, dtype=batch['img'].dtype) + + out_after = tta_optimize( + model, + batch, + gt_kpts_norm, + num_iters=args.tta_num_iters, + lr=args.tta_lr, + ) + + render_and_save( + renderer, + cam_crop_to_full_fn, + out_after, + batch, + img_fn, + animal_id, + args.out_folder, + suffix='after_tta', + side_view=args.side_view, + save_mesh=args.save_mesh, + ) + + +if __name__ == '__main__': + main() diff --git a/images/teaser.png b/images/teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..0565e56973c430588bb28395e5f55162e2affef6 --- /dev/null +++ b/images/teaser.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a617ca4fd37de03e2db4ccf397ce9841ed32c3fe18c766c4832d41af574ad746 +size 4287693 diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..aaca01816e9098e0f0ab266b9dfd7e86ec6b4129 --- /dev/null +++ b/packages.txt @@ -0,0 +1,7 @@ +libosmesa6 +libgl1 +libgl1-mesa-dri +libegl-mesa0 +libegl1 +libglx-mesa0 +libgles2 diff --git a/prima/__init__.py b/prima/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a653e908faa00c04476586f4dbeedd876f468831 --- /dev/null +++ b/prima/__init__.py @@ -0,0 +1,25 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +"""Top-level package for PRIMA. + +This package contains models, datasets and utilities for +3D animal pose and shape estimation. +""" + +from importlib.metadata import PackageNotFoundError, version + + +try: # pragma: no cover - best effort during development + __version__ = version("prima-animal") +except PackageNotFoundError: # pragma: no cover + __version__ = "0.0.0" + + +__all__ = ["__version__"] diff --git a/prima/configs/__init__.py b/prima/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c465c6b3c894534c973c1e8ba8344716c4554e0 --- /dev/null +++ b/prima/configs/__init__.py @@ -0,0 +1,99 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from typing import Dict +from yacs.config import CfgNode as CN + +def to_lower(x: Dict) -> Dict: + """ + Convert all dictionary keys to lowercase + Args: + x (dict): Input dictionary + Returns: + dict: Output dictionary with all keys converted to lowercase + """ + return {k.lower(): v for k, v in x.items()} + + +_C = CN(new_allowed=True) + +_C.GENERAL = CN(new_allowed=True) +_C.GENERAL.RESUME = True +_C.GENERAL.TIME_TO_RUN = 3300 +_C.GENERAL.VAL_STEPS = 100 +_C.GENERAL.LOG_STEPS = 100 +_C.GENERAL.CHECKPOINT_STEPS = 20000 +_C.GENERAL.CHECKPOINT_DIR = "checkpoints" +_C.GENERAL.SUMMARY_DIR = "tensorboard" +_C.GENERAL.NUM_GPUS = 1 +_C.GENERAL.NUM_WORKERS = 4 +_C.GENERAL.MIXED_PRECISION = True +_C.GENERAL.ALLOW_CUDA = True +_C.GENERAL.PIN_MEMORY = False +_C.GENERAL.DISTRIBUTED = False +_C.GENERAL.LOCAL_RANK = 0 +_C.GENERAL.USE_SYNCBN = False +_C.GENERAL.WORLD_SIZE = 1 +_C.GENERAL.PREFETCH_FACTOR = 2 + +_C.TRAIN = CN(new_allowed=True) +_C.TRAIN.NUM_EPOCHS = 100 +_C.TRAIN.SHUFFLE = True +_C.TRAIN.WARMUP = False +_C.TRAIN.NORMALIZE_PER_IMAGE = False +_C.TRAIN.CLIP_GRAD = False +_C.TRAIN.CLIP_GRAD_VALUE = 1.0 +_C.LOSS_WEIGHTS = CN(new_allowed=True) + +_C.DATASETS = CN(new_allowed=True) + +_C.MODEL = CN(new_allowed=True) +_C.MODEL.IMAGE_SIZE = 224 + +_C.EXTRA = CN(new_allowed=True) +_C.EXTRA.FOCAL_LENGTH = 5000 + +_C.DATASETS.CONFIG = CN(new_allowed=True) +_C.DATASETS.CONFIG.SCALE_FACTOR = 0.3 +_C.DATASETS.CONFIG.ROT_FACTOR = 30 +_C.DATASETS.CONFIG.TRANS_FACTOR = 0.02 +_C.DATASETS.CONFIG.COLOR_SCALE = 0.2 +_C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6 +_C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5 +_C.DATASETS.CONFIG.DO_FLIP = False +_C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5 +_C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10 + + +def default_config() -> CN: + """ + Get a yacs CfgNode object with the default config values. + """ + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() + + +def get_config(config_file: str, merge: bool = True) -> CN: + """ + Read a config file and optionally merge it with the default config file. + Args: + config_file (str): Path to config file. + merge (bool): Whether to merge with the default config or not. + Returns: + CfgNode: Config as a yacs CfgNode object. + """ + if merge: + cfg = default_config() + else: + cfg = CN(new_allowed=True) + cfg.merge_from_file(config_file) + + cfg.freeze() + return cfg diff --git a/prima/datasets/__init__.py b/prima/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f05024bca78146ece6bdc9ca1ffff821371694f1 --- /dev/null +++ b/prima/datasets/__init__.py @@ -0,0 +1,79 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from typing import Dict, Optional +from torch.utils.data import WeightedRandomSampler +import torch +import pytorch_lightning as pl +from yacs.config import CfgNode +from .datasets import OptionAnimalDataset, TrainDataset +from prima.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +class DataModule(pl.LightningDataModule): + + def __init__(self, cfg: CfgNode) -> None: + """ + Initialize LightningDataModule for AMR training + Args: + cfg (CfgNode): Config file as a yacs CfgNode containing necessary dataset info. + """ + super().__init__() + self.cfg = cfg + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + self.mocap_dataset = None + self.weight_sampler = None + + def setup(self, stage: Optional[str] = None) -> None: + """ + Load datasets necessary for training + Args: + stage: + """ + if self.train_dataset is None: + self.train_dataset = OptionAnimalDataset(self.cfg) + self.weight_sampler = WeightedRandomSampler(weights=self.train_dataset.weights, + num_samples=len(self.train_dataset)) + if self.val_dataset is None: + self.val_dataset = TrainDataset(self.cfg, is_train=False, + root_image=self.cfg.DATASETS.ANIMAL3D.ROOT_IMAGE, + json_file=self.cfg.DATASETS.ANIMAL3D.JSON_FILE.TEST) + + def train_dataloader(self) -> Dict: + """ + Setup training data loader. + Returns: + Dict: Dictionary containing image and mocap data dataloaders + """ + shuffle = False if self.weight_sampler is not None else True + train_dataloader = torch.utils.data.DataLoader(self.train_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, + num_workers=self.cfg.GENERAL.NUM_WORKERS, + prefetch_factor=self.cfg.GENERAL.PREFETCH_FACTOR, + pin_memory=True, + shuffle=shuffle, + sampler=self.weight_sampler, + ) + return {'img': train_dataloader} + + def val_dataloader(self) -> torch.utils.data.DataLoader: + """ + Setup val data loader. + Returns: + torch.utils.data.DataLoader: Validation dataloader + """ + val_dataloader = torch.utils.data.DataLoader(self.val_dataset, self.cfg.TRAIN.BATCH_SIZE, drop_last=True, + num_workers=self.cfg.GENERAL.NUM_WORKERS, pin_memory=True) + return val_dataloader + + + diff --git a/prima/datasets/datasets.py b/prima/datasets/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..b28cef295650d1e1ac8178c67b8ab350bea854ed --- /dev/null +++ b/prima/datasets/datasets.py @@ -0,0 +1,278 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import copy +import os +import numpy as np +import torch +from yacs.config import CfgNode +import cv2 +import pyrootutils +from torch.utils.data import ConcatDataset +from typing import List +root = pyrootutils.setup_root( + search_from=__file__, + indicator=[".git", "pyproject.toml"], + pythonpath=True, + dotenv=True, +) + +import json +import hydra +from omegaconf import DictConfig, OmegaConf +from PIL import Image +from torch.utils.data import Dataset, DataLoader +from typing import Optional, Tuple +from .utils import get_example, expand_to_aspect_ratio + + +class TrainDataset(Dataset): + def __init__(self, cfg: CfgNode, is_train: bool, root_image: str, json_file: str): + super().__init__() + self.root_image = root_image + self.focal_length = cfg.SMAL.get("FOCAL_LENGTH", 1000) + + json_file = json_file + with open(json_file, 'r') as f: + self.data = json.load(f) + + self.is_train = is_train + self.IMG_SIZE = cfg.MODEL.IMAGE_SIZE + self.MEAN = 255. * np.array(cfg.MODEL.IMAGE_MEAN) + self.STD = 255. * np.array(cfg.MODEL.IMAGE_STD) + self.use_skimage_antialias = cfg.DATASETS.get('USE_SKIMAGE_ANTIALIAS', False) + self.border_mode = { + 'constant': cv2.BORDER_CONSTANT, + 'replicate': cv2.BORDER_REPLICATE, + }[cfg.DATASETS.get('BORDER_MODE', 'constant')] + + self.augm_config = cfg.DATASETS.CONFIG + + def __len__(self): + return len(self.data['data']) + + def __getitem__(self, item): + data = self.data['data'][item] + key = data['img_path'] + image = np.array(Image.open(os.path.join(self.root_image, key)).convert("RGB")) + mask = np.array(Image.open(os.path.join(self.root_image, data['mask_path'])).convert('L')) + category_idx = data['supercategory'] + keypoint_2d = np.array(data['keypoint_2d'], dtype=np.float32) + if 'keypoint_3d' in data: + keypoint_3d = np.concatenate( + (data['keypoint_3d'], np.ones((len(data['keypoint_3d']), 1))), axis=-1).astype(np.float32) + else: + keypoint_3d = np.zeros((len(keypoint_2d), 4), dtype=np.float32) + bbox = data['bbox'] # [x, y, w, h] + center = np.array([(bbox[0] * 2 + bbox[2]) // 2, (bbox[1] * 2 + bbox[3]) // 2]) + pose = np.array(data['pose'], dtype=np.float32) if 'pose' in data else np.zeros(105, dtype=np.float32) # [105, ] + betas = np.array(data['shape'] + data['shape_extra'], dtype=np.float32) if 'shape' in data else np.zeros(41, dtype=np.float32) # [41, ] + translation = np.array(data['trans'], dtype=np.float32) if 'trans' in data else np.zeros(3, dtype=np.float32) # [3, ] + # Fixed: Check if all elements are zero, not if all elements are truthy + has_pose = np.array(1., dtype=np.float32) if not (pose == 0).all() else np.array(0., dtype=np.float32) + has_betas = np.array(1., dtype=np.float32) if not (betas == 0).all() else np.array(0., dtype=np.float32) + has_translation = np.array(1., dtype=np.float32) if not (translation == 0).all() else np.array(0., dtype=np.float32) + ori_keypoint_2d = keypoint_2d.copy() + center_x, center_y = center[0], center[1] + bbox_size = max([bbox[2], bbox[3]]) + + smal_params = {'global_orient': pose[:3], + 'pose': pose[3:], + 'betas': betas, + 'transl': translation, + } + has_smal_params = {'global_orient': has_pose, + 'pose': has_pose, + 'betas': has_betas, + 'transl': has_translation, + } + smal_params_is_axis_angle = {'global_orient': True, + 'pose': True, + 'betas': False, + 'transl': False, + } + + augm_config = copy.deepcopy(self.augm_config) + img_rgba = np.concatenate([image, mask[:, :, None]], axis=2) + img_patch_rgba, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, trans, img_border_mask = get_example( + img_rgba, + center_x, center_y, + bbox_size, bbox_size, + keypoint_2d, keypoint_3d, + smal_params, has_smal_params, + self.IMG_SIZE, self.IMG_SIZE, + self.MEAN, self.STD, self.is_train, augm_config, + is_bgr=False, return_trans=True, + use_skimage_antialias=self.use_skimage_antialias, + border_mode=self.border_mode + ) + img_patch = (img_patch_rgba[:3, :, :]) + mask_patch = (img_patch_rgba[3, :, :] / 255.0).clip(0, 1) + if (mask_patch < 0.5).all(): + mask_patch = np.ones_like(mask_patch) + + item = {'img': img_patch, + 'mask': mask_patch, + 'keypoints_2d': keypoints_2d, + 'keypoints_3d': keypoints_3d, + 'orig_keypoints_2d': ori_keypoint_2d, + 'box_center': np.array(center.copy(), dtype=np.float32), + 'box_size': float(bbox_size), + 'img_size': np.array(1.0 * img_size[::-1].copy(), dtype=np.float32), + 'smal_params': smal_params, + 'has_smal_params': has_smal_params, + 'smal_params_is_axis_angle': smal_params_is_axis_angle, + '_trans': trans, + 'focal_length': np.array([self.focal_length, self.focal_length], dtype=np.float32), + 'category': np.array(category_idx, dtype=np.int32), + 'supercategory': np.array(category_idx, dtype=np.int32), + "img_border_mask": img_border_mask.astype(np.float32), + "has_mask": np.array(1, dtype=np.float32)} + return item + + +class EvaluationDataset(Dataset): + def __init__(self, root_image: str, json_file: str, augm_config, + focal_length: int=1000, image_size: int=256, + mean: List[float]=[0.485, 0.456, 0.406], std: List[float]=[0.229, 0.224, 0.225]): + super().__init__() + self.root_image = root_image + self.focal_length = focal_length + + with open(json_file, 'r') as f: + self.data = json.load(f) + + self.is_train = False + self.IMG_SIZE = image_size + self.MEAN = 255. * np.array(mean) + self.STD = 255. * np.array(std) + self.use_skimage_antialias = False + self.border_mode = cv2.BORDER_CONSTANT + self.augm_config = augm_config + + def __len__(self): + return len(self.data['data']) + + def __getitem__(self, item): + data = self.data['data'][item] + key = data['img_path'] + image = np.array(Image.open(os.path.join(self.root_image, key)).convert("RGB")) + mask = np.array(Image.open(os.path.join(self.root_image, data['mask_path'])).convert('L')) + category_idx = data['supercategory'] + keypoint_2d = np.array(data['keypoint_2d'], dtype=np.float32) + # add check keypoint_3d, make it suitable for 2D dataset, and same with train dataset + if 'keypoint_3d' in data: + keypoint_3d = np.concatenate( + (data['keypoint_3d'], np.ones((len(data['keypoint_3d']), 1))), axis=-1).astype(np.float32) + else: + keypoint_3d = np.zeros((len(keypoint_2d), 4), dtype=np.float32) + bbox = data['bbox'] # [x, y, w, h] + center = np.array([(bbox[0] * 2 + bbox[2]) // 2, (bbox[1] * 2 + bbox[3]) // 2]) + pose = np.array(data['pose'], dtype=np.float32) if 'pose' in data else np.zeros(105, dtype=np.float32) # [105, ] + betas = np.array(data['shape'] + data['shape_extra'], dtype=np.float32) if 'shape' in data else np.zeros(41, dtype=np.float32) # [41, ] + translation = np.array(data['trans'], dtype=np.float32) if 'trans' in data else np.zeros(3, dtype=np.float32) # [3, ] + # Fixed: Check if all elements are zero, not if all elements are truthy + has_pose = np.array(1., dtype=np.float32) if not (pose == 0).all() else np.array(0., dtype=np.float32) + has_betas = np.array(1., dtype=np.float32) if not (betas == 0).all() else np.array(0., dtype=np.float32) + has_translation = np.array(1., dtype=np.float32) if not (translation == 0).all() else np.array(0., dtype=np.float32) + ori_keypoint_2d = keypoint_2d.copy() + center_x, center_y = center[0], center[1] + + scale = np.array([bbox[2], bbox[3]], dtype=np.float32) / 200. + bbox_size = expand_to_aspect_ratio(scale*200, None).max() + bbox_expand_factor = bbox_size / ((scale*200).max()) + + smal_params = {'global_orient': pose[:3], + 'pose': pose[3:], + 'betas': betas, + 'transl': translation, + 'bone': np.zeros(24, dtype=np.float32) if 'bone' not in data else np.array(data['bone']) + } + has_smal_params = {'global_orient': has_pose, + 'pose': has_pose, + 'betas': has_betas, + 'transl': has_translation, + 'bone': np.array(1., dtype=np.float32) if 'bone' in data else np.array(0., dtype=np.float32), + } + smal_params_is_axis_angle = {'global_orient': True, + 'pose': True, + 'betas': False, + 'transl': False, + 'bone': False + } + + augm_config = copy.deepcopy(self.augm_config) + img_rgba = np.concatenate([image, mask[:, :, None]], axis=2) + img_patch_rgba, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, trans, img_border_mask = get_example( + img_rgba, + center_x, center_y, + bbox_size, bbox_size, + keypoint_2d, keypoint_3d, + smal_params, has_smal_params, + self.IMG_SIZE, self.IMG_SIZE, + self.MEAN, self.STD, self.is_train, augm_config, + is_bgr=False, return_trans=True, + use_skimage_antialias=self.use_skimage_antialias, + border_mode=self.border_mode + ) + img_patch = (img_patch_rgba[:3, :, :]) + mask_patch = (img_patch_rgba[3, :, :] / 255.0).clip(0, 1) + if (mask_patch < 0.5).all(): + mask_patch = np.ones_like(mask_patch) + + item = {'img': img_patch, + 'mask': mask_patch, + 'keypoints_2d': keypoints_2d, + 'keypoints_3d': keypoints_3d, + 'orig_keypoints_2d': ori_keypoint_2d, + 'box_center': np.array(center.copy(), dtype=np.float32), + 'box_size': float(bbox_size), + 'img_size': np.array(1.0 * img_size[::-1].copy(), dtype=np.float32), + 'smal_params': smal_params, + 'has_smal_params': has_smal_params, + 'smal_params_is_axis_angle': smal_params_is_axis_angle, + '_trans': trans, + 'focal_length': np.array([self.focal_length, self.focal_length], dtype=np.float32), + 'category': np.array(category_idx, dtype=np.int32), + 'bbox_expand_factor': bbox_expand_factor, + 'supercategory': np.array(category_idx, dtype=np.int32), + "img_border_mask": img_border_mask.astype(np.float32), + 'has_mask': np.array(1., dtype=np.float32), + 'imgname': key, + 'bbox': np.array(bbox, dtype=np.float32)} + return item + + +class OptionAnimalDataset(Dataset): + def __init__(self, cfg: CfgNode): + datasets = [] + weights = [] + + dataset_configs = cfg.DATASETS + for dataset_name in dataset_configs: + if dataset_name != "CONFIG": + datasets.append(TrainDataset(cfg, + is_train=True, + root_image=dataset_configs[dataset_name].ROOT_IMAGE, + json_file=dataset_configs[dataset_name].JSON_FILE.TRAIN)) + weights.extend([dataset_configs[dataset_name].WEIGHT] * len(datasets[-1])) + + # Concatenate all enabled datasets + if datasets: + self.dataset = ConcatDataset(datasets) + self.weights = torch.tensor(weights, dtype=torch.float32) + else: + raise ValueError("No datasets enabled in the configuration.") + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + return self.dataset[idx] + \ No newline at end of file diff --git a/prima/datasets/dlc2coco.py b/prima/datasets/dlc2coco.py new file mode 100644 index 0000000000000000000000000000000000000000..9e4d29824b8b6eeccd683d73f5302601de66c098 --- /dev/null +++ b/prima/datasets/dlc2coco.py @@ -0,0 +1,362 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +''' +this scripts if to convert DeepLabCut labeled data format (20 keypoints) to COCO format (26 keypoints ), also image should be extracted from the raw video to save as frames. + +Usage: + python dlc2coco.py --dataset_dir /path/to/dataset --output_dir /path/to/output + +for camera x +dlc keypoint data: //fte_pw/camx_fte.csv, + where video frame index from the video and keypoint coordinates are stored +raw video: //camx.mp4 + +for coco format, please refer to: + ./datasets/quadruped2d/test.json + +also, the relationship of multiview should be saved. + + +keypoint mapping from acinoset to animal3d : +keypoint_mapping = {"acinoset":[2, 1, -1, 13, 10, 19, 16, 5, -1, -1, -1, -1, 11, 8, 12, 9, 18, 15, 3, 7, -1,-1,-1,-1, 0, 6]} + + +''' + +import argparse +import os +import json +import cv2 +import numpy as np +import pandas as pd +from pathlib import Path +from tqdm import tqdm + +# DLC keypoints (20 keypoints from acinoset): +# 0: nose, 1: r_eye, 2: l_eye, 3: neck_base, 4: spine, 5: tail_base, 6: tail1, 7: tail2, +# 8: r_shoulder, 9: r_front_knee, 10: r_front_ankle, 11: l_shoulder, 12: l_front_knee, 13: l_front_ankle, +# 14: r_hip, 15: r_back_knee, 16: r_back_ankle, 17: l_hip, 18: l_back_knee, 19: l_back_ankle + +# Animal3D keypoints (26 keypoints): +# Based on the mapping: [2, 1, -1, 13, 10, 19, 16, 5, -1, -1, -1, -1, 11, 8, 12, 9, 18, 15, 3, 7, -1,-1,-1,-1, 0, 6] +# This means: animal3d_idx 0 maps to acinoset_idx 2 (l_eye), animal3d_idx 1 maps to acinoset_idx 1 (r_eye), etc. + +# Keypoint mapping from acinoset (DLC) to animal3d (COCO format) +KEYPOINT_MAPPING = [2, 1, -1, 13, 10, 19, 16, 5, -1, -1, -1, -1, 11, 8, 12, 9, 18, 15, 3, 7, -1, -1, -1, -1, 0, 6] + +def read_dlc_csv(csv_path): + """ + Read DeepLabCut CSV file and extract keypoint data + Returns: DataFrame with frame index and keypoint coordinates + """ + # Read the CSV file, skip the first 2 rows (header rows) + df = pd.read_csv(csv_path, skiprows=2) + + # Replace NaN with 0 + df = df.fillna(0) + + # The first column is frame index + frame_indices = df.iloc[:, 0].values + + # Extract keypoint coordinates (x, y, likelihood) + # DLC format: each keypoint has 3 columns (x, y, likelihood) + num_keypoints = 20 + keypoints_data = [] + + for idx, frame_idx in enumerate(frame_indices): + keypoints = [] + for kp_idx in range(num_keypoints): + col_start = 1 + kp_idx * 3 + x = float(df.iloc[idx, col_start]) + y = float(df.iloc[idx, col_start + 1]) + likelihood = float(df.iloc[idx, col_start + 2]) + + # If likelihood is 0 (from NaN), but x and y are not 0, assume it's a valid point + if likelihood == 0 and (x != 0 or y != 0): + likelihood = 1.0 # Default to high confidence + + keypoints.append([x, y, likelihood]) + + keypoints_data.append({ + 'frame_idx': int(frame_idx), + 'keypoints': keypoints + }) + + return keypoints_data + +def map_keypoints_to_animal3d(acinoset_keypoints): + """ + Map 20 DLC keypoints to 26 Animal3D keypoints using the provided mapping + acinoset_keypoints: list of [x, y, likelihood] for 20 keypoints + Returns: list of [x, y, visibility] for 26 keypoints + """ + animal3d_keypoints = [] + + for animal3d_idx, acinoset_idx in enumerate(KEYPOINT_MAPPING): + if acinoset_idx == -1: + # Missing keypoint, set to [0, 0, 0] + animal3d_keypoints.append([0.0, 0.0, 0.0]) + else: + x, y, likelihood = acinoset_keypoints[acinoset_idx] + # Replace NaN with 0 + if np.isnan(x): + x = 0.0 + if np.isnan(y): + y = 0.0 + if np.isnan(likelihood): + likelihood = 0.0 + + # Convert likelihood to visibility flag (2 = visible, 1 = occluded, 0 = not labeled) + # If the keypoint has valid coordinates, mark as visible + if x != 0.0 or y != 0.0: + visibility = 2.0 + else: + visibility = 0.0 + + animal3d_keypoints.append([float(x), float(y), visibility]) + + return animal3d_keypoints + +def extract_frames_from_video(video_path, output_dir, frame_indices, behavior, camera_id): + """ + Extract specific frames from video and save as images + Returns: dict mapping frame_idx to image path + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + print(f"Error: Cannot open video {video_path}") + return {} + + frame_paths = {} + + # Sort frame indices for efficient extraction + sorted_frames = sorted(set(frame_indices)) # Remove duplicates + + pbar = tqdm(total=len(sorted_frames), desc=f"Extracting frames from {video_path.name}") + + for target_frame in sorted_frames: + # Use CAP_PROP_POS_FRAMES to seek to the exact frame + cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame) + ret, frame = cap.read() + + if ret and frame is not None: + # Save frame as image with behavior name in filename + img_filename = f"{behavior}_cam{camera_id}_frame_{target_frame:06d}.jpg" + img_path = output_dir / img_filename + cv2.imwrite(str(img_path), frame) + frame_paths[target_frame] = str(img_path.relative_to(output_dir.parent.parent)) + else: + print(f"Warning: Failed to read frame {target_frame} from {video_path.name}") + + pbar.update(1) + + pbar.close() + cap.release() + + return frame_paths + +def compute_bbox_from_keypoints(keypoints): + """ + Compute bounding box from keypoints + keypoints: list of [x, y, visibility] + Returns: [x, y, width, height] + """ + valid_points = [(kp[0], kp[1]) for kp in keypoints if kp[2] > 0] + + if not valid_points: + return [0, 0, 0, 0] + + xs, ys = zip(*valid_points) + x_min, x_max = min(xs), max(xs) + y_min, y_max = min(ys), max(ys) + + # Add some padding + padding = 20 + x_min = max(0, x_min - padding) + y_min = max(0, y_min - padding) + width = (x_max - x_min) + 2 * padding + height = (y_max - y_min) + 2 * padding + + return [float(x_min), float(y_min), float(width), float(height)] + +def process_camera(camera_id, base_dir, output_dir, behavior): + """ + Process one camera: read CSV, extract frames, convert to COCO format + behavior: name of the behavior (e.g., 'run', 'flick') + """ + base_dir = Path(base_dir) + output_dir = Path(output_dir) + + # Paths + csv_path = base_dir / "fte_pw" / f"cam{camera_id}_fte.csv" + video_path = base_dir / f"cam{camera_id}.mp4" + + print(f"\nProcessing Camera {camera_id} - Behavior: {behavior}...") + print(f"CSV: {csv_path}") + print(f"Video: {video_path}") + + # Read keypoint data from CSV + keypoints_data = read_dlc_csv(csv_path) + print(f"Found {len(keypoints_data)} frames with keypoints") + + # Extract frames from video + frame_indices = [kp_data['frame_idx'] for kp_data in keypoints_data] + images_dir = output_dir / "images" / behavior / f"cam{camera_id}" + frame_paths = extract_frames_from_video(video_path, images_dir, frame_indices, behavior, camera_id) + + # Convert to COCO format + coco_data = [] + for kp_data in tqdm(keypoints_data, desc=f"Converting cam{camera_id} to COCO format"): + frame_idx = kp_data['frame_idx'] + + if frame_idx not in frame_paths: + continue + + # Map keypoints from acinoset (20) to animal3d (26) + acinoset_kps = kp_data['keypoints'] + animal3d_kps = map_keypoints_to_animal3d(acinoset_kps) + + # Compute bounding box + bbox = compute_bbox_from_keypoints(animal3d_kps) + + # Create COCO entry + img_path = frame_paths[frame_idx] + coco_entry = { + "img_path": img_path, + "mask_path": img_path, # Same as img_path + "bbox": bbox, + "keypoint_2d": animal3d_kps, + "camera_id": camera_id, + "frame_idx": frame_idx, + "behavior": behavior + } + + coco_data.append(coco_entry) + + return coco_data + +def parse_args(): + parser = argparse.ArgumentParser( + description="Convert DeepLabCut labeled data to COCO format" + ) + parser.add_argument( + "--dataset_dir", type=str, default=".", + help="Root directory containing behavior subdirectories (run, flick, etc.)" + ) + parser.add_argument( + "--output_dir", type=str, default=None, + help="Output directory for COCO format data (default: {dataset_dir}/coco_format)" + ) + parser.add_argument( + "--behaviors", type=str, nargs="+", default=["run", "flick"], + help="Behavior names to process (default: run flick)" + ) + parser.add_argument( + "--cameras", type=int, nargs="+", default=[1, 2, 3, 4, 5, 6], + help="Camera IDs to process (default: 1 2 3 4 5 6)" + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + dataset_dir = Path(args.dataset_dir) + output_dir = Path(args.output_dir) if args.output_dir else dataset_dir / "coco_format" + output_dir.mkdir(parents=True, exist_ok=True) + + behaviors = args.behaviors + camera_ids = args.cameras + + all_data = [] + behavior_data = {} + camera_data = {} + + for behavior in behaviors: + behavior_dir = dataset_dir / behavior + behavior_data[behavior] = [] + + print(f"\n{'='*60}") + print(f"Processing Behavior: {behavior.upper()}") + print(f"{'='*60}") + + for cam_id in camera_ids: + coco_data = process_camera(cam_id, behavior_dir, output_dir, behavior) + all_data.extend(coco_data) + behavior_data[behavior].extend(coco_data) + + # Store per-camera-behavior data + key = f"{behavior}_cam{cam_id}" + camera_data[key] = coco_data + + # Save combined data (all behaviors and cameras) + output_json = output_dir / "all_data.json" + with open(output_json, 'w') as f: + json.dump({"data": all_data}, f, indent=4) + + print(f"\n{'='*60}") + print(f"SUMMARY") + print(f"{'='*60}") + print(f"Saved combined data to {output_json}") + print(f"Total entries: {len(all_data)}") + + # Save per-behavior data + for behavior in behaviors: + behavior_json = output_dir / f"{behavior}.json" + with open(behavior_json, 'w') as f: + json.dump({"data": behavior_data[behavior]}, f, indent=4) + print(f"\nSaved {behavior} data to {behavior_json} ({len(behavior_data[behavior])} entries)") + + # Save per-camera-behavior data + for behavior in behaviors: + for cam_id in camera_ids: + key = f"{behavior}_cam{cam_id}" + cam_json = output_dir / f"{behavior}_cam{cam_id}.json" + with open(cam_json, 'w') as f: + json.dump({"data": camera_data[key]}, f, indent=4) + print(f" - {behavior}_cam{cam_id}: {len(camera_data[key])} entries") + + # Save multiview relationships + # Group by behavior and frame index to establish multiview correspondence + multiview_data = {} + for entry in all_data: + behavior = entry['behavior'] + frame_idx = entry['frame_idx'] + cam_id = entry['camera_id'] + + if behavior not in multiview_data: + multiview_data[behavior] = {} + + if frame_idx not in multiview_data[behavior]: + multiview_data[behavior][frame_idx] = {} + + multiview_data[behavior][frame_idx][f"cam{cam_id}"] = { + "img_path": entry['img_path'], + "keypoint_2d": entry['keypoint_2d'], + "bbox": entry['bbox'] + } + + multiview_json = output_dir / "multiview_mapping.json" + with open(multiview_json, 'w') as f: + json.dump(multiview_data, f, indent=4) + + print(f"\nSaved multiview mapping to {multiview_json}") + for behavior in behaviors: + print(f" - {behavior}: {len(multiview_data.get(behavior, {}))} synchronized frames") + + print(f"\n{'='*60}") + print("Conversion complete!") + print(f"{'='*60}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/prima/datasets/split_acinoset.py b/prima/datasets/split_acinoset.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb327d7f191bc8f5c07240c0f235a596d77f00a --- /dev/null +++ b/prima/datasets/split_acinoset.py @@ -0,0 +1,153 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +""" +Split acinoset multiview_mapping.json into train and test sets (7:3 ratio). + +Usage: + python split_acinoset.py \ + --input_json /path/to/multiview_mapping.json \ + --output_dir /path/to/output \ + --train_ratio 0.7 \ + --seed 42 +""" + +import argparse +import json +import random +from pathlib import Path +from collections import defaultdict + +# ------------------------------------------------------------------ +# EDIT THIS to point to your dataset root (see examples above). +# All paths below are relative to this directory. +# ------------------------------------------------------------------ +BASE_DIR = Path("datasets") + + +def split_multiview_data(input_json, output_dir, train_ratio=0.7, seed=42): + """ + Split multiview mapping data into train and test sets. + + + Args: + input_json: Path to multiview_mapping.json + output_dir: Directory to save train.json and test.json + train_ratio: Ratio of training data (default 0.7 for 70%%) + train_ratio: Ratio of training data (default 0.7 for 70%%) + seed: Random seed for reproducibility + """ + # Set random seed + random.seed(seed) + + # Load data + print(f"Loading data from {input_json}...") + with open(input_json, 'r') as f: + data = json.load(f) + + # Initialize train and test splits + train_data = defaultdict(dict) + test_data = defaultdict(dict) + + # Process each behavior + for behavior, frames in data.items(): + print(f"\nProcessing behavior: {behavior}") + + # Get all frame indices + frame_indices = list(frames.keys()) + total_frames = len(frame_indices) + + # Shuffle frame indices + random.shuffle(frame_indices) + + # Calculate split point + train_size = int(total_frames * train_ratio) + + # Split frames + train_frames = frame_indices[:train_size] + test_frames = frame_indices[train_size:] + + print(f" Total frames: {total_frames}") + print(f" Train frames: {len(train_frames)}") + print(f" Test frames: {len(test_frames)}") + + # Assign to train and test + for frame_idx in train_frames: + train_data[behavior][frame_idx] = frames[frame_idx] + + for frame_idx in test_frames: + test_data[behavior][frame_idx] = frames[frame_idx] + + # Save train and test splits + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + train_json = output_dir / "train.json" + test_json = output_dir / "test.json" + + print(f"\nSaving train data to {train_json}...") + with open(train_json, 'w') as f: + json.dump(dict(train_data), f, indent=4) + + print(f"Saving test data to {test_json}...") + with open(test_json, 'w') as f: + json.dump(dict(test_data), f, indent=4) + + # Print summary + print("\n" + "="*50) + print("Summary:") + print("="*50) + + total_train_frames = sum(len(frames) for frames in train_data.values()) + total_test_frames = sum(len(frames) for frames in test_data.values()) + total_frames = total_train_frames + total_test_frames + + print(f"Total frames: {total_frames}") + print(f"Train frames: {total_train_frames} ({total_train_frames/total_frames*100:.1f}%%)") + print(f"Test frames: {total_test_frames} ({total_test_frames/total_frames*100:.1f}%%)") + print("\nPer behavior:") + for behavior in train_data.keys(): + train_count = len(train_data[behavior]) + test_count = len(test_data[behavior]) + total_count = train_count + test_count + print(f" {behavior}: train={train_count}, test={test_count}, total={total_count}") + + print("\nDone!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Split multiview_mapping.json into train/test sets (default 7:3)." + ) + parser.add_argument( + "--input_json", type=str, + default="datasets/acinoset/multiview_mapping.json", + help="Path to multiview_mapping.json (default: datasets/acinoset/multiview_mapping.json)." + ) + parser.add_argument( + "--output_dir", type=str, + default="datasets/acinoset", + help="Directory to save train.json and test.json (default: datasets/acinoset)." + ) + parser.add_argument( + "--train_ratio", type=float, default=0.7, + help="Fraction of data for training (default: 0.7)." + ) + parser.add_argument( + "--seed", type=int, default=42, + help="Random seed for reproducibility (default: 42)." + ) + args = parser.parse_args() + + split_multiview_data( + input_json=args.input_json, + output_dir=args.output_dir, + train_ratio=args.train_ratio, + seed=args.seed, + ) diff --git a/prima/datasets/utils.py b/prima/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3add793fb07156c3562b9c345ddea1d171090771 --- /dev/null +++ b/prima/datasets/utils.py @@ -0,0 +1,1106 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +""" +Parts of the code are taken or adapted from +https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py +""" +import torch +import numpy as np +from skimage.transform import rotate, resize +from skimage.filters import gaussian +import random +import cv2 +from typing import List, Dict, Tuple +from yacs.config import CfgNode +from typing import Union + + +def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None): + """Increase the size of the bounding box to match the target shape.""" + if target_aspect_ratio is None: + return input_shape + + try: + w, h = input_shape + except (ValueError, TypeError): + return input_shape + + w_t, h_t = target_aspect_ratio + if h / w < h_t / w_t: + h_new = max(w * h_t / w_t, h) + w_new = w + else: + h_new = h + w_new = max(h * w_t / h_t, w) + if h_new < h or w_new < w: + raise ValueError(f"Expanded size ({w_new}, {h_new}) smaller than original ({w}, {h})") + return np.array([w_new, h_new]) + + +def do_augmentation(aug_config: CfgNode) -> Tuple: + """ + Compute random augmentation parameters. + Args: + aug_config (CfgNode): Config containing augmentation parameters. + Returns: + scale (float): Box rescaling factor. + rot (float): Random image rotation. + do_flip (bool): Whether to flip image or not. + do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT). + color_scale (List): Color rescaling factor + tx (float): Random translation along the x axis. + ty (float): Random translation along the y axis. + """ + + tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR + ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR + scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0 + rot = np.clip(np.random.randn(), -2.0, + 2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0 + do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE + do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE + extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0) + # extreme_crop_lvl = 0 + c_up = 1.0 + aug_config.COLOR_SCALE + c_low = 1.0 - aug_config.COLOR_SCALE + color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)] + return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty + + +def rotate_2d(pt_2d: np.array, rot_rad: float) -> np.array: + """ + Rotate a 2D point on the x-y plane. + Args: + pt_2d (np.array): Input 2D point with shape (2,). + rot_rad (float): Rotation angle + Returns: + np.array: Rotated 2D point. + """ + x = pt_2d[0] + y = pt_2d[1] + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + xx = x * cs - y * sn + yy = x * sn + y * cs + return np.array([xx, yy], dtype=np.float32) + + +def gen_trans_from_patch_cv(c_x: float, c_y: float, + src_width: float, src_height: float, + dst_width: float, dst_height: float, + scale: float, rot: float) -> np.array: + """ + Create transformation matrix for the bounding box crop. + Args: + c_x (float): Bounding box center x coordinate in the original image. + c_y (float): Bounding box center y coordinate in the original image. + src_width (float): Bounding box width. + src_height (float): Bounding box height. + dst_width (float): Output box width. + dst_height (float): Output box height. + scale (float): Rescaling factor for the bounding box (augmentation). + rot (float): Random rotation applied to the box. + Returns: + trans (np.array): Target geometric transformation. + """ + # augment size with scale + src_w = src_width * scale + src_h = src_height * scale + src_center = np.zeros(2) + src_center[0] = c_x + src_center[1] = c_y + # augment rotation + rot_rad = np.pi * rot / 180 + src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad) + src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad) + + dst_w = dst_width + dst_h = dst_height + dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32) + dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32) + dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = src_center + src[1, :] = src_center + src_downdir + src[2, :] = src_center + src_rightdir + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = dst_center + dst[1, :] = dst_center + dst_downdir + dst[2, :] = dst_center + dst_rightdir + + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def trans_point2d(pt_2d: np.array, trans: np.array): + """ + Transform a 2D point using translation matrix trans. + Args: + pt_2d (np.array): Input 2D point with shape (2,). + trans (np.array): Transformation matrix. + Returns: + np.array: Transformed 2D point. + """ + src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T + dst_pt = np.dot(trans, src_pt) + return dst_pt[0:2] + + +def get_transform(center, scale, res, rot=0): + """Generate transformation matrix.""" + """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py""" + h = 200 * scale + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / h + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / h + .5) + t[1, 2] = res[0] * (-float(center[1]) / h + .5) + t[2, 2] = 1 + if not rot == 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3, 3)) + rot_rad = rot * np.pi / 180 + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0, :2] = [cs, -sn] + rot_mat[1, :2] = [sn, cs] + rot_mat[2, 2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0, 2] = -res[1] / 2 + t_mat[1, 2] = -res[0] / 2 + t_inv = t_mat.copy() + t_inv[:2, 2] *= -1 + t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t))) + return t + + +def transform(pt, center, scale, res, invert=0, rot=0, as_int=True): + """Transform pixel location to different reference.""" + """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py""" + t = get_transform(center, scale, res, rot=rot) + if invert: + t = np.linalg.inv(t) + new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T + new_pt = np.dot(t, new_pt) + if as_int: + new_pt = new_pt.astype(int) + return new_pt[:2] + 1 + + +def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0): + c_x = (ul[0] + br[0]) / 2 + c_y = (ul[1] + br[1]) / 2 + bb_width = patch_width = br[0] - ul[0] + bb_height = patch_height = br[1] - ul[1] + trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0) + img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=border_mode, + borderValue=border_value + ) + + # Force borderValue=cv2.BORDER_CONSTANT for alpha channel + if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT): + img_patch[:, :, 3] = cv2.warpAffine(img[:, :, 3], trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + ) + + return img_patch + + +def generate_image_patch_skimage(img: np.array, c_x: float, c_y: float, + bb_width: float, bb_height: float, + patch_width: float, patch_height: float, + do_flip: bool, scale: float, rot: float, + border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]: + """ + Crop image according to the supplied bounding box. + Args: + img (np.array): Input image of shape (H, W, 3) + c_x (float): Bounding box center x coordinate in the original image. + c_y (float): Bounding box center y coordinate in the original image. + bb_width (float): Bounding box width. + bb_height (float): Bounding box height. + patch_width (float): Output box width. + patch_height (float): Output box height. + do_flip (bool): Whether to flip image or not. + scale (float): Rescaling factor for the bounding box (augmentation). + rot (float): Random rotation applied to the box. + Returns: + img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3) + trans (np.array): Transformation matrix. + """ + + img_height, img_width, img_channels = img.shape + if do_flip: + img = img[:, ::-1, :] + c_x = img_width - c_x - 1 + + trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot) + + # img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR) + + # skimage + center = np.zeros(2) + center[0] = c_x + center[1] = c_y + res = np.zeros(2) + res[0] = patch_width + res[1] = patch_height + # assumes bb_width = bb_height + # assumes patch_width = patch_height + assert bb_width == bb_height, f'{bb_width=} != {bb_height=}' + assert patch_width == patch_height, f'{patch_width=} != {patch_height=}' + scale1 = scale * bb_width / 200. + + # Upper left point + ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1 + # Bottom right point + br = np.array(transform([res[0] + 1, + res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1 + + # Padding so that when rotated proper amount of context is included + try: + pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1 + except Exception as e: + raise RuntimeError(f"Failed to compute pad: ul={ul}, br={br}") from e + if not rot == 0: + ul -= pad + br += pad + + if False: + # Old way of cropping image + ul_int = ul.astype(int) + br_int = br.astype(int) + new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape) + + # Range to fill new array + new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0] + new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1] + # Range to sample from original image + old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0]) + old_y = max(0, ul_int[1]), min(len(img), br_int[1]) + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], + old_x[0]:old_x[1]] + + # New way of cropping image + new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32) + + # print(f'{new_img.shape=}') + # print(f'{new_img1.shape=}') + # print(f'{np.allclose(new_img, new_img1)=}') + # print(f'{img.dtype=}') + + if not rot == 0: + # Remove padding + + new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot) + new_img = new_img[pad:-pad, pad:-pad] + + if new_img.shape[0] < 1 or new_img.shape[1] < 1: + raise ValueError( + f"Image patch too small: {new_img.shape}, original: {img.shape}, " + f"ul={ul}, br={br}, pad={pad}, rot={rot}" + ) + + # resize image + new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res) + + new_img = np.clip(new_img, 0, 255).astype(np.uint8) + + return new_img, trans + + +def generate_image_patch_cv2(img: np.array, c_x: float, c_y: float, + bb_width: float, bb_height: float, + patch_width: float, patch_height: float, + do_flip: bool, scale: float, rot: float, + border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]: + """ + Crop the input image and return the crop and the corresponding transformation matrix. + Args: + img (np.array): Input image of shape (H, W, 3) + c_x (float): Bounding box center x coordinate in the original image. + c_y (float): Bounding box center y coordinate in the original image. + bb_width (float): Bounding box width. + bb_height (float): Bounding box height. + patch_width (float): Output box width. + patch_height (float): Output box height. + do_flip (bool): Whether to flip image or not. + scale (float): Rescaling factor for the bounding box (augmentation). + rot (float): Random rotation applied to the box. + Returns: + img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3) + trans (np.array): Transformation matrix. + """ + + img_height, img_width, img_channels = img.shape + if do_flip: + img = img[:, ::-1, :] + c_x = img_width - c_x - 1 + + trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot) + + img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=border_mode, + borderValue=border_value, + ) + # Force borderValue=cv2.BORDER_CONSTANT for alpha channel + if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT): + img_patch[:, :, 3] = cv2.warpAffine(img[:, :, 3], trans, (int(patch_width), int(patch_height)), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + ) + + is_border = np.all(img_patch[:, :, :-1] == border_value, axis=2) if img_patch.shape[2] == 4 else np.all(img_patch == 0, axis=2) + img_border_mask = ~is_border + return img_patch, trans, img_border_mask + + +def convert_cvimg_to_tensor(cvimg: np.array): + """ + Convert image from HWC to CHW format. + Args: + cvimg (np.array): Image of shape (H, W, 3) as loaded by OpenCV. + Returns: + np.array: Output image of shape (3, H, W). + """ + # from h,w,c(OpenCV) to c,h,w + img = cvimg.copy() + img = np.transpose(img, (2, 0, 1)) + # from int to float + img = img.astype(np.float32) + return img + + +def fliplr_params(smal_params: Dict, has_smal_params: Dict) -> Tuple[Dict, Dict]: + """ + Flip SMAL parameters when flipping the image. + Args: + smal_params (Dict): SMAL parameter annotations. + has_smal_params (Dict): Whether SMAL annotations are valid. + Returns: + Dict, Dict: Flipped SMAL parameters and valid flags. + """ + global_orient = smal_params['global_orient'].copy() + pose = smal_params['pose'].copy() + betas = smal_params['betas'].copy() + transl = smal_params['transl'].copy() + has_global_orient = has_smal_params['global_orient'].copy() + has_pose = has_smal_params['pose'].copy() + has_betas = has_smal_params['betas'].copy() + has_transl = has_smal_params['transl'].copy() + + global_orient[1::3] *= -1 + global_orient[2::3] *= -1 + pose[1::3] *= -1 + pose[2::3] *= -1 + transl[1::3] *= -1 + transl[2::3] *= -1 + + smal_params = {'global_orient': global_orient.astype(np.float32), + 'pose': pose.astype(np.float32), + 'betas': betas.astype(np.float32), + 'transl': transl.astype(np.float32) + } + + has_smal_params = {'global_orient': has_global_orient, + 'pose': has_pose, + 'betas': has_betas, + 'transl': has_transl + } + + return smal_params, has_smal_params + + +def fliplr_keypoints(joints: np.array, width: float, flip_permutation: List[int]) -> np.array: + """ + Flip 2D or 3D keypoints. + Args: + joints (np.array): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence. + flip_permutation (List): Permutation to apply after flipping. + Returns: + np.array: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively. + """ + joints = joints.copy() + # Flip horizontal + joints[:, 0] = width - joints[:, 0] - 1 + joints = joints[flip_permutation, :] + + return joints + + +def keypoint_3d_processing(keypoints_3d: np.array, rot: float, flip: bool) -> np.array: + """ + Process 3D keypoints (rotation/flipping). + Args: + keypoints_3d (np.array): Input array of shape (N, 4) containing the 3D keypoints and confidence. + rot (float): Random rotation applied to the keypoints. + Returns: + np.array: Transformed 3D keypoints with shape (N, 4). + """ + # in-plane rotation + rot_mat = np.eye(3, dtype=np.float32) + if not rot == 0: + rot_rad = -rot * np.pi / 180 + sn, cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0, :2] = [cs, -sn] + rot_mat[1, :2] = [sn, cs] + keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1]) + # flip the x coordinates + if flip: + keypoints_3d = fliplr_keypoints(keypoints_3d, list(range(len(keypoints_3d)))) + keypoints_3d = keypoints_3d.astype('float32') + return keypoints_3d + + +def rot_aa(aa: np.array, rot: float) -> np.array: + """ + Rotate axis angle parameters. + Args: + aa (np.array): Axis-angle vector of shape (3,). + rot (np.array): Rotation angle in degrees. + Returns: + np.array: Rotated axis-angle vector. + """ + # pose parameters + R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], + [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], + [0, 0, 1]]) + # find the rotation of the hand in camera frame + per_rdg, _ = cv2.Rodrigues(aa) + # apply the global rotation to the global orientation + resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg)) + aa = (resrot.T)[0] + return aa.astype(np.float32) + + +def smal_param_processing(smal_params: Dict, has_smal_params: Dict, rot: float, do_flip: bool) -> Tuple[Dict, Dict]: + """ + Apply random augmentations to the SMAL parameters. + Args: + smal_params (Dict): SMAL parameter annotations. + has_smal_params (Dict): Whether SMAL annotations are valid. + rot (float): Random rotation applied to the keypoints. + do_flip (bool): Whether to flip keypoints or not. + Returns: + Dict, Dict: Transformed SMAL parameters and valid flags. + """ + if do_flip: + smal_params, has_smal_params = fliplr_params(smal_params, has_smal_params) + smal_params['global_orient'] = rot_aa(smal_params['global_orient'], rot) + # camera location is not change, so the translation is not change too. + # smal_params['transl'] = np.dot(np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], + # [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], + # [0, 0, 1]], dtype=np.float32), smal_params['transl']) + return smal_params, has_smal_params + + +def get_example(img_path: Union[str,np.ndarray], center_x: float, center_y: float, + width: float, height: float, + keypoints_2d: np.array, keypoints_3d: np.array, + smal_params: Dict, has_smal_params: Dict, + patch_width: int, patch_height: int, + mean: np.array, std: np.array, + do_augment: bool, augm_config: CfgNode, + is_bgr: bool = True, + use_skimage_antialias: bool = False, + border_mode: int = cv2.BORDER_CONSTANT, + return_trans: bool = False,) -> Tuple: + """ + Get an example from the dataset and (possibly) apply random augmentations. + Args: + img_path (str): Image filename + center_x (float): Bounding box center x coordinate in the original image. + center_y (float): Bounding box center y coordinate in the original image. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates. + keypoints_3d (np.array): Array with shape (N,4) containing the 3D keypoints. + smal_params (Dict): SMAL parameter annotations. + has_smal_params (Dict): Whether SMAL annotations are valid. + patch_width (float): Output box width. + patch_height (float): Output box height. + mean (np.array): Array of shape (3,) containing the mean for normalizing the input image. + std (np.array): Array of shape (3,) containing the std for normalizing the input image. + do_augment (bool): Whether to apply data augmentation or not. + aug_config (CfgNode): Config containing augmentation parameters. + Returns: + return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size + img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height) + keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints. + keypoints_3d (np.array): Array with shape (N,4) containing the transformed 3D keypoints. + smal_params (Dict): Transformed SMAL parameters. + has_smal_params (Dict): Valid flag for transformed SMAL parameters. + img_size (np.array): Image size of the original image. + """ + if isinstance(img_path, str): + # 1. load image + cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) + if not isinstance(cvimg, np.ndarray): + raise IOError("Fail to read %s" % img_path) + elif isinstance(img_path, np.ndarray): + cvimg = img_path + else: + raise TypeError('img_path must be either a string or a numpy array') + img_height, img_width, img_channels = cvimg.shape + + img_size = np.array([img_height, img_width], dtype=np.int32) + + # 2. get augmentation params + if do_augment: + # box rescale factor, rotation angle, flip or not flip, crop or not crop, ..., color scale, translation x, ... + scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config) + else: + scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0, + 1.0, + 1.0], 0., 0. + if width < 1 or height < 1: + # Skip invalid samples with width/height < 1 + print(f"Warning: Invalid bbox size - width: {width}, height: {height}. Using default size.") + width = max(width, 1.0) + height = max(height, 1.0) + + if do_extreme_crop: + if extreme_crop_lvl == 0: + center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d) + elif extreme_crop_lvl == 1: + center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height, + keypoints_2d) + + THRESH = 4 + if width1 < THRESH or height1 < THRESH: + pass + else: + center_x, center_y, width, height = center_x1, center_y1, width1, height1 + + center_x += width * tx + center_y += height * ty + + # Process 3D keypoints + keypoints_3d = keypoint_3d_processing(keypoints_3d, rot, do_flip) + + # 3. generate image patch + if use_skimage_antialias: + # Blur image to avoid aliasing artifacts + downsampling_factor = (patch_width / (width * scale)) + if downsampling_factor > 1.1: + cvimg = gaussian(cvimg, sigma=(downsampling_factor - 1) / 2, channel_axis=2, preserve_range=True, + truncate=3.0) + # augmentation image, translation matrix + img_patch_cv, trans, img_border_mask = generate_image_patch_cv2(cvimg, + center_x, center_y, + width, height, + patch_width, patch_height, + do_flip, scale, rot, + border_mode=border_mode) + + image = img_patch_cv.copy() + if is_bgr: + image = image[:, :, ::-1] + img_patch_cv = image.copy() + img_patch = convert_cvimg_to_tensor(image) # [h, w, 4] -> [4, h, w] + + smal_params, has_smal_params = smal_param_processing(smal_params, has_smal_params, rot, do_flip) + + # apply normalization + for n_c in range(min(img_channels, 3)): + img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255) + if mean is not None and std is not None: + img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c] + + if do_flip: + keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, list(range(len(keypoints_2d)))) + + for n_jt in range(len(keypoints_2d)): + keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans) + keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5 + + if not return_trans: + return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, img_border_mask + else: + return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, trans, img_border_mask + + +def get_cub17_example(cvimg: np.array, + keypoints_2d: np.array, + center_x: float, center_y: float, + width: float, height: float, + patch_width: int, patch_height: int, + mean: np.array, std: np.array, + do_augment: bool, augm_config: CfgNode, + return_trans=True) -> Tuple: + """ + Get an example from the dataset and (possibly) apply random augmentations. + Args: + cvimg (np.ndarray): Image + keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates. + center_x (float): Bounding box center x coordinate in the original image. + center_y (float): Bounding box center y coordinate in the original image. + width (float): Bounding box width. + height (float): Bounding box height. + patch_width (int): Output box width. + patch_height (int): Output box height. + mean (np.array): Array of shape (3,) containing the mean for normalizing the input image. + std (np.array): Array of shape (3,) containing the std for normalizing the input image. + do_augment (bool): Whether to apply data augmentation or not. + aug_config (CfgNode): Config containing augmentation parameters. + Returns: + return img_patch, keypoints_2d + img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height) + keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints. + """ + img_height, img_width, img_channels = cvimg.shape + + img_size = np.array([img_height, img_width], dtype=np.int32) + + # 2. get augmentation params + if do_augment: + # box rescale factor, rotation angle, flip or not flip, crop or not crop, ..., color scale, translation x, ... + scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config) + else: + scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0, + 1.0, + 1.0], 0., 0. + # bounding box height and width + center_x += width * tx + center_y += height * ty + # augmentation image, translation matrix + img_patch_cv, trans, img_border_mask = generate_image_patch_cv2(cvimg, + center_x, center_y, + width, height, + patch_width, patch_height, + do_flip, scale, rot, + border_mode=cv2.BORDER_CONSTANT) + + image = img_patch_cv.copy() + img_patch = convert_cvimg_to_tensor(image) # [h, w, 4] -> [4, h, w] + + # apply normalization + for n_c in range(min(img_channels, 3)): + img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255) + if mean is not None and std is not None: + img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c] + + if do_flip: + keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, list(range(len(keypoints_2d)))) + + for n_jt in range(len(keypoints_2d)): + keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans) + keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5 + + if return_trans: + return img_patch, keypoints_2d, img_size, trans, img_border_mask + else: + return img_patch, keypoints_2d, img_size, img_border_mask + + +def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple: + """ + Extreme cropping: Crop the box up to the hip locations. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25 + 0, 25 + 1, 25 + 4, 25 + 5] + keypoints_2d[lower_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box up to the shoulder locations. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in + [0, 1, 2, 3, 4, 5, 6, 7, + 10, 11, 14, 15, 16]] + keypoints_2d[lower_body_keypoints, :] = 0 + center, scale = get_bbox(keypoints_2d) + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.2 * scale[0] + height = 1.2 * scale[1] + return center_x, center_y, width, height + + +def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the head. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in + [0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 14, 15, 16]] + keypoints_2d[lower_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.3 * scale[0] + height = 1.3 * scale[1] + return center_x, center_y, width, height + + +def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the torso. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in + [0, 1, 4, 5, 6, + 7, 10, 11, 13, + 17, 18]] + keypoints_2d[nontorso_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the right arm. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [ + 25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonrightarm_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the left arm. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [ + 25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]] + keypoints_2d[nonleftarm_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the legs. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in + [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]] + keypoints_2d[nonlegs_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the right leg. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in + [3, 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18]] + keypoints_2d[nonrightleg_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array): + """ + Extreme cropping: Crop the box and keep on only the left leg. + Args: + center_x (float): x coordinate of the bounding box center. + center_y (float): y coordinate of the bounding box center. + width (float): Bounding box width. + height (float): Bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + center_x (float): x coordinate of the new bounding box center. + center_y (float): y coordinate of the new bounding box center. + width (float): New bounding box width. + height (float): New bounding box height. + """ + keypoints_2d = keypoints_2d.copy() + nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in + [0, 1, 2, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18]] + keypoints_2d[nonleftleg_body_keypoints, :] = 0 + if keypoints_2d[:, -1].sum() > 1: + center, scale = get_bbox(keypoints_2d) + center_x = center[0] + center_y = center[1] + width = 1.1 * scale[0] + height = 1.1 * scale[1] + return center_x, center_y, width, height + + +def full_body(keypoints_2d: np.array) -> bool: + """ + Check if all main body joints are visible. + Args: + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + bool: True if all main body joints are visible. + """ + + body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14] + body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]] + return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len( + body_keypoints) + + +def upper_body(keypoints_2d: np.array): + """ + Check if all upper body joints are visible. + Args: + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + Returns: + bool: True if all main body joints are visible. + """ + lower_body_keypoints_openpose = [10, 11, 13, 14] + lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]] + upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18] + upper_body_keypoints = [25 + 8, 25 + 9, 25 + 12, 25 + 13, 25 + 17, 25 + 18] + return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0) \ + and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2) + + +def get_bbox(keypoints_2d: np.array, rescale: float = 1.2) -> Tuple: + """ + Get center and scale for bounding box from openpose detections. + Args: + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + rescale (float): Scale factor to rescale bounding boxes computed from the keypoints. + Returns: + center (np.array): Array of shape (2,) containing the new bounding box center. + scale (float): New bounding box scale. + """ + valid = keypoints_2d[:, -1] > 0 + valid_keypoints = keypoints_2d[valid][:, :-1] + center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0)) + bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0)) + # adjust bounding box tightness + scale = bbox_size + scale *= rescale + return center, scale + + +def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple: + """ + Perform extreme cropping + Args: + center_x (float): x coordinate of bounding box center. + center_y (float): y coordinate of bounding box center. + width (float): bounding box width. + height (float): bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + rescale (float): Scale factor to rescale bounding boxes computed from the keypoints. + Returns: + center_x (float): x coordinate of bounding box center. + center_y (float): y coordinate of bounding box center. + width (float): bounding box width. + height (float): bounding box height. + """ + p = torch.rand(1).item() + if full_body(keypoints_2d): + if p < 0.7: + center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d) + elif p < 0.9: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + elif upper_body(keypoints_2d): + if p < 0.9: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + + return center_x, center_y, max(width, height), max(width, height) + + +def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float, + keypoints_2d: np.array) -> Tuple: + """ + Perform aggressive extreme cropping + Args: + center_x (float): x coordinate of bounding box center. + center_y (float): y coordinate of bounding box center. + width (float): bounding box width. + height (float): bounding box height. + keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations. + rescale (float): Scale factor to rescale bounding boxes computed from the keypoints. + Returns: + center_x (float): x coordinate of bounding box center. + center_y (float): y coordinate of bounding box center. + width (float): bounding box width. + height (float): bounding box height. + """ + p = torch.rand(1).item() + if full_body(keypoints_2d): + if p < 0.2: + center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d) + elif p < 0.3: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + elif p < 0.4: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + elif p < 0.5: + center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.6: + center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.7: + center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.8: + center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.9: + center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d) + elif upper_body(keypoints_2d): + if p < 0.2: + center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d) + elif p < 0.4: + center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d) + elif p < 0.6: + center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d) + elif p < 0.8: + center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d) + else: + center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d) + return center_x, center_y, max(width, height), max(width, height) diff --git a/prima/datasets/vitdet_dataset.py b/prima/datasets/vitdet_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..55ba785fc0f9315501f122a76f8ab7766dba062a --- /dev/null +++ b/prima/datasets/vitdet_dataset.py @@ -0,0 +1,100 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from typing import Dict + +import cv2 +import numpy as np +from skimage.filters import gaussian +from yacs.config import CfgNode +import torch + +from .utils import (convert_cvimg_to_tensor, + expand_to_aspect_ratio, + generate_image_patch_cv2) + +DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406]) +DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225]) + + +class ViTDetDataset(torch.utils.data.Dataset): + + def __init__(self, + cfg: CfgNode, + img_cv2: np.array, + boxes: np.array, + rescale_factor=1, + train: bool = False, + **kwargs): + super().__init__() + self.cfg = cfg + self.img_cv2 = img_cv2 + self.boxes = boxes + + assert train is False, "ViTDetDataset is only for inference" + self.train = train + self.img_size = cfg.MODEL.IMAGE_SIZE + self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN) + self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD) + + # Preprocess annotations + boxes = boxes.astype(np.float32) + self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0 + self.scale = rescale_factor * (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0 + self.animalid = np.arange(len(boxes), dtype=np.int32) + + def __len__(self) -> int: + return len(self.animalid) + + def __getitem__(self, idx: int) -> Dict[str, np.array]: + + center = self.center[idx].copy() + center_x = center[0] + center_y = center[1] + + scale = self.scale[idx] + BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None) + bbox_size = expand_to_aspect_ratio(scale * 200, target_aspect_ratio=BBOX_SHAPE).max() + + patch_width = patch_height = self.img_size + + flip = False + + # 3. generate image patch + # if use_skimage_antialias: + cvimg = self.img_cv2.copy() + if True: + # Blur image to avoid aliasing artifacts + downsampling_factor = ((bbox_size * 1.0) / patch_width) + downsampling_factor = downsampling_factor / 2.0 + if downsampling_factor > 1.1: + cvimg = gaussian(cvimg, sigma=(downsampling_factor - 1) / 2, channel_axis=2, preserve_range=True) + + img_patch_cv, trans, _ = generate_image_patch_cv2(cvimg, + center_x, center_y, + bbox_size, bbox_size, + patch_width, patch_height, + flip, 1.0, 0.0, + border_mode=cv2.BORDER_CONSTANT) + img_patch_cv = img_patch_cv[:, :, ::-1] + img_patch = convert_cvimg_to_tensor(img_patch_cv) + + # apply normalization + for n_c in range(min(self.img_cv2.shape[2], 3)): + img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c] + + item = { + 'img': img_patch, + 'animalid': int(self.animalid[idx]), + 'box_center': self.center[idx].copy(), + 'box_size': bbox_size, + 'img_size': 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]]), + 'focal_length': np.array([self.cfg.EXTRA.FOCAL_LENGTH, self.cfg.EXTRA.FOCAL_LENGTH]), + } + return item diff --git a/prima/models/__init__.py b/prima/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d41a09abd2b42228d2a13373344415dc6d781ba8 --- /dev/null +++ b/prima/models/__init__.py @@ -0,0 +1,54 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from .prima import PRIMA + + +def load_prima(checkpoint_path): + from pathlib import Path + from ..configs import get_config + model_cfg = str(Path(checkpoint_path).parent.parent / '.hydra/config.yaml') + model_cfg = get_config(model_cfg) + + # Override some config values, to crop bbox correctly + if (model_cfg.MODEL.BACKBONE.TYPE == 'vit') and ('BBOX_SHAPE' not in model_cfg.MODEL): + model_cfg.defrost() + assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone" + model_cfg.MODEL.BBOX_SHAPE = [192, 256] + model_cfg.freeze() + if (model_cfg.MODEL.BACKBONE.TYPE == 'dinov3') and ('BBOX_SHAPE' not in model_cfg.MODEL): + model_cfg.defrost() + assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for dino backbone" + model_cfg.MODEL.BBOX_SHAPE = [256, 256] + model_cfg.freeze() + + if (model_cfg.MODEL.BACKBONE.TYPE == 'dinov2') and ('BBOX_SHAPE' not in model_cfg.MODEL): + model_cfg.defrost() + assert model_cfg.MODEL.IMAGE_SIZE == 252, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 252 for dino backbone" + model_cfg.MODEL.BBOX_SHAPE = [252, 252] + model_cfg.freeze() + + + + # Update config to be compatible with demo + if ('PRETRAINED_WEIGHTS' in model_cfg.MODEL.BACKBONE): + model_cfg.defrost() + model_cfg.MODEL.BACKBONE.pop('PRETRAINED_WEIGHTS') + model_cfg.freeze() + + # Offscreen training renderer is not needed for demo/inference startup and + # can fail on some local OpenGL backends. + model = PRIMA.load_from_checkpoint( + checkpoint_path, + strict=False, + cfg=model_cfg, + map_location='cpu', + init_renderer=False, + ) + return model, model_cfg diff --git a/prima/models/backbones/__init__.py b/prima/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fedb1b14855557664259adf3c4a331f185410363 --- /dev/null +++ b/prima/models/backbones/__init__.py @@ -0,0 +1,19 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from .vit import vith + + + + +def create_backbone(cfg): + if cfg.MODEL.BACKBONE.TYPE in ['vith','concat','aa']: # vit bb will be used in these three cases - animal feature extractor + return vith(cfg) + else: + raise NotImplementedError('Backbone type is not implemented') diff --git a/prima/models/backbones/vit.py b/prima/models/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b625913a166263119532e3b96de745259157c --- /dev/null +++ b/prima/models/backbones/vit.py @@ -0,0 +1,375 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from functools import partial +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.layers import drop_path, to_2tuple, trunc_normal_ + + +def vith(cfg): + return ViT( + img_size=(256, 192), + patch_size=16, + embed_dim=1280, + depth=32, + num_heads=16, + ratio=1, + use_checkpoint=False, + # use_checkpoint=True, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.55, + use_cls=True, # cls for animal family classification + ) + + +def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + cls_token = None + B, L, C = abs_pos.shape + if has_cls_token: + cls_token = abs_pos[:, 0:1] + abs_pos = abs_pos[:, 1:] + + if ori_h != h or ori_w != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).reshape(B, -1, C) + + else: + new_abs_pos = abs_pos + + if cls_token is not None: + new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1) + return new_abs_pos + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.dim = dim + + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, + norm_layer=nn.LayerNorm, attn_head_dim=None, + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim + ) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2) + self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio)) + self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1])) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), + padding=4 + 2 * (ratio // 2 - 1)) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, (Hp, Wp) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class ViT(nn.Module): + + def __init__(self, + img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False, + frozen_stages=-1, ratio=1, last_norm=True, use_cls=False, + patch_padding='pad', freeze_attn=False, freeze_ffn=False, + ): + # Protect mutable default arguments + super(ViT, self).__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.frozen_stages = frozen_stages + self.use_checkpoint = use_checkpoint + self.patch_padding = patch_padding + self.freeze_attn = freeze_attn + self.freeze_ffn = freeze_ffn + self.depth = depth + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio) + num_patches = self.patch_embed.num_patches + + # since the pretraining model has class token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + ) + for i in range(depth)]) + + self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + + self.use_cls = use_cls + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + nn.init.normal_(self.cls_token, std=1e-6) + + self._freeze_stages() + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if self.freeze_attn: + for i in range(0, self.depth): + m = self.blocks[i] + m.attn.eval() + m.norm1.eval() + for param in m.attn.parameters(): + param.requires_grad = False + for param in m.norm1.parameters(): + param.requires_grad = False + + if self.freeze_ffn: + self.pos_embed.requires_grad = False + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + for i in range(0, self.depth): + m = self.blocks[i] + m.mlp.eval() + m.norm2.eval() + for param in m.mlp.parameters(): + param.requires_grad = False + for param in m.norm2.parameters(): + param.requires_grad = False + + def init_weights(self): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x): + B, C, H, W = x.shape + x, (Hp, Wp) = self.patch_embed(x) + + if self.pos_embed is not None: + # fit for multiple GPU training + # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference + x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] + + x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) if self.use_cls else x + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + x = self.last_norm(x) + + cls = x[:, 0] if self.use_cls else None + x = x[:, 1:] if self.use_cls else x + xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() + + return xp, cls # shape [B, D, Hp, Wp], [B, D] + + def forward(self, x): + x, cls = self.forward_features(x) + return x, cls + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() diff --git a/prima/models/bioclip_embedding.py b/prima/models/bioclip_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1ebd5c1ada10748c48622c4eb3c5226e2d3136 --- /dev/null +++ b/prima/models/bioclip_embedding.py @@ -0,0 +1,70 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +""" +bioclip Embedding Module +Converts image batch to embeddings that can be concatenated with image features +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +class BioClipEmbedding(nn.Module): + """ + Embeds images into a feature space using BioClip model that can be combined with image features. + + Args: + embed_dim: Output embedding dimension, should match the dimension of image features for concatenation + """ + + def __init__(self, cfg, embed_dim: int = 1024): + super().__init__() + + self.embed_dim = embed_dim + + import open_clip + + if cfg.MODEL.BIOCLIP_EMBEDDING.TYPE == 'bioclip2': + print("[BioClipEmbedding] Using BioClip2 model from Hugging Face Hub") + self.species_model, _,_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip-2') + else: + self.species_model, _,_ = open_clip.create_model_and_transforms('hf-hub:imageomics/bioclip') + # tokenizer = open_clip.get_tokenizer('hf-hub:imageomics/bioclip') + + + self.species_model.eval() + + # Get the output dimension from the model + bioclip_output_dim = self.species_model.visual.output_dim + + # Project to target dimension + self.projection = nn.Sequential( + nn.Linear(bioclip_output_dim, embed_dim), + nn.LayerNorm(embed_dim), + ) + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Args: + images: Tensor of shape (B, C, H, W) representing a batch of images + Returns: + Tensor of shape (B, embed_dim) representing the embedded features + """ + # BioClip expects 224x224 input, resize if needed + if images.shape[-2:] != (224, 224): + images_resized = F.interpolate(images, size=(224, 224), mode='bilinear', align_corners=False) + else: + images_resized = images + + with torch.no_grad(): + image_features = self.species_model.encode_image(images_resized) + + projected_features = self.projection(image_features) + + return projected_features \ No newline at end of file diff --git a/prima/models/components/__init__.py b/prima/models/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/prima/models/components/model_utils.py b/prima/models/components/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..16d8a5ab4140da0335edf9720fa7788180defc13 --- /dev/null +++ b/prima/models/components/model_utils.py @@ -0,0 +1,160 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DropPath(nn.Module): + # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + def __init__(self, drop_prob=0.0, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/prima/models/components/pose_transformer.py b/prima/models/components/pose_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..698e30bd866cc5ea7a85064d67f85b068584f022 --- /dev/null +++ b/prima/models/components/pose_transformer.py @@ -0,0 +1,366 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from inspect import isfunction +from typing import Callable, Optional + +import torch +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import nn + +from .t_cond_mlp import ( + AdaptiveLayerNorm1D, + FrequencyEmbedder, + normalization_layer, +) + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +class PreNorm(nn.Module): + def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1): + super().__init__() + self.norm = normalization_layer(norm, dim, norm_cond_dim) + self.fn = fn + + def forward(self, x: torch.Tensor, *args, **kwargs): + if isinstance(self.norm, AdaptiveLayerNorm1D): + return self.fn(self.norm(x, *args), **kwargs) + else: + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class CrossAttention(nn.Module): + def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + context_dim = default(context_dim, dim) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) + self.to_q = nn.Linear(dim, inner_dim, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x, context=None): + context = default(context, x) + k, v = self.to_kv(context).chunk(2, dim=-1) + q = self.to_q(x) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v]) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + norm: str = "layer", + norm_cond_dim: int = -1, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + ff = FeedForward(dim, mlp_dim, dropout=dropout) + self.layers.append( + nn.ModuleList( + [ + PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), + ] + ) + ) + + def forward(self, x: torch.Tensor, *args): + for attn, ff in self.layers: + x = attn(x, *args) + x + x = ff(x, *args) + x + return x + + +class TransformerCrossAttn(nn.Module): + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + norm: str = "layer", + norm_cond_dim: int = -1, + context_dim: Optional[int] = None, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) + ca = CrossAttention( + dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout + ) + ff = FeedForward(dim, mlp_dim, dropout=dropout) + self.layers.append( + nn.ModuleList( + [ + PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim), + PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim), + ] + ) + ) + + def forward(self, x: torch.Tensor, *args, context=None, context_list=None): + if context_list is None: + context_list = [context] * len(self.layers) + if len(context_list) != len(self.layers): + raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})") + + for i, (self_attn, cross_attn, ff) in enumerate(self.layers): + x = self_attn(x, *args) + x + x = cross_attn(x, *args, context=context_list[i]) + x + x = ff(x, *args) + x + return x + + +class DropTokenDropout(nn.Module): + def __init__(self, p: float = 0.1): + super().__init__() + if p < 0 or p > 1: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + self.p = p + + def forward(self, x: torch.Tensor): + # x: (batch_size, seq_len, dim) + if self.training and self.p > 0: + zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool() + + if zero_mask.any(): + x = x[:, ~zero_mask, :] + return x + + +class ZeroTokenDropout(nn.Module): + def __init__(self, p: float = 0.1): + super().__init__() + if p < 0 or p > 1: + raise ValueError( + "dropout probability has to be between 0 and 1, " "but got {}".format(p) + ) + self.p = p + + def forward(self, x: torch.Tensor): + # x: (batch_size, seq_len, dim) + if self.training and self.p > 0: + zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool() + # Zero-out the masked tokens + x[zero_mask, :] = 0 + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, + num_tokens: int, + token_dim: int, + dim: int, + depth: int, + heads: int, + mlp_dim: int, + dim_head: int = 64, + dropout: float = 0.0, + emb_dropout: float = 0.0, + emb_dropout_type: str = "drop", + emb_dropout_loc: str = "token", + norm: str = "layer", + norm_cond_dim: int = -1, + token_pe_numfreq: int = -1, + ): + super().__init__() + if token_pe_numfreq > 0: + token_dim_new = token_dim * (2 * token_pe_numfreq + 1) + self.to_token_embedding = nn.Sequential( + Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim), + FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1), + Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new), + nn.Linear(token_dim_new, dim), + ) + else: + self.to_token_embedding = nn.Linear(token_dim, dim) + self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) + if emb_dropout_type == "drop": + self.dropout = DropTokenDropout(emb_dropout) + elif emb_dropout_type == "zero": + self.dropout = ZeroTokenDropout(emb_dropout) + else: + raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}") + self.emb_dropout_loc = emb_dropout_loc + + self.transformer = Transformer( + dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim + ) + + def forward(self, inp: torch.Tensor, *args, **kwargs): + x = inp + + if self.emb_dropout_loc == "input": + x = self.dropout(x) + x = self.to_token_embedding(x) + + if self.emb_dropout_loc == "token": + x = self.dropout(x) + b, n, _ = x.shape + x += self.pos_embedding[:, :n] + + if self.emb_dropout_loc == "token_afterpos": + x = self.dropout(x) + x = self.transformer(x, *args) + return x + + +class TransformerDecoder(nn.Module): + def __init__( + self, + num_tokens: int, + token_dim: int, + dim: int, + depth: int, + heads: int, + mlp_dim: int, + dim_head: int = 64, + dropout: float = 0.0, + emb_dropout: float = 0.0, + emb_dropout_type: str = 'drop', + norm: str = "layer", + norm_cond_dim: int = -1, + context_dim: Optional[int] = None, + skip_token_embedding: bool = False, + ): + super().__init__() + if not skip_token_embedding: + self.to_token_embedding = nn.Linear(token_dim, dim) + else: + self.to_token_embedding = nn.Identity() + if token_dim != dim: + raise ValueError( + f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True" + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) + if emb_dropout_type == "drop": + self.dropout = DropTokenDropout(emb_dropout) + elif emb_dropout_type == "zero": + self.dropout = ZeroTokenDropout(emb_dropout) + elif emb_dropout_type == "normal": + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = TransformerCrossAttn( + dim, + depth, + heads, + dim_head, + mlp_dim, + dropout, + norm=norm, + norm_cond_dim=norm_cond_dim, + context_dim=context_dim, + ) + + def forward(self, inp: torch.Tensor, *args, context=None, context_list=None): + x = self.to_token_embedding(inp) + b, n, _ = x.shape + + x = self.dropout(x) + x += self.pos_embedding[:, :n] + + x = self.transformer(x, *args, context=context, context_list=context_list) + return x + diff --git a/prima/models/components/position_encoding.py b/prima/models/components/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..8456b77882ac9ad7d8c86dc84209ab63a06d4f46 --- /dev/null +++ b/prima/models/components/position_encoding.py @@ -0,0 +1,84 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple + +import numpy as np + +import torch +from torch import nn + +# Rotary Positional Encoding, adapted from: +# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +# 2. https://github.com/naver-ai/rope-vit +# 3. https://github.com/lucidrains/rotary-embedding-torch + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + if freqs_cis.is_cuda: + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + else: + # torch.repeat on complex numbers may not be supported on non-CUDA devices + # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten + freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/prima/models/components/t_cond_mlp.py b/prima/models/components/t_cond_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..df2e931aea60e82dbe5cb2b08cad111e401f1e41 --- /dev/null +++ b/prima/models/components/t_cond_mlp.py @@ -0,0 +1,204 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import copy +from typing import List, Optional + +import torch + + +class AdaptiveLayerNorm1D(torch.nn.Module): + def __init__(self, data_dim: int, norm_cond_dim: int): + super().__init__() + if data_dim <= 0: + raise ValueError(f"data_dim must be positive, but got {data_dim}") + if norm_cond_dim <= 0: + raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}") + self.norm = torch.nn.LayerNorm(data_dim) + self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim) + torch.nn.init.zeros_(self.linear.weight) + torch.nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + # x: (batch, ..., data_dim) + # t: (batch, norm_cond_dim) + # return: (batch, data_dim) + x = self.norm(x) + alpha, beta = self.linear(t).chunk(2, dim=-1) + + # Add singleton dimensions to alpha and beta + if x.dim() > 2: + alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1]) + beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1]) + + return x * (1 + alpha) + beta + + +class SequentialCond(torch.nn.Sequential): + def forward(self, input, *args, **kwargs): + for module in self: + if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)): + input = module(input, *args, **kwargs) + else: + input = module(input) + return input + + +def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1): + if norm == "batch": + return torch.nn.BatchNorm1d(dim) + elif norm == "layer": + return torch.nn.LayerNorm(dim) + elif norm == "ada": + assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}" + return AdaptiveLayerNorm1D(dim, norm_cond_dim) + elif norm is None: + return torch.nn.Identity() + else: + raise ValueError(f"Unknown norm: {norm}") + + +def linear_norm_activ_dropout( + input_dim: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, +) -> SequentialCond: + layers = [] + layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias)) + if norm is not None: + layers.append(normalization_layer(norm, output_dim, norm_cond_dim)) + layers.append(copy.deepcopy(activation)) + if dropout > 0.0: + layers.append(torch.nn.Dropout(dropout)) + return SequentialCond(*layers) + + +def create_simple_mlp( + input_dim: int, + hidden_dims: List[int], + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, +) -> SequentialCond: + layers = [] + prev_dim = input_dim + for hidden_dim in hidden_dims: + layers.extend( + linear_norm_activ_dropout( + prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ) + ) + prev_dim = hidden_dim + layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias)) + return SequentialCond(*layers) + + +class ResidualMLPBlock(torch.nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + num_hidden_layers: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + norm_cond_dim: int = -1, + ): + super().__init__() + if not (input_dim == output_dim == hidden_dim): + raise NotImplementedError( + f"input_dim {input_dim} != output_dim {output_dim} is not implemented" + ) + + layers = [] + prev_dim = input_dim + for i in range(num_hidden_layers): + layers.append( + linear_norm_activ_dropout( + prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ) + ) + prev_dim = hidden_dim + self.model = SequentialCond(*layers) + self.skip = torch.nn.Identity() + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return x + self.model(x, *args, **kwargs) + + +class ResidualMLP(torch.nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + num_hidden_layers: int, + output_dim: int, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + norm: Optional[str] = "layer", # Options: ada/batch/layer + dropout: float = 0.0, + num_blocks: int = 1, + norm_cond_dim: int = -1, + ): + super().__init__() + self.input_dim = input_dim + self.model = SequentialCond( + linear_norm_activ_dropout( + input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim + ), + *[ + ResidualMLPBlock( + hidden_dim, + hidden_dim, + num_hidden_layers, + hidden_dim, + activation, + bias, + norm, + dropout, + norm_cond_dim, + ) + for _ in range(num_blocks) + ], + torch.nn.Linear(hidden_dim, output_dim, bias=bias), + ) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return self.model(x, *args, **kwargs) + + +class FrequencyEmbedder(torch.nn.Module): + def __init__(self, num_frequencies, max_freq_log2): + super().__init__() + frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies) + self.register_buffer("frequencies", frequencies) + + def forward(self, x): + # x should be of size (N,) or (N, D) + N = x.size(0) + if x.dim() == 1: # (N,) + x = x.unsqueeze(1) # (N, D) where D=1 + x_unsqueezed = x.unsqueeze(-1) # (N, D, 1) + scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies) + s = torch.sin(scaled) + c = torch.cos(scaled) + embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view( + N, -1 + ) # (N, D * 2 * num_frequencies + D) + return embedded + diff --git a/prima/models/components/transformer.py b/prima/models/components/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5e42ddf114db7fa1ec4e95aef8f3ebb6fbbbd72f --- /dev/null +++ b/prima/models/components/transformer.py @@ -0,0 +1,400 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import math +import warnings +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from .position_encoding import apply_rotary_enc, compute_axial_cis +from .model_utils import MLP + +warnings.simplefilter(action="ignore", category=FutureWarning) + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +# Check whether Flash Attention is available (and use it by default) +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() +# A fallback setting to allow all available kernels if Flash Attention fails +ALLOW_ALL_KERNELS = False + + +def sdp_kernel_context(dropout_p): + """ + Get the context for the attention scaled dot-product kernel. We use Flash Attention + by default, but fall back to all available kernels if Flash Attention fails. + """ + if ALLOW_ALL_KERNELS: + return contextlib.nullcontext() + + return torch.backends.cuda.sdp_kernel( + enable_flash=USE_FLASH_ATTN, + # if Flash attention kernel is off, then math kernel needs to be enabled + enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + enable_mem_efficient=OLD_GPU, + ) + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation + ) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2).contiguous() # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2).contiguous() + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int=0, + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/prima/models/discriminator.py b/prima/models/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..1465965754417f203337d437f665299fc5df8e34 --- /dev/null +++ b/prima/models/discriminator.py @@ -0,0 +1,129 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import torch +import torch.nn as nn + + +class Discriminator(nn.Module): + + def __init__(self): + """ + Pose + Shape discriminator proposed in HMR + """ + super(Discriminator, self).__init__() + + self.num_joints = 34 + # poses_alone + self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1) + nn.init.xavier_uniform_(self.D_conv1.weight) + nn.init.zeros_(self.D_conv1.bias) + self.relu = nn.ReLU(inplace=True) + self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1) + nn.init.xavier_uniform_(self.D_conv2.weight) + nn.init.zeros_(self.D_conv2.bias) + pose_out = [] + for i in range(self.num_joints): + pose_out_temp = nn.Linear(32, 1) + nn.init.xavier_uniform_(pose_out_temp.weight) + nn.init.zeros_(pose_out_temp.bias) + pose_out.append(pose_out_temp) + self.pose_out = nn.ModuleList(pose_out) + + # betas + self.betas_fc1 = nn.Linear(41, 10) # SMAL betas is 41 + nn.init.xavier_uniform_(self.betas_fc1.weight) + nn.init.zeros_(self.betas_fc1.bias) + self.betas_fc2 = nn.Linear(10, 5) + nn.init.xavier_uniform_(self.betas_fc2.weight) + nn.init.zeros_(self.betas_fc2.bias) + self.betas_out = nn.Linear(5, 1) + nn.init.xavier_uniform_(self.betas_out.weight) + nn.init.zeros_(self.betas_out.bias) + + # bones + self.bone_fc1 = nn.Linear(24, 10) # SMAL betas is 41 + nn.init.xavier_uniform_(self.bone_fc1.weight) + nn.init.zeros_(self.bone_fc1.bias) + self.bone_fc2 = nn.Linear(10, 5) + nn.init.xavier_uniform_(self.bone_fc2.weight) + nn.init.zeros_(self.bone_fc2.bias) + self.bone_out = nn.Linear(5, 1) + nn.init.xavier_uniform_(self.bone_out.weight) + nn.init.zeros_(self.bone_out.bias) + + # poses_joint + self.D_alljoints_fc1 = nn.Linear(32 * self.num_joints, 1024) + nn.init.xavier_uniform_(self.D_alljoints_fc1.weight) + nn.init.zeros_(self.D_alljoints_fc1.bias) + self.D_alljoints_fc2 = nn.Linear(1024, 1024) + nn.init.xavier_uniform_(self.D_alljoints_fc2.weight) + nn.init.zeros_(self.D_alljoints_fc2.bias) + self.D_alljoints_out = nn.Linear(1024, 1) + nn.init.xavier_uniform_(self.D_alljoints_out.weight) + nn.init.zeros_(self.D_alljoints_out.bias) + + def forward(self, poses: torch.Tensor, betas: torch.Tensor, bone=None) -> torch.Tensor: + """ + Forward pass of the discriminator. + Args: + poses (torch.Tensor): Tensor of shape (B, 23, 3, 3) containing a batch of poses (excluding the global orientation). + betas (torch.Tensor): Tensor of shape (B, 41) containing a batch of SMAL beta coefficients. + Returns: + torch.Tensor: Discriminator output with shape (B, 25) + """ + # bn = poses.shape[0] + # poses B x 207 + # poses = poses.reshape(bn, -1) + # poses B x num_joints x 1 x 9 + poses = poses.reshape(-1, self.num_joints, 1, 9) + bn = poses.shape[0] + # poses B x 9 x num_joints x 1 + poses = poses.permute(0, 3, 1, 2).contiguous() + + # poses_alone + poses = self.D_conv1(poses) + poses = self.relu(poses) + poses = self.D_conv2(poses) + poses = self.relu(poses) + + poses_out = [] + for i in range(self.num_joints): + poses_out_ = self.pose_out[i](poses[:, :, i, 0]) + poses_out.append(poses_out_) + poses_out = torch.cat(poses_out, dim=1) + + # betas + betas = self.betas_fc1(betas) + betas = self.relu(betas) + betas = self.betas_fc2(betas) + betas = self.relu(betas) + betas_out = self.betas_out(betas) + + # bone + if bone is not None: + bone = self.bone_fc1(bone) + bone = self.relu(bone) + bone = self.bone_fc2(bone) + bone = self.relu(bone) + bone_out = self.bone_out(bone) + + # poses_joint + poses = poses.reshape(bn, -1) + poses_all = self.D_alljoints_fc1(poses) + poses_all = self.relu(poses_all) + poses_all = self.D_alljoints_fc2(poses_all) + poses_all = self.relu(poses_all) + poses_all_out = self.D_alljoints_out(poses_all) + + if bone is not None: + disc_out = torch.cat((poses_out, betas_out, poses_all_out, bone_out), 1) + else: + disc_out = torch.cat((poses_out, betas_out, poses_all_out), 1) + return disc_out diff --git a/prima/models/heads/__init__.py b/prima/models/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0c0852538a268d9e93332e15553178a6e55439b2 --- /dev/null +++ b/prima/models/heads/__init__.py @@ -0,0 +1 @@ +from .smal_head import build_smal_head diff --git a/prima/models/heads/classifier_head.py b/prima/models/heads/classifier_head.py new file mode 100644 index 0000000000000000000000000000000000000000..ff034cbe5a25dc90f5fb4926768ca0e41d93b6e1 --- /dev/null +++ b/prima/models/heads/classifier_head.py @@ -0,0 +1,30 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from torch import nn + + +class ClassTokenHead(nn.Module): + def __init__(self, embed_dim=1280, hidden_dim=4096, output_dim=256, num_layers=3, last_bn=True): + super().__init__() + mlp = [] + for l in range(num_layers): + dim1 = embed_dim if l == 0 else hidden_dim + dim2 = output_dim if l == num_layers - 1 else hidden_dim + mlp.append(nn.Linear(dim1, dim2, bias=False)) + if l < num_layers - 1: + mlp.append(nn.BatchNorm1d(dim2)) + mlp.append(nn.ReLU(inplace=True)) + elif last_bn: + mlp.append(nn.BatchNorm1d(dim2, affine=False)) + self.head = nn.Sequential(*mlp) + + def forward(self, x): + cls_feats = self.head(x) + return cls_feats \ No newline at end of file diff --git a/prima/models/heads/smal_head.py b/prima/models/heads/smal_head.py new file mode 100644 index 0000000000000000000000000000000000000000..53e20b0d9b4d020a79bddc6cfaa9e59f1dc16188 --- /dev/null +++ b/prima/models/heads/smal_head.py @@ -0,0 +1,647 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import einops +import pickle as pkl +from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat +from ..components.pose_transformer import TransformerDecoder + + +def build_smal_head(cfg): + smal_head_type = cfg.MODEL.SMAL_HEAD.get('TYPE', 'amr') + + if smal_head_type == 'new_bio_pose_transformer_decoder': + return NewBioGuidedSMALPoseDecoder(cfg) + else: + raise ValueError('Unknown SMAL head type: {}'.format(smal_head_type)) + + + + + + + +class NewBioGuidedSMALPoseDecoder(nn.Module): + """ + Bio-Guided SMAL Decoder with Pose Token Aggregation + + Final version: + - Query tokens = [param token] + [2D keypoint tokens (optional)] + [3D keypoint tokens (optional)] + - SAM3D-body-style layer-wise keypoint token updates: + * 2D: predict (x,y) in [-0.5,0.5] from kp2d tokens -> token_augment position encoding + + grid_sample image features at predicted locations -> add into kp2d token embeddings + + invalid_mask (out-of-bounds and optional vis mask) zeroes updates + * 3D: predict (x,y,z) from kp3d tokens -> pelvis-normalize -> token_augment position encoding + - token_augment is injected by feeding (token_embeddings + token_augment) into each decoder layer. + - Only param token (index 0) is used to regress pose/betas/cam deltas. + - Outputs: + pred_smal_params: dict with global_orient/pose/betas and optional keypoints_2d/3d + pred_cam: [B,3] + extra_outputs: includes bio-guided shape_feat/init_betas and pred_smal_params_list + """ + + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + # ========== Basic config ========== + self.joint_rep_type = cfg.MODEL.SMAL_HEAD.get("JOINT_REP", "6d") + self.joint_rep_dim = {"6d": 6, "aa": 3}[self.joint_rep_type] + self.npose = self.joint_rep_dim * (cfg.SMAL.NUM_JOINTS + 1) + + # ========== Dimensions ========== + self.decoder_dim = cfg.MODEL.SMAL_HEAD.get("DECODER_DIM", 1024) + context_dim = cfg.MODEL.SMAL_HEAD.get("IN_CHANNELS", 1024) + num_layers = cfg.MODEL.SMAL_HEAD.get("NUM_DECODER_LAYERS", 4) + num_heads = cfg.MODEL.SMAL_HEAD.get("NUM_HEADS", 8) + mlp_ratio = cfg.MODEL.SMAL_HEAD.get("MLP_RATIO", 4.0) + + # keypoint config + self.use_keypoint_2d_tokens = cfg.MODEL.SMAL_HEAD.get("USE_KEYPOINT_2D_TOKENS", False) + self.use_keypoint_3d_tokens = cfg.MODEL.SMAL_HEAD.get("USE_KEYPOINT_3D_TOKENS", False) + self.num_keypoints = cfg.SMAL.get("NUM_KEYPOINTS", 26) + self.keypoint_token_update = cfg.MODEL.SMAL_HEAD.get("KEYPOINT_TOKEN_UPDATE", False) + + # 2D update: whether to inject sampled image feature into kp2d tokens + self.kp2d_inject_image_feat = cfg.MODEL.SMAL_HEAD.get("KP2D_INJECT_IMAGE_FEAT", True) + + # IEF iters + self.ief_iters = cfg.MODEL.SMAL_HEAD.get("IEF_ITERS", 3) + + # pelvis indices + self.pelvis_idx = cfg.SMAL.get("PELVIS_IDX", [0, 1]) + + # ========== Test-time optimization config ========== + self._tta_mode = False # Track if in test-time adaptation mode + + # ========== [Coarse] Bio prior ========== + self.bio_to_betas_init = nn.Sequential( + nn.Linear(context_dim, 256), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(256, 41), + ) + self.shape_projector = nn.Sequential( + nn.Linear(41, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 128), + ) + + # ========== Init pose/cam ========== + self.init_pose = nn.Parameter(torch.zeros(1, self.npose)) + self.init_cam = nn.Parameter(torch.tensor([[0.9, 0, 0]], dtype=torch.float32)) + + # params -> param token + param_dim = self.npose + 41 + 3 + self.param_to_token = nn.Sequential( + nn.Linear(param_dim, self.decoder_dim), + nn.LayerNorm(self.decoder_dim), + nn.ReLU(), + ) + + # ========== Keypoint token embeddings ========== + if self.use_keypoint_2d_tokens: + self.keypoint_2d_embeddings = nn.Embedding(self.num_keypoints, self.decoder_dim) + nn.init.normal_(self.keypoint_2d_embeddings.weight, std=0.02) + + # (x,y) -> token augment + self.keypoint_2d_pos_encoder = nn.Sequential( + nn.Linear(2, 256), + nn.ReLU(), + nn.Linear(256, self.decoder_dim), + ) + # sampled image feat -> token dim (add into token embeddings) + self.keypoint_2d_feat_linear = nn.Linear(self.decoder_dim, self.decoder_dim) + + if self.use_keypoint_3d_tokens: + self.keypoint_3d_embeddings = nn.Embedding(self.num_keypoints, self.decoder_dim) + nn.init.normal_(self.keypoint_3d_embeddings.weight, std=0.02) + + # (x,y,z) -> token augment + self.keypoint_3d_pos_encoder = nn.Sequential( + nn.Linear(3, 256), + nn.ReLU(), + nn.Linear(256, self.decoder_dim), + ) + + # ========== Per-token intermediate heads (predict from kp tokens themselves) ========== + if self.keypoint_token_update: + if self.use_keypoint_2d_tokens: + self.kp2d_from_tokens = nn.Sequential( + nn.Linear(self.decoder_dim, self.decoder_dim), + nn.ReLU(), + nn.Linear(self.decoder_dim, 2), + ) + if self.use_keypoint_3d_tokens: + self.kp3d_from_tokens = nn.Sequential( + nn.Linear(self.decoder_dim, self.decoder_dim), + nn.ReLU(), + nn.Linear(self.decoder_dim, 3), + ) + + # ========== Image feature projection + pos encoding ========== + self.image_proj = nn.Identity() if context_dim == self.decoder_dim else nn.Linear(context_dim, self.decoder_dim) + self.image_pos_encoding = PositionalEncoding2D(self.decoder_dim) + + # ========== Transformer decoder layers ========== + self.layers = nn.ModuleList( + [ + PoseTransformerDecoderLayer( + d_model=self.decoder_dim, + nhead=num_heads, + dim_feedforward=int(self.decoder_dim * mlp_ratio), + dropout=0.1, + ) + for _ in range(num_layers) + ] + ) + self.norm = nn.LayerNorm(self.decoder_dim) + + # ========== Regression heads (param token only) ========== + self.decpose = nn.Sequential( + nn.Linear(self.decoder_dim, self.decoder_dim), + nn.ReLU(), + nn.Linear(self.decoder_dim, self.npose), + ) + self.decshape = nn.Sequential( + nn.Linear(self.decoder_dim, self.decoder_dim), + nn.ReLU(), + nn.Linear(self.decoder_dim, 41), + ) + self.deccam = nn.Sequential( + nn.Linear(self.decoder_dim, self.decoder_dim // 2), + nn.ReLU(), + nn.Linear(self.decoder_dim // 2, 3), + ) + + # -------------------------- + # helpers: query token build + # -------------------------- + def _build_query_tokens(self, pred_pose, pred_betas, pred_cam): + B = pred_pose.shape[0] + tokens = [] + + params = torch.cat([pred_pose, pred_betas, pred_cam], dim=1) # [B, param_dim] + param_token = self.param_to_token(params).unsqueeze(1) # [B,1,D] + tokens.append(param_token) + + kp2d_start = None + kp3d_start = None + + if self.use_keypoint_2d_tokens: + kp2d_start = sum(t.shape[1] for t in tokens) + kp2d_tokens = self.keypoint_2d_embeddings.weight.unsqueeze(0).expand(B, -1, -1).contiguous() + tokens.append(kp2d_tokens) + + if self.use_keypoint_3d_tokens: + kp3d_start = sum(t.shape[1] for t in tokens) + kp3d_tokens = self.keypoint_3d_embeddings.weight.unsqueeze(0).expand(B, -1, -1).contiguous() + tokens.append(kp3d_tokens) + + token_embeddings = torch.cat(tokens, dim=1) # [B,Nq,D] + token_augment = torch.zeros_like(token_embeddings) + + return token_embeddings, token_augment, kp2d_start, kp3d_start + + # -------------------------- + # helpers: updates + # -------------------------- + def _kp2d_update(self, token_embeddings, token_augment, image_features, kp2d_start, H, W, vis_mask=None): + """ + SAM3D-body-style 2D keypoint token update. + + image_features: [B, HW, D] projected + pos-encoded, with HW=H*W (expected 12*16) + vis_mask: optional [B,N] bool (True=valid) + """ + if not (self.keypoint_token_update and self.use_keypoint_2d_tokens): + return token_embeddings, token_augment, None + + B = token_embeddings.shape[0] + N = self.num_keypoints + + kp_tokens = token_embeddings[:, kp2d_start : kp2d_start + N, :] # [B,N,D] + + # predict coords in [-0.5,0.5] + pred_xy = self.kp2d_from_tokens(kp_tokens) + pred_xy = torch.tanh(pred_xy) * 0.5 + + # invalid mask (out of bounds + optional vis) + pred_xy_01 = pred_xy + 0.5 + invalid = ( + (pred_xy_01[..., 0] < 0.0) + | (pred_xy_01[..., 0] > 1.0) + | (pred_xy_01[..., 1] < 0.0) + | (pred_xy_01[..., 1] > 1.0) + ) + if vis_mask is not None: + invalid = invalid | (~vis_mask) + valid = (~invalid).unsqueeze(-1).float() # [B,N,1] + + # update token_augment slice + token_augment = token_augment.clone() + token_augment[:, kp2d_start : kp2d_start + N, :] = self.keypoint_2d_pos_encoder(pred_xy) * valid + + # inject sampled image feature into kp2d tokens (optional) + if self.kp2d_inject_image_feat: + img = image_features.view(B, H, W, self.decoder_dim).permute(0, 3, 1, 2).contiguous() # [B,D,H,W] + grid = (pred_xy * 2.0).unsqueeze(2) # [B,N,1,2] in [-1,1] + + sampled = ( + F.grid_sample(img, grid, mode="bilinear", padding_mode="zeros", align_corners=False) + .squeeze(3) + .permute(0, 2, 1) + .contiguous() + ) # [B,N,D] + + sampled = sampled * valid + token_embeddings = token_embeddings.clone() + token_embeddings[:, kp2d_start : kp2d_start + N, :] += self.keypoint_2d_feat_linear(sampled) + + return token_embeddings, token_augment, pred_xy + + def _kp3d_update(self, token_embeddings, token_augment, kp3d_start): + if not (self.keypoint_token_update and self.use_keypoint_3d_tokens): + return token_embeddings, token_augment, None + + N = self.num_keypoints + kp_tokens = token_embeddings[:, kp3d_start : kp3d_start + N, :] # [B,N,D] + + pred_xyz = self.kp3d_from_tokens(kp_tokens) # [B,N,3] + + # pelvis normalize + pelvis_center = pred_xyz[:, self.pelvis_idx, :].mean(dim=1, keepdim=True) # [B,1,3] + pred_xyz_norm = pred_xyz - pelvis_center + + token_augment = token_augment.clone() + token_augment[:, kp3d_start : kp3d_start + N, :] = self.keypoint_3d_pos_encoder(pred_xyz_norm) + + return token_embeddings, token_augment, pred_xyz + + # -------------------------- + # forward + # -------------------------- + def forward(self, x, keypoint_coords_2d=None, keypoint_coords_3d=None, **kwargs): + """ + Inputs: + x: [B, Hp*Wp+1, C] image tokens from backbone concatenated with bio token + (BioCLIP token is the last token in the sequence) + + Note: + keypoint_coords_2d can optionally provide a vis/conf mask: [B,N,3] (x,y,vis) + We do NOT inject GT coords into tokens by default; they are used only as optional masking. + """ + B = x.shape[0] + + # ---- Data preprocessing ---- + # Handle 4D input tensors of shape (B, C, H, W). + if len(x.shape) == 4: + x = einops.rearrange(x, 'b c h w -> b (h w) c') + + bio_token = x[:, -1, :] # [B, C] - the BioCLIP token is the final token + image_features = x[:, :-1, :] # [B, H*W, C] - remaining image features + + # ---- Coarse bio shape ---- + init_betas = self.bio_to_betas_init(bio_token) # [B,41] + shape_feat = F.normalize(self.shape_projector(init_betas), dim=1) + + # ---- Image feature projection ---- + # Project only image features, excluding the bio token. + image_features = self.image_proj(image_features) # [B,HW,D] + + # Your backbone: vit crop 256x192 with patch16 => Hp=12, Wp=16 + H, W = 12, 16 + assert image_features.shape[1] == H * W, f"Expected HW={H*W}, got {image_features.shape[1]}" + + img_pos = self.image_pos_encoding(H, W).to(image_features.device) # [HW,D] + image_features = image_features + img_pos.unsqueeze(0) + + # ---- init params ---- + pred_pose = self.init_pose.expand(B, -1) + pred_betas = init_betas + pred_cam = self.init_cam.expand(B, -1) + + pred_pose_list, pred_betas_list, pred_cam_list = [], [], [] + pred_keypoints_2d_list, pred_keypoints_3d_list = [], [] + + # Optional visibility mask from provided 2D keypoints + vis_mask = None + if keypoint_coords_2d is not None and keypoint_coords_2d.shape[-1] == 3: + vis_mask = keypoint_coords_2d[..., 2] > 0 # [B,N] + + # ---- IEF loop ---- + for _ in range(self.ief_iters): + token_embeddings, token_augment, kp2d_start, kp3d_start = self._build_query_tokens( + pred_pose, pred_betas, pred_cam + ) + + # ---- Transformer layers ---- + for layer_idx, layer in enumerate(self.layers): + # inject dynamic augment + tokens_in = token_embeddings + token_augment + token_embeddings = layer(tokens_in, image_features) + + # layer-wise token update (skip last layer) + if self.keypoint_token_update and (layer_idx < len(self.layers) - 1): + if self.use_keypoint_2d_tokens: + token_embeddings, token_augment, pred_xy = self._kp2d_update( + token_embeddings, token_augment, image_features, kp2d_start, H, W, vis_mask=vis_mask + ) + if pred_xy is not None: + pred_keypoints_2d_list.append(pred_xy) + + if self.use_keypoint_3d_tokens: + token_embeddings, token_augment, pred_xyz = self._kp3d_update( + token_embeddings, token_augment, kp3d_start + ) + if pred_xyz is not None: + pred_keypoints_3d_list.append(pred_xyz) + + # ---- Regress deltas from param token ---- + token_embeddings = self.norm(token_embeddings) + param_token_out = token_embeddings[:, 0, :] + + delta_pose = self.decpose(param_token_out) + delta_betas = self.decshape(param_token_out) + delta_cam = self.deccam(param_token_out) + + pred_pose = pred_pose + delta_pose + pred_betas = pred_betas + delta_betas + pred_cam = pred_cam + delta_cam + + pred_pose_list.append(pred_pose) + pred_betas_list.append(pred_betas) + pred_cam_list.append(pred_cam) + + # ---- Convert joint representation ---- + joint_conversion_fn = { + "6d": rot6d_to_rotmat, + "aa": lambda y: aa_to_rotmat(y.view(-1, 3).contiguous()), + }[self.joint_rep_type] + + pred_smal_params_list = { + "pose": torch.cat( + [joint_conversion_fn(p).view(B, -1, 3, 3)[:, 1:, :, :] for p in pred_pose_list], + dim=0, + ), + "betas": torch.cat(pred_betas_list, dim=0), + "cam": torch.cat(pred_cam_list, dim=0), + "keypoints_2d": torch.cat(pred_keypoints_2d_list, dim=0) if len(pred_keypoints_2d_list) else None, + "keypoints_3d": torch.cat(pred_keypoints_3d_list, dim=0) if len(pred_keypoints_3d_list) else None, + } + + pred_pose_mat = joint_conversion_fn(pred_pose).view(B, self.cfg.SMAL.NUM_JOINTS + 1, 3, 3) + pred_smal_params = { + "global_orient": pred_pose_mat[:, [0]], + "pose": pred_pose_mat[:, 1:], + "betas": pred_betas, + } + + # expose final predicted keypoints for losses + if self.keypoint_token_update: + if self.use_keypoint_2d_tokens and len(pred_keypoints_2d_list): + pred_smal_params["keypoints_2d"] = pred_keypoints_2d_list[-1] + if self.use_keypoint_3d_tokens and len(pred_keypoints_3d_list): + pred_smal_params["keypoints_3d"] = pred_keypoints_3d_list[-1] + + extra_outputs = { + "shape_feat": shape_feat, + "init_betas": init_betas, + "pred_smal_params_list": pred_smal_params_list, + } + return pred_smal_params, pred_cam, extra_outputs + + # -------------------------- + # Test-time optimization helpers + # -------------------------- + def freeze_all_except_keypoint_tokens(self): + """ + Freeze all parameters except keypoint token embeddings and their prediction heads. + Use this before test-time optimization. + """ + # Freeze everything first + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze only keypoint-related parameters + if self.use_keypoint_2d_tokens: + for param in self.keypoint_2d_embeddings.parameters(): + param.requires_grad = True + for param in self.keypoint_2d_pos_encoder.parameters(): + param.requires_grad = True + for param in self.keypoint_2d_feat_linear.parameters(): + param.requires_grad = True + if self.keypoint_token_update: + for param in self.kp2d_from_tokens.parameters(): + param.requires_grad = True + + if self.use_keypoint_3d_tokens: + for param in self.keypoint_3d_embeddings.parameters(): + param.requires_grad = True + for param in self.keypoint_3d_pos_encoder.parameters(): + param.requires_grad = True + if self.keypoint_token_update: + for param in self.kp3d_from_tokens.parameters(): + param.requires_grad = True + + self._tta_mode = True + print("[TTA] Frozen all parameters except keypoint tokens") + + def freeze_backbone_only(self): + """ + Freeze only backbone, keep SMAL head trainable. + Use for full SMAL parameter + keypoint optimization. + """ + # Unfreeze all SMAL head parameters + for param in self.parameters(): + param.requires_grad = True + + self._tta_mode = True + print("[TTA] SMAL head fully trainable (backbone frozen separately)") + + def freeze_except_regression_heads(self): + """ + Freeze everything except the final regression heads (pose/shape/cam) and keypoint embeddings. + Keep transformer frozen to preserve pretrained representations. + """ + # Freeze everything first + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze only the final regression heads (small MLPs) + for param in self.decpose.parameters(): + param.requires_grad = True + for param in self.decshape.parameters(): + param.requires_grad = True + for param in self.deccam.parameters(): + param.requires_grad = True + + # Unfreeze ONLY keypoint embeddings (learned tokens, NOT position encoders) + if self.use_keypoint_2d_tokens: + self.keypoint_2d_embeddings.weight.requires_grad = True + + if self.use_keypoint_3d_tokens: + self.keypoint_3d_embeddings.weight.requires_grad = True + + # DO NOT unfreeze transformer - keep pretrained representations + # DO NOT unfreeze param_to_token - keep initial token mapping stable + + self._tta_mode = True + print("[TTA] Frozen all except regression heads and keypoint embeddings") + + def unfreeze_all(self): + """Restore all parameters to trainable state.""" + for param in self.parameters(): + param.requires_grad = True + self._tta_mode = False + print("[TTA] Unfrozen all parameters") + + def get_tta_parameters(self, mode='keypoints_only'): + """ + Get list of parameters that should be optimized during test-time adaptation. + MUST match what's unfrozen by freeze methods! + + Args: + mode: 'keypoints_only', 'regression_heads', or 'all' + """ + params = [] + + # Keypoint embeddings only (NOT position encoders or feature linears) + if mode in ['keypoints_only', 'regression_heads', 'all']: + if self.use_keypoint_2d_tokens: + params.append(self.keypoint_2d_embeddings.weight) + + if self.use_keypoint_3d_tokens: + params.append(self.keypoint_3d_embeddings.weight) + + # Regression heads only (NO transformer or param_to_token) + if mode in ['regression_heads', 'all']: + params.extend(list(self.decpose.parameters())) + params.extend(list(self.decshape.parameters())) + params.extend(list(self.deccam.parameters())) + + return params + + + + +class PoseTransformerDecoderLayer(nn.Module): + """ + Single-layer transformer decoder for pose-token aggregation. + Includes self-attention over tokens, cross-attention from tokens to + image features, and a feed-forward network. + """ + + def __init__(self, d_model=1024, nhead=8, dim_feedforward=4096, dropout=0.1): + super().__init__() + + # Self-attention over tokens + self.self_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=True + ) + self.norm1 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + + # Cross-attention from image features into tokens + self.cross_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=True + ) + self.norm2 = nn.LayerNorm(d_model) + self.dropout2 = nn.Dropout(dropout) + + # Feed-Forward Network + self.ffn = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + nn.Dropout(dropout), + ) + self.norm3 = nn.LayerNorm(d_model) + + def forward(self, tokens, image_features): + """ + Args: + tokens: [B, N_tokens, C] containing pose and keypoint tokens + image_features: [B, N_pixels, C] image features + + Returns: + tokens: [B, N_tokens, C] updated tokens + """ + + # Self-attention lets tokens exchange information. + attn_output, _ = self.self_attn(tokens, tokens, tokens) + tokens = tokens + self.dropout1(attn_output) + tokens = self.norm1(tokens) + + # Cross-attention injects visual information from image features. + attn_output, _ = self.cross_attn( + query=tokens, + key=image_features, + value=image_features, + ) + tokens = tokens + self.dropout2(attn_output) + tokens = self.norm2(tokens) + + # Feed-Forward Network + ffn_output = self.ffn(tokens) + tokens = tokens + ffn_output + tokens = self.norm3(tokens) + + return tokens + + +class PositionalEncoding2D(nn.Module): + """ + 2D sinusoidal positional encoding for image features. + """ + + def __init__(self, embed_dim=1024, temperature=10000): + super().__init__() + self.embed_dim = embed_dim + self.temperature = temperature + + def forward(self, H, W): + """ + Args: + H, W: height and width of the feature map + + Returns: + pos_encoding: [H*W, embed_dim] + """ + # Build grid coordinates. + y_embed = torch.arange(H, dtype=torch.float32).unsqueeze(1).repeat(1, W) + x_embed = torch.arange(W, dtype=torch.float32).unsqueeze(0).repeat(H, 1) + + # Normalize to [0, 1]. + y_embed = y_embed / H + x_embed = x_embed / W + + # Build frequencies. + dim_t = torch.arange(self.embed_dim // 2, dtype=torch.float32) + dim_t = self.temperature ** (2 * dim_t / self.embed_dim) + + # Apply sine/cosine encoding. + pos_x = x_embed[: , : , None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + + pos_x = torch.stack( + [pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()], dim=3 + ).flatten(2) + pos_y = torch. stack( + [pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()], dim=3 + ).flatten(2) + + pos = torch.cat([pos_y, pos_x], dim=2).flatten(0, 1) # [H*W, embed_dim] + + return pos + + diff --git a/prima/models/losses.py b/prima/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..2982d17689d36d91ab8d351e354a04c1699ddf75 --- /dev/null +++ b/prima/models/losses.py @@ -0,0 +1,580 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import torch +import torch.nn as nn +import numpy as np +import pickle +import torch.nn.functional as F +from ..utils.geometry import aa_to_rotmat +from typing import Dict + + +def matrix_to_axis_angle(rot_mats: torch.Tensor) -> torch.Tensor: + """Convert rotation matrices (..., 3, 3) to axis-angle vectors (..., 3). + + This local implementation avoids a hard runtime dependency on PyTorch3D. + """ + if rot_mats.shape[-2:] != (3, 3): + raise ValueError(f"Expected (..., 3, 3) rotation matrices, got {rot_mats.shape}") + + trace = rot_mats[..., 0, 0] + rot_mats[..., 1, 1] + rot_mats[..., 2, 2] + cos_theta = (trace - 1.0) * 0.5 + cos_theta = torch.clamp(cos_theta, -1.0, 1.0) + theta = torch.acos(cos_theta) + + vee = torch.stack( + [ + rot_mats[..., 2, 1] - rot_mats[..., 1, 2], + rot_mats[..., 0, 2] - rot_mats[..., 2, 0], + rot_mats[..., 1, 0] - rot_mats[..., 0, 1], + ], + dim=-1, + ) + sin_theta = torch.sin(theta) + eps = 1e-6 + scale = theta / torch.clamp(2.0 * sin_theta, min=eps) + aa = vee * scale.unsqueeze(-1) + + # For very small angles, first-order approximation: aa ~= 0.5 * vee + small = theta < 1e-4 + if small.any(): + aa = torch.where(small.unsqueeze(-1), 0.5 * vee, aa) + return aa + + + +class DepthLoss(nn.Module): + """ + Depth loss between predicted SMAL vertices and GT SMAL vertices. + Compares vertex Z (camera space) after applying camera translation. + Only computes loss for samples that have valid GT SMAL parameters. + """ + def __init__(self, loss_type: str = 'l1'): + super().__init__() + self.loss_type = loss_type + self.l1 = nn.L1Loss(reduction='none') # Changed to 'none' for per-sample masking + self.l2 = nn.MSELoss(reduction='none') # Changed to 'none' for per-sample masking + + def forward(self, + pred_vertices: torch.Tensor, # (B, V, 3) + pred_cam_t: torch.Tensor, # (B, 3) + gt_smal_params: Dict[str, torch.Tensor], + smal_model, # SMAL instance callable + is_axis_angle: Dict[str, torch.Tensor], + gt_cam_t: torch.Tensor = None, # (B, 3) or None -> fallback to pred_cam_t + has_smal_params: Dict[str, torch.Tensor] = None # Added masking support + ) -> torch.Tensor: + batch_size = pred_vertices.shape[0] + device = pred_vertices.device + + # Determine which samples have valid GT SMAL params + # A sample is valid only if it has GT for pose, betas, and global_orient + if has_smal_params is not None: + valid_mask = (has_smal_params['pose'] * + has_smal_params['betas'] * + has_smal_params['global_orient']).bool() + + # If no samples have valid GT, return zero loss + if valid_mask.sum() == 0: + return torch.tensor(0., device=device, dtype=pred_vertices.dtype) + else: + # If not provided, assume all samples are valid + valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device) + + # prepare GT params for SMAL + gt_params_for_smal = {} + for k in ['global_orient', 'pose', 'betas']: + val = gt_smal_params[k].to(device=device) + if k == 'betas': + gt_params_for_smal[k] = val.view(batch_size, -1) + else: + gt_val = val.view(batch_size, -1) + if is_axis_angle[k].all(): + gt_val = aa_to_rotmat(gt_val.reshape(-1, 3)).view(batch_size, -1, 3, 3) + else: + gt_val = gt_val.view(batch_size, -1, 3, 3) + gt_params_for_smal[k] = gt_val + + # generate GT vertices (no grad) + with torch.no_grad(): + gt_out = smal_model(**gt_params_for_smal, pose2rot=False) + gt_vertices = gt_out.vertices.view(batch_size, -1, 3) + + if gt_cam_t is None: + gt_cam_t = pred_cam_t + + # depth = z in camera coordinates + pred_depth = (pred_vertices + pred_cam_t.unsqueeze(1))[..., 2] # (B, V) + gt_depth = (gt_vertices + gt_cam_t.unsqueeze(1))[..., 2] # (B, V) + + # Compute loss per sample + if self.loss_type == 'l1': + loss_per_sample = self.l1(pred_depth, gt_depth).mean(dim=1) # (B,) + else: + loss_per_sample = self.l2(pred_depth, gt_depth).mean(dim=1) # (B,) + + # Apply mask: only compute loss for samples with valid GT + masked_loss = loss_per_sample * valid_mask.float() + + # Return mean over valid samples + num_valid = valid_mask.sum().float() + if num_valid > 0: + return masked_loss.sum() / num_valid + else: + return torch.tensor(0., device=device, dtype=pred_vertices.dtype) + + +class MaskLoss(nn.Module): + """ + Mask loss between rendered predicted mesh mask and rendered GT mesh mask. + This loss relies on a MeshRenderer-like object that provides `render_mask(vertices, camera_translation, focal_length)` + returning a single-channel numpy mask (H, W) with values 0/1. + """ + def __init__(self, mesh_renderer=None): + super().__init__() + self.mesh_renderer = mesh_renderer + self.l1 = nn.L1Loss(reduction='mean') + + def forward(self, + pred_vertices: torch.Tensor, # (B, V, 3) + pred_cam_t: torch.Tensor, # (B, 3) + gt_smal_params: Dict[str, torch.Tensor], + smal_model, # SMAL instance callable + is_axis_angle: Dict[str, torch.Tensor], + gt_cam_t: torch.Tensor = None, # optional (B,3) + focal_length: float = 1000.0 + ) -> torch.Tensor: + batch_size = pred_vertices.shape[0] + device = pred_vertices.device + + # if no renderer available, return zero loss + if self.mesh_renderer is None: + return torch.tensor(0., device=device, dtype=pred_vertices.dtype) + + # prepare GT params for SMAL + gt_params_for_smal = {} + for k in ['global_orient', 'pose', 'betas']: + val = gt_smal_params[k].to(device=device) + if k == 'betas': + gt_params_for_smal[k] = val.view(batch_size, -1) + else: + gt_val = val.view(batch_size, -1) + if is_axis_angle[k].all(): + gt_val = aa_to_rotmat(gt_val.reshape(-1, 3)).view(batch_size, -1, 3, 3) + else: + gt_val = gt_val.view(batch_size, -1, 3, 3) + gt_params_for_smal[k] = gt_val + + # generate GT vertices (no grad) + with torch.no_grad(): + gt_out = smal_model(**gt_params_for_smal, pose2rot=False) + gt_vertices = gt_out.vertices + + if gt_cam_t is None: + gt_cam_t = pred_cam_t + + # convert to numpy for renderer + pred_vertices_np = pred_vertices.detach().cpu().numpy() + gt_vertices_np = gt_vertices.detach().cpu().numpy() + cam_np = pred_cam_t.detach().cpu().numpy() if pred_cam_t is not None else np.zeros((batch_size, 3), dtype=np.float32) + + per_item_losses = [] + for i in range(batch_size): + try: + pred_mask = self.mesh_renderer.render_mask(pred_vertices_np[i], cam_np[i], focal_length) + gt_mask_r = self.mesh_renderer.render_mask(gt_vertices_np[i], cam_np[i], focal_length) + pm = torch.from_numpy(pred_mask).to(device=device, dtype=pred_vertices.dtype) + gm = torch.from_numpy(gt_mask_r).to(device=device, dtype=pred_vertices.dtype) + per_item_losses.append(self.l1(pm, gm)) + except Exception: + # ignore render failure for this sample + continue + + if len(per_item_losses) == 0: + return torch.tensor(0., device=device, dtype=pred_vertices.dtype) + + return torch.stack(per_item_losses).mean() + +class Keypoint2DLoss(nn.Module): + + def __init__(self, loss_type: str = 'l1'): + """ + 2D keypoint loss module. + Args: + loss_type (str): Choose between l1 and l2 losses. + """ + super(Keypoint2DLoss, self).__init__() + if loss_type == 'l1': + self.loss_fn = nn.L1Loss(reduction='none') + elif loss_type == 'l2': + self.loss_fn = nn.MSELoss(reduction='none') + else: + raise NotImplementedError('Unsupported loss function') + + def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor: + """ + Compute 2D reprojection loss on the keypoints. + Args: + pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints) + gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence. + Returns: + torch.Tensor: 2D keypoint loss. + """ + conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() + batch_size = conf.shape[0] + loss = (conf * self.loss_fn(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).sum(dim=(1, 2)) + return loss.sum() + + +class Keypoint3DLoss(nn.Module): + + def __init__(self, loss_type: str = 'l1'): + """ + 3D keypoint loss module. + Args: + loss_type (str): Choose between l1 and l2 losses. + """ + super(Keypoint3DLoss, self).__init__() + if loss_type == 'l1': + self.loss_fn = nn.L1Loss(reduction='none') + elif loss_type == 'l2': + self.loss_fn = nn.MSELoss(reduction='none') + else: + raise NotImplementedError('Unsupported loss function') + + def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor, pelvis_id: int = 0): + """ + Compute 3D keypoint loss. + Args: + pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints) + gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence. + Returns: + torch.Tensor: 3D keypoint loss. + """ + batch_size = pred_keypoints_3d.shape[0] + gt_keypoints_3d = gt_keypoints_3d.clone() + pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1) + gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, pelvis_id, :-1].unsqueeze(dim=1) + conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() + gt_keypoints_3d = gt_keypoints_3d[:, :, :-1] + loss = (conf * self.loss_fn(pred_keypoints_3d, gt_keypoints_3d)).sum(dim=(1, 2)) + return loss.sum() + + +class ParameterLoss(nn.Module): + + def __init__(self): + """ + SMAL parameter loss module. + """ + super(ParameterLoss, self).__init__() + self.loss_fn = nn.MSELoss(reduction='none') + + def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor): + """ + Compute SMAL parameter loss. + Args: + pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas) + gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth MANO parameters. + Returns: + torch.Tensor: L2 parameter loss loss. + """ + mask = torch.ones_like(pred_param, device=pred_param.device, dtype=pred_param.dtype) + batch_size = pred_param.shape[0] + num_dims = len(pred_param.shape) + mask_dimension = [batch_size] + [1] * (num_dims - 1) + has_param = has_param.type(pred_param.type()).view(*mask_dimension) + loss_param = (has_param * self.loss_fn(pred_param*mask, gt_param*mask)) + return loss_param.sum() + + +class PosePriorLoss(nn.Module): + def __init__(self, path_prior): + super(PosePriorLoss, self).__init__() + with open(path_prior, "rb") as f: + data_prior = pickle.load(f, encoding="latin1") + + self.register_buffer("mean_pose", torch.from_numpy(data_prior["mean_pose"]).float()) + self.register_buffer("precs", torch.from_numpy(np.array(data_prior["pic"])).float()) + + use_index = np.ones(105, dtype=bool) + use_index[:3] = False # global rotation set False + self.register_buffer("use_index", torch.from_numpy(use_index).float()) + + def forward(self, x, has_gt): + """ + Args: + x: (batch_size, 35, 3, 3) + has_gt: has pose? + Returns: + pose prior loss + """ + if has_gt.sum() == len(has_gt): + return torch.tensor(0.0, dtype=x.dtype, device=x.device) + has_gt = has_gt.type(torch.bool) + x = x[~has_gt] + x = matrix_to_axis_angle(x.reshape(-1, 3, 3)) + delta = x.reshape(-1, 35*3) - self.mean_pose + loss = torch.tensordot(delta, self.precs, dims=([1], [0])) * self.use_index + return (loss ** 2).mean() + + +class ShapePriorLoss(nn.Module): + def __init__(self, path_prior): + super(ShapePriorLoss, self).__init__() + with open(path_prior, "rb") as f: + data_prior = pickle.load(f, encoding="latin1") + + model_covs = np.array(data_prior["cluster_cov"]) # shape: (5, 41, 41) + inverse_covs = np.stack( + [np.linalg.inv(model_cov + 1e-5 * np.eye(model_cov.shape[0])) for model_cov in model_covs], + axis=0) + prec = np.stack([np.linalg.cholesky(inverse_cov) for inverse_cov in inverse_covs], axis=0) + + self.register_buffer("betas_prec", torch.FloatTensor(prec)) + self.register_buffer("mean_betas", torch.FloatTensor(data_prior["cluster_means"])) + + def forward(self, x, category, has_gt): + """ + Args: + x: predicted betas (batch_size, 41) + category: animal category (batch_size,) + has_gt: has shape? + Returns: + shape prior loss + """ + if has_gt.sum() == len(has_gt): + return torch.tensor(0.0, dtype=x.dtype, device=x.device) + has_gt = has_gt.type(torch.bool) + x, category = x[~has_gt], category[~has_gt] + delta = (x - self.mean_betas[category.long()]) # [batch_size, 41] + loss = [] + for x0, c0 in zip(delta, category): + loss.append(torch.tensordot(x0, self.betas_prec[c0], dims=([0], [0]))) + loss = torch.stack(loss, dim=0) + return (loss ** 2).mean() + + + +class PrototypeSupConLoss(nn.Module): + def __init__(self, prototypes_init, feat_dim=128, temperature=0.1): + """ + prototypes_init: precomputed (5, 512) BioCLIP family prototypes + feat_dim: dimension of the projected shape feature (128) + """ + super().__init__() + self.temperature = temperature + + # The prototypes should live in the projected feature space. + # A practical setup is to pass the BioCLIP centers through the projector + # once at the beginning of training to initialize these prototypes. + self.register_buffer("prototypes", torch.randn(5, feat_dim)) + + def forward(self, features, labels): + """ + features: (B, 128) normalized shared features or shape features + labels: (B,) family indices for the 5-way classification setting + """ + # 1. Ensure features are normalized. + features = F.normalize(features, p=2, dim=1) + # 2. Ensure prototypes are normalized as well. + prototypes = F.normalize(self.prototypes, p=2, dim=1) + + # 3. Compute sample-to-prototype similarities with temperature scaling. + logits = torch.matmul(features, prototypes.T) / self.temperature + + # 4. Cross-entropy pulls samples toward their family prototype and + # pushes them away from the other family prototypes. + loss = F.cross_entropy(logits, labels) + + return loss + + @torch.no_grad() + def update_prototypes(self, features, labels, momentum=0.999): + """ + Optional: update prototypes with momentum during training so they + adapt gradually to the 3D task. + """ + for i in range(5): + mask = (labels == i) + if mask.any(): + new_mean = features[mask].mean(dim=0) + self.prototypes[i] = momentum * self.prototypes[i] + (1 - momentum) * new_mean + + +class SupConLoss(nn.Module): + def __init__(self, temperature=0.1, contrast_mode='all', + base_temperature=0.07): + super(SupConLoss, self).__init__() + self.temperature = temperature + self.contrast_mode = contrast_mode + self.base_temperature = base_temperature + + def forward(self, features, labels=None, mask=None): + """ + Args: + features: hidden vector of shape [bsz, ...]. + labels: ground truth of shape [bsz]. + mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j + has the same class as sample i. Can be asymmetric. + Returns: + A loss scalar. + """ + features = torch.stack((features, features), dim=1) + device = features.device + + if len(features.shape) < 3: + raise ValueError('`features` needs to be [bsz, n_views, ...],' + 'at least 3 dimensions are required') + if len(features.shape) > 3: + features = features.view(features.shape[0], features.shape[1], -1) + + batch_size = features.shape[0] + if labels is not None and mask is not None: + raise ValueError('Cannot define both `labels` and `mask`') + elif labels is None and mask is None: + mask = torch.eye(batch_size, dtype=torch.float32).to(device) + elif labels is not None: + labels = labels.contiguous().view(-1, 1) + if labels.shape[0] != batch_size: + raise ValueError('Num of labels does not match num of features') + mask = torch.eq(labels, labels.T).float().to(device) + else: + mask = mask.float().to(device) + + contrast_count = features.shape[1] + contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) + if self.contrast_mode == 'one': + anchor_feature = features[:, 0] + anchor_count = 1 + elif self.contrast_mode == 'all': + anchor_feature = contrast_feature + anchor_count = contrast_count + else: + raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) + + # compute logits + anchor_dot_contrast = torch.div( + torch.matmul(anchor_feature, contrast_feature.T), + self.temperature) + # for numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # tile mask + mask = mask.repeat(anchor_count, contrast_count) + # mask-out self-contrast cases + logits_mask = torch.scatter( + torch.ones_like(mask), + 1, + torch.arange(batch_size * anchor_count).view(-1, 1).to(device), + 0 + ) + mask = mask * logits_mask + + # compute log_prob + exp_logits = torch.exp(logits) * logits_mask + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + + # compute mean of log-likelihood over positive + mask_pos_pairs = mask.sum(1) + mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs) + mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs + + # loss + loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos + loss = loss.view(anchor_count, batch_size).mean() + + return loss + +# Auxiliary intermediate-supervision loss module. + +class InterLoss(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.use_intermediate_supervision = cfg.LOSS.get('USE_INTERMEDIATE_SUPERVISION', True) + self.intermediate_weight = cfg.LOSS.get('INTERMEDIATE_WEIGHT', 0.5) + + # 2D keypoint loss + self.keypoint_2d_loss = nn.MSELoss(reduction='none') + + # 3D keypoint loss + self.keypoint_3d_loss = nn.MSELoss(reduction='none') + + def forward(self, predictions, gt_data): + """ + Args: + predictions: model outputs (pred_smal_params, pred_cam, pred_smal_params_list) + gt_data: dict containing ground-truth data + - 'keypoints_2d': [B, N, 3] (x, y, visibility) + - 'keypoints_3d': [B, N, 3] (x, y, z) or [B, N, 4] (x, y, z, confidence) + """ + pred_smal_params, pred_cam, pred_smal_params_list = predictions + + losses = {} + total_loss = 0.0 + + # ========== Supervision for final predictions ========== + # Final keypoint supervision can be added here after running the + # predicted parameters through the SMAL model. + + # ========== Supervision for intermediate predictions ========== + if self.use_intermediate_supervision and pred_smal_params_list is not None: + + # 2D keypoint supervision + if 'keypoints_2d' in pred_smal_params_list and pred_smal_params_list['keypoints_2d'] is not None: + pred_kps_2d_all = pred_smal_params_list['keypoints_2d'] + # [B*num_iters, N, 2] + + gt_kps_2d = gt_data['keypoints_2d'][: , :, :2] # [B, N, 2] + gt_vis_2d = gt_data['keypoints_2d'][:, :, 2] # [B, N] + + # Repeat the ground truth for each iteration. + num_iters = pred_kps_2d_all.shape[0] // gt_kps_2d.shape[0] + gt_kps_2d_repeated = gt_kps_2d.repeat(num_iters, 1, 1) # [B*num_iters, N, 2] + gt_vis_2d_repeated = gt_vis_2d.repeat(num_iters, 1) # [B*num_iters, N] + + # Compute the loss only over visible keypoints. + loss_2d = self.keypoint_2d_loss(pred_kps_2d_all, gt_kps_2d_repeated) + loss_2d = loss_2d.mean(dim=-1) # [B*num_iters, N] + loss_2d = (loss_2d * gt_vis_2d_repeated).sum() / (gt_vis_2d_repeated.sum() + 1e-6) + + losses['intermediate_keypoints_2d'] = loss_2d * self.intermediate_weight + total_loss += losses['intermediate_keypoints_2d'] + + # 3D keypoint supervision + if 'keypoints_3d' in pred_smal_params_list and pred_smal_params_list['keypoints_3d'] is not None: + pred_kps_3d_all = pred_smal_params_list['keypoints_3d'] + # [B*num_iters, N, 3] + + gt_kps_3d = gt_data['keypoints_3d'][: , :, :3] # [B, N, 3] + if gt_data['keypoints_3d'].shape[-1] == 4: + gt_conf_3d = gt_data['keypoints_3d'][:, :, 3] # [B, N] + else: + gt_conf_3d = torch.ones_like(gt_kps_3d[:, :, 0]) # All keypoints are valid. + + # Repeat the ground truth for each iteration. + num_iters = pred_kps_3d_all.shape[0] // gt_kps_3d.shape[0] + gt_kps_3d_repeated = gt_kps_3d.repeat(num_iters, 1, 1) + gt_conf_3d_repeated = gt_conf_3d.repeat(num_iters, 1) + + # Compute the loss. + loss_3d = self.keypoint_3d_loss(pred_kps_3d_all, gt_kps_3d_repeated) + loss_3d = loss_3d.mean(dim=-1) # [B*num_iters, N] + loss_3d = (loss_3d * gt_conf_3d_repeated).sum() / (gt_conf_3d_repeated.sum() + 1e-6) + + losses['intermediate_keypoints_3d'] = loss_3d * self.intermediate_weight + total_loss += losses['intermediate_keypoints_3d'] + + # ... other losses (pose, shape, etc.) ... + + losses['total'] = total_loss + return losses diff --git a/prima/models/prima.py b/prima/models/prima.py new file mode 100755 index 0000000000000000000000000000000000000000..0cb861dd5b7eb2693c03afe97d3fcd53647f7b1a --- /dev/null +++ b/prima/models/prima.py @@ -0,0 +1,615 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import torch +import pickle +import pytorch_lightning as pl +from typing import Any, Dict +from yacs.config import CfgNode + +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from torchvision.utils import make_grid +from ..utils.geometry import perspective_projection, aa_to_rotmat +from ..utils.pylogger import get_pylogger +from .backbones import create_backbone +from .heads import build_smal_head +from prima.models.smal_wrapper import SMAL +from .discriminator import Discriminator + +from .bioclip_embedding import BioClipEmbedding +import sys +from transformers import AutoModel, AutoFeatureExtractor +import einops + +import open_clip + + +from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss, ShapePriorLoss, PosePriorLoss, SupConLoss +log = get_pylogger(__name__) + + +class PRIMA(pl.LightningModule): + + def __init__(self, cfg: CfgNode, init_renderer: bool = True): + """ + Setup PRIMA model + Args: + cfg (CfgNode): Config file as a yacs CfgNode + """ + super().__init__() + + # Save hyperparameters + self.save_hyperparameters(logger=False, ignore=['init_renderer']) + + self.cfg = cfg + # Create backbone feature extractor + + if cfg.MODEL.BACKBONE.TYPE =='vith': + self.backbone = create_backbone(cfg) # create vit backbone anyway, for inference, no config loading, just load ckpt weights + + if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None): # pretrained exists and not none, then true + + log.info(f'Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}') + state_dict = torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu', weights_only=True)['state_dict'] + state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items()} + + missing_keys, unexpected_keys = self.backbone.load_state_dict(state_dict, strict=False) + + + # freeze backbones + if cfg.MODEL.BACKBONE.get('FREEZE', False) and cfg.MODEL.BACKBONE.TYPE == 'vith': + log.info(f'Freezing first 2/3 blocks of vit backbone') + # Freeze patch embedding + if hasattr(self.backbone, 'patch_embed'): + for p in self.backbone.patch_embed.parameters(): + p.requires_grad = False + + # Freeze first 2/3 of transformer blocks + if hasattr(self.backbone, 'blocks'): + total_blocks = len(self.backbone.blocks) + freeze_blocks = int(total_blocks * 2 / 3) + log.info(f'Freezing {freeze_blocks} out of {total_blocks} blocks') + for i in range(freeze_blocks): + for p in self.backbone.blocks[i].parameters(): + p.requires_grad = False + + # Create SMAL head (predicts SMAL params + perspective camera) + self.smal_head = build_smal_head(cfg) + + # Instantiate SMAL model + smal_model_path = cfg.SMAL.MODEL_PATH + with open(smal_model_path, 'rb') as f: + smal_cfg = pickle.load(f, encoding="latin1") + self.smal = SMAL(**smal_cfg) + + # create bioclip model for species classification token extraction + use_bioclip_embedding = cfg.MODEL.get('USE_BIOCLIP_EMBEDDING', False) + if use_bioclip_embedding: + bioclip_config = cfg.MODEL.get('BIOCLIP_EMBEDDING', {}) + embed_dim = bioclip_config.get('EMBED_DIM', 1280) + self.bioclip_embedding = BioClipEmbedding(cfg, embed_dim=embed_dim) + # Freeze BioClip model by default + for param in self.bioclip_embedding.species_model.parameters(): + param.requires_grad = False + else: + self.bioclip_embedding = None + + # Create discriminator + self.discriminator = Discriminator() + + + + + # Define loss functions + self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1') + self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1') + + if self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP2D', 0) > 0: + self.intermediate_kp2d_loss = Keypoint2DLoss(loss_type='l1') + if self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP3D', 0) > 0: + self.intermediate_kp3d_loss = Keypoint3DLoss(loss_type='l1') + self.smal_parameter_loss = ParameterLoss() + self.shape_prior_loss = ShapePriorLoss(path_prior=cfg.SMAL.SHAPE_PRIOR_PATH) + self.pose_prior_loss = PosePriorLoss(path_prior=cfg.SMAL.POSE_PRIOR_PATH) + self.supcon_loss = SupConLoss() + + + self.register_buffer('initialized', torch.tensor(False)) + + # init depth renderer for supervised training + # Setup renderer for visualization + if init_renderer: + from ..utils import MeshRenderer + + self.mesh_renderer = MeshRenderer(self.cfg, faces=self.smal.faces.numpy()) + else: + self.mesh_renderer = None + + # Disable automatic optimization since we use adversarial training + self.automatic_optimization = False + + def get_parameters(self): + all_params = list(self.smal_head.parameters()) + if self.cfg.MODEL.BACKBONE.TYPE in ['vith', 'dinov2', 'dinov3']: + all_params += list(self.backbone.parameters()) + + + if hasattr(self, 'keypoint_projection') and self.keypoint_projection is not None: + all_params += list(self.keypoint_projection.parameters()) + if hasattr(self, 'bioclip_embedding') and self.bioclip_embedding is not None: + # Only add projection parameters as the model itself is frozen + all_params += list(self.bioclip_embedding.projection.parameters()) + return all_params + + def configure_optimizers(self): + """ + Setup model and discriminator Optimizers + Returns: + Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers + """ + # Use separate learning rates only for vith backbone + if self.cfg.MODEL.BACKBONE.TYPE == 'vith': + # Separate backbone parameters and other parameters + backbone_params = [] + other_params = [] + + # Collect backbone parameters + if hasattr(self, 'backbone'): + backbone_params = list(filter(lambda p: p.requires_grad, self.backbone.parameters())) + + # Collect other parameters + other_params += list(self.smal_head.parameters()) + + + if hasattr(self, 'keypoint_projection') and self.keypoint_projection is not None: + other_params += list(self.keypoint_projection.parameters()) + if hasattr(self, 'bioclip_embedding') and self.bioclip_embedding is not None: + other_params += list(self.bioclip_embedding.projection.parameters()) + + + # Filter only trainable parameters + other_params = list(filter(lambda p: p.requires_grad, other_params)) + + # Create parameter groups with different learning rates + param_groups = [ + {'params': backbone_params, 'lr': self.cfg.TRAIN.LR / 10.0}, # Backbone: 1/10 lr + {'params': other_params, 'lr': self.cfg.TRAIN.LR} # Other modules: normal lr + ] + + log.info(f'Using separate LR for vith backbone') + log.info(f'Backbone parameters: {len(backbone_params)}, lr={self.cfg.TRAIN.LR / 10.0}') + log.info(f'Other parameters: {len(other_params)}, lr={self.cfg.TRAIN.LR}') + else: + # Use same learning rate for all parameters + all_params = list(filter(lambda p: p.requires_grad, self.get_parameters())) + param_groups = [{'params': all_params, 'lr': self.cfg.TRAIN.LR}] + log.info(f'Using same LR for all parameters: {len(all_params)}, lr={self.cfg.TRAIN.LR}') + + optimizer = torch.optim.AdamW(params=param_groups, + weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) + if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0: + optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(), + lr=self.cfg.TRAIN.LR, + weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) + else: + return optimizer, + + return optimizer, optimizer_disc + + def forward_step(self, batch: Dict, train: bool = False) -> Dict: + """ + Run a forward step of the network + Args: + batch (Dict): Dictionary containing batch data + train (bool): Flag indicating whether it is training or validation mode + Returns: + Dict: Dictionary containing the regression output + """ + + # Use RGB image as input + x = batch['img'] # [B, 3, H, W] + batch_size = x.shape[0] + + # Compute conditioning features using the backbone + if self.cfg.MODEL.BACKBONE.TYPE =='vith': # vit backbone return [1, 1280, 12, 16] + conditioning_feats, cls = self.backbone(x[:, :, :, 32:-32]) # reshape the input into [256, 192] + # return shape shape [B, D, Hp, Wp], [B, D] + if conditioning_feats.ndim == 4: + # Flatten spatial dimensions into sequence dimension: [B, D, Hp, Wp] -> [B, Hp*Wp, D] + B, D, Hp, Wp = conditioning_feats.shape + conditioning_feats = conditioning_feats.permute(0, 2, 3, 1).reshape(B, Hp * Wp, D) # [B, Hp*Wp, D] + + + # add bioclip embedding if enabled + if self.bioclip_embedding is not None: + species_feature = self.bioclip_embedding(batch['img']) # [B, embed_dim] + + # concatenate species feature to conditioning_feats along token dimension + if len(conditioning_feats.shape) == 3: + # Token-wise concatenation: add species_feature as a single token + # (B, embed_dim) -> (B, 1, embed_dim) + species_token = species_feature.unsqueeze(1) # (B, 1, embed_dim) + # Concatenate along token dimension: (B, num_tokens, C) + (B, 1, embed_dim) -> (B, num_tokens + 1, C or embed_dim) + # Note: This requires C == embed_dim for consistent feature dimensions + conditioning_feats = torch.cat([conditioning_feats, species_token], dim=1) # (B, num_tokens + 1, C) + else: + # If conditioning_feats is 2D (B, C), concat directly along feature dimension + conditioning_feats = torch.cat([conditioning_feats, species_feature], dim=-1) + + # Predict SMAL parameters and camera + pred_smal_params, pred_cam, extra_outputs = self.smal_head(conditioning_feats) + + + # Store useful regression outputs to the output dict + output = {} + + if 'shape_feat' in extra_outputs: + output['shape_feat'] = extra_outputs['shape_feat'] + + if 'init_betas' in extra_outputs: + output['init_betas'] = extra_outputs['init_betas'].reshape(batch_size, -1) + + + output['pred_cam'] = pred_cam # [B, 3] + output['pred_smal_params'] = {k: v.clone() for k, v in pred_smal_params.items()} + + + + # Compute camera translation + focal_length = batch['focal_length'] + + pred_cam_t = torch.stack([ + pred_cam[:, 1], + pred_cam[:, 2], + 2 * focal_length[:, 0] / (self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] + 1e-9) + ], dim=-1) # [B, 3] + + output['pred_cam_t'] = pred_cam_t # [B, 3] + output['focal_length'] = focal_length # [B, 2] + + # Compute model vertices, joints and the projected joints + pred_smal_params['global_orient'] = pred_smal_params['global_orient'].reshape(batch_size, -1, 3, 3) + pred_smal_params['pose'] = pred_smal_params['pose'].reshape(batch_size, -1, 3, 3) + pred_smal_params['betas'] = pred_smal_params['betas'].reshape(batch_size, -1) + smal_output = self.smal(**pred_smal_params, pose2rot=False) + + pred_keypoints_3d = smal_output.joints + pred_vertices = smal_output.vertices + output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3) + output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3) + + # project 3D keypoints to 2D + pred_keypoints_2d = perspective_projection( + pred_keypoints_3d, + translation=pred_cam_t, + focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE + ) # [B, num_joints, 2] + output['pred_keypoints_2d'] = pred_keypoints_2d + + # get intermediate keypoint predictions if available + + if 'keypoints_3d' in pred_smal_params and pred_smal_params['keypoints_3d'] is not None: + inter_keypoints_3d = pred_smal_params['keypoints_3d'] + output['inter_keypoints_3d'] = inter_keypoints_3d.reshape(batch_size, -1, 3) + # output['use_intermediate_kp3d_loss'] = True + + if 'keypoints_2d' in pred_smal_params and pred_smal_params['keypoints_2d'] is not None: + inter_keypoints_2d = pred_smal_params['keypoints_2d'] + output['inter_keypoints_2d'] = inter_keypoints_2d.reshape(batch_size, -1, 2) + # output['use_intermediate_kp2d_loss'] = True + + return output + + def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor: + """ + Compute losses given the input batch and the regression output + Args: + batch (Dict): Dictionary containing batch data + output (Dict): Dictionary containing the regression output + train (bool): Flag indicating whether it is training or validation mode + Returns: + torch.Tensor : Total loss for current batch + """ + + pred_smal_params = output['pred_smal_params'] + pred_keypoints_2d = output['pred_keypoints_2d'] + pred_keypoints_3d = output['pred_keypoints_3d'] + + if 'inter_keypoints_2d' in output: + inter_keypoints_2d = output['inter_keypoints_2d'] + if 'inter_keypoints_3d' in output: + inter_keypoints_3d = output['inter_keypoints_3d'] + + batch_size = pred_smal_params['pose'].shape[0] + device = pred_smal_params['pose'].device + dtype = pred_smal_params['pose'].dtype + + # Get annotations + gt_keypoints_2d = batch['keypoints_2d'] + gt_keypoints_3d = batch['keypoints_3d'] + gt_smal_params = batch['smal_params'] + gt_mask = batch['mask'] + has_smal_params = batch['has_smal_params'] + is_axis_angle = batch['smal_params_is_axis_angle'] + has_mask = batch['has_mask'] + + # Compute 2D keypoint loss + loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d) + + # Compute 3D keypoint loss + loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0) + + # Compute intermediate 2D keypoint loss if available + loss_intermediate_kp2d = torch.tensor(0., device=device, dtype=dtype) + if 'inter_keypoints_2d' in output: + loss_intermediate_kp2d = self.intermediate_kp2d_loss(inter_keypoints_2d, gt_keypoints_2d) + # loss_keypoints_2d = loss_keypoints_2d + loss_intermediate_kp2d + + # Compute intermediate 3D keypoint loss if available + loss_intermediate_kp3d = torch.tensor(0., device=device, dtype=dtype) + if 'inter_keypoints_3d' in output: + loss_intermediate_kp3d = self.intermediate_kp3d_loss(inter_keypoints_3d, gt_keypoints_3d, pelvis_id=0) + # loss_keypoints_3d = loss_keypoints_3d + loss_intermediate_kp3d + + # add intermediate keypoint losses if available + + # Compute loss on SMAL parameters + loss_smal_params = {} + for k, pred in pred_smal_params.items(): + # Skip keypoint predictions - they're handled separately + if k in ['keypoints_2d', 'keypoints_3d']: + continue + + gt = gt_smal_params[k].view(batch_size, -1) + if is_axis_angle[k].all(): + gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3) + has_gt = has_smal_params[k] + + # Only compute parameter loss if ANY sample has GT + param_loss = self.smal_parameter_loss(pred.reshape(batch_size, -1), + gt.reshape(batch_size, -1), + has_gt) + + if k == "betas": + # Only add shape prior loss if NOT all samples have GT (prior is regularization for samples without GT) + # But the shape_prior_loss already handles this check internally + loss_smal_params[k] = param_loss + self.shape_prior_loss(pred, batch["category"], has_gt) + if 'init_betas' in output: + init_betas = output['init_betas'] + loss_smal_params[k] = loss_smal_params[k] + self.shape_prior_loss(init_betas, batch["category"], has_gt) / 2. + + else: + # Only add pose prior loss if NOT all samples have GT + # The pose_prior_loss already handles this check internally + loss_smal_params[k] = param_loss + \ + self.pose_prior_loss(torch.cat((pred_smal_params["global_orient"], + pred_smal_params["pose"]), + dim=1), has_gt) / 2. + if 'shape_feat' in output: + loss_supcon = self.supcon_loss(output['shape_feat'], labels=batch['category']) + else: + loss_supcon = torch.tensor(0., device=device, dtype=dtype) + loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d + \ + self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d + \ + sum([loss_smal_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_smal_params]) + \ + self.cfg.LOSS_WEIGHTS['SUPCON'] * loss_supcon + + if 'inter_keypoints_2d' in output: + loss = loss + self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP2D', 0) * loss_intermediate_kp2d + if 'inter_keypoints_3d' in output: + loss = loss + self.cfg.LOSS_WEIGHTS.get('INTERMEDIATE_KP3D', 0) * loss_intermediate_kp3d + + + losses = dict(loss=loss.detach(), + loss_keypoints_2d=loss_keypoints_2d.detach(), + loss_keypoints_3d=loss_keypoints_3d.detach(), + loss_supcon=loss_supcon.detach(), + ) + + for k, v in loss_smal_params.items(): + losses['loss_' + k] = v.detach() + + # attach intermediate keypoint losses if computed + if 'inter_keypoints_2d' in output: + losses['loss_inter_keypoints_2d'] = loss_intermediate_kp2d.detach() + if 'inter_keypoints_3d' in output: + losses['loss_inter_keypoints_3d'] = loss_intermediate_kp3d.detach() + + + + output['losses'] = losses + + return loss + + def forward(self, batch: Dict) -> Dict: + """ + Run a forward step of the network in val mode + Args: + batch (Dict): Dictionary containing batch data + Returns: + Dict: Dictionary containing the regression output + """ + return self.forward_step(batch, train=False) + + def training_step_discriminator(self, batch: Dict, + pose: torch.Tensor, + betas: torch.Tensor, + optimizer: torch.optim.Optimizer) -> torch.Tensor: + """ + Run a discriminator training step + Args: + batch (Dict): Dictionary containing mocap batch data + pose (torch.Tensor): Regressed pose from current step + betas (torch.Tensor): Regressed betas from current step + optimizer (torch.optim.Optimizer): Discriminator optimizer + Returns: + torch.Tensor: Discriminator loss + """ + batch_size = pose.shape[0] + gt_pose = batch['pose'] + gt_betas = batch['betas'] + gt_rotmat = aa_to_rotmat(gt_pose.view(-1, 3)).view(batch_size, -1, 3, 3) + disc_fake_out = self.discriminator(pose.detach(), betas.detach()) + loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size + disc_real_out = self.discriminator(gt_rotmat.detach(), gt_betas.detach()) + loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size + loss_disc = loss_fake + loss_real + loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() + return loss_disc.detach() + + # Tensoroboard logging should run from first rank only + @pl.utilities.rank_zero.rank_zero_only + def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, + write_to_summary_writer: bool = True) -> None: + """ + Log results to Tensorboard + Args: + batch (Dict): Dictionary containing batch data + output (Dict): Dictionary containing the regression output + step_count (int): Global training step count + train (bool): Flag indicating whether it is training or validation mode + """ + + mode = 'train' if train else 'val' + + images = batch['img'] + gt_keypoints_2d = batch['keypoints_2d'] + batch_size = images.shape[0] + + # mul std then add mean + images = (images) * (torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)) + images = (images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)) + + pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3) + losses = output['losses'] + pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3) + pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2) + + if write_to_summary_writer: + summary_writer = self.logger.experiment + for loss_name, val in losses.items(): + summary_writer.add_scalar(mode + '/' + loss_name, val.detach().item(), step_count) + # if train is False: + # for metric_name, val in output['metric'].items(): + # summary_writer.add_scalar(mode + '/' + metric_name, val, step_count) + num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES) + + predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(), + pred_cam_t[:num_images].cpu().numpy(), + images[:num_images].cpu().numpy(), + self.cfg.SMAL.get("FOCAL_LENGTH", 1000), + pred_keypoints_2d[:num_images].cpu().numpy(), + gt_keypoints_2d[:num_images].cpu().numpy(), + pred_masks=output.get('pred_masks', None)[:num_images] if output.get('pred_masks', None) is not None else None, + gt_masks=output.get('gt_masks', None)[:num_images] if output.get('gt_masks', None) is not None else None, + ) + predictions = make_grid(predictions, nrow=5, padding=2) + if write_to_summary_writer: + summary_writer.add_image('%s/predictions' % mode, predictions, step_count) + + return predictions + + def training_step(self, batch: Dict) -> Dict: + """ + Run a full training step + Args: + batch (Dict): Dictionary containing {'img', 'mask', 'keypoints_2d', 'keypoints_3d', 'orig_keypoints_2d', + 'box_center', 'box_size', 'img_size', 'smal_params', + 'smal_params_is_axis_angle', '_trans', 'imgname', 'focal_length'} + Returns: + Dict: Dictionary containing regression output. + """ + batch = batch['img'] + optimizer = self.optimizers(use_pl_optimizer=True) + if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0: + optimizer, optimizer_disc = optimizer + + batch_size = batch['img'].shape[0] + output = self.forward_step(batch, train=True) + pred_smal_params = output['pred_smal_params'] + loss = self.compute_loss(batch, output, train=True) + if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0: + disc_out = self.discriminator(pred_smal_params['pose'].reshape(batch_size, -1), + pred_smal_params['betas'].reshape(batch_size, -1)) + loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size + loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv + + # Error if Nan + if torch.isnan(loss): + raise ValueError('Loss is NaN') + + optimizer.zero_grad() + self.manual_backward(loss) + # Clip gradient + if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0: + gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, + error_if_nonfinite=True) + self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + # For compatibility + # if self.cfg.LOSS_WEIGHTS.ADVERSARIAL == 0: + # optimizer.param_groups[0]['capturable'] = True + + optimizer.step() + if self.cfg.LOSS_WEIGHTS.get("ADVERSARIAL", 0) > 0: + loss_disc = self.training_step_discriminator(batch['smal_params'], + pred_smal_params['pose'].reshape(batch_size, -1), + pred_smal_params['betas'].reshape(batch_size, -1), + optimizer_disc) + output['losses']['loss_gen'] = loss_adv + output['losses']['loss_disc'] = loss_disc + + if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0: + self.tensorboard_logging(batch, output, self.global_step, train=True) + + # Log training loss to the logger so checkpoint callback can monitor it. + self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, + logger=True, batch_size=batch_size, sync_dist=True) + + return output + + def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict: + """ + Run a validation step and log to Tensorboard + Args: + batch (Dict): Dictionary containing batch data + batch_idx (int): Unused. + Returns: + Dict: Dictionary containing regression output. + """ + # The validation dataloader yields the inner batch dict directly (not wrapped as {'img': loader}). + # Run forward, compute loss and log aggregated validation metrics so ModelCheckpoint can monitor them. + output = self.forward_step(batch, train=False) + # compute_loss will populate output['losses'] and return the scalar loss + loss = self.compute_loss(batch, output, train=False) + + # Ensure losses dict is available + losses = output.get('losses', {}) + + # Log all validation losses to logger with on_epoch=True so checkpoint monitors epoch-level metric + for loss_name, val in losses.items(): + # use prog_bar only for the main loss + prog = True if loss_name == 'loss' else False + # Log as 'val/' e.g. 'val/loss' + self.log(f'val/{loss_name}', val, on_step=False, on_epoch=True, prog_bar=prog, logger=True, + sync_dist=True) + + # Periodically write images/other visuals to tensorboard + # Log visualizations on the first batch of each validation epoch + if batch_idx == 0: + # Use global_step for step count when logging validation visuals + self.tensorboard_logging(batch, output, self.global_step, train=False) + + return output diff --git a/prima/models/smal_wrapper.py b/prima/models/smal_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ea9de81756a8a8535be645f854843db1017f9d82 --- /dev/null +++ b/prima/models/smal_wrapper.py @@ -0,0 +1,134 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import json +from torch import nn +import torch +import numpy as np +import pickle +import cv2 +from typing import Optional, Tuple, NewType +from dataclasses import dataclass +import smplx +from smplx.lbs import vertices2joints, lbs +from smplx.utils import MANOOutput, to_tensor, ModelOutput +from smplx.vertex_ids import vertex_ids + +Tensor = NewType('Tensor', torch.Tensor) +keypoint_vertices_idx = [[1068, 1080, 1029, 1226], [2660, 3030, 2675, 3038], [910], [360, 1203, 1235, 1230], + [3188, 3156, 2327, 3183], [1976, 1974, 1980, 856], [3854, 2820, 3852, 3858], [452, 1811], + [416, 235, 182], [2156, 2382, 2203], [829], [2793], [60, 114, 186, 59], + [2091, 2037, 2036, 2160], [384, 799, 1169, 431], [2351, 2763, 2397, 3127], + [221, 104], [2754, 2192], [191, 1158, 3116, 2165], + [28, 1109, 1110, 1111, 1835, 1836, 3067, 3068, 3069], + [498, 499, 500, 501, 502, 503], [2463, 2464, 2465, 2466, 2467, 2468], + [764, 915, 916, 917, 934, 935, 956], [2878, 2879, 2880, 2897, 2898, 2919, 3751], + [1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762], + [0, 464, 465, 726, 1824, 2429, 2430, 2690]] + +name2id35 = {'RFoot': 14, 'RFootBack': 24, 'spine1': 4, 'Head': 16, 'LLegBack3': 19, 'RLegBack1': 21, 'pelvis0': 1, + 'RLegBack3': 23, 'LLegBack2': 18, 'spine0': 3, 'spine3': 6, 'spine2': 5, 'Mouth': 32, 'Neck': 15, + 'LFootBack': 20, 'LLegBack1': 17, 'RLeg3': 13, 'RLeg2': 12, 'LLeg1': 7, 'LLeg3': 9, 'RLeg1': 11, + 'LLeg2': 8, 'spine': 2, 'LFoot': 10, 'Tail7': 31, 'Tail6': 30, 'Tail5': 29, 'Tail4': 28, 'Tail3': 27, + 'Tail2': 26, 'Tail1': 25, 'RLegBack2': 22, 'root': 0, 'LEar': 33, 'REar': 34, 'EndNose': 35, 'Chin': 36, + 'RightEarTip': 37, 'LeftEarTip': 38, 'LeftEye': 39, 'RightEye': 40} + +@dataclass +class SMALOutput(ModelOutput): + betas: Optional[Tensor] = None + pose: Optional[Tensor] = None + + +class SMALLayer(nn.Module): + def __init__(self, num_betas=41, **kwargs): + super().__init__() + self.num_betas = num_betas + self.register_buffer("shapedirs", torch.from_numpy(np.array(kwargs['shapedirs'], dtype=np.float32))[:, :, :num_betas]) # [3889, 3, 41] + self.register_buffer("v_template", torch.from_numpy(np.array(kwargs['v_template']).astype(np.float32))) # [3889, 3] + self.register_buffer("posedirs", torch.from_numpy(np.array(kwargs['posedirs'], dtype=np.float32)).reshape(-1, + 34*9).T) # [34*9, 11667] + self.register_buffer("J_regressor", torch.from_numpy(kwargs['J_regressor'].toarray().astype(np.float32))) # [33, 3389] + self.register_buffer("lbs_weights", torch.from_numpy(np.array(kwargs['weights'], dtype=np.float32))) # [3889, 33] + self.register_buffer("faces", torch.from_numpy(np.array(kwargs['f'], dtype=np.int32))) # [7774, 3] + + kintree_table = kwargs['kintree_table'] + self.register_buffer("parents", torch.from_numpy(kintree_table[0].astype(np.int32))) + + def forward( + self, + betas: Optional[Tensor] = None, + global_orient: Optional[Tensor] = None, + pose: Optional[Tensor] = None, + transl: Optional[Tensor] = None, + return_verts: bool = True, + return_full_pose: bool = False, + **kwargs): + """ + Args: + betas: [batch_size, 10] + global_orient: [batch_size, 1, 3, 3] + pose: [batch_size, num_joints, 3, 3] + transl: [batch_size, num_joints, 3] + return_verts: + return_full_pose: + **kwargs: + Returns: + """ + device, dtype = betas.device, betas.dtype + if global_orient is None: + batch_size = 1 + global_orient = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous() + else: + batch_size = global_orient.shape[0] + if pose is None: + pose = torch.eye(3, device=device, dtype=dtype).view( + 1, 1, 3, 3).expand(batch_size, 34, -1, -1).contiguous() + if betas is None: + betas = torch.zeros( + [batch_size, self.num_betas], dtype=dtype, device=device) + if transl is None: + transl = torch.zeros([batch_size, 3], dtype=dtype, device=device) + + full_pose = torch.cat([global_orient, pose], dim=1) + vertices, joints = lbs(betas, full_pose, self.v_template, + self.shapedirs, self.posedirs, + self.J_regressor, self.parents, + self.lbs_weights, pose2rot=False) + + if transl is not None: + joints = joints + transl.unsqueeze(dim=1) + vertices = vertices + transl.unsqueeze(dim=1) + + output = SMALOutput( + vertices=vertices if return_verts else None, + joints=joints if return_verts else None, + betas=betas, + global_orient=global_orient, + pose=pose, + transl=transl, + full_pose=full_pose if return_full_pose else None, + ) + return output + + +class SMAL(SMALLayer): + def __init__(self, **kwargs): + super(SMAL, self).__init__(**kwargs) + + def forward(self, *args, **kwargs): + smal_output = super(SMAL, self).forward(**kwargs) + + keypoint = [] + for kp_v in keypoint_vertices_idx: + keypoint.append(smal_output.vertices[:, kp_v, :].mean(dim=1)) + smal_output.joints = torch.stack(keypoint, dim=1) + return smal_output + + diff --git a/prima/utils/__init__.py b/prima/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..9b323a57a00b59045862cfeded8f431418235b24 --- /dev/null +++ b/prima/utils/__init__.py @@ -0,0 +1,45 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from typing import Any + + +def recursive_to(x: Any, target: Any): + """ + Recursively transfer a batch of data to the target device + Args: + x (Any): Batch of data. + target (torch.device): Target device. + Returns: + Batch of data where all tensors are transferred to the target device. + """ + import torch + + def move(value: Any): + if isinstance(value, dict): + return {k: move(v) for k, v in value.items()} + if isinstance(value, torch.Tensor): + return value.to(target) + if isinstance(value, list): + return [move(i) for i in value] + return value + + return move(x) + + +def __getattr__(name: str): + if name == "MeshRenderer": + from .mesh_renderer import MeshRenderer + + globals()[name] = MeshRenderer + return MeshRenderer + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = ["MeshRenderer", "recursive_to"] diff --git a/prima/utils/detection.py b/prima/utils/detection.py new file mode 100644 index 0000000000000000000000000000000000000000..72dcae2a21ac74e3fd3592a1c0c0bda324376a5f --- /dev/null +++ b/prima/utils/detection.py @@ -0,0 +1,118 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from __future__ import annotations + +# Utilities for filtering animal detections before PRIMA demo inference. +# +# Detectron2 may return both a full-animal box and a local/partial box for the +# same animal. These helpers keep the demo pipeline from rendering the same +# animal multiple times. + +from typing import Iterable + +import numpy as np + +ANIMAL_COCO_IDS = (15, 16, 17, 18, 19, 21, 22) + + +def _box_areas(boxes: np.ndarray) -> np.ndarray: + widths = np.maximum(0.0, boxes[:, 2] - boxes[:, 0]) + heights = np.maximum(0.0, boxes[:, 3] - boxes[:, 1]) + return widths * heights + + +def _intersection_areas(box: np.ndarray, boxes: np.ndarray) -> np.ndarray: + x1 = np.maximum(box[0], boxes[:, 0]) + y1 = np.maximum(box[1], boxes[:, 1]) + x2 = np.minimum(box[2], boxes[:, 2]) + y2 = np.minimum(box[3], boxes[:, 3]) + return np.maximum(0.0, x2 - x1) * np.maximum(0.0, y2 - y1) + + +def _suppress_duplicate_boxes( + boxes: np.ndarray, + scores: np.ndarray, + *, + iou_threshold: float, + containment_threshold: float, +) -> np.ndarray: + if len(boxes) <= 1: + return np.arange(len(boxes), dtype=np.int64) + + boxes = boxes.astype(np.float32, copy=False) + scores = scores.astype(np.float32, copy=False) + areas = _box_areas(boxes) + + contained = np.zeros(len(boxes), dtype=bool) + for idx, area in enumerate(areas): + if area <= 0: + contained[idx] = True + continue + larger = np.where(areas > area)[0] + if len(larger) == 0: + continue + covered = _intersection_areas(boxes[idx], boxes[larger]) / area + if np.any(covered >= containment_threshold): + contained[idx] = True + + candidates = np.where(~contained)[0] + if len(candidates) <= 1: + return candidates + + order = candidates[np.argsort(scores[candidates])[::-1]] + keep = [] + while len(order) > 0: + current = order[0] + keep.append(current) + rest = order[1:] + if len(rest) == 0: + break + + inter = _intersection_areas(boxes[current], boxes[rest]) + union = areas[current] + areas[rest] - inter + iou = np.divide(inter, union, out=np.zeros_like(inter), where=union > 0) + order = rest[iou <= iou_threshold] + + return np.array(sorted(keep), dtype=np.int64) + + +def select_animal_boxes( + det_instances, + *, + animal_class_ids: Iterable[int] = ANIMAL_COCO_IDS, + score_threshold: float = 0.7, + iou_threshold: float = 0.5, + containment_threshold: float = 0.9, +) -> tuple[np.ndarray, int]: + """Return filtered animal boxes and the number of duplicate boxes removed.""" + class_ids = set(int(class_id) for class_id in animal_class_ids) + classes = det_instances.pred_classes.detach().cpu().numpy() + scores = det_instances.scores.detach().cpu().numpy() + + valid_idx = np.array( + [ + i + for i, (class_id, score) in enumerate(zip(classes, scores)) + if int(class_id) in class_ids and float(score) > float(score_threshold) + ], + dtype=np.int64, + ) + if len(valid_idx) == 0: + return np.zeros((0, 4), dtype=np.float32), 0 + + boxes = det_instances.pred_boxes.tensor[valid_idx].detach().cpu().numpy() + scores = scores[valid_idx] + keep = _suppress_duplicate_boxes( + boxes, + scores, + iou_threshold=iou_threshold, + containment_threshold=containment_threshold, + ) + return boxes[keep], int(len(boxes) - len(keep)) diff --git a/prima/utils/evaluate_metric.py b/prima/utils/evaluate_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..39720bd53098b21a5cfbe06bf18317c9f217066b --- /dev/null +++ b/prima/utils/evaluate_metric.py @@ -0,0 +1,206 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import torch +import numpy as np +import open3d as o3d +from typing import Dict, List, Union +from pytorch3d.transforms import axis_angle_to_matrix + + +def compute_scale_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor: + """ + Computes a scale transform (s) in a batched way that takes + a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3). + Args: + S1 (torch.Tensor): First set of points of shape (B, N, 3). + S2 (torch.Tensor): Second set of points of shape (B, N, 3). + Returns: + (torch.Tensor): The first set of points after applying the scale transformation. + """ + + # 1. Remove mean. + mu1 = S1.mean(dim=1, keepdim=True) + mu2 = S2.mean(dim=1, keepdim=True) + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = (X1 ** 2).sum(dim=(1, 2), keepdim=True) + + # 3. Compute scale. + scale = (X2 * X1).sum(dim=(1, 2), keepdim=True) / var1 + + # 4. Apply scale transform. + S1_hat = scale * X1 + mu2 + + return S1_hat + + +def compute_similarity_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor: + """ + Computes a similarity transform (sR, t) in a batched way that takes + a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3), + where R is a 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + Args: + S1 (torch.Tensor): First set of points of shape (B, N, 3). + S2 (torch.Tensor): Second set of points of shape (B, N, 3). + Returns: + (torch.Tensor): The first set of points after applying the similarity transformation. + """ + + batch_size = S1.shape[0] + S1 = S1.permute(0, 2, 1) + S2 = S2.permute(0, 2, 1) + # 1. Remove mean. + mu1 = S1.mean(dim=2, keepdim=True) + mu2 = S2.mean(dim=2, keepdim=True) + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = (X1 ** 2).sum(dim=(1, 2)) + + # 3. The outer product of X1 and X2. + K = torch.matmul(X1.float(), X2.permute(0, 2, 1)) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K. + U, s, V = torch.svd(K.float()) + Vh = V.permute(0, 2, 1) + + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = torch.eye(U.shape[1], device=U.device).unsqueeze(0).repeat(batch_size, 1, 1).float() + Z[:, -1, -1] *= torch.sign(torch.linalg.det(torch.matmul(U.float(), Vh.float()).float())) + + # Construct R. + R = torch.matmul(torch.matmul(V, Z), U.permute(0, 2, 1)) + + # 5. Recover scale. + trace = torch.matmul(R, K).diagonal(offset=0, dim1=-1, dim2=-2).sum(dim=-1) + scale = (trace / var1).unsqueeze(dim=-1).unsqueeze(dim=-1) + + # 6. Recover translation. + t = mu2 - scale * torch.matmul(R.float(), mu1.float()) + + # 7. Error: + S1_hat = scale * torch.matmul(R.float(), S1.float()).float() + t + + return S1_hat.permute(0, 2, 1) + + +def pointcloud(points: np.ndarray): + pcd = o3d.geometry.PointCloud() + points = o3d.utility.Vector3dVector(points) + pcd.points = points + return pcd + + +class Evaluator: + def __init__(self, smal_model, image_size: int=256, pelvis_ind: int = 7): + self.pelvis_ind = pelvis_ind + self.smal_model = smal_model + self.image_size = image_size + + def compute_pck(self, output: Dict, batch: Dict, pck_threshold: Union[List, None]): + if pck_threshold is None or len(pck_threshold) == 0: + return torch.tensor([], dtype=torch.float32) + + pred_keypoints_2d = output['pred_keypoints_2d'].detach().cpu() + gt_keypoints_2d = batch['keypoints_2d'].detach().cpu() + + pred_keypoints_2d = (pred_keypoints_2d + 0.5) * self.image_size + conf = gt_keypoints_2d[:, :, -1] + gt_keypoints_2d = (gt_keypoints_2d[:, :, :-1] + 0.5) * self.image_size + + if 'mask' in batch and batch['mask'] is not None: + seg_area = torch.sum(batch['mask'].detach().cpu().reshape(batch['mask'].shape[0], -1), dim=-1).unsqueeze(-1) + else: + seg_area = torch.tensor([self.image_size * self.image_size] * len(pred_keypoints_2d), dtype=torch.float32).unsqueeze(-1) + + total_visible = torch.sum(conf, dim=-1).clamp_min(1e-6) # (B,) + dist = torch.norm(pred_keypoints_2d - gt_keypoints_2d, dim=-1) # (B, K) + norm_dist = dist / torch.sqrt(seg_area) # (B, K) + + thresholds = torch.tensor(pck_threshold, dtype=torch.float32).view(-1, 1, 1) # (T, 1, 1) + hits = (norm_dist.unsqueeze(0) < thresholds).float() # (T, B, K) + pcks = (hits * conf.unsqueeze(0)).sum(dim=-1) / total_visible.unsqueeze(0) # (T, B) + return pcks.mean(dim=1) # (T,) + + def compute_pa_mpjpe(self, pred_joints, gt_joints): + S1_hat = compute_similarity_transform(pred_joints, gt_joints) + pa_mpjpe = torch.sqrt(((S1_hat - gt_joints) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * 1000 + return pa_mpjpe.mean() + + def compute_pa_mpvpe(self, gt_vertices: torch.Tensor, pred_vertices: torch.Tensor): + batch_size = pred_vertices.shape[0] + S1_hat = compute_similarity_transform(pred_vertices, gt_vertices) + pa_mpvpe = torch.sqrt(((S1_hat - gt_vertices) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() * 1000 + return pa_mpvpe.mean() + + def eval_3d(self, output: Dict, batch: Dict): + """ + Evaluate current batch + Args: + output: model output + batch: model input + Returns: evaluate metric + """ + if batch['has_smal_params']["betas"].sum() == 0: + return 0., 0. + + pred_keypoints_3d = output["pred_keypoints_3d"].detach() + pred_keypoints_3d = pred_keypoints_3d[:, None, :, :] + batch_size = pred_keypoints_3d.shape[0] + num_samples = pred_keypoints_3d.shape[1] + gt_keypoints_3d = batch['keypoints_3d'][:, :, :-1].unsqueeze(1).repeat(1, num_samples, 1, 1) + gt_vertices = self.smal_forward(batch) + + # Align predictions and ground truth such that the pelvis location is at the origin + pred_keypoints_3d -= pred_keypoints_3d[:, :, [self.pelvis_ind]] + gt_keypoints_3d -= gt_keypoints_3d[:, :, [self.pelvis_ind]] + + pa_mpjpe = self.compute_pa_mpjpe(pred_keypoints_3d.reshape(batch_size * num_samples, -1, 3), + gt_keypoints_3d.reshape(batch_size * num_samples, -1, 3)) + pa_mpvpe = self.compute_pa_mpvpe(gt_vertices, output['pred_vertices']) + return pa_mpjpe, pa_mpvpe + + def eval_2d(self, output: Dict, batch: Dict, pck_threshold: List[float]=[0.10, 0.15]): + pck = self.compute_pck(output, batch, pck_threshold=pck_threshold) + auc = self.compute_auc(batch, output) + return pck.tolist(), auc + + def compute_auc(self, batch: Dict, output: Dict, threshold_min: float=0.0, threshold_max: float=1.0, steps: int=100): + thresholds = np.linspace(threshold_min, threshold_max, steps) + pck_curve = self.compute_pck(output, batch, thresholds.tolist()).numpy() # (steps,) + norm_factor = threshold_max - threshold_min + auc = float(np.trapz(pck_curve, thresholds) / norm_factor) + return auc + + def smal_forward(self, batch: Dict): + batch_size = batch['img'].shape[0] + smal_params = batch['smal_params'] + smal_params['global_orient'] = axis_angle_to_matrix(smal_params['global_orient'].reshape(batch_size, -1)).unsqueeze(1) + smal_params['pose'] = axis_angle_to_matrix(smal_params['pose'].reshape(batch_size, -1, 3)) + # The SMAL model only registers buffers (e.g. shapedirs) and has no trainable parameters, + # so self.smal_model.parameters() can be empty and calling next on it would raise StopIteration. + # Here we first try to get the device from parameters; if there are no parameters, fall back to buffers; + # if there are no buffers either, fall back to the device of the input batch. + try: + device = next(self.smal_model.parameters()).device + except StopIteration: + try: + device = next(self.smal_model.buffers()).device + except StopIteration: + device = batch['img'].device + smal_params = {k: v.to(device) for k, v in smal_params.items()} + with torch.no_grad(): + smal_output = self.smal_model(**smal_params) + vertices = smal_output.vertices + return vertices diff --git a/prima/utils/geometry.py b/prima/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc16e8faf270d761a35dbb7d04381a6cf8ab0c0 --- /dev/null +++ b/prima/utils/geometry.py @@ -0,0 +1,115 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from typing import Optional +import torch +from torch.nn import functional as F + + +def aa_to_rotmat(theta: torch.Tensor): + """ + Convert axis-angle representation to rotation matrix. + Works by first converting it to a quaternion. + Args: + theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations. + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm = torch.norm(theta + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim=1) + return quat_to_rotmat(quat) + + +def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor: + """ + Convert quaternion representation to rotation matrix. + Args: + quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z). + Returns: + torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, + 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx, + 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + + +def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor: + """ + Convert 6D rotation representation to 3x3 rotation matrix. + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Args: + x (torch.Tensor): (B,6) Batch of 6-D rotation representations. + Returns: + torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3). + """ + x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous() + a1 = x[:, :, 0] + a2 = x[:, :, 1] + b1 = F.normalize(a1) + b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) + b3 = torch.cross(b1, b2, dim=1) + return torch.stack((b1, b2, b3), dim=-1) + + +def perspective_projection(points: torch.Tensor, + translation: torch.Tensor, + focal_length: torch.Tensor, + camera_center: Optional[torch.Tensor] = None, + rotation: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Computes the perspective projection of a set of 3D points. + Args: + points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points. + translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation. + focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels. + camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels. + rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation. + Returns: + torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points. + """ + batch_size = points.shape[0] + if rotation is None: + rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1) + if camera_center is None: + camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype) + # Populate intrinsic camera matrix K. + K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype) + K[:, 0, 0] = focal_length[:, 0] + K[:, 1, 1] = focal_length[:, 1] + K[:, 2, 2] = 1. + K[:, :-1, -1] = camera_center + + # Transform points + points = torch.einsum('bij,bkj->bki', rotation, points) + points = points + translation.unsqueeze(1) + + # Apply perspective distortion + projected_points = points / points[:, :, -1].unsqueeze(-1) + + # Apply camera intrinsics + projected_points = torch.einsum('bij,bkj->bki', K, projected_points) + + + return projected_points[:, :, :-1] diff --git a/prima/utils/mesh_renderer.py b/prima/utils/mesh_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..75010042b03e53f3340a2b5223ce91626a62ee43 --- /dev/null +++ b/prima/utils/mesh_renderer.py @@ -0,0 +1,330 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import os +from ctypes.util import find_library + +if 'PYOPENGL_PLATFORM' not in os.environ and os.uname().sysname != 'Darwin': + # Prefer EGL; PyOpenGL's OSMesa bindings can lack symbols required by pyrender. + os.environ['PYOPENGL_PLATFORM'] = 'egl' if find_library('EGL') else 'osmesa' + if os.environ['PYOPENGL_PLATFORM'] == 'egl': + os.environ.setdefault('EGL_PLATFORM', 'surfaceless') +import torch +from torchvision.utils import make_grid +import numpy as np +import pyrender +import trimesh +import cv2 +import math +import torch.nn.functional as F +from typing import List, Tuple + + +def create_raymond_lights(): + import pyrender + thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]) + phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0]) + + nodes = [] + + for phi, theta in zip(phis, thetas): + xp = np.sin(theta) * np.cos(phi) + yp = np.sin(theta) * np.sin(phi) + zp = np.cos(theta) + + z = np.array([xp, yp, zp]) + z = z / np.linalg.norm(z) + x = np.array([-z[1], z[0], 0.0]) + if np.linalg.norm(x) == 0: + x = np.array([1.0, 0.0, 0.0]) + x = x / np.linalg.norm(x) + y = np.cross(z, x) + + matrix = np.eye(4) + matrix[:3, :3] = np.c_[x, y, z] + nodes.append(pyrender.Node( + light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0), + matrix=matrix + )) + + return nodes + + +def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]: + """ + Compute rectangle enclosing keypoints above the threshold. + Args: + keypoints (np.array): Keypoint array of shape (N, 3). + threshold (float): Confidence visualization threshold. + Returns: + Tuple[float, float, float]: Rectangle width, height and area. + """ + valid_ind = keypoints[:, -1] > threshold + if valid_ind.sum() > 0: + valid_keypoints = keypoints[valid_ind][:, :-1] + max_x = valid_keypoints[:, 0].max() + max_y = valid_keypoints[:, 1].max() + min_x = valid_keypoints[:, 0].min() + min_y = valid_keypoints[:, 1].min() + width = max_x - min_x + height = max_y - min_y + area = width * height + return width, height, area + else: + return 0, 0, 0 + + +def render_keypoint(img: np.array, keypoint: np.array, threshold=0.1, + use_confidence=False, map_fn=lambda x: np.ones_like(x), alpha=1.0) -> np.array: + if use_confidence and map_fn is not None: + thicknessCircleRatioRight = 1. / 50 * map_fn(keypoint[:, -1]) + else: + thicknessCircleRatioRight = 1. / 50 * np.ones(keypoint.shape[0]) + + thicknessLineRatioWRTCircle = 0.75 + if keypoint.shape[0] == 26: + pairs = [0, 24, 1, 24, 2, 24, 3, 14, 4, 15, 5, 16, 6, 17, 7, 18, 8, 12, 9, 13, 10, 7, 11, 7, + 12, 18, 13, 18, 14, 8, 15, 9, 16, 10, 17, 11, 18, 24, 19, 25, 20, 0, 21, 1, 22, 24, + 23, 24, 25, 7] + elif keypoint.shape[0] == 18: + pairs = [9, 8, 8, 2, 2, 3, 3, 4, 2, 0, 2, 1, 4, 5, + 5, 14, 14, 15, 4, 6, 6, 7, 7, 11, 11, 10, + 7, 13, 13, 12, 5, 16, 5, 17] + else: + raise ValueError("Keypoint shape not supported") + pairs = np.array(pairs).reshape(-1, 2) if pairs is not None else None + colors = [255., 0., 85., + 255., 0., 0., + 255., 85., 0., + 255., 170., 0., + 255., 255., 0., + 170., 255., 0., + 85., 255., 0., + 0., 255., 0., + 255., 0., 0., + 0., 255., 85., + 0., 255., 170., + 0., 255., 255., + 0., 170., 255., + 0., 85., 255., + 0., 0., 255., + 255., 0., 170., + 170., 0., 255., + 255., 0., 255., + 85., 0., 255., + 0., 0., 255., + 0., 0., 255., + 0., 0., 255., + 0., 255., 255., + 0., 255., 255., + 0., 255., 255., + 255., 225., 255.] + colors = np.array(colors).reshape(-1, 3) + poseScales = [1] + + img_orig = img.copy() + width, height = img.shape[1], img.shape[2] + area = width * height + + lineType = 8 + shift = 0 + numberColors = len(colors) + thresholdRectangle = 0.1 + + animal_width, animal_height, animal_area = get_keypoints_rectangle(keypoint, thresholdRectangle) + if animal_area > 0: + ratioAreas = min(1, max(animal_width / width, animal_height / height)) + thicknessRatio = np.maximum(np.round(math.sqrt(area) * thicknessCircleRatioRight * ratioAreas), 2) + thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio)) + thicknessLine = np.maximum(1, np.round(thicknessRatio * thicknessLineRatioWRTCircle)) + radius = thicknessRatio / 2 + else: + return img + + img = np.ascontiguousarray(img.copy()) + if pairs is not None: + for i, pair in enumerate(pairs): + index1, index2 = pair + if keypoint[index1, -1] > threshold and keypoint[index2, -1] > threshold: + thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * poseScales[0])) + colorIndex = index2 + color = colors[colorIndex % numberColors] + keypoint1 = keypoint[index1, :-1].astype(np.int32) + keypoint2 = keypoint[index2, :-1].astype(np.int32) + cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()), + thicknessLineScaled, lineType, shift) + for part in range(len(keypoint)): + faceIndex = part + if keypoint[faceIndex, -1] > threshold: + radiusScaled = int(round(radius[faceIndex] * poseScales[0])) + thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * poseScales[0])) + colorIndex = part + color = colors[colorIndex % numberColors] + center = keypoint[faceIndex, :-1].astype(np.int32) + cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled, + lineType, shift) + + return img + + +class MeshRenderer: + + def __init__(self, cfg, faces=None): + self.cfg = cfg + self.img_res = cfg.MODEL.IMAGE_SIZE + self.renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res, + viewport_height=self.img_res, + point_size=1.0) + + self.camera_center = [self.img_res // 2, self.img_res // 2] + self.faces = faces + + def visualize(self, vertices, camera_translation, images, focal_length, nrow=3, padding=2): + images_np = np.transpose(images, (0, 2, 3, 1)) + rend_imgs = [] + for i in range(vertices.shape[0]): + rend_img = torch.from_numpy(np.transpose( + self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=False), + (2, 0, 1))).float() + rend_img_side = torch.from_numpy(np.transpose( + self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=True), + (2, 0, 1))).float() + rend_imgs.append(torch.from_numpy(images[i])) + rend_imgs.append(rend_img) + rend_imgs.append(rend_img_side) + rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding) + return rend_imgs + + def visualize_tensorboard(self, vertices, camera_translation, images, focal_length, pred_keypoints, gt_keypoints, + pred_masks=None, gt_masks=None): + images_np = np.transpose(images, (0, 2, 3, 1)) + rend_imgs = [] + pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1) + pred_keypoints = self.img_res * (pred_keypoints + 0.5) + gt_keypoints[:, :, :-1] = self.img_res * (gt_keypoints[:, :, :-1] + 0.5) + # keypoint_matches = [(1, 12), (2, 8), (3, 7), (4, 6), (5, 9), + # (6, 10), (7, 11), (8, 14), (9, 2), (10, 1), (11, 0), (12, 3), (13, 4), (14, 5)] + # rend_img_pytorch3d = self.render_by_pytorch3d(vertices, camera_translation, + # images_np, focal_length=self.focal_length) + for i in range(vertices.shape[0]): + rend_img = torch.from_numpy(np.transpose( + self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=False), + (2, 0, 1))).float() + rend_img_side = torch.from_numpy(np.transpose( + self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=focal_length, side_view=True), + (2, 0, 1))).float() + keypoints = pred_keypoints[i] + pred_keypoints_img = render_keypoint(255 * images_np[i].copy(), keypoints) / 255 + keypoints = gt_keypoints[i] + gt_keypoints_img = render_keypoint(255 * images_np[i].copy(), keypoints) / 255 + rend_imgs.append(torch.from_numpy(images[i])) + rend_imgs.append(rend_img) + rend_imgs.append(rend_img_side) + if pred_masks is not None: + rend_imgs.append(torch.from_numpy(pred_masks[i])) + if gt_masks is not None: + rend_imgs.append(torch.from_numpy(gt_masks[i])) + rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2, 0, 1)) + rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2, 0, 1)) + return rend_imgs + + def __call__(self, vertices, camera_translation, image, focal_length, text=None, resize=None, side_view=False, + baseColorFactor=(1.0, 1.0, 0.9, 1.0), rot_angle=90): + renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1], + viewport_height=image.shape[0], + point_size=1.0) + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.0, + alphaMode='OPAQUE', + baseColorFactor=baseColorFactor) + + camera_translation_local = camera_translation.copy() + camera_translation_local[0] *= -1. + + mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy()) + if side_view: + rot = trimesh.transformations.rotation_matrix( + np.radians(rot_angle), [0, 1, 0]) + mesh.apply_transform(rot) + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation_local + camera_center = [image.shape[1] / 2., image.shape[0] / 2.] + camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length, + cx=camera_center[0], cy=camera_center[1], + zfar=1000) + scene.add(camera, pose=camera_pose) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis] + if not side_view: + output_img = (color[:, :, :3] * valid_mask + + (1 - valid_mask) * image) + else: + output_img = color[:, :, :3] + if resize is not None: + output_img = cv2.resize(output_img, resize) + + output_img = output_img.astype(np.float32) + renderer.delete() + return output_img + + def render_mask(self, vertices, camera_translation, focal_length, side_view=False, rot_angle=90): + """ + Render only the visibility mask (alpha>0) of the mesh given vertices and camera translation. + Returns a single-channel float32 numpy array with values 0.0 or 1.0 with shape (H, W). + """ + renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res, + viewport_height=self.img_res, + point_size=1.0) + + mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy()) + if side_view: + rot = trimesh.transformations.rotation_matrix( + np.radians(rot_angle), [0, 1, 0]) + mesh.apply_transform(rot) + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + mesh = pyrender.Mesh.from_trimesh(mesh) + + scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation + camera_center = [self.img_res / 2., self.img_res / 2.] + camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length, + cx=camera_center[0], cy=camera_center[1], + zfar=1000) + scene.add(camera, pose=camera_pose) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + # alpha channel indicates visibility + mask = (color[:, :, -1] > 0).astype(np.float32) + renderer.delete() + return mask diff --git a/prima/utils/misc.py b/prima/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7c4a15159ac5395ac7d82d43e28bfc6de308fbd0 --- /dev/null +++ b/prima/utils/misc.py @@ -0,0 +1,211 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import time +import warnings +from importlib.util import find_spec +from pathlib import Path +from typing import Callable, List + +import hydra +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Callback +from pytorch_lightning.loggers import Logger +from pytorch_lightning.utilities import rank_zero_only + +from . import pylogger, rich_utils + +log = pylogger.get_pylogger(__name__) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that wraps the task function in extra utilities. + + Makes multirun more resistant to failure. + + Utilities: + - Calling the `utils.extras()` before the task is started + - Calling the `utils.close_loggers()` after the task is finished + - Logging the exception if occurs + - Logging the task total execution time + - Logging the output dir + """ + + def wrap(cfg: DictConfig): + start_time = time.time() + try: + # apply extra utilities + extras(cfg) + + # execute the task + ret = task_func(cfg=cfg) + except Exception as ex: + log.exception("") # save exception to `.log` file + raise ex + finally: + path = Path(cfg.paths.output_dir, "exec_time.log") + content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" + save_file(path, content) # save task execution time (even if exception occurs) + close_loggers() # close loggers (even if exception occurs so multirun won't fail) + + log.info(f"Output dir: {cfg.paths.output_dir}") + + return ret + + return wrap + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +@rank_zero_only +def save_file(path: str, content: str) -> None: + """Save file in rank zero mode (only on one process in multi-GPU setup).""" + with open(path, "w+") as file: + file.write(content) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("Callbacks config is empty.") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("Logger config is empty.") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + for k in cfg.keys(): + hparams[k] = cfg.get(k) + + # Resolve all interpolations + def _resolve(_cfg): + if isinstance(_cfg, DictConfig): + _cfg = OmegaConf.to_container(_cfg, resolve=True) + return _cfg + + hparams = {k: _resolve(v) for k, v in hparams.items()} + + # send hparams to all loggers + trainer.logger.log_hyperparams(hparams) + + +def get_metric_value(metric_dict: dict, metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule.""" + + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def close_loggers() -> None: + """Makes sure all loggers closed properly (prevents logging failure during multirun).""" + + log.info("Closing loggers...") + + if find_spec("wandb"): # if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() diff --git a/prima/utils/pylogger.py b/prima/utils/pylogger.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5e2f94074a368f12a78036cc469128188f2af8 --- /dev/null +++ b/prima/utils/pylogger.py @@ -0,0 +1,26 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +import logging + +from pytorch_lightning.utilities import rank_zero_only + + +def get_pylogger(name=__name__) -> logging.Logger: + """Initializes multi-GPU-friendly python command line logger.""" + + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger diff --git a/prima/utils/renderer.py b/prima/utils/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..71ae021c2436ba8ff750d4be9b99bba70afbfa17 --- /dev/null +++ b/prima/utils/renderer.py @@ -0,0 +1,433 @@ +from __future__ import annotations +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + + +import os +from ctypes.util import find_library + +if 'PYOPENGL_PLATFORM' not in os.environ and os.uname().sysname != 'Darwin': + # Prefer EGL; PyOpenGL's OSMesa bindings can lack symbols required by pyrender. + os.environ['PYOPENGL_PLATFORM'] = 'egl' if find_library('EGL') else 'osmesa' + if os.environ['PYOPENGL_PLATFORM'] == 'egl': + os.environ.setdefault('EGL_PLATFORM', 'surfaceless') +import torch +import numpy as np +import pyrender +import trimesh +import cv2 +from yacs.config import CfgNode +from typing import List, Optional + + +def cam_crop_to_full(cam_bbox, box_center, box_size, img_size, focal_length=5000.): + # Convert cam_bbox to full image + img_w, img_h = img_size[:, 0], img_size[:, 1] + cx, cy, b = box_center[:, 0], box_center[:, 1], box_size + w_2, h_2 = img_w / 2., img_h / 2. + bs = b * cam_bbox[:, 0] + 1e-9 + tz = 2 * focal_length / bs + tx = (2 * (cx - w_2) / bs) + cam_bbox[:, 1] + ty = (2 * (cy - h_2) / bs) + cam_bbox[:, 2] + full_cam = torch.stack([tx, ty, tz], dim=-1) + return full_cam + + +def get_light_poses(n_lights=5, elevation=np.pi / 3, dist=12): + # get lights in a circle around origin at elevation + thetas = elevation * np.ones(n_lights) + phis = 2 * np.pi * np.arange(n_lights) / n_lights + poses = [] + trans = make_translation(torch.tensor([0, 0, dist])) + for phi, theta in zip(phis, thetas): + rot = make_rotation(rx=-theta, ry=phi, order="xyz") + poses.append((rot @ trans).numpy()) + return poses + + +def make_translation(t): + return make_4x4_pose(torch.eye(3), t) + + +def make_rotation(rx=0, ry=0, rz=0, order="xyz"): + Rx = rotx(rx) + Ry = roty(ry) + Rz = rotz(rz) + if order == "xyz": + R = Rz @ Ry @ Rx + elif order == "xzy": + R = Ry @ Rz @ Rx + elif order == "yxz": + R = Rz @ Rx @ Ry + elif order == "yzx": + R = Rx @ Rz @ Ry + elif order == "zyx": + R = Rx @ Ry @ Rz + elif order == "zxy": + R = Ry @ Rx @ Rz + return make_4x4_pose(R, torch.zeros(3)) + + +def make_4x4_pose(R, t): + """ + :param R (*, 3, 3) + :param t (*, 3) + return (*, 4, 4) + """ + dims = R.shape[:-2] + pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1) + bottom = ( + torch.tensor([0, 0, 0, 1], device=R.device) + .reshape(*(1,) * len(dims), 1, 4) + .expand(*dims, 1, 4) + ) + return torch.cat([pose_3x4, bottom], dim=-2) + + +def rotx(theta): + return torch.tensor( + [ + [1, 0, 0], + [0, np.cos(theta), -np.sin(theta)], + [0, np.sin(theta), np.cos(theta)], + ], + dtype=torch.float32, + ) + + +def roty(theta): + return torch.tensor( + [ + [np.cos(theta), 0, np.sin(theta)], + [0, 1, 0], + [-np.sin(theta), 0, np.cos(theta)], + ], + dtype=torch.float32, + ) + + +def rotz(theta): + return torch.tensor( + [ + [np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1], + ], + dtype=torch.float32, + ) + + +def create_raymond_lights() -> List[pyrender.Node]: + """ + Return raymond light nodes for the scene. + """ + thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]) + phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0]) + + nodes = [] + + for phi, theta in zip(phis, thetas): + xp = np.sin(theta) * np.cos(phi) + yp = np.sin(theta) * np.sin(phi) + zp = np.cos(theta) + + z = np.array([xp, yp, zp]) + z = z / np.linalg.norm(z) + x = np.array([-z[1], z[0], 0.0]) + if np.linalg.norm(x) == 0: + x = np.array([1.0, 0.0, 0.0]) + x = x / np.linalg.norm(x) + y = np.cross(z, x) + + matrix = np.eye(4) + matrix[:3, :3] = np.c_[x, y, z] + nodes.append(pyrender.Node( + light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0), + matrix=matrix + )) + + return nodes + + +class Renderer: + + def __init__(self, cfg: CfgNode, faces: np.array): + """ + Wrapper around the pyrender renderer to render MANO meshes. + Args: + cfg (CfgNode): Model config file. + faces (np.array): Array of shape (F, 3) containing the mesh faces. + """ + self.cfg = cfg + self.focal_length = 1000. if faces.shape[0] == 7774 else 2167. + self.img_res = cfg.MODEL.IMAGE_SIZE + + self.camera_center = [self.img_res // 2, self.img_res // 2] + self.faces = faces.cpu().numpy() + + def __call__(self, + vertices: np.array, + camera_translation: np.array, + image: torch.Tensor, + full_frame: bool = False, + imgname: Optional[str] = None, + side_view=False, rot_angle=90, + mesh_base_color=(1.0, 1.0, 0.9), + scene_bg_color=(0, 0, 0), + return_rgba=False, + depth = False, + focal_length: Optional[float] = None, + ) -> np.array: + """ + Render meshes on input image + Args: + vertices (np.array): Array of shape (V, 3) containing the mesh vertices. + camera_translation (np.array): Array of shape (3,) with the camera translation. + image (torch.Tensor): Tensor of shape (3, H, W) containing the image crop with normalized pixel values. + full_frame (bool): If True, then render on the full image. + imgname (Optional[str]): Contains the original image filenamee. Used only if full_frame == True. + focal_length (Optional[float]): Custom focal length. If None, uses self.focal_length. + """ + + if full_frame: + + image = cv2.imread(imgname) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) / 255. + else: + image = (image.clone()) * (torch.tensor(self.cfg.MODEL.IMAGE_STD, device=image.device).reshape(3, 1, 1)) + image = image + torch.tensor(self.cfg.MODEL.IMAGE_MEAN, device=image.device).reshape(3, 1, 1) + image = image.permute(1, 2, 0).cpu().numpy() + + # Use custom focal length if provided, otherwise use default + focal_length_to_use = focal_length if focal_length is not None else self.focal_length + + renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1], + viewport_height=image.shape[0], + point_size=1.0) + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.0, + alphaMode='OPAQUE', + baseColorFactor=(*mesh_base_color, 1.0)) + + camera_translation_local = camera_translation.copy() + camera_translation_local[0] *= -1. + + mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy()) + if side_view: + rot = trimesh.transformations.rotation_matrix( + np.radians(rot_angle), [0, 1, 0]) + mesh.apply_transform(rot) + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation_local + camera_center = [image.shape[1] / 2., image.shape[0] / 2.] + camera = pyrender.IntrinsicsCamera(fx=focal_length_to_use, fy=focal_length_to_use, + cx=camera_center[0], cy=camera_center[1], zfar=1e12) + scene.add(camera, pose=camera_pose) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + renderer.delete() + + if depth: + return rend_depth + + if return_rgba: + return color + + valid_mask = (rend_depth > 0).astype(np.float32)[:, :, np.newaxis] + if not side_view: + output_img = (color[:, :, :3] * valid_mask + (1 - valid_mask) * image) + else: + output_img = color[:, :, :3] + + output_img = output_img.astype(np.float32) + return output_img + + def vertices_to_trimesh(self, vertices, camera_translation, mesh_base_color=(1.0, 1.0, 0.9), + rot_axis=[1, 0, 0], rot_angle=0): + # material = pyrender.MetallicRoughnessMaterial( + # metallicFactor=0.0, + # alphaMode='OPAQUE', + # baseColorFactor=(*mesh_base_color, 1.0)) + vertex_colors = np.array([(*mesh_base_color, 1.0)] * vertices.shape[0]) + mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces.copy(), vertex_colors=vertex_colors) + # mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy()) + + rot = trimesh.transformations.rotation_matrix( + np.radians(rot_angle), rot_axis) + mesh.apply_transform(rot) + + rot = trimesh.transformations.rotation_matrix( + np.radians(180), [1, 0, 0]) + mesh.apply_transform(rot) + return mesh + + def render_rgba( + self, + vertices: np.array, + cam_t=None, + rot=None, + rot_axis=[1, 0, 0], + rot_angle=0, + camera_z=3, + # camera_translation: np.array, + mesh_base_color=(1.0, 1.0, 0.9), + scene_bg_color=(0, 0, 0), + render_res=[256, 256], + focal_length=None, + ): + + renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0], + viewport_height=render_res[1], + point_size=1.0) + # material = pyrender.MetallicRoughnessMaterial( + # metallicFactor=0.0, + # alphaMode='OPAQUE', + # baseColorFactor=(*mesh_base_color, 1.0)) + + focal_length = focal_length if focal_length is not None else self.focal_length + + if cam_t is not None: + camera_translation = cam_t.copy() + camera_translation[0] *= -1. + else: + camera_translation = np.array([0, 0, camera_z * focal_length / render_res[1]]) + + mesh = self.vertices_to_trimesh(vertices, np.array([0, 0, 0]), mesh_base_color, rot_axis, rot_angle, + ) + mesh = pyrender.Mesh.from_trimesh(mesh) + # mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + scene.add(mesh, 'mesh') + + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_translation + camera_center = [render_res[0] / 2., render_res[1] / 2.] + camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length, + cx=camera_center[0], cy=camera_center[1], zfar=1e12) + + # Create camera node and add it to pyRender scene + camera_node = pyrender.Node(camera=camera, matrix=camera_pose) + scene.add_node(camera_node) + self.add_point_lighting(scene, camera_node) + self.add_lighting(scene, camera_node) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + renderer.delete() + + return color + + def render_rgba_multiple( + self, + vertices: List[np.array], + cam_t: List[np.array], + rot_axis=[1, 0, 0], + rot_angle=0, + mesh_base_color=(1.0, 1.0, 0.9), + scene_bg_color=(0, 0, 0), + render_res=[256, 256], + focal_length=None, + ): + + renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0], + viewport_height=render_res[1], + point_size=1.0) + # material = pyrender.MetallicRoughnessMaterial( + # metallicFactor=0.0, + # alphaMode='OPAQUE', + # baseColorFactor=(*mesh_base_color, 1.0)) + + mesh_list = [pyrender.Mesh.from_trimesh( + self.vertices_to_trimesh(vvv, ttt.copy(), mesh_base_color, rot_axis, rot_angle)) for + vvv, ttt in zip(vertices, cam_t)] + + scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0], + ambient_light=(0.3, 0.3, 0.3)) + for i, mesh in enumerate(mesh_list): + scene.add(mesh, f'mesh_{i}') + + camera_pose = np.eye(4) + # camera_pose[:3, 3] = camera_translation + camera_center = [render_res[0] / 2., render_res[1] / 2.] + focal_length = focal_length if focal_length is not None else self.focal_length + camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length, + cx=camera_center[0], cy=camera_center[1], zfar=1e12) + + # Create camera node and add it to pyRender scene + camera_node = pyrender.Node(camera=camera, matrix=camera_pose) + scene.add_node(camera_node) + self.add_point_lighting(scene, camera_node) + self.add_lighting(scene, camera_node) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) + color = color.astype(np.float32) / 255.0 + renderer.delete() + + return color + + def add_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0): + + light_poses = get_light_poses() + light_poses.append(np.eye(4)) + cam_pose = scene.get_pose(cam_node) + for i, pose in enumerate(light_poses): + matrix = cam_pose @ pose + node = pyrender.Node( + name=f"light-{i:02d}", + light=pyrender.DirectionalLight(color=color, intensity=intensity), + matrix=matrix, + ) + if scene.has_node(node): + continue + scene.add_node(node) + + def add_point_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0): + + light_poses = get_light_poses(dist=0.5) + light_poses.append(np.eye(4)) + cam_pose = scene.get_pose(cam_node) + for i, pose in enumerate(light_poses): + matrix = cam_pose @ pose + # node = pyrender.Node( + # name=f"light-{i:02d}", + # light=pyrender.DirectionalLight(color=color, intensity=intensity), + # matrix=matrix, + # ) + node = pyrender.Node( + name=f"plight-{i:02d}", + light=pyrender.PointLight(color=color, intensity=intensity), + matrix=matrix, + ) + if scene.has_node(node): + continue + scene.add_node(node) diff --git a/prima/utils/rich_utils.py b/prima/utils/rich_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..20f8c36a549238eefa0a2096d33e18fbd8775a7a --- /dev/null +++ b/prima/utils/rich_utils.py @@ -0,0 +1,114 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch_lightning.utilities import rank_zero_only +from rich.prompt import Prompt + +from . import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "datamodule", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + save_to_file (bool, optional): Whether to export config to the hydra output folder. + """ + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + queue.append(field) if field in cfg else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config.""" + + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) + + +if __name__ == "__main__": + from hydra import compose, initialize + + with initialize(version_base="1.2", config_path="../../configs_hydra"): + cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) + print_config_tree(cfg, resolve=False, save_to_file=False) diff --git a/prima/utils/weights.py b/prima/utils/weights.py new file mode 100644 index 0000000000000000000000000000000000000000..c02e593a5713a4261e23120fc01fdcbadb88ddbc --- /dev/null +++ b/prima/utils/weights.py @@ -0,0 +1,337 @@ +""" +PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation + +Official implementation of the paper: +"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" +by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis +Licensed under a modified MIT license +""" + +from __future__ import annotations + +import os +import shutil +from pathlib import Path +from typing import Iterable, Optional, Sequence, Union + +HF_REPO_ID = "MLAdaptiveIntelligence/PRIMA" +DEFAULT_HF_REPO_ID = HF_REPO_ID + +DEFAULT_STAGE1_CHECKPOINT = Path("data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt") +DEFAULT_STAGE3_CHECKPOINT = Path("data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt") + +SMAL_ASSET_PATHS = [ + "my_smpl_00781_4_all.pkl", + "my_smpl_data_00781_4_all.pkl", + "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl", +] +BACKBONE_ASSET_PATH = "amr_vitbb.pth" +STAGE1_CONFIG_ASSET_PATH = "config_s1_HYDRA.yaml" +STAGE1_CHECKPOINT_ASSET_PATH = "s1ckpt_inference.ckpt" +STAGE3_CONFIG_ASSET_PATH = "config_s3_HYDRA.yaml" +STAGE3_CHECKPOINT_ASSET_PATH = "s3ckpt_inference.ckpt" + +STAGE_ASSETS = { + "PRIMAS1": (STAGE1_CONFIG_ASSET_PATH, STAGE1_CHECKPOINT_ASSET_PATH, "s1ckpt_inference.ckpt"), + "PRIMAS3": (STAGE3_CONFIG_ASSET_PATH, STAGE3_CHECKPOINT_ASSET_PATH, "s3ckpt_inference.ckpt"), +} + +STAGE_CHECKPOINTS = { + "PRIMAS1": Path("PRIMAS1/checkpoints/s1ckpt_inference.ckpt"), + "PRIMAS3": Path("PRIMAS3/checkpoints/s3ckpt_inference.ckpt"), +} + +PathLike = Union[str, Path] + + +def _resolve_hf_repo_id(hf_repo_id: Optional[str]) -> str: + return hf_repo_id or os.environ.get("PRIMA_HF_REPO_ID", HF_REPO_ID) + + +def _default_checkpoint_path(data_dir: PathLike = "data") -> Path: + return Path(data_dir) / STAGE_CHECKPOINTS["PRIMAS1"] + + +def _config_path_for_checkpoint(checkpoint_path: PathLike) -> Path: + checkpoint_path = Path(checkpoint_path) + return checkpoint_path.parent.parent / ".hydra" / "config.yaml" + + +def _stage_for_checkpoint(checkpoint_path: PathLike) -> Optional[str]: + checkpoint_path = Path(checkpoint_path) + if len(checkpoint_path.parents) < 2: + return None + stage_name = checkpoint_path.parent.parent.name + stage_assets = STAGE_ASSETS.get(stage_name) + if stage_assets is None: + return None + _, _, checkpoint_name = stage_assets + if checkpoint_path.name != checkpoint_name: + return None + return stage_name + + +def _download_file( + hf_repo_id: str, + remote_filename: str, + destination: Path, + force_download: bool = False, +) -> None: + try: + from huggingface_hub import hf_hub_download + except ImportError: + raise ImportError( + "huggingface_hub is required to download PRIMA demo assets. " + "Install it with: pip install huggingface_hub\n" + "Or download the assets manually and pass a local checkpoint path." + ) from None + + destination.parent.mkdir(parents=True, exist_ok=True) + downloaded = hf_hub_download( + repo_id=hf_repo_id, + filename=remote_filename, + local_dir=str(destination.parent), + local_dir_use_symlinks=False, + force_download=force_download, + ) + downloaded_path = Path(downloaded).resolve() + target = destination.resolve() + if downloaded_path != target: + if target.exists(): + target.unlink() + shutil.move(str(downloaded_path), str(target)) + + +def _validate_torch_checkpoint(path: Path) -> None: + import inspect + import pickle + import zipfile + + import torch + + if zipfile.is_zipfile(path): + with zipfile.ZipFile(path) as checkpoint_zip: + corrupt_member = checkpoint_zip.testzip() + if corrupt_member is not None: + raise RuntimeError( + f"Checkpoint file is invalid or incomplete: {path}\n" + f"Corrupt archive member: {corrupt_member}\n" + "Please redownload the checkpoint and try again." + ) + + supports_weights_only = "weights_only" in inspect.signature(torch.load).parameters + load_kwargs = {"map_location": "cpu"} + if supports_weights_only: + load_kwargs["weights_only"] = True + + try: + torch.load(path, **load_kwargs) + except pickle.UnpicklingError as exc: + message = str(exc) + if ( + supports_weights_only + and "Weights only load failed" in message + and ("Unsupported global" in message or "Unsupported class" in message) + ): + return + raise RuntimeError( + f"Checkpoint file is invalid or incomplete: {path}\n" + "Downloaded checkpoint is not loadable. " + "Please verify the uploaded Hugging Face file and try again." + ) from exc + except Exception as exc: + raise RuntimeError( + f"Checkpoint file is invalid or incomplete: {path}\n" + "Downloaded checkpoint is not loadable. " + "Please verify the uploaded Hugging Face file and try again." + ) from exc + + +def _ensure_backbone(data_dir: Path, force: bool, hf_repo_id: str) -> None: + target = data_dir / "amr_vitbb.pth" + if target.exists() and not force: + print(f"[skip] {target} already exists") + return + + print("[download] pretrained backbone") + _download_file(hf_repo_id, BACKBONE_ASSET_PATH, target, force_download=force) + print(f"[ok] {target}") + + +def _ensure_smal_assets(data_dir: Path, force: bool, hf_repo_id: str) -> None: + required = [Path(p).name for p in SMAL_ASSET_PATHS] + smal_dir = data_dir / "smal" + if smal_dir.exists() and all((smal_dir / n).exists() for n in required) and not force: + print("[skip] SMAL files already exist") + return + + print("[download] SMAL assets") + for asset_path in SMAL_ASSET_PATHS: + target = smal_dir / Path(asset_path).name + _download_file(hf_repo_id, asset_path, target, force_download=force) + print(f"[ok] {smal_dir}") + + +def _ensure_stage_assets( + stage_name: str, + data_dir: Path, + force: bool, + hf_repo_id: str, + validate_existing: bool = True, +) -> None: + if stage_name not in STAGE_ASSETS: + known = ", ".join(sorted(STAGE_ASSETS)) + raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}") + + config_asset_path, checkpoint_asset_path, checkpoint_name = STAGE_ASSETS[stage_name] + stage_dir = data_dir / stage_name + config_target = stage_dir / ".hydra" / "config.yaml" + checkpoint_target = stage_dir / "checkpoints" / checkpoint_name + redownload_checkpoint = False + + if config_target.exists() and checkpoint_target.exists() and not force: + if validate_existing: + try: + _validate_torch_checkpoint(checkpoint_target) + except RuntimeError: + print(f"[warn] {stage_name} checkpoint is incomplete, redownloading checkpoint only.") + redownload_checkpoint = True + else: + print(f"[skip] {stage_name} assets already exist") + return + else: + print(f"[skip] {stage_name} assets already exist") + return + + print(f"[download] {stage_name} assets") + config_target.parent.mkdir(parents=True, exist_ok=True) + checkpoint_target.parent.mkdir(parents=True, exist_ok=True) + if force or not config_target.exists(): + _download_file(hf_repo_id, config_asset_path, config_target, force_download=force) + if redownload_checkpoint and checkpoint_target.exists(): + checkpoint_target.unlink() + if force or redownload_checkpoint or not checkpoint_target.exists(): + _download_file( + hf_repo_id, + checkpoint_asset_path, + checkpoint_target, + force_download=force or redownload_checkpoint, + ) + _validate_torch_checkpoint(checkpoint_target) + print(f"[ok] {stage_dir}") + + +def _normalize_stages(stages: Union[str, Iterable[str]]) -> Sequence[str]: + if isinstance(stages, str): + return (stages,) + return tuple(stages) + + +def _verify_assets(data_dir: Path, stages: Sequence[str]) -> None: + required_paths = [ + data_dir / "smal" / "my_smpl_00781_4_all.pkl", + data_dir / "smal" / "my_smpl_data_00781_4_all.pkl", + data_dir / "smal" / "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl", + data_dir / "amr_vitbb.pth", + ] + for stage_name in stages: + if stage_name not in STAGE_ASSETS: + known = ", ".join(sorted(STAGE_ASSETS)) + raise ValueError(f"Unknown PRIMA stage '{stage_name}'. Expected one of: {known}") + _, _, checkpoint_name = STAGE_ASSETS[stage_name] + stage_dir = data_dir / stage_name + required_paths.extend( + [ + stage_dir / ".hydra" / "config.yaml", + stage_dir / "checkpoints" / checkpoint_name, + ] + ) + + missing = [p for p in required_paths if not p.exists()] + if missing: + raise FileNotFoundError("Missing required files:\n" + "\n".join(str(p) for p in missing)) + + for stage_name in stages: + _, _, checkpoint_name = STAGE_ASSETS[stage_name] + _validate_torch_checkpoint(data_dir / stage_name / "checkpoints" / checkpoint_name) + + +def _ensure_assets_for_checkpoint( + checkpoint_path: PathLike, + force: bool = False, + hf_repo_id: Optional[str] = None, +) -> None: + checkpoint_path = Path(checkpoint_path) + config_path = _config_path_for_checkpoint(checkpoint_path) + stage_name = _stage_for_checkpoint(checkpoint_path) + if stage_name is None: + if checkpoint_path.exists() and config_path.exists() and not force: + print(f"[skip] Using local PRIMA checkpoint {checkpoint_path}") + return + raise FileNotFoundError( + "Missing checkpoint or config for a custom path:\n" + f" checkpoint: {checkpoint_path}\n" + f" config: {config_path}\n" + "Auto-download supports the standard PRIMA demo layouts only:\n" + " data/PRIMAS1/checkpoints/s1ckpt_inference.ckpt\n" + " data/PRIMAS3/checkpoints/s3ckpt_inference.ckpt\n" + "Pass one of those paths, or download/copy your custom checkpoint manually." + ) + + data_dir = checkpoint_path.parent.parent.parent + repo_id = _resolve_hf_repo_id(hf_repo_id) + print(f"[download] Ensuring PRIMA demo assets under {data_dir}") + _ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id) + _ensure_backbone(data_dir, force=force, hf_repo_id=repo_id) + _ensure_stage_assets( + stage_name, + data_dir, + force=force, + hf_repo_id=repo_id, + validate_existing=False, + ) + + +def ensure_demo_assets( + data_dir: PathLike = "data", + *, + stages: Union[str, Iterable[str]] = ("PRIMAS1",), + force: bool = False, + hf_repo_id: Optional[str] = None, +) -> None: + """Ensure PRIMA demo assets exist in the expected ``data/`` layout.""" + data_dir = Path(data_dir).resolve() + data_dir.mkdir(parents=True, exist_ok=True) + repo_id = _resolve_hf_repo_id(hf_repo_id) + selected_stages = _normalize_stages(stages) + + _ensure_smal_assets(data_dir, force=force, hf_repo_id=repo_id) + _ensure_backbone(data_dir, force=force, hf_repo_id=repo_id) + for stage_name in selected_stages: + _ensure_stage_assets(stage_name, data_dir, force=force, hf_repo_id=repo_id) + _verify_assets(data_dir, selected_stages) + + +def resolve_prima_checkpoint_path( + checkpoint_path: PathLike = "", + *, + data_dir: PathLike = "data", + auto_download: bool = True, + hf_repo_id: Optional[str] = None, + force: bool = False, +) -> str: + """Return a PRIMA checkpoint path, downloading default demo assets if needed.""" + resolved = Path(checkpoint_path) if checkpoint_path else _default_checkpoint_path(data_dir) + if auto_download: + _ensure_assets_for_checkpoint(resolved, force=force, hf_repo_id=hf_repo_id) + return str(resolved) + + +__all__ = [ + "DEFAULT_HF_REPO_ID", + "DEFAULT_STAGE1_CHECKPOINT", + "DEFAULT_STAGE3_CHECKPOINT", + "HF_REPO_ID", + "ensure_demo_assets", + "resolve_prima_checkpoint_path", +] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..dc163e43bdc7837d64bbdfa1dfbc2d2cdef18ec7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,91 @@ +[build-system] +requires = ["setuptools>=61", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "prima-animal" +version = "0.1.7" +description = "PRIMA: 3D animal pose and shape estimation" +readme = "README.md" +requires-python = ">=3.10" + +authors = [ + { name = "Xiaohang Yu", email = "xiaohang.yu@epfl.ch" }, + { name = "Ti Wang", email = "ti.wang@epfl.ch" }, + { name = "Mackenzie Weygandt Mathis", email = "mackenzie.mathis@epfl.ch" }, + +] + +keywords = ["3d", "animal", "pose", "shape", "vision", "pytorch"] + +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] + +dependencies = [ + # Core + "numpy==1.26.1", + "pandas==2.3.2", + + # Vision & geometry + "opencv-python==4.11.0.86", + "pyrender==0.1.45", + "open3d==0.19.0", + "trimesh==4.8.2", + "scikit-image==0.25.2", + "mmcv==1.3.9", + + # Model components + "smplx==0.1.28", + "yacs==0.1.8", + "timm==1.0.24", + "einops==0.8.1", + "xtcocotools==1.14.3", + "open_clip_torch==3.2.0", + "transformers==4.57.0", + + # Config & utilities + "omegaconf==2.3.0", + "hydra-core==1.3.2", + "hydra-submitit-launcher==1.2.0", + "hydra-colorlog==1.2.0", + "pyrootutils==1.0.4", + "rich==14.1.0", + + # IO / misc + "gdown==5.2.0", + # Match HF Space (Gradio 6.x) and local demo; do not pin 5.1 — Space injects gradio==6.x. + "gradio>=5.1,<7", + "pydantic>=2.10,<3", + + # Training framework + "pytorch-lightning==2.5.5", + + # Build/runtime helpers needed for older mmcv installation flows + "setuptools<81", + "packaging<25", + "Cython<3", + "wheel", + + # Demo runtime dependencies (included in main PyPI install) + "detectron2 @ git+https://github.com/facebookresearch/detectron2.git", + "deeplabcut", +] + +[project.optional-dependencies] +all = [] + +[project.urls] +Homepage = "https://github.com/AdaptiveMotorControlLab/PRIMA" +Source = "https://github.com/AdaptiveMotorControlLab/PRIMA" +Issues = "https://github.com/AdaptiveMotorControlLab/PRIMA/issues" + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +where = ["."] +include = ["prima*", "chumpy*"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c15280034016c0380b97f742163b4b69d1b05078 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,33 @@ +huggingface_hub<1 +torch==2.2.1 +torchvision==0.17.1 +# Do not pin gradio here: Hugging Face Spaces runs +# pip install -r requirements.txt gradio[oauth,mcp]==6.x ... +# and a pinned older gradio causes ResolutionImpossible. Local installs get Gradio from +# scripts/clean_install_local.sh after this file, or from pyproject.toml for editable installs. +deeplabcut==3.0.0rc14 +pytorch-lightning==2.5.5 +yacs==0.1.8 +pyrender==0.1.45 +trimesh==4.8.2 +opencv-python==4.11.0.86 +timm==1.0.24 +einops==0.8.1 +smplx==0.1.28 +xtcocotools==1.14.3 +open_clip_torch==3.2.0 +transformers==4.56.2 +omegaconf==2.3.0 +hydra-core==1.3.2 +hydra-submitit-launcher==1.2.0 +hydra-colorlog==1.2.0 +pyrootutils==1.0.4 +rich==14.1.0 +scikit-image==0.25.2 +pandas==2.3.2 +numpy==1.26.1 +gdown==5.2.0 +setuptools<81 +packaging<25 +Cython<3 +wheel