File size: 6,250 Bytes
34a4bcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import torch
import torch.nn as nn
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.utils import deprecated_arg
__all__ = ["ViT"]
class ViT(nn.Module):
"""
Vision Transformer (ViT), based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
ViT supports Torchscript but only works for Pytorch after 1.8.
"""
@deprecated_arg(
name="pos_embed", since="1.2", removed="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead."
)
def __init__(
self,
in_channels: int,
img_size: Sequence[int] | int,
patch_size: Sequence[int] | int,
hidden_size: int = 768,
mlp_dim: int = 3072,
num_layers: int = 12,
num_heads: int = 12,
pos_embed: str = "conv",
proj_type: str = "conv",
pos_embed_type: str = "learnable",
classification: bool = False,
num_classes: int = 2,
dropout_rate: float = 0.0,
spatial_dims: int = 3,
post_activation="Tanh",
qkv_bias: bool = False,
save_attn: bool = False,
) -> None:
"""
Args:
in_channels (int): dimension of input channels.
img_size (Union[Sequence[int], int]): dimension of input image.
patch_size (Union[Sequence[int], int]): dimension of patch size.
hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
num_layers (int, optional): number of transformer blocks. Defaults to 12.
num_heads (int, optional): number of attention heads. Defaults to 12.
proj_type (str, optional): patch embedding layer type. Defaults to "conv".
pos_embed_type (str, optional): position embedding type. Defaults to "learnable".
classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
num_classes (int, optional): number of classes if classification is used. Defaults to 2.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
post_activation (str, optional): add a final acivation function to the classification head
when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
Set to other values to remove this function.
qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.
.. deprecated:: 1.4
``pos_embed`` is deprecated in favor of ``proj_type``.
Examples::
# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
>>> net = ViT(in_channels=1, img_size=(96,96,96), proj_type='conv', pos_embed_type='sincos')
# for 3-channel with image size of (128,128,128), 24 layers and classification backbone
>>> net = ViT(in_channels=3, img_size=(128,128,128), proj_type='conv', pos_embed_type='sincos', classification=True)
# for 3-channel with image size of (224,224), 12 layers and classification backbone
>>> net = ViT(in_channels=3, img_size=(224,224), proj_type='conv', pos_embed_type='sincos', classification=True,
>>> spatial_dims=2)
"""
super().__init__()
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
if hidden_size % num_heads != 0:
raise ValueError("hidden_size should be divisible by num_heads.")
self.classification = classification
self.patch_embedding = PatchEmbeddingBlock(
in_channels=in_channels,
img_size=img_size,
patch_size=patch_size,
hidden_size=hidden_size,
num_heads=num_heads,
proj_type=proj_type,
pos_embed_type=pos_embed_type,
dropout_rate=dropout_rate,
spatial_dims=spatial_dims,
)
self.blocks = nn.ModuleList(
[
TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
for i in range(num_layers)
]
)
self.norm = nn.LayerNorm(hidden_size)
if self.classification:
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
if post_activation == "Tanh":
self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh())
else:
self.classification_head = nn.Linear(hidden_size, num_classes) # type: ignore
def forward(self, x):
x = self.patch_embedding(x)
if hasattr(self, "cls_token"):
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
hidden_states_out = []
for blk in self.blocks:
x = blk(x)
hidden_states_out.append(x)
x = self.norm(x)
if hasattr(self, "classification_head"):
x = self.classification_head(x[:, 0])
return x, hidden_states_out
|