import os import time import uuid from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.middleware import SlowAPIMiddleware from slowapi.util import get_remote_address from starlette.middleware.base import BaseHTTPMiddleware from app.constants import DEPRECATED_ENDPOINTS from app.exceptions import setup_exception_handlers from app.middleware.deprecated_monitor import DeprecatedEndpointMonitor from app.middleware.security import ZeroTrustMiddleware from app.services.infrastructure.apm_service import APMMiddleware from app.services.infrastructure.security.audit_service import audit_service from core.csrf_protection import CSRFProtectionMiddleware from core.logging import log_error, log_request, logger from core.performance import PerformanceMonitoringMiddleware from core.rate_limiting import RateLimitingMiddleware # Middleware imports from core.validation import InputValidationMiddleware from middleware.request_id import RequestIDMiddleware # Security headers middleware class SecurityHeadersMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): response = await call_next(request) response.headers["X-Content-Type-Options"] = "nosniff" # X-Frame-Options is removed to allow iframe embedding on Hugging Face Spaces response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Strict-Transport-Security"] = ( "max-age=31536000; includeSubDomains" ) response.headers["Content-Security-Policy"] = ( "default-src 'self' https://huggingface.co; " "script-src 'self' 'unsafe-inline' 'unsafe-eval'; " "style-src 'self' 'unsafe-inline'; " "img-src 'self' data: https:; " "font-src 'self' data:; " "frame-ancestors 'self' https://*.hf.space https://huggingface.co;" ) response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" return response async def request_logging_middleware(request: Request, call_next): """ Middleware for logging all requests with comprehensive audit trail and deprecated endpoint tracking. """ # Use existing request ID from RequestIDMiddleware, or fallback request_id = getattr(request.state, "request_id", str(uuid.uuid4())[:8]) start_time = time.time() # Get client information client_ip = request.client.host if request.client else "unknown" user_agent = request.headers.get("user-agent", "unknown") method = request.method path = request.url.path query_params = str(request.query_params) # Check for deprecated endpoints deprecated_info = None if path in DEPRECATED_ENDPOINTS: deprecated_info = DEPRECATED_ENDPOINTS[path] logger.warning( f"Deprecated endpoint accessed: {path}", extra={ "deprecated_endpoint": path, "deprecated_since": deprecated_info["deprecated_since"], "replacement": deprecated_info.get("replacement"), "request_id": request_id, "client_ip": client_ip, }, ) # Extract user ID from JWT token (simplified for now) user_id = None session_id = str(uuid.uuid4()) # Generate session ID for tracking auth_header = request.headers.get("authorization") if auth_header and auth_header.startswith("Bearer "): try: # In a real implementation, you'd decode the JWT to get user info # For now, simulate user extraction user_id = "authenticated_user" # Placeholder except Exception as e: logger.warning(f"Failed to extract user from token: {e}") # Determine audit action type if path.startswith("/api/v1/auth"): action = "login" if method == "POST" and "login" in path else "auth_access" elif any( path.startswith(f"/api/v1/{endpoint}") for endpoint in ["cases", "transactions", "alerts"] ): action = "data_access" if method == "GET" else "data_modification" elif path.startswith("/api/v1/admin"): action = "admin_operation" else: action = "api_access" # Determine resource being accessed resource_type = path.replace("/api/v1/", "").split("/")[0] or "api" resource_id = None # Prepare audit details details = { "method": method, "path": path, "query_params": query_params[:500] if query_params else None, # Limit size "user_agent": user_agent[:200], "session_id": session_id, "request_id": request_id, } # Add deprecated endpoint info if applicable if deprecated_info: details["deprecated_endpoint"] = { "deprecated_since": deprecated_info["deprecated_since"], "removal_version": deprecated_info["removal_version"], "migration_guide": deprecated_info["migration_guide"], "replacement": deprecated_info["replacement"], } try: response = await call_next(request) duration = time.time() - start_time # Add request ID to response headers response.headers["X-Request-ID"] = request_id # Add deprecation warning headers if applicable if deprecated_info: response.headers["X-Deprecated-Endpoint"] = "true" response.headers["X-Deprecation-Info"] = ( f"Deprecated since {deprecated_info['deprecated_since']}. Use {deprecated_info['replacement']} instead." ) response.headers["X-Migration-Guide"] = deprecated_info["migration_guide"] # Still return 200 but with warning headers response.status_code = 200 # Update audit details with response info details.update( { "status_code": response.status_code, "response_time": round(duration, 3), "success": response.status_code < 400, } ) # Log successful requests to application log log_request( request_id=request_id, method=request.method, path=str(request.url.path), status_code=response.status_code, duration=duration, ) # Log audit event to persistent database audit_service.log_request( user_id=user_id, session_id=session_id, method=method, endpoint=path, status_code=response.status_code, processing_time=duration, details={ **details, "action": action, "resource_type": resource_type, "resource_id": resource_id, }, ip_address=client_ip, user_agent=user_agent, ) return response except Exception as e: duration = time.time() - start_time # Log failed requests to application log log_error( "request_failed", f"Request failed: {e!s}", { "request_id": request_id, "method": request.method, "path": str(request.url.path), "duration": duration, }, ) # Log failed request to audit log audit_service.log_request( user_id=user_id, session_id=session_id, method=method, endpoint=path, status_code=500, processing_time=duration, details={ **details, "action": f"failed_{action}", "resource_type": resource_type, "resource_id": resource_id, }, ip_address=client_ip, user_agent=user_agent, ) raise def setup_middleware(app: FastAPI): # Setup Exception Handlers setup_exception_handlers(app) environment = os.getenv("ENVIRONMENT", "development").lower() is_development = environment == "development" # Security middleware - only in production if not is_development: app.add_middleware(HTTPSRedirectMiddleware) app.add_middleware( TrustedHostMiddleware, allowed_hosts=[ "api.zenith.com", "localhost", "testserver", "testclient", ], ) # CORS configuration with security allowed_origins = [] # Check for environment variable configuration first (Production & Development) env_origins = os.getenv("CORS_ALLOWED_ORIGINS") if env_origins: # Support comma-separated list or "*" if env_origins == "*": # Warn if wildcard is used in production if not is_development: logger.warning( "CORS: Using wildcard '*' in production is not recommended" ) allowed_origins = ["*"] else: allowed_origins = [origin.strip() for origin in env_origins.split(",")] elif is_development: # Development fallback - local development origins allowed_origins = [ "http://localhost:5173", "http://localhost:5174", "http://localhost:5175", "http://127.0.0.1:5173", "http://127.0.0.1:5174", "http://127.0.0.1:3000", "http://localhost:3000", ] else: # Production: Restrict to known frontend domains # For Hugging Face Spaces and production deployments production_origins = [ "https://huggingface.co", "https://*.hf.space", "https://zenith.com", "https://www.zenith.com", ] # Add Hugging Space subdomains if available hf_space = os.getenv("HF_SPACE") if hf_space: production_origins.append(f"https://{hf_space}.hf.space") allowed_origins = production_origins logger.info( f"CORS: Production mode - restricting to allowed origins: {allowed_origins}" ) app.add_middleware( CORSMiddleware, allow_origins=allowed_origins, allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=[ "Authorization", "Content-Type", "X-Requested-With", "Accept", "Accept-Encoding", "Accept-Language", ], max_age=86400, # 24 hours ) # Add security monitoring middleware app.add_middleware(RateLimitingMiddleware) # Rate limiting (legacy, keep for compatibility) limiter = Limiter(key_func=get_remote_address) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_middleware(SlowAPIMiddleware) # APM Monitoring app.add_middleware(APMMiddleware) # Performance monitoring middleware for Prometheus metrics app.add_middleware(PerformanceMonitoringMiddleware) # Security Headers app.add_middleware(SecurityHeadersMiddleware) # Response compression middleware app.add_middleware(GZipMiddleware, minimum_size=1000) # Add input validation middleware app.add_middleware(InputValidationMiddleware) # Zero-Trust Implementation: Strict API Key validation app.add_middleware(ZeroTrustMiddleware) # CSRF protection middleware - RE-ENABLED for production security app.add_middleware(CSRFProtectionMiddleware) # Request ID middleware - distributed tracing (runs early) app.add_middleware(RequestIDMiddleware) # Deprecated endpoint monitoring app.add_middleware(DeprecatedEndpointMonitor) # Request logging middleware app.middleware("http")(request_logging_middleware)