Spaces:
Runtime error
Runtime error
File size: 2,188 Bytes
dd1b74d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | """FastAPI dependencies for authentication and database."""
from typing import Annotated
from fastapi import Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from models.user import User
from services.auth_service import decode_access_token
def get_current_user(
credentials: Annotated[str | None, Depends(
"get_authorization_header"
)] = None,
db: Session = Depends(get_db)
) -> User:
"""
Dependency to get the current authenticated user from JWT token.
Args:
credentials: Bearer token from Authorization header
db: Database session
Returns:
The authenticated User object
Raises:
HTTPException: 401 if token is invalid or expired
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
if credentials is None:
raise credentials_exception
# Extract token from "Bearer <token>" format
if not credentials.startswith("Bearer "):
raise credentials_exception
token = credentials.replace("Bearer ", "")
payload = decode_access_token(token)
if payload is None:
raise credentials_exception
user_id: int | None = payload.get("sub")
if user_id is None:
raise credentials_exception
# Fetch user from database
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise credentials_exception
return user
def get_authorization_header(
authorization: str | None = None
) -> str | None:
"""
Extract the authorization header value.
This is a separate function to allow dependency injection.
"""
if authorization is None:
return None
return authorization
async def get_current_user_async():
"""Async version of get_current_user for future async implementation."""
# TODO: Implement async version with asyncpg/SQLModel async
pass
class AuthError(Exception):
"""Custom authentication error."""
def __init__(self, detail: str):
self.detail = detail
|