Spaces:
Running
Running
| """Input validation and sanitization utilities. | |
| Provides query validation, length limits, character filtering, and | |
| injection pattern detection as a defense-in-depth layer beyond the | |
| security agent. | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from dataclasses import dataclass | |
| from utils.logging import get_logger | |
| logger = get_logger(__name__) | |
| # Maximum query length to prevent abuse | |
| MAX_QUERY_LENGTH: int = 4000 | |
| MAX_QUERY_WORDS: int = 500 | |
| # Dangerous characters/patterns to flag | |
| _DANGEROUS_PATTERNS: list[re.Pattern] = [ | |
| # SQL injection attempts | |
| re.compile( | |
| r"(\b(union|select|insert|update|delete|drop|create|alter)\b.*\b(from|into|table|database)\b)", | |
| re.IGNORECASE, | |
| ), | |
| # Command injection | |
| re.compile(r"[;&|`$]\s*\w+", re.IGNORECASE), | |
| # Path traversal | |
| re.compile(r"\.\.[\/\\]"), | |
| # Null bytes | |
| re.compile(r"\x00"), | |
| # Excessive repetition (DoS pattern) | |
| re.compile(r"(.)\1{50,}"), | |
| ] | |
| # Allowed characters for basic sanitization | |
| _SAFE_QUERY_PATTERN = re.compile( | |
| r"^[\w\s\-.,;:!?'\"()[\]{}@#$%&*+/=<>|~^`\u0600-\u06FF]+$", re.UNICODE | |
| ) | |
| class ValidationResult: | |
| """Result of input validation.""" | |
| valid: bool | |
| sanitized_query: str | |
| message: str = "" | |
| violations: list[str] = None # type: ignore[assignment] | |
| def __post_init__(self): | |
| if self.violations is None: | |
| self.violations = [] | |
| def validate_query( | |
| query: str, | |
| max_length: int | None = None, | |
| max_words: int | None = None, | |
| allow_dangerous: bool = False, | |
| ) -> ValidationResult: | |
| """Validate and sanitize a user query. | |
| Checks length limits, word count, dangerous patterns, and optionally | |
| sanitizes the query by removing/replacing problematic characters. | |
| Args: | |
| query: The raw user query. | |
| max_length: Maximum character length. Defaults to MAX_QUERY_LENGTH. | |
| max_words: Maximum word count. Defaults to MAX_QUERY_WORDS. | |
| allow_dangerous: If True, only warns about dangerous patterns | |
| instead of rejecting. Defaults to False. | |
| Returns: | |
| ValidationResult with valid flag, sanitized query, and any violations. | |
| """ | |
| violations: list[str] = [] | |
| # Check for None or empty | |
| if not query or not query.strip(): | |
| return ValidationResult( | |
| valid=False, | |
| sanitized_query="", | |
| message="Query cannot be empty.", | |
| violations=["empty_query"], | |
| ) | |
| max_len = max_length or MAX_QUERY_LENGTH | |
| max_w = max_words or MAX_QUERY_WORDS | |
| # Length check | |
| if len(query) > max_len: | |
| violations.append(f"query_too_long: {len(query)} > {max_len}") | |
| # Word count check | |
| word_count = len(query.split()) | |
| if word_count > max_w: | |
| violations.append(f"query_too_many_words: {word_count} > {max_w}") | |
| # Dangerous pattern checks | |
| for pattern in _DANGEROUS_PATTERNS: | |
| if pattern.search(query): | |
| violations.append(f"dangerous_pattern: {pattern.pattern[:40]}...") | |
| # Sanitize: trim whitespace, normalize newlines | |
| sanitized = query.strip() | |
| sanitized = re.sub(r"\r\n|\r", "\n", sanitized) | |
| # Collapse multiple spaces | |
| sanitized = re.sub(r" {2,}", " ", sanitized) | |
| if violations: | |
| if allow_dangerous: | |
| logger.warning( | |
| "query_validation_warnings", | |
| violations=violations, | |
| query_len=len(query), | |
| ) | |
| return ValidationResult( | |
| valid=True, | |
| sanitized_query=sanitized, | |
| message="Query passed with warnings.", | |
| violations=violations, | |
| ) | |
| logger.warning( | |
| "query_validation_failed", | |
| violations=violations, | |
| query_len=len(query), | |
| ) | |
| return ValidationResult( | |
| valid=False, | |
| sanitized_query=sanitized, | |
| message=f"Query validation failed: {'; '.join(violations)}", | |
| violations=violations, | |
| ) | |
| return ValidationResult( | |
| valid=True, | |
| sanitized_query=sanitized, | |
| message="Query validation passed.", | |
| ) | |
| def sanitize_metadata(metadata: dict) -> dict: | |
| """Sanitize metadata values to prevent injection in stored payloads. | |
| Args: | |
| metadata: Raw metadata dict. | |
| Returns: | |
| Sanitized metadata with string values cleaned. | |
| """ | |
| sanitized: dict = {} | |
| for key, value in metadata.items(): | |
| # Sanitize keys | |
| safe_key = re.sub(r"[^\w\-_.]", "_", str(key))[:128] | |
| if isinstance(value, str): | |
| # Trim and sanitize strings | |
| safe_value = value.strip()[:2000] | |
| safe_value = safe_value.replace("\x00", "") | |
| sanitized[safe_key] = safe_value | |
| elif isinstance(value, (int, float, bool)): | |
| sanitized[safe_key] = value | |
| elif isinstance(value, list): | |
| sanitized[safe_key] = [ | |
| v.strip()[:500].replace("\x00", "") if isinstance(v, str) else v | |
| for v in value[:50] # Limit list size | |
| ] | |
| elif isinstance(value, dict): | |
| sanitized[safe_key] = sanitize_metadata(value) | |
| else: | |
| sanitized[safe_key] = str(value)[:500] | |
| return sanitized | |
| def validate_file_path(file_path: str) -> ValidationResult: | |
| """Validate a file path for upload safety. | |
| Args: | |
| file_path: The file path to validate. | |
| Returns: | |
| ValidationResult indicating if the path is safe. | |
| """ | |
| violations: list[str] = [] | |
| if not file_path or not file_path.strip(): | |
| return ValidationResult( | |
| valid=False, | |
| sanitized_query="", | |
| message="File path cannot be empty.", | |
| violations=["empty_path"], | |
| ) | |
| # Path traversal check | |
| if ".." in file_path or "~" in file_path: | |
| violations.append("path_traversal_attempt") | |
| # Null byte injection | |
| if "\x00" in file_path: | |
| violations.append("null_byte_injection") | |
| # Length check | |
| if len(file_path) > 512: | |
| violations.append("path_too_long") | |
| sanitized = file_path.strip() | |
| if violations: | |
| logger.warning( | |
| "file_path_validation_failed", | |
| violations=violations, | |
| path=file_path[:100], | |
| ) | |
| return ValidationResult( | |
| valid=False, | |
| sanitized_query=sanitized, | |
| message=f"File path validation failed: {'; '.join(violations)}", | |
| violations=violations, | |
| ) | |
| return ValidationResult( | |
| valid=True, | |
| sanitized_query=sanitized, | |
| message="File path validation passed.", | |
| ) | |