#!/usr/bin/env python3 """ # Guard Rails System for RAG This module provides comprehensive guard rails for the RAG system to ensure: - Input validation and sanitization - Output safety and content filtering - Model safety and prompt injection protection - Data privacy and PII detection - Rate limiting and abuse prevention ## Guard Rail Categories 1. **Input Guards**: Validate and sanitize user inputs 2. **Output Guards**: Filter and validate generated responses 3. **Model Guards**: Protect against prompt injection and jailbreaks 4. **Data Guards**: Detect and handle sensitive information 5. **System Guards**: Rate limiting and resource protection """ import re import time import hashlib from typing import List, Dict, Optional, Tuple, Any from dataclasses import dataclass from collections import defaultdict, deque import logging from loguru import logger # ============================================================================= # DATA STRUCTURES # ============================================================================= @dataclass class GuardRailResult: """ Result from a guard rail check Attributes: passed: Whether the check passed blocked: Whether the input/output should be blocked reason: Reason for blocking or warning confidence: Confidence score for the decision metadata: Additional information about the check """ passed: bool blocked: bool reason: str confidence: float metadata: Dict[str, Any] @dataclass class GuardRailConfig: """ Configuration for guard rail system Attributes: max_query_length: Maximum allowed query length max_response_length: Maximum allowed response length min_confidence_threshold: Minimum confidence for responses rate_limit_requests: Maximum requests per time window rate_limit_window: Time window for rate limiting (seconds) enable_pii_detection: Whether to detect PII in documents enable_content_filtering: Whether to filter harmful content enable_prompt_injection_detection: Whether to detect prompt injection """ max_query_length: int = 1000 max_response_length: int = 5000 min_confidence_threshold: float = 0.3 rate_limit_requests: int = 100 rate_limit_window: int = 3600 # 1 hour enable_pii_detection: bool = True enable_content_filtering: bool = True enable_prompt_injection_detection: bool = True # ============================================================================= # INPUT GUARD RAILS # ============================================================================= class InputGuards: """Guard rails for input validation and sanitization""" def __init__(self, config: GuardRailConfig): self.config = config # Compile regex patterns for efficiency self.suspicious_patterns = [ re.compile(r"system:|assistant:|user:", re.IGNORECASE), re.compile(r"ignore previous|forget everything", re.IGNORECASE), re.compile(r"you are now|act as|pretend to be", re.IGNORECASE), re.compile(r" GuardRailResult: """ Validate user query for safety and appropriateness Args: query: User's query string user_id: User identifier for rate limiting Returns: GuardRailResult with validation outcome """ # Check query length if len(query) > self.config.max_query_length: return GuardRailResult( passed=False, blocked=True, reason=f"Query too long ({len(query)} chars, max {self.config.max_query_length})", confidence=1.0, metadata={"query_length": len(query)}, ) # Check for empty or whitespace-only queries if not query.strip(): return GuardRailResult( passed=False, blocked=True, reason="Empty or whitespace-only query", confidence=1.0, metadata={}, ) # Check for suspicious patterns (potential prompt injection) if self.config.enable_prompt_injection_detection: for pattern in self.suspicious_patterns: if pattern.search(query): return GuardRailResult( passed=False, blocked=True, reason="Suspicious pattern detected (potential prompt injection)", confidence=0.8, metadata={"pattern": pattern.pattern}, ) # Check for harmful content if self.config.enable_content_filtering: harmful_matches = [] for pattern in self.harmful_patterns: if pattern.search(query): harmful_matches.append(pattern.pattern) if harmful_matches: return GuardRailResult( passed=False, blocked=True, reason="Harmful content detected", confidence=0.7, metadata={"harmful_patterns": harmful_matches}, ) return GuardRailResult( passed=True, blocked=False, reason="Query validated successfully", confidence=1.0, metadata={}, ) def sanitize_query(self, query: str) -> str: """ Sanitize query to remove potentially harmful content Args: query: Raw query string Returns: Sanitized query string """ # Remove HTML tags query = re.sub(r"<[^>]+>", "", query) # Remove script tags and content query = re.sub( r"", "", query, flags=re.IGNORECASE | re.DOTALL ) # Remove excessive whitespace query = re.sub(r"\s+", " ", query).strip() return query # ============================================================================= # OUTPUT GUARD RAILS # ============================================================================= class OutputGuards: """Guard rails for output validation and filtering""" def __init__(self, config: GuardRailConfig): self.config = config # Response quality patterns self.low_quality_patterns = [ re.compile(r"\b(i don\'t know|i cannot|i am unable)\b", re.IGNORECASE), re.compile(r"\b(no information|not found|not available)\b", re.IGNORECASE), ] # Hallucination indicators self.hallucination_patterns = [ re.compile( r"\b(according to the document|as mentioned in|the document states)\b", re.IGNORECASE, ), re.compile( r"\b(based on the provided|in the given|from the text)\b", re.IGNORECASE ), ] def validate_response( self, response: str, confidence: float, context: str = "" ) -> GuardRailResult: """ Validate generated response for safety and quality Args: response: Generated response text confidence: Confidence score from RAG system context: Retrieved context for validation Returns: GuardRailResult with validation outcome """ # Check response length if len(response) > self.config.max_response_length: return GuardRailResult( passed=False, blocked=True, reason=f"Response too long ({len(response)} chars, max {self.config.max_response_length})", confidence=1.0, metadata={"response_length": len(response)}, ) # Check confidence threshold if confidence < self.config.min_confidence_threshold: return GuardRailResult( passed=False, blocked=False, reason=f"Low confidence response ({confidence:.2f} < {self.config.min_confidence_threshold})", confidence=confidence, metadata={"confidence": confidence}, ) # Check for low quality responses low_quality_count = 0 for pattern in self.low_quality_patterns: if pattern.search(response): low_quality_count += 1 if low_quality_count >= 2: return GuardRailResult( passed=False, blocked=False, reason="Low quality response detected", confidence=0.6, metadata={"low_quality_indicators": low_quality_count}, ) # Check for potential hallucinations if context and self._detect_hallucination(response, context): return GuardRailResult( passed=False, blocked=False, reason="Potential hallucination detected", confidence=0.7, metadata={"hallucination_risk": "high"}, ) return GuardRailResult( passed=True, blocked=False, reason="Response validated successfully", confidence=confidence, metadata={}, ) def _detect_hallucination(self, response: str, context: str) -> bool: """ Detect potential hallucinations in response Args: response: Generated response context: Retrieved context Returns: True if hallucination is likely detected """ # Simple heuristic: check if response contains specific claims not in context response_lower = response.lower() context_lower = context.lower() # Check for specific claims that should be in context claim_indicators = [ "the document states", "according to the text", "as mentioned in", "the information shows", ] for indicator in claim_indicators: if indicator in response_lower: # Check if the surrounding text is actually in context # This is a simplified check - more sophisticated methods would be needed return False # For now, we'll be conservative return False def filter_response(self, response: str) -> str: """ Filter response to remove potentially harmful content Args: response: Raw response string Returns: Filtered response string """ # Remove HTML tags response = re.sub(r"<[^>]+>", "", response) # Remove script content response = re.sub( r"", "", response, flags=re.IGNORECASE | re.DOTALL ) # Remove excessive newlines response = re.sub(r"\n\s*\n\s*\n+", "\n\n", response) return response.strip() # ============================================================================= # DATA GUARD RAILS # ============================================================================= class DataGuards: """Guard rails for data privacy and PII detection""" def __init__(self, config: GuardRailConfig): self.config = config # PII patterns self.pii_patterns = { "email": re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"), "phone": re.compile(r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b"), "ssn": re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), "credit_card": re.compile(r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b"), "ip_address": re.compile(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b"), } def detect_pii(self, text: str) -> GuardRailResult: """ Detect personally identifiable information in text Args: text: Text to analyze for PII Returns: GuardRailResult with PII detection outcome """ if not self.config.enable_pii_detection: return GuardRailResult( passed=True, blocked=False, reason="PII detection disabled", confidence=1.0, metadata={}, ) detected_pii = {} for pii_type, pattern in self.pii_patterns.items(): matches = pattern.findall(text) if matches: detected_pii[pii_type] = len(matches) if detected_pii: return GuardRailResult( passed=False, blocked=True, reason=f"PII detected: {', '.join(detected_pii.keys())}", confidence=0.9, metadata={"detected_pii": detected_pii}, ) return GuardRailResult( passed=True, blocked=False, reason="No PII detected", confidence=1.0, metadata={}, ) def sanitize_pii(self, text: str) -> str: """ Sanitize text by removing or masking PII Args: text: Text containing potential PII Returns: Sanitized text with PII masked """ # Mask email addresses text = re.sub( r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", "[EMAIL]", text ) # Mask phone numbers text = re.sub(r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b", "[PHONE]", text) # Mask SSN text = re.sub(r"\b\d{3}-\d{2}-\d{4}\b", "[SSN]", text) # Mask credit card numbers text = re.sub(r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b", "[CREDIT_CARD]", text) # Mask IP addresses text = re.sub(r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", "[IP_ADDRESS]", text) return text # ============================================================================= # SYSTEM GUARD RAILS # ============================================================================= class SystemGuards: """Guard rails for system-level protection""" def __init__(self, config: GuardRailConfig): self.config = config self.request_history = defaultdict(lambda: deque(maxlen=1000)) self.blocked_users = set() def check_rate_limit(self, user_id: str) -> GuardRailResult: """ Check if user has exceeded rate limits Args: user_id: User identifier Returns: GuardRailResult with rate limit check outcome """ current_time = time.time() user_requests = self.request_history[user_id] # Remove old requests outside the window while ( user_requests and current_time - user_requests[0] > self.config.rate_limit_window ): user_requests.popleft() # Check if user is blocked if user_id in self.blocked_users: return GuardRailResult( passed=False, blocked=True, reason="User is blocked due to previous violations", confidence=1.0, metadata={"user_id": user_id}, ) # Check rate limit if len(user_requests) >= self.config.rate_limit_requests: # Block user temporarily self.blocked_users.add(user_id) return GuardRailResult( passed=False, blocked=True, reason=f"Rate limit exceeded ({len(user_requests)} requests in {self.config.rate_limit_window}s)", confidence=1.0, metadata={"requests": len(user_requests)}, ) # Add current request user_requests.append(current_time) return GuardRailResult( passed=True, blocked=False, reason="Rate limit check passed", confidence=1.0, metadata={"requests": len(user_requests)}, ) def check_resource_usage( self, memory_usage: float, cpu_usage: float ) -> GuardRailResult: """ Check system resource usage Args: memory_usage: Current memory usage percentage cpu_usage: Current CPU usage percentage Returns: GuardRailResult with resource check outcome """ # Define thresholds memory_threshold = 90.0 # 90% memory usage cpu_threshold = 95.0 # 95% CPU usage if memory_usage > memory_threshold: return GuardRailResult( passed=False, blocked=True, reason=f"High memory usage ({memory_usage:.1f}%)", confidence=1.0, metadata={"memory_usage": memory_usage}, ) if cpu_usage > cpu_threshold: return GuardRailResult( passed=False, blocked=True, reason=f"High CPU usage ({cpu_usage:.1f}%)", confidence=1.0, metadata={"cpu_usage": cpu_usage}, ) return GuardRailResult( passed=True, blocked=False, reason="Resource usage acceptable", confidence=1.0, metadata={"memory_usage": memory_usage, "cpu_usage": cpu_usage}, ) # ============================================================================= # MAIN GUARD RAIL SYSTEM # ============================================================================= class GuardRailSystem: """ Comprehensive guard rail system for RAG This class orchestrates all guard rail components to ensure safe and reliable operation of the RAG system. """ def __init__(self, config: GuardRailConfig = None): self.config = config or GuardRailConfig() # Initialize all guard rail components self.input_guards = InputGuards(self.config) self.output_guards = OutputGuards(self.config) self.data_guards = DataGuards(self.config) self.system_guards = SystemGuards(self.config) logger.info("Guard rail system initialized successfully") def validate_input(self, query: str, user_id: str = "anonymous") -> GuardRailResult: """ Comprehensive input validation Args: query: User query user_id: User identifier Returns: GuardRailResult with validation outcome """ # Check rate limits first rate_limit_result = self.system_guards.check_rate_limit(user_id) if not rate_limit_result.passed: return rate_limit_result # Validate query query_result = self.input_guards.validate_query(query, user_id) if not query_result.passed: return query_result # Check for PII in query pii_result = self.data_guards.detect_pii(query) if not pii_result.passed: return pii_result return GuardRailResult( passed=True, blocked=False, reason="Input validation passed", confidence=1.0, metadata={}, ) def validate_output( self, response: str, confidence: float, context: str = "" ) -> GuardRailResult: """ Comprehensive output validation Args: response: Generated response confidence: Confidence score context: Retrieved context Returns: GuardRailResult with validation outcome """ # Validate response response_result = self.output_guards.validate_response( response, confidence, context ) if not response_result.passed: return response_result # Check for PII in response pii_result = self.data_guards.detect_pii(response) if not pii_result.passed: return pii_result return GuardRailResult( passed=True, blocked=False, reason="Output validation passed", confidence=confidence, metadata={}, ) def sanitize_input(self, query: str) -> str: """Sanitize user input""" return self.input_guards.sanitize_query(query) def sanitize_output(self, response: str) -> str: """Sanitize generated output""" return self.output_guards.filter_response(response) def sanitize_data(self, text: str) -> str: """Sanitize data by removing PII""" return self.data_guards.sanitize_pii(text) def get_system_status(self) -> Dict[str, Any]: """ Get current system status and statistics Returns: Dictionary with system status information """ return { "total_users": len(self.system_guards.request_history), "blocked_users": len(self.system_guards.blocked_users), "config": { "max_query_length": self.config.max_query_length, "max_response_length": self.config.max_response_length, "min_confidence_threshold": self.config.min_confidence_threshold, "rate_limit_requests": self.config.rate_limit_requests, "rate_limit_window": self.config.rate_limit_window, }, }