File size: 13,248 Bytes
bcc8074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75fb504
bcc8074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75fb504
 
bcc8074
 
 
 
 
75fb504
 
bcc8074
 
 
 
 
75fb504
bcc8074
75fb504
 
bcc8074
 
 
75fb504
 
 
 
 
 
bcc8074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75fb504
 
bcc8074
 
 
 
 
 
75fb504
bcc8074
 
75fb504
bcc8074
 
75fb504
bcc8074
 
75fb504
 
 
 
 
 
 
bcc8074
 
 
 
75fb504
 
bcc8074
75fb504
bcc8074
 
 
 
 
 
 
75fb504
 
 
bcc8074
75fb504
 
 
 
 
 
 
 
bcc8074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75fb504
 
bcc8074
 
 
 
 
 
75fb504
bcc8074
 
 
 
75fb504
bcc8074
 
 
 
 
 
 
 
75fb504
bcc8074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75fb504
 
 
 
 
 
 
bcc8074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
"""
Modular JWT Service

A self-contained, plug-and-play service for creating and verifying JWT tokens.
Can be used in any Python application with minimal configuration.

Usage:
    from services.jwt_service import JWTService, TokenPayload
    
    # Initialize with secret key
    jwt_service = JWTService(secret_key="your-secret-key")
    
    # Or use environment variable JWT_SECRET
    jwt_service = JWTService()
    
    # Create a token
    token = jwt_service.create_token(user_id="user123", email="user@example.com")
    
    # Verify a token
    payload = jwt_service.verify_token(token)
    print(payload.user_id, payload.email)

Environment Variables:
    JWT_SECRET: Your secret key for signing tokens (required)
    JWT_EXPIRY_HOURS: Token expiry in hours (default: 168 = 7 days)
    JWT_ALGORITHM: Algorithm to use (default: HS256)

Dependencies:
    PyJWT>=2.8.0

Generate a secure secret:
    python -c "import secrets; print(secrets.token_urlsafe(64))"
"""

import os
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
import jwt

logger = logging.getLogger(__name__)


@dataclass
class TokenPayload:
    """
    Payload extracted from a verified JWT token.
    
    Attributes:
        user_id: The user's unique identifier (sub claim)
        email: The user's email address
        issued_at: When the token was issued
        expires_at: When the token expires
        token_version: Version number for token invalidation
        extra: Any additional claims in the token
    """
    user_id: str
    email: str
    issued_at: datetime
    expires_at: datetime
    token_version: int = 1
    token_type: str = "access"  # "access" or "refresh"
    extra: Dict[str, Any] = None
    
    def __post_init__(self):
        if self.extra is None:
            self.extra = {}
    
    @property
    def is_expired(self) -> bool:
        """Check if the token has expired."""
        return datetime.utcnow() > self.expires_at
    
    @property
    def time_until_expiry(self) -> timedelta:
        """Get time remaining until expiry."""
        return self.expires_at - datetime.utcnow()


class JWTError(Exception):
    """Base exception for JWT errors."""
    pass


class TokenExpiredError(JWTError):
    """Raised when the token has expired."""
    pass


class InvalidTokenError(JWTError):
    """Raised when the token is invalid."""
    pass


class ConfigurationError(JWTError):
    """Raised when the service is not properly configured."""
    pass


