Anirudh Balaraman commited on
Commit
1baebae
·
1 Parent(s): 2a68513
.gitignore CHANGED
@@ -5,4 +5,5 @@ temp_data/
5
  temp.ipynb
6
  __pycache__/
7
  **/__pycache__/
8
- *.pyc
 
 
5
  temp.ipynb
6
  __pycache__/
7
  **/__pycache__/
8
+ *.pyc
9
+ .ruff_cache
config/config_cspca_test.yaml CHANGED
@@ -1,6 +1,5 @@
1
- project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
2
  data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
3
- dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/dataset/PICAI_cspca.json
4
  num_classes: !!int 4
5
  mil_mode: att_trans
6
  tile_count: !!int 24
@@ -8,10 +7,7 @@ tile_size: !!int 64
8
  depth: !!int 3
9
  use_heatmap: !!bool True
10
  workers: !!int 6
11
- checkpoint_cspca: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/models/cspca_model.pth
12
- num_seeds: !!int 2
13
- batch_size: !!int 1
14
-
15
-
16
 
17
 
 
 
1
  data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
2
+ dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/dataset/PICAI_cspca.json
3
  num_classes: !!int 4
4
  mil_mode: att_trans
5
  tile_count: !!int 24
 
7
  depth: !!int 3
8
  use_heatmap: !!bool True
9
  workers: !!int 6
10
+ checkpoint_cspca: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/models/cspca_model.pth
11
+ batch_size: !!int 8
 
 
 
12
 
13
 
config/config_cspca_train.yaml CHANGED
@@ -1,6 +1,5 @@
1
- project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
2
  data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
3
- dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/dataset/PICAI_cspca.json
4
  num_classes: !!int 4
5
  mil_mode: att_trans
6
  tile_count: !!int 24
@@ -8,12 +7,11 @@ tile_size: !!int 64
8
  depth: !!int 3
9
  use_heatmap: !!bool True
10
  workers: !!int 6
11
- checkpoint_pirads: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/models/pirads.pt
12
- epochs: !!int 1
13
  batch_size: !!int 8
14
  optim_lr: !!float 2e-4
15
 
16
 
17
 
18
 
19
-
 
 
1
  data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
2
+ dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/dataset/PICAI_cspca.json
3
  num_classes: !!int 4
4
  mil_mode: att_trans
5
  tile_count: !!int 24
 
7
  depth: !!int 3
8
  use_heatmap: !!bool True
9
  workers: !!int 6
10
+ checkpoint_pirads: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/models/pirads.pt
11
+ epochs: !!int 80
12
  batch_size: !!int 8
13
  optim_lr: !!float 2e-4
14
 
15
 
16
 
17
 
 
config/config_pirads_test.yaml CHANGED
@@ -1,18 +1,15 @@
1
  run_name: pirads_test_run
2
- project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
3
  data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/prostate-foundation/PICAI_registered/t2_hist_matched
4
- dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/dataset/PI-RADS_data.json
5
  num_classes: !!int 4
6
  mil_mode: att_trans
7
  tile_count: !!int 24
8
  tile_size: !!int 64
9
  depth: !!int 3
10
  use_heatmap: !!bool True
11
- workers: !!int 0
12
- checkpoint: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/models/pirads.pt
13
  amp: !!bool True
14
- dry_run: !!bool True
15
-
16
 
17
 
18
 
 
1
  run_name: pirads_test_run
 
2
  data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/prostate-foundation/PICAI_registered/t2_hist_matched
3
+ dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/dataset/PI-RADS_data.json
4
  num_classes: !!int 4
5
  mil_mode: att_trans
6
  tile_count: !!int 24
7
  tile_size: !!int 64
8
  depth: !!int 3
9
  use_heatmap: !!bool True
10
+ workers: !!int 8
11
+ checkpoint: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/models/pirads.pt
12
  amp: !!bool True
 
 
13
 
14
 
15
 
config/config_pirads_train.yaml CHANGED
@@ -1,21 +1,18 @@
1
- project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
2
  data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/prostate-foundation/PICAI_registered/t2_hist_matched
3
- dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/dataset/PI-RADS_data.json
4
  num_classes: !!int 4
5
  mil_mode: att_trans
6
  tile_count: !!int 24
7
  tile_size: !!int 64
8
  depth: !!int 3
9
  use_heatmap: !!bool True
10
- workers: !!int 0
11
- checkpoint_pirads: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/models/pirads.pt
12
- epochs: !!int 2
13
  batch_size: !!int 8
14
  optim_lr: !!float 2e-4
15
  weight_decay: !!float 1e-5
16
  amp: !!bool True
17
  wandb: !!bool True
18
- dry_run: !!bool True
19
 
20
 
21
 
 
 
1
  data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/prostate-foundation/PICAI_registered/t2_hist_matched
2
+ dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/dataset/PI-RADS_data.json
3
  num_classes: !!int 4
4
  mil_mode: att_trans
5
  tile_count: !!int 24
6
  tile_size: !!int 64
7
  depth: !!int 3
8
  use_heatmap: !!bool True
9
+ workers: !!int 4
10
+ epochs: !!int 100
 
11
  batch_size: !!int 8
12
  optim_lr: !!float 2e-4
13
  weight_decay: !!float 1e-5
14
  amp: !!bool True
15
  wandb: !!bool True
 
16
 
17
 
18
 
job_scripts/train_pirads.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=pirads_training # Specify job name
3
+ #SBATCH --partition=gpu # Specify partition name
4
+ #SBATCH --mem=128G
5
+ #SBATCH --gres=gpu:1
6
+ #SBATCH --time=48:00:00 # Set a limit on the total run time
7
+ #SBATCH --output=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/logs/%x/log.o%j # File name for standard output
8
+ #SBATCH --error=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/logs/%x/log.e%j # File name for standard error output
9
+ #SBATCH --mail-user=anirudh.balaraman@charite.de
10
+ #SBATCH --mail-type=END,FAIL
11
+
12
+
13
+ source /etc/profile.d/conda.sh
14
+ conda activate foundation
15
+
16
+ RUNDIR="/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate"
17
+
18
+
19
+ srun python -u $RUNDIR/run_pirads.py --mode train --config $RUNDIR/config/config_pirads_train.yaml
preprocess_main.py CHANGED
@@ -1,68 +1,61 @@
1
- import SimpleITK as sitk
2
  import os
3
- import numpy as np
4
- import nrrd
5
- from AIAH_utility.viewer import BasicViewer
6
- from tqdm import tqdm
7
- import pandas as pd
8
- from picai_prep.preprocessing import PreprocessingSettings, Sample
9
- import multiprocessing
10
- import sys
11
  from src.preprocessing.register_and_crop import register_files
12
  from src.preprocessing.prostate_mask import get_segmask
13
  from src.preprocessing.histogram_match import histmatch
14
  from src.preprocessing.generate_heatmap import get_heatmap
15
  import logging
16
- from pathlib import Path
17
  from src.utils import setup_logging
18
  from src.utils import validate_steps
19
  import argparse
20
- import yaml
 
21
 
22
  def parse_args():
23
- FUNCTIONS = {
24
- "register_and_crop": register_files,
25
- "histogram_match": histmatch,
26
- "get_segmentation_mask": get_segmask,
27
- "get_heatmap": get_heatmap,
28
- }
29
  parser = argparse.ArgumentParser(description="File preprocessing")
30
  parser.add_argument("--config", type=str, help="Path to YAML config file")
31
  parser.add_argument(
32
  "--steps",
33
- nargs="+", # ← list of strings
34
- choices=FUNCTIONS.keys(), # ← restrict allowed values
 
 
 
 
 
35
  required=True,
36
- help="Steps to execute (one or more)"
37
  )
38
  parser.add_argument("--t2_dir", default=None, help="Path to T2W files")
39
  parser.add_argument("--dwi_dir", default=None, help="Path to DWI files")
40
  parser.add_argument("--adc_dir", default=None, help="Path to ADC files")
41
  parser.add_argument("--seg_dir", default=None, help="Path to segmentation masks")
42
  parser.add_argument("--output_dir", default=None, help="Path to output folder")
43
- parser.add_argument("--margin", default=0.2, type=float, help="Margin to center crop the images")
 
 
44
  parser.add_argument("--project_dir", default=None, help="Project directory")
45
-
46
  args = parser.parse_args()
47
  if args.config:
48
- with open(args.config, 'r') as config_file:
49
  config = yaml.safe_load(config_file)
50
  args.__dict__.update(config)
51
  return args
52
 
 
53
  if __name__ == "__main__":
54
  args = parse_args()
55
  FUNCTIONS = {
56
- "register_and_crop": register_files,
57
- "histogram_match": histmatch,
58
- "get_segmentation_mask": get_segmask,
59
- "get_heatmap": get_heatmap,
60
  }
61
 
62
- args.logfile = os.path.join(args.output_dir, f"preprocessing.log")
63
  setup_logging(args.logfile)
64
  logging.info("Starting preprocessing")
65
  validate_steps(args.steps)
66
  for step in args.steps:
67
  func = FUNCTIONS[step]
68
- args = func(args)
 
 
1
  import os
 
 
 
 
 
 
 
 
2
  from src.preprocessing.register_and_crop import register_files
3
  from src.preprocessing.prostate_mask import get_segmask
4
  from src.preprocessing.histogram_match import histmatch
5
  from src.preprocessing.generate_heatmap import get_heatmap
6
  import logging
 
7
  from src.utils import setup_logging
8
  from src.utils import validate_steps
9
  import argparse
10
+ import yaml
11
+
12
 
13
  def parse_args():
 
 
 
 
 
 
14
  parser = argparse.ArgumentParser(description="File preprocessing")
15
  parser.add_argument("--config", type=str, help="Path to YAML config file")
16
  parser.add_argument(
17
  "--steps",
18
+ nargs="+", # ← list of strings
19
+ choices=[
20
+ "register_and_crop",
21
+ "histogram_match",
22
+ "get_segmentation_mask",
23
+ "get_heatmap",
24
+ ], # ← restrict allowed values
25
  required=True,
26
+ help="Steps to execute (one or more)",
27
  )
28
  parser.add_argument("--t2_dir", default=None, help="Path to T2W files")
29
  parser.add_argument("--dwi_dir", default=None, help="Path to DWI files")
30
  parser.add_argument("--adc_dir", default=None, help="Path to ADC files")
31
  parser.add_argument("--seg_dir", default=None, help="Path to segmentation masks")
32
  parser.add_argument("--output_dir", default=None, help="Path to output folder")
33
+ parser.add_argument(
34
+ "--margin", default=0.2, type=float, help="Margin to center crop the images"
35
+ )
36
  parser.add_argument("--project_dir", default=None, help="Project directory")
37
+
38
  args = parser.parse_args()
39
  if args.config:
40
+ with open(args.config, "r") as config_file:
41
  config = yaml.safe_load(config_file)
42
  args.__dict__.update(config)
43
  return args
44
 
45
+
46
  if __name__ == "__main__":
47
  args = parse_args()
48
  FUNCTIONS = {
49
+ "register_and_crop": register_files,
50
+ "histogram_match": histmatch,
51
+ "get_segmentation_mask": get_segmask,
52
+ "get_heatmap": get_heatmap,
53
  }
54
 
55
+ args.logfile = os.path.join(args.output_dir, "preprocessing.log")
56
  setup_logging(args.logfile)
57
  logging.info("Starting preprocessing")
58
  validate_steps(args.steps)
59
  for step in args.steps:
60
  func = FUNCTIONS[step]
61
+ args = func(args)
pyproject.toml CHANGED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [tool.ruff]
2
+ line-length = 100
3
+
4
+ [tool.ruff.lint]
5
+ select = ["E", "W"]
6
+ ignore = ["E501"]
7
+
8
+ [tool.ruff.format]
9
+ quote-style = "double"
run_cspca.py CHANGED
@@ -1,120 +1,102 @@
1
  import argparse
2
  import os
3
  import shutil
4
- import time
5
  import yaml
6
  import sys
7
- import gdown
8
- import numpy as np
9
  import torch
10
- import torch.distributed as dist
11
- import torch.multiprocessing as mp
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from monai.config import KeysCollection
15
- from monai.metrics import Cumulative, CumulativeAverage
16
- from monai.networks.nets import milmodel, resnet, MILModel
17
-
18
- from sklearn.metrics import cohen_kappa_score
19
- from torch.cuda.amp import GradScaler, autocast
20
- from torch.utils.data.dataloader import default_collate
21
- from torchvision.models.resnet import ResNet50_Weights
22
- import shutil
23
  from pathlib import Path
24
- from torch.utils.data.distributed import DistributedSampler
25
- from torch.utils.tensorboard import SummaryWriter
26
  from monai.utils import set_determinism
27
-
28
- import matplotlib.pyplot as plt
29
-
30
- import wandb
31
- import math
32
  import logging
33
- from pathlib import Path
34
-
35
-
36
  from src.model.MIL import MILModel_3D
37
  from src.model.csPCa_model import csPCa_Model
38
  from src.data.data_loader import get_dataloader
39
  from src.utils import save_cspca_checkpoint, get_metrics, setup_logging
40
  from src.train.train_cspca import train_epoch, val_epoch
 
41
 
42
- def main_worker(args):
43
 
44
- mil_model = MILModel_3D(
45
- num_classes=args.num_classes,
46
- mil_mode=args.mil_mode
47
- ).to(args.device)
48
  cache_dir_path = Path(os.path.join(args.logdir, "cache"))
49
 
50
- if args.mode == 'train':
51
-
52
  checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
53
  mil_model.load_state_dict(checkpoint["state_dict"])
54
  mil_model = mil_model.to(args.device)
55
-
56
- model_dir = os.path.join(args.project_dir,'models')
57
- metrics_dict = {'auc':[], 'sensitivity':[], 'specificity':[]}
 
 
58
  for st in list(range(args.num_seeds)):
59
  set_determinism(seed=st)
60
-
61
  train_loader = get_dataloader(args, split="train")
62
  valid_loader = get_dataloader(args, split="test")
63
  cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
64
- for submodule in [cspca_model.backbone.net,
65
- cspca_model.backbone.myfc,
66
- cspca_model.backbone.transformer]:
 
 
67
  for param in submodule.parameters():
68
  param.requires_grad = False
69
 
70
- optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, cspca_model.parameters()), lr=args.optim_lr)
 
 
71
 
72
- old_loss = float('inf')
73
  old_auc = 0.0
74
  for epoch in range(args.epochs):
75
- train_loss, train_auc = train_epoch(cspca_model, train_loader, optimizer, epoch=epoch, args=args)
76
- logging.info(f"STATE {st} EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}")
 
 
 
 
77
  val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
78
- logging.info(f"STATE {st} EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}")
79
- val_metric['state'] = st
80
- if val_metric['loss'] < old_loss:
81
- old_loss = val_metric['loss']
82
- old_auc = val_metric['auc']
83
- sensitivity = val_metric['sensitivity']
84
- specificity = val_metric['specificity']
85
- if len(metrics_dict['auc']) == 0:
86
- save_cspca_checkpoint(cspca_model, val_metric, model_dir)
87
- elif val_metric['auc'] >= max(metrics_dict['auc']):
88
  save_cspca_checkpoint(cspca_model, val_metric, model_dir)
89
 
90
- metrics_dict['auc'].append(old_auc)
91
- metrics_dict['sensitivity'].append(sensitivity)
92
- metrics_dict['specificity'].append(specificity)
93
  if cache_dir_path.exists() and cache_dir_path.is_dir():
94
  shutil.rmtree(cache_dir_path)
95
 
96
  get_metrics(metrics_dict)
97
 
98
- elif args.mode == 'test':
99
-
100
  cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
101
  checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
102
- cspca_model.load_state_dict(checkpt['state_dict'])
103
  cspca_model = cspca_model.to(args.device)
104
- if 'auc' in checkpt and 'sensitivity' in checkpt and 'specificity' in checkpt:
105
- auc, sens, spec = checkpt['auc'], checkpt['sensitivity'], checkpt['specificity']
106
- logging.info(f"csPCa Model loaded from {args.checkpoint_cspca} with AUC: {auc}, Sensitivity: {sens}, Specificity: {spec} on the test set.")
 
 
107
  else:
108
  logging.info(f"csPCa Model loaded from {args.checkpoint_cspca}.")
109
-
110
- metrics_dict = {'auc':[], 'sensitivity':[], 'specificity':[]}
111
  for st in list(range(args.num_seeds)):
112
  set_determinism(seed=st)
113
  test_loader = get_dataloader(args, split="test")
114
  test_metric = val_epoch(cspca_model, test_loader, epoch=0, args=args)
115
- metrics_dict['auc'].append(test_metric['auc'])
116
- metrics_dict['sensitivity'].append(test_metric['sensitivity'])
117
- metrics_dict['specificity'].append(test_metric['specificity'])
118
 
119
  if cache_dir_path.exists() and cache_dir_path.is_dir():
120
  shutil.rmtree(cache_dir_path)
@@ -122,43 +104,58 @@ def main_worker(args):
122
  get_metrics(metrics_dict)
123
 
124
 
125
-
126
-
127
  def parse_args():
128
- parser = argparse.ArgumentParser(description="Multiple Instance Learning (MIL) for csPCa risk prediction.")
129
- parser.add_argument('--mode', type=str, choices=['train', 'test'], required=True, help='Operation mode: train or infer')
130
- parser.add_argument('--run_name', type=str, default='train_cspca', help='run name for log file')
131
- parser.add_argument('--config', type=str, help='Path to YAML config file')
132
- parser.add_argument(
133
- "--project_dir", default=None, help="path to project firectory"
134
  )
135
  parser.add_argument(
136
- "--data_root", default=None, help="path to root folder of images"
 
 
 
 
137
  )
 
 
 
 
138
  parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
139
  parser.add_argument("--num_classes", default=4, type=int, help="number of output classes")
140
- parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm: choose either att_trans or att_pyramid")
141
  parser.add_argument(
142
- "--tile_count", default=24, type=int, help="number of patches (instances) to extract from MRI input"
 
 
143
  )
144
- parser.add_argument("--tile_size", default=64, type=int, help="size of square patch (instance) in pixels")
145
- parser.add_argument("--depth", default=3, type=int, help="number of slices in each 3D patch (instance)")
146
  parser.add_argument(
147
- "--use_heatmap", action="store_true",
148
- help="enable weak attention heatmap guided patch generation"
 
 
149
  )
150
  parser.add_argument(
151
- "--no_heatmap", dest="use_heatmap", action="store_false",
152
- help="disable heatmap"
 
 
 
 
 
 
 
 
 
 
153
  )
154
  parser.set_defaults(use_heatmap=True)
155
  parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
156
- #parser.add_argument("--dry-run", action="store_true")
157
  parser.add_argument("--checkpoint_pirads", default=None, help="Load PI-RADS model")
158
- parser.add_argument("--epochs", "--max_epochs", default=30, type=int, help="number of training epochs")
 
 
159
  parser.add_argument("--batch_size", default=32, type=int, help="number of MRI scans per batch")
160
  parser.add_argument("--optim_lr", default=2e-4, type=float, help="initial learning rate")
161
- #parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
162
  parser.add_argument(
163
  "--val_every",
164
  "--val_interval",
@@ -166,42 +163,50 @@ def parse_args():
166
  type=int,
167
  help="run validation after this number of epochs, default 1 to run every epoch",
168
  )
169
- parser.add_argument("--dry_run", action="store_true", help="Run the script in dry-run mode (default: False)")
 
 
170
  parser.add_argument("--checkpoint_cspca", default=None, help="load existing checkpoint")
171
- parser.add_argument("--num_seeds", default=20, type=int, help="number of seeds to be run to build CI")
 
 
172
  args = parser.parse_args()
173
  if args.config:
174
- with open(args.config, 'r') as config_file:
175
  config = yaml.safe_load(config_file)
176
  args.__dict__.update(config)
177
 
178
-
179
-
180
  return args
181
 
182
 
183
-
184
  if __name__ == "__main__":
 
185
  args = parse_args()
 
 
 
 
 
 
 
186
  args.logdir = os.path.join(args.project_dir, "logs", args.run_name)
187
  os.makedirs(args.logdir, exist_ok=True)
188
  args.logfile = os.path.join(args.logdir, f"{args.run_name}.log")
189
  setup_logging(args.logfile)
190
 
191
-
192
  logging.info("Argument values:")
193
  for k, v in vars(args).items():
194
  logging.info(f"{k} => {v}")
195
  logging.info("-----------------")
196
 
197
  if args.dataset_json is None:
198
- logging.error('Dataset path not provided. Quitting.')
199
  sys.exit(1)
200
- if args.checkpoint_pirads is None and args.mode == 'train':
201
- logging.error('PI-RADS checkpoint path not provided. Quitting.')
202
  sys.exit(1)
203
- elif args.checkpoint_cspca is None and args.mode == 'test':
204
- logging.error('csPCa checkpoint path not provided. Quitting.')
205
  sys.exit(1)
206
  args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
207
 
@@ -216,5 +221,4 @@ if __name__ == "__main__":
216
  args.num_seeds = 2
217
  args.wandb = False
218
 
219
-
220
  main_worker(args)
 
1
  import argparse
2
  import os
3
  import shutil
 
4
  import yaml
5
  import sys
 
 
6
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from pathlib import Path
 
 
8
  from monai.utils import set_determinism
 
 
 
 
 
9
  import logging
 
 
 
10
  from src.model.MIL import MILModel_3D
11
  from src.model.csPCa_model import csPCa_Model
12
  from src.data.data_loader import get_dataloader
13
  from src.utils import save_cspca_checkpoint, get_metrics, setup_logging
14
  from src.train.train_cspca import train_epoch, val_epoch
15
+ import random
16
 
 
17
 
18
+ def main_worker(args):
19
+ mil_model = MILModel_3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
 
 
20
  cache_dir_path = Path(os.path.join(args.logdir, "cache"))
21
 
22
+ if args.mode == "train":
 
23
  checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
24
  mil_model.load_state_dict(checkpoint["state_dict"])
25
  mil_model = mil_model.to(args.device)
26
+
27
+ model_dir = os.path.join(args.logdir, "models")
28
+ os.makedirs(model_dir, exist_ok=True)
29
+
30
+ metrics_dict = {"auc": [], "sensitivity": [], "specificity": []}
31
  for st in list(range(args.num_seeds)):
