Divyanshu Tak commited on
Commit
5a169ab
·
1 Parent(s): 50a5e7b

V0-commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/.DS_Store +0 -0
  2. src/BrainIAC/.DS_Store +0 -0
  3. src/BrainIAC/Brainage/README.md +55 -0
  4. src/BrainIAC/Brainage/__init__.py +0 -0
  5. src/BrainIAC/Brainage/__pycache__/__init__.cpython-39.pyc +0 -0
  6. src/BrainIAC/Brainage/__pycache__/infer_brainage.cpython-39.pyc +0 -0
  7. src/BrainIAC/Brainage/brainage.jpeg +3 -0
  8. src/BrainIAC/Brainage/infer_brainage.py +85 -0
  9. src/BrainIAC/Brainage/train_brainage.py +230 -0
  10. src/BrainIAC/HD_BET/__pycache__/config.cpython-310.pyc +0 -0
  11. src/BrainIAC/HD_BET/__pycache__/config.cpython-38.pyc +0 -0
  12. src/BrainIAC/HD_BET/__pycache__/config.cpython-39.pyc +0 -0
  13. src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-310.pyc +0 -0
  14. src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-38.pyc +0 -0
  15. src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-39.pyc +0 -0
  16. src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-310.pyc +0 -0
  17. src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-38.pyc +0 -0
  18. src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-310.pyc +0 -0
  19. src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-38.pyc +0 -0
  20. src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-39.pyc +0 -0
  21. src/BrainIAC/HD_BET/__pycache__/paths.cpython-310.pyc +0 -0
  22. src/BrainIAC/HD_BET/__pycache__/paths.cpython-38.pyc +0 -0
  23. src/BrainIAC/HD_BET/__pycache__/paths.cpython-39.pyc +0 -0
  24. src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-310.pyc +0 -0
  25. src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-38.pyc +0 -0
  26. src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-39.pyc +0 -0
  27. src/BrainIAC/HD_BET/__pycache__/run.cpython-310.pyc +0 -0
  28. src/BrainIAC/HD_BET/__pycache__/run.cpython-38.pyc +0 -0
  29. src/BrainIAC/HD_BET/__pycache__/run.cpython-39.pyc +0 -0
  30. src/BrainIAC/HD_BET/__pycache__/utils.cpython-310.pyc +0 -0
  31. src/BrainIAC/HD_BET/__pycache__/utils.cpython-38.pyc +0 -0
  32. src/BrainIAC/HD_BET/__pycache__/utils.cpython-39.pyc +0 -0
  33. src/BrainIAC/HD_BET/config.py +121 -0
  34. src/BrainIAC/HD_BET/data_loading.py +121 -0
  35. src/BrainIAC/HD_BET/hd_bet.py +119 -0
  36. src/BrainIAC/HD_BET/network_architecture.py +213 -0
  37. src/BrainIAC/HD_BET/paths.py +6 -0
  38. src/BrainIAC/HD_BET/predict_case.py +126 -0
  39. src/BrainIAC/HD_BET/run.py +117 -0
  40. src/BrainIAC/HD_BET/utils.py +115 -0
  41. src/BrainIAC/MCIclassification/README.md +52 -0
  42. src/BrainIAC/MCIclassification/__init__.py +0 -0
  43. src/BrainIAC/MCIclassification/__pycache__/__init__.cpython-39.pyc +0 -0
  44. src/BrainIAC/MCIclassification/__pycache__/infer_mci.cpython-39.pyc +0 -0
  45. src/BrainIAC/MCIclassification/infer_mci.py +142 -0
  46. src/BrainIAC/MCIclassification/mci.jpeg +3 -0
  47. src/BrainIAC/MCIclassification/train_mci.py +265 -0
  48. src/BrainIAC/__init__.py +0 -0
  49. src/BrainIAC/__pycache__/dataset2.cpython-39.pyc +0 -0
  50. src/BrainIAC/__pycache__/load_brainiac.cpython-39.pyc +0 -0
src/.DS_Store ADDED
Binary file (8.2 kB). View file
 
src/BrainIAC/.DS_Store ADDED
Binary file (8.2 kB). View file
 
src/BrainIAC/Brainage/README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Brain Age Prediction
2
+
3
+ <p align="left">
4
+ <img src="brainage.jpeg" width="200" alt="Brain Age Prediction Example"/>
5
+ </p>
6
+
7
+ ## Overview
8
+
9
+ We present the brainage prediction training and inference code for BrainIAC as a downstream task. The pipeline is trained and infered on T1 scans, with MAE as evaluation metric.
10
+
11
+ ## Data Requirements
12
+
13
+ - **Input**: T1-weighted MRI scans
14
+ - **Format**: NIFTI (.nii.gz)
15
+ - **Preprocessing**: Bias field corrected, registered to standard space, skull stripped
16
+ - **CSV Structure**:
17
+ ```
18
+ pat_id,scandate,label
19
+ subject001,20240101,65 # brain age in years
20
+ ```
21
+ refer to [ quickstart.ipynb](../quickstart.ipynb) to find how to preprocess data and generate csv file.
22
+
23
+
24
+ ## Setup
25
+
26
+ 1. **Configuration**:
27
+ change the [config.yml](../config.yml) file accordingly.
28
+ ```yaml
29
+ # config.yml
30
+ data:
31
+ train_csv: "path/to/train.csv"
32
+ val_csv: "path/to/val.csv"
33
+ test_csv: "path/to/test.csv"
34
+ root_dir: "../data/sample/processed"
35
+ collate: 1 # single scan framework
36
+
37
+ checkpoints: "./checkpoints/brainage_model.00" # for inference/testing
38
+
39
+ train:
40
+ finetune: 'yes' # yes to finetune the entire model
41
+ freeze: 'no' # yes to freeze the resnet backbone
42
+ weights: ./checkpoints/brainiac.ckpt # path to brainiac weights
43
+
44
+ ```
45
+
46
+ 2. **Training**:
47
+ ```bash
48
+ python -m Brainage.train_brainage
49
+ ```
50
+
51
+ 3. **Inference**:
52
+ ```bash
53
+ python -m Brainage.infer_brainage
54
+ ```
55
+
src/BrainIAC/Brainage/__init__.py ADDED
File without changes
src/BrainIAC/Brainage/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (145 Bytes). View file
 
src/BrainIAC/Brainage/__pycache__/infer_brainage.cpython-39.pyc ADDED
Binary file (3 kB). View file
 
src/BrainIAC/Brainage/brainage.jpeg ADDED

Git LFS Details

  • SHA256: 4b844af61b1dea2e772edfddcd8f8adb0453721f7972684b7b580e85ae2addf5
  • Pointer size: 130 Bytes
  • Size of remote file: 33.5 kB
