jebin2 commited on
Commit
bc8ed4e
·
1 Parent(s): 3c56e03
app.py CHANGED
@@ -108,7 +108,7 @@ async def lifespan(app: FastAPI):
108
  logger.info("✅ Database initialized")
109
 
110
  # Start background job worker
111
- from services.gemini_job_worker import start_worker, stop_worker
112
  await start_worker()
113
  logger.info("Background job worker started")
114
 
 
108
  logger.info("✅ Database initialized")
109
 
110
  # Start background job worker
111
+ from services.gemini_service import start_worker, stop_worker
112
  await start_worker()
113
  logger.info("Background job worker started")
114
 
dependencies.py CHANGED
@@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
9
 
10
  from core.database import get_db
11
  from core.models import User, RateLimit
12
- from services.jwt_service import (
13
  verify_access_token,
14
  TokenExpiredError,
15
  InvalidTokenError,
 
9
 
10
  from core.database import get_db
11
  from core.models import User, RateLimit
12
+ from services.auth_service.jwt_provider import (
13
  verify_access_token,
14
  TokenExpiredError,
15
  InvalidTokenError,
routers/auth.py CHANGED
@@ -22,17 +22,14 @@ from core.schemas import (
22
  TokenRefreshRequest,
23
  TokenRefreshResponse
24
  )
25
- from services.google_auth_service import (
26
  GoogleAuthService,
27
- InvalidTokenError as GoogleInvalidTokenError,
28
- ConfigurationError as GoogleConfigError,
29
- get_google_auth_service
30
  )
31
- from services.jwt_service import (
32
  JWTService,
33
  create_access_token,
34
- get_jwt_service,
35
- InvalidTokenError as JWTInvalidTokenError
36
  )
37
  from dependencies import check_rate_limit, get_current_user
38
  from services.drive_service import DriveService
 
22
  TokenRefreshRequest,
23
  TokenRefreshResponse
24
  )
25
+ from services.auth_service.google_provider import (
26
  GoogleAuthService,
27
+ GoogleUserInfo,
28
+ InvalidTokenError,
 
29
  )
30
+ from services.auth_service.jwt_provider import (
31
  JWTService,
32
  create_access_token,
 
 
33
  )
34
  from dependencies import check_rate_limit, get_current_user
35
  from services.drive_service import DriveService
services/credit_service.py DELETED
@@ -1,257 +0,0 @@
1
- """
2
- Credit Service - Manages credit reservation, confirmation, and refunding.
3
-
4
- Implements the Credit Reservation Pattern:
5
- 1. Reserve credits when job is created (deduct from user, track in job)
6
- 2. Confirm credits only on successful completion
7
- 3. Refund credits on refundable errors (server-side issues)
8
- 4. Keep credits on non-refundable errors (user-caused issues)
9
- """
10
- import logging
11
- from typing import Optional
12
- from sqlalchemy.ext.asyncio import AsyncSession
13
- from sqlalchemy import select
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- # =============================================================================
19
- # Error Categories for Refund Decisions
20
- # =============================================================================
21
-
22
- # Refundable errors - User gets credits back (server/API issues)
23
- REFUNDABLE_ERROR_PATTERNS = [
24
- "API_KEY_INVALID",
25
- "QUOTA_EXCEEDED",
26
- "INTERNAL_ERROR",
27
- "CONNECTION_FAILED",
28
- "SERVER_SHUTDOWN",
29
- "TIMEOUT",
30
- "Server Authentication Error",
31
- "Network error",
32
- "Connection refused",
33
- "Connection reset",
34
- "Service unavailable",
35
- "503",
36
- "500",
37
- "429", # Rate limit (our quota, not user's fault)
38
- ]
39
-
40
- # Non-refundable error patterns - User's input/content issue
41
- NON_REFUNDABLE_ERROR_PATTERNS = [
42
- "safety",
43
- "blocked",
44
- "SAFETY_FILTER",
45
- "INVALID_INPUT",
46
- "Invalid image",
47
- "Bad request",
48
- "400",
49
- "cancelled",
50
- "User cancelled",
51
- ]
52
-
53
-
54
- def is_refundable_error(error_message: Optional[str]) -> bool:
55
- """
56
- Determine if an error should result in a credit refund.
57
-
58
- Args:
59
- error_message: The error message from the failed job
60
-
61
- Returns:
62
- True if the error is refundable (server/API issue)
63
- False if non-refundable (user's fault) or no error message
64
- """
65
- if not error_message:
66
- return False
67
-
68
- error_lower = error_message.lower()
69
-
70
- # Check for REFUNDABLE patterns FIRST (specific server errors take precedence)
71
- # This ensures API_KEY_INVALID is caught before generic "400" matcher
72
- for pattern in REFUNDABLE_ERROR_PATTERNS:
73
- if pattern.lower() in error_lower:
74
- logger.debug(f"Error matched refundable pattern '{pattern}': {error_message[:100]}")
75
- return True
76
-
77
- # Check for non-refundable patterns (user-caused issues)
78
- for pattern in NON_REFUNDABLE_ERROR_PATTERNS:
79
- if pattern.lower() in error_lower:
80
- logger.debug(f"Error matched non-refundable pattern '{pattern}': {error_message[:100]}")
81
- return False
82
-
83
- # Default: Max retries exceeded is refundable (we consumed API resources trying)
84
- if "max retries" in error_lower:
85
- return True
86
-
87
- # Default: Unknown errors are NOT refundable to prevent abuse
88
- # If it's an unknown error, it's more likely user-caused
89
- logger.debug(f"Unknown error (not refundable): {error_message[:100]}")
90
- return False
91
-
92
-
93
- async def reserve_credit(session: AsyncSession, user, amount: int = 1) -> bool:
94
- """
95
- Reserve credits for a job (deduct from user's balance).
96
-
97
- The credits are deducted but tracked in the job's credits_reserved field.
98
- If the job fails with a refundable error, they can be restored.
99
-
100
- Args:
101
- session: Database session
102
- user: User model instance
103
- amount: Number of credits to reserve (default: 1)
104
-
105
- Returns:
106
- True if credits were successfully reserved
107
- False if user has insufficient credits
108
- """
109
- if user.credits < amount:
110
- logger.warning(f"User {user.user_id} has insufficient credits ({user.credits}) to reserve {amount}")
111
- return False
112
-
113
- user.credits -= amount
114
- logger.info(f"Reserved {amount} credit(s) for user {user.user_id}. Remaining: {user.credits}")
115
- # Note: Don't commit here - let caller handle transaction
116
- return True
117
-
118
-
119
- async def confirm_credit(session: AsyncSession, job) -> None:
120
- """
121
- Confirm that credits were legitimately used for a completed job.
122
-
123
- This is called when a job completes successfully. The credits stay
124
- deducted (they were already deducted during reservation).
125
-
126
- Args:
127
- session: Database session
128
- job: GeminiJob model instance
129
- """
130
- if job.credits_reserved > 0:
131
- # Credits were used - clear the reservation tracking
132
- credits_used = job.credits_reserved
133
- job.credits_reserved = 0
134
- logger.info(f"Confirmed {credits_used} credit(s) used for job {job.job_id}")
135
- # Note: Don't commit here - let caller handle transaction
136
-
137
-
138
- async def refund_credit(session: AsyncSession, job, reason: str) -> bool:
139
- """
140
- Refund reserved credits back to the user.
141
-
142
- Called when a job fails due to a refundable error (server-side issue).
143
-
144
- Args:
145
- session: Database session
146
- job: GeminiJob model instance
147
- reason: Reason for the refund (for logging)
148
-
149
- Returns:
150
- True if credits were refunded
151
- False if no credits to refund or already refunded
152
- """
153
- if job.credits_reserved <= 0:
154
- logger.debug(f"Job {job.job_id} has no credits to refund")
155
- return False
156
-
157
- if job.credits_refunded:
158
- logger.warning(f"Job {job.job_id} was already refunded")
159
- return False
160
-
161
- # Get the user to restore credits
162
- from core.models import User
163
-
164
- result = await session.execute(
165
- select(User).where(User.id == job.user_id)
166
- )
167
- user = result.scalar_one_or_none()
168
-
169
- if not user:
170
- logger.error(f"Cannot refund job {job.job_id}: User {job.user_id} not found")
171
- return False
172
-
173
- # Restore credits
174
- credits_to_refund = job.credits_reserved
175
- user.credits += credits_to_refund
176
- job.credits_reserved = 0
177
- job.credits_refunded = True
178
-
179
- logger.info(
180
- f"Refunded {credits_to_refund} credit(s) to user {user.user_id} for job {job.job_id}. "
181
- f"Reason: {reason[:100]}. New balance: {user.credits}"
182
- )
183
-
184
- # Note: Don't commit here - let caller handle transaction
185
- return True
186
-
187
-
188
- async def handle_job_completion(session: AsyncSession, job) -> None:
189
- """
190
- Handle credit finalization when a job completes or fails.
191
-
192
- This is the main entry point called by the job worker.
193
-
194
- Args:
195
- session: Database session
196
- job: GeminiJob model instance with final status
197
- """
198
- if job.status == "completed":
199
- # Success - confirm credits were used
200
- await confirm_credit(session, job)
201
-
202
- elif job.status == "failed":
203
- # Failure - check if refundable
204
- if is_refundable_error(job.error_message):
205
- await refund_credit(session, job, job.error_message or "Unknown error")
206
- else:
207
- # Non-refundable - confirm credits were used (user's fault)
208
- await confirm_credit(session, job)
209
- logger.info(f"Job {job.job_id} failed with non-refundable error, credits kept")
210
-
211
- elif job.status == "cancelled":
212
- # Cancelled jobs get refunds only if they were never started
213
- if job.started_at is None:
214
- await refund_credit(session, job, "Job cancelled before processing")
215
- else:
216
- # Was processing - keep credits (API may have been consumed)
217
- await confirm_credit(session, job)
218
- logger.info(f"Job {job.job_id} cancelled during processing, credits kept")
219
-
220
-
221
- async def refund_orphaned_jobs(session: AsyncSession) -> int:
222
- """
223
- Refund credits for jobs that were abandoned due to server shutdown.
224
-
225
- Called during graceful shutdown to ensure no credits are lost.
226
-
227
- Args:
228
- session: Database session
229
-
230
- Returns:
231
- Number of jobs that were refunded
232
- """
233
- from core.models import GeminiJob
234
-
235
- # Find jobs that are still processing with reserved credits
236
- result = await session.execute(
237
- select(GeminiJob).where(
238
- GeminiJob.status == "processing",
239
- GeminiJob.credits_reserved > 0,
240
- GeminiJob.credits_refunded == False
241
- )
242
- )
243
- orphaned_jobs = result.scalars().all()
244
-
245
- refund_count = 0
246
- for job in orphaned_jobs:
247
- if await refund_credit(session, job, "SERVER_SHUTDOWN: Job orphaned during server shutdown"):
248
- # Mark job as failed
249
- job.status = "failed"
250
- job.error_message = "Server shutdown during processing. Credits refunded."
251
- refund_count += 1
252
-
253
- if refund_count > 0:
254
- await session.commit()
255
- logger.info(f"Refunded {refund_count} orphaned job(s) during shutdown")
256
-
257
- return refund_count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
services/google_auth_service.py DELETED
@@ -1,232 +0,0 @@
1
- """
2
- Modular Google OAuth Service
3
-
4
- A self-contained, plug-and-play service for verifying Google ID tokens.
5
- Can be used in any Python application with minimal configuration.
6
-
7
- Usage:
8
- from services.google_auth_service import GoogleAuthService, GoogleUserInfo
9
-
10
- # Initialize with client ID
11
- auth_service = GoogleAuthService(client_id="your-google-client-id")
12
-
13
- # Or use environment variable GOOGLE_CLIENT_ID
14
- auth_service = GoogleAuthService()
15
-
16
- # Verify a Google ID token
17
- user_info = auth_service.verify_token(id_token)
18
- print(user_info.email, user_info.google_id, user_info.name)
19
-
20
- Environment Variables:
21
- GOOGLE_CLIENT_ID: Your Google OAuth 2.0 Client ID
22
-
23
- Dependencies:
24
- google-auth>=2.0.0
25
- google-auth-oauthlib>=1.0.0
26
- """
27
-
28
- import os
29
- import logging
30
- from dataclasses import dataclass
31
- from typing import Optional
32
- from google.oauth2 import id_token as google_id_token
33
- from google.auth.transport import requests as google_requests
34
-
35
- logger = logging.getLogger(__name__)
36
-
37
-
38
- @dataclass
39
- class GoogleUserInfo:
40
- """
41
- User information extracted from a verified Google ID token.
42
-
43
- Attributes:
44
- google_id: Unique Google user identifier (sub claim)
45
- email: User's email address
46
- email_verified: Whether Google has verified the email
47
- name: User's display name (may be None)
48
- picture: URL to user's profile picture (may be None)
49
- given_name: User's first name (may be None)
50
- family_name: User's last name (may be None)
51
- locale: User's locale preference (may be None)
52
- """
53
- google_id: str
54
- email: str
55
- email_verified: bool = True
56
- name: Optional[str] = None
57
- picture: Optional[str] = None
58
- given_name: Optional[str] = None
59
- family_name: Optional[str] = None
60
- locale: Optional[str] = None
61
-
62
-
63
- class GoogleAuthError(Exception):
64
- """Base exception for Google Auth errors."""
65
- pass
66
-
67
-
68
- class InvalidTokenError(GoogleAuthError):
69
- """Raised when the token is invalid or expired."""
70
- pass
71
-
72
-
73
- class ConfigurationError(GoogleAuthError):
74
- """Raised when the service is not properly configured."""
75
- pass
76
-
77
-
78
- class GoogleAuthService:
79
- """
80
- Service for verifying Google OAuth ID tokens.
81
-
82
- This service validates ID tokens issued by Google Sign-In and extracts
83
- user information. It's designed to be modular and reusable across
84
- different applications.
85
-
86
- Example:
87
- service = GoogleAuthService()
88
- try:
89
- user_info = service.verify_token(token_from_frontend)
90
- print(f"Welcome {user_info.name}!")
91
- except InvalidTokenError:
92
- print("Invalid or expired token")
93
- """
94
-
95
- def __init__(
96
- self,
97
- client_id: Optional[str] = None,
98
- clock_skew_seconds: int = 0
99
- ):
100
- """
101
- Initialize the Google Auth Service.
102
-
103
- Args:
104
- client_id: Google OAuth 2.0 Client ID. If not provided,
105
- falls back to GOOGLE_CLIENT_ID environment variable.
106
- clock_skew_seconds: Allowed clock skew in seconds for token
107
- validation (default: 0).
108
-
109
- Raises:
110
- ConfigurationError: If no client_id is provided or found.
111
- """
112
- self.client_id = client_id or os.getenv("AUTH_SIGN_IN_GOOGLE_CLIENT_ID")
113
- self.clock_skew_seconds = clock_skew_seconds
114
-
115
- if not self.client_id:
116
- raise ConfigurationError(
117
- "Google Client ID is required. Either pass client_id parameter "
118
- "or set GOOGLE_CLIENT_ID environment variable."
119
- )
120
-
121
- logger.info(f"GoogleAuthService initialized with client_id: {self.client_id[:20]}...")
122
-
123
- def verify_token(self, id_token: str) -> GoogleUserInfo:
124
- """
125
- Verify a Google ID token and extract user information.
126
-
127
- Args:
128
- id_token: The ID token received from the frontend after
129
- Google Sign-In.
130
-
131
- Returns:
132
- GoogleUserInfo: Dataclass containing user's Google profile info.
133
-
134
- Raises:
135
- InvalidTokenError: If the token is invalid, expired, or
136
- doesn't match the expected client ID.
137
- """
138
- if not id_token:
139
- raise InvalidTokenError("Token cannot be empty")
140
-
141
- try:
142
- # Verify the token with Google
143
- idinfo = google_id_token.verify_oauth2_token(
144
- id_token,
145
- google_requests.Request(),
146
- self.client_id,
147
- clock_skew_in_seconds=self.clock_skew_seconds
148
- )
149
-
150
- # Validate issuer
151
- if idinfo.get("iss") not in ["accounts.google.com", "https://accounts.google.com"]:
152
- raise InvalidTokenError("Invalid token issuer")
153
-
154
- # Validate audience
155
- if idinfo.get("aud") != self.client_id:
156
- raise InvalidTokenError("Token was not issued for this application")
157
-
158
- # Extract user info
159
- return GoogleUserInfo(
160
- google_id=idinfo["sub"],
161
- email=idinfo["email"],
162
- email_verified=idinfo.get("email_verified", False),
163
- name=idinfo.get("name"),
164
- picture=idinfo.get("picture"),
165
- given_name=idinfo.get("given_name"),
166
- family_name=idinfo.get("family_name"),
167
- locale=idinfo.get("locale")
168
- )
169
-
170
- except ValueError as e:
171
- logger.warning(f"Token verification failed: {e}")
172
- raise InvalidTokenError(f"Token verification failed: {str(e)}")
173
- except Exception as e:
174
- logger.error(f"Unexpected error during token verification: {e}")
175
- raise InvalidTokenError(f"Token verification error: {str(e)}")
176
-
177
- def verify_token_safe(self, id_token: str) -> Optional[GoogleUserInfo]:
178
- """
179
- Verify a Google ID token without raising exceptions.
180
-
181
- Useful for cases where you want to check validity without
182
- exception handling.
183
-
184
- Args:
185
- id_token: The ID token to verify.
186
-
187
- Returns:
188
- GoogleUserInfo if valid, None if invalid.
189
- """
190
- try:
191
- return self.verify_token(id_token)
192
- except GoogleAuthError:
193
- return None
194
-
195
-
196
- # Singleton instance for convenience (initialized on first use)
197
- _default_service: Optional[GoogleAuthService] = None
198
-
199
-
200
- def get_google_auth_service() -> GoogleAuthService:
201
- """
202
- Get the default GoogleAuthService instance.
203
-
204
- Creates a singleton instance using environment variables.
205
-
206
- Returns:
207
- GoogleAuthService: The default service instance.
208
-
209
- Raises:
210
- ConfigurationError: If GOOGLE_CLIENT_ID is not set.
211
- """
212
- global _default_service
213
- if _default_service is None:
214
- _default_service = GoogleAuthService()
215
- return _default_service
216
-
217
-
218
- def verify_google_token(id_token: str) -> GoogleUserInfo:
219
- """
220
- Convenience function to verify a token using the default service.
221
-
222
- Args:
223
- id_token: The Google ID token to verify.
224
-
225
- Returns:
226
- GoogleUserInfo: Verified user information.
227
-
228
- Raises:
229
- InvalidTokenError: If verification fails.
230
- ConfigurationError: If service is not configured.
231
- """
232
- return get_google_auth_service().verify_token(id_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
services/jwt_service.py DELETED
@@ -1,386 +0,0 @@
1
- """
2
- Modular JWT Service
3
-
4
- A self-contained, plug-and-play service for creating and verifying JWT tokens.
5
- Can be used in any Python application with minimal configuration.
6
-
7
- Usage:
8
- from services.jwt_service import JWTService, TokenPayload
9
-
10
- # Initialize with secret key
11
- jwt_service = JWTService(secret_key="your-secret-key")
12
-
13
- # Or use environment variable JWT_SECRET
14
- jwt_service = JWTService()
15
-
16
- # Create a token
17
- token = jwt_service.create_token(user_id="user123", email="user@example.com")
18
-
19
- # Verify a token
20
- payload = jwt_service.verify_token(token)
21
- print(payload.user_id, payload.email)
22
-
23
- Environment Variables:
24
- JWT_SECRET: Your secret key for signing tokens (required)
25
- JWT_EXPIRY_HOURS: Token expiry in hours (default: 168 = 7 days)
26
- JWT_ALGORITHM: Algorithm to use (default: HS256)
27
-
28
- Dependencies:
29
- PyJWT>=2.8.0
30
-
31
- Generate a secure secret:
32
- python -c "import secrets; print(secrets.token_urlsafe(64))"
33
- """
34
-
35
- import os
36
- import logging
37
- from dataclasses import dataclass
38
- from datetime import datetime, timedelta
39
- from typing import Optional, Dict, Any
40
- import jwt
41
-
42
- logger = logging.getLogger(__name__)
43
-
44
-
45
- @dataclass
46
- class TokenPayload:
47
- """
48
- Payload extracted from a verified JWT token.
49
-
50
- Attributes:
51
- user_id: The user's unique identifier (sub claim)
52
- email: The user's email address
53
- issued_at: When the token was issued
54
- expires_at: When the token expires
55
- token_version: Version number for token invalidation
56
- extra: Any additional claims in the token
57
- """
58
- user_id: str
59
- email: str
60
- issued_at: datetime
61
- expires_at: datetime
62
- token_version: int = 1
63
- extra: Dict[str, Any] = None
64
-
65
- def __post_init__(self):
66
- if self.extra is None:
67
- self.extra = {}
68
-
69
- @property
70
- def is_expired(self) -> bool:
71
- """Check if the token has expired."""
72
- return datetime.utcnow() > self.expires_at
73
-
74
- @property
75
- def time_until_expiry(self) -> timedelta:
76
- """Get time remaining until expiry."""
77
- return self.expires_at - datetime.utcnow()
78
-
79
-
80
- class JWTError(Exception):
81
- """Base exception for JWT errors."""
82
- pass
83
-
84
-
85
- class TokenExpiredError(JWTError):
86
- """Raised when the token has expired."""
87
- pass
88
-
89
-
90
- class InvalidTokenError(JWTError):
91
- """Raised when the token is invalid."""
92
- pass
93
-
94
-
95
- class ConfigurationError(JWTError):
96
- """Raised when the service is not properly configured."""
97
- pass
98
-
99
-
100
- class JWTService:
101
- """
102
- Service for creating and verifying JWT tokens.
103
-
104
- This service handles JWT token lifecycle for authentication.
105
- It's designed to be modular and reusable across different applications.
106
-
107
- Example:
108
- service = JWTService(secret_key="my-secret")
109
-
110
- # Create token
111
- token = service.create_token(user_id="u123", email="a@b.com")
112
-
113
- # Verify token
114
- try:
115
- payload = service.verify_token(token)
116
- print(f"User: {payload.user_id}")
117
- except TokenExpiredError:
118
- print("Token expired, please login again")
119
- except InvalidTokenError:
120
- print("Invalid token")
121
- """
122
-
123
- # Default configuration
124
- DEFAULT_ALGORITHM = "HS256"
125
- DEFAULT_EXPIRY_HOURS = 168 # 7 days
126
-
127
- def __init__(
128
- self,
129
- secret_key: Optional[str] = None,
130
- algorithm: Optional[str] = None,
131
- expiry_hours: Optional[int] = None
132
- ):
133
- """
134
- Initialize the JWT Service.
135
-
136
- Args:
137
- secret_key: Secret key for signing tokens. If not provided,
138
- falls back to JWT_SECRET environment variable.
139
- algorithm: JWT algorithm (default: HS256).
140
- expiry_hours: Token expiry in hours (default: 168 = 7 days).
141
-
142
- Raises:
143
- ConfigurationError: If no secret_key is provided or found.
144
- """
145
- self.secret_key = secret_key or os.getenv("JWT_SECRET")
146
- self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM)
147
- self.expiry_hours = expiry_hours or int(
148
- os.getenv("JWT_EXPIRY_HOURS", str(self.DEFAULT_EXPIRY_HOURS))
149
- )
150
-
151
- if not self.secret_key:
152
- raise ConfigurationError(
153
- "JWT secret key is required. Either pass secret_key parameter "
154
- "or set JWT_SECRET environment variable. "
155
- "Generate one with: python -c \"import secrets; print(secrets.token_urlsafe(64))\""
156
- )
157
-
158
- # Warn if secret is too short
159
- if len(self.secret_key) < 32:
160
- logger.warning(
161
- "JWT secret key is short (< 32 chars). "
162
- "Consider using a longer secret for better security."
163
- )
164
-
165
- logger.info(
166
- f"JWTService initialized (algorithm={self.algorithm}, "
167
- f"expiry={self.expiry_hours}h)"
168
- )
169
-
170
- def create_token(
171
- self,
172
- user_id: str,
173
- email: str,
174
- token_version: int = 1,
175
- extra_claims: Optional[Dict[str, Any]] = None,
176
- expiry_hours: Optional[int] = None
177
- ) -> str:
178
- """
179
- Create a JWT token for a user.
180
-
181
- Args:
182
- user_id: The user's unique identifier.
183
- email: The user's email address.
184
- token_version: User's current token version for invalidation.
185
- extra_claims: Additional claims to include in the token.
186
- expiry_hours: Custom expiry for this token (overrides default).
187
-
188
- Returns:
189
- str: The encoded JWT token.
190
- """
191
- now = datetime.utcnow()
192
- expiry = expiry_hours or self.expiry_hours
193
-
194
- payload = {
195
- "sub": user_id,
196
- "email": email,
197
- "tv": token_version, # Token version for invalidation
198
- "iat": now,
199
- "exp": now + timedelta(hours=expiry),
200
- }
201
-
202
- if extra_claims:
203
- payload.update(extra_claims)
204
-
205
- token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
206
-
207
- logger.debug(f"Created token for user_id={user_id} (version={token_version})")
208
- return token
209
-
210
- def verify_token(self, token: str) -> TokenPayload:
211
- """
212
- Verify a JWT token and extract the payload.
213
-
214
- Args:
215
- token: The JWT token to verify.
216
-
217
- Returns:
218
- TokenPayload: Dataclass containing the verified payload.
219
-
220
- Raises:
221
- TokenExpiredError: If the token has expired.
222
- InvalidTokenError: If the token is invalid or malformed.
223
- """
224
- if not token:
225
- raise InvalidTokenError("Token cannot be empty")
226
-
227
- try:
228
- payload = jwt.decode(
229
- token,
230
- self.secret_key,
231
- algorithms=[self.algorithm]
232
- )
233
-
234
- # Extract standard claims
235
- user_id = payload.get("sub")
236
- email = payload.get("email")
237
- token_version = payload.get("tv", 1) # Default to 1 for backward compatibility
238
- iat = payload.get("iat")
239
- exp = payload.get("exp")
240
-
241
- if not user_id or not email:
242
- raise InvalidTokenError("Token missing required claims (sub, email)")
243
-
244
- # Convert timestamps to datetime
245
- issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat
246
- expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
247
-
248
- # Extract extra claims
249
- standard_claims = {"sub", "email", "tv", "iat", "exp"}
250
- extra = {k: v for k, v in payload.items() if k not in standard_claims}
251
-
252
- return TokenPayload(
253
- user_id=user_id,
254
- email=email,
255
- issued_at=issued_at,
256
- expires_at=expires_at,
257
- token_version=token_version,
258
- extra=extra
259
- )
260
-
261
- except jwt.ExpiredSignatureError:
262
- logger.debug("Token verification failed: expired")
263
- raise TokenExpiredError("Token has expired")
264
- except jwt.InvalidTokenError as e:
265
- logger.debug(f"Token verification failed: {e}")
266
- raise InvalidTokenError(f"Invalid token: {str(e)}")
267
- except Exception as e:
268
- logger.error(f"Unexpected error during token verification: {e}")
269
- raise InvalidTokenError(f"Token verification error: {str(e)}")
270
-
271
- def verify_token_safe(self, token: str) -> Optional[TokenPayload]:
272
- """
273
- Verify a JWT token without raising exceptions.
274
-
275
- Args:
276
- token: The JWT token to verify.
277
-
278
- Returns:
279
- TokenPayload if valid, None if invalid or expired.
280
- """
281
- try:
282
- return self.verify_token(token)
283
- except JWTError:
284
- return None
285
-
286
- def refresh_token(
287
- self,
288
- token: str,
289
- expiry_hours: Optional[int] = None
290
- ) -> str:
291
- """
292
- Refresh a token by creating a new one with the same claims.
293
-
294
- Args:
295
- token: The current (possibly expired) token.
296
- expiry_hours: Custom expiry for the new token.
297
-
298
- Returns:
299
- str: A new JWT token with updated expiry.
300
-
301
- Raises:
302
- InvalidTokenError: If the token is malformed.
303
- """
304
- try:
305
- # Decode without verifying expiry
306
- payload = jwt.decode(
307
- token,
308
- self.secret_key,
309
- algorithms=[self.algorithm],
310
- options={"verify_exp": False}
311
- )
312
-
313
- user_id = payload.get("sub")
314
- email = payload.get("email")
315
-
316
- if not user_id or not email:
317
- raise InvalidTokenError("Token missing required claims")
318
-
319
- # Preserve extra claims
320
- standard_claims = {"sub", "email", "iat", "exp"}
321
- extra = {k: v for k, v in payload.items() if k not in standard_claims}
322
-
323
- return self.create_token(
324
- user_id=user_id,
325
- email=email,
326
- extra_claims=extra,
327
- expiry_hours=expiry_hours
328
- )
329
-
330
- except jwt.InvalidTokenError as e:
331
- raise InvalidTokenError(f"Cannot refresh invalid token: {str(e)}")
332
-
333
-
334
- # Singleton instance for convenience
335
- _default_service: Optional[JWTService] = None
336
-
337
-
338
- def get_jwt_service() -> JWTService:
339
- """
340
- Get the default JWTService instance.
341
-
342
- Creates a singleton instance using environment variables.
343
-
344
- Returns:
345
- JWTService: The default service instance.
346
-
347
- Raises:
348
- ConfigurationError: If JWT_SECRET is not set.
349
- """
350
- global _default_service
351
- if _default_service is None:
352
- _default_service = JWTService()
353
- return _default_service
354
-
355
-
356
- def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
357
- """
358
- Convenience function to create a token using the default service.
359
-
360
- Args:
361
- user_id: The user's unique identifier.
362
- email: The user's email address.
363
- token_version: User's current token version for invalidation.
364
- **kwargs: Additional arguments passed to create_token.
365
-
366
- Returns:
367
- str: The encoded JWT token.
368
- """
369
- return get_jwt_service().create_token(user_id, email, token_version, **kwargs)
370
-
371
-
372
- def verify_access_token(token: str) -> TokenPayload:
373
- """
374
- Convenience function to verify a token using the default service.
375
-
376
- Args:
377
- token: The JWT token to verify.
378
-
379
- Returns:
380
- TokenPayload: Verified token payload.
381
-
382
- Raises:
383
- TokenExpiredError: If the token has expired.
384
- InvalidTokenError: If the token is invalid.
385
- """
386
- return get_jwt_service().verify_token(token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
services/priority_worker_pool.py CHANGED
@@ -378,7 +378,7 @@ class PriorityWorker(Generic[JobType]):
378
  return
379
 
380
  try:
381
- from services.credit_service import handle_job_completion
382
  await handle_job_completion(session, job)
383
  except ImportError:
384
  # Credit service not available - skip
@@ -519,7 +519,7 @@ class PriorityWorkerPool(Generic[JobType]):
519
  async def _refund_orphaned_jobs(self):
520
  """Refund credits for jobs abandoned during shutdown."""
521
  try:
522
- from services.credit_service import refund_orphaned_jobs
523
  async with self.session_maker() as session:
524
  refund_count = await refund_orphaned_jobs(session)
525
  if refund_count > 0:
 
378
  return
379
 
380
  try:
381
+ from services.credit_service.credit_manager import handle_job_completion
382
  await handle_job_completion(session, job)
383
  except ImportError:
384
  # Credit service not available - skip
 
519
  async def _refund_orphaned_jobs(self):
520
  """Refund credits for jobs abandoned during shutdown."""
521
  try:
522
+ from services.credit_service.credit_manager import refund_orphaned_jobs
523
  async with self.session_maker() as session:
524
  refund_count = await refund_orphaned_jobs(session)
525
  if refund_count > 0: