3d_model / ylff /utils /oracle_losses.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Oracle Ensemble Loss Functions: Uncertainty-weighted losses using continuous confidence.
Uses oracle uncertainty propagation to create continuous confidence masks that
weight training by uncertainty rather than binary rejection.
"""
import logging
from typing import Dict, Optional
import torch
logger = logging.getLogger(__name__)
def oracle_uncertainty_weighted_pose_loss(
poses_pred: torch.Tensor, # (N, 3, 4) w2c
poses_target: torch.Tensor, # (N, 3, 4) w2c
confidence: torch.Tensor, # (N,) frame-level confidence [0.0-1.0]
uncertainty: Optional[torch.Tensor] = None, # (N, 6) pose uncertainty
weight_rotation: float = 1.0,
weight_translation: float = 0.1,
use_uncertainty_weighting: bool = True,
) -> Dict[str, torch.Tensor]:
"""
Compute pose loss weighted by continuous oracle confidence/uncertainty.
Uses continuous confidence scores rather than binary rejection.
Args:
poses_pred: Predicted poses (N, 3, 4) w2c
poses_target: Target poses (N, 3, 4) w2c
confidence: Frame-level confidence scores (N,) [0.0-1.0] from oracle ensemble
uncertainty: Optional pose uncertainty (N, 6) for covariance-aware loss
weight_rotation: Weight for rotation loss
weight_translation: Weight for translation loss
use_uncertainty_weighting: If True, weight by confidence; if False, use uniform
Returns:
Dict with:
- 'total_loss': Combined weighted loss
- 'rotation_loss': Rotation component
- 'translation_loss': Translation component
- 'mean_confidence': Average confidence of frames
- 'num_frames': Total number of frames
"""
N = len(poses_pred)
# Extract rotation and translation
R_pred = poses_pred[:, :3, :3]
t_pred = poses_pred[:, :3, 3]
R_target = poses_target[:, :3, :3]
t_target = poses_target[:, :3, 3]
# Compute per-frame losses
# Rotation loss (geodesic distance per frame)
R_diff = torch.matmul(R_pred, R_target.transpose(-2, -1))
trace = torch.diagonal(R_diff, dim1=-2, dim2=-1).sum(dim=-1)
trace_clamped = torch.clamp(trace, -1.0, 3.0)
rot_errors = torch.acos((trace_clamped - 1.0) / 2.0) # (N,)
# Translation loss (L1 per frame)
trans_errors = torch.norm(t_pred - t_target, dim=1) # (N,)
# Weight by continuous confidence (not binary rejection)
if use_uncertainty_weighting:
# Weight by confidence: higher confidence = more weight
# Normalize by sum of weights to get weighted average
weights = confidence / (confidence.sum() + 1e-6)
weighted_rot_loss = (rot_errors * weights).sum()
weighted_trans_loss = (trans_errors * weights).sum()
else:
# Uniform weighting
weighted_rot_loss = rot_errors.mean()
weighted_trans_loss = trans_errors.mean()
total_loss = weight_rotation * weighted_rot_loss + weight_translation * weighted_trans_loss
return {
"total_loss": total_loss,
"rotation_loss": weighted_rot_loss,
"translation_loss": weighted_trans_loss,
"mean_confidence": confidence.mean(),
"num_frames": torch.tensor(N, device=poses_pred.device),
}
def oracle_uncertainty_weighted_depth_loss(
depth_pred: torch.Tensor, # (N, H, W)
depth_target: torch.Tensor, # (N, H, W)
confidence: torch.Tensor, # (N, H, W) pixel-level confidence [0.0-1.0]
uncertainty: Optional[torch.Tensor] = None, # (N, H, W) depth uncertainty
valid_mask: Optional[torch.Tensor] = None, # (N, H, W) additional validity mask
loss_type: str = "l1",
relative_error: bool = True,
use_uncertainty_weighting: bool = True,
) -> Dict[str, torch.Tensor]:
"""
Compute depth loss weighted by continuous oracle confidence/uncertainty.
Uses continuous confidence scores rather than binary rejection.
Args:
depth_pred: Predicted depth (N, H, W)
depth_target: Target depth (N, H, W)
confidence: Pixel-level confidence scores (N, H, W) [0.0-1.0]
uncertainty: Optional depth uncertainty (N, H, W) for covariance-aware loss
valid_mask: Additional validity mask (e.g., finite depth, > 0)
loss_type: 'l1' or 'l2'
relative_error: Use relative error (depth_diff / depth) instead of absolute
use_uncertainty_weighting: If True, weight by confidence; if False, use uniform
Returns:
Dict with:
- 'total_loss': Weighted depth loss
- 'num_pixels': Total number of valid pixels
- 'mean_confidence': Average confidence of valid pixels
"""
# Combine with validity mask
if valid_mask is not None:
combined_mask = valid_mask
else:
combined_mask = torch.ones_like(confidence, dtype=torch.bool)
num_valid = combined_mask.sum().item()
if num_valid == 0:
logger.warning("No valid pixels for depth loss")
return {
"total_loss": torch.tensor(0.0, device=depth_pred.device),
"num_pixels": torch.tensor(0, device=depth_pred.device),
"mean_confidence": torch.tensor(0.0, device=depth_pred.device),
}
# Compute depth error
depth_diff = torch.abs(depth_pred - depth_target)
if relative_error:
# Relative error: |pred - target| / target
depth_error = depth_diff / (depth_target + 1e-6)
else:
depth_error = depth_diff
# Weight by continuous confidence (not binary rejection)
if use_uncertainty_weighting:
# Weight by confidence: higher confidence = more weight
weights = confidence * combined_mask.float()
if loss_type == "l1":
weighted_error = depth_error * weights
else: # l2
weighted_error = (depth_error**2) * weights
# Normalize by sum of weights (weighted average)
total_loss = weighted_error.sum() / (weights.sum() + 1e-6)
else:
# Uniform weighting (only valid pixels)
if loss_type == "l1":
total_loss = (depth_error * combined_mask.float()).sum() / (num_valid + 1e-6)
else: # l2
total_loss = ((depth_error**2) * combined_mask.float()).sum() / (num_valid + 1e-6)
return {
"total_loss": total_loss,
"num_pixels": torch.tensor(num_valid, device=depth_pred.device),
"mean_confidence": confidence[combined_mask].mean(),
}
def oracle_uncertainty_ensemble_loss(
da3_output: Dict[str, torch.Tensor],
oracle_targets: Dict[str, torch.Tensor],
uncertainty_results: Dict[str, torch.Tensor],
loss_weights: Optional[Dict[str, float]] = None,
use_uncertainty_weighting: bool = True,
) -> Dict[str, torch.Tensor]:
"""
Combined loss using continuous oracle uncertainty propagation.
Uses continuous confidence scores rather than binary rejection.
Args:
da3_output: DA3 predictions dict with:
- 'poses': (N, 3, 4) predicted poses w2c
- 'depth': (N, H, W) predicted depth maps
oracle_targets: Oracle target values dict with:
- 'poses': (N, 3, 4) target poses w2c
- 'depth': (N, H, W) target depth maps (LiDAR or BA)
uncertainty_results: Uncertainty propagation results dict with:
- 'pose_confidence': (N,) frame-level confidence [0.0-1.0]
- 'depth_confidence': (N, H, W) pixel-level confidence [0.0-1.0]
- 'pose_uncertainty': (N, 6) optional pose uncertainty
- 'depth_uncertainty': (N, H, W) optional depth uncertainty
loss_weights: Optional weights for each loss component
use_uncertainty_weighting: If True, weight by confidence; if False, use uniform
Returns:
Dict with all loss components and statistics
"""
if loss_weights is None:
loss_weights = {
"pose": 1.0,
"depth": 1.0,
}
results = {}
# Pose loss (weighted by continuous confidence)
if "poses" in da3_output and "poses" in oracle_targets:
pose_loss_dict = oracle_uncertainty_weighted_pose_loss(
da3_output["poses"],
oracle_targets["poses"],
uncertainty_results["pose_confidence"],
uncertainty=uncertainty_results.get("pose_uncertainty"),
use_uncertainty_weighting=use_uncertainty_weighting,
)
results.update({f"pose_{k}": v for k, v in pose_loss_dict.items()})
results["pose_loss"] = pose_loss_dict["total_loss"] * loss_weights["pose"]
else:
results["pose_loss"] = torch.tensor(0.0, device=da3_output["depth"].device)
# Depth loss (weighted by continuous confidence)
if "depth" in da3_output and "depth" in oracle_targets:
# Create valid mask (finite depth, > 0)
valid_depth = (
torch.isfinite(oracle_targets["depth"])
& (oracle_targets["depth"] > 0)
& torch.isfinite(da3_output["depth"])
& (da3_output["depth"] > 0)
)
depth_loss_dict = oracle_uncertainty_weighted_depth_loss(
da3_output["depth"],
oracle_targets["depth"],
uncertainty_results["depth_confidence"],
uncertainty=uncertainty_results.get("depth_uncertainty"),
valid_mask=valid_depth,
relative_error=True,
use_uncertainty_weighting=use_uncertainty_weighting,
)
results.update({f"depth_{k}": v for k, v in depth_loss_dict.items()})
results["depth_loss"] = depth_loss_dict["total_loss"] * loss_weights["depth"]
else:
results["depth_loss"] = torch.tensor(0.0, device=da3_output["depth"].device)
# Total loss
results["total_loss"] = results["pose_loss"] + results["depth_loss"]
# Statistics
if "pose_confidence" in uncertainty_results:
results["mean_pose_confidence"] = uncertainty_results["pose_confidence"].mean()
results["min_pose_confidence"] = uncertainty_results["pose_confidence"].min()
results["max_pose_confidence"] = uncertainty_results["pose_confidence"].max()
if "depth_confidence" in uncertainty_results:
results["mean_depth_confidence"] = uncertainty_results["depth_confidence"].mean()
results["min_depth_confidence"] = uncertainty_results["depth_confidence"].min()
results["max_depth_confidence"] = uncertainty_results["depth_confidence"].max()
return results