src/BrainIAC/Brainage/infer_brainage.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ import os
4
+ from tqdm import tqdm
5
+ from torch.utils.data import DataLoader
6
+ from torch.cuda.amp import autocast
7
+ from sklearn.metrics import mean_absolute_error
8
+ import sys
9
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
+ from dataset2 import MedicalImageDatasetBalancedIntensity3D
11
+ from model import Backbone, SingleScanModel, Classifier
12
+ from utils import BaseConfig
13
+
14
+ class BrainAgeInference(BaseConfig):
15
+ """
16
+ Inference class for brain age prediction model.
17
+ """
18
+
19
+ def __init__(self):
20
+ """Initialize the inference setup with model and data."""
21
+ super().__init__()
22
+ self.setup_model()
23
+ self.setup_data()
24
+
25
+ def setup_model(self):
26
+ config = self.get_config()
27
+ self.backbone = Backbone()
28
+ self.classifier = Classifier(d_model=2048)
29
+ self.model = SingleScanModel(self.backbone, self.classifier)
30
+
31
+ # Load weights
32
+ checkpoint = torch.load(config["infer"]["checkpoints"], map_location=self.device)
33
+ self.model.load_state_dict(checkpoint["model_state_dict"])
34
+ self.model = self.model.to(self.device)
35
+ self.model.eval()
36
+ print("Model and checkpoint loaded!")
37
+
38
+ ## spinup dataloaders
39
+ def setup_data(self):
40
+ config = self.get_config()
41
+ self.test_dataset = MedicalImageDatasetBalancedIntensity3D(
42
+ csv_path=config["data"]["test_csv"],
43
+ root_dir=config["data"]["root_dir"]
44
+ )
45
+ self.test_loader = DataLoader(
46
+ self.test_dataset,
47
+ batch_size=1,
48
+ shuffle=False,
49
+ collate_fn=self.custom_collate,
50
+ num_workers=1
51
+ )
52
+
53
+ def infer(self):
54
+ """ Infer pass """
55
+ results_df = pd.DataFrame(columns=['PredictedAge', 'TrueAge'])
56
+ all_labels = []
57
+ all_predictions = []
58
+
59
+ with torch.no_grad():
60
+ for sample in tqdm(self.test_loader, desc="Inference", unit="batch"):
61
+ inputs = sample['image'].to(self.device)
62
+ labels = sample['label'].float().to(self.device)
63
+
64
+ with autocast():
65
+ outputs = self.model(inputs)
66
+
67
+ predictions = outputs.cpu().numpy().flatten()
68
+ all_labels.extend(labels.cpu().numpy().flatten())
69
+ all_predictions.extend(predictions)
70
+
71
+ result = pd.DataFrame({
72
+ 'PredictedAge': predictions,
73
+ 'TrueAge': labels.cpu().numpy().flatten()
74
+ })
75
+ results_df = pd.concat([results_df, result], ignore_index=True)
76
+
77
+ mae = mean_absolute_error(all_labels, all_predictions)
78
+ print(f"Mean Absolute Error (MAE): {mae:.4f} months")
79
+ results_df.to_csv('./data/output/brainage_output.csv', index=False)
80
+
81
+ return mae
82
+
83
+ if __name__ == "__main__":
84
+ inferencer = BrainAgeInference()
85
+ mae = inferencer.infer()
src/BrainIAC/Brainage/train_brainage.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ import wandb
6
+ from tqdm import tqdm
7
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
8
+ from torch.cuda.amp import GradScaler, autocast
9
+ from sklearn.metrics import mean_absolute_error
10
+ import os
11
+ import sys
12
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+ from dataset2 import MedicalImageDatasetBalancedIntensity3D, TransformationMedicalImageDatasetBalancedIntensity3D
14
+ from model import Backbone, SingleScanModel, Classifier
15
+ from utils import BaseConfig
16
+
17
+
18
+ class BrainAgeTrainer(BaseConfig):
19
+ """
20
+ A trainer class for brain age prediction models.
21
+
22
+ This class handles the complete training pipeline including model setup,
23
+ data loading, training loop, and validation.
24
+ Inherits from BaseConfig for configuration management.
25
+ """
26
+
27
+ def __init__(self):
28
+ """Initialize the trainer with model, data, and training setup."""
29
+ super().__init__()
30
+ self.setup_wandb()
31
+ self.setup_model()
32
+ self.setup_data()
33
+ self.setup_training()
34
+
35
+ ## setup wandb logger
36
+ def setup_wandb(self):
37
+ config = self.get_config()
38
+ wandb.init(
39
+ project=config['logger']['project_name'],
40
+ name=config['logger']['run_name'],
41
+ config=config
42
+ )
43
+
44
+ def setup_model(self):
45
+ """
46
+ Set up the model architecture.
47
+
48
+ Initializes the backbone and classifier blocks, and loads
49
+ checkpoints
50
+ """
51
+ self.backbone = Backbone()
52
+ self.classifier = Classifier(d_model=2048)
53
+ self.model = SingleScanModel(self.backbone, self.classifier)
54
+
55
+ # Load BrainIACs weights
56
+ config = self.get_config()
57
+ if config["train"]["finetune"] == "yes":
58
+ checkpoint = torch.load(config["train"]["weights"], map_location=self.device)
59
+ state_dict = checkpoint["state_dict"]
60
+ filtered_state_dict = {}
61
+ for key, value in state_dict.items():
62
+ new_key = key.replace("module.", "backbone.") if key.startswith("module.") else key
63
+ filtered_state_dict[new_key] = value
64
+ self.model.backbone.load_state_dict(filtered_state_dict, strict=False)
65
+ print("Pretrained weights loaded!")
66
+
67
+ # Freeze backbone if specified
68
+ if config["train"]["freeze"] == "yes":
69
+ for param in self.model.backbone.parameters():
70
+ param.requires_grad = False
71
+ print("Backbone weights frozen!")
72
+
73
+ self.model = self.model.to(self.device)
74
+
75
+ def setup_data(self):
76
+ """
77
+ Set up data loaders for training and validation.
78
+ Inherit configuration from the base config
79
+ """
80
+ config = self.get_config()
81
+ self.train_dataset = TransformationMedicalImageDatasetBalancedIntensity3D(
82
+ csv_path=config['data']['train_csv'],
83
+ root_dir=config["data"]["root_dir"]
84
+ )
85
+ self.val_dataset = MedicalImageDatasetBalancedIntensity3D(
86
+ csv_path=config['data']['val_csv'],
87
+ root_dir=config["data"]["root_dir"]
88
+ )
89
+
90
+ self.train_loader = DataLoader(
91
+ self.train_dataset,
92
+ batch_size=config["data"]["batch_size"],
93
+ shuffle=True,
94
+ collate_fn=self.custom_collate,
95
+ num_workers=config["data"]["num_workers"]
96
+ )
97
+ self.val_loader = DataLoader(
98
+ self.val_dataset,
99
+ batch_size=1,
100
+ shuffle=False,
101
+ collate_fn=self.custom_collate,
102
+ num_workers=1
103
+ )
104
+
105
+ def setup_training(self):
106
+ """
107
+ Set up training config with loss, scheduler, optimizer.
108
+ """
109
+ config = self.get_config()
110
+ self.criterion = nn.MSELoss()
111
+ self.optimizer = optim.Adam(
112
+ self.model.parameters(),
113
+ lr=config['optim']['lr'],
114
+ weight_decay=config["optim"]["weight_decay"]
115
+ )
116
+ self.scheduler = CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=2)
117
+ self.scaler = GradScaler()
118
+
119
+ def train(self):
120
+ """
121
+ main training loop
122
+ """
123
+ config = self.get_config()
124
+ max_epochs = config['optim']['max_epochs']
125
+ best_val_loss = float('inf')
126
+ best_val_mae = float('inf')
127
+
128
+ for epoch in range(max_epochs):
129
+ train_loss = self.train_epoch(epoch, max_epochs)
130
+ val_loss, mae = self.validate_epoch(epoch, max_epochs)
131
+
132
+ # Save best model
133
+ if (val_loss <= best_val_loss) and (mae <= best_val_mae):
134
+ print(f"Improved Val Loss from {best_val_loss:.4f} to {val_loss:.4f}")
135
+ print(f"Improved Val MAE from {best_val_mae:.4f} to {mae:.4f}")
136
+ best_val_loss = val_loss
137
+ best_val_mae = mae
138
+ self.save_checkpoint(epoch, val_loss, mae)
139
+
140
+ wandb.finish()
141
+
142
+ def train_epoch(self, epoch, max_epochs):
143
+ """
144
+ Train pass.
145
+
146
+ Args:
147
+ epoch (int): Current epoch number
148
+ max_epochs (int): Total number of epochs
149
+
150
+ Returns:
151
+ float: Average training loss for the epoch
152
+ """
153
+ self.model.train()
154
+ train_loss = 0.0
155
+
156
+ for sample in tqdm(self.train_loader, desc=f"Training Epoch {epoch}/{max_epochs-1}"):
157
+ inputs = sample['image'].to(self.device)
158
+ labels = sample['label'].float().to(self.device)
159
+
160
+ self.optimizer.zero_grad()
161
+ with autocast():
162
+ outputs = self.model(inputs)
163
+ loss = self.criterion(outputs, labels.unsqueeze(1))
164
+
165
+ self.scaler.scale(loss).backward()
166
+ self.scaler.step(self.optimizer)
167
+ self.scaler.update()
168
+
169
+ train_loss += loss.item() * inputs.size(0)
170
+
171
+ train_loss = train_loss / len(self.train_loader.dataset)
172
+ wandb.log({"Train Loss": train_loss})
173
+ return train_loss
174
+
175
+ def validate_epoch(self, epoch, max_epochs):
176
+ """
177
+ Validation pass.
178
+
179
+ Args:
180
+ epoch (int): Current epoch number
181
+ max_epochs (int): Total number of epochs
182
+
183
+ Returns:
184
+ tuple: (validation_loss, mean_absolute_error)
185
+ """
186
+ self.model.eval()
187
+ val_loss = 0.0
188
+ all_labels = []
189
+ all_preds = []
190
+
191
+ with torch.no_grad():
192
+ for sample in tqdm(self.val_loader, desc=f"Validation Epoch {epoch}/{max_epochs-1}"):
193
+ inputs = sample['image'].to(self.device)
194
+ labels = sample['label'].float().to(self.device)
195
+
196
+ outputs = self.model(inputs)
197
+ loss = self.criterion(outputs, labels.unsqueeze(1))
198
+
199
+ val_loss += loss.item() * inputs.size(0)
200
+ all_labels.extend(labels.cpu().numpy().flatten())
201
+ all_preds.extend(outputs.cpu().numpy().flatten())
202
+
203
+ val_loss = val_loss / len(self.val_loader.dataset)
204
+ mae = mean_absolute_error(all_labels, all_preds)
205
+
206
+ wandb.log({"Val Loss": val_loss, "MAE": mae})
207
+ self.scheduler.step(val_loss)
208
+
209
+ print(f"Epoch {epoch}/{max_epochs-1} Val Loss: {val_loss:.4f} MAE: {mae:.4f}")
210
+ return val_loss, mae
211
+
212
+ def save_checkpoint(self, epoch, loss, mae):
213
+ """
214
+ Save model checkpoint.
215
+ """
216
+ config = self.get_config()
217
+ checkpoint = {
218
+ 'model_state_dict': self.model.state_dict(),
219
+ 'loss': loss,
220
+ 'epoch': epoch,
221
+ }
222
+ save_path = os.path.join(
223
+ config['logger']['save_dir'],
224
+ config['logger']['save_name'].format(epoch=epoch, loss=loss, metric=mae)
225
+ )
226
+ torch.save(checkpoint, save_path)
227
+
228
+ if __name__ == "__main__":
229
+ trainer = BrainAgeTrainer()
230
+ trainer.train()
src/BrainIAC/HD_BET/__pycache__/config.cpython-310.pyc ADDED
Binary file (4.15 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/config.cpython-38.pyc ADDED
Binary file (4.13 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/config.cpython-39.pyc ADDED
Binary file (4.19 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-310.pyc ADDED
Binary file (4.47 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-38.pyc ADDED
Binary file (4.48 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/data_loading.cpython-39.pyc ADDED
Binary file (4.46 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/hd_bet.cpython-38.pyc ADDED
Binary file (4.27 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-310.pyc ADDED
Binary file (6.78 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-38.pyc ADDED
Binary file (6.89 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/network_architecture.cpython-39.pyc ADDED
Binary file (6.84 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/paths.cpython-310.pyc ADDED
Binary file (324 Bytes). View file
 
src/BrainIAC/HD_BET/__pycache__/paths.cpython-38.pyc ADDED
Binary file (335 Bytes). View file
 
src/BrainIAC/HD_BET/__pycache__/paths.cpython-39.pyc ADDED
Binary file (322 Bytes). View file
 
src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-310.pyc ADDED
Binary file (3.68 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-38.pyc ADDED
Binary file (3.67 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/predict_case.cpython-39.pyc ADDED
Binary file (3.68 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/run.cpython-310.pyc ADDED
Binary file (3.83 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/run.cpython-38.pyc ADDED
Binary file (3.88 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/run.cpython-39.pyc ADDED
Binary file (3.85 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.68 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.85 kB). View file
 
src/BrainIAC/HD_BET/__pycache__/utils.cpython-39.pyc ADDED
Binary file (4.81 kB). View file
 
src/BrainIAC/HD_BET/config.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from HD_BET.utils import SetNetworkToVal, softmax_helper
4
+ from abc import abstractmethod
5
+ from HD_BET.network_architecture import Network
6
+
7
+
8
+ class BaseConfig(object):
9
+ def __init__(self):
10
+ pass
11
+
12
+ @abstractmethod
13
+ def get_split(self, fold, random_state=12345):
14
+ pass
15
+
16
+ @abstractmethod
17
+ def get_network(self, mode="train"):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def get_basic_generators(self, fold):
22
+ pass
23
+
24
+ @abstractmethod
25
+ def get_data_generators(self, fold):
26
+ pass
27
+
28
+ def preprocess(self, data):
29
+ return data
30
+
31
+ def __repr__(self):
32
+ res = ""
33
+ for v in vars(self):
34
+ if not v.startswith("__") and not v.startswith("_") and v != 'dataset':
35
+ res += (v + ": " + str(self.__getattribute__(v)) + "\n")
36
+ return res
37
+
38
+
39
+ class HD_BET_Config(BaseConfig):
40
+ def __init__(self):
41
+ super(HD_BET_Config, self).__init__()
42
+
43
+ self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name
44
+
45
+ # network parameters
46
+ self.net_base_num_layers = 21
47
+ self.BATCH_SIZE = 2
48
+ self.net_do_DS = True
49
+ self.net_dropout_p = 0.0
50
+ self.net_use_inst_norm = True
51
+ self.net_conv_use_bias = True
52
+ self.net_norm_use_affine = True
53
+ self.net_leaky_relu_slope = 1e-1
54
+
55
+ # hyperparameters
56
+ self.INPUT_PATCH_SIZE = (128, 128, 128)
57
+ self.num_classes = 2
58
+ self.selected_data_channels = range(1)
59
+
60
+ # data augmentation
61
+ self.da_mirror_axes = (2, 3, 4)
62
+
63
+ # validation
64
+ self.val_use_DO = False
65
+ self.val_use_train_mode = False # for dropout sampling
66
+ self.val_num_repeats = 1 # only useful if dropout sampling
67
+ self.val_batch_size = 1 # only useful if dropout sampling
68
+ self.val_save_npz = True
69
+ self.val_do_mirroring = True # test time data augmentation via mirroring
70
+ self.val_write_images = True
71
+ self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property
72
+ self.val_min_size = self.INPUT_PATCH_SIZE
73
+ self.val_fn = None
74
+
75
+ # CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_
76
+ # stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults
77
+ # to false in 0.4)
78
+ self.val_use_moving_averages = False
79
+
80
+ def get_network(self, train=True, pretrained_weights=None):
81
+ net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers,
82
+ self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias,
83
+ self.net_norm_use_affine, True, self.net_do_DS)
84
+
85
+ if pretrained_weights is not None:
86
+ net.load_state_dict(
87
+ torch.load(pretrained_weights, map_location=lambda storage, loc: storage))
88
+
89
+ if train:
90
+ net.train(True)
91
+ else:
92
+ net.train(False)
93
+ net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages))
94
+ net.do_ds = False
95
+
96
+ optimizer = None
97
+ self.lr_scheduler = None
98
+ return net, optimizer
99
+
100
+ def get_data_generators(self, fold):
101
+ pass
102
+
103
+ def get_split(self, fold, random_state=12345):
104
+ pass
105
+
106
+ def get_basic_generators(self, fold):
107
+ pass
108
+
109
+ def on_epoch_end(self, epoch):
110
+ pass
111
+
112
+ def preprocess(self, data):
113
+ data = np.copy(data)
114
+ for c in range(data.shape[0]):
115
+ data[c] -= data[c].mean()
116
+ data[c] /= data[c].std()
117
+ return data
118
+
119
+
120
+ config = HD_BET_Config
121
+
src/BrainIAC/HD_BET/data_loading.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import SimpleITK as sitk
2
+ import numpy as np
3
+ from skimage.transform import resize
4
+
5
+
6
+ def resize_image(image, old_spacing, new_spacing, order=3):
7
+ new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))),
8
+ int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))),
9
+ int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2]))))
10
+ return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False)
11
+
12
+
13
+ def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)):
14
+ spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]]
15
+ image = sitk.GetArrayFromImage(itk_image).astype(float)
16
+
17
+ assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed"
18
+
19
+ if not is_seg:
20
+ if np.any([[i != j] for i, j in zip(spacing, spacing_target)]):
21
+ image = resize_image(image, spacing, spacing_target).astype(np.float32)
22
+
23
+ image -= image.mean()
24
+ image /= image.std()
25
+ else:
26
+ new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))),
27
+ int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))),
28
+ int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2]))))
29
+ image = resize_segmentation(image, new_shape, 1)
30
+ return image
31
+
32
+
33
+ def load_and_preprocess(mri_file):
34
+ images = {}
35
+ # t1
36
+ images["T1"] = sitk.ReadImage(mri_file)
37
+
38
+ properties_dict = {
39
+ "spacing": images["T1"].GetSpacing(),
40
+ "direction": images["T1"].GetDirection(),
41
+ "size": images["T1"].GetSize(),
42
+ "origin": images["T1"].GetOrigin()
43
+ }
44
+
45
+ for k in images.keys():
46
+ images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5))
47
+
48
+ properties_dict['size_before_cropping'] = images["T1"].shape
49
+
50
+ imgs = []
51
+ for seq in ['T1']:
52
+ imgs.append(images[seq][None])
53
+ all_data = np.vstack(imgs)
54
+ print("image shape after preprocessing: ", str(all_data[0].shape))
55
+ return all_data, properties_dict
56
+
57
+
58
+ def save_segmentation_nifti(segmentation, dct, out_fname, order=1):
59
+ '''
60
+ segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out
61
+ of the original image
62
+
63
+ dct:
64
+ size_before_cropping
65
+ brain_bbox
66
+ size -> this is the original size of the dataset, if the image was not resampled, this is the same as size_before_cropping
67
+ spacing
68
+ origin
69
+ direction
70
+
71
+ :param segmentation:
72
+ :param dct:
73
+ :param out_fname:
74
+ :return:
75
+ '''
76
+ old_size = dct.get('size_before_cropping')
77
+ bbox = dct.get('brain_bbox')
78
+ if bbox is not None:
79
+ seg_old_size = np.zeros(old_size)
80
+ for c in range(3):
81
+ bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c]))
82
+ seg_old_size[bbox[0][0]:bbox[0][1],
83
+ bbox[1][0]:bbox[1][1],
84
+ bbox[2][0]:bbox[2][1]] = segmentation
85
+ else:
86
+ seg_old_size = segmentation
87
+ if np.any(np.array(seg_old_size) != np.array(dct['size'])[[2, 1, 0]]):
88
+ seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order)
89
+ else:
90
+ seg_old_spacing = seg_old_size
91
+ seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(np.int32))
92
+ seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]])
93
+ seg_resized_itk.SetOrigin(dct['origin'])
94
+ seg_resized_itk.SetDirection(dct['direction'])
95
+ sitk.WriteImage(seg_resized_itk, out_fname)
96
+
97
+
98
+ def resize_segmentation(segmentation, new_shape, order=3, cval=0):
99
+ '''
100
+ Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency
101
+
102
+ Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one
103
+ hot encoding which is resized and transformed back to a segmentation map.
104
+ This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2])
105
+ :param segmentation:
106
+ :param new_shape:
107
+ :param order:
108
+ :return:
109
+ '''
110
+ tpe = segmentation.dtype
111
+ unique_labels = np.unique(segmentation)
112
+ assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation"
113
+ if order == 0:
114
+ return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe)
115
+ else:
116
+ reshaped = np.zeros(new_shape, dtype=segmentation.dtype)
117
+
118
+ for i, c in enumerate(unique_labels):
119
+ reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False)
120
+ reshaped[reshaped_multihot >= 0.5] = c
121
+ return reshaped
src/BrainIAC/HD_BET/hd_bet.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import sys
5
+ sys.path.append("/mnt/93E8-0534/AIDAN/HDBET/")
6
+ from HD_BET.run import run_hd_bet
7
+ from HD_BET.utils import maybe_mkdir_p, subfiles
8
+ import HD_BET
9
+
10
+ def hd_bet(input_file_or_dir,output_file_or_dir,mode,device,tta,pp=1,save_mask=0,overwrite_existing=1):
11
+
12
+ if output_file_or_dir is None:
13
+ output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir),
14
+ os.path.basename(input_file_or_dir).split(".")[0] + "_bet")
15
+
16
+
17
+ params_file = os.path.join(HD_BET.__path__[0], "model_final.py")
18
+ config_file = os.path.join(HD_BET.__path__[0], "config.py")
19
+
20
+ assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
21
+
22
+ if device == 'cpu':
23
+ pass
24
+ else:
25
+ device = int(device)
26
+
27
+ if os.path.isdir(input_file_or_dir):
28
+ maybe_mkdir_p(output_file_or_dir)
29
+ input_files = subfiles(input_file_or_dir, suffix='_0000.nii.gz', join=False)
30
+
31
+ if len(input_files) == 0:
32
+ raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here")
33
+
34
+ output_files = [os.path.join(output_file_or_dir, i) for i in input_files]
35
+ input_files = [os.path.join(input_file_or_dir, i) for i in input_files]
36
+ else:
37
+ if not output_file_or_dir.endswith('.nii.gz'):
38
+ output_file_or_dir += '.nii.gz'
39
+ assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
40
+
41
+ output_files = [output_file_or_dir]
42
+ input_files = [input_file_or_dir]
43
+
44
+ if tta == 0:
45
+ tta = False
46
+ elif tta == 1:
47
+ tta = True
48
+ else:
49
+ raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta))
50
+
51
+ if overwrite_existing == 0:
52
+ overwrite_existing = False
53
+ elif overwrite_existing == 1:
54
+ overwrite_existing = True
55
+ else:
56
+ raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing))
57
+
58
+ if pp == 0:
59
+ pp = False
60
+ elif pp == 1:
61
+ pp = True
62
+ else:
63
+ raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
64
+
65
+ if save_mask == 0:
66
+ save_mask = False
67
+ elif save_mask == 1:
68
+ save_mask = True
69
+ else:
70
+ raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
71
+
72
+ run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ print("\n########################")
77
+ print("If you are using hd-bet, please cite the following paper:")
78
+ print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W,"
79
+ "Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial"
80
+ "neural networks. arXiv preprint arXiv:1901.11341, 2019.")
81
+ print("########################\n")
82
+
83
+ import argparse
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be '
86
+ 'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to '
87
+ 'split 4d sequences into 3d images. If folder: all files ending with .nii.gz '
88
+ 'within that folder will be brain extracted.', required=True, type=str)
89
+ parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder'
90
+ ' will be created', required=False, type=str)
91
+ parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will '
92
+ 'use only one set of parameters whereas accurate will '
93
+ 'use the five sets of parameters that resulted from '
94
+ 'our cross-validation as an ensemble. Default: '
95
+ 'accurate',
96
+ required=False)
97
+ parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. '
98
+ 'Must be either int or str. Use int for GPU id or '
99
+ '\'cpu\' to run on CPU. When using CPU you should '
100
+ 'consider disabling tta. Default for -device is: 0',
101
+ required=False)
102
+ parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation '
103
+ '(mirroring). 1= True, 0=False. Disable this '
104
+ 'if you are using CPU to speed things up! '
105
+ 'Default: 1')
106
+ parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all'
107
+ ' but the largest connected component in '
108
+ 'the prediction. Default: 1')
109
+ parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation '
110
+ 'mask will not be '
111
+ 'saved')
112
+ parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't "
113
+ "want to overwrite existing "
114
+ "predictions")
115
+
116
+ args = parser.parse_args()
117
+
118
+ hd_bet(args.input,args.output,args.mode,args.device,args.tta,args.pp,args.save_mask,args.overwrite_existing)
119
+
src/BrainIAC/HD_BET/network_architecture.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from HD_BET.utils import softmax_helper
5
+
6
+
7
+ class EncodingModule(nn.Module):
8
+ def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True,
9
+ inst_norm_affine=True, lrelu_inplace=True):
10
+ nn.Module.__init__(self)
11
+ self.dropout_p = dropout_p
12
+ self.lrelu_inplace = lrelu_inplace
13
+ self.inst_norm_affine = inst_norm_affine
14
+ self.conv_bias = conv_bias
15
+ self.leakiness = leakiness
16
+ self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
17
+ self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
18
+ self.dropout = nn.Dropout3d(dropout_p)
19
+ self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
20
+ self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
21
+
22
+ def forward(self, x):
23
+ skip = x
24
+ x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
25
+ x = self.conv1(x)
26
+ if self.dropout_p is not None and self.dropout_p > 0:
27
+ x = self.dropout(x)
28
+ x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
29
+ x = self.conv2(x)
30
+ x = x + skip
31
+ return x
32
+
33
+
34
+ class Upsample(nn.Module):
35
+ def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True):
36
+ super(Upsample, self).__init__()
37
+ self.align_corners = align_corners
38
+ self.mode = mode
39
+ self.scale_factor = scale_factor
40
+ self.size = size
41
+
42
+ def forward(self, x):
43
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
44
+ align_corners=self.align_corners)
45
+
46
+
47
+ class LocalizationModule(nn.Module):
48
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
49
+ lrelu_inplace=True):
50
+ nn.Module.__init__(self)
51
+ self.lrelu_inplace = lrelu_inplace
52
+ self.inst_norm_affine = inst_norm_affine
53
+ self.conv_bias = conv_bias
54
+ self.leakiness = leakiness
55
+ self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias)
56
+ self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
57
+ self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias)
58
+ self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
59
+
60
+ def forward(self, x):
61
+ x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
62
+ x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
63
+ return x
64
+
65
+
66
+ class UpsamplingModule(nn.Module):
67
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
68
+ lrelu_inplace=True):
69
+ nn.Module.__init__(self)
70
+ self.lrelu_inplace = lrelu_inplace
71
+ self.inst_norm_affine = inst_norm_affine
72
+ self.conv_bias = conv_bias
73
+ self.leakiness = leakiness
74
+ self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True)
75
+ self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias)
76
+ self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
77
+
78
+ def forward(self, x):
79
+ x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness,
80
+ inplace=self.lrelu_inplace)
81
+ return x
82
+
83
+
84
+ class DownsamplingModule(nn.Module):
85
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
86
+ lrelu_inplace=True):
87
+ nn.Module.__init__(self)
88
+ self.lrelu_inplace = lrelu_inplace
89
+ self.inst_norm_affine = inst_norm_affine
90
+ self.conv_bias = conv_bias
91
+ self.leakiness = leakiness
92
+ self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
93
+ self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias)
94
+
95
+ def forward(self, x):
96
+ x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
97
+ b = self.downsample(x)
98
+ return x, b
99
+
100
+
101
+ class Network(nn.Module):
102
+ def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3,
103
+ final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
104
+ lrelu_inplace=True, do_ds=True):
105
+ super(Network, self).__init__()
106
+
107
+ self.do_ds = do_ds
108
+ self.lrelu_inplace = lrelu_inplace
109
+ self.inst_norm_affine = inst_norm_affine
110
+ self.conv_bias = conv_bias
111
+ self.leakiness = leakiness
112
+ self.final_nonlin = final_nonlin
113
+ self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias)
114
+
115
+ self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
116
+ inst_norm_affine=True, lrelu_inplace=True)
117
+ self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True,
118
+ inst_norm_affine=True, lrelu_inplace=True)
119
+
120
+ self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
121
+ inst_norm_affine=True, lrelu_inplace=True)
122
+ self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True,
123
+ inst_norm_affine=True, lrelu_inplace=True)
124
+
125
+ self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
126
+ inst_norm_affine=True, lrelu_inplace=True)
127
+ self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True,
128
+ inst_norm_affine=True, lrelu_inplace=True)
129
+
130
+ self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
131
+ inst_norm_affine=True, lrelu_inplace=True)
132
+ self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True,
133
+ inst_norm_affine=True, lrelu_inplace=True)
134
+
135
+ self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2,
136
+ conv_bias=True, inst_norm_affine=True, lrelu_inplace=True)
137
+
138
+ self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
139
+ self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
140
+ inst_norm_affine=True, lrelu_inplace=True)
141
+
142
+ self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
143
+ inst_norm_affine=True, lrelu_inplace=True)
144
+ self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
145
+ inst_norm_affine=True, lrelu_inplace=True)
146
+
147
+ self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
148
+ inst_norm_affine=True, lrelu_inplace=True)
149
+ self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False)
150
+ self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
151
+ inst_norm_affine=True, lrelu_inplace=True)
152
+
153
+ self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
154
+ inst_norm_affine=True, lrelu_inplace=True)
155
+ self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
156
+ self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True,
157
+ inst_norm_affine=True, lrelu_inplace=True)
158
+
159
+ self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
160
+ self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
161
+ self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
162
+ self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
163
+ self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
164
+
165
+ def forward(self, x):
166
+ seg_outputs = []
167
+
168
+ x = self.init_conv(x)
169
+ x = self.context1(x)
170
+
171
+ skip1, x = self.down1(x)
172
+ x = self.context2(x)
173
+
174
+ skip2, x = self.down2(x)
175
+ x = self.context3(x)
176
+
177
+ skip3, x = self.down3(x)
178
+ x = self.context4(x)
179
+
180
+ skip4, x = self.down4(x)
181
+ x = self.context5(x)
182
+
183
+ x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
184
+ x = self.up1(x)
185
+
186
+ x = torch.cat((skip4, x), dim=1)
187
+ x = self.loc1(x)
188
+ x = self.up2(x)
189
+
190
+ x = torch.cat((skip3, x), dim=1)
191
+ x = self.loc2(x)
192
+ loc2_seg = self.final_nonlin(self.loc2_seg(x))
193
+ seg_outputs.append(loc2_seg)
194
+ x = self.up3(x)
195
+
196
+ x = torch.cat((skip2, x), dim=1)
197
+ x = self.loc3(x)
198
+ loc3_seg = self.final_nonlin(self.loc3_seg(x))
199
+ seg_outputs.append(loc3_seg)
200
+ x = self.up4(x)
201
+
202
+ x = torch.cat((skip1, x), dim=1)
203
+ x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness,
204
+ inplace=self.lrelu_inplace)
205
+ x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness,
206
+ inplace=self.lrelu_inplace)
207
+ x = self.final_nonlin(self.seg_layer(x))
208
+ seg_outputs.append(x)
209
+
210
+ if self.do_ds:
211
+ return seg_outputs[::-1]
212
+ else:
213
+ return seg_outputs[-1]
src/BrainIAC/HD_BET/paths.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # please refer to the readme on where to get the parameters. Save them in this folder:
4
+ # Original Path: "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params"
5
+ # Updated path for Docker container:
6
+ folder_with_parameter_files = "/app/BrainIAC/hdbet_model"
src/BrainIAC/HD_BET/predict_case.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None):
6
+ if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)):
7
+ shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3
8
+ shp = patient.shape
9
+ new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0],
10
+ shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1],
11
+ shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]]
12
+ for i in range(len(shp)):
13
+ if shp[i] % shape_must_be_divisible_by[i] == 0:
14
+ new_shp[i] -= shape_must_be_divisible_by[i]
15
+ if min_size is not None:
16
+ new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0)
17
+ return reshape_by_padding_upper_coords(patient, new_shp, 0), shp
18
+
19
+
20
+ def reshape_by_padding_upper_coords(image, new_shape, pad_value=None):
21
+ shape = tuple(list(image.shape))
22
+ new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0))
23
+ if pad_value is None:
24
+ if len(shape) == 2:
25
+ pad_value = image[0,0]
26
+ elif len(shape) == 3:
27
+ pad_value = image[0, 0, 0]
28
+ else:
29
+ raise ValueError("Image must be either 2 or 3 dimensional")
30
+ res = np.ones(list(new_shape), dtype=image.dtype) * pad_value
31
+ if len(shape) == 2:
32
+ res[0:0+int(shape[0]), 0:0+int(shape[1])] = image
33
+ elif len(shape) == 3:
34
+ res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image
35
+ return res
36
+
37
+
38
+ def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None,
39
+ new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)):
40
+ with torch.no_grad():
41
+ pad_res = []
42
+ for i in range(patient_data.shape[0]):
43
+ t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size)
44
+ pad_res.append(t[None])
45
+
46
+ patient_data = np.vstack(pad_res)
47
+
48
+ new_shp = patient_data.shape
49
+
50
+ data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32)
51
+
52
+ data[0] = patient_data
53
+
54
+ if BATCH_SIZE is not None:
55
+ data = np.vstack([data] * BATCH_SIZE)
56
+
57
+ a = torch.rand(data.shape).float()
58
+
59
+ if main_device == 'cpu':
60
+ pass
61
+ else:
62
+ a = a.cuda(main_device)
63
+
64
+ if do_mirroring:
65
+ x = 8
66
+ else:
67
+ x = 1
68
+ all_preds = []
69
+ for i in range(num_repeats):
70
+ for m in range(x):
71
+ data_for_net = np.array(data)
72
+ do_stuff = False
73
+ if m == 0:
74
+ do_stuff = True
75
+ pass
76
+ if m == 1 and (4 in mirror_axes):
77
+ do_stuff = True
78
+ data_for_net = data_for_net[:, :, :, :, ::-1]
79
+ if m == 2 and (3 in mirror_axes):
80
+ do_stuff = True
81
+ data_for_net = data_for_net[:, :, :, ::-1, :]
82
+ if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
83
+ do_stuff = True
84
+ data_for_net = data_for_net[:, :, :, ::-1, ::-1]
85
+ if m == 4 and (2 in mirror_axes):
86
+ do_stuff = True
87
+ data_for_net = data_for_net[:, :, ::-1, :, :]
88
+ if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
89
+ do_stuff = True
90
+ data_for_net = data_for_net[:, :, ::-1, :, ::-1]
91
+ if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
92
+ do_stuff = True
93
+ data_for_net = data_for_net[:, :, ::-1, ::-1, :]
94
+ if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
95
+ do_stuff = True
96
+ data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1]
97
+
98
+ if do_stuff:
99
+ _ = a.data.copy_(torch.from_numpy(np.copy(data_for_net)))
100
+ p = net(a) # np.copy is necessary because ::-1 creates just a view i think
101
+ p = p.data.cpu().numpy()
102
+
103
+ if m == 0:
104
+ pass
105
+ if m == 1 and (4 in mirror_axes):
106
+ p = p[:, :, :, :, ::-1]
107
+ if m == 2 and (3 in mirror_axes):
108
+ p = p[:, :, :, ::-1, :]
109
+ if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
110
+ p = p[:, :, :, ::-1, ::-1]
111
+ if m == 4 and (2 in mirror_axes):
112
+ p = p[:, :, ::-1, :, :]
113
+ if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
114
+ p = p[:, :, ::-1, :, ::-1]
115
+ if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
116
+ p = p[:, :, ::-1, ::-1, :]
117
+ if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
118
+ p = p[:, :, ::-1, ::-1, ::-1]
119
+ all_preds.append(p)
120
+
121
+ stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]]
122
+ predicted_segmentation = stacked.mean(0).argmax(0)
123
+ uncertainty = stacked.var(0)
124
+ bayesian_predictions = stacked
125
+ softmax_pred = stacked.mean(0)
126
+ return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty
src/BrainIAC/HD_BET/run.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import SimpleITK as sitk
4
+ from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti
5
+ from HD_BET.predict_case import predict_case_3D_net
6
+ import imp
7
+ from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters
8
+ import os
9
+ import HD_BET
10
+
11
+
12
+ def apply_bet(img, bet, out_fname):
13
+ img_itk = sitk.ReadImage(img)
14
+ img_npy = sitk.GetArrayFromImage(img_itk)
15
+ img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet))
16
+ img_npy[img_bet == 0] = 0
17
+ out = sitk.GetImageFromArray(img_npy)
18
+ out.CopyInformation(img_itk)
19
+ sitk.WriteImage(out, out_fname)
20
+
21
+
22
+ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0,
23
+ postprocess=False, do_tta=True, keep_mask=True, overwrite=True):
24
+ """
25
+
26
+ :param mri_fnames: str or list/tuple of str
27
+ :param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames
28
+ :param mode: fast or accurate
29
+ :param config_file: config.py
30
+ :param device: either int (for device id) or 'cpu'
31
+ :param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all
32
+ but the largest predicted connected component. Default False
33
+ :param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use
34
+ CPU you may want to turn that off to speed things up
35
+ :return:
36
+ """
37
+
38
+ list_of_param_files = []
39
+
40
+ if mode == 'fast':
41
+ params_file = get_params_fname(0)
42
+ maybe_download_parameters(0)
43
+
44
+ list_of_param_files.append(params_file)
45
+ elif mode == 'accurate':
46
+ for i in range(5):
47
+ params_file = get_params_fname(i)
48
+ maybe_download_parameters(i)
49
+
50
+ list_of_param_files.append(params_file)
51
+ else:
52
+ raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode)
53
+
54
+ assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files"
55
+
56
+ cf = imp.load_source('cf', config_file)
57
+ cf = cf.config()
58
+
59
+ net, _ = cf.get_network(cf.val_use_train_mode, None)
60
+ if device == "cpu":
61
+ net = net.cpu()
62
+ else:
63
+ net.cuda(device)
64
+
65
+ if not isinstance(mri_fnames, (list, tuple)):
66
+ mri_fnames = [mri_fnames]
67
+
68
+ if not isinstance(output_fnames, (list, tuple)):
69
+ output_fnames = [output_fnames]
70
+
71
+ assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length"
72
+
73
+ params = []
74
+ for p in list_of_param_files:
75
+ params.append(torch.load(p, map_location=lambda storage, loc: storage))
76
+
77
+ for in_fname, out_fname in zip(mri_fnames, output_fnames):
78
+ mask_fname = out_fname[:-7] + "_mask.nii.gz"
79
+ if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)):
80
+ print("File:", in_fname)
81
+ print("preprocessing...")
82
+ try:
83
+ data, data_dict = load_and_preprocess(in_fname)
84
+ except RuntimeError:
85
+ print("\nERROR\nCould not read file", in_fname, "\n")
86
+ continue
87
+ except AssertionError as e:
88
+ print(e)
89
+ continue
90
+
91
+ softmax_preds = []
92
+
93
+ print("prediction (CNN id)...")
94
+ for i, p in enumerate(params):
95
+ print(i)
96
+ net.load_state_dict(p)
97
+ net.eval()
98
+ net.apply(SetNetworkToVal(False, False))
99
+ _, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats,
100
+ cf.val_batch_size, cf.net_input_must_be_divisible_by,
101
+ cf.val_min_size, device, cf.da_mirror_axes)
102
+ softmax_preds.append(softmax_pred[None])
103
+
104
+ seg = np.argmax(np.vstack(softmax_preds).mean(0), 0)
105
+
106
+ if postprocess:
107
+ seg = postprocess_prediction(seg)
108
+
109
+ print("exporting segmentation...")
110
+ save_segmentation_nifti(seg, data_dict, mask_fname)
111
+
112
+ apply_bet(in_fname, mask_fname, out_fname)
113
+
114
+ if not keep_mask:
115
+ os.remove(mask_fname)
116
+
117
+
src/BrainIAC/HD_BET/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from urllib.request import urlopen
2
+ import torch
3
+ from torch import nn
4
+ import numpy as np
5
+ from skimage.morphology import label
6
+ import os
7
+ from HD_BET.paths import folder_with_parameter_files
8
+
9
+
10
+ def get_params_fname(fold):
11
+ return os.path.join(folder_with_parameter_files, "%d.model" % fold)
12
+
13
+
14
+ def maybe_download_parameters(fold=0, force_overwrite=False):
15
+ """
16
+ Downloads the parameters for some fold if it is not present yet.
17
+ :param fold:
18
+ :param force_overwrite: if True the old parameter file will be deleted (if present) prior to download
19
+ :return:
20
+ """
21
+
22
+ assert 0 <= fold <= 4, "fold must be between 0 and 4"
23
+
24
+ if not os.path.isdir(folder_with_parameter_files):
25
+ maybe_mkdir_p(folder_with_parameter_files)
26
+
27
+ out_filename = get_params_fname(fold)
28
+
29
+ if force_overwrite and os.path.isfile(out_filename):
30
+ os.remove(out_filename)
31
+
32
+ if not os.path.isfile(out_filename):
33
+ url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold
34
+ print("Downloading", url, "...")
35
+ data = urlopen(url).read()
36
+ #out_filename = "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params/0.model"
37
+ with open(out_filename, 'wb') as f:
38
+ f.write(data)
39
+
40
+
41
+ def init_weights(module):
42
+ if isinstance(module, nn.Conv3d):
43
+ module.weight = nn.init.kaiming_normal(module.weight, a=1e-2)
44
+ if module.bias is not None:
45
+ module.bias = nn.init.constant(module.bias, 0)
46
+
47
+
48
+ def softmax_helper(x):
49
+ rpt = [1 for _ in range(len(x.size()))]
50
+ rpt[1] = x.size(1)
51
+ x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
52
+ e_x = torch.exp(x - x_max)
53
+ return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
54
+
55
+
56
+ class SetNetworkToVal(object):
57
+ def __init__(self, use_dropout_sampling=False, norm_use_average=True):
58
+ self.norm_use_average = norm_use_average
59
+ self.use_dropout_sampling = use_dropout_sampling
60
+
61
+ def __call__(self, module):
62
+ if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout):
63
+ module.train(self.use_dropout_sampling)
64
+ elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \
65
+ isinstance(module, nn.InstanceNorm1d) \
66
+ or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \
67
+ isinstance(module, nn.BatchNorm1d):
68
+ module.train(not self.norm_use_average)
69
+
70
+
71
+ def postprocess_prediction(seg):
72
+ # basically look for connected components and choose the largest one, delete everything else
73
+ print("running postprocessing... ")
74
+ mask = seg != 0
75
+ lbls = label(mask, connectivity=mask.ndim)
76
+ lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)]
77
+ largest_region = np.argmax(lbls_sizes[1:]) + 1
78
+ seg[lbls != largest_region] = 0
79
+ return seg
80
+
81
+
82
+ def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
83
+ if join:
84
+ l = os.path.join
85
+ else:
86
+ l = lambda x, y: y
87
+ res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))
88
+ and (prefix is None or i.startswith(prefix))
89
+ and (suffix is None or i.endswith(suffix))]
90
+ if sort:
91
+ res.sort()
92
+ return res
93
+
94
+
95
+ def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
96
+ if join:
97
+ l = os.path.join
98
+ else:
99
+ l = lambda x, y: y
100
+ res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
101
+ and (prefix is None or i.startswith(prefix))
102
+ and (suffix is None or i.endswith(suffix))]
103
+ if sort:
104
+ res.sort()
105
+ return res
106
+
107
+
108
+ subfolders = subdirs # I am tired of confusing those
109
+
110
+
111
+ def maybe_mkdir_p(directory):
112
+ splits = directory.split("/")[1:]
113
+ for i in range(0, len(splits)):
114
+ if not os.path.isdir(os.path.join("", *splits[:i+1])):
115
+ os.mkdir(os.path.join("", *splits[:i+1]))
src/BrainIAC/MCIclassification/README.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MCI Classification
2
+
3
+ <p align="left">
4
+ <img src="mci.jpeg" width="200" alt="MCI Classification Example"/>
5
+ </p>
6
+
7
+ ## Overview
8
+
9
+ We present the MCI classification training and inference code for BrainIAC as a downstream task. The pipeline is trained and infered on T1 scans, with AUC and F1 as evaluation metric.
10
+
11
+ ## Data Requirements
12
+
13
+ - **Input**: T1-weighted MR scans
14
+ - **Format**: NIFTI (.nii.gz)
15
+ - **Preprocessing**: Bias field corrected, registered to standard space, skull stripped, histogram normalized (optional)
16
+ - **CSV Structure**:
17
+ ```
18
+ pat_id,scandate,label
19
+ subject001,20240101,1 # 1 for MCI, 0 for healthy control
20
+ ```
21
+ refer to [ quickstart.ipynb](../quickstart.ipynb) to find how to preprocess data and generate csv file.
22
+
23
+ ## Setup
24
+
25
+ 1. **Configuration**:
26
+ change the [config.yml](../config.yml) file accordingly.
27
+ ```yaml
28
+ # config.yml
29
+ data:
30
+ train_csv: "path/to/train.csv"
31
+ val_csv: "path/to/val.csv"
32
+ test_csv: "path/to/test.csv"
33
+ root_dir: "../data/sample/processed"
34
+ collate: 1 # single scan framework
35
+
36
+ checkpoints: "./checkpoints/mci_model.00" # for inference/testing
37
+
38
+ train:
39
+ finetune: 'yes' # yes to finetune the entire model
40
+ freeze: 'no' # yes to freeze the resnet backbone
41
+ weights: ./checkpoints/brainiac.ckpt # path to brainiac weights
42
+ ```
43
+
44
+ 2. **Training**:
45
+ ```bash
46
+ python -m MCIclassification.train_mci
47
+ ```
48
+
49
+ 3. **Inference**:
50
+ ```bash
51
+ python -m MCIclassification.infer_mci
52
+ ```
src/BrainIAC/MCIclassification/__init__.py ADDED
File without changes
src/BrainIAC/MCIclassification/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (154 Bytes). View file
 
