Spaces:
Runtime error
Runtime error
| """Neuroimaging visualization for Gradio. | |
| This module provides visualization components for neuroimaging data: | |
| - Matplotlib-based 2D slice comparisons | |
| - NIfTI URL helper for Custom Component | |
| See: | |
| - docs/specs/07-hf-spaces-deployment.md | |
| - docs/specs/19-perf-base64-to-file-urls.md (Issue #19 optimization) | |
| """ | |
| from __future__ import annotations | |
| from typing import TYPE_CHECKING | |
| import numpy as np | |
| if TYPE_CHECKING: | |
| from pathlib import Path | |
| from matplotlib.figure import Figure | |
| from stroke_deepisles_demo.core.logging import get_logger | |
| from stroke_deepisles_demo.metrics import load_nifti_as_array | |
| logger = get_logger(__name__) | |
| def nifti_to_gradio_url(nifti_path: Path) -> str: | |
| """ | |
| Get Gradio file URL for a NIfTI file. | |
| Uses Gradio's built-in file serving instead of base64 encoding. | |
| This reduces payload size by ~33% and improves browser performance | |
| by avoiding large base64 strings in the DOM. | |
| Args: | |
| nifti_path: Path to NIfTI file. Must be in an allowed path: | |
| - tempfile.gettempdir() (default for pipeline results) | |
| - Current working directory | |
| - Paths specified in allowed_paths during launch() | |
| Returns: | |
| Gradio file URL (e.g., /gradio_api/file=/tmp/.../dwi.nii.gz) | |
| Note: | |
| This replaces the deprecated nifti_to_data_url() function. | |
| See Issue #19 for performance analysis and benchmarks. | |
| References: | |
| - https://www.gradio.app/guides/file-access | |
| - https://niivue.com/docs/loading/ | |
| """ | |
| # Ensure we use absolute path for Gradio's file serving | |
| abs_path = nifti_path.resolve() | |
| # Gradio file URL format (standard since Gradio 4.x) | |
| # Files in tempfile.gettempdir() are allowed by default | |
| return f"/gradio_api/file={abs_path}" | |
| def get_slice_at_max_lesion( | |
| mask_path: Path, | |
| orientation: str = "axial", | |
| ) -> int: | |
| """ | |
| Find slice index with maximum lesion area. | |
| Useful for displaying the most informative slice. | |
| Args: | |
| mask_path: Path to lesion mask NIfTI | |
| orientation: Slice orientation ("axial", "coronal", "sagittal") | |
| Returns: | |
| Slice index with maximum lesion area | |
| """ | |
| data, _ = load_nifti_as_array(mask_path) | |
| # Determine axes to sum over | |
| # Default NIfTI (RAS+): x=sagittal, y=coronal, z=axial | |
| # array indices: [x, y, z] | |
| if orientation == "sagittal": | |
| # Sum over y and z (axes 1, 2) -> result shape [x] | |
| lesion_counts = np.sum(data > 0, axis=(1, 2)) | |
| elif orientation == "coronal": | |
| # Sum over x and z (axes 0, 2) -> result shape [y] | |
| lesion_counts = np.sum(data > 0, axis=(0, 2)) | |
| else: # axial | |
| # Sum over x and y (axes 0, 1) -> result shape [z] | |
| lesion_counts = np.sum(data > 0, axis=(0, 1)) | |
| max_slice = int(np.argmax(lesion_counts)) | |
| # If mask is empty, return middle slice | |
| if np.max(lesion_counts) == 0: | |
| if orientation == "sagittal": | |
| return int(data.shape[0] // 2) | |
| elif orientation == "coronal": | |
| return int(data.shape[1] // 2) | |
| else: | |
| return int(data.shape[2] // 2) | |
| return max_slice | |
| def render_3panel_view( | |
| nifti_path: Path, | |
| mask_path: Path | None = None, | |
| *, | |
| mask_alpha: float = 0.5, | |
| ) -> Figure: | |
| """ | |
| Render axial/coronal/sagittal slices with optional mask overlay. | |
| Args: | |
| nifti_path: Path to base NIfTI volume | |
| mask_path: Optional path to mask for overlay | |
| mask_alpha: Transparency of mask overlay | |
| Returns: | |
| Matplotlib figure with 3-panel view | |
| """ | |
| data, _ = load_nifti_as_array(nifti_path) | |
| mask_data = None | |
| if mask_path: | |
| mask_data, _ = load_nifti_as_array(mask_path) | |
| # Get slices (middle by default, or max lesion if mask exists) | |
| mid_x, mid_y, mid_z = data.shape[0] // 2, data.shape[1] // 2, data.shape[2] // 2 | |
| if mask_data is not None and np.any(mask_data > 0): | |
| # Try to find a slice that intersects the lesion best | |
| # Simplified: use center of mass of lesion | |
| coords = np.argwhere(mask_data > 0) | |
| center = coords.mean(axis=0).astype(int) | |
| mid_x, mid_y, mid_z = center[0], center[1], center[2] | |
| # Create figure using OO API for thread safety | |
| fig = Figure(figsize=(15, 5)) | |
| fig.patch.set_facecolor("black") | |
| axes = fig.subplots(1, 3) | |
| # Axial (XY plane, Z fixed) - often needs rotation 90 deg | |
| # NIfTI data[x, y, z]. To display standard axial: | |
| # usually imshow(data[:, :, z].T, origin='lower') | |
| ax_slice = np.rot90(data[:, :, mid_z]) | |
| axes[0].imshow(ax_slice, cmap="gray") | |
| axes[0].set_title(f"Axial (z={mid_z})", color="white") | |
| if mask_data is not None: | |
| m_slice = np.rot90(mask_data[:, :, mid_z]) | |
| # Binarize at 0.5 threshold for visible overlay (consistent with compute_dice) | |
| m_slice_binary = (m_slice > 0.5).astype(float) | |
| axes[0].imshow( | |
| np.ma.masked_where(m_slice_binary == 0, m_slice_binary), # type: ignore[no-untyped-call] | |
| cmap="Reds", | |
| alpha=mask_alpha, | |
| vmin=0, | |
| vmax=1, | |
| ) | |
| # Coronal (XZ plane, Y fixed) | |
| cor_slice = np.rot90(data[:, mid_y, :]) | |
| axes[1].imshow(cor_slice, cmap="gray") | |
| axes[1].set_title(f"Coronal (y={mid_y})", color="white") | |
| if mask_data is not None: | |
| m_slice = np.rot90(mask_data[:, mid_y, :]) | |
| # Binarize at 0.5 threshold for visible overlay (consistent with compute_dice) | |
| m_slice_binary = (m_slice > 0.5).astype(float) | |
| axes[1].imshow( | |
| np.ma.masked_where(m_slice_binary == 0, m_slice_binary), # type: ignore[no-untyped-call] | |
| cmap="Reds", | |
| alpha=mask_alpha, | |
| vmin=0, | |
| vmax=1, | |
| ) | |
| # Sagittal (YZ plane, X fixed) | |
| sag_slice = np.rot90(data[mid_x, :, :]) | |
| axes[2].imshow(sag_slice, cmap="gray") | |
| axes[2].set_title(f"Sagittal (x={mid_x})", color="white") | |
| if mask_data is not None: | |
| m_slice = np.rot90(mask_data[mid_x, :, :]) | |
| # Binarize at 0.5 threshold for visible overlay (consistent with compute_dice) | |
| m_slice_binary = (m_slice > 0.5).astype(float) | |
| axes[2].imshow( | |
| np.ma.masked_where(m_slice_binary == 0, m_slice_binary), # type: ignore[no-untyped-call] | |
| cmap="Reds", | |
| alpha=mask_alpha, | |
| vmin=0, | |
| vmax=1, | |
| ) | |
| for ax in axes: | |
| ax.axis("off") | |
| fig.tight_layout() | |
| return fig | |
| def render_slice_comparison( | |
| dwi_path: Path, | |
| prediction_path: Path, | |
| ground_truth_path: Path | None = None, | |
| *, | |
| slice_idx: int | None = None, | |
| orientation: str = "axial", | |
| ) -> Figure: | |
| """ | |
| Render side-by-side comparison of DWI, prediction, and ground truth. | |
| Args: | |
| dwi_path: Path to DWI NIfTI | |
| prediction_path: Path to predicted mask NIfTI | |
| ground_truth_path: Optional path to ground truth mask | |
| slice_idx: Slice index (default: max lesion or middle) | |
| orientation: One of "axial", "coronal", "sagittal" | |
| Returns: | |
| Matplotlib figure with comparison view | |
| """ | |
| dwi_data, _ = load_nifti_as_array(dwi_path) | |
| pred_data, _ = load_nifti_as_array(prediction_path) | |
| gt_data = None | |
| if ground_truth_path: | |
| gt_data, _ = load_nifti_as_array(ground_truth_path) | |
| # Determine slice index | |
| if slice_idx is None: | |
| # Use prediction to find best slice | |
| slice_idx = get_slice_at_max_lesion(prediction_path, orientation) | |
| # Extract slices based on orientation | |
| # Assuming data[x, y, z] | |
| if orientation == "sagittal": | |
| # X fixed | |
| d_slice = np.rot90(dwi_data[slice_idx, :, :]) | |
| p_slice = np.rot90(pred_data[slice_idx, :, :]) | |
| g_slice = np.rot90(gt_data[slice_idx, :, :]) if gt_data is not None else None | |
| elif orientation == "coronal": | |
| # Y fixed | |
| d_slice = np.rot90(dwi_data[:, slice_idx, :]) | |
| p_slice = np.rot90(pred_data[:, slice_idx, :]) | |
| g_slice = np.rot90(gt_data[:, slice_idx, :]) if gt_data is not None else None | |
| else: | |
| # Z fixed (axial) | |
| d_slice = np.rot90(dwi_data[:, :, slice_idx]) | |
| p_slice = np.rot90(pred_data[:, :, slice_idx]) | |
| g_slice = np.rot90(gt_data[:, :, slice_idx]) if gt_data is not None else None | |
| # Plotting | |
| num_plots = 3 if gt_data is not None else 2 | |
| # Create figure using OO API for thread safety | |
| fig = Figure(figsize=(5 * num_plots, 5)) | |
| fig.patch.set_facecolor("black") | |
| axes = fig.subplots(1, num_plots) | |
| if num_plots == 2: | |
| axes = np.array(axes) # handle single case if needed, but subplots(1,2) returns array | |
| # 1. DWI | |
| axes[0].imshow(d_slice, cmap="gray") | |
| axes[0].set_title("DWI Input", color="white") | |
| # 2. Prediction | |
| # Binarize prediction at threshold 0.5 for visible overlay (Issue #23) | |
| # Model output may contain probability values (0.0-1.0) which render as | |
| # nearly-white in the "Reds" colormap. Binarizing ensures consistent | |
| # visualization matching how compute_dice() evaluates predictions. | |
| p_slice_binary = (p_slice > 0.5).astype(float) | |
| axes[1].imshow(d_slice, cmap="gray") | |
| axes[1].imshow( | |
| np.ma.masked_where(p_slice_binary == 0, p_slice_binary), # type: ignore[no-untyped-call] | |
| cmap="Reds", | |
| alpha=0.5, | |
| vmin=0, | |
| vmax=1, | |
| ) | |
| axes[1].set_title("Prediction", color="white") | |
| # 3. GT (if available) | |
| if gt_data is not None: | |
| axes[2].imshow(d_slice, cmap="gray") | |
| axes[2].imshow( | |
| np.ma.masked_where(g_slice == 0, g_slice), # type: ignore[no-untyped-call] | |
| cmap="Greens", | |
| alpha=0.5, | |
| vmin=0, | |
| vmax=1, | |
| ) | |
| axes[2].set_title("Ground Truth", color="white") | |
| for ax in axes: | |
| ax.axis("off") | |
| fig.tight_layout() | |
| return fig | |