jebin2 commited on
Commit
75fb504
·
1 Parent(s): 5e3877e
core/schemas.py CHANGED
@@ -12,12 +12,14 @@ class GoogleAuthRequest(BaseModel):
12
  """Request with Google ID token from frontend Sign-In."""
13
  id_token: str = Field(..., min_length=1, description="Google ID token from Sign-In")
14
  temp_user_id: Optional[str] = Field(None, description="Optional temp user ID for linking")
 
15
 
16
 
17
  class AuthResponse(BaseModel):
18
  """Response after successful Google authentication."""
19
  success: bool
20
  access_token: str
 
21
  user_id: str
22
  email: str
23
  name: Optional[str] = None
@@ -36,11 +38,13 @@ class UserInfoResponse(BaseModel):
36
 
37
  class TokenRefreshRequest(BaseModel):
38
  """Request to refresh an access token."""
39
- token: str = Field(..., description="Current access token to refresh")
 
40
 
41
 
42
  class TokenRefreshResponse(BaseModel):
43
  """Response with refreshed access token."""
44
  success: bool
45
  access_token: str
 
46
 
 
12
  """Request with Google ID token from frontend Sign-In."""
13
  id_token: str = Field(..., min_length=1, description="Google ID token from Sign-In")
14
  temp_user_id: Optional[str] = Field(None, description="Optional temp user ID for linking")
15
+ client_type: str = Field("web", description="Client type: 'web' (cookies) or 'mobile' (body)")
16
 
17
 
18
  class AuthResponse(BaseModel):
19
  """Response after successful Google authentication."""
20
  success: bool
21
  access_token: str
22
+ refresh_token: str
23
  user_id: str
24
  email: str
25
  name: Optional[str] = None
 
38
 
39
  class TokenRefreshRequest(BaseModel):
40
  """Request to refresh an access token."""
41
+ """Request to refresh an access token."""
42
+ token: Optional[str] = Field(None, description="Current refresh token (optional if in cookie)")
43
 
44
 
45
  class TokenRefreshResponse(BaseModel):
46
  """Response with refreshed access token."""
47
  success: bool
48
  access_token: str
49
+ refresh_token: str
50
 
