File size: 7,914 Bytes
69066c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
"""
Tests for validators module.
"""
import unittest
import os
import tempfile
import numpy as np
from validators import (
validate_file_path,
validate_file_size,
validate_file_extension,
validate_image_file,
validate_threshold,
validate_mask_threshold,
validate_coordinates,
validate_bounding_box,
validate_num_masks,
validate_prompt_text,
validate_modality,
validate_transparency,
validate_brightness_contrast,
ValidationError,
)
class TestValidators(unittest.TestCase):
"""Test cases for input validation functions."""
def setUp(self):
"""Set up test fixtures."""
self.temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
self.temp_file.write(b'test content')
self.temp_file.close()
self.temp_path = self.temp_file.name
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.temp_path):
os.unlink(self.temp_path)
def test_validate_file_path_valid(self):
"""Test file path validation with valid file."""
is_valid, error = validate_file_path(self.temp_path)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_file_path_none(self):
"""Test file path validation with None."""
is_valid, error = validate_file_path(None)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_file_path_not_exists(self):
"""Test file path validation with non-existent file."""
is_valid, error = validate_file_path("/nonexistent/file.png")
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_file_size_valid(self):
"""Test file size validation with valid file."""
is_valid, error = validate_file_size(self.temp_path)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_file_extension_valid(self):
"""Test file extension validation with valid extension."""
is_valid, error = validate_file_extension(self.temp_path)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_file_extension_invalid(self):
"""Test file extension validation with invalid extension."""
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
temp_file.close()
is_valid, error = validate_file_extension(temp_file.name)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
os.unlink(temp_file.name)
def test_validate_threshold_valid(self):
"""Test threshold validation with valid values."""
for threshold in [0.0, 0.1, 0.5, 1.0]:
is_valid, error = validate_threshold(threshold)
self.assertTrue(is_valid, f"Threshold {threshold} should be valid")
self.assertIsNone(error)
def test_validate_threshold_invalid(self):
"""Test threshold validation with invalid values."""
for threshold in [-0.1, 1.1, "invalid"]:
is_valid, error = validate_threshold(threshold)
self.assertFalse(is_valid, f"Threshold {threshold} should be invalid")
self.assertIsNotNone(error)
def test_validate_coordinates_valid(self):
"""Test coordinate validation with valid values."""
is_valid, error = validate_coordinates(100, 200)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_coordinates_invalid(self):
"""Test coordinate validation with invalid values."""
# Negative coordinates
is_valid, error = validate_coordinates(-1, 100)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
# Too large coordinates
is_valid, error = validate_coordinates(20000, 100)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_bounding_box_valid(self):
"""Test bounding box validation with valid values."""
is_valid, error = validate_bounding_box(10, 20, 100, 200)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_bounding_box_invalid(self):
"""Test bounding box validation with invalid values."""
# x2 <= x1
is_valid, error = validate_bounding_box(100, 20, 50, 200)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
# y2 <= y1
is_valid, error = validate_bounding_box(10, 200, 100, 50)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_num_masks_valid(self):
"""Test num masks validation with valid values."""
for num in [1, 3, 5]:
is_valid, error = validate_num_masks(num)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_num_masks_invalid(self):
"""Test num masks validation with invalid values."""
for num in [0, 6, -1]:
is_valid, error = validate_num_masks(num)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_prompt_text_valid(self):
"""Test prompt text validation with valid values."""
is_valid, error, prompt = validate_prompt_text("brain")
self.assertTrue(is_valid)
self.assertIsNone(error)
self.assertEqual(prompt, "brain")
def test_validate_prompt_text_none(self):
"""Test prompt text validation with None (should use default)."""
is_valid, error, prompt = validate_prompt_text(None)
self.assertTrue(is_valid)
self.assertEqual(prompt, "brain") # Default
def test_validate_prompt_text_empty(self):
"""Test prompt text validation with empty string (should use default)."""
is_valid, error, prompt = validate_prompt_text(" ")
self.assertTrue(is_valid)
self.assertEqual(prompt, "brain") # Default
def test_validate_modality_valid(self):
"""Test modality validation with valid values."""
for modality in ["CT", "MRI", "ct", "mri"]:
is_valid, error = validate_modality(modality)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_modality_invalid(self):
"""Test modality validation with invalid values."""
for modality in [None, "invalid", "XRAY"]:
is_valid, error = validate_modality(modality)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_transparency_valid(self):
"""Test transparency validation with valid values."""
for trans in [0.0, 0.5, 1.0]:
is_valid, error = validate_transparency(trans)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_transparency_invalid(self):
"""Test transparency validation with invalid values."""
for trans in [-0.1, 1.1, "invalid"]:
is_valid, error = validate_transparency(trans)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_brightness_contrast_valid(self):
"""Test brightness/contrast validation with valid values."""
for val in [0.0, 1.0, 2.0, 3.0]:
is_valid, error = validate_brightness_contrast(val, "test")
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_brightness_contrast_invalid(self):
"""Test brightness/contrast validation with invalid values."""
for val in [-0.1, 3.1, "invalid"]:
is_valid, error = validate_brightness_contrast(val, "test")
self.assertFalse(is_valid)
self.assertIsNotNone(error)
if __name__ == '__main__':
unittest.main()
|