Anirudh Balaraman commited on
Commit
80a9c91
·
1 Parent(s): 1baebae

fix finetuning

Browse files
config/config_cspca_train.yaml CHANGED
@@ -7,10 +7,10 @@ tile_size: !!int 64
7
  depth: !!int 3
8
  use_heatmap: !!bool True
9
  workers: !!int 6
10
- checkpoint_pirads: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/models/pirads.pt
11
  epochs: !!int 80
12
  batch_size: !!int 8
13
- optim_lr: !!float 2e-4
14
 
15
 
16
 
 
7
  depth: !!int 3
8
  use_heatmap: !!bool True
9
  workers: !!int 6
10
+ checkpoint_pirads: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/logs/pirads_training/model_47.pt
11
  epochs: !!int 80
12
  batch_size: !!int 8
13
+ optim_lr: !!float 2e-5
14
 
15
 
16
 
job_scripts/train_cspca.sh CHANGED
@@ -1,11 +1,11 @@
1
  #!/bin/bash
2
- #SBATCH --job-name=cspca_training # Specify job name
3
  #SBATCH --partition=gpu # Specify partition name
4
  #SBATCH --mem=128G
5
  #SBATCH --gres=gpu:1
6
  #SBATCH --time=48:00:00 # Set a limit on the total run time
7
- #SBATCH --output=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/logs/%x/log.o%j # File name for standard output
8
- #SBATCH --error=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/logs/%x/log.e%j # File name for standard error output
9
  #SBATCH --mail-user=anirudh.balaraman@charite.de
10
  #SBATCH --mail-type=END,FAIL
11
 
@@ -13,7 +13,7 @@
13
  source /etc/profile.d/conda.sh
14
  conda activate foundation
15
 
16
- RUNDIR="/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation"
17
 
18
 
19
- srun python -u $RUNDIR/MIL/new_folder/run_cspca.py --mode train --config $RUNDIR/MIL/new_folder/config/config_cspca_train.yaml
 
1
  #!/bin/bash
2
+ #SBATCH --job-name=cspca_train_47 # Specify job name
3
  #SBATCH --partition=gpu # Specify partition name
4
  #SBATCH --mem=128G
5
  #SBATCH --gres=gpu:1
6
  #SBATCH --time=48:00:00 # Set a limit on the total run time
7
+ #SBATCH --output=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/logs/%x/log.o%j # File name for standard output
8
+ #SBATCH --error=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/logs/%x/log.e%j # File name for standard error output
9
  #SBATCH --mail-user=anirudh.balaraman@charite.de
10
  #SBATCH --mail-type=END,FAIL
11
 
 
13
  source /etc/profile.d/conda.sh
14
  conda activate foundation
15
 
16
+ RUNDIR="/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate"
17
 
18
 
19
+ srun python -u $RUNDIR/run_cspca.py --mode train --config $RUNDIR/config/config_cspca_train.yaml
requirements.txt CHANGED
@@ -14,6 +14,7 @@ triton==3.1.0
14
 
15
  # ---- MONAI / medical imaging ----
16
  monai==1.4.0
 
17
  SimpleITK==2.4.0
18
  pynrrd==1.1.1
19
  nibabel==5.3.2
@@ -35,7 +36,6 @@ tensorboard==2.18.0
35
 
36
  # ---- Utilities ----
37
  tqdm==4.67.1
38
- gdown==5.2.0
39
  requests
40
  filelock
41
  packaging
@@ -44,7 +44,4 @@ packaging
44
  streamlit==1.50.0
45
 
46
  # ---- Grad-CAM ----
47
- grad-cam @ git+https://github.com/jacobgil/pytorch-grad-cam.git@781dbc0d16ffa95b6d18b96b7b829840a82d93d1
48
-
49
- # ---- Your external dependency ----
50
- -e git+https://github.com/ai-assisted-healthcare/AIAH_utility.git@368233822b057b6bfef88f9e4b23c2967ae7bb35#egg=AIAH_utility
 
14
 
15
  # ---- MONAI / medical imaging ----
16
  monai==1.4.0
17
+ itk>=5.3.0
18
  SimpleITK==2.4.0
19
  pynrrd==1.1.1
20
  nibabel==5.3.2
 
36
 
37
  # ---- Utilities ----
38
  tqdm==4.67.1
 
39
  requests
40
  filelock
41
  packaging
 
44
  streamlit==1.50.0
45
 
46
  # ---- Grad-CAM ----
47
+ grad-cam @ git+https://github.com/jacobgil/pytorch-grad-cam.git@781dbc0d16ffa95b6d18b96b7b829840a82d93d1
 
 
 
run_pirads.py CHANGED
@@ -75,7 +75,7 @@ def main_worker(args):
75
  val_loss_min = float("inf")
76
  epochs_no_improve = 0
77
  for epoch in range(start_epoch, n_epochs):
78
- logging.info(time.ctime(), "Epoch:", epoch)
79
  epoch_time = time.time()
80
  train_loss, train_acc, train_att_loss, batch_norm = train_epoch(
81
  model, train_loader, optimizer, scaler=scaler, epoch=epoch, args=args
 
75
  val_loss_min = float("inf")
76
  epochs_no_improve = 0
77
  for epoch in range(start_epoch, n_epochs):
78
+ logging.info(f"{time.ctime()} | Epoch: {epoch}")
79
  epoch_time = time.time()
80
  train_loss, train_acc, train_att_loss, batch_norm = train_epoch(
81
  model, train_loader, optimizer, scaler=scaler, epoch=epoch, args=args
src/utils.py CHANGED
@@ -35,7 +35,6 @@ def save_cspca_checkpoint(model, val_metric, model_dir):
35
  "auc": val_metric["auc"],
36
  "sensitivity": val_metric["sensitivity"],
37
  "specificity": val_metric["specificity"],
38
- "state": val_metric["state"],
39
  "state_dict": state_dict,
40
  }
41
  torch.save(save_dict, os.path.join(model_dir, "cspca_model.pth"))
 
35
  "auc": val_metric["auc"],
36
  "sensitivity": val_metric["sensitivity"],
37
  "specificity": val_metric["specificity"],
 
38
  "state_dict": state_dict,
39
  }
40
  torch.save(save_dict, os.path.join(model_dir, "cspca_model.pth"))