diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..a785b0ca2e286ee29f99228542d49762420c821f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +doc/images/DLCS_1419_ann0_slice134_triple.png filter=lfs diff=lfs merge=lfs -text +doc/images/DLCS_1419_ann1_slice204_triple.png filter=lfs diff=lfs merge=lfs -text +doc/images/DLCS_1443_ann1_slice125_triple.png filter=lfs diff=lfs merge=lfs -text +doc/images/DLCS_1446_ann0_slice122_triple.png filter=lfs diff=lfs merge=lfs -text +doc/images/DLCS_1447_ann0_slice206_triple.png filter=lfs diff=lfs merge=lfs -text +doc/images/DLCS_1453_ann0_slice204_triple.png filter=lfs diff=lfs merge=lfs -text +doc/images/DLCS_1508_ann0_slice46_triple.png filter=lfs diff=lfs merge=lfs -text +doc/images/DLCS_1519_ann3_slice155_triple.png filter=lfs diff=lfs merge=lfs -text +doc/images/GanAI_fid_scatter_marker_legend.png filter=lfs diff=lfs merge=lfs -text +doc/images/NoMAISI_train_and_infer.png filter=lfs diff=lfs merge=lfs -text +doc/images/TaskCls.png filter=lfs diff=lfs merge=lfs -text +doc/images/workflow.png filter=lfs diff=lfs merge=lfs -text +NoMAISI_logo.png filter=lfs diff=lfs merge=lfs -text diff --git a/NoMAISI_logo.png b/NoMAISI_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..01e3f0b86680cdbb3fb179d64c6f45a8a43389f3 --- /dev/null +++ b/NoMAISI_logo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59e28b561fa2a934150fa912146fc81f75aa8b526defd5c698c46cac09995c94 +size 185945 diff --git a/configs/config_maisi3d-rflow.json b/configs/config_maisi3d-rflow.json new file mode 100644 index 0000000000000000000000000000000000000000..1f2ee4619b464ba47e619191085214bcf042752e --- /dev/null +++ b/configs/config_maisi3d-rflow.json @@ -0,0 +1,150 @@ +{ + "spatial_dims": 3, + "image_channels": 1, + "latent_channels": 4, + "include_body_region": false, + "mask_generation_latent_shape": [ + 4, + 64, + 64, + 64 + ], + "autoencoder_def": { + "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi", + "spatial_dims": "@spatial_dims", + "in_channels": "@image_channels", + "out_channels": "@image_channels", + "latent_channels": "@latent_channels", + "num_channels": [ + 64, + 128, + 256 + ], + "num_res_blocks": [2,2,2], + "norm_num_groups": 32, + "norm_eps": 1e-06, + "attention_levels": [ + false, + false, + false + ], + "with_encoder_nonlocal_attn": false, + "with_decoder_nonlocal_attn": false, + "use_checkpointing": false, + "use_convtranspose": false, + "norm_float16": true, + "num_splits": 4, + "dim_split": 1 + }, + "diffusion_unet_def": { + "_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi", + "spatial_dims": "@spatial_dims", + "in_channels": "@latent_channels", + "out_channels": "@latent_channels", + "num_channels": [64, 128, 256, 512], + "attention_levels": [ + false, + false, + true, + true + ], + "num_head_channels": [ + 0, + 0, + 32, + 32 + ], + "num_res_blocks": 2, + "use_flash_attention": true, + "include_top_region_index_input": "@include_body_region", + "include_bottom_region_index_input": "@include_body_region", + "include_spacing_input": true, + "num_class_embeds": 128, + "resblock_updown": true, + "include_fc": true + }, + "controlnet_def": { + "_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi", + "spatial_dims": "@spatial_dims", + "in_channels": "@latent_channels", + "num_channels": [64, 128, 256, 512], + "attention_levels": [ + false, + false, + true, + true + ], + "num_head_channels": [ + 0, + 0, + 32, + 32 + ], + "num_res_blocks": 2, + "use_flash_attention": true, + "conditioning_embedding_in_channels": 8, + "conditioning_embedding_num_channels": [8, 32, 64], + "num_class_embeds": 128, + "resblock_updown": true, + "include_fc": true + }, + "mask_generation_autoencoder_def": { + "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi", + "spatial_dims": "@spatial_dims", + "in_channels": 8, + "out_channels": 125, + "latent_channels": "@latent_channels", + "num_channels": [ + 32, + 64, + 128 + ], + "num_res_blocks": [1, 2, 2], + "norm_num_groups": 32, + "norm_eps": 1e-06, + "attention_levels": [ + false, + false, + false + ], + "with_encoder_nonlocal_attn": false, + "with_decoder_nonlocal_attn": false, + "use_flash_attention": false, + "use_checkpointing": true, + "use_convtranspose": true, + "norm_float16": true, + "num_splits": 8, + "dim_split": 1 + }, + "mask_generation_diffusion_def": { + "_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet", + "spatial_dims": "@spatial_dims", + "in_channels": "@latent_channels", + "out_channels": "@latent_channels", + "channels":[64, 128, 256, 512], + "attention_levels":[false, false, true, true], + "num_head_channels":[0, 0, 32, 32], + "num_res_blocks": 2, + "use_flash_attention": true, + "with_conditioning": true, + "upcast_attention": true, + "cross_attention_dim": 10 + }, + "mask_generation_scale_factor": 1.0055984258651733, + "noise_scheduler": { + "_target_": "monai.networks.schedulers.rectified_flow.RFlowScheduler", + "num_train_timesteps": 1000, + "use_discrete_timesteps": false, + "use_timestep_transform": true, + "sample_method": "uniform", + "scale":1.4 + }, + "mask_generation_noise_scheduler": { + "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler", + "num_train_timesteps": 1000, + "beta_start": 0.0015, + "beta_end": 0.0195, + "schedule": "scaled_linear_beta", + "clip_sample": false + } +} diff --git a/configs/infr_config_NoMAISI_controlnet.json b/configs/infr_config_NoMAISI_controlnet.json new file mode 100644 index 0000000000000000000000000000000000000000..95f5dee643dfb75659cc11b0c8c77feaf1a5ff54 --- /dev/null +++ b/configs/infr_config_NoMAISI_controlnet.json @@ -0,0 +1,17 @@ +{ + "controlnet_train": { + "batch_size": 2, + "cache_rate": 0.0, + "fold": 1, + "lr": 1e-5, + "n_epochs": 500, + "weighted_loss_label": [23], + "weighted_loss": 100 + }, + "controlnet_infer": { + "num_inference_steps": 30, + "autoencoder_sliding_window_infer_size": [80, 80, 64], + "autoencoder_sliding_window_infer_overlap": 0.25, + "modality": 1 + } +} diff --git a/configs/infr_env_NoMAISI_DLCSD24_demo.json b/configs/infr_env_NoMAISI_DLCSD24_demo.json new file mode 100644 index 0000000000000000000000000000000000000000..68e7cee8ec6a678022ea5b27e25876eab7fde3c5 --- /dev/null +++ b/configs/infr_env_NoMAISI_DLCSD24_demo.json @@ -0,0 +1,11 @@ +{ + "model_dir": "./models/", + "output_dir": "./outputs/NoMAISI_DLCSD24_demo_512xy_256z_771p25m", + "tfevent_path": "./outputs/tfevent", + "trained_autoencoder_path": "./models/autoencoder.pt", + "trained_diffusion_path": "./models/diffusion_unet.pt", + "trained_controlnet_path": "./models/Experiments_NoMAISI_512xy_256z_771p25m_finetune_500epoch_best.pt", + "exp_name": "NoMAISI_DLCSD24_demo_512xy_256z_771p25m", + "data_base_dir": ["/home/ft42/NoMAISI/data"], + "json_data_list": ["/home/ft42/NoMAISI/data/infr_NoMAISI_DLCSD24_demo_512xy_256z_771p25m_dataset.json"] +} \ No newline at end of file diff --git a/data/DLCS_1419_seg_sh.nii.gz b/data/DLCS_1419_seg_sh.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..0cb4ba85a21e78694b4e4a67f6c462bed8fb4f9a --- /dev/null +++ b/data/DLCS_1419_seg_sh.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83da8dbf3b165023f3ffcec571fe5766177b65aabfa143f3a0bef5be41af757b +size 2265286 diff --git a/data/infr_NoMAISI_DLCSD24_demo_512xy_256z_771p25m_dataset.json b/data/infr_NoMAISI_DLCSD24_demo_512xy_256z_771p25m_dataset.json new file mode 100644 index 0000000000000000000000000000000000000000..889d87880e4f2709ad37e978082dc6910b6d7925 --- /dev/null +++ b/data/infr_NoMAISI_DLCSD24_demo_512xy_256z_771p25m_dataset.json @@ -0,0 +1,32 @@ +{ + "name": "NoMAISI_DLCSD24_demo_512xy_256z_771p25m", + "numTest": 1, + "testing": [ + { + + "label": "DLCS_1419_seg_sh.nii.gz", + "fold": 0, + "dim": [ + 512, + 512, + 256 + ], + "spacing": [ + 0.703125, + 0.703125, + 1.25 + ], + "top_region_index": [ + 0, + 1, + 0, + 0 + ], + "bottom_region_index": [ + 0, + 0, + 1, + 0 + ] + }] +} diff --git a/doc/images/DLCS_1419_ann0_slice134_triple.png b/doc/images/DLCS_1419_ann0_slice134_triple.png new file mode 100644 index 0000000000000000000000000000000000000000..18c489595953872db1bf883c30b027a38aa92e57 --- /dev/null +++ b/doc/images/DLCS_1419_ann0_slice134_triple.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9729e15104e9f3b6ae675f57bf7d5f9f1aec3e191a4d7a68209bde4a3d148363 +size 1236594 diff --git a/doc/images/DLCS_1419_ann1_slice204_triple.png b/doc/images/DLCS_1419_ann1_slice204_triple.png new file mode 100644 index 0000000000000000000000000000000000000000..6111e022c84b3b6621fec2f47f69703f84587bd8 --- /dev/null +++ b/doc/images/DLCS_1419_ann1_slice204_triple.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5bbcd3ddca8a3623f38764984fed7f9a36c92d8e2c98336c4b3e5e0aadb29e0a +size 1131619 diff --git a/doc/images/DLCS_1443_ann1_slice125_triple.png b/doc/images/DLCS_1443_ann1_slice125_triple.png new file mode 100644 index 0000000000000000000000000000000000000000..3167fcde06f5f1c6642785d7adad551d44a97895 --- /dev/null +++ b/doc/images/DLCS_1443_ann1_slice125_triple.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6336851a8174aeedd990f169f4dfa1ec8f2524adbbfd048f1d491ba0973ae72 +size 1112723 diff --git a/doc/images/DLCS_1446_ann0_slice122_triple.png b/doc/images/DLCS_1446_ann0_slice122_triple.png new file mode 100644 index 0000000000000000000000000000000000000000..42724a4d3d8e565b06c72d6880ddf44ee9fed33c --- /dev/null +++ b/doc/images/DLCS_1446_ann0_slice122_triple.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29706ad025325e95dd9ad6cc56e52ea9481866a23d97fb3033198f76a5b65a13 +size 954995 diff --git a/doc/images/DLCS_1447_ann0_slice206_triple.png b/doc/images/DLCS_1447_ann0_slice206_triple.png new file mode 100644 index 0000000000000000000000000000000000000000..2b577c8a01aaf8cff0c5f91ac3fb8bab2f7a5cab --- /dev/null +++ b/doc/images/DLCS_1447_ann0_slice206_triple.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad5d313eee8c53edb67c8240963c29c340f2e4456db8cdfc538f7c10fcbf7f2f +size 892809 diff --git a/doc/images/DLCS_1453_ann0_slice204_triple.png b/doc/images/DLCS_1453_ann0_slice204_triple.png new file mode 100644 index 0000000000000000000000000000000000000000..e730d48a8fbb82a2ec2dc4c40af6964152325e51 --- /dev/null +++ b/doc/images/DLCS_1453_ann0_slice204_triple.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1cb674c92523eab8a008367b658561cfb94ebc3dfbc84b6f666d609097f2863 +size 1196470 diff --git a/doc/images/DLCS_1508_ann0_slice46_triple.png b/doc/images/DLCS_1508_ann0_slice46_triple.png new file mode 100644 index 0000000000000000000000000000000000000000..e924cbb6217b99f26147bfe30c4b60ef4777b01b --- /dev/null +++ b/doc/images/DLCS_1508_ann0_slice46_triple.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d3f245b13e4495d01e8585058239c02f2cbc17b72557d8306b58bce23747334 +size 1642446 diff --git a/doc/images/DLCS_1519_ann3_slice155_triple.png b/doc/images/DLCS_1519_ann3_slice155_triple.png new file mode 100644 index 0000000000000000000000000000000000000000..123b31591b1024d2e6d8b12b55669a01fe70511b --- /dev/null +++ b/doc/images/DLCS_1519_ann3_slice155_triple.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0a7db06ba28e1412d546d2ef917c50f04dd2cff9f06c5d83d611c406185fd13 +size 1362753 diff --git a/doc/images/GanAI_fid_scatter_marker_legend.png b/doc/images/GanAI_fid_scatter_marker_legend.png new file mode 100644 index 0000000000000000000000000000000000000000..c26e0d7cce9e42fe26cc2a1896a49d8d19efc097 --- /dev/null +++ b/doc/images/GanAI_fid_scatter_marker_legend.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60c1e2e2be297fd13de2600aa2559c853db277d9ef3238da7a166c1e3472a237 +size 179343 diff --git a/doc/images/NoMAISI_train_and_infer.png b/doc/images/NoMAISI_train_and_infer.png new file mode 100644 index 0000000000000000000000000000000000000000..555b5449fc9247410b996ae47353d1718ab4d92b --- /dev/null +++ b/doc/images/NoMAISI_train_and_infer.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffc762231f799865c8a36898ae6e23434f0f188edd45fec1be88bbd9f582a3f4 +size 456748 diff --git a/doc/images/TaskCls.png b/doc/images/TaskCls.png new file mode 100644 index 0000000000000000000000000000000000000000..227d7dad001769094386cfe95369653fc3b6235a --- /dev/null +++ b/doc/images/TaskCls.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d23c4d5110aab51b39e9772122eb98edaa5d260e1fcc3de24ff486fb5feaa06 +size 280285 diff --git a/doc/images/workflow.png b/doc/images/workflow.png new file mode 100644 index 0000000000000000000000000000000000000000..5463d902da87ad63de9cafd413fe1084bb7d09db --- /dev/null +++ b/doc/images/workflow.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3bfeafa6ca6729ce6808c39e13afaa222d7ce102277b0f5fcb7d3eb29148ef93 +size 610149 diff --git a/inference.sub b/inference.sub new file mode 100644 index 0000000000000000000000000000000000000000..4d791b1f2fa6a34b18730b5b45b9d1c564476536 --- /dev/null +++ b/inference.sub @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=nomaisi +#SBATCH --mail-type=END,FAIL +#SBATCH --mail-user=ft42@duke.edu +#SBATCH -p vram48 +#SBATCH --ntasks=1 # +#SBATCH --gpus=1 # 2 GPU per task, chose more if model is capable of multi gpu training +#SBATCH --cpus-per-task=16 # More if it is CPU intensive job too NNUNET demands lot of CPU + +## Make sure logs directory is present on current directory (same as this script) +#SBATCH --output=logs/NoMAISI-infr-log-%j.out +#SBATCH --error=logs/NoMAISI-infr-log-%j.out + + + +echo "Job starting" +echo "GPUs Given: $CUDA_VISIBLE_DEVICES" +module load miniconda/py39_4.12.0 +source activate monai-auto3dseg + + +# Add the correct path to PYTHONPATH +export MONAI_DATA_DIRECTORY=/home/ft42/NoMAISI/ + +python -m scripts.infer_testV2_controlnet -c ./configs/config_maisi3d-rflow.json -e ./configs/infr_env_NoMAISI_DLCSD24_demo.json -t ./configs/infr_config_NoMAISI_controlnet.json diff --git a/logs/NoMAISI-infr-log-38612.out b/logs/NoMAISI-infr-log-38612.out new file mode 100644 index 0000000000000000000000000000000000000000..db024ef7f287d0cf8d2463b53afd84babc5fcdf1 --- /dev/null +++ b/logs/NoMAISI-infr-log-38612.out @@ -0,0 +1,18 @@ +Job starting +GPUs Given: 0 +[2025-09-24 13:42:58.511][ INFO](maisi.controlnet.infer) - Number of GPUs: 1 +[2025-09-24 13:42:58.512][ INFO](maisi.controlnet.infer) - World_size: 1 +[2025-09-24 13:42:59.541][ INFO](maisi.controlnet.infer) - Load trained diffusion model from ./models/autoencoder.pt. +[2025-09-24 13:43:03.285][ INFO](maisi.controlnet.infer) - Load trained diffusion model from ./models/diffusion_unet.pt. +[2025-09-24 13:43:03.287][ INFO](maisi.controlnet.infer) - loaded scale_factor from diffusion model ckpt -> 1.0311251878738403. +2025-09-24 13:43:03,824 - INFO - 'dst' model updated: 180 of 231 variables. +[2025-09-24 13:43:04.077][ INFO](maisi.controlnet.infer) - load trained controlnet model from ./models/Experiments_NoMAISI_512xy_256z_771p25m_finetune_500epoch_best.pt +[2025-09-24 13:43:07.130][ INFO](root) - `controllable_anatomy_size` is not provided. +[2025-09-24 13:43:07.133][ INFO](root) - ---- Start generating latent features... ---- + 0%| | 0/30 [00:00 0, 1.0, 0.0) + + return output.squeeze(0).squeeze(0) + + +def augmentation_tumor_bone(pt_nda, output_size, random_seed=None): + volume = pt_nda.squeeze(0) + real_l_volume_ = torch.zeros_like(volume) + real_l_volume_[volume == 128] = 1 + real_l_volume_ = real_l_volume_.to(torch.uint8) + + elastic = RandAffine( + mode="nearest", + prob=1.0, + translate_range=(5, 5, 0), + rotate_range=(0, 0, 0.1), + scale_range=(0.15, 0.15, 0), + padding_mode="zeros", + ) + elastic.set_random_state(seed=random_seed) + + tumor_szie = torch.sum((real_l_volume_ > 0).float()) + ########################### + # remove pred in pseudo_label in real lesion region + volume[real_l_volume_ > 0] = 200 + ########################### + if tumor_szie > 0: + # get organ mask + organ_mask = ( + torch.logical_and(33 <= volume, volume <= 56).float() + + torch.logical_and(63 <= volume, volume <= 97).float() + + (volume == 127).float() + + (volume == 114).float() + + real_l_volume_ + ) + organ_mask = (organ_mask > 0).float() + cnt = 0 + while True: + threshold = 0.8 if cnt < 40 else 0.75 + real_l_volume = real_l_volume_ + # random distor mask + distored_mask = elastic((real_l_volume > 0).cuda(), spatial_size=tuple(output_size)).as_tensor() + real_l_volume = distored_mask * organ_mask + cnt += 1 + print(torch.sum(real_l_volume), "|", tumor_szie * threshold) + if torch.sum(real_l_volume) >= tumor_szie * threshold: + real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5) + real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8) + break + else: + real_l_volume = real_l_volume_ + + volume[real_l_volume == 1] = 128 + + pt_nda = volume.unsqueeze(0) + return pt_nda + + +def augmentation_tumor_liver(pt_nda, output_size, random_seed=None): + volume = pt_nda.squeeze(0) + real_l_volume_ = torch.zeros_like(volume) + real_l_volume_[volume == 1] = 1 + real_l_volume_[volume == 26] = 2 + real_l_volume_ = real_l_volume_.to(torch.uint8) + + elastic = Rand3DElastic( + mode="nearest", + prob=1.0, + sigma_range=(5, 8), + magnitude_range=(100, 200), + translate_range=(10, 10, 10), + rotate_range=(np.pi / 36, np.pi / 36, np.pi / 36), + scale_range=(0.2, 0.2, 0.2), + padding_mode="zeros", + ) + elastic.set_random_state(seed=random_seed) + + tumor_szie = torch.sum(real_l_volume_ == 2) + ########################### + # remove pred organ labels + volume[volume == 1] = 0 + volume[volume == 26] = 0 + # before move tumor maks, full the original location by organ labels + volume[real_l_volume_ == 1] = 1 + volume[real_l_volume_ == 2] = 1 + ########################### + while True: + real_l_volume = real_l_volume_ + # random distor mask + real_l_volume = elastic((real_l_volume == 2).cuda(), spatial_size=tuple(output_size)).as_tensor() + # get organ mask + organ_mask = (real_l_volume_ == 1).float() + (real_l_volume_ == 2).float() + + organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5) + organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0) + real_l_volume = real_l_volume * organ_mask + print(torch.sum(real_l_volume), "|", tumor_szie * 0.80) + if torch.sum(real_l_volume) >= tumor_szie * 0.80: + real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5) + real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0) + break + + volume[real_l_volume == 1] = 26 + + pt_nda = volume.unsqueeze(0) + return pt_nda + + +def augmentation_tumor_lung(pt_nda, output_size, random_seed=None): + volume = pt_nda.squeeze(0) + real_l_volume_ = torch.zeros_like(volume) + real_l_volume_[volume == 23] = 1 + real_l_volume_ = real_l_volume_.to(torch.uint8) + + elastic = Rand3DElastic( + mode="nearest", + prob=1.0, + sigma_range=(5, 8), + magnitude_range=(100, 200), + translate_range=(20, 20, 20), + rotate_range=(np.pi / 36, np.pi / 36, np.pi), + scale_range=(0.15, 0.15, 0.15), + padding_mode="zeros", + ) + elastic.set_random_state(seed=random_seed) + + tumor_szie = torch.sum(real_l_volume_) + # before move lung tumor maks, full the original location by lung labels + new_real_l_volume_ = dilate3d(real_l_volume_.squeeze(0), erosion=3) + new_real_l_volume_ = new_real_l_volume_.unsqueeze(0) + new_real_l_volume_[real_l_volume_ > 0] = 0 + new_real_l_volume_[volume < 28] = 0 + new_real_l_volume_[volume > 32] = 0 + tmp = volume[(volume * new_real_l_volume_).nonzero(as_tuple=True)].view(-1) + + mode = torch.mode(tmp, 0)[0].item() + print(mode) + assert 28 <= mode <= 32 + volume[real_l_volume_.bool()] = mode + ########################### + if tumor_szie > 0: + # aug + while True: + real_l_volume = real_l_volume_ + # random distor mask + real_l_volume = elastic(real_l_volume, spatial_size=tuple(output_size)).as_tensor() + # get lung mask v2 (133 order) + lung_mask = ( + (volume == 28).float() + + (volume == 29).float() + + (volume == 30).float() + + (volume == 31).float() + + (volume == 32).float() + ) + + lung_mask = dilate3d(lung_mask.squeeze(0), erosion=5) + lung_mask = erode3d(lung_mask, erosion=5).unsqueeze(0) + real_l_volume = real_l_volume * lung_mask + print(torch.sum(real_l_volume), "|", tumor_szie * 0.85) + if torch.sum(real_l_volume) >= tumor_szie * 0.85: + real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5) + real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8) + break + else: + real_l_volume = real_l_volume_ + + volume[real_l_volume == 1] = 23 + + pt_nda = volume.unsqueeze(0) + return pt_nda + + +def augmentation_tumor_pancreas(pt_nda, output_size, random_seed=None): + volume = pt_nda.squeeze(0) + real_l_volume_ = torch.zeros_like(volume) + real_l_volume_[volume == 4] = 1 + real_l_volume_[volume == 24] = 2 + real_l_volume_ = real_l_volume_.to(torch.uint8) + + elastic = Rand3DElastic( + mode="nearest", + prob=1.0, + sigma_range=(5, 8), + magnitude_range=(100, 200), + translate_range=(15, 15, 15), + rotate_range=(np.pi / 36, np.pi / 36, np.pi / 36), + scale_range=(0.1, 0.1, 0.1), + padding_mode="zeros", + ) + elastic.set_random_state(seed=random_seed) + + tumor_szie = torch.sum(real_l_volume_ == 2) + ########################### + # remove pred organ labels + volume[volume == 24] = 0 + volume[volume == 4] = 0 + # before move tumor maks, full the original location by organ labels + volume[real_l_volume_ == 1] = 4 + volume[real_l_volume_ == 2] = 4 + ########################### + while True: + real_l_volume = real_l_volume_ + # random distor mask + real_l_volume = elastic((real_l_volume == 2).cuda(), spatial_size=tuple(output_size)).as_tensor() + # get organ mask + organ_mask = (real_l_volume_ == 1).float() + (real_l_volume_ == 2).float() + + organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5) + organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0) + real_l_volume = real_l_volume * organ_mask + print(torch.sum(real_l_volume), "|", tumor_szie * 0.80) + if torch.sum(real_l_volume) >= tumor_szie * 0.80: + real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5) + real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0) + break + + volume[real_l_volume == 1] = 24 + + pt_nda = volume.unsqueeze(0) + return pt_nda + + +def augmentation_tumor_colon(pt_nda, output_size, random_seed=None): + volume = pt_nda.squeeze(0) + real_l_volume_ = torch.zeros_like(volume) + real_l_volume_[volume == 27] = 1 + real_l_volume_ = real_l_volume_.to(torch.uint8) + + elastic = Rand3DElastic( + mode="nearest", + prob=1.0, + sigma_range=(5, 8), + magnitude_range=(100, 200), + translate_range=(5, 5, 5), + rotate_range=(np.pi / 36, np.pi / 36, np.pi / 36), + scale_range=(0.1, 0.1, 0.1), + padding_mode="zeros", + ) + elastic.set_random_state(seed=random_seed) + + tumor_szie = torch.sum(real_l_volume_) + ########################### + # before move tumor maks, full the original location by organ labels + volume[real_l_volume_.bool()] = 62 + ########################### + if tumor_szie > 0: + # get organ mask + organ_mask = (volume == 62).float() + organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5) + organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0) + # cnt = 0 + cnt = 0 + while True: + threshold = 0.8 + real_l_volume = real_l_volume_ + if cnt < 20: + # random distor mask + distored_mask = elastic((real_l_volume == 1).cuda(), spatial_size=tuple(output_size)).as_tensor() + real_l_volume = distored_mask * organ_mask + elif 20 <= cnt < 40: + threshold = 0.75 + else: + break + + real_l_volume = real_l_volume * organ_mask + print(torch.sum(real_l_volume), "|", tumor_szie * threshold) + cnt += 1 + if torch.sum(real_l_volume) >= tumor_szie * threshold: + real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5) + real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8) + break + else: + real_l_volume = real_l_volume_ + # break + volume[real_l_volume == 1] = 27 + + pt_nda = volume.unsqueeze(0) + return pt_nda + + +def augmentation_body(pt_nda, random_seed=None): + volume = pt_nda.squeeze(0) + + zoom = RandZoom(min_zoom=0.99, max_zoom=1.01, mode="nearest", align_corners=None, prob=1.0) + zoom.set_random_state(seed=random_seed) + + volume = zoom(volume) + + pt_nda = volume.unsqueeze(0) + return pt_nda + + +def augmentation(pt_nda, output_size, random_seed=None): + label_list = torch.unique(pt_nda) + label_list = list(label_list.cpu().numpy()) + + if 128 in label_list: + print("augmenting bone lesion/tumor") + pt_nda = augmentation_tumor_bone(pt_nda, output_size, random_seed) + elif 26 in label_list: + print("augmenting liver tumor") + pt_nda = augmentation_tumor_liver(pt_nda, output_size, random_seed) + elif 23 in label_list: + print("augmenting lung tumor") + pt_nda = augmentation_tumor_lung(pt_nda, output_size, random_seed) + elif 24 in label_list: + print("augmenting pancreas tumor") + pt_nda = augmentation_tumor_pancreas(pt_nda, output_size, random_seed) + elif 27 in label_list: + print("augmenting colon tumor") + pt_nda = augmentation_tumor_colon(pt_nda, output_size, random_seed) + else: + print("augmenting body") + pt_nda = augmentation_body(pt_nda, random_seed) + + return pt_nda diff --git a/scripts/compute_fid_2-5d_ct.py b/scripts/compute_fid_2-5d_ct.py new file mode 100644 index 0000000000000000000000000000000000000000..0442a3e927e68f395cdf0308f84b310f401c20a0 --- /dev/null +++ b/scripts/compute_fid_2-5d_ct.py @@ -0,0 +1,747 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +# either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +""" +Compute 2.5D FID using distributed GPU processing. + +SHELL Usage Example: +------------------- + #!/bin/bash + + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 + NUM_GPUS=7 + + torchrun --nproc_per_node=${NUM_GPUS} compute_fid_2-5d_ct.py \ + --model_name "radimagenet_resnet50" \ + --real_dataset_root "path/to/datasetA" \ + --real_filelist "path/to/filelistA.txt" \ + --real_features_dir "datasetA" \ + --synth_dataset_root "path/to/datasetB" \ + --synth_filelist "path/to/filelistB.txt" \ + --synth_features_dir "datasetB" \ + --enable_center_slices_ratio 0.4 \ + --enable_padding True \ + --enable_center_cropping True \ + --enable_resampling_spacing "1.0x1.0x1.0" \ + --ignore_existing True \ + --num_images 100 \ + --output_root "./features/features-512x512x512" \ + --target_shape "512x512x512" + +This script loads two datasets (real vs. synthetic) in 3D medical format (NIfTI) +and extracts feature maps via a 2.5D approach. It then computes the Frechet +Inception Distance (FID) across three orthogonal planes. Data parallelism +is implemented using torch.distributed with an NCCL backend. + +Function Arguments (main): +-------------------------- + real_dataset_root (str): + Root folder for the real dataset. + + real_filelist (str): + Text file listing 3D images for the real dataset. + + real_features_dir (str): + Subdirectory (under `output_root`) in which to store feature files + extracted from the real dataset. + + synth_dataset_root (str): + Root folder for the synthetic dataset. + + synth_filelist (str): + Text file listing 3D images for the synthetic dataset. + + synth_features_dir (str): + Subdirectory (under `output_root`) in which to store feature files + extracted from the synthetic dataset. + + enable_center_slices_ratio (float or None): + - If not None, only slices around the specified center ratio will be used + (analogous to "enable_center_slices=True" with that ratio). + - If None, no center-slice selection is performed + (analogous to "enable_center_slices=False"). + + enable_padding (bool): + Whether to pad images to `target_shape`. + + enable_center_cropping (bool): + Whether to center-crop images to `target_shape`. + + enable_resampling_spacing (str or None): + - If not None, resample images to the specified voxel spacing (e.g. "1.0x1.0x1.0") + (analogous to "enable_resampling=True" with that spacing). + - If None, resampling is skipped + (analogous to "enable_resampling=False"). + + ignore_existing (bool): + If True, ignore any existing .pt feature files and force re-extraction. + + model_name (str): + Model identifier. Typically "radimagenet_resnet50" or "squeezenet1_1". + + num_images (int): + Max number of images to process from each dataset (truncate if more are present). + + output_root (str): + Folder where extracted .pt feature files, logs, and results are saved. + + target_shape (str): + Target shape as "XxYxZ" for padding, cropping, or resampling operations. +""" + + +from __future__ import annotations + +import os +import sys +import torch +import fire +import monai +import re +import torch.distributed as dist +import torch.nn.functional as F + +from datetime import timedelta +from pathlib import Path +from monai.metrics.fid import FIDMetric +from monai.transforms import Compose + +import logging + +# ------------------------------------------------------------------------------ +# Create logger +# ------------------------------------------------------------------------------ +logger = logging.getLogger("fid_2-5d_ct") +if not logger.handlers: + # Configure logger only if it has no handlers (avoid reconfiguring in multi-rank scenarios) + logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logger.setLevel(logging.INFO) + + +def drop_empty_slice(slices, empty_threshold: float): + """ + Decide which 2D slices to keep by checking if their maximum intensity + is below a certain threshold. + + Args: + slices (tuple or list of Tensors): Each element is (B, C, H, W). + empty_threshold (float): If the slice's maximum value is below this threshold, + it is considered "empty". + + Returns: + list[bool]: A list of booleans indicating for each slice whether to keep it. + """ + outputs = [] + n_drop = 0 + for s in slices: + largest_unique = torch.max(torch.unique(s)) + if largest_unique < empty_threshold: + outputs.append(False) + n_drop += 1 + else: + outputs.append(True) + + logger.info(f"Empty slice drop rate {round((n_drop/len(slices))*100,1)}%") + return outputs + + +def subtract_mean(x: torch.Tensor) -> torch.Tensor: + """ + Subtract per-channel means (ImageNet-like: [0.406, 0.456, 0.485]) + from the input 4D or 5D tensor. Expects channels in the first dimension + after the batch dimension: (B, C, H, W) or (B, C, H, W, D). + """ + mean = [0.406, 0.456, 0.485] + x[:, 0, ...] -= mean[0] + x[:, 1, ...] -= mean[1] + x[:, 2, ...] -= mean[2] + return x + + +def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: + """ + Average out the spatial dimensions of a tensor, preserving or removing them + according to `keepdim`. This is used to produce a 1D feature vector + out of a feature map. + + Args: + x (torch.Tensor): Input tensor (B, C, H, W, ...) or (B, C, H, W). + keepdim (bool): Whether to keep dimension or not after averaging. + + Returns: + torch.Tensor: Tensor with reduced spatial dimensions. + """ + dim = len(x.shape) + # 2D -> no average + if dim == 2: + return x + # 3D -> average over last dim + if dim == 3: + return x.mean([2], keepdim=keepdim) + # 4D -> average over H,W + if dim == 4: + return x.mean([2, 3], keepdim=keepdim) + # 5D -> average over H,W,D + if dim == 5: + return x.mean([2, 3, 4], keepdim=keepdim) + return x + + +def medicalnet_intensity_normalisation(volume: torch.Tensor) -> torch.Tensor: + """ + Intensity normalization approach from MedicalNet: + (volume - mean) / (std + 1e-5) across spatial dims. + Expects (B, C, H, W) or (B, C, H, W, D). + """ + dim = len(volume.shape) + if dim == 4: + mean = volume.mean([2, 3], keepdim=True) + std = volume.std([2, 3], keepdim=True) + elif dim == 5: + mean = volume.mean([2, 3, 4], keepdim=True) + std = volume.std([2, 3, 4], keepdim=True) + else: + return volume + return (volume - mean) / (std + 1e-5) + + +def radimagenet_intensity_normalisation(volume: torch.Tensor, norm2d: bool = False) -> torch.Tensor: + """ + Intensity normalization for radimagenet_resnet. Optionally normalizes each 2D slice individually. + + Args: + volume (torch.Tensor): Input (B, C, H, W) or (B, C, H, W, D). + norm2d (bool): If True, normalizes each (H,W) slice to [0,1], then subtracts the ImageNet mean. + """ + logger.info(f"norm2d: {norm2d}") + dim = len(volume.shape) + # If norm2d is True, only meaningful for 4D data (B, C, H, W): + if dim == 4 and norm2d: + max2d, _ = torch.max(volume, dim=2, keepdim=True) + max2d, _ = torch.max(max2d, dim=3, keepdim=True) + min2d, _ = torch.min(volume, dim=2, keepdim=True) + min2d, _ = torch.min(min2d, dim=3, keepdim=True) + # Scale each slice to 0..1 + volume = (volume - min2d) / (max2d - min2d + 1e-10) + # Subtract channel mean + return subtract_mean(volume) + elif dim == 4: + # 4D but no per-slice normalization + max3d = torch.max(volume) + min3d = torch.min(volume) + volume = (volume - min3d) / (max3d - min3d + 1e-10) + return subtract_mean(volume) + # Fallback for e.g. 5D data is simply a min-max over entire volume + if dim == 5: + maxval = torch.max(volume) + minval = torch.min(volume) + volume = (volume - minval) / (maxval - minval + 1e-10) + return subtract_mean(volume) + return volume + + +def get_features_2p5d( + image: torch.Tensor, + feature_network: torch.nn.Module, + center_slices: bool = False, + center_slices_ratio: float = 1.0, + sample_every_k: int = 1, + xy_only: bool = True, + drop_empty: bool = False, + empty_threshold: float = -700, +) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + """ + Extract 2.5D features from a 3D image by slicing it along XY, YZ, ZX planes. + + Args: + image (torch.Tensor): Input 5D tensor in shape (B, C, H, W, D). + feature_network (torch.nn.Module): Model that processes 2D slices (C,H,W). + center_slices (bool): Whether to slice only the center portion of each axis. + center_slices_ratio (float): Ratio of slices to keep in the center if `center_slices` is True. + sample_every_k (int): Downsampling factor along each axis when slicing. + xy_only (bool): If True, return only the XY-plane features. + drop_empty (bool): Drop slices that are deemed "empty" below `empty_threshold`. + empty_threshold (float): Threshold to decide emptiness of slices. + + Returns: + tuple of torch.Tensor or None: (XY_features, YZ_features, ZX_features). + """ + logger.info(f"center_slices: {center_slices}, ratio: {center_slices_ratio}") + + # If there's only 1 channel, replicate to 3 channels + if image.shape[1] == 1: + image = image.repeat(1, 3, 1, 1, 1) + + # Convert from 'RGB'→(R,G,B) to (B,G,R) + image = image[:, [2, 1, 0], ...] + + B, C, H, W, D = image.size() + with torch.no_grad(): + # ---------------------- XY-plane slicing along D ---------------------- + if center_slices: + start_d = int((1.0 - center_slices_ratio) / 2.0 * D) + end_d = int((1.0 + center_slices_ratio) / 2.0 * D) + slices = torch.unbind(image[:, :, :, :, start_d:end_d:sample_every_k], dim=-1) + else: + slices = torch.unbind(image, dim=-1) + + if drop_empty: + mapping_index = drop_empty_slice(slices, empty_threshold) + else: + mapping_index = [True for _ in range(len(slices))] + + images_2d = torch.cat(slices, dim=0) + images_2d = radimagenet_intensity_normalisation(images_2d) + images_2d = images_2d[mapping_index] + + feature_image_xy = feature_network.forward(images_2d) + feature_image_xy = spatial_average(feature_image_xy, keepdim=False) + if xy_only: + return feature_image_xy, None, None + + # ---------------------- YZ-plane slicing along H ---------------------- + if center_slices: + start_h = int((1.0 - center_slices_ratio) / 2.0 * H) + end_h = int((1.0 + center_slices_ratio) / 2.0 * H) + slices = torch.unbind(image[:, :, start_h:end_h:sample_every_k, :, :], dim=2) + else: + slices = torch.unbind(image, dim=2) + + if drop_empty: + mapping_index = drop_empty_slice(slices, empty_threshold) + else: + mapping_index = [True for _ in range(len(slices))] + + images_2d = torch.cat(slices, dim=0) + images_2d = radimagenet_intensity_normalisation(images_2d) + images_2d = images_2d[mapping_index] + + feature_image_yz = feature_network.forward(images_2d) + feature_image_yz = spatial_average(feature_image_yz, keepdim=False) + + # ---------------------- ZX-plane slicing along W ---------------------- + if center_slices: + start_w = int((1.0 - center_slices_ratio) / 2.0 * W) + end_w = int((1.0 + center_slices_ratio) / 2.0 * W) + slices = torch.unbind(image[:, :, :, start_w:end_w:sample_every_k, :], dim=3) + else: + slices = torch.unbind(image, dim=3) + + if drop_empty: + mapping_index = drop_empty_slice(slices, empty_threshold) + else: + mapping_index = [True for _ in range(len(slices))] + + images_2d = torch.cat(slices, dim=0) + images_2d = radimagenet_intensity_normalisation(images_2d) + images_2d = images_2d[mapping_index] + + feature_image_zx = feature_network.forward(images_2d) + feature_image_zx = spatial_average(feature_image_zx, keepdim=False) + + return feature_image_xy, feature_image_yz, feature_image_zx + + +def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float = 0.0) -> torch.Tensor: + """ + Zero-pad a 2D feature map or other tensor along the first dimension to match a specified size. + + Args: + tensor (torch.Tensor): The feature tensor to pad. + max_size (int): Desired size along the first dimension. + padding_value (float): Value to fill during padding. + + Returns: + torch.Tensor: Padded tensor matching `max_size` along dim=0. + """ + pad_size = [0, 0] * (len(tensor.shape) - 1) + [0, max_size - tensor.shape[0]] + return F.pad(tensor, pad_size, "constant", padding_value) + + +def main( + real_dataset_root: str = "path/to/datasetA", + real_filelist: str = "path/to/filelistA.txt", + real_features_dir: str = "datasetA", + synth_dataset_root: str = "path/to/datasetB", + synth_filelist: str = "path/to/filelistB.txt", + synth_features_dir: str = "datasetB", + enable_center_slices_ratio: float = None, + enable_padding: bool = True, + enable_center_cropping: bool = True, + enable_resampling_spacing: str = None, + ignore_existing: bool = False, + model_name: str = "radimagenet_resnet50", + num_images: int = 100, + output_root: str = "./features/features-512x512x512", + target_shape: str = "512x512x512", +): + """ + Compute 2.5D FID using distributed GPU processing. + + This function loads two datasets (real vs. synthetic) in 3D medical format (NIfTI) + and extracts feature maps via a 2.5D approach, then computes the Frechet Inception + Distance (FID) across three orthogonal planes. Data parallelism is implemented + using torch.distributed with an NCCL backend. + + Args: + real_dataset_root (str): + Root folder for the real dataset. + real_filelist (str): + Path to a text file listing 3D images (e.g., NIfTI files) for the real dataset. + Each line in this file should contain a relative path (or filename) to a NIfTI file. + For example, your "real_filelist.txt" could look like: + case001.nii.gz + case002.nii.gz + case003.nii.gz + ... + These entries will be appended to `real_dataset_root`. + real_features_dir (str): + Name of the directory under `output_root` in which to store + extracted features for the real dataset. + + synth_dataset_root (str): + Root folder for the synthetic dataset. + synth_filelist (str): + Path to a text file listing 3D images (e.g., NIfTI files) for the synthetic dataset. + The format is the same as the real dataset file list, for example: + synth_case001.nii.gz + synth_case002.nii.gz + synth_case003.nii.gz + ... + These entries will be appended to `synth_dataset_root`. + synth_features_dir (str): + Name of the directory under `output_root` in which to store + extracted features for the synthetic dataset. + + enable_center_slices_ratio (float or None): + - If not None, only slices around the specified center ratio are used. + (similar to "enable_center_slices=True" with that ratio in an earlier script). + - If None, no center-slice selection is performed + (similar to "enable_center_slices=False"). + + enable_padding (bool): + Whether to pad images to `target_shape`. + + enable_center_cropping (bool): + Whether to center-crop images to `target_shape`. + + enable_resampling_spacing (str or None): + - If not None, resample images to this voxel spacing (e.g. "1.0x1.0x1.0") + (similar to "enable_resampling=True" with that spacing). + - If None, skip resampling (similar to "enable_resampling=False"). + + ignore_existing (bool): + If True, ignore any existing .pt feature files and force re-computation. + + model_name (str): + Model identifier. Typically "radimagenet_resnet50" or "squeezenet1_1". + + num_images (int): + Maximum number of images to load from each dataset (truncate if more are present). + + output_root (str): + Parent folder where extracted .pt files and logs will be saved. + + target_shape (str): + Target shape, e.g. "512x512x512", for padding, cropping, or resampling operations. + + Returns: + None + """ + # ------------------------------------------------------------------------- + # Initialize Process Group (Distributed) + # ------------------------------------------------------------------------- + dist.init_process_group(backend="nccl", init_method="env://", timeout=timedelta(seconds=7200)) + + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(dist.get_world_size()) + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + logger.info(f"[INFO] Running process on {device} of total {world_size} ranks.") + + # Convert potential string bools to actual bools (if using Fire or similar) + if not isinstance(enable_padding, bool): + enable_padding = enable_padding.lower() == "true" + if not isinstance(enable_center_cropping, bool): + enable_center_cropping = enable_center_cropping.lower() == "true" + if not isinstance(ignore_existing, bool): + ignore_existing = ignore_existing.lower() == "true" + + # Merge logic for center slices + enable_center_slices = enable_center_slices_ratio is not None + + # Merge logic for resampling + enable_resampling = enable_resampling_spacing is not None + + # Print out some flags on rank 0 + if local_rank == 0: + logger.info(f"Real dataset root: {real_dataset_root}") + logger.info(f"Synth dataset root: {synth_dataset_root}") + logger.info(f"enable_center_slices_ratio: {enable_center_slices_ratio}") + logger.info(f"enable_center_slices: {enable_center_slices}") + logger.info(f"enable_padding: {enable_padding}") + logger.info(f"enable_center_cropping: {enable_center_cropping}") + logger.info(f"enable_resampling_spacing: {enable_resampling_spacing}") + logger.info(f"enable_resampling: {enable_resampling}") + logger.info(f"ignore_existing: {ignore_existing}") + + # ------------------------------------------------------------------------- + # Load feature extraction model + # ------------------------------------------------------------------------- + if model_name == "radimagenet_resnet50": + feature_network = torch.hub.load( + "Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True, trust_repo=True + ) + suffix = "radimagenet_resnet50" + else: + import torchvision + + feature_network = torchvision.models.squeezenet1_1(pretrained=True) + suffix = "squeezenet1_1" + + feature_network.to(device) + feature_network.eval() + + # ------------------------------------------------------------------------- + # Parse shape/spacings + # ------------------------------------------------------------------------- + t_shape = [int(x) for x in target_shape.split("x")] + target_shape_tuple = tuple(t_shape) + + # If not None, parse the resampling spacing + if enable_resampling: + rs_spacing = [float(x) for x in enable_resampling_spacing.split("x")] + rs_spacing_tuple = tuple(rs_spacing) + if local_rank == 0: + logger.info(f"Resampling spacing: {rs_spacing_tuple}") + else: + rs_spacing_tuple = (1.0, 1.0, 1.0) + + # Use the ratio if provided, otherwise 1.0 + center_slices_ratio_final = enable_center_slices_ratio if enable_center_slices else 1.0 + if local_rank == 0: + logger.info(f"center_slices_ratio: {center_slices_ratio_final}") + + # ------------------------------------------------------------------------- + # Prepare Real Dataset + # ------------------------------------------------------------------------- + output_root_real = os.path.join(output_root, real_features_dir) + with open(real_filelist, "r") as rf: + real_lines = [l.strip() for l in rf.readlines()] + real_lines.sort() + real_lines = real_lines[:num_images] + + real_filenames = [{"image": os.path.join(real_dataset_root, f)} for f in real_lines] + real_filenames = monai.data.partition_dataset( + data=real_filenames, shuffle=False, num_partitions=world_size, even_divisible=False + )[local_rank] + + # ------------------------------------------------------------------------- + # Prepare Synthetic Dataset + # ------------------------------------------------------------------------- + output_root_synth = os.path.join(output_root, synth_features_dir) + with open(synth_filelist, "r") as sf: + synth_lines = [l.strip() for l in sf.readlines()] + synth_lines.sort() + synth_lines = synth_lines[:num_images] + + synth_filenames = [{"image": os.path.join(synth_dataset_root, f)} for f in synth_lines] + synth_filenames = monai.data.partition_dataset( + data=synth_filenames, shuffle=False, num_partitions=world_size, even_divisible=False + )[local_rank] + + # ------------------------------------------------------------------------- + # Build MONAI transforms + # ------------------------------------------------------------------------- + transform_list = [ + monai.transforms.LoadImaged(keys=["image"]), + monai.transforms.EnsureChannelFirstd(keys=["image"]), + monai.transforms.Orientationd(keys=["image"], axcodes="RAS"), + ] + + if enable_resampling: + transform_list.append(monai.transforms.Spacingd(keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"])) + + if enable_padding: + transform_list.append( + monai.transforms.SpatialPadd(keys=["image"], spatial_size=target_shape_tuple, mode="constant", value=-1000) + ) + + if enable_center_cropping: + transform_list.append(monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple)) + + transform_list.append( + monai.transforms.ScaleIntensityRanged( + keys=["image"], a_min=-1000, a_max=1000, b_min=-1000, b_max=1000, clip=True + ) + ) + transforms = Compose(transform_list) + + # ------------------------------------------------------------------------- + # Create DataLoaders + # ------------------------------------------------------------------------- + real_ds = monai.data.Dataset(data=real_filenames, transform=transforms) + real_loader = monai.data.DataLoader(real_ds, num_workers=6, batch_size=1, shuffle=False) + + synth_ds = monai.data.Dataset(data=synth_filenames, transform=transforms) + synth_loader = monai.data.DataLoader(synth_ds, num_workers=6, batch_size=1, shuffle=False) + + # ------------------------------------------------------------------------- + # Extract features for Real Dataset + # ------------------------------------------------------------------------- + real_features_xy, real_features_yz, real_features_zx = [], [], [] + for idx, batch_data in enumerate(real_loader, start=1): + img = batch_data["image"].to(device) + fn = img.meta["filename_or_obj"][0] + logger.info(f"[Rank {local_rank}] Real data {idx}/{len(real_filenames)}: {fn}") + + out_fp = fn.replace(real_dataset_root, output_root_real).replace(".nii.gz", ".pt") + out_fp = Path(out_fp) + out_fp.parent.mkdir(parents=True, exist_ok=True) + + if (not ignore_existing) and os.path.isfile(out_fp): + feats = torch.load(out_fp, weights_only=True) + else: + img_t = img.as_tensor() + logger.info(f"image shape: {tuple(img_t.shape)}") + + feats = get_features_2p5d( + img_t, + feature_network, + center_slices=enable_center_slices, + center_slices_ratio=center_slices_ratio_final, + xy_only=False, + ) + logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}") + torch.save(feats, out_fp) + + real_features_xy.append(feats[0]) + real_features_yz.append(feats[1]) + real_features_zx.append(feats[2]) + + real_features_xy = torch.vstack(real_features_xy) + real_features_yz = torch.vstack(real_features_yz) + real_features_zx = torch.vstack(real_features_zx) + logger.info( + f"Real feature shapes: {real_features_xy.shape}, " f"{real_features_yz.shape}, {real_features_zx.shape}" + ) + + # ------------------------------------------------------------------------- + # Extract features for Synthetic Dataset + # ------------------------------------------------------------------------- + synth_features_xy, synth_features_yz, synth_features_zx = [], [], [] + for idx, batch_data in enumerate(synth_loader, start=1): + img = batch_data["image"].to(device) + fn = img.meta["filename_or_obj"][0] + logger.info(f"[Rank {local_rank}] Synth data {idx}/{len(synth_filenames)}: {fn}") + + out_fp = fn.replace(synth_dataset_root, output_root_synth).replace(".nii.gz", ".pt") + out_fp = Path(out_fp) + out_fp.parent.mkdir(parents=True, exist_ok=True) + + if (not ignore_existing) and os.path.isfile(out_fp): + feats = torch.load(out_fp, weights_only=True) + else: + img_t = img.as_tensor() + logger.info(f"image shape: {tuple(img_t.shape)}") + + feats = get_features_2p5d( + img_t, + feature_network, + center_slices=enable_center_slices, + center_slices_ratio=center_slices_ratio_final, + xy_only=False, + ) + logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}") + torch.save(feats, out_fp) + + synth_features_xy.append(feats[0]) + synth_features_yz.append(feats[1]) + synth_features_zx.append(feats[2]) + + synth_features_xy = torch.vstack(synth_features_xy) + synth_features_yz = torch.vstack(synth_features_yz) + synth_features_zx = torch.vstack(synth_features_zx) + logger.info( + f"Synth feature shapes: {synth_features_xy.shape}, " f"{synth_features_yz.shape}, {synth_features_zx.shape}" + ) + + # ------------------------------------------------------------------------- + # All-reduce / gather features across ranks + # ------------------------------------------------------------------------- + features = [ + real_features_xy, + real_features_yz, + real_features_zx, + synth_features_xy, + synth_features_yz, + synth_features_zx, + ] + + # 1) Gather local feature sizes across ranks + local_sizes = [] + for ft_idx in range(len(features)): + local_size = torch.tensor([features[ft_idx].shape[0]], dtype=torch.int64, device=device) + local_sizes.append(local_size) + + all_sizes = [] + for ft_idx in range(len(features)): + rank_sizes = [torch.tensor([0], dtype=torch.int64, device=device) for _ in range(world_size)] + dist.all_gather(rank_sizes, local_sizes[ft_idx]) + all_sizes.append(rank_sizes) + + # 2) Pad and gather all features + all_tensors_list = [] + for ft_idx, ft in enumerate(features): + max_size = max(all_sizes[ft_idx]).item() + ft_padded = pad_to_max_size(ft, max_size) + + gather_list = [torch.empty_like(ft_padded) for _ in range(world_size)] + dist.all_gather(gather_list, ft_padded) + + # Trim each gather back to the real size + for rk in range(world_size): + gather_list[rk] = gather_list[rk][: all_sizes[ft_idx][rk], :] + + all_tensors_list.append(gather_list) + + # On rank 0, compute FID + if local_rank == 0: + real_xy = torch.vstack(all_tensors_list[0]) + real_yz = torch.vstack(all_tensors_list[1]) + real_zx = torch.vstack(all_tensors_list[2]) + + synth_xy = torch.vstack(all_tensors_list[3]) + synth_yz = torch.vstack(all_tensors_list[4]) + synth_zx = torch.vstack(all_tensors_list[5]) + + logger.info(f"Final Real shapes: {real_xy.shape}, {real_yz.shape}, {real_zx.shape}") + logger.info(f"Final Synth shapes: {synth_xy.shape}, {synth_yz.shape}, {synth_zx.shape}") + + fid = FIDMetric() + logger.info(f"Computing FID for: {output_root_real} | {output_root_synth}") + fid_res_xy = fid(synth_xy, real_xy) + fid_res_yz = fid(synth_yz, real_yz) + fid_res_zx = fid(synth_zx, real_zx) + + logger.info(f"FID XY: {fid_res_xy}") + logger.info(f"FID YZ: {fid_res_yz}") + logger.info(f"FID ZX: {fid_res_zx}") + fid_avg = (fid_res_xy + fid_res_yz + fid_res_zx) / 3.0 + logger.info(f"FID Avg: {fid_avg}") + + dist.destroy_process_group() + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/scripts/diff_model_create_training_data.py b/scripts/diff_model_create_training_data.py new file mode 100644 index 0000000000000000000000000000000000000000..f70d3859ac32c1dfd60265b31f8eaeaabc7b2e5f --- /dev/null +++ b/scripts/diff_model_create_training_data.py @@ -0,0 +1,231 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import json +import logging +import os +from pathlib import Path + +import monai +import nibabel as nib +import numpy as np +import torch +import torch.distributed as dist +from monai.transforms import Compose +from monai.utils import set_determinism + +from .diff_model_setting import initialize_distributed, load_config, setup_logging +from .utils import define_instance + +# Set the random seed for reproducibility +set_determinism(seed=0) + + +def create_transforms(dim: tuple = None) -> Compose: + """ + Create a set of MONAI transforms for preprocessing. + + Args: + dim (tuple, optional): New dimensions for resizing. Defaults to None. + + Returns: + Compose: Composed MONAI transforms. + """ + if dim: + return Compose( + [ + monai.transforms.LoadImaged(keys="image"), + monai.transforms.EnsureChannelFirstd(keys="image"), + monai.transforms.Orientationd(keys="image", axcodes="RAS"), + monai.transforms.EnsureTyped(keys="image", dtype=torch.float32), + monai.transforms.ScaleIntensityRanged( + keys="image", a_min=-1000, a_max=1000, b_min=0, b_max=1, clip=True + ), + monai.transforms.Resized(keys="image", spatial_size=dim, mode="trilinear"), + ] + ) + else: + return Compose( + [ + monai.transforms.LoadImaged(keys="image"), + monai.transforms.EnsureChannelFirstd(keys="image"), + monai.transforms.Orientationd(keys="image", axcodes="RAS"), + ] + ) + + +def round_number(number: int, base_number: int = 128) -> int: + """ + Round the number to the nearest multiple of the base number, with a minimum value of the base number. + + Args: + number (int): Number to be rounded. + base_number (int): Number to be common divisor. + + Returns: + int: Rounded number. + """ + new_number = max(round(float(number) / float(base_number)), 1.0) * float(base_number) + return int(new_number) + + +def load_filenames(data_list_path: str) -> list: + """ + Load filenames from the JSON data list. + + Args: + data_list_path (str): Path to the JSON data list file. + + Returns: + list: List of filenames. + """ + with open(data_list_path, "r") as file: + json_data = json.load(file) + filenames_raw = json_data["training"] + return [_item["image"] for _item in filenames_raw] + + +def process_file( + filepath: str, + args: argparse.Namespace, + autoencoder: torch.nn.Module, + device: torch.device, + plain_transforms: Compose, + new_transforms: Compose, + logger: logging.Logger, +) -> None: + """ + Process a single file to create training data. + + Args: + filepath (str): Path to the file to be processed. + args (argparse.Namespace): Configuration arguments. + autoencoder (torch.nn.Module): Autoencoder model. + device (torch.device): Device to process the file on. + plain_transforms (Compose): Plain transforms. + new_transforms (Compose): New transforms. + logger (logging.Logger): Logger for logging information. + """ + out_filename_base = filepath.replace(".gz", "").replace(".nii", "") + out_filename_base = os.path.join(args.embedding_base_dir, out_filename_base) + out_filename = out_filename_base + "_emb.nii.gz" + + if os.path.isfile(out_filename): + return + + test_data = {"image": os.path.join(args.data_base_dir, filepath)} + transformed_data = plain_transforms(test_data) + nda = transformed_data["image"] + + dim = [int(nda.meta["dim"][_i]) for _i in range(1, 4)] + spacing = [float(nda.meta["pixdim"][_i]) for _i in range(1, 4)] + + logger.info(f"old dim: {dim}, old spacing: {spacing}") + + new_data = new_transforms(test_data) + nda_image = new_data["image"] + + new_affine = nda_image.meta["affine"].numpy() + nda_image = nda_image.numpy().squeeze() + + logger.info(f"new dim: {nda_image.shape}, new affine: {new_affine}") + + try: + out_path = Path(out_filename) + out_path.parent.mkdir(parents=True, exist_ok=True) + logger.info(f"out_filename: {out_filename}") + + with torch.amp.autocast("cuda"): + pt_nda = torch.from_numpy(nda_image).float().to(device).unsqueeze(0).unsqueeze(0) + z = autoencoder.encode_stage_2_inputs(pt_nda) + logger.info(f"z: {z.size()}, {z.dtype}") + + out_nda = z.squeeze().cpu().detach().numpy().transpose(1, 2, 3, 0) + out_img = nib.Nifti1Image(np.float32(out_nda), affine=new_affine) + nib.save(out_img, out_filename) + except Exception as e: + logger.error(f"Error processing {filepath}: {e}") + + +@torch.inference_mode() +def diff_model_create_training_data( + env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int +) -> None: + """ + Create training data for the diffusion model. + + Args: + env_config_path (str): Path to the environment configuration file. + model_config_path (str): Path to the model configuration file. + model_def_path (str): Path to the model definition file. + """ + args = load_config(env_config_path, model_config_path, model_def_path) + local_rank, world_size, device = initialize_distributed(num_gpus=num_gpus) + logger = setup_logging("creating training data") + logger.info(f"Using device {device}") + + autoencoder = define_instance(args, "autoencoder_def").to(device) + try: + checkpoint_autoencoder = torch.load(args.trained_autoencoder_path, weights_only=True) + autoencoder.load_state_dict(checkpoint_autoencoder) + except Exception: + logger.error("The trained_autoencoder_path does not exist!") + + Path(args.embedding_base_dir).mkdir(parents=True, exist_ok=True) + + filenames_raw = load_filenames(args.json_data_list) + logger.info(f"filenames_raw: {filenames_raw}") + + plain_transforms = create_transforms(dim=None) + + for _iter in range(len(filenames_raw)): + if _iter % world_size != local_rank: + continue + + filepath = filenames_raw[_iter] + new_dim = tuple( + round_number( + int(plain_transforms({"image": os.path.join(args.data_base_dir, filepath)})["image"].meta["dim"][_i]) + ) + for _i in range(1, 4) + ) + new_transforms = create_transforms(new_dim) + + process_file(filepath, args, autoencoder, device, plain_transforms, new_transforms, logger) + + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Diffusion Model Training Data Creation") + parser.add_argument( + "--env_config", + type=str, + default="./configs/environment_maisi_diff_model_train.json", + help="Path to environment configuration file", + ) + parser.add_argument( + "--model_config", + type=str, + default="./configs/config_maisi_diff_model_train.json", + help="Path to model training/inference configuration", + ) + parser.add_argument( + "--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file" + ) + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for distributed training") + + args = parser.parse_args() + diff_model_create_training_data(args.env_config, args.model_config, args.model_def, args.num_gpus) diff --git a/scripts/diff_model_infer.py b/scripts/diff_model_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..298e9c52112dab0aef795b195c8628dc60210ca1 --- /dev/null +++ b/scripts/diff_model_infer.py @@ -0,0 +1,358 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import logging +import os +import random +from datetime import datetime + +import nibabel as nib +import numpy as np +import torch +import torch.distributed as dist +from monai.inferers import sliding_window_inference +from monai.inferers.inferer import SlidingWindowInferer +from monai.networks.schedulers import RFlowScheduler +from monai.utils import set_determinism +from tqdm import tqdm + +from .diff_model_setting import initialize_distributed, load_config, setup_logging +from .sample import ReconModel, check_input +from .utils import define_instance, dynamic_infer + + +def set_random_seed(seed: int) -> int: + """ + Set random seed for reproducibility. + + Args: + seed (int): Random seed. + + Returns: + int: Set random seed. + """ + random_seed = random.randint(0, 99999) if seed is None else seed + set_determinism(random_seed) + return random_seed + + +def load_models(args: argparse.Namespace, device: torch.device, logger: logging.Logger) -> tuple: + """ + Load the autoencoder and UNet models. + + Args: + args (argparse.Namespace): Configuration arguments. + device (torch.device): Device to load models on. + logger (logging.Logger): Logger for logging information. + + Returns: + tuple: Loaded autoencoder, UNet model, and scale factor. + """ + autoencoder = define_instance(args, "autoencoder_def").to(device) + try: + checkpoint_autoencoder = torch.load(args.trained_autoencoder_path, weights_only=True) + autoencoder.load_state_dict(checkpoint_autoencoder) + except Exception: + logger.error("The trained_autoencoder_path does not exist!") + + unet = define_instance(args, "diffusion_unet_def").to(device) + checkpoint = torch.load(f"{args.model_dir}/{args.model_filename}", map_location=device, weights_only=False) + unet.load_state_dict(checkpoint["unet_state_dict"], strict=True) + logger.info(f"checkpoints {args.model_dir}/{args.model_filename} loaded.") + + scale_factor = checkpoint["scale_factor"] + logger.info(f"scale_factor -> {scale_factor}.") + + return autoencoder, unet, scale_factor + + +def prepare_tensors(args: argparse.Namespace, device: torch.device) -> tuple: + """ + Prepare necessary tensors for inference. + + Args: + args (argparse.Namespace): Configuration arguments. + device (torch.device): Device to load tensors on. + + Returns: + tuple: Prepared top_region_index_tensor, bottom_region_index_tensor, and spacing_tensor. + """ + top_region_index_tensor = np.array(args.diffusion_unet_inference["top_region_index"]).astype(float) * 1e2 + bottom_region_index_tensor = np.array(args.diffusion_unet_inference["bottom_region_index"]).astype(float) * 1e2 + spacing_tensor = np.array(args.diffusion_unet_inference["spacing"]).astype(float) * 1e2 + + top_region_index_tensor = torch.from_numpy(top_region_index_tensor[np.newaxis, :]).half().to(device) + bottom_region_index_tensor = torch.from_numpy(bottom_region_index_tensor[np.newaxis, :]).half().to(device) + spacing_tensor = torch.from_numpy(spacing_tensor[np.newaxis, :]).half().to(device) + modality_tensor = args.diffusion_unet_inference["modality"] * torch.ones( + (len(spacing_tensor)), dtype=torch.long + ).to(device) + + return top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor + + +def run_inference( + args: argparse.Namespace, + device: torch.device, + autoencoder: torch.nn.Module, + unet: torch.nn.Module, + scale_factor: float, + top_region_index_tensor: torch.Tensor, + bottom_region_index_tensor: torch.Tensor, + spacing_tensor: torch.Tensor, + modality_tensor: torch.Tensor, + output_size: tuple, + divisor: int, + logger: logging.Logger, +) -> np.ndarray: + """ + Run the inference to generate synthetic images. + + Args: + args (argparse.Namespace): Configuration arguments. + device (torch.device): Device to run inference on. + autoencoder (torch.nn.Module): Autoencoder model. + unet (torch.nn.Module): UNet model. + scale_factor (float): Scale factor for the model. + top_region_index_tensor (torch.Tensor): Top region index tensor. + bottom_region_index_tensor (torch.Tensor): Bottom region index tensor. + spacing_tensor (torch.Tensor): Spacing tensor. + modality_tensor (torch.Tensor): Modality tensor. + output_size (tuple): Output size of the synthetic image. + divisor (int): Divisor for downsample level. + logger (logging.Logger): Logger for logging information. + + Returns: + np.ndarray: Generated synthetic image data. + """ + include_body_region = unet.include_top_region_index_input + include_modality = unet.num_class_embeds is not None + + noise = torch.randn( + ( + 1, + args.latent_channels, + output_size[0] // divisor, + output_size[1] // divisor, + output_size[2] // divisor, + ), + device=device, + ) + logger.info(f"noise: {noise.device}, {noise.dtype}, {type(noise)}") + + image = noise + noise_scheduler = define_instance(args, "noise_scheduler") + if isinstance(noise_scheduler, RFlowScheduler): + noise_scheduler.set_timesteps( + num_inference_steps=args.diffusion_unet_inference["num_inference_steps"], + input_img_size_numel=torch.prod(torch.tensor(noise.shape[2:])), + ) + else: + noise_scheduler.set_timesteps(num_inference_steps=args.diffusion_unet_inference["num_inference_steps"]) + + recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device) + autoencoder.eval() + unet.eval() + + all_timesteps = noise_scheduler.timesteps + all_next_timesteps = torch.cat((all_timesteps[1:], torch.tensor([0], dtype=all_timesteps.dtype))) + progress_bar = tqdm( + zip(all_timesteps, all_next_timesteps), + total=min(len(all_timesteps), len(all_next_timesteps)), + ) + with torch.amp.autocast("cuda", enabled=True): + for t, next_t in progress_bar: + # Create a dictionary to store the inputs + unet_inputs = { + "x": image, + "timesteps": torch.Tensor((t,)).to(device), + "spacing_tensor": spacing_tensor, + } + + # Add extra arguments if include_body_region is True + if include_body_region: + unet_inputs.update( + { + "top_region_index_tensor": top_region_index_tensor, + "bottom_region_index_tensor": bottom_region_index_tensor, + } + ) + + if include_modality: + unet_inputs.update( + { + "class_labels": modality_tensor, + } + ) + model_output = unet(**unet_inputs) + if not isinstance(noise_scheduler, RFlowScheduler): + image, _ = noise_scheduler.step(model_output, t, image) # type: ignore + else: + image, _ = noise_scheduler.step(model_output, t, image, next_t) # type: ignore + + inferer = SlidingWindowInferer( + roi_size=[80, 80, 80], + sw_batch_size=1, + progress=True, + mode="gaussian", + overlap=0.4, + sw_device=device, + device=device, + ) + synthetic_images = dynamic_infer(inferer, recon_model, image) + data = synthetic_images.squeeze().cpu().detach().numpy() + a_min, a_max, b_min, b_max = -1000, 1000, 0, 1 + data = (data - b_min) / (b_max - b_min) * (a_max - a_min) + a_min + data = np.clip(data, a_min, a_max) + return np.int16(data) + + +def save_image( + data: np.ndarray, + output_size: tuple, + out_spacing: tuple, + output_path: str, + logger: logging.Logger, +) -> None: + """ + Save the generated synthetic image to a file. + + Args: + data (np.ndarray): Synthetic image data. + output_size (tuple): Output size of the image. + out_spacing (tuple): Spacing of the output image. + output_path (str): Path to save the output image. + logger (logging.Logger): Logger for logging information. + """ + out_affine = np.eye(4) + for i in range(3): + out_affine[i, i] = out_spacing[i] + + new_image = nib.Nifti1Image(data, affine=out_affine) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + nib.save(new_image, output_path) + logger.info(f"Saved {output_path}.") + + +@torch.inference_mode() +def diff_model_infer(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None: + """ + Main function to run the diffusion model inference. + + Args: + env_config_path (str): Path to the environment configuration file. + model_config_path (str): Path to the model configuration file. + model_def_path (str): Path to the model definition file. + """ + args = load_config(env_config_path, model_config_path, model_def_path) + local_rank, world_size, device = initialize_distributed(num_gpus) + logger = setup_logging("inference") + random_seed = set_random_seed( + args.diffusion_unet_inference["random_seed"] + local_rank + if args.diffusion_unet_inference["random_seed"] + else None + ) + logger.info(f"Using {device} of {world_size} with random seed: {random_seed}") + + output_size = tuple(args.diffusion_unet_inference["dim"]) + out_spacing = tuple(args.diffusion_unet_inference["spacing"]) + output_prefix = args.output_prefix + ckpt_filepath = f"{args.model_dir}/{args.model_filename}" + + if local_rank == 0: + logger.info(f"[config] ckpt_filepath -> {ckpt_filepath}.") + logger.info(f"[config] random_seed -> {random_seed}.") + logger.info(f"[config] output_prefix -> {output_prefix}.") + logger.info(f"[config] output_size -> {output_size}.") + logger.info(f"[config] out_spacing -> {out_spacing}.") + + check_input(None, None, None, output_size, out_spacing, None) + + autoencoder, unet, scale_factor = load_models(args, device, logger) + num_downsample_level = max( + 1, + ( + len(args.diffusion_unet_def["num_channels"]) + if isinstance(args.diffusion_unet_def["num_channels"], list) + else len(args.diffusion_unet_def["attention_levels"]) + ), + ) + divisor = 2 ** (num_downsample_level - 2) + logger.info(f"num_downsample_level -> {num_downsample_level}, divisor -> {divisor}.") + + top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor = prepare_tensors(args, device) + data = run_inference( + args, + device, + autoencoder, + unet, + scale_factor, + top_region_index_tensor, + bottom_region_index_tensor, + spacing_tensor, + modality_tensor, + output_size, + divisor, + logger, + ) + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + output_path = "{0}/{1}_seed{2}_size{3:d}x{4:d}x{5:d}_spacing{6:.2f}x{7:.2f}x{8:.2f}_{9}_rank{10}.nii.gz".format( + args.output_dir, + output_prefix, + random_seed, + output_size[0], + output_size[1], + output_size[2], + out_spacing[0], + out_spacing[1], + out_spacing[2], + timestamp, + local_rank, + ) + save_image(data, output_size, out_spacing, output_path, logger) + + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Diffusion Model Inference") + parser.add_argument( + "--env_config", + type=str, + default="./configs/environment_maisi_diff_model_train.json", + help="Path to environment configuration file", + ) + parser.add_argument( + "--model_config", + type=str, + default="./configs/config_maisi_diff_model_train.json", + help="Path to model training/inference configuration", + ) + parser.add_argument( + "--model_def", + type=str, + default="./configs/config_maisi.json", + help="Path to model definition file", + ) + parser.add_argument( + "--num_gpus", + type=int, + default=1, + help="Number of GPUs to use for distributed inference", + ) + + args = parser.parse_args() + diff_model_infer(args.env_config, args.model_config, args.model_def, args.num_gpus) diff --git a/scripts/diff_model_setting.py b/scripts/diff_model_setting.py new file mode 100644 index 0000000000000000000000000000000000000000..cee91965938cae01bc71b9aba1771944130712aa --- /dev/null +++ b/scripts/diff_model_setting.py @@ -0,0 +1,92 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import json +import logging + +import torch +import torch.distributed as dist +from monai.utils import RankFilter + + +def setup_logging(logger_name: str = "") -> logging.Logger: + """ + Setup the logging configuration. + + Args: + logger_name (str): logger name. + + Returns: + logging.Logger: Configured logger. + """ + logger = logging.getLogger(logger_name) + if dist.is_initialized(): + logger.addFilter(RankFilter()) + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + return logger + + +def load_config(env_config_path: str, model_config_path: str, model_def_path: str) -> argparse.Namespace: + """ + Load configuration from JSON files. + + Args: + env_config_path (str): Path to the environment configuration file. + model_config_path (str): Path to the model configuration file. + model_def_path (str): Path to the model definition file. + + Returns: + argparse.Namespace: Loaded configuration. + """ + args = argparse.Namespace() + + with open(env_config_path, "r") as f: + env_config = json.load(f) + for k, v in env_config.items(): + setattr(args, k, v) + + with open(model_config_path, "r") as f: + model_config = json.load(f) + for k, v in model_config.items(): + setattr(args, k, v) + + with open(model_def_path, "r") as f: + model_def = json.load(f) + for k, v in model_def.items(): + setattr(args, k, v) + + return args + + +def initialize_distributed(num_gpus: int) -> tuple: + """ + Initialize distributed training. + + Returns: + tuple: local_rank, world_size, and device. + """ + if torch.cuda.is_available() and num_gpus > 1: + dist.init_process_group(backend="nccl", init_method="env://") + local_rank = dist.get_rank() + world_size = dist.get_world_size() + else: + local_rank = 0 + world_size = 1 + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + return local_rank, world_size, device diff --git a/scripts/diff_model_train.py b/scripts/diff_model_train.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b20aed939835e8104d326b0c4addea82c7946f --- /dev/null +++ b/scripts/diff_model_train.py @@ -0,0 +1,499 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import json +import logging +import os +from datetime import datetime +from pathlib import Path + +import monai +import torch +import torch.distributed as dist +from monai.data import DataLoader, partition_dataset +from monai.networks.schedulers import RFlowScheduler +from monai.networks.schedulers.ddpm import DDPMPredictionType +from monai.transforms import Compose +from monai.utils import first +from torch.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel + +from .diff_model_setting import initialize_distributed, load_config, setup_logging +from .utils import define_instance + + +def load_filenames(data_list_path: str) -> list: + """ + Load filenames from the JSON data list. + + Args: + data_list_path (str): Path to the JSON data list file. + + Returns: + list: List of filenames. + """ + with open(data_list_path, "r") as file: + json_data = json.load(file) + filenames_train = json_data["training"] + return [_item["image"].replace(".nii.gz", "_emb.nii.gz") for _item in filenames_train] + + +def prepare_data( + train_files: list, + device: torch.device, + cache_rate: float, + num_workers: int = 2, + batch_size: int = 1, + include_body_region: bool = False, +) -> DataLoader: + """ + Prepare training data. + + Args: + train_files (list): List of training files. + device (torch.device): Device to use for training. + cache_rate (float): Cache rate for dataset. + num_workers (int): Number of workers for data loading. + batch_size (int): Mini-batch size. + include_body_region (bool): Whether to include body region in data + + Returns: + DataLoader: Data loader for training. + """ + + def _load_data_from_file(file_path, key): + with open(file_path) as f: + return torch.FloatTensor(json.load(f)[key]) + + train_transforms_list = [ + monai.transforms.LoadImaged(keys=["image"]), + monai.transforms.EnsureChannelFirstd(keys=["image"]), + monai.transforms.Lambdad(keys="spacing", func=lambda x: _load_data_from_file(x, "spacing")), + monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2), + ] + if include_body_region: + train_transforms_list += [ + monai.transforms.Lambdad( + keys="top_region_index", func=lambda x: _load_data_from_file(x, "top_region_index") + ), + monai.transforms.Lambdad( + keys="bottom_region_index", func=lambda x: _load_data_from_file(x, "bottom_region_index") + ), + monai.transforms.Lambdad(keys="top_region_index", func=lambda x: x * 1e2), + monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: x * 1e2), + ] + train_transforms = Compose(train_transforms_list) + + train_ds = monai.data.CacheDataset( + data=train_files, transform=train_transforms, cache_rate=cache_rate, num_workers=num_workers + ) + + return DataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True) + + +def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Logger) -> torch.nn.Module: + """ + Load the UNet model. + + Args: + args (argparse.Namespace): Configuration arguments. + device (torch.device): Device to load the model on. + logger (logging.Logger): Logger for logging information. + + Returns: + torch.nn.Module: Loaded UNet model. + """ + unet = define_instance(args, "diffusion_unet_def").to(device) + unet = torch.nn.SyncBatchNorm.convert_sync_batchnorm(unet) + + if dist.is_initialized(): + unet = DistributedDataParallel(unet, device_ids=[device], find_unused_parameters=True) + + if args.existing_ckpt_filepath is None: + logger.info("Training from scratch.") + else: + checkpoint_unet = torch.load(f"{args.existing_ckpt_filepath}", map_location=device, weights_only=False) + if dist.is_initialized(): + unet.module.load_state_dict(checkpoint_unet["unet_state_dict"], strict=True) + else: + unet.load_state_dict(checkpoint_unet["unet_state_dict"], strict=True) + logger.info(f"Pretrained checkpoint {args.existing_ckpt_filepath} loaded.") + + return unet + + +def calculate_scale_factor(train_loader: DataLoader, device: torch.device, logger: logging.Logger) -> torch.Tensor: + """ + Calculate the scaling factor for the dataset. + + Args: + train_loader (DataLoader): Data loader for training. + device (torch.device): Device to use for calculation. + logger (logging.Logger): Logger for logging information. + + Returns: + torch.Tensor: Calculated scaling factor. + """ + check_data = first(train_loader) + z = check_data["image"].to(device) + scale_factor = 1 / torch.std(z) + logger.info(f"Scaling factor set to {scale_factor}.") + + if dist.is_initialized(): + dist.barrier() + dist.all_reduce(scale_factor, op=torch.distributed.ReduceOp.AVG) + logger.info(f"scale_factor -> {scale_factor}.") + return scale_factor + + +def create_optimizer(model: torch.nn.Module, lr: float) -> torch.optim.Optimizer: + """ + Create optimizer for training. + + Args: + model (torch.nn.Module): Model to optimize. + lr (float): Learning rate. + + Returns: + torch.optim.Optimizer: Created optimizer. + """ + return torch.optim.Adam(params=model.parameters(), lr=lr) + + +def create_lr_scheduler(optimizer: torch.optim.Optimizer, total_steps: int) -> torch.optim.lr_scheduler.PolynomialLR: + """ + Create learning rate scheduler. + + Args: + optimizer (torch.optim.Optimizer): Optimizer to schedule. + total_steps (int): Total number of training steps. + + Returns: + torch.optim.lr_scheduler.PolynomialLR: Created learning rate scheduler. + """ + return torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0) + + +def train_one_epoch( + epoch: int, + unet: torch.nn.Module, + train_loader: DataLoader, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler.PolynomialLR, + loss_pt: torch.nn.L1Loss, + scaler: GradScaler, + scale_factor: torch.Tensor, + noise_scheduler: torch.nn.Module, + num_images_per_batch: int, + num_train_timesteps: int, + device: torch.device, + logger: logging.Logger, + local_rank: int, + amp: bool = True, +) -> torch.Tensor: + """ + Train the model for one epoch. + + Args: + epoch (int): Current epoch number. + unet (torch.nn.Module): UNet model. + train_loader (DataLoader): Data loader for training. + optimizer (torch.optim.Optimizer): Optimizer. + lr_scheduler (torch.optim.lr_scheduler.PolynomialLR): Learning rate scheduler. + loss_pt (torch.nn.L1Loss): Loss function. + scaler (GradScaler): Gradient scaler for mixed precision training. + scale_factor (torch.Tensor): Scaling factor. + noise_scheduler (torch.nn.Module): Noise scheduler. + num_images_per_batch (int): Number of images per batch. + num_train_timesteps (int): Number of training timesteps. + device (torch.device): Device to use for training. + logger (logging.Logger): Logger for logging information. + local_rank (int): Local rank for distributed training. + amp (bool): Use automatic mixed precision training. + + Returns: + torch.Tensor: Training loss for the epoch. + """ + include_body_region = unet.include_top_region_index_input + include_modality = unet.num_class_embeds is not None + + if local_rank == 0: + current_lr = optimizer.param_groups[0]["lr"] + logger.info(f"Epoch {epoch + 1}, lr {current_lr}.") + + _iter = 0 + loss_torch = torch.zeros(2, dtype=torch.float, device=device) + + unet.train() + for train_data in train_loader: + current_lr = optimizer.param_groups[0]["lr"] + + _iter += 1 + images = train_data["image"].to(device) + images = images * scale_factor + + if include_body_region: + top_region_index_tensor = train_data["top_region_index"].to(device) + bottom_region_index_tensor = train_data["bottom_region_index"].to(device) + # We trained with only CT in this version + if include_modality: + modality_tensor = torch.ones((len(images),), dtype=torch.long).to(device) + spacing_tensor = train_data["spacing"].to(device) + + optimizer.zero_grad(set_to_none=True) + + with autocast("cuda", enabled=amp): + noise = torch.randn_like(images) + + if isinstance(noise_scheduler, RFlowScheduler): + timesteps = noise_scheduler.sample_timesteps(images) + else: + timesteps = torch.randint(0, num_train_timesteps, (images.shape[0],), device=images.device).long() + + noisy_latent = noise_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps) + + # Create a dictionary to store the inputs + unet_inputs = { + "x": noisy_latent, + "timesteps": timesteps, + "spacing_tensor": spacing_tensor, + } + # Add extra arguments if include_body_region is True + if include_body_region: + unet_inputs.update( + { + "top_region_index_tensor": top_region_index_tensor, + "bottom_region_index_tensor": bottom_region_index_tensor, + } + ) + if include_modality: + unet_inputs.update( + { + "class_labels": modality_tensor, + } + ) + model_output = unet(**unet_inputs) + + if noise_scheduler.prediction_type == DDPMPredictionType.EPSILON: + # predict noise + model_gt = noise + elif noise_scheduler.prediction_type == DDPMPredictionType.SAMPLE: + # predict sample + model_gt = images + elif noise_scheduler.prediction_type == DDPMPredictionType.V_PREDICTION: + # predict velocity + model_gt = images - noise + else: + raise ValueError( + "noise scheduler prediction type has to be chosen from ", + f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]", + ) + + loss = loss_pt(model_output.float(), model_gt.float()) + + if amp: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + + lr_scheduler.step() + + loss_torch[0] += loss.item() + loss_torch[1] += 1.0 + + if local_rank == 0: + logger.info( + "[{0}] epoch {1}, iter {2}/{3}, loss: {4:.4f}, lr: {5:.12f}.".format( + str(datetime.now())[:19], epoch + 1, _iter, len(train_loader), loss.item(), current_lr + ) + ) + + if dist.is_initialized(): + dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM) + + return loss_torch + + +def save_checkpoint( + epoch: int, + unet: torch.nn.Module, + loss_torch_epoch: float, + num_train_timesteps: int, + scale_factor: torch.Tensor, + ckpt_folder: str, + args: argparse.Namespace, +) -> None: + """ + Save checkpoint. + + Args: + epoch (int): Current epoch number. + unet (torch.nn.Module): UNet model. + loss_torch_epoch (float): Training loss for the epoch. + num_train_timesteps (int): Number of training timesteps. + scale_factor (torch.Tensor): Scaling factor. + ckpt_folder (str): Checkpoint folder path. + args (argparse.Namespace): Configuration arguments. + """ + unet_state_dict = unet.module.state_dict() if dist.is_initialized() else unet.state_dict() + torch.save( + { + "epoch": epoch + 1, + "loss": loss_torch_epoch, + "num_train_timesteps": num_train_timesteps, + "scale_factor": scale_factor, + "unet_state_dict": unet_state_dict, + }, + f"{ckpt_folder}/{args.model_filename}", + ) + + +def diff_model_train( + env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, amp: bool = True +) -> None: + """ + Main function to train a diffusion model. + + Args: + env_config_path (str): Path to the environment configuration file. + model_config_path (str): Path to the model configuration file. + model_def_path (str): Path to the model definition file. + num_gpus (int): Number of GPUs to use for training. + amp (bool): Use automatic mixed precision training. + """ + args = load_config(env_config_path, model_config_path, model_def_path) + local_rank, world_size, device = initialize_distributed(num_gpus) + logger = setup_logging("training") + + logger.info(f"Using {device} of {world_size}") + + if local_rank == 0: + logger.info(f"[config] ckpt_folder -> {args.model_dir}.") + logger.info(f"[config] data_root -> {args.embedding_base_dir}.") + logger.info(f"[config] data_list -> {args.json_data_list}.") + logger.info(f"[config] lr -> {args.diffusion_unet_train['lr']}.") + logger.info(f"[config] num_epochs -> {args.diffusion_unet_train['n_epochs']}.") + logger.info(f"[config] num_train_timesteps -> {args.noise_scheduler['num_train_timesteps']}.") + + Path(args.model_dir).mkdir(parents=True, exist_ok=True) + + unet = load_unet(args, device, logger) + noise_scheduler = define_instance(args, "noise_scheduler") + include_body_region = unet.include_top_region_index_input + + filenames_train = load_filenames(args.json_data_list) + if local_rank == 0: + logger.info(f"num_files_train: {len(filenames_train)}") + + train_files = [] + for _i in range(len(filenames_train)): + str_img = os.path.join(args.embedding_base_dir, filenames_train[_i]) + if not os.path.exists(str_img): + continue + + str_info = os.path.join(args.embedding_base_dir, filenames_train[_i]) + ".json" + train_files_i = {"image": str_img, "spacing": str_info} + if include_body_region: + train_files_i["top_region_index"] = str_info + train_files_i["bottom_region_index"] = str_info + train_files.append(train_files_i) + if dist.is_initialized(): + train_files = partition_dataset( + data=train_files, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True + )[local_rank] + + train_loader = prepare_data( + train_files, + device, + args.diffusion_unet_train["cache_rate"], + batch_size=args.diffusion_unet_train["batch_size"], + include_body_region=include_body_region, + ) + + scale_factor = calculate_scale_factor(train_loader, device, logger) + optimizer = create_optimizer(unet, args.diffusion_unet_train["lr"]) + + total_steps = (args.diffusion_unet_train["n_epochs"] * len(train_loader.dataset)) / args.diffusion_unet_train[ + "batch_size" + ] + lr_scheduler = create_lr_scheduler(optimizer, total_steps) + loss_pt = torch.nn.L1Loss() + scaler = GradScaler("cuda") + + torch.set_float32_matmul_precision("highest") + logger.info("torch.set_float32_matmul_precision -> highest.") + + for epoch in range(args.diffusion_unet_train["n_epochs"]): + loss_torch = train_one_epoch( + epoch, + unet, + train_loader, + optimizer, + lr_scheduler, + loss_pt, + scaler, + scale_factor, + noise_scheduler, + args.diffusion_unet_train["batch_size"], + args.noise_scheduler["num_train_timesteps"], + device, + logger, + local_rank, + amp=amp, + ) + + loss_torch = loss_torch.tolist() + if torch.cuda.device_count() == 1 or local_rank == 0: + loss_torch_epoch = loss_torch[0] / loss_torch[1] + logger.info(f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}.") + + save_checkpoint( + epoch, + unet, + loss_torch_epoch, + args.noise_scheduler["num_train_timesteps"], + scale_factor, + args.model_dir, + args, + ) + + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Diffusion Model Training") + parser.add_argument( + "--env_config", + type=str, + default="./configs/environment_maisi_diff_model.json", + help="Path to environment configuration file", + ) + parser.add_argument( + "--model_config", + type=str, + default="./configs/config_maisi_diff_model.json", + help="Path to model training/inference configuration", + ) + parser.add_argument( + "--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file" + ) + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for training") + parser.add_argument("--no_amp", dest="amp", action="store_false", help="Disable automatic mixed precision training") + + args = parser.parse_args() + diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp) diff --git a/scripts/find_masks.py b/scripts/find_masks.py new file mode 100644 index 0000000000000000000000000000000000000000..cd33a257f79c0d30120634768048cd1f957ca71d --- /dev/null +++ b/scripts/find_masks.py @@ -0,0 +1,157 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +from typing import Sequence + +from monai.apps.utils import extractall +from monai.utils import ensure_tuple_rep + + +def convert_body_region(body_region: str | Sequence[str]) -> Sequence[int]: + """ + Convert body region string to body region index. + Args: + body_region: list of input body region string. If single str, will be converted to list of str. + Return: + body_region_indices, list of input body region index. + """ + if type(body_region) is str: + body_region = [body_region] + + # body region mapping for maisi + region_mapping_maisi = { + "head": 0, + "chest": 1, + "thorax": 1, + "chest/thorax": 1, + "abdomen": 2, + "pelvis": 3, + "lower": 3, + "pelvis/lower": 3, + } + + # perform mapping + body_region_indices = [] + for region in body_region: + normalized_region = region.lower() # norm str to lower case + if normalized_region not in region_mapping_maisi: + raise ValueError(f"Invalid region: {normalized_region}") + body_region_indices.append(region_mapping_maisi[normalized_region]) + + return body_region_indices + + +def find_masks( + body_region: str | Sequence[str], + anatomy_list: int | Sequence[int], + spacing: Sequence[float] | float = 1.0, + output_size: Sequence[int] = [512, 512, 512], + check_spacing_and_output_size: bool = False, + database_filepath: str = "./configs/database.json", + mask_foldername: str = "./datasets/masks/", +): + """ + Find candidate masks that fullfills all the requirements. + They shoud contain all the body region in `body_region`, all the anatomies in `anatomy_list`. + If there is no tumor specified in `anatomy_list`, we also expect the candidate masks to be tumor free. + If check_spacing_and_output_size is True, the candidate masks need to have the expected `spacing` and `output_size`. + Args: + body_region: list of input body region string. If single str, will be converted to list of str. + The found candidate mask will include these body regions. + anatomy_list: list of input anatomy. The found candidate mask will include these anatomies. + spacing: list of three floats, voxel spacing. If providing a single number, will use it for all the three dimensions. + output_size: list of three int, expected candidate mask spatial size. + check_spacing_and_output_size: whether we expect candidate mask to have spatial size of `output_size` and voxel size of `spacing`. + database_filepath: path for the json file that stores the information of all the candidate masks. + mask_foldername: directory that saves all the candidate masks. + Return: + candidate_masks, list of dict, each dict contains information of one candidate mask that fullfills all the requirements. + """ + # check and preprocess input + body_region = convert_body_region(body_region) + + if isinstance(anatomy_list, int): + anatomy_list = [anatomy_list] + + spacing = ensure_tuple_rep(spacing, 3) + + if not os.path.exists(mask_foldername): + zip_file_path = mask_foldername + ".zip" + + if not os.path.isfile(zip_file_path): + raise ValueError(f"Please download {zip_file_path} following the instruction in ./datasets/README.md.") + + print(f"Extracting {zip_file_path} to {os.path.dirname(zip_file_path)}") + extractall(filepath=zip_file_path, output_dir=os.path.dirname(zip_file_path), file_type="zip") + print(f"Unzipped {zip_file_path} to {mask_foldername}.") + + if not os.path.isfile(database_filepath): + raise ValueError(f"Please download {database_filepath} following the instruction in ./datasets/README.md.") + with open(database_filepath, "r") as f: + db = json.load(f) + + # select candidate_masks + candidate_masks = [] + for _item in db: + if not set(anatomy_list).issubset(_item["label_list"]): + continue + + # whether to keep this mask, default to be True. + keep_mask = True + + # extract region indice (top_index and bottom_index) for candidate mask + include_body_region = "top_region_index" in _item.keys() + if include_body_region: + top_index = [index for index, element in enumerate(_item["top_region_index"]) if element != 0] + top_index = top_index[0] + bottom_index = [index for index, element in enumerate(_item["bottom_region_index"]) if element != 0] + bottom_index = bottom_index[0] + + # if candiate mask does not contain all the body_region, skip it + for _idx in body_region: + if _idx > bottom_index or _idx < top_index: + keep_mask = False + + for tumor_label in [23, 24, 26, 27, 128]: + # we skip those mask with tumors if users do not provide tumor label in anatomy_list + if tumor_label not in anatomy_list and tumor_label in _item["label_list"]: + keep_mask = False + + if check_spacing_and_output_size: + # if the output_size and spacing are different with user's input, skip it + for axis in range(3): + if _item["dim"][axis] != output_size[axis] or _item["spacing"][axis] != spacing[axis]: + keep_mask = False + + if keep_mask: + # if decide to keep this mask, we pack the information of this mask and add to final output. + candidate = { + "pseudo_label": os.path.join(mask_foldername, _item["pseudo_label_filename"]), + "spacing": _item["spacing"], + "dim": _item["dim"], + } + if include_body_region: + candidate["top_region_index"] = _item["top_region_index"] + candidate["bottom_region_index"] = _item["bottom_region_index"] + + # Conditionally add the label to the candidate dictionary + if "label_filename" in _item: + candidate["label"] = os.path.join(mask_foldername, _item["label_filename"]) + + candidate_masks.append(candidate) + + if len(candidate_masks) == 0 and not check_spacing_and_output_size: + raise ValueError("Cannot find body region with given anatomy list.") + + return candidate_masks diff --git a/scripts/infer_controlnet.py b/scripts/infer_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f85c41b180afc5a5888d00bb49183e06e3e2e384 --- /dev/null +++ b/scripts/infer_controlnet.py @@ -0,0 +1,222 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import os +import sys +from datetime import datetime + +import torch +import torch.distributed as dist +from monai.data import MetaTensor, decollate_batch +from monai.networks.utils import copy_model_state +from monai.transforms import SaveImage +from monai.utils import RankFilter + +from .sample import check_input, ldm_conditional_sample_one_image +from .utils import define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp + + +@torch.inference_mode() +def main(): + parser = argparse.ArgumentParser(description="maisi.controlnet.infer") + parser.add_argument( + "-e", + "--environment-file", + default="./configs/environment_maisi_controlnet_train.json", + help="environment json file that stores environment path", + ) + parser.add_argument( + "-c", + "--config-file", + default="./configs/config_maisi.json", + help="config json file that stores network hyper-parameters", + ) + parser.add_argument( + "-t", + "--training-config", + default="./configs/config_maisi_controlnet_train.json", + help="config json file that stores training hyper-parameters", + ) + parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node") + + args = parser.parse_args() + + # Step 0: configuration + logger = logging.getLogger("maisi.controlnet.infer") + # whether to use distributed data parallel + use_ddp = args.gpus > 1 + if use_ddp: + rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + device = setup_ddp(rank, world_size) + logger.addFilter(RankFilter()) + else: + rank = 0 + world_size = 1 + device = torch.device(f"cuda:{rank}") + + torch.cuda.set_device(device) + logger.info(f"Number of GPUs: {torch.cuda.device_count()}") + logger.info(f"World_size: {world_size}") + + with open(args.environment_file, "r") as env_file: + env_dict = json.load(env_file) + with open(args.config_file, "r") as config_file: + config_dict = json.load(config_file) + with open(args.training_config, "r") as training_config_file: + training_config_dict = json.load(training_config_file) + + for k, v in env_dict.items(): + setattr(args, k, v) + for k, v in config_dict.items(): + setattr(args, k, v) + for k, v in training_config_dict.items(): + setattr(args, k, v) + + # Step 1: set data loader + _, val_loader = prepare_maisi_controlnet_json_dataloader( + json_data_list=args.json_data_list, + data_base_dir=args.data_base_dir, + rank=rank, + world_size=world_size, + batch_size=args.controlnet_train["batch_size"], + cache_rate=args.controlnet_train["cache_rate"], + fold=args.controlnet_train["fold"], + ) + + # Step 2: define AE, diffusion model and controlnet + # define AE + autoencoder = define_instance(args, "autoencoder_def").to(device) + # load trained autoencoder model + if args.trained_autoencoder_path is not None: + if not os.path.exists(args.trained_autoencoder_path): + raise ValueError("Please download the autoencoder checkpoint.") + autoencoder_ckpt = torch.load(args.trained_autoencoder_path, weights_only=True) + autoencoder.load_state_dict(autoencoder_ckpt) + logger.info(f"Load trained diffusion model from {args.trained_autoencoder_path}.") + else: + logger.info("trained autoencoder model is not loaded.") + + # define diffusion Model + unet = define_instance(args, "diffusion_unet_def").to(device) + include_body_region = unet.include_top_region_index_input + include_modality = unet.num_class_embeds is not None + + # load trained diffusion model + if args.trained_diffusion_path is not None: + if not os.path.exists(args.trained_diffusion_path): + raise ValueError("Please download the trained diffusion unet checkpoint.") + diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device, weights_only=False) + unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"]) + # load scale factor from diffusion model checkpoint + scale_factor = diffusion_model_ckpt["scale_factor"] + logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.") + logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.") + else: + logger.info("trained diffusion model is not loaded.") + scale_factor = 1.0 + logger.info(f"set scale_factor -> {scale_factor}.") + + # define ControlNet + controlnet = define_instance(args, "controlnet_def").to(device) + # copy weights from the DM to the controlnet + copy_model_state(controlnet, unet.state_dict()) + # load trained controlnet model if it is provided + if args.trained_controlnet_path is not None: + if not os.path.exists(args.trained_controlnet_path): + raise ValueError("Please download the trained ControlNet checkpoint.") + controlnet.load_state_dict( + torch.load(args.trained_controlnet_path, map_location=device, weights_only=False)["controlnet_state_dict"] + ) + logger.info(f"load trained controlnet model from {args.trained_controlnet_path}") + else: + logger.info("trained controlnet is not loaded.") + + noise_scheduler = define_instance(args, "noise_scheduler") + + # Step 3: inference + autoencoder.eval() + controlnet.eval() + unet.eval() + + for batch in val_loader: + + # get label mask + labels = batch["label"].to(device) + # get corresponding conditions + if include_body_region: + top_region_index_tensor = batch["top_region_index"].to(device) + bottom_region_index_tensor = batch["bottom_region_index"].to(device) + else: + top_region_index_tensor = None + bottom_region_index_tensor = None + spacing_tensor = batch["spacing"].to(device) + modality_tensor = args.controlnet_infer["modality"] * torch.ones((len(labels),), dtype=torch.long).to(device) + out_spacing = tuple((batch["spacing"].squeeze().numpy() / 100).tolist()) + # get target dimension + dim = batch["dim"] + output_size = (dim[0].item(), dim[1].item(), dim[2].item()) + latent_shape = (args.latent_channels, output_size[0] // 4, output_size[1] // 4, output_size[2] // 4) + # check if output_size and out_spacing are valid. + check_input(None, None, None, output_size, out_spacing, None) + # generate a single synthetic image using a latent diffusion model with controlnet. + synthetic_images, _ = ldm_conditional_sample_one_image( + autoencoder=autoencoder, + diffusion_unet=unet, + controlnet=controlnet, + noise_scheduler=noise_scheduler, + scale_factor=scale_factor, + device=device, + combine_label_or=labels, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + modality_tensor=modality_tensor, + latent_shape=latent_shape, + output_size=output_size, + noise_factor=1.0, + num_inference_steps=args.controlnet_infer["num_inference_steps"], + autoencoder_sliding_window_infer_size=args.controlnet_infer["autoencoder_sliding_window_infer_size"], + autoencoder_sliding_window_infer_overlap=args.controlnet_infer["autoencoder_sliding_window_infer_overlap"], + ) + # save image/label pairs + labels = decollate_batch(batch)[0]["label"] + real_object_name = labels.meta.get("filename_or_obj", "default_name.nii.gz") + labels.meta["filename_or_obj"] = real_object_name + output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + synthetic_images = MetaTensor(synthetic_images.squeeze(0), meta=labels.meta) + img_saver = SaveImage( + output_dir=args.output_dir, + output_postfix="image", + separate_folder=False, + ) + img_saver(synthetic_images) + label_saver = SaveImage( + output_dir=args.output_dir, + output_postfix="label", + separate_folder=False, + ) + label_saver(labels) + if use_ddp: + dist.destroy_process_group() + + +if __name__ == "__main__": + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + main() diff --git a/scripts/infer_testV2_controlnet.py b/scripts/infer_testV2_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..610a7013050bfe2fe9f154934b265ddbaa01c7e2 --- /dev/null +++ b/scripts/infer_testV2_controlnet.py @@ -0,0 +1,220 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import os +import sys +from datetime import datetime + +import torch +import torch.distributed as dist +from monai.data import MetaTensor, decollate_batch +from monai.networks.utils import copy_model_state +from monai.transforms import SaveImage +from monai.utils import RankFilter + +from .sample import check_input, ldm_conditional_sample_one_image +from .utils import define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp, prepare_maisi_controlnet_test_dataloader + + +@torch.inference_mode() +def main(): + parser = argparse.ArgumentParser(description="maisi.controlnet.infer") + parser.add_argument( + "-e", + "--environment-file", + default="./configs/environment_maisi_controlnet_train.json", + help="environment json file that stores environment path", + ) + parser.add_argument( + "-c", + "--config-file", + default="./configs/config_maisi.json", + help="config json file that stores network hyper-parameters", + ) + parser.add_argument( + "-t", + "--training-config", + default="./configs/config_maisi_controlnet_train.json", + help="config json file that stores training hyper-parameters", + ) + parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node") + + args = parser.parse_args() + + # Step 0: configuration + logger = logging.getLogger("maisi.controlnet.infer") + # whether to use distributed data parallel + use_ddp = args.gpus > 1 + if use_ddp: + rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + device = setup_ddp(rank, world_size) + logger.addFilter(RankFilter()) + else: + rank = 0 + world_size = 1 + device = torch.device(f"cuda:{rank}") + + torch.cuda.set_device(device) + logger.info(f"Number of GPUs: {torch.cuda.device_count()}") + logger.info(f"World_size: {world_size}") + + with open(args.environment_file, "r") as env_file: + env_dict = json.load(env_file) + with open(args.config_file, "r") as config_file: + config_dict = json.load(config_file) + with open(args.training_config, "r") as training_config_file: + training_config_dict = json.load(training_config_file) + + for k, v in env_dict.items(): + setattr(args, k, v) + for k, v in config_dict.items(): + setattr(args, k, v) + for k, v in training_config_dict.items(): + setattr(args, k, v) + + # Step 1: set data loader + val_loader = prepare_maisi_controlnet_test_dataloader( + json_data_list=args.json_data_list, + data_base_dir=args.data_base_dir, + batch_size=args.controlnet_train["batch_size"], + cache_rate=args.controlnet_train["cache_rate"], + rank=rank, + world_size=world_size,) + + # Step 2: define AE, diffusion model and controlnet + # define AE + autoencoder = define_instance(args, "autoencoder_def").to(device) + # load trained autoencoder model + if args.trained_autoencoder_path is not None: + if not os.path.exists(args.trained_autoencoder_path): + raise ValueError("Please download the autoencoder checkpoint.") + autoencoder_ckpt = torch.load(args.trained_autoencoder_path, weights_only=True) + autoencoder.load_state_dict(autoencoder_ckpt) + logger.info(f"Load trained diffusion model from {args.trained_autoencoder_path}.") + else: + logger.info("trained autoencoder model is not loaded.") + + # define diffusion Model + unet = define_instance(args, "diffusion_unet_def").to(device) + include_body_region = unet.include_top_region_index_input + include_modality = unet.num_class_embeds is not None + + # load trained diffusion model + if args.trained_diffusion_path is not None: + if not os.path.exists(args.trained_diffusion_path): + raise ValueError("Please download the trained diffusion unet checkpoint.") + diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device, weights_only=False) + unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"]) + # load scale factor from diffusion model checkpoint + scale_factor = diffusion_model_ckpt["scale_factor"] + logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.") + logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.") + else: + logger.info("trained diffusion model is not loaded.") + scale_factor = 1.0 + logger.info(f"set scale_factor -> {scale_factor}.") + + # define ControlNet + controlnet = define_instance(args, "controlnet_def").to(device) + # copy weights from the DM to the controlnet + copy_model_state(controlnet, unet.state_dict()) + # load trained controlnet model if it is provided + if args.trained_controlnet_path is not None: + if not os.path.exists(args.trained_controlnet_path): + raise ValueError("Please download the trained ControlNet checkpoint.") + controlnet.load_state_dict( + torch.load(args.trained_controlnet_path, map_location=device, weights_only=False)["controlnet_state_dict"] + ) + logger.info(f"load trained controlnet model from {args.trained_controlnet_path}") + else: + logger.info("trained controlnet is not loaded.") + + noise_scheduler = define_instance(args, "noise_scheduler") + + # Step 3: inference + autoencoder.eval() + controlnet.eval() + unet.eval() + + for batch in val_loader: + + # get label mask + labels = batch["label"].to(device) + # get corresponding conditions + if include_body_region: + top_region_index_tensor = batch["top_region_index"].to(device) + bottom_region_index_tensor = batch["bottom_region_index"].to(device) + else: + top_region_index_tensor = None + bottom_region_index_tensor = None + spacing_tensor = batch["spacing"].to(device) + modality_tensor = args.controlnet_infer["modality"] * torch.ones((len(labels),), dtype=torch.long).to(device) + out_spacing = tuple((batch["spacing"].squeeze().numpy() / 100).tolist()) + # get target dimension + dim = batch["dim"] + output_size = (dim[0].item(), dim[1].item(), dim[2].item()) + latent_shape = (args.latent_channels, output_size[0] // 4, output_size[1] // 4, output_size[2] // 4) + # check if output_size and out_spacing are valid. + check_input(None, None, None, output_size, out_spacing, None) + # generate a single synthetic image using a latent diffusion model with controlnet. + synthetic_images, _ = ldm_conditional_sample_one_image( + autoencoder=autoencoder, + diffusion_unet=unet, + controlnet=controlnet, + noise_scheduler=noise_scheduler, + scale_factor=scale_factor, + device=device, + combine_label_or=labels, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + modality_tensor=modality_tensor, + latent_shape=latent_shape, + output_size=output_size, + noise_factor=1.0, + num_inference_steps=args.controlnet_infer["num_inference_steps"], + autoencoder_sliding_window_infer_size=args.controlnet_infer["autoencoder_sliding_window_infer_size"], + autoencoder_sliding_window_infer_overlap=args.controlnet_infer["autoencoder_sliding_window_infer_overlap"], + ) + # save image/label pairs + labels = decollate_batch(batch)[0]["label"] + real_object_name = labels.meta.get("filename_or_obj", "default_name.nii.gz") + labels.meta["filename_or_obj"] = real_object_name + output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + synthetic_images = MetaTensor(synthetic_images.squeeze(0), meta=labels.meta) + img_saver = SaveImage( + output_dir=args.output_dir, + output_postfix="image", + separate_folder=False, + ) + img_saver(synthetic_images) + label_saver = SaveImage( + output_dir=args.output_dir, + output_postfix="label", + separate_folder=False, + ) + label_saver(labels) + if use_ddp: + dist.destroy_process_group() + + +if __name__ == "__main__": + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + main() diff --git a/scripts/infer_test_controlnet.py b/scripts/infer_test_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..738669431ea9599d4d16f6bb6427442f1688976a --- /dev/null +++ b/scripts/infer_test_controlnet.py @@ -0,0 +1,220 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import os +import sys +from datetime import datetime + +import torch +import torch.distributed as dist +from monai.data import MetaTensor, decollate_batch +from monai.networks.utils import copy_model_state +from monai.transforms import SaveImage +from monai.utils import RankFilter + +from .sample import check_input, ldm_conditional_sample_one_image +from .utils import define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp, prepare_maisi_controlnet_infer_dataloader + + +@torch.inference_mode() +def main(): + parser = argparse.ArgumentParser(description="maisi.controlnet.infer") + parser.add_argument( + "-e", + "--environment-file", + default="./configs/environment_maisi_controlnet_train.json", + help="environment json file that stores environment path", + ) + parser.add_argument( + "-c", + "--config-file", + default="./configs/config_maisi.json", + help="config json file that stores network hyper-parameters", + ) + parser.add_argument( + "-t", + "--training-config", + default="./configs/config_maisi_controlnet_train.json", + help="config json file that stores training hyper-parameters", + ) + parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node") + + args = parser.parse_args() + + # Step 0: configuration + logger = logging.getLogger("maisi.controlnet.infer") + # whether to use distributed data parallel + use_ddp = args.gpus > 1 + if use_ddp: + rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + device = setup_ddp(rank, world_size) + logger.addFilter(RankFilter()) + else: + rank = 0 + world_size = 1 + device = torch.device(f"cuda:{rank}") + + torch.cuda.set_device(device) + logger.info(f"Number of GPUs: {torch.cuda.device_count()}") + logger.info(f"World_size: {world_size}") + + with open(args.environment_file, "r") as env_file: + env_dict = json.load(env_file) + with open(args.config_file, "r") as config_file: + config_dict = json.load(config_file) + with open(args.training_config, "r") as training_config_file: + training_config_dict = json.load(training_config_file) + + for k, v in env_dict.items(): + setattr(args, k, v) + for k, v in config_dict.items(): + setattr(args, k, v) + for k, v in training_config_dict.items(): + setattr(args, k, v) + + # Step 1: set data loader + val_loader = prepare_maisi_controlnet_infer_dataloader( + json_data_list=args.json_data_list, + data_base_dir=args.data_base_dir, + batch_size=args.controlnet_train["batch_size"], + cache_rate=args.controlnet_train["cache_rate"], + rank=rank, + world_size=world_size,) + + # Step 2: define AE, diffusion model and controlnet + # define AE + autoencoder = define_instance(args, "autoencoder_def").to(device) + # load trained autoencoder model + if args.trained_autoencoder_path is not None: + if not os.path.exists(args.trained_autoencoder_path): + raise ValueError("Please download the autoencoder checkpoint.") + autoencoder_ckpt = torch.load(args.trained_autoencoder_path, weights_only=True) + autoencoder.load_state_dict(autoencoder_ckpt) + logger.info(f"Load trained diffusion model from {args.trained_autoencoder_path}.") + else: + logger.info("trained autoencoder model is not loaded.") + + # define diffusion Model + unet = define_instance(args, "diffusion_unet_def").to(device) + include_body_region = unet.include_top_region_index_input + include_modality = unet.num_class_embeds is not None + + # load trained diffusion model + if args.trained_diffusion_path is not None: + if not os.path.exists(args.trained_diffusion_path): + raise ValueError("Please download the trained diffusion unet checkpoint.") + diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device, weights_only=False) + unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"]) + # load scale factor from diffusion model checkpoint + scale_factor = diffusion_model_ckpt["scale_factor"] + logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.") + logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.") + else: + logger.info("trained diffusion model is not loaded.") + scale_factor = 1.0 + logger.info(f"set scale_factor -> {scale_factor}.") + + # define ControlNet + controlnet = define_instance(args, "controlnet_def").to(device) + # copy weights from the DM to the controlnet + copy_model_state(controlnet, unet.state_dict()) + # load trained controlnet model if it is provided + if args.trained_controlnet_path is not None: + if not os.path.exists(args.trained_controlnet_path): + raise ValueError("Please download the trained ControlNet checkpoint.") + controlnet.load_state_dict( + torch.load(args.trained_controlnet_path, map_location=device, weights_only=False)["controlnet_state_dict"] + ) + logger.info(f"load trained controlnet model from {args.trained_controlnet_path}") + else: + logger.info("trained controlnet is not loaded.") + + noise_scheduler = define_instance(args, "noise_scheduler") + + # Step 3: inference + autoencoder.eval() + controlnet.eval() + unet.eval() + + for batch in val_loader: + + # get label mask + labels = batch["label"].to(device) + # get corresponding conditions + if include_body_region: + top_region_index_tensor = batch["top_region_index"].to(device) + bottom_region_index_tensor = batch["bottom_region_index"].to(device) + else: + top_region_index_tensor = None + bottom_region_index_tensor = None + spacing_tensor = batch["spacing"].to(device) + modality_tensor = args.controlnet_infer["modality"] * torch.ones((len(labels),), dtype=torch.long).to(device) + out_spacing = tuple((batch["spacing"].squeeze().numpy() / 100).tolist()) + # get target dimension + dim = batch["dim"] + output_size = (dim[0].item(), dim[1].item(), dim[2].item()) + latent_shape = (args.latent_channels, output_size[0] // 4, output_size[1] // 4, output_size[2] // 4) + # check if output_size and out_spacing are valid. + check_input(None, None, None, output_size, out_spacing, None) + # generate a single synthetic image using a latent diffusion model with controlnet. + synthetic_images, _ = ldm_conditional_sample_one_image( + autoencoder=autoencoder, + diffusion_unet=unet, + controlnet=controlnet, + noise_scheduler=noise_scheduler, + scale_factor=scale_factor, + device=device, + combine_label_or=labels, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + modality_tensor=modality_tensor, + latent_shape=latent_shape, + output_size=output_size, + noise_factor=1.0, + num_inference_steps=args.controlnet_infer["num_inference_steps"], + autoencoder_sliding_window_infer_size=args.controlnet_infer["autoencoder_sliding_window_infer_size"], + autoencoder_sliding_window_infer_overlap=args.controlnet_infer["autoencoder_sliding_window_infer_overlap"], + ) + # save image/label pairs + labels = decollate_batch(batch)[0]["label"] + real_object_name = labels.meta.get("filename_or_obj", "default_name.nii.gz") + labels.meta["filename_or_obj"] = real_object_name + output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + synthetic_images = MetaTensor(synthetic_images.squeeze(0), meta=labels.meta) + img_saver = SaveImage( + output_dir=args.output_dir, + output_postfix="image", + separate_folder=False, + ) + img_saver(synthetic_images) + label_saver = SaveImage( + output_dir=args.output_dir, + output_postfix="label", + separate_folder=False, + ) + label_saver(labels) + if use_ddp: + dist.destroy_process_group() + + +if __name__ == "__main__": + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + main() diff --git a/scripts/inference.py b/scripts/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..6d12a7058c3c10d69b4e4b9fe46794d57ff076c4 --- /dev/null +++ b/scripts/inference.py @@ -0,0 +1,299 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +#     http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# # MAISI Inference Script +import argparse +import json +import logging +import os +import sys +import tempfile + +import monai +import torch +from monai.apps import download_url +from monai.config import print_config +from monai.transforms import LoadImage, Orientation +from monai.utils import set_determinism + +from scripts.sample import LDMSampler, check_input +from scripts.utils import define_instance +from scripts.utils_plot import find_label_center_loc, get_xyz_plot, show_image + + +def main(): + parser = argparse.ArgumentParser(description="maisi.controlnet.training") + parser.add_argument( + "-e", + "--environment-file", + default="./configs/environment.json", + help="environment json file that stores environment path", + ) + parser.add_argument( + "-c", + "--config-file", + default="./configs/config_maisi.json", + help="config json file that stores network hyper-parameters", + ) + parser.add_argument( + "-i", + "--inference-file", + default="./configs/config_infer.json", + help="config json file that stores inference hyper-parameters", + ) + parser.add_argument( + "-x", + "--extra-config-file", + default=None, + help="config json file that stores inference extra parameters", + ) + parser.add_argument( + "-s", + "--random-seed", + default=None, + help="random seed, can be None or int", + ) + parser.add_argument( + "--version", + default="maisi3d-rflow", + type=str, + help="maisi_version, choose from ['maisi3d-ddpm', 'maisi3d-rflow']", + ) + args = parser.parse_args() + # Step 0: configuration + logger = logging.getLogger("maisi.inference") + + maisi_version = args.version + + # ## Set deterministic training for reproducibility + if args.random_seed is not None: + set_determinism(seed=args.random_seed) + + # ## Setup data directory + # You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable. + # This allows you to save results and reuse downloads. + # If not specified a temporary directory will be used. + + directory = os.environ.get("MONAI_DATA_DIRECTORY") + if directory is not None: + os.makedirs(directory, exist_ok=True) + root_dir = tempfile.mkdtemp() if directory is None else directory + print(root_dir) + + # TODO: remove the `files` after the files are uploaded to the NGC + files = [ + { + "path": "models/autoencoder_epoch273.pt", + "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials" + "/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt", + }, + { + "path": "models/mask_generation_autoencoder.pt", + "url": "https://developer.download.nvidia.com/assets/Clara/monai" + "/tutorials/mask_generation_autoencoder.pt", + }, + { + "path": "models/mask_generation_diffusion_unet.pt", + "url": "https://developer.download.nvidia.com/assets/Clara/monai" + "/tutorials/model_zoo/model_maisi_mask_generation_diffusion_unet_v2.pt", + }, + { + "path": "configs/all_anatomy_size_condtions.json", + "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/all_anatomy_size_condtions.json", + }, + { + "path": "datasets/all_masks_flexible_size_and_spacing_4000.zip", + "url": "https://developer.download.nvidia.com/assets/Clara/monai" + "/tutorials/all_masks_flexible_size_and_spacing_4000.zip", + }, + ] + + if maisi_version == "maisi3d-ddpm": + files += [ + { + "path": "models/diff_unet_3d_ddpm.pt", + "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo" + "/model_maisi_input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1_alternative.pt", + }, + { + "path": "models/controlnet_3d_ddpm.pt", + "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo" + "/model_maisi_controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current_alternative.pt", + }, + { + "path": "configs/candidate_masks_flexible_size_and_spacing_3000.json", + "url": "https://developer.download.nvidia.com/assets/Clara/monai" + "/tutorials/candidate_masks_flexible_size_and_spacing_3000.json", + }, + ] + elif maisi_version == "maisi3d-rflow": + files += [ + { + "path": "models/diff_unet_3d_rflow.pt", + "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/" + "diff_unet_ckpt_rflow_epoch19350.pt", + }, + { + "path": "models/controlnet_3d_rflow.pt", + "url": "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/" + "controlnet_rflow_epoch60.pt", + }, + { + "path": "configs/candidate_masks_flexible_size_and_spacing_4000.json", + "url": "https://developer.download.nvidia.com/assets/Clara/monai" + "/tutorials/candidate_masks_flexible_size_and_spacing_4000.json", + }, + ] + else: + raise ValueError( + f"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}." + ) + + for file in files: + file["path"] = file["path"] if "datasets/" not in file["path"] else os.path.join(root_dir, file["path"]) + download_url(url=file["url"], filepath=file["path"]) + + # ## Read in environment setting, including data directory, model directory, and output directory + # The information for data directory, model directory, and output directory are saved in ./configs/environment.json + env_dict = json.load(open(args.environment_file, "r")) + for k, v in env_dict.items(): + # Update the path to the downloaded dataset in MONAI_DATA_DIRECTORY + val = v if "datasets/" not in v else os.path.join(root_dir, v) + setattr(args, k, val) + print(f"{k}: {val}") + print("Global config variables have been loaded.") + + # ## Read in configuration setting, including network definition, body region and anatomy to generate, etc. + # + # The information for the inference input, like body region and anatomy to generate, is stored in "./configs/config_infer.json". + # Please refer to README.md for the details. + config_dict = json.load(open(args.config_file, "r")) + for k, v in config_dict.items(): + setattr(args, k, v) + + # check the format of inference inputs + config_infer_dict = json.load(open(args.inference_file, "r")) + # override num_split if asked + if "autoencoder_tp_num_splits" in config_infer_dict: + args.autoencoder_def["num_splits"] = config_infer_dict["autoencoder_tp_num_splits"] + args.mask_generation_autoencoder_def["num_splits"] = config_infer_dict["autoencoder_tp_num_splits"] + for k, v in config_infer_dict.items(): + setattr(args, k, v) + print(f"{k}: {v}") + + # + # ## Read in optional extra configuration setting - typically acceleration options (TRT) + # + # + if args.extra_config_file is not None: + extra_config_dict = json.load(open(args.extra_config_file, "r")) + for k, v in extra_config_dict.items(): + setattr(args, k, v) + print(f"{k}: {v}") + + check_input( + args.body_region, + args.anatomy_list, + args.label_dict_json, + args.output_size, + args.spacing, + args.controllable_anatomy_size, + ) + latent_shape = [args.latent_channels, args.output_size[0] // 4, args.output_size[1] // 4, args.output_size[2] // 4] + print("Network definition and inference inputs have been loaded.") + + # ## Initialize networks and noise scheduler, then load the trained model weights. + # The networks and noise scheduler are defined in `config_file`. We will read them in and load the model weights. + noise_scheduler = define_instance(args, "noise_scheduler") + mask_generation_noise_scheduler = define_instance(args, "mask_generation_noise_scheduler") + + device = torch.device("cuda") + + autoencoder = define_instance(args, "autoencoder").to(device) + checkpoint_autoencoder = torch.load(args.trained_autoencoder_path, weights_only=True) + autoencoder.load_state_dict(checkpoint_autoencoder) + + diffusion_unet = define_instance(args, "diffusion_unet").to(device) + checkpoint_diffusion_unet = torch.load(args.trained_diffusion_path, weights_only=False) + diffusion_unet.load_state_dict(checkpoint_diffusion_unet["unet_state_dict"], strict=True) + scale_factor = checkpoint_diffusion_unet["scale_factor"].to(device) + + controlnet = define_instance(args, "controlnet").to(device) + checkpoint_controlnet = torch.load(args.trained_controlnet_path, weights_only=False) + monai.networks.utils.copy_model_state(controlnet, diffusion_unet.state_dict()) + controlnet.load_state_dict(checkpoint_controlnet["controlnet_state_dict"], strict=True) + + mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder").to(device) + checkpoint_mask_generation_autoencoder = torch.load( + args.trained_mask_generation_autoencoder_path, weights_only=True + ) + mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder) + + mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion").to(device) + checkpoint_mask_generation_diffusion_unet = torch.load( + args.trained_mask_generation_diffusion_path, weights_only=False + ) + mask_generation_diffusion_unet.load_state_dict(checkpoint_mask_generation_diffusion_unet["unet_state_dict"]) + mask_generation_scale_factor = checkpoint_mask_generation_diffusion_unet["scale_factor"] + + print("All the trained model weights have been loaded.") + + # ## Define the LDM Sampler, which contains functions that will perform the inference. + ldm_sampler = LDMSampler( + args.body_region, + args.anatomy_list, + args.all_mask_files_json, + args.all_anatomy_size_conditions_json, + args.all_mask_files_base_dir, + args.label_dict_json, + args.label_dict_remap_json, + autoencoder, + diffusion_unet, + controlnet, + noise_scheduler, + scale_factor, + mask_generation_autoencoder, + mask_generation_diffusion_unet, + mask_generation_scale_factor, + mask_generation_noise_scheduler, + device, + latent_shape, + args.mask_generation_latent_shape, + args.output_size, + args.output_dir, + args.controllable_anatomy_size, + image_output_ext=args.image_output_ext, + label_output_ext=args.label_output_ext, + spacing=args.spacing, + modality=args.modality, + num_inference_steps=args.num_inference_steps, + mask_generation_num_inference_steps=args.mask_generation_num_inference_steps, + random_seed=args.random_seed, + autoencoder_sliding_window_infer_size=args.autoencoder_sliding_window_infer_size, + autoencoder_sliding_window_infer_overlap=args.autoencoder_sliding_window_infer_overlap, + ) + + print(f"The generated image/mask pairs will be saved in {args.output_dir}.") + output_filenames = ldm_sampler.sample_multiple_images(args.num_output_samples) + print("MAISI image/mask generation finished") + + +if __name__ == "__main__": + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + torch.cuda.reset_peak_memory_stats() + main() + peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3) # Convert to GB + print(f"Peak GPU memory usage: {peak_memory_gb:.2f} GB") diff --git a/scripts/quality_check.py b/scripts/quality_check.py new file mode 100644 index 0000000000000000000000000000000000000000..742084aa8ff29e6f5e5d963dd1cff77a8db457ad --- /dev/null +++ b/scripts/quality_check.py @@ -0,0 +1,149 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +def get_masked_data(label_data, image_data, labels): + """ + Extracts and returns the image data corresponding to specified labels within a 3D volume. + + This function efficiently masks the `image_data` array based on the provided `labels` in the `label_data` array. + The function handles cases with both a large and small number of labels, optimizing performance accordingly. + + Args: + label_data (np.ndarray): A NumPy array containing label data, representing different anatomical + regions or classes in a 3D medical image. + image_data (np.ndarray): A NumPy array containing the image data from which the relevant regions + will be extracted. + labels (list of int): A list of integers representing the label values to be used for masking. + + Returns: + np.ndarray: A NumPy array containing the elements of `image_data` that correspond to the specified + labels in `label_data`. If no labels are provided, an empty array is returned. + + Raises: + ValueError: If `image_data` and `label_data` do not have the same shape. + + Example: + label_int_dict = {"liver": [1], "kidney": [5, 14]} + masked_data = get_masked_data(label_data, image_data, label_int_dict["kidney"]) + """ + + # Check if the shapes of image_data and label_data match + if image_data.shape != label_data.shape: + raise ValueError( + f"Shape mismatch: image_data has shape {image_data.shape}, " + f"but label_data has shape {label_data.shape}. They must be the same." + ) + + if not labels: + return np.array([]) # Return an empty array if no labels are provided + + labels = list(set(labels)) # remove duplicate items + + # Optimize performance based on the number of labels + num_label_acceleration_thresh = 3 + if len(labels) >= num_label_acceleration_thresh: + # if many labels, np.isin is faster + mask = np.isin(label_data, labels) + else: + # Use logical OR to combine masks if the number of labels is small + mask = np.zeros_like(label_data, dtype=bool) + for label in labels: + mask = np.logical_or(mask, label_data == label) + + # Retrieve the masked data + masked_data = image_data[mask.astype(bool)] + + return masked_data + + +def is_outlier(statistics, image_data, label_data, label_int_dict): + """ + Perform a quality check on the generated image by comparing its statistics with precomputed thresholds. + + Args: + statistics (dict): Dictionary containing precomputed statistics including mean +/- 3sigma ranges. + image_data (np.ndarray): The image data to be checked, typically a 3D NumPy array. + label_data (np.ndarray): The label data corresponding to the image, used for masking regions of interest. + label_int_dict (dict): Dictionary mapping label names to their corresponding integer lists. + e.g., label_int_dict = {"liver": [1], "kidney": [5, 14]} + + Returns: + dict: A dictionary with labels as keys, each containing the quality check result, + including whether it's an outlier, the median value, and the thresholds used. + If no data is found for a label, the median value will be `None` and `is_outlier` will be `False`. + + Example: + # Example input data + statistics = { + "liver": { + "sigma_6_low": -21.596463547885904, + "sigma_6_high": 156.27881534763367 + }, + "kidney": { + "sigma_6_low": -15.0, + "sigma_6_high": 120.0 + } + } + label_int_dict = { + "liver": [1], + "kidney": [5, 14] + } + image_data = np.random.rand(100, 100, 100) # Replace with actual image data + label_data = np.zeros((100, 100, 100)) # Replace with actual label data + label_data[40:60, 40:60, 40:60] = 1 # Example region for liver + label_data[70:90, 70:90, 70:90] = 5 # Example region for kidney + result = is_outlier(statistics, image_data, label_data, label_int_dict) + """ + outlier_results = {} + + for label_name, stats in statistics.items(): + # Get the thresholds from the statistics + low_thresh = min(stats["sigma_6_low"], stats["percentile_0_5"]) # or "sigma_12_low" depending on your needs + high_thresh = max(stats["sigma_6_high"], stats["percentile_99_5"]) # or "sigma_12_high" depending on your needs + + if label_name == "bone": + high_thresh = 1000.0 + + # Retrieve the corresponding label integers + labels = label_int_dict.get(label_name, []) + masked_data = get_masked_data(label_data, image_data, labels) + masked_data = masked_data[~np.isnan(masked_data)] + + if len(masked_data) == 0 or masked_data.size == 0: + outlier_results[label_name] = { + "is_outlier": False, + "median_value": None, + "low_thresh": low_thresh, + "high_thresh": high_thresh, + } + continue + + # Compute the median of the masked region + median_value = np.nanmedian(masked_data) + + if np.isnan(median_value): + median_value = None + is_outlier = False + else: + # Determine if the median value is an outlier + is_outlier = median_value < low_thresh or median_value > high_thresh + + outlier_results[label_name] = { + "is_outlier": is_outlier, + "median_value": median_value, + "low_thresh": low_thresh, + "high_thresh": high_thresh, + } + + return outlier_results diff --git a/scripts/rectified_flow.py b/scripts/rectified_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..e660a1abb6734c93a6b8e312a49bfdc6e86b7591 --- /dev/null +++ b/scripts/rectified_flow.py @@ -0,0 +1,322 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py +# which has the following license: +# https://github.com/hpcaitech/Open-Sora/blob/main/LICENSE +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from typing import Union + +import numpy as np +import torch +from torch.distributions import LogisticNormal + +from monai.utils import StrEnum + +from .ddpm import DDPMPredictionType +from .scheduler import Scheduler + + +class RFlowPredictionType(StrEnum): + """ + Set of valid prediction type names for the RFlow scheduler's `prediction_type` argument. + + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + V_PREDICTION = DDPMPredictionType.V_PREDICTION + + +def timestep_transform( + t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3 +): + """ + Applies a transformation to the timestep based on image resolution scaling. + + Args: + t (torch.Tensor): The original timestep(s). + input_img_size_numel (torch.Tensor): The input image's size (H * W * D). + base_img_size_numel (int): reference H*W*D size, usually smaller than input_img_size_numel. + scale (float): Scaling factor for the transformation. + num_train_timesteps (int): Total number of training timesteps. + spatial_dim (int): Number of spatial dimensions in the image. + + Returns: + torch.Tensor: Transformed timestep(s). + """ + t = t / num_train_timesteps + ratio_space = (input_img_size_numel / base_img_size_numel) ** (1.0 / spatial_dim) + + ratio = ratio_space * scale + new_t = ratio * t / (1 + (ratio - 1) * t) + + new_t = new_t * num_train_timesteps + return new_t + + +class RFlowScheduler(Scheduler): + """ + A rectified flow scheduler for guiding the diffusion process in a generative model. + + Supports uniform and logit-normal sampling methods, timestep transformation for + different resolutions, and noise addition during diffusion. + + Args: + num_train_timesteps (int): Total number of training timesteps. + use_discrete_timesteps (bool): Whether to use discrete timesteps. + sample_method (str): Training time step sampling method ('uniform' or 'logit-normal'). + loc (float): Location parameter for logit-normal distribution, used only if sample_method='logit-normal'. + scale (float): Scale parameter for logit-normal distribution, used only if sample_method='logit-normal'. + use_timestep_transform (bool): Whether to apply timestep transformation. + If true, there will be more inference timesteps at early(noisy) stages for larger image volumes. + transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True. + steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True. + base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True. + spatial_dim (int): 2 or 3, incidcating 2D or 3D images, used only if use_timestep_transform=True. + + Example: + + .. code-block:: python + + # define a scheduler + noise_scheduler = RFlowScheduler( + num_train_timesteps = 1000, + use_discrete_timesteps = True, + sample_method = 'logit-normal', + use_timestep_transform = True, + base_img_size_numel = 32 * 32 * 32, + spatial_dim = 3 + ) + + # during training + inputs = torch.ones(2,4,64,64,32) + noise = torch.randn_like(inputs) + timesteps = noise_scheduler.sample_timesteps(inputs) + noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + predicted_velocity = diffusion_unet( + x=noisy_inputs, + timesteps=timesteps + ) + loss = loss_l1(predicted_velocity, (inputs - noise)) + + # during inference + noisy_inputs = torch.randn(2,4,64,64,32) + input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:]) + noise_scheduler.set_timesteps( + num_inference_steps=30, input_img_size_numel=input_img_size_numel) + ) + all_next_timesteps = torch.cat( + (noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype)) + ) + for t, next_t in tqdm( + zip(noise_scheduler.timesteps, all_next_timesteps), + total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)), + ): + predicted_velocity = diffusion_unet( + x=noisy_inputs, + timesteps=timesteps + ) + noisy_inputs, _ = noise_scheduler.step(predicted_velocity, t, noisy_inputs, next_t) + final_output = noisy_inputs + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + use_discrete_timesteps: bool = True, + sample_method: str = "uniform", + loc: float = 0.0, + scale: float = 1.0, + use_timestep_transform: bool = False, + transform_scale: float = 1.0, + steps_offset: int = 0, + base_img_size_numel: int = 32 * 32 * 32, + spatial_dim: int = 3, + ): + # rectified flow only accepts velocity prediction + self.prediction_type = RFlowPredictionType.V_PREDICTION + + self.num_train_timesteps = num_train_timesteps + self.use_discrete_timesteps = use_discrete_timesteps + self.base_img_size_numel = base_img_size_numel + self.spatial_dim = spatial_dim + + # sample method + if sample_method not in ["uniform", "logit-normal"]: + raise ValueError( + f"sample_method = {sample_method}, which has to be chosen from ['uniform', 'logit-normal']." + ) + self.sample_method = sample_method + if sample_method == "logit-normal": + self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale])) + self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device) + + # timestep transform + self.use_timestep_transform = use_timestep_transform + self.transform_scale = transform_scale + self.steps_offset = steps_offset + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples: original samples + noise: noise to add to samples + timesteps: timesteps tensor with shape of (N,), indicating the timestep to be computed for each sample. + + Returns: + noisy_samples: sample with added noise + """ + timepoints: torch.Tensor = timesteps.float() / self.num_train_timesteps + timepoints = 1 - timepoints # [1,1/1000] + + # expand timepoint to noise shape + if noise.ndim == 5: + timepoints = timepoints[..., None, None, None, None].expand(-1, *noise.shape[1:]) + elif noise.ndim == 4: + timepoints = timepoints[..., None, None, None].expand(-1, *noise.shape[1:]) + else: + raise ValueError(f"noise tensor has to be 4D or 5D tensor, yet got shape of {noise.shape}") + + noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise + + return noisy_samples + + def set_timesteps( + self, + num_inference_steps: int, + device: str | torch.device | None = None, + input_img_size_numel: int | None = None, + ) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + input_img_size_numel: int, H*W*D of the image, used with self.use_timestep_transform is True. + """ + if num_inference_steps > self.num_train_timesteps or num_inference_steps < 1: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} should be at least 1, " + "and cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + # prepare timesteps + timesteps = [ + (1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps) + ] + if self.use_discrete_timesteps: + timesteps = [int(round(t)) for t in timesteps] + if self.use_timestep_transform: + timesteps = [ + timestep_transform( + t, + input_img_size_numel=input_img_size_numel, + base_img_size_numel=self.base_img_size_numel, + num_train_timesteps=self.num_train_timesteps, + spatial_dim=self.spatial_dim, + ) + for t in timesteps + ] + timesteps_np = np.array(timesteps).astype(np.float16) + if self.use_discrete_timesteps: + timesteps_np = timesteps_np.astype(np.int64) + self.timesteps = torch.from_numpy(timesteps_np).to(device) + self.timesteps += self.steps_offset + + def sample_timesteps(self, x_start): + """ + Randomly samples training timesteps using the chosen sampling method. + + Args: + x_start (torch.Tensor): The input tensor for sampling. + + Returns: + torch.Tensor: Sampled timesteps. + """ + if self.sample_method == "uniform": + t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps + elif self.sample_method == "logit-normal": + t = self.sample_t(x_start) * self.num_train_timesteps + + if self.use_discrete_timesteps: + t = t.long() + + if self.use_timestep_transform: + input_img_size_numel = torch.prod(torch.tensor(x_start.shape[2:])) + t = timestep_transform( + t, + input_img_size_numel=input_img_size_numel, + base_img_size_numel=self.base_img_size_numel, + num_train_timesteps=self.num_train_timesteps, + spatial_dim=len(x_start.shape) - 2, + ) + + return t + + def step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: Union[int, None] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predicts the next sample in the diffusion process. + + Args: + model_output (torch.Tensor): Output from the trained diffusion model. + timestep (int): Current timestep in the diffusion chain. + sample (torch.Tensor): Current sample in the process. + next_timestep (Union[int, None]): Optional next timestep. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Predicted sample at the next step and additional info. + """ + # Ensure num_inference_steps exists and is a valid integer + if not hasattr(self, "num_inference_steps") or not isinstance(self.num_inference_steps, int): + raise AttributeError( + "num_inference_steps is missing or not an integer in the class." + "Please run self.set_timesteps(num_inference_steps,device,input_img_size_numel) to set it." + ) + + v_pred = model_output + + if next_timestep is not None: + next_timestep = int(next_timestep) + dt: float = ( + float(timestep - next_timestep) / self.num_train_timesteps + ) # Now next_timestep is guaranteed to be int + else: + dt = ( + 1.0 / float(self.num_inference_steps) if self.num_inference_steps > 0 else 0.0 + ) # Avoid division by zero + + pred_post_sample = sample + v_pred * dt + pred_original_sample = sample + v_pred * timestep / self.num_train_timesteps + + return pred_post_sample, pred_original_sample diff --git a/scripts/sample.py b/scripts/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..18d091d46f6034acae056c7e1e9e56465fc885bc --- /dev/null +++ b/scripts/sample.py @@ -0,0 +1,1177 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import math +import os +import random +import time +from datetime import datetime +import warnings +import gc + +import monai +import torch +from monai.data import MetaTensor +from monai.inferers.inferer import DiffusionInferer +from monai.transforms import Compose, SaveImage +from monai.utils import set_determinism +from tqdm import tqdm +from monai.inferers.inferer import SlidingWindowInferer +from monai.networks.schedulers import RFlowScheduler, DDPMScheduler + +from .augmentation import augmentation +from .find_masks import find_masks +from .quality_check import is_outlier +from .utils import ( + binarize_labels, + general_mask_generation_post_process, + get_body_region_index_from_mask, + remap_labels, + dynamic_infer, +) + + +class ReconModel(torch.nn.Module): + """ + A PyTorch module for reconstructing images from latent representations. + + Attributes: + autoencoder: The autoencoder model used for decoding. + scale_factor: Scaling factor applied to the input before decoding. + """ + + def __init__(self, autoencoder, scale_factor): + super().__init__() + self.autoencoder = autoencoder + self.scale_factor = scale_factor + + def forward(self, z): + """ + Decode the input latent representation to an image. + + Args: + z (torch.Tensor): The input latent representation. + + Returns: + torch.Tensor: The reconstructed image. + """ + recon_pt_nda = self.autoencoder.decode_stage_2_outputs(z / self.scale_factor) + return recon_pt_nda + + +def initialize_noise_latents(latent_shape, device): + """ + Initialize random noise latents for image generation with float16. + + Args: + latent_shape (tuple): The shape of the latent space. + device (torch.device): The device to create the tensor on. + + Returns: + torch.Tensor: Initialized noise latents. + """ + return ( + torch.randn( + [ + 1, + ] + + list(latent_shape) + ) + .half() + .to(device) + ) + + +def ldm_conditional_sample_one_mask( + autoencoder, + diffusion_unet, + noise_scheduler, + scale_factor, + anatomy_size, + device, + latent_shape, + label_dict_remap_json, + num_inference_steps=1000, + autoencoder_sliding_window_infer_size=[96, 96, 96], + autoencoder_sliding_window_infer_overlap=0.6667, +): + """ + Generate a single synthetic mask using a latent diffusion model. + + Args: + autoencoder (nn.Module): The autoencoder model. + diffusion_unet (nn.Module): The diffusion U-Net model. + noise_scheduler: The noise scheduler for the diffusion process. + scale_factor (float): Scaling factor for the latent space. + anatomy_size (torch.Tensor): Tensor specifying the desired anatomy sizes. + device (torch.device): The device to run the computation on. + latent_shape (tuple): The shape of the latent space. + label_dict_remap_json (str): Path to the JSON file for label remapping. + num_inference_steps (int): Number of inference steps for the diffusion process. + autoencoder_sliding_window_infer_size (list, optional): Size of the sliding window for inference. Defaults to [96, 96, 96]. + autoencoder_sliding_window_infer_overlap (float, optional): Overlap ratio for sliding window inference. Defaults to 0.6667. + + Returns: + torch.Tensor: The generated synthetic mask. + """ + recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device) + + with torch.no_grad(), torch.amp.autocast("cuda"): + # Generate random noise + latents = initialize_noise_latents(latent_shape, device) + anatomy_size = torch.FloatTensor(anatomy_size).unsqueeze(0).unsqueeze(0).half().to(device) + # synthesize latents + if isinstance(noise_scheduler, DDPMScheduler) and num_inference_steps < noise_scheduler.num_train_timesteps: + warnings.warn( + "**************************************************************\n" + "* WARNING: Mask noise_scheduler is a DDPMScheduler.\n" + "* We expect num_inference_steps = noise_scheduler.num_train_timesteps" + f" = {noise_scheduler.num_train_timesteps}.\n" + f"* Yet got num_inference_steps = {num_inference_steps}.\n" + "* The generated image quality is not guaranteed.\n" + "**************************************************************" + ) + + noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps) + # mask generator is DDPM + inferer_ddpm = DiffusionInferer(noise_scheduler) + latents = inferer_ddpm.sample( + input_noise=latents, + diffusion_model=diffusion_unet, + scheduler=noise_scheduler, + verbose=True, + conditioning=anatomy_size.to(device), + ) + + inferer = SlidingWindowInferer( + roi_size=autoencoder_sliding_window_infer_size, + sw_batch_size=1, + progress=True, + mode="gaussian", + overlap=autoencoder_sliding_window_infer_overlap, + sw_device=device, + device=torch.device("cpu"), + ) + synthetic_mask = dynamic_infer(inferer, recon_model, latents) + synthetic_mask = torch.softmax(synthetic_mask, dim=1) + synthetic_mask = torch.argmax(synthetic_mask, dim=1, keepdim=True) + # mapping raw index to 132 labels + synthetic_mask = remap_labels(synthetic_mask, label_dict_remap_json) + + ###### post process ##### + data = synthetic_mask.squeeze().cpu().detach().numpy() + + labels = [23, 24, 26, 27, 128] + target_tumor_label = None + for index, size in enumerate(anatomy_size[0, 0, 5:10]): + if size.item() != -1.0: + target_tumor_label = labels[index] + + logging.info(f"target_tumor_label for postprocess:{target_tumor_label}") + data = general_mask_generation_post_process(data, target_tumor_label=target_tumor_label, device=device) + synthetic_mask = torch.from_numpy(data).unsqueeze(0).unsqueeze(0).to(device) + + return synthetic_mask + + +def ldm_conditional_sample_one_image( + autoencoder, + diffusion_unet, + controlnet, + noise_scheduler, + scale_factor, + device, + combine_label_or, + spacing_tensor, + latent_shape, + output_size, + noise_factor, + top_region_index_tensor=None, + bottom_region_index_tensor=None, + modality_tensor=None, + num_inference_steps=1000, + autoencoder_sliding_window_infer_size=[96, 96, 96], + autoencoder_sliding_window_infer_overlap=0.6667, +): + """ + Generate a single synthetic image using a latent diffusion model with controlnet. + + Args: + autoencoder (nn.Module): The autoencoder model. + diffusion_unet (nn.Module): The diffusion U-Net model. + controlnet (nn.Module): The controlnet model. + noise_scheduler: The noise scheduler for the diffusion process. + scale_factor (float): Scaling factor for the latent space. + device (torch.device): The device to run the computation on. + combine_label_or (torch.Tensor): The combined label tensor. + spacing_tensor (torch.Tensor): Tensor specifying the spacing. + latent_shape (tuple): The shape of the latent space. + output_size (tuple): The desired output size of the image. + noise_factor (float): Factor to scale the initial noise. + top_region_index_tensor (torch.Tensor): Tensor specifying the top region index. Defaults to None. + bottom_region_index_tensor (torch.Tensor): Tensor specifying the bottom region index. Defaults to None. + modality_tensor (torch.Tensor): Int Tensor specifying the modality. + num_inference_steps (int): Number of inference steps for the diffusion process. + autoencoder_sliding_window_infer_size (list, optional): Size of the sliding window for inference. Defaults to [96, 96, 96]. + autoencoder_sliding_window_infer_overlap (float, optional): Overlap ratio for sliding window inference. Defaults to 0.6667. + + Returns: + tuple: A tuple containing the synthetic image and its corresponding label. + """ + # CT image intensity range + a_min = -1000 + a_max = 1000 + # autoencoder output intensity range + b_min = 0.0 + b_max = 1 + + include_body_region = diffusion_unet.include_top_region_index_input + include_modality = diffusion_unet.num_class_embeds is not None + + recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device) + + with torch.no_grad(), torch.amp.autocast("cuda"): + logging.info("---- Start generating latent features... ----") + start_time = time.time() + # generate segmentation mask + combine_label = combine_label_or.to(device) + if ( + output_size[0] != combine_label.shape[2] + or output_size[1] != combine_label.shape[3] + or output_size[2] != combine_label.shape[4] + ): + logging.info( + "output_size is not a desired value. Need to interpolate the mask to match with output_size. The result image will be very low quality." + ) + combine_label = torch.nn.functional.interpolate(combine_label, size=output_size, mode="nearest") + + controlnet_cond_vis = binarize_labels(combine_label.as_tensor().long()).half() + + # Generate random noise + latents = initialize_noise_latents(latent_shape, device) * noise_factor + + # synthesize latents + if isinstance(noise_scheduler, RFlowScheduler): + noise_scheduler.set_timesteps( + num_inference_steps=num_inference_steps, + input_img_size_numel=torch.prod(torch.tensor(latents.shape[2:])), + ) + else: + noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps) + + if isinstance(noise_scheduler, DDPMScheduler) and num_inference_steps < noise_scheduler.num_train_timesteps: + warnings.warn( + "**************************************************************\n" + "* WARNING: Image noise_scheduler is a DDPMScheduler.\n" + "* We expect num_inference_steps = noise_scheduler.num_train_timesteps" + f" = {noise_scheduler.num_train_timesteps}.\n" + f"* Yet got num_inference_steps = {num_inference_steps}.\n" + "* The generated image quality is not guaranteed.\n" + "**************************************************************" + ) + + all_timesteps = noise_scheduler.timesteps + all_next_timesteps = torch.cat((all_timesteps[1:], torch.tensor([0], dtype=all_timesteps.dtype))) + progress_bar = tqdm( + zip(all_timesteps, all_next_timesteps), + total=min(len(all_timesteps), len(all_next_timesteps)), + ) + for t, next_t in progress_bar: + # get controlnet output + # Create a dictionary to store the inputs + controlnet_inputs = { + "x": latents, + "timesteps": torch.Tensor((t,)).to(device), + "controlnet_cond": controlnet_cond_vis, + } + if include_modality: + controlnet_inputs.update( + { + "class_labels": modality_tensor, + } + ) + down_block_res_samples, mid_block_res_sample = controlnet(**controlnet_inputs) + + # get diffusion network output + # Create a dictionary to store the inputs + unet_inputs = { + "x": latents, + "timesteps": torch.Tensor((t,)).to(device), + "spacing_tensor": spacing_tensor, + "down_block_additional_residuals": down_block_res_samples, + "mid_block_additional_residual": mid_block_res_sample, + } + # Add extra arguments if include_body_region is True + if include_body_region: + unet_inputs.update( + { + "top_region_index_tensor": top_region_index_tensor, + "bottom_region_index_tensor": bottom_region_index_tensor, + } + ) + if include_modality: + unet_inputs.update( + { + "class_labels": modality_tensor, + } + ) + model_output = diffusion_unet(**unet_inputs) + + if not isinstance(noise_scheduler, RFlowScheduler): + latents, _ = noise_scheduler.step(model_output, t, latents) # type: ignore + else: + latents, _ = noise_scheduler.step(model_output, t, latents, next_t) # type: ignore + end_time = time.time() + logging.info(f"---- DM/ControlNet Latent features generation time: {end_time - start_time} seconds ----") + del ( + unet_inputs, + controlnet_inputs, + model_output, + controlnet_cond_vis, + down_block_res_samples, + mid_block_res_sample, + ) + gc.collect() + torch.cuda.empty_cache() + + # decode latents to synthesized images + logging.info("---- Start decoding latent features into images... ----") + start_time = time.time() + + inferer = SlidingWindowInferer( + roi_size=autoencoder_sliding_window_infer_size, + sw_batch_size=1, + progress=True, + mode="gaussian", + overlap=autoencoder_sliding_window_infer_overlap, + sw_device=device, + device=torch.device("cpu"), + ) + synthetic_images = dynamic_infer(inferer, recon_model, latents) + synthetic_images = torch.clip(synthetic_images, b_min, b_max).cpu() + end_time = time.time() + logging.info(f"---- Image VAE decoding time: {end_time - start_time} seconds ----") + + ## post processing: + # project output to [0, 1] + synthetic_images = (synthetic_images - b_min) / (b_max - b_min) + # project output to [-1000, 1000] + synthetic_images = synthetic_images * (a_max - a_min) + a_min + # regularize background intensities + synthetic_images = crop_img_body_mask(synthetic_images, combine_label) + torch.cuda.empty_cache() + + return synthetic_images, combine_label + + +def filter_mask_with_organs(combine_label, anatomy_list): + """ + Filter a mask to only include specified organs. + + Args: + combine_label (torch.Tensor): The input mask. + anatomy_list (list): List of organ labels to keep. + + Returns: + torch.Tensor: The filtered mask. + """ + # final output mask file has shape of output_size, contains labels in anatomy_list + # it is already interpolated to target size + combine_label = combine_label.long() + # filter out the organs that are not in anatomy_list + for i in range(len(anatomy_list)): + organ = anatomy_list[i] + # replace it with a negative value so it will get mixed + combine_label[combine_label == organ] = -(i + 1) + # zero-out voxels with value not in anatomy_list + combine_label[combine_label > 0] = 0 + # output positive values + combine_label = -combine_label + return combine_label + + +def crop_img_body_mask(synthetic_images, combine_label): + """ + Crop the synthetic image using a body mask. + + Args: + synthetic_images (torch.Tensor): The synthetic images. + combine_label (torch.Tensor): The body mask. + + Returns: + torch.Tensor: The cropped synthetic images. + """ + synthetic_images[combine_label == 0] = -1000 + return synthetic_images + + +def check_input( + body_region, + anatomy_list, + label_dict_json, + output_size, + spacing, + controllable_anatomy_size=[("pancreas", 0.5)], +): + """ + Validate input parameters for image generation. + + Args: + body_region (list): List of body regions. + anatomy_list (list): List of anatomical structures. + label_dict_json (str): Path to the label dictionary JSON file. + output_size (tuple): Desired output size of the image. + spacing (tuple): Desired voxel spacing. + controllable_anatomy_size (list): List of tuples specifying controllable anatomy sizes. + + Raises: + ValueError: If any input parameter is invalid. + """ + # check output_size and spacing format + if output_size[0] != output_size[1]: + raise ValueError(f"The first two components of output_size need to be equal, yet got {output_size}.") + if (output_size[0] not in [256, 384, 512]) or (output_size[2] not in [128, 256, 384, 512, 640, 768]): + raise ValueError( + f"The output_size[0] have to be chosen from [256, 384, 512], and output_size[2] have to be chosen from [128, 256, 384, 512, 640, 768], yet got {output_size}." + ) + + if spacing[0] != spacing[1]: + raise ValueError(f"The first two components of spacing need to be equal, yet got {spacing}.") + if spacing[0] < 0.5 or spacing[0] > 3.0 or spacing[2] < 0.5 or spacing[2] > 5.0: + raise ValueError( + f"spacing[0] have to be between 0.5 and 3.0 mm, spacing[2] have to be between 0.5 and 5.0 mm, yet got {spacing}." + ) + + if output_size[0] * spacing[0] < 256: + FOV = [output_size[axis] * spacing[axis] for axis in range(3)] + raise ValueError( + f"`'spacing'({spacing}mm) and 'output_size'({output_size}) together decide the output field of view (FOV). The FOV will be {FOV}mm. We recommend the FOV in x and y axis to be at least 256mm for head, and at least 384mm for other body regions like abdomen. There is no such restriction for z-axis." + ) + + if controllable_anatomy_size == None: + logging.info(f"`controllable_anatomy_size` is not provided.") + return + + # check controllable_anatomy_size format + if len(controllable_anatomy_size) > 10: + raise ValueError( + f"The length of list controllable_anatomy_size has to be less than 10. Yet got length equal to {len(controllable_anatomy_size)}." + ) + available_controllable_organ = [ + "liver", + "gallbladder", + "stomach", + "pancreas", + "colon", + ] + available_controllable_tumor = [ + "hepatic tumor", + "bone lesion", + "lung tumor", + "colon cancer primaries", + "pancreatic tumor", + ] + available_controllable_anatomy = available_controllable_organ + available_controllable_tumor + controllable_tumor = [] + controllable_organ = [] + for controllable_anatomy_size_pair in controllable_anatomy_size: + if controllable_anatomy_size_pair[0] not in available_controllable_anatomy: + raise ValueError( + f"The controllable_anatomy have to be chosen from {available_controllable_anatomy}, yet got {controllable_anatomy_size_pair[0]}." + ) + if controllable_anatomy_size_pair[0] in available_controllable_tumor: + controllable_tumor += [controllable_anatomy_size_pair[0]] + if controllable_anatomy_size_pair[0] in available_controllable_organ: + controllable_organ += [controllable_anatomy_size_pair[0]] + if controllable_anatomy_size_pair[1] == -1: + continue + if controllable_anatomy_size_pair[1] < 0 or controllable_anatomy_size_pair[1] > 1.0: + raise ValueError( + f"The controllable size scale have to be between 0 and 1,0, or equal to -1, yet got {controllable_anatomy_size_pair[1]}." + ) + if len(controllable_tumor + controllable_organ) != len(list(set(controllable_tumor + controllable_organ))): + raise ValueError(f"Please do not repeat controllable_anatomy. Got {controllable_tumor + controllable_organ}.") + if len(controllable_tumor) > 1: + raise ValueError(f"Only one controllable tumor is supported. Yet got {controllable_tumor}.") + + if len(controllable_anatomy_size) > 0: + logging.info( + f"`controllable_anatomy_size` is not empty.\nWe will ignore `body_region` and `anatomy_list` and synthesize based on `controllable_anatomy_size`: ({controllable_anatomy_size})." + ) + else: + logging.info( + f"`controllable_anatomy_size` is empty.\nWe will synthesize based on `body_region`: ({body_region}) and `anatomy_list`: ({anatomy_list})." + ) + # check body_region format + available_body_region = [ + "head", + "chest", + "thorax", + "abdomen", + "pelvis", + "lower", + ] + for region in body_region: + if region not in available_body_region: + raise ValueError( + f"The components in body_region have to be chosen from {available_body_region}, yet got {region}." + ) + + # check anatomy_list format + with open(label_dict_json) as f: + label_dict = json.load(f) + for anatomy in anatomy_list: + if anatomy not in label_dict.keys(): + raise ValueError( + f"The components in anatomy_list have to be chosen from {label_dict.keys()}, yet got {anatomy}." + ) + logging.info(f"The generate results will have voxel size to be {spacing}mm, volume size to be {output_size}.") + + return + + +class LDMSampler: + """ + A sampler class for generating synthetic medical images and masks using latent diffusion models. + + Attributes: + Various attributes related to model configuration, input parameters, and generation settings. + """ + + def __init__( + self, + body_region, + anatomy_list, + all_mask_files_json, + all_anatomy_size_condtions_json, + all_mask_files_base_dir, + label_dict_json, + label_dict_remap_json, + autoencoder, + diffusion_unet, + controlnet, + noise_scheduler, + scale_factor, + mask_generation_autoencoder, + mask_generation_diffusion_unet, + mask_generation_scale_factor, + mask_generation_noise_scheduler, + device, + latent_shape, + mask_generation_latent_shape, + output_size, + output_dir, + controllable_anatomy_size, + image_output_ext=".nii.gz", + label_output_ext=".nii.gz", + real_img_median_statistics="./configs/image_median_statistics.json", + spacing=[1, 1, 1], + modality=1, + num_inference_steps=None, + mask_generation_num_inference_steps=None, + random_seed=None, + autoencoder_sliding_window_infer_size=[96, 96, 96], + autoencoder_sliding_window_infer_overlap=0.6667, + ) -> None: + """ + Initialize the LDMSampler with various parameters and models. + + Args: + Various parameters related to model configuration, input settings, and output specifications. + """ + self.random_seed = random_seed + if random_seed is not None: + set_determinism(seed=random_seed) + + with open(label_dict_json, "r") as f: + label_dict = json.load(f) + self.all_anatomy_size_condtions_json = all_anatomy_size_condtions_json + + # intialize variables + self.body_region = body_region + self.anatomy_list = [label_dict[organ] for organ in anatomy_list] + self.all_mask_files_json = all_mask_files_json + self.data_root = all_mask_files_base_dir + self.label_dict_remap_json = label_dict_remap_json + self.autoencoder = autoencoder + self.diffusion_unet = diffusion_unet + self.controlnet = controlnet + self.noise_scheduler = noise_scheduler + self.scale_factor = scale_factor + self.mask_generation_autoencoder = mask_generation_autoencoder + self.mask_generation_diffusion_unet = mask_generation_diffusion_unet + self.mask_generation_scale_factor = mask_generation_scale_factor + self.mask_generation_noise_scheduler = mask_generation_noise_scheduler + self.device = device + self.latent_shape = latent_shape + self.mask_generation_latent_shape = mask_generation_latent_shape + self.output_size = output_size + self.output_dir = output_dir + self.noise_factor = 1.0 + self.controllable_anatomy_size = controllable_anatomy_size + if len(self.controllable_anatomy_size): + logging.info("controllable_anatomy_size is given, mask generation is triggered!") + # overwrite the anatomy_list by given organs in self.controllable_anatomy_size + self.anatomy_list = [label_dict[organ_and_size[0]] for organ_and_size in self.controllable_anatomy_size] + self.image_output_ext = image_output_ext + self.label_output_ext = label_output_ext + # Set the default value for number of inference steps to 1000 + self.num_inference_steps = num_inference_steps if num_inference_steps is not None else 1000 + self.mask_generation_num_inference_steps = ( + mask_generation_num_inference_steps if mask_generation_num_inference_steps is not None else 1000 + ) + + if any(size % 16 != 0 for size in autoencoder_sliding_window_infer_size): + raise ValueError( + f"autoencoder_sliding_window_infer_size must be divisible by 16.\n Got {autoencoder_sliding_window_infer_size}" + ) + if not (0 <= autoencoder_sliding_window_infer_overlap <= 1): + raise ValueError( + f"Value of autoencoder_sliding_window_infer_overlap must be between 0 and 1.\n Got {autoencoder_sliding_window_infer_overlap}" + ) + self.autoencoder_sliding_window_infer_size = autoencoder_sliding_window_infer_size + self.autoencoder_sliding_window_infer_overlap = autoencoder_sliding_window_infer_overlap + + # quality check args + self.max_try_time = 2 # if not pass quality check, will try self.max_try_time times + with open(real_img_median_statistics, "r") as json_file: + self.median_statistics = json.load(json_file) + self.label_int_dict = { + "liver": [1], + "spleen": [3], + "pancreas": [4], + "kidney": [5, 14], + "lung": [28, 29, 30, 31, 31], + "brain": [22], + "hepatic tumor": [26], + "bone lesion": [128], + "lung tumor": [23], + "colon cancer primaries": [27], + "pancreatic tumor": [24], + "bone": list(range(33, 57)) + list(range(63, 98)) + [120, 122, 127], + } + + # networks + self.autoencoder.eval() + self.diffusion_unet.eval() + self.controlnet.eval() + self.mask_generation_autoencoder.eval() + self.mask_generation_diffusion_unet.eval() + + self.spacing = spacing + self.modality_tensor = modality * torch.ones((1,), dtype=torch.long).to(device) + self.include_body_region = self.diffusion_unet.include_top_region_index_input + self.include_modality = self.diffusion_unet.num_class_embeds is not None + + val_transforms_list = [ + monai.transforms.LoadImaged(keys=["pseudo_label"]), + monai.transforms.EnsureChannelFirstd(keys=["pseudo_label"]), + monai.transforms.Orientationd(keys=["pseudo_label"], axcodes="RAS"), + monai.transforms.EnsureTyped(keys=["pseudo_label"], dtype=torch.uint8), + monai.transforms.Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)), + monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2), + ] + if self.include_body_region: + val_transforms_list += [ + monai.transforms.Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x)), + monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x)), + monai.transforms.Lambdad(keys="top_region_index", func=lambda x: x * 1e2), + monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: x * 1e2), + ] + + self.val_transforms = Compose(val_transforms_list) + logging.info("LDM sampler initialized.") + + def sample_multiple_images(self, num_img): + """ + Generate multiple synthetic images and masks. + + Args: + num_img (int): Number of images to generate. + """ + modality_tensor = self.modality_tensor + output_filenames = [] + if len(self.controllable_anatomy_size) > 0: + # we will use mask generation instead of finding candidate masks + # create a dummy selected_mask_files for placeholder + selected_mask_files = list(range(num_img)) + # prerpare organ size conditions + anatomy_size_condtion = self.prepare_anatomy_size_condtion(self.controllable_anatomy_size) + else: + need_resample = False + # find candidate mask and save to candidate_mask_files + candidate_mask_files = find_masks( + self.body_region, + self.anatomy_list, + self.spacing, + self.output_size, + True, + self.all_mask_files_json, + self.data_root, + ) + if len(candidate_mask_files) < num_img: + # if we cannot find enough masks based on the exact match of anatomy list, spacing, and output size, + # then we will try to find the closest mask in terms of spacing, and output size. + logging.info("Resample mask file to get desired output size and spacing") + candidate_mask_files = self.find_closest_masks(num_img) + need_resample = True + + selected_mask_files = self.select_mask(candidate_mask_files, num_img) + logging.info(f"Images will be generated based on {selected_mask_files}.") + if len(selected_mask_files) < num_img: + raise ValueError( + ( + f"len(selected_mask_files) ({len(selected_mask_files)}) < num_img ({num_img}). " + "This should not happen. Please revisit function select_mask(self, candidate_mask_files, num_img)." + ) + ) + + num_generated_img = 0 + for index_s in range(len(selected_mask_files)): + item = selected_mask_files[index_s] + if num_generated_img >= num_img: + break + logging.info("---- Start preparing masks... ----") + start_time = time.time() + if len(self.controllable_anatomy_size) > 0: + # generate a synthetic mask + ( + combine_label_or, + top_region_index_tensor, + bottom_region_index_tensor, + spacing_tensor, + ) = self.prepare_one_mask_and_meta_info(anatomy_size_condtion) + else: + # read in mask file + mask_file = item["mask_file"] + if_aug = item["if_aug"] + ( + combine_label_or, + top_region_index_tensor, + bottom_region_index_tensor, + spacing_tensor, + ) = self.read_mask_information(mask_file) + if need_resample: + combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) + # mask augmentation + if if_aug: + combine_label_or = augmentation(combine_label_or, self.output_size, self.random_seed) + end_time = time.time() + logging.info(f"---- Mask preparation time: {end_time - start_time} seconds ----") + torch.cuda.empty_cache() + # generate image/label pairs + to_generate = True + try_time = 0 + # start generation + synthetic_images, synthetic_labels = self.sample_one_pair( + combine_label_or, + top_region_index_tensor, + bottom_region_index_tensor, + spacing_tensor, + modality_tensor, + ) + # synthetic image quality check + pass_quality_check = self.quality_check( + synthetic_images.cpu().detach().numpy(), combine_label_or.cpu().detach().numpy() + ) + print(num_img - num_generated_img, (len(selected_mask_files) - index_s)) + if pass_quality_check or (num_img - num_generated_img) >= (len(selected_mask_files) - index_s): + if not pass_quality_check: + logging.info( + "Generated image/label pair did not pass quality check, but will still save them. " + "Please consider changing spacing and output_size to facilitate a more realistic setting." + ) + num_generated_img = num_generated_img + 1 + # save image/label pairs + output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz" + synthetic_images = MetaTensor(synthetic_images, meta=synthetic_labels.meta) + img_saver = SaveImage( + output_dir=self.output_dir, + output_postfix=output_postfix + "_image", + output_ext=self.image_output_ext, + separate_folder=False, + ) + img_saver(synthetic_images[0]) + synthetic_images_filename = os.path.join( + self.output_dir, "sample_" + output_postfix + "_image" + self.image_output_ext + ) + # filter out the organs that are not in anatomy_list + synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list) + label_saver = SaveImage( + output_dir=self.output_dir, + output_postfix=output_postfix + "_label", + output_ext=self.label_output_ext, + separate_folder=False, + ) + label_saver(synthetic_labels[0]) + synthetic_labels_filename = os.path.join( + self.output_dir, "sample_" + output_postfix + "_label" + self.label_output_ext + ) + output_filenames.append([synthetic_images_filename, synthetic_labels_filename]) + to_generate = False + else: + logging.info("Generated image/label pair did not pass quality check, will re-generate another pair.") + return output_filenames + + def select_mask(self, candidate_mask_files, num_img): + """ + Select mask files for image generation. + + Args: + candidate_mask_files (list): List of candidate mask files. + num_img (int): Number of images to generate. + + Returns: + list: Selected mask files with augmentation flags. + """ + selected_mask_files = [] + random.shuffle(candidate_mask_files) + + for n in range(len(candidate_mask_files)): + mask_file = candidate_mask_files[n % len(candidate_mask_files)] + selected_mask_files.append({"mask_file": mask_file, "if_aug": True}) + return selected_mask_files + + def sample_one_pair( + self, + combine_label_or_aug, + top_region_index_tensor, + bottom_region_index_tensor, + spacing_tensor, + modality_tensor, + ): + """ + Generate a single pair of synthetic image and mask. + + Args: + combine_label_or_aug (torch.Tensor): Combined label tensor or augmented label. + top_region_index_tensor (torch.Tensor): Tensor specifying the top region index. + bottom_region_index_tensor (torch.Tensor): Tensor specifying the bottom region index. + spacing_tensor (torch.Tensor): Tensor specifying the spacing. + modality_tensor (torch.Tensor): Int Tensor specifying the modality. + + Returns: + tuple: A tuple containing the synthetic image and its corresponding label. + """ + # generate image/label pairs + synthetic_images, synthetic_labels = ldm_conditional_sample_one_image( + autoencoder=self.autoencoder, + diffusion_unet=self.diffusion_unet, + controlnet=self.controlnet, + noise_scheduler=self.noise_scheduler, + scale_factor=self.scale_factor, + device=self.device, + combine_label_or=combine_label_or_aug, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + modality_tensor=modality_tensor, + latent_shape=self.latent_shape, + output_size=self.output_size, + noise_factor=self.noise_factor, + num_inference_steps=self.num_inference_steps, + autoencoder_sliding_window_infer_size=self.autoencoder_sliding_window_infer_size, + autoencoder_sliding_window_infer_overlap=self.autoencoder_sliding_window_infer_overlap, + ) + return synthetic_images, synthetic_labels + + def prepare_anatomy_size_condtion( + self, + controllable_anatomy_size, + ): + """ + Prepare anatomy size conditions for mask generation. + + Args: + controllable_anatomy_size (list): List of tuples specifying controllable anatomy sizes. + + Returns: + list: Prepared anatomy size conditions. + """ + anatomy_size_idx = { + "gallbladder": 0, + "liver": 1, + "stomach": 2, + "pancreas": 3, + "colon": 4, + "lung tumor": 5, + "pancreatic tumor": 6, + "hepatic tumor": 7, + "colon cancer primaries": 8, + "bone lesion": 9, + } + provide_anatomy_size = [None for _ in range(10)] + logging.info(f"controllable_anatomy_size: {controllable_anatomy_size}") + for element in controllable_anatomy_size: + anatomy_name, anatomy_size = element + provide_anatomy_size[anatomy_size_idx[anatomy_name]] = anatomy_size + + with open(self.all_anatomy_size_condtions_json, "r") as f: + all_anatomy_size_condtions = json.load(f) + + # loop through the database and find closest combinations + candidate_list = [] + for anatomy_size in all_anatomy_size_condtions: + size = anatomy_size["organ_size"] + diff = 0 + for db_size, provide_size in zip(size, provide_anatomy_size): + if provide_size is None: + continue + diff += abs(provide_size - db_size) + candidate_list.append((size, diff)) + candidate_condition = sorted(candidate_list, key=lambda x: x[1])[0][0] + + # overwrite the anatomy size provided by users + for element in controllable_anatomy_size: + anatomy_name, anatomy_size = element + candidate_condition[anatomy_size_idx[anatomy_name]] = anatomy_size + + return candidate_condition + + def prepare_one_mask_and_meta_info(self, anatomy_size_condtion): + """ + Prepare a single mask and its associated meta information. + + Args: + anatomy_size_condtion (list): Anatomy size conditions. + + Returns: + tuple: A tuple containing the prepared mask and associated tensors. + """ + combine_label_or = self.sample_one_mask(anatomy_size=anatomy_size_condtion) + # TODO: current mask generation model only can generate 256^3 volumes with 1.5 mm spacing. + affine = torch.zeros((4, 4)) + affine[0, 0] = 1.5 + affine[1, 1] = 1.5 + affine[2, 2] = 1.5 + affine[3, 3] = 1.0 # dummy + combine_label_or = MetaTensor(combine_label_or, affine=affine) + combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) + + top_region_index, bottom_region_index = get_body_region_index_from_mask(combine_label_or) + + spacing_tensor = torch.FloatTensor(self.spacing).unsqueeze(0).half().to(self.device) * 1e2 + top_region_index_tensor = torch.FloatTensor(top_region_index).unsqueeze(0).half().to(self.device) * 1e2 + bottom_region_index_tensor = torch.FloatTensor(bottom_region_index).unsqueeze(0).half().to(self.device) * 1e2 + + return combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor + + def sample_one_mask(self, anatomy_size): + """ + Generate a single synthetic mask. + + Args: + anatomy_size (list): Anatomy size specifications. + + Returns: + torch.Tensor: The generated synthetic mask. + """ + # generate one synthetic mask + synthetic_mask = ldm_conditional_sample_one_mask( + self.mask_generation_autoencoder, + self.mask_generation_diffusion_unet, + self.mask_generation_noise_scheduler, + self.mask_generation_scale_factor, + anatomy_size, + self.device, + self.mask_generation_latent_shape, + label_dict_remap_json=self.label_dict_remap_json, + num_inference_steps=self.mask_generation_num_inference_steps, + autoencoder_sliding_window_infer_size=self.autoencoder_sliding_window_infer_size, + autoencoder_sliding_window_infer_overlap=self.autoencoder_sliding_window_infer_overlap, + ) + return synthetic_mask + + def ensure_output_size_and_spacing(self, labels, check_contains_target_labels=True): + """ + Ensure the output mask has the correct size and spacing. + + Args: + labels (torch.Tensor): Input label tensor. + check_contains_target_labels (bool): Whether to check if the resampled mask contains target labels. + + Returns: + torch.Tensor: Resampled label tensor. + + Raises: + ValueError: If the resampled mask doesn't contain required class labels. + """ + current_spacing = [labels.affine[0, 0], labels.affine[1, 1], labels.affine[2, 2]] + current_shape = list(labels.squeeze().shape) + + need_resample = False + # check spacing + for i, j in zip(current_spacing, self.spacing): + if i != j: + need_resample = True + # check output size + for i, j in zip(current_shape, self.output_size): + if i != j: + need_resample = True + # resample to target size and spacing + if need_resample: + logging.info("Resampling mask to target shape and spacing") + logging.info(f"Resize Spacing: {current_spacing} -> {self.spacing}") + logging.info(f"Output size: {current_shape} -> {self.output_size}") + spacing = monai.transforms.Spacing(pixdim=tuple(self.spacing), mode="nearest") + pad_crop = monai.transforms.ResizeWithPadOrCrop(spatial_size=tuple(self.output_size)) + labels = pad_crop(spacing(labels.squeeze(0))).unsqueeze(0).to(labels.dtype) + + contained_labels = torch.unique(labels) + if check_contains_target_labels: + # check if the resampled mask still contains those target labels + for anatomy_label in self.anatomy_list: + if anatomy_label not in contained_labels: + raise ValueError( + f"Resampled mask does not contain required class labels {anatomy_label}. Please tune spacing and output size." + ) + return labels + + def read_mask_information(self, mask_file): + """ + Read mask information from a file. + + Args: + mask_file (str): Path to the mask file. + + Returns: + tuple: A tuple containing the mask tensor and associated information. + """ + val_data = self.val_transforms(mask_file) + + for key in ["pseudo_label", "spacing", "top_region_index", "bottom_region_index"]: + if isinstance(val_data[key], torch.Tensor): + val_data[key] = val_data[key].unsqueeze(0).to(self.device) + else: + val_data[key] = None + + return ( + val_data["pseudo_label"], + val_data["top_region_index"], + val_data["bottom_region_index"], + val_data["spacing"], + ) + + def find_closest_masks(self, num_img): + """ + Find the closest matching masks from the database. + + Args: + num_img (int): Number of images to generate. + + Returns: + list: List of closest matching mask candidates. + + Raises: + ValueError: If suitable candidates cannot be found. + """ + # first check the database based on anatomy list + candidates = find_masks( + self.body_region, + self.anatomy_list, + self.spacing, + self.output_size, + False, + self.all_mask_files_json, + self.data_root, + ) + + if len(candidates) < num_img: + raise ValueError(f"candidate masks are less than {num_img}).") + + # loop through the database and find closest combinations + new_candidates = [] + for c in candidates: + diff = 0 + include_c = True + for axis in range(3): + if abs(c["dim"][axis]) < self.output_size[axis] - 64: + # we cannot upsample the mask too much + include_c = False + break + # check diff in FOV, major metric + diff += abs( + (abs(c["dim"][axis] * c["spacing"][axis]) - self.output_size[axis] * self.spacing[axis]) / 10 + ) + # check diff in dim + diff += abs((abs(c["dim"][axis]) - self.output_size[axis]) / 100) + # check diff in spacing + diff += abs(abs(c["spacing"][axis]) - self.spacing[axis]) + if include_c: + new_candidates.append((c, diff)) + + # choose top-2*num_img candidates (at least 5) + num_candidates = max(self.max_try_time * num_img, 5) + new_candidates = sorted(new_candidates, key=lambda x: x[1]) + + final_candidates = [] + # check top-2*num_img candidates and update spacing after resampling + for c, _ in new_candidates: + c = self.resample_mask_check_organ_list(c) + if c is not None: + final_candidates.append(c) + if len(final_candidates) >= num_candidates: + break + if len(final_candidates) == 0: + raise ValueError("Cannot find body region with given organ list.") + return final_candidates + + def resample_mask_check_organ_list(self, mask): + """ + Resample mask and check if the resampled mask contains the required organ list. + + Args: + mask (dict): input mask. + + Returns: + dict: resampled mask. If None, means the resampled mask does not contain the required organ list + + Raises: + ValueError: If suitable candidates cannot be found. + """ + + image_loader = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True) + label = image_loader(mask["pseudo_label"]) + try: + label = self.ensure_output_size_and_spacing(label.unsqueeze(0)) + except ValueError as e: + if "Resampled mask does not contain required class labels" in str(e): + return None + else: + raise e + # get region_index after resample + top_region_index, bottom_region_index = get_body_region_index_from_mask(label) + mask["top_region_index"] = top_region_index + mask["bottom_region_index"] = bottom_region_index + mask["spacing"] = self.spacing + mask["dim"] = self.output_size + return mask + + def quality_check(self, image_data, label_data): + """ + Perform a quality check on the generated image. + Args: + image_data (np.ndarray): The generated image. + label_data (np.ndarray): The corresponding whole body mask. + Returns: + bool: True if the image passes the quality check, False otherwise. + """ + outlier_results = is_outlier(self.median_statistics, image_data, label_data, self.label_int_dict) + for label, result in outlier_results.items(): + if result.get("is_outlier", False): + logging.info( + f"Generated image quality check for label '{label}' failed: median value {result['median_value']} is outside the acceptable range ({result['low_thresh']} - {result['high_thresh']})." + ) + return False + return True diff --git a/scripts/scheduler.py b/scripts/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..71f4f082c0312c80871befb9651c58823318003f --- /dev/null +++ b/scripts/scheduler.py @@ -0,0 +1,207 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +from __future__ import annotations + +import torch +import torch.nn as nn + +from monai.utils import ComponentStore, unsqueeze_right + +NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") + + +@NoiseSchedules.add_def("linear_beta", "Linear beta schedule") +def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + +@NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule") +def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Scaled linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + +@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule") +def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6): + """ + Sigmoid beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + sig_range: pos/neg range of sigmoid input, default 6 + + Returns: + betas: beta schedule tensor + """ + betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + + +@NoiseSchedules.add_def("cosine", "Cosine schedule") +def _cosine_beta(num_train_timesteps: int, s: float = 8e-3): + """ + Cosine noise schedule, see https://arxiv.org/abs/2102.09672 + + Args: + num_train_timesteps: number of timesteps + s: smoothing factor, default 8e-3 (see referenced paper) + + Returns: + (betas, alphas, alpha_cumprod) values + """ + x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1) + alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod /= alphas_cumprod[0].item() + betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + betas = torch.clip(betas, 0.0, 0.999) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + return betas, alphas, alphas_cumprod + + +class Scheduler(nn.Module): + """ + Base class for other schedulers based on a noise schedule function. + + This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here + the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`, + which is the name of a component in NoiseSchedules. These components must all be callables which return either + the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions + can be provided by using the NoiseSchedules.add_def, for example: + + .. code-block:: python + + from monai.networks.schedulers import NoiseSchedules, DDPMScheduler + + @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function") + def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule") + + All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of + timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through + the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules + to get a listing of stored objects with their docstring descriptions. + + Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule + type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended + to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are + still used for some schedules but these are provided as keyword arguments now. + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, + a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple + schedule_args: arguments to pass to the schedule function + """ + + def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None: + super().__init__() + schedule_args["num_train_timesteps"] = num_train_timesteps + noise_sched = NoiseSchedules[schedule](**schedule_args) + + # set betas, alphas, alphas_cumprod based off return value from noise function + if isinstance(noise_sched, tuple): + self.betas, self.alphas, self.alphas_cumprod = noise_sched + else: + self.betas = noise_sched + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + self.num_train_timesteps = num_train_timesteps + self.one = torch.tensor(1.0) + + # settable values + self.num_inference_steps: int | None = None + self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples: original samples + noise: noise to add to samples + timesteps: timesteps tensor indicating the timestep to be computed for each sample. + + Returns: + noisy_samples: sample with added noise + """ + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_cumprod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim) + sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( + (1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim + ) + + noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) + sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( + (1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim + ) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/scripts/schedulers/__pycache__/ddpm.cpython-310.pyc b/scripts/schedulers/__pycache__/ddpm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eacc1c5226cbb11764da67d59e5f4faa8d205a6 Binary files /dev/null and b/scripts/schedulers/__pycache__/ddpm.cpython-310.pyc differ diff --git a/scripts/schedulers/__pycache__/rectified_flow.cpython-310.pyc b/scripts/schedulers/__pycache__/rectified_flow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81a0a2e8e65c4764a5a50f78970d6fb654f71100 Binary files /dev/null and b/scripts/schedulers/__pycache__/rectified_flow.cpython-310.pyc differ diff --git a/scripts/schedulers/__pycache__/scheduler.cpython-310.pyc b/scripts/schedulers/__pycache__/scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..066d1746dd94751a6228535b1d16238f4c6847ba Binary files /dev/null and b/scripts/schedulers/__pycache__/scheduler.cpython-310.pyc differ diff --git a/scripts/schedulers/ddim.py b/scripts/schedulers/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..50a680336d52013f31457c08a4f02306899de8b0 --- /dev/null +++ b/scripts/schedulers/ddim.py @@ -0,0 +1,294 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import numpy as np +import torch + +from .ddpm import DDPMPredictionType +from .scheduler import Scheduler + +DDIMPredictionType = DDPMPredictionType + + +class DDIMScheduler(Scheduler): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion + Implicit Models" https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one. + For the final step there is no previous alpha. When this option is `True` the previous alpha product is + fixed to `1`, otherwise it uses the value of alpha at step 0. + steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type: member of DDPMPredictionType + clip_sample_min: minimum clipping value when clip_sample equals True + clip_sample_max: maximum clipping value when clip_sample equals True + schedule_args: arguments to pass to the schedule function + + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = DDIMPredictionType.EPSILON, + clip_sample_min: float = -1.0, + clip_sample_max: float = 1.0, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if prediction_type not in DDIMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType") + + self.prediction_type = prediction_type + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64)) + + self.clip_sample = clip_sample + self.clip_sample_values = [clip_sample_min, clip_sample_max] + self.steps_offset = steps_offset + + # default the number of inference timesteps to the number of train steps + self.num_inference_steps: int + self.set_timesteps(self.num_train_timesteps) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + if self.steps_offset >= step_ratio: + raise ValueError( + f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to " + f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed" + f" the max train timestep." + ) + + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.steps_offset + + def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance: torch.Tensor = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.0, + generator: torch.Generator | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + eta: weight of noise for added noise in diffusion step. + generator: random number generator. + + Returns: + pred_prev_sample: Predicted previous sample + pred_original_sample: Predicted original sample + """ + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - model_output -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # predefinitions satisfy pylint/mypy, these values won't be ultimately used + pred_original_sample = sample + pred_epsilon = model_output + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5) + pred_epsilon = model_output + elif self.prediction_type == DDIMPredictionType.SAMPLE: + pred_original_sample = model_output + pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp( + pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance**0.5 + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon + + # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + + if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 + device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu") + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator, device=device) + variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample, pred_original_sample + + def reversed_step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + pred_original_sample: Predicted original sample + """ + # See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf + + # Notation ( -> + # - model_output -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_post_sample -> "x_t+1" + + # 1. get previous step value (=t+1) + prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas at timestep t+1 + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # predefinitions satisfy pylint/mypy, these values won't be ultimately used + pred_original_sample = sample + pred_epsilon = model_output + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + + if self.prediction_type == DDIMPredictionType.EPSILON: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.prediction_type == DDIMPredictionType.SAMPLE: + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.prediction_type == DDIMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp( + pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] + ) + + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + + # 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + return pred_post_sample, pred_original_sample diff --git a/scripts/schedulers/ddpm.py b/scripts/schedulers/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b7ab55f516f3cbe101df466a7e6c13a887be66 --- /dev/null +++ b/scripts/schedulers/ddpm.py @@ -0,0 +1,254 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class DDPMVarianceType(StrEnum): + """ + Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise + to the denoised sample. + """ + + FIXED_SMALL = "fixed_small" + FIXED_LARGE = "fixed_large" + LEARNED = "learned" + LEARNED_RANGE = "learned_range" + + +class DDPMPredictionType(StrEnum): + """ + Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + sample: directly predicting the noisy sample + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + SAMPLE = "sample" + V_PREDICTION = "v_prediction" + + +class DDPMScheduler(Scheduler): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models" + https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + variance_type: member of DDPMVarianceType + clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. + prediction_type: member of DDPMPredictionType + clip_sample_min: minimum clipping value when clip_sample equals True + clip_sample_max: maximum clipping value when clip_sample equals True + schedule_args: arguments to pass to the schedule function + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + variance_type: str = DDPMVarianceType.FIXED_SMALL, + clip_sample: bool = True, + prediction_type: str = DDPMPredictionType.EPSILON, + clip_sample_min: float = -1.0, + clip_sample_max: float = 1.0, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if variance_type not in DDPMVarianceType.__members__.values(): + raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") + + if prediction_type not in DDPMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`") + + self.clip_sample = clip_sample + self.clip_sample_values = [clip_sample_min, clip_sample_max] + self.variance_type = variance_type + self.prediction_type = prediction_type + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + + def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: + """ + Compute the mean of the posterior at timestep t. + + Args: + timestep: current timestep. + x0: the noise-free input. + x_t: the input noised to timestep t. + + Returns: + Returns the mean + """ + # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), + # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) + alpha_t = self.alphas[timestep] + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) + x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + + mean: torch.Tensor = x_0_coefficient * x_0 + x_t_coefficient * x_t + + return mean + + def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor: + """ + Compute the variance of the posterior at timestep t. + + Args: + timestep: current timestep. + predicted_variance: variance predicted by the model. + + Returns: + Returns the variance + """ + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance: torch.Tensor = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] + # hacks - were probably added for training stability + if self.variance_type == DDPMVarianceType.FIXED_SMALL: + variance = torch.clamp(variance, min=1e-20) + elif self.variance_type == DDPMVarianceType.FIXED_LARGE: + variance = self.betas[timestep] + elif self.variance_type == DDPMVarianceType.LEARNED and predicted_variance is not None: + return predicted_variance + elif self.variance_type == DDPMVarianceType.LEARNED_RANGE and predicted_variance is not None: + min_log = variance + max_log = self.betas[timestep] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + generator: random number generator. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.prediction_type == DDPMPredictionType.EPSILON: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.prediction_type == DDPMPredictionType.SAMPLE: + pred_original_sample = model_output + elif self.prediction_type == DDPMPredictionType.V_PREDICTION: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + + # 3. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp( + pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t + current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance: torch.Tensor = torch.tensor(0) + if timestep > 0: + noise = torch.randn( + model_output.size(), + dtype=model_output.dtype, + layout=model_output.layout, + generator=generator, + device=model_output.device, + ) + variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise + + pred_prev_sample = pred_prev_sample + variance + + return pred_prev_sample, pred_original_sample diff --git a/scripts/schedulers/pndm.py b/scripts/schedulers/pndm.py new file mode 100644 index 0000000000000000000000000000000000000000..c0728bbdff7ed627bb4ee9f766217d85ef8756d8 --- /dev/null +++ b/scripts/schedulers/pndm.py @@ -0,0 +1,316 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch + +from monai.utils import StrEnum + +from .scheduler import Scheduler + + +class PNDMPredictionType(StrEnum): + """ + Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument. + + epsilon: predicting the noise of the diffusion process + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + EPSILON = "epsilon" + V_PREDICTION = "v_prediction" + + +class PNDMScheduler(Scheduler): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al., + "Pseudo Numerical Methods for Diffusion Models on Manifolds" https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, name of noise schedule function in component store + skip_prk_steps: + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms step. + set_alpha_to_one: + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + prediction_type: member of DDPMPredictionType + steps_offset: + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + schedule_args: arguments to pass to the schedule function + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + schedule: str = "linear_beta", + skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + prediction_type: str = PNDMPredictionType.EPSILON, + steps_offset: int = 0, + **schedule_args, + ) -> None: + super().__init__(num_train_timesteps, schedule, **schedule_args) + + if prediction_type not in PNDMPredictionType.__members__.values(): + raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType") + + self.prediction_type = prediction_type + + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + self.skip_prk_steps = skip_prk_steps + self.steps_offset = steps_offset + + # running values + self.cur_model_output = torch.Tensor() + self.counter = 0 + self.cur_sample = torch.Tensor() + self.ets: list = [] + + # default the number of inference timesteps to the number of train steps + self.set_timesteps(num_train_timesteps) + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64) + self._timesteps += self.steps_offset + + if self.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + self.prk_timesteps = np.array([]) + self.plms_timesteps = self._timesteps[::-1] + + else: + prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( + np.array([0, self.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + ) + self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() + self.plms_timesteps = self._timesteps[:-3][ + ::-1 + ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy + + timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + # update num_inference_steps - necessary if we use prk steps + self.num_inference_steps = len(self.timesteps) + + self.ets = [] + self.counter = 0 + + def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> tuple[torch.Tensor, Any]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + Returns: + pred_prev_sample: Predicted previous sample + """ + # return a tuple for consistency with samplers that return (previous pred, original sample pred) + + if self.counter < len(self.prk_timesteps) and not self.skip_prk_steps: + return self.step_prk(model_output=model_output, timestep=timestep, sample=sample), None + else: + return self.step_plms(model_output=model_output, timestep=timestep, sample=sample), None + + def step_prk(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> torch.Tensor: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = 0 if self.counter % 2 else self.num_train_timesteps // self.num_inference_steps // 2 + prev_timestep = timestep - diff_to_prev + timestep = self.prk_timesteps[self.counter // 4 * 4] + + if self.counter % 4 == 0: + self.cur_model_output = 1 / 6 * model_output + self.ets.append(model_output) + self.cur_sample = sample + elif (self.counter - 1) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 2) % 4 == 0: + self.cur_model_output += 1 / 3 * model_output + elif (self.counter - 3) % 4 == 0: + model_output = self.cur_model_output + 1 / 6 * model_output + self.cur_model_output = torch.Tensor() + + # cur_sample should not be an empty torch.Tensor() + cur_sample = self.cur_sample if self.cur_sample.numel() != 0 else sample + + prev_sample: torch.Tensor = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + self.counter += 1 + + return prev_sample + + def step_plms(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> Any: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + + Returns: + pred_prev_sample: Predicted previous sample + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if not self.skip_prk_steps and len(self.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + ) + + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + + if self.counter != 1: + self.ets = self.ets[-3:] + self.ets.append(model_output) + else: + prev_timestep = timestep + timestep = timestep + self.num_train_timesteps // self.num_inference_steps + + if len(self.ets) == 1 and self.counter == 0: + model_output = model_output + self.cur_sample = sample + elif len(self.ets) == 1 and self.counter == 1: + model_output = (model_output + self.ets[-1]) / 2 + sample = self.cur_sample + self.cur_sample = torch.Tensor() + elif len(self.ets) == 2: + model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) + self.counter += 1 + + return prev_sample + + def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + if self.prediction_type == PNDMPredictionType.V_PREDICTION: + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample diff --git a/scripts/schedulers/rectified_flow.py b/scripts/schedulers/rectified_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..e660a1abb6734c93a6b8e312a49bfdc6e86b7591 --- /dev/null +++ b/scripts/schedulers/rectified_flow.py @@ -0,0 +1,322 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py +# which has the following license: +# https://github.com/hpcaitech/Open-Sora/blob/main/LICENSE +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from typing import Union + +import numpy as np +import torch +from torch.distributions import LogisticNormal + +from monai.utils import StrEnum + +from .ddpm import DDPMPredictionType +from .scheduler import Scheduler + + +class RFlowPredictionType(StrEnum): + """ + Set of valid prediction type names for the RFlow scheduler's `prediction_type` argument. + + v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf + """ + + V_PREDICTION = DDPMPredictionType.V_PREDICTION + + +def timestep_transform( + t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3 +): + """ + Applies a transformation to the timestep based on image resolution scaling. + + Args: + t (torch.Tensor): The original timestep(s). + input_img_size_numel (torch.Tensor): The input image's size (H * W * D). + base_img_size_numel (int): reference H*W*D size, usually smaller than input_img_size_numel. + scale (float): Scaling factor for the transformation. + num_train_timesteps (int): Total number of training timesteps. + spatial_dim (int): Number of spatial dimensions in the image. + + Returns: + torch.Tensor: Transformed timestep(s). + """ + t = t / num_train_timesteps + ratio_space = (input_img_size_numel / base_img_size_numel) ** (1.0 / spatial_dim) + + ratio = ratio_space * scale + new_t = ratio * t / (1 + (ratio - 1) * t) + + new_t = new_t * num_train_timesteps + return new_t + + +class RFlowScheduler(Scheduler): + """ + A rectified flow scheduler for guiding the diffusion process in a generative model. + + Supports uniform and logit-normal sampling methods, timestep transformation for + different resolutions, and noise addition during diffusion. + + Args: + num_train_timesteps (int): Total number of training timesteps. + use_discrete_timesteps (bool): Whether to use discrete timesteps. + sample_method (str): Training time step sampling method ('uniform' or 'logit-normal'). + loc (float): Location parameter for logit-normal distribution, used only if sample_method='logit-normal'. + scale (float): Scale parameter for logit-normal distribution, used only if sample_method='logit-normal'. + use_timestep_transform (bool): Whether to apply timestep transformation. + If true, there will be more inference timesteps at early(noisy) stages for larger image volumes. + transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True. + steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True. + base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True. + spatial_dim (int): 2 or 3, incidcating 2D or 3D images, used only if use_timestep_transform=True. + + Example: + + .. code-block:: python + + # define a scheduler + noise_scheduler = RFlowScheduler( + num_train_timesteps = 1000, + use_discrete_timesteps = True, + sample_method = 'logit-normal', + use_timestep_transform = True, + base_img_size_numel = 32 * 32 * 32, + spatial_dim = 3 + ) + + # during training + inputs = torch.ones(2,4,64,64,32) + noise = torch.randn_like(inputs) + timesteps = noise_scheduler.sample_timesteps(inputs) + noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + predicted_velocity = diffusion_unet( + x=noisy_inputs, + timesteps=timesteps + ) + loss = loss_l1(predicted_velocity, (inputs - noise)) + + # during inference + noisy_inputs = torch.randn(2,4,64,64,32) + input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:]) + noise_scheduler.set_timesteps( + num_inference_steps=30, input_img_size_numel=input_img_size_numel) + ) + all_next_timesteps = torch.cat( + (noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype)) + ) + for t, next_t in tqdm( + zip(noise_scheduler.timesteps, all_next_timesteps), + total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)), + ): + predicted_velocity = diffusion_unet( + x=noisy_inputs, + timesteps=timesteps + ) + noisy_inputs, _ = noise_scheduler.step(predicted_velocity, t, noisy_inputs, next_t) + final_output = noisy_inputs + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + use_discrete_timesteps: bool = True, + sample_method: str = "uniform", + loc: float = 0.0, + scale: float = 1.0, + use_timestep_transform: bool = False, + transform_scale: float = 1.0, + steps_offset: int = 0, + base_img_size_numel: int = 32 * 32 * 32, + spatial_dim: int = 3, + ): + # rectified flow only accepts velocity prediction + self.prediction_type = RFlowPredictionType.V_PREDICTION + + self.num_train_timesteps = num_train_timesteps + self.use_discrete_timesteps = use_discrete_timesteps + self.base_img_size_numel = base_img_size_numel + self.spatial_dim = spatial_dim + + # sample method + if sample_method not in ["uniform", "logit-normal"]: + raise ValueError( + f"sample_method = {sample_method}, which has to be chosen from ['uniform', 'logit-normal']." + ) + self.sample_method = sample_method + if sample_method == "logit-normal": + self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale])) + self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device) + + # timestep transform + self.use_timestep_transform = use_timestep_transform + self.transform_scale = transform_scale + self.steps_offset = steps_offset + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples: original samples + noise: noise to add to samples + timesteps: timesteps tensor with shape of (N,), indicating the timestep to be computed for each sample. + + Returns: + noisy_samples: sample with added noise + """ + timepoints: torch.Tensor = timesteps.float() / self.num_train_timesteps + timepoints = 1 - timepoints # [1,1/1000] + + # expand timepoint to noise shape + if noise.ndim == 5: + timepoints = timepoints[..., None, None, None, None].expand(-1, *noise.shape[1:]) + elif noise.ndim == 4: + timepoints = timepoints[..., None, None, None].expand(-1, *noise.shape[1:]) + else: + raise ValueError(f"noise tensor has to be 4D or 5D tensor, yet got shape of {noise.shape}") + + noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise + + return noisy_samples + + def set_timesteps( + self, + num_inference_steps: int, + device: str | torch.device | None = None, + input_img_size_numel: int | None = None, + ) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + input_img_size_numel: int, H*W*D of the image, used with self.use_timestep_transform is True. + """ + if num_inference_steps > self.num_train_timesteps or num_inference_steps < 1: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} should be at least 1, " + "and cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + # prepare timesteps + timesteps = [ + (1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps) + ] + if self.use_discrete_timesteps: + timesteps = [int(round(t)) for t in timesteps] + if self.use_timestep_transform: + timesteps = [ + timestep_transform( + t, + input_img_size_numel=input_img_size_numel, + base_img_size_numel=self.base_img_size_numel, + num_train_timesteps=self.num_train_timesteps, + spatial_dim=self.spatial_dim, + ) + for t in timesteps + ] + timesteps_np = np.array(timesteps).astype(np.float16) + if self.use_discrete_timesteps: + timesteps_np = timesteps_np.astype(np.int64) + self.timesteps = torch.from_numpy(timesteps_np).to(device) + self.timesteps += self.steps_offset + + def sample_timesteps(self, x_start): + """ + Randomly samples training timesteps using the chosen sampling method. + + Args: + x_start (torch.Tensor): The input tensor for sampling. + + Returns: + torch.Tensor: Sampled timesteps. + """ + if self.sample_method == "uniform": + t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps + elif self.sample_method == "logit-normal": + t = self.sample_t(x_start) * self.num_train_timesteps + + if self.use_discrete_timesteps: + t = t.long() + + if self.use_timestep_transform: + input_img_size_numel = torch.prod(torch.tensor(x_start.shape[2:])) + t = timestep_transform( + t, + input_img_size_numel=input_img_size_numel, + base_img_size_numel=self.base_img_size_numel, + num_train_timesteps=self.num_train_timesteps, + spatial_dim=len(x_start.shape) - 2, + ) + + return t + + def step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: Union[int, None] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Predicts the next sample in the diffusion process. + + Args: + model_output (torch.Tensor): Output from the trained diffusion model. + timestep (int): Current timestep in the diffusion chain. + sample (torch.Tensor): Current sample in the process. + next_timestep (Union[int, None]): Optional next timestep. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Predicted sample at the next step and additional info. + """ + # Ensure num_inference_steps exists and is a valid integer + if not hasattr(self, "num_inference_steps") or not isinstance(self.num_inference_steps, int): + raise AttributeError( + "num_inference_steps is missing or not an integer in the class." + "Please run self.set_timesteps(num_inference_steps,device,input_img_size_numel) to set it." + ) + + v_pred = model_output + + if next_timestep is not None: + next_timestep = int(next_timestep) + dt: float = ( + float(timestep - next_timestep) / self.num_train_timesteps + ) # Now next_timestep is guaranteed to be int + else: + dt = ( + 1.0 / float(self.num_inference_steps) if self.num_inference_steps > 0 else 0.0 + ) # Avoid division by zero + + pred_post_sample = sample + v_pred * dt + pred_original_sample = sample + v_pred * timestep / self.num_train_timesteps + + return pred_post_sample, pred_original_sample diff --git a/scripts/schedulers/scheduler.py b/scripts/schedulers/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..71f4f082c0312c80871befb9651c58823318003f --- /dev/null +++ b/scripts/schedulers/scheduler.py @@ -0,0 +1,207 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + + +from __future__ import annotations + +import torch +import torch.nn as nn + +from monai.utils import ComponentStore, unsqueeze_right + +NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") + + +@NoiseSchedules.add_def("linear_beta", "Linear beta schedule") +def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + +@NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule") +def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): + """ + Scaled linear beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + + Returns: + betas: beta schedule tensor + """ + return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + +@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule") +def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6): + """ + Sigmoid beta noise schedule function. + + Args: + num_train_timesteps: number of timesteps + beta_start: start of beta range, default 1e-4 + beta_end: end of beta range, default 2e-2 + sig_range: pos/neg range of sigmoid input, default 6 + + Returns: + betas: beta schedule tensor + """ + betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) + return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + + +@NoiseSchedules.add_def("cosine", "Cosine schedule") +def _cosine_beta(num_train_timesteps: int, s: float = 8e-3): + """ + Cosine noise schedule, see https://arxiv.org/abs/2102.09672 + + Args: + num_train_timesteps: number of timesteps + s: smoothing factor, default 8e-3 (see referenced paper) + + Returns: + (betas, alphas, alpha_cumprod) values + """ + x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1) + alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod /= alphas_cumprod[0].item() + betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + betas = torch.clip(betas, 0.0, 0.999) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + return betas, alphas, alphas_cumprod + + +class Scheduler(nn.Module): + """ + Base class for other schedulers based on a noise schedule function. + + This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here + the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`, + which is the name of a component in NoiseSchedules. These components must all be callables which return either + the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions + can be provided by using the NoiseSchedules.add_def, for example: + + .. code-block:: python + + from monai.networks.schedulers import NoiseSchedules, DDPMScheduler + + @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function") + def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): + return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + + scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule") + + All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of + timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through + the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules + to get a listing of stored objects with their docstring descriptions. + + Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule + type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended + to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are + still used for some schedules but these are provided as keyword arguments now. + + Args: + num_train_timesteps: number of diffusion steps used to train the model. + schedule: member of NoiseSchedules, + a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple + schedule_args: arguments to pass to the schedule function + """ + + def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None: + super().__init__() + schedule_args["num_train_timesteps"] = num_train_timesteps + noise_sched = NoiseSchedules[schedule](**schedule_args) + + # set betas, alphas, alphas_cumprod based off return value from noise function + if isinstance(noise_sched, tuple): + self.betas, self.alphas, self.alphas_cumprod = noise_sched + else: + self.betas = noise_sched + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + self.num_train_timesteps = num_train_timesteps + self.one = torch.tensor(1.0) + + # settable values + self.num_inference_steps: int | None = None + self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) + + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + """ + Add noise to the original samples. + + Args: + original_samples: original samples + noise: noise to add to samples + timesteps: timesteps tensor indicating the timestep to be computed for each sample. + + Returns: + noisy_samples: sample with added noise + """ + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_cumprod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim) + sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( + (1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim + ) + + noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) + sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right( + (1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim + ) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity diff --git a/scripts/train_controlnet.py b/scripts/train_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2cce1ef5f925d65a0269a8aee094b9e7c23c5e --- /dev/null +++ b/scripts/train_controlnet.py @@ -0,0 +1,352 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import os +import sys +import time +from datetime import timedelta +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from monai.networks.utils import copy_model_state +from monai.utils import RankFilter +from monai.networks.schedulers import RFlowScheduler +#from .schedulers.rectified_flow import RFlowScheduler +from monai.networks.schedulers.ddpm import DDPMPredictionType +from torch.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from .utils import binarize_labels, define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp + + +def main(): + parser = argparse.ArgumentParser(description="maisi.controlnet.training") + parser.add_argument( + "-e", + "--environment-file", + default="./configs/environment_maisi_controlnet_train.json", + help="environment json file that stores environment path", + ) + parser.add_argument( + "-c", + "--config-file", + default="./configs/config_maisi-ddpm.json", + help="config json file that stores network hyper-parameters", + ) + parser.add_argument( + "-t", + "--training-config", + default="./configs/config_maisi_controlnet_train.json", + help="config json file that stores training hyper-parameters", + ) + parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node") + + args = parser.parse_args() + + # Step 0: configuration + logger = logging.getLogger("maisi.controlnet.training") + # whether to use distributed data parallel + use_ddp = args.gpus > 1 + if use_ddp: + rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + device = setup_ddp(rank, world_size) + logger.addFilter(RankFilter()) + else: + rank = 0 + world_size = 1 + device = torch.device(f"cuda:{rank}") + + torch.cuda.set_device(device) + logger.info(f"Number of GPUs: {torch.cuda.device_count()}") + logger.info(f"World_size: {world_size}") + + with open(args.environment_file, "r") as env_file: + env_dict = json.load(env_file) + with open(args.config_file, "r") as config_file: + config_dict = json.load(config_file) + with open(args.training_config, "r") as training_config_file: + training_config_dict = json.load(training_config_file) + + for k, v in env_dict.items(): + setattr(args, k, v) + for k, v in config_dict.items(): + setattr(args, k, v) + for k, v in training_config_dict.items(): + setattr(args, k, v) + + # initialize tensorboard writer + if rank == 0: + tensorboard_path = os.path.join(args.tfevent_path, args.exp_name) + Path(tensorboard_path).mkdir(parents=True, exist_ok=True) + tensorboard_writer = SummaryWriter(tensorboard_path) + + # Step 1: set data loader + train_loader, _ = prepare_maisi_controlnet_json_dataloader( + json_data_list=args.json_data_list, + data_base_dir=args.data_base_dir, + rank=rank, + world_size=world_size, + batch_size=args.controlnet_train["batch_size"], + cache_rate=args.controlnet_train["cache_rate"], + fold=args.controlnet_train["fold"], + ) + + # Step 2: define diffusion model and controlnet + # define diffusion Model + unet = define_instance(args, "diffusion_unet_def").to(device) + include_body_region = unet.include_top_region_index_input + include_modality = unet.num_class_embeds is not None + + # load trained diffusion model + if args.trained_diffusion_path is not None: + if not os.path.exists(args.trained_diffusion_path): + raise ValueError("Please download the trained diffusion unet checkpoint.") + diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device, weights_only=False) + unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"]) + # load scale factor from diffusion model checkpoint + scale_factor = diffusion_model_ckpt["scale_factor"] + logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.") + logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.") + else: + logger.info("trained diffusion model is not loaded.") + scale_factor = 1.0 + logger.info(f"set scale_factor -> {scale_factor}.") + + # define ControlNet + controlnet = define_instance(args, "controlnet_def").to(device) + # copy weights from the DM to the controlnet + copy_model_state(controlnet, unet.state_dict()) + # load trained controlnet model if it is provided + if args.trained_controlnet_path is not None: + if not os.path.exists(args.trained_controlnet_path): + raise ValueError("Please download the trained ControlNet checkpoint.") + controlnet.load_state_dict( + torch.load(args.trained_controlnet_path, map_location=device, weights_only=False)["controlnet_state_dict"] + ) + logger.info(f"load trained controlnet model from {args.trained_controlnet_path}") + else: + logger.info("train controlnet model from scratch.") + # we freeze the parameters of the diffusion model. + for p in unet.parameters(): + p.requires_grad = False + + noise_scheduler = define_instance(args, "noise_scheduler") + + if use_ddp: + controlnet = DDP(controlnet, device_ids=[device], output_device=rank, find_unused_parameters=True) + + # Step 3: training config + weighted_loss = args.controlnet_train["weighted_loss"] + weighted_loss_label = args.controlnet_train["weighted_loss_label"] + optimizer = torch.optim.AdamW(params=controlnet.parameters(), lr=args.controlnet_train["lr"]) + total_steps = (args.controlnet_train["n_epochs"] * len(train_loader.dataset)) / args.controlnet_train["batch_size"] + logger.info(f"total number of training steps: {total_steps}.") + + lr_scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0) + + # Step 4: training + n_epochs = args.controlnet_train["n_epochs"] + scaler = GradScaler("cuda") + total_step = 0 + best_loss = 1e4 + + if weighted_loss > 1.0: + logger.info(f"apply weighted loss = {weighted_loss} on labels: {weighted_loss_label}") + + controlnet.train() + unet.eval() + prev_time = time.time() + for epoch in range(n_epochs): + epoch_loss_ = 0 + for step, batch in enumerate(train_loader): + # logger.info(f"Reading image: {batch['image']}") + # logger.info(f"Reading image: {batch['image'].meta['filename_or_obj']}") + # logger.info(f"[Data Load] Step {step+1}: Reading sample {batch.get('image', 'N/A')}") + # get image embedding and label mask and scale image embedding by the provided scale_factor + images = batch["image"].to(device) * scale_factor + labels = batch["label"].to(device) + # get corresponding conditions + if include_body_region: + top_region_index_tensor = batch["top_region_index"].to(device) + bottom_region_index_tensor = batch["bottom_region_index"].to(device) + # We trained with only CT in this version + if include_modality: + modality_tensor = torch.ones((len(images),), dtype=torch.long).to(device) + spacing_tensor = batch["spacing"].to(device) + + optimizer.zero_grad(set_to_none=True) + + with autocast("cuda", enabled=True): + # generate random noise + noise_shape = list(images.shape) + noise = torch.randn(noise_shape, dtype=images.dtype).to(device) + + # use binary encoding to encode segmentation mask + controlnet_cond = binarize_labels(labels.as_tensor().to(torch.uint8)).float() + + # create timesteps + if isinstance(noise_scheduler, RFlowScheduler): + timesteps = noise_scheduler.sample_timesteps(images) + else: + timesteps = torch.randint( + 0, noise_scheduler.num_train_timesteps, (images.shape[0],), device=images.device + ).long() + + # create noisy latent + noisy_latent = noise_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps) + + # get controlnet output + # Create a dictionary to store the inputs + controlnet_inputs = { + "x": noisy_latent, + "timesteps": timesteps, + "controlnet_cond": controlnet_cond, + } + if include_modality: + controlnet_inputs.update( + { + "class_labels": modality_tensor, + } + ) + down_block_res_samples, mid_block_res_sample = controlnet(**controlnet_inputs) + + # get diffusion network output + # Create a dictionary to store the inputs + unet_inputs = { + "x": noisy_latent, + "timesteps": timesteps, + "spacing_tensor": spacing_tensor, + "down_block_additional_residuals": down_block_res_samples, + "mid_block_additional_residual": mid_block_res_sample, + } + # Add extra arguments if include_body_region is True + if include_body_region: + unet_inputs.update( + { + "top_region_index_tensor": top_region_index_tensor, + "bottom_region_index_tensor": bottom_region_index_tensor, + } + ) + if include_modality: + unet_inputs.update( + { + "class_labels": modality_tensor, + } + ) + model_output = unet(**unet_inputs) + + if noise_scheduler.prediction_type == DDPMPredictionType.EPSILON: + # predict noise + model_gt = noise + elif noise_scheduler.prediction_type == DDPMPredictionType.SAMPLE: + # predict sample + model_gt = images + elif noise_scheduler.prediction_type == DDPMPredictionType.V_PREDICTION: + # predict velocity + model_gt = images - noise + else: + raise ValueError( + "noise scheduler prediction type has to be chosen from ", + f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]", + ) + + if weighted_loss > 1.0: + weights = torch.ones_like(images).to(images.device) + roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(images.device) + interpolate_label = F.interpolate(labels, size=images.shape[2:], mode="nearest") + # assign larger weights for ROI (tumor) + for label in weighted_loss_label: + roi[interpolate_label == label] = 1 + weights[roi.repeat(1, images.shape[1], 1, 1, 1) == 1] = weighted_loss + loss = (F.l1_loss(model_output.float(), model_gt.float(), reduction="none") * weights).mean() + else: + loss = F.l1_loss(model_output.float(), model_gt.float()) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + lr_scheduler.step() + total_step += 1 + + if rank == 0: + # write train loss for each batch into tensorboard + tensorboard_writer.add_scalar( + "train/train_controlnet_loss_iter", loss.detach().cpu().item(), total_step + ) + batches_done = step + 1 + batches_left = len(train_loader) - batches_done + time_left = timedelta(seconds=batches_left * (time.time() - prev_time)) + prev_time = time.time() + logger.info( + "\r[Epoch %d/%d] [Batch %d/%d] [LR: %.8f] [loss: %.4f] ETA: %s " + % ( + epoch + 1, + n_epochs, + step + 1, + len(train_loader), + lr_scheduler.get_last_lr()[0], + loss.detach().cpu().item(), + time_left, + ) + ) + epoch_loss_ += loss.detach() + + epoch_loss = epoch_loss_ / (step + 1) + + if use_ddp: + dist.barrier() + dist.all_reduce(epoch_loss, op=torch.distributed.ReduceOp.AVG) + + if rank == 0: + tensorboard_writer.add_scalar("train/train_controlnet_loss_epoch", epoch_loss.cpu().item(), total_step) + # save controlnet only on master GPU (rank 0) + controlnet_state_dict = controlnet.module.state_dict() if world_size > 1 else controlnet.state_dict() + torch.save( + { + "epoch": epoch + 1, + "loss": epoch_loss, + "controlnet_state_dict": controlnet_state_dict, + }, + f"{args.model_dir}/{args.exp_name}_current.pt", + ) + + if epoch_loss < best_loss: + best_loss = epoch_loss + logger.info(f"best loss -> {best_loss}.") + torch.save( + { + "epoch": epoch + 1, + "loss": best_loss, + "controlnet_state_dict": controlnet_state_dict, + }, + f"{args.model_dir}/{args.exp_name}_best.pt", + ) + + torch.cuda.empty_cache() + if use_ddp: + dist.destroy_process_group() + + +if __name__ == "__main__": + logging.basicConfig( + stream=sys.stdout, + level=logging.INFO, + format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + main() \ No newline at end of file diff --git a/scripts/transforms.py b/scripts/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..8a8237d765a765ec4ac3e170a9ecb348fceb0261 --- /dev/null +++ b/scripts/transforms.py @@ -0,0 +1,324 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional + +import torch +from monai.transforms import ( + Compose, + DivisiblePadd, + EnsureChannelFirstd, + EnsureTyped, + Lambdad, + LoadImaged, + Orientationd, + RandAdjustContrastd, + RandBiasFieldd, + RandFlipd, + RandGibbsNoised, + RandHistogramShiftd, + RandRotate90d, + RandRotated, + RandScaleIntensityd, + RandShiftIntensityd, + RandSpatialCropd, + RandZoomd, + ResizeWithPadOrCropd, + ScaleIntensityRanged, + ScaleIntensityRangePercentilesd, + SelectItemsd, + Spacingd, + SpatialPadd, +) + +SUPPORT_MODALITIES = ["ct", "mri"] + + +def define_fixed_intensity_transform(modality: str, image_keys: List[str] = ["image"]) -> List: + """ + Define fixed intensity transform based on the modality. + + Args: + modality (str): The imaging modality, either 'ct' or 'mri'. + image_keys (List[str], optional): List of image keys. Defaults to ["image"]. + + Returns: + List: A list of intensity transforms. + """ + if modality not in SUPPORT_MODALITIES: + warnings.warn( + f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities." + ) + + modality = modality.lower() # Normalize modality to lowercase + + intensity_transforms = { + "mri": [ + ScaleIntensityRangePercentilesd(keys=image_keys, lower=0.0, upper=99.5, b_min=0.0, b_max=1, clip=False) + ], + "ct": [ScaleIntensityRanged(keys=image_keys, a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0, clip=True)], + } + + if modality not in intensity_transforms: + return [] + + return intensity_transforms[modality] + + +def define_random_intensity_transform(modality: str, image_keys: List[str] = ["image"]) -> List: + """ + Define random intensity transform based on the modality. + + Args: + modality (str): The imaging modality, either 'ct' or 'mri'. + image_keys (List[str], optional): List of image keys. Defaults to ["image"]. + + Returns: + List: A list of random intensity transforms. + """ + modality = modality.lower() # Normalize modality to lowercase + if modality not in SUPPORT_MODALITIES: + warnings.warn( + f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities." + ) + + if modality == "ct": + return [] # CT HU intensity is stable across different datasets + elif modality == "mri": + return [ + RandBiasFieldd(keys=image_keys, prob=0.3, coeff_range=(0.0, 0.3)), + RandGibbsNoised(keys=image_keys, prob=0.3, alpha=(0.5, 1.0)), + RandAdjustContrastd(keys=image_keys, prob=0.3, gamma=(0.5, 2.0)), + RandHistogramShiftd(keys=image_keys, prob=0.05, num_control_points=10), + ] + else: + return [] + + +def define_vae_transform( + is_train: bool, + modality: str, + random_aug: bool, + k: int = 4, + patch_size: List[int] = [128, 128, 128], + val_patch_size: Optional[List[int]] = None, + output_dtype: torch.dtype = torch.float32, + spacing_type: str = "original", + spacing: Optional[List[float]] = None, + image_keys: List[str] = ["image"], + label_keys: List[str] = [], + additional_keys: List[str] = [], + select_channel: int = 0, +) -> tuple: + """ + Define the MAISI VAE transform pipeline for training or validation. + + Args: + is_train (bool): Whether it's for training or not. If True, the output transform will consider random_aug, the cropping will use "patch_size" for random cropping. If False, the output transform will alwasy treat "random_aug" as False, will use "val_patch_size" for central cropping. + modality (str): The imaging modality, either 'ct' or 'mri'. + random_aug (bool): Whether to apply random data augmentation. + k (int, optional): Patches should be divisible by k. Defaults to 4. + patch_size (List[int], optional): Size of the patches. Defaults to [128, 128, 128]. Will random crop patch for training. + val_patch_size (Optional[List[int]], optional): Size of validation patches. Defaults to None. If None, will use the whole volume for validation. If given, will central crop a patch for validation. + output_dtype (torch.dtype, optional): Output data type. Defaults to torch.float32. + spacing_type (str, optional): Type of spacing. Defaults to "original". Choose from ["original", "fixed", "rand_zoom"]. + spacing (Optional[List[float]], optional): Spacing values. Defaults to None. + image_keys (List[str], optional): List of image keys. Defaults to ["image"]. + label_keys (List[str], optional): List of label keys. Defaults to []. + additional_keys (List[str], optional): List of additional keys. Defaults to []. + select_channel (int, optional): Channel to select for multi-channel MRI. Defaults to 0. + + Returns: + tuple: A tuple containing Composed Transform train_transforms or val_transforms depending on 'is_train'. + """ + modality = modality.lower() # Normalize modality to lowercase + if modality not in SUPPORT_MODALITIES: + warnings.warn( + f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities." + ) + + if spacing_type not in ["original", "fixed", "rand_zoom"]: + raise ValueError(f"spacing_type has to be chosen from ['original', 'fixed', 'rand_zoom']. Got {spacing_type}.") + + keys = image_keys + label_keys + additional_keys + interp_mode = ["bilinear"] * len(image_keys) + ["nearest"] * len(label_keys) + + common_transform = [ + SelectItemsd(keys=keys, allow_missing_keys=True), + LoadImaged(keys=keys, allow_missing_keys=True), + EnsureChannelFirstd(keys=keys, allow_missing_keys=True), + Orientationd(keys=keys, axcodes="RAS", allow_missing_keys=True), + ] + + if modality == "mri": + common_transform.append(Lambdad(keys=image_keys, func=lambda x: x[select_channel : select_channel + 1, ...])) + + common_transform.extend(define_fixed_intensity_transform(modality, image_keys=image_keys)) + + if spacing_type == "fixed": + common_transform.append( + Spacingd(keys=image_keys + label_keys, allow_missing_keys=True, pixdim=spacing, mode=interp_mode) + ) + + random_transform = [] + if is_train and random_aug: + random_transform.extend(define_random_intensity_transform(modality, image_keys=image_keys)) + random_transform.extend( + [RandFlipd(keys=keys, allow_missing_keys=True, prob=0.5, spatial_axis=axis) for axis in range(3)] + + [ + RandRotate90d(keys=keys, allow_missing_keys=True, prob=0.5, spatial_axes=axes) + for axes in [(0, 1), (1, 2), (0, 2)] + ] + + [ + RandScaleIntensityd(keys=image_keys, allow_missing_keys=True, prob=0.3, factors=(0.9, 1.1)), + RandShiftIntensityd(keys=image_keys, allow_missing_keys=True, prob=0.3, offsets=0.05), + ] + ) + + if spacing_type == "rand_zoom": + random_transform.extend( + [ + RandZoomd( + keys=image_keys + label_keys, + allow_missing_keys=True, + prob=0.3, + min_zoom=0.5, + max_zoom=1.5, + keep_size=False, + mode=interp_mode, + ), + RandRotated( + keys=image_keys + label_keys, + allow_missing_keys=True, + prob=0.3, + range_x=0.1, + range_y=0.1, + range_z=0.1, + keep_size=True, + mode=interp_mode, + ), + ] + ) + + if is_train: + train_crop = [ + SpatialPadd(keys=keys, spatial_size=patch_size, allow_missing_keys=True), + RandSpatialCropd( + keys=keys, roi_size=patch_size, allow_missing_keys=True, random_size=False, random_center=True + ), + ] + else: + val_crop = ( + [DivisiblePadd(keys=keys, allow_missing_keys=True, k=k)] + if val_patch_size is None + else [ResizeWithPadOrCropd(keys=keys, allow_missing_keys=True, spatial_size=val_patch_size)] + ) + + final_transform = [EnsureTyped(keys=keys, dtype=output_dtype, allow_missing_keys=True)] + + if is_train: + train_transforms = Compose( + common_transform + random_transform + train_crop + final_transform + if random_aug + else common_transform + train_crop + final_transform + ) + return train_transforms + else: + val_transforms = Compose(common_transform + val_crop + final_transform) + return val_transforms + + +class VAE_Transform: + """ + A class to handle MAISI VAE transformations for different modalities. + """ + + def __init__( + self, + is_train: bool, + random_aug: bool, + k: int = 4, + patch_size: List[int] = [128, 128, 128], + val_patch_size: Optional[List[int]] = None, + output_dtype: torch.dtype = torch.float32, + spacing_type: str = "original", + spacing: Optional[List[float]] = None, + image_keys: List[str] = ["image"], + label_keys: List[str] = [], + additional_keys: List[str] = [], + select_channel: int = 0, + ): + """ + Initialize the VAE_Transform. + + Args: + is_train (bool): Whether it's for training or not. If True, the output transform will consider random_aug, the cropping will use "patch_size" for random cropping. If False, the output transform will alwasy treat "random_aug" as False, will use "val_patch_size" for central cropping. + random_aug (bool): Whether to apply random data augmentation for training. + k (int, optional): Patches should be divisible by k. Defaults to 4. + patch_size (List[int], optional): Size of the patches. Defaults to [128, 128, 128]. Will random crop patch for training. + val_patch_size (Optional[List[int]], optional): Size of validation patches. Defaults to None. If None, will use the whole volume for validation. If given, will central crop a patch for validation. + output_dtype (torch.dtype, optional): Output data type. Defaults to torch.float32. + spacing_type (str, optional): Type of spacing. Defaults to "original". Choose from ["original", "fixed", "rand_zoom"]. + spacing (Optional[List[float]], optional): Spacing values. Defaults to None. + image_keys (List[str], optional): List of image keys. Defaults to ["image"]. + label_keys (List[str], optional): List of label keys. Defaults to []. + additional_keys (List[str], optional): List of additional keys. Defaults to []. + select_channel (int, optional): Channel to select for multi-channel MRI. Defaults to 0. + """ + if spacing_type not in ["original", "fixed", "rand_zoom"]: + raise ValueError( + f"spacing_type has to be chosen from ['original', 'fixed', 'rand_zoom']. Got {spacing_type}." + ) + + self.is_train = is_train + self.transform_dict = {} + + for modality in ["ct", "mri"]: + self.transform_dict[modality] = define_vae_transform( + is_train=is_train, + modality=modality, + random_aug=random_aug, + k=k, + patch_size=patch_size, + val_patch_size=val_patch_size, + output_dtype=output_dtype, + spacing_type=spacing_type, + spacing=spacing, + image_keys=image_keys, + label_keys=label_keys, + additional_keys=additional_keys, + select_channel=select_channel, + ) + + def __call__(self, img: dict, fixed_modality: Optional[str] = None) -> dict: + """ + Apply the appropriate transform to the input image. + + Args: + img (dict): Input image dictionary. + fixed_modality (Optional[str], optional): Fixed modality to use. Defaults to None. + + Returns: + Composed Transform + + Raises: + ValueError: If the modality is not 'ct' or 'mri'. + """ + modality = fixed_modality or img["class"] + modality = modality.lower() # Normalize modality to lowercase + if modality not in ["ct", "mri"]: + warnings.warn( + f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities." + ) + + transform = self.transform_dict[modality] + return transform(img) diff --git a/scripts/utils.py b/scripts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c0bccffd9d0f4d442691f4b7ced32b7a62675390 --- /dev/null +++ b/scripts/utils.py @@ -0,0 +1,884 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import logging +import math +import os +from argparse import Namespace +from datetime import timedelta +from typing import Any, Sequence + +import numpy as np +import skimage +import torch +import torch.distributed as dist +from monai.bundle import ConfigParser +from monai.config import DtypeLike, NdarrayOrTensor +from monai.data import CacheDataset, DataLoader, partition_dataset +from monai.transforms import Compose, EnsureTyped, Lambdad, LoadImaged, Orientationd +from monai.transforms.utils_morphological_ops import dilate, erode +from monai.utils import TransformBackends, convert_data_type, convert_to_dst_type, get_equivalent_dtype +from scipy import stats +from torch import Tensor + + +def remap_labels(mask, label_dict_remap_json): + """ + Remap labels in the mask according to the provided label dictionary. + + This function reads a JSON file containing label mapping information and applies + the mapping to the input mask. + + Args: + mask (Tensor): The input mask tensor to be remapped. + label_dict_remap_json (str): Path to the JSON file containing the label mapping dictionary. + + Returns: + Tensor: The remapped mask tensor. + """ + with open(label_dict_remap_json, "r") as f: + mapping_dict = json.load(f) + mapper = MapLabelValue( + orig_labels=[pair[0] for pair in mapping_dict.values()], + target_labels=[pair[1] for pair in mapping_dict.values()], + dtype=torch.uint8, + ) + return mapper(mask[0, ...])[None, ...].to(mask.device) + + +def get_index_arr(img): + """ + Generate an index array for the given image. + + This function creates a 3D array of indices corresponding to the dimensions of the input image. + + Args: + img (ndarray): The input image array. + + Returns: + ndarray: A 3D array containing the indices for each dimension of the input image. + """ + return np.moveaxis( + np.moveaxis( + np.stack(np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]), np.arange(img.shape[2]))), 0, 3 + ), + 0, + 1, + ) + + +def supress_non_largest_components(img, target_label, default_val=0): + """ + Suppress all components except the largest one(s) for specified target labels. + + This function identifies the largest component(s) for each target label and + suppresses all other smaller components. + + Args: + img (ndarray): The input image array. + target_label (list): List of label values to process. + default_val (int, optional): Value to assign to suppressed voxels. Defaults to 0. + + Returns: + tuple: A tuple containing: + - ndarray: Modified image with non-largest components suppressed. + - int: Number of voxels that were changed. + """ + index_arr = get_index_arr(img) + img_mod = copy.deepcopy(img) + new_background = np.zeros(img.shape, dtype=np.bool_) + for label in target_label: + label_cc = skimage.measure.label(img == label, connectivity=3) + uv, uc = np.unique(label_cc, return_counts=True) + dominant_vals = uv[np.argsort(uc)[::-1][:2]] + if len(dominant_vals) >= 2: # Case: no predictions + new_background = np.logical_or( + new_background, + np.logical_not(np.logical_or(label_cc == dominant_vals[0], label_cc == dominant_vals[1])), + ) + + for voxel in index_arr[new_background]: + img_mod[tuple(voxel)] = default_val + diff = np.sum((img - img_mod) > 0) + + return img_mod, diff + + +def erode_one_img(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> Tensor: + """ + Erode 2D/3D binary mask with data type as torch tensor. + + Args: + mask_t: input 2D/3D binary mask, [M,N] or [M,N,P] torch tensor. + filter_size: erosion filter size, has to be odd numbers, default to be 3. + pad_value: the filled value for padding. We need to pad the input before filtering + to keep the output with the same size as input. Usually use default value + and not changed. + + Return: + Tensor: eroded mask, same shape as input. + """ + return ( + erode( + mask_t.float() + .unsqueeze(0) + .unsqueeze( + 0, + ), + filter_size, + pad_value=pad_value, + ) + .squeeze(0) + .squeeze(0) + ) + + +def dilate_one_img(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor: + """ + Dilate 2D/3D binary mask with data type as torch tensor. + + Args: + mask_t: input 2D/3D binary mask, [M,N] or [M,N,P] torch tensor. + filter_size: dilation filter size, has to be odd numbers, default to be 3. + pad_value: the filled value for padding. We need to pad the input before filtering + to keep the output with the same size as input. Usually use default value + and not changed. + + Return: + Tensor: dilated mask, same shape as input. + """ + return ( + dilate( + mask_t.float() + .unsqueeze(0) + .unsqueeze( + 0, + ), + filter_size, + pad_value=pad_value, + ) + .squeeze(0) + .squeeze(0) + ) + + +def binarize_labels(x: Tensor, bits: int = 8) -> Tensor: + """ + Convert input tensor to binary representation. + + This function takes an input tensor and converts it to a binary representation + using the specified number of bits. + + Args: + x (Tensor): Input tensor with shape (B, 1, H, W, D). + bits (int, optional): Number of bits to use for binary representation. Defaults to 8. + + Returns: + Tensor: Binary representation of the input tensor with shape (B, bits, H, W, D). + """ + mask = 2 ** torch.arange(bits).to(x.device, x.dtype) + return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte().squeeze(1).permute(0, 4, 1, 2, 3) + + +def setup_ddp(rank: int, world_size: int) -> torch.device: + """ + Initialize the distributed process group. + + Args: + rank (int): rank of the current process. + world_size (int): number of processes participating in the job. + + Returns: + torch.device: device of the current process. + """ + dist.init_process_group( + backend="nccl", init_method="env://", timeout=timedelta(seconds=36000), rank=rank, world_size=world_size + ) + dist.barrier() + device = torch.device(f"cuda:{rank}") + return device + + +def define_instance(args: Namespace, instance_def_key: str) -> Any: + """ + Define and instantiate an object based on the provided arguments and instance definition key. + + This function uses a ConfigParser to parse the arguments and instantiate an object + defined by the instance_def_key. + + Args: + args: An object containing the arguments to be parsed. + instance_def_key (str): The key used to retrieve the instance definition from the parsed content. + + Returns: + The instantiated object as defined by the instance_def_key in the parsed configuration. + """ + parser = ConfigParser(vars(args)) + parser.parse(True) + return parser.get_parsed_content(instance_def_key, instantiate=True) + + +def add_data_dir2path(list_files: list, data_dir: str, fold: int = None) -> tuple[list, list]: + """ + Read a list of data dictionary. + + Args: + list_files (list): input data to load and transform to generate dataset for model. + data_dir (str): directory of files. + fold (int, optional): fold index for cross validation. Defaults to None. + + Returns: + tuple[list, list]: A tuple of two arrays (training, validation). + """ + new_list_files = copy.deepcopy(list_files) + if fold is not None: + new_list_files_train = [] + new_list_files_val = [] + for d in new_list_files: + d["image"] = os.path.join(data_dir, d["image"]) + + if "label" in d: + d["label"] = os.path.join(data_dir, d["label"]) + + if fold is not None: + if d["fold"] == fold: + new_list_files_val.append(copy.deepcopy(d)) + else: + new_list_files_train.append(copy.deepcopy(d)) + + if fold is not None: + return new_list_files_train, new_list_files_val + else: + return new_list_files, [] + + +def prepare_maisi_controlnet_json_dataloader( + json_data_list: list | str, + data_base_dir: list | str, + batch_size: int = 1, + fold: int = 0, + cache_rate: float = 0.0, + rank: int = 0, + world_size: int = 1, +) -> tuple[DataLoader, DataLoader]: + """ + Prepare dataloaders for training and validation. + + Args: + json_data_list (list | str): the name of JSON files listing the data. + data_base_dir (list | str): directory of files. + batch_size (int, optional): how many samples per batch to load . Defaults to 1. + fold (int, optional): fold index for cross validation. Defaults to 0. + cache_rate (float, optional): percentage of cached data in total. Defaults to 0.0. + rank (int, optional): rank of the current process. Defaults to 0. + world_size (int, optional): number of processes participating in the job. Defaults to 1. + + Returns: + tuple[DataLoader, DataLoader]: A tuple of two dataloaders (training, validation). + """ + use_ddp = world_size > 1 + if isinstance(json_data_list, list): + assert isinstance(data_base_dir, list) + list_train = [] + list_valid = [] + for data_list, data_root in zip(json_data_list, data_base_dir): + with open(data_list, "r") as f: + json_data = json.load(f)["training"] + train, val = add_data_dir2path(json_data, data_root, fold) + list_train += train + list_valid += val + else: + with open(json_data_list, "r") as f: + json_data = json.load(f)["training"] + list_train, list_valid = add_data_dir2path(json_data, data_base_dir, fold) + + common_transform = [ + LoadImaged(keys=["image", "label"], image_only=True, ensure_channel_first=True), + Orientationd(keys=["label"], axcodes="RAS"), + EnsureTyped(keys=["label"], dtype=torch.uint8, track_meta=True), + Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x), allow_missing_keys=True), + Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x), allow_missing_keys=True), + Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)), + Lambdad( + keys=["top_region_index", "bottom_region_index", "spacing"], func=lambda x: x * 1e2, allow_missing_keys=True + ), + ] + train_transforms, val_transforms = Compose(common_transform), Compose(common_transform) + + train_loader = None + + if use_ddp: + list_train = partition_dataset( + data=list_train, + shuffle=True, + num_partitions=world_size, + even_divisible=True, + )[rank] + train_ds = CacheDataset(data=list_train, transform=train_transforms, cache_rate=cache_rate, num_workers=8) + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) + if use_ddp: + list_valid = partition_dataset( + data=list_valid, + shuffle=True, + num_partitions=world_size, + even_divisible=False, + )[rank] + val_ds = CacheDataset( + data=list_valid, + transform=val_transforms, + cache_rate=cache_rate, + num_workers=8, + ) + val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False) + return train_loader, val_loader + + +def prepare_maisi_controlnet_infer_dataloader( + json_data_list: list | str, + data_base_dir: list | str, + batch_size: int = 1, + cache_rate: float = 0.0, + rank: int = 0, + world_size: int = 1, +) -> DataLoader: + """ + Prepare dataloader for inference (test set). + + Args: + json_data_list (list | str): the name of JSON files listing the data. + data_base_dir (list | str): directory of files. + batch_size (int, optional): how many samples per batch to load. Defaults to 1. + cache_rate (float, optional): percentage of cached data in total. Defaults to 0.0. + rank (int, optional): rank of the current process. Defaults to 0. + world_size (int, optional): number of processes participating in the job. Defaults to 1. + + Returns: + DataLoader: Dataloader for inference. + """ + use_ddp = world_size > 1 + if isinstance(json_data_list, list): + assert isinstance(data_base_dir, list) + list_infer = [] + for data_list, data_root in zip(json_data_list, data_base_dir): + with open(data_list, "r") as f: + # Use "testing" key for test set + json_data = json.load(f)["testing"] + infer, _ = add_data_dir2path(json_data, data_root, fold=None) + list_infer += infer + else: + with open(json_data_list, "r") as f: + json_data = json.load(f)["testing"] + list_infer, _ = add_data_dir2path(json_data, data_base_dir, fold=None) + + common_transform = [ + LoadImaged(keys=["image", "label"], image_only=True, ensure_channel_first=True), + Orientationd(keys=["label"], axcodes="RAS"), + EnsureTyped(keys=["label"], dtype=torch.uint8, track_meta=True), + Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x), allow_missing_keys=True), + Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x), allow_missing_keys=True), + Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)), + Lambdad( + keys=["top_region_index", "bottom_region_index", "spacing"], func=lambda x: x * 1e2, allow_missing_keys=True + ), + ] + infer_transforms = Compose(common_transform) + + if use_ddp: + list_infer = partition_dataset( + data=list_infer, + shuffle=False, + num_partitions=world_size, + even_divisible=False, + )[rank] + infer_ds = CacheDataset(data=list_infer, transform=infer_transforms, cache_rate=cache_rate, num_workers=8) + infer_loader = DataLoader(infer_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False) + return infer_loader + + + + +def add_data_dir2path_test(list_files: list, data_dir: str, fold: int = None) -> tuple[list, list]: + """ + Read a list of data dictionary. + + Args: + list_files (list): input data to load and transform to generate dataset for model. + data_dir (str): directory of files. + fold (int, optional): fold index for cross validation. Defaults to None. + + Returns: + tuple[list, list]: A tuple of two arrays (training, validation). + """ + new_list_files = copy.deepcopy(list_files) + if fold is not None: + new_list_files_train = [] + new_list_files_val = [] + for d in new_list_files: + d["label"] = os.path.join(data_dir, d["label"]) + + if fold is not None: + if d["fold"] == fold: + new_list_files_val.append(copy.deepcopy(d)) + else: + new_list_files_train.append(copy.deepcopy(d)) + + if fold is not None: + return new_list_files_train, new_list_files_val + else: + return new_list_files, [] + + + +def prepare_maisi_controlnet_test_dataloader( + json_data_list: list | str, + data_base_dir: list | str, + batch_size: int = 1, + cache_rate: float = 0.0, + rank: int = 0, + world_size: int = 1, +) -> DataLoader: + """ + Prepare dataloader for inference (test set). + + Args: + json_data_list (list | str): the name of JSON files listing the data. + data_base_dir (list | str): directory of files. + batch_size (int, optional): how many samples per batch to load. Defaults to 1. + cache_rate (float, optional): percentage of cached data in total. Defaults to 0.0. + rank (int, optional): rank of the current process. Defaults to 0. + world_size (int, optional): number of processes participating in the job. Defaults to 1. + + Returns: + DataLoader: Dataloader for inference. + """ + use_ddp = world_size > 1 + if isinstance(json_data_list, list): + assert isinstance(data_base_dir, list) + list_infer = [] + for data_list, data_root in zip(json_data_list, data_base_dir): + with open(data_list, "r") as f: + # Use "testing" key for test set + json_data = json.load(f)["testing"] + infer, _ = add_data_dir2path_test(json_data, data_root, fold=None) + list_infer += infer + else: + with open(json_data_list, "r") as f: + json_data = json.load(f)["testing"] + list_infer, _ = add_data_dir2path_test(json_data, data_base_dir, fold=None) + + common_transform = [ + LoadImaged(keys=["label"], image_only=True, ensure_channel_first=True), + Orientationd(keys=["label"], axcodes="RAS"), + EnsureTyped(keys=["label"], dtype=torch.uint8, track_meta=True), + Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x), allow_missing_keys=True), + Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x), allow_missing_keys=True), + Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)), + Lambdad( + keys=["top_region_index", "bottom_region_index", "spacing"], func=lambda x: x * 1e2, allow_missing_keys=True + ), + ] + infer_transforms = Compose(common_transform) + + if use_ddp: + list_infer = partition_dataset( + data=list_infer, + shuffle=False, + num_partitions=world_size, + even_divisible=False, + )[rank] + infer_ds = CacheDataset(data=list_infer, transform=infer_transforms, cache_rate=cache_rate, num_workers=8) + infer_loader = DataLoader(infer_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False) + return infer_loader + +def organ_fill_by_closing(data, target_label, device, close_times=2, filter_size=3, pad_value=0.0): + """ + Fill holes in an organ mask using morphological closing operations. + + This function performs a series of dilation and erosion operations to fill holes + in the organ mask identified by the target label. + + Args: + data (ndarray): The input data containing organ labels. + target_label (int): The label of the organ to be processed. + device (str): The device to perform the operations on (e.g., 'cuda:0'). + close_times (int, optional): Number of times to perform the closing operation. Defaults to 2. + filter_size (int, optional): Size of the filter for dilation and erosion. Defaults to 3. + pad_value (float, optional): Value used for padding in dilation and erosion. Defaults to 0.0. + + Returns: + ndarray: Boolean mask of the filled organ. + """ + mask = (data == target_label).astype(np.uint8) + mask = torch.from_numpy(mask).to(device) + for _ in range(close_times): + mask = dilate_one_img(mask, filter_size=filter_size, pad_value=pad_value) + mask = erode_one_img(mask, filter_size=filter_size, pad_value=pad_value) + return mask.cpu().numpy().astype(np.bool_) + + +def organ_fill_by_removed_mask(data, target_label, remove_mask, device): + """ + Fill an organ mask in regions where it was previously removed. + + Args: + data (ndarray): The input data containing organ labels. + target_label (int): The label of the organ to be processed. + remove_mask (ndarray): Boolean mask indicating regions where the organ was removed. + device (str): The device to perform the operations on (e.g., 'cuda:0'). + + Returns: + ndarray: Boolean mask of the filled organ in previously removed regions. + """ + mask = (data == target_label).astype(np.uint8) + mask = dilate_one_img(torch.from_numpy(mask).to(device), filter_size=3, pad_value=0.0) + mask = dilate_one_img(mask, filter_size=3, pad_value=0.0) + roi_oragn_mask = dilate_one_img(mask, filter_size=3, pad_value=0.0).cpu().numpy() + return (roi_oragn_mask * remove_mask).astype(np.bool_) + + +def get_body_region_index_from_mask(input_mask): + """ + Determine the top and bottom body region indices from an input mask. + + Args: + input_mask (Tensor): Input mask tensor containing body region labels. + + Returns: + tuple: Two lists representing the top and bottom region indices. + """ + region_indices = {} + # head and neck + region_indices["region_0"] = [22, 120] + # thorax + region_indices["region_1"] = [28, 29, 30, 31, 32] + # abdomen + region_indices["region_2"] = [1, 2, 3, 4, 5, 14] + # pelvis and lower + region_indices["region_3"] = [93, 94] + + nda = input_mask.cpu().numpy().squeeze() + unique_elements = np.lib.arraysetops.unique(nda) + unique_elements = list(unique_elements) + # print(f"nda: {nda.shape} {unique_elements}.") + overlap_array = np.zeros(len(region_indices), dtype=np.uint8) + for _j in range(len(region_indices)): + overlap = any(element in region_indices[f"region_{_j}"] for element in unique_elements) + overlap_array[_j] = np.uint8(overlap) + overlap_array_indices = np.nonzero(overlap_array)[0] + top_region_index = np.eye(len(region_indices), dtype=np.uint8)[np.amin(overlap_array_indices), ...] + top_region_index = list(top_region_index) + top_region_index = [int(_k) for _k in top_region_index] + bottom_region_index = np.eye(len(region_indices), dtype=np.uint8)[np.amax(overlap_array_indices), ...] + bottom_region_index = list(bottom_region_index) + bottom_region_index = [int(_k) for _k in bottom_region_index] + # print(f"{top_region_index} {bottom_region_index}") + return top_region_index, bottom_region_index + + +def general_mask_generation_post_process(volume_t, target_tumor_label=None, device="cuda:0"): + """ + Perform post-processing on a generated mask volume. + + This function applies various refinement steps to improve the quality of the generated mask, + including body mask refinement, tumor prediction refinement, and organ-specific processing. + + Args: + volume_t (ndarray): Input volume containing organ and tumor labels. + target_tumor_label (int, optional): Label of the target tumor. Defaults to None. + device (str, optional): Device to perform operations on. Defaults to "cuda:0". + + Returns: + ndarray: Post-processed volume with refined organ and tumor labels. + """ + # assume volume_t is np array with shape (H,W,D) + hepatic_vessel = volume_t == 25 + airway = volume_t == 132 + + # ------------ refine body mask pred + body_region_mask = ( + erode_one_img(torch.from_numpy((volume_t > 0)).to(device), filter_size=3, pad_value=0.0).cpu().numpy() + ) + body_region_mask, _ = supress_non_largest_components(body_region_mask, [1]) + body_region_mask = ( + dilate_one_img(torch.from_numpy(body_region_mask).to(device), filter_size=3, pad_value=0.0) + .cpu() + .numpy() + .astype(np.uint8) + ) + volume_t = volume_t * body_region_mask + + # ------------ refine tumor pred + tumor_organ_dict = {23: 28, 24: 4, 26: 1, 27: 62, 128: 200} + for t in [23, 24, 26, 27, 128]: + if t != target_tumor_label: + volume_t[volume_t == t] = tumor_organ_dict[t] + else: + volume_t[organ_fill_by_closing(volume_t, target_label=t, device=device)] = t + volume_t[organ_fill_by_closing(volume_t, target_label=t, device=device)] = t + # we only keep the largest connected componet for tumors except hepatic tumor and bone lesion + if target_tumor_label != 26 and target_tumor_label != 128: + volume_t, _ = supress_non_largest_components(volume_t, [target_tumor_label], default_val=200) + target_tumor = volume_t == target_tumor_label + + # ------------ remove undesired organ pred + # general post-process non-largest components suppression + # process 4 ROI organs + spleen + 2 kidney + 5 lung lobes + duodenum + inferior vena cava + oran_list = [1, 4, 10, 12, 3, 28, 29, 30, 31, 32, 5, 14, 13, 6, 7, 8, 9, 10] + if target_tumor_label != 128: + oran_list += list(range(33, 60)) # + list(range(63,87)) + data, _ = supress_non_largest_components(volume_t, oran_list, default_val=200) # 200 is body region + organ_remove_mask = (volume_t - data).astype(np.bool_) + # process intestinal system (stomach 12, duodenum 13, small bowel 19, colon 62) + intestinal_mask_ = ( + (data == 12).astype(np.uint8) + + (data == 13).astype(np.uint8) + + (data == 19).astype(np.uint8) + + (data == 62).astype(np.uint8) + ) + intestinal_mask, _ = supress_non_largest_components(intestinal_mask_, [1], default_val=0) + # process small bowel 19 + small_bowel_remove_mask = (data == 19).astype(np.uint8) - (data == 19).astype(np.uint8) * intestinal_mask + # process colon 62 + colon_remove_mask = (data == 62).astype(np.uint8) - (data == 62).astype(np.uint8) * intestinal_mask + intestinal_remove_mask = (small_bowel_remove_mask + colon_remove_mask).astype(np.bool_) + data[intestinal_remove_mask] = 200 + + # ------------ full correponding organ in removed regions + for organ_label in oran_list: + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + + if target_tumor_label == 23 and np.sum(target_tumor) > 0: + # speical process for cases with lung tumor + dia_lung_tumor_mask = ( + dilate_one_img(torch.from_numpy((data == 23)).to(device), filter_size=3, pad_value=0.0).cpu().numpy() + ) + tmp = ( + (data * (dia_lung_tumor_mask.astype(np.uint8) - (data == 23).astype(np.uint8))).astype(np.float32).flatten() + ) + tmp[tmp == 0] = float("nan") + mode = int(stats.mode(tmp.flatten(), nan_policy="omit")[0]) + if mode in [28, 29, 30, 31, 32]: + dia_lung_tumor_mask = ( + dilate_one_img(torch.from_numpy(dia_lung_tumor_mask).to(device), filter_size=3, pad_value=0.0) + .cpu() + .numpy() + ) + lung_remove_mask = dia_lung_tumor_mask.astype(np.uint8) - (data == 23).astype(np.uint8).astype(np.uint8) + data[organ_fill_by_removed_mask(data, target_label=mode, remove_mask=lung_remove_mask, device=device)] = ( + mode + ) + dia_lung_tumor_mask = ( + dilate_one_img(torch.from_numpy(dia_lung_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy() + ) + data[ + organ_fill_by_removed_mask( + data, target_label=23, remove_mask=dia_lung_tumor_mask * organ_remove_mask, device=device + ) + ] = 23 + for organ_label in [28, 29, 30, 31, 32]: + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + + if target_tumor_label == 26 and np.sum(target_tumor) > 0: + # speical process for cases with hepatic tumor + # process liver 1 + data[organ_fill_by_removed_mask(data, target_label=1, remove_mask=intestinal_remove_mask, device=device)] = 1 + data[organ_fill_by_removed_mask(data, target_label=1, remove_mask=intestinal_remove_mask, device=device)] = 1 + # process spleen 2 + data[organ_fill_by_removed_mask(data, target_label=3, remove_mask=organ_remove_mask, device=device)] = 3 + data[organ_fill_by_removed_mask(data, target_label=3, remove_mask=organ_remove_mask, device=device)] = 3 + dia_tumor_mask = ( + dilate_one_img(torch.from_numpy((data == target_tumor_label)).to(device), filter_size=3, pad_value=0.0) + .cpu() + .numpy() + ) + dia_tumor_mask = ( + dilate_one_img(torch.from_numpy(dia_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy() + ) + data[ + organ_fill_by_removed_mask( + data, target_label=target_tumor_label, remove_mask=dia_tumor_mask * organ_remove_mask, device=device + ) + ] = target_tumor_label + # refine hepatic tumor + hepatic_tumor_vessel_liver_mask_ = ( + (data == 26).astype(np.uint8) + (data == 25).astype(np.uint8) + (data == 1).astype(np.uint8) + ) + hepatic_tumor_vessel_liver_mask_ = (hepatic_tumor_vessel_liver_mask_ > 1).astype(np.uint8) + hepatic_tumor_vessel_liver_mask, _ = supress_non_largest_components( + hepatic_tumor_vessel_liver_mask_, [1], default_val=0 + ) + removed_region = (hepatic_tumor_vessel_liver_mask_ - hepatic_tumor_vessel_liver_mask).astype(np.bool_) + data[removed_region] = 200 + target_tumor = (target_tumor * hepatic_tumor_vessel_liver_mask).astype(np.bool_) + # refine liver + data[organ_fill_by_closing(data, target_label=1, device=device)] = 1 + data[organ_fill_by_closing(data, target_label=1, device=device)] = 1 + data[organ_fill_by_closing(data, target_label=1, device=device)] = 1 + + if target_tumor_label == 27 and np.sum(target_tumor) > 0: + # speical process for cases with colon tumor + dia_tumor_mask = ( + dilate_one_img(torch.from_numpy((data == target_tumor_label)).to(device), filter_size=3, pad_value=0.0) + .cpu() + .numpy() + ) + dia_tumor_mask = ( + dilate_one_img(torch.from_numpy(dia_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy() + ) + data[ + organ_fill_by_removed_mask( + data, target_label=target_tumor_label, remove_mask=dia_tumor_mask * organ_remove_mask, device=device + ) + ] = target_tumor_label + + if target_tumor_label == 129 and np.sum(target_tumor) > 0: + # speical process for cases with kidney tumor + for organ_label in [5, 14]: + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + # TODO: current model does not support hepatic vessel by size control. + # we treat it as liver for better visiaulization + print( + "Current model does not support hepatic vessel by size control, " + "so we treat generated hepatic vessel as part of liver for better visiaulization." + ) + data[hepatic_vessel] = 1 + data[airway] = 132 + if target_tumor_label is not None: + data[target_tumor] = target_tumor_label + + return data + + +class MapLabelValue: + """ + Utility to map label values to another set of values. + For example, map [3, 2, 1] to [0, 1, 2], [1, 2, 3] -> [0.5, 1.5, 2.5], ["label3", "label2", "label1"] -> [0, 1, 2], + [3.5, 2.5, 1.5] -> ["label0", "label1", "label2"], etc. + The label data must be numpy array or array-like data and the output data will be numpy array. + + """ + + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + + def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None: + """ + Args: + orig_labels: original labels that map to others. + target_labels: expected label values, 1: 1 map to the `orig_labels`. + dtype: convert the output data to dtype, default to float32. + if dtype is from PyTorch, the transform will use the pytorch backend, else with numpy backend. + + """ + if len(orig_labels) != len(target_labels): + raise ValueError("orig_labels and target_labels must have the same length.") + + self.orig_labels = orig_labels + self.target_labels = target_labels + self.pair = tuple((o, t) for o, t in zip(self.orig_labels, self.target_labels) if o != t) + type_dtype = type(dtype) + if getattr(type_dtype, "__module__", "") == "torch": + self.use_numpy = False + self.dtype = get_equivalent_dtype(dtype, data_type=torch.Tensor) + else: + self.use_numpy = True + self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray) + + def __call__(self, img: NdarrayOrTensor): + """ + Apply the label mapping to the input image. + + Args: + img (NdarrayOrTensor): Input image to be remapped. + + Returns: + NdarrayOrTensor: Remapped image. + """ + if self.use_numpy: + img_np, *_ = convert_data_type(img, np.ndarray) + _out_shape = img_np.shape + img_flat = img_np.flatten() + try: + out_flat = img_flat.astype(self.dtype) + except ValueError: + # can't copy unchanged labels as the expected dtype is not supported, must map all the label values + out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype) + for o, t in self.pair: + out_flat[img_flat == o] = t + out_t = out_flat.reshape(_out_shape) + else: + img_t, *_ = convert_data_type(img, torch.Tensor) + out_t = img_t.detach().clone().to(self.dtype) # type: ignore + for o, t in self.pair: + out_t[img_t == o] = t + out, *_ = convert_to_dst_type(src=out_t, dst=img, dtype=self.dtype) + return out + + +def KL_loss(z_mu, z_sigma): + """ + Compute the Kullback-Leibler (KL) divergence loss for a variational autoencoder (VAE). + + The KL divergence measures how one probability distribution diverges from a second, expected probability distribution. + In the context of VAEs, this loss term ensures that the learned latent space distribution is close to a standard normal distribution. + + Args: + z_mu (torch.Tensor): Mean of the latent variable distribution, shape [N,C,H,W,D] or [N,C,H,W]. + z_sigma (torch.Tensor): Standard deviation of the latent variable distribution, same shape as 'z_mu'. + + Returns: + torch.Tensor: The computed KL divergence loss, averaged over the batch. + """ + eps = 1e-10 + kl_loss = 0.5 * torch.sum( + z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2) + eps) - 1, + dim=list(range(1, len(z_sigma.shape))), + ) + return torch.sum(kl_loss) / kl_loss.shape[0] + + +def dynamic_infer(inferer, model, images): + """ + Perform dynamic inference using a model and an inferer, typically a monai SlidingWindowInferer. + + This function determines whether to use the model directly or to use the provided inferer + (such as a sliding window inferer) based on the size of the input images. + + Args: + inferer: An inference object, typically a monai SlidingWindowInferer, which handles patch-based inference. + model (torch.nn.Module): The model used for inference. + images (torch.Tensor): The input images for inference, shape [N,C,H,W,D] or [N,C,H,W]. + + Returns: + torch.Tensor: The output from the model or the inferer, depending on the input size. + """ + if torch.numel(images[0:1, 0:1, ...]) <= math.prod(inferer.roi_size): + return model(images) + else: + # Extract the spatial dimensions from the images tensor (H, W, D) + spatial_dims = images.shape[2:] + orig_roi = inferer.roi_size + + # Check that roi has the same number of dimensions as spatial_dims + if len(orig_roi) != len(spatial_dims): + raise ValueError(f"ROI length ({len(orig_roi)}) does not match spatial dimensions ({len(spatial_dims)}).") + + # Iterate and adjust each ROI dimension + adjusted_roi = [min(roi_dim, img_dim) for roi_dim, img_dim in zip(orig_roi, spatial_dims)] + inferer.roi_size = adjusted_roi + output = inferer(network=model, inputs=images) + inferer.roi_size = orig_roi + return output diff --git a/scripts/utils_plot.py b/scripts/utils_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..7ecb79c60338f35d369a17541773024155ec138c --- /dev/null +++ b/scripts/utils_plot.py @@ -0,0 +1,195 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from monai.transforms import AsDiscrete + + +def find_label_center_loc(x): + """ + Find the center location of non-zero elements in a binary mask. + + Args: + x (torch.Tensor): Binary mask tensor. Expected shape: [H, W, D] or [C, H, W, D]. + + Returns: + list: Center locations for each dimension. Each element is either + the middle index of non-zero locations or None if no non-zero elements exist. + """ + label_loc = torch.where(x != 0) + center_loc = [] + for loc in label_loc: + unique_loc = torch.unique(loc) + if len(unique_loc) == 0: + center_loc.append(None) + else: + center_loc.append(unique_loc[len(unique_loc) // 2]) + + return center_loc + + +def normalize_label_to_uint8(colorize, label, n_label): + """ + Normalize and colorize a label tensor to a uint8 image. + + Args: + colorize (torch.Tensor): Weight tensor for colorization. Expected shape: [3, n_label, 1, 1]. + label (torch.Tensor): Input label tensor. Expected shape: [1, H, W]. + n_label (int): Number of unique labels. + + Returns: + numpy.ndarray: Normalized and colorized image as uint8 numpy array. Shape: [H, W, 3]. + """ + with torch.no_grad(): + post_label = AsDiscrete(to_onehot=n_label) + label = post_label(label).permute(1, 0, 2, 3) + label = F.conv2d(label, weight=colorize) + label = torch.clip(label, 0, 1).squeeze().permute(1, 2, 0).cpu().numpy() + + draw_img = (label * 255).astype(np.uint8) + + return draw_img + + +def visualize_one_slice_in_3d(image, axis: int = 2, center=None, mask_bool=True, n_label=105, colorize=None): + """ + Extract and visualize a 2D slice from a 3D image or label tensor. + + Args: + image (torch.Tensor): Input 3D image or label tensor. Expected shape: [1, H, W, D]. + axis (int, optional): Axis along which to extract the slice (0, 1, or 2). Defaults to 2. + center (int, optional): Index of the slice to extract. If None, the middle slice is used. + mask_bool (bool, optional): If True, treat the input as a label mask and normalize it. Defaults to True. + n_label (int, optional): Number of labels in the mask. Used only if mask_bool is True. Defaults to 105. + colorize (torch.Tensor, optional): Colorization weights for label normalization. + Expected shape: [3, n_label, 1, 1] if provided. + + Returns: + numpy.ndarray: 2D slice of the input. If mask_bool is True, returns a normalized uint8 array + with shape [3, H, W]. Otherwise, returns a float32 array with shape [3, H, W]. + + Raises: + ValueError: If the specified axis is not 0, 1, or 2. + """ + # draw image + if center is None: + center = image.shape[2:][axis] // 2 + if axis == 0: + draw_img = image[..., center, :, :] + elif axis == 1: + draw_img = image[..., :, center, :] + elif axis == 2: + draw_img = image[..., :, :, center] + else: + raise ValueError("axis should be in [0,1,2]") + if mask_bool: + draw_img = normalize_label_to_uint8(colorize, draw_img, n_label) + else: + draw_img = draw_img.squeeze().cpu().numpy().astype(np.float32) + draw_img = np.stack((draw_img,) * 3, axis=-1) + return draw_img + + +def show_image(image, title="mask"): + """ + Plot and display an input image. + + Args: + image (numpy.ndarray): Image to be displayed. Expected shape: [H, W] for grayscale or [H, W, 3] for RGB. + title (str, optional): Title for the plot. Defaults to "mask". + """ + plt.figure("check", (24, 12)) + plt.subplot(1, 2, 1) + plt.title(title) + plt.imshow(image) + plt.show() + + +def to_shape(a, shape): + """ + Pad an image to a desired shape. + + This function pads a 3D numpy array (image) with zeros to reach the specified shape. + The padding is added equally on both sides of each dimension, with any odd padding + added to the end. + + Args: + a (numpy.ndarray): Input 3D array to be padded. Expected shape: [X, Y, Z]. + shape (tuple): Desired output shape as (x_, y_, z_). + + Returns: + numpy.ndarray: Padded array with the desired shape [x_, y_, z_]. + + Note: + If the input shape is larger than the desired shape in any dimension, + no padding is removed; the original size is maintained for that dimension. + Padding is done using numpy's pad function with 'constant' mode (zero-padding). + """ + x_, y_, z_ = shape + x, y, z = a.shape + x_pad = x_ - x + y_pad = y_ - y + z_pad = z_ - z + return np.pad( + a, + ( + (x_pad // 2, x_pad // 2 + x_pad % 2), + (y_pad // 2, y_pad // 2 + y_pad % 2), + (z_pad // 2, z_pad // 2 + z_pad % 2), + ), + mode="constant", + ) + + +def get_xyz_plot(image, center_loc_axis, mask_bool=True, n_label=105, colorize=None, target_class_index=0): + """ + Generate a concatenated XYZ plot of 2D slices from a 3D image. + + This function creates visualizations of three orthogonal slices (XY, XZ, YZ) from a 3D image + and concatenates them into a single 2D image. + + Args: + image (torch.Tensor): Input 3D image tensor. Expected shape: [1, H, W, D]. + center_loc_axis (list): List of three integers specifying the center locations for each axis. + mask_bool (bool, optional): Whether to apply masking. Defaults to True. + n_label (int, optional): Number of labels for visualization. Defaults to 105. + colorize (torch.Tensor, optional): Colorization weights. Expected shape: [3, n_label, 1, 1] if provided. + target_class_index (int, optional): Index of the target class. Defaults to 0. + + Returns: + numpy.ndarray: Concatenated 2D image of the three orthogonal slices. Shape: [max(H,W,D), 3*max(H,W,D), 3]. + + Note: + The output image is padded to ensure all slices have the same dimensions. + """ + target_shape = list(image.shape[1:]) # [1,H,W,D] + img_list = [] + + for axis in range(3): + center = center_loc_axis[axis] + + img = visualize_one_slice_in_3d( + torch.flip(image.unsqueeze(0), [-3, -2, -1]), + axis, + center=center, + mask_bool=mask_bool, + n_label=n_label, + colorize=colorize, + ) + img = img.transpose([2, 1, 0]) + + img = to_shape(img, (3, max(target_shape), max(target_shape))) + img_list.append(img) + img = np.concatenate(img_list, axis=2).transpose([1, 2, 0]) + return img