3d_model / docs /ATTENTION_AND_ACTIVATIONS.md
Azan
Clean deployment build (Squashed)
7a87926

Attention Mechanisms & Activation Functions

Current State

Attention Mechanisms in DA3

DA3 uses DinoV2 Vision Transformer with custom attention:

  1. Alternating Local/Global Attention

    • Local attention (layers < alt_start): Process each view independently

      # Flatten batch and sequence: [B, S, N, C] -> [(B*S), N, C]
      x = rearrange(x, "b s n c -> (b s) n c")
      x = block(x, pos=pos)  # Process independently
      x = rearrange(x, "(b s) n c -> b s n c", b=b, s=s)
      
    • Global attention (layers β‰₯ alt_start, odd): Cross-view attention

      # Concatenate all views: [B, S, N, C] -> [B, (S*N), C]
      x = rearrange(x, "b s n c -> b (s n) c")
      x = block(x, pos=pos)  # Process all views together
      
  2. Additional Features:

    • RoPE (Rotary Position Embedding): Better spatial understanding
    • QK Normalization: Stabilizes training
    • Multi-head attention: Standard transformer attention

Configuration:

  • DA3-Large: alt_start: 8 (layers 0-7 local, then alternating)
  • DA3-Giant: alt_start: 13
  • DA3Metric-Large: alt_start: -1 (disabled, all local)

Activation Functions in DA3

Output activations (not hidden layer activations):

  1. Depth: exp (exponential)

    depth = exp(logits)  # Range: (0, +∞)
    
  2. Confidence: expp1 (exponential + 1)

    confidence = exp(logits) + 1  # Range: [1, +∞)
    
  3. Ray: linear (no activation)

    ray = logits  # Range: (-∞, +∞)
    

Note: Hidden layer activations (ReLU, GELU, SiLU, etc.) are in the DinoV2 backbone, which we don't control.

What We Control

βœ… What We Can Modify

  1. Loss Functions (ylff/utils/oracle_losses.py)

    • Custom loss weighting
    • Uncertainty propagation
    • Confidence-based weighting
  2. Training Pipeline (ylff/services/pretrain.py, ylff/services/fine_tune.py)

    • Training loop
    • Data loading
    • Optimization strategies
  3. Preprocessing (ylff/services/preprocessing.py)

    • Oracle uncertainty computation
    • Data augmentation
    • Sequence processing
  4. FlashAttention Wrapper (ylff/utils/flash_attention.py)

    • Utility exists but requires model code access to integrate

❌ What We Cannot Modify (Without Model Code Access)

  1. Model Architecture (DinoV2 backbone)

    • Attention mechanisms (local/global alternating)
    • Hidden layer activations
    • Transformer blocks
  2. Output Activations (depth, confidence, ray)

    • These are part of the DA3 model definition

Implementing Custom Approaches

Option 1: Custom Attention Wrapper (Requires Model Access)

If you have access to the DA3 model code, you can:

  1. Replace Attention Layers

    # Custom attention mechanism
    class CustomAttention(nn.Module):
        def __init__(self, dim, num_heads):
            super().__init__()
            self.attention = YourCustomAttention(dim, num_heads)
    
        def forward(self, x):
            return self.attention(x)
    
    # Replace in model
    model.encoder.layers[8].attn = CustomAttention(...)
    
  2. Modify Alternating Pattern

    # Change when global attention starts
    model.dinov2.alt_start = 10  # Start global attention later
    
  3. Add Custom Position Embeddings

    # Replace RoPE with your own
    model.dinov2.rope = YourCustomPositionEmbedding(...)
    

Option 2: Post-Processing with Custom Logic

You can add custom logic after model inference:

  1. Custom Confidence Computation

    # In ylff/utils/oracle_uncertainty.py
    def compute_custom_confidence(da3_output, oracle_data):
        # Your custom confidence computation
        custom_conf = your_confidence_function(da3_output, oracle_data)
        return custom_conf
    
  2. Custom Attention-Based Fusion

    # Add attention-based fusion of multiple views
    class AttentionFusion(nn.Module):
        def forward(self, features_list):
            # Cross-attention between views
            fused = self.cross_attention(features_list)
            return fused
    

