File size: 4,129 Bytes
2310db1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
JWT Token Validator
==================

Supabase Auth๊ฐ€ ๋ฐœ๊ธ‰ํ•œ JWT ํ† ํฐ ๊ฒ€์ฆ.
"""

import logging
from typing import Dict, Any, Optional, Tuple
from datetime import datetime, timezone

try:
    import jwt
except ImportError:
    jwt = None  # PyJWT ๋ฏธ์„ค์น˜ ์‹œ graceful ์ฒ˜๋ฆฌ

from .config import SUPABASE_JWT_SECRET

logger = logging.getLogger("eodi.auth.token_validator")


class TokenValidator:
    """
    JWT ํ† ํฐ ๊ฒ€์ฆ๊ธฐ.
    
    Supabase Auth๊ฐ€ ๋ฐœ๊ธ‰ํ•œ access_token์˜ ์„œ๋ช… ๋ฐ ๋งŒ๋ฃŒ ๊ฒ€์ฆ.
    ์„œ๋ฒ„ ์ธก์—์„œ ํ† ํฐ ์œ ํšจ์„ฑ์„ ํ™•์ธํ•  ๋•Œ ์‚ฌ์šฉ.
    """
    
    def __init__(self, jwt_secret: str = None):
        """
        Args:
            jwt_secret: Supabase JWT Secret (์—†์œผ๋ฉด ํ™˜๊ฒฝ๋ณ€์ˆ˜์—์„œ ๋กœ๋“œ)
        """
        self.jwt_secret = jwt_secret or SUPABASE_JWT_SECRET
        self._enabled = bool(self.jwt_secret) and jwt is not None
        
        if jwt is None:
            logger.warning("PyJWT๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. pip install PyJWT")
        elif not self._enabled:
            logger.warning("JWT Secret ๋ฏธ์„ค์ • - ํ† ํฐ ๊ฒ€์ฆ์ด ๋น„ํ™œ์„ฑํ™”๋ฉ๋‹ˆ๋‹ค")
    
    @property
    def is_enabled(self) -> bool:
        """ํ† ํฐ ๊ฒ€์ฆ ํ™œ์„ฑํ™” ์—ฌ๋ถ€"""
        return self._enabled
    
    def validate(self, token: str) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
        """
        JWT ํ† ํฐ ๊ฒ€์ฆ.
        
        Args:
            token: JWT access_token
            
        Returns:
            (์œ ํšจ์—ฌ๋ถ€, payload, error_message)
        """
        if not self._enabled:
            # JWT Secret ๋ฏธ์„ค์ • ์‹œ ๊ฒ€์ฆ ์Šคํ‚ต (๊ฐœ๋ฐœ ํ™˜๊ฒฝ์šฉ)
            logger.warning("JWT ๊ฒ€์ฆ ์Šคํ‚ต (Secret ๋ฏธ์„ค์ • ๋˜๋Š” PyJWT ๋ฏธ์„ค์น˜)")
            return True, {"sub": "unknown"}, None
        
        try:
            payload = jwt.decode(
                token,
                self.jwt_secret,
                algorithms=["HS256"],
                options={"verify_aud": False}  # Supabase๋Š” aud ํด๋ ˆ์ž„์ด ๋‹ค์–‘ํ•จ
            )
            return True, payload, None
            
        except jwt.ExpiredSignatureError:
            return False, None, "ํ† ํฐ์ด ๋งŒ๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค"
        except jwt.InvalidTokenError as e:
            logger.warning(f"JWT ๊ฒ€์ฆ ์‹คํŒจ: {e}")
            return False, None, f"์œ ํšจํ•˜์ง€ ์•Š์€ ํ† ํฐ์ž…๋‹ˆ๋‹ค: {e}"
    
    def extract_user_id(self, token: str) -> Optional[str]:
        """
        ํ† ํฐ์—์„œ user_id(sub ํด๋ ˆ์ž„) ์ถ”์ถœ.
        
        Args:
            token: JWT access_token
            
        Returns:
            user_id ๋˜๋Š” None
        """
        is_valid, payload, _ = self.validate(token)
        if is_valid and payload:
            return payload.get("sub")
        return None
    
    def is_expired(self, token: str) -> bool:
        """
        ํ† ํฐ ๋งŒ๋ฃŒ ์—ฌ๋ถ€ ํ™•์ธ.
        
        Args:
            token: JWT access_token
            
        Returns:
            ๋งŒ๋ฃŒ ์—ฌ๋ถ€
        """
        is_valid, _, error = self.validate(token)
        return not is_valid and error and "๋งŒ๋ฃŒ" in error
    
    def get_expiry_time(self, token: str) -> Optional[datetime]:
        """
        ํ† ํฐ ๋งŒ๋ฃŒ ์‹œ๊ฐ„ ์ถ”์ถœ.
        
        Args:
            token: JWT access_token
            
        Returns:
            ๋งŒ๋ฃŒ ์‹œ๊ฐ„ ๋˜๋Š” None
        """
        if jwt is None:
            return None
            
        try:
            # ๊ฒ€์ฆ ์—†์ด ํŽ˜์ด๋กœ๋“œ๋งŒ ๋””์ฝ”๋”ฉ
            payload = jwt.decode(
                token,
                options={"verify_signature": False}
            )
            exp = payload.get("exp")
            if exp:
                return datetime.fromtimestamp(exp, tz=timezone.utc)
            return None
        except Exception:
            return None


# ์ „์—ญ ์ธ์Šคํ„ด์Šค (Lazy loading)
_token_validator = None


def get_token_validator() -> TokenValidator:
    """TokenValidator ์‹ฑ๊ธ€ํ†ค ์ธ์Šคํ„ด์Šค ๋ฐ˜ํ™˜"""
    global _token_validator
    if _token_validator is None:
        _token_validator = TokenValidator()
    return _token_validator