Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import gradio as gr | |
| import cv2 | |
| import torch | |
| from zipfile import ZipFile | |
| import json | |
| import os | |
| import re | |
| from models.utils import load_config | |
| from models.hybridgnet_se_resnext_dual import HybridDual | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| def pad_to_square(img): | |
| h, w = img.shape[:2] | |
| if h > w: | |
| padw = (h - w) | |
| auxw = padw % 2 | |
| img = np.pad(img, ((0, 0), (padw//2, padw//2 + auxw)), 'constant') | |
| padh = 0 | |
| auxh = 0 | |
| else: | |
| padh = (w - h) | |
| auxh = padh % 2 | |
| img = np.pad(img, ((padh//2, padh//2 + auxh), (0, 0)), 'constant') | |
| padw = 0 | |
| auxw = 0 | |
| return img, (padh, padw, auxh, auxw) | |
| def preprocess(input_img, image_size=512): | |
| """ Preprocess the input image to fit the model requirements. | |
| Args: | |
| input_img (numpy array): The input image to preprocess. | |
| image_size (int): The desired size of the output image. | |
| Returns: | |
| img (numpy array): The preprocessed image. | |
| padding (tuple): Padding information for the original image. | |
| """ | |
| img, padding = pad_to_square(input_img) | |
| h, w = img.shape[:2] | |
| if h != image_size or w != image_size: | |
| img = cv2.resize(img, (image_size, image_size), interpolation = cv2.INTER_CUBIC) | |
| return img, (h, w, padding) | |
| def removePreprocess(output, info, image_size=512): | |
| """ Remove the preprocessing applied to the output. | |
| Args: | |
| output (numpy array): The output from the model. | |
| info (tuple): Information about the original image size and padding. | |
| image_size (int): The size of the image after preprocessing. | |
| Returns: | |
| output (numpy array): The output adjusted to the original image size. | |
| """ | |
| h, w, padding = info | |
| if h != image_size or w != image_size: | |
| output = output * h | |
| else: | |
| output = output * image_size | |
| padh, padw, auxh, auxw = padding | |
| output[:, 0] = output[:, 0] - padw//2 | |
| output[:, 1] = output[:, 1] - padh//2 | |
| return output | |
| def zip_files(files): | |
| with ZipFile("tmp/complete_results.zip", "w") as zipObj: | |
| for idx, file in enumerate(files): | |
| zipObj.write(file, arcname=file.split("/")[-1]) | |
| return "tmp/complete_results.zip" | |
| def drawOnTop(nodes, input_img, organ_ids, circ_organ_order, organ_names): | |
| import matplotlib.pyplot as plt | |
| from matplotlib import cm | |
| colors = cm.get_cmap('tab20', len(organ_ids)) | |
| plt.figure(figsize=(10,10)) | |
| plt.imshow(input_img, cmap='gray') | |
| k = 0 | |
| for organ_id in organ_ids: | |
| node_ids = circ_organ_order[organ_id] | |
| x_nodes = nodes[node_ids, 0] | |
| y_nodes = nodes[node_ids, 1] | |
| # Color code the nodes in order using hsv | |
| plt.scatter(x_nodes, y_nodes, color=colors(k), label=f'{organ_names[organ_id].title()}') | |
| plt.fill(x_nodes, y_nodes, color=colors(k), alpha=0.3) # Fill area under the points | |
| k += 1 | |
| plt.axis('off') | |
| # put the legend outside the image | |
| plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='medium', title='Organs') | |
| plt.tight_layout() | |
| plt.savefig("tmp/overlap_segmentation.png", bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| plt.cla() | |
| return cv2.imread("tmp/overlap_segmentation.png") | |
| def segment(input_img, dataset_display_name): | |
| # Map nice names to internal codes | |
| name_mapping = { | |
| "Cardiac Ultrasound Images": "CAMUS", | |
| "Prenatal Ultrasound": "FETAL", | |
| "Cardiac MRI": "MRI", | |
| "Chest X-Ray": "PAX-RAY++" | |
| } | |
| dataset = name_mapping.get(dataset_display_name, dataset_display_name) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| image_size = 512 | |
| if dataset == "CAMUS": | |
| model_checkpoint = "weights/CAMUS/CAMUS_NN_dual.pth" | |
| parameters = json.load(open("weights/CAMUS/hyperparameters.json")) | |
| parameters["naive"] = False | |
| parameters["use_dual"] = True | |
| parameters["dual"] = True | |
| dataset_path = "weights/CAMUS/Dataset" | |
| elif dataset == "FETAL": | |
| model_checkpoint = "weights/FETAL/FETAL_dual.pth" | |
| parameters = json.load(open("weights/FETAL/hyperparameters.json")) | |
| parameters["naive"] = True | |
| parameters["use_dual"] = True | |
| parameters["dual"] = True | |
| dataset_path = "weights/FETAL/Dataset" | |
| elif dataset == "MRI": | |
| model_checkpoint = "weights/MRI/MRI_dual.pth" | |
| parameters = json.load(open("weights/MRI/hyperparameters.json")) | |
| parameters["naive"] = True | |
| parameters["use_dual"] = True | |
| parameters["dual"] = True | |
| dataset_path = "weights/MRI/Dataset" | |
| image_size = 256 | |
| elif dataset == "PAX-RAY++": | |
| model_checkpoint = "weights/PAXRAY/PAXRAY_NN_dual.pth" | |
| parameters = json.load(open("weights/PAXRAY/hyperparameters.json")) | |
| parameters["naive"] = False | |
| parameters["use_dual"] = True | |
| parameters["dual"] = True | |
| dataset_path = "weights/PAXRAY/Dataset" | |
| else: | |
| return None, None | |
| config, D_t, U_t, A_t = load_config(dataset_path, parameters) | |
| model = HybridDual(config, D_t, U_t, A_t).to(device) | |
| model.load_checkpoint(model_checkpoint, device) | |
| model.eval() | |
| if parameters["naive"]: | |
| organ_id = np.load("%s/Naive/adj_full_organ_id.npy" % dataset_path)[:,0] | |
| organ_order = np.unique(organ_id).astype(int) | |
| organ_order = [str(int(org)) for org in organ_order] | |
| circ_organ_order = {} | |
| for i, org in enumerate(organ_order): | |
| # put all the idxs of the organ in the dict | |
| circ_organ_order[str(int(org))] = np.where(organ_id == int(org))[0].tolist() | |
| else: | |
| # Load organ IDs | |
| organ_id = np.load("%s/NonNaive/adj_full_organ_id.npy" % dataset_path)[:,0] | |
| unique_organs = set() | |
| for org_str in organ_id: | |
| for org in str(org_str).split('-'): | |
| if org: # Skip empty strings | |
| unique_organs.add(int(org)) | |
| organ_order = sorted(list(unique_organs)) | |
| organ_order = [str(int(org)) for org in organ_order] | |
| with open(f"{dataset_path}/NonNaive/organ_order_full.json", "r") as f: | |
| circ_organ_order = json.load(f) | |
| # map organ names to organ order | |
| organ_names = config["organ_names"] | |
| organs = config["organs"] | |
| organs = [str(int(org)) for org in organs] # Ensure organs are strings to match organ_names | |
| # use organs to map organ names, both are lists organs is keys and organ_names is values | |
| organ_names = {organs[i]: organ_names[i] for i in range(len(organs))} | |
| # If the dataset is FETAL, only keep fetal head, that's organ "2" | |
| if dataset == "FETAL": | |
| organ_order = ["2"] | |
| circ_organ_order = {"2": circ_organ_order["2"]} | |
| organ_names = {"2": organ_names["2"]} | |
| input_img = cv2.imread(input_img, 0) / 255.0 | |
| img, (h, w, padding) = preprocess(input_img, image_size=image_size) | |
| data = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device).float() | |
| with torch.no_grad(): | |
| output = model(data)[0].cpu().numpy().reshape(-1, 2) | |
| output = removePreprocess(output, (h, w, padding), image_size=image_size) | |
| output = output.astype('int') | |
| landmarks_dict = {} | |
| for organ in circ_organ_order.keys(): | |
| organ_name = organ_names[organ] | |
| index_train = circ_organ_order[organ] | |
| landmarks_dict[organ_name] = output[index_train].tolist() | |
| out_path = f"tmp/landmarks.json" | |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) | |
| with open(out_path, 'w') as f: | |
| json.dump(landmarks_dict, f) | |
| outseg = drawOnTop(output, input_img, organ_order, circ_organ_order, organ_names) | |
| zip = zip_files(["tmp/overlap_segmentation.png", "tmp/landmarks.json"]) | |
| return outseg, ["tmp/overlap_segmentation.png", "tmp/landmarks.json", zip] | |
| def get_examples(dataset_name): | |
| # Map nice names to internal codes | |
| name_mapping = { | |
| "Cardiac Ultrasound Images": "CAMUS", | |
| "Prenatal Ultrasound": "FETAL", | |
| "Cardiac MRI": "MRI", | |
| "Chest X-Ray": "PAX-RAY++" | |
| } | |
| # Get the internal code, fallback to dataset_name if not found | |
| internal_name = name_mapping.get(dataset_name, dataset_name) | |
| folder_map = { | |
| "CAMUS": "examples/heart", | |
| "CXRAY": "examples/chest", | |
| "FETAL": "examples/fetal", | |
| "MRI": "examples/mri", | |
| "PAX-RAY++": "examples/chest" | |
| } | |
| folder = folder_map.get(internal_name, "") | |
| if folder and os.path.exists(folder): | |
| files = os.listdir(folder) | |
| examples = [os.path.join(folder, f) for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
| # sort them by natural key using re | |
| examples.sort(key=lambda x: [int(i) for i in re.findall(r'\d+', x)]) | |
| return examples | |
| else: | |
| return [] | |
| def update_examples_and_clear(dataset_name): | |
| """Update example choices and clear both example selector and input image""" | |
| examples = get_examples(dataset_name) | |
| return ( | |
| gr.update(choices=examples, value=None), # Update example selector and clear selection | |
| None # Clear input image | |
| ) | |
| def load_example(example_path): | |
| """Load selected example into input image""" | |
| if example_path: | |
| return example_path | |
| return None | |
| if __name__ == "__main__": | |
| theme = gr.themes.Default() | |
| with gr.Blocks(theme=theme) as demo: | |
| gr.Markdown(""" | |
| # Mask-HybridGNet: Graph-based segmentation with emergent anatomical correspondence from pixel-level supervision | |
| This demo showcases the HybridGNet model for graph-based medical image segmentation across multiple imaging modalities. | |
| **Instructions:** | |
| 1. Choose the imaging modality you want to work with | |
| 2. Upload an image in PNG or JPEG format, or select an example from the dropdown | |
| 3. Click "Segment Image" to perform automated segmentation | |
| **Note:** Image preprocessing is handled automatically and will be reversed after segmentation to provide results in the original image coordinates. | |
| """) | |
| with gr.Tab("Segment Image"): | |
| # Dataset selector at the top | |
| dataset_selector = gr.Dropdown( | |
| label="Select Image Modality", | |
| choices=["Cardiac Ultrasound Images", "Prenatal Ultrasound", "Cardiac MRI", "Chest X-Ray"], | |
| value="Cardiac Ultrasound Images" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="filepath", height=750) | |
| # Example selector dropdown | |
| example_selector = gr.Dropdown( | |
| label="Example Images", | |
| choices=get_examples("Cardiac Ultrasound Images"), | |
| value=None, | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| clear_button = gr.Button("Clear") | |
| image_button = gr.Button("Segment Image") | |
| with gr.Column(): | |
| image_output = gr.Image(type="filepath", height=750) | |
| results = gr.File() | |
| # Clear button functionality | |
| clear_button.click(lambda: None, None, image_input, queue=False) | |
| clear_button.click(lambda: None, None, image_output, queue=False) | |
| clear_button.click(lambda: None, None, results, queue=False) | |
| clear_button.click(lambda: None, None, example_selector, queue=False) | |
| # When dataset changes, update example choices and clear selections | |
| dataset_selector.change( | |
| fn=update_examples_and_clear, | |
| inputs=[dataset_selector], | |
| outputs=[example_selector, image_input] | |
| ) | |
| # When example is selected, load it into input image | |
| example_selector.change( | |
| fn=load_example, | |
| inputs=[example_selector], | |
| outputs=[image_input] | |
| ) | |
| # Segment button | |
| image_button.click( | |
| fn=segment, | |
| inputs=[image_input, dataset_selector], | |
| outputs=[image_output, results], | |
| queue=False | |
| ) | |
| gr.Markdown(""" | |
| **Example Image Sources:** | |
| All example images were obtained from Wikimedia Commons under open licenses. | |
| None of these come from the training nor testing datasets of the models. | |
| **Chest X-Ray Images:** | |
| - Creative Commons Attribution-Share Alike 4.0 International. Source: https://commons.wikimedia.org/wiki/File:Chest_X-ray.jpg | |
| - Creative Commons Attribution-Share Alike 3.0 Unported license. Source: https://commons.wikimedia.org/wiki/File:Medical_X-Ray_imaging_AMS02_nevit.jpg | |
| - Creative Commons Attribution-Share Alike 3.0 Unported license. Source: https://commons.wikimedia.org/wiki/File:Medical_X-Ray_imaging_BHB02_nevit.jpg | |
| - Creative Commons Attribution-Share Alike 3.0 Unported license. Source: https://commons.wikimedia.org/wiki/File:Medical_X-Ray_imaging_WQC07_nevit.jpg | |
| **Cardiac Ultrasound Images:** | |
| Creative Commons Attribution-Share Alike 4.0 International. Source: https://commons.wikimedia.org/wiki/File:ProapsZCMiCh_.gif | |
| Images were extracted from the frames of the GIF animation. | |
| **Cardiac MRI Images:** | |
| Public Domain. Source: https://commons.wikimedia.org/wiki/File:Multslice_short_axis.gif | |
| Images were extracted from the frames of the GIF animation. | |
| **Prenatal Ultrasound Images:** | |
| Creative Commons Attribution-Share Alike 3.0 Unported license. | |
| Sources: | |
| - https://commons.wikimedia.org/wiki/File:Pregnancy_ultrasound_110328105247_1105070.jpg | |
| - https://commons.wikimedia.org/wiki/File:Pregnancy_ultrasound_110325112752_1129380.jpg | |
| - https://commons.wikimedia.org/wiki/File:Pregnancy_ultrasound_110322143039_1441530.jpg | |
| - https://commons.wikimedia.org/wiki/File:Pregnancy_ultrasound_110328105247_1054210.jpg | |
| - https://commons.wikimedia.org/wiki/File:Pregnancy_ultrasound_110322143039_1440550.jpg | |
| - https://commons.wikimedia.org/wiki/File:Pregnancy_ultrasound_110322143039_1442060.jpg | |
| **Author:** Nicolás Gaggion | |
| **Website:** [ngaggion.github.io](https://ngaggion.github.io/) | |
| """) | |
| demo.launch() |