Spaces:
Paused
Paused
| # This file runs a health check for the LLM, used on litellm/proxy | |
| import asyncio | |
| import logging | |
| import random | |
| from typing import List, Optional | |
| import litellm | |
| logger = logging.getLogger(__name__) | |
| from litellm.constants import HEALTH_CHECK_TIMEOUT_SECONDS | |
| ILLEGAL_DISPLAY_PARAMS = [ | |
| "messages", | |
| "api_key", | |
| "prompt", | |
| "input", | |
| "vertex_credentials", | |
| "aws_access_key_id", | |
| "aws_secret_access_key", | |
| ] | |
| MINIMAL_DISPLAY_PARAMS = ["model", "mode_error"] | |
| def _get_random_llm_message(): | |
| """ | |
| Get a random message from the LLM. | |
| """ | |
| messages = ["Hey how's it going?", "What's 1 + 1?"] | |
| return [{"role": "user", "content": random.choice(messages)}] | |
| def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True): | |
| """ | |
| Clean the endpoint data for display to users. | |
| """ | |
| endpoint_data.pop("litellm_logging_obj", None) | |
| return ( | |
| {k: v for k, v in endpoint_data.items() if k not in ILLEGAL_DISPLAY_PARAMS} | |
| if details is not False | |
| else {k: v for k, v in endpoint_data.items() if k in MINIMAL_DISPLAY_PARAMS} | |
| ) | |
| def filter_deployments_by_id( | |
| model_list: List, | |
| ) -> List: | |
| seen_ids = set() | |
| filtered_deployments = [] | |
| for deployment in model_list: | |
| _model_info = deployment.get("model_info") or {} | |
| _id = _model_info.get("id") or None | |
| if _id is None: | |
| continue | |
| if _id not in seen_ids: | |
| seen_ids.add(_id) | |
| filtered_deployments.append(deployment) | |
| return filtered_deployments | |
| async def run_with_timeout(task, timeout): | |
| try: | |
| return await asyncio.wait_for(task, timeout) | |
| except asyncio.TimeoutError: | |
| task.cancel() | |
| # Only cancel child tasks of the current task | |
| current_task = asyncio.current_task() | |
| for t in asyncio.all_tasks(): | |
| if t != current_task: | |
| t.cancel() | |
| try: | |
| await asyncio.wait_for(task, 0.1) # Give 100ms for cleanup | |
| except (asyncio.TimeoutError, asyncio.CancelledError, Exception): | |
| pass | |
| return {"error": "Timeout exceeded"} | |
| async def _perform_health_check(model_list: list, details: Optional[bool] = True): | |
| """ | |
| Perform a health check for each model in the list. | |
| """ | |
| tasks = [] | |
| for model in model_list: | |
| litellm_params = model["litellm_params"] | |
| model_info = model.get("model_info", {}) | |
| mode = model_info.get("mode", None) | |
| litellm_params = _update_litellm_params_for_health_check( | |
| model_info, litellm_params | |
| ) | |
| timeout = model_info.get("health_check_timeout") or HEALTH_CHECK_TIMEOUT_SECONDS | |
| task = run_with_timeout( | |
| litellm.ahealth_check( | |
| model["litellm_params"], | |
| mode=mode, | |
| prompt="test from litellm", | |
| input=["test from litellm"], | |
| ), | |
| timeout, | |
| ) | |
| tasks.append(task) | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| healthy_endpoints = [] | |
| unhealthy_endpoints = [] | |
| for is_healthy, model in zip(results, model_list): | |
| litellm_params = model["litellm_params"] | |
| if isinstance(is_healthy, dict) and "error" not in is_healthy: | |
| healthy_endpoints.append( | |
| _clean_endpoint_data({**litellm_params, **is_healthy}, details) | |
| ) | |
| elif isinstance(is_healthy, dict): | |
| unhealthy_endpoints.append( | |
| _clean_endpoint_data({**litellm_params, **is_healthy}, details) | |
| ) | |
| else: | |
| unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details)) | |
| return healthy_endpoints, unhealthy_endpoints | |
| def _update_litellm_params_for_health_check( | |
| model_info: dict, litellm_params: dict | |
| ) -> dict: | |
| """ | |
| Update the litellm params for health check. | |
| - gets a short `messages` param for health check | |
| - updates the `model` param with the `health_check_model` if it exists Doc: https://docs.litellm.ai/docs/proxy/health#wildcard-routes | |
| """ | |
| litellm_params["messages"] = _get_random_llm_message() | |
| _health_check_model = model_info.get("health_check_model", None) | |
| if _health_check_model is not None: | |
| litellm_params["model"] = _health_check_model | |
| return litellm_params | |
| async def perform_health_check( | |
| model_list: list, | |
| model: Optional[str] = None, | |
| cli_model: Optional[str] = None, | |
| details: Optional[bool] = True, | |
| ): | |
| """ | |
| Perform a health check on the system. | |
| Returns: | |
| (bool): True if the health check passes, False otherwise. | |
| """ | |
| if not model_list: | |
| if cli_model: | |
| model_list = [ | |
| {"model_name": cli_model, "litellm_params": {"model": cli_model}} | |
| ] | |
| else: | |
| return [], [] | |
| if model is not None: | |
| _new_model_list = [ | |
| x for x in model_list if x["litellm_params"]["model"] == model | |
| ] | |
| if _new_model_list == []: | |
| _new_model_list = [x for x in model_list if x["model_name"] == model] | |
| model_list = _new_model_list | |
| model_list = filter_deployments_by_id( | |
| model_list=model_list | |
| ) # filter duplicate deployments (e.g. when model alias'es are used) | |
| healthy_endpoints, unhealthy_endpoints = await _perform_health_check( | |
| model_list, details | |
| ) | |
| return healthy_endpoints, unhealthy_endpoints | |