class JWTService:
    """
    Service for creating and verifying JWT tokens.
    
    This service handles JWT token lifecycle for authentication.
    It's designed to be modular and reusable across different applications.
    
    Example:
        service = JWTService(secret_key="my-secret")
        
        # Create token
        token = service.create_token(user_id="u123", email="a@b.com")
        
        # Verify token
        try:
            payload = service.verify_token(token)
            print(f"User: {payload.user_id}")
        except TokenExpiredError:
            print("Token expired, please login again")
        except InvalidTokenError:
            print("Invalid token")
    """
    
    # Default configuration
    DEFAULT_ALGORITHM = "HS256"
    DEFAULT_ACCESS_EXPIRY_MINUTES = 15  # 15 minutes
    DEFAULT_REFRESH_EXPIRY_DAYS = 7     # 7 days
    
    def __init__(
        self,
        secret_key: Optional[str] = None,
        algorithm: Optional[str] = None,
        access_expiry_minutes: Optional[int] = None,
        refresh_expiry_days: Optional[int] = None
    ):
        """
        Initialize the JWT Service.
        
        Args:
            secret_key: Secret key for signing tokens.
            algorithm: JWT algorithm (default: HS256).
            access_expiry_minutes: Access token expiry (default: 15 min).
            refresh_expiry_days: Refresh token expiry (default: 7 days).
        """
        self.secret_key = secret_key or os.getenv("JWT_SECRET")
        self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM)
        
        self.access_expiry_minutes = access_expiry_minutes or int(
            os.getenv("JWT_ACCESS_EXPIRY_MINUTES", str(self.DEFAULT_ACCESS_EXPIRY_MINUTES))
        )
        self.refresh_expiry_days = refresh_expiry_days or int(
            os.getenv("JWT_REFRESH_EXPIRY_DAYS", str(self.DEFAULT_REFRESH_EXPIRY_DAYS))
        )
        
        if not self.secret_key:
            raise ConfigurationError(
                "JWT secret key is required. Either pass secret_key parameter "
                "or set JWT_SECRET environment variable. "
                "Generate one with: python -c \"import secrets; print(secrets.token_urlsafe(64))\""
            )
        
        # Warn if secret is too short
        if len(self.secret_key) < 32:
            logger.warning(
                "JWT secret key is short (< 32 chars). "
                "Consider using a longer secret for better security."
            )
        
        logger.info(
            f"JWTService initialized (alg={self.algorithm}, "
            f"access={self.access_expiry_minutes}m, refresh={self.refresh_expiry_days}d)"
        )
    
    def create_token(
        self,
        user_id: str,
        email: str,
        token_type: str = "access",
        token_version: int = 1,
        extra_claims: Optional[Dict[str, Any]] = None,
        expiry_delta: Optional[timedelta] = None
    ) -> str:
        """
        Create a JWT token.
        """
        now = datetime.utcnow()
        
        if expiry_delta:
            expires_at = now + expiry_delta
        elif token_type == "refresh":
            expires_at = now + timedelta(days=self.refresh_expiry_days)
        else:
            expires_at = now + timedelta(minutes=self.access_expiry_minutes)
        
        payload = {
            "sub": user_id,
            "email": email,
            "type": token_type,
            "tv": token_version,
            "iat": now,
            "exp": expires_at,
        }
        
        if extra_claims:
            payload.update(extra_claims)
        
        token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
        
        token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
        
        logger.debug(f"Created {token_type} token for {user_id}")
        return token

    def create_access_token(self, user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
        """Create a short-lived access token."""
        return self.create_token(user_id, email, "access", token_version, **kwargs)

    def create_refresh_token(self, user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
        """Create a long-lived refresh token."""
        return self.create_token(user_id, email, "refresh", token_version, **kwargs)
    
    def verify_token(self, token: str) -> TokenPayload:
        """
        Verify a JWT token and extract the payload.
        
        Args:
            token: The JWT token to verify.
        
        Returns:
            TokenPayload: Dataclass containing the verified payload.
        
        Raises:
            TokenExpiredError: If the token has expired.
            InvalidTokenError: If the token is invalid or malformed.
        """
        if not token:
            raise InvalidTokenError("Token cannot be empty")
        
        try:
            payload = jwt.decode(
                token,
                self.secret_key,
                algorithms=[self.algorithm]
            )
            
            # Extract standard claims
            user_id = payload.get("sub")
            email = payload.get("email")
            token_type = payload.get("type", "access")  # Default to access for backward compat
            token_version = payload.get("tv", 1)
            iat = payload.get("iat")
            exp = payload.get("exp")
            
            if not user_id or not email:
                raise InvalidTokenError("Token missing required claims (sub, email)")
            
            # Convert timestamps
            issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat
            expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
            
            # Extract extra claims
            standard_claims = {"sub", "email", "type", "tv", "iat", "exp"}
            extra = {k: v for k, v in payload.items() if k not in standard_claims}
            
            return TokenPayload(
                user_id=user_id,
                email=email,
                issued_at=issued_at,
                expires_at=expires_at,
                token_version=token_version,
                token_type=token_type,
                extra=extra
            )
            
        except jwt.ExpiredSignatureError:
            logger.debug("Token verification failed: expired")
            raise TokenExpiredError("Token has expired")
        except jwt.InvalidTokenError as e:
            logger.debug(f"Token verification failed: {e}")
            raise InvalidTokenError(f"Invalid token: {str(e)}")
        except Exception as e:
            logger.error(f"Unexpected error during token verification: {e}")
            raise InvalidTokenError(f"Token verification error: {str(e)}")
    
    def verify_token_safe(self, token: str) -> Optional[TokenPayload]:
        """
        Verify a JWT token without raising exceptions.
        
        Args:
            token: The JWT token to verify.
        
        Returns:
            TokenPayload if valid, None if invalid or expired.
        """
        try:
            return self.verify_token(token)
        except JWTError:
            return None
    
    def refresh_token(
        self,
        token: str,
        expiry_hours: Optional[int] = None
    ) -> str:
        """
        Refresh a token by creating a new one with the same claims.
        
        Args:
            token: The current (possibly expired) token.
            expiry_hours: Custom expiry for the new token.
        
        Returns:
            str: A new JWT token with updated expiry.
        
        Raises:
            InvalidTokenError: If the token is malformed.
        """
        try:
            # Decode without verifying expiry
            payload = jwt.decode(
                token,
                self.secret_key,
                algorithms=[self.algorithm],
                options={"verify_exp": False}
            )
            
            user_id = payload.get("sub")
            email = payload.get("email")
            
            if not user_id or not email:
                raise InvalidTokenError("Token missing required claims")
            
            # Preserve extra claims
            standard_claims = {"sub", "email", "iat", "exp"}
            extra = {k: v for k, v in payload.items() if k not in standard_claims}
            
            return self.create_token(
                user_id=user_id,
                email=email,
                extra_claims=extra,
                expiry_hours=expiry_hours
            )
            
        except jwt.InvalidTokenError as e:
            raise InvalidTokenError(f"Cannot refresh invalid token: {str(e)}")


# Singleton instance for convenience
_default_service: Optional[JWTService] = None


def get_jwt_service() -> JWTService:
    """
    Get the default JWTService instance.
    
    Creates a singleton instance using environment variables.
    
    Returns:
        JWTService: The default service instance.
    
    Raises:
        ConfigurationError: If JWT_SECRET is not set.
    """
    global _default_service
    if _default_service is None:
        _default_service = JWTService()
    return _default_service


def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
    """
    Convenience function to create a token using the default service.
    
    Args:
        user_id: The user's unique identifier.
        email: The user's email address.
        token_version: User's current token version for invalidation.
        **kwargs: Additional arguments passed to create_token.
    
    Returns:
        str: The encoded JWT token.
    """
def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
    """Convenience function to create an access token."""
    return get_jwt_service().create_access_token(user_id, email, token_version, **kwargs)

def create_refresh_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
    """Convenience function to create a refresh token."""
    return get_jwt_service().create_refresh_token(user_id, email, token_version, **kwargs)


def verify_access_token(token: str) -> TokenPayload:
    """
    Convenience function to verify a token using the default service.
    
    Args:
        token: The JWT token to verify.
    
    Returns:
        TokenPayload: Verified token payload.
    
    Raises:
        TokenExpiredError: If the token has expired.
        InvalidTokenError: If the token is invalid.
    """
    return get_jwt_service().verify_token(token)