# Attention Mechanisms & Activation Functions ## Current State ### Attention Mechanisms in DA3 **DA3 uses DinoV2 Vision Transformer with custom attention:** 1. **Alternating Local/Global Attention** - **Local attention** (layers < `alt_start`): Process each view independently ```python # Flatten batch and sequence: [B, S, N, C] -> [(B*S), N, C] x = rearrange(x, "b s n c -> (b s) n c") x = block(x, pos=pos) # Process independently x = rearrange(x, "(b s) n c -> b s n c", b=b, s=s) ``` - **Global attention** (layers ≥ `alt_start`, odd): Cross-view attention ```python # Concatenate all views: [B, S, N, C] -> [B, (S*N), C] x = rearrange(x, "b s n c -> b (s n) c") x = block(x, pos=pos) # Process all views together ``` 2. **Additional Features:** - **RoPE (Rotary Position Embedding)**: Better spatial understanding - **QK Normalization**: Stabilizes training - **Multi-head attention**: Standard transformer attention **Configuration:** - **DA3-Large**: `alt_start: 8` (layers 0-7 local, then alternating) - **DA3-Giant**: `alt_start: 13` - **DA3Metric-Large**: `alt_start: -1` (disabled, all local) ### Activation Functions in DA3 **Output activations (not hidden layer activations):** 1. **Depth**: `exp` (exponential) ```python depth = exp(logits) # Range: (0, +∞) ``` 2. **Confidence**: `expp1` (exponential + 1) ```python confidence = exp(logits) + 1 # Range: [1, +∞) ``` 3. **Ray**: `linear` (no activation) ```python ray = logits # Range: (-∞, +∞) ``` **Note:** Hidden layer activations (ReLU, GELU, SiLU, etc.) are in the DinoV2 backbone, which we don't control. ## What We Control ### ✅ What We Can Modify 1. **Loss Functions** (`ylff/utils/oracle_losses.py`) - Custom loss weighting - Uncertainty propagation - Confidence-based weighting 2. **Training Pipeline** (`ylff/services/pretrain.py`, `ylff/services/fine_tune.py`) - Training loop - Data loading - Optimization strategies 3. **Preprocessing** (`ylff/services/preprocessing.py`) - Oracle uncertainty computation - Data augmentation - Sequence processing 4. **FlashAttention Wrapper** (`ylff/utils/flash_attention.py`) - Utility exists but requires model code access to integrate ### ❌ What We Cannot Modify (Without Model Code Access) 1. **Model Architecture** (DinoV2 backbone) - Attention mechanisms (local/global alternating) - Hidden layer activations - Transformer blocks 2. **Output Activations** (depth, confidence, ray) - These are part of the DA3 model definition ## Implementing Custom Approaches ### Option 1: Custom Attention Wrapper (Requires Model Access) If you have access to the DA3 model code, you can: 1. **Replace Attention Layers** ```python # Custom attention mechanism class CustomAttention(nn.Module): def __init__(self, dim, num_heads): super().__init__() self.attention = YourCustomAttention(dim, num_heads) def forward(self, x): return self.attention(x) # Replace in model model.encoder.layers[8].attn = CustomAttention(...) ``` 2. **Modify Alternating Pattern** ```python # Change when global attention starts model.dinov2.alt_start = 10 # Start global attention later ``` 3. **Add Custom Position Embeddings** ```python # Replace RoPE with your own model.dinov2.rope = YourCustomPositionEmbedding(...) ``` ### Option 2: Post-Processing with Custom Logic You can add custom logic **after** model inference: 1. **Custom Confidence Computation** ```python # In ylff/utils/oracle_uncertainty.py def compute_custom_confidence(da3_output, oracle_data): # Your custom confidence computation custom_conf = your_confidence_function(da3_output, oracle_data) return custom_conf ``` 2. **Custom Attention-Based Fusion** ```python # Add attention-based fusion of multiple views class AttentionFusion(nn.Module): def forward(self, features_list): # Cross-attention between views fused = self.cross_attention(features_list) return fused ``` ### Option 3: Custom Activation Functions (Output Layer) If you modify the model, you can change output activations: 1. **Custom Depth Activation** ```python # Instead of exp, use your activation def custom_depth_activation(logits): # Your custom function return your_function(logits) ``` 2. **Custom Confidence Activation** ```python # Instead of expp1, use your activation def custom_confidence_activation(logits): # Your custom function return your_function(logits) ``` ## Recommended Approach ### For Custom Attention 1. **If you have model code access:** - Modify `src/depth_anything_3/model/dinov2/vision_transformer.py` - Replace attention blocks with your custom implementation - Test with small models first 2. **If you don't have model code access:** - Use post-processing attention (Option 2) - Add attention-based fusion layers after model inference - Implement in `ylff/utils/oracle_uncertainty.py` or new utility ### For Custom Activations 1. **Output activations:** - Modify model code if available - Or add post-processing to transform outputs 2. **Hidden activations:** - Requires model code access - Or create a wrapper model that processes features ## Example: Custom Cross-View Attention ```python # ylff/utils/custom_attention.py import torch import torch.nn as nn import torch.nn.functional as F class CustomCrossViewAttention(nn.Module): """ Custom attention mechanism for multi-view depth estimation. This can be used as a post-processing step or integrated into the model. """ def __init__(self, dim, num_heads=8): super().__init__() self.num_heads = num_heads self.head_dim = dim // num_heads self.q_proj = nn.Linear(dim, dim) self.k_proj = nn.Linear(dim, dim) self.v_proj = nn.Linear(dim, dim) self.out_proj = nn.Linear(dim, dim) def forward(self, features_list): """ Args: features_list: List of feature tensors from different views Each: [B, N, C] where N is spatial dimensions Returns: Fused features: [B, N, C] """ # Stack views: [B, S, N, C] x = torch.stack(features_list, dim=1) B, S, N, C = x.shape # Reshape for multi-head attention x = x.view(B * S, N, C) # Compute Q, K, V q = self.q_proj(x).view(B * S, N, self.num_heads, self.head_dim) k = self.k_proj(x).view(B * S, N, self.num_heads, self.head_dim) v = self.v_proj(x).view(B * S, N, self.num_heads, self.head_dim) # Transpose for attention: [B*S, num_heads, N, head_dim] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Cross-view attention: reshape to [B, S*N, num_heads, head_dim] q = q.view(B, S * N, self.num_heads, self.head_dim) k = k.view(B, S * N, self.num_heads, self.head_dim) v = v.view(B, S * N, self.num_heads, self.head_dim) # Compute attention scale = 1.0 / (self.head_dim ** 0.5) scores = torch.matmul(q, k.transpose(-2, -1)) * scale attn_weights = F.softmax(scores, dim=-1) # Apply attention out = torch.matmul(attn_weights, v) # Reshape back and project out = out.view(B * S, N, C) out = self.out_proj(out) # Average across views or use reference view out = out.view(B, S, N, C) out = out.mean(dim=1) # [B, N, C] return out ``` ## Example: Custom Activation Functions ```python # ylff/utils/custom_activations.py import torch import torch.nn as nn import torch.nn.functional as F class SwishDepthActivation(nn.Module): """Swish activation for depth (smooth, bounded).""" def forward(self, logits): # Swish: x * sigmoid(x) depth = logits * torch.sigmoid(logits) # Ensure positive depth = F.relu(depth) + 0.1 # Minimum depth return depth class SoftplusConfidenceActivation(nn.Module): """Softplus activation for confidence (smooth, bounded).""" def forward(self, logits): # Softplus: log(1 + exp(x)) confidence = F.softplus(logits) + 1.0 # Minimum confidence of 1 return confidence class ClampedRayActivation(nn.Module): """Clamped activation for rays (bounded directions).""" def forward(self, logits): # Clamp to reasonable range rays = torch.tanh(logits) * 10.0 # Scale to [-10, 10] return rays ``` ## Next Steps 1. **Decide what you want to customize:** - Attention mechanism? - Activation functions? - Both? 2. **Check model code access:** - Do you have access to `src/depth_anything_3/model/`? - Or do you need post-processing approaches? 3. **Implement incrementally:** - Start with post-processing (easier) - Move to model modifications if needed - Test on small datasets first 4. **Integrate with training:** - Add to `ylff/services/pretrain.py` or `ylff/services/fine_tune.py` - Update loss functions if needed - Add CLI/API options Let me know what specific attention mechanism or activation function you want to implement, and I can help you build it! 🚀