|
|
from functools import cache |
|
|
import numpy as np |
|
|
import scipy.sparse as sp |
|
|
import torch |
|
|
import cv2 |
|
|
import roma |
|
|
from tqdm import tqdm |
|
|
|
|
|
from cloud_opt.utils import * |
|
|
|
|
|
|
|
|
def compute_edge_scores(edges, edge2conf_i, edge2conf_j): |
|
|
""" |
|
|
edges: 'i_j', (i,j) |
|
|
""" |
|
|
score_dict = { |
|
|
(i, j): edge_conf(edge2conf_i[e], edge2conf_j[e]) for e, (i, j) in edges |
|
|
} |
|
|
return score_dict |
|
|
|
|
|
|
|
|
def dict_to_sparse_graph(dic): |
|
|
n_imgs = max(max(e) for e in dic) + 1 |
|
|
res = sp.dok_array((n_imgs, n_imgs)) |
|
|
for edge, value in dic.items(): |
|
|
res[edge] = value |
|
|
return res |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def init_minimum_spanning_tree(self, **kw): |
|
|
"""Init all camera poses (image-wise and pairwise poses) given |
|
|
an initial set of pairwise estimations. |
|
|
""" |
|
|
device = self.device |
|
|
pts3d, _, im_focals, im_poses = minimum_spanning_tree( |
|
|
self.imshapes, |
|
|
self.edges, |
|
|
self.edge2pts_i, |
|
|
self.edge2pts_j, |
|
|
self.edge2conf_i, |
|
|
self.edge2conf_j, |
|
|
self.im_conf, |
|
|
self.min_conf_thr, |
|
|
device, |
|
|
has_im_poses=self.has_im_poses, |
|
|
verbose=self.verbose, |
|
|
**kw, |
|
|
) |
|
|
|
|
|
return init_from_pts3d(self, pts3d, im_focals, im_poses) |
|
|
|
|
|
|
|
|
def minimum_spanning_tree( |
|
|
imshapes, |
|
|
edges, |
|
|
edge2pred_i, |
|
|
edge2pred_j, |
|
|
edge2conf_i, |
|
|
edge2conf_j, |
|
|
im_conf, |
|
|
min_conf_thr, |
|
|
device, |
|
|
has_im_poses=True, |
|
|
niter_PnP=10, |
|
|
verbose=True, |
|
|
save_score_path=None, |
|
|
): |
|
|
n_imgs = len(imshapes) |
|
|
eadge_and_scores = compute_edge_scores(map(i_j_ij, edges), edge2conf_i, edge2conf_j) |
|
|
sparse_graph = -dict_to_sparse_graph(eadge_and_scores) |
|
|
msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo() |
|
|
|
|
|
|
|
|
pts3d = [None] * len(imshapes) |
|
|
|
|
|
todo = sorted(zip(-msp.data, msp.row, msp.col)) |
|
|
im_poses = [None] * n_imgs |
|
|
im_focals = [None] * n_imgs |
|
|
|
|
|
|
|
|
score, i, j = todo.pop() |
|
|
if verbose: |
|
|
print(f" init edge ({i}*,{j}*) {score=}") |
|
|
i_j = edge_str(i, j) |
|
|
|
|
|
pts3d[i] = edge2pred_i[i_j].clone() |
|
|
pts3d[j] = edge2pred_j[i_j].clone() |
|
|
done = {i, j} |
|
|
if has_im_poses: |
|
|
im_poses[i] = torch.eye(4, device=device) |
|
|
im_focals[i] = estimate_focal(edge2pred_i[i_j]) |
|
|
|
|
|
|
|
|
msp_edges = [(i, j)] |
|
|
while todo: |
|
|
|
|
|
score, i, j = todo.pop() |
|
|
|
|
|
if im_focals[i] is None: |
|
|
im_focals[i] = estimate_focal(edge2pred_i[i_j]) |
|
|
|
|
|
if i in done: |
|
|
if verbose: |
|
|
print(f" init edge ({i},{j}*) {score=}") |
|
|
assert j not in done |
|
|
|
|
|
i_j = edge_str(i, j) |
|
|
s, R, T = rigid_points_registration( |
|
|
edge2pred_i[i_j], pts3d[i], conf=edge2conf_i[i_j] |
|
|
) |
|
|
trf = sRT_to_4x4(s, R, T, device) |
|
|
pts3d[j] = geotrf(trf, edge2pred_j[i_j]) |
|
|
done.add(j) |
|
|
msp_edges.append((i, j)) |
|
|
|
|
|
if has_im_poses and im_poses[i] is None: |
|
|
im_poses[i] = sRT_to_4x4(1, R, T, device) |
|
|
|
|
|
elif j in done: |
|
|
if verbose: |
|
|
print(f" init edge ({i}*,{j}) {score=}") |
|
|
assert i not in done |
|
|
i_j = edge_str(i, j) |
|
|
s, R, T = rigid_points_registration( |
|
|
edge2pred_j[i_j], pts3d[j], conf=edge2conf_j[i_j] |
|
|
) |
|
|
trf = sRT_to_4x4(s, R, T, device) |
|
|
pts3d[i] = geotrf(trf, edge2pred_i[i_j]) |
|
|
done.add(i) |
|
|
msp_edges.append((i, j)) |
|
|
|
|
|
if has_im_poses and im_poses[i] is None: |
|
|
im_poses[i] = sRT_to_4x4(1, R, T, device) |
|
|
else: |
|
|
|
|
|
todo.insert(0, (score, i, j)) |
|
|
|
|
|
if has_im_poses: |
|
|
|
|
|
pair_scores = list( |
|
|
sparse_graph.values() |
|
|
) |
|
|
edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[ |
|
|
np.argsort(pair_scores) |
|
|
] |
|
|
for i, j in edges_from_best_to_worse.tolist(): |
|
|
if im_focals[i] is None: |
|
|
im_focals[i] = estimate_focal(edge2pred_i[edge_str(i, j)]) |
|
|
|
|
|
for i in range(n_imgs): |
|
|
if im_poses[i] is None: |
|
|
msk = im_conf[i] > min_conf_thr |
|
|
res = fast_pnp( |
|
|
pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP |
|
|
) |
|
|
if res: |
|
|
im_focals[i], im_poses[i] = res |
|
|
if im_poses[i] is None: |
|
|
im_poses[i] = torch.eye(4, device=device) |
|
|
im_poses = torch.stack(im_poses) |
|
|
else: |
|
|
im_poses = im_focals = None |
|
|
|
|
|
return pts3d, msp_edges, im_focals, im_poses |
|
|
|
|
|
|
|
|
def init_from_pts3d(self, pts3d, im_focals, im_poses): |
|
|
|
|
|
nkp, known_poses_msk, known_poses = self.get_known_poses() |
|
|
if nkp == 1: |
|
|
raise NotImplementedError( |
|
|
"Would be simpler to just align everything afterwards on the single known pose" |
|
|
) |
|
|
elif nkp > 1: |
|
|
|
|
|
s, R, T = align_multiple_poses( |
|
|
im_poses[known_poses_msk], known_poses[known_poses_msk] |
|
|
) |
|
|
trf = sRT_to_4x4(s, R, T, device=known_poses.device) |
|
|
|
|
|
|
|
|
im_poses = trf @ im_poses |
|
|
im_poses[:, :3, :3] /= s |
|
|
for img_pts3d in pts3d: |
|
|
img_pts3d[:] = geotrf(trf, img_pts3d) |
|
|
else: |
|
|
pass |
|
|
|
|
|
|
|
|
for e, (i, j) in enumerate(self.edges): |
|
|
i_j = edge_str(i, j) |
|
|
|
|
|
s, R, T = rigid_points_registration( |
|
|
self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j] |
|
|
) |
|
|
self._set_pose(self.pw_poses, e, R, T, scale=s) |
|
|
|
|
|
|
|
|
s_factor = self.get_pw_norm_scale_factor() |
|
|
im_poses[:, :3, 3] *= s_factor |
|
|
for img_pts3d in pts3d: |
|
|
img_pts3d *= s_factor |
|
|
|
|
|
|
|
|
if self.has_im_poses: |
|
|
for i in range(self.n_imgs): |
|
|
cam2world = im_poses[i] |
|
|
depth = geotrf(inv(cam2world), pts3d[i])[..., 2] |
|
|
self._set_depthmap(i, depth) |
|
|
self._set_pose(self.im_poses, i, cam2world) |
|
|
if im_focals[i] is not None: |
|
|
if not self.shared_focal: |
|
|
self._set_focal(i, im_focals[i]) |
|
|
if self.shared_focal: |
|
|
self._set_focal(0, sum(im_focals) / self.n_imgs) |
|
|
if self.n_imgs > 2: |
|
|
self._set_init_depthmap() |
|
|
|
|
|
if self.verbose: |
|
|
with torch.no_grad(): |
|
|
print(" init loss =", float(self())) |
|
|
|