Spaces:
Sleeping
Sleeping
| """ | |
| Credit Transaction Manager - Core credit operations service. | |
| All credit operations flow through this service: | |
| - Reserve credits (deduct from balance) | |
| - Refund credits (add back to balance) | |
| - Confirm credits (mark as used) | |
| - Add credits (purchases, bonuses) | |
| Provides complete audit trail with request/response context. | |
| """ | |
| import uuid | |
| import logging | |
| from datetime import datetime | |
| from typing import Optional, List | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from sqlalchemy import select, desc, func | |
| from fastapi import Request | |
| from core.models import User, CreditTransaction | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # Custom Exceptions | |
| # ============================================================================= | |
| class InsufficientCreditsError(Exception): | |
| """Raised when user doesn't have enough credits.""" | |
| pass | |
| class TransactionNotFoundError(Exception): | |
| """Raised when a transaction cannot be found.""" | |
| pass | |
| class UserNotFoundError(Exception): | |
| """Raised when a user cannot be found.""" | |
| pass | |
| # ============================================================================= | |
| # Credit Transaction Manager | |
| # ============================================================================= | |
| class CreditTransactionManager: | |
| """ | |
| Centralized credit transaction management. | |
| All credit operations flow through this service. | |
| """ | |
| async def reserve_credits( | |
| session: AsyncSession, | |
| user: User, | |
| amount: int, | |
| source: str, | |
| reference_type: Optional[str] = None, | |
| reference_id: Optional[str] = None, | |
| reason: Optional[str] = None, | |
| metadata: Optional[dict] = None, | |
| request: Optional[Request] = None | |
| ) -> CreditTransaction: | |
| """ | |
| Reserve credits (deduct from balance). | |
| Called by middleware when request arrives. | |
| Args: | |
| session: Database session | |
| user: User model instance (used for user.id reference) | |
| amount: Number of credits to reserve | |
| source: Source of transaction (e.g., "middleware") | |
| reference_type: Type of reference (e.g., "request", "job") | |
| reference_id: Reference identifier (e.g., "POST:/gemini/video") | |
| reason: Human-readable reason | |
| metadata: Additional context | |
| request: FastAPI request object (for extracting metadata) | |
| Returns: | |
| CreditTransaction record | |
| Raises: | |
| InsufficientCreditsError: If user doesn't have enough credits | |
| UserNotFoundError: If user not found in database | |
| """ | |
| # Fetch fresh user from DB to ensure we have current state | |
| # (user parameter may be detached from AuthMiddleware's closed session) | |
| db_user = await session.get(User, user.id) | |
| if not db_user: | |
| raise UserNotFoundError(f"User {user.id} not found") | |
| if db_user.credits < amount: | |
| raise InsufficientCreditsError( | |
| f"User has {db_user.credits} credits, needs {amount}" | |
| ) | |
| # Generate transaction ID | |
| transaction_id = f"ctx_{uuid.uuid4().hex[:16]}" | |
| # Capture request metadata | |
| request_metadata = {} | |
| request_path = None | |
| request_method = None | |
| if request: | |
| request_path = request.url.path | |
| request_method = request.method | |
| request_metadata = { | |
| "path": request_path, | |
| "method": request_method, | |
| "ip": request.client.host if request.client else None, | |
| "user_agent": request.headers.get("user-agent") | |
| } | |
| # Create transaction record | |
| transaction = CreditTransaction( | |
| transaction_id=transaction_id, | |
| user_id=db_user.id, | |
| transaction_type="reserve", | |
| amount=-amount, # Negative for deduction | |
| balance_before=db_user.credits, | |
| balance_after=db_user.credits - amount, | |
| source=source, | |
| reference_type=reference_type, | |
| reference_id=reference_id, | |
| request_path=request_path, | |
| request_method=request_method, | |
| reason=reason or f"Reserved for {request_path or 'operation'}", | |
| extra_data={**request_metadata, **(metadata or {})} | |
| ) | |
| # Update user balance | |
| db_user.credits -= amount | |
| db_user.last_used_at = datetime.utcnow() | |
| session.add(transaction) | |
| await session.flush() # Don't commit yet - let caller handle | |
| logger.info( | |
| f"Reserved {amount} credits for user {db_user.user_id}, " | |
| f"transaction: {transaction_id}, new balance: {db_user.credits}" | |
| ) | |
| return transaction | |
| async def confirm_credits( | |
| session: AsyncSession, | |
| transaction_id: str, | |
| metadata: Optional[dict] = None | |
| ) -> CreditTransaction: | |
| """ | |
| Confirm credits were legitimately used. | |
| Called by middleware after successful response. | |
| Creates a confirmation record (amount = 0). | |
| Args: | |
| session: Database session | |
| transaction_id: Original reservation transaction ID | |
| metadata: Additional context (response_status, etc.) | |
| Returns: | |
| CreditTransaction record | |
| Raises: | |
| TransactionNotFoundError: If original transaction not found | |
| """ | |
| # Find original reservation | |
| result = await session.execute( | |
| select(CreditTransaction).where( | |
| CreditTransaction.transaction_id == transaction_id | |
| ) | |
| ) | |
| original = result.scalar_one_or_none() | |
| if not original: | |
| raise TransactionNotFoundError(f"Transaction {transaction_id} not found") | |
| # Create confirmation transaction | |
| confirm_tx = CreditTransaction( | |
| transaction_id=f"cfm_{uuid.uuid4().hex[:16]}", | |
| user_id=original.user_id, | |
| transaction_type="confirm", | |
| amount=0, # No balance change | |
| balance_before=original.balance_after, | |
| balance_after=original.balance_after, | |
| source="middleware", | |
| reference_type=original.reference_type, | |
| reference_id=original.reference_id, | |
| request_path=original.request_path, | |
| request_method=original.request_method, | |
| reason=f"Confirmed usage for {original.transaction_id}", | |
| extra_data={ | |
| "original_transaction_id": transaction_id, | |
| **(metadata or {}) | |
| } | |
| ) | |
| session.add(confirm_tx) | |
| await session.flush() | |
| logger.info(f"Confirmed credits for transaction {transaction_id}") | |
| return confirm_tx | |
| async def refund_credits( | |
| session: AsyncSession, | |
| transaction_id: str, | |
| reason: str, | |
| metadata: Optional[dict] = None | |
| ) -> CreditTransaction: | |
| """ | |
| Refund reserved credits. | |
| Called by middleware after failure or refundable error. | |
| Args: | |
| session: Database session | |
| transaction_id: Original reservation transaction ID | |
| reason: Reason for refund | |
| metadata: Additional context | |
| Returns: | |
| CreditTransaction record | |
| Raises: | |
| TransactionNotFoundError: If original transaction not found | |
| UserNotFoundError: If user not found | |
| """ | |
| # Find original reservation | |
| result = await session.execute( | |
| select(CreditTransaction).where( | |
| CreditTransaction.transaction_id == transaction_id | |
| ) | |
| ) | |
| original = result.scalar_one_or_none() | |
| if not original: | |
| raise TransactionNotFoundError(f"Transaction {transaction_id} not found") | |
| # Get user | |
| user_result = await session.execute( | |
| select(User).where(User.id == original.user_id) | |
| ) | |
| user = user_result.scalar_one_or_none() | |
| if not user: | |
| raise UserNotFoundError(f"User {original.user_id} not found") | |
| # Calculate refund amount (reverse of original deduction) | |
| refund_amount = abs(original.amount) | |
| # Create refund transaction | |
| refund_tx = CreditTransaction( | |
| transaction_id=f"ref_{uuid.uuid4().hex[:16]}", | |
| user_id=user.id, | |
| transaction_type="refund", | |
| amount=refund_amount, # Positive for addition | |
| balance_before=user.credits, | |
| balance_after=user.credits + refund_amount, | |
| source="middleware", | |
| reference_type=original.reference_type, | |
| reference_id=original.reference_id, | |
| request_path=original.request_path, | |
| request_method=original.request_method, | |
| reason=reason, | |
| extra_data={ | |
| "original_transaction_id": transaction_id, | |
| **(metadata or {}) | |
| } | |
| ) | |
| # Update user balance | |
| user.credits += refund_amount | |
| session.add(refund_tx) | |
| await session.flush() | |
| logger.info( | |
| f"Refunded {refund_amount} credits for transaction {transaction_id}, " | |
| f"reason: {reason[:100]}, new balance: {user.credits}" | |
| ) | |
| return refund_tx | |
| async def add_credits( | |
| session: AsyncSession, | |
| user: User, | |
| amount: int, | |
| source: str, | |
| reference_type: Optional[str] = None, | |
| reference_id: Optional[str] = None, | |
| reason: Optional[str] = None, | |
| metadata: Optional[dict] = None | |
| ) -> CreditTransaction: | |
| """ | |
| Add credits (purchase, bonus, etc). | |
| Used by payment router only. | |
| Args: | |
| session: Database session | |
| user: User model instance (used for user.id reference) | |
| amount: Number of credits to add | |
| source: Source of transaction (e.g., "payment") | |
| reference_type: Type of reference (e.g., "payment") | |
| reference_id: Reference identifier (e.g., transaction_id) | |
| reason: Human-readable reason | |
| metadata: Additional context | |
| Returns: | |
| CreditTransaction record | |
| Raises: | |
| UserNotFoundError: If user not found in database | |
| """ | |
| # Fetch fresh user from DB to ensure we have current state | |
| # (user parameter may be detached from AuthMiddleware's closed session) | |
| db_user = await session.get(User, user.id) | |
| if not db_user: | |
| raise UserNotFoundError(f"User {user.id} not found") | |
| transaction_id = f"add_{uuid.uuid4().hex[:16]}" | |
| transaction = CreditTransaction( | |
| transaction_id=transaction_id, | |
| user_id=db_user.id, | |
| transaction_type="purchase", | |
| amount=amount, | |
| balance_before=db_user.credits, | |
| balance_after=db_user.credits + amount, | |
| source=source, | |
| reference_type=reference_type, | |
| reference_id=reference_id, | |
| reason=reason, | |
| extra_data=metadata | |
| ) | |
| db_user.credits += amount | |
| session.add(transaction) | |
| await session.flush() | |
| logger.info( | |
| f"Added {amount} credits to user {db_user.user_id}, " | |
| f"source: {source}, new balance: {db_user.credits}" | |
| ) | |
| return transaction | |
| async def get_balance( | |
| session: AsyncSession, | |
| user_id: int, | |
| verify: bool = False | |
| ) -> int: | |
| """ | |
| Get current balance, optionally verify against transactions. | |
| Args: | |
| session: Database session | |
| user_id: User ID | |
| verify: If True, calculate balance from transactions and compare | |
| Returns: | |
| Current balance | |
| """ | |
| # Get user | |
| result = await session.execute( | |
| select(User).where(User.id == user_id) | |
| ) | |
| user = result.scalar_one_or_none() | |
| if not user: | |
| raise UserNotFoundError(f"User {user_id} not found") | |
| if not verify: | |
| return user.credits | |
| # Calculate balance from transactions | |
| tx_result = await session.execute( | |
| select(func.sum(CreditTransaction.amount)).where( | |
| CreditTransaction.user_id == user_id | |
| ) | |
| ) | |
| calculated_balance = tx_result.scalar() or 0 | |
| if calculated_balance != user.credits: | |
| logger.warning( | |
| f"Balance mismatch for user {user_id}: " | |
| f"stored={user.credits}, calculated={calculated_balance}" | |
| ) | |
| return user.credits | |
| async def get_transaction_history( | |
| session: AsyncSession, | |
| user_id: int, | |
| transaction_type: Optional[str] = None, | |
| limit: int = 50, | |
| offset: int = 0 | |
| ) -> List[CreditTransaction]: | |
| """ | |
| Get transaction history with filters. | |
| Args: | |
| session: Database session | |
| user_id: User ID | |
| transaction_type: Filter by transaction type | |
| limit: Maximum number of results | |
| offset: Offset for pagination | |
| Returns: | |
| List of CreditTransaction records | |
| """ | |
| query = select(CreditTransaction).where( | |
| CreditTransaction.user_id == user_id | |
| ) | |
| if transaction_type: | |
| query = query.where(CreditTransaction.transaction_type == transaction_type) | |
| query = query.order_by(desc(CreditTransaction.created_at)).offset(offset).limit(limit) | |
| result = await session.execute(query) | |
| return list(result.scalars().all()) | |
| __all__ = [ | |
| 'CreditTransactionManager', | |
| 'InsufficientCreditsError', | |
| 'TransactionNotFoundError', | |
| 'UserNotFoundError' | |
| ] | |