Spaces:
Running
Running
| """ | |
| Credit Middleware - Automatic credit management. | |
| Reserves credits on request, confirms/refunds based on response. | |
| """ | |
| import logging | |
| import json | |
| from datetime import datetime | |
| from fastapi import Request, status | |
| from fastapi.responses import JSONResponse, Response | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from core.database import async_session_maker | |
| from core.api_response import error_response, ErrorCode | |
| from services.credit_service.config import CreditServiceConfig | |
| from services.credit_service.transaction_manager import ( | |
| CreditTransactionManager, | |
| InsufficientCreditsError | |
| ) | |
| from services.credit_service.response_inspector import ResponseInspector | |
| from services.base_service.middleware_chain import ( | |
| BaseServiceMiddleware, | |
| get_request_context, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class CreditMiddleware(BaseServiceMiddleware): | |
| """ | |
| Credit middleware with automatic credit management. | |
| Reserves credits on request, confirms/refunds based on response. | |
| Must run after AuthMiddleware. | |
| """ | |
| SERVICE_NAME = "credit" | |
| async def dispatch(self, request: Request, call_next): | |
| """Process request through credit middleware.""" | |
| if request.method == "OPTIONS": | |
| return await call_next(request) | |
| path = request.url.path | |
| config = CreditServiceConfig.get_config(path) | |
| credit_cost = config.get("cost", 0) | |
| endpoint_type = config.get("type", "free") | |
| # Skip free endpoints | |
| if credit_cost == 0: | |
| return await call_next(request) | |
| # Require authentication for paid endpoints | |
| user = getattr(request.state, 'user', None) | |
| if not user: | |
| return JSONResponse( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| content=error_response( | |
| ErrorCode.UNAUTHORIZED, | |
| "Authentication required for this endpoint" | |
| ) | |
| ) | |
| # Reserve credits | |
| async with async_session_maker() as db: | |
| try: | |
| transaction = await CreditTransactionManager.reserve_credits( | |
| session=db, | |
| user=user, | |
| amount=credit_cost, | |
| source="middleware", | |
| reference_type="request", | |
| reference_id=f"{request.method}:{path}", | |
| reason=f"{request.method} {path}", | |
| metadata={ | |
| "endpoint_type": endpoint_type, | |
| "cost": credit_cost | |
| }, | |
| request=request | |
| ) | |
| await db.commit() | |
| # Store transaction info in request state | |
| request.state.credit_transaction_id = transaction.transaction_id | |
| request.state.endpoint_type = endpoint_type | |
| request.state.credit_cost = credit_cost | |
| request.state.credits_reserved = credit_cost # For gemini router compatibility | |
| self.log_request( | |
| request, | |
| f"Reserved {credit_cost} credits ({endpoint_type}), txn: {transaction.transaction_id}" | |
| ) | |
| except InsufficientCreditsError as e: | |
| await db.rollback() | |
| return JSONResponse( | |
| status_code=status.HTTP_402_PAYMENT_REQUIRED, | |
| content=error_response( | |
| ErrorCode.INSUFFICIENT_CREDITS, | |
| f"Insufficient credits. Required: {credit_cost}, Available: {user.credits}", | |
| { | |
| "credits_required": credit_cost, | |
| "credits_available": user.credits | |
| } | |
| ) | |
| ) | |
| except Exception as e: | |
| await db.rollback() | |
| logger.error(f"Credit reservation failed: {e}", exc_info=True) | |
| return JSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content=error_response( | |
| ErrorCode.SERVER_ERROR, | |
| "Failed to reserve credits" | |
| ) | |
| ) | |
| # Process request | |
| response = await call_next(request) | |
| # Check if credits were reserved | |
| transaction_id = getattr(request.state, 'credit_transaction_id', None) | |
| if not transaction_id: | |
| return response | |
| # Read response body for inspection | |
| response_body = b"" | |
| async for chunk in response.body_iterator: | |
| response_body += chunk | |
| # Inspect response to determine credit handling | |
| response_data = ResponseInspector.parse_response_body(response_body) | |
| inspector = ResponseInspector() | |
| # Confirm or refund based on response | |
| async with async_session_maker() as db: | |
| try: | |
| should_confirm = inspector.should_confirm( | |
| response, endpoint_type, response_data | |
| ) | |
| should_refund = inspector.should_refund( | |
| response, endpoint_type, response_data | |
| ) | |
| if should_confirm: | |
| await CreditTransactionManager.confirm_credits( | |
| session=db, | |
| transaction_id=transaction_id, | |
| metadata={ | |
| "response_status": response.status_code, | |
| "endpoint_type": endpoint_type | |
| } | |
| ) | |
| await db.commit() | |
| self.log_request( | |
| request, | |
| f"Credits confirmed for {transaction_id} (success)" | |
| ) | |
| elif should_refund: | |
| reason = inspector.get_refund_reason(response, response_data) | |
| await CreditTransactionManager.refund_credits( | |
| session=db, | |
| transaction_id=transaction_id, | |
| reason=reason, | |
| metadata={ | |
| "response_status": response.status_code, | |
| "endpoint_type": endpoint_type, | |
| "error": response_data.get("detail") if response_data else None | |
| } | |
| ) | |
| await db.commit() | |
| self.log_request( | |
| request, | |
| f"Credits refunded for {transaction_id}: {reason}" | |
| ) | |
| else: | |
| self.log_request( | |
| request, | |
| f"Credits kept reserved for {transaction_id} (async pending)" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing credit confirmation/refund: {e}", exc_info=True) | |
| return Response( | |
| content=response_body, | |
| status_code=response.status_code, | |
| headers=dict(response.headers), | |
| media_type=response.media_type | |
| ) | |
| __all__ = ['CreditMiddleware'] | |