Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/kiwisolver/_cext.cpython-38-x86_64-linux-gnu.so +3 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/matplotlib/ft2font.cpython-38-x86_64-linux-gnu.so +3 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/dataloader.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/png_saver.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/png_writer.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/__pycache__/utils.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/acti_norm.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/activation.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/convolutions.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/dints_block.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/downsample.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/dynunet_block.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/localnet_block.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/mlp.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/regunet_block.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/selfattention.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/squeeze_and_excitation.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/transformerblock.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/warp.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__init__.py +31 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/utils.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/convutils.py +227 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/drop_path.py +45 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/factories.py +371 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/filtering.py +101 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/gmm.py +85 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/simplelayers.py +532 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/spatial_transforms.py +564 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/utils.py +116 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/weight_init.py +64 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__init__.py +91 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/attentionunet.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/autoencoder.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/basic_unet.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/dynunet.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/efficientnet.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/fullyconnectednet.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/netadapter.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/regressor.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/regunet.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/resnet.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/segresnet.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/senet.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/swin_unetr.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/torchvision_fc.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/unet.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/unetr.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/varautoencoder.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/vit.cpython-38.pyc +0 -0
.gitattributes
CHANGED
|
@@ -339,3 +339,6 @@ my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/opencv_pyth
|
|
| 339 |
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/opencv_python.libs/libxkbcommon-71ae2972.so.0.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 340 |
my_container_sandbox/workspace/anaconda3/lib/libnvvm.so.4.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 341 |
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/sklearn/_isotonic.cpython-38-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/opencv_python.libs/libxkbcommon-71ae2972.so.0.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 340 |
my_container_sandbox/workspace/anaconda3/lib/libnvvm.so.4.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 341 |
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/sklearn/_isotonic.cpython-38-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 342 |
+
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/matplotlib/ft2font.cpython-38-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 343 |
+
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/torchvision.libs/libnvjpeg.90286a3c.so.11 filter=lfs diff=lfs merge=lfs -text
|
| 344 |
+
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/kiwisolver/_cext.cpython-38-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/kiwisolver/_cext.cpython-38-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:847be32e761c90dedc0a587a45e0b9b1bc1d9c6a9322e3e278bad141fb8ebe90
|
| 3 |
+
size 4260107
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/matplotlib/ft2font.cpython-38-x86_64-linux-gnu.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:03e5cffa72fb1aaef4acef0341ad5f6bb19501fcb384aaa7e7321e570813ac51
|
| 3 |
+
size 4120999
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/dataloader.cpython-38.pyc
ADDED
|
Binary file (2.96 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/png_saver.cpython-38.pyc
ADDED
|
Binary file (6.5 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/png_writer.cpython-38.pyc
ADDED
|
Binary file (2.92 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (23.8 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/acti_norm.cpython-38.pyc
ADDED
|
Binary file (3.55 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/activation.cpython-38.pyc
ADDED
|
Binary file (5.8 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/convolutions.cpython-38.pyc
ADDED
|
Binary file (9.72 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/dints_block.cpython-38.pyc
ADDED
|
Binary file (6.55 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/downsample.cpython-38.pyc
ADDED
|
Binary file (2.11 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/dynunet_block.cpython-38.pyc
ADDED
|
Binary file (8.65 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/localnet_block.cpython-38.pyc
ADDED
|
Binary file (9.86 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/mlp.cpython-38.pyc
ADDED
|
Binary file (2.4 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/regunet_block.cpython-38.pyc
ADDED
|
Binary file (8.18 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/selfattention.cpython-38.pyc
ADDED
|
Binary file (2.07 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/squeeze_and_excitation.cpython-38.pyc
ADDED
|
Binary file (9.84 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/transformerblock.cpython-38.pyc
ADDED
|
Binary file (1.71 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/warp.cpython-38.pyc
ADDED
|
Binary file (5.24 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding
|
| 13 |
+
from .drop_path import DropPath
|
| 14 |
+
from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args
|
| 15 |
+
from .filtering import BilateralFilter, PHLFilter
|
| 16 |
+
from .gmm import GaussianMixtureModel
|
| 17 |
+
from .simplelayers import (
|
| 18 |
+
LLTM,
|
| 19 |
+
ChannelPad,
|
| 20 |
+
Flatten,
|
| 21 |
+
GaussianFilter,
|
| 22 |
+
HilbertTransform,
|
| 23 |
+
Reshape,
|
| 24 |
+
SavitzkyGolayFilter,
|
| 25 |
+
SkipConnection,
|
| 26 |
+
apply_filter,
|
| 27 |
+
separable_filtering,
|
| 28 |
+
)
|
| 29 |
+
from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push
|
| 30 |
+
from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer
|
| 31 |
+
from .weight_init import _no_grad_trunc_normal_, trunc_normal_
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (3.53 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/convutils.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
__all__ = ["same_padding", "stride_minus_kernel_padding", "calculate_out_shape", "gaussian_1d", "polyval"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def same_padding(
|
| 21 |
+
kernel_size: Union[Sequence[int], int], dilation: Union[Sequence[int], int] = 1
|
| 22 |
+
) -> Union[Tuple[int, ...], int]:
|
| 23 |
+
"""
|
| 24 |
+
Return the padding value needed to ensure a convolution using the given kernel size produces an output of the same
|
| 25 |
+
shape as the input for a stride of 1, otherwise ensure a shape of the input divided by the stride rounded down.
|
| 26 |
+
|
| 27 |
+
Raises:
|
| 28 |
+
NotImplementedError: When ``np.any((kernel_size - 1) * dilation % 2 == 1)``.
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
kernel_size_np = np.atleast_1d(kernel_size)
|
| 33 |
+
dilation_np = np.atleast_1d(dilation)
|
| 34 |
+
|
| 35 |
+
if np.any((kernel_size_np - 1) * dilation % 2 == 1):
|
| 36 |
+
raise NotImplementedError(
|
| 37 |
+
f"Same padding not available for kernel_size={kernel_size_np} and dilation={dilation_np}."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
padding_np = (kernel_size_np - 1) / 2 * dilation_np
|
| 41 |
+
padding = tuple(int(p) for p in padding_np)
|
| 42 |
+
|
| 43 |
+
return padding if len(padding) > 1 else padding[0]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def stride_minus_kernel_padding(
|
| 47 |
+
kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int]
|
| 48 |
+
) -> Union[Tuple[int, ...], int]:
|
| 49 |
+
kernel_size_np = np.atleast_1d(kernel_size)
|
| 50 |
+
stride_np = np.atleast_1d(stride)
|
| 51 |
+
|
| 52 |
+
out_padding_np = stride_np - kernel_size_np
|
| 53 |
+
out_padding = tuple(int(p) for p in out_padding_np)
|
| 54 |
+
|
| 55 |
+
return out_padding if len(out_padding) > 1 else out_padding[0]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def calculate_out_shape(
|
| 59 |
+
in_shape: Union[Sequence[int], int, np.ndarray],
|
| 60 |
+
kernel_size: Union[Sequence[int], int],
|
| 61 |
+
stride: Union[Sequence[int], int],
|
| 62 |
+
padding: Union[Sequence[int], int],
|
| 63 |
+
) -> Union[Tuple[int, ...], int]:
|
| 64 |
+
"""
|
| 65 |
+
Calculate the output tensor shape when applying a convolution to a tensor of shape `inShape` with kernel size
|
| 66 |
+
`kernel_size`, stride value `stride`, and input padding value `padding`. All arguments can be scalars or multiple
|
| 67 |
+
values, return value is a scalar if all inputs are scalars.
|
| 68 |
+
"""
|
| 69 |
+
in_shape_np = np.atleast_1d(in_shape)
|
| 70 |
+
kernel_size_np = np.atleast_1d(kernel_size)
|
| 71 |
+
stride_np = np.atleast_1d(stride)
|
| 72 |
+
padding_np = np.atleast_1d(padding)
|
| 73 |
+
|
| 74 |
+
out_shape_np = ((in_shape_np - kernel_size_np + padding_np + padding_np) // stride_np) + 1
|
| 75 |
+
out_shape = tuple(int(s) for s in out_shape_np)
|
| 76 |
+
|
| 77 |
+
return out_shape if len(out_shape) > 1 else out_shape[0]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def gaussian_1d(
|
| 81 |
+
sigma: torch.Tensor, truncated: float = 4.0, approx: str = "erf", normalize: bool = False
|
| 82 |
+
) -> torch.Tensor:
|
| 83 |
+
"""
|
| 84 |
+
one dimensional Gaussian kernel.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
sigma: std of the kernel
|
| 88 |
+
truncated: tail length
|
| 89 |
+
approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".
|
| 90 |
+
|
| 91 |
+
- ``erf`` approximation interpolates the error function;
|
| 92 |
+
- ``sampled`` uses a sampled Gaussian kernel;
|
| 93 |
+
- ``scalespace`` corresponds to
|
| 94 |
+
https://en.wikipedia.org/wiki/Scale_space_implementation#The_discrete_Gaussian_kernel
|
| 95 |
+
based on the modified Bessel functions.
|
| 96 |
+
|
| 97 |
+
normalize: whether to normalize the kernel with `kernel.sum()`.
|
| 98 |
+
|
| 99 |
+
Raises:
|
| 100 |
+
ValueError: When ``truncated`` is non-positive.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
1D torch tensor
|
| 104 |
+
|
| 105 |
+
"""
|
| 106 |
+
sigma = torch.as_tensor(sigma, dtype=torch.float, device=sigma.device if isinstance(sigma, torch.Tensor) else None)
|
| 107 |
+
device = sigma.device
|
| 108 |
+
if truncated <= 0.0:
|
| 109 |
+
raise ValueError(f"truncated must be positive, got {truncated}.")
|
| 110 |
+
tail = int(max(float(sigma) * truncated, 0.5) + 0.5)
|
| 111 |
+
if approx.lower() == "erf":
|
| 112 |
+
x = torch.arange(-tail, tail + 1, dtype=torch.float, device=device)
|
| 113 |
+
t = 0.70710678 / torch.abs(sigma)
|
| 114 |
+
out = 0.5 * ((t * (x + 0.5)).erf() - (t * (x - 0.5)).erf())
|
| 115 |
+
out = out.clamp(min=0)
|
| 116 |
+
elif approx.lower() == "sampled":
|
| 117 |
+
x = torch.arange(-tail, tail + 1, dtype=torch.float, device=sigma.device)
|
| 118 |
+
out = torch.exp(-0.5 / (sigma * sigma) * x**2)
|
| 119 |
+
if not normalize: # compute the normalizer
|
| 120 |
+
out = out / (2.5066282 * sigma)
|
| 121 |
+
elif approx.lower() == "scalespace":
|
| 122 |
+
sigma2 = sigma * sigma
|
| 123 |
+
out_pos: List[Optional[torch.Tensor]] = [None] * (tail + 1)
|
| 124 |
+
out_pos[0] = _modified_bessel_0(sigma2)
|
| 125 |
+
out_pos[1] = _modified_bessel_1(sigma2)
|
| 126 |
+
for k in range(2, len(out_pos)):
|
| 127 |
+
out_pos[k] = _modified_bessel_i(k, sigma2)
|
| 128 |
+
out = out_pos[:0:-1]
|
| 129 |
+
out.extend(out_pos)
|
| 130 |
+
out = torch.stack(out) * torch.exp(-sigma2)
|
| 131 |
+
else:
|
| 132 |
+
raise NotImplementedError(f"Unsupported option: approx='{approx}'.")
|
| 133 |
+
return out / out.sum() if normalize else out # type: ignore
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def polyval(coef, x) -> torch.Tensor:
|
| 137 |
+
"""
|
| 138 |
+
Evaluates the polynomial defined by `coef` at `x`.
|
| 139 |
+
|
| 140 |
+
For a 1D sequence of coef (length n), evaluate::
|
| 141 |
+
|
| 142 |
+
y = coef[n-1] + x * (coef[n-2] + ... + x * (coef[1] + x * coef[0]))
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
coef: a sequence of floats representing the coefficients of the polynomial
|
| 146 |
+
x: float or a sequence of floats representing the variable of the polynomial
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
1D torch tensor
|
| 150 |
+
"""
|
| 151 |
+
device = x.device if isinstance(x, torch.Tensor) else None
|
| 152 |
+
coef = torch.as_tensor(coef, dtype=torch.float, device=device)
|
| 153 |
+
if coef.ndim == 0 or (len(coef) < 1):
|
| 154 |
+
return torch.zeros(x.shape)
|
| 155 |
+
x = torch.as_tensor(x, dtype=torch.float, device=device)
|
| 156 |
+
ans = coef[0]
|
| 157 |
+
for c in coef[1:]:
|
| 158 |
+
ans = ans * x + c
|
| 159 |
+
return ans # type: ignore
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _modified_bessel_0(x: torch.Tensor) -> torch.Tensor:
|
| 163 |
+
x = torch.as_tensor(x, dtype=torch.float, device=x.device if isinstance(x, torch.Tensor) else None)
|
| 164 |
+
if torch.abs(x) < 3.75:
|
| 165 |
+
y = x * x / 14.0625
|
| 166 |
+
return polyval([0.45813e-2, 0.360768e-1, 0.2659732, 1.2067492, 3.0899424, 3.5156229, 1.0], y)
|
| 167 |
+
ax = torch.abs(x)
|
| 168 |
+
y = 3.75 / ax
|
| 169 |
+
_coef = [
|
| 170 |
+
0.392377e-2,
|
| 171 |
+
-0.1647633e-1,
|
| 172 |
+
0.2635537e-1,
|
| 173 |
+
-0.2057706e-1,
|
| 174 |
+
0.916281e-2,
|
| 175 |
+
-0.157565e-2,
|
| 176 |
+
0.225319e-2,
|
| 177 |
+
0.1328592e-1,
|
| 178 |
+
0.39894228,
|
| 179 |
+
]
|
| 180 |
+
return polyval(_coef, y) * torch.exp(ax) / torch.sqrt(ax)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _modified_bessel_1(x: torch.Tensor) -> torch.Tensor:
|
| 184 |
+
x = torch.as_tensor(x, dtype=torch.float, device=x.device if isinstance(x, torch.Tensor) else None)
|
| 185 |
+
if torch.abs(x) < 3.75:
|
| 186 |
+
y = x * x / 14.0625
|
| 187 |
+
_coef = [0.32411e-3, 0.301532e-2, 0.2658733e-1, 0.15084934, 0.51498869, 0.87890594, 0.5]
|
| 188 |
+
return torch.abs(x) * polyval(_coef, y)
|
| 189 |
+
ax = torch.abs(x)
|
| 190 |
+
y = 3.75 / ax
|
| 191 |
+
_coef = [
|
| 192 |
+
-0.420059e-2,
|
| 193 |
+
0.1787654e-1,
|
| 194 |
+
-0.2895312e-1,
|
| 195 |
+
0.2282967e-1,
|
| 196 |
+
-0.1031555e-1,
|
| 197 |
+
0.163801e-2,
|
| 198 |
+
-0.362018e-2,
|
| 199 |
+
-0.3988024e-1,
|
| 200 |
+
0.39894228,
|
| 201 |
+
]
|
| 202 |
+
ans = polyval(_coef, y) * torch.exp(ax) / torch.sqrt(ax)
|
| 203 |
+
return -ans if x < 0.0 else ans
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _modified_bessel_i(n: int, x: torch.Tensor) -> torch.Tensor:
|
| 207 |
+
if n < 2:
|
| 208 |
+
raise ValueError(f"n must be greater than 1, got n={n}.")
|
| 209 |
+
x = torch.as_tensor(x, dtype=torch.float, device=x.device if isinstance(x, torch.Tensor) else None)
|
| 210 |
+
if x == 0.0:
|
| 211 |
+
return x
|
| 212 |
+
device = x.device
|
| 213 |
+
tox = 2.0 / torch.abs(x)
|
| 214 |
+
ans, bip, bi = torch.tensor(0.0, device=device), torch.tensor(0.0, device=device), torch.tensor(1.0, device=device)
|
| 215 |
+
m = int(2 * (n + np.floor(np.sqrt(40.0 * n))))
|
| 216 |
+
for j in range(m, 0, -1):
|
| 217 |
+
bim = bip + float(j) * tox * bi
|
| 218 |
+
bip = bi
|
| 219 |
+
bi = bim
|
| 220 |
+
if abs(bi) > 1.0e10:
|
| 221 |
+
ans = ans * 1.0e-10
|
| 222 |
+
bi = bi * 1.0e-10
|
| 223 |
+
bip = bip * 1.0e-10
|
| 224 |
+
if j == n:
|
| 225 |
+
ans = bip
|
| 226 |
+
ans = ans * _modified_bessel_0(x) / bi
|
| 227 |
+
return -ans if x < 0.0 and (n % 2) == 1 else ans
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/drop_path.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DropPath(nn.Module):
|
| 16 |
+
"""Stochastic drop paths per sample for residual blocks.
|
| 17 |
+
Based on:
|
| 18 |
+
https://github.com/rwightman/pytorch-image-models
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True) -> None:
|
| 22 |
+
"""
|
| 23 |
+
Args:
|
| 24 |
+
drop_prob: drop path probability.
|
| 25 |
+
scale_by_keep: scaling by non-dropped probability.
|
| 26 |
+
"""
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.drop_prob = drop_prob
|
| 29 |
+
self.scale_by_keep = scale_by_keep
|
| 30 |
+
|
| 31 |
+
if not (0 <= drop_prob <= 1):
|
| 32 |
+
raise ValueError("Drop path prob should be between 0 and 1.")
|
| 33 |
+
|
| 34 |
+
def drop_path(self, x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
|
| 35 |
+
if drop_prob == 0.0 or not training:
|
| 36 |
+
return x
|
| 37 |
+
keep_prob = 1 - drop_prob
|
| 38 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 39 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 40 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 41 |
+
random_tensor.div_(keep_prob)
|
| 42 |
+
return x * random_tensor
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
return self.drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/factories.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Defines factories for creating layers in generic, extensible, and dimensionally independent ways. A separate factory
|
| 14 |
+
object is created for each type of layer, and factory functions keyed to names are added to these objects. Whenever
|
| 15 |
+
a layer is requested the factory name and any necessary arguments are passed to the factory object. The return value
|
| 16 |
+
is typically a type but can be any callable producing a layer object.
|
| 17 |
+
|
| 18 |
+
The factory objects contain functions keyed to names converted to upper case, these names can be referred to as members
|
| 19 |
+
of the factory so that they can function as constant identifiers. eg. instance normalization is named `Norm.INSTANCE`.
|
| 20 |
+
|
| 21 |
+
For example, to get a transpose convolution layer the name is needed and then a dimension argument is provided which is
|
| 22 |
+
passed to the factory function:
|
| 23 |
+
|
| 24 |
+
.. code-block:: python
|
| 25 |
+
|
| 26 |
+
dimension = 3
|
| 27 |
+
name = Conv.CONVTRANS
|
| 28 |
+
conv = Conv[name, dimension]
|
| 29 |
+
|
| 30 |
+
This allows the `dimension` value to be set in the constructor, for example so that the dimensionality of a network is
|
| 31 |
+
parameterizable. Not all factories require arguments after the name, the caller must be aware which are required.
|
| 32 |
+
|
| 33 |
+
Defining new factories involves creating the object then associating it with factory functions:
|
| 34 |
+
|
| 35 |
+
.. code-block:: python
|
| 36 |
+
|
| 37 |
+
fact = LayerFactory()
|
| 38 |
+
|
| 39 |
+
@fact.factory_function('test')
|
| 40 |
+
def make_something(x, y):
|
| 41 |
+
# do something with x and y to choose which layer type to return
|
| 42 |
+
return SomeLayerType
|
| 43 |
+
...
|
| 44 |
+
|
| 45 |
+
# request object from factory TEST with 1 and 2 as values for x and y
|
| 46 |
+
layer = fact[fact.TEST, 1, 2]
|
| 47 |
+
|
| 48 |
+
Typically the caller of a factory would know what arguments to pass (ie. the dimensionality of the requested type) but
|
| 49 |
+
can be parameterized with the factory name and the arguments to pass to the created type at instantiation time:
|
| 50 |
+
|
| 51 |
+
.. code-block:: python
|
| 52 |
+
|
| 53 |
+
def use_factory(fact_args):
|
| 54 |
+
fact_name, type_args = split_args
|
| 55 |
+
layer_type = fact[fact_name, 1, 2]
|
| 56 |
+
return layer_type(**type_args)
|
| 57 |
+
...
|
| 58 |
+
|
| 59 |
+
kw_args = {'arg0':0, 'arg1':True}
|
| 60 |
+
layer = use_factory( (fact.TEST, kwargs) )
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
import warnings
|
| 64 |
+
from typing import Any, Callable, Dict, Tuple, Type, Union
|
| 65 |
+
|
| 66 |
+
import torch
|
| 67 |
+
import torch.nn as nn
|
| 68 |
+
|
| 69 |
+
from monai.utils import look_up_option, optional_import
|
| 70 |
+
|
| 71 |
+
InstanceNorm3dNVFuser, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class LayerFactory:
|
| 78 |
+
"""
|
| 79 |
+
Factory object for creating layers, this uses given factory functions to actually produce the types or constructing
|
| 80 |
+
callables. These functions are referred to by name and can be added at any time.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(self) -> None:
|
| 84 |
+
self.factories: Dict[str, Callable] = {}
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def names(self) -> Tuple[str, ...]:
|
| 88 |
+
"""
|
| 89 |
+
Produces all factory names.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
return tuple(self.factories)
|
| 93 |
+
|
| 94 |
+
def add_factory_callable(self, name: str, func: Callable) -> None:
|
| 95 |
+
"""
|
| 96 |
+
Add the factory function to this object under the given name.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
self.factories[name.upper()] = func
|
| 100 |
+
self.__doc__ = (
|
| 101 |
+
"The supported member"
|
| 102 |
+
+ ("s are: " if len(self.names) > 1 else " is: ")
|
| 103 |
+
+ ", ".join(f"``{name}``" for name in self.names)
|
| 104 |
+
+ ".\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing."
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def factory_function(self, name: str) -> Callable:
|
| 108 |
+
"""
|
| 109 |
+
Decorator for adding a factory function with the given name.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def _add(func: Callable) -> Callable:
|
| 113 |
+
self.add_factory_callable(name, func)
|
| 114 |
+
return func
|
| 115 |
+
|
| 116 |
+
return _add
|
| 117 |
+
|
| 118 |
+
def get_constructor(self, factory_name: str, *args) -> Any:
|
| 119 |
+
"""
|
| 120 |
+
Get the constructor for the given factory name and arguments.
|
| 121 |
+
|
| 122 |
+
Raises:
|
| 123 |
+
TypeError: When ``factory_name`` is not a ``str``.
|
| 124 |
+
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
if not isinstance(factory_name, str):
|
| 128 |
+
raise TypeError(f"factory_name must a str but is {type(factory_name).__name__}.")
|
| 129 |
+
|
| 130 |
+
func = look_up_option(factory_name.upper(), self.factories)
|
| 131 |
+
return func(*args)
|
| 132 |
+
|
| 133 |
+
def __getitem__(self, args) -> Any:
|
| 134 |
+
"""
|
| 135 |
+
Get the given name or name/arguments pair. If `args` is a callable it is assumed to be the constructor
|
| 136 |
+
itself and is returned, otherwise it should be the factory name or a pair containing the name and arguments.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
# `args[0]` is actually a type or constructor
|
| 140 |
+
if callable(args):
|
| 141 |
+
return args
|
| 142 |
+
|
| 143 |
+
# `args` is a factory name or a name with arguments
|
| 144 |
+
if isinstance(args, str):
|
| 145 |
+
name_obj, args = args, ()
|
| 146 |
+
else:
|
| 147 |
+
name_obj, *args = args
|
| 148 |
+
|
| 149 |
+
return self.get_constructor(name_obj, *args)
|
| 150 |
+
|
| 151 |
+
def __getattr__(self, key):
|
| 152 |
+
"""
|
| 153 |
+
If `key` is a factory name, return it, otherwise behave as inherited. This allows referring to factory names
|
| 154 |
+
as if they were constants, eg. `Fact.FOO` for a factory Fact with factory function foo.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
if key in self.factories:
|
| 158 |
+
return key
|
| 159 |
+
|
| 160 |
+
return super().__getattribute__(key)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def split_args(args):
|
| 164 |
+
"""
|
| 165 |
+
Split arguments in a way to be suitable for using with the factory types. If `args` is a string it's interpreted as
|
| 166 |
+
the type name.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
args (str or a tuple of object name and kwarg dict): input arguments to be parsed.
|
| 170 |
+
|
| 171 |
+
Raises:
|
| 172 |
+
TypeError: When ``args`` type is not in ``Union[str, Tuple[Union[str, Callable], dict]]``.
|
| 173 |
+
|
| 174 |
+
Examples::
|
| 175 |
+
|
| 176 |
+
>>> act_type, args = split_args("PRELU")
|
| 177 |
+
>>> monai.networks.layers.Act[act_type]
|
| 178 |
+
<class 'torch.nn.modules.activation.PReLU'>
|
| 179 |
+
|
| 180 |
+
>>> act_type, args = split_args(("PRELU", {"num_parameters": 1, "init": 0.25}))
|
| 181 |
+
>>> monai.networks.layers.Act[act_type](**args)
|
| 182 |
+
PReLU(num_parameters=1)
|
| 183 |
+
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
if isinstance(args, str):
|
| 187 |
+
return args, {}
|
| 188 |
+
name_obj, name_args = args
|
| 189 |
+
|
| 190 |
+
if not isinstance(name_obj, (str, Callable)) or not isinstance(name_args, dict):
|
| 191 |
+
msg = "Layer specifiers must be single strings or pairs of the form (name/object-types, argument dict)"
|
| 192 |
+
raise TypeError(msg)
|
| 193 |
+
|
| 194 |
+
return name_obj, name_args
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Define factories for these layer types
|
| 198 |
+
|
| 199 |
+
Dropout = LayerFactory()
|
| 200 |
+
Norm = LayerFactory()
|
| 201 |
+
Act = LayerFactory()
|
| 202 |
+
Conv = LayerFactory()
|
| 203 |
+
Pool = LayerFactory()
|
| 204 |
+
Pad = LayerFactory()
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@Dropout.factory_function("dropout")
|
| 208 |
+
def dropout_factory(dim: int) -> Type[Union[nn.Dropout, nn.Dropout2d, nn.Dropout3d]]:
|
| 209 |
+
types = (nn.Dropout, nn.Dropout2d, nn.Dropout3d)
|
| 210 |
+
return types[dim - 1]
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@Dropout.factory_function("alphadropout")
|
| 214 |
+
def alpha_dropout_factory(_dim):
|
| 215 |
+
return nn.AlphaDropout
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@Norm.factory_function("instance")
|
| 219 |
+
def instance_factory(dim: int) -> Type[Union[nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d]]:
|
| 220 |
+
types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)
|
| 221 |
+
return types[dim - 1]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@Norm.factory_function("batch")
|
| 225 |
+
def batch_factory(dim: int) -> Type[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]]:
|
| 226 |
+
types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
|
| 227 |
+
return types[dim - 1]
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@Norm.factory_function("group")
|
| 231 |
+
def group_factory(_dim) -> Type[nn.GroupNorm]:
|
| 232 |
+
return nn.GroupNorm
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@Norm.factory_function("layer")
|
| 236 |
+
def layer_factory(_dim) -> Type[nn.LayerNorm]:
|
| 237 |
+
return nn.LayerNorm
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@Norm.factory_function("localresponse")
|
| 241 |
+
def local_response_factory(_dim) -> Type[nn.LocalResponseNorm]:
|
| 242 |
+
return nn.LocalResponseNorm
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@Norm.factory_function("syncbatch")
|
| 246 |
+
def sync_batch_factory(_dim) -> Type[nn.SyncBatchNorm]:
|
| 247 |
+
return nn.SyncBatchNorm
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@Norm.factory_function("instance_nvfuser")
|
| 251 |
+
def instance_nvfuser_factory(dim):
|
| 252 |
+
"""
|
| 253 |
+
`InstanceNorm3dNVFuser` is a faster verison of InstanceNorm layer and implemented in `apex`.
|
| 254 |
+
It only supports 3d tensors as the input. It also requires to use with CUDA and non-Windows OS.
|
| 255 |
+
In this function, if the required library `apex.normalization.InstanceNorm3dNVFuser` does not exist,
|
| 256 |
+
`nn.InstanceNorm3d` will be returned instead.
|
| 257 |
+
This layer is based on a customized autograd function, which is not supported in TorchScript currently.
|
| 258 |
+
Please switch to use `nn.InstanceNorm3d` if TorchScript is necessary.
|
| 259 |
+
|
| 260 |
+
Please check the following link for more details about how to install `apex`:
|
| 261 |
+
https://github.com/NVIDIA/apex#installation
|
| 262 |
+
|
| 263 |
+
"""
|
| 264 |
+
types = (nn.InstanceNorm1d, nn.InstanceNorm2d)
|
| 265 |
+
if dim != 3:
|
| 266 |
+
warnings.warn(f"`InstanceNorm3dNVFuser` only supports 3d cases, use {types[dim - 1]} instead.")
|
| 267 |
+
return types[dim - 1]
|
| 268 |
+
# test InstanceNorm3dNVFuser installation with a basic example
|
| 269 |
+
has_nvfuser_flag = has_nvfuser
|
| 270 |
+
if not torch.cuda.is_available():
|
| 271 |
+
return nn.InstanceNorm3d
|
| 272 |
+
try:
|
| 273 |
+
layer = InstanceNorm3dNVFuser(num_features=1, affine=True).to("cuda:0")
|
| 274 |
+
inp = torch.randn([1, 1, 1, 1, 1]).to("cuda:0")
|
| 275 |
+
out = layer(inp)
|
| 276 |
+
del inp, out, layer
|
| 277 |
+
except Exception:
|
| 278 |
+
has_nvfuser_flag = False
|
| 279 |
+
if not has_nvfuser_flag:
|
| 280 |
+
warnings.warn(
|
| 281 |
+
"`apex.normalization.InstanceNorm3dNVFuser` is not installed properly, use nn.InstanceNorm3d instead."
|
| 282 |
+
)
|
| 283 |
+
return nn.InstanceNorm3d
|
| 284 |
+
return InstanceNorm3dNVFuser
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
Act.add_factory_callable("elu", lambda: nn.modules.ELU)
|
| 288 |
+
Act.add_factory_callable("relu", lambda: nn.modules.ReLU)
|
| 289 |
+
Act.add_factory_callable("leakyrelu", lambda: nn.modules.LeakyReLU)
|
| 290 |
+
Act.add_factory_callable("prelu", lambda: nn.modules.PReLU)
|
| 291 |
+
Act.add_factory_callable("relu6", lambda: nn.modules.ReLU6)
|
| 292 |
+
Act.add_factory_callable("selu", lambda: nn.modules.SELU)
|
| 293 |
+
Act.add_factory_callable("celu", lambda: nn.modules.CELU)
|
| 294 |
+
Act.add_factory_callable("gelu", lambda: nn.modules.GELU)
|
| 295 |
+
Act.add_factory_callable("sigmoid", lambda: nn.modules.Sigmoid)
|
| 296 |
+
Act.add_factory_callable("tanh", lambda: nn.modules.Tanh)
|
| 297 |
+
Act.add_factory_callable("softmax", lambda: nn.modules.Softmax)
|
| 298 |
+
Act.add_factory_callable("logsoftmax", lambda: nn.modules.LogSoftmax)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
@Act.factory_function("swish")
|
| 302 |
+
def swish_factory():
|
| 303 |
+
from monai.networks.blocks.activation import Swish
|
| 304 |
+
|
| 305 |
+
return Swish
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
@Act.factory_function("memswish")
|
| 309 |
+
def memswish_factory():
|
| 310 |
+
from monai.networks.blocks.activation import MemoryEfficientSwish
|
| 311 |
+
|
| 312 |
+
return MemoryEfficientSwish
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@Act.factory_function("mish")
|
| 316 |
+
def mish_factory():
|
| 317 |
+
from monai.networks.blocks.activation import Mish
|
| 318 |
+
|
| 319 |
+
return Mish
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@Conv.factory_function("conv")
|
| 323 |
+
def conv_factory(dim: int) -> Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]]:
|
| 324 |
+
types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)
|
| 325 |
+
return types[dim - 1]
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
@Conv.factory_function("convtrans")
|
| 329 |
+
def convtrans_factory(dim: int) -> Type[Union[nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d]]:
|
| 330 |
+
types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)
|
| 331 |
+
return types[dim - 1]
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
@Pool.factory_function("max")
|
| 335 |
+
def maxpooling_factory(dim: int) -> Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]]:
|
| 336 |
+
types = (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)
|
| 337 |
+
return types[dim - 1]
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
@Pool.factory_function("adaptivemax")
|
| 341 |
+
def adaptive_maxpooling_factory(
|
| 342 |
+
dim: int,
|
| 343 |
+
) -> Type[Union[nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d]]:
|
| 344 |
+
types = (nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d)
|
| 345 |
+
return types[dim - 1]
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
@Pool.factory_function("avg")
|
| 349 |
+
def avgpooling_factory(dim: int) -> Type[Union[nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d]]:
|
| 350 |
+
types = (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)
|
| 351 |
+
return types[dim - 1]
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
@Pool.factory_function("adaptiveavg")
|
| 355 |
+
def adaptive_avgpooling_factory(
|
| 356 |
+
dim: int,
|
| 357 |
+
) -> Type[Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]]:
|
| 358 |
+
types = (nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d)
|
| 359 |
+
return types[dim - 1]
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
@Pad.factory_function("replicationpad")
|
| 363 |
+
def replication_pad_factory(dim: int) -> Type[Union[nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d]]:
|
| 364 |
+
types = (nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d)
|
| 365 |
+
return types[dim - 1]
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
@Pad.factory_function("constantpad")
|
| 369 |
+
def constant_pad_factory(dim: int) -> Type[Union[nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]]:
|
| 370 |
+
types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)
|
| 371 |
+
return types[dim - 1]
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/filtering.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from monai.utils.module import optional_import
|
| 15 |
+
|
| 16 |
+
_C, _ = optional_import("monai._C")
|
| 17 |
+
|
| 18 |
+
__all__ = ["BilateralFilter", "PHLFilter"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BilateralFilter(torch.autograd.Function):
|
| 22 |
+
"""
|
| 23 |
+
Blurs the input tensor spatially whilst preserving edges. Can run on 1D, 2D, or 3D,
|
| 24 |
+
tensors (on top of Batch and Channel dimensions). Two implementations are provided,
|
| 25 |
+
an exact solution and a much faster approximation which uses a permutohedral lattice.
|
| 26 |
+
|
| 27 |
+
See:
|
| 28 |
+
https://en.wikipedia.org/wiki/Bilateral_filter
|
| 29 |
+
https://graphics.stanford.edu/papers/permutohedral/
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
input: input tensor.
|
| 33 |
+
|
| 34 |
+
spatial sigma: the standard deviation of the spatial blur. Higher values can
|
| 35 |
+
hurt performance when not using the approximate method (see fast approx).
|
| 36 |
+
|
| 37 |
+
color sigma: the standard deviation of the color blur. Lower values preserve
|
| 38 |
+
edges better whilst higher values tend to a simple gaussian spatial blur.
|
| 39 |
+
|
| 40 |
+
fast approx: This flag chooses between two implementations. The approximate method may
|
| 41 |
+
produce artifacts in some scenarios whereas the exact solution may be intolerably
|
| 42 |
+
slow for high spatial standard deviations.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
output (torch.Tensor): output tensor.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True):
|
| 50 |
+
ctx.ss = spatial_sigma
|
| 51 |
+
ctx.cs = color_sigma
|
| 52 |
+
ctx.fa = fast_approx
|
| 53 |
+
output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx)
|
| 54 |
+
return output_data
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def backward(ctx, grad_output):
|
| 58 |
+
spatial_sigma, color_sigma, fast_approx = ctx.ss, ctx.cs, ctx.fa
|
| 59 |
+
grad_input = _C.bilateral_filter(grad_output, spatial_sigma, color_sigma, fast_approx)
|
| 60 |
+
return grad_input, None, None, None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class PHLFilter(torch.autograd.Function):
|
| 64 |
+
"""
|
| 65 |
+
Filters input based on arbitrary feature vectors. Uses a permutohedral
|
| 66 |
+
lattice data structure to efficiently approximate n-dimensional gaussian
|
| 67 |
+
filtering. Complexity is broadly independent of kernel size. Most applicable
|
| 68 |
+
to higher filter dimensions and larger kernel sizes.
|
| 69 |
+
|
| 70 |
+
See:
|
| 71 |
+
https://graphics.stanford.edu/papers/permutohedral/
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
input: input tensor to be filtered.
|
| 75 |
+
|
| 76 |
+
features: feature tensor used to filter the input.
|
| 77 |
+
|
| 78 |
+
sigmas: the standard deviations of each feature in the filter.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
output (torch.Tensor): output tensor.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def forward(ctx, input, features, sigmas=None):
|
| 86 |
+
|
| 87 |
+
scaled_features = features
|
| 88 |
+
if sigmas is not None:
|
| 89 |
+
for i in range(features.size(1)):
|
| 90 |
+
scaled_features[:, i, ...] /= sigmas[i]
|
| 91 |
+
|
| 92 |
+
ctx.save_for_backward(scaled_features)
|
| 93 |
+
output_data = _C.phl_filter(input, scaled_features)
|
| 94 |
+
return output_data
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
def backward(ctx, grad_output):
|
| 98 |
+
raise NotImplementedError("PHLFilter does not currently support Backpropagation")
|
| 99 |
+
# scaled_features, = ctx.saved_variables
|
| 100 |
+
# grad_input = _C.phl_filter(grad_output, scaled_features)
|
| 101 |
+
# return grad_input
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/gmm.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from monai._extensions.loader import load_module
|
| 15 |
+
|
| 16 |
+
__all__ = ["GaussianMixtureModel"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class GaussianMixtureModel:
|
| 20 |
+
"""
|
| 21 |
+
Takes an initial labeling and uses a mixture of Gaussians to approximate each classes
|
| 22 |
+
distribution in the feature space. Each unlabeled element is then assigned a probability
|
| 23 |
+
of belonging to each class based on it's fit to each classes approximated distribution.
|
| 24 |
+
|
| 25 |
+
See:
|
| 26 |
+
https://en.wikipedia.org/wiki/Mixture_model
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, channel_count: int, mixture_count: int, mixture_size: int, verbose_build: bool = False):
|
| 30 |
+
"""
|
| 31 |
+
Args:
|
| 32 |
+
channel_count: The number of features per element.
|
| 33 |
+
mixture_count: The number of class distributions.
|
| 34 |
+
mixture_size: The number Gaussian components per class distribution.
|
| 35 |
+
verbose_build: If ``True``, turns on verbose logging of load steps.
|
| 36 |
+
"""
|
| 37 |
+
if not torch.cuda.is_available():
|
| 38 |
+
raise NotImplementedError("GaussianMixtureModel is currently implemented for CUDA.")
|
| 39 |
+
self.channel_count = channel_count
|
| 40 |
+
self.mixture_count = mixture_count
|
| 41 |
+
self.mixture_size = mixture_size
|
| 42 |
+
self.compiled_extension = load_module(
|
| 43 |
+
"gmm",
|
| 44 |
+
{"CHANNEL_COUNT": channel_count, "MIXTURE_COUNT": mixture_count, "MIXTURE_SIZE": mixture_size},
|
| 45 |
+
verbose_build=verbose_build,
|
| 46 |
+
)
|
| 47 |
+
self.params, self.scratch = self.compiled_extension.init()
|
| 48 |
+
|
| 49 |
+
def reset(self):
|
| 50 |
+
"""
|
| 51 |
+
Resets the parameters of the model.
|
| 52 |
+
"""
|
| 53 |
+
self.params, self.scratch = self.compiled_extension.init()
|
| 54 |
+
|
| 55 |
+
def learn(self, features, labels):
|
| 56 |
+
"""
|
| 57 |
+
Learns, from scratch, the distribution of each class from the provided labels.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
features (torch.Tensor): features for each element.
|
| 61 |
+
labels (torch.Tensor): initial labeling for each element.
|
| 62 |
+
"""
|
| 63 |
+
self.compiled_extension.learn(self.params, self.scratch, features, labels)
|
| 64 |
+
|
| 65 |
+
def apply(self, features):
|
| 66 |
+
"""
|
| 67 |
+
Applies the current model to a set of feature vectors.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
features (torch.Tensor): feature vectors for each element.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
output (torch.Tensor): class assignment probabilities for each element.
|
| 74 |
+
"""
|
| 75 |
+
return _ApplyFunc.apply(self.params, features, self.compiled_extension)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class _ApplyFunc(torch.autograd.Function):
|
| 79 |
+
@staticmethod
|
| 80 |
+
def forward(ctx, params, features, compiled_extension):
|
| 81 |
+
return compiled_extension.apply(params, features)
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def backward(ctx, grad_output):
|
| 85 |
+
raise NotImplementedError("GMM does not support backpropagation")
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/simplelayers.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
from copy import deepcopy
|
| 14 |
+
from typing import List, Sequence, Union
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch import nn
|
| 19 |
+
from torch.autograd import Function
|
| 20 |
+
|
| 21 |
+
from monai.networks.layers.convutils import gaussian_1d
|
| 22 |
+
from monai.networks.layers.factories import Conv
|
| 23 |
+
from monai.utils import ChannelMatching, SkipMode, look_up_option, optional_import, pytorch_after
|
| 24 |
+
from monai.utils.misc import issequenceiterable
|
| 25 |
+
|
| 26 |
+
_C, _ = optional_import("monai._C")
|
| 27 |
+
fft, _ = optional_import("torch.fft")
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
"ChannelPad",
|
| 31 |
+
"Flatten",
|
| 32 |
+
"GaussianFilter",
|
| 33 |
+
"HilbertTransform",
|
| 34 |
+
"LLTM",
|
| 35 |
+
"Reshape",
|
| 36 |
+
"SavitzkyGolayFilter",
|
| 37 |
+
"SkipConnection",
|
| 38 |
+
"apply_filter",
|
| 39 |
+
"separable_filtering",
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ChannelPad(nn.Module):
|
| 44 |
+
"""
|
| 45 |
+
Expand the input tensor's channel dimension from length `in_channels` to `out_channels`,
|
| 46 |
+
by padding or a projection.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
spatial_dims: int,
|
| 52 |
+
in_channels: int,
|
| 53 |
+
out_channels: int,
|
| 54 |
+
mode: Union[ChannelMatching, str] = ChannelMatching.PAD,
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
spatial_dims: number of spatial dimensions of the input image.
|
| 60 |
+
in_channels: number of input channels.
|
| 61 |
+
out_channels: number of output channels.
|
| 62 |
+
mode: {``"pad"``, ``"project"``}
|
| 63 |
+
Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``.
|
| 64 |
+
|
| 65 |
+
- ``"pad"``: with zero padding.
|
| 66 |
+
- ``"project"``: with a trainable conv with kernel size one.
|
| 67 |
+
"""
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.project = None
|
| 70 |
+
self.pad = None
|
| 71 |
+
if in_channels == out_channels:
|
| 72 |
+
return
|
| 73 |
+
mode = look_up_option(mode, ChannelMatching)
|
| 74 |
+
if mode == ChannelMatching.PROJECT:
|
| 75 |
+
conv_type = Conv[Conv.CONV, spatial_dims]
|
| 76 |
+
self.project = conv_type(in_channels, out_channels, kernel_size=1)
|
| 77 |
+
return
|
| 78 |
+
if mode == ChannelMatching.PAD:
|
| 79 |
+
if in_channels > out_channels:
|
| 80 |
+
raise ValueError('Incompatible values: channel_matching="pad" and in_channels > out_channels.')
|
| 81 |
+
pad_1 = (out_channels - in_channels) // 2
|
| 82 |
+
pad_2 = out_channels - in_channels - pad_1
|
| 83 |
+
pad = [0, 0] * spatial_dims + [pad_1, pad_2] + [0, 0]
|
| 84 |
+
self.pad = tuple(pad)
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
if self.project is not None:
|
| 89 |
+
return torch.as_tensor(self.project(x)) # as_tensor used to get around mypy typing bug
|
| 90 |
+
if self.pad is not None:
|
| 91 |
+
return F.pad(x, self.pad)
|
| 92 |
+
return x
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class SkipConnection(nn.Module):
|
| 96 |
+
"""
|
| 97 |
+
Combine the forward pass input with the result from the given submodule::
|
| 98 |
+
|
| 99 |
+
--+--submodule--o--
|
| 100 |
+
|_____________|
|
| 101 |
+
|
| 102 |
+
The available modes are ``"cat"``, ``"add"``, ``"mul"``.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(self, submodule, dim: int = 1, mode: Union[str, SkipMode] = "cat") -> None:
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
submodule: the module defines the trainable branch.
|
| 110 |
+
dim: the dimension over which the tensors are concatenated.
|
| 111 |
+
Used when mode is ``"cat"``.
|
| 112 |
+
mode: ``"cat"``, ``"add"``, ``"mul"``. defaults to ``"cat"``.
|
| 113 |
+
"""
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.submodule = submodule
|
| 116 |
+
self.dim = dim
|
| 117 |
+
self.mode = look_up_option(mode, SkipMode).value
|
| 118 |
+
|
| 119 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 120 |
+
y = self.submodule(x)
|
| 121 |
+
|
| 122 |
+
if self.mode == "cat":
|
| 123 |
+
return torch.cat([x, y], dim=self.dim)
|
| 124 |
+
if self.mode == "add":
|
| 125 |
+
return torch.add(x, y)
|
| 126 |
+
if self.mode == "mul":
|
| 127 |
+
return torch.mul(x, y)
|
| 128 |
+
raise NotImplementedError(f"Unsupported mode {self.mode}.")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class Flatten(nn.Module):
|
| 132 |
+
"""
|
| 133 |
+
Flattens the given input in the forward pass to be [B,-1] in shape.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 137 |
+
return x.view(x.size(0), -1)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class Reshape(nn.Module):
|
| 141 |
+
"""
|
| 142 |
+
Reshapes input tensors to the given shape (minus batch dimension), retaining original batch size.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def __init__(self, *shape: int) -> None:
|
| 146 |
+
"""
|
| 147 |
+
Given a shape list/tuple `shape` of integers (s0, s1, ... , sn), this layer will reshape input tensors of
|
| 148 |
+
shape (batch, s0 * s1 * ... * sn) to shape (batch, s0, s1, ... , sn).
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
shape: list/tuple of integer shape dimensions
|
| 152 |
+
"""
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.shape = (1,) + tuple(shape)
|
| 155 |
+
|
| 156 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 157 |
+
shape = list(self.shape)
|
| 158 |
+
shape[0] = x.shape[0] # done this way for Torchscript
|
| 159 |
+
return x.reshape(shape)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _separable_filtering_conv(
|
| 163 |
+
input_: torch.Tensor,
|
| 164 |
+
kernels: List[torch.Tensor],
|
| 165 |
+
pad_mode: str,
|
| 166 |
+
d: int,
|
| 167 |
+
spatial_dims: int,
|
| 168 |
+
paddings: List[int],
|
| 169 |
+
num_channels: int,
|
| 170 |
+
) -> torch.Tensor:
|
| 171 |
+
|
| 172 |
+
if d < 0:
|
| 173 |
+
return input_
|
| 174 |
+
|
| 175 |
+
s = [1] * len(input_.shape)
|
| 176 |
+
s[d + 2] = -1
|
| 177 |
+
_kernel = kernels[d].reshape(s)
|
| 178 |
+
|
| 179 |
+
# if filter kernel is unity, don't convolve
|
| 180 |
+
if _kernel.numel() == 1 and _kernel[0] == 1:
|
| 181 |
+
return _separable_filtering_conv(input_, kernels, pad_mode, d - 1, spatial_dims, paddings, num_channels)
|
| 182 |
+
|
| 183 |
+
_kernel = _kernel.repeat([num_channels, 1] + [1] * spatial_dims)
|
| 184 |
+
_padding = [0] * spatial_dims
|
| 185 |
+
_padding[d] = paddings[d]
|
| 186 |
+
conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1]
|
| 187 |
+
|
| 188 |
+
# translate padding for input to torch.nn.functional.pad
|
| 189 |
+
_reversed_padding_repeated_twice: List[List[int]] = [[p, p] for p in reversed(_padding)]
|
| 190 |
+
_sum_reversed_padding_repeated_twice: List[int] = sum(_reversed_padding_repeated_twice, [])
|
| 191 |
+
padded_input = F.pad(input_, _sum_reversed_padding_repeated_twice, mode=pad_mode)
|
| 192 |
+
|
| 193 |
+
return conv_type(
|
| 194 |
+
input=_separable_filtering_conv(padded_input, kernels, pad_mode, d - 1, spatial_dims, paddings, num_channels),
|
| 195 |
+
weight=_kernel,
|
| 196 |
+
groups=num_channels,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def separable_filtering(x: torch.Tensor, kernels: List[torch.Tensor], mode: str = "zeros") -> torch.Tensor:
|
| 201 |
+
"""
|
| 202 |
+
Apply 1-D convolutions along each spatial dimension of `x`.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
x: the input image. must have shape (batch, channels, H[, W, ...]).
|
| 206 |
+
kernels: kernel along each spatial dimension.
|
| 207 |
+
could be a single kernel (duplicated for all spatial dimensions), or
|
| 208 |
+
a list of `spatial_dims` number of kernels.
|
| 209 |
+
mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'``
|
| 210 |
+
or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information.
|
| 211 |
+
|
| 212 |
+
Raises:
|
| 213 |
+
TypeError: When ``x`` is not a ``torch.Tensor``.
|
| 214 |
+
|
| 215 |
+
Examples:
|
| 216 |
+
|
| 217 |
+
.. code-block:: python
|
| 218 |
+
|
| 219 |
+
>>> import torch
|
| 220 |
+
>>> from monai.networks.layers import separable_filtering
|
| 221 |
+
>>> img = torch.randn(2, 4, 32, 32) # batch_size 2, channels 4, 32x32 2D images
|
| 222 |
+
# applying a [-1, 0, 1] filter along each of the spatial dimensions.
|
| 223 |
+
# the output shape is the same as the input shape.
|
| 224 |
+
>>> out = separable_filtering(img, torch.tensor((-1., 0., 1.)))
|
| 225 |
+
# applying `[-1, 0, 1]`, `[1, 0, -1]` filters along two spatial dimensions respectively.
|
| 226 |
+
# the output shape is the same as the input shape.
|
| 227 |
+
>>> out = separable_filtering(img, [torch.tensor((-1., 0., 1.)), torch.tensor((1., 0., -1.))])
|
| 228 |
+
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
if not isinstance(x, torch.Tensor):
|
| 232 |
+
raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.")
|
| 233 |
+
|
| 234 |
+
spatial_dims = len(x.shape) - 2
|
| 235 |
+
if isinstance(kernels, torch.Tensor):
|
| 236 |
+
kernels = [kernels] * spatial_dims
|
| 237 |
+
_kernels = [s.to(x) for s in kernels]
|
| 238 |
+
_paddings = [(k.shape[0] - 1) // 2 for k in _kernels]
|
| 239 |
+
n_chs = x.shape[1]
|
| 240 |
+
pad_mode = "constant" if mode == "zeros" else mode
|
| 241 |
+
|
| 242 |
+
return _separable_filtering_conv(x, _kernels, pad_mode, spatial_dims - 1, spatial_dims, _paddings, n_chs)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 246 |
+
"""
|
| 247 |
+
Filtering `x` with `kernel` independently for each batch and channel respectively.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
x: the input image, must have shape (batch, channels, H[, W, D]).
|
| 251 |
+
kernel: `kernel` must at least have the spatial shape (H_k[, W_k, D_k]).
|
| 252 |
+
`kernel` shape must be broadcastable to the `batch` and `channels` dimensions of `x`.
|
| 253 |
+
kwargs: keyword arguments passed to `conv*d()` functions.
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
The filtered `x`.
|
| 257 |
+
|
| 258 |
+
Examples:
|
| 259 |
+
|
| 260 |
+
.. code-block:: python
|
| 261 |
+
|
| 262 |
+
>>> import torch
|
| 263 |
+
>>> from monai.networks.layers import apply_filter
|
| 264 |
+
>>> img = torch.rand(2, 5, 10, 10) # batch_size 2, channels 5, 10x10 2D images
|
| 265 |
+
>>> out = apply_filter(img, torch.rand(3, 3)) # spatial kernel
|
| 266 |
+
>>> out = apply_filter(img, torch.rand(5, 3, 3)) # channel-wise kernels
|
| 267 |
+
>>> out = apply_filter(img, torch.rand(2, 5, 3, 3)) # batch-, channel-wise kernels
|
| 268 |
+
|
| 269 |
+
"""
|
| 270 |
+
if not isinstance(x, torch.Tensor):
|
| 271 |
+
raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.")
|
| 272 |
+
batch, chns, *spatials = x.shape
|
| 273 |
+
n_spatial = len(spatials)
|
| 274 |
+
if n_spatial > 3:
|
| 275 |
+
raise NotImplementedError(f"Only spatial dimensions up to 3 are supported but got {n_spatial}.")
|
| 276 |
+
k_size = len(kernel.shape)
|
| 277 |
+
if k_size < n_spatial or k_size > n_spatial + 2:
|
| 278 |
+
raise ValueError(
|
| 279 |
+
f"kernel must have {n_spatial} ~ {n_spatial + 2} dimensions to match the input shape {x.shape}."
|
| 280 |
+
)
|
| 281 |
+
kernel = kernel.to(x)
|
| 282 |
+
# broadcast kernel size to (batch chns, spatial_kernel_size)
|
| 283 |
+
kernel = kernel.expand(batch, chns, *kernel.shape[(k_size - n_spatial) :])
|
| 284 |
+
kernel = kernel.reshape(-1, 1, *kernel.shape[2:]) # group=1
|
| 285 |
+
x = x.view(1, kernel.shape[0], *spatials)
|
| 286 |
+
conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1]
|
| 287 |
+
if "padding" not in kwargs:
|
| 288 |
+
if pytorch_after(1, 10):
|
| 289 |
+
kwargs["padding"] = "same"
|
| 290 |
+
else:
|
| 291 |
+
# even-sized kernels are not supported
|
| 292 |
+
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
|
| 293 |
+
|
| 294 |
+
if "stride" not in kwargs:
|
| 295 |
+
kwargs["stride"] = 1
|
| 296 |
+
output = conv(x, kernel, groups=kernel.shape[0], bias=None, **kwargs)
|
| 297 |
+
return output.view(batch, chns, *output.shape[2:])
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class SavitzkyGolayFilter(nn.Module):
|
| 301 |
+
"""
|
| 302 |
+
Convolve a Tensor along a particular axis with a Savitzky-Golay kernel.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
window_length: Length of the filter window, must be a positive odd integer.
|
| 306 |
+
order: Order of the polynomial to fit to each window, must be less than ``window_length``.
|
| 307 |
+
axis (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension).
|
| 308 |
+
mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or
|
| 309 |
+
``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information.
|
| 310 |
+
"""
|
| 311 |
+
|
| 312 |
+
def __init__(self, window_length: int, order: int, axis: int = 2, mode: str = "zeros"):
|
| 313 |
+
|
| 314 |
+
super().__init__()
|
| 315 |
+
if order >= window_length:
|
| 316 |
+
raise ValueError("order must be less than window_length.")
|
| 317 |
+
|
| 318 |
+
self.axis = axis
|
| 319 |
+
self.mode = mode
|
| 320 |
+
self.coeffs = self._make_coeffs(window_length, order)
|
| 321 |
+
|
| 322 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 323 |
+
"""
|
| 324 |
+
Args:
|
| 325 |
+
x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and
|
| 326 |
+
have a device type of ``'cpu'``.
|
| 327 |
+
Returns:
|
| 328 |
+
torch.Tensor: ``x`` filtered by Savitzky-Golay kernel with window length ``self.window_length`` using
|
| 329 |
+
polynomials of order ``self.order``, along axis specified in ``self.axis``.
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
# Make input a real tensor on the CPU
|
| 333 |
+
x = torch.as_tensor(x, device=x.device if isinstance(x, torch.Tensor) else None)
|
| 334 |
+
if torch.is_complex(x):
|
| 335 |
+
raise ValueError("x must be real.")
|
| 336 |
+
x = x.to(dtype=torch.float)
|
| 337 |
+
|
| 338 |
+
if (self.axis < 0) or (self.axis > len(x.shape) - 1):
|
| 339 |
+
raise ValueError(f"Invalid axis for shape of x, got axis {self.axis} and shape {x.shape}.")
|
| 340 |
+
|
| 341 |
+
# Create list of filter kernels (1 per spatial dimension). The kernel for self.axis will be the savgol coeffs,
|
| 342 |
+
# while the other kernels will be set to [1].
|
| 343 |
+
n_spatial_dims = len(x.shape) - 2
|
| 344 |
+
spatial_processing_axis = self.axis - 2
|
| 345 |
+
new_dims_before = spatial_processing_axis
|
| 346 |
+
new_dims_after = n_spatial_dims - spatial_processing_axis - 1
|
| 347 |
+
kernel_list = [self.coeffs.to(device=x.device, dtype=x.dtype)]
|
| 348 |
+
for _ in range(new_dims_before):
|
| 349 |
+
kernel_list.insert(0, torch.ones(1, device=x.device, dtype=x.dtype))
|
| 350 |
+
for _ in range(new_dims_after):
|
| 351 |
+
kernel_list.append(torch.ones(1, device=x.device, dtype=x.dtype))
|
| 352 |
+
|
| 353 |
+
return separable_filtering(x, kernel_list, mode=self.mode)
|
| 354 |
+
|
| 355 |
+
@staticmethod
|
| 356 |
+
def _make_coeffs(window_length, order):
|
| 357 |
+
|
| 358 |
+
half_length, rem = divmod(window_length, 2)
|
| 359 |
+
if rem == 0:
|
| 360 |
+
raise ValueError("window_length must be odd.")
|
| 361 |
+
|
| 362 |
+
idx = torch.arange(window_length - half_length - 1, -half_length - 1, -1, dtype=torch.float, device="cpu")
|
| 363 |
+
a = idx ** torch.arange(order + 1, dtype=torch.float, device="cpu").reshape(-1, 1)
|
| 364 |
+
y = torch.zeros(order + 1, dtype=torch.float, device="cpu")
|
| 365 |
+
y[0] = 1.0
|
| 366 |
+
return torch.lstsq(y, a).solution.squeeze()
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class HilbertTransform(nn.Module):
|
| 370 |
+
"""
|
| 371 |
+
Determine the analytical signal of a Tensor along a particular axis.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
axis: Axis along which to apply Hilbert transform. Default 2 (first spatial dimension).
|
| 375 |
+
n: Number of Fourier components (i.e. FFT size). Default: ``x.shape[axis]``.
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
def __init__(self, axis: int = 2, n: Union[int, None] = None) -> None:
|
| 379 |
+
|
| 380 |
+
super().__init__()
|
| 381 |
+
self.axis = axis
|
| 382 |
+
self.n = n
|
| 383 |
+
|
| 384 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 385 |
+
"""
|
| 386 |
+
Args:
|
| 387 |
+
x: Tensor or array-like to transform. Must be real and in shape ``[Batch, chns, spatial1, spatial2, ...]``.
|
| 388 |
+
Returns:
|
| 389 |
+
torch.Tensor: Analytical signal of ``x``, transformed along axis specified in ``self.axis`` using
|
| 390 |
+
FFT of size ``self.N``. The absolute value of ``x_ht`` relates to the envelope of ``x`` along axis ``self.axis``.
|
| 391 |
+
"""
|
| 392 |
+
|
| 393 |
+
# Make input a real tensor
|
| 394 |
+
x = torch.as_tensor(x, device=x.device if isinstance(x, torch.Tensor) else None)
|
| 395 |
+
if torch.is_complex(x):
|
| 396 |
+
raise ValueError("x must be real.")
|
| 397 |
+
x = x.to(dtype=torch.float)
|
| 398 |
+
|
| 399 |
+
if (self.axis < 0) or (self.axis > len(x.shape) - 1):
|
| 400 |
+
raise ValueError(f"Invalid axis for shape of x, got axis {self.axis} and shape {x.shape}.")
|
| 401 |
+
|
| 402 |
+
n = x.shape[self.axis] if self.n is None else self.n
|
| 403 |
+
if n <= 0:
|
| 404 |
+
raise ValueError("N must be positive.")
|
| 405 |
+
x = torch.as_tensor(x, dtype=torch.complex64)
|
| 406 |
+
# Create frequency axis
|
| 407 |
+
f = torch.cat(
|
| 408 |
+
[
|
| 409 |
+
torch.true_divide(torch.arange(0, (n - 1) // 2 + 1, device=x.device), float(n)),
|
| 410 |
+
torch.true_divide(torch.arange(-(n // 2), 0, device=x.device), float(n)),
|
| 411 |
+
]
|
| 412 |
+
)
|
| 413 |
+
xf = fft.fft(x, n=n, dim=self.axis)
|
| 414 |
+
# Create step function
|
| 415 |
+
u = torch.heaviside(f, torch.tensor([0.5], device=f.device))
|
| 416 |
+
u = torch.as_tensor(u, dtype=x.dtype, device=u.device)
|
| 417 |
+
new_dims_before = self.axis
|
| 418 |
+
new_dims_after = len(xf.shape) - self.axis - 1
|
| 419 |
+
for _ in range(new_dims_before):
|
| 420 |
+
u.unsqueeze_(0)
|
| 421 |
+
for _ in range(new_dims_after):
|
| 422 |
+
u.unsqueeze_(-1)
|
| 423 |
+
|
| 424 |
+
ht = fft.ifft(xf * 2 * u, dim=self.axis)
|
| 425 |
+
|
| 426 |
+
# Apply transform
|
| 427 |
+
return torch.as_tensor(ht, device=ht.device, dtype=ht.dtype)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class GaussianFilter(nn.Module):
|
| 431 |
+
def __init__(
|
| 432 |
+
self,
|
| 433 |
+
spatial_dims: int,
|
| 434 |
+
sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor],
|
| 435 |
+
truncated: float = 4.0,
|
| 436 |
+
approx: str = "erf",
|
| 437 |
+
requires_grad: bool = False,
|
| 438 |
+
) -> None:
|
| 439 |
+
"""
|
| 440 |
+
Args:
|
| 441 |
+
spatial_dims: number of spatial dimensions of the input image.
|
| 442 |
+
must have shape (Batch, channels, H[, W, ...]).
|
| 443 |
+
sigma: std. could be a single value, or `spatial_dims` number of values.
|
| 444 |
+
truncated: spreads how many stds.
|
| 445 |
+
approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".
|
| 446 |
+
|
| 447 |
+
- ``erf`` approximation interpolates the error function;
|
| 448 |
+
- ``sampled`` uses a sampled Gaussian kernel;
|
| 449 |
+
- ``scalespace`` corresponds to
|
| 450 |
+
https://en.wikipedia.org/wiki/Scale_space_implementation#The_discrete_Gaussian_kernel
|
| 451 |
+
based on the modified Bessel functions.
|
| 452 |
+
|
| 453 |
+
requires_grad: whether to store the gradients for sigma.
|
| 454 |
+
if True, `sigma` will be the initial value of the parameters of this module
|
| 455 |
+
(for example `parameters()` iterator could be used to get the parameters);
|
| 456 |
+
otherwise this module will fix the kernels using `sigma` as the std.
|
| 457 |
+
"""
|
| 458 |
+
if issequenceiterable(sigma):
|
| 459 |
+
if len(sigma) != spatial_dims: # type: ignore
|
| 460 |
+
raise ValueError
|
| 461 |
+
else:
|
| 462 |
+
sigma = [deepcopy(sigma) for _ in range(spatial_dims)] # type: ignore
|
| 463 |
+
super().__init__()
|
| 464 |
+
self.sigma = [
|
| 465 |
+
torch.nn.Parameter(
|
| 466 |
+
torch.as_tensor(s, dtype=torch.float, device=s.device if isinstance(s, torch.Tensor) else None),
|
| 467 |
+
requires_grad=requires_grad,
|
| 468 |
+
)
|
| 469 |
+
for s in sigma # type: ignore
|
| 470 |
+
]
|
| 471 |
+
self.truncated = truncated
|
| 472 |
+
self.approx = approx
|
| 473 |
+
for idx, param in enumerate(self.sigma):
|
| 474 |
+
self.register_parameter(f"kernel_sigma_{idx}", param)
|
| 475 |
+
|
| 476 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 477 |
+
"""
|
| 478 |
+
Args:
|
| 479 |
+
x: in shape [Batch, chns, H, W, D].
|
| 480 |
+
"""
|
| 481 |
+
_kernel = [gaussian_1d(s, truncated=self.truncated, approx=self.approx) for s in self.sigma]
|
| 482 |
+
return separable_filtering(x=x, kernels=_kernel)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class LLTMFunction(Function):
|
| 486 |
+
@staticmethod
|
| 487 |
+
def forward(ctx, input, weights, bias, old_h, old_cell):
|
| 488 |
+
outputs = _C.lltm_forward(input, weights, bias, old_h, old_cell)
|
| 489 |
+
new_h, new_cell = outputs[:2]
|
| 490 |
+
variables = outputs[1:] + [weights]
|
| 491 |
+
ctx.save_for_backward(*variables)
|
| 492 |
+
|
| 493 |
+
return new_h, new_cell
|
| 494 |
+
|
| 495 |
+
@staticmethod
|
| 496 |
+
def backward(ctx, grad_h, grad_cell):
|
| 497 |
+
outputs = _C.lltm_backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors)
|
| 498 |
+
d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs[:5]
|
| 499 |
+
|
| 500 |
+
return d_input, d_weights, d_bias, d_old_h, d_old_cell
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class LLTM(nn.Module):
|
| 504 |
+
"""
|
| 505 |
+
This recurrent unit is similar to an LSTM, but differs in that it lacks a forget
|
| 506 |
+
gate and uses an Exponential Linear Unit (ELU) as its internal activation function.
|
| 507 |
+
Because this unit never forgets, call it LLTM, or Long-Long-Term-Memory unit.
|
| 508 |
+
It has both C++ and CUDA implementation, automatically switch according to the
|
| 509 |
+
target device where put this module to.
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
input_features: size of input feature data
|
| 513 |
+
state_size: size of the state of recurrent unit
|
| 514 |
+
|
| 515 |
+
Referring to: https://pytorch.org/tutorials/advanced/cpp_extension.html
|
| 516 |
+
"""
|
| 517 |
+
|
| 518 |
+
def __init__(self, input_features: int, state_size: int):
|
| 519 |
+
super().__init__()
|
| 520 |
+
self.input_features = input_features
|
| 521 |
+
self.state_size = state_size
|
| 522 |
+
self.weights = nn.Parameter(torch.empty(3 * state_size, input_features + state_size))
|
| 523 |
+
self.bias = nn.Parameter(torch.empty(1, 3 * state_size))
|
| 524 |
+
self.reset_parameters()
|
| 525 |
+
|
| 526 |
+
def reset_parameters(self):
|
| 527 |
+
stdv = 1.0 / math.sqrt(self.state_size)
|
| 528 |
+
for weight in self.parameters():
|
| 529 |
+
weight.data.uniform_(-stdv, +stdv)
|
| 530 |
+
|
| 531 |
+
def forward(self, input, state):
|
| 532 |
+
return LLTMFunction.apply(input, self.weights, self.bias, *state)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/spatial_transforms.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from typing import Optional, Sequence, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
from monai.networks import to_norm_affine
|
| 18 |
+
from monai.utils import GridSampleMode, GridSamplePadMode, ensure_tuple, look_up_option, optional_import
|
| 19 |
+
|
| 20 |
+
_C, _ = optional_import("monai._C")
|
| 21 |
+
|
| 22 |
+
__all__ = ["AffineTransform", "grid_pull", "grid_push", "grid_count", "grid_grad"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class _GridPull(torch.autograd.Function):
|
| 26 |
+
@staticmethod
|
| 27 |
+
def forward(ctx, input, grid, interpolation, bound, extrapolate):
|
| 28 |
+
opt = (bound, interpolation, extrapolate)
|
| 29 |
+
output = _C.grid_pull(input, grid, *opt)
|
| 30 |
+
if input.requires_grad or grid.requires_grad:
|
| 31 |
+
ctx.opt = opt
|
| 32 |
+
ctx.save_for_backward(input, grid)
|
| 33 |
+
|
| 34 |
+
return output
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def backward(ctx, grad):
|
| 38 |
+
if not (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]):
|
| 39 |
+
return None, None, None, None, None
|
| 40 |
+
var = ctx.saved_tensors
|
| 41 |
+
opt = ctx.opt
|
| 42 |
+
grads = _C.grid_pull_backward(grad, *var, *opt)
|
| 43 |
+
if ctx.needs_input_grad[0]:
|
| 44 |
+
return grads[0], grads[1] if ctx.needs_input_grad[1] else None, None, None, None
|
| 45 |
+
if ctx.needs_input_grad[1]:
|
| 46 |
+
return None, grads[0], None, None, None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def grid_pull(
|
| 50 |
+
input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True
|
| 51 |
+
) -> torch.Tensor:
|
| 52 |
+
"""
|
| 53 |
+
Sample an image with respect to a deformation field.
|
| 54 |
+
|
| 55 |
+
`interpolation` can be an int, a string or an InterpolationType.
|
| 56 |
+
Possible values are::
|
| 57 |
+
|
| 58 |
+
- 0 or 'nearest' or InterpolationType.nearest
|
| 59 |
+
- 1 or 'linear' or InterpolationType.linear
|
| 60 |
+
- 2 or 'quadratic' or InterpolationType.quadratic
|
| 61 |
+
- 3 or 'cubic' or InterpolationType.cubic
|
| 62 |
+
- 4 or 'fourth' or InterpolationType.fourth
|
| 63 |
+
- 5 or 'fifth' or InterpolationType.fifth
|
| 64 |
+
- 6 or 'sixth' or InterpolationType.sixth
|
| 65 |
+
- 7 or 'seventh' or InterpolationType.seventh
|
| 66 |
+
|
| 67 |
+
A list of values can be provided, in the order [W, H, D],
|
| 68 |
+
to specify dimension-specific interpolation orders.
|
| 69 |
+
|
| 70 |
+
`bound` can be an int, a string or a BoundType.
|
| 71 |
+
Possible values are::
|
| 72 |
+
|
| 73 |
+
- 0 or 'replicate' or 'nearest' or BoundType.replicate or 'border'
|
| 74 |
+
- 1 or 'dct1' or 'mirror' or BoundType.dct1
|
| 75 |
+
- 2 or 'dct2' or 'reflect' or BoundType.dct2
|
| 76 |
+
- 3 or 'dst1' or 'antimirror' or BoundType.dst1
|
| 77 |
+
- 4 or 'dst2' or 'antireflect' or BoundType.dst2
|
| 78 |
+
- 5 or 'dft' or 'wrap' or BoundType.dft
|
| 79 |
+
- 7 or 'zero' or 'zeros' or BoundType.zero
|
| 80 |
+
|
| 81 |
+
A list of values can be provided, in the order [W, H, D],
|
| 82 |
+
to specify dimension-specific boundary conditions.
|
| 83 |
+
`sliding` is a specific condition than only applies to flow fields
|
| 84 |
+
(with as many channels as dimensions). It cannot be dimension-specific.
|
| 85 |
+
Note that:
|
| 86 |
+
|
| 87 |
+
- `dft` corresponds to circular padding
|
| 88 |
+
- `dct2` corresponds to Neumann boundary conditions (symmetric)
|
| 89 |
+
- `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)
|
| 90 |
+
|
| 91 |
+
See Also:
|
| 92 |
+
- https://en.wikipedia.org/wiki/Discrete_cosine_transform
|
| 93 |
+
- https://en.wikipedia.org/wiki/Discrete_sine_transform
|
| 94 |
+
- ``help(monai._C.BoundType)``
|
| 95 |
+
- ``help(monai._C.InterpolationType)``
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
input: Input image. `(B, C, Wi, Hi, Di)`.
|
| 99 |
+
grid: Deformation field. `(B, Wo, Ho, Do, 1|2|3)`.
|
| 100 |
+
interpolation (int or list[int] , optional): Interpolation order.
|
| 101 |
+
Defaults to `'linear'`.
|
| 102 |
+
bound (BoundType, or list[BoundType], optional): Boundary conditions.
|
| 103 |
+
Defaults to `'zero'`.
|
| 104 |
+
extrapolate: Extrapolate out-of-bound data.
|
| 105 |
+
Defaults to `True`.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
output (torch.Tensor): Deformed image `(B, C, Wo, Ho, Do)`.
|
| 109 |
+
|
| 110 |
+
"""
|
| 111 |
+
# Convert parameters
|
| 112 |
+
bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)]
|
| 113 |
+
interpolation = [
|
| 114 |
+
_C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i)
|
| 115 |
+
for i in ensure_tuple(interpolation)
|
| 116 |
+
]
|
| 117 |
+
out: torch.Tensor
|
| 118 |
+
out = _GridPull.apply(input, grid, interpolation, bound, extrapolate)
|
| 119 |
+
return out
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class _GridPush(torch.autograd.Function):
|
| 123 |
+
@staticmethod
|
| 124 |
+
def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):
|
| 125 |
+
opt = (bound, interpolation, extrapolate)
|
| 126 |
+
output = _C.grid_push(input, grid, shape, *opt)
|
| 127 |
+
if input.requires_grad or grid.requires_grad:
|
| 128 |
+
ctx.opt = opt
|
| 129 |
+
ctx.save_for_backward(input, grid)
|
| 130 |
+
|
| 131 |
+
return output
|
| 132 |
+
|
| 133 |
+
@staticmethod
|
| 134 |
+
def backward(ctx, grad):
|
| 135 |
+
if not (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]):
|
| 136 |
+
return None, None, None, None, None, None
|
| 137 |
+
var = ctx.saved_tensors
|
| 138 |
+
opt = ctx.opt
|
| 139 |
+
grads = _C.grid_push_backward(grad, *var, *opt)
|
| 140 |
+
if ctx.needs_input_grad[0]:
|
| 141 |
+
return grads[0], grads[1] if ctx.needs_input_grad[1] else None, None, None, None, None
|
| 142 |
+
if ctx.needs_input_grad[1]:
|
| 143 |
+
return None, grads[0], None, None, None, None
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def grid_push(
|
| 147 |
+
input: torch.Tensor, grid: torch.Tensor, shape=None, interpolation="linear", bound="zero", extrapolate: bool = True
|
| 148 |
+
):
|
| 149 |
+
"""
|
| 150 |
+
Splat an image with respect to a deformation field (pull adjoint).
|
| 151 |
+
|
| 152 |
+
`interpolation` can be an int, a string or an InterpolationType.
|
| 153 |
+
Possible values are::
|
| 154 |
+
|
| 155 |
+
- 0 or 'nearest' or InterpolationType.nearest
|
| 156 |
+
- 1 or 'linear' or InterpolationType.linear
|
| 157 |
+
- 2 or 'quadratic' or InterpolationType.quadratic
|
| 158 |
+
- 3 or 'cubic' or InterpolationType.cubic
|
| 159 |
+
- 4 or 'fourth' or InterpolationType.fourth
|
| 160 |
+
- 5 or 'fifth' or InterpolationType.fifth
|
| 161 |
+
- 6 or 'sixth' or InterpolationType.sixth
|
| 162 |
+
- 7 or 'seventh' or InterpolationType.seventh
|
| 163 |
+
|
| 164 |
+
A list of values can be provided, in the order `[W, H, D]`,
|
| 165 |
+
to specify dimension-specific interpolation orders.
|
| 166 |
+
|
| 167 |
+
`bound` can be an int, a string or a BoundType.
|
| 168 |
+
Possible values are::
|
| 169 |
+
|
| 170 |
+
- 0 or 'replicate' or 'nearest' or BoundType.replicate
|
| 171 |
+
- 1 or 'dct1' or 'mirror' or BoundType.dct1
|
| 172 |
+
- 2 or 'dct2' or 'reflect' or BoundType.dct2
|
| 173 |
+
- 3 or 'dst1' or 'antimirror' or BoundType.dst1
|
| 174 |
+
- 4 or 'dst2' or 'antireflect' or BoundType.dst2
|
| 175 |
+
- 5 or 'dft' or 'wrap' or BoundType.dft
|
| 176 |
+
- 7 or 'zero' or BoundType.zero
|
| 177 |
+
|
| 178 |
+
A list of values can be provided, in the order `[W, H, D]`,
|
| 179 |
+
to specify dimension-specific boundary conditions.
|
| 180 |
+
`sliding` is a specific condition than only applies to flow fields
|
| 181 |
+
(with as many channels as dimensions). It cannot be dimension-specific.
|
| 182 |
+
Note that:
|
| 183 |
+
|
| 184 |
+
- `dft` corresponds to circular padding
|
| 185 |
+
- `dct2` corresponds to Neumann boundary conditions (symmetric)
|
| 186 |
+
- `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)
|
| 187 |
+
|
| 188 |
+
See Also:
|
| 189 |
+
|
| 190 |
+
- https://en.wikipedia.org/wiki/Discrete_cosine_transform
|
| 191 |
+
- https://en.wikipedia.org/wiki/Discrete_sine_transform
|
| 192 |
+
- ``help(monai._C.BoundType)``
|
| 193 |
+
- ``help(monai._C.InterpolationType)``
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
input: Input image `(B, C, Wi, Hi, Di)`.
|
| 197 |
+
grid: Deformation field `(B, Wi, Hi, Di, 1|2|3)`.
|
| 198 |
+
shape: Shape of the source image.
|
| 199 |
+
interpolation (int or list[int] , optional): Interpolation order.
|
| 200 |
+
Defaults to `'linear'`.
|
| 201 |
+
bound (BoundType, or list[BoundType], optional): Boundary conditions.
|
| 202 |
+
Defaults to `'zero'`.
|
| 203 |
+
extrapolate: Extrapolate out-of-bound data.
|
| 204 |
+
Defaults to `True`.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
output (torch.Tensor): Splatted image `(B, C, Wo, Ho, Do)`.
|
| 208 |
+
|
| 209 |
+
"""
|
| 210 |
+
# Convert parameters
|
| 211 |
+
bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)]
|
| 212 |
+
interpolation = [
|
| 213 |
+
_C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i)
|
| 214 |
+
for i in ensure_tuple(interpolation)
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
if shape is None:
|
| 218 |
+
shape = tuple(input.shape[2:])
|
| 219 |
+
|
| 220 |
+
return _GridPush.apply(input, grid, shape, interpolation, bound, extrapolate)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class _GridCount(torch.autograd.Function):
|
| 224 |
+
@staticmethod
|
| 225 |
+
def forward(ctx, grid, shape, interpolation, bound, extrapolate):
|
| 226 |
+
opt = (bound, interpolation, extrapolate)
|
| 227 |
+
output = _C.grid_count(grid, shape, *opt)
|
| 228 |
+
if grid.requires_grad:
|
| 229 |
+
ctx.opt = opt
|
| 230 |
+
ctx.save_for_backward(grid)
|
| 231 |
+
|
| 232 |
+
return output
|
| 233 |
+
|
| 234 |
+
@staticmethod
|
| 235 |
+
def backward(ctx, grad):
|
| 236 |
+
if ctx.needs_input_grad[0]:
|
| 237 |
+
var = ctx.saved_tensors
|
| 238 |
+
opt = ctx.opt
|
| 239 |
+
return _C.grid_count_backward(grad, *var, *opt), None, None, None, None
|
| 240 |
+
return None, None, None, None, None
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def grid_count(grid: torch.Tensor, shape=None, interpolation="linear", bound="zero", extrapolate: bool = True):
|
| 244 |
+
"""
|
| 245 |
+
Splatting weights with respect to a deformation field (pull adjoint).
|
| 246 |
+
|
| 247 |
+
This function is equivalent to applying grid_push to an image of ones.
|
| 248 |
+
|
| 249 |
+
`interpolation` can be an int, a string or an InterpolationType.
|
| 250 |
+
Possible values are::
|
| 251 |
+
|
| 252 |
+
- 0 or 'nearest' or InterpolationType.nearest
|
| 253 |
+
- 1 or 'linear' or InterpolationType.linear
|
| 254 |
+
- 2 or 'quadratic' or InterpolationType.quadratic
|
| 255 |
+
- 3 or 'cubic' or InterpolationType.cubic
|
| 256 |
+
- 4 or 'fourth' or InterpolationType.fourth
|
| 257 |
+
- 5 or 'fifth' or InterpolationType.fifth
|
| 258 |
+
- 6 or 'sixth' or InterpolationType.sixth
|
| 259 |
+
- 7 or 'seventh' or InterpolationType.seventh
|
| 260 |
+
|
| 261 |
+
A list of values can be provided, in the order [W, H, D],
|
| 262 |
+
to specify dimension-specific interpolation orders.
|
| 263 |
+
|
| 264 |
+
`bound` can be an int, a string or a BoundType.
|
| 265 |
+
Possible values are::
|
| 266 |
+
|
| 267 |
+
- 0 or 'replicate' or 'nearest' or BoundType.replicate
|
| 268 |
+
- 1 or 'dct1' or 'mirror' or BoundType.dct1
|
| 269 |
+
- 2 or 'dct2' or 'reflect' or BoundType.dct2
|
| 270 |
+
- 3 or 'dst1' or 'antimirror' or BoundType.dst1
|
| 271 |
+
- 4 or 'dst2' or 'antireflect' or BoundType.dst2
|
| 272 |
+
- 5 or 'dft' or 'wrap' or BoundType.dft
|
| 273 |
+
- 7 or 'zero' or BoundType.zero
|
| 274 |
+
|
| 275 |
+
A list of values can be provided, in the order [W, H, D],
|
| 276 |
+
to specify dimension-specific boundary conditions.
|
| 277 |
+
`sliding` is a specific condition than only applies to flow fields
|
| 278 |
+
(with as many channels as dimensions). It cannot be dimension-specific.
|
| 279 |
+
Note that:
|
| 280 |
+
|
| 281 |
+
- `dft` corresponds to circular padding
|
| 282 |
+
- `dct2` corresponds to Neumann boundary conditions (symmetric)
|
| 283 |
+
- `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)
|
| 284 |
+
|
| 285 |
+
See Also:
|
| 286 |
+
|
| 287 |
+
- https://en.wikipedia.org/wiki/Discrete_cosine_transform
|
| 288 |
+
- https://en.wikipedia.org/wiki/Discrete_sine_transform
|
| 289 |
+
- ``help(monai._C.BoundType)``
|
| 290 |
+
- ``help(monai._C.InterpolationType)``
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
grid: Deformation field `(B, Wi, Hi, Di, 2|3)`.
|
| 294 |
+
shape: shape of the source image.
|
| 295 |
+
interpolation (int or list[int] , optional): Interpolation order.
|
| 296 |
+
Defaults to `'linear'`.
|
| 297 |
+
bound (BoundType, or list[BoundType], optional): Boundary conditions.
|
| 298 |
+
Defaults to `'zero'`.
|
| 299 |
+
extrapolate (bool, optional): Extrapolate out-of-bound data.
|
| 300 |
+
Defaults to `True`.
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
output (torch.Tensor): Splat weights `(B, 1, Wo, Ho, Do)`.
|
| 304 |
+
|
| 305 |
+
"""
|
| 306 |
+
# Convert parameters
|
| 307 |
+
bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)]
|
| 308 |
+
interpolation = [
|
| 309 |
+
_C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i)
|
| 310 |
+
for i in ensure_tuple(interpolation)
|
| 311 |
+
]
|
| 312 |
+
|
| 313 |
+
if shape is None:
|
| 314 |
+
shape = tuple(grid.shape[2:])
|
| 315 |
+
|
| 316 |
+
return _GridCount.apply(grid, shape, interpolation, bound, extrapolate)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class _GridGrad(torch.autograd.Function):
|
| 320 |
+
@staticmethod
|
| 321 |
+
def forward(ctx, input, grid, interpolation, bound, extrapolate):
|
| 322 |
+
opt = (bound, interpolation, extrapolate)
|
| 323 |
+
output = _C.grid_grad(input, grid, *opt)
|
| 324 |
+
if input.requires_grad or grid.requires_grad:
|
| 325 |
+
ctx.opt = opt
|
| 326 |
+
ctx.save_for_backward(input, grid)
|
| 327 |
+
|
| 328 |
+
return output
|
| 329 |
+
|
| 330 |
+
@staticmethod
|
| 331 |
+
def backward(ctx, grad):
|
| 332 |
+
if not (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]):
|
| 333 |
+
return None, None, None, None, None
|
| 334 |
+
var = ctx.saved_tensors
|
| 335 |
+
opt = ctx.opt
|
| 336 |
+
grads = _C.grid_grad_backward(grad, *var, *opt)
|
| 337 |
+
if ctx.needs_input_grad[0]:
|
| 338 |
+
return grads[0], grads[1] if ctx.needs_input_grad[1] else None, None, None, None
|
| 339 |
+
if ctx.needs_input_grad[1]:
|
| 340 |
+
return None, grads[0], None, None, None
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def grid_grad(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", bound="zero", extrapolate: bool = True):
|
| 344 |
+
"""
|
| 345 |
+
Sample an image with respect to a deformation field.
|
| 346 |
+
|
| 347 |
+
`interpolation` can be an int, a string or an InterpolationType.
|
| 348 |
+
Possible values are::
|
| 349 |
+
|
| 350 |
+
- 0 or 'nearest' or InterpolationType.nearest
|
| 351 |
+
- 1 or 'linear' or InterpolationType.linear
|
| 352 |
+
- 2 or 'quadratic' or InterpolationType.quadratic
|
| 353 |
+
- 3 or 'cubic' or InterpolationType.cubic
|
| 354 |
+
- 4 or 'fourth' or InterpolationType.fourth
|
| 355 |
+
- 5 or 'fifth' or InterpolationType.fifth
|
| 356 |
+
- 6 or 'sixth' or InterpolationType.sixth
|
| 357 |
+
- 7 or 'seventh' or InterpolationType.seventh
|
| 358 |
+
|
| 359 |
+
A list of values can be provided, in the order [W, H, D],
|
| 360 |
+
to specify dimension-specific interpolation orders.
|
| 361 |
+
|
| 362 |
+
`bound` can be an int, a string or a BoundType.
|
| 363 |
+
Possible values are::
|
| 364 |
+
|
| 365 |
+
- 0 or 'replicate' or 'nearest' or BoundType.replicate
|
| 366 |
+
- 1 or 'dct1' or 'mirror' or BoundType.dct1
|
| 367 |
+
- 2 or 'dct2' or 'reflect' or BoundType.dct2
|
| 368 |
+
- 3 or 'dst1' or 'antimirror' or BoundType.dst1
|
| 369 |
+
- 4 or 'dst2' or 'antireflect' or BoundType.dst2
|
| 370 |
+
- 5 or 'dft' or 'wrap' or BoundType.dft
|
| 371 |
+
- 7 or 'zero' or BoundType.zero
|
| 372 |
+
|
| 373 |
+
A list of values can be provided, in the order [W, H, D],
|
| 374 |
+
to specify dimension-specific boundary conditions.
|
| 375 |
+
`sliding` is a specific condition than only applies to flow fields
|
| 376 |
+
(with as many channels as dimensions). It cannot be dimension-specific.
|
| 377 |
+
Note that:
|
| 378 |
+
|
| 379 |
+
- `dft` corresponds to circular padding
|
| 380 |
+
- `dct2` corresponds to Neumann boundary conditions (symmetric)
|
| 381 |
+
- `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)
|
| 382 |
+
|
| 383 |
+
See Also:
|
| 384 |
+
|
| 385 |
+
- https://en.wikipedia.org/wiki/Discrete_cosine_transform
|
| 386 |
+
- https://en.wikipedia.org/wiki/Discrete_sine_transform
|
| 387 |
+
- ``help(monai._C.BoundType)``
|
| 388 |
+
- ``help(monai._C.InterpolationType)``
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
input: Input image. `(B, C, Wi, Hi, Di)`.
|
| 393 |
+
grid: Deformation field. `(B, Wo, Ho, Do, 2|3)`.
|
| 394 |
+
interpolation (int or list[int] , optional): Interpolation order.
|
| 395 |
+
Defaults to `'linear'`.
|
| 396 |
+
bound (BoundType, or list[BoundType], optional): Boundary conditions.
|
| 397 |
+
Defaults to `'zero'`.
|
| 398 |
+
extrapolate: Extrapolate out-of-bound data. Defaults to `True`.
|
| 399 |
+
|
| 400 |
+
Returns:
|
| 401 |
+
output (torch.Tensor): Sampled gradients (B, C, Wo, Ho, Do, 1|2|3).
|
| 402 |
+
|
| 403 |
+
"""
|
| 404 |
+
# Convert parameters
|
| 405 |
+
bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)]
|
| 406 |
+
interpolation = [
|
| 407 |
+
_C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i)
|
| 408 |
+
for i in ensure_tuple(interpolation)
|
| 409 |
+
]
|
| 410 |
+
|
| 411 |
+
return _GridGrad.apply(input, grid, interpolation, bound, extrapolate)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class AffineTransform(nn.Module):
|
| 415 |
+
def __init__(
|
| 416 |
+
self,
|
| 417 |
+
spatial_size: Optional[Union[Sequence[int], int]] = None,
|
| 418 |
+
normalized: bool = False,
|
| 419 |
+
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
|
| 420 |
+
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.ZEROS,
|
| 421 |
+
align_corners: bool = False,
|
| 422 |
+
reverse_indexing: bool = True,
|
| 423 |
+
zero_centered: Optional[bool] = None,
|
| 424 |
+
) -> None:
|
| 425 |
+
"""
|
| 426 |
+
Apply affine transformations with a batch of affine matrices.
|
| 427 |
+
|
| 428 |
+
When `normalized=False` and `reverse_indexing=True`,
|
| 429 |
+
it does the commonly used resampling in the 'pull' direction
|
| 430 |
+
following the ``scipy.ndimage.affine_transform`` convention.
|
| 431 |
+
In this case `theta` is equivalent to (ndim+1, ndim+1) input ``matrix`` of ``scipy.ndimage.affine_transform``,
|
| 432 |
+
operates on homogeneous coordinates.
|
| 433 |
+
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html
|
| 434 |
+
|
| 435 |
+
When `normalized=True` and `reverse_indexing=False`,
|
| 436 |
+
it applies `theta` to the normalized coordinates (coords. in the range of [-1, 1]) directly.
|
| 437 |
+
This is often used with `align_corners=False` to achieve resolution-agnostic resampling,
|
| 438 |
+
thus useful as a part of trainable modules such as the spatial transformer networks.
|
| 439 |
+
See also: https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
spatial_size: output spatial shape, the full output shape will be
|
| 443 |
+
`[N, C, *spatial_size]` where N and C are inferred from the `src` input of `self.forward`.
|
| 444 |
+
normalized: indicating whether the provided affine matrix `theta` is defined
|
| 445 |
+
for the normalized coordinates. If `normalized=False`, `theta` will be converted
|
| 446 |
+
to operate on normalized coordinates as pytorch affine_grid works with the normalized
|
| 447 |
+
coordinates.
|
| 448 |
+
mode: {``"bilinear"``, ``"nearest"``}
|
| 449 |
+
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
|
| 450 |
+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
|
| 451 |
+
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
|
| 452 |
+
Padding mode for outside grid values. Defaults to ``"zeros"``.
|
| 453 |
+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
|
| 454 |
+
align_corners: see also https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html.
|
| 455 |
+
reverse_indexing: whether to reverse the spatial indexing of image and coordinates.
|
| 456 |
+
set to `False` if `theta` follows pytorch's default "D, H, W" convention.
|
| 457 |
+
set to `True` if `theta` follows `scipy.ndimage` default "i, j, k" convention.
|
| 458 |
+
zero_centered: whether the affine is applied to coordinates in a zero-centered value range.
|
| 459 |
+
With `zero_centered=True`, for example, the center of rotation will be the
|
| 460 |
+
spatial center of the input; with `zero_centered=False`, the center of rotation will be the
|
| 461 |
+
origin of the input. This option is only available when `normalized=False`,
|
| 462 |
+
where the default behaviour is `False` if unspecified.
|
| 463 |
+
See also: :py:func:`monai.networks.utils.normalize_transform`.
|
| 464 |
+
"""
|
| 465 |
+
super().__init__()
|
| 466 |
+
self.spatial_size = ensure_tuple(spatial_size) if spatial_size is not None else None
|
| 467 |
+
self.normalized = normalized
|
| 468 |
+
self.mode: GridSampleMode = look_up_option(mode, GridSampleMode)
|
| 469 |
+
self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode)
|
| 470 |
+
self.align_corners = align_corners
|
| 471 |
+
self.reverse_indexing = reverse_indexing
|
| 472 |
+
if zero_centered is not None and self.normalized:
|
| 473 |
+
raise ValueError("`normalized=True` is not compatible with the `zero_centered` option.")
|
| 474 |
+
self.zero_centered = zero_centered if zero_centered is not None else False
|
| 475 |
+
|
| 476 |
+
def forward(
|
| 477 |
+
self, src: torch.Tensor, theta: torch.Tensor, spatial_size: Optional[Union[Sequence[int], int]] = None
|
| 478 |
+
) -> torch.Tensor:
|
| 479 |
+
"""
|
| 480 |
+
``theta`` must be an affine transformation matrix with shape
|
| 481 |
+
3x3 or Nx3x3 or Nx2x3 or 2x3 for spatial 2D transforms,
|
| 482 |
+
4x4 or Nx4x4 or Nx3x4 or 3x4 for spatial 3D transforms,
|
| 483 |
+
where `N` is the batch size. `theta` will be converted into float Tensor for the computation.
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
src (array_like): image in spatial 2D or 3D (N, C, spatial_dims),
|
| 487 |
+
where N is the batch dim, C is the number of channels.
|
| 488 |
+
theta (array_like): Nx3x3, Nx2x3, 3x3, 2x3 for spatial 2D inputs,
|
| 489 |
+
Nx4x4, Nx3x4, 3x4, 4x4 for spatial 3D inputs. When the batch dimension is omitted,
|
| 490 |
+
`theta` will be repeated N times, N is the batch dim of `src`.
|
| 491 |
+
spatial_size: output spatial shape, the full output shape will be
|
| 492 |
+
`[N, C, *spatial_size]` where N and C are inferred from the `src`.
|
| 493 |
+
|
| 494 |
+
Raises:
|
| 495 |
+
TypeError: When ``theta`` is not a ``torch.Tensor``.
|
| 496 |
+
ValueError: When ``theta`` is not one of [Nxdxd, dxd].
|
| 497 |
+
ValueError: When ``theta`` is not one of [Nx3x3, Nx4x4].
|
| 498 |
+
TypeError: When ``src`` is not a ``torch.Tensor``.
|
| 499 |
+
ValueError: When ``src`` spatially is not one of [2D, 3D].
|
| 500 |
+
ValueError: When affine and image batch dimension differ.
|
| 501 |
+
|
| 502 |
+
"""
|
| 503 |
+
# validate `theta`
|
| 504 |
+
if not isinstance(theta, torch.Tensor):
|
| 505 |
+
raise TypeError(f"theta must be torch.Tensor but is {type(theta).__name__}.")
|
| 506 |
+
if theta.dim() not in (2, 3):
|
| 507 |
+
raise ValueError(f"theta must be Nxdxd or dxd, got {theta.shape}.")
|
| 508 |
+
if theta.dim() == 2:
|
| 509 |
+
theta = theta[None] # adds a batch dim.
|
| 510 |
+
theta = theta.clone() # no in-place change of theta
|
| 511 |
+
theta_shape = tuple(theta.shape[1:])
|
| 512 |
+
if theta_shape in ((2, 3), (3, 4)): # needs padding to dxd
|
| 513 |
+
pad_affine = torch.tensor([0, 0, 1] if theta_shape[0] == 2 else [0, 0, 0, 1])
|
| 514 |
+
pad_affine = pad_affine.repeat(theta.shape[0], 1, 1).to(theta)
|
| 515 |
+
pad_affine.requires_grad = False
|
| 516 |
+
theta = torch.cat([theta, pad_affine], dim=1)
|
| 517 |
+
if tuple(theta.shape[1:]) not in ((3, 3), (4, 4)):
|
| 518 |
+
raise ValueError(f"theta must be Nx3x3 or Nx4x4, got {theta.shape}.")
|
| 519 |
+
|
| 520 |
+
# validate `src`
|
| 521 |
+
if not isinstance(src, torch.Tensor):
|
| 522 |
+
raise TypeError(f"src must be torch.Tensor but is {type(src).__name__}.")
|
| 523 |
+
sr = src.dim() - 2 # input spatial rank
|
| 524 |
+
if sr not in (2, 3):
|
| 525 |
+
raise ValueError(f"Unsupported src dimension: {sr}, available options are [2, 3].")
|
| 526 |
+
|
| 527 |
+
# set output shape
|
| 528 |
+
src_size = tuple(src.shape)
|
| 529 |
+
dst_size = src_size # default to the src shape
|
| 530 |
+
if self.spatial_size is not None:
|
| 531 |
+
dst_size = src_size[:2] + self.spatial_size
|
| 532 |
+
if spatial_size is not None:
|
| 533 |
+
dst_size = src_size[:2] + ensure_tuple(spatial_size)
|
| 534 |
+
|
| 535 |
+
# reverse and normalize theta if needed
|
| 536 |
+
if not self.normalized:
|
| 537 |
+
theta = to_norm_affine(
|
| 538 |
+
affine=theta,
|
| 539 |
+
src_size=src_size[2:],
|
| 540 |
+
dst_size=dst_size[2:],
|
| 541 |
+
align_corners=self.align_corners,
|
| 542 |
+
zero_centered=self.zero_centered,
|
| 543 |
+
)
|
| 544 |
+
if self.reverse_indexing:
|
| 545 |
+
rev_idx = torch.as_tensor(range(sr - 1, -1, -1), device=src.device)
|
| 546 |
+
theta[:, :sr] = theta[:, rev_idx]
|
| 547 |
+
theta[:, :, :sr] = theta[:, :, rev_idx]
|
| 548 |
+
if (theta.shape[0] == 1) and src_size[0] > 1:
|
| 549 |
+
# adds a batch dim to `theta` in order to match `src`
|
| 550 |
+
theta = theta.repeat(src_size[0], 1, 1)
|
| 551 |
+
if theta.shape[0] != src_size[0]:
|
| 552 |
+
raise ValueError(
|
| 553 |
+
f"affine and image batch dimension must match, got affine={theta.shape[0]} image={src_size[0]}."
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
grid = nn.functional.affine_grid(theta=theta[:, :sr], size=list(dst_size), align_corners=self.align_corners)
|
| 557 |
+
dst = nn.functional.grid_sample(
|
| 558 |
+
input=src.contiguous(),
|
| 559 |
+
grid=grid,
|
| 560 |
+
mode=self.mode.value,
|
| 561 |
+
padding_mode=self.padding_mode.value,
|
| 562 |
+
align_corners=self.align_corners,
|
| 563 |
+
)
|
| 564 |
+
return dst
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/utils.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from typing import Optional, Tuple, Union
|
| 13 |
+
|
| 14 |
+
from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args
|
| 15 |
+
from monai.utils import has_option
|
| 16 |
+
|
| 17 |
+
__all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_norm_layer(name: Union[Tuple, str], spatial_dims: Optional[int] = 1, channels: Optional[int] = 1):
|
| 21 |
+
"""
|
| 22 |
+
Create a normalization layer instance.
|
| 23 |
+
|
| 24 |
+
For example, to create normalization layers:
|
| 25 |
+
|
| 26 |
+
.. code-block:: python
|
| 27 |
+
|
| 28 |
+
from monai.networks.layers import get_norm_layer
|
| 29 |
+
|
| 30 |
+
g_layer = get_norm_layer(name=("group", {"num_groups": 1}))
|
| 31 |
+
n_layer = get_norm_layer(name="instance", spatial_dims=2)
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
name: a normalization type string or a tuple of type string and parameters.
|
| 35 |
+
spatial_dims: number of spatial dimensions of the input.
|
| 36 |
+
channels: number of features/channels when the normalization layer requires this parameter
|
| 37 |
+
but it is not specified in the norm parameters.
|
| 38 |
+
"""
|
| 39 |
+
norm_name, norm_args = split_args(name)
|
| 40 |
+
norm_type = Norm[norm_name, spatial_dims]
|
| 41 |
+
kw_args = dict(norm_args)
|
| 42 |
+
if has_option(norm_type, "num_features") and "num_features" not in kw_args:
|
| 43 |
+
kw_args["num_features"] = channels
|
| 44 |
+
if has_option(norm_type, "num_channels") and "num_channels" not in kw_args:
|
| 45 |
+
kw_args["num_channels"] = channels
|
| 46 |
+
return norm_type(**kw_args)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_act_layer(name: Union[Tuple, str]):
|
| 50 |
+
"""
|
| 51 |
+
Create an activation layer instance.
|
| 52 |
+
|
| 53 |
+
For example, to create activation layers:
|
| 54 |
+
|
| 55 |
+
.. code-block:: python
|
| 56 |
+
|
| 57 |
+
from monai.networks.layers import get_act_layer
|
| 58 |
+
|
| 59 |
+
s_layer = get_act_layer(name="swish")
|
| 60 |
+
p_layer = get_act_layer(name=("prelu", {"num_parameters": 1, "init": 0.25}))
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
name: an activation type string or a tuple of type string and parameters.
|
| 64 |
+
"""
|
| 65 |
+
act_name, act_args = split_args(name)
|
| 66 |
+
act_type = Act[act_name]
|
| 67 |
+
return act_type(**act_args)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_dropout_layer(name: Union[Tuple, str, float, int], dropout_dim: Optional[int] = 1):
|
| 71 |
+
"""
|
| 72 |
+
Create a dropout layer instance.
|
| 73 |
+
|
| 74 |
+
For example, to create dropout layers:
|
| 75 |
+
|
| 76 |
+
.. code-block:: python
|
| 77 |
+
|
| 78 |
+
from monai.networks.layers import get_dropout_layer
|
| 79 |
+
|
| 80 |
+
d_layer = get_dropout_layer(name="dropout")
|
| 81 |
+
a_layer = get_dropout_layer(name=("alphadropout", {"p": 0.25}))
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
name: a dropout ratio or a tuple of dropout type and parameters.
|
| 85 |
+
dropout_dim: the spatial dimension of the dropout operation.
|
| 86 |
+
"""
|
| 87 |
+
if isinstance(name, (int, float)):
|
| 88 |
+
# if dropout was specified simply as a p value, use default name and make a keyword map with the value
|
| 89 |
+
drop_name = Dropout.DROPOUT
|
| 90 |
+
drop_args = {"p": float(name)}
|
| 91 |
+
else:
|
| 92 |
+
drop_name, drop_args = split_args(name)
|
| 93 |
+
drop_type = Dropout[drop_name, dropout_dim]
|
| 94 |
+
return drop_type(**drop_args)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_pool_layer(name: Union[Tuple, str], spatial_dims: Optional[int] = 1):
|
| 98 |
+
"""
|
| 99 |
+
Create a pooling layer instance.
|
| 100 |
+
|
| 101 |
+
For example, to create adaptiveavg layer:
|
| 102 |
+
|
| 103 |
+
.. code-block:: python
|
| 104 |
+
|
| 105 |
+
from monai.networks.layers import get_pool_layer
|
| 106 |
+
|
| 107 |
+
pool_layer = get_pool_layer(("adaptiveavg", {"output_size": (1, 1, 1)}), spatial_dims=3)
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
name: a pooling type string or a tuple of type string and parameters.
|
| 111 |
+
spatial_dims: number of spatial dimensions of the input.
|
| 112 |
+
|
| 113 |
+
"""
|
| 114 |
+
pool_name, pool_args = split_args(name)
|
| 115 |
+
pool_type = Pool[pool_name, spatial_dims]
|
| 116 |
+
return pool_type(**pool_args)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/weight_init.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 18 |
+
"""Tensor initialization with truncated normal distribution.
|
| 19 |
+
Based on:
|
| 20 |
+
https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 21 |
+
https://github.com/rwightman/pytorch-image-models
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
tensor: an n-dimensional `torch.Tensor`.
|
| 25 |
+
mean: the mean of the normal distribution.
|
| 26 |
+
std: the standard deviation of the normal distribution.
|
| 27 |
+
a: the minimum cutoff value.
|
| 28 |
+
b: the maximum cutoff value.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def norm_cdf(x):
|
| 32 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
| 33 |
+
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
l = norm_cdf((a - mean) / std)
|
| 36 |
+
u = norm_cdf((b - mean) / std)
|
| 37 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 38 |
+
tensor.erfinv_()
|
| 39 |
+
tensor.mul_(std * math.sqrt(2.0))
|
| 40 |
+
tensor.add_(mean)
|
| 41 |
+
tensor.clamp_(min=a, max=b)
|
| 42 |
+
return tensor
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
| 46 |
+
"""Tensor initialization with truncated normal distribution.
|
| 47 |
+
Based on:
|
| 48 |
+
https://github.com/rwightman/pytorch-image-models
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
tensor: an n-dimensional `torch.Tensor`
|
| 52 |
+
mean: the mean of the normal distribution
|
| 53 |
+
std: the standard deviation of the normal distribution
|
| 54 |
+
a: the minimum cutoff value
|
| 55 |
+
b: the maximum cutoff value
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
if not std > 0:
|
| 59 |
+
raise ValueError("the standard deviation should be greater than zero.")
|
| 60 |
+
|
| 61 |
+
if a >= b:
|
| 62 |
+
raise ValueError("minimum cutoff value (a) should be smaller than maximum cutoff value (b).")
|
| 63 |
+
|
| 64 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__init__.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from .ahnet import AHnet, Ahnet, AHNet
|
| 13 |
+
from .attentionunet import AttentionUnet
|
| 14 |
+
from .autoencoder import AutoEncoder
|
| 15 |
+
from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet
|
| 16 |
+
from .classifier import Classifier, Critic, Discriminator
|
| 17 |
+
from .densenet import (
|
| 18 |
+
DenseNet,
|
| 19 |
+
Densenet,
|
| 20 |
+
DenseNet121,
|
| 21 |
+
Densenet121,
|
| 22 |
+
DenseNet169,
|
| 23 |
+
Densenet169,
|
| 24 |
+
DenseNet201,
|
| 25 |
+
Densenet201,
|
| 26 |
+
DenseNet264,
|
| 27 |
+
Densenet264,
|
| 28 |
+
densenet121,
|
| 29 |
+
densenet169,
|
| 30 |
+
densenet201,
|
| 31 |
+
densenet264,
|
| 32 |
+
)
|
| 33 |
+
from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch
|
| 34 |
+
from .dynunet import DynUNet, DynUnet, Dynunet
|
| 35 |
+
from .efficientnet import (
|
| 36 |
+
BlockArgs,
|
| 37 |
+
EfficientNet,
|
| 38 |
+
EfficientNetBN,
|
| 39 |
+
EfficientNetBNFeatures,
|
| 40 |
+
drop_connect,
|
| 41 |
+
get_efficientnet_image_size,
|
| 42 |
+
)
|
| 43 |
+
from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
|
| 44 |
+
from .generator import Generator
|
| 45 |
+
from .highresnet import HighResBlock, HighResNet
|
| 46 |
+
from .milmodel import MILModel
|
| 47 |
+
from .netadapter import NetAdapter
|
| 48 |
+
from .regressor import Regressor
|
| 49 |
+
from .regunet import GlobalNet, LocalNet, RegUNet
|
| 50 |
+
from .resnet import ResNet, resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
|
| 51 |
+
from .segresnet import SegResNet, SegResNetVAE
|
| 52 |
+
from .senet import (
|
| 53 |
+
SENet,
|
| 54 |
+
SEnet,
|
| 55 |
+
Senet,
|
| 56 |
+
SENet154,
|
| 57 |
+
SEnet154,
|
| 58 |
+
Senet154,
|
| 59 |
+
SEResNet50,
|
| 60 |
+
SEresnet50,
|
| 61 |
+
Seresnet50,
|
| 62 |
+
SEResNet101,
|
| 63 |
+
SEresnet101,
|
| 64 |
+
Seresnet101,
|
| 65 |
+
SEResNet152,
|
| 66 |
+
SEresnet152,
|
| 67 |
+
Seresnet152,
|
| 68 |
+
SEResNext50,
|
| 69 |
+
SEResNeXt50,
|
| 70 |
+
SEresnext50,
|
| 71 |
+
Seresnext50,
|
| 72 |
+
SEResNext101,
|
| 73 |
+
SEResNeXt101,
|
| 74 |
+
SEresnext101,
|
| 75 |
+
Seresnext101,
|
| 76 |
+
senet154,
|
| 77 |
+
seresnet50,
|
| 78 |
+
seresnet101,
|
| 79 |
+
seresnet152,
|
| 80 |
+
seresnext50,
|
| 81 |
+
seresnext101,
|
| 82 |
+
)
|
| 83 |
+
from .swin_unetr import SwinUNETR
|
| 84 |
+
from .torchvision_fc import TorchVisionFCModel
|
| 85 |
+
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
|
| 86 |
+
from .unet import UNet, Unet
|
| 87 |
+
from .unetr import UNETR
|
| 88 |
+
from .varautoencoder import VarAutoEncoder
|
| 89 |
+
from .vit import ViT
|
| 90 |
+
from .vitautoenc import ViTAutoEnc
|
| 91 |
+
from .vnet import VNet
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/attentionunet.cpython-38.pyc
ADDED
|
Binary file (6.42 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/autoencoder.cpython-38.pyc
ADDED
|
Binary file (9.21 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/basic_unet.cpython-38.pyc
ADDED
|
Binary file (9.93 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/dynunet.cpython-38.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/efficientnet.cpython-38.pyc
ADDED
|
Binary file (26.8 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/fullyconnectednet.cpython-38.pyc
ADDED
|
Binary file (6.5 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/netadapter.cpython-38.pyc
ADDED
|
Binary file (3.6 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/regressor.cpython-38.pyc
ADDED
|
Binary file (5.33 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/regunet.cpython-38.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/resnet.cpython-38.pyc
ADDED
|
Binary file (13.3 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/segresnet.cpython-38.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/senet.cpython-38.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/swin_unetr.cpython-38.pyc
ADDED
|
Binary file (27.9 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/torchvision_fc.cpython-38.pyc
ADDED
|
Binary file (2.99 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/unet.cpython-38.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/unetr.cpython-38.pyc
ADDED
|
Binary file (5.31 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/varautoencoder.cpython-38.pyc
ADDED
|
Binary file (5.77 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/nets/__pycache__/vit.cpython-38.pyc
ADDED
|
Binary file (3.93 kB). View file
|
|
|