3d_model / ylff /utils /uncertainty_head.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Uncertainty Output Head: Predicts per-pixel depth uncertainty and per-frame pose uncertainty.
This head can be added to DA3 models to output uncertainty estimates alongside
depth and pose predictions, enabling uncertainty-aware training and inference.
"""
import logging
from typing import Dict, Optional
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
class DepthUncertaintyHead(nn.Module):
"""
Output head that predicts depth with per-pixel uncertainty.
Outputs:
- depth: [B, H, W] predicted depth in meters
- uncertainty: [B, H, W] predicted uncertainty (std) in meters
- confidence: [B, H, W] confidence score [0, 1] (derived from uncertainty)
"""
def __init__(
self,
in_dim: int,
min_depth: float = 0.1,
max_depth: float = 100.0,
min_uncertainty: float = 0.01,
max_uncertainty: float = 10.0,
use_shared_features: bool = True,
):
"""
Args:
in_dim: Input feature dimension
min_depth: Minimum depth in meters
max_depth: Maximum depth in meters
min_uncertainty: Minimum uncertainty in meters (std)
max_uncertainty: Maximum uncertainty in meters (std)
use_shared_features: If True, share features between depth and uncertainty
"""
super().__init__()
self.min_depth = min_depth
self.max_depth = max_depth
self.min_uncertainty = min_uncertainty
self.max_uncertainty = max_uncertainty
if use_shared_features:
# Shared feature extraction
self.shared_conv = nn.Sequential(
nn.Conv2d(in_dim, in_dim // 2, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(in_dim // 2, in_dim // 4, 3, padding=1),
nn.ReLU(inplace=True),
)
shared_dim = in_dim // 4
else:
self.shared_conv = None
shared_dim = in_dim
# Depth prediction head
self.depth_head = nn.Sequential(
nn.Conv2d(shared_dim, shared_dim // 2, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(shared_dim // 2, 1, 1),
nn.ReLU(inplace=True), # Ensure positive depth
)
# Uncertainty prediction head
self.uncertainty_head = nn.Sequential(
nn.Conv2d(shared_dim, shared_dim // 2, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(shared_dim // 2, 1, 1),
nn.Softplus(), # Ensure positive uncertainty
)
def forward(self, features: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Args:
features: [B, C, H, W] feature map
Returns:
Dict with:
- 'depth': [B, H, W] depth in meters
- 'uncertainty': [B, H, W] uncertainty (std) in meters
- 'confidence': [B, H, W] confidence [0, 1]
"""
# Extract shared features if enabled
if self.shared_conv is not None:
shared_features = self.shared_conv(features)
else:
shared_features = features
# Predict depth (absolute scale in meters)
depth_logits = self.depth_head(shared_features) # [B, 1, H, W]
depth = depth_logits.squeeze(1) * (self.max_depth - self.min_depth) + self.min_depth
depth = torch.clamp(depth, min=self.min_depth, max=self.max_depth)
# Predict uncertainty (std in meters)
uncertainty_logits = self.uncertainty_head(shared_features) # [B, 1, H, W]
uncertainty = uncertainty_logits.squeeze(1)
uncertainty = torch.clamp(uncertainty, min=self.min_uncertainty, max=self.max_uncertainty)
# Derive confidence from uncertainty
# Higher uncertainty → lower confidence
# Use inverse relationship: conf = 1 / (1 + uncertainty)
# Normalize to [0, 1] range
confidence = 1.0 / (1.0 + uncertainty)
return {
"depth": depth,
"uncertainty": uncertainty,
"confidence": confidence,
}
class PoseUncertaintyHead(nn.Module):
"""
Output head that predicts pose with per-frame uncertainty.
Outputs:
- pose: [B, N, 3, 4] predicted pose (w2c)
- uncertainty: [B, N, 6] predicted pose uncertainty (6D: 3 rot + 3 trans)
- confidence: [B, N] frame-level confidence [0, 1]
"""
def __init__(
self,
in_dim: int,
min_rot_uncertainty: float = 0.001, # radians (~0.06 degrees)
max_rot_uncertainty: float = 0.175, # radians (~10 degrees)
min_trans_uncertainty: float = 0.001, # meters
max_trans_uncertainty: float = 1.0, # meters
):
"""
Args:
in_dim: Input feature dimension (typically 2*C from concatenated features)
min_rot_uncertainty: Minimum rotation uncertainty in radians
max_rot_uncertainty: Maximum rotation uncertainty in radians
min_trans_uncertainty: Minimum translation uncertainty in meters
max_trans_uncertainty: Maximum translation uncertainty in meters
"""
super().__init__()
self.min_rot_uncertainty = min_rot_uncertainty
self.max_rot_uncertainty = max_rot_uncertainty
self.min_trans_uncertainty = min_trans_uncertainty
self.max_trans_uncertainty = max_trans_uncertainty
# Pose prediction (rotation + translation)
self.pose_head = nn.Sequential(
nn.Linear(in_dim, in_dim // 2),
nn.ReLU(inplace=True),
nn.Linear(in_dim // 2, 6), # 3 rot (axis-angle) + 3 trans
)
# Uncertainty prediction (6D: 3 rot + 3 trans)
self.uncertainty_head = nn.Sequential(
nn.Linear(in_dim, in_dim // 2),
nn.ReLU(inplace=True),
nn.Linear(in_dim // 2, 6), # 6D uncertainty
nn.Softplus(), # Ensure positive uncertainty
)
def forward(self, features: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Args:
features: [B, N, C] feature vectors (one per frame)
Returns:
Dict with:
- 'pose': [B, N, 3, 4] pose (w2c)
- 'uncertainty': [B, N, 6] pose uncertainty (3 rot + 3 trans)
- 'confidence': [B, N] frame-level confidence [0, 1]
"""
B, N, C = features.shape
# Predict pose (axis-angle rotation + translation)
pose_params = self.pose_head(features) # [B, N, 6]
rot_params = pose_params[:, :, :3] # [B, N, 3] axis-angle
trans_params = pose_params[:, :, 3:] # [B, N, 3] translation
# Convert axis-angle to rotation matrix
rot_matrices = self._axis_angle_to_rotation_matrix(rot_params) # [B, N, 3, 3]
# Combine into pose matrix [B, N, 3, 4]
poses = torch.cat([rot_matrices, trans_params.unsqueeze(-1)], dim=-1)
# Predict uncertainty
uncertainty_params = self.uncertainty_head(features) # [B, N, 6]
rot_uncertainty = uncertainty_params[:, :, :3] # [B, N, 3]
trans_uncertainty = uncertainty_params[:, :, 3:] # [B, N, 3]
# Clamp uncertainty to reasonable ranges
rot_uncertainty = torch.clamp(
rot_uncertainty,
min=self.min_rot_uncertainty,
max=self.max_rot_uncertainty,
)
trans_uncertainty = torch.clamp(
trans_uncertainty,
min=self.min_trans_uncertainty,
max=self.max_trans_uncertainty,
)
# Combine into 6D uncertainty
uncertainty = torch.cat([rot_uncertainty, trans_uncertainty], dim=-1) # [B, N, 6]
# Derive confidence from uncertainty
# Use geometric mean of rotation and translation uncertainties
rot_uncertainty_mean = rot_uncertainty.mean(dim=-1) # [B, N]
trans_uncertainty_mean = trans_uncertainty.mean(dim=-1) # [B, N]
combined_uncertainty = (rot_uncertainty_mean * trans_uncertainty_mean) ** 0.5
# Convert to confidence: conf = 1 / (1 + uncertainty)
confidence = 1.0 / (1.0 + combined_uncertainty)
return {
"pose": poses,
"uncertainty": uncertainty,
"confidence": confidence,
}
def _axis_angle_to_rotation_matrix(self, axis_angle: torch.Tensor) -> torch.Tensor:
"""
Convert axis-angle representation to rotation matrix using Rodrigues' formula.
Args:
axis_angle: [B, N, 3] axis-angle representation
Returns:
rotation_matrix: [B, N, 3, 3] rotation matrices
"""
B, N, _ = axis_angle.shape
device = axis_angle.device
# Compute angle and axis
angle = torch.norm(axis_angle, dim=-1, keepdim=True) # [B, N, 1]
angle = torch.clamp(angle, min=1e-8) # Avoid division by zero
axis = axis_angle / angle # [B, N, 3]
# Rodrigues' rotation formula
cos_angle = torch.cos(angle) # [B, N, 1]
sin_angle = torch.sin(angle) # [B, N, 1]
# Cross product matrix K = [0, -z, y; z, 0, -x; -y, x, 0]
K = torch.zeros(B, N, 3, 3, device=device)
K[:, :, 0, 1] = -axis[:, :, 2]
K[:, :, 0, 2] = axis[:, :, 1]
K[:, :, 1, 0] = axis[:, :, 2]
K[:, :, 1, 2] = -axis[:, :, 0]
K[:, :, 2, 0] = -axis[:, :, 1]
K[:, :, 2, 1] = axis[:, :, 0]
# Rotation matrix: R = I + sin(θ)K + (1 - cos(θ))K²
I = torch.eye(3, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1)
K_squared = torch.matmul(K, K)
R = I + sin_angle.unsqueeze(-1) * K + (1.0 - cos_angle).unsqueeze(-1) * K_squared
return R
class UncertaintyAwareDA3Wrapper(nn.Module):
"""
Wrapper that adds uncertainty prediction to DA3 model.
This wraps the existing DA3 model and adds uncertainty heads for depth and pose.
The uncertainty heads take features from the DA3 model and predict uncertainty.
"""
def __init__(
self,
da3_model: nn.Module,
depth_uncertainty_head: Optional[DepthUncertaintyHead] = None,
pose_uncertainty_head: Optional[PoseUncertaintyHead] = None,
freeze_base_model: bool = False,
):
"""
Args:
da3_model: Base DA3 model
depth_uncertainty_head: Optional depth uncertainty head (auto-created if None)
pose_uncertainty_head: Optional pose uncertainty head (auto-created if None)
freeze_base_model: If True, freeze base model weights (only train uncertainty heads)
"""
super().__init__()
self.da3_model = da3_model
self.freeze_base_model = freeze_base_model
if freeze_base_model:
for param in self.da3_model.parameters():
param.requires_grad = False
# Auto-create uncertainty heads if not provided
# Note: Feature dimensions need to be determined from model architecture
# For now, we'll create placeholder heads that can be replaced
if depth_uncertainty_head is None:
# Default: assume 1024-dim features (ViT-Large)
self.depth_uncertainty_head = DepthUncertaintyHead(in_dim=1024)
else:
self.depth_uncertainty_head = depth_uncertainty_head
if pose_uncertainty_head is None:
# Default: assume 2048-dim features (concatenated local+global)
self.pose_uncertainty_head = PoseUncertaintyHead(in_dim=2048)
else:
self.pose_uncertainty_head = pose_uncertainty_head
def forward(self, images: list, extract_features: bool = False) -> Dict[str, torch.Tensor]:
"""
Forward pass with uncertainty prediction.
Args:
images: List of input images
extract_features: If True, also return intermediate features
Returns:
Dict with:
- 'depth': [N, H, W] depth maps
- 'depth_uncertainty': [N, H, W] depth uncertainty
- 'depth_confidence': [N, H, W] depth confidence
- 'poses': [N, 3, 4] camera poses
- 'pose_uncertainty': [N, 6] pose uncertainty
- 'pose_confidence': [N] pose confidence
- 'features': (optional) intermediate features if extract_features=True
"""
# Run base DA3 model
da3_output = self.da3_model.inference(images)
# Extract depth and poses
depth = da3_output.depth # [N, H, W] numpy array
poses = da3_output.extrinsics # [N, 3, 4] numpy array
# NOTE: For full uncertainty prediction, we need to extract features from the model
# This requires access to model internals. For now, this is a placeholder.
#
# To implement fully:
# 1. Extract features from DA3 backbone (requires model access)
# 2. Pass features to depth_uncertainty_head
# 3. Extract pose features (from camera tokens or aggregated features)
# 4. Pass to pose_uncertainty_head
#
# For now, return None - uncertainty will be predicted during training
# when features are available
depth_uncertainty = None
depth_confidence = None
pose_uncertainty = None
pose_confidence = None
result = {
"depth": depth,
"poses": poses,
"depth_uncertainty": depth_uncertainty,
"depth_confidence": depth_confidence,
"pose_uncertainty": pose_uncertainty,
"pose_confidence": pose_confidence,
}
if extract_features:
result["features"] = None # Placeholder
return result
def create_uncertainty_head_from_features(
feature_dim: int,
head_type: str = "depth",
**kwargs,
) -> nn.Module:
"""
Create uncertainty head with specified feature dimension.
Args:
feature_dim: Input feature dimension
head_type: 'depth' or 'pose'
**kwargs: Additional arguments for head initialization
Returns:
Uncertainty head module
"""
if head_type == "depth":
return DepthUncertaintyHead(in_dim=feature_dim, **kwargs)
elif head_type == "pose":
return PoseUncertaintyHead(in_dim=feature_dim, **kwargs)
else:
raise ValueError(f"Unknown head type: {head_type}")
def uncertainty_prediction_loss(
uncertainty_pred: torch.Tensor, # [B, ...] predicted uncertainty
uncertainty_target: torch.Tensor, # [B, ...] target uncertainty (from oracle)
confidence_target: Optional[torch.Tensor] = None, # [B, ...] target confidence
loss_type: str = "l1",
) -> torch.Tensor:
"""
Loss function for training uncertainty prediction.
Args:
uncertainty_pred: Predicted uncertainty
uncertainty_target: Target uncertainty (from oracle)
confidence_target: Optional target confidence (alternative to uncertainty)
loss_type: 'l1' or 'l2'
Returns:
Uncertainty prediction loss
"""
if confidence_target is not None:
# Convert confidence to uncertainty: uncertainty = 1 / confidence - 1
uncertainty_target = 1.0 / (confidence_target + 1e-8) - 1.0
valid_mask = (uncertainty_target > 0) & (uncertainty_target < 100.0)
if valid_mask.sum() == 0:
return torch.tensor(0.0, device=uncertainty_pred.device)
if loss_type == "l1":
error = torch.abs(uncertainty_pred[valid_mask] - uncertainty_target[valid_mask])
else: # l2
error = (uncertainty_pred[valid_mask] - uncertainty_target[valid_mask]) ** 2
return error.mean()