Spaces:
Runtime error
Runtime error
Anirudh Balaraman commited on
Commit ·
906fcb9
1
Parent(s): 54bae7b
add scripts
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +2 -0
- config/config_cspca_test.yaml +17 -0
- config/config_cspca_train.yaml +19 -0
- config/config_pirads_test.yaml +18 -0
- config/config_pirads_train.yaml +22 -0
- config/config_preprocess.yaml +11 -0
- config/inference.json +153 -0
- dataset/PI-RADS_data.json +0 -0
- dataset/PICAI_cspca.json +0 -0
- dataset/adc_reference.nrrd +3 -0
- dataset/dwi_reference.nrrd +3 -0
- dataset/prostate_segmentation_reference.nrrd +3 -0
- dataset/t2_reference.nrrd +3 -0
- job_scripts/train_cspca.sh +19 -0
- preprocess_main.py +68 -0
- pyproject.toml +0 -0
- run_cspca.py +220 -0
- run_pirads.py +283 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-39.pyc +0 -0
- src/__pycache__/utils.cpython-39.pyc +0 -0
- src/data/__init__.py +0 -0
- src/data/__pycache__/__init__.cpython-39.pyc +0 -0
- src/data/__pycache__/custom_transforms.cpython-39.pyc +0 -0
- src/data/__pycache__/data_loader.cpython-39.pyc +0 -0
- src/data/custom_transforms.py +350 -0
- src/data/data_loader.py +125 -0
- src/model/MIL.py +248 -0
- src/model/__init__.py +0 -0
- src/model/__pycache__/MIL.cpython-39.pyc +0 -0
- src/model/__pycache__/__init__.cpython-39.pyc +0 -0
- src/model/__pycache__/csPCa_model.cpython-39.pyc +0 -0
- src/model/csPCa_model.py +50 -0
- src/preprocessing/__init__.py +0 -0
- src/preprocessing/__pycache__/__init__.cpython-39.pyc +0 -0
- src/preprocessing/__pycache__/center_crop.cpython-39.pyc +0 -0
- src/preprocessing/__pycache__/generate_heatmap.cpython-39.pyc +0 -0
- src/preprocessing/__pycache__/histogram_match.cpython-39.pyc +0 -0
- src/preprocessing/__pycache__/prostate_mask.cpython-39.pyc +0 -0
- src/preprocessing/__pycache__/register_and_crop.cpython-39.pyc +0 -0
- src/preprocessing/center_crop.py +64 -0
- src/preprocessing/generate_heatmap.py +76 -0
- src/preprocessing/histogram_match.py +62 -0
- src/preprocessing/prostate_mask.py +128 -0
- src/preprocessing/register_and_crop.py +67 -0
- src/train/__init__.py +0 -0
- src/train/__pycache__/__init__.cpython-39.pyc +0 -0
- src/train/__pycache__/train_cspca.cpython-39.pyc +0 -0
- src/train/__pycache__/train_pirads.cpython-39.pyc +0 -0
- src/train/train_cspca.py +141 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
logs/
|
| 2 |
+
models/
|
config/config_cspca_test.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
|
| 2 |
+
data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
|
| 3 |
+
dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/dataset/PICAI_cspca.json
|
| 4 |
+
num_classes: !!int 4
|
| 5 |
+
mil_mode: att_trans
|
| 6 |
+
tile_count: !!int 24
|
| 7 |
+
tile_size: !!int 64
|
| 8 |
+
depth: !!int 3
|
| 9 |
+
use_heatmap: !!bool True
|
| 10 |
+
workers: !!int 6
|
| 11 |
+
checkpoint_cspca: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/models/cspca_model.pth
|
| 12 |
+
num_seeds: !!int 2
|
| 13 |
+
batch_size: !!int 1
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
config/config_cspca_train.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
|
| 2 |
+
data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
|
| 3 |
+
dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/dataset/PICAI_cspca.json
|
| 4 |
+
num_classes: !!int 4
|
| 5 |
+
mil_mode: att_trans
|
| 6 |
+
tile_count: !!int 24
|
| 7 |
+
tile_size: !!int 64
|
| 8 |
+
depth: !!int 3
|
| 9 |
+
use_heatmap: !!bool True
|
| 10 |
+
workers: !!int 6
|
| 11 |
+
checkpoint_pirads: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/models/pirads.pt
|
| 12 |
+
epochs: !!int 1
|
| 13 |
+
batch_size: !!int 8
|
| 14 |
+
optim_lr: !!float 2e-4
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
config/config_pirads_test.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run_name: pirads_test_run
|
| 2 |
+
project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
|
| 3 |
+
data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/prostate-foundation/PICAI_registered/t2_hist_matched
|
| 4 |
+
dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/dataset/PI-RADS_data.json
|
| 5 |
+
num_classes: !!int 4
|
| 6 |
+
mil_mode: att_trans
|
| 7 |
+
tile_count: !!int 24
|
| 8 |
+
tile_size: !!int 64
|
| 9 |
+
depth: !!int 3
|
| 10 |
+
use_heatmap: !!bool True
|
| 11 |
+
workers: !!int 0
|
| 12 |
+
checkpoint: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/models/pirads.pt
|
| 13 |
+
amp: !!bool True
|
| 14 |
+
dry_run: !!bool True
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
config/config_pirads_train.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
|
| 2 |
+
data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/prostate-foundation/PICAI_registered/t2_hist_matched
|
| 3 |
+
dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/dataset/PI-RADS_data.json
|
| 4 |
+
num_classes: !!int 4
|
| 5 |
+
mil_mode: att_trans
|
| 6 |
+
tile_count: !!int 24
|
| 7 |
+
tile_size: !!int 64
|
| 8 |
+
depth: !!int 3
|
| 9 |
+
use_heatmap: !!bool True
|
| 10 |
+
workers: !!int 0
|
| 11 |
+
checkpoint_pirads: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/models/pirads.pt
|
| 12 |
+
epochs: !!int 2
|
| 13 |
+
batch_size: !!int 8
|
| 14 |
+
optim_lr: !!float 2e-4
|
| 15 |
+
weight_decay: !!float 1e-5
|
| 16 |
+
amp: !!bool True
|
| 17 |
+
wandb: !!bool True
|
| 18 |
+
dry_run: !!bool True
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
config/config_preprocess.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
t2_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/data_temp/t2
|
| 2 |
+
dwi_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/data_temp/dwi
|
| 3 |
+
adc_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/data_temp/adc
|
| 4 |
+
output_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/data_temp/processed
|
| 5 |
+
project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
config/inference.json
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"imports": [
|
| 3 |
+
"$import pandas as pd",
|
| 4 |
+
"$import os"
|
| 5 |
+
],
|
| 6 |
+
"bundle_root": "/workspace/data/prostate_mri_anatomy",
|
| 7 |
+
"output_dir": "$@bundle_root + '/eval'",
|
| 8 |
+
"dataset_dir": "/workspace/data/prostate158/prostate158_train/",
|
| 9 |
+
"datalist": "$list(@dataset_dir + pd.read_csv(@dataset_dir + 'valid.csv').t2)",
|
| 10 |
+
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
|
| 11 |
+
"network_def": {
|
| 12 |
+
"_target_": "UNet",
|
| 13 |
+
"spatial_dims": 3,
|
| 14 |
+
"in_channels": 1,
|
| 15 |
+
"out_channels": 3,
|
| 16 |
+
"channels": [
|
| 17 |
+
16,
|
| 18 |
+
32,
|
| 19 |
+
64,
|
| 20 |
+
128,
|
| 21 |
+
256,
|
| 22 |
+
512
|
| 23 |
+
],
|
| 24 |
+
"strides": [
|
| 25 |
+
2,
|
| 26 |
+
2,
|
| 27 |
+
2,
|
| 28 |
+
2,
|
| 29 |
+
2
|
| 30 |
+
],
|
| 31 |
+
"num_res_units": 4,
|
| 32 |
+
"norm": "batch",
|
| 33 |
+
"act": "prelu",
|
| 34 |
+
"dropout": 0.15
|
| 35 |
+
},
|
| 36 |
+
"network": "$@network_def.to(@device)",
|
| 37 |
+
"preprocessing": {
|
| 38 |
+
"_target_": "Compose",
|
| 39 |
+
"transforms": [
|
| 40 |
+
{
|
| 41 |
+
"_target_": "LoadImaged",
|
| 42 |
+
"keys": "image"
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"_target_": "EnsureChannelFirstd",
|
| 46 |
+
"keys": "image"
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"_target_": "Orientationd",
|
| 50 |
+
"keys": "image",
|
| 51 |
+
"axcodes": "RAS"
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"_target_": "Spacingd",
|
| 55 |
+
"keys": "image",
|
| 56 |
+
"pixdim": [
|
| 57 |
+
0.5,
|
| 58 |
+
0.5,
|
| 59 |
+
0.5
|
| 60 |
+
],
|
| 61 |
+
"mode": "bilinear"
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"_target_": "ScaleIntensityd",
|
| 65 |
+
"keys": "image",
|
| 66 |
+
"minv": 0,
|
| 67 |
+
"maxv": 1
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"_target_": "NormalizeIntensityd",
|
| 71 |
+
"keys": "image"
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"_target_": "EnsureTyped",
|
| 75 |
+
"keys": "image"
|
| 76 |
+
}
|
| 77 |
+
]
|
| 78 |
+
},
|
| 79 |
+
"dataset": {
|
| 80 |
+
"_target_": "Dataset",
|
| 81 |
+
"data": "$[{'image': i} for i in @datalist]",
|
| 82 |
+
"transform": "@preprocessing"
|
| 83 |
+
},
|
| 84 |
+
"dataloader": {
|
| 85 |
+
"_target_": "DataLoader",
|
| 86 |
+
"dataset": "@dataset",
|
| 87 |
+
"batch_size": 1,
|
| 88 |
+
"shuffle": false,
|
| 89 |
+
"num_workers": 4
|
| 90 |
+
},
|
| 91 |
+
"inferer": {
|
| 92 |
+
"_target_": "SlidingWindowInferer",
|
| 93 |
+
"roi_size": [
|
| 94 |
+
96,
|
| 95 |
+
96,
|
| 96 |
+
96
|
| 97 |
+
],
|
| 98 |
+
"sw_batch_size": 4,
|
| 99 |
+
"overlap": 0.5
|
| 100 |
+
},
|
| 101 |
+
"postprocessing": {
|
| 102 |
+
"_target_": "Compose",
|
| 103 |
+
"transforms": [
|
| 104 |
+
{
|
| 105 |
+
"_target_": "AsDiscreted",
|
| 106 |
+
"keys": "pred",
|
| 107 |
+
"argmax": true
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"_target_": "KeepLargestConnectedComponentd",
|
| 111 |
+
"keys": "pred",
|
| 112 |
+
"applied_labels": [
|
| 113 |
+
1,
|
| 114 |
+
2
|
| 115 |
+
]
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"_target_": "SaveImaged",
|
| 119 |
+
"keys": "pred",
|
| 120 |
+
"resample": false,
|
| 121 |
+
"meta_keys": "pred_meta_dict",
|
| 122 |
+
"output_dir": "@output_dir"
|
| 123 |
+
}
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
"handlers": [
|
| 127 |
+
{
|
| 128 |
+
"_target_": "CheckpointLoader",
|
| 129 |
+
"load_path": "$@bundle_root + '/models/model.pt'",
|
| 130 |
+
"load_dict": {
|
| 131 |
+
"model": "@network"
|
| 132 |
+
}
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"_target_": "StatsHandler",
|
| 136 |
+
"iteration_log": false
|
| 137 |
+
}
|
| 138 |
+
],
|
| 139 |
+
"evaluator": {
|
| 140 |
+
"_target_": "SupervisedEvaluator",
|
| 141 |
+
"device": "@device",
|
| 142 |
+
"val_data_loader": "@dataloader",
|
| 143 |
+
"network": "@network",
|
| 144 |
+
"inferer": "@inferer",
|
| 145 |
+
"postprocessing": "@postprocessing",
|
| 146 |
+
"val_handlers": "@handlers",
|
| 147 |
+
"amp": true
|
| 148 |
+
},
|
| 149 |
+
"evaluating": [
|
| 150 |
+
"$setattr(torch.backends.cudnn, 'benchmark', True)",
|
| 151 |
+
"$@evaluator.run()"
|
| 152 |
+
]
|
| 153 |
+
}
|
dataset/PI-RADS_data.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/PICAI_cspca.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/adc_reference.nrrd
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:46821ddb4373198d98b5877ec69a702718702cbaae6b9ef00b6b5ad235cf3f3e
|
| 3 |
+
size 3815961
|
dataset/dwi_reference.nrrd
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a964835a0c8c016b162a2eaa56456f9704290b3a4128aafeb9b01805166ca7b9
|
| 3 |
+
size 3815961
|
dataset/prostate_segmentation_reference.nrrd
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:95984795649662ae405077d334744ac2e9d9fb8db7864c4ea99791a706ccee19
|
| 3 |
+
size 13434
|
dataset/t2_reference.nrrd
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:561af922d7461870fcc1e84134d659f7611a8ae73d1a0681fde79808f1cd99a9
|
| 3 |
+
size 3815961
|
job_scripts/train_cspca.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=cspca_training # Specify job name
|
| 3 |
+
#SBATCH --partition=gpu # Specify partition name
|
| 4 |
+
#SBATCH --mem=128G
|
| 5 |
+
#SBATCH --gres=gpu:1
|
| 6 |
+
#SBATCH --time=48:00:00 # Set a limit on the total run time
|
| 7 |
+
#SBATCH --output=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/logs/%x/log.o%j # File name for standard output
|
| 8 |
+
#SBATCH --error=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/logs/%x/log.e%j # File name for standard error output
|
| 9 |
+
#SBATCH --mail-user=anirudh.balaraman@charite.de
|
| 10 |
+
#SBATCH --mail-type=END,FAIL
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
source /etc/profile.d/conda.sh
|
| 14 |
+
conda activate foundation
|
| 15 |
+
|
| 16 |
+
RUNDIR="/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
srun python -u $RUNDIR/MIL/new_folder/run_cspca.py --mode train --config $RUNDIR/MIL/new_folder/config/config_cspca_train.yaml
|
preprocess_main.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import SimpleITK as sitk
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import nrrd
|
| 5 |
+
from AIAH_utility.viewer import BasicViewer
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from picai_prep.preprocessing import PreprocessingSettings, Sample
|
| 9 |
+
import multiprocessing
|
| 10 |
+
import sys
|
| 11 |
+
from src.preprocessing.register_and_crop import register_files
|
| 12 |
+
from src.preprocessing.prostate_mask import get_segmask
|
| 13 |
+
from src.preprocessing.histogram_match import histmatch
|
| 14 |
+
from src.preprocessing.generate_heatmap import get_heatmap
|
| 15 |
+
import logging
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from src.utils import setup_logging
|
| 18 |
+
from src.utils import validate_steps
|
| 19 |
+
import argparse
|
| 20 |
+
import yaml
|
| 21 |
+
|
| 22 |
+
def parse_args():
|
| 23 |
+
FUNCTIONS = {
|
| 24 |
+
"register_and_crop": register_files,
|
| 25 |
+
"histogram_match": histmatch,
|
| 26 |
+
"get_segmentation_mask": get_segmask,
|
| 27 |
+
"get_heatmap": get_heatmap,
|
| 28 |
+
}
|
| 29 |
+
parser = argparse.ArgumentParser(description="File preprocessing")
|
| 30 |
+
parser.add_argument("--config", type=str, help="Path to YAML config file")
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--steps",
|
| 33 |
+
nargs="+", # ← list of strings
|
| 34 |
+
choices=FUNCTIONS.keys(), # ← restrict allowed values
|
| 35 |
+
required=True,
|
| 36 |
+
help="Steps to execute (one or more)"
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument("--t2_dir", default=None, help="Path to T2W files")
|
| 39 |
+
parser.add_argument("--dwi_dir", default=None, help="Path to DWI files")
|
| 40 |
+
parser.add_argument("--adc_dir", default=None, help="Path to ADC files")
|
| 41 |
+
parser.add_argument("--seg_dir", default=None, help="Path to segmentation masks")
|
| 42 |
+
parser.add_argument("--output_dir", default=None, help="Path to output folder")
|
| 43 |
+
parser.add_argument("--margin", default=0.2, type=float, help="Margin to center crop the images")
|
| 44 |
+
parser.add_argument("--project_dir", default=None, help="Project directory")
|
| 45 |
+
|
| 46 |
+
args = parser.parse_args()
|
| 47 |
+
if args.config:
|
| 48 |
+
with open(args.config, 'r') as config_file:
|
| 49 |
+
config = yaml.safe_load(config_file)
|
| 50 |
+
args.__dict__.update(config)
|
| 51 |
+
return args
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
args = parse_args()
|
| 55 |
+
FUNCTIONS = {
|
| 56 |
+
"register_and_crop": register_files,
|
| 57 |
+
"histogram_match": histmatch,
|
| 58 |
+
"get_segmentation_mask": get_segmask,
|
| 59 |
+
"get_heatmap": get_heatmap,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
args.logfile = os.path.join(args.output_dir, f"preprocessing.log")
|
| 63 |
+
setup_logging(args.logfile)
|
| 64 |
+
logging.info("Starting preprocessing")
|
| 65 |
+
validate_steps(args.steps)
|
| 66 |
+
for step in args.steps:
|
| 67 |
+
func = FUNCTIONS[step]
|
| 68 |
+
args = func(args)
|
pyproject.toml
ADDED
|
File without changes
|
run_cspca.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import time
|
| 5 |
+
import yaml
|
| 6 |
+
import sys
|
| 7 |
+
import gdown
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
import torch.multiprocessing as mp
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from monai.config import KeysCollection
|
| 15 |
+
from monai.metrics import Cumulative, CumulativeAverage
|
| 16 |
+
from monai.networks.nets import milmodel, resnet, MILModel
|
| 17 |
+
|
| 18 |
+
from sklearn.metrics import cohen_kappa_score
|
| 19 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 20 |
+
from torch.utils.data.dataloader import default_collate
|
| 21 |
+
from torchvision.models.resnet import ResNet50_Weights
|
| 22 |
+
import shutil
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 25 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 26 |
+
from monai.utils import set_determinism
|
| 27 |
+
|
| 28 |
+
import matplotlib.pyplot as plt
|
| 29 |
+
|
| 30 |
+
import wandb
|
| 31 |
+
import math
|
| 32 |
+
import logging
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
from src.model.MIL import MILModel_3D
|
| 37 |
+
from src.model.csPCa_model import csPCa_Model
|
| 38 |
+
from src.data.data_loader import get_dataloader
|
| 39 |
+
from src.utils import save_cspca_checkpoint, get_metrics, setup_logging
|
| 40 |
+
from src.train.train_cspca import train_epoch, val_epoch
|
| 41 |
+
|
| 42 |
+
def main_worker(args):
|
| 43 |
+
|
| 44 |
+
mil_model = MILModel_3D(
|
| 45 |
+
num_classes=args.num_classes,
|
| 46 |
+
mil_mode=args.mil_mode
|
| 47 |
+
).to(args.device)
|
| 48 |
+
cache_dir_path = Path(os.path.join(args.logdir, "cache"))
|
| 49 |
+
|
| 50 |
+
if args.mode == 'train':
|
| 51 |
+
|
| 52 |
+
checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
|
| 53 |
+
mil_model.load_state_dict(checkpoint["state_dict"])
|
| 54 |
+
mil_model = mil_model.to(args.device)
|
| 55 |
+
|
| 56 |
+
model_dir = os.path.join(args.project_dir,'models')
|
| 57 |
+
metrics_dict = {'auc':[], 'sensitivity':[], 'specificity':[]}
|
| 58 |
+
for st in list(range(args.num_seeds)):
|
| 59 |
+
set_determinism(seed=st)
|
| 60 |
+
|
| 61 |
+
train_loader = get_dataloader(args, split="train")
|
| 62 |
+
valid_loader = get_dataloader(args, split="test")
|
| 63 |
+
cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
|
| 64 |
+
for submodule in [cspca_model.backbone.net,
|
| 65 |
+
cspca_model.backbone.myfc,
|
| 66 |
+
cspca_model.backbone.transformer]:
|
| 67 |
+
for param in submodule.parameters():
|
| 68 |
+
param.requires_grad = False
|
| 69 |
+
|
| 70 |
+
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, cspca_model.parameters()), lr=args.optim_lr)
|
| 71 |
+
|
| 72 |
+
old_loss = float('inf')
|
| 73 |
+
old_auc = 0.0
|
| 74 |
+
for epoch in range(args.epochs):
|
| 75 |
+
train_loss, train_auc = train_epoch(cspca_model, train_loader, optimizer, epoch=epoch, args=args)
|
| 76 |
+
logging.info(f"STATE {st} EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}")
|
| 77 |
+
val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
|
| 78 |
+
logging.info(f"STATE {st} EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}")
|
| 79 |
+
val_metric['state'] = st
|
| 80 |
+
if val_metric['loss'] < old_loss:
|
| 81 |
+
old_loss = val_metric['loss']
|
| 82 |
+
old_auc = val_metric['auc']
|
| 83 |
+
sensitivity = val_metric['sensitivity']
|
| 84 |
+
specificity = val_metric['specificity']
|
| 85 |
+
if len(metrics_dict['auc']) == 0:
|
| 86 |
+
save_cspca_checkpoint(cspca_model, val_metric, model_dir)
|
| 87 |
+
elif val_metric['auc'] >= max(metrics_dict['auc']):
|
| 88 |
+
save_cspca_checkpoint(cspca_model, val_metric, model_dir)
|
| 89 |
+
|
| 90 |
+
metrics_dict['auc'].append(old_auc)
|
| 91 |
+
metrics_dict['sensitivity'].append(sensitivity)
|
| 92 |
+
metrics_dict['specificity'].append(specificity)
|
| 93 |
+
if cache_dir_path.exists() and cache_dir_path.is_dir():
|
| 94 |
+
shutil.rmtree(cache_dir_path)
|
| 95 |
+
|
| 96 |
+
get_metrics(metrics_dict)
|
| 97 |
+
|
| 98 |
+
elif args.mode == 'test':
|
| 99 |
+
|
| 100 |
+
cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
|
| 101 |
+
checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
|
| 102 |
+
cspca_model.load_state_dict(checkpt['state_dict'])
|
| 103 |
+
cspca_model = cspca_model.to(args.device)
|
| 104 |
+
if 'auc' in checkpt and 'sensitivity' in checkpt and 'specificity' in checkpt:
|
| 105 |
+
auc, sens, spec = checkpt['auc'], checkpt['sensitivity'], checkpt['specificity']
|
| 106 |
+
logging.info(f"csPCa Model loaded from {args.checkpoint_cspca} with AUC: {auc}, Sensitivity: {sens}, Specificity: {spec} on the test set.")
|
| 107 |
+
else:
|
| 108 |
+
logging.info(f"csPCa Model loaded from {args.checkpoint_cspca}.")
|
| 109 |
+
|
| 110 |
+
metrics_dict = {'auc':[], 'sensitivity':[], 'specificity':[]}
|
| 111 |
+
for st in list(range(args.num_seeds)):
|
| 112 |
+
set_determinism(seed=st)
|
| 113 |
+
test_loader = get_dataloader(args, split="test")
|
| 114 |
+
test_metric = val_epoch(cspca_model, test_loader, epoch=0, args=args)
|
| 115 |
+
metrics_dict['auc'].append(test_metric['auc'])
|
| 116 |
+
metrics_dict['sensitivity'].append(test_metric['sensitivity'])
|
| 117 |
+
metrics_dict['specificity'].append(test_metric['specificity'])
|
| 118 |
+
|
| 119 |
+
if cache_dir_path.exists() and cache_dir_path.is_dir():
|
| 120 |
+
shutil.rmtree(cache_dir_path)
|
| 121 |
+
|
| 122 |
+
get_metrics(metrics_dict)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def parse_args():
|
| 128 |
+
parser = argparse.ArgumentParser(description="Multiple Instance Learning (MIL) for csPCa risk prediction.")
|
| 129 |
+
parser.add_argument('--mode', type=str, choices=['train', 'test'], required=True, help='Operation mode: train or infer')
|
| 130 |
+
parser.add_argument('--run_name', type=str, default='train_cspca', help='run name for log file')
|
| 131 |
+
parser.add_argument('--config', type=str, help='Path to YAML config file')
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--project_dir", default=None, help="path to project firectory"
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--data_root", default=None, help="path to root folder of images"
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
|
| 139 |
+
parser.add_argument("--num_classes", default=4, type=int, help="number of output classes")
|
| 140 |
+
parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm: choose either att_trans or att_pyramid")
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--tile_count", default=24, type=int, help="number of patches (instances) to extract from MRI input"
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument("--tile_size", default=64, type=int, help="size of square patch (instance) in pixels")
|
| 145 |
+
parser.add_argument("--depth", default=3, type=int, help="number of slices in each 3D patch (instance)")
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--use_heatmap", action="store_true",
|
| 148 |
+
help="enable weak attention heatmap guided patch generation"
|
| 149 |
+
)
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--no_heatmap", dest="use_heatmap", action="store_false",
|
| 152 |
+
help="disable heatmap"
|
| 153 |
+
)
|
| 154 |
+
parser.set_defaults(use_heatmap=True)
|
| 155 |
+
parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
|
| 156 |
+
#parser.add_argument("--dry-run", action="store_true")
|
| 157 |
+
parser.add_argument("--checkpoint_pirads", default=None, help="Load PI-RADS model")
|
| 158 |
+
parser.add_argument("--epochs", "--max_epochs", default=30, type=int, help="number of training epochs")
|
| 159 |
+
parser.add_argument("--batch_size", default=32, type=int, help="number of MRI scans per batch")
|
| 160 |
+
parser.add_argument("--optim_lr", default=2e-4, type=float, help="initial learning rate")
|
| 161 |
+
#parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--val_every",
|
| 164 |
+
"--val_interval",
|
| 165 |
+
default=1,
|
| 166 |
+
type=int,
|
| 167 |
+
help="run validation after this number of epochs, default 1 to run every epoch",
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument("--dry_run", action="store_true", help="Run the script in dry-run mode (default: False)")
|
| 170 |
+
parser.add_argument("--checkpoint_cspca", default=None, help="load existing checkpoint")
|
| 171 |
+
parser.add_argument("--num_seeds", default=20, type=int, help="number of seeds to be run to build CI")
|
| 172 |
+
args = parser.parse_args()
|
| 173 |
+
if args.config:
|
| 174 |
+
with open(args.config, 'r') as config_file:
|
| 175 |
+
config = yaml.safe_load(config_file)
|
| 176 |
+
args.__dict__.update(config)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
return args
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
args = parse_args()
|
| 186 |
+
args.logdir = os.path.join(args.project_dir, "logs", args.run_name)
|
| 187 |
+
os.makedirs(args.logdir, exist_ok=True)
|
| 188 |
+
args.logfile = os.path.join(args.logdir, f"{args.run_name}.log")
|
| 189 |
+
setup_logging(args.logfile)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
logging.info("Argument values:")
|
| 193 |
+
for k, v in vars(args).items():
|
| 194 |
+
logging.info(f"{k} => {v}")
|
| 195 |
+
logging.info("-----------------")
|
| 196 |
+
|
| 197 |
+
if args.dataset_json is None:
|
| 198 |
+
logging.error('Dataset path not provided. Quitting.')
|
| 199 |
+
sys.exit(1)
|
| 200 |
+
if args.checkpoint_pirads is None and args.mode == 'train':
|
| 201 |
+
logging.error('PI-RADS checkpoint path not provided. Quitting.')
|
| 202 |
+
sys.exit(1)
|
| 203 |
+
elif args.checkpoint_cspca is None and args.mode == 'test':
|
| 204 |
+
logging.error('csPCa checkpoint path not provided. Quitting.')
|
| 205 |
+
sys.exit(1)
|
| 206 |
+
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 207 |
+
|
| 208 |
+
if args.device == torch.device("cuda"):
|
| 209 |
+
torch.backends.cudnn.benchmark = True
|
| 210 |
+
|
| 211 |
+
if args.dry_run:
|
| 212 |
+
logging.info("Dry run mode enabled.")
|
| 213 |
+
args.epochs = 2
|
| 214 |
+
args.batch_size = 2
|
| 215 |
+
args.workers = 0
|
| 216 |
+
args.num_seeds = 2
|
| 217 |
+
args.wandb = False
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
main_worker(args)
|
run_pirads.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import collections.abc
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import time
|
| 6 |
+
import yaml
|
| 7 |
+
import sys
|
| 8 |
+
import gdown
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
import torch.multiprocessing as mp
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from monai.config import KeysCollection
|
| 16 |
+
from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
|
| 17 |
+
from monai.data.wsi_reader import WSIReader
|
| 18 |
+
from monai.metrics import Cumulative, CumulativeAverage
|
| 19 |
+
from monai.networks.nets import milmodel, resnet, MILModel
|
| 20 |
+
|
| 21 |
+
from sklearn.metrics import cohen_kappa_score
|
| 22 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 23 |
+
from torch.utils.data.dataloader import default_collate
|
| 24 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 25 |
+
from monai.utils import set_determinism
|
| 26 |
+
import matplotlib.pyplot as plt
|
| 27 |
+
|
| 28 |
+
import wandb
|
| 29 |
+
import math
|
| 30 |
+
import logging
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from src.data.data_loader import get_dataloader
|
| 33 |
+
from src.train.train_pirads import train_epoch, val_epoch
|
| 34 |
+
from src.model.MIL import MILModel_3D
|
| 35 |
+
from src.utils import save_pirads_checkpoint, setup_logging
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main_worker(args):
|
| 39 |
+
if args.device == torch.device("cuda"):
|
| 40 |
+
torch.cuda.set_device(args.gpu) # use this default device (same as args.device if not distributed)
|
| 41 |
+
torch.backends.cudnn.benchmark = True
|
| 42 |
+
|
| 43 |
+
model = MILModel_3D(
|
| 44 |
+
num_classes=args.num_classes,
|
| 45 |
+
mil_mode=args.mil_mode
|
| 46 |
+
)
|
| 47 |
+
start_epoch = 0
|
| 48 |
+
best_acc = 0.0
|
| 49 |
+
if args.checkpoint is not None:
|
| 50 |
+
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
| 51 |
+
model.load_state_dict(checkpoint["state_dict"])
|
| 52 |
+
|
| 53 |
+
if "epoch" in checkpoint:
|
| 54 |
+
start_epoch = checkpoint["epoch"]
|
| 55 |
+
if "best_acc" in checkpoint:
|
| 56 |
+
best_acc = checkpoint["best_acc"]
|
| 57 |
+
logging.info("=> loaded checkpoint %s (epoch %d) (bestacc %f)",args.checkpoint, start_epoch, best_acc)
|
| 58 |
+
cache_dir_ = os.path.join(args.logdir, "cache")
|
| 59 |
+
model.to(args.device)
|
| 60 |
+
params = model.parameters()
|
| 61 |
+
if args.mode == 'train':
|
| 62 |
+
train_loader = get_dataloader(args, split=args.mode)
|
| 63 |
+
valid_loader = get_dataloader(args, split="test")
|
| 64 |
+
logging.info("Dataset training:", str(len(train_loader.dataset)), "test:", str(len(valid_loader.dataset)))
|
| 65 |
+
|
| 66 |
+
if args.mil_mode in ["att_trans", "att_trans_pyramid"]:
|
| 67 |
+
params = [
|
| 68 |
+
{"params": list(model.attention.parameters()) + list(model.myfc.parameters()) + list(model.net.parameters())},
|
| 69 |
+
{"params": list(model.transformer.parameters()), "lr": 6e-5, "weight_decay": 0.1},
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
optimizer = torch.optim.AdamW(params, lr=args.optim_lr, weight_decay=args.weight_decay)
|
| 73 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0)
|
| 74 |
+
scaler = torch.amp.GradScaler(device=str(args.device), enabled=args.amp)
|
| 75 |
+
|
| 76 |
+
if args.logdir is not None:
|
| 77 |
+
writer = SummaryWriter(log_dir=args.logdir)
|
| 78 |
+
logging.info("Writing Tensorboard logs to ", writer.log_dir)
|
| 79 |
+
else:
|
| 80 |
+
writer = None
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# RUN TRAINING
|
| 84 |
+
n_epochs = args.epochs
|
| 85 |
+
val_loss_min = float("inf")
|
| 86 |
+
epochs_no_improve = 0
|
| 87 |
+
for epoch in range(start_epoch, n_epochs):
|
| 88 |
+
|
| 89 |
+
logging.info(time.ctime(), "Epoch:", epoch)
|
| 90 |
+
epoch_time = time.time()
|
| 91 |
+
train_loss, train_acc, train_att_loss, batch_norm = train_epoch(model, train_loader, optimizer, scaler=scaler, epoch=epoch, args=args)
|
| 92 |
+
logging.info(
|
| 93 |
+
"Final training %d/%d loss: %.4f attention loss: %.4f acc: %.4f time %.2fs",
|
| 94 |
+
epoch,
|
| 95 |
+
n_epochs - 1,
|
| 96 |
+
train_loss,
|
| 97 |
+
train_att_loss,
|
| 98 |
+
train_acc,
|
| 99 |
+
time.time() - epoch_time,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
if writer is not None:
|
| 103 |
+
writer.add_scalar("train_loss", train_loss, epoch)
|
| 104 |
+
writer.add_scalar("train_attention_loss", train_att_loss, epoch)
|
| 105 |
+
writer.add_scalar("train_acc", train_acc, epoch)
|
| 106 |
+
wandb.log({"Train Loss": train_loss, "Train Accuracy": train_acc, "Train Attention Loss": train_att_loss, "Batch Norm": batch_norm}, step=epoch)
|
| 107 |
+
|
| 108 |
+
model_new_best = False
|
| 109 |
+
val_acc = 0
|
| 110 |
+
if (epoch + 1) % args.val_every == 0:
|
| 111 |
+
epoch_time = time.time()
|
| 112 |
+
val_loss, val_acc, qwk = val_epoch(model, valid_loader, epoch=epoch, args=args)
|
| 113 |
+
|
| 114 |
+
logging.info(
|
| 115 |
+
"Final test %d/%d loss: %.4f acc: %.4f qwk: %.4f time %.2fs",
|
| 116 |
+
epoch,
|
| 117 |
+
n_epochs - 1,
|
| 118 |
+
val_loss,
|
| 119 |
+
val_acc,
|
| 120 |
+
qwk,
|
| 121 |
+
time.time() - epoch_time,
|
| 122 |
+
)
|
| 123 |
+
if writer is not None:
|
| 124 |
+
writer.add_scalar("test_loss", val_loss, epoch)
|
| 125 |
+
writer.add_scalar("test_acc", val_acc, epoch)
|
| 126 |
+
writer.add_scalar("test_qwk", qwk, epoch)
|
| 127 |
+
|
| 128 |
+
#val_acc = qwk
|
| 129 |
+
wandb.log({"Test Loss": val_loss, "Test Accuracy": val_acc,"Cohen Kappa": qwk}, step=epoch)
|
| 130 |
+
if val_loss < val_loss_min:
|
| 131 |
+
logging.info("Loss (%.6f --> %.6f)", val_loss_min, val_loss)
|
| 132 |
+
val_loss_min = val_loss
|
| 133 |
+
model_new_best = True
|
| 134 |
+
|
| 135 |
+
if args.logdir is not None:
|
| 136 |
+
save_pirads_checkpoint(model, epoch, args, best_acc=val_acc, filename=f"model_{epoch}.pt")
|
| 137 |
+
if model_new_best:
|
| 138 |
+
logging.info("Copying to model.pt new best model!!!!")
|
| 139 |
+
shutil.copyfile(os.path.join(args.logdir, f"model_{epoch}.pt"), os.path.join(args.logdir, "model.pt"))
|
| 140 |
+
epochs_no_improve = 0
|
| 141 |
+
|
| 142 |
+
else:
|
| 143 |
+
epochs_no_improve += 1
|
| 144 |
+
if epochs_no_improve == args.early_stop:
|
| 145 |
+
logging.info('Early stopping!')
|
| 146 |
+
break
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
scheduler.step()
|
| 151 |
+
|
| 152 |
+
logging.info("ALL DONE")
|
| 153 |
+
|
| 154 |
+
elif args.mode == 'test':
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
kappa_list = []
|
| 158 |
+
for seed in list(range(args.num_seeds)):
|
| 159 |
+
set_determinism(seed=seed)
|
| 160 |
+
valid_loader = get_dataloader(args, split=args.mode)
|
| 161 |
+
logging.info("test:", str(len(valid_loader.dataset)))
|
| 162 |
+
val_loss, val_acc, qwk = val_epoch(model, valid_loader, epoch=0, args=args)
|
| 163 |
+
kappa_list.append(qwk)
|
| 164 |
+
logging.info(f"Seed {seed}, QWK: {qwk}")
|
| 165 |
+
if os.path.exists(cache_dir_):
|
| 166 |
+
logging.info("Removing cache directory ", cache_dir_)
|
| 167 |
+
shutil.rmtree(cache_dir_)
|
| 168 |
+
|
| 169 |
+
logging.info(f"Mean QWK over {args.num_seeds} seeds: {np.mean(kappa_list)}")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
if os.path.exists(cache_dir_):
|
| 173 |
+
logging.info("Removing cache directory ", cache_dir_)
|
| 174 |
+
shutil.rmtree(cache_dir_)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def parse_args():
|
| 178 |
+
parser = argparse.ArgumentParser(description="Multiple Instance Learning (MIL) for PIRADS Classification.")
|
| 179 |
+
parser.add_argument('--mode', type=str, choices=['train', 'test'], required=True, help='operation mode: train or infer')
|
| 180 |
+
parser.add_argument('--wandb', action='store_true', help='Add this flag to enable WandB logging')
|
| 181 |
+
parser.add_argument('--project_name', type=str, default='Classification_prostate', help='WandB project name')
|
| 182 |
+
parser.add_argument('--run_name', type=str, default='train_pirads', help='run name for WandB logging')
|
| 183 |
+
parser.add_argument('--config', type=str, help='path to YAML config file')
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--project_dir", default=None, help="path to project firectory"
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--data_root", default=None, help="path to root folder of images"
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
|
| 191 |
+
parser.add_argument("--num_classes", default=4, type=int, help="number of output classes")
|
| 192 |
+
parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm: choose either att_trans or att_pyramid")
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--tile_count", default=24, type=int, help="number of patches (instances) to extract from MRI input"
|
| 195 |
+
)
|
| 196 |
+
parser.add_argument("--tile_size", default=64, type=int, help="size of square patch (instance) in pixels")
|
| 197 |
+
parser.add_argument("--depth", default=3, type=int, help="number of slices in each 3D patch (instance)")
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--use_heatmap", action="store_true",
|
| 200 |
+
help="enable weak attention heatmap guided patch generation"
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--no_heatmap", dest="use_heatmap", action="store_false",
|
| 204 |
+
help="disable heatmap"
|
| 205 |
+
)
|
| 206 |
+
parser.set_defaults(use_heatmap=True)
|
| 207 |
+
parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
|
| 208 |
+
|
| 209 |
+
parser.add_argument("--checkpoint", default=None, help="load existing checkpoint")
|
| 210 |
+
parser.add_argument("--epochs", "--max_epochs", default=50, type=int, help="number of training epochs")
|
| 211 |
+
parser.add_argument("--early_stop", default=40, type=int, help="early stopping criteria")
|
| 212 |
+
parser.add_argument("--batch_size", default=4, type=int, help="number of MRI scans per batch")
|
| 213 |
+
parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate")
|
| 214 |
+
parser.add_argument("--weight_decay", default=0, type=float, help="optimizer weight decay")
|
| 215 |
+
parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--val_every",
|
| 218 |
+
"--val_interval",
|
| 219 |
+
default=1,
|
| 220 |
+
type=int,
|
| 221 |
+
help="run validation after this number of epochs, default 1 to run every epoch",
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument("--dry_run", action="store_true", help="Run the script in dry-run mode (default: False)")
|
| 224 |
+
args = parser.parse_args()
|
| 225 |
+
if args.config:
|
| 226 |
+
with open(args.config, 'r') as config_file:
|
| 227 |
+
config = yaml.safe_load(config_file)
|
| 228 |
+
args.__dict__.update(config)
|
| 229 |
+
return args
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
if __name__ == "__main__":
|
| 235 |
+
args = parse_args()
|
| 236 |
+
args.logdir = os.path.join(args.project_dir, "logs", args.run_name)
|
| 237 |
+
os.makedirs(args.logdir, exist_ok=True)
|
| 238 |
+
args.logfile = os.path.join(args.logdir, f"{args.run_name}.log")
|
| 239 |
+
setup_logging(args.logfile)
|
| 240 |
+
|
| 241 |
+
logging.info("Argument values:")
|
| 242 |
+
for k, v in vars(args).items():
|
| 243 |
+
logging.info(f"{k} => {v}")
|
| 244 |
+
logging.info("-----------------")
|
| 245 |
+
|
| 246 |
+
args.num_seeds = 10
|
| 247 |
+
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 248 |
+
if args.device == torch.device("cpu"):
|
| 249 |
+
args.amp = False
|
| 250 |
+
if args.dataset_json is None:
|
| 251 |
+
logging.error('Dataset JSON file not provided. Quitting.')
|
| 252 |
+
sys.exit(1)
|
| 253 |
+
if args.checkpoint is None and args.mode == 'test':
|
| 254 |
+
logging.error('Model checkpoint path not provided. Quitting.')
|
| 255 |
+
sys.exit(1)
|
| 256 |
+
|
| 257 |
+
if args.dry_run:
|
| 258 |
+
logging.info("Dry run mode enabled.")
|
| 259 |
+
args.epochs = 2
|
| 260 |
+
args.batch_size = 2
|
| 261 |
+
args.workers = 0
|
| 262 |
+
args.num_seeds = 2
|
| 263 |
+
args.wandb = False
|
| 264 |
+
|
| 265 |
+
mode_wandb = "online" if args.wandb else "disabled"
|
| 266 |
+
|
| 267 |
+
config_wandb = {
|
| 268 |
+
"learning_rate": args.optim_lr,
|
| 269 |
+
"batch_size": args.batch_size,
|
| 270 |
+
"epochs": args.epochs,
|
| 271 |
+
"patch size": args.tile_size,
|
| 272 |
+
"patch count": args.tile_count,
|
| 273 |
+
}
|
| 274 |
+
wandb.init(project=args.project_name,
|
| 275 |
+
name=args.run_name,
|
| 276 |
+
dir=os.path.join(args.logdir, "wandb"),
|
| 277 |
+
config=config_wandb,
|
| 278 |
+
mode=mode_wandb)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
main_worker(args)
|
| 282 |
+
|
| 283 |
+
wandb.finish()
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (218 Bytes). View file
|
|
|
src/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (3.49 kB). View file
|
|
|
src/data/__init__.py
ADDED
|
File without changes
|
src/data/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (223 Bytes). View file
|
|
|
src/data/__pycache__/custom_transforms.cpython-39.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
src/data/__pycache__/data_loader.cpython-39.pyc
ADDED
|
Binary file (4.24 kB). View file
|
|
|
src/data/custom_transforms.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Union, Optional
|
| 4 |
+
|
| 5 |
+
from monai.transforms import MapTransform
|
| 6 |
+
from monai.config import DtypeLike, KeysCollection
|
| 7 |
+
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
|
| 8 |
+
from monai.data.meta_obj import get_track_meta
|
| 9 |
+
from monai.transforms.transform import Transform
|
| 10 |
+
from monai.transforms.utils import soft_clip
|
| 11 |
+
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
|
| 12 |
+
from monai.utils.enums import TransformBackends
|
| 13 |
+
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor, get_equivalent_dtype
|
| 14 |
+
from scipy.ndimage import binary_dilation
|
| 15 |
+
import cv2
|
| 16 |
+
from typing import Union, Sequence
|
| 17 |
+
from collections.abc import Hashable, Mapping, Sequence
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DilateAndSaveMaskd(MapTransform):
|
| 22 |
+
"""
|
| 23 |
+
Custom transform to dilate binary mask and save a copy.
|
| 24 |
+
"""
|
| 25 |
+
def __init__(self, keys, dilation_size=10, copy_key="original_mask"):
|
| 26 |
+
super().__init__(keys)
|
| 27 |
+
self.dilation_size = dilation_size
|
| 28 |
+
self.copy_key = copy_key
|
| 29 |
+
|
| 30 |
+
def __call__(self, data):
|
| 31 |
+
d = dict(data)
|
| 32 |
+
|
| 33 |
+
for key in self.keys:
|
| 34 |
+
mask = d[key].numpy() if isinstance(d[key], torch.Tensor) else d[key]
|
| 35 |
+
mask = mask.squeeze(0) # Remove channel dimension if present
|
| 36 |
+
|
| 37 |
+
# Save a copy of the original mask
|
| 38 |
+
d[self.copy_key] = torch.tensor(mask, dtype=torch.float32).unsqueeze(0) # Save to a new key
|
| 39 |
+
|
| 40 |
+
# Apply binary dilation to the mask
|
| 41 |
+
dilated_mask = binary_dilation(mask, iterations=self.dilation_size).astype(np.uint8)
|
| 42 |
+
|
| 43 |
+
# Store the dilated mask
|
| 44 |
+
d[key] = torch.tensor(dilated_mask, dtype=torch.float32).unsqueeze(0) # Add channel dimension back
|
| 45 |
+
|
| 46 |
+
return d
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ClipMaskIntensityPercentiles(Transform):
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
lower: Union[float, None],
|
| 57 |
+
upper: Union[float, None],
|
| 58 |
+
sharpness_factor : Union[float, None] = None,
|
| 59 |
+
channel_wise: bool = False,
|
| 60 |
+
dtype: DtypeLike = np.float32,
|
| 61 |
+
) -> None:
|
| 62 |
+
|
| 63 |
+
if lower is None and upper is None:
|
| 64 |
+
raise ValueError("lower or upper percentiles must be provided")
|
| 65 |
+
if lower is not None and (lower < 0.0 or lower > 100.0):
|
| 66 |
+
raise ValueError("Percentiles must be in the range [0, 100]")
|
| 67 |
+
if upper is not None and (upper < 0.0 or upper > 100.0):
|
| 68 |
+
raise ValueError("Percentiles must be in the range [0, 100]")
|
| 69 |
+
if upper is not None and lower is not None and upper < lower:
|
| 70 |
+
raise ValueError("upper must be greater than or equal to lower")
|
| 71 |
+
if sharpness_factor is not None and sharpness_factor <= 0:
|
| 72 |
+
raise ValueError("sharpness_factor must be greater than 0")
|
| 73 |
+
|
| 74 |
+
#self.mask_data = mask_data
|
| 75 |
+
self.lower = lower
|
| 76 |
+
self.upper = upper
|
| 77 |
+
self.sharpness_factor = sharpness_factor
|
| 78 |
+
self.channel_wise = channel_wise
|
| 79 |
+
self.dtype = dtype
|
| 80 |
+
|
| 81 |
+
def _clip(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
|
| 82 |
+
masked_img = img * (mask_data > 0)
|
| 83 |
+
if self.sharpness_factor is not None:
|
| 84 |
+
|
| 85 |
+
lower_percentile = percentile(masked_img, self.lower) if self.lower is not None else None
|
| 86 |
+
upper_percentile = percentile(masked_img, self.upper) if self.upper is not None else None
|
| 87 |
+
img = soft_clip(img, self.sharpness_factor, lower_percentile, upper_percentile, self.dtype)
|
| 88 |
+
else:
|
| 89 |
+
|
| 90 |
+
lower_percentile = percentile(masked_img, self.lower) if self.lower is not None else percentile(masked_img, 0)
|
| 91 |
+
upper_percentile = percentile(masked_img, self.upper) if self.upper is not None else percentile(masked_img, 100)
|
| 92 |
+
img = clip(img, lower_percentile, upper_percentile)
|
| 93 |
+
|
| 94 |
+
img = convert_to_tensor(img, track_meta=False)
|
| 95 |
+
return img
|
| 96 |
+
|
| 97 |
+
def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
|
| 98 |
+
"""
|
| 99 |
+
Apply the transform to `img`.
|
| 100 |
+
"""
|
| 101 |
+
img = convert_to_tensor(img, track_meta=get_track_meta())
|
| 102 |
+
img_t = convert_to_tensor(img, track_meta=False)
|
| 103 |
+
mask_t = convert_to_tensor(mask_data, track_meta=False)
|
| 104 |
+
if self.channel_wise:
|
| 105 |
+
img_t = torch.stack([self._clip(img=d, mask_data=mask_t[e]) for e,d in enumerate(img_t)]) # type: ignore
|
| 106 |
+
else:
|
| 107 |
+
img_t = self._clip(img=img_t, mask_data=mask_t)
|
| 108 |
+
|
| 109 |
+
img = convert_to_dst_type(img_t, dst=img)[0]
|
| 110 |
+
|
| 111 |
+
return img
|
| 112 |
+
|
| 113 |
+
class ClipMaskIntensityPercentilesd(MapTransform):
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
keys: KeysCollection,
|
| 118 |
+
mask_key: str,
|
| 119 |
+
lower: Union[float, None],
|
| 120 |
+
upper: Union[float, None],
|
| 121 |
+
sharpness_factor: Union[float, None] = None,
|
| 122 |
+
channel_wise: bool = False,
|
| 123 |
+
dtype: DtypeLike = np.float32,
|
| 124 |
+
allow_missing_keys: bool = False,
|
| 125 |
+
) -> None:
|
| 126 |
+
super().__init__(keys, allow_missing_keys)
|
| 127 |
+
self.scaler = ClipMaskIntensityPercentiles(
|
| 128 |
+
lower=lower, upper=upper, sharpness_factor=sharpness_factor, channel_wise=channel_wise, dtype=dtype
|
| 129 |
+
)
|
| 130 |
+
self.mask_key = mask_key
|
| 131 |
+
|
| 132 |
+
def __call__(self, data: dict) -> dict:
|
| 133 |
+
d = dict(data)
|
| 134 |
+
for key in self.key_iterator(d):
|
| 135 |
+
d[key] = self.scaler(d[key], d[self.mask_key])
|
| 136 |
+
return d
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class ElementwiseProductd(MapTransform):
|
| 142 |
+
def __init__(self, keys: KeysCollection, output_key: str) -> None:
|
| 143 |
+
super().__init__(keys)
|
| 144 |
+
self.output_key = output_key
|
| 145 |
+
|
| 146 |
+
def __call__(self, data) -> NdarrayOrTensor:
|
| 147 |
+
d = dict(data)
|
| 148 |
+
d[self.output_key] = d[self.keys[0]] * d[self.keys[1]]
|
| 149 |
+
return d
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class CLAHEd(MapTransform):
|
| 153 |
+
"""
|
| 154 |
+
Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to images in a data dictionary.
|
| 155 |
+
Works on 2D images or 3D volumes (applied slice-by-slice).
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
keys (KeysCollection): Keys of the items to be transformed.
|
| 159 |
+
clip_limit (float): Threshold for contrast limiting. Default is 2.0.
|
| 160 |
+
tile_grid_size (Union[tuple, Sequence[int]]): Size of grid for histogram equalization (default: (8,8)).
|
| 161 |
+
"""
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
keys: KeysCollection,
|
| 165 |
+
clip_limit: float = 2.0,
|
| 166 |
+
tile_grid_size: Union[tuple, Sequence[int]] = (8, 8),
|
| 167 |
+
) -> None:
|
| 168 |
+
super().__init__(keys)
|
| 169 |
+
self.clip_limit = clip_limit
|
| 170 |
+
self.tile_grid_size = tile_grid_size
|
| 171 |
+
|
| 172 |
+
def __call__(self, data):
|
| 173 |
+
d = dict(data)
|
| 174 |
+
for key in self.keys:
|
| 175 |
+
image_ = d[key]
|
| 176 |
+
|
| 177 |
+
image = image_.cpu().numpy()
|
| 178 |
+
|
| 179 |
+
if image.dtype != np.uint8:
|
| 180 |
+
image = image.astype(np.uint8)
|
| 181 |
+
|
| 182 |
+
clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)
|
| 183 |
+
# Handle 2D images or process 3D images slice-by-slice.
|
| 184 |
+
|
| 185 |
+
image_clahe = np.stack([clahe.apply(slice) for slice in image[0]])
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# Convert back to float in [0,1]
|
| 189 |
+
processed_img = image_clahe.astype(np.float32) / 255.0
|
| 190 |
+
reshaped_ = processed_img.reshape(1, *processed_img.shape)
|
| 191 |
+
d[key] = torch.from_numpy(reshaped_).to(image_.device)
|
| 192 |
+
return d
|
| 193 |
+
|
| 194 |
+
class NormalizeIntensity_custom(Transform):
|
| 195 |
+
"""
|
| 196 |
+
Normalize input based on the `subtrahend` and `divisor`: `(img - subtrahend) / divisor`.
|
| 197 |
+
Use calculated mean or std value of the input image if no `subtrahend` or `divisor` provided.
|
| 198 |
+
This transform can normalize only non-zero values or entire image, and can also calculate
|
| 199 |
+
mean and std on each channel separately.
|
| 200 |
+
When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should
|
| 201 |
+
be the number of image channels if they are not None.
|
| 202 |
+
If the input is not of floating point type, it will be converted to float32
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
subtrahend: the amount to subtract by (usually the mean).
|
| 206 |
+
divisor: the amount to divide by (usually the standard deviation).
|
| 207 |
+
nonzero: whether only normalize non-zero values.
|
| 208 |
+
channel_wise: if True, calculate on each channel separately, otherwise, calculate on
|
| 209 |
+
the entire image directly. default to False.
|
| 210 |
+
dtype: output data type, if None, same as input image. defaults to float32.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
|
| 214 |
+
|
| 215 |
+
def __init__(
|
| 216 |
+
self,
|
| 217 |
+
subtrahend: Union[Sequence, NdarrayOrTensor, None] = None,
|
| 218 |
+
divisor: Union[Sequence, NdarrayOrTensor, None] = None,
|
| 219 |
+
nonzero: bool = False,
|
| 220 |
+
channel_wise: bool = False,
|
| 221 |
+
dtype: DtypeLike = np.float32,
|
| 222 |
+
) -> None:
|
| 223 |
+
self.subtrahend = subtrahend
|
| 224 |
+
self.divisor = divisor
|
| 225 |
+
self.nonzero = nonzero
|
| 226 |
+
self.channel_wise = channel_wise
|
| 227 |
+
self.dtype = dtype
|
| 228 |
+
|
| 229 |
+
@staticmethod
|
| 230 |
+
def _mean(x):
|
| 231 |
+
if isinstance(x, np.ndarray):
|
| 232 |
+
return np.mean(x)
|
| 233 |
+
x = torch.mean(x.float())
|
| 234 |
+
return x.item() if x.numel() == 1 else x
|
| 235 |
+
|
| 236 |
+
@staticmethod
|
| 237 |
+
def _std(x):
|
| 238 |
+
if isinstance(x, np.ndarray):
|
| 239 |
+
return np.std(x)
|
| 240 |
+
x = torch.std(x.float(), unbiased=False)
|
| 241 |
+
return x.item() if x.numel() == 1 else x
|
| 242 |
+
|
| 243 |
+
def _normalize(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTensor:
|
| 244 |
+
img, *_ = convert_data_type(img, dtype=torch.float32)
|
| 245 |
+
'''
|
| 246 |
+
if self.nonzero:
|
| 247 |
+
slices = img != 0
|
| 248 |
+
masked_img = img[slices]
|
| 249 |
+
if not slices.any():
|
| 250 |
+
return img
|
| 251 |
+
else:
|
| 252 |
+
slices = None
|
| 253 |
+
masked_img = img
|
| 254 |
+
'''
|
| 255 |
+
slices = None
|
| 256 |
+
mask_data = mask_data.squeeze(0)
|
| 257 |
+
slices_mask = mask_data > 0
|
| 258 |
+
masked_img = img[slices_mask]
|
| 259 |
+
|
| 260 |
+
_sub = sub if sub is not None else self._mean(masked_img)
|
| 261 |
+
if isinstance(_sub, (torch.Tensor, np.ndarray)):
|
| 262 |
+
_sub, *_ = convert_to_dst_type(_sub, img)
|
| 263 |
+
if slices is not None:
|
| 264 |
+
_sub = _sub[slices]
|
| 265 |
+
|
| 266 |
+
_div = div if div is not None else self._std(masked_img)
|
| 267 |
+
if np.isscalar(_div):
|
| 268 |
+
if _div == 0.0:
|
| 269 |
+
_div = 1.0
|
| 270 |
+
elif isinstance(_div, (torch.Tensor, np.ndarray)):
|
| 271 |
+
_div, *_ = convert_to_dst_type(_div, img)
|
| 272 |
+
if slices is not None:
|
| 273 |
+
_div = _div[slices]
|
| 274 |
+
_div[_div == 0.0] = 1.0
|
| 275 |
+
|
| 276 |
+
if slices is not None:
|
| 277 |
+
img[slices] = (masked_img - _sub) / _div
|
| 278 |
+
else:
|
| 279 |
+
img = (img - _sub) / _div
|
| 280 |
+
return img
|
| 281 |
+
|
| 282 |
+
def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
|
| 283 |
+
"""
|
| 284 |
+
Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True,
|
| 285 |
+
"""
|
| 286 |
+
img = convert_to_tensor(img, track_meta=get_track_meta())
|
| 287 |
+
mask_data = convert_to_tensor(mask_data, track_meta=get_track_meta())
|
| 288 |
+
dtype = self.dtype or img.dtype
|
| 289 |
+
if self.channel_wise:
|
| 290 |
+
if self.subtrahend is not None and len(self.subtrahend) != len(img):
|
| 291 |
+
raise ValueError(f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components.")
|
| 292 |
+
if self.divisor is not None and len(self.divisor) != len(img):
|
| 293 |
+
raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.")
|
| 294 |
+
|
| 295 |
+
if not img.dtype.is_floating_point:
|
| 296 |
+
img, *_ = convert_data_type(img, dtype=torch.float32)
|
| 297 |
+
|
| 298 |
+
for i, d in enumerate(img):
|
| 299 |
+
img[i] = self._normalize( # type: ignore
|
| 300 |
+
d,
|
| 301 |
+
mask_data,
|
| 302 |
+
sub=self.subtrahend[i] if self.subtrahend is not None else None,
|
| 303 |
+
div=self.divisor[i] if self.divisor is not None else None,
|
| 304 |
+
)
|
| 305 |
+
else:
|
| 306 |
+
img = self._normalize(img, mask_data, self.subtrahend, self.divisor)
|
| 307 |
+
|
| 308 |
+
out = convert_to_dst_type(img, img, dtype=dtype)[0]
|
| 309 |
+
return out
|
| 310 |
+
|
| 311 |
+
class NormalizeIntensity_customd(MapTransform):
|
| 312 |
+
"""
|
| 313 |
+
Dictionary-based wrapper of :py:class:`monai.transforms.NormalizeIntensity`.
|
| 314 |
+
This transform can normalize only non-zero values or entire image, and can also calculate
|
| 315 |
+
mean and std on each channel separately.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
keys: keys of the corresponding items to be transformed.
|
| 319 |
+
See also: monai.transforms.MapTransform
|
| 320 |
+
subtrahend: the amount to subtract by (usually the mean)
|
| 321 |
+
divisor: the amount to divide by (usually the standard deviation)
|
| 322 |
+
nonzero: whether only normalize non-zero values.
|
| 323 |
+
channel_wise: if True, calculate on each channel separately, otherwise, calculate on
|
| 324 |
+
the entire image directly. default to False.
|
| 325 |
+
dtype: output data type, if None, same as input image. defaults to float32.
|
| 326 |
+
allow_missing_keys: don't raise exception if key is missing.
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
backend = NormalizeIntensity_custom.backend
|
| 330 |
+
|
| 331 |
+
def __init__(
|
| 332 |
+
self,
|
| 333 |
+
keys: KeysCollection,
|
| 334 |
+
mask_key: str,
|
| 335 |
+
subtrahend:Union[ NdarrayOrTensor, None] = None,
|
| 336 |
+
divisor: Union[ NdarrayOrTensor, None] = None,
|
| 337 |
+
nonzero: bool = False,
|
| 338 |
+
channel_wise: bool = False,
|
| 339 |
+
dtype: DtypeLike = np.float32,
|
| 340 |
+
allow_missing_keys: bool = False,
|
| 341 |
+
) -> None:
|
| 342 |
+
super().__init__(keys, allow_missing_keys)
|
| 343 |
+
self.normalizer = NormalizeIntensity_custom(subtrahend, divisor, nonzero, channel_wise, dtype)
|
| 344 |
+
self.mask_key = mask_key
|
| 345 |
+
|
| 346 |
+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
|
| 347 |
+
d = dict(data)
|
| 348 |
+
for key in self.key_iterator(d):
|
| 349 |
+
d[key] = self.normalizer(d[key], d[self.mask_key])
|
| 350 |
+
return d
|
src/data/data_loader.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from monai.config import KeysCollection
|
| 5 |
+
from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
|
| 6 |
+
from monai.transforms import (
|
| 7 |
+
Compose,
|
| 8 |
+
LoadImaged,
|
| 9 |
+
MapTransform,
|
| 10 |
+
ScaleIntensityRanged,
|
| 11 |
+
SplitDimd,
|
| 12 |
+
ToTensord,
|
| 13 |
+
ConcatItemsd,
|
| 14 |
+
SelectItemsd,
|
| 15 |
+
EnsureChannelFirstd,
|
| 16 |
+
RepeatChanneld,
|
| 17 |
+
DeleteItemsd,
|
| 18 |
+
EnsureTyped,
|
| 19 |
+
ClipIntensityPercentilesd,
|
| 20 |
+
MaskIntensityd,
|
| 21 |
+
RandCropByPosNegLabeld,
|
| 22 |
+
NormalizeIntensityd,
|
| 23 |
+
SqueezeDimd,
|
| 24 |
+
ScaleIntensityd,
|
| 25 |
+
ScaleIntensityd,
|
| 26 |
+
Transposed,
|
| 27 |
+
RandWeightedCropd,
|
| 28 |
+
)
|
| 29 |
+
from .custom_transforms import (
|
| 30 |
+
NormalizeIntensity_customd,
|
| 31 |
+
ClipMaskIntensityPercentilesd,
|
| 32 |
+
ElementwiseProductd,
|
| 33 |
+
)
|
| 34 |
+
import torch
|
| 35 |
+
from torch.utils.data.dataloader import default_collate
|
| 36 |
+
import matplotlib.pyplot as plt
|
| 37 |
+
from typing import Literal
|
| 38 |
+
import monai
|
| 39 |
+
import collections.abc
|
| 40 |
+
|
| 41 |
+
def list_data_collate(batch: collections.abc.Sequence):
|
| 42 |
+
"""
|
| 43 |
+
Combine instances from a list of dicts into a single dict, by stacking them along first dim
|
| 44 |
+
[{'image' : 3xHxW}, {'image' : 3xHxW}, {'image' : 3xHxW}...] - > {'image' : Nx3xHxW}
|
| 45 |
+
followed by the default collate which will form a batch BxNx3xHxW
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
for i, item in enumerate(batch):
|
| 49 |
+
data = item[0]
|
| 50 |
+
data["image"] = torch.stack([ix["image"] for ix in item], dim=0)
|
| 51 |
+
|
| 52 |
+
if all("final_heatmap" in ix for ix in item):
|
| 53 |
+
data["final_heatmap"] = torch.stack([ix["final_heatmap"] for ix in item], dim=0)
|
| 54 |
+
|
| 55 |
+
batch[i] = data
|
| 56 |
+
return default_collate(batch)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def data_transform(args):
|
| 61 |
+
|
| 62 |
+
if args.use_heatmap:
|
| 63 |
+
transform = Compose(
|
| 64 |
+
[
|
| 65 |
+
LoadImaged(keys=["image", "mask","dwi", "adc", "heatmap"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
|
| 66 |
+
ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
|
| 67 |
+
ConcatItemsd(keys=["image", "dwi", "adc"], name="image", dim=0), # stacks to (3, H, W)
|
| 68 |
+
NormalizeIntensity_customd(keys=["image"], channel_wise=True, mask_key="mask"),
|
| 69 |
+
ElementwiseProductd(keys=["mask", "heatmap"], output_key="final_heatmap"),
|
| 70 |
+
RandWeightedCropd(keys=["image", "final_heatmap"],
|
| 71 |
+
w_key="final_heatmap",
|
| 72 |
+
spatial_size=(args.tile_size,args.tile_size,args.depth),
|
| 73 |
+
num_samples=args.tile_count),
|
| 74 |
+
EnsureTyped(keys=["label"], dtype=torch.float32),
|
| 75 |
+
Transposed(keys=["image"], indices=(0, 3, 1, 2)),
|
| 76 |
+
DeleteItemsd(keys=['mask', 'dwi', 'adc', 'heatmap']),
|
| 77 |
+
ToTensord(keys=["image", "label", "final_heatmap"]),
|
| 78 |
+
]
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
transform = Compose(
|
| 82 |
+
[
|
| 83 |
+
LoadImaged(keys=["image", "mask","dwi", "adc"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),
|
| 84 |
+
ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
|
| 85 |
+
ConcatItemsd(keys=["image", "dwi", "adc"], name="image", dim=0), # stacks to (3, H, W)
|
| 86 |
+
NormalizeIntensityd(keys=["image"], channel_wise=True),
|
| 87 |
+
RandCropByPosNegLabeld(keys=["image"],
|
| 88 |
+
label_key="mask",
|
| 89 |
+
spatial_size=(args.tile_size,args.tile_size,args.depth),
|
| 90 |
+
pos=1,
|
| 91 |
+
neg=0,
|
| 92 |
+
num_samples=args.tile_count),
|
| 93 |
+
EnsureTyped(keys=["label"], dtype=torch.float32),
|
| 94 |
+
Transposed(keys=["image"], indices=(0, 3, 1, 2)),
|
| 95 |
+
DeleteItemsd(keys=['mask', 'dwi', 'adc']),
|
| 96 |
+
ToTensord(keys=["image", "label"]),
|
| 97 |
+
]
|
| 98 |
+
)
|
| 99 |
+
return transform
|
| 100 |
+
|
| 101 |
+
def get_dataloader(args, split: Literal["train", "test"]):
|
| 102 |
+
|
| 103 |
+
data_list = load_decathlon_datalist(
|
| 104 |
+
data_list_file_path=args.dataset_json,
|
| 105 |
+
data_list_key=split,
|
| 106 |
+
base_dir=args.data_root,
|
| 107 |
+
)
|
| 108 |
+
if args.dry_run:
|
| 109 |
+
data_list = data_list[:8] # Use only 8 samples for dry run
|
| 110 |
+
cache_dir_ = os.path.join(args.logdir, "cache")
|
| 111 |
+
os.makedirs(os.path.join(cache_dir_, split), exist_ok=True)
|
| 112 |
+
transform = data_transform(args)
|
| 113 |
+
dataset = PersistentDataset(data=data_list, transform=transform, cache_dir= os.path.join(cache_dir_, split))
|
| 114 |
+
loader = torch.utils.data.DataLoader(
|
| 115 |
+
dataset,
|
| 116 |
+
batch_size=args.batch_size,
|
| 117 |
+
shuffle=(split == "train"),
|
| 118 |
+
num_workers=args.workers,
|
| 119 |
+
pin_memory=True,
|
| 120 |
+
multiprocessing_context="spawn" if args.workers > 0 else None,
|
| 121 |
+
sampler=None,
|
| 122 |
+
collate_fn=list_data_collate,
|
| 123 |
+
)
|
| 124 |
+
return loader
|
| 125 |
+
|
src/model/MIL.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import cast
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from monai.utils.module import optional_import
|
| 10 |
+
from monai.networks.nets import resnet
|
| 11 |
+
models, _ = optional_import("torchvision.models")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MILModel_3D(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Multiple Instance Learning (MIL) model, with a backbone classification model.
|
| 17 |
+
Adapted from MONAI, modified for 3D images. The expected shape of input data is `[B, N, C, D, H, W]`,
|
| 18 |
+
where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances
|
| 19 |
+
extracted from every original image in the batch. A tutorial example is available at:
|
| 20 |
+
https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
num_classes: number of output classes.
|
| 24 |
+
mil_mode: MIL algorithm, available values (Defaults to ``"att"``):
|
| 25 |
+
|
| 26 |
+
- ``"mean"`` - average features from all instances, equivalent to pure CNN (non MIL).
|
| 27 |
+
- ``"max"`` - retain only the instance with the max probability for loss calculation.
|
| 28 |
+
- ``"att"`` - attention based MIL https://arxiv.org/abs/1802.04712.
|
| 29 |
+
- ``"att_trans"`` - transformer MIL https://arxiv.org/abs/2111.01556.
|
| 30 |
+
- ``"att_trans_pyramid"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556.
|
| 31 |
+
|
| 32 |
+
pretrained: init backbone with pretrained weights, defaults to ``True``.
|
| 33 |
+
backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features,
|
| 34 |
+
or a string name of a torchvision model).
|
| 35 |
+
Defaults to ``None``, in which case ResNet50 is used.
|
| 36 |
+
backbone_num_features: Number of output features of the backbone CNN
|
| 37 |
+
Defaults to ``None`` (necessary only when using a custom backbone)
|
| 38 |
+
trans_blocks: number of the blocks in `TransformEncoder` layer.
|
| 39 |
+
trans_dropout: dropout rate in `TransformEncoder` layer.
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
num_classes: int,
|
| 46 |
+
mil_mode: str = "att",
|
| 47 |
+
pretrained: bool = True,
|
| 48 |
+
backbone: str | nn.Module | None = None,
|
| 49 |
+
backbone_num_features: int | None = None,
|
| 50 |
+
trans_blocks: int = 4,
|
| 51 |
+
trans_dropout: float = 0.0,
|
| 52 |
+
) -> None:
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
if num_classes <= 0:
|
| 56 |
+
raise ValueError("Number of classes must be positive: " + str(num_classes))
|
| 57 |
+
|
| 58 |
+
if mil_mode.lower() not in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]:
|
| 59 |
+
raise ValueError("Unsupported mil_mode: " + str(mil_mode))
|
| 60 |
+
|
| 61 |
+
self.mil_mode = mil_mode.lower()
|
| 62 |
+
self.attention = nn.Sequential()
|
| 63 |
+
self.transformer: nn.Module | None = None
|
| 64 |
+
|
| 65 |
+
if backbone is None:
|
| 66 |
+
net = resnet.resnet18(spatial_dims=3, n_input_channels=3, num_classes=5, )
|
| 67 |
+
nfc = net.fc.in_features # save the number of final features
|
| 68 |
+
net.fc = torch.nn.Identity() # remove final linear layer
|
| 69 |
+
|
| 70 |
+
self.extra_outputs: dict[str, torch.Tensor] = {}
|
| 71 |
+
|
| 72 |
+
if mil_mode == "att_trans_pyramid":
|
| 73 |
+
# register hooks to capture outputs of intermediate layers
|
| 74 |
+
def forward_hook(layer_name):
|
| 75 |
+
|
| 76 |
+
def hook(module, input, output):
|
| 77 |
+
self.extra_outputs[layer_name] = output
|
| 78 |
+
|
| 79 |
+
return hook
|
| 80 |
+
|
| 81 |
+
net.layer1.register_forward_hook(forward_hook("layer1"))
|
| 82 |
+
net.layer2.register_forward_hook(forward_hook("layer2"))
|
| 83 |
+
net.layer3.register_forward_hook(forward_hook("layer3"))
|
| 84 |
+
net.layer4.register_forward_hook(forward_hook("layer4"))
|
| 85 |
+
|
| 86 |
+
elif isinstance(backbone, str):
|
| 87 |
+
# assume torchvision model string is provided
|
| 88 |
+
torch_model = getattr(models, backbone, None)
|
| 89 |
+
if torch_model is None:
|
| 90 |
+
raise ValueError("Unknown torch vision model" + str(backbone))
|
| 91 |
+
net = torch_model(pretrained=pretrained)
|
| 92 |
+
|
| 93 |
+
if getattr(net, "fc", None) is not None:
|
| 94 |
+
nfc = net.fc.in_features # save the number of final features
|
| 95 |
+
net.fc = torch.nn.Identity() # remove final linear layer
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError(
|
| 98 |
+
"Unable to detect FC layer for the torchvision model " + str(backbone),
|
| 99 |
+
". Please initialize the backbone model manually.",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
elif isinstance(backbone, nn.Module):
|
| 103 |
+
# use a custom backbone
|
| 104 |
+
net = backbone
|
| 105 |
+
nfc = backbone_num_features
|
| 106 |
+
net.fc = torch.nn.Identity() # remove final linear layer
|
| 107 |
+
|
| 108 |
+
self.extra_outputs: dict[str, torch.Tensor] = {}
|
| 109 |
+
|
| 110 |
+
if mil_mode == "att_trans_pyramid":
|
| 111 |
+
# register hooks to capture outputs of intermediate layers
|
| 112 |
+
def forward_hook(layer_name):
|
| 113 |
+
|
| 114 |
+
def hook(module, input, output):
|
| 115 |
+
self.extra_outputs[layer_name] = output
|
| 116 |
+
|
| 117 |
+
return hook
|
| 118 |
+
|
| 119 |
+
net.layer1.register_forward_hook(forward_hook("layer1"))
|
| 120 |
+
net.layer2.register_forward_hook(forward_hook("layer2"))
|
| 121 |
+
net.layer3.register_forward_hook(forward_hook("layer3"))
|
| 122 |
+
net.layer4.register_forward_hook(forward_hook("layer4"))
|
| 123 |
+
|
| 124 |
+
if backbone_num_features is None:
|
| 125 |
+
raise ValueError("Number of endencoder features must be provided for a custom backbone model")
|
| 126 |
+
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError("Unsupported backbone")
|
| 129 |
+
|
| 130 |
+
if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]:
|
| 131 |
+
raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode))
|
| 132 |
+
|
| 133 |
+
if self.mil_mode in ["mean", "max"]:
|
| 134 |
+
pass
|
| 135 |
+
elif self.mil_mode == "att":
|
| 136 |
+
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))
|
| 137 |
+
|
| 138 |
+
elif self.mil_mode == "att_trans":
|
| 139 |
+
transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout)
|
| 140 |
+
self.transformer = nn.TransformerEncoder(transformer, num_layers=trans_blocks)
|
| 141 |
+
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))
|
| 142 |
+
|
| 143 |
+
elif self.mil_mode == "att_trans_pyramid":
|
| 144 |
+
transformer_list = nn.ModuleList(
|
| 145 |
+
[
|
| 146 |
+
nn.TransformerEncoder(
|
| 147 |
+
nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout), num_layers=trans_blocks
|
| 148 |
+
),
|
| 149 |
+
nn.Sequential(
|
| 150 |
+
nn.Linear(192, 64),
|
| 151 |
+
nn.TransformerEncoder(
|
| 152 |
+
nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout),
|
| 153 |
+
num_layers=trans_blocks,
|
| 154 |
+
),
|
| 155 |
+
),
|
| 156 |
+
nn.Sequential(
|
| 157 |
+
nn.Linear(320, 64),
|
| 158 |
+
nn.TransformerEncoder(
|
| 159 |
+
nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout),
|
| 160 |
+
num_layers=trans_blocks,
|
| 161 |
+
),
|
| 162 |
+
),
|
| 163 |
+
nn.TransformerEncoder(
|
| 164 |
+
nn.TransformerEncoderLayer(d_model=576, nhead=8, dropout=trans_dropout),
|
| 165 |
+
num_layers=trans_blocks,
|
| 166 |
+
),
|
| 167 |
+
]
|
| 168 |
+
)
|
| 169 |
+
self.transformer = transformer_list
|
| 170 |
+
nfc = nfc + 64
|
| 171 |
+
self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))
|
| 172 |
+
|
| 173 |
+
else:
|
| 174 |
+
raise ValueError("Unsupported mil_mode: " + str(mil_mode))
|
| 175 |
+
|
| 176 |
+
self.myfc = nn.Linear(nfc, num_classes)
|
| 177 |
+
self.net = net
|
| 178 |
+
|
| 179 |
+
def calc_head(self, x: torch.Tensor) -> torch.Tensor:
|
| 180 |
+
sh = x.shape
|
| 181 |
+
|
| 182 |
+
if self.mil_mode == "mean":
|
| 183 |
+
x = self.myfc(x)
|
| 184 |
+
x = torch.mean(x, dim=1)
|
| 185 |
+
|
| 186 |
+
elif self.mil_mode == "max":
|
| 187 |
+
x = self.myfc(x)
|
| 188 |
+
x, _ = torch.max(x, dim=1)
|
| 189 |
+
|
| 190 |
+
elif self.mil_mode == "att":
|
| 191 |
+
a = self.attention(x)
|
| 192 |
+
a = torch.softmax(a, dim=1)
|
| 193 |
+
x = torch.sum(x * a, dim=1)
|
| 194 |
+
|
| 195 |
+
x = self.myfc(x)
|
| 196 |
+
|
| 197 |
+
elif self.mil_mode == "att_trans" and self.transformer is not None:
|
| 198 |
+
x = x.permute(1, 0, 2)
|
| 199 |
+
x = self.transformer(x)
|
| 200 |
+
x = x.permute(1, 0, 2)
|
| 201 |
+
|
| 202 |
+
a = self.attention(x)
|
| 203 |
+
a = torch.softmax(a, dim=1)
|
| 204 |
+
x = torch.sum(x * a, dim=1)
|
| 205 |
+
|
| 206 |
+
x = self.myfc(x)
|
| 207 |
+
|
| 208 |
+
elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None:
|
| 209 |
+
l1 = torch.mean(self.extra_outputs["layer1"], dim=(2, 3, 4)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
|
| 210 |
+
l2 = torch.mean(self.extra_outputs["layer2"], dim=(2, 3, 4)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
|
| 211 |
+
l3 = torch.mean(self.extra_outputs["layer3"], dim=(2, 3, 4)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
|
| 212 |
+
l4 = torch.mean(self.extra_outputs["layer4"], dim=(2, 3, 4)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)
|
| 213 |
+
|
| 214 |
+
transformer_list = cast(nn.ModuleList, self.transformer)
|
| 215 |
+
|
| 216 |
+
x = transformer_list[0](l1)
|
| 217 |
+
x = transformer_list[1](torch.cat((x, l2), dim=2))
|
| 218 |
+
x = transformer_list[2](torch.cat((x, l3), dim=2))
|
| 219 |
+
x = transformer_list[3](torch.cat((x, l4), dim=2))
|
| 220 |
+
|
| 221 |
+
x = x.permute(1, 0, 2)
|
| 222 |
+
|
| 223 |
+
a = self.attention(x)
|
| 224 |
+
a = torch.softmax(a, dim=1)
|
| 225 |
+
x = torch.sum(x * a, dim=1)
|
| 226 |
+
|
| 227 |
+
x = self.myfc(x)
|
| 228 |
+
|
| 229 |
+
else:
|
| 230 |
+
raise ValueError("Wrong model mode" + str(self.mil_mode))
|
| 231 |
+
|
| 232 |
+
return x
|
| 233 |
+
|
| 234 |
+
def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor:
|
| 235 |
+
sh = x.shape
|
| 236 |
+
x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5])
|
| 237 |
+
|
| 238 |
+
x = self.net(x)
|
| 239 |
+
x = x.reshape(sh[0], sh[1], -1)
|
| 240 |
+
|
| 241 |
+
if not no_head:
|
| 242 |
+
x = self.calc_head(x)
|
| 243 |
+
|
| 244 |
+
return x
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
|
src/model/__init__.py
ADDED
|
File without changes
|
src/model/__pycache__/MIL.cpython-39.pyc
ADDED
|
Binary file (6.85 kB). View file
|
|
|
src/model/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (224 Bytes). View file
|
|
|
src/model/__pycache__/csPCa_model.cpython-39.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
src/model/csPCa_model.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import cast
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from monai.utils.module import optional_import
|
| 10 |
+
models, _ = optional_import("torchvision.models")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SimpleNN(nn.Module):
|
| 15 |
+
def __init__(self, input_dim):
|
| 16 |
+
super(SimpleNN, self).__init__()
|
| 17 |
+
self.net = nn.Sequential(
|
| 18 |
+
nn.Linear(input_dim, 256),
|
| 19 |
+
nn.ReLU(),
|
| 20 |
+
nn.Linear( 256,128),
|
| 21 |
+
nn.ReLU(),
|
| 22 |
+
nn.Dropout(p=0.3),
|
| 23 |
+
nn.Linear(128, 1),
|
| 24 |
+
nn.Sigmoid() # since binary classification
|
| 25 |
+
)
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
return self.net(x)
|
| 28 |
+
|
| 29 |
+
class csPCa_Model(nn.Module):
|
| 30 |
+
def __init__(self, backbone):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.backbone = backbone
|
| 33 |
+
self.fc_dim = backbone.myfc.in_features
|
| 34 |
+
self.fc_cspca = SimpleNN(input_dim=self.fc_dim)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
sh = x.shape
|
| 38 |
+
x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5])
|
| 39 |
+
x = self.backbone.net(x)
|
| 40 |
+
x = x.reshape(sh[0], sh[1], -1)
|
| 41 |
+
x = x.permute(1, 0, 2)
|
| 42 |
+
x = self.backbone.transformer(x)
|
| 43 |
+
x = x.permute(1, 0, 2)
|
| 44 |
+
a = self.backbone.attention(x)
|
| 45 |
+
a = torch.softmax(a, dim=1)
|
| 46 |
+
x = torch.sum(x * a, dim=1)
|
| 47 |
+
|
| 48 |
+
x = self.fc_cspca(x)
|
| 49 |
+
return x
|
| 50 |
+
|
src/preprocessing/__init__.py
ADDED
|
File without changes
|
src/preprocessing/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (232 Bytes). View file
|
|
|
src/preprocessing/__pycache__/center_crop.cpython-39.pyc
ADDED
|
Binary file (2.73 kB). View file
|
|
|
src/preprocessing/__pycache__/generate_heatmap.cpython-39.pyc
ADDED
|
Binary file (1.49 kB). View file
|
|
|
src/preprocessing/__pycache__/histogram_match.cpython-39.pyc
ADDED
|
Binary file (1.99 kB). View file
|
|
|
src/preprocessing/__pycache__/prostate_mask.cpython-39.pyc
ADDED
|
Binary file (3.86 kB). View file
|
|
|
src/preprocessing/__pycache__/register_and_crop.cpython-39.pyc
ADDED
|
Binary file (2.13 kB). View file
|
|
|
src/preprocessing/center_crop.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 - 2022 MONAI Consortium
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
#python scripts/center_crop.py --file_name path/to/t2_image --out_name cropped_t2
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
#import argparse
|
| 15 |
+
from typing import Union
|
| 16 |
+
|
| 17 |
+
import SimpleITK as sitk # noqa N813
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _flatten(t):
|
| 21 |
+
return [item for sublist in t for item in sublist]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def crop(image: sitk.Image, margin: Union[int, float], interpolator=sitk.sitkLinear):
|
| 25 |
+
"""
|
| 26 |
+
Crops a sitk.Image while retaining correct spacing. Negative margins will lead to zero padding
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
image: a sitk.Image
|
| 30 |
+
margin: margins to crop. Single integer or float (percentage crop),
|
| 31 |
+
lists of int/float or nestes lists are supported.
|
| 32 |
+
"""
|
| 33 |
+
if isinstance(margin, (list, tuple)):
|
| 34 |
+
assert len(margin) == 3, "expected margin to be of length 3"
|
| 35 |
+
else:
|
| 36 |
+
assert isinstance(margin, (int, float)), "expected margin to be a float value"
|
| 37 |
+
margin = [margin, margin, margin]
|
| 38 |
+
|
| 39 |
+
margin = [m if isinstance(m, (tuple, list)) else [m, m] for m in margin]
|
| 40 |
+
old_size = image.GetSize()
|
| 41 |
+
|
| 42 |
+
# calculate new origin and new image size
|
| 43 |
+
if all([isinstance(m, float) for m in _flatten(margin)]):
|
| 44 |
+
assert all([m >= 0 and m < 0.5 for m in _flatten(margin)]), "margins must be between 0 and 0.5"
|
| 45 |
+
to_crop = [[int(sz * _m) for _m in m] for sz, m in zip(old_size, margin)]
|
| 46 |
+
elif all([isinstance(m, int) for m in _flatten(margin)]):
|
| 47 |
+
to_crop = margin
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError("Wrong format of margins.")
|
| 50 |
+
|
| 51 |
+
new_size = [sz - sum(c) for sz, c in zip(old_size, to_crop)]
|
| 52 |
+
|
| 53 |
+
# origin has Index (0,0,0)
|
| 54 |
+
# new origin has Index (to_crop[0][0], to_crop[2][0], to_crop[2][0])
|
| 55 |
+
new_origin = image.TransformIndexToPhysicalPoint([c[0] for c in to_crop])
|
| 56 |
+
|
| 57 |
+
# create reference plane to resample image
|
| 58 |
+
ref_image = sitk.Image(new_size, image.GetPixelIDValue())
|
| 59 |
+
ref_image.SetSpacing(image.GetSpacing())
|
| 60 |
+
ref_image.SetOrigin(new_origin)
|
| 61 |
+
ref_image.SetDirection(image.GetDirection())
|
| 62 |
+
|
| 63 |
+
return sitk.Resample(image, ref_image, interpolator=interpolator)
|
| 64 |
+
|
src/preprocessing/generate_heatmap.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import nrrd
|
| 5 |
+
import json
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import json
|
| 8 |
+
import SimpleITK as sitk
|
| 9 |
+
import multiprocessing
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_heatmap(args):
|
| 15 |
+
|
| 16 |
+
files = os.listdir(args.t2_dir)
|
| 17 |
+
args.heatmapdir = os.path.join(args.output_dir, 'heatmaps/')
|
| 18 |
+
os.makedirs(args.heatmapdir, exist_ok=True)
|
| 19 |
+
for file in files:
|
| 20 |
+
|
| 21 |
+
bool_dwi = False
|
| 22 |
+
bool_adc = False
|
| 23 |
+
mask, _ = nrrd.read(os.path.join(args.seg_dir, file))
|
| 24 |
+
dwi, _ = nrrd.read(os.path.join(args.dwi_dir, file))
|
| 25 |
+
adc, _ = nrrd.read(os.path.join(args.adc_dir, file))
|
| 26 |
+
|
| 27 |
+
nonzero_vals_dwi = dwi[mask > 0]
|
| 28 |
+
if len(nonzero_vals_dwi) > 0:
|
| 29 |
+
min_val = nonzero_vals_dwi.min()
|
| 30 |
+
max_val = nonzero_vals_dwi.max()
|
| 31 |
+
heatmap_dwi = np.zeros_like(dwi, dtype=np.float32)
|
| 32 |
+
|
| 33 |
+
if min_val != max_val:
|
| 34 |
+
heatmap_dwi = (dwi - min_val) / (max_val - min_val)
|
| 35 |
+
masked_heatmap_dwi = np.where(mask > 0, heatmap_dwi, heatmap_dwi[mask>0].min())
|
| 36 |
+
else:
|
| 37 |
+
bool_dwi = True
|
| 38 |
+
|
| 39 |
+
else:
|
| 40 |
+
bool_dwi = True
|
| 41 |
+
|
| 42 |
+
nonzero_vals_adc = adc[mask > 0]
|
| 43 |
+
if len(nonzero_vals_adc) > 0:
|
| 44 |
+
min_val = nonzero_vals_adc.min()
|
| 45 |
+
max_val = nonzero_vals_adc.max()
|
| 46 |
+
heatmap_adc = np.zeros_like(adc, dtype=np.float32)
|
| 47 |
+
|
| 48 |
+
if min_val != max_val:
|
| 49 |
+
heatmap_adc = (max_val - adc) / (max_val - min_val)
|
| 50 |
+
masked_heatmap_adc = np.where(mask > 0, heatmap_adc, heatmap_adc[mask>0].min())
|
| 51 |
+
else:
|
| 52 |
+
bool_adc = True
|
| 53 |
+
|
| 54 |
+
else:
|
| 55 |
+
bool_adc = True
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if bool_dwi:
|
| 59 |
+
mix_mask = masked_heatmap_adc
|
| 60 |
+
if bool_adc:
|
| 61 |
+
mix_mask = masked_heatmap_dwi
|
| 62 |
+
if not bool_dwi and not bool_adc:
|
| 63 |
+
mix_mask = masked_heatmap_dwi * masked_heatmap_adc
|
| 64 |
+
else:
|
| 65 |
+
mix_mask = np.ones_like(adc, dtype=np.float32)
|
| 66 |
+
|
| 67 |
+
mix_mask = (mix_mask - mix_mask.min()) / (mix_mask.max() - mix_mask.min())
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
nrrd.write(os.path.join(args.heatmapdir, file), mix_mask)
|
| 71 |
+
|
| 72 |
+
return args
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
src/preprocessing/histogram_match.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import SimpleITK as sitk
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import nrrd
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import random
|
| 8 |
+
import json
|
| 9 |
+
from skimage import exposure
|
| 10 |
+
import multiprocessing
|
| 11 |
+
import logging
|
| 12 |
+
def get_histmatched(data, ref_data, mask, ref_mask):
|
| 13 |
+
|
| 14 |
+
source_pixels = data[mask > 0]
|
| 15 |
+
ref_pixels = ref_data[ref_mask > 0]
|
| 16 |
+
matched_pixels = exposure.match_histograms(source_pixels, ref_pixels)
|
| 17 |
+
matched_img = data.copy()
|
| 18 |
+
matched_img[mask > 0] = matched_pixels
|
| 19 |
+
|
| 20 |
+
return matched_img
|
| 21 |
+
|
| 22 |
+
def histmatch(args):
|
| 23 |
+
|
| 24 |
+
files = os.listdir(args.t2_dir)
|
| 25 |
+
|
| 26 |
+
t2_histmatched_dir = os.path.join(args.output_dir, 't2_histmatched')
|
| 27 |
+
dwi_histmatched_dir = os.path.join(args.output_dir, 'DWI_histmatched')
|
| 28 |
+
adc_histmatched_dir = os.path.join(args.output_dir, 'ADC_histmatched')
|
| 29 |
+
os.makedirs(t2_histmatched_dir, exist_ok=True)
|
| 30 |
+
os.makedirs(dwi_histmatched_dir, exist_ok=True)
|
| 31 |
+
os.makedirs(adc_histmatched_dir, exist_ok=True)
|
| 32 |
+
logging.info("Starting histogram matching")
|
| 33 |
+
for file in files:
|
| 34 |
+
|
| 35 |
+
t2_image, header_t2 = nrrd.read(os.path.join(args.t2_dir, file))
|
| 36 |
+
dwi_image, header_dwi = nrrd.read(os.path.join(args.dwi_dir, file))
|
| 37 |
+
adc_image, header_adc = nrrd.read(os.path.join(args.adc_dir, file))
|
| 38 |
+
|
| 39 |
+
ref_t2, _ = nrrd.read(os.path.join(args.project_dir, 'dataset', 't2_reference.nrrd'))
|
| 40 |
+
ref_dwi, _ = nrrd.read(os.path.join(args.project_dir, 'dataset', 'dwi_reference.nrrd'))
|
| 41 |
+
ref_adc , _ = nrrd.read(os.path.join(args.project_dir, 'dataset', 'adc_reference.nrrd'))
|
| 42 |
+
|
| 43 |
+
prostate_mask, _ = nrrd.read(os.path.join(args.seg_dir, file))
|
| 44 |
+
ref_prostate_mask, _ = nrrd.read(os.path.join(args.project_dir, 'dataset', 'prostate_segmentation_reference.nrrd'))
|
| 45 |
+
|
| 46 |
+
histmatched_t2 = get_histmatched(t2_image, ref_t2, prostate_mask, ref_prostate_mask)
|
| 47 |
+
histmatched_dwi = get_histmatched(dwi_image, ref_dwi, prostate_mask, ref_prostate_mask)
|
| 48 |
+
histmatched_adc = get_histmatched(adc_image, ref_adc, prostate_mask, ref_prostate_mask)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
nrrd.write(os.path.join(t2_histmatched_dir, file), histmatched_t2, header_t2)
|
| 53 |
+
nrrd.write(os.path.join(dwi_histmatched_dir, file), histmatched_dwi, header_dwi)
|
| 54 |
+
nrrd.write(os.path.join(adc_histmatched_dir, file), histmatched_adc, header_adc)
|
| 55 |
+
|
| 56 |
+
args.t2_dir = t2_histmatched_dir
|
| 57 |
+
args.dwi_dir = dwi_histmatched_dir
|
| 58 |
+
args.adc_dir = adc_histmatched_dir
|
| 59 |
+
|
| 60 |
+
return args
|
| 61 |
+
|
| 62 |
+
|
src/preprocessing/prostate_mask.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Union
|
| 3 |
+
import SimpleITK as sitk
|
| 4 |
+
import numpy as np
|
| 5 |
+
import nrrd
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from AIAH_utility.viewer import BasicViewer, ListViewer
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import monai
|
| 11 |
+
from monai.bundle import ConfigParser
|
| 12 |
+
from monai.config import print_config
|
| 13 |
+
import torch
|
| 14 |
+
import sys
|
| 15 |
+
import os
|
| 16 |
+
import nibabel as nib
|
| 17 |
+
import shutil
|
| 18 |
+
|
| 19 |
+
from tqdm import trange, tqdm
|
| 20 |
+
|
| 21 |
+
from monai.data import DataLoader, Dataset, TestTimeAugmentation, create_test_image_2d
|
| 22 |
+
from monai.losses import DiceLoss
|
| 23 |
+
from monai.metrics import DiceMetric
|
| 24 |
+
from monai.networks.nets import UNet
|
| 25 |
+
from monai.transforms import (
|
| 26 |
+
Activationsd,
|
| 27 |
+
AsDiscreted,
|
| 28 |
+
Compose,
|
| 29 |
+
CropForegroundd,
|
| 30 |
+
DivisiblePadd,
|
| 31 |
+
Invertd,
|
| 32 |
+
LoadImaged,
|
| 33 |
+
ScaleIntensityd,
|
| 34 |
+
RandRotated,
|
| 35 |
+
RandRotate,
|
| 36 |
+
InvertibleTransform,
|
| 37 |
+
RandFlipd,
|
| 38 |
+
Activations,
|
| 39 |
+
AsDiscrete,
|
| 40 |
+
NormalizeIntensityd,
|
| 41 |
+
)
|
| 42 |
+
from monai.utils import set_determinism
|
| 43 |
+
from monai.transforms import (
|
| 44 |
+
Resize,
|
| 45 |
+
EnsureChannelFirstd,
|
| 46 |
+
Orientationd,
|
| 47 |
+
Spacingd,
|
| 48 |
+
EnsureTyped,
|
| 49 |
+
)
|
| 50 |
+
import nrrd
|
| 51 |
+
|
| 52 |
+
set_determinism(43)
|
| 53 |
+
from monai.data import MetaTensor
|
| 54 |
+
import SimpleITK as sitk
|
| 55 |
+
import pandas as pd
|
| 56 |
+
import logging
|
| 57 |
+
def get_segmask(args):
|
| 58 |
+
|
| 59 |
+
args.seg_dir = os.path.join(args.output_dir, "prostate_mask")
|
| 60 |
+
os.makedirs(args.seg_dir, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 63 |
+
model_config_file = os.path.join(args.project_dir, "config", "inference.json")
|
| 64 |
+
model_config = ConfigParser()
|
| 65 |
+
model_config.read_config(model_config_file)
|
| 66 |
+
model_config["output_dir"] = args.seg_dir
|
| 67 |
+
model_config["dataset_dir"] = args.t2_dir
|
| 68 |
+
files = os.listdir(args.t2_dir)
|
| 69 |
+
model_config["datalist"] = [os.path.join(args.t2_dir, f) for f in files]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
checkpoint = os.path.join(
|
| 73 |
+
args.project_dir,
|
| 74 |
+
"models",
|
| 75 |
+
"prostate_segmentation_model.pt",
|
| 76 |
+
)
|
| 77 |
+
preprocessing = model_config.get_parsed_content("preprocessing")
|
| 78 |
+
model = model_config.get_parsed_content("network_def").to(device)
|
| 79 |
+
inferer = model_config.get_parsed_content("inferer")
|
| 80 |
+
postprocessing = model_config.get_parsed_content("postprocessing")
|
| 81 |
+
dataloader = model_config.get_parsed_content("dataloader")
|
| 82 |
+
model.load_state_dict(torch.load(checkpoint, map_location=device))
|
| 83 |
+
model.eval()
|
| 84 |
+
|
| 85 |
+
torch.cuda.empty_cache()
|
| 86 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 87 |
+
|
| 88 |
+
keys = "image"
|
| 89 |
+
transform = Compose(
|
| 90 |
+
[
|
| 91 |
+
LoadImaged(keys=keys),
|
| 92 |
+
EnsureChannelFirstd(keys=keys),
|
| 93 |
+
Orientationd(keys=keys, axcodes="RAS"),
|
| 94 |
+
Spacingd(keys=keys, pixdim=[0.5, 0.5, 0.5], mode="bilinear"),
|
| 95 |
+
ScaleIntensityd(keys=keys, minv=0, maxv=1),
|
| 96 |
+
NormalizeIntensityd(keys=keys),
|
| 97 |
+
EnsureTyped(keys=keys),
|
| 98 |
+
]
|
| 99 |
+
)
|
| 100 |
+
logging.info("Starting prostate segmentation")
|
| 101 |
+
for file in tqdm(files):
|
| 102 |
+
|
| 103 |
+
data = {"image": os.path.join(args.t2_dir, file)}
|
| 104 |
+
transformed_data = transform(data)
|
| 105 |
+
a = transformed_data
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
images = a["image"].reshape(1, *(a["image"].shape)).to(device)
|
| 108 |
+
data["pred"] = inferer(images, network=model)
|
| 109 |
+
pred_img = data["pred"].argmax(1).cpu()
|
| 110 |
+
|
| 111 |
+
model_output = {}
|
| 112 |
+
model_output["image"] = MetaTensor(pred_img, meta=transformed_data["image"].meta)
|
| 113 |
+
transformed_data["image"].data = model_output["image"].data
|
| 114 |
+
temp = transform.inverse(transformed_data)
|
| 115 |
+
pred_temp = temp["image"][0].numpy()
|
| 116 |
+
pred_nrrd = np.round(pred_temp)
|
| 117 |
+
|
| 118 |
+
nonzero_counts = np.count_nonzero(pred_nrrd, axis=(0,1))
|
| 119 |
+
top_slices = np.argsort(nonzero_counts)[-10:]
|
| 120 |
+
output_ = np.zeros_like(pred_nrrd)
|
| 121 |
+
output_[:,:,top_slices] = pred_nrrd[:,:,top_slices]
|
| 122 |
+
|
| 123 |
+
nrrd.write(os.path.join(args.seg_dir, file), output_)
|
| 124 |
+
|
| 125 |
+
return args
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
src/preprocessing/register_and_crop.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import SimpleITK as sitk
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import nrrd
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from picai_prep.preprocessing import PreprocessingSettings, Sample
|
| 8 |
+
import multiprocessing
|
| 9 |
+
from .center_crop import crop
|
| 10 |
+
import logging
|
| 11 |
+
def register_files(args):
|
| 12 |
+
files = os.listdir(args.t2_dir)
|
| 13 |
+
new_spacing = (0.4, 0.4, 3.0)
|
| 14 |
+
t2_registered_dir = os.path.join(args.output_dir, 't2_registered')
|
| 15 |
+
dwi_registered_dir = os.path.join(args.output_dir, 'DWI_registered')
|
| 16 |
+
adc_registered_dir = os.path.join(args.output_dir, 'ADC_registered')
|
| 17 |
+
os.makedirs(t2_registered_dir, exist_ok=True)
|
| 18 |
+
os.makedirs(dwi_registered_dir, exist_ok=True)
|
| 19 |
+
os.makedirs(adc_registered_dir, exist_ok=True)
|
| 20 |
+
logging.info("Starting registration and cropping")
|
| 21 |
+
for file in tqdm(files):
|
| 22 |
+
|
| 23 |
+
t2_image = sitk.ReadImage(os.path.join(args.t2_dir, file))
|
| 24 |
+
dwi_image = sitk.ReadImage(os.path.join(args.dwi_dir, file))
|
| 25 |
+
adc_image = sitk.ReadImage(os.path.join(args.adc_dir, file))
|
| 26 |
+
|
| 27 |
+
original_spacing = t2_image.GetSpacing()
|
| 28 |
+
original_size = t2_image.GetSize()
|
| 29 |
+
new_size = [
|
| 30 |
+
int(round(osz * ospc / nspc))
|
| 31 |
+
for osz, ospc, nspc in zip(original_size, original_spacing, new_spacing)
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
images_to_preprocess = {}
|
| 35 |
+
images_to_preprocess['t2'] = t2_image
|
| 36 |
+
images_to_preprocess['hbv'] = dwi_image
|
| 37 |
+
images_to_preprocess['adc'] = adc_image
|
| 38 |
+
|
| 39 |
+
pat_case = Sample(
|
| 40 |
+
scans=[
|
| 41 |
+
images_to_preprocess.get('t2'),
|
| 42 |
+
images_to_preprocess.get('hbv'),
|
| 43 |
+
images_to_preprocess.get('adc'),
|
| 44 |
+
],
|
| 45 |
+
settings=PreprocessingSettings(spacing=[3.0,0.4,0.4], matrix_size=[new_size[2],new_size[1],new_size[0]]),
|
| 46 |
+
)
|
| 47 |
+
pat_case.preprocess()
|
| 48 |
+
|
| 49 |
+
t2_post = pat_case.__dict__['scans'][0]
|
| 50 |
+
dwi_post = pat_case.__dict__['scans'][1]
|
| 51 |
+
adc_post = pat_case.__dict__['scans'][2]
|
| 52 |
+
cropped_t2 = crop(t2_post, [args.margin, args.margin, 0.0])
|
| 53 |
+
cropped_dwi = crop(dwi_post, [args.margin, args.margin, 0.0])
|
| 54 |
+
cropped_adc = crop(adc_post, [args.margin, args.margin, 0.0])
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
sitk.WriteImage(cropped_t2, os.path.join(t2_registered_dir, file))
|
| 59 |
+
sitk.WriteImage(cropped_dwi, os.path.join(dwi_registered_dir, file))
|
| 60 |
+
sitk.WriteImage(cropped_adc, os.path.join(adc_registered_dir, file))
|
| 61 |
+
|
| 62 |
+
args.t2_dir = t2_registered_dir
|
| 63 |
+
args.dwi_dir = dwi_registered_dir
|
| 64 |
+
args.adc_dir = adc_registered_dir
|
| 65 |
+
|
| 66 |
+
return args
|
| 67 |
+
|
src/train/__init__.py
ADDED
|
File without changes
|
src/train/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (224 Bytes). View file
|
|
|
src/train/__pycache__/train_cspca.cpython-39.pyc
ADDED
|
Binary file (4.66 kB). View file
|
|
|
src/train/__pycache__/train_pirads.cpython-39.pyc
ADDED
|
Binary file (6.63 kB). View file
|
|
|
src/train/train_cspca.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import collections.abc
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import time
|
| 6 |
+
import yaml
|
| 7 |
+
from scipy.stats import pearsonr
|
| 8 |
+
import gdown
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
import torch.multiprocessing as mp
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from monai.config import KeysCollection
|
| 16 |
+
from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
|
| 17 |
+
from monai.data.wsi_reader import WSIReader
|
| 18 |
+
from monai.metrics import Cumulative, CumulativeAverage
|
| 19 |
+
from monai.networks.nets import milmodel, resnet, MILModel
|
| 20 |
+
from monai.transforms import (
|
| 21 |
+
Compose,
|
| 22 |
+
GridPatchd,
|
| 23 |
+
LoadImaged,
|
| 24 |
+
MapTransform,
|
| 25 |
+
RandFlipd,
|
| 26 |
+
RandGridPatchd,
|
| 27 |
+
RandRotate90d,
|
| 28 |
+
ScaleIntensityRanged,
|
| 29 |
+
SplitDimd,
|
| 30 |
+
ToTensord,
|
| 31 |
+
ConcatItemsd,
|
| 32 |
+
SelectItemsd,
|
| 33 |
+
EnsureChannelFirstd,
|
| 34 |
+
RepeatChanneld,
|
| 35 |
+
DeleteItemsd,
|
| 36 |
+
EnsureTyped,
|
| 37 |
+
ClipIntensityPercentilesd,
|
| 38 |
+
MaskIntensityd,
|
| 39 |
+
HistogramNormalized,
|
| 40 |
+
RandBiasFieldd,
|
| 41 |
+
RandCropByPosNegLabeld,
|
| 42 |
+
NormalizeIntensityd,
|
| 43 |
+
SqueezeDimd,
|
| 44 |
+
CropForegroundd,
|
| 45 |
+
ScaleIntensityd,
|
| 46 |
+
SpatialPadd,
|
| 47 |
+
CenterSpatialCropd,
|
| 48 |
+
ScaleIntensityd,
|
| 49 |
+
Transposed,
|
| 50 |
+
RandWeightedCropd,
|
| 51 |
+
)
|
| 52 |
+
from sklearn.metrics import cohen_kappa_score, roc_curve, confusion_matrix
|
| 53 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 54 |
+
from torch.utils.data.dataloader import default_collate
|
| 55 |
+
from torchvision.models.resnet import ResNet50_Weights
|
| 56 |
+
import torch.optim as optim
|
| 57 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 58 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 59 |
+
import matplotlib.pyplot as plt
|
| 60 |
+
import matplotlib.patches as patches
|
| 61 |
+
from tqdm import tqdm
|
| 62 |
+
from sklearn.metrics import confusion_matrix, roc_auc_score
|
| 63 |
+
from sklearn.metrics import roc_auc_score
|
| 64 |
+
from sklearn.preprocessing import label_binarize
|
| 65 |
+
import numpy as np
|
| 66 |
+
from AIAH_utility.viewer import BasicViewer
|
| 67 |
+
from scipy.special import expit
|
| 68 |
+
import nrrd
|
| 69 |
+
import random
|
| 70 |
+
from sklearn.metrics import roc_auc_score
|
| 71 |
+
import SimpleITK as sitk
|
| 72 |
+
from AIAH_utility.viewer import BasicViewer
|
| 73 |
+
import pandas as pd
|
| 74 |
+
import json
|
| 75 |
+
from sklearn.preprocessing import StandardScaler
|
| 76 |
+
from torch.utils.data import DataLoader, TensorDataset, Dataset
|
| 77 |
+
from sklearn.linear_model import LogisticRegression
|
| 78 |
+
from sklearn.utils import resample
|
| 79 |
+
import monai
|
| 80 |
+
|
| 81 |
+
def train_epoch(cspca_model, loader, optimizer, epoch, args):
|
| 82 |
+
cspca_model.train()
|
| 83 |
+
criterion = nn.BCELoss()
|
| 84 |
+
loss = 0.0
|
| 85 |
+
run_loss = CumulativeAverage()
|
| 86 |
+
TARGETS = Cumulative()
|
| 87 |
+
PREDS = Cumulative()
|
| 88 |
+
|
| 89 |
+
for idx, batch_data in enumerate(loader):
|
| 90 |
+
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 91 |
+
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
|
| 92 |
+
|
| 93 |
+
optimizer.zero_grad()
|
| 94 |
+
output = cspca_model(data)
|
| 95 |
+
output = output.squeeze(1)
|
| 96 |
+
loss = criterion(output, target)
|
| 97 |
+
loss.backward()
|
| 98 |
+
optimizer.step()
|
| 99 |
+
|
| 100 |
+
TARGETS.extend(target.detach().cpu())
|
| 101 |
+
PREDS.extend(output.detach().cpu())
|
| 102 |
+
run_loss.append(loss.item())
|
| 103 |
+
|
| 104 |
+
loss_epoch = run_loss.aggregate()
|
| 105 |
+
target_list = TARGETS.get_buffer().cpu().numpy()
|
| 106 |
+
pred_list = PREDS.get_buffer().cpu().numpy()
|
| 107 |
+
auc_epoch = roc_auc_score(target_list, pred_list)
|
| 108 |
+
|
| 109 |
+
return loss_epoch, auc_epoch
|
| 110 |
+
|
| 111 |
+
def val_epoch(cspca_model, loader, epoch, args):
|
| 112 |
+
cspca_model.eval()
|
| 113 |
+
criterion = nn.BCELoss()
|
| 114 |
+
loss = 0.0
|
| 115 |
+
run_loss = CumulativeAverage()
|
| 116 |
+
TARGETS = Cumulative()
|
| 117 |
+
PREDS = Cumulative()
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
for idx, batch_data in enumerate(loader):
|
| 120 |
+
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 121 |
+
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
|
| 122 |
+
|
| 123 |
+
output = cspca_model(data)
|
| 124 |
+
output = output.squeeze(1)
|
| 125 |
+
loss = criterion(output, target)
|
| 126 |
+
|
| 127 |
+
TARGETS.extend(target.detach().cpu())
|
| 128 |
+
PREDS.extend(output.detach().cpu())
|
| 129 |
+
run_loss.append(loss.item())
|
| 130 |
+
|
| 131 |
+
loss_epoch = run_loss.aggregate()
|
| 132 |
+
target_list = TARGETS.get_buffer().cpu().numpy()
|
| 133 |
+
pred_list = PREDS.get_buffer().cpu().numpy()
|
| 134 |
+
auc_epoch = roc_auc_score(target_list, pred_list)
|
| 135 |
+
y_pred_categoric = (pred_list >= 0.5)
|
| 136 |
+
tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel()
|
| 137 |
+
sens_epoch = tp / (tp + fn)
|
| 138 |
+
spec_epoch = tn / (tn + fp)
|
| 139 |
+
val_epoch_metric = {'epoch': epoch, 'loss': loss_epoch, 'auc': auc_epoch, 'sensitivity': sens_epoch, 'specificity': spec_epoch}
|
| 140 |
+
return val_epoch_metric
|
| 141 |
+
|