| """Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
|
| Dense Prediction Transformer Decoder architecture.
|
|
|
| Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| from typing import Iterable
|
|
|
| import torch
|
| from torch import nn
|
|
|
|
|
| class MultiresConvDecoder(nn.Module):
|
| """Decoder for multi-resolution encodings."""
|
|
|
| def __init__(
|
| self,
|
| dims_encoder: Iterable[int],
|
| dim_decoder: int,
|
| ):
|
| """Initialize multiresolution convolutional decoder.
|
|
|
| Args:
|
| ----
|
| dims_encoder: Expected dims at each level from the encoder.
|
| dim_decoder: Dim of decoder features.
|
|
|
| """
|
| super().__init__()
|
| self.dims_encoder = list(dims_encoder)
|
| self.dim_decoder = dim_decoder
|
| self.dim_out = dim_decoder
|
|
|
| num_encoders = len(self.dims_encoder)
|
|
|
|
|
|
|
|
|
| conv0 = (
|
| nn.Conv2d(self.dims_encoder[0], dim_decoder, kernel_size=1, bias=False)
|
| if self.dims_encoder[0] != dim_decoder
|
| else nn.Identity()
|
| )
|
|
|
| convs = [conv0]
|
| for i in range(1, num_encoders):
|
| convs.append(
|
| nn.Conv2d(
|
| self.dims_encoder[i],
|
| dim_decoder,
|
| kernel_size=3,
|
| stride=1,
|
| padding=1,
|
| bias=False,
|
| )
|
| )
|
|
|
| self.convs = nn.ModuleList(convs)
|
|
|
| fusions = []
|
| for i in range(num_encoders):
|
| fusions.append(
|
| FeatureFusionBlock2d(
|
| num_features=dim_decoder,
|
| deconv=(i != 0),
|
| batch_norm=False,
|
| )
|
| )
|
| self.fusions = nn.ModuleList(fusions)
|
|
|
| def forward(self, encodings: torch.Tensor) -> torch.Tensor:
|
| """Decode the multi-resolution encodings."""
|
| num_levels = len(encodings)
|
| num_encoders = len(self.dims_encoder)
|
|
|
| if num_levels != num_encoders:
|
| raise ValueError(
|
| f"Got encoder output levels={num_levels}, expected levels={num_encoders+1}."
|
| )
|
|
|
|
|
|
|
|
|
| features = self.convs[-1](encodings[-1])
|
| lowres_features = features
|
| features = self.fusions[-1](features)
|
| for i in range(num_levels - 2, -1, -1):
|
| features_i = self.convs[i](encodings[i])
|
| features = self.fusions[i](features, features_i)
|
| return features, lowres_features
|
|
|
|
|
| class ResidualBlock(nn.Module):
|
| """Generic implementation of residual blocks.
|
|
|
| This implements a generic residual block from
|
| He et al. - Identity Mappings in Deep Residual Networks (2016),
|
| https://arxiv.org/abs/1603.05027
|
| which can be further customized via factory functions.
|
| """
|
|
|
| def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None:
|
| """Initialize ResidualBlock."""
|
| super().__init__()
|
| self.residual = residual
|
| self.shortcut = shortcut
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """Apply residual block."""
|
| delta_x = self.residual(x)
|
|
|
| if self.shortcut is not None:
|
| x = self.shortcut(x)
|
|
|
| return x + delta_x
|
|
|
|
|
| class FeatureFusionBlock2d(nn.Module):
|
| """Feature fusion for DPT."""
|
|
|
| def __init__(
|
| self,
|
| num_features: int,
|
| deconv: bool = False,
|
| batch_norm: bool = False,
|
| ):
|
| """Initialize feature fusion block.
|
|
|
| Args:
|
| ----
|
| num_features: Input and output dimensions.
|
| deconv: Whether to use deconv before the final output conv.
|
| batch_norm: Whether to use batch normalization in resnet blocks.
|
|
|
| """
|
| super().__init__()
|
|
|
| self.resnet1 = self._residual_block(num_features, batch_norm)
|
| self.resnet2 = self._residual_block(num_features, batch_norm)
|
|
|
| self.use_deconv = deconv
|
| if deconv:
|
| self.deconv = nn.ConvTranspose2d(
|
| in_channels=num_features,
|
| out_channels=num_features,
|
| kernel_size=2,
|
| stride=2,
|
| padding=0,
|
| bias=False,
|
| )
|
|
|
| self.out_conv = nn.Conv2d(
|
| num_features,
|
| num_features,
|
| kernel_size=1,
|
| stride=1,
|
| padding=0,
|
| bias=True,
|
| )
|
|
|
| self.skip_add = nn.quantized.FloatFunctional()
|
|
|
| def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor:
|
| """Process and fuse input features."""
|
| x = x0
|
|
|
| if x1 is not None:
|
| res = self.resnet1(x1)
|
| x = self.skip_add.add(x, res)
|
|
|
| x = self.resnet2(x)
|
|
|
| if self.use_deconv:
|
| x = self.deconv(x)
|
| x = self.out_conv(x)
|
|
|
| return x
|
|
|
| @staticmethod
|
| def _residual_block(num_features: int, batch_norm: bool):
|
| """Create a residual block."""
|
|
|
| def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]:
|
| layers = [
|
| nn.ReLU(False),
|
| nn.Conv2d(
|
| num_features,
|
| num_features,
|
| kernel_size=3,
|
| stride=1,
|
| padding=1,
|
| bias=not batch_norm,
|
| ),
|
| ]
|
| if batch_norm:
|
| layers.append(nn.BatchNorm2d(dim))
|
| return layers
|
|
|
| residual = nn.Sequential(
|
| *_create_block(dim=num_features, batch_norm=batch_norm),
|
| *_create_block(dim=num_features, batch_norm=batch_norm),
|
| )
|
| return ResidualBlock(residual)
|
|
|