File size: 7,033 Bytes
e4b9a7b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | # Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
from glob import glob
import nibabel as nib
import numpy as np
import torch
from ignite.metrics import Accuracy
import monai
from monai.data import create_test_image_3d
from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.handlers import (
CheckpointSaver,
LrScheduleHandler,
MeanDice,
StatsHandler,
TensorBoardImageHandler,
TensorBoardStatsHandler,
ValidationHandler,
)
from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.transforms import (
Activationsd,
AsChannelFirstd,
AsDiscreted,
Compose,
KeepLargestConnectedComponentd,
LoadNiftid,
RandCropByPosNegLabeld,
RandRotate90d,
ScaleIntensityd,
ToTensord,
)
def main(tempdir):
monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
# create a temporary directory and 40 random image, mask pairs
print(f"generating synthetic data to {tempdir} (this may take a while)")
for i in range(40):
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))
n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])]
val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])]
# define transforms for image and segmentation
train_transforms = Compose(
[
LoadNiftid(keys=["image", "label"]),
AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
ScaleIntensityd(keys="image"),
RandCropByPosNegLabeld(
keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
),
RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
ToTensord(keys=["image", "label"]),
]
)
val_transforms = Compose(
[
LoadNiftid(keys=["image", "label"]),
AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
ScaleIntensityd(keys="image"),
ToTensord(keys=["image", "label"]),
]
)
# create a training data loader
train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
# create a validation data loader
val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)
val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)
# create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = monai.networks.nets.UNet(
dimensions=3,
in_channels=1,
out_channels=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
).to(device)
loss = monai.losses.DiceLoss(sigmoid=True)
opt = torch.optim.Adam(net.parameters(), 1e-3)
lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)
val_post_transforms = Compose(
[
Activationsd(keys="pred", sigmoid=True),
AsDiscreted(keys="pred", threshold_values=True),
KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
]
)
val_handlers = [
StatsHandler(output_transform=lambda x: None),
TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None),
TensorBoardImageHandler(
log_dir="./runs/",
batch_transform=lambda x: (x["image"], x["label"]),
output_transform=lambda x: x["pred"],
),
CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True),
]
evaluator = SupervisedEvaluator(
device=device,
val_data_loader=val_loader,
network=net,
inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
post_transform=val_post_transforms,
key_val_metric={
"val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"]))
},
additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
val_handlers=val_handlers,
# if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
)
train_post_transforms = Compose(
[
Activationsd(keys="pred", sigmoid=True),
AsDiscreted(keys="pred", threshold_values=True),
KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
]
)
train_handlers = [
LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
TensorBoardStatsHandler(log_dir="./runs/", tag_name="train_loss", output_transform=lambda x: x["loss"]),
CheckpointSaver(save_dir="./runs/", save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True),
]
trainer = SupervisedTrainer(
device=device,
max_epochs=5,
train_data_loader=train_loader,
network=net,
optimizer=opt,
loss_function=loss,
inferer=SimpleInferer(),
post_transform=train_post_transforms,
key_train_metric={"train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))},
train_handlers=train_handlers,
# if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training
amp=True if monai.config.get_torch_version_tuple() >= (1, 6) else False,
)
trainer.run()
if __name__ == "__main__":
with tempfile.TemporaryDirectory() as tempdir:
main(tempdir)
|