32
  set_determinism(seed=st)
33
+
34
  train_loader = get_dataloader(args, split="train")
35
  valid_loader = get_dataloader(args, split="test")
36
  cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
37
+ for submodule in [
38
+ cspca_model.backbone.net,
39
+ cspca_model.backbone.myfc,
40
+ cspca_model.backbone.transformer,
41
+ ]:
42
  for param in submodule.parameters():
43
  param.requires_grad = False
44
 
45
+ optimizer = torch.optim.AdamW(
46
+ filter(lambda p: p.requires_grad, cspca_model.parameters()), lr=args.optim_lr
47
+ )
48
 
49
+ old_loss = float("inf")
50
  old_auc = 0.0
51
  for epoch in range(args.epochs):
52
+ train_loss, train_auc = train_epoch(
53
+ cspca_model, train_loader, optimizer, epoch=epoch, args=args
54
+ )
55
+ logging.info(
56
+ f"STATE {st} EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}"
57
+ )
58
  val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
59
+ logging.info(
60
+ f"STATE {st} EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
61
+ )
62
+ val_metric["state"] = st
63
+ if val_metric["loss"] < old_loss:
64
+ old_loss = val_metric["loss"]
65
+ old_auc = val_metric["auc"]
66
+ sensitivity = val_metric["sensitivity"]
67
+ specificity = val_metric["specificity"]
68
+ if not metrics_dict["auc"] or val_metric["auc"] >= max(metrics_dict["auc"]):
69
  save_cspca_checkpoint(cspca_model, val_metric, model_dir)
70
 
71
+ metrics_dict["auc"].append(old_auc)
72
+ metrics_dict["sensitivity"].append(sensitivity)
73
+ metrics_dict["specificity"].append(specificity)
74
  if cache_dir_path.exists() and cache_dir_path.is_dir():
75
  shutil.rmtree(cache_dir_path)
76
 
77
  get_metrics(metrics_dict)
78
 
79
+ elif args.mode == "test":
 
80
  cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
81
  checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
82
+ cspca_model.load_state_dict(checkpt["state_dict"])
83
  cspca_model = cspca_model.to(args.device)
84
+ if "auc" in checkpt and "sensitivity" in checkpt and "specificity" in checkpt:
85
+ auc, sens, spec = checkpt["auc"], checkpt["sensitivity"], checkpt["specificity"]
86
+ logging.info(
87
+ f"csPCa Model loaded from {args.checkpoint_cspca} with AUC: {auc}, Sensitivity: {sens}, Specificity: {spec} on the test set."
88
+ )
89
  else:
90
  logging.info(f"csPCa Model loaded from {args.checkpoint_cspca}.")
91
+
92
+ metrics_dict = {"auc": [], "sensitivity": [], "specificity": []}
93
  for st in list(range(args.num_seeds)):
94
  set_determinism(seed=st)
95
  test_loader = get_dataloader(args, split="test")
96
  test_metric = val_epoch(cspca_model, test_loader, epoch=0, args=args)
97
+ metrics_dict["auc"].append(test_metric["auc"])
98
+ metrics_dict["sensitivity"].append(test_metric["sensitivity"])
99
+ metrics_dict["specificity"].append(test_metric["specificity"])
100
 
101
  if cache_dir_path.exists() and cache_dir_path.is_dir():
102
  shutil.rmtree(cache_dir_path)
 
104
  get_metrics(metrics_dict)
105
 
106
 
 
 
107
  def parse_args():
108
+ parser = argparse.ArgumentParser(
109
+ description="Multiple Instance Learning (MIL) for csPCa risk prediction."
 
 
 
 
110
  )
111
  parser.add_argument(
112
+ "--mode",
113
+ type=str,
114
+ choices=["train", "test"],
115
+ required=True,
116
+ help="Operation mode: train or infer",
117
  )
118
+ parser.add_argument("--run_name", type=str, default="train_cspca", help="run name for log file")
119
+ parser.add_argument("--config", type=str, help="Path to YAML config file")
120
+ parser.add_argument("--project_dir", default=None, help="path to project firectory")
121
+ parser.add_argument("--data_root", default=None, help="path to root folder of images")
122
  parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
123
  parser.add_argument("--num_classes", default=4, type=int, help="number of output classes")
 
124
  parser.add_argument(
125
+ "--mil_mode",
126
+ default="att_trans",
127
+ help="MIL algorithm: choose either att_trans or att_pyramid",
128
  )
 
 
129
  parser.add_argument(
130
+ "--tile_count",
131
+ default=24,
132
+ type=int,
133
+ help="number of patches (instances) to extract from MRI input",
134
  )
135
  parser.add_argument(
136
+ "--tile_size", default=64, type=int, help="size of square patch (instance) in pixels"
137
+ )
138
+ parser.add_argument(
139
+ "--depth", default=3, type=int, help="number of slices in each 3D patch (instance)"
140
+ )
141
+ parser.add_argument(
142
+ "--use_heatmap",
143
+ action="store_true",
144
+ help="enable weak attention heatmap guided patch generation",
145
+ )
146
+ parser.add_argument(
147
+ "--no_heatmap", dest="use_heatmap", action="store_false", help="disable heatmap"
148
  )
149
  parser.set_defaults(use_heatmap=True)
150
  parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
151
+ # parser.add_argument("--dry-run", action="store_true")
152
  parser.add_argument("--checkpoint_pirads", default=None, help="Load PI-RADS model")
153
+ parser.add_argument(
154
+ "--epochs", "--max_epochs", default=30, type=int, help="number of training epochs"
155
+ )
156
  parser.add_argument("--batch_size", default=32, type=int, help="number of MRI scans per batch")
157
  parser.add_argument("--optim_lr", default=2e-4, type=float, help="initial learning rate")
158
+ # parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
159
  parser.add_argument(
160
  "--val_every",
161
  "--val_interval",
 
163
  type=int,
164
  help="run validation after this number of epochs, default 1 to run every epoch",
165
  )
166
+ parser.add_argument(
167
+ "--dry_run", action="store_true", help="Run the script in dry-run mode (default: False)"
168
+ )
169
  parser.add_argument("--checkpoint_cspca", default=None, help="load existing checkpoint")
170
+ parser.add_argument(
171
+ "--num_seeds", default=20, type=int, help="number of seeds to be run to build CI"
172
+ )
173
  args = parser.parse_args()
174
  if args.config:
175
+ with open(args.config, "r") as config_file:
176
  config = yaml.safe_load(config_file)
177
  args.__dict__.update(config)
178
 
 
 
179
  return args
180
 
181
 
 
182
  if __name__ == "__main__":
183
+
184
  args = parse_args()
185
+ if args.project_dir is None:
186
+ args.project_dir = Path(__file__).resolve().parent # Set project directory
187
+
188
+ slurm_job_name = os.getenv('SLURM_JOB_NAME') # If the script is submitted via slurm, job name is the run name
189
+ if slurm_job_name:
190
+ args.run_name = slurm_job_name
191
+
192
  args.logdir = os.path.join(args.project_dir, "logs", args.run_name)
193
  os.makedirs(args.logdir, exist_ok=True)
194
  args.logfile = os.path.join(args.logdir, f"{args.run_name}.log")
195
  setup_logging(args.logfile)
196
 
 
197
  logging.info("Argument values:")
198
  for k, v in vars(args).items():
199
  logging.info(f"{k} => {v}")
200
  logging.info("-----------------")
201
 
202
  if args.dataset_json is None:
203
+ logging.error("Dataset path not provided. Quitting.")
204
  sys.exit(1)
205
+ if args.checkpoint_pirads is None and args.mode == "train":
206
+ logging.error("PI-RADS checkpoint path not provided. Quitting.")
207
  sys.exit(1)
208
+ elif args.checkpoint_cspca is None and args.mode == "test":
209
+ logging.error("csPCa checkpoint path not provided. Quitting.")
210
  sys.exit(1)
211
  args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
212
 
 
221
  args.num_seeds = 2
222
  args.wandb = False
223
 
 
224
  main_worker(args)
run_inference.py CHANGED
@@ -1,66 +1,21 @@
1
  import argparse
2
  import os
3
- import shutil
4
- import time
5
  import yaml
6
- import sys
7
- import gdown
8
- import numpy as np
9
  import torch
10
- import torch.distributed as dist
11
- import torch.multiprocessing as mp
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from monai.config import KeysCollection
15
- from monai.metrics import Cumulative, CumulativeAverage
16
- from monai.networks.nets import milmodel, resnet, MILModel
17
-
18
- from sklearn.metrics import cohen_kappa_score
19
- from torch.cuda.amp import GradScaler, autocast
20
- from torch.utils.data.dataloader import default_collate
21
- from torchvision.models.resnet import ResNet50_Weights
22
- import shutil
23
- from pathlib import Path
24
- from torch.utils.data.distributed import DistributedSampler
25
- from torch.utils.tensorboard import SummaryWriter
26
- from monai.utils import set_determinism
27
- import matplotlib.pyplot as plt
28
- import wandb
29
- import math
30
  import logging
31
- from pathlib import Path
32
-
33
-
34
  from src.model.MIL import MILModel_3D
35
  from src.model.csPCa_model import csPCa_Model
36
- from src.data.data_loader import get_dataloader
37
- from src.utils import save_cspca_checkpoint, get_metrics, setup_logging, save_pirads_checkpoint, get_parent_image, get_patch_coordinate
38
- from src.train import train_cspca, train_pirads
39
- import SimpleITK as sitk
40
-
41
- import nrrd
42
-
43
- from tqdm import tqdm
44
- import pandas as pd
45
- from picai_prep.preprocessing import PreprocessingSettings, Sample
46
- import multiprocessing
47
- import sys
48
  from src.preprocessing.register_and_crop import register_files
49
  from src.preprocessing.prostate_mask import get_segmask
50
  from src.preprocessing.histogram_match import histmatch
51
  from src.preprocessing.generate_heatmap import get_heatmap
52
- import logging
53
- from pathlib import Path
54
- from src.utils import setup_logging
55
- from src.utils import validate_steps
56
- import argparse
57
- import yaml
58
  from src.data.data_loader import data_transform, list_data_collate
59
- from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
60
  import json
61
 
62
- def parse_args():
63
 
 
64
  parser = argparse.ArgumentParser(description="File preprocessing")
65
  parser.add_argument("--config", type=str, help="Path to YAML config file")
66
  parser.add_argument("--t2_dir", default=None, help="Path to T2W files")
@@ -68,7 +23,9 @@ def parse_args():
68
  parser.add_argument("--adc_dir", default=None, help="Path to ADC files")
69
  parser.add_argument("--seg_dir", default=None, help="Path to segmentation masks")
70
  parser.add_argument("--output_dir", default=None, help="Path to output folder")
71
- parser.add_argument("--margin", default=0.2, type=float, help="Margin to center crop the images")
 
 
72
  parser.add_argument("--num_classes", default=4, type=int)
73
  parser.add_argument("--mil_mode", default="att_trans", type=str)
74
  parser.add_argument("--use_heatmap", default=True, type=bool)
@@ -76,47 +33,49 @@ def parse_args():
76
  parser.add_argument("--tile_count", default=24, type=int)
77
  parser.add_argument("--depth", default=3, type=int)
78
  parser.add_argument("--project_dir", default=None, help="Project directory")
79
-
80
  args = parser.parse_args()
81
  if args.config:
82
- with open(args.config, 'r') as config_file:
83
  config = yaml.safe_load(config_file)
84
  args.__dict__.update(config)
85
  return args
86
 
 
87
  if __name__ == "__main__":
88
  args = parse_args()
89
  FUNCTIONS = {
90
- "register_and_crop": register_files,
91
- "histogram_match": histmatch,
92
- "get_segmentation_mask": get_segmask,
93
- "get_heatmap": get_heatmap,
94
  }
95
 
96
- args.logfile = os.path.join(args.output_dir, f"inference.log")
97
  setup_logging(args.logfile)
98
  logging.info("Starting preprocessing")
99
  steps = ["register_and_crop", "get_segmentation_mask", "histogram_match", "get_heatmap"]
100
  for step in steps:
101
  func = FUNCTIONS[step]
102
- args = func(args)
103
 
104
  logging.info("Preprocessing completed.")
105
 
106
  args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107
 
108
  logging.info("Loading PIRADS model")
109
- pirads_model = MILModel_3D(
110
- num_classes=args.num_classes,
111
- mil_mode=args.mil_mode
112
  )
113
- pirads_checkpoint = torch.load(os.path.join(args.project_dir, 'models', 'pirads.pt'), map_location="cpu")
114
  pirads_model.load_state_dict(pirads_checkpoint["state_dict"])
115
  pirads_model.to(args.device)
116
  logging.info("Loading csPCa model")
117
  cspca_model = csPCa_Model(backbone=pirads_model).to(args.device)
118
- checkpt = torch.load(os.path.join(args.project_dir, 'models', 'cspca_model.pth'), map_location="cpu")
119
- cspca_model.load_state_dict(checkpt['state_dict'])
 
 
120
  cspca_model = cspca_model.to(args.device)
121
 
122
  transform = data_transform(args)
@@ -124,12 +83,12 @@ if __name__ == "__main__":
124
  args.data_list = []
125
  for file in files:
126
  temp = {}
127
- temp['image'] = os.path.join(args.t2_dir, file)
128
- temp['dwi'] = os.path.join(args.dwi_dir, file)
129
- temp['adc'] = os.path.join(args.adc_dir, file)
130
- temp['heatmap'] = os.path.join(args.heatmapdir, file)
131
- temp['mask'] = os.path.join(args.seg_dir, file)
132
- temp['label'] = 0 # dummy label
133
  args.data_list.append(temp)
134
 
135
  dataset = Dataset(data=args.data_list, transform=transform)
@@ -139,7 +98,7 @@ if __name__ == "__main__":
139
  shuffle=False,
140
  num_workers=0,
141
  pin_memory=True,
142
- multiprocessing_context= None,
143
  sampler=None,
144
  collate_fn=list_data_collate,
145
  )
@@ -153,7 +112,7 @@ if __name__ == "__main__":
153
  for idx, batch_data in enumerate(loader):
154
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
155
  logits = pirads_model(data)
156
- pirads_score= torch.argmax(logits, dim=1)
157
  pirads_list.append(pirads_score.item())
158
 
159
  output = cspca_model(data)
@@ -181,17 +140,19 @@ if __name__ == "__main__":
181
  for i in args.data_list:
182
  parent_image = get_parent_image([i], args)
183
 
184
- coords = get_patch_coordinate(patches_top_5, parent_image, args)
185
  coords_list.append(coords)
186
  output_dict = {}
187
 
188
- for i,j in enumerate(files):
189
- logging.info(f"File: {j}, PIRADS score: {pirads_list[i]}, csPCa risk score: {cspca_risk_list[i]:.4f}")
190
-
 
191
  output_dict[j] = {
192
- 'Predicted PIRAD Score': pirads_list[i] + 2.0,
193
- 'csPCa risk': cspca_risk_list[i],
194
- 'Top left coordinate of top 5 patches(x,y,z)': coords_list[i],
195
  }
196
- with open(os.path.join(args.output_dir, "results.json"), 'w') as f:
197
- json.dump(output_dict, f, indent=4)
 
 
1
  import argparse
2
  import os
 
 
3
  import yaml
 
 
 
4
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import logging
 
 
 
6
  from src.model.MIL import MILModel_3D
7
  from src.model.csPCa_model import csPCa_Model
8
+ from src.utils import setup_logging, get_parent_image, get_patch_coordinate
 
 
 
 
 
 
 
 
 
 
 
9
  from src.preprocessing.register_and_crop import register_files
10
  from src.preprocessing.prostate_mask import get_segmask
11
  from src.preprocessing.histogram_match import histmatch
12
  from src.preprocessing.generate_heatmap import get_heatmap
 
 
 
 
 
 
13
  from src.data.data_loader import data_transform, list_data_collate
14
+ from monai.data import Dataset
15
  import json
16
 
 
17
 
18
+ def parse_args():
19
  parser = argparse.ArgumentParser(description="File preprocessing")
20
  parser.add_argument("--config", type=str, help="Path to YAML config file")
21
  parser.add_argument("--t2_dir", default=None, help="Path to T2W files")
 
23
  parser.add_argument("--adc_dir", default=None, help="Path to ADC files")
24
  parser.add_argument("--seg_dir", default=None, help="Path to segmentation masks")
25
  parser.add_argument("--output_dir", default=None, help="Path to output folder")
26
+ parser.add_argument(
27
+ "--margin", default=0.2, type=float, help="Margin to center crop the images"
28
+ )
29
  parser.add_argument("--num_classes", default=4, type=int)
30
  parser.add_argument("--mil_mode", default="att_trans", type=str)
31
  parser.add_argument("--use_heatmap", default=True, type=bool)
 
33
  parser.add_argument("--tile_count", default=24, type=int)
34
  parser.add_argument("--depth", default=3, type=int)
35
  parser.add_argument("--project_dir", default=None, help="Project directory")
36
+
37
  args = parser.parse_args()
38
  if args.config:
39
+ with open(args.config, "r") as config_file:
40
  config = yaml.safe_load(config_file)
41
  args.__dict__.update(config)
42
  return args
43
 
44
+
45
  if __name__ == "__main__":
46
  args = parse_args()
47
  FUNCTIONS = {
48
+ "register_and_crop": register_files,
49
+ "histogram_match": histmatch,
50
+ "get_segmentation_mask": get_segmask,
51
+ "get_heatmap": get_heatmap,
52
  }
53
 
54
+ args.logfile = os.path.join(args.output_dir, "inference.log")
55
  setup_logging(args.logfile)
56
  logging.info("Starting preprocessing")
57
  steps = ["register_and_crop", "get_segmentation_mask", "histogram_match", "get_heatmap"]
58
  for step in steps:
59
  func = FUNCTIONS[step]
60
+ args = func(args)
61
 
62
  logging.info("Preprocessing completed.")
63
 
64
  args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
 
66
  logging.info("Loading PIRADS model")
67
+ pirads_model = MILModel_3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
68
+ pirads_checkpoint = torch.load(
69
+ os.path.join(args.project_dir, "models", "pirads.pt"), map_location="cpu"
70
  )
 
71
  pirads_model.load_state_dict(pirads_checkpoint["state_dict"])
72
  pirads_model.to(args.device)
73
  logging.info("Loading csPCa model")
74
  cspca_model = csPCa_Model(backbone=pirads_model).to(args.device)
75
+ checkpt = torch.load(
76
+ os.path.join(args.project_dir, "models", "cspca_model.pth"), map_location="cpu"
77
+ )
78
+ cspca_model.load_state_dict(checkpt["state_dict"])
79
  cspca_model = cspca_model.to(args.device)
80
 
81
  transform = data_transform(args)
 
83
  args.data_list = []
84
  for file in files:
85
  temp = {}
86
+ temp["image"] = os.path.join(args.t2_dir, file)
87
+ temp["dwi"] = os.path.join(args.dwi_dir, file)
88
+ temp["adc"] = os.path.join(args.adc_dir, file)
89
+ temp["heatmap"] = os.path.join(args.heatmapdir, file)
90
+ temp["mask"] = os.path.join(args.seg_dir, file)
91
+ temp["label"] = 0 # dummy label
92
  args.data_list.append(temp)
93
 
94
  dataset = Dataset(data=args.data_list, transform=transform)
 
98
  shuffle=False,
99
  num_workers=0,
100
  pin_memory=True,
101
+ multiprocessing_context=None,
102
  sampler=None,
103
  collate_fn=list_data_collate,
104
  )
 
112
  for idx, batch_data in enumerate(loader):
113
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
114
  logits = pirads_model(data)
115
+ pirads_score = torch.argmax(logits, dim=1)
116
  pirads_list.append(pirads_score.item())
117
 
118
  output = cspca_model(data)
 
140
  for i in args.data_list:
141
  parent_image = get_parent_image([i], args)
142
 
143
+ coords = get_patch_coordinate(patches_top_5, parent_image)
144
  coords_list.append(coords)
145
  output_dict = {}
146
 
147
+ for i, j in enumerate(files):
148
+ logging.info(
149
+ f"File: {j}, PIRADS score: {pirads_list[i] + 2.0}, csPCa risk score: {cspca_risk_list[i]:.4f}"
150
+ )
151
  output_dict[j] = {
152
+ "Predicted PIRAD Score": pirads_list[i] + 2.0,
153
+ "csPCa risk": cspca_risk_list[i],
154
+ "Top left coordinate of top 5 patches(x,y,z)": coords_list[i],
155
  }
156
+
157
+ with open(os.path.join(args.output_dir, "results.json"), "w") as f:
158
+ json.dump(output_dict, f, indent=4)
run_pirads.py CHANGED
@@ -1,32 +1,14 @@
1
  import argparse
2
- import collections.abc
3
  import os
4
  import shutil
5
  import time
6
  import yaml
7
  import sys
8
- import gdown
9
  import numpy as np
10
  import torch
11
- import torch.distributed as dist
12
- import torch.multiprocessing as mp
13
- import torch.nn as nn
14
- import torch.nn.functional as F
15
- from monai.config import KeysCollection
16
- from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
17
- from monai.data.wsi_reader import WSIReader
18
- from monai.metrics import Cumulative, CumulativeAverage
19
- from monai.networks.nets import milmodel, resnet, MILModel
20
-
21
- from sklearn.metrics import cohen_kappa_score
22
- from torch.cuda.amp import GradScaler, autocast
23
- from torch.utils.data.dataloader import default_collate
24
  from torch.utils.tensorboard import SummaryWriter
25
  from monai.utils import set_determinism
26
- import matplotlib.pyplot as plt
27
-
28
  import wandb
29
- import math
30
  import logging
31
  from pathlib import Path
32
  from src.data.data_loader import get_dataloader
@@ -37,13 +19,9 @@ from src.utils import save_pirads_checkpoint, setup_logging
37
 
38
  def main_worker(args):
39
  if args.device == torch.device("cuda"):
40
- torch.cuda.set_device(args.gpu) # use this default device (same as args.device if not distributed)
41
  torch.backends.cudnn.benchmark = True
42
 
43
- model = MILModel_3D(
44
- num_classes=args.num_classes,
45
- mil_mode=args.mil_mode
46
- )
47
  start_epoch = 0
