Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- my_container_sandbox/workspace/anaconda3/lib/libnvvm.so.4.0.0 +3 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/__init__.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/box_utils.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/csv_saver.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/dataset.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/dataset_summary.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/folder_layout.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/grid_dataset.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_writer.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/iterable_dataset.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/meta_obj.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/meta_tensor.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/nifti_saver.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/nifti_writer.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/samplers.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/synthetic.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/test_time_augmentation.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/thread_buffer.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/torchscript_utils.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/utils.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/wsi_datasets.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/wsi_reader.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/handlers/__pycache__/checkpoint_loader.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/handlers/__pycache__/parameter_scheduler.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/handlers/__pycache__/utils.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/__init__.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/contrastive.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/deform.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/dice.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/focal_loss.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/giou_loss.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/image_dissimilarity.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/multi_scale.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/spatial_mask.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/tversky.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/__pycache__/__init__.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__init__.py +40 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/acti_norm.py +101 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/activation.py +162 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/aspp.py +105 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/backbone_fpn_utils.py +176 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/convolutions.py +326 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/crf.py +120 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/dints_block.py +272 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/downsample.py +61 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/dynunet_block.py +288 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/fcn.py +242 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/feature_pyramid_network.py +263 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/localnet_block.py +286 -0
.gitattributes
CHANGED
|
@@ -336,3 +336,6 @@ my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/scipy.libs/
|
|
| 336 |
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/brotli/_brotli.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 337 |
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/opencv_python.libs/libxcb-xkb-9ba31ab3.so.1.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 338 |
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/opencv_python.libs/libssl-28bef1ac.so.1.1 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/brotli/_brotli.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 337 |
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/opencv_python.libs/libxcb-xkb-9ba31ab3.so.1.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 338 |
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/opencv_python.libs/libssl-28bef1ac.so.1.1 filter=lfs diff=lfs merge=lfs -text
|
| 339 |
+
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/opencv_python.libs/libxkbcommon-71ae2972.so.0.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 340 |
+
my_container_sandbox/workspace/anaconda3/lib/libnvvm.so.4.0.0 filter=lfs diff=lfs merge=lfs -text
|
| 341 |
+
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/sklearn/_isotonic.cpython-38-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
my_container_sandbox/workspace/anaconda3/lib/libnvvm.so.4.0.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f0c95d249f3d60d67dbd191174fe1d8b9f393b1887ae1a54b4988823c7e42a31
|
| 3 |
+
size 26650200
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (4.41 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/box_utils.cpython-38.pyc
ADDED
|
Binary file (37 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/csv_saver.cpython-38.pyc
ADDED
|
Binary file (4.88 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/dataset.cpython-38.pyc
ADDED
|
Binary file (56.6 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/dataset_summary.cpython-38.pyc
ADDED
|
Binary file (8.38 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/folder_layout.cpython-38.pyc
ADDED
|
Binary file (3.58 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/grid_dataset.cpython-38.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_writer.cpython-38.pyc
ADDED
|
Binary file (32 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/iterable_dataset.cpython-38.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/meta_obj.cpython-38.pyc
ADDED
|
Binary file (7.96 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/meta_tensor.cpython-38.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/nifti_saver.cpython-38.pyc
ADDED
|
Binary file (8.83 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/nifti_writer.cpython-38.pyc
ADDED
|
Binary file (6.94 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/samplers.cpython-38.pyc
ADDED
|
Binary file (4.75 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/synthetic.cpython-38.pyc
ADDED
|
Binary file (5.22 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/test_time_augmentation.cpython-38.pyc
ADDED
|
Binary file (8.9 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/thread_buffer.cpython-38.pyc
ADDED
|
Binary file (8.63 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/torchscript_utils.cpython-38.pyc
ADDED
|
Binary file (4.68 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (54 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/wsi_datasets.cpython-38.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/wsi_reader.cpython-38.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/handlers/__pycache__/checkpoint_loader.cpython-38.pyc
ADDED
|
Binary file (5.49 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/handlers/__pycache__/parameter_scheduler.cpython-38.pyc
ADDED
|
Binary file (6.53 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/handlers/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (9.52 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (990 Bytes). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/contrastive.cpython-38.pyc
ADDED
|
Binary file (2.93 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/deform.cpython-38.pyc
ADDED
|
Binary file (4.09 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/dice.cpython-38.pyc
ADDED
|
Binary file (36.3 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/focal_loss.cpython-38.pyc
ADDED
|
Binary file (7.24 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/giou_loss.cpython-38.pyc
ADDED
|
Binary file (2.43 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/image_dissimilarity.cpython-38.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/multi_scale.cpython-38.pyc
ADDED
|
Binary file (3.05 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/spatial_mask.cpython-38.pyc
ADDED
|
Binary file (2.36 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/losses/__pycache__/tversky.cpython-38.pyc
ADDED
|
Binary file (5.19 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (583 Bytes). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from .acti_norm import ADN
|
| 13 |
+
from .activation import MemoryEfficientSwish, Mish, Swish
|
| 14 |
+
from .aspp import SimpleASPP
|
| 15 |
+
from .backbone_fpn_utils import BackboneWithFPN
|
| 16 |
+
from .convolutions import Convolution, ResidualUnit
|
| 17 |
+
from .crf import CRF
|
| 18 |
+
from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
|
| 19 |
+
from .downsample import MaxAvgPool
|
| 20 |
+
from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding
|
| 21 |
+
from .fcn import FCN, GCN, MCFCN, Refine
|
| 22 |
+
from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7
|
| 23 |
+
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
|
| 24 |
+
from .mlp import MLPBlock
|
| 25 |
+
from .patchembedding import PatchEmbed, PatchEmbeddingBlock
|
| 26 |
+
from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock
|
| 27 |
+
from .segresnet_block import ResBlock
|
| 28 |
+
from .selfattention import SABlock
|
| 29 |
+
from .squeeze_and_excitation import (
|
| 30 |
+
ChannelSELayer,
|
| 31 |
+
ResidualSELayer,
|
| 32 |
+
SEBlock,
|
| 33 |
+
SEBottleneck,
|
| 34 |
+
SEResNetBottleneck,
|
| 35 |
+
SEResNeXtBottleneck,
|
| 36 |
+
)
|
| 37 |
+
from .transformerblock import TransformerBlock
|
| 38 |
+
from .unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
|
| 39 |
+
from .upsample import SubpixelUpsample, Subpixelupsample, SubpixelUpSample, Upsample, UpSample
|
| 40 |
+
from .warp import DVF2DDF, Warp
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/acti_norm.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from typing import Optional, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
from monai.networks.layers.utils import get_act_layer, get_dropout_layer, get_norm_layer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ADN(nn.Sequential):
|
| 20 |
+
"""
|
| 21 |
+
Constructs a sequential module of optional activation (A), dropout (D), and normalization (N) layers
|
| 22 |
+
with an arbitrary order::
|
| 23 |
+
|
| 24 |
+
-- (Norm) -- (Dropout) -- (Acti) --
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
ordering: a string representing the ordering of activation, dropout, and normalization. Defaults to "NDA".
|
| 28 |
+
in_channels: `C` from an expected input of size (N, C, H[, W, D]).
|
| 29 |
+
act: activation type and arguments. Defaults to PReLU.
|
| 30 |
+
norm: feature normalization type and arguments. Defaults to instance norm.
|
| 31 |
+
norm_dim: determine the spatial dimensions of the normalization layer.
|
| 32 |
+
defaults to `dropout_dim` if unspecified.
|
| 33 |
+
dropout: dropout ratio. Defaults to no dropout.
|
| 34 |
+
dropout_dim: determine the spatial dimensions of dropout.
|
| 35 |
+
defaults to `norm_dim` if unspecified.
|
| 36 |
+
|
| 37 |
+
- When dropout_dim = 1, randomly zeroes some of the elements for each channel.
|
| 38 |
+
- When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map).
|
| 39 |
+
- When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map).
|
| 40 |
+
|
| 41 |
+
Examples::
|
| 42 |
+
|
| 43 |
+
# activation, group norm, dropout
|
| 44 |
+
>>> norm_params = ("GROUP", {"num_groups": 1, "affine": False})
|
| 45 |
+
>>> ADN(norm=norm_params, in_channels=1, dropout_dim=1, dropout=0.8, ordering="AND")
|
| 46 |
+
ADN(
|
| 47 |
+
(A): ReLU()
|
| 48 |
+
(N): GroupNorm(1, 1, eps=1e-05, affine=False)
|
| 49 |
+
(D): Dropout(p=0.8, inplace=False)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# LeakyReLU, dropout
|
| 53 |
+
>>> act_params = ("leakyrelu", {"negative_slope": 0.1, "inplace": True})
|
| 54 |
+
>>> ADN(act=act_params, in_channels=1, dropout_dim=1, dropout=0.8, ordering="AD")
|
| 55 |
+
ADN(
|
| 56 |
+
(A): LeakyReLU(negative_slope=0.1, inplace=True)
|
| 57 |
+
(D): Dropout(p=0.8, inplace=False)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
See also:
|
| 61 |
+
|
| 62 |
+
:py:class:`monai.networks.layers.Dropout`
|
| 63 |
+
:py:class:`monai.networks.layers.Act`
|
| 64 |
+
:py:class:`monai.networks.layers.Norm`
|
| 65 |
+
:py:class:`monai.networks.layers.split_args`
|
| 66 |
+
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
ordering: str = "NDA",
|
| 72 |
+
in_channels: Optional[int] = None,
|
| 73 |
+
act: Optional[Union[Tuple, str]] = "RELU",
|
| 74 |
+
norm: Optional[Union[Tuple, str]] = None,
|
| 75 |
+
norm_dim: Optional[int] = None,
|
| 76 |
+
dropout: Optional[Union[Tuple, str, float]] = None,
|
| 77 |
+
dropout_dim: Optional[int] = None,
|
| 78 |
+
) -> None:
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
op_dict = {"A": None, "D": None, "N": None}
|
| 82 |
+
# define the normalization type and the arguments to the constructor
|
| 83 |
+
if norm is not None:
|
| 84 |
+
if norm_dim is None and dropout_dim is None:
|
| 85 |
+
raise ValueError("norm_dim or dropout_dim needs to be specified.")
|
| 86 |
+
op_dict["N"] = get_norm_layer(name=norm, spatial_dims=norm_dim or dropout_dim, channels=in_channels)
|
| 87 |
+
|
| 88 |
+
# define the activation type and the arguments to the constructor
|
| 89 |
+
if act is not None:
|
| 90 |
+
op_dict["A"] = get_act_layer(act)
|
| 91 |
+
|
| 92 |
+
if dropout is not None:
|
| 93 |
+
if norm_dim is None and dropout_dim is None:
|
| 94 |
+
raise ValueError("norm_dim or dropout_dim needs to be specified.")
|
| 95 |
+
op_dict["D"] = get_dropout_layer(name=dropout, dropout_dim=dropout_dim or norm_dim)
|
| 96 |
+
|
| 97 |
+
for item in ordering.upper():
|
| 98 |
+
if item not in op_dict:
|
| 99 |
+
raise ValueError(f"ordering must be a string of {op_dict}, got {item} in it.")
|
| 100 |
+
if op_dict[item] is not None:
|
| 101 |
+
self.add_module(item, op_dict[item])
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/activation.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
from monai.utils import optional_import
|
| 16 |
+
|
| 17 |
+
if optional_import("torch.nn.functional", name="mish")[1]:
|
| 18 |
+
|
| 19 |
+
def monai_mish(x, inplace: bool = False):
|
| 20 |
+
return torch.nn.functional.mish(x, inplace=inplace)
|
| 21 |
+
|
| 22 |
+
else:
|
| 23 |
+
|
| 24 |
+
def monai_mish(x, inplace: bool = False):
|
| 25 |
+
return x * torch.tanh(torch.nn.functional.softplus(x))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if optional_import("torch.nn.functional", name="silu")[1]:
|
| 29 |
+
|
| 30 |
+
def monai_swish(x, inplace: bool = False):
|
| 31 |
+
return torch.nn.functional.silu(x, inplace=inplace)
|
| 32 |
+
|
| 33 |
+
else:
|
| 34 |
+
|
| 35 |
+
def monai_swish(x, inplace: bool = False):
|
| 36 |
+
return SwishImplementation.apply(x)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Swish(nn.Module):
|
| 40 |
+
r"""Applies the element-wise function:
|
| 41 |
+
|
| 42 |
+
.. math::
|
| 43 |
+
\text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) ~~~~\text{for constant value}~ \alpha.
|
| 44 |
+
|
| 45 |
+
Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
Shape:
|
| 49 |
+
- Input: :math:`(N, *)` where `*` means, any number of additional dimensions
|
| 50 |
+
- Output: :math:`(N, *)`, same shape as the input
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
Examples::
|
| 54 |
+
|
| 55 |
+
>>> import torch
|
| 56 |
+
>>> from monai.networks.layers.factories import Act
|
| 57 |
+
>>> m = Act['swish']()
|
| 58 |
+
>>> input = torch.randn(2)
|
| 59 |
+
>>> output = m(input)
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, alpha=1.0):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.alpha = alpha
|
| 65 |
+
|
| 66 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
return input * torch.sigmoid(self.alpha * input)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class SwishImplementation(torch.autograd.Function):
|
| 71 |
+
r"""Memory efficient implementation for training
|
| 72 |
+
Follows recommendation from:
|
| 73 |
+
https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853
|
| 74 |
+
|
| 75 |
+
Results in ~ 30% memory saving during training as compared to Swish()
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def forward(ctx, input):
|
| 80 |
+
result = input * torch.sigmoid(input)
|
| 81 |
+
ctx.save_for_backward(input)
|
| 82 |
+
return result
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def backward(ctx, grad_output):
|
| 86 |
+
input = ctx.saved_tensors[0]
|
| 87 |
+
sigmoid_input = torch.sigmoid(input)
|
| 88 |
+
return grad_output * (sigmoid_input * (1 + input * (1 - sigmoid_input)))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class MemoryEfficientSwish(nn.Module):
|
| 92 |
+
r"""Applies the element-wise function:
|
| 93 |
+
|
| 94 |
+
.. math::
|
| 95 |
+
\text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) ~~~~\text{for constant value}~ \alpha=1.
|
| 96 |
+
|
| 97 |
+
Memory efficient implementation for training following recommendation from:
|
| 98 |
+
https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853
|
| 99 |
+
|
| 100 |
+
Results in ~ 30% memory saving during training as compared to Swish()
|
| 101 |
+
|
| 102 |
+
Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.
|
| 103 |
+
|
| 104 |
+
From Pytorch 1.7.0+, the optimized version of `Swish` named `SiLU` is implemented,
|
| 105 |
+
this class will utilize `torch.nn.functional.silu` to do the calculation if meets the version.
|
| 106 |
+
|
| 107 |
+
Shape:
|
| 108 |
+
- Input: :math:`(N, *)` where `*` means, any number of additional
|
| 109 |
+
dimensions
|
| 110 |
+
- Output: :math:`(N, *)`, same shape as the input
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
Examples::
|
| 114 |
+
|
| 115 |
+
>>> import torch
|
| 116 |
+
>>> from monai.networks.layers.factories import Act
|
| 117 |
+
>>> m = Act['memswish']()
|
| 118 |
+
>>> input = torch.randn(2)
|
| 119 |
+
>>> output = m(input)
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, inplace: bool = False):
|
| 123 |
+
super().__init__()
|
| 124 |
+
# inplace only works when using torch.nn.functional.silu
|
| 125 |
+
self.inplace = inplace
|
| 126 |
+
|
| 127 |
+
def forward(self, input: torch.Tensor):
|
| 128 |
+
return monai_swish(input, self.inplace)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class Mish(nn.Module):
|
| 132 |
+
r"""Applies the element-wise function:
|
| 133 |
+
|
| 134 |
+
.. math::
|
| 135 |
+
\text{Mish}(x) = x * tanh(\text{softplus}(x)).
|
| 136 |
+
|
| 137 |
+
Citation: Mish: A Self Regularized Non-Monotonic Activation Function, Diganta Misra, 2019, https://arxiv.org/abs/1908.08681.
|
| 138 |
+
|
| 139 |
+
From Pytorch 1.9.0+, the optimized version of `Mish` is implemented,
|
| 140 |
+
this class will utilize `torch.nn.functional.mish` to do the calculation if meets the version.
|
| 141 |
+
|
| 142 |
+
Shape:
|
| 143 |
+
- Input: :math:`(N, *)` where `*` means, any number of additional dimensions
|
| 144 |
+
- Output: :math:`(N, *)`, same shape as the input
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
Examples::
|
| 148 |
+
|
| 149 |
+
>>> import torch
|
| 150 |
+
>>> from monai.networks.layers.factories import Act
|
| 151 |
+
>>> m = Act['mish']()
|
| 152 |
+
>>> input = torch.randn(2)
|
| 153 |
+
>>> output = m(input)
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def __init__(self, inplace: bool = False):
|
| 157 |
+
super().__init__()
|
| 158 |
+
# inplace only works when using torch.nn.functional.mish
|
| 159 |
+
self.inplace = inplace
|
| 160 |
+
|
| 161 |
+
def forward(self, input: torch.Tensor):
|
| 162 |
+
return monai_mish(input, self.inplace)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/aspp.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from typing import Optional, Sequence, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
from monai.networks.blocks.convolutions import Convolution
|
| 18 |
+
from monai.networks.layers import same_padding
|
| 19 |
+
from monai.networks.layers.factories import Conv
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SimpleASPP(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
A simplified version of the atrous spatial pyramid pooling (ASPP) module.
|
| 25 |
+
|
| 26 |
+
Chen et al., Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation.
|
| 27 |
+
https://arxiv.org/abs/1802.02611
|
| 28 |
+
|
| 29 |
+
Wang et al., A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions
|
| 30 |
+
from CT Images. https://ieeexplore.ieee.org/document/9109297
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
spatial_dims: int,
|
| 36 |
+
in_channels: int,
|
| 37 |
+
conv_out_channels: int,
|
| 38 |
+
kernel_sizes: Sequence[int] = (1, 3, 3, 3),
|
| 39 |
+
dilations: Sequence[int] = (1, 2, 4, 6),
|
| 40 |
+
norm_type: Optional[Union[Tuple, str]] = "BATCH",
|
| 41 |
+
acti_type: Optional[Union[Tuple, str]] = "LEAKYRELU",
|
| 42 |
+
bias: bool = False,
|
| 43 |
+
) -> None:
|
| 44 |
+
"""
|
| 45 |
+
Args:
|
| 46 |
+
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
|
| 47 |
+
in_channels: number of input channels.
|
| 48 |
+
conv_out_channels: number of output channels of each atrous conv.
|
| 49 |
+
The final number of output channels is conv_out_channels * len(kernel_sizes).
|
| 50 |
+
kernel_sizes: a sequence of four convolutional kernel sizes.
|
| 51 |
+
Defaults to (1, 3, 3, 3) for four (dilated) convolutions.
|
| 52 |
+
dilations: a sequence of four convolutional dilation parameters.
|
| 53 |
+
Defaults to (1, 2, 4, 6) for four (dilated) convolutions.
|
| 54 |
+
norm_type: final kernel-size-one convolution normalization type.
|
| 55 |
+
Defaults to batch norm.
|
| 56 |
+
acti_type: final kernel-size-one convolution activation type.
|
| 57 |
+
Defaults to leaky ReLU.
|
| 58 |
+
bias: whether to have a bias term in convolution blocks. Defaults to False.
|
| 59 |
+
According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
|
| 60 |
+
if a conv layer is directly followed by a batch norm layer, bias should be False.
|
| 61 |
+
|
| 62 |
+
Raises:
|
| 63 |
+
ValueError: When ``kernel_sizes`` length differs from ``dilations``.
|
| 64 |
+
|
| 65 |
+
See also:
|
| 66 |
+
|
| 67 |
+
:py:class:`monai.networks.layers.Act`
|
| 68 |
+
:py:class:`monai.networks.layers.Conv`
|
| 69 |
+
:py:class:`monai.networks.layers.Norm`
|
| 70 |
+
|
| 71 |
+
"""
|
| 72 |
+
super().__init__()
|
| 73 |
+
if len(kernel_sizes) != len(dilations):
|
| 74 |
+
raise ValueError(
|
| 75 |
+
"kernel_sizes and dilations length must match, "
|
| 76 |
+
f"got kernel_sizes={len(kernel_sizes)} dilations={len(dilations)}."
|
| 77 |
+
)
|
| 78 |
+
pads = tuple(same_padding(k, d) for k, d in zip(kernel_sizes, dilations))
|
| 79 |
+
|
| 80 |
+
self.convs = nn.ModuleList()
|
| 81 |
+
for k, d, p in zip(kernel_sizes, dilations, pads):
|
| 82 |
+
_conv = Conv[Conv.CONV, spatial_dims](
|
| 83 |
+
in_channels=in_channels, out_channels=conv_out_channels, kernel_size=k, dilation=d, padding=p
|
| 84 |
+
)
|
| 85 |
+
self.convs.append(_conv)
|
| 86 |
+
|
| 87 |
+
out_channels = conv_out_channels * len(pads) # final conv. output channels
|
| 88 |
+
self.conv_k1 = Convolution(
|
| 89 |
+
spatial_dims=spatial_dims,
|
| 90 |
+
in_channels=out_channels,
|
| 91 |
+
out_channels=out_channels,
|
| 92 |
+
kernel_size=1,
|
| 93 |
+
act=acti_type,
|
| 94 |
+
norm=norm_type,
|
| 95 |
+
bias=bias,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 99 |
+
"""
|
| 100 |
+
Args:
|
| 101 |
+
x: in shape (batch, channel, spatial_1[, spatial_2, ...]).
|
| 102 |
+
"""
|
| 103 |
+
x_out = torch.cat([conv(x) for conv in self.convs], dim=1)
|
| 104 |
+
x_out = self.conv_k1(x_out)
|
| 105 |
+
return x_out
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/backbone_fpn_utils.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
# =========================================================================
|
| 13 |
+
# Adapted from https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/backbone_utils.py
|
| 14 |
+
# which has the following license...
|
| 15 |
+
# https://github.com/pytorch/vision/blob/main/LICENSE
|
| 16 |
+
#
|
| 17 |
+
# BSD 3-Clause License
|
| 18 |
+
|
| 19 |
+
# Copyright (c) Soumith Chintala 2016,
|
| 20 |
+
# All rights reserved.
|
| 21 |
+
|
| 22 |
+
# Redistribution and use in source and binary forms, with or without
|
| 23 |
+
# modification, are permitted provided that the following conditions are met:
|
| 24 |
+
|
| 25 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
| 26 |
+
# list of conditions and the following disclaimer.
|
| 27 |
+
|
| 28 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
| 29 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 30 |
+
# and/or other materials provided with the distribution.
|
| 31 |
+
|
| 32 |
+
# * Neither the name of the copyright holder nor the names of its
|
| 33 |
+
# contributors may be used to endorse or promote products derived from
|
| 34 |
+
# this software without specific prior written permission.
|
| 35 |
+
|
| 36 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 37 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 38 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 39 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 40 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 41 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 42 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 43 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 44 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 45 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
This script is modified from from torchvision to support N-D images,
|
| 49 |
+
by overriding the definition of convolutional layers and pooling layers.
|
| 50 |
+
|
| 51 |
+
https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/backbone_utils.py
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
from typing import Dict, List, Optional, Union
|
| 55 |
+
|
| 56 |
+
from torch import Tensor, nn
|
| 57 |
+
|
| 58 |
+
from monai.networks.nets import resnet
|
| 59 |
+
from monai.utils import optional_import
|
| 60 |
+
|
| 61 |
+
from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
|
| 62 |
+
|
| 63 |
+
torchvision_models, _ = optional_import("torchvision.models")
|
| 64 |
+
|
| 65 |
+
__all__ = ["BackboneWithFPN"]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class BackboneWithFPN(nn.Module):
|
| 69 |
+
"""
|
| 70 |
+
Adds an FPN on top of a model.
|
| 71 |
+
Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
|
| 72 |
+
extract a submodel that returns the feature maps specified in return_layers.
|
| 73 |
+
The same limitations of IntermediateLayerGetter apply here.
|
| 74 |
+
|
| 75 |
+
Same code as https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/backbone_utils.py
|
| 76 |
+
Except that this class uses spatial_dims
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
backbone: backbone network
|
| 80 |
+
return_layers: a dict containing the names
|
| 81 |
+
of the modules for which the activations will be returned as
|
| 82 |
+
the key of the dict, and the value of the dict is the name
|
| 83 |
+
of the returned activation (which the user can specify).
|
| 84 |
+
in_channels_list: number of channels for each feature map
|
| 85 |
+
that is returned, in the order they are present in the OrderedDict
|
| 86 |
+
out_channels: number of channels in the FPN.
|
| 87 |
+
spatial_dims: 2D or 3D images
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
backbone: nn.Module,
|
| 93 |
+
return_layers: Dict[str, str],
|
| 94 |
+
in_channels_list: List[int],
|
| 95 |
+
out_channels: int,
|
| 96 |
+
spatial_dims: Union[int, None] = None,
|
| 97 |
+
extra_blocks: Optional[ExtraFPNBlock] = None,
|
| 98 |
+
) -> None:
|
| 99 |
+
super().__init__()
|
| 100 |
+
|
| 101 |
+
# if spatial_dims is not specified, try to find it from backbone.
|
| 102 |
+
if spatial_dims is None:
|
| 103 |
+
if hasattr(backbone, "spatial_dims") and isinstance(backbone.spatial_dims, int):
|
| 104 |
+
spatial_dims = backbone.spatial_dims
|
| 105 |
+
elif isinstance(backbone.conv1, nn.Conv2d):
|
| 106 |
+
spatial_dims = 2
|
| 107 |
+
elif isinstance(backbone.conv1, nn.Conv3d):
|
| 108 |
+
spatial_dims = 3
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError("Could not find spatial_dims of backbone, please specify it.")
|
| 111 |
+
|
| 112 |
+
if extra_blocks is None:
|
| 113 |
+
extra_blocks = LastLevelMaxPool(spatial_dims)
|
| 114 |
+
|
| 115 |
+
self.body = torchvision_models._utils.IntermediateLayerGetter(backbone, return_layers=return_layers)
|
| 116 |
+
self.fpn = FeaturePyramidNetwork(
|
| 117 |
+
spatial_dims=spatial_dims,
|
| 118 |
+
in_channels_list=in_channels_list,
|
| 119 |
+
out_channels=out_channels,
|
| 120 |
+
extra_blocks=extra_blocks,
|
| 121 |
+
)
|
| 122 |
+
self.out_channels = out_channels
|
| 123 |
+
|
| 124 |
+
def forward(self, x: Tensor) -> Dict[str, Tensor]:
|
| 125 |
+
"""
|
| 126 |
+
Computes the resulted feature maps of the network.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
x: input images
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
feature maps after FPN layers. They are ordered from highest resolution first.
|
| 133 |
+
"""
|
| 134 |
+
x = self.body(x) # backbone
|
| 135 |
+
y: Dict[str, Tensor] = self.fpn(x) # FPN
|
| 136 |
+
return y
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _resnet_fpn_extractor(
|
| 140 |
+
backbone: resnet.ResNet,
|
| 141 |
+
spatial_dims: int,
|
| 142 |
+
trainable_layers: int = 5,
|
| 143 |
+
returned_layers: Optional[List[int]] = None,
|
| 144 |
+
extra_blocks: Optional[ExtraFPNBlock] = None,
|
| 145 |
+
) -> BackboneWithFPN:
|
| 146 |
+
"""
|
| 147 |
+
Same code as https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/backbone_utils.py
|
| 148 |
+
Except that ``in_channels_stage2 = backbone.in_planes // 8`` instead of ``in_channels_stage2 = backbone.inplanes // 8``,
|
| 149 |
+
and it requires spatial_dims: 2D or 3D images.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
# select layers that wont be frozen
|
| 153 |
+
if trainable_layers < 0 or trainable_layers > 5:
|
| 154 |
+
raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
|
| 155 |
+
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
|
| 156 |
+
if trainable_layers == 5:
|
| 157 |
+
layers_to_train.append("bn1")
|
| 158 |
+
for name, parameter in backbone.named_parameters():
|
| 159 |
+
if all([not name.startswith(layer) for layer in layers_to_train]):
|
| 160 |
+
parameter.requires_grad_(False)
|
| 161 |
+
|
| 162 |
+
if extra_blocks is None:
|
| 163 |
+
extra_blocks = LastLevelMaxPool(spatial_dims)
|
| 164 |
+
|
| 165 |
+
if returned_layers is None:
|
| 166 |
+
returned_layers = [1, 2, 3, 4]
|
| 167 |
+
if min(returned_layers) <= 0 or max(returned_layers) >= 5:
|
| 168 |
+
raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
|
| 169 |
+
return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
|
| 170 |
+
|
| 171 |
+
in_channels_stage2 = backbone.in_planes // 8
|
| 172 |
+
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
|
| 173 |
+
out_channels = 256
|
| 174 |
+
return BackboneWithFPN(
|
| 175 |
+
backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, spatial_dims=spatial_dims
|
| 176 |
+
)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/convolutions.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from typing import Optional, Sequence, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
from monai.networks.blocks import ADN
|
| 19 |
+
from monai.networks.layers.convutils import same_padding, stride_minus_kernel_padding
|
| 20 |
+
from monai.networks.layers.factories import Conv
|
| 21 |
+
from monai.utils.deprecate_utils import deprecated_arg
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Convolution(nn.Sequential):
|
| 25 |
+
"""
|
| 26 |
+
Constructs a convolution with normalization, optional dropout, and optional activation layers::
|
| 27 |
+
|
| 28 |
+
-- (Conv|ConvTrans) -- (Norm -- Dropout -- Acti) --
|
| 29 |
+
|
| 30 |
+
if ``conv_only`` set to ``True``::
|
| 31 |
+
|
| 32 |
+
-- (Conv|ConvTrans) --
|
| 33 |
+
|
| 34 |
+
For example:
|
| 35 |
+
|
| 36 |
+
.. code-block:: python
|
| 37 |
+
|
| 38 |
+
from monai.networks.blocks import Convolution
|
| 39 |
+
|
| 40 |
+
conv = Convolution(
|
| 41 |
+
dimensions=3,
|
| 42 |
+
in_channels=1,
|
| 43 |
+
out_channels=1,
|
| 44 |
+
adn_ordering="ADN",
|
| 45 |
+
act=("prelu", {"init": 0.2}),
|
| 46 |
+
dropout=0.1,
|
| 47 |
+
norm=("layer", {"normalized_shape": (10, 10, 10)}),
|
| 48 |
+
)
|
| 49 |
+
print(conv)
|
| 50 |
+
|
| 51 |
+
output::
|
| 52 |
+
|
| 53 |
+
Convolution(
|
| 54 |
+
(conv): Conv3d(1, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
| 55 |
+
(adn): ADN(
|
| 56 |
+
(A): PReLU(num_parameters=1)
|
| 57 |
+
(D): Dropout(p=0.1, inplace=False)
|
| 58 |
+
(N): LayerNorm((10, 10, 10), eps=1e-05, elementwise_affine=True)
|
| 59 |
+
)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
spatial_dims: number of spatial dimensions.
|
| 64 |
+
in_channels: number of input channels.
|
| 65 |
+
out_channels: number of output channels.
|
| 66 |
+
strides: convolution stride. Defaults to 1.
|
| 67 |
+
kernel_size: convolution kernel size. Defaults to 3.
|
| 68 |
+
adn_ordering: a string representing the ordering of activation, normalization, and dropout.
|
| 69 |
+
Defaults to "NDA".
|
| 70 |
+
act: activation type and arguments. Defaults to PReLU.
|
| 71 |
+
norm: feature normalization type and arguments. Defaults to instance norm.
|
| 72 |
+
dropout: dropout ratio. Defaults to no dropout.
|
| 73 |
+
dropout_dim: determine the spatial dimensions of dropout. Defaults to 1.
|
| 74 |
+
|
| 75 |
+
- When dropout_dim = 1, randomly zeroes some of the elements for each channel.
|
| 76 |
+
- When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map).
|
| 77 |
+
- When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map).
|
| 78 |
+
|
| 79 |
+
The value of dropout_dim should be no no larger than the value of `spatial_dims`.
|
| 80 |
+
dilation: dilation rate. Defaults to 1.
|
| 81 |
+
groups: controls the connections between inputs and outputs. Defaults to 1.
|
| 82 |
+
bias: whether to have a bias term. Defaults to True.
|
| 83 |
+
conv_only: whether to use the convolutional layer only. Defaults to False.
|
| 84 |
+
is_transposed: if True uses ConvTrans instead of Conv. Defaults to False.
|
| 85 |
+
padding: controls the amount of implicit zero-paddings on both sides for padding number of points
|
| 86 |
+
for each dimension. Defaults to None.
|
| 87 |
+
output_padding: controls the additional size added to one side of the output shape.
|
| 88 |
+
Defaults to None.
|
| 89 |
+
|
| 90 |
+
.. deprecated:: 0.6.0
|
| 91 |
+
``dimensions`` is deprecated, use ``spatial_dims`` instead.
|
| 92 |
+
|
| 93 |
+
See also:
|
| 94 |
+
|
| 95 |
+
:py:class:`monai.networks.layers.Conv`
|
| 96 |
+
:py:class:`monai.networks.blocks.ADN`
|
| 97 |
+
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
@deprecated_arg(
|
| 101 |
+
name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead."
|
| 102 |
+
)
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
spatial_dims: int,
|
| 106 |
+
in_channels: int,
|
| 107 |
+
out_channels: int,
|
| 108 |
+
strides: Union[Sequence[int], int] = 1,
|
| 109 |
+
kernel_size: Union[Sequence[int], int] = 3,
|
| 110 |
+
adn_ordering: str = "NDA",
|
| 111 |
+
act: Optional[Union[Tuple, str]] = "PRELU",
|
| 112 |
+
norm: Optional[Union[Tuple, str]] = "INSTANCE",
|
| 113 |
+
dropout: Optional[Union[Tuple, str, float]] = None,
|
| 114 |
+
dropout_dim: Optional[int] = 1,
|
| 115 |
+
dilation: Union[Sequence[int], int] = 1,
|
| 116 |
+
groups: int = 1,
|
| 117 |
+
bias: bool = True,
|
| 118 |
+
conv_only: bool = False,
|
| 119 |
+
is_transposed: bool = False,
|
| 120 |
+
padding: Optional[Union[Sequence[int], int]] = None,
|
| 121 |
+
output_padding: Optional[Union[Sequence[int], int]] = None,
|
| 122 |
+
dimensions: Optional[int] = None,
|
| 123 |
+
) -> None:
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.dimensions = spatial_dims if dimensions is None else dimensions
|
| 126 |
+
self.in_channels = in_channels
|
| 127 |
+
self.out_channels = out_channels
|
| 128 |
+
self.is_transposed = is_transposed
|
| 129 |
+
if padding is None:
|
| 130 |
+
padding = same_padding(kernel_size, dilation)
|
| 131 |
+
conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, self.dimensions]
|
| 132 |
+
|
| 133 |
+
conv: nn.Module
|
| 134 |
+
if is_transposed:
|
| 135 |
+
if output_padding is None:
|
| 136 |
+
output_padding = stride_minus_kernel_padding(1, strides)
|
| 137 |
+
conv = conv_type(
|
| 138 |
+
in_channels,
|
| 139 |
+
out_channels,
|
| 140 |
+
kernel_size=kernel_size,
|
| 141 |
+
stride=strides,
|
| 142 |
+
padding=padding,
|
| 143 |
+
output_padding=output_padding,
|
| 144 |
+
groups=groups,
|
| 145 |
+
bias=bias,
|
| 146 |
+
dilation=dilation,
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
conv = conv_type(
|
| 150 |
+
in_channels,
|
| 151 |
+
out_channels,
|
| 152 |
+
kernel_size=kernel_size,
|
| 153 |
+
stride=strides,
|
| 154 |
+
padding=padding,
|
| 155 |
+
dilation=dilation,
|
| 156 |
+
groups=groups,
|
| 157 |
+
bias=bias,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
self.add_module("conv", conv)
|
| 161 |
+
|
| 162 |
+
if not conv_only:
|
| 163 |
+
self.add_module(
|
| 164 |
+
"adn",
|
| 165 |
+
ADN(
|
| 166 |
+
ordering=adn_ordering,
|
| 167 |
+
in_channels=out_channels,
|
| 168 |
+
act=act,
|
| 169 |
+
norm=norm,
|
| 170 |
+
norm_dim=self.dimensions,
|
| 171 |
+
dropout=dropout,
|
| 172 |
+
dropout_dim=dropout_dim,
|
| 173 |
+
),
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class ResidualUnit(nn.Module):
|
| 178 |
+
"""
|
| 179 |
+
Residual module with multiple convolutions and a residual connection.
|
| 180 |
+
|
| 181 |
+
For example:
|
| 182 |
+
|
| 183 |
+
.. code-block:: python
|
| 184 |
+
|
| 185 |
+
from monai.networks.blocks import ResidualUnit
|
| 186 |
+
|
| 187 |
+
convs = ResidualUnit(
|
| 188 |
+
spatial_dims=3,
|
| 189 |
+
in_channels=1,
|
| 190 |
+
out_channels=1,
|
| 191 |
+
adn_ordering="AN",
|
| 192 |
+
act=("prelu", {"init": 0.2}),
|
| 193 |
+
norm=("layer", {"normalized_shape": (10, 10, 10)}),
|
| 194 |
+
)
|
| 195 |
+
print(convs)
|
| 196 |
+
|
| 197 |
+
output::
|
| 198 |
+
|
| 199 |
+
ResidualUnit(
|
| 200 |
+
(conv): Sequential(
|
| 201 |
+
(unit0): Convolution(
|
| 202 |
+
(conv): Conv3d(1, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
| 203 |
+
(adn): ADN(
|
| 204 |
+
(A): PReLU(num_parameters=1)
|
| 205 |
+
(N): LayerNorm((10, 10, 10), eps=1e-05, elementwise_affine=True)
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
(unit1): Convolution(
|
| 209 |
+
(conv): Conv3d(1, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
| 210 |
+
(adn): ADN(
|
| 211 |
+
(A): PReLU(num_parameters=1)
|
| 212 |
+
(N): LayerNorm((10, 10, 10), eps=1e-05, elementwise_affine=True)
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
)
|
| 216 |
+
(residual): Identity()
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
spatial_dims: number of spatial dimensions.
|
| 221 |
+
in_channels: number of input channels.
|
| 222 |
+
out_channels: number of output channels.
|
| 223 |
+
strides: convolution stride. Defaults to 1.
|
| 224 |
+
kernel_size: convolution kernel size. Defaults to 3.
|
| 225 |
+
subunits: number of convolutions. Defaults to 2.
|
| 226 |
+
adn_ordering: a string representing the ordering of activation, normalization, and dropout.
|
| 227 |
+
Defaults to "NDA".
|
| 228 |
+
act: activation type and arguments. Defaults to PReLU.
|
| 229 |
+
norm: feature normalization type and arguments. Defaults to instance norm.
|
| 230 |
+
dropout: dropout ratio. Defaults to no dropout.
|
| 231 |
+
dropout_dim: determine the dimensions of dropout. Defaults to 1.
|
| 232 |
+
|
| 233 |
+
- When dropout_dim = 1, randomly zeroes some of the elements for each channel.
|
| 234 |
+
- When dropout_dim = 2, Randomly zero out entire channels (a channel is a 2D feature map).
|
| 235 |
+
- When dropout_dim = 3, Randomly zero out entire channels (a channel is a 3D feature map).
|
| 236 |
+
|
| 237 |
+
The value of dropout_dim should be no no larger than the value of `dimensions`.
|
| 238 |
+
dilation: dilation rate. Defaults to 1.
|
| 239 |
+
bias: whether to have a bias term. Defaults to True.
|
| 240 |
+
last_conv_only: for the last subunit, whether to use the convolutional layer only.
|
| 241 |
+
Defaults to False.
|
| 242 |
+
padding: controls the amount of implicit zero-paddings on both sides for padding number of points
|
| 243 |
+
for each dimension. Defaults to None.
|
| 244 |
+
|
| 245 |
+
.. deprecated:: 0.6.0
|
| 246 |
+
``dimensions`` is deprecated, use ``spatial_dims`` instead.
|
| 247 |
+
|
| 248 |
+
See also:
|
| 249 |
+
|
| 250 |
+
:py:class:`monai.networks.blocks.Convolution`
|
| 251 |
+
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
@deprecated_arg(name="dimensions", since="0.6", msg_suffix="Please use `spatial_dims` instead.")
|
| 255 |
+
def __init__(
|
| 256 |
+
self,
|
| 257 |
+
spatial_dims: int,
|
| 258 |
+
in_channels: int,
|
| 259 |
+
out_channels: int,
|
| 260 |
+
strides: Union[Sequence[int], int] = 1,
|
| 261 |
+
kernel_size: Union[Sequence[int], int] = 3,
|
| 262 |
+
subunits: int = 2,
|
| 263 |
+
adn_ordering: str = "NDA",
|
| 264 |
+
act: Optional[Union[Tuple, str]] = "PRELU",
|
| 265 |
+
norm: Optional[Union[Tuple, str]] = "INSTANCE",
|
| 266 |
+
dropout: Optional[Union[Tuple, str, float]] = None,
|
| 267 |
+
dropout_dim: Optional[int] = 1,
|
| 268 |
+
dilation: Union[Sequence[int], int] = 1,
|
| 269 |
+
bias: bool = True,
|
| 270 |
+
last_conv_only: bool = False,
|
| 271 |
+
padding: Optional[Union[Sequence[int], int]] = None,
|
| 272 |
+
dimensions: Optional[int] = None,
|
| 273 |
+
) -> None:
|
| 274 |
+
super().__init__()
|
| 275 |
+
self.dimensions = spatial_dims if dimensions is None else dimensions
|
| 276 |
+
self.in_channels = in_channels
|
| 277 |
+
self.out_channels = out_channels
|
| 278 |
+
self.conv = nn.Sequential()
|
| 279 |
+
self.residual = nn.Identity()
|
| 280 |
+
if not padding:
|
| 281 |
+
padding = same_padding(kernel_size, dilation)
|
| 282 |
+
schannels = in_channels
|
| 283 |
+
sstrides = strides
|
| 284 |
+
subunits = max(1, subunits)
|
| 285 |
+
|
| 286 |
+
for su in range(subunits):
|
| 287 |
+
conv_only = last_conv_only and su == (subunits - 1)
|
| 288 |
+
unit = Convolution(
|
| 289 |
+
self.dimensions,
|
| 290 |
+
schannels,
|
| 291 |
+
out_channels,
|
| 292 |
+
strides=sstrides,
|
| 293 |
+
kernel_size=kernel_size,
|
| 294 |
+
adn_ordering=adn_ordering,
|
| 295 |
+
act=act,
|
| 296 |
+
norm=norm,
|
| 297 |
+
dropout=dropout,
|
| 298 |
+
dropout_dim=dropout_dim,
|
| 299 |
+
dilation=dilation,
|
| 300 |
+
bias=bias,
|
| 301 |
+
conv_only=conv_only,
|
| 302 |
+
padding=padding,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.conv.add_module(f"unit{su:d}", unit)
|
| 306 |
+
|
| 307 |
+
# after first loop set channels and strides to what they should be for subsequent units
|
| 308 |
+
schannels = out_channels
|
| 309 |
+
sstrides = 1
|
| 310 |
+
|
| 311 |
+
# apply convolution to input to change number of output channels and size to match that coming from self.conv
|
| 312 |
+
if np.prod(strides) != 1 or in_channels != out_channels:
|
| 313 |
+
rkernel_size = kernel_size
|
| 314 |
+
rpadding = padding
|
| 315 |
+
|
| 316 |
+
if np.prod(strides) == 1: # if only adapting number of channels a 1x1 kernel is used with no padding
|
| 317 |
+
rkernel_size = 1
|
| 318 |
+
rpadding = 0
|
| 319 |
+
|
| 320 |
+
conv_type = Conv[Conv.CONV, self.dimensions]
|
| 321 |
+
self.residual = conv_type(in_channels, out_channels, rkernel_size, strides, rpadding, bias=bias)
|
| 322 |
+
|
| 323 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 324 |
+
res: torch.Tensor = self.residual(x) # create the additive residual from x
|
| 325 |
+
cx: torch.Tensor = self.conv(x) # apply x to sequence of operations
|
| 326 |
+
return cx + res # add the residual to the output
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/crf.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch.nn.functional import softmax
|
| 16 |
+
|
| 17 |
+
from monai.networks.layers.filtering import PHLFilter
|
| 18 |
+
from monai.networks.utils import meshgrid_ij
|
| 19 |
+
|
| 20 |
+
__all__ = ["CRF"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CRF(torch.nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
Conditional Random Field: Combines message passing with a class
|
| 26 |
+
compatibility convolution into an iterative process designed
|
| 27 |
+
to successively minimise the energy of the class labeling.
|
| 28 |
+
|
| 29 |
+
In this implementation, the message passing step is a weighted
|
| 30 |
+
combination of a gaussian filter and a bilateral filter.
|
| 31 |
+
The bilateral term is included to respect existing structure
|
| 32 |
+
within the reference tensor.
|
| 33 |
+
|
| 34 |
+
See:
|
| 35 |
+
https://arxiv.org/abs/1502.03240
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
iterations: int = 5,
|
| 41 |
+
bilateral_weight: float = 1.0,
|
| 42 |
+
gaussian_weight: float = 1.0,
|
| 43 |
+
bilateral_spatial_sigma: float = 5.0,
|
| 44 |
+
bilateral_color_sigma: float = 0.5,
|
| 45 |
+
gaussian_spatial_sigma: float = 5.0,
|
| 46 |
+
update_factor: float = 3.0,
|
| 47 |
+
compatibility_matrix: Optional[torch.Tensor] = None,
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Args:
|
| 51 |
+
iterations: the number of iterations.
|
| 52 |
+
bilateral_weight: the weighting of the bilateral term in the message passing step.
|
| 53 |
+
gaussian_weight: the weighting of the gaussian term in the message passing step.
|
| 54 |
+
bilateral_spatial_sigma: standard deviation in spatial coordinates for the bilateral term.
|
| 55 |
+
bilateral_color_sigma: standard deviation in color space for the bilateral term.
|
| 56 |
+
gaussian_spatial_sigma: standard deviation in spatial coordinates for the gaussian term.
|
| 57 |
+
update_factor: determines the magnitude of each update.
|
| 58 |
+
compatibility_matrix: a matrix describing class compatibility,
|
| 59 |
+
should be NxN where N is the number of classes.
|
| 60 |
+
"""
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.iterations = iterations
|
| 63 |
+
self.bilateral_weight = bilateral_weight
|
| 64 |
+
self.gaussian_weight = gaussian_weight
|
| 65 |
+
self.bilateral_spatial_sigma = bilateral_spatial_sigma
|
| 66 |
+
self.bilateral_color_sigma = bilateral_color_sigma
|
| 67 |
+
self.gaussian_spatial_sigma = gaussian_spatial_sigma
|
| 68 |
+
self.update_factor = update_factor
|
| 69 |
+
self.compatibility_matrix = compatibility_matrix
|
| 70 |
+
|
| 71 |
+
def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor):
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
input_tensor: tensor containing initial class logits.
|
| 75 |
+
reference_tensor: the reference tensor used to guide the message passing.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
output (torch.Tensor): output tensor.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
# constructing spatial feature tensor
|
| 82 |
+
spatial_features = _create_coordinate_tensor(reference_tensor)
|
| 83 |
+
|
| 84 |
+
# constructing final feature tensors for bilateral and gaussian kernel
|
| 85 |
+
bilateral_features = torch.cat(
|
| 86 |
+
[spatial_features / self.bilateral_spatial_sigma, reference_tensor / self.bilateral_color_sigma], dim=1
|
| 87 |
+
)
|
| 88 |
+
gaussian_features = spatial_features / self.gaussian_spatial_sigma
|
| 89 |
+
|
| 90 |
+
# setting up output tensor
|
| 91 |
+
output_tensor = softmax(input_tensor, dim=1)
|
| 92 |
+
|
| 93 |
+
# mean field loop
|
| 94 |
+
for _ in range(self.iterations):
|
| 95 |
+
|
| 96 |
+
# message passing step for both kernels
|
| 97 |
+
bilateral_output = PHLFilter.apply(output_tensor, bilateral_features)
|
| 98 |
+
gaussian_output = PHLFilter.apply(output_tensor, gaussian_features)
|
| 99 |
+
|
| 100 |
+
# combining filter outputs
|
| 101 |
+
combined_output = self.bilateral_weight * bilateral_output + self.gaussian_weight * gaussian_output
|
| 102 |
+
|
| 103 |
+
# optionally running a compatibility transform
|
| 104 |
+
if self.compatibility_matrix is not None:
|
| 105 |
+
flat = combined_output.flatten(start_dim=2).permute(0, 2, 1)
|
| 106 |
+
flat = torch.matmul(flat, self.compatibility_matrix)
|
| 107 |
+
combined_output = flat.permute(0, 2, 1).reshape(combined_output.shape)
|
| 108 |
+
|
| 109 |
+
# update and normalize
|
| 110 |
+
output_tensor = softmax(input_tensor + self.update_factor * combined_output, dim=1)
|
| 111 |
+
|
| 112 |
+
return output_tensor
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# helper methods
|
| 116 |
+
def _create_coordinate_tensor(tensor):
|
| 117 |
+
axes = [torch.arange(tensor.size(i)) for i in range(2, tensor.dim())]
|
| 118 |
+
grids = meshgrid_ij(axes)
|
| 119 |
+
coords = torch.stack(grids).to(device=tensor.device, dtype=tensor.dtype)
|
| 120 |
+
return torch.stack(tensor.size(0) * [coords], dim=0)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/dints_block.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
from typing import Tuple, Union
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from monai.networks.layers.factories import Conv
|
| 18 |
+
from monai.networks.layers.utils import get_act_layer, get_norm_layer
|
| 19 |
+
|
| 20 |
+
__all__ = ["FactorizedIncreaseBlock", "FactorizedReduceBlock", "P3DActiConvNormBlock", "ActiConvNormBlock"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class FactorizedIncreaseBlock(torch.nn.Sequential):
|
| 24 |
+
"""
|
| 25 |
+
Up-sampling the features by two using linear interpolation and convolutions.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
in_channel: int,
|
| 31 |
+
out_channel: int,
|
| 32 |
+
spatial_dims: int = 3,
|
| 33 |
+
act_name: Union[Tuple, str] = "RELU",
|
| 34 |
+
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Args:
|
| 38 |
+
in_channel: number of input channels
|
| 39 |
+
out_channel: number of output channels
|
| 40 |
+
spatial_dims: number of spatial dimensions
|
| 41 |
+
act_name: activation layer type and arguments.
|
| 42 |
+
norm_name: feature normalization type and arguments.
|
| 43 |
+
"""
|
| 44 |
+
super().__init__()
|
| 45 |
+
self._in_channel = in_channel
|
| 46 |
+
self._out_channel = out_channel
|
| 47 |
+
self._spatial_dims = spatial_dims
|
| 48 |
+
if self._spatial_dims not in (2, 3):
|
| 49 |
+
raise ValueError("spatial_dims must be 2 or 3.")
|
| 50 |
+
|
| 51 |
+
conv_type = Conv[Conv.CONV, self._spatial_dims]
|
| 52 |
+
mode = "trilinear" if self._spatial_dims == 3 else "bilinear"
|
| 53 |
+
self.add_module("up", torch.nn.Upsample(scale_factor=2, mode=mode, align_corners=True))
|
| 54 |
+
self.add_module("acti", get_act_layer(name=act_name))
|
| 55 |
+
self.add_module(
|
| 56 |
+
"conv",
|
| 57 |
+
conv_type(
|
| 58 |
+
in_channels=self._in_channel,
|
| 59 |
+
out_channels=self._out_channel,
|
| 60 |
+
kernel_size=1,
|
| 61 |
+
stride=1,
|
| 62 |
+
padding=0,
|
| 63 |
+
groups=1,
|
| 64 |
+
bias=False,
|
| 65 |
+
dilation=1,
|
| 66 |
+
),
|
| 67 |
+
)
|
| 68 |
+
self.add_module(
|
| 69 |
+
"norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class FactorizedReduceBlock(torch.nn.Module):
|
| 74 |
+
"""
|
| 75 |
+
Down-sampling the feature by 2 using stride.
|
| 76 |
+
The length along each spatial dimension must be a multiple of 2.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
in_channel: int,
|
| 82 |
+
out_channel: int,
|
| 83 |
+
spatial_dims: int = 3,
|
| 84 |
+
act_name: Union[Tuple, str] = "RELU",
|
| 85 |
+
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
|
| 86 |
+
):
|
| 87 |
+
"""
|
| 88 |
+
Args:
|
| 89 |
+
in_channel: number of input channels
|
| 90 |
+
out_channel: number of output channels.
|
| 91 |
+
spatial_dims: number of spatial dimensions.
|
| 92 |
+
act_name: activation layer type and arguments.
|
| 93 |
+
norm_name: feature normalization type and arguments.
|
| 94 |
+
"""
|
| 95 |
+
super().__init__()
|
| 96 |
+
self._in_channel = in_channel
|
| 97 |
+
self._out_channel = out_channel
|
| 98 |
+
self._spatial_dims = spatial_dims
|
| 99 |
+
if self._spatial_dims not in (2, 3):
|
| 100 |
+
raise ValueError("spatial_dims must be 2 or 3.")
|
| 101 |
+
|
| 102 |
+
conv_type = Conv[Conv.CONV, self._spatial_dims]
|
| 103 |
+
|
| 104 |
+
self.act = get_act_layer(name=act_name)
|
| 105 |
+
self.conv_1 = conv_type(
|
| 106 |
+
in_channels=self._in_channel,
|
| 107 |
+
out_channels=self._out_channel // 2,
|
| 108 |
+
kernel_size=1,
|
| 109 |
+
stride=2,
|
| 110 |
+
padding=0,
|
| 111 |
+
groups=1,
|
| 112 |
+
bias=False,
|
| 113 |
+
dilation=1,
|
| 114 |
+
)
|
| 115 |
+
self.conv_2 = conv_type(
|
| 116 |
+
in_channels=self._in_channel,
|
| 117 |
+
out_channels=self._out_channel - self._out_channel // 2,
|
| 118 |
+
kernel_size=1,
|
| 119 |
+
stride=2,
|
| 120 |
+
padding=0,
|
| 121 |
+
groups=1,
|
| 122 |
+
bias=False,
|
| 123 |
+
dilation=1,
|
| 124 |
+
)
|
| 125 |
+
self.norm = get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel)
|
| 126 |
+
|
| 127 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
"""
|
| 129 |
+
The length along each spatial dimension must be a multiple of 2.
|
| 130 |
+
"""
|
| 131 |
+
x = self.act(x)
|
| 132 |
+
if self._spatial_dims == 3:
|
| 133 |
+
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:, 1:])], dim=1)
|
| 134 |
+
else:
|
| 135 |
+
out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
|
| 136 |
+
out = self.norm(out)
|
| 137 |
+
return out
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class P3DActiConvNormBlock(torch.nn.Sequential):
|
| 141 |
+
"""
|
| 142 |
+
-- (act) -- (conv) -- (norm) --
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
in_channel: int,
|
| 148 |
+
out_channel: int,
|
| 149 |
+
kernel_size: int,
|
| 150 |
+
padding: int,
|
| 151 |
+
mode: int = 0,
|
| 152 |
+
act_name: Union[Tuple, str] = "RELU",
|
| 153 |
+
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
|
| 154 |
+
):
|
| 155 |
+
"""
|
| 156 |
+
Args:
|
| 157 |
+
in_channel: number of input channels.
|
| 158 |
+
out_channel: number of output channels.
|
| 159 |
+
kernel_size: kernel size to be expanded to 3D.
|
| 160 |
+
padding: padding size to be expanded to 3D.
|
| 161 |
+
mode: mode for the anisotropic kernels:
|
| 162 |
+
|
| 163 |
+
- 0: ``(k, k, 1)``, ``(1, 1, k)``,
|
| 164 |
+
- 1: ``(k, 1, k)``, ``(1, k, 1)``,
|
| 165 |
+
- 2: ``(1, k, k)``. ``(k, 1, 1)``.
|
| 166 |
+
|
| 167 |
+
act_name: activation layer type and arguments.
|
| 168 |
+
norm_name: feature normalization type and arguments.
|
| 169 |
+
"""
|
| 170 |
+
super().__init__()
|
| 171 |
+
self._in_channel = in_channel
|
| 172 |
+
self._out_channel = out_channel
|
| 173 |
+
self._p3dmode = int(mode)
|
| 174 |
+
|
| 175 |
+
conv_type = Conv[Conv.CONV, 3]
|
| 176 |
+
|
| 177 |
+
if self._p3dmode == 0: # (k, k, 1), (1, 1, k)
|
| 178 |
+
kernel_size0 = (kernel_size, kernel_size, 1)
|
| 179 |
+
kernel_size1 = (1, 1, kernel_size)
|
| 180 |
+
padding0 = (padding, padding, 0)
|
| 181 |
+
padding1 = (0, 0, padding)
|
| 182 |
+
elif self._p3dmode == 1: # (k, 1, k), (1, k, 1)
|
| 183 |
+
kernel_size0 = (kernel_size, 1, kernel_size)
|
| 184 |
+
kernel_size1 = (1, kernel_size, 1)
|
| 185 |
+
padding0 = (padding, 0, padding)
|
| 186 |
+
padding1 = (0, padding, 0)
|
| 187 |
+
elif self._p3dmode == 2: # (1, k, k), (k, 1, 1)
|
| 188 |
+
kernel_size0 = (1, kernel_size, kernel_size)
|
| 189 |
+
kernel_size1 = (kernel_size, 1, 1)
|
| 190 |
+
padding0 = (0, padding, padding)
|
| 191 |
+
padding1 = (padding, 0, 0)
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError("`mode` must be 0, 1, or 2.")
|
| 194 |
+
|
| 195 |
+
self.add_module("acti", get_act_layer(name=act_name))
|
| 196 |
+
self.add_module(
|
| 197 |
+
"conv",
|
| 198 |
+
conv_type(
|
| 199 |
+
in_channels=self._in_channel,
|
| 200 |
+
out_channels=self._in_channel,
|
| 201 |
+
kernel_size=kernel_size0,
|
| 202 |
+
stride=1,
|
| 203 |
+
padding=padding0,
|
| 204 |
+
groups=1,
|
| 205 |
+
bias=False,
|
| 206 |
+
dilation=1,
|
| 207 |
+
),
|
| 208 |
+
)
|
| 209 |
+
self.add_module(
|
| 210 |
+
"conv_1",
|
| 211 |
+
conv_type(
|
| 212 |
+
in_channels=self._in_channel,
|
| 213 |
+
out_channels=self._out_channel,
|
| 214 |
+
kernel_size=kernel_size1,
|
| 215 |
+
stride=1,
|
| 216 |
+
padding=padding1,
|
| 217 |
+
groups=1,
|
| 218 |
+
bias=False,
|
| 219 |
+
dilation=1,
|
| 220 |
+
),
|
| 221 |
+
)
|
| 222 |
+
self.add_module("norm", get_norm_layer(name=norm_name, spatial_dims=3, channels=self._out_channel))
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class ActiConvNormBlock(torch.nn.Sequential):
|
| 226 |
+
"""
|
| 227 |
+
-- (Acti) -- (Conv) -- (Norm) --
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
def __init__(
|
| 231 |
+
self,
|
| 232 |
+
in_channel: int,
|
| 233 |
+
out_channel: int,
|
| 234 |
+
kernel_size: int = 3,
|
| 235 |
+
padding: int = 1,
|
| 236 |
+
spatial_dims: int = 3,
|
| 237 |
+
act_name: Union[Tuple, str] = "RELU",
|
| 238 |
+
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
|
| 239 |
+
):
|
| 240 |
+
"""
|
| 241 |
+
Args:
|
| 242 |
+
in_channel: number of input channels.
|
| 243 |
+
out_channel: number of output channels.
|
| 244 |
+
kernel_size: kernel size of the convolution.
|
| 245 |
+
padding: padding size of the convolution.
|
| 246 |
+
spatial_dims: number of spatial dimensions.
|
| 247 |
+
act_name: activation layer type and arguments.
|
| 248 |
+
norm_name: feature normalization type and arguments.
|
| 249 |
+
"""
|
| 250 |
+
super().__init__()
|
| 251 |
+
self._in_channel = in_channel
|
| 252 |
+
self._out_channel = out_channel
|
| 253 |
+
self._spatial_dims = spatial_dims
|
| 254 |
+
|
| 255 |
+
conv_type = Conv[Conv.CONV, self._spatial_dims]
|
| 256 |
+
self.add_module("acti", get_act_layer(name=act_name))
|
| 257 |
+
self.add_module(
|
| 258 |
+
"conv",
|
| 259 |
+
conv_type(
|
| 260 |
+
in_channels=self._in_channel,
|
| 261 |
+
out_channels=self._out_channel,
|
| 262 |
+
kernel_size=kernel_size,
|
| 263 |
+
stride=1,
|
| 264 |
+
padding=padding,
|
| 265 |
+
groups=1,
|
| 266 |
+
bias=False,
|
| 267 |
+
dilation=1,
|
| 268 |
+
),
|
| 269 |
+
)
|
| 270 |
+
self.add_module(
|
| 271 |
+
"norm", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel)
|
| 272 |
+
)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/downsample.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from typing import Optional, Sequence, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
from monai.networks.layers.factories import Pool
|
| 18 |
+
from monai.utils import ensure_tuple_rep
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MaxAvgPool(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
Downsample with both maxpooling and avgpooling,
|
| 24 |
+
double the channel size by concatenating the downsampled feature maps.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
spatial_dims: int,
|
| 30 |
+
kernel_size: Union[Sequence[int], int],
|
| 31 |
+
stride: Optional[Union[Sequence[int], int]] = None,
|
| 32 |
+
padding: Union[Sequence[int], int] = 0,
|
| 33 |
+
ceil_mode: bool = False,
|
| 34 |
+
) -> None:
|
| 35 |
+
"""
|
| 36 |
+
Args:
|
| 37 |
+
spatial_dims: number of spatial dimensions of the input image.
|
| 38 |
+
kernel_size: the kernel size of both pooling operations.
|
| 39 |
+
stride: the stride of the window. Default value is `kernel_size`.
|
| 40 |
+
padding: implicit zero padding to be added to both pooling operations.
|
| 41 |
+
ceil_mode: when True, will use ceil instead of floor to compute the output shape.
|
| 42 |
+
"""
|
| 43 |
+
super().__init__()
|
| 44 |
+
_params = {
|
| 45 |
+
"kernel_size": ensure_tuple_rep(kernel_size, spatial_dims),
|
| 46 |
+
"stride": None if stride is None else ensure_tuple_rep(stride, spatial_dims),
|
| 47 |
+
"padding": ensure_tuple_rep(padding, spatial_dims),
|
| 48 |
+
"ceil_mode": ceil_mode,
|
| 49 |
+
}
|
| 50 |
+
self.max_pool = Pool[Pool.MAX, spatial_dims](**_params)
|
| 51 |
+
self.avg_pool = Pool[Pool.AVG, spatial_dims](**_params)
|
| 52 |
+
|
| 53 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 54 |
+
"""
|
| 55 |
+
Args:
|
| 56 |
+
x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...]).
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Tensor in shape (batch, 2*channel, spatial_1[, spatial_2, ...]).
|
| 60 |
+
"""
|
| 61 |
+
return torch.cat([self.max_pool(x), self.avg_pool(x)], dim=1)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/dynunet_block.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from typing import Optional, Sequence, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
from monai.networks.blocks.convolutions import Convolution
|
| 19 |
+
from monai.networks.layers.factories import Act, Norm
|
| 20 |
+
from monai.networks.layers.utils import get_act_layer, get_norm_layer
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class UnetResBlock(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
A skip-connection based module that can be used for DynUNet, based on:
|
| 26 |
+
`Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
|
| 27 |
+
`nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
spatial_dims: number of spatial dimensions.
|
| 31 |
+
in_channels: number of input channels.
|
| 32 |
+
out_channels: number of output channels.
|
| 33 |
+
kernel_size: convolution kernel size.
|
| 34 |
+
stride: convolution stride.
|
| 35 |
+
norm_name: feature normalization type and arguments.
|
| 36 |
+
act_name: activation layer type and arguments.
|
| 37 |
+
dropout: dropout probability.
|
| 38 |
+
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
spatial_dims: int,
|
| 44 |
+
in_channels: int,
|
| 45 |
+
out_channels: int,
|
| 46 |
+
kernel_size: Union[Sequence[int], int],
|
| 47 |
+
stride: Union[Sequence[int], int],
|
| 48 |
+
norm_name: Union[Tuple, str],
|
| 49 |
+
act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
|
| 50 |
+
dropout: Optional[Union[Tuple, str, float]] = None,
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.conv1 = get_conv_layer(
|
| 54 |
+
spatial_dims,
|
| 55 |
+
in_channels,
|
| 56 |
+
out_channels,
|
| 57 |
+
kernel_size=kernel_size,
|
| 58 |
+
stride=stride,
|
| 59 |
+
dropout=dropout,
|
| 60 |
+
conv_only=True,
|
| 61 |
+
)
|
| 62 |
+
self.conv2 = get_conv_layer(
|
| 63 |
+
spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True
|
| 64 |
+
)
|
| 65 |
+
self.lrelu = get_act_layer(name=act_name)
|
| 66 |
+
self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
|
| 67 |
+
self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
|
| 68 |
+
self.downsample = in_channels != out_channels
|
| 69 |
+
stride_np = np.atleast_1d(stride)
|
| 70 |
+
if not np.all(stride_np == 1):
|
| 71 |
+
self.downsample = True
|
| 72 |
+
if self.downsample:
|
| 73 |
+
self.conv3 = get_conv_layer(
|
| 74 |
+
spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True
|
| 75 |
+
)
|
| 76 |
+
self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
|
| 77 |
+
|
| 78 |
+
def forward(self, inp):
|
| 79 |
+
residual = inp
|
| 80 |
+
out = self.conv1(inp)
|
| 81 |
+
out = self.norm1(out)
|
| 82 |
+
out = self.lrelu(out)
|
| 83 |
+
out = self.conv2(out)
|
| 84 |
+
out = self.norm2(out)
|
| 85 |
+
if hasattr(self, "conv3"):
|
| 86 |
+
residual = self.conv3(residual)
|
| 87 |
+
if hasattr(self, "norm3"):
|
| 88 |
+
residual = self.norm3(residual)
|
| 89 |
+
out += residual
|
| 90 |
+
out = self.lrelu(out)
|
| 91 |
+
return out
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class UnetBasicBlock(nn.Module):
|
| 95 |
+
"""
|
| 96 |
+
A CNN module module that can be used for DynUNet, based on:
|
| 97 |
+
`Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
|
| 98 |
+
`nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
spatial_dims: number of spatial dimensions.
|
| 102 |
+
in_channels: number of input channels.
|
| 103 |
+
out_channels: number of output channels.
|
| 104 |
+
kernel_size: convolution kernel size.
|
| 105 |
+
stride: convolution stride.
|
| 106 |
+
norm_name: feature normalization type and arguments.
|
| 107 |
+
act_name: activation layer type and arguments.
|
| 108 |
+
dropout: dropout probability.
|
| 109 |
+
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
spatial_dims: int,
|
| 115 |
+
in_channels: int,
|
| 116 |
+
out_channels: int,
|
| 117 |
+
kernel_size: Union[Sequence[int], int],
|
| 118 |
+
stride: Union[Sequence[int], int],
|
| 119 |
+
norm_name: Union[Tuple, str],
|
| 120 |
+
act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
|
| 121 |
+
dropout: Optional[Union[Tuple, str, float]] = None,
|
| 122 |
+
):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.conv1 = get_conv_layer(
|
| 125 |
+
spatial_dims,
|
| 126 |
+
in_channels,
|
| 127 |
+
out_channels,
|
| 128 |
+
kernel_size=kernel_size,
|
| 129 |
+
stride=stride,
|
| 130 |
+
dropout=dropout,
|
| 131 |
+
conv_only=True,
|
| 132 |
+
)
|
| 133 |
+
self.conv2 = get_conv_layer(
|
| 134 |
+
spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True
|
| 135 |
+
)
|
| 136 |
+
self.lrelu = get_act_layer(name=act_name)
|
| 137 |
+
self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
|
| 138 |
+
self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
|
| 139 |
+
|
| 140 |
+
def forward(self, inp):
|
| 141 |
+
out = self.conv1(inp)
|
| 142 |
+
out = self.norm1(out)
|
| 143 |
+
out = self.lrelu(out)
|
| 144 |
+
out = self.conv2(out)
|
| 145 |
+
out = self.norm2(out)
|
| 146 |
+
out = self.lrelu(out)
|
| 147 |
+
return out
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class UnetUpBlock(nn.Module):
|
| 151 |
+
"""
|
| 152 |
+
An upsampling module that can be used for DynUNet, based on:
|
| 153 |
+
`Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
|
| 154 |
+
`nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
spatial_dims: number of spatial dimensions.
|
| 158 |
+
in_channels: number of input channels.
|
| 159 |
+
out_channels: number of output channels.
|
| 160 |
+
kernel_size: convolution kernel size.
|
| 161 |
+
stride: convolution stride.
|
| 162 |
+
upsample_kernel_size: convolution kernel size for transposed convolution layers.
|
| 163 |
+
norm_name: feature normalization type and arguments.
|
| 164 |
+
act_name: activation layer type and arguments.
|
| 165 |
+
dropout: dropout probability.
|
| 166 |
+
trans_bias: transposed convolution bias.
|
| 167 |
+
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
spatial_dims: int,
|
| 173 |
+
in_channels: int,
|
| 174 |
+
out_channels: int,
|
| 175 |
+
kernel_size: Union[Sequence[int], int],
|
| 176 |
+
stride: Union[Sequence[int], int],
|
| 177 |
+
upsample_kernel_size: Union[Sequence[int], int],
|
| 178 |
+
norm_name: Union[Tuple, str],
|
| 179 |
+
act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
|
| 180 |
+
dropout: Optional[Union[Tuple, str, float]] = None,
|
| 181 |
+
trans_bias: bool = False,
|
| 182 |
+
):
|
| 183 |
+
super().__init__()
|
| 184 |
+
upsample_stride = upsample_kernel_size
|
| 185 |
+
self.transp_conv = get_conv_layer(
|
| 186 |
+
spatial_dims,
|
| 187 |
+
in_channels,
|
| 188 |
+
out_channels,
|
| 189 |
+
kernel_size=upsample_kernel_size,
|
| 190 |
+
stride=upsample_stride,
|
| 191 |
+
dropout=dropout,
|
| 192 |
+
bias=trans_bias,
|
| 193 |
+
conv_only=True,
|
| 194 |
+
is_transposed=True,
|
| 195 |
+
)
|
| 196 |
+
self.conv_block = UnetBasicBlock(
|
| 197 |
+
spatial_dims,
|
| 198 |
+
out_channels + out_channels,
|
| 199 |
+
out_channels,
|
| 200 |
+
kernel_size=kernel_size,
|
| 201 |
+
stride=1,
|
| 202 |
+
dropout=dropout,
|
| 203 |
+
norm_name=norm_name,
|
| 204 |
+
act_name=act_name,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def forward(self, inp, skip):
|
| 208 |
+
# number of channels for skip should equals to out_channels
|
| 209 |
+
out = self.transp_conv(inp)
|
| 210 |
+
out = torch.cat((out, skip), dim=1)
|
| 211 |
+
out = self.conv_block(out)
|
| 212 |
+
return out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class UnetOutBlock(nn.Module):
|
| 216 |
+
def __init__(
|
| 217 |
+
self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None
|
| 218 |
+
):
|
| 219 |
+
super().__init__()
|
| 220 |
+
self.conv = get_conv_layer(
|
| 221 |
+
spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
def forward(self, inp):
|
| 225 |
+
return self.conv(inp)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def get_conv_layer(
|
| 229 |
+
spatial_dims: int,
|
| 230 |
+
in_channels: int,
|
| 231 |
+
out_channels: int,
|
| 232 |
+
kernel_size: Union[Sequence[int], int] = 3,
|
| 233 |
+
stride: Union[Sequence[int], int] = 1,
|
| 234 |
+
act: Optional[Union[Tuple, str]] = Act.PRELU,
|
| 235 |
+
norm: Union[Tuple, str] = Norm.INSTANCE,
|
| 236 |
+
dropout: Optional[Union[Tuple, str, float]] = None,
|
| 237 |
+
bias: bool = False,
|
| 238 |
+
conv_only: bool = True,
|
| 239 |
+
is_transposed: bool = False,
|
| 240 |
+
):
|
| 241 |
+
padding = get_padding(kernel_size, stride)
|
| 242 |
+
output_padding = None
|
| 243 |
+
if is_transposed:
|
| 244 |
+
output_padding = get_output_padding(kernel_size, stride, padding)
|
| 245 |
+
return Convolution(
|
| 246 |
+
spatial_dims,
|
| 247 |
+
in_channels,
|
| 248 |
+
out_channels,
|
| 249 |
+
strides=stride,
|
| 250 |
+
kernel_size=kernel_size,
|
| 251 |
+
act=act,
|
| 252 |
+
norm=norm,
|
| 253 |
+
dropout=dropout,
|
| 254 |
+
bias=bias,
|
| 255 |
+
conv_only=conv_only,
|
| 256 |
+
is_transposed=is_transposed,
|
| 257 |
+
padding=padding,
|
| 258 |
+
output_padding=output_padding,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def get_padding(
|
| 263 |
+
kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int]
|
| 264 |
+
) -> Union[Tuple[int, ...], int]:
|
| 265 |
+
|
| 266 |
+
kernel_size_np = np.atleast_1d(kernel_size)
|
| 267 |
+
stride_np = np.atleast_1d(stride)
|
| 268 |
+
padding_np = (kernel_size_np - stride_np + 1) / 2
|
| 269 |
+
if np.min(padding_np) < 0:
|
| 270 |
+
raise AssertionError("padding value should not be negative, please change the kernel size and/or stride.")
|
| 271 |
+
padding = tuple(int(p) for p in padding_np)
|
| 272 |
+
|
| 273 |
+
return padding if len(padding) > 1 else padding[0]
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_output_padding(
|
| 277 |
+
kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int]
|
| 278 |
+
) -> Union[Tuple[int, ...], int]:
|
| 279 |
+
kernel_size_np = np.atleast_1d(kernel_size)
|
| 280 |
+
stride_np = np.atleast_1d(stride)
|
| 281 |
+
padding_np = np.atleast_1d(padding)
|
| 282 |
+
|
| 283 |
+
out_padding_np = 2 * padding_np + stride_np - kernel_size_np
|
| 284 |
+
if np.min(out_padding_np) < 0:
|
| 285 |
+
raise AssertionError("out_padding value should not be negative, please change the kernel size and/or stride.")
|
| 286 |
+
out_padding = tuple(int(p) for p in out_padding_np)
|
| 287 |
+
|
| 288 |
+
return out_padding if len(out_padding) > 1 else out_padding[0]
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/fcn.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from typing import Type
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from monai.networks.blocks.convolutions import Convolution
|
| 19 |
+
from monai.networks.blocks.upsample import UpSample
|
| 20 |
+
from monai.networks.layers.factories import Act, Conv, Norm
|
| 21 |
+
from monai.utils import optional_import
|
| 22 |
+
|
| 23 |
+
models, _ = optional_import("torchvision", name="models")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class GCN(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
The Global Convolutional Network module using large 1D
|
| 29 |
+
Kx1 and 1xK kernels to represent 2D kernels.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, inplanes: int, planes: int, ks: int = 7):
|
| 33 |
+
"""
|
| 34 |
+
Args:
|
| 35 |
+
inplanes: number of input channels.
|
| 36 |
+
planes: number of output channels.
|
| 37 |
+
ks: kernel size for one dimension. Defaults to 7.
|
| 38 |
+
"""
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2]
|
| 42 |
+
self.conv_l1 = conv2d_type(in_channels=inplanes, out_channels=planes, kernel_size=(ks, 1), padding=(ks // 2, 0))
|
| 43 |
+
self.conv_l2 = conv2d_type(in_channels=planes, out_channels=planes, kernel_size=(1, ks), padding=(0, ks // 2))
|
| 44 |
+
self.conv_r1 = conv2d_type(in_channels=inplanes, out_channels=planes, kernel_size=(1, ks), padding=(0, ks // 2))
|
| 45 |
+
self.conv_r2 = conv2d_type(in_channels=planes, out_channels=planes, kernel_size=(ks, 1), padding=(ks // 2, 0))
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
x: in shape (batch, inplanes, spatial_1, spatial_2).
|
| 51 |
+
"""
|
| 52 |
+
x_l = self.conv_l1(x)
|
| 53 |
+
x_l = self.conv_l2(x_l)
|
| 54 |
+
x_r = self.conv_r1(x)
|
| 55 |
+
x_r = self.conv_r2(x_r)
|
| 56 |
+
x = x_l + x_r
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Refine(nn.Module):
|
| 61 |
+
"""
|
| 62 |
+
Simple residual block to refine the details of the activation maps.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, planes: int):
|
| 66 |
+
"""
|
| 67 |
+
Args:
|
| 68 |
+
planes: number of input channels.
|
| 69 |
+
"""
|
| 70 |
+
super().__init__()
|
| 71 |
+
|
| 72 |
+
relu_type: Type[nn.ReLU] = Act[Act.RELU]
|
| 73 |
+
conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2]
|
| 74 |
+
norm2d_type: Type[nn.BatchNorm2d] = Norm[Norm.BATCH, 2]
|
| 75 |
+
|
| 76 |
+
self.bn = norm2d_type(num_features=planes)
|
| 77 |
+
self.relu = relu_type(inplace=True)
|
| 78 |
+
self.conv1 = conv2d_type(in_channels=planes, out_channels=planes, kernel_size=3, padding=1)
|
| 79 |
+
self.conv2 = conv2d_type(in_channels=planes, out_channels=planes, kernel_size=3, padding=1)
|
| 80 |
+
|
| 81 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 82 |
+
"""
|
| 83 |
+
Args:
|
| 84 |
+
x: in shape (batch, planes, spatial_1, spatial_2).
|
| 85 |
+
"""
|
| 86 |
+
residual = x
|
| 87 |
+
x = self.bn(x)
|
| 88 |
+
x = self.relu(x)
|
| 89 |
+
x = self.conv1(x)
|
| 90 |
+
x = self.bn(x)
|
| 91 |
+
x = self.relu(x)
|
| 92 |
+
x = self.conv2(x)
|
| 93 |
+
|
| 94 |
+
return residual + x
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class FCN(nn.Module):
|
| 98 |
+
"""
|
| 99 |
+
2D FCN network with 3 input channels. The small decoder is built
|
| 100 |
+
with the GCN and Refine modules.
|
| 101 |
+
The code is adapted from `lsqshr's official 2D code <https://github.com/lsqshr/AH-Net/blob/master/net2d.py>`_.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
out_channels: number of output channels. Defaults to 1.
|
| 105 |
+
upsample_mode: [``"transpose"``, ``"bilinear"``]
|
| 106 |
+
The mode of upsampling manipulations.
|
| 107 |
+
Using the second mode cannot guarantee the model's reproducibility. Defaults to ``bilinear``.
|
| 108 |
+
|
| 109 |
+
- ``transpose``, uses transposed convolution layers.
|
| 110 |
+
- ``bilinear``, uses bilinear interpolation.
|
| 111 |
+
|
| 112 |
+
pretrained: If True, returns a model pre-trained on ImageNet
|
| 113 |
+
progress: If True, displays a progress bar of the download to stderr.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self, out_channels: int = 1, upsample_mode: str = "bilinear", pretrained: bool = True, progress: bool = True
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
|
| 121 |
+
conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2]
|
| 122 |
+
|
| 123 |
+
self.upsample_mode = upsample_mode
|
| 124 |
+
self.conv2d_type = conv2d_type
|
| 125 |
+
self.out_channels = out_channels
|
| 126 |
+
resnet = models.resnet50(pretrained=pretrained, progress=progress)
|
| 127 |
+
|
| 128 |
+
self.conv1 = resnet.conv1
|
| 129 |
+
self.bn0 = resnet.bn1
|
| 130 |
+
self.relu = resnet.relu
|
| 131 |
+
self.maxpool = resnet.maxpool
|
| 132 |
+
|
| 133 |
+
self.layer1 = resnet.layer1
|
| 134 |
+
self.layer2 = resnet.layer2
|
| 135 |
+
self.layer3 = resnet.layer3
|
| 136 |
+
self.layer4 = resnet.layer4
|
| 137 |
+
|
| 138 |
+
self.gcn1 = GCN(2048, self.out_channels)
|
| 139 |
+
self.gcn2 = GCN(1024, self.out_channels)
|
| 140 |
+
self.gcn3 = GCN(512, self.out_channels)
|
| 141 |
+
self.gcn4 = GCN(64, self.out_channels)
|
| 142 |
+
self.gcn5 = GCN(64, self.out_channels)
|
| 143 |
+
|
| 144 |
+
self.refine1 = Refine(self.out_channels)
|
| 145 |
+
self.refine2 = Refine(self.out_channels)
|
| 146 |
+
self.refine3 = Refine(self.out_channels)
|
| 147 |
+
self.refine4 = Refine(self.out_channels)
|
| 148 |
+
self.refine5 = Refine(self.out_channels)
|
| 149 |
+
self.refine6 = Refine(self.out_channels)
|
| 150 |
+
self.refine7 = Refine(self.out_channels)
|
| 151 |
+
self.refine8 = Refine(self.out_channels)
|
| 152 |
+
self.refine9 = Refine(self.out_channels)
|
| 153 |
+
self.refine10 = Refine(self.out_channels)
|
| 154 |
+
self.transformer = self.conv2d_type(in_channels=256, out_channels=64, kernel_size=1)
|
| 155 |
+
|
| 156 |
+
if self.upsample_mode == "transpose":
|
| 157 |
+
self.up_conv = UpSample(spatial_dims=2, in_channels=self.out_channels, scale_factor=2, mode="deconv")
|
| 158 |
+
|
| 159 |
+
def forward(self, x: torch.Tensor):
|
| 160 |
+
"""
|
| 161 |
+
Args:
|
| 162 |
+
x: in shape (batch, 3, spatial_1, spatial_2).
|
| 163 |
+
"""
|
| 164 |
+
org_input = x
|
| 165 |
+
x = self.conv1(x)
|
| 166 |
+
x = self.bn0(x)
|
| 167 |
+
x = self.relu(x)
|
| 168 |
+
conv_x = x
|
| 169 |
+
x = self.maxpool(x)
|
| 170 |
+
pool_x = x
|
| 171 |
+
|
| 172 |
+
fm1 = self.layer1(x)
|
| 173 |
+
fm2 = self.layer2(fm1)
|
| 174 |
+
fm3 = self.layer3(fm2)
|
| 175 |
+
fm4 = self.layer4(fm3)
|
| 176 |
+
|
| 177 |
+
gcfm1 = self.refine1(self.gcn1(fm4))
|
| 178 |
+
gcfm2 = self.refine2(self.gcn2(fm3))
|
| 179 |
+
gcfm3 = self.refine3(self.gcn3(fm2))
|
| 180 |
+
gcfm4 = self.refine4(self.gcn4(pool_x))
|
| 181 |
+
gcfm5 = self.refine5(self.gcn5(conv_x))
|
| 182 |
+
|
| 183 |
+
if self.upsample_mode == "transpose":
|
| 184 |
+
fs1 = self.refine6(self.up_conv(gcfm1) + gcfm2)
|
| 185 |
+
fs2 = self.refine7(self.up_conv(fs1) + gcfm3)
|
| 186 |
+
fs3 = self.refine8(self.up_conv(fs2) + gcfm4)
|
| 187 |
+
fs4 = self.refine9(self.up_conv(fs3) + gcfm5)
|
| 188 |
+
return self.refine10(self.up_conv(fs4))
|
| 189 |
+
fs1 = self.refine6(F.interpolate(gcfm1, fm3.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm2)
|
| 190 |
+
fs2 = self.refine7(F.interpolate(fs1, fm2.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm3)
|
| 191 |
+
fs3 = self.refine8(F.interpolate(fs2, pool_x.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm4)
|
| 192 |
+
fs4 = self.refine9(F.interpolate(fs3, conv_x.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm5)
|
| 193 |
+
return self.refine10(F.interpolate(fs4, org_input.size()[2:], mode=self.upsample_mode, align_corners=True))
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class MCFCN(FCN):
|
| 197 |
+
"""
|
| 198 |
+
The multi-channel version of the 2D FCN module.
|
| 199 |
+
Adds a projection layer to take arbitrary number of inputs.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
in_channels: number of input channels. Defaults to 3.
|
| 203 |
+
out_channels: number of output channels. Defaults to 1.
|
| 204 |
+
upsample_mode: [``"transpose"``, ``"bilinear"``]
|
| 205 |
+
The mode of upsampling manipulations.
|
| 206 |
+
Using the second mode cannot guarantee the model's reproducibility. Defaults to ``bilinear``.
|
| 207 |
+
|
| 208 |
+
- ``transpose``, uses transposed convolution layers.
|
| 209 |
+
- ``bilinear``, uses bilinear interpolate.
|
| 210 |
+
pretrained: If True, returns a model pre-trained on ImageNet
|
| 211 |
+
progress: If True, displays a progress bar of the download to stderr.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
in_channels: int = 3,
|
| 217 |
+
out_channels: int = 1,
|
| 218 |
+
upsample_mode: str = "bilinear",
|
| 219 |
+
pretrained: bool = True,
|
| 220 |
+
progress: bool = True,
|
| 221 |
+
):
|
| 222 |
+
super().__init__(
|
| 223 |
+
out_channels=out_channels, upsample_mode=upsample_mode, pretrained=pretrained, progress=progress
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
self.init_proj = Convolution(
|
| 227 |
+
spatial_dims=2,
|
| 228 |
+
in_channels=in_channels,
|
| 229 |
+
out_channels=3,
|
| 230 |
+
kernel_size=1,
|
| 231 |
+
act=("relu", {"inplace": True}),
|
| 232 |
+
norm=Norm.BATCH,
|
| 233 |
+
bias=False,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def forward(self, x: torch.Tensor):
|
| 237 |
+
"""
|
| 238 |
+
Args:
|
| 239 |
+
x: in shape (batch, in_channels, spatial_1, spatial_2).
|
| 240 |
+
"""
|
| 241 |
+
x = self.init_proj(x)
|
| 242 |
+
return super().forward(x)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/feature_pyramid_network.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
# =========================================================================
|
| 13 |
+
# Adapted from https://github.com/pytorch/vision/blob/release/0.12/torchvision/ops/feature_pyramid_network.py
|
| 14 |
+
# which has the following license...
|
| 15 |
+
# https://github.com/pytorch/vision/blob/main/LICENSE
|
| 16 |
+
#
|
| 17 |
+
# BSD 3-Clause License
|
| 18 |
+
|
| 19 |
+
# Copyright (c) Soumith Chintala 2016,
|
| 20 |
+
# All rights reserved.
|
| 21 |
+
|
| 22 |
+
# Redistribution and use in source and binary forms, with or without
|
| 23 |
+
# modification, are permitted provided that the following conditions are met:
|
| 24 |
+
|
| 25 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
| 26 |
+
# list of conditions and the following disclaimer.
|
| 27 |
+
|
| 28 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
| 29 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 30 |
+
# and/or other materials provided with the distribution.
|
| 31 |
+
|
| 32 |
+
# * Neither the name of the copyright holder nor the names of its
|
| 33 |
+
# contributors may be used to endorse or promote products derived from
|
| 34 |
+
# this software without specific prior written permission.
|
| 35 |
+
|
| 36 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 37 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 38 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 39 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 40 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 41 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 42 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 43 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 44 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 45 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
This script is modified from from torchvision to support N-D images,
|
| 49 |
+
by overriding the definition of convolutional layers and pooling layers.
|
| 50 |
+
|
| 51 |
+
https://github.com/pytorch/vision/blob/release/0.12/torchvision/ops/feature_pyramid_network.py
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
from collections import OrderedDict
|
| 55 |
+
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
| 56 |
+
|
| 57 |
+
import torch.nn.functional as F
|
| 58 |
+
from torch import Tensor, nn
|
| 59 |
+
|
| 60 |
+
from monai.networks.layers.factories import Conv, Pool
|
| 61 |
+
|
| 62 |
+
__all__ = ["ExtraFPNBlock", "LastLevelMaxPool", "LastLevelP6P7", "FeaturePyramidNetwork"]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ExtraFPNBlock(nn.Module):
|
| 66 |
+
"""
|
| 67 |
+
Base class for the extra block in the FPN.
|
| 68 |
+
|
| 69 |
+
Same code as https://github.com/pytorch/vision/blob/release/0.12/torchvision/ops/feature_pyramid_network.py
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def forward(self, results: List[Tensor], x: List[Tensor], names: List[str]) -> Tuple[List[Tensor], List[str]]:
|
| 73 |
+
"""
|
| 74 |
+
Compute extended set of results of the FPN and their names.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
results: the result of the FPN
|
| 78 |
+
x: the original feature maps
|
| 79 |
+
names: the names for each one of the original feature maps
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
- the extended set of results of the FPN
|
| 83 |
+
- the extended set of names for the results
|
| 84 |
+
"""
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class LastLevelMaxPool(ExtraFPNBlock):
|
| 89 |
+
"""
|
| 90 |
+
Applies a max_pool2d or max_pool3d on top of the last feature map. Serves as an ``extra_blocks``
|
| 91 |
+
in :class:`~monai.networks.blocks.feature_pyramid_network.FeaturePyramidNetwork` .
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, spatial_dims: int):
|
| 95 |
+
super().__init__()
|
| 96 |
+
pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims]
|
| 97 |
+
self.maxpool = pool_type(kernel_size=1, stride=2, padding=0)
|
| 98 |
+
|
| 99 |
+
def forward(self, results: List[Tensor], x: List[Tensor], names: List[str]) -> Tuple[List[Tensor], List[str]]:
|
| 100 |
+
names.append("pool")
|
| 101 |
+
results.append(self.maxpool(results[-1]))
|
| 102 |
+
return results, names
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class LastLevelP6P7(ExtraFPNBlock):
|
| 106 |
+
"""
|
| 107 |
+
This module is used in RetinaNet to generate extra layers, P6 and P7.
|
| 108 |
+
Serves as an ``extra_blocks``
|
| 109 |
+
in :class:`~monai.networks.blocks.feature_pyramid_network.FeaturePyramidNetwork` .
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int):
|
| 113 |
+
super().__init__()
|
| 114 |
+
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
|
| 115 |
+
self.p6 = conv_type(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
| 116 |
+
self.p7 = conv_type(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
| 117 |
+
for module in [self.p6, self.p7]:
|
| 118 |
+
nn.init.kaiming_uniform_(module.weight, a=1)
|
| 119 |
+
nn.init.constant_(module.bias, 0)
|
| 120 |
+
self.use_P5 = in_channels == out_channels
|
| 121 |
+
|
| 122 |
+
def forward(self, results: List[Tensor], x: List[Tensor], names: List[str]) -> Tuple[List[Tensor], List[str]]:
|
| 123 |
+
p5, c5 = results[-1], x[-1]
|
| 124 |
+
x5 = p5 if self.use_P5 else c5
|
| 125 |
+
p6 = self.p6(x5)
|
| 126 |
+
p7 = self.p7(F.relu(p6))
|
| 127 |
+
results.extend([p6, p7])
|
| 128 |
+
names.extend(["p6", "p7"])
|
| 129 |
+
return results, names
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class FeaturePyramidNetwork(nn.Module):
|
| 133 |
+
"""
|
| 134 |
+
Module that adds a FPN from on top of a set of feature maps. This is based on
|
| 135 |
+
`"Feature Pyramid Network for Object Detection" <https://arxiv.org/abs/1612.03144>`_.
|
| 136 |
+
|
| 137 |
+
The feature maps are currently supposed to be in increasing depth
|
| 138 |
+
order.
|
| 139 |
+
|
| 140 |
+
The input to the model is expected to be an OrderedDict[Tensor], containing
|
| 141 |
+
the feature maps on top of which the FPN will be added.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
spatial_dims: 2D or 3D images
|
| 145 |
+
in_channels_list: number of channels for each feature map that
|
| 146 |
+
is passed to the module
|
| 147 |
+
out_channels: number of channels of the FPN representation
|
| 148 |
+
extra_blocks: if provided, extra operations will
|
| 149 |
+
be performed. It is expected to take the fpn features, the original
|
| 150 |
+
features and the names of the original features as input, and returns
|
| 151 |
+
a new list of feature maps and their corresponding names
|
| 152 |
+
|
| 153 |
+
Examples::
|
| 154 |
+
|
| 155 |
+
>>> m = FeaturePyramidNetwork(2, [10, 20, 30], 5)
|
| 156 |
+
>>> # get some dummy data
|
| 157 |
+
>>> x = OrderedDict()
|
| 158 |
+
>>> x['feat0'] = torch.rand(1, 10, 64, 64)
|
| 159 |
+
>>> x['feat2'] = torch.rand(1, 20, 16, 16)
|
| 160 |
+
>>> x['feat3'] = torch.rand(1, 30, 8, 8)
|
| 161 |
+
>>> # compute the FPN on top of x
|
| 162 |
+
>>> output = m(x)
|
| 163 |
+
>>> print([(k, v.shape) for k, v in output.items()])
|
| 164 |
+
>>> # returns
|
| 165 |
+
>>> [('feat0', torch.Size([1, 5, 64, 64])),
|
| 166 |
+
>>> ('feat2', torch.Size([1, 5, 16, 16])),
|
| 167 |
+
>>> ('feat3', torch.Size([1, 5, 8, 8]))]
|
| 168 |
+
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
spatial_dims: int,
|
| 174 |
+
in_channels_list: List[int],
|
| 175 |
+
out_channels: int,
|
| 176 |
+
extra_blocks: Optional[ExtraFPNBlock] = None,
|
| 177 |
+
):
|
| 178 |
+
super().__init__()
|
| 179 |
+
|
| 180 |
+
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
|
| 181 |
+
|
| 182 |
+
self.inner_blocks = nn.ModuleList()
|
| 183 |
+
self.layer_blocks = nn.ModuleList()
|
| 184 |
+
for in_channels in in_channels_list:
|
| 185 |
+
if in_channels == 0:
|
| 186 |
+
raise ValueError("in_channels=0 is currently not supported")
|
| 187 |
+
inner_block_module = conv_type(in_channels, out_channels, 1)
|
| 188 |
+
layer_block_module = conv_type(out_channels, out_channels, 3, padding=1)
|
| 189 |
+
self.inner_blocks.append(inner_block_module)
|
| 190 |
+
self.layer_blocks.append(layer_block_module)
|
| 191 |
+
|
| 192 |
+
# initialize parameters now to avoid modifying the initialization of top_blocks
|
| 193 |
+
conv_type_: Type[nn.Module] = Conv[Conv.CONV, spatial_dims]
|
| 194 |
+
for m in self.modules():
|
| 195 |
+
if isinstance(m, conv_type_):
|
| 196 |
+
nn.init.kaiming_uniform_(m.weight, a=1)
|
| 197 |
+
nn.init.constant_(m.bias, 0.0) # type: ignore
|
| 198 |
+
|
| 199 |
+
if extra_blocks is not None:
|
| 200 |
+
if not isinstance(extra_blocks, ExtraFPNBlock):
|
| 201 |
+
raise AssertionError
|
| 202 |
+
self.extra_blocks = extra_blocks
|
| 203 |
+
|
| 204 |
+
def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
|
| 205 |
+
"""
|
| 206 |
+
This is equivalent to self.inner_blocks[idx](x),
|
| 207 |
+
but torchscript doesn't support this yet
|
| 208 |
+
"""
|
| 209 |
+
num_blocks = len(self.inner_blocks)
|
| 210 |
+
if idx < 0:
|
| 211 |
+
idx += num_blocks
|
| 212 |
+
out = x
|
| 213 |
+
for i, module in enumerate(self.inner_blocks):
|
| 214 |
+
if i == idx:
|
| 215 |
+
out = module(x)
|
| 216 |
+
return out
|
| 217 |
+
|
| 218 |
+
def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
|
| 219 |
+
"""
|
| 220 |
+
This is equivalent to self.layer_blocks[idx](x),
|
| 221 |
+
but torchscript doesn't support this yet
|
| 222 |
+
"""
|
| 223 |
+
num_blocks = len(self.layer_blocks)
|
| 224 |
+
if idx < 0:
|
| 225 |
+
idx += num_blocks
|
| 226 |
+
out = x
|
| 227 |
+
for i, module in enumerate(self.layer_blocks):
|
| 228 |
+
if i == idx:
|
| 229 |
+
out = module(x)
|
| 230 |
+
return out
|
| 231 |
+
|
| 232 |
+
def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
| 233 |
+
"""
|
| 234 |
+
Computes the FPN for a set of feature maps.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
x: feature maps for each feature level.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
feature maps after FPN layers. They are ordered from highest resolution first.
|
| 241 |
+
"""
|
| 242 |
+
# unpack OrderedDict into two lists for easier handling
|
| 243 |
+
names = list(x.keys())
|
| 244 |
+
x_values: List[Tensor] = list(x.values())
|
| 245 |
+
|
| 246 |
+
last_inner = self.get_result_from_inner_blocks(x_values[-1], -1)
|
| 247 |
+
results = []
|
| 248 |
+
results.append(self.get_result_from_layer_blocks(last_inner, -1))
|
| 249 |
+
|
| 250 |
+
for idx in range(len(x_values) - 2, -1, -1):
|
| 251 |
+
inner_lateral = self.get_result_from_inner_blocks(x_values[idx], idx)
|
| 252 |
+
feat_shape = inner_lateral.shape[2:]
|
| 253 |
+
inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
|
| 254 |
+
last_inner = inner_lateral + inner_top_down
|
| 255 |
+
results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
|
| 256 |
+
|
| 257 |
+
if self.extra_blocks is not None:
|
| 258 |
+
results, names = self.extra_blocks(results, x_values, names)
|
| 259 |
+
|
| 260 |
+
# make it back an OrderedDict
|
| 261 |
+
out = OrderedDict([(k, v) for k, v in zip(names, results)])
|
| 262 |
+
|
| 263 |
+
return out
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/localnet_block.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from typing import Optional, Sequence, Tuple, Type, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn
|
| 16 |
+
from torch.nn import functional as F
|
| 17 |
+
|
| 18 |
+
from monai.networks.blocks import Convolution
|
| 19 |
+
from monai.networks.layers import same_padding
|
| 20 |
+
from monai.networks.layers.factories import Conv, Norm, Pool
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_conv_block(
|
| 24 |
+
spatial_dims: int,
|
| 25 |
+
in_channels: int,
|
| 26 |
+
out_channels: int,
|
| 27 |
+
kernel_size: Union[Sequence[int], int] = 3,
|
| 28 |
+
act: Optional[Union[Tuple, str]] = "RELU",
|
| 29 |
+
norm: Optional[Union[Tuple, str]] = "BATCH",
|
| 30 |
+
) -> nn.Module:
|
| 31 |
+
padding = same_padding(kernel_size)
|
| 32 |
+
mod: nn.Module = Convolution(
|
| 33 |
+
spatial_dims,
|
| 34 |
+
in_channels,
|
| 35 |
+
out_channels,
|
| 36 |
+
kernel_size=kernel_size,
|
| 37 |
+
act=act,
|
| 38 |
+
norm=norm,
|
| 39 |
+
bias=False,
|
| 40 |
+
conv_only=False,
|
| 41 |
+
padding=padding,
|
| 42 |
+
)
|
| 43 |
+
return mod
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_conv_layer(
|
| 47 |
+
spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int] = 3
|
| 48 |
+
) -> nn.Module:
|
| 49 |
+
padding = same_padding(kernel_size)
|
| 50 |
+
mod: nn.Module = Convolution(
|
| 51 |
+
spatial_dims, in_channels, out_channels, kernel_size=kernel_size, bias=False, conv_only=True, padding=padding
|
| 52 |
+
)
|
| 53 |
+
return mod
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_deconv_block(spatial_dims: int, in_channels: int, out_channels: int) -> nn.Module:
|
| 57 |
+
mod: nn.Module = Convolution(
|
| 58 |
+
spatial_dims=spatial_dims,
|
| 59 |
+
in_channels=in_channels,
|
| 60 |
+
out_channels=out_channels,
|
| 61 |
+
strides=2,
|
| 62 |
+
act="RELU",
|
| 63 |
+
norm="BATCH",
|
| 64 |
+
bias=False,
|
| 65 |
+
is_transposed=True,
|
| 66 |
+
padding=1,
|
| 67 |
+
output_padding=1,
|
| 68 |
+
)
|
| 69 |
+
return mod
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class ResidualBlock(nn.Module):
|
| 73 |
+
def __init__(
|
| 74 |
+
self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int]
|
| 75 |
+
) -> None:
|
| 76 |
+
super().__init__()
|
| 77 |
+
if in_channels != out_channels:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}"
|
| 80 |
+
)
|
| 81 |
+
self.conv_block = get_conv_block(
|
| 82 |
+
spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size
|
| 83 |
+
)
|
| 84 |
+
self.conv = get_conv_layer(
|
| 85 |
+
spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size
|
| 86 |
+
)
|
| 87 |
+
self.norm = Norm[Norm.BATCH, spatial_dims](out_channels)
|
| 88 |
+
self.relu = nn.ReLU()
|
| 89 |
+
|
| 90 |
+
def forward(self, x) -> torch.Tensor:
|
| 91 |
+
out: torch.Tensor = self.relu(self.norm(self.conv(self.conv_block(x))) + x)
|
| 92 |
+
return out
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class LocalNetResidualBlock(nn.Module):
|
| 96 |
+
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None:
|
| 97 |
+
super().__init__()
|
| 98 |
+
if in_channels != out_channels:
|
| 99 |
+
raise ValueError(
|
| 100 |
+
f"expecting in_channels == out_channels, " f"got in_channels={in_channels}, out_channels={out_channels}"
|
| 101 |
+
)
|
| 102 |
+
self.conv_layer = get_conv_layer(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels)
|
| 103 |
+
self.norm = Norm[Norm.BATCH, spatial_dims](out_channels)
|
| 104 |
+
self.relu = nn.ReLU()
|
| 105 |
+
|
| 106 |
+
def forward(self, x, mid) -> torch.Tensor:
|
| 107 |
+
out: torch.Tensor = self.relu(self.norm(self.conv_layer(x)) + mid)
|
| 108 |
+
return out
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class LocalNetDownSampleBlock(nn.Module):
|
| 112 |
+
"""
|
| 113 |
+
A down-sample module that can be used for LocalNet, based on:
|
| 114 |
+
`Weakly-supervised convolutional neural networks for multimodal image registration
|
| 115 |
+
<https://doi.org/10.1016/j.media.2018.07.002>`_.
|
| 116 |
+
`Label-driven weakly-supervised learning for multimodal deformable image registration
|
| 117 |
+
<https://arxiv.org/abs/1711.01666>`_.
|
| 118 |
+
|
| 119 |
+
Adapted from:
|
| 120 |
+
DeepReg (https://github.com/DeepRegNet/DeepReg)
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(
|
| 124 |
+
self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int]
|
| 125 |
+
) -> None:
|
| 126 |
+
"""
|
| 127 |
+
Args:
|
| 128 |
+
spatial_dims: number of spatial dimensions.
|
| 129 |
+
in_channels: number of input channels.
|
| 130 |
+
out_channels: number of output channels.
|
| 131 |
+
kernel_size: convolution kernel size.
|
| 132 |
+
Raises:
|
| 133 |
+
NotImplementedError: when ``kernel_size`` is even
|
| 134 |
+
"""
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.conv_block = get_conv_block(
|
| 137 |
+
spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size
|
| 138 |
+
)
|
| 139 |
+
self.residual_block = ResidualBlock(
|
| 140 |
+
spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size
|
| 141 |
+
)
|
| 142 |
+
self.max_pool = Pool[Pool.MAX, spatial_dims](kernel_size=2)
|
| 143 |
+
|
| 144 |
+
def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 145 |
+
"""
|
| 146 |
+
Halves the spatial dimensions.
|
| 147 |
+
A tuple of (x, mid) is returned:
|
| 148 |
+
|
| 149 |
+
- x is the downsample result, in shape (batch, ``out_channels``, insize_1 / 2, insize_2 / 2, [insize_3 / 2]),
|
| 150 |
+
- mid is the mid-level feature, in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3])
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])
|
| 154 |
+
|
| 155 |
+
Raises:
|
| 156 |
+
ValueError: when input spatial dimensions are not even.
|
| 157 |
+
"""
|
| 158 |
+
for i in x.shape[2:]:
|
| 159 |
+
if i % 2 != 0:
|
| 160 |
+
raise ValueError("expecting x spatial dimensions be even, " f"got x of shape {x.shape}")
|
| 161 |
+
x = self.conv_block(x)
|
| 162 |
+
mid = self.residual_block(x)
|
| 163 |
+
x = self.max_pool(mid)
|
| 164 |
+
return x, mid
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class LocalNetUpSampleBlock(nn.Module):
|
| 168 |
+
"""
|
| 169 |
+
A up-sample module that can be used for LocalNet, based on:
|
| 170 |
+
`Weakly-supervised convolutional neural networks for multimodal image registration
|
| 171 |
+
<https://doi.org/10.1016/j.media.2018.07.002>`_.
|
| 172 |
+
`Label-driven weakly-supervised learning for multimodal deformable image registration
|
| 173 |
+
<https://arxiv.org/abs/1711.01666>`_.
|
| 174 |
+
|
| 175 |
+
Adapted from:
|
| 176 |
+
DeepReg (https://github.com/DeepRegNet/DeepReg)
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None:
|
| 180 |
+
"""
|
| 181 |
+
Args:
|
| 182 |
+
spatial_dims: number of spatial dimensions.
|
| 183 |
+
in_channels: number of input channels.
|
| 184 |
+
out_channels: number of output channels.
|
| 185 |
+
Raises:
|
| 186 |
+
ValueError: when ``in_channels != 2 * out_channels``
|
| 187 |
+
"""
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.deconv_block = get_deconv_block(
|
| 190 |
+
spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels
|
| 191 |
+
)
|
| 192 |
+
self.conv_block = get_conv_block(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels)
|
| 193 |
+
self.residual_block = LocalNetResidualBlock(
|
| 194 |
+
spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels
|
| 195 |
+
)
|
| 196 |
+
if in_channels / out_channels != 2:
|
| 197 |
+
raise ValueError(
|
| 198 |
+
f"expecting in_channels == 2 * out_channels, "
|
| 199 |
+
f"got in_channels={in_channels}, out_channels={out_channels}"
|
| 200 |
+
)
|
| 201 |
+
self.out_channels = out_channels
|
| 202 |
+
|
| 203 |
+
def addictive_upsampling(self, x, mid) -> torch.Tensor:
|
| 204 |
+
x = F.interpolate(x, mid.shape[2:])
|
| 205 |
+
# [(batch, out_channels, ...), (batch, out_channels, ...)]
|
| 206 |
+
x = x.split(split_size=int(self.out_channels), dim=1)
|
| 207 |
+
# (batch, out_channels, ...)
|
| 208 |
+
out: torch.Tensor = torch.sum(torch.stack(x, dim=-1), dim=-1)
|
| 209 |
+
return out
|
| 210 |
+
|
| 211 |
+
def forward(self, x, mid) -> torch.Tensor:
|
| 212 |
+
"""
|
| 213 |
+
Halves the channel and doubles the spatial dimensions.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
x: feature to be up-sampled, in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])
|
| 217 |
+
mid: mid-level feature saved during down-sampling,
|
| 218 |
+
in shape (batch, ``out_channels``, midsize_1, midsize_2, [midsize_3])
|
| 219 |
+
|
| 220 |
+
Raises:
|
| 221 |
+
ValueError: when ``midsize != insize * 2``
|
| 222 |
+
"""
|
| 223 |
+
for i, j in zip(x.shape[2:], mid.shape[2:]):
|
| 224 |
+
if j != 2 * i:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
"expecting mid spatial dimensions be exactly the double of x spatial dimensions, "
|
| 227 |
+
f"got x of shape {x.shape}, mid of shape {mid.shape}"
|
| 228 |
+
)
|
| 229 |
+
h0 = self.deconv_block(x) + self.addictive_upsampling(x, mid)
|
| 230 |
+
r1 = h0 + mid
|
| 231 |
+
r2 = self.conv_block(h0)
|
| 232 |
+
out: torch.Tensor = self.residual_block(r2, r1)
|
| 233 |
+
return out
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class LocalNetFeatureExtractorBlock(nn.Module):
|
| 237 |
+
"""
|
| 238 |
+
A feature-extraction module that can be used for LocalNet, based on:
|
| 239 |
+
`Weakly-supervised convolutional neural networks for multimodal image registration
|
| 240 |
+
<https://doi.org/10.1016/j.media.2018.07.002>`_.
|
| 241 |
+
`Label-driven weakly-supervised learning for multimodal deformable image registration
|
| 242 |
+
<https://arxiv.org/abs/1711.01666>`_.
|
| 243 |
+
|
| 244 |
+
Adapted from:
|
| 245 |
+
DeepReg (https://github.com/DeepRegNet/DeepReg)
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def __init__(
|
| 249 |
+
self,
|
| 250 |
+
spatial_dims: int,
|
| 251 |
+
in_channels: int,
|
| 252 |
+
out_channels: int,
|
| 253 |
+
act: Optional[Union[Tuple, str]] = "RELU",
|
| 254 |
+
initializer: str = "kaiming_uniform",
|
| 255 |
+
) -> None:
|
| 256 |
+
"""
|
| 257 |
+
Args:
|
| 258 |
+
spatial_dims: number of spatial dimensions.
|
| 259 |
+
in_channels: number of input channels.
|
| 260 |
+
out_channels: number of output channels.
|
| 261 |
+
act: activation type and arguments. Defaults to ReLU.
|
| 262 |
+
kernel_initializer: kernel initializer. Defaults to None.
|
| 263 |
+
"""
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.conv_block = get_conv_block(
|
| 266 |
+
spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None
|
| 267 |
+
)
|
| 268 |
+
conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims]
|
| 269 |
+
for m in self.conv_block.modules():
|
| 270 |
+
if isinstance(m, conv_type):
|
| 271 |
+
if initializer == "kaiming_uniform":
|
| 272 |
+
nn.init.kaiming_normal_(torch.as_tensor(m.weight))
|
| 273 |
+
elif initializer == "zeros":
|
| 274 |
+
nn.init.zeros_(torch.as_tensor(m.weight))
|
| 275 |
+
else:
|
| 276 |
+
raise ValueError(
|
| 277 |
+
f"initializer {initializer} is not supported, " "currently supporting kaiming_uniform and zeros"
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def forward(self, x) -> torch.Tensor:
|
| 281 |
+
"""
|
| 282 |
+
Args:
|
| 283 |
+
x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])
|
| 284 |
+
"""
|
| 285 |
+
out: torch.Tensor = self.conv_block(x)
|
| 286 |
+
return out
|