Venn's picture
Upload maisi_ct_generative version 1.0.0
02aa18d verified
{
"imports": [
"$import glob",
"$import os",
"$import scripts",
"$import ignite"
],
"bundle_root": ".",
"ckpt_dir": "$@bundle_root + '/models'",
"output_dir": "$@bundle_root + '/output'",
"data_list_file_path": "$@bundle_root + '/datasets/C4KC-KiTS_subset.json'",
"dataset_dir": "$@bundle_root + '/datasets/C4KC-KiTS_subset'",
"trained_diffusion_path": "$@ckpt_dir + '/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt'",
"trained_controlnet_path": "$@ckpt_dir + '/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt'",
"use_tensorboard": true,
"fold": 0,
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"epochs": 100,
"batch_size": 1,
"val_at_start": false,
"learning_rate": 0.0001,
"weighted_loss_label": [
129
],
"weighted_loss": 100,
"amp": true,
"train_datalist": "$scripts.utils.maisi_datafold_read(json_list=@data_list_file_path, data_base_dir=@dataset_dir, fold=@fold)[0]",
"spatial_dims": 3,
"image_channels": 1,
"latent_channels": 4,
"diffusion_unet_def": {
"_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"out_channels": "@latent_channels",
"num_channels": [
64,
128,
256,
512
],
"attention_levels": [
false,
false,
true,
true
],
"num_head_channels": [
0,
0,
32,
32
],
"num_res_blocks": 2,
"use_flash_attention": true,
"include_top_region_index_input": true,
"include_bottom_region_index_input": true,
"include_spacing_input": true
},
"controlnet_def": {
"_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"num_channels": [
64,
128,
256,
512
],
"attention_levels": [
false,
false,
true,
true
],
"num_head_channels": [
0,
0,
32,
32
],
"num_res_blocks": 2,
"use_flash_attention": true,
"conditioning_embedding_in_channels": 8,
"conditioning_embedding_num_channels": [
8,
32,
64
]
},
"noise_scheduler": {
"_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
"num_train_timesteps": 1000,
"beta_start": 0.0015,
"beta_end": 0.0195,
"schedule": "scaled_linear_beta",
"clip_sample": false
},
"unzip_dataset": "$scripts.utils.unzip_dataset(@dataset_dir)",
"diffusion_unet": "$@diffusion_unet_def.to(@device)",
"checkpoint_diffusion_unet": "$torch.load(@trained_diffusion_path, weights_only=False)",
"load_diffusion": "$@diffusion_unet.load_state_dict(@checkpoint_diffusion_unet['unet_state_dict'])",
"controlnet": "$@controlnet_def.to(@device)",
"copy_controlnet_state": "$monai.networks.utils.copy_model_state(@controlnet, @diffusion_unet.state_dict())",
"checkpoint_controlnet": "$torch.load(@trained_controlnet_path, weights_only=False)",
"load_controlnet": "$@controlnet.load_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)",
"scale_factor": "$@checkpoint_diffusion_unet['scale_factor'].to(@device)",
"loss": {
"_target_": "torch.nn.L1Loss",
"reduction": "none"
},
"optimizer": {
"_target_": "torch.optim.AdamW",
"params": "$@controlnet.parameters()",
"lr": "@learning_rate",
"weight_decay": 1e-05
},
"lr_schedule": {
"activate": true,
"lr_scheduler": {
"_target_": "torch.optim.lr_scheduler.PolynomialLR",
"optimizer": "@optimizer",
"total_iters": "$(@epochs * len(@train#dataloader.dataset)) / @batch_size",
"power": 2.0
}
},
"train": {
"deterministic_transforms": [
{
"_target_": "LoadImaged",
"keys": [
"image",
"label"
],
"image_only": true,
"ensure_channel_first": true
},
{
"_target_": "Orientationd",
"keys": [
"label"
],
"axcodes": "RAS"
},
{
"_target_": "EnsureTyped",
"keys": [
"label"
],
"dtype": "$torch.uint8",
"track_meta": true
},
{
"_target_": "Lambdad",
"keys": "top_region_index",
"func": "$lambda x: torch.FloatTensor(x)"
},
{
"_target_": "Lambdad",
"keys": "bottom_region_index",
"func": "$lambda x: torch.FloatTensor(x)"
},
{
"_target_": "Lambdad",
"keys": "spacing",
"func": "$lambda x: torch.FloatTensor(x)"
},
{
"_target_": "Lambdad",
"keys": "top_region_index",
"func": "$lambda x: x * 1e2"
},
{
"_target_": "Lambdad",
"keys": "bottom_region_index",
"func": "$lambda x: x * 1e2"
},
{
"_target_": "Lambdad",
"keys": "spacing",
"func": "$lambda x: x * 1e2"
}
],
"inferer": {
"_target_": "SimpleInferer"
},
"preprocessing": {
"_target_": "Compose",
"transforms": "$@train#deterministic_transforms"
},
"dataset": {
"_target_": "Dataset",
"data": "@train_datalist",
"transform": "@train#preprocessing"
},
"dataloader": {
"_target_": "DataLoader",
"dataset": "@train#dataset",
"batch_size": "@batch_size",
"shuffle": true,
"num_workers": 4,
"pin_memory": true,
"persistent_workers": true
},
"handlers": [
{
"_target_": "LrScheduleHandler",
"_disabled_": "$not @lr_schedule#activate",
"lr_scheduler": "@lr_schedule#lr_scheduler",
"epoch_level": false,
"print_lr": true
},
{
"_target_": "CheckpointSaver",
"save_dir": "@ckpt_dir",
"save_dict": {
"controlnet_state_dict": "@controlnet",
"optimizer": "@optimizer"
},
"save_interval": 1,
"n_saved": 5
},
{
"_target_": "TensorBoardStatsHandler",
"_disabled_": "$not @use_tensorboard",
"log_dir": "@output_dir",
"tag_name": "train_loss",
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
},
{
"_target_": "StatsHandler",
"tag_name": "train_loss",
"name": "StatsHandler",
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
}
],
"trainer": {
"_target_": "scripts.trainer.MAISIControlNetTrainer",
"_requires_": [
"@load_diffusion",
"@copy_controlnet_state",
"@load_controlnet",
"@unzip_dataset"
],
"max_epochs": "@epochs",
"device": "@device",
"train_data_loader": "@train#dataloader",
"diffusion_unet": "@diffusion_unet",
"controlnet": "@controlnet",
"noise_scheduler": "@noise_scheduler",
"loss_function": "@loss",
"optimizer": "@optimizer",
"inferer": "@train#inferer",
"key_train_metric": null,
"train_handlers": "@train#handlers",
"amp": "@amp",
"hyper_kwargs": {
"weighted_loss": "@weighted_loss",
"weighted_loss_label": "@weighted_loss_label",
"scale_factor": "@scale_factor"
}
}
},
"initialize": [
"$monai.utils.set_determinism(seed=0)"
],
"run": [
"$@train#trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, ignite.handlers.TerminateOnNan())",
"$@train#trainer.run()"
]
}