| | |
| | |
| | |
| | |
| | |
| | |
| | from copy import copy, deepcopy |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from dust3r.inference import get_pred_pts3d, find_opt_scaling |
| | from dust3r.utils.geometry import inv, geotrf, normalize_pointcloud |
| | from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale |
| |
|
| |
|
| | def Sum(*losses_and_masks): |
| | loss, mask = losses_and_masks[0] |
| | if loss.ndim > 0: |
| | |
| | return losses_and_masks |
| | else: |
| | |
| | for loss2, mask2 in losses_and_masks[1:]: |
| | loss = loss + loss2 |
| | return loss |
| |
|
| |
|
| | class BaseCriterion(nn.Module): |
| | def __init__(self, reduction='mean'): |
| | super().__init__() |
| | self.reduction = reduction |
| |
|
| |
|
| | class LLoss (BaseCriterion): |
| | """ L-norm loss |
| | """ |
| |
|
| | def forward(self, a, b): |
| | assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}' |
| | dist = self.distance(a, b) |
| | assert dist.ndim == a.ndim - 1 |
| | if self.reduction == 'none': |
| | return dist |
| | if self.reduction == 'sum': |
| | return dist.sum() |
| | if self.reduction == 'mean': |
| | return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) |
| | raise ValueError(f'bad {self.reduction=} mode') |
| |
|
| | def distance(self, a, b): |
| | raise NotImplementedError() |
| |
|
| |
|
| | class L21Loss (LLoss): |
| | """ Euclidean distance between 3d points """ |
| |
|
| | def distance(self, a, b): |
| | return torch.norm(a - b, dim=-1) |
| |
|
| |
|
| | L21 = L21Loss() |
| |
|
| |
|
| | class Criterion (nn.Module): |
| | def __init__(self, criterion=None): |
| | super().__init__() |
| | assert isinstance(criterion, BaseCriterion), f'{criterion} is not a proper criterion!' |
| | self.criterion = copy(criterion) |
| |
|
| | def get_name(self): |
| | return f'{type(self).__name__}({self.criterion})' |
| |
|
| | def with_reduction(self, mode='none'): |
| | res = loss = deepcopy(self) |
| | while loss is not None: |
| | assert isinstance(loss, Criterion) |
| | loss.criterion.reduction = mode |
| | loss = loss._loss2 |
| | return res |
| |
|
| |
|
| | class MultiLoss (nn.Module): |
| | """ Easily combinable losses (also keep track of individual loss values): |
| | loss = MyLoss1() + 0.1*MyLoss2() |
| | Usage: |
| | Inherit from this class and override get_name() and compute_loss() |
| | """ |
| |
|
| | def __init__(self): |
| | super().__init__() |
| | self._alpha = 1 |
| | self._loss2 = None |
| |
|
| | def compute_loss(self, *args, **kwargs): |
| | raise NotImplementedError() |
| |
|
| | def get_name(self): |
| | raise NotImplementedError() |
| |
|
| | def __mul__(self, alpha): |
| | assert isinstance(alpha, (int, float)) |
| | res = copy(self) |
| | res._alpha = alpha |
| | return res |
| | __rmul__ = __mul__ |
| |
|
| | def __add__(self, loss2): |
| | assert isinstance(loss2, MultiLoss) |
| | res = cur = copy(self) |
| | |
| | while cur._loss2 is not None: |
| | cur = cur._loss2 |
| | cur._loss2 = loss2 |
| | return res |
| |
|
| | def __repr__(self): |
| | name = self.get_name() |
| | if self._alpha != 1: |
| | name = f'{self._alpha:g}*{name}' |
| | if self._loss2: |
| | name = f'{name} + {self._loss2}' |
| | return name |
| |
|
| | def forward(self, *args, **kwargs): |
| | loss = self.compute_loss(*args, **kwargs) |
| | if isinstance(loss, tuple): |
| | loss, details = loss |
| | elif loss.ndim == 0: |
| | details = {self.get_name(): float(loss)} |
| | else: |
| | details = {} |
| | loss = loss * self._alpha |
| |
|
| | if self._loss2: |
| | loss2, details2 = self._loss2(*args, **kwargs) |
| | loss = loss + loss2 |
| | details |= details2 |
| |
|
| | return loss, details |
| |
|
| |
|
| | class Regr3D (Criterion, MultiLoss): |
| | """ Ensure that all 3D points are correct. |
| | Asymmetric loss: view1 is supposed to be the anchor. |
| | |
| | P1 = RT1 @ D1 |
| | P2 = RT2 @ D2 |
| | loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1) |
| | loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2) |
| | = (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2) |
| | """ |
| |
|
| | def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False): |
| | super().__init__(criterion) |
| | self.norm_mode = norm_mode |
| | self.gt_scale = gt_scale |
| |
|
| | def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None): |
| | |
| | in_camera1 = inv(gt1['camera_pose']) |
| | gt_pts1 = geotrf(in_camera1, gt1['pts3d']) |
| | gt_pts2 = geotrf(in_camera1, gt2['pts3d']) |
| |
|
| | valid1 = gt1['valid_mask'].clone() |
| | valid2 = gt2['valid_mask'].clone() |
| |
|
| | if dist_clip is not None: |
| | |
| | dis1 = gt_pts1.norm(dim=-1) |
| | dis2 = gt_pts2.norm(dim=-1) |
| | valid1 = valid1 & (dis1 <= dist_clip) |
| | valid2 = valid2 & (dis2 <= dist_clip) |
| |
|
| | pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False) |
| | pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True) |
| |
|
| | |
| | if self.norm_mode: |
| | pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2) |
| | if self.norm_mode and not self.gt_scale: |
| | gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2) |
| |
|
| | return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, {} |
| |
|
| | def compute_loss(self, gt1, gt2, pred1, pred2, **kw): |
| | gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ |
| | self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw) |
| | |
| | l1 = self.criterion(pred_pts1[mask1], gt_pts1[mask1]) |
| | |
| | l2 = self.criterion(pred_pts2[mask2], gt_pts2[mask2]) |
| | self_name = type(self).__name__ |
| | details = {self_name + '_pts3d_1': float(l1.mean()), self_name + '_pts3d_2': float(l2.mean())} |
| | return Sum((l1, mask1), (l2, mask2)), (details | monitoring) |
| |
|
| |
|
| | class ConfLoss (MultiLoss): |
| | """ Weighted regression by learned confidence. |
| | Assuming the input pixel_loss is a pixel-level regression loss. |
| | |
| | Principle: |
| | high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10) |
| | low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10) |
| | |
| | alpha: hyperparameter |
| | """ |
| |
|
| | def __init__(self, pixel_loss, alpha=1): |
| | super().__init__() |
| | assert alpha > 0 |
| | self.alpha = alpha |
| | self.pixel_loss = pixel_loss.with_reduction('none') |
| |
|
| | def get_name(self): |
| | return f'ConfLoss({self.pixel_loss})' |
| |
|
| | def get_conf_log(self, x): |
| | return x, torch.log(x) |
| |
|
| | def compute_loss(self, gt1, gt2, pred1, pred2, **kw): |
| | |
| | ((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw) |
| | if loss1.numel() == 0: |
| | print('NO VALID POINTS in img1', force=True) |
| | if loss2.numel() == 0: |
| | print('NO VALID POINTS in img2', force=True) |
| |
|
| | |
| | conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1]) |
| | conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2]) |
| | conf_loss1 = loss1 * conf1 - self.alpha * log_conf1 |
| | conf_loss2 = loss2 * conf2 - self.alpha * log_conf2 |
| |
|
| | |
| | conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0 |
| | conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0 |
| |
|
| | return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details) |
| |
|
| |
|
| | class Regr3D_ShiftInv (Regr3D): |
| | """ Same than Regr3D but invariant to depth shift. |
| | """ |
| |
|
| | def get_all_pts3d(self, gt1, gt2, pred1, pred2): |
| | |
| | gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ |
| | super().get_all_pts3d(gt1, gt2, pred1, pred2) |
| |
|
| | |
| | gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2] |
| | pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2] |
| | gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None] |
| | pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None] |
| |
|
| | |
| | gt_z1 -= gt_shift_z |
| | gt_z2 -= gt_shift_z |
| | pred_z1 -= pred_shift_z |
| | pred_z2 -= pred_shift_z |
| |
|
| | |
| | return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring |
| |
|
| |
|
| | class Regr3D_ScaleInv (Regr3D): |
| | """ Same than Regr3D but invariant to depth shift. |
| | if gt_scale == True: enforce the prediction to take the same scale than GT |
| | """ |
| |
|
| | def get_all_pts3d(self, gt1, gt2, pred1, pred2): |
| | |
| | gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = super().get_all_pts3d(gt1, gt2, pred1, pred2) |
| |
|
| | |
| | _, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2) |
| | _, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2) |
| |
|
| | |
| | pred_scale = pred_scale.clip(min=1e-3, max=1e3) |
| |
|
| | |
| | if self.gt_scale: |
| | pred_pts1 *= gt_scale / pred_scale |
| | pred_pts2 *= gt_scale / pred_scale |
| | |
| | else: |
| | gt_pts1 /= gt_scale |
| | gt_pts2 /= gt_scale |
| | pred_pts1 /= pred_scale |
| | pred_pts2 /= pred_scale |
| | |
| |
|
| | return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring |
| |
|
| |
|
| | class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv): |
| | |
| | pass |
| |
|