"""Gradio UI for PanCancerSeg single-case CT tumour segmentation.""" import shutil import tempfile from pathlib import Path import gradio as gr from predict import ( CANCER_CONFIGS, install_custom_trainer, resolve_case_id, resolve_model_folder, run_nnunet_prediction_single, summarize_segmentation, ) from visualize import generate_outputs # ── Constants ────────────────────────────────────────────────────────────────── CANCER_TYPE_CHOICES = { "Kidney Cancer": "kidney_cancer", "Liver Cancer": "liver_cancer", "Pancreatic Cancer": "pancreatic_cancer", "Lung Cancer": "lung_cancer", } DEFAULT_MODEL_DIR = str(Path(__file__).parent / "PanCancerSeg-Specialized-weights") DEFAULT_DEVICE = "cuda" # Hugging Face Hub repo that hosts the trained nnUNet weights. On Spaces (where the # local weights folder is absent) we download them on first use. MODEL_REPO_ID = "KS987/PanCancerSeg-Specialized-weights" # Resolved once per process; subsequent inferences reuse it (no re-download). _WEIGHTS_DIR: Path | None = None def resolve_weights_dir() -> Path: """Return a directory containing the DatasetXXX_* model folders. Prefer a local checkout (fast local dev); otherwise download the weights from the Hugging Face Hub once and cache the resolved path in-process so we never hit the Hub again on later inferences. """ global _WEIGHTS_DIR if _WEIGHTS_DIR is not None: return _WEIGHTS_DIR local_dir = Path(DEFAULT_MODEL_DIR).expanduser().resolve() if local_dir.exists() and any(local_dir.glob("Dataset*")): _WEIGHTS_DIR = local_dir return _WEIGHTS_DIR from huggingface_hub import snapshot_download downloaded = snapshot_download( repo_id=MODEL_REPO_ID, repo_type="model", allow_patterns=["Dataset*/**"], ) _WEIGHTS_DIR = Path(downloaded) return _WEIGHTS_DIR # ── ZeroGPU support ────────────────────────────────────────────────────────── # On Hugging Face ZeroGPU Spaces the `spaces` package is available, and any GPU # work must run inside a function decorated with `@spaces.GPU`. Locally (or on a # dedicated GPU Space) the package is absent, so we fall back to a no-op so the # same code keeps working everywhere. try: import spaces # type: ignore _HAS_ZEROGPU = True except ImportError: spaces = None _HAS_ZEROGPU = False def gpu_task(duration: int = 180): if _HAS_ZEROGPU: return spaces.GPU(duration=duration) def _identity(fn): return fn return _identity @gpu_task(duration=180) def run_gpu_segmentation(model_folder_str: str, input_file_str: str, output_file_str: str) -> None: """Run nnUNet inference on GPU. Executed inside the ZeroGPU worker process. Uses the single-case, no-multiprocessing path because ZeroGPU runs this in a daemon process that is not allowed to spawn child processes. """ # The custom trainer must be registered inside the GPU worker process so that # nnUNet can discover it when initialising from the trained model folder. install_custom_trainer() run_nnunet_prediction_single( model_folder=model_folder_str, input_file=input_file_str, output_file=output_file_str, device="cuda", ) _SAMPLE_DIR = Path(__file__).parent / "sample_input" _CANCER_TYPE_TO_FOLDER = { "Kidney Cancer": "kidney", "Liver Cancer": "liver", "Pancreatic Cancer": "pancreas", "Lung Cancer": "lung", } def load_example(cancer_type_label: str, index: int) -> str: """Return the index-th (1-based) example _0000.nii.gz for the given cancer type.""" folder = _SAMPLE_DIR / _CANCER_TYPE_TO_FOLDER[cancer_type_label] files = sorted(folder.glob("*_0000.nii.gz")) if len(files) < index: raise gr.Error(f"Example {index} not found for {cancer_type_label} in {folder}") return str(files[index - 1]) def count_examples(cancer_type_label: str) -> int: """Number of bundled example CT volumes for a cancer type.""" folder = _SAMPLE_DIR / _CANCER_TYPE_TO_FOLDER[cancer_type_label] if not folder.exists(): return 0 return len(sorted(folder.glob("*_0000.nii.gz"))) def available_cancer_labels(weights_dir) -> list: """Cancer labels whose DatasetXXX folder is present under ``weights_dir``. A single-cancer Space bundles exactly one DatasetXXX folder, so this returns a single label and the UI locks to it. A full checkout with all four datasets returns every label and the UI shows the selector. """ weights_dir = Path(weights_dir) found = [ label for label, key in CANCER_TYPE_CHOICES.items() if (weights_dir / CANCER_CONFIGS[key]["dataset_name"]).exists() ] return found or list(CANCER_TYPE_CHOICES.keys()) # ── Inference ────────────────────────────────────────────────────────────────── def run_inference( input_file, cancer_type_label, fps, progress=gr.Progress(track_tqdm=True), ): if input_file is None: raise gr.Error("Please upload a .nii.gz CT image first.") input_path = Path(input_file) if not input_path.name.endswith(".nii.gz"): raise gr.Error(f"File must be .nii.gz format. Got: {input_path.name}") progress(0.02, desc="Resolving model weights...") try: model_dir_path = resolve_weights_dir() except Exception as e: raise gr.Error(f"Failed to obtain model weights from '{MODEL_REPO_ID}': {e}") cancer_key = CANCER_TYPE_CHOICES[cancer_type_label] config = CANCER_CONFIGS[cancer_key] case_id = resolve_case_id(input_path) progress(0.10, desc="Loading model weights...") model_folder = resolve_model_folder(model_dir_path, config["dataset_name"]) output_dir = Path(tempfile.mkdtemp(prefix="pancancerseg_out_")) try: with tempfile.TemporaryDirectory(prefix="pancancerseg_in_") as tmp: tmp_path = Path(tmp) tmp_input_dir = tmp_path / "input" tmp_output_dir = tmp_path / "prediction" tmp_input_dir.mkdir() tmp_output_dir.mkdir() nnunet_input = tmp_input_dir / f"{case_id}_0000.nii.gz" try: nnunet_input.symlink_to(input_path.resolve()) except (OSError, NotImplementedError): shutil.copy2(input_path, nnunet_input) raw_seg = tmp_output_dir / f"{case_id}.nii.gz" progress(0.20, desc="Running nnUNet inference on GPU (this may take a few minutes)...") run_gpu_segmentation( str(model_folder), str(nnunet_input), str(raw_seg), ) if not raw_seg.exists(): produced = [p.name for p in tmp_output_dir.glob("*.nii.gz")] raise RuntimeError( f"nnUNet did not produce the expected segmentation. Found: {produced}" ) seg_path = output_dir / f"{case_id}_seg.nii.gz" shutil.copy2(raw_seg, seg_path) progress(0.80, desc="Generating slice images and overlay video...") viz = generate_outputs( image_path=input_path, mask_path=seg_path, output_dir=output_dir, case_name=case_id, cancer_type=config["display_name"], wl=config["wl"], ww=config["ww"], color=config["color"], alpha=0.5, fps=int(fps), ) progress(0.95, desc="Computing tumour volume...") positive_voxels, tumor_volume_ml = summarize_segmentation(seg_path) stats = ( f"Case ID : {case_id}\n" f"Cancer type : {config['display_name']}\n" f"Positive voxels: {positive_voxels:,}\n" f"Tumour volume : {tumor_volume_ml:.3f} mL" ) slices = viz["slices"] video_path = viz["video"] video_out = ( str(video_path) if video_path.exists() and video_path.stat().st_size > 0 else None ) progress(1.0, desc="Done!") return ( stats, str(seg_path), str(slices.get("centroid")), str(slices.get("max_area")), str(slices.get("extent25")), str(slices.get("extent75")), video_out, ) except Exception as e: shutil.rmtree(output_dir, ignore_errors=True) raise gr.Error(str(e)) # ── UI ───────────────────────────────────────────────────────────────────────── def build_ui(available_labels=None): labels = available_labels or list(CANCER_TYPE_CHOICES.keys()) single = len(labels) == 1 default_label = labels[0] if single: title = f"# PanCancerSeg — {default_label} CT Segmentation" intro = ( f"Upload a `.nii.gz` CT image and click **Run Inference** to segment " f"**{default_label.lower()}** and obtain a mask plus visualisations." ) else: title = "# PanCancerSeg — Specialist CT Tumour Segmentation" intro = ( "Upload a `.nii.gz` CT image, select the cancer type, and click " "**Run Inference** to obtain a segmentation mask and visualisations." ) n_examples = count_examples(default_label) if single else 2 with gr.Blocks(title="PanCancerSeg Inference") as demo: gr.Markdown(f"{title}\n{intro}") with gr.Row(): # ── Left panel: inputs ───────────────────────────────────────────── with gr.Column(scale=1, min_width=300): input_file = gr.File( label="CT Image (.nii.gz)", file_types=[".gz"], ) cancer_type = gr.Dropdown( choices=labels, value=default_label, label="Cancer Type", interactive=not single, ) fps = gr.Slider( minimum=1, maximum=30, value=10, step=1, label="Video FPS", ) example_buttons = [] if n_examples > 0: with gr.Row(): for i in range(1, n_examples + 1): label = "Load Example" if n_examples == 1 else f"Load Example {i}" example_buttons.append(gr.Button(label, size="lg")) run_btn = gr.Button("Run Inference", variant="primary", size="lg") video_out = gr.Video(label="Overlay Video") # ── Right panel: outputs ─────────────────────────────────────────── with gr.Column(scale=2): with gr.Row(): stats_box = gr.Textbox( label="Inference Summary", lines=4, interactive=False, ) seg_file = gr.File(label="Download Segmentation Mask (.nii.gz)") with gr.Row(): img_centroid = gr.Image(label="Centroid Slice", type="filepath") img_max_area = gr.Image(label="Max Area Slice", type="filepath") with gr.Row(): img_ext25 = gr.Image(label="Extent 25% Slice", type="filepath") img_ext75 = gr.Image(label="Extent 75% Slice", type="filepath") for idx, btn in enumerate(example_buttons, start=1): btn.click( fn=(lambda i: lambda ct: load_example(ct, i))(idx), inputs=[cancer_type], outputs=[input_file], ) run_btn.click( fn=run_inference, inputs=[input_file, cancer_type, fps], outputs=[ stats_box, seg_file, img_centroid, img_max_area, img_ext25, img_ext75, video_out, ], ) return demo if __name__ == "__main__": import os # Warm the weights cache at startup so the very first inference (and every # later one) does not trigger a download. Failures are non-fatal: we fall # back to lazy download on the first request. labels = None try: weights_dir = resolve_weights_dir() labels = available_cancer_labels(weights_dir) print(f"[startup] available cancer models: {labels}") except Exception as e: print(f"[startup] weight pre-fetch skipped: {e}") demo = build_ui(labels) # Hugging Face Spaces expect the app on port 7860 (set via GRADIO_SERVER_PORT). # Locally this falls back to 7860 unless overridden. port = int(os.environ.get("GRADIO_SERVER_PORT", 7860)) demo.launch( server_name="0.0.0.0", server_port=port, share=False, theme=gr.themes.Soft(), ssr_mode=False, mcp_server=True )