Anirudh Balaraman commited on
Commit
906fcb9
·
1 Parent(s): 54bae7b

add scripts

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. config/config_cspca_test.yaml +17 -0
  3. config/config_cspca_train.yaml +19 -0
  4. config/config_pirads_test.yaml +18 -0
  5. config/config_pirads_train.yaml +22 -0
  6. config/config_preprocess.yaml +11 -0
  7. config/inference.json +153 -0
  8. dataset/PI-RADS_data.json +0 -0
  9. dataset/PICAI_cspca.json +0 -0
  10. dataset/adc_reference.nrrd +3 -0
  11. dataset/dwi_reference.nrrd +3 -0
  12. dataset/prostate_segmentation_reference.nrrd +3 -0
  13. dataset/t2_reference.nrrd +3 -0
  14. job_scripts/train_cspca.sh +19 -0
  15. preprocess_main.py +68 -0
  16. pyproject.toml +0 -0
  17. run_cspca.py +220 -0
  18. run_pirads.py +283 -0
  19. src/__init__.py +0 -0
  20. src/__pycache__/__init__.cpython-39.pyc +0 -0
  21. src/__pycache__/utils.cpython-39.pyc +0 -0
  22. src/data/__init__.py +0 -0
  23. src/data/__pycache__/__init__.cpython-39.pyc +0 -0
  24. src/data/__pycache__/custom_transforms.cpython-39.pyc +0 -0
  25. src/data/__pycache__/data_loader.cpython-39.pyc +0 -0
  26. src/data/custom_transforms.py +350 -0
  27. src/data/data_loader.py +125 -0
  28. src/model/MIL.py +248 -0
  29. src/model/__init__.py +0 -0
  30. src/model/__pycache__/MIL.cpython-39.pyc +0 -0
  31. src/model/__pycache__/__init__.cpython-39.pyc +0 -0
  32. src/model/__pycache__/csPCa_model.cpython-39.pyc +0 -0
  33. src/model/csPCa_model.py +50 -0
  34. src/preprocessing/__init__.py +0 -0
  35. src/preprocessing/__pycache__/__init__.cpython-39.pyc +0 -0
  36. src/preprocessing/__pycache__/center_crop.cpython-39.pyc +0 -0
  37. src/preprocessing/__pycache__/generate_heatmap.cpython-39.pyc +0 -0
  38. src/preprocessing/__pycache__/histogram_match.cpython-39.pyc +0 -0
  39. src/preprocessing/__pycache__/prostate_mask.cpython-39.pyc +0 -0
  40. src/preprocessing/__pycache__/register_and_crop.cpython-39.pyc +0 -0
  41. src/preprocessing/center_crop.py +64 -0
  42. src/preprocessing/generate_heatmap.py +76 -0
  43. src/preprocessing/histogram_match.py +62 -0
  44. src/preprocessing/prostate_mask.py +128 -0
  45. src/preprocessing/register_and_crop.py +67 -0
  46. src/train/__init__.py +0 -0
  47. src/train/__pycache__/__init__.cpython-39.pyc +0 -0
  48. src/train/__pycache__/train_cspca.cpython-39.pyc +0 -0
  49. src/train/__pycache__/train_pirads.cpython-39.pyc +0 -0
  50. src/train/train_cspca.py +141 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ logs/
2
+ models/
config/config_cspca_test.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
2
+ data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
3
+ dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/dataset/PICAI_cspca.json
4
+ num_classes: !!int 4
5
+ mil_mode: att_trans
6
+ tile_count: !!int 24
7
+ tile_size: !!int 64
8
+ depth: !!int 3
9
+ use_heatmap: !!bool True
10
+ workers: !!int 6
11
+ checkpoint_cspca: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/models/cspca_model.pth
12
+ num_seeds: !!int 2
13
+ batch_size: !!int 1
14
+
15
+
16
+
17
+
config/config_cspca_train.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
2
+ data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
3
+ dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/dataset/PICAI_cspca.json
4
+ num_classes: !!int 4
5
+ mil_mode: att_trans
6
+ tile_count: !!int 24
7
+ tile_size: !!int 64
8
+ depth: !!int 3
9
+ use_heatmap: !!bool True
10
+ workers: !!int 6
11
+ checkpoint_pirads: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/models/pirads.pt
12
+ epochs: !!int 1
13
+ batch_size: !!int 8
14
+ optim_lr: !!float 2e-4
15
+
16
+
17
+
18
+
19
+
config/config_pirads_test.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_name: pirads_test_run
2
+ project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
3
+ data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/prostate-foundation/PICAI_registered/t2_hist_matched
4
+ dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/dataset/PI-RADS_data.json
5
+ num_classes: !!int 4
6
+ mil_mode: att_trans
7
+ tile_count: !!int 24
8
+ tile_size: !!int 64
9
+ depth: !!int 3
10
+ use_heatmap: !!bool True
11
+ workers: !!int 0
12
+ checkpoint: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/models/pirads.pt
13
+ amp: !!bool True
14
+ dry_run: !!bool True
15
+
16
+
17
+
18
+
config/config_pirads_train.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
2
+ data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/prostate-foundation/PICAI_registered/t2_hist_matched
3
+ dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/dataset/PI-RADS_data.json
4
+ num_classes: !!int 4
5
+ mil_mode: att_trans
6
+ tile_count: !!int 24
7
+ tile_size: !!int 64
8
+ depth: !!int 3
9
+ use_heatmap: !!bool True
10
+ workers: !!int 0
11
+ checkpoint_pirads: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/models/pirads.pt
12
+ epochs: !!int 2
13
+ batch_size: !!int 8
14
+ optim_lr: !!float 2e-4
15
+ weight_decay: !!float 1e-5
16
+ amp: !!bool True
17
+ wandb: !!bool True
18
+ dry_run: !!bool True
19
+
20
+
21
+
22
+
config/config_preprocess.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ t2_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/data_temp/t2
2
+ dwi_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/data_temp/dwi
3
+ adc_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/data_temp/adc
4
+ output_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/data_temp/processed
5
+ project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder
6
+
7
+
8
+
9
+
10
+
11
+
config/inference.json ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import pandas as pd",
4
+ "$import os"
5
+ ],
6
+ "bundle_root": "/workspace/data/prostate_mri_anatomy",
7
+ "output_dir": "$@bundle_root + '/eval'",
8
+ "dataset_dir": "/workspace/data/prostate158/prostate158_train/",
9
+ "datalist": "$list(@dataset_dir + pd.read_csv(@dataset_dir + 'valid.csv').t2)",
10
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
11
+ "network_def": {
12
+ "_target_": "UNet",
13
+ "spatial_dims": 3,
14
+ "in_channels": 1,
15
+ "out_channels": 3,
16
+ "channels": [
17
+ 16,
18
+ 32,
19
+ 64,
20
+ 128,
21
+ 256,
22
+ 512
23
+ ],
24
+ "strides": [
25
+ 2,
26
+ 2,
27
+ 2,
28
+ 2,
29
+ 2
30
+ ],
31
+ "num_res_units": 4,
32
+ "norm": "batch",
33
+ "act": "prelu",
34
+ "dropout": 0.15
35
+ },
36
+ "network": "$@network_def.to(@device)",
37
+ "preprocessing": {
38
+ "_target_": "Compose",
39
+ "transforms": [
40
+ {
41
+ "_target_": "LoadImaged",
42
+ "keys": "image"
43
+ },
44
+ {
45
+ "_target_": "EnsureChannelFirstd",
46
+ "keys": "image"
47
+ },
48
+ {
49
+ "_target_": "Orientationd",
50
+ "keys": "image",
51
+ "axcodes": "RAS"
52
+ },
53
+ {
54
+ "_target_": "Spacingd",
55
+ "keys": "image",
56
+ "pixdim": [
57
+ 0.5,
58
+ 0.5,
59
+ 0.5
60
+ ],
61
+ "mode": "bilinear"
62
+ },
63
+ {
64
+ "_target_": "ScaleIntensityd",
65
+ "keys": "image",
66
+ "minv": 0,
67
+ "maxv": 1
68
+ },
69
+ {
70
+ "_target_": "NormalizeIntensityd",
71
+ "keys": "image"
72
+ },
73
+ {
74
+ "_target_": "EnsureTyped",
75
+ "keys": "image"
76
+ }
77
+ ]
78
+ },
79
+ "dataset": {
80
+ "_target_": "Dataset",
81
+ "data": "$[{'image': i} for i in @datalist]",
82
+ "transform": "@preprocessing"
83
+ },
84
+ "dataloader": {
85
+ "_target_": "DataLoader",
86
+ "dataset": "@dataset",
87
+ "batch_size": 1,
88
+ "shuffle": false,
89
+ "num_workers": 4
90
+ },
91
+ "inferer": {
92
+ "_target_": "SlidingWindowInferer",
93
+ "roi_size": [
94
+ 96,
95
+ 96,
96
+ 96
97
+ ],
98
+ "sw_batch_size": 4,
99
+ "overlap": 0.5
100
+ },
101
+ "postprocessing": {
102
+ "_target_": "Compose",
103
+ "transforms": [
104
+ {
105
+ "_target_": "AsDiscreted",
106
+ "keys": "pred",
107
+ "argmax": true
108
+ },
109
+ {
110
+ "_target_": "KeepLargestConnectedComponentd",
111
+ "keys": "pred",
112
+ "applied_labels": [
113
+ 1,
114
+ 2
115
+ ]
116
+ },
117
+ {
118
+ "_target_": "SaveImaged",
119
+ "keys": "pred",
120
+ "resample": false,
121
+ "meta_keys": "pred_meta_dict",
122
+ "output_dir": "@output_dir"
123
+ }
124
+ ]
125
+ },
126
+ "handlers": [
127
+ {
128
+ "_target_": "CheckpointLoader",
129
+ "load_path": "$@bundle_root + '/models/model.pt'",
130
+ "load_dict": {
131
+ "model": "@network"
132
+ }
133
+ },
134
+ {
135
+ "_target_": "StatsHandler",
136
+ "iteration_log": false
137
+ }
138
+ ],
139
+ "evaluator": {
140
+ "_target_": "SupervisedEvaluator",
141
+ "device": "@device",
142
+ "val_data_loader": "@dataloader",
143
+ "network": "@network",
144
+ "inferer": "@inferer",
145
+ "postprocessing": "@postprocessing",
146
+ "val_handlers": "@handlers",
147
+ "amp": true
148
+ },
149
+ "evaluating": [
150
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
151
+ "$@evaluator.run()"
152
+ ]
153
+ }
dataset/PI-RADS_data.json ADDED
The diff for this file is too large to render. See raw diff
 
dataset/PICAI_cspca.json ADDED
The diff for this file is too large to render. See raw diff
 
dataset/adc_reference.nrrd ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46821ddb4373198d98b5877ec69a702718702cbaae6b9ef00b6b5ad235cf3f3e
3
+ size 3815961
dataset/dwi_reference.nrrd ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a964835a0c8c016b162a2eaa56456f9704290b3a4128aafeb9b01805166ca7b9
3
+ size 3815961
dataset/prostate_segmentation_reference.nrrd ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95984795649662ae405077d334744ac2e9d9fb8db7864c4ea99791a706ccee19
3
+ size 13434
dataset/t2_reference.nrrd ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:561af922d7461870fcc1e84134d659f7611a8ae73d1a0681fde79808f1cd99a9
3
+ size 3815961
job_scripts/train_cspca.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=cspca_training # Specify job name
3
+ #SBATCH --partition=gpu # Specify partition name
4
+ #SBATCH --mem=128G
5
+ #SBATCH --gres=gpu:1
6
+ #SBATCH --time=48:00:00 # Set a limit on the total run time
7
+ #SBATCH --output=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/logs/%x/log.o%j # File name for standard output
8
+ #SBATCH --error=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/logs/%x/log.e%j # File name for standard error output
9
+ #SBATCH --mail-user=anirudh.balaraman@charite.de
10
+ #SBATCH --mail-type=END,FAIL
11
+
12
+
13
+ source /etc/profile.d/conda.sh
14
+ conda activate foundation
15
+
16
+ RUNDIR="/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation"
17
+
18
+
19
+ srun python -u $RUNDIR/MIL/new_folder/run_cspca.py --mode train --config $RUNDIR/MIL/new_folder/config/config_cspca_train.yaml
preprocess_main.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import SimpleITK as sitk
2
+ import os
3
+ import numpy as np
4
+ import nrrd
5
+ from AIAH_utility.viewer import BasicViewer
6
+ from tqdm import tqdm
7
+ import pandas as pd
8
+ from picai_prep.preprocessing import PreprocessingSettings, Sample
9
+ import multiprocessing
10
+ import sys
11
+ from src.preprocessing.register_and_crop import register_files
12
+ from src.preprocessing.prostate_mask import get_segmask
13
+ from src.preprocessing.histogram_match import histmatch
14
+ from src.preprocessing.generate_heatmap import get_heatmap
15
+ import logging
16
+ from pathlib import Path
17
+ from src.utils import setup_logging
18
+ from src.utils import validate_steps
19
+ import argparse
20
+ import yaml
21
+
22
+ def parse_args():
23
+ FUNCTIONS = {
24
+ "register_and_crop": register_files,
25
+ "histogram_match": histmatch,
26
+ "get_segmentation_mask": get_segmask,
27
+ "get_heatmap": get_heatmap,
28
+ }
29
+ parser = argparse.ArgumentParser(description="File preprocessing")
30
+ parser.add_argument("--config", type=str, help="Path to YAML config file")
31
+ parser.add_argument(
32
+ "--steps",
33
+ nargs="+", # ← list of strings
34
+ choices=FUNCTIONS.keys(), # ← restrict allowed values
35
+ required=True,
36
+ help="Steps to execute (one or more)"
37
+ )
38
+ parser.add_argument("--t2_dir", default=None, help="Path to T2W files")
39
+ parser.add_argument("--dwi_dir", default=None, help="Path to DWI files")
40
+ parser.add_argument("--adc_dir", default=None, help="Path to ADC files")
41
+ parser.add_argument("--seg_dir", default=None, help="Path to segmentation masks")
42
+ parser.add_argument("--output_dir", default=None, help="Path to output folder")
43
+ parser.add_argument("--margin", default=0.2, type=float, help="Margin to center crop the images")
44
+ parser.add_argument("--project_dir", default=None, help="Project directory")
45
+
46
+ args = parser.parse_args()
47
+ if args.config:
48
+ with open(args.config, 'r') as config_file:
49
+ config = yaml.safe_load(config_file)
50
+ args.__dict__.update(config)
51
+ return args
52
+
53
+ if __name__ == "__main__":
54
+ args = parse_args()
55
+ FUNCTIONS = {
56
+ "register_and_crop": register_files,
57
+ "histogram_match": histmatch,
58
+ "get_segmentation_mask": get_segmask,
59
+ "get_heatmap": get_heatmap,
60
+ }
61
+
62
+ args.logfile = os.path.join(args.output_dir, f"preprocessing.log")
63
+ setup_logging(args.logfile)
64
+ logging.info("Starting preprocessing")
65
+ validate_steps(args.steps)
66
+ for step in args.steps:
67
+ func = FUNCTIONS[step]
68
+ args = func(args)
pyproject.toml ADDED
File without changes
run_cspca.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import shutil
4
+ import time
5
+ import yaml
6
+ import sys
7
+ import gdown
8
+ import numpy as np
9
+ import torch
10
+ import torch.distributed as dist
11
+ import torch.multiprocessing as mp
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from monai.config import KeysCollection
15
+ from monai.metrics import Cumulative, CumulativeAverage
16
+ from monai.networks.nets import milmodel, resnet, MILModel
17
+
18
+ from sklearn.metrics import cohen_kappa_score
19
+ from torch.cuda.amp import GradScaler, autocast
20
+ from torch.utils.data.dataloader import default_collate
21
+ from torchvision.models.resnet import ResNet50_Weights
22
+ import shutil
23
+ from pathlib import Path
24
+ from torch.utils.data.distributed import DistributedSampler
25
+ from torch.utils.tensorboard import SummaryWriter
26
+ from monai.utils import set_determinism
27
+
28
+ import matplotlib.pyplot as plt
29
+
30
+ import wandb
31
+ import math
32
+ import logging
33
+ from pathlib import Path
34
+
35
+
36
+ from src.model.MIL import MILModel_3D
37
+ from src.model.csPCa_model import csPCa_Model
38
+ from src.data.data_loader import get_dataloader
39
+ from src.utils import save_cspca_checkpoint, get_metrics, setup_logging
40
+ from src.train.train_cspca import train_epoch, val_epoch
41
+
42
+ def main_worker(args):
43
+
44
+ mil_model = MILModel_3D(
45
+ num_classes=args.num_classes,
46
+ mil_mode=args.mil_mode
47
+ ).to(args.device)
48
+ cache_dir_path = Path(os.path.join(args.logdir, "cache"))
49
+
50
+ if args.mode == 'train':
51
+
52
+ checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
53
+ mil_model.load_state_dict(checkpoint["state_dict"])
54
+ mil_model = mil_model.to(args.device)
55
+
56
+ model_dir = os.path.join(args.project_dir,'models')
57
+ metrics_dict = {'auc':[], 'sensitivity':[], 'specificity':[]}
58
+ for st in list(range(args.num_seeds)):
59
+ set_determinism(seed=st)
60
+
61
+ train_loader = get_dataloader(args, split="train")
62
+ valid_loader = get_dataloader(args, split="test")
63
+ cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
64
+ for submodule in [cspca_model.backbone.net,
65
+ cspca_model.backbone.myfc,
66
+ cspca_model.backbone.transformer]:
67
+ for param in submodule.parameters():
68
+ param.requires_grad = False
69
+
70
+ optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, cspca_model.parameters()), lr=args.optim_lr)
71
+
72
+ old_loss = float('inf')
73
+ old_auc = 0.0
74
+ for epoch in range(args.epochs):
75
+ train_loss, train_auc = train_epoch(cspca_model, train_loader, optimizer, epoch=epoch, args=args)
76
+ logging.info(f"STATE {st} EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}")
77
+ val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
78
+ logging.info(f"STATE {st} EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}")
79
+ val_metric['state'] = st
80
+ if val_metric['loss'] < old_loss:
81
+ old_loss = val_metric['loss']
82
+ old_auc = val_metric['auc']
83
+ sensitivity = val_metric['sensitivity']
84
+ specificity = val_metric['specificity']
85
+ if len(metrics_dict['auc']) == 0:
86
+ save_cspca_checkpoint(cspca_model, val_metric, model_dir)
87
+ elif val_metric['auc'] >= max(metrics_dict['auc']):
88
+ save_cspca_checkpoint(cspca_model, val_metric, model_dir)
89
+
90
+ metrics_dict['auc'].append(old_auc)
91
+ metrics_dict['sensitivity'].append(sensitivity)
92
+ metrics_dict['specificity'].append(specificity)
93
+ if cache_dir_path.exists() and cache_dir_path.is_dir():
94
+ shutil.rmtree(cache_dir_path)
95
+
96
+ get_metrics(metrics_dict)
97
+
98
+ elif args.mode == 'test':
99
+
100
+ cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
101
+ checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
102
+ cspca_model.load_state_dict(checkpt['state_dict'])
103
+ cspca_model = cspca_model.to(args.device)
104
+ if 'auc' in checkpt and 'sensitivity' in checkpt and 'specificity' in checkpt:
105
+ auc, sens, spec = checkpt['auc'], checkpt['sensitivity'], checkpt['specificity']
106
+ logging.info(f"csPCa Model loaded from {args.checkpoint_cspca} with AUC: {auc}, Sensitivity: {sens}, Specificity: {spec} on the test set.")
107
+ else:
108
+ logging.info(f"csPCa Model loaded from {args.checkpoint_cspca}.")
109
+
110
+ metrics_dict = {'auc':[], 'sensitivity':[], 'specificity':[]}
111
+ for st in list(range(args.num_seeds)):
112
+ set_determinism(seed=st)
113
+ test_loader = get_dataloader(args, split="test")
114
+ test_metric = val_epoch(cspca_model, test_loader, epoch=0, args=args)
115
+ metrics_dict['auc'].append(test_metric['auc'])
116
+ metrics_dict['sensitivity'].append(test_metric['sensitivity'])
117
+ metrics_dict['specificity'].append(test_metric['specificity'])
118
+
119
+ if cache_dir_path.exists() and cache_dir_path.is_dir():
120
+ shutil.rmtree(cache_dir_path)
121
+
122
+ get_metrics(metrics_dict)
123
+
124
+
125
+
126
+
127
+ def parse_args():
128
+ parser = argparse.ArgumentParser(description="Multiple Instance Learning (MIL) for csPCa risk prediction.")
129
+ parser.add_argument('--mode', type=str, choices=['train', 'test'], required=True, help='Operation mode: train or infer')
130
+ parser.add_argument('--run_name', type=str, default='train_cspca', help='run name for log file')
131
+ parser.add_argument('--config', type=str, help='Path to YAML config file')
132
+ parser.add_argument(
133
+ "--project_dir", default=None, help="path to project firectory"
134
+ )
135
+ parser.add_argument(
136
+ "--data_root", default=None, help="path to root folder of images"
137
+ )
138
+ parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
139
+ parser.add_argument("--num_classes", default=4, type=int, help="number of output classes")
140
+ parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm: choose either att_trans or att_pyramid")
141
+ parser.add_argument(
142
+ "--tile_count", default=24, type=int, help="number of patches (instances) to extract from MRI input"
143
+ )
144
+ parser.add_argument("--tile_size", default=64, type=int, help="size of square patch (instance) in pixels")
145
+ parser.add_argument("--depth", default=3, type=int, help="number of slices in each 3D patch (instance)")
146
+ parser.add_argument(
147
+ "--use_heatmap", action="store_true",
148
+ help="enable weak attention heatmap guided patch generation"
149
+ )
150
+ parser.add_argument(
151
+ "--no_heatmap", dest="use_heatmap", action="store_false",
152
+ help="disable heatmap"
153
+ )
154
+ parser.set_defaults(use_heatmap=True)
155
+ parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
156
+ #parser.add_argument("--dry-run", action="store_true")
157
+ parser.add_argument("--checkpoint_pirads", default=None, help="Load PI-RADS model")
158
+ parser.add_argument("--epochs", "--max_epochs", default=30, type=int, help="number of training epochs")
159
+ parser.add_argument("--batch_size", default=32, type=int, help="number of MRI scans per batch")
160
+ parser.add_argument("--optim_lr", default=2e-4, type=float, help="initial learning rate")
161
+ #parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
162
+ parser.add_argument(
163
+ "--val_every",
164
+ "--val_interval",
165
+ default=1,
166
+ type=int,
167
+ help="run validation after this number of epochs, default 1 to run every epoch",
168
+ )
169
+ parser.add_argument("--dry_run", action="store_true", help="Run the script in dry-run mode (default: False)")
170
+ parser.add_argument("--checkpoint_cspca", default=None, help="load existing checkpoint")
171
+ parser.add_argument("--num_seeds", default=20, type=int, help="number of seeds to be run to build CI")
172
+ args = parser.parse_args()
173
+ if args.config:
174
+ with open(args.config, 'r') as config_file:
175
+ config = yaml.safe_load(config_file)
176
+ args.__dict__.update(config)
177
+
178
+
179
+
180
+ return args
181
+
182
+
183
+
184
+ if __name__ == "__main__":
185
+ args = parse_args()
186
+ args.logdir = os.path.join(args.project_dir, "logs", args.run_name)
187
+ os.makedirs(args.logdir, exist_ok=True)
188
+ args.logfile = os.path.join(args.logdir, f"{args.run_name}.log")
189
+ setup_logging(args.logfile)
190
+
191
+
192
+ logging.info("Argument values:")
193
+ for k, v in vars(args).items():
194
+ logging.info(f"{k} => {v}")
195
+ logging.info("-----------------")
196
+
197
+ if args.dataset_json is None:
198
+ logging.error('Dataset path not provided. Quitting.')
199
+ sys.exit(1)
200
+ if args.checkpoint_pirads is None and args.mode == 'train':
201
+ logging.error('PI-RADS checkpoint path not provided. Quitting.')
202
+ sys.exit(1)
203
+ elif args.checkpoint_cspca is None and args.mode == 'test':
204
+ logging.error('csPCa checkpoint path not provided. Quitting.')
205
+ sys.exit(1)
206
+ args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
207
+
208
+ if args.device == torch.device("cuda"):
209
+ torch.backends.cudnn.benchmark = True
210
+
211
+ if args.dry_run:
212
+ logging.info("Dry run mode enabled.")
213
+ args.epochs = 2
214
+ args.batch_size = 2
215
+ args.workers = 0
216
+ args.num_seeds = 2
217
+ args.wandb = False
218
+
219
+
220
+ main_worker(args)
run_pirads.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections.abc
3
+ import os
4
+ import shutil
5
+ import time
6
+ import yaml
7
+ import sys
8
+ import gdown
9
+ import numpy as np
10
+ import torch
11
+ import torch.distributed as dist
12
+ import torch.multiprocessing as mp
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from monai.config import KeysCollection
16
+ from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
17
+ from monai.data.wsi_reader import WSIReader
18
+ from monai.metrics import Cumulative, CumulativeAverage
19
+ from monai.networks.nets import milmodel, resnet, MILModel
20
+
21
+ from sklearn.metrics import cohen_kappa_score
22
+ from torch.cuda.amp import GradScaler, autocast
23
+ from torch.utils.data.dataloader import default_collate
24
+ from torch.utils.tensorboard import SummaryWriter
25
+ from monai.utils import set_determinism
26
+ import matplotlib.pyplot as plt
27
+
28
+ import wandb
29
+ import math
30
+ import logging
31
+ from pathlib import Path
32
+ from src.data.data_loader import get_dataloader
33
+ from src.train.train_pirads import train_epoch, val_epoch
34
+ from src.model.MIL import MILModel_3D
35
+ from src.utils import save_pirads_checkpoint, setup_logging
36
+
37
+
38
+ def main_worker(args):
39
+ if args.device == torch.device("cuda"):
40
+ torch.cuda.set_device(args.gpu) # use this default device (same as args.device if not distributed)
41
+ torch.backends.cudnn.benchmark = True
42
+
43
+ model = MILModel_3D(
44
+ num_classes=args.num_classes,
45
+ mil_mode=args.mil_mode
46
+ )
47
+ start_epoch = 0
48
+ best_acc = 0.0
49
+ if args.checkpoint is not None:
50
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
51
+ model.load_state_dict(checkpoint["state_dict"])
52
+
53
+ if "epoch" in checkpoint:
54
+ start_epoch = checkpoint["epoch"]
55
+ if "best_acc" in checkpoint:
56
+ best_acc = checkpoint["best_acc"]
57
+ logging.info("=> loaded checkpoint %s (epoch %d) (bestacc %f)",args.checkpoint, start_epoch, best_acc)
58
+ cache_dir_ = os.path.join(args.logdir, "cache")
59
+ model.to(args.device)
60
+ params = model.parameters()
61
+ if args.mode == 'train':
62
+ train_loader = get_dataloader(args, split=args.mode)
63
+ valid_loader = get_dataloader(args, split="test")
64
+ logging.info("Dataset training:", str(len(train_loader.dataset)), "test:", str(len(valid_loader.dataset)))
65
+
66
+ if args.mil_mode in ["att_trans", "att_trans_pyramid"]:
67
+ params = [
68
+ {"params": list(model.attention.parameters()) + list(model.myfc.parameters()) + list(model.net.parameters())},
69
+ {"params": list(model.transformer.parameters()), "lr": 6e-5, "weight_decay": 0.1},
70
+ ]
71
+
72
+ optimizer = torch.optim.AdamW(params, lr=args.optim_lr, weight_decay=args.weight_decay)
73
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0)
74
+ scaler = torch.amp.GradScaler(device=str(args.device), enabled=args.amp)
75
+
76
+ if args.logdir is not None:
77
+ writer = SummaryWriter(log_dir=args.logdir)
78
+ logging.info("Writing Tensorboard logs to ", writer.log_dir)
79
+ else:
80
+ writer = None
81
+
82
+
83
+ # RUN TRAINING
84
+ n_epochs = args.epochs
85
+ val_loss_min = float("inf")
86
+ epochs_no_improve = 0
87
+ for epoch in range(start_epoch, n_epochs):
88
+
89
+ logging.info(time.ctime(), "Epoch:", epoch)
90
+ epoch_time = time.time()
91
+ train_loss, train_acc, train_att_loss, batch_norm = train_epoch(model, train_loader, optimizer, scaler=scaler, epoch=epoch, args=args)
92
+ logging.info(
93
+ "Final training %d/%d loss: %.4f attention loss: %.4f acc: %.4f time %.2fs",
94
+ epoch,
95
+ n_epochs - 1,
96
+ train_loss,
97
+ train_att_loss,
98
+ train_acc,
99
+ time.time() - epoch_time,
100
+ )
101
+
102
+ if writer is not None:
103
+ writer.add_scalar("train_loss", train_loss, epoch)
104
+ writer.add_scalar("train_attention_loss", train_att_loss, epoch)
105
+ writer.add_scalar("train_acc", train_acc, epoch)
106
+ wandb.log({"Train Loss": train_loss, "Train Accuracy": train_acc, "Train Attention Loss": train_att_loss, "Batch Norm": batch_norm}, step=epoch)
107
+
108
+ model_new_best = False
109
+ val_acc = 0
110
+ if (epoch + 1) % args.val_every == 0:
111
+ epoch_time = time.time()
112
+ val_loss, val_acc, qwk = val_epoch(model, valid_loader, epoch=epoch, args=args)
113
+
114
+ logging.info(
115
+ "Final test %d/%d loss: %.4f acc: %.4f qwk: %.4f time %.2fs",
116
+ epoch,
117
+ n_epochs - 1,
118
+ val_loss,
119
+ val_acc,
120
+ qwk,
121
+ time.time() - epoch_time,
122
+ )
123
+ if writer is not None:
124
+ writer.add_scalar("test_loss", val_loss, epoch)
125
+ writer.add_scalar("test_acc", val_acc, epoch)
126
+ writer.add_scalar("test_qwk", qwk, epoch)
127
+
128
+ #val_acc = qwk
129
+ wandb.log({"Test Loss": val_loss, "Test Accuracy": val_acc,"Cohen Kappa": qwk}, step=epoch)
130
+ if val_loss < val_loss_min:
131
+ logging.info("Loss (%.6f --> %.6f)", val_loss_min, val_loss)
132
+ val_loss_min = val_loss
133
+ model_new_best = True
134
+
135
+ if args.logdir is not None:
136
+ save_pirads_checkpoint(model, epoch, args, best_acc=val_acc, filename=f"model_{epoch}.pt")
137
+ if model_new_best:
138
+ logging.info("Copying to model.pt new best model!!!!")
139
+ shutil.copyfile(os.path.join(args.logdir, f"model_{epoch}.pt"), os.path.join(args.logdir, "model.pt"))
140
+ epochs_no_improve = 0
141
+
142
+ else:
143
+ epochs_no_improve += 1
144
+ if epochs_no_improve == args.early_stop:
145
+ logging.info('Early stopping!')
146
+ break
147
+
148
+
149
+
150
+ scheduler.step()
151
+
152
+ logging.info("ALL DONE")
153
+
154
+ elif args.mode == 'test':
155
+
156
+
157
+ kappa_list = []
158
+ for seed in list(range(args.num_seeds)):
159
+ set_determinism(seed=seed)
160
+ valid_loader = get_dataloader(args, split=args.mode)
161
+ logging.info("test:", str(len(valid_loader.dataset)))
162
+ val_loss, val_acc, qwk = val_epoch(model, valid_loader, epoch=0, args=args)
163
+ kappa_list.append(qwk)
164
+ logging.info(f"Seed {seed}, QWK: {qwk}")
165
+ if os.path.exists(cache_dir_):
166
+ logging.info("Removing cache directory ", cache_dir_)
167
+ shutil.rmtree(cache_dir_)
168
+
169
+ logging.info(f"Mean QWK over {args.num_seeds} seeds: {np.mean(kappa_list)}")
170
+
171
+
172
+ if os.path.exists(cache_dir_):
173
+ logging.info("Removing cache directory ", cache_dir_)
174
+ shutil.rmtree(cache_dir_)
175
+
176
+
177
+ def parse_args():
178
+ parser = argparse.ArgumentParser(description="Multiple Instance Learning (MIL) for PIRADS Classification.")
179
+ parser.add_argument('--mode', type=str, choices=['train', 'test'], required=True, help='operation mode: train or infer')
180
+ parser.add_argument('--wandb', action='store_true', help='Add this flag to enable WandB logging')
181
+ parser.add_argument('--project_name', type=str, default='Classification_prostate', help='WandB project name')
182
+ parser.add_argument('--run_name', type=str, default='train_pirads', help='run name for WandB logging')
183
+ parser.add_argument('--config', type=str, help='path to YAML config file')
184
+ parser.add_argument(
185
+ "--project_dir", default=None, help="path to project firectory"
186
+ )
187
+ parser.add_argument(
188
+ "--data_root", default=None, help="path to root folder of images"
189
+ )
190
+ parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
191
+ parser.add_argument("--num_classes", default=4, type=int, help="number of output classes")
192
+ parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm: choose either att_trans or att_pyramid")
193
+ parser.add_argument(
194
+ "--tile_count", default=24, type=int, help="number of patches (instances) to extract from MRI input"
195
+ )
196
+ parser.add_argument("--tile_size", default=64, type=int, help="size of square patch (instance) in pixels")
197
+ parser.add_argument("--depth", default=3, type=int, help="number of slices in each 3D patch (instance)")
198
+ parser.add_argument(
199
+ "--use_heatmap", action="store_true",
200
+ help="enable weak attention heatmap guided patch generation"
201
+ )
202
+ parser.add_argument(
203
+ "--no_heatmap", dest="use_heatmap", action="store_false",
204
+ help="disable heatmap"
205
+ )
206
+ parser.set_defaults(use_heatmap=True)
207
+ parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
208
+
209
+ parser.add_argument("--checkpoint", default=None, help="load existing checkpoint")
210
+ parser.add_argument("--epochs", "--max_epochs", default=50, type=int, help="number of training epochs")
211
+ parser.add_argument("--early_stop", default=40, type=int, help="early stopping criteria")
212
+ parser.add_argument("--batch_size", default=4, type=int, help="number of MRI scans per batch")
213
+ parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate")
214
+ parser.add_argument("--weight_decay", default=0, type=float, help="optimizer weight decay")
215
+ parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
216
+ parser.add_argument(
217
+ "--val_every",
218
+ "--val_interval",
219
+ default=1,
220
+ type=int,
221
+ help="run validation after this number of epochs, default 1 to run every epoch",
222
+ )
223
+ parser.add_argument("--dry_run", action="store_true", help="Run the script in dry-run mode (default: False)")
224
+ args = parser.parse_args()
225
+ if args.config:
226
+ with open(args.config, 'r') as config_file:
227
+ config = yaml.safe_load(config_file)
228
+ args.__dict__.update(config)
229
+ return args
230
+
231
+
232
+
233
+
234
+ if __name__ == "__main__":
235
+ args = parse_args()
236
+ args.logdir = os.path.join(args.project_dir, "logs", args.run_name)
237
+ os.makedirs(args.logdir, exist_ok=True)
238
+ args.logfile = os.path.join(args.logdir, f"{args.run_name}.log")
239
+ setup_logging(args.logfile)
240
+
241
+ logging.info("Argument values:")
242
+ for k, v in vars(args).items():
243
+ logging.info(f"{k} => {v}")
244
+ logging.info("-----------------")
245
+
246
+ args.num_seeds = 10
247
+ args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
248
+ if args.device == torch.device("cpu"):
249
+ args.amp = False
250
+ if args.dataset_json is None:
251
+ logging.error('Dataset JSON file not provided. Quitting.')
252
+ sys.exit(1)
253
+ if args.checkpoint is None and args.mode == 'test':
254
+ logging.error('Model checkpoint path not provided. Quitting.')
255
+ sys.exit(1)
256
+
257
+ if args.dry_run:
258
+ logging.info("Dry run mode enabled.")
259
+ args.epochs = 2
260
+ args.batch_size = 2
261
+ args.workers = 0
262
+ args.num_seeds = 2
263
+ args.wandb = False
264
+
265
+ mode_wandb = "online" if args.wandb else "disabled"
266
+
267
+ config_wandb = {
268
+ "learning_rate": args.optim_lr,
269
+ "batch_size": args.batch_size,
270
+ "epochs": args.epochs,
271
+ "patch size": args.tile_size,
272
+ "patch count": args.tile_count,
273
+ }
274
+ wandb.init(project=args.project_name,
275
+ name=args.run_name,
276
+ dir=os.path.join(args.logdir, "wandb"),
277
+ config=config_wandb,
278
+ mode=mode_wandb)
279
+
280
+
281
+ main_worker(args)
282
+
283
+ wandb.finish()
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (218 Bytes). View file
 
src/__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.49 kB). View file
 
src/data/__init__.py ADDED
File without changes
src/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (223 Bytes). View file
 
src/data/__pycache__/custom_transforms.cpython-39.pyc ADDED
Binary file (12.6 kB). View file
 
src/data/__pycache__/data_loader.cpython-39.pyc ADDED
Binary file (4.24 kB). View file
 
src/data/custom_transforms.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from typing import Union, Optional
4
+
5
+ from monai.transforms import MapTransform
6
+ from monai.config import DtypeLike, KeysCollection
7
+ from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
8
+ from monai.data.meta_obj import get_track_meta
9
+ from monai.transforms.transform import Transform
10
+ from monai.transforms.utils import soft_clip
11
+ from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
12
+ from monai.utils.enums import TransformBackends
13
+ from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor, get_equivalent_dtype
14
+ from scipy.ndimage import binary_dilation
15
+ import cv2
16
+ from typing import Union, Sequence
17
+ from collections.abc import Hashable, Mapping, Sequence
18
+
19
+
20
+
21
+ class DilateAndSaveMaskd(MapTransform):
22
+ """
23
+ Custom transform to dilate binary mask and save a copy.
24
+ """
25
+ def __init__(self, keys, dilation_size=10, copy_key="original_mask"):
26
+ super().__init__(keys)
27
+ self.dilation_size = dilation_size
28
+ self.copy_key = copy_key
29
+
30
+ def __call__(self, data):
31
+ d = dict(data)
32
+
33
+ for key in self.keys:
34
+ mask = d[key].numpy() if isinstance(d[key], torch.Tensor) else d[key]
35
+ mask = mask.squeeze(0) # Remove channel dimension if present
36
+
37
+ # Save a copy of the original mask
38
+ d[self.copy_key] = torch.tensor(mask, dtype=torch.float32).unsqueeze(0) # Save to a new key
39
+
40
+ # Apply binary dilation to the mask
41
+ dilated_mask = binary_dilation(mask, iterations=self.dilation_size).astype(np.uint8)
42
+
43
+ # Store the dilated mask
44
+ d[key] = torch.tensor(dilated_mask, dtype=torch.float32).unsqueeze(0) # Add channel dimension back
45
+
46
+ return d
47
+
48
+
49
+ class ClipMaskIntensityPercentiles(Transform):
50
+
51
+
52
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
53
+
54
+ def __init__(
55
+ self,
56
+ lower: Union[float, None],
57
+ upper: Union[float, None],
58
+ sharpness_factor : Union[float, None] = None,
59
+ channel_wise: bool = False,
60
+ dtype: DtypeLike = np.float32,
61
+ ) -> None:
62
+
63
+ if lower is None and upper is None:
64
+ raise ValueError("lower or upper percentiles must be provided")
65
+ if lower is not None and (lower < 0.0 or lower > 100.0):
66
+ raise ValueError("Percentiles must be in the range [0, 100]")
67
+ if upper is not None and (upper < 0.0 or upper > 100.0):
68
+ raise ValueError("Percentiles must be in the range [0, 100]")
69
+ if upper is not None and lower is not None and upper < lower:
70
+ raise ValueError("upper must be greater than or equal to lower")
71
+ if sharpness_factor is not None and sharpness_factor <= 0:
72
+ raise ValueError("sharpness_factor must be greater than 0")
73
+
74
+ #self.mask_data = mask_data
75
+ self.lower = lower
76
+ self.upper = upper
77
+ self.sharpness_factor = sharpness_factor
78
+ self.channel_wise = channel_wise
79
+ self.dtype = dtype
80
+
81
+ def _clip(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
82
+ masked_img = img * (mask_data > 0)
83
+ if self.sharpness_factor is not None:
84
+
85
+ lower_percentile = percentile(masked_img, self.lower) if self.lower is not None else None
86
+ upper_percentile = percentile(masked_img, self.upper) if self.upper is not None else None
87
+ img = soft_clip(img, self.sharpness_factor, lower_percentile, upper_percentile, self.dtype)
88
+ else:
89
+
90
+ lower_percentile = percentile(masked_img, self.lower) if self.lower is not None else percentile(masked_img, 0)
91
+ upper_percentile = percentile(masked_img, self.upper) if self.upper is not None else percentile(masked_img, 100)
92
+ img = clip(img, lower_percentile, upper_percentile)
93
+
94
+ img = convert_to_tensor(img, track_meta=False)
95
+ return img
96
+
97
+ def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
98
+ """
99
+ Apply the transform to `img`.
100
+ """
101
+ img = convert_to_tensor(img, track_meta=get_track_meta())
102
+ img_t = convert_to_tensor(img, track_meta=False)
103
+ mask_t = convert_to_tensor(mask_data, track_meta=False)
104
+ if self.channel_wise:
105
+ img_t = torch.stack([self._clip(img=d, mask_data=mask_t[e]) for e,d in enumerate(img_t)]) # type: ignore
106
+ else:
107
+ img_t = self._clip(img=img_t, mask_data=mask_t)
108
+
109
+ img = convert_to_dst_type(img_t, dst=img)[0]
110
+
111
+ return img
112
+
113
+ class ClipMaskIntensityPercentilesd(MapTransform):
114
+
115
+ def __init__(
116
+ self,
117
+ keys: KeysCollection,
118
+ mask_key: str,
119
+ lower: Union[float, None],
120
+ upper: Union[float, None],
121
+ sharpness_factor: Union[float, None] = None,
122
+ channel_wise: bool = False,
123
+ dtype: DtypeLike = np.float32,
124
+ allow_missing_keys: bool = False,
125
+ ) -> None:
126
+ super().__init__(keys, allow_missing_keys)
127
+ self.scaler = ClipMaskIntensityPercentiles(
128
+ lower=lower, upper=upper, sharpness_factor=sharpness_factor, channel_wise=channel_wise, dtype=dtype
129
+ )
130
+ self.mask_key = mask_key
131
+
132
+ def __call__(self, data: dict) -> dict:
133
+ d = dict(data)
134
+ for key in self.key_iterator(d):
135
+ d[key] = self.scaler(d[key], d[self.mask_key])
136
+ return d
137
+
138
+
139
+
140
+
141
+ class ElementwiseProductd(MapTransform):
142
+ def __init__(self, keys: KeysCollection, output_key: str) -> None:
143
+ super().__init__(keys)
144
+ self.output_key = output_key
145
+
146
+ def __call__(self, data) -> NdarrayOrTensor:
147
+ d = dict(data)
148
+ d[self.output_key] = d[self.keys[0]] * d[self.keys[1]]
149
+ return d
150
+
151
+
152
+ class CLAHEd(MapTransform):
153
+ """
154
+ Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to images in a data dictionary.
155
+ Works on 2D images or 3D volumes (applied slice-by-slice).
156
+
157
+ Args:
158
+ keys (KeysCollection): Keys of the items to be transformed.
159
+ clip_limit (float): Threshold for contrast limiting. Default is 2.0.
160
+ tile_grid_size (Union[tuple, Sequence[int]]): Size of grid for histogram equalization (default: (8,8)).
161
+ """
162
+ def __init__(
163
+ self,
164
+ keys: KeysCollection,
165
+ clip_limit: float = 2.0,
166
+ tile_grid_size: Union[tuple, Sequence[int]] = (8, 8),
167
+ ) -> None:
168
+ super().__init__(keys)
169
+ self.clip_limit = clip_limit
170
+ self.tile_grid_size = tile_grid_size
171
+
172
+ def __call__(self, data):
173
+ d = dict(data)
174
+ for key in self.keys:
175
+ image_ = d[key]
176
+
177
+ image = image_.cpu().numpy()
178
+
179
+ if image.dtype != np.uint8:
180
+ image = image.astype(np.uint8)
181
+
182
+ clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)
183
+ # Handle 2D images or process 3D images slice-by-slice.
184
+
185
+ image_clahe = np.stack([clahe.apply(slice) for slice in image[0]])
186
+
187
+
188
+ # Convert back to float in [0,1]
189
+ processed_img = image_clahe.astype(np.float32) / 255.0
190
+ reshaped_ = processed_img.reshape(1, *processed_img.shape)
191
+ d[key] = torch.from_numpy(reshaped_).to(image_.device)
192
+ return d
193
+
194
+ class NormalizeIntensity_custom(Transform):
195
+ """
196
+ Normalize input based on the `subtrahend` and `divisor`: `(img - subtrahend) / divisor`.
197
+ Use calculated mean or std value of the input image if no `subtrahend` or `divisor` provided.
198
+ This transform can normalize only non-zero values or entire image, and can also calculate
199
+ mean and std on each channel separately.
200
+ When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should
201
+ be the number of image channels if they are not None.
202
+ If the input is not of floating point type, it will be converted to float32
203
+
204
+ Args:
205
+ subtrahend: the amount to subtract by (usually the mean).
206
+ divisor: the amount to divide by (usually the standard deviation).
207
+ nonzero: whether only normalize non-zero values.
208
+ channel_wise: if True, calculate on each channel separately, otherwise, calculate on
209
+ the entire image directly. default to False.
210
+ dtype: output data type, if None, same as input image. defaults to float32.
211
+ """
212
+
213
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
214
+
215
+ def __init__(
216
+ self,
217
+ subtrahend: Union[Sequence, NdarrayOrTensor, None] = None,
218
+ divisor: Union[Sequence, NdarrayOrTensor, None] = None,
219
+ nonzero: bool = False,
220
+ channel_wise: bool = False,
221
+ dtype: DtypeLike = np.float32,
222
+ ) -> None:
223
+ self.subtrahend = subtrahend
224
+ self.divisor = divisor
225
+ self.nonzero = nonzero
226
+ self.channel_wise = channel_wise
227
+ self.dtype = dtype
228
+
229
+ @staticmethod
230
+ def _mean(x):
231
+ if isinstance(x, np.ndarray):
232
+ return np.mean(x)
233
+ x = torch.mean(x.float())
234
+ return x.item() if x.numel() == 1 else x
235
+
236
+ @staticmethod
237
+ def _std(x):
238
+ if isinstance(x, np.ndarray):
239
+ return np.std(x)
240
+ x = torch.std(x.float(), unbiased=False)
241
+ return x.item() if x.numel() == 1 else x
242
+
243
+ def _normalize(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTensor:
244
+ img, *_ = convert_data_type(img, dtype=torch.float32)
245
+ '''
246
+ if self.nonzero:
247
+ slices = img != 0
248
+ masked_img = img[slices]
249
+ if not slices.any():
250
+ return img
251
+ else:
252
+ slices = None
253
+ masked_img = img
254
+ '''
255
+ slices = None
256
+ mask_data = mask_data.squeeze(0)
257
+ slices_mask = mask_data > 0
258
+ masked_img = img[slices_mask]
259
+
260
+ _sub = sub if sub is not None else self._mean(masked_img)
261
+ if isinstance(_sub, (torch.Tensor, np.ndarray)):
262
+ _sub, *_ = convert_to_dst_type(_sub, img)
263
+ if slices is not None:
264
+ _sub = _sub[slices]
265
+
266
+ _div = div if div is not None else self._std(masked_img)
267
+ if np.isscalar(_div):
268
+ if _div == 0.0:
269
+ _div = 1.0
270
+ elif isinstance(_div, (torch.Tensor, np.ndarray)):
271
+ _div, *_ = convert_to_dst_type(_div, img)
272
+ if slices is not None:
273
+ _div = _div[slices]
274
+ _div[_div == 0.0] = 1.0
275
+
276
+ if slices is not None:
277
+ img[slices] = (masked_img - _sub) / _div
278
+ else:
279
+ img = (img - _sub) / _div
280
+ return img
281
+
282
+ def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
283
+ """
284
+ Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True,
285
+ """
286
+ img = convert_to_tensor(img, track_meta=get_track_meta())
287
+ mask_data = convert_to_tensor(mask_data, track_meta=get_track_meta())
288
+ dtype = self.dtype or img.dtype
289
+ if self.channel_wise:
290
+ if self.subtrahend is not None and len(self.subtrahend) != len(img):
291
+ raise ValueError(f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components.")
292
+ if self.divisor is not None and len(self.divisor) != len(img):
293
+ raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.")
294
+
295
+ if not img.dtype.is_floating_point:
296
+ img, *_ = convert_data_type(img, dtype=torch.float32)
297
+
298
+ for i, d in enumerate(img):
299
+ img[i] = self._normalize( # type: ignore
300
+ d,
301
+ mask_data,
302
+ sub=self.subtrahend[i] if self.subtrahend is not None else None,
303
+ div=self.divisor[i] if self.divisor is not None else None,
304
+ )
305
+ else:
306
+ img = self._normalize(img, mask_data, self.subtrahend, self.divisor)
307
+
308
+ out = convert_to_dst_type(img, img, dtype=dtype)[0]
309
+ return out
310
+
311
+ class NormalizeIntensity_customd(MapTransform):
312
+ """
313
+ Dictionary-based wrapper of :py:class:`monai.transforms.NormalizeIntensity`.
314
+ This transform can normalize only non-zero values or entire image, and can also calculate
315
+ mean and std on each channel separately.
316
+
317
+ Args:
318
+ keys: keys of the corresponding items to be transformed.
319
+ See also: monai.transforms.MapTransform
320
+ subtrahend: the amount to subtract by (usually the mean)
321
+ divisor: the amount to divide by (usually the standard deviation)
322
+ nonzero: whether only normalize non-zero values.
323
+ channel_wise: if True, calculate on each channel separately, otherwise, calculate on
324
+ the entire image directly. default to False.
325
+ dtype: output data type, if None, same as input image. defaults to float32.
326
+ allow_missing_keys: don't raise exception if key is missing.
327
+ """
328
+
329
+ backend = NormalizeIntensity_custom.backend
330
+
331
+ def __init__(
332
+ self,
333
+ keys: KeysCollection,
334
+ mask_key: str,
335
+ subtrahend:Union[ NdarrayOrTensor, None] = None,
336
+ divisor: Union[ NdarrayOrTensor, None] = None,
337
+ nonzero: bool = False,
338
+ channel_wise: bool = False,
339
+ dtype: DtypeLike = np.float32,
340
+ allow_missing_keys: bool = False,
341
+ ) -> None:
342
+ super().__init__(keys, allow_missing_keys)
343
+ self.normalizer = NormalizeIntensity_custom(subtrahend, divisor, nonzero, channel_wise, dtype)
344
+ self.mask_key = mask_key
345
+
346
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
347
+ d = dict(data)
348
+ for key in self.key_iterator(d):
349
+ d[key] = self.normalizer(d[key], d[self.mask_key])
350
+ return d
src/data/data_loader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import numpy as np
4
+ from monai.config import KeysCollection
5
+ from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
6
+ from monai.transforms import (
7
+ Compose,
8
+ LoadImaged,
9
+ MapTransform,
10
+ ScaleIntensityRanged,
11
+ SplitDimd,
12
+ ToTensord,
13
+ ConcatItemsd,
14
+ SelectItemsd,
15
+ EnsureChannelFirstd,
16
+ RepeatChanneld,
17
+ DeleteItemsd,
18
+ EnsureTyped,
19
+ ClipIntensityPercentilesd,
20
+ MaskIntensityd,
21
+ RandCropByPosNegLabeld,
22
+ NormalizeIntensityd,
23
+ SqueezeDimd,
24
+ ScaleIntensityd,
25
+ ScaleIntensityd,
26
+ Transposed,
27
+ RandWeightedCropd,
28
+ )
29
+ from .custom_transforms import (
30
+ NormalizeIntensity_customd,
31
+ ClipMaskIntensityPercentilesd,
32
+ ElementwiseProductd,
33
+ )
34
+ import torch
35
+ from torch.utils.data.dataloader import default_collate
36
+ import matplotlib.pyplot as plt
37
+ from typing import Literal
38
+ import monai
39
+ import collections.abc
40
+
41
+ def list_data_collate(batch: collections.abc.Sequence):
42
+ """
43
+ Combine instances from a list of dicts into a single dict, by stacking them along first dim
44
+ [{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
45
+ followed by the default collate which will form a batch BxNx3xHxW
46
+ """
47
+
48
+ for i, item in enumerate(batch):
49
+ data = item[0]
50
+ data["image"] = torch.stack([ix["image"] for ix in item], dim=0)
51
+
52
+ if all("final_heatmap" in ix for ix in item):
53
+ data["final_heatmap"] = torch.stack([ix["final_heatmap"] for ix in item], dim=0)
54
+
55
+ batch[i] = data
56
+ return default_collate(batch)
57
+
58
+
59
+
60
+ def data_transform(args):
61
+
62
+ if args.use_heatmap:
63
+ transform = Compose(
64
+ [
65
+ LoadImaged(keys=["image", "mask","dwi", "adc", "heatmap"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
66
+ ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
67
+ ConcatItemsd(keys=["image", "dwi", "adc"], name="image", dim=0), # stacks to (3, H, W)
68
+ NormalizeIntensity_customd(keys=["image"], channel_wise=True, mask_key="mask"),
69
+ ElementwiseProductd(keys=["mask", "heatmap"], output_key="final_heatmap"),
70
+ RandWeightedCropd(keys=["image", "final_heatmap"],
71
+ w_key="final_heatmap",
72
+ spatial_size=(args.tile_size,args.tile_size,args.depth),
73
+ num_samples=args.tile_count),
74
+ EnsureTyped(keys=["label"], dtype=torch.float32),
75
+ Transposed(keys=["image"], indices=(0, 3, 1, 2)),
76
+ DeleteItemsd(keys=['mask', 'dwi', 'adc', 'heatmap']),
77
+ ToTensord(keys=["image", "label", "final_heatmap"]),
78
+ ]
79
+ )
80
+ else:
81
+ transform = Compose(
82
+ [
83
+ LoadImaged(keys=["image", "mask","dwi", "adc"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
84
+ ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
85
+ ConcatItemsd(keys=["image", "dwi", "adc"], name="image", dim=0), # stacks to (3, H, W)
86
+ NormalizeIntensityd(keys=["image"], channel_wise=True),
87
+ RandCropByPosNegLabeld(keys=["image"],
88
+ label_key="mask",
89
+ spatial_size=(args.tile_size,args.tile_size,args.depth),
90
+ pos=1,
91
+ neg=0,
92
+ num_samples=args.tile_count),
93
+ EnsureTyped(keys=["label"], dtype=torch.float32),
94
+ Transposed(keys=["image"], indices=(0, 3, 1, 2)),
95
+ DeleteItemsd(keys=['mask', 'dwi', 'adc']),
96
+ ToTensord(keys=["image", "label"]),
97
+ ]
98
+ )
99
+ return transform
100
+
101
+ def get_dataloader(args, split: Literal["train", "test"]):
102
+
103
+ data_list = load_decathlon_datalist(
104
+ data_list_file_path=args.dataset_json,
105
+ data_list_key=split,
106
+ base_dir=args.data_root,
107
+ )
108
+ if args.dry_run:
109
+ data_list = data_list[:8] # Use only 8 samples for dry run
110
+ cache_dir_ = os.path.join(args.logdir, "cache")
111
+ os.makedirs(os.path.join(cache_dir_, split), exist_ok=True)
112
+ transform = data_transform(args)
113
+ dataset = PersistentDataset(data=data_list, transform=transform, cache_dir= os.path.join(cache_dir_, split))
114
+ loader = torch.utils.data.DataLoader(
115
+ dataset,
116
+ batch_size=args.batch_size,
117
+ shuffle=(split == "train"),
118
+ num_workers=args.workers,
119
+ pin_memory=True,
120
+ multiprocessing_context="spawn" if args.workers > 0 else None,
121
+ sampler=None,
122
+ collate_fn=list_data_collate,
123
+ )
124
+ return loader
125
+
src/model/MIL.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ from typing import cast
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from monai.utils.module import optional_import
10
+ from monai.networks.nets import resnet
11
+ models, _ = optional_import("torchvision.models")
12
+
13
+
14
+ class MILModel_3D(nn.Module):
15
+ """
16
+ Multiple Instance Learning (MIL) model, with a backbone classification model.
17
+ Adapted from MONAI, modified for 3D images. The expected shape of input data is `[B, N, C, D, H, W]`,
18
+ where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances
19
+ extracted from every original image in the batch. A tutorial example is available at:
20
+ https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning.
21
+
22
+ Args:
23
+ num_classes: number of output classes.
24
+ mil_mode: MIL algorithm, available values (Defaults to ``"att"``):
25
+
26
+ - ``"mean"`` - average features from all instances, equivalent to pure CNN (non MIL).
27
+ - ``"max"`` - retain only the instance with the max probability for loss calculation.
28
+ - ``"att"`` - attention based MIL https://arxiv.org/abs/1802.04712.
29
+ - ``"att_trans"`` - transformer MIL https://arxiv.org/abs/2111.01556.
30
+ - ``"att_trans_pyramid"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556.
31
+
32
+ pretrained: init backbone with pretrained weights, defaults to ``True``.
33
+ backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features,
34
+ or a string name of a torchvision model).
35
+ Defaults to ``None``, in which case ResNet50 is used.
36
+ backbone_num_features: Number of output features of the backbone CNN
37
+ Defaults to ``None`` (necessary only when using a custom backbone)
38
+ trans_blocks: number of the blocks in `TransformEncoder` layer.
39
+ trans_dropout: dropout rate in `TransformEncoder` layer.
40
+
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ num_classes: int,
46
+ mil_mode: str = "att",
47
+ pretrained: bool = True,
48
+ backbone: str | nn.Module | None = None,
49
+ backbone_num_features: int | None = None,
50
+ trans_blocks: int = 4,
51
+ trans_dropout: float = 0.0,
52
+ ) -> None:
53
+ super().__init__()
54
+
55
+ if num_classes <= 0:
56
+ raise ValueError("Number of classes must be positive: " + str(num_classes))
57
+
58
+ if mil_mode.lower() not in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]:
59
+ raise ValueError("Unsupported mil_mode: " + str(mil_mode))
60
+
61
+ self.mil_mode = mil_mode.lower()
62
+ self.attention = nn.Sequential()
63
+ self.transformer: nn.Module | None = None
64
+
65
+ if backbone is None:
66
+ net = resnet.resnet18(spatial_dims=3, n_input_channels=3, num_classes=5, )
67
+ nfc = net.fc.in_features # save the number of final features
68
+ net.fc = torch.nn.Identity() # remove final linear layer
69
+
70
+ self.extra_outputs: dict[str, torch.Tensor] = {}
71
+
72
+ if mil_mode == "att_trans_pyramid":
73
+ # register hooks to capture outputs of intermediate layers
74
+ def forward_hook(layer_name):
75
+
76
+ def hook(module, input, output):
77
+ self.extra_outputs[layer_name] = output
78
+
79
+ return hook
80
+
81
+ net.layer1.register_forward_hook(forward_hook("layer1"))
82
+ net.layer2.register_forward_hook(forward_hook("layer2"))
83
+ net.layer3.register_forward_hook(forward_hook("layer3"))
84
+ net.layer4.register_forward_hook(forward_hook("layer4"))
85
+
86
+ elif isinstance(backbone, str):
87
+ # assume torchvision model string is provided
88
+ torch_model = getattr(models, backbone, None)
89
+ if torch_model is None:
90
+ raise ValueError("Unknown torch vision model" + str(backbone))
91
+ net = torch_model(pretrained=pretrained)
92
+
93
+ if getattr(net, "fc", None) is not None:
94
+ nfc = net.fc.in_features # save the number of final features
95
+ net.fc = torch.nn.Identity() # remove final linear layer
96
+ else:
97
+ raise ValueError(
98
+ "Unable to detect FC layer for the torchvision model " + str(backbone),
99
+ ". Please initialize the backbone model manually.",
100
+ )
101
+
102
+ elif isinstance(backbone, nn.Module):
103
+ # use a custom backbone
104
+ net = backbone
105
+ nfc = backbone_num_features
106
+ net.fc = torch.nn.Identity() # remove final linear layer
107
+
108
+ self.extra_outputs: dict[str, torch.Tensor] = {}
109
+
110
+ if mil_mode == "att_trans_pyramid":
111
+ # register hooks to capture outputs of intermediate layers
112
+ def forward_hook(layer_name):
113
+
114
+ def hook(module, input, output):
115
+ self.extra_outputs[layer_name] = output
116
+
117
+ return hook
118
+
119
+ net.layer1.register_forward_hook(forward_hook("layer1"))
120
+ net.layer2.register_forward_hook(forward_hook("layer2"))
121
+ net.layer3.register_forward_hook(forward_hook("layer3"))
122
+ net.layer4.register_forward_hook(forward_hook("layer4"))
123
+
124
+ if backbone_num_features is None:
125
+ raise ValueError("Number of endencoder features must be provided for a custom backbone model")
126
+
127
+ else:
128
+ raise ValueError("Unsupported backbone")
129
+
130
+ if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]:
131
+ raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode))
132
+
133
+ if self.mil_mode in ["mean", "max"]:
134
+ pass
135
+ elif self.mil_mode == "att":
136
+ self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))
137
+
138
+ elif self.mil_mode == "att_trans":
139
+ transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout)
140
+ self.transformer = nn.TransformerEncoder(transformer, num_layers=trans_blocks)
141
+ self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))
142
+
143
+ elif self.mil_mode == "att_trans_pyramid":
144
+ transformer_list = nn.ModuleList(
145
+ [
146
+ nn.TransformerEncoder(
147
+ nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout), num_layers=trans_blocks
148
+ ),
149
+ nn.Sequential(
150
+ nn.Linear(192, 64),
151
+ nn.TransformerEncoder(
152
+ nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout),
153
+ num_layers=trans_blocks,
154
+ ),
155
+ ),
156
+ nn.Sequential(
157
+ nn.Linear(320, 64),
158
+ nn.TransformerEncoder(
159
+ nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout),
160
+ num_layers=trans_blocks,
161
+ ),
162
+ ),
163
+ nn.TransformerEncoder(
164
+ nn.TransformerEncoderLayer(d_model=576, nhead=8, dropout=trans_dropout),
165
+ num_layers=trans_blocks,
166
+ ),
167
+ ]
168
+ )
169
+ self.transformer = transformer_list
170
+ nfc = nfc + 64
171
+ self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))
172
+
173
+ else:
174
+ raise ValueError("Unsupported mil_mode: " + str(mil_mode))
175
+
176
+ self.myfc = nn.Linear(nfc, num_classes)
177
+ self.net = net
178
+
179
+ def calc_head(self, x: torch.Tensor) -> torch.Tensor:
180
+ sh = x.shape
181
+
182
+ if self.mil_mode == "mean":
183
+ x = self.myfc(x)
184
+ x = torch.mean(x, dim=1)
185
+
186
+ elif self.mil_mode == "max":
187
+ x = self.myfc(x)
188
+ x, _ = torch.max(x, dim=1)
189
+
190
+ elif self.mil_mode == "att":
191
+ a = self.attention(x)
192
+ a = torch.softmax(a, dim=1)
193
+ x = torch.sum(x * a, dim=1)
194
+
195
+ x = self.myfc(x)
196
+
197
+ elif self.mil_mode == "att_trans" and self.transformer is not None:
198
+ x = x.permute(1, 0, 2)
199
+ x = self.transformer(x)
200
+ x = x.permute(1, 0, 2)
201
+
202
+ a = self.attention(x)
203
+ a = torch.softmax(a, dim=1)
204
+ x = torch.sum(x * a, dim=1)
205
+
206
+ x = self.myfc(x)
207
+
208
+ elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None:
209
+ l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3, 4)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
210
+ l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3, 4)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
211
+ l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3, 4)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
212
+ l4 = torch.mean(self.extra_outputs["layer4"], dim=(2, 3, 4)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
213
+
214
+ transformer_list = cast(nn.ModuleList, self.transformer)
215
+
216
+ x = transformer_list[0](l1)
217
+ x = transformer_list[1](torch.cat((x, l2), dim=2))
218
+ x = transformer_list[2](torch.cat((x, l3), dim=2))
219
+ x = transformer_list[3](torch.cat((x, l4), dim=2))
220
+
221
+ x = x.permute(1, 0, 2)
222
+
223
+ a = self.attention(x)
224
+ a = torch.softmax(a, dim=1)
225
+ x = torch.sum(x * a, dim=1)
226
+
227
+ x = self.myfc(x)
228
+
229
+ else:
230
+ raise ValueError("Wrong model mode" + str(self.mil_mode))
231
+
232
+ return x
233
+
234
+ def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor:
235
+ sh = x.shape
236
+ x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5])
237
+
238
+ x = self.net(x)
239
+ x = x.reshape(sh[0], sh[1], -1)
240
+
241
+ if not no_head:
242
+ x = self.calc_head(x)
243
+
244
+ return x
245
+
246
+
247
+
248
+
src/model/__init__.py ADDED
File without changes
src/model/__pycache__/MIL.cpython-39.pyc ADDED
Binary file (6.85 kB). View file
 
src/model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (224 Bytes). View file
 
src/model/__pycache__/csPCa_model.cpython-39.pyc ADDED
Binary file (1.92 kB). View file
 
src/model/csPCa_model.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ from typing import cast
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from monai.utils.module import optional_import
10
+ models, _ = optional_import("torchvision.models")
11
+
12
+
13
+
14
+ class SimpleNN(nn.Module):
15
+ def __init__(self, input_dim):
16
+ super(SimpleNN, self).__init__()
17
+ self.net = nn.Sequential(
18
+ nn.Linear(input_dim, 256),
19
+ nn.ReLU(),
20
+ nn.Linear( 256,128),
21
+ nn.ReLU(),
22
+ nn.Dropout(p=0.3),
23
+ nn.Linear(128, 1),
24
+ nn.Sigmoid() # since binary classification
25
+ )
26
+ def forward(self, x):
27
+ return self.net(x)
28
+
29
+ class csPCa_Model(nn.Module):
30
+ def __init__(self, backbone):
31
+ super().__init__()
32
+ self.backbone = backbone
33
+ self.fc_dim = backbone.myfc.in_features
34
+ self.fc_cspca = SimpleNN(input_dim=self.fc_dim)
35
+
36
+ def forward(self, x):
37
+ sh = x.shape
38
+ x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5])
39
+ x = self.backbone.net(x)
40
+ x = x.reshape(sh[0], sh[1], -1)
41
+ x = x.permute(1, 0, 2)
42
+ x = self.backbone.transformer(x)
43
+ x = x.permute(1, 0, 2)
44
+ a = self.backbone.attention(x)
45
+ a = torch.softmax(a, dim=1)
46
+ x = torch.sum(x * a, dim=1)
47
+
48
+ x = self.fc_cspca(x)
49
+ return x
50
+
src/preprocessing/__init__.py ADDED
File without changes
src/preprocessing/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (232 Bytes). View file
 
src/preprocessing/__pycache__/center_crop.cpython-39.pyc ADDED
Binary file (2.73 kB). View file
 
src/preprocessing/__pycache__/generate_heatmap.cpython-39.pyc ADDED
Binary file (1.49 kB). View file
 
src/preprocessing/__pycache__/histogram_match.cpython-39.pyc ADDED
Binary file (1.99 kB). View file
 
src/preprocessing/__pycache__/prostate_mask.cpython-39.pyc ADDED
Binary file (3.86 kB). View file
 
src/preprocessing/__pycache__/register_and_crop.cpython-39.pyc ADDED
Binary file (2.13 kB). View file
 
src/preprocessing/center_crop.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 - 2022 MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ #python scripts/center_crop.py --file_name path/to/t2_image --out_name cropped_t2
12
+
13
+
14
+ #import argparse
15
+ from typing import Union
16
+
17
+ import SimpleITK as sitk # noqa N813
18
+
19
+
20
+ def _flatten(t):
21
+ return [item for sublist in t for item in sublist]
22
+
23
+
24
+ def crop(image: sitk.Image, margin: Union[int, float], interpolator=sitk.sitkLinear):
25
+ """
26
+ Crops a sitk.Image while retaining correct spacing. Negative margins will lead to zero padding
27
+
28
+ Args:
29
+ image: a sitk.Image
30
+ margin: margins to crop. Single integer or float (percentage crop),
31
+ lists of int/float or nestes lists are supported.
32
+ """
33
+ if isinstance(margin, (list, tuple)):
34
+ assert len(margin) == 3, "expected margin to be of length 3"
35
+ else:
36
+ assert isinstance(margin, (int, float)), "expected margin to be a float value"
37
+ margin = [margin, margin, margin]
38
+
39
+ margin = [m if isinstance(m, (tuple, list)) else [m, m] for m in margin]
40
+ old_size = image.GetSize()
41
+
42
+ # calculate new origin and new image size
43
+ if all([isinstance(m, float) for m in _flatten(margin)]):
44
+ assert all([m >= 0 and m < 0.5 for m in _flatten(margin)]), "margins must be between 0 and 0.5"
45
+ to_crop = [[int(sz * _m) for _m in m] for sz, m in zip(old_size, margin)]
46
+ elif all([isinstance(m, int) for m in _flatten(margin)]):
47
+ to_crop = margin
48
+ else:
49
+ raise ValueError("Wrong format of margins.")
50
+
51
+ new_size = [sz - sum(c) for sz, c in zip(old_size, to_crop)]
52
+
53
+ # origin has Index (0,0,0)
54
+ # new origin has Index (to_crop[0][0], to_crop[2][0], to_crop[2][0])
55
+ new_origin = image.TransformIndexToPhysicalPoint([c[0] for c in to_crop])
56
+
57
+ # create reference plane to resample image
58
+ ref_image = sitk.Image(new_size, image.GetPixelIDValue())
59
+ ref_image.SetSpacing(image.GetSpacing())
60
+ ref_image.SetOrigin(new_origin)
61
+ ref_image.SetDirection(image.GetDirection())
62
+
63
+ return sitk.Resample(image, ref_image, interpolator=interpolator)
64
+
src/preprocessing/generate_heatmap.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import numpy as np
4
+ import nrrd
5
+ import json
6
+ import pandas as pd
7
+ import json
8
+ import SimpleITK as sitk
9
+ import multiprocessing
10
+
11
+ import logging
12
+
13
+
14
+ def get_heatmap(args):
15
+
16
+ files = os.listdir(args.t2_dir)
17
+ args.heatmapdir = os.path.join(args.output_dir, 'heatmaps/')
18
+ os.makedirs(args.heatmapdir, exist_ok=True)
19
+ for file in files:
20
+
21
+ bool_dwi = False
22
+ bool_adc = False
23
+ mask, _ = nrrd.read(os.path.join(args.seg_dir, file))
24
+ dwi, _ = nrrd.read(os.path.join(args.dwi_dir, file))
25
+ adc, _ = nrrd.read(os.path.join(args.adc_dir, file))
26
+
27
+ nonzero_vals_dwi = dwi[mask > 0]
28
+ if len(nonzero_vals_dwi) > 0:
29
+ min_val = nonzero_vals_dwi.min()
30
+ max_val = nonzero_vals_dwi.max()
31
+ heatmap_dwi = np.zeros_like(dwi, dtype=np.float32)
32
+
33
+ if min_val != max_val:
34
+ heatmap_dwi = (dwi - min_val) / (max_val - min_val)
35
+ masked_heatmap_dwi = np.where(mask > 0, heatmap_dwi, heatmap_dwi[mask>0].min())
36
+ else:
37
+ bool_dwi = True
38
+
39
+ else:
40
+ bool_dwi = True
41
+
42
+ nonzero_vals_adc = adc[mask > 0]
43
+ if len(nonzero_vals_adc) > 0:
44
+ min_val = nonzero_vals_adc.min()
45
+ max_val = nonzero_vals_adc.max()
46
+ heatmap_adc = np.zeros_like(adc, dtype=np.float32)
47
+
48
+ if min_val != max_val:
49
+ heatmap_adc = (max_val - adc) / (max_val - min_val)
50
+ masked_heatmap_adc = np.where(mask > 0, heatmap_adc, heatmap_adc[mask>0].min())
51
+ else:
52
+ bool_adc = True
53
+
54
+ else:
55
+ bool_adc = True
56
+
57
+
58
+ if bool_dwi:
59
+ mix_mask = masked_heatmap_adc
60
+ if bool_adc:
61
+ mix_mask = masked_heatmap_dwi
62
+ if not bool_dwi and not bool_adc:
63
+ mix_mask = masked_heatmap_dwi * masked_heatmap_adc
64
+ else:
65
+ mix_mask = np.ones_like(adc, dtype=np.float32)
66
+
67
+ mix_mask = (mix_mask - mix_mask.min()) / (mix_mask.max() - mix_mask.min())
68
+
69
+
70
+ nrrd.write(os.path.join(args.heatmapdir, file), mix_mask)
71
+
72
+ return args
73
+
74
+
75
+
76
+
src/preprocessing/histogram_match.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import SimpleITK as sitk
2
+ import os
3
+ import numpy as np
4
+ import nrrd
5
+ from tqdm import tqdm
6
+ import pandas as pd
7
+ import random
8
+ import json
9
+ from skimage import exposure
10
+ import multiprocessing
11
+ import logging
12
+ def get_histmatched(data, ref_data, mask, ref_mask):
13
+
14
+ source_pixels = data[mask > 0]
15
+ ref_pixels = ref_data[ref_mask > 0]
16
+ matched_pixels = exposure.match_histograms(source_pixels, ref_pixels)
17
+ matched_img = data.copy()
18
+ matched_img[mask > 0] = matched_pixels
19
+
20
+ return matched_img
21
+
22
+ def histmatch(args):
23
+
24
+ files = os.listdir(args.t2_dir)
25
+
26
+ t2_histmatched_dir = os.path.join(args.output_dir, 't2_histmatched')
27
+ dwi_histmatched_dir = os.path.join(args.output_dir, 'DWI_histmatched')
28
+ adc_histmatched_dir = os.path.join(args.output_dir, 'ADC_histmatched')
29
+ os.makedirs(t2_histmatched_dir, exist_ok=True)
30
+ os.makedirs(dwi_histmatched_dir, exist_ok=True)
31
+ os.makedirs(adc_histmatched_dir, exist_ok=True)
32
+ logging.info("Starting histogram matching")
33
+ for file in files:
34
+
35
+ t2_image, header_t2 = nrrd.read(os.path.join(args.t2_dir, file))
36
+ dwi_image, header_dwi = nrrd.read(os.path.join(args.dwi_dir, file))
37
+ adc_image, header_adc = nrrd.read(os.path.join(args.adc_dir, file))
38
+
39
+ ref_t2, _ = nrrd.read(os.path.join(args.project_dir, 'dataset', 't2_reference.nrrd'))
40
+ ref_dwi, _ = nrrd.read(os.path.join(args.project_dir, 'dataset', 'dwi_reference.nrrd'))
41
+ ref_adc , _ = nrrd.read(os.path.join(args.project_dir, 'dataset', 'adc_reference.nrrd'))
42
+
43
+ prostate_mask, _ = nrrd.read(os.path.join(args.seg_dir, file))
44
+ ref_prostate_mask, _ = nrrd.read(os.path.join(args.project_dir, 'dataset', 'prostate_segmentation_reference.nrrd'))
45
+
46
+ histmatched_t2 = get_histmatched(t2_image, ref_t2, prostate_mask, ref_prostate_mask)
47
+ histmatched_dwi = get_histmatched(dwi_image, ref_dwi, prostate_mask, ref_prostate_mask)
48
+ histmatched_adc = get_histmatched(adc_image, ref_adc, prostate_mask, ref_prostate_mask)
49
+
50
+
51
+
52
+ nrrd.write(os.path.join(t2_histmatched_dir, file), histmatched_t2, header_t2)
53
+ nrrd.write(os.path.join(dwi_histmatched_dir, file), histmatched_dwi, header_dwi)
54
+ nrrd.write(os.path.join(adc_histmatched_dir, file), histmatched_adc, header_adc)
55
+
56
+ args.t2_dir = t2_histmatched_dir
57
+ args.dwi_dir = dwi_histmatched_dir
58
+ args.adc_dir = adc_histmatched_dir
59
+
60
+ return args
61
+
62
+
src/preprocessing/prostate_mask.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+ import SimpleITK as sitk
4
+ import numpy as np
5
+ import nrrd
6
+ import matplotlib.pyplot as plt
7
+ from tqdm import tqdm
8
+ from AIAH_utility.viewer import BasicViewer, ListViewer
9
+ from PIL import Image
10
+ import monai
11
+ from monai.bundle import ConfigParser
12
+ from monai.config import print_config
13
+ import torch
14
+ import sys
15
+ import os
16
+ import nibabel as nib
17
+ import shutil
18
+
19
+ from tqdm import trange, tqdm
20
+
21
+ from monai.data import DataLoader, Dataset, TestTimeAugmentation, create_test_image_2d
22
+ from monai.losses import DiceLoss
23
+ from monai.metrics import DiceMetric
24
+ from monai.networks.nets import UNet
25
+ from monai.transforms import (
26
+ Activationsd,
27
+ AsDiscreted,
28
+ Compose,
29
+ CropForegroundd,
30
+ DivisiblePadd,
31
+ Invertd,
32
+ LoadImaged,
33
+ ScaleIntensityd,
34
+ RandRotated,
35
+ RandRotate,
36
+ InvertibleTransform,
37
+ RandFlipd,
38
+ Activations,
39
+ AsDiscrete,
40
+ NormalizeIntensityd,
41
+ )
42
+ from monai.utils import set_determinism
43
+ from monai.transforms import (
44
+ Resize,
45
+ EnsureChannelFirstd,
46
+ Orientationd,
47
+ Spacingd,
48
+ EnsureTyped,
49
+ )
50
+ import nrrd
51
+
52
+ set_determinism(43)
53
+ from monai.data import MetaTensor
54
+ import SimpleITK as sitk
55
+ import pandas as pd
56
+ import logging
57
+ def get_segmask(args):
58
+
59
+ args.seg_dir = os.path.join(args.output_dir, "prostate_mask")
60
+ os.makedirs(args.seg_dir, exist_ok=True)
61
+
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+ model_config_file = os.path.join(args.project_dir, "config", "inference.json")
64
+ model_config = ConfigParser()
65
+ model_config.read_config(model_config_file)
66
+ model_config["output_dir"] = args.seg_dir
67
+ model_config["dataset_dir"] = args.t2_dir
68
+ files = os.listdir(args.t2_dir)
69
+ model_config["datalist"] = [os.path.join(args.t2_dir, f) for f in files]
70
+
71
+
72
+ checkpoint = os.path.join(
73
+ args.project_dir,
74
+ "models",
75
+ "prostate_segmentation_model.pt",
76
+ )
77
+ preprocessing = model_config.get_parsed_content("preprocessing")
78
+ model = model_config.get_parsed_content("network_def").to(device)
79
+ inferer = model_config.get_parsed_content("inferer")
80
+ postprocessing = model_config.get_parsed_content("postprocessing")
81
+ dataloader = model_config.get_parsed_content("dataloader")
82
+ model.load_state_dict(torch.load(checkpoint, map_location=device))
83
+ model.eval()
84
+
85
+ torch.cuda.empty_cache()
86
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
87
+
88
+ keys = "image"
89
+ transform = Compose(
90
+ [
91
+ LoadImaged(keys=keys),
92
+ EnsureChannelFirstd(keys=keys),
93
+ Orientationd(keys=keys, axcodes="RAS"),
94
+ Spacingd(keys=keys, pixdim=[0.5, 0.5, 0.5], mode="bilinear"),
95
+ ScaleIntensityd(keys=keys, minv=0, maxv=1),
96
+ NormalizeIntensityd(keys=keys),
97
+ EnsureTyped(keys=keys),
98
+ ]
99
+ )
100
+ logging.info("Starting prostate segmentation")
101
+ for file in tqdm(files):
102
+
103
+ data = {"image": os.path.join(args.t2_dir, file)}
104
+ transformed_data = transform(data)
105
+ a = transformed_data
106
+ with torch.no_grad():
107
+ images = a["image"].reshape(1, *(a["image"].shape)).to(device)
108
+ data["pred"] = inferer(images, network=model)
109
+ pred_img = data["pred"].argmax(1).cpu()
110
+
111
+ model_output = {}
112
+ model_output["image"] = MetaTensor(pred_img, meta=transformed_data["image"].meta)
113
+ transformed_data["image"].data = model_output["image"].data
114
+ temp = transform.inverse(transformed_data)
115
+ pred_temp = temp["image"][0].numpy()
116
+ pred_nrrd = np.round(pred_temp)
117
+
118
+ nonzero_counts = np.count_nonzero(pred_nrrd, axis=(0,1))
119
+ top_slices = np.argsort(nonzero_counts)[-10:]
120
+ output_ = np.zeros_like(pred_nrrd)
121
+ output_[:,:,top_slices] = pred_nrrd[:,:,top_slices]
122
+
123
+ nrrd.write(os.path.join(args.seg_dir, file), output_)
124
+
125
+ return args
126
+
127
+
128
+
src/preprocessing/register_and_crop.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import SimpleITK as sitk
2
+ import os
3
+ import numpy as np
4
+ import nrrd
5
+ from tqdm import tqdm
6
+ import pandas as pd
7
+ from picai_prep.preprocessing import PreprocessingSettings, Sample
8
+ import multiprocessing
9
+ from .center_crop import crop
10
+ import logging
11
+ def register_files(args):
12
+ files = os.listdir(args.t2_dir)
13
+ new_spacing = (0.4, 0.4, 3.0)
14
+ t2_registered_dir = os.path.join(args.output_dir, 't2_registered')
15
+ dwi_registered_dir = os.path.join(args.output_dir, 'DWI_registered')
16
+ adc_registered_dir = os.path.join(args.output_dir, 'ADC_registered')
17
+ os.makedirs(t2_registered_dir, exist_ok=True)
18
+ os.makedirs(dwi_registered_dir, exist_ok=True)
19
+ os.makedirs(adc_registered_dir, exist_ok=True)
20
+ logging.info("Starting registration and cropping")
21
+ for file in tqdm(files):
22
+
23
+ t2_image = sitk.ReadImage(os.path.join(args.t2_dir, file))
24
+ dwi_image = sitk.ReadImage(os.path.join(args.dwi_dir, file))
25
+ adc_image = sitk.ReadImage(os.path.join(args.adc_dir, file))
26
+
27
+ original_spacing = t2_image.GetSpacing()
28
+ original_size = t2_image.GetSize()
29
+ new_size = [
30
+ int(round(osz * ospc / nspc))
31
+ for osz, ospc, nspc in zip(original_size, original_spacing, new_spacing)
32
+ ]
33
+
34
+ images_to_preprocess = {}
35
+ images_to_preprocess['t2'] = t2_image
36
+ images_to_preprocess['hbv'] = dwi_image
37
+ images_to_preprocess['adc'] = adc_image
38
+
39
+ pat_case = Sample(
40
+ scans=[
41
+ images_to_preprocess.get('t2'),
42
+ images_to_preprocess.get('hbv'),
43
+ images_to_preprocess.get('adc'),
44
+ ],
45
+ settings=PreprocessingSettings(spacing=[3.0,0.4,0.4], matrix_size=[new_size[2],new_size[1],new_size[0]]),
46
+ )
47
+ pat_case.preprocess()
48
+
49
+ t2_post = pat_case.__dict__['scans'][0]
50
+ dwi_post = pat_case.__dict__['scans'][1]
51
+ adc_post = pat_case.__dict__['scans'][2]
52
+ cropped_t2 = crop(t2_post, [args.margin, args.margin, 0.0])
53
+ cropped_dwi = crop(dwi_post, [args.margin, args.margin, 0.0])
54
+ cropped_adc = crop(adc_post, [args.margin, args.margin, 0.0])
55
+
56
+
57
+
58
+ sitk.WriteImage(cropped_t2, os.path.join(t2_registered_dir, file))
59
+ sitk.WriteImage(cropped_dwi, os.path.join(dwi_registered_dir, file))
60
+ sitk.WriteImage(cropped_adc, os.path.join(adc_registered_dir, file))
61
+
62
+ args.t2_dir = t2_registered_dir
63
+ args.dwi_dir = dwi_registered_dir
64
+ args.adc_dir = adc_registered_dir
65
+
66
+ return args
67
+
src/train/__init__.py ADDED
File without changes
src/train/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (224 Bytes). View file
 
src/train/__pycache__/train_cspca.cpython-39.pyc ADDED
Binary file (4.66 kB). View file
 
src/train/__pycache__/train_pirads.cpython-39.pyc ADDED
Binary file (6.63 kB). View file
 
src/train/train_cspca.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections.abc
3
+ import os
4
+ import shutil
5
+ import time
6
+ import yaml
7
+ from scipy.stats import pearsonr
8
+ import gdown
9
+ import numpy as np
10
+ import torch
11
+ import torch.distributed as dist
12
+ import torch.multiprocessing as mp
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from monai.config import KeysCollection
16
+ from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
17
+ from monai.data.wsi_reader import WSIReader
18
+ from monai.metrics import Cumulative, CumulativeAverage
19
+ from monai.networks.nets import milmodel, resnet, MILModel
20
+ from monai.transforms import (
21
+ Compose,
22
+ GridPatchd,
23
+ LoadImaged,
24
+ MapTransform,
25
+ RandFlipd,
26
+ RandGridPatchd,
27
+ RandRotate90d,
28
+ ScaleIntensityRanged,
29
+ SplitDimd,
30
+ ToTensord,
31
+ ConcatItemsd,
32
+ SelectItemsd,
33
+ EnsureChannelFirstd,
34
+ RepeatChanneld,
35
+ DeleteItemsd,
36
+ EnsureTyped,
37
+ ClipIntensityPercentilesd,
38
+ MaskIntensityd,
39
+ HistogramNormalized,
40
+ RandBiasFieldd,
41
+ RandCropByPosNegLabeld,
42
+ NormalizeIntensityd,
43
+ SqueezeDimd,
44
+ CropForegroundd,
45
+ ScaleIntensityd,
46
+ SpatialPadd,
47
+ CenterSpatialCropd,
48
+ ScaleIntensityd,
49
+ Transposed,
50
+ RandWeightedCropd,
51
+ )
52
+ from sklearn.metrics import cohen_kappa_score, roc_curve, confusion_matrix
53
+ from torch.cuda.amp import GradScaler, autocast
54
+ from torch.utils.data.dataloader import default_collate
55
+ from torchvision.models.resnet import ResNet50_Weights
56
+ import torch.optim as optim
57
+ from torch.utils.data.distributed import DistributedSampler
58
+ from torch.utils.tensorboard import SummaryWriter
59
+ import matplotlib.pyplot as plt
60
+ import matplotlib.patches as patches
61
+ from tqdm import tqdm
62
+ from sklearn.metrics import confusion_matrix, roc_auc_score
63
+ from sklearn.metrics import roc_auc_score
64
+ from sklearn.preprocessing import label_binarize
65
+ import numpy as np
66
+ from AIAH_utility.viewer import BasicViewer
67
+ from scipy.special import expit
68
+ import nrrd
69
+ import random
70
+ from sklearn.metrics import roc_auc_score
71
+ import SimpleITK as sitk
72
+ from AIAH_utility.viewer import BasicViewer
73
+ import pandas as pd
74
+ import json
75
+ from sklearn.preprocessing import StandardScaler
76
+ from torch.utils.data import DataLoader, TensorDataset, Dataset
77
+ from sklearn.linear_model import LogisticRegression
78
+ from sklearn.utils import resample
79
+ import monai
80
+
81
+ def train_epoch(cspca_model, loader, optimizer, epoch, args):
82
+ cspca_model.train()
83
+ criterion = nn.BCELoss()
84
+ loss = 0.0
85
+ run_loss = CumulativeAverage()
86
+ TARGETS = Cumulative()
87
+ PREDS = Cumulative()
88
+
89
+ for idx, batch_data in enumerate(loader):
90
+ data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
91
+ target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
92
+
93
+ optimizer.zero_grad()
94
+ output = cspca_model(data)
95
+ output = output.squeeze(1)
96
+ loss = criterion(output, target)
97
+ loss.backward()
98
+ optimizer.step()
99
+
100
+ TARGETS.extend(target.detach().cpu())
101
+ PREDS.extend(output.detach().cpu())
102
+ run_loss.append(loss.item())
103
+
104
+ loss_epoch = run_loss.aggregate()
105
+ target_list = TARGETS.get_buffer().cpu().numpy()
106
+ pred_list = PREDS.get_buffer().cpu().numpy()
107
+ auc_epoch = roc_auc_score(target_list, pred_list)
108
+
109
+ return loss_epoch, auc_epoch
110
+
111
+ def val_epoch(cspca_model, loader, epoch, args):
112
+ cspca_model.eval()
113
+ criterion = nn.BCELoss()
114
+ loss = 0.0
115
+ run_loss = CumulativeAverage()
116
+ TARGETS = Cumulative()
117
+ PREDS = Cumulative()
118
+ with torch.no_grad():
119
+ for idx, batch_data in enumerate(loader):
120
+ data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
121
+ target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
122
+
123
+ output = cspca_model(data)
124
+ output = output.squeeze(1)
125
+ loss = criterion(output, target)
126
+
127
+ TARGETS.extend(target.detach().cpu())
128
+ PREDS.extend(output.detach().cpu())
129
+ run_loss.append(loss.item())
130
+
131
+ loss_epoch = run_loss.aggregate()
132
+ target_list = TARGETS.get_buffer().cpu().numpy()
133
+ pred_list = PREDS.get_buffer().cpu().numpy()
134
+ auc_epoch = roc_auc_score(target_list, pred_list)
135
+ y_pred_categoric = (pred_list >= 0.5)
136
+ tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel()
137
+ sens_epoch = tp / (tp + fn)
138
+ spec_epoch = tn / (tn + fp)
139
+ val_epoch_metric = {'epoch': epoch, 'loss': loss_epoch, 'auc': auc_epoch, 'sensitivity': sens_epoch, 'specificity': spec_epoch}
140
+ return val_epoch_metric
141
+