File size: 8,545 Bytes
34a4bcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import torch.nn as nn
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
from monai.networks.nets.vit import ViT
from monai.utils import deprecated_arg, ensure_tuple_rep
class UNETR(nn.Module):
"""
UNETR based on: "Hatamizadeh et al.,
UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
"""
@deprecated_arg(
name="pos_embed", since="1.2", removed="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead."
)
def __init__(
self,
in_channels: int,
out_channels: int,
img_size: Sequence[int] | int,
feature_size: int = 16,
hidden_size: int = 768,
mlp_dim: int = 3072,
num_heads: int = 12,
pos_embed: str = "conv",
proj_type: str = "conv",
norm_name: tuple | str = "instance",
conv_block: bool = True,
res_block: bool = True,
dropout_rate: float = 0.0,
spatial_dims: int = 3,
qkv_bias: bool = False,
save_attn: bool = False,
) -> None:
"""
Args:
in_channels: dimension of input channels.
out_channels: dimension of output channels.
img_size: dimension of input image.
feature_size: dimension of network feature size. Defaults to 16.
hidden_size: dimension of hidden layer. Defaults to 768.
mlp_dim: dimension of feedforward layer. Defaults to 3072.
num_heads: number of attention heads. Defaults to 12.
proj_type: patch embedding layer type. Defaults to "conv".
norm_name: feature normalization type and arguments. Defaults to "instance".
conv_block: if convolutional block is used. Defaults to True.
res_block: if residual block is used. Defaults to True.
dropout_rate: fraction of the input units to drop. Defaults to 0.0.
spatial_dims: number of spatial dims. Defaults to 3.
qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False.
save_attn: to make accessible the attention in self attention block. Defaults to False.
.. deprecated:: 1.4
``pos_embed`` is deprecated in favor of ``proj_type``.
Examples::
# for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm
>>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch')
# for single channel input 4-channel output with image size of (96,96), feature size of 32 and batch norm
>>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2)
# for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm
>>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), proj_type='conv', norm_name='instance')
"""
super().__init__()
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
if hidden_size % num_heads != 0:
raise ValueError("hidden_size should be divisible by num_heads.")
self.num_layers = 12
img_size = ensure_tuple_rep(img_size, spatial_dims)
self.patch_size = ensure_tuple_rep(16, spatial_dims)
self.feat_size = tuple(img_d // p_d for img_d, p_d in zip(img_size, self.patch_size))
self.hidden_size = hidden_size
self.classification = False
self.vit = ViT(
in_channels=in_channels,
img_size=img_size,
patch_size=self.patch_size,
hidden_size=hidden_size,
mlp_dim=mlp_dim,
num_layers=self.num_layers,
num_heads=num_heads,
proj_type=proj_type,
classification=self.classification,
dropout_rate=dropout_rate,
spatial_dims=spatial_dims,
qkv_bias=qkv_bias,
save_attn=save_attn,
)
self.encoder1 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.encoder2 = UnetrPrUpBlock(
spatial_dims=spatial_dims,
in_channels=hidden_size,
out_channels=feature_size * 2,
num_layer=2,
kernel_size=3,
stride=1,
upsample_kernel_size=2,
norm_name=norm_name,
conv_block=conv_block,
res_block=res_block,
)
self.encoder3 = UnetrPrUpBlock(
spatial_dims=spatial_dims,
in_channels=hidden_size,
out_channels=feature_size * 4,
num_layer=1,
kernel_size=3,
stride=1,
upsample_kernel_size=2,
norm_name=norm_name,
conv_block=conv_block,
res_block=res_block,
)
self.encoder4 = UnetrPrUpBlock(
spatial_dims=spatial_dims,
in_channels=hidden_size,
out_channels=feature_size * 8,
num_layer=0,
kernel_size=3,
stride=1,
upsample_kernel_size=2,
norm_name=norm_name,
conv_block=conv_block,
res_block=res_block,
)
self.decoder5 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=hidden_size,
out_channels=feature_size * 8,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=feature_size * 4,
out_channels=feature_size * 2,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=feature_size * 2,
out_channels=feature_size,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims))
self.proj_view_shape = list(self.feat_size) + [self.hidden_size]
def proj_feat(self, x):
new_view = [x.size(0)] + self.proj_view_shape
x = x.view(new_view)
x = x.permute(self.proj_axes).contiguous()
return x
def forward(self, x_in):
x, hidden_states_out = self.vit(x_in)
enc1 = self.encoder1(x_in)
x2 = hidden_states_out[3]
enc2 = self.encoder2(self.proj_feat(x2))
x3 = hidden_states_out[6]
enc3 = self.encoder3(self.proj_feat(x3))
x4 = hidden_states_out[9]
enc4 = self.encoder4(self.proj_feat(x4))
dec4 = self.proj_feat(x)
dec3 = self.decoder5(dec4, enc4)
dec2 = self.decoder4(dec3, enc3)
dec1 = self.decoder3(dec2, enc2)
out = self.decoder2(dec1, enc1)
return self.out(out)
|