"""Input validation utilities.""" import re from typing import Optional from writing_studio.core.config import settings from writing_studio.core.exceptions import ValidationError from writing_studio.utils.logging import logger def sanitize_text(text: str) -> str: """ Sanitize input text by removing potentially harmful content. Args: text: Input text to sanitize Returns: Sanitized text """ if not text: return "" # Remove null bytes text = text.replace("\x00", "") # Normalize whitespace text = re.sub(r"\s+", " ", text) # Strip leading/trailing whitespace text = text.strip() return text def validate_text_input( text: str, max_length: Optional[int] = None, min_length: int = 1 ) -> str: """ Validate and sanitize text input. Args: text: Input text to validate max_length: Maximum allowed length (default: from settings) min_length: Minimum allowed length Returns: Validated and sanitized text Raises: ValidationError: If validation fails """ if not isinstance(text, str): raise ValidationError("Input must be a string", {"type": type(text).__name__}) # Sanitize text = sanitize_text(text) # Check minimum length if len(text) < min_length: raise ValidationError( f"Text must be at least {min_length} characters", {"length": len(text), "min_length": min_length}, ) # Check maximum length max_len = max_length or settings.max_text_length if len(text) > max_len: logger.warning(f"Text exceeds maximum length: {len(text)} > {max_len}") raise ValidationError( f"Text exceeds maximum length of {max_len} characters", {"length": len(text), "max_length": max_len}, ) return text def validate_model_name(model_name: str) -> str: """ Validate HuggingFace model name. Args: model_name: Model identifier Returns: Validated model name Raises: ValidationError: If validation fails """ if not isinstance(model_name, str): raise ValidationError("Model name must be a string", {"type": type(model_name).__name__}) model_name = model_name.strip() if not model_name: raise ValidationError("Model name cannot be empty") # Basic validation for HuggingFace model names # Format: organization/model-name or just model-name if not re.match(r"^[a-zA-Z0-9][\w\-./]*$", model_name): raise ValidationError( "Invalid model name format", {"model_name": model_name} ) # Check for path traversal attempts if ".." in model_name or model_name.startswith("/"): raise ValidationError( "Model name contains invalid characters", {"model_name": model_name} ) return model_name def validate_generation_params( max_length: int, num_sequences: int, temperature: float = 1.0 ) -> dict: """ Validate text generation parameters. Args: max_length: Maximum generation length num_sequences: Number of sequences to generate temperature: Sampling temperature Returns: Validated parameters Raises: ValidationError: If validation fails """ errors = {} if not isinstance(max_length, int) or max_length < 1: errors["max_length"] = "Must be a positive integer" if max_length > settings.max_model_length: errors["max_length"] = f"Exceeds maximum of {settings.max_model_length}" if not isinstance(num_sequences, int) or num_sequences < 1: errors["num_sequences"] = "Must be a positive integer" if num_sequences > 5: errors["num_sequences"] = "Cannot exceed 5 sequences" if not isinstance(temperature, (int, float)) or temperature <= 0: errors["temperature"] = "Must be a positive number" if errors: raise ValidationError("Invalid generation parameters", errors) return { "max_length": max_length, "num_sequences": num_sequences, "temperature": temperature, }