Spaces:
Runtime error
Runtime error
fix(ui): prediction overlay invisible, race condition, thread safety (#23) (#23)
Browse filesPrimary fix: Prediction mask probability values (0.0-0.3) rendered nearly-white
in "Reds" colormap. Now binarized at 0.5 threshold for visible red overlay.
Additional fixes discovered during audit:
- Race condition: replaced global _previous_results_dir with gr.State
- compute_volume_ml: added threshold=0.5 for consistent binarization
- render_3panel_view: wired into UI with Tabs layout (Interactive 3D / Static Report)
- Matplotlib thread safety: refactored from pyplot to OO API (Figure())
All 136 tests pass. Lint and type checks clean.
- docs/specs/23-slice-comparison-overlay-bug.md +287 -0
- src/stroke_deepisles_demo/metrics.py +9 -1
- src/stroke_deepisles_demo/ui/app.py +66 -27
- src/stroke_deepisles_demo/ui/components.py +16 -11
- src/stroke_deepisles_demo/ui/viewer.py +25 -11
- tests/conftest.py +41 -0
- tests/ui/test_app.py +8 -2
- tests/ui/test_viewer.py +91 -0
docs/specs/23-slice-comparison-overlay-bug.md
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Bug Investigation: Slice Comparison Prediction Overlay Not Visible
|
| 2 |
+
|
| 3 |
+
**Issue**: Prediction overlay is invisible in slice comparison while ground truth overlay is visible
|
| 4 |
+
|
| 5 |
+
**Date**: 2025-12-09
|
| 6 |
+
**Branch**: `debug/slice-comparison-prediction-overlay`
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## Observed Behavior
|
| 11 |
+
|
| 12 |
+
In the Gradio UI "Slice Comparison" tab:
|
| 13 |
+
- **DWI Input** (left panel): Shows grayscale brain scan ✓
|
| 14 |
+
- **Prediction** (middle panel): Shows grayscale brain scan **without any visible overlay** ✗
|
| 15 |
+
- **Ground Truth** (right panel): Shows grayscale brain scan **with green overlay** ✓
|
| 16 |
+
|
| 17 |
+
## Expected Behavior
|
| 18 |
+
|
| 19 |
+
The Prediction panel should show a **red overlay** on the predicted lesion area, similar to how Ground Truth shows a green overlay.
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## Code Analysis
|
| 24 |
+
|
| 25 |
+
### Visualization Code (`viewer.py:261-268`)
|
| 26 |
+
|
| 27 |
+
```python
|
| 28 |
+
# Prediction panel
|
| 29 |
+
axes[1].imshow(d_slice, cmap="gray")
|
| 30 |
+
axes[1].imshow(
|
| 31 |
+
np.ma.masked_where(p_slice == 0, p_slice),
|
| 32 |
+
cmap="Reds",
|
| 33 |
+
alpha=0.5,
|
| 34 |
+
vmin=0,
|
| 35 |
+
vmax=1,
|
| 36 |
+
)
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
### Ground Truth Code (`viewer.py:273-280`)
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
# Ground Truth panel
|
| 43 |
+
axes[2].imshow(d_slice, cmap="gray")
|
| 44 |
+
axes[2].imshow(
|
| 45 |
+
np.ma.masked_where(g_slice == 0, g_slice),
|
| 46 |
+
cmap="Greens",
|
| 47 |
+
alpha=0.5,
|
| 48 |
+
vmin=0,
|
| 49 |
+
vmax=1,
|
| 50 |
+
)
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
The code is **structurally identical**. The only difference is:
|
| 54 |
+
- Prediction: `cmap="Reds"`
|
| 55 |
+
- Ground Truth: `cmap="Greens"`
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## Hypothesis
|
| 60 |
+
|
| 61 |
+
### Primary Hypothesis: Probability vs Binary Mask Values
|
| 62 |
+
|
| 63 |
+
| Mask Type | Typical Values | Colormap Rendering | Visibility |
|
| 64 |
+
|-----------|----------------|-------------------|------------|
|
| 65 |
+
| Ground Truth | Binary (0 or 1) | 1.0 → **Dark Green** | High ✓ |
|
| 66 |
+
| Prediction | Probabilities (0.0-0.3) | 0.1 → **Nearly White** | None ✗ |
|
| 67 |
+
|
| 68 |
+
**Why this matters:**
|
| 69 |
+
|
| 70 |
+
1. Matplotlib's **"Reds" colormap** goes from white (0) → red (1)
|
| 71 |
+
2. With `vmin=0, vmax=1`:
|
| 72 |
+
- A value of `0.05` maps to 5% of the colormap = nearly white
|
| 73 |
+
- A value of `1.0` maps to 100% of the colormap = red
|
| 74 |
+
3. With `alpha=0.5` over a grayscale background, nearly-white overlays are **invisible**
|
| 75 |
+
|
| 76 |
+
**Evidence:**
|
| 77 |
+
- DeepISLES SEALS model may output probability maps, not binary masks
|
| 78 |
+
- The `compute_dice` function in `metrics.py` applies a `threshold=0.5` to binarize predictions
|
| 79 |
+
- The visualization does **not** apply any thresholding before display
|
| 80 |
+
|
| 81 |
+
### Alternative Hypotheses
|
| 82 |
+
|
| 83 |
+
1. **Empty slice**: Prediction mask is all zeros at the selected slice (unlikely given the slice selection logic uses `get_slice_at_max_lesion(prediction_path)`)
|
| 84 |
+
|
| 85 |
+
2. **Data type issue**: Float comparison `p_slice == 0` may fail for float32 arrays (unlikely - works for ground truth)
|
| 86 |
+
|
| 87 |
+
3. **File path mismatch**: Wrong file being loaded as prediction (need to verify)
|
| 88 |
+
|
| 89 |
+
---
|
| 90 |
+
|
| 91 |
+
## Diagnostic Steps
|
| 92 |
+
|
| 93 |
+
### 1. Check Prediction Mask Values
|
| 94 |
+
|
| 95 |
+
```python
|
| 96 |
+
import nibabel as nib
|
| 97 |
+
import numpy as np
|
| 98 |
+
|
| 99 |
+
# Load a prediction mask from a recent run
|
| 100 |
+
pred = nib.load("/path/to/prediction.nii.gz").get_fdata()
|
| 101 |
+
print(f"Shape: {pred.shape}")
|
| 102 |
+
print(f"Dtype: {pred.dtype}")
|
| 103 |
+
print(f"Min: {pred.min()}, Max: {pred.max()}")
|
| 104 |
+
print(f"Unique values: {np.unique(pred)[:20]}") # First 20 unique values
|
| 105 |
+
print(f"Non-zero count: {np.count_nonzero(pred)}")
|
| 106 |
+
print(f"Values > 0.5: {np.count_nonzero(pred > 0.5)}")
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
### 2. Check Ground Truth Mask Values
|
| 110 |
+
|
| 111 |
+
```python
|
| 112 |
+
gt = nib.load("/path/to/ground_truth.nii.gz").get_fdata()
|
| 113 |
+
print(f"Shape: {gt.shape}")
|
| 114 |
+
print(f"Dtype: {gt.dtype}")
|
| 115 |
+
print(f"Min: {gt.min()}, Max: {gt.max()}")
|
| 116 |
+
print(f"Unique values: {np.unique(gt)}")
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
### 3. Visual Comparison
|
| 120 |
+
|
| 121 |
+
```python
|
| 122 |
+
# Plot histogram of values
|
| 123 |
+
import matplotlib.pyplot as plt
|
| 124 |
+
fig, axes = plt.subplots(1, 2)
|
| 125 |
+
axes[0].hist(pred[pred > 0].flatten(), bins=50)
|
| 126 |
+
axes[0].set_title("Prediction non-zero values")
|
| 127 |
+
axes[1].hist(gt[gt > 0].flatten(), bins=50)
|
| 128 |
+
axes[1].set_title("Ground Truth non-zero values")
|
| 129 |
+
plt.savefig("mask_histograms.png")
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
---
|
| 133 |
+
|
| 134 |
+
## Proposed Fix
|
| 135 |
+
|
| 136 |
+
### Option A: Binarize Prediction Before Display (Recommended)
|
| 137 |
+
|
| 138 |
+
```python
|
| 139 |
+
# In render_slice_comparison, before creating overlay:
|
| 140 |
+
p_slice_binary = (p_slice > 0.5).astype(float)
|
| 141 |
+
|
| 142 |
+
axes[1].imshow(
|
| 143 |
+
np.ma.masked_where(p_slice_binary == 0, p_slice_binary),
|
| 144 |
+
cmap="Reds",
|
| 145 |
+
alpha=0.5,
|
| 146 |
+
vmin=0,
|
| 147 |
+
vmax=1,
|
| 148 |
+
)
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
**Pros:**
|
| 152 |
+
- Consistent with how `compute_dice` treats predictions
|
| 153 |
+
- Clear visualization of model decision boundary
|
| 154 |
+
- Matches clinical interpretation (lesion vs not-lesion)
|
| 155 |
+
|
| 156 |
+
**Cons:**
|
| 157 |
+
- Loses probability information in visualization
|
| 158 |
+
|
| 159 |
+
### Option B: Dynamic Normalization
|
| 160 |
+
|
| 161 |
+
```python
|
| 162 |
+
# Normalize to actual value range instead of fixed 0-1
|
| 163 |
+
p_max = p_slice.max() if p_slice.max() > 0 else 1.0
|
| 164 |
+
axes[1].imshow(
|
| 165 |
+
np.ma.masked_where(p_slice == 0, p_slice),
|
| 166 |
+
cmap="Reds",
|
| 167 |
+
alpha=0.5,
|
| 168 |
+
vmin=0,
|
| 169 |
+
vmax=p_max,
|
| 170 |
+
)
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
**Pros:**
|
| 174 |
+
- Shows probability information
|
| 175 |
+
- Works regardless of value range
|
| 176 |
+
|
| 177 |
+
**Cons:**
|
| 178 |
+
- Inconsistent intensity across cases
|
| 179 |
+
- Low-confidence predictions still appear bright (misleading)
|
| 180 |
+
|
| 181 |
+
### Option C: Threshold-Based Masking
|
| 182 |
+
|
| 183 |
+
```python
|
| 184 |
+
# Only show values above a threshold
|
| 185 |
+
threshold = 0.5
|
| 186 |
+
axes[1].imshow(
|
| 187 |
+
np.ma.masked_where(p_slice < threshold, p_slice),
|
| 188 |
+
cmap="Reds",
|
| 189 |
+
alpha=0.5,
|
| 190 |
+
vmin=threshold,
|
| 191 |
+
vmax=1.0,
|
| 192 |
+
)
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
**Pros:**
|
| 196 |
+
- Only shows confident predictions
|
| 197 |
+
- Good dynamic range for visible values
|
| 198 |
+
|
| 199 |
+
**Cons:**
|
| 200 |
+
- May hide uncertain but potentially relevant areas
|
| 201 |
+
|
| 202 |
+
---
|
| 203 |
+
|
| 204 |
+
## Recommendation
|
| 205 |
+
|
| 206 |
+
**Implement Option A (Binarize)** because:
|
| 207 |
+
|
| 208 |
+
1. It matches the clinical use case (segmentation → binary decision)
|
| 209 |
+
2. It's consistent with `compute_dice` threshold behavior
|
| 210 |
+
3. It provides clear, interpretable visualization
|
| 211 |
+
4. The raw probability map can still be viewed in NiiVue if needed
|
| 212 |
+
|
| 213 |
+
---
|
| 214 |
+
|
| 215 |
+
## Dependencies
|
| 216 |
+
|
| 217 |
+
| Package | Version | Relevant |
|
| 218 |
+
|---------|---------|----------|
|
| 219 |
+
| gradio | >=6.0.0 | Unlikely cause (renders matplotlib figure correctly) |
|
| 220 |
+
| matplotlib | >=3.8.0 | Colormap behavior is standard |
|
| 221 |
+
| numpy | >=1.26.0,<2.0.0 | Float comparison works correctly |
|
| 222 |
+
| nibabel | >=5.2.0 | Loads data correctly |
|
| 223 |
+
|
| 224 |
+
---
|
| 225 |
+
|
| 226 |
+
## Resolution
|
| 227 |
+
|
| 228 |
+
**Status**: FIXED (2025-12-09)
|
| 229 |
+
**Branch**: `debug/slice-comparison-prediction-overlay`
|
| 230 |
+
|
| 231 |
+
### Changes Made
|
| 232 |
+
|
| 233 |
+
**Primary Fix (Issue #23):**
|
| 234 |
+
|
| 235 |
+
1. **`viewer.py:270-275`**: Added binarization of prediction mask in `render_slice_comparison`:
|
| 236 |
+
```python
|
| 237 |
+
# Binarize prediction at threshold 0.5 for visible overlay (Issue #23)
|
| 238 |
+
p_slice_binary = (p_slice > 0.5).astype(float)
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
2. **`viewer.py:156-164`**: Added binarization in `render_3panel_view` for consistency
|
| 242 |
+
|
| 243 |
+
3. **`tests/conftest.py`**: Added `synthetic_probability_mask` and `synthetic_binary_mask` fixtures
|
| 244 |
+
|
| 245 |
+
4. **`tests/ui/test_viewer.py`**: Added `TestRenderSliceComparisonProbabilityMask` test class
|
| 246 |
+
|
| 247 |
+
**Additional Fixes (Found During Audit):**
|
| 248 |
+
|
| 249 |
+
5. **Race Condition (P2)**: Replaced global `_previous_results_dir` with `gr.State` for per-session thread-safe cleanup tracking
|
| 250 |
+
|
| 251 |
+
6. **Inconsistent Threshold in compute_volume_ml**: Added `threshold=0.5` parameter for consistent binarization
|
| 252 |
+
|
| 253 |
+
7. **render_3panel_view Wired Into UI**:
|
| 254 |
+
- Added `gr.Tabs` layout with "Interactive 3D" and "Static Report" tabs
|
| 255 |
+
- `render_3panel_view` now displayed in "Static Report" alongside slice comparison
|
| 256 |
+
- Provides WebGL2 fallback via static matplotlib figures
|
| 257 |
+
|
| 258 |
+
8. **Thread-Safe Matplotlib**: Refactored from `pyplot` API to Object-Oriented API (`Figure()`) for multi-user safety
|
| 259 |
+
|
| 260 |
+
### Verification
|
| 261 |
+
|
| 262 |
+
- All 136 tests pass
|
| 263 |
+
- Lint (ruff) passes
|
| 264 |
+
- Type check (mypy) passes
|
| 265 |
+
|
| 266 |
+
## Files Modified
|
| 267 |
+
|
| 268 |
+
| File | Changes |
|
| 269 |
+
|------|---------|
|
| 270 |
+
| `src/stroke_deepisles_demo/ui/viewer.py` | OO matplotlib API, binarization in both render functions |
|
| 271 |
+
| `src/stroke_deepisles_demo/ui/app.py` | gr.State, render_3panel_view integration, volume_ml |
|
| 272 |
+
| `src/stroke_deepisles_demo/ui/components.py` | Tabs layout (Interactive 3D / Static Report) |
|
| 273 |
+
| `src/stroke_deepisles_demo/metrics.py` | threshold parameter for compute_volume_ml |
|
| 274 |
+
| `tests/conftest.py` | New probability/binary mask fixtures |
|
| 275 |
+
| `tests/ui/test_viewer.py` | Probability mask tests |
|
| 276 |
+
| `tests/ui/test_app.py` | Updated for new return signature |
|
| 277 |
+
|
| 278 |
+
## Next Steps
|
| 279 |
+
|
| 280 |
+
1. [x] Run diagnostic script to confirm hypothesis
|
| 281 |
+
2. [x] Implement fix (Option A - binarize)
|
| 282 |
+
3. [x] Add test case for probability-valued masks
|
| 283 |
+
4. [x] Wire render_3panel_view into UI with tabs
|
| 284 |
+
5. [x] Fix race condition with gr.State
|
| 285 |
+
6. [x] Make matplotlib thread-safe with OO API
|
| 286 |
+
7. [ ] Verify fix in local Gradio app (manual testing recommended)
|
| 287 |
+
8. [ ] Create PR and merge to main
|
src/stroke_deepisles_demo/metrics.py
CHANGED
|
@@ -91,6 +91,8 @@ def compute_dice(
|
|
| 91 |
def compute_volume_ml(
|
| 92 |
mask: Path | NDArray[np.floating[Any]],
|
| 93 |
voxel_size_mm: tuple[float, float, float] | None = None,
|
|
|
|
|
|
|
| 94 |
) -> float:
|
| 95 |
"""
|
| 96 |
Compute lesion volume in milliliters.
|
|
@@ -98,9 +100,14 @@ def compute_volume_ml(
|
|
| 98 |
Args:
|
| 99 |
mask: Path to NIfTI file or numpy array
|
| 100 |
voxel_size_mm: Voxel dimensions in mm (read from NIfTI if None)
|
|
|
|
| 101 |
|
| 102 |
Returns:
|
| 103 |
Volume in milliliters (mL)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
"""
|
| 105 |
if isinstance(mask, Path):
|
| 106 |
data, loaded_zooms = load_nifti_as_array(mask)
|
|
@@ -110,7 +117,8 @@ def compute_volume_ml(
|
|
| 110 |
# Default to 1mm isotropic if not provided for array
|
| 111 |
voxel_dims = voxel_size_mm if voxel_size_mm is not None else (1.0, 1.0, 1.0)
|
| 112 |
|
| 113 |
-
|
|
|
|
| 114 |
voxel_vol_mm3 = math.prod(voxel_dims)
|
| 115 |
|
| 116 |
return float(volume_voxels * voxel_vol_mm3 / 1000.0) # mm3 -> mL
|
|
|
|
| 91 |
def compute_volume_ml(
|
| 92 |
mask: Path | NDArray[np.floating[Any]],
|
| 93 |
voxel_size_mm: tuple[float, float, float] | None = None,
|
| 94 |
+
*,
|
| 95 |
+
threshold: float = 0.5,
|
| 96 |
) -> float:
|
| 97 |
"""
|
| 98 |
Compute lesion volume in milliliters.
|
|
|
|
| 100 |
Args:
|
| 101 |
mask: Path to NIfTI file or numpy array
|
| 102 |
voxel_size_mm: Voxel dimensions in mm (read from NIfTI if None)
|
| 103 |
+
threshold: Threshold for binarization (default 0.5 for consistency with compute_dice)
|
| 104 |
|
| 105 |
Returns:
|
| 106 |
Volume in milliliters (mL)
|
| 107 |
+
|
| 108 |
+
Note:
|
| 109 |
+
Uses the same default threshold (0.5) as compute_dice for consistency.
|
| 110 |
+
This ensures the volume measurement matches the clinical segmentation decision boundary.
|
| 111 |
"""
|
| 112 |
if isinstance(mask, Path):
|
| 113 |
data, loaded_zooms = load_nifti_as_array(mask)
|
|
|
|
| 117 |
# Default to 1mm isotropic if not provided for array
|
| 118 |
voxel_dims = voxel_size_mm if voxel_size_mm is not None else (1.0, 1.0, 1.0)
|
| 119 |
|
| 120 |
+
# Binarize at threshold for consistent measurement with compute_dice
|
| 121 |
+
volume_voxels = np.sum(data > threshold)
|
| 122 |
voxel_vol_mm3 = math.prod(voxel_dims)
|
| 123 |
|
| 124 |
return float(volume_voxels * voxel_vol_mm3 / 1000.0) # mm3 -> mL
|
src/stroke_deepisles_demo/ui/app.py
CHANGED
|
@@ -3,13 +3,15 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import shutil
|
| 6 |
-
from
|
|
|
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
from matplotlib.figure import Figure # noqa: TC002
|
| 10 |
|
| 11 |
from stroke_deepisles_demo.core.logging import get_logger
|
| 12 |
from stroke_deepisles_demo.data import list_case_ids
|
|
|
|
| 13 |
from stroke_deepisles_demo.pipeline import run_pipeline_on_case
|
| 14 |
from stroke_deepisles_demo.ui.components import (
|
| 15 |
create_case_selector,
|
|
@@ -20,17 +22,12 @@ from stroke_deepisles_demo.ui.viewer import (
|
|
| 20 |
NIIVUE_UPDATE_JS,
|
| 21 |
create_niivue_html,
|
| 22 |
nifti_to_gradio_url,
|
|
|
|
| 23 |
render_slice_comparison,
|
| 24 |
)
|
| 25 |
|
| 26 |
-
if TYPE_CHECKING:
|
| 27 |
-
from pathlib import Path
|
| 28 |
-
|
| 29 |
logger = get_logger(__name__)
|
| 30 |
|
| 31 |
-
# Shared output directory for UI results (cleaned up between runs to prevent disk accumulation)
|
| 32 |
-
_previous_results_dir: Path | None = None
|
| 33 |
-
|
| 34 |
|
| 35 |
def initialize_case_selector() -> gr.Dropdown:
|
| 36 |
"""
|
|
@@ -57,9 +54,26 @@ def initialize_case_selector() -> gr.Dropdown:
|
|
| 57 |
return gr.Dropdown(choices=[], info=f"Error loading data: {e!s}")
|
| 58 |
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
def run_segmentation(
|
| 61 |
-
case_id: str,
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
| 63 |
"""
|
| 64 |
Run segmentation and return results for display.
|
| 65 |
|
|
@@ -67,30 +81,26 @@ def run_segmentation(
|
|
| 67 |
case_id: Selected case identifier
|
| 68 |
fast_mode: Whether to use fast mode (SEALS)
|
| 69 |
show_ground_truth: Whether to show ground truth in plots
|
|
|
|
| 70 |
|
| 71 |
Returns:
|
| 72 |
-
Tuple of (niivue_html, slice_fig, metrics_dict, download_path, status_msg)
|
|
|
|
| 73 |
"""
|
| 74 |
if not case_id:
|
| 75 |
return (
|
| 76 |
"",
|
| 77 |
None,
|
|
|
|
| 78 |
{},
|
| 79 |
None,
|
| 80 |
"Please select a case first.",
|
|
|
|
| 81 |
)
|
| 82 |
|
| 83 |
try:
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# Clean up previous results to prevent disk accumulation on HF Spaces
|
| 87 |
-
if _previous_results_dir is not None and _previous_results_dir.exists():
|
| 88 |
-
try:
|
| 89 |
-
shutil.rmtree(_previous_results_dir)
|
| 90 |
-
logger.debug("Cleaned up previous results: %s", _previous_results_dir)
|
| 91 |
-
except OSError as e:
|
| 92 |
-
# Log but don't fail - cleanup is best-effort
|
| 93 |
-
logger.warning("Failed to cleanup %s: %s", _previous_results_dir, e)
|
| 94 |
|
| 95 |
logger.info("Running segmentation for %s", case_id)
|
| 96 |
result = run_pipeline_on_case(
|
|
@@ -100,9 +110,6 @@ def run_segmentation(
|
|
| 100 |
cleanup_staging=True,
|
| 101 |
)
|
| 102 |
|
| 103 |
-
# Track results_dir for cleanup on next run
|
| 104 |
-
_previous_results_dir = result.results_dir
|
| 105 |
-
|
| 106 |
# 1. NiiVue Visualization
|
| 107 |
# Use Gradio's file serving (Issue #19 optimization)
|
| 108 |
# This eliminates ~65MB base64 payloads, improving load times and browser memory
|
|
@@ -122,8 +129,10 @@ def run_segmentation(
|
|
| 122 |
height=500,
|
| 123 |
)
|
| 124 |
|
| 125 |
-
# 2.
|
| 126 |
gt_path = result.ground_truth if show_ground_truth else None
|
|
|
|
|
|
|
| 127 |
slice_fig = render_slice_comparison(
|
| 128 |
dwi_path=dwi_path,
|
| 129 |
prediction_path=result.prediction_mask,
|
|
@@ -131,10 +140,24 @@ def run_segmentation(
|
|
| 131 |
orientation="axial",
|
| 132 |
)
|
| 133 |
|
| 134 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
metrics = {
|
| 136 |
"case_id": result.case_id,
|
| 137 |
"dice_score": result.dice_score,
|
|
|
|
| 138 |
"elapsed_seconds": round(result.elapsed_seconds, 2),
|
| 139 |
"model": "SEALS (Fast)" if fast_mode else "Ensemble",
|
| 140 |
}
|
|
@@ -148,11 +171,20 @@ def run_segmentation(
|
|
| 148 |
else "Success!"
|
| 149 |
)
|
| 150 |
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
except Exception as e:
|
| 154 |
logger.exception("Error running segmentation")
|
| 155 |
-
return "", None, {}, None, f"Error: {e!s}"
|
| 156 |
|
| 157 |
|
| 158 |
def create_app() -> gr.Blocks:
|
|
@@ -165,6 +197,10 @@ def create_app() -> gr.Blocks:
|
|
| 165 |
with gr.Blocks(
|
| 166 |
title="Stroke Lesion Segmentation Demo",
|
| 167 |
) as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
# Header
|
| 169 |
gr.Markdown("""
|
| 170 |
# Stroke Lesion Segmentation Demo
|
|
@@ -197,13 +233,16 @@ def create_app() -> gr.Blocks:
|
|
| 197 |
case_selector,
|
| 198 |
settings["fast_mode"],
|
| 199 |
settings["show_ground_truth"],
|
|
|
|
| 200 |
],
|
| 201 |
outputs=[
|
| 202 |
results["niivue_viewer"],
|
| 203 |
results["slice_plot"],
|
|
|
|
| 204 |
results["metrics"],
|
| 205 |
results["download"],
|
| 206 |
status,
|
|
|
|
| 207 |
],
|
| 208 |
).then(
|
| 209 |
fn=None, # Explicitly None to run JS only
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import shutil
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
from matplotlib.figure import Figure # noqa: TC002
|
| 11 |
|
| 12 |
from stroke_deepisles_demo.core.logging import get_logger
|
| 13 |
from stroke_deepisles_demo.data import list_case_ids
|
| 14 |
+
from stroke_deepisles_demo.metrics import compute_volume_ml
|
| 15 |
from stroke_deepisles_demo.pipeline import run_pipeline_on_case
|
| 16 |
from stroke_deepisles_demo.ui.components import (
|
| 17 |
create_case_selector,
|
|
|
|
| 22 |
NIIVUE_UPDATE_JS,
|
| 23 |
create_niivue_html,
|
| 24 |
nifti_to_gradio_url,
|
| 25 |
+
render_3panel_view,
|
| 26 |
render_slice_comparison,
|
| 27 |
)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
| 29 |
logger = get_logger(__name__)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
def initialize_case_selector() -> gr.Dropdown:
|
| 33 |
"""
|
|
|
|
| 54 |
return gr.Dropdown(choices=[], info=f"Error loading data: {e!s}")
|
| 55 |
|
| 56 |
|
| 57 |
+
def _cleanup_previous_results(previous_results_dir: str | None) -> None:
|
| 58 |
+
"""Clean up previous results directory (per-session, thread-safe)."""
|
| 59 |
+
if previous_results_dir is None:
|
| 60 |
+
return
|
| 61 |
+
prev_path = Path(previous_results_dir)
|
| 62 |
+
if prev_path.exists():
|
| 63 |
+
try:
|
| 64 |
+
shutil.rmtree(prev_path)
|
| 65 |
+
logger.debug("Cleaned up previous results: %s", prev_path)
|
| 66 |
+
except OSError as e:
|
| 67 |
+
# Log but don't fail - cleanup is best-effort
|
| 68 |
+
logger.warning("Failed to cleanup %s: %s", prev_path, e)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
def run_segmentation(
|
| 72 |
+
case_id: str,
|
| 73 |
+
fast_mode: bool,
|
| 74 |
+
show_ground_truth: bool,
|
| 75 |
+
previous_results_dir: str | None,
|
| 76 |
+
) -> tuple[str, Figure | None, Figure | None, dict[str, Any], str | None, str, str | None]:
|
| 77 |
"""
|
| 78 |
Run segmentation and return results for display.
|
| 79 |
|
|
|
|
| 81 |
case_id: Selected case identifier
|
| 82 |
fast_mode: Whether to use fast mode (SEALS)
|
| 83 |
show_ground_truth: Whether to show ground truth in plots
|
| 84 |
+
previous_results_dir: Path to previous results (from gr.State, for cleanup)
|
| 85 |
|
| 86 |
Returns:
|
| 87 |
+
Tuple of (niivue_html, slice_fig, ortho_fig, metrics_dict, download_path, status_msg, new_results_dir)
|
| 88 |
+
The new_results_dir is returned to update the gr.State for next cleanup.
|
| 89 |
"""
|
| 90 |
if not case_id:
|
| 91 |
return (
|
| 92 |
"",
|
| 93 |
None,
|
| 94 |
+
None,
|
| 95 |
{},
|
| 96 |
None,
|
| 97 |
"Please select a case first.",
|
| 98 |
+
previous_results_dir, # Keep existing state
|
| 99 |
)
|
| 100 |
|
| 101 |
try:
|
| 102 |
+
# Clean up previous results (per-session, thread-safe via gr.State)
|
| 103 |
+
_cleanup_previous_results(previous_results_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
logger.info("Running segmentation for %s", case_id)
|
| 106 |
result = run_pipeline_on_case(
|
|
|
|
| 110 |
cleanup_staging=True,
|
| 111 |
)
|
| 112 |
|
|
|
|
|
|
|
|
|
|
| 113 |
# 1. NiiVue Visualization
|
| 114 |
# Use Gradio's file serving (Issue #19 optimization)
|
| 115 |
# This eliminates ~65MB base64 payloads, improving load times and browser memory
|
|
|
|
| 129 |
height=500,
|
| 130 |
)
|
| 131 |
|
| 132 |
+
# 2. Static Visualizations (Matplotlib)
|
| 133 |
gt_path = result.ground_truth if show_ground_truth else None
|
| 134 |
+
|
| 135 |
+
# 2a. Slice Comparison
|
| 136 |
slice_fig = render_slice_comparison(
|
| 137 |
dwi_path=dwi_path,
|
| 138 |
prediction_path=result.prediction_mask,
|
|
|
|
| 140 |
orientation="axial",
|
| 141 |
)
|
| 142 |
|
| 143 |
+
# 2b. Orthogonal 3-Panel View
|
| 144 |
+
ortho_fig = render_3panel_view(
|
| 145 |
+
nifti_path=dwi_path,
|
| 146 |
+
mask_path=result.prediction_mask,
|
| 147 |
+
mask_alpha=0.5,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# 3. Metrics (including volume with consistent 0.5 threshold)
|
| 151 |
+
volume_ml: float | None = None
|
| 152 |
+
try:
|
| 153 |
+
volume_ml = round(compute_volume_ml(result.prediction_mask, threshold=0.5), 2)
|
| 154 |
+
except Exception:
|
| 155 |
+
logger.warning("Failed to compute volume for %s", case_id, exc_info=True)
|
| 156 |
+
|
| 157 |
metrics = {
|
| 158 |
"case_id": result.case_id,
|
| 159 |
"dice_score": result.dice_score,
|
| 160 |
+
"volume_ml": volume_ml,
|
| 161 |
"elapsed_seconds": round(result.elapsed_seconds, 2),
|
| 162 |
"model": "SEALS (Fast)" if fast_mode else "Ensemble",
|
| 163 |
}
|
|
|
|
| 171 |
else "Success!"
|
| 172 |
)
|
| 173 |
|
| 174 |
+
# Return new results_dir to update gr.State for next cleanup
|
| 175 |
+
return (
|
| 176 |
+
niivue_html,
|
| 177 |
+
slice_fig,
|
| 178 |
+
ortho_fig,
|
| 179 |
+
metrics,
|
| 180 |
+
download_path,
|
| 181 |
+
status_msg,
|
| 182 |
+
str(result.results_dir),
|
| 183 |
+
)
|
| 184 |
|
| 185 |
except Exception as e:
|
| 186 |
logger.exception("Error running segmentation")
|
| 187 |
+
return "", None, None, {}, None, f"Error: {e!s}", previous_results_dir
|
| 188 |
|
| 189 |
|
| 190 |
def create_app() -> gr.Blocks:
|
|
|
|
| 197 |
with gr.Blocks(
|
| 198 |
title="Stroke Lesion Segmentation Demo",
|
| 199 |
) as demo:
|
| 200 |
+
# Per-session state for cleanup tracking (fixes race condition in multi-user env)
|
| 201 |
+
# This replaces the previous global _previous_results_dir variable
|
| 202 |
+
previous_results_state = gr.State(value=None)
|
| 203 |
+
|
| 204 |
# Header
|
| 205 |
gr.Markdown("""
|
| 206 |
# Stroke Lesion Segmentation Demo
|
|
|
|
| 233 |
case_selector,
|
| 234 |
settings["fast_mode"],
|
| 235 |
settings["show_ground_truth"],
|
| 236 |
+
previous_results_state, # Pass per-session state for cleanup
|
| 237 |
],
|
| 238 |
outputs=[
|
| 239 |
results["niivue_viewer"],
|
| 240 |
results["slice_plot"],
|
| 241 |
+
results["ortho_plot"],
|
| 242 |
results["metrics"],
|
| 243 |
results["download"],
|
| 244 |
status,
|
| 245 |
+
previous_results_state, # Update state with new results_dir
|
| 246 |
],
|
| 247 |
).then(
|
| 248 |
fn=None, # Explicitly None to run JS only
|
src/stroke_deepisles_demo/ui/components.py
CHANGED
|
@@ -39,17 +39,21 @@ def create_results_display() -> dict[str, gr.components.Component]:
|
|
| 39 |
"""
|
| 40 |
# Using gr.Group to group them visually
|
| 41 |
with gr.Group():
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
metrics = gr.JSON(label="Metrics")
|
| 55 |
download = gr.File(label="Download Prediction")
|
|
@@ -57,6 +61,7 @@ def create_results_display() -> dict[str, gr.components.Component]:
|
|
| 57 |
return {
|
| 58 |
"niivue_viewer": niivue_viewer,
|
| 59 |
"slice_plot": slice_plot,
|
|
|
|
| 60 |
"metrics": metrics,
|
| 61 |
"download": download,
|
| 62 |
}
|
|
|
|
| 39 |
"""
|
| 40 |
# Using gr.Group to group them visually
|
| 41 |
with gr.Group():
|
| 42 |
+
with gr.Tabs():
|
| 43 |
+
with gr.Tab("Interactive 3D"):
|
| 44 |
+
# NiiVue visualization uses HTML with js_on_load for JavaScript execution
|
| 45 |
+
# Note: Gradio strips <script> tags from HTML value for security,
|
| 46 |
+
# so we must use js_on_load to run our NiiVue initialization code.
|
| 47 |
+
# The HTML value contains data-* attributes with volume URLs.
|
| 48 |
+
niivue_viewer = gr.HTML(
|
| 49 |
+
label="Interactive 3D Viewer",
|
| 50 |
+
js_on_load=NIIVUE_ON_LOAD_JS,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
with gr.Tab("Static Report"):
|
| 54 |
+
# Slice comparisons (Matplotlib)
|
| 55 |
+
slice_plot = gr.Plot(label="Slice Comparison (Validation)")
|
| 56 |
+
ortho_plot = gr.Plot(label="Orthogonal Views (Anatomy)")
|
| 57 |
|
| 58 |
metrics = gr.JSON(label="Metrics")
|
| 59 |
download = gr.File(label="Download Prediction")
|
|
|
|
| 61 |
return {
|
| 62 |
"niivue_viewer": niivue_viewer,
|
| 63 |
"slice_plot": slice_plot,
|
| 64 |
+
"ortho_plot": ortho_plot,
|
| 65 |
"metrics": metrics,
|
| 66 |
"download": download,
|
| 67 |
}
|
src/stroke_deepisles_demo/ui/viewer.py
CHANGED
|
@@ -16,15 +16,14 @@ import json
|
|
| 16 |
import uuid
|
| 17 |
from typing import TYPE_CHECKING
|
| 18 |
|
| 19 |
-
import matplotlib.pyplot as plt
|
| 20 |
import numpy as np
|
|
|
|
| 21 |
|
| 22 |
from stroke_deepisles_demo.metrics import load_nifti_as_array
|
| 23 |
|
| 24 |
if TYPE_CHECKING:
|
| 25 |
from pathlib import Path
|
| 26 |
|
| 27 |
-
from matplotlib.figure import Figure
|
| 28 |
|
| 29 |
# NiiVue version - updated to latest stable (Dec 2025)
|
| 30 |
NIIVUE_VERSION = "0.65.0"
|
|
@@ -141,9 +140,10 @@ def render_3panel_view(
|
|
| 141 |
center = coords.mean(axis=0).astype(int)
|
| 142 |
mid_x, mid_y, mid_z = center[0], center[1], center[2]
|
| 143 |
|
| 144 |
-
# Create figure
|
| 145 |
-
fig
|
| 146 |
fig.patch.set_facecolor("black")
|
|
|
|
| 147 |
|
| 148 |
# Axial (XY plane, Z fixed) - often needs rotation 90 deg
|
| 149 |
# NIfTI data[x, y, z]. To display standard axial:
|
|
@@ -153,8 +153,10 @@ def render_3panel_view(
|
|
| 153 |
axes[0].set_title(f"Axial (z={mid_z})", color="white")
|
| 154 |
if mask_data is not None:
|
| 155 |
m_slice = np.rot90(mask_data[:, :, mid_z])
|
|
|
|
|
|
|
| 156 |
axes[0].imshow(
|
| 157 |
-
np.ma.masked_where(
|
| 158 |
cmap="Reds",
|
| 159 |
alpha=mask_alpha,
|
| 160 |
vmin=0,
|
|
@@ -167,8 +169,10 @@ def render_3panel_view(
|
|
| 167 |
axes[1].set_title(f"Coronal (y={mid_y})", color="white")
|
| 168 |
if mask_data is not None:
|
| 169 |
m_slice = np.rot90(mask_data[:, mid_y, :])
|
|
|
|
|
|
|
| 170 |
axes[1].imshow(
|
| 171 |
-
np.ma.masked_where(
|
| 172 |
cmap="Reds",
|
| 173 |
alpha=mask_alpha,
|
| 174 |
vmin=0,
|
|
@@ -181,8 +185,10 @@ def render_3panel_view(
|
|
| 181 |
axes[2].set_title(f"Sagittal (x={mid_x})", color="white")
|
| 182 |
if mask_data is not None:
|
| 183 |
m_slice = np.rot90(mask_data[mid_x, :, :])
|
|
|
|
|
|
|
| 184 |
axes[2].imshow(
|
| 185 |
-
np.ma.masked_where(
|
| 186 |
cmap="Reds",
|
| 187 |
alpha=mask_alpha,
|
| 188 |
vmin=0,
|
|
@@ -192,7 +198,7 @@ def render_3panel_view(
|
|
| 192 |
for ax in axes:
|
| 193 |
ax.axis("off")
|
| 194 |
|
| 195 |
-
|
| 196 |
return fig
|
| 197 |
|
| 198 |
|
|
@@ -248,8 +254,11 @@ def render_slice_comparison(
|
|
| 248 |
|
| 249 |
# Plotting
|
| 250 |
num_plots = 3 if gt_data is not None else 2
|
| 251 |
-
|
|
|
|
| 252 |
fig.patch.set_facecolor("black")
|
|
|
|
|
|
|
| 253 |
if num_plots == 2:
|
| 254 |
axes = np.array(axes) # handle single case if needed, but subplots(1,2) returns array
|
| 255 |
|
|
@@ -258,9 +267,14 @@ def render_slice_comparison(
|
|
| 258 |
axes[0].set_title("DWI Input", color="white")
|
| 259 |
|
| 260 |
# 2. Prediction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
axes[1].imshow(d_slice, cmap="gray")
|
| 262 |
axes[1].imshow(
|
| 263 |
-
np.ma.masked_where(
|
| 264 |
cmap="Reds",
|
| 265 |
alpha=0.5,
|
| 266 |
vmin=0,
|
|
@@ -283,7 +297,7 @@ def render_slice_comparison(
|
|
| 283 |
for ax in axes:
|
| 284 |
ax.axis("off")
|
| 285 |
|
| 286 |
-
|
| 287 |
return fig
|
| 288 |
|
| 289 |
|
|
|
|
| 16 |
import uuid
|
| 17 |
from typing import TYPE_CHECKING
|
| 18 |
|
|
|
|
| 19 |
import numpy as np
|
| 20 |
+
from matplotlib.figure import Figure
|
| 21 |
|
| 22 |
from stroke_deepisles_demo.metrics import load_nifti_as_array
|
| 23 |
|
| 24 |
if TYPE_CHECKING:
|
| 25 |
from pathlib import Path
|
| 26 |
|
|
|
|
| 27 |
|
| 28 |
# NiiVue version - updated to latest stable (Dec 2025)
|
| 29 |
NIIVUE_VERSION = "0.65.0"
|
|
|
|
| 140 |
center = coords.mean(axis=0).astype(int)
|
| 141 |
mid_x, mid_y, mid_z = center[0], center[1], center[2]
|
| 142 |
|
| 143 |
+
# Create figure using OO API for thread safety
|
| 144 |
+
fig = Figure(figsize=(15, 5))
|
| 145 |
fig.patch.set_facecolor("black")
|
| 146 |
+
axes = fig.subplots(1, 3)
|
| 147 |
|
| 148 |
# Axial (XY plane, Z fixed) - often needs rotation 90 deg
|
| 149 |
# NIfTI data[x, y, z]. To display standard axial:
|
|
|
|
| 153 |
axes[0].set_title(f"Axial (z={mid_z})", color="white")
|
| 154 |
if mask_data is not None:
|
| 155 |
m_slice = np.rot90(mask_data[:, :, mid_z])
|
| 156 |
+
# Binarize at 0.5 threshold for visible overlay (consistent with compute_dice)
|
| 157 |
+
m_slice_binary = (m_slice > 0.5).astype(float)
|
| 158 |
axes[0].imshow(
|
| 159 |
+
np.ma.masked_where(m_slice_binary == 0, m_slice_binary), # type: ignore[no-untyped-call]
|
| 160 |
cmap="Reds",
|
| 161 |
alpha=mask_alpha,
|
| 162 |
vmin=0,
|
|
|
|
| 169 |
axes[1].set_title(f"Coronal (y={mid_y})", color="white")
|
| 170 |
if mask_data is not None:
|
| 171 |
m_slice = np.rot90(mask_data[:, mid_y, :])
|
| 172 |
+
# Binarize at 0.5 threshold for visible overlay (consistent with compute_dice)
|
| 173 |
+
m_slice_binary = (m_slice > 0.5).astype(float)
|
| 174 |
axes[1].imshow(
|
| 175 |
+
np.ma.masked_where(m_slice_binary == 0, m_slice_binary), # type: ignore[no-untyped-call]
|
| 176 |
cmap="Reds",
|
| 177 |
alpha=mask_alpha,
|
| 178 |
vmin=0,
|
|
|
|
| 185 |
axes[2].set_title(f"Sagittal (x={mid_x})", color="white")
|
| 186 |
if mask_data is not None:
|
| 187 |
m_slice = np.rot90(mask_data[mid_x, :, :])
|
| 188 |
+
# Binarize at 0.5 threshold for visible overlay (consistent with compute_dice)
|
| 189 |
+
m_slice_binary = (m_slice > 0.5).astype(float)
|
| 190 |
axes[2].imshow(
|
| 191 |
+
np.ma.masked_where(m_slice_binary == 0, m_slice_binary), # type: ignore[no-untyped-call]
|
| 192 |
cmap="Reds",
|
| 193 |
alpha=mask_alpha,
|
| 194 |
vmin=0,
|
|
|
|
| 198 |
for ax in axes:
|
| 199 |
ax.axis("off")
|
| 200 |
|
| 201 |
+
fig.tight_layout()
|
| 202 |
return fig
|
| 203 |
|
| 204 |
|
|
|
|
| 254 |
|
| 255 |
# Plotting
|
| 256 |
num_plots = 3 if gt_data is not None else 2
|
| 257 |
+
# Create figure using OO API for thread safety
|
| 258 |
+
fig = Figure(figsize=(5 * num_plots, 5))
|
| 259 |
fig.patch.set_facecolor("black")
|
| 260 |
+
axes = fig.subplots(1, num_plots)
|
| 261 |
+
|
| 262 |
if num_plots == 2:
|
| 263 |
axes = np.array(axes) # handle single case if needed, but subplots(1,2) returns array
|
| 264 |
|
|
|
|
| 267 |
axes[0].set_title("DWI Input", color="white")
|
| 268 |
|
| 269 |
# 2. Prediction
|
| 270 |
+
# Binarize prediction at threshold 0.5 for visible overlay (Issue #23)
|
| 271 |
+
# Model output may contain probability values (0.0-1.0) which render as
|
| 272 |
+
# nearly-white in the "Reds" colormap. Binarizing ensures consistent
|
| 273 |
+
# visualization matching how compute_dice() evaluates predictions.
|
| 274 |
+
p_slice_binary = (p_slice > 0.5).astype(float)
|
| 275 |
axes[1].imshow(d_slice, cmap="gray")
|
| 276 |
axes[1].imshow(
|
| 277 |
+
np.ma.masked_where(p_slice_binary == 0, p_slice_binary), # type: ignore[no-untyped-call]
|
| 278 |
cmap="Reds",
|
| 279 |
alpha=0.5,
|
| 280 |
vmin=0,
|
|
|
|
| 297 |
for ax in axes:
|
| 298 |
ax.axis("off")
|
| 299 |
|
| 300 |
+
fig.tight_layout()
|
| 301 |
return fig
|
| 302 |
|
| 303 |
|
tests/conftest.py
CHANGED
|
@@ -61,6 +61,47 @@ def synthetic_case_files(temp_dir: Path) -> CaseFiles:
|
|
| 61 |
)
|
| 62 |
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
@pytest.fixture
|
| 65 |
def synthetic_isles_dir(temp_dir: Path) -> Path:
|
| 66 |
"""
|
|
|
|
| 61 |
)
|
| 62 |
|
| 63 |
|
| 64 |
+
@pytest.fixture
|
| 65 |
+
def synthetic_probability_mask(temp_dir: Path) -> Path:
|
| 66 |
+
"""
|
| 67 |
+
Create a synthetic probability mask (float values 0.0-1.0).
|
| 68 |
+
|
| 69 |
+
This simulates model output that may contain probability values
|
| 70 |
+
rather than binary 0/1 masks. Used to test visualization handling
|
| 71 |
+
of probability-valued segmentation masks.
|
| 72 |
+
|
| 73 |
+
The mask has values ONLY at slice 5 to ensure get_slice_at_max_lesion selects it:
|
| 74 |
+
- Outer region with low probability (0.3) - below 0.5 threshold
|
| 75 |
+
- Inner region with high probability (0.8) - above 0.5 threshold
|
| 76 |
+
|
| 77 |
+
See: docs/specs/23-slice-comparison-overlay-bug.md
|
| 78 |
+
"""
|
| 79 |
+
mask_data = np.zeros((10, 10, 10), dtype=np.float32)
|
| 80 |
+
|
| 81 |
+
# Only populate slice 5 to ensure it's selected as max lesion slice
|
| 82 |
+
# Outer region: low confidence (below 0.5 threshold)
|
| 83 |
+
mask_data[2:8, 2:8, 5] = 0.3
|
| 84 |
+
# Inner region: high confidence (above 0.5 threshold) - this should be visible
|
| 85 |
+
mask_data[3:7, 3:7, 5] = 0.8
|
| 86 |
+
|
| 87 |
+
img = nib.Nifti1Image(mask_data, affine=np.eye(4)) # type: ignore
|
| 88 |
+
path = temp_dir / "probability_mask.nii.gz"
|
| 89 |
+
nib.save(img, path) # type: ignore
|
| 90 |
+
return path
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@pytest.fixture
|
| 94 |
+
def synthetic_binary_mask(temp_dir: Path) -> Path:
|
| 95 |
+
"""Create a synthetic binary mask (0 or 1 values only)."""
|
| 96 |
+
mask_data = np.zeros((10, 10, 10), dtype=np.uint8)
|
| 97 |
+
mask_data[3:7, 3:7, 4:6] = 1 # Binary lesion region
|
| 98 |
+
|
| 99 |
+
img = nib.Nifti1Image(mask_data, affine=np.eye(4)) # type: ignore
|
| 100 |
+
path = temp_dir / "binary_mask.nii.gz"
|
| 101 |
+
nib.save(img, path) # type: ignore
|
| 102 |
+
return path
|
| 103 |
+
|
| 104 |
+
|
| 105 |
@pytest.fixture
|
| 106 |
def synthetic_isles_dir(temp_dir: Path) -> Path:
|
| 107 |
"""
|
tests/ui/test_app.py
CHANGED
|
@@ -67,12 +67,18 @@ def test_run_segmentation_logic() -> None:
|
|
| 67 |
),
|
| 68 |
patch("stroke_deepisles_demo.ui.app.create_niivue_html", return_value="<div></div>"),
|
| 69 |
patch("stroke_deepisles_demo.ui.app.render_slice_comparison", return_value=MagicMock()),
|
|
|
|
|
|
|
| 70 |
):
|
| 71 |
-
html, _fig, metrics, _dl_path, status = run_segmentation(
|
| 72 |
-
"sub-001",
|
|
|
|
|
|
|
|
|
|
| 73 |
)
|
| 74 |
|
| 75 |
assert html == "<div></div>"
|
| 76 |
assert metrics["case_id"] == "sub-001"
|
| 77 |
assert metrics["dice_score"] == 0.85
|
|
|
|
| 78 |
assert "Success" in status
|
|
|
|
| 67 |
),
|
| 68 |
patch("stroke_deepisles_demo.ui.app.create_niivue_html", return_value="<div></div>"),
|
| 69 |
patch("stroke_deepisles_demo.ui.app.render_slice_comparison", return_value=MagicMock()),
|
| 70 |
+
patch("stroke_deepisles_demo.ui.app.render_3panel_view", return_value=MagicMock()),
|
| 71 |
+
patch("stroke_deepisles_demo.ui.app.compute_volume_ml", return_value=15.5),
|
| 72 |
):
|
| 73 |
+
html, _fig, _ortho, metrics, _dl_path, status, _new_results_dir = run_segmentation(
|
| 74 |
+
"sub-001",
|
| 75 |
+
fast_mode=True,
|
| 76 |
+
show_ground_truth=True,
|
| 77 |
+
previous_results_dir=None, # No previous results in test
|
| 78 |
)
|
| 79 |
|
| 80 |
assert html == "<div></div>"
|
| 81 |
assert metrics["case_id"] == "sub-001"
|
| 82 |
assert metrics["dice_score"] == 0.85
|
| 83 |
+
assert "volume_ml" in metrics # New metric added
|
| 84 |
assert "Success" in status
|
tests/ui/test_viewer.py
CHANGED
|
@@ -152,6 +152,97 @@ class TestNiftiToGradioUrl:
|
|
| 152 |
assert ";base64," not in url
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
class TestCreateNiivueHtml:
|
| 156 |
"""Tests for create_niivue_html."""
|
| 157 |
|
|
|
|
| 152 |
assert ";base64," not in url
|
| 153 |
|
| 154 |
|
| 155 |
+
class TestRenderSliceComparisonProbabilityMask:
|
| 156 |
+
"""Tests for render_slice_comparison with probability masks (Issue #23).
|
| 157 |
+
|
| 158 |
+
This test class verifies that probability-valued prediction masks
|
| 159 |
+
are rendered visibly. The bug occurs when:
|
| 160 |
+
- Ground truth is binary (0 or 1) → renders as visible green
|
| 161 |
+
- Prediction is probability (0.1-0.5) → renders as nearly-invisible white
|
| 162 |
+
|
| 163 |
+
See: docs/specs/23-slice-comparison-overlay-bug.md
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def test_probability_mask_has_visible_overlay(
|
| 167 |
+
self,
|
| 168 |
+
synthetic_nifti_3d: Path,
|
| 169 |
+
synthetic_probability_mask: Path,
|
| 170 |
+
) -> None:
|
| 171 |
+
"""
|
| 172 |
+
Probability mask should produce visible overlay in rendering.
|
| 173 |
+
|
| 174 |
+
This test exposes the bug where low probability values (e.g., 0.3)
|
| 175 |
+
render as nearly-white in the "Reds" colormap and are invisible.
|
| 176 |
+
"""
|
| 177 |
+
fig = render_slice_comparison(
|
| 178 |
+
synthetic_nifti_3d,
|
| 179 |
+
synthetic_probability_mask, # Probability values 0.3, 0.7
|
| 180 |
+
ground_truth_path=None,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Get the prediction axis (index 1)
|
| 184 |
+
ax = fig.axes[1]
|
| 185 |
+
|
| 186 |
+
# The axis should have at least 2 images (DWI background + overlay)
|
| 187 |
+
images = ax.get_images()
|
| 188 |
+
assert len(images) >= 2, "Prediction panel should have overlay image"
|
| 189 |
+
|
| 190 |
+
# The overlay should have non-zero alpha (visible)
|
| 191 |
+
overlay = images[1]
|
| 192 |
+
alpha = overlay.get_alpha()
|
| 193 |
+
assert alpha is None or alpha > 0 # None means default alpha (1.0)
|
| 194 |
+
|
| 195 |
+
plt.close(fig)
|
| 196 |
+
|
| 197 |
+
def test_binary_vs_probability_mask_comparison(
|
| 198 |
+
self,
|
| 199 |
+
synthetic_nifti_3d: Path,
|
| 200 |
+
synthetic_binary_mask: Path,
|
| 201 |
+
synthetic_probability_mask: Path,
|
| 202 |
+
) -> None:
|
| 203 |
+
"""
|
| 204 |
+
Both binary and probability masks should render visible overlays.
|
| 205 |
+
|
| 206 |
+
This is the core test for Issue #23. If the probability mask renders
|
| 207 |
+
invisibly while the binary mask renders visibly, the bug is confirmed.
|
| 208 |
+
"""
|
| 209 |
+
# Render with binary mask (expected to work)
|
| 210 |
+
fig_binary = render_slice_comparison(
|
| 211 |
+
synthetic_nifti_3d,
|
| 212 |
+
synthetic_binary_mask,
|
| 213 |
+
ground_truth_path=None,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Render with probability mask (may be invisible - the bug)
|
| 217 |
+
fig_prob = render_slice_comparison(
|
| 218 |
+
synthetic_nifti_3d,
|
| 219 |
+
synthetic_probability_mask,
|
| 220 |
+
ground_truth_path=None,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Get overlay data from both
|
| 224 |
+
binary_overlay = fig_binary.axes[1].get_images()[1].get_array()
|
| 225 |
+
prob_overlay = fig_prob.axes[1].get_images()[1].get_array()
|
| 226 |
+
|
| 227 |
+
# Both should have non-masked (visible) pixels
|
| 228 |
+
binary_visible = (
|
| 229 |
+
not binary_overlay.mask.all() # type: ignore[union-attr]
|
| 230 |
+
if hasattr(binary_overlay, "mask")
|
| 231 |
+
else True
|
| 232 |
+
)
|
| 233 |
+
prob_visible = (
|
| 234 |
+
not prob_overlay.mask.all() # type: ignore[union-attr]
|
| 235 |
+
if hasattr(prob_overlay, "mask")
|
| 236 |
+
else True
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
assert binary_visible, "Binary mask overlay should have visible pixels"
|
| 240 |
+
assert prob_visible, "Probability mask overlay should have visible pixels"
|
| 241 |
+
|
| 242 |
+
plt.close(fig_binary)
|
| 243 |
+
plt.close(fig_prob)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
class TestCreateNiivueHtml:
|
| 247 |
"""Tests for create_niivue_html."""
|
| 248 |
|