Spaces:
Runtime error
Runtime error
| """Main Gradio application for stroke-deepisles-demo.""" | |
| from __future__ import annotations | |
| import shutil | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| from matplotlib.figure import Figure # noqa: TC002 | |
| from stroke_deepisles_demo.core.logging import get_logger | |
| from stroke_deepisles_demo.data import list_case_ids | |
| from stroke_deepisles_demo.metrics import compute_volume_ml | |
| from stroke_deepisles_demo.pipeline import run_pipeline_on_case | |
| from stroke_deepisles_demo.ui.components import ( | |
| create_case_selector, | |
| create_results_display, | |
| create_settings_accordion, | |
| ) | |
| from stroke_deepisles_demo.ui.viewer import ( | |
| nifti_to_gradio_url, | |
| render_3panel_view, | |
| render_slice_comparison, | |
| ) | |
| logger = get_logger(__name__) | |
| def initialize_case_selector() -> gr.Dropdown: | |
| """ | |
| Initialize case selector by loading dataset (lazy load). | |
| This prevents the app from hanging during startup while downloading data. | |
| Called via demo.load() after the UI renders. | |
| """ | |
| try: | |
| logger.info("Initializing dataset for case selector...") | |
| case_ids = list_case_ids() | |
| if not case_ids: | |
| return gr.Dropdown(choices=[], info="No cases found in dataset.") | |
| return gr.Dropdown( | |
| choices=case_ids, | |
| value=case_ids[0], | |
| info="Choose a case from isles24-stroke dataset", | |
| interactive=True, | |
| ) | |
| except Exception as e: | |
| logger.exception("Failed to initialize dataset") | |
| return gr.Dropdown(choices=[], info=f"Error loading data: {e!s}") | |
| def _cleanup_previous_results(previous_results_dir: str | None) -> None: | |
| """Clean up previous results directory (per-session, thread-safe). | |
| Security: Validates path is under allowed results root to prevent | |
| arbitrary file deletion via manipulated Gradio state. | |
| """ | |
| if previous_results_dir is None: | |
| return | |
| from stroke_deepisles_demo.core.config import get_settings | |
| prev_path = Path(previous_results_dir).resolve() | |
| allowed_root = get_settings().results_dir.resolve() | |
| # Security: Ensure path is under allowed root (prevent path traversal) | |
| try: | |
| prev_path.relative_to(allowed_root) | |
| except ValueError: | |
| logger.warning( | |
| "Refusing to cleanup path outside allowed root: %s (root: %s)", | |
| prev_path, | |
| allowed_root, | |
| ) | |
| return | |
| if prev_path.exists(): | |
| try: | |
| shutil.rmtree(prev_path) | |
| logger.debug("Cleaned up previous results: %s", prev_path) | |
| except OSError as e: | |
| # Log but don't fail - cleanup is best-effort | |
| logger.warning("Failed to cleanup %s: %s", prev_path, e) | |
| def run_segmentation( | |
| case_id: str, | |
| fast_mode: bool, | |
| show_ground_truth: bool, | |
| previous_results_dir: str | None, | |
| ) -> tuple[ | |
| dict[str, str | None] | None, | |
| Figure | None, | |
| Figure | None, | |
| dict[str, Any], | |
| str | None, | |
| str, | |
| str | None, | |
| ]: | |
| """ | |
| Run segmentation and return results for display. | |
| Args: | |
| case_id: Selected case identifier | |
| fast_mode: Whether to use fast mode (SEALS) | |
| show_ground_truth: Whether to show ground truth in plots | |
| previous_results_dir: Path to previous results (from gr.State, for cleanup) | |
| Returns: | |
| Tuple of (niivue_data, slice_fig, ortho_fig, metrics_dict, download_path, status_msg, new_results_dir) | |
| The new_results_dir is returned to update the gr.State for next cleanup. | |
| """ | |
| if not case_id: | |
| return ( | |
| None, | |
| None, | |
| None, | |
| {}, | |
| None, | |
| "Please select a case first.", | |
| previous_results_dir, # Keep existing state | |
| ) | |
| try: | |
| # Clean up previous results (per-session, thread-safe via gr.State) | |
| _cleanup_previous_results(previous_results_dir) | |
| logger.info("Running segmentation for %s", case_id) | |
| result = run_pipeline_on_case( | |
| case_id, | |
| fast=fast_mode, | |
| compute_dice=True, | |
| cleanup_staging=True, | |
| ) | |
| # 1. NiiVue Visualization | |
| # Use Gradio's file serving (Issue #19 optimization) | |
| # This eliminates ~65MB base64 payloads, improving load times and browser memory | |
| # Files in tempfile.gettempdir() are accessible via /gradio_api/file= by default | |
| dwi_path = result.input_files["dwi"] | |
| dwi_url = nifti_to_gradio_url(dwi_path) | |
| # prediction_mask is always a valid Path from the pipeline (not Optional) | |
| # The .exists() check is defense-in-depth only | |
| mask_url = None | |
| if result.prediction_mask.exists(): | |
| mask_url = nifti_to_gradio_url(result.prediction_mask) | |
| niivue_data = {"background_url": dwi_url, "overlay_url": mask_url} | |
| # 2. Static Visualizations (Matplotlib) | |
| gt_path = result.ground_truth if show_ground_truth else None | |
| # 2a. Slice Comparison | |
| slice_fig = render_slice_comparison( | |
| dwi_path=dwi_path, | |
| prediction_path=result.prediction_mask, | |
| ground_truth_path=gt_path, | |
| orientation="axial", | |
| ) | |
| # 2b. Orthogonal 3-Panel View | |
| ortho_fig = render_3panel_view( | |
| nifti_path=dwi_path, | |
| mask_path=result.prediction_mask, | |
| mask_alpha=0.5, | |
| ) | |
| # 3. Metrics (including volume with consistent 0.5 threshold) | |
| volume_ml: float | None = None | |
| try: | |
| volume_ml = round(compute_volume_ml(result.prediction_mask, threshold=0.5), 2) | |
| except Exception: | |
| logger.warning("Failed to compute volume for %s", case_id, exc_info=True) | |
| metrics = { | |
| "case_id": result.case_id, | |
| "dice_score": result.dice_score, | |
| "volume_ml": volume_ml, | |
| "elapsed_seconds": round(result.elapsed_seconds, 2), | |
| "model": "SEALS (Fast)" if fast_mode else "Ensemble", | |
| } | |
| # 4. Download | |
| download_path = str(result.prediction_mask) | |
| status_msg = ( | |
| f"Success! Dice: {result.dice_score:.3f}" | |
| if result.dice_score is not None | |
| else "Success!" | |
| ) | |
| # Return new results_dir to update gr.State for next cleanup | |
| return ( | |
| niivue_data, | |
| slice_fig, | |
| ortho_fig, | |
| metrics, | |
| download_path, | |
| status_msg, | |
| str(result.results_dir), | |
| ) | |
| except Exception as e: | |
| logger.exception("Error running segmentation") | |
| return None, None, None, {}, None, f"Error: {e!s}", previous_results_dir | |
| def create_app() -> gr.Blocks: | |
| """ | |
| Create the Gradio application. | |
| Returns: | |
| Configured gr.Blocks application | |
| """ | |
| with gr.Blocks( | |
| title="Stroke Lesion Segmentation Demo", | |
| ) as demo: | |
| # Per-session state for cleanup tracking (fixes race condition in multi-user env) | |
| # This replaces the previous global _previous_results_dir variable | |
| previous_results_state = gr.State(value=None) | |
| # Header | |
| gr.Markdown(""" | |
| # Stroke Lesion Segmentation Demo | |
| This demo runs [DeepISLES](https://github.com/ezequieldlrosa/DeepIsles) | |
| stroke segmentation on cases from | |
| [isles24-stroke](https://huggingface.co/datasets/hugging-science/isles24-stroke). | |
| **Model:** SEALS (ISLES'22 winner) - Fast, accurate ischemic stroke lesion segmentation. | |
| **Note:** First run may take a moment to load models and data. | |
| """) | |
| with gr.Row(): | |
| # Left column: Controls | |
| with gr.Column(scale=1): | |
| case_selector = create_case_selector() | |
| settings = create_settings_accordion() | |
| run_btn = gr.Button("Run Segmentation", variant="primary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| # Right column: Results | |
| with gr.Column(scale=2): | |
| results = create_results_display() | |
| # Event handlers | |
| run_btn.click( | |
| fn=run_segmentation, | |
| inputs=[ | |
| case_selector, | |
| settings["fast_mode"], | |
| settings["show_ground_truth"], | |
| previous_results_state, # Pass per-session state for cleanup | |
| ], | |
| outputs=[ | |
| results["niivue_viewer"], | |
| results["slice_plot"], | |
| results["ortho_plot"], | |
| results["metrics"], | |
| results["download"], | |
| status, | |
| previous_results_state, # Update state with new results_dir | |
| ], | |
| ) | |
| # Note: No need for .then(js=...) anymore, the custom component updates reactively. | |
| # Trigger data loading after UI renders (prevents startup timeout) | |
| demo.load(initialize_case_selector, outputs=[case_selector]) | |
| return demo # type: ignore[no-any-return] | |
| # Lazy initialization pattern | |
| _demo: gr.Blocks | None = None | |
| def get_demo() -> gr.Blocks: | |
| """Get the global demo instance, creating it if necessary.""" | |
| global _demo | |
| if _demo is None: | |
| _demo = create_app() | |
| return _demo | |
| if __name__ == "__main__": | |
| from stroke_deepisles_demo.core.config import get_settings | |
| from stroke_deepisles_demo.core.logging import setup_logging | |
| settings = get_settings() | |
| setup_logging(settings.log_level, format_style=settings.log_format) | |
| # Log startup info for debugging HF Spaces issues | |
| logger.info("=" * 60) | |
| logger.info("STARTUP: stroke-deepisles-demo") | |
| logger.info("=" * 60) | |
| get_demo().launch( | |
| server_name=settings.gradio_server_name, | |
| server_port=settings.gradio_server_port, | |
| share=settings.gradio_share, | |
| theme=gr.themes.Soft(), | |
| css="footer {visibility: hidden}", | |
| show_error=settings.gradio_show_error, # Default False for security | |
| ) | |