Spaces:
Sleeping
Sleeping
| from fastapi import Request, HTTPException, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.responses import Response, JSONResponse | |
| from config.settings import settings | |
| from auth.jwt_handler import verify_token | |
| import logging | |
| from typing import Callable, Awaitable | |
| logger = logging.getLogger(__name__) | |
| class AuthMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Authentication middleware that handles both internal service-to-service | |
| and external user authentication. | |
| """ | |
| def __init__(self, app): | |
| super().__init__(app) | |
| self.jwt_bearer = JWTBearer(auto_error=False) | |
| async def dispatch( | |
| self, request: Request, call_next: Callable[[Request], Awaitable[Response]] | |
| ) -> Response: | |
| """ | |
| Dispatch the request, performing authentication. | |
| - If the Authorization header contains the internal service secret, | |
| the request is marked as internal and allowed to proceed. | |
| - Otherwise, it attempts to validate a user JWT. | |
| """ | |
| request.state.is_internal = False | |
| request.state.user = None | |
| auth_header = request.headers.get("Authorization") | |
| if auth_header: | |
| try: | |
| scheme, token = auth_header.split() | |
| if scheme.lower() == "bearer": | |
| # Check for internal service secret | |
| if token == settings.jwt_secret: | |
| request.state.is_internal = True | |
| logger.debug("Internal service request authenticated.") | |
| return await call_next(request) | |
| # If not the internal secret, try to validate as a user JWT | |
| token_payload = verify_token(token) | |
| if token_payload: | |
| request.state.user = token_payload | |
| logger.debug(f"User request authenticated: {token_payload}") | |
| else: | |
| # If token is invalid (but not the service secret) | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid or expired token", | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid authentication scheme", | |
| ) | |
| except HTTPException as e: | |
| return JSONResponse( | |
| status_code=e.status_code, content={"detail": e.detail} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Authentication error: {e}", exc_info=True) | |
| return JSONResponse( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| content={"detail": "Could not validate credentials"}, | |
| ) | |
| # Let unprotected routes pass through | |
| return await call_next(request) | |
| class JWTBearer(HTTPBearer): | |
| """ | |
| Custom JWT Bearer authentication scheme for user-facing routes. | |
| """ | |
| def __init__(self, auto_error: bool = True): | |
| super(JWTBearer, self).__init__(auto_error=auto_error) | |
| async def __call__(self, request: Request): | |
| """ | |
| Validate token from request.state if already processed by middleware. | |
| """ | |
| if request.state.user: | |
| return request.state.user | |
| if self.auto_error: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Not authenticated", | |
| ) | |
| return None | |
| def get_current_user_id(request: Request) -> int: | |
| """ | |
| Dependency to get the current user ID from the request state. | |
| """ | |
| if request.state.is_internal: | |
| # For internal requests, trust the user_id from the URL path | |
| try: | |
| return int(request.path_params["user_id"]) | |
| except (KeyError, ValueError): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="user_id not found in URL for internal request" | |
| ) | |
| if request.state.user and "sub" in request.state.user: | |
| return int(request.state.user["sub"]) | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Not authenticated" | |
| ) | |
| def validate_user_id_from_token(request: Request, url_user_id: int) -> bool: | |
| """ | |
| Validates that the user_id from the token matches the one in the URL, | |
| or bypasses the check for internal requests. | |
| """ | |
| if request.state.is_internal: | |
| return True | |
| token_user_id = get_current_user_id(request) | |
| if token_user_id != url_user_id: | |
| logger.warning( | |
| f"User ID mismatch - Token: {token_user_id}, URL: {url_user_id}, Path: {request.url.path}" | |
| ) | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="User ID in token does not match user ID in URL", | |
| ) | |
| return True |