src/BrainIAC/MCIclassification/__pycache__/infer_mci.cpython-39.pyc ADDED
Binary file (3.96 kB). View file
 
src/BrainIAC/MCIclassification/infer_mci.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ import os
4
+ from tqdm import tqdm
5
+ from torch.utils.data import DataLoader
6
+ from torch.cuda.amp import autocast
7
+ from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
8
+ import numpy as np
9
+ import sys
10
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
+ from dataset2 import MedicalImageDatasetBalancedIntensity3D
12
+ from model import Backbone, SingleScanModelBP, Classifier
13
+ from utils import BaseConfig
14
+
15
+
16
+
17
+ def calculate_metrics(pred_probs, pred_labels, true_labels):
18
+ """
19
+ classification metrics.
20
+ Args:
21
+ pred_probs (numpy.ndarray): Predicted probabilities
22
+ pred_labels (numpy.ndarray): Predicted labels
23
+ true_labels (numpy.ndarray): Ground truth labels
24
+
25
+ Returns:
26
+ dict: Dictionary containing accuracy, precision, recall, F1, and AUC metrics
27
+ """
28
+ accuracy = accuracy_score(true_labels, pred_labels)
29
+ precision = precision_score(true_labels, pred_labels)
30
+ recall = recall_score(true_labels, pred_labels)
31
+ f1 = f1_score(true_labels, pred_labels)
32
+ auc = roc_auc_score(true_labels, pred_probs)
33
+
34
+ return {
35
+ 'accuracy': accuracy,
36
+ 'precision': precision,
37
+ 'recall': recall,
38
+ 'f1': f1,
39
+ 'auc': auc
40
+ }
41
+
42
+
43
+ #============================
44
+ # INFERENCE CLASS
45
+ #============================
46
+ class MCIInference(BaseConfig):
47
+ """
48
+ Inference class for MCI classification model.
49
+ """
50
+
51
+ def __init__(self):
52
+ super().__init__()
53
+ self.setup_model()
54
+ self.setup_data()
55
+
56
+ def setup_model(self):
57
+ config = self.get_config()
58
+ self.backbone = Backbone()
59
+ self.classifier = Classifier(d_model=2048, num_classes=1) # Binary classification
60
+ self.model = SingleScanModelBP(self.backbone, self.classifier)
61
+
62
+ # Load weights
63
+ checkpoint = torch.load(config["infer"]["checkpoints"], map_location=self.device, weights_only=False)
64
+ self.model.load_state_dict(checkpoint["model_state_dict"], strict=False)
65
+ self.model = self.model.to(self.device)
66
+ self.model.eval()
67
+ print("Model and checkpoint loaded!")
68
+
69
+ ## spin up data loaders
70
+ def setup_data(self):
71
+ config = self.get_config()
72
+ self.test_dataset = MedicalImageDatasetBalancedIntensity3D(
73
+ csv_path=config["data"]["test_csv"],
74
+ root_dir=config["data"]["root_dir"]
75
+ )
76
+ self.test_loader = DataLoader(
77
+ self.test_dataset,
78
+ batch_size=1,
79
+ shuffle=False,
80
+ collate_fn=self.custom_collate,
81
+ num_workers=1
82
+ )
83
+
84
+ def infer(self):
85
+ """
86
+ Run inference pass
87
+
88
+ Returns:
89
+ dict: Dictionary with evaluation metrics
90
+ """
91
+ results_df = pd.DataFrame(columns=['PredictedProb', 'PredictedLabel', 'TrueLabel'])
92
+ all_labels = []
93
+ all_predictions = []
94
+ all_probs = []
95
+
96
+ with torch.no_grad():
97
+ for sample in tqdm(self.test_loader, desc="Inference", unit="batch"):
98
+ inputs = sample['image'].to(self.device)
99
+ labels = sample['label'].float().to(self.device)
100
+
101
+ with autocast():
102
+ outputs = self.model(inputs)
103
+
104
+ probs = torch.sigmoid(outputs).cpu().numpy().flatten()
105
+ preds = (probs > 0.5).astype(int)
106
+
107
+ all_labels.extend(labels.cpu().numpy().flatten())
108
+ all_predictions.extend(preds)
109
+ all_probs.extend(probs)
110
+
111
+ result = pd.DataFrame({
112
+ 'PredictedProb': probs,
113
+ 'PredictedLabel': preds,
114
+ 'TrueLabel': labels.cpu().numpy().flatten()
115
+ })
116
+
117
+ results_df = pd.concat([results_df, result], ignore_index=True)
118
+
119
+ # log metrics
120
+ """metrics = calculate_metrics(
121
+ np.array(all_probs),
122
+ np.array(all_predictions),
123
+ np.array(all_labels)
124
+ )
125
+
126
+
127
+ print("\nTest Set Metrics:")
128
+ print(f"Accuracy: {metrics['accuracy']:.4f}")
129
+ print(f"Precision: {metrics['precision']:.4f}")
130
+ print(f"Recall: {metrics['recall']:.4f}")
131
+ print(f"F1 Score: {metrics['f1']:.4f}")
132
+ print(f"AUC: {metrics['auc']:.4f}")"""
133
+
134
+ # Save results
135
+ print("PredictedLabel", results_df["PredictedLabel"][0])
136
+ results_df.to_csv('./data/output/mci_classification_predictions.csv', index=False)
137
+
138
+ return None
139
+
140
+ if __name__ == "__main__":
141
+ inferencer = MCIInference()
142
+ metrics = inferencer.infer()
src/BrainIAC/MCIclassification/mci.jpeg ADDED

