jebin2 commited on
Commit
19e4a8c
·
1 Parent(s): 23af55d

tokrn version

Browse files
Files changed (4) hide show
  1. core/models.py +4 -0
  2. dependencies.py +11 -0
  3. routers/auth.py +47 -9
  4. services/jwt_service.py +12 -4
core/models.py CHANGED
@@ -54,6 +54,10 @@ class User(Base):
54
  # Legacy field (kept for migration, nullable now)
55
  secret_key_hash = Column(String(255), nullable=True)
56
 
 
 
 
 
57
  # Credits and status
58
  credits = Column(Integer, default=100)
59
  created_at = Column(DateTime(timezone=True), server_default=func.now())
 
54
  # Legacy field (kept for migration, nullable now)
55
  secret_key_hash = Column(String(255), nullable=True)
56
 
57
+ # Token versioning for JWT invalidation
58
+ # Incrementing this invalidates all existing tokens for this user
59
+ token_version = Column(Integer, default=1, nullable=False)
60
+
61
  # Credits and status
62
  credits = Column(Integer, default=100)
63
  created_at = Column(DateTime(timezone=True), server_default=func.now())
dependencies.py CHANGED
@@ -77,6 +77,8 @@ async def get_current_user(
77
  Extract and verify JWT from Authorization header.
78
  Returns the authenticated user.
79
 
 
 
80
  Usage:
81
  @router.get("/protected")
82
  async def protected_route(user: User = Depends(get_current_user)):
@@ -135,6 +137,15 @@ async def get_current_user(
135
  detail="User not found or inactive"
136
  )
137
 
 
 
 
 
 
 
 
 
 
138
  return user
139
 
140
 
 
77
  Extract and verify JWT from Authorization header.
78
  Returns the authenticated user.
79
 
80
+ Also validates token_version to support instant logout/invalidation.
81
+
82
  Usage:
83
  @router.get("/protected")
84
  async def protected_route(user: User = Depends(get_current_user)):
 
137
  detail="User not found or inactive"
138
  )
139
 
140
+ # Validate token version - if user's version is higher, token is invalidated
141
+ if payload.token_version < user.token_version:
142
+ logger.info(f"Token invalidated for user {user.user_id}: token_version {payload.token_version} < {user.token_version}")
143
+ raise HTTPException(
144
+ status_code=status.HTTP_401_UNAUTHORIZED,
145
+ detail="Token has been invalidated. Please sign in again.",
146
+ headers={"WWW-Authenticate": "Bearer"}
147
+ )
148
+
149
  return user
150
 
151
 
