Spaces:
Paused
Paused
| import logging | |
| import os | |
| from collections.abc import Callable, Generator, Iterable, Sequence | |
| from typing import IO, Any, Optional, Union, cast | |
| from core.entities.embedding_type import EmbeddingInputType | |
| from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle | |
| from core.entities.provider_entities import ModelLoadBalancingConfiguration | |
| from core.errors.error import ProviderTokenNotInitError | |
| from core.model_runtime.callbacks.base_callback import Callback | |
| from core.model_runtime.entities.llm_entities import LLMResult | |
| from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool | |
| from core.model_runtime.entities.model_entities import ModelType | |
| from core.model_runtime.entities.rerank_entities import RerankResult | |
| from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult | |
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError | |
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |
| from core.model_runtime.model_providers.__base.moderation_model import ModerationModel | |
| from core.model_runtime.model_providers.__base.rerank_model import RerankModel | |
| from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel | |
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |
| from core.model_runtime.model_providers.__base.tts_model import TTSModel | |
| from core.provider_manager import ProviderManager | |
| from extensions.ext_redis import redis_client | |
| from models.provider import ProviderType | |
| logger = logging.getLogger(__name__) | |
| class ModelInstance: | |
| """ | |
| Model instance class | |
| """ | |
| def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: | |
| self.provider_model_bundle = provider_model_bundle | |
| self.model = model | |
| self.provider = provider_model_bundle.configuration.provider.provider | |
| self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) | |
| self.model_type_instance = self.provider_model_bundle.model_type_instance | |
| self.load_balancing_manager = self._get_load_balancing_manager( | |
| configuration=provider_model_bundle.configuration, | |
| model_type=provider_model_bundle.model_type_instance.model_type, | |
| model=model, | |
| credentials=self.credentials, | |
| ) | |
| def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: | |
| """ | |
| Fetch credentials from provider model bundle | |
| :param provider_model_bundle: provider model bundle | |
| :param model: model name | |
| :return: | |
| """ | |
| configuration = provider_model_bundle.configuration | |
| model_type = provider_model_bundle.model_type_instance.model_type | |
| credentials = configuration.get_current_credentials(model_type=model_type, model=model) | |
| if credentials is None: | |
| raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.") | |
| return credentials | |
| def _get_load_balancing_manager( | |
| configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict | |
| ) -> Optional["LBModelManager"]: | |
| """ | |
| Get load balancing model credentials | |
| :param configuration: provider configuration | |
| :param model_type: model type | |
| :param model: model name | |
| :param credentials: model credentials | |
| :return: | |
| """ | |
| if configuration.model_settings and configuration.using_provider_type == ProviderType.CUSTOM: | |
| current_model_setting = None | |
| # check if model is disabled by admin | |
| for model_setting in configuration.model_settings: | |
| if model_setting.model_type == model_type and model_setting.model == model: | |
| current_model_setting = model_setting | |
| break | |
| # check if load balancing is enabled | |
| if current_model_setting and current_model_setting.load_balancing_configs: | |
| # use load balancing proxy to choose credentials | |
| lb_model_manager = LBModelManager( | |
| tenant_id=configuration.tenant_id, | |
| provider=configuration.provider.provider, | |
| model_type=model_type, | |
| model=model, | |
| load_balancing_configs=current_model_setting.load_balancing_configs, | |
| managed_credentials=credentials if configuration.custom_configuration.provider else None, | |
| ) | |
| return lb_model_manager | |
| return None | |
| def invoke_llm( | |
| self, | |
| prompt_messages: list[PromptMessage], | |
| model_parameters: Optional[dict] = None, | |
| tools: Sequence[PromptMessageTool] | None = None, | |
| stop: Optional[list[str]] = None, | |
| stream: bool = True, | |
| user: Optional[str] = None, | |
| callbacks: Optional[list[Callback]] = None, | |
| ) -> Union[LLMResult, Generator]: | |
| """ | |
| Invoke large language model | |
| :param prompt_messages: prompt messages | |
| :param model_parameters: model parameters | |
| :param tools: tools for tool calling | |
| :param stop: stop words | |
| :param stream: is stream response | |
| :param user: unique user id | |
| :param callbacks: callbacks | |
| :return: full response or stream response chunk generator result | |
| """ | |
| if not isinstance(self.model_type_instance, LargeLanguageModel): | |
| raise Exception("Model type instance is not LargeLanguageModel") | |
| self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) | |
| return self._round_robin_invoke( | |
| function=self.model_type_instance.invoke, | |
| model=self.model, | |
| credentials=self.credentials, | |
| prompt_messages=prompt_messages, | |
| model_parameters=model_parameters, | |
| tools=tools, | |
| stop=stop, | |
| stream=stream, | |
| user=user, | |
| callbacks=callbacks, | |
| ) | |
| def get_llm_num_tokens( | |
| self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None | |
| ) -> int: | |
| """ | |
| Get number of tokens for llm | |
| :param prompt_messages: prompt messages | |
| :param tools: tools for tool calling | |
| :return: | |
| """ | |
| if not isinstance(self.model_type_instance, LargeLanguageModel): | |
| raise Exception("Model type instance is not LargeLanguageModel") | |
| self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) | |
| return self._round_robin_invoke( | |
| function=self.model_type_instance.get_num_tokens, | |
| model=self.model, | |
| credentials=self.credentials, | |
| prompt_messages=prompt_messages, | |
| tools=tools, | |
| ) | |
| def invoke_text_embedding( | |
| self, texts: list[str], user: Optional[str] = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT | |
| ) -> TextEmbeddingResult: | |
| """ | |
| Invoke large language model | |
| :param texts: texts to embed | |
| :param user: unique user id | |
| :param input_type: input type | |
| :return: embeddings result | |
| """ | |
| if not isinstance(self.model_type_instance, TextEmbeddingModel): | |
| raise Exception("Model type instance is not TextEmbeddingModel") | |
| self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) | |
| return self._round_robin_invoke( | |
| function=self.model_type_instance.invoke, | |
| model=self.model, | |
| credentials=self.credentials, | |
| texts=texts, | |
| user=user, | |
| input_type=input_type, | |
| ) | |
| def get_text_embedding_num_tokens(self, texts: list[str]) -> int: | |
| """ | |
| Get number of tokens for text embedding | |
| :param texts: texts to embed | |
| :return: | |
| """ | |
| if not isinstance(self.model_type_instance, TextEmbeddingModel): | |
| raise Exception("Model type instance is not TextEmbeddingModel") | |
| self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) | |
| return self._round_robin_invoke( | |
| function=self.model_type_instance.get_num_tokens, | |
| model=self.model, | |
| credentials=self.credentials, | |
| texts=texts, | |
| ) | |
| def invoke_rerank( | |
| self, | |
| query: str, | |
| docs: list[str], | |
| score_threshold: Optional[float] = None, | |
| top_n: Optional[int] = None, | |
| user: Optional[str] = None, | |
| ) -> RerankResult: | |
| """ | |
| Invoke rerank model | |
| :param query: search query | |
| :param docs: docs for reranking | |
| :param score_threshold: score threshold | |
| :param top_n: top n | |
| :param user: unique user id | |
| :return: rerank result | |
| """ | |
| if not isinstance(self.model_type_instance, RerankModel): | |
| raise Exception("Model type instance is not RerankModel") | |
| self.model_type_instance = cast(RerankModel, self.model_type_instance) | |
| return self._round_robin_invoke( | |
| function=self.model_type_instance.invoke, | |
| model=self.model, | |
| credentials=self.credentials, | |
| query=query, | |
| docs=docs, | |
| score_threshold=score_threshold, | |
| top_n=top_n, | |
| user=user, | |
| ) | |
| def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: | |
| """ | |
| Invoke moderation model | |
| :param text: text to moderate | |
| :param user: unique user id | |
| :return: false if text is safe, true otherwise | |
| """ | |
| if not isinstance(self.model_type_instance, ModerationModel): | |
| raise Exception("Model type instance is not ModerationModel") | |
| self.model_type_instance = cast(ModerationModel, self.model_type_instance) | |
| return self._round_robin_invoke( | |
| function=self.model_type_instance.invoke, | |
| model=self.model, | |
| credentials=self.credentials, | |
| text=text, | |
| user=user, | |
| ) | |
| def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: | |
| """ | |
| Invoke large language model | |
| :param file: audio file | |
| :param user: unique user id | |
| :return: text for given audio file | |
| """ | |
| if not isinstance(self.model_type_instance, Speech2TextModel): | |
| raise Exception("Model type instance is not Speech2TextModel") | |
| self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) | |
| return self._round_robin_invoke( | |
| function=self.model_type_instance.invoke, | |
| model=self.model, | |
| credentials=self.credentials, | |
| file=file, | |
| user=user, | |
| ) | |
| def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]: | |
| """ | |
| Invoke large language tts model | |
| :param content_text: text content to be translated | |
| :param tenant_id: user tenant id | |
| :param voice: model timbre | |
| :param user: unique user id | |
| :return: text for given audio file | |
| """ | |
| if not isinstance(self.model_type_instance, TTSModel): | |
| raise Exception("Model type instance is not TTSModel") | |
| self.model_type_instance = cast(TTSModel, self.model_type_instance) | |
| return self._round_robin_invoke( | |
| function=self.model_type_instance.invoke, | |
| model=self.model, | |
| credentials=self.credentials, | |
| content_text=content_text, | |
| user=user, | |
| tenant_id=tenant_id, | |
| voice=voice, | |
| ) | |
| def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): | |
| """ | |
| Round-robin invoke | |
| :param function: function to invoke | |
| :param args: function args | |
| :param kwargs: function kwargs | |
| :return: | |
| """ | |
| if not self.load_balancing_manager: | |
| return function(*args, **kwargs) | |
| last_exception = None | |
| while True: | |
| lb_config = self.load_balancing_manager.fetch_next() | |
| if not lb_config: | |
| if not last_exception: | |
| raise ProviderTokenNotInitError("Model credentials is not initialized.") | |
| else: | |
| raise last_exception | |
| try: | |
| if "credentials" in kwargs: | |
| del kwargs["credentials"] | |
| return function(*args, **kwargs, credentials=lb_config.credentials) | |
| except InvokeRateLimitError as e: | |
| # expire in 60 seconds | |
| self.load_balancing_manager.cooldown(lb_config, expire=60) | |
| last_exception = e | |
| continue | |
| except (InvokeAuthorizationError, InvokeConnectionError) as e: | |
| # expire in 10 seconds | |
| self.load_balancing_manager.cooldown(lb_config, expire=10) | |
| last_exception = e | |
| continue | |
| except Exception as e: | |
| raise e | |
| def get_tts_voices(self, language: Optional[str] = None) -> list: | |
| """ | |
| Invoke large language tts model voices | |
| :param language: tts language | |
| :return: tts model voices | |
| """ | |
| if not isinstance(self.model_type_instance, TTSModel): | |
| raise Exception("Model type instance is not TTSModel") | |
| self.model_type_instance = cast(TTSModel, self.model_type_instance) | |
| return self.model_type_instance.get_tts_model_voices( | |
| model=self.model, credentials=self.credentials, language=language | |
| ) | |
| class ModelManager: | |
| def __init__(self) -> None: | |
| self._provider_manager = ProviderManager() | |
| def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: | |
| """ | |
| Get model instance | |
| :param tenant_id: tenant id | |
| :param provider: provider name | |
| :param model_type: model type | |
| :param model: model name | |
| :return: | |
| """ | |
| if not provider: | |
| return self.get_default_model_instance(tenant_id, model_type) | |
| provider_model_bundle = self._provider_manager.get_provider_model_bundle( | |
| tenant_id=tenant_id, provider=provider, model_type=model_type | |
| ) | |
| return ModelInstance(provider_model_bundle, model) | |
| def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: | |
| """ | |
| Return first provider and the first model in the provider | |
| :param tenant_id: tenant id | |
| :param model_type: model type | |
| :return: provider name, model name | |
| """ | |
| return self._provider_manager.get_first_provider_first_model(tenant_id, model_type) | |
| def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance: | |
| """ | |
| Get default model instance | |
| :param tenant_id: tenant id | |
| :param model_type: model type | |
| :return: | |
| """ | |
| default_model_entity = self._provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type) | |
| if not default_model_entity: | |
| raise ProviderTokenNotInitError(f"Default model not found for {model_type}") | |
| return self.get_model_instance( | |
| tenant_id=tenant_id, | |
| provider=default_model_entity.provider.provider, | |
| model_type=model_type, | |
| model=default_model_entity.model, | |
| ) | |
| class LBModelManager: | |
| def __init__( | |
| self, | |
| tenant_id: str, | |
| provider: str, | |
| model_type: ModelType, | |
| model: str, | |
| load_balancing_configs: list[ModelLoadBalancingConfiguration], | |
| managed_credentials: Optional[dict] = None, | |
| ) -> None: | |
| """ | |
| Load balancing model manager | |
| :param tenant_id: tenant_id | |
| :param provider: provider | |
| :param model_type: model_type | |
| :param model: model name | |
| :param load_balancing_configs: all load balancing configurations | |
| :param managed_credentials: credentials if load balancing configuration name is __inherit__ | |
| """ | |
| self._tenant_id = tenant_id | |
| self._provider = provider | |
| self._model_type = model_type | |
| self._model = model | |
| self._load_balancing_configs = load_balancing_configs | |
| for load_balancing_config in self._load_balancing_configs[:]: # Iterate over a shallow copy of the list | |
| if load_balancing_config.name == "__inherit__": | |
| if not managed_credentials: | |
| # remove __inherit__ if managed credentials is not provided | |
| self._load_balancing_configs.remove(load_balancing_config) | |
| else: | |
| load_balancing_config.credentials = managed_credentials | |
| def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: | |
| """ | |
| Get next model load balancing config | |
| Strategy: Round Robin | |
| :return: | |
| """ | |
| cache_key = "model_lb_index:{}:{}:{}:{}".format( | |
| self._tenant_id, self._provider, self._model_type.value, self._model | |
| ) | |
| cooldown_load_balancing_configs = [] | |
| max_index = len(self._load_balancing_configs) | |
| while True: | |
| current_index = redis_client.incr(cache_key) | |
| current_index = cast(int, current_index) | |
| if current_index >= 10000000: | |
| current_index = 1 | |
| redis_client.set(cache_key, current_index) | |
| redis_client.expire(cache_key, 3600) | |
| if current_index > max_index: | |
| current_index = current_index % max_index | |
| real_index = current_index - 1 | |
| if real_index > max_index: | |
| real_index = 0 | |
| config = self._load_balancing_configs[real_index] | |
| if self.in_cooldown(config): | |
| cooldown_load_balancing_configs.append(config) | |
| if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs): | |
| # all configs are in cooldown | |
| return None | |
| continue | |
| if bool(os.environ.get("DEBUG", "False").lower() == "true"): | |
| logger.info( | |
| f"Model LB\nid: {config.id}\nname:{config.name}\n" | |
| f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" | |
| f"model_type: {self._model_type.value}\nmodel: {self._model}" | |
| ) | |
| return config | |
| return None | |
| def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None: | |
| """ | |
| Cooldown model load balancing config | |
| :param config: model load balancing config | |
| :param expire: cooldown time | |
| :return: | |
| """ | |
| cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( | |
| self._tenant_id, self._provider, self._model_type.value, self._model, config.id | |
| ) | |
| redis_client.setex(cooldown_cache_key, expire, "true") | |
| def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: | |
| """ | |
| Check if model load balancing config is in cooldown | |
| :param config: model load balancing config | |
| :return: | |
| """ | |
| cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( | |
| self._tenant_id, self._provider, self._model_type.value, self._model, config.id | |
| ) | |
| res = redis_client.exists(cooldown_cache_key) | |
| res = cast(bool, res) | |
| return res | |
| def get_config_in_cooldown_and_ttl( | |
| tenant_id: str, provider: str, model_type: ModelType, model: str, config_id: str | |
| ) -> tuple[bool, int]: | |
| """ | |
| Get model load balancing config is in cooldown and ttl | |
| :param tenant_id: workspace id | |
| :param provider: provider name | |
| :param model_type: model type | |
| :param model: model name | |
| :param config_id: model load balancing config id | |
| :return: | |
| """ | |
| cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( | |
| tenant_id, provider, model_type.value, model, config_id | |
| ) | |
| ttl = redis_client.ttl(cooldown_cache_key) | |
| if ttl == -2: | |
| return False, 0 | |
| ttl = cast(int, ttl) | |
| return True, ttl | |