Spaces:
Sleeping
Sleeping
| """Input validation utilities.""" | |
| import re | |
| from typing import Optional | |
| from writing_studio.core.config import settings | |
| from writing_studio.core.exceptions import ValidationError | |
| from writing_studio.utils.logging import logger | |
| def sanitize_text(text: str) -> str: | |
| """ | |
| Sanitize input text by removing potentially harmful content. | |
| Args: | |
| text: Input text to sanitize | |
| Returns: | |
| Sanitized text | |
| """ | |
| if not text: | |
| return "" | |
| # Remove null bytes | |
| text = text.replace("\x00", "") | |
| # Normalize whitespace | |
| text = re.sub(r"\s+", " ", text) | |
| # Strip leading/trailing whitespace | |
| text = text.strip() | |
| return text | |
| def validate_text_input( | |
| text: str, max_length: Optional[int] = None, min_length: int = 1 | |
| ) -> str: | |
| """ | |
| Validate and sanitize text input. | |
| Args: | |
| text: Input text to validate | |
| max_length: Maximum allowed length (default: from settings) | |
| min_length: Minimum allowed length | |
| Returns: | |
| Validated and sanitized text | |
| Raises: | |
| ValidationError: If validation fails | |
| """ | |
| if not isinstance(text, str): | |
| raise ValidationError("Input must be a string", {"type": type(text).__name__}) | |
| # Sanitize | |
| text = sanitize_text(text) | |
| # Check minimum length | |
| if len(text) < min_length: | |
| raise ValidationError( | |
| f"Text must be at least {min_length} characters", | |
| {"length": len(text), "min_length": min_length}, | |
| ) | |
| # Check maximum length | |
| max_len = max_length or settings.max_text_length | |
| if len(text) > max_len: | |
| logger.warning(f"Text exceeds maximum length: {len(text)} > {max_len}") | |
| raise ValidationError( | |
| f"Text exceeds maximum length of {max_len} characters", | |
| {"length": len(text), "max_length": max_len}, | |
| ) | |
| return text | |
| def validate_model_name(model_name: str) -> str: | |
| """ | |
| Validate HuggingFace model name. | |
| Args: | |
| model_name: Model identifier | |
| Returns: | |
| Validated model name | |
| Raises: | |
| ValidationError: If validation fails | |
| """ | |
| if not isinstance(model_name, str): | |
| raise ValidationError("Model name must be a string", {"type": type(model_name).__name__}) | |
| model_name = model_name.strip() | |
| if not model_name: | |
| raise ValidationError("Model name cannot be empty") | |
| # Basic validation for HuggingFace model names | |
| # Format: organization/model-name or just model-name | |
| if not re.match(r"^[a-zA-Z0-9][\w\-./]*$", model_name): | |
| raise ValidationError( | |
| "Invalid model name format", {"model_name": model_name} | |
| ) | |
| # Check for path traversal attempts | |
| if ".." in model_name or model_name.startswith("/"): | |
| raise ValidationError( | |
| "Model name contains invalid characters", {"model_name": model_name} | |
| ) | |
| return model_name | |
| def validate_generation_params( | |
| max_length: int, num_sequences: int, temperature: float = 1.0 | |
| ) -> dict: | |
| """ | |
| Validate text generation parameters. | |
| Args: | |
| max_length: Maximum generation length | |
| num_sequences: Number of sequences to generate | |
| temperature: Sampling temperature | |
| Returns: | |
| Validated parameters | |
| Raises: | |
| ValidationError: If validation fails | |
| """ | |
| errors = {} | |
| if not isinstance(max_length, int) or max_length < 1: | |
| errors["max_length"] = "Must be a positive integer" | |
| if max_length > settings.max_model_length: | |
| errors["max_length"] = f"Exceeds maximum of {settings.max_model_length}" | |
| if not isinstance(num_sequences, int) or num_sequences < 1: | |
| errors["num_sequences"] = "Must be a positive integer" | |
| if num_sequences > 5: | |
| errors["num_sequences"] = "Cannot exceed 5 sequences" | |
| if not isinstance(temperature, (int, float)) or temperature <= 0: | |
| errors["temperature"] = "Must be a positive number" | |
| if errors: | |
| raise ValidationError("Invalid generation parameters", errors) | |
| return { | |
| "max_length": max_length, | |
| "num_sequences": num_sequences, | |
| "temperature": temperature, | |
| } | |