Spaces:
Sleeping
Sleeping
| """ | |
| Authentication Dependencies | |
| FastAPI dependencies for user authentication and authorization. | |
| """ | |
| import logging | |
| from typing import Optional | |
| from fastapi import Request, Depends, HTTPException, status | |
| from sqlalchemy import select | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from core.database import get_db | |
| from core.models import User | |
| from services.auth_service.jwt_provider import ( | |
| verify_access_token, | |
| TokenExpiredError, | |
| InvalidTokenError, | |
| JWTError | |
| ) | |
| logger = logging.getLogger(__name__) | |
| async def get_current_user( | |
| req: Request, | |
| db: AsyncSession = Depends(get_db) | |
| ) -> User: | |
| """ | |
| Extract and verify JWT from Authorization header. | |
| Returns the authenticated user. | |
| Also validates token_version to support instant logout/invalidation. | |
| Usage: | |
| @router.get("/protected") | |
| async def protected_route(user: User = Depends(get_current_user)): | |
| return {"user_id": user.user_id} | |
| """ | |
| auth_header = req.headers.get("Authorization") | |
| if not auth_header: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Missing Authorization header", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| if not auth_header.startswith("Bearer "): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid Authorization header format. Use: Bearer <token>", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| token = auth_header.split(" ", 1)[1] | |
| try: | |
| payload = verify_access_token(token) | |
| # Ensure it's an access token, not a refresh token | |
| if payload.extra.get("type") == "refresh": | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Cannot use refresh token for API access" | |
| ) | |
| except TokenExpiredError: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Token has expired. Please sign in again.", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| except InvalidTokenError as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=f"Invalid token: {str(e)}", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| except JWTError as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=f"Authentication error: {str(e)}", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| # Get user from DB | |
| query = select(User).where( | |
| User.user_id == payload.user_id, | |
| User.is_active == True | |
| ) | |
| result = await db.execute(query) | |
| user = result.scalar_one_or_none() | |
| if not user: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="User not found or inactive" | |
| ) | |
| # Validate token version - if user's version is higher, token is invalidated | |
| if payload.token_version < user.token_version: | |
| logger.info(f"Token invalidated for user {user.user_id}: token_version {payload.token_version} < {user.token_version}") | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Token has been invalidated. Please sign in again.", | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| return user | |
| async def get_optional_user( | |
| req: Request, | |
| db: AsyncSession = Depends(get_db) | |
| ) -> Optional[User]: | |
| """ | |
| Attempt to extract and verify JWT from Authorization header. | |
| Returns the authenticated user if valid, or None if not authenticated. | |
| Unlike get_current_user, this does NOT raise errors for missing/invalid tokens. | |
| Useful for endpoints that work for both authenticated and anonymous users. | |
| Usage: | |
| @router.get("/optional-auth") | |
| async def optional_auth_route(user: Optional[User] = Depends(get_optional_user)): | |
| if user: | |
| return {"user_id": user.user_id} | |
| return {"message": "anonymous"} | |
| """ | |
| auth_header = req.headers.get("Authorization") | |
| if not auth_header or not auth_header.startswith("Bearer "): | |
| return None | |
| token = auth_header.split(" ", 1)[1] | |
| try: | |
| payload = verify_access_token(token) | |
| except (TokenExpiredError, InvalidTokenError, JWTError) as e: | |
| logger.debug(f"Optional auth failed: {e}") | |
| return None | |
| # Get user from DB | |
| query = select(User).where( | |
| User.user_id == payload.user_id, | |
| User.is_active == True | |
| ) | |
| result = await db.execute(query) | |
| user = result.scalar_one_or_none() | |
| if not user: | |
| return None | |
| # Validate token version | |
| if payload.token_version < user.token_version: | |
| logger.debug(f"Token invalidated for user {user.user_id}") | |
| return None | |
| return user | |