|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import math |
|
|
from collections.abc import Sequence |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock |
|
|
from monai.networks.blocks.transformerblock import TransformerBlock |
|
|
from monai.networks.layers import Conv |
|
|
from monai.utils import deprecated_arg, ensure_tuple_rep, is_sqrt |
|
|
|
|
|
__all__ = ["ViTAutoEnc"] |
|
|
|
|
|
|
|
|
class ViTAutoEnc(nn.Module): |
|
|
""" |
|
|
Vision Transformer (ViT), based on: "Dosovitskiy et al., |
|
|
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>" |
|
|
|
|
|
Modified to also give same dimension outputs as the input size of the image |
|
|
""" |
|
|
|
|
|
@deprecated_arg( |
|
|
name="pos_embed", since="1.2", removed="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead." |
|
|
) |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
img_size: Sequence[int] | int, |
|
|
patch_size: Sequence[int] | int, |
|
|
out_channels: int = 1, |
|
|
deconv_chns: int = 16, |
|
|
hidden_size: int = 768, |
|
|
mlp_dim: int = 3072, |
|
|
num_layers: int = 12, |
|
|
num_heads: int = 12, |
|
|
pos_embed: str = "conv", |
|
|
proj_type: str = "conv", |
|
|
dropout_rate: float = 0.0, |
|
|
spatial_dims: int = 3, |
|
|
qkv_bias: bool = False, |
|
|
save_attn: bool = False, |
|
|
) -> None: |
|
|
""" |
|
|
Args: |
|
|
in_channels: dimension of input channels or the number of channels for input. |
|
|
img_size: dimension of input image. |
|
|
patch_size: dimension of patch size |
|
|
out_channels: number of output channels. Defaults to 1. |
|
|
deconv_chns: number of channels for the deconvolution layers. Defaults to 16. |
|
|
hidden_size: dimension of hidden layer. Defaults to 768. |
|
|
mlp_dim: dimension of feedforward layer. Defaults to 3072. |
|
|
num_layers: number of transformer blocks. Defaults to 12. |
|
|
num_heads: number of attention heads. Defaults to 12. |
|
|
proj_type: position embedding layer type. Defaults to "conv". |
|
|
dropout_rate: fraction of the input units to drop. Defaults to 0.0. |
|
|
spatial_dims: number of spatial dimensions. Defaults to 3. |
|
|
qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False. |
|
|
save_attn: to make accessible the attention in self attention block. Defaults to False. Defaults to False. |
|
|
|
|
|
.. deprecated:: 1.4 |
|
|
``pos_embed`` is deprecated in favor of ``proj_type``. |
|
|
|
|
|
Examples:: |
|
|
|
|
|
# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone |
|
|
# It will provide an output of same size as that of the input |
|
|
>>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), proj_type='conv') |
|
|
|
|
|
# for 3-channel with image size of (128,128,128), output will be same size as of input |
|
|
>>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), proj_type='conv') |
|
|
|
|
|
""" |
|
|
|
|
|
super().__init__() |
|
|
if not is_sqrt(patch_size): |
|
|
raise ValueError(f"patch_size should be square number, got {patch_size}.") |
|
|
self.patch_size = ensure_tuple_rep(patch_size, spatial_dims) |
|
|
self.img_size = ensure_tuple_rep(img_size, spatial_dims) |
|
|
self.spatial_dims = spatial_dims |
|
|
for m, p in zip(self.img_size, self.patch_size): |
|
|
if m % p != 0: |
|
|
raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.") |
|
|
|
|
|
self.patch_embedding = PatchEmbeddingBlock( |
|
|
in_channels=in_channels, |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
hidden_size=hidden_size, |
|
|
num_heads=num_heads, |
|
|
proj_type=proj_type, |
|
|
dropout_rate=dropout_rate, |
|
|
spatial_dims=self.spatial_dims, |
|
|
) |
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn) |
|
|
for i in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.norm = nn.LayerNorm(hidden_size) |
|
|
|
|
|
conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims] |
|
|
|
|
|
up_kernel_size = [int(math.sqrt(i)) for i in self.patch_size] |
|
|
self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=up_kernel_size, stride=up_kernel_size) |
|
|
self.conv3d_transpose_1 = conv_trans( |
|
|
in_channels=deconv_chns, out_channels=out_channels, kernel_size=up_kernel_size, stride=up_kernel_size |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x: input tensor must have isotropic spatial dimensions, |
|
|
such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``. |
|
|
""" |
|
|
spatial_size = x.shape[2:] |
|
|
x = self.patch_embedding(x) |
|
|
hidden_states_out = [] |
|
|
for blk in self.blocks: |
|
|
x = blk(x) |
|
|
hidden_states_out.append(x) |
|
|
x = self.norm(x) |
|
|
x = x.transpose(1, 2) |
|
|
d = [s // p for s, p in zip(spatial_size, self.patch_size)] |
|
|
x = torch.reshape(x, [x.shape[0], x.shape[1], *d]) |
|
|
x = self.conv3d_transpose(x) |
|
|
x = self.conv3d_transpose_1(x) |
|
|
return x, hidden_states_out |
|
|
|