48
  best_acc = 0.0
49
  if args.checkpoint is not None:
@@ -54,41 +32,54 @@ def main_worker(args):
54
  start_epoch = checkpoint["epoch"]
55
  if "best_acc" in checkpoint:
56
  best_acc = checkpoint["best_acc"]
57
- logging.info("=> loaded checkpoint %s (epoch %d) (bestacc %f)",args.checkpoint, start_epoch, best_acc)
 
 
 
 
 
58
  cache_dir_ = os.path.join(args.logdir, "cache")
59
  model.to(args.device)
60
  params = model.parameters()
61
- if args.mode == 'train':
62
- train_loader = get_dataloader(args, split=args.mode)
63
  valid_loader = get_dataloader(args, split="test")
64
- logging.info("Dataset training:", str(len(train_loader.dataset)), "test:", str(len(valid_loader.dataset)))
65
-
 
 
66
  if args.mil_mode in ["att_trans", "att_trans_pyramid"]:
67
  params = [
68
- {"params": list(model.attention.parameters()) + list(model.myfc.parameters()) + list(model.net.parameters())},
 
 
 
 
69
  {"params": list(model.transformer.parameters()), "lr": 6e-5, "weight_decay": 0.1},
70
  ]
71
 
72
  optimizer = torch.optim.AdamW(params, lr=args.optim_lr, weight_decay=args.weight_decay)
73
- scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0)
 
 
74
  scaler = torch.amp.GradScaler(device=str(args.device), enabled=args.amp)
75
 
76
  if args.logdir is not None:
77
  writer = SummaryWriter(log_dir=args.logdir)
78
- logging.info("Writing Tensorboard logs to ", writer.log_dir)
79
  else:
80
  writer = None
81
 
82
-
83
  # RUN TRAINING
84
  n_epochs = args.epochs
85
  val_loss_min = float("inf")
86
  epochs_no_improve = 0
87
  for epoch in range(start_epoch, n_epochs):
88
-
89
  logging.info(time.ctime(), "Epoch:", epoch)
90
  epoch_time = time.time()
91
- train_loss, train_acc, train_att_loss, batch_norm = train_epoch(model, train_loader, optimizer, scaler=scaler, epoch=epoch, args=args)
 
 
92
  logging.info(
93
  "Final training %d/%d loss: %.4f attention loss: %.4f acc: %.4f time %.2fs",
94
  epoch,
@@ -98,13 +89,21 @@ def main_worker(args):
98
  train_acc,
99
  time.time() - epoch_time,
100
  )
101
-
102
  if writer is not None:
103
  writer.add_scalar("train_loss", train_loss, epoch)
104
  writer.add_scalar("train_attention_loss", train_att_loss, epoch)
105
  writer.add_scalar("train_acc", train_acc, epoch)
106
- wandb.log({"Train Loss": train_loss, "Train Accuracy": train_acc, "Train Attention Loss": train_att_loss, "Batch Norm": batch_norm}, step=epoch)
107
-
 
 
 
 
 
 
 
 
108
  model_new_best = False
109
  val_acc = 0
110
  if (epoch + 1) % args.val_every == 0:
@@ -125,35 +124,39 @@ def main_worker(args):
125
  writer.add_scalar("test_acc", val_acc, epoch)
126
  writer.add_scalar("test_qwk", qwk, epoch)
127
 
128
- #val_acc = qwk
129
- wandb.log({"Test Loss": val_loss, "Test Accuracy": val_acc,"Cohen Kappa": qwk}, step=epoch)
 
 
 
130
  if val_loss < val_loss_min:
131
  logging.info("Loss (%.6f --> %.6f)", val_loss_min, val_loss)
132
  val_loss_min = val_loss
133
  model_new_best = True
134
 
135
  if args.logdir is not None:
136
- save_pirads_checkpoint(model, epoch, args, best_acc=val_acc, filename=f"model_{epoch}.pt")
 
 
137
  if model_new_best:
138
- logging.info("Copying to model.pt new best model!!!!")
139
- shutil.copyfile(os.path.join(args.logdir, f"model_{epoch}.pt"), os.path.join(args.logdir, "model.pt"))
 
 
 
140
  epochs_no_improve = 0
141
-
142
  else:
143
  epochs_no_improve += 1
144
  if epochs_no_improve == args.early_stop:
145
- logging.info('Early stopping!')
146
  break
147
-
148
-
149
 
150
  scheduler.step()
151
 
152
  logging.info("ALL DONE")
153
-
154
- elif args.mode == 'test':
155
 
156
-
157
  kappa_list = []
158
  for seed in list(range(args.num_seeds)):
159
  set_determinism(seed=seed)
@@ -163,51 +166,73 @@ def main_worker(args):
163
  kappa_list.append(qwk)
164
  logging.info(f"Seed {seed}, QWK: {qwk}")
165
  if os.path.exists(cache_dir_):
166
- logging.info("Removing cache directory ", cache_dir_)
167
  shutil.rmtree(cache_dir_)
168
-
169
  logging.info(f"Mean QWK over {args.num_seeds} seeds: {np.mean(kappa_list)}")
170
 
171
-
172
  if os.path.exists(cache_dir_):
173
- logging.info("Removing cache directory ", cache_dir_)
174
  shutil.rmtree(cache_dir_)
175
 
176
 
177
  def parse_args():
178
- parser = argparse.ArgumentParser(description="Multiple Instance Learning (MIL) for PIRADS Classification.")
179
- parser.add_argument('--mode', type=str, choices=['train', 'test'], required=True, help='operation mode: train or infer')
180
- parser.add_argument('--wandb', action='store_true', help='Add this flag to enable WandB logging')
181
- parser.add_argument('--project_name', type=str, default='Classification_prostate', help='WandB project name')
182
- parser.add_argument('--run_name', type=str, default='train_pirads', help='run name for WandB logging')
183
- parser.add_argument('--config', type=str, help='path to YAML config file')
 
 
 
 
 
 
 
184
  parser.add_argument(
185
- "--project_dir", default=None, help="path to project firectory"
186
  )
187
  parser.add_argument(
188
- "--data_root", default=None, help="path to root folder of images"
189
  )
 
 
 
190
  parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
191
  parser.add_argument("--num_classes", default=4, type=int, help="number of output classes")
192
- parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm: choose either att_trans or att_pyramid")
193
  parser.add_argument(
194
- "--tile_count", default=24, type=int, help="number of patches (instances) to extract from MRI input"
 
 
195
  )
196
- parser.add_argument("--tile_size", default=64, type=int, help="size of square patch (instance) in pixels")
197
- parser.add_argument("--depth", default=3, type=int, help="number of slices in each 3D patch (instance)")
198
  parser.add_argument(
199
- "--use_heatmap", action="store_true",
200
- help="enable weak attention heatmap guided patch generation"
 
 
 
 
 
 
 
 
201
  )
202
  parser.add_argument(
203
- "--no_heatmap", dest="use_heatmap", action="store_false",
204
- help="disable heatmap"
 
 
 
 
205
  )
206
  parser.set_defaults(use_heatmap=True)
207
  parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
208
 
209
  parser.add_argument("--checkpoint", default=None, help="load existing checkpoint")
210
- parser.add_argument("--epochs", "--max_epochs", default=50, type=int, help="number of training epochs")
 
 
211
  parser.add_argument("--early_stop", default=40, type=int, help="early stopping criteria")
212
  parser.add_argument("--batch_size", default=4, type=int, help="number of MRI scans per batch")
213
  parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate")
@@ -220,19 +245,27 @@ def parse_args():
220
  type=int,
221
  help="run validation after this number of epochs, default 1 to run every epoch",
222
  )
223
- parser.add_argument("--dry_run", action="store_true", help="Run the script in dry-run mode (default: False)")
 
 
224
  args = parser.parse_args()
225
  if args.config:
226
- with open(args.config, 'r') as config_file:
227
  config = yaml.safe_load(config_file)
228
  args.__dict__.update(config)
229
  return args
230
-
231
-
232
 
233
 
234
  if __name__ == "__main__":
 
235
  args = parse_args()
 
 
 
 
 
 
 
236
  args.logdir = os.path.join(args.project_dir, "logs", args.run_name)
237
  os.makedirs(args.logdir, exist_ok=True)
238
  args.logfile = os.path.join(args.logdir, f"{args.run_name}.log")
@@ -243,15 +276,14 @@ if __name__ == "__main__":
243
  logging.info(f"{k} => {v}")
244
  logging.info("-----------------")
245
 
246
- args.num_seeds = 10
247
  args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
248
  if args.device == torch.device("cpu"):
249
  args.amp = False
250
  if args.dataset_json is None:
251
- logging.error('Dataset JSON file not provided. Quitting.')
252
  sys.exit(1)
253
- if args.checkpoint is None and args.mode == 'test':
254
- logging.error('Model checkpoint path not provided. Quitting.')
255
  sys.exit(1)
256
 
257
  if args.dry_run:
@@ -261,8 +293,8 @@ if __name__ == "__main__":
261
  args.workers = 0
262
  args.num_seeds = 2
263
  args.wandb = False
264
-
265
- mode_wandb = "online" if args.wandb else "disabled"
266
 
267
  config_wandb = {
268
  "learning_rate": args.optim_lr,
@@ -271,13 +303,14 @@ if __name__ == "__main__":
271
  "patch size": args.tile_size,
272
  "patch count": args.tile_count,
273
  }
274
- wandb.init(project=args.project_name,
275
- name=args.run_name,
276
- dir=os.path.join(args.logdir, "wandb"),
277
- config=config_wandb,
278
- mode=mode_wandb)
279
-
 
280
 
281
  main_worker(args)
282
-
283
  wandb.finish()
 
1
  import argparse
 
2
  import os
3
  import shutil
4
  import time
5
  import yaml
6
  import sys
 
7
  import numpy as np
8
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from torch.utils.tensorboard import SummaryWriter
10
  from monai.utils import set_determinism
 
 
11
  import wandb
 
12
  import logging
13
  from pathlib import Path
14
  from src.data.data_loader import get_dataloader
 
19
 
20
  def main_worker(args):
21
  if args.device == torch.device("cuda"):
 
22
  torch.backends.cudnn.benchmark = True
23
 
24
+ model = MILModel_3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
 
 
 
25
  start_epoch = 0
26
  best_acc = 0.0
27
  if args.checkpoint is not None:
 
32
  start_epoch = checkpoint["epoch"]
33
  if "best_acc" in checkpoint:
34
  best_acc = checkpoint["best_acc"]
35
+ logging.info(
36
+ "=> loaded checkpoint %s (epoch %d) (bestacc %f)",
37
+ args.checkpoint,
38
+ start_epoch,
39
+ best_acc,
40
+ )
41
  cache_dir_ = os.path.join(args.logdir, "cache")
42
  model.to(args.device)
43
  params = model.parameters()
44
+ if args.mode == "train":
45
+ train_loader = get_dataloader(args, split="train")
46
  valid_loader = get_dataloader(args, split="test")
47
+ logging.info(
48
+ f"Dataset training: {len(train_loader.dataset)}, test: {len(valid_loader.dataset)}"
49
+ )
50
+
51
  if args.mil_mode in ["att_trans", "att_trans_pyramid"]:
52
  params = [
53
+ {
54
+ "params": list(model.attention.parameters())
55
+ + list(model.myfc.parameters())
56
+ + list(model.net.parameters())
57
+ },
58
  {"params": list(model.transformer.parameters()), "lr": 6e-5, "weight_decay": 0.1},
59
  ]
60
 
61
  optimizer = torch.optim.AdamW(params, lr=args.optim_lr, weight_decay=args.weight_decay)
62
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
63
+ optimizer, T_max=args.epochs, eta_min=0
64
+ )
65
  scaler = torch.amp.GradScaler(device=str(args.device), enabled=args.amp)
66
 
67
  if args.logdir is not None:
68
  writer = SummaryWriter(log_dir=args.logdir)
69
+ logging.info(f"Writing Tensorboard logs to {writer.log_dir}")
70
  else:
71
  writer = None
72
 
 
73
  # RUN TRAINING
74
  n_epochs = args.epochs
75
  val_loss_min = float("inf")
76
  epochs_no_improve = 0
77
  for epoch in range(start_epoch, n_epochs):
 
78
  logging.info(time.ctime(), "Epoch:", epoch)
79
  epoch_time = time.time()
80
+ train_loss, train_acc, train_att_loss, batch_norm = train_epoch(
81
+ model, train_loader, optimizer, scaler=scaler, epoch=epoch, args=args
82
+ )
83
  logging.info(
84
  "Final training %d/%d loss: %.4f attention loss: %.4f acc: %.4f time %.2fs",
85
  epoch,
 
89
  train_acc,
90
  time.time() - epoch_time,
91
  )
92
+
93
  if writer is not None:
94
  writer.add_scalar("train_loss", train_loss, epoch)
95
  writer.add_scalar("train_attention_loss", train_att_loss, epoch)
96
  writer.add_scalar("train_acc", train_acc, epoch)
97
+ wandb.log(
98
+ {
99
+ "Train Loss": train_loss,
100
+ "Train Accuracy": train_acc,
101
+ "Train Attention Loss": train_att_loss,
102
+ "Batch Norm": batch_norm,
103
+ },
104
+ step=epoch,
105
+ )
106
+
107
  model_new_best = False
108
  val_acc = 0
109
  if (epoch + 1) % args.val_every == 0:
 
124
  writer.add_scalar("test_acc", val_acc, epoch)
125
  writer.add_scalar("test_qwk", qwk, epoch)
126
 
127
+ # val_acc = qwk
128
+ wandb.log(
129
+ {"Test Loss": val_loss, "Test Accuracy": val_acc, "Cohen Kappa": qwk},
130
+ step=epoch,
131
+ )
132
  if val_loss < val_loss_min:
133
  logging.info("Loss (%.6f --> %.6f)", val_loss_min, val_loss)
134
  val_loss_min = val_loss
135
  model_new_best = True
136
 
137
  if args.logdir is not None:
138
+ save_pirads_checkpoint(
139
+ model, epoch, args, best_acc=val_acc, filename=f"model_{epoch}.pt"
140
+ )
141
  if model_new_best:
142
+ logging.info("Copying to model.pt new best model")
143
+ shutil.copyfile(
144
+ os.path.join(args.logdir, f"model_{epoch}.pt"),
145
+ os.path.join(args.logdir, "model.pt"),
146
+ )
147
  epochs_no_improve = 0
148
+
149
  else:
150
  epochs_no_improve += 1
151
  if epochs_no_improve == args.early_stop:
152
+ logging.info("Early stopping!")
153
  break
 
 
154
 
155
  scheduler.step()
156
 
157
  logging.info("ALL DONE")
 
 
158
 
159
+ elif args.mode == "test":
160
  kappa_list = []
161
  for seed in list(range(args.num_seeds)):
162
  set_determinism(seed=seed)
 
166
  kappa_list.append(qwk)
167
  logging.info(f"Seed {seed}, QWK: {qwk}")
168
  if os.path.exists(cache_dir_):
169
+ logging.info(f"Removing cache directory {cache_dir_}")
170
  shutil.rmtree(cache_dir_)
171
+
172
  logging.info(f"Mean QWK over {args.num_seeds} seeds: {np.mean(kappa_list)}")
173
 
 
174
  if os.path.exists(cache_dir_):
175
+ logging.info(f"Removing cache directory {cache_dir_}")
176
  shutil.rmtree(cache_dir_)
177
 
178
 
179
  def parse_args():
180
+ parser = argparse.ArgumentParser(
181
+ description="Multiple Instance Learning (MIL) for PIRADS Classification."
182
+ )
183
+ parser.add_argument(
184
+ "--mode",
185
+ type=str,
186
+ choices=["train", "test"],
187
+ required=True,
188
+ help="operation mode: train or infer",
189
+ )
190
+ parser.add_argument(
191
+ "--wandb", action="store_true", help="Add this flag to enable WandB logging"
192
+ )
193
  parser.add_argument(
194
+ "--project_name", type=str, default="Classification_prostate", help="WandB project name"
195
  )
196
  parser.add_argument(
197
+ "--run_name", type=str, default="train_pirads", help="run name for WandB logging"
198
  )
199
+ parser.add_argument("--config", type=str, help="path to YAML config file")
200
+ parser.add_argument("--project_dir", default=None, help="path to project firectory")
201
+ parser.add_argument("--data_root", default=None, help="path to root folder of images")
202
  parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
203
  parser.add_argument("--num_classes", default=4, type=int, help="number of output classes")
 
204
  parser.add_argument(
205
+ "--mil_mode",
206
+ default="att_trans",
207
+ help="MIL algorithm: choose either att_trans or att_pyramid",
208
  )
 
 
209
  parser.add_argument(
210
+ "--tile_count",
211
+ default=24,
212
+ type=int,
213
+ help="number of patches (instances) to extract from MRI input",
214
+ )
215
+ parser.add_argument(
216
+ "--tile_size", default=64, type=int, help="size of square patch (instance) in pixels"
217
+ )
218
+ parser.add_argument(
219
+ "--depth", default=3, type=int, help="number of slices in each 3D patch (instance)"
220
  )
221
  parser.add_argument(
222
+ "--use_heatmap",
223
+ action="store_true",
224
+ help="enable weak attention heatmap guided patch generation",
225
+ )
226
+ parser.add_argument(
227
+ "--no_heatmap", dest="use_heatmap", action="store_false", help="disable heatmap"
228
  )
229
  parser.set_defaults(use_heatmap=True)
230
  parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
231
 
232
  parser.add_argument("--checkpoint", default=None, help="load existing checkpoint")
233
+ parser.add_argument(
234
+ "--epochs", "--max_epochs", default=50, type=int, help="number of training epochs"
235
+ )
236
  parser.add_argument("--early_stop", default=40, type=int, help="early stopping criteria")
237
  parser.add_argument("--batch_size", default=4, type=int, help="number of MRI scans per batch")
238
  parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate")
 
245
  type=int,
246
  help="run validation after this number of epochs, default 1 to run every epoch",
247
  )
248
+ parser.add_argument(
249
+ "--dry_run", action="store_true", help="Run the script in dry-run mode (default: False)"
250
+ )
251
  args = parser.parse_args()
252
  if args.config:
253
+ with open(args.config, "r") as config_file:
254
  config = yaml.safe_load(config_file)
255
  args.__dict__.update(config)
256
  return args
 
 
257
 
258
 
259
  if __name__ == "__main__":
260
+
261
  args = parse_args()
262
+ if args.project_dir is None:
263
+ args.project_dir = Path(__file__).resolve().parent # Set project directory
264
+
265
+ slurm_job_name = os.getenv('SLURM_JOB_NAME') # If the script is submitted via slurm, job name is the run name
266
+ if slurm_job_name:
267
+ args.run_name = slurm_job_name
268
+
269
  args.logdir = os.path.join(args.project_dir, "logs", args.run_name)
270
  os.makedirs(args.logdir, exist_ok=True)
271
  args.logfile = os.path.join(args.logdir, f"{args.run_name}.log")
 
276
  logging.info(f"{k} => {v}")
277
  logging.info("-----------------")
278
 
 
279
  args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
280
  if args.device == torch.device("cpu"):
281
  args.amp = False
282
  if args.dataset_json is None:
283
+ logging.error("Dataset JSON file not provided. Quitting.")
284
  sys.exit(1)
285
+ if args.checkpoint is None and args.mode == "test":
286
+ logging.error("Model checkpoint path not provided. Quitting.")
287
  sys.exit(1)
288
 
289
  if args.dry_run:
 
293
  args.workers = 0
294
  args.num_seeds = 2
295
  args.wandb = False
296
+
297
+ mode_wandb = "online" if args.wandb and args.mode != "test" else "disabled"
298
 
299
  config_wandb = {
300
  "learning_rate": args.optim_lr,
 
303
  "patch size": args.tile_size,
304
  "patch count": args.tile_count,
305
  }
306
+ wandb.init(
307
+ project=args.project_name,
308
+ name=args.run_name,
309
+ dir=os.path.join(args.logdir, "wandb"),
310
+ config=config_wandb,
311
+ mode=mode_wandb,
312
+ )
313
 
314
  main_worker(args)
315
+
316
  wandb.finish()
src/data/custom_transforms.py CHANGED
@@ -1,27 +1,25 @@
1
  import numpy as np
2
  import torch
3
- from typing import Union, Optional
4
-
5
  from monai.transforms import MapTransform
6
  from monai.config import DtypeLike, KeysCollection
7
- from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
8
  from monai.data.meta_obj import get_track_meta
9
  from monai.transforms.transform import Transform
10
  from monai.transforms.utils import soft_clip
11
- from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
12
  from monai.utils.enums import TransformBackends
13
- from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor, get_equivalent_dtype
14
  from scipy.ndimage import binary_dilation
15
  import cv2
16
- from typing import Union, Sequence
17
  from collections.abc import Hashable, Mapping, Sequence
18
 
19
 
20
-
21
  class DilateAndSaveMaskd(MapTransform):
22
  """
23
  Custom transform to dilate binary mask and save a copy.
24
  """
 
25
  def __init__(self, keys, dilation_size=10, copy_key="original_mask"):
26
  super().__init__(keys)
27
  self.dilation_size = dilation_size
@@ -29,37 +27,61 @@ class DilateAndSaveMaskd(MapTransform):
29
 
30
  def __call__(self, data):
31
  d = dict(data)
32
-
33
  for key in self.keys:
34
  mask = d[key].numpy() if isinstance(d[key], torch.Tensor) else d[key]
35
  mask = mask.squeeze(0) # Remove channel dimension if present
36
 
37
  # Save a copy of the original mask
38
- d[self.copy_key] = torch.tensor(mask, dtype=torch.float32).unsqueeze(0) # Save to a new key
 
 
39
 
40
  # Apply binary dilation to the mask
41
  dilated_mask = binary_dilation(mask, iterations=self.dilation_size).astype(np.uint8)
42
 
43
  # Store the dilated mask
44
- d[key] = torch.tensor(dilated_mask, dtype=torch.float32).unsqueeze(0) # Add channel dimension back
 
 
45
 
46
  return d
47
 
48
 
49
  class ClipMaskIntensityPercentiles(Transform):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
 
53
 