routers/auth.py CHANGED
@@ -167,8 +167,8 @@ async def google_auth(
167
  db.add(audit_log)
168
  await db.commit()
169
 
170
- # Create our JWT access token
171
- access_token = create_access_token(user.user_id, user.email)
172
 
173
  # Sync DB to Drive (Async)
174
  background_tasks.add_task(drive_service.upload_db)
@@ -214,6 +214,8 @@ async def refresh_token(
214
  Use this when the current token is about to expire
215
  (or has recently expired) to get a new one without
216
  requiring the user to sign in again.
 
 
217
  """
218
  ip = req.client.host
219
 
@@ -226,7 +228,42 @@ async def refresh_token(
226
 
227
  try:
228
  jwt_service = get_jwt_service()
229
- new_token = jwt_service.refresh_token(request.token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  return TokenRefreshResponse(
232
  success=True,
@@ -249,14 +286,15 @@ async def logout(
249
  """
250
  Logout current user.
251
 
252
- Note: JWT tokens are stateless, so this endpoint mainly
253
- serves to log the action. Frontend should discard the token.
254
-
255
- For full session invalidation, consider implementing
256
- a token blacklist or reducing token expiry times.
257
  """
258
  ip = req.client.host
259
 
 
 
 
 
260
  # Log logout
261
  audit_log = AuditLog(
262
  user_id=user.user_id,
@@ -270,4 +308,4 @@ async def logout(
270
  # Sync DB to Drive (Async)
271
  background_tasks.add_task(drive_service.upload_db)
272
 
273
- return {"success": True, "message": "Logged out successfully"}
 
167
  db.add(audit_log)
168
  await db.commit()
169
 
170
+ # Create our JWT access token with current token_version
171
+ access_token = create_access_token(user.user_id, user.email, user.token_version)
172
 
173
  # Sync DB to Drive (Async)
174
  background_tasks.add_task(drive_service.upload_db)
 
214
  Use this when the current token is about to expire
215
  (or has recently expired) to get a new one without
216
  requiring the user to sign in again.
217
+
218
+ Validates that the token_version is still valid before refreshing.
219
  """
220
  ip = req.client.host
221
 
 
228
 
229
  try:
230
  jwt_service = get_jwt_service()
231
+
232
+ # Decode the token (without verifying expiry) to get user info
233
+ import jwt as pyjwt
234
+ payload = pyjwt.decode(
235
+ request.token,
236
+ jwt_service.secret_key,
237
+ algorithms=[jwt_service.algorithm],
238
+ options={"verify_exp": False}
239
+ )
240
+
241
+ user_id = payload.get("sub")
242
+ token_version = payload.get("tv", 1)
243
+
244
+ if not user_id:
245
+ raise JWTInvalidTokenError("Token missing required claims")
246
+
247
+ # Check if user exists and token version is still valid
248
+ query = select(User).where(User.user_id == user_id, User.is_active == True)
249
+ result = await db.execute(query)
250
+ user = result.scalar_one_or_none()
251
+
252
+ if not user:
253
+ raise HTTPException(
254
+ status_code=status.HTTP_401_UNAUTHORIZED,
255
+ detail="User not found or inactive"
256
+ )
257
+
258
+ # Validate token version
259
+ if token_version < user.token_version:
260
+ raise HTTPException(
261
+ status_code=status.HTTP_401_UNAUTHORIZED,
262
+ detail="Token has been invalidated. Please sign in again."
263
+ )
264
+
265
+ # Create new token with current token_version
266
+ new_token = create_access_token(user.user_id, user.email, user.token_version)
267
 
268
  return TokenRefreshResponse(
269
  success=True,
 
286
  """
287
  Logout current user.
288
 
289
+ Increments the user's token_version which invalidates ALL existing
290
+ tokens for this user. This provides instant logout across all devices.
 
 
 
291
  """
292
  ip = req.client.host
293
 
294
+ # Increment token version to invalidate all existing tokens
295
+ user.token_version += 1
296
+ logger.info(f"User {user.user_id} logged out. Token version incremented to {user.token_version}")
297
+
298
  # Log logout
299
  audit_log = AuditLog(
300
  user_id=user.user_id,
 
308
  # Sync DB to Drive (Async)
309
  background_tasks.add_task(drive_service.upload_db)
310
 
311
+ return {"success": True, "message": "Logged out successfully. All sessions invalidated."}
services/jwt_service.py CHANGED
@@ -52,12 +52,14 @@ class TokenPayload:
52
  email: The user's email address
53
  issued_at: When the token was issued
54
  expires_at: When the token expires
 
55
  extra: Any additional claims in the token
56
  """
57
  user_id: str
58
  email: str
59
  issued_at: datetime
60
  expires_at: datetime
 
61
  extra: Dict[str, Any] = None
62
 
63
  def __post_init__(self):
@@ -169,6 +171,7 @@ class JWTService:
169
  self,
170
  user_id: str,
171
  email: str,
 
172
  extra_claims: Optional[Dict[str, Any]] = None,
173
  expiry_hours: Optional[int] = None
174
  ) -> str:
@@ -178,6 +181,7 @@ class JWTService:
178
  Args:
179
  user_id: The user's unique identifier.
180
  email: The user's email address.
 
181
  extra_claims: Additional claims to include in the token.
182
  expiry_hours: Custom expiry for this token (overrides default).
183
 
@@ -190,6 +194,7 @@ class JWTService:
190
  payload = {
191
  "sub": user_id,
192
  "email": email,
 
193
  "iat": now,
194
  "exp": now + timedelta(hours=expiry),
195
  }
@@ -199,7 +204,7 @@ class JWTService:
199
 
200
  token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
201
 
202
- logger.debug(f"Created token for user_id={user_id}")
203
  return token
204
 
205
  def verify_token(self, token: str) -> TokenPayload:
@@ -229,6 +234,7 @@ class JWTService:
229
  # Extract standard claims
230
  user_id = payload.get("sub")
231
  email = payload.get("email")
 
232
  iat = payload.get("iat")
233
  exp = payload.get("exp")
234
 
@@ -240,7 +246,7 @@ class JWTService:
240
  expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
241
 
242
  # Extract extra claims
243
- standard_claims = {"sub", "email", "iat", "exp"}
244
  extra = {k: v for k, v in payload.items() if k not in standard_claims}
245
 
246
  return TokenPayload(
@@ -248,6 +254,7 @@ class JWTService:
248
  email=email,
249
  issued_at=issued_at,
250
  expires_at=expires_at,
 
251
  extra=extra
252
  )
253
 
@@ -346,19 +353,20 @@ def get_jwt_service() -> JWTService:
346
  return _default_service
347
 
348
 
349
- def create_access_token(user_id: str, email: str, **kwargs) -> str:
350
  """
351
  Convenience function to create a token using the default service.
352
 
353
  Args:
354
  user_id: The user's unique identifier.
355
  email: The user's email address.
 
356
  **kwargs: Additional arguments passed to create_token.
357
 
358
  Returns:
359
  str: The encoded JWT token.
360
  """
361
- return get_jwt_service().create_token(user_id, email, **kwargs)
362
 
363
 
364
  def verify_access_token(token: str) -> TokenPayload:
 
52
  email: The user's email address
53
  issued_at: When the token was issued
54
  expires_at: When the token expires
55
+ token_version: Version number for token invalidation
56
  extra: Any additional claims in the token
57
  """
58
  user_id: str
59
  email: str
60
  issued_at: datetime
61
  expires_at: datetime
62
+ token_version: int = 1
63
  extra: Dict[str, Any] = None
64
 
65
  def __post_init__(self):
 
171
  self,
172
  user_id: str,
173
  email: str,
174
+ token_version: int = 1,
175
  extra_claims: Optional[Dict[str, Any]] = None,
176
  expiry_hours: Optional[int] = None
177
  ) -> str:
 
181
  Args:
182
  user_id: The user's unique identifier.
183
  email: The user's email address.
184
+ token_version: User's current token version for invalidation.
185
  extra_claims: Additional claims to include in the token.
186
  expiry_hours: Custom expiry for this token (overrides default).
187
 
 
194
  payload = {
195
  "sub": user_id,
196
  "email": email,
197
+ "tv": token_version, # Token version for invalidation
198
  "iat": now,
199
  "exp": now + timedelta(hours=expiry),
200
  }
 
204
 
205
  token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
206
 
207
+ logger.debug(f"Created token for user_id={user_id} (version={token_version})")
208
  return token
209
 
210
  def verify_token(self, token: str) -> TokenPayload:
 
234
  # Extract standard claims
235
  user_id = payload.get("sub")
236
  email = payload.get("email")
237
+ token_version = payload.get("tv", 1) # Default to 1 for backward compatibility
238
  iat = payload.get("iat")
239
  exp = payload.get("exp")
240
 
 
246
  expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
247
 
248
  # Extract extra claims
249
+ standard_claims = {"sub", "email", "tv", "iat", "exp"}
250
  extra = {k: v for k, v in payload.items() if k not in standard_claims}
251
 
252
  return TokenPayload(
 
254
  email=email,
255
  issued_at=issued_at,
256
  expires_at=expires_at,
257
+ token_version=token_version,
258
  extra=extra
259
  )
260
 
 
353
  return _default_service
354
 
355
 
356
+ def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
357
  """
358
  Convenience function to create a token using the default service.
359
 
360
  Args:
361
  user_id: The user's unique identifier.
362
  email: The user's email address.
363
+ token_version: User's current token version for invalidation.
364
  **kwargs: Additional arguments passed to create_token.
365
 
366
  Returns:
367
  str: The encoded JWT token.
368
  """
369
+ return get_jwt_service().create_token(user_id, email, token_version, **kwargs)
370
 
371
 
372
  def verify_access_token(token: str) -> TokenPayload: