vista2d / configs /hyper_parameters.yaml
project-monai's picture
Upload vista2d version 0.3.1
fd4ffa6 verified
imports:
- $import os
# seed: 28022024 # uncommend for deterministic results (but slower)
seed: null
bundle_root: "."
ckpt_path: $os.path.join(@bundle_root, "models") # location to save checkpoints
output_dir: $os.path.join(@bundle_root, "eval") # location to save events and logs
log_output_file: $os.path.join(@output_dir, "vista_cell.log")
mlflow_tracking_uri: null # enable mlflow logging, e.g. $@ckpt_path + '/mlruns/ or "http://127.0.0.1:8080" or a remote url
mlflow_log_system_metrics: true # log system metrics to mlflow (requires: pip install psutil pynvml)
mlflow_run_name: null # optional name of the current run
ckpt_save: true # save checkpoints periodically
amp: true
amp_dtype: "float16" #float16 or bfloat16 (Ampere or newer)
channels_last: true
compile: false # complie the model for faster processing
start_epoch: 0
run_final_testing: true
use_weighted_sampler: false # only applicable when using several dataset jsons for data_list_files
pretrained_ckpt_name: null
pretrained_ckpt_path: null
# for commandline setting of a single dataset
datalist: datalists/cellpose_datalist.json
basedir: /cellpose_dataset
data_list_files:
- {datalist: "@datalist", basedir: "@basedir"}
fold: 0
learning_rate: 0.01 # try 1.0e-4 if using AdamW
quick: false # whether to use a small subset of data for quick testing
roi_size: [256, 256]
train:
skip: false
handlers: []
trainer:
num_warmup_epochs: 3
max_epochs: 200
num_epochs_per_saving: 1
num_epochs_per_validation: null
num_workers: 4
batch_size: 1
dataset:
preprocessing:
roi_size: "@roi_size"
data:
key: null # set to 'testing' to use this subset in periodic validations, instead of the the validation set
data_list_files: "@data_list_files"
dataset:
data:
key: "testing"
data_list_files: "@data_list_files"
validate:
grouping: true
evaluator:
postprocessing: "@postprocessing"
dataset:
data: "@dataset#data"
batch_size: 1
num_workers: 4
preprocessing: null
postprocessing: null
inferer: null
handlers: null
key_metric: null
infer:
evaluator:
postprocessing: "@postprocessing"
dataset:
data: "@dataset#data"
device: "$torch.device(('cuda:' + os.environ.get('LOCAL_RANK', '0')) if torch.cuda.is_available() else 'cpu')"
network_def:
_target_: monai.networks.nets.cell_sam_wrapper.CellSamWrapper
checkpoint: $os.path.join(@ckpt_path, "sam_vit_b_01ec64.pth")
network: $@network_def.to(@device)
loss_function:
_target_: scripts.components.CellLoss
key_metric:
_target_: scripts.components.CellAcc
# optimizer:
# _target_: torch.optim.AdamW
# params: $@network.parameters()
# lr: "@learning_rate"
# weight_decay: 1.0e-5
optimizer:
_target_: torch.optim.SGD
params: $@network.parameters()
momentum: 0.9
lr: "@learning_rate"
weight_decay: 1.0e-5
lr_scheduler:
_target_: monai.optimizers.lr_scheduler.WarmupCosineSchedule
optimizer: "@optimizer"
warmup_steps: "@train#trainer#num_warmup_epochs"
warmup_multiplier: 0.1
t_total: "@train#trainer#max_epochs"
inferer:
sliding_inferer:
_target_: monai.inferers.SlidingWindowInfererAdapt
roi_size: "@roi_size"
sw_batch_size: 1
overlap: 0.625
mode: "gaussian"
cache_roi_weight_map: true
progress: false
image_saver:
_target_: scripts.components.SaveTiffd
keys: "seg"
output_dir: "@output_dir"
nested_folder: false
postprocessing:
_target_: monai.transforms.Compose
transforms:
- "@image_saver"