54
  def __init__(
55
  self,
56
  lower: Union[float, None],
57
  upper: Union[float, None],
58
- sharpness_factor : Union[float, None] = None,
59
  channel_wise: bool = False,
60
  dtype: DtypeLike = np.float32,
61
  ) -> None:
62
-
63
  if lower is None and upper is None:
64
  raise ValueError("lower or upper percentiles must be provided")
65
  if lower is not None and (lower < 0.0 or lower > 100.0):
@@ -71,7 +93,7 @@ class ClipMaskIntensityPercentiles(Transform):
71
  if sharpness_factor is not None and sharpness_factor <= 0:
72
  raise ValueError("sharpness_factor must be greater than 0")
73
 
74
- #self.mask_data = mask_data
75
  self.lower = lower
76
  self.upper = upper
77
  self.sharpness_factor = sharpness_factor
@@ -81,14 +103,26 @@ class ClipMaskIntensityPercentiles(Transform):
81
  def _clip(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
82
  masked_img = img * (mask_data > 0)
83
  if self.sharpness_factor is not None:
84
-
85
- lower_percentile = percentile(masked_img, self.lower) if self.lower is not None else None
86
- upper_percentile = percentile(masked_img, self.upper) if self.upper is not None else None
87
- img = soft_clip(img, self.sharpness_factor, lower_percentile, upper_percentile, self.dtype)
 
 
 
 
 
88
  else:
89
-
90
- lower_percentile = percentile(masked_img, self.lower) if self.lower is not None else percentile(masked_img, 0)
91
- upper_percentile = percentile(masked_img, self.upper) if self.upper is not None else percentile(masked_img, 100)
 
 
 
 
 
 
 
92
  img = clip(img, lower_percentile, upper_percentile)
93
 
94
  img = convert_to_tensor(img, track_meta=False)
@@ -102,7 +136,9 @@ class ClipMaskIntensityPercentiles(Transform):
102
  img_t = convert_to_tensor(img, track_meta=False)
103
  mask_t = convert_to_tensor(mask_data, track_meta=False)
104
  if self.channel_wise:
105
- img_t = torch.stack([self._clip(img=d, mask_data=mask_t[e]) for e,d in enumerate(img_t)]) # type: ignore
 
 
106
  else:
107
  img_t = self._clip(img=img_t, mask_data=mask_t)
108
 
@@ -110,7 +146,28 @@ class ClipMaskIntensityPercentiles(Transform):
110
 
111
  return img
112
 
 
113
  class ClipMaskIntensityPercentilesd(MapTransform):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  def __init__(
116
  self,
@@ -125,7 +182,11 @@ class ClipMaskIntensityPercentilesd(MapTransform):
125
  ) -> None:
126
  super().__init__(keys, allow_missing_keys)
127
  self.scaler = ClipMaskIntensityPercentiles(
128
- lower=lower, upper=upper, sharpness_factor=sharpness_factor, channel_wise=channel_wise, dtype=dtype
 
 
 
 
129
  )
130
  self.mask_key = mask_key
131
 
@@ -134,16 +195,32 @@ class ClipMaskIntensityPercentilesd(MapTransform):
134
  for key in self.key_iterator(d):
135
  d[key] = self.scaler(d[key], d[self.mask_key])
136
  return d
137
-
138
-
139
 
140
 
141
  class ElementwiseProductd(MapTransform):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def __init__(self, keys: KeysCollection, output_key: str) -> None:
143
  super().__init__(keys)
144
  self.output_key = output_key
145
 
146
- def __call__(self, data) -> NdarrayOrTensor:
147
  d = dict(data)
148
  d[self.output_key] = d[self.keys[0]] * d[self.keys[1]]
149
  return d
@@ -159,6 +236,7 @@ class CLAHEd(MapTransform):
159
  clip_limit (float): Threshold for contrast limiting. Default is 2.0.
160
  tile_grid_size (Union[tuple, Sequence[int]]): Size of grid for histogram equalization (default: (8,8)).
161
  """
 
162
  def __init__(
163
  self,
164
  keys: KeysCollection,
@@ -184,13 +262,13 @@ class CLAHEd(MapTransform):
184
 
185
  image_clahe = np.stack([clahe.apply(slice) for slice in image[0]])
186
 
187
-
188
  # Convert back to float in [0,1]
189
  processed_img = image_clahe.astype(np.float32) / 255.0
190
  reshaped_ = processed_img.reshape(1, *processed_img.shape)
191
  d[key] = torch.from_numpy(reshaped_).to(image_.device)
192
  return d
193
-
 
194
  class NormalizeIntensity_custom(Transform):
195
  """
196
  Normalize input based on the `subtrahend` and `divisor`: `(img - subtrahend) / divisor`.
@@ -240,9 +318,11 @@ class NormalizeIntensity_custom(Transform):
240
  x = torch.std(x.float(), unbiased=False)
241
  return x.item() if x.numel() == 1 else x
242
 
243
- def _normalize(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTensor:
 
 
244
  img, *_ = convert_data_type(img, dtype=torch.float32)
245
- '''
246
  if self.nonzero:
247
  slices = img != 0
248
  masked_img = img[slices]
@@ -251,7 +331,7 @@ class NormalizeIntensity_custom(Transform):
251
  else:
252
  slices = None
253
  masked_img = img
254
- '''
255
  slices = None
256
  mask_data = mask_data.squeeze(0)
257
  slices_mask = mask_data > 0
@@ -288,9 +368,13 @@ class NormalizeIntensity_custom(Transform):
288
  dtype = self.dtype or img.dtype
289
  if self.channel_wise:
290
  if self.subtrahend is not None and len(self.subtrahend) != len(img):
291
- raise ValueError(f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components.")
 
 
292
  if self.divisor is not None and len(self.divisor) != len(img):
293
- raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.")
 
 
294
 
295
  if not img.dtype.is_floating_point:
296
  img, *_ = convert_data_type(img, dtype=torch.float32)
@@ -308,21 +392,27 @@ class NormalizeIntensity_custom(Transform):
308
  out = convert_to_dst_type(img, img, dtype=dtype)[0]
309
  return out
310
 
 
311
  class NormalizeIntensity_customd(MapTransform):
312
  """
313
- Dictionary-based wrapper of :py:class:`monai.transforms.NormalizeIntensity`.
314
- This transform can normalize only non-zero values or entire image, and can also calculate
315
- mean and std on each channel separately.
 
316
 
317
  Args:
318
  keys: keys of the corresponding items to be transformed.
319
- See also: monai.transforms.MapTransform
320
- subtrahend: the amount to subtract by (usually the mean)
321
- divisor: the amount to divide by (usually the standard deviation)
 
 
 
 
322
  nonzero: whether only normalize non-zero values.
323
  channel_wise: if True, calculate on each channel separately, otherwise, calculate on
324
- the entire image directly. default to False.
325
- dtype: output data type, if None, same as input image. defaults to float32.
326
  allow_missing_keys: don't raise exception if key is missing.
327
  """
328
 
@@ -332,19 +422,21 @@ class NormalizeIntensity_customd(MapTransform):
332
  self,
333
  keys: KeysCollection,
334
  mask_key: str,
335
- subtrahend:Union[ NdarrayOrTensor, None] = None,
336
- divisor: Union[ NdarrayOrTensor, None] = None,
337
  nonzero: bool = False,
338
  channel_wise: bool = False,
339
  dtype: DtypeLike = np.float32,
340
  allow_missing_keys: bool = False,
341
  ) -> None:
342
  super().__init__(keys, allow_missing_keys)
343
- self.normalizer = NormalizeIntensity_custom(subtrahend, divisor, nonzero, channel_wise, dtype)
 
 
344
  self.mask_key = mask_key
345
 
346
  def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
347
  d = dict(data)
348
  for key in self.key_iterator(d):
349
  d[key] = self.normalizer(d[key], d[self.mask_key])
350
- return d
 
1
  import numpy as np
2
  import torch
3
+ from typing import Union
 
4
  from monai.transforms import MapTransform
5
  from monai.config import DtypeLike, KeysCollection
6
+ from monai.config.type_definitions import NdarrayOrTensor
7
  from monai.data.meta_obj import get_track_meta
8
  from monai.transforms.transform import Transform
9
  from monai.transforms.utils import soft_clip
10
+ from monai.transforms.utils_pytorch_numpy_unification import clip, percentile
11
  from monai.utils.enums import TransformBackends
12
+ from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor
13
  from scipy.ndimage import binary_dilation
14
  import cv2
 
15
  from collections.abc import Hashable, Mapping, Sequence
16
 
17
 
 
18
  class DilateAndSaveMaskd(MapTransform):
19
  """
20
  Custom transform to dilate binary mask and save a copy.
21
  """
22
+
23
  def __init__(self, keys, dilation_size=10, copy_key="original_mask"):
24
  super().__init__(keys)
25
  self.dilation_size = dilation_size
 
27
 
28
  def __call__(self, data):
29
  d = dict(data)
30
+
31
  for key in self.keys:
32
  mask = d[key].numpy() if isinstance(d[key], torch.Tensor) else d[key]
33
  mask = mask.squeeze(0) # Remove channel dimension if present
34
 
35
  # Save a copy of the original mask
36
+ d[self.copy_key] = torch.tensor(mask, dtype=torch.float32).unsqueeze(
37
+ 0
38
+ ) # Save to a new key
39
 
40
  # Apply binary dilation to the mask
41
  dilated_mask = binary_dilation(mask, iterations=self.dilation_size).astype(np.uint8)
42
 
43
  # Store the dilated mask
44
+ d[key] = torch.tensor(dilated_mask, dtype=torch.float32).unsqueeze(
45
+ 0
46
+ ) # Add channel dimension back
47
 
48
  return d
49
 
50
 
51
  class ClipMaskIntensityPercentiles(Transform):
52
+ """
53
+ Clip image intensity values based on percentiles computed from a masked region.
54
+ This transform clips the intensity range of an image to values between lower and upper
55
+ percentiles calculated only from voxels where the mask is positive. It supports both
56
+ hard clipping and soft (smooth) clipping via a sharpness factor.
57
+ Args:
58
+ lower: Lower percentile threshold in range [0, 100]. If None, no lower clipping applied.
59
+ upper: Upper percentile threshold in range [0, 100]. If None, no upper clipping applied.
60
+ sharpness_factor: If provided, applies soft clipping with this sharpness parameter.
61
+ Must be greater than 0. If None, applies hard clipping instead.
62
+ channel_wise: If True, applies clipping independently to each channel using the
63
+ corresponding channel's mask. If False, uses the same mask for all channels.
64
+ dtype: Output data type for the clipped image. Defaults to np.float32.
65
+ Raises:
66
+ ValueError: If both lower and upper are None, if percentiles are outside [0, 100],
67
+ if upper < lower, or if sharpness_factor <= 0.
68
+ Returns:
69
+ Clipped image with intensities adjusted based on masked percentiles.
70
+ Note:
71
+ Supports both torch.Tensor and numpy.ndarray inputs.
72
 
73
 
74
  backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
75
+ """
76
 
77
  def __init__(
78
  self,
79
  lower: Union[float, None],
80
  upper: Union[float, None],
81
+ sharpness_factor: Union[float, None] = None,
82
  channel_wise: bool = False,
83
  dtype: DtypeLike = np.float32,
84
  ) -> None:
 
85
  if lower is None and upper is None:
86
  raise ValueError("lower or upper percentiles must be provided")
87
  if lower is not None and (lower < 0.0 or lower > 100.0):
 
93
  if sharpness_factor is not None and sharpness_factor <= 0:
94
  raise ValueError("sharpness_factor must be greater than 0")
95
 
96
+ # self.mask_data = mask_data
97
  self.lower = lower
98
  self.upper = upper
99
  self.sharpness_factor = sharpness_factor
 
103
  def _clip(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
104
  masked_img = img * (mask_data > 0)
105
  if self.sharpness_factor is not None:
106
+ lower_percentile = (
107
+ percentile(masked_img, self.lower) if self.lower is not None else None
108
+ )
109
+ upper_percentile = (
110
+ percentile(masked_img, self.upper) if self.upper is not None else None
111
+ )
112
+ img = soft_clip(
113
+ img, self.sharpness_factor, lower_percentile, upper_percentile, self.dtype
114
+ )
115
  else:
116
+ lower_percentile = (
117
+ percentile(masked_img, self.lower)
118
+ if self.lower is not None
119
+ else percentile(masked_img, 0)
120
+ )
121
+ upper_percentile = (
122
+ percentile(masked_img, self.upper)
123
+ if self.upper is not None
124
+ else percentile(masked_img, 100)
125
+ )
126
  img = clip(img, lower_percentile, upper_percentile)
127
 
128
  img = convert_to_tensor(img, track_meta=False)
 
136
  img_t = convert_to_tensor(img, track_meta=False)
137
  mask_t = convert_to_tensor(mask_data, track_meta=False)
138
  if self.channel_wise:
139
+ img_t = torch.stack(
140
+ [self._clip(img=d, mask_data=mask_t[e]) for e, d in enumerate(img_t)]
141
+ ) # type: ignore
142
  else:
143
  img_t = self._clip(img=img_t, mask_data=mask_t)
144
 
 
146
 
147
  return img
148
 
149
+
150
  class ClipMaskIntensityPercentilesd(MapTransform):
151
+ """
152
+ Dictionary wrapper for ClipMaskIntensityPercentiles.
153
+ Args:
154
+ keys: Keys of the corresponding items to be transformed.
155
+ mask_key: Key to the mask data in the input dictionary used to compute percentiles. Only intensity values where the mask is positive will be considered.
156
+ lower: Lower percentile value (0-100) for clipping. If None, no lower clipping is applied.
157
+ upper: Upper percentile value (0-100) for clipping. If None, no upper clipping is applied.
158
+ sharpness_factor: Optional factor to enhance contrast after clipping. If None, no sharpness enhancement is applied.
159
+ channel_wise: If True, compute percentiles separately for each channel. If False, compute globally.
160
+ dtype: Data type of the output. Defaults to np.float32.
161
+ allow_missing_keys: If True, missing keys will not raise an error. Defaults to False.
162
+ Example:
163
+ >>> transform = ClipMaskIntensityPercentilesd(
164
+ ... keys=["image"],
165
+ ... mask_key="mask",
166
+ ... lower=2,
167
+ ... upper=98,
168
+ ... sharpness_factor=1.0
169
+ ... )
170
+ """
171
 
172
  def __init__(
173
  self,
 
182
  ) -> None:
183
  super().__init__(keys, allow_missing_keys)
184
  self.scaler = ClipMaskIntensityPercentiles(
185
+ lower=lower,
186
+ upper=upper,
187
+ sharpness_factor=sharpness_factor,
188
+ channel_wise=channel_wise,
189
+ dtype=dtype,
190
  )
191
  self.mask_key = mask_key
192
 
 
195
  for key in self.key_iterator(d):
196
  d[key] = self.scaler(d[key], d[self.mask_key])
197
  return d
 
 
198
 
199
 
200
  class ElementwiseProductd(MapTransform):
201
+ """
202
+ A dictionary-based transform that computes the elementwise product of two arrays.
203
+ This transform multiplies two input arrays element-by-element and stores the result
204
+ in a specified output key.
205
+ Args:
206
+ keys: Collection of keys to select from the input dictionary. Must contain exactly
207
+ two keys whose corresponding values will be multiplied together.
208
+ output_key: Key in the output dictionary where the product result will be stored.
209
+ Returns:
210
+ Dictionary with the elementwise product stored at the output_key.
211
+ Example:
212
+ >>> transform = ElementwiseProductd(keys=["image1", "image2"], output_key="product")
213
+ >>> data = {"image1": np.array([1, 2, 3]), "image2": np.array([2, 3, 4])}
214
+ >>> result = transform(data)
215
+ >>> result["product"]
216
+ array([ 2, 6, 12])
217
+ """
218
+
219
  def __init__(self, keys: KeysCollection, output_key: str) -> None:
220
  super().__init__(keys)
221
  self.output_key = output_key
222
 
223
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
224
  d = dict(data)
225
  d[self.output_key] = d[self.keys[0]] * d[self.keys[1]]
226
  return d
 
236
  clip_limit (float): Threshold for contrast limiting. Default is 2.0.
237
  tile_grid_size (Union[tuple, Sequence[int]]): Size of grid for histogram equalization (default: (8,8)).
238
  """
239
+
240
  def __init__(
241
  self,
242
  keys: KeysCollection,
 
262
 
263
  image_clahe = np.stack([clahe.apply(slice) for slice in image[0]])
264
 
 
265
  # Convert back to float in [0,1]
266
  processed_img = image_clahe.astype(np.float32) / 255.0
267
  reshaped_ = processed_img.reshape(1, *processed_img.shape)
268
  d[key] = torch.from_numpy(reshaped_).to(image_.device)
269
  return d
270
+
271
+
272
  class NormalizeIntensity_custom(Transform):
273
  """
274
  Normalize input based on the `subtrahend` and `divisor`: `(img - subtrahend) / divisor`.
 
318
  x = torch.std(x.float(), unbiased=False)
319
  return x.item() if x.numel() == 1 else x
320
 
321
+ def _normalize(
322
+ self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor, sub=None, div=None
323
+ ) -> NdarrayOrTensor:
324
  img, *_ = convert_data_type(img, dtype=torch.float32)
325
+ """
326
  if self.nonzero:
327
  slices = img != 0
328
  masked_img = img[slices]
 
331
  else:
332
  slices = None
333
  masked_img = img
334
+ """
335
  slices = None
336
  mask_data = mask_data.squeeze(0)
337
  slices_mask = mask_data > 0
 
368
  dtype = self.dtype or img.dtype
369
  if self.channel_wise:
370
  if self.subtrahend is not None and len(self.subtrahend) != len(img):
371
+ raise ValueError(
372
+ f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components."
373
+ )
374
  if self.divisor is not None and len(self.divisor) != len(img):
375
+ raise ValueError(
376
+ f"img has {len(img)} channels, but divisor has {len(self.divisor)} components."
377
+ )
378
 
379
  if not img.dtype.is_floating_point:
380
  img, *_ = convert_data_type(img, dtype=torch.float32)
 
392
  out = convert_to_dst_type(img, img, dtype=dtype)[0]
393
  return out
394
 
395
+
396
  class NormalizeIntensity_customd(MapTransform):
397
  """
398
+ Dictionary-based wrapper of :class:`NormalizeIntensity_custom`.
399
+
400
+ The mean and standard deviation are calculated only from intensities which are
401
+ defined in the mask provided through ``mask_key``.
402
 
403
  Args:
404
  keys: keys of the corresponding items to be transformed.
405
+ See also: :py:class:`monai.transforms.MapTransform`
406
+ mask_key: key of the corresponding mask item to be used for calculating
407
+ statistics (mean and std).
408
+ subtrahend: the amount to subtract by (usually the mean). If None,
409
+ the mean is calculated from the masked region of the input image.
410
+ divisor: the amount to divide by (usually the standard deviation). If None,
411
+ the std is calculated from the masked region of the input image.
412
  nonzero: whether only normalize non-zero values.
413
  channel_wise: if True, calculate on each channel separately, otherwise, calculate on
414
+ the entire image directly. Defaults to False.
415
+ dtype: output data type, if None, same as input image. Defaults to float32.
416
  allow_missing_keys: don't raise exception if key is missing.
417
  """
418
 
 
422
  self,
423
  keys: KeysCollection,
424
  mask_key: str,
425
+ subtrahend: Union[NdarrayOrTensor, None] = None,
426
+ divisor: Union[NdarrayOrTensor, None] = None,
427
  nonzero: bool = False,
428
  channel_wise: bool = False,
429
  dtype: DtypeLike = np.float32,
430
  allow_missing_keys: bool = False,
431
  ) -> None:
432
  super().__init__(keys, allow_missing_keys)
433
+ self.normalizer = NormalizeIntensity_custom(
434
+ subtrahend, divisor, nonzero, channel_wise, dtype
435
+ )
436
  self.mask_key = mask_key
437
 
438
  def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
439
  d = dict(data)
440
  for key in self.key_iterator(d):
441
  d[key] = self.normalizer(d[key], d[self.mask_key])
442
+ return d
src/data/data_loader.py CHANGED
@@ -1,43 +1,29 @@
1
- import argparse
2
  import os
3
  import numpy as np
4
- from monai.config import KeysCollection
5
- from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
6
  from monai.transforms import (
7
  Compose,
8
  LoadImaged,
9
- MapTransform,
10
- ScaleIntensityRanged,
11
- SplitDimd,
12
  ToTensord,
13
- ConcatItemsd,
14
- SelectItemsd,
15
- EnsureChannelFirstd,
16
- RepeatChanneld,
17
  DeleteItemsd,
18
  EnsureTyped,
19
- ClipIntensityPercentilesd,
20
- MaskIntensityd,
21
  RandCropByPosNegLabeld,
22
  NormalizeIntensityd,
23
- SqueezeDimd,
24
- ScaleIntensityd,
25
- ScaleIntensityd,
26
  Transposed,
27
  RandWeightedCropd,
28
  )
29
  from .custom_transforms import (
30
- NormalizeIntensity_customd,
31
- ClipMaskIntensityPercentilesd,
32
  ElementwiseProductd,
33
  )
34
  import torch
35
  from torch.utils.data.dataloader import default_collate
36
- import matplotlib.pyplot as plt
37
  from typing import Literal
38
- import monai
39
  import collections.abc
40
 
 
41
  def list_data_collate(batch: collections.abc.Sequence):
42
  """
43
  Combine instances from a list of dicts into a single dict, by stacking them along first dim
@@ -51,55 +37,71 @@ def list_data_collate(batch: collections.abc.Sequence):
51
 
52
  if all("final_heatmap" in ix for ix in item):
53
  data["final_heatmap"] = torch.stack([ix["final_heatmap"] for ix in item], dim=0)
54
-
55
  batch[i] = data
56
  return default_collate(batch)
57
 
58
 
59
-
60
  def data_transform(args):
61
-
62
  if args.use_heatmap:
63
  transform = Compose(
64
  [
65
- LoadImaged(keys=["image", "mask","dwi", "adc", "heatmap"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
 
 
 
 
 
66
  ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
67
- ConcatItemsd(keys=["image", "dwi", "adc"], name="image", dim=0), # stacks to (3, H, W)
 
 
68
  NormalizeIntensity_customd(keys=["image"], channel_wise=True, mask_key="mask"),
69
  ElementwiseProductd(keys=["mask", "heatmap"], output_key="final_heatmap"),
70
- RandWeightedCropd(keys=["image", "final_heatmap"],
71
- w_key="final_heatmap",
72
- spatial_size=(args.tile_size,args.tile_size,args.depth),
73
- num_samples=args.tile_count),
 
 
74
  EnsureTyped(keys=["label"], dtype=torch.float32),
75
  Transposed(keys=["image"], indices=(0, 3, 1, 2)),
76
- DeleteItemsd(keys=['mask', 'dwi', 'adc', 'heatmap']),
77
  ToTensord(keys=["image", "label", "final_heatmap"]),
78
  ]
79
  )
80
  else:
81
  transform = Compose(
82
  [
83
- LoadImaged(keys=["image", "mask","dwi", "adc"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
 
 
 
 
 
84
  ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
85
- ConcatItemsd(keys=["image", "dwi", "adc"], name="image", dim=0), # stacks to (3, H, W)
 
 
86
  NormalizeIntensityd(keys=["image"], channel_wise=True),
87
- RandCropByPosNegLabeld(keys=["image"],
88
- label_key="mask",
89
- spatial_size=(args.tile_size,args.tile_size,args.depth),
90
- pos=1,
91
- neg=0,
92
- num_samples=args.tile_count),
 
 
93
  EnsureTyped(keys=["label"], dtype=torch.float32),
94
  Transposed(keys=["image"], indices=(0, 3, 1, 2)),
95
- DeleteItemsd(keys=['mask', 'dwi', 'adc']),
96
  ToTensord(keys=["image", "label"]),
97
  ]
98
  )
99
  return transform
100
 
101
- def get_dataloader(args, split: Literal["train", "test"]):
102
 
 
103
  data_list = load_decathlon_datalist(
104
  data_list_file_path=args.dataset_json,
105
  data_list_key=split,
@@ -110,16 +112,17 @@ def get_dataloader(args, split: Literal["train", "test"]):
110
  cache_dir_ = os.path.join(args.logdir, "cache")
111
  os.makedirs(os.path.join(cache_dir_, split), exist_ok=True)
112
  transform = data_transform(args)
113
- dataset = PersistentDataset(data=data_list, transform=transform, cache_dir= os.path.join(cache_dir_, split))
 
 
114
  loader = torch.utils.data.DataLoader(
115
  dataset,
116
  batch_size=args.batch_size,
117
  shuffle=(split == "train"),
118
  num_workers=args.workers,
119
  pin_memory=True,
120
- multiprocessing_context="spawn" if args.workers > 0 else None,
121
  sampler=None,
122
  collate_fn=list_data_collate,
123
  )
124
  return loader
125
-
 
 
1
  import os
2
  import numpy as np
3
+ from monai.data import load_decathlon_datalist, ITKReader, PersistentDataset
 
4
  from monai.transforms import (
5
  Compose,
6
  LoadImaged,
 
 
 
7
  ToTensord,
8
+ ConcatItemsd,
 
 
 
9
  DeleteItemsd,
10
  EnsureTyped,
 
 
11
  RandCropByPosNegLabeld,
12
  NormalizeIntensityd,
 
 
 
13
  Transposed,
14
  RandWeightedCropd,
15
  )
16
  from .custom_transforms import (
17
+ NormalizeIntensity_customd,
18
+ ClipMaskIntensityPercentilesd,
19
  ElementwiseProductd,
20
  )
21
  import torch
22
  from torch.utils.data.dataloader import default_collate
 
23
  from typing import Literal
 
24
  import collections.abc
25
 
26
+
27
  def list_data_collate(batch: collections.abc.Sequence):
28
  """
29
  Combine instances from a list of dicts into a single dict, by stacking them along first dim
 
37
 
38
  if all("final_heatmap" in ix for ix in item):
39
  data["final_heatmap"] = torch.stack([ix["final_heatmap"] for ix in item], dim=0)
40
+
41
  batch[i] = data
42
  return default_collate(batch)
43
 
44
 
 
45
  def data_transform(args):
 
46
  if args.use_heatmap:
47
  transform = Compose(
48
  [
49
+ LoadImaged(
50
+ keys=["image", "mask", "dwi", "adc", "heatmap"],
51
+ reader=ITKReader(),
52
+ ensure_channel_first=True,
53
+ dtype=np.float32,
54
+ ),
55
  ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
56
+ ConcatItemsd(
57
+ keys=["image", "dwi", "adc"], name="image", dim=0
58
+ ), # stacks to (3, H, W)
59
  NormalizeIntensity_customd(keys=["image"], channel_wise=True, mask_key="mask"),
60
  ElementwiseProductd(keys=["mask", "heatmap"], output_key="final_heatmap"),
61
+ RandWeightedCropd(
62
+ keys=["image", "final_heatmap"],
63
+ w_key="final_heatmap",
64
+ spatial_size=(args.tile_size, args.tile_size, args.depth),
65
+ num_samples=args.tile_count,
66
+ ),
67
  EnsureTyped(keys=["label"], dtype=torch.float32),
68
  Transposed(keys=["image"], indices=(0, 3, 1, 2)),
69
+ DeleteItemsd(keys=["mask", "dwi", "adc", "heatmap"]),
70
  ToTensord(keys=["image", "label", "final_heatmap"]),
71
  ]
72
  )
73
  else:
74
  transform = Compose(
75
  [
76
+ LoadImaged(
77
+ keys=["image", "mask", "dwi", "adc"],
78
+ reader=ITKReader(),
79
+ ensure_channel_first=True,
80
+ dtype=np.float32,
81
+ ),
82
  ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
83
+ ConcatItemsd(
84
+ keys=["image", "dwi", "adc"], name="image", dim=0
85
+ ), # stacks to (3, H, W)
86
  NormalizeIntensityd(keys=["image"], channel_wise=True),
87
+ RandCropByPosNegLabeld(
88
+ keys=["image"],
89
+ label_key="mask",
90
+ spatial_size=(args.tile_size, args.tile_size, args.depth),
91
+ pos=1,
92
+ neg=0,
93
+ num_samples=args.tile_count,
94
+ ),
95
  EnsureTyped(keys=["label"], dtype=torch.float32),
96
  Transposed(keys=["image"], indices=(0, 3, 1, 2)),
97
+ DeleteItemsd(keys=["mask", "dwi", "adc"]),
98
  ToTensord(keys=["image", "label"]),
99
  ]
100
  )
101
  return transform
102
 
 
103
 
104
+ def get_dataloader(args, split: Literal["train", "test"]):
105
  data_list = load_decathlon_datalist(
106
  data_list_file_path=args.dataset_json,
107
  data_list_key=split,
 
112
  cache_dir_ = os.path.join(args.logdir, "cache")
113
  os.makedirs(os.path.join(cache_dir_, split), exist_ok=True)
114
  transform = data_transform(args)
115
+ dataset = PersistentDataset(
116
+ data=data_list, transform=transform, cache_dir=os.path.join(cache_dir_, split)
117
+ )
118
  loader = torch.utils.data.DataLoader(
119
  dataset,
120
  batch_size=args.batch_size,
121
  shuffle=(split == "train"),
122
  num_workers=args.workers,
123
  pin_memory=True,
124
+ multiprocessing_context="fork" if args.workers > 0 else None,
125
  sampler=None,
126
  collate_fn=list_data_collate,
127
  )
128
  return loader
 
src/model/MIL.py CHANGED
@@ -1,13 +1,10 @@
1
-
2
  from __future__ import annotations
3
-
4
  from typing import cast
5
-
6
  import torch
7
  import torch.nn as nn
8
-
9
  from monai.utils.module import optional_import
10
  from monai.networks.nets import resnet
 
11
  models, _ = optional_import("torchvision.models")
12
 
13
 
@@ -16,8 +13,7 @@ class MILModel_3D(nn.Module):
16
  Multiple Instance Learning (MIL) model, with a backbone classification model.
17
  Adapted from MONAI, modified for 3D images. The expected shape of input data is `[B, N, C, D, H, W]`,
18
  where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances
19
- extracted from every original image in the batch. A tutorial example is available at:
20
- https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning.
21
 
22
  Args:
23
  num_classes: number of output classes.
@@ -29,10 +25,9 @@ class MILModel_3D(nn.Module):
29
  - ``"att_trans"`` - transformer MIL https://arxiv.org/abs/2111.01556.
30
  - ``"att_trans_pyramid"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556.
31
 
32
- pretrained: init backbone with pretrained weights, defaults to ``True``.
33
  backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features,
34
  or a string name of a torchvision model).
35
- Defaults to ``None``, in which case ResNet50 is used.
36
  backbone_num_features: Number of output features of the backbone CNN
37
  Defaults to ``None`` (necessary only when using a custom backbone)
38
  trans_blocks: number of the blocks in `TransformEncoder` layer.
@@ -63,7 +58,11 @@ class MILModel_3D(nn.Module):
63
  self.transformer: nn.Module | None = None
64
 
65
  if backbone is None:
66
- net = resnet.resnet18(spatial_dims=3, n_input_channels=3, num_classes=5, )
 
 
 
 
67
  nfc = net.fc.in_features # save the number of final features
68
  net.fc = torch.nn.Identity() # remove final linear layer
69
 
@@ -72,7 +71,6 @@ class MILModel_3D(nn.Module):
72
  if mil_mode == "att_trans_pyramid":
73
  # register hooks to capture outputs of intermediate layers
74
  def forward_hook(layer_name):
75
-
76
  def hook(module, input, output):
77
  self.extra_outputs[layer_name] = output
78
 
@@ -105,31 +103,23 @@ class MILModel_3D(nn.Module):
105
  nfc = backbone_num_features
106
  net.fc = torch.nn.Identity() # remove final linear layer
107
 
108
- self.extra_outputs: dict[str, torch.Tensor] = {}
109
-
110
  if mil_mode == "att_trans_pyramid":
111
  # register hooks to capture outputs of intermediate layers
112
- def forward_hook(layer_name):
113
-
114
- def hook(module, input, output):
115
- self.extra_outputs[layer_name] = output
116
-
117
- return hook
118
-
119
- net.layer1.register_forward_hook(forward_hook("layer1"))
120
- net.layer2.register_forward_hook(forward_hook("layer2"))
121
- net.layer3.register_forward_hook(forward_hook("layer3"))
122
- net.layer4.register_forward_hook(forward_hook("layer4"))
123
 
124
  if backbone_num_features is None:
125
- raise ValueError("Number of endencoder features must be provided for a custom backbone model")
 
 
126
 
127
  else:
128
  raise ValueError("Unsupported backbone")
129
-
130
  if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]:
131
  raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode))
