Spaces:
Runtime error
Runtime error
| import json | |
| from typing import Any, Dict, List, Optional, cast | |
| from langchain_core.messages import ( | |
| AIMessage, | |
| BaseMessage, | |
| ChatMessage, | |
| HumanMessage, | |
| SystemMessage, | |
| ) | |
| from langchain_core.pydantic_v1 import SecretStr, validator | |
| from langchain_core.utils import convert_to_secret_str | |
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |
| from langchain.chat_models.base import SimpleChatModel | |
| from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase | |
| from langchain.utils import get_from_dict_or_env | |
| class LlamaContentFormatter(ContentFormatterBase): | |
| """Content formatter for `LLaMA`.""" | |
| SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"] | |
| def _convert_message_to_dict(message: BaseMessage) -> Dict: | |
| """Converts message to a dict according to role""" | |
| content = cast(str, message.content) | |
| if isinstance(message, HumanMessage): | |
| return { | |
| "role": "user", | |
| "content": ContentFormatterBase.escape_special_characters(content), | |
| } | |
| elif isinstance(message, AIMessage): | |
| return { | |
| "role": "assistant", | |
| "content": ContentFormatterBase.escape_special_characters(content), | |
| } | |
| elif isinstance(message, SystemMessage): | |
| return { | |
| "role": "system", | |
| "content": ContentFormatterBase.escape_special_characters(content), | |
| } | |
| elif ( | |
| isinstance(message, ChatMessage) | |
| and message.role in LlamaContentFormatter.SUPPORTED_ROLES | |
| ): | |
| return { | |
| "role": message.role, | |
| "content": ContentFormatterBase.escape_special_characters(content), | |
| } | |
| else: | |
| supported = ",".join( | |
| [role for role in LlamaContentFormatter.SUPPORTED_ROLES] | |
| ) | |
| raise ValueError( | |
| f"""Received unsupported role. | |
| Supported roles for the LLaMa Foundation Model: {supported}""" | |
| ) | |
| def _format_request_payload( | |
| self, messages: List[BaseMessage], model_kwargs: Dict | |
| ) -> bytes: | |
| chat_messages = [ | |
| LlamaContentFormatter._convert_message_to_dict(message) | |
| for message in messages | |
| ] | |
| prompt = json.dumps( | |
| {"input_data": {"input_string": chat_messages, "parameters": model_kwargs}} | |
| ) | |
| return self.format_request_payload(prompt=prompt, model_kwargs=model_kwargs) | |
| def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: | |
| """Formats the request according to the chosen api""" | |
| return str.encode(prompt) | |
| def format_response_payload(self, output: bytes) -> str: | |
| """Formats response""" | |
| return json.loads(output)["output"] | |
| class AzureMLChatOnlineEndpoint(SimpleChatModel): | |
| """`AzureML` Chat models API. | |
| Example: | |
| .. code-block:: python | |
| azure_chat = AzureMLChatOnlineEndpoint( | |
| endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score", | |
| endpoint_api_key="my-api-key", | |
| content_formatter=content_formatter, | |
| ) | |
| """ | |
| endpoint_url: str = "" | |
| """URL of pre-existing Endpoint. Should be passed to constructor or specified as | |
| env var `AZUREML_ENDPOINT_URL`.""" | |
| endpoint_api_key: SecretStr = convert_to_secret_str("") | |
| """Authentication Key for Endpoint. Should be passed to constructor or specified as | |
| env var `AZUREML_ENDPOINT_API_KEY`.""" | |
| http_client: Any = None #: :meta private: | |
| content_formatter: Any = None | |
| """The content formatter that provides an input and output | |
| transform function to handle formats between the LLM and | |
| the endpoint""" | |
| model_kwargs: Optional[dict] = None | |
| """Keyword arguments to pass to the model.""" | |
| def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient: | |
| """Validate that api key and python package exist in environment.""" | |
| values["endpoint_api_key"] = convert_to_secret_str( | |
| get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY") | |
| ) | |
| endpoint_url = get_from_dict_or_env( | |
| values, "endpoint_url", "AZUREML_ENDPOINT_URL" | |
| ) | |
| http_client = AzureMLEndpointClient( | |
| endpoint_url, values["endpoint_api_key"].get_secret_value() | |
| ) | |
| return http_client | |
| def _identifying_params(self) -> Dict[str, Any]: | |
| """Get the identifying parameters.""" | |
| _model_kwargs = self.model_kwargs or {} | |
| return { | |
| **{"model_kwargs": _model_kwargs}, | |
| } | |
| def _llm_type(self) -> str: | |
| """Return type of llm.""" | |
| return "azureml_chat_endpoint" | |
| def _call( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> str: | |
| """Call out to an AzureML Managed Online endpoint. | |
| Args: | |
| messages: The messages in the conversation with the chat model. | |
| stop: Optional list of stop words to use when generating. | |
| Returns: | |
| The string generated by the model. | |
| Example: | |
| .. code-block:: python | |
| response = azureml_model("Tell me a joke.") | |
| """ | |
| _model_kwargs = self.model_kwargs or {} | |
| request_payload = self.content_formatter._format_request_payload( | |
| messages, _model_kwargs | |
| ) | |
| response_payload = self.http_client.call(request_payload, **kwargs) | |
| generated_text = self.content_formatter.format_response_payload( | |
| response_payload | |
| ) | |
| return generated_text | |