"""JWT authentication middleware and dependencies.""" from typing import Optional from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlalchemy.orm import Session from models.user import User from services.auth_service import decode_access_token from database import get_db # HTTP Bearer token security scheme security = HTTPBearer() class CredentialsError(Exception): """Custom exception for authentication errors.""" def __init__(self, detail: str): self.detail = detail def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(security), db: Session = Depends(get_db) ) -> User: """ Dependency to get the current authenticated user from JWT token. Args: credentials: HTTP Bearer token credentials db: Database session Returns: The authenticated User object Raises: HTTPException: 401 if token is invalid or expired """ credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) token = credentials.credentials payload = decode_access_token(token) if payload is None: raise credentials_exception user_id_str: Optional[str] = payload.get("sub") if user_id_str is None: raise credentials_exception # Convert string sub to int try: user_id = int(user_id_str) except (ValueError, TypeError): raise credentials_exception # Fetch user from database user = db.get(User, user_id) if user is None: raise credentials_exception return user def get_optional_user( credentials: Optional[HTTPAuthorizationCredentials] = Depends( HTTPBearer(auto_error=False) ), db: Session = Depends(get_db) ) -> Optional[User]: """ Optional authentication dependency. Returns None if no valid token is provided, rather than raising an exception. """ if credentials is None: return None token = credentials.credentials payload = decode_access_token(token) if payload is None: return None user_id_str: Optional[str] = payload.get("sub") if user_id_str is None: return None # Convert string sub to int try: user_id = int(user_id_str) except (ValueError, TypeError): return None user = db.get(User, user_id) return user