bookmyservice-ams / app /auth /middleware.py
MukeshKapoor25's picture
feat: add authentication, security middleware, and optimize JSON handling
fd2ce9d
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
import time
from collections import defaultdict
import asyncio
from typing import Dict, List
from app.core.config import settings
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to all responses"""
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Content-Security-Policy"] = "default-src 'self'; img-src 'self' data: https:; style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; font-src 'self' data: https:"
return response
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Simple rate limiting middleware with memory cleanup"""
def __init__(self, app, calls: int = None, period: int = None, cleanup_interval: int = 300):
super().__init__(app)
self.calls = calls or settings.RATE_LIMIT_CALLS
self.period = period or settings.RATE_LIMIT_PERIOD
self.cleanup_interval = cleanup_interval # Cleanup every 5 minutes by default
self.clients: Dict[str, List[float]] = defaultdict(list)
self.last_cleanup = time.time()
def _cleanup_inactive_clients(self, now: float):
"""Remove clients that haven't made requests within the cleanup interval"""
inactive_clients = []
for client_ip, request_times in self.clients.items():
if not request_times or (now - max(request_times)) > self.cleanup_interval:
inactive_clients.append(client_ip)
for client_ip in inactive_clients:
del self.clients[client_ip]
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
now = time.time()
# Periodic cleanup of inactive clients
if now - self.last_cleanup > self.cleanup_interval:
self._cleanup_inactive_clients(now)
self.last_cleanup = now
# Clean old requests for current client
self.clients[client_ip] = [
req_time for req_time in self.clients[client_ip]
if now - req_time < self.period
]
# Check rate limit
if len(self.clients[client_ip]) >= self.calls:
raise HTTPException(
status_code=429,
detail="Rate limit exceeded. Please try again later."
)
# Add current request
self.clients[client_ip].append(now)
response = await call_next(request)
return response
def add_security_middleware(app: FastAPI):
"""Add all security middleware to the FastAPI app"""
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
allow_headers=["*"],
)
# Trusted host middleware
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=settings.ALLOWED_HOSTS
)
# Security headers middleware
app.add_middleware(SecurityHeadersMiddleware)
# Rate limiting middleware
app.add_middleware(RateLimitMiddleware)