Spaces:
Sleeping
Sleeping
refactor
Browse files- app.py +1 -1
- dependencies.py +1 -1
- routers/auth.py +4 -7
- services/credit_service.py +0 -257
- services/google_auth_service.py +0 -232
- services/jwt_service.py +0 -386
- services/priority_worker_pool.py +2 -2
app.py
CHANGED
|
@@ -108,7 +108,7 @@ async def lifespan(app: FastAPI):
|
|
| 108 |
logger.info("✅ Database initialized")
|
| 109 |
|
| 110 |
# Start background job worker
|
| 111 |
-
from services.
|
| 112 |
await start_worker()
|
| 113 |
logger.info("Background job worker started")
|
| 114 |
|
|
|
|
| 108 |
logger.info("✅ Database initialized")
|
| 109 |
|
| 110 |
# Start background job worker
|
| 111 |
+
from services.gemini_service import start_worker, stop_worker
|
| 112 |
await start_worker()
|
| 113 |
logger.info("Background job worker started")
|
| 114 |
|
dependencies.py
CHANGED
|
@@ -9,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|
| 9 |
|
| 10 |
from core.database import get_db
|
| 11 |
from core.models import User, RateLimit
|
| 12 |
-
from services.
|
| 13 |
verify_access_token,
|
| 14 |
TokenExpiredError,
|
| 15 |
InvalidTokenError,
|
|
|
|
| 9 |
|
| 10 |
from core.database import get_db
|
| 11 |
from core.models import User, RateLimit
|
| 12 |
+
from services.auth_service.jwt_provider import (
|
| 13 |
verify_access_token,
|
| 14 |
TokenExpiredError,
|
| 15 |
InvalidTokenError,
|
routers/auth.py
CHANGED
|
@@ -22,17 +22,14 @@ from core.schemas import (
|
|
| 22 |
TokenRefreshRequest,
|
| 23 |
TokenRefreshResponse
|
| 24 |
)
|
| 25 |
-
from services.
|
| 26 |
GoogleAuthService,
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
get_google_auth_service
|
| 30 |
)
|
| 31 |
-
from services.
|
| 32 |
JWTService,
|
| 33 |
create_access_token,
|
| 34 |
-
get_jwt_service,
|
| 35 |
-
InvalidTokenError as JWTInvalidTokenError
|
| 36 |
)
|
| 37 |
from dependencies import check_rate_limit, get_current_user
|
| 38 |
from services.drive_service import DriveService
|
|
|
|
| 22 |
TokenRefreshRequest,
|
| 23 |
TokenRefreshResponse
|
| 24 |
)
|
| 25 |
+
from services.auth_service.google_provider import (
|
| 26 |
GoogleAuthService,
|
| 27 |
+
GoogleUserInfo,
|
| 28 |
+
InvalidTokenError,
|
|
|
|
| 29 |
)
|
| 30 |
+
from services.auth_service.jwt_provider import (
|
| 31 |
JWTService,
|
| 32 |
create_access_token,
|
|
|
|
|
|
|
| 33 |
)
|
| 34 |
from dependencies import check_rate_limit, get_current_user
|
| 35 |
from services.drive_service import DriveService
|
services/credit_service.py
DELETED
|
@@ -1,257 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Credit Service - Manages credit reservation, confirmation, and refunding.
|
| 3 |
-
|
| 4 |
-
Implements the Credit Reservation Pattern:
|
| 5 |
-
1. Reserve credits when job is created (deduct from user, track in job)
|
| 6 |
-
2. Confirm credits only on successful completion
|
| 7 |
-
3. Refund credits on refundable errors (server-side issues)
|
| 8 |
-
4. Keep credits on non-refundable errors (user-caused issues)
|
| 9 |
-
"""
|
| 10 |
-
import logging
|
| 11 |
-
from typing import Optional
|
| 12 |
-
from sqlalchemy.ext.asyncio import AsyncSession
|
| 13 |
-
from sqlalchemy import select
|
| 14 |
-
|
| 15 |
-
logger = logging.getLogger(__name__)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
# =============================================================================
|
| 19 |
-
# Error Categories for Refund Decisions
|
| 20 |
-
# =============================================================================
|
| 21 |
-
|
| 22 |
-
# Refundable errors - User gets credits back (server/API issues)
|
| 23 |
-
REFUNDABLE_ERROR_PATTERNS = [
|
| 24 |
-
"API_KEY_INVALID",
|
| 25 |
-
"QUOTA_EXCEEDED",
|
| 26 |
-
"INTERNAL_ERROR",
|
| 27 |
-
"CONNECTION_FAILED",
|
| 28 |
-
"SERVER_SHUTDOWN",
|
| 29 |
-
"TIMEOUT",
|
| 30 |
-
"Server Authentication Error",
|
| 31 |
-
"Network error",
|
| 32 |
-
"Connection refused",
|
| 33 |
-
"Connection reset",
|
| 34 |
-
"Service unavailable",
|
| 35 |
-
"503",
|
| 36 |
-
"500",
|
| 37 |
-
"429", # Rate limit (our quota, not user's fault)
|
| 38 |
-
]
|
| 39 |
-
|
| 40 |
-
# Non-refundable error patterns - User's input/content issue
|
| 41 |
-
NON_REFUNDABLE_ERROR_PATTERNS = [
|
| 42 |
-
"safety",
|
| 43 |
-
"blocked",
|
| 44 |
-
"SAFETY_FILTER",
|
| 45 |
-
"INVALID_INPUT",
|
| 46 |
-
"Invalid image",
|
| 47 |
-
"Bad request",
|
| 48 |
-
"400",
|
| 49 |
-
"cancelled",
|
| 50 |
-
"User cancelled",
|
| 51 |
-
]
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def is_refundable_error(error_message: Optional[str]) -> bool:
|
| 55 |
-
"""
|
| 56 |
-
Determine if an error should result in a credit refund.
|
| 57 |
-
|
| 58 |
-
Args:
|
| 59 |
-
error_message: The error message from the failed job
|
| 60 |
-
|
| 61 |
-
Returns:
|
| 62 |
-
True if the error is refundable (server/API issue)
|
| 63 |
-
False if non-refundable (user's fault) or no error message
|
| 64 |
-
"""
|
| 65 |
-
if not error_message:
|
| 66 |
-
return False
|
| 67 |
-
|
| 68 |
-
error_lower = error_message.lower()
|
| 69 |
-
|
| 70 |
-
# Check for REFUNDABLE patterns FIRST (specific server errors take precedence)
|
| 71 |
-
# This ensures API_KEY_INVALID is caught before generic "400" matcher
|
| 72 |
-
for pattern in REFUNDABLE_ERROR_PATTERNS:
|
| 73 |
-
if pattern.lower() in error_lower:
|
| 74 |
-
logger.debug(f"Error matched refundable pattern '{pattern}': {error_message[:100]}")
|
| 75 |
-
return True
|
| 76 |
-
|
| 77 |
-
# Check for non-refundable patterns (user-caused issues)
|
| 78 |
-
for pattern in NON_REFUNDABLE_ERROR_PATTERNS:
|
| 79 |
-
if pattern.lower() in error_lower:
|
| 80 |
-
logger.debug(f"Error matched non-refundable pattern '{pattern}': {error_message[:100]}")
|
| 81 |
-
return False
|
| 82 |
-
|
| 83 |
-
# Default: Max retries exceeded is refundable (we consumed API resources trying)
|
| 84 |
-
if "max retries" in error_lower:
|
| 85 |
-
return True
|
| 86 |
-
|
| 87 |
-
# Default: Unknown errors are NOT refundable to prevent abuse
|
| 88 |
-
# If it's an unknown error, it's more likely user-caused
|
| 89 |
-
logger.debug(f"Unknown error (not refundable): {error_message[:100]}")
|
| 90 |
-
return False
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
async def reserve_credit(session: AsyncSession, user, amount: int = 1) -> bool:
|
| 94 |
-
"""
|
| 95 |
-
Reserve credits for a job (deduct from user's balance).
|
| 96 |
-
|
| 97 |
-
The credits are deducted but tracked in the job's credits_reserved field.
|
| 98 |
-
If the job fails with a refundable error, they can be restored.
|
| 99 |
-
|
| 100 |
-
Args:
|
| 101 |
-
session: Database session
|
| 102 |
-
user: User model instance
|
| 103 |
-
amount: Number of credits to reserve (default: 1)
|
| 104 |
-
|
| 105 |
-
Returns:
|
| 106 |
-
True if credits were successfully reserved
|
| 107 |
-
False if user has insufficient credits
|
| 108 |
-
"""
|
| 109 |
-
if user.credits < amount:
|
| 110 |
-
logger.warning(f"User {user.user_id} has insufficient credits ({user.credits}) to reserve {amount}")
|
| 111 |
-
return False
|
| 112 |
-
|
| 113 |
-
user.credits -= amount
|
| 114 |
-
logger.info(f"Reserved {amount} credit(s) for user {user.user_id}. Remaining: {user.credits}")
|
| 115 |
-
# Note: Don't commit here - let caller handle transaction
|
| 116 |
-
return True
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
async def confirm_credit(session: AsyncSession, job) -> None:
|
| 120 |
-
"""
|
| 121 |
-
Confirm that credits were legitimately used for a completed job.
|
| 122 |
-
|
| 123 |
-
This is called when a job completes successfully. The credits stay
|
| 124 |
-
deducted (they were already deducted during reservation).
|
| 125 |
-
|
| 126 |
-
Args:
|
| 127 |
-
session: Database session
|
| 128 |
-
job: GeminiJob model instance
|
| 129 |
-
"""
|
| 130 |
-
if job.credits_reserved > 0:
|
| 131 |
-
# Credits were used - clear the reservation tracking
|
| 132 |
-
credits_used = job.credits_reserved
|
| 133 |
-
job.credits_reserved = 0
|
| 134 |
-
logger.info(f"Confirmed {credits_used} credit(s) used for job {job.job_id}")
|
| 135 |
-
# Note: Don't commit here - let caller handle transaction
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
async def refund_credit(session: AsyncSession, job, reason: str) -> bool:
|
| 139 |
-
"""
|
| 140 |
-
Refund reserved credits back to the user.
|
| 141 |
-
|
| 142 |
-
Called when a job fails due to a refundable error (server-side issue).
|
| 143 |
-
|
| 144 |
-
Args:
|
| 145 |
-
session: Database session
|
| 146 |
-
job: GeminiJob model instance
|
| 147 |
-
reason: Reason for the refund (for logging)
|
| 148 |
-
|
| 149 |
-
Returns:
|
| 150 |
-
True if credits were refunded
|
| 151 |
-
False if no credits to refund or already refunded
|
| 152 |
-
"""
|
| 153 |
-
if job.credits_reserved <= 0:
|
| 154 |
-
logger.debug(f"Job {job.job_id} has no credits to refund")
|
| 155 |
-
return False
|
| 156 |
-
|
| 157 |
-
if job.credits_refunded:
|
| 158 |
-
logger.warning(f"Job {job.job_id} was already refunded")
|
| 159 |
-
return False
|
| 160 |
-
|
| 161 |
-
# Get the user to restore credits
|
| 162 |
-
from core.models import User
|
| 163 |
-
|
| 164 |
-
result = await session.execute(
|
| 165 |
-
select(User).where(User.id == job.user_id)
|
| 166 |
-
)
|
| 167 |
-
user = result.scalar_one_or_none()
|
| 168 |
-
|
| 169 |
-
if not user:
|
| 170 |
-
logger.error(f"Cannot refund job {job.job_id}: User {job.user_id} not found")
|
| 171 |
-
return False
|
| 172 |
-
|
| 173 |
-
# Restore credits
|
| 174 |
-
credits_to_refund = job.credits_reserved
|
| 175 |
-
user.credits += credits_to_refund
|
| 176 |
-
job.credits_reserved = 0
|
| 177 |
-
job.credits_refunded = True
|
| 178 |
-
|
| 179 |
-
logger.info(
|
| 180 |
-
f"Refunded {credits_to_refund} credit(s) to user {user.user_id} for job {job.job_id}. "
|
| 181 |
-
f"Reason: {reason[:100]}. New balance: {user.credits}"
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
# Note: Don't commit here - let caller handle transaction
|
| 185 |
-
return True
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
async def handle_job_completion(session: AsyncSession, job) -> None:
|
| 189 |
-
"""
|
| 190 |
-
Handle credit finalization when a job completes or fails.
|
| 191 |
-
|
| 192 |
-
This is the main entry point called by the job worker.
|
| 193 |
-
|
| 194 |
-
Args:
|
| 195 |
-
session: Database session
|
| 196 |
-
job: GeminiJob model instance with final status
|
| 197 |
-
"""
|
| 198 |
-
if job.status == "completed":
|
| 199 |
-
# Success - confirm credits were used
|
| 200 |
-
await confirm_credit(session, job)
|
| 201 |
-
|
| 202 |
-
elif job.status == "failed":
|
| 203 |
-
# Failure - check if refundable
|
| 204 |
-
if is_refundable_error(job.error_message):
|
| 205 |
-
await refund_credit(session, job, job.error_message or "Unknown error")
|
| 206 |
-
else:
|
| 207 |
-
# Non-refundable - confirm credits were used (user's fault)
|
| 208 |
-
await confirm_credit(session, job)
|
| 209 |
-
logger.info(f"Job {job.job_id} failed with non-refundable error, credits kept")
|
| 210 |
-
|
| 211 |
-
elif job.status == "cancelled":
|
| 212 |
-
# Cancelled jobs get refunds only if they were never started
|
| 213 |
-
if job.started_at is None:
|
| 214 |
-
await refund_credit(session, job, "Job cancelled before processing")
|
| 215 |
-
else:
|
| 216 |
-
# Was processing - keep credits (API may have been consumed)
|
| 217 |
-
await confirm_credit(session, job)
|
| 218 |
-
logger.info(f"Job {job.job_id} cancelled during processing, credits kept")
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
async def refund_orphaned_jobs(session: AsyncSession) -> int:
|
| 222 |
-
"""
|
| 223 |
-
Refund credits for jobs that were abandoned due to server shutdown.
|
| 224 |
-
|
| 225 |
-
Called during graceful shutdown to ensure no credits are lost.
|
| 226 |
-
|
| 227 |
-
Args:
|
| 228 |
-
session: Database session
|
| 229 |
-
|
| 230 |
-
Returns:
|
| 231 |
-
Number of jobs that were refunded
|
| 232 |
-
"""
|
| 233 |
-
from core.models import GeminiJob
|
| 234 |
-
|
| 235 |
-
# Find jobs that are still processing with reserved credits
|
| 236 |
-
result = await session.execute(
|
| 237 |
-
select(GeminiJob).where(
|
| 238 |
-
GeminiJob.status == "processing",
|
| 239 |
-
GeminiJob.credits_reserved > 0,
|
| 240 |
-
GeminiJob.credits_refunded == False
|
| 241 |
-
)
|
| 242 |
-
)
|
| 243 |
-
orphaned_jobs = result.scalars().all()
|
| 244 |
-
|
| 245 |
-
refund_count = 0
|
| 246 |
-
for job in orphaned_jobs:
|
| 247 |
-
if await refund_credit(session, job, "SERVER_SHUTDOWN: Job orphaned during server shutdown"):
|
| 248 |
-
# Mark job as failed
|
| 249 |
-
job.status = "failed"
|
| 250 |
-
job.error_message = "Server shutdown during processing. Credits refunded."
|
| 251 |
-
refund_count += 1
|
| 252 |
-
|
| 253 |
-
if refund_count > 0:
|
| 254 |
-
await session.commit()
|
| 255 |
-
logger.info(f"Refunded {refund_count} orphaned job(s) during shutdown")
|
| 256 |
-
|
| 257 |
-
return refund_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
services/google_auth_service.py
DELETED
|
@@ -1,232 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Modular Google OAuth Service
|
| 3 |
-
|
| 4 |
-
A self-contained, plug-and-play service for verifying Google ID tokens.
|
| 5 |
-
Can be used in any Python application with minimal configuration.
|
| 6 |
-
|
| 7 |
-
Usage:
|
| 8 |
-
from services.google_auth_service import GoogleAuthService, GoogleUserInfo
|
| 9 |
-
|
| 10 |
-
# Initialize with client ID
|
| 11 |
-
auth_service = GoogleAuthService(client_id="your-google-client-id")
|
| 12 |
-
|
| 13 |
-
# Or use environment variable GOOGLE_CLIENT_ID
|
| 14 |
-
auth_service = GoogleAuthService()
|
| 15 |
-
|
| 16 |
-
# Verify a Google ID token
|
| 17 |
-
user_info = auth_service.verify_token(id_token)
|
| 18 |
-
print(user_info.email, user_info.google_id, user_info.name)
|
| 19 |
-
|
| 20 |
-
Environment Variables:
|
| 21 |
-
GOOGLE_CLIENT_ID: Your Google OAuth 2.0 Client ID
|
| 22 |
-
|
| 23 |
-
Dependencies:
|
| 24 |
-
google-auth>=2.0.0
|
| 25 |
-
google-auth-oauthlib>=1.0.0
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
import os
|
| 29 |
-
import logging
|
| 30 |
-
from dataclasses import dataclass
|
| 31 |
-
from typing import Optional
|
| 32 |
-
from google.oauth2 import id_token as google_id_token
|
| 33 |
-
from google.auth.transport import requests as google_requests
|
| 34 |
-
|
| 35 |
-
logger = logging.getLogger(__name__)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
@dataclass
|
| 39 |
-
class GoogleUserInfo:
|
| 40 |
-
"""
|
| 41 |
-
User information extracted from a verified Google ID token.
|
| 42 |
-
|
| 43 |
-
Attributes:
|
| 44 |
-
google_id: Unique Google user identifier (sub claim)
|
| 45 |
-
email: User's email address
|
| 46 |
-
email_verified: Whether Google has verified the email
|
| 47 |
-
name: User's display name (may be None)
|
| 48 |
-
picture: URL to user's profile picture (may be None)
|
| 49 |
-
given_name: User's first name (may be None)
|
| 50 |
-
family_name: User's last name (may be None)
|
| 51 |
-
locale: User's locale preference (may be None)
|
| 52 |
-
"""
|
| 53 |
-
google_id: str
|
| 54 |
-
email: str
|
| 55 |
-
email_verified: bool = True
|
| 56 |
-
name: Optional[str] = None
|
| 57 |
-
picture: Optional[str] = None
|
| 58 |
-
given_name: Optional[str] = None
|
| 59 |
-
family_name: Optional[str] = None
|
| 60 |
-
locale: Optional[str] = None
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
class GoogleAuthError(Exception):
|
| 64 |
-
"""Base exception for Google Auth errors."""
|
| 65 |
-
pass
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
class InvalidTokenError(GoogleAuthError):
|
| 69 |
-
"""Raised when the token is invalid or expired."""
|
| 70 |
-
pass
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class ConfigurationError(GoogleAuthError):
|
| 74 |
-
"""Raised when the service is not properly configured."""
|
| 75 |
-
pass
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
class GoogleAuthService:
|
| 79 |
-
"""
|
| 80 |
-
Service for verifying Google OAuth ID tokens.
|
| 81 |
-
|
| 82 |
-
This service validates ID tokens issued by Google Sign-In and extracts
|
| 83 |
-
user information. It's designed to be modular and reusable across
|
| 84 |
-
different applications.
|
| 85 |
-
|
| 86 |
-
Example:
|
| 87 |
-
service = GoogleAuthService()
|
| 88 |
-
try:
|
| 89 |
-
user_info = service.verify_token(token_from_frontend)
|
| 90 |
-
print(f"Welcome {user_info.name}!")
|
| 91 |
-
except InvalidTokenError:
|
| 92 |
-
print("Invalid or expired token")
|
| 93 |
-
"""
|
| 94 |
-
|
| 95 |
-
def __init__(
|
| 96 |
-
self,
|
| 97 |
-
client_id: Optional[str] = None,
|
| 98 |
-
clock_skew_seconds: int = 0
|
| 99 |
-
):
|
| 100 |
-
"""
|
| 101 |
-
Initialize the Google Auth Service.
|
| 102 |
-
|
| 103 |
-
Args:
|
| 104 |
-
client_id: Google OAuth 2.0 Client ID. If not provided,
|
| 105 |
-
falls back to GOOGLE_CLIENT_ID environment variable.
|
| 106 |
-
clock_skew_seconds: Allowed clock skew in seconds for token
|
| 107 |
-
validation (default: 0).
|
| 108 |
-
|
| 109 |
-
Raises:
|
| 110 |
-
ConfigurationError: If no client_id is provided or found.
|
| 111 |
-
"""
|
| 112 |
-
self.client_id = client_id or os.getenv("AUTH_SIGN_IN_GOOGLE_CLIENT_ID")
|
| 113 |
-
self.clock_skew_seconds = clock_skew_seconds
|
| 114 |
-
|
| 115 |
-
if not self.client_id:
|
| 116 |
-
raise ConfigurationError(
|
| 117 |
-
"Google Client ID is required. Either pass client_id parameter "
|
| 118 |
-
"or set GOOGLE_CLIENT_ID environment variable."
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
logger.info(f"GoogleAuthService initialized with client_id: {self.client_id[:20]}...")
|
| 122 |
-
|
| 123 |
-
def verify_token(self, id_token: str) -> GoogleUserInfo:
|
| 124 |
-
"""
|
| 125 |
-
Verify a Google ID token and extract user information.
|
| 126 |
-
|
| 127 |
-
Args:
|
| 128 |
-
id_token: The ID token received from the frontend after
|
| 129 |
-
Google Sign-In.
|
| 130 |
-
|
| 131 |
-
Returns:
|
| 132 |
-
GoogleUserInfo: Dataclass containing user's Google profile info.
|
| 133 |
-
|
| 134 |
-
Raises:
|
| 135 |
-
InvalidTokenError: If the token is invalid, expired, or
|
| 136 |
-
doesn't match the expected client ID.
|
| 137 |
-
"""
|
| 138 |
-
if not id_token:
|
| 139 |
-
raise InvalidTokenError("Token cannot be empty")
|
| 140 |
-
|
| 141 |
-
try:
|
| 142 |
-
# Verify the token with Google
|
| 143 |
-
idinfo = google_id_token.verify_oauth2_token(
|
| 144 |
-
id_token,
|
| 145 |
-
google_requests.Request(),
|
| 146 |
-
self.client_id,
|
| 147 |
-
clock_skew_in_seconds=self.clock_skew_seconds
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
# Validate issuer
|
| 151 |
-
if idinfo.get("iss") not in ["accounts.google.com", "https://accounts.google.com"]:
|
| 152 |
-
raise InvalidTokenError("Invalid token issuer")
|
| 153 |
-
|
| 154 |
-
# Validate audience
|
| 155 |
-
if idinfo.get("aud") != self.client_id:
|
| 156 |
-
raise InvalidTokenError("Token was not issued for this application")
|
| 157 |
-
|
| 158 |
-
# Extract user info
|
| 159 |
-
return GoogleUserInfo(
|
| 160 |
-
google_id=idinfo["sub"],
|
| 161 |
-
email=idinfo["email"],
|
| 162 |
-
email_verified=idinfo.get("email_verified", False),
|
| 163 |
-
name=idinfo.get("name"),
|
| 164 |
-
picture=idinfo.get("picture"),
|
| 165 |
-
given_name=idinfo.get("given_name"),
|
| 166 |
-
family_name=idinfo.get("family_name"),
|
| 167 |
-
locale=idinfo.get("locale")
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
except ValueError as e:
|
| 171 |
-
logger.warning(f"Token verification failed: {e}")
|
| 172 |
-
raise InvalidTokenError(f"Token verification failed: {str(e)}")
|
| 173 |
-
except Exception as e:
|
| 174 |
-
logger.error(f"Unexpected error during token verification: {e}")
|
| 175 |
-
raise InvalidTokenError(f"Token verification error: {str(e)}")
|
| 176 |
-
|
| 177 |
-
def verify_token_safe(self, id_token: str) -> Optional[GoogleUserInfo]:
|
| 178 |
-
"""
|
| 179 |
-
Verify a Google ID token without raising exceptions.
|
| 180 |
-
|
| 181 |
-
Useful for cases where you want to check validity without
|
| 182 |
-
exception handling.
|
| 183 |
-
|
| 184 |
-
Args:
|
| 185 |
-
id_token: The ID token to verify.
|
| 186 |
-
|
| 187 |
-
Returns:
|
| 188 |
-
GoogleUserInfo if valid, None if invalid.
|
| 189 |
-
"""
|
| 190 |
-
try:
|
| 191 |
-
return self.verify_token(id_token)
|
| 192 |
-
except GoogleAuthError:
|
| 193 |
-
return None
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
# Singleton instance for convenience (initialized on first use)
|
| 197 |
-
_default_service: Optional[GoogleAuthService] = None
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
def get_google_auth_service() -> GoogleAuthService:
|
| 201 |
-
"""
|
| 202 |
-
Get the default GoogleAuthService instance.
|
| 203 |
-
|
| 204 |
-
Creates a singleton instance using environment variables.
|
| 205 |
-
|
| 206 |
-
Returns:
|
| 207 |
-
GoogleAuthService: The default service instance.
|
| 208 |
-
|
| 209 |
-
Raises:
|
| 210 |
-
ConfigurationError: If GOOGLE_CLIENT_ID is not set.
|
| 211 |
-
"""
|
| 212 |
-
global _default_service
|
| 213 |
-
if _default_service is None:
|
| 214 |
-
_default_service = GoogleAuthService()
|
| 215 |
-
return _default_service
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
def verify_google_token(id_token: str) -> GoogleUserInfo:
|
| 219 |
-
"""
|
| 220 |
-
Convenience function to verify a token using the default service.
|
| 221 |
-
|
| 222 |
-
Args:
|
| 223 |
-
id_token: The Google ID token to verify.
|
| 224 |
-
|
| 225 |
-
Returns:
|
| 226 |
-
GoogleUserInfo: Verified user information.
|
| 227 |
-
|
| 228 |
-
Raises:
|
| 229 |
-
InvalidTokenError: If verification fails.
|
| 230 |
-
ConfigurationError: If service is not configured.
|
| 231 |
-
"""
|
| 232 |
-
return get_google_auth_service().verify_token(id_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
services/jwt_service.py
DELETED
|
@@ -1,386 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Modular JWT Service
|
| 3 |
-
|
| 4 |
-
A self-contained, plug-and-play service for creating and verifying JWT tokens.
|
| 5 |
-
Can be used in any Python application with minimal configuration.
|
| 6 |
-
|
| 7 |
-
Usage:
|
| 8 |
-
from services.jwt_service import JWTService, TokenPayload
|
| 9 |
-
|
| 10 |
-
# Initialize with secret key
|
| 11 |
-
jwt_service = JWTService(secret_key="your-secret-key")
|
| 12 |
-
|
| 13 |
-
# Or use environment variable JWT_SECRET
|
| 14 |
-
jwt_service = JWTService()
|
| 15 |
-
|
| 16 |
-
# Create a token
|
| 17 |
-
token = jwt_service.create_token(user_id="user123", email="user@example.com")
|
| 18 |
-
|
| 19 |
-
# Verify a token
|
| 20 |
-
payload = jwt_service.verify_token(token)
|
| 21 |
-
print(payload.user_id, payload.email)
|
| 22 |
-
|
| 23 |
-
Environment Variables:
|
| 24 |
-
JWT_SECRET: Your secret key for signing tokens (required)
|
| 25 |
-
JWT_EXPIRY_HOURS: Token expiry in hours (default: 168 = 7 days)
|
| 26 |
-
JWT_ALGORITHM: Algorithm to use (default: HS256)
|
| 27 |
-
|
| 28 |
-
Dependencies:
|
| 29 |
-
PyJWT>=2.8.0
|
| 30 |
-
|
| 31 |
-
Generate a secure secret:
|
| 32 |
-
python -c "import secrets; print(secrets.token_urlsafe(64))"
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
import os
|
| 36 |
-
import logging
|
| 37 |
-
from dataclasses import dataclass
|
| 38 |
-
from datetime import datetime, timedelta
|
| 39 |
-
from typing import Optional, Dict, Any
|
| 40 |
-
import jwt
|
| 41 |
-
|
| 42 |
-
logger = logging.getLogger(__name__)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
@dataclass
|
| 46 |
-
class TokenPayload:
|
| 47 |
-
"""
|
| 48 |
-
Payload extracted from a verified JWT token.
|
| 49 |
-
|
| 50 |
-
Attributes:
|
| 51 |
-
user_id: The user's unique identifier (sub claim)
|
| 52 |
-
email: The user's email address
|
| 53 |
-
issued_at: When the token was issued
|
| 54 |
-
expires_at: When the token expires
|
| 55 |
-
token_version: Version number for token invalidation
|
| 56 |
-
extra: Any additional claims in the token
|
| 57 |
-
"""
|
| 58 |
-
user_id: str
|
| 59 |
-
email: str
|
| 60 |
-
issued_at: datetime
|
| 61 |
-
expires_at: datetime
|
| 62 |
-
token_version: int = 1
|
| 63 |
-
extra: Dict[str, Any] = None
|
| 64 |
-
|
| 65 |
-
def __post_init__(self):
|
| 66 |
-
if self.extra is None:
|
| 67 |
-
self.extra = {}
|
| 68 |
-
|
| 69 |
-
@property
|
| 70 |
-
def is_expired(self) -> bool:
|
| 71 |
-
"""Check if the token has expired."""
|
| 72 |
-
return datetime.utcnow() > self.expires_at
|
| 73 |
-
|
| 74 |
-
@property
|
| 75 |
-
def time_until_expiry(self) -> timedelta:
|
| 76 |
-
"""Get time remaining until expiry."""
|
| 77 |
-
return self.expires_at - datetime.utcnow()
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class JWTError(Exception):
|
| 81 |
-
"""Base exception for JWT errors."""
|
| 82 |
-
pass
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
class TokenExpiredError(JWTError):
|
| 86 |
-
"""Raised when the token has expired."""
|
| 87 |
-
pass
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
class InvalidTokenError(JWTError):
|
| 91 |
-
"""Raised when the token is invalid."""
|
| 92 |
-
pass
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
class ConfigurationError(JWTError):
|
| 96 |
-
"""Raised when the service is not properly configured."""
|
| 97 |
-
pass
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
class JWTService:
|
| 101 |
-
"""
|
| 102 |
-
Service for creating and verifying JWT tokens.
|
| 103 |
-
|
| 104 |
-
This service handles JWT token lifecycle for authentication.
|
| 105 |
-
It's designed to be modular and reusable across different applications.
|
| 106 |
-
|
| 107 |
-
Example:
|
| 108 |
-
service = JWTService(secret_key="my-secret")
|
| 109 |
-
|
| 110 |
-
# Create token
|
| 111 |
-
token = service.create_token(user_id="u123", email="a@b.com")
|
| 112 |
-
|
| 113 |
-
# Verify token
|
| 114 |
-
try:
|
| 115 |
-
payload = service.verify_token(token)
|
| 116 |
-
print(f"User: {payload.user_id}")
|
| 117 |
-
except TokenExpiredError:
|
| 118 |
-
print("Token expired, please login again")
|
| 119 |
-
except InvalidTokenError:
|
| 120 |
-
print("Invalid token")
|
| 121 |
-
"""
|
| 122 |
-
|
| 123 |
-
# Default configuration
|
| 124 |
-
DEFAULT_ALGORITHM = "HS256"
|
| 125 |
-
DEFAULT_EXPIRY_HOURS = 168 # 7 days
|
| 126 |
-
|
| 127 |
-
def __init__(
|
| 128 |
-
self,
|
| 129 |
-
secret_key: Optional[str] = None,
|
| 130 |
-
algorithm: Optional[str] = None,
|
| 131 |
-
expiry_hours: Optional[int] = None
|
| 132 |
-
):
|
| 133 |
-
"""
|
| 134 |
-
Initialize the JWT Service.
|
| 135 |
-
|
| 136 |
-
Args:
|
| 137 |
-
secret_key: Secret key for signing tokens. If not provided,
|
| 138 |
-
falls back to JWT_SECRET environment variable.
|
| 139 |
-
algorithm: JWT algorithm (default: HS256).
|
| 140 |
-
expiry_hours: Token expiry in hours (default: 168 = 7 days).
|
| 141 |
-
|
| 142 |
-
Raises:
|
| 143 |
-
ConfigurationError: If no secret_key is provided or found.
|
| 144 |
-
"""
|
| 145 |
-
self.secret_key = secret_key or os.getenv("JWT_SECRET")
|
| 146 |
-
self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM)
|
| 147 |
-
self.expiry_hours = expiry_hours or int(
|
| 148 |
-
os.getenv("JWT_EXPIRY_HOURS", str(self.DEFAULT_EXPIRY_HOURS))
|
| 149 |
-
)
|
| 150 |
-
|
| 151 |
-
if not self.secret_key:
|
| 152 |
-
raise ConfigurationError(
|
| 153 |
-
"JWT secret key is required. Either pass secret_key parameter "
|
| 154 |
-
"or set JWT_SECRET environment variable. "
|
| 155 |
-
"Generate one with: python -c \"import secrets; print(secrets.token_urlsafe(64))\""
|
| 156 |
-
)
|
| 157 |
-
|
| 158 |
-
# Warn if secret is too short
|
| 159 |
-
if len(self.secret_key) < 32:
|
| 160 |
-
logger.warning(
|
| 161 |
-
"JWT secret key is short (< 32 chars). "
|
| 162 |
-
"Consider using a longer secret for better security."
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
logger.info(
|
| 166 |
-
f"JWTService initialized (algorithm={self.algorithm}, "
|
| 167 |
-
f"expiry={self.expiry_hours}h)"
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
def create_token(
|
| 171 |
-
self,
|
| 172 |
-
user_id: str,
|
| 173 |
-
email: str,
|
| 174 |
-
token_version: int = 1,
|
| 175 |
-
extra_claims: Optional[Dict[str, Any]] = None,
|
| 176 |
-
expiry_hours: Optional[int] = None
|
| 177 |
-
) -> str:
|
| 178 |
-
"""
|
| 179 |
-
Create a JWT token for a user.
|
| 180 |
-
|
| 181 |
-
Args:
|
| 182 |
-
user_id: The user's unique identifier.
|
| 183 |
-
email: The user's email address.
|
| 184 |
-
token_version: User's current token version for invalidation.
|
| 185 |
-
extra_claims: Additional claims to include in the token.
|
| 186 |
-
expiry_hours: Custom expiry for this token (overrides default).
|
| 187 |
-
|
| 188 |
-
Returns:
|
| 189 |
-
str: The encoded JWT token.
|
| 190 |
-
"""
|
| 191 |
-
now = datetime.utcnow()
|
| 192 |
-
expiry = expiry_hours or self.expiry_hours
|
| 193 |
-
|
| 194 |
-
payload = {
|
| 195 |
-
"sub": user_id,
|
| 196 |
-
"email": email,
|
| 197 |
-
"tv": token_version, # Token version for invalidation
|
| 198 |
-
"iat": now,
|
| 199 |
-
"exp": now + timedelta(hours=expiry),
|
| 200 |
-
}
|
| 201 |
-
|
| 202 |
-
if extra_claims:
|
| 203 |
-
payload.update(extra_claims)
|
| 204 |
-
|
| 205 |
-
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
| 206 |
-
|
| 207 |
-
logger.debug(f"Created token for user_id={user_id} (version={token_version})")
|
| 208 |
-
return token
|
| 209 |
-
|
| 210 |
-
def verify_token(self, token: str) -> TokenPayload:
|
| 211 |
-
"""
|
| 212 |
-
Verify a JWT token and extract the payload.
|
| 213 |
-
|
| 214 |
-
Args:
|
| 215 |
-
token: The JWT token to verify.
|
| 216 |
-
|
| 217 |
-
Returns:
|
| 218 |
-
TokenPayload: Dataclass containing the verified payload.
|
| 219 |
-
|
| 220 |
-
Raises:
|
| 221 |
-
TokenExpiredError: If the token has expired.
|
| 222 |
-
InvalidTokenError: If the token is invalid or malformed.
|
| 223 |
-
"""
|
| 224 |
-
if not token:
|
| 225 |
-
raise InvalidTokenError("Token cannot be empty")
|
| 226 |
-
|
| 227 |
-
try:
|
| 228 |
-
payload = jwt.decode(
|
| 229 |
-
token,
|
| 230 |
-
self.secret_key,
|
| 231 |
-
algorithms=[self.algorithm]
|
| 232 |
-
)
|
| 233 |
-
|
| 234 |
-
# Extract standard claims
|
| 235 |
-
user_id = payload.get("sub")
|
| 236 |
-
email = payload.get("email")
|
| 237 |
-
token_version = payload.get("tv", 1) # Default to 1 for backward compatibility
|
| 238 |
-
iat = payload.get("iat")
|
| 239 |
-
exp = payload.get("exp")
|
| 240 |
-
|
| 241 |
-
if not user_id or not email:
|
| 242 |
-
raise InvalidTokenError("Token missing required claims (sub, email)")
|
| 243 |
-
|
| 244 |
-
# Convert timestamps to datetime
|
| 245 |
-
issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat
|
| 246 |
-
expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
|
| 247 |
-
|
| 248 |
-
# Extract extra claims
|
| 249 |
-
standard_claims = {"sub", "email", "tv", "iat", "exp"}
|
| 250 |
-
extra = {k: v for k, v in payload.items() if k not in standard_claims}
|
| 251 |
-
|
| 252 |
-
return TokenPayload(
|
| 253 |
-
user_id=user_id,
|
| 254 |
-
email=email,
|
| 255 |
-
issued_at=issued_at,
|
| 256 |
-
expires_at=expires_at,
|
| 257 |
-
token_version=token_version,
|
| 258 |
-
extra=extra
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
except jwt.ExpiredSignatureError:
|
| 262 |
-
logger.debug("Token verification failed: expired")
|
| 263 |
-
raise TokenExpiredError("Token has expired")
|
| 264 |
-
except jwt.InvalidTokenError as e:
|
| 265 |
-
logger.debug(f"Token verification failed: {e}")
|
| 266 |
-
raise InvalidTokenError(f"Invalid token: {str(e)}")
|
| 267 |
-
except Exception as e:
|
| 268 |
-
logger.error(f"Unexpected error during token verification: {e}")
|
| 269 |
-
raise InvalidTokenError(f"Token verification error: {str(e)}")
|
| 270 |
-
|
| 271 |
-
def verify_token_safe(self, token: str) -> Optional[TokenPayload]:
|
| 272 |
-
"""
|
| 273 |
-
Verify a JWT token without raising exceptions.
|
| 274 |
-
|
| 275 |
-
Args:
|
| 276 |
-
token: The JWT token to verify.
|
| 277 |
-
|
| 278 |
-
Returns:
|
| 279 |
-
TokenPayload if valid, None if invalid or expired.
|
| 280 |
-
"""
|
| 281 |
-
try:
|
| 282 |
-
return self.verify_token(token)
|
| 283 |
-
except JWTError:
|
| 284 |
-
return None
|
| 285 |
-
|
| 286 |
-
def refresh_token(
|
| 287 |
-
self,
|
| 288 |
-
token: str,
|
| 289 |
-
expiry_hours: Optional[int] = None
|
| 290 |
-
) -> str:
|
| 291 |
-
"""
|
| 292 |
-
Refresh a token by creating a new one with the same claims.
|
| 293 |
-
|
| 294 |
-
Args:
|
| 295 |
-
token: The current (possibly expired) token.
|
| 296 |
-
expiry_hours: Custom expiry for the new token.
|
| 297 |
-
|
| 298 |
-
Returns:
|
| 299 |
-
str: A new JWT token with updated expiry.
|
| 300 |
-
|
| 301 |
-
Raises:
|
| 302 |
-
InvalidTokenError: If the token is malformed.
|
| 303 |
-
"""
|
| 304 |
-
try:
|
| 305 |
-
# Decode without verifying expiry
|
| 306 |
-
payload = jwt.decode(
|
| 307 |
-
token,
|
| 308 |
-
self.secret_key,
|
| 309 |
-
algorithms=[self.algorithm],
|
| 310 |
-
options={"verify_exp": False}
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
user_id = payload.get("sub")
|
| 314 |
-
email = payload.get("email")
|
| 315 |
-
|
| 316 |
-
if not user_id or not email:
|
| 317 |
-
raise InvalidTokenError("Token missing required claims")
|
| 318 |
-
|
| 319 |
-
# Preserve extra claims
|
| 320 |
-
standard_claims = {"sub", "email", "iat", "exp"}
|
| 321 |
-
extra = {k: v for k, v in payload.items() if k not in standard_claims}
|
| 322 |
-
|
| 323 |
-
return self.create_token(
|
| 324 |
-
user_id=user_id,
|
| 325 |
-
email=email,
|
| 326 |
-
extra_claims=extra,
|
| 327 |
-
expiry_hours=expiry_hours
|
| 328 |
-
)
|
| 329 |
-
|
| 330 |
-
except jwt.InvalidTokenError as e:
|
| 331 |
-
raise InvalidTokenError(f"Cannot refresh invalid token: {str(e)}")
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
# Singleton instance for convenience
|
| 335 |
-
_default_service: Optional[JWTService] = None
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
def get_jwt_service() -> JWTService:
|
| 339 |
-
"""
|
| 340 |
-
Get the default JWTService instance.
|
| 341 |
-
|
| 342 |
-
Creates a singleton instance using environment variables.
|
| 343 |
-
|
| 344 |
-
Returns:
|
| 345 |
-
JWTService: The default service instance.
|
| 346 |
-
|
| 347 |
-
Raises:
|
| 348 |
-
ConfigurationError: If JWT_SECRET is not set.
|
| 349 |
-
"""
|
| 350 |
-
global _default_service
|
| 351 |
-
if _default_service is None:
|
| 352 |
-
_default_service = JWTService()
|
| 353 |
-
return _default_service
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
|
| 357 |
-
"""
|
| 358 |
-
Convenience function to create a token using the default service.
|
| 359 |
-
|
| 360 |
-
Args:
|
| 361 |
-
user_id: The user's unique identifier.
|
| 362 |
-
email: The user's email address.
|
| 363 |
-
token_version: User's current token version for invalidation.
|
| 364 |
-
**kwargs: Additional arguments passed to create_token.
|
| 365 |
-
|
| 366 |
-
Returns:
|
| 367 |
-
str: The encoded JWT token.
|
| 368 |
-
"""
|
| 369 |
-
return get_jwt_service().create_token(user_id, email, token_version, **kwargs)
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
def verify_access_token(token: str) -> TokenPayload:
|
| 373 |
-
"""
|
| 374 |
-
Convenience function to verify a token using the default service.
|
| 375 |
-
|
| 376 |
-
Args:
|
| 377 |
-
token: The JWT token to verify.
|
| 378 |
-
|
| 379 |
-
Returns:
|
| 380 |
-
TokenPayload: Verified token payload.
|
| 381 |
-
|
| 382 |
-
Raises:
|
| 383 |
-
TokenExpiredError: If the token has expired.
|
| 384 |
-
InvalidTokenError: If the token is invalid.
|
| 385 |
-
"""
|
| 386 |
-
return get_jwt_service().verify_token(token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
services/priority_worker_pool.py
CHANGED
|
@@ -378,7 +378,7 @@ class PriorityWorker(Generic[JobType]):
|
|
| 378 |
return
|
| 379 |
|
| 380 |
try:
|
| 381 |
-
from services.credit_service import handle_job_completion
|
| 382 |
await handle_job_completion(session, job)
|
| 383 |
except ImportError:
|
| 384 |
# Credit service not available - skip
|
|
@@ -519,7 +519,7 @@ class PriorityWorkerPool(Generic[JobType]):
|
|
| 519 |
async def _refund_orphaned_jobs(self):
|
| 520 |
"""Refund credits for jobs abandoned during shutdown."""
|
| 521 |
try:
|
| 522 |
-
from services.credit_service import refund_orphaned_jobs
|
| 523 |
async with self.session_maker() as session:
|
| 524 |
refund_count = await refund_orphaned_jobs(session)
|
| 525 |
if refund_count > 0:
|
|
|
|
| 378 |
return
|
| 379 |
|
| 380 |
try:
|
| 381 |
+
from services.credit_service.credit_manager import handle_job_completion
|
| 382 |
await handle_job_completion(session, job)
|
| 383 |
except ImportError:
|
| 384 |
# Credit service not available - skip
|
|
|
|
| 519 |
async def _refund_orphaned_jobs(self):
|
| 520 |
"""Refund credits for jobs abandoned during shutdown."""
|
| 521 |
try:
|
| 522 |
+
from services.credit_service.credit_manager import refund_orphaned_jobs
|
| 523 |
async with self.session_maker() as session:
|
| 524 |
refund_count = await refund_orphaned_jobs(session)
|
| 525 |
if refund_count > 0:
|