import os import time import json import gradio as gr from gradio_niivueviewer import NiiVueViewer DEMO_DIR = os.path.dirname(os.path.abspath(__file__)) EXAMPLE_CT = os.path.join(DEMO_DIR, "input/amos_0004-img.nii.gz") EXAMPLE_MASK = os.path.join(DEMO_DIR, "output/amos_0004-mask.nii.gz") # TotalSegmentator labels present in the demo mask SEG_LABELS = { 1: "Spleen", 2: "Right Kidney", 3: "Left Kidney", 4: "Gallbladder", 5: "Liver", 6: "Stomach", 7: "Pancreas", 8: "Right Adrenal Gland", 9: "Left Adrenal Gland", 10: "Left Lung (Upper)", 11: "Left Lung (Lower)", 13: "Right Lung (Middle)", 14: "Right Lung (Lower)", 15: "Esophagus", 18: "Small Bowel", 19: "Duodenum", 20: "Colon", 21: "Urinary Bladder", 22: "Prostate/Uterus", 25: "Sacrum", 26: "Vertebrae S1", 27: "Vertebrae L5", 28: "Vertebrae L4", 29: "Vertebrae L3", 30: "Vertebrae L2", 31: "Vertebrae L1", 32: "Vertebrae T12", 33: "Vertebrae T11", 34: "Vertebrae T10", 35: "Vertebrae T9", 36: "Vertebrae T8", 51: "Left Hip", 52: "Right Hip", 63: "Gluteus Maximus (L)", 64: "Gluteus Maximus (R)", 65: "Gluteus Medius (L)", 66: "Gluteus Medius (R)", 67: "Gluteus Minimus (L)", 68: "Gluteus Minimus (R)", 75: "Autochthon (L)", 76: "Autochthon (R)", 77: "Iliopsoas (L)", 78: "Iliopsoas (R)", 79: "Rib 1 (L)", 80: "Rib 1 (R)", 81: "Rib 2 (L)", 82: "Rib 2 (R)", 83: "Rib 3 (L)", 84: "Rib 3 (R)", 85: "Rib 4 (L)", 86: "Rib 4 (R)", 87: "Rib 5 (L)", 88: "Rib 5 (R)", 89: "Rib 6 (L)", 97: "Rib 7 (L)", 98: "Rib 7 (R)", 99: "Rib 8 (L)", 100: "Rib 8 (R)", 101: "Rib 9 (L)", 102: "Rib 9 (R)", 103: "Rib 10 (L)", 109: "Sternum", 110: "Costal Cartilage", 111: "Aorta", 112: "Pulmonary Vein", 113: "Superior Vena Cava", 114: "Inferior Vena Cava", 115: "Portal/Splenic Vein", 117: "Heart", } def preview_ct(ct_path: str | None): """Show CT only (no segmentation yet).""" if not ct_path: return None return [ct_path] def predict(ct_path: str | None): """Fake inference: sleep then return demo mask + label JSON.""" if not ct_path: return None, None # Simulate model inference time.sleep(4) labels_json = {str(k): v for k, v in SEG_LABELS.items()} return [ct_path, EXAMPLE_MASK], labels_json with gr.Blocks(title="Abdominal Organ Segmentation") as demo: gr.Markdown( "## Abdominal Organ Segmentation Demo\n" "Upload a CT scan or click the example below, then press **Run** to segment." ) with gr.Row(): ct_input = gr.File( label="Input CT (.nii / .nii.gz)", file_types=[".nii", ".gz"], scale=3, ) run_btn = gr.Button("▶ Run Segmentation", variant="primary", scale=1) viewer = NiiVueViewer(label="Viewer", seg_labels=SEG_LABELS) labels_out = gr.JSON(label="Segmentation Labels") gr.Examples( examples=[[EXAMPLE_CT]], inputs=[ct_input], fn=preview_ct, outputs=[viewer], run_on_click=True, label="Example", ) ct_input.change(preview_ct, inputs=ct_input, outputs=viewer) run_btn.click(predict, inputs=ct_input, outputs=[viewer, labels_out]) if __name__ == "__main__": demo.launch( server_port=7870, allowed_paths=[DEMO_DIR], )