132
-
133
  if self.mil_mode in ["mean", "max"]:
134
  pass
135
  elif self.mil_mode == "att":
@@ -144,7 +134,8 @@ class MILModel_3D(nn.Module):
144
  transformer_list = nn.ModuleList(
145
  [
146
  nn.TransformerEncoder(
147
- nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout), num_layers=trans_blocks
 
148
  ),
149
  nn.Sequential(
150
  nn.Linear(192, 64),
@@ -206,10 +197,26 @@ class MILModel_3D(nn.Module):
206
  x = self.myfc(x)
207
 
208
  elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None:
209
- l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3, 4)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
210
- l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3, 4)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
211
- l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3, 4)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
212
- l4 = torch.mean(self.extra_outputs["layer4"], dim=(2, 3, 4)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  transformer_list = cast(nn.ModuleList, self.transformer)
215
 
@@ -242,7 +249,3 @@ class MILModel_3D(nn.Module):
242
  x = self.calc_head(x)
243
 
244
  return x
245
-
246
-
247
-
248
-
 
 
1
  from __future__ import annotations
 
2
  from typing import cast
 
3
  import torch
4
  import torch.nn as nn
 
5
  from monai.utils.module import optional_import
6
  from monai.networks.nets import resnet
7
+
8
  models, _ = optional_import("torchvision.models")
9
 
10
 
 
13
  Multiple Instance Learning (MIL) model, with a backbone classification model.
14
  Adapted from MONAI, modified for 3D images. The expected shape of input data is `[B, N, C, D, H, W]`,
15
  where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances
16
+ extracted from every original image in the batch.
 
17
 
18
  Args:
19
  num_classes: number of output classes.
 
25
  - ``"att_trans"`` - transformer MIL https://arxiv.org/abs/2111.01556.
26
  - ``"att_trans_pyramid"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556.
27
 
 
28
  backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features,
29
  or a string name of a torchvision model).
30
+ Defaults to ``None``, in which case ResNet18 is used.
31
  backbone_num_features: Number of output features of the backbone CNN
32
  Defaults to ``None`` (necessary only when using a custom backbone)
33
  trans_blocks: number of the blocks in `TransformEncoder` layer.
 
58
  self.transformer: nn.Module | None = None
59
 
60
  if backbone is None:
61
+ net = resnet.resnet18(
62
+ spatial_dims=3,
63
+ n_input_channels=3,
64
+ num_classes=5,
65
+ )
66
  nfc = net.fc.in_features # save the number of final features
67
  net.fc = torch.nn.Identity() # remove final linear layer
68
 
 
71
  if mil_mode == "att_trans_pyramid":
72
  # register hooks to capture outputs of intermediate layers
73
  def forward_hook(layer_name):
 
74
  def hook(module, input, output):
75
  self.extra_outputs[layer_name] = output
76
 
 
103
  nfc = backbone_num_features
104
  net.fc = torch.nn.Identity() # remove final linear layer
105
 
 
 
106
  if mil_mode == "att_trans_pyramid":
107
  # register hooks to capture outputs of intermediate layers
108
+ raise ValueError(
109
+ "Cannot use att_trans_pyramid with custom backbone. Have to use the default ResNet 18 backbone."
110
+ )
 
 
 
 
 
 
 
 
111
 
112
  if backbone_num_features is None:
113
+ raise ValueError(
114
+ "Number of endencoder features must be provided for a custom backbone model"
115
+ )
116
 
117
  else:
118
  raise ValueError("Unsupported backbone")
119
+
120
  if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]:
121
  raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode))
122
+
123
  if self.mil_mode in ["mean", "max"]:
124
  pass
125
  elif self.mil_mode == "att":
 
