Anirudh Balaraman commited on
Commit
c67c387
·
1 Parent(s): 16c0de3

fix cspca inference

Browse files
config/config_cspca_train_2.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
2
+ dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/dataset/PICAI_cspca.json
3
+ num_classes: !!int 4
4
+ mil_mode: att_trans
5
+ tile_count: !!int 24
6
+ tile_size: !!int 64
7
+ depth: !!int 3
8
+ use_heatmap: !!bool True
9
+ workers: !!int 6
10
+ checkpoint_pirads: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/logs/pirads_training/model_70.pt
11
+ epochs: !!int 80
12
+ batch_size: !!int 8
13
+ optim_lr: !!float 2e-5
14
+
15
+
16
+
17
+
job_scripts/train_cspca_2.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=cspca_training_70 # Specify job name
3
+ #SBATCH --partition=gpu # Specify partition name
4
+ #SBATCH --mem=128G
5
+ #SBATCH --gres=gpu:1
6
+ #SBATCH --time=48:00:00 # Set a limit on the total run time
7
+ #SBATCH --output=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/logs/%x/log.o%j # File name for standard output
8
+ #SBATCH --error=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/logs/%x/log.e%j # File name for standard error output
9
+ #SBATCH --mail-user=anirudh.balaraman@charite.de
10
+ #SBATCH --mail-type=END,FAIL
11
+
12
+
13
+ source /etc/profile.d/conda.sh
14
+ conda activate foundation
15
+
16
+ RUNDIR="/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate"
17
+
18
+
19
+ srun python -u $RUNDIR/run_cspca.py --mode train --config $RUNDIR/config/config_cspca_train_2.yaml
run_cspca.py CHANGED
@@ -51,11 +51,11 @@ def main_worker(args):
51
  cspca_model, train_loader, optimizer, epoch=epoch, args=args
52
  )
53
  logging.info(
54
- f"STATE {st} EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}"
55
  )
56
  val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
57
  logging.info(
58
- f"STATE {st} EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
59
  )
60
  if val_metric["loss"] < old_loss:
61
  old_loss = val_metric["loss"]
 
51
  cspca_model, train_loader, optimizer, epoch=epoch, args=args
52
  )
53
  logging.info(
54
+ f"EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}"
55
  )
56
  val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
57
  logging.info(
58
+ f"EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
59
  )
60
  if val_metric["loss"] < old_loss:
61
  old_loss = val_metric["loss"]