Spaces:
Running
Running
| """JWT token handling for session authentication. | |
| This module handles creation and verification of JWT session tokens. | |
| These tokens are returned to the frontend after OAuth and used for all API calls. | |
| """ | |
| import os | |
| import secrets | |
| from datetime import datetime, timedelta, timezone | |
| from typing import Optional | |
| import jwt | |
| from pydantic import BaseModel | |
| class TokenPayload(BaseModel): | |
| """JWT token payload.""" | |
| user_id: str # HF username | |
| exp: datetime | |
| iat: datetime | |
| jti: str # Unique token ID | |
| class JWTHandler: | |
| """Handles JWT session token creation and verification.""" | |
| def __init__( | |
| self, | |
| secret_key: Optional[str] = None, | |
| algorithm: str = "HS256", | |
| token_lifetime_hours: int = 8, | |
| ): | |
| # Use provided secret or generate one (note: generated secret won't survive restarts) | |
| self.secret_key = ( | |
| secret_key or os.environ.get("JWT_SECRET_KEY") or secrets.token_urlsafe(32) | |
| ) | |
| self.algorithm = algorithm | |
| self.token_lifetime = timedelta(hours=token_lifetime_hours) | |
| # Track revoked tokens (jti -> revocation time) | |
| self._revoked_tokens: dict[str, datetime] = {} | |
| def create_token(self, user_id: str) -> str: | |
| """Create a new JWT session token for a user. | |
| Args: | |
| user_id: The HF username | |
| Returns: | |
| Encoded JWT token string | |
| """ | |
| now = datetime.now(timezone.utc) | |
| payload = { | |
| "user_id": user_id, | |
| "exp": now + self.token_lifetime, | |
| "iat": now, | |
| "jti": secrets.token_urlsafe(16), | |
| } | |
| return jwt.encode(payload, self.secret_key, algorithm=self.algorithm) | |
| def verify_token(self, token: str) -> Optional[TokenPayload]: | |
| """Verify a JWT token and return its payload. | |
| Args: | |
| token: The JWT token string | |
| Returns: | |
| TokenPayload if valid, None if invalid or expired | |
| """ | |
| try: | |
| payload = jwt.decode( | |
| token, | |
| self.secret_key, | |
| algorithms=[self.algorithm], | |
| ) | |
| # Check if token is revoked | |
| jti = payload.get("jti") | |
| if jti and jti in self._revoked_tokens: | |
| return None | |
| return TokenPayload( | |
| user_id=payload["user_id"], | |
| exp=datetime.fromtimestamp(payload["exp"], tz=timezone.utc), | |
| iat=datetime.fromtimestamp(payload["iat"], tz=timezone.utc), | |
| jti=payload.get("jti", ""), | |
| ) | |
| except jwt.ExpiredSignatureError: | |
| return None | |
| except jwt.InvalidTokenError: | |
| return None | |
| def revoke_token(self, token: str) -> bool: | |
| """Revoke a token so it can no longer be used. | |
| Args: | |
| token: The JWT token to revoke | |
| Returns: | |
| True if revoked, False if token was invalid | |
| """ | |
| payload = self.verify_token(token) | |
| if payload and payload.jti: | |
| self._revoked_tokens[payload.jti] = datetime.now(timezone.utc) | |
| return True | |
| return False | |
| def cleanup_revoked(self) -> int: | |
| """Remove expired tokens from the revoked list. | |
| Returns: | |
| Number of tokens cleaned up | |
| """ | |
| now = datetime.now(timezone.utc) | |
| cutoff = now - self.token_lifetime | |
| to_remove = [ | |
| jti | |
| for jti, revoked_at in self._revoked_tokens.items() | |
| if revoked_at < cutoff | |
| ] | |
| for jti in to_remove: | |
| del self._revoked_tokens[jti] | |
| return len(to_remove) | |
| # Global JWT handler instance | |
| jwt_handler = JWTHandler() | |