Spaces:
Sleeping
Sleeping
| """ | |
| Input Validation Utilities | |
| =========================== | |
| Validation functions for user inputs in Nano Banana Streamlit. | |
| Ensures data integrity and provides clear error messages. | |
| """ | |
| from typing import Optional, List, Tuple | |
| from pathlib import Path | |
| from PIL import Image | |
| from config.settings import Settings | |
| from utils.logging_utils import get_logger | |
| logger = get_logger(__name__) | |
| # ============================================================================= | |
| # PARAMETER VALIDATION | |
| # ============================================================================= | |
| def validate_temperature(temperature: float) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Validate temperature parameter. | |
| Args: | |
| temperature: Temperature value to validate | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| error_message is None if valid | |
| """ | |
| if not isinstance(temperature, (int, float)): | |
| return False, "Temperature must be a number" | |
| if temperature < Settings.MIN_TEMPERATURE or temperature > Settings.MAX_TEMPERATURE: | |
| return False, f"Temperature must be between {Settings.MIN_TEMPERATURE} and {Settings.MAX_TEMPERATURE}" | |
| return True, None | |
| def validate_aspect_ratio(aspect_ratio: str) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Validate aspect ratio parameter. | |
| Args: | |
| aspect_ratio: Aspect ratio string (display name or value) | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| if not isinstance(aspect_ratio, str): | |
| return False, "Aspect ratio must be a string" | |
| # Check if it's a display name | |
| if aspect_ratio in Settings.ASPECT_RATIOS: | |
| return True, None | |
| # Check if it's a ratio value | |
| if aspect_ratio in Settings.ASPECT_RATIOS.values(): | |
| return True, None | |
| return False, f"Invalid aspect ratio: {aspect_ratio}" | |
| def validate_backend(backend: str) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Validate backend parameter. | |
| Args: | |
| backend: Backend name | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| if not isinstance(backend, str): | |
| return False, "Backend must be a string" | |
| if backend not in Settings.AVAILABLE_BACKENDS: | |
| return False, f"Invalid backend: {backend}. Must be one of {Settings.AVAILABLE_BACKENDS}" | |
| return True, None | |
| def validate_prompt(prompt: str, min_length: int = 1, max_length: int = 5000) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Validate text prompt. | |
| Args: | |
| prompt: Text prompt | |
| min_length: Minimum required length (default: 1) | |
| max_length: Maximum allowed length (default: 5000) | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| if not isinstance(prompt, str): | |
| return False, "Prompt must be a string" | |
| prompt = prompt.strip() | |
| if len(prompt) < min_length: | |
| return False, f"Prompt must be at least {min_length} character(s)" | |
| if len(prompt) > max_length: | |
| return False, f"Prompt must be at most {max_length} characters" | |
| return True, None | |
| def validate_character_name(name: str) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Validate character name. | |
| Args: | |
| name: Character name | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| if not isinstance(name, str): | |
| return False, "Character name must be a string" | |
| name = name.strip() | |
| if len(name) < 1: | |
| return False, "Character name cannot be empty" | |
| if len(name) > 100: | |
| return False, "Character name must be at most 100 characters" | |
| return True, None | |
| # ============================================================================= | |
| # IMAGE VALIDATION | |
| # ============================================================================= | |
| def validate_image(image: Image.Image) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Validate PIL Image object. | |
| Checks: | |
| - Is valid Image instance | |
| - Has reasonable dimensions | |
| - Is in supported format | |
| Args: | |
| image: PIL Image to validate | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| if not isinstance(image, Image.Image): | |
| return False, "Invalid image object" | |
| # Check dimensions | |
| width, height = image.size | |
| if width < 1 or height < 1: | |
| return False, "Image has invalid dimensions" | |
| if width > 8192 or height > 8192: | |
| return False, "Image is too large (max 8192x8192 pixels)" | |
| # Check mode (format) | |
| if image.mode not in ['RGB', 'RGBA', 'L', 'P']: | |
| return False, f"Unsupported image mode: {image.mode}" | |
| return True, None | |
| def validate_image_file(file_path: Path) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Validate image file path and format. | |
| Args: | |
| file_path: Path to image file | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| if not isinstance(file_path, Path): | |
| try: | |
| file_path = Path(file_path) | |
| except Exception: | |
| return False, "Invalid file path" | |
| # Check exists | |
| if not file_path.exists(): | |
| return False, f"File not found: {file_path}" | |
| # Check is file (not directory) | |
| if not file_path.is_file(): | |
| return False, f"Not a file: {file_path}" | |
| # Check extension | |
| valid_extensions = {'.png', '.jpg', '.jpeg', '.webp', '.bmp'} | |
| if file_path.suffix.lower() not in valid_extensions: | |
| return False, f"Unsupported file format: {file_path.suffix}" | |
| # Try to open as image | |
| try: | |
| with Image.open(file_path) as img: | |
| return validate_image(img) | |
| except Exception as e: | |
| return False, f"Cannot open as image: {e}" | |
| def validate_image_upload_size(file_size_bytes: int) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Validate uploaded file size. | |
| Args: | |
| file_size_bytes: File size in bytes | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| max_bytes = Settings.MAX_IMAGE_UPLOAD_SIZE * 1024 * 1024 # Convert MB to bytes | |
| if file_size_bytes > max_bytes: | |
| max_mb = Settings.MAX_IMAGE_UPLOAD_SIZE | |
| actual_mb = file_size_bytes / (1024 * 1024) | |
| return False, f"File too large: {actual_mb:.1f}MB (max: {max_mb}MB)" | |
| return True, None | |
| # ============================================================================= | |
| # GENERATION REQUEST VALIDATION | |
| # ============================================================================= | |
| def validate_generation_request( | |
| prompt: str, | |
| backend: str, | |
| aspect_ratio: str, | |
| temperature: float, | |
| input_images: Optional[List[Image.Image]] = None | |
| ) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Validate a complete generation request. | |
| Validates all parameters required for image generation. | |
| Args: | |
| prompt: Text prompt | |
| backend: Backend name | |
| aspect_ratio: Aspect ratio | |
| temperature: Temperature value | |
| input_images: Optional list of input images | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| error_message is None if valid | |
| """ | |
| # Validate prompt | |
| valid, error = validate_prompt(prompt) | |
| if not valid: | |
| return False, f"Invalid prompt: {error}" | |
| # Validate backend | |
| valid, error = validate_backend(backend) | |
| if not valid: | |
| return False, f"Invalid backend: {error}" | |
| # Validate aspect ratio | |
| valid, error = validate_aspect_ratio(aspect_ratio) | |
| if not valid: | |
| return False, f"Invalid aspect ratio: {error}" | |
| # Validate temperature | |
| valid, error = validate_temperature(temperature) | |
| if not valid: | |
| return False, f"Invalid temperature: {error}" | |
| # Validate input images if provided | |
| if input_images: | |
| if not isinstance(input_images, list): | |
| return False, "Input images must be a list" | |
| if len(input_images) > 3: | |
| return False, "Maximum 3 input images allowed" | |
| for idx, img in enumerate(input_images, 1): | |
| valid, error = validate_image(img) | |
| if not valid: | |
| return False, f"Invalid input image {idx}: {error}" | |
| return True, None | |
| def validate_character_forge_request( | |
| character_name: str, | |
| initial_image: Optional[Image.Image], | |
| face_image: Optional[Image.Image], | |
| body_image: Optional[Image.Image], | |
| image_type: str, | |
| backend: str | |
| ) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Validate a Character Forge generation request. | |
| Args: | |
| character_name: Name for character | |
| initial_image: Initial image (for Face Only / Full Body modes) | |
| face_image: Face image (for Face+Body Separate mode) | |
| body_image: Body image (for Face+Body Separate mode) | |
| image_type: Type of input ("Face Only", "Full Body", "Face + Body (Separate)") | |
| backend: Backend to use | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| # Validate character name | |
| valid, error = validate_character_name(character_name) | |
| if not valid: | |
| return False, error | |
| # Validate backend | |
| valid, error = validate_backend(backend) | |
| if not valid: | |
| return False, error | |
| # Validate images based on mode | |
| if image_type == "Face + Body (Separate)": | |
| if face_image is None: | |
| return False, "Face image is required for Face+Body Separate mode" | |
| if body_image is None: | |
| return False, "Body image is required for Face+Body Separate mode" | |
| valid, error = validate_image(face_image) | |
| if not valid: | |
| return False, f"Invalid face image: {error}" | |
| valid, error = validate_image(body_image) | |
| if not valid: | |
| return False, f"Invalid body image: {error}" | |
| else: # Face Only or Full Body | |
| if initial_image is None: | |
| return False, f"Initial image is required for {image_type} mode" | |
| valid, error = validate_image(initial_image) | |
| if not valid: | |
| return False, f"Invalid initial image: {error}" | |
| return True, None | |
| # ============================================================================= | |
| # BACKEND AVAILABILITY VALIDATION | |
| # ============================================================================= | |
| def validate_backend_available(backend: str, api_key: Optional[str] = None) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Check if a backend is available and properly configured. | |
| Args: | |
| backend: Backend name | |
| api_key: API key (for Gemini backend) | |
| Returns: | |
| Tuple of (is_available, error_message) | |
| """ | |
| # Validate backend name first | |
| valid, error = validate_backend(backend) | |
| if not valid: | |
| return False, error | |
| # Check Gemini API | |
| if backend == Settings.BACKEND_GEMINI: | |
| if not api_key: | |
| return False, "Gemini API key not configured. Please set GEMINI_API_KEY or enter it in settings." | |
| return True, None | |
| # Check OmniGen2 | |
| if backend == Settings.BACKEND_OMNIGEN2: | |
| # Try to check if server is running | |
| try: | |
| import requests | |
| response = requests.get(f"{Settings.OMNIGEN2_BASE_URL}/health", timeout=2) | |
| if response.ok: | |
| data = response.json() | |
| if data.get('status') == 'healthy': | |
| return True, None | |
| else: | |
| return False, "OmniGen2 server is not healthy. Check server.log for details." | |
| else: | |
| return False, f"OmniGen2 server returned error: {response.status_code}" | |
| except Exception as e: | |
| return False, f"OmniGen2 server not responding. Start it with: omnigen2_plugin/server.bat start" | |
| return False, f"Unknown backend: {backend}" | |
| # ============================================================================= | |
| # HELPER FUNCTIONS | |
| # ============================================================================= | |
| def raise_if_invalid(is_valid: bool, error_message: Optional[str], exception_type=ValueError): | |
| """ | |
| Raise an exception if validation failed. | |
| Helper function for turning validation results into exceptions. | |
| Args: | |
| is_valid: Validation result | |
| error_message: Error message (if invalid) | |
| exception_type: Exception class to raise (default: ValueError) | |
| Raises: | |
| exception_type: If is_valid is False | |
| """ | |
| if not is_valid: | |
| logger.error(f"Validation failed: {error_message}") | |
| raise exception_type(error_message) | |
| def log_validation_error(validation_result: Tuple[bool, Optional[str]], context: str = ""): | |
| """ | |
| Log a validation error if validation failed. | |
| Args: | |
| validation_result: Result tuple from validation function | |
| context: Optional context string for the log message | |
| """ | |
| is_valid, error_message = validation_result | |
| if not is_valid: | |
| if context: | |
| logger.warning(f"Validation failed [{context}]: {error_message}") | |
| else: | |
| logger.warning(f"Validation failed: {error_message}") | |