Spaces:
Paused
Paused
| """ | |
| Web Application Firewall (WAF) middleware implementation | |
| """ | |
| import json | |
| import logging | |
| import re | |
| import time | |
| from typing import Any, Optional | |
| from fastapi import Request | |
| from fastapi.responses import JSONResponse | |
| logger = logging.getLogger(__name__) | |
| class ThreatSignature: | |
| """Threat signature patterns for WAF""" | |
| # SQL Injection patterns | |
| SQL_INJECTION_PATTERNS = [ | |
| r"(\b(union|select|insert|update|delete|drop|create|alter|exec|execute)\b)", | |
| r"(\b(OR|AND)\s*\d+\s*=\s*\d+)", | |
| r"(\b(SELECT\s+.+\s+FROM\s+\w+)", | |
| r"(\b(INSERT\s+INTO\s+\w+)", | |
| r"(\b(UPDATE\s+\w+\s+SET\s+)", | |
| r"(--|;|/\*|\*/)", | |
| r"(\b(SLEEP|WAITFOR|BENCHMARK)\s*\d+)", | |
| r"(\b(LOAD_FILE|INTO\s+OUTFILE)\s+)", | |
| ] | |
| # XSS patterns | |
| XSS_PATTERNS = [ | |
| r"<script[^>]*>.*?</script>", | |
| r"javascript:", | |
| r"vbscript:", | |
| r"onload\s*=", | |
| r"onerror\s*=", | |
| r"onclick\s*=", | |
| r"<iframe[^>]*>", | |
| r"<object[^>]*>", | |
| r"<embed[^>]*>", | |
| ] | |
| # Path Traversal patterns | |
| PATH_TRAVERSAL_PATTERNS = [ | |
| r"(\.\./|\.\.\\)", | |
| r"(/etc/|/var/|/bin/)", | |
| r"(\.+)\\.\\", | |
| r"(\w+)\\(\w+)\\", | |
| ] | |
| # Command Injection patterns | |
| COMMAND_INJECTION_PATTERNS = [ | |
| r"[;&|`$(){}]", | |
| r"\b(command|cmd|exec|eval)\b", | |
| r"\b(powershell|bash|sh|ksh|csh|tcsh|zsh)\b", | |
| ] | |
| # File Upload patterns | |
| MALICIOUS_FILE_PATTERNS = [ | |
| r"\.(php|jsp|asp|aspx|sh|bat|exe|cmd|com|scr)$", | |
| r"(php|python|perl|ruby|node)", | |
| r"(eval|exec|system)", | |
| ] | |
| # LFI/Remote File Inclusion | |
| LFI_PATTERNS = [ | |
| r"(php://|file://|http://|https://)", | |
| r"(include|require|include_once|file_get_contents)", | |
| r"(\\x00)", | |
| ] | |
| # HTTP Parameter Pollution | |
| HTTP_PP_PATTERNS = [ | |
| r"^[^\?&]+=", | |
| r"[^\?&]*[;&\|]", | |
| ] | |
| class SecurityHeaders: | |
| """Security headers manager""" | |
| def get_security_headers() -> dict[str, str]: | |
| return { | |
| "X-Content-Type-Options": "nosniff", | |
| "X-Frame-Options": "DENY", | |
| "X-XSS-Protection": "1; mode=block", | |
| "Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload", | |
| "Referrer-Policy": "strict-origin-when-cross-origin", | |
| "Content-Security-Policy": "default-src 'self'; script-src 'self'; object-src 'self'; style-src 'self' 'unsafe-inline';", | |
| "X-Download-Options": "noopen", | |
| "X-Permitted-Cross-Domain-Policies": "cross-origin", | |
| } | |
| class RateLimiter: | |
| """Rate limiting implementation""" | |
| def __init__(self, redis_client=None): | |
| self.redis = redis_client | |
| self.requests = {} | |
| def _get_client_identifier(self, request: Request) -> str: | |
| """Get unique client identifier""" | |
| return request.headers.get("x-forwarded-for") or request.client.host or request.client.host | |
| async def is_rate_limited(self, request: Request, limit: int = 100, window: int = 60) -> bool: | |
| """Check if request is rate limited""" | |
| if not self.redis: | |
| return False | |
| client_id = self._get_client_identifier(request) | |
| key = f"rate_limit:{client_id}" | |
| current_time = time.time() | |
| # Get current requests from Redis | |
| try: | |
| requests_data = await self.redis.lrange(key, 0, -1) | |
| requests = [json.loads(r) for r in requests_data if r] | |
| # Remove old requests outside window | |
| window_start = current_time - window | |
| requests = [r for r in requests if json.loads(r).get("timestamp", 0) > window_start] | |
| # Add current request | |
| requests.append({"timestamp": current_time}) | |
| # Update Redis with cleaned list | |
| await self.redis.delete(key) | |
| if requests: | |
| await self.redis.lpush(key, *[json.dumps(r) for r in requests[-100:]]) # Keep last 100 | |
| return len(requests) >= limit | |
| except Exception as e: | |
| logger.error(f"Rate limiting error: {e}") | |
| return False | |
| class WAF: | |
| """Web Application Firewall implementation""" | |
| def __init__(self, redis_client=None): | |
| self.redis = redis_client | |
| self.rate_limiter = RateLimiter(redis_client) | |
| self.blocked_ips = set() | |
| self.whitelist = set() | |
| self.threat_log = [] | |
| def _check_threat(self, value: str, patterns: list[str]) -> Optional[str]: | |
| """Check if value matches any threat pattern""" | |
| for pattern in patterns: | |
| if re.search(pattern, value, re.IGNORECASE): | |
| return pattern | |
| return None | |
| def _check_sql_injection(self, request: Request) -> tuple[bool, Optional[str]]: | |
| """Check for SQL injection attempts""" | |
| suspicious_patterns = [] | |
| # Check URL parameters | |
| for param, value in request.query_params.items(): | |
| pattern = self._check_threat(value, self.SQL_INJECTION_PATTERNS) | |
| if pattern: | |
| suspicious_patterns.append(f"SQL pattern in parameter '{param}': {pattern}") | |
| # Check body (if JSON) | |
| if hasattr(request, "_json") and request._json: | |
| for key, value in request._json.items(): | |
| pattern = self._check_threat(str(value), self.SQL_INJECTION_PATTERNS) | |
| if pattern: | |
| suspicious_patterns.append(f"SQL pattern in JSON field '{key}': {pattern}") | |
| if suspicious_patterns: | |
| return True, "; ".join(suspicious_patterns) | |
| return False, None | |
| def _check_xss(self, request: Request) -> tuple[bool, Optional[str]]: | |
| """Check for XSS attempts""" | |
| suspicious_patterns = [] | |
| # Check URL parameters | |
| for param, value in request.query_params.items(): | |
| pattern = self._check_threat(value, self.XSS_PATTERNS) | |
| if pattern: | |
| suspicious_patterns.append(f"XSS pattern in parameter '{param}': {pattern}") | |
| # Check body | |
| if hasattr(request, "_json") and request._json: | |
| for key, value in request._json.items(): | |
| pattern = self._check_threat(str(value), self.XSS_PATTERNS) | |
| if pattern: | |
| suspicious_patterns.append(f"XSS pattern in JSON field '{key}': {pattern}") | |
| if suspicious_patterns: | |
| return True, "; ".join(suspicious_patterns) | |
| return False, None | |
| def _check_path_traversal(self, request: Request) -> tuple[bool, Optional[str]]: | |
| """Check for path traversal attempts""" | |
| suspicious_patterns = [] | |
| # Check URL path | |
| path = request.url.path | |
| for param, value in request.query_params.items(): | |
| pattern = self._check_threat(value, self.PATH_TRAVERSAL_PATTERNS) | |
| if pattern: | |
| suspicious_patterns.append(f"Path traversal in parameter '{param}': {pattern}") | |
| pattern = self._check_threat(path, self.PATH_TRAVERSAL_PATTERNS) | |
| if pattern: | |
| suspicious_patterns.append(f"Path traversal in URL: {pattern}") | |
| if suspicious_patterns: | |
| return True, "; ".join(suspicious_patterns) | |
| return False, None | |
| def _check_file_upload(self, request: Request, file_data: bytes) -> tuple[bool, Optional[str]]: | |
| """Check for malicious file uploads""" | |
| # Check file extension | |
| malicious_patterns = [] | |
| for pattern in self.MALICIOUS_FILE_PATTERNS: | |
| if re.search(pattern, file_data.filename or "", re.IGNORECASE): | |
| malicious_patterns.append(f"Malicious file pattern: {pattern}") | |
| # Check content for malicious code | |
| content_str = file_data.file.read(1024).decode("utf-8", errors="ignore") | |
| for pattern in self.MALICIOUS_FILE_PATTERNS: | |
| if re.search(pattern, content_str, re.IGNORECASE): | |
| malicious_patterns.append(f"Malicious code pattern: {pattern}") | |
| if malicious_patterns: | |
| return True, "; ".join(malicious_patterns) | |
| return False, None | |
| def get_client_ip(self, request: Request) -> str: | |
| """Get client IP address""" | |
| return request.headers.get("x-forwarded-for") or request.client.host | |
| async def is_ip_blocked(self, request: Request) -> bool: | |
| """Check if IP is blocked""" | |
| client_ip = self.get_client_ip(request) | |
| if client_ip in self.whitelist: | |
| return False | |
| if client_ip in self.blocked_ips: | |
| return True | |
| # Check temporary blocks from Redis | |
| if self.redis: | |
| try: | |
| temp_blocked = await self.redis.smembers(f"waf:temp_blocked:{client_ip}") | |
| if temp_blocked: | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error checking temporary IP blocks: {e}") | |
| return False | |
| async def block_ip(self, ip: str, duration: int = 3600, reason: str = "Threat detected") -> bool: | |
| """Block an IP address""" | |
| try: | |
| await self.redis.sadd("waf:blocked_ips", ip) | |
| await self.redis.sadd(f"waf:blocked_reason:{ip}", reason) | |
| await self.redis.sadd(f"waf:blocked_time:{ip}", time.time()) | |
| await self.redis.expire(f"waf:blocked_reason:{ip}", duration) | |
| await self.redis.expire(f"waf:blocked_time:{ip}", duration) | |
| self.blocked_ips.add(ip) | |
| # Log the blocking | |
| logger.warning(f"IP {ip} blocked for {duration}s. Reason: {reason}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error blocking IP {ip}: {e}") | |
| return False | |
| def create_threat_response(self, threat_type: str, details: str) -> JSONResponse: | |
| """Create response for detected threats""" | |
| return JSONResponse( | |
| status_code=403, | |
| content={ | |
| "error": "Threat detected", | |
| "threat_type": threat_type, | |
| "details": details, | |
| "timestamp": time.time(), | |
| }, | |
| headers={ | |
| **SecurityHeaders.get_security_headers(), | |
| "X-Content-Type-Options": "nosniff", | |
| }, | |
| ) | |
| async def check_request(self, request: Request) -> dict[str, Any]: | |
| """Perform comprehensive WAF checks""" | |
| threats_detected = [] | |
| # Check IP blocking | |
| if await self.is_ip_blocked(request): | |
| return {"blocked": True, "reason": "IP address blocked", "threats": ["blocked_ip"]} | |
| # Check rate limiting | |
| if await self.rate_limiter.is_rate_limited(request): | |
| threats_detected.append("rate_limit_exceeded") | |
| # Check SQL Injection | |
| is_sql_injection, sql_details = self._check_sql_injection(request) | |
| if is_sql_injection: | |
| threats_detected.append(f"sql_injection: {sql_details}") | |
| # Check XSS | |
| is_xss, xss_details = self._check_xss(request) | |
| if is_xss: | |
| threats_detected.append(f"xss: {xss_details}") | |
| # Check Path Traversal | |
| is_path_traversal, path_details = self._check_path_traversal(request) | |
| if is_path_traversal: | |
| threats_detected.append(f"path_traversal: {path_details}") | |
| # Check for malicious patterns in all parameters | |
| all_params = dict(request.query_params) | |
| if hasattr(request, "_json"): | |
| all_params.update(request._json) | |
| for param, value in all_params.items(): | |
| # Command injection check | |
| pattern = self._check_threat(str(value), self.COMMAND_INJECTION_PATTERNS) | |
| if pattern: | |
| threats_detected.append(f"command_injection in '{param}': {pattern}") | |
| # File inclusion check | |
| pattern = self._check_threat(str(value), self.LFI_PATTERNS) | |
| if pattern: | |
| threats_detected.append(f"file_inclusion in '{param}': {pattern}") | |
| return { | |
| "blocked": len(threats_detected) > 0, | |
| "threats": threats_detected, | |
| "risk_score": min(len(threats_detected) * 10, 100), | |
| } | |
| async def log_threat(self, request: Request, check_result: dict[str, Any]): | |
| """Log detected threats""" | |
| client_ip = self.get_client_ip(request) | |
| threat_data = { | |
| "timestamp": time.time(), | |
| "ip": client_ip, | |
| "method": request.method, | |
| "url": str(request.url), | |
| "user_agent": request.headers.get("user-agent", ""), | |
| "blocked": check_result.get("blocked", False), | |
| "threats": check_result.get("threats", []), | |
| "risk_score": check_result.get("risk_score", 0), | |
| "request_data": {"query_params": dict(request.query_params), "headers": dict(request.headers)}, | |
| } | |
| self.threat_log.append(threat_data) | |
| # Store in Redis for analysis | |
| if self.redis: | |
| await self.redis.lpush("waf:threat_log", json.dumps(threat_data)) | |
| # Keep only last 1000 threats | |
| await self.redis.ltrim("waf:threat_log", 0, 1000) | |
| # Global WAF instance | |
| waf = WAF() | |
| # Middleware function | |
| async def waf_middleware(request: Request, call_next): | |
| """WAF middleware for FastAPI""" | |
| check_result = await waf.check_request(request) | |
| if check_result["blocked"]: | |
| return waf.create_threat_response("Multiple threats detected", f"Threats: {', '.join(check_result['threats'])}") | |
| # Log the request for analysis | |
| await waf.log_threat(request, check_result) | |
| # Continue to next middleware if not blocked | |
| response = await call_next(request) | |
| # Add security headers to all responses | |
| for header, value in SecurityHeaders.get_security_headers().items(): | |
| response.headers[header] = value | |
| return response | |