| import mimetypes |
| import os |
| import asyncio |
| import aiohttp |
| import json |
|
|
| from helpers.vector_db import VectorDB |
|
|
| os.environ["USER_AGENT"] = "@mixedbread-ai/unstructured" |
| from langchain_unstructured import UnstructuredLoader |
|
|
| from urllib.parse import urlparse |
| from typing import Callable, Sequence, List, Optional, Tuple |
| from datetime import datetime |
|
|
| from langchain_community.document_loaders import AsyncHtmlLoader |
| from langchain_community.document_loaders.text import TextLoader |
| from langchain_community.document_loaders.pdf import PyMuPDFLoader |
| from langchain_community.document_transformers import MarkdownifyTransformer |
| from langchain_community.document_loaders.parsers.images import TesseractBlobParser |
|
|
| from langchain_core.documents import Document |
| from langchain.schema import SystemMessage, HumanMessage |
|
|
| from helpers.print_style import PrintStyle |
| from helpers import files, errors |
| from agent import Agent |
|
|
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
|
| DEFAULT_SEARCH_THRESHOLD = 0.5 |
|
|
|
|
| class DocumentQueryStore: |
| """ |
| FAISS Store for document query results. |
| Manages documents identified by URI for storage, retrieval, and searching. |
| """ |
|
|
| |
| DEFAULT_CHUNK_SIZE = 1000 |
| DEFAULT_CHUNK_OVERLAP = 100 |
|
|
| |
| _stores: dict[str, "DocumentQueryStore"] = {} |
|
|
| @staticmethod |
| def get(agent: Agent): |
| """Create a DocumentQueryStore instance for the specified agent.""" |
| if not agent or not agent.config: |
| raise ValueError("Agent and agent config must be provided") |
|
|
| |
| store = DocumentQueryStore(agent) |
| return store |
|
|
| def __init__( |
| self, |
| agent: Agent, |
| ): |
| """Initialize a DocumentQueryStore instance.""" |
| self.agent = agent |
| self.vector_db: VectorDB | None = None |
|
|
| @staticmethod |
| def normalize_uri(uri: str) -> str: |
| """ |
| Normalize a document URI to ensure consistent lookup. |
| |
| Args: |
| uri: The URI to normalize |
| |
| Returns: |
| Normalized URI |
| """ |
| |
| normalized = uri.strip() |
|
|
| |
| parsed = urlparse(normalized) |
| scheme = parsed.scheme or "file" |
|
|
| |
| if scheme == "file": |
| path = files.fix_dev_path( |
| normalized.removeprefix("file://").removeprefix("file:") |
| ) |
| normalized = f"file://{path}" |
|
|
| elif scheme in ["http", "https"]: |
| |
| normalized = normalized.replace("http://", "https://") |
|
|
| return normalized |
|
|
| def init_vector_db(self): |
| return VectorDB(self.agent, cache=True) |
|
|
| async def add_document( |
| self, text: str, document_uri: str, metadata: dict | None = None |
| ) -> tuple[bool, list[str]]: |
| """ |
| Add a document to the store with the given URI. |
| |
| Args: |
| text: The document text content |
| document_uri: The URI that uniquely identifies this document |
| metadata: Optional metadata for the document |
| |
| Returns: |
| True if successful, False otherwise |
| """ |
| |
| document_uri = self.normalize_uri(document_uri) |
|
|
| |
| await self.delete_document(document_uri) |
|
|
| |
| doc_metadata = metadata or {} |
| doc_metadata["document_uri"] = document_uri |
| doc_metadata["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
| |
| text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=self.DEFAULT_CHUNK_SIZE, chunk_overlap=self.DEFAULT_CHUNK_OVERLAP |
| ) |
| chunks = text_splitter.split_text(text) |
|
|
| |
| docs = [] |
| for i, chunk in enumerate(chunks): |
| chunk_metadata = doc_metadata.copy() |
| chunk_metadata["chunk_index"] = i |
| chunk_metadata["total_chunks"] = len(chunks) |
| docs.append(Document(page_content=chunk, metadata=chunk_metadata)) |
|
|
| if not docs: |
| PrintStyle.error(f"No chunks created for document: {document_uri}") |
| return False, [] |
|
|
| try: |
| |
| if not self.vector_db: |
| self.vector_db = self.init_vector_db() |
|
|
| ids = await self.vector_db.insert_documents(docs) |
| PrintStyle.standard( |
| f"Added document '{document_uri}' with {len(docs)} chunks" |
| ) |
| return True, ids |
| except Exception as e: |
| err_text = errors.format_error(e) |
| PrintStyle.error(f"Error adding document '{document_uri}': {err_text}") |
| return False, [] |
|
|
| async def get_document(self, document_uri: str) -> Optional[Document]: |
| """ |
| Retrieve a document by its URI. |
| |
| Args: |
| document_uri: The URI of the document to retrieve |
| |
| Returns: |
| The complete document if found, None otherwise |
| """ |
|
|
| |
| if not self.vector_db: |
| return None |
|
|
| |
| document_uri = self.normalize_uri(document_uri) |
|
|
| |
| docs = await self._get_document_chunks(document_uri) |
| if not docs: |
| PrintStyle.error(f"Document not found: {document_uri}") |
| return None |
|
|
| |
| chunks = sorted(docs, key=lambda x: x.metadata.get("chunk_index", 0)) |
| full_content = "\n".join(chunk.page_content for chunk in chunks) |
|
|
| |
| metadata = chunks[0].metadata.copy() |
| metadata.pop("chunk_index", None) |
| metadata.pop("total_chunks", None) |
|
|
| return Document(page_content=full_content, metadata=metadata) |
|
|
| async def _get_document_chunks(self, document_uri: str) -> List[Document]: |
| """ |
| Get all chunks for a document. |
| |
| Args: |
| document_uri: The URI of the document |
| |
| Returns: |
| List of document chunks |
| """ |
|
|
| |
| if not self.vector_db: |
| return [] |
|
|
| |
| document_uri = self.normalize_uri(document_uri) |
|
|
| |
|
|
| chunks = await self.vector_db.search_by_metadata( |
| filter=f"document_uri == '{document_uri}'", |
| ) |
|
|
| PrintStyle.standard(f"Found {len(chunks)} chunks for document: {document_uri}") |
| return chunks |
|
|
| async def document_exists(self, document_uri: str) -> bool: |
| """ |
| Check if a document exists in the store. |
| |
| Args: |
| document_uri: The URI of the document to check |
| |
| Returns: |
| True if the document exists, False otherwise |
| """ |
|
|
| |
| if not self.vector_db: |
| return False |
|
|
| |
| document_uri = self.normalize_uri(document_uri) |
|
|
| chunks = await self._get_document_chunks(document_uri) |
| return len(chunks) > 0 |
|
|
| async def delete_document(self, document_uri: str) -> bool: |
| """ |
| Delete a document from the store. |
| |
| Args: |
| document_uri: The URI of the document to delete |
| |
| Returns: |
| True if deleted, False if not found |
| """ |
|
|
| |
| if not self.vector_db: |
| return False |
|
|
| |
| document_uri = self.normalize_uri(document_uri) |
|
|
| chunks = await self.vector_db.search_by_metadata( |
| filter=f"document_uri == '{document_uri}'", |
| ) |
| if not chunks: |
| return False |
|
|
| |
| ids_to_delete = [chunk.metadata["id"] for chunk in chunks] |
|
|
| |
| if ids_to_delete: |
| dels = await self.vector_db.delete_documents_by_ids(ids_to_delete) |
| PrintStyle.standard( |
| f"Deleted document '{document_uri}' with {len(dels)} chunks" |
| ) |
| return True |
|
|
| return False |
|
|
| async def search_documents( |
| self, query: str, limit: int = 10, threshold: float = 0.5, filter: str = "" |
| ) -> List[Document]: |
| """ |
| Search for documents similar to the query across the entire store. |
| |
| Args: |
| query: The search query string |
| limit: Maximum number of results to return |
| threshold: Minimum similarity score threshold (0-1) |
| |
| Returns: |
| List of matching documents |
| """ |
|
|
| |
| if not self.vector_db: |
| return [] |
|
|
| |
| if not query: |
| return [] |
|
|
| |
| try: |
| results = await self.vector_db.search_by_similarity_threshold( |
| query=query, limit=limit, threshold=threshold, filter=filter |
| ) |
|
|
| PrintStyle.standard(f"Search '{query}' returned {len(results)} results") |
| return results |
| except Exception as e: |
| PrintStyle.error(f"Error searching documents: {str(e)}") |
| return [] |
|
|
| async def search_document( |
| self, document_uri: str, query: str, limit: int = 10, threshold: float = 0.5 |
| ) -> List[Document]: |
| """ |
| Search for content within a specific document. |
| |
| Args: |
| document_uri: The URI of the document to search within |
| query: The search query string |
| limit: Maximum number of results to return |
| threshold: Minimum similarity score threshold (0-1) |
| |
| Returns: |
| List of matching document chunks |
| """ |
| return await self.search_documents( |
| query, limit, threshold, f"document_uri == '{document_uri}'" |
| ) |
|
|
| async def list_documents(self) -> List[str]: |
| """ |
| Get a list of all document URIs in the store. |
| |
| Returns: |
| List of document URIs |
| """ |
| |
| if not self.vector_db: |
| return [] |
|
|
| |
| uris = set() |
| for doc in self.vector_db.db.get_all_docs().values(): |
| if isinstance(doc.metadata, dict): |
| uri = doc.metadata.get("document_uri") |
| if uri: |
| uris.add(uri) |
|
|
| return sorted(list(uris)) |
|
|
|
|
| class DocumentQueryHelper: |
|
|
| def __init__( |
| self, agent: Agent, progress_callback: Callable[[str], None] | None = None |
| ): |
| self.agent = agent |
| self.store = DocumentQueryStore.get(agent) |
| self.progress_callback = progress_callback or (lambda x: None) |
| self.store_lock = asyncio.Lock() |
|
|
| async def document_qa( |
| self, document_uris: List[str], questions: Sequence[str] |
| ) -> Tuple[bool, str]: |
| self.progress_callback( |
| f"Starting Q&A process for {len(document_uris)} documents" |
| ) |
| await self.agent.handle_intervention() |
|
|
| |
| await asyncio.gather( |
| *[self.document_get_content(uri, True) for uri in document_uris] |
| ) |
| await self.agent.handle_intervention() |
| selected_chunks = {} |
| for question in questions: |
| self.progress_callback(f"Optimizing query: {question}") |
| await self.agent.handle_intervention() |
| human_content = f'Search Query: "{question}"' |
| system_content = self.agent.parse_prompt( |
| "fw.document_query.optmimize_query.md" |
| ) |
|
|
| optimized_query = ( |
| await self.agent.call_utility_model( |
| system=system_content, message=human_content |
| ) |
| ).strip() |
|
|
| await self.agent.handle_intervention() |
| self.progress_callback(f"Searching documents with query: {optimized_query}") |
|
|
| normalized_uris = [self.store.normalize_uri(uri) for uri in document_uris] |
| doc_filter = " or ".join( |
| [f"document_uri == '{uri}'" for uri in normalized_uris] |
| ) |
|
|
| chunks = await self.store.search_documents( |
| query=optimized_query, |
| limit=100, |
| threshold=DEFAULT_SEARCH_THRESHOLD, |
| filter=doc_filter, |
| ) |
|
|
| self.progress_callback(f"Found {len(chunks)} chunks") |
|
|
| for chunk in chunks: |
| selected_chunks[chunk.metadata["id"]] = chunk |
|
|
| if not selected_chunks: |
| self.progress_callback("No relevant content found in the documents") |
| content = f"!!! No content found for documents: {json.dumps(document_uris)} matching queries: {json.dumps(questions)}" |
| return False, content |
|
|
| self.progress_callback( |
| f"Processing {len(questions)} questions in context of {len(selected_chunks)} chunks" |
| ) |
| await self.agent.handle_intervention() |
|
|
| questions_str = "\n".join([f" * {question}" for question in questions]) |
| content = "\n\n----\n\n".join( |
| [chunk.page_content for chunk in selected_chunks.values()] |
| ) |
|
|
| qa_system_message = self.agent.parse_prompt( |
| "fw.document_query.system_prompt.md" |
| ) |
| qa_user_message = f"# Document:\n{content}\n\n# Queries:\n{questions_str}" |
|
|
| ai_response, _reasoning = await self.agent.call_chat_model( |
| messages=[ |
| SystemMessage(content=qa_system_message), |
| HumanMessage(content=qa_user_message), |
| ], |
| explicit_caching=False, |
| ) |
|
|
| self.progress_callback(f"Q&A process completed") |
|
|
| return True, str(ai_response) |
|
|
| async def document_get_content( |
| self, document_uri: str, add_to_db: bool = False |
| ) -> str: |
| self.progress_callback(f"Fetching document content") |
| await self.agent.handle_intervention() |
| url = urlparse(document_uri) |
| scheme = url.scheme or "file" |
| mimetype, encoding = mimetypes.guess_type(document_uri) |
| mimetype = mimetype or "application/octet-stream" |
|
|
| if mimetype == "application/octet-stream": |
| if url.scheme in ["http", "https"]: |
| response: aiohttp.ClientResponse | None = None |
| retries = 0 |
| last_error = "" |
| while not response and retries < 3: |
| try: |
| async with aiohttp.ClientSession() as session: |
| response = await session.head( |
| document_uri, |
| timeout=aiohttp.ClientTimeout(total=2.0), |
| allow_redirects=True, |
| ) |
| if response.status > 399: |
| raise Exception(response.status) |
| break |
| except Exception as e: |
| await asyncio.sleep(1) |
| last_error = str(e) |
| retries += 1 |
| await self.agent.handle_intervention() |
|
|
| if not response: |
| raise ValueError( |
| f"DocumentQueryHelper::document_get_content: Document fetch error: {document_uri} ({last_error})" |
| ) |
|
|
| mimetype = response.headers["content-type"] |
| if "content-length" in response.headers: |
| content_length = ( |
| float(response.headers["content-length"]) / 1024 / 1024 |
| ) |
| if content_length > 50.0: |
| raise ValueError( |
| f"Document content length exceeds max. 50MB: {content_length} MB ({document_uri})" |
| ) |
| if mimetype and "; charset=" in mimetype: |
| mimetype = mimetype.split("; charset=")[0] |
|
|
| if scheme == "file": |
| try: |
| document_uri = files.fix_dev_path(url.path) |
| except Exception as e: |
| raise ValueError(f"Invalid document path '{url.path}'") from e |
|
|
| if encoding: |
| raise ValueError( |
| f"Compressed documents are unsupported '{encoding}' ({document_uri})" |
| ) |
|
|
| if mimetype == "application/octet-stream": |
| raise ValueError( |
| f"Unsupported document mimetype '{mimetype}' ({document_uri})" |
| ) |
|
|
| |
| document_uri_norm = self.store.normalize_uri(document_uri) |
|
|
| await self.agent.handle_intervention() |
| exists = await self.store.document_exists(document_uri_norm) |
| document_content = "" |
| if not exists: |
| await self.agent.handle_intervention() |
| if mimetype.startswith("image/"): |
| document_content = self.handle_image_document(document_uri, scheme) |
| elif mimetype == "text/html": |
| document_content = self.handle_html_document(document_uri, scheme) |
| elif mimetype.startswith("text/") or mimetype == "application/json": |
| document_content = self.handle_text_document(document_uri, scheme) |
| elif mimetype == "application/pdf": |
| document_content = self.handle_pdf_document(document_uri, scheme) |
| else: |
| document_content = self.handle_unstructured_document( |
| document_uri, scheme |
| ) |
| if add_to_db: |
| self.progress_callback(f"Indexing document") |
| await self.agent.handle_intervention() |
| async with self.store_lock: |
| success, ids = await self.store.add_document( |
| document_content, document_uri_norm |
| ) |
| if not success: |
| self.progress_callback(f"Failed to index document") |
| raise ValueError( |
| f"DocumentQueryHelper::document_get_content: Failed to index document: {document_uri_norm}" |
| ) |
| self.progress_callback(f"Indexed {len(ids)} chunks") |
| else: |
| await self.agent.handle_intervention() |
| doc = await self.store.get_document(document_uri_norm) |
| if doc: |
| document_content = doc.page_content |
| else: |
| raise ValueError( |
| f"DocumentQueryHelper::document_get_content: Document not found: {document_uri_norm}" |
| ) |
| return document_content |
|
|
| def handle_image_document(self, document: str, scheme: str) -> str: |
| return self.handle_unstructured_document(document, scheme) |
|
|
| def handle_html_document(self, document: str, scheme: str) -> str: |
| if scheme in ["http", "https"]: |
| loader = AsyncHtmlLoader(web_path=document) |
| parts: list[Document] = loader.load() |
| elif scheme == "file": |
| |
| file_content_bytes = files.read_file_bin(document) |
| file_content = file_content_bytes.decode("utf-8") |
| |
| parts = [Document(page_content=file_content, metadata={"source": document})] |
| else: |
| raise ValueError(f"Unsupported scheme: {scheme}") |
|
|
| return "\n".join( |
| [ |
| element.page_content |
| for element in MarkdownifyTransformer().transform_documents(parts) |
| ] |
| ) |
|
|
| def handle_text_document(self, document: str, scheme: str) -> str: |
| if scheme in ["http", "https"]: |
| loader = AsyncHtmlLoader(web_path=document) |
| elements: list[Document] = loader.load() |
| elif scheme == "file": |
| |
| file_content_bytes = files.read_file_bin(document) |
| file_content = file_content_bytes.decode("utf-8") |
| |
| elements = [ |
| Document(page_content=file_content, metadata={"source": document}) |
| ] |
| else: |
| raise ValueError(f"Unsupported scheme: {scheme}") |
|
|
| return "\n".join([element.page_content for element in elements]) |
|
|
| def handle_pdf_document(self, document: str, scheme: str) -> str: |
| temp_file_path = "" |
| if scheme == "file": |
| |
| file_content_bytes = files.read_file_bin(document) |
| |
| import tempfile |
|
|
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file: |
| temp_file.write(file_content_bytes) |
| temp_file_path = temp_file.name |
| elif scheme in ["http", "https"]: |
| |
| import requests |
| import tempfile |
|
|
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file: |
| response = requests.get(document, timeout=10.0) |
| if response.status_code != 200: |
| raise ValueError( |
| f"DocumentQueryHelper::handle_pdf_document: Failed to download PDF from {document}: {response.status_code}" |
| ) |
| temp_file.write(response.content) |
| temp_file_path = temp_file.name |
| else: |
| raise ValueError(f"Unsupported scheme: {scheme}") |
|
|
| if not os.path.exists(temp_file_path): |
| raise ValueError( |
| f"DocumentQueryHelper::handle_pdf_document: Temporary file not found: {temp_file_path}" |
| ) |
|
|
| try: |
| try: |
| loader = PyMuPDFLoader( |
| temp_file_path, |
| mode="single", |
| extract_tables="markdown", |
| extract_images=True, |
| images_inner_format="text", |
| images_parser=TesseractBlobParser(), |
| pages_delimiter="\n", |
| ) |
| elements: list[Document] = loader.load() |
| contents = "\n".join([element.page_content for element in elements]) |
| except Exception as e: |
| PrintStyle.error( |
| f"DocumentQueryHelper::handle_pdf_document: Error loading with PyMuPDF: {e}" |
| ) |
| contents = "" |
|
|
| if not contents: |
| import pdf2image |
| import pytesseract |
|
|
| PrintStyle.debug( |
| f"DocumentQueryHelper::handle_pdf_document: FALLBACK Converting PDF to images: {temp_file_path}" |
| ) |
|
|
| |
| pages = pdf2image.convert_from_path(temp_file_path) |
| for page in pages: |
| contents += pytesseract.image_to_string(page) + "\n\n" |
|
|
| return contents |
| finally: |
| os.unlink(temp_file_path) |
|
|
| def handle_unstructured_document(self, document: str, scheme: str) -> str: |
| elements: list[Document] = [] |
| if scheme in ["http", "https"]: |
| |
| loader = UnstructuredLoader( |
| web_url=document, |
| mode="single", |
| partition_via_api=False, |
| |
| strategy="hi_res", |
| ) |
| elements = loader.load() |
| elif scheme == "file": |
| |
| file_content_bytes = files.read_file_bin(document) |
| |
| import tempfile |
| import os |
|
|
| |
| _, ext = os.path.splitext(document) |
| with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: |
| temp_file.write(file_content_bytes) |
| temp_file_path = temp_file.name |
|
|
| try: |
| loader = UnstructuredLoader( |
| file_path=temp_file_path, |
| mode="single", |
| partition_via_api=False, |
| |
| strategy="hi_res", |
| ) |
| elements = loader.load() |
| finally: |
| |
| os.unlink(temp_file_path) |
| else: |
| raise ValueError(f"Unsupported scheme: {scheme}") |
|
|
| return "\n".join([element.page_content for element in elements]) |
|
|