jebin2 commited on
Commit
bcc8074
·
1 Parent(s): 3fb20ed
app.py CHANGED
@@ -39,6 +39,52 @@ async def lifespan(app: FastAPI):
39
  register_db_service_config()
40
  logger.info("✅ DB Service configured")
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # Check for RESET_DB environment variable
43
  if os.getenv("RESET_DB", "").lower() == "true":
44
  logger.warning(f"RESET_DB is set to true. Skipping download and clearing local database ({DB_FILENAME}).")
@@ -95,6 +141,14 @@ app.add_middleware(
95
  allow_headers=["*"],
96
  )
97
 
 
 
 
 
 
 
 
 
98
  # Include Routers
99
  app.include_router(general.router)
100
  app.include_router(auth.router)
 
39
  register_db_service_config()
40
  logger.info("✅ DB Service configured")
41
 
42
+ # Register Auth Service configuration
43
+ from services.auth_service import register_auth_service
44
+ register_auth_service(
45
+ required_urls=[
46
+ "/blink",
47
+ "/api/*", # All admin blink API endpoints
48
+ "/contact",
49
+ "/gemini/*",
50
+ "/credits/balance",
51
+ "/credits/history",
52
+ "/payments/create-order",
53
+ "/payments/verify/*",
54
+ ],
55
+ optional_urls=[
56
+ "/", # Home page works with or without auth
57
+ ],
58
+ public_urls=[
59
+ "/health",
60
+ "/auth/*",
61
+ "/payments/packages", # Public pricing info
62
+ "/payments/webhook/*", # Webhooks from payment gateway
63
+ "/docs",
64
+ "/openapi.json",
65
+ "/redoc",
66
+ ],
67
+ jwt_secret=os.getenv("JWT_SECRET"),
68
+ jwt_algorithm="HS256",
69
+ jwt_expiry_hours=24,
70
+ google_client_id=os.getenv("AUTH_SIGN_IN_GOOGLE_CLIENT_ID"),
71
+ admin_emails=os.getenv("ADMIN_EMAILS", "").split(",") if os.getenv("ADMIN_EMAILS") else [],
72
+ )
73
+ logger.info("✅ Auth Service configured")
74
+
75
+ # Register Credit Service configuration
76
+ from services.credit_service import register_credit_service
77
+ register_credit_service(
78
+ route_costs={
79
+ "/gemini/generate-animation-prompt": 1,
80
+ "/gemini/edit-image": 1,
81
+ "/gemini/generate-video": 10,
82
+ "/gemini/generate-text": 1,
83
+ "/gemini/analyze-image": 1,
84
+ }
85
+ )
86
+ logger.info("✅ Credit Service configured")
87
+
88
  # Check for RESET_DB environment variable
89
  if os.getenv("RESET_DB", "").lower() == "true":
90
  logger.warning(f"RESET_DB is set to true. Skipping download and clearing local database ({DB_FILENAME}).")
 
141
  allow_headers=["*"],
142
  )
143
 
144
+ # Add Credit Middleware first (executes second - after auth)
145
+ from services.credit_service import CreditMiddleware
146
+ app.add_middleware(CreditMiddleware)
147
+
148
+ # Add Auth Middleware second (executes first - sets user)
149
+ from services.auth_service import AuthMiddleware
150
+ app.add_middleware(AuthMiddleware)
151
+
152
  # Include Routers
153
  app.include_router(general.router)
154
  app.include_router(auth.router)
routers/blink.py CHANGED
@@ -11,7 +11,7 @@ import logging
11
  from core.database import get_db
12
  from core.models import User, AuditLog, GeminiJob, Contact, ClientUser
13
  from services.encryption_service import decrypt_multiple_blocks
14
- from dependencies import get_geolocation, get_optional_user, get_current_user
15
 
16
  logger = logging.getLogger(__name__)
17
 
@@ -27,16 +27,18 @@ USER_ID_LENGTH = 20
27
 
28
  @router.get("/api/data")
29
  async def get_data(
 
30
  page: int = Query(1, ge=1, description="Page number"),
31
  limit: int = Query(100, ge=1, le=500, description="Items per page"),
32
  log_type: str = Query(None, description="Filter by log type: client, server"),
33
- user: User = Depends(get_current_user), # Auth required
34
  db: AsyncSession = Depends(get_db)
35
  ):
36
  """
37
  Get paginated audit log data for the authenticated user.
38
  Admins see all logs from all users.
 
39
  """
 
40
  from services.db_service import QueryService
41
 
42
  try:
