File size: 4,732 Bytes
db7c1e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
Authentication module for the AI Backend with RAG + Authentication
Implements JWT-based authentication with password hashing
"""
from datetime import datetime, timedelta
from typing import Optional, Union
import jwt
from passlib.context import CryptContext
from fastapi import HTTPException, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
import logging

from ..config.settings import settings

logger = logging.getLogger(__name__)

# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

# JWT security scheme
security = HTTPBearer()

class TokenData(BaseModel):
    username: Optional[str] = None
    user_id: Optional[str] = None


class AuthHandler:
    def __init__(self):
        self.secret_key = settings.secret_key
        self.algorithm = settings.jwt_algorithm
        self.access_token_expires = timedelta(minutes=settings.jwt_expires_in // 60)  # Convert seconds to minutes

    def verify_password(self, plain_password: str, hashed_password: str) -> bool:
        """
        Verify a plain password against a hashed password
        """
        return pwd_context.verify(plain_password, hashed_password)

    def get_password_hash(self, password: str) -> str:
        """
        Generate a hash for a plain password
        """
        return pwd_context.hash(password)

    def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None) -> str:
        """
        Create a JWT access token with optional expiration time
        """
        to_encode = data.copy()

        if expires_delta:
            expire = datetime.utcnow() + expires_delta
        else:
            expire = datetime.utcnow() + self.access_token_expires

        to_encode.update({"exp": expire, "iat": datetime.utcnow()})

        encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
        return encoded_jwt

    def decode_access_token(self, token: str) -> Optional[TokenData]:
        """
        Decode a JWT token and return token data
        """
        try:
            payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
            username: str = payload.get("sub")
            user_id: str = payload.get("user_id")

            if username is None:
                return None

            token_data = TokenData(username=username, user_id=user_id)
            return token_data
        except jwt.exceptions.ExpiredSignatureError:
            logger.warning("Expired token attempted to be decoded")
            return None
        except jwt.exceptions.InvalidTokenError:
            logger.warning("Invalid token attempted to be decoded")
            return None

    async def get_current_user(self, token: str = Depends(security)) -> TokenData:
        """
        Get the current user from the provided JWT token
        This function can be used as a dependency in route handlers
        """
        credentials_exception = HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Could not validate credentials",
            headers={"WWW-Authenticate": "Bearer"},
        )

        try:
            token_data = self.decode_access_token(token.credentials)
            if token_data is None:
                raise credentials_exception
            return token_data
        except Exception as e:
            logger.error(f"Error getting current user: {e}")
            raise credentials_exception


# Create a global instance of AuthHandler
auth_handler = AuthHandler()

# Convenience functions for use in other modules
def get_password_hash(password: str) -> str:
    """Generate a hash for a plain password"""
    return auth_handler.get_password_hash(password)

def verify_password(plain_password: str, hashed_password: str) -> bool:
    """Verify a plain password against a hashed password"""
    return auth_handler.verify_password(plain_password, hashed_password)

def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
    """Create a JWT access token"""
    return auth_handler.create_access_token(data, expires_delta)

def decode_access_token(token: str) -> Optional[TokenData]:
    """Decode a JWT token and return token data"""
    return auth_handler.decode_access_token(token)

async def get_current_user(token: str = Depends(security)) -> TokenData:
    """Get the current user from the provided JWT token"""
    return await auth_handler.get_current_user(token)

def create_user_token(user_id: str, username: str) -> str:
    """Create a token specifically for a user"""
    data = {"sub": username, "user_id": user_id}
    return create_access_token(data)