Spaces:
Sleeping
Sleeping
File size: 5,015 Bytes
a42ab7e 050d8f8 a42ab7e 050d8f8 a42ab7e 050d8f8 e39877e a42ab7e bc8ed4e 1bd7131 050d8f8 1bd7131 050d8f8 1bd7131 19e4a8c 1bd7131 050d8f8 1bd7131 050d8f8 1bd7131 050d8f8 1bd7131 050d8f8 1bd7131 050d8f8 1bd7131 75fb504 1bd7131 050d8f8 1bd7131 050d8f8 1bd7131 050d8f8 1bd7131 050d8f8 1bd7131 19e4a8c 1bd7131 7dfb3ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
"""
Authentication Dependencies
FastAPI dependencies for user authentication and authorization.
"""
import logging
from typing import Optional
from fastapi import Request, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from core.database import get_db
from core.models import User
from services.auth_service.jwt_provider import (
verify_access_token,
TokenExpiredError,
InvalidTokenError,
JWTError
)
logger = logging.getLogger(__name__)
async def get_current_user(
req: Request,
db: AsyncSession = Depends(get_db)
) -> User:
"""
Extract and verify JWT from Authorization header.
Returns the authenticated user.
Also validates token_version to support instant logout/invalidation.
Usage:
@router.get("/protected")
async def protected_route(user: User = Depends(get_current_user)):
return {"user_id": user.user_id}
"""
auth_header = req.headers.get("Authorization")
if not auth_header:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Authorization header",
headers={"WWW-Authenticate": "Bearer"}
)
if not auth_header.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid Authorization header format. Use: Bearer <token>",
headers={"WWW-Authenticate": "Bearer"}
)
token = auth_header.split(" ", 1)[1]
try:
payload = verify_access_token(token)
# Ensure it's an access token, not a refresh token
if payload.extra.get("type") == "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Cannot use refresh token for API access"
)
except TokenExpiredError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired. Please sign in again.",
headers={"WWW-Authenticate": "Bearer"}
)
except InvalidTokenError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid token: {str(e)}",
headers={"WWW-Authenticate": "Bearer"}
)
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Authentication error: {str(e)}",
headers={"WWW-Authenticate": "Bearer"}
)
# Get user from DB
query = select(User).where(
User.user_id == payload.user_id,
User.is_active == True
)
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive"
)
# Validate token version - if user's version is higher, token is invalidated
if payload.token_version < user.token_version:
logger.info(f"Token invalidated for user {user.user_id}: token_version {payload.token_version} < {user.token_version}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has been invalidated. Please sign in again.",
headers={"WWW-Authenticate": "Bearer"}
)
return user
async def get_optional_user(
req: Request,
db: AsyncSession = Depends(get_db)
) -> Optional[User]:
"""
Attempt to extract and verify JWT from Authorization header.
Returns the authenticated user if valid, or None if not authenticated.
Unlike get_current_user, this does NOT raise errors for missing/invalid tokens.
Useful for endpoints that work for both authenticated and anonymous users.
Usage:
@router.get("/optional-auth")
async def optional_auth_route(user: Optional[User] = Depends(get_optional_user)):
if user:
return {"user_id": user.user_id}
return {"message": "anonymous"}
"""
auth_header = req.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
return None
token = auth_header.split(" ", 1)[1]
try:
payload = verify_access_token(token)
except (TokenExpiredError, InvalidTokenError, JWTError) as e:
logger.debug(f"Optional auth failed: {e}")
return None
# Get user from DB
query = select(User).where(
User.user_id == payload.user_id,
User.is_active == True
)
result = await db.execute(query)
user = result.scalar_one_or_none()
if not user:
return None
# Validate token version
if payload.token_version < user.token_version:
logger.debug(f"Token invalidated for user {user.user_id}")
return None
return user
|