File size: 3,715 Bytes
69066c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
"""
Tests for segmentation module.
"""
import unittest
import numpy as np
import tempfile
from PIL import Image
from segmentation import (
calculate_dice_score,
calculate_iou_score,
generate_grid_points,
format_roi_statistics,
)
class TestSegmentation(unittest.TestCase):
"""Test cases for segmentation functions."""
def test_calculate_dice_score_perfect_match(self):
"""Test Dice score calculation with perfect match."""
mask1 = np.ones((10, 10), dtype=bool)
mask2 = np.ones((10, 10), dtype=bool)
dice = calculate_dice_score(mask1, mask2)
self.assertEqual(dice, 1.0)
def test_calculate_dice_score_no_overlap(self):
"""Test Dice score calculation with no overlap."""
mask1 = np.zeros((10, 10), dtype=bool)
mask1[0:5, 0:5] = True
mask2 = np.zeros((10, 10), dtype=bool)
mask2[5:10, 5:10] = True
dice = calculate_dice_score(mask1, mask2)
self.assertEqual(dice, 0.0)
def test_calculate_dice_score_partial_overlap(self):
"""Test Dice score calculation with partial overlap."""
mask1 = np.zeros((10, 10), dtype=bool)
mask1[0:7, 0:7] = True
mask2 = np.zeros((10, 10), dtype=bool)
mask2[3:10, 3:10] = True
dice = calculate_dice_score(mask1, mask2)
self.assertGreater(dice, 0.0)
self.assertLess(dice, 1.0)
def test_calculate_iou_score_perfect_match(self):
"""Test IoU score calculation with perfect match."""
mask1 = np.ones((10, 10), dtype=bool)
mask2 = np.ones((10, 10), dtype=bool)
iou = calculate_iou_score(mask1, mask2)
self.assertEqual(iou, 1.0)
def test_calculate_iou_score_no_overlap(self):
"""Test IoU score calculation with no overlap."""
mask1 = np.zeros((10, 10), dtype=bool)
mask1[0:5, 0:5] = True
mask2 = np.zeros((10, 10), dtype=bool)
mask2[5:10, 5:10] = True
iou = calculate_iou_score(mask1, mask2)
self.assertEqual(iou, 0.0)
def test_generate_grid_points(self):
"""Test grid point generation."""
image_size = (100, 200)
points_per_side = 10
points = generate_grid_points(image_size, points_per_side)
self.assertEqual(points.shape[0], points_per_side * points_per_side)
self.assertEqual(points.shape[1], 2)
# Check that points are within image bounds
self.assertTrue(np.all(points[:, 0] >= 0))
self.assertTrue(np.all(points[:, 0] < image_size[1]))
self.assertTrue(np.all(points[:, 1] >= 0))
self.assertTrue(np.all(points[:, 1] < image_size[0])
)
def test_format_roi_statistics_valid(self):
"""Test ROI statistics formatting with valid stats."""
stats = {
"area_pixels": 1000,
"area_percentage": 10.5,
"mean_intensity": 128.5,
"std_intensity": 25.3,
"min_intensity": 50.0,
"max_intensity": 200.0,
"centroid": (100.5, 150.2),
"bounding_box": (50, 75, 150, 225)
}
formatted = format_roi_statistics(stats)
self.assertIsInstance(formatted, str)
self.assertIn("1000", formatted)
self.assertIn("10.5", formatted)
def test_format_roi_statistics_error(self):
"""Test ROI statistics formatting with error."""
stats = {
"error": "No valid mask available",
"area_pixels": 0
}
formatted = format_roi_statistics(stats)
self.assertIsInstance(formatted, str)
self.assertIn("Error", formatted)
if __name__ == '__main__':
unittest.main()
|