File size: 7,914 Bytes
69066c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""
Tests for validators module.
"""

import unittest
import os
import tempfile
import numpy as np
from validators import (
    validate_file_path,
    validate_file_size,
    validate_file_extension,
    validate_image_file,
    validate_threshold,
    validate_mask_threshold,
    validate_coordinates,
    validate_bounding_box,
    validate_num_masks,
    validate_prompt_text,
    validate_modality,
    validate_transparency,
    validate_brightness_contrast,
    ValidationError,
)


class TestValidators(unittest.TestCase):
    """Test cases for input validation functions."""
    
    def setUp(self):
        """Set up test fixtures."""
        self.temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
        self.temp_file.write(b'test content')
        self.temp_file.close()
        self.temp_path = self.temp_file.name
    
    def tearDown(self):
        """Clean up test fixtures."""
        if os.path.exists(self.temp_path):
            os.unlink(self.temp_path)
    
    def test_validate_file_path_valid(self):
        """Test file path validation with valid file."""
        is_valid, error = validate_file_path(self.temp_path)
        self.assertTrue(is_valid)
        self.assertIsNone(error)
    
    def test_validate_file_path_none(self):
        """Test file path validation with None."""
        is_valid, error = validate_file_path(None)
        self.assertFalse(is_valid)
        self.assertIsNotNone(error)
    
    def test_validate_file_path_not_exists(self):
        """Test file path validation with non-existent file."""
        is_valid, error = validate_file_path("/nonexistent/file.png")
        self.assertFalse(is_valid)
        self.assertIsNotNone(error)
    
    def test_validate_file_size_valid(self):
        """Test file size validation with valid file."""
        is_valid, error = validate_file_size(self.temp_path)
        self.assertTrue(is_valid)
        self.assertIsNone(error)
    
    def test_validate_file_extension_valid(self):
        """Test file extension validation with valid extension."""
        is_valid, error = validate_file_extension(self.temp_path)
        self.assertTrue(is_valid)
        self.assertIsNone(error)
    
    def test_validate_file_extension_invalid(self):
        """Test file extension validation with invalid extension."""
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
        temp_file.close()
        is_valid, error = validate_file_extension(temp_file.name)
        self.assertFalse(is_valid)
        self.assertIsNotNone(error)
        os.unlink(temp_file.name)
    
    def test_validate_threshold_valid(self):
        """Test threshold validation with valid values."""
        for threshold in [0.0, 0.1, 0.5, 1.0]:
            is_valid, error = validate_threshold(threshold)
            self.assertTrue(is_valid, f"Threshold {threshold} should be valid")
            self.assertIsNone(error)
    
    def test_validate_threshold_invalid(self):
        """Test threshold validation with invalid values."""
        for threshold in [-0.1, 1.1, "invalid"]:
            is_valid, error = validate_threshold(threshold)
            self.assertFalse(is_valid, f"Threshold {threshold} should be invalid")
            self.assertIsNotNone(error)
    
    def test_validate_coordinates_valid(self):
        """Test coordinate validation with valid values."""
        is_valid, error = validate_coordinates(100, 200)
        self.assertTrue(is_valid)
        self.assertIsNone(error)
    
    def test_validate_coordinates_invalid(self):
        """Test coordinate validation with invalid values."""
        # Negative coordinates
        is_valid, error = validate_coordinates(-1, 100)
        self.assertFalse(is_valid)
        self.assertIsNotNone(error)
        
        # Too large coordinates
        is_valid, error = validate_coordinates(20000, 100)
        self.assertFalse(is_valid)
        self.assertIsNotNone(error)
    
    def test_validate_bounding_box_valid(self):
        """Test bounding box validation with valid values."""
        is_valid, error = validate_bounding_box(10, 20, 100, 200)
        self.assertTrue(is_valid)
        self.assertIsNone(error)
    
    def test_validate_bounding_box_invalid(self):
        """Test bounding box validation with invalid values."""
        # x2 <= x1
        is_valid, error = validate_bounding_box(100, 20, 50, 200)
        self.assertFalse(is_valid)
        self.assertIsNotNone(error)
        
        # y2 <= y1
        is_valid, error = validate_bounding_box(10, 200, 100, 50)
        self.assertFalse(is_valid)
        self.assertIsNotNone(error)
    
    def test_validate_num_masks_valid(self):
        """Test num masks validation with valid values."""
        for num in [1, 3, 5]:
            is_valid, error = validate_num_masks(num)
            self.assertTrue(is_valid)
            self.assertIsNone(error)
    
    def test_validate_num_masks_invalid(self):
        """Test num masks validation with invalid values."""
        for num in [0, 6, -1]:
            is_valid, error = validate_num_masks(num)
            self.assertFalse(is_valid)
            self.assertIsNotNone(error)
    
    def test_validate_prompt_text_valid(self):
        """Test prompt text validation with valid values."""
        is_valid, error, prompt = validate_prompt_text("brain")
        self.assertTrue(is_valid)
        self.assertIsNone(error)
        self.assertEqual(prompt, "brain")
    
    def test_validate_prompt_text_none(self):
        """Test prompt text validation with None (should use default)."""
        is_valid, error, prompt = validate_prompt_text(None)
        self.assertTrue(is_valid)
        self.assertEqual(prompt, "brain")  # Default
    
    def test_validate_prompt_text_empty(self):
        """Test prompt text validation with empty string (should use default)."""
        is_valid, error, prompt = validate_prompt_text("   ")
        self.assertTrue(is_valid)
        self.assertEqual(prompt, "brain")  # Default
    
    def test_validate_modality_valid(self):
        """Test modality validation with valid values."""
        for modality in ["CT", "MRI", "ct", "mri"]:
            is_valid, error = validate_modality(modality)
            self.assertTrue(is_valid)
            self.assertIsNone(error)
    
    def test_validate_modality_invalid(self):
        """Test modality validation with invalid values."""
        for modality in [None, "invalid", "XRAY"]:
            is_valid, error = validate_modality(modality)
            self.assertFalse(is_valid)
            self.assertIsNotNone(error)
    
    def test_validate_transparency_valid(self):
        """Test transparency validation with valid values."""
        for trans in [0.0, 0.5, 1.0]:
            is_valid, error = validate_transparency(trans)
            self.assertTrue(is_valid)
            self.assertIsNone(error)
    
    def test_validate_transparency_invalid(self):
        """Test transparency validation with invalid values."""
        for trans in [-0.1, 1.1, "invalid"]:
            is_valid, error = validate_transparency(trans)
            self.assertFalse(is_valid)
            self.assertIsNotNone(error)
    
    def test_validate_brightness_contrast_valid(self):
        """Test brightness/contrast validation with valid values."""
        for val in [0.0, 1.0, 2.0, 3.0]:
            is_valid, error = validate_brightness_contrast(val, "test")
            self.assertTrue(is_valid)
            self.assertIsNone(error)
    
    def test_validate_brightness_contrast_invalid(self):
        """Test brightness/contrast validation with invalid values."""
        for val in [-0.1, 3.1, "invalid"]:
            is_valid, error = validate_brightness_contrast(val, "test")
            self.assertFalse(is_valid)
            self.assertIsNotNone(error)


if __name__ == '__main__':
    unittest.main()