Spaces:
Running
Running
| import warnings | |
| import requests | |
| from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage | |
| from pydantic.v1 import SecretStr | |
| from langflow.base.models.chat_result import get_chat_result | |
| from langflow.base.models.model_utils import get_model_name | |
| from langflow.custom.custom_component.component import Component | |
| from langflow.io import ( | |
| BoolInput, | |
| DropdownInput, | |
| HandleInput, | |
| MessageInput, | |
| MessageTextInput, | |
| Output, | |
| SecretStrInput, | |
| StrInput, | |
| ) | |
| from langflow.schema.message import Message | |
| ND_MODEL_MAPPING = { | |
| "gpt-4o": {"provider": "openai", "model": "gpt-4o"}, | |
| "gpt-4o-mini": {"provider": "openai", "model": "gpt-4o-mini"}, | |
| "gpt-4-turbo": {"provider": "openai", "model": "gpt-4-turbo-2024-04-09"}, | |
| "claude-3-5-haiku-20241022": {"provider": "anthropic", "model": "claude-3-5-haiku-20241022"}, | |
| "claude-3-5-sonnet-20241022": {"provider": "anthropic", "model": "claude-3-5-sonnet-20241022"}, | |
| "anthropic.claude-3-5-sonnet-20241022-v2:0": {"provider": "anthropic", "model": "claude-3-5-sonnet-20241022"}, | |
| "anthropic.claude-3-5-haiku-20241022-v1:0": {"provider": "anthropic", "model": "claude-3-5-haiku-20241022"}, | |
| "gemini-1.5-pro": {"provider": "google", "model": "gemini-1.5-pro-latest"}, | |
| "gemini-1.5-flash": {"provider": "google", "model": "gemini-1.5-flash-latest"}, | |
| "llama-3.1-sonar-large-128k-online": {"provider": "perplexity", "model": "llama-3.1-sonar-large-128k-online"}, | |
| "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": { | |
| "provider": "togetherai", | |
| "model": "Meta-Llama-3.1-70B-Instruct-Turbo", | |
| }, | |
| "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": { | |
| "provider": "togetherai", | |
| "model": "Meta-Llama-3.1-405B-Instruct-Turbo", | |
| }, | |
| "mistral-large-latest": {"provider": "mistral", "model": "mistral-large-2407"}, | |
| } | |
| class NotDiamondComponent(Component): | |
| display_name = "Not Diamond Router" | |
| description = "Call the right model at the right time with the world's most powerful AI model router." | |
| documentation: str = "https://docs.notdiamond.ai/" | |
| icon = "NotDiamond" | |
| name = "NotDiamond" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self._selected_model_name = None | |
| inputs = [ | |
| MessageInput(name="input_value", display_name="Input"), | |
| MessageTextInput( | |
| name="system_message", | |
| display_name="System Message", | |
| info="System message to pass to the model.", | |
| advanced=False, | |
| ), | |
| HandleInput( | |
| name="models", | |
| display_name="Language Models", | |
| input_types=["LanguageModel"], | |
| required=True, | |
| is_list=True, | |
| info="Link the models you want to route between.", | |
| ), | |
| SecretStrInput( | |
| name="api_key", | |
| display_name="Not Diamond API Key", | |
| info="The Not Diamond API Key to use for routing.", | |
| advanced=False, | |
| value="NOTDIAMOND_API_KEY", | |
| ), | |
| StrInput( | |
| name="preference_id", | |
| display_name="Preference ID", | |
| info="The ID of the router preference that was configured via the Dashboard.", | |
| advanced=False, | |
| ), | |
| DropdownInput( | |
| name="tradeoff", | |
| display_name="Tradeoff", | |
| info="The tradeoff between cost and latency for the router to determine the best LLM for a given query.", | |
| advanced=False, | |
| options=["quality", "cost", "latency"], | |
| value="quality", | |
| ), | |
| BoolInput( | |
| name="hash_content", | |
| display_name="Hash Content", | |
| info="Whether to hash the content before being sent to the NotDiamond API.", | |
| advanced=False, | |
| value=False, | |
| ), | |
| ] | |
| outputs = [ | |
| Output(display_name="Output", name="output", method="model_select"), | |
| Output( | |
| display_name="Selected Model", | |
| name="selected_model", | |
| method="get_selected_model", | |
| required_inputs=["output"], | |
| ), | |
| ] | |
| def get_selected_model(self) -> str: | |
| return self._selected_model_name | |
| def model_select(self) -> Message: | |
| api_key = SecretStr(self.api_key).get_secret_value() if self.api_key else None | |
| input_value = self.input_value | |
| system_message = self.system_message | |
| messages = self._format_input(input_value, system_message) | |
| selected_models = [] | |
| mapped_selected_models = [] | |
| for model in self.models: | |
| model_name = get_model_name(model) | |
| if model_name in ND_MODEL_MAPPING: | |
| selected_models.append(model) | |
| mapped_selected_models.append(ND_MODEL_MAPPING[model_name]) | |
| payload = { | |
| "messages": messages, | |
| "llm_providers": mapped_selected_models, | |
| "hash_content": self.hash_content, | |
| } | |
| if self.tradeoff != "quality": | |
| payload["tradeoff"] = self.tradeoff | |
| if self.preference_id and self.preference_id != "": | |
| payload["preference_id"] = self.preference_id | |
| header = { | |
| "Authorization": f"Bearer {api_key}", | |
| "accept": "application/json", | |
| "content-type": "application/json", | |
| } | |
| response = requests.post( | |
| "https://api.notdiamond.ai/v2/modelRouter/modelSelect", | |
| json=payload, | |
| headers=header, | |
| timeout=10, | |
| ) | |
| result = response.json() | |
| chosen_model = self.models[0] # By default there is a fallback model | |
| self._selected_model_name = get_model_name(chosen_model) | |
| if "providers" not in result: | |
| # No provider returned by NotDiamond API, likely failed. Fallback to first model. | |
| return self._call_get_chat_result(chosen_model, input_value, system_message) | |
| providers = result["providers"] | |
| if len(providers) == 0: | |
| # No provider returned by NotDiamond API, likely failed. Fallback to first model. | |
| return self._call_get_chat_result(chosen_model, input_value, system_message) | |
| nd_result = providers[0] | |
| for nd_model, selected_model in zip(mapped_selected_models, selected_models, strict=False): | |
| if nd_model["provider"] == nd_result["provider"] and nd_model["model"] == nd_result["model"]: | |
| chosen_model = selected_model | |
| self._selected_model_name = get_model_name(chosen_model) | |
| break | |
| return self._call_get_chat_result(chosen_model, input_value, system_message) | |
| def _call_get_chat_result(self, chosen_model, input_value, system_message): | |
| return get_chat_result( | |
| runnable=chosen_model, | |
| input_value=input_value, | |
| system_message=system_message, | |
| ) | |
| def _format_input( | |
| self, | |
| input_value: str | Message, | |
| system_message: str | None = None, | |
| ): | |
| messages: list[BaseMessage] = [] | |
| if not input_value and not system_message: | |
| msg = "The message you want to send to the router is empty." | |
| raise ValueError(msg) | |
| system_message_added = False | |
| if input_value: | |
| if isinstance(input_value, Message): | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| if "prompt" in input_value: | |
| prompt = input_value.load_lc_prompt() | |
| if system_message: | |
| prompt.messages = [ | |
| SystemMessage(content=system_message), | |
| *prompt.messages, # type: ignore[has-type] | |
| ] | |
| system_message_added = True | |
| messages.extend(prompt.messages) | |
| else: | |
| messages.append(input_value.to_lc_message()) | |
| else: | |
| messages.append(HumanMessage(content=input_value)) | |
| if system_message and not system_message_added: | |
| messages.insert(0, SystemMessage(content=system_message)) | |
| # Convert Langchain messages to OpenAI format | |
| openai_messages = [] | |
| for msg in messages: | |
| if isinstance(msg, HumanMessage): | |
| openai_messages.append({"role": "user", "content": msg.content}) | |
| elif isinstance(msg, AIMessage): | |
| openai_messages.append({"role": "assistant", "content": msg.content}) | |
| elif isinstance(msg, SystemMessage): | |
| openai_messages.append({"role": "system", "content": msg.content}) | |
| return openai_messages | |