Spaces:
Sleeping
Sleeping
new auth
Browse files- core/schemas.py +5 -1
- dependencies.py +11 -3
- routers/auth.py +90 -28
- services/auth_service/jwt_provider.py +52 -32
core/schemas.py
CHANGED
|
@@ -12,12 +12,14 @@ class GoogleAuthRequest(BaseModel):
|
|
| 12 |
"""Request with Google ID token from frontend Sign-In."""
|
| 13 |
id_token: str = Field(..., min_length=1, description="Google ID token from Sign-In")
|
| 14 |
temp_user_id: Optional[str] = Field(None, description="Optional temp user ID for linking")
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class AuthResponse(BaseModel):
|
| 18 |
"""Response after successful Google authentication."""
|
| 19 |
success: bool
|
| 20 |
access_token: str
|
|
|
|
| 21 |
user_id: str
|
| 22 |
email: str
|
| 23 |
name: Optional[str] = None
|
|
@@ -36,11 +38,13 @@ class UserInfoResponse(BaseModel):
|
|
| 36 |
|
| 37 |
class TokenRefreshRequest(BaseModel):
|
| 38 |
"""Request to refresh an access token."""
|
| 39 |
-
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
class TokenRefreshResponse(BaseModel):
|
| 43 |
"""Response with refreshed access token."""
|
| 44 |
success: bool
|
| 45 |
access_token: str
|
|
|
|
| 46 |
|
|
|
|
| 12 |
"""Request with Google ID token from frontend Sign-In."""
|
| 13 |
id_token: str = Field(..., min_length=1, description="Google ID token from Sign-In")
|
| 14 |
temp_user_id: Optional[str] = Field(None, description="Optional temp user ID for linking")
|
| 15 |
+
client_type: str = Field("web", description="Client type: 'web' (cookies) or 'mobile' (body)")
|
| 16 |
|
| 17 |
|
| 18 |
class AuthResponse(BaseModel):
|
| 19 |
"""Response after successful Google authentication."""
|
| 20 |
success: bool
|
| 21 |
access_token: str
|
| 22 |
+
refresh_token: str
|
| 23 |
user_id: str
|
| 24 |
email: str
|
| 25 |
name: Optional[str] = None
|
|
|
|
| 38 |
|
| 39 |
class TokenRefreshRequest(BaseModel):
|
| 40 |
"""Request to refresh an access token."""
|
| 41 |
+
"""Request to refresh an access token."""
|
| 42 |
+
token: Optional[str] = Field(None, description="Current refresh token (optional if in cookie)")
|
| 43 |
|
| 44 |
|
| 45 |
class TokenRefreshResponse(BaseModel):
|
| 46 |
"""Response with refreshed access token."""
|
| 47 |
success: bool
|
| 48 |
access_token: str
|
| 49 |
+
refresh_token: str
|
| 50 |
|
dependencies.py
CHANGED
|
@@ -36,16 +36,16 @@ async def check_rate_limit(
|
|
| 36 |
now = datetime.utcnow()
|
| 37 |
window_start = now - timedelta(minutes=window_minutes)
|
| 38 |
|
| 39 |
-
# Check existing limit
|
| 40 |
query = select(RateLimit).where(
|
| 41 |
and_(
|
| 42 |
RateLimit.identifier == identifier,
|
| 43 |
RateLimit.endpoint == endpoint,
|
| 44 |
RateLimit.window_start >= window_start
|
| 45 |
)
|
| 46 |
-
)
|
| 47 |
result = await db.execute(query)
|
| 48 |
-
rate_limit = result.
|
| 49 |
|
| 50 |
if rate_limit:
|
| 51 |
if rate_limit.attempts >= limit:
|
|
@@ -104,6 +104,14 @@ async def get_current_user(
|
|
| 104 |
|
| 105 |
try:
|
| 106 |
payload = verify_access_token(token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
except TokenExpiredError:
|
| 108 |
raise HTTPException(
|
| 109 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
| 36 |
now = datetime.utcnow()
|
| 37 |
window_start = now - timedelta(minutes=window_minutes)
|
| 38 |
|
| 39 |
+
# Check existing limit (get most recent if multiple exist)
|
| 40 |
query = select(RateLimit).where(
|
| 41 |
and_(
|
| 42 |
RateLimit.identifier == identifier,
|
| 43 |
RateLimit.endpoint == endpoint,
|
| 44 |
RateLimit.window_start >= window_start
|
| 45 |
)
|
| 46 |
+
).order_by(RateLimit.window_start.desc())
|
| 47 |
result = await db.execute(query)
|
| 48 |
+
rate_limit = result.scalars().first()
|
| 49 |
|
| 50 |
if rate_limit:
|
| 51 |
if rate_limit.attempts >= limit:
|
|
|
|
| 104 |
|
| 105 |
try:
|
| 106 |
payload = verify_access_token(token)
|
| 107 |
+
|
| 108 |
+
# Ensure it's an access token, not a refresh token
|
| 109 |
+
if payload.extra.get("type") == "refresh":
|
| 110 |
+
raise HTTPException(
|
| 111 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 112 |
+
detail="Cannot use refresh token for API access"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
except TokenExpiredError:
|
| 116 |
raise HTTPException(
|
| 117 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
routers/auth.py
CHANGED
|
@@ -32,6 +32,7 @@ from services.auth_service.google_provider import (
|
|
| 32 |
from services.auth_service.jwt_provider import (
|
| 33 |
JWTService,
|
| 34 |
create_access_token,
|
|
|
|
| 35 |
get_jwt_service,
|
| 36 |
InvalidTokenError as JWTInvalidTokenError,
|
| 37 |
)
|
|
@@ -78,15 +79,11 @@ async def google_auth(
|
|
| 78 |
"""
|
| 79 |
Authenticate with Google ID token.
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
3. Frontend sends that token to this endpoint
|
| 85 |
-
4. We verify it with Google and issue our own JWT
|
| 86 |
-
|
| 87 |
-
Creates new user or returns existing user.
|
| 88 |
-
Existing users matched by email.
|
| 89 |
"""
|
|
|
|
| 90 |
ip = req.client.host
|
| 91 |
|
| 92 |
# Rate Limit: 10 attempts per minute per IP
|
|
@@ -200,23 +197,44 @@ async def google_auth(
|
|
| 200 |
)
|
| 201 |
await db.commit()
|
| 202 |
|
| 203 |
-
# Create our JWT access token
|
| 204 |
access_token = create_access_token(user.user_id, user.email, user.token_version)
|
|
|
|
| 205 |
|
| 206 |
# Sync DB to Drive (Async)
|
| 207 |
from services.backup_service import get_backup_service
|
| 208 |
backup_service = get_backup_service()
|
| 209 |
background_tasks.add_task(backup_service.backup_async)
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
|
| 222 |
@router.get("/me", response_model=UserInfoResponse)
|
|
@@ -254,8 +272,8 @@ async def refresh_token(
|
|
| 254 |
"""
|
| 255 |
ip = req.client.host
|
| 256 |
|
| 257 |
-
# Rate Limit:
|
| 258 |
-
if not await check_rate_limit(db, ip, "/auth/refresh",
|
| 259 |
raise HTTPException(
|
| 260 |
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
| 261 |
detail="Too many refresh attempts"
|
|
@@ -264,10 +282,24 @@ async def refresh_token(
|
|
| 264 |
try:
|
| 265 |
jwt_service = get_jwt_service()
|
| 266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
# Decode the token (without verifying expiry) to get user info
|
| 268 |
import jwt as pyjwt
|
| 269 |
payload = pyjwt.decode(
|
| 270 |
-
|
| 271 |
jwt_service.secret_key,
|
| 272 |
algorithms=[jwt_service.algorithm],
|
| 273 |
options={"verify_exp": False}
|
|
@@ -275,9 +307,17 @@ async def refresh_token(
|
|
| 275 |
|
| 276 |
user_id = payload.get("sub")
|
| 277 |
token_version = payload.get("tv", 1)
|
|
|
|
| 278 |
|
| 279 |
if not user_id:
|
| 280 |
raise JWTInvalidTokenError("Token missing required claims")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
# Check if user exists and token version is still valid
|
| 283 |
query = select(User).where(User.user_id == user_id, User.is_active == True)
|
|
@@ -297,13 +337,33 @@ async def refresh_token(
|
|
| 297 |
detail="Token has been invalidated. Please sign in again."
|
| 298 |
)
|
| 299 |
|
| 300 |
-
# Create new token
|
| 301 |
-
|
| 302 |
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
except JWTInvalidTokenError as e:
|
| 308 |
raise HTTPException(
|
| 309 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
@@ -346,4 +406,6 @@ async def logout(
|
|
| 346 |
backup_service = get_backup_service()
|
| 347 |
background_tasks.add_task(backup_service.backup_async)
|
| 348 |
|
| 349 |
-
|
|
|
|
|
|
|
|
|
| 32 |
from services.auth_service.jwt_provider import (
|
| 33 |
JWTService,
|
| 34 |
create_access_token,
|
| 35 |
+
create_refresh_token,
|
| 36 |
get_jwt_service,
|
| 37 |
InvalidTokenError as JWTInvalidTokenError,
|
| 38 |
)
|
|
|
|
| 79 |
"""
|
| 80 |
Authenticate with Google ID token.
|
| 81 |
|
| 82 |
+
Supports two client types:
|
| 83 |
+
- "web": Sets refresh_token in HttpOnly cookie (secure)
|
| 84 |
+
- "mobile": Returns refresh_token in JSON body
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
"""
|
| 86 |
+
response = JSONResponse(content={}) # Placeholder, will be populated later
|
| 87 |
ip = req.client.host
|
| 88 |
|
| 89 |
# Rate Limit: 10 attempts per minute per IP
|
|
|
|
| 197 |
)
|
| 198 |
await db.commit()
|
| 199 |
|
| 200 |
+
# Create our JWT access token and refresh token
|
| 201 |
access_token = create_access_token(user.user_id, user.email, user.token_version)
|
| 202 |
+
refresh_token = create_refresh_token(user.user_id, user.email, user.token_version)
|
| 203 |
|
| 204 |
# Sync DB to Drive (Async)
|
| 205 |
from services.backup_service import get_backup_service
|
| 206 |
backup_service = get_backup_service()
|
| 207 |
background_tasks.add_task(backup_service.backup_async)
|
| 208 |
|
| 209 |
+
# Prepare response data
|
| 210 |
+
response_data = {
|
| 211 |
+
"success": True,
|
| 212 |
+
"access_token": access_token,
|
| 213 |
+
"user_id": user.user_id,
|
| 214 |
+
"email": user.email,
|
| 215 |
+
"name": user.name,
|
| 216 |
+
"credits": user.credits,
|
| 217 |
+
"is_new_user": is_new_user
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
# Handle token delivery based on client type
|
| 221 |
+
if request.client_type == "web":
|
| 222 |
+
# Web: Set HttpOnly cookie for refresh token
|
| 223 |
+
response = JSONResponse(content=response_data)
|
| 224 |
+
response.set_cookie(
|
| 225 |
+
key="refresh_token",
|
| 226 |
+
value=refresh_token,
|
| 227 |
+
httponly=True,
|
| 228 |
+
secure=True, # Should be True in production
|
| 229 |
+
samesite="lax",
|
| 230 |
+
max_age=7 * 24 * 60 * 60 # 7 days
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
# Mobile: Return refresh token in body
|
| 234 |
+
response_data["refresh_token"] = refresh_token
|
| 235 |
+
response = JSONResponse(content=response_data)
|
| 236 |
+
|
| 237 |
+
return response
|
| 238 |
|
| 239 |
|
| 240 |
@router.get("/me", response_model=UserInfoResponse)
|
|
|
|
| 272 |
"""
|
| 273 |
ip = req.client.host
|
| 274 |
|
| 275 |
+
# Rate Limit: 20 refreshes per minute per IP (increased for proactive refresh on page load)
|
| 276 |
+
if not await check_rate_limit(db, ip, "/auth/refresh", 20, 1):
|
| 277 |
raise HTTPException(
|
| 278 |
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
| 279 |
detail="Too many refresh attempts"
|
|
|
|
| 282 |
try:
|
| 283 |
jwt_service = get_jwt_service()
|
| 284 |
|
| 285 |
+
# Get token from body or cookie
|
| 286 |
+
token_to_refresh = request.token
|
| 287 |
+
using_cookie = False
|
| 288 |
+
|
| 289 |
+
if not token_to_refresh:
|
| 290 |
+
token_to_refresh = req.cookies.get("refresh_token")
|
| 291 |
+
using_cookie = True
|
| 292 |
+
|
| 293 |
+
if not token_to_refresh:
|
| 294 |
+
raise HTTPException(
|
| 295 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 296 |
+
detail="Refresh token missing"
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
# Decode the token (without verifying expiry) to get user info
|
| 300 |
import jwt as pyjwt
|
| 301 |
payload = pyjwt.decode(
|
| 302 |
+
token_to_refresh,
|
| 303 |
jwt_service.secret_key,
|
| 304 |
algorithms=[jwt_service.algorithm],
|
| 305 |
options={"verify_exp": False}
|
|
|
|
| 307 |
|
| 308 |
user_id = payload.get("sub")
|
| 309 |
token_version = payload.get("tv", 1)
|
| 310 |
+
token_type = payload.get("type", "access")
|
| 311 |
|
| 312 |
if not user_id:
|
| 313 |
raise JWTInvalidTokenError("Token missing required claims")
|
| 314 |
+
|
| 315 |
+
# Verify it's a refresh token
|
| 316 |
+
if token_type != "refresh":
|
| 317 |
+
raise HTTPException(
|
| 318 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 319 |
+
detail="Invalid token type. Expected refresh token."
|
| 320 |
+
)
|
| 321 |
|
| 322 |
# Check if user exists and token version is still valid
|
| 323 |
query = select(User).where(User.user_id == user_id, User.is_active == True)
|
|
|
|
| 337 |
detail="Token has been invalidated. Please sign in again."
|
| 338 |
)
|
| 339 |
|
| 340 |
+
# Create new access token
|
| 341 |
+
new_access_token = create_access_token(user.user_id, user.email, user.token_version)
|
| 342 |
|
| 343 |
+
# ROTATION: Issue new refresh token
|
| 344 |
+
new_refresh_token = create_refresh_token(user.user_id, user.email, user.token_version)
|
| 345 |
+
|
| 346 |
+
response_data = {
|
| 347 |
+
"success": True,
|
| 348 |
+
"access_token": new_access_token
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
if using_cookie:
|
| 352 |
+
# If came from cookie, rotate cookie
|
| 353 |
+
response = JSONResponse(content=response_data)
|
| 354 |
+
response.set_cookie(
|
| 355 |
+
key="refresh_token",
|
| 356 |
+
value=new_refresh_token,
|
| 357 |
+
httponly=True,
|
| 358 |
+
secure=True,
|
| 359 |
+
samesite="lax",
|
| 360 |
+
max_age=7 * 24 * 60 * 60
|
| 361 |
+
)
|
| 362 |
+
return response
|
| 363 |
+
else:
|
| 364 |
+
# If came from body, return in body
|
| 365 |
+
response_data["refresh_token"] = new_refresh_token
|
| 366 |
+
return TokenRefreshResponse(**response_data)
|
| 367 |
except JWTInvalidTokenError as e:
|
| 368 |
raise HTTPException(
|
| 369 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
| 406 |
backup_service = get_backup_service()
|
| 407 |
background_tasks.add_task(backup_service.backup_async)
|
| 408 |
|
| 409 |
+
response = JSONResponse(content={"success": True, "message": "Logged out successfully. All sessions invalidated."})
|
| 410 |
+
response.delete_cookie(key="refresh_token")
|
| 411 |
+
return response
|
services/auth_service/jwt_provider.py
CHANGED
|
@@ -60,6 +60,7 @@ class TokenPayload:
|
|
| 60 |
issued_at: datetime
|
| 61 |
expires_at: datetime
|
| 62 |
token_version: int = 1
|
|
|
|
| 63 |
extra: Dict[str, Any] = None
|
| 64 |
|
| 65 |
def __post_init__(self):
|
|
@@ -122,30 +123,33 @@ class JWTService:
|
|
| 122 |
|
| 123 |
# Default configuration
|
| 124 |
DEFAULT_ALGORITHM = "HS256"
|
| 125 |
-
|
|
|
|
| 126 |
|
| 127 |
def __init__(
|
| 128 |
self,
|
| 129 |
secret_key: Optional[str] = None,
|
| 130 |
algorithm: Optional[str] = None,
|
| 131 |
-
|
|
|
|
| 132 |
):
|
| 133 |
"""
|
| 134 |
Initialize the JWT Service.
|
| 135 |
|
| 136 |
Args:
|
| 137 |
-
secret_key: Secret key for signing tokens.
|
| 138 |
-
falls back to JWT_SECRET environment variable.
|
| 139 |
algorithm: JWT algorithm (default: HS256).
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
Raises:
|
| 143 |
-
ConfigurationError: If no secret_key is provided or found.
|
| 144 |
"""
|
| 145 |
self.secret_key = secret_key or os.getenv("JWT_SECRET")
|
| 146 |
self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM)
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
)
|
| 150 |
|
| 151 |
if not self.secret_key:
|
|
@@ -163,40 +167,38 @@ class JWTService:
|
|
| 163 |
)
|
| 164 |
|
| 165 |
logger.info(
|
| 166 |
-
f"JWTService initialized (
|
| 167 |
-
f"
|
| 168 |
)
|
| 169 |
|
| 170 |
def create_token(
|
| 171 |
self,
|
| 172 |
user_id: str,
|
| 173 |
email: str,
|
|
|
|
| 174 |
token_version: int = 1,
|
| 175 |
extra_claims: Optional[Dict[str, Any]] = None,
|
| 176 |
-
|
| 177 |
) -> str:
|
| 178 |
"""
|
| 179 |
-
Create a JWT token
|
| 180 |
-
|
| 181 |
-
Args:
|
| 182 |
-
user_id: The user's unique identifier.
|
| 183 |
-
email: The user's email address.
|
| 184 |
-
token_version: User's current token version for invalidation.
|
| 185 |
-
extra_claims: Additional claims to include in the token.
|
| 186 |
-
expiry_hours: Custom expiry for this token (overrides default).
|
| 187 |
-
|
| 188 |
-
Returns:
|
| 189 |
-
str: The encoded JWT token.
|
| 190 |
"""
|
| 191 |
now = datetime.utcnow()
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
payload = {
|
| 195 |
"sub": user_id,
|
| 196 |
"email": email,
|
| 197 |
-
"
|
|
|
|
| 198 |
"iat": now,
|
| 199 |
-
"exp":
|
| 200 |
}
|
| 201 |
|
| 202 |
if extra_claims:
|
|
@@ -204,8 +206,18 @@ class JWTService:
|
|
| 204 |
|
| 205 |
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
| 206 |
|
| 207 |
-
|
|
|
|
|
|
|
| 208 |
return token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
def verify_token(self, token: str) -> TokenPayload:
|
| 211 |
"""
|
|
@@ -234,19 +246,20 @@ class JWTService:
|
|
| 234 |
# Extract standard claims
|
| 235 |
user_id = payload.get("sub")
|
| 236 |
email = payload.get("email")
|
| 237 |
-
|
|
|
|
| 238 |
iat = payload.get("iat")
|
| 239 |
exp = payload.get("exp")
|
| 240 |
|
| 241 |
if not user_id or not email:
|
| 242 |
raise InvalidTokenError("Token missing required claims (sub, email)")
|
| 243 |
|
| 244 |
-
# Convert timestamps
|
| 245 |
issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat
|
| 246 |
expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
|
| 247 |
|
| 248 |
# Extract extra claims
|
| 249 |
-
standard_claims = {"sub", "email", "tv", "iat", "exp"}
|
| 250 |
extra = {k: v for k, v in payload.items() if k not in standard_claims}
|
| 251 |
|
| 252 |
return TokenPayload(
|
|
@@ -255,6 +268,7 @@ class JWTService:
|
|
| 255 |
issued_at=issued_at,
|
| 256 |
expires_at=expires_at,
|
| 257 |
token_version=token_version,
|
|
|
|
| 258 |
extra=extra
|
| 259 |
)
|
| 260 |
|
|
@@ -366,7 +380,13 @@ def create_access_token(user_id: str, email: str, token_version: int = 1, **kwar
|
|
| 366 |
Returns:
|
| 367 |
str: The encoded JWT token.
|
| 368 |
"""
|
| 369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
|
| 372 |
def verify_access_token(token: str) -> TokenPayload:
|
|
|
|
| 60 |
issued_at: datetime
|
| 61 |
expires_at: datetime
|
| 62 |
token_version: int = 1
|
| 63 |
+
token_type: str = "access" # "access" or "refresh"
|
| 64 |
extra: Dict[str, Any] = None
|
| 65 |
|
| 66 |
def __post_init__(self):
|
|
|
|
| 123 |
|
| 124 |
# Default configuration
|
| 125 |
DEFAULT_ALGORITHM = "HS256"
|
| 126 |
+
DEFAULT_ACCESS_EXPIRY_MINUTES = 15 # 15 minutes
|
| 127 |
+
DEFAULT_REFRESH_EXPIRY_DAYS = 7 # 7 days
|
| 128 |
|
| 129 |
def __init__(
|
| 130 |
self,
|
| 131 |
secret_key: Optional[str] = None,
|
| 132 |
algorithm: Optional[str] = None,
|
| 133 |
+
access_expiry_minutes: Optional[int] = None,
|
| 134 |
+
refresh_expiry_days: Optional[int] = None
|
| 135 |
):
|
| 136 |
"""
|
| 137 |
Initialize the JWT Service.
|
| 138 |
|
| 139 |
Args:
|
| 140 |
+
secret_key: Secret key for signing tokens.
|
|
|
|
| 141 |
algorithm: JWT algorithm (default: HS256).
|
| 142 |
+
access_expiry_minutes: Access token expiry (default: 15 min).
|
| 143 |
+
refresh_expiry_days: Refresh token expiry (default: 7 days).
|
|
|
|
|
|
|
| 144 |
"""
|
| 145 |
self.secret_key = secret_key or os.getenv("JWT_SECRET")
|
| 146 |
self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM)
|
| 147 |
+
|
| 148 |
+
self.access_expiry_minutes = access_expiry_minutes or int(
|
| 149 |
+
os.getenv("JWT_ACCESS_EXPIRY_MINUTES", str(self.DEFAULT_ACCESS_EXPIRY_MINUTES))
|
| 150 |
+
)
|
| 151 |
+
self.refresh_expiry_days = refresh_expiry_days or int(
|
| 152 |
+
os.getenv("JWT_REFRESH_EXPIRY_DAYS", str(self.DEFAULT_REFRESH_EXPIRY_DAYS))
|
| 153 |
)
|
| 154 |
|
| 155 |
if not self.secret_key:
|
|
|
|
| 167 |
)
|
| 168 |
|
| 169 |
logger.info(
|
| 170 |
+
f"JWTService initialized (alg={self.algorithm}, "
|
| 171 |
+
f"access={self.access_expiry_minutes}m, refresh={self.refresh_expiry_days}d)"
|
| 172 |
)
|
| 173 |
|
| 174 |
def create_token(
|
| 175 |
self,
|
| 176 |
user_id: str,
|
| 177 |
email: str,
|
| 178 |
+
token_type: str = "access",
|
| 179 |
token_version: int = 1,
|
| 180 |
extra_claims: Optional[Dict[str, Any]] = None,
|
| 181 |
+
expiry_delta: Optional[timedelta] = None
|
| 182 |
) -> str:
|
| 183 |
"""
|
| 184 |
+
Create a JWT token.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
"""
|
| 186 |
now = datetime.utcnow()
|
| 187 |
+
|
| 188 |
+
if expiry_delta:
|
| 189 |
+
expires_at = now + expiry_delta
|
| 190 |
+
elif token_type == "refresh":
|
| 191 |
+
expires_at = now + timedelta(days=self.refresh_expiry_days)
|
| 192 |
+
else:
|
| 193 |
+
expires_at = now + timedelta(minutes=self.access_expiry_minutes)
|
| 194 |
|
| 195 |
payload = {
|
| 196 |
"sub": user_id,
|
| 197 |
"email": email,
|
| 198 |
+
"type": token_type,
|
| 199 |
+
"tv": token_version,
|
| 200 |
"iat": now,
|
| 201 |
+
"exp": expires_at,
|
| 202 |
}
|
| 203 |
|
| 204 |
if extra_claims:
|
|
|
|
| 206 |
|
| 207 |
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
| 208 |
|
| 209 |
+
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
| 210 |
+
|
| 211 |
+
logger.debug(f"Created {token_type} token for {user_id}")
|
| 212 |
return token
|
| 213 |
+
|
| 214 |
+
def create_access_token(self, user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
|
| 215 |
+
"""Create a short-lived access token."""
|
| 216 |
+
return self.create_token(user_id, email, "access", token_version, **kwargs)
|
| 217 |
+
|
| 218 |
+
def create_refresh_token(self, user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
|
| 219 |
+
"""Create a long-lived refresh token."""
|
| 220 |
+
return self.create_token(user_id, email, "refresh", token_version, **kwargs)
|
| 221 |
|
| 222 |
def verify_token(self, token: str) -> TokenPayload:
|
| 223 |
"""
|
|
|
|
| 246 |
# Extract standard claims
|
| 247 |
user_id = payload.get("sub")
|
| 248 |
email = payload.get("email")
|
| 249 |
+
token_type = payload.get("type", "access") # Default to access for backward compat
|
| 250 |
+
token_version = payload.get("tv", 1)
|
| 251 |
iat = payload.get("iat")
|
| 252 |
exp = payload.get("exp")
|
| 253 |
|
| 254 |
if not user_id or not email:
|
| 255 |
raise InvalidTokenError("Token missing required claims (sub, email)")
|
| 256 |
|
| 257 |
+
# Convert timestamps
|
| 258 |
issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat
|
| 259 |
expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
|
| 260 |
|
| 261 |
# Extract extra claims
|
| 262 |
+
standard_claims = {"sub", "email", "type", "tv", "iat", "exp"}
|
| 263 |
extra = {k: v for k, v in payload.items() if k not in standard_claims}
|
| 264 |
|
| 265 |
return TokenPayload(
|
|
|
|
| 268 |
issued_at=issued_at,
|
| 269 |
expires_at=expires_at,
|
| 270 |
token_version=token_version,
|
| 271 |
+
token_type=token_type,
|
| 272 |
extra=extra
|
| 273 |
)
|
| 274 |
|
|
|
|
| 380 |
Returns:
|
| 381 |
str: The encoded JWT token.
|
| 382 |
"""
|
| 383 |
+
def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
|
| 384 |
+
"""Convenience function to create an access token."""
|
| 385 |
+
return get_jwt_service().create_access_token(user_id, email, token_version, **kwargs)
|
| 386 |
+
|
| 387 |
+
def create_refresh_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
|
| 388 |
+
"""Convenience function to create a refresh token."""
|
| 389 |
+
return get_jwt_service().create_refresh_token(user_id, email, token_version, **kwargs)
|
| 390 |
|
| 391 |
|
| 392 |
def verify_access_token(token: str) -> TokenPayload:
|