| | import logging |
| | import os |
| | from typing import Awaitable, Optional, Union |
| |
|
| | import requests |
| | import aiohttp |
| | import asyncio |
| | import hashlib |
| | from concurrent.futures import ThreadPoolExecutor |
| | import time |
| | import re |
| |
|
| | from urllib.parse import quote |
| | from huggingface_hub import snapshot_download |
| | from langchain_classic.retrievers import ( |
| | ContextualCompressionRetriever, |
| | EnsembleRetriever, |
| | ) |
| | from langchain_community.retrievers import BM25Retriever |
| | from langchain_core.documents import Document |
| |
|
| | from open_webui.config import VECTOR_DB |
| | from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT |
| |
|
| |
|
| | from open_webui.models.users import UserModel |
| | from open_webui.models.files import Files |
| | from open_webui.models.knowledge import Knowledges |
| |
|
| | from open_webui.models.chats import Chats |
| | from open_webui.models.notes import Notes |
| | from open_webui.models.access_grants import AccessGrants |
| |
|
| | from open_webui.retrieval.vector.main import GetResult |
| | from open_webui.utils.headers import include_user_info_headers |
| | from open_webui.utils.misc import get_message_list |
| |
|
| | from open_webui.retrieval.web.utils import get_web_loader |
| | from open_webui.retrieval.loaders.youtube import YoutubeLoader |
| |
|
| |
|
| | from open_webui.env import ( |
| | AIOHTTP_CLIENT_TIMEOUT, |
| | OFFLINE_MODE, |
| | ENABLE_FORWARD_USER_INFO_HEADERS, |
| | AIOHTTP_CLIENT_SESSION_SSL, |
| | ) |
| | from open_webui.config import ( |
| | RAG_EMBEDDING_QUERY_PREFIX, |
| | RAG_EMBEDDING_CONTENT_PREFIX, |
| | RAG_EMBEDDING_PREFIX_FIELD_NAME, |
| | ) |
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| |
|
| | from typing import Any |
| |
|
| | from langchain_core.callbacks import CallbackManagerForRetrieverRun |
| | from langchain_core.retrievers import BaseRetriever |
| |
|
| |
|
| | def is_youtube_url(url: str) -> bool: |
| | youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$" |
| | return re.match(youtube_regex, url) is not None |
| |
|
| |
|
| | def get_loader(request, url: str): |
| | if is_youtube_url(url): |
| | return YoutubeLoader( |
| | url, |
| | language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE, |
| | proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL, |
| | ) |
| | else: |
| | return get_web_loader( |
| | url, |
| | verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, |
| | requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS, |
| | trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV, |
| | ) |
| |
|
| |
|
| | def get_content_from_url(request, url: str) -> str: |
| | loader = get_loader(request, url) |
| | docs = loader.load() |
| | content = " ".join([doc.page_content for doc in docs]) |
| | return content, docs |
| |
|
| |
|
| | class VectorSearchRetriever(BaseRetriever): |
| | collection_name: Any |
| | embedding_function: Any |
| | top_k: int |
| |
|
| | def _get_relevant_documents( |
| | self, query: str, *, run_manager: CallbackManagerForRetrieverRun |
| | ) -> list[Document]: |
| | """Get documents relevant to a query. |
| | |
| | Args: |
| | query: String to find relevant documents for. |
| | run_manager: The callback handler to use. |
| | |
| | Returns: |
| | List of relevant documents. |
| | """ |
| | return [] |
| |
|
| | async def _aget_relevant_documents( |
| | self, |
| | query: str, |
| | *, |
| | run_manager: CallbackManagerForRetrieverRun, |
| | ) -> list[Document]: |
| | embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) |
| | result = VECTOR_DB_CLIENT.search( |
| | collection_name=self.collection_name, |
| | vectors=[embedding], |
| | limit=self.top_k, |
| | ) |
| |
|
| | ids = result.ids[0] |
| | metadatas = result.metadatas[0] |
| | documents = result.documents[0] |
| |
|
| | results = [] |
| | for idx in range(len(ids)): |
| | results.append( |
| | Document( |
| | metadata=metadatas[idx], |
| | page_content=documents[idx], |
| | ) |
| | ) |
| | return results |
| |
|
| |
|
| | def query_doc( |
| | collection_name: str, query_embedding: list[float], k: int, user: UserModel = None |
| | ): |
| | try: |
| | log.debug(f"query_doc:doc {collection_name}") |
| | result = VECTOR_DB_CLIENT.search( |
| | collection_name=collection_name, |
| | vectors=[query_embedding], |
| | limit=k, |
| | ) |
| |
|
| | if result: |
| | log.info(f"query_doc:result {result.ids} {result.metadatas}") |
| |
|
| | return result |
| | except Exception as e: |
| | log.exception(f"Error querying doc {collection_name} with limit {k}: {e}") |
| | raise e |
| |
|
| |
|
| | def get_doc(collection_name: str, user: UserModel = None): |
| | try: |
| | log.debug(f"get_doc:doc {collection_name}") |
| | result = VECTOR_DB_CLIENT.get(collection_name=collection_name) |
| |
|
| | if result: |
| | log.info(f"query_doc:result {result.ids} {result.metadatas}") |
| |
|
| | return result |
| | except Exception as e: |
| | log.exception(f"Error getting doc {collection_name}: {e}") |
| | raise e |
| |
|
| |
|
| | def get_enriched_texts(collection_result: GetResult) -> list[str]: |
| | enriched_texts = [] |
| | for idx, text in enumerate(collection_result.documents[0]): |
| | metadata = collection_result.metadatas[0][idx] |
| | metadata_parts = [text] |
| |
|
| | |
| | if metadata.get("name"): |
| | filename = metadata["name"] |
| | filename_tokens = ( |
| | filename.replace("_", " ").replace("-", " ").replace(".", " ") |
| | ) |
| | metadata_parts.append( |
| | f"Filename: {filename} {filename_tokens} {filename_tokens}" |
| | ) |
| |
|
| | |
| | if metadata.get("title"): |
| | metadata_parts.append(f"Title: {metadata['title']}") |
| |
|
| | |
| | if metadata.get("headings") and isinstance(metadata["headings"], list): |
| | headings = " > ".join(str(h) for h in metadata["headings"]) |
| | metadata_parts.append(f"Section: {headings}") |
| |
|
| | |
| | if metadata.get("source"): |
| | metadata_parts.append(f"Source: {metadata['source']}") |
| |
|
| | |
| | if metadata.get("snippet"): |
| | metadata_parts.append(f"Snippet: {metadata['snippet']}") |
| |
|
| | enriched_texts.append(" ".join(metadata_parts)) |
| |
|
| | return enriched_texts |
| |
|
| |
|
| | async def query_doc_with_hybrid_search( |
| | collection_name: str, |
| | collection_result: GetResult, |
| | query: str, |
| | embedding_function, |
| | k: int, |
| | reranking_function, |
| | k_reranker: int, |
| | r: float, |
| | hybrid_bm25_weight: float, |
| | enable_enriched_texts: bool = False, |
| | ) -> dict: |
| | try: |
| | |
| | if ( |
| | not collection_result |
| | or not hasattr(collection_result, "documents") |
| | or not hasattr(collection_result, "metadatas") |
| | ): |
| | log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}") |
| | return {"documents": [], "metadatas": [], "distances": []} |
| |
|
| | |
| | if ( |
| | not collection_result.documents |
| | or len(collection_result.documents) == 0 |
| | or not collection_result.documents[0] |
| | ): |
| | log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}") |
| | return {"documents": [], "metadatas": [], "distances": []} |
| |
|
| | log.debug(f"query_doc_with_hybrid_search:doc {collection_name}") |
| |
|
| | bm25_texts = ( |
| | get_enriched_texts(collection_result) |
| | if enable_enriched_texts |
| | else collection_result.documents[0] |
| | ) |
| |
|
| | bm25_retriever = BM25Retriever.from_texts( |
| | texts=bm25_texts, |
| | metadatas=collection_result.metadatas[0], |
| | ) |
| | bm25_retriever.k = k |
| |
|
| | vector_search_retriever = VectorSearchRetriever( |
| | collection_name=collection_name, |
| | embedding_function=embedding_function, |
| | top_k=k, |
| | ) |
| |
|
| | if hybrid_bm25_weight <= 0: |
| | ensemble_retriever = EnsembleRetriever( |
| | retrievers=[vector_search_retriever], weights=[1.0] |
| | ) |
| | elif hybrid_bm25_weight >= 1: |
| | ensemble_retriever = EnsembleRetriever( |
| | retrievers=[bm25_retriever], weights=[1.0] |
| | ) |
| | else: |
| | ensemble_retriever = EnsembleRetriever( |
| | retrievers=[bm25_retriever, vector_search_retriever], |
| | weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight], |
| | ) |
| |
|
| | compressor = RerankCompressor( |
| | embedding_function=embedding_function, |
| | top_n=k_reranker, |
| | reranking_function=reranking_function, |
| | r_score=r, |
| | ) |
| |
|
| | compression_retriever = ContextualCompressionRetriever( |
| | base_compressor=compressor, base_retriever=ensemble_retriever |
| | ) |
| |
|
| | result = await compression_retriever.ainvoke(query) |
| |
|
| | distances = [d.metadata.get("score") for d in result] |
| | documents = [d.page_content for d in result] |
| | metadatas = [d.metadata for d in result] |
| |
|
| | |
| | if k < k_reranker: |
| | sorted_items = sorted( |
| | zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True |
| | ) |
| | sorted_items = sorted_items[:k] |
| |
|
| | if sorted_items: |
| | distances, documents, metadatas = map(list, zip(*sorted_items)) |
| | else: |
| | distances, documents, metadatas = [], [], [] |
| |
|
| | result = { |
| | "distances": [distances], |
| | "documents": [documents], |
| | "metadatas": [metadatas], |
| | } |
| |
|
| | log.info( |
| | "query_doc_with_hybrid_search:result " |
| | + f'{result["metadatas"]} {result["distances"]}' |
| | ) |
| | return result |
| | except Exception as e: |
| | log.exception(f"Error querying doc {collection_name} with hybrid search: {e}") |
| | raise e |
| |
|
| |
|
| | def merge_get_results(get_results: list[dict]) -> dict: |
| | |
| | combined_documents = [] |
| | combined_metadatas = [] |
| | combined_ids = [] |
| |
|
| | for data in get_results: |
| | combined_documents.extend(data["documents"][0]) |
| | combined_metadatas.extend(data["metadatas"][0]) |
| | combined_ids.extend(data["ids"][0]) |
| |
|
| | |
| | result = { |
| | "documents": [combined_documents], |
| | "metadatas": [combined_metadatas], |
| | "ids": [combined_ids], |
| | } |
| |
|
| | return result |
| |
|
| |
|
| | def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict: |
| | |
| | combined = dict() |
| |
|
| | for data in query_results: |
| | if ( |
| | len(data.get("distances", [])) == 0 |
| | or len(data.get("documents", [])) == 0 |
| | or len(data.get("metadatas", [])) == 0 |
| | ): |
| | continue |
| |
|
| | distances = data["distances"][0] |
| | documents = data["documents"][0] |
| | metadatas = data["metadatas"][0] |
| |
|
| | for distance, document, metadata in zip(distances, documents, metadatas): |
| | if isinstance(document, str): |
| | doc_hash = hashlib.sha256( |
| | document.encode() |
| | ).hexdigest() |
| |
|
| | if doc_hash not in combined.keys(): |
| | combined[doc_hash] = (distance, document, metadata) |
| | continue |
| |
|
| | |
| | if distance > combined[doc_hash][0]: |
| | combined[doc_hash] = (distance, document, metadata) |
| |
|
| | combined = list(combined.values()) |
| | |
| | combined.sort(key=lambda x: x[0], reverse=True) |
| |
|
| | |
| | sorted_distances, sorted_documents, sorted_metadatas = ( |
| | zip(*combined[:k]) if combined else ([], [], []) |
| | ) |
| |
|
| | |
| | return { |
| | "distances": [list(sorted_distances)], |
| | "documents": [list(sorted_documents)], |
| | "metadatas": [list(sorted_metadatas)], |
| | } |
| |
|
| |
|
| | def get_all_items_from_collections(collection_names: list[str]) -> dict: |
| | results = [] |
| |
|
| | for collection_name in collection_names: |
| | if collection_name: |
| | try: |
| | result = get_doc(collection_name=collection_name) |
| | if result is not None: |
| | results.append(result.model_dump()) |
| | except Exception as e: |
| | log.exception(f"Error when querying the collection: {e}") |
| | else: |
| | pass |
| |
|
| | return merge_get_results(results) |
| |
|
| |
|
| | async def query_collection( |
| | collection_names: list[str], |
| | queries: list[str], |
| | embedding_function, |
| | k: int, |
| | ) -> dict: |
| | results = [] |
| | error = False |
| |
|
| | def process_query_collection(collection_name, query_embedding): |
| | try: |
| | if collection_name: |
| | result = query_doc( |
| | collection_name=collection_name, |
| | k=k, |
| | query_embedding=query_embedding, |
| | ) |
| | if result is not None: |
| | return result.model_dump(), None |
| | return None, None |
| | except Exception as e: |
| | log.exception(f"Error when querying the collection: {e}") |
| | return None, e |
| |
|
| | |
| | query_embeddings = await embedding_function( |
| | queries, prefix=RAG_EMBEDDING_QUERY_PREFIX |
| | ) |
| | log.debug( |
| | f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections" |
| | ) |
| |
|
| | with ThreadPoolExecutor() as executor: |
| | future_results = [] |
| | for query_embedding in query_embeddings: |
| | for collection_name in collection_names: |
| | result = executor.submit( |
| | process_query_collection, collection_name, query_embedding |
| | ) |
| | future_results.append(result) |
| | task_results = [future.result() for future in future_results] |
| |
|
| | for result, err in task_results: |
| | if err is not None: |
| | error = True |
| | elif result is not None: |
| | results.append(result) |
| |
|
| | if error and not results: |
| | log.warning("All collection queries failed. No results returned.") |
| |
|
| | return merge_and_sort_query_results(results, k=k) |
| |
|
| |
|
| | async def query_collection_with_hybrid_search( |
| | collection_names: list[str], |
| | queries: list[str], |
| | embedding_function, |
| | k: int, |
| | reranking_function, |
| | k_reranker: int, |
| | r: float, |
| | hybrid_bm25_weight: float, |
| | enable_enriched_texts: bool = False, |
| | ) -> dict: |
| | results = [] |
| | error = False |
| | |
| | |
| | collection_results = {} |
| | for collection_name in collection_names: |
| | try: |
| | log.debug( |
| | f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}" |
| | ) |
| | collection_results[collection_name] = VECTOR_DB_CLIENT.get( |
| | collection_name=collection_name |
| | ) |
| | except Exception as e: |
| | log.exception(f"Failed to fetch collection {collection_name}: {e}") |
| | collection_results[collection_name] = None |
| |
|
| | log.info( |
| | f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..." |
| | ) |
| |
|
| | async def process_query(collection_name, query): |
| | try: |
| | result = await query_doc_with_hybrid_search( |
| | collection_name=collection_name, |
| | collection_result=collection_results[collection_name], |
| | query=query, |
| | embedding_function=embedding_function, |
| | k=k, |
| | reranking_function=reranking_function, |
| | k_reranker=k_reranker, |
| | r=r, |
| | hybrid_bm25_weight=hybrid_bm25_weight, |
| | enable_enriched_texts=enable_enriched_texts, |
| | ) |
| | return result, None |
| | except Exception as e: |
| | log.exception(f"Error when querying the collection with hybrid_search: {e}") |
| | return None, e |
| |
|
| | |
| | |
| | tasks = [ |
| | (collection_name, query) |
| | for collection_name in collection_names |
| | if collection_results[collection_name] is not None |
| | for query in queries |
| | ] |
| |
|
| | |
| | task_results = await asyncio.gather( |
| | *[process_query(collection_name, query) for collection_name, query in tasks] |
| | ) |
| |
|
| | for result, err in task_results: |
| | if err is not None: |
| | error = True |
| | elif result is not None: |
| | results.append(result) |
| |
|
| | if error and not results: |
| | raise Exception( |
| | "Hybrid search failed for all collections. Using Non-hybrid search as fallback." |
| | ) |
| |
|
| | return merge_and_sort_query_results(results, k=k) |
| |
|
| |
|
| | def generate_openai_batch_embeddings( |
| | model: str, |
| | texts: list[str], |
| | url: str = "https://api.openai.com/v1", |
| | key: str = "", |
| | prefix: str = None, |
| | user: UserModel = None, |
| | ) -> Optional[list[list[float]]]: |
| | try: |
| | log.debug( |
| | f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}" |
| | ) |
| | json_data = {"input": texts, "model": model} |
| | if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): |
| | json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix |
| |
|
| | headers = { |
| | "Content-Type": "application/json", |
| | "Authorization": f"Bearer {key}", |
| | } |
| | if ENABLE_FORWARD_USER_INFO_HEADERS and user: |
| | headers = include_user_info_headers(headers, user) |
| |
|
| | r = requests.post( |
| | f"{url}/embeddings", |
| | headers=headers, |
| | json=json_data, |
| | ) |
| | r.raise_for_status() |
| | data = r.json() |
| | if "data" in data: |
| | return [elem["embedding"] for elem in data["data"]] |
| | else: |
| | raise "Something went wrong :/" |
| | except Exception as e: |
| | log.exception(f"Error generating openai batch embeddings: {e}") |
| | return None |
| |
|
| |
|
| | async def agenerate_openai_batch_embeddings( |
| | model: str, |
| | texts: list[str], |
| | url: str = "https://api.openai.com/v1", |
| | key: str = "", |
| | prefix: str = None, |
| | user: UserModel = None, |
| | ) -> Optional[list[list[float]]]: |
| | try: |
| | log.debug( |
| | f"agenerate_openai_batch_embeddings:model {model} batch size: {len(texts)}" |
| | ) |
| | form_data = {"input": texts, "model": model} |
| | if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): |
| | form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix |
| |
|
| | headers = { |
| | "Content-Type": "application/json", |
| | "Authorization": f"Bearer {key}", |
| | } |
| | if ENABLE_FORWARD_USER_INFO_HEADERS and user: |
| | headers = include_user_info_headers(headers, user) |
| |
|
| | async with aiohttp.ClientSession( |
| | trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) |
| | ) as session: |
| | async with session.post( |
| | f"{url}/embeddings", |
| | headers=headers, |
| | json=form_data, |
| | ssl=AIOHTTP_CLIENT_SESSION_SSL, |
| | ) as r: |
| | r.raise_for_status() |
| | data = await r.json() |
| | if "data" in data: |
| | return [item["embedding"] for item in data["data"]] |
| | else: |
| | raise Exception("Something went wrong :/") |
| | except Exception as e: |
| | log.exception(f"Error generating openai batch embeddings: {e}") |
| | return None |
| |
|
| |
|
| | def generate_azure_openai_batch_embeddings( |
| | model: str, |
| | texts: list[str], |
| | url: str, |
| | key: str = "", |
| | version: str = "", |
| | prefix: str = None, |
| | user: UserModel = None, |
| | ) -> Optional[list[list[float]]]: |
| | try: |
| | log.debug( |
| | f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}" |
| | ) |
| | json_data = {"input": texts} |
| | if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): |
| | json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix |
| |
|
| | url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}" |
| |
|
| | for _ in range(5): |
| | headers = { |
| | "Content-Type": "application/json", |
| | "api-key": key, |
| | } |
| | if ENABLE_FORWARD_USER_INFO_HEADERS and user: |
| | headers = include_user_info_headers(headers, user) |
| |
|
| | r = requests.post( |
| | url, |
| | headers=headers, |
| | json=json_data, |
| | ) |
| | if r.status_code == 429: |
| | retry = float(r.headers.get("Retry-After", "1")) |
| | time.sleep(retry) |
| | continue |
| | r.raise_for_status() |
| | data = r.json() |
| | if "data" in data: |
| | return [elem["embedding"] for elem in data["data"]] |
| | else: |
| | raise Exception("Something went wrong :/") |
| | return None |
| | except Exception as e: |
| | log.exception(f"Error generating azure openai batch embeddings: {e}") |
| | return None |
| |
|
| |
|
| | async def agenerate_azure_openai_batch_embeddings( |
| | model: str, |
| | texts: list[str], |
| | url: str, |
| | key: str = "", |
| | version: str = "", |
| | prefix: str = None, |
| | user: UserModel = None, |
| | ) -> Optional[list[list[float]]]: |
| | try: |
| | log.debug( |
| | f"agenerate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}" |
| | ) |
| | form_data = {"input": texts} |
| | if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): |
| | form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix |
| |
|
| | full_url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}" |
| |
|
| | headers = { |
| | "Content-Type": "application/json", |
| | "api-key": key, |
| | } |
| | if ENABLE_FORWARD_USER_INFO_HEADERS and user: |
| | headers = include_user_info_headers(headers, user) |
| |
|
| | async with aiohttp.ClientSession( |
| | trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) |
| | ) as session: |
| | async with session.post( |
| | full_url, |
| | headers=headers, |
| | json=form_data, |
| | ssl=AIOHTTP_CLIENT_SESSION_SSL, |
| | ) as r: |
| | r.raise_for_status() |
| | data = await r.json() |
| | if "data" in data: |
| | return [item["embedding"] for item in data["data"]] |
| | else: |
| | raise Exception("Something went wrong :/") |
| | except Exception as e: |
| | log.exception(f"Error generating azure openai batch embeddings: {e}") |
| | return None |
| |
|
| |
|
| | def generate_ollama_batch_embeddings( |
| | model: str, |
| | texts: list[str], |
| | url: str, |
| | key: str = "", |
| | prefix: str = None, |
| | user: UserModel = None, |
| | ) -> Optional[list[list[float]]]: |
| | try: |
| | log.debug( |
| | f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}" |
| | ) |
| | json_data = {"input": texts, "model": model} |
| | if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): |
| | json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix |
| |
|
| | headers = { |
| | "Content-Type": "application/json", |
| | "Authorization": f"Bearer {key}", |
| | } |
| | if ENABLE_FORWARD_USER_INFO_HEADERS and user: |
| | headers = include_user_info_headers(headers, user) |
| |
|
| | r = requests.post( |
| | f"{url}/api/embed", |
| | headers=headers, |
| | json=json_data, |
| | ) |
| | r.raise_for_status() |
| | data = r.json() |
| |
|
| | if "embeddings" in data: |
| | return data["embeddings"] |
| | else: |
| | raise "Something went wrong :/" |
| | except Exception as e: |
| | log.exception(f"Error generating ollama batch embeddings: {e}") |
| | return None |
| |
|
| |
|
| | async def agenerate_ollama_batch_embeddings( |
| | model: str, |
| | texts: list[str], |
| | url: str, |
| | key: str = "", |
| | prefix: str = None, |
| | user: UserModel = None, |
| | ) -> Optional[list[list[float]]]: |
| | try: |
| | log.debug( |
| | f"agenerate_ollama_batch_embeddings:model {model} batch size: {len(texts)}" |
| | ) |
| | form_data = {"input": texts, "model": model} |
| | if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): |
| | form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix |
| |
|
| | headers = { |
| | "Content-Type": "application/json", |
| | "Authorization": f"Bearer {key}", |
| | } |
| | if ENABLE_FORWARD_USER_INFO_HEADERS and user: |
| | headers = include_user_info_headers(headers, user) |
| |
|
| | async with aiohttp.ClientSession( |
| | trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) |
| | ) as session: |
| | async with session.post( |
| | f"{url}/api/embed", |
| | headers=headers, |
| | json=form_data, |
| | ssl=AIOHTTP_CLIENT_SESSION_SSL, |
| | ) as r: |
| | r.raise_for_status() |
| | data = await r.json() |
| | if "embeddings" in data: |
| | return data["embeddings"] |
| | else: |
| | raise Exception("Something went wrong :/") |
| | except Exception as e: |
| | log.exception(f"Error generating ollama batch embeddings: {e}") |
| | return None |
| |
|
| |
|
| | def get_embedding_function( |
| | embedding_engine, |
| | embedding_model, |
| | embedding_function, |
| | url, |
| | key, |
| | embedding_batch_size, |
| | azure_api_version=None, |
| | enable_async=True, |
| | ) -> Awaitable: |
| | if embedding_engine == "": |
| | |
| | async def async_embedding_function(query, prefix=None, user=None): |
| | return await asyncio.to_thread( |
| | ( |
| | lambda query, prefix=None: embedding_function.encode( |
| | query, |
| | batch_size=int(embedding_batch_size), |
| | **({"prompt": prefix} if prefix else {}), |
| | ).tolist() |
| | ), |
| | query, |
| | prefix, |
| | ) |
| |
|
| | return async_embedding_function |
| | elif embedding_engine in ["ollama", "openai", "azure_openai"]: |
| | embedding_function = lambda query, prefix=None, user=None: generate_embeddings( |
| | engine=embedding_engine, |
| | model=embedding_model, |
| | text=query, |
| | prefix=prefix, |
| | url=url, |
| | key=key, |
| | user=user, |
| | azure_api_version=azure_api_version, |
| | ) |
| |
|
| | async def async_embedding_function(query, prefix=None, user=None): |
| | if isinstance(query, list): |
| | |
| | batches = [ |
| | query[i : i + embedding_batch_size] |
| | for i in range(0, len(query), embedding_batch_size) |
| | ] |
| |
|
| | if enable_async: |
| | log.debug( |
| | f"generate_multiple_async: Processing {len(batches)} batches in parallel" |
| | ) |
| | |
| | tasks = [ |
| | embedding_function(batch, prefix=prefix, user=user) |
| | for batch in batches |
| | ] |
| | batch_results = await asyncio.gather(*tasks) |
| | else: |
| | log.debug( |
| | f"generate_multiple_async: Processing {len(batches)} batches sequentially" |
| | ) |
| | batch_results = [] |
| | for batch in batches: |
| | batch_results.append( |
| | await embedding_function(batch, prefix=prefix, user=user) |
| | ) |
| |
|
| | |
| | embeddings = [] |
| | for batch_embeddings in batch_results: |
| | if isinstance(batch_embeddings, list): |
| | embeddings.extend(batch_embeddings) |
| |
|
| | log.debug( |
| | f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches" |
| | ) |
| | return embeddings |
| | else: |
| | return await embedding_function(query, prefix, user) |
| |
|
| | return async_embedding_function |
| | else: |
| | raise ValueError(f"Unknown embedding engine: {embedding_engine}") |
| |
|
| |
|
| | async def generate_embeddings( |
| | engine: str, |
| | model: str, |
| | text: Union[str, list[str]], |
| | prefix: Union[str, None] = None, |
| | **kwargs, |
| | ): |
| | url = kwargs.get("url", "") |
| | key = kwargs.get("key", "") |
| | user = kwargs.get("user") |
| |
|
| | if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None: |
| | if isinstance(text, list): |
| | text = [f"{prefix}{text_element}" for text_element in text] |
| | else: |
| | text = f"{prefix}{text}" |
| |
|
| | if engine == "ollama": |
| | embeddings = await agenerate_ollama_batch_embeddings( |
| | **{ |
| | "model": model, |
| | "texts": text if isinstance(text, list) else [text], |
| | "url": url, |
| | "key": key, |
| | "prefix": prefix, |
| | "user": user, |
| | } |
| | ) |
| | return embeddings[0] if isinstance(text, str) else embeddings |
| | elif engine == "openai": |
| | embeddings = await agenerate_openai_batch_embeddings( |
| | model, text if isinstance(text, list) else [text], url, key, prefix, user |
| | ) |
| | return embeddings[0] if isinstance(text, str) else embeddings |
| | elif engine == "azure_openai": |
| | azure_api_version = kwargs.get("azure_api_version", "") |
| | embeddings = await agenerate_azure_openai_batch_embeddings( |
| | model, |
| | text if isinstance(text, list) else [text], |
| | url, |
| | key, |
| | azure_api_version, |
| | prefix, |
| | user, |
| | ) |
| | return embeddings[0] if isinstance(text, str) else embeddings |
| |
|
| |
|
| | def get_reranking_function(reranking_engine, reranking_model, reranking_function): |
| | if reranking_function is None: |
| | return None |
| | if reranking_engine == "external": |
| | return lambda query, documents, user=None: reranking_function.predict( |
| | [(query, doc.page_content) for doc in documents], user=user |
| | ) |
| | else: |
| | return lambda query, documents, user=None: reranking_function.predict( |
| | [(query, doc.page_content) for doc in documents] |
| | ) |
| |
|
| |
|
| | async def get_sources_from_items( |
| | request, |
| | items, |
| | queries, |
| | embedding_function, |
| | k, |
| | reranking_function, |
| | k_reranker, |
| | r, |
| | hybrid_bm25_weight, |
| | hybrid_search, |
| | full_context=False, |
| | user: Optional[UserModel] = None, |
| | ): |
| | log.debug( |
| | f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}" |
| | ) |
| |
|
| | extracted_collections = [] |
| | query_results = [] |
| |
|
| | for item in items: |
| | query_result = None |
| | collection_names = [] |
| |
|
| | if item.get("type") == "text": |
| | |
| | |
| |
|
| | if item.get("context") == "full": |
| | if item.get("file"): |
| | |
| | query_result = { |
| | "documents": [ |
| | [item.get("file", {}).get("data", {}).get("content")] |
| | ], |
| | "metadatas": [[item.get("file", {}).get("meta", {})]], |
| | } |
| |
|
| | if query_result is None: |
| | |
| | if item.get("collection_name"): |
| | |
| | collection_names.append(item.get("collection_name")) |
| | elif item.get("file"): |
| | |
| | query_result = { |
| | "documents": [ |
| | [item.get("file", {}).get("data", {}).get("content")] |
| | ], |
| | "metadatas": [[item.get("file", {}).get("meta", {})]], |
| | } |
| | else: |
| | |
| | query_result = { |
| | "documents": [[item.get("content")]], |
| | "metadatas": [ |
| | [{"file_id": item.get("id"), "name": item.get("name")}] |
| | ], |
| | } |
| |
|
| | elif item.get("type") == "note": |
| | |
| | note = Notes.get_note_by_id(item.get("id")) |
| |
|
| | if note and ( |
| | user.role == "admin" |
| | or note.user_id == user.id |
| | or AccessGrants.has_access( |
| | user_id=user.id, |
| | resource_type="note", |
| | resource_id=note.id, |
| | permission="read", |
| | ) |
| | ): |
| | |
| | query_result = { |
| | "documents": [[note.data.get("content", {}).get("md", "")]], |
| | "metadatas": [[{"file_id": note.id, "name": note.title}]], |
| | } |
| |
|
| | elif item.get("type") == "chat": |
| | |
| | chat = Chats.get_chat_by_id(item.get("id")) |
| |
|
| | if chat and (user.role == "admin" or chat.user_id == user.id): |
| | messages_map = chat.chat.get("history", {}).get("messages", {}) |
| | message_id = chat.chat.get("history", {}).get("currentId") |
| |
|
| | if messages_map and message_id: |
| | |
| | message_list = get_message_list(messages_map, message_id) |
| | message_history = "\n".join( |
| | [ |
| | f"#### {m.get('role', 'user').capitalize()}\n{m.get('content')}\n" |
| | for m in message_list |
| | ] |
| | ) |
| |
|
| | |
| | query_result = { |
| | "documents": [[message_history]], |
| | "metadatas": [[{"file_id": chat.id, "name": chat.title}]], |
| | } |
| |
|
| | elif item.get("type") == "url": |
| | content, docs = get_content_from_url(request, item.get("url")) |
| | if docs: |
| | query_result = { |
| | "documents": [[content]], |
| | "metadatas": [[{"url": item.get("url"), "name": item.get("url")}]], |
| | } |
| | elif item.get("type") == "file": |
| | if ( |
| | item.get("context") == "full" |
| | or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL |
| | ): |
| | if item.get("file", {}).get("data", {}).get("content", ""): |
| | |
| | |
| | query_result = { |
| | "documents": [ |
| | [item.get("file", {}).get("data", {}).get("content", "")] |
| | ], |
| | "metadatas": [ |
| | [ |
| | { |
| | "file_id": item.get("id"), |
| | "name": item.get("name"), |
| | **item.get("file") |
| | .get("data", {}) |
| | .get("metadata", {}), |
| | } |
| | ] |
| | ], |
| | } |
| | elif item.get("id"): |
| | file_object = Files.get_file_by_id(item.get("id")) |
| | if file_object: |
| | query_result = { |
| | "documents": [[file_object.data.get("content", "")]], |
| | "metadatas": [ |
| | [ |
| | { |
| | "file_id": item.get("id"), |
| | "name": file_object.filename, |
| | "source": file_object.filename, |
| | } |
| | ] |
| | ], |
| | } |
| | else: |
| | |
| | if item.get("legacy"): |
| | collection_names.append(f"{item['id']}") |
| | else: |
| | collection_names.append(f"file-{item['id']}") |
| |
|
| | elif item.get("type") == "collection": |
| | |
| | knowledge_base = Knowledges.get_knowledge_by_id(item.get("id")) |
| |
|
| | if knowledge_base and ( |
| | user.role == "admin" |
| | or knowledge_base.user_id == user.id |
| | or AccessGrants.has_access( |
| | user_id=user.id, |
| | resource_type="knowledge", |
| | resource_id=knowledge_base.id, |
| | permission="read", |
| | ) |
| | ): |
| | if ( |
| | item.get("context") == "full" |
| | or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL |
| | ): |
| | if knowledge_base and ( |
| | user.role == "admin" |
| | or knowledge_base.user_id == user.id |
| | or AccessGrants.has_access( |
| | user_id=user.id, |
| | resource_type="knowledge", |
| | resource_id=knowledge_base.id, |
| | permission="read", |
| | ) |
| | ): |
| | files = Knowledges.get_files_by_id(knowledge_base.id) |
| |
|
| | documents = [] |
| | metadatas = [] |
| | for file in files: |
| | documents.append(file.data.get("content", "")) |
| | metadatas.append( |
| | { |
| | "file_id": file.id, |
| | "name": file.filename, |
| | "source": file.filename, |
| | } |
| | ) |
| |
|
| | query_result = { |
| | "documents": [documents], |
| | "metadatas": [metadatas], |
| | } |
| | else: |
| | |
| | if item.get("legacy"): |
| | collection_names = item.get("collection_names", []) |
| | else: |
| | collection_names.append(item["id"]) |
| |
|
| | elif item.get("docs"): |
| | |
| | query_result = { |
| | "documents": [[doc.get("content") for doc in item.get("docs")]], |
| | "metadatas": [[doc.get("metadata") for doc in item.get("docs")]], |
| | } |
| | elif item.get("collection_name"): |
| | |
| | collection_names.append(item["collection_name"]) |
| | elif item.get("collection_names"): |
| | |
| | collection_names.extend(item["collection_names"]) |
| |
|
| | |
| | |
| | if query_result is None and collection_names: |
| | collection_names = set(collection_names).difference(extracted_collections) |
| | if not collection_names: |
| | log.debug(f"skipping {item} as it has already been extracted") |
| | continue |
| |
|
| | try: |
| | if full_context: |
| | query_result = get_all_items_from_collections(collection_names) |
| | else: |
| | query_result = None |
| | if hybrid_search: |
| | try: |
| | query_result = await query_collection_with_hybrid_search( |
| | collection_names=collection_names, |
| | queries=queries, |
| | embedding_function=embedding_function, |
| | k=k, |
| | reranking_function=reranking_function, |
| | k_reranker=k_reranker, |
| | r=r, |
| | hybrid_bm25_weight=hybrid_bm25_weight, |
| | enable_enriched_texts=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS, |
| | ) |
| | except Exception as e: |
| | log.debug( |
| | "Error when using hybrid search, using non hybrid search as fallback." |
| | ) |
| |
|
| | |
| | if not hybrid_search and query_result is None: |
| | query_result = await query_collection( |
| | collection_names=collection_names, |
| | queries=queries, |
| | embedding_function=embedding_function, |
| | k=k, |
| | ) |
| | except Exception as e: |
| | log.exception(e) |
| |
|
| | extracted_collections.extend(collection_names) |
| |
|
| | if query_result: |
| | if "data" in item: |
| | del item["data"] |
| | query_results.append({**query_result, "file": item}) |
| |
|
| | sources = [] |
| | for query_result in query_results: |
| | try: |
| | if "documents" in query_result: |
| | if "metadatas" in query_result: |
| | source = { |
| | "source": query_result["file"], |
| | "document": query_result["documents"][0], |
| | "metadata": query_result["metadatas"][0], |
| | } |
| | if "distances" in query_result and query_result["distances"]: |
| | source["distances"] = query_result["distances"][0] |
| |
|
| | sources.append(source) |
| | except Exception as e: |
| | log.exception(e) |
| | return sources |
| |
|
| |
|
| | def get_model_path(model: str, update_model: bool = False): |
| | |
| | cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") |
| |
|
| | local_files_only = not update_model |
| |
|
| | if OFFLINE_MODE: |
| | local_files_only = True |
| |
|
| | snapshot_kwargs = { |
| | "cache_dir": cache_dir, |
| | "local_files_only": local_files_only, |
| | } |
| |
|
| | log.debug(f"model: {model}") |
| | log.debug(f"snapshot_kwargs: {snapshot_kwargs}") |
| |
|
| | |
| | if ( |
| | os.path.exists(model) |
| | or ("\\" in model or model.count("/") > 1) |
| | and local_files_only |
| | ): |
| | |
| | return model |
| | elif "/" not in model: |
| | |
| | model = "sentence-transformers" + "/" + model |
| |
|
| | snapshot_kwargs["repo_id"] = model |
| |
|
| | |
| | try: |
| | model_repo_path = snapshot_download(**snapshot_kwargs) |
| | log.debug(f"model_repo_path: {model_repo_path}") |
| | return model_repo_path |
| | except Exception as e: |
| | log.exception(f"Cannot determine model snapshot path: {e}") |
| | return model |
| |
|
| |
|
| | import operator |
| | from typing import Optional, Sequence |
| |
|
| | from langchain_core.callbacks import Callbacks |
| | from langchain_core.documents import BaseDocumentCompressor, Document |
| |
|
| |
|
| | class RerankCompressor(BaseDocumentCompressor): |
| | embedding_function: Any |
| | top_n: int |
| | reranking_function: Any |
| | r_score: float |
| |
|
| | class Config: |
| | extra = "forbid" |
| | arbitrary_types_allowed = True |
| |
|
| | def compress_documents( |
| | self, |
| | documents: Sequence[Document], |
| | query: str, |
| | callbacks: Optional[Callbacks] = None, |
| | ) -> Sequence[Document]: |
| | """Compress retrieved documents given the query context. |
| | |
| | Args: |
| | documents: The retrieved documents. |
| | query: The query context. |
| | callbacks: Optional callbacks to run during compression. |
| | |
| | Returns: |
| | The compressed documents. |
| | |
| | """ |
| | return [] |
| |
|
| | async def acompress_documents( |
| | self, |
| | documents: Sequence[Document], |
| | query: str, |
| | callbacks: Optional[Callbacks] = None, |
| | ) -> Sequence[Document]: |
| | reranking = self.reranking_function is not None |
| |
|
| | scores = None |
| | if reranking: |
| | scores = await asyncio.to_thread(self.reranking_function, query, documents) |
| | else: |
| | from sentence_transformers import util |
| |
|
| | query_embedding = await self.embedding_function( |
| | query, RAG_EMBEDDING_QUERY_PREFIX |
| | ) |
| | document_embedding = await self.embedding_function( |
| | [doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX |
| | ) |
| | scores = util.cos_sim(query_embedding, document_embedding)[0] |
| |
|
| | if scores is not None: |
| | docs_with_scores = list( |
| | zip( |
| | documents, |
| | scores.tolist() if not isinstance(scores, list) else scores, |
| | ) |
| | ) |
| | if self.r_score: |
| | docs_with_scores = [ |
| | (d, s) for d, s in docs_with_scores if s >= self.r_score |
| | ] |
| |
|
| | result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) |
| | final_results = [] |
| | for doc, doc_score in result[: self.top_n]: |
| | metadata = doc.metadata |
| | metadata["score"] = doc_score |
| | doc = Document( |
| | page_content=doc.page_content, |
| | metadata=metadata, |
| | ) |
| | final_results.append(doc) |
| | return final_results |
| | else: |
| | log.warning( |
| | "No valid scores found, check your reranking function. Returning original documents." |
| | ) |
| | return documents |
| |
|