Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| import threading | |
| from typing import Any, Dict, List, Mapping, Optional | |
| import requests | |
| from langchain_core.messages import ( | |
| AIMessage, | |
| BaseMessage, | |
| ChatMessage, | |
| HumanMessage, | |
| ) | |
| from langchain_core.outputs import ChatGeneration, ChatResult | |
| from langchain_core.pydantic_v1 import root_validator | |
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |
| from langchain.chat_models.base import BaseChatModel | |
| from langchain.utils import get_from_dict_or_env | |
| logger = logging.getLogger(__name__) | |
| def _convert_message_to_dict(message: BaseMessage) -> dict: | |
| if isinstance(message, ChatMessage): | |
| message_dict = {"role": message.role, "content": message.content} | |
| elif isinstance(message, HumanMessage): | |
| message_dict = {"role": "user", "content": message.content} | |
| elif isinstance(message, AIMessage): | |
| message_dict = {"role": "assistant", "content": message.content} | |
| else: | |
| raise ValueError(f"Got unknown type {message}") | |
| return message_dict | |
| class ErnieBotChat(BaseChatModel): | |
| """`ERNIE-Bot` large language model. | |
| ERNIE-Bot is a large language model developed by Baidu, | |
| covering a huge amount of Chinese data. | |
| To use, you should have the `ernie_client_id` and `ernie_client_secret` set, | |
| or set the environment variable `ERNIE_CLIENT_ID` and `ERNIE_CLIENT_SECRET`. | |
| Note: | |
| access_token will be automatically generated based on client_id and client_secret, | |
| and will be regenerated after expiration (30 days). | |
| Default model is `ERNIE-Bot-turbo`, | |
| currently supported models are `ERNIE-Bot-turbo`, `ERNIE-Bot` | |
| Example: | |
| .. code-block:: python | |
| from langchain.chat_models import ErnieBotChat | |
| chat = ErnieBotChat(model_name='ERNIE-Bot') | |
| """ | |
| ernie_api_base: Optional[str] = None | |
| """Baidu application custom endpoints""" | |
| ernie_client_id: Optional[str] = None | |
| """Baidu application client id""" | |
| ernie_client_secret: Optional[str] = None | |
| """Baidu application client secret""" | |
| access_token: Optional[str] = None | |
| """access token is generated by client id and client secret, | |
| setting this value directly will cause an error""" | |
| model_name: str = "ERNIE-Bot-turbo" | |
| """model name of ernie, default is `ERNIE-Bot-turbo`. | |
| Currently supported `ERNIE-Bot-turbo`, `ERNIE-Bot`""" | |
| request_timeout: Optional[int] = 60 | |
| """request timeout for chat http requests""" | |
| streaming: Optional[bool] = False | |
| """streaming mode. not supported yet.""" | |
| top_p: Optional[float] = 0.8 | |
| temperature: Optional[float] = 0.95 | |
| penalty_score: Optional[float] = 1 | |
| _lock = threading.Lock() | |
| def validate_environment(cls, values: Dict) -> Dict: | |
| values["ernie_api_base"] = get_from_dict_or_env( | |
| values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com" | |
| ) | |
| values["ernie_client_id"] = get_from_dict_or_env( | |
| values, | |
| "ernie_client_id", | |
| "ERNIE_CLIENT_ID", | |
| ) | |
| values["ernie_client_secret"] = get_from_dict_or_env( | |
| values, | |
| "ernie_client_secret", | |
| "ERNIE_CLIENT_SECRET", | |
| ) | |
| return values | |
| def _chat(self, payload: object) -> dict: | |
| base_url = f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/chat" | |
| model_paths = { | |
| "ERNIE-Bot-turbo": "eb-instant", | |
| "ERNIE-Bot": "completions", | |
| "ERNIE-Bot-4": "completions_pro", | |
| "BLOOMZ-7B": "bloomz_7b1", | |
| "Llama-2-7b-chat": "llama_2_7b", | |
| "Llama-2-13b-chat": "llama_2_13b", | |
| "Llama-2-70b-chat": "llama_2_70b", | |
| } | |
| if self.model_name in model_paths: | |
| url = f"{base_url}/{model_paths[self.model_name]}" | |
| else: | |
| raise ValueError(f"Got unknown model_name {self.model_name}") | |
| resp = requests.post( | |
| url, | |
| timeout=self.request_timeout, | |
| headers={ | |
| "Content-Type": "application/json", | |
| }, | |
| params={"access_token": self.access_token}, | |
| json=payload, | |
| ) | |
| return resp.json() | |
| def _refresh_access_token_with_lock(self) -> None: | |
| with self._lock: | |
| logger.debug("Refreshing access token") | |
| base_url: str = f"{self.ernie_api_base}/oauth/2.0/token" | |
| resp = requests.post( | |
| base_url, | |
| timeout=10, | |
| headers={ | |
| "Content-Type": "application/json", | |
| "Accept": "application/json", | |
| }, | |
| params={ | |
| "grant_type": "client_credentials", | |
| "client_id": self.ernie_client_id, | |
| "client_secret": self.ernie_client_secret, | |
| }, | |
| ) | |
| self.access_token = str(resp.json().get("access_token")) | |
| def _generate( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> ChatResult: | |
| if self.streaming: | |
| raise ValueError("`streaming` option currently unsupported.") | |
| if not self.access_token: | |
| self._refresh_access_token_with_lock() | |
| payload = { | |
| "messages": [_convert_message_to_dict(m) for m in messages], | |
| "top_p": self.top_p, | |
| "temperature": self.temperature, | |
| "penalty_score": self.penalty_score, | |
| **kwargs, | |
| } | |
| logger.debug(f"Payload for ernie api is {payload}") | |
| resp = self._chat(payload) | |
| if resp.get("error_code"): | |
| if resp.get("error_code") == 111: | |
| logger.debug("access_token expired, refresh it") | |
| self._refresh_access_token_with_lock() | |
| resp = self._chat(payload) | |
| else: | |
| raise ValueError(f"Error from ErnieChat api response: {resp}") | |
| return self._create_chat_result(resp) | |
| def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: | |
| if "function_call" in response: | |
| fc_str = '{{"function_call": {}}}'.format( | |
| json.dumps(response.get("function_call")) | |
| ) | |
| generations = [ChatGeneration(message=AIMessage(content=fc_str))] | |
| else: | |
| generations = [ | |
| ChatGeneration(message=AIMessage(content=response.get("result"))) | |
| ] | |
| token_usage = response.get("usage", {}) | |
| llm_output = {"token_usage": token_usage, "model_name": self.model_name} | |
| return ChatResult(generations=generations, llm_output=llm_output) | |
| def _llm_type(self) -> str: | |
| return "ernie-bot-chat" | |