|
|
|
|
|
imports: |
|
|
- $import os |
|
|
- $import datetime |
|
|
- $import torch |
|
|
- $import scripts |
|
|
- $import monai |
|
|
- $import torch.distributed as dist |
|
|
- $import operator |
|
|
|
|
|
|
|
|
- |
|
|
image: $monai.utils.CommonKeys.IMAGE |
|
|
label: $monai.utils.CommonKeys.LABEL |
|
|
pred: $monai.utils.CommonKeys.PRED |
|
|
|
|
|
is_dist: '$dist.is_initialized()' |
|
|
rank: '$dist.get_rank() if @is_dist else 0' |
|
|
is_not_rank0: '$@rank > 0' |
|
|
device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")' |
|
|
|
|
|
network_def: |
|
|
_target_: monai.networks.nets.DiffusionModelUNet |
|
|
spatial_dims: 2 |
|
|
in_channels: 1 |
|
|
out_channels: 1 |
|
|
channels: [64, 128, 128] |
|
|
attention_levels: [false, true, true] |
|
|
num_res_blocks: 1 |
|
|
num_head_channels: 128 |
|
|
|
|
|
network: $@network_def.to(@device) |
|
|
bundle_root: . |
|
|
ckpt_path: $@bundle_root + '/models/model.pt' |
|
|
use_amp: true |
|
|
image_dim: 64 |
|
|
image_size: [1, '@image_dim', '@image_dim'] |
|
|
num_train_timesteps: 1000 |
|
|
|
|
|
base_transforms: |
|
|
- _target_: LoadImaged |
|
|
keys: '@image' |
|
|
image_only: true |
|
|
- _target_: EnsureChannelFirstd |
|
|
keys: '@image' |
|
|
- _target_: ScaleIntensityRanged |
|
|
keys: '@image' |
|
|
a_min: 0.0 |
|
|
a_max: 255.0 |
|
|
b_min: 0.0 |
|
|
b_max: 1.0 |
|
|
clip: true |
|
|
|
|
|
scheduler: |
|
|
_target_: monai.networks.schedulers.DDPMScheduler |
|
|
num_train_timesteps: '@num_train_timesteps' |
|
|
|
|
|
inferer: |
|
|
_target_: monai.inferers.DiffusionInferer |
|
|
scheduler: '@scheduler' |
|
|
|
|
|
|
|
|
|
|
|
batch_size: 1 |
|
|
num_workers: 0 |
|
|
|
|
|
noise: $torch.rand(1,1,@image_dim,@image_dim) |
|
|
|
|
|
out_file: "" |
|
|
|
|
|
|
|
|
sample: '$lambda x: @inferer.sample(input_noise=x, diffusion_model=@network, scheduler=@scheduler)' |
|
|
|
|
|
load_state: '$@network.load_state_dict(torch.load(@ckpt_path, weights_only = True))' |
|
|
|
|
|
save_trans: |
|
|
_target_: Compose |
|
|
transforms: |
|
|
- _target_: ScaleIntensity |
|
|
minv: 0.0 |
|
|
maxv: 255.0 |
|
|
- _target_: ToTensor |
|
|
track_meta: false |
|
|
- _target_: SaveImage |
|
|
output_ext: "jpg" |
|
|
resample: false |
|
|
output_dtype: '$torch.uint8' |
|
|
separate_folder: false |
|
|
output_postfix: '@out_file' |
|
|
|
|
|
|
|
|
testing: |
|
|
- '@load_state' |
|
|
- '$torch.save(@sample(@noise.to(@device)), @out_file)' |
|
|
|
|
|
|
|
|
testing_jpg: |
|
|
- '@load_state' |
|
|
- '$@save_trans(@sample(@noise.to(@device))[0])' |
|
|
|