VibecoderMcSwaggins's picture
feat(phase-4): Gradio UI with NiiVue visualization (#5)
d77e99f unverified
raw
history blame
4.6 kB
"""Tests for viewer module."""
from __future__ import annotations
from typing import TYPE_CHECKING
import matplotlib
# Non-interactive backend for tests - must be before pyplot import
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from stroke_deepisles_demo.ui.viewer import (
create_niivue_html,
get_slice_at_max_lesion,
render_3panel_view,
render_slice_comparison,
)
if TYPE_CHECKING:
from pathlib import Path
class TestRender3PanelView:
"""Tests for render_3panel_view."""
def test_returns_matplotlib_figure(self, synthetic_nifti_3d: Path) -> None:
"""Returns a matplotlib Figure object."""
fig = render_3panel_view(synthetic_nifti_3d)
assert isinstance(fig, Figure)
plt.close(fig)
def test_has_three_axes(self, synthetic_nifti_3d: Path) -> None:
"""Figure has 3 subplots (axial, coronal, sagittal)."""
fig = render_3panel_view(synthetic_nifti_3d)
assert len(fig.axes) == 3
plt.close(fig)
def test_overlay_mask_when_provided(self, synthetic_nifti_3d: Path, temp_dir: Path) -> None:
"""Overlays mask when mask_path provided."""
# Create a simple mask
import nibabel as nib
mask_data = np.zeros((10, 10, 10), dtype=np.uint8)
mask_data[4:6, 4:6, 4:6] = 1
mask_img = nib.Nifti1Image(mask_data, np.eye(4)) # type: ignore
mask_path = temp_dir / "mask.nii.gz"
nib.save(mask_img, mask_path) # type: ignore
fig = render_3panel_view(synthetic_nifti_3d, mask_path=mask_path)
# Should not raise
assert fig is not None
plt.close(fig)
class TestRenderSliceComparison:
"""Tests for render_slice_comparison."""
def test_comparison_without_ground_truth(self, synthetic_nifti_3d: Path) -> None:
"""Works when ground truth is None."""
fig = render_slice_comparison(
synthetic_nifti_3d,
synthetic_nifti_3d, # Use same as prediction for test
ground_truth_path=None,
)
assert isinstance(fig, Figure)
plt.close(fig)
def test_comparison_with_ground_truth(self, synthetic_nifti_3d: Path) -> None:
"""Works when ground truth is provided."""
fig = render_slice_comparison(
synthetic_nifti_3d,
synthetic_nifti_3d,
ground_truth_path=synthetic_nifti_3d,
)
assert isinstance(fig, Figure)
plt.close(fig)
class TestGetSliceAtMaxLesion:
"""Tests for get_slice_at_max_lesion."""
def test_finds_slice_with_lesion(self, temp_dir: Path) -> None:
"""Returns slice index where lesion is largest."""
import nibabel as nib
# Create mask with lesion at slice 7
mask_data = np.zeros((10, 10, 10), dtype=np.uint8)
mask_data[:, :, 7] = 1 # Full slice 7 is lesion
mask_img = nib.Nifti1Image(mask_data, np.eye(4)) # type: ignore
mask_path = temp_dir / "mask.nii.gz"
nib.save(mask_img, mask_path) # type: ignore
slice_idx = get_slice_at_max_lesion(mask_path, orientation="axial")
assert slice_idx == 7
def test_returns_middle_for_empty_mask(self, temp_dir: Path) -> None:
"""Returns middle slice when mask is empty."""
import nibabel as nib
mask_data = np.zeros((10, 10, 20), dtype=np.uint8)
mask_img = nib.Nifti1Image(mask_data, np.eye(4)) # type: ignore
mask_path = temp_dir / "mask.nii.gz"
nib.save(mask_img, mask_path) # type: ignore
slice_idx = get_slice_at_max_lesion(mask_path, orientation="axial")
assert slice_idx == 10 # Middle of 20
class TestCreateNiivueHtml:
"""Tests for create_niivue_html."""
def test_includes_volume_url(self) -> None:
"""Generated HTML includes the volume URL."""
html = create_niivue_html("http://example.com/brain.nii.gz")
assert "http://example.com/brain.nii.gz" in html
def test_includes_mask_when_provided(self) -> None:
"""Generated HTML includes mask URL when provided."""
html = create_niivue_html(
"http://example.com/brain.nii.gz",
mask_url="http://example.com/mask.nii.gz",
)
assert "http://example.com/mask.nii.gz" in html
def test_sets_height(self) -> None:
"""Generated HTML respects height parameter."""
html = create_niivue_html(
"http://example.com/brain.nii.gz",
height=600,
)
assert "height:600px" in html