3d_model / ylff /utils /losses.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Loss functions for pose and depth estimation.
"""
from typing import Optional
import torch
import torch.nn.functional as F
def geodesic_rotation_loss(R_pred: torch.Tensor, R_target: torch.Tensor) -> torch.Tensor:
"""
Compute geodesic distance between rotation matrices.
Args:
R_pred: Predicted rotation matrices (..., 3, 3)
R_target: Target rotation matrices (..., 3, 3)
Returns:
Geodesic distance in radians
"""
# R_diff = R_pred @ R_target^T
R_diff = torch.matmul(R_pred, R_target.transpose(-2, -1))
# Trace of rotation matrix: tr(R) = 1 + 2*cos(θ)
trace = torch.diagonal(R_diff, dim1=-2, dim2=-1).sum(dim=-1)
# Clamp to valid range for arccos
trace_clamped = torch.clamp(trace, -1.0, 3.0)
# Angle: θ = arccos((tr(R) - 1) / 2)
angle = torch.acos((trace_clamped - 1.0) / 2.0)
return angle.mean()
def pose_loss(
poses_pred: torch.Tensor,
poses_target: torch.Tensor,
weight_rotation: float = 1.0,
weight_translation: float = 0.1,
) -> torch.Tensor:
"""
Compute pose loss (rotation + translation).
Args:
poses_pred: Predicted poses (N, 3, 4) or (N, 4, 4)
poses_target: Target poses (N, 3, 4) or (N, 4, 4)
weight_rotation: Weight for rotation loss
weight_translation: Weight for translation loss
Returns:
Combined pose loss
"""
# Extract rotation and translation
if poses_pred.shape[1] == 4:
R_pred = poses_pred[:, :3, :3]
t_pred = poses_pred[:, :3, 3]
R_target = poses_target[:, :3, :3]
t_target = poses_target[:, :3, 3]
else:
R_pred = poses_pred[:, :3, :3]
t_pred = poses_pred[:, :3, 3]
R_target = poses_target[:, :3, :3]
t_target = poses_target[:, :3, 3]
# Rotation loss (geodesic distance)
loss_rot = geodesic_rotation_loss(R_pred, R_target)
# Translation loss (L1)
loss_trans = F.l1_loss(t_pred, t_target)
return weight_rotation * loss_rot + weight_translation * loss_trans
def depth_loss(
depth_pred: torch.Tensor,
depth_target: torch.Tensor,
mask: Optional[torch.Tensor] = None,
loss_type: str = "l1",
) -> torch.Tensor:
"""
Compute depth loss.
Args:
depth_pred: Predicted depth (N, H, W)
depth_target: Target depth (N, H, W)
mask: Valid depth mask (N, H, W), optional
loss_type: 'l1' or 'l2'
Returns:
Depth loss
"""
if mask is not None:
depth_pred = depth_pred * mask
depth_target = depth_target * mask
valid_pixels = mask.sum()
else:
valid_pixels = depth_pred.numel()
if loss_type == "l1":
loss = F.l1_loss(depth_pred, depth_target, reduction="sum")
elif loss_type == "l2":
loss = F.mse_loss(depth_pred, depth_target, reduction="sum")
else:
raise ValueError(f"Unknown loss type: {loss_type}")
return loss / (valid_pixels + 1e-8)
def confidence_weighted_loss(
pred: torch.Tensor,
target: torch.Tensor,
confidence: torch.Tensor,
base_loss_fn: callable = F.l1_loss,
) -> torch.Tensor:
"""
Compute confidence-weighted loss.
Args:
pred: Predictions
target: Targets
confidence: Confidence weights (higher = more confident)
base_loss_fn: Base loss function
Returns:
Weighted loss
"""
loss = base_loss_fn(pred, target, reduction="none")
weighted_loss = (loss * confidence).mean()
return weighted_loss