Spaces:
Runtime error
Runtime error
Anirudh Balaraman commited on
Commit ·
c67c387
1
Parent(s): 16c0de3
fix cspca inference
Browse files- config/config_cspca_train_2.yaml +17 -0
- job_scripts/train_cspca_2.sh +19 -0
- run_cspca.py +2 -2
config/config_cspca_train_2.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
|
| 2 |
+
dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/dataset/PICAI_cspca.json
|
| 3 |
+
num_classes: !!int 4
|
| 4 |
+
mil_mode: att_trans
|
| 5 |
+
tile_count: !!int 24
|
| 6 |
+
tile_size: !!int 64
|
| 7 |
+
depth: !!int 3
|
| 8 |
+
use_heatmap: !!bool True
|
| 9 |
+
workers: !!int 6
|
| 10 |
+
checkpoint_pirads: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/logs/pirads_training/model_70.pt
|
| 11 |
+
epochs: !!int 80
|
| 12 |
+
batch_size: !!int 8
|
| 13 |
+
optim_lr: !!float 2e-5
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
job_scripts/train_cspca_2.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=cspca_training_70 # Specify job name
|
| 3 |
+
#SBATCH --partition=gpu # Specify partition name
|
| 4 |
+
#SBATCH --mem=128G
|
| 5 |
+
#SBATCH --gres=gpu:1
|
| 6 |
+
#SBATCH --time=48:00:00 # Set a limit on the total run time
|
| 7 |
+
#SBATCH --output=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/logs/%x/log.o%j # File name for standard output
|
| 8 |
+
#SBATCH --error=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/logs/%x/log.e%j # File name for standard error output
|
| 9 |
+
#SBATCH --mail-user=anirudh.balaraman@charite.de
|
| 10 |
+
#SBATCH --mail-type=END,FAIL
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
source /etc/profile.d/conda.sh
|
| 14 |
+
conda activate foundation
|
| 15 |
+
|
| 16 |
+
RUNDIR="/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
srun python -u $RUNDIR/run_cspca.py --mode train --config $RUNDIR/config/config_cspca_train_2.yaml
|
run_cspca.py
CHANGED
|
@@ -51,11 +51,11 @@ def main_worker(args):
|
|
| 51 |
cspca_model, train_loader, optimizer, epoch=epoch, args=args
|
| 52 |
)
|
| 53 |
logging.info(
|
| 54 |
-
f"
|
| 55 |
)
|
| 56 |
val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
|
| 57 |
logging.info(
|
| 58 |
-
f"
|
| 59 |
)
|
| 60 |
if val_metric["loss"] < old_loss:
|
| 61 |
old_loss = val_metric["loss"]
|
|
|
|
| 51 |
cspca_model, train_loader, optimizer, epoch=epoch, args=args
|
| 52 |
)
|
| 53 |
logging.info(
|
| 54 |
+
f"EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}"
|
| 55 |
)
|
| 56 |
val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
|
| 57 |
logging.info(
|
| 58 |
+
f"EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
|
| 59 |
)
|
| 60 |
if val_metric["loss"] < old_loss:
|
| 61 |
old_loss = val_metric["loss"]
|