Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| # Guard Rails System for RAG | |
| This module provides comprehensive guard rails for the RAG system to ensure: | |
| - Input validation and sanitization | |
| - Output safety and content filtering | |
| - Model safety and prompt injection protection | |
| - Data privacy and PII detection | |
| - Rate limiting and abuse prevention | |
| ## Guard Rail Categories | |
| 1. **Input Guards**: Validate and sanitize user inputs | |
| 2. **Output Guards**: Filter and validate generated responses | |
| 3. **Model Guards**: Protect against prompt injection and jailbreaks | |
| 4. **Data Guards**: Detect and handle sensitive information | |
| 5. **System Guards**: Rate limiting and resource protection | |
| """ | |
| import re | |
| import time | |
| import hashlib | |
| from typing import List, Dict, Optional, Tuple, Any | |
| from dataclasses import dataclass | |
| from collections import defaultdict, deque | |
| import logging | |
| from loguru import logger | |
| # ============================================================================= | |
| # DATA STRUCTURES | |
| # ============================================================================= | |
| class GuardRailResult: | |
| """ | |
| Result from a guard rail check | |
| Attributes: | |
| passed: Whether the check passed | |
| blocked: Whether the input/output should be blocked | |
| reason: Reason for blocking or warning | |
| confidence: Confidence score for the decision | |
| metadata: Additional information about the check | |
| """ | |
| passed: bool | |
| blocked: bool | |
| reason: str | |
| confidence: float | |
| metadata: Dict[str, Any] | |
| class GuardRailConfig: | |
| """ | |
| Configuration for guard rail system | |
| Attributes: | |
| max_query_length: Maximum allowed query length | |
| max_response_length: Maximum allowed response length | |
| min_confidence_threshold: Minimum confidence for responses | |
| rate_limit_requests: Maximum requests per time window | |
| rate_limit_window: Time window for rate limiting (seconds) | |
| enable_pii_detection: Whether to detect PII in documents | |
| enable_content_filtering: Whether to filter harmful content | |
| enable_prompt_injection_detection: Whether to detect prompt injection | |
| """ | |
| max_query_length: int = 1000 | |
| max_response_length: int = 5000 | |
| min_confidence_threshold: float = 0.3 | |
| rate_limit_requests: int = 100 | |
| rate_limit_window: int = 3600 # 1 hour | |
| enable_pii_detection: bool = True | |
| enable_content_filtering: bool = True | |
| enable_prompt_injection_detection: bool = True | |
| # ============================================================================= | |
| # INPUT GUARD RAILS | |
| # ============================================================================= | |
| class InputGuards: | |
| """Guard rails for input validation and sanitization""" | |
| def __init__(self, config: GuardRailConfig): | |
| self.config = config | |
| # Compile regex patterns for efficiency | |
| self.suspicious_patterns = [ | |
| re.compile(r"system:|assistant:|user:", re.IGNORECASE), | |
| re.compile(r"ignore previous|forget everything", re.IGNORECASE), | |
| re.compile(r"you are now|act as|pretend to be", re.IGNORECASE), | |
| re.compile(r"<script|javascript:|eval\(", re.IGNORECASE), | |
| re.compile( | |
| r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+" | |
| ), | |
| ] | |
| # Harmful content patterns | |
| self.harmful_patterns = [ | |
| re.compile(r"\b(hack|crack|exploit|vulnerability)\b", re.IGNORECASE), | |
| re.compile(r"\b(bomb|weapon|explosive)\b", re.IGNORECASE), | |
| re.compile(r"\b(drug|illegal|contraband)\b", re.IGNORECASE), | |
| ] | |
| def validate_query(self, query: str, user_id: str = "anonymous") -> GuardRailResult: | |
| """ | |
| Validate user query for safety and appropriateness | |
| Args: | |
| query: User's query string | |
| user_id: User identifier for rate limiting | |
| Returns: | |
| GuardRailResult with validation outcome | |
| """ | |
| # Check query length | |
| if len(query) > self.config.max_query_length: | |
| return GuardRailResult( | |
| passed=False, | |
| blocked=True, | |
| reason=f"Query too long ({len(query)} chars, max {self.config.max_query_length})", | |
| confidence=1.0, | |
| metadata={"query_length": len(query)}, | |
| ) | |
| # Check for empty or whitespace-only queries | |
| if not query.strip(): | |
| return GuardRailResult( | |
| passed=False, | |
| blocked=True, | |
| reason="Empty or whitespace-only query", | |
| confidence=1.0, | |
| metadata={}, | |
| ) | |
| # Check for suspicious patterns (potential prompt injection) | |
| if self.config.enable_prompt_injection_detection: | |
| for pattern in self.suspicious_patterns: | |
| if pattern.search(query): | |
| return GuardRailResult( | |
| passed=False, | |
| blocked=True, | |
| reason="Suspicious pattern detected (potential prompt injection)", | |
| confidence=0.8, | |
| metadata={"pattern": pattern.pattern}, | |
| ) | |
| # Check for harmful content | |
| if self.config.enable_content_filtering: | |
| harmful_matches = [] | |
| for pattern in self.harmful_patterns: | |
| if pattern.search(query): | |
| harmful_matches.append(pattern.pattern) | |
| if harmful_matches: | |
| return GuardRailResult( | |
| passed=False, | |
| blocked=True, | |
| reason="Harmful content detected", | |
| confidence=0.7, | |
| metadata={"harmful_patterns": harmful_matches}, | |
| ) | |
| return GuardRailResult( | |
| passed=True, | |
| blocked=False, | |
| reason="Query validated successfully", | |
| confidence=1.0, | |
| metadata={}, | |
| ) | |
| def sanitize_query(self, query: str) -> str: | |
| """ | |
| Sanitize query to remove potentially harmful content | |
| Args: | |
| query: Raw query string | |
| Returns: | |
| Sanitized query string | |
| """ | |
| # Remove HTML tags | |
| query = re.sub(r"<[^>]+>", "", query) | |
| # Remove script tags and content | |
| query = re.sub( | |
| r"<script.*?</script>", "", query, flags=re.IGNORECASE | re.DOTALL | |
| ) | |
| # Remove excessive whitespace | |
| query = re.sub(r"\s+", " ", query).strip() | |
| return query | |
| # ============================================================================= | |
| # OUTPUT GUARD RAILS | |
| # ============================================================================= | |
| class OutputGuards: | |
| """Guard rails for output validation and filtering""" | |
| def __init__(self, config: GuardRailConfig): | |
| self.config = config | |
| # Response quality patterns | |
| self.low_quality_patterns = [ | |
| re.compile(r"\b(i don\'t know|i cannot|i am unable)\b", re.IGNORECASE), | |
| re.compile(r"\b(no information|not found|not available)\b", re.IGNORECASE), | |
| ] | |
| # Hallucination indicators | |
| self.hallucination_patterns = [ | |
| re.compile( | |
| r"\b(according to the document|as mentioned in|the document states)\b", | |
| re.IGNORECASE, | |
| ), | |
| re.compile( | |
| r"\b(based on the provided|in the given|from the text)\b", re.IGNORECASE | |
| ), | |
| ] | |
| def validate_response( | |
| self, response: str, confidence: float, context: str = "" | |
| ) -> GuardRailResult: | |
| """ | |
| Validate generated response for safety and quality | |
| Args: | |
| response: Generated response text | |
| confidence: Confidence score from RAG system | |
| context: Retrieved context for validation | |
| Returns: | |
| GuardRailResult with validation outcome | |
| """ | |
| # Check response length | |
| if len(response) > self.config.max_response_length: | |
| return GuardRailResult( | |
| passed=False, | |
| blocked=True, | |
| reason=f"Response too long ({len(response)} chars, max {self.config.max_response_length})", | |
| confidence=1.0, | |
| metadata={"response_length": len(response)}, | |
| ) | |
| # Check confidence threshold | |
| if confidence < self.config.min_confidence_threshold: | |
| return GuardRailResult( | |
| passed=False, | |
| blocked=False, | |
| reason=f"Low confidence response ({confidence:.2f} < {self.config.min_confidence_threshold})", | |
| confidence=confidence, | |
| metadata={"confidence": confidence}, | |
| ) | |
| # Check for low quality responses | |
| low_quality_count = 0 | |
| for pattern in self.low_quality_patterns: | |
| if pattern.search(response): | |
| low_quality_count += 1 | |
| if low_quality_count >= 2: | |
| return GuardRailResult( | |
| passed=False, | |
| blocked=False, | |
| reason="Low quality response detected", | |
| confidence=0.6, | |
| metadata={"low_quality_indicators": low_quality_count}, | |
| ) | |
| # Check for potential hallucinations | |
| if context and self._detect_hallucination(response, context): | |
| return GuardRailResult( | |
| passed=False, | |
| blocked=False, | |
| reason="Potential hallucination detected", | |
| confidence=0.7, | |
| metadata={"hallucination_risk": "high"}, | |
| ) | |
| return GuardRailResult( | |
| passed=True, | |
| blocked=False, | |
| reason="Response validated successfully", | |
| confidence=confidence, | |
| metadata={}, | |
| ) | |
| def _detect_hallucination(self, response: str, context: str) -> bool: | |
| """ | |
| Detect potential hallucinations in response | |
| Args: | |
| response: Generated response | |
| context: Retrieved context | |
| Returns: | |
| True if hallucination is likely detected | |
| """ | |
| # Simple heuristic: check if response contains specific claims not in context | |
| response_lower = response.lower() | |
| context_lower = context.lower() | |
| # Check for specific claims that should be in context | |
| claim_indicators = [ | |
| "the document states", | |
| "according to the text", | |
| "as mentioned in", | |
| "the information shows", | |
| ] | |
| for indicator in claim_indicators: | |
| if indicator in response_lower: | |
| # Check if the surrounding text is actually in context | |
| # This is a simplified check - more sophisticated methods would be needed | |
| return False # For now, we'll be conservative | |
| return False | |
| def filter_response(self, response: str) -> str: | |
| """ | |
| Filter response to remove potentially harmful content | |
| Args: | |
| response: Raw response string | |
| Returns: | |
| Filtered response string | |
| """ | |
| # Remove HTML tags | |
| response = re.sub(r"<[^>]+>", "", response) | |
| # Remove script content | |
| response = re.sub( | |
| r"<script.*?</script>", "", response, flags=re.IGNORECASE | re.DOTALL | |
| ) | |
| # Remove excessive newlines | |
| response = re.sub(r"\n\s*\n\s*\n+", "\n\n", response) | |
| return response.strip() | |
| # ============================================================================= | |
| # DATA GUARD RAILS | |
| # ============================================================================= | |
| class DataGuards: | |
| """Guard rails for data privacy and PII detection""" | |
| def __init__(self, config: GuardRailConfig): | |
| self.config = config | |
| # PII patterns | |
| self.pii_patterns = { | |
| "email": re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"), | |
| "phone": re.compile(r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b"), | |
| "ssn": re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), | |
| "credit_card": re.compile(r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b"), | |
| "ip_address": re.compile(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b"), | |
| } | |
| def detect_pii(self, text: str) -> GuardRailResult: | |
| """ | |
| Detect personally identifiable information in text | |
| Args: | |
| text: Text to analyze for PII | |
| Returns: | |
| GuardRailResult with PII detection outcome | |
| """ | |
| if not self.config.enable_pii_detection: | |
| return GuardRailResult( | |
| passed=True, | |
| blocked=False, | |
| reason="PII detection disabled", | |
| confidence=1.0, | |
| metadata={}, | |
| ) | |
| detected_pii = {} | |
| for pii_type, pattern in self.pii_patterns.items(): | |
| matches = pattern.findall(text) | |
| if matches: | |
| detected_pii[pii_type] = len(matches) | |
| if detected_pii: | |
| return GuardRailResult( | |
| passed=False, | |
| blocked=True, | |
| reason=f"PII detected: {', '.join(detected_pii.keys())}", | |
| confidence=0.9, | |
| metadata={"detected_pii": detected_pii}, | |
| ) | |
| return GuardRailResult( | |
| passed=True, | |
| blocked=False, | |
| reason="No PII detected", | |
| confidence=1.0, | |
| metadata={}, | |
| ) | |
| def sanitize_pii(self, text: str) -> str: | |
| """ | |
| Sanitize text by removing or masking PII | |
| Args: | |
| text: Text containing potential PII | |
| Returns: | |
| Sanitized text with PII masked | |
| """ | |
| # Mask email addresses | |
| text = re.sub( | |
| r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", "[EMAIL]", text | |
| ) | |
| # Mask phone numbers | |
| text = re.sub(r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b", "[PHONE]", text) | |
| # Mask SSN | |
| text = re.sub(r"\b\d{3}-\d{2}-\d{4}\b", "[SSN]", text) | |
| # Mask credit card numbers | |
| text = re.sub(r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b", "[CREDIT_CARD]", text) | |
| # Mask IP addresses | |
| text = re.sub(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", "[IP_ADDRESS]", text) | |
| return text | |
| # ============================================================================= | |
| # SYSTEM GUARD RAILS | |
| # ============================================================================= | |
| class SystemGuards: | |
| """Guard rails for system-level protection""" | |
| def __init__(self, config: GuardRailConfig): | |
| self.config = config | |
| self.request_history = defaultdict(lambda: deque(maxlen=1000)) | |
| self.blocked_users = set() | |
| def check_rate_limit(self, user_id: str) -> GuardRailResult: | |
| """ | |
| Check if user has exceeded rate limits | |
| Args: | |
| user_id: User identifier | |
| Returns: | |
| GuardRailResult with rate limit check outcome | |
| """ | |
| current_time = time.time() | |
| user_requests = self.request_history[user_id] | |
| # Remove old requests outside the window | |
| while ( | |
| user_requests | |
| and current_time - user_requests[0] > self.config.rate_limit_window | |
| ): | |
| user_requests.popleft() | |
| # Check if user is blocked | |
| if user_id in self.blocked_users: | |
| return GuardRailResult( | |
| passed=False, | |
| blocked=True, | |
| reason="User is blocked due to previous violations", | |
| confidence=1.0, | |
| metadata={"user_id": user_id}, | |
| ) | |
| # Check rate limit | |
| if len(user_requests) >= self.config.rate_limit_requests: | |
| # Block user temporarily | |
| self.blocked_users.add(user_id) | |
| return GuardRailResult( | |
| passed=False, | |
| blocked=True, | |
| reason=f"Rate limit exceeded ({len(user_requests)} requests in {self.config.rate_limit_window}s)", | |
| confidence=1.0, | |
| metadata={"requests": len(user_requests)}, | |
| ) | |
| # Add current request | |
| user_requests.append(current_time) | |
| return GuardRailResult( | |
| passed=True, | |
| blocked=False, | |
| reason="Rate limit check passed", | |
| confidence=1.0, | |
| metadata={"requests": len(user_requests)}, | |
| ) | |
| def check_resource_usage( | |
| self, memory_usage: float, cpu_usage: float | |
| ) -> GuardRailResult: | |
| """ | |
| Check system resource usage | |
| Args: | |
| memory_usage: Current memory usage percentage | |
| cpu_usage: Current CPU usage percentage | |
| Returns: | |
| GuardRailResult with resource check outcome | |
| """ | |
| # Define thresholds | |
| memory_threshold = 90.0 # 90% memory usage | |
| cpu_threshold = 95.0 # 95% CPU usage | |
| if memory_usage > memory_threshold: | |
| return GuardRailResult( | |
| passed=False, | |
| blocked=True, | |
| reason=f"High memory usage ({memory_usage:.1f}%)", | |
| confidence=1.0, | |
| metadata={"memory_usage": memory_usage}, | |
| ) | |
| if cpu_usage > cpu_threshold: | |
| return GuardRailResult( | |
| passed=False, | |
| blocked=True, | |
| reason=f"High CPU usage ({cpu_usage:.1f}%)", | |
| confidence=1.0, | |
| metadata={"cpu_usage": cpu_usage}, | |
| ) | |
| return GuardRailResult( | |
| passed=True, | |
| blocked=False, | |
| reason="Resource usage acceptable", | |
| confidence=1.0, | |
| metadata={"memory_usage": memory_usage, "cpu_usage": cpu_usage}, | |
| ) | |
| # ============================================================================= | |
| # MAIN GUARD RAIL SYSTEM | |
| # ============================================================================= | |
| class GuardRailSystem: | |
| """ | |
| Comprehensive guard rail system for RAG | |
| This class orchestrates all guard rail components to ensure | |
| safe and reliable operation of the RAG system. | |
| """ | |
| def __init__(self, config: GuardRailConfig = None): | |
| self.config = config or GuardRailConfig() | |
| # Initialize all guard rail components | |
| self.input_guards = InputGuards(self.config) | |
| self.output_guards = OutputGuards(self.config) | |
| self.data_guards = DataGuards(self.config) | |
| self.system_guards = SystemGuards(self.config) | |
| logger.info("Guard rail system initialized successfully") | |
| def validate_input(self, query: str, user_id: str = "anonymous") -> GuardRailResult: | |
| """ | |
| Comprehensive input validation | |
| Args: | |
| query: User query | |
| user_id: User identifier | |
| Returns: | |
| GuardRailResult with validation outcome | |
| """ | |
| # Check rate limits first | |
| rate_limit_result = self.system_guards.check_rate_limit(user_id) | |
| if not rate_limit_result.passed: | |
| return rate_limit_result | |
| # Validate query | |
| query_result = self.input_guards.validate_query(query, user_id) | |
| if not query_result.passed: | |
| return query_result | |
| # Check for PII in query | |
| pii_result = self.data_guards.detect_pii(query) | |
| if not pii_result.passed: | |
| return pii_result | |
| return GuardRailResult( | |
| passed=True, | |
| blocked=False, | |
| reason="Input validation passed", | |
| confidence=1.0, | |
| metadata={}, | |
| ) | |
| def validate_output( | |
| self, response: str, confidence: float, context: str = "" | |
| ) -> GuardRailResult: | |
| """ | |
| Comprehensive output validation | |
| Args: | |
| response: Generated response | |
| confidence: Confidence score | |
| context: Retrieved context | |
| Returns: | |
| GuardRailResult with validation outcome | |
| """ | |
| # Validate response | |
| response_result = self.output_guards.validate_response( | |
| response, confidence, context | |
| ) | |
| if not response_result.passed: | |
| return response_result | |
| # Check for PII in response | |
| pii_result = self.data_guards.detect_pii(response) | |
| if not pii_result.passed: | |
| return pii_result | |
| return GuardRailResult( | |
| passed=True, | |
| blocked=False, | |
| reason="Output validation passed", | |
| confidence=confidence, | |
| metadata={}, | |
| ) | |
| def sanitize_input(self, query: str) -> str: | |
| """Sanitize user input""" | |
| return self.input_guards.sanitize_query(query) | |
| def sanitize_output(self, response: str) -> str: | |
| """Sanitize generated output""" | |
| return self.output_guards.filter_response(response) | |
| def sanitize_data(self, text: str) -> str: | |
| """Sanitize data by removing PII""" | |
| return self.data_guards.sanitize_pii(text) | |
| def get_system_status(self) -> Dict[str, Any]: | |
| """ | |
| Get current system status and statistics | |
| Returns: | |
| Dictionary with system status information | |
| """ | |
| return { | |
| "total_users": len(self.system_guards.request_history), | |
| "blocked_users": len(self.system_guards.blocked_users), | |
| "config": { | |
| "max_query_length": self.config.max_query_length, | |
| "max_response_length": self.config.max_response_length, | |
| "min_confidence_threshold": self.config.min_confidence_threshold, | |
| "rate_limit_requests": self.config.rate_limit_requests, | |
| "rate_limit_window": self.config.rate_limit_window, | |
| }, | |
| } | |