Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| import logging | |
| from hashlib import sha1 | |
| from threading import Thread | |
| from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | |
| from langchain_core.embeddings import Embeddings | |
| from langchain_core.pydantic_v1 import BaseSettings | |
| from langchain_core.vectorstores import VectorStore | |
| from langchain.docstore.document import Document | |
| logger = logging.getLogger() | |
| def has_mul_sub_str(s: str, *args: Any) -> bool: | |
| """ | |
| Check if a string contains multiple substrings. | |
| Args: | |
| s: string to check. | |
| *args: substrings to check. | |
| Returns: | |
| True if all substrings are in the string, False otherwise. | |
| """ | |
| for a in args: | |
| if a not in s: | |
| return False | |
| return True | |
| class ClickhouseSettings(BaseSettings): | |
| """`ClickHouse` client configuration. | |
| Attribute: | |
| host (str) : An URL to connect to MyScale backend. | |
| Defaults to 'localhost'. | |
| port (int) : URL port to connect with HTTP. Defaults to 8443. | |
| username (str) : Username to login. Defaults to None. | |
| password (str) : Password to login. Defaults to None. | |
| index_type (str): index type string. | |
| index_param (list): index build parameter. | |
| index_query_params(dict): index query parameters. | |
| database (str) : Database name to find the table. Defaults to 'default'. | |
| table (str) : Table name to operate on. | |
| Defaults to 'vector_table'. | |
| metric (str) : Metric to compute distance, | |
| supported are ('angular', 'euclidean', 'manhattan', 'hamming', | |
| 'dot'). Defaults to 'angular'. | |
| https://github.com/spotify/annoy/blob/main/src/annoymodule.cc#L149-L169 | |
| column_map (Dict) : Column type map to project column name onto langchain | |
| semantics. Must have keys: `text`, `id`, `vector`, | |
| must be same size to number of columns. For example: | |
| .. code-block:: python | |
| { | |
| 'id': 'text_id', | |
| 'uuid': 'global_unique_id' | |
| 'embedding': 'text_embedding', | |
| 'document': 'text_plain', | |
| 'metadata': 'metadata_dictionary_in_json', | |
| } | |
| Defaults to identity map. | |
| """ | |
| host: str = "localhost" | |
| port: int = 8123 | |
| username: Optional[str] = None | |
| password: Optional[str] = None | |
| index_type: str = "annoy" | |
| # Annoy supports L2Distance and cosineDistance. | |
| index_param: Optional[Union[List, Dict]] = ["'L2Distance'", 100] | |
| index_query_params: Dict[str, str] = {} | |
| column_map: Dict[str, str] = { | |
| "id": "id", | |
| "uuid": "uuid", | |
| "document": "document", | |
| "embedding": "embedding", | |
| "metadata": "metadata", | |
| } | |
| database: str = "default" | |
| table: str = "langchain" | |
| metric: str = "angular" | |
| def __getitem__(self, item: str) -> Any: | |
| return getattr(self, item) | |
| class Config: | |
| env_file = ".env" | |
| env_prefix = "clickhouse_" | |
| env_file_encoding = "utf-8" | |
| class Clickhouse(VectorStore): | |
| """`ClickHouse VectorSearch` vector store. | |
| You need a `clickhouse-connect` python package, and a valid account | |
| to connect to ClickHouse. | |
| ClickHouse can not only search with simple vector indexes, | |
| it also supports complex query with multiple conditions, | |
| constraints and even sub-queries. | |
| For more information, please visit | |
| [ClickHouse official site](https://clickhouse.com/clickhouse) | |
| """ | |
| def __init__( | |
| self, | |
| embedding: Embeddings, | |
| config: Optional[ClickhouseSettings] = None, | |
| **kwargs: Any, | |
| ) -> None: | |
| """ClickHouse Wrapper to LangChain | |
| embedding_function (Embeddings): | |
| config (ClickHouseSettings): Configuration to ClickHouse Client | |
| Other keyword arguments will pass into | |
| [clickhouse-connect](https://docs.clickhouse.com/) | |
| """ | |
| try: | |
| from clickhouse_connect import get_client | |
| except ImportError: | |
| raise ImportError( | |
| "Could not import clickhouse connect python package. " | |
| "Please install it with `pip install clickhouse-connect`." | |
| ) | |
| try: | |
| from tqdm import tqdm | |
| self.pgbar = tqdm | |
| except ImportError: | |
| # Just in case if tqdm is not installed | |
| self.pgbar = lambda x, **kwargs: x | |
| super().__init__() | |
| if config is not None: | |
| self.config = config | |
| else: | |
| self.config = ClickhouseSettings() | |
| assert self.config | |
| assert self.config.host and self.config.port | |
| assert ( | |
| self.config.column_map | |
| and self.config.database | |
| and self.config.table | |
| and self.config.metric | |
| ) | |
| for k in ["id", "embedding", "document", "metadata", "uuid"]: | |
| assert k in self.config.column_map | |
| assert self.config.metric in [ | |
| "angular", | |
| "euclidean", | |
| "manhattan", | |
| "hamming", | |
| "dot", | |
| ] | |
| # initialize the schema | |
| dim = len(embedding.embed_query("test")) | |
| index_params = ( | |
| ( | |
| ",".join([f"'{k}={v}'" for k, v in self.config.index_param.items()]) | |
| if self.config.index_param | |
| else "" | |
| ) | |
| if isinstance(self.config.index_param, Dict) | |
| else ",".join([str(p) for p in self.config.index_param]) | |
| if isinstance(self.config.index_param, List) | |
| else self.config.index_param | |
| ) | |
| self.schema = f"""\ | |
| CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( | |
| {self.config.column_map['id']} Nullable(String), | |
| {self.config.column_map['document']} Nullable(String), | |
| {self.config.column_map['embedding']} Array(Float32), | |
| {self.config.column_map['metadata']} JSON, | |
| {self.config.column_map['uuid']} UUID DEFAULT generateUUIDv4(), | |
| CONSTRAINT cons_vec_len CHECK length({self.config.column_map['embedding']}) = {dim}, | |
| INDEX vec_idx {self.config.column_map['embedding']} TYPE \ | |
| {self.config.index_type}({index_params}) GRANULARITY 1000 | |
| ) ENGINE = MergeTree ORDER BY uuid SETTINGS index_granularity = 8192\ | |
| """ | |
| self.dim = dim | |
| self.BS = "\\" | |
| self.must_escape = ("\\", "'") | |
| self.embedding_function = embedding | |
| self.dist_order = "ASC" # Only support ConsingDistance and L2Distance | |
| # Create a connection to clickhouse | |
| self.client = get_client( | |
| host=self.config.host, | |
| port=self.config.port, | |
| username=self.config.username, | |
| password=self.config.password, | |
| **kwargs, | |
| ) | |
| # Enable JSON type | |
| self.client.command("SET allow_experimental_object_type=1") | |
| # Enable Annoy index | |
| self.client.command("SET allow_experimental_annoy_index=1") | |
| self.client.command(self.schema) | |
| def embeddings(self) -> Embeddings: | |
| return self.embedding_function | |
| def escape_str(self, value: str) -> str: | |
| return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value) | |
| def _build_insert_sql(self, transac: Iterable, column_names: Iterable[str]) -> str: | |
| ks = ",".join(column_names) | |
| _data = [] | |
| for n in transac: | |
| n = ",".join([f"'{self.escape_str(str(_n))}'" for _n in n]) | |
| _data.append(f"({n})") | |
| i_str = f""" | |
| INSERT INTO TABLE | |
| {self.config.database}.{self.config.table}({ks}) | |
| VALUES | |
| {','.join(_data)} | |
| """ | |
| return i_str | |
| def _insert(self, transac: Iterable, column_names: Iterable[str]) -> None: | |
| _insert_query = self._build_insert_sql(transac, column_names) | |
| self.client.command(_insert_query) | |
| def add_texts( | |
| self, | |
| texts: Iterable[str], | |
| metadatas: Optional[List[dict]] = None, | |
| batch_size: int = 32, | |
| ids: Optional[Iterable[str]] = None, | |
| **kwargs: Any, | |
| ) -> List[str]: | |
| """Insert more texts through the embeddings and add to the VectorStore. | |
| Args: | |
| texts: Iterable of strings to add to the VectorStore. | |
| ids: Optional list of ids to associate with the texts. | |
| batch_size: Batch size of insertion | |
| metadata: Optional column data to be inserted | |
| Returns: | |
| List of ids from adding the texts into the VectorStore. | |
| """ | |
| # Embed and create the documents | |
| ids = ids or [sha1(t.encode("utf-8")).hexdigest() for t in texts] | |
| colmap_ = self.config.column_map | |
| transac = [] | |
| column_names = { | |
| colmap_["id"]: ids, | |
| colmap_["document"]: texts, | |
| colmap_["embedding"]: self.embedding_function.embed_documents(list(texts)), | |
| } | |
| metadatas = metadatas or [{} for _ in texts] | |
| column_names[colmap_["metadata"]] = map(json.dumps, metadatas) | |
| assert len(set(colmap_) - set(column_names)) >= 0 | |
| keys, values = zip(*column_names.items()) | |
| try: | |
| t = None | |
| for v in self.pgbar( | |
| zip(*values), desc="Inserting data...", total=len(metadatas) | |
| ): | |
| assert ( | |
| len(v[keys.index(self.config.column_map["embedding"])]) == self.dim | |
| ) | |
| transac.append(v) | |
| if len(transac) == batch_size: | |
| if t: | |
| t.join() | |
| t = Thread(target=self._insert, args=[transac, keys]) | |
| t.start() | |
| transac = [] | |
| if len(transac) > 0: | |
| if t: | |
| t.join() | |
| self._insert(transac, keys) | |
| return [i for i in ids] | |
| except Exception as e: | |
| logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") | |
| return [] | |
| def from_texts( | |
| cls, | |
| texts: List[str], | |
| embedding: Embeddings, | |
| metadatas: Optional[List[Dict[Any, Any]]] = None, | |
| config: Optional[ClickhouseSettings] = None, | |
| text_ids: Optional[Iterable[str]] = None, | |
| batch_size: int = 32, | |
| **kwargs: Any, | |
| ) -> Clickhouse: | |
| """Create ClickHouse wrapper with existing texts | |
| Args: | |
| embedding_function (Embeddings): Function to extract text embedding | |
| texts (Iterable[str]): List or tuple of strings to be added | |
| config (ClickHouseSettings, Optional): ClickHouse configuration | |
| text_ids (Optional[Iterable], optional): IDs for the texts. | |
| Defaults to None. | |
| batch_size (int, optional): Batchsize when transmitting data to ClickHouse. | |
| Defaults to 32. | |
| metadata (List[dict], optional): metadata to texts. Defaults to None. | |
| Other keyword arguments will pass into | |
| [clickhouse-connect](https://clickhouse.com/docs/en/integrations/python#clickhouse-connect-driver-api) | |
| Returns: | |
| ClickHouse Index | |
| """ | |
| ctx = cls(embedding, config, **kwargs) | |
| ctx.add_texts(texts, ids=text_ids, batch_size=batch_size, metadatas=metadatas) | |
| return ctx | |
| def __repr__(self) -> str: | |
| """Text representation for ClickHouse Vector Store, prints backends, username | |
| and schemas. Easy to use with `str(ClickHouse())` | |
| Returns: | |
| repr: string to show connection info and data schema | |
| """ | |
| _repr = f"\033[92m\033[1m{self.config.database}.{self.config.table} @ " | |
| _repr += f"{self.config.host}:{self.config.port}\033[0m\n\n" | |
| _repr += f"\033[1musername: {self.config.username}\033[0m\n\nTable Schema:\n" | |
| _repr += "-" * 51 + "\n" | |
| for r in self.client.query( | |
| f"DESC {self.config.database}.{self.config.table}" | |
| ).named_results(): | |
| _repr += ( | |
| f"|\033[94m{r['name']:24s}\033[0m|\033[96m{r['type']:24s}\033[0m|\n" | |
| ) | |
| _repr += "-" * 51 + "\n" | |
| return _repr | |
| def _build_query_sql( | |
| self, q_emb: List[float], topk: int, where_str: Optional[str] = None | |
| ) -> str: | |
| q_emb_str = ",".join(map(str, q_emb)) | |
| if where_str: | |
| where_str = f"PREWHERE {where_str}" | |
| else: | |
| where_str = "" | |
| settings_strs = [] | |
| if self.config.index_query_params: | |
| for k in self.config.index_query_params: | |
| settings_strs.append(f"SETTING {k}={self.config.index_query_params[k]}") | |
| q_str = f""" | |
| SELECT {self.config.column_map['document']}, | |
| {self.config.column_map['metadata']}, dist | |
| FROM {self.config.database}.{self.config.table} | |
| {where_str} | |
| ORDER BY L2Distance({self.config.column_map['embedding']}, [{q_emb_str}]) | |
| AS dist {self.dist_order} | |
| LIMIT {topk} {' '.join(settings_strs)} | |
| """ | |
| return q_str | |
| def similarity_search( | |
| self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any | |
| ) -> List[Document]: | |
| """Perform a similarity search with ClickHouse | |
| Args: | |
| query (str): query string | |
| k (int, optional): Top K neighbors to retrieve. Defaults to 4. | |
| where_str (Optional[str], optional): where condition string. | |
| Defaults to None. | |
| NOTE: Please do not let end-user to fill this and always be aware | |
| of SQL injection. When dealing with metadatas, remember to | |
| use `{self.metadata_column}.attribute` instead of `attribute` | |
| alone. The default name for it is `metadata`. | |
| Returns: | |
| List[Document]: List of Documents | |
| """ | |
| return self.similarity_search_by_vector( | |
| self.embedding_function.embed_query(query), k, where_str, **kwargs | |
| ) | |
| def similarity_search_by_vector( | |
| self, | |
| embedding: List[float], | |
| k: int = 4, | |
| where_str: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> List[Document]: | |
| """Perform a similarity search with ClickHouse by vectors | |
| Args: | |
| query (str): query string | |
| k (int, optional): Top K neighbors to retrieve. Defaults to 4. | |
| where_str (Optional[str], optional): where condition string. | |
| Defaults to None. | |
| NOTE: Please do not let end-user to fill this and always be aware | |
| of SQL injection. When dealing with metadatas, remember to | |
| use `{self.metadata_column}.attribute` instead of `attribute` | |
| alone. The default name for it is `metadata`. | |
| Returns: | |
| List[Document]: List of documents | |
| """ | |
| q_str = self._build_query_sql(embedding, k, where_str) | |
| try: | |
| return [ | |
| Document( | |
| page_content=r[self.config.column_map["document"]], | |
| metadata=r[self.config.column_map["metadata"]], | |
| ) | |
| for r in self.client.query(q_str).named_results() | |
| ] | |
| except Exception as e: | |
| logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") | |
| return [] | |
| def similarity_search_with_relevance_scores( | |
| self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any | |
| ) -> List[Tuple[Document, float]]: | |
| """Perform a similarity search with ClickHouse | |
| Args: | |
| query (str): query string | |
| k (int, optional): Top K neighbors to retrieve. Defaults to 4. | |
| where_str (Optional[str], optional): where condition string. | |
| Defaults to None. | |
| NOTE: Please do not let end-user to fill this and always be aware | |
| of SQL injection. When dealing with metadatas, remember to | |
| use `{self.metadata_column}.attribute` instead of `attribute` | |
| alone. The default name for it is `metadata`. | |
| Returns: | |
| List[Document]: List of (Document, similarity) | |
| """ | |
| q_str = self._build_query_sql( | |
| self.embedding_function.embed_query(query), k, where_str | |
| ) | |
| try: | |
| return [ | |
| ( | |
| Document( | |
| page_content=r[self.config.column_map["document"]], | |
| metadata=r[self.config.column_map["metadata"]], | |
| ), | |
| r["dist"], | |
| ) | |
| for r in self.client.query(q_str).named_results() | |
| ] | |
| except Exception as e: | |
| logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") | |
| return [] | |
| def drop(self) -> None: | |
| """ | |
| Helper function: Drop data | |
| """ | |
| self.client.command( | |
| f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}" | |
| ) | |
| def metadata_column(self) -> str: | |
| return self.config.column_map["metadata"] | |