jebin2's picture
new auth
75fb504
"""
Modular JWT Service
A self-contained, plug-and-play service for creating and verifying JWT tokens.
Can be used in any Python application with minimal configuration.
Usage:
from services.jwt_service import JWTService, TokenPayload
# Initialize with secret key
jwt_service = JWTService(secret_key="your-secret-key")
# Or use environment variable JWT_SECRET
jwt_service = JWTService()
# Create a token
token = jwt_service.create_token(user_id="user123", email="user@example.com")
# Verify a token
payload = jwt_service.verify_token(token)
print(payload.user_id, payload.email)
Environment Variables:
JWT_SECRET: Your secret key for signing tokens (required)
JWT_EXPIRY_HOURS: Token expiry in hours (default: 168 = 7 days)
JWT_ALGORITHM: Algorithm to use (default: HS256)
Dependencies:
PyJWT>=2.8.0
Generate a secure secret:
python -c "import secrets; print(secrets.token_urlsafe(64))"
"""
import os
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
import jwt
logger = logging.getLogger(__name__)
@dataclass
class TokenPayload:
"""
Payload extracted from a verified JWT token.
Attributes:
user_id: The user's unique identifier (sub claim)
email: The user's email address
issued_at: When the token was issued
expires_at: When the token expires
token_version: Version number for token invalidation
extra: Any additional claims in the token
"""
user_id: str
email: str
issued_at: datetime
expires_at: datetime
token_version: int = 1
token_type: str = "access" # "access" or "refresh"
extra: Dict[str, Any] = None
def __post_init__(self):
if self.extra is None:
self.extra = {}
@property
def is_expired(self) -> bool:
"""Check if the token has expired."""
return datetime.utcnow() > self.expires_at
@property
def time_until_expiry(self) -> timedelta:
"""Get time remaining until expiry."""
return self.expires_at - datetime.utcnow()
class JWTError(Exception):
"""Base exception for JWT errors."""
pass
class TokenExpiredError(JWTError):
"""Raised when the token has expired."""
pass
class InvalidTokenError(JWTError):
"""Raised when the token is invalid."""
pass
class ConfigurationError(JWTError):
"""Raised when the service is not properly configured."""
pass
class JWTService:
"""
Service for creating and verifying JWT tokens.
This service handles JWT token lifecycle for authentication.
It's designed to be modular and reusable across different applications.
Example:
service = JWTService(secret_key="my-secret")
# Create token
token = service.create_token(user_id="u123", email="a@b.com")
# Verify token
try:
payload = service.verify_token(token)
print(f"User: {payload.user_id}")
except TokenExpiredError:
print("Token expired, please login again")
except InvalidTokenError:
print("Invalid token")
"""
# Default configuration
DEFAULT_ALGORITHM = "HS256"
DEFAULT_ACCESS_EXPIRY_MINUTES = 15 # 15 minutes
DEFAULT_REFRESH_EXPIRY_DAYS = 7 # 7 days
def __init__(
self,
secret_key: Optional[str] = None,
algorithm: Optional[str] = None,
access_expiry_minutes: Optional[int] = None,
refresh_expiry_days: Optional[int] = None
):
"""
Initialize the JWT Service.
Args:
secret_key: Secret key for signing tokens.
algorithm: JWT algorithm (default: HS256).
access_expiry_minutes: Access token expiry (default: 15 min).
refresh_expiry_days: Refresh token expiry (default: 7 days).
"""
self.secret_key = secret_key or os.getenv("JWT_SECRET")
self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM)
self.access_expiry_minutes = access_expiry_minutes or int(
os.getenv("JWT_ACCESS_EXPIRY_MINUTES", str(self.DEFAULT_ACCESS_EXPIRY_MINUTES))
)
self.refresh_expiry_days = refresh_expiry_days or int(
os.getenv("JWT_REFRESH_EXPIRY_DAYS", str(self.DEFAULT_REFRESH_EXPIRY_DAYS))
)
if not self.secret_key:
raise ConfigurationError(
"JWT secret key is required. Either pass secret_key parameter "
"or set JWT_SECRET environment variable. "
"Generate one with: python -c \"import secrets; print(secrets.token_urlsafe(64))\""
)
# Warn if secret is too short
if len(self.secret_key) < 32:
logger.warning(
"JWT secret key is short (< 32 chars). "
"Consider using a longer secret for better security."
)
logger.info(
f"JWTService initialized (alg={self.algorithm}, "
f"access={self.access_expiry_minutes}m, refresh={self.refresh_expiry_days}d)"
)
def create_token(
self,
user_id: str,
email: str,
token_type: str = "access",
token_version: int = 1,
extra_claims: Optional[Dict[str, Any]] = None,
expiry_delta: Optional[timedelta] = None
) -> str:
"""
Create a JWT token.
"""
now = datetime.utcnow()
if expiry_delta:
expires_at = now + expiry_delta
elif token_type == "refresh":
expires_at = now + timedelta(days=self.refresh_expiry_days)
else:
expires_at = now + timedelta(minutes=self.access_expiry_minutes)
payload = {
"sub": user_id,
"email": email,
"type": token_type,
"tv": token_version,
"iat": now,
"exp": expires_at,
}
if extra_claims:
payload.update(extra_claims)
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
logger.debug(f"Created {token_type} token for {user_id}")
return token
def create_access_token(self, user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
"""Create a short-lived access token."""
return self.create_token(user_id, email, "access", token_version, **kwargs)
def create_refresh_token(self, user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
"""Create a long-lived refresh token."""
return self.create_token(user_id, email, "refresh", token_version, **kwargs)
def verify_token(self, token: str) -> TokenPayload:
"""
Verify a JWT token and extract the payload.
Args:
token: The JWT token to verify.
Returns:
TokenPayload: Dataclass containing the verified payload.
Raises:
TokenExpiredError: If the token has expired.
InvalidTokenError: If the token is invalid or malformed.
"""
if not token:
raise InvalidTokenError("Token cannot be empty")
try:
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm]
)
# Extract standard claims
user_id = payload.get("sub")
email = payload.get("email")
token_type = payload.get("type", "access") # Default to access for backward compat
token_version = payload.get("tv", 1)
iat = payload.get("iat")
exp = payload.get("exp")
if not user_id or not email:
raise InvalidTokenError("Token missing required claims (sub, email)")
# Convert timestamps
issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat
expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
# Extract extra claims
standard_claims = {"sub", "email", "type", "tv", "iat", "exp"}
extra = {k: v for k, v in payload.items() if k not in standard_claims}
return TokenPayload(
user_id=user_id,
email=email,
issued_at=issued_at,
expires_at=expires_at,
token_version=token_version,
token_type=token_type,
extra=extra
)
except jwt.ExpiredSignatureError:
logger.debug("Token verification failed: expired")
raise TokenExpiredError("Token has expired")
except jwt.InvalidTokenError as e:
logger.debug(f"Token verification failed: {e}")
raise InvalidTokenError(f"Invalid token: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error during token verification: {e}")
raise InvalidTokenError(f"Token verification error: {str(e)}")
def verify_token_safe(self, token: str) -> Optional[TokenPayload]:
"""
Verify a JWT token without raising exceptions.
Args:
token: The JWT token to verify.
Returns:
TokenPayload if valid, None if invalid or expired.
"""
try:
return self.verify_token(token)
except JWTError:
return None
def refresh_token(
self,
token: str,
expiry_hours: Optional[int] = None
) -> str:
"""
Refresh a token by creating a new one with the same claims.
Args:
token: The current (possibly expired) token.
expiry_hours: Custom expiry for the new token.
Returns:
str: A new JWT token with updated expiry.
Raises:
InvalidTokenError: If the token is malformed.
"""
try:
# Decode without verifying expiry
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
options={"verify_exp": False}
)
user_id = payload.get("sub")
email = payload.get("email")
if not user_id or not email:
raise InvalidTokenError("Token missing required claims")
# Preserve extra claims
standard_claims = {"sub", "email", "iat", "exp"}
extra = {k: v for k, v in payload.items() if k not in standard_claims}
return self.create_token(
user_id=user_id,
email=email,
extra_claims=extra,
expiry_hours=expiry_hours
)
except jwt.InvalidTokenError as e:
raise InvalidTokenError(f"Cannot refresh invalid token: {str(e)}")
# Singleton instance for convenience
_default_service: Optional[JWTService] = None
def get_jwt_service() -> JWTService:
"""
Get the default JWTService instance.
Creates a singleton instance using environment variables.
Returns:
JWTService: The default service instance.
Raises:
ConfigurationError: If JWT_SECRET is not set.
"""
global _default_service
if _default_service is None:
_default_service = JWTService()
return _default_service
def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
"""
Convenience function to create a token using the default service.
Args:
user_id: The user's unique identifier.
email: The user's email address.
token_version: User's current token version for invalidation.
**kwargs: Additional arguments passed to create_token.
Returns:
str: The encoded JWT token.
"""
def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
"""Convenience function to create an access token."""
return get_jwt_service().create_access_token(user_id, email, token_version, **kwargs)
def create_refresh_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
"""Convenience function to create a refresh token."""
return get_jwt_service().create_refresh_token(user_id, email, token_version, **kwargs)
def verify_access_token(token: str) -> TokenPayload:
"""
Convenience function to verify a token using the default service.
Args:
token: The JWT token to verify.
Returns:
TokenPayload: Verified token payload.
Raises:
TokenExpiredError: If the token has expired.
InvalidTokenError: If the token is invalid.
"""
return get_jwt_service().verify_token(token)