|
|
""" |
|
|
Tests for validators module. |
|
|
""" |
|
|
|
|
|
import unittest |
|
|
import os |
|
|
import tempfile |
|
|
import numpy as np |
|
|
from validators import ( |
|
|
validate_file_path, |
|
|
validate_file_size, |
|
|
validate_file_extension, |
|
|
validate_image_file, |
|
|
validate_threshold, |
|
|
validate_mask_threshold, |
|
|
validate_coordinates, |
|
|
validate_bounding_box, |
|
|
validate_num_masks, |
|
|
validate_prompt_text, |
|
|
validate_modality, |
|
|
validate_transparency, |
|
|
validate_brightness_contrast, |
|
|
ValidationError, |
|
|
) |
|
|
|
|
|
|
|
|
class TestValidators(unittest.TestCase): |
|
|
"""Test cases for input validation functions.""" |
|
|
|
|
|
def setUp(self): |
|
|
"""Set up test fixtures.""" |
|
|
self.temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
self.temp_file.write(b'test content') |
|
|
self.temp_file.close() |
|
|
self.temp_path = self.temp_file.name |
|
|
|
|
|
def tearDown(self): |
|
|
"""Clean up test fixtures.""" |
|
|
if os.path.exists(self.temp_path): |
|
|
os.unlink(self.temp_path) |
|
|
|
|
|
def test_validate_file_path_valid(self): |
|
|
"""Test file path validation with valid file.""" |
|
|
is_valid, error = validate_file_path(self.temp_path) |
|
|
self.assertTrue(is_valid) |
|
|
self.assertIsNone(error) |
|
|
|
|
|
def test_validate_file_path_none(self): |
|
|
"""Test file path validation with None.""" |
|
|
is_valid, error = validate_file_path(None) |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIsNotNone(error) |
|
|
|
|
|
def test_validate_file_path_not_exists(self): |
|
|
"""Test file path validation with non-existent file.""" |
|
|
is_valid, error = validate_file_path("/nonexistent/file.png") |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIsNotNone(error) |
|
|
|
|
|
def test_validate_file_size_valid(self): |
|
|
"""Test file size validation with valid file.""" |
|
|
is_valid, error = validate_file_size(self.temp_path) |
|
|
self.assertTrue(is_valid) |
|
|
self.assertIsNone(error) |
|
|
|
|
|
def test_validate_file_extension_valid(self): |
|
|
"""Test file extension validation with valid extension.""" |
|
|
is_valid, error = validate_file_extension(self.temp_path) |
|
|
self.assertTrue(is_valid) |
|
|
self.assertIsNone(error) |
|
|
|
|
|
def test_validate_file_extension_invalid(self): |
|
|
"""Test file extension validation with invalid extension.""" |
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') |
|
|
temp_file.close() |
|
|
is_valid, error = validate_file_extension(temp_file.name) |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIsNotNone(error) |
|
|
os.unlink(temp_file.name) |
|
|
|
|
|
def test_validate_threshold_valid(self): |
|
|
"""Test threshold validation with valid values.""" |
|
|
for threshold in [0.0, 0.1, 0.5, 1.0]: |
|
|
is_valid, error = validate_threshold(threshold) |
|
|
self.assertTrue(is_valid, f"Threshold {threshold} should be valid") |
|
|
self.assertIsNone(error) |
|
|
|
|
|
def test_validate_threshold_invalid(self): |
|
|
"""Test threshold validation with invalid values.""" |
|
|
for threshold in [-0.1, 1.1, "invalid"]: |
|
|
is_valid, error = validate_threshold(threshold) |
|
|
self.assertFalse(is_valid, f"Threshold {threshold} should be invalid") |
|
|
self.assertIsNotNone(error) |
|
|
|
|
|
def test_validate_coordinates_valid(self): |
|
|
"""Test coordinate validation with valid values.""" |
|
|
is_valid, error = validate_coordinates(100, 200) |
|
|
self.assertTrue(is_valid) |
|
|
self.assertIsNone(error) |
|
|
|
|
|
def test_validate_coordinates_invalid(self): |
|
|
"""Test coordinate validation with invalid values.""" |
|
|
|
|
|
is_valid, error = validate_coordinates(-1, 100) |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIsNotNone(error) |
|
|
|
|
|
|
|
|
is_valid, error = validate_coordinates(20000, 100) |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIsNotNone(error) |
|
|
|
|
|
def test_validate_bounding_box_valid(self): |
|
|
"""Test bounding box validation with valid values.""" |
|
|
is_valid, error = validate_bounding_box(10, 20, 100, 200) |
|
|
self.assertTrue(is_valid) |
|
|
self.assertIsNone(error) |
|
|
|
|
|
def test_validate_bounding_box_invalid(self): |
|
|
"""Test bounding box validation with invalid values.""" |
|
|
|
|
|
is_valid, error = validate_bounding_box(100, 20, 50, 200) |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIsNotNone(error) |
|
|
|
|
|
|
|
|
is_valid, error = validate_bounding_box(10, 200, 100, 50) |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIsNotNone(error) |
|
|
|
|
|
def test_validate_num_masks_valid(self): |
|
|
"""Test num masks validation with valid values.""" |
|
|
for num in [1, 3, 5]: |
|
|
is_valid, error = validate_num_masks(num) |
|
|
self.assertTrue(is_valid) |
|
|
self.assertIsNone(error) |
|
|
|
|
|
def test_validate_num_masks_invalid(self): |
|
|
"""Test num masks validation with invalid values.""" |
|
|
for num in [0, 6, -1]: |
|
|
is_valid, error = validate_num_masks(num) |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIsNotNone(error) |
|
|
|
|
|
def test_validate_prompt_text_valid(self): |
|
|
"""Test prompt text validation with valid values.""" |
|
|
is_valid, error, prompt = validate_prompt_text("brain") |
|
|
self.assertTrue(is_valid) |
|
|
self.assertIsNone(error) |
|
|
self.assertEqual(prompt, "brain") |
|
|
|
|
|
def test_validate_prompt_text_none(self): |
|
|
"""Test prompt text validation with None (should use default).""" |
|
|
is_valid, error, prompt = validate_prompt_text(None) |
|
|
self.assertTrue(is_valid) |
|
|
self.assertEqual(prompt, "brain") |
|
|
|
|
|
def test_validate_prompt_text_empty(self): |
|
|
"""Test prompt text validation with empty string (should use default).""" |
|
|
is_valid, error, prompt = validate_prompt_text(" ") |
|
|
self.assertTrue(is_valid) |
|
|
self.assertEqual(prompt, "brain") |
|
|
|
|
|
def test_validate_modality_valid(self): |
|
|
"""Test modality validation with valid values.""" |
|
|
for modality in ["CT", "MRI", "ct", "mri"]: |
|
|
is_valid, error = validate_modality(modality) |
|
|
self.assertTrue(is_valid) |
|
|
self.assertIsNone(error) |
|
|
|
|
|
def test_validate_modality_invalid(self): |
|
|
"""Test modality validation with invalid values.""" |
|
|
for modality in [None, "invalid", "XRAY"]: |
|
|
is_valid, error = validate_modality(modality) |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIsNotNone(error) |
|
|
|
|
|
def test_validate_transparency_valid(self): |
|
|
"""Test transparency validation with valid values.""" |
|
|
for trans in [0.0, 0.5, 1.0]: |
|
|
is_valid, error = validate_transparency(trans) |
|
|
self.assertTrue(is_valid) |
|
|
self.assertIsNone(error) |
|
|
|
|
|
def test_validate_transparency_invalid(self): |
|
|
"""Test transparency validation with invalid values.""" |
|
|
for trans in [-0.1, 1.1, "invalid"]: |
|
|
is_valid, error = validate_transparency(trans) |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIsNotNone(error) |
|
|
|
|
|
def test_validate_brightness_contrast_valid(self): |
|
|
"""Test brightness/contrast validation with valid values.""" |
|
|
for val in [0.0, 1.0, 2.0, 3.0]: |
|
|
is_valid, error = validate_brightness_contrast(val, "test") |
|
|
self.assertTrue(is_valid) |
|
|
self.assertIsNone(error) |
|
|
|
|
|
def test_validate_brightness_contrast_invalid(self): |
|
|
"""Test brightness/contrast validation with invalid values.""" |
|
|
for val in [-0.1, 3.1, "invalid"]: |
|
|
is_valid, error = validate_brightness_contrast(val, "test") |
|
|
self.assertFalse(is_valid) |
|
|
self.assertIsNotNone(error) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
unittest.main() |
|
|
|
|
|
|