| from opensearchpy import OpenSearch |
| from typing import Optional |
|
|
| from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult |
| from open_webui.config import ( |
| OPENSEARCH_URI, |
| OPENSEARCH_SSL, |
| OPENSEARCH_CERT_VERIFY, |
| OPENSEARCH_USERNAME, |
| OPENSEARCH_PASSWORD, |
| ) |
|
|
|
|
| class OpenSearchClient: |
| def __init__(self): |
| self.index_prefix = "open_webui" |
| self.client = OpenSearch( |
| hosts=[OPENSEARCH_URI], |
| use_ssl=OPENSEARCH_SSL, |
| verify_certs=OPENSEARCH_CERT_VERIFY, |
| http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD), |
| ) |
|
|
| def _result_to_get_result(self, result) -> GetResult: |
| ids = [] |
| documents = [] |
| metadatas = [] |
|
|
| for hit in result["hits"]["hits"]: |
| ids.append(hit["_id"]) |
| documents.append(hit["_source"].get("text")) |
| metadatas.append(hit["_source"].get("metadata")) |
|
|
| return GetResult(ids=ids, documents=documents, metadatas=metadatas) |
|
|
| def _result_to_search_result(self, result) -> SearchResult: |
| ids = [] |
| distances = [] |
| documents = [] |
| metadatas = [] |
|
|
| for hit in result["hits"]["hits"]: |
| ids.append(hit["_id"]) |
| distances.append(hit["_score"]) |
| documents.append(hit["_source"].get("text")) |
| metadatas.append(hit["_source"].get("metadata")) |
|
|
| return SearchResult( |
| ids=ids, distances=distances, documents=documents, metadatas=metadatas |
| ) |
|
|
| def _create_index(self, index_name: str, dimension: int): |
| body = { |
| "mappings": { |
| "properties": { |
| "id": {"type": "keyword"}, |
| "vector": { |
| "type": "dense_vector", |
| "dims": dimension, |
| "index": true, |
| "similarity": "faiss", |
| "method": { |
| "name": "hnsw", |
| "space_type": "ip", |
| "engine": "faiss", |
| "ef_construction": 128, |
| "m": 16, |
| }, |
| }, |
| "text": {"type": "text"}, |
| "metadata": {"type": "object"}, |
| } |
| } |
| } |
| self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body) |
|
|
| def _create_batches(self, items: list[VectorItem], batch_size=100): |
| for i in range(0, len(items), batch_size): |
| yield items[i : i + batch_size] |
|
|
| def has_collection(self, index_name: str) -> bool: |
| |
| |
| return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}") |
|
|
| def delete_colleciton(self, index_name: str): |
| |
| |
| self.client.indices.delete(index=f"{self.index_prefix}_{index_name}") |
|
|
| def search( |
| self, index_name: str, vectors: list[list[float]], limit: int |
| ) -> Optional[SearchResult]: |
| query = { |
| "size": limit, |
| "_source": ["text", "metadata"], |
| "query": { |
| "script_score": { |
| "query": {"match_all": {}}, |
| "script": { |
| "source": "cosineSimilarity(params.vector, 'vector') + 1.0", |
| "params": { |
| "vector": vectors[0] |
| }, |
| }, |
| } |
| }, |
| } |
|
|
| result = self.client.search( |
| index=f"{self.index_prefix}_{index_name}", body=query |
| ) |
|
|
| return self._result_to_search_result(result) |
|
|
| def get_or_create_index(self, index_name: str, dimension: int): |
| if not self.has_index(index_name): |
| self._create_index(index_name, dimension) |
|
|
| def get(self, index_name: str) -> Optional[GetResult]: |
| query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]} |
|
|
| result = self.client.search( |
| index=f"{self.index_prefix}_{index_name}", body=query |
| ) |
| return self._result_to_get_result(result) |
|
|
| def insert(self, index_name: str, items: list[VectorItem]): |
| if not self.has_index(index_name): |
| self._create_index(index_name, dimension=len(items[0]["vector"])) |
|
|
| for batch in self._create_batches(items): |
| actions = [ |
| { |
| "index": { |
| "_id": item["id"], |
| "_source": { |
| "vector": item["vector"], |
| "text": item["text"], |
| "metadata": item["metadata"], |
| }, |
| } |
| } |
| for item in batch |
| ] |
| self.client.bulk(actions) |
|
|
| def upsert(self, index_name: str, items: list[VectorItem]): |
| if not self.has_index(index_name): |
| self._create_index(index_name, dimension=len(items[0]["vector"])) |
|
|
| for batch in self._create_batches(items): |
| actions = [ |
| { |
| "index": { |
| "_id": item["id"], |
| "_source": { |
| "vector": item["vector"], |
| "text": item["text"], |
| "metadata": item["metadata"], |
| }, |
| } |
| } |
| for item in batch |
| ] |
| self.client.bulk(actions) |
|
|
| def delete(self, index_name: str, ids: list[str]): |
| actions = [ |
| {"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}} |
| for id in ids |
| ] |
| self.client.bulk(body=actions) |
|
|
| def reset(self): |
| indices = self.client.indices.get(index=f"{self.index_prefix}_*") |
| for index in indices: |
| self.client.indices.delete(index=index) |
|
|