File size: 3,608 Bytes
fd4ffa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
imports:
  - $import os

# seed: 28022024 # uncommend for deterministic results (but slower)
seed: null

bundle_root: "."
ckpt_path: $os.path.join(@bundle_root, "models")  # location to save checkpoints
output_dir: $os.path.join(@bundle_root, "eval")  # location to save events and logs
log_output_file: $os.path.join(@output_dir, "vista_cell.log")

mlflow_tracking_uri: null                     # enable mlflow logging, e.g. $@ckpt_path + '/mlruns/ or "http://127.0.0.1:8080" or a remote url
mlflow_log_system_metrics: true               # log system metrics to mlflow (requires: pip install psutil pynvml)
mlflow_run_name: null                         # optional name of the current run

ckpt_save: true                               # save checkpoints periodically
amp: true
amp_dtype: "float16" #float16 or bfloat16 (Ampere or newer)
channels_last: true
compile: false # complie the model for faster processing

start_epoch: 0
run_final_testing: true
use_weighted_sampler: false                   # only applicable when using several dataset jsons for data_list_files

pretrained_ckpt_name: null
pretrained_ckpt_path: null

# for commandline setting of a single dataset
datalist: datalists/cellpose_datalist.json
basedir: /cellpose_dataset
data_list_files:
  - {datalist: "@datalist", basedir: "@basedir"}


fold: 0
learning_rate: 0.01  # try 1.0e-4 if using AdamW
quick: false  # whether to use a small subset of data for quick testing
roi_size: [256, 256]

train:
  skip: false
  handlers: []
  trainer:
    num_warmup_epochs: 3
    max_epochs: 200
    num_epochs_per_saving: 1
    num_epochs_per_validation: null
  num_workers: 4
  batch_size: 1
  dataset:
    preprocessing:
      roi_size: "@roi_size"
    data:
      key: null # set to 'testing' to use this subset in periodic validations, instead of the the validation set
      data_list_files: "@data_list_files"

dataset:
  data:
    key: "testing"
    data_list_files: "@data_list_files"

validate:
  grouping: true
  evaluator:
    postprocessing: "@postprocessing"
  dataset:
    data: "@dataset#data"
  batch_size: 1
  num_workers: 4
  preprocessing: null
  postprocessing: null
  inferer: null
  handlers: null
  key_metric: null

infer:
  evaluator:
    postprocessing: "@postprocessing"
  dataset:
    data: "@dataset#data"


device: "$torch.device(('cuda:' + os.environ.get('LOCAL_RANK', '0')) if torch.cuda.is_available() else 'cpu')"
network_def:
  _target_: monai.networks.nets.cell_sam_wrapper.CellSamWrapper
  checkpoint: $os.path.join(@ckpt_path, "sam_vit_b_01ec64.pth")
network: $@network_def.to(@device)

loss_function:
  _target_: scripts.components.CellLoss

key_metric:
  _target_: scripts.components.CellAcc

# optimizer:
#   _target_: torch.optim.AdamW
#   params: $@network.parameters()
#   lr: "@learning_rate"
#   weight_decay: 1.0e-5

optimizer:
  _target_: torch.optim.SGD
  params: $@network.parameters()
  momentum: 0.9
  lr: "@learning_rate"
  weight_decay: 1.0e-5

lr_scheduler:
  _target_: monai.optimizers.lr_scheduler.WarmupCosineSchedule
  optimizer: "@optimizer"
  warmup_steps: "@train#trainer#num_warmup_epochs"
  warmup_multiplier: 0.1
  t_total: "@train#trainer#max_epochs"

inferer:
  sliding_inferer:
    _target_: monai.inferers.SlidingWindowInfererAdapt
    roi_size: "@roi_size"
    sw_batch_size: 1
    overlap: 0.625
    mode: "gaussian"
    cache_roi_weight_map: true
    progress: false

image_saver:
  _target_: scripts.components.SaveTiffd
  keys: "seg"
  output_dir: "@output_dir"
  nested_folder: false

postprocessing:
  _target_: monai.transforms.Compose
  transforms:
    - "@image_saver"