| """ |
| Loss functions for pose and depth estimation. |
| """ |
|
|
| from typing import Optional |
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| def geodesic_rotation_loss(R_pred: torch.Tensor, R_target: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute geodesic distance between rotation matrices. |
| |
| Args: |
| R_pred: Predicted rotation matrices (..., 3, 3) |
| R_target: Target rotation matrices (..., 3, 3) |
| |
| Returns: |
| Geodesic distance in radians |
| """ |
| |
| R_diff = torch.matmul(R_pred, R_target.transpose(-2, -1)) |
|
|
| |
| trace = torch.diagonal(R_diff, dim1=-2, dim2=-1).sum(dim=-1) |
|
|
| |
| trace_clamped = torch.clamp(trace, -1.0, 3.0) |
|
|
| |
| angle = torch.acos((trace_clamped - 1.0) / 2.0) |
|
|
| return angle.mean() |
|
|
|
|
| def pose_loss( |
| poses_pred: torch.Tensor, |
| poses_target: torch.Tensor, |
| weight_rotation: float = 1.0, |
| weight_translation: float = 0.1, |
| ) -> torch.Tensor: |
| """ |
| Compute pose loss (rotation + translation). |
| |
| Args: |
| poses_pred: Predicted poses (N, 3, 4) or (N, 4, 4) |
| poses_target: Target poses (N, 3, 4) or (N, 4, 4) |
| weight_rotation: Weight for rotation loss |
| weight_translation: Weight for translation loss |
| |
| Returns: |
| Combined pose loss |
| """ |
| |
| if poses_pred.shape[1] == 4: |
| R_pred = poses_pred[:, :3, :3] |
| t_pred = poses_pred[:, :3, 3] |
| R_target = poses_target[:, :3, :3] |
| t_target = poses_target[:, :3, 3] |
| else: |
| R_pred = poses_pred[:, :3, :3] |
| t_pred = poses_pred[:, :3, 3] |
| R_target = poses_target[:, :3, :3] |
| t_target = poses_target[:, :3, 3] |
|
|
| |
| loss_rot = geodesic_rotation_loss(R_pred, R_target) |
|
|
| |
| loss_trans = F.l1_loss(t_pred, t_target) |
|
|
| return weight_rotation * loss_rot + weight_translation * loss_trans |
|
|
|
|
| def depth_loss( |
| depth_pred: torch.Tensor, |
| depth_target: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| loss_type: str = "l1", |
| ) -> torch.Tensor: |
| """ |
| Compute depth loss. |
| |
| Args: |
| depth_pred: Predicted depth (N, H, W) |
| depth_target: Target depth (N, H, W) |
| mask: Valid depth mask (N, H, W), optional |
| loss_type: 'l1' or 'l2' |
| |
| Returns: |
| Depth loss |
| """ |
| if mask is not None: |
| depth_pred = depth_pred * mask |
| depth_target = depth_target * mask |
| valid_pixels = mask.sum() |
| else: |
| valid_pixels = depth_pred.numel() |
|
|
| if loss_type == "l1": |
| loss = F.l1_loss(depth_pred, depth_target, reduction="sum") |
| elif loss_type == "l2": |
| loss = F.mse_loss(depth_pred, depth_target, reduction="sum") |
| else: |
| raise ValueError(f"Unknown loss type: {loss_type}") |
|
|
| return loss / (valid_pixels + 1e-8) |
|
|
|
|
| def confidence_weighted_loss( |
| pred: torch.Tensor, |
| target: torch.Tensor, |
| confidence: torch.Tensor, |
| base_loss_fn: callable = F.l1_loss, |
| ) -> torch.Tensor: |
| """ |
| Compute confidence-weighted loss. |
| |
| Args: |
| pred: Predictions |
| target: Targets |
| confidence: Confidence weights (higher = more confident) |
| base_loss_fn: Base loss function |
| |
| Returns: |
| Weighted loss |
| """ |
| loss = base_loss_fn(pred, target, reduction="none") |
| weighted_loss = (loss * confidence).mean() |
| return weighted_loss |
|
|