| """ |
| API Middleware & Security Layer |
| - Global exception handling |
| - CORS configuration |
| - Security headers |
| - Request validation |
| - Logging middleware |
| """ |
|
|
| from fastapi import Request, Response |
| from fastapi.responses import JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.middleware.trustedhost import TrustedHostMiddleware |
| from fastapi.middleware.gzip import GZipMiddleware |
| import time |
| import logging |
| from typing import Callable |
| import json |
|
|
| from server.error_handler import log_error, log_audit, log_performance |
| from server.rate_limiter import check_rate_limit_middleware |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class SecurityHeadersMiddleware: |
| """Add security headers to all responses""" |
| |
| def __init__(self, app): |
| self.app = app |
| |
| async def __call__(self, request: Request, call_next: Callable) -> Response: |
| response = await call_next(request) |
| |
| |
| response.headers["X-Content-Type-Options"] = "nosniff" |
| response.headers["X-Frame-Options"] = "DENY" |
| response.headers["X-XSS-Protection"] = "1; mode=block" |
| response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" |
| response.headers["Content-Security-Policy"] = "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'" |
| response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" |
| response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()" |
| |
| return response |
|
|
| class RequestLoggingMiddleware: |
| """Log all requests and responses""" |
| |
| def __init__(self, app): |
| self.app = app |
| |
| async def __call__(self, request: Request, call_next: Callable) -> Response: |
| start_time = time.time() |
| |
| |
| user_id = None |
| try: |
| auth = request.headers.get('authorization', '') |
| if auth.startswith('Bearer '): |
| from server import users |
| token = auth.split(' ', 1)[1].strip() |
| user_id = users.verify_token(token) |
| except: |
| pass |
| |
| |
| logger.info(f"β {request.method} {request.url.path} | User: {user_id}") |
| |
| try: |
| response = await call_next(request) |
| except Exception as e: |
| |
| error_id = log_error( |
| 'REQUEST_ERROR', |
| str(e), |
| endpoint=request.url.path, |
| user_id=user_id, |
| status_code=500 |
| ) |
| logger.error(f"β {request.method} {request.url.path} | Error: {str(e)} | ID: {error_id}") |
| return JSONResponse( |
| {'ok': False, 'error': str(e), 'error_id': error_id}, |
| status_code=500 |
| ) |
| |
| |
| response_time = (time.time() - start_time) * 1000 |
| log_performance( |
| request.url.path, |
| request.method, |
| response_time, |
| response.status_code, |
| user_id |
| ) |
| |
| logger.info(f"β {request.method} {request.url.path} | {response.status_code} | {response_time:.2f}ms") |
| |
| return response |
|
|
| class RateLimitMiddleware: |
| """Enforce rate limiting""" |
| |
| def __init__(self, app): |
| self.app = app |
| |
| async def __call__(self, request: Request, call_next: Callable) -> Response: |
| |
| if request.url.path == '/health': |
| return await call_next(request) |
| |
| |
| allowed, error = check_rate_limit_middleware(request) |
| if not allowed: |
| return JSONResponse( |
| {'ok': False, 'error': error}, |
| status_code=429, |
| headers={'Retry-After': '60'} |
| ) |
| |
| return await call_next(request) |
|
|
| class InputValidationMiddleware: |
| """Validate and sanitize input""" |
| |
| def __init__(self, app): |
| self.app = app |
| |
| async def __call__(self, request: Request, call_next: Callable) -> Response: |
| |
| content_length = request.headers.get('content-length') |
| if content_length and int(content_length) > 10 * 1024 * 1024: |
| return JSONResponse( |
| {'ok': False, 'error': 'Request body too large'}, |
| status_code=413 |
| ) |
| |
| |
| if request.method in ['POST', 'PUT', 'PATCH']: |
| content_type = request.headers.get('content-type', '') |
| if content_type and not any(ct in content_type for ct in ['application/json', 'multipart/form-data', 'application/x-www-form-urlencoded']): |
| return JSONResponse( |
| {'ok': False, 'error': 'Invalid content type'}, |
| status_code=415 |
| ) |
| |
| return await call_next(request) |
|
|
| def setup_middleware(app): |
| """Setup all middleware""" |
| |
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining", "Retry-After"], |
| ) |
| |
| |
| app.add_middleware( |
| TrustedHostMiddleware, |
| allowed_hosts=["localhost", "127.0.0.1", "*.moharek.com"] |
| ) |
| |
| |
| app.add_middleware(GZipMiddleware, minimum_size=1000) |
| |
| |
| app.add_middleware(SecurityHeadersMiddleware) |
| app.add_middleware(InputValidationMiddleware) |
| app.add_middleware(RateLimitMiddleware) |
| app.add_middleware(RequestLoggingMiddleware) |
|
|
| def setup_exception_handlers(app): |
| """Setup global exception handlers""" |
| |
| @app.exception_handler(Exception) |
| async def global_exception_handler(request: Request, exc: Exception): |
| """Handle all unhandled exceptions""" |
| error_id = log_error( |
| 'UNHANDLED_EXCEPTION', |
| str(exc), |
| endpoint=request.url.path, |
| status_code=500 |
| ) |
| |
| logger.error(f"Unhandled exception: {str(exc)} | Error ID: {error_id}") |
| |
| return JSONResponse( |
| { |
| 'ok': False, |
| 'error': 'Internal server error', |
| 'error_id': error_id |
| }, |
| status_code=500 |
| ) |
| |
| @app.exception_handler(ValueError) |
| async def value_error_handler(request: Request, exc: ValueError): |
| """Handle validation errors""" |
| error_id = log_error( |
| 'VALIDATION_ERROR', |
| str(exc), |
| endpoint=request.url.path, |
| status_code=400 |
| ) |
| |
| return JSONResponse( |
| { |
| 'ok': False, |
| 'error': str(exc), |
| 'error_id': error_id |
| }, |
| status_code=400 |
| ) |
|
|
| def setup_startup_shutdown(app): |
| """Setup startup and shutdown events""" |
| |
| @app.on_event("startup") |
| async def startup(): |
| """Initialize services on startup""" |
| logger.info("π Starting GEO Platform API...") |
| |
| |
| try: |
| from server.monitoring import start_monitoring |
| start_monitoring() |
| logger.info("β Monitoring daemon started") |
| except Exception as e: |
| logger.error(f"β Failed to start monitoring: {e}") |
| |
| |
| try: |
| from server.error_handler import init_error_db |
| from server.monitoring import init_monitoring_db |
| init_error_db() |
| init_monitoring_db() |
| logger.info("β Databases initialized") |
| except Exception as e: |
| logger.error(f"β Failed to initialize databases: {e}") |
| |
| |
| try: |
| from server.cache_manager import cache |
| stats = cache.get_cache_stats() |
| logger.info(f"β Cache initialized: {stats}") |
| except Exception as e: |
| logger.error(f"β Failed to initialize cache: {e}") |
| |
| logger.info("β GEO Platform API is ready!") |
| |
| @app.on_event("shutdown") |
| async def shutdown(): |
| """Cleanup on shutdown""" |
| logger.info("π Shutting down GEO Platform API...") |
| |
| |
| try: |
| from server.monitoring import stop_monitoring |
| stop_monitoring() |
| logger.info("β Monitoring daemon stopped") |
| except Exception as e: |
| logger.error(f"β Failed to stop monitoring: {e}") |
| |
| |
| try: |
| from server.cache_manager import cache |
| cache.clear() |
| logger.info("β Cache cleared") |
| except Exception as e: |
| logger.error(f"β Failed to clear cache: {e}") |
| |
| logger.info("β GEO Platform API shutdown complete") |
|
|
| |
| def sanitize_input(data: str) -> str: |
| """Sanitize user input""" |
| if not isinstance(data, str): |
| return data |
| |
| |
| dangerous_chars = ['<', '>', '"', "'", '&', ';'] |
| for char in dangerous_chars: |
| data = data.replace(char, '') |
| |
| return data.strip() |
|
|
| def validate_url(url: str) -> bool: |
| """Validate URL format""" |
| import re |
| url_pattern = r'^https?://[^\s/$.?#].[^\s]*$' |
| return bool(re.match(url_pattern, url)) |
|
|
| def validate_email(email: str) -> bool: |
| """Validate email format""" |
| import re |
| email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' |
| return bool(re.match(email_pattern, email)) |
|
|