VibecoderMcSwaggins's picture
feat: Gradio Custom Component for NiiVue (#29)
227ab66 unverified
"""Neuroimaging visualization for Gradio.
This module provides visualization components for neuroimaging data:
- Matplotlib-based 2D slice comparisons
- NIfTI URL helper for Custom Component
See:
- docs/specs/07-hf-spaces-deployment.md
- docs/specs/19-perf-base64-to-file-urls.md (Issue #19 optimization)
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from pathlib import Path
from matplotlib.figure import Figure
from stroke_deepisles_demo.core.logging import get_logger
from stroke_deepisles_demo.metrics import load_nifti_as_array
logger = get_logger(__name__)
def nifti_to_gradio_url(nifti_path: Path) -> str:
"""
Get Gradio file URL for a NIfTI file.
Uses Gradio's built-in file serving instead of base64 encoding.
This reduces payload size by ~33% and improves browser performance
by avoiding large base64 strings in the DOM.
Args:
nifti_path: Path to NIfTI file. Must be in an allowed path:
- tempfile.gettempdir() (default for pipeline results)
- Current working directory
- Paths specified in allowed_paths during launch()
Returns:
Gradio file URL (e.g., /gradio_api/file=/tmp/.../dwi.nii.gz)
Note:
This replaces the deprecated nifti_to_data_url() function.
See Issue #19 for performance analysis and benchmarks.
References:
- https://www.gradio.app/guides/file-access
- https://niivue.com/docs/loading/
"""
# Ensure we use absolute path for Gradio's file serving
abs_path = nifti_path.resolve()
# Gradio file URL format (standard since Gradio 4.x)
# Files in tempfile.gettempdir() are allowed by default
return f"/gradio_api/file={abs_path}"
def get_slice_at_max_lesion(
mask_path: Path,
orientation: str = "axial",
) -> int:
"""
Find slice index with maximum lesion area.
Useful for displaying the most informative slice.
Args:
mask_path: Path to lesion mask NIfTI
orientation: Slice orientation ("axial", "coronal", "sagittal")
Returns:
Slice index with maximum lesion area
"""
data, _ = load_nifti_as_array(mask_path)
# Determine axes to sum over
# Default NIfTI (RAS+): x=sagittal, y=coronal, z=axial
# array indices: [x, y, z]
if orientation == "sagittal":
# Sum over y and z (axes 1, 2) -> result shape [x]
lesion_counts = np.sum(data > 0, axis=(1, 2))
elif orientation == "coronal":
# Sum over x and z (axes 0, 2) -> result shape [y]
lesion_counts = np.sum(data > 0, axis=(0, 2))
else: # axial
# Sum over x and y (axes 0, 1) -> result shape [z]
lesion_counts = np.sum(data > 0, axis=(0, 1))
max_slice = int(np.argmax(lesion_counts))
# If mask is empty, return middle slice
if np.max(lesion_counts) == 0:
if orientation == "sagittal":
return int(data.shape[0] // 2)
elif orientation == "coronal":
return int(data.shape[1] // 2)
else:
return int(data.shape[2] // 2)
return max_slice
def render_3panel_view(
nifti_path: Path,
mask_path: Path | None = None,
*,
mask_alpha: float = 0.5,
) -> Figure:
"""
Render axial/coronal/sagittal slices with optional mask overlay.
Args:
nifti_path: Path to base NIfTI volume
mask_path: Optional path to mask for overlay
mask_alpha: Transparency of mask overlay
Returns:
Matplotlib figure with 3-panel view
"""
data, _ = load_nifti_as_array(nifti_path)
mask_data = None
if mask_path:
mask_data, _ = load_nifti_as_array(mask_path)
# Get slices (middle by default, or max lesion if mask exists)
mid_x, mid_y, mid_z = data.shape[0] // 2, data.shape[1] // 2, data.shape[2] // 2
if mask_data is not None and np.any(mask_data > 0):
# Try to find a slice that intersects the lesion best
# Simplified: use center of mass of lesion
coords = np.argwhere(mask_data > 0)
center = coords.mean(axis=0).astype(int)
mid_x, mid_y, mid_z = center[0], center[1], center[2]
# Create figure using OO API for thread safety
fig = Figure(figsize=(15, 5))
fig.patch.set_facecolor("black")
axes = fig.subplots(1, 3)
# Axial (XY plane, Z fixed) - often needs rotation 90 deg
# NIfTI data[x, y, z]. To display standard axial:
# usually imshow(data[:, :, z].T, origin='lower')
ax_slice = np.rot90(data[:, :, mid_z])
axes[0].imshow(ax_slice, cmap="gray")
axes[0].set_title(f"Axial (z={mid_z})", color="white")
if mask_data is not None:
m_slice = np.rot90(mask_data[:, :, mid_z])
# Binarize at 0.5 threshold for visible overlay (consistent with compute_dice)
m_slice_binary = (m_slice > 0.5).astype(float)
axes[0].imshow(
np.ma.masked_where(m_slice_binary == 0, m_slice_binary), # type: ignore[no-untyped-call]
cmap="Reds",
alpha=mask_alpha,
vmin=0,
vmax=1,
)
# Coronal (XZ plane, Y fixed)
cor_slice = np.rot90(data[:, mid_y, :])
axes[1].imshow(cor_slice, cmap="gray")
axes[1].set_title(f"Coronal (y={mid_y})", color="white")
if mask_data is not None:
m_slice = np.rot90(mask_data[:, mid_y, :])
# Binarize at 0.5 threshold for visible overlay (consistent with compute_dice)
m_slice_binary = (m_slice > 0.5).astype(float)
axes[1].imshow(
np.ma.masked_where(m_slice_binary == 0, m_slice_binary), # type: ignore[no-untyped-call]
cmap="Reds",
alpha=mask_alpha,
vmin=0,
vmax=1,
)
# Sagittal (YZ plane, X fixed)
sag_slice = np.rot90(data[mid_x, :, :])
axes[2].imshow(sag_slice, cmap="gray")
axes[2].set_title(f"Sagittal (x={mid_x})", color="white")
if mask_data is not None:
m_slice = np.rot90(mask_data[mid_x, :, :])
# Binarize at 0.5 threshold for visible overlay (consistent with compute_dice)
m_slice_binary = (m_slice > 0.5).astype(float)
axes[2].imshow(
np.ma.masked_where(m_slice_binary == 0, m_slice_binary), # type: ignore[no-untyped-call]
cmap="Reds",
alpha=mask_alpha,
vmin=0,
vmax=1,
)
for ax in axes:
ax.axis("off")
fig.tight_layout()
return fig
def render_slice_comparison(
dwi_path: Path,
prediction_path: Path,
ground_truth_path: Path | None = None,
*,
slice_idx: int | None = None,
orientation: str = "axial",
) -> Figure:
"""
Render side-by-side comparison of DWI, prediction, and ground truth.
Args:
dwi_path: Path to DWI NIfTI
prediction_path: Path to predicted mask NIfTI
ground_truth_path: Optional path to ground truth mask
slice_idx: Slice index (default: max lesion or middle)
orientation: One of "axial", "coronal", "sagittal"
Returns:
Matplotlib figure with comparison view
"""
dwi_data, _ = load_nifti_as_array(dwi_path)
pred_data, _ = load_nifti_as_array(prediction_path)
gt_data = None
if ground_truth_path:
gt_data, _ = load_nifti_as_array(ground_truth_path)
# Determine slice index
if slice_idx is None:
# Use prediction to find best slice
slice_idx = get_slice_at_max_lesion(prediction_path, orientation)
# Extract slices based on orientation
# Assuming data[x, y, z]
if orientation == "sagittal":
# X fixed
d_slice = np.rot90(dwi_data[slice_idx, :, :])
p_slice = np.rot90(pred_data[slice_idx, :, :])
g_slice = np.rot90(gt_data[slice_idx, :, :]) if gt_data is not None else None
elif orientation == "coronal":
# Y fixed
d_slice = np.rot90(dwi_data[:, slice_idx, :])
p_slice = np.rot90(pred_data[:, slice_idx, :])
g_slice = np.rot90(gt_data[:, slice_idx, :]) if gt_data is not None else None
else:
# Z fixed (axial)
d_slice = np.rot90(dwi_data[:, :, slice_idx])
p_slice = np.rot90(pred_data[:, :, slice_idx])
g_slice = np.rot90(gt_data[:, :, slice_idx]) if gt_data is not None else None
# Plotting
num_plots = 3 if gt_data is not None else 2
# Create figure using OO API for thread safety
fig = Figure(figsize=(5 * num_plots, 5))
fig.patch.set_facecolor("black")
axes = fig.subplots(1, num_plots)
if num_plots == 2:
axes = np.array(axes) # handle single case if needed, but subplots(1,2) returns array
# 1. DWI
axes[0].imshow(d_slice, cmap="gray")
axes[0].set_title("DWI Input", color="white")
# 2. Prediction
# Binarize prediction at threshold 0.5 for visible overlay (Issue #23)
# Model output may contain probability values (0.0-1.0) which render as
# nearly-white in the "Reds" colormap. Binarizing ensures consistent
# visualization matching how compute_dice() evaluates predictions.
p_slice_binary = (p_slice > 0.5).astype(float)
axes[1].imshow(d_slice, cmap="gray")
axes[1].imshow(
np.ma.masked_where(p_slice_binary == 0, p_slice_binary), # type: ignore[no-untyped-call]
cmap="Reds",
alpha=0.5,
vmin=0,
vmax=1,
)
axes[1].set_title("Prediction", color="white")
# 3. GT (if available)
if gt_data is not None:
axes[2].imshow(d_slice, cmap="gray")
axes[2].imshow(
np.ma.masked_where(g_slice == 0, g_slice), # type: ignore[no-untyped-call]
cmap="Greens",
alpha=0.5,
vmin=0,
vmax=1,
)
axes[2].set_title("Ground Truth", color="white")
for ax in axes:
ax.axis("off")
fig.tight_layout()
return fig