dependencies.py CHANGED
@@ -36,16 +36,16 @@ async def check_rate_limit(
36
  now = datetime.utcnow()
37
  window_start = now - timedelta(minutes=window_minutes)
38
 
39
- # Check existing limit
40
  query = select(RateLimit).where(
41
  and_(
42
  RateLimit.identifier == identifier,
43
  RateLimit.endpoint == endpoint,
44
  RateLimit.window_start >= window_start
45
  )
46
- )
47
  result = await db.execute(query)
48
- rate_limit = result.scalar_one_or_none()
49
 
50
  if rate_limit:
51
  if rate_limit.attempts >= limit:
@@ -104,6 +104,14 @@ async def get_current_user(
104
 
105
  try:
106
  payload = verify_access_token(token)
 
 
 
 
 
 
 
 
107
  except TokenExpiredError:
108
  raise HTTPException(
109
  status_code=status.HTTP_401_UNAUTHORIZED,
 
36
  now = datetime.utcnow()
37
  window_start = now - timedelta(minutes=window_minutes)
38
 
39
+ # Check existing limit (get most recent if multiple exist)
40
  query = select(RateLimit).where(
41
  and_(
42
  RateLimit.identifier == identifier,
43
  RateLimit.endpoint == endpoint,
44
  RateLimit.window_start >= window_start
45
  )
46
+ ).order_by(RateLimit.window_start.desc())
47
  result = await db.execute(query)
48
+ rate_limit = result.scalars().first()
49
 
50
  if rate_limit:
51
  if rate_limit.attempts >= limit:
 
104
 
105
  try:
106
  payload = verify_access_token(token)
107
+
108
+ # Ensure it's an access token, not a refresh token
109
+ if payload.extra.get("type") == "refresh":
110
+ raise HTTPException(
111
+ status_code=status.HTTP_401_UNAUTHORIZED,
112
+ detail="Cannot use refresh token for API access"
113
+ )
114
+
115
  except TokenExpiredError:
116
  raise HTTPException(
117
  status_code=status.HTTP_401_UNAUTHORIZED,
routers/auth.py CHANGED
@@ -32,6 +32,7 @@ from services.auth_service.google_provider import (
32
  from services.auth_service.jwt_provider import (
33
  JWTService,
34
  create_access_token,
 
35
  get_jwt_service,
36
  InvalidTokenError as JWTInvalidTokenError,
37
  )
@@ -78,15 +79,11 @@ async def google_auth(
78
  """
79
  Authenticate with Google ID token.
80
 
81
- Frontend flow:
82
- 1. User clicks "Sign in with Google" button
83
- 2. Google returns an ID token
84
- 3. Frontend sends that token to this endpoint
85
- 4. We verify it with Google and issue our own JWT
86
-
87
- Creates new user or returns existing user.
88
- Existing users matched by email.
89
  """
 
90
  ip = req.client.host
91
 
92
  # Rate Limit: 10 attempts per minute per IP
@@ -200,23 +197,44 @@ async def google_auth(
200
  )
201
  await db.commit()
202
 
203
- # Create our JWT access token with current token_version
204
  access_token = create_access_token(user.user_id, user.email, user.token_version)
 
205
 
206
  # Sync DB to Drive (Async)
207
  from services.backup_service import get_backup_service
208
  backup_service = get_backup_service()
209
  background_tasks.add_task(backup_service.backup_async)
210
 
211
- return AuthResponse(
212
- success=True,
213
- access_token=access_token,
214
- user_id=user.user_id,
215
- email=user.email,
216
- name=user.name,
217
- credits=user.credits,
218
- is_new_user=is_new_user
219
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
 
222
  @router.get("/me", response_model=UserInfoResponse)
@@ -254,8 +272,8 @@ async def refresh_token(
254
  """
255
  ip = req.client.host
256
 
257
- # Rate Limit: 5 refreshes per minute per IP
258
- if not await check_rate_limit(db, ip, "/auth/refresh", 5, 1):
259
  raise HTTPException(
260
  status_code=status.HTTP_429_TOO_MANY_REQUESTS,
261
  detail="Too many refresh attempts"
@@ -264,10 +282,24 @@ async def refresh_token(
264
  try:
265
  jwt_service = get_jwt_service()
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  # Decode the token (without verifying expiry) to get user info
268
  import jwt as pyjwt
269
  payload = pyjwt.decode(
270
- request.token,
271
  jwt_service.secret_key,
272
  algorithms=[jwt_service.algorithm],
273
  options={"verify_exp": False}
@@ -275,9 +307,17 @@ async def refresh_token(
275
 
276
  user_id = payload.get("sub")
277
  token_version = payload.get("tv", 1)
 
278
 
279
  if not user_id:
280
  raise JWTInvalidTokenError("Token missing required claims")
 
 
 
 
 
 
 
281
 
282
  # Check if user exists and token version is still valid
283
  query = select(User).where(User.user_id == user_id, User.is_active == True)
@@ -297,13 +337,33 @@ async def refresh_token(
297
  detail="Token has been invalidated. Please sign in again."
298
  )
299
 
300
- # Create new token with current token_version
301
- new_token = create_access_token(user.user_id, user.email, user.token_version)
302
 
303
- return TokenRefreshResponse(
304
- success=True,
305
- access_token=new_token
306
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  except JWTInvalidTokenError as e:
308
  raise HTTPException(
309
  status_code=status.HTTP_401_UNAUTHORIZED,
@@ -346,4 +406,6 @@ async def logout(
346
  backup_service = get_backup_service()
347
  background_tasks.add_task(backup_service.backup_async)
348
 
349
- return {"success": True, "message": "Logged out successfully. All sessions invalidated."}
 
 
 
32
  from services.auth_service.jwt_provider import (
33
  JWTService,
34
  create_access_token,
35
+ create_refresh_token,
36
  get_jwt_service,
37
  InvalidTokenError as JWTInvalidTokenError,
38
  )
 
79
  """
80
  Authenticate with Google ID token.
81
 
82
+ Supports two client types:
83
+ - "web": Sets refresh_token in HttpOnly cookie (secure)
84
+ - "mobile": Returns refresh_token in JSON body
 
 
 
 
 
85
  """
86
+ response = JSONResponse(content={}) # Placeholder, will be populated later
87
  ip = req.client.host
88
 
89
  # Rate Limit: 10 attempts per minute per IP
 
197
  )
198
  await db.commit()
199
 
200
+ # Create our JWT access token and refresh token
201
  access_token = create_access_token(user.user_id, user.email, user.token_version)
202
+ refresh_token = create_refresh_token(user.user_id, user.email, user.token_version)
203
 
204
  # Sync DB to Drive (Async)
205
  from services.backup_service import get_backup_service
206
  backup_service = get_backup_service()
207
  background_tasks.add_task(backup_service.backup_async)
208
 
209
+ # Prepare response data
210
+ response_data = {
211
+ "success": True,
212
+ "access_token": access_token,
213
+ "user_id": user.user_id,
214
+ "email": user.email,
215
+ "name": user.name,
216
+ "credits": user.credits,
217
+ "is_new_user": is_new_user
218
+ }
219
+
220
+ # Handle token delivery based on client type
221
+ if request.client_type == "web":
222
+ # Web: Set HttpOnly cookie for refresh token
223
+ response = JSONResponse(content=response_data)
224
+ response.set_cookie(
225
+ key="refresh_token",
226
+ value=refresh_token,
227
+ httponly=True,
228
+ secure=True, # Should be True in production
229
+ samesite="lax",
230
+ max_age=7 * 24 * 60 * 60 # 7 days
231
+ )
232
+ else:
233
+ # Mobile: Return refresh token in body
234
+ response_data["refresh_token"] = refresh_token
235
+ response = JSONResponse(content=response_data)
236
+
237
+ return response
238
 
239
 
240
  @router.get("/me", response_model=UserInfoResponse)
 
272
  """
273
  ip = req.client.host
274
 
275
+ # Rate Limit: 20 refreshes per minute per IP (increased for proactive refresh on page load)
276
+ if not await check_rate_limit(db, ip, "/auth/refresh", 20, 1):
277
  raise HTTPException(
278
  status_code=status.HTTP_429_TOO_MANY_REQUESTS,
279
  detail="Too many refresh attempts"
 
282
  try:
283
  jwt_service = get_jwt_service()
284
 
285
+ # Get token from body or cookie
286
+ token_to_refresh = request.token
287
+ using_cookie = False
288
+
289
+ if not token_to_refresh:
290
+ token_to_refresh = req.cookies.get("refresh_token")
291
+ using_cookie = True
292
+
293
+ if not token_to_refresh:
294
+ raise HTTPException(
295
+ status_code=status.HTTP_401_UNAUTHORIZED,
296
+ detail="Refresh token missing"
297
+ )
298
+
299
  # Decode the token (without verifying expiry) to get user info
300
  import jwt as pyjwt
301
  payload = pyjwt.decode(
302
+ token_to_refresh,
303
  jwt_service.secret_key,
304
  algorithms=[jwt_service.algorithm],
305
  options={"verify_exp": False}
 
307
 
308
  user_id = payload.get("sub")
309
  token_version = payload.get("tv", 1)
310
+ token_type = payload.get("type", "access")
311
 
312
  if not user_id:
313
  raise JWTInvalidTokenError("Token missing required claims")
314
+
315
+ # Verify it's a refresh token
316
+ if token_type != "refresh":
317
+ raise HTTPException(
318
+ status_code=status.HTTP_401_UNAUTHORIZED,
319
+ detail="Invalid token type. Expected refresh token."
320
+ )
321
 
322
  # Check if user exists and token version is still valid
323
  query = select(User).where(User.user_id == user_id, User.is_active == True)
 
337
  detail="Token has been invalidated. Please sign in again."
338
  )
339
 
340
+ # Create new access token
341
+ new_access_token = create_access_token(user.user_id, user.email, user.token_version)
342
 
343
+ # ROTATION: Issue new refresh token
344
+ new_refresh_token = create_refresh_token(user.user_id, user.email, user.token_version)
345
+
346
+ response_data = {
347
+ "success": True,
348
+ "access_token": new_access_token
349
+ }
350
+
351
+ if using_cookie:
352
+ # If came from cookie, rotate cookie
353
+ response = JSONResponse(content=response_data)
354
+ response.set_cookie(
355
+ key="refresh_token",
356
+ value=new_refresh_token,
357
+ httponly=True,
358
+ secure=True,
359
+ samesite="lax",
360
+ max_age=7 * 24 * 60 * 60
361
+ )
362
+ return response
363
+ else:
364
+ # If came from body, return in body
365
+ response_data["refresh_token"] = new_refresh_token
366
+ return TokenRefreshResponse(**response_data)
367
  except JWTInvalidTokenError as e:
368
  raise HTTPException(
369
  status_code=status.HTTP_401_UNAUTHORIZED,
 
406
  backup_service = get_backup_service()
407
  background_tasks.add_task(backup_service.backup_async)
408
 
409
+ response = JSONResponse(content={"success": True, "message": "Logged out successfully. All sessions invalidated."})
410
+ response.delete_cookie(key="refresh_token")
411
+ return response
services/auth_service/jwt_provider.py CHANGED
@@ -60,6 +60,7 @@ class TokenPayload:
60
  issued_at: datetime
61
  expires_at: datetime
62
  token_version: int = 1
 
63
  extra: Dict[str, Any] = None
64
 
65
  def __post_init__(self):
@@ -122,30 +123,33 @@ class JWTService:
122
 
123
  # Default configuration
124
  DEFAULT_ALGORITHM = "HS256"
125
- DEFAULT_EXPIRY_HOURS = 168 # 7 days
 
126
 
127
  def __init__(
128
  self,
129
  secret_key: Optional[str] = None,
130
  algorithm: Optional[str] = None,
131
- expiry_hours: Optional[int] = None
 
132
  ):
133
  """
134
  Initialize the JWT Service.
135
 
136
  Args:
137
- secret_key: Secret key for signing tokens. If not provided,
138
- falls back to JWT_SECRET environment variable.
139
  algorithm: JWT algorithm (default: HS256).
140
- expiry_hours: Token expiry in hours (default: 168 = 7 days).
141
-
142
- Raises:
143
- ConfigurationError: If no secret_key is provided or found.
144
  """
145
  self.secret_key = secret_key or os.getenv("JWT_SECRET")
146
  self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM)
147
- self.expiry_hours = expiry_hours or int(
148
- os.getenv("JWT_EXPIRY_HOURS", str(self.DEFAULT_EXPIRY_HOURS))
 
 
 
 
149
  )
150
 
151
  if not self.secret_key:
@@ -163,40 +167,38 @@ class JWTService:
163
  )
164
 
165
  logger.info(
166
- f"JWTService initialized (algorithm={self.algorithm}, "
167
- f"expiry={self.expiry_hours}h)"
168
  )
169
 
170
  def create_token(
171
  self,
172
  user_id: str,
173
  email: str,
 
174
  token_version: int = 1,
175
  extra_claims: Optional[Dict[str, Any]] = None,
176
- expiry_hours: Optional[int] = None
177
  ) -> str:
178
  """
179
- Create a JWT token for a user.
180
-
181
- Args:
182
- user_id: The user's unique identifier.
183
- email: The user's email address.
184
- token_version: User's current token version for invalidation.
185
- extra_claims: Additional claims to include in the token.
186
- expiry_hours: Custom expiry for this token (overrides default).
187
-
188
- Returns:
189
- str: The encoded JWT token.
190
  """
191
  now = datetime.utcnow()
192
- expiry = expiry_hours or self.expiry_hours
 
 
 
 
 
 
193
 
194
  payload = {
195
  "sub": user_id,
196
  "email": email,
197
- "tv": token_version, # Token version for invalidation
 
198
  "iat": now,
199
- "exp": now + timedelta(hours=expiry),
200
  }
201
 
202
  if extra_claims:
@@ -204,8 +206,18 @@ class JWTService:
204
 
205
  token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
206
 
207
- logger.debug(f"Created token for user_id={user_id} (version={token_version})")
 
 
208
  return token
 
 
 
 
 
 
 
 
209
 
210
  def verify_token(self, token: str) -> TokenPayload:
211
  """
@@ -234,19 +246,20 @@ class JWTService:
234
  # Extract standard claims
235
  user_id = payload.get("sub")
236
  email = payload.get("email")
237
- token_version = payload.get("tv", 1) # Default to 1 for backward compatibility
 
238
  iat = payload.get("iat")
239
  exp = payload.get("exp")
240
 
241
  if not user_id or not email:
242
  raise InvalidTokenError("Token missing required claims (sub, email)")
243
 
244
- # Convert timestamps to datetime
245
  issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat
246
  expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
247
 
248
  # Extract extra claims
249
- standard_claims = {"sub", "email", "tv", "iat", "exp"}
250
  extra = {k: v for k, v in payload.items() if k not in standard_claims}
251
 
252
  return TokenPayload(
@@ -255,6 +268,7 @@ class JWTService:
255
  issued_at=issued_at,
256
  expires_at=expires_at,
257
  token_version=token_version,
 
258
  extra=extra
259
  )
260
 
@@ -366,7 +380,13 @@ def create_access_token(user_id: str, email: str, token_version: int = 1, **kwar
366
  Returns:
367
  str: The encoded JWT token.
368
  """
369
- return get_jwt_service().create_token(user_id, email, token_version, **kwargs)
 
 
 
 
 
 
370
 
371
 
372
  def verify_access_token(token: str) -> TokenPayload:
 
60
  issued_at: datetime
61
  expires_at: datetime
62
  token_version: int = 1
63
+ token_type: str = "access" # "access" or "refresh"
64
  extra: Dict[str, Any] = None
65
 
66
  def __post_init__(self):
 
123
 
124
  # Default configuration
125
  DEFAULT_ALGORITHM = "HS256"
126
+ DEFAULT_ACCESS_EXPIRY_MINUTES = 15 # 15 minutes
127
+ DEFAULT_REFRESH_EXPIRY_DAYS = 7 # 7 days
128
 
129
  def __init__(
130
  self,
131
  secret_key: Optional[str] = None,
132
  algorithm: Optional[str] = None,
133
+ access_expiry_minutes: Optional[int] = None,
134
+ refresh_expiry_days: Optional[int] = None
135
  ):
136
  """
137
  Initialize the JWT Service.
138
 
139
  Args:
140
+ secret_key: Secret key for signing tokens.
 
141
  algorithm: JWT algorithm (default: HS256).
142
+ access_expiry_minutes: Access token expiry (default: 15 min).
143
+ refresh_expiry_days: Refresh token expiry (default: 7 days).
 
 
144
  """
145
  self.secret_key = secret_key or os.getenv("JWT_SECRET")
146
  self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM)
147
+
148
+ self.access_expiry_minutes = access_expiry_minutes or int(
149
+ os.getenv("JWT_ACCESS_EXPIRY_MINUTES", str(self.DEFAULT_ACCESS_EXPIRY_MINUTES))
150
+ )
151
+ self.refresh_expiry_days = refresh_expiry_days or int(
152
+ os.getenv("JWT_REFRESH_EXPIRY_DAYS", str(self.DEFAULT_REFRESH_EXPIRY_DAYS))
153
  )
154
 
155
  if not self.secret_key:
 
167
  )
168
 
169
  logger.info(
170
+ f"JWTService initialized (alg={self.algorithm}, "
171
+ f"access={self.access_expiry_minutes}m, refresh={self.refresh_expiry_days}d)"
172
  )
173
 
174
  def create_token(
175
  self,
176
  user_id: str,
177
  email: str,
178
+ token_type: str = "access",
179
  token_version: int = 1,
180
  extra_claims: Optional[Dict[str, Any]] = None,
181
+ expiry_delta: Optional[timedelta] = None
182
  ) -> str:
183
  """
184
+ Create a JWT token.
 
 
 
 
 
 
 
 
 
 
185
  """
186
  now = datetime.utcnow()
187
+
188
+ if expiry_delta:
189
+ expires_at = now + expiry_delta
190
+ elif token_type == "refresh":
191
+ expires_at = now + timedelta(days=self.refresh_expiry_days)
192
+ else:
193
+ expires_at = now + timedelta(minutes=self.access_expiry_minutes)
194
 
195
  payload = {
196
  "sub": user_id,
197
  "email": email,
198
+ "type": token_type,
199
+ "tv": token_version,
200
  "iat": now,
201
+ "exp": expires_at,
202
  }
203
 
204
  if extra_claims:
 
206
 
207
  token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
208
 
209
+ token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
210
+
211
+ logger.debug(f"Created {token_type} token for {user_id}")
212
  return token
213
+
214
+ def create_access_token(self, user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
215
+ """Create a short-lived access token."""
216
+ return self.create_token(user_id, email, "access", token_version, **kwargs)
217
+
218
+ def create_refresh_token(self, user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
219
+ """Create a long-lived refresh token."""
220
+ return self.create_token(user_id, email, "refresh", token_version, **kwargs)
221
 
222
  def verify_token(self, token: str) -> TokenPayload:
223
  """
 
246
  # Extract standard claims
247
  user_id = payload.get("sub")
248
  email = payload.get("email")
249
+ token_type = payload.get("type", "access") # Default to access for backward compat
250
+ token_version = payload.get("tv", 1)
251
  iat = payload.get("iat")
252
  exp = payload.get("exp")
253
 
254
  if not user_id or not email:
255
  raise InvalidTokenError("Token missing required claims (sub, email)")
256
 
257
+ # Convert timestamps
258
  issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat
259
  expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
260
 
261
  # Extract extra claims
262
+ standard_claims = {"sub", "email", "type", "tv", "iat", "exp"}
263
  extra = {k: v for k, v in payload.items() if k not in standard_claims}
264
 
265
  return TokenPayload(
 
268
  issued_at=issued_at,
269
  expires_at=expires_at,
270
  token_version=token_version,
271
+ token_type=token_type,
272
  extra=extra
273
  )
274
 
 
380
  Returns:
381
  str: The encoded JWT token.
382
  """
383
+ def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
384
+ """Convenience function to create an access token."""
385
+ return get_jwt_service().create_access_token(user_id, email, token_version, **kwargs)
386
+
387
+ def create_refresh_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
388
+ """Convenience function to create a refresh token."""
389
+ return get_jwt_service().create_refresh_token(user_id, email, token_version, **kwargs)
390
 
391
 
392
  def verify_access_token(token: str) -> TokenPayload: