| | import os |
| | import argparse |
| | import numpy as np |
| | from tqdm import tqdm |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader |
| | import torchvision.transforms as transforms |
| |
|
| | from lib.core.config import cfg, update_config |
| | from lib.models.model import HACO |
| | from lib.utils.contact_utils import get_contact_thres |
| | from lib.utils.train_utils import worker_init_fn, set_seed |
| | from lib.utils.eval_utils import evaluation |
| |
|
| |
|
| | parser = argparse.ArgumentParser(description='Test HACO') |
| | parser.add_argument('--backbone', type=str, default='hamer', choices=['hamer', 'vit-l-16', 'vit-b-16', 'vit-s-16', 'handoccnet', 'hrnet-w48', 'hrnet-w32', 'resnet-152', 'resnet-101', 'resnet-50', 'resnet-34', 'resnet-18'], help='backbone model') |
| | parser.add_argument('--checkpoint', type=str, default='', help='model path for evaluation') |
| | args = parser.parse_args() |
| |
|
| |
|
| | |
| | exec(f'from data.{cfg.DATASET.test_name}.dataset import {cfg.DATASET.test_name}') |
| |
|
| |
|
| | |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | torch.set_num_threads(cfg.DATASET.workers) |
| | os.environ["OMP_NUM_THREADS"] = "4" |
| | os.environ["MKL_NUM_THREADS"] = "4" |
| |
|
| |
|
| | |
| | experiment_dir = f'experiments_test_{cfg.DATASET.test_name.lower()}' |
| | checkpoint_dir = os.path.join(experiment_dir, 'full', 'checkpoints') |
| | os.makedirs(checkpoint_dir, exist_ok=True) |
| |
|
| |
|
| | |
| | update_config(backbone_type=args.backbone, exp_dir=experiment_dir) |
| |
|
| |
|
| | |
| | from lib.core.config import logger |
| | set_seed(cfg.MODEL.seed) |
| | logger.info(f"Using random seed: {cfg.MODEL.seed}") |
| |
|
| |
|
| | |
| | transform = transforms.ToTensor() |
| | test_dataset = eval(f'{cfg.DATASET.test_name}')(transform, 'test') |
| | |
| |
|
| |
|
| | |
| | test_dataloader = DataLoader(test_dataset, batch_size=cfg.TEST.batch, shuffle=False, num_workers=cfg.DATASET.workers, pin_memory=True, drop_last=False, worker_init_fn=worker_init_fn) |
| | |
| |
|
| |
|
| | logger.info(f"# of test samples: {len(test_dataset)}") |
| |
|
| |
|
| | |
| | model = HACO().to(device) |
| | model.eval() |
| | |
| |
|
| |
|
| | |
| | if args.checkpoint: |
| | checkpoint = torch.load(args.checkpoint, map_location=device) |
| | model.load_state_dict(checkpoint['state_dict']) |
| |
|
| |
|
| | |
| | eval_result = { |
| | 'cont_pre': [None for _ in range(len(test_dataset))], |
| | 'cont_rec': [None for _ in range(len(test_dataset))], |
| | 'cont_f1': [None for _ in range(len(test_dataset))], |
| | } |
| |
|
| | test_iterator = tqdm(enumerate(test_dataloader), total=len(test_dataloader), leave=False) |
| | model.eval() |
| |
|
| |
|
| | for idx, data in test_iterator: |
| | |
| | with torch.no_grad(): |
| | outputs = model({'input': data['input_data'], 'target': data['targets_data'], 'meta_info': data['meta_info']}, mode="test") |
| | |
| |
|
| |
|
| | |
| | |
| | eval_thres = get_contact_thres(args.backbone) |
| | eval_out = evaluation(outputs, data['targets_data'], data['meta_info'], mode='test', thres=eval_thres) |
| | for key in [*eval_out]: |
| | eval_result[key][idx] = eval_out[key] |
| |
|
| | |
| | total_cont_pre = np.mean([x if x is not None else 0.0 for x in eval_result['cont_pre'][:idx+1]]) |
| | total_cont_rec = np.mean([x if x is not None else 0.0 for x in eval_result['cont_rec'][:idx+1]]) |
| | total_cont_f1 = np.mean([x if x is not None else 0.0 for x in eval_result['cont_f1'][:idx+1]]) |
| | |
| |
|
| |
|
| | logger.info(f"C-Pre: {total_cont_pre:.3f} | C-Rec: {total_cont_rec:.3f} | C-F1: {total_cont_f1:.3f}") |
| | |
| |
|
| |
|
| | logger.info('Test finished!!!!') |
| | logger.info(f"Final Results --- C-Pre: {total_cont_pre:.3f} | C-Rec: {total_cont_rec:.3f} | C-F1: {total_cont_f1:.3f}") |