Spaces:
Running
Running
File size: 13,248 Bytes
bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 75fb504 bcc8074 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 |
"""
Modular JWT Service
A self-contained, plug-and-play service for creating and verifying JWT tokens.
Can be used in any Python application with minimal configuration.
Usage:
from services.jwt_service import JWTService, TokenPayload
# Initialize with secret key
jwt_service = JWTService(secret_key="your-secret-key")
# Or use environment variable JWT_SECRET
jwt_service = JWTService()
# Create a token
token = jwt_service.create_token(user_id="user123", email="user@example.com")
# Verify a token
payload = jwt_service.verify_token(token)
print(payload.user_id, payload.email)
Environment Variables:
JWT_SECRET: Your secret key for signing tokens (required)
JWT_EXPIRY_HOURS: Token expiry in hours (default: 168 = 7 days)
JWT_ALGORITHM: Algorithm to use (default: HS256)
Dependencies:
PyJWT>=2.8.0
Generate a secure secret:
python -c "import secrets; print(secrets.token_urlsafe(64))"
"""
import os
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
import jwt
logger = logging.getLogger(__name__)
@dataclass
class TokenPayload:
"""
Payload extracted from a verified JWT token.
Attributes:
user_id: The user's unique identifier (sub claim)
email: The user's email address
issued_at: When the token was issued
expires_at: When the token expires
token_version: Version number for token invalidation
extra: Any additional claims in the token
"""
user_id: str
email: str
issued_at: datetime
expires_at: datetime
token_version: int = 1
token_type: str = "access" # "access" or "refresh"
extra: Dict[str, Any] = None
def __post_init__(self):
if self.extra is None:
self.extra = {}
@property
def is_expired(self) -> bool:
"""Check if the token has expired."""
return datetime.utcnow() > self.expires_at
@property
def time_until_expiry(self) -> timedelta:
"""Get time remaining until expiry."""
return self.expires_at - datetime.utcnow()
class JWTError(Exception):
"""Base exception for JWT errors."""
pass
class TokenExpiredError(JWTError):
"""Raised when the token has expired."""
pass
class InvalidTokenError(JWTError):
"""Raised when the token is invalid."""
pass
class ConfigurationError(JWTError):
"""Raised when the service is not properly configured."""
pass
class JWTService:
"""
Service for creating and verifying JWT tokens.
This service handles JWT token lifecycle for authentication.
It's designed to be modular and reusable across different applications.
Example:
service = JWTService(secret_key="my-secret")
# Create token
token = service.create_token(user_id="u123", email="a@b.com")
# Verify token
try:
payload = service.verify_token(token)
print(f"User: {payload.user_id}")
except TokenExpiredError:
print("Token expired, please login again")
except InvalidTokenError:
print("Invalid token")
"""
# Default configuration
DEFAULT_ALGORITHM = "HS256"
DEFAULT_ACCESS_EXPIRY_MINUTES = 15 # 15 minutes
DEFAULT_REFRESH_EXPIRY_DAYS = 7 # 7 days
def __init__(
self,
secret_key: Optional[str] = None,
algorithm: Optional[str] = None,
access_expiry_minutes: Optional[int] = None,
refresh_expiry_days: Optional[int] = None
):
"""
Initialize the JWT Service.
Args:
secret_key: Secret key for signing tokens.
algorithm: JWT algorithm (default: HS256).
access_expiry_minutes: Access token expiry (default: 15 min).
refresh_expiry_days: Refresh token expiry (default: 7 days).
"""
self.secret_key = secret_key or os.getenv("JWT_SECRET")
self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM)
self.access_expiry_minutes = access_expiry_minutes or int(
os.getenv("JWT_ACCESS_EXPIRY_MINUTES", str(self.DEFAULT_ACCESS_EXPIRY_MINUTES))
)
self.refresh_expiry_days = refresh_expiry_days or int(
os.getenv("JWT_REFRESH_EXPIRY_DAYS", str(self.DEFAULT_REFRESH_EXPIRY_DAYS))
)
if not self.secret_key:
raise ConfigurationError(
"JWT secret key is required. Either pass secret_key parameter "
"or set JWT_SECRET environment variable. "
"Generate one with: python -c \"import secrets; print(secrets.token_urlsafe(64))\""
)
# Warn if secret is too short
if len(self.secret_key) < 32:
logger.warning(
"JWT secret key is short (< 32 chars). "
"Consider using a longer secret for better security."
)
logger.info(
f"JWTService initialized (alg={self.algorithm}, "
f"access={self.access_expiry_minutes}m, refresh={self.refresh_expiry_days}d)"
)
def create_token(
self,
user_id: str,
email: str,
token_type: str = "access",
token_version: int = 1,
extra_claims: Optional[Dict[str, Any]] = None,
expiry_delta: Optional[timedelta] = None
) -> str:
"""
Create a JWT token.
"""
now = datetime.utcnow()
if expiry_delta:
expires_at = now + expiry_delta
elif token_type == "refresh":
expires_at = now + timedelta(days=self.refresh_expiry_days)
else:
expires_at = now + timedelta(minutes=self.access_expiry_minutes)
payload = {
"sub": user_id,
"email": email,
"type": token_type,
"tv": token_version,
"iat": now,
"exp": expires_at,
}
if extra_claims:
payload.update(extra_claims)
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
logger.debug(f"Created {token_type} token for {user_id}")
return token
def create_access_token(self, user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
"""Create a short-lived access token."""
return self.create_token(user_id, email, "access", token_version, **kwargs)
def create_refresh_token(self, user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
"""Create a long-lived refresh token."""
return self.create_token(user_id, email, "refresh", token_version, **kwargs)
def verify_token(self, token: str) -> TokenPayload:
"""
Verify a JWT token and extract the payload.
Args:
token: The JWT token to verify.
Returns:
TokenPayload: Dataclass containing the verified payload.
Raises:
TokenExpiredError: If the token has expired.
InvalidTokenError: If the token is invalid or malformed.
"""
if not token:
raise InvalidTokenError("Token cannot be empty")
try:
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm]
)
# Extract standard claims
user_id = payload.get("sub")
email = payload.get("email")
token_type = payload.get("type", "access") # Default to access for backward compat
token_version = payload.get("tv", 1)
iat = payload.get("iat")
exp = payload.get("exp")
if not user_id or not email:
raise InvalidTokenError("Token missing required claims (sub, email)")
# Convert timestamps
issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat
expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
# Extract extra claims
standard_claims = {"sub", "email", "type", "tv", "iat", "exp"}
extra = {k: v for k, v in payload.items() if k not in standard_claims}
return TokenPayload(
user_id=user_id,
email=email,
issued_at=issued_at,
expires_at=expires_at,
token_version=token_version,
token_type=token_type,
extra=extra
)
except jwt.ExpiredSignatureError:
logger.debug("Token verification failed: expired")
raise TokenExpiredError("Token has expired")
except jwt.InvalidTokenError as e:
logger.debug(f"Token verification failed: {e}")
raise InvalidTokenError(f"Invalid token: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error during token verification: {e}")
raise InvalidTokenError(f"Token verification error: {str(e)}")
def verify_token_safe(self, token: str) -> Optional[TokenPayload]:
"""
Verify a JWT token without raising exceptions.
Args:
token: The JWT token to verify.
Returns:
TokenPayload if valid, None if invalid or expired.
"""
try:
return self.verify_token(token)
except JWTError:
return None
def refresh_token(
self,
token: str,
expiry_hours: Optional[int] = None
) -> str:
"""
Refresh a token by creating a new one with the same claims.
Args:
token: The current (possibly expired) token.
expiry_hours: Custom expiry for the new token.
Returns:
str: A new JWT token with updated expiry.
Raises:
InvalidTokenError: If the token is malformed.
"""
try:
# Decode without verifying expiry
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
options={"verify_exp": False}
)
user_id = payload.get("sub")
email = payload.get("email")
if not user_id or not email:
raise InvalidTokenError("Token missing required claims")
# Preserve extra claims
standard_claims = {"sub", "email", "iat", "exp"}
extra = {k: v for k, v in payload.items() if k not in standard_claims}
return self.create_token(
user_id=user_id,
email=email,
extra_claims=extra,
expiry_hours=expiry_hours
)
except jwt.InvalidTokenError as e:
raise InvalidTokenError(f"Cannot refresh invalid token: {str(e)}")
# Singleton instance for convenience
_default_service: Optional[JWTService] = None
def get_jwt_service() -> JWTService:
"""
Get the default JWTService instance.
Creates a singleton instance using environment variables.
Returns:
JWTService: The default service instance.
Raises:
ConfigurationError: If JWT_SECRET is not set.
"""
global _default_service
if _default_service is None:
_default_service = JWTService()
return _default_service
def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
"""
Convenience function to create a token using the default service.
Args:
user_id: The user's unique identifier.
email: The user's email address.
token_version: User's current token version for invalidation.
**kwargs: Additional arguments passed to create_token.
Returns:
str: The encoded JWT token.
"""
def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
"""Convenience function to create an access token."""
return get_jwt_service().create_access_token(user_id, email, token_version, **kwargs)
def create_refresh_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
"""Convenience function to create a refresh token."""
return get_jwt_service().create_refresh_token(user_id, email, token_version, **kwargs)
def verify_access_token(token: str) -> TokenPayload:
"""
Convenience function to verify a token using the default service.
Args:
token: The JWT token to verify.
Returns:
TokenPayload: Verified token payload.
Raises:
TokenExpiredError: If the token has expired.
InvalidTokenError: If the token is invalid.
"""
return get_jwt_service().verify_token(token)
|