PRIMA-demo / app.py
mwmathis's picture
Upload app.py with huggingface_hub
18f34cb verified
"""
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""
"""Gradio demo for PRIMA + SuperAnimal + TTA.
This script wraps the ``demo_tta.py`` pipeline into an interactive
Gradio interface. The overall logic follows:
1. Given an input image, run Detectron2 to detect animals.
2. For each detected animal, run PRIMA for 3D pose/shape estimation.
3. Run DeepLabCut SuperAnimal to obtain 2D keypoints.
4. Map SuperAnimal 39 keypoints to the 26 PRIMA keypoints.
5. Run test-time adaptation (TTA) with user-specified lr and iters.
6. Render and save before/after TTA results and keypoint visualizations.
"""
import argparse
import os
import sys
import tempfile
import traceback
from types import SimpleNamespace
from typing import List, Tuple
from pathlib import Path
import cv2
import gradio as gr
import numpy as np
import torch
import torch.utils.data
# Repo-local minimal ``chumpy`` shim (see ``chumpy/__init__.py``) so SMAL pickles load
# without installing the full chumpy package in Space builds.
_REPO_ROOT = Path(__file__).resolve().parent
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
# Default checkpoint path following README instructions
DEFAULT_CHECKPOINT = "data/PRIMAS1/checkpoints/s1ckpt.ckpt"
DEFAULT_HF_ASSET_REPO = "MLAdaptiveIntelligence/PRIMA"
# Output folder for rendered images/meshes and keypoints
DEFAULT_OUT_FOLDER = "demo_out_tta_gradio"
def _ensure_demo_assets(checkpoint_path: str) -> None:
"""Download required demo assets when running in a clean environment."""
from scripts.setup_demo_data import (
maybe_download_smal,
maybe_download_backbone,
maybe_download_stage,
)
checkpoint = Path(checkpoint_path)
data_dir = checkpoint.parents[2]
hf_repo_id = os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_ASSET_REPO)
maybe_download_smal(data_dir, force=False, hf_repo_id=hf_repo_id)
maybe_download_backbone(data_dir, force=False, hf_repo_id=hf_repo_id)
maybe_download_stage(
"PRIMAS1",
"config_s1_HYDRA.yaml",
"s1ckpt.ckpt",
"s1ckpt.ckpt",
data_dir,
force=False,
hf_repo_id=hf_repo_id,
)
def _load_prima_model(checkpoint_path: str = DEFAULT_CHECKPOINT):
"""Load PRIMA model and renderer once for the Gradio app."""
from prima.models import load_prima
from prima.utils.renderer import Renderer
checkpoint = Path(checkpoint_path)
cfg_path = checkpoint.parent.parent / ".hydra" / "config.yaml"
if not checkpoint.exists() or not cfg_path.exists():
_ensure_demo_assets(checkpoint_path)
if not checkpoint.exists():
raise FileNotFoundError(
f"Missing checkpoint: {checkpoint}. Download demo checkpoints/data as described in README."
)
if not cfg_path.exists():
raise FileNotFoundError(
f"Missing model config: {cfg_path}. Ensure the full checkpoint folder layout from README is present."
)
model, model_cfg = load_prima(checkpoint_path)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
model.eval()
renderer = Renderer(model_cfg, faces=model.smal.faces)
return model, model_cfg, renderer, device
def _build_detector():
"""Build Detectron2 animal detector (same config as demo_tta/demo.py)."""
try:
import detectron2.config
import detectron2.engine
from detectron2 import model_zoo
except Exception as e:
print(f"[warn] Detectron2 unavailable ({type(e).__name__}: {e}); using full-image fallback bbox.")
return None
cfg = detectron2.config.get_cfg()
cfg.merge_from_file(
model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml")
)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = (
"https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/"
"faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
)
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
detector = detectron2.engine.DefaultPredictor(cfg)
return detector
# SuperAnimal defaults (same as in demo_tta parser)
SUPER_ANIMAL_ARGS = SimpleNamespace(
superanimal_name="superanimal_quadruped",
superanimal_model_name="hrnet_w32",
superanimal_detector_name="fasterrcnn_resnet50_fpn_v2",
superanimal_max_individuals=1,
)
def _collect_animal_results(
model,
model_cfg,
renderer,
device,
detector,
out_folder: str,
img_rgb: np.ndarray,
tta_lr: float,
tta_num_iters: int,
det_thresh: float,
kp_conf_thresh: float,
side_view: bool,
save_mesh: bool,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], str | None, str | None]:
"""Run detection + PRIMA + SuperAnimal + TTA on a single RGB image.
Returns:
before_imgs: list of HxWx3 RGB images (before TTA) for all animals
after_imgs: list of HxWx3 RGB images (after TTA) for all animals
kpt_imgs: list of HxWx3 RGB keypoint visualizations
first_before_mesh: path to first animal's before-TTA mesh (.obj) or None
first_after_mesh: path to first animal's after-TTA mesh (.obj) or None
"""
from prima.utils import recursive_to
from prima.datasets.vitdet_dataset import ViTDetDataset
from demo_tta import (
ANIMAL_COCO_IDS,
denorm_patch_to_rgb,
map_superanimal_to_prima,
run_superanimal_on_patch,
save_keypoint_vis,
tta_optimize,
)
# Detect animals
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
if detector is None:
# Fallback for environments where Detectron2 is unavailable: process full image as one crop.
h, w = img_bgr.shape[:2]
boxes = np.array([[0.0, 0.0, float(max(1, w - 1)), float(max(1, h - 1))]], dtype=np.float32)
else:
det_out = detector(img_bgr)
det_instances = det_out["instances"]
valid_idx = [
i
for i, (c, s) in enumerate(zip(det_instances.pred_classes, det_instances.scores))
if (int(c) in ANIMAL_COCO_IDS) and (float(s) > float(det_thresh))
]
if len(valid_idx) == 0:
return [], [], [], None, None
boxes = det_instances.pred_boxes.tensor[valid_idx].cpu().numpy()
dataset = ViTDetDataset(model_cfg, img_bgr, boxes)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
before_imgs: List[np.ndarray] = []
after_imgs: List[np.ndarray] = []
kpt_imgs: List[np.ndarray] = []
before_mesh_paths: List[str] = []
after_mesh_paths: List[str] = []
img_token = next(tempfile._get_candidate_names())
for batch in dataloader:
batch = recursive_to(batch, device)
with torch.no_grad():
out_before = model(batch)
animal_id = int(batch["animalid"][0])
# Save/render before TTA
img_fn = f"{img_token}"
from demo_tta import render_and_save # imported lazily to avoid circular issues
render_and_save(
renderer,
out_before,
batch,
img_fn,
animal_id,
out_folder,
suffix="before_tta",
side_view=side_view,
save_mesh=save_mesh,
)
before_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.png")
if os.path.exists(before_png_path):
before_bgr = cv2.imread(before_png_path)
if before_bgr is not None:
before_imgs.append(cv2.cvtColor(before_bgr, cv2.COLOR_BGR2RGB))
if save_mesh:
before_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.obj")
if os.path.exists(before_obj_path):
before_mesh_paths.append(before_obj_path)
if int(tta_num_iters) <= 0:
render_and_save(
renderer,
out_before,
batch,
img_fn,
animal_id,
out_folder,
suffix="after_tta",
side_view=side_view,
save_mesh=save_mesh,
)
after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
if os.path.exists(after_png_path):
after_bgr = cv2.imread(after_png_path)
if after_bgr is not None:
after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB))
if save_mesh:
after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj")
if os.path.exists(after_obj_path):
after_mesh_paths.append(after_obj_path)
continue
# Prepare patch for SuperAnimal
patch_rgb = denorm_patch_to_rgb(batch["img"][0])
with tempfile.TemporaryDirectory(prefix=f"dlc_{img_fn}_{animal_id}_") as tmp_dir:
bodyparts_xyc = run_superanimal_on_patch(patch_rgb, SUPER_ANIMAL_ARGS, tmp_dir)
if bodyparts_xyc is None:
# No keypoints => skip TTA for this animal
continue
mapped_xyc = map_superanimal_to_prima(bodyparts_xyc)
mapped_xyc[mapped_xyc[:, 2] < float(kp_conf_thresh), 2] = 0.0
# Save keypoint visualization and npy
kpt_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.png")
save_keypoint_vis(patch_rgb, mapped_xyc, kpt_png_path)
npy_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_prima26_kpts.npy")
np.save(npy_path, mapped_xyc)
if os.path.exists(kpt_png_path):
kpt_bgr = cv2.imread(kpt_png_path)
if kpt_bgr is not None:
kpt_imgs.append(cv2.cvtColor(kpt_bgr, cv2.COLOR_BGR2RGB))
# Normalize keypoints to [-0.5, 0.5] as in demo_tta
patch_h, patch_w = patch_rgb.shape[:2]
mapped_norm = mapped_xyc.copy()
mapped_norm[:, 0] = mapped_norm[:, 0] / float(patch_w) - 0.5
mapped_norm[:, 1] = mapped_norm[:, 1] / float(patch_h) - 0.5
gt_kpts_norm = torch.from_numpy(mapped_norm[None]).to(device=device, dtype=batch["img"].dtype)
# Run TTA
out_after = tta_optimize(
model,
batch,
gt_kpts_norm,
num_iters=int(tta_num_iters),
lr=float(tta_lr),
)
render_and_save(
renderer,
out_after,
batch,
img_fn,
animal_id,
out_folder,
suffix="after_tta",
side_view=side_view,
save_mesh=save_mesh,
)
after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
if os.path.exists(after_png_path):
after_bgr = cv2.imread(after_png_path)
if after_bgr is not None:
after_imgs.append(cv2.cvtColor(after_bgr, cv2.COLOR_BGR2RGB))
if save_mesh:
after_obj_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.obj")
if os.path.exists(after_obj_path):
after_mesh_paths.append(after_obj_path)
first_before_mesh = before_mesh_paths[0] if before_mesh_paths else None
first_after_mesh = after_mesh_paths[0] if after_mesh_paths else None
return before_imgs, after_imgs, kpt_imgs, first_before_mesh, first_after_mesh
def build_demo(checkpoint_path: str = DEFAULT_CHECKPOINT, out_folder: str = DEFAULT_OUT_FOLDER) -> gr.Interface:
os.makedirs(out_folder, exist_ok=True)
runtime_cache = {
"model": None,
"model_cfg": None,
"renderer": None,
"device": None,
"detector": None,
}
def gradio_inference(
image: np.ndarray,
tta_lr: float,
tta_num_iters: int,
det_thresh: float,
kp_conf_thresh: float,
side_view: bool,
save_mesh: bool,
):
"""Wrapper for Gradio. ``image`` is an RGB numpy array."""
if image is None:
return None, None, None, "No image provided."
if image.dtype != np.uint8:
img_rgb = np.clip(image, 0, 255).astype(np.uint8)
else:
img_rgb = image
if runtime_cache["model"] is None:
try:
model, model_cfg, renderer, device = _load_prima_model(checkpoint_path)
detector = _build_detector()
except Exception as e:
return None, None, None, f"Model initialization failed:\n{traceback.format_exc()}"
runtime_cache["model"] = model
runtime_cache["model_cfg"] = model_cfg
runtime_cache["renderer"] = renderer
runtime_cache["device"] = device
runtime_cache["detector"] = detector
try:
before_imgs, after_imgs, kpt_imgs, mesh_before, mesh_after = _collect_animal_results(
runtime_cache["model"],
runtime_cache["model_cfg"],
runtime_cache["renderer"],
runtime_cache["device"],
runtime_cache["detector"],
out_folder,
img_rgb,
tta_lr=tta_lr,
tta_num_iters=tta_num_iters,
det_thresh=det_thresh,
kp_conf_thresh=kp_conf_thresh,
side_view=side_view,
save_mesh=save_mesh,
)
except Exception as e:
return None, None, None, f"Inference failed:\n{traceback.format_exc()}"
first_before = before_imgs[0] if before_imgs else None
first_after = after_imgs[0] if after_imgs else None
first_kpts = kpt_imgs[0] if kpt_imgs else None
if first_before is None and first_after is None:
return (
None,
None,
None,
"No output generated. Try an image with a clearly visible quadruped.",
)
return first_before, first_after, first_kpts, "OK"
return gr.Interface(
fn=gradio_inference,
analytics_enabled=False,
cache_examples=False,
inputs=[
gr.Image(
label="Input image",
type="numpy",
sources=["upload", "clipboard"],
),
gr.Slider(
label="TTA learning rate",
minimum=1e-7,
maximum=1e-4,
value=1e-6,
step=1e-7,
),
gr.Slider(
label="TTA iterations",
minimum=0,
maximum=100,
value=30,
step=1,
info="Set to 0 to disable TTA and reuse the initial PRIMA prediction.",
),
gr.Slider(
label="Detection threshold",
minimum=0.3,
maximum=0.9,
value=0.7,
step=0.05,
),
gr.Slider(
label="Keypoint confidence threshold",
minimum=0.0,
maximum=1.0,
value=0.1,
step=0.05,
),
gr.Checkbox(label="Render side view", value=False),
gr.Checkbox(label="Save meshes (.obj)", value=True),
],
outputs=[
gr.Image(label="Before TTA"),
gr.Image(label="After TTA"),
gr.Image(label="PRIMA 26 keypoints"),
gr.Textbox(label="Status / Traceback", lines=12),
],
title="PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation",
description=(
"Upload an animal image. The demo runs Detectron2 for animal detection, "
"PRIMA for 3D pose/shape, DeepLabCut SuperAnimal for 2D keypoints, and "
"test-time adaptation (TTA) with configurable learning rate and iterations. "
"Set TTA iterations to 0 to disable adaptation.\n\n"
"Results (PNG/OBJ and 26-keypoint visualizations) are saved under "
f"'{out_folder}'."
),
examples=[
[
"demo_data/000000015956_horse.png",
1e-6,
30,
0.7,
0.1,
False,
True,
],
[
"demo_data/n02412080_12159.png",
1e-6,
30,
0.7,
0.1,
False,
True,
],
[
"demo_data/000000315905_zebra.jpg",
1e-6,
30,
0.7,
0.1,
False,
True,
],
[
"demo_data/beagle.jpg",
1e-6,
0,
0.7,
0.1,
False,
True,
],
[
"demo_data/shepherd_hati.jpg",
1e-6,
0,
0.7,
0.1,
False,
True,
],
],
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Gradio demo for PRIMA + SuperAnimal + TTA")
parser.add_argument(
"--checkpoint",
type=str,
default=DEFAULT_CHECKPOINT,
help="Path to the pretrained PRIMA checkpoint",
)
parser.add_argument(
"--out_folder",
type=str,
default=DEFAULT_OUT_FOLDER,
help="Folder used to save rendered outputs and meshes",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
demo = build_demo(checkpoint_path=args.checkpoint, out_folder=args.out_folder)
demo.launch()