| | |
| | import logging |
| | from constant import Constants |
| | import requests |
| | from dataclasses import dataclass |
| | from typing import Dict, List, Optional, Union |
| |
|
| | from collections import OrderedDict |
| |
|
| | logger = logging.getLogger("Reranker") |
| |
|
| |
|
| | @dataclass |
| | class Model: |
| | """ |
| | Abstract base class representing a model served by the HiveGPT Model Router. |
| | |
| | Attributes: |
| | name (str): The HuggingFace repository path for the model, e.g., "meta-llama/Meta-Llama-3.1-8B". |
| | alias (str): A shorter, more user-friendly alias or identifier for the model. |
| | openai_endpoint (str): The base openai endpoint through which the model can be accessed. |
| | """ |
| |
|
| | name: str |
| | alias: str |
| | openai_endpoint: str |
| |
|
| |
|
| | @dataclass |
| | class LLMModel(Model): |
| | """ |
| | Represents an LLM served by the HiveGPT Model Router. |
| | |
| | Attributes: |
| | name (str): The HuggingFace repository path for the model, e.g., "meta-llama/Meta-Llama-3.1-8B". |
| | alias (str): A shorter, more user-friendly alias or identifier for the model. |
| | openai_endpoint (str): The base openai endpoint through which the model can be accessed. |
| | max_len (int): The maximum sequence length that the model can handle. |
| | """ |
| |
|
| | max_len: int |
| |
|
| |
|
| |
|
| |
|
| |
|
| | class ModelRouter: |
| | """ |
| | A wrapper class that fetches info from the HiveGPT Model Router |
| | """ |
| |
|
| | def __init__(self, host: str = Constants.MODEL_ROUTER_HOST, port: str = Constants.MODEL_ROUTER_PORT): |
| | """ |
| | Initializes the ModelRouter. |
| | |
| | Args: |
| | host (str): The hostname of the Model Router server. |
| | port (int): The port number of the Model Router server. |
| | |
| | Note: The ModelRouter will automatically refresh the map of served models upon initialization. |
| | """ |
| | self.host = host |
| | self.port = port |
| | self.models_health_endpoint = f"http://{self.host}:{self.port}/v1/models" |
| | self.served_models: Dict[str, LLMModel] = {} |
| | self.logger = logging.getLogger("HiveGPT Model Router") |
| | self.refresh() |
| |
|
| | def _generate_openai_base(self, alias: str, base_endpoint: str = "/v1") -> str: |
| | """ |
| | Generates the base OpenAI endpoint URL for a given alias. |
| | |
| | Args: |
| | alias (str): The alias of the model. |
| | base_endpoint (str): The base endpoint for the OpenAI API. |
| | |
| | Returns: |
| | str: The base OpenAI endpoint URL for the given alias. |
| | """ |
| | return f"http://{self.host}:{self.port}/{alias}{base_endpoint}" |
| | |
| | def _sort_language_models(self): |
| | """ |
| | Sort returned models by alias in ascending order |
| | and put the default LLM always on top. |
| | """ |
| | default_model_key = Constants.DEFAULT_LLM_NAME |
| |
|
| | |
| | default_model = {default_model_key: self.served_models[default_model_key]} if default_model_key in self.served_models else None |
| |
|
| | |
| | other_models = {k: v for k, v in self.served_models.items() if k != default_model_key} |
| | sorted_other_models = OrderedDict(sorted(other_models.items(), key=lambda item: item[0])) |
| |
|
| | |
| | sorted_llms = sorted_other_models |
| | if default_model is not None: |
| | sorted_llms = OrderedDict(**default_model, **sorted_other_models) |
| |
|
| | |
| | self.served_models = sorted_llms |
| |
|
| | def refresh(self): |
| | """Refreshes the map of served models.""" |
| | try: |
| | response = requests.get(self.models_health_endpoint) |
| | response.raise_for_status() |
| | models_json = response.json() |
| |
|
| | models = {} |
| | for model in models_json: |
| | alias = model["model_alias"] |
| | name = model["model_name"] |
| | max_len = model["max_model_len"] |
| | openai_endpoint = self._generate_openai_base(alias=alias) |
| | models[name] = LLMModel(name=name, alias=alias, openai_endpoint=openai_endpoint, max_len=max_len) |
| |
|
| | self.served_models = models |
| | self._sort_language_models() |
| | self.logger.info("Models map successfully refreshed.") |
| |
|
| | except requests.RequestException as e: |
| | self.logger.error(f"Failed to refresh models map: {e}") |
| | self.served_models = {} |
| |
|
| | def get_llm_model(self, name: str) -> Optional[LLMModel]: |
| | """Gets the LLMModel object for the specified model name. |
| | |
| | Args: |
| | name (str): The HuggingFace repository path for the model. for example, "meta-llama/Meta-Llama-3.1-8B" |
| | |
| | Returns: |
| | Optional[Model]: The Model object. |
| | Returns None if the model name is not found. |
| | """ |
| | return self.served_models.get(name) |
| |
|
| | def get_all_llm_models(self) -> Dict[str, LLMModel]: |
| | """Returns a map of all served LLMs. |
| | |
| | Returns: |
| | Dict[str, LLMModel]: A dictionary where keys are LLM names and values are LLMModel objects. |
| | """ |
| | self._sort_language_models() |
| | return self.served_models |