# %% import warnings from typing import Optional, Sequence, Union, Callable, Literal, Any import torch import torch.nn as nn from segmentation_models_pytorch.base import ( ClassificationHead, SegmentationModel, ) from segmentation_models_pytorch.encoders.timm_vit import TimmViTEncoder from segmentation_models_pytorch.base.utils import is_torch_compiling from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from segmentation_models_pytorch.base.modules import Activation class ReadoutConcatBlock(nn.Module): """ Concatenates the cls tokens with the features to make use of the global information aggregated in the prefix (cls) tokens. Projects the combined feature map to the original embedding dimension using a MLP. According to: https://github.com/isl-org/DPT/blob/cd3fe90bb4c48577535cc4d51b602acca688a2ee/dpt/vit.py#L79-L90 """ def __init__(self, embed_dim: int, has_prefix_tokens: bool): super().__init__() in_features = embed_dim * 2 if has_prefix_tokens else embed_dim out_features = embed_dim self.project = nn.Sequential( nn.Linear(in_features, out_features), nn.GELU(), ) def forward( self, features: torch.Tensor, prefix_tokens: Optional[torch.Tensor] = None ) -> torch.Tensor: batch_size, embed_dim, height, width = features.shape # Rearrange to (batch_size, height * width, embed_dim) features = features.view(batch_size, embed_dim, -1) features = features.transpose(1, 2).contiguous() if prefix_tokens is not None: # (batch_size, num_prefix_tokens, embed_dim) -> (batch_size, 1, embed_dim) prefix_tokens = prefix_tokens[:, :1].expand_as(features) features = torch.cat([features, prefix_tokens], dim=2) # Project to embedding dimension features = self.project(features) # Rearrange back to (batch_size, embed_dim, height, width) features = features.transpose(1, 2) features = features.view(batch_size, -1, height, width) return features class ReadoutAddBlock(nn.Module): """ Adds the prefix tokens to the features to make use of the global information aggregated in the prefix (cls) tokens. According to: https://github.com/isl-org/DPT/blob/cd3fe90bb4c48577535cc4d51b602acca688a2ee/dpt/vit.py#L71-L76 """ def forward( self, features: torch.Tensor, prefix_tokens: Optional[torch.Tensor] = None ) -> torch.Tensor: if prefix_tokens is not None: batch_size, embed_dim, height, width = features.shape prefix_tokens = prefix_tokens.mean(dim=1) prefix_tokens = prefix_tokens.view(batch_size, embed_dim, 1, 1) features = features + prefix_tokens return features class ReadoutIgnoreBlock(nn.Module): """ Ignores the prefix tokens and returns the features as is. """ def forward(self, features: torch.Tensor, *args, **kwargs) -> torch.Tensor: return features class ReassembleBlock(nn.Module): """ Processes the features such that they have progressively increasing embedding size and progressively decreasing spatial dimension """ def __init__( self, in_channels: int, mid_channels: int, out_channels: int, upsample_factor: int, ): super().__init__() self.project_to_out_channel = nn.Conv2d( in_channels=in_channels, out_channels=mid_channels, kernel_size=1, ) if upsample_factor > 1.0: self.upsample = nn.ConvTranspose2d( in_channels=mid_channels, out_channels=mid_channels, kernel_size=int(upsample_factor), stride=int(upsample_factor), ) elif upsample_factor == 1.0: self.upsample = nn.Identity() else: self.upsample = nn.Conv2d( in_channels=mid_channels, out_channels=mid_channels, kernel_size=3, stride=int(1 / upsample_factor), padding=1, ) self.project_to_feature_dim = nn.Conv2d( in_channels=mid_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.project_to_out_channel(x) x = self.upsample(x) x = self.project_to_feature_dim(x) return x class ResidualConvBlock(nn.Module): def __init__(self, feature_dim: int): super().__init__() self.conv_1 = nn.Conv2d( in_channels=feature_dim, out_channels=feature_dim, kernel_size=3, padding=1, bias=False, ) self.batch_norm_1 = nn.BatchNorm2d(num_features=feature_dim) self.conv_2 = nn.Conv2d( in_channels=feature_dim, out_channels=feature_dim, kernel_size=3, padding=1, bias=False, ) self.batch_norm_2 = nn.BatchNorm2d(num_features=feature_dim) self.activation = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x # Block 1 x = self.activation(x) x = self.conv_1(x) x = self.batch_norm_1(x) # Block 2 x = self.activation(x) x = self.conv_2(x) x = self.batch_norm_2(x) # Add residual x = x + residual return x class FusionBlock(nn.Module): """ Fuses the processed encoder features in a residual manner and upsamples them """ def __init__(self, feature_dim: int): super().__init__() self.residual_conv_block1 = ResidualConvBlock(feature_dim) self.residual_conv_block2 = ResidualConvBlock(feature_dim) self.project = nn.Conv2d(feature_dim, feature_dim, kernel_size=1) self.activation = nn.ReLU() def forward( self, feature: torch.Tensor, previous_feature: Optional[torch.Tensor] = None, ) -> torch.Tensor: feature = self.residual_conv_block1(feature) if previous_feature is not None: feature = feature + previous_feature feature = self.residual_conv_block2(feature) feature = nn.functional.interpolate( feature, scale_factor=2, align_corners=True, mode="bilinear" ) feature = self.project(feature) return feature class DPTDecoder(nn.Module): """ Decoder part for DPT Processes the encoder features and class tokens (if encoder has class_tokens) to have spatial downsampling ratios of [1/4, 1/8, 1/16, 1/32, ...] relative to the input image spatial dimension. The decoder then fuses these features in a residual manner and progressively upsamples them by a factor of 2 so that the output has a downsampling ratio of 1/2 relative to the input image spatial dimension """ def __init__( self, encoder_out_channels: Sequence[int] = (756, 756, 756, 756), encoder_output_strides: Sequence[int] = (16, 16, 16, 16), encoder_has_prefix_tokens: bool = True, readout: Literal["cat", "add", "ignore"] = "cat", intermediate_channels: Sequence[int] = (256, 512, 1024, 1024), fusion_channels: int = 256, ): super().__init__() if not ( len(encoder_out_channels) == len(encoder_output_strides) == len(intermediate_channels) ): raise ValueError( "encoder_out_channels, encoder_output_strides and intermediate_channels must have the same length" ) num_blocks = len(encoder_out_channels) # If encoder has prefix tokens (e.g. cls_token), then we can concat/add/ignore them # according to the readout mode if readout == "cat": blocks = [ ReadoutConcatBlock(in_channels, encoder_has_prefix_tokens) for in_channels in encoder_out_channels ] elif readout == "add": blocks = [ReadoutAddBlock() for _ in encoder_out_channels] elif readout == "ignore": blocks = [ReadoutIgnoreBlock() for _ in encoder_out_channels] else: raise ValueError( f"Invalid readout mode: {readout}, should be one of: 'cat', 'add', 'ignore'" ) self.projection_blocks = nn.ModuleList(blocks) # Upsample factors to resize features to progressively smaller scales # For ViT models where all layers have the same stride, we create multi-scale features # by progressively downsampling from the encoder output scale_factors = [] min_stride = min(encoder_output_strides) for i, stride in enumerate(encoder_output_strides): # Progressive downsampling: i=0 keeps original, i=1 halves, i=2 quarters, etc. target_stride = min_stride * (2 ** i) scale_factor = stride / target_stride scale_factors.append(scale_factor) self.reassemble_blocks = nn.ModuleList() for i in range(num_blocks): block = ReassembleBlock( in_channels=encoder_out_channels[i], mid_channels=intermediate_channels[i], out_channels=fusion_channels, upsample_factor=scale_factors[i], ) self.reassemble_blocks.append(block) # Fusion blocks to fuse the processed features in a sequential manner fusion_blocks = [FusionBlock(fusion_channels) for _ in range(num_blocks)] self.fusion_blocks = nn.ModuleList(fusion_blocks) def forward( self, features: list[torch.Tensor], prefix_tokens: list[Optional[torch.Tensor]] ) -> torch.Tensor: # Process the encoder features to scale of [1/4, 1/8, 1/16, 1/32, ...] processed_features = [] for i, (feature, prefix_tokens_i) in enumerate(zip(features, prefix_tokens)): projected_feature = self.projection_blocks[i](feature, prefix_tokens_i) processed_feature = self.reassemble_blocks[i](projected_feature) processed_features.append(processed_feature) # Fusion and progressive upsampling starting from the last processed feature processed_features = processed_features[::-1] fused_feature = None for fusion_block, feature in zip(self.fusion_blocks, processed_features): fused_feature = fusion_block(feature, fused_feature) return fused_feature class DPTSegmentationHead(nn.Module): def __init__( self, in_channels: int, out_channels: int, activation: Optional[Union[str, Callable]] = None, kernel_size: int = 3, upsampling: float = 2.0, ): super().__init__() self.head = nn.Sequential( nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, padding=1, bias=False ), nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), nn.Dropout(p=0.1, inplace=False), nn.Conv2d(in_channels, out_channels, kernel_size=1), ) self.activation = Activation(activation) self.upsampling_factor = upsampling def forward(self, x: torch.Tensor) -> torch.Tensor: head_output = self.head(x) resized_output = nn.functional.interpolate( head_output, scale_factor=self.upsampling_factor, mode="bilinear", align_corners=True, ) activation_output = self.activation(resized_output) return activation_output class DPT(SegmentationModel): """ DPT is a dense prediction architecture that leverages vision transformers in place of convolutional networks as a backbone for dense prediction tasks It assembles tokens from various stages of the vision transformer into image-like representations at various resolutions and progressively combines them into full-resolution predictions using a convolutional decoder. The transformer backbone processes representations at a constant and relatively high resolution and has a global receptive field at every stage. These properties allow the dense vision transformer to provide finer-grained and more globally coherent predictions when compared to fully-convolutional networks Note: Since this model uses a Vision Transformer backbone, it typically requires a fixed input image size. To handle variable input sizes, you can set `dynamic_img_size=True` in the model initialization (if supported by the specific `timm` encoder). You can check if an encoder requires fixed size using `model.encoder.is_fixed_input_size`, and get the required input dimensions from `model.encoder.input_size`, however it's no guarantee that information is available. Args: encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) to extract features of different spatial resolution. encoder_depth: A number of stages used in encoder in range [1,4]. Each stage generate features smaller by a factor equal to the ViT model patch_size in spatial dimensions. Default is 4. encoder_weights: One of **None** (random initialization), or not **None** (pretrained weights would be loaded with respect to the encoder_name, e.g. for ``"tu-vit_base_patch16_224.augreg_in21k"`` - ``"augreg_in21k"`` weights would be loaded). encoder_output_indices: The indices of the encoder output features to use. If **None** will be sampled uniformly across the number of blocks in encoder, e.g. if number of blocks is 4 and encoder has 20 blocks, then encoder_output_indices will be (4, 9, 14, 19). If specified the number of indices should be equal to encoder_depth. Default is **None**. decoder_readout: The strategy to utilize the prefix tokens (e.g. cls_token) from the encoder. Can be one of **"cat"**, **"add"**, or **"ignore"**. Default is **"cat"**. decoder_intermediate_channels: The number of channels for the intermediate decoder layers. Reduce if you want to reduce the number of parameters in the decoder. Default is (256, 512, 1024, 1024). decoder_fusion_channels: The latent dimension to which the encoder features will be projected to before fusion. Default is 256. in_channels: Number of input channels for the model, default is 3 (RGB images) classes: Number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**. Default is **None**. aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - **classes** (*int*): A number of classes; - **pooling** (*str*): One of "max", "avg". Default is "avg"; - **dropout** (*float*): Dropout factor in [0, 1); - **activation** (*str*): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits). kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. Specify ``dynamic_img_size=True`` to allow the model to handle images of different sizes. Returns: ``torch.nn.Module``: DPT """ # fails for encoders with prefix tokens _is_torch_scriptable = False _is_torch_compilable = True requires_divisible_input_shape = True @supports_config_loading def __init__( self, encoder_name: str = "hf-hub:paige-ai/Virchow2", encoder_depth: int = 4, encoder_output_indices: Optional[list[int]] = None, decoder_readout: Literal["ignore", "add", "cat"] = "cat", decoder_intermediate_channels: Sequence[int] = (224, 448, 896, 896), decoder_fusion_channels: int = 224, in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): super().__init__() if decoder_readout not in ["ignore", "add", "cat"]: raise ValueError( f"Invalid decoder readout mode. Must be one of: 'ignore', 'add', 'cat'. Got: {decoder_readout}" ) from timm.layers import SwiGLUPacked self.encoder = TimmViTEncoder( name=encoder_name, mlp_layer=SwiGLUPacked, act_layer=torch.nn.SiLU, in_channels=in_channels, depth=encoder_depth, output_indices=encoder_output_indices, **kwargs, ) if not self.encoder.has_prefix_tokens and decoder_readout != "ignore": warnings.warn( f"Encoder does not have prefix tokens (e.g. cls_token), but `decoder_readout` is set to '{decoder_readout}'. " f"It's recommended to set `decoder_readout='ignore'` when using a encoder without prefix tokens.", UserWarning, ) self.decoder = DPTDecoder( encoder_out_channels=self.encoder.out_channels, encoder_output_strides=self.encoder.output_strides, encoder_has_prefix_tokens=self.encoder.has_prefix_tokens, readout=decoder_readout, intermediate_channels=decoder_intermediate_channels, fusion_channels=decoder_fusion_channels, ) # Calculate required upsampling for segmentation head # Decoder output spatial size: encoder_stride / (2^num_fusion_blocks) # For ViT with stride 14 and 4 fusion blocks: 16 / (2^2) = 4, times 2 for each fusion = 32 # Actually: encoder_spatial = 224 / encoder_stride # decoder_output_spatial = encoder_spatial / (2^(num_fusion_blocks-1)) # But we want output to match input, so: # segmentation_upsampling = encoder_stride / 2 (to match input size) # However, simpler: just use encoder stride as the upsampling factor encoder_stride = max(self.encoder.output_strides) if self.encoder.output_strides else 14 seg_head_upsampling = float(encoder_stride / 2) # Upsample by encoder_stride/2 to reach input resolution self.segmentation_head = DPTSegmentationHead( in_channels=decoder_fusion_channels, out_channels=classes, activation=activation, kernel_size=3, upsampling=seg_head_upsampling, ) if aux_params is not None: self.classification_head = ClassificationHead( in_channels=self.encoder.out_channels[-1], **aux_params ) else: self.classification_head = None self.name = f"dpt-{encoder_name.split('/')[-1]}" self.initialize() def forward(self, x): """Sequentially pass `x` trough model`s encoder, decoder and heads""" if not ( torch.jit.is_scripting() or torch.jit.is_tracing() or is_torch_compiling() ): self.check_input_shape(x) features, prefix_tokens = self.encoder(x) decoder_output = self.decoder(features, prefix_tokens) masks = self.segmentation_head(decoder_output) # Ensure contiguous memory layout for DDP compatibility masks = masks.contiguous() if self.classification_head is not None: labels = self.classification_head(features[-1]) return masks, labels return masks if __name__ == "__main__": model = DPT() img = torch.randn(1, 3, 224, 224) out = model(img) print(out.shape) # torch.Size([1, 1, 224, 224])