134
  transformer_list = nn.ModuleList(
135
  [
136
  nn.TransformerEncoder(
137
+ nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout),
138
+ num_layers=trans_blocks,
139
  ),
140
  nn.Sequential(
141
  nn.Linear(192, 64),
 
197
  x = self.myfc(x)
198
 
199
  elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None:
200
+ l1 = (
201
+ torch.mean(self.extra_outputs["layer1"], dim=(2, 3, 4))
202
+ .reshape(sh[0], sh[1], -1)
203
+ .permute(1, 0, 2)
204
+ )
205
+ l2 = (
206
+ torch.mean(self.extra_outputs["layer2"], dim=(2, 3, 4))
207
+ .reshape(sh[0], sh[1], -1)
208
+ .permute(1, 0, 2)
209
+ )
210
+ l3 = (
211
+ torch.mean(self.extra_outputs["layer3"], dim=(2, 3, 4))
212
+ .reshape(sh[0], sh[1], -1)
213
+ .permute(1, 0, 2)
214
+ )
215
+ l4 = (
216
+ torch.mean(self.extra_outputs["layer4"], dim=(2, 3, 4))
217
+ .reshape(sh[0], sh[1], -1)
218
+ .permute(1, 0, 2)
219
+ )
220
 
221
  transformer_list = cast(nn.ModuleList, self.transformer)
222
 
 
249
  x = self.calc_head(x)
250
 
251
  return x
 
 
 
 
src/model/csPCa_model.py CHANGED
@@ -1,32 +1,72 @@
1
-
2
  from __future__ import annotations
3
-
4
- from typing import cast
5
-
6
  import torch
7
  import torch.nn as nn
8
-
9
  from monai.utils.module import optional_import
10
- models, _ = optional_import("torchvision.models")
11
 
 
12
 
13
 
14
  class SimpleNN(nn.Module):
 
 
 
 
 
 
 
 
 
 
15
  def __init__(self, input_dim):
16
  super(SimpleNN, self).__init__()
17
  self.net = nn.Sequential(
18
  nn.Linear(input_dim, 256),
19
  nn.ReLU(),
20
- nn.Linear( 256,128),
21
  nn.ReLU(),
22
- nn.Dropout(p=0.3),
23
  nn.Linear(128, 1),
24
- nn.Sigmoid() # since binary classification
25
  )
 
26
  def forward(self, x):
 
 
 
 
 
 
 
 
 
27
  return self.net(x)
28
 
 
29
  class csPCa_Model(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def __init__(self, backbone):
31
  super().__init__()
32
  self.backbone = backbone
@@ -47,4 +87,3 @@ class csPCa_Model(nn.Module):
47
 
48
  x = self.fc_cspca(x)
49
  return x
50
-
 
 
1
  from __future__ import annotations
 
 
 
2
  import torch
3
  import torch.nn as nn
 
4
  from monai.utils.module import optional_import
 
5
 
6
+ models, _ = optional_import("torchvision.models")
7
 
8
 
9
  class SimpleNN(nn.Module):
10
+ """
11
+ A simple Multi-Layer Perceptron (MLP) for binary classification.
12
+
13
+ This network consists of two hidden layers with ReLU activation and a dropout layer,
14
+ followed by a final sigmoid activation for probability output.
15
+
16
+ Args:
17
+ input_dim (int): The number of input features.
18
+ """
19
+
20
  def __init__(self, input_dim):
21
  super(SimpleNN, self).__init__()
22
  self.net = nn.Sequential(
23
  nn.Linear(input_dim, 256),
24
  nn.ReLU(),
25
+ nn.Linear(256, 128),
26
  nn.ReLU(),
27
+ nn.Dropout(p=0.3),
28
  nn.Linear(128, 1),
29
+ nn.Sigmoid(), # since binary classification
30
  )
31
+
32
  def forward(self, x):
33
+ """
34
+ Forward pass of the classifier.
35
+
36
+ Args:
37
+ x (torch.Tensor): Input tensor of shape (Batch, input_dim).
38
+
39
+ Returns:
40
+ torch.Tensor: Output probabilities of shape (Batch, 1).
41
+ """
42
  return self.net(x)
43
 
44
+
45
  class csPCa_Model(nn.Module):
46
+ """
47
+ Clinically Significant Prostate Cancer (csPCa) risk prediction model using a MIL backbone.
48
+
49
+ This model repurposes a pre-trained Multiple Instance Learning (MIL) backbone (originally
50
+ designed for PI-RADS prediction) for binary csPCa risk assessment. It utilizes the
51
+ backbone's feature extractor, transformer, and attention mechanism to aggregate instance-level
52
+ features into a bag-level embedding.
53
+
54
+ The original fully connected classification head of the backbone is replaced by a
55
+ custom :class:`SimpleNN` head for the new task.
56
+
57
+ Args:
58
+ backbone (nn.Module): A pre-trained MIL model. The backbone must possess the
59
+ following attributes/sub-modules:
60
+ - ``net``: The CNN feature extractor.
61
+ - ``transformer``: A sequence modeling module.
62
+ - ``attention``: An attention mechanism for pooling.
63
+ - ``myfc``: The original fully connected layer (used to determine feature dimensions).
64
+
65
+ Attributes:
66
+ fc_cspca (SimpleNN): The new classification head for csPCa prediction.
67
+ backbone: The MIL based PI-RADS classifier.
68
+ """
69
+
70
  def __init__(self, backbone):
71
  super().__init__()
72
  self.backbone = backbone
 
87
 
88
  x = self.fc_cspca(x)
89
  return x
 
src/preprocessing/center_crop.py CHANGED
@@ -8,10 +8,10 @@
8
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
  # See the License for the specific language governing permissions and
10
  # limitations under the License.
11
- #python scripts/center_crop.py --file_name path/to/t2_image --out_name cropped_t2
12
 
13
 
14
- #import argparse
15
  from typing import Union
16
 
17
  import SimpleITK as sitk # noqa N813
@@ -41,7 +41,9 @@ def crop(image: sitk.Image, margin: Union[int, float], interpolator=sitk.sitkLin
41
 
42
  # calculate new origin and new image size
43
  if all([isinstance(m, float) for m in _flatten(margin)]):
44
- assert all([m >= 0 and m < 0.5 for m in _flatten(margin)]), "margins must be between 0 and 0.5"
 
 
45
  to_crop = [[int(sz * _m) for _m in m] for sz, m in zip(old_size, margin)]
46
  elif all([isinstance(m, int) for m in _flatten(margin)]):
47
  to_crop = margin
@@ -61,4 +63,3 @@ def crop(image: sitk.Image, margin: Union[int, float], interpolator=sitk.sitkLin
61
  ref_image.SetDirection(image.GetDirection())
62
 
63
  return sitk.Resample(image, ref_image, interpolator=interpolator)
64
-
 
8
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
  # See the License for the specific language governing permissions and
10
  # limitations under the License.
11
+ # python scripts/center_crop.py --file_name path/to/t2_image --out_name cropped_t2
12
 
13
 
14
+ # import argparse
15
  from typing import Union
16
 
17
  import SimpleITK as sitk # noqa N813
 
41
 
42
  # calculate new origin and new image size
43
  if all([isinstance(m, float) for m in _flatten(margin)]):
44
+ assert all([m >= 0 and m < 0.5 for m in _flatten(margin)]), (
45
+ "margins must be between 0 and 0.5"
46
+ )
47
  to_crop = [[int(sz * _m) for _m in m] for sz, m in zip(old_size, margin)]
48
  elif all([isinstance(m, int) for m in _flatten(margin)]):
49
  to_crop = margin
 
63
  ref_image.SetDirection(image.GetDirection())
64
 
65
  return sitk.Resample(image, ref_image, interpolator=interpolator)
 
src/preprocessing/generate_heatmap.py CHANGED
@@ -1,76 +1,93 @@
1
-
2
  import os
3
  import numpy as np
4
  import nrrd
5
- import json
6
- import pandas as pd
7
- import json
8
- import SimpleITK as sitk
9
- import multiprocessing
10
-
11
  import logging
12
 
13
 
14
  def get_heatmap(args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  files = os.listdir(args.t2_dir)
17
- args.heatmapdir = os.path.join(args.output_dir, 'heatmaps/')
18
  os.makedirs(args.heatmapdir, exist_ok=True)
19
- for file in files:
20
-
21
  bool_dwi = False
22
  bool_adc = False
23
  mask, _ = nrrd.read(os.path.join(args.seg_dir, file))
24
- dwi, _ = nrrd.read(os.path.join(args.dwi_dir, file))
25
- adc, _ = nrrd.read(os.path.join(args.adc_dir, file))
26
-
27
  nonzero_vals_dwi = dwi[mask > 0]
 
28
  if len(nonzero_vals_dwi) > 0:
29
  min_val = nonzero_vals_dwi.min()
30
  max_val = nonzero_vals_dwi.max()
31
  heatmap_dwi = np.zeros_like(dwi, dtype=np.float32)
32
-
33
  if min_val != max_val:
34
  heatmap_dwi = (dwi - min_val) / (max_val - min_val)
35
- masked_heatmap_dwi = np.where(mask > 0, heatmap_dwi, heatmap_dwi[mask>0].min())
36
  else:
37
  bool_dwi = True
38
-
39
  else:
40
  bool_dwi = True
41
 
42
  nonzero_vals_adc = adc[mask > 0]
 
43
  if len(nonzero_vals_adc) > 0:
44
  min_val = nonzero_vals_adc.min()
45
  max_val = nonzero_vals_adc.max()
46
  heatmap_adc = np.zeros_like(adc, dtype=np.float32)
47
-
48
  if min_val != max_val:
49
  heatmap_adc = (max_val - adc) / (max_val - min_val)
50
- masked_heatmap_adc = np.where(mask > 0, heatmap_adc, heatmap_adc[mask>0].min())
51
  else:
52
  bool_adc = True
53
 
54
  else:
55
  bool_adc = True
56
-
57
 
58
- if bool_dwi:
59
- mix_mask = masked_heatmap_adc
60
- if bool_adc:
61
- mix_mask = masked_heatmap_dwi
62
  if not bool_dwi and not bool_adc:
63
  mix_mask = masked_heatmap_dwi * masked_heatmap_adc
 
 
 
 
 
 
 
64
  else:
65
  mix_mask = np.ones_like(adc, dtype=np.float32)
 
66
 
67
  mix_mask = (mix_mask - mix_mask.min()) / (mix_mask.max() - mix_mask.min())
 
68
 
69
-
70
- nrrd.write(os.path.join(args.heatmapdir, file), mix_mask)
71
-
72
  return args
73
-
74
-
75
-
76
-
 
 
1
  import os
2
  import numpy as np
3
  import nrrd
4
+ from tqdm import tqdm
 
 
 
 
 
5
  import logging
6
 
7
 
8
  def get_heatmap(args):
9
+ """
10
+ Generate heatmaps from DWI (Diffusion Weighted Imaging) and ADC (Apparent Diffusion Coefficient) medical imaging data.
11
+ This function processes medical imaging files (DWI and ADC) along with their corresponding
12
+ segmentation masks to create normalized heatmaps. It combines the DWI
13
+ and ADC heatmaps through element-wise multiplication.
14
+ Args:
15
+ args: An object containing the following attributes:
16
+ - t2_dir (str): Directory path containing T2 image files.
17
+ - dwi_dir (str): Directory path containing DWI image files.
18
+ - adc_dir (str): Directory path containing ADC image files.
19
+ - seg_dir (str): Directory path containing segmentation mask files.
20
+ - output_dir (str): Base output directory where 'heatmaps/' subdirectory will be created.
21
+ - heatmapdir (str): Output directory for heatmap files (created by function).
22
+ Returns:
23
+ args: The modified args object with heatmapdir attribute set.
24
+ Raises:
25
+ FileNotFoundError: If input directories or files do not exist.
26
+ ValueError: If NRRD files cannot be read properly.
27
+ Notes:
28
+ - DWI heatmap is normalized as (dwi - min) / (max - min)
29
+ - ADC heatmap is normalized as (max - adc) / (max - min) (inverted)
30
+ - Final heatmap is re-normalized to [0, 1] range
31
+ - If all values in a mask region are identical, the heatmap is skipped for that modality
32
+ - Output files are written in NRRD format with the same header as the input DWI file
33
+ """
34
 
35
  files = os.listdir(args.t2_dir)
36
+ args.heatmapdir = os.path.join(args.output_dir, "heatmaps/")
37
  os.makedirs(args.heatmapdir, exist_ok=True)
38
+ logging.info("Starting heatmap generation")
39
+ for file in tqdm(files):
40
  bool_dwi = False
41
  bool_adc = False
42
  mask, _ = nrrd.read(os.path.join(args.seg_dir, file))
43
+ dwi, header_dwi = nrrd.read(os.path.join(args.dwi_dir, file))
44
+ adc, header_adc = nrrd.read(os.path.join(args.adc_dir, file))
 
45
  nonzero_vals_dwi = dwi[mask > 0]
46
+
47
  if len(nonzero_vals_dwi) > 0:
48
  min_val = nonzero_vals_dwi.min()
49
  max_val = nonzero_vals_dwi.max()
50
  heatmap_dwi = np.zeros_like(dwi, dtype=np.float32)
51
+
52
  if min_val != max_val:
53
  heatmap_dwi = (dwi - min_val) / (max_val - min_val)
54
+ masked_heatmap_dwi = np.where(mask > 0, heatmap_dwi, heatmap_dwi[mask > 0].min())
55
  else:
56
  bool_dwi = True
57
+
58
  else:
59
  bool_dwi = True
60
 
61
  nonzero_vals_adc = adc[mask > 0]
62
+
63
  if len(nonzero_vals_adc) > 0:
64
  min_val = nonzero_vals_adc.min()
65
  max_val = nonzero_vals_adc.max()
66
  heatmap_adc = np.zeros_like(adc, dtype=np.float32)
67
+
68
  if min_val != max_val:
69
  heatmap_adc = (max_val - adc) / (max_val - min_val)
70
+ masked_heatmap_adc = np.where(mask > 0, heatmap_adc, heatmap_adc[mask > 0].min())
71
  else:
72
  bool_adc = True
73
 
74
  else:
75
  bool_adc = True
 
76
 
 
 
 
 
77
  if not bool_dwi and not bool_adc:
78
  mix_mask = masked_heatmap_dwi * masked_heatmap_adc
79
+ write_header = header_dwi
80
+ elif bool_dwi:
81
+ mix_mask = masked_heatmap_adc
82
+ write_header = header_adc
83
+ elif bool_adc:
84
+ mix_mask = masked_heatmap_dwi
85
+ write_header = header_dwi
86
  else:
87
  mix_mask = np.ones_like(adc, dtype=np.float32)
88
+ write_header = header_dwi
89
 
90
  mix_mask = (mix_mask - mix_mask.min()) / (mix_mask.max() - mix_mask.min())
91
+ nrrd.write(os.path.join(args.heatmapdir, file), mix_mask, write_header)
92
 
 
 
 
93
  return args
 
 
 
 
src/preprocessing/histogram_match.py CHANGED
@@ -1,16 +1,30 @@
1
- import SimpleITK as sitk
2
  import os
3
- import numpy as np
4
  import nrrd
5
- from tqdm import tqdm
6
- import pandas as pd
7
- import random
8
- import json
9
  from skimage import exposure
10
- import multiprocessing
11
  import logging
12
- def get_histmatched(data, ref_data, mask, ref_mask):
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  source_pixels = data[mask > 0]
15
  ref_pixels = ref_data[ref_mask > 0]
16
  matched_pixels = exposure.match_histograms(source_pixels, ref_pixels)
@@ -19,36 +33,35 @@ def get_histmatched(data, ref_data, mask, ref_mask):
19
 
20
  return matched_img
21
 
22
- def histmatch(args):
23
 
 
24
  files = os.listdir(args.t2_dir)
25
-
26
- t2_histmatched_dir = os.path.join(args.output_dir, 't2_histmatched')
27
- dwi_histmatched_dir = os.path.join(args.output_dir, 'DWI_histmatched')
28
- adc_histmatched_dir = os.path.join(args.output_dir, 'ADC_histmatched')
29
  os.makedirs(t2_histmatched_dir, exist_ok=True)
30
  os.makedirs(dwi_histmatched_dir, exist_ok=True)
31
  os.makedirs(adc_histmatched_dir, exist_ok=True)
32
  logging.info("Starting histogram matching")
33
- for file in files:
34
 
 
35
  t2_image, header_t2 = nrrd.read(os.path.join(args.t2_dir, file))
36
  dwi_image, header_dwi = nrrd.read(os.path.join(args.dwi_dir, file))
37
  adc_image, header_adc = nrrd.read(os.path.join(args.adc_dir, file))
38
 
39
- ref_t2, _ = nrrd.read(os.path.join(args.project_dir, 'dataset', 't2_reference.nrrd'))
40
- ref_dwi, _ = nrrd.read(os.path.join(args.project_dir, 'dataset', 'dwi_reference.nrrd'))
41
- ref_adc , _ = nrrd.read(os.path.join(args.project_dir, 'dataset', 'adc_reference.nrrd'))
42
-
43
  prostate_mask, _ = nrrd.read(os.path.join(args.seg_dir, file))
44
- ref_prostate_mask, _ = nrrd.read(os.path.join(args.project_dir, 'dataset', 'prostate_segmentation_reference.nrrd'))
 
 
45
 
46
  histmatched_t2 = get_histmatched(t2_image, ref_t2, prostate_mask, ref_prostate_mask)
47
  histmatched_dwi = get_histmatched(dwi_image, ref_dwi, prostate_mask, ref_prostate_mask)
48
  histmatched_adc = get_histmatched(adc_image, ref_adc, prostate_mask, ref_prostate_mask)
49
 
50
-
51
-
52
  nrrd.write(os.path.join(t2_histmatched_dir, file), histmatched_t2, header_t2)
53
  nrrd.write(os.path.join(dwi_histmatched_dir, file), histmatched_dwi, header_dwi)
54
  nrrd.write(os.path.join(adc_histmatched_dir, file), histmatched_adc, header_adc)
@@ -58,5 +71,3 @@ def histmatch(args):
58
  args.adc_dir = adc_histmatched_dir
59
 
60
  return args
61
-
62
-
 
 
1
  import os
 
2
  import nrrd
 
 
 
 
3
  from skimage import exposure
 
4
  import logging
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
 
9
+ def get_histmatched(
10
+ data: np.ndarray, ref_data: np.ndarray, mask: np.ndarray, ref_mask: np.ndarray
11
+ ) -> np.ndarray:
12
+ """
13
+ Perform histogram matching on source data using a reference image.
14
+ This function adjusts the histogram of the source image to match the
15
+ histogram of the reference image within masked regions. Only pixels
16
+ where the mask is greater than 0 are considered for matching.
17
+ Args:
18
+ data: Source image array to be histogram matched.
19
+ ref_data: Reference image array whose histogram will be used as target.
20
+ mask: Binary mask for source image indicating valid pixels (values > 0).
21
+ ref_mask: Binary mask for reference image indicating valid pixels (values > 0).
22
+ Returns:
23
+ Histogram-matched image with the same shape as input data.
24
+ Only pixels in masked regions are modified; unmasked pixels remain unchanged.
25
+ Example:
26
+ >>> matched = get_histmatched(source_img, reference_img, source_mask, ref_mask)
27
+ """
28
  source_pixels = data[mask > 0]
29
  ref_pixels = ref_data[ref_mask > 0]
30
  matched_pixels = exposure.match_histograms(source_pixels, ref_pixels)
 
33
 
34
  return matched_img
35
 
 
36
 
37
+ def histmatch(args):
38
  files = os.listdir(args.t2_dir)
39
+
40
+ t2_histmatched_dir = os.path.join(args.output_dir, "t2_histmatched")
41
+ dwi_histmatched_dir = os.path.join(args.output_dir, "DWI_histmatched")
42
+ adc_histmatched_dir = os.path.join(args.output_dir, "ADC_histmatched")
43
  os.makedirs(t2_histmatched_dir, exist_ok=True)
44
  os.makedirs(dwi_histmatched_dir, exist_ok=True)
45
  os.makedirs(adc_histmatched_dir, exist_ok=True)
46
  logging.info("Starting histogram matching")
 
47
 
48
+ for file in tqdm(files):
49
  t2_image, header_t2 = nrrd.read(os.path.join(args.t2_dir, file))
50
  dwi_image, header_dwi = nrrd.read(os.path.join(args.dwi_dir, file))
51
  adc_image, header_adc = nrrd.read(os.path.join(args.adc_dir, file))
52
 
53
+ ref_t2, _ = nrrd.read(os.path.join(args.project_dir, "dataset", "t2_reference.nrrd"))
54
+ ref_dwi, _ = nrrd.read(os.path.join(args.project_dir, "dataset", "dwi_reference.nrrd"))
55
+ ref_adc, _ = nrrd.read(os.path.join(args.project_dir, "dataset", "adc_reference.nrrd"))
 
56
  prostate_mask, _ = nrrd.read(os.path.join(args.seg_dir, file))
57
+ ref_prostate_mask, _ = nrrd.read(
58
+ os.path.join(args.project_dir, "dataset", "prostate_segmentation_reference.nrrd")
59
+ )
60
 
61
  histmatched_t2 = get_histmatched(t2_image, ref_t2, prostate_mask, ref_prostate_mask)
62
  histmatched_dwi = get_histmatched(dwi_image, ref_dwi, prostate_mask, ref_prostate_mask)
63
  histmatched_adc = get_histmatched(adc_image, ref_adc, prostate_mask, ref_prostate_mask)
64
 
 
 
65
  nrrd.write(os.path.join(t2_histmatched_dir, file), histmatched_t2, header_t2)
66
  nrrd.write(os.path.join(dwi_histmatched_dir, file), histmatched_dwi, header_dwi)
67
  nrrd.write(os.path.join(adc_histmatched_dir, file), histmatched_adc, header_adc)
 
71
  args.adc_dir = adc_histmatched_dir
72
 
73
  return args
 
 
src/preprocessing/prostate_mask.py CHANGED
@@ -1,65 +1,63 @@
1
  import os
2
- from typing import Union
3
- import SimpleITK as sitk
4
  import numpy as np
5
  import nrrd
6
- import matplotlib.pyplot as plt
7
  from tqdm import tqdm
8
- from AIAH_utility.viewer import BasicViewer, ListViewer
9
- from PIL import Image
10
- import monai
11
  from monai.bundle import ConfigParser
12
- from monai.config import print_config
13
  import torch
14
- import sys
15
- import os
16
- import nibabel as nib
17
- import shutil
18
 
19
- from tqdm import trange, tqdm
20
 
21
- from monai.data import DataLoader, Dataset, TestTimeAugmentation, create_test_image_2d
22
- from monai.losses import DiceLoss
23
- from monai.metrics import DiceMetric
24
- from monai.networks.nets import UNet
25
  from monai.transforms import (
26
- Activationsd,
27
- AsDiscreted,
28
  Compose,
29
- CropForegroundd,
30
- DivisiblePadd,
31
- Invertd,
32
  LoadImaged,
33
  ScaleIntensityd,
34
- RandRotated,
35
- RandRotate,
36
- InvertibleTransform,
37
- RandFlipd,
38
- Activations,
39
- AsDiscrete,
40
  NormalizeIntensityd,
41
  )
42
  from monai.utils import set_determinism
43
  from monai.transforms import (
44
- Resize,
45
  EnsureChannelFirstd,
46
  Orientationd,
47
  Spacingd,
48
  EnsureTyped,
49
  )
50
- import nrrd
51
-
52
- set_determinism(43)
53
  from monai.data import MetaTensor
54
- import SimpleITK as sitk
55
- import pandas as pd
56
  import logging
 
 
 
 
57
  def get_segmask(args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  args.seg_dir = os.path.join(args.output_dir, "prostate_mask")
60
  os.makedirs(args.seg_dir, exist_ok=True)
61
 
62
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
63
  model_config_file = os.path.join(args.project_dir, "config", "inference.json")
64
  model_config = ConfigParser()
65
  model_config.read_config(model_config_file)
@@ -67,24 +65,16 @@ def get_segmask(args):
67
  model_config["dataset_dir"] = args.t2_dir
68
  files = os.listdir(args.t2_dir)
69
  model_config["datalist"] = [os.path.join(args.t2_dir, f) for f in files]
70
-
71
-
72
  checkpoint = os.path.join(
73
  args.project_dir,
74
  "models",
75
  "prostate_segmentation_model.pt",
76
  )
77
- preprocessing = model_config.get_parsed_content("preprocessing")
78
  model = model_config.get_parsed_content("network_def").to(device)
79
  inferer = model_config.get_parsed_content("inferer")
80
- postprocessing = model_config.get_parsed_content("postprocessing")
81
- dataloader = model_config.get_parsed_content("dataloader")
82
  model.load_state_dict(torch.load(checkpoint, map_location=device))
83
  model.eval()
84
 
85
- torch.cuda.empty_cache()
86
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
87
-
88
  keys = "image"
89
  transform = Compose(
90
  [
@@ -99,8 +89,8 @@ def get_segmask(args):
99
  )
100
  logging.info("Starting prostate segmentation")
101
  for file in tqdm(files):
102
-
103
  data = {"image": os.path.join(args.t2_dir, file)}
 
104
  transformed_data = transform(data)
105
  a = transformed_data
106
  with torch.no_grad():
@@ -114,15 +104,12 @@ def get_segmask(args):
114
  temp = transform.inverse(transformed_data)
115
  pred_temp = temp["image"][0].numpy()
116
  pred_nrrd = np.round(pred_temp)
117
-
118
- nonzero_counts = np.count_nonzero(pred_nrrd, axis=(0,1))
119
  top_slices = np.argsort(nonzero_counts)[-10:]
120
  output_ = np.zeros_like(pred_nrrd)
121
- output_[:,:,top_slices] = pred_nrrd[:,:,top_slices]
122
-
123
- nrrd.write(os.path.join(args.seg_dir, file), output_)
124
-
125
- return args
126
 
 
127
 
128
-
 
1
  import os
 
 
2
  import numpy as np
3
  import nrrd
 
4
  from tqdm import tqdm
 
 
 
5
  from monai.bundle import ConfigParser
 
6
  import torch
 
 
 
 
7
 
 
8
 
 
 
 
 
9
  from monai.transforms import (
 
 
10
  Compose,
 
 
 
11
  LoadImaged,
12
  ScaleIntensityd,
 
 
 
 
 
 
13
  NormalizeIntensityd,
14
  )
15
  from monai.utils import set_determinism
16
  from monai.transforms import (
 
17
  EnsureChannelFirstd,
18
  Orientationd,
19
  Spacingd,
20
  EnsureTyped,
21
  )
 
 
 
22
  from monai.data import MetaTensor
 
 
23
  import logging
24
+
25
+ set_determinism(43)
26
+
27
+
28
  def get_segmask(args):
29
+ """
30
+ Generate prostate segmentation masks using a pre-trained deep learning model.
31
+ This function performs inference on T2-weighted MRI images to segment the prostate gland.
32
+ It applies preprocessing transformations, runs the segmentation model, and saves the
33
+ predicted masks. Post-processing is applied to retain only the top 10 slices with
34
+ the highest non-zero voxel counts.
35
+ Args:
36
+ args: An arguments object containing:
37
+ - output_dir (str): Base output directory where segmentation masks will be saved
38
+ - project_dir (str): Root project directory containing model config and checkpoint
39
+ - t2_dir (str): Directory containing input T2-weighted MRI images in NRRD format
40
+ Returns:
41
+ args: The updated arguments object with seg_dir added, pointing to the
42
+ prostate_mask subdirectory within output_dir
43
+ Raises:
44
+ FileNotFoundError: If the model checkpoint or config file is not found
45
+ RuntimeError: If CUDA operations fail on GPU
46
+ Notes:
47
+ - Automatically selects GPU (CUDA) if available, otherwise uses CPU
48
+ - Applies MONAI transformations: loading, orientation (RAS), spacing (0.5mm isotropic),
49
+ intensity scaling and normalization
50
+ - Post-processing filters predictions to top 10 slices by non-zero voxel density
51
+ - Output masks are saved in NRRD format preserving original image headers
52
+ """
53
 
54
  args.seg_dir = os.path.join(args.output_dir, "prostate_mask")
55
  os.makedirs(args.seg_dir, exist_ok=True)
56
 
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ torch.cuda.empty_cache()
59
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
60
+
61
  model_config_file = os.path.join(args.project_dir, "config", "inference.json")
62
  model_config = ConfigParser()
63
  model_config.read_config(model_config_file)
 
65
  model_config["dataset_dir"] = args.t2_dir
66
  files = os.listdir(args.t2_dir)
67
  model_config["datalist"] = [os.path.join(args.t2_dir, f) for f in files]
 
 
68
  checkpoint = os.path.join(
69
  args.project_dir,
70
  "models",
71
  "prostate_segmentation_model.pt",
72
  )
 
73
  model = model_config.get_parsed_content("network_def").to(device)
74
  inferer = model_config.get_parsed_content("inferer")
 
 
75
  model.load_state_dict(torch.load(checkpoint, map_location=device))
76
  model.eval()
77
 
 
 
 
78
  keys = "image"
79
  transform = Compose(
80
  [
 
89
  )
90
  logging.info("Starting prostate segmentation")
91
  for file in tqdm(files):
 
92
  data = {"image": os.path.join(args.t2_dir, file)}
93
+ _, header_t2 = nrrd.read(data["image"])
94
  transformed_data = transform(data)
95
  a = transformed_data
96
  with torch.no_grad():
 
104
  temp = transform.inverse(transformed_data)
105
  pred_temp = temp["image"][0].numpy()
106
  pred_nrrd = np.round(pred_temp)
107
+
108
+ nonzero_counts = np.count_nonzero(pred_nrrd, axis=(0, 1))
109
  top_slices = np.argsort(nonzero_counts)[-10:]
110
  output_ = np.zeros_like(pred_nrrd)
111
+ output_[:, :, top_slices] = pred_nrrd[:, :, top_slices]
 
 
 
 
112
 
113
+ nrrd.write(os.path.join(args.seg_dir, file), output_, header_t2)
114
 
115
+ return args
src/preprocessing/register_and_crop.py CHANGED
@@ -1,25 +1,42 @@
1
- import SimpleITK as sitk
2
  import os
3
- import numpy as np
4
- import nrrd
5
  from tqdm import tqdm
6
- import pandas as pd
7
  from picai_prep.preprocessing import PreprocessingSettings, Sample
8
- import multiprocessing
9
  from .center_crop import crop
10
  import logging
 
 
11
  def register_files(args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  files = os.listdir(args.t2_dir)
13
- new_spacing = (0.4, 0.4, 3.0)
14
- t2_registered_dir = os.path.join(args.output_dir, 't2_registered')
15
- dwi_registered_dir = os.path.join(args.output_dir, 'DWI_registered')
16
- adc_registered_dir = os.path.join(args.output_dir, 'ADC_registered')
17
  os.makedirs(t2_registered_dir, exist_ok=True)
18
  os.makedirs(dwi_registered_dir, exist_ok=True)
19
  os.makedirs(adc_registered_dir, exist_ok=True)
20
  logging.info("Starting registration and cropping")
21
  for file in tqdm(files):
22
-
23
  t2_image = sitk.ReadImage(os.path.join(args.t2_dir, file))
24
  dwi_image = sitk.ReadImage(os.path.join(args.dwi_dir, file))
25
  adc_image = sitk.ReadImage(os.path.join(args.adc_dir, file))
@@ -30,38 +47,37 @@ def register_files(args):
30
  int(round(osz * ospc / nspc))
31
  for osz, ospc, nspc in zip(original_size, original_spacing, new_spacing)
32
  ]
33
-
34
  images_to_preprocess = {}
35
- images_to_preprocess['t2'] = t2_image
36
- images_to_preprocess['hbv'] = dwi_image
37
- images_to_preprocess['adc'] = adc_image
38
 
39
  pat_case = Sample(
40
  scans=[
41
- images_to_preprocess.get('t2'),
42
- images_to_preprocess.get('hbv'),
43
- images_to_preprocess.get('adc'),
44
  ],
45
- settings=PreprocessingSettings(spacing=[3.0,0.4,0.4], matrix_size=[new_size[2],new_size[1],new_size[0]]),
 
 
46
  )
47
  pat_case.preprocess()
48
-
49
- t2_post = pat_case.__dict__['scans'][0]
50
- dwi_post = pat_case.__dict__['scans'][1]
51
- adc_post = pat_case.__dict__['scans'][2]
52
  cropped_t2 = crop(t2_post, [args.margin, args.margin, 0.0])
53
  cropped_dwi = crop(dwi_post, [args.margin, args.margin, 0.0])
54
  cropped_adc = crop(adc_post, [args.margin, args.margin, 0.0])
55
 
56
-
57
-
58
  sitk.WriteImage(cropped_t2, os.path.join(t2_registered_dir, file))
59
  sitk.WriteImage(cropped_dwi, os.path.join(dwi_registered_dir, file))
60
  sitk.WriteImage(cropped_adc, os.path.join(adc_registered_dir, file))
61
 
62
- args.t2_dir = t2_registered_dir
63
- args.dwi_dir = dwi_registered_dir
64
- args.adc_dir = adc_registered_dir
65
-
66
- return args
67
 
 
 
1
+ import SimpleITK as sitk
2
  import os
 
 
3
  from tqdm import tqdm
 
4
  from picai_prep.preprocessing import PreprocessingSettings, Sample
 
5
  from .center_crop import crop
6
  import logging
7
+
8
+
9
  def register_files(args):
10
+ """
11
+ Register and crop medical images (T2, DWI, and ADC) to a standardized spacing and size.
12
+ This function reads medical images from specified directories, resamples them to a
13
+ new spacing of (0.4, 0.4, 3.0) mm, preprocesses them using the Sample class, and crops
14
+ them with specified margins. The processed images are saved to new output directories.
15
+ Args:
16
+ args: An argument object containing:
17
+ - t2_dir (str): Directory path containing T2 weighted images
18
+ - dwi_dir (str): Directory path containing DWI (Diffusion Weighted Imaging) images
19
+ - adc_dir (str): Directory path containing ADC (Apparent Diffusion Coefficient) images
20
+ - output_dir (str): Directory path where registered images will be saved
21
+ - margin (float): Margin in mm to crop from x and y dimensions
22
+ Returns:
23
+ args: Updated argument object with modified directory paths pointing to the
24
+ registered image directories (t2_registered, DWI_registered, ADC_registered)
25
+ Raises:
26
+ FileNotFoundError: If input directories do not exist or files cannot be read
27
+ RuntimeError: If image preprocessing or cropping fails
28
+ """
29
+
30
  files = os.listdir(args.t2_dir)
31
+ new_spacing = (0.4, 0.4, 3.0)
32
+ t2_registered_dir = os.path.join(args.output_dir, "t2_registered")
33
+ dwi_registered_dir = os.path.join(args.output_dir, "DWI_registered")
34
+ adc_registered_dir = os.path.join(args.output_dir, "ADC_registered")
35
  os.makedirs(t2_registered_dir, exist_ok=True)
36
  os.makedirs(dwi_registered_dir, exist_ok=True)
37
  os.makedirs(adc_registered_dir, exist_ok=True)
38
  logging.info("Starting registration and cropping")
39
  for file in tqdm(files):
 
40
  t2_image = sitk.ReadImage(os.path.join(args.t2_dir, file))
41
  dwi_image = sitk.ReadImage(os.path.join(args.dwi_dir, file))
42
  adc_image = sitk.ReadImage(os.path.join(args.adc_dir, file))
 
47
  int(round(osz * ospc / nspc))
48
  for osz, ospc, nspc in zip(original_size, original_spacing, new_spacing)
49
  ]
50
+
51
  images_to_preprocess = {}
52
+ images_to_preprocess["t2"] = t2_image
53
+ images_to_preprocess["hbv"] = dwi_image
54
+ images_to_preprocess["adc"] = adc_image
55
 
56
  pat_case = Sample(
57
  scans=[
58
+ images_to_preprocess.get("t2"),
59
+ images_to_preprocess.get("hbv"),
60
+ images_to_preprocess.get("adc"),
61
  ],
62
+ settings=PreprocessingSettings(
63
+ spacing=[3.0, 0.4, 0.4], matrix_size=[new_size[2], new_size[1], new_size[0]]
64
+ ),
65
  )
66
  pat_case.preprocess()
67
+
68
+ t2_post = pat_case.__dict__["scans"][0]
69
+ dwi_post = pat_case.__dict__["scans"][1]
70
+ adc_post = pat_case.__dict__["scans"][2]
71
  cropped_t2 = crop(t2_post, [args.margin, args.margin, 0.0])
72
  cropped_dwi = crop(dwi_post, [args.margin, args.margin, 0.0])
73
  cropped_adc = crop(adc_post, [args.margin, args.margin, 0.0])
74
 
 
 
75
  sitk.WriteImage(cropped_t2, os.path.join(t2_registered_dir, file))
76
  sitk.WriteImage(cropped_dwi, os.path.join(dwi_registered_dir, file))
77
  sitk.WriteImage(cropped_adc, os.path.join(adc_registered_dir, file))
78
 
79
+ args.t2_dir = t2_registered_dir
80
+ args.dwi_dir = dwi_registered_dir
81
+ args.adc_dir = adc_registered_dir
 
 
82
 
83
+ return args
src/train/train_cspca.py CHANGED
@@ -1,91 +1,18 @@
1
- import argparse
2
- import collections.abc
3
- import os
4
- import shutil
5
- import time
6
- import yaml
7
- from scipy.stats import pearsonr
8
- import gdown
9
- import numpy as np
10
  import torch
11
- import torch.distributed as dist
12
- import torch.multiprocessing as mp
13
  import torch.nn as nn
14
- import torch.nn.functional as F
15
- from monai.config import KeysCollection
16
- from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
17
- from monai.data.wsi_reader import WSIReader
18
  from monai.metrics import Cumulative, CumulativeAverage
19
- from monai.networks.nets import milmodel, resnet, MILModel
20
- from monai.transforms import (
21
- Compose,
22
- GridPatchd,
23
- LoadImaged,
24
- MapTransform,
25
- RandFlipd,
26
- RandGridPatchd,
27
- RandRotate90d,
28
- ScaleIntensityRanged,
29
- SplitDimd,
30
- ToTensord,
31
- ConcatItemsd,
32
- SelectItemsd,
33
- EnsureChannelFirstd,
34
- RepeatChanneld,
35
- DeleteItemsd,
36
- EnsureTyped,
37
- ClipIntensityPercentilesd,
38
- MaskIntensityd,
39
- HistogramNormalized,
40
- RandBiasFieldd,
41
- RandCropByPosNegLabeld,
42
- NormalizeIntensityd,
43
- SqueezeDimd,
44
- CropForegroundd,
45
- ScaleIntensityd,
46
- SpatialPadd,
47
- CenterSpatialCropd,
48
- ScaleIntensityd,
49
- Transposed,
50
- RandWeightedCropd,
51
- )
52
- from sklearn.metrics import cohen_kappa_score, roc_curve, confusion_matrix
53
- from torch.cuda.amp import GradScaler, autocast
54
- from torch.utils.data.dataloader import default_collate
55
- from torchvision.models.resnet import ResNet50_Weights
56
- import torch.optim as optim
57
- from torch.utils.data.distributed import DistributedSampler
58
- from torch.utils.tensorboard import SummaryWriter
59
- import matplotlib.pyplot as plt
60
- import matplotlib.patches as patches
61
- from tqdm import tqdm
62
- from sklearn.metrics import confusion_matrix, roc_auc_score
63
  from sklearn.metrics import roc_auc_score
64
- from sklearn.preprocessing import label_binarize
65
- import numpy as np
66
- from AIAH_utility.viewer import BasicViewer
67
- from scipy.special import expit
68
- import nrrd
69
- import random
70
- from sklearn.metrics import roc_auc_score
71
- import SimpleITK as sitk
72
- from AIAH_utility.viewer import BasicViewer
73
- import pandas as pd
74
- import json
75
- from sklearn.preprocessing import StandardScaler
76
- from torch.utils.data import DataLoader, TensorDataset, Dataset
77
- from sklearn.linear_model import LogisticRegression
78
- from sklearn.utils import resample
79
- import monai
80
 
81
  def train_epoch(cspca_model, loader, optimizer, epoch, args):
82
  cspca_model.train()
83
- criterion = nn.BCELoss()
84
  loss = 0.0
85
  run_loss = CumulativeAverage()
86
  TARGETS = Cumulative()
87
  PREDS = Cumulative()
88
-
89
  for idx, batch_data in enumerate(loader):
90
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
91
  target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
@@ -108,9 +35,10 @@ def train_epoch(cspca_model, loader, optimizer, epoch, args):
108
 
109
  return loss_epoch, auc_epoch
110
 
 
111
  def val_epoch(cspca_model, loader, epoch, args):
112
  cspca_model.eval()
113
- criterion = nn.BCELoss()
114
  loss = 0.0
115
  run_loss = CumulativeAverage()
116
  TARGETS = Cumulative()
@@ -132,10 +60,15 @@ def val_epoch(cspca_model, loader, epoch, args):
132
  target_list = TARGETS.get_buffer().cpu().numpy()
133
  pred_list = PREDS.get_buffer().cpu().numpy()
134
  auc_epoch = roc_auc_score(target_list, pred_list)
135
- y_pred_categoric = (pred_list >= 0.5)
136
  tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel()
137
  sens_epoch = tp / (tp + fn)
138
  spec_epoch = tn / (tn + fp)
139
- val_epoch_metric = {'epoch': epoch, 'loss': loss_epoch, 'auc': auc_epoch, 'sensitivity': sens_epoch, 'specificity': spec_epoch}
 
 
 
 
 
 
140
  return val_epoch_metric
141
-
 
 
 
 
 
 
 
 
 
 
1
  import torch
 
 
2
  import torch.nn as nn
 
 
 
 
3
  from monai.metrics import Cumulative, CumulativeAverage
4
+ from sklearn.metrics import confusion_matrix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from sklearn.metrics import roc_auc_score
6
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def train_epoch(cspca_model, loader, optimizer, epoch, args):
9
  cspca_model.train()
10
+ criterion = nn.BCELoss()
11
  loss = 0.0
12
  run_loss = CumulativeAverage()
13
  TARGETS = Cumulative()
14
  PREDS = Cumulative()
15
+
16
  for idx, batch_data in enumerate(loader):
17
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
18
  target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
 
35
 
36
  return loss_epoch, auc_epoch
37
 
38
+
39
  def val_epoch(cspca_model, loader, epoch, args):
40
  cspca_model.eval()
41
+ criterion = nn.BCELoss()
42
  loss = 0.0
43
  run_loss = CumulativeAverage()
44
  TARGETS = Cumulative()
 
60
  target_list = TARGETS.get_buffer().cpu().numpy()
61
  pred_list = PREDS.get_buffer().cpu().numpy()
62
  auc_epoch = roc_auc_score(target_list, pred_list)
63
+ y_pred_categoric = pred_list >= 0.5
64
  tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel()
65
  sens_epoch = tp / (tp + fn)
66
  spec_epoch = tn / (tn + fp)
67
+ val_epoch_metric = {
68
+ "epoch": epoch,
69
+ "loss": loss_epoch,
70
+ "auc": auc_epoch,
71
+ "sensitivity": sens_epoch,
72
+ "specificity": spec_epoch,
73
+ }
74
  return val_epoch_metric
 
src/train/train_pirads.py CHANGED
@@ -1,65 +1,11 @@
1
- import argparse
2
- import collections.abc
3
- import os
4
- import shutil
5
  import time
6
- import yaml
7
- import sys
8
- import gdown
9
  import numpy as np
10
  import torch
11
- import torch.distributed as dist
12
- import torch.multiprocessing as mp
13
  import torch.nn as nn
14
- import torch.nn.functional as F
15
- from monai.config import KeysCollection
16
- from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
17
- from monai.data.wsi_reader import WSIReader
18
  from monai.metrics import Cumulative, CumulativeAverage
19
- from monai.networks.nets import milmodel, resnet, MILModel
20
- from monai.transforms import (
21
- Compose,
22
- GridPatchd,
23
- LoadImaged,
24
- MapTransform,
25
- RandFlipd,
26
- RandGridPatchd,
27
- RandRotate90d,
28
- ScaleIntensityRanged,
29
- SplitDimd,
30
- ToTensord,
31
- ConcatItemsd,
32
- SelectItemsd,
33
- EnsureChannelFirstd,
34
- RepeatChanneld,
35
- DeleteItemsd,
36
- EnsureTyped,
37
- ClipIntensityPercentilesd,
38
- MaskIntensityd,
39
- HistogramNormalized,
40
- RandBiasFieldd,
41
- RandCropByPosNegLabeld,
42
- NormalizeIntensityd,
43
- SqueezeDimd,
44
- CropForegroundd,
45
- ScaleIntensityd,
46
- SpatialPadd,
47
- CenterSpatialCropd,
48
- ScaleIntensityd,
49
- Transposed,
50
- RandWeightedCropd,
51
- )
52
  from sklearn.metrics import cohen_kappa_score
53
- from torch.cuda.amp import GradScaler, autocast
54
- from torch.utils.data.dataloader import default_collate
55
- from torch.utils.tensorboard import SummaryWriter
56
- import matplotlib.pyplot as plt
57
- import wandb
58
- import math
59
  import logging
60
- from pathlib import Path
61
- from src.data.data_loader import get_dataloader
62
- from src.utils import save_pirads_checkpoint, setup_logging
63
 
64
  def get_lambda_att(epoch, max_lambda=2.0, warmup_epochs=10):
65
  if epoch < warmup_epochs:
@@ -67,26 +13,53 @@ def get_lambda_att(epoch, max_lambda=2.0, warmup_epochs=10):
67
  else:
68
  return max_lambda
69
 
 
70
  def get_attention_scores(data, target, heatmap, args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  attention_score = torch.zeros((data.shape[0], data.shape[1]))
72
  for i in range(data.shape[0]):
73
  sample = heatmap[i]
74
  heatmap_patches = sample.squeeze(1)
75
- raw_scores = heatmap_patches.view(len(heatmap_patches), -1).sum(dim=1)
76
- attention_score[i] = raw_scores / raw_scores.sum()
77
  shuffled_images = torch.empty_like(data).to(args.device)
78
  att_labels = torch.empty_like(attention_score).to(args.device)
79
- for i in range(data.shape[0]):
80
- perm = torch.randperm(data.shape[1])
81
  shuffled_images[i] = data[i, perm]
82
- att_labels[i] = attention_score[i, perm]
83
-
84
- att_labels[torch.argwhere(target < 1)] = torch.ones_like(att_labels[0]) / len(att_labels[0])# Setting attention scores for cases
85
- att_labels = att_labels ** 2 # Sharpening
 
 
86
  att_labels = att_labels / att_labels.sum(dim=1, keepdim=True)
87
-
88
  return att_labels, shuffled_images
89
 
 
90
  def train_epoch(model, loader, optimizer, scaler, epoch, args):
91
  """One train epoch over the dataset"""
92
  lambda_att = get_lambda_att(epoch, warmup_epochs=25)
@@ -104,13 +77,14 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):
104
  loss, acc = 0.0, 0.0
105
 
106
  for idx, batch_data in enumerate(loader):
107
-
108
  eps = 1e-8
109
  data = batch_data["image"].as_subclass(torch.Tensor)
110
  target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
111
  target = target.long()
112
  if args.use_heatmap:
113
- att_labels, shuffled_images = get_attention_scores(data, target, batch_data['final_heatmap'], args)
 
 
114
  att_labels = att_labels + eps
115
  else:
116
  shuffled_images = data.to(args.device)
@@ -139,7 +113,7 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):
139
  b = b + eps
140
  att_preds = torch.softmax(b, dim=1)
141
  attn_loss = 1 - att_criterion(att_preds, att_labels).mean()
142
- loss = class_loss + (lambda_att*attn_loss)
143
  else:
144
  loss = class_loss
145
  attn_loss = torch.tensor(0.0)
@@ -150,14 +124,10 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):
150
  if not torch.isfinite(total_norm):
151
  logging.warning("Non-finite gradient norm detected, skipping batch.")
152
  optimizer.zero_grad()
 
153
  else:
154
  scaler.step(optimizer)
155
  scaler.update()
156
- shuffled_images = shuffled_images.to('cpu')
157
- logits = logits.to('cpu')
158
- logits_attn = logits_attn.to('cpu')
159
- target = target.to('cpu')
160
-
161
  batch_norm.append(total_norm)
162
  pred = torch.argmax(logits, dim=1)
163
  acc = (pred == target).sum() / len(pred)
@@ -171,15 +141,15 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):
171
  args.epochs,
172
  idx,
173
  len(loader),
174
- loss,
175
- attn_loss,
176
  acc,
177
  total_norm,
178
- time.time() - start_time
179
  )
180
  )
181
  start_time = time.time()
182
-
183
  del data, target, shuffled_images, logits, logits_attn
184
  torch.cuda.empty_cache()
185
  batch_norm_epoch = batch_norm.aggregate()
@@ -189,9 +159,7 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):
189
  return loss_epoch, acc_epoch, attn_loss_epoch, batch_norm_epoch
190
 
191
 
192
-
193
  def val_epoch(model, loader, epoch, args):
194
-
195
  criterion = nn.CrossEntropyLoss()
196
 
197
  run_loss = CumulativeAverage()
@@ -204,17 +172,17 @@ def val_epoch(model, loader, epoch, args):
204
  model.eval()
205
  with torch.no_grad():
206
  for idx, batch_data in enumerate(loader):
207
-
208
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
209
  target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
210
  target = target.long()
211
- with torch.cuda.amp.autocast(enabled=args.amp):
 
212
  logits = model(data)
213
  loss = criterion(logits, target)
214
 
215
- data = data.to('cpu')
216
- target = target.to('cpu')
217
- logits = logits.to('cpu')
218
  pred = torch.argmax(logits, dim=1)
219
  acc = (pred == target).sum() / len(target)
220
 
@@ -228,7 +196,7 @@ def val_epoch(model, loader, epoch, args):
228
  )
229
  )
