File size: 18,664 Bytes
34a4bcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 |
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from torch import nn
from torch.nn import functional as F
from monai.networks.blocks.regunet_block import (
RegistrationDownSampleBlock,
RegistrationExtractionBlock,
RegistrationResidualConvBlock,
get_conv_block,
get_deconv_block,
)
from monai.networks.utils import meshgrid_ij
__all__ = ["RegUNet", "AffineHead", "GlobalNet", "LocalNet"]
class RegUNet(nn.Module):
"""
Class that implements an adapted UNet. This class also serve as the parent class of LocalNet and GlobalNet
Reference:
O. Ronneberger, P. Fischer, and T. Brox,
“U-net: Convolutional networks for biomedical image segmentation,”,
Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
https://arxiv.org/abs/1505.04597
Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""
def __init__(
self,
spatial_dims: int,
in_channels: int,
num_channel_initial: int,
depth: int,
out_kernel_initializer: str | None = "kaiming_uniform",
out_activation: str | None = None,
out_channels: int = 3,
extract_levels: tuple[int] | None = None,
pooling: bool = True,
concat_skip: bool = False,
encode_kernel_sizes: int | list[int] = 3,
):
"""
Args:
spatial_dims: number of spatial dims
in_channels: number of input channels
num_channel_initial: number of initial channels
depth: input is at level 0, bottom is at level depth.
out_kernel_initializer: kernel initializer for the last layer
out_activation: activation at the last layer
out_channels: number of channels for the output
extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``
pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv
concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
encode_kernel_sizes: kernel size for down-sampling
"""
super().__init__()
if not extract_levels:
extract_levels = (depth,)
if max(extract_levels) != depth:
raise AssertionError
# save parameters
self.spatial_dims = spatial_dims
self.in_channels = in_channels
self.num_channel_initial = num_channel_initial
self.depth = depth
self.out_kernel_initializer = out_kernel_initializer
self.out_activation = out_activation
self.out_channels = out_channels
self.extract_levels = extract_levels
self.pooling = pooling
self.concat_skip = concat_skip
if isinstance(encode_kernel_sizes, int):
encode_kernel_sizes = [encode_kernel_sizes] * (self.depth + 1)
if len(encode_kernel_sizes) != self.depth + 1:
raise AssertionError
self.encode_kernel_sizes: list[int] = encode_kernel_sizes
self.num_channels = [self.num_channel_initial * (2**d) for d in range(self.depth + 1)]
self.min_extract_level = min(self.extract_levels)
# init layers
# all lists start with d = 0
self.encode_convs: nn.ModuleList
self.encode_pools: nn.ModuleList
self.bottom_block: nn.Sequential
self.decode_deconvs: nn.ModuleList
self.decode_convs: nn.ModuleList
self.output_block: nn.Module
# build layers
self.build_layers()
def build_layers(self):
self.build_encode_layers()
self.build_decode_layers()
def build_encode_layers(self):
# encoding / down-sampling
self.encode_convs = nn.ModuleList(
[
self.build_conv_block(
in_channels=self.in_channels if d == 0 else self.num_channels[d - 1],
out_channels=self.num_channels[d],
kernel_size=self.encode_kernel_sizes[d],
)
for d in range(self.depth)
]
)
self.encode_pools = nn.ModuleList(
[self.build_down_sampling_block(channels=self.num_channels[d]) for d in range(self.depth)]
)
self.bottom_block = self.build_bottom_block(
in_channels=self.num_channels[-2], out_channels=self.num_channels[-1]
)
def build_conv_block(self, in_channels, out_channels, kernel_size):
return nn.Sequential(
get_conv_block(
spatial_dims=self.spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
),
RegistrationResidualConvBlock(
spatial_dims=self.spatial_dims,
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
),
)
def build_down_sampling_block(self, channels: int):
return RegistrationDownSampleBlock(spatial_dims=self.spatial_dims, channels=channels, pooling=self.pooling)
def build_bottom_block(self, in_channels: int, out_channels: int):
kernel_size = self.encode_kernel_sizes[self.depth]
return nn.Sequential(
get_conv_block(
spatial_dims=self.spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
),
RegistrationResidualConvBlock(
spatial_dims=self.spatial_dims,
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
),
)
def build_decode_layers(self):
self.decode_deconvs = nn.ModuleList(
[
self.build_up_sampling_block(in_channels=self.num_channels[d + 1], out_channels=self.num_channels[d])
for d in range(self.depth - 1, self.min_extract_level - 1, -1)
]
)
self.decode_convs = nn.ModuleList(
[
self.build_conv_block(
in_channels=(2 * self.num_channels[d] if self.concat_skip else self.num_channels[d]),
out_channels=self.num_channels[d],
kernel_size=3,
)
for d in range(self.depth - 1, self.min_extract_level - 1, -1)
]
)
# extraction
self.output_block = self.build_output_block()
def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module:
return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels)
def build_output_block(self) -> nn.Module:
return RegistrationExtractionBlock(
spatial_dims=self.spatial_dims,
extract_levels=self.extract_levels,
num_channels=self.num_channels,
out_channels=self.out_channels,
kernel_initializer=self.out_kernel_initializer,
activation=self.out_activation,
)
def forward(self, x):
"""
Args:
x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])
Returns:
Tensor in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3]), with the same spatial size as ``x``
"""
image_size = x.shape[2:]
skips = [] # [0, ..., depth - 1]
encoded = x
for encode_conv, encode_pool in zip(self.encode_convs, self.encode_pools):
skip = encode_conv(encoded)
encoded = encode_pool(skip)
skips.append(skip)
decoded = self.bottom_block(encoded)
outs = [decoded]
for i, (decode_deconv, decode_conv) in enumerate(zip(self.decode_deconvs, self.decode_convs)):
decoded = decode_deconv(decoded)
if self.concat_skip:
decoded = torch.cat([decoded, skips[-i - 1]], dim=1)
else:
decoded = decoded + skips[-i - 1]
decoded = decode_conv(decoded)
outs.append(decoded)
out = self.output_block(outs, image_size=image_size)
return out
class AffineHead(nn.Module):
def __init__(
self,
spatial_dims: int,
image_size: list[int],
decode_size: list[int],
in_channels: int,
save_theta: bool = False,
):
"""
Args:
spatial_dims: number of spatial dimensions
image_size: output spatial size
decode_size: input spatial size (two or three integers depending on ``spatial_dims``)
in_channels: number of input channels
save_theta: whether to save the theta matrix estimation
"""
super().__init__()
self.spatial_dims = spatial_dims
if spatial_dims == 2:
in_features = in_channels * decode_size[0] * decode_size[1]
out_features = 6
out_init = torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)
elif spatial_dims == 3:
in_features = in_channels * decode_size[0] * decode_size[1] * decode_size[2]
out_features = 12
out_init = torch.tensor([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], dtype=torch.float)
else:
raise ValueError(f"only support 2D/3D operation, got spatial_dims={spatial_dims}")
self.fc = nn.Linear(in_features=in_features, out_features=out_features)
self.grid = self.get_reference_grid(image_size) # (spatial_dims, ...)
# init weight/bias
self.fc.weight.data.zero_()
self.fc.bias.data.copy_(out_init)
self.save_theta = save_theta
self.theta = torch.Tensor()
@staticmethod
def get_reference_grid(image_size: tuple[int] | list[int]) -> torch.Tensor:
mesh_points = [torch.arange(0, dim) for dim in image_size]
grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...)
return grid.to(dtype=torch.float)
def affine_transform(self, theta: torch.Tensor):
# (spatial_dims, ...) -> (spatial_dims + 1, ...)
grid_padded = torch.cat([self.grid, torch.ones_like(self.grid[:1])])
# grid_warped[b,p,...] = sum_over_q(grid_padded[q,...] * theta[b,p,q]
if self.spatial_dims == 2:
grid_warped = torch.einsum("qij,bpq->bpij", grid_padded, theta.reshape(-1, 2, 3))
elif self.spatial_dims == 3:
grid_warped = torch.einsum("qijk,bpq->bpijk", grid_padded, theta.reshape(-1, 3, 4))
else:
raise ValueError(f"do not support spatial_dims={self.spatial_dims}")
return grid_warped
def forward(self, x: list[torch.Tensor], image_size: list[int]) -> torch.Tensor:
f = x[0]
self.grid = self.grid.to(device=f.device)
theta = self.fc(f.reshape(f.shape[0], -1))
if self.save_theta:
self.theta = theta.detach()
out: torch.Tensor = self.affine_transform(theta) - self.grid
return out
class GlobalNet(RegUNet):
"""
Build GlobalNet for image registration.
Reference:
Hu, Yipeng, et al.
"Label-driven weakly-supervised learning
for multimodal deformable image registration,"
https://arxiv.org/abs/1711.01666
"""
def __init__(
self,
image_size: list[int],
spatial_dims: int,
in_channels: int,
num_channel_initial: int,
depth: int,
out_kernel_initializer: str | None = "kaiming_uniform",
out_activation: str | None = None,
pooling: bool = True,
concat_skip: bool = False,
encode_kernel_sizes: int | list[int] = 3,
save_theta: bool = False,
):
"""
Args:
image_size: output displacement field spatial size
spatial_dims: number of spatial dims
in_channels: number of input channels
num_channel_initial: number of initial channels
depth: input is at level 0, bottom is at level depth.
out_kernel_initializer: kernel initializer for the last layer
out_activation: activation at the last layer
pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv
concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
encode_kernel_sizes: kernel size for down-sampling
save_theta: whether to save the theta matrix estimation
"""
for size in image_size:
if size % (2**depth) != 0:
raise ValueError(
f"given depth {depth}, "
f"all input spatial dimension must be divisible by {2 ** depth}, "
f"got input of size {image_size}"
)
self.image_size = image_size
self.decode_size = [size // (2**depth) for size in image_size]
self.save_theta = save_theta
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_channels,
num_channel_initial=num_channel_initial,
depth=depth,
out_kernel_initializer=out_kernel_initializer,
out_activation=out_activation,
out_channels=spatial_dims,
pooling=pooling,
concat_skip=concat_skip,
encode_kernel_sizes=encode_kernel_sizes,
)
def build_output_block(self):
return AffineHead(
spatial_dims=self.spatial_dims,
image_size=self.image_size,
decode_size=self.decode_size,
in_channels=self.num_channels[-1],
save_theta=self.save_theta,
)
class AdditiveUpSampleBlock(nn.Module):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
mode: str = "nearest",
align_corners: bool | None = None,
):
super().__init__()
self.deconv = get_deconv_block(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels)
self.mode = mode
self.align_corners = align_corners
def forward(self, x: torch.Tensor) -> torch.Tensor:
output_size = [size * 2 for size in x.shape[2:]]
deconved = self.deconv(x)
resized = F.interpolate(x, output_size, mode=self.mode, align_corners=self.align_corners)
resized = torch.sum(torch.stack(resized.split(split_size=resized.shape[1] // 2, dim=1), dim=-1), dim=-1)
out: torch.Tensor = deconved + resized
return out
class LocalNet(RegUNet):
"""
Reimplementation of LocalNet, based on:
`Weakly-supervised convolutional neural networks for multimodal image registration
<https://doi.org/10.1016/j.media.2018.07.002>`_.
`Label-driven weakly-supervised learning for multimodal deformable image registration
<https://arxiv.org/abs/1711.01666>`_.
Adapted from:
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""
def __init__(
self,
spatial_dims: int,
in_channels: int,
num_channel_initial: int,
extract_levels: tuple[int],
out_kernel_initializer: str | None = "kaiming_uniform",
out_activation: str | None = None,
out_channels: int = 3,
pooling: bool = True,
use_additive_sampling: bool = True,
concat_skip: bool = False,
mode: str = "nearest",
align_corners: bool | None = None,
):
"""
Args:
spatial_dims: number of spatial dims
in_channels: number of input channels
num_channel_initial: number of initial channels
out_kernel_initializer: kernel initializer for the last layer
out_activation: activation at the last layer
out_channels: number of channels for the output
extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``
pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv3d
use_additive_sampling: whether use additive up-sampling layer for decoding.
concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
mode: mode for interpolation when use_additive_sampling, default is "nearest".
align_corners: align_corners for interpolation when use_additive_sampling, default is None.
"""
self.use_additive_upsampling = use_additive_sampling
self.mode = mode
self.align_corners = align_corners
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_channels,
num_channel_initial=num_channel_initial,
extract_levels=extract_levels,
depth=max(extract_levels),
out_kernel_initializer=out_kernel_initializer,
out_activation=out_activation,
out_channels=out_channels,
pooling=pooling,
concat_skip=concat_skip,
encode_kernel_sizes=[7] + [3] * max(extract_levels),
)
def build_bottom_block(self, in_channels: int, out_channels: int):
kernel_size = self.encode_kernel_sizes[self.depth]
return get_conv_block(
spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size
)
def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module:
if self.use_additive_upsampling:
return AdditiveUpSampleBlock(
spatial_dims=self.spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
mode=self.mode,
align_corners=self.align_corners,
)
return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels)
|