Spaces:
Sleeping
Sleeping
| """ | |
| Input validation and safety checks for prompts. | |
| This module provides comprehensive validation utilities for prompt safety, | |
| content filtering, and input sanitization to ensure secure and reliable | |
| prompt processing. | |
| """ | |
| import re | |
| import html | |
| from typing import List, Dict, Optional, Tuple, Set | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from config.logging import get_logger | |
| logger = get_logger(__name__) | |
| class ValidationSeverity(str, Enum): | |
| """Severity levels for validation issues.""" | |
| INFO = "info" | |
| WARNING = "warning" | |
| ERROR = "error" | |
| CRITICAL = "critical" | |
| class ValidationIssue: | |
| """Represents a validation issue.""" | |
| severity: ValidationSeverity | |
| code: str | |
| message: str | |
| location: Optional[str] = None | |
| suggestion: Optional[str] = None | |
| class ValidationResult: | |
| """Result of validation process.""" | |
| is_valid: bool | |
| issues: List[ValidationIssue] | |
| sanitized_content: Optional[str] = None | |
| risk_score: float = 0.0 | |
| def has_errors(self) -> bool: | |
| """Check if there are any error-level issues.""" | |
| return any(issue.severity in [ValidationSeverity.ERROR, ValidationSeverity.CRITICAL] | |
| for issue in self.issues) | |
| def has_warnings(self) -> bool: | |
| """Check if there are any warning-level issues.""" | |
| return any(issue.severity == ValidationSeverity.WARNING for issue in self.issues) | |
| class PromptValidator: | |
| """ | |
| Comprehensive prompt validation and safety checking. | |
| Provides multiple layers of validation including: | |
| - Content safety and injection detection | |
| - Format compliance checking | |
| - Length and structure validation | |
| - Business context validation | |
| """ | |
| # Prompt injection patterns | |
| INJECTION_PATTERNS = { | |
| "role_manipulation": [ | |
| r"(?i)\b(ignore|forget|disregard)\s+(previous|above|earlier|all)\s+(instructions?|prompts?|rules?)", | |
| r"(?i)\b(act\s+as|pretend\s+to\s+be|roleplay\s+as|simulate)\s+(?!a\s+business|an?\s+analyst)", | |
| r"(?i)\b(you\s+are\s+now|from\s+now\s+on|instead)\s+", | |
| r"(?i)\b(override|bypass|disable|turn\s+off)\s+" | |
| ], | |
| "system_commands": [ | |
| r"(?i)\b(system|admin|root|sudo)\s*[\(\[]", | |
| r"(?i)\b(exec|eval|run|execute)\s*[\(\[]", | |
| r"(?i)<script|javascript:|data:|vbscript:", | |
| r"(?i)\$\{|\$\(|`[^`]*`" # Variable expansion | |
| ], | |
| "format_breaking": [ | |
| r"(?i)\b(don't\s+use|ignore|skip|avoid)\s+(json|format|structure)", | |
| r"(?i)\b(plain\s+text|free\s+form|unstructured|raw\s+output)", | |
| r"(?i)\b(no\s+format|without\s+format|format\s+free)" | |
| ], | |
| "data_extraction": [ | |
| r"(?i)\b(show|reveal|display|print|output)\s+(all|your|the)\s+(data|information|content)", | |
| r"(?i)\b(what\s+is\s+your|tell\s+me\s+your)\s+(system|prompt|instructions?)", | |
| r"(?i)\b(dump|export|leak)\s+" | |
| ] | |
| } | |
| # Suspicious keywords by category | |
| SUSPICIOUS_KEYWORDS = { | |
| "high_risk": { | |
| "jailbreak", "prompt_injection", "ignore_instructions", "system_override", | |
| "admin_access", "root_privileges", "bypass_safety", "disable_filters" | |
| }, | |
| "medium_risk": { | |
| "eval", "exec", "sudo", "admin", "root", "shell", "command", | |
| "injection", "exploit", "hack", "bypass" | |
| }, | |
| "format_risk": { | |
| "plain_text", "no_json", "unstructured", "free_form", "raw_output", | |
| "ignore_format", "skip_structure", "format_free" | |
| } | |
| } | |
| # Required elements for topic extraction prompts | |
| REQUIRED_ELEMENTS = { | |
| "output_format": ["json", "format", "structure"], | |
| "topic_fields": ["topic_name", "topic_type", "confidence_score"], | |
| "business_context": ["business", "topic", "extract", "analyze"] | |
| } | |
| # Content length limits | |
| LENGTH_LIMITS = { | |
| "min_prompt_length": 20, | |
| "max_prompt_length": 15000, | |
| "max_line_length": 500, | |
| "max_word_length": 50 | |
| } | |
| def __init__(self): | |
| """Initialize the prompt validator.""" | |
| self.logger = get_logger(f"{__name__}.{self.__class__.__name__}") | |
| # Compile regex patterns for performance | |
| self._compiled_patterns = {} | |
| for category, patterns in self.INJECTION_PATTERNS.items(): | |
| self._compiled_patterns[category] = [re.compile(pattern) for pattern in patterns] | |
| def validate_prompt(self, prompt: str, context: Optional[Dict] = None) -> ValidationResult: | |
| """ | |
| Comprehensive prompt validation. | |
| Args: | |
| prompt: The prompt text to validate | |
| context: Optional context for validation | |
| Returns: | |
| ValidationResult with issues and sanitized content | |
| """ | |
| issues = [] | |
| risk_score = 0.0 | |
| try: | |
| # Basic validation | |
| basic_issues, basic_risk = self._validate_basic_structure(prompt) | |
| issues.extend(basic_issues) | |
| risk_score += basic_risk | |
| # Safety validation | |
| safety_issues, safety_risk = self._validate_safety(prompt) | |
| issues.extend(safety_issues) | |
| risk_score += safety_risk | |
| # Format validation | |
| format_issues, format_risk = self._validate_format_compliance(prompt) | |
| issues.extend(format_issues) | |
| risk_score += format_risk | |
| # Business context validation | |
| business_issues, business_risk = self._validate_business_context(prompt, context) | |
| issues.extend(business_issues) | |
| risk_score += business_risk | |
| # Content sanitization | |
| sanitized_content = self._sanitize_content(prompt) | |
| # Determine overall validity | |
| is_valid = not any(issue.severity in [ValidationSeverity.ERROR, ValidationSeverity.CRITICAL] | |
| for issue in issues) | |
| # Normalize risk score (0.0 to 1.0) | |
| risk_score = min(1.0, max(0.0, risk_score)) | |
| result = ValidationResult( | |
| is_valid=is_valid, | |
| issues=issues, | |
| sanitized_content=sanitized_content, | |
| risk_score=risk_score | |
| ) | |
| self.logger.debug( | |
| f"Prompt validation completed: valid={is_valid}, " | |
| f"issues={len(issues)}, risk_score={risk_score:.2f}" | |
| ) | |
| return result | |
| except Exception as e: | |
| self.logger.error(f"Error during prompt validation: {str(e)}") | |
| return ValidationResult( | |
| is_valid=False, | |
| issues=[ValidationIssue( | |
| severity=ValidationSeverity.CRITICAL, | |
| code="VALIDATION_ERROR", | |
| message=f"Validation failed: {str(e)}" | |
| )], | |
| risk_score=1.0 | |
| ) | |
| def _validate_basic_structure(self, prompt: str) -> Tuple[List[ValidationIssue], float]: | |
| """Validate basic prompt structure and length.""" | |
| issues = [] | |
| risk_score = 0.0 | |
| # Length validation | |
| if len(prompt) < self.LENGTH_LIMITS["min_prompt_length"]: | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.ERROR, | |
| code="PROMPT_TOO_SHORT", | |
| message=f"Prompt is too short ({len(prompt)} chars). Minimum: {self.LENGTH_LIMITS['min_prompt_length']}", | |
| suggestion="Provide more detailed instructions for better results" | |
| )) | |
| risk_score += 0.3 | |
| if len(prompt) > self.LENGTH_LIMITS["max_prompt_length"]: | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.WARNING, | |
| code="PROMPT_TOO_LONG", | |
| message=f"Prompt is very long ({len(prompt)} chars). May exceed token limits.", | |
| suggestion="Consider breaking into smaller, focused prompts" | |
| )) | |
| risk_score += 0.1 | |
| # Line length validation | |
| lines = prompt.split('\n') | |
| long_lines = [i for i, line in enumerate(lines) | |
| if len(line) > self.LENGTH_LIMITS["max_line_length"]] | |
| if long_lines: | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.INFO, | |
| code="LONG_LINES", | |
| message=f"Found {len(long_lines)} very long lines", | |
| location=f"Lines: {long_lines[:5]}", # Show first 5 | |
| suggestion="Consider breaking long lines for better readability" | |
| )) | |
| # Word length validation | |
| words = prompt.split() | |
| long_words = [word for word in words | |
| if len(word) > self.LENGTH_LIMITS["max_word_length"]] | |
| if len(long_words) > 5: # Allow some long words | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.INFO, | |
| code="LONG_WORDS", | |
| message=f"Found {len(long_words)} very long words", | |
| suggestion="Very long words may indicate encoded content or errors" | |
| )) | |
| # Character encoding validation | |
| try: | |
| prompt.encode('utf-8') | |
| except UnicodeEncodeError as e: | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.ERROR, | |
| code="ENCODING_ERROR", | |
| message=f"Invalid character encoding: {str(e)}", | |
| suggestion="Ensure prompt uses valid UTF-8 characters" | |
| )) | |
| risk_score += 0.2 | |
| return issues, risk_score | |
| def _validate_safety(self, prompt: str) -> Tuple[List[ValidationIssue], float]: | |
| """Validate prompt safety and detect injection attempts.""" | |
| issues = [] | |
| risk_score = 0.0 | |
| # Check for injection patterns | |
| for category, patterns in self._compiled_patterns.items(): | |
| matches = [] | |
| for pattern in patterns: | |
| found = pattern.findall(prompt) | |
| matches.extend(found) | |
| if matches: | |
| severity = ValidationSeverity.CRITICAL if category == "system_commands" else ValidationSeverity.WARNING | |
| risk_increase = 0.4 if category == "system_commands" else 0.2 | |
| issues.append(ValidationIssue( | |
| severity=severity, | |
| code=f"INJECTION_{category.upper()}", | |
| message=f"Detected potential {category.replace('_', ' ')} injection", | |
| location=f"Matches: {matches[:3]}", # Show first 3 matches | |
| suggestion="Review prompt for unintended injection patterns" | |
| )) | |
| risk_score += risk_increase | |
| # Check for suspicious keywords | |
| prompt_lower = prompt.lower() | |
| for risk_level, keywords in self.SUSPICIOUS_KEYWORDS.items(): | |
| found_keywords = [kw for kw in keywords if kw in prompt_lower] | |
| if found_keywords: | |
| if risk_level == "high_risk": | |
| severity = ValidationSeverity.ERROR | |
| risk_increase = 0.3 | |
| elif risk_level == "medium_risk": | |
| severity = ValidationSeverity.WARNING | |
| risk_increase = 0.1 | |
| else: # format_risk | |
| severity = ValidationSeverity.WARNING | |
| risk_increase = 0.15 | |
| issues.append(ValidationIssue( | |
| severity=severity, | |
| code=f"SUSPICIOUS_{risk_level.upper()}", | |
| message=f"Found {len(found_keywords)} {risk_level.replace('_', ' ')} keywords", | |
| location=f"Keywords: {found_keywords[:5]}", | |
| suggestion="Review keywords for potential security risks" | |
| )) | |
| risk_score += risk_increase | |
| # Check for HTML/XML content | |
| if '<' in prompt and '>' in prompt: | |
| html_tags = re.findall(r'<[^>]+>', prompt) | |
| if html_tags: | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.WARNING, | |
| code="HTML_CONTENT", | |
| message=f"Found {len(html_tags)} HTML-like tags", | |
| location=f"Tags: {html_tags[:3]}", | |
| suggestion="HTML content may indicate injection attempts" | |
| )) | |
| risk_score += 0.1 | |
| # Check for excessive special characters | |
| special_chars = re.findall(r'[^\w\s\.\,\!\?\;\:\-\(\)\[\]\{\}\"\'\/\\]', prompt) | |
| if len(special_chars) > len(prompt) * 0.1: # More than 10% special chars | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.INFO, | |
| code="EXCESSIVE_SPECIAL_CHARS", | |
| message=f"High ratio of special characters ({len(special_chars)}/{len(prompt)})", | |
| suggestion="Excessive special characters may indicate encoded content" | |
| )) | |
| return issues, risk_score | |
| def _validate_format_compliance(self, prompt: str) -> Tuple[List[ValidationIssue], float]: | |
| """Validate that prompt enforces proper output format.""" | |
| issues = [] | |
| risk_score = 0.0 | |
| prompt_lower = prompt.lower() | |
| # Check for required format elements | |
| missing_categories = [] | |
| for category, required_words in self.REQUIRED_ELEMENTS.items(): | |
| found_words = [word for word in required_words if word in prompt_lower] | |
| if not found_words: | |
| missing_categories.append(category) | |
| if missing_categories: | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.WARNING, | |
| code="MISSING_FORMAT_ELEMENTS", | |
| message=f"Missing format elements: {missing_categories}", | |
| suggestion="Ensure prompt requests proper JSON format and required fields" | |
| )) | |
| risk_score += 0.1 * len(missing_categories) | |
| # Check for format-breaking instructions | |
| format_breaking_patterns = [ | |
| r"(?i)\b(don't|do\s+not|avoid|skip)\s+(use\s+)?json", | |
| r"(?i)\b(plain\s+text|free\s+form|unstructured)", | |
| r"(?i)\b(ignore|skip|avoid)\s+(format|structure)" | |
| ] | |
| for pattern in format_breaking_patterns: | |
| if re.search(pattern, prompt): | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.ERROR, | |
| code="FORMAT_BREAKING_INSTRUCTION", | |
| message="Prompt contains instructions that may break output format", | |
| suggestion="Remove instructions that discourage structured output" | |
| )) | |
| risk_score += 0.3 | |
| break | |
| # Check for JSON format enforcement | |
| json_indicators = ["json", "format", "structure", "array", "object"] | |
| found_indicators = sum(1 for indicator in json_indicators if indicator in prompt_lower) | |
| if found_indicators < 2: | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.INFO, | |
| code="WEAK_FORMAT_ENFORCEMENT", | |
| message="Prompt may not strongly enforce JSON output format", | |
| suggestion="Add explicit JSON format requirements" | |
| )) | |
| return issues, risk_score | |
| def _validate_business_context( | |
| self, | |
| prompt: str, | |
| context: Optional[Dict] = None | |
| ) -> Tuple[List[ValidationIssue], float]: | |
| """Validate business context and topic extraction focus.""" | |
| issues = [] | |
| risk_score = 0.0 | |
| prompt_lower = prompt.lower() | |
| # Check for business-relevant keywords | |
| business_keywords = [ | |
| "business", "topic", "extract", "analyze", "insight", "category", | |
| "customer", "client", "requirement", "feedback", "solution" | |
| ] | |
| found_business_keywords = sum(1 for kw in business_keywords if kw in prompt_lower) | |
| if found_business_keywords < 3: | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.INFO, | |
| code="LIMITED_BUSINESS_CONTEXT", | |
| message="Prompt may lack sufficient business context", | |
| suggestion="Include more business-focused instructions for better results" | |
| )) | |
| # Check for topic extraction focus | |
| topic_keywords = ["topic", "theme", "subject", "category", "segment"] | |
| found_topic_keywords = sum(1 for kw in topic_keywords if kw in prompt_lower) | |
| if found_topic_keywords == 0: | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.WARNING, | |
| code="NO_TOPIC_FOCUS", | |
| message="Prompt doesn't explicitly mention topic extraction", | |
| suggestion="Add explicit topic extraction instructions" | |
| )) | |
| risk_score += 0.1 | |
| # Validate context if provided | |
| if context: | |
| if context.get("language") and context["language"] not in prompt_lower: | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.INFO, | |
| code="LANGUAGE_MISMATCH", | |
| message="Prompt language may not match specified context language", | |
| suggestion="Ensure prompt language matches context requirements" | |
| )) | |
| return issues, risk_score | |
| def _sanitize_content(self, prompt: str) -> str: | |
| """Sanitize prompt content while preserving functionality.""" | |
| try: | |
| # HTML escape potentially dangerous characters | |
| sanitized = html.escape(prompt, quote=False) | |
| # Remove null bytes and control characters (except newlines and tabs) | |
| sanitized = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]', '', sanitized) | |
| # Normalize whitespace | |
| sanitized = re.sub(r'\s+', ' ', sanitized) | |
| sanitized = sanitized.strip() | |
| return sanitized | |
| except Exception as e: | |
| self.logger.error(f"Error sanitizing content: {str(e)}") | |
| return prompt # Return original if sanitization fails | |
| def validate_template_variables(self, variables: Dict[str, str]) -> ValidationResult: | |
| """Validate template variables for safety.""" | |
| issues = [] | |
| risk_score = 0.0 | |
| for var_name, var_value in variables.items(): | |
| # Validate variable name | |
| if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', var_name): | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.WARNING, | |
| code="INVALID_VARIABLE_NAME", | |
| message=f"Variable name '{var_name}' contains invalid characters", | |
| suggestion="Use only alphanumeric characters and underscores" | |
| )) | |
| # Validate variable value | |
| if isinstance(var_value, str): | |
| var_validation = self.validate_prompt(var_value) | |
| if var_validation.risk_score > 0.5: | |
| issues.append(ValidationIssue( | |
| severity=ValidationSeverity.WARNING, | |
| code="RISKY_VARIABLE_VALUE", | |
| message=f"Variable '{var_name}' contains potentially risky content", | |
| suggestion="Review variable content for safety" | |
| )) | |
| risk_score += 0.1 | |
| return ValidationResult( | |
| is_valid=not any(issue.severity == ValidationSeverity.ERROR for issue in issues), | |
| issues=issues, | |
| risk_score=min(1.0, risk_score) | |
| ) | |
| def get_safety_recommendations(self, validation_result: ValidationResult) -> List[str]: | |
| """Get safety recommendations based on validation results.""" | |
| recommendations = [] | |
| if validation_result.risk_score > 0.7: | |
| recommendations.append("Consider rewriting the prompt to reduce security risks") | |
| if validation_result.has_errors: | |
| recommendations.append("Fix all error-level issues before using the prompt") | |
| if validation_result.has_warnings: | |
| recommendations.append("Review and address warning-level issues") | |
| # Specific recommendations based on issue codes | |
| issue_codes = {issue.code for issue in validation_result.issues} | |
| if "INJECTION_ROLE_MANIPULATION" in issue_codes: | |
| recommendations.append("Remove instructions that attempt to change the AI's role") | |
| if "FORMAT_BREAKING_INSTRUCTION" in issue_codes: | |
| recommendations.append("Ensure prompt enforces structured JSON output") | |
| if "PROMPT_TOO_SHORT" in issue_codes: | |
| recommendations.append("Provide more detailed instructions for better results") | |
| if "NO_TOPIC_FOCUS" in issue_codes: | |
| recommendations.append("Add explicit topic extraction instructions") | |
| # Add general security recommendation for high-risk prompts | |
| if any("INJECTION" in code or "SUSPICIOUS" in code for code in issue_codes): | |
| recommendations.append("Review prompt for potential security risks") | |
| return recommendations | |
| # Global validator instance | |
| _validator: Optional[PromptValidator] = None | |
| def get_prompt_validator() -> PromptValidator: | |
| """Get or create global prompt validator instance.""" | |
| global _validator | |
| if _validator is None: | |
| _validator = PromptValidator() | |
| return _validator | |
| def validate_prompt_safety(prompt: str, context: Optional[Dict] = None) -> ValidationResult: | |
| """Convenience function for prompt safety validation.""" | |
| validator = get_prompt_validator() | |
| return validator.validate_prompt(prompt, context) | |
| def sanitize_prompt_content(prompt: str) -> str: | |
| """Convenience function for prompt content sanitization.""" | |
| validator = get_prompt_validator() | |
| return validator._sanitize_content(prompt) |