File size: 4,377 Bytes
1800173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation

Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""

import numpy as np
from tqdm import tqdm
import torch
from prima.utils import recursive_to
from prima.utils.evaluate_metric import Evaluator
from prima.datasets.datasets import EvaluationDataset
import argparse
from torch.utils.data import DataLoader
from prima.models.prima import PRIMA
from prima.configs import get_config
torch.multiprocessing.set_sharing_strategy('file_system')


def main(args):
    cfg = get_config(args.config)
    default_cfg = get_config(args.default_eval_config)
    model = PRIMA.load_from_checkpoint(args.checkpoint, cfg=cfg, strict=False)
    model.eval()

    smal_evaluator = Evaluator(smal_model=model.smal, image_size=cfg.MODEL.IMAGE_SIZE)
    cfg_eval_dataset = dict(default_cfg.DATASETS)
    aug_cfg = cfg_eval_dataset.pop("CONFIG", None)  # augmentation config is not used in evaluation

    if args.dataset.upper() == "ALL":
        for key in cfg_eval_dataset.keys():
            print(f"-------- Evaluate {key} dataset ------------")
            eval_one_dataset(cfg_eval_dataset[key], default_cfg, cfg, model, 
                             evaluator=smal_evaluator, 
                             aug_cfg=aug_cfg, 
                             key=key,
                             device=args.device)
            print(f"-------{key} Dataset evaluate finish ------")
    else:
        print(f"-------- Evaluate {args.dataset} dataset ------------")
        eval_one_dataset(cfg_eval_dataset[args.dataset], default_cfg, cfg, model, 
                         evaluator=smal_evaluator, 
                         aug_cfg=aug_cfg, 
                         key=args.dataset,
                         device=args.device)
        print(f"-------{args.dataset} Dataset evaluate finish ------")


def eval_one_dataset(dataset_cfg, default_cfg, cfg, model, evaluator, aug_cfg, key, device='cuda'):
    dataset = EvaluationDataset(root_image=dataset_cfg['ROOT_IMAGE'], 
                                json_file=dataset_cfg['JSON_FILE']['TEST'], 
                                augm_config=aug_cfg, focal_length=cfg.SMAL.get("FOCAL_LENGTH", 1000),
                                image_size=cfg.MODEL.IMAGE_SIZE,
                                )
    dataloader = DataLoader(dataset, batch_size=1, num_workers=cfg.GENERAL.NUM_WORKERS)

    bar = tqdm(dataloader)
    pa_mpjpe_list, pck_list, auc_list, pa_mpvpe_list = [], [], [], []
    for i, batch in enumerate(bar):
        batch = recursive_to(batch, device)
        with torch.no_grad():
            output = model(batch)

        if key in ["ANIMAL3D", "CONTROL_ANIMAL3D"]:
            pa_mpjpe, pa_mpvpe = evaluator.eval_3d(output, batch)
        else:
            pa_mpjpe, pa_mpvpe = 0., 0.
        pck, auc = evaluator.eval_2d(output, batch, pck_threshold=default_cfg.METRIC.PCK_THRESHOLD)

        pa_mpjpe_list.append(pa_mpjpe)
        pa_mpvpe_list.append(pa_mpvpe)
        auc_list.append(auc)
        pck_list.append(pck)

        bar.set_postfix(PA_MPJPE=pa_mpjpe,
                        PA_MPVPE=pa_mpvpe,
                        AUC=auc,
                        pck=pck,)

    print("---------------- 3D metric -----------------")
    print(f"Avg PA-MPJPE: {np.mean(pa_mpjpe_list)}")
    print(f"Avg PA-MPVPE: {np.mean(pa_mpvpe_list)}")
    
    print("--------------- 2D metric ------------------")
    print(f"AUC: {np.mean(auc_list)}")
    pck_list = np.array(pck_list)
    for _, th in enumerate(default_cfg.METRIC.PCK_THRESHOLD):
        print(f"PCK@{th}: {np.mean(pck_list[:, _])}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, help="Path to config file", required=True)
    parser.add_argument("--checkpoint", type=str, help="Path to checkpoint file", required=True)
    parser.add_argument("--default_eval_config", type=str, default="./configs_hydra/experiment/default_val.yaml")
    parser.add_argument("--dataset", type=str, default="ALL")
    parser.add_argument("--device", type=str, default="cuda", help="Device to use for evaluation")
    args = parser.parse_args()
    main(args)