|
|
""" |
|
|
Tests for segmentation module. |
|
|
""" |
|
|
|
|
|
import unittest |
|
|
import numpy as np |
|
|
import tempfile |
|
|
from PIL import Image |
|
|
from segmentation import ( |
|
|
calculate_dice_score, |
|
|
calculate_iou_score, |
|
|
generate_grid_points, |
|
|
format_roi_statistics, |
|
|
) |
|
|
|
|
|
|
|
|
class TestSegmentation(unittest.TestCase): |
|
|
"""Test cases for segmentation functions.""" |
|
|
|
|
|
def test_calculate_dice_score_perfect_match(self): |
|
|
"""Test Dice score calculation with perfect match.""" |
|
|
mask1 = np.ones((10, 10), dtype=bool) |
|
|
mask2 = np.ones((10, 10), dtype=bool) |
|
|
dice = calculate_dice_score(mask1, mask2) |
|
|
self.assertEqual(dice, 1.0) |
|
|
|
|
|
def test_calculate_dice_score_no_overlap(self): |
|
|
"""Test Dice score calculation with no overlap.""" |
|
|
mask1 = np.zeros((10, 10), dtype=bool) |
|
|
mask1[0:5, 0:5] = True |
|
|
mask2 = np.zeros((10, 10), dtype=bool) |
|
|
mask2[5:10, 5:10] = True |
|
|
dice = calculate_dice_score(mask1, mask2) |
|
|
self.assertEqual(dice, 0.0) |
|
|
|
|
|
def test_calculate_dice_score_partial_overlap(self): |
|
|
"""Test Dice score calculation with partial overlap.""" |
|
|
mask1 = np.zeros((10, 10), dtype=bool) |
|
|
mask1[0:7, 0:7] = True |
|
|
mask2 = np.zeros((10, 10), dtype=bool) |
|
|
mask2[3:10, 3:10] = True |
|
|
dice = calculate_dice_score(mask1, mask2) |
|
|
self.assertGreater(dice, 0.0) |
|
|
self.assertLess(dice, 1.0) |
|
|
|
|
|
def test_calculate_iou_score_perfect_match(self): |
|
|
"""Test IoU score calculation with perfect match.""" |
|
|
mask1 = np.ones((10, 10), dtype=bool) |
|
|
mask2 = np.ones((10, 10), dtype=bool) |
|
|
iou = calculate_iou_score(mask1, mask2) |
|
|
self.assertEqual(iou, 1.0) |
|
|
|
|
|
def test_calculate_iou_score_no_overlap(self): |
|
|
"""Test IoU score calculation with no overlap.""" |
|
|
mask1 = np.zeros((10, 10), dtype=bool) |
|
|
mask1[0:5, 0:5] = True |
|
|
mask2 = np.zeros((10, 10), dtype=bool) |
|
|
mask2[5:10, 5:10] = True |
|
|
iou = calculate_iou_score(mask1, mask2) |
|
|
self.assertEqual(iou, 0.0) |
|
|
|
|
|
def test_generate_grid_points(self): |
|
|
"""Test grid point generation.""" |
|
|
image_size = (100, 200) |
|
|
points_per_side = 10 |
|
|
points = generate_grid_points(image_size, points_per_side) |
|
|
|
|
|
self.assertEqual(points.shape[0], points_per_side * points_per_side) |
|
|
self.assertEqual(points.shape[1], 2) |
|
|
|
|
|
|
|
|
self.assertTrue(np.all(points[:, 0] >= 0)) |
|
|
self.assertTrue(np.all(points[:, 0] < image_size[1])) |
|
|
self.assertTrue(np.all(points[:, 1] >= 0)) |
|
|
self.assertTrue(np.all(points[:, 1] < image_size[0]) |
|
|
) |
|
|
|
|
|
def test_format_roi_statistics_valid(self): |
|
|
"""Test ROI statistics formatting with valid stats.""" |
|
|
stats = { |
|
|
"area_pixels": 1000, |
|
|
"area_percentage": 10.5, |
|
|
"mean_intensity": 128.5, |
|
|
"std_intensity": 25.3, |
|
|
"min_intensity": 50.0, |
|
|
"max_intensity": 200.0, |
|
|
"centroid": (100.5, 150.2), |
|
|
"bounding_box": (50, 75, 150, 225) |
|
|
} |
|
|
formatted = format_roi_statistics(stats) |
|
|
self.assertIsInstance(formatted, str) |
|
|
self.assertIn("1000", formatted) |
|
|
self.assertIn("10.5", formatted) |
|
|
|
|
|
def test_format_roi_statistics_error(self): |
|
|
"""Test ROI statistics formatting with error.""" |
|
|
stats = { |
|
|
"error": "No valid mask available", |
|
|
"area_pixels": 0 |
|
|
} |
|
|
formatted = format_roi_statistics(stats) |
|
|
self.assertIsInstance(formatted, str) |
|
|
self.assertIn("Error", formatted) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
unittest.main() |
|
|
|
|
|
|