File size: 3,729 Bytes
e170a8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | import torch
from torchvision.transforms import Resize
from .matchers import LightGlue, LoFTR
# from .__models import SuperGlue, SGMNet, ASpanFormer, DKM
from .pose_solver import EssentialMatrixSolver, EssentialMatrixMetricSolver, PnPSolver, ProcrustesSolver
import time
class PoseRecover():
def __init__(self, matcher='lightglue', solver='procrustes', img_resize=None, device='cuda'):
self.device = device
if matcher == 'lightglue':
self.matcher = LightGlue(device=device)
elif matcher == 'loftr':
self.matcher = LoFTR(device=device)
# elif matcher == 'superglue':
# self.matcher = SuperGlue(device=device)
# elif matcher == 'aspanformer':
# self.matcher = ASpanFormer(device=device)
# elif matcher == 'sgmnet':
# self.matcher = SGMNet(device=device)
# elif matcher == 'dkm':
# self.matcher = DKM(device=device)
else:
raise NotImplementedError
self.img_resize = img_resize
self.basic_solver = EssentialMatrixSolver()
if solver == 'essential':
self.scaled_solver = EssentialMatrixMetricSolver()
elif solver == 'pnp':
self.scaled_solver = PnPSolver()
elif solver == 'procrustes':
self.scaled_solver = ProcrustesSolver()
def recover(self, image0, image1, K0, K1, bbox0=None, bbox1=None, mask0=None, mask1=None, depth0=None, depth1=None):
if self.img_resize is not None:
h, w = image0.shape[-2:]
if h > w:
h_new = self.img_resize
w_new = int(w * h_new / h)
else:
w_new = self.img_resize
h_new = int(h * w_new / w)
# h_new, w_new = 480, 640
resize = Resize((h_new, w_new), antialias=True)
scale0 = torch.tensor([image0.shape[-1]/w_new, image0.shape[-2]/h_new], dtype=torch.float)
scale1 = torch.tensor([image1.shape[-1]/w_new, image1.shape[-2]/h_new], dtype=torch.float)
image0 = resize(image0)
image1 = resize(image1)
points0, points1, preprocess_time, extract_time, match_time = self.matcher.match(image0, image1)
if self.img_resize is not None:
points0 *= scale0.unsqueeze(0).to(points0.device)
points1 *= scale1.unsqueeze(0).to(points1.device)
if bbox0 is not None and bbox1 is not None:
x1, y1, x2, y2 = bbox0
u1, v1, u2, v2 = bbox1
points0[:, 0] += x1
points0[:, 1] += y1
points1[:, 0] += u1
points1[:, 1] += v1
if mask0 is not None and mask1 is not None:
filtered_ind0 = mask0[(points0[:, 1]).int(), (points0[:, 0]).int()]
filtered_ind1 = mask1[(points1[:, 1]).int(), (points1[:, 0]).int()]
filtered_inds = filtered_ind0 * filtered_ind1
points0 = points0[filtered_inds]
points1 = points1[filtered_inds]
points0, points1 = points0.cpu().numpy(), points1.cpu().numpy()
start_time = time.time()
if depth0 is None or depth1 is None:
R_est, t_est, _ = self.basic_solver.estimate_pose(points0, points1, {'K_color0': K0, 'K_color1': K1})
else:
R_est, t_est, _ = self.scaled_solver.estimate_pose(points0, points1, {'K_color0': K0, 'K_color1': K1, 'depth0': depth0, 'depth1': depth1})
recover_time = time.time()
return R_est, t_est, points0, points1, preprocess_time, extract_time, match_time, recover_time-start_time
|