230
  start_time = time.time()
231
-
232
  del data, target, logits
233
  torch.cuda.empty_cache()
234
 
@@ -237,10 +205,5 @@ def val_epoch(model, loader, epoch, args):
237
  TARGETS = TARGETS.get_buffer().cpu().numpy()
238
  loss_epoch = run_loss.aggregate()
239
  acc_epoch = run_acc.aggregate()
240
- qwk = cohen_kappa_score(TARGETS.astype(np.float64),PREDS.astype(np.float64))
241
  return loss_epoch, acc_epoch, qwk
242
-
243
-
244
-
245
-
246
-
 
 
 
 
 
1
  import time
 
 
 
2
  import numpy as np
3
  import torch
 
 
4
  import torch.nn as nn
 
 
 
 
5
  from monai.metrics import Cumulative, CumulativeAverage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from sklearn.metrics import cohen_kappa_score
 
 
 
 
 
 
7
  import logging
8
+
 
 
9
 
10
  def get_lambda_att(epoch, max_lambda=2.0, warmup_epochs=10):
11
  if epoch < warmup_epochs:
 
13
  else:
14
  return max_lambda
15
 
16
+
17
  def get_attention_scores(data, target, heatmap, args):
18
+ """
19
+ Compute attention scores from heatmaps and shuffle data accordingly.
20
+ This function generates attention scores based on spatial heatmaps, applies
21
+ sharpening, and creates shuffled versions of the input data and attention
22
+ labels. For PI-RADS 2 (target < 1), uniform attention scores are assigned.
23
+ Args:
24
+ data (torch.Tensor): Input data tensor of shape (batch_size, num_patches, ...).
25
+ target (torch.Tensor): Target labels tensor of shape (batch_size,).
26
+ heatmap (torch.Tensor): Attention heatmap tensor corresponding to input patches.
27
+ args: Arguments object containing device specification.
28
+ Returns:
29
+ tuple: A tuple containing:
30
+ - att_labels (torch.Tensor): Sharpened and normalized attention scores
31
+ of shape (batch_size, num_patches), moved to args.device.
32
+ - shuffled_images (torch.Tensor): Randomly permuted data samples
33
+ of shape (batch_size, num_patches, ...), moved to args.device.
34
+ Note:
35
+ - Attention scores are computed by summing heatmap values across spatial dimensions.
36
+ - Data and attention labels are shuffled with the same permutation per sample.
37
+ - PI-RADS 2 samples receive uniform attention distribution.
38
+ - Attention scores are squared for sharpening and then normalized.
39
+ """
40
+
41
  attention_score = torch.zeros((data.shape[0], data.shape[1]))
