|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from cloud_opt.dust3r_opt.base_opt import BasePCOptimizer |
|
|
from dust3r.utils.geometry import xy_grid, geotrf |
|
|
from dust3r.utils.device import to_cpu, to_numpy |
|
|
|
|
|
|
|
|
class PointCloudOptimizer(BasePCOptimizer): |
|
|
"""Optimize a global scene, given a list of pairwise observations. |
|
|
Graph node: images |
|
|
Graph edges: observations = (pred1, pred2) |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
self.has_im_poses = True |
|
|
self.focal_break = focal_break |
|
|
|
|
|
|
|
|
self.im_depthmaps = nn.ParameterList( |
|
|
torch.randn(H, W) / 10 - 3 for H, W in self.imshapes |
|
|
) |
|
|
self.im_poses = nn.ParameterList( |
|
|
self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs) |
|
|
) |
|
|
self.im_focals = nn.ParameterList( |
|
|
torch.FloatTensor([self.focal_break * np.log(max(H, W))]) |
|
|
for H, W in self.imshapes |
|
|
) |
|
|
self.im_pp = nn.ParameterList( |
|
|
torch.zeros((2,)) for _ in range(self.n_imgs) |
|
|
) |
|
|
self.im_pp.requires_grad_(optimize_pp) |
|
|
|
|
|
self.imshape = self.imshapes[0] |
|
|
im_areas = [h * w for h, w in self.imshapes] |
|
|
self.max_area = max(im_areas) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.im_poses = ParameterStack(self.im_poses, is_param=True) |
|
|
self.im_focals = ParameterStack(self.im_focals, is_param=True) |
|
|
self.im_pp = ParameterStack(self.im_pp, is_param=True) |
|
|
self.register_buffer( |
|
|
"_pp", torch.tensor([(w / 2, h / 2) for h, w in self.imshapes]) |
|
|
) |
|
|
self.register_buffer( |
|
|
"_grid", |
|
|
ParameterStack( |
|
|
[xy_grid(W, H, device=self.device) for H, W in self.imshapes], |
|
|
fill=self.max_area, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
self.register_buffer( |
|
|
"_weight_i", |
|
|
ParameterStack( |
|
|
[self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], |
|
|
fill=self.max_area, |
|
|
), |
|
|
) |
|
|
self.register_buffer( |
|
|
"_weight_j", |
|
|
ParameterStack( |
|
|
[self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], |
|
|
fill=self.max_area, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
self.register_buffer( |
|
|
"_stacked_pred_i", |
|
|
ParameterStack(self.pred_i, self.str_edges, fill=self.max_area), |
|
|
) |
|
|
self.register_buffer( |
|
|
"_stacked_pred_j", |
|
|
ParameterStack(self.pred_j, self.str_edges, fill=self.max_area), |
|
|
) |
|
|
self.register_buffer("_ei", torch.tensor([i for i, j in self.edges])) |
|
|
self.register_buffer("_ej", torch.tensor([j for i, j in self.edges])) |
|
|
self.total_area_i = sum([im_areas[i] for i, j in self.edges]) |
|
|
self.total_area_j = sum([im_areas[j] for i, j in self.edges]) |
|
|
|
|
|
def _check_all_imgs_are_selected(self, msk): |
|
|
assert np.all( |
|
|
self._get_msk_indices(msk) == np.arange(self.n_imgs) |
|
|
), "incomplete mask!" |
|
|
|
|
|
def preset_pose(self, known_poses, pose_msk=None): |
|
|
self._check_all_imgs_are_selected(pose_msk) |
|
|
|
|
|
if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: |
|
|
known_poses = [known_poses] |
|
|
for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): |
|
|
if self.verbose: |
|
|
print(f" (setting pose #{idx} = {pose[:3,3]})") |
|
|
self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose))) |
|
|
|
|
|
|
|
|
self.im_poses.requires_grad_(False) |
|
|
for p in self.im_poses: |
|
|
print(p.requires_grad) |
|
|
print(p.data) |
|
|
n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) |
|
|
self.norm_pw_scale = n_known_poses <= 1 |
|
|
|
|
|
|
|
|
self.norm_pw_scale = False |
|
|
|
|
|
def preset_focal(self, known_focals, msk=None): |
|
|
self._check_all_imgs_are_selected(msk) |
|
|
|
|
|
for idx, focal in zip(self._get_msk_indices(msk), known_focals): |
|
|
if self.verbose: |
|
|
print(f" (setting focal #{idx} = {focal})") |
|
|
self._no_grad(self._set_focal(idx, focal)) |
|
|
|
|
|
self.im_focals.requires_grad_(False) |
|
|
|
|
|
def preset_principal_point(self, known_pp, msk=None): |
|
|
self._check_all_imgs_are_selected(msk) |
|
|
|
|
|
for idx, pp in zip(self._get_msk_indices(msk), known_pp): |
|
|
if self.verbose: |
|
|
print(f" (setting principal point #{idx} = {pp})") |
|
|
self._no_grad(self._set_principal_point(idx, pp)) |
|
|
|
|
|
self.im_pp.requires_grad_(False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_msk_indices(self, msk): |
|
|
if msk is None: |
|
|
return range(self.n_imgs) |
|
|
elif isinstance(msk, int): |
|
|
return [msk] |
|
|
elif isinstance(msk, (tuple, list)): |
|
|
return self._get_msk_indices(np.array(msk)) |
|
|
elif msk.dtype in (bool, torch.bool, np.bool_): |
|
|
assert len(msk) == self.n_imgs |
|
|
return np.where(msk)[0] |
|
|
elif np.issubdtype(msk.dtype, np.integer): |
|
|
return msk |
|
|
else: |
|
|
raise ValueError(f"bad {msk=}") |
|
|
|
|
|
def _no_grad(self, tensor): |
|
|
assert ( |
|
|
tensor.requires_grad |
|
|
), "it must be True at this point, otherwise no modification occurs" |
|
|
|
|
|
def _set_focal(self, idx, focal, force=False): |
|
|
param = self.im_focals[idx] |
|
|
if ( |
|
|
param.requires_grad or force |
|
|
): |
|
|
param.data[:] = self.focal_break * np.log(focal) |
|
|
return param |
|
|
|
|
|
def get_focals(self): |
|
|
log_focals = torch.stack(list(self.im_focals), dim=0) |
|
|
return (log_focals / self.focal_break).exp() |
|
|
|
|
|
def get_known_focal_mask(self): |
|
|
return torch.tensor([not (p.requires_grad) for p in self.im_focals]) |
|
|
|
|
|
def _set_principal_point(self, idx, pp, force=False): |
|
|
param = self.im_pp[idx] |
|
|
H, W = self.imshapes[idx] |
|
|
if ( |
|
|
param.requires_grad or force |
|
|
): |
|
|
param.data[:] = to_cpu(to_numpy(pp) - (W / 2, H / 2)) / 10 |
|
|
return param |
|
|
|
|
|
def get_principal_points(self): |
|
|
return self._pp + 10 * self.im_pp |
|
|
|
|
|
def get_intrinsics(self): |
|
|
K = torch.zeros((self.n_imgs, 3, 3), device=self.device) |
|
|
focals = self.get_focals().flatten() |
|
|
K[:, 0, 0] = K[:, 1, 1] = focals |
|
|
K[:, :2, 2] = self.get_principal_points() |
|
|
K[:, 2, 2] = 1 |
|
|
return K |
|
|
|
|
|
def get_im_poses(self): |
|
|
cam2world = self._get_poses(self.im_poses) |
|
|
return cam2world |
|
|
|
|
|
|
|
|
def preset_depth(self, known_depths, msk=None): |
|
|
"""Preset known depth maps for specified images. |
|
|
|
|
|
Args: |
|
|
known_depths: List of depth maps or single depth map (should be in normal depth space, not log space) |
|
|
msk: Mask or indices indicating which images to preset. If None, applies to all images. |
|
|
""" |
|
|
self._check_all_imgs_are_selected(msk) |
|
|
|
|
|
if isinstance(known_depths, (torch.Tensor, np.ndarray)) and known_depths.ndim == 2: |
|
|
known_depths = [known_depths] |
|
|
|
|
|
for idx, depth in zip(self._get_msk_indices(msk), known_depths): |
|
|
if self.verbose: |
|
|
print(f" (setting depth #{idx})") |
|
|
|
|
|
depth = _ravel_hw(depth, self.max_area).view(self.imshapes[idx]) |
|
|
self._no_grad(self._set_depthmap(idx, torch.tensor(depth))) |
|
|
self.im_depthmaps[idx].requires_grad_(False) |
|
|
|
|
|
|
|
|
def _set_depthmap(self, idx, depth, force=False): |
|
|
"""Set a depth map for an image. |
|
|
|
|
|
Args: |
|
|
idx: Image index |
|
|
depth: Depth map in normal space (not log space) |
|
|
force: Whether to force setting even if already initialized |
|
|
""" |
|
|
depth = _ravel_hw(depth, self.max_area) |
|
|
depth = depth.view(self.imshapes[idx]) |
|
|
depth = depth.nan_to_num(neginf=0) |
|
|
param = self.im_depthmaps[idx] |
|
|
if ( |
|
|
param.requires_grad or force |
|
|
): |
|
|
param.data[:] = depth.log().nan_to_num(neginf=0) |
|
|
return param |
|
|
|
|
|
def get_depthmaps(self, raw=False): |
|
|
res = ParameterStack(self.im_depthmaps, is_param=False).exp() |
|
|
if not raw: |
|
|
res = [dm[: h * w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)] |
|
|
return res |
|
|
|
|
|
def depth_to_pts3d(self): |
|
|
|
|
|
focals = self.get_focals() |
|
|
pp = self.get_principal_points() |
|
|
im_poses = self.get_im_poses() |
|
|
depth = self.get_depthmaps(raw=True) |
|
|
|
|
|
|
|
|
rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp) |
|
|
|
|
|
return geotrf(im_poses, rel_ptmaps) |
|
|
|
|
|
def get_pts3d(self, raw=False): |
|
|
res = self.depth_to_pts3d() |
|
|
if not raw: |
|
|
res = [dm[: h * w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] |
|
|
return res |
|
|
|
|
|
def forward(self): |
|
|
pw_poses = self.get_pw_poses() |
|
|
pw_adapt = self.get_adaptors().unsqueeze(1) |
|
|
proj_pts3d = self.get_pts3d(raw=True) |
|
|
|
|
|
|
|
|
aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i) |
|
|
aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j) |
|
|
|
|
|
|
|
|
li = ( |
|
|
self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() |
|
|
/ self.total_area_i |
|
|
) |
|
|
lj = ( |
|
|
self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() |
|
|
/ self.total_area_j |
|
|
) |
|
|
|
|
|
return li + lj |
|
|
|
|
|
|
|
|
def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp): |
|
|
pp = pp.unsqueeze(1) |
|
|
focal = focal.unsqueeze(1) |
|
|
if depth.ndim == 3: |
|
|
depth = depth.view(depth.shape[0], -1) |
|
|
assert focal.shape == (len(depth), 1, 1) |
|
|
assert pp.shape == (len(depth), 1, 2) |
|
|
assert pixel_grid.shape == depth.shape + (2,) |
|
|
depth = depth.unsqueeze(-1) |
|
|
return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1) |
|
|
|
|
|
|
|
|
def ParameterStack(params, keys=None, is_param=None, fill=0): |
|
|
if keys is not None: |
|
|
params = [params[k] for k in keys] |
|
|
|
|
|
if fill > 0: |
|
|
params = [_ravel_hw(p, fill) for p in params] |
|
|
|
|
|
requires_grad = params[0].requires_grad |
|
|
assert all(p.requires_grad == requires_grad for p in params) if is_param else True |
|
|
|
|
|
params = torch.stack(list(params)).float().detach() |
|
|
if is_param or requires_grad: |
|
|
params = nn.Parameter(params) |
|
|
params.requires_grad_(requires_grad) |
|
|
return params |
|
|
|
|
|
|
|
|
def _ravel_hw(tensor, fill=0): |
|
|
|
|
|
tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) |
|
|
|
|
|
if len(tensor) < fill: |
|
|
tensor = torch.cat( |
|
|
(tensor, tensor.new_zeros((fill - len(tensor),) + tensor.shape[1:])) |
|
|
) |
|
|
return tensor |
|
|
|
|
|
|
|
|
def acceptable_focal_range(H, W, minf=0.5, maxf=3.5): |
|
|
focal_base = max(H, W) / ( |
|
|
2 * np.tan(np.deg2rad(60) / 2) |
|
|
) |
|
|
return minf * focal_base, maxf * focal_base |
|
|
|
|
|
|
|
|
def apply_mask(img, msk): |
|
|
img = img.copy() |
|
|
img[msk] = 0 |
|
|
return img |
|
|
|