NeuroSAM3 / tests /test_segmentation.py
mmrech's picture
Refactor codebase: Add modular structure, logging, validation, and comprehensive improvements
69066c5
"""
Tests for segmentation module.
"""
import unittest
import numpy as np
import tempfile
from PIL import Image
from segmentation import (
calculate_dice_score,
calculate_iou_score,
generate_grid_points,
format_roi_statistics,
)
class TestSegmentation(unittest.TestCase):
"""Test cases for segmentation functions."""
def test_calculate_dice_score_perfect_match(self):
"""Test Dice score calculation with perfect match."""
mask1 = np.ones((10, 10), dtype=bool)
mask2 = np.ones((10, 10), dtype=bool)
dice = calculate_dice_score(mask1, mask2)
self.assertEqual(dice, 1.0)
def test_calculate_dice_score_no_overlap(self):
"""Test Dice score calculation with no overlap."""
mask1 = np.zeros((10, 10), dtype=bool)
mask1[0:5, 0:5] = True
mask2 = np.zeros((10, 10), dtype=bool)
mask2[5:10, 5:10] = True
dice = calculate_dice_score(mask1, mask2)
self.assertEqual(dice, 0.0)
def test_calculate_dice_score_partial_overlap(self):
"""Test Dice score calculation with partial overlap."""
mask1 = np.zeros((10, 10), dtype=bool)
mask1[0:7, 0:7] = True
mask2 = np.zeros((10, 10), dtype=bool)
mask2[3:10, 3:10] = True
dice = calculate_dice_score(mask1, mask2)
self.assertGreater(dice, 0.0)
self.assertLess(dice, 1.0)
def test_calculate_iou_score_perfect_match(self):
"""Test IoU score calculation with perfect match."""
mask1 = np.ones((10, 10), dtype=bool)
mask2 = np.ones((10, 10), dtype=bool)
iou = calculate_iou_score(mask1, mask2)
self.assertEqual(iou, 1.0)
def test_calculate_iou_score_no_overlap(self):
"""Test IoU score calculation with no overlap."""
mask1 = np.zeros((10, 10), dtype=bool)
mask1[0:5, 0:5] = True
mask2 = np.zeros((10, 10), dtype=bool)
mask2[5:10, 5:10] = True
iou = calculate_iou_score(mask1, mask2)
self.assertEqual(iou, 0.0)
def test_generate_grid_points(self):
"""Test grid point generation."""
image_size = (100, 200)
points_per_side = 10
points = generate_grid_points(image_size, points_per_side)
self.assertEqual(points.shape[0], points_per_side * points_per_side)
self.assertEqual(points.shape[1], 2)
# Check that points are within image bounds
self.assertTrue(np.all(points[:, 0] >= 0))
self.assertTrue(np.all(points[:, 0] < image_size[1]))
self.assertTrue(np.all(points[:, 1] >= 0))
self.assertTrue(np.all(points[:, 1] < image_size[0])
)
def test_format_roi_statistics_valid(self):
"""Test ROI statistics formatting with valid stats."""
stats = {
"area_pixels": 1000,
"area_percentage": 10.5,
"mean_intensity": 128.5,
"std_intensity": 25.3,
"min_intensity": 50.0,
"max_intensity": 200.0,
"centroid": (100.5, 150.2),
"bounding_box": (50, 75, 150, 225)
}
formatted = format_roi_statistics(stats)
self.assertIsInstance(formatted, str)
self.assertIn("1000", formatted)
self.assertIn("10.5", formatted)
def test_format_roi_statistics_error(self):
"""Test ROI statistics formatting with error."""
stats = {
"error": "No valid mask available",
"area_pixels": 0
}
formatted = format_roi_statistics(stats)
self.assertIsInstance(formatted, str)
self.assertIn("Error", formatted)
if __name__ == '__main__':
unittest.main()