apigateway / services /credit_service /transaction_manager.py
jebin2's picture
credit issue fix
2dbfc89
"""
Credit Transaction Manager - Core credit operations service.
All credit operations flow through this service:
- Reserve credits (deduct from balance)
- Refund credits (add back to balance)
- Confirm credits (mark as used)
- Add credits (purchases, bonuses)
Provides complete audit trail with request/response context.
"""
import uuid
import logging
from datetime import datetime
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, desc, func
from fastapi import Request
from core.models import User, CreditTransaction
logger = logging.getLogger(__name__)
# =============================================================================
# Custom Exceptions
# =============================================================================
class InsufficientCreditsError(Exception):
"""Raised when user doesn't have enough credits."""
pass
class TransactionNotFoundError(Exception):
"""Raised when a transaction cannot be found."""
pass
class UserNotFoundError(Exception):
"""Raised when a user cannot be found."""
pass
# =============================================================================
# Credit Transaction Manager
# =============================================================================
class CreditTransactionManager:
"""
Centralized credit transaction management.
All credit operations flow through this service.
"""
@staticmethod
async def reserve_credits(
session: AsyncSession,
user: User,
amount: int,
source: str,
reference_type: Optional[str] = None,
reference_id: Optional[str] = None,
reason: Optional[str] = None,
metadata: Optional[dict] = None,
request: Optional[Request] = None
) -> CreditTransaction:
"""
Reserve credits (deduct from balance).
Called by middleware when request arrives.
Args:
session: Database session
user: User model instance (used for user.id reference)
amount: Number of credits to reserve
source: Source of transaction (e.g., "middleware")
reference_type: Type of reference (e.g., "request", "job")
reference_id: Reference identifier (e.g., "POST:/gemini/video")
reason: Human-readable reason
metadata: Additional context
request: FastAPI request object (for extracting metadata)
Returns:
CreditTransaction record
Raises:
InsufficientCreditsError: If user doesn't have enough credits
UserNotFoundError: If user not found in database
"""
# Fetch fresh user from DB to ensure we have current state
# (user parameter may be detached from AuthMiddleware's closed session)
db_user = await session.get(User, user.id)
if not db_user:
raise UserNotFoundError(f"User {user.id} not found")
if db_user.credits < amount:
raise InsufficientCreditsError(
f"User has {db_user.credits} credits, needs {amount}"
)
# Generate transaction ID
transaction_id = f"ctx_{uuid.uuid4().hex[:16]}"
# Capture request metadata
request_metadata = {}
request_path = None
request_method = None
if request:
request_path = request.url.path
request_method = request.method
request_metadata = {
"path": request_path,
"method": request_method,
"ip": request.client.host if request.client else None,
"user_agent": request.headers.get("user-agent")
}
# Create transaction record
transaction = CreditTransaction(
transaction_id=transaction_id,
user_id=db_user.id,
transaction_type="reserve",
amount=-amount, # Negative for deduction
balance_before=db_user.credits,
balance_after=db_user.credits - amount,
source=source,
reference_type=reference_type,
reference_id=reference_id,
request_path=request_path,
request_method=request_method,
reason=reason or f"Reserved for {request_path or 'operation'}",
extra_data={**request_metadata, **(metadata or {})}
)
# Update user balance
db_user.credits -= amount
db_user.last_used_at = datetime.utcnow()
session.add(transaction)
await session.flush() # Don't commit yet - let caller handle
logger.info(
f"Reserved {amount} credits for user {db_user.user_id}, "
f"transaction: {transaction_id}, new balance: {db_user.credits}"
)
return transaction
@staticmethod
async def confirm_credits(
session: AsyncSession,
transaction_id: str,
metadata: Optional[dict] = None
) -> CreditTransaction:
"""
Confirm credits were legitimately used.
Called by middleware after successful response.
Creates a confirmation record (amount = 0).
Args:
session: Database session
transaction_id: Original reservation transaction ID
metadata: Additional context (response_status, etc.)
Returns:
CreditTransaction record
Raises:
TransactionNotFoundError: If original transaction not found
"""
# Find original reservation
result = await session.execute(
select(CreditTransaction).where(
CreditTransaction.transaction_id == transaction_id
)
)
original = result.scalar_one_or_none()
if not original:
raise TransactionNotFoundError(f"Transaction {transaction_id} not found")
# Create confirmation transaction
confirm_tx = CreditTransaction(
transaction_id=f"cfm_{uuid.uuid4().hex[:16]}",
user_id=original.user_id,
transaction_type="confirm",
amount=0, # No balance change
balance_before=original.balance_after,
balance_after=original.balance_after,
source="middleware",
reference_type=original.reference_type,
reference_id=original.reference_id,
request_path=original.request_path,
request_method=original.request_method,
reason=f"Confirmed usage for {original.transaction_id}",
extra_data={
"original_transaction_id": transaction_id,
**(metadata or {})
}
)
session.add(confirm_tx)
await session.flush()
logger.info(f"Confirmed credits for transaction {transaction_id}")
return confirm_tx
@staticmethod
async def refund_credits(
session: AsyncSession,
transaction_id: str,
reason: str,
metadata: Optional[dict] = None
) -> CreditTransaction:
"""
Refund reserved credits.
Called by middleware after failure or refundable error.
Args:
session: Database session
transaction_id: Original reservation transaction ID
reason: Reason for refund
metadata: Additional context
Returns:
CreditTransaction record
Raises:
TransactionNotFoundError: If original transaction not found
UserNotFoundError: If user not found
"""
# Find original reservation
result = await session.execute(
select(CreditTransaction).where(
CreditTransaction.transaction_id == transaction_id
)
)
original = result.scalar_one_or_none()
if not original:
raise TransactionNotFoundError(f"Transaction {transaction_id} not found")
# Get user
user_result = await session.execute(
select(User).where(User.id == original.user_id)
)
user = user_result.scalar_one_or_none()
if not user:
raise UserNotFoundError(f"User {original.user_id} not found")
# Calculate refund amount (reverse of original deduction)
refund_amount = abs(original.amount)
# Create refund transaction
refund_tx = CreditTransaction(
transaction_id=f"ref_{uuid.uuid4().hex[:16]}",
user_id=user.id,
transaction_type="refund",
amount=refund_amount, # Positive for addition
balance_before=user.credits,
balance_after=user.credits + refund_amount,
source="middleware",
reference_type=original.reference_type,
reference_id=original.reference_id,
request_path=original.request_path,
request_method=original.request_method,
reason=reason,
extra_data={
"original_transaction_id": transaction_id,
**(metadata or {})
}
)
# Update user balance
user.credits += refund_amount
session.add(refund_tx)
await session.flush()
logger.info(
f"Refunded {refund_amount} credits for transaction {transaction_id}, "
f"reason: {reason[:100]}, new balance: {user.credits}"
)
return refund_tx
@staticmethod
async def add_credits(
session: AsyncSession,
user: User,
amount: int,
source: str,
reference_type: Optional[str] = None,
reference_id: Optional[str] = None,
reason: Optional[str] = None,
metadata: Optional[dict] = None
) -> CreditTransaction:
"""
Add credits (purchase, bonus, etc).
Used by payment router only.
Args:
session: Database session
user: User model instance (used for user.id reference)
amount: Number of credits to add
source: Source of transaction (e.g., "payment")
reference_type: Type of reference (e.g., "payment")
reference_id: Reference identifier (e.g., transaction_id)
reason: Human-readable reason
metadata: Additional context
Returns:
CreditTransaction record
Raises:
UserNotFoundError: If user not found in database
"""
# Fetch fresh user from DB to ensure we have current state
# (user parameter may be detached from AuthMiddleware's closed session)
db_user = await session.get(User, user.id)
if not db_user:
raise UserNotFoundError(f"User {user.id} not found")
transaction_id = f"add_{uuid.uuid4().hex[:16]}"
transaction = CreditTransaction(
transaction_id=transaction_id,
user_id=db_user.id,
transaction_type="purchase",
amount=amount,
balance_before=db_user.credits,
balance_after=db_user.credits + amount,
source=source,
reference_type=reference_type,
reference_id=reference_id,
reason=reason,
extra_data=metadata
)
db_user.credits += amount
session.add(transaction)
await session.flush()
logger.info(
f"Added {amount} credits to user {db_user.user_id}, "
f"source: {source}, new balance: {db_user.credits}"
)
return transaction
@staticmethod
async def get_balance(
session: AsyncSession,
user_id: int,
verify: bool = False
) -> int:
"""
Get current balance, optionally verify against transactions.
Args:
session: Database session
user_id: User ID
verify: If True, calculate balance from transactions and compare
Returns:
Current balance
"""
# Get user
result = await session.execute(
select(User).where(User.id == user_id)
)
user = result.scalar_one_or_none()
if not user:
raise UserNotFoundError(f"User {user_id} not found")
if not verify:
return user.credits
# Calculate balance from transactions
tx_result = await session.execute(
select(func.sum(CreditTransaction.amount)).where(
CreditTransaction.user_id == user_id
)
)
calculated_balance = tx_result.scalar() or 0
if calculated_balance != user.credits:
logger.warning(
f"Balance mismatch for user {user_id}: "
f"stored={user.credits}, calculated={calculated_balance}"
)
return user.credits
@staticmethod
async def get_transaction_history(
session: AsyncSession,
user_id: int,
transaction_type: Optional[str] = None,
limit: int = 50,
offset: int = 0
) -> List[CreditTransaction]:
"""
Get transaction history with filters.
Args:
session: Database session
user_id: User ID
transaction_type: Filter by transaction type
limit: Maximum number of results
offset: Offset for pagination
Returns:
List of CreditTransaction records
"""
query = select(CreditTransaction).where(
CreditTransaction.user_id == user_id
)
if transaction_type:
query = query.where(CreditTransaction.transaction_type == transaction_type)
query = query.order_by(desc(CreditTransaction.created_at)).offset(offset).limit(limit)
result = await session.execute(query)
return list(result.scalars().all())
__all__ = [
'CreditTransactionManager',
'InsufficientCreditsError',
'TransactionNotFoundError',
'UserNotFoundError'
]