File size: 3,825 Bytes
35e7795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""用户认证相关的工具函数和依赖项"""

from typing import Optional

import jwt
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session

from qa_annotate.config import settings
from qa_annotate.database.base import get_db
from qa_annotate.database.crud import UserCRUD
from qa_annotate.schema.user import User

# HTTP Bearer 认证方案
security = HTTPBearer()


def decode_token(token: str) -> Optional[dict]:
    """解码JWT令牌"""
    try:
        # 使用JWT解码令牌
        token_data = jwt.decode(
            token,
            settings.SECRET_KEY,
            algorithms=[settings.ALGORITHM],
        )
        return token_data
    except jwt.ExpiredSignatureError:
        # 令牌已过期
        return None
    except jwt.InvalidTokenError:
        # 令牌无效
        return None
    except Exception:
        return None


def get_token_from_request(request: Request) -> Optional[str]:
    """从请求中获取 token(支持 Authorization 头和 cookie)"""
    # 尝试从请求头获取认证信息
    authorization = request.headers.get("Authorization")
    if authorization and authorization.startswith("Bearer "):
        return authorization.split(" ")[1]

    # 如果没有 Authorization 头,尝试从 cookie 获取
    return request.cookies.get("access_token")


def get_user_from_token(token: str, db: Session) -> Optional[User]:
    """根据 token 获取用户(内部辅助函数)"""
    # 解码令牌
    token_data = decode_token(token)
    if not token_data:
        return None

    # 从令牌中获取用户ID
    user_id = token_data.get("user_id")
    if not user_id:
        return None

    # 从数据库获取用户
    user = UserCRUD.get_by_id(db, user_id=user_id)
    if not user:
        return None

    # 检查用户是否激活
    if not user.is_active:
        return None

    return user


async def get_current_user(
    credentials: HTTPAuthorizationCredentials = Depends(security),
    db: Session = Depends(get_db),
) -> User:
    """获取当前认证用户(依赖项)"""
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="无法验证凭据",
        headers={"WWW-Authenticate": "Bearer"},
    )

    # 获取令牌
    token = credentials.credentials

    # 使用辅助函数获取用户
    user = get_user_from_token(token, db)
    if not user:
        raise credentials_exception

    return user


async def get_optional_user(
    request: Request, db: Session = Depends(get_db)
) -> Optional[User]:
    """可选获取当前用户(未登录时返回None)"""
    try:
        # 从请求中获取 token
        token = get_token_from_request(request)
        if not token:
            return None

        # 使用辅助函数获取用户
        return get_user_from_token(token, db)
    except Exception:
        return None


async def get_current_active_user(
    current_user: User = Depends(get_current_user),
) -> User:
    """获取当前激活用户(依赖项)"""
    if not current_user.is_active:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用"
        )
    return current_user


async def get_current_superuser(current_user: User = Depends(get_current_user)) -> User:
    """获取当前超级用户(依赖项)"""
    if not current_user.is_active:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN, detail="用户已被禁用"
        )
    if not current_user.is_superuser:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN, detail="权限不足,需要超级用户权限"
        )
    return current_user