milkzheng's picture
Upload folder using huggingface_hub
b2be135 verified
Raw
History Blame Contribute Delete
20.4 kB
# %%
import warnings
from typing import Optional, Sequence, Union, Callable, Literal, Any
import torch
import torch.nn as nn
from segmentation_models_pytorch.base import (
ClassificationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders.timm_vit import TimmViTEncoder
from segmentation_models_pytorch.base.utils import is_torch_compiling
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
from segmentation_models_pytorch.base.modules import Activation
class ReadoutConcatBlock(nn.Module):
"""
Concatenates the cls tokens with the features to make use of the global information aggregated in the prefix (cls) tokens.
Projects the combined feature map to the original embedding dimension using a MLP.
According to:
https://github.com/isl-org/DPT/blob/cd3fe90bb4c48577535cc4d51b602acca688a2ee/dpt/vit.py#L79-L90
"""
def __init__(self, embed_dim: int, has_prefix_tokens: bool):
super().__init__()
in_features = embed_dim * 2 if has_prefix_tokens else embed_dim
out_features = embed_dim
self.project = nn.Sequential(
nn.Linear(in_features, out_features),
nn.GELU(),
)
def forward(
self, features: torch.Tensor, prefix_tokens: Optional[torch.Tensor] = None
) -> torch.Tensor:
batch_size, embed_dim, height, width = features.shape
# Rearrange to (batch_size, height * width, embed_dim)
features = features.view(batch_size, embed_dim, -1)
features = features.transpose(1, 2).contiguous()
if prefix_tokens is not None:
# (batch_size, num_prefix_tokens, embed_dim) -> (batch_size, 1, embed_dim)
prefix_tokens = prefix_tokens[:, :1].expand_as(features)
features = torch.cat([features, prefix_tokens], dim=2)
# Project to embedding dimension
features = self.project(features)
# Rearrange back to (batch_size, embed_dim, height, width)
features = features.transpose(1, 2)
features = features.view(batch_size, -1, height, width)
return features
class ReadoutAddBlock(nn.Module):
"""
Adds the prefix tokens to the features to make use of the global information aggregated in the prefix (cls) tokens.
According to:
https://github.com/isl-org/DPT/blob/cd3fe90bb4c48577535cc4d51b602acca688a2ee/dpt/vit.py#L71-L76
"""
def forward(
self, features: torch.Tensor, prefix_tokens: Optional[torch.Tensor] = None
) -> torch.Tensor:
if prefix_tokens is not None:
batch_size, embed_dim, height, width = features.shape
prefix_tokens = prefix_tokens.mean(dim=1)
prefix_tokens = prefix_tokens.view(batch_size, embed_dim, 1, 1)
features = features + prefix_tokens
return features
class ReadoutIgnoreBlock(nn.Module):
"""
Ignores the prefix tokens and returns the features as is.
"""
def forward(self, features: torch.Tensor, *args, **kwargs) -> torch.Tensor:
return features
class ReassembleBlock(nn.Module):
"""
Processes the features such that they have progressively increasing embedding size and progressively decreasing
spatial dimension
"""
def __init__(
self,
in_channels: int,
mid_channels: int,
out_channels: int,
upsample_factor: int,
):
super().__init__()
self.project_to_out_channel = nn.Conv2d(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
)
if upsample_factor > 1.0:
self.upsample = nn.ConvTranspose2d(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=int(upsample_factor),
stride=int(upsample_factor),
)
elif upsample_factor == 1.0:
self.upsample = nn.Identity()
else:
self.upsample = nn.Conv2d(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
stride=int(1 / upsample_factor),
padding=1,
)
self.project_to_feature_dim = nn.Conv2d(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.project_to_out_channel(x)
x = self.upsample(x)
x = self.project_to_feature_dim(x)
return x
class ResidualConvBlock(nn.Module):
def __init__(self, feature_dim: int):
super().__init__()
self.conv_1 = nn.Conv2d(
in_channels=feature_dim,
out_channels=feature_dim,
kernel_size=3,
padding=1,
bias=False,
)
self.batch_norm_1 = nn.BatchNorm2d(num_features=feature_dim)
self.conv_2 = nn.Conv2d(
in_channels=feature_dim,
out_channels=feature_dim,
kernel_size=3,
padding=1,
bias=False,
)
self.batch_norm_2 = nn.BatchNorm2d(num_features=feature_dim)
self.activation = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
# Block 1
x = self.activation(x)
x = self.conv_1(x)
x = self.batch_norm_1(x)
# Block 2
x = self.activation(x)
x = self.conv_2(x)
x = self.batch_norm_2(x)
# Add residual
x = x + residual
return x
class FusionBlock(nn.Module):
"""
Fuses the processed encoder features in a residual manner and upsamples them
"""
def __init__(self, feature_dim: int):
super().__init__()
self.residual_conv_block1 = ResidualConvBlock(feature_dim)
self.residual_conv_block2 = ResidualConvBlock(feature_dim)
self.project = nn.Conv2d(feature_dim, feature_dim, kernel_size=1)
self.activation = nn.ReLU()
def forward(
self,
feature: torch.Tensor,
previous_feature: Optional[torch.Tensor] = None,
) -> torch.Tensor:
feature = self.residual_conv_block1(feature)
if previous_feature is not None:
feature = feature + previous_feature
feature = self.residual_conv_block2(feature)
feature = nn.functional.interpolate(
feature, scale_factor=2, align_corners=True, mode="bilinear"
)
feature = self.project(feature)
return feature
class DPTDecoder(nn.Module):
"""
Decoder part for DPT
Processes the encoder features and class tokens (if encoder has class_tokens) to have spatial downsampling ratios of
[1/4, 1/8, 1/16, 1/32, ...] relative to the input image spatial dimension.
The decoder then fuses these features in a residual manner and progressively upsamples them by a factor of 2 so that the
output has a downsampling ratio of 1/2 relative to the input image spatial dimension
"""
def __init__(
self,
encoder_out_channels: Sequence[int] = (756, 756, 756, 756),
encoder_output_strides: Sequence[int] = (16, 16, 16, 16),
encoder_has_prefix_tokens: bool = True,
readout: Literal["cat", "add", "ignore"] = "cat",
intermediate_channels: Sequence[int] = (256, 512, 1024, 1024),
fusion_channels: int = 256,
):
super().__init__()
if not (
len(encoder_out_channels)
== len(encoder_output_strides)
== len(intermediate_channels)
):
raise ValueError(
"encoder_out_channels, encoder_output_strides and intermediate_channels must have the same length"
)
num_blocks = len(encoder_out_channels)
# If encoder has prefix tokens (e.g. cls_token), then we can concat/add/ignore them
# according to the readout mode
if readout == "cat":
blocks = [
ReadoutConcatBlock(in_channels, encoder_has_prefix_tokens)
for in_channels in encoder_out_channels
]
elif readout == "add":
blocks = [ReadoutAddBlock() for _ in encoder_out_channels]
elif readout == "ignore":
blocks = [ReadoutIgnoreBlock() for _ in encoder_out_channels]
else:
raise ValueError(
f"Invalid readout mode: {readout}, should be one of: 'cat', 'add', 'ignore'"
)
self.projection_blocks = nn.ModuleList(blocks)
# Upsample factors to resize features to progressively smaller scales
# For ViT models where all layers have the same stride, we create multi-scale features
# by progressively downsampling from the encoder output
scale_factors = []
min_stride = min(encoder_output_strides)
for i, stride in enumerate(encoder_output_strides):
# Progressive downsampling: i=0 keeps original, i=1 halves, i=2 quarters, etc.
target_stride = min_stride * (2 ** i)
scale_factor = stride / target_stride
scale_factors.append(scale_factor)
self.reassemble_blocks = nn.ModuleList()
for i in range(num_blocks):
block = ReassembleBlock(
in_channels=encoder_out_channels[i],
mid_channels=intermediate_channels[i],
out_channels=fusion_channels,
upsample_factor=scale_factors[i],
)
self.reassemble_blocks.append(block)
# Fusion blocks to fuse the processed features in a sequential manner
fusion_blocks = [FusionBlock(fusion_channels) for _ in range(num_blocks)]
self.fusion_blocks = nn.ModuleList(fusion_blocks)
def forward(
self, features: list[torch.Tensor], prefix_tokens: list[Optional[torch.Tensor]]
) -> torch.Tensor:
# Process the encoder features to scale of [1/4, 1/8, 1/16, 1/32, ...]
processed_features = []
for i, (feature, prefix_tokens_i) in enumerate(zip(features, prefix_tokens)):
projected_feature = self.projection_blocks[i](feature, prefix_tokens_i)
processed_feature = self.reassemble_blocks[i](projected_feature)
processed_features.append(processed_feature)
# Fusion and progressive upsampling starting from the last processed feature
processed_features = processed_features[::-1]
fused_feature = None
for fusion_block, feature in zip(self.fusion_blocks, processed_features):
fused_feature = fusion_block(feature, fused_feature)
return fused_feature
class DPTSegmentationHead(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
activation: Optional[Union[str, Callable]] = None,
kernel_size: int = 3,
upsampling: float = 2.0,
):
super().__init__()
self.head = nn.Sequential(
nn.Conv2d(
in_channels, in_channels, kernel_size=kernel_size, padding=1, bias=False
),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Dropout(p=0.1, inplace=False),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
)
self.activation = Activation(activation)
self.upsampling_factor = upsampling
def forward(self, x: torch.Tensor) -> torch.Tensor:
head_output = self.head(x)
resized_output = nn.functional.interpolate(
head_output,
scale_factor=self.upsampling_factor,
mode="bilinear",
align_corners=True,
)
activation_output = self.activation(resized_output)
return activation_output
class DPT(SegmentationModel):
"""
DPT is a dense prediction architecture that leverages vision transformers in place of convolutional networks as
a backbone for dense prediction tasks
It assembles tokens from various stages of the vision transformer into image-like representations at various resolutions
and progressively combines them into full-resolution predictions using a convolutional decoder.
The transformer backbone processes representations at a constant and relatively high resolution and has a global receptive
field at every stage. These properties allow the dense vision transformer to provide finer-grained and more globally coherent
predictions when compared to fully-convolutional networks
Note:
Since this model uses a Vision Transformer backbone, it typically requires a fixed input image size.
To handle variable input sizes, you can set `dynamic_img_size=True` in the model initialization
(if supported by the specific `timm` encoder). You can check if an encoder requires fixed size
using `model.encoder.is_fixed_input_size`, and get the required input dimensions from
`model.encoder.input_size`, however it's no guarantee that information is available.
Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
to extract features of different spatial resolution.
encoder_depth: A number of stages used in encoder in range [1,4]. Each stage generate features
smaller by a factor equal to the ViT model patch_size in spatial dimensions.
Default is 4.
encoder_weights: One of **None** (random initialization), or not **None** (pretrained weights would be loaded
with respect to the encoder_name, e.g. for ``"tu-vit_base_patch16_224.augreg_in21k"`` - ``"augreg_in21k"``
weights would be loaded).
encoder_output_indices: The indices of the encoder output features to use. If **None** will be sampled uniformly
across the number of blocks in encoder, e.g. if number of blocks is 4 and encoder has 20 blocks, then
encoder_output_indices will be (4, 9, 14, 19). If specified the number of indices should be equal to
encoder_depth. Default is **None**.
decoder_readout: The strategy to utilize the prefix tokens (e.g. cls_token) from the encoder.
Can be one of **"cat"**, **"add"**, or **"ignore"**. Default is **"cat"**.
decoder_intermediate_channels: The number of channels for the intermediate decoder layers. Reduce if you
want to reduce the number of parameters in the decoder. Default is (256, 512, 1024, 1024).
decoder_fusion_channels: The latent dimension to which the encoder features will be projected to before fusion.
Default is 256.
in_channels: Number of input channels for the model, default is 3 (RGB images)
classes: Number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**. Default is **None**.
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- **classes** (*int*): A number of classes;
- **pooling** (*str*): One of "max", "avg". Default is "avg";
- **dropout** (*float*): Dropout factor in [0, 1);
- **activation** (*str*): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits).
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with
``None`` values are pruned before passing. Specify ``dynamic_img_size=True`` to allow the model to handle images of different sizes.
Returns:
``torch.nn.Module``: DPT
"""
# fails for encoders with prefix tokens
_is_torch_scriptable = False
_is_torch_compilable = True
requires_divisible_input_shape = True
@supports_config_loading
def __init__(
self,
encoder_name: str = "hf-hub:paige-ai/Virchow2",
encoder_depth: int = 4,
encoder_output_indices: Optional[list[int]] = None,
decoder_readout: Literal["ignore", "add", "cat"] = "cat",
decoder_intermediate_channels: Sequence[int] = (224, 448, 896, 896),
decoder_fusion_channels: int = 224,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, Callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()
if decoder_readout not in ["ignore", "add", "cat"]:
raise ValueError(
f"Invalid decoder readout mode. Must be one of: 'ignore', 'add', 'cat'. Got: {decoder_readout}"
)
from timm.layers import SwiGLUPacked
self.encoder = TimmViTEncoder(
name=encoder_name,
mlp_layer=SwiGLUPacked,
act_layer=torch.nn.SiLU,
in_channels=in_channels,
depth=encoder_depth,
output_indices=encoder_output_indices,
**kwargs,
)
if not self.encoder.has_prefix_tokens and decoder_readout != "ignore":
warnings.warn(
f"Encoder does not have prefix tokens (e.g. cls_token), but `decoder_readout` is set to '{decoder_readout}'. "
f"It's recommended to set `decoder_readout='ignore'` when using a encoder without prefix tokens.",
UserWarning,
)
self.decoder = DPTDecoder(
encoder_out_channels=self.encoder.out_channels,
encoder_output_strides=self.encoder.output_strides,
encoder_has_prefix_tokens=self.encoder.has_prefix_tokens,
readout=decoder_readout,
intermediate_channels=decoder_intermediate_channels,
fusion_channels=decoder_fusion_channels,
)
# Calculate required upsampling for segmentation head
# Decoder output spatial size: encoder_stride / (2^num_fusion_blocks)
# For ViT with stride 14 and 4 fusion blocks: 16 / (2^2) = 4, times 2 for each fusion = 32
# Actually: encoder_spatial = 224 / encoder_stride
# decoder_output_spatial = encoder_spatial / (2^(num_fusion_blocks-1))
# But we want output to match input, so:
# segmentation_upsampling = encoder_stride / 2 (to match input size)
# However, simpler: just use encoder stride as the upsampling factor
encoder_stride = max(self.encoder.output_strides) if self.encoder.output_strides else 14
seg_head_upsampling = float(encoder_stride / 2) # Upsample by encoder_stride/2 to reach input resolution
self.segmentation_head = DPTSegmentationHead(
in_channels=decoder_fusion_channels,
out_channels=classes,
activation=activation,
kernel_size=3,
upsampling=seg_head_upsampling,
)
if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None
self.name = f"dpt-{encoder_name.split('/')[-1]}"
self.initialize()
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
if not (
torch.jit.is_scripting() or torch.jit.is_tracing() or is_torch_compiling()
):
self.check_input_shape(x)
features, prefix_tokens = self.encoder(x)
decoder_output = self.decoder(features, prefix_tokens)
masks = self.segmentation_head(decoder_output)
# Ensure contiguous memory layout for DDP compatibility
masks = masks.contiguous()
if self.classification_head is not None:
labels = self.classification_head(features[-1])
return masks, labels
return masks
if __name__ == "__main__":
model = DPT()
img = torch.randn(1, 3, 224, 224)
out = model(img)
print(out.shape) # torch.Size([1, 1, 224, 224])