| |
| |
| |
| |
| |
|
|
|
|
| |
|
|
|
|
| import os |
| from typing import List, Dict, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from .head_act import activate_head |
| from .utils import create_uv_grid, position_grid_to_embed |
|
|
|
|
| class DPTHead(nn.Module): |
| """ |
| DPT Head for dense prediction tasks. |
| |
| This implementation follows the architecture described in "Vision Transformers for Dense Prediction" |
| (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer |
| backbone and produces dense predictions by fusing multi-scale features. |
| |
| Args: |
| dim_in (int): Input dimension (channels). |
| patch_size (int, optional): Patch size. Default is 14. |
| output_dim (int, optional): Number of output channels. Default is 4. |
| activation (str, optional): Activation type. Default is "inv_log". |
| conf_activation (str, optional): Confidence activation type. Default is "expp1". |
| features (int, optional): Feature channels for intermediate representations. Default is 256. |
| out_channels (List[int], optional): Output channels for each intermediate layer. |
| intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. |
| pos_embed (bool, optional): Whether to use positional embedding. Default is True. |
| feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. |
| down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. |
| """ |
|
|
| def __init__( |
| self, |
| dim_in: int, |
| patch_size: int = 14, |
| output_dim: int = 4, |
| activation: str = "inv_log", |
| conf_activation: str = "expp1", |
| features: int = 256, |
| out_channels: List[int] = [256, 512, 1024, 1024], |
| intermediate_layer_idx: List[int] = [4, 11, 17, 23], |
| pos_embed: bool = True, |
| feature_only: bool = False, |
| down_ratio: int = 1, |
| ) -> None: |
| super(DPTHead, self).__init__() |
| self.patch_size = patch_size |
| self.activation = activation |
| self.conf_activation = conf_activation |
| self.pos_embed = pos_embed |
| self.feature_only = feature_only |
| self.down_ratio = down_ratio |
| self.intermediate_layer_idx = intermediate_layer_idx |
|
|
| self.norm = nn.LayerNorm(dim_in) |
|
|
| |
| self.projects = nn.ModuleList( |
| [ |
| nn.Conv2d( |
| in_channels=dim_in, |
| out_channels=oc, |
| kernel_size=1, |
| stride=1, |
| padding=0, |
| ) |
| for oc in out_channels |
| ] |
| ) |
|
|
| |
| self.resize_layers = nn.ModuleList( |
| [ |
| nn.ConvTranspose2d( |
| in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 |
| ), |
| nn.ConvTranspose2d( |
| in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 |
| ), |
| nn.Identity(), |
| nn.Conv2d( |
| in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 |
| ), |
| ] |
| ) |
|
|
| self.scratch = _make_scratch( |
| out_channels, |
| features, |
| expand=False, |
| ) |
|
|
| |
| self.scratch.stem_transpose = None |
| self.scratch.refinenet1 = _make_fusion_block(features) |
| self.scratch.refinenet2 = _make_fusion_block(features) |
| self.scratch.refinenet3 = _make_fusion_block(features) |
| self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) |
|
|
| head_features_1 = features |
| head_features_2 = 32 |
| |
| if feature_only: |
| self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) |
| else: |
| self.scratch.output_conv1 = nn.Conv2d( |
| head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 |
| ) |
| conv2_in_channels = head_features_1 // 2 |
|
|
| self.scratch.output_conv2 = nn.Sequential( |
| nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), |
| ) |
|
|
| def forward( |
| self, |
| aggregated_tokens_list: List[torch.Tensor], |
| images: torch.Tensor, |
| patch_start_idx: int, |
| frames_chunk_size: int = 8, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| """ |
| Forward pass through the DPT head, supports processing by chunking frames. |
| Args: |
| aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. |
| images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. |
| patch_start_idx (int): Starting index for patch tokens in the token sequence. |
| Used to separate patch tokens from other tokens (e.g., camera or register tokens). |
| frames_chunk_size (int, optional): Number of frames to process in each chunk. |
| If None or larger than S, all frames are processed at once. Default: 8. |
| |
| Returns: |
| Tensor or Tuple[Tensor, Tensor]: |
| - If feature_only=True: Feature maps with shape [B, S, C, H, W] |
| - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] |
| """ |
| B, S, _, H, W = images.shape |
|
|
| |
| if frames_chunk_size is None or frames_chunk_size >= S: |
| return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) |
|
|
| |
| assert frames_chunk_size > 0 |
|
|
| |
| all_preds = [] |
| all_conf = [] |
| |
| for frames_start_idx in range(0, S, frames_chunk_size): |
| frames_end_idx = min(frames_start_idx + frames_chunk_size, S) |
|
|
| |
| if self.feature_only: |
| chunk_output = self._forward_impl( |
| aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx |
| ) |
| all_preds.append(chunk_output) |
| else: |
| chunk_preds, chunk_conf = self._forward_impl( |
| aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx |
| ) |
| all_preds.append(chunk_preds) |
| all_conf.append(chunk_conf) |
| |
| |
| if self.feature_only: |
| return torch.cat(all_preds, dim=1) |
| else: |
| return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) |
|
|
| def _forward_impl( |
| self, |
| aggregated_tokens_list: List[torch.Tensor], |
| images: torch.Tensor, |
| patch_start_idx: int, |
| frames_start_idx: int = None, |
| frames_end_idx: int = None, |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| """ |
| Implementation of the forward pass through the DPT head. |
| |
| This method processes a specific chunk of frames from the sequence. |
| |
| Args: |
| aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. |
| images (Tensor): Input images with shape [B, S, 3, H, W]. |
| patch_start_idx (int): Starting index for patch tokens. |
| frames_start_idx (int, optional): Starting index for frames to process. |
| frames_end_idx (int, optional): Ending index for frames to process. |
| |
| Returns: |
| Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). |
| """ |
| if frames_start_idx is not None and frames_end_idx is not None: |
| images = images[:, frames_start_idx:frames_end_idx] |
|
|
| B, S, _, H, W = images.shape |
|
|
| patch_h, patch_w = H // self.patch_size, W // self.patch_size |
|
|
| out = [] |
| dpt_idx = 0 |
| |
| for layer_idx in self.intermediate_layer_idx: |
| |
| if len(aggregated_tokens_list) > 10: |
| x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] |
| else: |
| list_idx = self.intermediate_layer_idx.index(layer_idx) |
| x = aggregated_tokens_list[list_idx][:, :, patch_start_idx:] |
|
|
| |
| if frames_start_idx is not None and frames_end_idx is not None: |
| x = x[:, frames_start_idx:frames_end_idx].contiguous() |
|
|
| x = x.view(B * S, -1, x.shape[-1]) |
| x = self.norm(x) |
| |
| x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) |
|
|
| x = self.projects[dpt_idx](x) |
| if self.pos_embed: |
| x = self._apply_pos_embed(x, W, H) |
|
|
| x = self.resize_layers[dpt_idx](x) |
| |
| out.append(x) |
| dpt_idx += 1 |
|
|
| |
| out = self.scratch_forward(out) |
| |
| out = custom_interpolate( |
| out, |
| (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), |
| mode="bilinear", |
| align_corners=True, |
| ) |
|
|
| if self.pos_embed: |
| out = self._apply_pos_embed(out, W, H) |
|
|
| if self.feature_only: |
| return out.view(B, S, *out.shape[1:]) |
|
|
| out = self.scratch.output_conv2(out) |
| preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) |
|
|
| preds = preds.view(B, S, *preds.shape[1:]) |
| conf = conf.view(B, S, *conf.shape[1:]) |
| return preds, conf |
|
|
| def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: |
| """ |
| Apply positional embedding to tensor x. |
| """ |
| patch_w = x.shape[-1] |
| patch_h = x.shape[-2] |
| pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) |
| pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) |
| pos_embed = pos_embed * ratio |
| pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) |
| return x + pos_embed |
|
|
| def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: |
| """ |
| Forward pass through the fusion blocks. |
| |
| Args: |
| features (List[Tensor]): List of feature maps from different layers. |
| |
| Returns: |
| Tensor: Fused feature map. |
| """ |
| layer_1, layer_2, layer_3, layer_4 = features |
|
|
| layer_1_rn = self.scratch.layer1_rn(layer_1) |
| layer_2_rn = self.scratch.layer2_rn(layer_2) |
| layer_3_rn = self.scratch.layer3_rn(layer_3) |
| layer_4_rn = self.scratch.layer4_rn(layer_4) |
|
|
| out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) |
| del layer_4_rn, layer_4 |
|
|
| out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) |
| del layer_3_rn, layer_3 |
|
|
| out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) |
| del layer_2_rn, layer_2 |
|
|
| out = self.scratch.refinenet1(out, layer_1_rn) |
| del layer_1_rn, layer_1 |
|
|
| out = self.scratch.output_conv1(out) |
| return out |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: |
| return FeatureFusionBlock( |
| features, |
| nn.ReLU(inplace=True), |
| deconv=False, |
| bn=False, |
| expand=False, |
| align_corners=True, |
| size=size, |
| has_residual=has_residual, |
| groups=groups, |
| ) |
|
|
|
|
| def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: |
| scratch = nn.Module() |
| out_shape1 = out_shape |
| out_shape2 = out_shape |
| out_shape3 = out_shape |
| if len(in_shape) >= 4: |
| out_shape4 = out_shape |
|
|
| if expand: |
| out_shape1 = out_shape |
| out_shape2 = out_shape * 2 |
| out_shape3 = out_shape * 4 |
| if len(in_shape) >= 4: |
| out_shape4 = out_shape * 8 |
|
|
| scratch.layer1_rn = nn.Conv2d( |
| in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups |
| ) |
| scratch.layer2_rn = nn.Conv2d( |
| in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups |
| ) |
| scratch.layer3_rn = nn.Conv2d( |
| in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups |
| ) |
| if len(in_shape) >= 4: |
| scratch.layer4_rn = nn.Conv2d( |
| in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups |
| ) |
| return scratch |
|
|
|
|
| class ResidualConvUnit(nn.Module): |
| """Residual convolution module.""" |
|
|
| def __init__(self, features, activation, bn, groups=1): |
| """Init. |
| |
| Args: |
| features (int): number of features |
| """ |
| super().__init__() |
|
|
| self.bn = bn |
| self.groups = groups |
| self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) |
| self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) |
|
|
| self.norm1 = None |
| self.norm2 = None |
|
|
| self.activation = activation |
| self.skip_add = nn.quantized.FloatFunctional() |
|
|
| def forward(self, x): |
| """Forward pass. |
| |
| Args: |
| x (tensor): input |
| |
| Returns: |
| tensor: output |
| """ |
|
|
| out = self.activation(x) |
| out = self.conv1(out) |
| if self.norm1 is not None: |
| out = self.norm1(out) |
|
|
| out = self.activation(out) |
| out = self.conv2(out) |
| if self.norm2 is not None: |
| out = self.norm2(out) |
|
|
| return self.skip_add.add(out, x) |
|
|
|
|
| class FeatureFusionBlock(nn.Module): |
| """Feature fusion block.""" |
|
|
| def __init__( |
| self, |
| features, |
| activation, |
| deconv=False, |
| bn=False, |
| expand=False, |
| align_corners=True, |
| size=None, |
| has_residual=True, |
| groups=1, |
| ): |
| """Init. |
| |
| Args: |
| features (int): number of features |
| """ |
| super(FeatureFusionBlock, self).__init__() |
|
|
| self.deconv = deconv |
| self.align_corners = align_corners |
| self.groups = groups |
| self.expand = expand |
| out_features = features |
| if self.expand == True: |
| out_features = features // 2 |
|
|
| self.out_conv = nn.Conv2d( |
| features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups |
| ) |
|
|
| if has_residual: |
| self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) |
|
|
| self.has_residual = has_residual |
| self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) |
|
|
| self.skip_add = nn.quantized.FloatFunctional() |
| self.size = size |
|
|
| def forward(self, *xs, size=None): |
| """Forward pass. |
| |
| Returns: |
| tensor: output |
| """ |
| output = xs[0] |
|
|
| if self.has_residual: |
| res = self.resConfUnit1(xs[1]) |
| output = self.skip_add.add(output, res) |
|
|
| output = self.resConfUnit2(output) |
|
|
| if (size is None) and (self.size is None): |
| modifier = {"scale_factor": 2} |
| elif size is None: |
| modifier = {"size": self.size} |
| else: |
| modifier = {"size": size} |
|
|
| output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) |
| output = self.out_conv(output) |
|
|
| return output |
|
|
|
|
| def custom_interpolate( |
| x: torch.Tensor, |
| size: Tuple[int, int] = None, |
| scale_factor: float = None, |
| mode: str = "bilinear", |
| align_corners: bool = True, |
| ) -> torch.Tensor: |
| """ |
| Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. |
| """ |
| if size is None: |
| size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) |
| |
| INT_MAX = 1610612736 |
|
|
| input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] |
|
|
| if input_elements > INT_MAX: |
| chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) |
| interpolated_chunks = [ |
| nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks |
| ] |
| x = torch.cat(interpolated_chunks, dim=0) |
| return x.contiguous() |
| else: |
| return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) |
|
|