import re from fastapi.responses import JSONResponse from fastapi import HTTPException, Request from starlette.middleware.base import BaseHTTPMiddleware from fastapi.middleware.cors import CORSMiddleware from src.utils import JWTUtil from src.services import SessionService EXEMPT_ROUTES = [ "/", "/docs", "/openapi.json", "/api/v1/auth/login", "/api/v1/auth/logout", "/api/v1/auth/register", ] EXEMPT_ROUTE_PATTERNS = [] class AuthenticationMiddleware(BaseHTTPMiddleware): def __init__(self, app): super().__init__(app) self._jwt = JWTUtil() self._session_service = SessionService() async def dispatch(self, request: Request, call_next): if request.method == "OPTIONS": return await call_next(request) if self._require_auth(request.url.path): try: user_payload = await self._validate_api_key( request.headers.get("Authorization") ) request.state.user = user_payload except HTTPException as e: # Create a response with CORS headers response = JSONResponse( status_code=e.status_code, content={"detail": e.detail} ) origin = request.headers.get("origin", "*") response.headers["Access-Control-Allow-Origin"] = origin response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Allow-Methods"] = "*" response.headers["Access-Control-Allow-Headers"] = "*" return response except Exception as e: response = JSONResponse( status_code=500, content={"detail": "Internal server error"} ) origin = request.headers.get("origin", "*") response.headers["Access-Control-Allow-Origin"] = origin response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Allow-Methods"] = "*" response.headers["Access-Control-Allow-Headers"] = "*" return response return await call_next(request) def _require_auth(self, path: str): if path in EXEMPT_ROUTES or any( re.match(pattern, path) for pattern in EXEMPT_ROUTE_PATTERNS ): return False return True async def _validate_api_key(self, api_key: str): if not api_key: raise HTTPException(status_code=401, detail="No API key provided") if not re.match(r"Bearer .+", api_key): raise HTTPException(status_code=401, detail="Invalid API key") token = api_key.split(" ")[1] is_expired = await self._session_service.is_session_expired(token) if is_expired: raise HTTPException(status_code=401, detail="Token expired") jwt_payload = self._jwt.validate_jwt(token) if not jwt_payload: raise HTTPException(status_code=401, detail="Invalid token") return jwt_payload