teoat's picture
fix(backend): fix port and health check robustness
d29a5a0 verified
"""
Web Application Firewall (WAF) middleware implementation
"""
import json
import logging
import re
import time
from typing import Any, Optional
from fastapi import Request
from fastapi.responses import JSONResponse
logger = logging.getLogger(__name__)
class ThreatSignature:
"""Threat signature patterns for WAF"""
# SQL Injection patterns
SQL_INJECTION_PATTERNS = [
r"(\b(union|select|insert|update|delete|drop|create|alter|exec|execute)\b)",
r"(\b(OR|AND)\s*\d+\s*=\s*\d+)",
r"(\b(SELECT\s+.+\s+FROM\s+\w+)",
r"(\b(INSERT\s+INTO\s+\w+)",
r"(\b(UPDATE\s+\w+\s+SET\s+)",
r"(--|;|/\*|\*/)",
r"(\b(SLEEP|WAITFOR|BENCHMARK)\s*\d+)",
r"(\b(LOAD_FILE|INTO\s+OUTFILE)\s+)",
]
# XSS patterns
XSS_PATTERNS = [
r"<script[^>]*>.*?</script>",
r"javascript:",
r"vbscript:",
r"onload\s*=",
r"onerror\s*=",
r"onclick\s*=",
r"<iframe[^>]*>",
r"<object[^>]*>",
r"<embed[^>]*>",
]
# Path Traversal patterns
PATH_TRAVERSAL_PATTERNS = [
r"(\.\./|\.\.\\)",
r"(/etc/|/var/|/bin/)",
r"(\.+)\\.\\",
r"(\w+)\\(\w+)\\",
]
# Command Injection patterns
COMMAND_INJECTION_PATTERNS = [
r"[;&|`$(){}]",
r"\b(command|cmd|exec|eval)\b",
r"\b(powershell|bash|sh|ksh|csh|tcsh|zsh)\b",
]
# File Upload patterns
MALICIOUS_FILE_PATTERNS = [
r"\.(php|jsp|asp|aspx|sh|bat|exe|cmd|com|scr)$",
r"(php|python|perl|ruby|node)",
r"(eval|exec|system)",
]
# LFI/Remote File Inclusion
LFI_PATTERNS = [
r"(php://|file://|http://|https://)",
r"(include|require|include_once|file_get_contents)",
r"(\\x00)",
]
# HTTP Parameter Pollution
HTTP_PP_PATTERNS = [
r"^[^\?&]+=",
r"[^\?&]*[;&\|]",
]
class SecurityHeaders:
"""Security headers manager"""
@staticmethod
def get_security_headers() -> dict[str, str]:
return {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Content-Security-Policy": "default-src 'self'; script-src 'self'; object-src 'self'; style-src 'self' 'unsafe-inline';",
"X-Download-Options": "noopen",
"X-Permitted-Cross-Domain-Policies": "cross-origin",
}
class RateLimiter:
"""Rate limiting implementation"""
def __init__(self, redis_client=None):
self.redis = redis_client
self.requests = {}
def _get_client_identifier(self, request: Request) -> str:
"""Get unique client identifier"""
return request.headers.get("x-forwarded-for") or request.client.host or request.client.host
async def is_rate_limited(self, request: Request, limit: int = 100, window: int = 60) -> bool:
"""Check if request is rate limited"""
if not self.redis:
return False
client_id = self._get_client_identifier(request)
key = f"rate_limit:{client_id}"
current_time = time.time()
# Get current requests from Redis
try:
requests_data = await self.redis.lrange(key, 0, -1)
requests = [json.loads(r) for r in requests_data if r]
# Remove old requests outside window
window_start = current_time - window
requests = [r for r in requests if json.loads(r).get("timestamp", 0) > window_start]
# Add current request
requests.append({"timestamp": current_time})
# Update Redis with cleaned list
await self.redis.delete(key)
if requests:
await self.redis.lpush(key, *[json.dumps(r) for r in requests[-100:]]) # Keep last 100
return len(requests) >= limit
except Exception as e:
logger.error(f"Rate limiting error: {e}")
return False
class WAF:
"""Web Application Firewall implementation"""
def __init__(self, redis_client=None):
self.redis = redis_client
self.rate_limiter = RateLimiter(redis_client)
self.blocked_ips = set()
self.whitelist = set()
self.threat_log = []
def _check_threat(self, value: str, patterns: list[str]) -> Optional[str]:
"""Check if value matches any threat pattern"""
for pattern in patterns:
if re.search(pattern, value, re.IGNORECASE):
return pattern
return None
def _check_sql_injection(self, request: Request) -> tuple[bool, Optional[str]]:
"""Check for SQL injection attempts"""
suspicious_patterns = []
# Check URL parameters
for param, value in request.query_params.items():
pattern = self._check_threat(value, self.SQL_INJECTION_PATTERNS)
if pattern:
suspicious_patterns.append(f"SQL pattern in parameter '{param}': {pattern}")
# Check body (if JSON)
if hasattr(request, "_json") and request._json:
for key, value in request._json.items():
pattern = self._check_threat(str(value), self.SQL_INJECTION_PATTERNS)
if pattern:
suspicious_patterns.append(f"SQL pattern in JSON field '{key}': {pattern}")
if suspicious_patterns:
return True, "; ".join(suspicious_patterns)
return False, None
def _check_xss(self, request: Request) -> tuple[bool, Optional[str]]:
"""Check for XSS attempts"""
suspicious_patterns = []
# Check URL parameters
for param, value in request.query_params.items():
pattern = self._check_threat(value, self.XSS_PATTERNS)
if pattern:
suspicious_patterns.append(f"XSS pattern in parameter '{param}': {pattern}")
# Check body
if hasattr(request, "_json") and request._json:
for key, value in request._json.items():
pattern = self._check_threat(str(value), self.XSS_PATTERNS)
if pattern:
suspicious_patterns.append(f"XSS pattern in JSON field '{key}': {pattern}")
if suspicious_patterns:
return True, "; ".join(suspicious_patterns)
return False, None
def _check_path_traversal(self, request: Request) -> tuple[bool, Optional[str]]:
"""Check for path traversal attempts"""
suspicious_patterns = []
# Check URL path
path = request.url.path
for param, value in request.query_params.items():
pattern = self._check_threat(value, self.PATH_TRAVERSAL_PATTERNS)
if pattern:
suspicious_patterns.append(f"Path traversal in parameter '{param}': {pattern}")
pattern = self._check_threat(path, self.PATH_TRAVERSAL_PATTERNS)
if pattern:
suspicious_patterns.append(f"Path traversal in URL: {pattern}")
if suspicious_patterns:
return True, "; ".join(suspicious_patterns)
return False, None
def _check_file_upload(self, request: Request, file_data: bytes) -> tuple[bool, Optional[str]]:
"""Check for malicious file uploads"""
# Check file extension
malicious_patterns = []
for pattern in self.MALICIOUS_FILE_PATTERNS:
if re.search(pattern, file_data.filename or "", re.IGNORECASE):
malicious_patterns.append(f"Malicious file pattern: {pattern}")
# Check content for malicious code
content_str = file_data.file.read(1024).decode("utf-8", errors="ignore")
for pattern in self.MALICIOUS_FILE_PATTERNS:
if re.search(pattern, content_str, re.IGNORECASE):
malicious_patterns.append(f"Malicious code pattern: {pattern}")
if malicious_patterns:
return True, "; ".join(malicious_patterns)
return False, None
def get_client_ip(self, request: Request) -> str:
"""Get client IP address"""
return request.headers.get("x-forwarded-for") or request.client.host
async def is_ip_blocked(self, request: Request) -> bool:
"""Check if IP is blocked"""
client_ip = self.get_client_ip(request)
if client_ip in self.whitelist:
return False
if client_ip in self.blocked_ips:
return True
# Check temporary blocks from Redis
if self.redis:
try:
temp_blocked = await self.redis.smembers(f"waf:temp_blocked:{client_ip}")
if temp_blocked:
return True
except Exception as e:
logger.error(f"Error checking temporary IP blocks: {e}")
return False
async def block_ip(self, ip: str, duration: int = 3600, reason: str = "Threat detected") -> bool:
"""Block an IP address"""
try:
await self.redis.sadd("waf:blocked_ips", ip)
await self.redis.sadd(f"waf:blocked_reason:{ip}", reason)
await self.redis.sadd(f"waf:blocked_time:{ip}", time.time())
await self.redis.expire(f"waf:blocked_reason:{ip}", duration)
await self.redis.expire(f"waf:blocked_time:{ip}", duration)
self.blocked_ips.add(ip)
# Log the blocking
logger.warning(f"IP {ip} blocked for {duration}s. Reason: {reason}")
return True
except Exception as e:
logger.error(f"Error blocking IP {ip}: {e}")
return False
def create_threat_response(self, threat_type: str, details: str) -> JSONResponse:
"""Create response for detected threats"""
return JSONResponse(
status_code=403,
content={
"error": "Threat detected",
"threat_type": threat_type,
"details": details,
"timestamp": time.time(),
},
headers={
**SecurityHeaders.get_security_headers(),
"X-Content-Type-Options": "nosniff",
},
)
async def check_request(self, request: Request) -> dict[str, Any]:
"""Perform comprehensive WAF checks"""
threats_detected = []
# Check IP blocking
if await self.is_ip_blocked(request):
return {"blocked": True, "reason": "IP address blocked", "threats": ["blocked_ip"]}
# Check rate limiting
if await self.rate_limiter.is_rate_limited(request):
threats_detected.append("rate_limit_exceeded")
# Check SQL Injection
is_sql_injection, sql_details = self._check_sql_injection(request)
if is_sql_injection:
threats_detected.append(f"sql_injection: {sql_details}")
# Check XSS
is_xss, xss_details = self._check_xss(request)
if is_xss:
threats_detected.append(f"xss: {xss_details}")
# Check Path Traversal
is_path_traversal, path_details = self._check_path_traversal(request)
if is_path_traversal:
threats_detected.append(f"path_traversal: {path_details}")
# Check for malicious patterns in all parameters
all_params = dict(request.query_params)
if hasattr(request, "_json"):
all_params.update(request._json)
for param, value in all_params.items():
# Command injection check
pattern = self._check_threat(str(value), self.COMMAND_INJECTION_PATTERNS)
if pattern:
threats_detected.append(f"command_injection in '{param}': {pattern}")
# File inclusion check
pattern = self._check_threat(str(value), self.LFI_PATTERNS)
if pattern:
threats_detected.append(f"file_inclusion in '{param}': {pattern}")
return {
"blocked": len(threats_detected) > 0,
"threats": threats_detected,
"risk_score": min(len(threats_detected) * 10, 100),
}
async def log_threat(self, request: Request, check_result: dict[str, Any]):
"""Log detected threats"""
client_ip = self.get_client_ip(request)
threat_data = {
"timestamp": time.time(),
"ip": client_ip,
"method": request.method,
"url": str(request.url),
"user_agent": request.headers.get("user-agent", ""),
"blocked": check_result.get("blocked", False),
"threats": check_result.get("threats", []),
"risk_score": check_result.get("risk_score", 0),
"request_data": {"query_params": dict(request.query_params), "headers": dict(request.headers)},
}
self.threat_log.append(threat_data)
# Store in Redis for analysis
if self.redis:
await self.redis.lpush("waf:threat_log", json.dumps(threat_data))
# Keep only last 1000 threats
await self.redis.ltrim("waf:threat_log", 0, 1000)
# Global WAF instance
waf = WAF()
# Middleware function
async def waf_middleware(request: Request, call_next):
"""WAF middleware for FastAPI"""
check_result = await waf.check_request(request)
if check_result["blocked"]:
return waf.create_threat_response("Multiple threats detected", f"Threats: {', '.join(check_result['threats'])}")
# Log the request for analysis
await waf.log_threat(request, check_result)
# Continue to next middleware if not blocked
response = await call_next(request)
# Add security headers to all responses
for header, value in SecurityHeaders.get_security_headers().items():
response.headers[header] = value
return response