Spaces:
Runtime error
Runtime error
| """JWT authentication middleware and dependencies.""" | |
| from typing import Optional | |
| from fastapi import Depends, HTTPException, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from sqlalchemy.orm import Session | |
| from models.user import User | |
| from services.auth_service import decode_access_token | |
| from database import get_db | |
| # HTTP Bearer token security scheme | |
| security = HTTPBearer() | |
| class CredentialsError(Exception): | |
| """Custom exception for authentication errors.""" | |
| def __init__(self, detail: str): | |
| self.detail = detail | |
| def get_current_user( | |
| credentials: HTTPAuthorizationCredentials = Depends(security), | |
| db: Session = Depends(get_db) | |
| ) -> User: | |
| """ | |
| Dependency to get the current authenticated user from JWT token. | |
| Args: | |
| credentials: HTTP Bearer token credentials | |
| db: Database session | |
| Returns: | |
| The authenticated User object | |
| Raises: | |
| HTTPException: 401 if token is invalid or expired | |
| """ | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| token = credentials.credentials | |
| payload = decode_access_token(token) | |
| if payload is None: | |
| raise credentials_exception | |
| user_id_str: Optional[str] = payload.get("sub") | |
| if user_id_str is None: | |
| raise credentials_exception | |
| # Convert string sub to int | |
| try: | |
| user_id = int(user_id_str) | |
| except (ValueError, TypeError): | |
| raise credentials_exception | |
| # Fetch user from database | |
| user = db.get(User, user_id) | |
| if user is None: | |
| raise credentials_exception | |
| return user | |
| def get_optional_user( | |
| credentials: Optional[HTTPAuthorizationCredentials] = Depends( | |
| HTTPBearer(auto_error=False) | |
| ), | |
| db: Session = Depends(get_db) | |
| ) -> Optional[User]: | |
| """ | |
| Optional authentication dependency. | |
| Returns None if no valid token is provided, rather than raising an exception. | |
| """ | |
| if credentials is None: | |
| return None | |
| token = credentials.credentials | |
| payload = decode_access_token(token) | |
| if payload is None: | |
| return None | |
| user_id_str: Optional[str] = payload.get("sub") | |
| if user_id_str is None: | |
| return None | |
| # Convert string sub to int | |
| try: | |
| user_id = int(user_id_str) | |
| except (ValueError, TypeError): | |
| return None | |
| user = db.get(User, user_id) | |
| return user | |