| """ |
| Tests for validators module. |
| """ |
|
|
| import unittest |
| import os |
| import tempfile |
| import numpy as np |
| from validators import ( |
| validate_file_path, |
| validate_file_size, |
| validate_file_extension, |
| validate_image_file, |
| validate_threshold, |
| validate_mask_threshold, |
| validate_coordinates, |
| validate_bounding_box, |
| validate_num_masks, |
| validate_prompt_text, |
| validate_modality, |
| validate_transparency, |
| validate_brightness_contrast, |
| ValidationError, |
| ) |
|
|
|
|
| class TestValidators(unittest.TestCase): |
| """Test cases for input validation functions.""" |
| |
| def setUp(self): |
| """Set up test fixtures.""" |
| self.temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
| self.temp_file.write(b'test content') |
| self.temp_file.close() |
| self.temp_path = self.temp_file.name |
| |
| def tearDown(self): |
| """Clean up test fixtures.""" |
| if os.path.exists(self.temp_path): |
| os.unlink(self.temp_path) |
| |
| def test_validate_file_path_valid(self): |
| """Test file path validation with valid file.""" |
| is_valid, error = validate_file_path(self.temp_path) |
| self.assertTrue(is_valid) |
| self.assertIsNone(error) |
| |
| def test_validate_file_path_none(self): |
| """Test file path validation with None.""" |
| is_valid, error = validate_file_path(None) |
| self.assertFalse(is_valid) |
| self.assertIsNotNone(error) |
| |
| def test_validate_file_path_not_exists(self): |
| """Test file path validation with non-existent file.""" |
| is_valid, error = validate_file_path("/nonexistent/file.png") |
| self.assertFalse(is_valid) |
| self.assertIsNotNone(error) |
| |
| def test_validate_file_size_valid(self): |
| """Test file size validation with valid file.""" |
| is_valid, error = validate_file_size(self.temp_path) |
| self.assertTrue(is_valid) |
| self.assertIsNone(error) |
| |
| def test_validate_file_extension_valid(self): |
| """Test file extension validation with valid extension.""" |
| is_valid, error = validate_file_extension(self.temp_path) |
| self.assertTrue(is_valid) |
| self.assertIsNone(error) |
| |
| def test_validate_file_extension_invalid(self): |
| """Test file extension validation with invalid extension.""" |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') |
| temp_file.close() |
| is_valid, error = validate_file_extension(temp_file.name) |
| self.assertFalse(is_valid) |
| self.assertIsNotNone(error) |
| os.unlink(temp_file.name) |
| |
| def test_validate_threshold_valid(self): |
| """Test threshold validation with valid values.""" |
| for threshold in [0.0, 0.1, 0.5, 1.0]: |
| is_valid, error = validate_threshold(threshold) |
| self.assertTrue(is_valid, f"Threshold {threshold} should be valid") |
| self.assertIsNone(error) |
| |
| def test_validate_threshold_invalid(self): |
| """Test threshold validation with invalid values.""" |
| for threshold in [-0.1, 1.1, "invalid"]: |
| is_valid, error = validate_threshold(threshold) |
| self.assertFalse(is_valid, f"Threshold {threshold} should be invalid") |
| self.assertIsNotNone(error) |
| |
| def test_validate_coordinates_valid(self): |
| """Test coordinate validation with valid values.""" |
| is_valid, error = validate_coordinates(100, 200) |
| self.assertTrue(is_valid) |
| self.assertIsNone(error) |
| |
| def test_validate_coordinates_invalid(self): |
| """Test coordinate validation with invalid values.""" |
| |
| is_valid, error = validate_coordinates(-1, 100) |
| self.assertFalse(is_valid) |
| self.assertIsNotNone(error) |
| |
| |
| is_valid, error = validate_coordinates(20000, 100) |
| self.assertFalse(is_valid) |
| self.assertIsNotNone(error) |
| |
| def test_validate_bounding_box_valid(self): |
| """Test bounding box validation with valid values.""" |
| is_valid, error = validate_bounding_box(10, 20, 100, 200) |
| self.assertTrue(is_valid) |
| self.assertIsNone(error) |
| |
| def test_validate_bounding_box_invalid(self): |
| """Test bounding box validation with invalid values.""" |
| |
| is_valid, error = validate_bounding_box(100, 20, 50, 200) |
| self.assertFalse(is_valid) |
| self.assertIsNotNone(error) |
| |
| |
| is_valid, error = validate_bounding_box(10, 200, 100, 50) |
| self.assertFalse(is_valid) |
| self.assertIsNotNone(error) |
| |
| def test_validate_num_masks_valid(self): |
| """Test num masks validation with valid values.""" |
| for num in [1, 3, 5]: |
| is_valid, error = validate_num_masks(num) |
| self.assertTrue(is_valid) |
| self.assertIsNone(error) |
| |
| def test_validate_num_masks_invalid(self): |
| """Test num masks validation with invalid values.""" |
| for num in [0, 6, -1]: |
| is_valid, error = validate_num_masks(num) |
| self.assertFalse(is_valid) |
| self.assertIsNotNone(error) |
| |
| def test_validate_prompt_text_valid(self): |
| """Test prompt text validation with valid values.""" |
| is_valid, error, prompt = validate_prompt_text("brain") |
| self.assertTrue(is_valid) |
| self.assertIsNone(error) |
| self.assertEqual(prompt, "brain") |
| |
| def test_validate_prompt_text_none(self): |
| """Test prompt text validation with None (should use default).""" |
| is_valid, error, prompt = validate_prompt_text(None) |
| self.assertTrue(is_valid) |
| self.assertEqual(prompt, "brain") |
| |
| def test_validate_prompt_text_empty(self): |
| """Test prompt text validation with empty string (should use default).""" |
| is_valid, error, prompt = validate_prompt_text(" ") |
| self.assertTrue(is_valid) |
| self.assertEqual(prompt, "brain") |
| |
| def test_validate_modality_valid(self): |
| """Test modality validation with valid values.""" |
| for modality in ["CT", "MRI", "ct", "mri"]: |
| is_valid, error = validate_modality(modality) |
| self.assertTrue(is_valid) |
| self.assertIsNone(error) |
| |
| def test_validate_modality_invalid(self): |
| """Test modality validation with invalid values.""" |
| for modality in [None, "invalid", "XRAY"]: |
| is_valid, error = validate_modality(modality) |
| self.assertFalse(is_valid) |
| self.assertIsNotNone(error) |
| |
| def test_validate_transparency_valid(self): |
| """Test transparency validation with valid values.""" |
| for trans in [0.0, 0.5, 1.0]: |
| is_valid, error = validate_transparency(trans) |
| self.assertTrue(is_valid) |
| self.assertIsNone(error) |
| |
| def test_validate_transparency_invalid(self): |
| """Test transparency validation with invalid values.""" |
| for trans in [-0.1, 1.1, "invalid"]: |
| is_valid, error = validate_transparency(trans) |
| self.assertFalse(is_valid) |
| self.assertIsNotNone(error) |
| |
| def test_validate_brightness_contrast_valid(self): |
| """Test brightness/contrast validation with valid values.""" |
| for val in [0.0, 1.0, 2.0, 3.0]: |
| is_valid, error = validate_brightness_contrast(val, "test") |
| self.assertTrue(is_valid) |
| self.assertIsNone(error) |
| |
| def test_validate_brightness_contrast_invalid(self): |
| """Test brightness/contrast validation with invalid values.""" |
| for val in [-0.1, 3.1, "invalid"]: |
| is_valid, error = validate_brightness_contrast(val, "test") |
| self.assertFalse(is_valid) |
| self.assertIsNotNone(error) |
|
|
|
|
| if __name__ == '__main__': |
| unittest.main() |
|
|
|
|