Spaces:
Runtime error
Runtime error
| import os | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Callable, Dict, List, Optional | |
| import requests | |
| from langchain_core.pydantic_v1 import ( | |
| BaseModel, | |
| Extra, | |
| Field, | |
| PrivateAttr, | |
| root_validator, | |
| validator, | |
| ) | |
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |
| from langchain.llms.base import LLM | |
| __all__ = ["Databricks"] | |
| class _DatabricksClientBase(BaseModel, ABC): | |
| """A base JSON API client that talks to Databricks.""" | |
| api_url: str | |
| api_token: str | |
| def post_raw(self, request: Any) -> Any: | |
| headers = {"Authorization": f"Bearer {self.api_token}"} | |
| response = requests.post(self.api_url, headers=headers, json=request) | |
| # TODO: error handling and automatic retries | |
| if not response.ok: | |
| raise ValueError(f"HTTP {response.status_code} error: {response.text}") | |
| return response.json() | |
| def post(self, request: Any) -> Any: | |
| ... | |
| class _DatabricksServingEndpointClient(_DatabricksClientBase): | |
| """An API client that talks to a Databricks serving endpoint.""" | |
| host: str | |
| endpoint_name: str | |
| def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: | |
| if "api_url" not in values: | |
| host = values["host"] | |
| endpoint_name = values["endpoint_name"] | |
| api_url = f"https://{host}/serving-endpoints/{endpoint_name}/invocations" | |
| values["api_url"] = api_url | |
| return values | |
| def post(self, request: Any) -> Any: | |
| # See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html | |
| wrapped_request = {"dataframe_records": [request]} | |
| response = self.post_raw(wrapped_request)["predictions"] | |
| # For a single-record query, the result is not a list. | |
| if isinstance(response, list): | |
| response = response[0] | |
| return response | |
| class _DatabricksClusterDriverProxyClient(_DatabricksClientBase): | |
| """An API client that talks to a Databricks cluster driver proxy app.""" | |
| host: str | |
| cluster_id: str | |
| cluster_driver_port: str | |
| def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: | |
| if "api_url" not in values: | |
| host = values["host"] | |
| cluster_id = values["cluster_id"] | |
| port = values["cluster_driver_port"] | |
| api_url = f"https://{host}/driver-proxy-api/o/0/{cluster_id}/{port}" | |
| values["api_url"] = api_url | |
| return values | |
| def post(self, request: Any) -> Any: | |
| return self.post_raw(request) | |
| def get_repl_context() -> Any: | |
| """Gets the notebook REPL context if running inside a Databricks notebook. | |
| Returns None otherwise. | |
| """ | |
| try: | |
| from dbruntime.databricks_repl_context import get_context | |
| return get_context() | |
| except ImportError: | |
| raise ImportError( | |
| "Cannot access dbruntime, not running inside a Databricks notebook." | |
| ) | |
| def get_default_host() -> str: | |
| """Gets the default Databricks workspace hostname. | |
| Raises an error if the hostname cannot be automatically determined. | |
| """ | |
| host = os.getenv("DATABRICKS_HOST") | |
| if not host: | |
| try: | |
| host = get_repl_context().browserHostName | |
| if not host: | |
| raise ValueError("context doesn't contain browserHostName.") | |
| except Exception as e: | |
| raise ValueError( | |
| "host was not set and cannot be automatically inferred. Set " | |
| f"environment variable 'DATABRICKS_HOST'. Received error: {e}" | |
| ) | |
| # TODO: support Databricks CLI profile | |
| host = host.lstrip("https://").lstrip("http://").rstrip("/") | |
| return host | |
| def get_default_api_token() -> str: | |
| """Gets the default Databricks personal access token. | |
| Raises an error if the token cannot be automatically determined. | |
| """ | |
| if api_token := os.getenv("DATABRICKS_TOKEN"): | |
| return api_token | |
| try: | |
| api_token = get_repl_context().apiToken | |
| if not api_token: | |
| raise ValueError("context doesn't contain apiToken.") | |
| except Exception as e: | |
| raise ValueError( | |
| "api_token was not set and cannot be automatically inferred. Set " | |
| f"environment variable 'DATABRICKS_TOKEN'. Received error: {e}" | |
| ) | |
| # TODO: support Databricks CLI profile | |
| return api_token | |
| class Databricks(LLM): | |
| """Databricks serving endpoint or a cluster driver proxy app for LLM. | |
| It supports two endpoint types: | |
| * **Serving endpoint** (recommended for both production and development). | |
| We assume that an LLM was registered and deployed to a serving endpoint. | |
| To wrap it as an LLM you must have "Can Query" permission to the endpoint. | |
| Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and | |
| ``cluster_driver_port``. | |
| The expected model signature is: | |
| * inputs:: | |
| [{"name": "prompt", "type": "string"}, | |
| {"name": "stop", "type": "list[string]"}] | |
| * outputs: ``[{"type": "string"}]`` | |
| * **Cluster driver proxy app** (recommended for interactive development). | |
| One can load an LLM on a Databricks interactive cluster and start a local HTTP | |
| server on the driver node to serve the model at ``/`` using HTTP POST method | |
| with JSON input/output. | |
| Please use a port number between ``[3000, 8000]`` and let the server listen to | |
| the driver IP address or simply ``0.0.0.0`` instead of localhost only. | |
| To wrap it as an LLM you must have "Can Attach To" permission to the cluster. | |
| Set ``cluster_id`` and ``cluster_driver_port`` and do not set ``endpoint_name``. | |
| The expected server schema (using JSON schema) is: | |
| * inputs:: | |
| {"type": "object", | |
| "properties": { | |
| "prompt": {"type": "string"}, | |
| "stop": {"type": "array", "items": {"type": "string"}}}, | |
| "required": ["prompt"]}` | |
| * outputs: ``{"type": "string"}`` | |
| If the endpoint model signature is different or you want to set extra params, | |
| you can use `transform_input_fn` and `transform_output_fn` to apply necessary | |
| transformations before and after the query. | |
| """ | |
| host: str = Field(default_factory=get_default_host) | |
| """Databricks workspace hostname. | |
| If not provided, the default value is determined by | |
| * the ``DATABRICKS_HOST`` environment variable if present, or | |
| * the hostname of the current Databricks workspace if running inside | |
| a Databricks notebook attached to an interactive cluster in "single user" | |
| or "no isolation shared" mode. | |
| """ | |
| api_token: str = Field(default_factory=get_default_api_token) | |
| """Databricks personal access token. | |
| If not provided, the default value is determined by | |
| * the ``DATABRICKS_TOKEN`` environment variable if present, or | |
| * an automatically generated temporary token if running inside a Databricks | |
| notebook attached to an interactive cluster in "single user" or | |
| "no isolation shared" mode. | |
| """ | |
| endpoint_name: Optional[str] = None | |
| """Name of the model serving endpoint. | |
| You must specify the endpoint name to connect to a model serving endpoint. | |
| You must not set both ``endpoint_name`` and ``cluster_id``. | |
| """ | |
| cluster_id: Optional[str] = None | |
| """ID of the cluster if connecting to a cluster driver proxy app. | |
| If neither ``endpoint_name`` nor ``cluster_id`` is not provided and the code runs | |
| inside a Databricks notebook attached to an interactive cluster in "single user" | |
| or "no isolation shared" mode, the current cluster ID is used as default. | |
| You must not set both ``endpoint_name`` and ``cluster_id``. | |
| """ | |
| cluster_driver_port: Optional[str] = None | |
| """The port number used by the HTTP server running on the cluster driver node. | |
| The server should listen on the driver IP address or simply ``0.0.0.0`` to connect. | |
| We recommend the server using a port number between ``[3000, 8000]``. | |
| """ | |
| model_kwargs: Optional[Dict[str, Any]] = None | |
| """Extra parameters to pass to the endpoint.""" | |
| transform_input_fn: Optional[Callable] = None | |
| """A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible | |
| request object that the endpoint accepts. | |
| For example, you can apply a prompt template to the input prompt. | |
| """ | |
| transform_output_fn: Optional[Callable[..., str]] = None | |
| """A function that transforms the output from the endpoint to the generated text. | |
| """ | |
| _client: _DatabricksClientBase = PrivateAttr() | |
| class Config: | |
| extra = Extra.forbid | |
| underscore_attrs_are_private = True | |
| def set_cluster_id(cls, v: Any, values: Dict[str, Any]) -> Optional[str]: | |
| if v and values["endpoint_name"]: | |
| raise ValueError("Cannot set both endpoint_name and cluster_id.") | |
| elif values["endpoint_name"]: | |
| return None | |
| elif v: | |
| return v | |
| else: | |
| try: | |
| if v := get_repl_context().clusterId: | |
| return v | |
| raise ValueError("Context doesn't contain clusterId.") | |
| except Exception as e: | |
| raise ValueError( | |
| "Neither endpoint_name nor cluster_id was set. " | |
| "And the cluster_id cannot be automatically determined. Received" | |
| f" error: {e}" | |
| ) | |
| def set_cluster_driver_port(cls, v: Any, values: Dict[str, Any]) -> Optional[str]: | |
| if v and values["endpoint_name"]: | |
| raise ValueError("Cannot set both endpoint_name and cluster_driver_port.") | |
| elif values["endpoint_name"]: | |
| return None | |
| elif v is None: | |
| raise ValueError( | |
| "Must set cluster_driver_port to connect to a cluster driver." | |
| ) | |
| elif int(v) <= 0: | |
| raise ValueError(f"Invalid cluster_driver_port: {v}") | |
| else: | |
| return v | |
| def set_model_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: | |
| if v: | |
| assert "prompt" not in v, "model_kwargs must not contain key 'prompt'" | |
| assert "stop" not in v, "model_kwargs must not contain key 'stop'" | |
| return v | |
| def __init__(self, **data: Any): | |
| super().__init__(**data) | |
| if self.endpoint_name: | |
| self._client = _DatabricksServingEndpointClient( | |
| host=self.host, | |
| api_token=self.api_token, | |
| endpoint_name=self.endpoint_name, | |
| ) | |
| elif self.cluster_id and self.cluster_driver_port: | |
| self._client = _DatabricksClusterDriverProxyClient( | |
| host=self.host, | |
| api_token=self.api_token, | |
| cluster_id=self.cluster_id, | |
| cluster_driver_port=self.cluster_driver_port, | |
| ) | |
| else: | |
| raise ValueError( | |
| "Must specify either endpoint_name or cluster_id/cluster_driver_port." | |
| ) | |
| def _llm_type(self) -> str: | |
| """Return type of llm.""" | |
| return "databricks" | |
| def _call( | |
| self, | |
| prompt: str, | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> str: | |
| """Queries the LLM endpoint with the given prompt and stop sequence.""" | |
| # TODO: support callbacks | |
| request = {"prompt": prompt, "stop": stop} | |
| request.update(kwargs) | |
| if self.model_kwargs: | |
| request.update(self.model_kwargs) | |
| if self.transform_input_fn: | |
| request = self.transform_input_fn(**request) | |
| response = self._client.post(request) | |
| if self.transform_output_fn: | |
| response = self.transform_output_fn(response) | |
| return response | |