Spaces:
Running
on
Zero
Running
on
Zero
| import matplotlib.pyplot as plt | |
| import warnings | |
| import numpy as np | |
| import cv2 | |
| import os | |
| import os.path as osp | |
| import imageio | |
| from copy import deepcopy | |
| import loguru | |
| import torch | |
| from ..models.loftr import LoFTR, default_cfg | |
| from .utils3d import rect_to_img, canonical_to_camera, calc_pose | |
| class ElevEstHelper: | |
| _feature_matcher = None | |
| def get_feature_matcher(cls, ckpt_path, device): | |
| if cls._feature_matcher is None: | |
| loguru.logger.info("Loading feature matcher...") | |
| assert os.path.exists(ckpt_path) | |
| _default_cfg = deepcopy(default_cfg) | |
| _default_cfg['coarse']['temp_bug_fix'] = True # set to False when using the old ckpt | |
| matcher = LoFTR(config=_default_cfg) | |
| matcher.load_state_dict(torch.load(ckpt_path)['state_dict']) | |
| matcher = matcher.eval().to(device) | |
| cls._feature_matcher = matcher | |
| return cls._feature_matcher | |
| def mask_out_bkgd(img): | |
| if img.shape[-1] == 4: | |
| fg_mask = img[:, :, :3] | |
| else: | |
| loguru.logger.info("Image has no alpha channel, using thresholding to mask out background") | |
| fg_mask = ~(img > 245).all(axis=-1) | |
| return fg_mask | |
| def get_feature_matching(matcher, images): | |
| assert len(images) == 4 | |
| feature_matching = {} | |
| masks = [] | |
| for i in range(4): | |
| mask = mask_out_bkgd(images[i]) | |
| masks.append(mask) | |
| for i in range(0, 4): | |
| for j in range(i + 1, 4): | |
| mask0 = masks[i] | |
| mask1 = masks[j] | |
| img0_raw = cv2.cvtColor(images[i], cv2.COLOR_RGB2GRAY) | |
| img1_raw = cv2.cvtColor(images[j], cv2.COLOR_RGB2GRAY) | |
| original_shape = img0_raw.shape | |
| img0_raw_resized = cv2.resize(img0_raw, (480, 480)) | |
| img1_raw_resized = cv2.resize(img1_raw, (480, 480)) | |
| img0 = torch.from_numpy(img0_raw_resized)[None][None].cuda() / 255. | |
| img1 = torch.from_numpy(img1_raw_resized)[None][None].cuda() / 255. | |
| batch = {'image0': img0, 'image1': img1} | |
| # Inference with LoFTR and get prediction | |
| with torch.no_grad(): | |
| matcher(batch) | |
| mkpts0 = batch['mkpts0_f'].cpu().numpy() | |
| mkpts1 = batch['mkpts1_f'].cpu().numpy() | |
| mconf = batch['mconf'].cpu().numpy() | |
| mkpts0[:, 0] = mkpts0[:, 0] * original_shape[1] / 480 | |
| mkpts0[:, 1] = mkpts0[:, 1] * original_shape[0] / 480 | |
| mkpts1[:, 0] = mkpts1[:, 0] * original_shape[1] / 480 | |
| mkpts1[:, 1] = mkpts1[:, 1] * original_shape[0] / 480 | |
| keep0 = mask0[mkpts0[:, 1].astype(int), mkpts1[:, 0].astype(int)] | |
| keep1 = mask1[mkpts1[:, 1].astype(int), mkpts1[:, 0].astype(int)] | |
| keep = np.logical_and(keep0, keep1) | |
| mkpts0 = mkpts0[keep] | |
| mkpts1 = mkpts1[keep] | |
| mconf = mconf[keep] | |
| feature_matching[f"{i}_{j}"] = np.concatenate([mkpts0, mkpts1, mconf[:, None]], axis=1) | |
| return feature_matching | |
| def gen_pose_hypothesis(center_elevation): | |
| elevations = np.radians( | |
| [center_elevation, center_elevation - 10, center_elevation + 10, center_elevation, center_elevation]) # 45~120 | |
| azimuths = np.radians([30, 30, 30, 20, 40]) | |
| input_poses = calc_pose(elevations, azimuths, len(azimuths)) | |
| input_poses = input_poses[1:] | |
| input_poses[..., 1] *= -1 | |
| input_poses[..., 2] *= -1 | |
| return input_poses | |
| def ba_error_general(K, matches, poses): | |
| projmat0 = K @ poses[0].inverse()[:3, :4] | |
| projmat1 = K @ poses[1].inverse()[:3, :4] | |
| match_01 = matches[0] | |
| pts0 = match_01[:, :2] | |
| pts1 = match_01[:, 2:4] | |
| Xref = cv2.triangulatePoints(projmat0.cpu().numpy(), projmat1.cpu().numpy(), | |
| pts0.cpu().numpy().T, pts1.cpu().numpy().T) | |
| Xref = Xref[:3] / Xref[3:] | |
| Xref = Xref.T | |
| Xref = torch.from_numpy(Xref).float() | |
| reproj_error = 0 | |
| for match, cp in zip(matches[1:], poses[2:]): | |
| dist = (torch.norm(match_01[:, :2][:, None, :] - match[:, :2][None, :, :], dim=-1)) | |
| if dist.numel() > 0: | |
| # print("dist.shape", dist.shape) | |
| m0to2_index = dist.argmin(1) | |
| keep = dist[torch.arange(match_01.shape[0]), m0to2_index] < 1 | |
| if keep.sum() > 0: | |
| xref_in2 = rect_to_img(K, canonical_to_camera(Xref, cp.inverse())) | |
| reproj_error2 = torch.norm(match[m0to2_index][keep][:, 2:4] - xref_in2[keep], dim=-1) | |
| conf02 = match[m0to2_index][keep][:, -1] | |
| reproj_error += (reproj_error2 * conf02).sum() / (conf02.sum()) | |
| return reproj_error | |
| def find_optim_elev(elevs, nimgs, matches, K): | |
| errs = [] | |
| for elev in elevs: | |
| err = 0 | |
| cam_poses = gen_pose_hypothesis(elev) | |
| for start in range(nimgs - 1): | |
| batch_matches, batch_poses = [], [] | |
| for i in range(start, nimgs + start): | |
| ci = i % nimgs | |
| batch_poses.append(cam_poses[ci]) | |
| for j in range(nimgs - 1): | |
| key = f"{start}_{(start + j + 1) % nimgs}" | |
| match = matches[key] | |
| batch_matches.append(match) | |
| err += ba_error_general(K, batch_matches, batch_poses) | |
| errs.append(err) | |
| errs = torch.tensor(errs) | |
| optim_elev = elevs[torch.argmin(errs)].item() | |
| return optim_elev | |
| def get_elev_est(feature_matching, min_elev=30, max_elev=150, K=None): | |
| flag = True | |
| matches = {} | |
| for i in range(4): | |
| for j in range(i + 1, 4): | |
| match_ij = feature_matching[f"{i}_{j}"] | |
| if len(match_ij) == 0: | |
| flag = False | |
| match_ji = np.concatenate([match_ij[:, 2:4], match_ij[:, 0:2], match_ij[:, 4:5]], axis=1) | |
| matches[f"{i}_{j}"] = torch.from_numpy(match_ij).float() | |
| matches[f"{j}_{i}"] = torch.from_numpy(match_ji).float() | |
| if not flag: | |
| loguru.logger.info("0 matches, could not estimate elevation") | |
| return None | |
| interval = 10 | |
| elevs = np.arange(min_elev, max_elev, interval) | |
| optim_elev1 = find_optim_elev(elevs, 4, matches, K) | |
| elevs = np.arange(optim_elev1 - 10, optim_elev1 + 10, 1) | |
| elevs = elevs[elevs % 180 != 0] | |
| elevs = elevs[(elevs - 10) % 180 != 0] | |
| elevs = elevs[(elevs + 10) % 180 != 0] | |
| optim_elev2 = find_optim_elev(elevs, 4, matches, K) | |
| return optim_elev2 | |
| def elev_est_api(matcher, images, min_elev=30, max_elev=150, K=None): | |
| feature_matching = get_feature_matching(matcher, images) | |
| if K is None: | |
| loguru.logger.warning("K is not provided, using default K") | |
| K = np.array([[280.0, 0, 128.0], | |
| [0, 280.0, 128.0], | |
| [0, 0, 1]]) | |
| K = torch.from_numpy(K).float() | |
| elev = get_elev_est(feature_matching, min_elev, max_elev, K) | |
| return elev | |