""" Tests for validators module. """ import unittest import os import tempfile import numpy as np from validators import ( validate_file_path, validate_file_size, validate_file_extension, validate_image_file, validate_threshold, validate_mask_threshold, validate_coordinates, validate_bounding_box, validate_num_masks, validate_prompt_text, validate_modality, validate_transparency, validate_brightness_contrast, ValidationError, ) class TestValidators(unittest.TestCase): """Test cases for input validation functions.""" def setUp(self): """Set up test fixtures.""" self.temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') self.temp_file.write(b'test content') self.temp_file.close() self.temp_path = self.temp_file.name def tearDown(self): """Clean up test fixtures.""" if os.path.exists(self.temp_path): os.unlink(self.temp_path) def test_validate_file_path_valid(self): """Test file path validation with valid file.""" is_valid, error = validate_file_path(self.temp_path) self.assertTrue(is_valid) self.assertIsNone(error) def test_validate_file_path_none(self): """Test file path validation with None.""" is_valid, error = validate_file_path(None) self.assertFalse(is_valid) self.assertIsNotNone(error) def test_validate_file_path_not_exists(self): """Test file path validation with non-existent file.""" is_valid, error = validate_file_path("/nonexistent/file.png") self.assertFalse(is_valid) self.assertIsNotNone(error) def test_validate_file_size_valid(self): """Test file size validation with valid file.""" is_valid, error = validate_file_size(self.temp_path) self.assertTrue(is_valid) self.assertIsNone(error) def test_validate_file_extension_valid(self): """Test file extension validation with valid extension.""" is_valid, error = validate_file_extension(self.temp_path) self.assertTrue(is_valid) self.assertIsNone(error) def test_validate_file_extension_invalid(self): """Test file extension validation with invalid extension.""" temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') temp_file.close() is_valid, error = validate_file_extension(temp_file.name) self.assertFalse(is_valid) self.assertIsNotNone(error) os.unlink(temp_file.name) def test_validate_threshold_valid(self): """Test threshold validation with valid values.""" for threshold in [0.0, 0.1, 0.5, 1.0]: is_valid, error = validate_threshold(threshold) self.assertTrue(is_valid, f"Threshold {threshold} should be valid") self.assertIsNone(error) def test_validate_threshold_invalid(self): """Test threshold validation with invalid values.""" for threshold in [-0.1, 1.1, "invalid"]: is_valid, error = validate_threshold(threshold) self.assertFalse(is_valid, f"Threshold {threshold} should be invalid") self.assertIsNotNone(error) def test_validate_coordinates_valid(self): """Test coordinate validation with valid values.""" is_valid, error = validate_coordinates(100, 200) self.assertTrue(is_valid) self.assertIsNone(error) def test_validate_coordinates_invalid(self): """Test coordinate validation with invalid values.""" # Negative coordinates is_valid, error = validate_coordinates(-1, 100) self.assertFalse(is_valid) self.assertIsNotNone(error) # Too large coordinates is_valid, error = validate_coordinates(20000, 100) self.assertFalse(is_valid) self.assertIsNotNone(error) def test_validate_bounding_box_valid(self): """Test bounding box validation with valid values.""" is_valid, error = validate_bounding_box(10, 20, 100, 200) self.assertTrue(is_valid) self.assertIsNone(error) def test_validate_bounding_box_invalid(self): """Test bounding box validation with invalid values.""" # x2 <= x1 is_valid, error = validate_bounding_box(100, 20, 50, 200) self.assertFalse(is_valid) self.assertIsNotNone(error) # y2 <= y1 is_valid, error = validate_bounding_box(10, 200, 100, 50) self.assertFalse(is_valid) self.assertIsNotNone(error) def test_validate_num_masks_valid(self): """Test num masks validation with valid values.""" for num in [1, 3, 5]: is_valid, error = validate_num_masks(num) self.assertTrue(is_valid) self.assertIsNone(error) def test_validate_num_masks_invalid(self): """Test num masks validation with invalid values.""" for num in [0, 6, -1]: is_valid, error = validate_num_masks(num) self.assertFalse(is_valid) self.assertIsNotNone(error) def test_validate_prompt_text_valid(self): """Test prompt text validation with valid values.""" is_valid, error, prompt = validate_prompt_text("brain") self.assertTrue(is_valid) self.assertIsNone(error) self.assertEqual(prompt, "brain") def test_validate_prompt_text_none(self): """Test prompt text validation with None (should use default).""" is_valid, error, prompt = validate_prompt_text(None) self.assertTrue(is_valid) self.assertEqual(prompt, "brain") # Default def test_validate_prompt_text_empty(self): """Test prompt text validation with empty string (should use default).""" is_valid, error, prompt = validate_prompt_text(" ") self.assertTrue(is_valid) self.assertEqual(prompt, "brain") # Default def test_validate_modality_valid(self): """Test modality validation with valid values.""" for modality in ["CT", "MRI", "ct", "mri"]: is_valid, error = validate_modality(modality) self.assertTrue(is_valid) self.assertIsNone(error) def test_validate_modality_invalid(self): """Test modality validation with invalid values.""" for modality in [None, "invalid", "XRAY"]: is_valid, error = validate_modality(modality) self.assertFalse(is_valid) self.assertIsNotNone(error) def test_validate_transparency_valid(self): """Test transparency validation with valid values.""" for trans in [0.0, 0.5, 1.0]: is_valid, error = validate_transparency(trans) self.assertTrue(is_valid) self.assertIsNone(error) def test_validate_transparency_invalid(self): """Test transparency validation with invalid values.""" for trans in [-0.1, 1.1, "invalid"]: is_valid, error = validate_transparency(trans) self.assertFalse(is_valid) self.assertIsNotNone(error) def test_validate_brightness_contrast_valid(self): """Test brightness/contrast validation with valid values.""" for val in [0.0, 1.0, 2.0, 3.0]: is_valid, error = validate_brightness_contrast(val, "test") self.assertTrue(is_valid) self.assertIsNone(error) def test_validate_brightness_contrast_invalid(self): """Test brightness/contrast validation with invalid values.""" for val in [-0.1, 3.1, "invalid"]: is_valid, error = validate_brightness_contrast(val, "test") self.assertFalse(is_valid) self.assertIsNotNone(error) if __name__ == '__main__': unittest.main()