Spaces:
Runtime error
Runtime error
Anirudh Balaraman commited on
Commit ·
1baebae
1
Parent(s): 2a68513
cleanup
Browse files- .gitignore +2 -1
- config/config_cspca_test.yaml +3 -7
- config/config_cspca_train.yaml +3 -5
- config/config_pirads_test.yaml +3 -6
- config/config_pirads_train.yaml +3 -6
- job_scripts/train_pirads.sh +19 -0
- preprocess_main.py +22 -29
- pyproject.toml +9 -0
- run_cspca.py +105 -101
- run_inference.py +41 -80
- run_pirads.py +120 -87
- src/data/custom_transforms.py +136 -44
- src/data/data_loader.py +45 -42
- src/model/MIL.py +38 -35
- src/model/csPCa_model.py +49 -10
- src/preprocessing/center_crop.py +5 -4
- src/preprocessing/generate_heatmap.py +47 -30
- src/preprocessing/histogram_match.py +34 -23
- src/preprocessing/prostate_mask.py +37 -50
- src/preprocessing/register_and_crop.py +45 -29
- src/train/train_cspca.py +14 -81
- src/train/train_pirads.py +53 -90
- src/utils.py +79 -97
- temp.ipynb +0 -0
- tests/test_run.py +106 -0
- tests/test_run_cspca.py +0 -28
- tests/test_run_pirads.py +0 -28
.gitignore
CHANGED
|
@@ -5,4 +5,5 @@ temp_data/
|
|
| 5 |
temp.ipynb
|
| 6 |
__pycache__/
|
| 7 |
**/__pycache__/
|
| 8 |
-
*.pyc
|
|
|
|
|
|
| 5 |
temp.ipynb
|
| 6 |
__pycache__/
|
| 7 |
**/__pycache__/
|
| 8 |
+
*.pyc
|
| 9 |
+
.ruff_cache
|
config/config_cspca_test.yaml
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
-
project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
|
| 2 |
data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
|
| 3 |
-
dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/
|
| 4 |
num_classes: !!int 4
|
| 5 |
mil_mode: att_trans
|
| 6 |
tile_count: !!int 24
|
|
@@ -8,10 +7,7 @@ tile_size: !!int 64
|
|
| 8 |
depth: !!int 3
|
| 9 |
use_heatmap: !!bool True
|
| 10 |
workers: !!int 6
|
| 11 |
-
checkpoint_cspca: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/
|
| 12 |
-
|
| 13 |
-
batch_size: !!int 1
|
| 14 |
-
|
| 15 |
-
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
| 1 |
data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
|
| 2 |
+
dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/dataset/PICAI_cspca.json
|
| 3 |
num_classes: !!int 4
|
| 4 |
mil_mode: att_trans
|
| 5 |
tile_count: !!int 24
|
|
|
|
| 7 |
depth: !!int 3
|
| 8 |
use_heatmap: !!bool True
|
| 9 |
workers: !!int 6
|
| 10 |
+
checkpoint_cspca: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/models/cspca_model.pth
|
| 11 |
+
batch_size: !!int 8
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
config/config_cspca_train.yaml
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
-
project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
|
| 2 |
data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
|
| 3 |
-
dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/
|
| 4 |
num_classes: !!int 4
|
| 5 |
mil_mode: att_trans
|
| 6 |
tile_count: !!int 24
|
|
@@ -8,12 +7,11 @@ tile_size: !!int 64
|
|
| 8 |
depth: !!int 3
|
| 9 |
use_heatmap: !!bool True
|
| 10 |
workers: !!int 6
|
| 11 |
-
checkpoint_pirads: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/
|
| 12 |
-
epochs: !!int
|
| 13 |
batch_size: !!int 8
|
| 14 |
optim_lr: !!float 2e-4
|
| 15 |
|
| 16 |
|
| 17 |
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
| 1 |
data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/t2_hist_matched/
|
| 2 |
+
dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/dataset/PICAI_cspca.json
|
| 3 |
num_classes: !!int 4
|
| 4 |
mil_mode: att_trans
|
| 5 |
tile_count: !!int 24
|
|
|
|
| 7 |
depth: !!int 3
|
| 8 |
use_heatmap: !!bool True
|
| 9 |
workers: !!int 6
|
| 10 |
+
checkpoint_pirads: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/models/pirads.pt
|
| 11 |
+
epochs: !!int 80
|
| 12 |
batch_size: !!int 8
|
| 13 |
optim_lr: !!float 2e-4
|
| 14 |
|
| 15 |
|
| 16 |
|
| 17 |
|
|
|
config/config_pirads_test.yaml
CHANGED
|
@@ -1,18 +1,15 @@
|
|
| 1 |
run_name: pirads_test_run
|
| 2 |
-
project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
|
| 3 |
data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/prostate-foundation/PICAI_registered/t2_hist_matched
|
| 4 |
-
dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/
|
| 5 |
num_classes: !!int 4
|
| 6 |
mil_mode: att_trans
|
| 7 |
tile_count: !!int 24
|
| 8 |
tile_size: !!int 64
|
| 9 |
depth: !!int 3
|
| 10 |
use_heatmap: !!bool True
|
| 11 |
-
workers: !!int
|
| 12 |
-
checkpoint: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/
|
| 13 |
amp: !!bool True
|
| 14 |
-
dry_run: !!bool True
|
| 15 |
-
|
| 16 |
|
| 17 |
|
| 18 |
|
|
|
|
| 1 |
run_name: pirads_test_run
|
|
|
|
| 2 |
data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/prostate-foundation/PICAI_registered/t2_hist_matched
|
| 3 |
+
dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/dataset/PI-RADS_data.json
|
| 4 |
num_classes: !!int 4
|
| 5 |
mil_mode: att_trans
|
| 6 |
tile_count: !!int 24
|
| 7 |
tile_size: !!int 64
|
| 8 |
depth: !!int 3
|
| 9 |
use_heatmap: !!bool True
|
| 10 |
+
workers: !!int 8
|
| 11 |
+
checkpoint: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/models/pirads.pt
|
| 12 |
amp: !!bool True
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
|
config/config_pirads_train.yaml
CHANGED
|
@@ -1,21 +1,18 @@
|
|
| 1 |
-
project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder/
|
| 2 |
data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/prostate-foundation/PICAI_registered/t2_hist_matched
|
| 3 |
-
dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/
|
| 4 |
num_classes: !!int 4
|
| 5 |
mil_mode: att_trans
|
| 6 |
tile_count: !!int 24
|
| 7 |
tile_size: !!int 64
|
| 8 |
depth: !!int 3
|
| 9 |
use_heatmap: !!bool True
|
| 10 |
-
workers: !!int
|
| 11 |
-
|
| 12 |
-
epochs: !!int 2
|
| 13 |
batch_size: !!int 8
|
| 14 |
optim_lr: !!float 2e-4
|
| 15 |
weight_decay: !!float 1e-5
|
| 16 |
amp: !!bool True
|
| 17 |
wandb: !!bool True
|
| 18 |
-
dry_run: !!bool True
|
| 19 |
|
| 20 |
|
| 21 |
|
|
|
|
|
|
|
| 1 |
data_root: /sc-projects/sc-proj-cc06-ag-ki-radiologie/prostate-foundation/PICAI_registered/t2_hist_matched
|
| 2 |
+
dataset_json: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/dataset/PI-RADS_data.json
|
| 3 |
num_classes: !!int 4
|
| 4 |
mil_mode: att_trans
|
| 5 |
tile_count: !!int 24
|
| 6 |
tile_size: !!int 64
|
| 7 |
depth: !!int 3
|
| 8 |
use_heatmap: !!bool True
|
| 9 |
+
workers: !!int 4
|
| 10 |
+
epochs: !!int 100
|
|
|
|
| 11 |
batch_size: !!int 8
|
| 12 |
optim_lr: !!float 2e-4
|
| 13 |
weight_decay: !!float 1e-5
|
| 14 |
amp: !!bool True
|
| 15 |
wandb: !!bool True
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
|
job_scripts/train_pirads.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=pirads_training # Specify job name
|
| 3 |
+
#SBATCH --partition=gpu # Specify partition name
|
| 4 |
+
#SBATCH --mem=128G
|
| 5 |
+
#SBATCH --gres=gpu:1
|
| 6 |
+
#SBATCH --time=48:00:00 # Set a limit on the total run time
|
| 7 |
+
#SBATCH --output=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/logs/%x/log.o%j # File name for standard output
|
| 8 |
+
#SBATCH --error=/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/logs/%x/log.e%j # File name for standard error output
|
| 9 |
+
#SBATCH --mail-user=anirudh.balaraman@charite.de
|
| 10 |
+
#SBATCH --mail-type=END,FAIL
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
source /etc/profile.d/conda.sh
|
| 14 |
+
conda activate foundation
|
| 15 |
+
|
| 16 |
+
RUNDIR="/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
srun python -u $RUNDIR/run_pirads.py --mode train --config $RUNDIR/config/config_pirads_train.yaml
|
preprocess_main.py
CHANGED
|
@@ -1,68 +1,61 @@
|
|
| 1 |
-
import SimpleITK as sitk
|
| 2 |
import os
|
| 3 |
-
import numpy as np
|
| 4 |
-
import nrrd
|
| 5 |
-
from AIAH_utility.viewer import BasicViewer
|
| 6 |
-
from tqdm import tqdm
|
| 7 |
-
import pandas as pd
|
| 8 |
-
from picai_prep.preprocessing import PreprocessingSettings, Sample
|
| 9 |
-
import multiprocessing
|
| 10 |
-
import sys
|
| 11 |
from src.preprocessing.register_and_crop import register_files
|
| 12 |
from src.preprocessing.prostate_mask import get_segmask
|
| 13 |
from src.preprocessing.histogram_match import histmatch
|
| 14 |
from src.preprocessing.generate_heatmap import get_heatmap
|
| 15 |
import logging
|
| 16 |
-
from pathlib import Path
|
| 17 |
from src.utils import setup_logging
|
| 18 |
from src.utils import validate_steps
|
| 19 |
import argparse
|
| 20 |
-
import yaml
|
|
|
|
| 21 |
|
| 22 |
def parse_args():
|
| 23 |
-
FUNCTIONS = {
|
| 24 |
-
"register_and_crop": register_files,
|
| 25 |
-
"histogram_match": histmatch,
|
| 26 |
-
"get_segmentation_mask": get_segmask,
|
| 27 |
-
"get_heatmap": get_heatmap,
|
| 28 |
-
}
|
| 29 |
parser = argparse.ArgumentParser(description="File preprocessing")
|
| 30 |
parser.add_argument("--config", type=str, help="Path to YAML config file")
|
| 31 |
parser.add_argument(
|
| 32 |
"--steps",
|
| 33 |
-
nargs="+",
|
| 34 |
-
choices=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
required=True,
|
| 36 |
-
help="Steps to execute (one or more)"
|
| 37 |
)
|
| 38 |
parser.add_argument("--t2_dir", default=None, help="Path to T2W files")
|
| 39 |
parser.add_argument("--dwi_dir", default=None, help="Path to DWI files")
|
| 40 |
parser.add_argument("--adc_dir", default=None, help="Path to ADC files")
|
| 41 |
parser.add_argument("--seg_dir", default=None, help="Path to segmentation masks")
|
| 42 |
parser.add_argument("--output_dir", default=None, help="Path to output folder")
|
| 43 |
-
parser.add_argument(
|
|
|
|
|
|
|
| 44 |
parser.add_argument("--project_dir", default=None, help="Project directory")
|
| 45 |
-
|
| 46 |
args = parser.parse_args()
|
| 47 |
if args.config:
|
| 48 |
-
with open(args.config,
|
| 49 |
config = yaml.safe_load(config_file)
|
| 50 |
args.__dict__.update(config)
|
| 51 |
return args
|
| 52 |
|
|
|
|
| 53 |
if __name__ == "__main__":
|
| 54 |
args = parse_args()
|
| 55 |
FUNCTIONS = {
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
}
|
| 61 |
|
| 62 |
-
args.logfile = os.path.join(args.output_dir,
|
| 63 |
setup_logging(args.logfile)
|
| 64 |
logging.info("Starting preprocessing")
|
| 65 |
validate_steps(args.steps)
|
| 66 |
for step in args.steps:
|
| 67 |
func = FUNCTIONS[step]
|
| 68 |
-
args = func(args)
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from src.preprocessing.register_and_crop import register_files
|
| 3 |
from src.preprocessing.prostate_mask import get_segmask
|
| 4 |
from src.preprocessing.histogram_match import histmatch
|
| 5 |
from src.preprocessing.generate_heatmap import get_heatmap
|
| 6 |
import logging
|
|
|
|
| 7 |
from src.utils import setup_logging
|
| 8 |
from src.utils import validate_steps
|
| 9 |
import argparse
|
| 10 |
+
import yaml
|
| 11 |
+
|
| 12 |
|
| 13 |
def parse_args():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
parser = argparse.ArgumentParser(description="File preprocessing")
|
| 15 |
parser.add_argument("--config", type=str, help="Path to YAML config file")
|
| 16 |
parser.add_argument(
|
| 17 |
"--steps",
|
| 18 |
+
nargs="+", # ← list of strings
|
| 19 |
+
choices=[
|
| 20 |
+
"register_and_crop",
|
| 21 |
+
"histogram_match",
|
| 22 |
+
"get_segmentation_mask",
|
| 23 |
+
"get_heatmap",
|
| 24 |
+
], # ← restrict allowed values
|
| 25 |
required=True,
|
| 26 |
+
help="Steps to execute (one or more)",
|
| 27 |
)
|
| 28 |
parser.add_argument("--t2_dir", default=None, help="Path to T2W files")
|
| 29 |
parser.add_argument("--dwi_dir", default=None, help="Path to DWI files")
|
| 30 |
parser.add_argument("--adc_dir", default=None, help="Path to ADC files")
|
| 31 |
parser.add_argument("--seg_dir", default=None, help="Path to segmentation masks")
|
| 32 |
parser.add_argument("--output_dir", default=None, help="Path to output folder")
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--margin", default=0.2, type=float, help="Margin to center crop the images"
|
| 35 |
+
)
|
| 36 |
parser.add_argument("--project_dir", default=None, help="Project directory")
|
| 37 |
+
|
| 38 |
args = parser.parse_args()
|
| 39 |
if args.config:
|
| 40 |
+
with open(args.config, "r") as config_file:
|
| 41 |
config = yaml.safe_load(config_file)
|
| 42 |
args.__dict__.update(config)
|
| 43 |
return args
|
| 44 |
|
| 45 |
+
|
| 46 |
if __name__ == "__main__":
|
| 47 |
args = parse_args()
|
| 48 |
FUNCTIONS = {
|
| 49 |
+
"register_and_crop": register_files,
|
| 50 |
+
"histogram_match": histmatch,
|
| 51 |
+
"get_segmentation_mask": get_segmask,
|
| 52 |
+
"get_heatmap": get_heatmap,
|
| 53 |
}
|
| 54 |
|
| 55 |
+
args.logfile = os.path.join(args.output_dir, "preprocessing.log")
|
| 56 |
setup_logging(args.logfile)
|
| 57 |
logging.info("Starting preprocessing")
|
| 58 |
validate_steps(args.steps)
|
| 59 |
for step in args.steps:
|
| 60 |
func = FUNCTIONS[step]
|
| 61 |
+
args = func(args)
|
pyproject.toml
CHANGED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.ruff]
|
| 2 |
+
line-length = 100
|
| 3 |
+
|
| 4 |
+
[tool.ruff.lint]
|
| 5 |
+
select = ["E", "W"]
|
| 6 |
+
ignore = ["E501"]
|
| 7 |
+
|
| 8 |
+
[tool.ruff.format]
|
| 9 |
+
quote-style = "double"
|
run_cspca.py
CHANGED
|
@@ -1,120 +1,102 @@
|
|
| 1 |
import argparse
|
| 2 |
import os
|
| 3 |
import shutil
|
| 4 |
-
import time
|
| 5 |
import yaml
|
| 6 |
import sys
|
| 7 |
-
import gdown
|
| 8 |
-
import numpy as np
|
| 9 |
import torch
|
| 10 |
-
import torch.distributed as dist
|
| 11 |
-
import torch.multiprocessing as mp
|
| 12 |
-
import torch.nn as nn
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
from monai.config import KeysCollection
|
| 15 |
-
from monai.metrics import Cumulative, CumulativeAverage
|
| 16 |
-
from monai.networks.nets import milmodel, resnet, MILModel
|
| 17 |
-
|
| 18 |
-
from sklearn.metrics import cohen_kappa_score
|
| 19 |
-
from torch.cuda.amp import GradScaler, autocast
|
| 20 |
-
from torch.utils.data.dataloader import default_collate
|
| 21 |
-
from torchvision.models.resnet import ResNet50_Weights
|
| 22 |
-
import shutil
|
| 23 |
from pathlib import Path
|
| 24 |
-
from torch.utils.data.distributed import DistributedSampler
|
| 25 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 26 |
from monai.utils import set_determinism
|
| 27 |
-
|
| 28 |
-
import matplotlib.pyplot as plt
|
| 29 |
-
|
| 30 |
-
import wandb
|
| 31 |
-
import math
|
| 32 |
import logging
|
| 33 |
-
from pathlib import Path
|
| 34 |
-
|
| 35 |
-
|
| 36 |
from src.model.MIL import MILModel_3D
|
| 37 |
from src.model.csPCa_model import csPCa_Model
|
| 38 |
from src.data.data_loader import get_dataloader
|
| 39 |
from src.utils import save_cspca_checkpoint, get_metrics, setup_logging
|
| 40 |
from src.train.train_cspca import train_epoch, val_epoch
|
|
|
|
| 41 |
|
| 42 |
-
def main_worker(args):
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
mil_mode=args.mil_mode
|
| 47 |
-
).to(args.device)
|
| 48 |
cache_dir_path = Path(os.path.join(args.logdir, "cache"))
|
| 49 |
|
| 50 |
-
if args.mode ==
|
| 51 |
-
|
| 52 |
checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
|
| 53 |
mil_model.load_state_dict(checkpoint["state_dict"])
|
| 54 |
mil_model = mil_model.to(args.device)
|
| 55 |
-
|
| 56 |
-
model_dir = os.path.join(args.
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
for st in list(range(args.num_seeds)):
|
| 59 |
set_determinism(seed=st)
|
| 60 |
-
|
| 61 |
train_loader = get_dataloader(args, split="train")
|
| 62 |
valid_loader = get_dataloader(args, split="test")
|
| 63 |
cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
|
| 64 |
-
for submodule in [
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
for param in submodule.parameters():
|
| 68 |
param.requires_grad = False
|
| 69 |
|
| 70 |
-
optimizer = torch.optim.AdamW(
|
|
|
|
|
|
|
| 71 |
|
| 72 |
-
old_loss = float(
|
| 73 |
old_auc = 0.0
|
| 74 |
for epoch in range(args.epochs):
|
| 75 |
-
train_loss, train_auc = train_epoch(
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
|
| 78 |
-
logging.info(
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
save_cspca_checkpoint(cspca_model, val_metric, model_dir)
|
| 89 |
|
| 90 |
-
metrics_dict[
|
| 91 |
-
metrics_dict[
|
| 92 |
-
metrics_dict[
|
| 93 |
if cache_dir_path.exists() and cache_dir_path.is_dir():
|
| 94 |
shutil.rmtree(cache_dir_path)
|
| 95 |
|
| 96 |
get_metrics(metrics_dict)
|
| 97 |
|
| 98 |
-
elif args.mode ==
|
| 99 |
-
|
| 100 |
cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
|
| 101 |
checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
|
| 102 |
-
cspca_model.load_state_dict(checkpt[
|
| 103 |
cspca_model = cspca_model.to(args.device)
|
| 104 |
-
if
|
| 105 |
-
auc, sens, spec = checkpt[
|
| 106 |
-
logging.info(
|
|
|
|
|
|
|
| 107 |
else:
|
| 108 |
logging.info(f"csPCa Model loaded from {args.checkpoint_cspca}.")
|
| 109 |
-
|
| 110 |
-
metrics_dict = {
|
| 111 |
for st in list(range(args.num_seeds)):
|
| 112 |
set_determinism(seed=st)
|
| 113 |
test_loader = get_dataloader(args, split="test")
|
| 114 |
test_metric = val_epoch(cspca_model, test_loader, epoch=0, args=args)
|
| 115 |
-
metrics_dict[
|
| 116 |
-
metrics_dict[
|
| 117 |
-
metrics_dict[
|
| 118 |
|
| 119 |
if cache_dir_path.exists() and cache_dir_path.is_dir():
|
| 120 |
shutil.rmtree(cache_dir_path)
|
|
@@ -122,43 +104,58 @@ def main_worker(args):
|
|
| 122 |
get_metrics(metrics_dict)
|
| 123 |
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
def parse_args():
|
| 128 |
-
parser = argparse.ArgumentParser(
|
| 129 |
-
|
| 130 |
-
parser.add_argument('--run_name', type=str, default='train_cspca', help='run name for log file')
|
| 131 |
-
parser.add_argument('--config', type=str, help='Path to YAML config file')
|
| 132 |
-
parser.add_argument(
|
| 133 |
-
"--project_dir", default=None, help="path to project firectory"
|
| 134 |
)
|
| 135 |
parser.add_argument(
|
| 136 |
-
"--
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
|
| 139 |
parser.add_argument("--num_classes", default=4, type=int, help="number of output classes")
|
| 140 |
-
parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm: choose either att_trans or att_pyramid")
|
| 141 |
parser.add_argument(
|
| 142 |
-
"--
|
|
|
|
|
|
|
| 143 |
)
|
| 144 |
-
parser.add_argument("--tile_size", default=64, type=int, help="size of square patch (instance) in pixels")
|
| 145 |
-
parser.add_argument("--depth", default=3, type=int, help="number of slices in each 3D patch (instance)")
|
| 146 |
parser.add_argument(
|
| 147 |
-
"--
|
| 148 |
-
|
|
|
|
|
|
|
| 149 |
)
|
| 150 |
parser.add_argument(
|
| 151 |
-
"--
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
)
|
| 154 |
parser.set_defaults(use_heatmap=True)
|
| 155 |
parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
|
| 156 |
-
#parser.add_argument("--dry-run", action="store_true")
|
| 157 |
parser.add_argument("--checkpoint_pirads", default=None, help="Load PI-RADS model")
|
| 158 |
-
parser.add_argument(
|
|
|
|
|
|
|
| 159 |
parser.add_argument("--batch_size", default=32, type=int, help="number of MRI scans per batch")
|
| 160 |
parser.add_argument("--optim_lr", default=2e-4, type=float, help="initial learning rate")
|
| 161 |
-
#parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
|
| 162 |
parser.add_argument(
|
| 163 |
"--val_every",
|
| 164 |
"--val_interval",
|
|
@@ -166,42 +163,50 @@ def parse_args():
|
|
| 166 |
type=int,
|
| 167 |
help="run validation after this number of epochs, default 1 to run every epoch",
|
| 168 |
)
|
| 169 |
-
parser.add_argument(
|
|
|
|
|
|
|
| 170 |
parser.add_argument("--checkpoint_cspca", default=None, help="load existing checkpoint")
|
| 171 |
-
parser.add_argument(
|
|
|
|
|
|
|
| 172 |
args = parser.parse_args()
|
| 173 |
if args.config:
|
| 174 |
-
with open(args.config,
|
| 175 |
config = yaml.safe_load(config_file)
|
| 176 |
args.__dict__.update(config)
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
return args
|
| 181 |
|
| 182 |
|
| 183 |
-
|
| 184 |
if __name__ == "__main__":
|
|
|
|
| 185 |
args = parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
args.logdir = os.path.join(args.project_dir, "logs", args.run_name)
|
| 187 |
os.makedirs(args.logdir, exist_ok=True)
|
| 188 |
args.logfile = os.path.join(args.logdir, f"{args.run_name}.log")
|
| 189 |
setup_logging(args.logfile)
|
| 190 |
|
| 191 |
-
|
| 192 |
logging.info("Argument values:")
|
| 193 |
for k, v in vars(args).items():
|
| 194 |
logging.info(f"{k} => {v}")
|
| 195 |
logging.info("-----------------")
|
| 196 |
|
| 197 |
if args.dataset_json is None:
|
| 198 |
-
logging.error(
|
| 199 |
sys.exit(1)
|
| 200 |
-
if args.checkpoint_pirads is None and args.mode ==
|
| 201 |
-
logging.error(
|
| 202 |
sys.exit(1)
|
| 203 |
-
elif args.checkpoint_cspca is None and args.mode ==
|
| 204 |
-
logging.error(
|
| 205 |
sys.exit(1)
|
| 206 |
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 207 |
|
|
@@ -216,5 +221,4 @@ if __name__ == "__main__":
|
|
| 216 |
args.num_seeds = 2
|
| 217 |
args.wandb = False
|
| 218 |
|
| 219 |
-
|
| 220 |
main_worker(args)
|
|
|
|
| 1 |
import argparse
|
| 2 |
import os
|
| 3 |
import shutil
|
|
|
|
| 4 |
import yaml
|
| 5 |
import sys
|
|
|
|
|
|
|
| 6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from pathlib import Path
|
|
|
|
|
|
|
| 8 |
from monai.utils import set_determinism
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import logging
|
|
|
|
|
|
|
|
|
|
| 10 |
from src.model.MIL import MILModel_3D
|
| 11 |
from src.model.csPCa_model import csPCa_Model
|
| 12 |
from src.data.data_loader import get_dataloader
|
| 13 |
from src.utils import save_cspca_checkpoint, get_metrics, setup_logging
|
| 14 |
from src.train.train_cspca import train_epoch, val_epoch
|
| 15 |
+
import random
|
| 16 |
|
|
|
|
| 17 |
|
| 18 |
+
def main_worker(args):
|
| 19 |
+
mil_model = MILModel_3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
|
|
|
|
|
|
|
| 20 |
cache_dir_path = Path(os.path.join(args.logdir, "cache"))
|
| 21 |
|
| 22 |
+
if args.mode == "train":
|
|
|
|
| 23 |
checkpoint = torch.load(args.checkpoint_pirads, weights_only=False, map_location="cpu")
|
| 24 |
mil_model.load_state_dict(checkpoint["state_dict"])
|
| 25 |
mil_model = mil_model.to(args.device)
|
| 26 |
+
|
| 27 |
+
model_dir = os.path.join(args.logdir, "models")
|
| 28 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
metrics_dict = {"auc": [], "sensitivity": [], "specificity": []}
|
| 31 |
for st in list(range(args.num_seeds)):
|
| 32 |
set_determinism(seed=st)
|
| 33 |
+
|
| 34 |
train_loader = get_dataloader(args, split="train")
|
| 35 |
valid_loader = get_dataloader(args, split="test")
|
| 36 |
cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
|
| 37 |
+
for submodule in [
|
| 38 |
+
cspca_model.backbone.net,
|
| 39 |
+
cspca_model.backbone.myfc,
|
| 40 |
+
cspca_model.backbone.transformer,
|
| 41 |
+
]:
|
| 42 |
for param in submodule.parameters():
|
| 43 |
param.requires_grad = False
|
| 44 |
|
| 45 |
+
optimizer = torch.optim.AdamW(
|
| 46 |
+
filter(lambda p: p.requires_grad, cspca_model.parameters()), lr=args.optim_lr
|
| 47 |
+
)
|
| 48 |
|
| 49 |
+
old_loss = float("inf")
|
| 50 |
old_auc = 0.0
|
| 51 |
for epoch in range(args.epochs):
|
| 52 |
+
train_loss, train_auc = train_epoch(
|
| 53 |
+
cspca_model, train_loader, optimizer, epoch=epoch, args=args
|
| 54 |
+
)
|
| 55 |
+
logging.info(
|
| 56 |
+
f"STATE {st} EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}"
|
| 57 |
+
)
|
| 58 |
val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
|
| 59 |
+
logging.info(
|
| 60 |
+
f"STATE {st} EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
|
| 61 |
+
)
|
| 62 |
+
val_metric["state"] = st
|
| 63 |
+
if val_metric["loss"] < old_loss:
|
| 64 |
+
old_loss = val_metric["loss"]
|
| 65 |
+
old_auc = val_metric["auc"]
|
| 66 |
+
sensitivity = val_metric["sensitivity"]
|
| 67 |
+
specificity = val_metric["specificity"]
|
| 68 |
+
if not metrics_dict["auc"] or val_metric["auc"] >= max(metrics_dict["auc"]):
|
| 69 |
save_cspca_checkpoint(cspca_model, val_metric, model_dir)
|
| 70 |
|
| 71 |
+
metrics_dict["auc"].append(old_auc)
|
| 72 |
+
metrics_dict["sensitivity"].append(sensitivity)
|
| 73 |
+
metrics_dict["specificity"].append(specificity)
|
| 74 |
if cache_dir_path.exists() and cache_dir_path.is_dir():
|
| 75 |
shutil.rmtree(cache_dir_path)
|
| 76 |
|
| 77 |
get_metrics(metrics_dict)
|
| 78 |
|
| 79 |
+
elif args.mode == "test":
|
|
|
|
| 80 |
cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
|
| 81 |
checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
|
| 82 |
+
cspca_model.load_state_dict(checkpt["state_dict"])
|
| 83 |
cspca_model = cspca_model.to(args.device)
|
| 84 |
+
if "auc" in checkpt and "sensitivity" in checkpt and "specificity" in checkpt:
|
| 85 |
+
auc, sens, spec = checkpt["auc"], checkpt["sensitivity"], checkpt["specificity"]
|
| 86 |
+
logging.info(
|
| 87 |
+
f"csPCa Model loaded from {args.checkpoint_cspca} with AUC: {auc}, Sensitivity: {sens}, Specificity: {spec} on the test set."
|
| 88 |
+
)
|
| 89 |
else:
|
| 90 |
logging.info(f"csPCa Model loaded from {args.checkpoint_cspca}.")
|
| 91 |
+
|
| 92 |
+
metrics_dict = {"auc": [], "sensitivity": [], "specificity": []}
|
| 93 |
for st in list(range(args.num_seeds)):
|
| 94 |
set_determinism(seed=st)
|
| 95 |
test_loader = get_dataloader(args, split="test")
|
| 96 |
test_metric = val_epoch(cspca_model, test_loader, epoch=0, args=args)
|
| 97 |
+
metrics_dict["auc"].append(test_metric["auc"])
|
| 98 |
+
metrics_dict["sensitivity"].append(test_metric["sensitivity"])
|
| 99 |
+
metrics_dict["specificity"].append(test_metric["specificity"])
|
| 100 |
|
| 101 |
if cache_dir_path.exists() and cache_dir_path.is_dir():
|
| 102 |
shutil.rmtree(cache_dir_path)
|
|
|
|
| 104 |
get_metrics(metrics_dict)
|
| 105 |
|
| 106 |
|
|
|
|
|
|
|
| 107 |
def parse_args():
|
| 108 |
+
parser = argparse.ArgumentParser(
|
| 109 |
+
description="Multiple Instance Learning (MIL) for csPCa risk prediction."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
)
|
| 111 |
parser.add_argument(
|
| 112 |
+
"--mode",
|
| 113 |
+
type=str,
|
| 114 |
+
choices=["train", "test"],
|
| 115 |
+
required=True,
|
| 116 |
+
help="Operation mode: train or infer",
|
| 117 |
)
|
| 118 |
+
parser.add_argument("--run_name", type=str, default="train_cspca", help="run name for log file")
|
| 119 |
+
parser.add_argument("--config", type=str, help="Path to YAML config file")
|
| 120 |
+
parser.add_argument("--project_dir", default=None, help="path to project firectory")
|
| 121 |
+
parser.add_argument("--data_root", default=None, help="path to root folder of images")
|
| 122 |
parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
|
| 123 |
parser.add_argument("--num_classes", default=4, type=int, help="number of output classes")
|
|
|
|
| 124 |
parser.add_argument(
|
| 125 |
+
"--mil_mode",
|
| 126 |
+
default="att_trans",
|
| 127 |
+
help="MIL algorithm: choose either att_trans or att_pyramid",
|
| 128 |
)
|
|
|
|
|
|
|
| 129 |
parser.add_argument(
|
| 130 |
+
"--tile_count",
|
| 131 |
+
default=24,
|
| 132 |
+
type=int,
|
| 133 |
+
help="number of patches (instances) to extract from MRI input",
|
| 134 |
)
|
| 135 |
parser.add_argument(
|
| 136 |
+
"--tile_size", default=64, type=int, help="size of square patch (instance) in pixels"
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--depth", default=3, type=int, help="number of slices in each 3D patch (instance)"
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--use_heatmap",
|
| 143 |
+
action="store_true",
|
| 144 |
+
help="enable weak attention heatmap guided patch generation",
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--no_heatmap", dest="use_heatmap", action="store_false", help="disable heatmap"
|
| 148 |
)
|
| 149 |
parser.set_defaults(use_heatmap=True)
|
| 150 |
parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
|
| 151 |
+
# parser.add_argument("--dry-run", action="store_true")
|
| 152 |
parser.add_argument("--checkpoint_pirads", default=None, help="Load PI-RADS model")
|
| 153 |
+
parser.add_argument(
|
| 154 |
+
"--epochs", "--max_epochs", default=30, type=int, help="number of training epochs"
|
| 155 |
+
)
|
| 156 |
parser.add_argument("--batch_size", default=32, type=int, help="number of MRI scans per batch")
|
| 157 |
parser.add_argument("--optim_lr", default=2e-4, type=float, help="initial learning rate")
|
| 158 |
+
# parser.add_argument("--amp", action="store_true", help="use AMP, recommended")
|
| 159 |
parser.add_argument(
|
| 160 |
"--val_every",
|
| 161 |
"--val_interval",
|
|
|
|
| 163 |
type=int,
|
| 164 |
help="run validation after this number of epochs, default 1 to run every epoch",
|
| 165 |
)
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--dry_run", action="store_true", help="Run the script in dry-run mode (default: False)"
|
| 168 |
+
)
|
| 169 |
parser.add_argument("--checkpoint_cspca", default=None, help="load existing checkpoint")
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--num_seeds", default=20, type=int, help="number of seeds to be run to build CI"
|
| 172 |
+
)
|
| 173 |
args = parser.parse_args()
|
| 174 |
if args.config:
|
| 175 |
+
with open(args.config, "r") as config_file:
|
| 176 |
config = yaml.safe_load(config_file)
|
| 177 |
args.__dict__.update(config)
|
| 178 |
|
|
|
|
|
|
|
| 179 |
return args
|
| 180 |
|
| 181 |
|
|
|
|
| 182 |
if __name__ == "__main__":
|
| 183 |
+
|
| 184 |
args = parse_args()
|
| 185 |
+
if args.project_dir is None:
|
| 186 |
+
args.project_dir = Path(__file__).resolve().parent # Set project directory
|
| 187 |
+
|
| 188 |
+
slurm_job_name = os.getenv('SLURM_JOB_NAME') # If the script is submitted via slurm, job name is the run name
|
| 189 |
+
if slurm_job_name:
|
| 190 |
+
args.run_name = slurm_job_name
|
| 191 |
+
|
| 192 |
args.logdir = os.path.join(args.project_dir, "logs", args.run_name)
|
| 193 |
os.makedirs(args.logdir, exist_ok=True)
|
| 194 |
args.logfile = os.path.join(args.logdir, f"{args.run_name}.log")
|
| 195 |
setup_logging(args.logfile)
|
| 196 |
|
|
|
|
| 197 |
logging.info("Argument values:")
|
| 198 |
for k, v in vars(args).items():
|
| 199 |
logging.info(f"{k} => {v}")
|
| 200 |
logging.info("-----------------")
|
| 201 |
|
| 202 |
if args.dataset_json is None:
|
| 203 |
+
logging.error("Dataset path not provided. Quitting.")
|
| 204 |
sys.exit(1)
|
| 205 |
+
if args.checkpoint_pirads is None and args.mode == "train":
|
| 206 |
+
logging.error("PI-RADS checkpoint path not provided. Quitting.")
|
| 207 |
sys.exit(1)
|
| 208 |
+
elif args.checkpoint_cspca is None and args.mode == "test":
|
| 209 |
+
logging.error("csPCa checkpoint path not provided. Quitting.")
|
| 210 |
sys.exit(1)
|
| 211 |
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 212 |
|
|
|
|
| 221 |
args.num_seeds = 2
|
| 222 |
args.wandb = False
|
| 223 |
|
|
|
|
| 224 |
main_worker(args)
|
run_inference.py
CHANGED
|
@@ -1,66 +1,21 @@
|
|
| 1 |
import argparse
|
| 2 |
import os
|
| 3 |
-
import shutil
|
| 4 |
-
import time
|
| 5 |
import yaml
|
| 6 |
-
import sys
|
| 7 |
-
import gdown
|
| 8 |
-
import numpy as np
|
| 9 |
import torch
|
| 10 |
-
import torch.distributed as dist
|
| 11 |
-
import torch.multiprocessing as mp
|
| 12 |
-
import torch.nn as nn
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
from monai.config import KeysCollection
|
| 15 |
-
from monai.metrics import Cumulative, CumulativeAverage
|
| 16 |
-
from monai.networks.nets import milmodel, resnet, MILModel
|
| 17 |
-
|
| 18 |
-
from sklearn.metrics import cohen_kappa_score
|
| 19 |
-
from torch.cuda.amp import GradScaler, autocast
|
| 20 |
-
from torch.utils.data.dataloader import default_collate
|
| 21 |
-
from torchvision.models.resnet import ResNet50_Weights
|
| 22 |
-
import shutil
|
| 23 |
-
from pathlib import Path
|
| 24 |
-
from torch.utils.data.distributed import DistributedSampler
|
| 25 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 26 |
-
from monai.utils import set_determinism
|
| 27 |
-
import matplotlib.pyplot as plt
|
| 28 |
-
import wandb
|
| 29 |
-
import math
|
| 30 |
import logging
|
| 31 |
-
from pathlib import Path
|
| 32 |
-
|
| 33 |
-
|
| 34 |
from src.model.MIL import MILModel_3D
|
| 35 |
from src.model.csPCa_model import csPCa_Model
|
| 36 |
-
from src.
|
| 37 |
-
from src.utils import save_cspca_checkpoint, get_metrics, setup_logging, save_pirads_checkpoint, get_parent_image, get_patch_coordinate
|
| 38 |
-
from src.train import train_cspca, train_pirads
|
| 39 |
-
import SimpleITK as sitk
|
| 40 |
-
|
| 41 |
-
import nrrd
|
| 42 |
-
|
| 43 |
-
from tqdm import tqdm
|
| 44 |
-
import pandas as pd
|
| 45 |
-
from picai_prep.preprocessing import PreprocessingSettings, Sample
|
| 46 |
-
import multiprocessing
|
| 47 |
-
import sys
|
| 48 |
from src.preprocessing.register_and_crop import register_files
|
| 49 |
from src.preprocessing.prostate_mask import get_segmask
|
| 50 |
from src.preprocessing.histogram_match import histmatch
|
| 51 |
from src.preprocessing.generate_heatmap import get_heatmap
|
| 52 |
-
import logging
|
| 53 |
-
from pathlib import Path
|
| 54 |
-
from src.utils import setup_logging
|
| 55 |
-
from src.utils import validate_steps
|
| 56 |
-
import argparse
|
| 57 |
-
import yaml
|
| 58 |
from src.data.data_loader import data_transform, list_data_collate
|
| 59 |
-
from monai.data import Dataset
|
| 60 |
import json
|
| 61 |
|
| 62 |
-
def parse_args():
|
| 63 |
|
|
|
|
| 64 |
parser = argparse.ArgumentParser(description="File preprocessing")
|
| 65 |
parser.add_argument("--config", type=str, help="Path to YAML config file")
|
| 66 |
parser.add_argument("--t2_dir", default=None, help="Path to T2W files")
|
|
@@ -68,7 +23,9 @@ def parse_args():
|
|
| 68 |
parser.add_argument("--adc_dir", default=None, help="Path to ADC files")
|
| 69 |
parser.add_argument("--seg_dir", default=None, help="Path to segmentation masks")
|
| 70 |
parser.add_argument("--output_dir", default=None, help="Path to output folder")
|
| 71 |
-
parser.add_argument(
|
|
|
|
|
|
|
| 72 |
parser.add_argument("--num_classes", default=4, type=int)
|
| 73 |
parser.add_argument("--mil_mode", default="att_trans", type=str)
|
| 74 |
parser.add_argument("--use_heatmap", default=True, type=bool)
|
|
@@ -76,47 +33,49 @@ def parse_args():
|
|
| 76 |
parser.add_argument("--tile_count", default=24, type=int)
|
| 77 |
parser.add_argument("--depth", default=3, type=int)
|
| 78 |
parser.add_argument("--project_dir", default=None, help="Project directory")
|
| 79 |
-
|
| 80 |
args = parser.parse_args()
|
| 81 |
if args.config:
|
| 82 |
-
with open(args.config,
|
| 83 |
config = yaml.safe_load(config_file)
|
| 84 |
args.__dict__.update(config)
|
| 85 |
return args
|
| 86 |
|
|
|
|
| 87 |
if __name__ == "__main__":
|
| 88 |
args = parse_args()
|
| 89 |
FUNCTIONS = {
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
}
|
| 95 |
|
| 96 |
-
args.logfile = os.path.join(args.output_dir,
|
| 97 |
setup_logging(args.logfile)
|
| 98 |
logging.info("Starting preprocessing")
|
| 99 |
steps = ["register_and_crop", "get_segmentation_mask", "histogram_match", "get_heatmap"]
|
| 100 |
for step in steps:
|
| 101 |
func = FUNCTIONS[step]
|
| 102 |
-
args = func(args)
|
| 103 |
|
| 104 |
logging.info("Preprocessing completed.")
|
| 105 |
|
| 106 |
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 107 |
|
| 108 |
logging.info("Loading PIRADS model")
|
| 109 |
-
pirads_model = MILModel_3D(
|
| 110 |
-
|
| 111 |
-
|
| 112 |
)
|
| 113 |
-
pirads_checkpoint = torch.load(os.path.join(args.project_dir, 'models', 'pirads.pt'), map_location="cpu")
|
| 114 |
pirads_model.load_state_dict(pirads_checkpoint["state_dict"])
|
| 115 |
pirads_model.to(args.device)
|
| 116 |
logging.info("Loading csPCa model")
|
| 117 |
cspca_model = csPCa_Model(backbone=pirads_model).to(args.device)
|
| 118 |
-
checkpt = torch.load(
|
| 119 |
-
|
|
|
|
|
|
|
| 120 |
cspca_model = cspca_model.to(args.device)
|
| 121 |
|
| 122 |
transform = data_transform(args)
|
|
@@ -124,12 +83,12 @@ if __name__ == "__main__":
|
|
| 124 |
args.data_list = []
|
| 125 |
for file in files:
|
| 126 |
temp = {}
|
| 127 |
-
temp[
|
| 128 |
-
temp[
|
| 129 |
-
temp[
|
| 130 |
-
temp[
|
| 131 |
-
temp[
|
| 132 |
-
temp[
|
| 133 |
args.data_list.append(temp)
|
| 134 |
|
| 135 |
dataset = Dataset(data=args.data_list, transform=transform)
|
|
@@ -139,7 +98,7 @@ if __name__ == "__main__":
|
|
| 139 |
shuffle=False,
|
| 140 |
num_workers=0,
|
| 141 |
pin_memory=True,
|
| 142 |
-
multiprocessing_context=
|
| 143 |
sampler=None,
|
| 144 |
collate_fn=list_data_collate,
|
| 145 |
)
|
|
@@ -153,7 +112,7 @@ if __name__ == "__main__":
|
|
| 153 |
for idx, batch_data in enumerate(loader):
|
| 154 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 155 |
logits = pirads_model(data)
|
| 156 |
-
pirads_score= torch.argmax(logits, dim=1)
|
| 157 |
pirads_list.append(pirads_score.item())
|
| 158 |
|
| 159 |
output = cspca_model(data)
|
|
@@ -181,17 +140,19 @@ if __name__ == "__main__":
|
|
| 181 |
for i in args.data_list:
|
| 182 |
parent_image = get_parent_image([i], args)
|
| 183 |
|
| 184 |
-
coords = get_patch_coordinate(patches_top_5, parent_image
|
| 185 |
coords_list.append(coords)
|
| 186 |
output_dict = {}
|
| 187 |
|
| 188 |
-
for i,j in enumerate(files):
|
| 189 |
-
logging.info(
|
| 190 |
-
|
|
|
|
| 191 |
output_dict[j] = {
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
}
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
import os
|
|
|
|
|
|
|
| 3 |
import yaml
|
|
|
|
|
|
|
|
|
|
| 4 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import logging
|
|
|
|
|
|
|
|
|
|
| 6 |
from src.model.MIL import MILModel_3D
|
| 7 |
from src.model.csPCa_model import csPCa_Model
|
| 8 |
+
from src.utils import setup_logging, get_parent_image, get_patch_coordinate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from src.preprocessing.register_and_crop import register_files
|
| 10 |
from src.preprocessing.prostate_mask import get_segmask
|
| 11 |
from src.preprocessing.histogram_match import histmatch
|
| 12 |
from src.preprocessing.generate_heatmap import get_heatmap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from src.data.data_loader import data_transform, list_data_collate
|
| 14 |
+
from monai.data import Dataset
|
| 15 |
import json
|
| 16 |
|
|
|
|
| 17 |
|
| 18 |
+
def parse_args():
|
| 19 |
parser = argparse.ArgumentParser(description="File preprocessing")
|
| 20 |
parser.add_argument("--config", type=str, help="Path to YAML config file")
|
| 21 |
parser.add_argument("--t2_dir", default=None, help="Path to T2W files")
|
|
|
|
| 23 |
parser.add_argument("--adc_dir", default=None, help="Path to ADC files")
|
| 24 |
parser.add_argument("--seg_dir", default=None, help="Path to segmentation masks")
|
| 25 |
parser.add_argument("--output_dir", default=None, help="Path to output folder")
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--margin", default=0.2, type=float, help="Margin to center crop the images"
|
| 28 |
+
)
|
| 29 |
parser.add_argument("--num_classes", default=4, type=int)
|
| 30 |
parser.add_argument("--mil_mode", default="att_trans", type=str)
|
| 31 |
parser.add_argument("--use_heatmap", default=True, type=bool)
|
|
|
|
| 33 |
parser.add_argument("--tile_count", default=24, type=int)
|
| 34 |
parser.add_argument("--depth", default=3, type=int)
|
| 35 |
parser.add_argument("--project_dir", default=None, help="Project directory")
|
| 36 |
+
|
| 37 |
args = parser.parse_args()
|
| 38 |
if args.config:
|
| 39 |
+
with open(args.config, "r") as config_file:
|
| 40 |
config = yaml.safe_load(config_file)
|
| 41 |
args.__dict__.update(config)
|
| 42 |
return args
|
| 43 |
|
| 44 |
+
|
| 45 |
if __name__ == "__main__":
|
| 46 |
args = parse_args()
|
| 47 |
FUNCTIONS = {
|
| 48 |
+
"register_and_crop": register_files,
|
| 49 |
+
"histogram_match": histmatch,
|
| 50 |
+
"get_segmentation_mask": get_segmask,
|
| 51 |
+
"get_heatmap": get_heatmap,
|
| 52 |
}
|
| 53 |
|
| 54 |
+
args.logfile = os.path.join(args.output_dir, "inference.log")
|
| 55 |
setup_logging(args.logfile)
|
| 56 |
logging.info("Starting preprocessing")
|
| 57 |
steps = ["register_and_crop", "get_segmentation_mask", "histogram_match", "get_heatmap"]
|
| 58 |
for step in steps:
|
| 59 |
func = FUNCTIONS[step]
|
| 60 |
+
args = func(args)
|
| 61 |
|
| 62 |
logging.info("Preprocessing completed.")
|
| 63 |
|
| 64 |
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 65 |
|
| 66 |
logging.info("Loading PIRADS model")
|
| 67 |
+
pirads_model = MILModel_3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
|
| 68 |
+
pirads_checkpoint = torch.load(
|
| 69 |
+
os.path.join(args.project_dir, "models", "pirads.pt"), map_location="cpu"
|
| 70 |
)
|
|
|
|
| 71 |
pirads_model.load_state_dict(pirads_checkpoint["state_dict"])
|
| 72 |
pirads_model.to(args.device)
|
| 73 |
logging.info("Loading csPCa model")
|
| 74 |
cspca_model = csPCa_Model(backbone=pirads_model).to(args.device)
|
| 75 |
+
checkpt = torch.load(
|
| 76 |
+
os.path.join(args.project_dir, "models", "cspca_model.pth"), map_location="cpu"
|
| 77 |
+
)
|
| 78 |
+
cspca_model.load_state_dict(checkpt["state_dict"])
|
| 79 |
cspca_model = cspca_model.to(args.device)
|
| 80 |
|
| 81 |
transform = data_transform(args)
|
|
|
|
| 83 |
args.data_list = []
|
| 84 |
for file in files:
|
| 85 |
temp = {}
|
| 86 |
+
temp["image"] = os.path.join(args.t2_dir, file)
|
| 87 |
+
temp["dwi"] = os.path.join(args.dwi_dir, file)
|
| 88 |
+
temp["adc"] = os.path.join(args.adc_dir, file)
|
| 89 |
+
temp["heatmap"] = os.path.join(args.heatmapdir, file)
|
| 90 |
+
temp["mask"] = os.path.join(args.seg_dir, file)
|
| 91 |
+
temp["label"] = 0 # dummy label
|
| 92 |
args.data_list.append(temp)
|
| 93 |
|
| 94 |
dataset = Dataset(data=args.data_list, transform=transform)
|
|
|
|
| 98 |
shuffle=False,
|
| 99 |
num_workers=0,
|
| 100 |
pin_memory=True,
|
| 101 |
+
multiprocessing_context=None,
|
| 102 |
sampler=None,
|
| 103 |
collate_fn=list_data_collate,
|
| 104 |
)
|
|
|
|
| 112 |
for idx, batch_data in enumerate(loader):
|
| 113 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 114 |
logits = pirads_model(data)
|
| 115 |
+
pirads_score = torch.argmax(logits, dim=1)
|
| 116 |
pirads_list.append(pirads_score.item())
|
| 117 |
|
| 118 |
output = cspca_model(data)
|
|
|
|
| 140 |
for i in args.data_list:
|
| 141 |
parent_image = get_parent_image([i], args)
|
| 142 |
|
| 143 |
+
coords = get_patch_coordinate(patches_top_5, parent_image)
|
| 144 |
coords_list.append(coords)
|
| 145 |
output_dict = {}
|
| 146 |
|
| 147 |
+
for i, j in enumerate(files):
|
| 148 |
+
logging.info(
|
| 149 |
+
f"File: {j}, PIRADS score: {pirads_list[i] + 2.0}, csPCa risk score: {cspca_risk_list[i]:.4f}"
|
| 150 |
+
)
|
| 151 |
output_dict[j] = {
|
| 152 |
+
"Predicted PIRAD Score": pirads_list[i] + 2.0,
|
| 153 |
+
"csPCa risk": cspca_risk_list[i],
|
| 154 |
+
"Top left coordinate of top 5 patches(x,y,z)": coords_list[i],
|
| 155 |
}
|
| 156 |
+
|
| 157 |
+
with open(os.path.join(args.output_dir, "results.json"), "w") as f:
|
| 158 |
+
json.dump(output_dict, f, indent=4)
|
run_pirads.py
CHANGED
|
@@ -1,32 +1,14 @@
|
|
| 1 |
import argparse
|
| 2 |
-
import collections.abc
|
| 3 |
import os
|
| 4 |
import shutil
|
| 5 |
import time
|
| 6 |
import yaml
|
| 7 |
import sys
|
| 8 |
-
import gdown
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
-
import torch.distributed as dist
|
| 12 |
-
import torch.multiprocessing as mp
|
| 13 |
-
import torch.nn as nn
|
| 14 |
-
import torch.nn.functional as F
|
| 15 |
-
from monai.config import KeysCollection
|
| 16 |
-
from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
|
| 17 |
-
from monai.data.wsi_reader import WSIReader
|
| 18 |
-
from monai.metrics import Cumulative, CumulativeAverage
|
| 19 |
-
from monai.networks.nets import milmodel, resnet, MILModel
|
| 20 |
-
|
| 21 |
-
from sklearn.metrics import cohen_kappa_score
|
| 22 |
-
from torch.cuda.amp import GradScaler, autocast
|
| 23 |
-
from torch.utils.data.dataloader import default_collate
|
| 24 |
from torch.utils.tensorboard import SummaryWriter
|
| 25 |
from monai.utils import set_determinism
|
| 26 |
-
import matplotlib.pyplot as plt
|
| 27 |
-
|
| 28 |
import wandb
|
| 29 |
-
import math
|
| 30 |
import logging
|
| 31 |
from pathlib import Path
|
| 32 |
from src.data.data_loader import get_dataloader
|
|
@@ -37,13 +19,9 @@ from src.utils import save_pirads_checkpoint, setup_logging
|
|
| 37 |
|
| 38 |
def main_worker(args):
|
| 39 |
if args.device == torch.device("cuda"):
|
| 40 |
-
torch.cuda.set_device(args.gpu) # use this default device (same as args.device if not distributed)
|
| 41 |
torch.backends.cudnn.benchmark = True
|
| 42 |
|
| 43 |
-
model = MILModel_3D(
|
| 44 |
-
num_classes=args.num_classes,
|
| 45 |
-
mil_mode=args.mil_mode
|
| 46 |
-
)
|
| 47 |
start_epoch = 0
|
| 48 |
best_acc = 0.0
|
| 49 |
if args.checkpoint is not None:
|
|
@@ -54,41 +32,54 @@ def main_worker(args):
|
|
| 54 |
start_epoch = checkpoint["epoch"]
|
| 55 |
if "best_acc" in checkpoint:
|
| 56 |
best_acc = checkpoint["best_acc"]
|
| 57 |
-
logging.info(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
cache_dir_ = os.path.join(args.logdir, "cache")
|
| 59 |
model.to(args.device)
|
| 60 |
params = model.parameters()
|
| 61 |
-
if args.mode ==
|
| 62 |
-
train_loader = get_dataloader(args, split=
|
| 63 |
valid_loader = get_dataloader(args, split="test")
|
| 64 |
-
logging.info(
|
| 65 |
-
|
|
|
|
|
|
|
| 66 |
if args.mil_mode in ["att_trans", "att_trans_pyramid"]:
|
| 67 |
params = [
|
| 68 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
{"params": list(model.transformer.parameters()), "lr": 6e-5, "weight_decay": 0.1},
|
| 70 |
]
|
| 71 |
|
| 72 |
optimizer = torch.optim.AdamW(params, lr=args.optim_lr, weight_decay=args.weight_decay)
|
| 73 |
-
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
|
|
|
|
|
| 74 |
scaler = torch.amp.GradScaler(device=str(args.device), enabled=args.amp)
|
| 75 |
|
| 76 |
if args.logdir is not None:
|
| 77 |
writer = SummaryWriter(log_dir=args.logdir)
|
| 78 |
-
logging.info("Writing Tensorboard logs to
|
| 79 |
else:
|
| 80 |
writer = None
|
| 81 |
|
| 82 |
-
|
| 83 |
# RUN TRAINING
|
| 84 |
n_epochs = args.epochs
|
| 85 |
val_loss_min = float("inf")
|
| 86 |
epochs_no_improve = 0
|
| 87 |
for epoch in range(start_epoch, n_epochs):
|
| 88 |
-
|
| 89 |
logging.info(time.ctime(), "Epoch:", epoch)
|
| 90 |
epoch_time = time.time()
|
| 91 |
-
train_loss, train_acc, train_att_loss, batch_norm = train_epoch(
|
|
|
|
|
|
|
| 92 |
logging.info(
|
| 93 |
"Final training %d/%d loss: %.4f attention loss: %.4f acc: %.4f time %.2fs",
|
| 94 |
epoch,
|
|
@@ -98,13 +89,21 @@ def main_worker(args):
|
|
| 98 |
train_acc,
|
| 99 |
time.time() - epoch_time,
|
| 100 |
)
|
| 101 |
-
|
| 102 |
if writer is not None:
|
| 103 |
writer.add_scalar("train_loss", train_loss, epoch)
|
| 104 |
writer.add_scalar("train_attention_loss", train_att_loss, epoch)
|
| 105 |
writer.add_scalar("train_acc", train_acc, epoch)
|
| 106 |
-
wandb.log(
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
model_new_best = False
|
| 109 |
val_acc = 0
|
| 110 |
if (epoch + 1) % args.val_every == 0:
|
|
@@ -125,35 +124,39 @@ def main_worker(args):
|
|
| 125 |
writer.add_scalar("test_acc", val_acc, epoch)
|
| 126 |
writer.add_scalar("test_qwk", qwk, epoch)
|
| 127 |
|
| 128 |
-
#val_acc = qwk
|
| 129 |
-
wandb.log(
|
|
|
|
|
|
|
|
|
|
| 130 |
if val_loss < val_loss_min:
|
| 131 |
logging.info("Loss (%.6f --> %.6f)", val_loss_min, val_loss)
|
| 132 |
val_loss_min = val_loss
|
| 133 |
model_new_best = True
|
| 134 |
|
| 135 |
if args.logdir is not None:
|
| 136 |
-
save_pirads_checkpoint(
|
|
|
|
|
|
|
| 137 |
if model_new_best:
|
| 138 |
-
logging.info("Copying to model.pt new best model
|
| 139 |
-
shutil.copyfile(
|
|
|
|
|
|
|
|
|
|
| 140 |
epochs_no_improve = 0
|
| 141 |
-
|
| 142 |
else:
|
| 143 |
epochs_no_improve += 1
|
| 144 |
if epochs_no_improve == args.early_stop:
|
| 145 |
-
logging.info(
|
| 146 |
break
|
| 147 |
-
|
| 148 |
-
|
| 149 |
|
| 150 |
scheduler.step()
|
| 151 |
|
| 152 |
logging.info("ALL DONE")
|
| 153 |
-
|
| 154 |
-
elif args.mode == 'test':
|
| 155 |
|
| 156 |
-
|
| 157 |
kappa_list = []
|
| 158 |
for seed in list(range(args.num_seeds)):
|
| 159 |
set_determinism(seed=seed)
|
|
@@ -163,51 +166,73 @@ def main_worker(args):
|
|
| 163 |
kappa_list.append(qwk)
|
| 164 |
logging.info(f"Seed {seed}, QWK: {qwk}")
|
| 165 |
if os.path.exists(cache_dir_):
|
| 166 |
-
logging.info("Removing cache directory
|
| 167 |
shutil.rmtree(cache_dir_)
|
| 168 |
-
|
| 169 |
logging.info(f"Mean QWK over {args.num_seeds} seeds: {np.mean(kappa_list)}")
|
| 170 |
|
| 171 |
-
|
| 172 |
if os.path.exists(cache_dir_):
|
| 173 |
-
logging.info("Removing cache directory
|
| 174 |
shutil.rmtree(cache_dir_)
|
| 175 |
|
| 176 |
|
| 177 |
def parse_args():
|
| 178 |
-
parser = argparse.ArgumentParser(
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
parser.add_argument(
|
| 182 |
-
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
parser.add_argument(
|
| 185 |
-
"--
|
| 186 |
)
|
| 187 |
parser.add_argument(
|
| 188 |
-
"--
|
| 189 |
)
|
|
|
|
|
|
|
|
|
|
| 190 |
parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
|
| 191 |
parser.add_argument("--num_classes", default=4, type=int, help="number of output classes")
|
| 192 |
-
parser.add_argument("--mil_mode", default="att_trans", help="MIL algorithm: choose either att_trans or att_pyramid")
|
| 193 |
parser.add_argument(
|
| 194 |
-
"--
|
|
|
|
|
|
|
| 195 |
)
|
| 196 |
-
parser.add_argument("--tile_size", default=64, type=int, help="size of square patch (instance) in pixels")
|
| 197 |
-
parser.add_argument("--depth", default=3, type=int, help="number of slices in each 3D patch (instance)")
|
| 198 |
parser.add_argument(
|
| 199 |
-
"--
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
)
|
| 202 |
parser.add_argument(
|
| 203 |
-
"--
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
)
|
| 206 |
parser.set_defaults(use_heatmap=True)
|
| 207 |
parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
|
| 208 |
|
| 209 |
parser.add_argument("--checkpoint", default=None, help="load existing checkpoint")
|
| 210 |
-
parser.add_argument(
|
|
|
|
|
|
|
| 211 |
parser.add_argument("--early_stop", default=40, type=int, help="early stopping criteria")
|
| 212 |
parser.add_argument("--batch_size", default=4, type=int, help="number of MRI scans per batch")
|
| 213 |
parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate")
|
|
@@ -220,19 +245,27 @@ def parse_args():
|
|
| 220 |
type=int,
|
| 221 |
help="run validation after this number of epochs, default 1 to run every epoch",
|
| 222 |
)
|
| 223 |
-
parser.add_argument(
|
|
|
|
|
|
|
| 224 |
args = parser.parse_args()
|
| 225 |
if args.config:
|
| 226 |
-
with open(args.config,
|
| 227 |
config = yaml.safe_load(config_file)
|
| 228 |
args.__dict__.update(config)
|
| 229 |
return args
|
| 230 |
-
|
| 231 |
-
|
| 232 |
|
| 233 |
|
| 234 |
if __name__ == "__main__":
|
|
|
|
| 235 |
args = parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
args.logdir = os.path.join(args.project_dir, "logs", args.run_name)
|
| 237 |
os.makedirs(args.logdir, exist_ok=True)
|
| 238 |
args.logfile = os.path.join(args.logdir, f"{args.run_name}.log")
|
|
@@ -243,15 +276,14 @@ if __name__ == "__main__":
|
|
| 243 |
logging.info(f"{k} => {v}")
|
| 244 |
logging.info("-----------------")
|
| 245 |
|
| 246 |
-
args.num_seeds = 10
|
| 247 |
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 248 |
if args.device == torch.device("cpu"):
|
| 249 |
args.amp = False
|
| 250 |
if args.dataset_json is None:
|
| 251 |
-
logging.error(
|
| 252 |
sys.exit(1)
|
| 253 |
-
if args.checkpoint is None and args.mode ==
|
| 254 |
-
logging.error(
|
| 255 |
sys.exit(1)
|
| 256 |
|
| 257 |
if args.dry_run:
|
|
@@ -261,8 +293,8 @@ if __name__ == "__main__":
|
|
| 261 |
args.workers = 0
|
| 262 |
args.num_seeds = 2
|
| 263 |
args.wandb = False
|
| 264 |
-
|
| 265 |
-
mode_wandb = "online" if args.wandb else "disabled"
|
| 266 |
|
| 267 |
config_wandb = {
|
| 268 |
"learning_rate": args.optim_lr,
|
|
@@ -271,13 +303,14 @@ if __name__ == "__main__":
|
|
| 271 |
"patch size": args.tile_size,
|
| 272 |
"patch count": args.tile_count,
|
| 273 |
}
|
| 274 |
-
wandb.init(
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
|
|
|
| 280 |
|
| 281 |
main_worker(args)
|
| 282 |
-
|
| 283 |
wandb.finish()
|
|
|
|
| 1 |
import argparse
|
|
|
|
| 2 |
import os
|
| 3 |
import shutil
|
| 4 |
import time
|
| 5 |
import yaml
|
| 6 |
import sys
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from torch.utils.tensorboard import SummaryWriter
|
| 10 |
from monai.utils import set_determinism
|
|
|
|
|
|
|
| 11 |
import wandb
|
|
|
|
| 12 |
import logging
|
| 13 |
from pathlib import Path
|
| 14 |
from src.data.data_loader import get_dataloader
|
|
|
|
| 19 |
|
| 20 |
def main_worker(args):
|
| 21 |
if args.device == torch.device("cuda"):
|
|
|
|
| 22 |
torch.backends.cudnn.benchmark = True
|
| 23 |
|
| 24 |
+
model = MILModel_3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
|
|
|
|
|
|
|
|
|
|
| 25 |
start_epoch = 0
|
| 26 |
best_acc = 0.0
|
| 27 |
if args.checkpoint is not None:
|
|
|
|
| 32 |
start_epoch = checkpoint["epoch"]
|
| 33 |
if "best_acc" in checkpoint:
|
| 34 |
best_acc = checkpoint["best_acc"]
|
| 35 |
+
logging.info(
|
| 36 |
+
"=> loaded checkpoint %s (epoch %d) (bestacc %f)",
|
| 37 |
+
args.checkpoint,
|
| 38 |
+
start_epoch,
|
| 39 |
+
best_acc,
|
| 40 |
+
)
|
| 41 |
cache_dir_ = os.path.join(args.logdir, "cache")
|
| 42 |
model.to(args.device)
|
| 43 |
params = model.parameters()
|
| 44 |
+
if args.mode == "train":
|
| 45 |
+
train_loader = get_dataloader(args, split="train")
|
| 46 |
valid_loader = get_dataloader(args, split="test")
|
| 47 |
+
logging.info(
|
| 48 |
+
f"Dataset training: {len(train_loader.dataset)}, test: {len(valid_loader.dataset)}"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
if args.mil_mode in ["att_trans", "att_trans_pyramid"]:
|
| 52 |
params = [
|
| 53 |
+
{
|
| 54 |
+
"params": list(model.attention.parameters())
|
| 55 |
+
+ list(model.myfc.parameters())
|
| 56 |
+
+ list(model.net.parameters())
|
| 57 |
+
},
|
| 58 |
{"params": list(model.transformer.parameters()), "lr": 6e-5, "weight_decay": 0.1},
|
| 59 |
]
|
| 60 |
|
| 61 |
optimizer = torch.optim.AdamW(params, lr=args.optim_lr, weight_decay=args.weight_decay)
|
| 62 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 63 |
+
optimizer, T_max=args.epochs, eta_min=0
|
| 64 |
+
)
|
| 65 |
scaler = torch.amp.GradScaler(device=str(args.device), enabled=args.amp)
|
| 66 |
|
| 67 |
if args.logdir is not None:
|
| 68 |
writer = SummaryWriter(log_dir=args.logdir)
|
| 69 |
+
logging.info(f"Writing Tensorboard logs to {writer.log_dir}")
|
| 70 |
else:
|
| 71 |
writer = None
|
| 72 |
|
|
|
|
| 73 |
# RUN TRAINING
|
| 74 |
n_epochs = args.epochs
|
| 75 |
val_loss_min = float("inf")
|
| 76 |
epochs_no_improve = 0
|
| 77 |
for epoch in range(start_epoch, n_epochs):
|
|
|
|
| 78 |
logging.info(time.ctime(), "Epoch:", epoch)
|
| 79 |
epoch_time = time.time()
|
| 80 |
+
train_loss, train_acc, train_att_loss, batch_norm = train_epoch(
|
| 81 |
+
model, train_loader, optimizer, scaler=scaler, epoch=epoch, args=args
|
| 82 |
+
)
|
| 83 |
logging.info(
|
| 84 |
"Final training %d/%d loss: %.4f attention loss: %.4f acc: %.4f time %.2fs",
|
| 85 |
epoch,
|
|
|
|
| 89 |
train_acc,
|
| 90 |
time.time() - epoch_time,
|
| 91 |
)
|
| 92 |
+
|
| 93 |
if writer is not None:
|
| 94 |
writer.add_scalar("train_loss", train_loss, epoch)
|
| 95 |
writer.add_scalar("train_attention_loss", train_att_loss, epoch)
|
| 96 |
writer.add_scalar("train_acc", train_acc, epoch)
|
| 97 |
+
wandb.log(
|
| 98 |
+
{
|
| 99 |
+
"Train Loss": train_loss,
|
| 100 |
+
"Train Accuracy": train_acc,
|
| 101 |
+
"Train Attention Loss": train_att_loss,
|
| 102 |
+
"Batch Norm": batch_norm,
|
| 103 |
+
},
|
| 104 |
+
step=epoch,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
model_new_best = False
|
| 108 |
val_acc = 0
|
| 109 |
if (epoch + 1) % args.val_every == 0:
|
|
|
|
| 124 |
writer.add_scalar("test_acc", val_acc, epoch)
|
| 125 |
writer.add_scalar("test_qwk", qwk, epoch)
|
| 126 |
|
| 127 |
+
# val_acc = qwk
|
| 128 |
+
wandb.log(
|
| 129 |
+
{"Test Loss": val_loss, "Test Accuracy": val_acc, "Cohen Kappa": qwk},
|
| 130 |
+
step=epoch,
|
| 131 |
+
)
|
| 132 |
if val_loss < val_loss_min:
|
| 133 |
logging.info("Loss (%.6f --> %.6f)", val_loss_min, val_loss)
|
| 134 |
val_loss_min = val_loss
|
| 135 |
model_new_best = True
|
| 136 |
|
| 137 |
if args.logdir is not None:
|
| 138 |
+
save_pirads_checkpoint(
|
| 139 |
+
model, epoch, args, best_acc=val_acc, filename=f"model_{epoch}.pt"
|
| 140 |
+
)
|
| 141 |
if model_new_best:
|
| 142 |
+
logging.info("Copying to model.pt new best model")
|
| 143 |
+
shutil.copyfile(
|
| 144 |
+
os.path.join(args.logdir, f"model_{epoch}.pt"),
|
| 145 |
+
os.path.join(args.logdir, "model.pt"),
|
| 146 |
+
)
|
| 147 |
epochs_no_improve = 0
|
| 148 |
+
|
| 149 |
else:
|
| 150 |
epochs_no_improve += 1
|
| 151 |
if epochs_no_improve == args.early_stop:
|
| 152 |
+
logging.info("Early stopping!")
|
| 153 |
break
|
|
|
|
|
|
|
| 154 |
|
| 155 |
scheduler.step()
|
| 156 |
|
| 157 |
logging.info("ALL DONE")
|
|
|
|
|
|
|
| 158 |
|
| 159 |
+
elif args.mode == "test":
|
| 160 |
kappa_list = []
|
| 161 |
for seed in list(range(args.num_seeds)):
|
| 162 |
set_determinism(seed=seed)
|
|
|
|
| 166 |
kappa_list.append(qwk)
|
| 167 |
logging.info(f"Seed {seed}, QWK: {qwk}")
|
| 168 |
if os.path.exists(cache_dir_):
|
| 169 |
+
logging.info(f"Removing cache directory {cache_dir_}")
|
| 170 |
shutil.rmtree(cache_dir_)
|
| 171 |
+
|
| 172 |
logging.info(f"Mean QWK over {args.num_seeds} seeds: {np.mean(kappa_list)}")
|
| 173 |
|
|
|
|
| 174 |
if os.path.exists(cache_dir_):
|
| 175 |
+
logging.info(f"Removing cache directory {cache_dir_}")
|
| 176 |
shutil.rmtree(cache_dir_)
|
| 177 |
|
| 178 |
|
| 179 |
def parse_args():
|
| 180 |
+
parser = argparse.ArgumentParser(
|
| 181 |
+
description="Multiple Instance Learning (MIL) for PIRADS Classification."
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--mode",
|
| 185 |
+
type=str,
|
| 186 |
+
choices=["train", "test"],
|
| 187 |
+
required=True,
|
| 188 |
+
help="operation mode: train or infer",
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--wandb", action="store_true", help="Add this flag to enable WandB logging"
|
| 192 |
+
)
|
| 193 |
parser.add_argument(
|
| 194 |
+
"--project_name", type=str, default="Classification_prostate", help="WandB project name"
|
| 195 |
)
|
| 196 |
parser.add_argument(
|
| 197 |
+
"--run_name", type=str, default="train_pirads", help="run name for WandB logging"
|
| 198 |
)
|
| 199 |
+
parser.add_argument("--config", type=str, help="path to YAML config file")
|
| 200 |
+
parser.add_argument("--project_dir", default=None, help="path to project firectory")
|
| 201 |
+
parser.add_argument("--data_root", default=None, help="path to root folder of images")
|
| 202 |
parser.add_argument("--dataset_json", default=None, type=str, help="path to dataset json file")
|
| 203 |
parser.add_argument("--num_classes", default=4, type=int, help="number of output classes")
|
|
|
|
| 204 |
parser.add_argument(
|
| 205 |
+
"--mil_mode",
|
| 206 |
+
default="att_trans",
|
| 207 |
+
help="MIL algorithm: choose either att_trans or att_pyramid",
|
| 208 |
)
|
|
|
|
|
|
|
| 209 |
parser.add_argument(
|
| 210 |
+
"--tile_count",
|
| 211 |
+
default=24,
|
| 212 |
+
type=int,
|
| 213 |
+
help="number of patches (instances) to extract from MRI input",
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--tile_size", default=64, type=int, help="size of square patch (instance) in pixels"
|
| 217 |
+
)
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--depth", default=3, type=int, help="number of slices in each 3D patch (instance)"
|
| 220 |
)
|
| 221 |
parser.add_argument(
|
| 222 |
+
"--use_heatmap",
|
| 223 |
+
action="store_true",
|
| 224 |
+
help="enable weak attention heatmap guided patch generation",
|
| 225 |
+
)
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--no_heatmap", dest="use_heatmap", action="store_false", help="disable heatmap"
|
| 228 |
)
|
| 229 |
parser.set_defaults(use_heatmap=True)
|
| 230 |
parser.add_argument("--workers", default=2, type=int, help="number of workers for data loading")
|
| 231 |
|
| 232 |
parser.add_argument("--checkpoint", default=None, help="load existing checkpoint")
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--epochs", "--max_epochs", default=50, type=int, help="number of training epochs"
|
| 235 |
+
)
|
| 236 |
parser.add_argument("--early_stop", default=40, type=int, help="early stopping criteria")
|
| 237 |
parser.add_argument("--batch_size", default=4, type=int, help="number of MRI scans per batch")
|
| 238 |
parser.add_argument("--optim_lr", default=3e-5, type=float, help="initial learning rate")
|
|
|
|
| 245 |
type=int,
|
| 246 |
help="run validation after this number of epochs, default 1 to run every epoch",
|
| 247 |
)
|
| 248 |
+
parser.add_argument(
|
| 249 |
+
"--dry_run", action="store_true", help="Run the script in dry-run mode (default: False)"
|
| 250 |
+
)
|
| 251 |
args = parser.parse_args()
|
| 252 |
if args.config:
|
| 253 |
+
with open(args.config, "r") as config_file:
|
| 254 |
config = yaml.safe_load(config_file)
|
| 255 |
args.__dict__.update(config)
|
| 256 |
return args
|
|
|
|
|
|
|
| 257 |
|
| 258 |
|
| 259 |
if __name__ == "__main__":
|
| 260 |
+
|
| 261 |
args = parse_args()
|
| 262 |
+
if args.project_dir is None:
|
| 263 |
+
args.project_dir = Path(__file__).resolve().parent # Set project directory
|
| 264 |
+
|
| 265 |
+
slurm_job_name = os.getenv('SLURM_JOB_NAME') # If the script is submitted via slurm, job name is the run name
|
| 266 |
+
if slurm_job_name:
|
| 267 |
+
args.run_name = slurm_job_name
|
| 268 |
+
|
| 269 |
args.logdir = os.path.join(args.project_dir, "logs", args.run_name)
|
| 270 |
os.makedirs(args.logdir, exist_ok=True)
|
| 271 |
args.logfile = os.path.join(args.logdir, f"{args.run_name}.log")
|
|
|
|
| 276 |
logging.info(f"{k} => {v}")
|
| 277 |
logging.info("-----------------")
|
| 278 |
|
|
|
|
| 279 |
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 280 |
if args.device == torch.device("cpu"):
|
| 281 |
args.amp = False
|
| 282 |
if args.dataset_json is None:
|
| 283 |
+
logging.error("Dataset JSON file not provided. Quitting.")
|
| 284 |
sys.exit(1)
|
| 285 |
+
if args.checkpoint is None and args.mode == "test":
|
| 286 |
+
logging.error("Model checkpoint path not provided. Quitting.")
|
| 287 |
sys.exit(1)
|
| 288 |
|
| 289 |
if args.dry_run:
|
|
|
|
| 293 |
args.workers = 0
|
| 294 |
args.num_seeds = 2
|
| 295 |
args.wandb = False
|
| 296 |
+
|
| 297 |
+
mode_wandb = "online" if args.wandb and args.mode != "test" else "disabled"
|
| 298 |
|
| 299 |
config_wandb = {
|
| 300 |
"learning_rate": args.optim_lr,
|
|
|
|
| 303 |
"patch size": args.tile_size,
|
| 304 |
"patch count": args.tile_count,
|
| 305 |
}
|
| 306 |
+
wandb.init(
|
| 307 |
+
project=args.project_name,
|
| 308 |
+
name=args.run_name,
|
| 309 |
+
dir=os.path.join(args.logdir, "wandb"),
|
| 310 |
+
config=config_wandb,
|
| 311 |
+
mode=mode_wandb,
|
| 312 |
+
)
|
| 313 |
|
| 314 |
main_worker(args)
|
| 315 |
+
|
| 316 |
wandb.finish()
|
src/data/custom_transforms.py
CHANGED
|
@@ -1,27 +1,25 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
-
from typing import Union
|
| 4 |
-
|
| 5 |
from monai.transforms import MapTransform
|
| 6 |
from monai.config import DtypeLike, KeysCollection
|
| 7 |
-
from monai.config.type_definitions import NdarrayOrTensor
|
| 8 |
from monai.data.meta_obj import get_track_meta
|
| 9 |
from monai.transforms.transform import Transform
|
| 10 |
from monai.transforms.utils import soft_clip
|
| 11 |
-
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile
|
| 12 |
from monai.utils.enums import TransformBackends
|
| 13 |
-
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor
|
| 14 |
from scipy.ndimage import binary_dilation
|
| 15 |
import cv2
|
| 16 |
-
from typing import Union, Sequence
|
| 17 |
from collections.abc import Hashable, Mapping, Sequence
|
| 18 |
|
| 19 |
|
| 20 |
-
|
| 21 |
class DilateAndSaveMaskd(MapTransform):
|
| 22 |
"""
|
| 23 |
Custom transform to dilate binary mask and save a copy.
|
| 24 |
"""
|
|
|
|
| 25 |
def __init__(self, keys, dilation_size=10, copy_key="original_mask"):
|
| 26 |
super().__init__(keys)
|
| 27 |
self.dilation_size = dilation_size
|
|
@@ -29,37 +27,61 @@ class DilateAndSaveMaskd(MapTransform):
|
|
| 29 |
|
| 30 |
def __call__(self, data):
|
| 31 |
d = dict(data)
|
| 32 |
-
|
| 33 |
for key in self.keys:
|
| 34 |
mask = d[key].numpy() if isinstance(d[key], torch.Tensor) else d[key]
|
| 35 |
mask = mask.squeeze(0) # Remove channel dimension if present
|
| 36 |
|
| 37 |
# Save a copy of the original mask
|
| 38 |
-
d[self.copy_key] = torch.tensor(mask, dtype=torch.float32).unsqueeze(
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# Apply binary dilation to the mask
|
| 41 |
dilated_mask = binary_dilation(mask, iterations=self.dilation_size).astype(np.uint8)
|
| 42 |
|
| 43 |
# Store the dilated mask
|
| 44 |
-
d[key] = torch.tensor(dilated_mask, dtype=torch.float32).unsqueeze(
|
|
|
|
|
|
|
| 45 |
|
| 46 |
return d
|
| 47 |
|
| 48 |
|
| 49 |
class ClipMaskIntensityPercentiles(Transform):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
|
|
|
|
| 53 |
|
| 54 |
def __init__(
|
| 55 |
self,
|
| 56 |
lower: Union[float, None],
|
| 57 |
upper: Union[float, None],
|
| 58 |
-
sharpness_factor
|
| 59 |
channel_wise: bool = False,
|
| 60 |
dtype: DtypeLike = np.float32,
|
| 61 |
) -> None:
|
| 62 |
-
|
| 63 |
if lower is None and upper is None:
|
| 64 |
raise ValueError("lower or upper percentiles must be provided")
|
| 65 |
if lower is not None and (lower < 0.0 or lower > 100.0):
|
|
@@ -71,7 +93,7 @@ class ClipMaskIntensityPercentiles(Transform):
|
|
| 71 |
if sharpness_factor is not None and sharpness_factor <= 0:
|
| 72 |
raise ValueError("sharpness_factor must be greater than 0")
|
| 73 |
|
| 74 |
-
#self.mask_data = mask_data
|
| 75 |
self.lower = lower
|
| 76 |
self.upper = upper
|
| 77 |
self.sharpness_factor = sharpness_factor
|
|
@@ -81,14 +103,26 @@ class ClipMaskIntensityPercentiles(Transform):
|
|
| 81 |
def _clip(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
|
| 82 |
masked_img = img * (mask_data > 0)
|
| 83 |
if self.sharpness_factor is not None:
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
else:
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
img = clip(img, lower_percentile, upper_percentile)
|
| 93 |
|
| 94 |
img = convert_to_tensor(img, track_meta=False)
|
|
@@ -102,7 +136,9 @@ class ClipMaskIntensityPercentiles(Transform):
|
|
| 102 |
img_t = convert_to_tensor(img, track_meta=False)
|
| 103 |
mask_t = convert_to_tensor(mask_data, track_meta=False)
|
| 104 |
if self.channel_wise:
|
| 105 |
-
img_t = torch.stack(
|
|
|
|
|
|
|
| 106 |
else:
|
| 107 |
img_t = self._clip(img=img_t, mask_data=mask_t)
|
| 108 |
|
|
@@ -110,7 +146,28 @@ class ClipMaskIntensityPercentiles(Transform):
|
|
| 110 |
|
| 111 |
return img
|
| 112 |
|
|
|
|
| 113 |
class ClipMaskIntensityPercentilesd(MapTransform):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
def __init__(
|
| 116 |
self,
|
|
@@ -125,7 +182,11 @@ class ClipMaskIntensityPercentilesd(MapTransform):
|
|
| 125 |
) -> None:
|
| 126 |
super().__init__(keys, allow_missing_keys)
|
| 127 |
self.scaler = ClipMaskIntensityPercentiles(
|
| 128 |
-
lower=lower,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
)
|
| 130 |
self.mask_key = mask_key
|
| 131 |
|
|
@@ -134,16 +195,32 @@ class ClipMaskIntensityPercentilesd(MapTransform):
|
|
| 134 |
for key in self.key_iterator(d):
|
| 135 |
d[key] = self.scaler(d[key], d[self.mask_key])
|
| 136 |
return d
|
| 137 |
-
|
| 138 |
-
|
| 139 |
|
| 140 |
|
| 141 |
class ElementwiseProductd(MapTransform):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
def __init__(self, keys: KeysCollection, output_key: str) -> None:
|
| 143 |
super().__init__(keys)
|
| 144 |
self.output_key = output_key
|
| 145 |
|
| 146 |
-
def __call__(self, data) -> NdarrayOrTensor:
|
| 147 |
d = dict(data)
|
| 148 |
d[self.output_key] = d[self.keys[0]] * d[self.keys[1]]
|
| 149 |
return d
|
|
@@ -159,6 +236,7 @@ class CLAHEd(MapTransform):
|
|
| 159 |
clip_limit (float): Threshold for contrast limiting. Default is 2.0.
|
| 160 |
tile_grid_size (Union[tuple, Sequence[int]]): Size of grid for histogram equalization (default: (8,8)).
|
| 161 |
"""
|
|
|
|
| 162 |
def __init__(
|
| 163 |
self,
|
| 164 |
keys: KeysCollection,
|
|
@@ -184,13 +262,13 @@ class CLAHEd(MapTransform):
|
|
| 184 |
|
| 185 |
image_clahe = np.stack([clahe.apply(slice) for slice in image[0]])
|
| 186 |
|
| 187 |
-
|
| 188 |
# Convert back to float in [0,1]
|
| 189 |
processed_img = image_clahe.astype(np.float32) / 255.0
|
| 190 |
reshaped_ = processed_img.reshape(1, *processed_img.shape)
|
| 191 |
d[key] = torch.from_numpy(reshaped_).to(image_.device)
|
| 192 |
return d
|
| 193 |
-
|
|
|
|
| 194 |
class NormalizeIntensity_custom(Transform):
|
| 195 |
"""
|
| 196 |
Normalize input based on the `subtrahend` and `divisor`: `(img - subtrahend) / divisor`.
|
|
@@ -240,9 +318,11 @@ class NormalizeIntensity_custom(Transform):
|
|
| 240 |
x = torch.std(x.float(), unbiased=False)
|
| 241 |
return x.item() if x.numel() == 1 else x
|
| 242 |
|
| 243 |
-
def _normalize(
|
|
|
|
|
|
|
| 244 |
img, *_ = convert_data_type(img, dtype=torch.float32)
|
| 245 |
-
|
| 246 |
if self.nonzero:
|
| 247 |
slices = img != 0
|
| 248 |
masked_img = img[slices]
|
|
@@ -251,7 +331,7 @@ class NormalizeIntensity_custom(Transform):
|
|
| 251 |
else:
|
| 252 |
slices = None
|
| 253 |
masked_img = img
|
| 254 |
-
|
| 255 |
slices = None
|
| 256 |
mask_data = mask_data.squeeze(0)
|
| 257 |
slices_mask = mask_data > 0
|
|
@@ -288,9 +368,13 @@ class NormalizeIntensity_custom(Transform):
|
|
| 288 |
dtype = self.dtype or img.dtype
|
| 289 |
if self.channel_wise:
|
| 290 |
if self.subtrahend is not None and len(self.subtrahend) != len(img):
|
| 291 |
-
raise ValueError(
|
|
|
|
|
|
|
| 292 |
if self.divisor is not None and len(self.divisor) != len(img):
|
| 293 |
-
raise ValueError(
|
|
|
|
|
|
|
| 294 |
|
| 295 |
if not img.dtype.is_floating_point:
|
| 296 |
img, *_ = convert_data_type(img, dtype=torch.float32)
|
|
@@ -308,21 +392,27 @@ class NormalizeIntensity_custom(Transform):
|
|
| 308 |
out = convert_to_dst_type(img, img, dtype=dtype)[0]
|
| 309 |
return out
|
| 310 |
|
|
|
|
| 311 |
class NormalizeIntensity_customd(MapTransform):
|
| 312 |
"""
|
| 313 |
-
Dictionary-based wrapper of :
|
| 314 |
-
|
| 315 |
-
mean and
|
|
|
|
| 316 |
|
| 317 |
Args:
|
| 318 |
keys: keys of the corresponding items to be transformed.
|
| 319 |
-
See also: monai.transforms.MapTransform
|
| 320 |
-
|
| 321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
nonzero: whether only normalize non-zero values.
|
| 323 |
channel_wise: if True, calculate on each channel separately, otherwise, calculate on
|
| 324 |
-
the entire image directly.
|
| 325 |
-
dtype: output data type, if None, same as input image.
|
| 326 |
allow_missing_keys: don't raise exception if key is missing.
|
| 327 |
"""
|
| 328 |
|
|
@@ -332,19 +422,21 @@ class NormalizeIntensity_customd(MapTransform):
|
|
| 332 |
self,
|
| 333 |
keys: KeysCollection,
|
| 334 |
mask_key: str,
|
| 335 |
-
subtrahend:Union[
|
| 336 |
-
divisor: Union[
|
| 337 |
nonzero: bool = False,
|
| 338 |
channel_wise: bool = False,
|
| 339 |
dtype: DtypeLike = np.float32,
|
| 340 |
allow_missing_keys: bool = False,
|
| 341 |
) -> None:
|
| 342 |
super().__init__(keys, allow_missing_keys)
|
| 343 |
-
self.normalizer = NormalizeIntensity_custom(
|
|
|
|
|
|
|
| 344 |
self.mask_key = mask_key
|
| 345 |
|
| 346 |
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
|
| 347 |
d = dict(data)
|
| 348 |
for key in self.key_iterator(d):
|
| 349 |
d[key] = self.normalizer(d[key], d[self.mask_key])
|
| 350 |
-
return d
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
+
from typing import Union
|
|
|
|
| 4 |
from monai.transforms import MapTransform
|
| 5 |
from monai.config import DtypeLike, KeysCollection
|
| 6 |
+
from monai.config.type_definitions import NdarrayOrTensor
|
| 7 |
from monai.data.meta_obj import get_track_meta
|
| 8 |
from monai.transforms.transform import Transform
|
| 9 |
from monai.transforms.utils import soft_clip
|
| 10 |
+
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile
|
| 11 |
from monai.utils.enums import TransformBackends
|
| 12 |
+
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor
|
| 13 |
from scipy.ndimage import binary_dilation
|
| 14 |
import cv2
|
|
|
|
| 15 |
from collections.abc import Hashable, Mapping, Sequence
|
| 16 |
|
| 17 |
|
|
|
|
| 18 |
class DilateAndSaveMaskd(MapTransform):
|
| 19 |
"""
|
| 20 |
Custom transform to dilate binary mask and save a copy.
|
| 21 |
"""
|
| 22 |
+
|
| 23 |
def __init__(self, keys, dilation_size=10, copy_key="original_mask"):
|
| 24 |
super().__init__(keys)
|
| 25 |
self.dilation_size = dilation_size
|
|
|
|
| 27 |
|
| 28 |
def __call__(self, data):
|
| 29 |
d = dict(data)
|
| 30 |
+
|
| 31 |
for key in self.keys:
|
| 32 |
mask = d[key].numpy() if isinstance(d[key], torch.Tensor) else d[key]
|
| 33 |
mask = mask.squeeze(0) # Remove channel dimension if present
|
| 34 |
|
| 35 |
# Save a copy of the original mask
|
| 36 |
+
d[self.copy_key] = torch.tensor(mask, dtype=torch.float32).unsqueeze(
|
| 37 |
+
0
|
| 38 |
+
) # Save to a new key
|
| 39 |
|
| 40 |
# Apply binary dilation to the mask
|
| 41 |
dilated_mask = binary_dilation(mask, iterations=self.dilation_size).astype(np.uint8)
|
| 42 |
|
| 43 |
# Store the dilated mask
|
| 44 |
+
d[key] = torch.tensor(dilated_mask, dtype=torch.float32).unsqueeze(
|
| 45 |
+
0
|
| 46 |
+
) # Add channel dimension back
|
| 47 |
|
| 48 |
return d
|
| 49 |
|
| 50 |
|
| 51 |
class ClipMaskIntensityPercentiles(Transform):
|
| 52 |
+
"""
|
| 53 |
+
Clip image intensity values based on percentiles computed from a masked region.
|
| 54 |
+
This transform clips the intensity range of an image to values between lower and upper
|
| 55 |
+
percentiles calculated only from voxels where the mask is positive. It supports both
|
| 56 |
+
hard clipping and soft (smooth) clipping via a sharpness factor.
|
| 57 |
+
Args:
|
| 58 |
+
lower: Lower percentile threshold in range [0, 100]. If None, no lower clipping applied.
|
| 59 |
+
upper: Upper percentile threshold in range [0, 100]. If None, no upper clipping applied.
|
| 60 |
+
sharpness_factor: If provided, applies soft clipping with this sharpness parameter.
|
| 61 |
+
Must be greater than 0. If None, applies hard clipping instead.
|
| 62 |
+
channel_wise: If True, applies clipping independently to each channel using the
|
| 63 |
+
corresponding channel's mask. If False, uses the same mask for all channels.
|
| 64 |
+
dtype: Output data type for the clipped image. Defaults to np.float32.
|
| 65 |
+
Raises:
|
| 66 |
+
ValueError: If both lower and upper are None, if percentiles are outside [0, 100],
|
| 67 |
+
if upper < lower, or if sharpness_factor <= 0.
|
| 68 |
+
Returns:
|
| 69 |
+
Clipped image with intensities adjusted based on masked percentiles.
|
| 70 |
+
Note:
|
| 71 |
+
Supports both torch.Tensor and numpy.ndarray inputs.
|
| 72 |
|
| 73 |
|
| 74 |
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
|
| 75 |
+
"""
|
| 76 |
|
| 77 |
def __init__(
|
| 78 |
self,
|
| 79 |
lower: Union[float, None],
|
| 80 |
upper: Union[float, None],
|
| 81 |
+
sharpness_factor: Union[float, None] = None,
|
| 82 |
channel_wise: bool = False,
|
| 83 |
dtype: DtypeLike = np.float32,
|
| 84 |
) -> None:
|
|
|
|
| 85 |
if lower is None and upper is None:
|
| 86 |
raise ValueError("lower or upper percentiles must be provided")
|
| 87 |
if lower is not None and (lower < 0.0 or lower > 100.0):
|
|
|
|
| 93 |
if sharpness_factor is not None and sharpness_factor <= 0:
|
| 94 |
raise ValueError("sharpness_factor must be greater than 0")
|
| 95 |
|
| 96 |
+
# self.mask_data = mask_data
|
| 97 |
self.lower = lower
|
| 98 |
self.upper = upper
|
| 99 |
self.sharpness_factor = sharpness_factor
|
|
|
|
| 103 |
def _clip(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor) -> NdarrayOrTensor:
|
| 104 |
masked_img = img * (mask_data > 0)
|
| 105 |
if self.sharpness_factor is not None:
|
| 106 |
+
lower_percentile = (
|
| 107 |
+
percentile(masked_img, self.lower) if self.lower is not None else None
|
| 108 |
+
)
|
| 109 |
+
upper_percentile = (
|
| 110 |
+
percentile(masked_img, self.upper) if self.upper is not None else None
|
| 111 |
+
)
|
| 112 |
+
img = soft_clip(
|
| 113 |
+
img, self.sharpness_factor, lower_percentile, upper_percentile, self.dtype
|
| 114 |
+
)
|
| 115 |
else:
|
| 116 |
+
lower_percentile = (
|
| 117 |
+
percentile(masked_img, self.lower)
|
| 118 |
+
if self.lower is not None
|
| 119 |
+
else percentile(masked_img, 0)
|
| 120 |
+
)
|
| 121 |
+
upper_percentile = (
|
| 122 |
+
percentile(masked_img, self.upper)
|
| 123 |
+
if self.upper is not None
|
| 124 |
+
else percentile(masked_img, 100)
|
| 125 |
+
)
|
| 126 |
img = clip(img, lower_percentile, upper_percentile)
|
| 127 |
|
| 128 |
img = convert_to_tensor(img, track_meta=False)
|
|
|
|
| 136 |
img_t = convert_to_tensor(img, track_meta=False)
|
| 137 |
mask_t = convert_to_tensor(mask_data, track_meta=False)
|
| 138 |
if self.channel_wise:
|
| 139 |
+
img_t = torch.stack(
|
| 140 |
+
[self._clip(img=d, mask_data=mask_t[e]) for e, d in enumerate(img_t)]
|
| 141 |
+
) # type: ignore
|
| 142 |
else:
|
| 143 |
img_t = self._clip(img=img_t, mask_data=mask_t)
|
| 144 |
|
|
|
|
| 146 |
|
| 147 |
return img
|
| 148 |
|
| 149 |
+
|
| 150 |
class ClipMaskIntensityPercentilesd(MapTransform):
|
| 151 |
+
"""
|
| 152 |
+
Dictionary wrapper for ClipMaskIntensityPercentiles.
|
| 153 |
+
Args:
|
| 154 |
+
keys: Keys of the corresponding items to be transformed.
|
| 155 |
+
mask_key: Key to the mask data in the input dictionary used to compute percentiles. Only intensity values where the mask is positive will be considered.
|
| 156 |
+
lower: Lower percentile value (0-100) for clipping. If None, no lower clipping is applied.
|
| 157 |
+
upper: Upper percentile value (0-100) for clipping. If None, no upper clipping is applied.
|
| 158 |
+
sharpness_factor: Optional factor to enhance contrast after clipping. If None, no sharpness enhancement is applied.
|
| 159 |
+
channel_wise: If True, compute percentiles separately for each channel. If False, compute globally.
|
| 160 |
+
dtype: Data type of the output. Defaults to np.float32.
|
| 161 |
+
allow_missing_keys: If True, missing keys will not raise an error. Defaults to False.
|
| 162 |
+
Example:
|
| 163 |
+
>>> transform = ClipMaskIntensityPercentilesd(
|
| 164 |
+
... keys=["image"],
|
| 165 |
+
... mask_key="mask",
|
| 166 |
+
... lower=2,
|
| 167 |
+
... upper=98,
|
| 168 |
+
... sharpness_factor=1.0
|
| 169 |
+
... )
|
| 170 |
+
"""
|
| 171 |
|
| 172 |
def __init__(
|
| 173 |
self,
|
|
|
|
| 182 |
) -> None:
|
| 183 |
super().__init__(keys, allow_missing_keys)
|
| 184 |
self.scaler = ClipMaskIntensityPercentiles(
|
| 185 |
+
lower=lower,
|
| 186 |
+
upper=upper,
|
| 187 |
+
sharpness_factor=sharpness_factor,
|
| 188 |
+
channel_wise=channel_wise,
|
| 189 |
+
dtype=dtype,
|
| 190 |
)
|
| 191 |
self.mask_key = mask_key
|
| 192 |
|
|
|
|
| 195 |
for key in self.key_iterator(d):
|
| 196 |
d[key] = self.scaler(d[key], d[self.mask_key])
|
| 197 |
return d
|
|
|
|
|
|
|
| 198 |
|
| 199 |
|
| 200 |
class ElementwiseProductd(MapTransform):
|
| 201 |
+
"""
|
| 202 |
+
A dictionary-based transform that computes the elementwise product of two arrays.
|
| 203 |
+
This transform multiplies two input arrays element-by-element and stores the result
|
| 204 |
+
in a specified output key.
|
| 205 |
+
Args:
|
| 206 |
+
keys: Collection of keys to select from the input dictionary. Must contain exactly
|
| 207 |
+
two keys whose corresponding values will be multiplied together.
|
| 208 |
+
output_key: Key in the output dictionary where the product result will be stored.
|
| 209 |
+
Returns:
|
| 210 |
+
Dictionary with the elementwise product stored at the output_key.
|
| 211 |
+
Example:
|
| 212 |
+
>>> transform = ElementwiseProductd(keys=["image1", "image2"], output_key="product")
|
| 213 |
+
>>> data = {"image1": np.array([1, 2, 3]), "image2": np.array([2, 3, 4])}
|
| 214 |
+
>>> result = transform(data)
|
| 215 |
+
>>> result["product"]
|
| 216 |
+
array([ 2, 6, 12])
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
def __init__(self, keys: KeysCollection, output_key: str) -> None:
|
| 220 |
super().__init__(keys)
|
| 221 |
self.output_key = output_key
|
| 222 |
|
| 223 |
+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
|
| 224 |
d = dict(data)
|
| 225 |
d[self.output_key] = d[self.keys[0]] * d[self.keys[1]]
|
| 226 |
return d
|
|
|
|
| 236 |
clip_limit (float): Threshold for contrast limiting. Default is 2.0.
|
| 237 |
tile_grid_size (Union[tuple, Sequence[int]]): Size of grid for histogram equalization (default: (8,8)).
|
| 238 |
"""
|
| 239 |
+
|
| 240 |
def __init__(
|
| 241 |
self,
|
| 242 |
keys: KeysCollection,
|
|
|
|
| 262 |
|
| 263 |
image_clahe = np.stack([clahe.apply(slice) for slice in image[0]])
|
| 264 |
|
|
|
|
| 265 |
# Convert back to float in [0,1]
|
| 266 |
processed_img = image_clahe.astype(np.float32) / 255.0
|
| 267 |
reshaped_ = processed_img.reshape(1, *processed_img.shape)
|
| 268 |
d[key] = torch.from_numpy(reshaped_).to(image_.device)
|
| 269 |
return d
|
| 270 |
+
|
| 271 |
+
|
| 272 |
class NormalizeIntensity_custom(Transform):
|
| 273 |
"""
|
| 274 |
Normalize input based on the `subtrahend` and `divisor`: `(img - subtrahend) / divisor`.
|
|
|
|
| 318 |
x = torch.std(x.float(), unbiased=False)
|
| 319 |
return x.item() if x.numel() == 1 else x
|
| 320 |
|
| 321 |
+
def _normalize(
|
| 322 |
+
self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor, sub=None, div=None
|
| 323 |
+
) -> NdarrayOrTensor:
|
| 324 |
img, *_ = convert_data_type(img, dtype=torch.float32)
|
| 325 |
+
"""
|
| 326 |
if self.nonzero:
|
| 327 |
slices = img != 0
|
| 328 |
masked_img = img[slices]
|
|
|
|
| 331 |
else:
|
| 332 |
slices = None
|
| 333 |
masked_img = img
|
| 334 |
+
"""
|
| 335 |
slices = None
|
| 336 |
mask_data = mask_data.squeeze(0)
|
| 337 |
slices_mask = mask_data > 0
|
|
|
|
| 368 |
dtype = self.dtype or img.dtype
|
| 369 |
if self.channel_wise:
|
| 370 |
if self.subtrahend is not None and len(self.subtrahend) != len(img):
|
| 371 |
+
raise ValueError(
|
| 372 |
+
f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components."
|
| 373 |
+
)
|
| 374 |
if self.divisor is not None and len(self.divisor) != len(img):
|
| 375 |
+
raise ValueError(
|
| 376 |
+
f"img has {len(img)} channels, but divisor has {len(self.divisor)} components."
|
| 377 |
+
)
|
| 378 |
|
| 379 |
if not img.dtype.is_floating_point:
|
| 380 |
img, *_ = convert_data_type(img, dtype=torch.float32)
|
|
|
|
| 392 |
out = convert_to_dst_type(img, img, dtype=dtype)[0]
|
| 393 |
return out
|
| 394 |
|
| 395 |
+
|
| 396 |
class NormalizeIntensity_customd(MapTransform):
|
| 397 |
"""
|
| 398 |
+
Dictionary-based wrapper of :class:`NormalizeIntensity_custom`.
|
| 399 |
+
|
| 400 |
+
The mean and standard deviation are calculated only from intensities which are
|
| 401 |
+
defined in the mask provided through ``mask_key``.
|
| 402 |
|
| 403 |
Args:
|
| 404 |
keys: keys of the corresponding items to be transformed.
|
| 405 |
+
See also: :py:class:`monai.transforms.MapTransform`
|
| 406 |
+
mask_key: key of the corresponding mask item to be used for calculating
|
| 407 |
+
statistics (mean and std).
|
| 408 |
+
subtrahend: the amount to subtract by (usually the mean). If None,
|
| 409 |
+
the mean is calculated from the masked region of the input image.
|
| 410 |
+
divisor: the amount to divide by (usually the standard deviation). If None,
|
| 411 |
+
the std is calculated from the masked region of the input image.
|
| 412 |
nonzero: whether only normalize non-zero values.
|
| 413 |
channel_wise: if True, calculate on each channel separately, otherwise, calculate on
|
| 414 |
+
the entire image directly. Defaults to False.
|
| 415 |
+
dtype: output data type, if None, same as input image. Defaults to float32.
|
| 416 |
allow_missing_keys: don't raise exception if key is missing.
|
| 417 |
"""
|
| 418 |
|
|
|
|
| 422 |
self,
|
| 423 |
keys: KeysCollection,
|
| 424 |
mask_key: str,
|
| 425 |
+
subtrahend: Union[NdarrayOrTensor, None] = None,
|
| 426 |
+
divisor: Union[NdarrayOrTensor, None] = None,
|
| 427 |
nonzero: bool = False,
|
| 428 |
channel_wise: bool = False,
|
| 429 |
dtype: DtypeLike = np.float32,
|
| 430 |
allow_missing_keys: bool = False,
|
| 431 |
) -> None:
|
| 432 |
super().__init__(keys, allow_missing_keys)
|
| 433 |
+
self.normalizer = NormalizeIntensity_custom(
|
| 434 |
+
subtrahend, divisor, nonzero, channel_wise, dtype
|
| 435 |
+
)
|
| 436 |
self.mask_key = mask_key
|
| 437 |
|
| 438 |
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
|
| 439 |
d = dict(data)
|
| 440 |
for key in self.key_iterator(d):
|
| 441 |
d[key] = self.normalizer(d[key], d[self.mask_key])
|
| 442 |
+
return d
|
src/data/data_loader.py
CHANGED
|
@@ -1,43 +1,29 @@
|
|
| 1 |
-
import argparse
|
| 2 |
import os
|
| 3 |
import numpy as np
|
| 4 |
-
from monai.
|
| 5 |
-
from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
|
| 6 |
from monai.transforms import (
|
| 7 |
Compose,
|
| 8 |
LoadImaged,
|
| 9 |
-
MapTransform,
|
| 10 |
-
ScaleIntensityRanged,
|
| 11 |
-
SplitDimd,
|
| 12 |
ToTensord,
|
| 13 |
-
ConcatItemsd,
|
| 14 |
-
SelectItemsd,
|
| 15 |
-
EnsureChannelFirstd,
|
| 16 |
-
RepeatChanneld,
|
| 17 |
DeleteItemsd,
|
| 18 |
EnsureTyped,
|
| 19 |
-
ClipIntensityPercentilesd,
|
| 20 |
-
MaskIntensityd,
|
| 21 |
RandCropByPosNegLabeld,
|
| 22 |
NormalizeIntensityd,
|
| 23 |
-
SqueezeDimd,
|
| 24 |
-
ScaleIntensityd,
|
| 25 |
-
ScaleIntensityd,
|
| 26 |
Transposed,
|
| 27 |
RandWeightedCropd,
|
| 28 |
)
|
| 29 |
from .custom_transforms import (
|
| 30 |
-
NormalizeIntensity_customd,
|
| 31 |
-
ClipMaskIntensityPercentilesd,
|
| 32 |
ElementwiseProductd,
|
| 33 |
)
|
| 34 |
import torch
|
| 35 |
from torch.utils.data.dataloader import default_collate
|
| 36 |
-
import matplotlib.pyplot as plt
|
| 37 |
from typing import Literal
|
| 38 |
-
import monai
|
| 39 |
import collections.abc
|
| 40 |
|
|
|
|
| 41 |
def list_data_collate(batch: collections.abc.Sequence):
|
| 42 |
"""
|
| 43 |
Combine instances from a list of dicts into a single dict, by stacking them along first dim
|
|
@@ -51,55 +37,71 @@ def list_data_collate(batch: collections.abc.Sequence):
|
|
| 51 |
|
| 52 |
if all("final_heatmap" in ix for ix in item):
|
| 53 |
data["final_heatmap"] = torch.stack([ix["final_heatmap"] for ix in item], dim=0)
|
| 54 |
-
|
| 55 |
batch[i] = data
|
| 56 |
return default_collate(batch)
|
| 57 |
|
| 58 |
|
| 59 |
-
|
| 60 |
def data_transform(args):
|
| 61 |
-
|
| 62 |
if args.use_heatmap:
|
| 63 |
transform = Compose(
|
| 64 |
[
|
| 65 |
-
LoadImaged(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
|
| 67 |
-
ConcatItemsd(
|
|
|
|
|
|
|
| 68 |
NormalizeIntensity_customd(keys=["image"], channel_wise=True, mask_key="mask"),
|
| 69 |
ElementwiseProductd(keys=["mask", "heatmap"], output_key="final_heatmap"),
|
| 70 |
-
RandWeightedCropd(
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
| 74 |
EnsureTyped(keys=["label"], dtype=torch.float32),
|
| 75 |
Transposed(keys=["image"], indices=(0, 3, 1, 2)),
|
| 76 |
-
DeleteItemsd(keys=[
|
| 77 |
ToTensord(keys=["image", "label", "final_heatmap"]),
|
| 78 |
]
|
| 79 |
)
|
| 80 |
else:
|
| 81 |
transform = Compose(
|
| 82 |
[
|
| 83 |
-
LoadImaged(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
|
| 85 |
-
ConcatItemsd(
|
|
|
|
|
|
|
| 86 |
NormalizeIntensityd(keys=["image"], channel_wise=True),
|
| 87 |
-
RandCropByPosNegLabeld(
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
| 93 |
EnsureTyped(keys=["label"], dtype=torch.float32),
|
| 94 |
Transposed(keys=["image"], indices=(0, 3, 1, 2)),
|
| 95 |
-
DeleteItemsd(keys=[
|
| 96 |
ToTensord(keys=["image", "label"]),
|
| 97 |
]
|
| 98 |
)
|
| 99 |
return transform
|
| 100 |
|
| 101 |
-
def get_dataloader(args, split: Literal["train", "test"]):
|
| 102 |
|
|
|
|
| 103 |
data_list = load_decathlon_datalist(
|
| 104 |
data_list_file_path=args.dataset_json,
|
| 105 |
data_list_key=split,
|
|
@@ -110,16 +112,17 @@ def get_dataloader(args, split: Literal["train", "test"]):
|
|
| 110 |
cache_dir_ = os.path.join(args.logdir, "cache")
|
| 111 |
os.makedirs(os.path.join(cache_dir_, split), exist_ok=True)
|
| 112 |
transform = data_transform(args)
|
| 113 |
-
dataset = PersistentDataset(
|
|
|
|
|
|
|
| 114 |
loader = torch.utils.data.DataLoader(
|
| 115 |
dataset,
|
| 116 |
batch_size=args.batch_size,
|
| 117 |
shuffle=(split == "train"),
|
| 118 |
num_workers=args.workers,
|
| 119 |
pin_memory=True,
|
| 120 |
-
multiprocessing_context="
|
| 121 |
sampler=None,
|
| 122 |
collate_fn=list_data_collate,
|
| 123 |
)
|
| 124 |
return loader
|
| 125 |
-
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import numpy as np
|
| 3 |
+
from monai.data import load_decathlon_datalist, ITKReader, PersistentDataset
|
|
|
|
| 4 |
from monai.transforms import (
|
| 5 |
Compose,
|
| 6 |
LoadImaged,
|
|
|
|
|
|
|
|
|
|
| 7 |
ToTensord,
|
| 8 |
+
ConcatItemsd,
|
|
|
|
|
|
|
|
|
|
| 9 |
DeleteItemsd,
|
| 10 |
EnsureTyped,
|
|
|
|
|
|
|
| 11 |
RandCropByPosNegLabeld,
|
| 12 |
NormalizeIntensityd,
|
|
|
|
|
|
|
|
|
|
| 13 |
Transposed,
|
| 14 |
RandWeightedCropd,
|
| 15 |
)
|
| 16 |
from .custom_transforms import (
|
| 17 |
+
NormalizeIntensity_customd,
|
| 18 |
+
ClipMaskIntensityPercentilesd,
|
| 19 |
ElementwiseProductd,
|
| 20 |
)
|
| 21 |
import torch
|
| 22 |
from torch.utils.data.dataloader import default_collate
|
|
|
|
| 23 |
from typing import Literal
|
|
|
|
| 24 |
import collections.abc
|
| 25 |
|
| 26 |
+
|
| 27 |
def list_data_collate(batch: collections.abc.Sequence):
|
| 28 |
"""
|
| 29 |
Combine instances from a list of dicts into a single dict, by stacking them along first dim
|
|
|
|
| 37 |
|
| 38 |
if all("final_heatmap" in ix for ix in item):
|
| 39 |
data["final_heatmap"] = torch.stack([ix["final_heatmap"] for ix in item], dim=0)
|
| 40 |
+
|
| 41 |
batch[i] = data
|
| 42 |
return default_collate(batch)
|
| 43 |
|
| 44 |
|
|
|
|
| 45 |
def data_transform(args):
|
|
|
|
| 46 |
if args.use_heatmap:
|
| 47 |
transform = Compose(
|
| 48 |
[
|
| 49 |
+
LoadImaged(
|
| 50 |
+
keys=["image", "mask", "dwi", "adc", "heatmap"],
|
| 51 |
+
reader=ITKReader(),
|
| 52 |
+
ensure_channel_first=True,
|
| 53 |
+
dtype=np.float32,
|
| 54 |
+
),
|
| 55 |
ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
|
| 56 |
+
ConcatItemsd(
|
| 57 |
+
keys=["image", "dwi", "adc"], name="image", dim=0
|
| 58 |
+
), # stacks to (3, H, W)
|
| 59 |
NormalizeIntensity_customd(keys=["image"], channel_wise=True, mask_key="mask"),
|
| 60 |
ElementwiseProductd(keys=["mask", "heatmap"], output_key="final_heatmap"),
|
| 61 |
+
RandWeightedCropd(
|
| 62 |
+
keys=["image", "final_heatmap"],
|
| 63 |
+
w_key="final_heatmap",
|
| 64 |
+
spatial_size=(args.tile_size, args.tile_size, args.depth),
|
| 65 |
+
num_samples=args.tile_count,
|
| 66 |
+
),
|
| 67 |
EnsureTyped(keys=["label"], dtype=torch.float32),
|
| 68 |
Transposed(keys=["image"], indices=(0, 3, 1, 2)),
|
| 69 |
+
DeleteItemsd(keys=["mask", "dwi", "adc", "heatmap"]),
|
| 70 |
ToTensord(keys=["image", "label", "final_heatmap"]),
|
| 71 |
]
|
| 72 |
)
|
| 73 |
else:
|
| 74 |
transform = Compose(
|
| 75 |
[
|
| 76 |
+
LoadImaged(
|
| 77 |
+
keys=["image", "mask", "dwi", "adc"],
|
| 78 |
+
reader=ITKReader(),
|
| 79 |
+
ensure_channel_first=True,
|
| 80 |
+
dtype=np.float32,
|
| 81 |
+
),
|
| 82 |
ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
|
| 83 |
+
ConcatItemsd(
|
| 84 |
+
keys=["image", "dwi", "adc"], name="image", dim=0
|
| 85 |
+
), # stacks to (3, H, W)
|
| 86 |
NormalizeIntensityd(keys=["image"], channel_wise=True),
|
| 87 |
+
RandCropByPosNegLabeld(
|
| 88 |
+
keys=["image"],
|
| 89 |
+
label_key="mask",
|
| 90 |
+
spatial_size=(args.tile_size, args.tile_size, args.depth),
|
| 91 |
+
pos=1,
|
| 92 |
+
neg=0,
|
| 93 |
+
num_samples=args.tile_count,
|
| 94 |
+
),
|
| 95 |
EnsureTyped(keys=["label"], dtype=torch.float32),
|
| 96 |
Transposed(keys=["image"], indices=(0, 3, 1, 2)),
|
| 97 |
+
DeleteItemsd(keys=["mask", "dwi", "adc"]),
|
| 98 |
ToTensord(keys=["image", "label"]),
|
| 99 |
]
|
| 100 |
)
|
| 101 |
return transform
|
| 102 |
|
|
|
|
| 103 |
|
| 104 |
+
def get_dataloader(args, split: Literal["train", "test"]):
|
| 105 |
data_list = load_decathlon_datalist(
|
| 106 |
data_list_file_path=args.dataset_json,
|
| 107 |
data_list_key=split,
|
|
|
|
| 112 |
cache_dir_ = os.path.join(args.logdir, "cache")
|
| 113 |
os.makedirs(os.path.join(cache_dir_, split), exist_ok=True)
|
| 114 |
transform = data_transform(args)
|
| 115 |
+
dataset = PersistentDataset(
|
| 116 |
+
data=data_list, transform=transform, cache_dir=os.path.join(cache_dir_, split)
|
| 117 |
+
)
|
| 118 |
loader = torch.utils.data.DataLoader(
|
| 119 |
dataset,
|
| 120 |
batch_size=args.batch_size,
|
| 121 |
shuffle=(split == "train"),
|
| 122 |
num_workers=args.workers,
|
| 123 |
pin_memory=True,
|
| 124 |
+
multiprocessing_context="fork" if args.workers > 0 else None,
|
| 125 |
sampler=None,
|
| 126 |
collate_fn=list_data_collate,
|
| 127 |
)
|
| 128 |
return loader
|
|
|
src/model/MIL.py
CHANGED
|
@@ -1,13 +1,10 @@
|
|
| 1 |
-
|
| 2 |
from __future__ import annotations
|
| 3 |
-
|
| 4 |
from typing import cast
|
| 5 |
-
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
-
|
| 9 |
from monai.utils.module import optional_import
|
| 10 |
from monai.networks.nets import resnet
|
|
|
|
| 11 |
models, _ = optional_import("torchvision.models")
|
| 12 |
|
| 13 |
|
|
@@ -16,8 +13,7 @@ class MILModel_3D(nn.Module):
|
|
| 16 |
Multiple Instance Learning (MIL) model, with a backbone classification model.
|
| 17 |
Adapted from MONAI, modified for 3D images. The expected shape of input data is `[B, N, C, D, H, W]`,
|
| 18 |
where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances
|
| 19 |
-
extracted from every original image in the batch.
|
| 20 |
-
https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning.
|
| 21 |
|
| 22 |
Args:
|
| 23 |
num_classes: number of output classes.
|
|
@@ -29,10 +25,9 @@ class MILModel_3D(nn.Module):
|
|
| 29 |
- ``"att_trans"`` - transformer MIL https://arxiv.org/abs/2111.01556.
|
| 30 |
- ``"att_trans_pyramid"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556.
|
| 31 |
|
| 32 |
-
pretrained: init backbone with pretrained weights, defaults to ``True``.
|
| 33 |
backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features,
|
| 34 |
or a string name of a torchvision model).
|
| 35 |
-
Defaults to ``None``, in which case
|
| 36 |
backbone_num_features: Number of output features of the backbone CNN
|
| 37 |
Defaults to ``None`` (necessary only when using a custom backbone)
|
| 38 |
trans_blocks: number of the blocks in `TransformEncoder` layer.
|
|
@@ -63,7 +58,11 @@ class MILModel_3D(nn.Module):
|
|
| 63 |
self.transformer: nn.Module | None = None
|
| 64 |
|
| 65 |
if backbone is None:
|
| 66 |
-
net = resnet.resnet18(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
nfc = net.fc.in_features # save the number of final features
|
| 68 |
net.fc = torch.nn.Identity() # remove final linear layer
|
| 69 |
|
|
@@ -72,7 +71,6 @@ class MILModel_3D(nn.Module):
|
|
| 72 |
if mil_mode == "att_trans_pyramid":
|
| 73 |
# register hooks to capture outputs of intermediate layers
|
| 74 |
def forward_hook(layer_name):
|
| 75 |
-
|
| 76 |
def hook(module, input, output):
|
| 77 |
self.extra_outputs[layer_name] = output
|
| 78 |
|
|
@@ -105,31 +103,23 @@ class MILModel_3D(nn.Module):
|
|
| 105 |
nfc = backbone_num_features
|
| 106 |
net.fc = torch.nn.Identity() # remove final linear layer
|
| 107 |
|
| 108 |
-
self.extra_outputs: dict[str, torch.Tensor] = {}
|
| 109 |
-
|
| 110 |
if mil_mode == "att_trans_pyramid":
|
| 111 |
# register hooks to capture outputs of intermediate layers
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
self.extra_outputs[layer_name] = output
|
| 116 |
-
|
| 117 |
-
return hook
|
| 118 |
-
|
| 119 |
-
net.layer1.register_forward_hook(forward_hook("layer1"))
|
| 120 |
-
net.layer2.register_forward_hook(forward_hook("layer2"))
|
| 121 |
-
net.layer3.register_forward_hook(forward_hook("layer3"))
|
| 122 |
-
net.layer4.register_forward_hook(forward_hook("layer4"))
|
| 123 |
|
| 124 |
if backbone_num_features is None:
|
| 125 |
-
raise ValueError(
|
|
|
|
|
|
|
| 126 |
|
| 127 |
else:
|
| 128 |
raise ValueError("Unsupported backbone")
|
| 129 |
-
|
| 130 |
if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]:
|
| 131 |
raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode))
|
| 132 |
-
|
| 133 |
if self.mil_mode in ["mean", "max"]:
|
| 134 |
pass
|
| 135 |
elif self.mil_mode == "att":
|
|
@@ -144,7 +134,8 @@ class MILModel_3D(nn.Module):
|
|
| 144 |
transformer_list = nn.ModuleList(
|
| 145 |
[
|
| 146 |
nn.TransformerEncoder(
|
| 147 |
-
nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout),
|
|
|
|
| 148 |
),
|
| 149 |
nn.Sequential(
|
| 150 |
nn.Linear(192, 64),
|
|
@@ -206,10 +197,26 @@ class MILModel_3D(nn.Module):
|
|
| 206 |
x = self.myfc(x)
|
| 207 |
|
| 208 |
elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None:
|
| 209 |
-
l1 =
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
transformer_list = cast(nn.ModuleList, self.transformer)
|
| 215 |
|
|
@@ -242,7 +249,3 @@ class MILModel_3D(nn.Module):
|
|
| 242 |
x = self.calc_head(x)
|
| 243 |
|
| 244 |
return x
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
|
|
|
| 2 |
from typing import cast
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
|
|
| 5 |
from monai.utils.module import optional_import
|
| 6 |
from monai.networks.nets import resnet
|
| 7 |
+
|
| 8 |
models, _ = optional_import("torchvision.models")
|
| 9 |
|
| 10 |
|
|
|
|
| 13 |
Multiple Instance Learning (MIL) model, with a backbone classification model.
|
| 14 |
Adapted from MONAI, modified for 3D images. The expected shape of input data is `[B, N, C, D, H, W]`,
|
| 15 |
where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances
|
| 16 |
+
extracted from every original image in the batch.
|
|
|
|
| 17 |
|
| 18 |
Args:
|
| 19 |
num_classes: number of output classes.
|
|
|
|
| 25 |
- ``"att_trans"`` - transformer MIL https://arxiv.org/abs/2111.01556.
|
| 26 |
- ``"att_trans_pyramid"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556.
|
| 27 |
|
|
|
|
| 28 |
backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features,
|
| 29 |
or a string name of a torchvision model).
|
| 30 |
+
Defaults to ``None``, in which case ResNet18 is used.
|
| 31 |
backbone_num_features: Number of output features of the backbone CNN
|
| 32 |
Defaults to ``None`` (necessary only when using a custom backbone)
|
| 33 |
trans_blocks: number of the blocks in `TransformEncoder` layer.
|
|
|
|
| 58 |
self.transformer: nn.Module | None = None
|
| 59 |
|
| 60 |
if backbone is None:
|
| 61 |
+
net = resnet.resnet18(
|
| 62 |
+
spatial_dims=3,
|
| 63 |
+
n_input_channels=3,
|
| 64 |
+
num_classes=5,
|
| 65 |
+
)
|
| 66 |
nfc = net.fc.in_features # save the number of final features
|
| 67 |
net.fc = torch.nn.Identity() # remove final linear layer
|
| 68 |
|
|
|
|
| 71 |
if mil_mode == "att_trans_pyramid":
|
| 72 |
# register hooks to capture outputs of intermediate layers
|
| 73 |
def forward_hook(layer_name):
|
|
|
|
| 74 |
def hook(module, input, output):
|
| 75 |
self.extra_outputs[layer_name] = output
|
| 76 |
|
|
|
|
| 103 |
nfc = backbone_num_features
|
| 104 |
net.fc = torch.nn.Identity() # remove final linear layer
|
| 105 |
|
|
|
|
|
|
|
| 106 |
if mil_mode == "att_trans_pyramid":
|
| 107 |
# register hooks to capture outputs of intermediate layers
|
| 108 |
+
raise ValueError(
|
| 109 |
+
"Cannot use att_trans_pyramid with custom backbone. Have to use the default ResNet 18 backbone."
|
| 110 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
if backbone_num_features is None:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
"Number of endencoder features must be provided for a custom backbone model"
|
| 115 |
+
)
|
| 116 |
|
| 117 |
else:
|
| 118 |
raise ValueError("Unsupported backbone")
|
| 119 |
+
|
| 120 |
if backbone is not None and mil_mode not in ["mean", "max", "att", "att_trans"]:
|
| 121 |
raise ValueError("Custom backbone is not supported for the mode:" + str(mil_mode))
|
| 122 |
+
|
| 123 |
if self.mil_mode in ["mean", "max"]:
|
| 124 |
pass
|
| 125 |
elif self.mil_mode == "att":
|
|
|
|
| 134 |
transformer_list = nn.ModuleList(
|
| 135 |
[
|
| 136 |
nn.TransformerEncoder(
|
| 137 |
+
nn.TransformerEncoderLayer(d_model=64, nhead=8, dropout=trans_dropout),
|
| 138 |
+
num_layers=trans_blocks,
|
| 139 |
),
|
| 140 |
nn.Sequential(
|
| 141 |
nn.Linear(192, 64),
|
|
|
|
| 197 |
x = self.myfc(x)
|
| 198 |
|
| 199 |
elif self.mil_mode == "att_trans_pyramid" and self.transformer is not None:
|
| 200 |
+
l1 = (
|
| 201 |
+
torch.mean(self.extra_outputs["layer1"], dim=(2, 3, 4))
|
| 202 |
+
.reshape(sh[0], sh[1], -1)
|
| 203 |
+
.permute(1, 0, 2)
|
| 204 |
+
)
|
| 205 |
+
l2 = (
|
| 206 |
+
torch.mean(self.extra_outputs["layer2"], dim=(2, 3, 4))
|
| 207 |
+
.reshape(sh[0], sh[1], -1)
|
| 208 |
+
.permute(1, 0, 2)
|
| 209 |
+
)
|
| 210 |
+
l3 = (
|
| 211 |
+
torch.mean(self.extra_outputs["layer3"], dim=(2, 3, 4))
|
| 212 |
+
.reshape(sh[0], sh[1], -1)
|
| 213 |
+
.permute(1, 0, 2)
|
| 214 |
+
)
|
| 215 |
+
l4 = (
|
| 216 |
+
torch.mean(self.extra_outputs["layer4"], dim=(2, 3, 4))
|
| 217 |
+
.reshape(sh[0], sh[1], -1)
|
| 218 |
+
.permute(1, 0, 2)
|
| 219 |
+
)
|
| 220 |
|
| 221 |
transformer_list = cast(nn.ModuleList, self.transformer)
|
| 222 |
|
|
|
|
| 249 |
x = self.calc_head(x)
|
| 250 |
|
| 251 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
src/model/csPCa_model.py
CHANGED
|
@@ -1,32 +1,72 @@
|
|
| 1 |
-
|
| 2 |
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
from typing import cast
|
| 5 |
-
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
-
|
| 9 |
from monai.utils.module import optional_import
|
| 10 |
-
models, _ = optional_import("torchvision.models")
|
| 11 |
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class SimpleNN(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def __init__(self, input_dim):
|
| 16 |
super(SimpleNN, self).__init__()
|
| 17 |
self.net = nn.Sequential(
|
| 18 |
nn.Linear(input_dim, 256),
|
| 19 |
nn.ReLU(),
|
| 20 |
-
nn.Linear(
|
| 21 |
nn.ReLU(),
|
| 22 |
-
nn.Dropout(p=0.3),
|
| 23 |
nn.Linear(128, 1),
|
| 24 |
-
nn.Sigmoid()
|
| 25 |
)
|
|
|
|
| 26 |
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
return self.net(x)
|
| 28 |
|
|
|
|
| 29 |
class csPCa_Model(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def __init__(self, backbone):
|
| 31 |
super().__init__()
|
| 32 |
self.backbone = backbone
|
|
@@ -47,4 +87,3 @@ class csPCa_Model(nn.Module):
|
|
| 47 |
|
| 48 |
x = self.fc_cspca(x)
|
| 49 |
return x
|
| 50 |
-
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
|
|
|
| 4 |
from monai.utils.module import optional_import
|
|
|
|
| 5 |
|
| 6 |
+
models, _ = optional_import("torchvision.models")
|
| 7 |
|
| 8 |
|
| 9 |
class SimpleNN(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
A simple Multi-Layer Perceptron (MLP) for binary classification.
|
| 12 |
+
|
| 13 |
+
This network consists of two hidden layers with ReLU activation and a dropout layer,
|
| 14 |
+
followed by a final sigmoid activation for probability output.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
input_dim (int): The number of input features.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
def __init__(self, input_dim):
|
| 21 |
super(SimpleNN, self).__init__()
|
| 22 |
self.net = nn.Sequential(
|
| 23 |
nn.Linear(input_dim, 256),
|
| 24 |
nn.ReLU(),
|
| 25 |
+
nn.Linear(256, 128),
|
| 26 |
nn.ReLU(),
|
| 27 |
+
nn.Dropout(p=0.3),
|
| 28 |
nn.Linear(128, 1),
|
| 29 |
+
nn.Sigmoid(), # since binary classification
|
| 30 |
)
|
| 31 |
+
|
| 32 |
def forward(self, x):
|
| 33 |
+
"""
|
| 34 |
+
Forward pass of the classifier.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
x (torch.Tensor): Input tensor of shape (Batch, input_dim).
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
torch.Tensor: Output probabilities of shape (Batch, 1).
|
| 41 |
+
"""
|
| 42 |
return self.net(x)
|
| 43 |
|
| 44 |
+
|
| 45 |
class csPCa_Model(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
Clinically Significant Prostate Cancer (csPCa) risk prediction model using a MIL backbone.
|
| 48 |
+
|
| 49 |
+
This model repurposes a pre-trained Multiple Instance Learning (MIL) backbone (originally
|
| 50 |
+
designed for PI-RADS prediction) for binary csPCa risk assessment. It utilizes the
|
| 51 |
+
backbone's feature extractor, transformer, and attention mechanism to aggregate instance-level
|
| 52 |
+
features into a bag-level embedding.
|
| 53 |
+
|
| 54 |
+
The original fully connected classification head of the backbone is replaced by a
|
| 55 |
+
custom :class:`SimpleNN` head for the new task.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
backbone (nn.Module): A pre-trained MIL model. The backbone must possess the
|
| 59 |
+
following attributes/sub-modules:
|
| 60 |
+
- ``net``: The CNN feature extractor.
|
| 61 |
+
- ``transformer``: A sequence modeling module.
|
| 62 |
+
- ``attention``: An attention mechanism for pooling.
|
| 63 |
+
- ``myfc``: The original fully connected layer (used to determine feature dimensions).
|
| 64 |
+
|
| 65 |
+
Attributes:
|
| 66 |
+
fc_cspca (SimpleNN): The new classification head for csPCa prediction.
|
| 67 |
+
backbone: The MIL based PI-RADS classifier.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
def __init__(self, backbone):
|
| 71 |
super().__init__()
|
| 72 |
self.backbone = backbone
|
|
|
|
| 87 |
|
| 88 |
x = self.fc_cspca(x)
|
| 89 |
return x
|
|
|
src/preprocessing/center_crop.py
CHANGED
|
@@ -8,10 +8,10 @@
|
|
| 8 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
# See the License for the specific language governing permissions and
|
| 10 |
# limitations under the License.
|
| 11 |
-
#python scripts/center_crop.py --file_name path/to/t2_image --out_name cropped_t2
|
| 12 |
|
| 13 |
|
| 14 |
-
#import argparse
|
| 15 |
from typing import Union
|
| 16 |
|
| 17 |
import SimpleITK as sitk # noqa N813
|
|
@@ -41,7 +41,9 @@ def crop(image: sitk.Image, margin: Union[int, float], interpolator=sitk.sitkLin
|
|
| 41 |
|
| 42 |
# calculate new origin and new image size
|
| 43 |
if all([isinstance(m, float) for m in _flatten(margin)]):
|
| 44 |
-
assert all([m >= 0 and m < 0.5 for m in _flatten(margin)]),
|
|
|
|
|
|
|
| 45 |
to_crop = [[int(sz * _m) for _m in m] for sz, m in zip(old_size, margin)]
|
| 46 |
elif all([isinstance(m, int) for m in _flatten(margin)]):
|
| 47 |
to_crop = margin
|
|
@@ -61,4 +63,3 @@ def crop(image: sitk.Image, margin: Union[int, float], interpolator=sitk.sitkLin
|
|
| 61 |
ref_image.SetDirection(image.GetDirection())
|
| 62 |
|
| 63 |
return sitk.Resample(image, ref_image, interpolator=interpolator)
|
| 64 |
-
|
|
|
|
| 8 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
# See the License for the specific language governing permissions and
|
| 10 |
# limitations under the License.
|
| 11 |
+
# python scripts/center_crop.py --file_name path/to/t2_image --out_name cropped_t2
|
| 12 |
|
| 13 |
|
| 14 |
+
# import argparse
|
| 15 |
from typing import Union
|
| 16 |
|
| 17 |
import SimpleITK as sitk # noqa N813
|
|
|
|
| 41 |
|
| 42 |
# calculate new origin and new image size
|
| 43 |
if all([isinstance(m, float) for m in _flatten(margin)]):
|
| 44 |
+
assert all([m >= 0 and m < 0.5 for m in _flatten(margin)]), (
|
| 45 |
+
"margins must be between 0 and 0.5"
|
| 46 |
+
)
|
| 47 |
to_crop = [[int(sz * _m) for _m in m] for sz, m in zip(old_size, margin)]
|
| 48 |
elif all([isinstance(m, int) for m in _flatten(margin)]):
|
| 49 |
to_crop = margin
|
|
|
|
| 63 |
ref_image.SetDirection(image.GetDirection())
|
| 64 |
|
| 65 |
return sitk.Resample(image, ref_image, interpolator=interpolator)
|
|
|
src/preprocessing/generate_heatmap.py
CHANGED
|
@@ -1,76 +1,93 @@
|
|
| 1 |
-
|
| 2 |
import os
|
| 3 |
import numpy as np
|
| 4 |
import nrrd
|
| 5 |
-
import
|
| 6 |
-
import pandas as pd
|
| 7 |
-
import json
|
| 8 |
-
import SimpleITK as sitk
|
| 9 |
-
import multiprocessing
|
| 10 |
-
|
| 11 |
import logging
|
| 12 |
|
| 13 |
|
| 14 |
def get_heatmap(args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
files = os.listdir(args.t2_dir)
|
| 17 |
-
args.heatmapdir = os.path.join(args.output_dir,
|
| 18 |
os.makedirs(args.heatmapdir, exist_ok=True)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
bool_dwi = False
|
| 22 |
bool_adc = False
|
| 23 |
mask, _ = nrrd.read(os.path.join(args.seg_dir, file))
|
| 24 |
-
dwi,
|
| 25 |
-
adc,
|
| 26 |
-
|
| 27 |
nonzero_vals_dwi = dwi[mask > 0]
|
|
|
|
| 28 |
if len(nonzero_vals_dwi) > 0:
|
| 29 |
min_val = nonzero_vals_dwi.min()
|
| 30 |
max_val = nonzero_vals_dwi.max()
|
| 31 |
heatmap_dwi = np.zeros_like(dwi, dtype=np.float32)
|
| 32 |
-
|
| 33 |
if min_val != max_val:
|
| 34 |
heatmap_dwi = (dwi - min_val) / (max_val - min_val)
|
| 35 |
-
masked_heatmap_dwi = np.where(mask > 0, heatmap_dwi, heatmap_dwi[mask>0].min())
|
| 36 |
else:
|
| 37 |
bool_dwi = True
|
| 38 |
-
|
| 39 |
else:
|
| 40 |
bool_dwi = True
|
| 41 |
|
| 42 |
nonzero_vals_adc = adc[mask > 0]
|
|
|
|
| 43 |
if len(nonzero_vals_adc) > 0:
|
| 44 |
min_val = nonzero_vals_adc.min()
|
| 45 |
max_val = nonzero_vals_adc.max()
|
| 46 |
heatmap_adc = np.zeros_like(adc, dtype=np.float32)
|
| 47 |
-
|
| 48 |
if min_val != max_val:
|
| 49 |
heatmap_adc = (max_val - adc) / (max_val - min_val)
|
| 50 |
-
masked_heatmap_adc = np.where(mask > 0, heatmap_adc, heatmap_adc[mask>0].min())
|
| 51 |
else:
|
| 52 |
bool_adc = True
|
| 53 |
|
| 54 |
else:
|
| 55 |
bool_adc = True
|
| 56 |
-
|
| 57 |
|
| 58 |
-
if bool_dwi:
|
| 59 |
-
mix_mask = masked_heatmap_adc
|
| 60 |
-
if bool_adc:
|
| 61 |
-
mix_mask = masked_heatmap_dwi
|
| 62 |
if not bool_dwi and not bool_adc:
|
| 63 |
mix_mask = masked_heatmap_dwi * masked_heatmap_adc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
else:
|
| 65 |
mix_mask = np.ones_like(adc, dtype=np.float32)
|
|
|
|
| 66 |
|
| 67 |
mix_mask = (mix_mask - mix_mask.min()) / (mix_mask.max() - mix_mask.min())
|
|
|
|
| 68 |
|
| 69 |
-
|
| 70 |
-
nrrd.write(os.path.join(args.heatmapdir, file), mix_mask)
|
| 71 |
-
|
| 72 |
return args
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import numpy as np
|
| 3 |
import nrrd
|
| 4 |
+
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import logging
|
| 6 |
|
| 7 |
|
| 8 |
def get_heatmap(args):
|
| 9 |
+
"""
|
| 10 |
+
Generate heatmaps from DWI (Diffusion Weighted Imaging) and ADC (Apparent Diffusion Coefficient) medical imaging data.
|
| 11 |
+
This function processes medical imaging files (DWI and ADC) along with their corresponding
|
| 12 |
+
segmentation masks to create normalized heatmaps. It combines the DWI
|
| 13 |
+
and ADC heatmaps through element-wise multiplication.
|
| 14 |
+
Args:
|
| 15 |
+
args: An object containing the following attributes:
|
| 16 |
+
- t2_dir (str): Directory path containing T2 image files.
|
| 17 |
+
- dwi_dir (str): Directory path containing DWI image files.
|
| 18 |
+
- adc_dir (str): Directory path containing ADC image files.
|
| 19 |
+
- seg_dir (str): Directory path containing segmentation mask files.
|
| 20 |
+
- output_dir (str): Base output directory where 'heatmaps/' subdirectory will be created.
|
| 21 |
+
- heatmapdir (str): Output directory for heatmap files (created by function).
|
| 22 |
+
Returns:
|
| 23 |
+
args: The modified args object with heatmapdir attribute set.
|
| 24 |
+
Raises:
|
| 25 |
+
FileNotFoundError: If input directories or files do not exist.
|
| 26 |
+
ValueError: If NRRD files cannot be read properly.
|
| 27 |
+
Notes:
|
| 28 |
+
- DWI heatmap is normalized as (dwi - min) / (max - min)
|
| 29 |
+
- ADC heatmap is normalized as (max - adc) / (max - min) (inverted)
|
| 30 |
+
- Final heatmap is re-normalized to [0, 1] range
|
| 31 |
+
- If all values in a mask region are identical, the heatmap is skipped for that modality
|
| 32 |
+
- Output files are written in NRRD format with the same header as the input DWI file
|
| 33 |
+
"""
|
| 34 |
|
| 35 |
files = os.listdir(args.t2_dir)
|
| 36 |
+
args.heatmapdir = os.path.join(args.output_dir, "heatmaps/")
|
| 37 |
os.makedirs(args.heatmapdir, exist_ok=True)
|
| 38 |
+
logging.info("Starting heatmap generation")
|
| 39 |
+
for file in tqdm(files):
|
| 40 |
bool_dwi = False
|
| 41 |
bool_adc = False
|
| 42 |
mask, _ = nrrd.read(os.path.join(args.seg_dir, file))
|
| 43 |
+
dwi, header_dwi = nrrd.read(os.path.join(args.dwi_dir, file))
|
| 44 |
+
adc, header_adc = nrrd.read(os.path.join(args.adc_dir, file))
|
|
|
|
| 45 |
nonzero_vals_dwi = dwi[mask > 0]
|
| 46 |
+
|
| 47 |
if len(nonzero_vals_dwi) > 0:
|
| 48 |
min_val = nonzero_vals_dwi.min()
|
| 49 |
max_val = nonzero_vals_dwi.max()
|
| 50 |
heatmap_dwi = np.zeros_like(dwi, dtype=np.float32)
|
| 51 |
+
|
| 52 |
if min_val != max_val:
|
| 53 |
heatmap_dwi = (dwi - min_val) / (max_val - min_val)
|
| 54 |
+
masked_heatmap_dwi = np.where(mask > 0, heatmap_dwi, heatmap_dwi[mask > 0].min())
|
| 55 |
else:
|
| 56 |
bool_dwi = True
|
| 57 |
+
|
| 58 |
else:
|
| 59 |
bool_dwi = True
|
| 60 |
|
| 61 |
nonzero_vals_adc = adc[mask > 0]
|
| 62 |
+
|
| 63 |
if len(nonzero_vals_adc) > 0:
|
| 64 |
min_val = nonzero_vals_adc.min()
|
| 65 |
max_val = nonzero_vals_adc.max()
|
| 66 |
heatmap_adc = np.zeros_like(adc, dtype=np.float32)
|
| 67 |
+
|
| 68 |
if min_val != max_val:
|
| 69 |
heatmap_adc = (max_val - adc) / (max_val - min_val)
|
| 70 |
+
masked_heatmap_adc = np.where(mask > 0, heatmap_adc, heatmap_adc[mask > 0].min())
|
| 71 |
else:
|
| 72 |
bool_adc = True
|
| 73 |
|
| 74 |
else:
|
| 75 |
bool_adc = True
|
|
|
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
if not bool_dwi and not bool_adc:
|
| 78 |
mix_mask = masked_heatmap_dwi * masked_heatmap_adc
|
| 79 |
+
write_header = header_dwi
|
| 80 |
+
elif bool_dwi:
|
| 81 |
+
mix_mask = masked_heatmap_adc
|
| 82 |
+
write_header = header_adc
|
| 83 |
+
elif bool_adc:
|
| 84 |
+
mix_mask = masked_heatmap_dwi
|
| 85 |
+
write_header = header_dwi
|
| 86 |
else:
|
| 87 |
mix_mask = np.ones_like(adc, dtype=np.float32)
|
| 88 |
+
write_header = header_dwi
|
| 89 |
|
| 90 |
mix_mask = (mix_mask - mix_mask.min()) / (mix_mask.max() - mix_mask.min())
|
| 91 |
+
nrrd.write(os.path.join(args.heatmapdir, file), mix_mask, write_header)
|
| 92 |
|
|
|
|
|
|
|
|
|
|
| 93 |
return args
|
|
|
|
|
|
|
|
|
|
|
|
src/preprocessing/histogram_match.py
CHANGED
|
@@ -1,16 +1,30 @@
|
|
| 1 |
-
import SimpleITK as sitk
|
| 2 |
import os
|
| 3 |
-
import numpy as np
|
| 4 |
import nrrd
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
-
import pandas as pd
|
| 7 |
-
import random
|
| 8 |
-
import json
|
| 9 |
from skimage import exposure
|
| 10 |
-
import multiprocessing
|
| 11 |
import logging
|
| 12 |
-
|
|
|
|
|
|
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
source_pixels = data[mask > 0]
|
| 15 |
ref_pixels = ref_data[ref_mask > 0]
|
| 16 |
matched_pixels = exposure.match_histograms(source_pixels, ref_pixels)
|
|
@@ -19,36 +33,35 @@ def get_histmatched(data, ref_data, mask, ref_mask):
|
|
| 19 |
|
| 20 |
return matched_img
|
| 21 |
|
| 22 |
-
def histmatch(args):
|
| 23 |
|
|
|
|
| 24 |
files = os.listdir(args.t2_dir)
|
| 25 |
-
|
| 26 |
-
t2_histmatched_dir = os.path.join(args.output_dir,
|
| 27 |
-
dwi_histmatched_dir = os.path.join(args.output_dir,
|
| 28 |
-
adc_histmatched_dir = os.path.join(args.output_dir,
|
| 29 |
os.makedirs(t2_histmatched_dir, exist_ok=True)
|
| 30 |
os.makedirs(dwi_histmatched_dir, exist_ok=True)
|
| 31 |
os.makedirs(adc_histmatched_dir, exist_ok=True)
|
| 32 |
logging.info("Starting histogram matching")
|
| 33 |
-
for file in files:
|
| 34 |
|
|
|
|
| 35 |
t2_image, header_t2 = nrrd.read(os.path.join(args.t2_dir, file))
|
| 36 |
dwi_image, header_dwi = nrrd.read(os.path.join(args.dwi_dir, file))
|
| 37 |
adc_image, header_adc = nrrd.read(os.path.join(args.adc_dir, file))
|
| 38 |
|
| 39 |
-
ref_t2, _ = nrrd.read(os.path.join(args.project_dir,
|
| 40 |
-
ref_dwi, _ = nrrd.read(os.path.join(args.project_dir,
|
| 41 |
-
ref_adc
|
| 42 |
-
|
| 43 |
prostate_mask, _ = nrrd.read(os.path.join(args.seg_dir, file))
|
| 44 |
-
ref_prostate_mask, _ = nrrd.read(
|
|
|
|
|
|
|
| 45 |
|
| 46 |
histmatched_t2 = get_histmatched(t2_image, ref_t2, prostate_mask, ref_prostate_mask)
|
| 47 |
histmatched_dwi = get_histmatched(dwi_image, ref_dwi, prostate_mask, ref_prostate_mask)
|
| 48 |
histmatched_adc = get_histmatched(adc_image, ref_adc, prostate_mask, ref_prostate_mask)
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
nrrd.write(os.path.join(t2_histmatched_dir, file), histmatched_t2, header_t2)
|
| 53 |
nrrd.write(os.path.join(dwi_histmatched_dir, file), histmatched_dwi, header_dwi)
|
| 54 |
nrrd.write(os.path.join(adc_histmatched_dir, file), histmatched_adc, header_adc)
|
|
@@ -58,5 +71,3 @@ def histmatch(args):
|
|
| 58 |
args.adc_dir = adc_histmatched_dir
|
| 59 |
|
| 60 |
return args
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import nrrd
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from skimage import exposure
|
|
|
|
| 4 |
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
|
| 9 |
+
def get_histmatched(
|
| 10 |
+
data: np.ndarray, ref_data: np.ndarray, mask: np.ndarray, ref_mask: np.ndarray
|
| 11 |
+
) -> np.ndarray:
|
| 12 |
+
"""
|
| 13 |
+
Perform histogram matching on source data using a reference image.
|
| 14 |
+
This function adjusts the histogram of the source image to match the
|
| 15 |
+
histogram of the reference image within masked regions. Only pixels
|
| 16 |
+
where the mask is greater than 0 are considered for matching.
|
| 17 |
+
Args:
|
| 18 |
+
data: Source image array to be histogram matched.
|
| 19 |
+
ref_data: Reference image array whose histogram will be used as target.
|
| 20 |
+
mask: Binary mask for source image indicating valid pixels (values > 0).
|
| 21 |
+
ref_mask: Binary mask for reference image indicating valid pixels (values > 0).
|
| 22 |
+
Returns:
|
| 23 |
+
Histogram-matched image with the same shape as input data.
|
| 24 |
+
Only pixels in masked regions are modified; unmasked pixels remain unchanged.
|
| 25 |
+
Example:
|
| 26 |
+
>>> matched = get_histmatched(source_img, reference_img, source_mask, ref_mask)
|
| 27 |
+
"""
|
| 28 |
source_pixels = data[mask > 0]
|
| 29 |
ref_pixels = ref_data[ref_mask > 0]
|
| 30 |
matched_pixels = exposure.match_histograms(source_pixels, ref_pixels)
|
|
|
|
| 33 |
|
| 34 |
return matched_img
|
| 35 |
|
|
|
|
| 36 |
|
| 37 |
+
def histmatch(args):
|
| 38 |
files = os.listdir(args.t2_dir)
|
| 39 |
+
|
| 40 |
+
t2_histmatched_dir = os.path.join(args.output_dir, "t2_histmatched")
|
| 41 |
+
dwi_histmatched_dir = os.path.join(args.output_dir, "DWI_histmatched")
|
| 42 |
+
adc_histmatched_dir = os.path.join(args.output_dir, "ADC_histmatched")
|
| 43 |
os.makedirs(t2_histmatched_dir, exist_ok=True)
|
| 44 |
os.makedirs(dwi_histmatched_dir, exist_ok=True)
|
| 45 |
os.makedirs(adc_histmatched_dir, exist_ok=True)
|
| 46 |
logging.info("Starting histogram matching")
|
|
|
|
| 47 |
|
| 48 |
+
for file in tqdm(files):
|
| 49 |
t2_image, header_t2 = nrrd.read(os.path.join(args.t2_dir, file))
|
| 50 |
dwi_image, header_dwi = nrrd.read(os.path.join(args.dwi_dir, file))
|
| 51 |
adc_image, header_adc = nrrd.read(os.path.join(args.adc_dir, file))
|
| 52 |
|
| 53 |
+
ref_t2, _ = nrrd.read(os.path.join(args.project_dir, "dataset", "t2_reference.nrrd"))
|
| 54 |
+
ref_dwi, _ = nrrd.read(os.path.join(args.project_dir, "dataset", "dwi_reference.nrrd"))
|
| 55 |
+
ref_adc, _ = nrrd.read(os.path.join(args.project_dir, "dataset", "adc_reference.nrrd"))
|
|
|
|
| 56 |
prostate_mask, _ = nrrd.read(os.path.join(args.seg_dir, file))
|
| 57 |
+
ref_prostate_mask, _ = nrrd.read(
|
| 58 |
+
os.path.join(args.project_dir, "dataset", "prostate_segmentation_reference.nrrd")
|
| 59 |
+
)
|
| 60 |
|
| 61 |
histmatched_t2 = get_histmatched(t2_image, ref_t2, prostate_mask, ref_prostate_mask)
|
| 62 |
histmatched_dwi = get_histmatched(dwi_image, ref_dwi, prostate_mask, ref_prostate_mask)
|
| 63 |
histmatched_adc = get_histmatched(adc_image, ref_adc, prostate_mask, ref_prostate_mask)
|
| 64 |
|
|
|
|
|
|
|
| 65 |
nrrd.write(os.path.join(t2_histmatched_dir, file), histmatched_t2, header_t2)
|
| 66 |
nrrd.write(os.path.join(dwi_histmatched_dir, file), histmatched_dwi, header_dwi)
|
| 67 |
nrrd.write(os.path.join(adc_histmatched_dir, file), histmatched_adc, header_adc)
|
|
|
|
| 71 |
args.adc_dir = adc_histmatched_dir
|
| 72 |
|
| 73 |
return args
|
|
|
|
|
|
src/preprocessing/prostate_mask.py
CHANGED
|
@@ -1,65 +1,63 @@
|
|
| 1 |
import os
|
| 2 |
-
from typing import Union
|
| 3 |
-
import SimpleITK as sitk
|
| 4 |
import numpy as np
|
| 5 |
import nrrd
|
| 6 |
-
import matplotlib.pyplot as plt
|
| 7 |
from tqdm import tqdm
|
| 8 |
-
from AIAH_utility.viewer import BasicViewer, ListViewer
|
| 9 |
-
from PIL import Image
|
| 10 |
-
import monai
|
| 11 |
from monai.bundle import ConfigParser
|
| 12 |
-
from monai.config import print_config
|
| 13 |
import torch
|
| 14 |
-
import sys
|
| 15 |
-
import os
|
| 16 |
-
import nibabel as nib
|
| 17 |
-
import shutil
|
| 18 |
|
| 19 |
-
from tqdm import trange, tqdm
|
| 20 |
|
| 21 |
-
from monai.data import DataLoader, Dataset, TestTimeAugmentation, create_test_image_2d
|
| 22 |
-
from monai.losses import DiceLoss
|
| 23 |
-
from monai.metrics import DiceMetric
|
| 24 |
-
from monai.networks.nets import UNet
|
| 25 |
from monai.transforms import (
|
| 26 |
-
Activationsd,
|
| 27 |
-
AsDiscreted,
|
| 28 |
Compose,
|
| 29 |
-
CropForegroundd,
|
| 30 |
-
DivisiblePadd,
|
| 31 |
-
Invertd,
|
| 32 |
LoadImaged,
|
| 33 |
ScaleIntensityd,
|
| 34 |
-
RandRotated,
|
| 35 |
-
RandRotate,
|
| 36 |
-
InvertibleTransform,
|
| 37 |
-
RandFlipd,
|
| 38 |
-
Activations,
|
| 39 |
-
AsDiscrete,
|
| 40 |
NormalizeIntensityd,
|
| 41 |
)
|
| 42 |
from monai.utils import set_determinism
|
| 43 |
from monai.transforms import (
|
| 44 |
-
Resize,
|
| 45 |
EnsureChannelFirstd,
|
| 46 |
Orientationd,
|
| 47 |
Spacingd,
|
| 48 |
EnsureTyped,
|
| 49 |
)
|
| 50 |
-
import nrrd
|
| 51 |
-
|
| 52 |
-
set_determinism(43)
|
| 53 |
from monai.data import MetaTensor
|
| 54 |
-
import SimpleITK as sitk
|
| 55 |
-
import pandas as pd
|
| 56 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def get_segmask(args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
args.seg_dir = os.path.join(args.output_dir, "prostate_mask")
|
| 60 |
os.makedirs(args.seg_dir, exist_ok=True)
|
| 61 |
|
| 62 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
| 63 |
model_config_file = os.path.join(args.project_dir, "config", "inference.json")
|
| 64 |
model_config = ConfigParser()
|
| 65 |
model_config.read_config(model_config_file)
|
|
@@ -67,24 +65,16 @@ def get_segmask(args):
|
|
| 67 |
model_config["dataset_dir"] = args.t2_dir
|
| 68 |
files = os.listdir(args.t2_dir)
|
| 69 |
model_config["datalist"] = [os.path.join(args.t2_dir, f) for f in files]
|
| 70 |
-
|
| 71 |
-
|
| 72 |
checkpoint = os.path.join(
|
| 73 |
args.project_dir,
|
| 74 |
"models",
|
| 75 |
"prostate_segmentation_model.pt",
|
| 76 |
)
|
| 77 |
-
preprocessing = model_config.get_parsed_content("preprocessing")
|
| 78 |
model = model_config.get_parsed_content("network_def").to(device)
|
| 79 |
inferer = model_config.get_parsed_content("inferer")
|
| 80 |
-
postprocessing = model_config.get_parsed_content("postprocessing")
|
| 81 |
-
dataloader = model_config.get_parsed_content("dataloader")
|
| 82 |
model.load_state_dict(torch.load(checkpoint, map_location=device))
|
| 83 |
model.eval()
|
| 84 |
|
| 85 |
-
torch.cuda.empty_cache()
|
| 86 |
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 87 |
-
|
| 88 |
keys = "image"
|
| 89 |
transform = Compose(
|
| 90 |
[
|
|
@@ -99,8 +89,8 @@ def get_segmask(args):
|
|
| 99 |
)
|
| 100 |
logging.info("Starting prostate segmentation")
|
| 101 |
for file in tqdm(files):
|
| 102 |
-
|
| 103 |
data = {"image": os.path.join(args.t2_dir, file)}
|
|
|
|
| 104 |
transformed_data = transform(data)
|
| 105 |
a = transformed_data
|
| 106 |
with torch.no_grad():
|
|
@@ -114,15 +104,12 @@ def get_segmask(args):
|
|
| 114 |
temp = transform.inverse(transformed_data)
|
| 115 |
pred_temp = temp["image"][0].numpy()
|
| 116 |
pred_nrrd = np.round(pred_temp)
|
| 117 |
-
|
| 118 |
-
nonzero_counts = np.count_nonzero(pred_nrrd, axis=(0,1))
|
| 119 |
top_slices = np.argsort(nonzero_counts)[-10:]
|
| 120 |
output_ = np.zeros_like(pred_nrrd)
|
| 121 |
-
output_[:,:,top_slices] = pred_nrrd[:,:,top_slices]
|
| 122 |
-
|
| 123 |
-
nrrd.write(os.path.join(args.seg_dir, file), output_)
|
| 124 |
-
|
| 125 |
-
return args
|
| 126 |
|
|
|
|
| 127 |
|
| 128 |
-
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import nrrd
|
|
|
|
| 4 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
| 5 |
from monai.bundle import ConfigParser
|
|
|
|
| 6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
|
|
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from monai.transforms import (
|
|
|
|
|
|
|
| 10 |
Compose,
|
|
|
|
|
|
|
|
|
|
| 11 |
LoadImaged,
|
| 12 |
ScaleIntensityd,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
NormalizeIntensityd,
|
| 14 |
)
|
| 15 |
from monai.utils import set_determinism
|
| 16 |
from monai.transforms import (
|
|
|
|
| 17 |
EnsureChannelFirstd,
|
| 18 |
Orientationd,
|
| 19 |
Spacingd,
|
| 20 |
EnsureTyped,
|
| 21 |
)
|
|
|
|
|
|
|
|
|
|
| 22 |
from monai.data import MetaTensor
|
|
|
|
|
|
|
| 23 |
import logging
|
| 24 |
+
|
| 25 |
+
set_determinism(43)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
def get_segmask(args):
|
| 29 |
+
"""
|
| 30 |
+
Generate prostate segmentation masks using a pre-trained deep learning model.
|
| 31 |
+
This function performs inference on T2-weighted MRI images to segment the prostate gland.
|
| 32 |
+
It applies preprocessing transformations, runs the segmentation model, and saves the
|
| 33 |
+
predicted masks. Post-processing is applied to retain only the top 10 slices with
|
| 34 |
+
the highest non-zero voxel counts.
|
| 35 |
+
Args:
|
| 36 |
+
args: An arguments object containing:
|
| 37 |
+
- output_dir (str): Base output directory where segmentation masks will be saved
|
| 38 |
+
- project_dir (str): Root project directory containing model config and checkpoint
|
| 39 |
+
- t2_dir (str): Directory containing input T2-weighted MRI images in NRRD format
|
| 40 |
+
Returns:
|
| 41 |
+
args: The updated arguments object with seg_dir added, pointing to the
|
| 42 |
+
prostate_mask subdirectory within output_dir
|
| 43 |
+
Raises:
|
| 44 |
+
FileNotFoundError: If the model checkpoint or config file is not found
|
| 45 |
+
RuntimeError: If CUDA operations fail on GPU
|
| 46 |
+
Notes:
|
| 47 |
+
- Automatically selects GPU (CUDA) if available, otherwise uses CPU
|
| 48 |
+
- Applies MONAI transformations: loading, orientation (RAS), spacing (0.5mm isotropic),
|
| 49 |
+
intensity scaling and normalization
|
| 50 |
+
- Post-processing filters predictions to top 10 slices by non-zero voxel density
|
| 51 |
+
- Output masks are saved in NRRD format preserving original image headers
|
| 52 |
+
"""
|
| 53 |
|
| 54 |
args.seg_dir = os.path.join(args.output_dir, "prostate_mask")
|
| 55 |
os.makedirs(args.seg_dir, exist_ok=True)
|
| 56 |
|
| 57 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 58 |
+
torch.cuda.empty_cache()
|
| 59 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 60 |
+
|
| 61 |
model_config_file = os.path.join(args.project_dir, "config", "inference.json")
|
| 62 |
model_config = ConfigParser()
|
| 63 |
model_config.read_config(model_config_file)
|
|
|
|
| 65 |
model_config["dataset_dir"] = args.t2_dir
|
| 66 |
files = os.listdir(args.t2_dir)
|
| 67 |
model_config["datalist"] = [os.path.join(args.t2_dir, f) for f in files]
|
|
|
|
|
|
|
| 68 |
checkpoint = os.path.join(
|
| 69 |
args.project_dir,
|
| 70 |
"models",
|
| 71 |
"prostate_segmentation_model.pt",
|
| 72 |
)
|
|
|
|
| 73 |
model = model_config.get_parsed_content("network_def").to(device)
|
| 74 |
inferer = model_config.get_parsed_content("inferer")
|
|
|
|
|
|
|
| 75 |
model.load_state_dict(torch.load(checkpoint, map_location=device))
|
| 76 |
model.eval()
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 78 |
keys = "image"
|
| 79 |
transform = Compose(
|
| 80 |
[
|
|
|
|
| 89 |
)
|
| 90 |
logging.info("Starting prostate segmentation")
|
| 91 |
for file in tqdm(files):
|
|
|
|
| 92 |
data = {"image": os.path.join(args.t2_dir, file)}
|
| 93 |
+
_, header_t2 = nrrd.read(data["image"])
|
| 94 |
transformed_data = transform(data)
|
| 95 |
a = transformed_data
|
| 96 |
with torch.no_grad():
|
|
|
|
| 104 |
temp = transform.inverse(transformed_data)
|
| 105 |
pred_temp = temp["image"][0].numpy()
|
| 106 |
pred_nrrd = np.round(pred_temp)
|
| 107 |
+
|
| 108 |
+
nonzero_counts = np.count_nonzero(pred_nrrd, axis=(0, 1))
|
| 109 |
top_slices = np.argsort(nonzero_counts)[-10:]
|
| 110 |
output_ = np.zeros_like(pred_nrrd)
|
| 111 |
+
output_[:, :, top_slices] = pred_nrrd[:, :, top_slices]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
+
nrrd.write(os.path.join(args.seg_dir, file), output_, header_t2)
|
| 114 |
|
| 115 |
+
return args
|
src/preprocessing/register_and_crop.py
CHANGED
|
@@ -1,25 +1,42 @@
|
|
| 1 |
-
import SimpleITK as sitk
|
| 2 |
import os
|
| 3 |
-
import numpy as np
|
| 4 |
-
import nrrd
|
| 5 |
from tqdm import tqdm
|
| 6 |
-
import pandas as pd
|
| 7 |
from picai_prep.preprocessing import PreprocessingSettings, Sample
|
| 8 |
-
import multiprocessing
|
| 9 |
from .center_crop import crop
|
| 10 |
import logging
|
|
|
|
|
|
|
| 11 |
def register_files(args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
files = os.listdir(args.t2_dir)
|
| 13 |
-
new_spacing = (0.4, 0.4, 3.0)
|
| 14 |
-
t2_registered_dir = os.path.join(args.output_dir,
|
| 15 |
-
dwi_registered_dir = os.path.join(args.output_dir,
|
| 16 |
-
adc_registered_dir = os.path.join(args.output_dir,
|
| 17 |
os.makedirs(t2_registered_dir, exist_ok=True)
|
| 18 |
os.makedirs(dwi_registered_dir, exist_ok=True)
|
| 19 |
os.makedirs(adc_registered_dir, exist_ok=True)
|
| 20 |
logging.info("Starting registration and cropping")
|
| 21 |
for file in tqdm(files):
|
| 22 |
-
|
| 23 |
t2_image = sitk.ReadImage(os.path.join(args.t2_dir, file))
|
| 24 |
dwi_image = sitk.ReadImage(os.path.join(args.dwi_dir, file))
|
| 25 |
adc_image = sitk.ReadImage(os.path.join(args.adc_dir, file))
|
|
@@ -30,38 +47,37 @@ def register_files(args):
|
|
| 30 |
int(round(osz * ospc / nspc))
|
| 31 |
for osz, ospc, nspc in zip(original_size, original_spacing, new_spacing)
|
| 32 |
]
|
| 33 |
-
|
| 34 |
images_to_preprocess = {}
|
| 35 |
-
images_to_preprocess[
|
| 36 |
-
images_to_preprocess[
|
| 37 |
-
images_to_preprocess[
|
| 38 |
|
| 39 |
pat_case = Sample(
|
| 40 |
scans=[
|
| 41 |
-
images_to_preprocess.get(
|
| 42 |
-
images_to_preprocess.get(
|
| 43 |
-
images_to_preprocess.get(
|
| 44 |
],
|
| 45 |
-
settings=PreprocessingSettings(
|
|
|
|
|
|
|
| 46 |
)
|
| 47 |
pat_case.preprocess()
|
| 48 |
-
|
| 49 |
-
t2_post = pat_case.__dict__[
|
| 50 |
-
dwi_post = pat_case.__dict__[
|
| 51 |
-
adc_post = pat_case.__dict__[
|
| 52 |
cropped_t2 = crop(t2_post, [args.margin, args.margin, 0.0])
|
| 53 |
cropped_dwi = crop(dwi_post, [args.margin, args.margin, 0.0])
|
| 54 |
cropped_adc = crop(adc_post, [args.margin, args.margin, 0.0])
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
sitk.WriteImage(cropped_t2, os.path.join(t2_registered_dir, file))
|
| 59 |
sitk.WriteImage(cropped_dwi, os.path.join(dwi_registered_dir, file))
|
| 60 |
sitk.WriteImage(cropped_adc, os.path.join(adc_registered_dir, file))
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
return args
|
| 67 |
|
|
|
|
|
|
| 1 |
+
import SimpleITK as sitk
|
| 2 |
import os
|
|
|
|
|
|
|
| 3 |
from tqdm import tqdm
|
|
|
|
| 4 |
from picai_prep.preprocessing import PreprocessingSettings, Sample
|
|
|
|
| 5 |
from .center_crop import crop
|
| 6 |
import logging
|
| 7 |
+
|
| 8 |
+
|
| 9 |
def register_files(args):
|
| 10 |
+
"""
|
| 11 |
+
Register and crop medical images (T2, DWI, and ADC) to a standardized spacing and size.
|
| 12 |
+
This function reads medical images from specified directories, resamples them to a
|
| 13 |
+
new spacing of (0.4, 0.4, 3.0) mm, preprocesses them using the Sample class, and crops
|
| 14 |
+
them with specified margins. The processed images are saved to new output directories.
|
| 15 |
+
Args:
|
| 16 |
+
args: An argument object containing:
|
| 17 |
+
- t2_dir (str): Directory path containing T2 weighted images
|
| 18 |
+
- dwi_dir (str): Directory path containing DWI (Diffusion Weighted Imaging) images
|
| 19 |
+
- adc_dir (str): Directory path containing ADC (Apparent Diffusion Coefficient) images
|
| 20 |
+
- output_dir (str): Directory path where registered images will be saved
|
| 21 |
+
- margin (float): Margin in mm to crop from x and y dimensions
|
| 22 |
+
Returns:
|
| 23 |
+
args: Updated argument object with modified directory paths pointing to the
|
| 24 |
+
registered image directories (t2_registered, DWI_registered, ADC_registered)
|
| 25 |
+
Raises:
|
| 26 |
+
FileNotFoundError: If input directories do not exist or files cannot be read
|
| 27 |
+
RuntimeError: If image preprocessing or cropping fails
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
files = os.listdir(args.t2_dir)
|
| 31 |
+
new_spacing = (0.4, 0.4, 3.0)
|
| 32 |
+
t2_registered_dir = os.path.join(args.output_dir, "t2_registered")
|
| 33 |
+
dwi_registered_dir = os.path.join(args.output_dir, "DWI_registered")
|
| 34 |
+
adc_registered_dir = os.path.join(args.output_dir, "ADC_registered")
|
| 35 |
os.makedirs(t2_registered_dir, exist_ok=True)
|
| 36 |
os.makedirs(dwi_registered_dir, exist_ok=True)
|
| 37 |
os.makedirs(adc_registered_dir, exist_ok=True)
|
| 38 |
logging.info("Starting registration and cropping")
|
| 39 |
for file in tqdm(files):
|
|
|
|
| 40 |
t2_image = sitk.ReadImage(os.path.join(args.t2_dir, file))
|
| 41 |
dwi_image = sitk.ReadImage(os.path.join(args.dwi_dir, file))
|
| 42 |
adc_image = sitk.ReadImage(os.path.join(args.adc_dir, file))
|
|
|
|
| 47 |
int(round(osz * ospc / nspc))
|
| 48 |
for osz, ospc, nspc in zip(original_size, original_spacing, new_spacing)
|
| 49 |
]
|
| 50 |
+
|
| 51 |
images_to_preprocess = {}
|
| 52 |
+
images_to_preprocess["t2"] = t2_image
|
| 53 |
+
images_to_preprocess["hbv"] = dwi_image
|
| 54 |
+
images_to_preprocess["adc"] = adc_image
|
| 55 |
|
| 56 |
pat_case = Sample(
|
| 57 |
scans=[
|
| 58 |
+
images_to_preprocess.get("t2"),
|
| 59 |
+
images_to_preprocess.get("hbv"),
|
| 60 |
+
images_to_preprocess.get("adc"),
|
| 61 |
],
|
| 62 |
+
settings=PreprocessingSettings(
|
| 63 |
+
spacing=[3.0, 0.4, 0.4], matrix_size=[new_size[2], new_size[1], new_size[0]]
|
| 64 |
+
),
|
| 65 |
)
|
| 66 |
pat_case.preprocess()
|
| 67 |
+
|
| 68 |
+
t2_post = pat_case.__dict__["scans"][0]
|
| 69 |
+
dwi_post = pat_case.__dict__["scans"][1]
|
| 70 |
+
adc_post = pat_case.__dict__["scans"][2]
|
| 71 |
cropped_t2 = crop(t2_post, [args.margin, args.margin, 0.0])
|
| 72 |
cropped_dwi = crop(dwi_post, [args.margin, args.margin, 0.0])
|
| 73 |
cropped_adc = crop(adc_post, [args.margin, args.margin, 0.0])
|
| 74 |
|
|
|
|
|
|
|
| 75 |
sitk.WriteImage(cropped_t2, os.path.join(t2_registered_dir, file))
|
| 76 |
sitk.WriteImage(cropped_dwi, os.path.join(dwi_registered_dir, file))
|
| 77 |
sitk.WriteImage(cropped_adc, os.path.join(adc_registered_dir, file))
|
| 78 |
|
| 79 |
+
args.t2_dir = t2_registered_dir
|
| 80 |
+
args.dwi_dir = dwi_registered_dir
|
| 81 |
+
args.adc_dir = adc_registered_dir
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
return args
|
src/train/train_cspca.py
CHANGED
|
@@ -1,91 +1,18 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import collections.abc
|
| 3 |
-
import os
|
| 4 |
-
import shutil
|
| 5 |
-
import time
|
| 6 |
-
import yaml
|
| 7 |
-
from scipy.stats import pearsonr
|
| 8 |
-
import gdown
|
| 9 |
-
import numpy as np
|
| 10 |
import torch
|
| 11 |
-
import torch.distributed as dist
|
| 12 |
-
import torch.multiprocessing as mp
|
| 13 |
import torch.nn as nn
|
| 14 |
-
import torch.nn.functional as F
|
| 15 |
-
from monai.config import KeysCollection
|
| 16 |
-
from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
|
| 17 |
-
from monai.data.wsi_reader import WSIReader
|
| 18 |
from monai.metrics import Cumulative, CumulativeAverage
|
| 19 |
-
from
|
| 20 |
-
from monai.transforms import (
|
| 21 |
-
Compose,
|
| 22 |
-
GridPatchd,
|
| 23 |
-
LoadImaged,
|
| 24 |
-
MapTransform,
|
| 25 |
-
RandFlipd,
|
| 26 |
-
RandGridPatchd,
|
| 27 |
-
RandRotate90d,
|
| 28 |
-
ScaleIntensityRanged,
|
| 29 |
-
SplitDimd,
|
| 30 |
-
ToTensord,
|
| 31 |
-
ConcatItemsd,
|
| 32 |
-
SelectItemsd,
|
| 33 |
-
EnsureChannelFirstd,
|
| 34 |
-
RepeatChanneld,
|
| 35 |
-
DeleteItemsd,
|
| 36 |
-
EnsureTyped,
|
| 37 |
-
ClipIntensityPercentilesd,
|
| 38 |
-
MaskIntensityd,
|
| 39 |
-
HistogramNormalized,
|
| 40 |
-
RandBiasFieldd,
|
| 41 |
-
RandCropByPosNegLabeld,
|
| 42 |
-
NormalizeIntensityd,
|
| 43 |
-
SqueezeDimd,
|
| 44 |
-
CropForegroundd,
|
| 45 |
-
ScaleIntensityd,
|
| 46 |
-
SpatialPadd,
|
| 47 |
-
CenterSpatialCropd,
|
| 48 |
-
ScaleIntensityd,
|
| 49 |
-
Transposed,
|
| 50 |
-
RandWeightedCropd,
|
| 51 |
-
)
|
| 52 |
-
from sklearn.metrics import cohen_kappa_score, roc_curve, confusion_matrix
|
| 53 |
-
from torch.cuda.amp import GradScaler, autocast
|
| 54 |
-
from torch.utils.data.dataloader import default_collate
|
| 55 |
-
from torchvision.models.resnet import ResNet50_Weights
|
| 56 |
-
import torch.optim as optim
|
| 57 |
-
from torch.utils.data.distributed import DistributedSampler
|
| 58 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 59 |
-
import matplotlib.pyplot as plt
|
| 60 |
-
import matplotlib.patches as patches
|
| 61 |
-
from tqdm import tqdm
|
| 62 |
-
from sklearn.metrics import confusion_matrix, roc_auc_score
|
| 63 |
from sklearn.metrics import roc_auc_score
|
| 64 |
-
|
| 65 |
-
import numpy as np
|
| 66 |
-
from AIAH_utility.viewer import BasicViewer
|
| 67 |
-
from scipy.special import expit
|
| 68 |
-
import nrrd
|
| 69 |
-
import random
|
| 70 |
-
from sklearn.metrics import roc_auc_score
|
| 71 |
-
import SimpleITK as sitk
|
| 72 |
-
from AIAH_utility.viewer import BasicViewer
|
| 73 |
-
import pandas as pd
|
| 74 |
-
import json
|
| 75 |
-
from sklearn.preprocessing import StandardScaler
|
| 76 |
-
from torch.utils.data import DataLoader, TensorDataset, Dataset
|
| 77 |
-
from sklearn.linear_model import LogisticRegression
|
| 78 |
-
from sklearn.utils import resample
|
| 79 |
-
import monai
|
| 80 |
|
| 81 |
def train_epoch(cspca_model, loader, optimizer, epoch, args):
|
| 82 |
cspca_model.train()
|
| 83 |
-
criterion = nn.BCELoss()
|
| 84 |
loss = 0.0
|
| 85 |
run_loss = CumulativeAverage()
|
| 86 |
TARGETS = Cumulative()
|
| 87 |
PREDS = Cumulative()
|
| 88 |
-
|
| 89 |
for idx, batch_data in enumerate(loader):
|
| 90 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 91 |
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
|
|
@@ -108,9 +35,10 @@ def train_epoch(cspca_model, loader, optimizer, epoch, args):
|
|
| 108 |
|
| 109 |
return loss_epoch, auc_epoch
|
| 110 |
|
|
|
|
| 111 |
def val_epoch(cspca_model, loader, epoch, args):
|
| 112 |
cspca_model.eval()
|
| 113 |
-
criterion = nn.BCELoss()
|
| 114 |
loss = 0.0
|
| 115 |
run_loss = CumulativeAverage()
|
| 116 |
TARGETS = Cumulative()
|
|
@@ -132,10 +60,15 @@ def val_epoch(cspca_model, loader, epoch, args):
|
|
| 132 |
target_list = TARGETS.get_buffer().cpu().numpy()
|
| 133 |
pred_list = PREDS.get_buffer().cpu().numpy()
|
| 134 |
auc_epoch = roc_auc_score(target_list, pred_list)
|
| 135 |
-
y_pred_categoric =
|
| 136 |
tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel()
|
| 137 |
sens_epoch = tp / (tp + fn)
|
| 138 |
spec_epoch = tn / (tn + fp)
|
| 139 |
-
val_epoch_metric = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
return val_epoch_metric
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
|
|
|
|
|
|
| 2 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from monai.metrics import Cumulative, CumulativeAverage
|
| 4 |
+
from sklearn.metrics import confusion_matrix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from sklearn.metrics import roc_auc_score
|
| 6 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
def train_epoch(cspca_model, loader, optimizer, epoch, args):
|
| 9 |
cspca_model.train()
|
| 10 |
+
criterion = nn.BCELoss()
|
| 11 |
loss = 0.0
|
| 12 |
run_loss = CumulativeAverage()
|
| 13 |
TARGETS = Cumulative()
|
| 14 |
PREDS = Cumulative()
|
| 15 |
+
|
| 16 |
for idx, batch_data in enumerate(loader):
|
| 17 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 18 |
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
|
|
|
|
| 35 |
|
| 36 |
return loss_epoch, auc_epoch
|
| 37 |
|
| 38 |
+
|
| 39 |
def val_epoch(cspca_model, loader, epoch, args):
|
| 40 |
cspca_model.eval()
|
| 41 |
+
criterion = nn.BCELoss()
|
| 42 |
loss = 0.0
|
| 43 |
run_loss = CumulativeAverage()
|
| 44 |
TARGETS = Cumulative()
|
|
|
|
| 60 |
target_list = TARGETS.get_buffer().cpu().numpy()
|
| 61 |
pred_list = PREDS.get_buffer().cpu().numpy()
|
| 62 |
auc_epoch = roc_auc_score(target_list, pred_list)
|
| 63 |
+
y_pred_categoric = pred_list >= 0.5
|
| 64 |
tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel()
|
| 65 |
sens_epoch = tp / (tp + fn)
|
| 66 |
spec_epoch = tn / (tn + fp)
|
| 67 |
+
val_epoch_metric = {
|
| 68 |
+
"epoch": epoch,
|
| 69 |
+
"loss": loss_epoch,
|
| 70 |
+
"auc": auc_epoch,
|
| 71 |
+
"sensitivity": sens_epoch,
|
| 72 |
+
"specificity": spec_epoch,
|
| 73 |
+
}
|
| 74 |
return val_epoch_metric
|
|
|
src/train/train_pirads.py
CHANGED
|
@@ -1,65 +1,11 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import collections.abc
|
| 3 |
-
import os
|
| 4 |
-
import shutil
|
| 5 |
import time
|
| 6 |
-
import yaml
|
| 7 |
-
import sys
|
| 8 |
-
import gdown
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
-
import torch.distributed as dist
|
| 12 |
-
import torch.multiprocessing as mp
|
| 13 |
import torch.nn as nn
|
| 14 |
-
import torch.nn.functional as F
|
| 15 |
-
from monai.config import KeysCollection
|
| 16 |
-
from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
|
| 17 |
-
from monai.data.wsi_reader import WSIReader
|
| 18 |
from monai.metrics import Cumulative, CumulativeAverage
|
| 19 |
-
from monai.networks.nets import milmodel, resnet, MILModel
|
| 20 |
-
from monai.transforms import (
|
| 21 |
-
Compose,
|
| 22 |
-
GridPatchd,
|
| 23 |
-
LoadImaged,
|
| 24 |
-
MapTransform,
|
| 25 |
-
RandFlipd,
|
| 26 |
-
RandGridPatchd,
|
| 27 |
-
RandRotate90d,
|
| 28 |
-
ScaleIntensityRanged,
|
| 29 |
-
SplitDimd,
|
| 30 |
-
ToTensord,
|
| 31 |
-
ConcatItemsd,
|
| 32 |
-
SelectItemsd,
|
| 33 |
-
EnsureChannelFirstd,
|
| 34 |
-
RepeatChanneld,
|
| 35 |
-
DeleteItemsd,
|
| 36 |
-
EnsureTyped,
|
| 37 |
-
ClipIntensityPercentilesd,
|
| 38 |
-
MaskIntensityd,
|
| 39 |
-
HistogramNormalized,
|
| 40 |
-
RandBiasFieldd,
|
| 41 |
-
RandCropByPosNegLabeld,
|
| 42 |
-
NormalizeIntensityd,
|
| 43 |
-
SqueezeDimd,
|
| 44 |
-
CropForegroundd,
|
| 45 |
-
ScaleIntensityd,
|
| 46 |
-
SpatialPadd,
|
| 47 |
-
CenterSpatialCropd,
|
| 48 |
-
ScaleIntensityd,
|
| 49 |
-
Transposed,
|
| 50 |
-
RandWeightedCropd,
|
| 51 |
-
)
|
| 52 |
from sklearn.metrics import cohen_kappa_score
|
| 53 |
-
from torch.cuda.amp import GradScaler, autocast
|
| 54 |
-
from torch.utils.data.dataloader import default_collate
|
| 55 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 56 |
-
import matplotlib.pyplot as plt
|
| 57 |
-
import wandb
|
| 58 |
-
import math
|
| 59 |
import logging
|
| 60 |
-
|
| 61 |
-
from src.data.data_loader import get_dataloader
|
| 62 |
-
from src.utils import save_pirads_checkpoint, setup_logging
|
| 63 |
|
| 64 |
def get_lambda_att(epoch, max_lambda=2.0, warmup_epochs=10):
|
| 65 |
if epoch < warmup_epochs:
|
|
@@ -67,26 +13,53 @@ def get_lambda_att(epoch, max_lambda=2.0, warmup_epochs=10):
|
|
| 67 |
else:
|
| 68 |
return max_lambda
|
| 69 |
|
|
|
|
| 70 |
def get_attention_scores(data, target, heatmap, args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
attention_score = torch.zeros((data.shape[0], data.shape[1]))
|
| 72 |
for i in range(data.shape[0]):
|
| 73 |
sample = heatmap[i]
|
| 74 |
heatmap_patches = sample.squeeze(1)
|
| 75 |
-
raw_scores = heatmap_patches.view(len(heatmap_patches), -1).sum(dim=1)
|
| 76 |
-
attention_score[i] = raw_scores / raw_scores.sum()
|
| 77 |
shuffled_images = torch.empty_like(data).to(args.device)
|
| 78 |
att_labels = torch.empty_like(attention_score).to(args.device)
|
| 79 |
-
for i in range(data.shape[0]):
|
| 80 |
-
perm = torch.randperm(data.shape[1])
|
| 81 |
shuffled_images[i] = data[i, perm]
|
| 82 |
-
att_labels[i] = attention_score[i, perm]
|
| 83 |
-
|
| 84 |
-
att_labels[torch.argwhere(target < 1)] = torch.ones_like(att_labels[0]) / len(
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
att_labels = att_labels / att_labels.sum(dim=1, keepdim=True)
|
| 87 |
-
|
| 88 |
return att_labels, shuffled_images
|
| 89 |
|
|
|
|
| 90 |
def train_epoch(model, loader, optimizer, scaler, epoch, args):
|
| 91 |
"""One train epoch over the dataset"""
|
| 92 |
lambda_att = get_lambda_att(epoch, warmup_epochs=25)
|
|
@@ -104,13 +77,14 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):
|
|
| 104 |
loss, acc = 0.0, 0.0
|
| 105 |
|
| 106 |
for idx, batch_data in enumerate(loader):
|
| 107 |
-
|
| 108 |
eps = 1e-8
|
| 109 |
data = batch_data["image"].as_subclass(torch.Tensor)
|
| 110 |
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
|
| 111 |
target = target.long()
|
| 112 |
if args.use_heatmap:
|
| 113 |
-
att_labels, shuffled_images = get_attention_scores(
|
|
|
|
|
|
|
| 114 |
att_labels = att_labels + eps
|
| 115 |
else:
|
| 116 |
shuffled_images = data.to(args.device)
|
|
@@ -139,7 +113,7 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):
|
|
| 139 |
b = b + eps
|
| 140 |
att_preds = torch.softmax(b, dim=1)
|
| 141 |
attn_loss = 1 - att_criterion(att_preds, att_labels).mean()
|
| 142 |
-
loss = class_loss + (lambda_att*attn_loss)
|
| 143 |
else:
|
| 144 |
loss = class_loss
|
| 145 |
attn_loss = torch.tensor(0.0)
|
|
@@ -150,14 +124,10 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):
|
|
| 150 |
if not torch.isfinite(total_norm):
|
| 151 |
logging.warning("Non-finite gradient norm detected, skipping batch.")
|
| 152 |
optimizer.zero_grad()
|
|
|
|
| 153 |
else:
|
| 154 |
scaler.step(optimizer)
|
| 155 |
scaler.update()
|
| 156 |
-
shuffled_images = shuffled_images.to('cpu')
|
| 157 |
-
logits = logits.to('cpu')
|
| 158 |
-
logits_attn = logits_attn.to('cpu')
|
| 159 |
-
target = target.to('cpu')
|
| 160 |
-
|
| 161 |
batch_norm.append(total_norm)
|
| 162 |
pred = torch.argmax(logits, dim=1)
|
| 163 |
acc = (pred == target).sum() / len(pred)
|
|
@@ -171,15 +141,15 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):
|
|
| 171 |
args.epochs,
|
| 172 |
idx,
|
| 173 |
len(loader),
|
| 174 |
-
loss,
|
| 175 |
-
attn_loss,
|
| 176 |
acc,
|
| 177 |
total_norm,
|
| 178 |
-
time.time() - start_time
|
| 179 |
)
|
| 180 |
)
|
| 181 |
start_time = time.time()
|
| 182 |
-
|
| 183 |
del data, target, shuffled_images, logits, logits_attn
|
| 184 |
torch.cuda.empty_cache()
|
| 185 |
batch_norm_epoch = batch_norm.aggregate()
|
|
@@ -189,9 +159,7 @@ def train_epoch(model, loader, optimizer, scaler, epoch, args):
|
|
| 189 |
return loss_epoch, acc_epoch, attn_loss_epoch, batch_norm_epoch
|
| 190 |
|
| 191 |
|
| 192 |
-
|
| 193 |
def val_epoch(model, loader, epoch, args):
|
| 194 |
-
|
| 195 |
criterion = nn.CrossEntropyLoss()
|
| 196 |
|
| 197 |
run_loss = CumulativeAverage()
|
|
@@ -204,17 +172,17 @@ def val_epoch(model, loader, epoch, args):
|
|
| 204 |
model.eval()
|
| 205 |
with torch.no_grad():
|
| 206 |
for idx, batch_data in enumerate(loader):
|
| 207 |
-
|
| 208 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 209 |
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
|
| 210 |
target = target.long()
|
| 211 |
-
|
|
|
|
| 212 |
logits = model(data)
|
| 213 |
loss = criterion(logits, target)
|
| 214 |
|
| 215 |
-
data = data.to(
|
| 216 |
-
target = target.to(
|
| 217 |
-
logits = logits.to(
|
| 218 |
pred = torch.argmax(logits, dim=1)
|
| 219 |
acc = (pred == target).sum() / len(target)
|
| 220 |
|
|
@@ -228,7 +196,7 @@ def val_epoch(model, loader, epoch, args):
|
|
| 228 |
)
|
| 229 |
)
|
| 230 |
start_time = time.time()
|
| 231 |
-
|
| 232 |
del data, target, logits
|
| 233 |
torch.cuda.empty_cache()
|
| 234 |
|
|
@@ -237,10 +205,5 @@ def val_epoch(model, loader, epoch, args):
|
|
| 237 |
TARGETS = TARGETS.get_buffer().cpu().numpy()
|
| 238 |
loss_epoch = run_loss.aggregate()
|
| 239 |
acc_epoch = run_acc.aggregate()
|
| 240 |
-
qwk = cohen_kappa_score(TARGETS.astype(np.float64),PREDS.astype(np.float64))
|
| 241 |
return loss_epoch, acc_epoch, qwk
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import time
|
|
|
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
|
|
|
|
|
|
| 4 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from monai.metrics import Cumulative, CumulativeAverage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from sklearn.metrics import cohen_kappa_score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import logging
|
| 8 |
+
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def get_lambda_att(epoch, max_lambda=2.0, warmup_epochs=10):
|
| 11 |
if epoch < warmup_epochs:
|
|
|
|
| 13 |
else:
|
| 14 |
return max_lambda
|
| 15 |
|
| 16 |
+
|
| 17 |
def get_attention_scores(data, target, heatmap, args):
|
| 18 |
+
"""
|
| 19 |
+
Compute attention scores from heatmaps and shuffle data accordingly.
|
| 20 |
+
This function generates attention scores based on spatial heatmaps, applies
|
| 21 |
+
sharpening, and creates shuffled versions of the input data and attention
|
| 22 |
+
labels. For PI-RADS 2 (target < 1), uniform attention scores are assigned.
|
| 23 |
+
Args:
|
| 24 |
+
data (torch.Tensor): Input data tensor of shape (batch_size, num_patches, ...).
|
| 25 |
+
target (torch.Tensor): Target labels tensor of shape (batch_size,).
|
| 26 |
+
heatmap (torch.Tensor): Attention heatmap tensor corresponding to input patches.
|
| 27 |
+
args: Arguments object containing device specification.
|
| 28 |
+
Returns:
|
| 29 |
+
tuple: A tuple containing:
|
| 30 |
+
- att_labels (torch.Tensor): Sharpened and normalized attention scores
|
| 31 |
+
of shape (batch_size, num_patches), moved to args.device.
|
| 32 |
+
- shuffled_images (torch.Tensor): Randomly permuted data samples
|
| 33 |
+
of shape (batch_size, num_patches, ...), moved to args.device.
|
| 34 |
+
Note:
|
| 35 |
+
- Attention scores are computed by summing heatmap values across spatial dimensions.
|
| 36 |
+
- Data and attention labels are shuffled with the same permutation per sample.
|
| 37 |
+
- PI-RADS 2 samples receive uniform attention distribution.
|
| 38 |
+
- Attention scores are squared for sharpening and then normalized.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
attention_score = torch.zeros((data.shape[0], data.shape[1]))
|
| 42 |
for i in range(data.shape[0]):
|
| 43 |
sample = heatmap[i]
|
| 44 |
heatmap_patches = sample.squeeze(1)
|
| 45 |
+
raw_scores = heatmap_patches.view(len(heatmap_patches), -1).sum(dim=1)
|
| 46 |
+
attention_score[i] = raw_scores / raw_scores.sum()
|
| 47 |
shuffled_images = torch.empty_like(data).to(args.device)
|
| 48 |
att_labels = torch.empty_like(attention_score).to(args.device)
|
| 49 |
+
for i in range(data.shape[0]):
|
| 50 |
+
perm = torch.randperm(data.shape[1])
|
| 51 |
shuffled_images[i] = data[i, perm]
|
| 52 |
+
att_labels[i] = attention_score[i, perm]
|
| 53 |
+
|
| 54 |
+
att_labels[torch.argwhere(target < 1)] = torch.ones_like(att_labels[0]) / len(
|
| 55 |
+
att_labels[0]
|
| 56 |
+
) # For PI-RADS 2, uniform scores across patches
|
| 57 |
+
att_labels = att_labels**2 # Sharpening
|
| 58 |
att_labels = att_labels / att_labels.sum(dim=1, keepdim=True)
|
| 59 |
+
|
| 60 |
return att_labels, shuffled_images
|
| 61 |
|
| 62 |
+
|
| 63 |
def train_epoch(model, loader, optimizer, scaler, epoch, args):
|
| 64 |
"""One train epoch over the dataset"""
|
| 65 |
lambda_att = get_lambda_att(epoch, warmup_epochs=25)
|
|
|
|
| 77 |
loss, acc = 0.0, 0.0
|
| 78 |
|
| 79 |
for idx, batch_data in enumerate(loader):
|
|
|
|
| 80 |
eps = 1e-8
|
| 81 |
data = batch_data["image"].as_subclass(torch.Tensor)
|
| 82 |
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
|
| 83 |
target = target.long()
|
| 84 |
if args.use_heatmap:
|
| 85 |
+
att_labels, shuffled_images = get_attention_scores(
|
| 86 |
+
data, target, batch_data["final_heatmap"], args
|
| 87 |
+
)
|
| 88 |
att_labels = att_labels + eps
|
| 89 |
else:
|
| 90 |
shuffled_images = data.to(args.device)
|
|
|
|
| 113 |
b = b + eps
|
| 114 |
att_preds = torch.softmax(b, dim=1)
|
| 115 |
attn_loss = 1 - att_criterion(att_preds, att_labels).mean()
|
| 116 |
+
loss = class_loss + (lambda_att * attn_loss)
|
| 117 |
else:
|
| 118 |
loss = class_loss
|
| 119 |
attn_loss = torch.tensor(0.0)
|
|
|
|
| 124 |
if not torch.isfinite(total_norm):
|
| 125 |
logging.warning("Non-finite gradient norm detected, skipping batch.")
|
| 126 |
optimizer.zero_grad()
|
| 127 |
+
scaler.update()
|
| 128 |
else:
|
| 129 |
scaler.step(optimizer)
|
| 130 |
scaler.update()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
batch_norm.append(total_norm)
|
| 132 |
pred = torch.argmax(logits, dim=1)
|
| 133 |
acc = (pred == target).sum() / len(pred)
|
|
|
|
| 141 |
args.epochs,
|
| 142 |
idx,
|
| 143 |
len(loader),
|
| 144 |
+
loss.item(),
|
| 145 |
+
attn_loss.item(),
|
| 146 |
acc,
|
| 147 |
total_norm,
|
| 148 |
+
time.time() - start_time,
|
| 149 |
)
|
| 150 |
)
|
| 151 |
start_time = time.time()
|
| 152 |
+
|
| 153 |
del data, target, shuffled_images, logits, logits_attn
|
| 154 |
torch.cuda.empty_cache()
|
| 155 |
batch_norm_epoch = batch_norm.aggregate()
|
|
|
|
| 159 |
return loss_epoch, acc_epoch, attn_loss_epoch, batch_norm_epoch
|
| 160 |
|
| 161 |
|
|
|
|
| 162 |
def val_epoch(model, loader, epoch, args):
|
|
|
|
| 163 |
criterion = nn.CrossEntropyLoss()
|
| 164 |
|
| 165 |
run_loss = CumulativeAverage()
|
|
|
|
| 172 |
model.eval()
|
| 173 |
with torch.no_grad():
|
| 174 |
for idx, batch_data in enumerate(loader):
|
|
|
|
| 175 |
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
|
| 176 |
target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
|
| 177 |
target = target.long()
|
| 178 |
+
|
| 179 |
+
with torch.amp.autocast(device_type=str(args.device), enabled=args.amp):
|
| 180 |
logits = model(data)
|
| 181 |
loss = criterion(logits, target)
|
| 182 |
|
| 183 |
+
data = data.to("cpu")
|
| 184 |
+
target = target.to("cpu")
|
| 185 |
+
logits = logits.to("cpu")
|
| 186 |
pred = torch.argmax(logits, dim=1)
|
| 187 |
acc = (pred == target).sum() / len(target)
|
| 188 |
|
|
|
|
| 196 |
)
|
| 197 |
)
|
| 198 |
start_time = time.time()
|
| 199 |
+
|
| 200 |
del data, target, logits
|
| 201 |
torch.cuda.empty_cache()
|
| 202 |
|
|
|
|
| 205 |
TARGETS = TARGETS.get_buffer().cpu().numpy()
|
| 206 |
loss_epoch = run_loss.aggregate()
|
| 207 |
acc_epoch = run_acc.aggregate()
|
| 208 |
+
qwk = cohen_kappa_score(TARGETS.astype(np.float64), PREDS.astype(np.float64))
|
| 209 |
return loss_epoch, acc_epoch, qwk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils.py
CHANGED
|
@@ -1,97 +1,46 @@
|
|
| 1 |
-
import argparse
|
| 2 |
import os
|
| 3 |
-
import shutil
|
| 4 |
-
import time
|
| 5 |
-
import yaml
|
| 6 |
import sys
|
| 7 |
-
import gdown
|
| 8 |
import numpy as np
|
| 9 |
import torch
|
| 10 |
-
import torch.distributed as dist
|
| 11 |
-
import torch.multiprocessing as mp
|
| 12 |
-
import torch.nn as nn
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
from monai.config import KeysCollection
|
| 15 |
-
from monai.metrics import Cumulative, CumulativeAverage
|
| 16 |
-
from monai.networks.nets import milmodel, resnet, MILModel
|
| 17 |
from monai.transforms import (
|
| 18 |
Compose,
|
| 19 |
-
GridPatchd,
|
| 20 |
LoadImaged,
|
| 21 |
-
MapTransform,
|
| 22 |
-
RandFlipd,
|
| 23 |
-
RandGridPatchd,
|
| 24 |
-
RandRotate90d,
|
| 25 |
-
ScaleIntensityRanged,
|
| 26 |
-
SplitDimd,
|
| 27 |
ToTensord,
|
| 28 |
-
ConcatItemsd,
|
| 29 |
-
SelectItemsd,
|
| 30 |
-
EnsureChannelFirstd,
|
| 31 |
-
RepeatChanneld,
|
| 32 |
-
DeleteItemsd,
|
| 33 |
EnsureTyped,
|
| 34 |
-
ClipIntensityPercentilesd,
|
| 35 |
-
MaskIntensityd,
|
| 36 |
-
HistogramNormalized,
|
| 37 |
-
RandBiasFieldd,
|
| 38 |
-
RandCropByPosNegLabeld,
|
| 39 |
-
NormalizeIntensityd,
|
| 40 |
-
SqueezeDimd,
|
| 41 |
-
CropForegroundd,
|
| 42 |
-
ScaleIntensityd,
|
| 43 |
-
SpatialPadd,
|
| 44 |
-
CenterSpatialCropd,
|
| 45 |
-
ScaleIntensityd,
|
| 46 |
-
Transposed,
|
| 47 |
-
RandWeightedCropd,
|
| 48 |
)
|
| 49 |
-
from sklearn.metrics import cohen_kappa_score
|
| 50 |
-
from torch.cuda.amp import GradScaler, autocast
|
| 51 |
-
from torch.utils.data.dataloader import default_collate
|
| 52 |
-
from torchvision.models.resnet import ResNet50_Weights
|
| 53 |
from .data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd
|
| 54 |
-
from
|
| 55 |
-
from torch.utils.tensorboard import SummaryWriter
|
| 56 |
-
import matplotlib.patches as patches
|
| 57 |
-
|
| 58 |
-
import matplotlib.pyplot as plt
|
| 59 |
-
|
| 60 |
-
import wandb
|
| 61 |
-
import math
|
| 62 |
-
from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
|
| 63 |
-
|
| 64 |
-
from src.model.MIL import MILModel_3D
|
| 65 |
-
from src.model.csPCa_model import csPCa_Model
|
| 66 |
-
|
| 67 |
import logging
|
| 68 |
from pathlib import Path
|
|
|
|
| 69 |
|
| 70 |
-
def save_pirads_checkpoint(model, epoch, args, filename="model.pth", best_acc=0):
|
| 71 |
|
| 72 |
-
|
|
|
|
| 73 |
|
| 74 |
state_dict = model.state_dict()
|
| 75 |
-
|
| 76 |
save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict}
|
| 77 |
-
|
| 78 |
filename = os.path.join(args.logdir, filename)
|
| 79 |
torch.save(save_dict, filename)
|
| 80 |
-
logging.info("Saving checkpoint
|
|
|
|
| 81 |
|
| 82 |
def save_cspca_checkpoint(model, val_metric, model_dir):
|
|
|
|
|
|
|
| 83 |
state_dict = model.state_dict()
|
| 84 |
save_dict = {
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
}
|
| 93 |
-
torch.save(save_dict, os.path.join(model_dir,
|
| 94 |
-
logging.info(
|
|
|
|
| 95 |
|
| 96 |
def get_metrics(metric_dict: dict):
|
| 97 |
for metric_name, metric_list in metric_dict.items():
|
|
@@ -102,6 +51,7 @@ def get_metrics(metric_dict: dict):
|
|
| 102 |
logging.info(f"Mean {metric_name}: {mean_metric:.3f}")
|
| 103 |
logging.info(f"95% CI: ({lower:.3f}, {upper:.3f})")
|
| 104 |
|
|
|
|
| 105 |
def setup_logging(log_file):
|
| 106 |
log_file = Path(log_file)
|
| 107 |
log_file.parent.mkdir(parents=True, exist_ok=True)
|
|
@@ -115,6 +65,7 @@ def setup_logging(log_file):
|
|
| 115 |
],
|
| 116 |
)
|
| 117 |
|
|
|
|
| 118 |
def validate_steps(steps):
|
| 119 |
REQUIRES = {
|
| 120 |
"get_segmentation_mask": ["register_and_crop"],
|
|
@@ -126,38 +77,63 @@ def validate_steps(steps):
|
|
| 126 |
for req in required:
|
| 127 |
if req not in steps[:i]:
|
| 128 |
logging.error(
|
| 129 |
-
f"Step '{step}' requires '{req}' to be executed before it. "
|
| 130 |
-
f"Given order: {steps}"
|
| 131 |
)
|
| 132 |
sys.exit(1)
|
| 133 |
|
| 134 |
-
def get_patch_coordinate(patches_top_5, parent_image, args):
|
| 135 |
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
coords = []
|
| 138 |
rows, h, w, slices = sample.shape
|
| 139 |
|
| 140 |
for i in range(rows):
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
| 161 |
|
| 162 |
return coords
|
| 163 |
|
|
@@ -165,7 +141,12 @@ def get_patch_coordinate(patches_top_5, parent_image, args):
|
|
| 165 |
def get_parent_image(temp_data_list, args):
|
| 166 |
transform_image = Compose(
|
| 167 |
[
|
| 168 |
-
LoadImaged(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
|
| 170 |
NormalizeIntensity_customd(keys=["image"], mask_key="mask", channel_wise=True),
|
| 171 |
EnsureTyped(keys=["label"], dtype=torch.float32),
|
|
@@ -173,9 +154,10 @@ def get_parent_image(temp_data_list, args):
|
|
| 173 |
]
|
| 174 |
)
|
| 175 |
dataset_image = Dataset(data=temp_data_list, transform=transform_image)
|
| 176 |
-
return dataset_image[0][
|
|
|
|
| 177 |
|
| 178 |
-
|
| 179 |
def visualise_patches():
|
| 180 |
sample = np.array([i.transpose(1,2,0) for i in patches_top_5])
|
| 181 |
rows = len(patches_top_5)
|
|
@@ -223,4 +205,4 @@ def visualise_patches():
|
|
| 223 |
plt.tight_layout()
|
| 224 |
plt.show()
|
| 225 |
a=1
|
| 226 |
-
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
| 2 |
import sys
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from monai.transforms import (
|
| 6 |
Compose,
|
|
|
|
| 7 |
LoadImaged,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
ToTensord,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
EnsureTyped,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from .data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd
|
| 12 |
+
from monai.data import Dataset, ITKReader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
import logging
|
| 14 |
from pathlib import Path
|
| 15 |
+
import cv2
|
| 16 |
|
|
|
|
| 17 |
|
| 18 |
+
def save_pirads_checkpoint(model, epoch, args, filename="model.pth", best_acc=0):
|
| 19 |
+
"""Save checkpoint for the PI-RADS model"""
|
| 20 |
|
| 21 |
state_dict = model.state_dict()
|
|
|
|
| 22 |
save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict}
|
|
|
|
| 23 |
filename = os.path.join(args.logdir, filename)
|
| 24 |
torch.save(save_dict, filename)
|
| 25 |
+
logging.info(f"Saving checkpoint {filename}")
|
| 26 |
+
|
| 27 |
|
| 28 |
def save_cspca_checkpoint(model, val_metric, model_dir):
|
| 29 |
+
"""Save checkpoint for the csPCa model"""
|
| 30 |
+
|
| 31 |
state_dict = model.state_dict()
|
| 32 |
save_dict = {
|
| 33 |
+
"epoch": val_metric["epoch"],
|
| 34 |
+
"loss": val_metric["loss"],
|
| 35 |
+
"auc": val_metric["auc"],
|
| 36 |
+
"sensitivity": val_metric["sensitivity"],
|
| 37 |
+
"specificity": val_metric["specificity"],
|
| 38 |
+
"state": val_metric["state"],
|
| 39 |
+
"state_dict": state_dict,
|
| 40 |
}
|
| 41 |
+
torch.save(save_dict, os.path.join(model_dir, "cspca_model.pth"))
|
| 42 |
+
logging.info(f"Saving model with auc: {val_metric['auc']}")
|
| 43 |
+
|
| 44 |
|
| 45 |
def get_metrics(metric_dict: dict):
|
| 46 |
for metric_name, metric_list in metric_dict.items():
|
|
|
|
| 51 |
logging.info(f"Mean {metric_name}: {mean_metric:.3f}")
|
| 52 |
logging.info(f"95% CI: ({lower:.3f}, {upper:.3f})")
|
| 53 |
|
| 54 |
+
|
| 55 |
def setup_logging(log_file):
|
| 56 |
log_file = Path(log_file)
|
| 57 |
log_file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 65 |
],
|
| 66 |
)
|
| 67 |
|
| 68 |
+
|
| 69 |
def validate_steps(steps):
|
| 70 |
REQUIRES = {
|
| 71 |
"get_segmentation_mask": ["register_and_crop"],
|
|
|
|
| 77 |
for req in required:
|
| 78 |
if req not in steps[:i]:
|
| 79 |
logging.error(
|
| 80 |
+
f"Step '{step}' requires '{req}' to be executed before it. Given order: {steps}"
|
|
|
|
| 81 |
)
|
| 82 |
sys.exit(1)
|
| 83 |
|
|
|
|
| 84 |
|
| 85 |
+
def get_patch_coordinate(patches_top_5, parent_image):
|
| 86 |
+
"""
|
| 87 |
+
Locate the coordinates of top-5 patches within a parent image.
|
| 88 |
+
|
| 89 |
+
This function searches for the spatial location of the first slice (j=0) of each
|
| 90 |
+
top-5 patch within the parent 3D image volume. It returns the top-left corner
|
| 91 |
+
coordinates (row, column) and the slice index where each patch is found.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
patches_top_5 (list): List of top-5 patch tensors, each with shape (C, H, W)
|
| 95 |
+
where C is channels, H is height, W is width.
|
| 96 |
+
parent_image (np.ndarray): 3D image volume with shape (height, width, slices)
|
| 97 |
+
to search within.
|
| 98 |
+
args: Configuration arguments (currently unused in the function).
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
list: List of tuples (row, col, slice_idx) representing the top-left corner
|
| 102 |
+
coordinates of each found patch in the parent image. Returns empty list
|
| 103 |
+
if no patches are found.
|
| 104 |
+
|
| 105 |
+
Note:
|
| 106 |
+
- Only searches for the first slice (j=0) of each patch.
|
| 107 |
+
- Uses exhaustive 2D spatial matching within each slice of the parent image.
|
| 108 |
+
- Returns coordinates of the first match found for each patch.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
sample = np.array([i.transpose(1, 2, 0) for i in patches_top_5])
|
| 112 |
coords = []
|
| 113 |
rows, h, w, slices = sample.shape
|
| 114 |
|
| 115 |
for i in range(rows):
|
| 116 |
+
template = sample[i, :, :, 0].astype(np.float32)
|
| 117 |
+
found = False
|
| 118 |
+
for k in list(range(parent_image.shape[2])):
|
| 119 |
+
img_slice = parent_image[:, :, k].astype(np.float32)
|
| 120 |
+
res = cv2.matchTemplate(img_slice, template, cv2.TM_CCOEFF_NORMED)
|
| 121 |
+
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
|
| 122 |
+
|
| 123 |
+
if max_val >= 0.99:
|
| 124 |
+
x, y = max_loc # OpenCV returns (col, row) -> (x, y)
|
| 125 |
+
|
| 126 |
+
# 2. Verification Step: Check if it's actually the correct patch
|
| 127 |
+
# This mimics your original np.array_equal strictness
|
| 128 |
+
candidate_patch = img_slice[y : y + h, x : x + w]
|
| 129 |
+
|
| 130 |
+
if np.allclose(candidate_patch, template, atol=1e-5):
|
| 131 |
+
coords.append((y, x, k)) # Original code stored (row, col, slice)
|
| 132 |
+
found = True
|
| 133 |
+
break
|
| 134 |
+
|
| 135 |
+
if not found:
|
| 136 |
+
print("Patch not found")
|
| 137 |
|
| 138 |
return coords
|
| 139 |
|
|
|
|
| 141 |
def get_parent_image(temp_data_list, args):
|
| 142 |
transform_image = Compose(
|
| 143 |
[
|
| 144 |
+
LoadImaged(
|
| 145 |
+
keys=["image", "mask"],
|
| 146 |
+
reader=ITKReader(),
|
| 147 |
+
ensure_channel_first=True,
|
| 148 |
+
dtype=np.float32,
|
| 149 |
+
),
|
| 150 |
ClipMaskIntensityPercentilesd(keys=["image"], lower=0, upper=99.5, mask_key="mask"),
|
| 151 |
NormalizeIntensity_customd(keys=["image"], mask_key="mask", channel_wise=True),
|
| 152 |
EnsureTyped(keys=["label"], dtype=torch.float32),
|
|
|
|
| 154 |
]
|
| 155 |
)
|
| 156 |
dataset_image = Dataset(data=temp_data_list, transform=transform_image)
|
| 157 |
+
return dataset_image[0]["image"][0].numpy()
|
| 158 |
+
|
| 159 |
|
| 160 |
+
"""
|
| 161 |
def visualise_patches():
|
| 162 |
sample = np.array([i.transpose(1,2,0) for i in patches_top_5])
|
| 163 |
rows = len(patches_top_5)
|
|
|
|
| 205 |
plt.tight_layout()
|
| 206 |
plt.show()
|
| 207 |
a=1
|
| 208 |
+
"""
|
temp.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tests/test_run.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_run_pirads_training():
|
| 7 |
+
"""
|
| 8 |
+
Test that run_cspca.py runs without crashing using an existing YAML config.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
# Path to your run_pirads.py script
|
| 12 |
+
repo_root = Path(__file__).parent.parent
|
| 13 |
+
script_path = repo_root / "run_pirads.py"
|
| 14 |
+
|
| 15 |
+
# Path to your existing config.yaml
|
| 16 |
+
config_path = repo_root / "config" / "config_pirads_train.yaml" # adjust this path
|
| 17 |
+
|
| 18 |
+
# Make sure the file exists
|
| 19 |
+
assert config_path.exists(), f"Config file not found: {config_path}"
|
| 20 |
+
|
| 21 |
+
# Run the script with the config
|
| 22 |
+
result = subprocess.run(
|
| 23 |
+
[sys.executable, str(script_path), "--mode", "train", "--config", str(config_path), "--dry_run", "True" ],
|
| 24 |
+
capture_output=True,
|
| 25 |
+
text=True,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Check that it ran without errors
|
| 29 |
+
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
| 30 |
+
|
| 31 |
+
def test_run_pirads_inference():
|
| 32 |
+
"""
|
| 33 |
+
Test that run_cspca.py runs without crashing using an existing YAML config.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
# Path to your run_pirads.py script
|
| 37 |
+
repo_root = Path(__file__).parent.parent
|
| 38 |
+
script_path = repo_root / "run_pirads.py"
|
| 39 |
+
|
| 40 |
+
# Path to your existing config.yaml
|
| 41 |
+
config_path = repo_root / "config" / "config_pirads_test.yaml" # adjust this path
|
| 42 |
+
|
| 43 |
+
# Make sure the file exists
|
| 44 |
+
assert config_path.exists(), f"Config file not found: {config_path}"
|
| 45 |
+
|
| 46 |
+
# Run the script with the config
|
| 47 |
+
result = subprocess.run(
|
| 48 |
+
[sys.executable, str(script_path), "--mode", "test", "--config", str(config_path), "--dry_run", "True" ],
|
| 49 |
+
capture_output=True,
|
| 50 |
+
text=True,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Check that it ran without errors
|
| 54 |
+
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
| 55 |
+
|
| 56 |
+
def test_run_cspca_training():
|
| 57 |
+
"""
|
| 58 |
+
Test that run_cspca.py runs without crashing using an existing YAML config.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
# Path to your run_cspca.py script
|
| 62 |
+
repo_root = Path(__file__).parent.parent
|
| 63 |
+
script_path = repo_root / "run_cspca.py"
|
| 64 |
+
|
| 65 |
+
# Path to your existing config.yaml
|
| 66 |
+
config_path = repo_root / "config" / "config_cspca_train.yaml" # adjust this path
|
| 67 |
+
|
| 68 |
+
# Make sure the file exists
|
| 69 |
+
assert config_path.exists(), f"Config file not found: {config_path}"
|
| 70 |
+
|
| 71 |
+
# Run the script with the config
|
| 72 |
+
result = subprocess.run(
|
| 73 |
+
[sys.executable, str(script_path), "--mode", "train", "--config", str(config_path), "--dry_run", "True" ],
|
| 74 |
+
capture_output=True,
|
| 75 |
+
text=True,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Check that it ran without errors
|
| 79 |
+
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
| 80 |
+
|
| 81 |
+
def test_run_cspca_inference():
|
| 82 |
+
"""
|
| 83 |
+
Test that run_cspca.py runs without crashing using an existing YAML config.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
# Path to your run_cspca.py script
|
| 87 |
+
repo_root = Path(__file__).parent.parent
|
| 88 |
+
script_path = repo_root / "run_cspca.py"
|
| 89 |
+
|
| 90 |
+
# Path to your existing config.yaml
|
| 91 |
+
config_path = repo_root / "config" / "config_cspca_test.yaml" # adjust this path
|
| 92 |
+
|
| 93 |
+
# Make sure the file exists
|
| 94 |
+
assert config_path.exists(), f"Config file not found: {config_path}"
|
| 95 |
+
|
| 96 |
+
# Run the script with the config
|
| 97 |
+
result = subprocess.run(
|
| 98 |
+
[sys.executable, str(script_path), "--mode", "test", "--config", str(config_path), "--dry_run", "True" ],
|
| 99 |
+
capture_output=True,
|
| 100 |
+
text=True,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Check that it ran without errors
|
| 104 |
+
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
| 105 |
+
|
| 106 |
+
|
tests/test_run_cspca.py
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
import subprocess
|
| 2 |
-
import sys
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
|
| 5 |
-
def test_run_cspca_with_existing_config():
|
| 6 |
-
"""
|
| 7 |
-
Test that run_cspca.py runs without crashing using an existing YAML config.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
# Path to your run_cspca.py script
|
| 11 |
-
repo_root = Path(__file__).parent.parent
|
| 12 |
-
script_path = repo_root / "run_cspca.py"
|
| 13 |
-
|
| 14 |
-
# Path to your existing config.yaml
|
| 15 |
-
config_path = repo_root / "config" / "config_cspca_test.yaml" # adjust this path
|
| 16 |
-
|
| 17 |
-
# Make sure the file exists
|
| 18 |
-
assert config_path.exists(), f"Config file not found: {config_path}"
|
| 19 |
-
|
| 20 |
-
# Run the script with the config
|
| 21 |
-
result = subprocess.run(
|
| 22 |
-
[sys.executable, str(script_path), "--mode","test", "--config", str(config_path)],
|
| 23 |
-
capture_output=True,
|
| 24 |
-
text=True,
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
# Check that it ran without errors
|
| 28 |
-
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_run_pirads.py
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
import subprocess
|
| 2 |
-
import sys
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
|
| 5 |
-
def test_run_cspca_with_existing_config():
|
| 6 |
-
"""
|
| 7 |
-
Test that run_cspca.py runs without crashing using an existing YAML config.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
# Path to your run_pirads.py script
|
| 11 |
-
repo_root = Path(__file__).parent.parent
|
| 12 |
-
script_path = repo_root / "run_pirads.py"
|
| 13 |
-
|
| 14 |
-
# Path to your existing config.yaml
|
| 15 |
-
config_path = repo_root / "config" / "config_pirads_test.yaml" # adjust this path
|
| 16 |
-
|
| 17 |
-
# Make sure the file exists
|
| 18 |
-
assert config_path.exists(), f"Config file not found: {config_path}"
|
| 19 |
-
|
| 20 |
-
# Run the script with the config
|
| 21 |
-
result = subprocess.run(
|
| 22 |
-
[sys.executable, str(script_path), "--mode","test", "--config", str(config_path)],
|
| 23 |
-
capture_output=True,
|
| 24 |
-
text=True,
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
# Check that it ran without errors
|
| 28 |
-
assert result.returncode == 0, f"Script failed with:\n{result.stderr}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|