# -*- coding: utf-8 -*- import logging from constant import Constants import requests from dataclasses import dataclass from typing import Dict, List, Optional, Union from collections import OrderedDict logger = logging.getLogger("Reranker") @dataclass class Model: """ Abstract base class representing a model served by the HiveGPT Model Router. Attributes: name (str): The HuggingFace repository path for the model, e.g., "meta-llama/Meta-Llama-3.1-8B". alias (str): A shorter, more user-friendly alias or identifier for the model. openai_endpoint (str): The base openai endpoint through which the model can be accessed. """ name: str alias: str openai_endpoint: str @dataclass class LLMModel(Model): """ Represents an LLM served by the HiveGPT Model Router. Attributes: name (str): The HuggingFace repository path for the model, e.g., "meta-llama/Meta-Llama-3.1-8B". alias (str): A shorter, more user-friendly alias or identifier for the model. openai_endpoint (str): The base openai endpoint through which the model can be accessed. max_len (int): The maximum sequence length that the model can handle. """ max_len: int class ModelRouter: """ A wrapper class that fetches info from the HiveGPT Model Router """ def __init__(self, host: str = Constants.MODEL_ROUTER_HOST, port: str = Constants.MODEL_ROUTER_PORT): """ Initializes the ModelRouter. Args: host (str): The hostname of the Model Router server. port (int): The port number of the Model Router server. Note: The ModelRouter will automatically refresh the map of served models upon initialization. """ self.host = host self.port = port self.models_health_endpoint = f"http://{self.host}:{self.port}/v1/models" self.served_models: Dict[str, LLMModel] = {} self.logger = logging.getLogger("HiveGPT Model Router") self.refresh() def _generate_openai_base(self, alias: str, base_endpoint: str = "/v1") -> str: """ Generates the base OpenAI endpoint URL for a given alias. Args: alias (str): The alias of the model. base_endpoint (str): The base endpoint for the OpenAI API. Returns: str: The base OpenAI endpoint URL for the given alias. """ return f"http://{self.host}:{self.port}/{alias}{base_endpoint}" def _sort_language_models(self): """ Sort returned models by alias in ascending order and put the default LLM always on top. """ default_model_key = Constants.DEFAULT_LLM_NAME # Get the default model default_model = {default_model_key: self.served_models[default_model_key]} if default_model_key in self.served_models else None # Sort remaining models in ascending order other_models = {k: v for k, v in self.served_models.items() if k != default_model_key} sorted_other_models = OrderedDict(sorted(other_models.items(), key=lambda item: item[0])) # Combine the default model and the sorted models sorted_llms = sorted_other_models if default_model is not None: sorted_llms = OrderedDict(**default_model, **sorted_other_models) # Update the served_models dictionary self.served_models = sorted_llms def refresh(self): """Refreshes the map of served models.""" try: response = requests.get(self.models_health_endpoint) response.raise_for_status() models_json = response.json() models = {} for model in models_json: alias = model["model_alias"] name = model["model_name"] max_len = model["max_model_len"] openai_endpoint = self._generate_openai_base(alias=alias) models[name] = LLMModel(name=name, alias=alias, openai_endpoint=openai_endpoint, max_len=max_len) self.served_models = models self._sort_language_models() self.logger.info("Models map successfully refreshed.") except requests.RequestException as e: self.logger.error(f"Failed to refresh models map: {e}") self.served_models = {} def get_llm_model(self, name: str) -> Optional[LLMModel]: """Gets the LLMModel object for the specified model name. Args: name (str): The HuggingFace repository path for the model. for example, "meta-llama/Meta-Llama-3.1-8B" Returns: Optional[Model]: The Model object. Returns None if the model name is not found. """ return self.served_models.get(name) def get_all_llm_models(self) -> Dict[str, LLMModel]: """Returns a map of all served LLMs. Returns: Dict[str, LLMModel]: A dictionary where keys are LLM names and values are LLMModel objects. """ self._sort_language_models() return self.served_models