Spaces:
Paused
Paused
| import asyncio | |
| from typing import TYPE_CHECKING, Any | |
| from litellm.utils import calculate_max_parallel_requests | |
| if TYPE_CHECKING: | |
| from litellm.router import Router as _Router | |
| LitellmRouter = _Router | |
| else: | |
| LitellmRouter = Any | |
| class InitalizeCachedClient: | |
| def set_max_parallel_requests_client( | |
| litellm_router_instance: LitellmRouter, model: dict | |
| ): | |
| litellm_params = model.get("litellm_params", {}) | |
| model_id = model["model_info"]["id"] | |
| rpm = litellm_params.get("rpm", None) | |
| tpm = litellm_params.get("tpm", None) | |
| max_parallel_requests = litellm_params.get("max_parallel_requests", None) | |
| calculated_max_parallel_requests = calculate_max_parallel_requests( | |
| rpm=rpm, | |
| max_parallel_requests=max_parallel_requests, | |
| tpm=tpm, | |
| default_max_parallel_requests=litellm_router_instance.default_max_parallel_requests, | |
| ) | |
| if calculated_max_parallel_requests: | |
| semaphore = asyncio.Semaphore(calculated_max_parallel_requests) | |
| cache_key = f"{model_id}_max_parallel_requests_client" | |
| litellm_router_instance.cache.set_cache( | |
| key=cache_key, | |
| value=semaphore, | |
| local_only=True, | |
| ) | |