File size: 5,015 Bytes
a42ab7e
 
 
 
 
050d8f8
a42ab7e
050d8f8
a42ab7e
050d8f8
 
e39877e
a42ab7e
bc8ed4e
1bd7131
 
 
 
 
050d8f8
 
 
1bd7131
 
050d8f8
 
 
 
1bd7131
 
 
19e4a8c
 
1bd7131
 
 
 
050d8f8
1bd7131
 
 
050d8f8
 
1bd7131
 
050d8f8
 
1bd7131
050d8f8
 
1bd7131
 
050d8f8
1bd7131
 
 
 
 
75fb504
 
 
 
 
 
 
 
1bd7131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
050d8f8
1bd7131
050d8f8
1bd7131
050d8f8
 
1bd7131
050d8f8
1bd7131
19e4a8c
 
 
 
 
 
 
 
 
1bd7131
 
 
7dfb3ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
Authentication Dependencies

FastAPI dependencies for user authentication and authorization.
"""
import logging
from typing import Optional
from fastapi import Request, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from core.database import get_db
from core.models import User
from services.auth_service.jwt_provider import (
    verify_access_token,
    TokenExpiredError,
    InvalidTokenError,
    JWTError
)

logger = logging.getLogger(__name__)


async def get_current_user(
    req: Request,
    db: AsyncSession = Depends(get_db)
) -> User:
    """
    Extract and verify JWT from Authorization header.
    Returns the authenticated user.
    
    Also validates token_version to support instant logout/invalidation.
    
    Usage:
        @router.get("/protected")
        async def protected_route(user: User = Depends(get_current_user)):
            return {"user_id": user.user_id}
    """
    auth_header = req.headers.get("Authorization")
    
    if not auth_header:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Missing Authorization header",
            headers={"WWW-Authenticate": "Bearer"}
        )
    
    if not auth_header.startswith("Bearer "):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid Authorization header format. Use: Bearer <token>",
            headers={"WWW-Authenticate": "Bearer"}
        )
    
    token = auth_header.split(" ", 1)[1]
    
    try:
        payload = verify_access_token(token)
        
        # Ensure it's an access token, not a refresh token
        if payload.extra.get("type") == "refresh":
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Cannot use refresh token for API access"
            )
            
    except TokenExpiredError:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Token has expired. Please sign in again.",
            headers={"WWW-Authenticate": "Bearer"}
        )
    except InvalidTokenError as e:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=f"Invalid token: {str(e)}",
            headers={"WWW-Authenticate": "Bearer"}
        )
    except JWTError as e:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=f"Authentication error: {str(e)}",
            headers={"WWW-Authenticate": "Bearer"}
        )
    
    # Get user from DB
    query = select(User).where(
        User.user_id == payload.user_id,
        User.is_active == True
    )
    result = await db.execute(query)
    user = result.scalar_one_or_none()
    
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="User not found or inactive"
        )
    
    # Validate token version - if user's version is higher, token is invalidated
    if payload.token_version < user.token_version:
        logger.info(f"Token invalidated for user {user.user_id}: token_version {payload.token_version} < {user.token_version}")
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Token has been invalidated. Please sign in again.",
            headers={"WWW-Authenticate": "Bearer"}
        )
    
    return user


async def get_optional_user(
    req: Request,
    db: AsyncSession = Depends(get_db)
) -> Optional[User]:
    """
    Attempt to extract and verify JWT from Authorization header.
    Returns the authenticated user if valid, or None if not authenticated.
    
    Unlike get_current_user, this does NOT raise errors for missing/invalid tokens.
    Useful for endpoints that work for both authenticated and anonymous users.
    
    Usage:
        @router.get("/optional-auth")
        async def optional_auth_route(user: Optional[User] = Depends(get_optional_user)):
            if user:
                return {"user_id": user.user_id}
            return {"message": "anonymous"}
    """
    auth_header = req.headers.get("Authorization")
    
    if not auth_header or not auth_header.startswith("Bearer "):
        return None
    
    token = auth_header.split(" ", 1)[1]
    
    try:
        payload = verify_access_token(token)
    except (TokenExpiredError, InvalidTokenError, JWTError) as e:
        logger.debug(f"Optional auth failed: {e}")
        return None
    
    # Get user from DB
    query = select(User).where(
        User.user_id == payload.user_id,
        User.is_active == True
    )
    result = await db.execute(query)
    user = result.scalar_one_or_none()
    
    if not user:
        return None
    
    # Validate token version
    if payload.token_version < user.token_version:
        logger.debug(f"Token invalidated for user {user.user_id}")
        return None
    
    return user