""" Tests for segmentation module. """ import unittest import numpy as np import tempfile from PIL import Image from segmentation import ( calculate_dice_score, calculate_iou_score, generate_grid_points, format_roi_statistics, ) class TestSegmentation(unittest.TestCase): """Test cases for segmentation functions.""" def test_calculate_dice_score_perfect_match(self): """Test Dice score calculation with perfect match.""" mask1 = np.ones((10, 10), dtype=bool) mask2 = np.ones((10, 10), dtype=bool) dice = calculate_dice_score(mask1, mask2) self.assertEqual(dice, 1.0) def test_calculate_dice_score_no_overlap(self): """Test Dice score calculation with no overlap.""" mask1 = np.zeros((10, 10), dtype=bool) mask1[0:5, 0:5] = True mask2 = np.zeros((10, 10), dtype=bool) mask2[5:10, 5:10] = True dice = calculate_dice_score(mask1, mask2) self.assertEqual(dice, 0.0) def test_calculate_dice_score_partial_overlap(self): """Test Dice score calculation with partial overlap.""" mask1 = np.zeros((10, 10), dtype=bool) mask1[0:7, 0:7] = True mask2 = np.zeros((10, 10), dtype=bool) mask2[3:10, 3:10] = True dice = calculate_dice_score(mask1, mask2) self.assertGreater(dice, 0.0) self.assertLess(dice, 1.0) def test_calculate_iou_score_perfect_match(self): """Test IoU score calculation with perfect match.""" mask1 = np.ones((10, 10), dtype=bool) mask2 = np.ones((10, 10), dtype=bool) iou = calculate_iou_score(mask1, mask2) self.assertEqual(iou, 1.0) def test_calculate_iou_score_no_overlap(self): """Test IoU score calculation with no overlap.""" mask1 = np.zeros((10, 10), dtype=bool) mask1[0:5, 0:5] = True mask2 = np.zeros((10, 10), dtype=bool) mask2[5:10, 5:10] = True iou = calculate_iou_score(mask1, mask2) self.assertEqual(iou, 0.0) def test_generate_grid_points(self): """Test grid point generation.""" image_size = (100, 200) points_per_side = 10 points = generate_grid_points(image_size, points_per_side) self.assertEqual(points.shape[0], points_per_side * points_per_side) self.assertEqual(points.shape[1], 2) # Check that points are within image bounds self.assertTrue(np.all(points[:, 0] >= 0)) self.assertTrue(np.all(points[:, 0] < image_size[1])) self.assertTrue(np.all(points[:, 1] >= 0)) self.assertTrue(np.all(points[:, 1] < image_size[0]) ) def test_format_roi_statistics_valid(self): """Test ROI statistics formatting with valid stats.""" stats = { "area_pixels": 1000, "area_percentage": 10.5, "mean_intensity": 128.5, "std_intensity": 25.3, "min_intensity": 50.0, "max_intensity": 200.0, "centroid": (100.5, 150.2), "bounding_box": (50, 75, 150, 225) } formatted = format_roi_statistics(stats) self.assertIsInstance(formatted, str) self.assertIn("1000", formatted) self.assertIn("10.5", formatted) def test_format_roi_statistics_error(self): """Test ROI statistics formatting with error.""" stats = { "error": "No valid mask available", "area_pixels": 0 } formatted = format_roi_statistics(stats) self.assertIsInstance(formatted, str) self.assertIn("Error", formatted) if __name__ == '__main__': unittest.main()