NeuroSAM3 / validators.py
mmrech's picture
Refactor codebase: Add modular structure, logging, validation, and comprehensive improvements
69066c5
"""
Input validation utilities for NeuroSAM 3 application.
Provides validation functions for user inputs, files, and parameters.
"""
import os
from typing import Optional, Tuple
from pathlib import Path
from logger_config import logger
from config import (
MAX_FILE_SIZE_BYTES,
ALLOWED_IMAGE_EXTENSIONS,
ALLOWED_ANNOTATION_EXTENSIONS,
MIN_THRESHOLD,
MAX_THRESHOLD,
MIN_MASK_THRESHOLD,
MAX_MASK_THRESHOLD,
MAX_COORDINATE_VALUE,
MIN_NUM_MASKS,
MAX_NUM_MASKS,
)
class ValidationError(Exception):
"""Custom exception for validation errors."""
pass
def validate_file_path(file_path: Optional[str]) -> Tuple[bool, Optional[str]]:
"""
Validate that a file path exists and is accessible.
Args:
file_path: Path to validate
Returns:
Tuple of (is_valid, error_message)
"""
if file_path is None:
return False, "File path is None"
if not isinstance(file_path, (str, Path)):
return False, f"Invalid file path type: {type(file_path)}"
file_path = str(file_path)
if not os.path.exists(file_path):
return False, f"File not found: {file_path}"
if not os.path.isfile(file_path):
return False, f"Path is not a file: {file_path}"
return True, None
def validate_file_size(file_path: str) -> Tuple[bool, Optional[str]]:
"""
Validate that a file size is within limits.
Args:
file_path: Path to file to validate
Returns:
Tuple of (is_valid, error_message)
"""
try:
file_size = os.path.getsize(file_path)
if file_size > MAX_FILE_SIZE_BYTES:
size_mb = file_size / (1024 * 1024)
max_mb = MAX_FILE_SIZE_BYTES / (1024 * 1024)
return False, f"File size ({size_mb:.2f} MB) exceeds maximum ({max_mb} MB)"
return True, None
except OSError as e:
return False, f"Could not check file size: {e}"
def validate_file_extension(file_path: str, allowed_extensions: tuple = ALLOWED_IMAGE_EXTENSIONS) -> Tuple[bool, Optional[str]]:
"""
Validate file extension.
Args:
file_path: Path to file
allowed_extensions: Tuple of allowed extensions (default: image extensions)
Returns:
Tuple of (is_valid, error_message)
"""
ext = os.path.splitext(file_path)[1].lower()
if ext not in allowed_extensions:
return False, f"File extension '{ext}' not allowed. Allowed: {', '.join(allowed_extensions)}"
return True, None
def validate_image_file(file_path: Optional[str]) -> Tuple[bool, Optional[str]]:
"""
Comprehensive validation for image files.
Args:
file_path: Path to image file
Returns:
Tuple of (is_valid, error_message)
"""
# Check if path is valid
is_valid, error = validate_file_path(file_path)
if not is_valid:
return False, error
file_path = str(file_path)
# Check extension
is_valid, error = validate_file_extension(file_path, ALLOWED_IMAGE_EXTENSIONS)
if not is_valid:
return False, error
# Check file size
is_valid, error = validate_file_size(file_path)
if not is_valid:
return False, error
return True, None
def validate_threshold(threshold: float) -> Tuple[bool, Optional[str]]:
"""
Validate threshold value.
Args:
threshold: Threshold value to validate
Returns:
Tuple of (is_valid, error_message)
"""
if not isinstance(threshold, (int, float)):
return False, f"Threshold must be a number, got {type(threshold)}"
if threshold < MIN_THRESHOLD or threshold > MAX_THRESHOLD:
return False, f"Threshold must be between {MIN_THRESHOLD} and {MAX_THRESHOLD}, got {threshold}"
return True, None
def validate_mask_threshold(mask_threshold: float) -> Tuple[bool, Optional[str]]:
"""
Validate mask threshold value.
Args:
mask_threshold: Mask threshold value to validate
Returns:
Tuple of (is_valid, error_message)
"""
if not isinstance(mask_threshold, (int, float)):
return False, f"Mask threshold must be a number, got {type(mask_threshold)}"
if mask_threshold < MIN_MASK_THRESHOLD or mask_threshold > MAX_MASK_THRESHOLD:
return False, f"Mask threshold must be between {MIN_MASK_THRESHOLD} and {MAX_MASK_THRESHOLD}, got {mask_threshold}"
return True, None
def validate_coordinates(x: float, y: float, max_value: int = MAX_COORDINATE_VALUE) -> Tuple[bool, Optional[str]]:
"""
Validate coordinate values.
Args:
x: X coordinate
y: Y coordinate
max_value: Maximum allowed coordinate value
Returns:
Tuple of (is_valid, error_message)
"""
if not isinstance(x, (int, float)) or not isinstance(y, (int, float)):
return False, f"Coordinates must be numbers, got x={type(x)}, y={type(y)}"
if x < 0 or y < 0:
return False, f"Coordinates must be non-negative, got x={x}, y={y}"
if x > max_value or y > max_value:
return False, f"Coordinates exceed maximum value ({max_value}), got x={x}, y={y}"
return True, None
def validate_bounding_box(x1: float, y1: float, x2: float, y2: float) -> Tuple[bool, Optional[str]]:
"""
Validate bounding box coordinates.
Args:
x1, y1: Top-left corner coordinates
x2, y2: Bottom-right corner coordinates
Returns:
Tuple of (is_valid, error_message)
"""
# Validate individual coordinates
for coord, name in [(x1, 'x1'), (y1, 'y1'), (x2, 'x2'), (y2, 'y2')]:
if not isinstance(coord, (int, float)):
return False, f"{name} must be a number, got {type(coord)}"
if coord < 0:
return False, f"{name} must be non-negative, got {coord}"
if coord > MAX_COORDINATE_VALUE:
return False, f"{name} exceeds maximum ({MAX_COORDINATE_VALUE}), got {coord}"
# Validate box dimensions
if x2 <= x1:
return False, f"x2 ({x2}) must be greater than x1 ({x1})"
if y2 <= y1:
return False, f"y2 ({y2}) must be greater than y1 ({y1})"
return True, None
def validate_num_masks(num_masks: int) -> Tuple[bool, Optional[str]]:
"""
Validate number of masks parameter.
Args:
num_masks: Number of masks to generate
Returns:
Tuple of (is_valid, error_message)
"""
if not isinstance(num_masks, int):
return False, f"Number of masks must be an integer, got {type(num_masks)}"
if num_masks < MIN_NUM_MASKS or num_masks > MAX_NUM_MASKS:
return False, f"Number of masks must be between {MIN_NUM_MASKS} and {MAX_NUM_MASKS}, got {num_masks}"
return True, None
def validate_prompt_text(prompt_text: Optional[str]) -> Tuple[bool, Optional[str], str]:
"""
Validate and sanitize prompt text.
Args:
prompt_text: Text prompt to validate
Returns:
Tuple of (is_valid, error_message, sanitized_prompt)
"""
if prompt_text is None:
return True, None, "brain" # Default prompt
if not isinstance(prompt_text, str):
return False, f"Prompt must be a string, got {type(prompt_text)}", ""
# Sanitize: strip whitespace
sanitized = prompt_text.strip()
# Check length (reasonable limit)
if len(sanitized) > 500:
return False, "Prompt text is too long (max 500 characters)", ""
# Use default if empty
if not sanitized:
sanitized = "brain"
return True, None, sanitized
def validate_modality(modality: Optional[str]) -> Tuple[bool, Optional[str]]:
"""
Validate imaging modality.
Args:
modality: Modality string (CT or MRI)
Returns:
Tuple of (is_valid, error_message)
"""
if modality is None:
return False, "Modality is required"
if not isinstance(modality, str):
return False, f"Modality must be a string, got {type(modality)}"
modality_upper = modality.upper()
if modality_upper not in ("CT", "MRI"):
return False, f"Modality must be 'CT' or 'MRI', got '{modality}'"
return True, None
def validate_transparency(transparency: float) -> Tuple[bool, Optional[str]]:
"""
Validate transparency value.
Args:
transparency: Transparency value (0.0-1.0)
Returns:
Tuple of (is_valid, error_message)
"""
if not isinstance(transparency, (int, float)):
return False, f"Transparency must be a number, got {type(transparency)}"
if transparency < 0.0 or transparency > 1.0:
return False, f"Transparency must be between 0.0 and 1.0, got {transparency}"
return True, None
def validate_brightness_contrast(value: float, name: str = "value") -> Tuple[bool, Optional[str]]:
"""
Validate brightness or contrast value.
Args:
value: Brightness or contrast value
name: Name of the parameter for error messages
Returns:
Tuple of (is_valid, error_message)
"""
if not isinstance(value, (int, float)):
return False, f"{name} must be a number, got {type(value)}"
if value < 0.0 or value > 3.0:
return False, f"{name} must be between 0.0 and 3.0, got {value}"
return True, None