import re import time from datetime import datetime, timedelta from typing import Dict, List, Optional, Tuple import logging import hashlib class AdvancedSecuritySystem: """ Advanced security system for input validation, rate limiting, and threat detection. Protects the AI system from abuse and malicious inputs. """ def __init__(self): self.rate_limits = {} self.suspicious_ips = {} self.security_log = [] # Suspicious patterns for input validation self.suspicious_patterns = [ # SQL Injection patterns r"(?i)(union.*select|select.*from|insert.*into|delete.*from|drop.*table)", r"(?i)(or.*1=1|and.*1=1|exec.*\(|xp_cmdshell)", r"(\b)(DROP|DELETE|INSERT|UPDATE|ALTER)(\b)", # XSS patterns r"(?i)(script|javascript|onload|onerror|onclick|alert\(|document\.cookie)", r"<.*>.*", # HTML tags # Command injection r"[;&|`]\s*\w+", r"\$\(.*\)", # Path traversal r"\.\./|\.\.\\", # Sensitive data patterns r"(?i)(password|token|key|secret|auth|credential)", r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", # IP addresses # Excessive length or repetition r".{10000,}", # Very long inputs r"(.)\1{50,}", # Repeated characters # Admin/privilege patterns r"(?i)(admin|root|sudo|su -|chmod|chown)" ] # Rate limiting configuration self.rate_limit_config = { "default": {"requests_per_minute": 60, "burst_capacity": 10}, "anonymous": {"requests_per_minute": 30, "burst_capacity": 5}, "suspicious": {"requests_per_minute": 10, "burst_capacity": 2} } self.setup_logging() def setup_logging(self): """Setup security logging""" self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.INFO) def check_request(self, query: str, user_id: str, ip_address: Optional[str] = None) -> Dict[str, any]: """ Comprehensive security check for incoming requests. Args: query: User's query text user_id: User identifier ip_address: Optional IP address for IP-based checks Returns: Security assessment result """ result = { "is_suspicious": False, "alerts": [], "risk_score": 0, "allowed": True, "rate_limit_info": {} } # Rate limiting check rate_limit_result = self.check_rate_limit(user_id, ip_address) if not rate_limit_result["allowed"]: result["is_suspicious"] = True result["allowed"] = False result["alerts"].append("Rate limit exceeded") result["risk_score"] = 100 result["rate_limit_info"] = rate_limit_result return result result["rate_limit_info"] = rate_limit_result # Input validation and pattern matching validation_result = self.validate_input(query, user_id) result["alerts"].extend(validation_result["alerts"]) result["risk_score"] += validation_result["risk_score"] # IP reputation check (if IP provided) if ip_address: ip_result = self.check_ip_reputation(ip_address) result["alerts"].extend(ip_result["alerts"]) result["risk_score"] += ip_result["risk_score"] # Determine overall suspicion if result["risk_score"] >= 50: result["is_suspicious"] = True if result["risk_score"] >= 80: result["allowed"] = False # Log security event self.log_security_event(user_id, ip_address, query, result) return result def check_rate_limit(self, user_id: str, ip_address: Optional[str] = None) -> Dict[str, any]: """Check rate limits for user and/or IP""" current_time = datetime.now() user_key = f"user_{user_id}" ip_key = f"ip_{ip_address}" if ip_address else None # Get rate limit configuration user_config = self.rate_limit_config.get("default") if user_id == "anonymous": user_config = self.rate_limit_config.get("anonymous", user_config) # Check if user is marked as suspicious if self.is_suspicious_user(user_id) or (ip_address and self.is_suspicious_ip(ip_address)): user_config = self.rate_limit_config.get("suspicious", user_config) # Clean old entries for user self.rate_limits[user_key] = [ t for t in self.rate_limits.get(user_key, []) if current_time - t < timedelta(minutes=1) ] # Clean old entries for IP (if provided) if ip_key: self.rate_limits[ip_key] = [ t for t in self.rate_limits.get(ip_key, []) if current_time - t < timedelta(minutes=1) ] # Check user rate limit user_requests = len(self.rate_limits.get(user_key, [])) user_allowed = user_requests < user_config["requests_per_minute"] # Check IP rate limit (if IP provided) ip_allowed = True if ip_key: ip_requests = len(self.rate_limits.get(ip_key, [])) ip_allowed = ip_requests < user_config["requests_per_minute"] allowed = user_allowed and ip_allowed # Add current request to counters if allowed if allowed: self.rate_limits.setdefault(user_key, []).append(current_time) if ip_key: self.rate_limits.setdefault(ip_key, []).append(current_time) return { "allowed": allowed, "user_requests": user_requests, "user_limit": user_config["requests_per_minute"], "ip_requests": len(self.rate_limits.get(ip_key, [])) if ip_key else 0, "ip_limit": user_config["requests_per_minute"] if ip_key else "N/A", "retry_after": 60 if not allowed else 0 } def validate_input(self, query: str, user_id: str) -> Dict[str, any]: """Validate and analyze user input""" result = { "alerts": [], "risk_score": 0 } # Pattern matching for pattern in self.suspicious_patterns: matches = re.findall(pattern, query) if matches: alert_msg = f"Suspicious pattern detected: {pattern[:50]}..." result["alerts"].append(alert_msg) result["risk_score"] += 20 # Query length analysis query_length = len(query) if query_length > 10000: result["alerts"].append("Excessively long query detected") result["risk_score"] += 30 elif query_length > 5000: result["alerts"].append("Very long query detected") result["risk_score"] += 15 # Special character analysis special_chars = len(re.findall(r'[^\w\s\.\?\!]', query)) special_char_ratio = special_chars / max(len(query), 1) if special_char_ratio > 0.3: result["alerts"].append("High percentage of special characters") result["risk_score"] += 25 elif special_char_ratio > 0.2: result["alerts"].append("Elevated special character usage") result["risk_score"] += 10 # Entropy analysis (for encrypted/encoded content) entropy = self.calculate_entropy(query) if entropy > 6.0: # High entropy might indicate encoded/encrypted content result["alerts"].append("High entropy content detected") result["risk_score"] += 20 return result def check_ip_reputation(self, ip_address: str) -> Dict[str, any]: """Check IP reputation (basic implementation)""" result = { "alerts": [], "risk_score": 0 } # Check if IP is in suspicious list if self.is_suspicious_ip(ip_address): result["alerts"].append("IP address has suspicious history") result["risk_score"] += 40 # Simple IP pattern check (private IPs, localhost, etc.) if ip_address in ["127.0.0.1", "localhost", "0.0.0.0"]: result["alerts"].append("Local IP address detected") result["risk_score"] += 10 # Check for rapid requests from this IP ip_key = f"ip_{ip_address}" recent_requests = len(self.rate_limits.get(ip_key, [])) if recent_requests > 50: # High volume from single IP result["alerts"].append("High request volume from IP") result["risk_score"] += 15 return result def calculate_entropy(self, text: str) -> float: """Calculate Shannon entropy of text (for detecting encoded content)""" if not text: return 0.0 import math entropy = 0.0 text_length = len(text) for char in set(text): p_x = float(text.count(char)) / text_length if p_x > 0: entropy += - p_x * math.log2(p_x) return entropy def is_suspicious_user(self, user_id: str) -> bool: """Check if user is marked as suspicious""" # In a real implementation, this would check a database # For now, use simple in-memory tracking user_key = f"user_{user_id}" return self.suspicious_ips.get(user_key, 0) > 5 def is_suspicious_ip(self, ip_address: str) -> bool: """Check if IP is marked as suspicious""" ip_key = f"ip_{ip_address}" return self.suspicious_ips.get(ip_key, 0) > 3 def mark_suspicious(self, user_id: str, ip_address: Optional[str] = None, reason: str = ""): """Mark user or IP as suspicious""" if user_id: user_key = f"user_{user_id}" self.suspicious_ips[user_key] = self.suspicious_ips.get(user_key, 0) + 1 if ip_address: ip_key = f"ip_{ip_address}" self.suspicious_ips[ip_key] = self.suspicious_ips.get(ip_key, 0) + 1 self.logger.warning(f"Marked as suspicious - User: {user_id}, IP: {ip_address}, Reason: {reason}") def log_security_event(self, user_id: str, ip_address: Optional[str], query: str, result: Dict): """Log security event for auditing""" event = { "timestamp": datetime.now().isoformat(), "user_id": user_id, "ip_address": ip_address, "query_preview": query[:100] + "..." if len(query) > 100 else query, "query_length": len(query), "risk_score": result["risk_score"], "alerts": result["alerts"], "allowed": result["allowed"], "is_suspicious": result["is_suspicious"] } self.security_log.append(event) # Keep only last 1000 events if len(self.security_log) > 1000: self.security_log = self.security_log[-1000:] # Log to security logger if high risk if result["risk_score"] >= 50: self.logger.warning(f"Security alert: User {user_id} - Score: {result['risk_score']} - Alerts: {result['alerts']}") def get_security_stats(self) -> Dict[str, any]: """Get security statistics""" recent_events = [e for e in self.security_log if datetime.now() - datetime.fromisoformat(e["timestamp"]) < timedelta(hours=24)] blocked_events = [e for e in recent_events if not e["allowed"]] suspicious_events = [e for e in recent_events if e["is_suspicious"]] return { "total_events_24h": len(recent_events), "blocked_requests_24h": len(blocked_events), "suspicious_requests_24h": len(suspicious_events), "current_suspicious_users": len([k for k, v in self.suspicious_ips.items() if k.startswith("user_") and v > 0]), "current_suspicious_ips": len([k for k, v in self.suspicious_ips.items() if k.startswith("ip_") and v > 0]), "rate_limits_tracked": len(self.rate_limits) } def sanitize_input(self, text: str) -> str: """Sanitize user input to prevent injection attacks""" if not text: return "" # Remove potentially dangerous characters sanitized = re.sub(r'[<>"\']', '', text) # Remove SQL injection patterns sanitized = re.sub(r'(\b)(DROP|DELETE|INSERT|UPDATE|ALTER|EXEC)(\b)', '', sanitized, flags=re.IGNORECASE) # Remove JavaScript and HTML patterns sanitized = re.sub(r'(javascript|script|onload|onerror|onclick)', '', sanitized, flags=re.IGNORECASE) # Remove command injection patterns sanitized = re.sub(r'[;&|`]\s*\w+', '', sanitized) return sanitized.strip() def reset_rate_limits(self, user_id: Optional[str] = None, ip_address: Optional[str] = None): """Reset rate limits for specific user or IP""" if user_id: user_key = f"user_{user_id}" if user_key in self.rate_limits: del self.rate_limits[user_key] if ip_address: ip_key = f"ip_{ip_address}" if ip_key in self.rate_limits: del self.rate_limits[ip_key]