File size: 6,367 Bytes
cd46ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""
Security utilities for ClipboardHealthAI application.

This module provides authentication and authorization functionality, including:
- Password verification
- JWT token creation and validation
- Permission-based endpoint protection using FastAPI dependencies
"""

from datetime import datetime, timedelta

import anyio
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError, jwt
from passlib.context import CryptContext

from cbh.api.account.dto import AccountType
from cbh.api.account.models import AccountModel
from cbh.core.config import settings


def verify_password(plain_password, hashed_password) -> bool:
    """
    Verify a password against its hashed version.

    Args:
        plain_password: The plain text password to verify
        hashed_password: The hashed password to check against

    Returns:
        bool: True if the password matches, False otherwise
    """
    result = CryptContext(schemes=["bcrypt"], deprecated="auto").verify(
        plain_password, hashed_password
    )
    return result


def create_access_token(email: str, account_id: str, account_type: AccountType):
    """
    Create a JWT access token for a user.

    Args:
        email: User's email address
        account_id: User's account ID
        account_type: User's account type

    Returns:
        str: Encoded JWT token
    """
    payload = {
        "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name": email,
        "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier": account_id,
        "accountId": account_id,
        "accountType": account_type.value,
        "iss": settings.Issuer,
        "aud": settings.Audience,
        "exp": datetime.utcnow() + timedelta(days=30),
    }
    encoded_jwt = jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256")
    return encoded_jwt


class PermissionDependency:
    """
    FastAPI dependency for protecting endpoints with authentication.

    This class implements the callable interface required for FastAPI dependencies
    and validates JWT tokens for protected endpoints.
    """

    def __init__(
        self, account_type: list[AccountType] | None = None, required: bool = True
    ):
        self.account_types = account_type
        self.required = required

    def __call__(
        self,
        credentials: HTTPAuthorizationCredentials | None = Depends(
            HTTPBearer(auto_error=False)
        ),
    ) -> AccountModel | None:
        """
        Validate authorization credentials and return account details.

        This method is called by FastAPI when the dependency is used.

        Args:
            credentials: The HTTP authorization credentials from the request

        Returns:
            AccountModel: The account details if authentication is successful

        Raises:
            HTTPException: If authentication fails
        """
        try:
            if not credentials and self.required:
                raise HTTPException(status_code=401, detail="Unauthorized")
            elif not credentials and not self.required:
                return None
            account_id = self.authenticate_jwt_token(credentials.credentials)
            account_data = anyio.from_thread.run(self.get_account_by_id, account_id)
            self.check_account_health(account_data)
            return AccountModel.from_mongo(account_data)

        except JWTError as e:
            raise HTTPException(  # pylint: disable=W0707
                status_code=403, detail="Permission denied"
            )
        except Exception as e:
            if isinstance(e, HTTPException) and e.status_code == 401:
                raise e
            raise HTTPException(  # pylint: disable=W0707
                status_code=403, detail="Permission denied"
            )

    @staticmethod
    async def get_account_by_id(account_id: str) -> dict:
        """
        Retrieve account data from the database by ID.

        Args:
            account_id: The account ID to look up

        Returns:
            dict: Account data from the database
        """
        account = await settings.DB_CLIENT.accounts.find_one({"id": account_id})
        if not account:
            raise HTTPException(status_code=403, detail="Permission denied")
        return account

    def check_account_health(self, account: dict):
        """
        Verify account data is valid and active.

        Args:
            account: Account data dictionary

        Raises:
            HTTPException: If the account is not valid
        """
        if not account:
            raise HTTPException(status_code=403, detail="Permission denied")
        if (
            self.account_types
            and AccountType(account["accountType"]) not in self.account_types
        ):
            raise HTTPException(status_code=403, detail="Permission denied")

    @staticmethod
    def authenticate_jwt_token(token: str) -> str:
        """
        Validate a JWT token and extract the account ID.

        Args:
            token: JWT token string

        Returns:
            str: Account ID from the token

        Raises:
            HTTPException: If token validation fails
        """
        payload = jwt.decode(
            token, settings.SECRET_KEY, algorithms="HS256", audience=settings.Audience
        )
        email: str | None = payload.get(
            "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name"
        )
        account_id: str | None = payload.get("accountId")

        if email is None or account_id is None:
            raise HTTPException(status_code=403, detail="Permission denied")

        return account_id


def check_account_token(token: str) -> dict | None:
    try:
        payload = jwt.decode(
            token, settings.SECRET_KEY, algorithms="HS256", audience=settings.Audience
        )
        email: str | None = payload.get(
            "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name"
        )
        account_id: str | None = payload.get("accountId")
        if email is None or account_id is None:
            return None
        return {
            "email": email,
            "account_id": account_id,
            "account_type": payload.get("accountType"),
        }
    except Exception as _:
        return None