| import torch |
| import yaml |
| import time |
| from collections import OrderedDict, namedtuple |
| import os |
| import sys |
|
|
| ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| sys.path.insert(0, ROOT_DIR) |
|
|
| from sgmnet import matcher as SGM_Model |
| from superglue import matcher as SG_Model |
|
|
|
|
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--matcher_name", type=str, default="SGM", help="number of processes." |
| ) |
| parser.add_argument( |
| "--config_path", |
| type=str, |
| default="configs/cost/sgm_cost.yaml", |
| help="number of processes.", |
| ) |
| parser.add_argument( |
| "--num_kpt", type=int, default=4000, help="keypoint number, default:100" |
| ) |
| parser.add_argument( |
| "--iter_num", type=int, default=100, help="keypoint number, default:100" |
| ) |
|
|
|
|
| def test_cost(test_data, model): |
| with torch.no_grad(): |
| |
| _ = model(test_data) |
| torch.cuda.synchronize() |
| a = time.time() |
| for _ in range(int(args.iter_num)): |
| _ = model(test_data) |
| torch.cuda.synchronize() |
| b = time.time() |
| print("Average time per run(ms): ", (b - a) / args.iter_num * 1e3) |
| print("Peak memory(MB): ", torch.cuda.max_memory_allocated() / 1e6) |
|
|
|
|
| if __name__ == "__main__": |
| torch.backends.cudnn.benchmark = False |
| args = parser.parse_args() |
| with open(args.config_path, "r") as f: |
| model_config = yaml.load(f) |
| model_config = namedtuple("model_config", model_config.keys())( |
| *model_config.values() |
| ) |
|
|
| if args.matcher_name == "SGM": |
| model = SGM_Model(model_config) |
| elif args.matcher_name == "SG": |
| model = SG_Model(model_config) |
| model.cuda(), model.eval() |
|
|
| test_data = { |
| "x1": torch.rand(1, args.num_kpt, 2).cuda() - 0.5, |
| "x2": torch.rand(1, args.num_kpt, 2).cuda() - 0.5, |
| "desc1": torch.rand(1, args.num_kpt, 128).cuda(), |
| "desc2": torch.rand(1, args.num_kpt, 128).cuda(), |
| } |
|
|
| test_cost(test_data, model) |
|
|