Option 3: Custom Activation Functions (Output Layer)

If you modify the model, you can change output activations:

  1. Custom Depth Activation

    # Instead of exp, use your activation
    def custom_depth_activation(logits):
        # Your custom function
        return your_function(logits)
    
  2. Custom Confidence Activation

    # Instead of expp1, use your activation
    def custom_confidence_activation(logits):
        # Your custom function
        return your_function(logits)
    

Recommended Approach

For Custom Attention

  1. If you have model code access:

    • Modify src/depth_anything_3/model/dinov2/vision_transformer.py
    • Replace attention blocks with your custom implementation
    • Test with small models first
  2. If you don't have model code access:

    • Use post-processing attention (Option 2)
    • Add attention-based fusion layers after model inference
    • Implement in ylff/utils/oracle_uncertainty.py or new utility

For Custom Activations

  1. Output activations:

    • Modify model code if available
    • Or add post-processing to transform outputs
  2. Hidden activations:

    • Requires model code access
    • Or create a wrapper model that processes features

Example: Custom Cross-View Attention

# ylff/utils/custom_attention.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class CustomCrossViewAttention(nn.Module):
    """
    Custom attention mechanism for multi-view depth estimation.

    This can be used as a post-processing step or integrated into the model.
    """

    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, features_list):
        """
        Args:
            features_list: List of feature tensors from different views
                          Each: [B, N, C] where N is spatial dimensions

        Returns:
            Fused features: [B, N, C]
        """
        # Stack views: [B, S, N, C]
        x = torch.stack(features_list, dim=1)
        B, S, N, C = x.shape

        # Reshape for multi-head attention
        x = x.view(B * S, N, C)

        # Compute Q, K, V
        q = self.q_proj(x).view(B * S, N, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(B * S, N, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(B * S, N, self.num_heads, self.head_dim)

        # Transpose for attention: [B*S, num_heads, N, head_dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Cross-view attention: reshape to [B, S*N, num_heads, head_dim]
        q = q.view(B, S * N, self.num_heads, self.head_dim)
        k = k.view(B, S * N, self.num_heads, self.head_dim)
        v = v.view(B, S * N, self.num_heads, self.head_dim)

        # Compute attention
        scale = 1.0 / (self.head_dim ** 0.5)
        scores = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn_weights = F.softmax(scores, dim=-1)

        # Apply attention
        out = torch.matmul(attn_weights, v)

        # Reshape back and project
        out = out.view(B * S, N, C)
        out = self.out_proj(out)

        # Average across views or use reference view
        out = out.view(B, S, N, C)
        out = out.mean(dim=1)  # [B, N, C]

        return out

Example: Custom Activation Functions

# ylff/utils/custom_activations.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class SwishDepthActivation(nn.Module):
    """Swish activation for depth (smooth, bounded)."""

    def forward(self, logits):
        # Swish: x * sigmoid(x)
        depth = logits * torch.sigmoid(logits)
        # Ensure positive
        depth = F.relu(depth) + 0.1  # Minimum depth
        return depth

class SoftplusConfidenceActivation(nn.Module):
    """Softplus activation for confidence (smooth, bounded)."""

    def forward(self, logits):
        # Softplus: log(1 + exp(x))
        confidence = F.softplus(logits) + 1.0  # Minimum confidence of 1
        return confidence

class ClampedRayActivation(nn.Module):
    """Clamped activation for rays (bounded directions)."""

    def forward(self, logits):
        # Clamp to reasonable range
        rays = torch.tanh(logits) * 10.0  # Scale to [-10, 10]
        return rays

Next Steps

  1. Decide what you want to customize:

    • Attention mechanism?
    • Activation functions?
    • Both?
  2. Check model code access:

    • Do you have access to src/depth_anything_3/model/?
    • Or do you need post-processing approaches?
  3. Implement incrementally:

    • Start with post-processing (easier)
    • Move to model modifications if needed
    • Test on small datasets first
  4. Integrate with training:

    • Add to ylff/services/pretrain.py or ylff/services/fine_tune.py
    • Update loss functions if needed
    • Add CLI/API options

Let me know what specific attention mechanism or activation function you want to implement, and I can help you build it! πŸš€