| { | |
| "ckpt_dir": "$@bundle_root + '/models'", | |
| "train_batch_size": 4, | |
| "lr": 1e-05, | |
| "train_patch_size": [ | |
| 144, | |
| 176, | |
| 112 | |
| ], | |
| "latent_shape": [ | |
| "@latent_channels", | |
| 36, | |
| 44, | |
| 28 | |
| ], | |
| "load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'", | |
| "load_autoencoder": "$@autoencoder_def.load_old_state_dict(torch.load(@load_autoencoder_path))", | |
| "autoencoder": "$@autoencoder_def.to(@device)", | |
| "network_def": { | |
| "_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet", | |
| "spatial_dims": "@spatial_dims", | |
| "in_channels": "@latent_channels", | |
| "out_channels": "@latent_channels", | |
| "channels": [ | |
| 256, | |
| 256, | |
| 512 | |
| ], | |
| "attention_levels": [ | |
| false, | |
| true, | |
| true | |
| ], | |
| "num_head_channels": [ | |
| 0, | |
| 64, | |
| 64 | |
| ], | |
| "num_res_blocks": 2, | |
| "include_fc": false, | |
| "use_combined_linear": false | |
| }, | |
| "diffusion": "$@network_def.to(@device)", | |
| "optimizer": { | |
| "_target_": "torch.optim.Adam", | |
| "params": "$@diffusion.parameters()", | |
| "lr": "@lr" | |
| }, | |
| "lr_scheduler": { | |
| "_target_": "torch.optim.lr_scheduler.MultiStepLR", | |
| "optimizer": "@optimizer", | |
| "milestones": [ | |
| 100, | |
| 1000 | |
| ], | |
| "gamma": 0.1 | |
| }, | |
| "scale_factor": "$scripts.utils.compute_scale_factor(@autoencoder,@train#dataloader,@device)", | |
| "noise_scheduler": { | |
| "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler", | |
| "_requires_": [ | |
| "@load_autoencoder" | |
| ], | |
| "schedule": "scaled_linear_beta", | |
| "num_train_timesteps": 1000, | |
| "beta_start": 0.0015, | |
| "beta_end": 0.0195 | |
| }, | |
| "loss": { | |
| "_target_": "torch.nn.MSELoss" | |
| }, | |
| "train": { | |
| "inferer": { | |
| "_target_": "monai.inferers.LatentDiffusionInferer", | |
| "scheduler": "@noise_scheduler", | |
| "scale_factor": "@scale_factor" | |
| }, | |
| "crop_transforms": [ | |
| { | |
| "_target_": "CenterSpatialCropd", | |
| "keys": "image", | |
| "roi_size": "@train_patch_size" | |
| } | |
| ], | |
| "preprocessing": { | |
| "_target_": "Compose", | |
| "transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms" | |
| }, | |
| "dataset": { | |
| "_target_": "monai.apps.DecathlonDataset", | |
| "root_dir": "@dataset_dir", | |
| "task": "Task01_BrainTumour", | |
| "section": "training", | |
| "cache_rate": 1.0, | |
| "num_workers": 8, | |
| "download": false, | |
| "transform": "@train#preprocessing" | |
| }, | |
| "dataloader": { | |
| "_target_": "DataLoader", | |
| "dataset": "@train#dataset", | |
| "batch_size": "@train_batch_size", | |
| "shuffle": true, | |
| "num_workers": 0 | |
| }, | |
| "handlers": [ | |
| { | |
| "_target_": "LrScheduleHandler", | |
| "lr_scheduler": "@lr_scheduler", | |
| "print_lr": true | |
| }, | |
| { | |
| "_target_": "CheckpointSaver", | |
| "save_dir": "@ckpt_dir", | |
| "save_dict": { | |
| "model": "@diffusion" | |
| }, | |
| "save_interval": 0, | |
| "save_final": true, | |
| "epoch_level": true, | |
| "final_filename": "model.pt" | |
| }, | |
| { | |
| "_target_": "StatsHandler", | |
| "tag_name": "train_diffusion_loss", | |
| "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)" | |
| }, | |
| { | |
| "_target_": "TensorBoardStatsHandler", | |
| "log_dir": "@tf_dir", | |
| "tag_name": "train_diffusion_loss", | |
| "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)" | |
| } | |
| ], | |
| "trainer": { | |
| "_target_": "scripts.ldm_trainer.LDMTrainer", | |
| "device": "@device", | |
| "max_epochs": 5000, | |
| "train_data_loader": "@train#dataloader", | |
| "network": "@diffusion", | |
| "autoencoder_model": "@autoencoder", | |
| "optimizer": "@optimizer", | |
| "loss_function": "@loss", | |
| "latent_shape": "@latent_shape", | |
| "inferer": "@train#inferer", | |
| "key_train_metric": "$None", | |
| "train_handlers": "@train#handlers" | |
| } | |
| }, | |
| "initialize": [ | |
| "$monai.utils.set_determinism(seed=0)" | |
| ], | |
| "run": [ | |
| "@load_autoencoder", | |
| "$@autoencoder.eval()", | |
| "$print('scale factor:',@scale_factor)", | |
| "$@train#trainer.run()" | |
| ] | |
| } | |