muhammadshaheryar's picture
Add application file
dd1b74d
"""JWT authentication middleware and dependencies."""
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from models.user import User
from services.auth_service import decode_access_token
from database import get_db
# HTTP Bearer token security scheme
security = HTTPBearer()
class CredentialsError(Exception):
"""Custom exception for authentication errors."""
def __init__(self, detail: str):
self.detail = detail
def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db)
) -> User:
"""
Dependency to get the current authenticated user from JWT token.
Args:
credentials: HTTP Bearer token credentials
db: Database session
Returns:
The authenticated User object
Raises:
HTTPException: 401 if token is invalid or expired
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
token = credentials.credentials
payload = decode_access_token(token)
if payload is None:
raise credentials_exception
user_id_str: Optional[str] = payload.get("sub")
if user_id_str is None:
raise credentials_exception
# Convert string sub to int
try:
user_id = int(user_id_str)
except (ValueError, TypeError):
raise credentials_exception
# Fetch user from database
user = db.get(User, user_id)
if user is None:
raise credentials_exception
return user
def get_optional_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(
HTTPBearer(auto_error=False)
),
db: Session = Depends(get_db)
) -> Optional[User]:
"""
Optional authentication dependency.
Returns None if no valid token is provided, rather than raising an exception.
"""
if credentials is None:
return None
token = credentials.credentials
payload = decode_access_token(token)
if payload is None:
return None
user_id_str: Optional[str] = payload.get("sub")
if user_id_str is None:
return None
# Convert string sub to int
try:
user_id = int(user_id_str)
except (ValueError, TypeError):
return None
user = db.get(User, user_id)
return user