Divyanshu Tak commited on
Commit ·
5a169ab
1
Parent(s): 50a5e7b
V0-commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- src/.DS_Store +0 -0
- src/BrainIAC/.DS_Store +0 -0
- src/BrainIAC/Brainage/README.md +55 -0
- src/BrainIAC/Brainage/__init__.py +0 -0
- src/BrainIAC/Brainage/__pycache__/__init__.cpython-39.pyc +0 -0
- src/BrainIAC/Brainage/__pycache__/infer_brainage.cpython-39.pyc +0 -0
- src/BrainIAC/Brainage/brainage.jpeg +3 -0
- src/BrainIAC/Brainage/infer_brainage.py +85 -0
- src/BrainIAC/Brainage/train_brainage.py +230 -0
- src/BrainIAC/HD_BET/__pycache__/config.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/config.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/config.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/paths.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/paths.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/paths.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/run.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/run.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/run.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/utils.cpython-310.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/utils.cpython-38.pyc +0 -0
- src/BrainIAC/HD_BET/__pycache__/utils.cpython-39.pyc +0 -0
- src/BrainIAC/HD_BET/config.py +121 -0
- src/BrainIAC/HD_BET/data_loading.py +121 -0
- src/BrainIAC/HD_BET/hd_bet.py +119 -0
- src/BrainIAC/HD_BET/network_architecture.py +213 -0
- src/BrainIAC/HD_BET/paths.py +6 -0
- src/BrainIAC/HD_BET/predict_case.py +126 -0
- src/BrainIAC/HD_BET/run.py +117 -0
- src/BrainIAC/HD_BET/utils.py +115 -0
- src/BrainIAC/MCIclassification/README.md +52 -0
- src/BrainIAC/MCIclassification/__init__.py +0 -0
- src/BrainIAC/MCIclassification/__pycache__/__init__.cpython-39.pyc +0 -0
- src/BrainIAC/MCIclassification/__pycache__/infer_mci.cpython-39.pyc +0 -0
- src/BrainIAC/MCIclassification/infer_mci.py +142 -0
- src/BrainIAC/MCIclassification/mci.jpeg +3 -0
- src/BrainIAC/MCIclassification/train_mci.py +265 -0
- src/BrainIAC/__init__.py +0 -0
- src/BrainIAC/__pycache__/dataset2.cpython-39.pyc +0 -0
- src/BrainIAC/__pycache__/load_brainiac.cpython-39.pyc +0 -0
src/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
src/BrainIAC/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
src/BrainIAC/Brainage/README.md
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Brain Age Prediction
|
| 2 |
+
|
| 3 |
+
<p align="left">
|
| 4 |
+
<img src="brainage.jpeg" width="200" alt="Brain Age Prediction Example"/>
|
| 5 |
+
</p>
|
| 6 |
+
|
| 7 |
+
## Overview
|
| 8 |
+
|
| 9 |
+
We present the brainage prediction training and inference code for BrainIAC as a downstream task. The pipeline is trained and infered on T1 scans, with MAE as evaluation metric.
|
| 10 |
+
|
| 11 |
+
## Data Requirements
|
| 12 |
+
|
| 13 |
+
- **Input**: T1-weighted MRI scans
|
| 14 |
+
- **Format**: NIFTI (.nii.gz)
|
| 15 |
+
- **Preprocessing**: Bias field corrected, registered to standard space, skull stripped
|
| 16 |
+
- **CSV Structure**:
|
| 17 |
+
```
|
| 18 |
+
pat_id,scandate,label
|
| 19 |
+
subject001,20240101,65 # brain age in years
|
| 20 |
+
```
|
| 21 |
+
refer to [ quickstart.ipynb](../quickstart.ipynb) to find how to preprocess data and generate csv file.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
## Setup
|
| 25 |
+
|
| 26 |
+
1. **Configuration**:
|
| 27 |
+
change the [config.yml](../config.yml) file accordingly.
|
| 28 |
+
```yaml
|
| 29 |
+
# config.yml
|
| 30 |
+
data:
|
| 31 |
+
train_csv: "path/to/train.csv"
|
| 32 |
+
val_csv: "path/to/val.csv"
|
| 33 |
+
test_csv: "path/to/test.csv"
|
| 34 |
+
root_dir: "../data/sample/processed"
|
| 35 |
+
collate: 1 # single scan framework
|
| 36 |
+
|
| 37 |
+
checkpoints: "./checkpoints/brainage_model.00" # for inference/testing
|
| 38 |
+
|
| 39 |
+
train:
|
| 40 |
+
finetune: 'yes' # yes to finetune the entire model
|
| 41 |
+
freeze: 'no' # yes to freeze the resnet backbone
|
| 42 |
+
weights: ./checkpoints/brainiac.ckpt # path to brainiac weights
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
2. **Training**:
|
| 47 |
+
```bash
|
| 48 |
+
python -m Brainage.train_brainage
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
3. **Inference**:
|
| 52 |
+
```bash
|
| 53 |
+
python -m Brainage.infer_brainage
|
| 54 |
+
```
|
| 55 |
+
|
src/BrainIAC/Brainage/__init__.py
ADDED
|
File without changes
|
src/BrainIAC/Brainage/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|
src/BrainIAC/Brainage/__pycache__/infer_brainage.cpython-39.pyc
ADDED
|
Binary file (3 kB). View file
|
|
|
src/BrainIAC/Brainage/brainage.jpeg
ADDED
|
Git LFS Details
|
src/BrainIAC/Brainage/infer_brainage.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import os
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torch.cuda.amp import autocast
|
| 7 |
+
from sklearn.metrics import mean_absolute_error
|
| 8 |
+
import sys
|
| 9 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 10 |
+
from dataset2 import MedicalImageDatasetBalancedIntensity3D
|
| 11 |
+
from model import Backbone, SingleScanModel, Classifier
|
| 12 |
+
from utils import BaseConfig
|
| 13 |
+
|
| 14 |
+
class BrainAgeInference(BaseConfig):
|
| 15 |
+
"""
|
| 16 |
+
Inference class for brain age prediction model.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
"""Initialize the inference setup with model and data."""
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.setup_model()
|
| 23 |
+
self.setup_data()
|
| 24 |
+
|
| 25 |
+
def setup_model(self):
|
| 26 |
+
config = self.get_config()
|
| 27 |
+
self.backbone = Backbone()
|
| 28 |
+
self.classifier = Classifier(d_model=2048)
|
| 29 |
+
self.model = SingleScanModel(self.backbone, self.classifier)
|
| 30 |
+
|
| 31 |
+
# Load weights
|
| 32 |
+
checkpoint = torch.load(config["infer"]["checkpoints"], map_location=self.device)
|
| 33 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
| 34 |
+
self.model = self.model.to(self.device)
|
| 35 |
+
self.model.eval()
|
| 36 |
+
print("Model and checkpoint loaded!")
|
| 37 |
+
|
| 38 |
+
## spinup dataloaders
|
| 39 |
+
def setup_data(self):
|
| 40 |
+
config = self.get_config()
|
| 41 |
+
self.test_dataset = MedicalImageDatasetBalancedIntensity3D(
|
| 42 |
+
csv_path=config["data"]["test_csv"],
|
| 43 |
+
root_dir=config["data"]["root_dir"]
|
| 44 |
+
)
|
| 45 |
+
self.test_loader = DataLoader(
|
| 46 |
+
self.test_dataset,
|
| 47 |
+
batch_size=1,
|
| 48 |
+
shuffle=False,
|
| 49 |
+
collate_fn=self.custom_collate,
|
| 50 |
+
num_workers=1
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def infer(self):
|
| 54 |
+
""" Infer pass """
|
| 55 |
+
results_df = pd.DataFrame(columns=['PredictedAge', 'TrueAge'])
|
| 56 |
+
all_labels = []
|
| 57 |
+
all_predictions = []
|
| 58 |
+
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
for sample in tqdm(self.test_loader, desc="Inference", unit="batch"):
|
| 61 |
+
inputs = sample['image'].to(self.device)
|
| 62 |
+
labels = sample['label'].float().to(self.device)
|
| 63 |
+
|
| 64 |
+
with autocast():
|
| 65 |
+
outputs = self.model(inputs)
|
| 66 |
+
|
| 67 |
+
predictions = outputs.cpu().numpy().flatten()
|
| 68 |
+
all_labels.extend(labels.cpu().numpy().flatten())
|
| 69 |
+
all_predictions.extend(predictions)
|
| 70 |
+
|
| 71 |
+
result = pd.DataFrame({
|
| 72 |
+
'PredictedAge': predictions,
|
| 73 |
+
'TrueAge': labels.cpu().numpy().flatten()
|
| 74 |
+
})
|
| 75 |
+
results_df = pd.concat([results_df, result], ignore_index=True)
|
| 76 |
+
|
| 77 |
+
mae = mean_absolute_error(all_labels, all_predictions)
|
| 78 |
+
print(f"Mean Absolute Error (MAE): {mae:.4f} months")
|
| 79 |
+
results_df.to_csv('./data/output/brainage_output.csv', index=False)
|
| 80 |
+
|
| 81 |
+
return mae
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
inferencer = BrainAgeInference()
|
| 85 |
+
mae = inferencer.infer()
|
src/BrainIAC/Brainage/train_brainage.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import wandb
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
| 8 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 9 |
+
from sklearn.metrics import mean_absolute_error
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 13 |
+
from dataset2 import MedicalImageDatasetBalancedIntensity3D, TransformationMedicalImageDatasetBalancedIntensity3D
|
| 14 |
+
from model import Backbone, SingleScanModel, Classifier
|
| 15 |
+
from utils import BaseConfig
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BrainAgeTrainer(BaseConfig):
|
| 19 |
+
"""
|
| 20 |
+
A trainer class for brain age prediction models.
|
| 21 |
+
|
| 22 |
+
This class handles the complete training pipeline including model setup,
|
| 23 |
+
data loading, training loop, and validation.
|
| 24 |
+
Inherits from BaseConfig for configuration management.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self):
|
| 28 |
+
"""Initialize the trainer with model, data, and training setup."""
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.setup_wandb()
|
| 31 |
+
self.setup_model()
|
| 32 |
+
self.setup_data()
|
| 33 |
+
self.setup_training()
|
| 34 |
+
|
| 35 |
+
## setup wandb logger
|
| 36 |
+
def setup_wandb(self):
|
| 37 |
+
config = self.get_config()
|
| 38 |
+
wandb.init(
|
| 39 |
+
project=config['logger']['project_name'],
|
| 40 |
+
name=config['logger']['run_name'],
|
| 41 |
+
config=config
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def setup_model(self):
|
| 45 |
+
"""
|
| 46 |
+
Set up the model architecture.
|
| 47 |
+
|
| 48 |
+
Initializes the backbone and classifier blocks, and loads
|
| 49 |
+
checkpoints
|
| 50 |
+
"""
|
| 51 |
+
self.backbone = Backbone()
|
| 52 |
+
self.classifier = Classifier(d_model=2048)
|
| 53 |
+
self.model = SingleScanModel(self.backbone, self.classifier)
|
| 54 |
+
|
| 55 |
+
# Load BrainIACs weights
|
| 56 |
+
config = self.get_config()
|
| 57 |
+
if config["train"]["finetune"] == "yes":
|
| 58 |
+
checkpoint = torch.load(config["train"]["weights"], map_location=self.device)
|
| 59 |
+
state_dict = checkpoint["state_dict"]
|
| 60 |
+
filtered_state_dict = {}
|
| 61 |
+
for key, value in state_dict.items():
|
| 62 |
+
new_key = key.replace("module.", "backbone.") if key.startswith("module.") else key
|
| 63 |
+
filtered_state_dict[new_key] = value
|
| 64 |
+
self.model.backbone.load_state_dict(filtered_state_dict, strict=False)
|
| 65 |
+
print("Pretrained weights loaded!")
|
| 66 |
+
|
| 67 |
+
# Freeze backbone if specified
|
| 68 |
+
if config["train"]["freeze"] == "yes":
|
| 69 |
+
for param in self.model.backbone.parameters():
|
| 70 |
+
param.requires_grad = False
|
| 71 |
+
print("Backbone weights frozen!")
|
| 72 |
+
|
| 73 |
+
self.model = self.model.to(self.device)
|
| 74 |
+
|
| 75 |
+
def setup_data(self):
|
| 76 |
+
"""
|
| 77 |
+
Set up data loaders for training and validation.
|
| 78 |
+
Inherit configuration from the base config
|
| 79 |
+
"""
|
| 80 |
+
config = self.get_config()
|
| 81 |
+
self.train_dataset = TransformationMedicalImageDatasetBalancedIntensity3D(
|
| 82 |
+
csv_path=config['data']['train_csv'],
|
| 83 |
+
root_dir=config["data"]["root_dir"]
|
| 84 |
+
)
|
| 85 |
+
self.val_dataset = MedicalImageDatasetBalancedIntensity3D(
|
| 86 |
+
csv_path=config['data']['val_csv'],
|
| 87 |
+
root_dir=config["data"]["root_dir"]
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
self.train_loader = DataLoader(
|
| 91 |
+
self.train_dataset,
|
| 92 |
+
batch_size=config["data"]["batch_size"],
|
| 93 |
+
shuffle=True,
|
| 94 |
+
collate_fn=self.custom_collate,
|
| 95 |
+
num_workers=config["data"]["num_workers"]
|
| 96 |
+
)
|
| 97 |
+
self.val_loader = DataLoader(
|
| 98 |
+
self.val_dataset,
|
| 99 |
+
batch_size=1,
|
| 100 |
+
shuffle=False,
|
| 101 |
+
collate_fn=self.custom_collate,
|
| 102 |
+
num_workers=1
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def setup_training(self):
|
| 106 |
+
"""
|
| 107 |
+
Set up training config with loss, scheduler, optimizer.
|
| 108 |
+
"""
|
| 109 |
+
config = self.get_config()
|
| 110 |
+
self.criterion = nn.MSELoss()
|
| 111 |
+
self.optimizer = optim.Adam(
|
| 112 |
+
self.model.parameters(),
|
| 113 |
+
lr=config['optim']['lr'],
|
| 114 |
+
weight_decay=config["optim"]["weight_decay"]
|
| 115 |
+
)
|
| 116 |
+
self.scheduler = CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=2)
|
| 117 |
+
self.scaler = GradScaler()
|
| 118 |
+
|
| 119 |
+
def train(self):
|
| 120 |
+
"""
|
| 121 |
+
main training loop
|
| 122 |
+
"""
|
| 123 |
+
config = self.get_config()
|
| 124 |
+
max_epochs = config['optim']['max_epochs']
|
| 125 |
+
best_val_loss = float('inf')
|
| 126 |
+
best_val_mae = float('inf')
|
| 127 |
+
|
| 128 |
+
for epoch in range(max_epochs):
|
| 129 |
+
train_loss = self.train_epoch(epoch, max_epochs)
|
| 130 |
+
val_loss, mae = self.validate_epoch(epoch, max_epochs)
|
| 131 |
+
|
| 132 |
+
# Save best model
|
| 133 |
+
if (val_loss <= best_val_loss) and (mae <= best_val_mae):
|
| 134 |
+
print(f"Improved Val Loss from {best_val_loss:.4f} to {val_loss:.4f}")
|
| 135 |
+
print(f"Improved Val MAE from {best_val_mae:.4f} to {mae:.4f}")
|
| 136 |
+
best_val_loss = val_loss
|
| 137 |
+
best_val_mae = mae
|
| 138 |
+
self.save_checkpoint(epoch, val_loss, mae)
|
| 139 |
+
|
| 140 |
+
wandb.finish()
|
| 141 |
+
|
| 142 |
+
def train_epoch(self, epoch, max_epochs):
|
| 143 |
+
"""
|
| 144 |
+
Train pass.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
epoch (int): Current epoch number
|
| 148 |
+
max_epochs (int): Total number of epochs
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
float: Average training loss for the epoch
|
| 152 |
+
"""
|
| 153 |
+
self.model.train()
|
| 154 |
+
train_loss = 0.0
|
| 155 |
+
|
| 156 |
+
for sample in tqdm(self.train_loader, desc=f"Training Epoch {epoch}/{max_epochs-1}"):
|
| 157 |
+
inputs = sample['image'].to(self.device)
|
| 158 |
+
labels = sample['label'].float().to(self.device)
|
| 159 |
+
|
| 160 |
+
self.optimizer.zero_grad()
|
| 161 |
+
with autocast():
|
| 162 |
+
outputs = self.model(inputs)
|
| 163 |
+
loss = self.criterion(outputs, labels.unsqueeze(1))
|
| 164 |
+
|
| 165 |
+
self.scaler.scale(loss).backward()
|
| 166 |
+
self.scaler.step(self.optimizer)
|
| 167 |
+
self.scaler.update()
|
| 168 |
+
|
| 169 |
+
train_loss += loss.item() * inputs.size(0)
|
| 170 |
+
|
| 171 |
+
train_loss = train_loss / len(self.train_loader.dataset)
|
| 172 |
+
wandb.log({"Train Loss": train_loss})
|
| 173 |
+
return train_loss
|
| 174 |
+
|
| 175 |
+
def validate_epoch(self, epoch, max_epochs):
|
| 176 |
+
"""
|
| 177 |
+
Validation pass.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
epoch (int): Current epoch number
|
| 181 |
+
max_epochs (int): Total number of epochs
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
tuple: (validation_loss, mean_absolute_error)
|
| 185 |
+
"""
|
| 186 |
+
self.model.eval()
|
| 187 |
+
val_loss = 0.0
|
| 188 |
+
all_labels = []
|
| 189 |
+
all_preds = []
|
| 190 |
+
|
| 191 |
+
with torch.no_grad():
|
| 192 |
+
for sample in tqdm(self.val_loader, desc=f"Validation Epoch {epoch}/{max_epochs-1}"):
|
| 193 |
+
inputs = sample['image'].to(self.device)
|
| 194 |
+
labels = sample['label'].float().to(self.device)
|
| 195 |
+
|
| 196 |
+
outputs = self.model(inputs)
|
| 197 |
+
loss = self.criterion(outputs, labels.unsqueeze(1))
|
| 198 |
+
|
| 199 |
+
val_loss += loss.item() * inputs.size(0)
|
| 200 |
+
all_labels.extend(labels.cpu().numpy().flatten())
|
| 201 |
+
all_preds.extend(outputs.cpu().numpy().flatten())
|
| 202 |
+
|
| 203 |
+
val_loss = val_loss / len(self.val_loader.dataset)
|
| 204 |
+
mae = mean_absolute_error(all_labels, all_preds)
|
| 205 |
+
|
| 206 |
+
wandb.log({"Val Loss": val_loss, "MAE": mae})
|
| 207 |
+
self.scheduler.step(val_loss)
|
| 208 |
+
|
| 209 |
+
print(f"Epoch {epoch}/{max_epochs-1} Val Loss: {val_loss:.4f} MAE: {mae:.4f}")
|
| 210 |
+
return val_loss, mae
|
| 211 |
+
|
| 212 |
+
def save_checkpoint(self, epoch, loss, mae):
|
| 213 |
+
"""
|
| 214 |
+
Save model checkpoint.
|
| 215 |
+
"""
|
| 216 |
+
config = self.get_config()
|
| 217 |
+
checkpoint = {
|
| 218 |
+
'model_state_dict': self.model.state_dict(),
|
| 219 |
+
'loss': loss,
|
| 220 |
+
'epoch': epoch,
|
| 221 |
+
}
|
| 222 |
+
save_path = os.path.join(
|
| 223 |
+
config['logger']['save_dir'],
|
| 224 |
+
config['logger']['save_name'].format(epoch=epoch, loss=loss, metric=mae)
|
| 225 |
+
)
|
| 226 |
+
torch.save(checkpoint, save_path)
|
| 227 |
+
|
| 228 |
+
if __name__ == "__main__":
|
| 229 |
+
trainer = BrainAgeTrainer()
|
| 230 |
+
trainer.train()
|
src/BrainIAC/HD_BET/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (4.15 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/config.cpython-38.pyc
ADDED
|
Binary file (4.13 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/config.cpython-39.pyc
ADDED
|
Binary file (4.19 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-310.pyc
ADDED
|
Binary file (4.47 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-38.pyc
ADDED
|
Binary file (4.48 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-39.pyc
ADDED
|
Binary file (4.46 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-310.pyc
ADDED
|
Binary file (4.21 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-38.pyc
ADDED
|
Binary file (4.27 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-310.pyc
ADDED
|
Binary file (6.78 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-38.pyc
ADDED
|
Binary file (6.89 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-39.pyc
ADDED
|
Binary file (6.84 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/paths.cpython-310.pyc
ADDED
|
Binary file (324 Bytes). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/paths.cpython-38.pyc
ADDED
|
Binary file (335 Bytes). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/paths.cpython-39.pyc
ADDED
|
Binary file (322 Bytes). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-310.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-38.pyc
ADDED
|
Binary file (3.67 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-39.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/run.cpython-310.pyc
ADDED
|
Binary file (3.83 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/run.cpython-38.pyc
ADDED
|
Binary file (3.88 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/run.cpython-39.pyc
ADDED
|
Binary file (3.85 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (4.68 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (4.85 kB). View file
|
|
|
src/BrainIAC/HD_BET/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (4.81 kB). View file
|
|
|
src/BrainIAC/HD_BET/config.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from HD_BET.utils import SetNetworkToVal, softmax_helper
|
| 4 |
+
from abc import abstractmethod
|
| 5 |
+
from HD_BET.network_architecture import Network
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseConfig(object):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
@abstractmethod
|
| 13 |
+
def get_split(self, fold, random_state=12345):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def get_network(self, mode="train"):
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def get_basic_generators(self, fold):
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def get_data_generators(self, fold):
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
def preprocess(self, data):
|
| 29 |
+
return data
|
| 30 |
+
|
| 31 |
+
def __repr__(self):
|
| 32 |
+
res = ""
|
| 33 |
+
for v in vars(self):
|
| 34 |
+
if not v.startswith("__") and not v.startswith("_") and v != 'dataset':
|
| 35 |
+
res += (v + ": " + str(self.__getattribute__(v)) + "\n")
|
| 36 |
+
return res
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class HD_BET_Config(BaseConfig):
|
| 40 |
+
def __init__(self):
|
| 41 |
+
super(HD_BET_Config, self).__init__()
|
| 42 |
+
|
| 43 |
+
self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name
|
| 44 |
+
|
| 45 |
+
# network parameters
|
| 46 |
+
self.net_base_num_layers = 21
|
| 47 |
+
self.BATCH_SIZE = 2
|
| 48 |
+
self.net_do_DS = True
|
| 49 |
+
self.net_dropout_p = 0.0
|
| 50 |
+
self.net_use_inst_norm = True
|
| 51 |
+
self.net_conv_use_bias = True
|
| 52 |
+
self.net_norm_use_affine = True
|
| 53 |
+
self.net_leaky_relu_slope = 1e-1
|
| 54 |
+
|
| 55 |
+
# hyperparameters
|
| 56 |
+
self.INPUT_PATCH_SIZE = (128, 128, 128)
|
| 57 |
+
self.num_classes = 2
|
| 58 |
+
self.selected_data_channels = range(1)
|
| 59 |
+
|
| 60 |
+
# data augmentation
|
| 61 |
+
self.da_mirror_axes = (2, 3, 4)
|
| 62 |
+
|
| 63 |
+
# validation
|
| 64 |
+
self.val_use_DO = False
|
| 65 |
+
self.val_use_train_mode = False # for dropout sampling
|
| 66 |
+
self.val_num_repeats = 1 # only useful if dropout sampling
|
| 67 |
+
self.val_batch_size = 1 # only useful if dropout sampling
|
| 68 |
+
self.val_save_npz = True
|
| 69 |
+
self.val_do_mirroring = True # test time data augmentation via mirroring
|
| 70 |
+
self.val_write_images = True
|
| 71 |
+
self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property
|
| 72 |
+
self.val_min_size = self.INPUT_PATCH_SIZE
|
| 73 |
+
self.val_fn = None
|
| 74 |
+
|
| 75 |
+
# CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_
|
| 76 |
+
# stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults
|
| 77 |
+
# to false in 0.4)
|
| 78 |
+
self.val_use_moving_averages = False
|
| 79 |
+
|
| 80 |
+
def get_network(self, train=True, pretrained_weights=None):
|
| 81 |
+
net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers,
|
| 82 |
+
self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias,
|
| 83 |
+
self.net_norm_use_affine, True, self.net_do_DS)
|
| 84 |
+
|
| 85 |
+
if pretrained_weights is not None:
|
| 86 |
+
net.load_state_dict(
|
| 87 |
+
torch.load(pretrained_weights, map_location=lambda storage, loc: storage))
|
| 88 |
+
|
| 89 |
+
if train:
|
| 90 |
+
net.train(True)
|
| 91 |
+
else:
|
| 92 |
+
net.train(False)
|
| 93 |
+
net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages))
|
| 94 |
+
net.do_ds = False
|
| 95 |
+
|
| 96 |
+
optimizer = None
|
| 97 |
+
self.lr_scheduler = None
|
| 98 |
+
return net, optimizer
|
| 99 |
+
|
| 100 |
+
def get_data_generators(self, fold):
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
def get_split(self, fold, random_state=12345):
|
| 104 |
+
pass
|
| 105 |
+
|
| 106 |
+
def get_basic_generators(self, fold):
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
def on_epoch_end(self, epoch):
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
def preprocess(self, data):
|
| 113 |
+
data = np.copy(data)
|
| 114 |
+
for c in range(data.shape[0]):
|
| 115 |
+
data[c] -= data[c].mean()
|
| 116 |
+
data[c] /= data[c].std()
|
| 117 |
+
return data
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
config = HD_BET_Config
|
| 121 |
+
|
src/BrainIAC/HD_BET/data_loading.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import SimpleITK as sitk
|
| 2 |
+
import numpy as np
|
| 3 |
+
from skimage.transform import resize
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def resize_image(image, old_spacing, new_spacing, order=3):
|
| 7 |
+
new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))),
|
| 8 |
+
int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))),
|
| 9 |
+
int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2]))))
|
| 10 |
+
return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)):
|
| 14 |
+
spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]]
|
| 15 |
+
image = sitk.GetArrayFromImage(itk_image).astype(float)
|
| 16 |
+
|
| 17 |
+
assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed"
|
| 18 |
+
|
| 19 |
+
if not is_seg:
|
| 20 |
+
if np.any([[i != j] for i, j in zip(spacing, spacing_target)]):
|
| 21 |
+
image = resize_image(image, spacing, spacing_target).astype(np.float32)
|
| 22 |
+
|
| 23 |
+
image -= image.mean()
|
| 24 |
+
image /= image.std()
|
| 25 |
+
else:
|
| 26 |
+
new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))),
|
| 27 |
+
int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))),
|
| 28 |
+
int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2]))))
|
| 29 |
+
image = resize_segmentation(image, new_shape, 1)
|
| 30 |
+
return image
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_and_preprocess(mri_file):
|
| 34 |
+
images = {}
|
| 35 |
+
# t1
|
| 36 |
+
images["T1"] = sitk.ReadImage(mri_file)
|
| 37 |
+
|
| 38 |
+
properties_dict = {
|
| 39 |
+
"spacing": images["T1"].GetSpacing(),
|
| 40 |
+
"direction": images["T1"].GetDirection(),
|
| 41 |
+
"size": images["T1"].GetSize(),
|
| 42 |
+
"origin": images["T1"].GetOrigin()
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
for k in images.keys():
|
| 46 |
+
images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5))
|
| 47 |
+
|
| 48 |
+
properties_dict['size_before_cropping'] = images["T1"].shape
|
| 49 |
+
|
| 50 |
+
imgs = []
|
| 51 |
+
for seq in ['T1']:
|
| 52 |
+
imgs.append(images[seq][None])
|
| 53 |
+
all_data = np.vstack(imgs)
|
| 54 |
+
print("image shape after preprocessing: ", str(all_data[0].shape))
|
| 55 |
+
return all_data, properties_dict
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def save_segmentation_nifti(segmentation, dct, out_fname, order=1):
|
| 59 |
+
'''
|
| 60 |
+
segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out
|
| 61 |
+
of the original image
|
| 62 |
+
|
| 63 |
+
dct:
|
| 64 |
+
size_before_cropping
|
| 65 |
+
brain_bbox
|
| 66 |
+
size -> this is the original size of the dataset, if the image was not resampled, this is the same as size_before_cropping
|
| 67 |
+
spacing
|
| 68 |
+
origin
|
| 69 |
+
direction
|
| 70 |
+
|
| 71 |
+
:param segmentation:
|
| 72 |
+
:param dct:
|
| 73 |
+
:param out_fname:
|
| 74 |
+
:return:
|
| 75 |
+
'''
|
| 76 |
+
old_size = dct.get('size_before_cropping')
|
| 77 |
+
bbox = dct.get('brain_bbox')
|
| 78 |
+
if bbox is not None:
|
| 79 |
+
seg_old_size = np.zeros(old_size)
|
| 80 |
+
for c in range(3):
|
| 81 |
+
bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c]))
|
| 82 |
+
seg_old_size[bbox[0][0]:bbox[0][1],
|
| 83 |
+
bbox[1][0]:bbox[1][1],
|
| 84 |
+
bbox[2][0]:bbox[2][1]] = segmentation
|
| 85 |
+
else:
|
| 86 |
+
seg_old_size = segmentation
|
| 87 |
+
if np.any(np.array(seg_old_size) != np.array(dct['size'])[[2, 1, 0]]):
|
| 88 |
+
seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order)
|
| 89 |
+
else:
|
| 90 |
+
seg_old_spacing = seg_old_size
|
| 91 |
+
seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(np.int32))
|
| 92 |
+
seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]])
|
| 93 |
+
seg_resized_itk.SetOrigin(dct['origin'])
|
| 94 |
+
seg_resized_itk.SetDirection(dct['direction'])
|
| 95 |
+
sitk.WriteImage(seg_resized_itk, out_fname)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def resize_segmentation(segmentation, new_shape, order=3, cval=0):
|
| 99 |
+
'''
|
| 100 |
+
Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency
|
| 101 |
+
|
| 102 |
+
Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one
|
| 103 |
+
hot encoding which is resized and transformed back to a segmentation map.
|
| 104 |
+
This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2])
|
| 105 |
+
:param segmentation:
|
| 106 |
+
:param new_shape:
|
| 107 |
+
:param order:
|
| 108 |
+
:return:
|
| 109 |
+
'''
|
| 110 |
+
tpe = segmentation.dtype
|
| 111 |
+
unique_labels = np.unique(segmentation)
|
| 112 |
+
assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation"
|
| 113 |
+
if order == 0:
|
| 114 |
+
return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe)
|
| 115 |
+
else:
|
| 116 |
+
reshaped = np.zeros(new_shape, dtype=segmentation.dtype)
|
| 117 |
+
|
| 118 |
+
for i, c in enumerate(unique_labels):
|
| 119 |
+
reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False)
|
| 120 |
+
reshaped[reshaped_multihot >= 0.5] = c
|
| 121 |
+
return reshaped
|
src/BrainIAC/HD_BET/hd_bet.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append("/mnt/93E8-0534/AIDAN/HDBET/")
|
| 6 |
+
from HD_BET.run import run_hd_bet
|
| 7 |
+
from HD_BET.utils import maybe_mkdir_p, subfiles
|
| 8 |
+
import HD_BET
|
| 9 |
+
|
| 10 |
+
def hd_bet(input_file_or_dir,output_file_or_dir,mode,device,tta,pp=1,save_mask=0,overwrite_existing=1):
|
| 11 |
+
|
| 12 |
+
if output_file_or_dir is None:
|
| 13 |
+
output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir),
|
| 14 |
+
os.path.basename(input_file_or_dir).split(".")[0] + "_bet")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
params_file = os.path.join(HD_BET.__path__[0], "model_final.py")
|
| 18 |
+
config_file = os.path.join(HD_BET.__path__[0], "config.py")
|
| 19 |
+
|
| 20 |
+
assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
|
| 21 |
+
|
| 22 |
+
if device == 'cpu':
|
| 23 |
+
pass
|
| 24 |
+
else:
|
| 25 |
+
device = int(device)
|
| 26 |
+
|
| 27 |
+
if os.path.isdir(input_file_or_dir):
|
| 28 |
+
maybe_mkdir_p(output_file_or_dir)
|
| 29 |
+
input_files = subfiles(input_file_or_dir, suffix='_0000.nii.gz', join=False)
|
| 30 |
+
|
| 31 |
+
if len(input_files) == 0:
|
| 32 |
+
raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here")
|
| 33 |
+
|
| 34 |
+
output_files = [os.path.join(output_file_or_dir, i) for i in input_files]
|
| 35 |
+
input_files = [os.path.join(input_file_or_dir, i) for i in input_files]
|
| 36 |
+
else:
|
| 37 |
+
if not output_file_or_dir.endswith('.nii.gz'):
|
| 38 |
+
output_file_or_dir += '.nii.gz'
|
| 39 |
+
assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
|
| 40 |
+
|
| 41 |
+
output_files = [output_file_or_dir]
|
| 42 |
+
input_files = [input_file_or_dir]
|
| 43 |
+
|
| 44 |
+
if tta == 0:
|
| 45 |
+
tta = False
|
| 46 |
+
elif tta == 1:
|
| 47 |
+
tta = True
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta))
|
| 50 |
+
|
| 51 |
+
if overwrite_existing == 0:
|
| 52 |
+
overwrite_existing = False
|
| 53 |
+
elif overwrite_existing == 1:
|
| 54 |
+
overwrite_existing = True
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing))
|
| 57 |
+
|
| 58 |
+
if pp == 0:
|
| 59 |
+
pp = False
|
| 60 |
+
elif pp == 1:
|
| 61 |
+
pp = True
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
|
| 64 |
+
|
| 65 |
+
if save_mask == 0:
|
| 66 |
+
save_mask = False
|
| 67 |
+
elif save_mask == 1:
|
| 68 |
+
save_mask = True
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
|
| 71 |
+
|
| 72 |
+
run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
print("\n########################")
|
| 77 |
+
print("If you are using hd-bet, please cite the following paper:")
|
| 78 |
+
print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W,"
|
| 79 |
+
"Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial"
|
| 80 |
+
"neural networks. arXiv preprint arXiv:1901.11341, 2019.")
|
| 81 |
+
print("########################\n")
|
| 82 |
+
|
| 83 |
+
import argparse
|
| 84 |
+
parser = argparse.ArgumentParser()
|
| 85 |
+
parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be '
|
| 86 |
+
'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to '
|
| 87 |
+
'split 4d sequences into 3d images. If folder: all files ending with .nii.gz '
|
| 88 |
+
'within that folder will be brain extracted.', required=True, type=str)
|
| 89 |
+
parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder'
|
| 90 |
+
' will be created', required=False, type=str)
|
| 91 |
+
parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will '
|
| 92 |
+
'use only one set of parameters whereas accurate will '
|
| 93 |
+
'use the five sets of parameters that resulted from '
|
| 94 |
+
'our cross-validation as an ensemble. Default: '
|
| 95 |
+
'accurate',
|
| 96 |
+
required=False)
|
| 97 |
+
parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. '
|
| 98 |
+
'Must be either int or str. Use int for GPU id or '
|
| 99 |
+
'\'cpu\' to run on CPU. When using CPU you should '
|
| 100 |
+
'consider disabling tta. Default for -device is: 0',
|
| 101 |
+
required=False)
|
| 102 |
+
parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation '
|
| 103 |
+
'(mirroring). 1= True, 0=False. Disable this '
|
| 104 |
+
'if you are using CPU to speed things up! '
|
| 105 |
+
'Default: 1')
|
| 106 |
+
parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all'
|
| 107 |
+
' but the largest connected component in '
|
| 108 |
+
'the prediction. Default: 1')
|
| 109 |
+
parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation '
|
| 110 |
+
'mask will not be '
|
| 111 |
+
'saved')
|
| 112 |
+
parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't "
|
| 113 |
+
"want to overwrite existing "
|
| 114 |
+
"predictions")
|
| 115 |
+
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
|
| 118 |
+
hd_bet(args.input,args.output,args.mode,args.device,args.tta,args.pp,args.save_mask,args.overwrite_existing)
|
| 119 |
+
|
src/BrainIAC/HD_BET/network_architecture.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from HD_BET.utils import softmax_helper
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EncodingModule(nn.Module):
|
| 8 |
+
def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True,
|
| 9 |
+
inst_norm_affine=True, lrelu_inplace=True):
|
| 10 |
+
nn.Module.__init__(self)
|
| 11 |
+
self.dropout_p = dropout_p
|
| 12 |
+
self.lrelu_inplace = lrelu_inplace
|
| 13 |
+
self.inst_norm_affine = inst_norm_affine
|
| 14 |
+
self.conv_bias = conv_bias
|
| 15 |
+
self.leakiness = leakiness
|
| 16 |
+
self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 17 |
+
self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
|
| 18 |
+
self.dropout = nn.Dropout3d(dropout_p)
|
| 19 |
+
self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 20 |
+
self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
skip = x
|
| 24 |
+
x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 25 |
+
x = self.conv1(x)
|
| 26 |
+
if self.dropout_p is not None and self.dropout_p > 0:
|
| 27 |
+
x = self.dropout(x)
|
| 28 |
+
x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 29 |
+
x = self.conv2(x)
|
| 30 |
+
x = x + skip
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Upsample(nn.Module):
|
| 35 |
+
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True):
|
| 36 |
+
super(Upsample, self).__init__()
|
| 37 |
+
self.align_corners = align_corners
|
| 38 |
+
self.mode = mode
|
| 39 |
+
self.scale_factor = scale_factor
|
| 40 |
+
self.size = size
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
|
| 44 |
+
align_corners=self.align_corners)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LocalizationModule(nn.Module):
|
| 48 |
+
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
| 49 |
+
lrelu_inplace=True):
|
| 50 |
+
nn.Module.__init__(self)
|
| 51 |
+
self.lrelu_inplace = lrelu_inplace
|
| 52 |
+
self.inst_norm_affine = inst_norm_affine
|
| 53 |
+
self.conv_bias = conv_bias
|
| 54 |
+
self.leakiness = leakiness
|
| 55 |
+
self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias)
|
| 56 |
+
self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 57 |
+
self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias)
|
| 58 |
+
self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 62 |
+
x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class UpsamplingModule(nn.Module):
|
| 67 |
+
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
| 68 |
+
lrelu_inplace=True):
|
| 69 |
+
nn.Module.__init__(self)
|
| 70 |
+
self.lrelu_inplace = lrelu_inplace
|
| 71 |
+
self.inst_norm_affine = inst_norm_affine
|
| 72 |
+
self.conv_bias = conv_bias
|
| 73 |
+
self.leakiness = leakiness
|
| 74 |
+
self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True)
|
| 75 |
+
self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias)
|
| 76 |
+
self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness,
|
| 80 |
+
inplace=self.lrelu_inplace)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class DownsamplingModule(nn.Module):
|
| 85 |
+
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
| 86 |
+
lrelu_inplace=True):
|
| 87 |
+
nn.Module.__init__(self)
|
| 88 |
+
self.lrelu_inplace = lrelu_inplace
|
| 89 |
+
self.inst_norm_affine = inst_norm_affine
|
| 90 |
+
self.conv_bias = conv_bias
|
| 91 |
+
self.leakiness = leakiness
|
| 92 |
+
self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 93 |
+
self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias)
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 97 |
+
b = self.downsample(x)
|
| 98 |
+
return x, b
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class Network(nn.Module):
|
| 102 |
+
def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3,
|
| 103 |
+
final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
| 104 |
+
lrelu_inplace=True, do_ds=True):
|
| 105 |
+
super(Network, self).__init__()
|
| 106 |
+
|
| 107 |
+
self.do_ds = do_ds
|
| 108 |
+
self.lrelu_inplace = lrelu_inplace
|
| 109 |
+
self.inst_norm_affine = inst_norm_affine
|
| 110 |
+
self.conv_bias = conv_bias
|
| 111 |
+
self.leakiness = leakiness
|
| 112 |
+
self.final_nonlin = final_nonlin
|
| 113 |
+
self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias)
|
| 114 |
+
|
| 115 |
+
self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
| 116 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 117 |
+
self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True,
|
| 118 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 119 |
+
|
| 120 |
+
self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
| 121 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 122 |
+
self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True,
|
| 123 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 124 |
+
|
| 125 |
+
self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
| 126 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 127 |
+
self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True,
|
| 128 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 129 |
+
|
| 130 |
+
self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
| 131 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 132 |
+
self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True,
|
| 133 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 134 |
+
|
| 135 |
+
self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2,
|
| 136 |
+
conv_bias=True, inst_norm_affine=True, lrelu_inplace=True)
|
| 137 |
+
|
| 138 |
+
self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
|
| 139 |
+
self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 140 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 141 |
+
|
| 142 |
+
self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 143 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 144 |
+
self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 145 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 146 |
+
|
| 147 |
+
self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 148 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 149 |
+
self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False)
|
| 150 |
+
self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 151 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 152 |
+
|
| 153 |
+
self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 154 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 155 |
+
self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
|
| 156 |
+
self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 157 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 158 |
+
|
| 159 |
+
self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
|
| 160 |
+
self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
|
| 161 |
+
self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
|
| 162 |
+
self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
|
| 163 |
+
self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
|
| 164 |
+
|
| 165 |
+
def forward(self, x):
|
| 166 |
+
seg_outputs = []
|
| 167 |
+
|
| 168 |
+
x = self.init_conv(x)
|
| 169 |
+
x = self.context1(x)
|
| 170 |
+
|
| 171 |
+
skip1, x = self.down1(x)
|
| 172 |
+
x = self.context2(x)
|
| 173 |
+
|
| 174 |
+
skip2, x = self.down2(x)
|
| 175 |
+
x = self.context3(x)
|
| 176 |
+
|
| 177 |
+
skip3, x = self.down3(x)
|
| 178 |
+
x = self.context4(x)
|
| 179 |
+
|
| 180 |
+
skip4, x = self.down4(x)
|
| 181 |
+
x = self.context5(x)
|
| 182 |
+
|
| 183 |
+
x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 184 |
+
x = self.up1(x)
|
| 185 |
+
|
| 186 |
+
x = torch.cat((skip4, x), dim=1)
|
| 187 |
+
x = self.loc1(x)
|
| 188 |
+
x = self.up2(x)
|
| 189 |
+
|
| 190 |
+
x = torch.cat((skip3, x), dim=1)
|
| 191 |
+
x = self.loc2(x)
|
| 192 |
+
loc2_seg = self.final_nonlin(self.loc2_seg(x))
|
| 193 |
+
seg_outputs.append(loc2_seg)
|
| 194 |
+
x = self.up3(x)
|
| 195 |
+
|
| 196 |
+
x = torch.cat((skip2, x), dim=1)
|
| 197 |
+
x = self.loc3(x)
|
| 198 |
+
loc3_seg = self.final_nonlin(self.loc3_seg(x))
|
| 199 |
+
seg_outputs.append(loc3_seg)
|
| 200 |
+
x = self.up4(x)
|
| 201 |
+
|
| 202 |
+
x = torch.cat((skip1, x), dim=1)
|
| 203 |
+
x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness,
|
| 204 |
+
inplace=self.lrelu_inplace)
|
| 205 |
+
x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness,
|
| 206 |
+
inplace=self.lrelu_inplace)
|
| 207 |
+
x = self.final_nonlin(self.seg_layer(x))
|
| 208 |
+
seg_outputs.append(x)
|
| 209 |
+
|
| 210 |
+
if self.do_ds:
|
| 211 |
+
return seg_outputs[::-1]
|
| 212 |
+
else:
|
| 213 |
+
return seg_outputs[-1]
|
src/BrainIAC/HD_BET/paths.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# please refer to the readme on where to get the parameters. Save them in this folder:
|
| 4 |
+
# Original Path: "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params"
|
| 5 |
+
# Updated path for Docker container:
|
| 6 |
+
folder_with_parameter_files = "/app/BrainIAC/hdbet_model"
|
src/BrainIAC/HD_BET/predict_case.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None):
|
| 6 |
+
if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)):
|
| 7 |
+
shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3
|
| 8 |
+
shp = patient.shape
|
| 9 |
+
new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0],
|
| 10 |
+
shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1],
|
| 11 |
+
shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]]
|
| 12 |
+
for i in range(len(shp)):
|
| 13 |
+
if shp[i] % shape_must_be_divisible_by[i] == 0:
|
| 14 |
+
new_shp[i] -= shape_must_be_divisible_by[i]
|
| 15 |
+
if min_size is not None:
|
| 16 |
+
new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0)
|
| 17 |
+
return reshape_by_padding_upper_coords(patient, new_shp, 0), shp
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def reshape_by_padding_upper_coords(image, new_shape, pad_value=None):
|
| 21 |
+
shape = tuple(list(image.shape))
|
| 22 |
+
new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0))
|
| 23 |
+
if pad_value is None:
|
| 24 |
+
if len(shape) == 2:
|
| 25 |
+
pad_value = image[0,0]
|
| 26 |
+
elif len(shape) == 3:
|
| 27 |
+
pad_value = image[0, 0, 0]
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError("Image must be either 2 or 3 dimensional")
|
| 30 |
+
res = np.ones(list(new_shape), dtype=image.dtype) * pad_value
|
| 31 |
+
if len(shape) == 2:
|
| 32 |
+
res[0:0+int(shape[0]), 0:0+int(shape[1])] = image
|
| 33 |
+
elif len(shape) == 3:
|
| 34 |
+
res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image
|
| 35 |
+
return res
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None,
|
| 39 |
+
new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)):
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
pad_res = []
|
| 42 |
+
for i in range(patient_data.shape[0]):
|
| 43 |
+
t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size)
|
| 44 |
+
pad_res.append(t[None])
|
| 45 |
+
|
| 46 |
+
patient_data = np.vstack(pad_res)
|
| 47 |
+
|
| 48 |
+
new_shp = patient_data.shape
|
| 49 |
+
|
| 50 |
+
data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32)
|
| 51 |
+
|
| 52 |
+
data[0] = patient_data
|
| 53 |
+
|
| 54 |
+
if BATCH_SIZE is not None:
|
| 55 |
+
data = np.vstack([data] * BATCH_SIZE)
|
| 56 |
+
|
| 57 |
+
a = torch.rand(data.shape).float()
|
| 58 |
+
|
| 59 |
+
if main_device == 'cpu':
|
| 60 |
+
pass
|
| 61 |
+
else:
|
| 62 |
+
a = a.cuda(main_device)
|
| 63 |
+
|
| 64 |
+
if do_mirroring:
|
| 65 |
+
x = 8
|
| 66 |
+
else:
|
| 67 |
+
x = 1
|
| 68 |
+
all_preds = []
|
| 69 |
+
for i in range(num_repeats):
|
| 70 |
+
for m in range(x):
|
| 71 |
+
data_for_net = np.array(data)
|
| 72 |
+
do_stuff = False
|
| 73 |
+
if m == 0:
|
| 74 |
+
do_stuff = True
|
| 75 |
+
pass
|
| 76 |
+
if m == 1 and (4 in mirror_axes):
|
| 77 |
+
do_stuff = True
|
| 78 |
+
data_for_net = data_for_net[:, :, :, :, ::-1]
|
| 79 |
+
if m == 2 and (3 in mirror_axes):
|
| 80 |
+
do_stuff = True
|
| 81 |
+
data_for_net = data_for_net[:, :, :, ::-1, :]
|
| 82 |
+
if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
|
| 83 |
+
do_stuff = True
|
| 84 |
+
data_for_net = data_for_net[:, :, :, ::-1, ::-1]
|
| 85 |
+
if m == 4 and (2 in mirror_axes):
|
| 86 |
+
do_stuff = True
|
| 87 |
+
data_for_net = data_for_net[:, :, ::-1, :, :]
|
| 88 |
+
if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
|
| 89 |
+
do_stuff = True
|
| 90 |
+
data_for_net = data_for_net[:, :, ::-1, :, ::-1]
|
| 91 |
+
if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
|
| 92 |
+
do_stuff = True
|
| 93 |
+
data_for_net = data_for_net[:, :, ::-1, ::-1, :]
|
| 94 |
+
if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
|
| 95 |
+
do_stuff = True
|
| 96 |
+
data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1]
|
| 97 |
+
|
| 98 |
+
if do_stuff:
|
| 99 |
+
_ = a.data.copy_(torch.from_numpy(np.copy(data_for_net)))
|
| 100 |
+
p = net(a) # np.copy is necessary because ::-1 creates just a view i think
|
| 101 |
+
p = p.data.cpu().numpy()
|
| 102 |
+
|
| 103 |
+
if m == 0:
|
| 104 |
+
pass
|
| 105 |
+
if m == 1 and (4 in mirror_axes):
|
| 106 |
+
p = p[:, :, :, :, ::-1]
|
| 107 |
+
if m == 2 and (3 in mirror_axes):
|
| 108 |
+
p = p[:, :, :, ::-1, :]
|
| 109 |
+
if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
|
| 110 |
+
p = p[:, :, :, ::-1, ::-1]
|
| 111 |
+
if m == 4 and (2 in mirror_axes):
|
| 112 |
+
p = p[:, :, ::-1, :, :]
|
| 113 |
+
if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
|
| 114 |
+
p = p[:, :, ::-1, :, ::-1]
|
| 115 |
+
if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
|
| 116 |
+
p = p[:, :, ::-1, ::-1, :]
|
| 117 |
+
if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
|
| 118 |
+
p = p[:, :, ::-1, ::-1, ::-1]
|
| 119 |
+
all_preds.append(p)
|
| 120 |
+
|
| 121 |
+
stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]]
|
| 122 |
+
predicted_segmentation = stacked.mean(0).argmax(0)
|
| 123 |
+
uncertainty = stacked.var(0)
|
| 124 |
+
bayesian_predictions = stacked
|
| 125 |
+
softmax_pred = stacked.mean(0)
|
| 126 |
+
return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty
|
src/BrainIAC/HD_BET/run.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import SimpleITK as sitk
|
| 4 |
+
from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti
|
| 5 |
+
from HD_BET.predict_case import predict_case_3D_net
|
| 6 |
+
import imp
|
| 7 |
+
from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters
|
| 8 |
+
import os
|
| 9 |
+
import HD_BET
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def apply_bet(img, bet, out_fname):
|
| 13 |
+
img_itk = sitk.ReadImage(img)
|
| 14 |
+
img_npy = sitk.GetArrayFromImage(img_itk)
|
| 15 |
+
img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet))
|
| 16 |
+
img_npy[img_bet == 0] = 0
|
| 17 |
+
out = sitk.GetImageFromArray(img_npy)
|
| 18 |
+
out.CopyInformation(img_itk)
|
| 19 |
+
sitk.WriteImage(out, out_fname)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0,
|
| 23 |
+
postprocess=False, do_tta=True, keep_mask=True, overwrite=True):
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
:param mri_fnames: str or list/tuple of str
|
| 27 |
+
:param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames
|
| 28 |
+
:param mode: fast or accurate
|
| 29 |
+
:param config_file: config.py
|
| 30 |
+
:param device: either int (for device id) or 'cpu'
|
| 31 |
+
:param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all
|
| 32 |
+
but the largest predicted connected component. Default False
|
| 33 |
+
:param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use
|
| 34 |
+
CPU you may want to turn that off to speed things up
|
| 35 |
+
:return:
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
list_of_param_files = []
|
| 39 |
+
|
| 40 |
+
if mode == 'fast':
|
| 41 |
+
params_file = get_params_fname(0)
|
| 42 |
+
maybe_download_parameters(0)
|
| 43 |
+
|
| 44 |
+
list_of_param_files.append(params_file)
|
| 45 |
+
elif mode == 'accurate':
|
| 46 |
+
for i in range(5):
|
| 47 |
+
params_file = get_params_fname(i)
|
| 48 |
+
maybe_download_parameters(i)
|
| 49 |
+
|
| 50 |
+
list_of_param_files.append(params_file)
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode)
|
| 53 |
+
|
| 54 |
+
assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files"
|
| 55 |
+
|
| 56 |
+
cf = imp.load_source('cf', config_file)
|
| 57 |
+
cf = cf.config()
|
| 58 |
+
|
| 59 |
+
net, _ = cf.get_network(cf.val_use_train_mode, None)
|
| 60 |
+
if device == "cpu":
|
| 61 |
+
net = net.cpu()
|
| 62 |
+
else:
|
| 63 |
+
net.cuda(device)
|
| 64 |
+
|
| 65 |
+
if not isinstance(mri_fnames, (list, tuple)):
|
| 66 |
+
mri_fnames = [mri_fnames]
|
| 67 |
+
|
| 68 |
+
if not isinstance(output_fnames, (list, tuple)):
|
| 69 |
+
output_fnames = [output_fnames]
|
| 70 |
+
|
| 71 |
+
assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length"
|
| 72 |
+
|
| 73 |
+
params = []
|
| 74 |
+
for p in list_of_param_files:
|
| 75 |
+
params.append(torch.load(p, map_location=lambda storage, loc: storage))
|
| 76 |
+
|
| 77 |
+
for in_fname, out_fname in zip(mri_fnames, output_fnames):
|
| 78 |
+
mask_fname = out_fname[:-7] + "_mask.nii.gz"
|
| 79 |
+
if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)):
|
| 80 |
+
print("File:", in_fname)
|
| 81 |
+
print("preprocessing...")
|
| 82 |
+
try:
|
| 83 |
+
data, data_dict = load_and_preprocess(in_fname)
|
| 84 |
+
except RuntimeError:
|
| 85 |
+
print("\nERROR\nCould not read file", in_fname, "\n")
|
| 86 |
+
continue
|
| 87 |
+
except AssertionError as e:
|
| 88 |
+
print(e)
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
softmax_preds = []
|
| 92 |
+
|
| 93 |
+
print("prediction (CNN id)...")
|
| 94 |
+
for i, p in enumerate(params):
|
| 95 |
+
print(i)
|
| 96 |
+
net.load_state_dict(p)
|
| 97 |
+
net.eval()
|
| 98 |
+
net.apply(SetNetworkToVal(False, False))
|
| 99 |
+
_, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats,
|
| 100 |
+
cf.val_batch_size, cf.net_input_must_be_divisible_by,
|
| 101 |
+
cf.val_min_size, device, cf.da_mirror_axes)
|
| 102 |
+
softmax_preds.append(softmax_pred[None])
|
| 103 |
+
|
| 104 |
+
seg = np.argmax(np.vstack(softmax_preds).mean(0), 0)
|
| 105 |
+
|
| 106 |
+
if postprocess:
|
| 107 |
+
seg = postprocess_prediction(seg)
|
| 108 |
+
|
| 109 |
+
print("exporting segmentation...")
|
| 110 |
+
save_segmentation_nifti(seg, data_dict, mask_fname)
|
| 111 |
+
|
| 112 |
+
apply_bet(in_fname, mask_fname, out_fname)
|
| 113 |
+
|
| 114 |
+
if not keep_mask:
|
| 115 |
+
os.remove(mask_fname)
|
| 116 |
+
|
| 117 |
+
|
src/BrainIAC/HD_BET/utils.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from urllib.request import urlopen
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from skimage.morphology import label
|
| 6 |
+
import os
|
| 7 |
+
from HD_BET.paths import folder_with_parameter_files
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_params_fname(fold):
|
| 11 |
+
return os.path.join(folder_with_parameter_files, "%d.model" % fold)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def maybe_download_parameters(fold=0, force_overwrite=False):
|
| 15 |
+
"""
|
| 16 |
+
Downloads the parameters for some fold if it is not present yet.
|
| 17 |
+
:param fold:
|
| 18 |
+
:param force_overwrite: if True the old parameter file will be deleted (if present) prior to download
|
| 19 |
+
:return:
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
assert 0 <= fold <= 4, "fold must be between 0 and 4"
|
| 23 |
+
|
| 24 |
+
if not os.path.isdir(folder_with_parameter_files):
|
| 25 |
+
maybe_mkdir_p(folder_with_parameter_files)
|
| 26 |
+
|
| 27 |
+
out_filename = get_params_fname(fold)
|
| 28 |
+
|
| 29 |
+
if force_overwrite and os.path.isfile(out_filename):
|
| 30 |
+
os.remove(out_filename)
|
| 31 |
+
|
| 32 |
+
if not os.path.isfile(out_filename):
|
| 33 |
+
url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold
|
| 34 |
+
print("Downloading", url, "...")
|
| 35 |
+
data = urlopen(url).read()
|
| 36 |
+
#out_filename = "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params/0.model"
|
| 37 |
+
with open(out_filename, 'wb') as f:
|
| 38 |
+
f.write(data)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def init_weights(module):
|
| 42 |
+
if isinstance(module, nn.Conv3d):
|
| 43 |
+
module.weight = nn.init.kaiming_normal(module.weight, a=1e-2)
|
| 44 |
+
if module.bias is not None:
|
| 45 |
+
module.bias = nn.init.constant(module.bias, 0)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def softmax_helper(x):
|
| 49 |
+
rpt = [1 for _ in range(len(x.size()))]
|
| 50 |
+
rpt[1] = x.size(1)
|
| 51 |
+
x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
|
| 52 |
+
e_x = torch.exp(x - x_max)
|
| 53 |
+
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SetNetworkToVal(object):
|
| 57 |
+
def __init__(self, use_dropout_sampling=False, norm_use_average=True):
|
| 58 |
+
self.norm_use_average = norm_use_average
|
| 59 |
+
self.use_dropout_sampling = use_dropout_sampling
|
| 60 |
+
|
| 61 |
+
def __call__(self, module):
|
| 62 |
+
if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout):
|
| 63 |
+
module.train(self.use_dropout_sampling)
|
| 64 |
+
elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \
|
| 65 |
+
isinstance(module, nn.InstanceNorm1d) \
|
| 66 |
+
or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \
|
| 67 |
+
isinstance(module, nn.BatchNorm1d):
|
| 68 |
+
module.train(not self.norm_use_average)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def postprocess_prediction(seg):
|
| 72 |
+
# basically look for connected components and choose the largest one, delete everything else
|
| 73 |
+
print("running postprocessing... ")
|
| 74 |
+
mask = seg != 0
|
| 75 |
+
lbls = label(mask, connectivity=mask.ndim)
|
| 76 |
+
lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)]
|
| 77 |
+
largest_region = np.argmax(lbls_sizes[1:]) + 1
|
| 78 |
+
seg[lbls != largest_region] = 0
|
| 79 |
+
return seg
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
|
| 83 |
+
if join:
|
| 84 |
+
l = os.path.join
|
| 85 |
+
else:
|
| 86 |
+
l = lambda x, y: y
|
| 87 |
+
res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))
|
| 88 |
+
and (prefix is None or i.startswith(prefix))
|
| 89 |
+
and (suffix is None or i.endswith(suffix))]
|
| 90 |
+
if sort:
|
| 91 |
+
res.sort()
|
| 92 |
+
return res
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
|
| 96 |
+
if join:
|
| 97 |
+
l = os.path.join
|
| 98 |
+
else:
|
| 99 |
+
l = lambda x, y: y
|
| 100 |
+
res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
|
| 101 |
+
and (prefix is None or i.startswith(prefix))
|
| 102 |
+
and (suffix is None or i.endswith(suffix))]
|
| 103 |
+
if sort:
|
| 104 |
+
res.sort()
|
| 105 |
+
return res
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
subfolders = subdirs # I am tired of confusing those
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def maybe_mkdir_p(directory):
|
| 112 |
+
splits = directory.split("/")[1:]
|
| 113 |
+
for i in range(0, len(splits)):
|
| 114 |
+
if not os.path.isdir(os.path.join("", *splits[:i+1])):
|
| 115 |
+
os.mkdir(os.path.join("", *splits[:i+1]))
|
src/BrainIAC/MCIclassification/README.md
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MCI Classification
|
| 2 |
+
|
| 3 |
+
<p align="left">
|
| 4 |
+
<img src="mci.jpeg" width="200" alt="MCI Classification Example"/>
|
| 5 |
+
</p>
|
| 6 |
+
|
| 7 |
+
## Overview
|
| 8 |
+
|
| 9 |
+
We present the MCI classification training and inference code for BrainIAC as a downstream task. The pipeline is trained and infered on T1 scans, with AUC and F1 as evaluation metric.
|
| 10 |
+
|
| 11 |
+
## Data Requirements
|
| 12 |
+
|
| 13 |
+
- **Input**: T1-weighted MR scans
|
| 14 |
+
- **Format**: NIFTI (.nii.gz)
|
| 15 |
+
- **Preprocessing**: Bias field corrected, registered to standard space, skull stripped, histogram normalized (optional)
|
| 16 |
+
- **CSV Structure**:
|
| 17 |
+
```
|
| 18 |
+
pat_id,scandate,label
|
| 19 |
+
subject001,20240101,1 # 1 for MCI, 0 for healthy control
|
| 20 |
+
```
|
| 21 |
+
refer to [ quickstart.ipynb](../quickstart.ipynb) to find how to preprocess data and generate csv file.
|
| 22 |
+
|
| 23 |
+
## Setup
|
| 24 |
+
|
| 25 |
+
1. **Configuration**:
|
| 26 |
+
change the [config.yml](../config.yml) file accordingly.
|
| 27 |
+
```yaml
|
| 28 |
+
# config.yml
|
| 29 |
+
data:
|
| 30 |
+
train_csv: "path/to/train.csv"
|
| 31 |
+
val_csv: "path/to/val.csv"
|
| 32 |
+
test_csv: "path/to/test.csv"
|
| 33 |
+
root_dir: "../data/sample/processed"
|
| 34 |
+
collate: 1 # single scan framework
|
| 35 |
+
|
| 36 |
+
checkpoints: "./checkpoints/mci_model.00" # for inference/testing
|
| 37 |
+
|
| 38 |
+
train:
|
| 39 |
+
finetune: 'yes' # yes to finetune the entire model
|
| 40 |
+
freeze: 'no' # yes to freeze the resnet backbone
|
| 41 |
+
weights: ./checkpoints/brainiac.ckpt # path to brainiac weights
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
2. **Training**:
|
| 45 |
+
```bash
|
| 46 |
+
python -m MCIclassification.train_mci
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
3. **Inference**:
|
| 50 |
+
```bash
|
| 51 |
+
python -m MCIclassification.infer_mci
|
| 52 |
+
```
|
src/BrainIAC/MCIclassification/__init__.py
ADDED
|
File without changes
|
src/BrainIAC/MCIclassification/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
src/BrainIAC/MCIclassification/__pycache__/infer_mci.cpython-39.pyc
ADDED
|
Binary file (3.96 kB). View file
|
|
|
src/BrainIAC/MCIclassification/infer_mci.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import os
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torch.cuda.amp import autocast
|
| 7 |
+
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
| 8 |
+
import numpy as np
|
| 9 |
+
import sys
|
| 10 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 11 |
+
from dataset2 import MedicalImageDatasetBalancedIntensity3D
|
| 12 |
+
from model import Backbone, SingleScanModelBP, Classifier
|
| 13 |
+
from utils import BaseConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def calculate_metrics(pred_probs, pred_labels, true_labels):
|
| 18 |
+
"""
|
| 19 |
+
classification metrics.
|
| 20 |
+
Args:
|
| 21 |
+
pred_probs (numpy.ndarray): Predicted probabilities
|
| 22 |
+
pred_labels (numpy.ndarray): Predicted labels
|
| 23 |
+
true_labels (numpy.ndarray): Ground truth labels
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
dict: Dictionary containing accuracy, precision, recall, F1, and AUC metrics
|
| 27 |
+
"""
|
| 28 |
+
accuracy = accuracy_score(true_labels, pred_labels)
|
| 29 |
+
precision = precision_score(true_labels, pred_labels)
|
| 30 |
+
recall = recall_score(true_labels, pred_labels)
|
| 31 |
+
f1 = f1_score(true_labels, pred_labels)
|
| 32 |
+
auc = roc_auc_score(true_labels, pred_probs)
|
| 33 |
+
|
| 34 |
+
return {
|
| 35 |
+
'accuracy': accuracy,
|
| 36 |
+
'precision': precision,
|
| 37 |
+
'recall': recall,
|
| 38 |
+
'f1': f1,
|
| 39 |
+
'auc': auc
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
#============================
|
| 44 |
+
# INFERENCE CLASS
|
| 45 |
+
#============================
|
| 46 |
+
class MCIInference(BaseConfig):
|
| 47 |
+
"""
|
| 48 |
+
Inference class for MCI classification model.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.setup_model()
|
| 54 |
+
self.setup_data()
|
| 55 |
+
|
| 56 |
+
def setup_model(self):
|
| 57 |
+
config = self.get_config()
|
| 58 |
+
self.backbone = Backbone()
|
| 59 |
+
self.classifier = Classifier(d_model=2048, num_classes=1) # Binary classification
|
| 60 |
+
self.model = SingleScanModelBP(self.backbone, self.classifier)
|
| 61 |
+
|
| 62 |
+
# Load weights
|
| 63 |
+
checkpoint = torch.load(config["infer"]["checkpoints"], map_location=self.device, weights_only=False)
|
| 64 |
+
self.model.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
| 65 |
+
self.model = self.model.to(self.device)
|
| 66 |
+
self.model.eval()
|
| 67 |
+
print("Model and checkpoint loaded!")
|
| 68 |
+
|
| 69 |
+
## spin up data loaders
|
| 70 |
+
def setup_data(self):
|
| 71 |
+
config = self.get_config()
|
| 72 |
+
self.test_dataset = MedicalImageDatasetBalancedIntensity3D(
|
| 73 |
+
csv_path=config["data"]["test_csv"],
|
| 74 |
+
root_dir=config["data"]["root_dir"]
|
| 75 |
+
)
|
| 76 |
+
self.test_loader = DataLoader(
|
| 77 |
+
self.test_dataset,
|
| 78 |
+
batch_size=1,
|
| 79 |
+
shuffle=False,
|
| 80 |
+
collate_fn=self.custom_collate,
|
| 81 |
+
num_workers=1
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def infer(self):
|
| 85 |
+
"""
|
| 86 |
+
Run inference pass
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
dict: Dictionary with evaluation metrics
|
| 90 |
+
"""
|
| 91 |
+
results_df = pd.DataFrame(columns=['PredictedProb', 'PredictedLabel', 'TrueLabel'])
|
| 92 |
+
all_labels = []
|
| 93 |
+
all_predictions = []
|
| 94 |
+
all_probs = []
|
| 95 |
+
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
for sample in tqdm(self.test_loader, desc="Inference", unit="batch"):
|
| 98 |
+
inputs = sample['image'].to(self.device)
|
| 99 |
+
labels = sample['label'].float().to(self.device)
|
| 100 |
+
|
| 101 |
+
with autocast():
|
| 102 |
+
outputs = self.model(inputs)
|
| 103 |
+
|
| 104 |
+
probs = torch.sigmoid(outputs).cpu().numpy().flatten()
|
| 105 |
+
preds = (probs > 0.5).astype(int)
|
| 106 |
+
|
| 107 |
+
all_labels.extend(labels.cpu().numpy().flatten())
|
| 108 |
+
all_predictions.extend(preds)
|
| 109 |
+
all_probs.extend(probs)
|
| 110 |
+
|
| 111 |
+
result = pd.DataFrame({
|
| 112 |
+
'PredictedProb': probs,
|
| 113 |
+
'PredictedLabel': preds,
|
| 114 |
+
'TrueLabel': labels.cpu().numpy().flatten()
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
results_df = pd.concat([results_df, result], ignore_index=True)
|
| 118 |
+
|
| 119 |
+
# log metrics
|
| 120 |
+
"""metrics = calculate_metrics(
|
| 121 |
+
np.array(all_probs),
|
| 122 |
+
np.array(all_predictions),
|
| 123 |
+
np.array(all_labels)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
print("\nTest Set Metrics:")
|
| 128 |
+
print(f"Accuracy: {metrics['accuracy']:.4f}")
|
| 129 |
+
print(f"Precision: {metrics['precision']:.4f}")
|
| 130 |
+
print(f"Recall: {metrics['recall']:.4f}")
|
| 131 |
+
print(f"F1 Score: {metrics['f1']:.4f}")
|
| 132 |
+
print(f"AUC: {metrics['auc']:.4f}")"""
|
| 133 |
+
|
| 134 |
+
# Save results
|
| 135 |
+
print("PredictedLabel", results_df["PredictedLabel"][0])
|
| 136 |
+
results_df.to_csv('./data/output/mci_classification_predictions.csv', index=False)
|
| 137 |
+
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
inferencer = MCIInference()
|
| 142 |
+
metrics = inferencer.infer()
|
src/BrainIAC/MCIclassification/mci.jpeg
ADDED
|
Git LFS Details
|
src/BrainIAC/MCIclassification/train_mci.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import wandb
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
| 8 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 12 |
+
from dataset2 import MedicalImageDatasetBalancedIntensity3D, TransformationMedicalImageDatasetBalancedIntensity3D
|
| 13 |
+
from model import Backbone, SingleScanModel, Classifier
|
| 14 |
+
from utils import BaseConfig
|
| 15 |
+
import numpy as np
|
| 16 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def calculate_metrics(pred_probs, pred_labels, true_labels):
|
| 20 |
+
"""
|
| 21 |
+
classification metrics.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
pred_probs (numpy.ndarray): Predicted probabilities
|
| 25 |
+
pred_labels (numpy.ndarray): Predicted labels
|
| 26 |
+
true_labels (numpy.ndarray): Ground truth labels
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
dict: Dictionary containing accuracy, precision, recall, F1, and AUC
|
| 30 |
+
"""
|
| 31 |
+
accuracy = accuracy_score(true_labels, pred_labels)
|
| 32 |
+
precision = precision_score(true_labels, pred_labels)
|
| 33 |
+
recall = recall_score(true_labels, pred_labels)
|
| 34 |
+
f1 = f1_score(true_labels, pred_labels)
|
| 35 |
+
auc = roc_auc_score(true_labels, pred_probs)
|
| 36 |
+
|
| 37 |
+
return {
|
| 38 |
+
'accuracy': accuracy,
|
| 39 |
+
'precision': precision,
|
| 40 |
+
'recall': recall,
|
| 41 |
+
'f1': f1,
|
| 42 |
+
'auc': auc
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
#============================
|
| 46 |
+
# TRAINER CLASS
|
| 47 |
+
#============================
|
| 48 |
+
|
| 49 |
+
class MCITrainer(BaseConfig):
|
| 50 |
+
"""
|
| 51 |
+
trainer class for MCI classification
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.setup_wandb()
|
| 57 |
+
self.setup_model()
|
| 58 |
+
self.setup_data()
|
| 59 |
+
self.setup_training()
|
| 60 |
+
|
| 61 |
+
def setup_wandb(self):
|
| 62 |
+
config = self.get_config()
|
| 63 |
+
wandb.init(
|
| 64 |
+
project=config['logger']['project_name'],
|
| 65 |
+
name=config['logger']['run_name'],
|
| 66 |
+
config=config
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def setup_model(self):
|
| 70 |
+
self.backbone = Backbone()
|
| 71 |
+
# Change classifier to output 1 value for binary classification
|
| 72 |
+
self.classifier = Classifier(d_model=2048, num_classes=1)
|
| 73 |
+
self.model = SingleScanModel(self.backbone, self.classifier)
|
| 74 |
+
|
| 75 |
+
# Load weights from brainiac
|
| 76 |
+
config = self.get_config()
|
| 77 |
+
if config["train"]["finetune"] == "yes":
|
| 78 |
+
checkpoint = torch.load(config["train"]["weights"], map_location=self.device)
|
| 79 |
+
state_dict = checkpoint["state_dict"]
|
| 80 |
+
filtered_state_dict = {}
|
| 81 |
+
for key, value in state_dict.items():
|
| 82 |
+
new_key = key.replace("module.", "backbone.") if key.startswith("module.") else key
|
| 83 |
+
filtered_state_dict[new_key] = value
|
| 84 |
+
self.model.backbone.load_state_dict(filtered_state_dict, strict=False)
|
| 85 |
+
print("Pretrained weights loaded!")
|
| 86 |
+
|
| 87 |
+
if config["train"]["freeze"] == "yes":
|
| 88 |
+
for param in self.model.backbone.parameters():
|
| 89 |
+
param.requires_grad = False
|
| 90 |
+
print("Backbone weights frozen!")
|
| 91 |
+
|
| 92 |
+
self.model = self.model.to(self.device)
|
| 93 |
+
|
| 94 |
+
## spinup dataloaders
|
| 95 |
+
def setup_data(self):
|
| 96 |
+
config = self.get_config()
|
| 97 |
+
self.train_dataset = TransformationMedicalImageDatasetBalancedIntensity3D(
|
| 98 |
+
csv_path=config['data']['train_csv'],
|
| 99 |
+
root_dir=config["data"]["root_dir"]
|
| 100 |
+
)
|
| 101 |
+
self.val_dataset = MedicalImageDatasetBalancedIntensity3D(
|
| 102 |
+
csv_path=config['data']['val_csv'],
|
| 103 |
+
root_dir=config["data"]["root_dir"]
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self.train_loader = DataLoader(
|
| 107 |
+
self.train_dataset,
|
| 108 |
+
batch_size=config["data"]["batch_size"],
|
| 109 |
+
shuffle=True,
|
| 110 |
+
collate_fn=self.custom_collate,
|
| 111 |
+
num_workers=config["data"]["num_workers"]
|
| 112 |
+
)
|
| 113 |
+
self.val_loader = DataLoader(
|
| 114 |
+
self.val_dataset,
|
| 115 |
+
batch_size=1,
|
| 116 |
+
shuffle=False,
|
| 117 |
+
collate_fn=self.custom_collate,
|
| 118 |
+
num_workers=1
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def setup_training(self):
|
| 122 |
+
"""
|
| 123 |
+
training setup
|
| 124 |
+
"""
|
| 125 |
+
config = self.get_config()
|
| 126 |
+
# BCE loss
|
| 127 |
+
self.criterion = nn.BCEWithLogitsLoss().to(self.device)
|
| 128 |
+
self.optimizer = optim.AdamW(
|
| 129 |
+
self.model.parameters(),
|
| 130 |
+
lr=config['optim']['lr'],
|
| 131 |
+
weight_decay=config["optim"]["weight_decay"]
|
| 132 |
+
)
|
| 133 |
+
self.scheduler = OneCycleLR(
|
| 134 |
+
self.optimizer,
|
| 135 |
+
max_lr=config['optim']['lr'],
|
| 136 |
+
epochs=config['optim']['max_epochs'],
|
| 137 |
+
steps_per_epoch=len(self.train_loader)
|
| 138 |
+
)
|
| 139 |
+
self.scaler = GradScaler()
|
| 140 |
+
|
| 141 |
+
## main training loop
|
| 142 |
+
def train(self):
|
| 143 |
+
config = self.get_config()
|
| 144 |
+
max_epochs = config['optim']['max_epochs']
|
| 145 |
+
best_metrics = {
|
| 146 |
+
'val_loss': float('inf'),
|
| 147 |
+
'accuracy': 0,
|
| 148 |
+
'precision': 0,
|
| 149 |
+
'recall': 0,
|
| 150 |
+
'f1': 0,
|
| 151 |
+
'auc': 0
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
for epoch in range(max_epochs):
|
| 155 |
+
train_loss = self.train_epoch(epoch, max_epochs)
|
| 156 |
+
val_loss, metrics = self.validate_epoch(epoch, max_epochs)
|
| 157 |
+
|
| 158 |
+
# Save best model based on validation loss and F1 score
|
| 159 |
+
if metrics['auc'] > best_metrics['auc']:
|
| 160 |
+
print(f"New best model found!")
|
| 161 |
+
print(f"Improved Val Loss from {best_metrics['val_loss']:.4f} to {val_loss:.4f}")
|
| 162 |
+
print(f"Improved F1 from {best_metrics['f1']:.4f} to {metrics['f1']:.4f}")
|
| 163 |
+
best_metrics.update(metrics)
|
| 164 |
+
best_metrics['val_loss'] = val_loss
|
| 165 |
+
self.save_checkpoint(epoch, val_loss, metrics)
|
| 166 |
+
|
| 167 |
+
wandb.finish()
|
| 168 |
+
|
| 169 |
+
## training pass
|
| 170 |
+
def train_epoch(self, epoch, max_epochs):
|
| 171 |
+
self.model.train()
|
| 172 |
+
train_loss = 0.0
|
| 173 |
+
|
| 174 |
+
for sample in tqdm(self.train_loader, desc=f"Training Epoch {epoch}/{max_epochs-1}"):
|
| 175 |
+
inputs = sample['image'].to(self.device)
|
| 176 |
+
labels = sample['label'].float().to(self.device)
|
| 177 |
+
|
| 178 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 179 |
+
with autocast():
|
| 180 |
+
outputs = self.model(inputs)
|
| 181 |
+
loss = self.criterion(outputs, labels.unsqueeze(1))
|
| 182 |
+
|
| 183 |
+
self.scaler.scale(loss).backward()
|
| 184 |
+
|
| 185 |
+
self.scaler.unscale_(self.optimizer)
|
| 186 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 187 |
+
|
| 188 |
+
self.scaler.step(self.optimizer)
|
| 189 |
+
self.scaler.update()
|
| 190 |
+
self.scheduler.step()
|
| 191 |
+
|
| 192 |
+
train_loss += loss.item() * inputs.size(0)
|
| 193 |
+
|
| 194 |
+
train_loss = train_loss / len(self.train_loader.dataset)
|
| 195 |
+
wandb.log({"Train Loss": train_loss})
|
| 196 |
+
return train_loss
|
| 197 |
+
|
| 198 |
+
## validation pass
|
| 199 |
+
def validate_epoch(self, epoch, max_epochs):
|
| 200 |
+
self.model.eval()
|
| 201 |
+
val_loss = 0.0
|
| 202 |
+
all_labels = []
|
| 203 |
+
all_preds = []
|
| 204 |
+
all_probs = []
|
| 205 |
+
|
| 206 |
+
with torch.no_grad():
|
| 207 |
+
for sample in tqdm(self.val_loader, desc=f"Validation Epoch {epoch}/{max_epochs-1}"):
|
| 208 |
+
inputs = sample['image'].to(self.device)
|
| 209 |
+
labels = sample['label'].float().to(self.device)
|
| 210 |
+
|
| 211 |
+
outputs = self.model(inputs)
|
| 212 |
+
loss = self.criterion(outputs, labels.unsqueeze(1))
|
| 213 |
+
|
| 214 |
+
# Get probabilities and predictions
|
| 215 |
+
probs = torch.sigmoid(outputs).cpu().numpy()
|
| 216 |
+
preds = (probs > 0.5).astype(int)
|
| 217 |
+
|
| 218 |
+
val_loss += loss.item() * inputs.size(0)
|
| 219 |
+
all_labels.extend(labels.cpu().numpy().flatten())
|
| 220 |
+
all_preds.extend(preds.flatten())
|
| 221 |
+
all_probs.extend(probs.flatten())
|
| 222 |
+
|
| 223 |
+
val_loss = val_loss / len(self.val_loader.dataset)
|
| 224 |
+
metrics = calculate_metrics(
|
| 225 |
+
np.array(all_probs),
|
| 226 |
+
np.array(all_preds),
|
| 227 |
+
np.array(all_labels)
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
wandb.log({
|
| 231 |
+
"Val Loss": val_loss,
|
| 232 |
+
"Accuracy": metrics['accuracy'],
|
| 233 |
+
"Precision": metrics['precision'],
|
| 234 |
+
"Recall": metrics['recall'],
|
| 235 |
+
"F1 Score": metrics['f1'],
|
| 236 |
+
"AUC": metrics['auc']
|
| 237 |
+
})
|
| 238 |
+
|
| 239 |
+
print(f"Epoch {epoch}/{max_epochs-1}")
|
| 240 |
+
print(f"Val Loss: {val_loss:.4f}")
|
| 241 |
+
print(f"Accuracy: {metrics['accuracy']:.4f}")
|
| 242 |
+
print(f"Precision: {metrics['precision']:.4f}")
|
| 243 |
+
print(f"Recall: {metrics['recall']:.4f}")
|
| 244 |
+
print(f"F1 Score: {metrics['f1']:.4f}")
|
| 245 |
+
print(f"AUC: {metrics['auc']:.4f}")
|
| 246 |
+
|
| 247 |
+
return val_loss, metrics
|
| 248 |
+
|
| 249 |
+
## save best model
|
| 250 |
+
def save_checkpoint(self, epoch, loss, metrics):
|
| 251 |
+
config = self.get_config()
|
| 252 |
+
checkpoint = {
|
| 253 |
+
'epoch': epoch,
|
| 254 |
+
'model_state_dict': self.model.state_dict(),
|
| 255 |
+
'metrics': metrics
|
| 256 |
+
}
|
| 257 |
+
save_path = os.path.join(
|
| 258 |
+
config['logger']['save_dir'],
|
| 259 |
+
config['logger']['save_name'].format(epoch=epoch, loss=loss, metric=metrics['f1'])
|
| 260 |
+
)
|
| 261 |
+
torch.save(checkpoint, save_path)
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
trainer = MCITrainer()
|
| 265 |
+
trainer.train()
|
src/BrainIAC/__init__.py
ADDED
|
File without changes
|
src/BrainIAC/__pycache__/dataset2.cpython-39.pyc
ADDED
|
Binary file (6.71 kB). View file
|
|
|
src/BrainIAC/__pycache__/load_brainiac.cpython-39.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|