Spaces:
Sleeping
Sleeping
| """ | |
| Authentication module for the AI Backend with RAG + Authentication | |
| Implements JWT-based authentication with password hashing | |
| """ | |
| from datetime import datetime, timedelta | |
| from typing import Optional, Union | |
| import jwt | |
| from passlib.context import CryptContext | |
| from fastapi import HTTPException, status, Depends | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from pydantic import BaseModel | |
| import logging | |
| from ..config.settings import settings | |
| logger = logging.getLogger(__name__) | |
| # Password hashing context | |
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
| # JWT security scheme | |
| security = HTTPBearer() | |
| class TokenData(BaseModel): | |
| username: Optional[str] = None | |
| user_id: Optional[str] = None | |
| class AuthHandler: | |
| def __init__(self): | |
| self.secret_key = settings.secret_key | |
| self.algorithm = settings.jwt_algorithm | |
| self.access_token_expires = timedelta(minutes=settings.jwt_expires_in // 60) # Convert seconds to minutes | |
| def verify_password(self, plain_password: str, hashed_password: str) -> bool: | |
| """ | |
| Verify a plain password against a hashed password | |
| """ | |
| return pwd_context.verify(plain_password, hashed_password) | |
| def get_password_hash(self, password: str) -> str: | |
| """ | |
| Generate a hash for a plain password | |
| """ | |
| return pwd_context.hash(password) | |
| def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None) -> str: | |
| """ | |
| Create a JWT access token with optional expiration time | |
| """ | |
| to_encode = data.copy() | |
| if expires_delta: | |
| expire = datetime.utcnow() + expires_delta | |
| else: | |
| expire = datetime.utcnow() + self.access_token_expires | |
| to_encode.update({"exp": expire, "iat": datetime.utcnow()}) | |
| encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) | |
| return encoded_jwt | |
| def decode_access_token(self, token: str) -> Optional[TokenData]: | |
| """ | |
| Decode a JWT token and return token data | |
| """ | |
| try: | |
| payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) | |
| username: str = payload.get("sub") | |
| user_id: str = payload.get("user_id") | |
| if username is None: | |
| return None | |
| token_data = TokenData(username=username, user_id=user_id) | |
| return token_data | |
| except jwt.exceptions.ExpiredSignatureError: | |
| logger.warning("Expired token attempted to be decoded") | |
| return None | |
| except jwt.exceptions.InvalidTokenError: | |
| logger.warning("Invalid token attempted to be decoded") | |
| return None | |
| async def get_current_user(self, token: str = Depends(security)) -> TokenData: | |
| """ | |
| Get the current user from the provided JWT token | |
| This function can be used as a dependency in route handlers | |
| """ | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| token_data = self.decode_access_token(token.credentials) | |
| if token_data is None: | |
| raise credentials_exception | |
| return token_data | |
| except Exception as e: | |
| logger.error(f"Error getting current user: {e}") | |
| raise credentials_exception | |
| # Create a global instance of AuthHandler | |
| auth_handler = AuthHandler() | |
| # Convenience functions for use in other modules | |
| def get_password_hash(password: str) -> str: | |
| """Generate a hash for a plain password""" | |
| return auth_handler.get_password_hash(password) | |
| def verify_password(plain_password: str, hashed_password: str) -> bool: | |
| """Verify a plain password against a hashed password""" | |
| return auth_handler.verify_password(plain_password, hashed_password) | |
| def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: | |
| """Create a JWT access token""" | |
| return auth_handler.create_access_token(data, expires_delta) | |
| def decode_access_token(token: str) -> Optional[TokenData]: | |
| """Decode a JWT token and return token data""" | |
| return auth_handler.decode_access_token(token) | |
| async def get_current_user(token: str = Depends(security)) -> TokenData: | |
| """Get the current user from the provided JWT token""" | |
| return await auth_handler.get_current_user(token) | |
| def create_user_token(user_id: str, username: str) -> str: | |
| """Create a token specifically for a user""" | |
| data = {"sub": username, "user_id": user_id} | |
| return create_access_token(data) |