imports: - $import os # seed: 28022024 # uncommend for deterministic results (but slower) seed: null bundle_root: "." ckpt_path: $os.path.join(@bundle_root, "models") # location to save checkpoints output_dir: $os.path.join(@bundle_root, "eval") # location to save events and logs log_output_file: $os.path.join(@output_dir, "vista_cell.log") mlflow_tracking_uri: null # enable mlflow logging, e.g. $@ckpt_path + '/mlruns/ or "http://127.0.0.1:8080" or a remote url mlflow_log_system_metrics: true # log system metrics to mlflow (requires: pip install psutil pynvml) mlflow_run_name: null # optional name of the current run ckpt_save: true # save checkpoints periodically amp: true amp_dtype: "float16" #float16 or bfloat16 (Ampere or newer) channels_last: true compile: false # complie the model for faster processing start_epoch: 0 run_final_testing: true use_weighted_sampler: false # only applicable when using several dataset jsons for data_list_files pretrained_ckpt_name: null pretrained_ckpt_path: null # for commandline setting of a single dataset datalist: datalists/cellpose_datalist.json basedir: /cellpose_dataset data_list_files: - {datalist: "@datalist", basedir: "@basedir"} fold: 0 learning_rate: 0.01 # try 1.0e-4 if using AdamW quick: false # whether to use a small subset of data for quick testing roi_size: [256, 256] train: skip: false handlers: [] trainer: num_warmup_epochs: 3 max_epochs: 200 num_epochs_per_saving: 1 num_epochs_per_validation: null num_workers: 4 batch_size: 1 dataset: preprocessing: roi_size: "@roi_size" data: key: null # set to 'testing' to use this subset in periodic validations, instead of the the validation set data_list_files: "@data_list_files" dataset: data: key: "testing" data_list_files: "@data_list_files" validate: grouping: true evaluator: postprocessing: "@postprocessing" dataset: data: "@dataset#data" batch_size: 1 num_workers: 4 preprocessing: null postprocessing: null inferer: null handlers: null key_metric: null infer: evaluator: postprocessing: "@postprocessing" dataset: data: "@dataset#data" device: "$torch.device(('cuda:' + os.environ.get('LOCAL_RANK', '0')) if torch.cuda.is_available() else 'cpu')" network_def: _target_: monai.networks.nets.cell_sam_wrapper.CellSamWrapper checkpoint: $os.path.join(@ckpt_path, "sam_vit_b_01ec64.pth") network: $@network_def.to(@device) loss_function: _target_: scripts.components.CellLoss key_metric: _target_: scripts.components.CellAcc # optimizer: # _target_: torch.optim.AdamW # params: $@network.parameters() # lr: "@learning_rate" # weight_decay: 1.0e-5 optimizer: _target_: torch.optim.SGD params: $@network.parameters() momentum: 0.9 lr: "@learning_rate" weight_decay: 1.0e-5 lr_scheduler: _target_: monai.optimizers.lr_scheduler.WarmupCosineSchedule optimizer: "@optimizer" warmup_steps: "@train#trainer#num_warmup_epochs" warmup_multiplier: 0.1 t_total: "@train#trainer#max_epochs" inferer: sliding_inferer: _target_: monai.inferers.SlidingWindowInfererAdapt roi_size: "@roi_size" sw_batch_size: 1 overlap: 0.625 mode: "gaussian" cache_roi_weight_map: true progress: false image_saver: _target_: scripts.components.SaveTiffd keys: "seg" output_dir: "@output_dir" nested_folder: false postprocessing: _target_: monai.transforms.Compose transforms: - "@image_saver"