jmisak's picture
Upload 41 files
aeb3f7c verified
"""Input validation utilities."""
import re
from typing import Optional
from writing_studio.core.config import settings
from writing_studio.core.exceptions import ValidationError
from writing_studio.utils.logging import logger
def sanitize_text(text: str) -> str:
"""
Sanitize input text by removing potentially harmful content.
Args:
text: Input text to sanitize
Returns:
Sanitized text
"""
if not text:
return ""
# Remove null bytes
text = text.replace("\x00", "")
# Normalize whitespace
text = re.sub(r"\s+", " ", text)
# Strip leading/trailing whitespace
text = text.strip()
return text
def validate_text_input(
text: str, max_length: Optional[int] = None, min_length: int = 1
) -> str:
"""
Validate and sanitize text input.
Args:
text: Input text to validate
max_length: Maximum allowed length (default: from settings)
min_length: Minimum allowed length
Returns:
Validated and sanitized text
Raises:
ValidationError: If validation fails
"""
if not isinstance(text, str):
raise ValidationError("Input must be a string", {"type": type(text).__name__})
# Sanitize
text = sanitize_text(text)
# Check minimum length
if len(text) < min_length:
raise ValidationError(
f"Text must be at least {min_length} characters",
{"length": len(text), "min_length": min_length},
)
# Check maximum length
max_len = max_length or settings.max_text_length
if len(text) > max_len:
logger.warning(f"Text exceeds maximum length: {len(text)} > {max_len}")
raise ValidationError(
f"Text exceeds maximum length of {max_len} characters",
{"length": len(text), "max_length": max_len},
)
return text
def validate_model_name(model_name: str) -> str:
"""
Validate HuggingFace model name.
Args:
model_name: Model identifier
Returns:
Validated model name
Raises:
ValidationError: If validation fails
"""
if not isinstance(model_name, str):
raise ValidationError("Model name must be a string", {"type": type(model_name).__name__})
model_name = model_name.strip()
if not model_name:
raise ValidationError("Model name cannot be empty")
# Basic validation for HuggingFace model names
# Format: organization/model-name or just model-name
if not re.match(r"^[a-zA-Z0-9][\w\-./]*$", model_name):
raise ValidationError(
"Invalid model name format", {"model_name": model_name}
)
# Check for path traversal attempts
if ".." in model_name or model_name.startswith("/"):
raise ValidationError(
"Model name contains invalid characters", {"model_name": model_name}
)
return model_name
def validate_generation_params(
max_length: int, num_sequences: int, temperature: float = 1.0
) -> dict:
"""
Validate text generation parameters.
Args:
max_length: Maximum generation length
num_sequences: Number of sequences to generate
temperature: Sampling temperature
Returns:
Validated parameters
Raises:
ValidationError: If validation fails
"""
errors = {}
if not isinstance(max_length, int) or max_length < 1:
errors["max_length"] = "Must be a positive integer"
if max_length > settings.max_model_length:
errors["max_length"] = f"Exceeds maximum of {settings.max_model_length}"
if not isinstance(num_sequences, int) or num_sequences < 1:
errors["num_sequences"] = "Must be a positive integer"
if num_sequences > 5:
errors["num_sequences"] = "Cannot exceed 5 sequences"
if not isinstance(temperature, (int, float)) or temperature <= 0:
errors["temperature"] = "Must be a positive number"
if errors:
raise ValidationError("Invalid generation parameters", errors)
return {
"max_length": max_length,
"num_sequences": num_sequences,
"temperature": temperature,
}