| from matching import get_matcher, available_models, get_default_device | |
| from pathlib import Path | |
| from argparse import ArgumentParser | |
| import cv2 | |
| import time | |
| from tqdm.auto import tqdm | |
| import torch | |
| import numpy as np | |
| def parse_args(): | |
| parser = ArgumentParser() | |
| parser.add_argument( | |
| "--task", type=str, default="benchmark", help="run benchmark or unit tests" | |
| ) | |
| parser.add_argument( | |
| "--matcher", | |
| type=str, | |
| nargs="+", | |
| default="all", | |
| help="which model or list of models to benchmark", | |
| ) | |
| parser.add_argument( | |
| "--img-size", | |
| type=int, | |
| default=512, | |
| help="image size to run matching on (resized to square)", | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default=get_default_device(), | |
| help="Device to run benchmark on", | |
| ) | |
| parser.add_argument( | |
| "--num-iters", | |
| type=int, | |
| default=5, | |
| help="number of interations to run benchmark and average over", | |
| ) | |
| args = parser.parse_args() | |
| if args.device == "cuda": | |
| assert ( | |
| torch.cuda.is_available() | |
| ), "Chosen cuda as device but cuda unavailable! Try another device (cpu)" | |
| if args.matcher == "all": | |
| args.matcher = available_models | |
| return args | |
| def get_img_pairs(): | |
| asset_dir = Path(__file__).parent.joinpath("assets/example_pairs") | |
| pairs = [ | |
| list(pair.iterdir()) for pair in list(asset_dir.iterdir()) if pair.is_dir() | |
| ] | |
| return pairs | |
| def test_H_est(matcher, img_size=512): | |
| """Given a matcher, compute a homography of two images with known ground | |
| truth and its error. The error for sift-lg is 0.002 for img_size=500. So it | |
| should roughly be below 0.01.""" | |
| img0_path = "assets/example_test/warped.jpg" | |
| img1_path = "assets/example_test/original.jpg" | |
| ground_truth = np.array( | |
| [[0.1500, 0.3500], [0.9500, 0.1500], [0.9000, 0.7000], [0.2500, 0.7000]] | |
| ) | |
| image0 = matcher.load_image(img0_path, resize=img_size) | |
| image1 = matcher.load_image(img1_path, resize=img_size) | |
| result = matcher(image0, image1) | |
| pred_homog = np.array( | |
| [[0, 0], [img_size, 0], [img_size, img_size], [0, img_size]], dtype=np.float32 | |
| ) | |
| pred_homog = np.reshape(pred_homog, (4, 1, 2)) | |
| prediction = cv2.perspectiveTransform(pred_homog, result["H"])[:, 0] / img_size | |
| max_error = np.abs(ground_truth - prediction).max() | |
| return max_error | |
| def test(matcher, img_sizes=[512, 256], error_thresh=0.05): | |
| passing = True | |
| for img_size in img_sizes: | |
| error = test_H_est(matcher, img_size=img_size) | |
| if error > error_thresh: | |
| passing = False | |
| raise RuntimeError( | |
| f"Large homography error in matcher (size={img_size} px): {error}" | |
| ) | |
| return passing, error | |
| def benchmark(matcher, num_iters=1, img_size=512): | |
| runtime = [] | |
| for _ in range(num_iters): | |
| for pair in get_img_pairs(): | |
| img0 = matcher.load_image(pair[0], resize=img_size) | |
| img1 = matcher.load_image(pair[1], resize=img_size) | |
| start = time.time() | |
| _ = matcher(img0, img1) | |
| duration = time.time() - start | |
| runtime.append(duration) | |
| return runtime, np.mean(runtime) | |
| def main(args): | |
| print(args) | |
| if args.task == "benchmark": | |
| with open("runtime_results.txt", "w") as f: | |
| for model in tqdm(args.matcher): | |
| try: | |
| matcher = get_matcher(model, device=args.device) | |
| runtimes, avg_runtime = benchmark( | |
| matcher, num_iters=args.num_iters, img_size=args.img_size | |
| ) | |
| runtime_str = f"{model: <15} OK {avg_runtime=:.3f}" | |
| f.write(runtime_str + "\n") | |
| tqdm.write(runtime_str) | |
| except Exception as e: | |
| tqdm.write(f"{model: <15} NOT OK - exception: {e}") | |
| elif args.task == "test": | |
| with open("test_results.txt", "w") as f: | |
| test_str = "Matcher, Passing Tests, Error (px)" | |
| f.write(test_str + "\n" + "-" * 40 + "\n") | |
| tqdm.write(test_str) | |
| for model in tqdm(args.matcher): | |
| try: | |
| matcher = get_matcher(model, device=args.device) | |
| passing, error_val = test(matcher) | |
| test_str = f"{model}, {passing}, {error_val}" | |
| f.write(test_str + "\n") | |
| tqdm.write(test_str) | |
| except Exception as e: | |
| f.write(f"Error with {model}: {e}") | |
| tqdm.write(f"Error with {model}: {e}") | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| print(f"Running with args: {args}") | |
| main(args) | |