Spaces:
Paused
Paused
| """ | |
| Module responsible for | |
| 1. Writing spend increments to either in memory list of transactions or to redis | |
| 2. Reading increments from redis or in memory list of transactions and committing them to db | |
| """ | |
| import asyncio | |
| import json | |
| import os | |
| import time | |
| import traceback | |
| from datetime import datetime, timedelta | |
| from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast, overload | |
| import litellm | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.caching import DualCache, RedisCache | |
| from litellm.constants import DB_SPEND_UPDATE_JOB_NAME | |
| from litellm.proxy._types import ( | |
| DB_CONNECTION_ERROR_TYPES, | |
| BaseDailySpendTransaction, | |
| DailyTagSpendTransaction, | |
| DailyTeamSpendTransaction, | |
| DailyUserSpendTransaction, | |
| DBSpendUpdateTransactions, | |
| Litellm_EntityType, | |
| LiteLLM_UserTable, | |
| SpendLogsMetadata, | |
| SpendLogsPayload, | |
| SpendUpdateQueueItem, | |
| ) | |
| from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import ( | |
| DailySpendUpdateQueue, | |
| ) | |
| from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager | |
| from litellm.proxy.db.db_transaction_queue.redis_update_buffer import RedisUpdateBuffer | |
| from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue | |
| if TYPE_CHECKING: | |
| from litellm.proxy.utils import PrismaClient, ProxyLogging | |
| else: | |
| PrismaClient = Any | |
| ProxyLogging = Any | |
| class DBSpendUpdateWriter: | |
| """ | |
| Module responsible for | |
| 1. Writing spend increments to either in memory list of transactions or to redis | |
| 2. Reading increments from redis or in memory list of transactions and committing them to db | |
| """ | |
| def __init__( | |
| self, | |
| redis_cache: Optional[RedisCache] = None, | |
| ): | |
| self.redis_cache = redis_cache | |
| self.redis_update_buffer = RedisUpdateBuffer(redis_cache=self.redis_cache) | |
| self.pod_lock_manager = PodLockManager() | |
| self.spend_update_queue = SpendUpdateQueue() | |
| self.daily_spend_update_queue = DailySpendUpdateQueue() | |
| self.daily_team_spend_update_queue = DailySpendUpdateQueue() | |
| self.daily_tag_spend_update_queue = DailySpendUpdateQueue() | |
| async def update_database( | |
| # LiteLLM management object fields | |
| self, | |
| token: Optional[str], | |
| user_id: Optional[str], | |
| end_user_id: Optional[str], | |
| team_id: Optional[str], | |
| org_id: Optional[str], | |
| # Completion object fields | |
| kwargs: Optional[dict], | |
| completion_response: Optional[Union[litellm.ModelResponse, Any, Exception]], | |
| start_time: Optional[datetime], | |
| end_time: Optional[datetime], | |
| response_cost: Optional[float], | |
| ): | |
| from litellm.proxy.proxy_server import ( | |
| disable_spend_logs, | |
| litellm_proxy_budget_name, | |
| prisma_client, | |
| user_api_key_cache, | |
| ) | |
| from litellm.proxy.utils import ProxyUpdateSpend, hash_token | |
| try: | |
| verbose_proxy_logger.debug( | |
| f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}" | |
| ) | |
| if ProxyUpdateSpend.disable_spend_updates() is True: | |
| return | |
| if token is not None and isinstance(token, str) and token.startswith("sk-"): | |
| hashed_token = hash_token(token=token) | |
| else: | |
| hashed_token = token | |
| ## CREATE SPEND LOG PAYLOAD ## | |
| from litellm.proxy.spend_tracking.spend_tracking_utils import ( | |
| get_logging_payload, | |
| ) | |
| payload = get_logging_payload( | |
| kwargs=kwargs, | |
| response_obj=completion_response, | |
| start_time=start_time, | |
| end_time=end_time, | |
| ) | |
| payload["spend"] = response_cost or 0.0 | |
| if isinstance(payload["startTime"], datetime): | |
| payload["startTime"] = payload["startTime"].isoformat() | |
| if isinstance(payload["endTime"], datetime): | |
| payload["endTime"] = payload["endTime"].isoformat() | |
| asyncio.create_task( | |
| self._update_user_db( | |
| response_cost=response_cost, | |
| user_id=user_id, | |
| prisma_client=prisma_client, | |
| user_api_key_cache=user_api_key_cache, | |
| litellm_proxy_budget_name=litellm_proxy_budget_name, | |
| end_user_id=end_user_id, | |
| ) | |
| ) | |
| asyncio.create_task( | |
| self._update_key_db( | |
| response_cost=response_cost, | |
| hashed_token=hashed_token, | |
| prisma_client=prisma_client, | |
| ) | |
| ) | |
| asyncio.create_task( | |
| self._update_team_db( | |
| response_cost=response_cost, | |
| team_id=team_id, | |
| user_id=user_id, | |
| prisma_client=prisma_client, | |
| ) | |
| ) | |
| asyncio.create_task( | |
| self._update_org_db( | |
| response_cost=response_cost, | |
| org_id=org_id, | |
| prisma_client=prisma_client, | |
| ) | |
| ) | |
| if disable_spend_logs is False: | |
| await self._insert_spend_log_to_db( | |
| payload=payload, | |
| prisma_client=prisma_client, | |
| ) | |
| else: | |
| verbose_proxy_logger.info( | |
| "disable_spend_logs=True. Skipping writing spend logs to db. Other spend updates - Key/User/Team table will still occur." | |
| ) | |
| asyncio.create_task( | |
| self.add_spend_log_transaction_to_daily_user_transaction( | |
| payload=payload, | |
| prisma_client=prisma_client, | |
| ) | |
| ) | |
| asyncio.create_task( | |
| self.add_spend_log_transaction_to_daily_team_transaction( | |
| payload=payload, | |
| prisma_client=prisma_client, | |
| ) | |
| ) | |
| asyncio.create_task( | |
| self.add_spend_log_transaction_to_daily_tag_transaction( | |
| payload=payload, | |
| prisma_client=prisma_client, | |
| ) | |
| ) | |
| verbose_proxy_logger.debug("Runs spend update on all tables") | |
| except Exception: | |
| verbose_proxy_logger.debug( | |
| f"Error updating Prisma database: {traceback.format_exc()}" | |
| ) | |
| async def _update_key_db( | |
| self, | |
| response_cost: Optional[float], | |
| hashed_token: Optional[str], | |
| prisma_client: Optional[PrismaClient], | |
| ): | |
| try: | |
| if hashed_token is None or prisma_client is None: | |
| return | |
| await self.spend_update_queue.add_update( | |
| update=SpendUpdateQueueItem( | |
| entity_type=Litellm_EntityType.KEY, | |
| entity_id=hashed_token, | |
| response_cost=response_cost, | |
| ) | |
| ) | |
| except Exception as e: | |
| verbose_proxy_logger.exception( | |
| f"Update Key DB Call failed to execute - {str(e)}" | |
| ) | |
| raise e | |
| async def _update_user_db( | |
| self, | |
| response_cost: Optional[float], | |
| user_id: Optional[str], | |
| prisma_client: Optional[PrismaClient], | |
| user_api_key_cache: DualCache, | |
| litellm_proxy_budget_name: Optional[str], | |
| end_user_id: Optional[str] = None, | |
| ): | |
| """ | |
| - Update that user's row | |
| - Update litellm-proxy-budget row (global proxy spend) | |
| """ | |
| ## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db | |
| existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id) | |
| if existing_user_obj is not None and isinstance(existing_user_obj, dict): | |
| existing_user_obj = LiteLLM_UserTable(**existing_user_obj) | |
| try: | |
| if prisma_client is not None: # update | |
| user_ids = [user_id] | |
| if ( | |
| litellm.max_budget > 0 | |
| ): # track global proxy budget, if user set max budget | |
| user_ids.append(litellm_proxy_budget_name) | |
| for _id in user_ids: | |
| if _id is not None: | |
| await self.spend_update_queue.add_update( | |
| update=SpendUpdateQueueItem( | |
| entity_type=Litellm_EntityType.USER, | |
| entity_id=_id, | |
| response_cost=response_cost, | |
| ) | |
| ) | |
| if end_user_id is not None: | |
| await self.spend_update_queue.add_update( | |
| update=SpendUpdateQueueItem( | |
| entity_type=Litellm_EntityType.END_USER, | |
| entity_id=end_user_id, | |
| response_cost=response_cost, | |
| ) | |
| ) | |
| except Exception as e: | |
| verbose_proxy_logger.info( | |
| "\033[91m" | |
| + f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}" | |
| ) | |
| async def _update_team_db( | |
| self, | |
| response_cost: Optional[float], | |
| team_id: Optional[str], | |
| user_id: Optional[str], | |
| prisma_client: Optional[PrismaClient], | |
| ): | |
| try: | |
| if team_id is None or prisma_client is None: | |
| verbose_proxy_logger.debug( | |
| "track_cost_callback: team_id is None or prisma_client is None. Not tracking spend for team" | |
| ) | |
| return | |
| await self.spend_update_queue.add_update( | |
| update=SpendUpdateQueueItem( | |
| entity_type=Litellm_EntityType.TEAM, | |
| entity_id=team_id, | |
| response_cost=response_cost, | |
| ) | |
| ) | |
| try: | |
| # Track spend of the team member within this team | |
| if user_id is not None: | |
| # key is "team_id::<value>::user_id::<value>" | |
| team_member_key = f"team_id::{team_id}::user_id::{user_id}" | |
| await self.spend_update_queue.add_update( | |
| update=SpendUpdateQueueItem( | |
| entity_type=Litellm_EntityType.TEAM_MEMBER, | |
| entity_id=team_member_key, | |
| response_cost=response_cost, | |
| ) | |
| ) | |
| except Exception: | |
| pass | |
| except Exception as e: | |
| verbose_proxy_logger.info( | |
| f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}" | |
| ) | |
| raise e | |
| async def _update_org_db( | |
| self, | |
| response_cost: Optional[float], | |
| org_id: Optional[str], | |
| prisma_client: Optional[PrismaClient], | |
| ): | |
| try: | |
| if org_id is None or prisma_client is None: | |
| verbose_proxy_logger.debug( | |
| "track_cost_callback: org_id is None or prisma_client is None. Not tracking spend for org" | |
| ) | |
| return | |
| await self.spend_update_queue.add_update( | |
| update=SpendUpdateQueueItem( | |
| entity_type=Litellm_EntityType.ORGANIZATION, | |
| entity_id=org_id, | |
| response_cost=response_cost, | |
| ) | |
| ) | |
| except Exception as e: | |
| verbose_proxy_logger.info( | |
| f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}" | |
| ) | |
| raise e | |
| async def _insert_spend_log_to_db( | |
| self, | |
| payload: Union[dict, SpendLogsPayload], | |
| prisma_client: Optional[PrismaClient] = None, | |
| spend_logs_url: Optional[str] = os.getenv("SPEND_LOGS_URL"), | |
| ) -> Optional[PrismaClient]: | |
| verbose_proxy_logger.info( | |
| "Writing spend log to db - request_id: {}, spend: {}".format( | |
| payload.get("request_id"), payload.get("spend") | |
| ) | |
| ) | |
| if prisma_client is not None and spend_logs_url is not None: | |
| prisma_client.spend_log_transactions.append(payload) | |
| elif prisma_client is not None: | |
| prisma_client.spend_log_transactions.append(payload) | |
| else: | |
| verbose_proxy_logger.debug( | |
| "prisma_client is None. Skipping writing spend logs to db." | |
| ) | |
| return prisma_client | |
| async def db_update_spend_transaction_handler( | |
| self, | |
| prisma_client: PrismaClient, | |
| n_retry_times: int, | |
| proxy_logging_obj: ProxyLogging, | |
| ): | |
| """ | |
| Handles commiting update spend transactions to db | |
| `UPDATES` can lead to deadlocks, hence we handle them separately | |
| Args: | |
| prisma_client: PrismaClient object | |
| n_retry_times: int, number of retry times | |
| proxy_logging_obj: ProxyLogging object | |
| How this works: | |
| - Check `general_settings.use_redis_transaction_buffer` | |
| - If enabled, write in-memory transactions to Redis | |
| - Check if this Pod should read from the DB | |
| else: | |
| - Regular flow of this method | |
| """ | |
| if RedisUpdateBuffer._should_commit_spend_updates_to_redis(): | |
| await self._commit_spend_updates_to_db_with_redis( | |
| prisma_client=prisma_client, | |
| n_retry_times=n_retry_times, | |
| proxy_logging_obj=proxy_logging_obj, | |
| ) | |
| else: | |
| await self._commit_spend_updates_to_db_without_redis_buffer( | |
| prisma_client=prisma_client, | |
| n_retry_times=n_retry_times, | |
| proxy_logging_obj=proxy_logging_obj, | |
| ) | |
| async def _commit_spend_updates_to_db_with_redis( | |
| self, | |
| prisma_client: PrismaClient, | |
| n_retry_times: int, | |
| proxy_logging_obj: ProxyLogging, | |
| ): | |
| """ | |
| Handler to commit spend updates to Redis and attempt to acquire lock to commit to db | |
| This is a v2 scalable approach to first commit spend updates to redis, then commit to db | |
| This minimizes DB Deadlocks since | |
| - All pods only need to write their spend updates to redis | |
| - Only 1 pod will commit to db at a time (based on if it can acquire the lock over writing to DB) | |
| """ | |
| await self.redis_update_buffer.store_in_memory_spend_updates_in_redis( | |
| spend_update_queue=self.spend_update_queue, | |
| daily_spend_update_queue=self.daily_spend_update_queue, | |
| daily_team_spend_update_queue=self.daily_team_spend_update_queue, | |
| daily_tag_spend_update_queue=self.daily_tag_spend_update_queue, | |
| ) | |
| # Only commit from redis to db if this pod is the leader | |
| if await self.pod_lock_manager.acquire_lock( | |
| cronjob_id=DB_SPEND_UPDATE_JOB_NAME, | |
| ): | |
| verbose_proxy_logger.debug("acquired lock for spend updates") | |
| try: | |
| db_spend_update_transactions = ( | |
| await self.redis_update_buffer.get_all_update_transactions_from_redis_buffer() | |
| ) | |
| if db_spend_update_transactions is not None: | |
| await self._commit_spend_updates_to_db( | |
| prisma_client=prisma_client, | |
| n_retry_times=n_retry_times, | |
| proxy_logging_obj=proxy_logging_obj, | |
| db_spend_update_transactions=db_spend_update_transactions, | |
| ) | |
| daily_spend_update_transactions = ( | |
| await self.redis_update_buffer.get_all_daily_spend_update_transactions_from_redis_buffer() | |
| ) | |
| if daily_spend_update_transactions is not None: | |
| await DBSpendUpdateWriter.update_daily_user_spend( | |
| n_retry_times=n_retry_times, | |
| prisma_client=prisma_client, | |
| proxy_logging_obj=proxy_logging_obj, | |
| daily_spend_transactions=daily_spend_update_transactions, | |
| ) | |
| daily_team_spend_update_transactions = ( | |
| await self.redis_update_buffer.get_all_daily_team_spend_update_transactions_from_redis_buffer() | |
| ) | |
| if daily_team_spend_update_transactions is not None: | |
| await DBSpendUpdateWriter.update_daily_team_spend( | |
| n_retry_times=n_retry_times, | |
| prisma_client=prisma_client, | |
| proxy_logging_obj=proxy_logging_obj, | |
| daily_spend_transactions=daily_team_spend_update_transactions, | |
| ) | |
| daily_tag_spend_update_transactions = ( | |
| await self.redis_update_buffer.get_all_daily_tag_spend_update_transactions_from_redis_buffer() | |
| ) | |
| if daily_tag_spend_update_transactions is not None: | |
| await DBSpendUpdateWriter.update_daily_tag_spend( | |
| n_retry_times=n_retry_times, | |
| prisma_client=prisma_client, | |
| proxy_logging_obj=proxy_logging_obj, | |
| daily_spend_transactions=daily_tag_spend_update_transactions, | |
| ) | |
| except Exception as e: | |
| verbose_proxy_logger.error(f"Error committing spend updates: {e}") | |
| finally: | |
| await self.pod_lock_manager.release_lock( | |
| cronjob_id=DB_SPEND_UPDATE_JOB_NAME, | |
| ) | |
| async def _commit_spend_updates_to_db_without_redis_buffer( | |
| self, | |
| prisma_client: PrismaClient, | |
| n_retry_times: int, | |
| proxy_logging_obj: ProxyLogging, | |
| ): | |
| """ | |
| Commits all the spend `UPDATE` transactions to the Database | |
| This is the regular flow of committing to db without using a redis buffer | |
| Note: This flow causes Deadlocks in production (1K RPS+). Use self._commit_spend_updates_to_db_with_redis() instead if you expect 1K+ RPS. | |
| """ | |
| # Aggregate all in memory spend updates (key, user, end_user, team, team_member, org) and commit to db | |
| ################## Spend Update Transactions ################## | |
| db_spend_update_transactions = ( | |
| await self.spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions() | |
| ) | |
| await self._commit_spend_updates_to_db( | |
| prisma_client=prisma_client, | |
| n_retry_times=n_retry_times, | |
| proxy_logging_obj=proxy_logging_obj, | |
| db_spend_update_transactions=db_spend_update_transactions, | |
| ) | |
| ################## Daily Spend Update Transactions ################## | |
| # Aggregate all in memory daily spend transactions and commit to db | |
| daily_spend_update_transactions = cast( | |
| Dict[str, DailyUserSpendTransaction], | |
| await self.daily_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions(), | |
| ) | |
| await DBSpendUpdateWriter.update_daily_user_spend( | |
| n_retry_times=n_retry_times, | |
| prisma_client=prisma_client, | |
| proxy_logging_obj=proxy_logging_obj, | |
| daily_spend_transactions=daily_spend_update_transactions, | |
| ) | |
| ################## Daily Team Spend Update Transactions ################## | |
| # Aggregate all in memory daily team spend transactions and commit to db | |
| daily_team_spend_update_transactions = cast( | |
| Dict[str, DailyTeamSpendTransaction], | |
| await self.daily_team_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions(), | |
| ) | |
| await DBSpendUpdateWriter.update_daily_team_spend( | |
| n_retry_times=n_retry_times, | |
| prisma_client=prisma_client, | |
| proxy_logging_obj=proxy_logging_obj, | |
| daily_spend_transactions=daily_team_spend_update_transactions, | |
| ) | |
| ################## Daily Tag Spend Update Transactions ################## | |
| # Aggregate all in memory daily tag spend transactions and commit to db | |
| daily_tag_spend_update_transactions = cast( | |
| Dict[str, DailyTagSpendTransaction], | |
| await self.daily_tag_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions(), | |
| ) | |
| await DBSpendUpdateWriter.update_daily_tag_spend( | |
| n_retry_times=n_retry_times, | |
| prisma_client=prisma_client, | |
| proxy_logging_obj=proxy_logging_obj, | |
| daily_spend_transactions=daily_tag_spend_update_transactions, | |
| ) | |
| async def _commit_spend_updates_to_db( # noqa: PLR0915 | |
| self, | |
| prisma_client: PrismaClient, | |
| n_retry_times: int, | |
| proxy_logging_obj: ProxyLogging, | |
| db_spend_update_transactions: DBSpendUpdateTransactions, | |
| ): | |
| """ | |
| Commits all the spend `UPDATE` transactions to the Database | |
| """ | |
| from litellm.proxy.utils import ( | |
| ProxyUpdateSpend, | |
| _raise_failed_update_spend_exception, | |
| ) | |
| ### UPDATE USER TABLE ### | |
| user_list_transactions = db_spend_update_transactions["user_list_transactions"] | |
| verbose_proxy_logger.debug( | |
| "User Spend transactions: {}".format(user_list_transactions) | |
| ) | |
| if ( | |
| user_list_transactions is not None | |
| and len(user_list_transactions.keys()) > 0 | |
| ): | |
| for i in range(n_retry_times + 1): | |
| start_time = time.time() | |
| try: | |
| async with prisma_client.db.tx( | |
| timeout=timedelta(seconds=60) | |
| ) as transaction: | |
| async with transaction.batch_() as batcher: | |
| for ( | |
| user_id, | |
| response_cost, | |
| ) in user_list_transactions.items(): | |
| batcher.litellm_usertable.update_many( | |
| where={"user_id": user_id}, | |
| data={"spend": {"increment": response_cost}}, | |
| ) | |
| break | |
| except DB_CONNECTION_ERROR_TYPES as e: | |
| if ( | |
| i >= n_retry_times | |
| ): # If we've reached the maximum number of retries | |
| _raise_failed_update_spend_exception( | |
| e=e, | |
| start_time=start_time, | |
| proxy_logging_obj=proxy_logging_obj, | |
| ) | |
| # Optionally, sleep for a bit before retrying | |
| await asyncio.sleep(2**i) # Exponential backoff | |
| except Exception as e: | |
| _raise_failed_update_spend_exception( | |
| e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj | |
| ) | |
| ### UPDATE END-USER TABLE ### | |
| end_user_list_transactions = db_spend_update_transactions[ | |
| "end_user_list_transactions" | |
| ] | |
| verbose_proxy_logger.debug( | |
| "End-User Spend transactions: {}".format(end_user_list_transactions) | |
| ) | |
| if ( | |
| end_user_list_transactions is not None | |
| and len(end_user_list_transactions.keys()) > 0 | |
| ): | |
| await ProxyUpdateSpend.update_end_user_spend( | |
| n_retry_times=n_retry_times, | |
| prisma_client=prisma_client, | |
| proxy_logging_obj=proxy_logging_obj, | |
| end_user_list_transactions=end_user_list_transactions, | |
| ) | |
| ### UPDATE KEY TABLE ### | |
| key_list_transactions = db_spend_update_transactions["key_list_transactions"] | |
| verbose_proxy_logger.debug( | |
| "KEY Spend transactions: {}".format(key_list_transactions) | |
| ) | |
| if key_list_transactions is not None and len(key_list_transactions.keys()) > 0: | |
| for i in range(n_retry_times + 1): | |
| start_time = time.time() | |
| try: | |
| async with prisma_client.db.tx( | |
| timeout=timedelta(seconds=60) | |
| ) as transaction: | |
| async with transaction.batch_() as batcher: | |
| for ( | |
| token, | |
| response_cost, | |
| ) in key_list_transactions.items(): | |
| batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists | |
| where={"token": token}, | |
| data={"spend": {"increment": response_cost}}, | |
| ) | |
| break | |
| except DB_CONNECTION_ERROR_TYPES as e: | |
| if ( | |
| i >= n_retry_times | |
| ): # If we've reached the maximum number of retries | |
| _raise_failed_update_spend_exception( | |
| e=e, | |
| start_time=start_time, | |
| proxy_logging_obj=proxy_logging_obj, | |
| ) | |
| # Optionally, sleep for a bit before retrying | |
| await asyncio.sleep(2**i) # Exponential backoff | |
| except Exception as e: | |
| _raise_failed_update_spend_exception( | |
| e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj | |
| ) | |
| ### UPDATE TEAM TABLE ### | |
| team_list_transactions = db_spend_update_transactions["team_list_transactions"] | |
| verbose_proxy_logger.debug( | |
| "Team Spend transactions: {}".format(team_list_transactions) | |
| ) | |
| if ( | |
| team_list_transactions is not None | |
| and len(team_list_transactions.keys()) > 0 | |
| ): | |
| for i in range(n_retry_times + 1): | |
| start_time = time.time() | |
| try: | |
| async with prisma_client.db.tx( | |
| timeout=timedelta(seconds=60) | |
| ) as transaction: | |
| async with transaction.batch_() as batcher: | |
| for ( | |
| team_id, | |
| response_cost, | |
| ) in team_list_transactions.items(): | |
| verbose_proxy_logger.debug( | |
| "Updating spend for team id={} by {}".format( | |
| team_id, response_cost | |
| ) | |
| ) | |
| batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists | |
| where={"team_id": team_id}, | |
| data={"spend": {"increment": response_cost}}, | |
| ) | |
| break | |
| except DB_CONNECTION_ERROR_TYPES as e: | |
| if ( | |
| i >= n_retry_times | |
| ): # If we've reached the maximum number of retries | |
| _raise_failed_update_spend_exception( | |
| e=e, | |
| start_time=start_time, | |
| proxy_logging_obj=proxy_logging_obj, | |
| ) | |
| # Optionally, sleep for a bit before retrying | |
| await asyncio.sleep(2**i) # Exponential backoff | |
| except Exception as e: | |
| _raise_failed_update_spend_exception( | |
| e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj | |
| ) | |
| ### UPDATE TEAM Membership TABLE with spend ### | |
| team_member_list_transactions = db_spend_update_transactions[ | |
| "team_member_list_transactions" | |
| ] | |
| verbose_proxy_logger.debug( | |
| "Team Membership Spend transactions: {}".format( | |
| team_member_list_transactions | |
| ) | |
| ) | |
| if ( | |
| team_member_list_transactions is not None | |
| and len(team_member_list_transactions.keys()) > 0 | |
| ): | |
| for i in range(n_retry_times + 1): | |
| start_time = time.time() | |
| try: | |
| async with prisma_client.db.tx( | |
| timeout=timedelta(seconds=60) | |
| ) as transaction: | |
| async with transaction.batch_() as batcher: | |
| for ( | |
| key, | |
| response_cost, | |
| ) in team_member_list_transactions.items(): | |
| # key is "team_id::<value>::user_id::<value>" | |
| team_id = key.split("::")[1] | |
| user_id = key.split("::")[3] | |
| batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists | |
| where={"team_id": team_id, "user_id": user_id}, | |
| data={"spend": {"increment": response_cost}}, | |
| ) | |
| break | |
| except DB_CONNECTION_ERROR_TYPES as e: | |
| if ( | |
| i >= n_retry_times | |
| ): # If we've reached the maximum number of retries | |
| _raise_failed_update_spend_exception( | |
| e=e, | |
| start_time=start_time, | |
| proxy_logging_obj=proxy_logging_obj, | |
| ) | |
| # Optionally, sleep for a bit before retrying | |
| await asyncio.sleep(2**i) # Exponential backoff | |
| except Exception as e: | |
| _raise_failed_update_spend_exception( | |
| e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj | |
| ) | |
| ### UPDATE ORG TABLE ### | |
| org_list_transactions = db_spend_update_transactions["org_list_transactions"] | |
| verbose_proxy_logger.debug( | |
| "Org Spend transactions: {}".format(org_list_transactions) | |
| ) | |
| if org_list_transactions is not None and len(org_list_transactions.keys()) > 0: | |
| for i in range(n_retry_times + 1): | |
| start_time = time.time() | |
| try: | |
| async with prisma_client.db.tx( | |
| timeout=timedelta(seconds=60) | |
| ) as transaction: | |
| async with transaction.batch_() as batcher: | |
| for ( | |
| org_id, | |
| response_cost, | |
| ) in org_list_transactions.items(): | |
| batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists | |
| where={"organization_id": org_id}, | |
| data={"spend": {"increment": response_cost}}, | |
| ) | |
| break | |
| except DB_CONNECTION_ERROR_TYPES as e: | |
| if ( | |
| i >= n_retry_times | |
| ): # If we've reached the maximum number of retries | |
| _raise_failed_update_spend_exception( | |
| e=e, | |
| start_time=start_time, | |
| proxy_logging_obj=proxy_logging_obj, | |
| ) | |
| # Optionally, sleep for a bit before retrying | |
| await asyncio.sleep(2**i) # Exponential backoff | |
| except Exception as e: | |
| _raise_failed_update_spend_exception( | |
| e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj | |
| ) | |
| async def _update_daily_spend( | |
| n_retry_times: int, | |
| prisma_client: PrismaClient, | |
| proxy_logging_obj: ProxyLogging, | |
| daily_spend_transactions: Dict[str, DailyUserSpendTransaction], | |
| entity_type: Literal["user"], | |
| entity_id_field: str, | |
| table_name: str, | |
| unique_constraint_name: str, | |
| ) -> None: | |
| ... | |
| async def _update_daily_spend( | |
| n_retry_times: int, | |
| prisma_client: PrismaClient, | |
| proxy_logging_obj: ProxyLogging, | |
| daily_spend_transactions: Dict[str, DailyTeamSpendTransaction], | |
| entity_type: Literal["team"], | |
| entity_id_field: str, | |
| table_name: str, | |
| unique_constraint_name: str, | |
| ) -> None: | |
| ... | |
| async def _update_daily_spend( | |
| n_retry_times: int, | |
| prisma_client: PrismaClient, | |
| proxy_logging_obj: ProxyLogging, | |
| daily_spend_transactions: Dict[str, DailyTagSpendTransaction], | |
| entity_type: Literal["tag"], | |
| entity_id_field: str, | |
| table_name: str, | |
| unique_constraint_name: str, | |
| ) -> None: | |
| ... | |
| async def _update_daily_spend( | |
| n_retry_times: int, | |
| prisma_client: PrismaClient, | |
| proxy_logging_obj: ProxyLogging, | |
| daily_spend_transactions: Union[ | |
| Dict[str, DailyUserSpendTransaction], | |
| Dict[str, DailyTeamSpendTransaction], | |
| Dict[str, DailyTagSpendTransaction], | |
| ], | |
| entity_type: Literal["user", "team", "tag"], | |
| entity_id_field: str, | |
| table_name: str, | |
| unique_constraint_name: str, | |
| ) -> None: | |
| """ | |
| Generic function to update daily spend for any entity type (user, team, tag) | |
| """ | |
| from litellm.proxy.utils import _raise_failed_update_spend_exception | |
| verbose_proxy_logger.debug( | |
| f"Daily {entity_type.capitalize()} Spend transactions: {len(daily_spend_transactions)}" | |
| ) | |
| BATCH_SIZE = 100 | |
| start_time = time.time() | |
| try: | |
| for i in range(n_retry_times + 1): | |
| try: | |
| transactions_to_process = dict( | |
| list(daily_spend_transactions.items())[:BATCH_SIZE] | |
| ) | |
| if len(transactions_to_process) == 0: | |
| verbose_proxy_logger.debug( | |
| f"No new transactions to process for daily {entity_type} spend update" | |
| ) | |
| break | |
| async with prisma_client.db.batch_() as batcher: | |
| for _, transaction in transactions_to_process.items(): | |
| entity_id = transaction.get(entity_id_field) | |
| if not entity_id: | |
| continue | |
| # Construct the where clause dynamically | |
| where_clause = { | |
| unique_constraint_name: { | |
| entity_id_field: entity_id, | |
| "date": transaction["date"], | |
| "api_key": transaction["api_key"], | |
| "model": transaction["model"], | |
| "custom_llm_provider": transaction.get( | |
| "custom_llm_provider" | |
| ), | |
| } | |
| } | |
| # Get the table dynamically | |
| table = getattr(batcher, table_name) | |
| # Common data structure for both create and update | |
| common_data = { | |
| entity_id_field: entity_id, | |
| "date": transaction["date"], | |
| "api_key": transaction["api_key"], | |
| "model": transaction["model"], | |
| "model_group": transaction.get("model_group"), | |
| "custom_llm_provider": transaction.get( | |
| "custom_llm_provider" | |
| ), | |
| "prompt_tokens": transaction["prompt_tokens"], | |
| "completion_tokens": transaction["completion_tokens"], | |
| "spend": transaction["spend"], | |
| "api_requests": transaction["api_requests"], | |
| "successful_requests": transaction[ | |
| "successful_requests" | |
| ], | |
| "failed_requests": transaction["failed_requests"], | |
| } | |
| # Add cache-related fields if they exist | |
| if "cache_read_input_tokens" in transaction: | |
| common_data[ | |
| "cache_read_input_tokens" | |
| ] = transaction.get("cache_read_input_tokens", 0) | |
| if "cache_creation_input_tokens" in transaction: | |
| common_data[ | |
| "cache_creation_input_tokens" | |
| ] = transaction.get("cache_creation_input_tokens", 0) | |
| # Create update data structure | |
| update_data = { | |
| "prompt_tokens": { | |
| "increment": transaction["prompt_tokens"] | |
| }, | |
| "completion_tokens": { | |
| "increment": transaction["completion_tokens"] | |
| }, | |
| "spend": {"increment": transaction["spend"]}, | |
| "api_requests": { | |
| "increment": transaction["api_requests"] | |
| }, | |
| "successful_requests": { | |
| "increment": transaction["successful_requests"] | |
| }, | |
| "failed_requests": { | |
| "increment": transaction["failed_requests"] | |
| }, | |
| } | |
| # Add cache-related fields to update if they exist | |
| if "cache_read_input_tokens" in transaction: | |
| update_data["cache_read_input_tokens"] = { | |
| "increment": transaction.get( | |
| "cache_read_input_tokens", 0 | |
| ) | |
| } | |
| if "cache_creation_input_tokens" in transaction: | |
| update_data["cache_creation_input_tokens"] = { | |
| "increment": transaction.get( | |
| "cache_creation_input_tokens", 0 | |
| ) | |
| } | |
| table.upsert( | |
| where=where_clause, | |
| data={ | |
| "create": common_data, | |
| "update": update_data, | |
| }, | |
| ) | |
| verbose_proxy_logger.info( | |
| f"Processed {len(transactions_to_process)} daily {entity_type} transactions in {time.time() - start_time:.2f}s" | |
| ) | |
| # Remove processed transactions | |
| for key in transactions_to_process.keys(): | |
| daily_spend_transactions.pop(key, None) | |
| break | |
| except DB_CONNECTION_ERROR_TYPES as e: | |
| if i >= n_retry_times: | |
| _raise_failed_update_spend_exception( | |
| e=e, | |
| start_time=start_time, | |
| proxy_logging_obj=proxy_logging_obj, | |
| ) | |
| await asyncio.sleep(2**i) | |
| except Exception as e: | |
| if "transactions_to_process" in locals(): | |
| for key in transactions_to_process.keys(): # type: ignore | |
| daily_spend_transactions.pop(key, None) | |
| _raise_failed_update_spend_exception( | |
| e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj | |
| ) | |
| async def update_daily_user_spend( | |
| n_retry_times: int, | |
| prisma_client: PrismaClient, | |
| proxy_logging_obj: ProxyLogging, | |
| daily_spend_transactions: Dict[str, DailyUserSpendTransaction], | |
| ): | |
| """ | |
| Batch job to update LiteLLM_DailyUserSpend table using in-memory daily_spend_transactions | |
| """ | |
| await DBSpendUpdateWriter._update_daily_spend( | |
| n_retry_times=n_retry_times, | |
| prisma_client=prisma_client, | |
| proxy_logging_obj=proxy_logging_obj, | |
| daily_spend_transactions=daily_spend_transactions, | |
| entity_type="user", | |
| entity_id_field="user_id", | |
| table_name="litellm_dailyuserspend", | |
| unique_constraint_name="user_id_date_api_key_model_custom_llm_provider", | |
| ) | |
| async def update_daily_team_spend( | |
| n_retry_times: int, | |
| prisma_client: PrismaClient, | |
| proxy_logging_obj: ProxyLogging, | |
| daily_spend_transactions: Dict[str, DailyTeamSpendTransaction], | |
| ): | |
| """ | |
| Batch job to update LiteLLM_DailyTeamSpend table using in-memory daily_spend_transactions | |
| """ | |
| await DBSpendUpdateWriter._update_daily_spend( | |
| n_retry_times=n_retry_times, | |
| prisma_client=prisma_client, | |
| proxy_logging_obj=proxy_logging_obj, | |
| daily_spend_transactions=daily_spend_transactions, | |
| entity_type="team", | |
| entity_id_field="team_id", | |
| table_name="litellm_dailyteamspend", | |
| unique_constraint_name="team_id_date_api_key_model_custom_llm_provider", | |
| ) | |
| async def update_daily_tag_spend( | |
| n_retry_times: int, | |
| prisma_client: PrismaClient, | |
| proxy_logging_obj: ProxyLogging, | |
| daily_spend_transactions: Dict[str, DailyTagSpendTransaction], | |
| ): | |
| """ | |
| Batch job to update LiteLLM_DailyTagSpend table using in-memory daily_spend_transactions | |
| """ | |
| await DBSpendUpdateWriter._update_daily_spend( | |
| n_retry_times=n_retry_times, | |
| prisma_client=prisma_client, | |
| proxy_logging_obj=proxy_logging_obj, | |
| daily_spend_transactions=daily_spend_transactions, | |
| entity_type="tag", | |
| entity_id_field="tag", | |
| table_name="litellm_dailytagspend", | |
| unique_constraint_name="tag_date_api_key_model_custom_llm_provider", | |
| ) | |
| async def _common_add_spend_log_transaction_to_daily_transaction( | |
| self, | |
| payload: Union[dict, SpendLogsPayload], | |
| prisma_client: PrismaClient, | |
| type: Literal["user", "team", "request_tags"] = "user", | |
| ) -> Optional[BaseDailySpendTransaction]: | |
| common_expected_keys = ["startTime", "api_key", "model", "custom_llm_provider"] | |
| if type == "user": | |
| expected_keys = ["user", *common_expected_keys] | |
| elif type == "team": | |
| expected_keys = ["team_id", *common_expected_keys] | |
| elif type == "request_tags": | |
| expected_keys = ["request_tags", *common_expected_keys] | |
| else: | |
| raise ValueError(f"Invalid type: {type}") | |
| if not all(key in payload for key in expected_keys): | |
| verbose_proxy_logger.debug( | |
| f"Missing expected keys: {expected_keys}, in payload, skipping from daily_user_spend_transactions" | |
| ) | |
| return None | |
| request_status = prisma_client.get_request_status(payload) | |
| verbose_proxy_logger.info(f"Logged request status: {request_status}") | |
| _metadata: SpendLogsMetadata = json.loads(payload["metadata"]) | |
| usage_obj = _metadata.get("usage_object", {}) or {} | |
| if isinstance(payload["startTime"], datetime): | |
| start_time = payload["startTime"].isoformat() | |
| date = start_time.split("T")[0] | |
| elif isinstance(payload["startTime"], str): | |
| date = payload["startTime"].split("T")[0] | |
| else: | |
| verbose_proxy_logger.debug( | |
| f"Invalid start time: {payload['startTime']}, skipping from daily_user_spend_transactions" | |
| ) | |
| return None | |
| try: | |
| daily_transaction = BaseDailySpendTransaction( | |
| date=date, | |
| api_key=payload["api_key"], | |
| model=payload["model"], | |
| model_group=payload["model_group"], | |
| custom_llm_provider=payload["custom_llm_provider"], | |
| prompt_tokens=payload["prompt_tokens"], | |
| completion_tokens=payload["completion_tokens"], | |
| spend=payload["spend"], | |
| api_requests=1, | |
| successful_requests=1 if request_status == "success" else 0, | |
| failed_requests=1 if request_status != "success" else 0, | |
| cache_read_input_tokens=usage_obj.get("cache_read_input_tokens", 0) | |
| or 0, | |
| cache_creation_input_tokens=usage_obj.get( | |
| "cache_creation_input_tokens", 0 | |
| ) | |
| or 0, | |
| ) | |
| return daily_transaction | |
| except Exception as e: | |
| raise e | |
| async def add_spend_log_transaction_to_daily_user_transaction( | |
| self, | |
| payload: Union[dict, SpendLogsPayload], | |
| prisma_client: Optional[PrismaClient] = None, | |
| ): | |
| """ | |
| Add a spend log transaction to the `daily_spend_update_queue` | |
| Key = @@unique([user_id, date, api_key, model, custom_llm_provider]) ) | |
| If key exists, update the transaction with the new spend and usage | |
| """ | |
| if prisma_client is None: | |
| verbose_proxy_logger.debug( | |
| "prisma_client is None. Skipping writing spend logs to db." | |
| ) | |
| return | |
| base_daily_transaction = ( | |
| await self._common_add_spend_log_transaction_to_daily_transaction( | |
| payload, prisma_client, "user" | |
| ) | |
| ) | |
| if base_daily_transaction is None: | |
| return | |
| daily_transaction_key = f"{payload['user']}_{base_daily_transaction['date']}_{payload['api_key']}_{payload['model']}_{payload['custom_llm_provider']}" | |
| daily_transaction = DailyUserSpendTransaction( | |
| user_id=payload["user"], **base_daily_transaction | |
| ) | |
| await self.daily_spend_update_queue.add_update( | |
| update={daily_transaction_key: daily_transaction} | |
| ) | |
| async def add_spend_log_transaction_to_daily_team_transaction( | |
| self, | |
| payload: SpendLogsPayload, | |
| prisma_client: Optional[PrismaClient] = None, | |
| ) -> None: | |
| if prisma_client is None: | |
| verbose_proxy_logger.debug( | |
| "prisma_client is None. Skipping writing spend logs to db." | |
| ) | |
| return | |
| base_daily_transaction = ( | |
| await self._common_add_spend_log_transaction_to_daily_transaction( | |
| payload, prisma_client, "team" | |
| ) | |
| ) | |
| if base_daily_transaction is None: | |
| return | |
| if payload["team_id"] is None: | |
| verbose_proxy_logger.debug( | |
| "team_id is None for request. Skipping incrementing team spend." | |
| ) | |
| return | |
| daily_transaction_key = f"{payload['team_id']}_{base_daily_transaction['date']}_{payload['api_key']}_{payload['model']}_{payload['custom_llm_provider']}" | |
| daily_transaction = DailyTeamSpendTransaction( | |
| team_id=payload["team_id"], **base_daily_transaction | |
| ) | |
| await self.daily_team_spend_update_queue.add_update( | |
| update={daily_transaction_key: daily_transaction} | |
| ) | |
| async def add_spend_log_transaction_to_daily_tag_transaction( | |
| self, | |
| payload: SpendLogsPayload, | |
| prisma_client: Optional[PrismaClient] = None, | |
| ) -> None: | |
| if prisma_client is None: | |
| verbose_proxy_logger.debug( | |
| "prisma_client is None. Skipping writing spend logs to db." | |
| ) | |
| return | |
| base_daily_transaction = ( | |
| await self._common_add_spend_log_transaction_to_daily_transaction( | |
| payload, prisma_client, "request_tags" | |
| ) | |
| ) | |
| if base_daily_transaction is None: | |
| return | |
| if payload["request_tags"] is None: | |
| verbose_proxy_logger.debug( | |
| "request_tags is None for request. Skipping incrementing tag spend." | |
| ) | |
| return | |
| request_tags = [] | |
| if isinstance(payload["request_tags"], str): | |
| request_tags = json.loads(payload["request_tags"]) | |
| elif isinstance(payload["request_tags"], list): | |
| request_tags = payload["request_tags"] | |
| else: | |
| raise ValueError(f"Invalid request_tags: {payload['request_tags']}") | |
| for tag in request_tags: | |
| daily_transaction_key = f"{tag}_{base_daily_transaction['date']}_{payload['api_key']}_{payload['model']}_{payload['custom_llm_provider']}" | |
| daily_transaction = DailyTagSpendTransaction( | |
| tag=tag, **base_daily_transaction | |
| ) | |
| await self.daily_tag_spend_update_queue.add_update( | |
| update={daily_transaction_key: daily_transaction} | |
| ) | |