42
  for i in range(data.shape[0]):
43
  sample = heatmap[i]
44
  heatmap_patches = sample.squeeze(1)
45
+ raw_scores = heatmap_patches.view(len(heatmap_patches), -1).sum(dim=1)
46
+ attention_score[i] = raw_scores / raw_scores.sum()
47
  shuffled_images = torch.empty_like(data).to(args.device)
48
  att_labels = torch.empty_like(attention_score).to(args.device)
49
+ for i in range(data.shape[0]):
50
+ perm = torch.randperm(data.shape[1])
51
  shuffled_images[i] = data[i, perm]
52
+ att_labels[i] = attention_score[i, perm]
53
+
54
+ att_labels[torch.argwhere(target < 1)] = torch.ones_like(att_labels[0]) / len(
55
+ att_labels[0]
56
+ ) # For PI-RADS 2, uniform scores across patches
57
+ att_labels = att_labels**2 # Sharpening
58
  att_labels = att_labels / att_labels.sum(dim=1, keepdim=True)
59
+
60
  return att_labels, shuffled_images
61
 
62
+
63
  def train_epoch(model, loader, optimizer, scaler, epoch, args):
64
  """One train epoch over the dataset"""
65
  lambda_att = get_lambda_att(epoch, warmup_epochs=25)
 
77
  loss, acc = 0.0, 0.0
78
 
79
  for idx, batch_data in enumerate(loader):
 
80
  eps = 1e-8
81
  data = batch_data["image"].as_subclass(torch.Tensor)
82
  target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
83
  target = target.long()
84
  if args.use_heatmap:
85
+ att_labels, shuffled_images = get_attention_scores(
86
+ data, target, batch_data["final_heatmap"], args
87
+ )
88
  att_labels = att_labels + eps
89
  else:
90
  shuffled_images = data.to(args.device)
 
113
  b = b + eps
114
  att_preds = torch.softmax(b, dim=1)
115
  attn_loss = 1 - att_criterion(att_preds, att_labels).mean()
116
+ loss = class_loss + (lambda_att * attn_loss)
117
  else:
118
  loss = class_loss
119
  attn_loss = torch.tensor(0.0)
 
124
  if not torch.isfinite(total_norm):
125
  logging.warning("Non-finite gradient norm detected, skipping batch.")
126
  optimizer.zero_grad()
127
+ scaler.update()
128
  else:
129
  scaler.step(optimizer)
130
  scaler.update()
 
 
 
 
 
131
  batch_norm.append(total_norm)
132
  pred = torch.argmax(logits, dim=1)
133
  acc = (pred == target).sum() / len(pred)
 
141
  args.epochs,
142
  idx,
143
  len(loader),
144
+ loss.item(),
145
+ attn_loss.item(),
146
  acc,
147
  total_norm,
148
+ time.time() - start_time,
149
  )
150
  )
151
  start_time = time.time()
152
+
153
  del data, target, shuffled_images, logits, logits_attn
154
  torch.cuda.empty_cache()
155
  batch_norm_epoch = batch_norm.aggregate()
 
159
  return loss_epoch, acc_epoch, attn_loss_epoch, batch_norm_epoch
160
 
161
 
 
162
  def val_epoch(model, loader, epoch, args):
 
163
  criterion = nn.CrossEntropyLoss()
164
 
165
  run_loss = CumulativeAverage()
 
172
  model.eval()
173
  with torch.no_grad():
174
  for idx, batch_data in enumerate(loader):
 
175
  data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
176
  target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
177
  target = target.long()
178
+
179
+ with torch.amp.autocast(device_type=str(args.device), enabled=args.amp):
180
  logits = model(data)
181
  loss = criterion(logits, target)
182
 
183
+ data = data.to("cpu")
184
+ target = target.to("cpu")
185
+ logits = logits.to("cpu")
186
  pred = torch.argmax(logits, dim=1)
187
  acc = (pred == target).sum() / len(target)
188
 
 
196
  )
197
  )
198
  start_time = time.time()
199
+
200
  del data, target, logits
201
  torch.cuda.empty_cache()
202
 
 
205
  TARGETS = TARGETS.get_buffer().cpu().numpy()
206
  loss_epoch = run_loss.aggregate()
207
  acc_epoch = run_acc.aggregate()
208
+ qwk = cohen_kappa_score(TARGETS.astype(np.float64), PREDS.astype(np.float64))
209
  return loss_epoch, acc_epoch, qwk
 
 
 
 
 
src/utils.py CHANGED
@@ -1,97 +1,46 @@
1
- import argparse
2
  import os
3
- import shutil
4
- import time
5
- import yaml
6
  import sys
7
- import gdown
8
  import numpy as np
9
  import torch
10
- import torch.distributed as dist
11
- import torch.multiprocessing as mp
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from monai.config import KeysCollection
15
- from monai.metrics import Cumulative, CumulativeAverage
16
- from monai.networks.nets import milmodel, resnet, MILModel
17
  from monai.transforms import (
18
  Compose,
19
- GridPatchd,
20
  LoadImaged,
21
- MapTransform,
22
- RandFlipd,
23
- RandGridPatchd,
24
- RandRotate90d,
25
- ScaleIntensityRanged,
26
- SplitDimd,
27
  ToTensord,
28
- ConcatItemsd,
29
- SelectItemsd,
30
- EnsureChannelFirstd,
31
- RepeatChanneld,
32
- DeleteItemsd,
33
  EnsureTyped,
34
- ClipIntensityPercentilesd,
35
- MaskIntensityd,
36
- HistogramNormalized,
37
- RandBiasFieldd,
38
- RandCropByPosNegLabeld,
39
- NormalizeIntensityd,
40
- SqueezeDimd,
41
- CropForegroundd,
42
- ScaleIntensityd,
43
- SpatialPadd,
44
- CenterSpatialCropd,
45
- ScaleIntensityd,
46
- Transposed,
47
- RandWeightedCropd,
48
  )
49
- from sklearn.metrics import cohen_kappa_score
50
- from torch.cuda.amp import GradScaler, autocast
51
- from torch.utils.data.dataloader import default_collate
52
- from torchvision.models.resnet import ResNet50_Weights
53
  from .data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd
54
- from torch.utils.data.distributed import DistributedSampler
55
- from torch.utils.tensorboard import SummaryWriter
56
- import matplotlib.patches as patches
57
-
58
- import matplotlib.pyplot as plt
59
-
60
- import wandb
61
- import math
62
- from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
63
-
64
- from src.model.MIL import MILModel_3D
65
- from src.model.csPCa_model import csPCa_Model
66
-
67
  import logging
68
  from pathlib import Path
 
69
 
70
- def save_pirads_checkpoint(model, epoch, args, filename="model.pth", best_acc=0):
71
 
72
- """Save checkpoint"""
 
73
 
74
  state_dict = model.state_dict()
75
-
76
  save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict}
77
-
78
  filename = os.path.join(args.logdir, filename)
79
  torch.save(save_dict, filename)
80
- logging.info("Saving checkpoint", filename)
 
81
 
82
  def save_cspca_checkpoint(model, val_metric, model_dir):
 
 
83
  state_dict = model.state_dict()
84
  save_dict = {
85
- 'epoch' : val_metric['epoch'],
86
- 'loss' : val_metric['loss'],
87
- 'auc' : val_metric['auc'],
88
- 'sensitivity' : val_metric['sensitivity'],
89
- 'specificity' : val_metric['specificity'],
90
- 'state' : val_metric['state'],
91
- 'state_dict' : state_dict,
92
  }
93
- torch.save(save_dict, os.path.join(model_dir,f"cspca_model.pth"))
94
- logging.info('Saving model with auc: ', str(val_metric['auc']))
 
95
 
96
  def get_metrics(metric_dict: dict):
97
  for metric_name, metric_list in metric_dict.items():
@@ -102,6 +51,7 @@ def get_metrics(metric_dict: dict):
102
  logging.info(f"Mean {metric_name}: {mean_metric:.3f}")
103
  logging.info(f"95% CI: ({lower:.3f}, {upper:.3f})")
104
 
 
105
  def setup_logging(log_file):
106
  log_file = Path(log_file)
107
  log_file.parent.mkdir(parents=True, exist_ok=True)
@@ -115,6 +65,7 @@ def setup_logging(log_file):
115
  ],
116
  )
117
 
 
118
  def validate_steps(steps):
119
  REQUIRES = {
120
  "get_segmentation_mask": ["register_and_crop"],
@@ -126,38 +77,63 @@ def validate_steps(steps):
126
  for req in required:
127
  if req not in steps[:i]:
128
  logging.error(
129
- f"Step '{step}' requires '{req}' to be executed before it. "
130
- f"Given order: {steps}"
131
  )
132
  sys.exit(1)
133
 
134
- def get_patch_coordinate(patches_top_5, parent_image, args):
135
 
136
- sample = np.array([i.transpose(1,2,0) for i in patches_top_5])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  coords = []
138
  rows, h, w, slices = sample.shape
139
 
140
  for i in range(rows):
141
- for j in range(slices):
142
- if j == 0:
143
- for k in range(parent_image.shape[2]):
144
- img_temp = parent_image[:, :, k]
145
- H, W = img_temp.shape
146
- h, w = sample[i, :, :, j].shape
147
- a,b = 0, 0 # Initialize a and b
148
- bool1 = False
149
- for l in range(H - h + 1):
150
- for m in range(W - w + 1):
151
- if np.array_equal(img_temp[l:l+h, m:m+w], sample[i, :, :, j]):
152
- a,b = l, m # top-left corner
153
- coords.append((a,b,k))
154
- bool1 = True
155
- break
156
- if bool1:
157
- break
158
-
159
- if bool1:
160
- break
 
161
 
162
  return coords
163
 
@@ -165,7 +141,12 @@ def get_patch_coordinate(patches_top_5, parent_image, args):
165
  def get_parent_image(temp_data_list, args):
166
  transform_image = Compose(
167
  [
168
- LoadImaged(keys=["image", "mask"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
 
 
 
 
 
169
  ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
170
  NormalizeIntensity_customd(keys=["image"], mask_key="mask", channel_wise=True),
171
  EnsureTyped(keys=["label"], dtype=torch.float32),
@@ -173,9 +154,10 @@ def get_parent_image(temp_data_list, args):
173
  ]
174
  )
175
  dataset_image = Dataset(data=temp_data_list, transform=transform_image)
176
- return dataset_image[0]['image'][0].numpy()
 
177
 
178
- '''
179
  def visualise_patches():
180
  sample = np.array([i.transpose(1,2,0) for i in patches_top_5])
181
  rows = len(patches_top_5)
@@ -223,4 +205,4 @@ def visualise_patches():
223
  plt.tight_layout()
224
  plt.show()
225
  a=1
226
- '''
 
 
1
  import os
 
 
 
2
  import sys
 
3
  import numpy as np
4
  import torch
 
 
 
 
 
 
 
5
  from monai.transforms import (
6
  Compose,
 
7
  LoadImaged,
 
 
 
 
 
 
8
  ToTensord,
 
 
 
 
 
9
  EnsureTyped,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  )
 
 
 
 
11
  from .data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd
12
+ from monai.data import Dataset, ITKReader
 
 
 
 
 
 
 
 
 
 
 
 
13
  import logging
14
  from pathlib import Path
15
+ import cv2
16
 
 
17
 
18
+ def save_pirads_checkpoint(model, epoch, args, filename="model.pth", best_acc=0):
19
+ """Save checkpoint for the PI-RADS model"""
20
 
21
  state_dict = model.state_dict()
 
22
  save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict}
 
23
  filename = os.path.join(args.logdir, filename)
24
  torch.save(save_dict, filename)
25
+ logging.info(f"Saving checkpoint {filename}")
26
+
27
 
28
  def save_cspca_checkpoint(model, val_metric, model_dir):
29
+ """Save checkpoint for the csPCa model"""
30
+
31
  state_dict = model.state_dict()
32
  save_dict = {
33
+ "epoch": val_metric["epoch"],
34
+ "loss": val_metric["loss"],
35
+ "auc": val_metric["auc"],
36
+ "sensitivity": val_metric["sensitivity"],
37
+ "specificity": val_metric["specificity"],
38
+ "state": val_metric["state"],
39
+ "state_dict": state_dict,
40
  }
41
+ torch.save(save_dict, os.path.join(model_dir, "cspca_model.pth"))
42
+ logging.info(f"Saving model with auc: {val_metric['auc']}")
43
+
44
 
45
  def get_metrics(metric_dict: dict):
46
  for metric_name, metric_list in metric_dict.items():
 
51
  logging.info(f"Mean {metric_name}: {mean_metric:.3f}")
52
  logging.info(f"95% CI: ({lower:.3f}, {upper:.3f})")
53
 
54
+
55
  def setup_logging(log_file):
56
  log_file = Path(log_file)
57
  log_file.parent.mkdir(parents=True, exist_ok=True)
 
65
  ],
66
  )
67
 
68
+
69
  def validate_steps(steps):
70
  REQUIRES = {
71
  "get_segmentation_mask": ["register_and_crop"],
 
77
  for req in required:
78
  if req not in steps[:i]:
79
  logging.error(
80
+ f"Step '{step}' requires '{req}' to be executed before it. Given order: {steps}"
 
81
  )
82
  sys.exit(1)
83
 
 
84
 
85
+ def get_patch_coordinate(patches_top_5, parent_image):
86
+ """
87
+ Locate the coordinates of top-5 patches within a parent image.
88
+
89
+ This function searches for the spatial location of the first slice (j=0) of each
90
+ top-5 patch within the parent 3D image volume. It returns the top-left corner
91
+ coordinates (row, column) and the slice index where each patch is found.
92
+
93
+ Args:
94
+ patches_top_5 (list): List of top-5 patch tensors, each with shape (C, H, W)
95
+ where C is channels, H is height, W is width.
96
+ parent_image (np.ndarray): 3D image volume with shape (height, width, slices)
97
+ to search within.
98
+ args: Configuration arguments (currently unused in the function).
99
+
100
+ Returns:
101
+ list: List of tuples (row, col, slice_idx) representing the top-left corner
102
+ coordinates of each found patch in the parent image. Returns empty list
103
+ if no patches are found.
104
+
105
+ Note:
106
+ - Only searches for the first slice (j=0) of each patch.
107
+ - Uses exhaustive 2D spatial matching within each slice of the parent image.
108
+ - Returns coordinates of the first match found for each patch.
109
+ """
110
+
111
+ sample = np.array([i.transpose(1, 2, 0) for i in patches_top_5])
112
  coords = []
113
  rows, h, w, slices = sample.shape
114
 
115
  for i in range(rows):
116
+ template = sample[i, :, :, 0].astype(np.float32)
117
+ found = False
118
+ for k in list(range(parent_image.shape[2])):
119
+ img_slice = parent_image[:, :, k].astype(np.float32)
120
+ res = cv2.matchTemplate(img_slice, template, cv2.TM_CCOEFF_NORMED)
121
+ min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
122
+
123
+ if max_val >= 0.99:
124
+ x, y = max_loc # OpenCV returns (col, row) -> (x, y)
125
+
126
+ # 2. Verification Step: Check if it's actually the correct patch
127
+ # This mimics your original np.array_equal strictness
128
+ candidate_patch = img_slice[y : y + h, x : x + w]
129
+
130
+ if np.allclose(candidate_patch, template, atol=1e-5):
131
+ coords.append((y, x, k)) # Original code stored (row, col, slice)
132
+ found = True
133
+ break
134
+
135
+ if not found:
136
+ print("Patch not found")
137
 
138
  return coords
139
 
 
141
  def get_parent_image(temp_data_list, args):
142
  transform_image = Compose(
143
  [
144
+ LoadImaged(
145
+ keys=["image", "mask"],
146
+ reader=ITKReader(),
147
+ ensure_channel_first=True,
148
+ dtype=np.float32,
149
+ ),
150
  ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
151
  NormalizeIntensity_customd(keys=["image"], mask_key="mask", channel_wise=True),
152
  EnsureTyped(keys=["label"], dtype=torch.float32),
 
154
  ]
155
  )
156
  dataset_image = Dataset(data=temp_data_list, transform=transform_image)
157
+ return dataset_image[0]["image"][0].numpy()
158
+
159
 
160
+ """
161
  def visualise_patches():
162
  sample = np.array([i.transpose(1,2,0) for i in patches_top_5])
163
  rows = len(patches_top_5)
 
205
  plt.tight_layout()
206
  plt.show()
207
  a=1
208
+ """
temp.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
tests/test_run.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ from pathlib import Path
4
+
5
+
6
+ def test_run_pirads_training():
7
+ """
8
+ Test that run_cspca.py runs without crashing using an existing YAML config.
9
+ """
10
+
11
+ # Path to your run_pirads.py script
12
+ repo_root = Path(__file__).parent.parent
13
+ script_path = repo_root / "run_pirads.py"
14
+
15
+ # Path to your existing config.yaml
16
+ config_path = repo_root / "config" / "config_pirads_train.yaml" # adjust this path
17
+
18
+ # Make sure the file exists
19
+ assert config_path.exists(), f"Config file not found: {config_path}"
20
+
21
+ # Run the script with the config
22
+ result = subprocess.run(
23
+ [sys.executable, str(script_path), "--mode", "train", "--config", str(config_path), "--dry_run", "True" ],
24
+ capture_output=True,
25
+ text=True,
26
+ )
27
+
28
+ # Check that it ran without errors
29
+ assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
30
+
31
+ def test_run_pirads_inference():
32
+ """
33
+ Test that run_cspca.py runs without crashing using an existing YAML config.
34
+ """
35
+
36
+ # Path to your run_pirads.py script
37
+ repo_root = Path(__file__).parent.parent
38
+ script_path = repo_root / "run_pirads.py"
39
+
40
+ # Path to your existing config.yaml
41
+ config_path = repo_root / "config" / "config_pirads_test.yaml" # adjust this path
42
+
43
+ # Make sure the file exists
44
+ assert config_path.exists(), f"Config file not found: {config_path}"
45
+
46
+ # Run the script with the config
47
+ result = subprocess.run(
48
+ [sys.executable, str(script_path), "--mode", "test", "--config", str(config_path), "--dry_run", "True" ],
49
+ capture_output=True,
50
+ text=True,
51
+ )
52
+
53
+ # Check that it ran without errors
54
+ assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
55
+
56
+ def test_run_cspca_training():
57
+ """
58
+ Test that run_cspca.py runs without crashing using an existing YAML config.
59
+ """
60
+
61
+ # Path to your run_cspca.py script
62
+ repo_root = Path(__file__).parent.parent
63
+ script_path = repo_root / "run_cspca.py"
64
+
65
+ # Path to your existing config.yaml
66
+ config_path = repo_root / "config" / "config_cspca_train.yaml" # adjust this path
67
+
68
+ # Make sure the file exists
69
+ assert config_path.exists(), f"Config file not found: {config_path}"
70
+
71
+ # Run the script with the config
72
+ result = subprocess.run(
73
+ [sys.executable, str(script_path), "--mode", "train", "--config", str(config_path), "--dry_run", "True" ],
74
+ capture_output=True,
75
+ text=True,
76
+ )
77
+
78
+ # Check that it ran without errors
79
+ assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
80
+
81
+ def test_run_cspca_inference():
82
+ """
83
+ Test that run_cspca.py runs without crashing using an existing YAML config.
84
+ """
85
+
86
+ # Path to your run_cspca.py script
87
+ repo_root = Path(__file__).parent.parent
88
+ script_path = repo_root / "run_cspca.py"
89
+
90
+ # Path to your existing config.yaml
91
+ config_path = repo_root / "config" / "config_cspca_test.yaml" # adjust this path
92
+
93
+ # Make sure the file exists
94
+ assert config_path.exists(), f"Config file not found: {config_path}"
95
+
96
+ # Run the script with the config
97
+ result = subprocess.run(
98
+ [sys.executable, str(script_path), "--mode", "test", "--config", str(config_path), "--dry_run", "True" ],
99
+ capture_output=True,
100
+ text=True,
101
+ )
102
+
103
+ # Check that it ran without errors
104
+ assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
105
+
106
+
tests/test_run_cspca.py DELETED
@@ -1,28 +0,0 @@
1
- import subprocess
2
- import sys
3
- from pathlib import Path
4
-
5
- def test_run_cspca_with_existing_config():
6
- """
7
- Test that run_cspca.py runs without crashing using an existing YAML config.
8
- """
9
-
10
- # Path to your run_cspca.py script
11
- repo_root = Path(__file__).parent.parent
12
- script_path = repo_root / "run_cspca.py"
13
-
14
- # Path to your existing config.yaml
15
- config_path = repo_root / "config" / "config_cspca_test.yaml" # adjust this path
16
-
17
- # Make sure the file exists
18
- assert config_path.exists(), f"Config file not found: {config_path}"
19
-
20
- # Run the script with the config
21
- result = subprocess.run(
22
- [sys.executable, str(script_path), "--mode","test", "--config", str(config_path)],
23
- capture_output=True,
24
- text=True,
25
- )
26
-
27
- # Check that it ran without errors
28
- assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_run_pirads.py DELETED
@@ -1,28 +0,0 @@
1
- import subprocess
2
- import sys
3
- from pathlib import Path
4
-
5
- def test_run_cspca_with_existing_config():
6
- """
7
- Test that run_cspca.py runs without crashing using an existing YAML config.
8
- """
9
-
10
- # Path to your run_pirads.py script
11
- repo_root = Path(__file__).parent.parent
12
- script_path = repo_root / "run_pirads.py"
13
-
14
- # Path to your existing config.yaml
15
- config_path = repo_root / "config" / "config_pirads_test.yaml" # adjust this path
16
-
17
- # Make sure the file exists
18
- assert config_path.exists(), f"Config file not found: {config_path}"
19
-
20
- # Run the script with the config
21
- result = subprocess.run(
22
- [sys.executable, str(script_path), "--mode","test", "--config", str(config_path)],
23
- capture_output=True,
24
- text=True,
25
- )
26
-
27
- # Check that it ran without errors
28
- assert result.returncode == 0, f"Script failed with:\n{result.stderr}"