| """ |
| Uncertainty Output Head: Predicts per-pixel depth uncertainty and per-frame pose uncertainty. |
| |
| This head can be added to DA3 models to output uncertainty estimates alongside |
| depth and pose predictions, enabling uncertainty-aware training and inference. |
| """ |
|
|
| import logging |
| from typing import Dict, Optional |
| import torch |
| import torch.nn as nn |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class DepthUncertaintyHead(nn.Module): |
| """ |
| Output head that predicts depth with per-pixel uncertainty. |
| |
| Outputs: |
| - depth: [B, H, W] predicted depth in meters |
| - uncertainty: [B, H, W] predicted uncertainty (std) in meters |
| - confidence: [B, H, W] confidence score [0, 1] (derived from uncertainty) |
| """ |
|
|
| def __init__( |
| self, |
| in_dim: int, |
| min_depth: float = 0.1, |
| max_depth: float = 100.0, |
| min_uncertainty: float = 0.01, |
| max_uncertainty: float = 10.0, |
| use_shared_features: bool = True, |
| ): |
| """ |
| Args: |
| in_dim: Input feature dimension |
| min_depth: Minimum depth in meters |
| max_depth: Maximum depth in meters |
| min_uncertainty: Minimum uncertainty in meters (std) |
| max_uncertainty: Maximum uncertainty in meters (std) |
| use_shared_features: If True, share features between depth and uncertainty |
| """ |
| super().__init__() |
| self.min_depth = min_depth |
| self.max_depth = max_depth |
| self.min_uncertainty = min_uncertainty |
| self.max_uncertainty = max_uncertainty |
|
|
| if use_shared_features: |
| |
| self.shared_conv = nn.Sequential( |
| nn.Conv2d(in_dim, in_dim // 2, 3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(in_dim // 2, in_dim // 4, 3, padding=1), |
| nn.ReLU(inplace=True), |
| ) |
| shared_dim = in_dim // 4 |
| else: |
| self.shared_conv = None |
| shared_dim = in_dim |
|
|
| |
| self.depth_head = nn.Sequential( |
| nn.Conv2d(shared_dim, shared_dim // 2, 3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(shared_dim // 2, 1, 1), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| |
| self.uncertainty_head = nn.Sequential( |
| nn.Conv2d(shared_dim, shared_dim // 2, 3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(shared_dim // 2, 1, 1), |
| nn.Softplus(), |
| ) |
|
|
| def forward(self, features: torch.Tensor) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| features: [B, C, H, W] feature map |
| |
| Returns: |
| Dict with: |
| - 'depth': [B, H, W] depth in meters |
| - 'uncertainty': [B, H, W] uncertainty (std) in meters |
| - 'confidence': [B, H, W] confidence [0, 1] |
| """ |
| |
| if self.shared_conv is not None: |
| shared_features = self.shared_conv(features) |
| else: |
| shared_features = features |
|
|
| |
| depth_logits = self.depth_head(shared_features) |
| depth = depth_logits.squeeze(1) * (self.max_depth - self.min_depth) + self.min_depth |
| depth = torch.clamp(depth, min=self.min_depth, max=self.max_depth) |
|
|
| |
| uncertainty_logits = self.uncertainty_head(shared_features) |
| uncertainty = uncertainty_logits.squeeze(1) |
| uncertainty = torch.clamp(uncertainty, min=self.min_uncertainty, max=self.max_uncertainty) |
|
|
| |
| |
| |
| |
| confidence = 1.0 / (1.0 + uncertainty) |
|
|
| return { |
| "depth": depth, |
| "uncertainty": uncertainty, |
| "confidence": confidence, |
| } |
|
|
|
|
| class PoseUncertaintyHead(nn.Module): |
| """ |
| Output head that predicts pose with per-frame uncertainty. |
| |
| Outputs: |
| - pose: [B, N, 3, 4] predicted pose (w2c) |
| - uncertainty: [B, N, 6] predicted pose uncertainty (6D: 3 rot + 3 trans) |
| - confidence: [B, N] frame-level confidence [0, 1] |
| """ |
|
|
| def __init__( |
| self, |
| in_dim: int, |
| min_rot_uncertainty: float = 0.001, |
| max_rot_uncertainty: float = 0.175, |
| min_trans_uncertainty: float = 0.001, |
| max_trans_uncertainty: float = 1.0, |
| ): |
| """ |
| Args: |
| in_dim: Input feature dimension (typically 2*C from concatenated features) |
| min_rot_uncertainty: Minimum rotation uncertainty in radians |
| max_rot_uncertainty: Maximum rotation uncertainty in radians |
| min_trans_uncertainty: Minimum translation uncertainty in meters |
| max_trans_uncertainty: Maximum translation uncertainty in meters |
| """ |
| super().__init__() |
| self.min_rot_uncertainty = min_rot_uncertainty |
| self.max_rot_uncertainty = max_rot_uncertainty |
| self.min_trans_uncertainty = min_trans_uncertainty |
| self.max_trans_uncertainty = max_trans_uncertainty |
|
|
| |
| self.pose_head = nn.Sequential( |
| nn.Linear(in_dim, in_dim // 2), |
| nn.ReLU(inplace=True), |
| nn.Linear(in_dim // 2, 6), |
| ) |
|
|
| |
| self.uncertainty_head = nn.Sequential( |
| nn.Linear(in_dim, in_dim // 2), |
| nn.ReLU(inplace=True), |
| nn.Linear(in_dim // 2, 6), |
| nn.Softplus(), |
| ) |
|
|
| def forward(self, features: torch.Tensor) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| features: [B, N, C] feature vectors (one per frame) |
| |
| Returns: |
| Dict with: |
| - 'pose': [B, N, 3, 4] pose (w2c) |
| - 'uncertainty': [B, N, 6] pose uncertainty (3 rot + 3 trans) |
| - 'confidence': [B, N] frame-level confidence [0, 1] |
| """ |
| B, N, C = features.shape |
|
|
| |
| pose_params = self.pose_head(features) |
| rot_params = pose_params[:, :, :3] |
| trans_params = pose_params[:, :, 3:] |
|
|
| |
| rot_matrices = self._axis_angle_to_rotation_matrix(rot_params) |
|
|
| |
| poses = torch.cat([rot_matrices, trans_params.unsqueeze(-1)], dim=-1) |
|
|
| |
| uncertainty_params = self.uncertainty_head(features) |
| rot_uncertainty = uncertainty_params[:, :, :3] |
| trans_uncertainty = uncertainty_params[:, :, 3:] |
|
|
| |
| rot_uncertainty = torch.clamp( |
| rot_uncertainty, |
| min=self.min_rot_uncertainty, |
| max=self.max_rot_uncertainty, |
| ) |
| trans_uncertainty = torch.clamp( |
| trans_uncertainty, |
| min=self.min_trans_uncertainty, |
| max=self.max_trans_uncertainty, |
| ) |
|
|
| |
| uncertainty = torch.cat([rot_uncertainty, trans_uncertainty], dim=-1) |
|
|
| |
| |
| rot_uncertainty_mean = rot_uncertainty.mean(dim=-1) |
| trans_uncertainty_mean = trans_uncertainty.mean(dim=-1) |
| combined_uncertainty = (rot_uncertainty_mean * trans_uncertainty_mean) ** 0.5 |
|
|
| |
| confidence = 1.0 / (1.0 + combined_uncertainty) |
|
|
| return { |
| "pose": poses, |
| "uncertainty": uncertainty, |
| "confidence": confidence, |
| } |
|
|
| def _axis_angle_to_rotation_matrix(self, axis_angle: torch.Tensor) -> torch.Tensor: |
| """ |
| Convert axis-angle representation to rotation matrix using Rodrigues' formula. |
| |
| Args: |
| axis_angle: [B, N, 3] axis-angle representation |
| |
| Returns: |
| rotation_matrix: [B, N, 3, 3] rotation matrices |
| """ |
| B, N, _ = axis_angle.shape |
| device = axis_angle.device |
|
|
| |
| angle = torch.norm(axis_angle, dim=-1, keepdim=True) |
| angle = torch.clamp(angle, min=1e-8) |
|
|
| axis = axis_angle / angle |
|
|
| |
| cos_angle = torch.cos(angle) |
| sin_angle = torch.sin(angle) |
|
|
| |
| K = torch.zeros(B, N, 3, 3, device=device) |
| K[:, :, 0, 1] = -axis[:, :, 2] |
| K[:, :, 0, 2] = axis[:, :, 1] |
| K[:, :, 1, 0] = axis[:, :, 2] |
| K[:, :, 1, 2] = -axis[:, :, 0] |
| K[:, :, 2, 0] = -axis[:, :, 1] |
| K[:, :, 2, 1] = axis[:, :, 0] |
|
|
| |
| I = torch.eye(3, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1) |
| K_squared = torch.matmul(K, K) |
|
|
| R = I + sin_angle.unsqueeze(-1) * K + (1.0 - cos_angle).unsqueeze(-1) * K_squared |
|
|
| return R |
|
|
|
|
| class UncertaintyAwareDA3Wrapper(nn.Module): |
| """ |
| Wrapper that adds uncertainty prediction to DA3 model. |
| |
| This wraps the existing DA3 model and adds uncertainty heads for depth and pose. |
| The uncertainty heads take features from the DA3 model and predict uncertainty. |
| """ |
|
|
| def __init__( |
| self, |
| da3_model: nn.Module, |
| depth_uncertainty_head: Optional[DepthUncertaintyHead] = None, |
| pose_uncertainty_head: Optional[PoseUncertaintyHead] = None, |
| freeze_base_model: bool = False, |
| ): |
| """ |
| Args: |
| da3_model: Base DA3 model |
| depth_uncertainty_head: Optional depth uncertainty head (auto-created if None) |
| pose_uncertainty_head: Optional pose uncertainty head (auto-created if None) |
| freeze_base_model: If True, freeze base model weights (only train uncertainty heads) |
| """ |
| super().__init__() |
| self.da3_model = da3_model |
| self.freeze_base_model = freeze_base_model |
|
|
| if freeze_base_model: |
| for param in self.da3_model.parameters(): |
| param.requires_grad = False |
|
|
| |
| |
| |
| if depth_uncertainty_head is None: |
| |
| self.depth_uncertainty_head = DepthUncertaintyHead(in_dim=1024) |
| else: |
| self.depth_uncertainty_head = depth_uncertainty_head |
|
|
| if pose_uncertainty_head is None: |
| |
| self.pose_uncertainty_head = PoseUncertaintyHead(in_dim=2048) |
| else: |
| self.pose_uncertainty_head = pose_uncertainty_head |
|
|
| def forward(self, images: list, extract_features: bool = False) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass with uncertainty prediction. |
| |
| Args: |
| images: List of input images |
| extract_features: If True, also return intermediate features |
| |
| Returns: |
| Dict with: |
| - 'depth': [N, H, W] depth maps |
| - 'depth_uncertainty': [N, H, W] depth uncertainty |
| - 'depth_confidence': [N, H, W] depth confidence |
| - 'poses': [N, 3, 4] camera poses |
| - 'pose_uncertainty': [N, 6] pose uncertainty |
| - 'pose_confidence': [N] pose confidence |
| - 'features': (optional) intermediate features if extract_features=True |
| """ |
| |
| da3_output = self.da3_model.inference(images) |
|
|
| |
| depth = da3_output.depth |
| poses = da3_output.extrinsics |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| depth_uncertainty = None |
| depth_confidence = None |
| pose_uncertainty = None |
| pose_confidence = None |
|
|
| result = { |
| "depth": depth, |
| "poses": poses, |
| "depth_uncertainty": depth_uncertainty, |
| "depth_confidence": depth_confidence, |
| "pose_uncertainty": pose_uncertainty, |
| "pose_confidence": pose_confidence, |
| } |
|
|
| if extract_features: |
| result["features"] = None |
|
|
| return result |
|
|
|
|
| def create_uncertainty_head_from_features( |
| feature_dim: int, |
| head_type: str = "depth", |
| **kwargs, |
| ) -> nn.Module: |
| """ |
| Create uncertainty head with specified feature dimension. |
| |
| Args: |
| feature_dim: Input feature dimension |
| head_type: 'depth' or 'pose' |
| **kwargs: Additional arguments for head initialization |
| |
| Returns: |
| Uncertainty head module |
| """ |
| if head_type == "depth": |
| return DepthUncertaintyHead(in_dim=feature_dim, **kwargs) |
| elif head_type == "pose": |
| return PoseUncertaintyHead(in_dim=feature_dim, **kwargs) |
| else: |
| raise ValueError(f"Unknown head type: {head_type}") |
|
|
|
|
| def uncertainty_prediction_loss( |
| uncertainty_pred: torch.Tensor, |
| uncertainty_target: torch.Tensor, |
| confidence_target: Optional[torch.Tensor] = None, |
| loss_type: str = "l1", |
| ) -> torch.Tensor: |
| """ |
| Loss function for training uncertainty prediction. |
| |
| Args: |
| uncertainty_pred: Predicted uncertainty |
| uncertainty_target: Target uncertainty (from oracle) |
| confidence_target: Optional target confidence (alternative to uncertainty) |
| loss_type: 'l1' or 'l2' |
| |
| Returns: |
| Uncertainty prediction loss |
| """ |
| if confidence_target is not None: |
| |
| uncertainty_target = 1.0 / (confidence_target + 1e-8) - 1.0 |
|
|
| valid_mask = (uncertainty_target > 0) & (uncertainty_target < 100.0) |
|
|
| if valid_mask.sum() == 0: |
| return torch.tensor(0.0, device=uncertainty_pred.device) |
|
|
| if loss_type == "l1": |
| error = torch.abs(uncertainty_pred[valid_mask] - uncertainty_target[valid_mask]) |
| else: |
| error = (uncertainty_pred[valid_mask] - uncertainty_target[valid_mask]) ** 2 |
|
|
| return error.mean() |
|
|