""" Input validation utilities for NeuroSAM 3 application. Provides validation functions for user inputs, files, and parameters. """ import os from typing import Optional, Tuple from pathlib import Path from logger_config import logger from config import ( MAX_FILE_SIZE_BYTES, ALLOWED_IMAGE_EXTENSIONS, ALLOWED_ANNOTATION_EXTENSIONS, MIN_THRESHOLD, MAX_THRESHOLD, MIN_MASK_THRESHOLD, MAX_MASK_THRESHOLD, MAX_COORDINATE_VALUE, MIN_NUM_MASKS, MAX_NUM_MASKS, ) class ValidationError(Exception): """Custom exception for validation errors.""" pass def validate_file_path(file_path: Optional[str]) -> Tuple[bool, Optional[str]]: """ Validate that a file path exists and is accessible. Args: file_path: Path to validate Returns: Tuple of (is_valid, error_message) """ if file_path is None: return False, "File path is None" if not isinstance(file_path, (str, Path)): return False, f"Invalid file path type: {type(file_path)}" file_path = str(file_path) if not os.path.exists(file_path): return False, f"File not found: {file_path}" if not os.path.isfile(file_path): return False, f"Path is not a file: {file_path}" return True, None def validate_file_size(file_path: str) -> Tuple[bool, Optional[str]]: """ Validate that a file size is within limits. Args: file_path: Path to file to validate Returns: Tuple of (is_valid, error_message) """ try: file_size = os.path.getsize(file_path) if file_size > MAX_FILE_SIZE_BYTES: size_mb = file_size / (1024 * 1024) max_mb = MAX_FILE_SIZE_BYTES / (1024 * 1024) return False, f"File size ({size_mb:.2f} MB) exceeds maximum ({max_mb} MB)" return True, None except OSError as e: return False, f"Could not check file size: {e}" def validate_file_extension(file_path: str, allowed_extensions: tuple = ALLOWED_IMAGE_EXTENSIONS) -> Tuple[bool, Optional[str]]: """ Validate file extension. Args: file_path: Path to file allowed_extensions: Tuple of allowed extensions (default: image extensions) Returns: Tuple of (is_valid, error_message) """ ext = os.path.splitext(file_path)[1].lower() if ext not in allowed_extensions: return False, f"File extension '{ext}' not allowed. Allowed: {', '.join(allowed_extensions)}" return True, None def validate_image_file(file_path: Optional[str]) -> Tuple[bool, Optional[str]]: """ Comprehensive validation for image files. Args: file_path: Path to image file Returns: Tuple of (is_valid, error_message) """ # Check if path is valid is_valid, error = validate_file_path(file_path) if not is_valid: return False, error file_path = str(file_path) # Check extension is_valid, error = validate_file_extension(file_path, ALLOWED_IMAGE_EXTENSIONS) if not is_valid: return False, error # Check file size is_valid, error = validate_file_size(file_path) if not is_valid: return False, error return True, None def validate_threshold(threshold: float) -> Tuple[bool, Optional[str]]: """ Validate threshold value. Args: threshold: Threshold value to validate Returns: Tuple of (is_valid, error_message) """ if not isinstance(threshold, (int, float)): return False, f"Threshold must be a number, got {type(threshold)}" if threshold < MIN_THRESHOLD or threshold > MAX_THRESHOLD: return False, f"Threshold must be between {MIN_THRESHOLD} and {MAX_THRESHOLD}, got {threshold}" return True, None def validate_mask_threshold(mask_threshold: float) -> Tuple[bool, Optional[str]]: """ Validate mask threshold value. Args: mask_threshold: Mask threshold value to validate Returns: Tuple of (is_valid, error_message) """ if not isinstance(mask_threshold, (int, float)): return False, f"Mask threshold must be a number, got {type(mask_threshold)}" if mask_threshold < MIN_MASK_THRESHOLD or mask_threshold > MAX_MASK_THRESHOLD: return False, f"Mask threshold must be between {MIN_MASK_THRESHOLD} and {MAX_MASK_THRESHOLD}, got {mask_threshold}" return True, None def validate_coordinates(x: float, y: float, max_value: int = MAX_COORDINATE_VALUE) -> Tuple[bool, Optional[str]]: """ Validate coordinate values. Args: x: X coordinate y: Y coordinate max_value: Maximum allowed coordinate value Returns: Tuple of (is_valid, error_message) """ if not isinstance(x, (int, float)) or not isinstance(y, (int, float)): return False, f"Coordinates must be numbers, got x={type(x)}, y={type(y)}" if x < 0 or y < 0: return False, f"Coordinates must be non-negative, got x={x}, y={y}" if x > max_value or y > max_value: return False, f"Coordinates exceed maximum value ({max_value}), got x={x}, y={y}" return True, None def validate_bounding_box(x1: float, y1: float, x2: float, y2: float) -> Tuple[bool, Optional[str]]: """ Validate bounding box coordinates. Args: x1, y1: Top-left corner coordinates x2, y2: Bottom-right corner coordinates Returns: Tuple of (is_valid, error_message) """ # Validate individual coordinates for coord, name in [(x1, 'x1'), (y1, 'y1'), (x2, 'x2'), (y2, 'y2')]: if not isinstance(coord, (int, float)): return False, f"{name} must be a number, got {type(coord)}" if coord < 0: return False, f"{name} must be non-negative, got {coord}" if coord > MAX_COORDINATE_VALUE: return False, f"{name} exceeds maximum ({MAX_COORDINATE_VALUE}), got {coord}" # Validate box dimensions if x2 <= x1: return False, f"x2 ({x2}) must be greater than x1 ({x1})" if y2 <= y1: return False, f"y2 ({y2}) must be greater than y1 ({y1})" return True, None def validate_num_masks(num_masks: int) -> Tuple[bool, Optional[str]]: """ Validate number of masks parameter. Args: num_masks: Number of masks to generate Returns: Tuple of (is_valid, error_message) """ if not isinstance(num_masks, int): return False, f"Number of masks must be an integer, got {type(num_masks)}" if num_masks < MIN_NUM_MASKS or num_masks > MAX_NUM_MASKS: return False, f"Number of masks must be between {MIN_NUM_MASKS} and {MAX_NUM_MASKS}, got {num_masks}" return True, None def validate_prompt_text(prompt_text: Optional[str]) -> Tuple[bool, Optional[str], str]: """ Validate and sanitize prompt text. Args: prompt_text: Text prompt to validate Returns: Tuple of (is_valid, error_message, sanitized_prompt) """ if prompt_text is None: return True, None, "brain" # Default prompt if not isinstance(prompt_text, str): return False, f"Prompt must be a string, got {type(prompt_text)}", "" # Sanitize: strip whitespace sanitized = prompt_text.strip() # Check length (reasonable limit) if len(sanitized) > 500: return False, "Prompt text is too long (max 500 characters)", "" # Use default if empty if not sanitized: sanitized = "brain" return True, None, sanitized def validate_modality(modality: Optional[str]) -> Tuple[bool, Optional[str]]: """ Validate imaging modality. Args: modality: Modality string (CT or MRI) Returns: Tuple of (is_valid, error_message) """ if modality is None: return False, "Modality is required" if not isinstance(modality, str): return False, f"Modality must be a string, got {type(modality)}" modality_upper = modality.upper() if modality_upper not in ("CT", "MRI"): return False, f"Modality must be 'CT' or 'MRI', got '{modality}'" return True, None def validate_transparency(transparency: float) -> Tuple[bool, Optional[str]]: """ Validate transparency value. Args: transparency: Transparency value (0.0-1.0) Returns: Tuple of (is_valid, error_message) """ if not isinstance(transparency, (int, float)): return False, f"Transparency must be a number, got {type(transparency)}" if transparency < 0.0 or transparency > 1.0: return False, f"Transparency must be between 0.0 and 1.0, got {transparency}" return True, None def validate_brightness_contrast(value: float, name: str = "value") -> Tuple[bool, Optional[str]]: """ Validate brightness or contrast value. Args: value: Brightness or contrast value name: Name of the parameter for error messages Returns: Tuple of (is_valid, error_message) """ if not isinstance(value, (int, float)): return False, f"{name} must be a number, got {type(value)}" if value < 0.0 or value > 3.0: return False, f"{name} must be between 0.0 and 3.0, got {value}" return True, None