Spaces:
Sleeping
Sleeping
File size: 4,154 Bytes
aeb3f7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
"""Input validation utilities."""
import re
from typing import Optional
from writing_studio.core.config import settings
from writing_studio.core.exceptions import ValidationError
from writing_studio.utils.logging import logger
def sanitize_text(text: str) -> str:
"""
Sanitize input text by removing potentially harmful content.
Args:
text: Input text to sanitize
Returns:
Sanitized text
"""
if not text:
return ""
# Remove null bytes
text = text.replace("\x00", "")
# Normalize whitespace
text = re.sub(r"\s+", " ", text)
# Strip leading/trailing whitespace
text = text.strip()
return text
def validate_text_input(
text: str, max_length: Optional[int] = None, min_length: int = 1
) -> str:
"""
Validate and sanitize text input.
Args:
text: Input text to validate
max_length: Maximum allowed length (default: from settings)
min_length: Minimum allowed length
Returns:
Validated and sanitized text
Raises:
ValidationError: If validation fails
"""
if not isinstance(text, str):
raise ValidationError("Input must be a string", {"type": type(text).__name__})
# Sanitize
text = sanitize_text(text)
# Check minimum length
if len(text) < min_length:
raise ValidationError(
f"Text must be at least {min_length} characters",
{"length": len(text), "min_length": min_length},
)
# Check maximum length
max_len = max_length or settings.max_text_length
if len(text) > max_len:
logger.warning(f"Text exceeds maximum length: {len(text)} > {max_len}")
raise ValidationError(
f"Text exceeds maximum length of {max_len} characters",
{"length": len(text), "max_length": max_len},
)
return text
def validate_model_name(model_name: str) -> str:
"""
Validate HuggingFace model name.
Args:
model_name: Model identifier
Returns:
Validated model name
Raises:
ValidationError: If validation fails
"""
if not isinstance(model_name, str):
raise ValidationError("Model name must be a string", {"type": type(model_name).__name__})
model_name = model_name.strip()
if not model_name:
raise ValidationError("Model name cannot be empty")
# Basic validation for HuggingFace model names
# Format: organization/model-name or just model-name
if not re.match(r"^[a-zA-Z0-9][\w\-./]*$", model_name):
raise ValidationError(
"Invalid model name format", {"model_name": model_name}
)
# Check for path traversal attempts
if ".." in model_name or model_name.startswith("/"):
raise ValidationError(
"Model name contains invalid characters", {"model_name": model_name}
)
return model_name
def validate_generation_params(
max_length: int, num_sequences: int, temperature: float = 1.0
) -> dict:
"""
Validate text generation parameters.
Args:
max_length: Maximum generation length
num_sequences: Number of sequences to generate
temperature: Sampling temperature
Returns:
Validated parameters
Raises:
ValidationError: If validation fails
"""
errors = {}
if not isinstance(max_length, int) or max_length < 1:
errors["max_length"] = "Must be a positive integer"
if max_length > settings.max_model_length:
errors["max_length"] = f"Exceeds maximum of {settings.max_model_length}"
if not isinstance(num_sequences, int) or num_sequences < 1:
errors["num_sequences"] = "Must be a positive integer"
if num_sequences > 5:
errors["num_sequences"] = "Cannot exceed 5 sequences"
if not isinstance(temperature, (int, float)) or temperature <= 0:
errors["temperature"] = "Must be a positive number"
if errors:
raise ValidationError("Invalid generation parameters", errors)
return {
"max_length": max_length,
"num_sequences": num_sequences,
"temperature": temperature,
}
|