Git LFS Details

  • SHA256: d2b911a21ab5a985d43fbba6144e7718af86b835833553769ee731c0baf57cd4
  • Pointer size: 130 Bytes
  • Size of remote file: 24.1 kB
src/BrainIAC/MCIclassification/train_mci.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ import wandb
6
+ from tqdm import tqdm
7
+ from torch.optim.lr_scheduler import OneCycleLR
8
+ from torch.cuda.amp import GradScaler, autocast
9
+ import os
10
+ import sys
11
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+ from dataset2 import MedicalImageDatasetBalancedIntensity3D, TransformationMedicalImageDatasetBalancedIntensity3D
13
+ from model import Backbone, SingleScanModel, Classifier
14
+ from utils import BaseConfig
15
+ import numpy as np
16
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
17
+
18
+
19
+ def calculate_metrics(pred_probs, pred_labels, true_labels):
20
+ """
21
+ classification metrics.
22
+
23
+ Args:
24
+ pred_probs (numpy.ndarray): Predicted probabilities
25
+ pred_labels (numpy.ndarray): Predicted labels
26
+ true_labels (numpy.ndarray): Ground truth labels
27
+
28
+ Returns:
29
+ dict: Dictionary containing accuracy, precision, recall, F1, and AUC
30
+ """
31
+ accuracy = accuracy_score(true_labels, pred_labels)
32
+ precision = precision_score(true_labels, pred_labels)
33
+ recall = recall_score(true_labels, pred_labels)
34
+ f1 = f1_score(true_labels, pred_labels)
35
+ auc = roc_auc_score(true_labels, pred_probs)
36
+
37
+ return {
38
+ 'accuracy': accuracy,
39
+ 'precision': precision,
40
+ 'recall': recall,
41
+ 'f1': f1,
42
+ 'auc': auc
43
+ }
44
+
45
+ #============================
46
+ # TRAINER CLASS
47
+ #============================
48
+
49
+ class MCITrainer(BaseConfig):
50
+ """
51
+ trainer class for MCI classification
52
+ """
53
+
54
+ def __init__(self):
55
+ super().__init__()
56
+ self.setup_wandb()
57
+ self.setup_model()
58
+ self.setup_data()
59
+ self.setup_training()
60
+
61
+ def setup_wandb(self):
62
+ config = self.get_config()
63
+ wandb.init(
64
+ project=config['logger']['project_name'],
65
+ name=config['logger']['run_name'],
66
+ config=config
67
+ )
68
+
69
+ def setup_model(self):
70
+ self.backbone = Backbone()
71
+ # Change classifier to output 1 value for binary classification
72
+ self.classifier = Classifier(d_model=2048, num_classes=1)
73
+ self.model = SingleScanModel(self.backbone, self.classifier)
74
+
75
+ # Load weights from brainiac
76
+ config = self.get_config()
77
+ if config["train"]["finetune"] == "yes":
78
+ checkpoint = torch.load(config["train"]["weights"], map_location=self.device)
79
+ state_dict = checkpoint["state_dict"]
80
+ filtered_state_dict = {}
81
+ for key, value in state_dict.items():
82
+ new_key = key.replace("module.", "backbone.") if key.startswith("module.") else key
83
+ filtered_state_dict[new_key] = value
84
+ self.model.backbone.load_state_dict(filtered_state_dict, strict=False)
85
+ print("Pretrained weights loaded!")
86
+
87
+ if config["train"]["freeze"] == "yes":
88
+ for param in self.model.backbone.parameters():
89
+ param.requires_grad = False
90
+ print("Backbone weights frozen!")
91
+
92
+ self.model = self.model.to(self.device)
93
+
94
+ ## spinup dataloaders
95
+ def setup_data(self):
96
+ config = self.get_config()
97
+ self.train_dataset = TransformationMedicalImageDatasetBalancedIntensity3D(
98
+ csv_path=config['data']['train_csv'],
99
+ root_dir=config["data"]["root_dir"]
100
+ )
101
+ self.val_dataset = MedicalImageDatasetBalancedIntensity3D(
102
+ csv_path=config['data']['val_csv'],
103
+ root_dir=config["data"]["root_dir"]
104
+ )
105
+
106
+ self.train_loader = DataLoader(
107
+ self.train_dataset,
108
+ batch_size=config["data"]["batch_size"],
109
+ shuffle=True,
110
+ collate_fn=self.custom_collate,
111
+ num_workers=config["data"]["num_workers"]
112
+ )
113
+ self.val_loader = DataLoader(
114
+ self.val_dataset,
115
+ batch_size=1,
116
+ shuffle=False,
117
+ collate_fn=self.custom_collate,
118
+ num_workers=1
119
+ )
120
+
121
+ def setup_training(self):
122
+ """
123
+ training setup
124
+ """
125
+ config = self.get_config()
126
+ # BCE loss
127
+ self.criterion = nn.BCEWithLogitsLoss().to(self.device)
128
+ self.optimizer = optim.AdamW(
129
+ self.model.parameters(),
130
+ lr=config['optim']['lr'],
131
+ weight_decay=config["optim"]["weight_decay"]
132
+ )
133
+ self.scheduler = OneCycleLR(
134
+ self.optimizer,
135
+ max_lr=config['optim']['lr'],
136
+ epochs=config['optim']['max_epochs'],
137
+ steps_per_epoch=len(self.train_loader)
138
+ )
139
+ self.scaler = GradScaler()
140
+
141
+ ## main training loop
142
+ def train(self):
143
+ config = self.get_config()
144
+ max_epochs = config['optim']['max_epochs']
145
+ best_metrics = {
146
+ 'val_loss': float('inf'),
147
+ 'accuracy': 0,
148
+ 'precision': 0,
149
+ 'recall': 0,
150
+ 'f1': 0,
151
+ 'auc': 0
152
+ }
153
+
154
+ for epoch in range(max_epochs):
155
+ train_loss = self.train_epoch(epoch, max_epochs)
156
+ val_loss, metrics = self.validate_epoch(epoch, max_epochs)
157
+
158
+ # Save best model based on validation loss and F1 score
159
+ if metrics['auc'] > best_metrics['auc']:
160
+ print(f"New best model found!")
161
+ print(f"Improved Val Loss from {best_metrics['val_loss']:.4f} to {val_loss:.4f}")
162
+ print(f"Improved F1 from {best_metrics['f1']:.4f} to {metrics['f1']:.4f}")
163
+ best_metrics.update(metrics)
164
+ best_metrics['val_loss'] = val_loss
165
+ self.save_checkpoint(epoch, val_loss, metrics)
166
+
167
+ wandb.finish()
168
+
169
+ ## training pass
170
+ def train_epoch(self, epoch, max_epochs):
171
+ self.model.train()
172
+ train_loss = 0.0
173
+
174
+ for sample in tqdm(self.train_loader, desc=f"Training Epoch {epoch}/{max_epochs-1}"):
175
+ inputs = sample['image'].to(self.device)
176
+ labels = sample['label'].float().to(self.device)
177
+
178
+ self.optimizer.zero_grad(set_to_none=True)
179
+ with autocast():
180
+ outputs = self.model(inputs)
181
+ loss = self.criterion(outputs, labels.unsqueeze(1))
182
+
183
+ self.scaler.scale(loss).backward()
184
+
185
+ self.scaler.unscale_(self.optimizer)
186
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
187
+
188
+ self.scaler.step(self.optimizer)
189
+ self.scaler.update()
190
+ self.scheduler.step()
191
+
192
+ train_loss += loss.item() * inputs.size(0)
193
+
194
+ train_loss = train_loss / len(self.train_loader.dataset)
195
+ wandb.log({"Train Loss": train_loss})
196
+ return train_loss
197
+
198
+ ## validation pass
199
+ def validate_epoch(self, epoch, max_epochs):
200
+ self.model.eval()
201
+ val_loss = 0.0
202
+ all_labels = []
203
+ all_preds = []
204
+ all_probs = []
205
+
206
+ with torch.no_grad():
207
+ for sample in tqdm(self.val_loader, desc=f"Validation Epoch {epoch}/{max_epochs-1}"):
208
+ inputs = sample['image'].to(self.device)
209
+ labels = sample['label'].float().to(self.device)
210
+
211
+ outputs = self.model(inputs)
212
+ loss = self.criterion(outputs, labels.unsqueeze(1))
213
+
214
+ # Get probabilities and predictions
215
+ probs = torch.sigmoid(outputs).cpu().numpy()
216
+ preds = (probs > 0.5).astype(int)
217
+
218
+ val_loss += loss.item() * inputs.size(0)
219
+ all_labels.extend(labels.cpu().numpy().flatten())
220
+ all_preds.extend(preds.flatten())
221
+ all_probs.extend(probs.flatten())
222
+
223
+ val_loss = val_loss / len(self.val_loader.dataset)
224
+ metrics = calculate_metrics(
225
+ np.array(all_probs),
226
+ np.array(all_preds),
227
+ np.array(all_labels)
228
+ )
229
+
230
+ wandb.log({
231
+ "Val Loss": val_loss,
232
+ "Accuracy": metrics['accuracy'],
233
+ "Precision": metrics['precision'],
234
+ "Recall": metrics['recall'],
235
+ "F1 Score": metrics['f1'],
236
+ "AUC": metrics['auc']
237
+ })
238
+
239
+ print(f"Epoch {epoch}/{max_epochs-1}")
240
+ print(f"Val Loss: {val_loss:.4f}")
241
+ print(f"Accuracy: {metrics['accuracy']:.4f}")
242
+ print(f"Precision: {metrics['precision']:.4f}")
243
+ print(f"Recall: {metrics['recall']:.4f}")
244
+ print(f"F1 Score: {metrics['f1']:.4f}")
245
+ print(f"AUC: {metrics['auc']:.4f}")
246
+
247
+ return val_loss, metrics
248
+
249
+ ## save best model
250
+ def save_checkpoint(self, epoch, loss, metrics):
251
+ config = self.get_config()
252
+ checkpoint = {
253
+ 'epoch': epoch,
254
+ 'model_state_dict': self.model.state_dict(),
255
+ 'metrics': metrics
256
+ }
257
+ save_path = os.path.join(
258
+ config['logger']['save_dir'],
259
+ config['logger']['save_name'].format(epoch=epoch, loss=loss, metric=metrics['f1'])
260
+ )
261
+ torch.save(checkpoint, save_path)
262
+
263
+ if __name__ == "__main__":
264
+ trainer = MCITrainer()
265
+ trainer.train()
src/BrainIAC/__init__.py ADDED
File without changes
src/BrainIAC/__pycache__/dataset2.cpython-39.pyc ADDED
Binary file (6.71 kB). View file
 
src/BrainIAC/__pycache__/load_brainiac.cpython-39.pyc ADDED
Binary file (1.44 kB). View file