| | |
| | |
| | |
| | |
| | |
| | |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from dust3r.cloud_opt.base_opt import BasePCOptimizer |
| | from dust3r.utils.geometry import geotrf |
| | from dust3r.utils.device import to_cpu, to_numpy |
| | from dust3r.utils.geometry import depthmap_to_pts3d |
| |
|
| |
|
| | class ModularPointCloudOptimizer (BasePCOptimizer): |
| | """ Optimize a global scene, given a list of pairwise observations. |
| | Unlike PointCloudOptimizer, you can fix parts of the optimization process (partial poses/intrinsics) |
| | Graph node: images |
| | Graph edges: observations = (pred1, pred2) |
| | """ |
| |
|
| | def __init__(self, *args, optimize_pp=False, fx_and_fy=False, focal_brake=20, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.has_im_poses = True |
| | self.focal_brake = focal_brake |
| |
|
| | |
| | self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) |
| | self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) |
| | default_focals = [self.focal_brake * np.log(max(H, W)) for H, W in self.imshapes] |
| | self.im_focals = nn.ParameterList(torch.FloatTensor([f, f] if fx_and_fy else [ |
| | f]) for f in default_focals) |
| | self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) |
| | self.im_pp.requires_grad_(optimize_pp) |
| |
|
| | def preset_pose(self, known_poses, pose_msk=None): |
| | if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: |
| | known_poses = [known_poses] |
| | for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): |
| | if self.verbose: |
| | print(f' (setting pose #{idx} = {pose[:3,3]})') |
| | self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose), force=True)) |
| |
|
| | |
| | n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) |
| | self.norm_pw_scale = (n_known_poses <= 1) |
| |
|
| | def preset_intrinsics(self, known_intrinsics, msk=None): |
| | if isinstance(known_intrinsics, torch.Tensor) and known_intrinsics.ndim == 2: |
| | known_intrinsics = [known_intrinsics] |
| | for K in known_intrinsics: |
| | assert K.shape == (3, 3) |
| | self.preset_focal([K.diagonal()[:2].mean() for K in known_intrinsics], msk) |
| | self.preset_principal_point([K[:2, 2] for K in known_intrinsics], msk) |
| |
|
| | def preset_focal(self, known_focals, msk=None): |
| | for idx, focal in zip(self._get_msk_indices(msk), known_focals): |
| | if self.verbose: |
| | print(f' (setting focal #{idx} = {focal})') |
| | self._no_grad(self._set_focal(idx, focal, force=True)) |
| |
|
| | def preset_principal_point(self, known_pp, msk=None): |
| | for idx, pp in zip(self._get_msk_indices(msk), known_pp): |
| | if self.verbose: |
| | print(f' (setting principal point #{idx} = {pp})') |
| | self._no_grad(self._set_principal_point(idx, pp, force=True)) |
| |
|
| | def _no_grad(self, tensor): |
| | return tensor.requires_grad_(False) |
| |
|
| | def _get_msk_indices(self, msk): |
| | if msk is None: |
| | return range(self.n_imgs) |
| | elif isinstance(msk, int): |
| | return [msk] |
| | elif isinstance(msk, (tuple, list)): |
| | return self._get_msk_indices(np.array(msk)) |
| | elif msk.dtype in (bool, torch.bool, np.bool_): |
| | assert len(msk) == self.n_imgs |
| | return np.where(msk)[0] |
| | elif np.issubdtype(msk.dtype, np.integer): |
| | return msk |
| | else: |
| | raise ValueError(f'bad {msk=}') |
| |
|
| | def _set_focal(self, idx, focal, force=False): |
| | param = self.im_focals[idx] |
| | if param.requires_grad or force: |
| | param.data[:] = self.focal_brake * np.log(focal) |
| | return param |
| |
|
| | def get_focals(self): |
| | log_focals = torch.stack(list(self.im_focals), dim=0) |
| | return (log_focals / self.focal_brake).exp() |
| |
|
| | def _set_principal_point(self, idx, pp, force=False): |
| | param = self.im_pp[idx] |
| | H, W = self.imshapes[idx] |
| | if param.requires_grad or force: |
| | param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10 |
| | return param |
| |
|
| | def get_principal_points(self): |
| | return torch.stack([pp.new((W/2, H/2))+10*pp for pp, (H, W) in zip(self.im_pp, self.imshapes)]) |
| |
|
| | def get_intrinsics(self): |
| | K = torch.zeros((self.n_imgs, 3, 3), device=self.device) |
| | focals = self.get_focals().view(self.n_imgs, -1) |
| | K[:, 0, 0] = focals[:, 0] |
| | K[:, 1, 1] = focals[:, -1] |
| | K[:, :2, 2] = self.get_principal_points() |
| | K[:, 2, 2] = 1 |
| | return K |
| |
|
| | def get_im_poses(self): |
| | cam2world = self._get_poses(torch.stack(list(self.im_poses))) |
| | return cam2world |
| |
|
| | def _set_depthmap(self, idx, depth, force=False): |
| | param = self.im_depthmaps[idx] |
| | if param.requires_grad or force: |
| | param.data[:] = depth.log().nan_to_num(neginf=0) |
| | return param |
| |
|
| | def get_depthmaps(self): |
| | return [d.exp() for d in self.im_depthmaps] |
| |
|
| | def focal_ex(self, focals, i): |
| | |
| | return focals[i][..., None, None].expand(1, *focals[i].shape, *self.imshapes[i]) |
| |
|
| | def depth_to_pts3d(self): |
| | |
| | focals = self.get_focals() |
| | pp = self.get_principal_points() |
| | im_poses = self.get_im_poses() |
| | depth = self.get_depthmaps() |
| | |
| | |
| | rel_ptmaps = [depthmap_to_pts3d(depth[i][None], self.focal_ex(focals, i), pp=pp[i:i+1])[0] for i in range(im_poses.shape[0])] |
| | |
| | return [geotrf(pose, ptmap) for pose, ptmap in zip(im_poses, rel_ptmaps)] |
| |
|
| | def get_pts3d(self): |
| | return self.depth_to_pts3d() |
| |
|