Spaces:
Sleeping
Sleeping
File size: 5,140 Bytes
697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 6a3de9e 697c967 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | from fastapi import Request, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response, JSONResponse
from config.settings import settings
from auth.jwt_handler import verify_token
import logging
from typing import Callable, Awaitable
logger = logging.getLogger(__name__)
class AuthMiddleware(BaseHTTPMiddleware):
"""
Authentication middleware that handles both internal service-to-service
and external user authentication.
"""
def __init__(self, app):
super().__init__(app)
self.jwt_bearer = JWTBearer(auto_error=False)
async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""
Dispatch the request, performing authentication.
- If the Authorization header contains the internal service secret,
the request is marked as internal and allowed to proceed.
- Otherwise, it attempts to validate a user JWT.
"""
request.state.is_internal = False
request.state.user = None
auth_header = request.headers.get("Authorization")
if auth_header:
try:
scheme, token = auth_header.split()
if scheme.lower() == "bearer":
# Check for internal service secret
if token == settings.jwt_secret:
request.state.is_internal = True
logger.debug("Internal service request authenticated.")
return await call_next(request)
# If not the internal secret, try to validate as a user JWT
token_payload = verify_token(token)
if token_payload:
request.state.user = token_payload
logger.debug(f"User request authenticated: {token_payload}")
else:
# If token is invalid (but not the service secret)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication scheme",
)
except HTTPException as e:
return JSONResponse(
status_code=e.status_code, content={"detail": e.detail}
)
except Exception as e:
logger.error(f"Authentication error: {e}", exc_info=True)
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Could not validate credentials"},
)
# Let unprotected routes pass through
return await call_next(request)
class JWTBearer(HTTPBearer):
"""
Custom JWT Bearer authentication scheme for user-facing routes.
"""
def __init__(self, auto_error: bool = True):
super(JWTBearer, self).__init__(auto_error=auto_error)
async def __call__(self, request: Request):
"""
Validate token from request.state if already processed by middleware.
"""
if request.state.user:
return request.state.user
if self.auto_error:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
)
return None
def get_current_user_id(request: Request) -> int:
"""
Dependency to get the current user ID from the request state.
"""
if request.state.is_internal:
# For internal requests, trust the user_id from the URL path
try:
return int(request.path_params["user_id"])
except (KeyError, ValueError):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="user_id not found in URL for internal request"
)
if request.state.user and "sub" in request.state.user:
return int(request.state.user["sub"])
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated"
)
def validate_user_id_from_token(request: Request, url_user_id: int) -> bool:
"""
Validates that the user_id from the token matches the one in the URL,
or bypasses the check for internal requests.
"""
if request.state.is_internal:
return True
token_user_id = get_current_user_id(request)
if token_user_id != url_user_id:
logger.warning(
f"User ID mismatch - Token: {token_user_id}, URL: {url_user_id}, Path: {request.url.path}"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User ID in token does not match user ID in URL",
)
return True |