File size: 3,608 Bytes
fd4ffa6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
imports:
- $import os
# seed: 28022024 # uncommend for deterministic results (but slower)
seed: null
bundle_root: "."
ckpt_path: $os.path.join(@bundle_root, "models") # location to save checkpoints
output_dir: $os.path.join(@bundle_root, "eval") # location to save events and logs
log_output_file: $os.path.join(@output_dir, "vista_cell.log")
mlflow_tracking_uri: null # enable mlflow logging, e.g. $@ckpt_path + '/mlruns/ or "http://127.0.0.1:8080" or a remote url
mlflow_log_system_metrics: true # log system metrics to mlflow (requires: pip install psutil pynvml)
mlflow_run_name: null # optional name of the current run
ckpt_save: true # save checkpoints periodically
amp: true
amp_dtype: "float16" #float16 or bfloat16 (Ampere or newer)
channels_last: true
compile: false # complie the model for faster processing
start_epoch: 0
run_final_testing: true
use_weighted_sampler: false # only applicable when using several dataset jsons for data_list_files
pretrained_ckpt_name: null
pretrained_ckpt_path: null
# for commandline setting of a single dataset
datalist: datalists/cellpose_datalist.json
basedir: /cellpose_dataset
data_list_files:
- {datalist: "@datalist", basedir: "@basedir"}
fold: 0
learning_rate: 0.01 # try 1.0e-4 if using AdamW
quick: false # whether to use a small subset of data for quick testing
roi_size: [256, 256]
train:
skip: false
handlers: []
trainer:
num_warmup_epochs: 3
max_epochs: 200
num_epochs_per_saving: 1
num_epochs_per_validation: null
num_workers: 4
batch_size: 1
dataset:
preprocessing:
roi_size: "@roi_size"
data:
key: null # set to 'testing' to use this subset in periodic validations, instead of the the validation set
data_list_files: "@data_list_files"
dataset:
data:
key: "testing"
data_list_files: "@data_list_files"
validate:
grouping: true
evaluator:
postprocessing: "@postprocessing"
dataset:
data: "@dataset#data"
batch_size: 1
num_workers: 4
preprocessing: null
postprocessing: null
inferer: null
handlers: null
key_metric: null
infer:
evaluator:
postprocessing: "@postprocessing"
dataset:
data: "@dataset#data"
device: "$torch.device(('cuda:' + os.environ.get('LOCAL_RANK', '0')) if torch.cuda.is_available() else 'cpu')"
network_def:
_target_: monai.networks.nets.cell_sam_wrapper.CellSamWrapper
checkpoint: $os.path.join(@ckpt_path, "sam_vit_b_01ec64.pth")
network: $@network_def.to(@device)
loss_function:
_target_: scripts.components.CellLoss
key_metric:
_target_: scripts.components.CellAcc
# optimizer:
# _target_: torch.optim.AdamW
# params: $@network.parameters()
# lr: "@learning_rate"
# weight_decay: 1.0e-5
optimizer:
_target_: torch.optim.SGD
params: $@network.parameters()
momentum: 0.9
lr: "@learning_rate"
weight_decay: 1.0e-5
lr_scheduler:
_target_: monai.optimizers.lr_scheduler.WarmupCosineSchedule
optimizer: "@optimizer"
warmup_steps: "@train#trainer#num_warmup_epochs"
warmup_multiplier: 0.1
t_total: "@train#trainer#max_epochs"
inferer:
sliding_inferer:
_target_: monai.inferers.SlidingWindowInfererAdapt
roi_size: "@roi_size"
sw_batch_size: 1
overlap: 0.625
mode: "gaussian"
cache_roi_weight_map: true
progress: false
image_saver:
_target_: scripts.components.SaveTiffd
keys: "seg"
output_dir: "@output_dir"
nested_folder: false
postprocessing:
_target_: monai.transforms.Compose
transforms:
- "@image_saver"
|