| """ |
| PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation |
| |
| Official implementation of the paper: |
| "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" |
| by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis |
| Licensed under a modified MIT license |
| """ |
|
|
| """Gradio demo for PRIMA + SuperAnimal + TTA. |
| |
| This script wraps the ``demo_tta.py`` pipeline into an interactive |
| Gradio interface. The overall logic follows: |
| |
| 1. Given an input image, run Detectron2 to detect animals. |
| 2. For each detected animal, run PRIMA for 3D pose/shape estimation. |
| 3. Run DeepLabCut SuperAnimal to obtain 2D keypoints. |
| 4. Map SuperAnimal 39 keypoints to the 26 PRIMA keypoints. |
| 5. Run test-time adaptation (TTA) with user-specified lr and iters. |
| 6. Render and save before/after TTA results and keypoint visualizations. |
| |
| """ |
|
|
| import argparse |
| import os |
| import sys |
| import tempfile |
| import traceback |
| from types import SimpleNamespace |
| from typing import List, Tuple |
| from pathlib import Path |
|
|
| import cv2 |
| import gradio as gr |
| import numpy as np |
| import torch |
| import torch.utils.data |
|
|
| |
| |
| _REPO_ROOT = Path(__file__).resolve().parent |
| if str(_REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(_REPO_ROOT)) |
|
|
|
|
| |
| DEFAULT_CHECKPOINT = "data/PRIMAS1/checkpoints/s1ckpt.ckpt" |
| DEFAULT_HF_ASSET_REPO = "MLAdaptiveIntelligence/PRIMA" |
|
|
| |
| DEFAULT_OUT_FOLDER = "demo_out_tta_gradio" |
|
|
|
|
| def _ensure_demo_assets(checkpoint_path: str) -> None: |
| """Download required demo assets when running in a clean environment.""" |
| from scripts.setup_demo_data import ( |
| maybe_download_smal, |
| maybe_download_backbone, |
| maybe_download_stage, |
| ) |
|
|
| checkpoint = Path(checkpoint_path) |
| data_dir = checkpoint.parents[2] |
| hf_repo_id = os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO) |
|
|
| maybe_download_smal(data_dir, force=False, hf_repo_id=hf_repo_id) |
| maybe_download_backbone(data_dir, force=False, hf_repo_id=hf_repo_id) |
| maybe_download_stage( |
| "PRIMAS1", |
| "config_s1_HYDRA.yaml", |
| "s1ckpt.ckpt", |
| "s1ckpt.ckpt", |
| data_dir, |
| force=False, |
| hf_repo_id=hf_repo_id, |
| ) |
|
|
|
|
| def _load_prima_model(checkpoint_path: str = DEFAULT_CHECKPOINT): |
| """Load PRIMA model and renderer once for the Gradio app.""" |
| from prima.models import load_prima |
| from prima.utils.renderer import Renderer |
|
|
| checkpoint = Path(checkpoint_path) |
| cfg_path = checkpoint.parent.parent / ".hydra" / "config.yaml" |
| if not checkpoint.exists() or not cfg_path.exists(): |
| _ensure_demo_assets(checkpoint_path) |
| if not checkpoint.exists(): |
| raise FileNotFoundError( |
| f"Missing checkpoint: {checkpoint}. Download demo checkpoints/data as described in README." |
| ) |
| if not cfg_path.exists(): |
| raise FileNotFoundError( |
| f"Missing model config: {cfg_path}. Ensure the full checkpoint folder layout from README is present." |
| ) |
|
|
| model, model_cfg = load_prima(checkpoint_path) |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| model = model.to(device) |
| model.eval() |
|
|
| renderer = Renderer(model_cfg, faces=model.smal.faces) |
| return model, model_cfg, renderer, device |
|
|
|
|
| def _build_detector(): |
| """Build Detectron2 animal detector (same config as demo_tta/demo.py).""" |
| try: |
| import detectron2.config |
| import detectron2.engine |
| from detectron2 import model_zoo |
| except Exception as e: |
| print(f"[warn] Detectron2 unavailable ({type(e).__name__}: {e}); using full-image fallback bbox.") |
| return None |
|
|
| cfg = detectron2.config.get_cfg() |
| cfg.merge_from_file( |
| model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml") |
| ) |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 |
| cfg.MODEL.WEIGHTS = ( |
| "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/" |
| "faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl" |
| ) |
| cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| detector = detectron2.engine.DefaultPredictor(cfg) |
| return detector |
|
|
| |
| SUPER_ANIMAL_ARGS = SimpleNamespace( |
| superanimal_name="superanimal_quadruped", |
| superanimal_model_name="hrnet_w32", |
| superanimal_detector_name="fasterrcnn_resnet50_fpn_v2", |
| superanimal_max_individuals=1, |
| ) |
|
|
|
|
| def _collect_animal_results( |
| model, |
| model_cfg, |
| renderer, |
| device, |
| detector, |
| out_folder: str, |
| img_rgb: np.ndarray, |
| tta_lr: float, |
| tta_num_iters: int, |
| det_thresh: float, |
| kp_conf_thresh: float, |
| side_view: bool, |
| save_mesh: bool, |
| ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], str | None, str | None]: |
| """Run detection + PRIMA + SuperAnimal + TTA on a single RGB image. |
| |
| Returns: |
| before_imgs: list of HxWx3 RGB images (before TTA) for all animals |
| after_imgs: list of HxWx3 RGB images (after TTA) for all animals |
| kpt_imgs: list of HxWx3 RGB keypoint visualizations |
| first_before_mesh: path to first animal's before-TTA mesh (.obj) or None |
| first_after_mesh: path to first animal's after-TTA mesh (.obj) or None |
| """ |
| from prima.utils import recursive_to |
| from prima.datasets.vitdet_dataset import ViTDetDataset |
| from demo_tta import ( |
| ANIMAL_COCO_IDS, |
| denorm_patch_to_rgb, |
| map_superanimal_to_prima, |
| run_superanimal_on_patch, |
| save_keypoint_vis, |
| tta_optimize, |
| ) |
|
|
| |
| img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) |
| if detector is None: |
| |
| h, w = img_bgr.shape[:2] |
| boxes = np.array([[0.0, 0.0, float(max(1, w - 1)), float(max(1, h - 1))]], dtype=np.float32) |
| else: |
| det_out = detector(img_bgr) |
| det_instances = det_out["instances"] |
|
|
| valid_idx = [ |
| i |
| for i, (c, s) in enumerate(zip(det_instances.pred_classes, det_instances.scores)) |
| if (int(c) in ANIMAL_COCO_IDS) and (float(s) > float(det_thresh)) |
| ] |
| if len(valid_idx) == 0: |
| return [], [], [], None, None |
|
|
| boxes = det_instances.pred_boxes.tensor[valid_idx].cpu().numpy() |
|
|
| dataset = ViTDetDataset(model_cfg, img_bgr, boxes) |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) |
|
|
| before_imgs: List[np.ndarray] = [] |
| after_imgs: List[np.ndarray] = [] |
| kpt_imgs: List[np.ndarray] = [] |
| before_mesh_paths: List[str] = [] |
| after_mesh_paths: List[str] = [] |
|
|
| img_token = next(tempfile._get_candidate_names()) |
|
|
| for batch in dataloader: |
| batch = recursive_to(batch, device) |
|
|
| with torch.no_grad(): |
| out_before = model(batch) |
|
|
| animal_id = int(batch["animalid"][0]) |
|
|
| |
| img_fn = f"{img_token}" |
| from demo_tta import render_and_save |
|
|
| render_and_save( |
| renderer, |
| out_before, |
| batch, |
| img_fn, |
| animal_id, |
| out_folder, |
| suffix="before_tta", |
| side_view=side_view, |
| save_mesh=save_mesh, |
| ) |
|
|
| before_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.png") |
| if os.path.exists(before_png_path): |
| before_bgr = cv2.imread(before_png_path) |
| if before_bgr is not None: |
| before_imgs.append(cv2.cvtColor(before_bgr, cv2.COLOR_BGR2RGB)) |
|
|
| if save_mesh: |
| before_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.obj") |
| if os.path.exists(before_obj_path): |
| before_mesh_paths.append(before_obj_path) |
|
|
| if int(tta_num_iters) <= 0: |
| render_and_save( |
| renderer, |
| out_before, |
| batch, |
| img_fn, |
| animal_id, |
| out_folder, |
| suffix="after_tta", |
| side_view=side_view, |
| save_mesh=save_mesh, |
| ) |
|
|
| after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png") |
| if os.path.exists(after_png_path): |
| after_bgr = cv2.imread(after_png_path) |
| if after_bgr is not None: |
| after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB)) |
|
|
| if save_mesh: |
| after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj") |
| if os.path.exists(after_obj_path): |
| after_mesh_paths.append(after_obj_path) |
| continue |
|
|
| |
| patch_rgb = denorm_patch_to_rgb(batch["img"][0]) |
| with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir: |
| bodyparts_xyc = run_superanimal_on_patch(patch_rgb, SUPER_ANIMAL_ARGS, tmp_dir) |
|
|
| if bodyparts_xyc is None: |
| |
| continue |
|
|
| mapped_xyc = map_superanimal_to_prima(bodyparts_xyc) |
| mapped_xyc[mapped_xyc[:, 2] < float(kp_conf_thresh), 2] = 0.0 |
|
|
| |
| kpt_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png") |
| save_keypoint_vis(patch_rgb, mapped_xyc, kpt_png_path) |
| npy_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy") |
| np.save(npy_path, mapped_xyc) |
|
|
| if os.path.exists(kpt_png_path): |
| kpt_bgr = cv2.imread(kpt_png_path) |
| if kpt_bgr is not None: |
| kpt_imgs.append(cv2.cvtColor(kpt_bgr, cv2.COLOR_BGR2RGB)) |
|
|
| |
| patch_h, patch_w = patch_rgb.shape[:2] |
| mapped_norm = mapped_xyc.copy() |
| mapped_norm[:, 0] = mapped_norm[:, 0] / float(patch_w) - 0.5 |
| mapped_norm[:, 1] = mapped_norm[:, 1] / float(patch_h) - 0.5 |
| gt_kpts_norm = torch.from_numpy(mapped_norm[None]).to(device=device, dtype=batch["img"].dtype) |
|
|
| |
| out_after = tta_optimize( |
| model, |
| batch, |
| gt_kpts_norm, |
| num_iters=int(tta_num_iters), |
| lr=float(tta_lr), |
| ) |
|
|
| render_and_save( |
| renderer, |
| out_after, |
| batch, |
| img_fn, |
| animal_id, |
| out_folder, |
| suffix="after_tta", |
| side_view=side_view, |
| save_mesh=save_mesh, |
| ) |
|
|
| after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png") |
| if os.path.exists(after_png_path): |
| after_bgr = cv2.imread(after_png_path) |
| if after_bgr is not None: |
| after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB)) |
|
|
| if save_mesh: |
| after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj") |
| if os.path.exists(after_obj_path): |
| after_mesh_paths.append(after_obj_path) |
|
|
| first_before_mesh = before_mesh_paths[0] if before_mesh_paths else None |
| first_after_mesh = after_mesh_paths[0] if after_mesh_paths else None |
|
|
| return before_imgs, after_imgs, kpt_imgs, first_before_mesh, first_after_mesh |
|
|
|
|
| def build_demo(checkpoint_path: str = DEFAULT_CHECKPOINT, out_folder: str = DEFAULT_OUT_FOLDER) -> gr.Interface: |
| os.makedirs(out_folder, exist_ok=True) |
| runtime_cache = { |
| "model": None, |
| "model_cfg": None, |
| "renderer": None, |
| "device": None, |
| "detector": None, |
| } |
|
|
| def gradio_inference( |
| image: np.ndarray, |
| tta_lr: float, |
| tta_num_iters: int, |
| det_thresh: float, |
| kp_conf_thresh: float, |
| side_view: bool, |
| save_mesh: bool, |
| ): |
| """Wrapper for Gradio. ``image`` is an RGB numpy array.""" |
|
|
| if image is None: |
| return None, None, None, "No image provided." |
|
|
| if image.dtype != np.uint8: |
| img_rgb = np.clip(image, 0, 255).astype(np.uint8) |
| else: |
| img_rgb = image |
|
|
| if runtime_cache["model"] is None: |
| try: |
| model, model_cfg, renderer, device = _load_prima_model(checkpoint_path) |
| detector = _build_detector() |
| except Exception as e: |
| return None, None, None, f"Model initialization failed:\n{traceback.format_exc()}" |
| runtime_cache["model"] = model |
| runtime_cache["model_cfg"] = model_cfg |
| runtime_cache["renderer"] = renderer |
| runtime_cache["device"] = device |
| runtime_cache["detector"] = detector |
|
|
| try: |
| before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = _collect_animal_results( |
| runtime_cache["model"], |
| runtime_cache["model_cfg"], |
| runtime_cache["renderer"], |
| runtime_cache["device"], |
| runtime_cache["detector"], |
| out_folder, |
| img_rgb, |
| tta_lr=tta_lr, |
| tta_num_iters=tta_num_iters, |
| det_thresh=det_thresh, |
| kp_conf_thresh=kp_conf_thresh, |
| side_view=side_view, |
| save_mesh=save_mesh, |
| ) |
| except Exception as e: |
| return None, None, None, f"Inference failed:\n{traceback.format_exc()}" |
|
|
| first_before = before_imgs[0] if before_imgs else None |
| first_after = after_imgs[0] if after_imgs else None |
| first_kpts = kpt_imgs[0] if kpt_imgs else None |
| if first_before is None and first_after is None: |
| return ( |
| None, |
| None, |
| None, |
| "No output generated. Try an image with a clearly visible quadruped.", |
| ) |
| return first_before, first_after, first_kpts, "OK" |
|
|
| return gr.Interface( |
| fn=gradio_inference, |
| analytics_enabled=False, |
| cache_examples=False, |
| inputs=[ |
| gr.Image( |
| label="Input image", |
| type="numpy", |
| sources=["upload", "clipboard"], |
| ), |
| gr.Slider( |
| label="TTA learning rate", |
| minimum=1e-7, |
| maximum=1e-4, |
| value=1e-6, |
| step=1e-7, |
| ), |
| gr.Slider( |
| label="TTA iterations", |
| minimum=0, |
| maximum=100, |
| value=30, |
| step=1, |
| info="Set to 0 to disable TTA and reuse the initial PRIMA prediction.", |
| ), |
| gr.Slider( |
| label="Detection threshold", |
| minimum=0.3, |
| maximum=0.9, |
| value=0.7, |
| step=0.05, |
| ), |
| gr.Slider( |
| label="Keypoint confidence threshold", |
| minimum=0.0, |
| maximum=1.0, |
| value=0.1, |
| step=0.05, |
| ), |
| gr.Checkbox(label="Render side view", value=False), |
| gr.Checkbox(label="Save meshes (.obj)", value=True), |
| ], |
| outputs=[ |
| gr.Image(label="Before TTA"), |
| gr.Image(label="After TTA"), |
| gr.Image(label="PRIMA 26 keypoints"), |
| gr.Textbox(label="Status / Traceback", lines=12), |
| ], |
| title="PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation", |
| description=( |
| "Upload an animal image. The demo runs Detectron2 for animal detection, " |
| "PRIMA for 3D pose/shape, DeepLabCut SuperAnimal for 2D keypoints, and " |
| "test-time adaptation (TTA) with configurable learning rate and iterations. " |
| "Set TTA iterations to 0 to disable adaptation.\n\n" |
| "Results (PNG/OBJ and 26-keypoint visualizations) are saved under " |
| f"'{out_folder}'." |
| ), |
| examples=[ |
| [ |
| "demo_data/000000015956_horse.png", |
| 1e-6, |
| 30, |
| 0.7, |
| 0.1, |
| False, |
| True, |
| ], |
| [ |
| "demo_data/n02412080_12159.png", |
| 1e-6, |
| 30, |
| 0.7, |
| 0.1, |
| False, |
| True, |
| ], |
| [ |
| "demo_data/000000315905_zebra.jpg", |
| 1e-6, |
| 30, |
| 0.7, |
| 0.1, |
| False, |
| True, |
| ], |
| [ |
| "demo_data/beagle.jpg", |
| 1e-6, |
| 0, |
| 0.7, |
| 0.1, |
| False, |
| True, |
| ], |
| [ |
| "demo_data/shepherd_hati.jpg", |
| 1e-6, |
| 0, |
| 0.7, |
| 0.1, |
| False, |
| True, |
| ], |
| ], |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Gradio demo for PRIMA + SuperAnimal + TTA") |
| parser.add_argument( |
| "--checkpoint", |
| type=str, |
| default=DEFAULT_CHECKPOINT, |
| help="Path to the pretrained PRIMA checkpoint", |
| ) |
| parser.add_argument( |
| "--out_folder", |
| type=str, |
| default=DEFAULT_OUT_FOLDER, |
| help="Folder used to save rendered outputs and meshes", |
| ) |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| demo = build_demo(checkpoint_path=args.checkpoint, out_folder=args.out_folder) |
| demo.launch() |
|
|