diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..bceafdf2f24771ce70ec25d65db22f31a3ff09b7 Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/BrainIAC/.DS_Store b/src/BrainIAC/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..7fa5d5ac21d1293a5178d86b0ba8df92d22b5b96 Binary files /dev/null and b/src/BrainIAC/.DS_Store differ diff --git a/src/BrainIAC/Brainage/README.md b/src/BrainIAC/Brainage/README.md new file mode 100644 index 0000000000000000000000000000000000000000..44da0a58db45468ee4a26d679d34f19b1214b4b6 --- /dev/null +++ b/src/BrainIAC/Brainage/README.md @@ -0,0 +1,55 @@ +# Brain Age Prediction + +
+
+
+
+
MCI Probability
") + probability_gauge = gr.Image(label="Probability Gauge", type="numpy", show_label=False, elem_classes=["probability-gauge"]) + + with gr.Group(): + gr.Markdown("### Saliency Map Viewer (Axial Slice)") + slice_slider = gr.Slider(label="Select Slice", minimum=0, maximum=0, step=1, value=0, visible=False) + with gr.Row(): + with gr.Column(): + gr.Markdown("Input Slice
") + input_slice_img = gr.Image(label="Input Slice", type="numpy", show_label=False) + with gr.Column(): + gr.Markdown("Saliency Heatmap
") + heatmap_slice_img = gr.Image(label="Saliency Heatmap", type="numpy", show_label=False) + with gr.Column(): + gr.Markdown("Overlay
") + overlay_slice_img = gr.Image(label="Overlay", type="numpy", show_label=False) + + # --- Wire Components --- + submit_btn.click( + fn=process_scan, + inputs=[file_type, scan_file, run_preprocess, generate_saliency_checkbox], + outputs=[prediction_output, input_slice_img, heatmap_slice_img, overlay_slice_img, probability_gauge, slice_slider, saliency_state] + ) + + slice_slider.change( + fn=update_slice_viewer, + inputs=[slice_slider, saliency_state], + outputs=[input_slice_img, heatmap_slice_img, overlay_slice_img] + ) + +# --- Launch the App --- +if __name__ == "__main__": + if model is None: + print("ERROR: Model failed to load. Gradio app cannot start.") + else: + print("Launching Gradio Interface...") + demo.launch(server_name="0.0.0.0", server_port=7860, debug=False, share=False) diff --git a/src/BrainIAC/checkpoints/__init__.py b/src/BrainIAC/checkpoints/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/BrainIAC/checkpoints/mci_model.pt b/src/BrainIAC/checkpoints/mci_model.pt new file mode 100644 index 0000000000000000000000000000000000000000..6df5c762fa2e0621f22c7587bfb0855e23259dbb --- /dev/null +++ b/src/BrainIAC/checkpoints/mci_model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:152490617007f608a96896e014e3321d22f50426273d928f07ce3fceabf4794d +size 184972165 diff --git a/src/BrainIAC/config.yml b/src/BrainIAC/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..5b0ee43626ea9cdcf96516319bc37639f220736d --- /dev/null +++ b/src/BrainIAC/config.yml @@ -0,0 +1,28 @@ +data: + batch_size: 16 + collate: 1 + num_workers: 4 + root_dir: ./data/sample/processed + test_csv: ./data/csvs/input_scans.csv + train_csv: ./data/csvs/train_set_100.csv + val_csv: ./data/csvs/val_set_100.csv +gpu: + device: cpu + visible_device: '' +infer: + checkpoints: ./checkpoints/mci_model.pt +logger: + project_name: brainage + run_name: ExperimentName_trainconfigs + save_dir: ./Checkpoints + save_name: ExperimentName_trainconfigs_checkpoint-{epoch:02d}-{loss:.2f}-{metric:.2f} +optim: + clr: 'no' + lr: 0.0001 + max_epochs: 200 + momentum: 0.9 + weight_decay: 1.0e-05 +train: + finetune: 'yes' + freeze: 'no' + weights: path/to/brainiac/weights diff --git a/src/BrainIAC/data/csvs/brainage.csv b/src/BrainIAC/data/csvs/brainage.csv new file mode 100644 index 0000000000000000000000000000000000000000..0142508025e12c48db6f1915debcf2e4f77158c8 --- /dev/null +++ b/src/BrainIAC/data/csvs/brainage.csv @@ -0,0 +1,2 @@ +pat_id,scandate,label +subpixar009,T1w,42.0 diff --git a/src/BrainIAC/data/csvs/input_scans.csv b/src/BrainIAC/data/csvs/input_scans.csv new file mode 100644 index 0000000000000000000000000000000000000000..5e22b0fb5b082233f8cfe5f4c449f4b288c15807 --- /dev/null +++ b/src/BrainIAC/data/csvs/input_scans.csv @@ -0,0 +1,2 @@ +pat_id,scandate,label +00001,t2f,1 \ No newline at end of file diff --git a/src/BrainIAC/data/csvs/mci.csv b/src/BrainIAC/data/csvs/mci.csv new file mode 100644 index 0000000000000000000000000000000000000000..8e8534800c70904f9fabb53319553eb318b6b43d --- /dev/null +++ b/src/BrainIAC/data/csvs/mci.csv @@ -0,0 +1,2 @@ +pat_id,scandate,label +subpixar009,T1w,0 \ No newline at end of file diff --git a/src/BrainIAC/data/csvs/sequenceclass.csv b/src/BrainIAC/data/csvs/sequenceclass.csv new file mode 100644 index 0000000000000000000000000000000000000000..5e22b0fb5b082233f8cfe5f4c449f4b288c15807 --- /dev/null +++ b/src/BrainIAC/data/csvs/sequenceclass.csv @@ -0,0 +1,2 @@ +pat_id,scandate,label +00001,t2f,1 \ No newline at end of file diff --git a/src/BrainIAC/data/csvs/stroke.csv b/src/BrainIAC/data/csvs/stroke.csv new file mode 100644 index 0000000000000000000000000000000000000000..562dec7da4571578721ffbd8585b8ee3e30f739b --- /dev/null +++ b/src/BrainIAC/data/csvs/stroke.csv @@ -0,0 +1,2 @@ +pat_id,scandate,label +subpixar009,T1w,0.0 diff --git a/src/BrainIAC/data/output/__init__.py b/src/BrainIAC/data/output/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/BrainIAC/data/sample/processed/00001_t1c.nii.gz b/src/BrainIAC/data/sample/processed/00001_t1c.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..ab00e12326e51bd3a33fd2480edd410c47b9989c --- /dev/null +++ b/src/BrainIAC/data/sample/processed/00001_t1c.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4399faadcc45c8a4541313cdf88aced7d835ed59ac3078d950e0eac293d603f5 +size 2901768 diff --git a/src/BrainIAC/data/sample/processed/00001_t1n.nii.gz b/src/BrainIAC/data/sample/processed/00001_t1n.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..cdefdc498322d0e737c8427445f7ea14fa043886 --- /dev/null +++ b/src/BrainIAC/data/sample/processed/00001_t1n.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e860924b936e301ddeba20409fbb59dde322475cb49328f1b46a9235c792e73e +size 2614637 diff --git a/src/BrainIAC/data/sample/processed/00001_t2f.nii.gz b/src/BrainIAC/data/sample/processed/00001_t2f.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..fcc44846b1045df8dfb8127531d54bee07d2df30 --- /dev/null +++ b/src/BrainIAC/data/sample/processed/00001_t2f.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:82aed8546af5e6d8d94fd91c56227abdcf6120130390d1556c4342a208980604 +size 2802807 diff --git a/src/BrainIAC/data/sample/processed/00001_t2w.nii.gz b/src/BrainIAC/data/sample/processed/00001_t2w.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..e15aaa1564ff91ecc8497b08a5243489abb71380 --- /dev/null +++ b/src/BrainIAC/data/sample/processed/00001_t2w.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4cd389cc57d12134a30a898c66532228126b9a7d0600ee578e82a32144528b51 +size 2714287 diff --git a/src/BrainIAC/data/sample/processed/subpixar009_T1w.nii.gz b/src/BrainIAC/data/sample/processed/subpixar009_T1w.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..8f1b9054c5f5f99941fcc12094cf23e1bc966ccf --- /dev/null +++ b/src/BrainIAC/data/sample/processed/subpixar009_T1w.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e46cc233952a663f6e5f89e07e5f5f436744d190772ee2a5ca1c639fff1f15a +size 1465387 diff --git a/src/BrainIAC/dataset2.py b/src/BrainIAC/dataset2.py new file mode 100644 index 0000000000000000000000000000000000000000..701f9afddfaaffda74283693d81fe1c7a99d48d9 --- /dev/null +++ b/src/BrainIAC/dataset2.py @@ -0,0 +1,168 @@ +import os +import torch +import pandas as pd +from torch.utils.data import Dataset +import nibabel as nib +from monai.transforms import Affined, RandGaussianNoised, Rand3DElasticd, AdjustContrastd, ScaleIntensityd, ToTensord, Resized, RandRotate90d, Resize, RandGaussianSmoothd, GaussianSmoothd, Rotate90d, StdShiftIntensityd, RandAdjustContrastd, Flipd +import random +import numpy as np + + +####################################### +## 3D SYNC TRANSFORM +####################################### + +class NormalSynchronizedTransform3D: + """ Vanilla Validation Transforms""" + + def __init__(self, image_size=(128,128,128), max_rotation=40, translate_range=0.2, scale_range=(0.9, 1.3), apply_prob=0.5): + self.image_size = image_size + self.max_rotation = max_rotation + self.translate_range = translate_range + self.scale_range = scale_range + self.apply_prob = apply_prob + + def __call__(self, scan_list): + transformed_scans = [] + resize_transform = Resized(spatial_size=(128,128,128), keys=["image"]) + scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0) # Intensity scaling + tensor_transform = ToTensord(keys=["image"]) # Convert to tensor + + for scan in scan_list: + sample = {"image": scan} + sample = resize_transform(sample) + sample = scale_transform(sample) + sample = tensor_transform(sample) + transformed_scans.append(sample["image"].squeeze()) + + return torch.stack(transformed_scans) + +class MedicalImageDatasetBalancedIntensity3D(Dataset): + """ Validation Dataset class """ + + def __init__(self, csv_path, root_dir, transform=None): + self.dataframe = pd.read_csv(csv_path, dtype={"pat_id":str, "scandate":str}) + self.root_dir = root_dir + self.transform = NormalSynchronizedTransform3D() + + def __len__(self): + return len(self.dataframe) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + ## load the niftis from csv + pat_id = str(self.dataframe.loc[idx, 'pat_id']) + scan_dates = str(self.dataframe.loc[idx, 'scandate']) + label = self.dataframe.loc[idx, 'label'] + scandates = scan_dates.split('-') + scan_list = [] + + + for scandate in scandates: + img_name = os.path.join(self.root_dir , f"{pat_id}_{scandate}.nii.gz") + scan = nib.load(img_name).get_fdata() + scan_list.append(torch.tensor(scan, dtype=torch.float32).unsqueeze(0)) + + ## package into a dictionary for val loader + transformed_scans = self.transform(scan_list) + sample = {"image": transformed_scans, "label": torch.tensor(label, dtype=torch.float32), "pat_id": pat_id} + return sample + + +class SynchronizedTransform3D: + """ Trainign Augmentation method """ + + def __init__(self, image_size=(128,128,128), max_rotation=0.34, translate_range=15, scale_range=(0.9, 1.3), apply_prob=0.5, gaussian_sigma_range=(0.25, 1.5), gaussian_noise_std_range=(0.05, 0.09)): + self.image_size = image_size + self.max_rotation = max_rotation + self.translate_range = translate_range + self.scale_range = scale_range + self.apply_prob = apply_prob + self.gaussian_sigma_range = gaussian_sigma_range + self.gaussian_noise_std_range = gaussian_noise_std_range + + def __call__(self, scan_list): + transformed_scans = [] + rotate_params = (random.uniform(-self.max_rotation, self.max_rotation),) * 3 if random.random() < self.apply_prob else (0, 0, 0) + translate_params = tuple([random.uniform(-self.translate_range, self.translate_range) for _ in range(3)]) if random.random() < self.apply_prob else (0, 0, 0) + scale_params = tuple([random.uniform(self.scale_range[0], self.scale_range[1]) for _ in range(3)]) if random.random() < self.apply_prob else (1, 1, 1) + gaussian_sigma = tuple([random.uniform(self.gaussian_sigma_range[0], self.gaussian_sigma_range[1]) for _ in range(3)]) if random.random() < self.apply_prob else None + gaussian_noise_std = random.uniform(self.gaussian_noise_std_range[0], self.gaussian_noise_std_range[1]) if random.random() < self.apply_prob else None + flip_axes = (0,1) if random.random() < self.apply_prob else None # Determine if and along which axes to flip + flip_x = 0 if random.random() < self.apply_prob else None + flip_y = 1 if random.random() < self.apply_prob else None + flip_z = 2 if random.random() < self.apply_prob else None + offset = random.randint(50,100) if random.random() < self.apply_prob else None + gammafactor = random.uniform(0.5,2.0) if random.random() < self.apply_prob else 1 + + affine_transform = Affined(keys=["image"], rotate_params=rotate_params, translate_params=translate_params, scale_params=scale_params, padding_mode='zeros') + gaussian_blur_transform = GaussianSmoothd(keys=["image"], sigma=gaussian_sigma) if gaussian_sigma else None + gaussian_noise_transform = RandGaussianNoised(keys=["image"], std=gaussian_noise_std, prob=1.0, mean=0.0, sample_std=False) if gaussian_noise_std else None + #flip_transform = Rotate90d(keys=["image"], k=1, spatial_axes=flip_axes) if flip_axes else None + flip_x_transform = Flipd(keys=["image"], spatial_axis=flip_x) if flip_x else None + flip_y_transform = Flipd(keys=["image"], spatial_axis=flip_y) if flip_y else None + flip_z_transform = Flipd(keys=["image"], spatial_axis=flip_z) if flip_z else None + resize_transform = Resized(spatial_size=(128,128,128), keys=["image"]) + scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0) # Intensity scaling + tensor_transform = ToTensord(keys=["image"]) # Convert to tensor + shift_intensity = StdShiftIntensityd(keys = ["image"], factor = offset, nonzero=True) + adjust_contrast = AdjustContrastd(keys = ["image"], gamma = gammafactor) + + for scan in scan_list: + sample = {"image": scan} + sample = resize_transform(sample) + sample = affine_transform(sample) + if flip_x_transform: + sample = flip_x_transform(sample) + if flip_y_transform: + sample = flip_y_transform(sample) + if flip_z_transform: + sample = flip_z_transform(sample) + if gaussian_blur_transform: + sample = gaussian_blur_transform(sample) + if offset: + sample = shift_intensity(sample) + sample = scale_transform(sample) + sample = adjust_contrast(sample) + if gaussian_noise_transform: + sample = gaussian_noise_transform(sample) + sample = tensor_transform(sample) + transformed_scans.append(sample["image"].squeeze()) + + return torch.stack(transformed_scans) + + +class TransformationMedicalImageDatasetBalancedIntensity3D(Dataset): + """ Training Dataset class """ + + def __init__(self, csv_path, root_dir, transform=None): + self.dataframe = pd.read_csv(csv_path, dtype={"pat_id":str, "scandate":str}) + self.root_dir = root_dir + self.transform = SynchronizedTransform3D() # calls training augmentations + + def __len__(self): + return len(self.dataframe) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + ## load the niftis from csv + pat_id = str(self.dataframe.loc[idx, 'pat_id']) + scan_dates = str(self.dataframe.loc[idx, 'scandate']) + label = self.dataframe.loc[idx, 'label'] + scandates = scan_dates.split('-') + scan_list = [] + + + for scandate in scandates: + img_name = os.path.join(self.root_dir , f"{pat_id}_{scandate}.nii.gz") #f"{pat_id}_{scandate}.nii.gz") + scan = nib.load(img_name).get_fdata() + scan_list.append(torch.tensor(scan, dtype=torch.float32).unsqueeze(0)) + + # package into a monai type dictionary + transformed_scans = self.transform(scan_list) + sample = {"image": transformed_scans, "label": torch.tensor(label, dtype=torch.float32), "pat_id": pat_id} + return sample diff --git a/src/BrainIAC/get_brainiac_features.py b/src/BrainIAC/get_brainiac_features.py new file mode 100644 index 0000000000000000000000000000000000000000..dda6e2f39d2974bb6478309551216addd00440ce --- /dev/null +++ b/src/BrainIAC/get_brainiac_features.py @@ -0,0 +1,119 @@ +import torch +import numpy as np +import pandas as pd +import random +import yaml +import os +import argparse +from tqdm import tqdm +from torch.utils.data import DataLoader +from dataset2 import MedicalImageDatasetBalancedIntensity3D +from load_brainiac import load_brainiac + +# fix random seed +seed = 42 +random.seed(seed) +np.random.seed(seed) +torch.manual_seed(seed) + + +# Set GPU +os.environ['CUDA_VISIBLE_DEVICES'] = "0" +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + +# Define custom collate function for data loading +def custom_collate(batch): + images = [item['image'] for item in batch] + labels = [item['label'] for item in batch] + + max_len = 1 + padded_images = [] + + for img in images: + pad_size = max_len - img.shape[0] + if pad_size > 0: + padding = torch.zeros((pad_size,) + img.shape[1:]) + img_padded = torch.cat([img, padding], dim=0) + padded_images.append(img_padded) + else: + padded_images.append(img) + + return {"image": torch.stack(padded_images, dim=0), "label": torch.stack(labels)} + +#========================= +# Inference function +#========================= + +def infer(model, test_loader): + features_df = None # Placeholder for feature DataFrame + model.eval() + + with torch.no_grad(): + for sample in tqdm(test_loader, desc="Inference", unit="batch"): + inputs = sample['image'].to(device) + class_labels = sample['label'].float().to(device) + + # Get features from the model + features = model(inputs) + features_numpy = features.cpu().numpy() + + # Expand features into separate columns + feature_columns = [f'Feature_{i}' for i in range(features_numpy.shape[1])] + batch_features = pd.DataFrame( + features_numpy, + columns=feature_columns + ) + batch_features['GroundTruthClassLabel'] = class_labels.cpu().numpy().flatten() + + # Append batch features to features_df + if features_df is None: + features_df = batch_features + else: + features_df = pd.concat([features_df, batch_features], ignore_index=True) + + return features_df + +#========================= +# Main inference pipeline +#========================= + +def main(): + # argparse + parser = argparse.ArgumentParser(description='Extract BrainIAC features from images') + parser.add_argument('--checkpoint', type=str, required=True, + help='Path to the BrainIAC model checkpoint') + parser.add_argument('--input_csv', type=str, required=True, + help='Path to the input CSV file containing image paths') + parser.add_argument('--output_csv', type=str, required=True, + help='Path to save the output features CSV') + parser.add_argument('--root_dir', type=str, required=True, + help='Root directory containing the image data') + args = parser.parse_args() + + # spinup the dataloader + test_dataset = MedicalImageDatasetBalancedIntensity3D( + csv_path=args.input_csv, + root_dir=args.root_dir + ) + test_loader = DataLoader( + test_dataset, + batch_size=1, + shuffle=False, + collate_fn=custom_collate, + num_workers=1 + ) + + # Load brainiac + model = load_brainiac(args.checkpoint, device) + model = model.to(device) + # infer + features_df = infer(model, test_loader) + + # Save features + features_df.to_csv(args.output_csv, index=False) + print(f"Features saved to {args.output_csv}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/BrainIAC/get_brainiac_saliencymap.py b/src/BrainIAC/get_brainiac_saliencymap.py new file mode 100644 index 0000000000000000000000000000000000000000..5c3ce921c2802afe552642f2862ce2af9b0a3ac4 --- /dev/null +++ b/src/BrainIAC/get_brainiac_saliencymap.py @@ -0,0 +1,111 @@ +import torch +import numpy as np +import random +import yaml +import os +import argparse +from tqdm import tqdm +from torch.utils.data import DataLoader +import nibabel as nib +from monai.visualize.gradient_based import SmoothGrad, GuidedBackpropSmoothGrad +from dataset2 import MedicalImageDatasetBalancedIntensity3D +from load_brainiac import load_brainiac + +# Fix random seed +seed = 42 +random.seed(seed) +np.random.seed(seed) +torch.manual_seed(seed) + +# collate funcntion (unneccerary for single timpoint input) +def custom_collate(batch): + """Handles variable size of the scans and pads the sequence dimension.""" + images = [item['image'] for item in batch] + labels = [item['label'] for item in batch] + patids = [item['pat_id'] for item in batch] + + max_len = 1 # singlescan input + padded_images = [] + + for img in images: + pad_size = max_len - img.shape[0] + if pad_size > 0: + padding = torch.zeros((pad_size,) + img.shape[1:]) + img_padded = torch.cat([img, padding], dim=0) + padded_images.append(img_padded) + else: + padded_images.append(img) + + return {"image": torch.stack(padded_images, dim=0), "label": labels, "pat_id": patids} + + +def generate_saliency_maps(model, data_loader, output_dir, device): + """Generate saliency maps using guided backprop method""" + model.eval() + visualizer = GuidedBackpropSmoothGrad(model=model.backbone, stdev_spread=0.15, n_samples=10, magnitude=True) + + for sample in tqdm(data_loader, desc="Generating saliency maps"): + inputs = sample['image'].requires_grad_(True) + patids = sample["pat_id"] + imagename = patids[0] + + input_tensor = inputs.to(device) + + with torch.enable_grad(): + saliency_map = visualizer(input_tensor) + + # Save input image and saliency map + inputs_np = input_tensor.squeeze().cpu().detach().numpy() + saliency_np = saliency_map.squeeze().cpu().detach().numpy() + + input_nifti = nib.Nifti1Image(inputs_np, np.eye(4)) + saliency_nifti = nib.Nifti1Image(saliency_np, np.eye(4)) + + # Save files + nib.save(input_nifti, os.path.join(output_dir, f"{imagename}_image.nii.gz")) + nib.save(saliency_nifti, os.path.join(output_dir, f"{imagename}_saliencymap.nii.gz")) + +def main(): + parser = argparse.ArgumentParser(description='Generate saliency maps for medical images') + parser.add_argument('--checkpoint', type=str, required=True, + help='Path to the model checkpoint') + parser.add_argument('--input_csv', type=str, required=True, + help='Path to the input CSV file containing image paths') + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save saliency maps') + parser.add_argument('--root_dir', type=str, required=True, + help='Root directory containing the image data') + + args = parser.parse_args() + device = torch.device("cpu") + + # Create output directory if it doesn't exist + os.makedirs(args.output_dir, exist_ok=True) + + # Initialize dataset and dataloader + dataset = MedicalImageDatasetBalancedIntensity3D( + csv_path=args.input_csv, + root_dir=args.root_dir + ) + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=custom_collate, + num_workers=1 + ) + + # Load brainiac and ensure it's on CPU + model = load_brainiac(args.checkpoint, device) + model = model.to(device) + + # Make sure model weights are on CPU + model.backbone = model.backbone.to(device) + + # Generate saliency maps + generate_saliency_maps(model, dataloader, args.output_dir, device) + + print(f"Saliency maps generated and saved to {args.output_dir}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/BrainIAC/golden_image/.DS_Store b/src/BrainIAC/golden_image/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..6cf36feab32f697bfed66e5ec110a0e87f914476 Binary files /dev/null and b/src/BrainIAC/golden_image/.DS_Store differ diff --git a/src/BrainIAC/golden_image/mni_templates/.DS_Store b/src/BrainIAC/golden_image/mni_templates/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/src/BrainIAC/golden_image/mni_templates/.DS_Store differ diff --git a/src/BrainIAC/golden_image/mni_templates/Parameters_Rigid.txt b/src/BrainIAC/golden_image/mni_templates/Parameters_Rigid.txt new file mode 100644 index 0000000000000000000000000000000000000000..19d729f7e970a3683fe06b2b59a24f27916a516a --- /dev/null +++ b/src/BrainIAC/golden_image/mni_templates/Parameters_Rigid.txt @@ -0,0 +1,141 @@ +// Example parameter file for rotation registration +// C-style comments: // + +// The internal pixel type, used for internal computations +// Leave to float in general. +// NB: this is not the type of the input images! The pixel +// type of the input images is automatically read from the +// images themselves. +// This setting can be changed to "short" to save some memory +// in case of very large 3D images. +(FixedInternalImagePixelType "float") +(MovingInternalImagePixelType "float") + +// **************** Main Components ************************** + +// The following components should usually be left as they are: +(Registration "MultiResolutionRegistration") +(Interpolator "BSplineInterpolator") +(ResampleInterpolator "FinalBSplineInterpolator") +(Resampler "DefaultResampler") + +// These may be changed to Fixed/MovingSmoothingImagePyramid. +// See the manual. +(FixedImagePyramid "FixedRecursiveImagePyramid") +(MovingImagePyramid "MovingRecursiveImagePyramid") + +// The following components are most important: +// The optimizer AdaptiveStochasticGradientDescent (ASGD) works +// quite ok in general. The Transform and Metric are important +// and need to be chosen careful for each application. See manual. +(Optimizer "AdaptiveStochasticGradientDescent") +(Transform "EulerTransform") +(Metric "AdvancedMattesMutualInformation") + +// ***************** Transformation ************************** + +// Scales the rotations compared to the translations, to make +// sure they are in the same range. In general, it's best to +// use automatic scales estimation: +(AutomaticScalesEstimation "true") + +// Automatically guess an initial translation by aligning the +// geometric centers of the fixed and moving. +(AutomaticTransformInitialization "true") + +// Whether transforms are combined by composition or by addition. +// In generally, Compose is the best option in most cases. +// It does not influence the results very much. +(HowToCombineTransforms "Compose") + +// ******************* Similarity measure ********************* + +// Number of grey level bins in each resolution level, +// for the mutual information. 16 or 32 usually works fine. +// You could also employ a hierarchical strategy: +//(NumberOfHistogramBins 16 32 64) +(NumberOfHistogramBins 32) + +// If you use a mask, this option is important. +// If the mask serves as region of interest, set it to false. +// If the mask indicates which pixels are valid, then set it to true. +// If you do not use a mask, the option doesn't matter. +(ErodeMask "false") + +// ******************** Multiresolution ********************** + +// The number of resolutions. 1 Is only enough if the expected +// deformations are small. 3 or 4 mostly works fine. For large +// images and large deformations, 5 or 6 may even be useful. +(NumberOfResolutions 4) + +// The downsampling/blurring factors for the image pyramids. +// By default, the images are downsampled by a factor of 2 +// compared to the next resolution. +// So, in 2D, with 4 resolutions, the following schedule is used: +//(ImagePyramidSchedule 8 8 4 4 2 2 1 1 ) +// And in 3D: +//(ImagePyramidSchedule 8 8 8 4 4 4 2 2 2 1 1 1 ) +// You can specify any schedule, for example: +//(ImagePyramidSchedule 4 4 4 3 2 1 1 1 ) +// Make sure that the number of elements equals the number +// of resolutions times the image dimension. + +// ******************* Optimizer **************************** + +// Maximum number of iterations in each resolution level: +// 200-500 works usually fine for rigid registration. +// For more robustness, you may increase this to 1000-2000. +(MaximumNumberOfIterations 250) + +// The step size of the optimizer, in mm. By default the voxel size is used. +// which usually works well. In case of unusual high-resolution images +// (eg histology) it is necessary to increase this value a bit, to the size +// of the "smallest visible structure" in the image: +//(MaximumStepLength 1.0) + +// **************** Image sampling ********************** + +// Number of spatial samples used to compute the mutual +// information (and its derivative) in each iteration. +// With an AdaptiveStochasticGradientDescent optimizer, +// in combination with the two options below, around 2000 +// samples may already suffice. +(NumberOfSpatialSamples 2048) + +// Refresh these spatial samples in every iteration, and select +// them randomly. See the manual for information on other sampling +// strategies. +(NewSamplesEveryIteration "true") +(ImageSampler "Random") + +// ************* Interpolation and Resampling **************** + +// Order of B-Spline interpolation used during registration/optimisation. +// It may improve accuracy if you set this to 3. Never use 0. +// An order of 1 gives linear interpolation. This is in most +// applications a good choice. +(BSplineInterpolationOrder 1) + +// Order of B-Spline interpolation used for applying the final +// deformation. +// 3 gives good accuracy; recommended in most cases. +// 1 gives worse accuracy (linear interpolation) +// 0 gives worst accuracy, but is appropriate for binary images +// (masks, segmentations); equivalent to nearest neighbor interpolation. +(FinalBSplineInterpolationOrder 3) + +//Default pixel value for pixels that come from outside the picture: +(DefaultPixelValue 0) + +// Choose whether to generate the deformed moving image. +// You can save some time by setting this to false, if you are +// only interested in the final (nonrigidly) deformed moving image +// for example. +(WriteResultImage "true") + +// The pixel type and format of the resulting deformed moving image +(ResultImagePixelType "short") +(ResultImageFormat "mhd") + + diff --git a/src/BrainIAC/golden_image/mni_templates/nihpd_asym_04.5-08.5_t1w.nii b/src/BrainIAC/golden_image/mni_templates/nihpd_asym_04.5-08.5_t1w.nii new file mode 100644 index 0000000000000000000000000000000000000000..9a9d07d6300675c02166186d9a980a7f8d86cf57 --- /dev/null +++ b/src/BrainIAC/golden_image/mni_templates/nihpd_asym_04.5-08.5_t1w.nii @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66273629365bd2cde6b2e2e7a74c0de52b703d1f92c12d2e4746aa9b47684e27 +size 17350930 diff --git a/src/BrainIAC/golden_image/mni_templates/nihpd_asym_07.5-13.5_t1w.nii b/src/BrainIAC/golden_image/mni_templates/nihpd_asym_07.5-13.5_t1w.nii new file mode 100644 index 0000000000000000000000000000000000000000..bfb6577cdd506e176d9e8d8522e67d88ac0ee797 --- /dev/null +++ b/src/BrainIAC/golden_image/mni_templates/nihpd_asym_07.5-13.5_t1w.nii @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07d05690e2d2635b28673cd4667617c97ae14f648a8923ba7de2fb6578ae61ad +size 17350930 diff --git a/src/BrainIAC/golden_image/mni_templates/nihpd_asym_13.0-18.5_t1w.nii b/src/BrainIAC/golden_image/mni_templates/nihpd_asym_13.0-18.5_t1w.nii new file mode 100644 index 0000000000000000000000000000000000000000..1c3f42145f63f2ef3a7b26e9ac44958491fc8aab --- /dev/null +++ b/src/BrainIAC/golden_image/mni_templates/nihpd_asym_13.0-18.5_t1w.nii @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f10804664000688f0ddc124b39ce3ae27f2c339a40583bd6ff916727e97b77d0 +size 17350930 diff --git a/src/BrainIAC/hdbet_model/0.model b/src/BrainIAC/hdbet_model/0.model new file mode 100644 index 0000000000000000000000000000000000000000..23d2336bed49651cb402e47ec75d81588aa5ce8d --- /dev/null +++ b/src/BrainIAC/hdbet_model/0.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f75233753c4750672815e2b7a86db754995ae44b8f1cd77bccfc37becd2d83c +size 65443735 diff --git a/src/BrainIAC/healthy_brain_preprocess.py b/src/BrainIAC/healthy_brain_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..041e42375b21e94ae51716df77f9c4d31cc85d25 --- /dev/null +++ b/src/BrainIAC/healthy_brain_preprocess.py @@ -0,0 +1,149 @@ +from __future__ import generators + +import logging +import glob, os, functools +import sys +sys.path.append('../') + + +import SimpleITK as sitk +import numpy as np +import scipy +import nibabel as nib +import skimage +import matplotlib.pyplot as plt +import scipy.misc +from scipy import ndimage +from skimage.transform import resize,rescale +import cv2 +import itk +import subprocess +from tqdm import tqdm +import pandas as pd +import warnings +import statistics +import torch +import csv +import os +import yaml + +from HD_BET.run import run_hd_bet # git clone HDBET repo +from dataset.preprocess_utils import enhance, enhance_noN4 +from dataset.preprocess_datasets_T1_to_2d import create_quantile_from_brain + +warnings.filterwarnings('ignore') +cuda_device = '1' +os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +#torch.cuda.set_device(1) # Set CUDA device to 0 (first GPU) +#net = net.cuda() + +def select_template_based_on_age(age): + for golden_file_path, age_values in age_ranges.items(): + if age_values['min_age'] <= int(age) and int(age) <= age_values['max_age']: + print(golden_file_path) + return golden_file_path + +def register_to_template(input_image_path, output_path, fixed_image_path,rename_id,create_subfolder=True): + fixed_image = itk.imread(fixed_image_path, itk.F) + + # Import Parameter Map + parameter_object = itk.ParameterObject.New() + parameter_object.AddParameterFile('/media/sdb/divyanshu/divyanshu/aidan_segmentation/pediatric-brain-age-main/dataset/golden_image/mni_templates/Parameters_Rigid.txt') + + if "nii" in input_image_path and "._" not in input_image_path: + print(input_image_path) + + # Call registration function + try: + moving_image = itk.imread(input_image_path, itk.F) + result_image, result_transform_parameters = itk.elastix_registration_method( + fixed_image, moving_image, + parameter_object=parameter_object, + log_to_console=False) + image_id = input_image_path.split("/")[-1] + + itk.imwrite(result_image, output_path+"/"+rename_id+".nii.gz") + + print("Registered ", rename_id) + except: + print("Cannot transform", rename_id) + +def outlier_voting(numbers): + mean = statistics.mean(numbers) + stdev = statistics.stdev(numbers) + + threshold = stdev # *2 #*3 + + good_nums_avg =[] + for n in numbers: + if n > mean + threshold or n < mean - threshold: + continue + else: + good_nums_avg.append(n) + + #if len(good_nums_avg)<=3: + # print(len(good_nums_avg)) + return np.average(good_nums_avg) + + +#Data https://openneuro.org/datasets/ds000228/versions/1.1.0 +img_path = '/media/sdb/divyanshu/divyanshu/aidan_segmentation/dummy_t1_preprocess/OAS2_0001_MR1.nii.gz' +data_path = "/media/sdb/divyanshu/divyanshu/longitudinal_fm/datasets/abide/data" +gt_age = 86 # age of subject +gender = "M" # gender +path_to = "/media/sdb/divyanshu/divyanshu/longitudinal_fm/datasets/abide/preprocessed_data" # save to + +# MNI templates http://nist.mni.mcgill.ca/pediatric-atlases-4-5-18-5y/ +age_ranges = {"/media/data/BrainIAC/src/BrainIAC/golden_image/mni_templates/nihpd_asym_04.5-08.5_t1w.nii" : {"min_age":3, "max_age":7}, + "/media/data/BrainIAC/src/BrainIAC/golden_image/mni_templates/nihpd_asym_07.5-13.5_t1w.nii": {"min_age":8, "max_age":13}, + "/media/data/BrainIAC/src/BrainIAC/golden_image/mni_templates/nihpd_asym_13.0-18.5_t1w.nii": {"min_age":14, "max_age":100}} + + +for eachimage in tqdm(os.listdir(data_path), desc="Processing images", unit="image"): + if 1:#"sub" in eachimage: + ## load image + img_path = os.path.join(data_path, eachimage) + nii= nib.load(img_path) + image, affine = nii.get_fdata(), nii.affine + #plt.imshow(image[:,:,100]) + #print(nib.aff2axcodes(affine)) + + # path to store registered image in + new_path_to = path_to#path_to+"/"+img_path.split("/")[-1].split(".")[0] + if eachimage in os.listdir(new_path_to): + print("yay") + else: + if not os.path.exists(path_to): + os.mkdir(path_to) + if not os.path.exists(new_path_to): + os.mkdir(new_path_to) + + # register image to MNI template + golden_file_path = select_template_based_on_age(gt_age) + print("Registering to template:", golden_file_path) + #fun fact: the registering to the template pipeline is not deterministic + register_to_template(img_path, new_path_to, golden_file_path,eachimage.split(".")[0]+"_"+"registered.nii.gz", create_subfolder=False) + + + # enchance and normalize image + #if not os.path.exists(new_path_to+"/no_z"): + # os.mkdir(new_path_to+"/no_z") + + image_sitk = sitk.ReadImage(os.path.join(new_path_to, eachimage.split(".")[0]+"_"+"registered.nii.gz")) + image_array = sitk.GetArrayFromImage(image_sitk) + image_array = enhance(image_array) # or enhance_noN4(image_array) if no bias field correction is needed + image3 = sitk.GetImageFromArray(image_array) + sitk.WriteImage(image3,os.path.join(new_path_to, eachimage.split(".")[0]+"_"+"registered_no_z.nii.gz")) + + #skull strip ## when running this with rest of the preprocessing, change the src path to include the registered image path!!!! + new_path_to = path_to + run_hd_bet(os.path.join(new_path_to, eachimage.split(".")[0]+"_"+"registered_no_z.nii.gz"),os.path.join(new_path_to, eachimage), + mode="accurate", + config_file='/media/sdb/divyanshu/divyanshu/aidan_segmentation/pediatric-brain-age-main/HD_BET/config.py', + device=device, + postprocess=False, + do_tta=True, + keep_mask=True, + overwrite=True) + diff --git a/src/BrainIAC/load_brainiac.py b/src/BrainIAC/load_brainiac.py new file mode 100644 index 0000000000000000000000000000000000000000..230c316a89557f39d1ed376bddf261d44a8d8610 --- /dev/null +++ b/src/BrainIAC/load_brainiac.py @@ -0,0 +1,39 @@ +import torch +from model import ResNet50_3D +import argparse + +def load_brainiac(checkpoint_path, device='cuda'): + """ + Load the ResNet50 model and BrainIAC checkpoint. + + Args: + checkpoint_path (str): Path to the model checkpoint + device (str): Device to load the model on ('cuda' or 'cpu') + + Returns: + model: Loaded model with checkpoint weights + """ + # spinup the model + model = ResNet50_3D() + + # Load brainiac weights + checkpoint = torch.load(checkpoint_path, map_location=device) + state_dict = checkpoint["state_dict"] + filtered_state_dict = {key: value for key, value in state_dict.items() if 'backbone' in key} + model.load_state_dict(filtered_state_dict) + print("BrainIAC Loaded!!") + + return model + +if __name__ == "__main__": + # Parse args + parser = argparse.ArgumentParser(description='Load backbone model from checkpoint') + parser.add_argument('--checkpoint', type=str, required=True, + help='Path to the model checkpoint') + parser.add_argument('--device', type=str, default='cuda', + help='Device to load the model on (cuda or cpu)') + args = parser.parse_args() + + # Load model + model = load_brainiac(args.checkpoint, args.device) + print(f"Model loaded successfully from {args.checkpoint}!") \ No newline at end of file diff --git a/src/BrainIAC/model.py b/src/BrainIAC/model.py new file mode 100644 index 0000000000000000000000000000000000000000..92bf8b87cb7e5678ae4f79a430ffc588a26ec271 --- /dev/null +++ b/src/BrainIAC/model.py @@ -0,0 +1,85 @@ +import torch.nn as nn +from monai.networks.nets import resnet101, resnet50, resnet18, ViT +import torch + + + +## resnet50 architecture, FC layers converted to I +class ResNet50_3D(nn.Module): + def __init__(self): + super(ResNet50_3D, self).__init__() + + resnet = resnet50(pretrained=False) # assuming you're not using a pretrained model + resnet.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) + hidden_dim = resnet.fc.in_features + self.backbone = resnet + self.backbone.fc = nn.Identity() + + def forward(self, x): + x = self.backbone(x) + return x + + + +class Classifier(nn.Module): + """ Classifier class with FC layer and single output neuron """ + def __init__(self, d_model, hidden_dim=1024, num_classes=1): + super(Classifier, self).__init__() + self.fc = nn.Linear(d_model, num_classes) + def forward(self, x): + x = self.fc(x) + return x + + +class Backbone(nn.Module): + """ ResNet 3D Backbone""" + + def __init__(self): + super(Backbone, self).__init__() + + resnet = resnet50(pretrained=False) # assuming you're not using a pretrained model + resnet.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) + hidden_dim = resnet.fc.in_features + self.backbone = resnet + self.backbone.fc = nn.Identity() + + def forward(self, x): + x = self.backbone(x) + return x + + +class SingleScanModel(nn.Module): + """ End to end model with backbone and classifier""" + + def __init__(self, backbone, classifier): + super(SingleScanModel, self).__init__() + self.backbone = backbone + self.classifier = classifier + self.dropout = nn.Dropout(p=0.2) + + + def forward(self, x): + + x = self.backbone(x) + x = self.dropout(x) + x = self.classifier(x) + return x + +class SingleScanModelBP(nn.Module): + """ End to end model with backbone and classifier that takes 2 input scans at once""" + + def __init__(self, backbone, classifier): + super(SingleScanModelBP, self).__init__() + self.backbone = backbone + self.classifier = classifier + self.dropout = nn.Dropout(p=0.2) + self.bilinear_pooling = nn.Bilinear(in1_features=2048, in2_features=2048, out_features=512) + + + def forward(self, x): + x = [self.backbone(scan) for scan in x.split(1, dim=1)] + features = torch.stack(x, dim=1).squeeze(2) + merged_features = torch.mean(features, dim=1) # Shape: (batch_size, feature_dim) + merged_features = self.dropout(merged_features) + output = self.classifier(merged_features) + return output \ No newline at end of file diff --git a/src/BrainIAC/preprocessing/HDBET_Code/.DS_Store b/src/BrainIAC/preprocessing/HDBET_Code/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..cde915ad96f102ec333b6157854d8779e38406d1 Binary files /dev/null and b/src/BrainIAC/preprocessing/HDBET_Code/.DS_Store differ diff --git a/src/BrainIAC/preprocessing/HD_BET/__pycache__/config.cpython-39.pyc b/src/BrainIAC/preprocessing/HD_BET/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a995f59b6f85eabc025629df1c500b698f67bb5 Binary files /dev/null and b/src/BrainIAC/preprocessing/HD_BET/__pycache__/config.cpython-39.pyc differ diff --git a/src/BrainIAC/preprocessing/HD_BET/__pycache__/data_loading.cpython-39.pyc b/src/BrainIAC/preprocessing/HD_BET/__pycache__/data_loading.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90c38b57865fe2c9e2384f17a6136fd5c7f9a991 Binary files /dev/null and b/src/BrainIAC/preprocessing/HD_BET/__pycache__/data_loading.cpython-39.pyc differ diff --git a/src/BrainIAC/preprocessing/HD_BET/__pycache__/hd_bet.cpython-39.pyc b/src/BrainIAC/preprocessing/HD_BET/__pycache__/hd_bet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0f828a3117d36e2f89037ed3d029d9fe39864e8 Binary files /dev/null and b/src/BrainIAC/preprocessing/HD_BET/__pycache__/hd_bet.cpython-39.pyc differ diff --git a/src/BrainIAC/preprocessing/HD_BET/__pycache__/network_architecture.cpython-39.pyc b/src/BrainIAC/preprocessing/HD_BET/__pycache__/network_architecture.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b844d1733cc804934064353aaa69edc523bab254 Binary files /dev/null and b/src/BrainIAC/preprocessing/HD_BET/__pycache__/network_architecture.cpython-39.pyc differ diff --git a/src/BrainIAC/preprocessing/HD_BET/__pycache__/paths.cpython-39.pyc b/src/BrainIAC/preprocessing/HD_BET/__pycache__/paths.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18dd04e17027dd282fc5ece3396f18c13eb6e85d Binary files /dev/null and b/src/BrainIAC/preprocessing/HD_BET/__pycache__/paths.cpython-39.pyc differ diff --git a/src/BrainIAC/preprocessing/HD_BET/__pycache__/predict_case.cpython-39.pyc b/src/BrainIAC/preprocessing/HD_BET/__pycache__/predict_case.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0594e090eac89fb59c9583820c6562556255637 Binary files /dev/null and b/src/BrainIAC/preprocessing/HD_BET/__pycache__/predict_case.cpython-39.pyc differ diff --git a/src/BrainIAC/preprocessing/HD_BET/__pycache__/run.cpython-39.pyc b/src/BrainIAC/preprocessing/HD_BET/__pycache__/run.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77abae02891cd3f7e1bd7673081f294f0db2054f Binary files /dev/null and b/src/BrainIAC/preprocessing/HD_BET/__pycache__/run.cpython-39.pyc differ diff --git a/src/BrainIAC/preprocessing/HD_BET/__pycache__/utils.cpython-39.pyc b/src/BrainIAC/preprocessing/HD_BET/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0e84846be112dcb4b7e3c7b2c6205656747a0b9 Binary files /dev/null and b/src/BrainIAC/preprocessing/HD_BET/__pycache__/utils.cpython-39.pyc differ diff --git a/src/BrainIAC/preprocessing/HD_BET/config.py b/src/BrainIAC/preprocessing/HD_BET/config.py new file mode 100644 index 0000000000000000000000000000000000000000..870951e5c9059fb9e20d6143e68266732f19234e --- /dev/null +++ b/src/BrainIAC/preprocessing/HD_BET/config.py @@ -0,0 +1,121 @@ +import numpy as np +import torch +from HD_BET.utils import SetNetworkToVal, softmax_helper +from abc import abstractmethod +from HD_BET.network_architecture import Network + + +class BaseConfig(object): + def __init__(self): + pass + + @abstractmethod + def get_split(self, fold, random_state=12345): + pass + + @abstractmethod + def get_network(self, mode="train"): + pass + + @abstractmethod + def get_basic_generators(self, fold): + pass + + @abstractmethod + def get_data_generators(self, fold): + pass + + def preprocess(self, data): + return data + + def __repr__(self): + res = "" + for v in vars(self): + if not v.startswith("__") and not v.startswith("_") and v != 'dataset': + res += (v + ": " + str(self.__getattribute__(v)) + "\n") + return res + + +class HD_BET_Config(BaseConfig): + def __init__(self): + super(HD_BET_Config, self).__init__() + + self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name + + # network parameters + self.net_base_num_layers = 21 + self.BATCH_SIZE = 2 + self.net_do_DS = True + self.net_dropout_p = 0.0 + self.net_use_inst_norm = True + self.net_conv_use_bias = True + self.net_norm_use_affine = True + self.net_leaky_relu_slope = 1e-1 + + # hyperparameters + self.INPUT_PATCH_SIZE = (128, 128, 128) + self.num_classes = 2 + self.selected_data_channels = range(1) + + # data augmentation + self.da_mirror_axes = (2, 3, 4) + + # validation + self.val_use_DO = False + self.val_use_train_mode = False # for dropout sampling + self.val_num_repeats = 1 # only useful if dropout sampling + self.val_batch_size = 1 # only useful if dropout sampling + self.val_save_npz = True + self.val_do_mirroring = True # test time data augmentation via mirroring + self.val_write_images = True + self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property + self.val_min_size = self.INPUT_PATCH_SIZE + self.val_fn = None + + # CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_ + # stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults + # to false in 0.4) + self.val_use_moving_averages = False + + def get_network(self, train=True, pretrained_weights=None): + net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers, + self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias, + self.net_norm_use_affine, True, self.net_do_DS) + + if pretrained_weights is not None: + net.load_state_dict( + torch.load(pretrained_weights, map_location=lambda storage, loc: storage)) + + if train: + net.train(True) + else: + net.train(False) + net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages)) + net.do_ds = False + + optimizer = None + self.lr_scheduler = None + return net, optimizer + + def get_data_generators(self, fold): + pass + + def get_split(self, fold, random_state=12345): + pass + + def get_basic_generators(self, fold): + pass + + def on_epoch_end(self, epoch): + pass + + def preprocess(self, data): + data = np.copy(data) + for c in range(data.shape[0]): + data[c] -= data[c].mean() + data[c] /= data[c].std() + return data + + +config = HD_BET_Config + diff --git a/src/BrainIAC/preprocessing/HD_BET/data_loading.py b/src/BrainIAC/preprocessing/HD_BET/data_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec4be63a8186b65bfb390770fefa6217b5dd2c5 --- /dev/null +++ b/src/BrainIAC/preprocessing/HD_BET/data_loading.py @@ -0,0 +1,121 @@ +import SimpleITK as sitk +import numpy as np +from skimage.transform import resize + + +def resize_image(image, old_spacing, new_spacing, order=3): + new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), + int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), + int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) + return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False) + + +def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)): + spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]] + image = sitk.GetArrayFromImage(itk_image).astype(float) + + assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed" + + if not is_seg: + if np.any([[i != j] for i, j in zip(spacing, spacing_target)]): + image = resize_image(image, spacing, spacing_target).astype(np.float32) + + image -= image.mean() + image /= image.std() + else: + new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))), + int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))), + int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2])))) + image = resize_segmentation(image, new_shape, 1) + return image + + +def load_and_preprocess(mri_file): + images = {} + # t1 + images["T1"] = sitk.ReadImage(mri_file) + + properties_dict = { + "spacing": images["T1"].GetSpacing(), + "direction": images["T1"].GetDirection(), + "size": images["T1"].GetSize(), + "origin": images["T1"].GetOrigin() + } + + for k in images.keys(): + images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5)) + + properties_dict['size_before_cropping'] = images["T1"].shape + + imgs = [] + for seq in ['T1']: + imgs.append(images[seq][None]) + all_data = np.vstack(imgs) + print("image shape after preprocessing: ", str(all_data[0].shape)) + return all_data, properties_dict + + +def save_segmentation_nifti(segmentation, dct, out_fname, order=1): + ''' + segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out + of the original image + + dct: + size_before_cropping + brain_bbox + size -> this is the original size of the dataset, if the image was not resampled, this is the same as size_before_cropping + spacing + origin + direction + + :param segmentation: + :param dct: + :param out_fname: + :return: + ''' + old_size = dct.get('size_before_cropping') + bbox = dct.get('brain_bbox') + if bbox is not None: + seg_old_size = np.zeros(old_size) + for c in range(3): + bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c])) + seg_old_size[bbox[0][0]:bbox[0][1], + bbox[1][0]:bbox[1][1], + bbox[2][0]:bbox[2][1]] = segmentation + else: + seg_old_size = segmentation + if np.any(np.array(seg_old_size) != np.array(dct['size'])[[2, 1, 0]]): + seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order) + else: + seg_old_spacing = seg_old_size + seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(np.int32)) + seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]]) + seg_resized_itk.SetOrigin(dct['origin']) + seg_resized_itk.SetDirection(dct['direction']) + sitk.WriteImage(seg_resized_itk, out_fname) + + +def resize_segmentation(segmentation, new_shape, order=3, cval=0): + ''' + Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency + + Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one + hot encoding which is resized and transformed back to a segmentation map. + This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2]) + :param segmentation: + :param new_shape: + :param order: + :return: + ''' + tpe = segmentation.dtype + unique_labels = np.unique(segmentation) + assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation" + if order == 0: + return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe) + else: + reshaped = np.zeros(new_shape, dtype=segmentation.dtype) + + for i, c in enumerate(unique_labels): + reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False) + reshaped[reshaped_multihot >= 0.5] = c + return reshaped diff --git a/src/BrainIAC/preprocessing/HD_BET/hd_bet.py b/src/BrainIAC/preprocessing/HD_BET/hd_bet.py new file mode 100644 index 0000000000000000000000000000000000000000..128575b6cfb4bdd98bf417ed598f905ef4896fd1 --- /dev/null +++ b/src/BrainIAC/preprocessing/HD_BET/hd_bet.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python + +import os +import sys +sys.path.append("/mnt/93E8-0534/AIDAN/HDBET/") +from HD_BET.run import run_hd_bet +from HD_BET.utils import maybe_mkdir_p, subfiles +import HD_BET + +def hd_bet(input_file_or_dir,output_file_or_dir,mode,device,tta,pp=1,save_mask=0,overwrite_existing=1): + + if output_file_or_dir is None: + output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir), + os.path.basename(input_file_or_dir).split(".")[0] + "_bet") + + + params_file = os.path.join(HD_BET.__path__[0], "model_final.py") + config_file = os.path.join(HD_BET.__path__[0], "config.py") + + assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input" + + if device == 'cpu': + pass + else: + device = int(device) + + if os.path.isdir(input_file_or_dir): + maybe_mkdir_p(output_file_or_dir) + input_files = subfiles(input_file_or_dir, suffix='_0000.nii.gz', join=False) + + if len(input_files) == 0: + raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here") + + output_files = [os.path.join(output_file_or_dir, i) for i in input_files] + input_files = [os.path.join(input_file_or_dir, i) for i in input_files] + else: + if not output_file_or_dir.endswith('.nii.gz'): + output_file_or_dir += '.nii.gz' + assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input" + + output_files = [output_file_or_dir] + input_files = [input_file_or_dir] + + if tta == 0: + tta = False + elif tta == 1: + tta = True + else: + raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta)) + + if overwrite_existing == 0: + overwrite_existing = False + elif overwrite_existing == 1: + overwrite_existing = True + else: + raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing)) + + if pp == 0: + pp = False + elif pp == 1: + pp = True + else: + raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp)) + + if save_mask == 0: + save_mask = False + elif save_mask == 1: + save_mask = True + else: + raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp)) + + run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing) + + +if __name__ == "__main__": + print("\n########################") + print("If you are using hd-bet, please cite the following paper:") + print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W," + "Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial" + "neural networks. arXiv preprint arXiv:1901.11341, 2019.") + print("########################\n") + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be ' + 'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to ' + 'split 4d sequences into 3d images. If folder: all files ending with .nii.gz ' + 'within that folder will be brain extracted.', required=True, type=str) + parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder' + ' will be created', required=False, type=str) + parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will ' + 'use only one set of parameters whereas accurate will ' + 'use the five sets of parameters that resulted from ' + 'our cross-validation as an ensemble. Default: ' + 'accurate', + required=False) + parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. ' + 'Must be either int or str. Use int for GPU id or ' + '\'cpu\' to run on CPU. When using CPU you should ' + 'consider disabling tta. Default for -device is: 0', + required=False) + parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation ' + '(mirroring). 1= True, 0=False. Disable this ' + 'if you are using CPU to speed things up! ' + 'Default: 1') + parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all' + ' but the largest connected component in ' + 'the prediction. Default: 1') + parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation ' + 'mask will not be ' + 'saved') + parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't " + "want to overwrite existing " + "predictions") + + args = parser.parse_args() + + hd_bet(args.input,args.output,args.mode,args.device,args.tta,args.pp,args.save_mask,args.overwrite_existing) + diff --git a/src/BrainIAC/preprocessing/HD_BET/network_architecture.py b/src/BrainIAC/preprocessing/HD_BET/network_architecture.py new file mode 100644 index 0000000000000000000000000000000000000000..0824aa10839024368ad8ab38c637ce81aa9327e5 --- /dev/null +++ b/src/BrainIAC/preprocessing/HD_BET/network_architecture.py @@ -0,0 +1,213 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from HD_BET.utils import softmax_helper + + +class EncodingModule(nn.Module): + def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True): + nn.Module.__init__(self) + self.dropout_p = dropout_p + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) + self.dropout = nn.Dropout3d(dropout_p) + self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) + + def forward(self, x): + skip = x + x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = self.conv1(x) + if self.dropout_p is not None and self.dropout_p > 0: + x = self.dropout(x) + x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = self.conv2(x) + x = x + skip + return x + + +class Upsample(nn.Module): + def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True): + super(Upsample, self).__init__() + self.align_corners = align_corners + self.mode = mode + self.scale_factor = scale_factor + self.size = size + + def forward(self, x): + return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, + align_corners=self.align_corners) + + +class LocalizationModule(nn.Module): + def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, + lrelu_inplace=True): + nn.Module.__init__(self) + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias) + self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias) + self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) + + def forward(self, x): + x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + return x + + +class UpsamplingModule(nn.Module): + def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, + lrelu_inplace=True): + nn.Module.__init__(self) + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True) + self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias) + self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) + + def forward(self, x): + x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness, + inplace=self.lrelu_inplace) + return x + + +class DownsamplingModule(nn.Module): + def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, + lrelu_inplace=True): + nn.Module.__init__(self) + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias) + + def forward(self, x): + x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + b = self.downsample(x) + return x, b + + +class Network(nn.Module): + def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3, + final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, + lrelu_inplace=True, do_ds=True): + super(Network, self).__init__() + + self.do_ds = do_ds + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.final_nonlin = final_nonlin + self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias) + + self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2, + conv_bias=True, inst_norm_affine=True, lrelu_inplace=True) + + self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) + self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False) + self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) + self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) + self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) + self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) + self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) + self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) + + def forward(self, x): + seg_outputs = [] + + x = self.init_conv(x) + x = self.context1(x) + + skip1, x = self.down1(x) + x = self.context2(x) + + skip2, x = self.down2(x) + x = self.context3(x) + + skip3, x = self.down3(x) + x = self.context4(x) + + skip4, x = self.down4(x) + x = self.context5(x) + + x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = self.up1(x) + + x = torch.cat((skip4, x), dim=1) + x = self.loc1(x) + x = self.up2(x) + + x = torch.cat((skip3, x), dim=1) + x = self.loc2(x) + loc2_seg = self.final_nonlin(self.loc2_seg(x)) + seg_outputs.append(loc2_seg) + x = self.up3(x) + + x = torch.cat((skip2, x), dim=1) + x = self.loc3(x) + loc3_seg = self.final_nonlin(self.loc3_seg(x)) + seg_outputs.append(loc3_seg) + x = self.up4(x) + + x = torch.cat((skip1, x), dim=1) + x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness, + inplace=self.lrelu_inplace) + x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness, + inplace=self.lrelu_inplace) + x = self.final_nonlin(self.seg_layer(x)) + seg_outputs.append(x) + + if self.do_ds: + return seg_outputs[::-1] + else: + return seg_outputs[-1] diff --git a/src/BrainIAC/preprocessing/HD_BET/paths.py b/src/BrainIAC/preprocessing/HD_BET/paths.py new file mode 100644 index 0000000000000000000000000000000000000000..e78cad195250591329a897705a1b1f58ec9a60ed --- /dev/null +++ b/src/BrainIAC/preprocessing/HD_BET/paths.py @@ -0,0 +1,6 @@ +import os + +# please refer to the readme on where to get the parameters. Save them in this folder: +current_dir = os.path.dirname(os.path.abspath(__file__)) +preprocessing_dir = os.path.dirname(current_dir) +folder_with_parameter_files = os.path.join(preprocessing_dir, 'hd-bet_params') diff --git a/src/BrainIAC/preprocessing/HD_BET/predict_case.py b/src/BrainIAC/preprocessing/HD_BET/predict_case.py new file mode 100644 index 0000000000000000000000000000000000000000..559c66739ae890f7e985e072eb49ce0ee0484978 --- /dev/null +++ b/src/BrainIAC/preprocessing/HD_BET/predict_case.py @@ -0,0 +1,126 @@ +import torch +import numpy as np + + +def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None): + if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)): + shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3 + shp = patient.shape + new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0], + shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1], + shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]] + for i in range(len(shp)): + if shp[i] % shape_must_be_divisible_by[i] == 0: + new_shp[i] -= shape_must_be_divisible_by[i] + if min_size is not None: + new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0) + return reshape_by_padding_upper_coords(patient, new_shp, 0), shp + + +def reshape_by_padding_upper_coords(image, new_shape, pad_value=None): + shape = tuple(list(image.shape)) + new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0)) + if pad_value is None: + if len(shape) == 2: + pad_value = image[0,0] + elif len(shape) == 3: + pad_value = image[0, 0, 0] + else: + raise ValueError("Image must be either 2 or 3 dimensional") + res = np.ones(list(new_shape), dtype=image.dtype) * pad_value + if len(shape) == 2: + res[0:0+int(shape[0]), 0:0+int(shape[1])] = image + elif len(shape) == 3: + res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image + return res + + +def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None, + new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)): + with torch.no_grad(): + pad_res = [] + for i in range(patient_data.shape[0]): + t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size) + pad_res.append(t[None]) + + patient_data = np.vstack(pad_res) + + new_shp = patient_data.shape + + data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32) + + data[0] = patient_data + + if BATCH_SIZE is not None: + data = np.vstack([data] * BATCH_SIZE) + + a = torch.rand(data.shape).float() + + if main_device == 'cpu': + pass + else: + a = a.cuda(main_device) + + if do_mirroring: + x = 8 + else: + x = 1 + all_preds = [] + for i in range(num_repeats): + for m in range(x): + data_for_net = np.array(data) + do_stuff = False + if m == 0: + do_stuff = True + pass + if m == 1 and (4 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, :, :, ::-1] + if m == 2 and (3 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, :, ::-1, :] + if m == 3 and (4 in mirror_axes) and (3 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, :, ::-1, ::-1] + if m == 4 and (2 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, ::-1, :, :] + if m == 5 and (2 in mirror_axes) and (4 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, ::-1, :, ::-1] + if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, ::-1, ::-1, :] + if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1] + + if do_stuff: + _ = a.data.copy_(torch.from_numpy(np.copy(data_for_net))) + p = net(a) # np.copy is necessary because ::-1 creates just a view i think + p = p.data.cpu().numpy() + + if m == 0: + pass + if m == 1 and (4 in mirror_axes): + p = p[:, :, :, :, ::-1] + if m == 2 and (3 in mirror_axes): + p = p[:, :, :, ::-1, :] + if m == 3 and (4 in mirror_axes) and (3 in mirror_axes): + p = p[:, :, :, ::-1, ::-1] + if m == 4 and (2 in mirror_axes): + p = p[:, :, ::-1, :, :] + if m == 5 and (2 in mirror_axes) and (4 in mirror_axes): + p = p[:, :, ::-1, :, ::-1] + if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): + p = p[:, :, ::-1, ::-1, :] + if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): + p = p[:, :, ::-1, ::-1, ::-1] + all_preds.append(p) + + stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]] + predicted_segmentation = stacked.mean(0).argmax(0) + uncertainty = stacked.var(0) + bayesian_predictions = stacked + softmax_pred = stacked.mean(0) + return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty diff --git a/src/BrainIAC/preprocessing/HD_BET/run.py b/src/BrainIAC/preprocessing/HD_BET/run.py new file mode 100644 index 0000000000000000000000000000000000000000..858934d8f67175df508884e9030f8d38ba0d07cf --- /dev/null +++ b/src/BrainIAC/preprocessing/HD_BET/run.py @@ -0,0 +1,117 @@ +import torch +import numpy as np +import SimpleITK as sitk +from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti +from HD_BET.predict_case import predict_case_3D_net +import imp +from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters +import os +import HD_BET + + +def apply_bet(img, bet, out_fname): + img_itk = sitk.ReadImage(img) + img_npy = sitk.GetArrayFromImage(img_itk) + img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet)) + img_npy[img_bet == 0] = 0 + out = sitk.GetImageFromArray(img_npy) + out.CopyInformation(img_itk) + sitk.WriteImage(out, out_fname) + + +def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0, + postprocess=False, do_tta=True, keep_mask=True, overwrite=True): + """ + + :param mri_fnames: str or list/tuple of str + :param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames + :param mode: fast or accurate + :param config_file: config.py + :param device: either int (for device id) or 'cpu' + :param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all + but the largest predicted connected component. Default False + :param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use + CPU you may want to turn that off to speed things up + :return: + """ + + list_of_param_files = [] + + if mode == 'fast': + params_file = get_params_fname(0) + maybe_download_parameters(0) + + list_of_param_files.append(params_file) + elif mode == 'accurate': + for i in range(5): + params_file = get_params_fname(i) + maybe_download_parameters(i) + + list_of_param_files.append(params_file) + else: + raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode) + + assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files" + + cf = imp.load_source('cf', config_file) + cf = cf.config() + + net, _ = cf.get_network(cf.val_use_train_mode, None) + if device == "cpu": + net = net.cpu() + else: + net.cuda(device) + + if not isinstance(mri_fnames, (list, tuple)): + mri_fnames = [mri_fnames] + + if not isinstance(output_fnames, (list, tuple)): + output_fnames = [output_fnames] + + assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length" + + params = [] + for p in list_of_param_files: + params.append(torch.load(p, map_location=lambda storage, loc: storage)) + + for in_fname, out_fname in zip(mri_fnames, output_fnames): + mask_fname = out_fname[:-7] + "_mask.nii.gz" + if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)): + print("File:", in_fname) + print("preprocessing...") + try: + data, data_dict = load_and_preprocess(in_fname) + except RuntimeError: + print("\nERROR\nCould not read file", in_fname, "\n") + continue + except AssertionError as e: + print(e) + continue + + softmax_preds = [] + + print("prediction (CNN id)...") + for i, p in enumerate(params): + print(i) + net.load_state_dict(p) + net.eval() + net.apply(SetNetworkToVal(False, False)) + _, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats, + cf.val_batch_size, cf.net_input_must_be_divisible_by, + cf.val_min_size, device, cf.da_mirror_axes) + softmax_preds.append(softmax_pred[None]) + + seg = np.argmax(np.vstack(softmax_preds).mean(0), 0) + + if postprocess: + seg = postprocess_prediction(seg) + + print("exporting segmentation...") + save_segmentation_nifti(seg, data_dict, mask_fname) + + apply_bet(in_fname, mask_fname, out_fname) + + if not keep_mask: + os.remove(mask_fname) + + diff --git a/src/BrainIAC/preprocessing/HD_BET/utils.py b/src/BrainIAC/preprocessing/HD_BET/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ba72a3d4d70accfd1fdc313a2f80b8d4c4c6eea --- /dev/null +++ b/src/BrainIAC/preprocessing/HD_BET/utils.py @@ -0,0 +1,115 @@ +from urllib.request import urlopen +import torch +from torch import nn +import numpy as np +from skimage.morphology import label +import os +from HD_BET.paths import folder_with_parameter_files + + +def get_params_fname(fold): + return os.path.join(folder_with_parameter_files, "%d.model" % fold) + + +def maybe_download_parameters(fold=0, force_overwrite=False): + """ + Downloads the parameters for some fold if it is not present yet. + :param fold: + :param force_overwrite: if True the old parameter file will be deleted (if present) prior to download + :return: + """ + + assert 0 <= fold <= 4, "fold must be between 0 and 4" + + if not os.path.isdir(folder_with_parameter_files): + maybe_mkdir_p(folder_with_parameter_files) + + out_filename = get_params_fname(fold) + + if force_overwrite and os.path.isfile(out_filename): + os.remove(out_filename) + + if not os.path.isfile(out_filename): + url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold + print("Downloading", url, "...") + data = urlopen(url).read() + #out_filename = "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params/0.model" + with open(out_filename, 'wb') as f: + f.write(data) + + +def init_weights(module): + if isinstance(module, nn.Conv3d): + module.weight = nn.init.kaiming_normal(module.weight, a=1e-2) + if module.bias is not None: + module.bias = nn.init.constant(module.bias, 0) + + +def softmax_helper(x): + rpt = [1 for _ in range(len(x.size()))] + rpt[1] = x.size(1) + x_max = x.max(1, keepdim=True)[0].repeat(*rpt) + e_x = torch.exp(x - x_max) + return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) + + +class SetNetworkToVal(object): + def __init__(self, use_dropout_sampling=False, norm_use_average=True): + self.norm_use_average = norm_use_average + self.use_dropout_sampling = use_dropout_sampling + + def __call__(self, module): + if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout): + module.train(self.use_dropout_sampling) + elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \ + isinstance(module, nn.InstanceNorm1d) \ + or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \ + isinstance(module, nn.BatchNorm1d): + module.train(not self.norm_use_average) + + +def postprocess_prediction(seg): + # basically look for connected components and choose the largest one, delete everything else + print("running postprocessing... ") + mask = seg != 0 + lbls = label(mask, connectivity=mask.ndim) + lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)] + largest_region = np.argmax(lbls_sizes[1:]) + 1 + seg[lbls != largest_region] = 0 + return seg + + +def subdirs(folder, join=True, prefix=None, suffix=None, sort=True): + if join: + l = os.path.join + else: + l = lambda x, y: y + res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i)) + and (prefix is None or i.startswith(prefix)) + and (suffix is None or i.endswith(suffix))] + if sort: + res.sort() + return res + + +def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): + if join: + l = os.path.join + else: + l = lambda x, y: y + res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) + and (prefix is None or i.startswith(prefix)) + and (suffix is None or i.endswith(suffix))] + if sort: + res.sort() + return res + + +subfolders = subdirs # I am tired of confusing those + + +def maybe_mkdir_p(directory): + splits = directory.split("/")[1:] + for i in range(0, len(splits)): + if not os.path.isdir(os.path.join("", *splits[:i+1])): + os.mkdir(os.path.join("", *splits[:i+1])) diff --git a/src/BrainIAC/preprocessing/__init__.py b/src/BrainIAC/preprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/BrainIAC/preprocessing/atlases/nihpd_asym_13.0-18.5_t1w.nii b/src/BrainIAC/preprocessing/atlases/nihpd_asym_13.0-18.5_t1w.nii new file mode 100644 index 0000000000000000000000000000000000000000..1c3f42145f63f2ef3a7b26e9ac44958491fc8aab --- /dev/null +++ b/src/BrainIAC/preprocessing/atlases/nihpd_asym_13.0-18.5_t1w.nii @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f10804664000688f0ddc124b39ce3ae27f2c339a40583bd6ff916727e97b77d0 +size 17350930 diff --git a/src/BrainIAC/preprocessing/atlases/nihpd_asym_13.0-18.5_t2w.nii b/src/BrainIAC/preprocessing/atlases/nihpd_asym_13.0-18.5_t2w.nii new file mode 100644 index 0000000000000000000000000000000000000000..825bf4f1cc52c642511442ba178374884ac8fd44 --- /dev/null +++ b/src/BrainIAC/preprocessing/atlases/nihpd_asym_13.0-18.5_t2w.nii @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8496034ae6dbfba1037a5bb51894f28034f73116cc01921a2e7c8b8938b746af +size 17350930 diff --git a/src/BrainIAC/preprocessing/atlases/temp_head.nii.gz b/src/BrainIAC/preprocessing/atlases/temp_head.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..fd8ebd2c8d31b86879b2e0a30732ca0cc1b56ee3 --- /dev/null +++ b/src/BrainIAC/preprocessing/atlases/temp_head.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8eba10e9528a5212b78fb1be477d5e8229751bfec5e1e1a648dc81ff83a622e +size 11976113 diff --git a/src/BrainIAC/preprocessing/dicomtonifti_2.py b/src/BrainIAC/preprocessing/dicomtonifti_2.py new file mode 100644 index 0000000000000000000000000000000000000000..d02c0a3381fd6a950538c436197d8cbf6e73301a --- /dev/null +++ b/src/BrainIAC/preprocessing/dicomtonifti_2.py @@ -0,0 +1,97 @@ +import SimpleITK as sitk +import os +import argparse +import glob +from tqdm import tqdm + +def convert_dicom_series_to_nifti(dicom_dir, output_file): + """ + Convert a single DICOM series to NIFTI format + Args: + dicom_dir: Directory containing DICOM files for one scan + output_file: Output NIFTI file path + Returns: + bool: True if conversion successful, False otherwise + """ + try: + reader = sitk.ImageSeriesReader() + + # get all the scans in the dir + dicom_files = sorted(glob.glob(os.path.join(dicom_dir, "*.dcm"))) + + if not dicom_files: + print(f"No DICOM files found in: {dicom_dir}") + return False + + reader.SetFileNames(dicom_files) + + # load dicom images + image = reader.Execute() + + + sitk.WriteImage(image, output_file) + return True + + except Exception as e: + print(f"Error converting {dicom_dir}: {str(e)}") + return False + +def convert_dicom_to_nifti(input_dir, output_dir): + """ + Convert multiple DICOM series to NIFTI format + Args: + input_dir: Root directory containing subdirectories of DICOM series + output_dir: Output directory for NIFTI files + """ + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + + input_dir = os.path.abspath(input_dir) + output_dir = os.path.abspath(output_dir) + + print(f"Looking for DICOM series in: {input_dir}") + + # Check if directory exists + if not os.path.isdir(input_dir): + print(f"Error: {input_dir} is not a directory") + return + + + scan_dirs = [d for d in os.listdir(input_dir) + if os.path.isdir(os.path.join(input_dir, d))] + + if not scan_dirs: + print("No subdirectories found in the input directory") + return + + print(f"Found {len(scan_dirs)} potential scan directories") + + # Process each scan directory + successful = 0 + failed = 0 + + for scan_dir in tqdm(scan_dirs, desc="Converting scans"): + input_path = os.path.join(input_dir, scan_dir) + output_file = os.path.join(output_dir, f"{scan_dir}.nii.gz") + + if convert_dicom_series_to_nifti(input_path, output_file): + successful += 1 + else: + failed += 1 + + print("\nConversion Summary:") + print(f"Successfully converted: {successful} scans") + print(f"Failed conversions: {failed} scans") + print(f"Output directory: {output_dir}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert DICOM series to NIFTI format") + parser.add_argument("--input", "-i", required=True, + help="Input directory containing subdirectories of DICOM series") + parser.add_argument("--output", "-o", required=True, + help="Output directory for NIFTI files") + + args = parser.parse_args() + + convert_dicom_to_nifti(args.input, args.output) \ No newline at end of file diff --git a/src/BrainIAC/preprocessing/hd-bet_params/0.model b/src/BrainIAC/preprocessing/hd-bet_params/0.model new file mode 100644 index 0000000000000000000000000000000000000000..23d2336bed49651cb402e47ec75d81588aa5ce8d --- /dev/null +++ b/src/BrainIAC/preprocessing/hd-bet_params/0.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f75233753c4750672815e2b7a86db754995ae44b8f1cd77bccfc37becd2d83c +size 65443735 diff --git a/src/BrainIAC/preprocessing/mri_preprocess_3d_simple.py b/src/BrainIAC/preprocessing/mri_preprocess_3d_simple.py new file mode 100644 index 0000000000000000000000000000000000000000..05b6ce81c51fc76b4cb30c37c849a70bebc458c8 --- /dev/null +++ b/src/BrainIAC/preprocessing/mri_preprocess_3d_simple.py @@ -0,0 +1,215 @@ +import sys +import os +import glob +import SimpleITK as sitk +from tqdm import tqdm +import random +from HD_BET.hd_bet import hd_bet +import argparse +import torch + +def brain_extraction(input_dir, output_dir, device): + """ + Brain extraction using HDBET package (UNet based DL method) + Args: + input_dir {path} -- input directory for registered images + output_dir {path} -- output directory for brain extracted images + Returns: + Brain images + """ + print("Running brain extraction...") + print(f"Input directory: {input_dir}") + print(f"Output directory: {output_dir}") + + # Run HD-BET directly with the output directory + hd_bet(input_dir, output_dir, device=device, mode='fast', tta=0) + + print('Brain extraction complete!') + print("\nContents of output directory after brain extraction:") + print(os.listdir(output_dir)) + +def registration(input_dir, output_dir, temp_img, interp_type='linear'): + """ + MRI registration with SimpleITK + Args: + input_dir {path} -- Directory containing input images + output_dir {path} -- Directory to save registered images + temp_img {str} -- Registration image template + Returns: + The sitk image object -- nii.gz + """ + + # Read the template image + fixed_img = sitk.ReadImage(temp_img, sitk.sitkFloat32) + + # Track problematic files + IDs = [] + print("Preloading step...") + for img_dir in tqdm(sorted(glob.glob(input_dir + '/*.nii.gz'))): + ID = img_dir.split('/')[-1].split('.')[0] + try: + moving_img = sitk.ReadImage(img_dir, sitk.sitkFloat32) + except Exception as e: + IDs.append(ID) + print(f"Error loading {ID}: {e}") + + count = 0 + print("Registering images...") + list_of_files = sorted(glob.glob(input_dir + '/*.nii.gz')) + + for img_dir in tqdm(list_of_files): + ID = img_dir.split('/')[-1].split('.')[0] + if ID in IDs: + print(f'Skipping problematic file: {ID}') + continue + + if "_mask" in ID: + continue + + print(f"Processing image {count + 1}: {ID}") + + try: + # Read and preprocess moving image + moving_img = sitk.ReadImage(img_dir, sitk.sitkFloat32) + moving_img = sitk.N4BiasFieldCorrection(moving_img) + + # Resample fixed image to 1mm isotropic + old_size = fixed_img.GetSize() + old_spacing = fixed_img.GetSpacing() + new_spacing = (1, 1, 1) + new_size = [ + int(round((old_size[0] * old_spacing[0]) / float(new_spacing[0]))), + int(round((old_size[1] * old_spacing[1]) / float(new_spacing[1]))), + int(round((old_size[2] * old_spacing[2]) / float(new_spacing[2]))) + ] + + # Set interpolation type + if interp_type == 'linear': + interp_type = sitk.sitkLinear + elif interp_type == 'bspline': + interp_type = sitk.sitkBSpline + elif interp_type == 'nearest_neighbor': + interp_type = sitk.sitkNearestNeighbor + + # Resample fixed image + resample = sitk.ResampleImageFilter() + resample.SetOutputSpacing(new_spacing) + resample.SetSize(new_size) + resample.SetOutputOrigin(fixed_img.GetOrigin()) + resample.SetOutputDirection(fixed_img.GetDirection()) + resample.SetInterpolator(interp_type) + resample.SetDefaultPixelValue(fixed_img.GetPixelIDValue()) + resample.SetOutputPixelType(sitk.sitkFloat32) + fixed_img = resample.Execute(fixed_img) + + # Initialize transform + transform = sitk.CenteredTransformInitializer( + fixed_img, + moving_img, + sitk.Euler3DTransform(), + sitk.CenteredTransformInitializerFilter.GEOMETRY) + + # Set up registration method + registration_method = sitk.ImageRegistrationMethod() + registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) + registration_method.SetMetricSamplingStrategy(registration_method.RANDOM) + registration_method.SetMetricSamplingPercentage(0.01) + registration_method.SetInterpolator(sitk.sitkLinear) + registration_method.SetOptimizerAsGradientDescent( + learningRate=1.0, + numberOfIterations=100, + convergenceMinimumValue=1e-6, + convergenceWindowSize=10) + registration_method.SetOptimizerScalesFromPhysicalShift() + registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1]) + registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0]) + registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() + registration_method.SetInitialTransform(transform) + + # Execute registration + final_transform = registration_method.Execute(fixed_img, moving_img) + + # Apply transform and save registered image + moving_img_resampled = sitk.Resample( + moving_img, + fixed_img, + final_transform, + sitk.sitkLinear, + 0.0, + moving_img.GetPixelID()) + + # Save with _0000 suffix as required by HD-BET + output_filename = os.path.join(output_dir, f"{ID}_0000.nii.gz") + sitk.WriteImage(moving_img_resampled, output_filename) + print(f"Saved registered image to: {output_filename}") + count += 1 + + except Exception as e: + print(f"Error processing {ID}: {e}") + continue + + print(f"Successfully registered {count} images.") + # Debug information + print(f"Contents of output directory {output_dir}:") + print(os.listdir(output_dir)) + return count > 0 + +def main(temp_img, input_dir, output_dir): + """ + Main function to process brain MRI images + Args: + temp_img {str} -- Path to template image + input_dir {str} -- Path to input directory containing images + output_dir {str} -- Path to output directory for results + """ + + os.makedirs(output_dir, exist_ok=True) + + # set device + device = "0" if torch.cuda.is_available() else "cpu" + + # Create temporary directory for intermediate results + temp_reg_dir = os.path.join(output_dir, 'temp_registered') + os.makedirs(temp_reg_dir, exist_ok=True) + + print("Starting brain MRI preprocessing...") + + # REgistration + print("\nStep 1: Image Registration") + success = registration( + input_dir=input_dir, + output_dir=temp_reg_dir, + temp_img=temp_img + ) + + if not success: + print("Registration failed! No images were processed successfully.") + return + + print("\nChecking temporary directory contents:") + print(os.listdir(temp_reg_dir)) + + # skullstripping + print("\nStep 2: Brain Extraction") + brain_extraction( + input_dir=temp_reg_dir, + output_dir=output_dir, + device=device + ) + + # Clean up temporary directory + import shutil + shutil.rmtree(temp_reg_dir) + + print("\nPreprocessing complete! Final results saved in:", output_dir) + print("Final preprocessed files:") + print(os.listdir(output_dir)) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Process brain MRI registration and skull stripping.") + parser.add_argument("--temp_img", type=str, required=True, help="Path to the atlas template image.") + parser.add_argument("--input_dir", type=str, required=True, help="Path to the input images directory.") + parser.add_argument("--output_dir", type=str, required=True, help="Path to save the processed images.") + + args = parser.parse_args() + main(temp_img=args.temp_img, input_dir=args.input_dir, output_dir=args.output_dir) \ No newline at end of file diff --git a/src/BrainIAC/preprocessing/preprocess_utils.py b/src/BrainIAC/preprocessing/preprocess_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4988111f5fc96f3e4856ec592e20f66d30de756e --- /dev/null +++ b/src/BrainIAC/preprocessing/preprocess_utils.py @@ -0,0 +1,511 @@ +import sys +sys.path.append('../TM2_segmentation') + +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +import logging +import SimpleITK as sitk +from scipy.signal import medfilt +import numpy as np +import nibabel as nib +import scipy +import skimage +import functools +from skimage.transform import resize +import subprocess +import pandas as pd +import shutil +import itk + +# compute the intersection over union of two binary masks +def iou(component1, component2): + component1 = np.array(component1, dtype=bool) + component2 = np.array(component2, dtype=bool) + + overlap = component1 * component2 # Logical AND + union = component1 + component2 # Logical OR + + IOU = overlap.sum()/float(union.sum()) + return IOU + +# helper function to get the id and path of the image and the mask +def get_id_and_path(row, image_dir, nested = False, no_tms=True): + patient_id, image_path, ltm_file, rtm_file = "","","","" + if no_tms and row['Ok registered? Y/N'] == "N" : + print("skip - bad registration") + return "","","","" + if "NDAR" in str(row['Filename']) and nested==False and no_tms: + patient_id = str(row['Filename']).split("_")[0] + else: + patient_id = str(row['Filename']).split(".")[0] + + path = find_file_in_path(patient_id, os.listdir(image_dir)) + + if nested: + patient_id = patient_id.split("/")[-1] + path = patient_id.split("/")[-1] + if no_tms==False: + path="" + + scan_folder = image_dir+path + patient_id=patient_id.split("/")[-1] + + for file in os.listdir(scan_folder): + t = image_dir+path+"/"+file + if "LTM" in file: + ltm_file = t + elif "RTM" in file: + rtm_file = t + elif "TM" in file: + rtm_file = t + ltm_file = t + if patient_id in file: + image_path = t + return patient_id, image_path, ltm_file, rtm_file + +# another helper function to get the id and path of the image and the mask, when the folder structure is different +def get_id_and_path_not_nested(row, image_dir, masks_dir): + patient_id, image_path, tm_file = 0,0,0 + if row['Ok registered? Y/N'] == "N": + print("skip - bad registration") + return 0,0,0,0 + if "NDAR" in row['Filename']: + patient_id = row['Filename'].split("_")[0] + else: + patient_id = row['Filename'].split(".")[0] + + path = find_file_in_path(patient_id, os.listdir(masks_dir)) + if len(path)<3: + return 0,0,0,0 + scan_folder_masks = masks_dir+path + + for file in os.listdir(scan_folder_masks): + if "._" in file: #skip hidden files + continue + if "TM" in file: + tm_file = masks_dir+path+"/"+file + elif ".nii" in file and "TM" not in file: + image_path = image_dir+patient_id+".nii" + + return patient_id, image_path, tm_file + +# crop the image to the bounding box of the mask +def crop_center(img,cropx,cropy): + y,x = img.shape + startx = x//2-(cropx//2) + starty = y//2-(cropy//2) + return img[starty:starty+cropy,startx:startx+cropx] + +# find the file in the path +def find_file_in_path(name, path): + result = [] + result = list(filter(lambda x:name in x, path)) + if len(result) != 0: + for file in result: + if "._" in file:#skip hidden files + continue + else: + return file + else: + return "" + +# perform the bias field correction +def bias_field_correction(img): + image = sitk.GetImageFromArray(img) + maskImage = sitk.OtsuThreshold(image, 0, 1, 200) + corrector = sitk.N4BiasFieldCorrectionImageFilter() + numberFittingLevels = 4 + + corrector.SetMaximumNumberOfIterations([100] * numberFittingLevels) + corrected_image = corrector.Execute(image, maskImage) + log_bias_field = corrector.GetLogBiasFieldAsImage(image) + corrected_image_full_resolution = image / sitk.Exp(log_bias_field) + return sitk.GetArrayFromImage(corrected_image_full_resolution) + +def load_nii(path): + nii = nib.load(path) + return nii.get_fdata(), nii.affine + +def save_nii(data, path, affine): + nib.save(nib.Nifti1Image(data, affine), path) + return + +def denoise(volume, kernel_size=3): + return medfilt(volume, kernel_size) + +# apply the windowing to the image +def apply_window(image, win_centre= 40, win_width= 400): + range_bottom = 149 #win_centre - win_width / 2 + scale = 256 / 256 #win_width + image = image - range_bottom + + image = image * scale + image[image < 0] = 0 + image[image > 255] = 255 + return image + +# rescale the intensity of the image and binning +def rescale_intensity(volume, percentils=[0.5, 99.5], bins_num=256): + #remove background pixels by the otsu filtering + t = skimage.filters.threshold_otsu(volume,nbins=6) + volume[volume < t] = 0 + + obj_volume = volume[np.where(volume > 0)] + min_value = np.percentile(obj_volume, percentils[0]) + max_value = np.percentile(obj_volume, percentils[1]) + if bins_num == 0: + obj_volume = (obj_volume - min_value) / (max_value - min_value).astype(np.float32) + else: + obj_volume = np.round((obj_volume - min_value) / (max_value - min_value) * (bins_num - 1)) + obj_volume[np.where(obj_volume < 1)] = 1 + obj_volume[np.where(obj_volume > (bins_num - 1))] = bins_num - 1 + + volume = volume.astype(obj_volume.dtype) + volume[np.where(volume > 0)] = obj_volume + return volume + +# equalize the histogram of the image +def equalize_hist(volume, bins_num=256): + obj_volume = volume[np.where(volume > 0)] + hist, bins = np.histogram(obj_volume, bins_num) + cdf = hist.cumsum() + cdf = (bins_num - 1) * cdf / cdf[-1] + + obj_volume = np.round(np.interp(obj_volume, bins[:-1], cdf)).astype(obj_volume.dtype) + volume[np.where(volume > 0)] = obj_volume + return volume + +# enhance the image +def enhance(volume, kernel_size=3, + percentils=[0.5, 99.5], bins_num=256, eh=True): + try: + volume = bias_field_correction(volume) + volume = denoise(volume, kernel_size) + volume = rescale_intensity(volume, percentils, bins_num) + if eh: + volume = equalize_hist(volume, bins_num) + return volume + except RuntimeError: + logging.warning('Failed enchancing') + +# enhance the image without bias field correction +def enhance_noN4(volume, kernel_size=3, + percentils=[0.5, 99.5], bins_num=256, eh=True): + try: + #volume = bias_field_correction(volume) + volume = denoise(volume, kernel_size) + #print(np.shape(volume)) + volume = rescale_intensity(volume, percentils, bins_num) + #print(np.shape(volume)) + if eh: + volume = equalize_hist(volume, bins_num) + return volume + except RuntimeError: + logging.warning('Failed enchancing') + +# get the resampled image +def get_resampled_sitk(data_sitk,target_spacing): + new_spacing = target_spacing + + orig_spacing = data_sitk.GetSpacing() + orig_size = data_sitk.GetSize() + + new_size = [int(orig_size[0] * orig_spacing[0] / new_spacing[0]), + int(orig_size[1] * orig_spacing[1] / new_spacing[1]), + int(orig_size[2] * orig_spacing[2] / new_spacing[2])] + + res_filter = sitk.ResampleImageFilter() + img_sitk = res_filter.Execute(data_sitk, + new_size, + sitk.Transform(), + sitk.sitkLinear, + data_sitk.GetOrigin(), + new_spacing, + data_sitk.GetDirection(), + 0, + data_sitk.GetPixelIDValue()) + + return img_sitk + +# convert the nrrd file to nifty file +def nrrd_to_nifty(nrrd_file): + _nrrd = nrrd.read(nrrd_file) + data_f = _nrrd[0] + header = _nrrd[1] + return np.asarray(data_f), header + +# crop the brain from the image +def crop_brain(var_img, mni_img): + # invert brain mask + inverted_mask = np.invert(mni_img.astype(bool)).astype(float) + mask_data = inverted_mask * var_img + return mask_data + +# normalize the image with the brain mask +def brain_norm_masked(mask_data, brain_data, to_save=False): + masked = crop_brain(brain_data, mask_data) + enhanced = enhance(masked) + return enhanced + +# enhance all the images in the path +def enhance_and_debias_all_in_path(image_dir='data/mni_templates_BK/',path_to='data/denoised_mris/',\ + input_annotation_file = 'data/all_metadata.csv'): + + df = pd.read_csv(input_annotation_file,header=0) + df=df[df['Ok registered? Y/N']=='Y'].reset_index() + #print(df.shape[0]) + for idx in range(0, 1): + print(idx) + row = df.iloc[idx] + patient_id, image_path, tm_file, _ = get_id_and_path(row, image_dir) + print(patient_id, image_path, tm_file) + image_sitk = sitk.ReadImage(image_path) + image_array = sitk.GetArrayFromImage(image_sitk) + image_array = enhance(image_array) + image3 = sitk.GetImageFromArray(image_array) + sitk.WriteImage(image3,path_to+patient_id+'.nii') + return + +# Z-enhance all the images in the path +def z_enhance_and_debias_all_in_path(image_dir='data/mni_templates_BK/',path_to='data/z_scored_mris/',\ + input_annotation_file = 'data/all_metadata.csv', for_training=True, annotations=True): + df = pd.read_csv(input_annotation_file,header=0) + + if for_training: + df=df[df['Ok registered? Y/N']=='Y'].reset_index() + print(df.shape[0]) + + for idx in range(0, df.shape[0]): + print(idx) + row = df.iloc[idx] + patient_id, image_path, tm_file, _ = get_id_and_path(row, image_dir, nested=False, no_tms=for_training) + print(patient_id, len(image_path), tm_file, path_to) + if not os.path.isdir(path_to+"no_z"): + os.mkdir(path_to+"no_z") + if not os.path.isdir(path_to+"z"): + os.mkdir(path_to+"z") + + if len(image_path)>3: + image_sitk = sitk.ReadImage(image_path) + image_array = sitk.GetArrayFromImage(image_sitk) + print(len(image_array)) + try: + image_array = enhance_noN4(image_array) + image3 = sitk.GetImageFromArray(image_array) + sitk.WriteImage(image3,path_to+"no_z/"+patient_id+'.nii') + os.mkdir(path_to+"z/"+patient_id) + if annotations: + shutil.copyfile(tm_file, path_to+"z/"+patient_id+"/TM.nii.gz") + duck_line = "zscore-normalize "+path_to+"no_z/"+patient_id+".nii -o "+path_to+"z/"+patient_id +"/"+patient_id+'.nii' + subprocess.getoutput(duck_line) + except: + continue + +# find the closest value in the list +def closest_value(input_list, input_value): + arr = np.asarray(input_list) + i = (np.abs(arr - input_value)).argmin() + return arr[i], i + +# find the centile of the input value +def find_centile(input_tmt, age, df): + #print("TMT:",input_tmt,"Age:", age) + val,i=closest_value(df['x'],age) + + centile = 'out of range' + if input_tmt
+
+
+
+
+ Predicted Brain Age: {{ prediction }}
+Processing... Please wait.
+