Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Type, Optional, Any | |
| from pydantic import BaseModel, Field | |
| from markitdown import MarkItDown | |
| from chonkie import SemanticChunker | |
| from qdrant_client import QdrantClient | |
| try: | |
| from crewai.tools import BaseTool | |
| except ImportError: | |
| from langchain.tools import BaseTool | |
| class SearchInput(BaseModel): | |
| query: str = Field(..., description="Search query") | |
| class DocumentSearchTool(BaseTool): | |
| name: str = "DocumentSearchTool" | |
| description: str = "Search uploaded document for relevant passages." | |
| args_schema: Type[BaseModel] = SearchInput | |
| file_path: Optional[str] = None | |
| client: Optional[Any] = None | |
| COLLECTION : str = "neuraldocs_collection" | |
| EMBED_MODEL : str = "minishlab/potion-base-8M" | |
| CHUNK_SIZE : int = 128 | |
| SIMILARITY_T: float = 0.5 | |
| TOP_K : int = 2 | |
| SEPARATOR : str = "\n---\n" | |
| def __init__(self, file_path: str): | |
| super().__init__(file_path=file_path, client=QdrantClient(":memory:")) | |
| self._build_index() | |
| def _to_text(self) -> str: | |
| converter = MarkItDown() | |
| result = converter.convert(self.file_path) | |
| text = result.text_content.strip() | |
| if not text: | |
| raise ValueError(f"Could not extract text from '{self.file_path}'.") | |
| return text[:5000] | |
| def _chunk(self, text: str) -> list: | |
| chunker = SemanticChunker( | |
| embedding_model=self.EMBED_MODEL, | |
| threshold=self.SIMILARITY_T, | |
| chunk_size=self.CHUNK_SIZE, | |
| min_sentences=1, | |
| ) | |
| return [c.text for c in chunker.chunk(text) if c.text.strip()] | |
| def _build_index(self) -> None: | |
| chunks = self._chunk(self._to_text()) | |
| source_name = os.path.basename(self.file_path) | |
| self.client.add( | |
| collection_name=self.COLLECTION, | |
| documents=chunks, | |
| metadata=[{"source": source_name, "chunk_id": i} for i in range(len(chunks))], | |
| ids=list(range(len(chunks))), | |
| ) | |
| def _run(self, query: str) -> str: | |
| hits = self.client.query( | |
| collection_name=self.COLLECTION, | |
| query_text=query, | |
| limit=self.TOP_K, | |
| ) | |
| passages = [h.document for h in hits if h.document and h.document.strip()] | |
| if not passages: | |
| return "No relevant passages found in the document." | |
| return self.SEPARATOR.join(passages) | |