vmem / extern /CUT3R /cloud_opt /init_all.py
Jiahua0's picture
Upload folder using huggingface_hub
ff47419 verified
from functools import cache
import numpy as np
import scipy.sparse as sp
import torch
import cv2
import roma
from tqdm import tqdm
from cloud_opt.utils import *
def compute_edge_scores(edges, edge2conf_i, edge2conf_j):
"""
edges: 'i_j', (i,j)
"""
score_dict = {
(i, j): edge_conf(edge2conf_i[e], edge2conf_j[e]) for e, (i, j) in edges
}
return score_dict
def dict_to_sparse_graph(dic):
n_imgs = max(max(e) for e in dic) + 1
res = sp.dok_array((n_imgs, n_imgs))
for edge, value in dic.items():
res[edge] = value
return res
@torch.no_grad()
def init_minimum_spanning_tree(self, **kw):
"""Init all camera poses (image-wise and pairwise poses) given
an initial set of pairwise estimations.
"""
device = self.device
pts3d, _, im_focals, im_poses = minimum_spanning_tree(
self.imshapes,
self.edges,
self.edge2pts_i,
self.edge2pts_j,
self.edge2conf_i,
self.edge2conf_j,
self.im_conf,
self.min_conf_thr,
device,
has_im_poses=self.has_im_poses,
verbose=self.verbose,
**kw,
)
return init_from_pts3d(self, pts3d, im_focals, im_poses)
def minimum_spanning_tree(
imshapes,
edges,
edge2pred_i,
edge2pred_j,
edge2conf_i,
edge2conf_j,
im_conf,
min_conf_thr,
device,
has_im_poses=True,
niter_PnP=10,
verbose=True,
save_score_path=None,
):
n_imgs = len(imshapes)
eadge_and_scores = compute_edge_scores(map(i_j_ij, edges), edge2conf_i, edge2conf_j)
sparse_graph = -dict_to_sparse_graph(eadge_and_scores)
msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()
# temp variable to store 3d points
pts3d = [None] * len(imshapes)
todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges
im_poses = [None] * n_imgs
im_focals = [None] * n_imgs
# init with strongest edge
score, i, j = todo.pop()
if verbose:
print(f" init edge ({i}*,{j}*) {score=}")
i_j = edge_str(i, j)
pts3d[i] = edge2pred_i[i_j].clone()
pts3d[j] = edge2pred_j[i_j].clone()
done = {i, j}
if has_im_poses:
im_poses[i] = torch.eye(4, device=device)
im_focals[i] = estimate_focal(edge2pred_i[i_j])
# set initial pointcloud based on pairwise graph
msp_edges = [(i, j)]
while todo:
# each time, predict the next one
score, i, j = todo.pop()
if im_focals[i] is None:
im_focals[i] = estimate_focal(edge2pred_i[i_j])
if i in done:
if verbose:
print(f" init edge ({i},{j}*) {score=}")
assert j not in done
# align pred[i] with pts3d[i], and then set j accordingly
i_j = edge_str(i, j)
s, R, T = rigid_points_registration(
edge2pred_i[i_j], pts3d[i], conf=edge2conf_i[i_j]
)
trf = sRT_to_4x4(s, R, T, device)
pts3d[j] = geotrf(trf, edge2pred_j[i_j])
done.add(j)
msp_edges.append((i, j))
if has_im_poses and im_poses[i] is None:
im_poses[i] = sRT_to_4x4(1, R, T, device)
elif j in done:
if verbose:
print(f" init edge ({i}*,{j}) {score=}")
assert i not in done
i_j = edge_str(i, j)
s, R, T = rigid_points_registration(
edge2pred_j[i_j], pts3d[j], conf=edge2conf_j[i_j]
)
trf = sRT_to_4x4(s, R, T, device)
pts3d[i] = geotrf(trf, edge2pred_i[i_j])
done.add(i)
msp_edges.append((i, j))
if has_im_poses and im_poses[i] is None:
im_poses[i] = sRT_to_4x4(1, R, T, device)
else:
# let's try again later
todo.insert(0, (score, i, j))
if has_im_poses:
# complete all missing informations
pair_scores = list(
sparse_graph.values()
) # already negative scores: less is best
edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[
np.argsort(pair_scores)
]
for i, j in edges_from_best_to_worse.tolist():
if im_focals[i] is None:
im_focals[i] = estimate_focal(edge2pred_i[edge_str(i, j)])
for i in range(n_imgs):
if im_poses[i] is None:
msk = im_conf[i] > min_conf_thr
res = fast_pnp(
pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP
)
if res:
im_focals[i], im_poses[i] = res
if im_poses[i] is None:
im_poses[i] = torch.eye(4, device=device)
im_poses = torch.stack(im_poses)
else:
im_poses = im_focals = None
return pts3d, msp_edges, im_focals, im_poses
def init_from_pts3d(self, pts3d, im_focals, im_poses):
# init poses
nkp, known_poses_msk, known_poses = self.get_known_poses()
if nkp == 1:
raise NotImplementedError(
"Would be simpler to just align everything afterwards on the single known pose"
)
elif nkp > 1:
# global rigid SE3 alignment
s, R, T = align_multiple_poses(
im_poses[known_poses_msk], known_poses[known_poses_msk]
)
trf = sRT_to_4x4(s, R, T, device=known_poses.device)
# rotate everything
im_poses = trf @ im_poses
im_poses[:, :3, :3] /= s # undo scaling on the rotation part
for img_pts3d in pts3d:
img_pts3d[:] = geotrf(trf, img_pts3d)
else:
pass # no known poses
# set all pairwise poses
for e, (i, j) in enumerate(self.edges):
i_j = edge_str(i, j)
# compute transform that goes from cam to world
s, R, T = rigid_points_registration(
self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]
)
self._set_pose(self.pw_poses, e, R, T, scale=s)
# take into account the scale normalization
s_factor = self.get_pw_norm_scale_factor()
im_poses[:, :3, 3] *= s_factor # apply downscaling factor
for img_pts3d in pts3d:
img_pts3d *= s_factor
# init all image poses
if self.has_im_poses:
for i in range(self.n_imgs):
cam2world = im_poses[i]
depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
self._set_depthmap(i, depth)
self._set_pose(self.im_poses, i, cam2world)
if im_focals[i] is not None:
if not self.shared_focal:
self._set_focal(i, im_focals[i])
if self.shared_focal:
self._set_focal(0, sum(im_focals) / self.n_imgs)
if self.n_imgs > 2:
self._set_init_depthmap()
if self.verbose:
with torch.no_grad():
print(" init loss =", float(self()))