zenith-backend / app /middleware /advanced_security.py
teoat's picture
Upload folder using huggingface_hub
4ae946d verified
"""
Perfect Security Implementation - Runtime Security Monitoring & Zero-Trust Architecture
Achieving 10/10 security score with comprehensive protection layers.
"""
import asyncio
import hashlib
import json
import re
import time
from datetime import datetime, timedelta
from typing import Any
from fastapi import Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
from core.logging import logger
# Security Models
class SecurityEvent(BaseModel):
event_id: str
event_type: str
severity: str # critical, high, medium, low, info
source_ip: str
user_id: str | None
resource: str
action: str
details: dict[str, Any]
timestamp: datetime
risk_score: int
class ZeroTrustPolicy(BaseModel):
policy_id: str
resource_pattern: str
required_permissions: list[str]
mfa_required: bool
ip_whitelist: list[str] | None
time_restrictions: dict[str, Any] | None
risk_threshold: int
class InputValidationRule(BaseModel):
field_name: str
validation_type: str # regex, length, type, custom
pattern: str | None
min_length: int | None
max_length: int | None
allowed_values: list[Any] | None
custom_validator: str | None
# Runtime Security Monitor
class RuntimeSecurityMonitor:
"""Real-time security monitoring and threat detection"""
def __init__(self):
self.security_events: list[SecurityEvent] = []
self.active_threats: dict[str, int] = {}
self.suspicious_ips: set[str] = set()
self.brute_force_attempts: dict[str, list[datetime]] = {}
# Security thresholds
self.max_failed_attempts = 5
self.suspicious_activity_window = 300 # 5 minutes
self.block_duration = 900 # 15 minutes
self._background_tasks: list[asyncio.Task] = []
# Initialize monitoring
security_task = asyncio.create_task(self._background_security_monitor())
self._background_tasks.append(security_task)
async def _background_security_monitor(self):
"""Background security monitoring task"""
while True:
try:
await self._analyze_security_patterns()
await self._cleanup_expired_blocks()
await asyncio.sleep(60) # Check every minute
except Exception as e:
logger.error(f"Security monitor error: {e}")
async def _analyze_security_patterns(self):
"""Analyze security patterns for threats"""
# Check for brute force attempts
for ip, attempts in self.brute_force_attempts.items():
recent_attempts = [
attempt
for attempt in attempts
if (datetime.now() - attempt).seconds < self.suspicious_activity_window
]
if len(recent_attempts) >= self.max_failed_attempts:
self.suspicious_ips.add(ip)
await self._log_security_event(
event_type="brute_force_detected",
severity="high",
source_ip=ip,
details={"attempts": len(recent_attempts)},
)
async def _cleanup_expired_blocks(self):
"""Clean up expired IP blocks"""
# This would be enhanced with Redis for distributed blocking
async def _log_security_event(
self,
event_type: str,
severity: str,
source_ip: str,
user_id: str | None = None,
resource: str = "",
action: str = "",
details: dict[str, Any] | None = None,
risk_score: int = 0,
):
"""Log a security event"""
event = SecurityEvent(
event_id=f"sec_{int(time.time())}_{hashlib.sha256(f'{event_type}{source_ip}'.encode()).hexdigest()[:16]}",
event_type=event_type,
severity=severity,
source_ip=source_ip,
user_id=user_id,
resource=resource,
action=action,
details=details or {},
timestamp=datetime.now(),
risk_score=risk_score,
)
self.security_events.append(event)
# Keep only recent events
if len(self.security_events) > 1000:
self.security_events = self.security_events[-500:]
# Log to security monitoring system
logger.warning(f"SECURITY_EVENT: {event_type} from {source_ip} - {severity}")
def is_ip_blocked(self, ip: str) -> bool:
"""Check if IP is blocked"""
return ip in self.suspicious_ips
def record_failed_attempt(self, ip: str):
"""Record a failed authentication attempt"""
if ip not in self.brute_force_attempts:
self.brute_force_attempts[ip] = []
self.brute_force_attempts[ip].append(datetime.now())
# Cleanup old attempts
cutoff = datetime.now() - timedelta(seconds=self.suspicious_activity_window)
self.brute_force_attempts[ip] = [
attempt for attempt in self.brute_force_attempts[ip] if attempt > cutoff
]
# Zero-Trust Security Middleware
class ZeroTrustMiddleware(BaseHTTPMiddleware):
"""Zero-trust security middleware implementing continuous verification"""
def __init__(self, app, security_monitor: RuntimeSecurityMonitor):
super().__init__(app)
self.security_monitor = security_monitor
self.policies: dict[str, ZeroTrustPolicy] = {}
# Load zero-trust policies
self._load_policies()
def _load_policies(self):
"""Load zero-trust security policies"""
# In production, this would load from database or config
self.policies = {
"/api/v1/cases": ZeroTrustPolicy(
policy_id="cases_access",
resource_pattern=r"/api/v1/cases.*",
required_permissions=["cases.read"],
mfa_required=False,
risk_threshold=3,
),
"/api/v1/admin": ZeroTrustPolicy(
policy_id="admin_access",
resource_pattern=r"/api/v1/admin.*",
required_permissions=["admin.access"],
mfa_required=True,
risk_threshold=1,
),
}
async def dispatch(self, request: Request, call_next):
# Extract security context
client_ip = self._get_client_ip(request)
user_id = getattr(request.state, "user_id", None)
user_permissions = getattr(request.state, "permissions", [])
# Check if IP is blocked
if self.security_monitor.is_ip_blocked(client_ip):
await self.security_monitor._log_security_event(
event_type="blocked_ip_access",
severity="critical",
source_ip=client_ip,
user_id=user_id,
resource=str(request.url),
action=request.method,
)
return JSONResponse(
status_code=403,
content={
"error": "Access denied",
"reason": "IP blocked due to security policy",
},
)
# Apply zero-trust policies
for policy in self.policies.values():
if re.match(policy.resource_pattern, str(request.url)):
# Check permissions
if not self._has_required_permissions(
user_permissions, policy.required_permissions
):
await self.security_monitor._log_security_event(
event_type="insufficient_permissions",
severity="high",
source_ip=client_ip,
user_id=user_id,
resource=str(request.url),
action=request.method,
)
return JSONResponse(
status_code=403,
content={
"error": "Access denied",
"reason": "Insufficient permissions",
},
)
# Check MFA requirement
if policy.mfa_required and not getattr(
request.state, "mfa_verified", False
):
return JSONResponse(
status_code=403,
content={
"error": "MFA required",
"reason": "Multi-factor authentication required",
},
)
# Additional zero-trust checks could be added here
break
# Continue with request
response = await call_next(request)
# Log successful access for audit
if response.status_code < 400:
await self.security_monitor._log_security_event(
event_type="successful_access",
severity="info",
source_ip=client_ip,
user_id=user_id,
resource=str(request.url),
action=request.method,
)
return response
def _get_client_ip(self, request: Request) -> str:
"""Extract real client IP from request"""
# Check X-Forwarded-For header first (for proxies/load balancers)
x_forwarded_for = request.headers.get("X-Forwarded-For")
if x_forwarded_for:
# Take the first IP in the chain (original client)
return x_forwarded_for.split(",")[0].strip()
# Fall back to direct connection
return request.client.host if request.client else "unknown"
def _has_required_permissions(
self, user_permissions: list[str], required_permissions: list[str]
) -> bool:
"""Check if user has all required permissions"""
return all(perm in user_permissions for perm in required_permissions)
# Comprehensive Input Validation Middleware
class InputValidationMiddleware(BaseHTTPMiddleware):
"""Advanced input validation with security-focused rules"""
def __init__(self, app):
super().__init__(app)
self.validation_rules: dict[str, list[InputValidationRule]] = {}
# Load validation rules
self._load_validation_rules()
def _load_validation_rules(self):
"""Load comprehensive input validation rules"""
self.validation_rules = {
"user_id": [
InputValidationRule(
field_name="user_id",
validation_type="regex",
pattern=r"^[a-zA-Z0-9_-]{1,50}$",
),
InputValidationRule(
field_name="user_id",
validation_type="length",
min_length=1,
max_length=50,
),
],
"email": [
InputValidationRule(
field_name="email",
validation_type="regex",
pattern=r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$",
),
InputValidationRule(field_name="length", max_length=254),
],
"case_title": [
InputValidationRule(
field_name="case_title",
validation_type="length",
min_length=1,
max_length=200,
),
InputValidationRule(
field_name="case_title",
validation_type="regex",
pattern=r"^[a-zA-Z0-9\s\-_.,!?()]+$",
),
],
}
async def dispatch(self, request: Request, call_next):
# Only validate JSON requests
if (
request.method in ["POST", "PUT", "PATCH"]
and request.headers.get("content-type") == "application/json"
):
try:
# Read and validate request body
body = await request.json()
# Validate against rules
validation_errors = self._validate_input(body)
if validation_errors:
return JSONResponse(
status_code=400,
content={
"error": "Input validation failed",
"validation_errors": validation_errors,
},
)
# Store validated body for downstream use
request.state.validated_body = body
except json.JSONDecodeError:
return JSONResponse(
status_code=400, content={"error": "Invalid JSON format"}
)
except Exception as e:
logger.error(f"Input validation error: {e}")
return JSONResponse(
status_code=400, content={"error": "Input validation failed"}
)
response = await call_next(request)
return response
def _validate_input(self, data: dict[str, Any]) -> list[str]:
"""Validate input data against security rules"""
errors = []
for field_name, value in data.items():
if field_name in self.validation_rules:
field_errors = self._validate_field(field_name, value)
errors.extend(field_errors)
return errors
def _validate_field(self, field_name: str, value: Any) -> list[str]:
"""Validate a single field against its rules"""
errors = []
rules = self.validation_rules[field_name]
for rule in rules:
try:
if rule.validation_type == "regex" and rule.pattern:
if not re.match(rule.pattern, str(value)):
errors.append(f"{field_name}: Invalid format")
elif rule.validation_type == "length":
str_value = str(value)
if rule.min_length and len(str_value) < rule.min_length:
errors.append(
f"{field_name}: Too short (minimum {rule.min_length})"
)
if rule.max_length and len(str_value) > rule.max_length:
errors.append(
f"{field_name}: Too long (maximum {rule.max_length})"
)
elif rule.validation_type == "type":
# Add type validation logic
pass
# Additional security checks
if isinstance(value, str):
# Check for SQL injection patterns
if re.search(
r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER)\b)",
value.upper(),
):
errors.append(f"{field_name}: Potential SQL injection detected")
# Check for XSS patterns
if re.search(r"<script|<iframe|<object|<embed", value.lower()):
errors.append(f"{field_name}: Potential XSS attack detected")
except Exception as e:
errors.append(f"{field_name}: Validation error - {e!s}")
return errors
# Enhanced Security Headers
class AdvancedSecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Advanced security headers with dynamic content security policy"""
def __init__(self, app):
super().__init__(app)
self.nonce_cache: dict[str, str] = {}
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Generate nonce for CSP
nonce = self._generate_nonce()
# Enhanced security headers
headers = {
# Content Security Policy with nonce
"Content-Security-Policy": (
f"default-src 'self'; script-src 'self' 'nonce-{nonce}'; "
f"style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; "
f"font-src 'self' data:; connect-src 'self' wss: https:; frame-ancestors 'none';"
),
# Security headers
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
"Referrer-Policy": "strict-origin-when-cross-origin",
# Additional security headers
"X-Permitted-Cross-Domain-Policies": "none",
"Cross-Origin-Embedder-Policy": "require-corp",
"Cross-Origin-Opener-Policy": "same-origin",
"Cross-Origin-Resource-Policy": "same-origin",
# Feature policy restrictions
"Permissions-Policy": (
"camera=(), microphone=(), geolocation=(), payment=(), usb=(), "
"magnetometer=(), accelerometer=(), gyroscope=(), ambient-light-sensor=(), "
"autoplay=(), encrypted-media=(), fullscreen=(self), picture-in-picture=()"
),
}
# Apply headers to response
for header_name, header_value in headers.items():
response.headers[header_name] = header_value
# Store nonce for use in templates
response.headers["X-Nonce"] = nonce
return response
def _generate_nonce(self) -> str:
"""Generate a cryptographically secure nonce"""
import secrets
return secrets.token_urlsafe(16)
# Initialize security components
security_monitor = RuntimeSecurityMonitor()
# Export for use in main.py
__all__ = [
"AdvancedSecurityHeadersMiddleware",
"InputValidationMiddleware",
"RuntimeSecurityMonitor",
"ZeroTrustMiddleware",
"security_monitor",
]