Spaces:
Sleeping
Sleeping
| from fastapi import Request, HTTPException, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from typing import Optional | |
| from src.utils.security import verify_token | |
| from src.models.user import User | |
| from sqlalchemy.orm import Session | |
| from src.database import get_db_session | |
| class JWTBearer(HTTPBearer): | |
| """ | |
| JWT Bearer token authentication middleware. | |
| """ | |
| def __init__(self, auto_error: bool = True): | |
| super(JWTBearer, self).__init__(auto_error=auto_error) | |
| async def __call__(self, request: Request) -> Optional[str]: | |
| """ | |
| Validate JWT token from request. | |
| Args: | |
| request: FastAPI request object | |
| Returns: | |
| str: The validated token if valid, None if invalid and auto_error is False | |
| Raises: | |
| HTTPException: If token is invalid and auto_error is True | |
| """ | |
| credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request) | |
| if credentials: | |
| if not credentials.scheme == "Bearer": | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Invalid authentication scheme." | |
| ) | |
| token = credentials.credentials | |
| else: | |
| # Try to get token from cookie as fallback | |
| token = request.cookies.get("access_token") | |
| if not token: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="No token provided." | |
| ) | |
| # Verify the token | |
| payload = verify_token(token) | |
| if payload is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Invalid or expired token." | |
| ) | |
| # Add user ID to request state for later use | |
| user_id = payload.get("sub") | |
| if not user_id: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Invalid token payload." | |
| ) | |
| request.state.user_id = user_id | |
| return token | |
| def verify_user_access(user_id: str, authenticated_user_id: str) -> bool: | |
| """ | |
| Verify that the requested user ID matches the authenticated user ID. | |
| Args: | |
| user_id: The user ID from the request path/params | |
| authenticated_user_id: The user ID from the JWT token | |
| Returns: | |
| bool: True if the IDs match, False otherwise | |
| """ | |
| return user_id == authenticated_user_id | |
| def get_current_user_from_request(request: Request) -> str: | |
| """ | |
| Get the authenticated user ID from the request state. | |
| Args: | |
| request: FastAPI request object | |
| Returns: | |
| str: The authenticated user ID | |
| """ | |
| if hasattr(request.state, 'user_id'): | |
| return request.state.user_id | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="User not authenticated" | |
| ) |