NeuroSAM3 / tests /test_validators.py
mmrech's picture
Refactor codebase: Add modular structure, logging, validation, and comprehensive improvements
69066c5
raw
history blame
7.91 kB
"""
Tests for validators module.
"""
import unittest
import os
import tempfile
import numpy as np
from validators import (
validate_file_path,
validate_file_size,
validate_file_extension,
validate_image_file,
validate_threshold,
validate_mask_threshold,
validate_coordinates,
validate_bounding_box,
validate_num_masks,
validate_prompt_text,
validate_modality,
validate_transparency,
validate_brightness_contrast,
ValidationError,
)
class TestValidators(unittest.TestCase):
"""Test cases for input validation functions."""
def setUp(self):
"""Set up test fixtures."""
self.temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
self.temp_file.write(b'test content')
self.temp_file.close()
self.temp_path = self.temp_file.name
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.temp_path):
os.unlink(self.temp_path)
def test_validate_file_path_valid(self):
"""Test file path validation with valid file."""
is_valid, error = validate_file_path(self.temp_path)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_file_path_none(self):
"""Test file path validation with None."""
is_valid, error = validate_file_path(None)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_file_path_not_exists(self):
"""Test file path validation with non-existent file."""
is_valid, error = validate_file_path("/nonexistent/file.png")
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_file_size_valid(self):
"""Test file size validation with valid file."""
is_valid, error = validate_file_size(self.temp_path)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_file_extension_valid(self):
"""Test file extension validation with valid extension."""
is_valid, error = validate_file_extension(self.temp_path)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_file_extension_invalid(self):
"""Test file extension validation with invalid extension."""
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
temp_file.close()
is_valid, error = validate_file_extension(temp_file.name)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
os.unlink(temp_file.name)
def test_validate_threshold_valid(self):
"""Test threshold validation with valid values."""
for threshold in [0.0, 0.1, 0.5, 1.0]:
is_valid, error = validate_threshold(threshold)
self.assertTrue(is_valid, f"Threshold {threshold} should be valid")
self.assertIsNone(error)
def test_validate_threshold_invalid(self):
"""Test threshold validation with invalid values."""
for threshold in [-0.1, 1.1, "invalid"]:
is_valid, error = validate_threshold(threshold)
self.assertFalse(is_valid, f"Threshold {threshold} should be invalid")
self.assertIsNotNone(error)
def test_validate_coordinates_valid(self):
"""Test coordinate validation with valid values."""
is_valid, error = validate_coordinates(100, 200)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_coordinates_invalid(self):
"""Test coordinate validation with invalid values."""
# Negative coordinates
is_valid, error = validate_coordinates(-1, 100)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
# Too large coordinates
is_valid, error = validate_coordinates(20000, 100)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_bounding_box_valid(self):
"""Test bounding box validation with valid values."""
is_valid, error = validate_bounding_box(10, 20, 100, 200)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_bounding_box_invalid(self):
"""Test bounding box validation with invalid values."""
# x2 <= x1
is_valid, error = validate_bounding_box(100, 20, 50, 200)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
# y2 <= y1
is_valid, error = validate_bounding_box(10, 200, 100, 50)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_num_masks_valid(self):
"""Test num masks validation with valid values."""
for num in [1, 3, 5]:
is_valid, error = validate_num_masks(num)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_num_masks_invalid(self):
"""Test num masks validation with invalid values."""
for num in [0, 6, -1]:
is_valid, error = validate_num_masks(num)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_prompt_text_valid(self):
"""Test prompt text validation with valid values."""
is_valid, error, prompt = validate_prompt_text("brain")
self.assertTrue(is_valid)
self.assertIsNone(error)
self.assertEqual(prompt, "brain")
def test_validate_prompt_text_none(self):
"""Test prompt text validation with None (should use default)."""
is_valid, error, prompt = validate_prompt_text(None)
self.assertTrue(is_valid)
self.assertEqual(prompt, "brain") # Default
def test_validate_prompt_text_empty(self):
"""Test prompt text validation with empty string (should use default)."""
is_valid, error, prompt = validate_prompt_text(" ")
self.assertTrue(is_valid)
self.assertEqual(prompt, "brain") # Default
def test_validate_modality_valid(self):
"""Test modality validation with valid values."""
for modality in ["CT", "MRI", "ct", "mri"]:
is_valid, error = validate_modality(modality)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_modality_invalid(self):
"""Test modality validation with invalid values."""
for modality in [None, "invalid", "XRAY"]:
is_valid, error = validate_modality(modality)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_transparency_valid(self):
"""Test transparency validation with valid values."""
for trans in [0.0, 0.5, 1.0]:
is_valid, error = validate_transparency(trans)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_transparency_invalid(self):
"""Test transparency validation with invalid values."""
for trans in [-0.1, 1.1, "invalid"]:
is_valid, error = validate_transparency(trans)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_brightness_contrast_valid(self):
"""Test brightness/contrast validation with valid values."""
for val in [0.0, 1.0, 2.0, 3.0]:
is_valid, error = validate_brightness_contrast(val, "test")
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_brightness_contrast_invalid(self):
"""Test brightness/contrast validation with invalid values."""
for val in [-0.1, 3.1, "invalid"]:
is_valid, error = validate_brightness_contrast(val, "test")
self.assertFalse(is_valid)
self.assertIsNotNone(error)
if __name__ == '__main__':
unittest.main()