Spaces:
Running
Running
| from langchain_community.chat_models.litellm import ChatLiteLLM, ChatLiteLLMException | |
| from langflow.base.constants import STREAM_INFO_TEXT | |
| from langflow.base.models.model import LCModelComponent | |
| from langflow.field_typing import LanguageModel | |
| from langflow.io import ( | |
| BoolInput, | |
| DictInput, | |
| DropdownInput, | |
| FloatInput, | |
| IntInput, | |
| MessageInput, | |
| SecretStrInput, | |
| StrInput, | |
| ) | |
| class ChatLiteLLMModelComponent(LCModelComponent): | |
| display_name = "LiteLLM" | |
| description = "`LiteLLM` collection of large language models." | |
| documentation = "https://python.langchain.com/docs/integrations/chat/litellm" | |
| icon = "🚄" | |
| inputs = [ | |
| MessageInput(name="input_value", display_name="Input"), | |
| StrInput( | |
| name="model", | |
| display_name="Model name", | |
| advanced=False, | |
| required=True, | |
| info="The name of the model to use. For example, `gpt-3.5-turbo`.", | |
| ), | |
| SecretStrInput( | |
| name="api_key", | |
| display_name="API Key", | |
| advanced=False, | |
| required=False, | |
| ), | |
| DropdownInput( | |
| name="provider", | |
| display_name="Provider", | |
| info="The provider of the API key.", | |
| options=[ | |
| "OpenAI", | |
| "Azure", | |
| "Anthropic", | |
| "Replicate", | |
| "Cohere", | |
| "OpenRouter", | |
| ], | |
| ), | |
| FloatInput( | |
| name="temperature", | |
| display_name="Temperature", | |
| advanced=False, | |
| required=False, | |
| value=0.7, | |
| ), | |
| DictInput( | |
| name="kwargs", | |
| display_name="Kwargs", | |
| advanced=True, | |
| required=False, | |
| is_list=True, | |
| value={}, | |
| ), | |
| DictInput( | |
| name="model_kwargs", | |
| display_name="Model kwargs", | |
| advanced=True, | |
| required=False, | |
| is_list=True, | |
| value={}, | |
| ), | |
| FloatInput(name="top_p", display_name="Top p", advanced=True, required=False, value=0.5), | |
| IntInput(name="top_k", display_name="Top k", advanced=True, required=False, value=35), | |
| IntInput( | |
| name="n", | |
| display_name="N", | |
| advanced=True, | |
| required=False, | |
| info="Number of chat completions to generate for each prompt. " | |
| "Note that the API may not return the full n completions if duplicates are generated.", | |
| value=1, | |
| ), | |
| IntInput( | |
| name="max_tokens", | |
| display_name="Max tokens", | |
| advanced=False, | |
| value=256, | |
| info="The maximum number of tokens to generate for each chat completion.", | |
| ), | |
| IntInput( | |
| name="max_retries", | |
| display_name="Max retries", | |
| advanced=True, | |
| required=False, | |
| value=6, | |
| ), | |
| BoolInput( | |
| name="verbose", | |
| display_name="Verbose", | |
| advanced=True, | |
| required=False, | |
| value=False, | |
| ), | |
| BoolInput( | |
| name="stream", | |
| display_name="Stream", | |
| info=STREAM_INFO_TEXT, | |
| advanced=True, | |
| ), | |
| StrInput( | |
| name="system_message", | |
| display_name="System Message", | |
| info="System message to pass to the model.", | |
| advanced=True, | |
| ), | |
| ] | |
| def build_model(self) -> LanguageModel: # type: ignore[type-var] | |
| try: | |
| import litellm | |
| litellm.drop_params = True | |
| litellm.set_verbose = self.verbose | |
| except ImportError as e: | |
| msg = "Could not import litellm python package. Please install it with `pip install litellm`" | |
| raise ChatLiteLLMException(msg) from e | |
| # Remove empty keys | |
| if "" in self.kwargs: | |
| del self.kwargs[""] | |
| if "" in self.model_kwargs: | |
| del self.model_kwargs[""] | |
| # Report missing fields for Azure provider | |
| if self.provider == "Azure": | |
| if "api_base" not in self.kwargs: | |
| msg = "Missing api_base on kwargs" | |
| raise ValueError(msg) | |
| if "api_version" not in self.model_kwargs: | |
| msg = "Missing api_version on model_kwargs" | |
| raise ValueError(msg) | |
| output = ChatLiteLLM( | |
| model=f"{self.provider.lower()}/{self.model}", | |
| client=None, | |
| streaming=self.stream, | |
| temperature=self.temperature, | |
| model_kwargs=self.model_kwargs if self.model_kwargs is not None else {}, | |
| top_p=self.top_p, | |
| top_k=self.top_k, | |
| n=self.n, | |
| max_tokens=self.max_tokens, | |
| max_retries=self.max_retries, | |
| **self.kwargs, | |
| ) | |
| output.client.api_key = self.api_key | |
| return output | |