|
|
|
|
|
imports: |
|
|
- $import os |
|
|
- $import datetime |
|
|
- $import torch |
|
|
- $import scripts |
|
|
- $import monai |
|
|
- $import torch.distributed as dist |
|
|
- $import operator |
|
|
|
|
|
|
|
|
- |
|
|
image: $monai.utils.CommonKeys.IMAGE |
|
|
label: $monai.utils.CommonKeys.LABEL |
|
|
pred: $monai.utils.CommonKeys.PRED |
|
|
|
|
|
is_dist: '$dist.is_initialized()' |
|
|
rank: '$dist.get_rank() if @is_dist else 0' |
|
|
is_not_rank0: '$@rank > 0' |
|
|
device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")' |
|
|
|
|
|
network_def: |
|
|
_target_: monai.networks.nets.DiffusionModelUNet |
|
|
spatial_dims: 2 |
|
|
in_channels: 1 |
|
|
out_channels: 1 |
|
|
channels: [64, 128, 128] |
|
|
attention_levels: [false, true, true] |
|
|
num_res_blocks: 1 |
|
|
num_head_channels: 128 |
|
|
|
|
|
network: $@network_def.to(@device) |
|
|
bundle_root: . |
|
|
ckpt_path: $@bundle_root + '/models/model.pt' |
|
|
use_amp: true |
|
|
image_dim: 64 |
|
|
image_size: [1, '@image_dim', '@image_dim'] |
|
|
num_train_timesteps: 1000 |
|
|
|
|
|
base_transforms: |
|
|
- _target_: LoadImaged |
|
|
keys: '@image' |
|
|
image_only: true |
|
|
- _target_: EnsureChannelFirstd |
|
|
keys: '@image' |
|
|
- _target_: ScaleIntensityRanged |
|
|
keys: '@image' |
|
|
a_min: 0.0 |
|
|
a_max: 255.0 |
|
|
b_min: 0.0 |
|
|
b_max: 1.0 |
|
|
clip: true |
|
|
|
|
|
scheduler: |
|
|
_target_: monai.networks.schedulers.DDPMScheduler |
|
|
num_train_timesteps: '@num_train_timesteps' |
|
|
|
|
|
inferer: |
|
|
_target_: monai.inferers.DiffusionInferer |
|
|
scheduler: '@scheduler' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_dir: $datetime.datetime.now().strftime('./results/output_%y%m%d_%H%M%S') |
|
|
dataset_dir: ./data |
|
|
|
|
|
train_data: |
|
|
_target_ : MedNISTDataset |
|
|
root_dir: '@dataset_dir' |
|
|
section: training |
|
|
download: true |
|
|
progress: false |
|
|
seed: 0 |
|
|
|
|
|
val_data: |
|
|
_target_ : MedNISTDataset |
|
|
root_dir: '@dataset_dir' |
|
|
section: validation |
|
|
download: true |
|
|
progress: false |
|
|
seed: 0 |
|
|
|
|
|
train_datalist: '$[{"image": item["image"]} for item in @train_data.data if item["class_name"] == "Hand"]' |
|
|
val_datalist: '$[{"image": item["image"]} for item in @val_data.data if item["class_name"] == "Hand"]' |
|
|
|
|
|
batch_size: 8 |
|
|
num_substeps: 1 |
|
|
num_workers: 4 |
|
|
use_thread_workers: false |
|
|
|
|
|
lr: 0.000025 |
|
|
rand_prob: 0.5 |
|
|
num_epochs: 75 |
|
|
val_interval: 5 |
|
|
save_interval: 5 |
|
|
|
|
|
train_transforms: |
|
|
- _target_: RandAffined |
|
|
keys: '@image' |
|
|
rotate_range: |
|
|
- ['$-np.pi / 36', '$np.pi / 36'] |
|
|
- ['$-np.pi / 36', '$np.pi / 36'] |
|
|
translate_range: |
|
|
- [-1, 1] |
|
|
- [-1, 1] |
|
|
scale_range: |
|
|
- [-0.05, 0.05] |
|
|
- [-0.05, 0.05] |
|
|
spatial_size: [64, 64] |
|
|
padding_mode: "zeros" |
|
|
prob: '@rand_prob' |
|
|
|
|
|
train_ds: |
|
|
_target_: Dataset |
|
|
data: $@train_datalist |
|
|
transform: |
|
|
_target_: Compose |
|
|
transforms: '$@base_transforms + @train_transforms' |
|
|
|
|
|
train_loader: |
|
|
_target_: ThreadDataLoader |
|
|
dataset: '@train_ds' |
|
|
batch_size: '@batch_size' |
|
|
repeats: '@num_substeps' |
|
|
num_workers: '@num_workers' |
|
|
use_thread_workers: '@use_thread_workers' |
|
|
persistent_workers: '$@num_workers > 0' |
|
|
shuffle: true |
|
|
|
|
|
val_ds: |
|
|
_target_: Dataset |
|
|
data: $@val_datalist |
|
|
transform: |
|
|
_target_: Compose |
|
|
transforms: '@base_transforms' |
|
|
|
|
|
val_loader: |
|
|
_target_: DataLoader |
|
|
dataset: '@val_ds' |
|
|
batch_size: '@batch_size' |
|
|
num_workers: '@num_workers' |
|
|
persistent_workers: '$@num_workers > 0' |
|
|
shuffle: false |
|
|
|
|
|
lossfn: |
|
|
_target_: torch.nn.MSELoss |
|
|
|
|
|
optimizer: |
|
|
_target_: torch.optim.Adam |
|
|
params: $@network.parameters() |
|
|
lr: '@lr' |
|
|
|
|
|
prepare_batch: |
|
|
_target_: monai.engines.DiffusionPrepareBatch |
|
|
num_train_timesteps: '@num_train_timesteps' |
|
|
|
|
|
val_handlers: |
|
|
- _target_: StatsHandler |
|
|
name: train_log |
|
|
output_transform: '$lambda x: None' |
|
|
_disabled_: '@is_not_rank0' |
|
|
|
|
|
evaluator: |
|
|
_target_: SupervisedEvaluator |
|
|
device: '@device' |
|
|
val_data_loader: '@val_loader' |
|
|
network: '@network' |
|
|
amp: '@use_amp' |
|
|
inferer: '@inferer' |
|
|
prepare_batch: '@prepare_batch' |
|
|
key_val_metric: |
|
|
val_mean_abs_error: |
|
|
_target_: MeanAbsoluteError |
|
|
output_transform: $monai.handlers.from_engine([@pred, @label]) |
|
|
metric_cmp_fn: '$operator.lt' |
|
|
val_handlers: '$list(filter(bool, @val_handlers))' |
|
|
|
|
|
handlers: |
|
|
- _target_: CheckpointLoader |
|
|
_disabled_: $not os.path.exists(@ckpt_path) |
|
|
load_path: '@ckpt_path' |
|
|
load_dict: |
|
|
model: '@network' |
|
|
- _target_: ValidationHandler |
|
|
validator: '@evaluator' |
|
|
epoch_level: true |
|
|
interval: '@val_interval' |
|
|
- _target_: CheckpointSaver |
|
|
save_dir: '@output_dir' |
|
|
save_dict: |
|
|
model: '@network' |
|
|
save_interval: '@save_interval' |
|
|
save_final: true |
|
|
epoch_level: true |
|
|
_disabled_: '@is_not_rank0' |
|
|
|
|
|
trainer: |
|
|
_target_: SupervisedTrainer |
|
|
max_epochs: '@num_epochs' |
|
|
device: '@device' |
|
|
train_data_loader: '@train_loader' |
|
|
network: '@network' |
|
|
loss_function: '@lossfn' |
|
|
optimizer: '@optimizer' |
|
|
inferer: '@inferer' |
|
|
prepare_batch: '@prepare_batch' |
|
|
key_train_metric: |
|
|
train_acc: |
|
|
_target_: MeanSquaredError |
|
|
output_transform: $monai.handlers.from_engine([@pred, @label]) |
|
|
metric_cmp_fn: '$operator.lt' |
|
|
train_handlers: '$list(filter(bool, @handlers))' |
|
|
amp: '@use_amp' |
|
|
|
|
|
training: |
|
|
- '$monai.utils.set_determinism(0)' |
|
|
- '$@trainer.run()' |
|
|
|