carraraig's picture
Hello
8816dfd
# -*- coding: utf-8 -*-
import logging
from constant import Constants
import requests
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from collections import OrderedDict
logger = logging.getLogger("Reranker")
@dataclass
class Model:
"""
Abstract base class representing a model served by the HiveGPT Model Router.
Attributes:
name (str): The HuggingFace repository path for the model, e.g., "meta-llama/Meta-Llama-3.1-8B".
alias (str): A shorter, more user-friendly alias or identifier for the model.
openai_endpoint (str): The base openai endpoint through which the model can be accessed.
"""
name: str
alias: str
openai_endpoint: str
@dataclass
class LLMModel(Model):
"""
Represents an LLM served by the HiveGPT Model Router.
Attributes:
name (str): The HuggingFace repository path for the model, e.g., "meta-llama/Meta-Llama-3.1-8B".
alias (str): A shorter, more user-friendly alias or identifier for the model.
openai_endpoint (str): The base openai endpoint through which the model can be accessed.
max_len (int): The maximum sequence length that the model can handle.
"""
max_len: int
class ModelRouter:
"""
A wrapper class that fetches info from the HiveGPT Model Router
"""
def __init__(self, host: str = Constants.MODEL_ROUTER_HOST, port: str = Constants.MODEL_ROUTER_PORT):
"""
Initializes the ModelRouter.
Args:
host (str): The hostname of the Model Router server.
port (int): The port number of the Model Router server.
Note: The ModelRouter will automatically refresh the map of served models upon initialization.
"""
self.host = host
self.port = port
self.models_health_endpoint = f"http://{self.host}:{self.port}/v1/models"
self.served_models: Dict[str, LLMModel] = {}
self.logger = logging.getLogger("HiveGPT Model Router")
self.refresh()
def _generate_openai_base(self, alias: str, base_endpoint: str = "/v1") -> str:
"""
Generates the base OpenAI endpoint URL for a given alias.
Args:
alias (str): The alias of the model.
base_endpoint (str): The base endpoint for the OpenAI API.
Returns:
str: The base OpenAI endpoint URL for the given alias.
"""
return f"http://{self.host}:{self.port}/{alias}{base_endpoint}"
def _sort_language_models(self):
"""
Sort returned models by alias in ascending order
and put the default LLM always on top.
"""
default_model_key = Constants.DEFAULT_LLM_NAME
# Get the default model
default_model = {default_model_key: self.served_models[default_model_key]} if default_model_key in self.served_models else None
# Sort remaining models in ascending order
other_models = {k: v for k, v in self.served_models.items() if k != default_model_key}
sorted_other_models = OrderedDict(sorted(other_models.items(), key=lambda item: item[0]))
# Combine the default model and the sorted models
sorted_llms = sorted_other_models
if default_model is not None:
sorted_llms = OrderedDict(**default_model, **sorted_other_models)
# Update the served_models dictionary
self.served_models = sorted_llms
def refresh(self):
"""Refreshes the map of served models."""
try:
response = requests.get(self.models_health_endpoint)
response.raise_for_status()
models_json = response.json()
models = {}
for model in models_json:
alias = model["model_alias"]
name = model["model_name"]
max_len = model["max_model_len"]
openai_endpoint = self._generate_openai_base(alias=alias)
models[name] = LLMModel(name=name, alias=alias, openai_endpoint=openai_endpoint, max_len=max_len)
self.served_models = models
self._sort_language_models()
self.logger.info("Models map successfully refreshed.")
except requests.RequestException as e:
self.logger.error(f"Failed to refresh models map: {e}")
self.served_models = {}
def get_llm_model(self, name: str) -> Optional[LLMModel]:
"""Gets the LLMModel object for the specified model name.
Args:
name (str): The HuggingFace repository path for the model. for example, "meta-llama/Meta-Llama-3.1-8B"
Returns:
Optional[Model]: The Model object.
Returns None if the model name is not found.
"""
return self.served_models.get(name)
def get_all_llm_models(self) -> Dict[str, LLMModel]:
"""Returns a map of all served LLMs.
Returns:
Dict[str, LLMModel]: A dictionary where keys are LLM names and values are LLMModel objects.
"""
self._sort_language_models()
return self.served_models