# This defines an inference script for generating a random image to a Pytorch file imports: - $import os - $import datetime - $import torch - $import scripts - $import monai - $import torch.distributed as dist - $import operator # Common elements to all yaml files - image: $monai.utils.CommonKeys.IMAGE label: $monai.utils.CommonKeys.LABEL pred: $monai.utils.CommonKeys.PRED is_dist: '$dist.is_initialized()' rank: '$dist.get_rank() if @is_dist else 0' is_not_rank0: '$@rank > 0' device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")' network_def: _target_: monai.networks.nets.DiffusionModelUNet spatial_dims: 2 in_channels: 1 out_channels: 1 channels: [64, 128, 128] attention_levels: [false, true, true] num_res_blocks: 1 num_head_channels: 128 network: $@network_def.to(@device) bundle_root: . ckpt_path: $@bundle_root + '/models/model.pt' use_amp: true image_dim: 64 image_size: [1, '@image_dim', '@image_dim'] num_train_timesteps: 1000 base_transforms: - _target_: LoadImaged keys: '@image' image_only: true - _target_: EnsureChannelFirstd keys: '@image' - _target_: ScaleIntensityRanged keys: '@image' a_min: 0.0 a_max: 255.0 b_min: 0.0 b_max: 1.0 clip: true scheduler: _target_: monai.networks.schedulers.DDPMScheduler num_train_timesteps: '@num_train_timesteps' inferer: _target_: monai.inferers.DiffusionInferer scheduler: '@scheduler' # Inference-specific batch_size: 1 num_workers: 0 noise: $torch.rand(1,1,@image_dim,@image_dim) # create a random image every time this program is run out_file: "" # where to save the tensor to # using a lambda this defines a simple sampling function used below sample: '$lambda x: @inferer.sample(input_noise=x, diffusion_model=@network, scheduler=@scheduler)' load_state: '$@network.load_state_dict(torch.load(@ckpt_path, weights_only = True))' # command to load the saved model weights save_trans: _target_: Compose transforms: - _target_: ScaleIntensity minv: 0.0 maxv: 255.0 - _target_: ToTensor track_meta: false - _target_: SaveImage output_ext: "jpg" resample: false output_dtype: '$torch.uint8' separate_folder: false output_postfix: '@out_file' # program to load the model weights, run `sample`, and store results to `out_file` testing: - '@load_state' - '$torch.save(@sample(@noise.to(@device)), @out_file)' #alternative version which saves to a jpg file testing_jpg: - '@load_state' - '$@save_trans(@sample(@noise.to(@device))[0])'