# flake8: noqa E501 # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Sequence, Tuple import torch import torch.nn as nn from addict import Dict from depth_anything_3.model.dpt import _make_fusion_block, _make_scratch from depth_anything_3.model.utils.head_utils import ( Permute, create_uv_grid, custom_interpolate, position_grid_to_embed, ) class DualDPT(nn.Module): """ Dual-head DPT for dense prediction with an always-on auxiliary head. Architectural notes: - Sky/object branches are removed. - `intermediate_layer_idx` is fixed to (0, 1, 2, 3). - Auxiliary head has its **own** fusion blocks (no fusion_inplace / no sharing). - Auxiliary head is internally multi-level; **only the final level** is returned. - Returns a **dict** with keys from `head_names`, e.g.: { main_name, f"{main_name}_conf", aux_name, f"{aux_name}_conf" } - `feature_only` is fixed to False. """ def __init__( self, dim_in: int, *, patch_size: int = 14, output_dim: int = 2, activation: str = "exp", conf_activation: str = "expp1", features: int = 256, out_channels: Sequence[int] = (256, 512, 1024, 1024), pos_embed: bool = True, down_ratio: int = 1, aux_pyramid_levels: int = 4, aux_out1_conv_num: int = 5, head_names: Tuple[str, str] = ("depth", "ray"), ) -> None: super().__init__() # -------------------- configuration -------------------- self.patch_size = patch_size self.activation = activation self.conf_activation = conf_activation self.pos_embed = pos_embed self.down_ratio = down_ratio self.aux_levels = aux_pyramid_levels self.aux_out1_conv_num = aux_out1_conv_num # names ONLY come from config (no hard-coded strings elsewhere) self.head_main, self.head_aux = head_names # Always expect 4 scales; enforce intermediate idx = (0, 1, 2, 3) self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3) # -------------------- token pre-norm + per-stage projection -------------------- self.norm = nn.LayerNorm(dim_in) self.projects = nn.ModuleList( [nn.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0) for oc in out_channels] ) # -------------------- spatial re-sizers (align to common scale before fusion) -------------------- # design: stage strides (x4, x2, x1, /2) relative to patch grid to align to a common pivot scale self.resize_layers = nn.ModuleList( [ nn.ConvTranspose2d( out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0 ), nn.ConvTranspose2d( out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0 ), nn.Identity(), nn.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1), ] ) # -------------------- scratch: stage adapters + fusion (main & aux are separate) -------------------- self.scratch = _make_scratch(list(out_channels), features, expand=False) # Main fusion chain (independent) self.scratch.refinenet1 = _make_fusion_block(features) self.scratch.refinenet2 = _make_fusion_block(features) self.scratch.refinenet3 = _make_fusion_block(features) self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) # Primary head neck + head (independent) head_features_1 = features head_features_2 = 32 self.scratch.output_conv1 = nn.Conv2d( head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 ) self.scratch.output_conv2 = nn.Sequential( nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), ) # Auxiliary fusion chain (completely separate; no sharing, i.e., "fusion_inplace=False") self.scratch.refinenet1_aux = _make_fusion_block(features) self.scratch.refinenet2_aux = _make_fusion_block(features) self.scratch.refinenet3_aux = _make_fusion_block(features) self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False) # Aux pre-head per level (we will only *return final level*) self.scratch.output_conv1_aux = nn.ModuleList( [self._make_aux_out1_block(head_features_1) for _ in range(self.aux_levels)] ) # Aux final projection per level use_ln = True ln_seq = ( [Permute((0, 2, 3, 1)), nn.LayerNorm(head_features_2), Permute((0, 3, 1, 2))] if use_ln else [] ) self.scratch.output_conv2_aux = nn.ModuleList( [ nn.Sequential( nn.Conv2d( head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1 ), *ln_seq, nn.ReLU(inplace=True), nn.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0), ) for _ in range(self.aux_levels) ] ) # ------------------------------------------------------------------------- # Public forward (supports frame chunking for memory) # ------------------------------------------------------------------------- def forward( self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int, chunk_size: int = 8, ) -> Dict[str, torch.Tensor]: """ Args: aggregated_tokens_list: List of 4 tensors [B, S, T, C] from transformer. images: [B, S, 3, H, W], in [0, 1]. patch_start_idx: Patch-token start in the token sequence (to drop non-patch tokens). frames_chunk_size: Optional chunking along S for memory. Returns: Dict[str, Tensor] with keys based on `head_names`, e.g.: self.head_main, f"{self.head_main}_conf", self.head_aux, f"{self.head_aux}_conf" Shapes: main: [B, S, out_dim, H/down_ratio, W/down_ratio] main_cf: [B, S, 1, H/down_ratio, W/down_ratio] aux: [B, S, 7, H/down_ratio, W/down_ratio] aux_cf: [B, S, 1, H/down_ratio, W/down_ratio] """ B, S, N, C = feats[0][0].shape feats = [feat[0].reshape(B * S, N, C) for feat in feats] if chunk_size is None or chunk_size >= S: out_dict = self._forward_impl(feats, H, W, patch_start_idx) out_dict = {k: v.reshape(B, S, *v.shape[1:]) for k, v in out_dict.items()} return Dict(out_dict) out_dicts = [] for s0 in range(0, S, chunk_size): s1 = min(s0 + chunk_size, S) out_dict = self._forward_impl( [feat[s0:s1] for feat in feats], H, W, patch_start_idx, ) out_dicts.append(out_dict) out_dict = { k: torch.cat([out_dict[k] for out_dict in out_dicts], dim=0) for k in out_dicts[0].keys() } out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()} return Dict(out_dict) # ------------------------------------------------------------------------- # Internal forward (single chunk) # ------------------------------------------------------------------------- def _forward_impl( self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int, ) -> Dict[str, torch.Tensor]: B, _, C = feats[0].shape ph, pw = H // self.patch_size, W // self.patch_size resized_feats = [] for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): x = feats[take_idx][:, patch_start_idx:] x = self.norm(x) x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw] x = self.projects[stage_idx](x) if self.pos_embed: x = self._add_pos_embed(x, W, H) x = self.resize_layers[stage_idx](x) # align scales resized_feats.append(x) # 2) Fuse pyramid (main & aux are completely independent) fused_main, fused_aux_pyr = self._fuse(resized_feats) # 3) Upsample to target resolution and (optional) add pos-embed again h_out = int(ph * self.patch_size / self.down_ratio) w_out = int(pw * self.patch_size / self.down_ratio) fused_main = custom_interpolate( fused_main, (h_out, w_out), mode="bilinear", align_corners=True ) if self.pos_embed: fused_main = self._add_pos_embed(fused_main, W, H) # Primary head: conv1 -> conv2 -> activate # fused_main = self.scratch.output_conv1(fused_main) main_logits = self.scratch.output_conv2(fused_main) fmap = main_logits.permute(0, 2, 3, 1) main_pred = self._apply_activation_single(fmap[..., :-1], self.activation) main_conf = self._apply_activation_single(fmap[..., -1], self.conf_activation) # Auxiliary head (multi-level inside) -> only last level returned (after activation) last_aux = fused_aux_pyr[-1] if self.pos_embed: last_aux = self._add_pos_embed(last_aux, W, H) # neck (per-level pre-conv) then final projection (only for last level) # last_aux = self.scratch.output_conv1_aux[-1](last_aux) last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux) fmap_last = last_aux_logits.permute(0, 2, 3, 1) aux_pred = self._apply_activation_single(fmap_last[..., :-1], "linear") aux_conf = self._apply_activation_single(fmap_last[..., -1], self.conf_activation) return { self.head_main: main_pred.squeeze(-1), f"{self.head_main}_conf": main_conf, self.head_aux: aux_pred, f"{self.head_aux}_conf": aux_conf, } # ------------------------------------------------------------------------- # Subroutines # ------------------------------------------------------------------------- def _fuse(self, feats: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]: """ Feature pyramid fusion. Returns: fused_main: Tensor at finest scale (after refinenet1) aux_pyr: List of aux tensors at each level (pre out_conv1_aux) """ l1, l2, l3, l4 = feats l1_rn = self.scratch.layer1_rn(l1) l2_rn = self.scratch.layer2_rn(l2) l3_rn = self.scratch.layer3_rn(l3) l4_rn = self.scratch.layer4_rn(l4) # level 4 -> 3 out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:]) aux_out = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:]) aux_list: List[torch.Tensor] = [] if self.aux_levels >= 4: aux_list.append(aux_out) # level 3 -> 2 out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:]) aux_out = self.scratch.refinenet3_aux(aux_out, l3_rn, size=l2_rn.shape[2:]) if self.aux_levels >= 3: aux_list.append(aux_out) # level 2 -> 1 out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:]) aux_out = self.scratch.refinenet2_aux(aux_out, l2_rn, size=l1_rn.shape[2:]) if self.aux_levels >= 2: aux_list.append(aux_out) # level 1 (final) out = self.scratch.refinenet1(out, l1_rn) aux_out = self.scratch.refinenet1_aux(aux_out, l1_rn) aux_list.append(aux_out) out = self.scratch.output_conv1(out) aux_list = [self.scratch.output_conv1_aux[i](aux) for i, aux in enumerate(aux_list)] return out, aux_list def _add_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: """Simple UV positional embedding added to feature maps.""" pw, ph = x.shape[-1], x.shape[-2] pe = create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device) pe = position_grid_to_embed(pe, x.shape[1]) * ratio pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) return x + pe def _make_aux_out1_block(self, in_ch: int) -> nn.Sequential: """Factory for the aux pre-head stack before the final 1x1 projection.""" if self.aux_out1_conv_num == 5: return nn.Sequential( nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1), nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1), nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), ) if self.aux_out1_conv_num == 3: return nn.Sequential( nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1), nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), ) if self.aux_out1_conv_num == 1: return nn.Sequential(nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1)) raise ValueError(f"aux_out1_conv_num {self.aux_out1_conv_num} not supported") def _apply_activation_single( self, x: torch.Tensor, activation: str = "linear" ) -> torch.Tensor: """ Apply activation to single channel output, maintaining semantic consistency with value branch in multi-channel case. Supports: exp / relu / sigmoid / softplus / tanh / linear / expp1 """ act = activation.lower() if isinstance(activation, str) else activation if act == "exp": return torch.exp(x) if act == "expm1": return torch.expm1(x) if act == "expp1": return torch.exp(x) + 1 if act == "relu": return torch.relu(x) if act == "sigmoid": return torch.sigmoid(x) if act == "softplus": return torch.nn.functional.softplus(x) if act == "tanh": return torch.tanh(x) # Default linear return x # # ----------------------------------------------------------------------------- # # Building blocks (tidy) # # ----------------------------------------------------------------------------- # def _make_fusion_block( # features: int, # size: Tuple[int, int] = None, # has_residual: bool = True, # groups: int = 1, # inplace: bool = False, # <- activation uses inplace=True by default; not related to "fusion_inplace" # ) -> nn.Module: # return FeatureFusionBlock( # features=features, # activation=nn.ReLU(inplace=inplace), # deconv=False, # bn=False, # expand=False, # align_corners=True, # size=size, # has_residual=has_residual, # groups=groups, # ) # def _make_scratch( # in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False # ) -> nn.Module: # scratch = nn.Module() # # optionally expand widths by stage # c1 = out_shape # c2 = out_shape * (2 if expand else 1) # c3 = out_shape * (4 if expand else 1) # c4 = out_shape * (8 if expand else 1) # scratch.layer1_rn = nn.Conv2d(in_shape[0], c1, 3, 1, 1, bias=False, groups=groups) # scratch.layer2_rn = nn.Conv2d(in_shape[1], c2, 3, 1, 1, bias=False, groups=groups) # scratch.layer3_rn = nn.Conv2d(in_shape[2], c3, 3, 1, 1, bias=False, groups=groups) # scratch.layer4_rn = nn.Conv2d(in_shape[3], c4, 3, 1, 1, bias=False, groups=groups) # return scratch # class ResidualConvUnit(nn.Module): # """Lightweight residual conv block used within fusion.""" # def __init__(self, features: int, activation: nn.Module, bn: bool, groups: int = 1) -> None: # super().__init__() # self.bn = bn # self.groups = groups # self.conv1 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups) # self.conv2 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups) # self.norm1 = None # self.norm2 = None # self.activation = activation # self.skip_add = nn.quantized.FloatFunctional() # def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override] # out = self.activation(x) # out = self.conv1(out) # if self.norm1 is not None: # out = self.norm1(out) # out = self.activation(out) # out = self.conv2(out) # if self.norm2 is not None: # out = self.norm2(out) # return self.skip_add.add(out, x) # class FeatureFusionBlock(nn.Module): # """Top-down fusion block: (optional) residual merge + upsample + 1x1 shrink.""" # def __init__( # self, # features: int, # activation: nn.Module, # deconv: bool = False, # bn: bool = False, # expand: bool = False, # align_corners: bool = True, # size: Tuple[int, int] = None, # has_residual: bool = True, # groups: int = 1, # ) -> None: # super().__init__() # self.align_corners = align_corners # self.size = size # self.has_residual = has_residual # self.resConfUnit1 = ( # ResidualConvUnit(features, activation, bn, groups=groups) if has_residual else None # ) # self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=groups) # out_features = (features // 2) if expand else features # self.out_conv = nn.Conv2d(features, out_features, 1, 1, 0, bias=True, groups=groups) # self.skip_add = nn.quantized.FloatFunctional() # def forward(self, *xs: torch.Tensor, size: Tuple[int, int] = None) -> torch.Tensor: # type: ignore[override] # """ # xs: # - xs[0]: top input # - xs[1]: (optional) lateral (to be added with residual) # """ # y = xs[0] # if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None: # y = self.skip_add.add(y, self.resConfUnit1(xs[1])) # y = self.resConfUnit2(y) # # upsample # if (size is None) and (self.size is None): # up_kwargs = {"scale_factor": 2} # elif size is None: # up_kwargs = {"size": self.size} # else: # up_kwargs = {"size": size} # y = custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners) # y = self.out_conv(y) # return y