""" Auth Middleware - Request authentication layer Intercepts requests to validate JWT tokens and attach authenticated user to request.state for use in route handlers. """ import logging from fastapi import Request, HTTPException, status from fastapi.responses import JSONResponse from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from starlette.middleware.base import BaseHTTPMiddleware from core.database import async_session_maker from core.models import User from core.api_response import error_response, ErrorCode from services.auth_service.config import AuthServiceConfig from services.auth_service.jwt_provider import ( verify_access_token, TokenExpiredError, InvalidTokenError, JWTError, ) from services.base_service.middleware_chain import ( BaseServiceMiddleware, get_request_context, ) logger = logging.getLogger(__name__) class AuthMiddleware(BaseServiceMiddleware): """ Authentication middleware for request validation. Flow: 1. Check if route requires/allows auth based on URL 2. Extract Authorization header 3. Verify JWT token 4. Load user from database 5. Attach user to request.state.user 6. Continue to next middleware/route Public routes skip all auth checks. Required routes must have valid auth or return 401. Optional routes attach user if auth is provided, but don't fail if missing. """ SERVICE_NAME = "auth" async def dispatch(self, request: Request, call_next): """Process request through auth middleware.""" # Skip OPTIONS requests (CORS preflight) if request.method == "OPTIONS": return await call_next(request) # Initialize request context ctx = get_request_context(request) # Get path and method from request path = request.url.path # Check if route is public (skip all auth) if AuthServiceConfig.is_public(path): self.log_request(request, "Public route, skipping auth") request.state.user = None ctx.user = None ctx.is_authenticated = False response = await call_next(request) return response # Check if route requires auth or allows optional auth requires_auth = AuthServiceConfig.requires_auth(path) allows_optional = AuthServiceConfig.allows_optional_auth(path) # If route doesn't require auth and doesn't allow optional, skip if not requires_auth and not allows_optional: self.log_request(request, "Route not configured for auth, skipping") request.state.user = None ctx.user = None ctx.is_authenticated = False response = await call_next(request) return response # Extract Authorization header auth_header = request.headers.get("Authorization") # If no auth header if not auth_header: if requires_auth: self.log_request(request, "Missing Authorization header (required)") return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content=error_response( ErrorCode.UNAUTHORIZED, "Missing Authorization header" ), headers={"WWW-Authenticate": "Bearer"}, ) else: # Optional auth, no header provided self.log_request(request, "No auth header (optional route)") request.state.user = None ctx.user = None ctx.is_authenticated = False response = await call_next(request) return response # Validate Authorization header format if not auth_header.startswith("Bearer "): if requires_auth: self.log_request(request, "Invalid Authorization header format") return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content=error_response( ErrorCode.TOKEN_INVALID, "Invalid Authorization header format. Use: Bearer " ), headers={"WWW-Authenticate": "Bearer"}, ) else: # Optional auth, invalid format request.state.user = None ctx.user = None ctx.is_authenticated = False response = await call_next(request) return response # Extract token token = auth_header.split(" ", 1)[1] # Verify token try: payload = verify_access_token(token) except TokenExpiredError: if requires_auth: self.log_request(request, "Token expired") return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content=error_response( ErrorCode.TOKEN_EXPIRED, "Token has expired. Please sign in again." ), headers={"WWW-Authenticate": "Bearer"}, ) else: # Optional auth, expired token request.state.user = None ctx.user = None ctx.is_authenticated = False response = await call_next(request) return response except (InvalidTokenError, JWTError) as e: if requires_auth: self.log_error(request, f"Token verification failed: {e}") return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content=error_response( ErrorCode.TOKEN_INVALID, f"Invalid token: {str(e)}" ), headers={"WWW-Authenticate": "Bearer"}, ) else: # Optional auth, invalid token request.state.user = None ctx.user = None ctx.is_authenticated = False response = await call_next(request) return response # Get database session async with async_session_maker() as db: try: # Load user from database query = select(User).where( User.user_id == payload.user_id, User.is_active == True ) result = await db.execute(query) user = result.scalar_one_or_none() if not user: if requires_auth: self.log_request(request, "User not found or inactive") return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content=error_response( ErrorCode.USER_NOT_FOUND, "User not found or inactive" ), ) else: # Optional auth, user not found request.state.user = None ctx.user = None ctx.is_authenticated = False response = await call_next(request) return response if payload.token_version < user.token_version: if requires_auth: self.log_request( request, f"Token invalidated (version {payload.token_version} < {user.token_version})" ) return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content=error_response( ErrorCode.TOKEN_INVALID, "Token has been invalidated. Please sign in again." ), headers={"WWW-Authenticate": "Bearer"}, ) else: # Optional auth, invalidated token request.state.user = None ctx.user = None ctx.is_authenticated = False response = await call_next(request) return response # Attach user to request state request.state.user = user ctx.set_user(user) # Check if user is admin is_admin = AuthServiceConfig.is_admin(user.email) request.state.is_admin = is_admin ctx.set_flag('is_admin', is_admin) self.log_request(request, f"Authenticated user: {user.user_id}") # Continue to next middleware/route response = await call_next(request) return response finally: await db.close() __all__ = ['AuthMiddleware']