Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import shutil | |
| import string | |
| import zipfile | |
| from functools import partial | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import nibabel as nib | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from tqdm import tqdm as std_tqdm | |
| tqdm = partial(std_tqdm, dynamic_ncols=True) | |
| # Import required modules from our project | |
| from utils.cropping import cropping | |
| from utils.hemisphere import hemisphere | |
| from utils.load_model import load_model | |
| from utils.make_csv import make_csv | |
| from utils.make_level import create_parcellated_images | |
| from utils.parcellation import parcellation | |
| from utils.postprocessing import postprocessing | |
| from utils.preprocessing import preprocessing | |
| from utils.stripping import stripping | |
| def nii_to_image(voxel_path, label_path, output_dir, basename): | |
| """ | |
| Converts two NIfTI files into 2D images for visualization. | |
| The voxel (input MRI) is shown as a grayscale image and the label (segmentation) | |
| is shown using a default color map. | |
| A middle slice is chosen by default. | |
| """ | |
| # Load the NIfTI volumes and squeeze to remove extra dimensions | |
| vdata = nib.squeeze_image(nib.as_closest_canonical(nib.load(voxel_path))) | |
| ldata = nib.squeeze_image(nib.as_closest_canonical(nib.load(label_path))) | |
| voxel = vdata.get_fdata().astype("float32") | |
| label = ldata.get_fdata().astype("int16") | |
| # Choose the middle slice along the first dimension and rotate for display | |
| slice_index = voxel.shape[0] // 2 | |
| slice_voxel = np.rot90(voxel[slice_index, :, :]) | |
| slice_label = np.rot90(label[slice_index, :, :]) | |
| # Plot and save the input MRI image | |
| plt.figure(figsize=(5, 5)) | |
| plt.imshow(slice_voxel, cmap="gray") | |
| plt.title("Input Image") | |
| plt.axis("off") | |
| input_png_path = os.path.join(os.path.dirname(output_dir), f"{basename}_input.png") | |
| plt.savefig(input_png_path, format="png", bbox_inches="tight", pad_inches=0) | |
| # Plot and save the parcellation (segmentation) map image | |
| plt.figure(figsize=(5, 5)) | |
| plt.imshow(slice_label) | |
| plt.title("Parcellation Result") | |
| plt.axis("off") | |
| parcellation_png_path = os.path.join(os.path.dirname(output_dir), f"{basename}_parcellation.png") | |
| plt.savefig(parcellation_png_path, format="png", bbox_inches="tight", pad_inches=0) | |
| return input_png_path, parcellation_png_path | |
| def run_inference(input_file, only_face_cropping, only_skull_stripping): | |
| # Generate a random 10-character string to create a unique temporary directory | |
| random_string = "".join(random.choices(string.ascii_letters + string.digits, k=10)) | |
| # Extract the base filename from the uploaded file (handle .nii and .nii.gz) | |
| basename = os.path.splitext(os.path.basename(input_file.name))[0] | |
| if basename.endswith(".nii"): | |
| basename = os.path.splitext(basename)[0] | |
| # Create an Options object (similar to argparse.Namespace) | |
| class Options: | |
| pass | |
| opt = Options() | |
| # Set the output directory uniquely with the random string and base filename | |
| opt.o = f"output/{random_string}/{basename}" | |
| opt.only_face_cropping = only_face_cropping | |
| opt.only_skull_stripping = only_skull_stripping | |
| # Device selection: prefer CUDA if available, otherwise MPS or CPU | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| print(f"Using device: {device}") | |
| # Load the pre-trained models from the fixed "model/" folder | |
| print("Loading models...") | |
| cnet, ssnet, pnet, hnet = load_model("model/", device=device) | |
| print("Models loaded successfully.") | |
| # --- Processing Flow (based on the original parcellation.py) --- | |
| # 1. Load the input image, convert to canonical orientation, and remove extra dimensions | |
| print("Loading and preprocessing the input image...") | |
| odata = nib.squeeze_image(nib.as_closest_canonical(nib.load(input_file.name))) | |
| nii = nib.Nifti1Image(odata.get_fdata().astype(np.float32), affine=odata.affine) | |
| os.makedirs(os.path.join(opt.o, "original"), exist_ok=True) | |
| original_nii_path = os.path.join(opt.o, f"original/{basename}.nii") | |
| nib.save(nii, original_nii_path) | |
| print(f"Input image saved to: {original_nii_path}") | |
| # 2. Preprocess the image | |
| print("Preprocessing the input image...") | |
| odata, data = preprocessing(input_file.name, opt.o, basename) | |
| print("Preprocessing completed.") | |
| # 3. Cropping | |
| print("Cropping the input image...") | |
| cropped, shift, out_filename = cropping(opt.o, basename, odata, data, cnet, device) | |
| print("Cropping completed.") | |
| if only_face_cropping: | |
| pass | |
| else: | |
| # 4. Skull stripping | |
| print("Performing skull stripping...") | |
| stripped, out_filename = stripping(opt.o, basename, cropped, odata, data, ssnet, shift, device) | |
| print("Skull stripping completed.") | |
| if only_skull_stripping: | |
| pass | |
| else: | |
| # 5. Parcellation | |
| print("Starting parcellation...") | |
| parcellated = parcellation(stripped, pnet, device) | |
| print("Parcellation completed.") | |
| # 6. Separate into hemispheres | |
| print("Separating hemispheres...") | |
| separated = hemisphere(stripped, hnet, device) | |
| print("Hemispheres separated.") | |
| # 7. Postprocessing | |
| print("Postprocessing the parcellated data...") | |
| output = postprocessing(parcellated, separated, shift, device) | |
| print("Postprocessing completed.") | |
| # 8. Create CSV with volume information, etc. | |
| print("Creating CSV with volume information...") | |
| df = make_csv(output, opt.o, basename) | |
| print("CSV created successfully.") | |
| # 9. Create and save the parcellation result NIfTI file | |
| nii_out = nib.Nifti1Image(output.astype(np.uint16), affine=data.affine) | |
| header = odata.header | |
| nii_out = nib.processing.conform( | |
| nii_out, | |
| out_shape=(header["dim"][1], header["dim"][2], header["dim"][3]), | |
| voxel_size=(header["pixdim"][1], header["pixdim"][2], header["pixdim"][3]), | |
| order=0, | |
| ) | |
| out_parcellated_dir = os.path.join(opt.o, "parcellated") | |
| os.makedirs(out_parcellated_dir, exist_ok=True) | |
| out_filename = os.path.join(out_parcellated_dir, f"{basename}_Type1_Level5.nii") | |
| nib.save(nii_out, out_filename) | |
| create_parcellated_images(output, opt.o, basename, odata, data) | |
| print(f"Parcellation result saved to: {out_filename}") | |
| # Zip the entire output directory into a ZIP file | |
| zip_path = os.path.join(os.path.dirname(opt.o), f"{basename}_results.zip") | |
| with zipfile.ZipFile(zip_path, "w") as zipf: | |
| for root, _, files in os.walk(opt.o): | |
| for file in files: | |
| file_path = os.path.join(root, file) | |
| # Adjust the path within the zip archive | |
| arcname = os.path.relpath(file_path, start=opt.o) | |
| zipf.write(file_path, arcname) | |
| # Convert the NIfTI files into visualization images (PNG) | |
| input_png_path, parcellation_png_path = nii_to_image(input_file.name, out_filename, opt.o, basename) | |
| # *** Cleanup: Remove the temporary output directory *** | |
| # Note: This is performed before returning. It is not possible to execute code after the return statement. | |
| # shutil.rmtree(opt.o) | |
| # Return the ZIP file path and the two visualization images | |
| return zip_path, Image.open(input_png_path), Image.open(parcellation_png_path) | |
| # Create the Gradio interface (the model folder input is not needed) | |
| iface = gr.Interface( | |
| fn=run_inference, | |
| inputs=[ | |
| gr.File(label="Input NIfTI File (.nii or .nii.gz)"), | |
| gr.Checkbox(label="Only Face Cropping", value=False), | |
| gr.Checkbox(label="Only Skull Stripping", value=False), | |
| ], | |
| outputs=[ | |
| gr.File(label="Output Results ZIP File"), | |
| gr.Image(label="MRI Image (Original)"), | |
| gr.Image(label="Parcellation Map (Type1_Level5)"), | |
| ], | |
| title="OpenMAP-T1 Inference", | |
| description=("The uploaded MRI image will be processed using OpenMAP-T1, and the parcellation " "results will be returned as a ZIP file along with visualization images."), | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |