""" PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation Official implementation of the paper: "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis Licensed under a modified MIT license """ import numpy as np from tqdm import tqdm import torch from prima.utils import recursive_to from prima.utils.evaluate_metric import Evaluator from prima.datasets.datasets import EvaluationDataset import argparse from torch.utils.data import DataLoader from prima.models.prima import PRIMA from prima.configs import get_config torch.multiprocessing.set_sharing_strategy('file_system') def main(args): cfg = get_config(args.config) default_cfg = get_config(args.default_eval_config) model = PRIMA.load_from_checkpoint(args.checkpoint, cfg=cfg, strict=False) model.eval() smal_evaluator = Evaluator(smal_model=model.smal, image_size=cfg.MODEL.IMAGE_SIZE) cfg_eval_dataset = dict(default_cfg.DATASETS) aug_cfg = cfg_eval_dataset.pop("CONFIG", None) # augmentation config is not used in evaluation if args.dataset.upper() == "ALL": for key in cfg_eval_dataset.keys(): print(f"-------- Evaluate {key} dataset ------------") eval_one_dataset(cfg_eval_dataset[key], default_cfg, cfg, model, evaluator=smal_evaluator, aug_cfg=aug_cfg, key=key, device=args.device) print(f"-------{key} Dataset evaluate finish ------") else: print(f"-------- Evaluate {args.dataset} dataset ------------") eval_one_dataset(cfg_eval_dataset[args.dataset], default_cfg, cfg, model, evaluator=smal_evaluator, aug_cfg=aug_cfg, key=args.dataset, device=args.device) print(f"-------{args.dataset} Dataset evaluate finish ------") def eval_one_dataset(dataset_cfg, default_cfg, cfg, model, evaluator, aug_cfg, key, device='cuda'): dataset = EvaluationDataset(root_image=dataset_cfg['ROOT_IMAGE'], json_file=dataset_cfg['JSON_FILE']['TEST'], augm_config=aug_cfg, focal_length=cfg.SMAL.get("FOCAL_LENGTH", 1000), image_size=cfg.MODEL.IMAGE_SIZE, ) dataloader = DataLoader(dataset, batch_size=1, num_workers=cfg.GENERAL.NUM_WORKERS) bar = tqdm(dataloader) pa_mpjpe_list, pck_list, auc_list, pa_mpvpe_list = [], [], [], [] for i, batch in enumerate(bar): batch = recursive_to(batch, device) with torch.no_grad(): output = model(batch) if key in ["ANIMAL3D", "CONTROL_ANIMAL3D"]: pa_mpjpe, pa_mpvpe = evaluator.eval_3d(output, batch) else: pa_mpjpe, pa_mpvpe = 0., 0. pck, auc = evaluator.eval_2d(output, batch, pck_threshold=default_cfg.METRIC.PCK_THRESHOLD) pa_mpjpe_list.append(pa_mpjpe) pa_mpvpe_list.append(pa_mpvpe) auc_list.append(auc) pck_list.append(pck) bar.set_postfix(PA_MPJPE=pa_mpjpe, PA_MPVPE=pa_mpvpe, AUC=auc, pck=pck,) print("---------------- 3D metric -----------------") print(f"Avg PA-MPJPE: {np.mean(pa_mpjpe_list)}") print(f"Avg PA-MPVPE: {np.mean(pa_mpvpe_list)}") print("--------------- 2D metric ------------------") print(f"AUC: {np.mean(auc_list)}") pck_list = np.array(pck_list) for _, th in enumerate(default_cfg.METRIC.PCK_THRESHOLD): print(f"PCK@{th}: {np.mean(pck_list[:, _])}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, help="Path to config file", required=True) parser.add_argument("--checkpoint", type=str, help="Path to checkpoint file", required=True) parser.add_argument("--default_eval_config", type=str, default="./configs_hydra/experiment/default_val.yaml") parser.add_argument("--dataset", type=str, default="ALL") parser.add_argument("--device", type=str, default="cuda", help="Device to use for evaluation") args = parser.parse_args() main(args)