Spaces:
Configuration error
Configuration error
| import torch | |
| import torch.nn as nn | |
| from copy import copy, deepcopy | |
| from stream3r.dust3r.utils.misc import invalid_to_zeros, invalid_to_nans | |
| from stream3r.dust3r.utils.geometry import inv, geotrf, depthmap_to_pts3d | |
| from stream3r.dust3r.utils.camera import pose_encoding_to_camera | |
| class BaseCriterion(nn.Module): | |
| def __init__(self, reduction="mean"): | |
| super().__init__() | |
| self.reduction = reduction | |
| class Criterion(nn.Module): | |
| def __init__(self, criterion=None): | |
| super().__init__() | |
| assert isinstance( | |
| criterion, | |
| BaseCriterion), f"{criterion} is not a proper criterion!" | |
| self.criterion = copy(criterion) | |
| def get_name(self): | |
| return f"{type(self).__name__}({self.criterion})" | |
| def with_reduction(self, mode="none"): | |
| res = loss = deepcopy(self) | |
| while loss is not None: | |
| assert isinstance(loss, Criterion) | |
| loss.criterion.reduction = mode # make it return the loss for each sample | |
| loss = loss._loss2 # we assume loss is a Multiloss | |
| return res | |
| class MultiLoss(nn.Module): | |
| """Easily combinable losses (also keep track of individual loss values): | |
| loss = MyLoss1() + 0.1*MyLoss2() | |
| Usage: | |
| Inherit from this class and override get_name() and compute_loss() | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self._alpha = 1 | |
| self._loss2 = None | |
| def compute_loss(self, *args, **kwargs): | |
| raise NotImplementedError() | |
| def get_name(self): | |
| raise NotImplementedError() | |
| def __mul__(self, alpha): | |
| assert isinstance(alpha, (int, float)) | |
| res = copy(self) | |
| res._alpha = alpha | |
| return res | |
| __rmul__ = __mul__ # same | |
| def __add__(self, loss2): | |
| assert isinstance(loss2, MultiLoss) | |
| res = cur = copy(self) | |
| while cur._loss2 is not None: | |
| cur = cur._loss2 | |
| cur._loss2 = loss2 | |
| return res | |
| def __repr__(self): | |
| name = self.get_name() | |
| if self._alpha != 1: | |
| name = f"{self._alpha:g}*{name}" | |
| if self._loss2: | |
| name = f"{name} + {self._loss2}" | |
| return name | |
| def forward(self, *args, **kwargs): | |
| loss = self.compute_loss(*args, **kwargs) | |
| if isinstance(loss, tuple): | |
| loss, details = loss | |
| elif loss.ndim == 0: | |
| details = {self.get_name(): float(loss)} | |
| else: | |
| details = {} | |
| loss = loss * self._alpha | |
| if self._loss2: | |
| loss2, details2 = self._loss2(*args, **kwargs) | |
| loss = loss + loss2 | |
| details |= details2 | |
| return loss, details | |
| class LLoss(BaseCriterion): | |
| """L-norm loss""" | |
| def forward(self, a, b): | |
| assert (a.shape == b.shape and a.ndim >= 2 | |
| and 1 <= a.shape[-1] <= 3), f"Bad shape = {a.shape}" | |
| dist = self.distance(a, b) | |
| if self.reduction == "none": | |
| return dist | |
| if self.reduction == "sum": | |
| return dist.sum() | |
| if self.reduction == "mean": | |
| return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) | |
| raise ValueError(f"bad {self.reduction=} mode") | |
| def distance(self, a, b): | |
| raise NotImplementedError() | |
| class L21Loss(LLoss): | |
| """Euclidean distance between 3d points""" | |
| def distance(self, a, b): | |
| return torch.norm(a - b, dim=-1) # normalized L2 distance | |
| L21 = L21Loss() | |
| def get_pred_pts3d(gt, pred, use_pose=False): | |
| if "depth" in pred and "pseudo_focal" in pred: | |
| try: | |
| pp = gt["camera_intrinsics"][..., :2, 2] | |
| except KeyError: | |
| pp = None | |
| pts3d = depthmap_to_pts3d(**pred, pp=pp) | |
| elif "pts3d" in pred: | |
| # pts3d from my camera | |
| pts3d = pred["pts3d"] | |
| elif "pts3d_in_other_view" in pred: | |
| # pts3d from the other camera, already transformed | |
| assert use_pose is False | |
| return pred["pts3d_in_other_view"] # return! | |
| if use_pose: | |
| camera_pose = pred.get("camera_pose") | |
| pts3d = pred.get("pts3d_in_self_view") | |
| assert camera_pose is not None | |
| assert pts3d is not None | |
| pts3d = geotrf(pose_encoding_to_camera(camera_pose), pts3d) | |
| return pts3d | |
| def Sum(losses, masks, conf=None): | |
| loss, mask = losses[0], masks[0] | |
| if loss.ndim > 0: | |
| # we are actually returning the loss for every pixels | |
| if conf is not None: | |
| return losses, masks, conf | |
| return losses, masks | |
| else: | |
| # we are returning the global loss | |
| for loss2 in losses[1:]: | |
| loss = loss + loss2 | |
| return loss | |
| def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True): | |
| assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3 | |
| assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3) | |
| norm_mode, dis_mode = norm_mode.split("_") | |
| nan_pts = [] | |
| nnzs = [] | |
| if norm_mode == "avg": | |
| # gather all points together (joint normalization) | |
| for i, pt in enumerate(pts): | |
| nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3) | |
| nan_pts.append(nan_pt) | |
| nnzs.append(nnz) | |
| if fix_first: | |
| break | |
| all_pts = torch.cat(nan_pts, dim=1) | |
| # compute distance to origin | |
| all_dis = all_pts.norm(dim=-1) | |
| if dis_mode == "dis": | |
| pass # do nothing | |
| elif dis_mode == "log1p": | |
| all_dis = torch.log1p(all_dis) | |
| else: | |
| raise ValueError(f"bad {dis_mode=}") | |
| norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8) | |
| else: | |
| raise ValueError(f"Not implemented {norm_mode=}") | |
| norm_factor = norm_factor.clip(min=1e-8) | |
| while norm_factor.ndim < pts[0].ndim: | |
| norm_factor.unsqueeze_(-1) | |
| return norm_factor | |
| def normalize_pointcloud_t(pts, | |
| norm_mode="avg_dis", | |
| valids=None, | |
| fix_first=True, | |
| gt=False): | |
| if gt: | |
| norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first) | |
| res = [] | |
| for i, pt in enumerate(pts): | |
| res.append(pt / norm_factor) | |
| else: | |
| # pts_l, pts_r = pts | |
| # use pts_l and pts_r[-1] as pts to normalize | |
| norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first) | |
| res = [] | |
| for i in range(len(pts)): | |
| res.append(pts[i] / norm_factor) | |
| # res_r.append(pts_r[i] / norm_factor) | |
| # res = [res_l, res_r] | |
| return res, norm_factor | |
| def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5): | |
| # set invalid points to NaN | |
| _zs = [] | |
| for i in range(len(zs)): | |
| valid_mask = valid_masks[i] if valid_masks is not None else None | |
| _z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1) | |
| _zs.append(_z) | |
| _zs = torch.cat(_zs, dim=-1) | |
| # compute median depth overall (ignoring nans) | |
| if quantile == 0.5: | |
| shift_z = torch.nanmedian(_zs, dim=-1).values | |
| else: | |
| shift_z = torch.nanquantile(_zs, quantile, dim=-1) | |
| return shift_z # (B,) | |
| def get_joint_pointcloud_center_scale(pts, | |
| valid_masks=None, | |
| z_only=False, | |
| center=True): | |
| # set invalid points to NaN | |
| _pts = [] | |
| for i in range(len(pts)): | |
| valid_mask = valid_masks[i] if valid_masks is not None else None | |
| _pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3) | |
| _pts.append(_pt) | |
| _pts = torch.cat(_pts, dim=1) | |
| # compute median center | |
| _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) | |
| if z_only: | |
| _center[..., :2] = 0 # do not center X and Y | |
| # compute median norm | |
| _norm = ((_pts - _center) if center else _pts).norm(dim=-1) | |
| scale = torch.nanmedian(_norm, dim=1).values | |
| return _center[:, None, :, :], scale[:, None, None, None] | |
| class Regr3D_t(Criterion, MultiLoss): | |
| def __init__(self, | |
| criterion, | |
| norm_mode="avg_dis", | |
| gt_scale=False, | |
| fix_first=True): | |
| super().__init__(criterion) | |
| self.norm_mode = norm_mode | |
| self.gt_scale = gt_scale | |
| self.fix_first = fix_first | |
| def get_all_pts3d_t(self, gts, preds, dist_clip=None): | |
| # everything is normalized w.r.t. camera of view1 | |
| in_camera1 = inv(gts[0]["camera_pose"]) | |
| gt_pts = [] | |
| valids = [] | |
| pr_pts = [] | |
| for i, gt in enumerate(gts): | |
| # in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3 | |
| gt_pts.append(geotrf(in_camera1, gt["pts3d"])) | |
| valid = gt["valid_mask"].clone() | |
| if dist_clip is not None: | |
| # points that are too far-away == invalid | |
| dis = gt["pts3d"].norm(dim=-1) | |
| valid = valid & (dis <= dist_clip) | |
| valids.append(valid) | |
| # pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True)) | |
| pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=False)) | |
| # if i != len(gts)-1: | |
| # pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0))) | |
| # if i != 0: | |
| # pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0))) | |
| # pr_pts = (pr_pts_l, pr_pts_r) | |
| if self.norm_mode: | |
| pr_pts, pr_factor = normalize_pointcloud_t( | |
| pr_pts, | |
| self.norm_mode, | |
| valids, | |
| fix_first=self.fix_first, | |
| gt=False) | |
| else: | |
| pr_factor = None | |
| if self.norm_mode and not self.gt_scale: | |
| gt_pts, gt_factor = normalize_pointcloud_t( | |
| gt_pts, | |
| self.norm_mode, | |
| valids, | |
| fix_first=self.fix_first, | |
| gt=True) | |
| else: | |
| gt_factor = None | |
| return gt_pts, pr_pts, gt_factor, pr_factor, valids, {} | |
| def compute_frame_loss(self, gts, preds, **kw): | |
| gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( | |
| self.get_all_pts3d_t(gts, preds, **kw)) | |
| pred_pts_l, pred_pts_r = pred_pts | |
| loss_all = [] | |
| mask_all = [] | |
| conf_all = [] | |
| loss_left = 0 | |
| loss_right = 0 | |
| pred_conf_l = 0 | |
| pred_conf_r = 0 | |
| for i in range(len(gt_pts)): | |
| # Left (Reference) | |
| if i != len(gt_pts) - 1: | |
| frame_loss = self.criterion(pred_pts_l[i][masks[i]], | |
| gt_pts[i][masks[i]]) | |
| loss_all.append(frame_loss) | |
| mask_all.append(masks[i]) | |
| conf_all.append(preds[i][0]["conf"]) | |
| # To compare target/reference loss | |
| if i != 0: | |
| loss_left += frame_loss.cpu().detach().numpy().mean() | |
| pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy( | |
| ).mean() | |
| # Right (Target) | |
| if i != 0: | |
| frame_loss = self.criterion(pred_pts_r[i - 1][masks[i]], | |
| gt_pts[i][masks[i]]) | |
| loss_all.append(frame_loss) | |
| mask_all.append(masks[i]) | |
| conf_all.append(preds[i - 1][1]["conf"]) | |
| # To compare target/reference loss | |
| if i != len(gt_pts) - 1: | |
| loss_right += frame_loss.cpu().detach().numpy().mean() | |
| pred_conf_r += preds[ | |
| i - 1][1]["conf"].cpu().detach().numpy().mean() | |
| if pr_factor is not None and gt_factor is not None: | |
| filter_factor = pr_factor[pr_factor > gt_factor] | |
| else: | |
| filter_factor = [] | |
| if len(filter_factor) > 0: | |
| factor_loss = (filter_factor - gt_factor).abs().mean() | |
| else: | |
| factor_loss = 0.0 | |
| self_name = type(self).__name__ | |
| details = { | |
| self_name + "_pts3d_1": float(loss_all[0].mean()), | |
| self_name + "_pts3d_2": float(loss_all[1].mean()), | |
| self_name + "loss_left": float(loss_left), | |
| self_name + "loss_right": float(loss_right), | |
| self_name + "conf_left": float(pred_conf_l), | |
| self_name + "conf_right": float(pred_conf_r), | |
| } | |
| return Sum(loss_all, mask_all, | |
| conf_all), (details | monitoring), factor_loss | |
| class ConfLoss_t(MultiLoss): | |
| """Weighted regression by learned confidence. | |
| Assuming the input pixel_loss is a pixel-level regression loss. | |
| Principle: | |
| high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10) | |
| low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10) | |
| alpha: hyperparameter | |
| """ | |
| def __init__(self, pixel_loss, alpha=1): | |
| super().__init__() | |
| assert alpha > 0 | |
| self.alpha = alpha | |
| self.pixel_loss = pixel_loss.with_reduction("none") | |
| def get_name(self): | |
| return f"ConfLoss({self.pixel_loss})" | |
| def get_conf_log(self, x): | |
| return x, torch.log(x) | |
| def compute_frame_loss(self, gts, preds, **kw): | |
| # compute per-pixel loss | |
| (losses, masks, | |
| confs), details, loss_factor = (self.pixel_loss.compute_frame_loss( | |
| gts, preds, **kw)) | |
| # weight by confidence | |
| conf_losses = [] | |
| conf_sum = 0 | |
| for i in range(len(losses)): | |
| conf, log_conf = self.get_conf_log(confs[i][masks[i]]) | |
| conf_sum += conf.mean() | |
| conf_loss = losses[i] * conf - self.alpha * log_conf | |
| conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0 | |
| conf_losses.append(conf_loss) | |
| conf_losses = torch.stack(conf_losses) * 2.0 | |
| conf_loss_mean = conf_losses.mean() | |
| return ( | |
| conf_loss_mean, | |
| dict( | |
| conf_loss_1=float(conf_losses[0]), | |
| conf_loss2=float(conf_losses[1]), | |
| conf_mean=conf_sum / len(losses), | |
| **details, | |
| ), | |
| loss_factor, | |
| ) | |
| class Regr3D_t_ShiftInv(Regr3D_t): | |
| """Same than Regr3D but invariant to depth shift.""" | |
| def get_all_pts3d_t(self, gts, preds): | |
| # compute unnormalized points | |
| gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( | |
| super().get_all_pts3d_t(gts, preds)) | |
| # pred_pts_l, pred_pts_r = pred_pts | |
| gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts] | |
| pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts] | |
| # pred_zs.append(pred_pts_r[-1][..., 2]) | |
| # compute median depth | |
| gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None] | |
| pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, | |
| None] | |
| # subtract the median depth | |
| for i in range(len(gt_pts)): | |
| gt_pts[i][..., 2] -= gt_shift_z | |
| for i in range(len(pred_pts)): | |
| # for j in range(len(pred_pts[i])): | |
| # pred_pts[i][..., 2] -= pred_shift_z | |
| pred_pts[i] = pred_pts[i].clone() | |
| pred_pts[i][..., 2] -= pred_shift_z # avoid in-place modification | |
| monitoring = dict( | |
| monitoring, | |
| gt_shift_z=gt_shift_z.mean().detach(), | |
| pred_shift_z=pred_shift_z.mean().detach(), | |
| ) | |
| return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring | |
| class Regr3D_t_ScaleInv(Regr3D_t): | |
| """Same than Regr3D but invariant to depth shift. | |
| if gt_scale == True: enforce the prediction to take the same scale than GT | |
| """ | |
| def get_all_pts3d_t(self, gts, preds): | |
| # compute depth-normalized points | |
| gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( | |
| super().get_all_pts3d_t(gts, preds)) | |
| pred_pts_all = [x.clone() for x in pred_pts | |
| ] # [pred_pt for pred_pt in pred_pts_l] | |
| _, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks) | |
| _, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks) | |
| # prevent predictions to be in a ridiculous range | |
| pred_scale = pred_scale.clip(min=1e-3, max=1e3) | |
| # subtract the median depth | |
| if self.gt_scale: | |
| for i in range(len(pred_pts)): | |
| # for j in range(len(pred_pts[i])): | |
| pred_pts[i] *= gt_scale / pred_scale | |
| else: | |
| for i in range(len(pred_pts)): | |
| # for j in range(len(pred_pts[i])): | |
| pred_pts[i] *= pred_scale / gt_scale | |
| for i in range(len(gt_pts)): | |
| gt_pts[i] *= gt_scale / pred_scale | |
| monitoring = dict(monitoring, | |
| gt_scale=gt_scale.mean(), | |
| pred_scale=pred_scale.mean().detach()) | |
| return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring | |
| class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv): | |
| # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv | |
| pass | |