@@ -93,15 +95,17 @@ async def get_data(
93
 
94
  @router.get("/api/users")
95
  async def get_users(
 
96
  page: int = Query(1, ge=1, description="Page number"),
97
  limit: int = Query(50, ge=1, le=500, description="Items per page"),
98
- user: User = Depends(get_current_user), # Auth required
99
  db: AsyncSession = Depends(get_db)
100
  ):
101
  """
102
  Get current user's profile data.
103
  Admins see paginated list of all users.
 
104
  """
 
105
  from services.db_service import QueryService
106
 
107
  try:
@@ -148,16 +152,18 @@ async def get_users(
148
 
149
  @router.get("/api/client-users")
150
  async def get_client_users(
 
151
  page: int = Query(1, ge=1, description="Page number"),
152
  limit: int = Query(50, ge=1, le=500, description="Items per page"),
153
  user_id: str = Query(None, description="Filter by server user_id"),
154
- user: User = Depends(get_current_user), # Auth required
155
  db: AsyncSession = Depends(get_db)
156
  ):
157
  """
158
  Get current user's client mappings.
159
  Admins see all client mappings from all users.
 
160
  """
 
161
  from services.db_service import QueryService
162
 
163
  try:
@@ -197,17 +203,19 @@ async def get_client_users(
197
 
198
  @router.get("/api/audit-logs")
199
  async def get_audit_logs(
 
200
  page: int = Query(1, ge=1, description="Page number"),
201
  limit: int = Query(50, ge=1, le=500, description="Items per page"),
202
  log_type: str = Query(None, description="Filter by log type: client, server"),
203
  action: str = Query(None, description="Filter by action"),
204
- user: User = Depends(get_current_user), # Auth required
205
  db: AsyncSession = Depends(get_db)
206
  ):
207
  """
208
  Get current user's audit logs with optional filters.
209
  Admins see all logs from all users.
 
210
  """
 
211
  from services.db_service import QueryService
212
 
213
  try:
@@ -263,15 +271,17 @@ async def get_audit_logs(
263
 
264
  @router.get("/api/gemini-jobs")
265
  async def get_gemini_jobs(
 
266
  page: int = Query(1, ge=1, description="Page number"),
267
  limit: int = Query(50, ge=1, le=500, description="Items per page"),
268
- user: User = Depends(get_current_user), # Auth required
269
  db: AsyncSession = Depends(get_db)
270
  ):
271
  """
272
  Get current user's Gemini jobs.
273
  Admins see all jobs from all users.
 
274
  """
 
275
  from services.db_service import QueryService
276
 
277
  try:
@@ -313,15 +323,17 @@ async def get_gemini_jobs(
313
 
314
  @router.get("/api/payment-transactions")
315
  async def get_payment_transactions(
 
316
  page: int = Query(1, ge=1, description="Page number"),
317
  limit: int = Query(50, ge=1, le=500, description="Items per page"),
318
- user: User = Depends(get_current_user), # Auth required
319
  db: AsyncSession = Depends(get_db)
320
  ):
321
  """
322
  Get current user's payment transactions.
323
  Admins see all transactions from all users.
 
324
  """
 
325
  from core.models import PaymentTransaction
326
  from services.db_service import QueryService
327
 
@@ -383,15 +395,17 @@ async def get_payment_transactions(
383
 
384
  @router.get("/api/contacts")
385
  async def get_contacts(
 
386
  page: int = Query(1, ge=1, description="Page number"),
387
  limit: int = Query(50, ge=1, le=500, description="Items per page"),
388
- user: User = Depends(get_current_user), # Auth required
389
  db: AsyncSession = Depends(get_db)
390
  ):
391
  """
392
  Get current user's contact form submissions.
393
  Admins see all contact submissions from all users.
 
394
  """
 
395
  from services.db_service import QueryService
396
 
397
  try:
@@ -436,13 +450,16 @@ async def get_contacts(
436
  async def blink(
437
  request: Request,
438
  userid: str = Query(..., description="User ID (20 chars) + encrypted data"),
439
- db: AsyncSession = Depends(get_db),
440
- current_user: User = Depends(get_optional_user)
441
  ):
442
  """
443
  Process blink request with encrypted user data.
444
  Logs to AuditLog with log_type='client'.
445
 
 
 
 
 
446
  If authenticated via JWT:
447
  - Creates a new ClientUser entry linking client_user_id to server user_id
448
  - Sets user_id in AuditLog entries
@@ -450,6 +467,8 @@ async def blink(
450
  If not authenticated:
451
  - Creates AuditLog entries with user_id=None (anonymous)
452
  """
 
 
453
  try:
454
  # Validate minimum length
455
  if len(userid) < USER_ID_LENGTH:
 
11
  from core.database import get_db
12
  from core.models import User, AuditLog, GeminiJob, Contact, ClientUser
13
  from services.encryption_service import decrypt_multiple_blocks
14
+ from dependencies import get_geolocation
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
27
 
28
  @router.get("/api/data")
29
  async def get_data(
30
+ request: Request,
31
  page: int = Query(1, ge=1, description="Page number"),
32
  limit: int = Query(100, ge=1, le=500, description="Items per page"),
33
  log_type: str = Query(None, description="Filter by log type: client, server"),
 
34
  db: AsyncSession = Depends(get_db)
35
  ):
36
  """
37
  Get paginated audit log data for the authenticated user.
38
  Admins see all logs from all users.
39
+ Auth handled by AuthMiddleware - user in request.state.user
40
  """
41
+ user = request.state.user
42
  from services.db_service import QueryService
43
 
44
  try:
 
95
 
96
  @router.get("/api/users")
97
  async def get_users(
98
+ request: Request,
99
  page: int = Query(1, ge=1, description="Page number"),
100
  limit: int = Query(50, ge=1, le=500, description="Items per page"),
 
101
  db: AsyncSession = Depends(get_db)
102
  ):
103
  """
104
  Get current user's profile data.
105
  Admins see paginated list of all users.
106
+ Auth handled by AuthMiddleware - user in request.state.user
107
  """
108
+ user = request.state.user
109
  from services.db_service import QueryService
110
 
111
  try:
 
152
 
153
  @router.get("/api/client-users")
154
  async def get_client_users(
155
+ request: Request,
156
  page: int = Query(1, ge=1, description="Page number"),
157
  limit: int = Query(50, ge=1, le=500, description="Items per page"),
158
  user_id: str = Query(None, description="Filter by server user_id"),
 
159
  db: AsyncSession = Depends(get_db)
160
  ):
161
  """
162
  Get current user's client mappings.
163
  Admins see all client mappings from all users.
164
+ Auth handled by AuthMiddleware - user in request.state.user
165
  """
166
+ user = request.state.user
167
  from services.db_service import QueryService
168
 
169
  try:
 
203
 
204
  @router.get("/api/audit-logs")
205
  async def get_audit_logs(
206
+ request: Request,
207
  page: int = Query(1, ge=1, description="Page number"),
208
  limit: int = Query(50, ge=1, le=500, description="Items per page"),
209
  log_type: str = Query(None, description="Filter by log type: client, server"),
210
  action: str = Query(None, description="Filter by action"),
 
211
  db: AsyncSession = Depends(get_db)
212
  ):
213
  """
214
  Get current user's audit logs with optional filters.
215
  Admins see all logs from all users.
216
+ Auth handled by AuthMiddleware - user in request.state.user
217
  """
218
+ user = request.state.user
219
  from services.db_service import QueryService
220
 
221
  try:
 
271
 
272
  @router.get("/api/gemini-jobs")
273
  async def get_gemini_jobs(
274
+ request: Request,
275
  page: int = Query(1, ge=1, description="Page number"),
276
  limit: int = Query(50, ge=1, le=500, description="Items per page"),
 
277
  db: AsyncSession = Depends(get_db)
278
  ):
279
  """
280
  Get current user's Gemini jobs.
281
  Admins see all jobs from all users.
282
+ Auth handled by AuthMiddleware - user in request.state.user
283
  """
284
+ user = request.state.user
285
  from services.db_service import QueryService
286
 
287
  try:
 
323
 
324
  @router.get("/api/payment-transactions")
325
  async def get_payment_transactions(
326
+ request: Request,
327
  page: int = Query(1, ge=1, description="Page number"),
328
  limit: int = Query(50, ge=1, le=500, description="Items per page"),
 
329
  db: AsyncSession = Depends(get_db)
330
  ):
331
  """
332
  Get current user's payment transactions.
333
  Admins see all transactions from all users.
334
+ Auth handled by AuthMiddleware - user in request.state.user
335
  """
336
+ user = request.state.user
337
  from core.models import PaymentTransaction
338
  from services.db_service import QueryService
339
 
 
395
 
396
  @router.get("/api/contacts")
397
  async def get_contacts(
398
+ request: Request,
399
  page: int = Query(1, ge=1, description="Page number"),
400
  limit: int = Query(50, ge=1, le=500, description="Items per page"),
 
401
  db: AsyncSession = Depends(get_db)
402
  ):
403
  """
404
  Get current user's contact form submissions.
405
  Admins see all contact submissions from all users.
406
+ Auth handled by AuthMiddleware - user in request.state.user
407
  """
408
+ user = request.state.user
409
  from services.db_service import QueryService
410
 
411
  try:
 
450
  async def blink(
451
  request: Request,
452
  userid: str = Query(..., description="User ID (20 chars) + encrypted data"),
453
+ db: AsyncSession = Depends(get_db)
 
454
  ):
455
  """
456
  Process blink request with encrypted user data.
457
  Logs to AuditLog with log_type='client'.
458
 
459
+ Auth is optional (handled by AuthMiddleware):
460
+ - If authenticated: user in request.state.user
461
+ - If not authenticated: request.state.user is None
462
+
463
  If authenticated via JWT:
464
  - Creates a new ClientUser entry linking client_user_id to server user_id
465
  - Sets user_id in AuditLog entries
 
467
  If not authenticated:
468
  - Creates AuditLog entries with user_id=None (anonymous)
469
  """
470
+ # Optional auth - may be None
471
+ current_user = request.state.user
472
  try:
473
  # Validate minimum length
474
  if len(userid) < USER_ID_LENGTH:
routers/contact.py CHANGED
@@ -14,7 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
14
 
15
  from core.database import get_db
16
  from core.models import User, Contact
17
- from dependencies import get_current_user
18
 
19
  logger = logging.getLogger(__name__)
20
 
@@ -45,14 +44,17 @@ class ContactResponse(BaseModel):
45
  async def submit_contact(
46
  request_body: ContactRequest,
47
  request: Request,
48
- user: User = Depends(get_current_user),
49
  db: AsyncSession = Depends(get_db)
50
  ):
51
  """
52
  Submit a contact form for customer support.
53
 
54
- Requires authentication - user must be logged in.
 
55
  """
 
 
 
56
  # Validate message
57
  if not request_body.message or not request_body.message.strip():
58
  raise HTTPException(
 
14
 
15
  from core.database import get_db
16
  from core.models import User, Contact
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
 
44
  async def submit_contact(
45
  request_body: ContactRequest,
46
  request: Request,
 
47
  db: AsyncSession = Depends(get_db)
48
  ):
49
  """
50
  Submit a contact form for customer support.
51
 
52
+ Requires authentication - user is authenticated by AuthMiddleware.
53
+ User is available in request.state.user
54
  """
55
+ # Get authenticated user from middleware
56
+ user = request.state.user
57
+
58
  # Validate message
59
  if not request_body.message or not request_body.message.strip():
60
  raise HTTPException(
routers/credits.py CHANGED
@@ -3,7 +3,7 @@ Credits Router - API endpoints for credit management.
3
 
4
  Provides endpoints for checking credit balance and viewing credit history.
5
  """
6
- from fastapi import APIRouter, Depends, Query
7
  from pydantic import BaseModel
8
  from typing import List, Optional
9
  from datetime import datetime
@@ -12,7 +12,6 @@ from sqlalchemy import select, desc
12
 
13
  from core.database import get_db
14
  from core.models import User, GeminiJob
15
- from dependencies import get_current_user
16
 
17
  router = APIRouter(prefix="/credits", tags=["credits"])
18
 
@@ -47,15 +46,17 @@ class CreditHistoryResponse(BaseModel):
47
  limit: int
48
 
49
 
50
- @router.get("", response_model=CreditBalanceResponse)
51
  async def get_credits(
52
- user: User = Depends(get_current_user)
53
  ):
54
  """
55
  Get current credit balance.
56
 
57
  Returns the user's current credit balance and last usage time.
 
58
  """
 
59
  return CreditBalanceResponse(
60
  user_id=user.user_id,
61
  credits=user.credits,
@@ -65,7 +66,7 @@ async def get_credits(
65
 
66
  @router.get("/history", response_model=CreditHistoryResponse)
67
  async def get_credit_history(
68
- user: User = Depends(get_current_user),
69
  db: AsyncSession = Depends(get_db),
70
  page: int = Query(1, ge=1, description="Page number"),
71
  limit: int = Query(20, ge=1, le=100, description="Items per page")
@@ -77,7 +78,9 @@ async def get_credit_history(
77
  showing which jobs used credits and which were refunded.
78
 
79
  Only includes jobs where credits were reserved (credits_reserved > 0).
 
80
  """
 
81
  offset = (page - 1) * limit
82
 
83
  # Query jobs with credit transactions
 
3
 
4
  Provides endpoints for checking credit balance and viewing credit history.
5
  """
6
+ from fastapi import APIRouter, Depends, Query, Request
7
  from pydantic import BaseModel
8
  from typing import List, Optional
9
  from datetime import datetime
 
12
 
13
  from core.database import get_db
14
  from core.models import User, GeminiJob
 
15
 
16
  router = APIRouter(prefix="/credits", tags=["credits"])
17
 
 
46
  limit: int
47
 
48
 
49
+ @router.get("/balance", response_model=CreditBalanceResponse)
50
  async def get_credits(
51
+ request: Request
52
  ):
53
  """
54
  Get current credit balance.
55
 
56
  Returns the user's current credit balance and last usage time.
57
+ Auth is handled by AuthMiddleware - user is in request.state.user
58
  """
59
+ user = request.state.user
60
  return CreditBalanceResponse(
61
  user_id=user.user_id,
62
  credits=user.credits,
 
66
 
67
  @router.get("/history", response_model=CreditHistoryResponse)
68
  async def get_credit_history(
69
+ request: Request,
70
  db: AsyncSession = Depends(get_db),
71
  page: int = Query(1, ge=1, description="Page number"),
72
  limit: int = Query(20, ge=1, le=100, description="Items per page")
 
78
  showing which jobs used credits and which were refunded.
79
 
80
  Only includes jobs where credits were reserved (credits_reserved > 0).
81
+ Auth is handled by AuthMiddleware - user is in request.state.user
82
  """
83
+ user = request.state.user
84
  offset = (page - 1) * limit
85
 
86
  # Query jobs with credit transactions
routers/gemini.py CHANGED
@@ -5,7 +5,7 @@ Authentication via JWT (Authorization: Bearer <token>).
5
  """
6
  import os
7
  import uuid
8
- from fastapi import APIRouter, Depends, HTTPException, status
9
  from fastapi.responses import FileResponse
10
  from pydantic import BaseModel, Field
11
  from typing import Optional, Literal
@@ -15,7 +15,6 @@ from sqlalchemy import select, func
15
  from core.database import get_db
16
  from core.models import User, GeminiJob
17
  from services.gemini_service import MODELS, DOWNLOADS_DIR
18
- from dependencies import verify_credits, verify_video_credits, get_current_user
19
  from datetime import datetime
20
 
21
  router = APIRouter(prefix="/gemini", tags=["gemini"])
@@ -102,13 +101,16 @@ async def create_job(
102
 
103
  @router.post("/generate-animation-prompt")
104
  async def generate_animation_prompt(
 
105
  request: GenerateAnimationPromptRequest,
106
- user: User = Depends(verify_credits),
107
  db: AsyncSession = Depends(get_db)
108
  ):
109
  """
110
  Queue an animation prompt generation job.
 
111
  """
 
 
112
  job = await create_job(
113
  db=db,
114
  user=user,
@@ -118,7 +120,7 @@ async def generate_animation_prompt(
118
  "mime_type": request.mime_type,
119
  "custom_prompt": request.custom_prompt
120
  },
121
- credits_reserved=1
122
  )
123
 
124
  position = await get_queue_position(db, job.job_id)
@@ -134,13 +136,16 @@ async def generate_animation_prompt(
134
 
135
  @router.post("/edit-image")
136
  async def edit_image(
 
137
  request: EditImageRequest,
138
- user: User = Depends(verify_credits),
139
  db: AsyncSession = Depends(get_db)
140
  ):
141
  """
142
  Queue an image edit job.
 
143
  """
 
 
144
  job = await create_job(
145
  db=db,
146
  user=user,
@@ -150,7 +155,7 @@ async def edit_image(
150
  "mime_type": request.mime_type,
151
  "prompt": request.prompt
152
  },
153
- credits_reserved=1
154
  )
155
 
156
  position = await get_queue_position(db, job.job_id)
@@ -166,13 +171,16 @@ async def edit_image(
166
 
167
  @router.post("/generate-video")
168
  async def generate_video(
 
169
  request: GenerateVideoRequest,
170
- user: User = Depends(verify_video_credits),
171
  db: AsyncSession = Depends(get_db)
172
  ):
173
  """
174
  Queue a video generation job.
 
175
  """
 
 
176
  job = await create_job(
177
  db=db,
178
  user=user,
@@ -185,7 +193,7 @@ async def generate_video(
185
  "resolution": request.resolution,
186
  "number_of_videos": request.number_of_videos
187
  },
188
- credits_reserved=10 # Video jobs cost 10 credits
189
  )
190
 
191
  position = await get_queue_position(db, job.job_id)
@@ -201,13 +209,16 @@ async def generate_video(
201
 
202
  @router.post("/generate-text")
203
  async def generate_text(
 
204
  request: GenerateTextRequest,
205
- user: User = Depends(verify_credits),
206
  db: AsyncSession = Depends(get_db)
207
  ):
208
  """
209
  Queue a text generation job.
 
210
  """
 
 
211
  job = await create_job(
212
  db=db,
213
  user=user,
@@ -216,7 +227,7 @@ async def generate_text(
216
  "prompt": request.prompt,
217
  "model": request.model
218
  },
219
- credits_reserved=1
220
  )
221
 
222
  position = await get_queue_position(db, job.job_id)
@@ -232,13 +243,16 @@ async def generate_text(
232
 
233
  @router.post("/analyze-image")
234
  async def analyze_image(
 
235
  request: AnalyzeImageRequest,
236
- user: User = Depends(verify_credits),
237
  db: AsyncSession = Depends(get_db)
238
  ):
239
  """
240
  Queue an image analysis job.
 
241
  """
 
 
242
  job = await create_job(
243
  db=db,
244
  user=user,
@@ -248,7 +262,7 @@ async def analyze_image(
248
  "mime_type": request.mime_type,
249
  "prompt": request.prompt
250
  },
251
- credits_reserved=1
252
  )
253
 
254
  position = await get_queue_position(db, job.job_id)
@@ -264,7 +278,7 @@ async def analyze_image(
264
 
265
  @router.get("/jobs")
266
  async def get_jobs(
267
- user: User = Depends(get_current_user),
268
  db: AsyncSession = Depends(get_db),
269
  page: int = 1,
270
  limit: int = 20
@@ -272,7 +286,9 @@ async def get_jobs(
272
  """
273
  Get all jobs created by the current user.
274
  Returns a paginated list of jobs with status, type, and prompt (for video jobs).
 
275
  """
 
276
  offset = (page - 1) * limit
277
 
278
  # Query jobs for the current user
@@ -326,14 +342,16 @@ async def get_jobs(
326
  @router.get("/job/{job_id}")
327
  async def get_job_status(
328
  job_id: str,
329
- user: User = Depends(get_current_user),
330
  db: AsyncSession = Depends(get_db)
331
  ):
332
  """
333
  Get the status of a job.
334
  Poll this endpoint until status is 'completed' or 'failed'.
335
  For processing video jobs, this will check the Gemini API status and update the job.
 
336
  """
 
337
  query = select(GeminiJob).where(
338
  GeminiJob.job_id == job_id,
339
  GeminiJob.user_id == user.id # Integer FK comparison
@@ -391,14 +409,16 @@ async def get_job_status(
391
  @router.get("/download/{job_id}")
392
  async def download_video(
393
  job_id: str,
394
- user: User = Depends(get_current_user),
395
  db: AsyncSession = Depends(get_db)
396
  ):
397
  """
398
  Download a generated video.
399
  Downloads from Gemini URL, streams to client, then deletes local file.
400
  No permanent storage on server.
 
401
  """
 
402
  from fastapi.responses import StreamingResponse
403
  import httpx
404
 
@@ -470,14 +490,16 @@ async def download_video(
470
  @router.post("/job/{job_id}/cancel")
471
  async def cancel_job(
472
  job_id: str,
473
- user: User = Depends(get_current_user),
474
  db: AsyncSession = Depends(get_db)
475
  ):
476
  """
477
  Cancel a queued job.
478
  Only jobs with status 'queued' can be cancelled.
479
  Processing/completed/failed jobs cannot be cancelled.
 
480
  """
 
481
  query = select(GeminiJob).where(
482
  GeminiJob.job_id == job_id,
483
  GeminiJob.user_id == user.id # Integer FK comparison
@@ -512,7 +534,7 @@ async def cancel_job(
512
  @router.delete("/job/{job_id}")
513
  async def delete_job(
514
  job_id: str,
515
- user: User = Depends(get_current_user),
516
  db: AsyncSession = Depends(get_db)
517
  ):
518
  """
@@ -521,7 +543,9 @@ async def delete_job(
521
  Refund policy:
522
  - If queued: Refund 8 credits (10 cost - 2 penalty), soft delete job.
523
  - If processing/completed/failed: Soft delete job (no refund).
 
524
  """
 
525
  from services.db_service import QueryService
526
 
527
  qs = QueryService(user, db)
 
5
  """
6
  import os
7
  import uuid
8
+ from fastapi import APIRouter, Depends, HTTPException, status, Request
9
  from fastapi.responses import FileResponse
10
  from pydantic import BaseModel, Field
11
  from typing import Optional, Literal
 
15
  from core.database import get_db
16
  from core.models import User, GeminiJob
17
  from services.gemini_service import MODELS, DOWNLOADS_DIR
 
18
  from datetime import datetime
19
 
20
  router = APIRouter(prefix="/gemini", tags=["gemini"])
 
101
 
102
  @router.post("/generate-animation-prompt")
103
  async def generate_animation_prompt(
104
+ req: Request,
105
  request: GenerateAnimationPromptRequest,
 
106
  db: AsyncSession = Depends(get_db)
107
  ):
108
  """
109
  Queue an animation prompt generation job.
110
+ Auth and credit validation handled by middleware.
111
  """
112
+ user = req.state.user
113
+ credits_reserved = req.state.credits_reserved
114
  job = await create_job(
115
  db=db,
116
  user=user,
 
120
  "mime_type": request.mime_type,
121
  "custom_prompt": request.custom_prompt
122
  },
123
+ credits_reserved=credits_reserved
124
  )
125
 
126
  position = await get_queue_position(db, job.job_id)
 
136
 
137
  @router.post("/edit-image")
138
  async def edit_image(
139
+ req: Request,
140
  request: EditImageRequest,
 
141
  db: AsyncSession = Depends(get_db)
142
  ):
143
  """
144
  Queue an image edit job.
145
+ Auth and credit validation handled by middleware.
146
  """
147
+ user = req.state.user
148
+ credits_reserved = req.state.credits_reserved
149
  job = await create_job(
150
  db=db,
151
  user=user,
 
155
  "mime_type": request.mime_type,
156
  "prompt": request.prompt
157
  },
158
+ credits_reserved=credits_reserved
159
  )
160
 
161
  position = await get_queue_position(db, job.job_id)
 
171
 
172
  @router.post("/generate-video")
173
  async def generate_video(
174
+ req: Request,
175
  request: GenerateVideoRequest,
 
176
  db: AsyncSession = Depends(get_db)
177
  ):
178
  """
179
  Queue a video generation job.
180
+ Auth and credit validation handled by middleware.
181
  """
182
+ user = req.state.user
183
+ credits_reserved = req.state.credits_reserved
184
  job = await create_job(
185
  db=db,
186
  user=user,
 
193
  "resolution": request.resolution,
194
  "number_of_videos": request.number_of_videos
195
  },
196
+ credits_reserved=credits_reserved # 10 credits for video
197
  )
198
 
199
  position = await get_queue_position(db, job.job_id)
 
209
 
210
  @router.post("/generate-text")
211
  async def generate_text(
212
+ req: Request,
213
  request: GenerateTextRequest,
 
214
  db: AsyncSession = Depends(get_db)
215
  ):
216
  """
217
  Queue a text generation job.
218
+ Auth and credit validation handled by middleware.
219
  """
220
+ user = req.state.user
221
+ credits_reserved = req.state.credits_reserved
222
  job = await create_job(
223
  db=db,
224
  user=user,
 
227
  "prompt": request.prompt,
228
  "model": request.model
229
  },
230
+ credits_reserved=credits_reserved
231
  )
232
 
233
  position = await get_queue_position(db, job.job_id)
 
243
 
244
  @router.post("/analyze-image")
245
  async def analyze_image(
246
+ req: Request,
247
  request: AnalyzeImageRequest,
 
248
  db: AsyncSession = Depends(get_db)
249
  ):
250
  """
251
  Queue an image analysis job.
252
+ Auth and credit validation handled by middleware.
253
  """
254
+ user = req.state.user
255
+ credits_reserved = req.state.credits_reserved
256
  job = await create_job(
257
  db=db,
258
  user=user,
 
262
  "mime_type": request.mime_type,
263
  "prompt": request.prompt
264
  },
265
+ credits_reserved=credits_reserved
266
  )
267
 
268
  position = await get_queue_position(db, job.job_id)
 
278
 
279
  @router.get("/jobs")
280
  async def get_jobs(
281
+ req: Request,
282
  db: AsyncSession = Depends(get_db),
283
  page: int = 1,
284
  limit: int = 20
 
286
  """
287
  Get all jobs created by the current user.
288
  Returns a paginated list of jobs with status, type, and prompt (for video jobs).
289
+ Auth handled by AuthMiddleware - user in request.state.user
290
  """
291
+ user = req.state.user
292
  offset = (page - 1) * limit
293
 
294
  # Query jobs for the current user
 
342
  @router.get("/job/{job_id}")
343
  async def get_job_status(
344
  job_id: str,
345
+ req: Request,
346
  db: AsyncSession = Depends(get_db)
347
  ):
348
  """
349
  Get the status of a job.
350
  Poll this endpoint until status is 'completed' or 'failed'.
351
  For processing video jobs, this will check the Gemini API status and update the job.
352
+ Auth handled by AuthMiddleware - user in request.state.user
353
  """
354
+ user = req.state.user
355
  query = select(GeminiJob).where(
356
  GeminiJob.job_id == job_id,
357
  GeminiJob.user_id == user.id # Integer FK comparison
 
409
  @router.get("/download/{job_id}")
410
  async def download_video(
411
  job_id: str,
412
+ req: Request,
413
  db: AsyncSession = Depends(get_db)
414
  ):
415
  """
416
  Download a generated video.
417
  Downloads from Gemini URL, streams to client, then deletes local file.
418
  No permanent storage on server.
419
+ Auth handled by AuthMiddleware - user in request.state.user
420
  """
421
+ user = req.state.user
422
  from fastapi.responses import StreamingResponse
423
  import httpx
424
 
 
490
  @router.post("/job/{job_id}/cancel")
491
  async def cancel_job(
492
  job_id: str,
493
+ req: Request,
494
  db: AsyncSession = Depends(get_db)
495
  ):
496
  """
497
  Cancel a queued job.
498
  Only jobs with status 'queued' can be cancelled.
499
  Processing/completed/failed jobs cannot be cancelled.
500
+ Auth handled by AuthMiddleware - user in request.state.user
501
  """
502
+ user = req.state.user
503
  query = select(GeminiJob).where(
504
  GeminiJob.job_id == job_id,
505
  GeminiJob.user_id == user.id # Integer FK comparison
 
534
  @router.delete("/job/{job_id}")
535
  async def delete_job(
536
  job_id: str,
537
+ req: Request,
538
  db: AsyncSession = Depends(get_db)
539
  ):
540
  """
 
543
  Refund policy:
544
  - If queued: Refund 8 credits (10 cost - 2 penalty), soft delete job.
545
  - If processing/completed/failed: Soft delete job (no refund).
546
+ Auth handled by AuthMiddleware - user in request.state.user
547
  """
548
+ user = req.state.user
549
  from services.db_service import QueryService
550
 
551
  qs = QueryService(user, db)
routers/payments.py CHANGED
@@ -21,7 +21,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
21
 
22
  from core.database import get_db
23
  from core.models import User, PaymentTransaction
24
- from dependencies import get_current_user
25
  from services.drive_service import DriveService
26
  from services.razorpay_service import (
27
  RazorpayService,
@@ -223,8 +222,8 @@ async def get_packages():
223
 
224
  @router.post("/create-order", response_model=CreateOrderResponse)
225
  async def create_order(
 
226
  request: CreateOrderRequest,
227
- user: User = Depends(get_current_user),
228
  db: AsyncSession = Depends(get_db)
229
  ):
230
  """
@@ -232,7 +231,9 @@ async def create_order(
232
 
233
  The client should use the returned order_id to open
234
  Razorpay checkout. After payment, call /verify endpoint.
 
235
  """
 
236
  # Check if Razorpay is configured
237
  if not is_razorpay_configured():
238
  raise HTTPException(
@@ -315,9 +316,9 @@ async def create_order(
315
 
316
  @router.post("/verify", response_model=VerifyPaymentResponse)
317
  async def verify_payment(
 
318
  request: VerifyPaymentRequest,
319
  background_tasks: BackgroundTasks,
320
- user: User = Depends(get_current_user),
321
  db: AsyncSession = Depends(get_db)
322
  ):
323
  """
@@ -325,7 +326,9 @@ async def verify_payment(
325
 
326
  Called after successful Razorpay checkout.
327
  Verifies the payment signature and credits the user.
 
328
  """
 
329
  try:
330
  razorpay_service = get_razorpay_service()
331
 
@@ -557,16 +560,18 @@ async def razorpay_webhook(
557
 
558
  @router.get("/history", response_model=PaymentHistoryResponse)
559
  async def get_payment_history(
 
560
  page: int = Query(1, ge=1, description="Page number"),
561
  limit: int = Query(20, ge=1, le=100, description="Items per page"),
562
- user: User = Depends(get_current_user),
563
  db: AsyncSession = Depends(get_db)
564
  ):
565
  """
566
  Get user's payment history with pagination.
567
 
568
  Returns payment transactions ordered by newest first.
 
569
  """
 
570
  # Get total count
571
  count_result = await db.execute(
572
  select(func.count(PaymentTransaction.id))
 
21
 
22
  from core.database import get_db
23
  from core.models import User, PaymentTransaction
 
24
  from services.drive_service import DriveService
25
  from services.razorpay_service import (
26
  RazorpayService,
 
222
 
223
  @router.post("/create-order", response_model=CreateOrderResponse)
224
  async def create_order(
225
+ req: Request,
226
  request: CreateOrderRequest,
 
227
  db: AsyncSession = Depends(get_db)
228
  ):
229
  """
 
231
 
232
  The client should use the returned order_id to open
233
  Razorpay checkout. After payment, call /verify endpoint.
234
+ Auth handled by AuthMiddleware - user in request.state.user
235
  """
236
+ user = req.state.user
237
  # Check if Razorpay is configured
238
  if not is_razorpay_configured():
239
  raise HTTPException(
 
316
 
317
  @router.post("/verify", response_model=VerifyPaymentResponse)
318
  async def verify_payment(
319
+ req: Request,
320
  request: VerifyPaymentRequest,
321
  background_tasks: BackgroundTasks,
 
322
  db: AsyncSession = Depends(get_db)
323
  ):
324
  """
 
326
 
327
  Called after successful Razorpay checkout.
328
  Verifies the payment signature and credits the user.
329
+ Auth handled by AuthMiddleware - user in request.state.user
330
  """
331
+ user = req.state.user
332
  try:
333
  razorpay_service = get_razorpay_service()
334
 
 
560
 
561
  @router.get("/history", response_model=PaymentHistoryResponse)
562
  async def get_payment_history(
563
+ req: Request,
564
  page: int = Query(1, ge=1, description="Page number"),
565
  limit: int = Query(20, ge=1, le=100, description="Items per page"),
 
566
  db: AsyncSession = Depends(get_db)
567
  ):
568
  """
569
  Get user's payment history with pagination.
570
 
571
  Returns payment transactions ordered by newest first.
572
+ Auth handled by AuthMiddleware - user in request.state.user
573
  """
574
+ user = req.state.user
575
  # Get total count
576
  count_result = await db.execute(
577
  select(func.count(PaymentTransaction.id))
services/auth_service/__init__.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Auth Service - Authentication layer for API Gateway
3
+
4
+ Provides plug-and-play authentication with:
5
+ - Google OAuth integration
6
+ - JWT token management
7
+ - Request middleware for auth validation
8
+ - URL-based route configuration
9
+
10
+ Usage:
11
+ # In app.py startup
12
+ from services.auth_service import register_auth_service
13
+
14
+ register_auth_service(
15
+ required_urls=["/api/*", "/admin/*"],
16
+ public_urls=["/", "/health", "/auth/*"],
17
+ jwt_secret=os.getenv("JWT_SECRET"),
18
+ google_client_id=os.getenv("GOOGLE_CLIENT_ID")
19
+ )
20
+
21
+ # In routers
22
+ from fastapi import Request
23
+
24
+ @router.get("/protected")
25
+ async def protected_route(request: Request):
26
+ user = request.state.user # Populated by AuthMiddleware
27
+ return {"user_id": user.id}
28
+ """
29
+
30
+ from services.auth_service.config import AuthServiceConfig
31
+ from services.auth_service.middleware import AuthMiddleware
32
+ from services.auth_service.google_provider import (
33
+ GoogleAuthService,
34
+ GoogleUserInfo,
35
+ verify_google_token,
36
+ GoogleAuthError,
37
+ InvalidTokenError as GoogleInvalidTokenError,
38
+ )
39
+ from services.auth_service.jwt_provider import (
40
+ JWTService,
41
+ TokenPayload,
42
+ create_access_token,
43
+ verify_access_token,
44
+ JWTError,
45
+ TokenExpiredError,
46
+ InvalidTokenError,
47
+ )
48
+
49
+
50
+ def register_auth_service(
51
+ required_urls: list = None,
52
+ optional_urls: list = None,
53
+ public_urls: list = None,
54
+ jwt_secret: str = None,
55
+ jwt_algorithm: str = "HS256",
56
+ jwt_expiry_hours: int = 24,
57
+ google_client_id: str = None,
58
+ admin_emails: list = None,
59
+ ) -> None:
60
+ """
61
+ Register the auth service with application configuration.
62
+
63
+ Args:
64
+ required_urls: URLs that REQUIRE authentication
65
+ optional_urls: URLs where authentication is optional
66
+ public_urls: URLs that don't need authentication
67
+ jwt_secret: Secret key for JWT signing
68
+ jwt_algorithm: JWT algorithm (default: HS256)
69
+ jwt_expiry_hours: Token expiry in hours (default: 24)
70
+ google_client_id: Google OAuth Client ID
71
+ admin_emails: List of admin email addresses
72
+ """
73
+ AuthServiceConfig.register(
74
+ required_urls=required_urls or [],
75
+ optional_urls=optional_urls or [],
76
+ public_urls=public_urls or [],
77
+ jwt_secret=jwt_secret,
78
+ jwt_algorithm=jwt_algorithm,
79
+ jwt_expiry_hours=jwt_expiry_hours,
80
+ google_client_id=google_client_id,
81
+ admin_emails=admin_emails or [],
82
+ )
83
+
84
+
85
+ __all__ = [
86
+ # Registration
87
+ 'register_auth_service',
88
+ 'AuthServiceConfig',
89
+ 'AuthMiddleware',
90
+
91
+ # Google OAuth
92
+ 'GoogleAuthService',
93
+ 'GoogleUserInfo',
94
+ 'verify_google_token',
95
+ 'GoogleAuthError',
96
+ 'GoogleInvalidTokenError',
97
+
98
+ # JWT
99
+ 'JWTService',
100
+ 'TokenPayload',
101
+ 'create_access_token',
102
+ 'verify_access_token',
103
+ 'JWTError',
104
+ 'TokenExpiredError',
105
+ 'InvalidTokenError',
106
+ ]
services/auth_service/config.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Auth Service Configuration
3
+
4
+ Manages authentication configuration and route matching for the auth service.
5
+ """
6
+
7
+ import logging
8
+ from typing import List
9
+ from services.base_service import BaseService, ServiceConfig
10
+ from services.base_service.route_matcher import RouteConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class AuthServiceConfig(BaseService):
16
+ """
17
+ Configuration for the auth service.
18
+
19
+ Controls which routes require authentication, which are optional,
20
+ and which are public (no auth needed).
21
+ """
22
+
23
+ SERVICE_NAME = "auth_service"
24
+
25
+ # Route configuration
26
+ _route_config: RouteConfig = None
27
+
28
+ # JWT configuration
29
+ _jwt_secret: str = None
30
+ _jwt_algorithm: str = "HS256"
31
+ _jwt_expiry_hours: int = 24
32
+
33
+ # Google OAuth configuration
34
+ _google_client_id: str = None
35
+
36
+ # Admin configuration
37
+ _admin_emails: List[str] = []
38
+
39
+ @classmethod
40
+ def register(
41
+ cls,
42
+ required_urls: List[str] = None,
43
+ optional_urls: List[str] = None,
44
+ public_urls: List[str] = None,
45
+ jwt_secret: str = None,
46
+ jwt_algorithm: str = "HS256",
47
+ jwt_expiry_hours: int = 24,
48
+ google_client_id: str = None,
49
+ admin_emails: List[str] = None,
50
+ ) -> None:
51
+ """
52
+ Register auth service configuration.
53
+
54
+ Args:
55
+ required_urls: URLs that REQUIRE authentication
56
+ optional_urls: URLs where authentication is optional
57
+ public_urls: URLs that don't need authentication
58
+ jwt_secret: Secret key for JWT signing
59
+ jwt_algorithm: JWT algorithm (default: HS256)
60
+ jwt_expiry_hours: Token expiry in hours (default: 24)
61
+ google_client_id: Google OAuth Client ID
62
+ admin_emails: List of admin email addresses
63
+
64
+ Raises:
65
+ RuntimeError: If service is already registered
66
+ ValueError: If jwt_secret is not provided
67
+ """
68
+ if cls._registered:
69
+ raise RuntimeError(f"{cls.SERVICE_NAME} is already registered")
70
+
71
+ # Validate JWT secret
72
+ if not jwt_secret:
73
+ raise ValueError("jwt_secret is required for auth service")
74
+
75
+ # Store route configuration
76
+ cls._route_config = RouteConfig(
77
+ required=required_urls or [],
78
+ optional=optional_urls or [],
79
+ public=public_urls or [],
80
+ )
81
+
82
+ # Store JWT configuration
83
+ cls._jwt_secret = jwt_secret
84
+ cls._jwt_algorithm = jwt_algorithm
85
+ cls._jwt_expiry_hours = jwt_expiry_hours
86
+
87
+ # Store Google OAuth configuration
88
+ cls._google_client_id = google_client_id
89
+
90
+ # Store admin configuration
91
+ cls._admin_emails = admin_emails or []
92
+
93
+ cls._registered = True
94
+
95
+ logger.info(f"✅ {cls.SERVICE_NAME} registered successfully")
96
+ logger.info(f" JWT algorithm: {cls._jwt_algorithm}")
97
+ logger.info(f" JWT expiry: {cls._jwt_expiry_hours} hours")
98
+ logger.info(f" Required URLs: {len(required_urls or [])}")
99
+ logger.info(f" Optional URLs: {len(optional_urls or [])}")
100
+ logger.info(f" Public URLs: {len(public_urls or [])}")
101
+ logger.info(f" Admin emails: {len(cls._admin_emails)}")
102
+
103
+ @classmethod
104
+ def get_middleware(cls):
105
+ """Return AuthMiddleware instance."""
106
+ from services.auth_service.middleware import AuthMiddleware
107
+ return AuthMiddleware
108
+
109
+ @classmethod
110
+ def requires_auth(cls, path: str) -> bool:
111
+ """Check if a URL path requires authentication."""
112
+ cls.assert_registered()
113
+ return cls._route_config.is_required(path)
114
+
115
+ @classmethod
116
+ def allows_optional_auth(cls, path: str) -> bool:
117
+ """Check if a URL path allows optional authentication."""
118
+ cls.assert_registered()
119
+ return cls._route_config.is_optional(path)
120
+
121
+ @classmethod
122
+ def is_public(cls, path: str) -> bool:
123
+ """Check if a URL path is public (no auth needed)."""
124
+ cls.assert_registered()
125
+ return cls._route_config.is_public(path)
126
+
127
+ @classmethod
128
+ def get_jwt_secret(cls) -> str:
129
+ """Get JWT secret key."""
130
+ cls.assert_registered()
131
+ return cls._jwt_secret
132
+
133
+ @classmethod
134
+ def get_jwt_algorithm(cls) -> str:
135
+ """Get JWT algorithm."""
136
+ cls.assert_registered()
137
+ return cls._jwt_algorithm
138
+
139
+ @classmethod
140
+ def get_jwt_expiry_hours(cls) -> int:
141
+ """Get JWT expiry hours."""
142
+ cls.assert_registered()
143
+ return cls._jwt_expiry_hours
144
+
145
+ @classmethod
146
+ def get_google_client_id(cls) -> str:
147
+ """Get Google OAuth Client ID."""
148
+ cls.assert_registered()
149
+ return cls._google_client_id
150
+
151
+ @classmethod
152
+ def is_admin(cls, email: str) -> bool:
153
+ """Check if an email is an admin."""
154
+ cls.assert_registered()
155
+ return email in cls._admin_emails
156
+
157
+ @classmethod
158
+ def get_admin_emails(cls) -> List[str]:
159
+ """Get list of admin emails."""
160
+ cls.assert_registered()
161
+ return cls._admin_emails.copy()
162
+
163
+
164
+ __all__ = ['AuthServiceConfig']
services/auth_service/google_provider.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modular Google OAuth Service
3
+
4
+ A self-contained, plug-and-play service for verifying Google ID tokens.
5
+ Can be used in any Python application with minimal configuration.
6
+
7
+ Usage:
8
+ from services.google_auth_service import GoogleAuthService, GoogleUserInfo
9
+
10
+ # Initialize with client ID
11
+ auth_service = GoogleAuthService(client_id="your-google-client-id")
12
+
13
+ # Or use environment variable GOOGLE_CLIENT_ID
14
+ auth_service = GoogleAuthService()
15
+
16
+ # Verify a Google ID token
17
+ user_info = auth_service.verify_token(id_token)
18
+ print(user_info.email, user_info.google_id, user_info.name)
19
+
20
+ Environment Variables:
21
+ GOOGLE_CLIENT_ID: Your Google OAuth 2.0 Client ID
22
+
23
+ Dependencies:
24
+ google-auth>=2.0.0
25
+ google-auth-oauthlib>=1.0.0
26
+ """
27
+
28
+ import os
29
+ import logging
30
+ from dataclasses import dataclass
31
+ from typing import Optional
32
+ from google.oauth2 import id_token as google_id_token
33
+ from google.auth.transport import requests as google_requests
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ @dataclass
39
+ class GoogleUserInfo:
40
+ """
41
+ User information extracted from a verified Google ID token.
42
+
43
+ Attributes:
44
+ google_id: Unique Google user identifier (sub claim)
45
+ email: User's email address
46
+ email_verified: Whether Google has verified the email
47
+ name: User's display name (may be None)
48
+ picture: URL to user's profile picture (may be None)
49
+ given_name: User's first name (may be None)
50
+ family_name: User's last name (may be None)
51
+ locale: User's locale preference (may be None)
52
+ """
53
+ google_id: str
54
+ email: str
55
+ email_verified: bool = True
56
+ name: Optional[str] = None
57
+ picture: Optional[str] = None
58
+ given_name: Optional[str] = None
59
+ family_name: Optional[str] = None
60
+ locale: Optional[str] = None
61
+
62
+
63
+ class GoogleAuthError(Exception):
64
+ """Base exception for Google Auth errors."""
65
+ pass
66
+
67
+
68
+ class InvalidTokenError(GoogleAuthError):
69
+ """Raised when the token is invalid or expired."""
70
+ pass
71
+
72
+
73
+ class ConfigurationError(GoogleAuthError):
74
+ """Raised when the service is not properly configured."""
75
+ pass
76
+
77
+
78
+ class GoogleAuthService:
79
+ """
80
+ Service for verifying Google OAuth ID tokens.
81
+
82
+ This service validates ID tokens issued by Google Sign-In and extracts
83
+ user information. It's designed to be modular and reusable across
84
+ different applications.
85
+
86
+ Example:
87
+ service = GoogleAuthService()
88
+ try:
89
+ user_info = service.verify_token(token_from_frontend)
90
+ print(f"Welcome {user_info.name}!")
91
+ except InvalidTokenError:
92
+ print("Invalid or expired token")
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ client_id: Optional[str] = None,
98
+ clock_skew_seconds: int = 0
99
+ ):
100
+ """
101
+ Initialize the Google Auth Service.
102
+
103
+ Args:
104
+ client_id: Google OAuth 2.0 Client ID. If not provided,
105
+ falls back to GOOGLE_CLIENT_ID environment variable.
106
+ clock_skew_seconds: Allowed clock skew in seconds for token
107
+ validation (default: 0).
108
+
109
+ Raises:
110
+ ConfigurationError: If no client_id is provided or found.
111
+ """
112
+ self.client_id = client_id or os.getenv("AUTH_SIGN_IN_GOOGLE_CLIENT_ID")
113
+ self.clock_skew_seconds = clock_skew_seconds
114
+
115
+ if not self.client_id:
116
+ raise ConfigurationError(
117
+ "Google Client ID is required. Either pass client_id parameter "
118
+ "or set GOOGLE_CLIENT_ID environment variable."
119
+ )
120
+
121
+ logger.info(f"GoogleAuthService initialized with client_id: {self.client_id[:20]}...")
122
+
123
+ def verify_token(self, id_token: str) -> GoogleUserInfo:
124
+ """
125
+ Verify a Google ID token and extract user information.
126
+
127
+ Args:
128
+ id_token: The ID token received from the frontend after
129
+ Google Sign-In.
130
+
131
+ Returns:
132
+ GoogleUserInfo: Dataclass containing user's Google profile info.
133
+
134
+ Raises:
135
+ InvalidTokenError: If the token is invalid, expired, or
136
+ doesn't match the expected client ID.
137
+ """
138
+ if not id_token:
139
+ raise InvalidTokenError("Token cannot be empty")
140
+
141
+ try:
142
+ # Verify the token with Google
143
+ idinfo = google_id_token.verify_oauth2_token(
144
+ id_token,
145
+ google_requests.Request(),
146
+ self.client_id,
147
+ clock_skew_in_seconds=self.clock_skew_seconds
148
+ )
149
+
150
+ # Validate issuer
151
+ if idinfo.get("iss") not in ["accounts.google.com", "https://accounts.google.com"]:
152
+ raise InvalidTokenError("Invalid token issuer")
153
+
154
+ # Validate audience
155
+ if idinfo.get("aud") != self.client_id:
156
+ raise InvalidTokenError("Token was not issued for this application")
157
+
158
+ # Extract user info
159
+ return GoogleUserInfo(
160
+ google_id=idinfo["sub"],
161
+ email=idinfo["email"],
162
+ email_verified=idinfo.get("email_verified", False),
163
+ name=idinfo.get("name"),
164
+ picture=idinfo.get("picture"),
165
+ given_name=idinfo.get("given_name"),
166
+ family_name=idinfo.get("family_name"),
167
+ locale=idinfo.get("locale")
168
+ )
169
+
170
+ except ValueError as e:
171
+ logger.warning(f"Token verification failed: {e}")
172
+ raise InvalidTokenError(f"Token verification failed: {str(e)}")
173
+ except Exception as e:
174
+ logger.error(f"Unexpected error during token verification: {e}")
175
+ raise InvalidTokenError(f"Token verification error: {str(e)}")
176
+
177
+ def verify_token_safe(self, id_token: str) -> Optional[GoogleUserInfo]:
178
+ """
179
+ Verify a Google ID token without raising exceptions.
180
+
181
+ Useful for cases where you want to check validity without
182
+ exception handling.
183
+
184
+ Args:
185
+ id_token: The ID token to verify.
186
+
187
+ Returns:
188
+ GoogleUserInfo if valid, None if invalid.
189
+ """
190
+ try:
191
+ return self.verify_token(id_token)
192
+ except GoogleAuthError:
193
+ return None
194
+
195
+
196
+ # Singleton instance for convenience (initialized on first use)
197
+ _default_service: Optional[GoogleAuthService] = None
198
+
199
+
200
+ def get_google_auth_service() -> GoogleAuthService:
201
+ """
202
+ Get the default GoogleAuthService instance.
203
+
204
+ Creates a singleton instance using environment variables.
205
+
206
+ Returns:
207
+ GoogleAuthService: The default service instance.
208
+
209
+ Raises:
210
+ ConfigurationError: If GOOGLE_CLIENT_ID is not set.
211
+ """
212
+ global _default_service
213
+ if _default_service is None:
214
+ _default_service = GoogleAuthService()
215
+ return _default_service
216
+
217
+
218
+ def verify_google_token(id_token: str) -> GoogleUserInfo:
219
+ """
220
+ Convenience function to verify a token using the default service.
221
+
222
+ Args:
223
+ id_token: The Google ID token to verify.
224
+
225
+ Returns:
226
+ GoogleUserInfo: Verified user information.
227
+
228
+ Raises:
229
+ InvalidTokenError: If verification fails.
230
+ ConfigurationError: If service is not configured.
231
+ """
232
+ return get_google_auth_service().verify_token(id_token)
services/auth_service/jwt_provider.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modular JWT Service
3
+
4
+ A self-contained, plug-and-play service for creating and verifying JWT tokens.
5
+ Can be used in any Python application with minimal configuration.
6
+
7
+ Usage:
8
+ from services.jwt_service import JWTService, TokenPayload
9
+
10
+ # Initialize with secret key
11
+ jwt_service = JWTService(secret_key="your-secret-key")
12
+
13
+ # Or use environment variable JWT_SECRET
14
+ jwt_service = JWTService()
15
+
16
+ # Create a token
17
+ token = jwt_service.create_token(user_id="user123", email="user@example.com")
18
+
19
+ # Verify a token
20
+ payload = jwt_service.verify_token(token)
21
+ print(payload.user_id, payload.email)
22
+
23
+ Environment Variables:
24
+ JWT_SECRET: Your secret key for signing tokens (required)
25
+ JWT_EXPIRY_HOURS: Token expiry in hours (default: 168 = 7 days)
26
+ JWT_ALGORITHM: Algorithm to use (default: HS256)
27
+
28
+ Dependencies:
29
+ PyJWT>=2.8.0
30
+
31
+ Generate a secure secret:
32
+ python -c "import secrets; print(secrets.token_urlsafe(64))"
33
+ """
34
+
35
+ import os
36
+ import logging
37
+ from dataclasses import dataclass
38
+ from datetime import datetime, timedelta
39
+ from typing import Optional, Dict, Any
40
+ import jwt
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ @dataclass
46
+ class TokenPayload:
47
+ """
48
+ Payload extracted from a verified JWT token.
49
+
50
+ Attributes:
51
+ user_id: The user's unique identifier (sub claim)
52
+ email: The user's email address
53
+ issued_at: When the token was issued
54
+ expires_at: When the token expires
55
+ token_version: Version number for token invalidation
56
+ extra: Any additional claims in the token
57
+ """
58
+ user_id: str
59
+ email: str
60
+ issued_at: datetime
61
+ expires_at: datetime
62
+ token_version: int = 1
63
+ extra: Dict[str, Any] = None
64
+
65
+ def __post_init__(self):
66
+ if self.extra is None:
67
+ self.extra = {}
68
+
69
+ @property
70
+ def is_expired(self) -> bool:
71
+ """Check if the token has expired."""
72
+ return datetime.utcnow() > self.expires_at
73
+
74
+ @property
75
+ def time_until_expiry(self) -> timedelta:
76
+ """Get time remaining until expiry."""
77
+ return self.expires_at - datetime.utcnow()
78
+
79
+
80
+ class JWTError(Exception):
81
+ """Base exception for JWT errors."""
82
+ pass
83
+
84
+
85
+ class TokenExpiredError(JWTError):
86
+ """Raised when the token has expired."""
87
+ pass
88
+
89
+
90
+ class InvalidTokenError(JWTError):
91
+ """Raised when the token is invalid."""
92
+ pass
93
+
94
+
95
+ class ConfigurationError(JWTError):
96
+ """Raised when the service is not properly configured."""
97
+ pass
98
+
99
+
100
+ class JWTService:
101
+ """
102
+ Service for creating and verifying JWT tokens.
103
+
104
+ This service handles JWT token lifecycle for authentication.
105
+ It's designed to be modular and reusable across different applications.
106
+
107
+ Example:
108
+ service = JWTService(secret_key="my-secret")
109
+
110
+ # Create token
111
+ token = service.create_token(user_id="u123", email="a@b.com")
112
+
113
+ # Verify token
114
+ try:
115
+ payload = service.verify_token(token)
116
+ print(f"User: {payload.user_id}")
117
+ except TokenExpiredError:
118
+ print("Token expired, please login again")
119
+ except InvalidTokenError:
120
+ print("Invalid token")
121
+ """
122
+
123
+ # Default configuration
124
+ DEFAULT_ALGORITHM = "HS256"
125
+ DEFAULT_EXPIRY_HOURS = 168 # 7 days
126
+
127
+ def __init__(
128
+ self,
129
+ secret_key: Optional[str] = None,
130
+ algorithm: Optional[str] = None,
131
+ expiry_hours: Optional[int] = None
132
+ ):
133
+ """
134
+ Initialize the JWT Service.
135
+
136
+ Args:
137
+ secret_key: Secret key for signing tokens. If not provided,
138
+ falls back to JWT_SECRET environment variable.
139
+ algorithm: JWT algorithm (default: HS256).
140
+ expiry_hours: Token expiry in hours (default: 168 = 7 days).
141
+
142
+ Raises:
143
+ ConfigurationError: If no secret_key is provided or found.
144
+ """
145
+ self.secret_key = secret_key or os.getenv("JWT_SECRET")
146
+ self.algorithm = algorithm or os.getenv("JWT_ALGORITHM", self.DEFAULT_ALGORITHM)
147
+ self.expiry_hours = expiry_hours or int(
148
+ os.getenv("JWT_EXPIRY_HOURS", str(self.DEFAULT_EXPIRY_HOURS))
149
+ )
150
+
151
+ if not self.secret_key:
152
+ raise ConfigurationError(
153
+ "JWT secret key is required. Either pass secret_key parameter "
154
+ "or set JWT_SECRET environment variable. "
155
+ "Generate one with: python -c \"import secrets; print(secrets.token_urlsafe(64))\""
156
+ )
157
+
158
+ # Warn if secret is too short
159
+ if len(self.secret_key) < 32:
160
+ logger.warning(
161
+ "JWT secret key is short (< 32 chars). "
162
+ "Consider using a longer secret for better security."
163
+ )
164
+
165
+ logger.info(
166
+ f"JWTService initialized (algorithm={self.algorithm}, "
167
+ f"expiry={self.expiry_hours}h)"
168
+ )
169
+
170
+ def create_token(
171
+ self,
172
+ user_id: str,
173
+ email: str,
174
+ token_version: int = 1,
175
+ extra_claims: Optional[Dict[str, Any]] = None,
176
+ expiry_hours: Optional[int] = None
177
+ ) -> str:
178
+ """
179
+ Create a JWT token for a user.
180
+
181
+ Args:
182
+ user_id: The user's unique identifier.
183
+ email: The user's email address.
184
+ token_version: User's current token version for invalidation.
185
+ extra_claims: Additional claims to include in the token.
186
+ expiry_hours: Custom expiry for this token (overrides default).
187
+
188
+ Returns:
189
+ str: The encoded JWT token.
190
+ """
191
+ now = datetime.utcnow()
192
+ expiry = expiry_hours or self.expiry_hours
193
+
194
+ payload = {
195
+ "sub": user_id,
196
+ "email": email,
197
+ "tv": token_version, # Token version for invalidation
198
+ "iat": now,
199
+ "exp": now + timedelta(hours=expiry),
200
+ }
201
+
202
+ if extra_claims:
203
+ payload.update(extra_claims)
204
+
205
+ token = jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
206
+
207
+ logger.debug(f"Created token for user_id={user_id} (version={token_version})")
208
+ return token
209
+
210
+ def verify_token(self, token: str) -> TokenPayload:
211
+ """
212
+ Verify a JWT token and extract the payload.
213
+
214
+ Args:
215
+ token: The JWT token to verify.
216
+
217
+ Returns:
218
+ TokenPayload: Dataclass containing the verified payload.
219
+
220
+ Raises:
221
+ TokenExpiredError: If the token has expired.
222
+ InvalidTokenError: If the token is invalid or malformed.
223
+ """
224
+ if not token:
225
+ raise InvalidTokenError("Token cannot be empty")
226
+
227
+ try:
228
+ payload = jwt.decode(
229
+ token,
230
+ self.secret_key,
231
+ algorithms=[self.algorithm]
232
+ )
233
+
234
+ # Extract standard claims
235
+ user_id = payload.get("sub")
236
+ email = payload.get("email")
237
+ token_version = payload.get("tv", 1) # Default to 1 for backward compatibility
238
+ iat = payload.get("iat")
239
+ exp = payload.get("exp")
240
+
241
+ if not user_id or not email:
242
+ raise InvalidTokenError("Token missing required claims (sub, email)")
243
+
244
+ # Convert timestamps to datetime
245
+ issued_at = datetime.utcfromtimestamp(iat) if isinstance(iat, (int, float)) else iat
246
+ expires_at = datetime.utcfromtimestamp(exp) if isinstance(exp, (int, float)) else exp
247
+
248
+ # Extract extra claims
249
+ standard_claims = {"sub", "email", "tv", "iat", "exp"}
250
+ extra = {k: v for k, v in payload.items() if k not in standard_claims}
251
+
252
+ return TokenPayload(
253
+ user_id=user_id,
254
+ email=email,
255
+ issued_at=issued_at,
256
+ expires_at=expires_at,
257
+ token_version=token_version,
258
+ extra=extra
259
+ )
260
+
261
+ except jwt.ExpiredSignatureError:
262
+ logger.debug("Token verification failed: expired")
263
+ raise TokenExpiredError("Token has expired")
264
+ except jwt.InvalidTokenError as e:
265
+ logger.debug(f"Token verification failed: {e}")
266
+ raise InvalidTokenError(f"Invalid token: {str(e)}")
267
+ except Exception as e:
268
+ logger.error(f"Unexpected error during token verification: {e}")
269
+ raise InvalidTokenError(f"Token verification error: {str(e)}")
270
+
271
+ def verify_token_safe(self, token: str) -> Optional[TokenPayload]:
272
+ """
273
+ Verify a JWT token without raising exceptions.
274
+
275
+ Args:
276
+ token: The JWT token to verify.
277
+
278
+ Returns:
279
+ TokenPayload if valid, None if invalid or expired.
280
+ """
281
+ try:
282
+ return self.verify_token(token)
283
+ except JWTError:
284
+ return None
285
+
286
+ def refresh_token(
287
+ self,
288
+ token: str,
289
+ expiry_hours: Optional[int] = None
290
+ ) -> str:
291
+ """
292
+ Refresh a token by creating a new one with the same claims.
293
+
294
+ Args:
295
+ token: The current (possibly expired) token.
296
+ expiry_hours: Custom expiry for the new token.
297
+
298
+ Returns:
299
+ str: A new JWT token with updated expiry.
300
+
301
+ Raises:
302
+ InvalidTokenError: If the token is malformed.
303
+ """
304
+ try:
305
+ # Decode without verifying expiry
306
+ payload = jwt.decode(
307
+ token,
308
+ self.secret_key,
309
+ algorithms=[self.algorithm],
310
+ options={"verify_exp": False}
311
+ )
312
+
313
+ user_id = payload.get("sub")
314
+ email = payload.get("email")
315
+
316
+ if not user_id or not email:
317
+ raise InvalidTokenError("Token missing required claims")
318
+
319
+ # Preserve extra claims
320
+ standard_claims = {"sub", "email", "iat", "exp"}
321
+ extra = {k: v for k, v in payload.items() if k not in standard_claims}
322
+
323
+ return self.create_token(
324
+ user_id=user_id,
325
+ email=email,
326
+ extra_claims=extra,
327
+ expiry_hours=expiry_hours
328
+ )
329
+
330
+ except jwt.InvalidTokenError as e:
331
+ raise InvalidTokenError(f"Cannot refresh invalid token: {str(e)}")
332
+
333
+
334
+ # Singleton instance for convenience
335
+ _default_service: Optional[JWTService] = None
336
+
337
+
338
+ def get_jwt_service() -> JWTService:
339
+ """
340
+ Get the default JWTService instance.
341
+
342
+ Creates a singleton instance using environment variables.
343
+
344
+ Returns:
345
+ JWTService: The default service instance.
346
+
347
+ Raises:
348
+ ConfigurationError: If JWT_SECRET is not set.
349
+ """
350
+ global _default_service
351
+ if _default_service is None:
352
+ _default_service = JWTService()
353
+ return _default_service
354
+
355
+
356
+ def create_access_token(user_id: str, email: str, token_version: int = 1, **kwargs) -> str:
357
+ """
358
+ Convenience function to create a token using the default service.
359
+
360
+ Args:
361
+ user_id: The user's unique identifier.
362
+ email: The user's email address.
363
+ token_version: User's current token version for invalidation.
364
+ **kwargs: Additional arguments passed to create_token.
365
+
366
+ Returns:
367
+ str: The encoded JWT token.
368
+ """
369
+ return get_jwt_service().create_token(user_id, email, token_version, **kwargs)
370
+
371
+
372
+ def verify_access_token(token: str) -> TokenPayload:
373
+ """
374
+ Convenience function to verify a token using the default service.
375
+
376
+ Args:
377
+ token: The JWT token to verify.
378
+
379
+ Returns:
380
+ TokenPayload: Verified token payload.
381
+
382
+ Raises:
383
+ TokenExpiredError: If the token has expired.
384
+ InvalidTokenError: If the token is invalid.
385
+ """
386
+ return get_jwt_service().verify_token(token)
services/auth_service/middleware.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Auth Middleware - Request authentication layer
3
+
4
+ Intercepts requests to validate JWT tokens and attach authenticated
5
+ user to request.state for use in route handlers.
6
+ """
7
+
8
+ import logging
9
+ from fastapi import Request, HTTPException, status
10
+ from fastapi.responses import JSONResponse
11
+ from sqlalchemy import select
12
+ from sqlalchemy.ext.asyncio import AsyncSession
13
+ from starlette.middleware.base import BaseHTTPMiddleware
14
+
15
+ from core.database import async_session_maker
16
+ from core.models import User
17
+ from services.auth_service.config import AuthServiceConfig
18
+ from services.auth_service.jwt_provider import (
19
+ verify_access_token,
20
+ TokenExpiredError,
21
+ InvalidTokenError,
22
+ JWTError,
23
+ )
24
+ from services.base_service.middleware_chain import (
25
+ BaseServiceMiddleware,
26
+ get_request_context,
27
+ )
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class AuthMiddleware(BaseServiceMiddleware):
33
+ """
34
+ Authentication middleware for request validation.
35
+
36
+ Flow:
37
+ 1. Check if route requires/allows auth based on URL
38
+ 2. Extract Authorization header
39
+ 3. Verify JWT token
40
+ 4. Load user from database
41
+ 5. Attach user to request.state.user
42
+ 6. Continue to next middleware/route
43
+
44
+ Public routes skip all auth checks.
45
+ Required routes must have valid auth or return 401.
46
+ Optional routes attach user if auth is provided, but don't fail if missing.
47
+ """
48
+
49
+ SERVICE_NAME = "auth"
50
+
51
+ async def dispatch(self, request: Request, call_next):
52
+ """Process request through auth middleware."""
53
+ # Skip OPTIONS requests (CORS preflight)
54
+ if request.method == "OPTIONS":
55
+ return await call_next(request)
56
+
57
+ # Initialize request context
58
+ ctx = get_request_context(request)
59
+
60
+ # Get path and method from request
61
+ path = request.url.path
62
+
63
+ # Check if route is public (skip all auth)
64
+ if AuthServiceConfig.is_public(path):
65
+ self.log_request(request, "Public route, skipping auth")
66
+ request.state.user = None
67
+ ctx.user = None
68
+ ctx.is_authenticated = False
69
+ response = await call_next(request)
70
+ return response
71
+
72
+ # Check if route requires auth or allows optional auth
73
+ requires_auth = AuthServiceConfig.requires_auth(path)
74
+ allows_optional = AuthServiceConfig.allows_optional_auth(path)
75
+
76
+ # If route doesn't require auth and doesn't allow optional, skip
77
+ if not requires_auth and not allows_optional:
78
+ self.log_request(request, "Route not configured for auth, skipping")
79
+ request.state.user = None
80
+ ctx.user = None
81
+ ctx.is_authenticated = False
82
+ response = await call_next(request)
83
+ return response
84
+
85
+ # Extract Authorization header
86
+ auth_header = request.headers.get("Authorization")
87
+
88
+ # If no auth header
89
+ if not auth_header:
90
+ if requires_auth:
91
+ self.log_request(request, "Missing Authorization header (required)")
92
+ return JSONResponse(
93
+ status_code=status.HTTP_401_UNAUTHORIZED,
94
+ content={"detail": "Missing Authorization header"},
95
+ headers={"WWW-Authenticate": "Bearer"},
96
+ )
97
+ else:
98
+ # Optional auth, no header provided
99
+ self.log_request(request, "No auth header (optional route)")
100
+ request.state.user = None
101
+ ctx.user = None
102
+ ctx.is_authenticated = False
103
+ response = await call_next(request)
104
+ return response
105
+
106
+ # Validate Authorization header format
107
+ if not auth_header.startswith("Bearer "):
108
+ if requires_auth:
109
+ self.log_request(request, "Invalid Authorization header format")
110
+ return JSONResponse(
111
+ status_code=status.HTTP_401_UNAUTHORIZED,
112
+ content={"detail": "Invalid Authorization header format. Use: Bearer <token>"},
113
+ headers={"WWW-Authenticate": "Bearer"},
114
+ )
115
+ else:
116
+ # Optional auth, invalid format
117
+ request.state.user = None
118
+ ctx.user = None
119
+ ctx.is_authenticated = False
120
+ response = await call_next(request)
121
+ return response
122
+
123
+ # Extract token
124
+ token = auth_header.split(" ", 1)[1]
125
+
126
+ # Verify token
127
+ try:
128
+ payload = verify_access_token(token)
129
+ except TokenExpiredError:
130
+ if requires_auth:
131
+ self.log_request(request, "Token expired")
132
+ return JSONResponse(
133
+ status_code=status.HTTP_401_UNAUTHORIZED,
134
+ content={"detail": "Token has expired. Please sign in again."},
135
+ headers={"WWW-Authenticate": "Bearer"},
136
+ )
137
+ else:
138
+ # Optional auth, expired token
139
+ request.state.user = None
140
+ ctx.user = None
141
+ ctx.is_authenticated = False
142
+ response = await call_next(request)
143
+ return response
144
+ except (InvalidTokenError, JWTError) as e:
145
+ if requires_auth:
146
+ self.log_error(request, f"Token verification failed: {e}")
147
+ return JSONResponse(
148
+ status_code=status.HTTP_401_UNAUTHORIZED,
149
+ content={"detail": f"Invalid token: {str(e)}"},
150
+ headers={"WWW-Authenticate": "Bearer"},
151
+ )
152
+ else:
153
+ # Optional auth, invalid token
154
+ request.state.user = None
155
+ ctx.user = None
156
+ ctx.is_authenticated = False
157
+ response = await call_next(request)
158
+ return response
159
+
160
+ # Get database session
161
+ async with async_session_maker() as db:
162
+ try:
163
+ # Load user from database
164
+ query = select(User).where(
165
+ User.user_id == payload.user_id,
166
+ User.is_active == True
167
+ )
168
+ result = await db.execute(query)
169
+ user = result.scalar_one_or_none()
170
+
171
+ if not user:
172
+ if requires_auth:
173
+ self.log_request(request, "User not found or inactive")
174
+ return JSONResponse(
175
+ status_code=status.HTTP_401_UNAUTHORIZED,
176
+ content={"detail": "User not found or inactive"},
177
+ )
178
+ else:
179
+ # Optional auth, user not found
180
+ request.state.user = None
181
+ ctx.user = None
182
+ ctx.is_authenticated = False
183
+ response = await call_next(request)
184
+ return response
185
+
186
+ # Validate token version
187
+ if payload.token_version < user.token_version:
188
+ if requires_auth:
189
+ self.log_request(
190
+ request,
191
+ f"Token invalidated (version {payload.token_version} < {user.token_version})"
192
+ )
193
+ return JSONResponse(
194
+ status_code=status.HTTP_401_UNAUTHORIZED,
195
+ content={"detail": "Token has been invalidated. Please sign in again."},
196
+ headers={"WWW-Authenticate": "Bearer"},
197
+ )
198
+ else:
199
+ # Optional auth, invalidated token
200
+ request.state.user = None
201
+ ctx.user = None
202
+ ctx.is_authenticated = False
203
+ response = await call_next(request)
204
+ return response
205
+
206
+ # Attach user to request state
207
+ request.state.user = user
208
+ ctx.set_user(user)
209
+
210
+ # Check if user is admin
211
+ is_admin = AuthServiceConfig.is_admin(user.email)
212
+ request.state.is_admin = is_admin
213
+ ctx.set_flag('is_admin', is_admin)
214
+
215
+ self.log_request(request, f"Authenticated user: {user.user_id}")
216
+
217
+ # Continue to next middleware/route
218
+ response = await call_next(request)
219
+ return response
220
+
221
+ finally:
222
+ await db.close()
223
+
224
+
225
+ __all__ = ['AuthMiddleware']
services/base_service/__init__.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base Service Infrastructure
3
+
4
+ Provides the foundation for plug-and-play services in the API gateway.
5
+ All services (auth, credit, gemini, etc.) extend this base infrastructure.
6
+
7
+ Core Components:
8
+ - BaseService: Abstract base class for all services
9
+ - ServiceConfig: Configuration container
10
+ - ServiceRegistry: Global registry for service discovery
11
+ - MiddlewareProtocol: Type definition for middleware functions
12
+
13
+ Usage:
14
+ class MyService(BaseService):
15
+ @classmethod
16
+ def register(cls, **config):
17
+ # Service-specific registration
18
+ pass
19
+
20
+ @classmethod
21
+ def get_middleware(cls):
22
+ # Return middleware function if needed
23
+ return MyMiddleware()
24
+ """
25
+
26
+ import logging
27
+ from abc import ABC, abstractmethod
28
+ from typing import Dict, Type, Optional, Callable, Any
29
+ from starlette.middleware.base import BaseHTTPMiddleware
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class ServiceConfig:
35
+ """
36
+ Base configuration container for services.
37
+
38
+ Services can extend this to add their specific configuration.
39
+ """
40
+
41
+ def __init__(self, **kwargs):
42
+ """Initialize configuration with arbitrary key-value pairs."""
43
+ self._config = kwargs
44
+
45
+ def get(self, key: str, default: Any = None) -> Any:
46
+ """Get configuration value."""
47
+ return self._config.get(key, default)
48
+
49
+ def set(self, key: str, value: Any) -> None:
50
+ """Set configuration value."""
51
+ self._config[key] = value
52
+
53
+ def __getitem__(self, key: str) -> Any:
54
+ """Dictionary-style access."""
55
+ return self._config[key]
56
+
57
+ def __setitem__(self, key: str, value: Any) -> None:
58
+ """Dictionary-style assignment."""
59
+ self._config[key] = value
60
+
61
+ def __contains__(self, key: str) -> bool:
62
+ """Check if key exists."""
63
+ return key in self._config
64
+
65
+
66
+ class BaseService(ABC):
67
+ """
68
+ Abstract base class for all plug-and-play services.
69
+
70
+ Services must implement:
71
+ - register(): Register service configuration at startup
72
+ - get_middleware(): Return middleware if service needs request interception
73
+ - on_shutdown(): Cleanup on app shutdown
74
+ """
75
+
76
+ # Service name (override in subclass)
77
+ SERVICE_NAME: str = "base_service"
78
+
79
+ # Service configuration
80
+ _config: Optional[ServiceConfig] = None
81
+
82
+ # Registration state
83
+ _registered: bool = False
84
+
85
+ @classmethod
86
+ @abstractmethod
87
+ def register(cls, **config) -> None:
88
+ """
89
+ Register service configuration at application startup.
90
+
91
+ Args:
92
+ **config: Service-specific configuration parameters
93
+
94
+ Raises:
95
+ RuntimeError: If service is already registered
96
+ """
97
+ if cls._registered:
98
+ raise RuntimeError(f"{cls.SERVICE_NAME} is already registered")
99
+
100
+ cls._config = ServiceConfig(**config)
101
+ cls._registered = True
102
+
103
+ logger.info(f"✅ {cls.SERVICE_NAME} registered successfully")
104
+
105
+ @classmethod
106
+ def get_middleware(cls) -> Optional[BaseHTTPMiddleware]:
107
+ """
108
+ Return middleware instance if service needs request interception.
109
+
110
+ Returns:
111
+ Middleware instance or None if service doesn't need middleware
112
+ """
113
+ return None
114
+
115
+ @classmethod
116
+ def on_shutdown(cls) -> None:
117
+ """
118
+ Cleanup hook called during application shutdown.
119
+
120
+ Override this to perform cleanup (close connections, save state, etc.)
121
+ """
122
+ pass
123
+
124
+ @classmethod
125
+ def is_registered(cls) -> bool:
126
+ """Check if service has been registered."""
127
+ return cls._registered
128
+
129
+ @classmethod
130
+ def assert_registered(cls) -> None:
131
+ """
132
+ Assert that service has been registered.
133
+
134
+ Raises:
135
+ RuntimeError: If service is not registered
136
+ """
137
+ if not cls._registered:
138
+ raise RuntimeError(
139
+ f"{cls.SERVICE_NAME} is not registered. "
140
+ f"Call {cls.SERVICE_NAME}.register() at application startup."
141
+ )
142
+
143
+ @classmethod
144
+ def get_config(cls) -> ServiceConfig:
145
+ """
146
+ Get service configuration.
147
+
148
+ Returns:
149
+ ServiceConfig instance
150
+
151
+ Raises:
152
+ RuntimeError: If service is not registered
153
+ """
154
+ cls.assert_registered()
155
+ return cls._config
156
+
157
+
158
+ class ServiceRegistry:
159
+ """
160
+ Global registry for service discovery and management.
161
+
162
+ Tracks all registered services and provides lookup functionality.
163
+ """
164
+
165
+ _services: Dict[str, Type[BaseService]] = {}
166
+
167
+ @classmethod
168
+ def register_service(cls, service_class: Type[BaseService]) -> None:
169
+ """
170
+ Register a service class in the global registry.
171
+
172
+ Args:
173
+ service_class: Service class to register
174
+ """
175
+ service_name = service_class.SERVICE_NAME
176
+
177
+ if service_name in cls._services:
178
+ logger.warning(f"Service '{service_name}' already registered, overwriting")
179
+
180
+ cls._services[service_name] = service_class
181
+ logger.debug(f"Registered service: {service_name}")
182
+
183
+ @classmethod
184
+ def get_service(cls, service_name: str) -> Optional[Type[BaseService]]:
185
+ """
186
+ Get a service class by name.
187
+
188
+ Args:
189
+ service_name: Name of the service to retrieve
190
+
191
+ Returns:
192
+ Service class or None if not found
193
+ """
194
+ return cls._services.get(service_name)
195
+
196
+ @classmethod
197
+ def get_all_services(cls) -> Dict[str, Type[BaseService]]:
198
+ """
199
+ Get all registered services.
200
+
201
+ Returns:
202
+ Dictionary mapping service names to service classes
203
+ """
204
+ return cls._services.copy()
205
+
206
+ @classmethod
207
+ def get_all_middleware(cls) -> list:
208
+ """
209
+ Get middleware from all registered services.
210
+
211
+ Returns:
212
+ List of middleware instances in registration order
213
+ """
214
+ middleware_list = []
215
+
216
+ for service_name, service_class in cls._services.items():
217
+ if service_class.is_registered():
218
+ middleware = service_class.get_middleware()
219
+ if middleware:
220
+ middleware_list.append(middleware)
221
+ logger.debug(f"Added middleware from service: {service_name}")
222
+
223
+ return middleware_list
224
+
225
+ @classmethod
226
+ def shutdown_all(cls) -> None:
227
+ """
228
+ Call shutdown hooks for all registered services.
229
+ """
230
+ logger.info("Shutting down all services...")
231
+
232
+ for service_name, service_class in cls._services.items():
233
+ try:
234
+ service_class.on_shutdown()
235
+ logger.debug(f"Shutdown complete: {service_name}")
236
+ except Exception as e:
237
+ logger.error(f"Error shutting down {service_name}: {e}")
238
+
239
+ logger.info("All services shut down")
240
+
241
+
242
+ __all__ = [
243
+ 'BaseService',
244
+ 'ServiceConfig',
245
+ 'ServiceRegistry',
246
+ ]
services/base_service/middleware_chain.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Middleware Chain - Orchestration of multiple middleware layers.
3
+
4
+ Provides utilities for managing and coordinating multiple middleware
5
+ components in the request/response flow.
6
+
7
+ Usage:
8
+ # In app.py
9
+ from services.base_service import MiddlewareChain
10
+
11
+ # Add middleware in reverse order (last added = first executed)
12
+ app.add_middleware(CreditMiddleware)
13
+ app.add_middleware(AuthMiddleware)
14
+
15
+ # Or use the chain helper
16
+ chain = MiddlewareChain()
17
+ chain.add(AuthMiddleware)
18
+ chain.add(CreditMiddleware)
19
+ chain.apply_to_app(app)
20
+ """
21
+
22
+ import logging
23
+ from typing import List, Type, Callable
24
+ from starlette.middleware.base import BaseHTTPMiddleware
25
+ from fastapi import FastAPI, Request, Response
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class RequestContext:
31
+ """
32
+ Shared context for passing data between middleware layers.
33
+
34
+ Attached to request.state for access across middleware and routers.
35
+ """
36
+
37
+ def __init__(self):
38
+ """Initialize empty context."""
39
+ # Auth layer
40
+ self.user = None
41
+ self.is_authenticated = False
42
+
43
+ # Credit layer
44
+ self.credits_reserved = 0
45
+ self.credit_cost = 0
46
+
47
+ # General
48
+ self.start_time = None
49
+ self.service_flags = {}
50
+
51
+ def set_user(self, user) -> None:
52
+ """Set authenticated user."""
53
+ self.user = user
54
+ self.is_authenticated = True
55
+
56
+ def set_credits(self, reserved: int, cost: int) -> None:
57
+ """Set credit information."""
58
+ self.credits_reserved = reserved
59
+ self.credit_cost = cost
60
+
61
+ def set_flag(self, key: str, value: any) -> None:
62
+ """Set a service-specific flag."""
63
+ self.service_flags[key] = value
64
+
65
+ def get_flag(self, key: str, default=None) -> any:
66
+ """Get a service-specific flag."""
67
+ return self.service_flags.get(key, default)
68
+
69
+
70
+ class MiddlewareChain:
71
+ """
72
+ Helper for managing middleware registration order.
73
+
74
+ FastAPI/Starlette middleware executes in REVERSE order of registration,
75
+ so the LAST middleware added is the FIRST to execute.
76
+
77
+ This class helps manage the order explicitly.
78
+ """
79
+
80
+ def __init__(self):
81
+ """Initialize empty middleware chain."""
82
+ self._middleware: List[Type[BaseHTTPMiddleware]] = []
83
+
84
+ def add(self, middleware_class: Type[BaseHTTPMiddleware], **kwargs) -> 'MiddlewareChain':
85
+ """
86
+ Add middleware to the chain.
87
+
88
+ Middleware is added to the END of the list, but will be registered
89
+ in REVERSE order (so first added = first executed).
90
+
91
+ Args:
92
+ middleware_class: Middleware class to add
93
+ **kwargs: Arguments to pass to middleware constructor
94
+
95
+ Returns:
96
+ Self for chaining
97
+ """
98
+ self._middleware.append((middleware_class, kwargs))
99
+ logger.debug(f"Added middleware to chain: {middleware_class.__name__}")
100
+ return self
101
+
102
+ def apply_to_app(self, app: FastAPI) -> None:
103
+ """
104
+ Apply all middleware to the FastAPI app in correct order.
105
+
106
+ Middleware is registered in REVERSE order so that the first
107
+ middleware added to the chain is the first to execute.
108
+
109
+ Args:
110
+ app: FastAPI application instance
111
+ """
112
+ # Reverse the list so first added = first executed
113
+ for middleware_class, kwargs in reversed(self._middleware):
114
+ app.add_middleware(middleware_class, **kwargs)
115
+ logger.info(f"Registered middleware: {middleware_class.__name__}")
116
+
117
+ def get_middleware_list(self) -> List[Type[BaseHTTPMiddleware]]:
118
+ """
119
+ Get the list of middleware in execution order.
120
+
121
+ Returns:
122
+ List of middleware classes in the order they will execute
123
+ """
124
+ return [m[0] for m in self._middleware]
125
+
126
+ def __len__(self) -> int:
127
+ """Get number of middleware in chain."""
128
+ return len(self._middleware)
129
+
130
+ def __repr__(self) -> str:
131
+ """String representation for debugging."""
132
+ middleware_names = [m[0].__name__ for m in self._middleware]
133
+ return f"MiddlewareChain({middleware_names})"
134
+
135
+
136
+ async def initialize_request_context(request: Request) -> None:
137
+ """
138
+ Initialize request context for middleware to use.
139
+
140
+ This should be called early in the middleware chain to ensure
141
+ request.state.ctx is available.
142
+
143
+ Usage:
144
+ class MyMiddleware(BaseHTTPMiddleware):
145
+ async def dispatch(self, request: Request, call_next):
146
+ await initialize_request_context(request)
147
+ # Now request.state.ctx is available
148
+ ...
149
+ """
150
+ if not hasattr(request.state, "ctx"):
151
+ request.state.ctx = RequestContext()
152
+
153
+
154
+ def get_request_context(request: Request) -> RequestContext:
155
+ """
156
+ Get request context from request.state.
157
+
158
+ Creates context if it doesn't exist.
159
+
160
+ Args:
161
+ request: FastAPI request object
162
+
163
+ Returns:
164
+ RequestContext instance
165
+ """
166
+ if not hasattr(request.state, "ctx"):
167
+ request.state.ctx = RequestContext()
168
+ return request.state.ctx
169
+
170
+
171
+ class BaseServiceMiddleware(BaseHTTPMiddleware):
172
+ """
173
+ Base class for service middleware.
174
+
175
+ Provides common functionality for all service middleware:
176
+ - Request context initialization
177
+ - Error handling
178
+ - Logging
179
+ """
180
+
181
+ SERVICE_NAME = "base"
182
+
183
+ async def dispatch(self, request: Request, call_next: Callable) -> Response:
184
+ """
185
+ Process request through middleware.
186
+
187
+ Override this in subclasses to implement service-specific logic.
188
+ """
189
+ # Initialize context
190
+ await initialize_request_context(request)
191
+
192
+ # Call next middleware/route
193
+ response = await call_next(request)
194
+
195
+ return response
196
+
197
+ def log_request(self, request: Request, message: str) -> None:
198
+ """Log request with service context."""
199
+ logger.info(f"[{self.SERVICE_NAME}] {request.method} {request.url.path} - {message}")
200
+
201
+ def log_error(self, request: Request, error: str) -> None:
202
+ """Log error with service context."""
203
+ logger.error(f"[{self.SERVICE_NAME}] {request.method} {request.url.path} - ERROR: {error}")
204
+
205
+
206
+ __all__ = [
207
+ 'MiddlewareChain',
208
+ 'RequestContext',
209
+ 'BaseServiceMiddleware',
210
+ 'initialize_request_context',
211
+ 'get_request_context',
212
+ ]
services/base_service/route_matcher.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Route Matcher - URL pattern matching for service configuration.
3
+
4
+ Provides flexible URL matching capabilities for services to define
5
+ which routes require auth, credits, etc.
6
+
7
+ Supported patterns:
8
+ - Exact match: "/api/users"
9
+ - Prefix match: "/api/*" (matches /api/anything)
10
+ - Wildcard match: "/api/users/*/posts" (matches /api/users/123/posts)
11
+ - Deep wildcard: "/api/**" (matches /api/users/123/posts/456)
12
+ - Regex match: "^/api/v[0-9]+/.*$"
13
+
14
+ Usage:
15
+ matcher = RouteMatcher(["/api/*", "/admin/**"])
16
+
17
+ if matcher.matches("/api/users"):
18
+ # Route requires auth
19
+ pass
20
+ """
21
+
22
+ import re
23
+ import logging
24
+ from typing import List, Set, Optional, Pattern
25
+ from fnmatch import fnmatch
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class RouteMatcher:
31
+ """
32
+ Flexible URL pattern matcher for route configuration.
33
+
34
+ Supports exact matches, glob patterns, and regex patterns.
35
+ """
36
+
37
+ def __init__(self, patterns: List[str]):
38
+ """
39
+ Initialize route matcher with patterns.
40
+
41
+ Args:
42
+ patterns: List of URL patterns to match
43
+ """
44
+ self.patterns = patterns
45
+ self._exact_matches: Set[str] = set()
46
+ self._prefix_patterns: List[str] = []
47
+ self._glob_patterns: List[str] = []
48
+ self._regex_patterns: List[Pattern] = []
49
+
50
+ # Classify patterns for performance
51
+ self._classify_patterns()
52
+
53
+ def _classify_patterns(self) -> None:
54
+ """
55
+ Classify patterns by type for optimal matching performance.
56
+
57
+ Order of matching:
58
+ 1. Exact matches (fastest - O(1))
59
+ 2. Prefix patterns (fast - string startswith)
60
+ 3. Glob patterns (medium - fnmatch)
61
+ 4. Regex patterns (slowest - regex matching)
62
+ """
63
+ for pattern in self.patterns:
64
+ # Empty pattern
65
+ if not pattern:
66
+ continue
67
+
68
+ # Regex pattern (starts with ^)
69
+ if pattern.startswith("^"):
70
+ try:
71
+ compiled = re.compile(pattern)
72
+ self._regex_patterns.append(compiled)
73
+ logger.debug(f"Classified as regex: {pattern}")
74
+ except re.error as e:
75
+ logger.warning(f"Invalid regex pattern '{pattern}': {e}")
76
+ continue
77
+
78
+ # Glob pattern (contains * or ?)
79
+ if "*" in pattern or "?" in pattern:
80
+ # Simple prefix wildcard: /api/*
81
+ if pattern.endswith("/*") and "*" not in pattern[:-2]:
82
+ prefix = pattern[:-2] # Remove /*
83
+ self._prefix_patterns.append(prefix)
84
+ logger.debug(f"Classified as prefix: {prefix}")
85
+ else:
86
+ # Complex glob: /api/*/users or /api/**
87
+ self._glob_patterns.append(pattern)
88
+ logger.debug(f"Classified as glob: {pattern}")
89
+ continue
90
+
91
+ # Exact match
92
+ self._exact_matches.add(pattern)
93
+ logger.debug(f"Classified as exact: {pattern}")
94
+
95
+ def matches(self, path: str) -> bool:
96
+ """
97
+ Check if a URL path matches any configured pattern.
98
+
99
+ Args:
100
+ path: URL path to check (e.g., "/api/users/123")
101
+
102
+ Returns:
103
+ True if path matches any pattern, False otherwise
104
+ """
105
+ # Strip query parameters and fragments
106
+ path = path.split("?")[0].split("#")[0]
107
+
108
+ # Normalize path (remove trailing slash unless it's just "/")
109
+ if path != "/" and path.endswith("/"):
110
+ path = path.rstrip("/")
111
+
112
+ # 1. Exact match (O(1))
113
+ if path in self._exact_matches:
114
+ return True
115
+
116
+ # 2. Prefix match (O(n) but fast)
117
+ for prefix in self._prefix_patterns:
118
+ if path.startswith(prefix + "/") or path == prefix:
119
+ return True
120
+
121
+ # 3. Glob match (O(n))
122
+ for pattern in self._glob_patterns:
123
+ if fnmatch(path, pattern):
124
+ return True
125
+
126
+ # 4. Regex match (O(n) but slower)
127
+ for regex in self._regex_patterns:
128
+ if regex.match(path):
129
+ return True
130
+
131
+ return False
132
+
133
+ def get_matching_pattern(self, path: str) -> Optional[str]:
134
+ """
135
+ Get the first pattern that matches the given path.
136
+
137
+ Useful for debugging or determining which rule matched.
138
+
139
+ Args:
140
+ path: URL path to check
141
+
142
+ Returns:
143
+ Matching pattern string or None
144
+ """
145
+ # Strip query parameters and fragments
146
+ path = path.split("?")[0].split("#")[0]
147
+
148
+ # Normalize path
149
+ if path != "/" and path.endswith("/"):
150
+ path = path.rstrip("/")
151
+
152
+ # Exact match
153
+ if path in self._exact_matches:
154
+ return path
155
+
156
+ # Prefix match
157
+ for prefix in self._prefix_patterns:
158
+ if path.startswith(prefix + "/") or path == prefix:
159
+ return prefix + "/*"
160
+
161
+ # Glob match
162
+ for pattern in self._glob_patterns:
163
+ if fnmatch(path, pattern):
164
+ return pattern
165
+
166
+ # Regex match
167
+ for regex in self._regex_patterns:
168
+ if regex.match(path):
169
+ return regex.pattern
170
+
171
+ return None
172
+
173
+ def __repr__(self) -> str:
174
+ """String representation for debugging."""
175
+ return (
176
+ f"RouteMatcher("
177
+ f"exact={len(self._exact_matches)}, "
178
+ f"prefix={len(self._prefix_patterns)}, "
179
+ f"glob={len(self._glob_patterns)}, "
180
+ f"regex={len(self._regex_patterns)})"
181
+ )
182
+
183
+
184
+ class RouteConfig:
185
+ """
186
+ Route configuration helper for services.
187
+
188
+ Manages multiple route lists (required, optional, public) with
189
+ precedence and exclusion logic.
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ required: List[str] = None,
195
+ optional: List[str] = None,
196
+ public: List[str] = None,
197
+ ):
198
+ """
199
+ Initialize route configuration.
200
+
201
+ Args:
202
+ required: Routes that REQUIRE the service (e.g., auth required)
203
+ optional: Routes where service is OPTIONAL (e.g., auth optional)
204
+ public: Routes that are PUBLIC (e.g., no auth needed)
205
+
206
+ Precedence: public > required > optional (for conflict resolution)
207
+ """
208
+ self.required_matcher = RouteMatcher(required or [])
209
+ self.optional_matcher = RouteMatcher(optional or [])
210
+ self.public_matcher = RouteMatcher(public or [])
211
+
212
+ def is_required(self, path: str) -> bool:
213
+ """
214
+ Check if service is REQUIRED for this path.
215
+
216
+ Returns False if path is public (public takes precedence).
217
+ """
218
+ if self.is_public(path):
219
+ return False
220
+ return self.required_matcher.matches(path)
221
+
222
+ def is_optional(self, path: str) -> bool:
223
+ """
224
+ Check if service is OPTIONAL for this path.
225
+
226
+ Returns False if path is public or required.
227
+ """
228
+ if self.is_public(path):
229
+ return False
230
+ if self.required_matcher.matches(path):
231
+ return False
232
+ return self.optional_matcher.matches(path)
233
+
234
+ def is_public(self, path: str) -> bool:
235
+ """
236
+ Check if path is PUBLIC (service not needed).
237
+
238
+ Public takes highest precedence.
239
+ """
240
+ return self.public_matcher.matches(path)
241
+
242
+ def requires_service(self, path: str) -> bool:
243
+ """
244
+ Check if service is needed (required OR optional) for this path.
245
+
246
+ Returns False if path is not matched by any configuration.
247
+ """
248
+ return self.is_required(path) or self.is_optional(path)
249
+
250
+
251
+ __all__ = [
252
+ 'RouteMatcher',
253
+ 'RouteConfig',
254
+ ]
services/credit_service/__init__.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Credit Service - Credit validation middleware for API Gateway
3
+
4
+ Provides plug-and-play credit management with:
5
+ - Per-route cost configuration
6
+ - Credit reservation and validation
7
+ - Request middleware for credit checks
8
+ - Automatic refund on errors
9
+
10
+ Usage:
11
+ # In app.py startup
12
+ from services.credit_service import register_credit_service
13
+
14
+ register_credit_service(
15
+ route_costs={
16
+ "/gemini/generate-animation-prompt": 1,
17
+ "/gemini/edit-image": 1,
18
+ "/gemini/generate-video": 10,
19
+ "/gemini/generate-text": 1,
20
+ "/gemini/analyze-image": 1,
21
+ }
22
+ )
23
+
24
+ # In routers
25
+ from fastapi import Request
26
+
27
+ @router.post("/api/endpoint")
28
+ async def endpoint(request: Request):
29
+ user = request.state.user # From AuthMiddleware
30
+ credits_reserved = request.state.credits_reserved # From CreditMiddleware
31
+ return {"credits_remaining": user.credits}
32
+ """
33
+
34
+ from services.credit_service.config import CreditServiceConfig
35
+ from services.credit_service.middleware import CreditMiddleware
36
+ from services.credit_service.credit_manager import (
37
+ reserve_credit,
38
+ confirm_credit,
39
+ refund_credit,
40
+ handle_job_completion,
41
+ is_refundable_error,
42
+ REFUNDABLE_ERROR_PATTERNS,
43
+ NON_REFUNDABLE_ERROR_PATTERNS,
44
+ )
45
+
46
+
47
+ def register_credit_service(
48
+ route_costs: dict = None,
49
+ ) -> None:
50
+ """
51
+ Register the credit service with application configuration.
52
+
53
+ Args:
54
+ route_costs: Dictionary mapping route paths to credit costs
55
+ Example: {"/gemini/generate-video": 10, "/gemini/edit-image": 1}
56
+ """
57
+ CreditServiceConfig.register(
58
+ route_costs=route_costs or {},
59
+ )
60
+
61
+
62
+ __all__ = [
63
+ # Registration
64
+ 'register_credit_service',
65
+ 'CreditServiceConfig',
66
+ 'CreditMiddleware',
67
+
68
+ # Credit Management
69
+ 'reserve_credit',
70
+ 'confirm_credit',
71
+ 'refund_credit',
72
+ 'handle_job_completion',
73
+ 'is_refundable_error',
74
+ 'REFUNDABLE_ERROR_PATTERNS',
75
+ 'NON_REFUNDABLE_ERROR_PATTERNS',
76
+ ]
services/credit_service/config.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Credit Service Configuration
3
+
4
+ Manages credit cost configuration for API routes.
5
+ """
6
+
7
+ import logging
8
+ from typing import Dict
9
+ from services.base_service import BaseService
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class CreditServiceConfig(BaseService):
15
+ """
16
+ Configuration for the credit service.
17
+
18
+ Controls which routes require credits and how much they cost.
19
+ """
20
+
21
+ SERVICE_NAME = "credit_service"
22
+
23
+ # Route cost configuration
24
+ _route_costs: Dict[str, int] = {}
25
+
26
+ @classmethod
27
+ def register(
28
+ cls,
29
+ route_costs: Dict[str, int] = None,
30
+ ) -> None:
31
+ """
32
+ Register credit service configuration.
33
+
34
+ Args:
35
+ route_costs: Dictionary mapping route paths to credit costs
36
+ Example: {"/gemini/generate-video": 10, "/gemini/edit-image": 1}
37
+
38
+ Raises:
39
+ RuntimeError: If service is already registered
40
+ """
41
+ if cls._registered:
42
+ raise RuntimeError(f"{cls.SERVICE_NAME} is already registered")
43
+
44
+ # Store route costs
45
+ cls._route_costs = route_costs or {}
46
+
47
+ cls._registered = True
48
+
49
+ logger.info(f"✅ {cls.SERVICE_NAME} registered successfully")
50
+ logger.info(f" Routes with credit costs: {len(cls._route_costs)}")
51
+ for route, cost in cls._route_costs.items():
52
+ logger.info(f" {route}: {cost} credits")
53
+
54
+ @classmethod
55
+ def get_middleware(cls):
56
+ """Return CreditMiddleware instance."""
57
+ from services.credit_service.middleware import CreditMiddleware
58
+ return CreditMiddleware
59
+
60
+ @classmethod
61
+ def get_cost(cls, path: str) -> int:
62
+ """
63
+ Get the credit cost for a given path.
64
+
65
+ Args:
66
+ path: URL path to check
67
+
68
+ Returns:
69
+ Credit cost (0 if route doesn't require credits)
70
+ """
71
+ cls.assert_registered()
72
+ return cls._route_costs.get(path, 0)
73
+
74
+ @classmethod
75
+ def requires_credits(cls, path: str) -> bool:
76
+ """Check if a URL path requires credits."""
77
+ cls.assert_registered()
78
+ return cls.get_cost(path) > 0
79
+
80
+ @classmethod
81
+ def get_all_costs(cls) -> Dict[str, int]:
82
+ """Get all route costs."""
83
+ cls.assert_registered()
84
+ return cls._route_costs.copy()
85
+
86
+
87
+ __all__ = ['CreditServiceConfig']
services/credit_service/credit_manager.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Credit Service - Manages credit reservation, confirmation, and refunding.
3
+
4
+ Implements the Credit Reservation Pattern:
5
+ 1. Reserve credits when job is created (deduct from user, track in job)
6
+ 2. Confirm credits only on successful completion
7
+ 3. Refund credits on refundable errors (server-side issues)
8
+ 4. Keep credits on non-refundable errors (user-caused issues)
9
+ """
10
+ import logging
11
+ from typing import Optional
12
+ from sqlalchemy.ext.asyncio import AsyncSession
13
+ from sqlalchemy import select
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ # =============================================================================
19
+ # Error Categories for Refund Decisions
20
+ # =============================================================================
21
+
22
+ # Refundable errors - User gets credits back (server/API issues)
23
+ REFUNDABLE_ERROR_PATTERNS = [
24
+ "API_KEY_INVALID",
25
+ "QUOTA_EXCEEDED",
26
+ "INTERNAL_ERROR",
27
+ "CONNECTION_FAILED",
28
+ "SERVER_SHUTDOWN",
29
+ "TIMEOUT",
30
+ "Server Authentication Error",
31
+ "Network error",
32
+ "Connection refused",
33
+ "Connection reset",
34
+ "Service unavailable",
35
+ "503",
36
+ "500",
37
+ "429", # Rate limit (our quota, not user's fault)
38
+ ]
39
+
40
+ # Non-refundable error patterns - User's input/content issue
41
+ NON_REFUNDABLE_ERROR_PATTERNS = [
42
+ "safety",
43
+ "blocked",
44
+ "SAFETY_FILTER",
45
+ "INVALID_INPUT",
46
+ "Invalid image",
47
+ "Bad request",
48
+ "400",
49
+ "cancelled",
50
+ "User cancelled",
51
+ ]
52
+
53
+
54
+ def is_refundable_error(error_message: Optional[str]) -> bool:
55
+ """
56
+ Determine if an error should result in a credit refund.
57
+
58
+ Args:
59
+ error_message: The error message from the failed job
60
+
61
+ Returns:
62
+ True if the error is refundable (server/API issue)
63
+ False if non-refundable (user's fault) or no error message
64
+ """
65
+ if not error_message:
66
+ return False
67
+
68
+ error_lower = error_message.lower()
69
+
70
+ # Check for REFUNDABLE patterns FIRST (specific server errors take precedence)
71
+ # This ensures API_KEY_INVALID is caught before generic "400" matcher
72
+ for pattern in REFUNDABLE_ERROR_PATTERNS:
73
+ if pattern.lower() in error_lower:
74
+ logger.debug(f"Error matched refundable pattern '{pattern}': {error_message[:100]}")
75
+ return True
76
+
77
+ # Check for non-refundable patterns (user-caused issues)
78
+ for pattern in NON_REFUNDABLE_ERROR_PATTERNS:
79
+ if pattern.lower() in error_lower:
80
+ logger.debug(f"Error matched non-refundable pattern '{pattern}': {error_message[:100]}")
81
+ return False
82
+
83
+ # Default: Max retries exceeded is refundable (we consumed API resources trying)
84
+ if "max retries" in error_lower:
85
+ return True
86
+
87
+ # Default: Unknown errors are NOT refundable to prevent abuse
88
+ # If it's an unknown error, it's more likely user-caused
89
+ logger.debug(f"Unknown error (not refundable): {error_message[:100]}")
90
+ return False
91
+
92
+
93
+ async def reserve_credit(session: AsyncSession, user, amount: int = 1) -> bool:
94
+ """
95
+ Reserve credits for a job (deduct from user's balance).
96
+
97
+ The credits are deducted but tracked in the job's credits_reserved field.
98
+ If the job fails with a refundable error, they can be restored.
99
+
100
+ Args:
101
+ session: Database session
102
+ user: User model instance
103
+ amount: Number of credits to reserve (default: 1)
104
+
105
+ Returns:
106
+ True if credits were successfully reserved
107
+ False if user has insufficient credits
108
+ """
109
+ if user.credits < amount:
110
+ logger.warning(f"User {user.user_id} has insufficient credits ({user.credits}) to reserve {amount}")
111
+ return False
112
+
113
+ user.credits -= amount
114
+ logger.info(f"Reserved {amount} credit(s) for user {user.user_id}. Remaining: {user.credits}")
115
+ # Note: Don't commit here - let caller handle transaction
116
+ return True
117
+
118
+
119
+ async def confirm_credit(session: AsyncSession, job) -> None:
120
+ """
121
+ Confirm that credits were legitimately used for a completed job.
122
+
123
+ This is called when a job completes successfully. The credits stay
124
+ deducted (they were already deducted during reservation).
125
+
126
+ Args:
127
+ session: Database session
128
+ job: GeminiJob model instance
129
+ """
130
+ if job.credits_reserved > 0:
131
+ # Credits were used - clear the reservation tracking
132
+ credits_used = job.credits_reserved
133
+ job.credits_reserved = 0
134
+ logger.info(f"Confirmed {credits_used} credit(s) used for job {job.job_id}")
135
+ # Note: Don't commit here - let caller handle transaction
136
+
137
+
138
+ async def refund_credit(session: AsyncSession, job, reason: str) -> bool:
139
+ """
140
+ Refund reserved credits back to the user.
141
+
142
+ Called when a job fails due to a refundable error (server-side issue).
143
+
144
+ Args:
145
+ session: Database session
146
+ job: GeminiJob model instance
147
+ reason: Reason for the refund (for logging)
148
+
149
+ Returns:
150
+ True if credits were refunded
151
+ False if no credits to refund or already refunded
152
+ """
153
+ if job.credits_reserved <= 0:
154
+ logger.debug(f"Job {job.job_id} has no credits to refund")
155
+ return False
156
+
157
+ if job.credits_refunded:
158
+ logger.warning(f"Job {job.job_id} was already refunded")
159
+ return False
160
+
161
+ # Get the user to restore credits
162
+ from core.models import User
163
+
164
+ result = await session.execute(
165
+ select(User).where(User.id == job.user_id)
166
+ )
167
+ user = result.scalar_one_or_none()
168
+
169
+ if not user:
170
+ logger.error(f"Cannot refund job {job.job_id}: User {job.user_id} not found")
171
+ return False
172
+
173
+ # Restore credits
174
+ credits_to_refund = job.credits_reserved
175
+ user.credits += credits_to_refund
176
+ job.credits_reserved = 0
177
+ job.credits_refunded = True
178
+
179
+ logger.info(
180
+ f"Refunded {credits_to_refund} credit(s) to user {user.user_id} for job {job.job_id}. "
181
+ f"Reason: {reason[:100]}. New balance: {user.credits}"
182
+ )
183
+
184
+ # Note: Don't commit here - let caller handle transaction
185
+ return True
186
+
187
+
188
+ async def handle_job_completion(session: AsyncSession, job) -> None:
189
+ """
190
+ Handle credit finalization when a job completes or fails.
191
+
192
+ This is the main entry point called by the job worker.
193
+
194
+ Args:
195
+ session: Database session
196
+ job: GeminiJob model instance with final status
197
+ """
198
+ if job.status == "completed":
199
+ # Success - confirm credits were used
200
+ await confirm_credit(session, job)
201
+
202
+ elif job.status == "failed":
203
+ # Failure - check if refundable
204
+ if is_refundable_error(job.error_message):
205
+ await refund_credit(session, job, job.error_message or "Unknown error")
206
+ else:
207
+ # Non-refundable - confirm credits were used (user's fault)
208
+ await confirm_credit(session, job)
209
+ logger.info(f"Job {job.job_id} failed with non-refundable error, credits kept")
210
+
211
+ elif job.status == "cancelled":
212
+ # Cancelled jobs get refunds only if they were never started
213
+ if job.started_at is None:
214
+ await refund_credit(session, job, "Job cancelled before processing")
215
+ else:
216
+ # Was processing - keep credits (API may have been consumed)
217
+ await confirm_credit(session, job)
218
+ logger.info(f"Job {job.job_id} cancelled during processing, credits kept")
219
+
220
+
221
+ async def refund_orphaned_jobs(session: AsyncSession) -> int:
222
+ """
223
+ Refund credits for jobs that were abandoned due to server shutdown.
224
+
225
+ Called during graceful shutdown to ensure no credits are lost.
226
+
227
+ Args:
228
+ session: Database session
229
+
230
+ Returns:
231
+ Number of jobs that were refunded
232
+ """
233
+ from core.models import GeminiJob
234
+
235
+ # Find jobs that are still processing with reserved credits
236
+ result = await session.execute(
237
+ select(GeminiJob).where(
238
+ GeminiJob.status == "processing",
239
+ GeminiJob.credits_reserved > 0,
240
+ GeminiJob.credits_refunded == False
241
+ )
242
+ )
243
+ orphaned_jobs = result.scalars().all()
244
+
245
+ refund_count = 0
246
+ for job in orphaned_jobs:
247
+ if await refund_credit(session, job, "SERVER_SHUTDOWN: Job orphaned during server shutdown"):
248
+ # Mark job as failed
249
+ job.status = "failed"
250
+ job.error_message = "Server shutdown during processing. Credits refunded."
251
+ refund_count += 1
252
+
253
+ if refund_count > 0:
254
+ await session.commit()
255
+ logger.info(f"Refunded {refund_count} orphaned job(s) during shutdown")
256
+
257
+ return refund_count
services/credit_service/middleware.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Credit Middleware - Request credit validation layer
3
+
4
+ Intercepts requests to validate and reserve credits for paid endpoints.
5
+ """
6
+
7
+ import logging
8
+ from datetime import datetime
9
+ from fastapi import Request, HTTPException, status
10
+ from fastapi.responses import JSONResponse
11
+ from sqlalchemy.ext.asyncio import AsyncSession
12
+
13
+ from core.database import async_session_maker
14
+ from services.credit_service.config import CreditServiceConfig
15
+ from services.base_service.middleware_chain import (
16
+ BaseServiceMiddleware,
17
+ get_request_context,
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class CreditMiddleware(BaseServiceMiddleware):
24
+ """
25
+ Credit validation middleware for request validation.
26
+
27
+ Flow:
28
+ 1. Check if route requires credits based on URL
29
+ 2. Get authenticated user from request.state (set by AuthMiddleware)
30
+ 3. Check if user has sufficient credits
31
+ 4. Reserve credits (deduct from balance)
32
+ 5. Attach credit info to request.state
33
+ 6. Continue to next middleware/route
34
+
35
+ Credits are reserved but tracked - they can be refunded if job fails
36
+ with a server-side error.
37
+
38
+ NOTE: This middleware MUST run AFTER AuthMiddleware since it needs
39
+ the authenticated user from request.state.user
40
+ """
41
+
42
+ SERVICE_NAME = "credit"
43
+
44
+ async def dispatch(self, request: Request, call_next):
45
+ """Process request through credit middleware."""
46
+ # Skip OPTIONS requests (CORS preflight)
47
+ if request.method == "OPTIONS":
48
+ return await call_next(request)
49
+
50
+ # Initialize request context
51
+ ctx = get_request_context(request)
52
+
53
+ # Get path from request
54
+ path = request.url.path
55
+
56
+ # Check if route requires credits
57
+ credit_cost = CreditServiceConfig.get_cost(path)
58
+
59
+ if credit_cost == 0:
60
+ # Route doesn't require credits, skip
61
+ self.log_request(request, f"Route doesn't require credits")
62
+ ctx.set_credits(0, 0)
63
+ response = await call_next(request)
64
+ return response
65
+
66
+ # Route requires credits - user MUST be authenticated
67
+ # (AuthMiddleware should have already validated this)
68
+ user = request.state.user if hasattr(request.state, 'user') else None
69
+
70
+ if not user:
71
+ # This shouldn't happen if auth is configured correctly
72
+ self.log_error(request, "Credit-required route accessed without authentication")
73
+ return JSONResponse(
74
+ status_code=status.HTTP_401_UNAUTHORIZED,
75
+ content={"detail": "Authentication required for this endpoint"},
76
+ )
77
+
78
+ # Check if user has sufficient credits
79
+ if user.credits < credit_cost:
80
+ self.log_request(
81
+ request,
82
+ f"Insufficient credits: has {user.credits}, needs {credit_cost}"
83
+ )
84
+ return JSONResponse(
85
+ status_code=status.HTTP_402_PAYMENT_REQUIRED,
86
+ content={
87
+ "detail": f"Insufficient credits. This operation requires {credit_cost} credits. You have {user.credits}.",
88
+ "credits_required": credit_cost,
89
+ "credits_available": user.credits,
90
+ },
91
+ )
92
+
93
+ # Reserve credits (deduct from user balance)
94
+ async with async_session_maker() as db:
95
+ try:
96
+ # Deduct credits
97
+ user.credits -= credit_cost
98
+ user.last_used_at = datetime.utcnow()
99
+
100
+ # Update in database
101
+ db.add(user)
102
+ await db.commit()
103
+ await db.refresh(user)
104
+
105
+ # Attach credit info to request state
106
+ ctx.set_credits(credit_cost, user.credits)
107
+ request.state.credits_reserved = credit_cost
108
+ request.state.credits_remaining = user.credits
109
+
110
+ self.log_request(
111
+ request,
112
+ f"Reserved {credit_cost} credits for {user.user_id}, remaining: {user.credits}"
113
+ )
114
+
115
+ # Continue to next middleware/route
116
+ response = await call_next(request)
117
+ return response
118
+
119
+ except Exception as e:
120
+ await db.rollback()
121
+ self.log_error(request, f"Error reserving credits: {e}")
122
+ return JSONResponse(
123
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
124
+ content={"detail": "Failed to reserve credits. Please try again."},
125
+ )
126
+ finally:
127
+ await db.close()
128
+
129
+
130
+ __all__ = ['CreditMiddleware']
tests/test_base_service.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for base service infrastructure.
3
+
4
+ Tests:
5
+ - BaseService registration and configuration
6
+ - ServiceRegistry service management
7
+ - ServiceConfig operations
8
+ """
9
+
10
+ import pytest
11
+ from services.base_service import BaseService, ServiceConfig, ServiceRegistry
12
+
13
+
14
+ class TestServiceConfig:
15
+ """Test ServiceConfig container."""
16
+
17
+ def test_initialization(self):
18
+ """Test config initialization with kwargs."""
19
+ config = ServiceConfig(key1="value1", key2=42)
20
+
21
+ assert config.get("key1") == "value1"
22
+ assert config.get("key2") == 42
23
+
24
+ def test_get_with_default(self):
25
+ """Test get with default value."""
26
+ config = ServiceConfig(key1="value1")
27
+
28
+ assert config.get("key1") == "value1"
29
+ assert config.get("missing", "default") == "default"
30
+ assert config.get("missing") is None
31
+
32
+ def test_set_value(self):
33
+ """Test setting values."""
34
+ config = ServiceConfig()
35
+
36
+ config.set("key1", "value1")
37
+ assert config.get("key1") == "value1"
38
+
39
+ def test_dictionary_access(self):
40
+ """Test dictionary-style access."""
41
+ config = ServiceConfig(key1="value1")
42
+
43
+ assert config["key1"] == "value1"
44
+
45
+ config["key2"] = "value2"
46
+ assert config["key2"] == "value2"
47
+
48
+ def test_contains(self):
49
+ """Test 'in' operator."""
50
+ config = ServiceConfig(key1="value1")
51
+
52
+ assert "key1" in config
53
+ assert "missing" not in config
54
+
55
+
56
+ class TestBaseService:
57
+ """Test BaseService abstract class."""
58
+
59
+ def setup_method(self):
60
+ """Reset service state before each test."""
61
+ # Create concrete test service
62
+ class TestService(BaseService):
63
+ SERVICE_NAME = "test_service"
64
+
65
+ @classmethod
66
+ def register(cls, **config):
67
+ super().register(**config)
68
+
69
+ self.TestService = TestService
70
+
71
+ # Reset state
72
+ self.TestService._registered = False
73
+ self.TestService._config = None
74
+
75
+ def test_registration(self):
76
+ """Test service registration."""
77
+ assert not self.TestService.is_registered()
78
+
79
+ self.TestService.register(key1="value1", key2=42)
80
+
81
+ assert self.TestService.is_registered()
82
+ assert self.TestService.get_config().get("key1") == "value1"
83
+ assert self.TestService.get_config().get("key2") == 42
84
+
85
+ def test_double_registration_fails(self):
86
+ """Test that double registration raises error."""
87
+ self.TestService.register(key1="value1")
88
+
89
+ with pytest.raises(RuntimeError, match="already registered"):
90
+ self.TestService.register(key2="value2")
91
+
92
+ def test_assert_registered(self):
93
+ """Test assert_registered raises when not registered."""
94
+ with pytest.raises(RuntimeError, match="not registered"):
95
+ self.TestService.assert_registered()
96
+
97
+ self.TestService.register()
98
+ self.TestService.assert_registered() # Should not raise
99
+
100
+ def test_get_config_before_registration(self):
101
+ """Test get_config raises before registration."""
102
+ with pytest.raises(RuntimeError, match="not registered"):
103
+ self.TestService.get_config()
104
+
105
+ def test_get_middleware_default(self):
106
+ """Test default get_middleware returns None."""
107
+ assert self.TestService.get_middleware() is None
108
+
109
+ def test_on_shutdown_default(self):
110
+ """Test default on_shutdown does nothing."""
111
+ self.TestService.on_shutdown() # Should not raise
112
+
113
+
114
+ class TestServiceRegistry:
115
+ """Test ServiceRegistry global registry."""
116
+
117
+ def setup_method(self):
118
+ """Reset registry before each test."""
119
+ ServiceRegistry._services = {}
120
+
121
+ def test_register_service(self):
122
+ """Test registering a service."""
123
+ class TestService(BaseService):
124
+ SERVICE_NAME = "test_service"
125
+
126
+ @classmethod
127
+ def register(cls, **config):
128
+ super().register(**config)
129
+
130
+ ServiceRegistry.register_service(TestService)
131
+
132
+ assert ServiceRegistry.get_service("test_service") == TestService
133
+
134
+ def test_register_multiple_services(self):
135
+ """Test registering multiple services."""
136
+ class Service1(BaseService):
137
+ SERVICE_NAME = "service1"
138
+
139
+ @classmethod
140
+ def register(cls, **config):
141
+ super().register(**config)
142
+
143
+ class Service2(BaseService):
144
+ SERVICE_NAME = "service2"
145
+
146
+ @classmethod
147
+ def register(cls, **config):
148
+ super().register(**config)
149
+
150
+ ServiceRegistry.register_service(Service1)
151
+ ServiceRegistry.register_service(Service2)
152
+
153
+ assert len(ServiceRegistry.get_all_services()) == 2
154
+ assert ServiceRegistry.get_service("service1") == Service1
155
+ assert ServiceRegistry.get_service("service2") == Service2
156
+
157
+ def test_get_nonexistent_service(self):
158
+ """Test getting service that doesn't exist."""
159
+ assert ServiceRegistry.get_service("nonexistent") is None
160
+
161
+ def test_overwrite_service(self):
162
+ """Test registering service with same name overwrites."""
163
+ class Service1(BaseService):
164
+ SERVICE_NAME = "test"
165
+ version = 1
166
+
167
+ @classmethod
168
+ def register(cls, **config):
169
+ super().register(**config)
170
+
171
+ class Service2(BaseService):
172
+ SERVICE_NAME = "test"
173
+ version = 2
174
+
175
+ @classmethod
176
+ def register(cls, **config):
177
+ super().register(**config)
178
+
179
+ ServiceRegistry.register_service(Service1)
180
+ ServiceRegistry.register_service(Service2)
181
+
182
+ service = ServiceRegistry.get_service("test")
183
+ assert service.version == 2
184
+
185
+ def test_get_all_middleware(self):
186
+ """Test getting middleware from all services."""
187
+ class MockMiddleware:
188
+ pass
189
+
190
+ class ServiceWithMiddleware(BaseService):
191
+ SERVICE_NAME = "with_middleware"
192
+
193
+ @classmethod
194
+ def register(cls, **config):
195
+ super().register(**config)
196
+
197
+ @classmethod
198
+ def get_middleware(cls):
199
+ return MockMiddleware()
200
+
201
+ class ServiceWithoutMiddleware(BaseService):
202
+ SERVICE_NAME = "without_middleware"
203
+
204
+ @classmethod
205
+ def register(cls, **config):
206
+ super().register(**config)
207
+
208
+ # Register services
209
+ ServiceWithMiddleware.register()
210
+ ServiceWithoutMiddleware.register()
211
+
212
+ ServiceRegistry.register_service(ServiceWithMiddleware)
213
+ ServiceRegistry.register_service(ServiceWithoutMiddleware)
214
+
215
+ middleware_list = ServiceRegistry.get_all_middleware()
216
+
217
+ assert len(middleware_list) == 1
218
+ assert isinstance(middleware_list[0], MockMiddleware)
219
+
220
+ def test_shutdown_all(self):
221
+ """Test calling shutdown on all services."""
222
+ shutdown_called = []
223
+
224
+ class Service1(BaseService):
225
+ SERVICE_NAME = "service1"
226
+
227
+ @classmethod
228
+ def register(cls, **config):
229
+ super().register(**config)
230
+
231
+ @classmethod
232
+ def on_shutdown(cls):
233
+ shutdown_called.append("service1")
234
+
235
+ class Service2(BaseService):
236
+ SERVICE_NAME = "service2"
237
+
238
+ @classmethod
239
+ def register(cls, **config):
240
+ super().register(**config)
241
+
242
+ @classmethod
243
+ def on_shutdown(cls):
244
+ shutdown_called.append("service2")
245
+
246
+ ServiceRegistry.register_service(Service1)
247
+ ServiceRegistry.register_service(Service2)
248
+
249
+ ServiceRegistry.shutdown_all()
250
+
251
+ assert "service1" in shutdown_called
252
+ assert "service2" in shutdown_called
tests/test_route_matcher.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for route matcher.
3
+
4
+ Tests:
5
+ - Exact path matching
6
+ - Prefix pattern matching
7
+ - Glob pattern matching
8
+ - Regex pattern matching
9
+ - RouteConfig precedence logic
10
+ """
11
+
12
+ import pytest
13
+ from services.base_service.route_matcher import RouteMatcher, RouteConfig
14
+
15
+
16
+ class TestRouteMatcher:
17
+ """Test RouteMatcher pattern matching."""
18
+
19
+ def test_exact_match(self):
20
+ """Test exact path matching."""
21
+ matcher = RouteMatcher(["/api/users", "/api/posts"])
22
+
23
+ assert matcher.matches("/api/users")
24
+ assert matcher.matches("/api/posts")
25
+ assert not matcher.matches("/api/comments")
26
+ assert not matcher.matches("/api/users/123")
27
+
28
+ def test_prefix_match(self):
29
+ """Test prefix wildcard matching."""
30
+ matcher = RouteMatcher(["/api/*", "/admin/*"])
31
+
32
+ assert matcher.matches("/api/users")
33
+ assert matcher.matches("/api/posts")
34
+ assert matcher.matches("/admin/dashboard")
35
+ assert not matcher.matches("/public/page")
36
+
37
+ def test_complex_glob_match(self):
38
+ """Test complex glob patterns."""
39
+ matcher = RouteMatcher(["/api/users/*/posts", "/api/**/comments"])
40
+
41
+ assert matcher.matches("/api/users/123/posts")
42
+ assert matcher.matches("/api/users/456/posts")
43
+ assert matcher.matches("/api/v1/users/comments")
44
+ assert matcher.matches("/api/deep/nested/path/comments")
45
+ assert not matcher.matches("/api/users/posts")
46
+
47
+ def test_regex_match(self):
48
+ """Test regex pattern matching."""
49
+ matcher = RouteMatcher(["^/api/v[0-9]+/.*$", "^/users/[0-9]+$"])
50
+
51
+ assert matcher.matches("/api/v1/users")
52
+ assert matcher.matches("/api/v2/posts")
53
+ assert matcher.matches("/users/123")
54
+ assert not matcher.matches("/api/v/users")
55
+ assert not matcher.matches("/users/abc")
56
+
57
+ def test_query_parameters_stripped(self):
58
+ """Test that query parameters are ignored."""
59
+ matcher = RouteMatcher(["/api/users"])
60
+
61
+ assert matcher.matches("/api/users?page=1")
62
+ assert matcher.matches("/api/users?page=1&limit=10")
63
+
64
+ def test_fragments_stripped(self):
65
+ """Test that URL fragments are ignored."""
66
+ matcher = RouteMatcher(["/api/users"])
67
+
68
+ assert matcher.matches("/api/users#section")
69
+
70
+ def test_trailing_slash_normalized(self):
71
+ """Test trailing slash normalization."""
72
+ matcher = RouteMatcher(["/api/users"])
73
+
74
+ assert matcher.matches("/api/users/")
75
+
76
+ # Root path keeps trailing slash
77
+ root_matcher = RouteMatcher(["/"])
78
+ assert root_matcher.matches("/")
79
+
80
+ def test_empty_patterns(self):
81
+ """Test with empty pattern list."""
82
+ matcher = RouteMatcher([])
83
+
84
+ assert not matcher.matches("/any/path")
85
+
86
+ def test_get_matching_pattern(self):
87
+ """Test getting the matched pattern."""
88
+ matcher = RouteMatcher([
89
+ "/api/users",
90
+ "/api/*",
91
+ "/admin/**"
92
+ ])
93
+
94
+ assert matcher.get_matching_pattern("/api/users") == "/api/users"
95
+ assert matcher.get_matching_pattern("/api/posts") == "/api/*"
96
+ assert matcher.get_matching_pattern("/admin/deep/path") == "/admin/**"
97
+ assert matcher.get_matching_pattern("/public") is None
98
+
99
+ def test_mixed_patterns(self):
100
+ """Test combination of all pattern types."""
101
+ matcher = RouteMatcher([
102
+ "/exact",
103
+ "/prefix/*",
104
+ "/glob/*/nested",
105
+ "^/regex/[0-9]+$"
106
+ ])
107
+
108
+ assert matcher.matches("/exact")
109
+ assert matcher.matches("/prefix/anything")
110
+ assert matcher.matches("/glob/123/nested")
111
+ assert matcher.matches("/regex/456")
112
+ assert not matcher.matches("/other")
113
+
114
+ def test_invalid_regex_pattern(self):
115
+ """Test that invalid regex is handled gracefully."""
116
+ # Should not raise, just log warning and skip pattern
117
+ matcher = RouteMatcher(["^[invalid(regex$"])
118
+
119
+ assert not matcher.matches("/anything")
120
+
121
+
122
+ class TestRouteConfig:
123
+ """Test RouteConfig precedence logic."""
124
+
125
+ def test_required_routes(self):
126
+ """Test required route checking."""
127
+ config = RouteConfig(
128
+ required=["/api/users", "/api/posts"],
129
+ )
130
+
131
+ assert config.is_required("/api/users")
132
+ assert config.is_required("/api/posts")
133
+ assert not config.is_required("/public")
134
+
135
+ def test_optional_routes(self):
136
+ """Test optional route checking."""
137
+ config = RouteConfig(
138
+ optional=["/", "/home"],
139
+ )
140
+
141
+ assert config.is_optional("/")
142
+ assert config.is_optional("/home")
143
+ assert not config.is_optional("/api/users")
144
+
145
+ def test_public_routes(self):
146
+ """Test public route checking."""
147
+ config = RouteConfig(
148
+ public=["/health", "/docs"],
149
+ )
150
+
151
+ assert config.is_public("/health")
152
+ assert config.is_public("/docs")
153
+ assert not config.is_public("/api/users")
154
+
155
+ def test_public_overrides_required(self):
156
+ """Test that public takes precedence over required."""
157
+ config = RouteConfig(
158
+ required=["/api/*"],
159
+ public=["/api/health"],
160
+ )
161
+
162
+ # /api/health is public, so not required
163
+ assert config.is_public("/api/health")
164
+ assert not config.is_required("/api/health")
165
+
166
+ # Other /api routes are required
167
+ assert config.is_required("/api/users")
168
+ assert not config.is_public("/api/users")
169
+
170
+ def test_public_overrides_optional(self):
171
+ """Test that public takes precedence over optional."""
172
+ config = RouteConfig(
173
+ optional=["/api/*"],
174
+ public=["/api/health"],
175
+ )
176
+
177
+ # /api/health is public, so not optional
178
+ assert config.is_public("/api/health")
179
+ assert not config.is_optional("/api/health")
180
+
181
+ # Other /api routes are optional
182
+ assert config.is_optional("/api/users")
183
+
184
+ def test_required_overrides_optional(self):
185
+ """Test that required takes precedence over optional."""
186
+ config = RouteConfig(
187
+ required=["/api/users"],
188
+ optional=["/api/*"],
189
+ )
190
+
191
+ # /api/users is required, so not optional
192
+ assert config.is_required("/api/users")
193
+ assert not config.is_optional("/api/users")
194
+
195
+ # Other /api routes are optional
196
+ assert config.is_optional("/api/posts")
197
+
198
+ def test_requires_service(self):
199
+ """Test requires_service helper."""
200
+ config = RouteConfig(
201
+ required=["/api/users"],
202
+ optional=["/api/posts"],
203
+ public=["/health"],
204
+ )
205
+
206
+ # Service required
207
+ assert config.requires_service("/api/users")
208
+
209
+ # Service optional (still requires service)
210
+ assert config.requires_service("/api/posts")
211
+
212
+ # Public (does not require service)
213
+ assert not config.requires_service("/health")
214
+
215
+ def test_empty_config(self):
216
+ """Test with empty configuration."""
217
+ config = RouteConfig()
218
+
219
+ assert not config.is_required("/any")
220
+ assert not config.is_optional("/any")
221
+ assert not config.is_public("/any")
222
+ assert not config.requires_service("/any")
223
+
224
+ def test_complex_precedence(self):
225
+ """Test complex precedence scenarios."""
226
+ config = RouteConfig(
227
+ required=["/api/users"], # Specific required path
228
+ optional=["/api/*"], # Broader optional pattern
229
+ public=["/api/health"],
230
+ )
231
+
232
+ # Public overrides everything
233
+ assert config.is_public("/api/health")
234
+ assert not config.is_required("/api/health")
235
+ assert not config.is_optional("/api/health")
236
+
237
+ # Required path
238
+ assert config.is_required("/api/users")
239
+ assert not config.is_optional("/api/users")
240
+
241
+ # Optional for other paths under /api
242
+ assert config.is_optional("/api/posts")
243
+ assert not config.is_required("/api/posts")