| """ |
| Oracle Ensemble Loss Functions: Uncertainty-weighted losses using continuous confidence. |
| |
| Uses oracle uncertainty propagation to create continuous confidence masks that |
| weight training by uncertainty rather than binary rejection. |
| """ |
|
|
| import logging |
| from typing import Dict, Optional |
| import torch |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def oracle_uncertainty_weighted_pose_loss( |
| poses_pred: torch.Tensor, |
| poses_target: torch.Tensor, |
| confidence: torch.Tensor, |
| uncertainty: Optional[torch.Tensor] = None, |
| weight_rotation: float = 1.0, |
| weight_translation: float = 0.1, |
| use_uncertainty_weighting: bool = True, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Compute pose loss weighted by continuous oracle confidence/uncertainty. |
| |
| Uses continuous confidence scores rather than binary rejection. |
| |
| Args: |
| poses_pred: Predicted poses (N, 3, 4) w2c |
| poses_target: Target poses (N, 3, 4) w2c |
| confidence: Frame-level confidence scores (N,) [0.0-1.0] from oracle ensemble |
| uncertainty: Optional pose uncertainty (N, 6) for covariance-aware loss |
| weight_rotation: Weight for rotation loss |
| weight_translation: Weight for translation loss |
| use_uncertainty_weighting: If True, weight by confidence; if False, use uniform |
| |
| Returns: |
| Dict with: |
| - 'total_loss': Combined weighted loss |
| - 'rotation_loss': Rotation component |
| - 'translation_loss': Translation component |
| - 'mean_confidence': Average confidence of frames |
| - 'num_frames': Total number of frames |
| """ |
| N = len(poses_pred) |
|
|
| |
| R_pred = poses_pred[:, :3, :3] |
| t_pred = poses_pred[:, :3, 3] |
| R_target = poses_target[:, :3, :3] |
| t_target = poses_target[:, :3, 3] |
|
|
| |
| |
| R_diff = torch.matmul(R_pred, R_target.transpose(-2, -1)) |
| trace = torch.diagonal(R_diff, dim1=-2, dim2=-1).sum(dim=-1) |
| trace_clamped = torch.clamp(trace, -1.0, 3.0) |
| rot_errors = torch.acos((trace_clamped - 1.0) / 2.0) |
|
|
| |
| trans_errors = torch.norm(t_pred - t_target, dim=1) |
|
|
| |
| if use_uncertainty_weighting: |
| |
| |
| weights = confidence / (confidence.sum() + 1e-6) |
| weighted_rot_loss = (rot_errors * weights).sum() |
| weighted_trans_loss = (trans_errors * weights).sum() |
| else: |
| |
| weighted_rot_loss = rot_errors.mean() |
| weighted_trans_loss = trans_errors.mean() |
|
|
| total_loss = weight_rotation * weighted_rot_loss + weight_translation * weighted_trans_loss |
|
|
| return { |
| "total_loss": total_loss, |
| "rotation_loss": weighted_rot_loss, |
| "translation_loss": weighted_trans_loss, |
| "mean_confidence": confidence.mean(), |
| "num_frames": torch.tensor(N, device=poses_pred.device), |
| } |
|
|
|
|
| def oracle_uncertainty_weighted_depth_loss( |
| depth_pred: torch.Tensor, |
| depth_target: torch.Tensor, |
| confidence: torch.Tensor, |
| uncertainty: Optional[torch.Tensor] = None, |
| valid_mask: Optional[torch.Tensor] = None, |
| loss_type: str = "l1", |
| relative_error: bool = True, |
| use_uncertainty_weighting: bool = True, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Compute depth loss weighted by continuous oracle confidence/uncertainty. |
| |
| Uses continuous confidence scores rather than binary rejection. |
| |
| Args: |
| depth_pred: Predicted depth (N, H, W) |
| depth_target: Target depth (N, H, W) |
| confidence: Pixel-level confidence scores (N, H, W) [0.0-1.0] |
| uncertainty: Optional depth uncertainty (N, H, W) for covariance-aware loss |
| valid_mask: Additional validity mask (e.g., finite depth, > 0) |
| loss_type: 'l1' or 'l2' |
| relative_error: Use relative error (depth_diff / depth) instead of absolute |
| use_uncertainty_weighting: If True, weight by confidence; if False, use uniform |
| |
| Returns: |
| Dict with: |
| - 'total_loss': Weighted depth loss |
| - 'num_pixels': Total number of valid pixels |
| - 'mean_confidence': Average confidence of valid pixels |
| """ |
| |
| if valid_mask is not None: |
| combined_mask = valid_mask |
| else: |
| combined_mask = torch.ones_like(confidence, dtype=torch.bool) |
|
|
| num_valid = combined_mask.sum().item() |
|
|
| if num_valid == 0: |
| logger.warning("No valid pixels for depth loss") |
| return { |
| "total_loss": torch.tensor(0.0, device=depth_pred.device), |
| "num_pixels": torch.tensor(0, device=depth_pred.device), |
| "mean_confidence": torch.tensor(0.0, device=depth_pred.device), |
| } |
|
|
| |
| depth_diff = torch.abs(depth_pred - depth_target) |
|
|
| if relative_error: |
| |
| depth_error = depth_diff / (depth_target + 1e-6) |
| else: |
| depth_error = depth_diff |
|
|
| |
| if use_uncertainty_weighting: |
| |
| weights = confidence * combined_mask.float() |
| if loss_type == "l1": |
| weighted_error = depth_error * weights |
| else: |
| weighted_error = (depth_error**2) * weights |
|
|
| |
| total_loss = weighted_error.sum() / (weights.sum() + 1e-6) |
| else: |
| |
| if loss_type == "l1": |
| total_loss = (depth_error * combined_mask.float()).sum() / (num_valid + 1e-6) |
| else: |
| total_loss = ((depth_error**2) * combined_mask.float()).sum() / (num_valid + 1e-6) |
|
|
| return { |
| "total_loss": total_loss, |
| "num_pixels": torch.tensor(num_valid, device=depth_pred.device), |
| "mean_confidence": confidence[combined_mask].mean(), |
| } |
|
|
|
|
| def oracle_uncertainty_ensemble_loss( |
| da3_output: Dict[str, torch.Tensor], |
| oracle_targets: Dict[str, torch.Tensor], |
| uncertainty_results: Dict[str, torch.Tensor], |
| loss_weights: Optional[Dict[str, float]] = None, |
| use_uncertainty_weighting: bool = True, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Combined loss using continuous oracle uncertainty propagation. |
| |
| Uses continuous confidence scores rather than binary rejection. |
| |
| Args: |
| da3_output: DA3 predictions dict with: |
| - 'poses': (N, 3, 4) predicted poses w2c |
| - 'depth': (N, H, W) predicted depth maps |
| oracle_targets: Oracle target values dict with: |
| - 'poses': (N, 3, 4) target poses w2c |
| - 'depth': (N, H, W) target depth maps (LiDAR or BA) |
| uncertainty_results: Uncertainty propagation results dict with: |
| - 'pose_confidence': (N,) frame-level confidence [0.0-1.0] |
| - 'depth_confidence': (N, H, W) pixel-level confidence [0.0-1.0] |
| - 'pose_uncertainty': (N, 6) optional pose uncertainty |
| - 'depth_uncertainty': (N, H, W) optional depth uncertainty |
| loss_weights: Optional weights for each loss component |
| use_uncertainty_weighting: If True, weight by confidence; if False, use uniform |
| |
| Returns: |
| Dict with all loss components and statistics |
| """ |
| if loss_weights is None: |
| loss_weights = { |
| "pose": 1.0, |
| "depth": 1.0, |
| } |
|
|
| results = {} |
|
|
| |
| if "poses" in da3_output and "poses" in oracle_targets: |
| pose_loss_dict = oracle_uncertainty_weighted_pose_loss( |
| da3_output["poses"], |
| oracle_targets["poses"], |
| uncertainty_results["pose_confidence"], |
| uncertainty=uncertainty_results.get("pose_uncertainty"), |
| use_uncertainty_weighting=use_uncertainty_weighting, |
| ) |
| results.update({f"pose_{k}": v for k, v in pose_loss_dict.items()}) |
| results["pose_loss"] = pose_loss_dict["total_loss"] * loss_weights["pose"] |
| else: |
| results["pose_loss"] = torch.tensor(0.0, device=da3_output["depth"].device) |
|
|
| |
| if "depth" in da3_output and "depth" in oracle_targets: |
| |
| valid_depth = ( |
| torch.isfinite(oracle_targets["depth"]) |
| & (oracle_targets["depth"] > 0) |
| & torch.isfinite(da3_output["depth"]) |
| & (da3_output["depth"] > 0) |
| ) |
|
|
| depth_loss_dict = oracle_uncertainty_weighted_depth_loss( |
| da3_output["depth"], |
| oracle_targets["depth"], |
| uncertainty_results["depth_confidence"], |
| uncertainty=uncertainty_results.get("depth_uncertainty"), |
| valid_mask=valid_depth, |
| relative_error=True, |
| use_uncertainty_weighting=use_uncertainty_weighting, |
| ) |
| results.update({f"depth_{k}": v for k, v in depth_loss_dict.items()}) |
| results["depth_loss"] = depth_loss_dict["total_loss"] * loss_weights["depth"] |
| else: |
| results["depth_loss"] = torch.tensor(0.0, device=da3_output["depth"].device) |
|
|
| |
| results["total_loss"] = results["pose_loss"] + results["depth_loss"] |
|
|
| |
| if "pose_confidence" in uncertainty_results: |
| results["mean_pose_confidence"] = uncertainty_results["pose_confidence"].mean() |
| results["min_pose_confidence"] = uncertainty_results["pose_confidence"].min() |
| results["max_pose_confidence"] = uncertainty_results["pose_confidence"].max() |
|
|
| if "depth_confidence" in uncertainty_results: |
| results["mean_depth_confidence"] = uncertainty_results["depth_confidence"].mean() |
| results["min_depth_confidence"] = uncertainty_results["depth_confidence"].min() |
| results["max_depth_confidence"] = uncertainty_results["depth_confidence"].max() |
|
|
| return results |
|
|