Spaces:
Sleeping
Sleeping
tokrn version
Browse files- core/models.py +4 -0
- dependencies.py +11 -0
- routers/auth.py +47 -9
- services/jwt_service.py +12 -4
core/models.py
CHANGED
|
@@ -54,6 +54,10 @@ class User(Base):
|
|
| 54 |
# Legacy field (kept for migration, nullable now)
|
| 55 |
secret_key_hash = Column(String(255), nullable=True)
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# Credits and status
|
| 58 |
credits = Column(Integer, default=100)
|
| 59 |
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
|
|
| 54 |
# Legacy field (kept for migration, nullable now)
|
| 55 |
secret_key_hash = Column(String(255), nullable=True)
|
| 56 |
|
| 57 |
+
# Token versioning for JWT invalidation
|
| 58 |
+
# Incrementing this invalidates all existing tokens for this user
|
| 59 |
+
token_version = Column(Integer, default=1, nullable=False)
|
| 60 |
+
|
| 61 |
# Credits and status
|
| 62 |
credits = Column(Integer, default=100)
|
| 63 |
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
dependencies.py
CHANGED
|
@@ -77,6 +77,8 @@ async def get_current_user(
|
|
| 77 |
Extract and verify JWT from Authorization header.
|
| 78 |
Returns the authenticated user.
|
| 79 |
|
|
|
|
|
|
|
| 80 |
Usage:
|
| 81 |
@router.get("/protected")
|
| 82 |
async def protected_route(user: User = Depends(get_current_user)):
|
|
@@ -135,6 +137,15 @@ async def get_current_user(
|
|
| 135 |
detail="User not found or inactive"
|
| 136 |
)
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
return user
|
| 139 |
|
| 140 |
|
|
|
|
| 77 |
Extract and verify JWT from Authorization header.
|
| 78 |
Returns the authenticated user.
|
| 79 |
|
| 80 |
+
Also validates token_version to support instant logout/invalidation.
|
| 81 |
+
|
| 82 |
Usage:
|
| 83 |
@router.get("/protected")
|
| 84 |
async def protected_route(user: User = Depends(get_current_user)):
|
|
|
|
| 137 |
detail="User not found or inactive"
|
| 138 |
)
|
| 139 |
|
| 140 |
+
# Validate token version - if user's version is higher, token is invalidated
|
| 141 |
+
if payload.token_version < user.token_version:
|
| 142 |
+
logger.info(f"Token invalidated for user {user.user_id}: token_version {payload.token_version} < {user.token_version}")
|
| 143 |
+
raise HTTPException(
|
| 144 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 145 |
+
detail="Token has been invalidated. Please sign in again.",
|
| 146 |
+
headers={"WWW-Authenticate": "Bearer"}
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
return user
|
| 150 |
|
| 151 |
|
routers/auth.py
CHANGED
|
@@ -167,8 +167,8 @@ async def google_auth(
|
|
| 167 |
db.add(audit_log)
|
| 168 |
await db.commit()
|
| 169 |
|
| 170 |
-
# Create our JWT access token
|
| 171 |
-
access_token = create_access_token(user.user_id, user.email)
|
| 172 |
|
| 173 |
# Sync DB to Drive (Async)
|
| 174 |
background_tasks.add_task(drive_service.upload_db)
|
|
@@ -214,6 +214,8 @@ async def refresh_token(
|
|
| 214 |
Use this when the current token is about to expire
|
| 215 |
(or has recently expired) to get a new one without
|
| 216 |
requiring the user to sign in again.
|
|
|
|
|
|
|
| 217 |
"""
|
| 218 |
ip = req.client.host
|
| 219 |
|
|
@@ -226,7 +228,42 @@ async def refresh_token(
|
|
| 226 |
|
| 227 |
try:
|
| 228 |
jwt_service = get_jwt_service()
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
return TokenRefreshResponse(
|
| 232 |
success=True,
|
|
@@ -249,14 +286,15 @@ async def logout(
|
|
| 249 |
"""
|
| 250 |
Logout current user.
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
For full session invalidation, consider implementing
|
| 256 |
-
a token blacklist or reducing token expiry times.
|
| 257 |
"""
|
| 258 |
ip = req.client.host
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
# Log logout
|
| 261 |
audit_log = AuditLog(
|
| 262 |
user_id=user.user_id,
|
|
@@ -270,4 +308,4 @@ async def logout(
|
|
| 270 |
# Sync DB to Drive (Async)
|
| 271 |
background_tasks.add_task(drive_service.upload_db)
|
| 272 |
|
| 273 |
-
return {"success": True, "message": "Logged out successfully"}
|
|
|
|
| 167 |
db.add(audit_log)
|
| 168 |
await db.commit()
|
| 169 |
|
| 170 |
+
# Create our JWT access token with current token_version
|
| 171 |
+
access_token = create_access_token(user.user_id, user.email, user.token_version)
|
| 172 |
|
| 173 |
# Sync DB to Drive (Async)
|
| 174 |
background_tasks.add_task(drive_service.upload_db)
|
|
|
|
| 214 |
Use this when the current token is about to expire
|
| 215 |
(or has recently expired) to get a new one without
|
| 216 |
requiring the user to sign in again.
|
| 217 |
+
|
| 218 |
+
Validates that the token_version is still valid before refreshing.
|
| 219 |
"""
|
| 220 |
ip = req.client.host
|
| 221 |
|
|
|
|
| 228 |
|
| 229 |
try:
|
| 230 |
jwt_service = get_jwt_service()
|
| 231 |
+
|
| 232 |
+
# Decode the token (without verifying expiry) to get user info
|
| 233 |
+
import jwt as pyjwt
|
| 234 |
+
payload = pyjwt.decode(
|
| 235 |
+
request.token,
|
| 236 |
+
jwt_service.secret_key,
|
| 237 |
+
algorithms=[jwt_service.algorithm],
|
| 238 |
+
options={"verify_exp": False}
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
user_id = payload.get("sub")
|
| 242 |
+
token_version = payload.get("tv", 1)
|
| 243 |
+
|
| 244 |
+
if not user_id:
|
| 245 |
+
raise JWTInvalidTokenError("Token missing required claims")
|
| 246 |
+
|
| 247 |
+
# Check if user exists and token version is still valid
|
| 248 |
+
query = select(User).where(User.user_id == user_id, User.is_active == True)
|
| 249 |
+
result = await db.execute(query)
|
| 250 |
+
user = result.scalar_one_or_none()
|
| 251 |
+
|
| 252 |
+
if not user:
|
| 253 |
+
raise HTTPException(
|
| 254 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 255 |
+
detail="User not found or inactive"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Validate token version
|
| 259 |
+
if token_version < user.token_version:
|
| 260 |
+
raise HTTPException(
|
| 261 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 262 |
+
detail="Token has been invalidated. Please sign in again."
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Create new token with current token_version
|
| 266 |
+
new_token = create_access_token(user.user_id, user.email, user.token_version)
|
| 267 |
|
| 268 |
return TokenRefreshResponse(
|
| 269 |
success=True,
|
|
|
|
| 286 |
"""
|
| 287 |
Logout current user.
|
| 288 |
|
| 289 |
+
Increments the user's token_version which invalidates ALL existing
|
| 290 |
+
tokens for this user. This provides instant logout across all devices.
|
|
|
|
|
|
|
|
|
|
| 291 |
"""
|
| 292 |
ip = req.client.host
|
| 293 |
|
| 294 |
+
# Increment token version to invalidate all existing tokens
|
| 295 |
+
user.token_version += 1
|
| 296 |
+
logger.info(f"User {user.user_id} logged out. Token version incremented to {user.token_version}")
|
| 297 |
+
|
| 298 |
# Log logout
|
| 299 |
audit_log = AuditLog(
|
| 300 |
user_id=user.user_id,
|
|
|
|
| 308 |
# Sync DB to Drive (Async)
|
| 309 |
background_tasks.add_task(drive_service.upload_db)
|
| 310 |
|
| 311 |
+
return {"success": True, "message": "Logged out successfully. All sessions invalidated."}
|
services/jwt_service.py
CHANGED
|
@@ -52,12 +52,14 @@ class TokenPayload:
|
|
| 52 |
email: The user's email address
|
| 53 |
issued_at: When the token was issued
|
| 54 |
expires_at: When the token expires
|
|
|
|
| 55 |
extra: Any additional claims in the token
|
| 56 |
"""
|
| 57 |
user_id: str
|
| 58 |
email: str
|
| 59 |
issued_at: datetime
|
| 60 |
expires_at: datetime
|
|
|
|
| 61 |
extra: Dict[str, Any] = None
|
| 62 |
|
| 63 |
def __post_init__(self):
|
|
@@ -169,6 +171,7 @@ class JWTService:
|
|
| 169 |
self,
|
| 170 |
user_id: str,
|
| 171 |
email: str,
|
|
|
|
| 172 |
extra_claims: Optional[Dict[str, Any]] = None,
|
| 173 |
expiry_hours: Optional[int] = None
|
| 174 |
) -> str:
|
|
@@ -178,6 +181,7 @@ class JWTService:
|
|
| 178 |
Args:
|
| 179 |
user_id: The user's unique identifier.
|
| 180 |
email: The user's email address.
|
|
|
|
| 181 |
extra_claims: Additional claims to include in the token.
|
| 182 |
expiry_hours: Custom expiry for this token (overrides default).
|
| 183 |
|
|
@@ -190,6 +194,7 @@ class JWTService:
|
|
| 190 |
payload = {
|
| 191 |
"sub": user_id,
|
| 192 |
"email": email,
|
|
|
|
| 193 |
"iat": now,
|
| 194 |
"exp": now + timedelta(hours=expiry),
|
| 195 |
}
|
|
@@ -199,7 +204,7 @@ class JWTService:
|
|
| 199 |
|
| 200 |
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
| 201 |
|
| 202 |
-
logger.debug(f"Created token for user_id={user_id}")
|
| 203 |
return token
|
| 204 |
|
| 205 |
def verify_token(self, token: str) -> TokenPayload:
|
|
@@ -229,6 +234,7 @@ class JWTService:
|
|
| 229 |
# Extract standard claims
|
| 230 |
user_id = payload.get("sub")
|
| 231 |
email = payload.get("email")
|
|
|
|
| 232 |
iat = payload.get("iat")
|
| 233 |
exp = payload.get("exp")
|
| 234 |
|
|
@@ -240,7 +246,7 @@ class JWTService:
|
|
| 240 |
expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
|
| 241 |
|
| 242 |
# Extract extra claims
|
| 243 |
-
standard_claims = {"sub", "email", "iat", "exp"}
|
| 244 |
extra = {k: v for k, v in payload.items() if k not in standard_claims}
|
| 245 |
|
| 246 |
return TokenPayload(
|
|
@@ -248,6 +254,7 @@ class JWTService:
|
|
| 248 |
email=email,
|
| 249 |
issued_at=issued_at,
|
| 250 |
expires_at=expires_at,
|
|
|
|
| 251 |
extra=extra
|
| 252 |
)
|
| 253 |
|
|
@@ -346,19 +353,20 @@ def get_jwt_service() -> JWTService:
|
|
| 346 |
return _default_service
|
| 347 |
|
| 348 |
|
| 349 |
-
def create_access_token(user_id: str, email: str, **kwargs) -> str:
|
| 350 |
"""
|
| 351 |
Convenience function to create a token using the default service.
|
| 352 |
|
| 353 |
Args:
|
| 354 |
user_id: The user's unique identifier.
|
| 355 |
email: The user's email address.
|
|
|
|
| 356 |
**kwargs: Additional arguments passed to create_token.
|
| 357 |
|
| 358 |
Returns:
|
| 359 |
str: The encoded JWT token.
|
| 360 |
"""
|
| 361 |
-
return get_jwt_service().create_token(user_id, email, **kwargs)
|
| 362 |
|
| 363 |
|
| 364 |
def verify_access_token(token: str) -> TokenPayload:
|
|
|
|
| 52 |
email: The user's email address
|
| 53 |
issued_at: When the token was issued
|
| 54 |
expires_at: When the token expires
|
| 55 |
+
token_version: Version number for token invalidation
|
| 56 |
extra: Any additional claims in the token
|
| 57 |
"""
|
| 58 |
user_id: str
|
| 59 |
email: str
|
| 60 |
issued_at: datetime
|
| 61 |
expires_at: datetime
|
| 62 |
+
token_version: int = 1
|
| 63 |
extra: Dict[str, Any] = None
|
| 64 |
|
| 65 |
def __post_init__(self):
|
|
|
|
| 171 |
self,
|
| 172 |
user_id: str,
|
| 173 |
email: str,
|
| 174 |
+
token_version: int = 1,
|
| 175 |
extra_claims: Optional[Dict[str, Any]] = None,
|
| 176 |
expiry_hours: Optional[int] = None
|
| 177 |
) -> str:
|
|
|
|
| 181 |
Args:
|
| 182 |
user_id: The user's unique identifier.
|
| 183 |
email: The user's email address.
|
| 184 |
+
token_version: User's current token version for invalidation.
|
| 185 |
extra_claims: Additional claims to include in the token.
|
| 186 |
expiry_hours: Custom expiry for this token (overrides default).
|
| 187 |
|
|
|
|
| 194 |
payload = {
|
| 195 |
"sub": user_id,
|
| 196 |
"email": email,
|
| 197 |
+
"tv": token_version, # Token version for invalidation
|
| 198 |
"iat": now,
|
| 199 |
"exp": now + timedelta(hours=expiry),
|
| 200 |
}
|
|
|
|
| 204 |
|
| 205 |
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
| 206 |
|
| 207 |
+
logger.debug(f"Created token for user_id={user_id} (version={token_version})")
|
| 208 |
return token
|
| 209 |
|
| 210 |
def verify_token(self, token: str) -> TokenPayload:
|
|
|
|
| 234 |
# Extract standard claims
|
| 235 |
user_id = payload.get("sub")
|
| 236 |
email = payload.get("email")
|
| 237 |
+
token_version = payload.get("tv", 1) # Default to 1 for backward compatibility
|
| 238 |
iat = payload.get("iat")
|
| 239 |
exp = payload.get("exp")
|
| 240 |
|
|
|
|
| 246 |
expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
|
| 247 |
|
| 248 |
# Extract extra claims
|
| 249 |
+
standard_claims = {"sub", "email", "tv", "iat", "exp"}
|
| 250 |
extra = {k: v for k, v in payload.items() if k not in standard_claims}
|
| 251 |
|
| 252 |
return TokenPayload(
|
|
|
|
| 254 |
email=email,
|
| 255 |
issued_at=issued_at,
|
| 256 |
expires_at=expires_at,
|
| 257 |
+
token_version=token_version,
|
| 258 |
extra=extra
|
| 259 |
)
|
| 260 |
|
|
|
|
| 353 |
return _default_service
|
| 354 |
|
| 355 |
|
| 356 |
+
def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
|
| 357 |
"""
|
| 358 |
Convenience function to create a token using the default service.
|
| 359 |
|
| 360 |
Args:
|
| 361 |
user_id: The user's unique identifier.
|
| 362 |
email: The user's email address.
|
| 363 |
+
token_version: User's current token version for invalidation.
|
| 364 |
**kwargs: Additional arguments passed to create_token.
|
| 365 |
|
| 366 |
Returns:
|
| 367 |
str: The encoded JWT token.
|
| 368 |
"""
|
| 369 |
+
return get_jwt_service().create_token(user_id, email, token_version, **kwargs)
|
| 370 |
|
| 371 |
|
| 372 |
def verify_access_token(token: str) -> TokenPayload:
|