stroke-deepisles-demo / tests /ui /test_viewer.py
VibecoderMcSwaggins's picture
feat: Gradio Custom Component for NiiVue (#29)
227ab66 unverified
"""Tests for viewer module."""
from __future__ import annotations
from typing import TYPE_CHECKING
import matplotlib
# Non-interactive backend for tests - must be before pyplot import
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from stroke_deepisles_demo.ui.viewer import (
get_slice_at_max_lesion,
nifti_to_gradio_url,
render_3panel_view,
render_slice_comparison,
)
if TYPE_CHECKING:
from pathlib import Path
class TestRender3PanelView:
"""Tests for render_3panel_view."""
def test_returns_matplotlib_figure(self, synthetic_nifti_3d: Path) -> None:
"""Returns a matplotlib Figure object."""
fig = render_3panel_view(synthetic_nifti_3d)
assert isinstance(fig, Figure)
plt.close(fig)
def test_has_three_axes(self, synthetic_nifti_3d: Path) -> None:
"""Figure has 3 subplots (axial, coronal, sagittal)."""
fig = render_3panel_view(synthetic_nifti_3d)
assert len(fig.axes) == 3
plt.close(fig)
def test_overlay_mask_when_provided(self, synthetic_nifti_3d: Path, temp_dir: Path) -> None:
"""Overlays mask when mask_path provided."""
# Create a simple mask
import nibabel as nib
mask_data = np.zeros((10, 10, 10), dtype=np.uint8)
mask_data[4:6, 4:6, 4:6] = 1
mask_img = nib.Nifti1Image(mask_data, np.eye(4)) # type: ignore
mask_path = temp_dir / "mask.nii.gz"
nib.save(mask_img, mask_path) # type: ignore
fig = render_3panel_view(synthetic_nifti_3d, mask_path=mask_path)
# Should not raise
assert fig is not None
plt.close(fig)
class TestRenderSliceComparison:
"""Tests for render_slice_comparison."""
def test_comparison_without_ground_truth(self, synthetic_nifti_3d: Path) -> None:
"""Works when ground truth is None."""
fig = render_slice_comparison(
synthetic_nifti_3d,
synthetic_nifti_3d, # Use same as prediction for test
ground_truth_path=None,
)
assert isinstance(fig, Figure)
plt.close(fig)
def test_comparison_with_ground_truth(self, synthetic_nifti_3d: Path) -> None:
"""Works when ground truth is provided."""
fig = render_slice_comparison(
synthetic_nifti_3d,
synthetic_nifti_3d,
ground_truth_path=synthetic_nifti_3d,
)
assert isinstance(fig, Figure)
plt.close(fig)
class TestGetSliceAtMaxLesion:
"""Tests for get_slice_at_max_lesion."""
def test_finds_slice_with_lesion(self, temp_dir: Path) -> None:
"""Returns slice index where lesion is largest."""
import nibabel as nib
# Create mask with lesion at slice 7
mask_data = np.zeros((10, 10, 10), dtype=np.uint8)
mask_data[:, :, 7] = 1 # Full slice 7 is lesion
mask_img = nib.Nifti1Image(mask_data, np.eye(4)) # type: ignore
mask_path = temp_dir / "mask.nii.gz"
nib.save(mask_img, mask_path) # type: ignore
slice_idx = get_slice_at_max_lesion(mask_path, orientation="axial")
assert slice_idx == 7
def test_returns_middle_for_empty_mask(self, temp_dir: Path) -> None:
"""Returns middle slice when mask is empty."""
import nibabel as nib
mask_data = np.zeros((10, 10, 20), dtype=np.uint8)
mask_img = nib.Nifti1Image(mask_data, np.eye(4)) # type: ignore
mask_path = temp_dir / "mask.nii.gz"
nib.save(mask_img, mask_path) # type: ignore
slice_idx = get_slice_at_max_lesion(mask_path, orientation="axial")
assert slice_idx == 10 # Middle of 20
class TestNiftiToGradioUrl:
"""Tests for nifti_to_gradio_url (Issue #19 optimization)."""
def test_returns_gradio_api_format(self, synthetic_nifti_3d: Path) -> None:
"""Returns URL in Gradio API format."""
url = nifti_to_gradio_url(synthetic_nifti_3d)
assert url.startswith("/gradio_api/file=")
def test_uses_absolute_path(self, synthetic_nifti_3d: Path) -> None:
"""URL contains absolute path to file."""
url = nifti_to_gradio_url(synthetic_nifti_3d)
# Extract path from URL
path_part = url.replace("/gradio_api/file=", "")
assert path_part.startswith("/") # Absolute path
assert "synthetic.nii.gz" in path_part
def test_preserves_file_extension(self, synthetic_nifti_3d: Path) -> None:
"""URL preserves .nii.gz extension."""
url = nifti_to_gradio_url(synthetic_nifti_3d)
assert url.endswith(".nii.gz")
def test_no_base64_encoding(self, synthetic_nifti_3d: Path) -> None:
"""URL does not contain base64-encoded data (Issue #19 requirement)."""
url = nifti_to_gradio_url(synthetic_nifti_3d)
# Base64 data URLs start with "data:" and contain ";base64,"
assert not url.startswith("data:")
assert ";base64," not in url
class TestRenderSliceComparisonProbabilityMask:
"""Tests for render_slice_comparison with probability masks (Issue #23).
This test class verifies that probability-valued prediction masks
are rendered visibly. The bug occurs when:
- Ground truth is binary (0 or 1) → renders as visible green
- Prediction is probability (0.1-0.5) → renders as nearly-invisible white
See: docs/specs/23-slice-comparison-overlay-bug.md
"""
def test_probability_mask_has_visible_overlay(
self,
synthetic_nifti_3d: Path,
synthetic_probability_mask: Path,
) -> None:
"""
Probability mask should produce visible overlay in rendering.
This test exposes the bug where low probability values (e.g., 0.3)
render as nearly-white in the "Reds" colormap and are invisible.
"""
fig = render_slice_comparison(
synthetic_nifti_3d,
synthetic_probability_mask, # Probability values 0.3, 0.7
ground_truth_path=None,
)
# Get the prediction axis (index 1)
ax = fig.axes[1]
# The axis should have at least 2 images (DWI background + overlay)
images = ax.get_images()
assert len(images) >= 2, "Prediction panel should have overlay image"
# The overlay should have non-zero alpha (visible)
overlay = images[1]
alpha = overlay.get_alpha()
assert alpha is None or alpha > 0 # None means default alpha (1.0)
plt.close(fig)
def test_binary_vs_probability_mask_comparison(
self,
synthetic_nifti_3d: Path,
synthetic_binary_mask: Path,
synthetic_probability_mask: Path,
) -> None:
"""
Both binary and probability masks should render visible overlays.
This is the core test for Issue #23. If the probability mask renders
invisibly while the binary mask renders visibly, the bug is confirmed.
"""
# Render with binary mask (expected to work)
fig_binary = render_slice_comparison(
synthetic_nifti_3d,
synthetic_binary_mask,
ground_truth_path=None,
)
# Render with probability mask (may be invisible - the bug)
fig_prob = render_slice_comparison(
synthetic_nifti_3d,
synthetic_probability_mask,
ground_truth_path=None,
)
# Get overlay data from both
binary_overlay = fig_binary.axes[1].get_images()[1].get_array()
prob_overlay = fig_prob.axes[1].get_images()[1].get_array()
# Both should have non-masked (visible) pixels
binary_visible = (
not binary_overlay.mask.all() # type: ignore[union-attr]
if hasattr(binary_overlay, "mask")
else True
)
prob_visible = (
not prob_overlay.mask.all() # type: ignore[union-attr]
if hasattr(prob_overlay, "mask")
else True
)
assert binary_visible, "Binary mask overlay should have visible pixels"
assert prob_visible, "Probability mask overlay should have visible pixels"
plt.close(fig_binary)
plt.close(fig_prob)