File size: 11,445 Bytes
02aa18d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
{
    "imports": [
        "$import torch",
        "$from pathlib import Path",
        "$import scripts"
    ],
    "bundle_root": ".",
    "model_dir": "$@bundle_root + '/models'",
    "output_dir": "$@bundle_root + '/output'",
    "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)",
    "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
    "trained_autoencoder_path": "$@model_dir + '/autoencoder.pt'",
    "trained_diffusion_path": "$@model_dir + '/diffusion_unet.pt'",
    "trained_controlnet_path": "$@model_dir + '/controlnet.pt'",
    "trained_mask_generation_autoencoder_path": "$@model_dir + '/mask_generation_autoencoder.pt'",
    "trained_mask_generation_diffusion_path": "$@model_dir + '/mask_generation_diffusion_unet.pt'",
    "all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_3000'",
    "all_mask_files_json": "$@bundle_root + '/configs/candidate_masks_flexible_size_and_spacing_3000.json'",
    "all_anatomy_size_condtions_json": "$@bundle_root + '/configs/all_anatomy_size_condtions.json'",
    "label_dict_json": "$@bundle_root + '/configs/label_dict.json'",
    "label_dict_remap_json": "$@bundle_root + '/configs/label_dict_124_to_132.json'",
    "real_img_median_statistics_file": "$@bundle_root + '/configs/image_median_statistics.json'",
    "num_output_samples": 1,
    "body_region": [],
    "anatomy_list": [
        "liver"
    ],
    "modality": "ct",
    "controllable_anatomy_size": [],
    "num_inference_steps": 30,
    "mask_generation_num_inference_steps": 1000,
    "random_seed": null,
    "spatial_dims": 3,
    "image_channels": 1,
    "latent_channels": 4,
    "output_size_xy": 512,
    "output_size_z": 512,
    "output_size": [
        "@output_size_xy",
        "@output_size_xy",
        "@output_size_z"
    ],
    "image_output_ext": ".nii.gz",
    "label_output_ext": ".nii.gz",
    "spacing_xy": 1.0,
    "spacing_z": 1.0,
    "spacing": [
        "@spacing_xy",
        "@spacing_xy",
        "@spacing_z"
    ],
    "latent_shape": [
        "@latent_channels",
        "$@output_size[0]//4",
        "$@output_size[1]//4",
        "$@output_size[2]//4"
    ],
    "mask_generation_latent_shape": [
        4,
        64,
        64,
        64
    ],
    "autoencoder_sliding_window_infer_size": [
        80,
        80,
        80
    ],
    "autoencoder_sliding_window_infer_overlap": 0.4,
    "autoencoder_def": {
        "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
        "spatial_dims": "@spatial_dims",
        "in_channels": "@image_channels",
        "out_channels": "@image_channels",
        "latent_channels": "@latent_channels",
        "num_channels": [
            64,
            128,
            256
        ],
        "num_res_blocks": [
            2,
            2,
            2
        ],
        "norm_num_groups": 32,
        "norm_eps": 1e-06,
        "attention_levels": [
            false,
            false,
            false
        ],
        "with_encoder_nonlocal_attn": false,
        "with_decoder_nonlocal_attn": false,
        "use_checkpointing": false,
        "use_convtranspose": false,
        "norm_float16": true,
        "num_splits": 2,
        "dim_split": 1
    },
    "diffusion_unet_def": {
        "_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
        "spatial_dims": "@spatial_dims",
        "in_channels": "@latent_channels",
        "out_channels": "@latent_channels",
        "num_channels": [
            64,
            128,
            256,
            512
        ],
        "attention_levels": [
            false,
            false,
            true,
            true
        ],
        "num_head_channels": [
            0,
            0,
            32,
            32
        ],
        "num_res_blocks": 2,
        "use_flash_attention": true,
        "include_top_region_index_input": false,
        "include_bottom_region_index_input": false,
        "include_spacing_input": true,
        "num_class_embeds": 128,
        "resblock_updown": true,
        "include_fc": true
    },
    "controlnet_def": {
        "_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi",
        "spatial_dims": "@spatial_dims",
        "in_channels": "@latent_channels",
        "num_channels": [
            64,
            128,
            256,
            512
        ],
        "attention_levels": [
            false,
            false,
            true,
            true
        ],
        "num_head_channels": [
            0,
            0,
            32,
            32
        ],
        "num_res_blocks": 2,
        "use_flash_attention": true,
        "conditioning_embedding_in_channels": 8,
        "conditioning_embedding_num_channels": [
            8,
            32,
            64
        ],
        "num_class_embeds": 128,
        "resblock_updown": true,
        "include_fc": true
    },
    "mask_generation_autoencoder_def": {
        "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
        "spatial_dims": "@spatial_dims",
        "in_channels": 8,
        "out_channels": 125,
        "latent_channels": "@latent_channels",
        "num_channels": [
            32,
            64,
            128
        ],
        "num_res_blocks": [
            1,
            2,
            2
        ],
        "norm_num_groups": 32,
        "norm_eps": 1e-06,
        "attention_levels": [
            false,
            false,
            false
        ],
        "with_encoder_nonlocal_attn": false,
        "with_decoder_nonlocal_attn": false,
        "use_flash_attention": false,
        "use_checkpointing": true,
        "use_convtranspose": true,
        "norm_float16": true,
        "num_splits": 8,
        "dim_split": 1
    },
    "mask_generation_diffusion_def": {
        "_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet",
        "spatial_dims": "@spatial_dims",
        "in_channels": "@latent_channels",
        "out_channels": "@latent_channels",
        "channels": [
            64,
            128,
            256,
            512
        ],
        "attention_levels": [
            false,
            false,
            true,
            true
        ],
        "num_head_channels": [
            0,
            0,
            32,
            32
        ],
        "num_res_blocks": 2,
        "use_flash_attention": true,
        "with_conditioning": true,
        "upcast_attention": true,
        "cross_attention_dim": 10
    },
    "autoencoder": "$@autoencoder_def.to(@device)",
    "checkpoint_autoencoder": "$torch.load(@trained_autoencoder_path, weights_only=True)",
    "load_autoencoder": "$@autoencoder.load_state_dict(@checkpoint_autoencoder)",
    "diffusion_unet": "$@diffusion_unet_def.to(@device)",
    "checkpoint_diffusion_unet": "$torch.load(@trained_diffusion_path, weights_only=False)",
    "load_diffusion": "$@diffusion_unet.load_state_dict(@checkpoint_diffusion_unet['unet_state_dict'])",
    "controlnet": "$@controlnet_def.to(@device)",
    "copy_controlnet_state": "$monai.networks.utils.copy_model_state(@controlnet, @diffusion_unet.state_dict())",
    "checkpoint_controlnet": "$torch.load(@trained_controlnet_path, weights_only=False)",
    "load_controlnet": "$@controlnet.load_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)",
    "scale_factor": "$@checkpoint_diffusion_unet['scale_factor'].to(@device)",
    "mask_generation_autoencoder": "$@mask_generation_autoencoder_def.to(@device)",
    "checkpoint_mask_generation_autoencoder": "$torch.load(@trained_mask_generation_autoencoder_path, weights_only=True)",
    "load_mask_generation_autoencoder": "$@mask_generation_autoencoder.load_state_dict(@checkpoint_mask_generation_autoencoder, strict=True)",
    "mask_generation_diffusion_unet": "$@mask_generation_diffusion_def.to(@device)",
    "checkpoint_mask_generation_diffusion_unet": "$torch.load(@trained_mask_generation_diffusion_path, weights_only=True)",
    "load_mask_generation_diffusion": "$@mask_generation_diffusion_unet.load_state_dict(@checkpoint_mask_generation_diffusion_unet['unet_state_dict'], strict=True)",
    "mask_generation_scale_factor": "$@checkpoint_mask_generation_diffusion_unet['scale_factor']",
    "noise_scheduler": {
        "_target_": "scripts.rectified_flow.RFlowScheduler",
        "num_train_timesteps": 1000,
        "use_discrete_timesteps": false,
        "use_timestep_transform": true,
        "sample_method": "uniform"
    },
    "mask_generation_noise_scheduler": {
        "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
        "num_train_timesteps": 1000,
        "beta_start": 0.0015,
        "beta_end": 0.0195,
        "schedule": "scaled_linear_beta",
        "clip_sample": false
    },
    "check_input": "$scripts.sample.check_input(@body_region,@anatomy_list,@label_dict_json,@output_size,@spacing,@controllable_anatomy_size)",
    "ldm_sampler": {
        "_target_": "scripts.sample.LDMSampler",
        "_requires_": [
            "@create_output_dir",
            "@load_diffusion",
            "@load_autoencoder",
            "@copy_controlnet_state",
            "@load_controlnet",
            "@load_mask_generation_autoencoder",
            "@load_mask_generation_diffusion",
            "@check_input"
        ],
        "body_region": "@body_region",
        "anatomy_list": "@anatomy_list",
        "modality": "@modality",
        "all_mask_files_json": "@all_mask_files_json",
        "all_anatomy_size_condtions_json": "@all_anatomy_size_condtions_json",
        "all_mask_files_base_dir": "@all_mask_files_base_dir",
        "label_dict_json": "@label_dict_json",
        "label_dict_remap_json": "@label_dict_remap_json",
        "autoencoder": "@autoencoder",
        "diffusion_unet": "@diffusion_unet",
        "controlnet": "@controlnet",
        "scale_factor": "@scale_factor",
        "noise_scheduler": "@noise_scheduler",
        "mask_generation_autoencoder": "@mask_generation_autoencoder",
        "mask_generation_diffusion_unet": "@mask_generation_diffusion_unet",
        "mask_generation_scale_factor": "@mask_generation_scale_factor",
        "mask_generation_noise_scheduler": "@mask_generation_noise_scheduler",
        "controllable_anatomy_size": "@controllable_anatomy_size",
        "image_output_ext": "@image_output_ext",
        "label_output_ext": "@label_output_ext",
        "real_img_median_statistics": "@real_img_median_statistics_file",
        "device": "@device",
        "latent_shape": "@latent_shape",
        "mask_generation_latent_shape": "@mask_generation_latent_shape",
        "output_size": "@output_size",
        "spacing": "@spacing",
        "output_dir": "@output_dir",
        "num_inference_steps": "@num_inference_steps",
        "mask_generation_num_inference_steps": "@mask_generation_num_inference_steps",
        "random_seed": "@random_seed",
        "autoencoder_sliding_window_infer_size": "@autoencoder_sliding_window_infer_size",
        "autoencoder_sliding_window_infer_overlap": "@autoencoder_sliding_window_infer_overlap"
    },
    "run": [
        "$monai.utils.set_determinism(seed=@random_seed)",
        "$@ldm_sampler.sample_multiple_images(@num_output_samples)"
    ],
    "evaluator": null
}