""" Credit Transaction Manager - Core credit operations service. All credit operations flow through this service: - Reserve credits (deduct from balance) - Refund credits (add back to balance) - Confirm credits (mark as used) - Add credits (purchases, bonuses) Provides complete audit trail with request/response context. """ import uuid import logging from datetime import datetime from typing import Optional, List from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, desc, func from fastapi import Request from core.models import User, CreditTransaction logger = logging.getLogger(__name__) # ============================================================================= # Custom Exceptions # ============================================================================= class InsufficientCreditsError(Exception): """Raised when user doesn't have enough credits.""" pass class TransactionNotFoundError(Exception): """Raised when a transaction cannot be found.""" pass class UserNotFoundError(Exception): """Raised when a user cannot be found.""" pass # ============================================================================= # Credit Transaction Manager # ============================================================================= class CreditTransactionManager: """ Centralized credit transaction management. All credit operations flow through this service. """ @staticmethod async def reserve_credits( session: AsyncSession, user: User, amount: int, source: str, reference_type: Optional[str] = None, reference_id: Optional[str] = None, reason: Optional[str] = None, metadata: Optional[dict] = None, request: Optional[Request] = None ) -> CreditTransaction: """ Reserve credits (deduct from balance). Called by middleware when request arrives. Args: session: Database session user: User model instance (used for user.id reference) amount: Number of credits to reserve source: Source of transaction (e.g., "middleware") reference_type: Type of reference (e.g., "request", "job") reference_id: Reference identifier (e.g., "POST:/gemini/video") reason: Human-readable reason metadata: Additional context request: FastAPI request object (for extracting metadata) Returns: CreditTransaction record Raises: InsufficientCreditsError: If user doesn't have enough credits UserNotFoundError: If user not found in database """ # Fetch fresh user from DB to ensure we have current state # (user parameter may be detached from AuthMiddleware's closed session) db_user = await session.get(User, user.id) if not db_user: raise UserNotFoundError(f"User {user.id} not found") if db_user.credits < amount: raise InsufficientCreditsError( f"User has {db_user.credits} credits, needs {amount}" ) # Generate transaction ID transaction_id = f"ctx_{uuid.uuid4().hex[:16]}" # Capture request metadata request_metadata = {} request_path = None request_method = None if request: request_path = request.url.path request_method = request.method request_metadata = { "path": request_path, "method": request_method, "ip": request.client.host if request.client else None, "user_agent": request.headers.get("user-agent") } # Create transaction record transaction = CreditTransaction( transaction_id=transaction_id, user_id=db_user.id, transaction_type="reserve", amount=-amount, # Negative for deduction balance_before=db_user.credits, balance_after=db_user.credits - amount, source=source, reference_type=reference_type, reference_id=reference_id, request_path=request_path, request_method=request_method, reason=reason or f"Reserved for {request_path or 'operation'}", extra_data={**request_metadata, **(metadata or {})} ) # Update user balance db_user.credits -= amount db_user.last_used_at = datetime.utcnow() session.add(transaction) await session.flush() # Don't commit yet - let caller handle logger.info( f"Reserved {amount} credits for user {db_user.user_id}, " f"transaction: {transaction_id}, new balance: {db_user.credits}" ) return transaction @staticmethod async def confirm_credits( session: AsyncSession, transaction_id: str, metadata: Optional[dict] = None ) -> CreditTransaction: """ Confirm credits were legitimately used. Called by middleware after successful response. Creates a confirmation record (amount = 0). Args: session: Database session transaction_id: Original reservation transaction ID metadata: Additional context (response_status, etc.) Returns: CreditTransaction record Raises: TransactionNotFoundError: If original transaction not found """ # Find original reservation result = await session.execute( select(CreditTransaction).where( CreditTransaction.transaction_id == transaction_id ) ) original = result.scalar_one_or_none() if not original: raise TransactionNotFoundError(f"Transaction {transaction_id} not found") # Create confirmation transaction confirm_tx = CreditTransaction( transaction_id=f"cfm_{uuid.uuid4().hex[:16]}", user_id=original.user_id, transaction_type="confirm", amount=0, # No balance change balance_before=original.balance_after, balance_after=original.balance_after, source="middleware", reference_type=original.reference_type, reference_id=original.reference_id, request_path=original.request_path, request_method=original.request_method, reason=f"Confirmed usage for {original.transaction_id}", extra_data={ "original_transaction_id": transaction_id, **(metadata or {}) } ) session.add(confirm_tx) await session.flush() logger.info(f"Confirmed credits for transaction {transaction_id}") return confirm_tx @staticmethod async def refund_credits( session: AsyncSession, transaction_id: str, reason: str, metadata: Optional[dict] = None ) -> CreditTransaction: """ Refund reserved credits. Called by middleware after failure or refundable error. Args: session: Database session transaction_id: Original reservation transaction ID reason: Reason for refund metadata: Additional context Returns: CreditTransaction record Raises: TransactionNotFoundError: If original transaction not found UserNotFoundError: If user not found """ # Find original reservation result = await session.execute( select(CreditTransaction).where( CreditTransaction.transaction_id == transaction_id ) ) original = result.scalar_one_or_none() if not original: raise TransactionNotFoundError(f"Transaction {transaction_id} not found") # Get user user_result = await session.execute( select(User).where(User.id == original.user_id) ) user = user_result.scalar_one_or_none() if not user: raise UserNotFoundError(f"User {original.user_id} not found") # Calculate refund amount (reverse of original deduction) refund_amount = abs(original.amount) # Create refund transaction refund_tx = CreditTransaction( transaction_id=f"ref_{uuid.uuid4().hex[:16]}", user_id=user.id, transaction_type="refund", amount=refund_amount, # Positive for addition balance_before=user.credits, balance_after=user.credits + refund_amount, source="middleware", reference_type=original.reference_type, reference_id=original.reference_id, request_path=original.request_path, request_method=original.request_method, reason=reason, extra_data={ "original_transaction_id": transaction_id, **(metadata or {}) } ) # Update user balance user.credits += refund_amount session.add(refund_tx) await session.flush() logger.info( f"Refunded {refund_amount} credits for transaction {transaction_id}, " f"reason: {reason[:100]}, new balance: {user.credits}" ) return refund_tx @staticmethod async def add_credits( session: AsyncSession, user: User, amount: int, source: str, reference_type: Optional[str] = None, reference_id: Optional[str] = None, reason: Optional[str] = None, metadata: Optional[dict] = None ) -> CreditTransaction: """ Add credits (purchase, bonus, etc). Used by payment router only. Args: session: Database session user: User model instance (used for user.id reference) amount: Number of credits to add source: Source of transaction (e.g., "payment") reference_type: Type of reference (e.g., "payment") reference_id: Reference identifier (e.g., transaction_id) reason: Human-readable reason metadata: Additional context Returns: CreditTransaction record Raises: UserNotFoundError: If user not found in database """ # Fetch fresh user from DB to ensure we have current state # (user parameter may be detached from AuthMiddleware's closed session) db_user = await session.get(User, user.id) if not db_user: raise UserNotFoundError(f"User {user.id} not found") transaction_id = f"add_{uuid.uuid4().hex[:16]}" transaction = CreditTransaction( transaction_id=transaction_id, user_id=db_user.id, transaction_type="purchase", amount=amount, balance_before=db_user.credits, balance_after=db_user.credits + amount, source=source, reference_type=reference_type, reference_id=reference_id, reason=reason, extra_data=metadata ) db_user.credits += amount session.add(transaction) await session.flush() logger.info( f"Added {amount} credits to user {db_user.user_id}, " f"source: {source}, new balance: {db_user.credits}" ) return transaction @staticmethod async def get_balance( session: AsyncSession, user_id: int, verify: bool = False ) -> int: """ Get current balance, optionally verify against transactions. Args: session: Database session user_id: User ID verify: If True, calculate balance from transactions and compare Returns: Current balance """ # Get user result = await session.execute( select(User).where(User.id == user_id) ) user = result.scalar_one_or_none() if not user: raise UserNotFoundError(f"User {user_id} not found") if not verify: return user.credits # Calculate balance from transactions tx_result = await session.execute( select(func.sum(CreditTransaction.amount)).where( CreditTransaction.user_id == user_id ) ) calculated_balance = tx_result.scalar() or 0 if calculated_balance != user.credits: logger.warning( f"Balance mismatch for user {user_id}: " f"stored={user.credits}, calculated={calculated_balance}" ) return user.credits @staticmethod async def get_transaction_history( session: AsyncSession, user_id: int, transaction_type: Optional[str] = None, limit: int = 50, offset: int = 0 ) -> List[CreditTransaction]: """ Get transaction history with filters. Args: session: Database session user_id: User ID transaction_type: Filter by transaction type limit: Maximum number of results offset: Offset for pagination Returns: List of CreditTransaction records """ query = select(CreditTransaction).where( CreditTransaction.user_id == user_id ) if transaction_type: query = query.where(CreditTransaction.transaction_type == transaction_type) query = query.order_by(desc(CreditTransaction.created_at)).offset(offset).limit(limit) result = await session.execute(query) return list(result.scalars().all()) __all__ = [ 'CreditTransactionManager', 'InsufficientCreditsError', 'TransactionNotFoundError', 'UserNotFoundError' ]