Image Segmentation
Transformers
Safetensors
metpredict_dpt
feature-extraction
pathology
dpt
custom_code
Instructions to use RendeiroLab/MetPredict-lung-structure-segmentation with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use RendeiroLab/MetPredict-lung-structure-segmentation with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-segmentation", model="RendeiroLab/MetPredict-lung-structure-segmentation", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("RendeiroLab/MetPredict-lung-structure-segmentation", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # %% | |
| import warnings | |
| from typing import Optional, Sequence, Union, Callable, Literal, Any | |
| import torch | |
| import torch.nn as nn | |
| from segmentation_models_pytorch.base import ( | |
| ClassificationHead, | |
| SegmentationModel, | |
| ) | |
| from segmentation_models_pytorch.encoders.timm_vit import TimmViTEncoder | |
| from segmentation_models_pytorch.base.utils import is_torch_compiling | |
| from segmentation_models_pytorch.base.hub_mixin import supports_config_loading | |
| from segmentation_models_pytorch.base.modules import Activation | |
| class ReadoutConcatBlock(nn.Module): | |
| """ | |
| Concatenates the cls tokens with the features to make use of the global information aggregated in the prefix (cls) tokens. | |
| Projects the combined feature map to the original embedding dimension using a MLP. | |
| According to: | |
| https://github.com/isl-org/DPT/blob/cd3fe90bb4c48577535cc4d51b602acca688a2ee/dpt/vit.py#L79-L90 | |
| """ | |
| def __init__(self, embed_dim: int, has_prefix_tokens: bool): | |
| super().__init__() | |
| in_features = embed_dim * 2 if has_prefix_tokens else embed_dim | |
| out_features = embed_dim | |
| self.project = nn.Sequential( | |
| nn.Linear(in_features, out_features), | |
| nn.GELU(), | |
| ) | |
| def forward( | |
| self, features: torch.Tensor, prefix_tokens: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| batch_size, embed_dim, height, width = features.shape | |
| # Rearrange to (batch_size, height * width, embed_dim) | |
| features = features.view(batch_size, embed_dim, -1) | |
| features = features.transpose(1, 2).contiguous() | |
| if prefix_tokens is not None: | |
| # (batch_size, num_prefix_tokens, embed_dim) -> (batch_size, 1, embed_dim) | |
| prefix_tokens = prefix_tokens[:, :1].expand_as(features) | |
| features = torch.cat([features, prefix_tokens], dim=2) | |
| # Project to embedding dimension | |
| features = self.project(features) | |
| # Rearrange back to (batch_size, embed_dim, height, width) | |
| features = features.transpose(1, 2) | |
| features = features.view(batch_size, -1, height, width) | |
| return features | |
| class ReadoutAddBlock(nn.Module): | |
| """ | |
| Adds the prefix tokens to the features to make use of the global information aggregated in the prefix (cls) tokens. | |
| According to: | |
| https://github.com/isl-org/DPT/blob/cd3fe90bb4c48577535cc4d51b602acca688a2ee/dpt/vit.py#L71-L76 | |
| """ | |
| def forward( | |
| self, features: torch.Tensor, prefix_tokens: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| if prefix_tokens is not None: | |
| batch_size, embed_dim, height, width = features.shape | |
| prefix_tokens = prefix_tokens.mean(dim=1) | |
| prefix_tokens = prefix_tokens.view(batch_size, embed_dim, 1, 1) | |
| features = features + prefix_tokens | |
| return features | |
| class ReadoutIgnoreBlock(nn.Module): | |
| """ | |
| Ignores the prefix tokens and returns the features as is. | |
| """ | |
| def forward(self, features: torch.Tensor, *args, **kwargs) -> torch.Tensor: | |
| return features | |
| class ReassembleBlock(nn.Module): | |
| """ | |
| Processes the features such that they have progressively increasing embedding size and progressively decreasing | |
| spatial dimension | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| mid_channels: int, | |
| out_channels: int, | |
| upsample_factor: int, | |
| ): | |
| super().__init__() | |
| self.project_to_out_channel = nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=mid_channels, | |
| kernel_size=1, | |
| ) | |
| if upsample_factor > 1.0: | |
| self.upsample = nn.ConvTranspose2d( | |
| in_channels=mid_channels, | |
| out_channels=mid_channels, | |
| kernel_size=int(upsample_factor), | |
| stride=int(upsample_factor), | |
| ) | |
| elif upsample_factor == 1.0: | |
| self.upsample = nn.Identity() | |
| else: | |
| self.upsample = nn.Conv2d( | |
| in_channels=mid_channels, | |
| out_channels=mid_channels, | |
| kernel_size=3, | |
| stride=int(1 / upsample_factor), | |
| padding=1, | |
| ) | |
| self.project_to_feature_dim = nn.Conv2d( | |
| in_channels=mid_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| bias=False, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.project_to_out_channel(x) | |
| x = self.upsample(x) | |
| x = self.project_to_feature_dim(x) | |
| return x | |
| class ResidualConvBlock(nn.Module): | |
| def __init__(self, feature_dim: int): | |
| super().__init__() | |
| self.conv_1 = nn.Conv2d( | |
| in_channels=feature_dim, | |
| out_channels=feature_dim, | |
| kernel_size=3, | |
| padding=1, | |
| bias=False, | |
| ) | |
| self.batch_norm_1 = nn.BatchNorm2d(num_features=feature_dim) | |
| self.conv_2 = nn.Conv2d( | |
| in_channels=feature_dim, | |
| out_channels=feature_dim, | |
| kernel_size=3, | |
| padding=1, | |
| bias=False, | |
| ) | |
| self.batch_norm_2 = nn.BatchNorm2d(num_features=feature_dim) | |
| self.activation = nn.ReLU() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| residual = x | |
| # Block 1 | |
| x = self.activation(x) | |
| x = self.conv_1(x) | |
| x = self.batch_norm_1(x) | |
| # Block 2 | |
| x = self.activation(x) | |
| x = self.conv_2(x) | |
| x = self.batch_norm_2(x) | |
| # Add residual | |
| x = x + residual | |
| return x | |
| class FusionBlock(nn.Module): | |
| """ | |
| Fuses the processed encoder features in a residual manner and upsamples them | |
| """ | |
| def __init__(self, feature_dim: int): | |
| super().__init__() | |
| self.residual_conv_block1 = ResidualConvBlock(feature_dim) | |
| self.residual_conv_block2 = ResidualConvBlock(feature_dim) | |
| self.project = nn.Conv2d(feature_dim, feature_dim, kernel_size=1) | |
| self.activation = nn.ReLU() | |
| def forward( | |
| self, | |
| feature: torch.Tensor, | |
| previous_feature: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| feature = self.residual_conv_block1(feature) | |
| if previous_feature is not None: | |
| feature = feature + previous_feature | |
| feature = self.residual_conv_block2(feature) | |
| feature = nn.functional.interpolate( | |
| feature, scale_factor=2, align_corners=True, mode="bilinear" | |
| ) | |
| feature = self.project(feature) | |
| return feature | |
| class DPTDecoder(nn.Module): | |
| """ | |
| Decoder part for DPT | |
| Processes the encoder features and class tokens (if encoder has class_tokens) to have spatial downsampling ratios of | |
| [1/4, 1/8, 1/16, 1/32, ...] relative to the input image spatial dimension. | |
| The decoder then fuses these features in a residual manner and progressively upsamples them by a factor of 2 so that the | |
| output has a downsampling ratio of 1/2 relative to the input image spatial dimension | |
| """ | |
| def __init__( | |
| self, | |
| encoder_out_channels: Sequence[int] = (756, 756, 756, 756), | |
| encoder_output_strides: Sequence[int] = (16, 16, 16, 16), | |
| encoder_has_prefix_tokens: bool = True, | |
| readout: Literal["cat", "add", "ignore"] = "cat", | |
| intermediate_channels: Sequence[int] = (256, 512, 1024, 1024), | |
| fusion_channels: int = 256, | |
| ): | |
| super().__init__() | |
| if not ( | |
| len(encoder_out_channels) | |
| == len(encoder_output_strides) | |
| == len(intermediate_channels) | |
| ): | |
| raise ValueError( | |
| "encoder_out_channels, encoder_output_strides and intermediate_channels must have the same length" | |
| ) | |
| num_blocks = len(encoder_out_channels) | |
| # If encoder has prefix tokens (e.g. cls_token), then we can concat/add/ignore them | |
| # according to the readout mode | |
| if readout == "cat": | |
| blocks = [ | |
| ReadoutConcatBlock(in_channels, encoder_has_prefix_tokens) | |
| for in_channels in encoder_out_channels | |
| ] | |
| elif readout == "add": | |
| blocks = [ReadoutAddBlock() for _ in encoder_out_channels] | |
| elif readout == "ignore": | |
| blocks = [ReadoutIgnoreBlock() for _ in encoder_out_channels] | |
| else: | |
| raise ValueError( | |
| f"Invalid readout mode: {readout}, should be one of: 'cat', 'add', 'ignore'" | |
| ) | |
| self.projection_blocks = nn.ModuleList(blocks) | |
| # Upsample factors to resize features to progressively smaller scales | |
| # For ViT models where all layers have the same stride, we create multi-scale features | |
| # by progressively downsampling from the encoder output | |
| scale_factors = [] | |
| min_stride = min(encoder_output_strides) | |
| for i, stride in enumerate(encoder_output_strides): | |
| # Progressive downsampling: i=0 keeps original, i=1 halves, i=2 quarters, etc. | |
| target_stride = min_stride * (2 ** i) | |
| scale_factor = stride / target_stride | |
| scale_factors.append(scale_factor) | |
| self.reassemble_blocks = nn.ModuleList() | |
| for i in range(num_blocks): | |
| block = ReassembleBlock( | |
| in_channels=encoder_out_channels[i], | |
| mid_channels=intermediate_channels[i], | |
| out_channels=fusion_channels, | |
| upsample_factor=scale_factors[i], | |
| ) | |
| self.reassemble_blocks.append(block) | |
| # Fusion blocks to fuse the processed features in a sequential manner | |
| fusion_blocks = [FusionBlock(fusion_channels) for _ in range(num_blocks)] | |
| self.fusion_blocks = nn.ModuleList(fusion_blocks) | |
| def forward( | |
| self, features: list[torch.Tensor], prefix_tokens: list[Optional[torch.Tensor]] | |
| ) -> torch.Tensor: | |
| # Process the encoder features to scale of [1/4, 1/8, 1/16, 1/32, ...] | |
| processed_features = [] | |
| for i, (feature, prefix_tokens_i) in enumerate(zip(features, prefix_tokens)): | |
| projected_feature = self.projection_blocks[i](feature, prefix_tokens_i) | |
| processed_feature = self.reassemble_blocks[i](projected_feature) | |
| processed_features.append(processed_feature) | |
| # Fusion and progressive upsampling starting from the last processed feature | |
| processed_features = processed_features[::-1] | |
| fused_feature = None | |
| for fusion_block, feature in zip(self.fusion_blocks, processed_features): | |
| fused_feature = fusion_block(feature, fused_feature) | |
| return fused_feature | |
| class DPTSegmentationHead(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| activation: Optional[Union[str, Callable]] = None, | |
| kernel_size: int = 3, | |
| upsampling: float = 2.0, | |
| ): | |
| super().__init__() | |
| self.head = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels, in_channels, kernel_size=kernel_size, padding=1, bias=False | |
| ), | |
| nn.BatchNorm2d(in_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(p=0.1, inplace=False), | |
| nn.Conv2d(in_channels, out_channels, kernel_size=1), | |
| ) | |
| self.activation = Activation(activation) | |
| self.upsampling_factor = upsampling | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| head_output = self.head(x) | |
| resized_output = nn.functional.interpolate( | |
| head_output, | |
| scale_factor=self.upsampling_factor, | |
| mode="bilinear", | |
| align_corners=True, | |
| ) | |
| activation_output = self.activation(resized_output) | |
| return activation_output | |
| class DPT(SegmentationModel): | |
| """ | |
| DPT is a dense prediction architecture that leverages vision transformers in place of convolutional networks as | |
| a backbone for dense prediction tasks | |
| It assembles tokens from various stages of the vision transformer into image-like representations at various resolutions | |
| and progressively combines them into full-resolution predictions using a convolutional decoder. | |
| The transformer backbone processes representations at a constant and relatively high resolution and has a global receptive | |
| field at every stage. These properties allow the dense vision transformer to provide finer-grained and more globally coherent | |
| predictions when compared to fully-convolutional networks | |
| Note: | |
| Since this model uses a Vision Transformer backbone, it typically requires a fixed input image size. | |
| To handle variable input sizes, you can set `dynamic_img_size=True` in the model initialization | |
| (if supported by the specific `timm` encoder). You can check if an encoder requires fixed size | |
| using `model.encoder.is_fixed_input_size`, and get the required input dimensions from | |
| `model.encoder.input_size`, however it's no guarantee that information is available. | |
| Args: | |
| encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) | |
| to extract features of different spatial resolution. | |
| encoder_depth: A number of stages used in encoder in range [1,4]. Each stage generate features | |
| smaller by a factor equal to the ViT model patch_size in spatial dimensions. | |
| Default is 4. | |
| encoder_weights: One of **None** (random initialization), or not **None** (pretrained weights would be loaded | |
| with respect to the encoder_name, e.g. for ``"tu-vit_base_patch16_224.augreg_in21k"`` - ``"augreg_in21k"`` | |
| weights would be loaded). | |
| encoder_output_indices: The indices of the encoder output features to use. If **None** will be sampled uniformly | |
| across the number of blocks in encoder, e.g. if number of blocks is 4 and encoder has 20 blocks, then | |
| encoder_output_indices will be (4, 9, 14, 19). If specified the number of indices should be equal to | |
| encoder_depth. Default is **None**. | |
| decoder_readout: The strategy to utilize the prefix tokens (e.g. cls_token) from the encoder. | |
| Can be one of **"cat"**, **"add"**, or **"ignore"**. Default is **"cat"**. | |
| decoder_intermediate_channels: The number of channels for the intermediate decoder layers. Reduce if you | |
| want to reduce the number of parameters in the decoder. Default is (256, 512, 1024, 1024). | |
| decoder_fusion_channels: The latent dimension to which the encoder features will be projected to before fusion. | |
| Default is 256. | |
| in_channels: Number of input channels for the model, default is 3 (RGB images) | |
| classes: Number of classes for output mask (or you can think as a number of channels of output mask) | |
| activation: An activation function to apply after the final convolution layer. | |
| Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, | |
| **callable** and **None**. Default is **None**. | |
| aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build | |
| on top of encoder if **aux_params** is not **None** (default). Supported params: | |
| - **classes** (*int*): A number of classes; | |
| - **pooling** (*str*): One of "max", "avg". Default is "avg"; | |
| - **dropout** (*float*): Dropout factor in [0, 1); | |
| - **activation** (*str*): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits). | |
| kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with | |
| ``None`` values are pruned before passing. Specify ``dynamic_img_size=True`` to allow the model to handle images of different sizes. | |
| Returns: | |
| ``torch.nn.Module``: DPT | |
| """ | |
| # fails for encoders with prefix tokens | |
| _is_torch_scriptable = False | |
| _is_torch_compilable = True | |
| requires_divisible_input_shape = True | |
| def __init__( | |
| self, | |
| encoder_name: str = "hf-hub:paige-ai/Virchow2", | |
| encoder_depth: int = 4, | |
| encoder_output_indices: Optional[list[int]] = None, | |
| decoder_readout: Literal["ignore", "add", "cat"] = "cat", | |
| decoder_intermediate_channels: Sequence[int] = (224, 448, 896, 896), | |
| decoder_fusion_channels: int = 224, | |
| in_channels: int = 3, | |
| classes: int = 1, | |
| activation: Optional[Union[str, Callable]] = None, | |
| aux_params: Optional[dict] = None, | |
| **kwargs: dict[str, Any], | |
| ): | |
| super().__init__() | |
| if decoder_readout not in ["ignore", "add", "cat"]: | |
| raise ValueError( | |
| f"Invalid decoder readout mode. Must be one of: 'ignore', 'add', 'cat'. Got: {decoder_readout}" | |
| ) | |
| from timm.layers import SwiGLUPacked | |
| self.encoder = TimmViTEncoder( | |
| name=encoder_name, | |
| mlp_layer=SwiGLUPacked, | |
| act_layer=torch.nn.SiLU, | |
| in_channels=in_channels, | |
| depth=encoder_depth, | |
| output_indices=encoder_output_indices, | |
| **kwargs, | |
| ) | |
| if not self.encoder.has_prefix_tokens and decoder_readout != "ignore": | |
| warnings.warn( | |
| f"Encoder does not have prefix tokens (e.g. cls_token), but `decoder_readout` is set to '{decoder_readout}'. " | |
| f"It's recommended to set `decoder_readout='ignore'` when using a encoder without prefix tokens.", | |
| UserWarning, | |
| ) | |
| self.decoder = DPTDecoder( | |
| encoder_out_channels=self.encoder.out_channels, | |
| encoder_output_strides=self.encoder.output_strides, | |
| encoder_has_prefix_tokens=self.encoder.has_prefix_tokens, | |
| readout=decoder_readout, | |
| intermediate_channels=decoder_intermediate_channels, | |
| fusion_channels=decoder_fusion_channels, | |
| ) | |
| # Calculate required upsampling for segmentation head | |
| # Decoder output spatial size: encoder_stride / (2^num_fusion_blocks) | |
| # For ViT with stride 14 and 4 fusion blocks: 16 / (2^2) = 4, times 2 for each fusion = 32 | |
| # Actually: encoder_spatial = 224 / encoder_stride | |
| # decoder_output_spatial = encoder_spatial / (2^(num_fusion_blocks-1)) | |
| # But we want output to match input, so: | |
| # segmentation_upsampling = encoder_stride / 2 (to match input size) | |
| # However, simpler: just use encoder stride as the upsampling factor | |
| encoder_stride = max(self.encoder.output_strides) if self.encoder.output_strides else 14 | |
| seg_head_upsampling = float(encoder_stride / 2) # Upsample by encoder_stride/2 to reach input resolution | |
| self.segmentation_head = DPTSegmentationHead( | |
| in_channels=decoder_fusion_channels, | |
| out_channels=classes, | |
| activation=activation, | |
| kernel_size=3, | |
| upsampling=seg_head_upsampling, | |
| ) | |
| if aux_params is not None: | |
| self.classification_head = ClassificationHead( | |
| in_channels=self.encoder.out_channels[-1], **aux_params | |
| ) | |
| else: | |
| self.classification_head = None | |
| self.name = f"dpt-{encoder_name.split('/')[-1]}" | |
| self.initialize() | |
| def forward(self, x): | |
| """Sequentially pass `x` trough model`s encoder, decoder and heads""" | |
| if not ( | |
| torch.jit.is_scripting() or torch.jit.is_tracing() or is_torch_compiling() | |
| ): | |
| self.check_input_shape(x) | |
| features, prefix_tokens = self.encoder(x) | |
| decoder_output = self.decoder(features, prefix_tokens) | |
| masks = self.segmentation_head(decoder_output) | |
| # Ensure contiguous memory layout for DDP compatibility | |
| masks = masks.contiguous() | |
| if self.classification_head is not None: | |
| labels = self.classification_head(features[-1]) | |
| return masks, labels | |
| return masks | |
| if __name__ == "__main__": | |
| model = DPT() | |
| img = torch.randn(1, 3, 224, 224) | |
| out = model(img) | |
| print(out.shape) # torch.Size([1, 1, 224, 224]) | |