maisi_ct_generative / configs /inference.json
Venn's picture
Upload maisi_ct_generative version 1.0.0
02aa18d verified
{
"imports": [
"$import torch",
"$from pathlib import Path",
"$import scripts"
],
"bundle_root": ".",
"model_dir": "$@bundle_root + '/models'",
"output_dir": "$@bundle_root + '/output'",
"create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)",
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"trained_autoencoder_path": "$@model_dir + '/autoencoder.pt'",
"trained_diffusion_path": "$@model_dir + '/diffusion_unet.pt'",
"trained_controlnet_path": "$@model_dir + '/controlnet.pt'",
"trained_mask_generation_autoencoder_path": "$@model_dir + '/mask_generation_autoencoder.pt'",
"trained_mask_generation_diffusion_path": "$@model_dir + '/mask_generation_diffusion_unet.pt'",
"all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_3000'",
"all_mask_files_json": "$@bundle_root + '/configs/candidate_masks_flexible_size_and_spacing_3000.json'",
"all_anatomy_size_condtions_json": "$@bundle_root + '/configs/all_anatomy_size_condtions.json'",
"label_dict_json": "$@bundle_root + '/configs/label_dict.json'",
"label_dict_remap_json": "$@bundle_root + '/configs/label_dict_124_to_132.json'",
"real_img_median_statistics_file": "$@bundle_root + '/configs/image_median_statistics.json'",
"num_output_samples": 1,
"body_region": [],
"anatomy_list": [
"liver"
],
"modality": "ct",
"controllable_anatomy_size": [],
"num_inference_steps": 30,
"mask_generation_num_inference_steps": 1000,
"random_seed": null,
"spatial_dims": 3,
"image_channels": 1,
"latent_channels": 4,
"output_size_xy": 512,
"output_size_z": 512,
"output_size": [
"@output_size_xy",
"@output_size_xy",
"@output_size_z"
],
"image_output_ext": ".nii.gz",
"label_output_ext": ".nii.gz",
"spacing_xy": 1.0,
"spacing_z": 1.0,
"spacing": [
"@spacing_xy",
"@spacing_xy",
"@spacing_z"
],
"latent_shape": [
"@latent_channels",
"$@output_size[0]//4",
"$@output_size[1]//4",
"$@output_size[2]//4"
],
"mask_generation_latent_shape": [
4,
64,
64,
64
],
"autoencoder_sliding_window_infer_size": [
80,
80,
80
],
"autoencoder_sliding_window_infer_overlap": 0.4,
"autoencoder_def": {
"_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
"spatial_dims": "@spatial_dims",
"in_channels": "@image_channels",
"out_channels": "@image_channels",
"latent_channels": "@latent_channels",
"num_channels": [
64,
128,
256
],
"num_res_blocks": [
2,
2,
2
],
"norm_num_groups": 32,
"norm_eps": 1e-06,
"attention_levels": [
false,
false,
false
],
"with_encoder_nonlocal_attn": false,
"with_decoder_nonlocal_attn": false,
"use_checkpointing": false,
"use_convtranspose": false,
"norm_float16": true,
"num_splits": 2,
"dim_split": 1
},
"diffusion_unet_def": {
"_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"out_channels": "@latent_channels",
"num_channels": [
64,
128,
256,
512
],
"attention_levels": [
false,
false,
true,
true
],
"num_head_channels": [
0,
0,
32,
32
],
"num_res_blocks": 2,
"use_flash_attention": true,
"include_top_region_index_input": false,
"include_bottom_region_index_input": false,
"include_spacing_input": true,
"num_class_embeds": 128,
"resblock_updown": true,
"include_fc": true
},
"controlnet_def": {
"_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"num_channels": [
64,
128,
256,
512
],
"attention_levels": [
false,
false,
true,
true
],
"num_head_channels": [
0,
0,
32,
32
],
"num_res_blocks": 2,
"use_flash_attention": true,
"conditioning_embedding_in_channels": 8,
"conditioning_embedding_num_channels": [
8,
32,
64
],
"num_class_embeds": 128,
"resblock_updown": true,
"include_fc": true
},
"mask_generation_autoencoder_def": {
"_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
"spatial_dims": "@spatial_dims",
"in_channels": 8,
"out_channels": 125,
"latent_channels": "@latent_channels",
"num_channels": [
32,
64,
128
],
"num_res_blocks": [
1,
2,
2
],
"norm_num_groups": 32,
"norm_eps": 1e-06,
"attention_levels": [
false,
false,
false
],
"with_encoder_nonlocal_attn": false,
"with_decoder_nonlocal_attn": false,
"use_flash_attention": false,
"use_checkpointing": true,
"use_convtranspose": true,
"norm_float16": true,
"num_splits": 8,
"dim_split": 1
},
"mask_generation_diffusion_def": {
"_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"out_channels": "@latent_channels",
"channels": [
64,
128,
256,
512
],
"attention_levels": [
false,
false,
true,
true
],
"num_head_channels": [
0,
0,
32,
32
],
"num_res_blocks": 2,
"use_flash_attention": true,
"with_conditioning": true,
"upcast_attention": true,
"cross_attention_dim": 10
},
"autoencoder": "$@autoencoder_def.to(@device)",
"checkpoint_autoencoder": "$torch.load(@trained_autoencoder_path, weights_only=True)",
"load_autoencoder": "$@autoencoder.load_state_dict(@checkpoint_autoencoder)",
"diffusion_unet": "$@diffusion_unet_def.to(@device)",
"checkpoint_diffusion_unet": "$torch.load(@trained_diffusion_path, weights_only=False)",
"load_diffusion": "$@diffusion_unet.load_state_dict(@checkpoint_diffusion_unet['unet_state_dict'])",
"controlnet": "$@controlnet_def.to(@device)",
"copy_controlnet_state": "$monai.networks.utils.copy_model_state(@controlnet, @diffusion_unet.state_dict())",
"checkpoint_controlnet": "$torch.load(@trained_controlnet_path, weights_only=False)",
"load_controlnet": "$@controlnet.load_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)",
"scale_factor": "$@checkpoint_diffusion_unet['scale_factor'].to(@device)",
"mask_generation_autoencoder": "$@mask_generation_autoencoder_def.to(@device)",
"checkpoint_mask_generation_autoencoder": "$torch.load(@trained_mask_generation_autoencoder_path, weights_only=True)",
"load_mask_generation_autoencoder": "$@mask_generation_autoencoder.load_state_dict(@checkpoint_mask_generation_autoencoder, strict=True)",
"mask_generation_diffusion_unet": "$@mask_generation_diffusion_def.to(@device)",
"checkpoint_mask_generation_diffusion_unet": "$torch.load(@trained_mask_generation_diffusion_path, weights_only=True)",
"load_mask_generation_diffusion": "$@mask_generation_diffusion_unet.load_state_dict(@checkpoint_mask_generation_diffusion_unet['unet_state_dict'], strict=True)",
"mask_generation_scale_factor": "$@checkpoint_mask_generation_diffusion_unet['scale_factor']",
"noise_scheduler": {
"_target_": "scripts.rectified_flow.RFlowScheduler",
"num_train_timesteps": 1000,
"use_discrete_timesteps": false,
"use_timestep_transform": true,
"sample_method": "uniform"
},
"mask_generation_noise_scheduler": {
"_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
"num_train_timesteps": 1000,
"beta_start": 0.0015,
"beta_end": 0.0195,
"schedule": "scaled_linear_beta",
"clip_sample": false
},
"check_input": "$scripts.sample.check_input(@body_region,@anatomy_list,@label_dict_json,@output_size,@spacing,@controllable_anatomy_size)",
"ldm_sampler": {
"_target_": "scripts.sample.LDMSampler",
"_requires_": [
"@create_output_dir",
"@load_diffusion",
"@load_autoencoder",
"@copy_controlnet_state",
"@load_controlnet",
"@load_mask_generation_autoencoder",
"@load_mask_generation_diffusion",
"@check_input"
],
"body_region": "@body_region",
"anatomy_list": "@anatomy_list",
"modality": "@modality",
"all_mask_files_json": "@all_mask_files_json",
"all_anatomy_size_condtions_json": "@all_anatomy_size_condtions_json",
"all_mask_files_base_dir": "@all_mask_files_base_dir",
"label_dict_json": "@label_dict_json",
"label_dict_remap_json": "@label_dict_remap_json",
"autoencoder": "@autoencoder",
"diffusion_unet": "@diffusion_unet",
"controlnet": "@controlnet",
"scale_factor": "@scale_factor",
"noise_scheduler": "@noise_scheduler",
"mask_generation_autoencoder": "@mask_generation_autoencoder",
"mask_generation_diffusion_unet": "@mask_generation_diffusion_unet",
"mask_generation_scale_factor": "@mask_generation_scale_factor",
"mask_generation_noise_scheduler": "@mask_generation_noise_scheduler",
"controllable_anatomy_size": "@controllable_anatomy_size",
"image_output_ext": "@image_output_ext",
"label_output_ext": "@label_output_ext",
"real_img_median_statistics": "@real_img_median_statistics_file",
"device": "@device",
"latent_shape": "@latent_shape",
"mask_generation_latent_shape": "@mask_generation_latent_shape",
"output_size": "@output_size",
"spacing": "@spacing",
"output_dir": "@output_dir",
"num_inference_steps": "@num_inference_steps",
"mask_generation_num_inference_steps": "@mask_generation_num_inference_steps",
"random_seed": "@random_seed",
"autoencoder_sliding_window_infer_size": "@autoencoder_sliding_window_infer_size",
"autoencoder_sliding_window_infer_overlap": "@autoencoder_sliding_window_infer_overlap"
},
"run": [
"$monai.utils.set_determinism(seed=@random_seed)",
"$@ldm_sampler.sample_multiple_images(@num_output_samples)"
],
"evaluator": null
}