Spaces:
Configuration error
Configuration error
| """ | |
| PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation | |
| Official implementation of the paper: | |
| "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" | |
| by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis | |
| Licensed under a modified MIT license | |
| """ | |
| from pathlib import Path | |
| import detectron2.config | |
| import detectron2.engine | |
| import torch | |
| import argparse | |
| import os | |
| import cv2 | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch.utils | |
| import torch.utils.data | |
| from prima.models import load_prima | |
| from prima.utils import recursive_to | |
| from prima.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD | |
| from prima.utils.detection import select_animal_boxes | |
| from prima.utils.weights import DEFAULT_HF_REPO_ID, resolve_prima_checkpoint_path | |
| import detectron2 | |
| from detectron2 import model_zoo | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353) | |
| GREEN = (0.65, 0.86, 0.74) | |
| REPO_ROOT = Path(__file__).resolve().parent | |
| def load_renderer_components(): | |
| try: | |
| from prima.utils.renderer import Renderer, cam_crop_to_full | |
| except Exception as exc: | |
| raise RuntimeError( | |
| "Cannot initialize the PRIMA renderer. Rendering requires a working " | |
| "pyrender/OpenGL backend such as EGL or OSMesa. Install the missing " | |
| "OpenGL runtime for this environment, or run in an environment where " | |
| "PYOPENGL_PLATFORM=egl/osmesa works." | |
| ) from exc | |
| return Renderer, cam_crop_to_full | |
| def main(): | |
| parser = argparse.ArgumentParser(description='prima demo code') | |
| parser.add_argument('--checkpoint', type=str, default='', | |
| help='Path to pretrained model checkpoint. Empty -> auto-download the default Stage 1 checkpoint.') | |
| parser.add_argument('--hf-repo-id', '--hf_repo_id', dest='hf_repo_id', | |
| type=str, default=os.environ.get("PRIMA_HF_REPO_ID", DEFAULT_HF_REPO_ID), | |
| help='Hugging Face repo ID containing PRIMA demo assets') | |
| parser.add_argument('--no-auto-download', '--no_auto_download', dest='no_auto_download', action='store_true', | |
| help='Disable automatic download of missing PRIMA demo assets') | |
| parser.add_argument('--img_folder', type=str, default='demo_data/', help='Folder with input images') | |
| parser.add_argument('--out_folder', type=str, default='demo_out', help='Output folder to save rendered results') | |
| parser.add_argument('--side_view', dest='side_view', action='store_true', default=False, | |
| help='If set, render side view also') | |
| parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, | |
| help='If set, save meshes to disk also') | |
| parser.add_argument('--batch_size', type=int, default=1, help='Batch size for inference/fitting') | |
| parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg', '*.JPEG'], | |
| help='List of file extensions to consider') | |
| args = parser.parse_args() | |
| checkpoint_path = resolve_prima_checkpoint_path( | |
| args.checkpoint, | |
| data_dir=REPO_ROOT / "data", | |
| auto_download=not args.no_auto_download, | |
| hf_repo_id=args.hf_repo_id, | |
| ) | |
| model, model_cfg = load_prima(checkpoint_path) | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| model = model.to(device) | |
| model.eval() | |
| # Setup the renderer | |
| Renderer, cam_crop_to_full = load_renderer_components() | |
| renderer = Renderer(model_cfg, faces=model.smal.faces) | |
| # Make output directory if it does not exist | |
| os.makedirs(args.out_folder, exist_ok=True) | |
| # Load detector | |
| cfg = detectron2.config.get_cfg() | |
| cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml")) | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
| cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl" | |
| cfg.MODEL.DEVICE = device.type | |
| detector = detectron2.engine.DefaultPredictor(cfg) | |
| img_paths = sorted([img for end in args.file_type for img in Path(args.img_folder).glob(end)]) | |
| num_readable_images = 0 | |
| num_rendered_results = 0 | |
| num_suppressed_detections = 0 | |
| for img_path in img_paths: | |
| img_bgr = cv2.imread(str(img_path)) | |
| if img_bgr is None: | |
| print(f"[WARN] Cannot read image: {img_path}") | |
| continue | |
| num_readable_images += 1 | |
| # Detect animals in image | |
| det_out = detector(img_bgr) | |
| det_instances = det_out['instances'] | |
| boxes, suppressed = select_animal_boxes(det_instances, score_threshold=0.7) | |
| num_suppressed_detections += suppressed | |
| if suppressed > 0: | |
| print(f"[INFO] Suppressed {suppressed} duplicate animal detection(s) in {img_path}") | |
| if len(boxes) == 0: | |
| print(f"[INFO] No animal detected in {img_path}") | |
| continue | |
| # Run PRIMA on detected animals | |
| dataset = ViTDetDataset(model_cfg, img_bgr, boxes) | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) | |
| for batch in tqdm(dataloader): | |
| batch = recursive_to(batch, device) | |
| with torch.no_grad(): | |
| out = model(batch) | |
| pred_cam = out['pred_cam'] | |
| box_center = batch["box_center"].float() | |
| box_size = batch["box_size"].float() | |
| img_size = batch["img_size"].float() | |
| scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max() | |
| pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size, | |
| scaled_focal_length).detach().cpu().numpy() | |
| # Render the result | |
| batch_size = batch['img'].shape[0] | |
| for n in range(batch_size): | |
| # Get filename from path img_path | |
| img_fn, _ = os.path.splitext(os.path.basename(img_path)) | |
| animal_id = int(batch['animalid'][n]) | |
| white_img = (torch.ones_like(batch['img'][n]).cpu() - DEFAULT_MEAN[:, None, None] / 255) / ( | |
| DEFAULT_STD[:, None, None] / 255) | |
| input_patch = (batch['img'][n].cpu() * (DEFAULT_STD[:, None, None]) + ( | |
| DEFAULT_MEAN[:, None, None])) / 255. | |
| input_patch = input_patch.permute(1, 2, 0).numpy() | |
| regression_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(), | |
| out['pred_cam_t'][n].detach().cpu().numpy(), | |
| batch['img'][n], | |
| mesh_base_color=GREEN, | |
| scene_bg_color=(1, 1, 1), | |
| ) | |
| final_img = np.concatenate([input_patch, regression_img], axis=1) | |
| if args.side_view: | |
| side_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(), | |
| out['pred_cam_t'][n].detach().cpu().numpy(), | |
| white_img, | |
| mesh_base_color=GREEN, | |
| scene_bg_color=(1, 1, 1), | |
| side_view=True) | |
| final_img = np.concatenate([final_img, side_img], axis=1) | |
| cv2.imwrite(os.path.join(args.out_folder, f'{img_fn}_{animal_id}.png'), | |
| cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR)) | |
| num_rendered_results += 1 | |
| # Add all verts and cams to list | |
| verts = out['pred_vertices'][n].detach().cpu().numpy() | |
| cam_t = pred_cam_t_full[n] | |
| # Save all meshes to disk | |
| if args.save_mesh: | |
| camera_translation = cam_t.copy() | |
| tmesh = renderer.vertices_to_trimesh(verts, camera_translation, LIGHT_BLUE) | |
| tmesh.export(os.path.join(args.out_folder, f'{img_fn}_{animal_id}.obj')) | |
| print( | |
| f"[done] Demo complete. Processed {num_readable_images}/{len(img_paths)} image(s), " | |
| f"saved {num_rendered_results} rendered result(s) to {args.out_folder}." | |
| ) | |
| if num_suppressed_detections > 0: | |
| print(f"[done] Suppressed {num_suppressed_detections} duplicate animal detection(s).") | |
| if __name__ == '__main__': | |
| main() | |