Spaces:
Paused
Paused
| import asyncio | |
| from typing import Dict, List, Optional | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.proxy._types import ( | |
| DBSpendUpdateTransactions, | |
| Litellm_EntityType, | |
| SpendUpdateQueueItem, | |
| ) | |
| from litellm.proxy.db.db_transaction_queue.base_update_queue import ( | |
| BaseUpdateQueue, | |
| service_logger_obj, | |
| ) | |
| from litellm.types.services import ServiceTypes | |
| class SpendUpdateQueue(BaseUpdateQueue): | |
| """ | |
| In memory buffer for spend updates that should be committed to the database | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue() | |
| async def flush_and_get_aggregated_db_spend_update_transactions( | |
| self, | |
| ) -> DBSpendUpdateTransactions: | |
| """Flush all updates from the queue and return all updates aggregated by entity type.""" | |
| updates = await self.flush_all_updates_from_in_memory_queue() | |
| verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates) | |
| return self.get_aggregated_db_spend_update_transactions(updates) | |
| async def add_update(self, update: SpendUpdateQueueItem): | |
| """Enqueue an update to the spend update queue""" | |
| verbose_proxy_logger.debug("Adding update to queue: %s", update) | |
| await self.update_queue.put(update) | |
| # if the queue is full, aggregate the updates | |
| if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE: | |
| verbose_proxy_logger.warning( | |
| "Spend update queue is full. Aggregating all entries in queue to concatenate entries." | |
| ) | |
| await self.aggregate_queue_updates() | |
| async def aggregate_queue_updates(self): | |
| """Concatenate all updates in the queue to reduce the size of in-memory queue""" | |
| updates: List[ | |
| SpendUpdateQueueItem | |
| ] = await self.flush_all_updates_from_in_memory_queue() | |
| aggregated_updates = self._get_aggregated_spend_update_queue_item(updates) | |
| for update in aggregated_updates: | |
| await self.update_queue.put(update) | |
| return | |
| def _get_aggregated_spend_update_queue_item( | |
| self, updates: List[SpendUpdateQueueItem] | |
| ) -> List[SpendUpdateQueueItem]: | |
| """ | |
| This is used to reduce the size of the in-memory queue by aggregating updates by entity type + id | |
| Aggregate updates by entity type + id | |
| eg. | |
| ``` | |
| [ | |
| { | |
| "entity_type": "user", | |
| "entity_id": "123", | |
| "response_cost": 100 | |
| }, | |
| { | |
| "entity_type": "user", | |
| "entity_id": "123", | |
| "response_cost": 200 | |
| } | |
| ] | |
| ``` | |
| becomes | |
| ``` | |
| [ | |
| { | |
| "entity_type": "user", | |
| "entity_id": "123", | |
| "response_cost": 300 | |
| } | |
| ] | |
| ``` | |
| """ | |
| verbose_proxy_logger.debug( | |
| "Aggregating spend updates, current queue size: %s", | |
| self.update_queue.qsize(), | |
| ) | |
| aggregated_spend_updates: List[SpendUpdateQueueItem] = [] | |
| _in_memory_map: Dict[str, SpendUpdateQueueItem] = {} | |
| """ | |
| Used for combining several updates into a single update | |
| Key=entity_type:entity_id | |
| Value=SpendUpdateQueueItem | |
| """ | |
| for update in updates: | |
| _key = f"{update.get('entity_type')}:{update.get('entity_id')}" | |
| if _key not in _in_memory_map: | |
| _in_memory_map[_key] = update | |
| else: | |
| current_cost = _in_memory_map[_key].get("response_cost", 0) or 0 | |
| update_cost = update.get("response_cost", 0) or 0 | |
| _in_memory_map[_key]["response_cost"] = current_cost + update_cost | |
| for _key, update in _in_memory_map.items(): | |
| aggregated_spend_updates.append(update) | |
| verbose_proxy_logger.debug( | |
| "Aggregated spend updates: %s", aggregated_spend_updates | |
| ) | |
| return aggregated_spend_updates | |
| def get_aggregated_db_spend_update_transactions( | |
| self, updates: List[SpendUpdateQueueItem] | |
| ) -> DBSpendUpdateTransactions: | |
| """Aggregate updates by entity type.""" | |
| # Initialize all transaction lists as empty dicts | |
| db_spend_update_transactions = DBSpendUpdateTransactions( | |
| user_list_transactions={}, | |
| end_user_list_transactions={}, | |
| key_list_transactions={}, | |
| team_list_transactions={}, | |
| team_member_list_transactions={}, | |
| org_list_transactions={}, | |
| ) | |
| # Map entity types to their corresponding transaction dictionary keys | |
| entity_type_to_dict_key = { | |
| Litellm_EntityType.USER: "user_list_transactions", | |
| Litellm_EntityType.END_USER: "end_user_list_transactions", | |
| Litellm_EntityType.KEY: "key_list_transactions", | |
| Litellm_EntityType.TEAM: "team_list_transactions", | |
| Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions", | |
| Litellm_EntityType.ORGANIZATION: "org_list_transactions", | |
| } | |
| for update in updates: | |
| entity_type = update.get("entity_type") | |
| entity_id = update.get("entity_id") or "" | |
| response_cost = update.get("response_cost") or 0 | |
| if entity_type is None: | |
| verbose_proxy_logger.debug( | |
| "Skipping update spend for update: %s, because entity_type is None", | |
| update, | |
| ) | |
| continue | |
| dict_key = entity_type_to_dict_key.get(entity_type) | |
| if dict_key is None: | |
| verbose_proxy_logger.debug( | |
| "Skipping update spend for update: %s, because entity_type is not in entity_type_to_dict_key", | |
| update, | |
| ) | |
| continue # Skip unknown entity types | |
| # Type-safe access using if/elif statements | |
| if dict_key == "user_list_transactions": | |
| transactions_dict = db_spend_update_transactions[ | |
| "user_list_transactions" | |
| ] | |
| elif dict_key == "end_user_list_transactions": | |
| transactions_dict = db_spend_update_transactions[ | |
| "end_user_list_transactions" | |
| ] | |
| elif dict_key == "key_list_transactions": | |
| transactions_dict = db_spend_update_transactions[ | |
| "key_list_transactions" | |
| ] | |
| elif dict_key == "team_list_transactions": | |
| transactions_dict = db_spend_update_transactions[ | |
| "team_list_transactions" | |
| ] | |
| elif dict_key == "team_member_list_transactions": | |
| transactions_dict = db_spend_update_transactions[ | |
| "team_member_list_transactions" | |
| ] | |
| elif dict_key == "org_list_transactions": | |
| transactions_dict = db_spend_update_transactions[ | |
| "org_list_transactions" | |
| ] | |
| else: | |
| continue | |
| if transactions_dict is None: | |
| transactions_dict = {} | |
| # type ignore: dict_key is guaranteed to be one of "one of ("user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions")" | |
| db_spend_update_transactions[dict_key] = transactions_dict # type: ignore | |
| if entity_id not in transactions_dict: | |
| transactions_dict[entity_id] = 0 | |
| transactions_dict[entity_id] += response_cost or 0 | |
| return db_spend_update_transactions | |
| async def _emit_new_item_added_to_queue_event( | |
| self, | |
| queue_size: Optional[int] = None, | |
| ): | |
| asyncio.create_task( | |
| service_logger_obj.async_service_success_hook( | |
| service=ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE, | |
| duration=0, | |
| call_type="_emit_new_item_added_to_queue_event", | |
| event_metadata={ | |
| "gauge_labels": ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE, | |
| "gauge_value": queue_size, | |
| }, | |
| ) | |
| ) | |