NeuroSAM3 / tests /test_validators.py
mmrech's picture
Refactor codebase: Add modular structure, logging, validation, and comprehensive improvements
69066c5
"""
Tests for validators module.
"""
import unittest
import os
import tempfile
import numpy as np
from validators import (
validate_file_path,
validate_file_size,
validate_file_extension,
validate_image_file,
validate_threshold,
validate_mask_threshold,
validate_coordinates,
validate_bounding_box,
validate_num_masks,
validate_prompt_text,
validate_modality,
validate_transparency,
validate_brightness_contrast,
ValidationError,
)
class TestValidators(unittest.TestCase):
"""Test cases for input validation functions."""
def setUp(self):
"""Set up test fixtures."""
self.temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
self.temp_file.write(b'test content')
self.temp_file.close()
self.temp_path = self.temp_file.name
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.temp_path):
os.unlink(self.temp_path)
def test_validate_file_path_valid(self):
"""Test file path validation with valid file."""
is_valid, error = validate_file_path(self.temp_path)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_file_path_none(self):
"""Test file path validation with None."""
is_valid, error = validate_file_path(None)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_file_path_not_exists(self):
"""Test file path validation with non-existent file."""
is_valid, error = validate_file_path("/nonexistent/file.png")
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_file_size_valid(self):
"""Test file size validation with valid file."""
is_valid, error = validate_file_size(self.temp_path)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_file_extension_valid(self):
"""Test file extension validation with valid extension."""
is_valid, error = validate_file_extension(self.temp_path)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_file_extension_invalid(self):
"""Test file extension validation with invalid extension."""
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
temp_file.close()
is_valid, error = validate_file_extension(temp_file.name)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
os.unlink(temp_file.name)
def test_validate_threshold_valid(self):
"""Test threshold validation with valid values."""
for threshold in [0.0, 0.1, 0.5, 1.0]:
is_valid, error = validate_threshold(threshold)
self.assertTrue(is_valid, f"Threshold {threshold} should be valid")
self.assertIsNone(error)
def test_validate_threshold_invalid(self):
"""Test threshold validation with invalid values."""
for threshold in [-0.1, 1.1, "invalid"]:
is_valid, error = validate_threshold(threshold)
self.assertFalse(is_valid, f"Threshold {threshold} should be invalid")
self.assertIsNotNone(error)
def test_validate_coordinates_valid(self):
"""Test coordinate validation with valid values."""
is_valid, error = validate_coordinates(100, 200)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_coordinates_invalid(self):
"""Test coordinate validation with invalid values."""
# Negative coordinates
is_valid, error = validate_coordinates(-1, 100)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
# Too large coordinates
is_valid, error = validate_coordinates(20000, 100)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_bounding_box_valid(self):
"""Test bounding box validation with valid values."""
is_valid, error = validate_bounding_box(10, 20, 100, 200)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_bounding_box_invalid(self):
"""Test bounding box validation with invalid values."""
# x2 <= x1
is_valid, error = validate_bounding_box(100, 20, 50, 200)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
# y2 <= y1
is_valid, error = validate_bounding_box(10, 200, 100, 50)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_num_masks_valid(self):
"""Test num masks validation with valid values."""
for num in [1, 3, 5]:
is_valid, error = validate_num_masks(num)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_num_masks_invalid(self):
"""Test num masks validation with invalid values."""
for num in [0, 6, -1]:
is_valid, error = validate_num_masks(num)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_prompt_text_valid(self):
"""Test prompt text validation with valid values."""
is_valid, error, prompt = validate_prompt_text("brain")
self.assertTrue(is_valid)
self.assertIsNone(error)
self.assertEqual(prompt, "brain")
def test_validate_prompt_text_none(self):
"""Test prompt text validation with None (should use default)."""
is_valid, error, prompt = validate_prompt_text(None)
self.assertTrue(is_valid)
self.assertEqual(prompt, "brain") # Default
def test_validate_prompt_text_empty(self):
"""Test prompt text validation with empty string (should use default)."""
is_valid, error, prompt = validate_prompt_text(" ")
self.assertTrue(is_valid)
self.assertEqual(prompt, "brain") # Default
def test_validate_modality_valid(self):
"""Test modality validation with valid values."""
for modality in ["CT", "MRI", "ct", "mri"]:
is_valid, error = validate_modality(modality)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_modality_invalid(self):
"""Test modality validation with invalid values."""
for modality in [None, "invalid", "XRAY"]:
is_valid, error = validate_modality(modality)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_transparency_valid(self):
"""Test transparency validation with valid values."""
for trans in [0.0, 0.5, 1.0]:
is_valid, error = validate_transparency(trans)
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_transparency_invalid(self):
"""Test transparency validation with invalid values."""
for trans in [-0.1, 1.1, "invalid"]:
is_valid, error = validate_transparency(trans)
self.assertFalse(is_valid)
self.assertIsNotNone(error)
def test_validate_brightness_contrast_valid(self):
"""Test brightness/contrast validation with valid values."""
for val in [0.0, 1.0, 2.0, 3.0]:
is_valid, error = validate_brightness_contrast(val, "test")
self.assertTrue(is_valid)
self.assertIsNone(error)
def test_validate_brightness_contrast_invalid(self):
"""Test brightness/contrast validation with invalid values."""
for val in [-0.1, 3.1, "invalid"]:
is_valid, error = validate_brightness_contrast(val, "test")
self.assertFalse(is_valid)
self.assertIsNotNone(error)
if __name__ == '__main__':
unittest.main()