|
|
""" |
|
|
Input validation utilities for NeuroSAM 3 application. |
|
|
Provides validation functions for user inputs, files, and parameters. |
|
|
""" |
|
|
|
|
|
import os |
|
|
from typing import Optional, Tuple |
|
|
from pathlib import Path |
|
|
from logger_config import logger |
|
|
from config import ( |
|
|
MAX_FILE_SIZE_BYTES, |
|
|
ALLOWED_IMAGE_EXTENSIONS, |
|
|
ALLOWED_ANNOTATION_EXTENSIONS, |
|
|
MIN_THRESHOLD, |
|
|
MAX_THRESHOLD, |
|
|
MIN_MASK_THRESHOLD, |
|
|
MAX_MASK_THRESHOLD, |
|
|
MAX_COORDINATE_VALUE, |
|
|
MIN_NUM_MASKS, |
|
|
MAX_NUM_MASKS, |
|
|
) |
|
|
|
|
|
|
|
|
class ValidationError(Exception): |
|
|
"""Custom exception for validation errors.""" |
|
|
pass |
|
|
|
|
|
|
|
|
def validate_file_path(file_path: Optional[str]) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Validate that a file path exists and is accessible. |
|
|
|
|
|
Args: |
|
|
file_path: Path to validate |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
if file_path is None: |
|
|
return False, "File path is None" |
|
|
|
|
|
if not isinstance(file_path, (str, Path)): |
|
|
return False, f"Invalid file path type: {type(file_path)}" |
|
|
|
|
|
file_path = str(file_path) |
|
|
|
|
|
if not os.path.exists(file_path): |
|
|
return False, f"File not found: {file_path}" |
|
|
|
|
|
if not os.path.isfile(file_path): |
|
|
return False, f"Path is not a file: {file_path}" |
|
|
|
|
|
return True, None |
|
|
|
|
|
|
|
|
def validate_file_size(file_path: str) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Validate that a file size is within limits. |
|
|
|
|
|
Args: |
|
|
file_path: Path to file to validate |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
try: |
|
|
file_size = os.path.getsize(file_path) |
|
|
if file_size > MAX_FILE_SIZE_BYTES: |
|
|
size_mb = file_size / (1024 * 1024) |
|
|
max_mb = MAX_FILE_SIZE_BYTES / (1024 * 1024) |
|
|
return False, f"File size ({size_mb:.2f} MB) exceeds maximum ({max_mb} MB)" |
|
|
return True, None |
|
|
except OSError as e: |
|
|
return False, f"Could not check file size: {e}" |
|
|
|
|
|
|
|
|
def validate_file_extension(file_path: str, allowed_extensions: tuple = ALLOWED_IMAGE_EXTENSIONS) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Validate file extension. |
|
|
|
|
|
Args: |
|
|
file_path: Path to file |
|
|
allowed_extensions: Tuple of allowed extensions (default: image extensions) |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
ext = os.path.splitext(file_path)[1].lower() |
|
|
if ext not in allowed_extensions: |
|
|
return False, f"File extension '{ext}' not allowed. Allowed: {', '.join(allowed_extensions)}" |
|
|
return True, None |
|
|
|
|
|
|
|
|
def validate_image_file(file_path: Optional[str]) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Comprehensive validation for image files. |
|
|
|
|
|
Args: |
|
|
file_path: Path to image file |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
|
|
|
is_valid, error = validate_file_path(file_path) |
|
|
if not is_valid: |
|
|
return False, error |
|
|
|
|
|
file_path = str(file_path) |
|
|
|
|
|
|
|
|
is_valid, error = validate_file_extension(file_path, ALLOWED_IMAGE_EXTENSIONS) |
|
|
if not is_valid: |
|
|
return False, error |
|
|
|
|
|
|
|
|
is_valid, error = validate_file_size(file_path) |
|
|
if not is_valid: |
|
|
return False, error |
|
|
|
|
|
return True, None |
|
|
|
|
|
|
|
|
def validate_threshold(threshold: float) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Validate threshold value. |
|
|
|
|
|
Args: |
|
|
threshold: Threshold value to validate |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
if not isinstance(threshold, (int, float)): |
|
|
return False, f"Threshold must be a number, got {type(threshold)}" |
|
|
|
|
|
if threshold < MIN_THRESHOLD or threshold > MAX_THRESHOLD: |
|
|
return False, f"Threshold must be between {MIN_THRESHOLD} and {MAX_THRESHOLD}, got {threshold}" |
|
|
|
|
|
return True, None |
|
|
|
|
|
|
|
|
def validate_mask_threshold(mask_threshold: float) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Validate mask threshold value. |
|
|
|
|
|
Args: |
|
|
mask_threshold: Mask threshold value to validate |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
if not isinstance(mask_threshold, (int, float)): |
|
|
return False, f"Mask threshold must be a number, got {type(mask_threshold)}" |
|
|
|
|
|
if mask_threshold < MIN_MASK_THRESHOLD or mask_threshold > MAX_MASK_THRESHOLD: |
|
|
return False, f"Mask threshold must be between {MIN_MASK_THRESHOLD} and {MAX_MASK_THRESHOLD}, got {mask_threshold}" |
|
|
|
|
|
return True, None |
|
|
|
|
|
|
|
|
def validate_coordinates(x: float, y: float, max_value: int = MAX_COORDINATE_VALUE) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Validate coordinate values. |
|
|
|
|
|
Args: |
|
|
x: X coordinate |
|
|
y: Y coordinate |
|
|
max_value: Maximum allowed coordinate value |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
if not isinstance(x, (int, float)) or not isinstance(y, (int, float)): |
|
|
return False, f"Coordinates must be numbers, got x={type(x)}, y={type(y)}" |
|
|
|
|
|
if x < 0 or y < 0: |
|
|
return False, f"Coordinates must be non-negative, got x={x}, y={y}" |
|
|
|
|
|
if x > max_value or y > max_value: |
|
|
return False, f"Coordinates exceed maximum value ({max_value}), got x={x}, y={y}" |
|
|
|
|
|
return True, None |
|
|
|
|
|
|
|
|
def validate_bounding_box(x1: float, y1: float, x2: float, y2: float) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Validate bounding box coordinates. |
|
|
|
|
|
Args: |
|
|
x1, y1: Top-left corner coordinates |
|
|
x2, y2: Bottom-right corner coordinates |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
|
|
|
for coord, name in [(x1, 'x1'), (y1, 'y1'), (x2, 'x2'), (y2, 'y2')]: |
|
|
if not isinstance(coord, (int, float)): |
|
|
return False, f"{name} must be a number, got {type(coord)}" |
|
|
if coord < 0: |
|
|
return False, f"{name} must be non-negative, got {coord}" |
|
|
if coord > MAX_COORDINATE_VALUE: |
|
|
return False, f"{name} exceeds maximum ({MAX_COORDINATE_VALUE}), got {coord}" |
|
|
|
|
|
|
|
|
if x2 <= x1: |
|
|
return False, f"x2 ({x2}) must be greater than x1 ({x1})" |
|
|
|
|
|
if y2 <= y1: |
|
|
return False, f"y2 ({y2}) must be greater than y1 ({y1})" |
|
|
|
|
|
return True, None |
|
|
|
|
|
|
|
|
def validate_num_masks(num_masks: int) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Validate number of masks parameter. |
|
|
|
|
|
Args: |
|
|
num_masks: Number of masks to generate |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
if not isinstance(num_masks, int): |
|
|
return False, f"Number of masks must be an integer, got {type(num_masks)}" |
|
|
|
|
|
if num_masks < MIN_NUM_MASKS or num_masks > MAX_NUM_MASKS: |
|
|
return False, f"Number of masks must be between {MIN_NUM_MASKS} and {MAX_NUM_MASKS}, got {num_masks}" |
|
|
|
|
|
return True, None |
|
|
|
|
|
|
|
|
def validate_prompt_text(prompt_text: Optional[str]) -> Tuple[bool, Optional[str], str]: |
|
|
""" |
|
|
Validate and sanitize prompt text. |
|
|
|
|
|
Args: |
|
|
prompt_text: Text prompt to validate |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message, sanitized_prompt) |
|
|
""" |
|
|
if prompt_text is None: |
|
|
return True, None, "brain" |
|
|
|
|
|
if not isinstance(prompt_text, str): |
|
|
return False, f"Prompt must be a string, got {type(prompt_text)}", "" |
|
|
|
|
|
|
|
|
sanitized = prompt_text.strip() |
|
|
|
|
|
|
|
|
if len(sanitized) > 500: |
|
|
return False, "Prompt text is too long (max 500 characters)", "" |
|
|
|
|
|
|
|
|
if not sanitized: |
|
|
sanitized = "brain" |
|
|
|
|
|
return True, None, sanitized |
|
|
|
|
|
|
|
|
def validate_modality(modality: Optional[str]) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Validate imaging modality. |
|
|
|
|
|
Args: |
|
|
modality: Modality string (CT or MRI) |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
if modality is None: |
|
|
return False, "Modality is required" |
|
|
|
|
|
if not isinstance(modality, str): |
|
|
return False, f"Modality must be a string, got {type(modality)}" |
|
|
|
|
|
modality_upper = modality.upper() |
|
|
if modality_upper not in ("CT", "MRI"): |
|
|
return False, f"Modality must be 'CT' or 'MRI', got '{modality}'" |
|
|
|
|
|
return True, None |
|
|
|
|
|
|
|
|
def validate_transparency(transparency: float) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Validate transparency value. |
|
|
|
|
|
Args: |
|
|
transparency: Transparency value (0.0-1.0) |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
if not isinstance(transparency, (int, float)): |
|
|
return False, f"Transparency must be a number, got {type(transparency)}" |
|
|
|
|
|
if transparency < 0.0 or transparency > 1.0: |
|
|
return False, f"Transparency must be between 0.0 and 1.0, got {transparency}" |
|
|
|
|
|
return True, None |
|
|
|
|
|
|
|
|
def validate_brightness_contrast(value: float, name: str = "value") -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Validate brightness or contrast value. |
|
|
|
|
|
Args: |
|
|
value: Brightness or contrast value |
|
|
name: Name of the parameter for error messages |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
if not isinstance(value, (int, float)): |
|
|
return False, f"{name} must be a number, got {type(value)}" |
|
|
|
|
|
if value < 0.0 or value > 3.0: |
|
|
return False, f"{name} must be between 0.0 and 3.0, got {value}" |
|
|
|
|
|
return True, None |
|
|
|
|
|
|