File size: 3,715 Bytes
69066c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Tests for segmentation module.
"""

import unittest
import numpy as np
import tempfile
from PIL import Image
from segmentation import (
    calculate_dice_score,
    calculate_iou_score,
    generate_grid_points,
    format_roi_statistics,
)


class TestSegmentation(unittest.TestCase):
    """Test cases for segmentation functions."""
    
    def test_calculate_dice_score_perfect_match(self):
        """Test Dice score calculation with perfect match."""
        mask1 = np.ones((10, 10), dtype=bool)
        mask2 = np.ones((10, 10), dtype=bool)
        dice = calculate_dice_score(mask1, mask2)
        self.assertEqual(dice, 1.0)
    
    def test_calculate_dice_score_no_overlap(self):
        """Test Dice score calculation with no overlap."""
        mask1 = np.zeros((10, 10), dtype=bool)
        mask1[0:5, 0:5] = True
        mask2 = np.zeros((10, 10), dtype=bool)
        mask2[5:10, 5:10] = True
        dice = calculate_dice_score(mask1, mask2)
        self.assertEqual(dice, 0.0)
    
    def test_calculate_dice_score_partial_overlap(self):
        """Test Dice score calculation with partial overlap."""
        mask1 = np.zeros((10, 10), dtype=bool)
        mask1[0:7, 0:7] = True
        mask2 = np.zeros((10, 10), dtype=bool)
        mask2[3:10, 3:10] = True
        dice = calculate_dice_score(mask1, mask2)
        self.assertGreater(dice, 0.0)
        self.assertLess(dice, 1.0)
    
    def test_calculate_iou_score_perfect_match(self):
        """Test IoU score calculation with perfect match."""
        mask1 = np.ones((10, 10), dtype=bool)
        mask2 = np.ones((10, 10), dtype=bool)
        iou = calculate_iou_score(mask1, mask2)
        self.assertEqual(iou, 1.0)
    
    def test_calculate_iou_score_no_overlap(self):
        """Test IoU score calculation with no overlap."""
        mask1 = np.zeros((10, 10), dtype=bool)
        mask1[0:5, 0:5] = True
        mask2 = np.zeros((10, 10), dtype=bool)
        mask2[5:10, 5:10] = True
        iou = calculate_iou_score(mask1, mask2)
        self.assertEqual(iou, 0.0)
    
    def test_generate_grid_points(self):
        """Test grid point generation."""
        image_size = (100, 200)
        points_per_side = 10
        points = generate_grid_points(image_size, points_per_side)
        
        self.assertEqual(points.shape[0], points_per_side * points_per_side)
        self.assertEqual(points.shape[1], 2)
        
        # Check that points are within image bounds
        self.assertTrue(np.all(points[:, 0] >= 0))
        self.assertTrue(np.all(points[:, 0] < image_size[1]))
        self.assertTrue(np.all(points[:, 1] >= 0))
        self.assertTrue(np.all(points[:, 1] < image_size[0])
)
    
    def test_format_roi_statistics_valid(self):
        """Test ROI statistics formatting with valid stats."""
        stats = {
            "area_pixels": 1000,
            "area_percentage": 10.5,
            "mean_intensity": 128.5,
            "std_intensity": 25.3,
            "min_intensity": 50.0,
            "max_intensity": 200.0,
            "centroid": (100.5, 150.2),
            "bounding_box": (50, 75, 150, 225)
        }
        formatted = format_roi_statistics(stats)
        self.assertIsInstance(formatted, str)
        self.assertIn("1000", formatted)
        self.assertIn("10.5", formatted)
    
    def test_format_roi_statistics_error(self):
        """Test ROI statistics formatting with error."""
        stats = {
            "error": "No valid mask available",
            "area_pixels": 0
        }
        formatted = format_roi_statistics(stats)
        self.assertIsInstance(formatted, str)
        self.assertIn("Error", formatted)


if __name__ == '__main__':
    unittest.main()