mednist_ddpm / configs /inference.yaml
project-monai's picture
Upload mednist_ddpm version 1.0.1
57decc6 verified
# This defines an inference script for generating a random image to a Pytorch file
imports:
- $import os
- $import datetime
- $import torch
- $import scripts
- $import monai
- $import torch.distributed as dist
- $import operator
# Common elements to all yaml files
-
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED
is_dist: '$dist.is_initialized()'
rank: '$dist.get_rank() if @is_dist else 0'
is_not_rank0: '$@rank > 0'
device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")'
network_def:
_target_: monai.networks.nets.DiffusionModelUNet
spatial_dims: 2
in_channels: 1
out_channels: 1
channels: [64, 128, 128]
attention_levels: [false, true, true]
num_res_blocks: 1
num_head_channels: 128
network: $@network_def.to(@device)
bundle_root: .
ckpt_path: $@bundle_root + '/models/model.pt'
use_amp: true
image_dim: 64
image_size: [1, '@image_dim', '@image_dim']
num_train_timesteps: 1000
base_transforms:
- _target_: LoadImaged
keys: '@image'
image_only: true
- _target_: EnsureChannelFirstd
keys: '@image'
- _target_: ScaleIntensityRanged
keys: '@image'
a_min: 0.0
a_max: 255.0
b_min: 0.0
b_max: 1.0
clip: true
scheduler:
_target_: monai.networks.schedulers.DDPMScheduler
num_train_timesteps: '@num_train_timesteps'
inferer:
_target_: monai.inferers.DiffusionInferer
scheduler: '@scheduler'
# Inference-specific
batch_size: 1
num_workers: 0
noise: $torch.rand(1,1,@image_dim,@image_dim) # create a random image every time this program is run
out_file: "" # where to save the tensor to
# using a lambda this defines a simple sampling function used below
sample: '$lambda x: @inferer.sample(input_noise=x, diffusion_model=@network, scheduler=@scheduler)'
load_state: '$@network.load_state_dict(torch.load(@ckpt_path, weights_only = True))' # command to load the saved model weights
save_trans:
_target_: Compose
transforms:
- _target_: ScaleIntensity
minv: 0.0
maxv: 255.0
- _target_: ToTensor
track_meta: false
- _target_: SaveImage
output_ext: "jpg"
resample: false
output_dtype: '$torch.uint8'
separate_folder: false
output_postfix: '@out_file'
# program to load the model weights, run `sample`, and store results to `out_file`
testing:
- '@load_state'
- '$torch.save(@sample(@noise.to(@device)), @out_file)'
#alternative version which saves to a jpg file
testing_jpg:
- '@load_state'
- '$@save_trans(@sample(@noise.to(@device))[0])'