""" Credit Middleware - Automatic credit management. Reserves credits on request, confirms/refunds based on response. """ import logging import json from datetime import datetime from fastapi import Request, status from fastapi.responses import JSONResponse, Response from sqlalchemy.ext.asyncio import AsyncSession from core.database import async_session_maker from core.api_response import error_response, ErrorCode from services.credit_service.config import CreditServiceConfig from services.credit_service.transaction_manager import ( CreditTransactionManager, InsufficientCreditsError ) from services.credit_service.response_inspector import ResponseInspector from services.base_service.middleware_chain import ( BaseServiceMiddleware, get_request_context, ) logger = logging.getLogger(__name__) class CreditMiddleware(BaseServiceMiddleware): """ Credit middleware with automatic credit management. Reserves credits on request, confirms/refunds based on response. Must run after AuthMiddleware. """ SERVICE_NAME = "credit" async def dispatch(self, request: Request, call_next): """Process request through credit middleware.""" if request.method == "OPTIONS": return await call_next(request) path = request.url.path config = CreditServiceConfig.get_config(path) credit_cost = config.get("cost", 0) endpoint_type = config.get("type", "free") # Skip free endpoints if credit_cost == 0: return await call_next(request) # Require authentication for paid endpoints user = getattr(request.state, 'user', None) if not user: return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content=error_response( ErrorCode.UNAUTHORIZED, "Authentication required for this endpoint" ) ) # Reserve credits async with async_session_maker() as db: try: transaction = await CreditTransactionManager.reserve_credits( session=db, user=user, amount=credit_cost, source="middleware", reference_type="request", reference_id=f"{request.method}:{path}", reason=f"{request.method} {path}", metadata={ "endpoint_type": endpoint_type, "cost": credit_cost }, request=request ) await db.commit() # Store transaction info in request state request.state.credit_transaction_id = transaction.transaction_id request.state.endpoint_type = endpoint_type request.state.credit_cost = credit_cost request.state.credits_reserved = credit_cost # For gemini router compatibility self.log_request( request, f"Reserved {credit_cost} credits ({endpoint_type}), txn: {transaction.transaction_id}" ) except InsufficientCreditsError as e: await db.rollback() return JSONResponse( status_code=status.HTTP_402_PAYMENT_REQUIRED, content=error_response( ErrorCode.INSUFFICIENT_CREDITS, f"Insufficient credits. Required: {credit_cost}, Available: {user.credits}", { "credits_required": credit_cost, "credits_available": user.credits } ) ) except Exception as e: await db.rollback() logger.error(f"Credit reservation failed: {e}", exc_info=True) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=error_response( ErrorCode.SERVER_ERROR, "Failed to reserve credits" ) ) # Process request response = await call_next(request) # Check if credits were reserved transaction_id = getattr(request.state, 'credit_transaction_id', None) if not transaction_id: return response # Read response body for inspection response_body = b"" async for chunk in response.body_iterator: response_body += chunk # Inspect response to determine credit handling response_data = ResponseInspector.parse_response_body(response_body) inspector = ResponseInspector() # Confirm or refund based on response async with async_session_maker() as db: try: should_confirm = inspector.should_confirm( response, endpoint_type, response_data ) should_refund = inspector.should_refund( response, endpoint_type, response_data ) if should_confirm: await CreditTransactionManager.confirm_credits( session=db, transaction_id=transaction_id, metadata={ "response_status": response.status_code, "endpoint_type": endpoint_type } ) await db.commit() self.log_request( request, f"Credits confirmed for {transaction_id} (success)" ) elif should_refund: reason = inspector.get_refund_reason(response, response_data) await CreditTransactionManager.refund_credits( session=db, transaction_id=transaction_id, reason=reason, metadata={ "response_status": response.status_code, "endpoint_type": endpoint_type, "error": response_data.get("detail") if response_data else None } ) await db.commit() self.log_request( request, f"Credits refunded for {transaction_id}: {reason}" ) else: self.log_request( request, f"Credits kept reserved for {transaction_id} (async pending)" ) except Exception as e: logger.error(f"Error processing credit confirmation/refund: {e}", exc_info=True) return Response( content=response_body, status_code=response.status_code, headers=dict(response.headers), media_type=response.media_type ) __all__ = ['CreditMiddleware']