Spaces:
Sleeping
Sleeping
| """ | |
| Modular JWT Service | |
| A self-contained, plug-and-play service for creating and verifying JWT tokens. | |
| Can be used in any Python application with minimal configuration. | |
| Usage: | |
| from services.jwt_service import JWTService, TokenPayload | |
| # Initialize with secret key | |
| jwt_service = JWTService(secret_key="your-secret-key") | |
| # Or use environment variable JWT_SECRET | |
| jwt_service = JWTService() | |
| # Create a token | |
| token = jwt_service.create_token(user_id="user123", email="user@example.com") | |
| # Verify a token | |
| payload = jwt_service.verify_token(token) | |
| print(payload.user_id, payload.email) | |
| Environment Variables: | |
| JWT_SECRET: Your secret key for signing tokens (required) | |
| JWT_EXPIRY_HOURS: Token expiry in hours (default: 168 = 7 days) | |
| JWT_ALGORITHM: Algorithm to use (default: HS256) | |
| Dependencies: | |
| PyJWT>=2.8.0 | |
| Generate a secure secret: | |
| python -c "import secrets; print(secrets.token_urlsafe(64))" | |
| """ | |
| import os | |
| import logging | |
| from dataclasses import dataclass | |
| from datetime import datetime, timedelta | |
| from typing import Optional, Dict, Any | |
| import jwt | |
| logger = logging.getLogger(__name__) | |
| class TokenPayload: | |
| """ | |
| Payload extracted from a verified JWT token. | |
| Attributes: | |
| user_id: The user's unique identifier (sub claim) | |
| email: The user's email address | |
| issued_at: When the token was issued | |
| expires_at: When the token expires | |
| token_version: Version number for token invalidation | |
| extra: Any additional claims in the token | |
| """ | |
| user_id: str | |
| email: str | |
| issued_at: datetime | |
| expires_at: datetime | |
| token_version: int = 1 | |
| token_type: str = "access" # "access" or "refresh" | |
| extra: Dict[str, Any] = None | |
| def __post_init__(self): | |
| if self.extra is None: | |
| self.extra = {} | |
| def is_expired(self) -> bool: | |
| """Check if the token has expired.""" | |
| return datetime.utcnow() > self.expires_at | |
| def time_until_expiry(self) -> timedelta: | |
| """Get time remaining until expiry.""" | |
| return self.expires_at - datetime.utcnow() | |
| class JWTError(Exception): | |
| """Base exception for JWT errors.""" | |
| pass | |
| class TokenExpiredError(JWTError): | |
| """Raised when the token has expired.""" | |
| pass | |
| class InvalidTokenError(JWTError): | |
| """Raised when the token is invalid.""" | |
| pass | |
| class ConfigurationError(JWTError): | |
| """Raised when the service is not properly configured.""" | |
| pass | |
| class JWTService: | |
| """ | |
| Service for creating and verifying JWT tokens. | |
| This service handles JWT token lifecycle for authentication. | |
| It's designed to be modular and reusable across different applications. | |
| Example: | |
| service = JWTService(secret_key="my-secret") | |
| # Create token | |
| token = service.create_token(user_id="u123", email="a@b.com") | |
| # Verify token | |
| try: | |
| payload = service.verify_token(token) | |
| print(f"User: {payload.user_id}") | |
| except TokenExpiredError: | |
| print("Token expired, please login again") | |
| except InvalidTokenError: | |
| print("Invalid token") | |
| """ | |
| # Default configuration | |
| DEFAULT_ALGORITHM = "HS256" | |
| DEFAULT_ACCESS_EXPIRY_MINUTES = 15 # 15 minutes | |
| DEFAULT_REFRESH_EXPIRY_DAYS = 7 # 7 days | |
| def __init__( | |
| self, | |
| secret_key: Optional[str] = None, | |
| algorithm: Optional[str] = None, | |
| access_expiry_minutes: Optional[int] = None, | |
| refresh_expiry_days: Optional[int] = None | |
| ): | |
| """ | |
| Initialize the JWT Service. | |
| Args: | |
| secret_key: Secret key for signing tokens. | |
| algorithm: JWT algorithm (default: HS256). | |
| access_expiry_minutes: Access token expiry (default: 15 min). | |
| refresh_expiry_days: Refresh token expiry (default: 7 days). | |
| """ | |
| self.secret_key = secret_key or os.getenv("JWT_SECRET") | |
| self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM) | |
| self.access_expiry_minutes = access_expiry_minutes or int( | |
| os.getenv("JWT_ACCESS_EXPIRY_MINUTES", str(self.DEFAULT_ACCESS_EXPIRY_MINUTES)) | |
| ) | |
| self.refresh_expiry_days = refresh_expiry_days or int( | |
| os.getenv("JWT_REFRESH_EXPIRY_DAYS", str(self.DEFAULT_REFRESH_EXPIRY_DAYS)) | |
| ) | |
| if not self.secret_key: | |
| raise ConfigurationError( | |
| "JWT secret key is required. Either pass secret_key parameter " | |
| "or set JWT_SECRET environment variable. " | |
| "Generate one with: python -c \"import secrets; print(secrets.token_urlsafe(64))\"" | |
| ) | |
| # Warn if secret is too short | |
| if len(self.secret_key) < 32: | |
| logger.warning( | |
| "JWT secret key is short (< 32 chars). " | |
| "Consider using a longer secret for better security." | |
| ) | |
| logger.info( | |
| f"JWTService initialized (alg={self.algorithm}, " | |
| f"access={self.access_expiry_minutes}m, refresh={self.refresh_expiry_days}d)" | |
| ) | |
| def create_token( | |
| self, | |
| user_id: str, | |
| email: str, | |
| token_type: str = "access", | |
| token_version: int = 1, | |
| extra_claims: Optional[Dict[str, Any]] = None, | |
| expiry_delta: Optional[timedelta] = None | |
| ) -> str: | |
| """ | |
| Create a JWT token. | |
| """ | |
| now = datetime.utcnow() | |
| if expiry_delta: | |
| expires_at = now + expiry_delta | |
| elif token_type == "refresh": | |
| expires_at = now + timedelta(days=self.refresh_expiry_days) | |
| else: | |
| expires_at = now + timedelta(minutes=self.access_expiry_minutes) | |
| payload = { | |
| "sub": user_id, | |
| "email": email, | |
| "type": token_type, | |
| "tv": token_version, | |
| "iat": now, | |
| "exp": expires_at, | |
| } | |
| if extra_claims: | |
| payload.update(extra_claims) | |
| token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm) | |
| token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm) | |
| logger.debug(f"Created {token_type} token for {user_id}") | |
| return token | |
| def create_access_token(self, user_id: str, email: str, token_version: int = 1, **kwargs) -> str: | |
| """Create a short-lived access token.""" | |
| return self.create_token(user_id, email, "access", token_version, **kwargs) | |
| def create_refresh_token(self, user_id: str, email: str, token_version: int = 1, **kwargs) -> str: | |
| """Create a long-lived refresh token.""" | |
| return self.create_token(user_id, email, "refresh", token_version, **kwargs) | |
| def verify_token(self, token: str) -> TokenPayload: | |
| """ | |
| Verify a JWT token and extract the payload. | |
| Args: | |
| token: The JWT token to verify. | |
| Returns: | |
| TokenPayload: Dataclass containing the verified payload. | |
| Raises: | |
| TokenExpiredError: If the token has expired. | |
| InvalidTokenError: If the token is invalid or malformed. | |
| """ | |
| if not token: | |
| raise InvalidTokenError("Token cannot be empty") | |
| try: | |
| payload = jwt.decode( | |
| token, | |
| self.secret_key, | |
| algorithms=[self.algorithm] | |
| ) | |
| # Extract standard claims | |
| user_id = payload.get("sub") | |
| email = payload.get("email") | |
| token_type = payload.get("type", "access") # Default to access for backward compat | |
| token_version = payload.get("tv", 1) | |
| iat = payload.get("iat") | |
| exp = payload.get("exp") | |
| if not user_id or not email: | |
| raise InvalidTokenError("Token missing required claims (sub, email)") | |
| # Convert timestamps | |
| issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat | |
| expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp | |
| # Extract extra claims | |
| standard_claims = {"sub", "email", "type", "tv", "iat", "exp"} | |
| extra = {k: v for k, v in payload.items() if k not in standard_claims} | |
| return TokenPayload( | |
| user_id=user_id, | |
| email=email, | |
| issued_at=issued_at, | |
| expires_at=expires_at, | |
| token_version=token_version, | |
| token_type=token_type, | |
| extra=extra | |
| ) | |
| except jwt.ExpiredSignatureError: | |
| logger.debug("Token verification failed: expired") | |
| raise TokenExpiredError("Token has expired") | |
| except jwt.InvalidTokenError as e: | |
| logger.debug(f"Token verification failed: {e}") | |
| raise InvalidTokenError(f"Invalid token: {str(e)}") | |
| except Exception as e: | |
| logger.error(f"Unexpected error during token verification: {e}") | |
| raise InvalidTokenError(f"Token verification error: {str(e)}") | |
| def verify_token_safe(self, token: str) -> Optional[TokenPayload]: | |
| """ | |
| Verify a JWT token without raising exceptions. | |
| Args: | |
| token: The JWT token to verify. | |
| Returns: | |
| TokenPayload if valid, None if invalid or expired. | |
| """ | |
| try: | |
| return self.verify_token(token) | |
| except JWTError: | |
| return None | |
| def refresh_token( | |
| self, | |
| token: str, | |
| expiry_hours: Optional[int] = None | |
| ) -> str: | |
| """ | |
| Refresh a token by creating a new one with the same claims. | |
| Args: | |
| token: The current (possibly expired) token. | |
| expiry_hours: Custom expiry for the new token. | |
| Returns: | |
| str: A new JWT token with updated expiry. | |
| Raises: | |
| InvalidTokenError: If the token is malformed. | |
| """ | |
| try: | |
| # Decode without verifying expiry | |
| payload = jwt.decode( | |
| token, | |
| self.secret_key, | |
| algorithms=[self.algorithm], | |
| options={"verify_exp": False} | |
| ) | |
| user_id = payload.get("sub") | |
| email = payload.get("email") | |
| if not user_id or not email: | |
| raise InvalidTokenError("Token missing required claims") | |
| # Preserve extra claims | |
| standard_claims = {"sub", "email", "iat", "exp"} | |
| extra = {k: v for k, v in payload.items() if k not in standard_claims} | |
| return self.create_token( | |
| user_id=user_id, | |
| email=email, | |
| extra_claims=extra, | |
| expiry_hours=expiry_hours | |
| ) | |
| except jwt.InvalidTokenError as e: | |
| raise InvalidTokenError(f"Cannot refresh invalid token: {str(e)}") | |
| # Singleton instance for convenience | |
| _default_service: Optional[JWTService] = None | |
| def get_jwt_service() -> JWTService: | |
| """ | |
| Get the default JWTService instance. | |
| Creates a singleton instance using environment variables. | |
| Returns: | |
| JWTService: The default service instance. | |
| Raises: | |
| ConfigurationError: If JWT_SECRET is not set. | |
| """ | |
| global _default_service | |
| if _default_service is None: | |
| _default_service = JWTService() | |
| return _default_service | |
| def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str: | |
| """ | |
| Convenience function to create a token using the default service. | |
| Args: | |
| user_id: The user's unique identifier. | |
| email: The user's email address. | |
| token_version: User's current token version for invalidation. | |
| **kwargs: Additional arguments passed to create_token. | |
| Returns: | |
| str: The encoded JWT token. | |
| """ | |
| def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str: | |
| """Convenience function to create an access token.""" | |
| return get_jwt_service().create_access_token(user_id, email, token_version, **kwargs) | |
| def create_refresh_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str: | |
| """Convenience function to create a refresh token.""" | |
| return get_jwt_service().create_refresh_token(user_id, email, token_version, **kwargs) | |
| def verify_access_token(token: str) -> TokenPayload: | |
| """ | |
| Convenience function to verify a token using the default service. | |
| Args: | |
| token: The JWT token to verify. | |
| Returns: | |
| TokenPayload: Verified token payload. | |
| Raises: | |
| TokenExpiredError: If the token has expired. | |
| InvalidTokenError: If the token is invalid. | |
| """ | |
| return get_jwt_service().verify_token(token) | |