Spaces:
Sleeping
Sleeping
refactor
Browse files- app.py +54 -0
- routers/blink.py +29 -10
- routers/contact.py +5 -3
- routers/credits.py +8 -5
- routers/gemini.py +41 -17
- routers/payments.py +9 -4
- services/auth_service/__init__.py +106 -0
- services/auth_service/config.py +164 -0
- services/auth_service/google_provider.py +232 -0
- services/auth_service/jwt_provider.py +386 -0
- services/auth_service/middleware.py +225 -0
- services/base_service/__init__.py +246 -0
- services/base_service/middleware_chain.py +212 -0
- services/base_service/route_matcher.py +254 -0
- services/credit_service/__init__.py +76 -0
- services/credit_service/config.py +87 -0
- services/credit_service/credit_manager.py +257 -0
- services/credit_service/middleware.py +130 -0
- tests/test_base_service.py +252 -0
- tests/test_route_matcher.py +243 -0
app.py
CHANGED
|
@@ -39,6 +39,52 @@ async def lifespan(app: FastAPI):
|
|
| 39 |
register_db_service_config()
|
| 40 |
logger.info("✅ DB Service configured")
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# Check for RESET_DB environment variable
|
| 43 |
if os.getenv("RESET_DB", "").lower() == "true":
|
| 44 |
logger.warning(f"RESET_DB is set to true. Skipping download and clearing local database ({DB_FILENAME}).")
|
|
@@ -95,6 +141,14 @@ app.add_middleware(
|
|
| 95 |
allow_headers=["*"],
|
| 96 |
)
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
# Include Routers
|
| 99 |
app.include_router(general.router)
|
| 100 |
app.include_router(auth.router)
|
|
|
|
| 39 |
register_db_service_config()
|
| 40 |
logger.info("✅ DB Service configured")
|
| 41 |
|
| 42 |
+
# Register Auth Service configuration
|
| 43 |
+
from services.auth_service import register_auth_service
|
| 44 |
+
register_auth_service(
|
| 45 |
+
required_urls=[
|
| 46 |
+
"/blink",
|
| 47 |
+
"/api/*", # All admin blink API endpoints
|
| 48 |
+
"/contact",
|
| 49 |
+
"/gemini/*",
|
| 50 |
+
"/credits/balance",
|
| 51 |
+
"/credits/history",
|
| 52 |
+
"/payments/create-order",
|
| 53 |
+
"/payments/verify/*",
|
| 54 |
+
],
|
| 55 |
+
optional_urls=[
|
| 56 |
+
"/", # Home page works with or without auth
|
| 57 |
+
],
|
| 58 |
+
public_urls=[
|
| 59 |
+
"/health",
|
| 60 |
+
"/auth/*",
|
| 61 |
+
"/payments/packages", # Public pricing info
|
| 62 |
+
"/payments/webhook/*", # Webhooks from payment gateway
|
| 63 |
+
"/docs",
|
| 64 |
+
"/openapi.json",
|
| 65 |
+
"/redoc",
|
| 66 |
+
],
|
| 67 |
+
jwt_secret=os.getenv("JWT_SECRET"),
|
| 68 |
+
jwt_algorithm="HS256",
|
| 69 |
+
jwt_expiry_hours=24,
|
| 70 |
+
google_client_id=os.getenv("AUTH_SIGN_IN_GOOGLE_CLIENT_ID"),
|
| 71 |
+
admin_emails=os.getenv("ADMIN_EMAILS", "").split(",") if os.getenv("ADMIN_EMAILS") else [],
|
| 72 |
+
)
|
| 73 |
+
logger.info("✅ Auth Service configured")
|
| 74 |
+
|
| 75 |
+
# Register Credit Service configuration
|
| 76 |
+
from services.credit_service import register_credit_service
|
| 77 |
+
register_credit_service(
|
| 78 |
+
route_costs={
|
| 79 |
+
"/gemini/generate-animation-prompt": 1,
|
| 80 |
+
"/gemini/edit-image": 1,
|
| 81 |
+
"/gemini/generate-video": 10,
|
| 82 |
+
"/gemini/generate-text": 1,
|
| 83 |
+
"/gemini/analyze-image": 1,
|
| 84 |
+
}
|
| 85 |
+
)
|
| 86 |
+
logger.info("✅ Credit Service configured")
|
| 87 |
+
|
| 88 |
# Check for RESET_DB environment variable
|
| 89 |
if os.getenv("RESET_DB", "").lower() == "true":
|
| 90 |
logger.warning(f"RESET_DB is set to true. Skipping download and clearing local database ({DB_FILENAME}).")
|
|
|
|
| 141 |
allow_headers=["*"],
|
| 142 |
)
|
| 143 |
|
| 144 |
+
# Add Credit Middleware first (executes second - after auth)
|
| 145 |
+
from services.credit_service import CreditMiddleware
|
| 146 |
+
app.add_middleware(CreditMiddleware)
|
| 147 |
+
|
| 148 |
+
# Add Auth Middleware second (executes first - sets user)
|
| 149 |
+
from services.auth_service import AuthMiddleware
|
| 150 |
+
app.add_middleware(AuthMiddleware)
|
| 151 |
+
|
| 152 |
# Include Routers
|
| 153 |
app.include_router(general.router)
|
| 154 |
app.include_router(auth.router)
|
routers/blink.py
CHANGED
|
@@ -11,7 +11,7 @@ import logging
|
|
| 11 |
from core.database import get_db
|
| 12 |
from core.models import User, AuditLog, GeminiJob, Contact, ClientUser
|
| 13 |
from services.encryption_service import decrypt_multiple_blocks
|
| 14 |
-
from dependencies import get_geolocation
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
|
@@ -27,16 +27,18 @@ USER_ID_LENGTH = 20
|
|
| 27 |
|
| 28 |
@router.get("/api/data")
|
| 29 |
async def get_data(
|
|
|
|
| 30 |
page: int = Query(1, ge=1, description="Page number"),
|
| 31 |
limit: int = Query(100, ge=1, le=500, description="Items per page"),
|
| 32 |
log_type: str = Query(None, description="Filter by log type: client, server"),
|
| 33 |
-
user: User = Depends(get_current_user), # Auth required
|
| 34 |
db: AsyncSession = Depends(get_db)
|
| 35 |
):
|
| 36 |
"""
|
| 37 |
Get paginated audit log data for the authenticated user.
|
| 38 |
Admins see all logs from all users.
|
|
|
|
| 39 |
"""
|
|
|
|
| 40 |
from services.db_service import QueryService
|
| 41 |
|
| 42 |
try:
|
|
@@ -93,15 +95,17 @@ async def get_data(
|
|
| 93 |
|
| 94 |
@router.get("/api/users")
|
| 95 |
async def get_users(
|
|
|
|
| 96 |
page: int = Query(1, ge=1, description="Page number"),
|
| 97 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 98 |
-
user: User = Depends(get_current_user), # Auth required
|
| 99 |
db: AsyncSession = Depends(get_db)
|
| 100 |
):
|
| 101 |
"""
|
| 102 |
Get current user's profile data.
|
| 103 |
Admins see paginated list of all users.
|
|
|
|
| 104 |
"""
|
|
|
|
| 105 |
from services.db_service import QueryService
|
| 106 |
|
| 107 |
try:
|
|
@@ -148,16 +152,18 @@ async def get_users(
|
|
| 148 |
|
| 149 |
@router.get("/api/client-users")
|
| 150 |
async def get_client_users(
|
|
|
|
| 151 |
page: int = Query(1, ge=1, description="Page number"),
|
| 152 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 153 |
user_id: str = Query(None, description="Filter by server user_id"),
|
| 154 |
-
user: User = Depends(get_current_user), # Auth required
|
| 155 |
db: AsyncSession = Depends(get_db)
|
| 156 |
):
|
| 157 |
"""
|
| 158 |
Get current user's client mappings.
|
| 159 |
Admins see all client mappings from all users.
|
|
|
|
| 160 |
"""
|
|
|
|
| 161 |
from services.db_service import QueryService
|
| 162 |
|
| 163 |
try:
|
|
@@ -197,17 +203,19 @@ async def get_client_users(
|
|
| 197 |
|
| 198 |
@router.get("/api/audit-logs")
|
| 199 |
async def get_audit_logs(
|
|
|
|
| 200 |
page: int = Query(1, ge=1, description="Page number"),
|
| 201 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 202 |
log_type: str = Query(None, description="Filter by log type: client, server"),
|
| 203 |
action: str = Query(None, description="Filter by action"),
|
| 204 |
-
user: User = Depends(get_current_user), # Auth required
|
| 205 |
db: AsyncSession = Depends(get_db)
|
| 206 |
):
|
| 207 |
"""
|
| 208 |
Get current user's audit logs with optional filters.
|
| 209 |
Admins see all logs from all users.
|
|
|
|
| 210 |
"""
|
|
|
|
| 211 |
from services.db_service import QueryService
|
| 212 |
|
| 213 |
try:
|
|
@@ -263,15 +271,17 @@ async def get_audit_logs(
|
|
| 263 |
|
| 264 |
@router.get("/api/gemini-jobs")
|
| 265 |
async def get_gemini_jobs(
|
|
|
|
| 266 |
page: int = Query(1, ge=1, description="Page number"),
|
| 267 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 268 |
-
user: User = Depends(get_current_user), # Auth required
|
| 269 |
db: AsyncSession = Depends(get_db)
|
| 270 |
):
|
| 271 |
"""
|
| 272 |
Get current user's Gemini jobs.
|
| 273 |
Admins see all jobs from all users.
|
|
|
|
| 274 |
"""
|
|
|
|
| 275 |
from services.db_service import QueryService
|
| 276 |
|
| 277 |
try:
|
|
@@ -313,15 +323,17 @@ async def get_gemini_jobs(
|
|
| 313 |
|
| 314 |
@router.get("/api/payment-transactions")
|
| 315 |
async def get_payment_transactions(
|
|
|
|
| 316 |
page: int = Query(1, ge=1, description="Page number"),
|
| 317 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 318 |
-
user: User = Depends(get_current_user), # Auth required
|
| 319 |
db: AsyncSession = Depends(get_db)
|
| 320 |
):
|
| 321 |
"""
|
| 322 |
Get current user's payment transactions.
|
| 323 |
Admins see all transactions from all users.
|
|
|
|
| 324 |
"""
|
|
|
|
| 325 |
from core.models import PaymentTransaction
|
| 326 |
from services.db_service import QueryService
|
| 327 |
|
|
@@ -383,15 +395,17 @@ async def get_payment_transactions(
|
|
| 383 |
|
| 384 |
@router.get("/api/contacts")
|
| 385 |
async def get_contacts(
|
|
|
|
| 386 |
page: int = Query(1, ge=1, description="Page number"),
|
| 387 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 388 |
-
user: User = Depends(get_current_user), # Auth required
|
| 389 |
db: AsyncSession = Depends(get_db)
|
| 390 |
):
|
| 391 |
"""
|
| 392 |
Get current user's contact form submissions.
|
| 393 |
Admins see all contact submissions from all users.
|
|
|
|
| 394 |
"""
|
|
|
|
| 395 |
from services.db_service import QueryService
|
| 396 |
|
| 397 |
try:
|
|
@@ -436,13 +450,16 @@ async def get_contacts(
|
|
| 436 |
async def blink(
|
| 437 |
request: Request,
|
| 438 |
userid: str = Query(..., description="User ID (20 chars) + encrypted data"),
|
| 439 |
-
db: AsyncSession = Depends(get_db)
|
| 440 |
-
current_user: User = Depends(get_optional_user)
|
| 441 |
):
|
| 442 |
"""
|
| 443 |
Process blink request with encrypted user data.
|
| 444 |
Logs to AuditLog with log_type='client'.
|
| 445 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
If authenticated via JWT:
|
| 447 |
- Creates a new ClientUser entry linking client_user_id to server user_id
|
| 448 |
- Sets user_id in AuditLog entries
|
|
@@ -450,6 +467,8 @@ async def blink(
|
|
| 450 |
If not authenticated:
|
| 451 |
- Creates AuditLog entries with user_id=None (anonymous)
|
| 452 |
"""
|
|
|
|
|
|
|
| 453 |
try:
|
| 454 |
# Validate minimum length
|
| 455 |
if len(userid) < USER_ID_LENGTH:
|
|
|
|
| 11 |
from core.database import get_db
|
| 12 |
from core.models import User, AuditLog, GeminiJob, Contact, ClientUser
|
| 13 |
from services.encryption_service import decrypt_multiple_blocks
|
| 14 |
+
from dependencies import get_geolocation
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
|
|
|
| 27 |
|
| 28 |
@router.get("/api/data")
|
| 29 |
async def get_data(
|
| 30 |
+
request: Request,
|
| 31 |
page: int = Query(1, ge=1, description="Page number"),
|
| 32 |
limit: int = Query(100, ge=1, le=500, description="Items per page"),
|
| 33 |
log_type: str = Query(None, description="Filter by log type: client, server"),
|
|
|
|
| 34 |
db: AsyncSession = Depends(get_db)
|
| 35 |
):
|
| 36 |
"""
|
| 37 |
Get paginated audit log data for the authenticated user.
|
| 38 |
Admins see all logs from all users.
|
| 39 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 40 |
"""
|
| 41 |
+
user = request.state.user
|
| 42 |
from services.db_service import QueryService
|
| 43 |
|
| 44 |
try:
|
|
|
|
| 95 |
|
| 96 |
@router.get("/api/users")
|
| 97 |
async def get_users(
|
| 98 |
+
request: Request,
|
| 99 |
page: int = Query(1, ge=1, description="Page number"),
|
| 100 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
|
|
|
| 101 |
db: AsyncSession = Depends(get_db)
|
| 102 |
):
|
| 103 |
"""
|
| 104 |
Get current user's profile data.
|
| 105 |
Admins see paginated list of all users.
|
| 106 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 107 |
"""
|
| 108 |
+
user = request.state.user
|
| 109 |
from services.db_service import QueryService
|
| 110 |
|
| 111 |
try:
|
|
|
|
| 152 |
|
| 153 |
@router.get("/api/client-users")
|
| 154 |
async def get_client_users(
|
| 155 |
+
request: Request,
|
| 156 |
page: int = Query(1, ge=1, description="Page number"),
|
| 157 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 158 |
user_id: str = Query(None, description="Filter by server user_id"),
|
|
|
|
| 159 |
db: AsyncSession = Depends(get_db)
|
| 160 |
):
|
| 161 |
"""
|
| 162 |
Get current user's client mappings.
|
| 163 |
Admins see all client mappings from all users.
|
| 164 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 165 |
"""
|
| 166 |
+
user = request.state.user
|
| 167 |
from services.db_service import QueryService
|
| 168 |
|
| 169 |
try:
|
|
|
|
| 203 |
|
| 204 |
@router.get("/api/audit-logs")
|
| 205 |
async def get_audit_logs(
|
| 206 |
+
request: Request,
|
| 207 |
page: int = Query(1, ge=1, description="Page number"),
|
| 208 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 209 |
log_type: str = Query(None, description="Filter by log type: client, server"),
|
| 210 |
action: str = Query(None, description="Filter by action"),
|
|
|
|
| 211 |
db: AsyncSession = Depends(get_db)
|
| 212 |
):
|
| 213 |
"""
|
| 214 |
Get current user's audit logs with optional filters.
|
| 215 |
Admins see all logs from all users.
|
| 216 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 217 |
"""
|
| 218 |
+
user = request.state.user
|
| 219 |
from services.db_service import QueryService
|
| 220 |
|
| 221 |
try:
|
|
|
|
| 271 |
|
| 272 |
@router.get("/api/gemini-jobs")
|
| 273 |
async def get_gemini_jobs(
|
| 274 |
+
request: Request,
|
| 275 |
page: int = Query(1, ge=1, description="Page number"),
|
| 276 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
|
|
|
| 277 |
db: AsyncSession = Depends(get_db)
|
| 278 |
):
|
| 279 |
"""
|
| 280 |
Get current user's Gemini jobs.
|
| 281 |
Admins see all jobs from all users.
|
| 282 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 283 |
"""
|
| 284 |
+
user = request.state.user
|
| 285 |
from services.db_service import QueryService
|
| 286 |
|
| 287 |
try:
|
|
|
|
| 323 |
|
| 324 |
@router.get("/api/payment-transactions")
|
| 325 |
async def get_payment_transactions(
|
| 326 |
+
request: Request,
|
| 327 |
page: int = Query(1, ge=1, description="Page number"),
|
| 328 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
|
|
|
| 329 |
db: AsyncSession = Depends(get_db)
|
| 330 |
):
|
| 331 |
"""
|
| 332 |
Get current user's payment transactions.
|
| 333 |
Admins see all transactions from all users.
|
| 334 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 335 |
"""
|
| 336 |
+
user = request.state.user
|
| 337 |
from core.models import PaymentTransaction
|
| 338 |
from services.db_service import QueryService
|
| 339 |
|
|
|
|
| 395 |
|
| 396 |
@router.get("/api/contacts")
|
| 397 |
async def get_contacts(
|
| 398 |
+
request: Request,
|
| 399 |
page: int = Query(1, ge=1, description="Page number"),
|
| 400 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
|
|
|
| 401 |
db: AsyncSession = Depends(get_db)
|
| 402 |
):
|
| 403 |
"""
|
| 404 |
Get current user's contact form submissions.
|
| 405 |
Admins see all contact submissions from all users.
|
| 406 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 407 |
"""
|
| 408 |
+
user = request.state.user
|
| 409 |
from services.db_service import QueryService
|
| 410 |
|
| 411 |
try:
|
|
|
|
| 450 |
async def blink(
|
| 451 |
request: Request,
|
| 452 |
userid: str = Query(..., description="User ID (20 chars) + encrypted data"),
|
| 453 |
+
db: AsyncSession = Depends(get_db)
|
|
|
|
| 454 |
):
|
| 455 |
"""
|
| 456 |
Process blink request with encrypted user data.
|
| 457 |
Logs to AuditLog with log_type='client'.
|
| 458 |
|
| 459 |
+
Auth is optional (handled by AuthMiddleware):
|
| 460 |
+
- If authenticated: user in request.state.user
|
| 461 |
+
- If not authenticated: request.state.user is None
|
| 462 |
+
|
| 463 |
If authenticated via JWT:
|
| 464 |
- Creates a new ClientUser entry linking client_user_id to server user_id
|
| 465 |
- Sets user_id in AuditLog entries
|
|
|
|
| 467 |
If not authenticated:
|
| 468 |
- Creates AuditLog entries with user_id=None (anonymous)
|
| 469 |
"""
|
| 470 |
+
# Optional auth - may be None
|
| 471 |
+
current_user = request.state.user
|
| 472 |
try:
|
| 473 |
# Validate minimum length
|
| 474 |
if len(userid) < USER_ID_LENGTH:
|
routers/contact.py
CHANGED
|
@@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|
| 14 |
|
| 15 |
from core.database import get_db
|
| 16 |
from core.models import User, Contact
|
| 17 |
-
from dependencies import get_current_user
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
|
@@ -45,14 +44,17 @@ class ContactResponse(BaseModel):
|
|
| 45 |
async def submit_contact(
|
| 46 |
request_body: ContactRequest,
|
| 47 |
request: Request,
|
| 48 |
-
user: User = Depends(get_current_user),
|
| 49 |
db: AsyncSession = Depends(get_db)
|
| 50 |
):
|
| 51 |
"""
|
| 52 |
Submit a contact form for customer support.
|
| 53 |
|
| 54 |
-
Requires authentication - user
|
|
|
|
| 55 |
"""
|
|
|
|
|
|
|
|
|
|
| 56 |
# Validate message
|
| 57 |
if not request_body.message or not request_body.message.strip():
|
| 58 |
raise HTTPException(
|
|
|
|
| 14 |
|
| 15 |
from core.database import get_db
|
| 16 |
from core.models import User, Contact
|
|
|
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
|
|
|
| 44 |
async def submit_contact(
|
| 45 |
request_body: ContactRequest,
|
| 46 |
request: Request,
|
|
|
|
| 47 |
db: AsyncSession = Depends(get_db)
|
| 48 |
):
|
| 49 |
"""
|
| 50 |
Submit a contact form for customer support.
|
| 51 |
|
| 52 |
+
Requires authentication - user is authenticated by AuthMiddleware.
|
| 53 |
+
User is available in request.state.user
|
| 54 |
"""
|
| 55 |
+
# Get authenticated user from middleware
|
| 56 |
+
user = request.state.user
|
| 57 |
+
|
| 58 |
# Validate message
|
| 59 |
if not request_body.message or not request_body.message.strip():
|
| 60 |
raise HTTPException(
|
routers/credits.py
CHANGED
|
@@ -3,7 +3,7 @@ Credits Router - API endpoints for credit management.
|
|
| 3 |
|
| 4 |
Provides endpoints for checking credit balance and viewing credit history.
|
| 5 |
"""
|
| 6 |
-
from fastapi import APIRouter, Depends, Query
|
| 7 |
from pydantic import BaseModel
|
| 8 |
from typing import List, Optional
|
| 9 |
from datetime import datetime
|
|
@@ -12,7 +12,6 @@ from sqlalchemy import select, desc
|
|
| 12 |
|
| 13 |
from core.database import get_db
|
| 14 |
from core.models import User, GeminiJob
|
| 15 |
-
from dependencies import get_current_user
|
| 16 |
|
| 17 |
router = APIRouter(prefix="/credits", tags=["credits"])
|
| 18 |
|
|
@@ -47,15 +46,17 @@ class CreditHistoryResponse(BaseModel):
|
|
| 47 |
limit: int
|
| 48 |
|
| 49 |
|
| 50 |
-
@router.get("", response_model=CreditBalanceResponse)
|
| 51 |
async def get_credits(
|
| 52 |
-
|
| 53 |
):
|
| 54 |
"""
|
| 55 |
Get current credit balance.
|
| 56 |
|
| 57 |
Returns the user's current credit balance and last usage time.
|
|
|
|
| 58 |
"""
|
|
|
|
| 59 |
return CreditBalanceResponse(
|
| 60 |
user_id=user.user_id,
|
| 61 |
credits=user.credits,
|
|
@@ -65,7 +66,7 @@ async def get_credits(
|
|
| 65 |
|
| 66 |
@router.get("/history", response_model=CreditHistoryResponse)
|
| 67 |
async def get_credit_history(
|
| 68 |
-
|
| 69 |
db: AsyncSession = Depends(get_db),
|
| 70 |
page: int = Query(1, ge=1, description="Page number"),
|
| 71 |
limit: int = Query(20, ge=1, le=100, description="Items per page")
|
|
@@ -77,7 +78,9 @@ async def get_credit_history(
|
|
| 77 |
showing which jobs used credits and which were refunded.
|
| 78 |
|
| 79 |
Only includes jobs where credits were reserved (credits_reserved > 0).
|
|
|
|
| 80 |
"""
|
|
|
|
| 81 |
offset = (page - 1) * limit
|
| 82 |
|
| 83 |
# Query jobs with credit transactions
|
|
|
|
| 3 |
|
| 4 |
Provides endpoints for checking credit balance and viewing credit history.
|
| 5 |
"""
|
| 6 |
+
from fastapi import APIRouter, Depends, Query, Request
|
| 7 |
from pydantic import BaseModel
|
| 8 |
from typing import List, Optional
|
| 9 |
from datetime import datetime
|
|
|
|
| 12 |
|
| 13 |
from core.database import get_db
|
| 14 |
from core.models import User, GeminiJob
|
|
|
|
| 15 |
|
| 16 |
router = APIRouter(prefix="/credits", tags=["credits"])
|
| 17 |
|
|
|
|
| 46 |
limit: int
|
| 47 |
|
| 48 |
|
| 49 |
+
@router.get("/balance", response_model=CreditBalanceResponse)
|
| 50 |
async def get_credits(
|
| 51 |
+
request: Request
|
| 52 |
):
|
| 53 |
"""
|
| 54 |
Get current credit balance.
|
| 55 |
|
| 56 |
Returns the user's current credit balance and last usage time.
|
| 57 |
+
Auth is handled by AuthMiddleware - user is in request.state.user
|
| 58 |
"""
|
| 59 |
+
user = request.state.user
|
| 60 |
return CreditBalanceResponse(
|
| 61 |
user_id=user.user_id,
|
| 62 |
credits=user.credits,
|
|
|
|
| 66 |
|
| 67 |
@router.get("/history", response_model=CreditHistoryResponse)
|
| 68 |
async def get_credit_history(
|
| 69 |
+
request: Request,
|
| 70 |
db: AsyncSession = Depends(get_db),
|
| 71 |
page: int = Query(1, ge=1, description="Page number"),
|
| 72 |
limit: int = Query(20, ge=1, le=100, description="Items per page")
|
|
|
|
| 78 |
showing which jobs used credits and which were refunded.
|
| 79 |
|
| 80 |
Only includes jobs where credits were reserved (credits_reserved > 0).
|
| 81 |
+
Auth is handled by AuthMiddleware - user is in request.state.user
|
| 82 |
"""
|
| 83 |
+
user = request.state.user
|
| 84 |
offset = (page - 1) * limit
|
| 85 |
|
| 86 |
# Query jobs with credit transactions
|
routers/gemini.py
CHANGED
|
@@ -5,7 +5,7 @@ Authentication via JWT (Authorization: Bearer <token>).
|
|
| 5 |
"""
|
| 6 |
import os
|
| 7 |
import uuid
|
| 8 |
-
from fastapi import APIRouter, Depends, HTTPException, status
|
| 9 |
from fastapi.responses import FileResponse
|
| 10 |
from pydantic import BaseModel, Field
|
| 11 |
from typing import Optional, Literal
|
|
@@ -15,7 +15,6 @@ from sqlalchemy import select, func
|
|
| 15 |
from core.database import get_db
|
| 16 |
from core.models import User, GeminiJob
|
| 17 |
from services.gemini_service import MODELS, DOWNLOADS_DIR
|
| 18 |
-
from dependencies import verify_credits, verify_video_credits, get_current_user
|
| 19 |
from datetime import datetime
|
| 20 |
|
| 21 |
router = APIRouter(prefix="/gemini", tags=["gemini"])
|
|
@@ -102,13 +101,16 @@ async def create_job(
|
|
| 102 |
|
| 103 |
@router.post("/generate-animation-prompt")
|
| 104 |
async def generate_animation_prompt(
|
|
|
|
| 105 |
request: GenerateAnimationPromptRequest,
|
| 106 |
-
user: User = Depends(verify_credits),
|
| 107 |
db: AsyncSession = Depends(get_db)
|
| 108 |
):
|
| 109 |
"""
|
| 110 |
Queue an animation prompt generation job.
|
|
|
|
| 111 |
"""
|
|
|
|
|
|
|
| 112 |
job = await create_job(
|
| 113 |
db=db,
|
| 114 |
user=user,
|
|
@@ -118,7 +120,7 @@ async def generate_animation_prompt(
|
|
| 118 |
"mime_type": request.mime_type,
|
| 119 |
"custom_prompt": request.custom_prompt
|
| 120 |
},
|
| 121 |
-
credits_reserved=
|
| 122 |
)
|
| 123 |
|
| 124 |
position = await get_queue_position(db, job.job_id)
|
|
@@ -134,13 +136,16 @@ async def generate_animation_prompt(
|
|
| 134 |
|
| 135 |
@router.post("/edit-image")
|
| 136 |
async def edit_image(
|
|
|
|
| 137 |
request: EditImageRequest,
|
| 138 |
-
user: User = Depends(verify_credits),
|
| 139 |
db: AsyncSession = Depends(get_db)
|
| 140 |
):
|
| 141 |
"""
|
| 142 |
Queue an image edit job.
|
|
|
|
| 143 |
"""
|
|
|
|
|
|
|
| 144 |
job = await create_job(
|
| 145 |
db=db,
|
| 146 |
user=user,
|
|
@@ -150,7 +155,7 @@ async def edit_image(
|
|
| 150 |
"mime_type": request.mime_type,
|
| 151 |
"prompt": request.prompt
|
| 152 |
},
|
| 153 |
-
credits_reserved=
|
| 154 |
)
|
| 155 |
|
| 156 |
position = await get_queue_position(db, job.job_id)
|
|
@@ -166,13 +171,16 @@ async def edit_image(
|
|
| 166 |
|
| 167 |
@router.post("/generate-video")
|
| 168 |
async def generate_video(
|
|
|
|
| 169 |
request: GenerateVideoRequest,
|
| 170 |
-
user: User = Depends(verify_video_credits),
|
| 171 |
db: AsyncSession = Depends(get_db)
|
| 172 |
):
|
| 173 |
"""
|
| 174 |
Queue a video generation job.
|
|
|
|
| 175 |
"""
|
|
|
|
|
|
|
| 176 |
job = await create_job(
|
| 177 |
db=db,
|
| 178 |
user=user,
|
|
@@ -185,7 +193,7 @@ async def generate_video(
|
|
| 185 |
"resolution": request.resolution,
|
| 186 |
"number_of_videos": request.number_of_videos
|
| 187 |
},
|
| 188 |
-
credits_reserved=
|
| 189 |
)
|
| 190 |
|
| 191 |
position = await get_queue_position(db, job.job_id)
|
|
@@ -201,13 +209,16 @@ async def generate_video(
|
|
| 201 |
|
| 202 |
@router.post("/generate-text")
|
| 203 |
async def generate_text(
|
|
|
|
| 204 |
request: GenerateTextRequest,
|
| 205 |
-
user: User = Depends(verify_credits),
|
| 206 |
db: AsyncSession = Depends(get_db)
|
| 207 |
):
|
| 208 |
"""
|
| 209 |
Queue a text generation job.
|
|
|
|
| 210 |
"""
|
|
|
|
|
|
|
| 211 |
job = await create_job(
|
| 212 |
db=db,
|
| 213 |
user=user,
|
|
@@ -216,7 +227,7 @@ async def generate_text(
|
|
| 216 |
"prompt": request.prompt,
|
| 217 |
"model": request.model
|
| 218 |
},
|
| 219 |
-
credits_reserved=
|
| 220 |
)
|
| 221 |
|
| 222 |
position = await get_queue_position(db, job.job_id)
|
|
@@ -232,13 +243,16 @@ async def generate_text(
|
|
| 232 |
|
| 233 |
@router.post("/analyze-image")
|
| 234 |
async def analyze_image(
|
|
|
|
| 235 |
request: AnalyzeImageRequest,
|
| 236 |
-
user: User = Depends(verify_credits),
|
| 237 |
db: AsyncSession = Depends(get_db)
|
| 238 |
):
|
| 239 |
"""
|
| 240 |
Queue an image analysis job.
|
|
|
|
| 241 |
"""
|
|
|
|
|
|
|
| 242 |
job = await create_job(
|
| 243 |
db=db,
|
| 244 |
user=user,
|
|
@@ -248,7 +262,7 @@ async def analyze_image(
|
|
| 248 |
"mime_type": request.mime_type,
|
| 249 |
"prompt": request.prompt
|
| 250 |
},
|
| 251 |
-
credits_reserved=
|
| 252 |
)
|
| 253 |
|
| 254 |
position = await get_queue_position(db, job.job_id)
|
|
@@ -264,7 +278,7 @@ async def analyze_image(
|
|
| 264 |
|
| 265 |
@router.get("/jobs")
|
| 266 |
async def get_jobs(
|
| 267 |
-
|
| 268 |
db: AsyncSession = Depends(get_db),
|
| 269 |
page: int = 1,
|
| 270 |
limit: int = 20
|
|
@@ -272,7 +286,9 @@ async def get_jobs(
|
|
| 272 |
"""
|
| 273 |
Get all jobs created by the current user.
|
| 274 |
Returns a paginated list of jobs with status, type, and prompt (for video jobs).
|
|
|
|
| 275 |
"""
|
|
|
|
| 276 |
offset = (page - 1) * limit
|
| 277 |
|
| 278 |
# Query jobs for the current user
|
|
@@ -326,14 +342,16 @@ async def get_jobs(
|
|
| 326 |
@router.get("/job/{job_id}")
|
| 327 |
async def get_job_status(
|
| 328 |
job_id: str,
|
| 329 |
-
|
| 330 |
db: AsyncSession = Depends(get_db)
|
| 331 |
):
|
| 332 |
"""
|
| 333 |
Get the status of a job.
|
| 334 |
Poll this endpoint until status is 'completed' or 'failed'.
|
| 335 |
For processing video jobs, this will check the Gemini API status and update the job.
|
|
|
|
| 336 |
"""
|
|
|
|
| 337 |
query = select(GeminiJob).where(
|
| 338 |
GeminiJob.job_id == job_id,
|
| 339 |
GeminiJob.user_id == user.id # Integer FK comparison
|
|
@@ -391,14 +409,16 @@ async def get_job_status(
|
|
| 391 |
@router.get("/download/{job_id}")
|
| 392 |
async def download_video(
|
| 393 |
job_id: str,
|
| 394 |
-
|
| 395 |
db: AsyncSession = Depends(get_db)
|
| 396 |
):
|
| 397 |
"""
|
| 398 |
Download a generated video.
|
| 399 |
Downloads from Gemini URL, streams to client, then deletes local file.
|
| 400 |
No permanent storage on server.
|
|
|
|
| 401 |
"""
|
|
|
|
| 402 |
from fastapi.responses import StreamingResponse
|
| 403 |
import httpx
|
| 404 |
|
|
@@ -470,14 +490,16 @@ async def download_video(
|
|
| 470 |
@router.post("/job/{job_id}/cancel")
|
| 471 |
async def cancel_job(
|
| 472 |
job_id: str,
|
| 473 |
-
|
| 474 |
db: AsyncSession = Depends(get_db)
|
| 475 |
):
|
| 476 |
"""
|
| 477 |
Cancel a queued job.
|
| 478 |
Only jobs with status 'queued' can be cancelled.
|
| 479 |
Processing/completed/failed jobs cannot be cancelled.
|
|
|
|
| 480 |
"""
|
|
|
|
| 481 |
query = select(GeminiJob).where(
|
| 482 |
GeminiJob.job_id == job_id,
|
| 483 |
GeminiJob.user_id == user.id # Integer FK comparison
|
|
@@ -512,7 +534,7 @@ async def cancel_job(
|
|
| 512 |
@router.delete("/job/{job_id}")
|
| 513 |
async def delete_job(
|
| 514 |
job_id: str,
|
| 515 |
-
|
| 516 |
db: AsyncSession = Depends(get_db)
|
| 517 |
):
|
| 518 |
"""
|
|
@@ -521,7 +543,9 @@ async def delete_job(
|
|
| 521 |
Refund policy:
|
| 522 |
- If queued: Refund 8 credits (10 cost - 2 penalty), soft delete job.
|
| 523 |
- If processing/completed/failed: Soft delete job (no refund).
|
|
|
|
| 524 |
"""
|
|
|
|
| 525 |
from services.db_service import QueryService
|
| 526 |
|
| 527 |
qs = QueryService(user, db)
|
|
|
|
| 5 |
"""
|
| 6 |
import os
|
| 7 |
import uuid
|
| 8 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
| 9 |
from fastapi.responses import FileResponse
|
| 10 |
from pydantic import BaseModel, Field
|
| 11 |
from typing import Optional, Literal
|
|
|
|
| 15 |
from core.database import get_db
|
| 16 |
from core.models import User, GeminiJob
|
| 17 |
from services.gemini_service import MODELS, DOWNLOADS_DIR
|
|
|
|
| 18 |
from datetime import datetime
|
| 19 |
|
| 20 |
router = APIRouter(prefix="/gemini", tags=["gemini"])
|
|
|
|
| 101 |
|
| 102 |
@router.post("/generate-animation-prompt")
|
| 103 |
async def generate_animation_prompt(
|
| 104 |
+
req: Request,
|
| 105 |
request: GenerateAnimationPromptRequest,
|
|
|
|
| 106 |
db: AsyncSession = Depends(get_db)
|
| 107 |
):
|
| 108 |
"""
|
| 109 |
Queue an animation prompt generation job.
|
| 110 |
+
Auth and credit validation handled by middleware.
|
| 111 |
"""
|
| 112 |
+
user = req.state.user
|
| 113 |
+
credits_reserved = req.state.credits_reserved
|
| 114 |
job = await create_job(
|
| 115 |
db=db,
|
| 116 |
user=user,
|
|
|
|
| 120 |
"mime_type": request.mime_type,
|
| 121 |
"custom_prompt": request.custom_prompt
|
| 122 |
},
|
| 123 |
+
credits_reserved=credits_reserved
|
| 124 |
)
|
| 125 |
|
| 126 |
position = await get_queue_position(db, job.job_id)
|
|
|
|
| 136 |
|
| 137 |
@router.post("/edit-image")
|
| 138 |
async def edit_image(
|
| 139 |
+
req: Request,
|
| 140 |
request: EditImageRequest,
|
|
|
|
| 141 |
db: AsyncSession = Depends(get_db)
|
| 142 |
):
|
| 143 |
"""
|
| 144 |
Queue an image edit job.
|
| 145 |
+
Auth and credit validation handled by middleware.
|
| 146 |
"""
|
| 147 |
+
user = req.state.user
|
| 148 |
+
credits_reserved = req.state.credits_reserved
|
| 149 |
job = await create_job(
|
| 150 |
db=db,
|
| 151 |
user=user,
|
|
|
|
| 155 |
"mime_type": request.mime_type,
|
| 156 |
"prompt": request.prompt
|
| 157 |
},
|
| 158 |
+
credits_reserved=credits_reserved
|
| 159 |
)
|
| 160 |
|
| 161 |
position = await get_queue_position(db, job.job_id)
|
|
|
|
| 171 |
|
| 172 |
@router.post("/generate-video")
|
| 173 |
async def generate_video(
|
| 174 |
+
req: Request,
|
| 175 |
request: GenerateVideoRequest,
|
|
|
|
| 176 |
db: AsyncSession = Depends(get_db)
|
| 177 |
):
|
| 178 |
"""
|
| 179 |
Queue a video generation job.
|
| 180 |
+
Auth and credit validation handled by middleware.
|
| 181 |
"""
|
| 182 |
+
user = req.state.user
|
| 183 |
+
credits_reserved = req.state.credits_reserved
|
| 184 |
job = await create_job(
|
| 185 |
db=db,
|
| 186 |
user=user,
|
|
|
|
| 193 |
"resolution": request.resolution,
|
| 194 |
"number_of_videos": request.number_of_videos
|
| 195 |
},
|
| 196 |
+
credits_reserved=credits_reserved # 10 credits for video
|
| 197 |
)
|
| 198 |
|
| 199 |
position = await get_queue_position(db, job.job_id)
|
|
|
|
| 209 |
|
| 210 |
@router.post("/generate-text")
|
| 211 |
async def generate_text(
|
| 212 |
+
req: Request,
|
| 213 |
request: GenerateTextRequest,
|
|
|
|
| 214 |
db: AsyncSession = Depends(get_db)
|
| 215 |
):
|
| 216 |
"""
|
| 217 |
Queue a text generation job.
|
| 218 |
+
Auth and credit validation handled by middleware.
|
| 219 |
"""
|
| 220 |
+
user = req.state.user
|
| 221 |
+
credits_reserved = req.state.credits_reserved
|
| 222 |
job = await create_job(
|
| 223 |
db=db,
|
| 224 |
user=user,
|
|
|
|
| 227 |
"prompt": request.prompt,
|
| 228 |
"model": request.model
|
| 229 |
},
|
| 230 |
+
credits_reserved=credits_reserved
|
| 231 |
)
|
| 232 |
|
| 233 |
position = await get_queue_position(db, job.job_id)
|
|
|
|
| 243 |
|
| 244 |
@router.post("/analyze-image")
|
| 245 |
async def analyze_image(
|
| 246 |
+
req: Request,
|
| 247 |
request: AnalyzeImageRequest,
|
|
|
|
| 248 |
db: AsyncSession = Depends(get_db)
|
| 249 |
):
|
| 250 |
"""
|
| 251 |
Queue an image analysis job.
|
| 252 |
+
Auth and credit validation handled by middleware.
|
| 253 |
"""
|
| 254 |
+
user = req.state.user
|
| 255 |
+
credits_reserved = req.state.credits_reserved
|
| 256 |
job = await create_job(
|
| 257 |
db=db,
|
| 258 |
user=user,
|
|
|
|
| 262 |
"mime_type": request.mime_type,
|
| 263 |
"prompt": request.prompt
|
| 264 |
},
|
| 265 |
+
credits_reserved=credits_reserved
|
| 266 |
)
|
| 267 |
|
| 268 |
position = await get_queue_position(db, job.job_id)
|
|
|
|
| 278 |
|
| 279 |
@router.get("/jobs")
|
| 280 |
async def get_jobs(
|
| 281 |
+
req: Request,
|
| 282 |
db: AsyncSession = Depends(get_db),
|
| 283 |
page: int = 1,
|
| 284 |
limit: int = 20
|
|
|
|
| 286 |
"""
|
| 287 |
Get all jobs created by the current user.
|
| 288 |
Returns a paginated list of jobs with status, type, and prompt (for video jobs).
|
| 289 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 290 |
"""
|
| 291 |
+
user = req.state.user
|
| 292 |
offset = (page - 1) * limit
|
| 293 |
|
| 294 |
# Query jobs for the current user
|
|
|
|
| 342 |
@router.get("/job/{job_id}")
|
| 343 |
async def get_job_status(
|
| 344 |
job_id: str,
|
| 345 |
+
req: Request,
|
| 346 |
db: AsyncSession = Depends(get_db)
|
| 347 |
):
|
| 348 |
"""
|
| 349 |
Get the status of a job.
|
| 350 |
Poll this endpoint until status is 'completed' or 'failed'.
|
| 351 |
For processing video jobs, this will check the Gemini API status and update the job.
|
| 352 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 353 |
"""
|
| 354 |
+
user = req.state.user
|
| 355 |
query = select(GeminiJob).where(
|
| 356 |
GeminiJob.job_id == job_id,
|
| 357 |
GeminiJob.user_id == user.id # Integer FK comparison
|
|
|
|
| 409 |
@router.get("/download/{job_id}")
|
| 410 |
async def download_video(
|
| 411 |
job_id: str,
|
| 412 |
+
req: Request,
|
| 413 |
db: AsyncSession = Depends(get_db)
|
| 414 |
):
|
| 415 |
"""
|
| 416 |
Download a generated video.
|
| 417 |
Downloads from Gemini URL, streams to client, then deletes local file.
|
| 418 |
No permanent storage on server.
|
| 419 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 420 |
"""
|
| 421 |
+
user = req.state.user
|
| 422 |
from fastapi.responses import StreamingResponse
|
| 423 |
import httpx
|
| 424 |
|
|
|
|
| 490 |
@router.post("/job/{job_id}/cancel")
|
| 491 |
async def cancel_job(
|
| 492 |
job_id: str,
|
| 493 |
+
req: Request,
|
| 494 |
db: AsyncSession = Depends(get_db)
|
| 495 |
):
|
| 496 |
"""
|
| 497 |
Cancel a queued job.
|
| 498 |
Only jobs with status 'queued' can be cancelled.
|
| 499 |
Processing/completed/failed jobs cannot be cancelled.
|
| 500 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 501 |
"""
|
| 502 |
+
user = req.state.user
|
| 503 |
query = select(GeminiJob).where(
|
| 504 |
GeminiJob.job_id == job_id,
|
| 505 |
GeminiJob.user_id == user.id # Integer FK comparison
|
|
|
|
| 534 |
@router.delete("/job/{job_id}")
|
| 535 |
async def delete_job(
|
| 536 |
job_id: str,
|
| 537 |
+
req: Request,
|
| 538 |
db: AsyncSession = Depends(get_db)
|
| 539 |
):
|
| 540 |
"""
|
|
|
|
| 543 |
Refund policy:
|
| 544 |
- If queued: Refund 8 credits (10 cost - 2 penalty), soft delete job.
|
| 545 |
- If processing/completed/failed: Soft delete job (no refund).
|
| 546 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 547 |
"""
|
| 548 |
+
user = req.state.user
|
| 549 |
from services.db_service import QueryService
|
| 550 |
|
| 551 |
qs = QueryService(user, db)
|
routers/payments.py
CHANGED
|
@@ -21,7 +21,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|
| 21 |
|
| 22 |
from core.database import get_db
|
| 23 |
from core.models import User, PaymentTransaction
|
| 24 |
-
from dependencies import get_current_user
|
| 25 |
from services.drive_service import DriveService
|
| 26 |
from services.razorpay_service import (
|
| 27 |
RazorpayService,
|
|
@@ -223,8 +222,8 @@ async def get_packages():
|
|
| 223 |
|
| 224 |
@router.post("/create-order", response_model=CreateOrderResponse)
|
| 225 |
async def create_order(
|
|
|
|
| 226 |
request: CreateOrderRequest,
|
| 227 |
-
user: User = Depends(get_current_user),
|
| 228 |
db: AsyncSession = Depends(get_db)
|
| 229 |
):
|
| 230 |
"""
|
|
@@ -232,7 +231,9 @@ async def create_order(
|
|
| 232 |
|
| 233 |
The client should use the returned order_id to open
|
| 234 |
Razorpay checkout. After payment, call /verify endpoint.
|
|
|
|
| 235 |
"""
|
|
|
|
| 236 |
# Check if Razorpay is configured
|
| 237 |
if not is_razorpay_configured():
|
| 238 |
raise HTTPException(
|
|
@@ -315,9 +316,9 @@ async def create_order(
|
|
| 315 |
|
| 316 |
@router.post("/verify", response_model=VerifyPaymentResponse)
|
| 317 |
async def verify_payment(
|
|
|
|
| 318 |
request: VerifyPaymentRequest,
|
| 319 |
background_tasks: BackgroundTasks,
|
| 320 |
-
user: User = Depends(get_current_user),
|
| 321 |
db: AsyncSession = Depends(get_db)
|
| 322 |
):
|
| 323 |
"""
|
|
@@ -325,7 +326,9 @@ async def verify_payment(
|
|
| 325 |
|
| 326 |
Called after successful Razorpay checkout.
|
| 327 |
Verifies the payment signature and credits the user.
|
|
|
|
| 328 |
"""
|
|
|
|
| 329 |
try:
|
| 330 |
razorpay_service = get_razorpay_service()
|
| 331 |
|
|
@@ -557,16 +560,18 @@ async def razorpay_webhook(
|
|
| 557 |
|
| 558 |
@router.get("/history", response_model=PaymentHistoryResponse)
|
| 559 |
async def get_payment_history(
|
|
|
|
| 560 |
page: int = Query(1, ge=1, description="Page number"),
|
| 561 |
limit: int = Query(20, ge=1, le=100, description="Items per page"),
|
| 562 |
-
user: User = Depends(get_current_user),
|
| 563 |
db: AsyncSession = Depends(get_db)
|
| 564 |
):
|
| 565 |
"""
|
| 566 |
Get user's payment history with pagination.
|
| 567 |
|
| 568 |
Returns payment transactions ordered by newest first.
|
|
|
|
| 569 |
"""
|
|
|
|
| 570 |
# Get total count
|
| 571 |
count_result = await db.execute(
|
| 572 |
select(func.count(PaymentTransaction.id))
|
|
|
|
| 21 |
|
| 22 |
from core.database import get_db
|
| 23 |
from core.models import User, PaymentTransaction
|
|
|
|
| 24 |
from services.drive_service import DriveService
|
| 25 |
from services.razorpay_service import (
|
| 26 |
RazorpayService,
|
|
|
|
| 222 |
|
| 223 |
@router.post("/create-order", response_model=CreateOrderResponse)
|
| 224 |
async def create_order(
|
| 225 |
+
req: Request,
|
| 226 |
request: CreateOrderRequest,
|
|
|
|
| 227 |
db: AsyncSession = Depends(get_db)
|
| 228 |
):
|
| 229 |
"""
|
|
|
|
| 231 |
|
| 232 |
The client should use the returned order_id to open
|
| 233 |
Razorpay checkout. After payment, call /verify endpoint.
|
| 234 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 235 |
"""
|
| 236 |
+
user = req.state.user
|
| 237 |
# Check if Razorpay is configured
|
| 238 |
if not is_razorpay_configured():
|
| 239 |
raise HTTPException(
|
|
|
|
| 316 |
|
| 317 |
@router.post("/verify", response_model=VerifyPaymentResponse)
|
| 318 |
async def verify_payment(
|
| 319 |
+
req: Request,
|
| 320 |
request: VerifyPaymentRequest,
|
| 321 |
background_tasks: BackgroundTasks,
|
|
|
|
| 322 |
db: AsyncSession = Depends(get_db)
|
| 323 |
):
|
| 324 |
"""
|
|
|
|
| 326 |
|
| 327 |
Called after successful Razorpay checkout.
|
| 328 |
Verifies the payment signature and credits the user.
|
| 329 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 330 |
"""
|
| 331 |
+
user = req.state.user
|
| 332 |
try:
|
| 333 |
razorpay_service = get_razorpay_service()
|
| 334 |
|
|
|
|
| 560 |
|
| 561 |
@router.get("/history", response_model=PaymentHistoryResponse)
|
| 562 |
async def get_payment_history(
|
| 563 |
+
req: Request,
|
| 564 |
page: int = Query(1, ge=1, description="Page number"),
|
| 565 |
limit: int = Query(20, ge=1, le=100, description="Items per page"),
|
|
|
|
| 566 |
db: AsyncSession = Depends(get_db)
|
| 567 |
):
|
| 568 |
"""
|
| 569 |
Get user's payment history with pagination.
|
| 570 |
|
| 571 |
Returns payment transactions ordered by newest first.
|
| 572 |
+
Auth handled by AuthMiddleware - user in request.state.user
|
| 573 |
"""
|
| 574 |
+
user = req.state.user
|
| 575 |
# Get total count
|
| 576 |
count_result = await db.execute(
|
| 577 |
select(func.count(PaymentTransaction.id))
|
services/auth_service/__init__.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Auth Service - Authentication layer for API Gateway
|
| 3 |
+
|
| 4 |
+
Provides plug-and-play authentication with:
|
| 5 |
+
- Google OAuth integration
|
| 6 |
+
- JWT token management
|
| 7 |
+
- Request middleware for auth validation
|
| 8 |
+
- URL-based route configuration
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
# In app.py startup
|
| 12 |
+
from services.auth_service import register_auth_service
|
| 13 |
+
|
| 14 |
+
register_auth_service(
|
| 15 |
+
required_urls=["/api/*", "/admin/*"],
|
| 16 |
+
public_urls=["/", "/health", "/auth/*"],
|
| 17 |
+
jwt_secret=os.getenv("JWT_SECRET"),
|
| 18 |
+
google_client_id=os.getenv("GOOGLE_CLIENT_ID")
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# In routers
|
| 22 |
+
from fastapi import Request
|
| 23 |
+
|
| 24 |
+
@router.get("/protected")
|
| 25 |
+
async def protected_route(request: Request):
|
| 26 |
+
user = request.state.user # Populated by AuthMiddleware
|
| 27 |
+
return {"user_id": user.id}
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
from services.auth_service.config import AuthServiceConfig
|
| 31 |
+
from services.auth_service.middleware import AuthMiddleware
|
| 32 |
+
from services.auth_service.google_provider import (
|
| 33 |
+
GoogleAuthService,
|
| 34 |
+
GoogleUserInfo,
|
| 35 |
+
verify_google_token,
|
| 36 |
+
GoogleAuthError,
|
| 37 |
+
InvalidTokenError as GoogleInvalidTokenError,
|
| 38 |
+
)
|
| 39 |
+
from services.auth_service.jwt_provider import (
|
| 40 |
+
JWTService,
|
| 41 |
+
TokenPayload,
|
| 42 |
+
create_access_token,
|
| 43 |
+
verify_access_token,
|
| 44 |
+
JWTError,
|
| 45 |
+
TokenExpiredError,
|
| 46 |
+
InvalidTokenError,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def register_auth_service(
|
| 51 |
+
required_urls: list = None,
|
| 52 |
+
optional_urls: list = None,
|
| 53 |
+
public_urls: list = None,
|
| 54 |
+
jwt_secret: str = None,
|
| 55 |
+
jwt_algorithm: str = "HS256",
|
| 56 |
+
jwt_expiry_hours: int = 24,
|
| 57 |
+
google_client_id: str = None,
|
| 58 |
+
admin_emails: list = None,
|
| 59 |
+
) -> None:
|
| 60 |
+
"""
|
| 61 |
+
Register the auth service with application configuration.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
required_urls: URLs that REQUIRE authentication
|
| 65 |
+
optional_urls: URLs where authentication is optional
|
| 66 |
+
public_urls: URLs that don't need authentication
|
| 67 |
+
jwt_secret: Secret key for JWT signing
|
| 68 |
+
jwt_algorithm: JWT algorithm (default: HS256)
|
| 69 |
+
jwt_expiry_hours: Token expiry in hours (default: 24)
|
| 70 |
+
google_client_id: Google OAuth Client ID
|
| 71 |
+
admin_emails: List of admin email addresses
|
| 72 |
+
"""
|
| 73 |
+
AuthServiceConfig.register(
|
| 74 |
+
required_urls=required_urls or [],
|
| 75 |
+
optional_urls=optional_urls or [],
|
| 76 |
+
public_urls=public_urls or [],
|
| 77 |
+
jwt_secret=jwt_secret,
|
| 78 |
+
jwt_algorithm=jwt_algorithm,
|
| 79 |
+
jwt_expiry_hours=jwt_expiry_hours,
|
| 80 |
+
google_client_id=google_client_id,
|
| 81 |
+
admin_emails=admin_emails or [],
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
__all__ = [
|
| 86 |
+
# Registration
|
| 87 |
+
'register_auth_service',
|
| 88 |
+
'AuthServiceConfig',
|
| 89 |
+
'AuthMiddleware',
|
| 90 |
+
|
| 91 |
+
# Google OAuth
|
| 92 |
+
'GoogleAuthService',
|
| 93 |
+
'GoogleUserInfo',
|
| 94 |
+
'verify_google_token',
|
| 95 |
+
'GoogleAuthError',
|
| 96 |
+
'GoogleInvalidTokenError',
|
| 97 |
+
|
| 98 |
+
# JWT
|
| 99 |
+
'JWTService',
|
| 100 |
+
'TokenPayload',
|
| 101 |
+
'create_access_token',
|
| 102 |
+
'verify_access_token',
|
| 103 |
+
'JWTError',
|
| 104 |
+
'TokenExpiredError',
|
| 105 |
+
'InvalidTokenError',
|
| 106 |
+
]
|
services/auth_service/config.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Auth Service Configuration
|
| 3 |
+
|
| 4 |
+
Manages authentication configuration and route matching for the auth service.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from typing import List
|
| 9 |
+
from services.base_service import BaseService, ServiceConfig
|
| 10 |
+
from services.base_service.route_matcher import RouteConfig
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AuthServiceConfig(BaseService):
|
| 16 |
+
"""
|
| 17 |
+
Configuration for the auth service.
|
| 18 |
+
|
| 19 |
+
Controls which routes require authentication, which are optional,
|
| 20 |
+
and which are public (no auth needed).
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
SERVICE_NAME = "auth_service"
|
| 24 |
+
|
| 25 |
+
# Route configuration
|
| 26 |
+
_route_config: RouteConfig = None
|
| 27 |
+
|
| 28 |
+
# JWT configuration
|
| 29 |
+
_jwt_secret: str = None
|
| 30 |
+
_jwt_algorithm: str = "HS256"
|
| 31 |
+
_jwt_expiry_hours: int = 24
|
| 32 |
+
|
| 33 |
+
# Google OAuth configuration
|
| 34 |
+
_google_client_id: str = None
|
| 35 |
+
|
| 36 |
+
# Admin configuration
|
| 37 |
+
_admin_emails: List[str] = []
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def register(
|
| 41 |
+
cls,
|
| 42 |
+
required_urls: List[str] = None,
|
| 43 |
+
optional_urls: List[str] = None,
|
| 44 |
+
public_urls: List[str] = None,
|
| 45 |
+
jwt_secret: str = None,
|
| 46 |
+
jwt_algorithm: str = "HS256",
|
| 47 |
+
jwt_expiry_hours: int = 24,
|
| 48 |
+
google_client_id: str = None,
|
| 49 |
+
admin_emails: List[str] = None,
|
| 50 |
+
) -> None:
|
| 51 |
+
"""
|
| 52 |
+
Register auth service configuration.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
required_urls: URLs that REQUIRE authentication
|
| 56 |
+
optional_urls: URLs where authentication is optional
|
| 57 |
+
public_urls: URLs that don't need authentication
|
| 58 |
+
jwt_secret: Secret key for JWT signing
|
| 59 |
+
jwt_algorithm: JWT algorithm (default: HS256)
|
| 60 |
+
jwt_expiry_hours: Token expiry in hours (default: 24)
|
| 61 |
+
google_client_id: Google OAuth Client ID
|
| 62 |
+
admin_emails: List of admin email addresses
|
| 63 |
+
|
| 64 |
+
Raises:
|
| 65 |
+
RuntimeError: If service is already registered
|
| 66 |
+
ValueError: If jwt_secret is not provided
|
| 67 |
+
"""
|
| 68 |
+
if cls._registered:
|
| 69 |
+
raise RuntimeError(f"{cls.SERVICE_NAME} is already registered")
|
| 70 |
+
|
| 71 |
+
# Validate JWT secret
|
| 72 |
+
if not jwt_secret:
|
| 73 |
+
raise ValueError("jwt_secret is required for auth service")
|
| 74 |
+
|
| 75 |
+
# Store route configuration
|
| 76 |
+
cls._route_config = RouteConfig(
|
| 77 |
+
required=required_urls or [],
|
| 78 |
+
optional=optional_urls or [],
|
| 79 |
+
public=public_urls or [],
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Store JWT configuration
|
| 83 |
+
cls._jwt_secret = jwt_secret
|
| 84 |
+
cls._jwt_algorithm = jwt_algorithm
|
| 85 |
+
cls._jwt_expiry_hours = jwt_expiry_hours
|
| 86 |
+
|
| 87 |
+
# Store Google OAuth configuration
|
| 88 |
+
cls._google_client_id = google_client_id
|
| 89 |
+
|
| 90 |
+
# Store admin configuration
|
| 91 |
+
cls._admin_emails = admin_emails or []
|
| 92 |
+
|
| 93 |
+
cls._registered = True
|
| 94 |
+
|
| 95 |
+
logger.info(f"✅ {cls.SERVICE_NAME} registered successfully")
|
| 96 |
+
logger.info(f" JWT algorithm: {cls._jwt_algorithm}")
|
| 97 |
+
logger.info(f" JWT expiry: {cls._jwt_expiry_hours} hours")
|
| 98 |
+
logger.info(f" Required URLs: {len(required_urls or [])}")
|
| 99 |
+
logger.info(f" Optional URLs: {len(optional_urls or [])}")
|
| 100 |
+
logger.info(f" Public URLs: {len(public_urls or [])}")
|
| 101 |
+
logger.info(f" Admin emails: {len(cls._admin_emails)}")
|
| 102 |
+
|
| 103 |
+
@classmethod
|
| 104 |
+
def get_middleware(cls):
|
| 105 |
+
"""Return AuthMiddleware instance."""
|
| 106 |
+
from services.auth_service.middleware import AuthMiddleware
|
| 107 |
+
return AuthMiddleware
|
| 108 |
+
|
| 109 |
+
@classmethod
|
| 110 |
+
def requires_auth(cls, path: str) -> bool:
|
| 111 |
+
"""Check if a URL path requires authentication."""
|
| 112 |
+
cls.assert_registered()
|
| 113 |
+
return cls._route_config.is_required(path)
|
| 114 |
+
|
| 115 |
+
@classmethod
|
| 116 |
+
def allows_optional_auth(cls, path: str) -> bool:
|
| 117 |
+
"""Check if a URL path allows optional authentication."""
|
| 118 |
+
cls.assert_registered()
|
| 119 |
+
return cls._route_config.is_optional(path)
|
| 120 |
+
|
| 121 |
+
@classmethod
|
| 122 |
+
def is_public(cls, path: str) -> bool:
|
| 123 |
+
"""Check if a URL path is public (no auth needed)."""
|
| 124 |
+
cls.assert_registered()
|
| 125 |
+
return cls._route_config.is_public(path)
|
| 126 |
+
|
| 127 |
+
@classmethod
|
| 128 |
+
def get_jwt_secret(cls) -> str:
|
| 129 |
+
"""Get JWT secret key."""
|
| 130 |
+
cls.assert_registered()
|
| 131 |
+
return cls._jwt_secret
|
| 132 |
+
|
| 133 |
+
@classmethod
|
| 134 |
+
def get_jwt_algorithm(cls) -> str:
|
| 135 |
+
"""Get JWT algorithm."""
|
| 136 |
+
cls.assert_registered()
|
| 137 |
+
return cls._jwt_algorithm
|
| 138 |
+
|
| 139 |
+
@classmethod
|
| 140 |
+
def get_jwt_expiry_hours(cls) -> int:
|
| 141 |
+
"""Get JWT expiry hours."""
|
| 142 |
+
cls.assert_registered()
|
| 143 |
+
return cls._jwt_expiry_hours
|
| 144 |
+
|
| 145 |
+
@classmethod
|
| 146 |
+
def get_google_client_id(cls) -> str:
|
| 147 |
+
"""Get Google OAuth Client ID."""
|
| 148 |
+
cls.assert_registered()
|
| 149 |
+
return cls._google_client_id
|
| 150 |
+
|
| 151 |
+
@classmethod
|
| 152 |
+
def is_admin(cls, email: str) -> bool:
|
| 153 |
+
"""Check if an email is an admin."""
|
| 154 |
+
cls.assert_registered()
|
| 155 |
+
return email in cls._admin_emails
|
| 156 |
+
|
| 157 |
+
@classmethod
|
| 158 |
+
def get_admin_emails(cls) -> List[str]:
|
| 159 |
+
"""Get list of admin emails."""
|
| 160 |
+
cls.assert_registered()
|
| 161 |
+
return cls._admin_emails.copy()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
__all__ = ['AuthServiceConfig']
|
services/auth_service/google_provider.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modular Google OAuth Service
|
| 3 |
+
|
| 4 |
+
A self-contained, plug-and-play service for verifying Google ID tokens.
|
| 5 |
+
Can be used in any Python application with minimal configuration.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
from services.google_auth_service import GoogleAuthService, GoogleUserInfo
|
| 9 |
+
|
| 10 |
+
# Initialize with client ID
|
| 11 |
+
auth_service = GoogleAuthService(client_id="your-google-client-id")
|
| 12 |
+
|
| 13 |
+
# Or use environment variable GOOGLE_CLIENT_ID
|
| 14 |
+
auth_service = GoogleAuthService()
|
| 15 |
+
|
| 16 |
+
# Verify a Google ID token
|
| 17 |
+
user_info = auth_service.verify_token(id_token)
|
| 18 |
+
print(user_info.email, user_info.google_id, user_info.name)
|
| 19 |
+
|
| 20 |
+
Environment Variables:
|
| 21 |
+
GOOGLE_CLIENT_ID: Your Google OAuth 2.0 Client ID
|
| 22 |
+
|
| 23 |
+
Dependencies:
|
| 24 |
+
google-auth>=2.0.0
|
| 25 |
+
google-auth-oauthlib>=1.0.0
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import os
|
| 29 |
+
import logging
|
| 30 |
+
from dataclasses import dataclass
|
| 31 |
+
from typing import Optional
|
| 32 |
+
from google.oauth2 import id_token as google_id_token
|
| 33 |
+
from google.auth.transport import requests as google_requests
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class GoogleUserInfo:
|
| 40 |
+
"""
|
| 41 |
+
User information extracted from a verified Google ID token.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
google_id: Unique Google user identifier (sub claim)
|
| 45 |
+
email: User's email address
|
| 46 |
+
email_verified: Whether Google has verified the email
|
| 47 |
+
name: User's display name (may be None)
|
| 48 |
+
picture: URL to user's profile picture (may be None)
|
| 49 |
+
given_name: User's first name (may be None)
|
| 50 |
+
family_name: User's last name (may be None)
|
| 51 |
+
locale: User's locale preference (may be None)
|
| 52 |
+
"""
|
| 53 |
+
google_id: str
|
| 54 |
+
email: str
|
| 55 |
+
email_verified: bool = True
|
| 56 |
+
name: Optional[str] = None
|
| 57 |
+
picture: Optional[str] = None
|
| 58 |
+
given_name: Optional[str] = None
|
| 59 |
+
family_name: Optional[str] = None
|
| 60 |
+
locale: Optional[str] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class GoogleAuthError(Exception):
|
| 64 |
+
"""Base exception for Google Auth errors."""
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class InvalidTokenError(GoogleAuthError):
|
| 69 |
+
"""Raised when the token is invalid or expired."""
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class ConfigurationError(GoogleAuthError):
|
| 74 |
+
"""Raised when the service is not properly configured."""
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class GoogleAuthService:
|
| 79 |
+
"""
|
| 80 |
+
Service for verifying Google OAuth ID tokens.
|
| 81 |
+
|
| 82 |
+
This service validates ID tokens issued by Google Sign-In and extracts
|
| 83 |
+
user information. It's designed to be modular and reusable across
|
| 84 |
+
different applications.
|
| 85 |
+
|
| 86 |
+
Example:
|
| 87 |
+
service = GoogleAuthService()
|
| 88 |
+
try:
|
| 89 |
+
user_info = service.verify_token(token_from_frontend)
|
| 90 |
+
print(f"Welcome {user_info.name}!")
|
| 91 |
+
except InvalidTokenError:
|
| 92 |
+
print("Invalid or expired token")
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
client_id: Optional[str] = None,
|
| 98 |
+
clock_skew_seconds: int = 0
|
| 99 |
+
):
|
| 100 |
+
"""
|
| 101 |
+
Initialize the Google Auth Service.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
client_id: Google OAuth 2.0 Client ID. If not provided,
|
| 105 |
+
falls back to GOOGLE_CLIENT_ID environment variable.
|
| 106 |
+
clock_skew_seconds: Allowed clock skew in seconds for token
|
| 107 |
+
validation (default: 0).
|
| 108 |
+
|
| 109 |
+
Raises:
|
| 110 |
+
ConfigurationError: If no client_id is provided or found.
|
| 111 |
+
"""
|
| 112 |
+
self.client_id = client_id or os.getenv("AUTH_SIGN_IN_GOOGLE_CLIENT_ID")
|
| 113 |
+
self.clock_skew_seconds = clock_skew_seconds
|
| 114 |
+
|
| 115 |
+
if not self.client_id:
|
| 116 |
+
raise ConfigurationError(
|
| 117 |
+
"Google Client ID is required. Either pass client_id parameter "
|
| 118 |
+
"or set GOOGLE_CLIENT_ID environment variable."
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
logger.info(f"GoogleAuthService initialized with client_id: {self.client_id[:20]}...")
|
| 122 |
+
|
| 123 |
+
def verify_token(self, id_token: str) -> GoogleUserInfo:
|
| 124 |
+
"""
|
| 125 |
+
Verify a Google ID token and extract user information.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
id_token: The ID token received from the frontend after
|
| 129 |
+
Google Sign-In.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
GoogleUserInfo: Dataclass containing user's Google profile info.
|
| 133 |
+
|
| 134 |
+
Raises:
|
| 135 |
+
InvalidTokenError: If the token is invalid, expired, or
|
| 136 |
+
doesn't match the expected client ID.
|
| 137 |
+
"""
|
| 138 |
+
if not id_token:
|
| 139 |
+
raise InvalidTokenError("Token cannot be empty")
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
# Verify the token with Google
|
| 143 |
+
idinfo = google_id_token.verify_oauth2_token(
|
| 144 |
+
id_token,
|
| 145 |
+
google_requests.Request(),
|
| 146 |
+
self.client_id,
|
| 147 |
+
clock_skew_in_seconds=self.clock_skew_seconds
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Validate issuer
|
| 151 |
+
if idinfo.get("iss") not in ["accounts.google.com", "https://accounts.google.com"]:
|
| 152 |
+
raise InvalidTokenError("Invalid token issuer")
|
| 153 |
+
|
| 154 |
+
# Validate audience
|
| 155 |
+
if idinfo.get("aud") != self.client_id:
|
| 156 |
+
raise InvalidTokenError("Token was not issued for this application")
|
| 157 |
+
|
| 158 |
+
# Extract user info
|
| 159 |
+
return GoogleUserInfo(
|
| 160 |
+
google_id=idinfo["sub"],
|
| 161 |
+
email=idinfo["email"],
|
| 162 |
+
email_verified=idinfo.get("email_verified", False),
|
| 163 |
+
name=idinfo.get("name"),
|
| 164 |
+
picture=idinfo.get("picture"),
|
| 165 |
+
given_name=idinfo.get("given_name"),
|
| 166 |
+
family_name=idinfo.get("family_name"),
|
| 167 |
+
locale=idinfo.get("locale")
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
except ValueError as e:
|
| 171 |
+
logger.warning(f"Token verification failed: {e}")
|
| 172 |
+
raise InvalidTokenError(f"Token verification failed: {str(e)}")
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.error(f"Unexpected error during token verification: {e}")
|
| 175 |
+
raise InvalidTokenError(f"Token verification error: {str(e)}")
|
| 176 |
+
|
| 177 |
+
def verify_token_safe(self, id_token: str) -> Optional[GoogleUserInfo]:
|
| 178 |
+
"""
|
| 179 |
+
Verify a Google ID token without raising exceptions.
|
| 180 |
+
|
| 181 |
+
Useful for cases where you want to check validity without
|
| 182 |
+
exception handling.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
id_token: The ID token to verify.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
GoogleUserInfo if valid, None if invalid.
|
| 189 |
+
"""
|
| 190 |
+
try:
|
| 191 |
+
return self.verify_token(id_token)
|
| 192 |
+
except GoogleAuthError:
|
| 193 |
+
return None
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# Singleton instance for convenience (initialized on first use)
|
| 197 |
+
_default_service: Optional[GoogleAuthService] = None
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_google_auth_service() -> GoogleAuthService:
|
| 201 |
+
"""
|
| 202 |
+
Get the default GoogleAuthService instance.
|
| 203 |
+
|
| 204 |
+
Creates a singleton instance using environment variables.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
GoogleAuthService: The default service instance.
|
| 208 |
+
|
| 209 |
+
Raises:
|
| 210 |
+
ConfigurationError: If GOOGLE_CLIENT_ID is not set.
|
| 211 |
+
"""
|
| 212 |
+
global _default_service
|
| 213 |
+
if _default_service is None:
|
| 214 |
+
_default_service = GoogleAuthService()
|
| 215 |
+
return _default_service
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def verify_google_token(id_token: str) -> GoogleUserInfo:
|
| 219 |
+
"""
|
| 220 |
+
Convenience function to verify a token using the default service.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
id_token: The Google ID token to verify.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
GoogleUserInfo: Verified user information.
|
| 227 |
+
|
| 228 |
+
Raises:
|
| 229 |
+
InvalidTokenError: If verification fails.
|
| 230 |
+
ConfigurationError: If service is not configured.
|
| 231 |
+
"""
|
| 232 |
+
return get_google_auth_service().verify_token(id_token)
|
services/auth_service/jwt_provider.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modular JWT Service
|
| 3 |
+
|
| 4 |
+
A self-contained, plug-and-play service for creating and verifying JWT tokens.
|
| 5 |
+
Can be used in any Python application with minimal configuration.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
from services.jwt_service import JWTService, TokenPayload
|
| 9 |
+
|
| 10 |
+
# Initialize with secret key
|
| 11 |
+
jwt_service = JWTService(secret_key="your-secret-key")
|
| 12 |
+
|
| 13 |
+
# Or use environment variable JWT_SECRET
|
| 14 |
+
jwt_service = JWTService()
|
| 15 |
+
|
| 16 |
+
# Create a token
|
| 17 |
+
token = jwt_service.create_token(user_id="user123", email="user@example.com")
|
| 18 |
+
|
| 19 |
+
# Verify a token
|
| 20 |
+
payload = jwt_service.verify_token(token)
|
| 21 |
+
print(payload.user_id, payload.email)
|
| 22 |
+
|
| 23 |
+
Environment Variables:
|
| 24 |
+
JWT_SECRET: Your secret key for signing tokens (required)
|
| 25 |
+
JWT_EXPIRY_HOURS: Token expiry in hours (default: 168 = 7 days)
|
| 26 |
+
JWT_ALGORITHM: Algorithm to use (default: HS256)
|
| 27 |
+
|
| 28 |
+
Dependencies:
|
| 29 |
+
PyJWT>=2.8.0
|
| 30 |
+
|
| 31 |
+
Generate a secure secret:
|
| 32 |
+
python -c "import secrets; print(secrets.token_urlsafe(64))"
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import os
|
| 36 |
+
import logging
|
| 37 |
+
from dataclasses import dataclass
|
| 38 |
+
from datetime import datetime, timedelta
|
| 39 |
+
from typing import Optional, Dict, Any
|
| 40 |
+
import jwt
|
| 41 |
+
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class TokenPayload:
|
| 47 |
+
"""
|
| 48 |
+
Payload extracted from a verified JWT token.
|
| 49 |
+
|
| 50 |
+
Attributes:
|
| 51 |
+
user_id: The user's unique identifier (sub claim)
|
| 52 |
+
email: The user's email address
|
| 53 |
+
issued_at: When the token was issued
|
| 54 |
+
expires_at: When the token expires
|
| 55 |
+
token_version: Version number for token invalidation
|
| 56 |
+
extra: Any additional claims in the token
|
| 57 |
+
"""
|
| 58 |
+
user_id: str
|
| 59 |
+
email: str
|
| 60 |
+
issued_at: datetime
|
| 61 |
+
expires_at: datetime
|
| 62 |
+
token_version: int = 1
|
| 63 |
+
extra: Dict[str, Any] = None
|
| 64 |
+
|
| 65 |
+
def __post_init__(self):
|
| 66 |
+
if self.extra is None:
|
| 67 |
+
self.extra = {}
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def is_expired(self) -> bool:
|
| 71 |
+
"""Check if the token has expired."""
|
| 72 |
+
return datetime.utcnow() > self.expires_at
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def time_until_expiry(self) -> timedelta:
|
| 76 |
+
"""Get time remaining until expiry."""
|
| 77 |
+
return self.expires_at - datetime.utcnow()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class JWTError(Exception):
|
| 81 |
+
"""Base exception for JWT errors."""
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class TokenExpiredError(JWTError):
|
| 86 |
+
"""Raised when the token has expired."""
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class InvalidTokenError(JWTError):
|
| 91 |
+
"""Raised when the token is invalid."""
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class ConfigurationError(JWTError):
|
| 96 |
+
"""Raised when the service is not properly configured."""
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class JWTService:
|
| 101 |
+
"""
|
| 102 |
+
Service for creating and verifying JWT tokens.
|
| 103 |
+
|
| 104 |
+
This service handles JWT token lifecycle for authentication.
|
| 105 |
+
It's designed to be modular and reusable across different applications.
|
| 106 |
+
|
| 107 |
+
Example:
|
| 108 |
+
service = JWTService(secret_key="my-secret")
|
| 109 |
+
|
| 110 |
+
# Create token
|
| 111 |
+
token = service.create_token(user_id="u123", email="a@b.com")
|
| 112 |
+
|
| 113 |
+
# Verify token
|
| 114 |
+
try:
|
| 115 |
+
payload = service.verify_token(token)
|
| 116 |
+
print(f"User: {payload.user_id}")
|
| 117 |
+
except TokenExpiredError:
|
| 118 |
+
print("Token expired, please login again")
|
| 119 |
+
except InvalidTokenError:
|
| 120 |
+
print("Invalid token")
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
# Default configuration
|
| 124 |
+
DEFAULT_ALGORITHM = "HS256"
|
| 125 |
+
DEFAULT_EXPIRY_HOURS = 168 # 7 days
|
| 126 |
+
|
| 127 |
+
def __init__(
|
| 128 |
+
self,
|
| 129 |
+
secret_key: Optional[str] = None,
|
| 130 |
+
algorithm: Optional[str] = None,
|
| 131 |
+
expiry_hours: Optional[int] = None
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
Initialize the JWT Service.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
secret_key: Secret key for signing tokens. If not provided,
|
| 138 |
+
falls back to JWT_SECRET environment variable.
|
| 139 |
+
algorithm: JWT algorithm (default: HS256).
|
| 140 |
+
expiry_hours: Token expiry in hours (default: 168 = 7 days).
|
| 141 |
+
|
| 142 |
+
Raises:
|
| 143 |
+
ConfigurationError: If no secret_key is provided or found.
|
| 144 |
+
"""
|
| 145 |
+
self.secret_key = secret_key or os.getenv("JWT_SECRET")
|
| 146 |
+
self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM)
|
| 147 |
+
self.expiry_hours = expiry_hours or int(
|
| 148 |
+
os.getenv("JWT_EXPIRY_HOURS", str(self.DEFAULT_EXPIRY_HOURS))
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
if not self.secret_key:
|
| 152 |
+
raise ConfigurationError(
|
| 153 |
+
"JWT secret key is required. Either pass secret_key parameter "
|
| 154 |
+
"or set JWT_SECRET environment variable. "
|
| 155 |
+
"Generate one with: python -c \"import secrets; print(secrets.token_urlsafe(64))\""
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Warn if secret is too short
|
| 159 |
+
if len(self.secret_key) < 32:
|
| 160 |
+
logger.warning(
|
| 161 |
+
"JWT secret key is short (< 32 chars). "
|
| 162 |
+
"Consider using a longer secret for better security."
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
logger.info(
|
| 166 |
+
f"JWTService initialized (algorithm={self.algorithm}, "
|
| 167 |
+
f"expiry={self.expiry_hours}h)"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
def create_token(
|
| 171 |
+
self,
|
| 172 |
+
user_id: str,
|
| 173 |
+
email: str,
|
| 174 |
+
token_version: int = 1,
|
| 175 |
+
extra_claims: Optional[Dict[str, Any]] = None,
|
| 176 |
+
expiry_hours: Optional[int] = None
|
| 177 |
+
) -> str:
|
| 178 |
+
"""
|
| 179 |
+
Create a JWT token for a user.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
user_id: The user's unique identifier.
|
| 183 |
+
email: The user's email address.
|
| 184 |
+
token_version: User's current token version for invalidation.
|
| 185 |
+
extra_claims: Additional claims to include in the token.
|
| 186 |
+
expiry_hours: Custom expiry for this token (overrides default).
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
str: The encoded JWT token.
|
| 190 |
+
"""
|
| 191 |
+
now = datetime.utcnow()
|
| 192 |
+
expiry = expiry_hours or self.expiry_hours
|
| 193 |
+
|
| 194 |
+
payload = {
|
| 195 |
+
"sub": user_id,
|
| 196 |
+
"email": email,
|
| 197 |
+
"tv": token_version, # Token version for invalidation
|
| 198 |
+
"iat": now,
|
| 199 |
+
"exp": now + timedelta(hours=expiry),
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
if extra_claims:
|
| 203 |
+
payload.update(extra_claims)
|
| 204 |
+
|
| 205 |
+
token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
| 206 |
+
|
| 207 |
+
logger.debug(f"Created token for user_id={user_id} (version={token_version})")
|
| 208 |
+
return token
|
| 209 |
+
|
| 210 |
+
def verify_token(self, token: str) -> TokenPayload:
|
| 211 |
+
"""
|
| 212 |
+
Verify a JWT token and extract the payload.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
token: The JWT token to verify.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
TokenPayload: Dataclass containing the verified payload.
|
| 219 |
+
|
| 220 |
+
Raises:
|
| 221 |
+
TokenExpiredError: If the token has expired.
|
| 222 |
+
InvalidTokenError: If the token is invalid or malformed.
|
| 223 |
+
"""
|
| 224 |
+
if not token:
|
| 225 |
+
raise InvalidTokenError("Token cannot be empty")
|
| 226 |
+
|
| 227 |
+
try:
|
| 228 |
+
payload = jwt.decode(
|
| 229 |
+
token,
|
| 230 |
+
self.secret_key,
|
| 231 |
+
algorithms=[self.algorithm]
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Extract standard claims
|
| 235 |
+
user_id = payload.get("sub")
|
| 236 |
+
email = payload.get("email")
|
| 237 |
+
token_version = payload.get("tv", 1) # Default to 1 for backward compatibility
|
| 238 |
+
iat = payload.get("iat")
|
| 239 |
+
exp = payload.get("exp")
|
| 240 |
+
|
| 241 |
+
if not user_id or not email:
|
| 242 |
+
raise InvalidTokenError("Token missing required claims (sub, email)")
|
| 243 |
+
|
| 244 |
+
# Convert timestamps to datetime
|
| 245 |
+
issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat
|
| 246 |
+
expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
|
| 247 |
+
|
| 248 |
+
# Extract extra claims
|
| 249 |
+
standard_claims = {"sub", "email", "tv", "iat", "exp"}
|
| 250 |
+
extra = {k: v for k, v in payload.items() if k not in standard_claims}
|
| 251 |
+
|
| 252 |
+
return TokenPayload(
|
| 253 |
+
user_id=user_id,
|
| 254 |
+
email=email,
|
| 255 |
+
issued_at=issued_at,
|
| 256 |
+
expires_at=expires_at,
|
| 257 |
+
token_version=token_version,
|
| 258 |
+
extra=extra
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
except jwt.ExpiredSignatureError:
|
| 262 |
+
logger.debug("Token verification failed: expired")
|
| 263 |
+
raise TokenExpiredError("Token has expired")
|
| 264 |
+
except jwt.InvalidTokenError as e:
|
| 265 |
+
logger.debug(f"Token verification failed: {e}")
|
| 266 |
+
raise InvalidTokenError(f"Invalid token: {str(e)}")
|
| 267 |
+
except Exception as e:
|
| 268 |
+
logger.error(f"Unexpected error during token verification: {e}")
|
| 269 |
+
raise InvalidTokenError(f"Token verification error: {str(e)}")
|
| 270 |
+
|
| 271 |
+
def verify_token_safe(self, token: str) -> Optional[TokenPayload]:
|
| 272 |
+
"""
|
| 273 |
+
Verify a JWT token without raising exceptions.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
token: The JWT token to verify.
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
TokenPayload if valid, None if invalid or expired.
|
| 280 |
+
"""
|
| 281 |
+
try:
|
| 282 |
+
return self.verify_token(token)
|
| 283 |
+
except JWTError:
|
| 284 |
+
return None
|
| 285 |
+
|
| 286 |
+
def refresh_token(
|
| 287 |
+
self,
|
| 288 |
+
token: str,
|
| 289 |
+
expiry_hours: Optional[int] = None
|
| 290 |
+
) -> str:
|
| 291 |
+
"""
|
| 292 |
+
Refresh a token by creating a new one with the same claims.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
token: The current (possibly expired) token.
|
| 296 |
+
expiry_hours: Custom expiry for the new token.
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
str: A new JWT token with updated expiry.
|
| 300 |
+
|
| 301 |
+
Raises:
|
| 302 |
+
InvalidTokenError: If the token is malformed.
|
| 303 |
+
"""
|
| 304 |
+
try:
|
| 305 |
+
# Decode without verifying expiry
|
| 306 |
+
payload = jwt.decode(
|
| 307 |
+
token,
|
| 308 |
+
self.secret_key,
|
| 309 |
+
algorithms=[self.algorithm],
|
| 310 |
+
options={"verify_exp": False}
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
user_id = payload.get("sub")
|
| 314 |
+
email = payload.get("email")
|
| 315 |
+
|
| 316 |
+
if not user_id or not email:
|
| 317 |
+
raise InvalidTokenError("Token missing required claims")
|
| 318 |
+
|
| 319 |
+
# Preserve extra claims
|
| 320 |
+
standard_claims = {"sub", "email", "iat", "exp"}
|
| 321 |
+
extra = {k: v for k, v in payload.items() if k not in standard_claims}
|
| 322 |
+
|
| 323 |
+
return self.create_token(
|
| 324 |
+
user_id=user_id,
|
| 325 |
+
email=email,
|
| 326 |
+
extra_claims=extra,
|
| 327 |
+
expiry_hours=expiry_hours
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
except jwt.InvalidTokenError as e:
|
| 331 |
+
raise InvalidTokenError(f"Cannot refresh invalid token: {str(e)}")
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# Singleton instance for convenience
|
| 335 |
+
_default_service: Optional[JWTService] = None
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def get_jwt_service() -> JWTService:
|
| 339 |
+
"""
|
| 340 |
+
Get the default JWTService instance.
|
| 341 |
+
|
| 342 |
+
Creates a singleton instance using environment variables.
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
JWTService: The default service instance.
|
| 346 |
+
|
| 347 |
+
Raises:
|
| 348 |
+
ConfigurationError: If JWT_SECRET is not set.
|
| 349 |
+
"""
|
| 350 |
+
global _default_service
|
| 351 |
+
if _default_service is None:
|
| 352 |
+
_default_service = JWTService()
|
| 353 |
+
return _default_service
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
|
| 357 |
+
"""
|
| 358 |
+
Convenience function to create a token using the default service.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
user_id: The user's unique identifier.
|
| 362 |
+
email: The user's email address.
|
| 363 |
+
token_version: User's current token version for invalidation.
|
| 364 |
+
**kwargs: Additional arguments passed to create_token.
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
str: The encoded JWT token.
|
| 368 |
+
"""
|
| 369 |
+
return get_jwt_service().create_token(user_id, email, token_version, **kwargs)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def verify_access_token(token: str) -> TokenPayload:
|
| 373 |
+
"""
|
| 374 |
+
Convenience function to verify a token using the default service.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
token: The JWT token to verify.
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
TokenPayload: Verified token payload.
|
| 381 |
+
|
| 382 |
+
Raises:
|
| 383 |
+
TokenExpiredError: If the token has expired.
|
| 384 |
+
InvalidTokenError: If the token is invalid.
|
| 385 |
+
"""
|
| 386 |
+
return get_jwt_service().verify_token(token)
|
services/auth_service/middleware.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Auth Middleware - Request authentication layer
|
| 3 |
+
|
| 4 |
+
Intercepts requests to validate JWT tokens and attach authenticated
|
| 5 |
+
user to request.state for use in route handlers.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from fastapi import Request, HTTPException, status
|
| 10 |
+
from fastapi.responses import JSONResponse
|
| 11 |
+
from sqlalchemy import select
|
| 12 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 13 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 14 |
+
|
| 15 |
+
from core.database import async_session_maker
|
| 16 |
+
from core.models import User
|
| 17 |
+
from services.auth_service.config import AuthServiceConfig
|
| 18 |
+
from services.auth_service.jwt_provider import (
|
| 19 |
+
verify_access_token,
|
| 20 |
+
TokenExpiredError,
|
| 21 |
+
InvalidTokenError,
|
| 22 |
+
JWTError,
|
| 23 |
+
)
|
| 24 |
+
from services.base_service.middleware_chain import (
|
| 25 |
+
BaseServiceMiddleware,
|
| 26 |
+
get_request_context,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AuthMiddleware(BaseServiceMiddleware):
|
| 33 |
+
"""
|
| 34 |
+
Authentication middleware for request validation.
|
| 35 |
+
|
| 36 |
+
Flow:
|
| 37 |
+
1. Check if route requires/allows auth based on URL
|
| 38 |
+
2. Extract Authorization header
|
| 39 |
+
3. Verify JWT token
|
| 40 |
+
4. Load user from database
|
| 41 |
+
5. Attach user to request.state.user
|
| 42 |
+
6. Continue to next middleware/route
|
| 43 |
+
|
| 44 |
+
Public routes skip all auth checks.
|
| 45 |
+
Required routes must have valid auth or return 401.
|
| 46 |
+
Optional routes attach user if auth is provided, but don't fail if missing.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
SERVICE_NAME = "auth"
|
| 50 |
+
|
| 51 |
+
async def dispatch(self, request: Request, call_next):
|
| 52 |
+
"""Process request through auth middleware."""
|
| 53 |
+
# Skip OPTIONS requests (CORS preflight)
|
| 54 |
+
if request.method == "OPTIONS":
|
| 55 |
+
return await call_next(request)
|
| 56 |
+
|
| 57 |
+
# Initialize request context
|
| 58 |
+
ctx = get_request_context(request)
|
| 59 |
+
|
| 60 |
+
# Get path and method from request
|
| 61 |
+
path = request.url.path
|
| 62 |
+
|
| 63 |
+
# Check if route is public (skip all auth)
|
| 64 |
+
if AuthServiceConfig.is_public(path):
|
| 65 |
+
self.log_request(request, "Public route, skipping auth")
|
| 66 |
+
request.state.user = None
|
| 67 |
+
ctx.user = None
|
| 68 |
+
ctx.is_authenticated = False
|
| 69 |
+
response = await call_next(request)
|
| 70 |
+
return response
|
| 71 |
+
|
| 72 |
+
# Check if route requires auth or allows optional auth
|
| 73 |
+
requires_auth = AuthServiceConfig.requires_auth(path)
|
| 74 |
+
allows_optional = AuthServiceConfig.allows_optional_auth(path)
|
| 75 |
+
|
| 76 |
+
# If route doesn't require auth and doesn't allow optional, skip
|
| 77 |
+
if not requires_auth and not allows_optional:
|
| 78 |
+
self.log_request(request, "Route not configured for auth, skipping")
|
| 79 |
+
request.state.user = None
|
| 80 |
+
ctx.user = None
|
| 81 |
+
ctx.is_authenticated = False
|
| 82 |
+
response = await call_next(request)
|
| 83 |
+
return response
|
| 84 |
+
|
| 85 |
+
# Extract Authorization header
|
| 86 |
+
auth_header = request.headers.get("Authorization")
|
| 87 |
+
|
| 88 |
+
# If no auth header
|
| 89 |
+
if not auth_header:
|
| 90 |
+
if requires_auth:
|
| 91 |
+
self.log_request(request, "Missing Authorization header (required)")
|
| 92 |
+
return JSONResponse(
|
| 93 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 94 |
+
content={"detail": "Missing Authorization header"},
|
| 95 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
# Optional auth, no header provided
|
| 99 |
+
self.log_request(request, "No auth header (optional route)")
|
| 100 |
+
request.state.user = None
|
| 101 |
+
ctx.user = None
|
| 102 |
+
ctx.is_authenticated = False
|
| 103 |
+
response = await call_next(request)
|
| 104 |
+
return response
|
| 105 |
+
|
| 106 |
+
# Validate Authorization header format
|
| 107 |
+
if not auth_header.startswith("Bearer "):
|
| 108 |
+
if requires_auth:
|
| 109 |
+
self.log_request(request, "Invalid Authorization header format")
|
| 110 |
+
return JSONResponse(
|
| 111 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 112 |
+
content={"detail": "Invalid Authorization header format. Use: Bearer <token>"},
|
| 113 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
# Optional auth, invalid format
|
| 117 |
+
request.state.user = None
|
| 118 |
+
ctx.user = None
|
| 119 |
+
ctx.is_authenticated = False
|
| 120 |
+
response = await call_next(request)
|
| 121 |
+
return response
|
| 122 |
+
|
| 123 |
+
# Extract token
|
| 124 |
+
token = auth_header.split(" ", 1)[1]
|
| 125 |
+
|
| 126 |
+
# Verify token
|
| 127 |
+
try:
|
| 128 |
+
payload = verify_access_token(token)
|
| 129 |
+
except TokenExpiredError:
|
| 130 |
+
if requires_auth:
|
| 131 |
+
self.log_request(request, "Token expired")
|
| 132 |
+
return JSONResponse(
|
| 133 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 134 |
+
content={"detail": "Token has expired. Please sign in again."},
|
| 135 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 136 |
+
)
|
| 137 |
+
else:
|
| 138 |
+
# Optional auth, expired token
|
| 139 |
+
request.state.user = None
|
| 140 |
+
ctx.user = None
|
| 141 |
+
ctx.is_authenticated = False
|
| 142 |
+
response = await call_next(request)
|
| 143 |
+
return response
|
| 144 |
+
except (InvalidTokenError, JWTError) as e:
|
| 145 |
+
if requires_auth:
|
| 146 |
+
self.log_error(request, f"Token verification failed: {e}")
|
| 147 |
+
return JSONResponse(
|
| 148 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 149 |
+
content={"detail": f"Invalid token: {str(e)}"},
|
| 150 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
# Optional auth, invalid token
|
| 154 |
+
request.state.user = None
|
| 155 |
+
ctx.user = None
|
| 156 |
+
ctx.is_authenticated = False
|
| 157 |
+
response = await call_next(request)
|
| 158 |
+
return response
|
| 159 |
+
|
| 160 |
+
# Get database session
|
| 161 |
+
async with async_session_maker() as db:
|
| 162 |
+
try:
|
| 163 |
+
# Load user from database
|
| 164 |
+
query = select(User).where(
|
| 165 |
+
User.user_id == payload.user_id,
|
| 166 |
+
User.is_active == True
|
| 167 |
+
)
|
| 168 |
+
result = await db.execute(query)
|
| 169 |
+
user = result.scalar_one_or_none()
|
| 170 |
+
|
| 171 |
+
if not user:
|
| 172 |
+
if requires_auth:
|
| 173 |
+
self.log_request(request, "User not found or inactive")
|
| 174 |
+
return JSONResponse(
|
| 175 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 176 |
+
content={"detail": "User not found or inactive"},
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
# Optional auth, user not found
|
| 180 |
+
request.state.user = None
|
| 181 |
+
ctx.user = None
|
| 182 |
+
ctx.is_authenticated = False
|
| 183 |
+
response = await call_next(request)
|
| 184 |
+
return response
|
| 185 |
+
|
| 186 |
+
# Validate token version
|
| 187 |
+
if payload.token_version < user.token_version:
|
| 188 |
+
if requires_auth:
|
| 189 |
+
self.log_request(
|
| 190 |
+
request,
|
| 191 |
+
f"Token invalidated (version {payload.token_version} < {user.token_version})"
|
| 192 |
+
)
|
| 193 |
+
return JSONResponse(
|
| 194 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 195 |
+
content={"detail": "Token has been invalidated. Please sign in again."},
|
| 196 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
# Optional auth, invalidated token
|
| 200 |
+
request.state.user = None
|
| 201 |
+
ctx.user = None
|
| 202 |
+
ctx.is_authenticated = False
|
| 203 |
+
response = await call_next(request)
|
| 204 |
+
return response
|
| 205 |
+
|
| 206 |
+
# Attach user to request state
|
| 207 |
+
request.state.user = user
|
| 208 |
+
ctx.set_user(user)
|
| 209 |
+
|
| 210 |
+
# Check if user is admin
|
| 211 |
+
is_admin = AuthServiceConfig.is_admin(user.email)
|
| 212 |
+
request.state.is_admin = is_admin
|
| 213 |
+
ctx.set_flag('is_admin', is_admin)
|
| 214 |
+
|
| 215 |
+
self.log_request(request, f"Authenticated user: {user.user_id}")
|
| 216 |
+
|
| 217 |
+
# Continue to next middleware/route
|
| 218 |
+
response = await call_next(request)
|
| 219 |
+
return response
|
| 220 |
+
|
| 221 |
+
finally:
|
| 222 |
+
await db.close()
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
__all__ = ['AuthMiddleware']
|
services/base_service/__init__.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base Service Infrastructure
|
| 3 |
+
|
| 4 |
+
Provides the foundation for plug-and-play services in the API gateway.
|
| 5 |
+
All services (auth, credit, gemini, etc.) extend this base infrastructure.
|
| 6 |
+
|
| 7 |
+
Core Components:
|
| 8 |
+
- BaseService: Abstract base class for all services
|
| 9 |
+
- ServiceConfig: Configuration container
|
| 10 |
+
- ServiceRegistry: Global registry for service discovery
|
| 11 |
+
- MiddlewareProtocol: Type definition for middleware functions
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
class MyService(BaseService):
|
| 15 |
+
@classmethod
|
| 16 |
+
def register(cls, **config):
|
| 17 |
+
# Service-specific registration
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
@classmethod
|
| 21 |
+
def get_middleware(cls):
|
| 22 |
+
# Return middleware function if needed
|
| 23 |
+
return MyMiddleware()
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import logging
|
| 27 |
+
from abc import ABC, abstractmethod
|
| 28 |
+
from typing import Dict, Type, Optional, Callable, Any
|
| 29 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ServiceConfig:
|
| 35 |
+
"""
|
| 36 |
+
Base configuration container for services.
|
| 37 |
+
|
| 38 |
+
Services can extend this to add their specific configuration.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, **kwargs):
|
| 42 |
+
"""Initialize configuration with arbitrary key-value pairs."""
|
| 43 |
+
self._config = kwargs
|
| 44 |
+
|
| 45 |
+
def get(self, key: str, default: Any = None) -> Any:
|
| 46 |
+
"""Get configuration value."""
|
| 47 |
+
return self._config.get(key, default)
|
| 48 |
+
|
| 49 |
+
def set(self, key: str, value: Any) -> None:
|
| 50 |
+
"""Set configuration value."""
|
| 51 |
+
self._config[key] = value
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, key: str) -> Any:
|
| 54 |
+
"""Dictionary-style access."""
|
| 55 |
+
return self._config[key]
|
| 56 |
+
|
| 57 |
+
def __setitem__(self, key: str, value: Any) -> None:
|
| 58 |
+
"""Dictionary-style assignment."""
|
| 59 |
+
self._config[key] = value
|
| 60 |
+
|
| 61 |
+
def __contains__(self, key: str) -> bool:
|
| 62 |
+
"""Check if key exists."""
|
| 63 |
+
return key in self._config
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class BaseService(ABC):
|
| 67 |
+
"""
|
| 68 |
+
Abstract base class for all plug-and-play services.
|
| 69 |
+
|
| 70 |
+
Services must implement:
|
| 71 |
+
- register(): Register service configuration at startup
|
| 72 |
+
- get_middleware(): Return middleware if service needs request interception
|
| 73 |
+
- on_shutdown(): Cleanup on app shutdown
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
# Service name (override in subclass)
|
| 77 |
+
SERVICE_NAME: str = "base_service"
|
| 78 |
+
|
| 79 |
+
# Service configuration
|
| 80 |
+
_config: Optional[ServiceConfig] = None
|
| 81 |
+
|
| 82 |
+
# Registration state
|
| 83 |
+
_registered: bool = False
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
@abstractmethod
|
| 87 |
+
def register(cls, **config) -> None:
|
| 88 |
+
"""
|
| 89 |
+
Register service configuration at application startup.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
**config: Service-specific configuration parameters
|
| 93 |
+
|
| 94 |
+
Raises:
|
| 95 |
+
RuntimeError: If service is already registered
|
| 96 |
+
"""
|
| 97 |
+
if cls._registered:
|
| 98 |
+
raise RuntimeError(f"{cls.SERVICE_NAME} is already registered")
|
| 99 |
+
|
| 100 |
+
cls._config = ServiceConfig(**config)
|
| 101 |
+
cls._registered = True
|
| 102 |
+
|
| 103 |
+
logger.info(f"✅ {cls.SERVICE_NAME} registered successfully")
|
| 104 |
+
|
| 105 |
+
@classmethod
|
| 106 |
+
def get_middleware(cls) -> Optional[BaseHTTPMiddleware]:
|
| 107 |
+
"""
|
| 108 |
+
Return middleware instance if service needs request interception.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Middleware instance or None if service doesn't need middleware
|
| 112 |
+
"""
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
@classmethod
|
| 116 |
+
def on_shutdown(cls) -> None:
|
| 117 |
+
"""
|
| 118 |
+
Cleanup hook called during application shutdown.
|
| 119 |
+
|
| 120 |
+
Override this to perform cleanup (close connections, save state, etc.)
|
| 121 |
+
"""
|
| 122 |
+
pass
|
| 123 |
+
|
| 124 |
+
@classmethod
|
| 125 |
+
def is_registered(cls) -> bool:
|
| 126 |
+
"""Check if service has been registered."""
|
| 127 |
+
return cls._registered
|
| 128 |
+
|
| 129 |
+
@classmethod
|
| 130 |
+
def assert_registered(cls) -> None:
|
| 131 |
+
"""
|
| 132 |
+
Assert that service has been registered.
|
| 133 |
+
|
| 134 |
+
Raises:
|
| 135 |
+
RuntimeError: If service is not registered
|
| 136 |
+
"""
|
| 137 |
+
if not cls._registered:
|
| 138 |
+
raise RuntimeError(
|
| 139 |
+
f"{cls.SERVICE_NAME} is not registered. "
|
| 140 |
+
f"Call {cls.SERVICE_NAME}.register() at application startup."
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
@classmethod
|
| 144 |
+
def get_config(cls) -> ServiceConfig:
|
| 145 |
+
"""
|
| 146 |
+
Get service configuration.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
ServiceConfig instance
|
| 150 |
+
|
| 151 |
+
Raises:
|
| 152 |
+
RuntimeError: If service is not registered
|
| 153 |
+
"""
|
| 154 |
+
cls.assert_registered()
|
| 155 |
+
return cls._config
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class ServiceRegistry:
|
| 159 |
+
"""
|
| 160 |
+
Global registry for service discovery and management.
|
| 161 |
+
|
| 162 |
+
Tracks all registered services and provides lookup functionality.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
_services: Dict[str, Type[BaseService]] = {}
|
| 166 |
+
|
| 167 |
+
@classmethod
|
| 168 |
+
def register_service(cls, service_class: Type[BaseService]) -> None:
|
| 169 |
+
"""
|
| 170 |
+
Register a service class in the global registry.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
service_class: Service class to register
|
| 174 |
+
"""
|
| 175 |
+
service_name = service_class.SERVICE_NAME
|
| 176 |
+
|
| 177 |
+
if service_name in cls._services:
|
| 178 |
+
logger.warning(f"Service '{service_name}' already registered, overwriting")
|
| 179 |
+
|
| 180 |
+
cls._services[service_name] = service_class
|
| 181 |
+
logger.debug(f"Registered service: {service_name}")
|
| 182 |
+
|
| 183 |
+
@classmethod
|
| 184 |
+
def get_service(cls, service_name: str) -> Optional[Type[BaseService]]:
|
| 185 |
+
"""
|
| 186 |
+
Get a service class by name.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
service_name: Name of the service to retrieve
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Service class or None if not found
|
| 193 |
+
"""
|
| 194 |
+
return cls._services.get(service_name)
|
| 195 |
+
|
| 196 |
+
@classmethod
|
| 197 |
+
def get_all_services(cls) -> Dict[str, Type[BaseService]]:
|
| 198 |
+
"""
|
| 199 |
+
Get all registered services.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Dictionary mapping service names to service classes
|
| 203 |
+
"""
|
| 204 |
+
return cls._services.copy()
|
| 205 |
+
|
| 206 |
+
@classmethod
|
| 207 |
+
def get_all_middleware(cls) -> list:
|
| 208 |
+
"""
|
| 209 |
+
Get middleware from all registered services.
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
List of middleware instances in registration order
|
| 213 |
+
"""
|
| 214 |
+
middleware_list = []
|
| 215 |
+
|
| 216 |
+
for service_name, service_class in cls._services.items():
|
| 217 |
+
if service_class.is_registered():
|
| 218 |
+
middleware = service_class.get_middleware()
|
| 219 |
+
if middleware:
|
| 220 |
+
middleware_list.append(middleware)
|
| 221 |
+
logger.debug(f"Added middleware from service: {service_name}")
|
| 222 |
+
|
| 223 |
+
return middleware_list
|
| 224 |
+
|
| 225 |
+
@classmethod
|
| 226 |
+
def shutdown_all(cls) -> None:
|
| 227 |
+
"""
|
| 228 |
+
Call shutdown hooks for all registered services.
|
| 229 |
+
"""
|
| 230 |
+
logger.info("Shutting down all services...")
|
| 231 |
+
|
| 232 |
+
for service_name, service_class in cls._services.items():
|
| 233 |
+
try:
|
| 234 |
+
service_class.on_shutdown()
|
| 235 |
+
logger.debug(f"Shutdown complete: {service_name}")
|
| 236 |
+
except Exception as e:
|
| 237 |
+
logger.error(f"Error shutting down {service_name}: {e}")
|
| 238 |
+
|
| 239 |
+
logger.info("All services shut down")
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
__all__ = [
|
| 243 |
+
'BaseService',
|
| 244 |
+
'ServiceConfig',
|
| 245 |
+
'ServiceRegistry',
|
| 246 |
+
]
|
services/base_service/middleware_chain.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Middleware Chain - Orchestration of multiple middleware layers.
|
| 3 |
+
|
| 4 |
+
Provides utilities for managing and coordinating multiple middleware
|
| 5 |
+
components in the request/response flow.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
# In app.py
|
| 9 |
+
from services.base_service import MiddlewareChain
|
| 10 |
+
|
| 11 |
+
# Add middleware in reverse order (last added = first executed)
|
| 12 |
+
app.add_middleware(CreditMiddleware)
|
| 13 |
+
app.add_middleware(AuthMiddleware)
|
| 14 |
+
|
| 15 |
+
# Or use the chain helper
|
| 16 |
+
chain = MiddlewareChain()
|
| 17 |
+
chain.add(AuthMiddleware)
|
| 18 |
+
chain.add(CreditMiddleware)
|
| 19 |
+
chain.apply_to_app(app)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import logging
|
| 23 |
+
from typing import List, Type, Callable
|
| 24 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 25 |
+
from fastapi import FastAPI, Request, Response
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class RequestContext:
|
| 31 |
+
"""
|
| 32 |
+
Shared context for passing data between middleware layers.
|
| 33 |
+
|
| 34 |
+
Attached to request.state for access across middleware and routers.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self):
|
| 38 |
+
"""Initialize empty context."""
|
| 39 |
+
# Auth layer
|
| 40 |
+
self.user = None
|
| 41 |
+
self.is_authenticated = False
|
| 42 |
+
|
| 43 |
+
# Credit layer
|
| 44 |
+
self.credits_reserved = 0
|
| 45 |
+
self.credit_cost = 0
|
| 46 |
+
|
| 47 |
+
# General
|
| 48 |
+
self.start_time = None
|
| 49 |
+
self.service_flags = {}
|
| 50 |
+
|
| 51 |
+
def set_user(self, user) -> None:
|
| 52 |
+
"""Set authenticated user."""
|
| 53 |
+
self.user = user
|
| 54 |
+
self.is_authenticated = True
|
| 55 |
+
|
| 56 |
+
def set_credits(self, reserved: int, cost: int) -> None:
|
| 57 |
+
"""Set credit information."""
|
| 58 |
+
self.credits_reserved = reserved
|
| 59 |
+
self.credit_cost = cost
|
| 60 |
+
|
| 61 |
+
def set_flag(self, key: str, value: any) -> None:
|
| 62 |
+
"""Set a service-specific flag."""
|
| 63 |
+
self.service_flags[key] = value
|
| 64 |
+
|
| 65 |
+
def get_flag(self, key: str, default=None) -> any:
|
| 66 |
+
"""Get a service-specific flag."""
|
| 67 |
+
return self.service_flags.get(key, default)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class MiddlewareChain:
|
| 71 |
+
"""
|
| 72 |
+
Helper for managing middleware registration order.
|
| 73 |
+
|
| 74 |
+
FastAPI/Starlette middleware executes in REVERSE order of registration,
|
| 75 |
+
so the LAST middleware added is the FIRST to execute.
|
| 76 |
+
|
| 77 |
+
This class helps manage the order explicitly.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self):
|
| 81 |
+
"""Initialize empty middleware chain."""
|
| 82 |
+
self._middleware: List[Type[BaseHTTPMiddleware]] = []
|
| 83 |
+
|
| 84 |
+
def add(self, middleware_class: Type[BaseHTTPMiddleware], **kwargs) -> 'MiddlewareChain':
|
| 85 |
+
"""
|
| 86 |
+
Add middleware to the chain.
|
| 87 |
+
|
| 88 |
+
Middleware is added to the END of the list, but will be registered
|
| 89 |
+
in REVERSE order (so first added = first executed).
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
middleware_class: Middleware class to add
|
| 93 |
+
**kwargs: Arguments to pass to middleware constructor
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Self for chaining
|
| 97 |
+
"""
|
| 98 |
+
self._middleware.append((middleware_class, kwargs))
|
| 99 |
+
logger.debug(f"Added middleware to chain: {middleware_class.__name__}")
|
| 100 |
+
return self
|
| 101 |
+
|
| 102 |
+
def apply_to_app(self, app: FastAPI) -> None:
|
| 103 |
+
"""
|
| 104 |
+
Apply all middleware to the FastAPI app in correct order.
|
| 105 |
+
|
| 106 |
+
Middleware is registered in REVERSE order so that the first
|
| 107 |
+
middleware added to the chain is the first to execute.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
app: FastAPI application instance
|
| 111 |
+
"""
|
| 112 |
+
# Reverse the list so first added = first executed
|
| 113 |
+
for middleware_class, kwargs in reversed(self._middleware):
|
| 114 |
+
app.add_middleware(middleware_class, **kwargs)
|
| 115 |
+
logger.info(f"Registered middleware: {middleware_class.__name__}")
|
| 116 |
+
|
| 117 |
+
def get_middleware_list(self) -> List[Type[BaseHTTPMiddleware]]:
|
| 118 |
+
"""
|
| 119 |
+
Get the list of middleware in execution order.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
List of middleware classes in the order they will execute
|
| 123 |
+
"""
|
| 124 |
+
return [m[0] for m in self._middleware]
|
| 125 |
+
|
| 126 |
+
def __len__(self) -> int:
|
| 127 |
+
"""Get number of middleware in chain."""
|
| 128 |
+
return len(self._middleware)
|
| 129 |
+
|
| 130 |
+
def __repr__(self) -> str:
|
| 131 |
+
"""String representation for debugging."""
|
| 132 |
+
middleware_names = [m[0].__name__ for m in self._middleware]
|
| 133 |
+
return f"MiddlewareChain({middleware_names})"
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
async def initialize_request_context(request: Request) -> None:
|
| 137 |
+
"""
|
| 138 |
+
Initialize request context for middleware to use.
|
| 139 |
+
|
| 140 |
+
This should be called early in the middleware chain to ensure
|
| 141 |
+
request.state.ctx is available.
|
| 142 |
+
|
| 143 |
+
Usage:
|
| 144 |
+
class MyMiddleware(BaseHTTPMiddleware):
|
| 145 |
+
async def dispatch(self, request: Request, call_next):
|
| 146 |
+
await initialize_request_context(request)
|
| 147 |
+
# Now request.state.ctx is available
|
| 148 |
+
...
|
| 149 |
+
"""
|
| 150 |
+
if not hasattr(request.state, "ctx"):
|
| 151 |
+
request.state.ctx = RequestContext()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_request_context(request: Request) -> RequestContext:
|
| 155 |
+
"""
|
| 156 |
+
Get request context from request.state.
|
| 157 |
+
|
| 158 |
+
Creates context if it doesn't exist.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
request: FastAPI request object
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
RequestContext instance
|
| 165 |
+
"""
|
| 166 |
+
if not hasattr(request.state, "ctx"):
|
| 167 |
+
request.state.ctx = RequestContext()
|
| 168 |
+
return request.state.ctx
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class BaseServiceMiddleware(BaseHTTPMiddleware):
|
| 172 |
+
"""
|
| 173 |
+
Base class for service middleware.
|
| 174 |
+
|
| 175 |
+
Provides common functionality for all service middleware:
|
| 176 |
+
- Request context initialization
|
| 177 |
+
- Error handling
|
| 178 |
+
- Logging
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
SERVICE_NAME = "base"
|
| 182 |
+
|
| 183 |
+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| 184 |
+
"""
|
| 185 |
+
Process request through middleware.
|
| 186 |
+
|
| 187 |
+
Override this in subclasses to implement service-specific logic.
|
| 188 |
+
"""
|
| 189 |
+
# Initialize context
|
| 190 |
+
await initialize_request_context(request)
|
| 191 |
+
|
| 192 |
+
# Call next middleware/route
|
| 193 |
+
response = await call_next(request)
|
| 194 |
+
|
| 195 |
+
return response
|
| 196 |
+
|
| 197 |
+
def log_request(self, request: Request, message: str) -> None:
|
| 198 |
+
"""Log request with service context."""
|
| 199 |
+
logger.info(f"[{self.SERVICE_NAME}] {request.method} {request.url.path} - {message}")
|
| 200 |
+
|
| 201 |
+
def log_error(self, request: Request, error: str) -> None:
|
| 202 |
+
"""Log error with service context."""
|
| 203 |
+
logger.error(f"[{self.SERVICE_NAME}] {request.method} {request.url.path} - ERROR: {error}")
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
__all__ = [
|
| 207 |
+
'MiddlewareChain',
|
| 208 |
+
'RequestContext',
|
| 209 |
+
'BaseServiceMiddleware',
|
| 210 |
+
'initialize_request_context',
|
| 211 |
+
'get_request_context',
|
| 212 |
+
]
|
services/base_service/route_matcher.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Route Matcher - URL pattern matching for service configuration.
|
| 3 |
+
|
| 4 |
+
Provides flexible URL matching capabilities for services to define
|
| 5 |
+
which routes require auth, credits, etc.
|
| 6 |
+
|
| 7 |
+
Supported patterns:
|
| 8 |
+
- Exact match: "/api/users"
|
| 9 |
+
- Prefix match: "/api/*" (matches /api/anything)
|
| 10 |
+
- Wildcard match: "/api/users/*/posts" (matches /api/users/123/posts)
|
| 11 |
+
- Deep wildcard: "/api/**" (matches /api/users/123/posts/456)
|
| 12 |
+
- Regex match: "^/api/v[0-9]+/.*$"
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
matcher = RouteMatcher(["/api/*", "/admin/**"])
|
| 16 |
+
|
| 17 |
+
if matcher.matches("/api/users"):
|
| 18 |
+
# Route requires auth
|
| 19 |
+
pass
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import re
|
| 23 |
+
import logging
|
| 24 |
+
from typing import List, Set, Optional, Pattern
|
| 25 |
+
from fnmatch import fnmatch
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class RouteMatcher:
|
| 31 |
+
"""
|
| 32 |
+
Flexible URL pattern matcher for route configuration.
|
| 33 |
+
|
| 34 |
+
Supports exact matches, glob patterns, and regex patterns.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, patterns: List[str]):
|
| 38 |
+
"""
|
| 39 |
+
Initialize route matcher with patterns.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
patterns: List of URL patterns to match
|
| 43 |
+
"""
|
| 44 |
+
self.patterns = patterns
|
| 45 |
+
self._exact_matches: Set[str] = set()
|
| 46 |
+
self._prefix_patterns: List[str] = []
|
| 47 |
+
self._glob_patterns: List[str] = []
|
| 48 |
+
self._regex_patterns: List[Pattern] = []
|
| 49 |
+
|
| 50 |
+
# Classify patterns for performance
|
| 51 |
+
self._classify_patterns()
|
| 52 |
+
|
| 53 |
+
def _classify_patterns(self) -> None:
|
| 54 |
+
"""
|
| 55 |
+
Classify patterns by type for optimal matching performance.
|
| 56 |
+
|
| 57 |
+
Order of matching:
|
| 58 |
+
1. Exact matches (fastest - O(1))
|
| 59 |
+
2. Prefix patterns (fast - string startswith)
|
| 60 |
+
3. Glob patterns (medium - fnmatch)
|
| 61 |
+
4. Regex patterns (slowest - regex matching)
|
| 62 |
+
"""
|
| 63 |
+
for pattern in self.patterns:
|
| 64 |
+
# Empty pattern
|
| 65 |
+
if not pattern:
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
# Regex pattern (starts with ^)
|
| 69 |
+
if pattern.startswith("^"):
|
| 70 |
+
try:
|
| 71 |
+
compiled = re.compile(pattern)
|
| 72 |
+
self._regex_patterns.append(compiled)
|
| 73 |
+
logger.debug(f"Classified as regex: {pattern}")
|
| 74 |
+
except re.error as e:
|
| 75 |
+
logger.warning(f"Invalid regex pattern '{pattern}': {e}")
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
# Glob pattern (contains * or ?)
|
| 79 |
+
if "*" in pattern or "?" in pattern:
|
| 80 |
+
# Simple prefix wildcard: /api/*
|
| 81 |
+
if pattern.endswith("/*") and "*" not in pattern[:-2]:
|
| 82 |
+
prefix = pattern[:-2] # Remove /*
|
| 83 |
+
self._prefix_patterns.append(prefix)
|
| 84 |
+
logger.debug(f"Classified as prefix: {prefix}")
|
| 85 |
+
else:
|
| 86 |
+
# Complex glob: /api/*/users or /api/**
|
| 87 |
+
self._glob_patterns.append(pattern)
|
| 88 |
+
logger.debug(f"Classified as glob: {pattern}")
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
# Exact match
|
| 92 |
+
self._exact_matches.add(pattern)
|
| 93 |
+
logger.debug(f"Classified as exact: {pattern}")
|
| 94 |
+
|
| 95 |
+
def matches(self, path: str) -> bool:
|
| 96 |
+
"""
|
| 97 |
+
Check if a URL path matches any configured pattern.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
path: URL path to check (e.g., "/api/users/123")
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
True if path matches any pattern, False otherwise
|
| 104 |
+
"""
|
| 105 |
+
# Strip query parameters and fragments
|
| 106 |
+
path = path.split("?")[0].split("#")[0]
|
| 107 |
+
|
| 108 |
+
# Normalize path (remove trailing slash unless it's just "/")
|
| 109 |
+
if path != "/" and path.endswith("/"):
|
| 110 |
+
path = path.rstrip("/")
|
| 111 |
+
|
| 112 |
+
# 1. Exact match (O(1))
|
| 113 |
+
if path in self._exact_matches:
|
| 114 |
+
return True
|
| 115 |
+
|
| 116 |
+
# 2. Prefix match (O(n) but fast)
|
| 117 |
+
for prefix in self._prefix_patterns:
|
| 118 |
+
if path.startswith(prefix + "/") or path == prefix:
|
| 119 |
+
return True
|
| 120 |
+
|
| 121 |
+
# 3. Glob match (O(n))
|
| 122 |
+
for pattern in self._glob_patterns:
|
| 123 |
+
if fnmatch(path, pattern):
|
| 124 |
+
return True
|
| 125 |
+
|
| 126 |
+
# 4. Regex match (O(n) but slower)
|
| 127 |
+
for regex in self._regex_patterns:
|
| 128 |
+
if regex.match(path):
|
| 129 |
+
return True
|
| 130 |
+
|
| 131 |
+
return False
|
| 132 |
+
|
| 133 |
+
def get_matching_pattern(self, path: str) -> Optional[str]:
|
| 134 |
+
"""
|
| 135 |
+
Get the first pattern that matches the given path.
|
| 136 |
+
|
| 137 |
+
Useful for debugging or determining which rule matched.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
path: URL path to check
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Matching pattern string or None
|
| 144 |
+
"""
|
| 145 |
+
# Strip query parameters and fragments
|
| 146 |
+
path = path.split("?")[0].split("#")[0]
|
| 147 |
+
|
| 148 |
+
# Normalize path
|
| 149 |
+
if path != "/" and path.endswith("/"):
|
| 150 |
+
path = path.rstrip("/")
|
| 151 |
+
|
| 152 |
+
# Exact match
|
| 153 |
+
if path in self._exact_matches:
|
| 154 |
+
return path
|
| 155 |
+
|
| 156 |
+
# Prefix match
|
| 157 |
+
for prefix in self._prefix_patterns:
|
| 158 |
+
if path.startswith(prefix + "/") or path == prefix:
|
| 159 |
+
return prefix + "/*"
|
| 160 |
+
|
| 161 |
+
# Glob match
|
| 162 |
+
for pattern in self._glob_patterns:
|
| 163 |
+
if fnmatch(path, pattern):
|
| 164 |
+
return pattern
|
| 165 |
+
|
| 166 |
+
# Regex match
|
| 167 |
+
for regex in self._regex_patterns:
|
| 168 |
+
if regex.match(path):
|
| 169 |
+
return regex.pattern
|
| 170 |
+
|
| 171 |
+
return None
|
| 172 |
+
|
| 173 |
+
def __repr__(self) -> str:
|
| 174 |
+
"""String representation for debugging."""
|
| 175 |
+
return (
|
| 176 |
+
f"RouteMatcher("
|
| 177 |
+
f"exact={len(self._exact_matches)}, "
|
| 178 |
+
f"prefix={len(self._prefix_patterns)}, "
|
| 179 |
+
f"glob={len(self._glob_patterns)}, "
|
| 180 |
+
f"regex={len(self._regex_patterns)})"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class RouteConfig:
|
| 185 |
+
"""
|
| 186 |
+
Route configuration helper for services.
|
| 187 |
+
|
| 188 |
+
Manages multiple route lists (required, optional, public) with
|
| 189 |
+
precedence and exclusion logic.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
required: List[str] = None,
|
| 195 |
+
optional: List[str] = None,
|
| 196 |
+
public: List[str] = None,
|
| 197 |
+
):
|
| 198 |
+
"""
|
| 199 |
+
Initialize route configuration.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
required: Routes that REQUIRE the service (e.g., auth required)
|
| 203 |
+
optional: Routes where service is OPTIONAL (e.g., auth optional)
|
| 204 |
+
public: Routes that are PUBLIC (e.g., no auth needed)
|
| 205 |
+
|
| 206 |
+
Precedence: public > required > optional (for conflict resolution)
|
| 207 |
+
"""
|
| 208 |
+
self.required_matcher = RouteMatcher(required or [])
|
| 209 |
+
self.optional_matcher = RouteMatcher(optional or [])
|
| 210 |
+
self.public_matcher = RouteMatcher(public or [])
|
| 211 |
+
|
| 212 |
+
def is_required(self, path: str) -> bool:
|
| 213 |
+
"""
|
| 214 |
+
Check if service is REQUIRED for this path.
|
| 215 |
+
|
| 216 |
+
Returns False if path is public (public takes precedence).
|
| 217 |
+
"""
|
| 218 |
+
if self.is_public(path):
|
| 219 |
+
return False
|
| 220 |
+
return self.required_matcher.matches(path)
|
| 221 |
+
|
| 222 |
+
def is_optional(self, path: str) -> bool:
|
| 223 |
+
"""
|
| 224 |
+
Check if service is OPTIONAL for this path.
|
| 225 |
+
|
| 226 |
+
Returns False if path is public or required.
|
| 227 |
+
"""
|
| 228 |
+
if self.is_public(path):
|
| 229 |
+
return False
|
| 230 |
+
if self.required_matcher.matches(path):
|
| 231 |
+
return False
|
| 232 |
+
return self.optional_matcher.matches(path)
|
| 233 |
+
|
| 234 |
+
def is_public(self, path: str) -> bool:
|
| 235 |
+
"""
|
| 236 |
+
Check if path is PUBLIC (service not needed).
|
| 237 |
+
|
| 238 |
+
Public takes highest precedence.
|
| 239 |
+
"""
|
| 240 |
+
return self.public_matcher.matches(path)
|
| 241 |
+
|
| 242 |
+
def requires_service(self, path: str) -> bool:
|
| 243 |
+
"""
|
| 244 |
+
Check if service is needed (required OR optional) for this path.
|
| 245 |
+
|
| 246 |
+
Returns False if path is not matched by any configuration.
|
| 247 |
+
"""
|
| 248 |
+
return self.is_required(path) or self.is_optional(path)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
__all__ = [
|
| 252 |
+
'RouteMatcher',
|
| 253 |
+
'RouteConfig',
|
| 254 |
+
]
|
services/credit_service/__init__.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Credit Service - Credit validation middleware for API Gateway
|
| 3 |
+
|
| 4 |
+
Provides plug-and-play credit management with:
|
| 5 |
+
- Per-route cost configuration
|
| 6 |
+
- Credit reservation and validation
|
| 7 |
+
- Request middleware for credit checks
|
| 8 |
+
- Automatic refund on errors
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
# In app.py startup
|
| 12 |
+
from services.credit_service import register_credit_service
|
| 13 |
+
|
| 14 |
+
register_credit_service(
|
| 15 |
+
route_costs={
|
| 16 |
+
"/gemini/generate-animation-prompt": 1,
|
| 17 |
+
"/gemini/edit-image": 1,
|
| 18 |
+
"/gemini/generate-video": 10,
|
| 19 |
+
"/gemini/generate-text": 1,
|
| 20 |
+
"/gemini/analyze-image": 1,
|
| 21 |
+
}
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# In routers
|
| 25 |
+
from fastapi import Request
|
| 26 |
+
|
| 27 |
+
@router.post("/api/endpoint")
|
| 28 |
+
async def endpoint(request: Request):
|
| 29 |
+
user = request.state.user # From AuthMiddleware
|
| 30 |
+
credits_reserved = request.state.credits_reserved # From CreditMiddleware
|
| 31 |
+
return {"credits_remaining": user.credits}
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
from services.credit_service.config import CreditServiceConfig
|
| 35 |
+
from services.credit_service.middleware import CreditMiddleware
|
| 36 |
+
from services.credit_service.credit_manager import (
|
| 37 |
+
reserve_credit,
|
| 38 |
+
confirm_credit,
|
| 39 |
+
refund_credit,
|
| 40 |
+
handle_job_completion,
|
| 41 |
+
is_refundable_error,
|
| 42 |
+
REFUNDABLE_ERROR_PATTERNS,
|
| 43 |
+
NON_REFUNDABLE_ERROR_PATTERNS,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def register_credit_service(
|
| 48 |
+
route_costs: dict = None,
|
| 49 |
+
) -> None:
|
| 50 |
+
"""
|
| 51 |
+
Register the credit service with application configuration.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
route_costs: Dictionary mapping route paths to credit costs
|
| 55 |
+
Example: {"/gemini/generate-video": 10, "/gemini/edit-image": 1}
|
| 56 |
+
"""
|
| 57 |
+
CreditServiceConfig.register(
|
| 58 |
+
route_costs=route_costs or {},
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
__all__ = [
|
| 63 |
+
# Registration
|
| 64 |
+
'register_credit_service',
|
| 65 |
+
'CreditServiceConfig',
|
| 66 |
+
'CreditMiddleware',
|
| 67 |
+
|
| 68 |
+
# Credit Management
|
| 69 |
+
'reserve_credit',
|
| 70 |
+
'confirm_credit',
|
| 71 |
+
'refund_credit',
|
| 72 |
+
'handle_job_completion',
|
| 73 |
+
'is_refundable_error',
|
| 74 |
+
'REFUNDABLE_ERROR_PATTERNS',
|
| 75 |
+
'NON_REFUNDABLE_ERROR_PATTERNS',
|
| 76 |
+
]
|
services/credit_service/config.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Credit Service Configuration
|
| 3 |
+
|
| 4 |
+
Manages credit cost configuration for API routes.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Dict
|
| 9 |
+
from services.base_service import BaseService
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CreditServiceConfig(BaseService):
|
| 15 |
+
"""
|
| 16 |
+
Configuration for the credit service.
|
| 17 |
+
|
| 18 |
+
Controls which routes require credits and how much they cost.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
SERVICE_NAME = "credit_service"
|
| 22 |
+
|
| 23 |
+
# Route cost configuration
|
| 24 |
+
_route_costs: Dict[str, int] = {}
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def register(
|
| 28 |
+
cls,
|
| 29 |
+
route_costs: Dict[str, int] = None,
|
| 30 |
+
) -> None:
|
| 31 |
+
"""
|
| 32 |
+
Register credit service configuration.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
route_costs: Dictionary mapping route paths to credit costs
|
| 36 |
+
Example: {"/gemini/generate-video": 10, "/gemini/edit-image": 1}
|
| 37 |
+
|
| 38 |
+
Raises:
|
| 39 |
+
RuntimeError: If service is already registered
|
| 40 |
+
"""
|
| 41 |
+
if cls._registered:
|
| 42 |
+
raise RuntimeError(f"{cls.SERVICE_NAME} is already registered")
|
| 43 |
+
|
| 44 |
+
# Store route costs
|
| 45 |
+
cls._route_costs = route_costs or {}
|
| 46 |
+
|
| 47 |
+
cls._registered = True
|
| 48 |
+
|
| 49 |
+
logger.info(f"✅ {cls.SERVICE_NAME} registered successfully")
|
| 50 |
+
logger.info(f" Routes with credit costs: {len(cls._route_costs)}")
|
| 51 |
+
for route, cost in cls._route_costs.items():
|
| 52 |
+
logger.info(f" {route}: {cost} credits")
|
| 53 |
+
|
| 54 |
+
@classmethod
|
| 55 |
+
def get_middleware(cls):
|
| 56 |
+
"""Return CreditMiddleware instance."""
|
| 57 |
+
from services.credit_service.middleware import CreditMiddleware
|
| 58 |
+
return CreditMiddleware
|
| 59 |
+
|
| 60 |
+
@classmethod
|
| 61 |
+
def get_cost(cls, path: str) -> int:
|
| 62 |
+
"""
|
| 63 |
+
Get the credit cost for a given path.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
path: URL path to check
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Credit cost (0 if route doesn't require credits)
|
| 70 |
+
"""
|
| 71 |
+
cls.assert_registered()
|
| 72 |
+
return cls._route_costs.get(path, 0)
|
| 73 |
+
|
| 74 |
+
@classmethod
|
| 75 |
+
def requires_credits(cls, path: str) -> bool:
|
| 76 |
+
"""Check if a URL path requires credits."""
|
| 77 |
+
cls.assert_registered()
|
| 78 |
+
return cls.get_cost(path) > 0
|
| 79 |
+
|
| 80 |
+
@classmethod
|
| 81 |
+
def get_all_costs(cls) -> Dict[str, int]:
|
| 82 |
+
"""Get all route costs."""
|
| 83 |
+
cls.assert_registered()
|
| 84 |
+
return cls._route_costs.copy()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
__all__ = ['CreditServiceConfig']
|
services/credit_service/credit_manager.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Credit Service - Manages credit reservation, confirmation, and refunding.
|
| 3 |
+
|
| 4 |
+
Implements the Credit Reservation Pattern:
|
| 5 |
+
1. Reserve credits when job is created (deduct from user, track in job)
|
| 6 |
+
2. Confirm credits only on successful completion
|
| 7 |
+
3. Refund credits on refundable errors (server-side issues)
|
| 8 |
+
4. Keep credits on non-refundable errors (user-caused issues)
|
| 9 |
+
"""
|
| 10 |
+
import logging
|
| 11 |
+
from typing import Optional
|
| 12 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 13 |
+
from sqlalchemy import select
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# =============================================================================
|
| 19 |
+
# Error Categories for Refund Decisions
|
| 20 |
+
# =============================================================================
|
| 21 |
+
|
| 22 |
+
# Refundable errors - User gets credits back (server/API issues)
|
| 23 |
+
REFUNDABLE_ERROR_PATTERNS = [
|
| 24 |
+
"API_KEY_INVALID",
|
| 25 |
+
"QUOTA_EXCEEDED",
|
| 26 |
+
"INTERNAL_ERROR",
|
| 27 |
+
"CONNECTION_FAILED",
|
| 28 |
+
"SERVER_SHUTDOWN",
|
| 29 |
+
"TIMEOUT",
|
| 30 |
+
"Server Authentication Error",
|
| 31 |
+
"Network error",
|
| 32 |
+
"Connection refused",
|
| 33 |
+
"Connection reset",
|
| 34 |
+
"Service unavailable",
|
| 35 |
+
"503",
|
| 36 |
+
"500",
|
| 37 |
+
"429", # Rate limit (our quota, not user's fault)
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
# Non-refundable error patterns - User's input/content issue
|
| 41 |
+
NON_REFUNDABLE_ERROR_PATTERNS = [
|
| 42 |
+
"safety",
|
| 43 |
+
"blocked",
|
| 44 |
+
"SAFETY_FILTER",
|
| 45 |
+
"INVALID_INPUT",
|
| 46 |
+
"Invalid image",
|
| 47 |
+
"Bad request",
|
| 48 |
+
"400",
|
| 49 |
+
"cancelled",
|
| 50 |
+
"User cancelled",
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def is_refundable_error(error_message: Optional[str]) -> bool:
|
| 55 |
+
"""
|
| 56 |
+
Determine if an error should result in a credit refund.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
error_message: The error message from the failed job
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
True if the error is refundable (server/API issue)
|
| 63 |
+
False if non-refundable (user's fault) or no error message
|
| 64 |
+
"""
|
| 65 |
+
if not error_message:
|
| 66 |
+
return False
|
| 67 |
+
|
| 68 |
+
error_lower = error_message.lower()
|
| 69 |
+
|
| 70 |
+
# Check for REFUNDABLE patterns FIRST (specific server errors take precedence)
|
| 71 |
+
# This ensures API_KEY_INVALID is caught before generic "400" matcher
|
| 72 |
+
for pattern in REFUNDABLE_ERROR_PATTERNS:
|
| 73 |
+
if pattern.lower() in error_lower:
|
| 74 |
+
logger.debug(f"Error matched refundable pattern '{pattern}': {error_message[:100]}")
|
| 75 |
+
return True
|
| 76 |
+
|
| 77 |
+
# Check for non-refundable patterns (user-caused issues)
|
| 78 |
+
for pattern in NON_REFUNDABLE_ERROR_PATTERNS:
|
| 79 |
+
if pattern.lower() in error_lower:
|
| 80 |
+
logger.debug(f"Error matched non-refundable pattern '{pattern}': {error_message[:100]}")
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
# Default: Max retries exceeded is refundable (we consumed API resources trying)
|
| 84 |
+
if "max retries" in error_lower:
|
| 85 |
+
return True
|
| 86 |
+
|
| 87 |
+
# Default: Unknown errors are NOT refundable to prevent abuse
|
| 88 |
+
# If it's an unknown error, it's more likely user-caused
|
| 89 |
+
logger.debug(f"Unknown error (not refundable): {error_message[:100]}")
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
async def reserve_credit(session: AsyncSession, user, amount: int = 1) -> bool:
|
| 94 |
+
"""
|
| 95 |
+
Reserve credits for a job (deduct from user's balance).
|
| 96 |
+
|
| 97 |
+
The credits are deducted but tracked in the job's credits_reserved field.
|
| 98 |
+
If the job fails with a refundable error, they can be restored.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
session: Database session
|
| 102 |
+
user: User model instance
|
| 103 |
+
amount: Number of credits to reserve (default: 1)
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
True if credits were successfully reserved
|
| 107 |
+
False if user has insufficient credits
|
| 108 |
+
"""
|
| 109 |
+
if user.credits < amount:
|
| 110 |
+
logger.warning(f"User {user.user_id} has insufficient credits ({user.credits}) to reserve {amount}")
|
| 111 |
+
return False
|
| 112 |
+
|
| 113 |
+
user.credits -= amount
|
| 114 |
+
logger.info(f"Reserved {amount} credit(s) for user {user.user_id}. Remaining: {user.credits}")
|
| 115 |
+
# Note: Don't commit here - let caller handle transaction
|
| 116 |
+
return True
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
async def confirm_credit(session: AsyncSession, job) -> None:
|
| 120 |
+
"""
|
| 121 |
+
Confirm that credits were legitimately used for a completed job.
|
| 122 |
+
|
| 123 |
+
This is called when a job completes successfully. The credits stay
|
| 124 |
+
deducted (they were already deducted during reservation).
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
session: Database session
|
| 128 |
+
job: GeminiJob model instance
|
| 129 |
+
"""
|
| 130 |
+
if job.credits_reserved > 0:
|
| 131 |
+
# Credits were used - clear the reservation tracking
|
| 132 |
+
credits_used = job.credits_reserved
|
| 133 |
+
job.credits_reserved = 0
|
| 134 |
+
logger.info(f"Confirmed {credits_used} credit(s) used for job {job.job_id}")
|
| 135 |
+
# Note: Don't commit here - let caller handle transaction
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
async def refund_credit(session: AsyncSession, job, reason: str) -> bool:
|
| 139 |
+
"""
|
| 140 |
+
Refund reserved credits back to the user.
|
| 141 |
+
|
| 142 |
+
Called when a job fails due to a refundable error (server-side issue).
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
session: Database session
|
| 146 |
+
job: GeminiJob model instance
|
| 147 |
+
reason: Reason for the refund (for logging)
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
True if credits were refunded
|
| 151 |
+
False if no credits to refund or already refunded
|
| 152 |
+
"""
|
| 153 |
+
if job.credits_reserved <= 0:
|
| 154 |
+
logger.debug(f"Job {job.job_id} has no credits to refund")
|
| 155 |
+
return False
|
| 156 |
+
|
| 157 |
+
if job.credits_refunded:
|
| 158 |
+
logger.warning(f"Job {job.job_id} was already refunded")
|
| 159 |
+
return False
|
| 160 |
+
|
| 161 |
+
# Get the user to restore credits
|
| 162 |
+
from core.models import User
|
| 163 |
+
|
| 164 |
+
result = await session.execute(
|
| 165 |
+
select(User).where(User.id == job.user_id)
|
| 166 |
+
)
|
| 167 |
+
user = result.scalar_one_or_none()
|
| 168 |
+
|
| 169 |
+
if not user:
|
| 170 |
+
logger.error(f"Cannot refund job {job.job_id}: User {job.user_id} not found")
|
| 171 |
+
return False
|
| 172 |
+
|
| 173 |
+
# Restore credits
|
| 174 |
+
credits_to_refund = job.credits_reserved
|
| 175 |
+
user.credits += credits_to_refund
|
| 176 |
+
job.credits_reserved = 0
|
| 177 |
+
job.credits_refunded = True
|
| 178 |
+
|
| 179 |
+
logger.info(
|
| 180 |
+
f"Refunded {credits_to_refund} credit(s) to user {user.user_id} for job {job.job_id}. "
|
| 181 |
+
f"Reason: {reason[:100]}. New balance: {user.credits}"
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Note: Don't commit here - let caller handle transaction
|
| 185 |
+
return True
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
async def handle_job_completion(session: AsyncSession, job) -> None:
|
| 189 |
+
"""
|
| 190 |
+
Handle credit finalization when a job completes or fails.
|
| 191 |
+
|
| 192 |
+
This is the main entry point called by the job worker.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
session: Database session
|
| 196 |
+
job: GeminiJob model instance with final status
|
| 197 |
+
"""
|
| 198 |
+
if job.status == "completed":
|
| 199 |
+
# Success - confirm credits were used
|
| 200 |
+
await confirm_credit(session, job)
|
| 201 |
+
|
| 202 |
+
elif job.status == "failed":
|
| 203 |
+
# Failure - check if refundable
|
| 204 |
+
if is_refundable_error(job.error_message):
|
| 205 |
+
await refund_credit(session, job, job.error_message or "Unknown error")
|
| 206 |
+
else:
|
| 207 |
+
# Non-refundable - confirm credits were used (user's fault)
|
| 208 |
+
await confirm_credit(session, job)
|
| 209 |
+
logger.info(f"Job {job.job_id} failed with non-refundable error, credits kept")
|
| 210 |
+
|
| 211 |
+
elif job.status == "cancelled":
|
| 212 |
+
# Cancelled jobs get refunds only if they were never started
|
| 213 |
+
if job.started_at is None:
|
| 214 |
+
await refund_credit(session, job, "Job cancelled before processing")
|
| 215 |
+
else:
|
| 216 |
+
# Was processing - keep credits (API may have been consumed)
|
| 217 |
+
await confirm_credit(session, job)
|
| 218 |
+
logger.info(f"Job {job.job_id} cancelled during processing, credits kept")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
async def refund_orphaned_jobs(session: AsyncSession) -> int:
|
| 222 |
+
"""
|
| 223 |
+
Refund credits for jobs that were abandoned due to server shutdown.
|
| 224 |
+
|
| 225 |
+
Called during graceful shutdown to ensure no credits are lost.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
session: Database session
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Number of jobs that were refunded
|
| 232 |
+
"""
|
| 233 |
+
from core.models import GeminiJob
|
| 234 |
+
|
| 235 |
+
# Find jobs that are still processing with reserved credits
|
| 236 |
+
result = await session.execute(
|
| 237 |
+
select(GeminiJob).where(
|
| 238 |
+
GeminiJob.status == "processing",
|
| 239 |
+
GeminiJob.credits_reserved > 0,
|
| 240 |
+
GeminiJob.credits_refunded == False
|
| 241 |
+
)
|
| 242 |
+
)
|
| 243 |
+
orphaned_jobs = result.scalars().all()
|
| 244 |
+
|
| 245 |
+
refund_count = 0
|
| 246 |
+
for job in orphaned_jobs:
|
| 247 |
+
if await refund_credit(session, job, "SERVER_SHUTDOWN: Job orphaned during server shutdown"):
|
| 248 |
+
# Mark job as failed
|
| 249 |
+
job.status = "failed"
|
| 250 |
+
job.error_message = "Server shutdown during processing. Credits refunded."
|
| 251 |
+
refund_count += 1
|
| 252 |
+
|
| 253 |
+
if refund_count > 0:
|
| 254 |
+
await session.commit()
|
| 255 |
+
logger.info(f"Refunded {refund_count} orphaned job(s) during shutdown")
|
| 256 |
+
|
| 257 |
+
return refund_count
|
services/credit_service/middleware.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Credit Middleware - Request credit validation layer
|
| 3 |
+
|
| 4 |
+
Intercepts requests to validate and reserve credits for paid endpoints.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from fastapi import Request, HTTPException, status
|
| 10 |
+
from fastapi.responses import JSONResponse
|
| 11 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 12 |
+
|
| 13 |
+
from core.database import async_session_maker
|
| 14 |
+
from services.credit_service.config import CreditServiceConfig
|
| 15 |
+
from services.base_service.middleware_chain import (
|
| 16 |
+
BaseServiceMiddleware,
|
| 17 |
+
get_request_context,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CreditMiddleware(BaseServiceMiddleware):
|
| 24 |
+
"""
|
| 25 |
+
Credit validation middleware for request validation.
|
| 26 |
+
|
| 27 |
+
Flow:
|
| 28 |
+
1. Check if route requires credits based on URL
|
| 29 |
+
2. Get authenticated user from request.state (set by AuthMiddleware)
|
| 30 |
+
3. Check if user has sufficient credits
|
| 31 |
+
4. Reserve credits (deduct from balance)
|
| 32 |
+
5. Attach credit info to request.state
|
| 33 |
+
6. Continue to next middleware/route
|
| 34 |
+
|
| 35 |
+
Credits are reserved but tracked - they can be refunded if job fails
|
| 36 |
+
with a server-side error.
|
| 37 |
+
|
| 38 |
+
NOTE: This middleware MUST run AFTER AuthMiddleware since it needs
|
| 39 |
+
the authenticated user from request.state.user
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
SERVICE_NAME = "credit"
|
| 43 |
+
|
| 44 |
+
async def dispatch(self, request: Request, call_next):
|
| 45 |
+
"""Process request through credit middleware."""
|
| 46 |
+
# Skip OPTIONS requests (CORS preflight)
|
| 47 |
+
if request.method == "OPTIONS":
|
| 48 |
+
return await call_next(request)
|
| 49 |
+
|
| 50 |
+
# Initialize request context
|
| 51 |
+
ctx = get_request_context(request)
|
| 52 |
+
|
| 53 |
+
# Get path from request
|
| 54 |
+
path = request.url.path
|
| 55 |
+
|
| 56 |
+
# Check if route requires credits
|
| 57 |
+
credit_cost = CreditServiceConfig.get_cost(path)
|
| 58 |
+
|
| 59 |
+
if credit_cost == 0:
|
| 60 |
+
# Route doesn't require credits, skip
|
| 61 |
+
self.log_request(request, f"Route doesn't require credits")
|
| 62 |
+
ctx.set_credits(0, 0)
|
| 63 |
+
response = await call_next(request)
|
| 64 |
+
return response
|
| 65 |
+
|
| 66 |
+
# Route requires credits - user MUST be authenticated
|
| 67 |
+
# (AuthMiddleware should have already validated this)
|
| 68 |
+
user = request.state.user if hasattr(request.state, 'user') else None
|
| 69 |
+
|
| 70 |
+
if not user:
|
| 71 |
+
# This shouldn't happen if auth is configured correctly
|
| 72 |
+
self.log_error(request, "Credit-required route accessed without authentication")
|
| 73 |
+
return JSONResponse(
|
| 74 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 75 |
+
content={"detail": "Authentication required for this endpoint"},
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Check if user has sufficient credits
|
| 79 |
+
if user.credits < credit_cost:
|
| 80 |
+
self.log_request(
|
| 81 |
+
request,
|
| 82 |
+
f"Insufficient credits: has {user.credits}, needs {credit_cost}"
|
| 83 |
+
)
|
| 84 |
+
return JSONResponse(
|
| 85 |
+
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
| 86 |
+
content={
|
| 87 |
+
"detail": f"Insufficient credits. This operation requires {credit_cost} credits. You have {user.credits}.",
|
| 88 |
+
"credits_required": credit_cost,
|
| 89 |
+
"credits_available": user.credits,
|
| 90 |
+
},
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Reserve credits (deduct from user balance)
|
| 94 |
+
async with async_session_maker() as db:
|
| 95 |
+
try:
|
| 96 |
+
# Deduct credits
|
| 97 |
+
user.credits -= credit_cost
|
| 98 |
+
user.last_used_at = datetime.utcnow()
|
| 99 |
+
|
| 100 |
+
# Update in database
|
| 101 |
+
db.add(user)
|
| 102 |
+
await db.commit()
|
| 103 |
+
await db.refresh(user)
|
| 104 |
+
|
| 105 |
+
# Attach credit info to request state
|
| 106 |
+
ctx.set_credits(credit_cost, user.credits)
|
| 107 |
+
request.state.credits_reserved = credit_cost
|
| 108 |
+
request.state.credits_remaining = user.credits
|
| 109 |
+
|
| 110 |
+
self.log_request(
|
| 111 |
+
request,
|
| 112 |
+
f"Reserved {credit_cost} credits for {user.user_id}, remaining: {user.credits}"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Continue to next middleware/route
|
| 116 |
+
response = await call_next(request)
|
| 117 |
+
return response
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
await db.rollback()
|
| 121 |
+
self.log_error(request, f"Error reserving credits: {e}")
|
| 122 |
+
return JSONResponse(
|
| 123 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 124 |
+
content={"detail": "Failed to reserve credits. Please try again."},
|
| 125 |
+
)
|
| 126 |
+
finally:
|
| 127 |
+
await db.close()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
__all__ = ['CreditMiddleware']
|
tests/test_base_service.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unit tests for base service infrastructure.
|
| 3 |
+
|
| 4 |
+
Tests:
|
| 5 |
+
- BaseService registration and configuration
|
| 6 |
+
- ServiceRegistry service management
|
| 7 |
+
- ServiceConfig operations
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import pytest
|
| 11 |
+
from services.base_service import BaseService, ServiceConfig, ServiceRegistry
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TestServiceConfig:
|
| 15 |
+
"""Test ServiceConfig container."""
|
| 16 |
+
|
| 17 |
+
def test_initialization(self):
|
| 18 |
+
"""Test config initialization with kwargs."""
|
| 19 |
+
config = ServiceConfig(key1="value1", key2=42)
|
| 20 |
+
|
| 21 |
+
assert config.get("key1") == "value1"
|
| 22 |
+
assert config.get("key2") == 42
|
| 23 |
+
|
| 24 |
+
def test_get_with_default(self):
|
| 25 |
+
"""Test get with default value."""
|
| 26 |
+
config = ServiceConfig(key1="value1")
|
| 27 |
+
|
| 28 |
+
assert config.get("key1") == "value1"
|
| 29 |
+
assert config.get("missing", "default") == "default"
|
| 30 |
+
assert config.get("missing") is None
|
| 31 |
+
|
| 32 |
+
def test_set_value(self):
|
| 33 |
+
"""Test setting values."""
|
| 34 |
+
config = ServiceConfig()
|
| 35 |
+
|
| 36 |
+
config.set("key1", "value1")
|
| 37 |
+
assert config.get("key1") == "value1"
|
| 38 |
+
|
| 39 |
+
def test_dictionary_access(self):
|
| 40 |
+
"""Test dictionary-style access."""
|
| 41 |
+
config = ServiceConfig(key1="value1")
|
| 42 |
+
|
| 43 |
+
assert config["key1"] == "value1"
|
| 44 |
+
|
| 45 |
+
config["key2"] = "value2"
|
| 46 |
+
assert config["key2"] == "value2"
|
| 47 |
+
|
| 48 |
+
def test_contains(self):
|
| 49 |
+
"""Test 'in' operator."""
|
| 50 |
+
config = ServiceConfig(key1="value1")
|
| 51 |
+
|
| 52 |
+
assert "key1" in config
|
| 53 |
+
assert "missing" not in config
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TestBaseService:
|
| 57 |
+
"""Test BaseService abstract class."""
|
| 58 |
+
|
| 59 |
+
def setup_method(self):
|
| 60 |
+
"""Reset service state before each test."""
|
| 61 |
+
# Create concrete test service
|
| 62 |
+
class TestService(BaseService):
|
| 63 |
+
SERVICE_NAME = "test_service"
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def register(cls, **config):
|
| 67 |
+
super().register(**config)
|
| 68 |
+
|
| 69 |
+
self.TestService = TestService
|
| 70 |
+
|
| 71 |
+
# Reset state
|
| 72 |
+
self.TestService._registered = False
|
| 73 |
+
self.TestService._config = None
|
| 74 |
+
|
| 75 |
+
def test_registration(self):
|
| 76 |
+
"""Test service registration."""
|
| 77 |
+
assert not self.TestService.is_registered()
|
| 78 |
+
|
| 79 |
+
self.TestService.register(key1="value1", key2=42)
|
| 80 |
+
|
| 81 |
+
assert self.TestService.is_registered()
|
| 82 |
+
assert self.TestService.get_config().get("key1") == "value1"
|
| 83 |
+
assert self.TestService.get_config().get("key2") == 42
|
| 84 |
+
|
| 85 |
+
def test_double_registration_fails(self):
|
| 86 |
+
"""Test that double registration raises error."""
|
| 87 |
+
self.TestService.register(key1="value1")
|
| 88 |
+
|
| 89 |
+
with pytest.raises(RuntimeError, match="already registered"):
|
| 90 |
+
self.TestService.register(key2="value2")
|
| 91 |
+
|
| 92 |
+
def test_assert_registered(self):
|
| 93 |
+
"""Test assert_registered raises when not registered."""
|
| 94 |
+
with pytest.raises(RuntimeError, match="not registered"):
|
| 95 |
+
self.TestService.assert_registered()
|
| 96 |
+
|
| 97 |
+
self.TestService.register()
|
| 98 |
+
self.TestService.assert_registered() # Should not raise
|
| 99 |
+
|
| 100 |
+
def test_get_config_before_registration(self):
|
| 101 |
+
"""Test get_config raises before registration."""
|
| 102 |
+
with pytest.raises(RuntimeError, match="not registered"):
|
| 103 |
+
self.TestService.get_config()
|
| 104 |
+
|
| 105 |
+
def test_get_middleware_default(self):
|
| 106 |
+
"""Test default get_middleware returns None."""
|
| 107 |
+
assert self.TestService.get_middleware() is None
|
| 108 |
+
|
| 109 |
+
def test_on_shutdown_default(self):
|
| 110 |
+
"""Test default on_shutdown does nothing."""
|
| 111 |
+
self.TestService.on_shutdown() # Should not raise
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class TestServiceRegistry:
|
| 115 |
+
"""Test ServiceRegistry global registry."""
|
| 116 |
+
|
| 117 |
+
def setup_method(self):
|
| 118 |
+
"""Reset registry before each test."""
|
| 119 |
+
ServiceRegistry._services = {}
|
| 120 |
+
|
| 121 |
+
def test_register_service(self):
|
| 122 |
+
"""Test registering a service."""
|
| 123 |
+
class TestService(BaseService):
|
| 124 |
+
SERVICE_NAME = "test_service"
|
| 125 |
+
|
| 126 |
+
@classmethod
|
| 127 |
+
def register(cls, **config):
|
| 128 |
+
super().register(**config)
|
| 129 |
+
|
| 130 |
+
ServiceRegistry.register_service(TestService)
|
| 131 |
+
|
| 132 |
+
assert ServiceRegistry.get_service("test_service") == TestService
|
| 133 |
+
|
| 134 |
+
def test_register_multiple_services(self):
|
| 135 |
+
"""Test registering multiple services."""
|
| 136 |
+
class Service1(BaseService):
|
| 137 |
+
SERVICE_NAME = "service1"
|
| 138 |
+
|
| 139 |
+
@classmethod
|
| 140 |
+
def register(cls, **config):
|
| 141 |
+
super().register(**config)
|
| 142 |
+
|
| 143 |
+
class Service2(BaseService):
|
| 144 |
+
SERVICE_NAME = "service2"
|
| 145 |
+
|
| 146 |
+
@classmethod
|
| 147 |
+
def register(cls, **config):
|
| 148 |
+
super().register(**config)
|
| 149 |
+
|
| 150 |
+
ServiceRegistry.register_service(Service1)
|
| 151 |
+
ServiceRegistry.register_service(Service2)
|
| 152 |
+
|
| 153 |
+
assert len(ServiceRegistry.get_all_services()) == 2
|
| 154 |
+
assert ServiceRegistry.get_service("service1") == Service1
|
| 155 |
+
assert ServiceRegistry.get_service("service2") == Service2
|
| 156 |
+
|
| 157 |
+
def test_get_nonexistent_service(self):
|
| 158 |
+
"""Test getting service that doesn't exist."""
|
| 159 |
+
assert ServiceRegistry.get_service("nonexistent") is None
|
| 160 |
+
|
| 161 |
+
def test_overwrite_service(self):
|
| 162 |
+
"""Test registering service with same name overwrites."""
|
| 163 |
+
class Service1(BaseService):
|
| 164 |
+
SERVICE_NAME = "test"
|
| 165 |
+
version = 1
|
| 166 |
+
|
| 167 |
+
@classmethod
|
| 168 |
+
def register(cls, **config):
|
| 169 |
+
super().register(**config)
|
| 170 |
+
|
| 171 |
+
class Service2(BaseService):
|
| 172 |
+
SERVICE_NAME = "test"
|
| 173 |
+
version = 2
|
| 174 |
+
|
| 175 |
+
@classmethod
|
| 176 |
+
def register(cls, **config):
|
| 177 |
+
super().register(**config)
|
| 178 |
+
|
| 179 |
+
ServiceRegistry.register_service(Service1)
|
| 180 |
+
ServiceRegistry.register_service(Service2)
|
| 181 |
+
|
| 182 |
+
service = ServiceRegistry.get_service("test")
|
| 183 |
+
assert service.version == 2
|
| 184 |
+
|
| 185 |
+
def test_get_all_middleware(self):
|
| 186 |
+
"""Test getting middleware from all services."""
|
| 187 |
+
class MockMiddleware:
|
| 188 |
+
pass
|
| 189 |
+
|
| 190 |
+
class ServiceWithMiddleware(BaseService):
|
| 191 |
+
SERVICE_NAME = "with_middleware"
|
| 192 |
+
|
| 193 |
+
@classmethod
|
| 194 |
+
def register(cls, **config):
|
| 195 |
+
super().register(**config)
|
| 196 |
+
|
| 197 |
+
@classmethod
|
| 198 |
+
def get_middleware(cls):
|
| 199 |
+
return MockMiddleware()
|
| 200 |
+
|
| 201 |
+
class ServiceWithoutMiddleware(BaseService):
|
| 202 |
+
SERVICE_NAME = "without_middleware"
|
| 203 |
+
|
| 204 |
+
@classmethod
|
| 205 |
+
def register(cls, **config):
|
| 206 |
+
super().register(**config)
|
| 207 |
+
|
| 208 |
+
# Register services
|
| 209 |
+
ServiceWithMiddleware.register()
|
| 210 |
+
ServiceWithoutMiddleware.register()
|
| 211 |
+
|
| 212 |
+
ServiceRegistry.register_service(ServiceWithMiddleware)
|
| 213 |
+
ServiceRegistry.register_service(ServiceWithoutMiddleware)
|
| 214 |
+
|
| 215 |
+
middleware_list = ServiceRegistry.get_all_middleware()
|
| 216 |
+
|
| 217 |
+
assert len(middleware_list) == 1
|
| 218 |
+
assert isinstance(middleware_list[0], MockMiddleware)
|
| 219 |
+
|
| 220 |
+
def test_shutdown_all(self):
|
| 221 |
+
"""Test calling shutdown on all services."""
|
| 222 |
+
shutdown_called = []
|
| 223 |
+
|
| 224 |
+
class Service1(BaseService):
|
| 225 |
+
SERVICE_NAME = "service1"
|
| 226 |
+
|
| 227 |
+
@classmethod
|
| 228 |
+
def register(cls, **config):
|
| 229 |
+
super().register(**config)
|
| 230 |
+
|
| 231 |
+
@classmethod
|
| 232 |
+
def on_shutdown(cls):
|
| 233 |
+
shutdown_called.append("service1")
|
| 234 |
+
|
| 235 |
+
class Service2(BaseService):
|
| 236 |
+
SERVICE_NAME = "service2"
|
| 237 |
+
|
| 238 |
+
@classmethod
|
| 239 |
+
def register(cls, **config):
|
| 240 |
+
super().register(**config)
|
| 241 |
+
|
| 242 |
+
@classmethod
|
| 243 |
+
def on_shutdown(cls):
|
| 244 |
+
shutdown_called.append("service2")
|
| 245 |
+
|
| 246 |
+
ServiceRegistry.register_service(Service1)
|
| 247 |
+
ServiceRegistry.register_service(Service2)
|
| 248 |
+
|
| 249 |
+
ServiceRegistry.shutdown_all()
|
| 250 |
+
|
| 251 |
+
assert "service1" in shutdown_called
|
| 252 |
+
assert "service2" in shutdown_called
|
tests/test_route_matcher.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unit tests for route matcher.
|
| 3 |
+
|
| 4 |
+
Tests:
|
| 5 |
+
- Exact path matching
|
| 6 |
+
- Prefix pattern matching
|
| 7 |
+
- Glob pattern matching
|
| 8 |
+
- Regex pattern matching
|
| 9 |
+
- RouteConfig precedence logic
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import pytest
|
| 13 |
+
from services.base_service.route_matcher import RouteMatcher, RouteConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestRouteMatcher:
|
| 17 |
+
"""Test RouteMatcher pattern matching."""
|
| 18 |
+
|
| 19 |
+
def test_exact_match(self):
|
| 20 |
+
"""Test exact path matching."""
|
| 21 |
+
matcher = RouteMatcher(["/api/users", "/api/posts"])
|
| 22 |
+
|
| 23 |
+
assert matcher.matches("/api/users")
|
| 24 |
+
assert matcher.matches("/api/posts")
|
| 25 |
+
assert not matcher.matches("/api/comments")
|
| 26 |
+
assert not matcher.matches("/api/users/123")
|
| 27 |
+
|
| 28 |
+
def test_prefix_match(self):
|
| 29 |
+
"""Test prefix wildcard matching."""
|
| 30 |
+
matcher = RouteMatcher(["/api/*", "/admin/*"])
|
| 31 |
+
|
| 32 |
+
assert matcher.matches("/api/users")
|
| 33 |
+
assert matcher.matches("/api/posts")
|
| 34 |
+
assert matcher.matches("/admin/dashboard")
|
| 35 |
+
assert not matcher.matches("/public/page")
|
| 36 |
+
|
| 37 |
+
def test_complex_glob_match(self):
|
| 38 |
+
"""Test complex glob patterns."""
|
| 39 |
+
matcher = RouteMatcher(["/api/users/*/posts", "/api/**/comments"])
|
| 40 |
+
|
| 41 |
+
assert matcher.matches("/api/users/123/posts")
|
| 42 |
+
assert matcher.matches("/api/users/456/posts")
|
| 43 |
+
assert matcher.matches("/api/v1/users/comments")
|
| 44 |
+
assert matcher.matches("/api/deep/nested/path/comments")
|
| 45 |
+
assert not matcher.matches("/api/users/posts")
|
| 46 |
+
|
| 47 |
+
def test_regex_match(self):
|
| 48 |
+
"""Test regex pattern matching."""
|
| 49 |
+
matcher = RouteMatcher(["^/api/v[0-9]+/.*$", "^/users/[0-9]+$"])
|
| 50 |
+
|
| 51 |
+
assert matcher.matches("/api/v1/users")
|
| 52 |
+
assert matcher.matches("/api/v2/posts")
|
| 53 |
+
assert matcher.matches("/users/123")
|
| 54 |
+
assert not matcher.matches("/api/v/users")
|
| 55 |
+
assert not matcher.matches("/users/abc")
|
| 56 |
+
|
| 57 |
+
def test_query_parameters_stripped(self):
|
| 58 |
+
"""Test that query parameters are ignored."""
|
| 59 |
+
matcher = RouteMatcher(["/api/users"])
|
| 60 |
+
|
| 61 |
+
assert matcher.matches("/api/users?page=1")
|
| 62 |
+
assert matcher.matches("/api/users?page=1&limit=10")
|
| 63 |
+
|
| 64 |
+
def test_fragments_stripped(self):
|
| 65 |
+
"""Test that URL fragments are ignored."""
|
| 66 |
+
matcher = RouteMatcher(["/api/users"])
|
| 67 |
+
|
| 68 |
+
assert matcher.matches("/api/users#section")
|
| 69 |
+
|
| 70 |
+
def test_trailing_slash_normalized(self):
|
| 71 |
+
"""Test trailing slash normalization."""
|
| 72 |
+
matcher = RouteMatcher(["/api/users"])
|
| 73 |
+
|
| 74 |
+
assert matcher.matches("/api/users/")
|
| 75 |
+
|
| 76 |
+
# Root path keeps trailing slash
|
| 77 |
+
root_matcher = RouteMatcher(["/"])
|
| 78 |
+
assert root_matcher.matches("/")
|
| 79 |
+
|
| 80 |
+
def test_empty_patterns(self):
|
| 81 |
+
"""Test with empty pattern list."""
|
| 82 |
+
matcher = RouteMatcher([])
|
| 83 |
+
|
| 84 |
+
assert not matcher.matches("/any/path")
|
| 85 |
+
|
| 86 |
+
def test_get_matching_pattern(self):
|
| 87 |
+
"""Test getting the matched pattern."""
|
| 88 |
+
matcher = RouteMatcher([
|
| 89 |
+
"/api/users",
|
| 90 |
+
"/api/*",
|
| 91 |
+
"/admin/**"
|
| 92 |
+
])
|
| 93 |
+
|
| 94 |
+
assert matcher.get_matching_pattern("/api/users") == "/api/users"
|
| 95 |
+
assert matcher.get_matching_pattern("/api/posts") == "/api/*"
|
| 96 |
+
assert matcher.get_matching_pattern("/admin/deep/path") == "/admin/**"
|
| 97 |
+
assert matcher.get_matching_pattern("/public") is None
|
| 98 |
+
|
| 99 |
+
def test_mixed_patterns(self):
|
| 100 |
+
"""Test combination of all pattern types."""
|
| 101 |
+
matcher = RouteMatcher([
|
| 102 |
+
"/exact",
|
| 103 |
+
"/prefix/*",
|
| 104 |
+
"/glob/*/nested",
|
| 105 |
+
"^/regex/[0-9]+$"
|
| 106 |
+
])
|
| 107 |
+
|
| 108 |
+
assert matcher.matches("/exact")
|
| 109 |
+
assert matcher.matches("/prefix/anything")
|
| 110 |
+
assert matcher.matches("/glob/123/nested")
|
| 111 |
+
assert matcher.matches("/regex/456")
|
| 112 |
+
assert not matcher.matches("/other")
|
| 113 |
+
|
| 114 |
+
def test_invalid_regex_pattern(self):
|
| 115 |
+
"""Test that invalid regex is handled gracefully."""
|
| 116 |
+
# Should not raise, just log warning and skip pattern
|
| 117 |
+
matcher = RouteMatcher(["^[invalid(regex$"])
|
| 118 |
+
|
| 119 |
+
assert not matcher.matches("/anything")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class TestRouteConfig:
|
| 123 |
+
"""Test RouteConfig precedence logic."""
|
| 124 |
+
|
| 125 |
+
def test_required_routes(self):
|
| 126 |
+
"""Test required route checking."""
|
| 127 |
+
config = RouteConfig(
|
| 128 |
+
required=["/api/users", "/api/posts"],
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
assert config.is_required("/api/users")
|
| 132 |
+
assert config.is_required("/api/posts")
|
| 133 |
+
assert not config.is_required("/public")
|
| 134 |
+
|
| 135 |
+
def test_optional_routes(self):
|
| 136 |
+
"""Test optional route checking."""
|
| 137 |
+
config = RouteConfig(
|
| 138 |
+
optional=["/", "/home"],
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
assert config.is_optional("/")
|
| 142 |
+
assert config.is_optional("/home")
|
| 143 |
+
assert not config.is_optional("/api/users")
|
| 144 |
+
|
| 145 |
+
def test_public_routes(self):
|
| 146 |
+
"""Test public route checking."""
|
| 147 |
+
config = RouteConfig(
|
| 148 |
+
public=["/health", "/docs"],
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
assert config.is_public("/health")
|
| 152 |
+
assert config.is_public("/docs")
|
| 153 |
+
assert not config.is_public("/api/users")
|
| 154 |
+
|
| 155 |
+
def test_public_overrides_required(self):
|
| 156 |
+
"""Test that public takes precedence over required."""
|
| 157 |
+
config = RouteConfig(
|
| 158 |
+
required=["/api/*"],
|
| 159 |
+
public=["/api/health"],
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# /api/health is public, so not required
|
| 163 |
+
assert config.is_public("/api/health")
|
| 164 |
+
assert not config.is_required("/api/health")
|
| 165 |
+
|
| 166 |
+
# Other /api routes are required
|
| 167 |
+
assert config.is_required("/api/users")
|
| 168 |
+
assert not config.is_public("/api/users")
|
| 169 |
+
|
| 170 |
+
def test_public_overrides_optional(self):
|
| 171 |
+
"""Test that public takes precedence over optional."""
|
| 172 |
+
config = RouteConfig(
|
| 173 |
+
optional=["/api/*"],
|
| 174 |
+
public=["/api/health"],
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# /api/health is public, so not optional
|
| 178 |
+
assert config.is_public("/api/health")
|
| 179 |
+
assert not config.is_optional("/api/health")
|
| 180 |
+
|
| 181 |
+
# Other /api routes are optional
|
| 182 |
+
assert config.is_optional("/api/users")
|
| 183 |
+
|
| 184 |
+
def test_required_overrides_optional(self):
|
| 185 |
+
"""Test that required takes precedence over optional."""
|
| 186 |
+
config = RouteConfig(
|
| 187 |
+
required=["/api/users"],
|
| 188 |
+
optional=["/api/*"],
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# /api/users is required, so not optional
|
| 192 |
+
assert config.is_required("/api/users")
|
| 193 |
+
assert not config.is_optional("/api/users")
|
| 194 |
+
|
| 195 |
+
# Other /api routes are optional
|
| 196 |
+
assert config.is_optional("/api/posts")
|
| 197 |
+
|
| 198 |
+
def test_requires_service(self):
|
| 199 |
+
"""Test requires_service helper."""
|
| 200 |
+
config = RouteConfig(
|
| 201 |
+
required=["/api/users"],
|
| 202 |
+
optional=["/api/posts"],
|
| 203 |
+
public=["/health"],
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Service required
|
| 207 |
+
assert config.requires_service("/api/users")
|
| 208 |
+
|
| 209 |
+
# Service optional (still requires service)
|
| 210 |
+
assert config.requires_service("/api/posts")
|
| 211 |
+
|
| 212 |
+
# Public (does not require service)
|
| 213 |
+
assert not config.requires_service("/health")
|
| 214 |
+
|
| 215 |
+
def test_empty_config(self):
|
| 216 |
+
"""Test with empty configuration."""
|
| 217 |
+
config = RouteConfig()
|
| 218 |
+
|
| 219 |
+
assert not config.is_required("/any")
|
| 220 |
+
assert not config.is_optional("/any")
|
| 221 |
+
assert not config.is_public("/any")
|
| 222 |
+
assert not config.requires_service("/any")
|
| 223 |
+
|
| 224 |
+
def test_complex_precedence(self):
|
| 225 |
+
"""Test complex precedence scenarios."""
|
| 226 |
+
config = RouteConfig(
|
| 227 |
+
required=["/api/users"], # Specific required path
|
| 228 |
+
optional=["/api/*"], # Broader optional pattern
|
| 229 |
+
public=["/api/health"],
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Public overrides everything
|
| 233 |
+
assert config.is_public("/api/health")
|
| 234 |
+
assert not config.is_required("/api/health")
|
| 235 |
+
assert not config.is_optional("/api/health")
|
| 236 |
+
|
| 237 |
+
# Required path
|
| 238 |
+
assert config.is_required("/api/users")
|
| 239 |
+
assert not config.is_optional("/api/users")
|
| 240 |
+
|
| 241 |
+
# Optional for other paths under /api
|
| 242 |
+
assert config.is_optional("/api/posts")
|
| 243 |
+
assert not config.is_required("/api/posts")
|