| import argparse | |
| from torch.utils.data import DataLoader | |
| import lightning as L | |
| from datasets import dataset_dict | |
| from model import PL_RelPose, keypoint_dict | |
| from configs.default import get_cfg_defaults | |
| def main(args): | |
| config = get_cfg_defaults() | |
| config.merge_from_file(args.config) | |
| task = config.DATASET.TASK | |
| dataset = config.DATASET.DATA_SOURCE | |
| batch_size = config.TRAINER.BATCH_SIZE | |
| num_workers = config.TRAINER.NUM_WORKERS | |
| pin_memory = config.TRAINER.PIN_MEMORY | |
| test_num_keypoints = config.MODEL.TEST_NUM_KEYPOINTS | |
| build_fn = dataset_dict[task][dataset] | |
| testset = build_fn('test', config) | |
| testloader = DataLoader(testset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory) | |
| pl_relpose = PL_RelPose.load_from_checkpoint(args.ckpt_path) | |
| pl_relpose.extractor = keypoint_dict[pl_relpose.hparams['features']](max_num_keypoints=test_num_keypoints, detection_threshold=0.0).eval() | |
| trainer = L.Trainer( | |
| devices=[0], | |
| ) | |
| trainer.test(pl_relpose, dataloaders=testloader) | |
| def get_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('config', type=str, help='.yaml configure file path') | |
| parser.add_argument('ckpt_path', type=str) | |
| return parser | |
| if __name__ == "__main__": | |
| parser = get_parser() | |
| args = parser.parse_args() | |
| main(args) | |