EarthLoc2
/
image-matching-models
/matching
/third_party
/accelerated_features
/modules
/eval
/scannet1500.py
| """ | |
| "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024." | |
| https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/ | |
| Camera pose metrics adapted from LoFTR https://github.com/zju3dv/LoFTR/blob/master/src/utils/metrics.py | |
| The main difference is the use of poselib instead of OpenCV's vanilla RANSAC for E_mat, which is more stable and MUCH and faster. | |
| """ | |
| import argparse | |
| import numpy as np | |
| import os | |
| import cv2 | |
| from tqdm import tqdm | |
| import json | |
| import multiprocessing as mp | |
| # Disable scientific notation | |
| np.set_printoptions(suppress=True) | |
| def intrinsics_to_camera(K): | |
| px, py = K[0, 2], K[1, 2] | |
| fx, fy = K[0, 0], K[1, 1] | |
| return { | |
| "model": "PINHOLE", | |
| "width": int(2 * px), | |
| "height": int(2 * py), | |
| "params": [fx, fy, px, py], | |
| } | |
| def angle_error_vec(v1, v2): | |
| n = np.linalg.norm(v1) * np.linalg.norm(v2) | |
| return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) | |
| def angle_error_mat(R1, R2): | |
| cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 | |
| cos = np.clip(cos, -1., 1.) # numercial errors can make it out of bounds | |
| return np.rad2deg(np.abs(np.arccos(cos))) | |
| def compute_pose_error(T_0to1, R, t): | |
| R_gt = T_0to1[:3, :3] | |
| t_gt = T_0to1[:3, 3] | |
| error_t = angle_error_vec(t, t_gt) | |
| error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation | |
| error_R = angle_error_mat(R, R_gt) | |
| return error_t, error_R | |
| def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999, type='poselib'): | |
| if len(kpts0) < 5: | |
| return None | |
| if type == 'poselib': | |
| import poselib | |
| (pose,details) = poselib.estimate_relative_pose( | |
| kpts0.tolist(), | |
| kpts1.tolist(), | |
| intrinsics_to_camera(K0), | |
| intrinsics_to_camera(K1), | |
| ransac_opt={ | |
| 'max_iterations': 10000, # default 100000 | |
| 'success_prob': conf, # default 0.99999 | |
| 'max_epipolar_error': thresh, # default 1.0 | |
| }, | |
| bundle_opt={ # all defaults | |
| }, | |
| ) | |
| ret = (pose.R, pose.t, details['inliers']) | |
| elif type == 'opencv': | |
| f_mean = np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) | |
| norm_thresh = thresh / f_mean | |
| kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] | |
| kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] | |
| E, mask = cv2.findEssentialMat( | |
| kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf, | |
| method=cv2.RANSAC) | |
| assert E is not None | |
| best_num_inliers = 0 | |
| ret = None | |
| for _E in np.split(E, len(E) / 3): | |
| n, R, t, _ = cv2.recoverPose( | |
| _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) | |
| if n > best_num_inliers: | |
| best_num_inliers = n | |
| ret = (R, t[:, 0], mask.ravel() > 0) | |
| else: | |
| raise NotImplementedError | |
| return ret | |
| def estimate_pose_parallel(args): | |
| return estimate_pose(*args) | |
| def pose_auc(errors, thresholds): | |
| sort_idx = np.argsort(errors) | |
| errors = np.array(errors.copy())[sort_idx] | |
| recall = (np.arange(len(errors)) + 1) / len(errors) | |
| errors = np.r_[0., errors] | |
| recall = np.r_[0., recall] | |
| aucs = [] | |
| for t in thresholds: | |
| last_index = np.searchsorted(errors, t) | |
| r = np.r_[recall[:last_index], recall[last_index-1]] | |
| e = np.r_[errors[:last_index], t] | |
| aucs.append(np.trapz(r, x=e)/t) | |
| return aucs | |
| def pose_accuracy(errors, thresholds): | |
| return [np.mean(errors < t) * 100 for t in thresholds] | |
| def get_relative_transform(pose0, pose1): | |
| R0 = pose0[..., :3, :3] # Bx3x3 | |
| t0 = pose0[..., :3, [3]] # Bx3x1 | |
| R1 = pose1[..., :3, :3] # Bx3x3 | |
| t1 = pose1[..., :3, [3]] # Bx3x1 | |
| R_0to1 = R1.transpose(-1, -2) @ R0 # Bx3x3 | |
| t_0to1 = R1.transpose(-1, -2) @ (t0 - t1) # Bx3x1 | |
| T_0to1 = np.concatenate([R_0to1, t_0to1], axis=-1) # Bx3x4 | |
| return T_0to1 | |
| class Scannet1500: | |
| default_config = { | |
| 'scannet_path': os.path.abspath(os.path.join(os.path.dirname(__file__), '../../data/ScanNet/scannet_test_1500')), | |
| 'gt_path': os.path.abspath(os.path.join(os.path.dirname(__file__), '../../data/ScanNet/test.npz')), | |
| 'pose_estimator': 'poselib', # poselib, opencv | |
| 'cache_images': True, | |
| 'ransac_thresholds': [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0], | |
| 'pose_thresholds': [5, 10, 20], | |
| 'max_pairs': -1, | |
| 'output': './output/scannet/', | |
| 'n_workers': 8, | |
| } | |
| def __init__(self, config={}) -> None: | |
| self.config = {**self.default_config, **config} | |
| if not os.path.exists(self.config['scannet_path']): | |
| raise RuntimeError( | |
| f"Dataset {self.config['scannet_path']} does not exist! \n \ | |
| > If you didn't download the dataset, use the downloader tool: python3 -m modules.dataset.download -h") | |
| self.pairs = self.read_gt() | |
| os.makedirs(self.config['output'], exist_ok=True) | |
| if self.config['n_workers'] == -1: | |
| self.config['n_workers'] = mp.cpu_count() | |
| self.image_cache = {} | |
| if self.config['cache_images']: | |
| self.load_images() | |
| def load_images(self): | |
| for pair in tqdm(self.pairs, desc='Caching images'): | |
| if pair['image0'] not in self.image_cache: | |
| self.image_cache[pair['image0']] = cv2.imread(pair['image0']) | |
| if pair['image1'] not in self.image_cache: | |
| self.image_cache[pair['image1']] = cv2.imread(pair['image1']) | |
| def read_image(self, path): | |
| if self.config['cache_images']: | |
| return self.image_cache[path] | |
| else: | |
| return cv2.imread(path) | |
| def read_gt(self): | |
| pairs = [] | |
| gt_poses = np.load(self.config['gt_path']) | |
| names = gt_poses['name'] | |
| for i in range(len(names)): | |
| scene_id = names[i, 0] | |
| scene_idx = names[i, 1] | |
| scene = f'scene{scene_id:04d}_{scene_idx:02d}' | |
| image0 = str(int(names[i, 2])) | |
| image1 = str(int(names[i, 3])) | |
| K0 = np.loadtxt( | |
| os.path.join(self.config['scannet_path'], 'scannet_test_1500', scene, 'intrinsic/intrinsic_color.txt') | |
| ) | |
| K1 = K0 | |
| pose_0 = np.loadtxt( | |
| os.path.join(self.config['scannet_path'], 'scannet_test_1500', scene, 'pose', image0 + '.txt') | |
| ) | |
| pose_1 = np.loadtxt( | |
| os.path.join(self.config['scannet_path'], 'scannet_test_1500', scene, 'pose', image1 + '.txt') | |
| ) | |
| T_0to1 = get_relative_transform(pose_0, pose_1) | |
| pairs.append({ | |
| 'image0': os.path.join(self.config['scannet_path'], 'scannet_test_1500', scene, 'color', image0 + '.jpg'), | |
| 'image1': os.path.join(self.config['scannet_path'], 'scannet_test_1500', scene, 'color', image1 + '.jpg'), | |
| 'K0': K0, | |
| 'K1': K1, | |
| 'T_0to1': T_0to1, | |
| }) | |
| return pairs | |
| def extract_and_save_matches(self, matcher_fn, name='', force=False): | |
| all_matches = [] | |
| if name == '': | |
| name = matcher_fn.__name__ | |
| fname = os.path.join(self.config['output'], f'{name}_matches.npz') | |
| if not force and os.path.exists(fname): | |
| return np.load(fname, allow_pickle=True)['all_matches'] | |
| for pair in tqdm(self.pairs, desc='Extracting matches'): | |
| image0 = self.read_image(pair['image0']) | |
| image1 = self.read_image(pair['image1']) | |
| mkpts0, mkpts1 = matcher_fn(image0, image1) | |
| all_matches.append({ | |
| 'image0': pair['image0'], | |
| 'image1': pair['image1'], | |
| 'mkpts0': mkpts0, | |
| 'mkpts1': mkpts1, | |
| }) | |
| np.savez(fname, all_matches=all_matches) | |
| return all_matches | |
| def run_benchmark(self, matcher_fn, name='', force=False): | |
| if name == '': | |
| name = matcher_fn.__name__ | |
| all_matches = self.extract_and_save_matches(matcher_fn, name=name, force=force) | |
| aucs_by_thresh = {} | |
| accs_by_thresh = {} | |
| for ransac_thresh in self.config['ransac_thresholds']: | |
| fname = os.path.join(self.config['output'], f'{name}_{self.config["pose_estimator"]}_{ransac_thresh}.txt') | |
| # check if exists and has the right number of lines | |
| if not force and os.path.exists(fname) and len(open(fname, 'r').readlines()) == len(self.pairs): | |
| errors = [] | |
| with open(fname, 'r') as f: | |
| lines = f.readlines() | |
| for line in lines: | |
| line = line.replace('\n', '') | |
| err_t, err_R = line.split(' ') | |
| errors.append([float(err_t), float(err_R)]) | |
| # redo the benchmark | |
| else: | |
| errors = [] | |
| pairs = self.pairs | |
| errors_file = open(fname, 'w') | |
| # do the benchmark in parallel | |
| if self.config['n_workers'] != 1: | |
| pool = mp.Pool(self.config['n_workers']) | |
| pool_args = [ (all_matches[pair_idx]['mkpts0'], all_matches[pair_idx]['mkpts1'], pair['K0'], pair['K1'], ransac_thresh) for pair_idx, pair in enumerate(pairs) ] | |
| results = list(tqdm(pool.imap(estimate_pose_parallel, pool_args), total=len(pool_args), desc=f'Running benchmark for th={ransac_thresh}', leave=False)) | |
| pool.close() | |
| for pair_idx, ret in enumerate(results): | |
| if ret is None: | |
| err_t, err_R = np.inf, np.inf | |
| else: | |
| R, t, inliers = ret | |
| pair = pairs[pair_idx] | |
| err_t, err_R = compute_pose_error(pair['T_0to1'], R, t) | |
| errors_file.write(f'{err_t} {err_R}\n') | |
| errors.append([err_t, err_R]) | |
| # do the benchmark in serial | |
| else: | |
| for pair_idx, pair in tqdm(enumerate(pairs), desc=f'Running benchmark for th={ransac_thresh}', leave=False, total=len(pairs)): | |
| mkpts0 = all_matches[pair_idx]['mkpts0'] | |
| mkpts1 = all_matches[pair_idx]['mkpts1'] | |
| ret = estimate_pose(mkpts0, mkpts1, pair['K0'], pair['K1'], ransac_thresh) | |
| if ret is None: | |
| err_t, err_R = np.inf, np.inf | |
| else: | |
| R, t, inliers = ret | |
| err_t, err_R = compute_pose_error(pair['T_0to1'], R, t) | |
| errors_file.write(f'{err_t} {err_R}\n') | |
| errors_file.flush() | |
| errors.append([err_t, err_R]) | |
| errors_file.close() | |
| # compute AUCs | |
| errors = np.array(errors) | |
| errors = errors.max(axis=1) | |
| aucs = pose_auc(errors, self.config['pose_thresholds']) | |
| accs = pose_accuracy(errors, self.config['pose_thresholds']) | |
| aucs = {k: v*100 for k, v in zip(self.config['pose_thresholds'], aucs)} | |
| accs = {k: v for k, v in zip(self.config['pose_thresholds'], accs)} | |
| aucs_by_thresh[ransac_thresh] = aucs | |
| accs_by_thresh[ransac_thresh] = accs | |
| # dump summary for this method | |
| summary = { | |
| 'name': name, | |
| 'aucs_by_thresh': aucs_by_thresh, | |
| 'accs_by_thresh': accs_by_thresh, | |
| } | |
| json.dump(summary, open(os.path.join(self.config['output'], f'{name}_{self.config["pose_estimator"]}_summary.json'), 'w'), indent=2) | |
| return aucs_by_thresh | |
| def get_xfeat(): | |
| from modules.xfeat import XFeat | |
| xfeat = XFeat() | |
| return xfeat.match_xfeat | |
| def get_xfeat_star(): | |
| from modules.xfeat import XFeat | |
| xfeat = XFeat(top_k=10_000) | |
| return xfeat.match_xfeat_star | |
| def get_alike(): | |
| from third_party import alike_wrapper as alike | |
| return alike.match_alike | |
| def print_fancy(d): | |
| print(json.dumps(d, indent=2)) | |
| def parse(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--scannet_path", type=str, required=True, help="Path to the Scannet 1500 dataset") | |
| parser.add_argument("--output", type=str, default="./output/scannet/", help="Path to the output directory") | |
| parser.add_argument("--max_pairs", type=int, default=-1, help="Maximum number of pairs to run the benchmark on") | |
| parser.add_argument("--force", action='store_true', help="Force running the benchmark again") | |
| parser.add_argument("--pose_estimator", type=str, default='poselib', help="Which pose estimator to use: poselib, opencv", choices=['poselib', 'opencv']) | |
| parser.add_argument("--show", action='store_true', help="Show the matches") | |
| parser.add_argument("--accuracy", action='store_true', help="Show the accuracy instead of AUC") | |
| parser.add_argument("--filter", type=str, nargs='+', help="Filter the results by the given names") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse() | |
| if not args.show: | |
| scannet = Scannet1500({ | |
| 'scannet_path': args.scannet_path, | |
| 'gt_path': args.scannet_path + "/test.npz", | |
| 'cache_images': False, | |
| 'output': args.output, | |
| 'max_pairs': args.max_pairs, | |
| 'pose_estimator': args.pose_estimator, | |
| 'ransac_thresholds': [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0], | |
| 'n_workers': 8, | |
| }) | |
| functions = { | |
| 'xfeat': get_xfeat(), | |
| 'xfeat_star': get_xfeat_star(), | |
| 'alike': get_alike(), | |
| } | |
| # save all results to a file | |
| all_results = {} | |
| for name, fn in functions.items(): | |
| print(name) | |
| result = scannet.run_benchmark(matcher_fn = fn, name=name, force=args.force) | |
| all_results[name] = result | |
| json.dump(all_results, open(os.path.join(args.output, 'summary.json'), 'w'), indent=2) | |
| if args.show: | |
| import glob | |
| import pandas as pd | |
| dataset_name = 'scannet' | |
| all_summary_files = glob.glob(os.path.join(args.output, "**_summary.json"), recursive=True) | |
| if args.filter: | |
| all_summary_files = [f for f in all_summary_files if any([fil in f for fil in args.filter])] | |
| dfs = [] | |
| names = [] | |
| estimators = [] | |
| metric_key = 'aucs_by_thresh' | |
| if args.accuracy: | |
| metric_key = 'accuracies_by_thresh' | |
| for summary in all_summary_files: | |
| summary_data = json.load(open(summary, 'r')) | |
| if metric_key not in summary_data: | |
| continue | |
| aucs_by_thresh = summary_data[metric_key] | |
| estimator = 'poselib' | |
| if 'opencv' in summary: | |
| estimator = 'opencv' | |
| #make sure everything is float | |
| for thresh in aucs_by_thresh: | |
| for k in aucs_by_thresh[thresh]: | |
| if isinstance(aucs_by_thresh[thresh][k], str): | |
| aucs_by_thresh[thresh][k] = float(aucs_by_thresh[thresh][k].replace(' ', '')) | |
| # find best threshold based on the 5, 10, 20 mAP and everything is float | |
| df = pd.DataFrame(aucs_by_thresh).T.astype(float) | |
| df['mean'] = df.mean(axis=1) | |
| # create a string column called estimator | |
| cols = df.columns.tolist() | |
| dfs.append(df) | |
| names.append(summary_data['name']) | |
| estimators.append(estimator) | |
| # use each col as the main col to determine the best threshold | |
| # for col in cols: | |
| col = 'mean' | |
| final_df = pd.DataFrame() | |
| # add cols | |
| final_df['name'] = names | |
| final_df['best_thresh'] = '' | |
| final_df['estimator'] = estimators | |
| final_df[cols] = -1.0 | |
| for df, name, estimator in zip(dfs, names, estimators): | |
| best_thresh = df[col].idxmax() | |
| best_results = df.loc[best_thresh] | |
| # now update the best_thresh based on the estimator | |
| final_df.loc[(final_df['name'] == name) & (final_df['estimator'] == estimator), 'best_thresh'] = best_thresh | |
| for _col in cols: | |
| final_df.loc[(final_df['name'] == name) & (final_df['estimator'] == estimator), _col] = best_results[_col] | |
| # sort by mean | |
| final_df = final_df.sort_values(by=['mean']) | |
| # reset index | |
| final_df = final_df.reset_index(drop=True) | |
| # drop estimator column | |
| final_df = final_df.drop(columns=['estimator']) | |
| # set max float precision to 1 | |
| final_df = final_df.round(1) | |
| print(f"Dataset: {dataset_name}") | |
| print(f"Sorting by {col}") | |
| print(final_df) | |
| print() | |
| final_df.to_csv(os.path.join(args.output, f"{dataset_name}_{col}.csv"), index=False) | |