neuraldocs / src /agentic_rag /tools /custom_tool.py
Lawrence121's picture
fix: reduce chunk size and limit to fit free tier
7e612d6
Raw
History Blame Contribute Delete
2.42 kB
import os
from typing import Type, Optional, Any
from pydantic import BaseModel, Field
from markitdown import MarkItDown
from chonkie import SemanticChunker
from qdrant_client import QdrantClient
try:
from crewai.tools import BaseTool
except ImportError:
from langchain.tools import BaseTool
class SearchInput(BaseModel):
query: str = Field(..., description="Search query")
class DocumentSearchTool(BaseTool):
name: str = "DocumentSearchTool"
description: str = "Search uploaded document for relevant passages."
args_schema: Type[BaseModel] = SearchInput
file_path: Optional[str] = None
client: Optional[Any] = None
COLLECTION : str = "neuraldocs_collection"
EMBED_MODEL : str = "minishlab/potion-base-8M"
CHUNK_SIZE : int = 128
SIMILARITY_T: float = 0.5
TOP_K : int = 2
SEPARATOR : str = "\n---\n"
def __init__(self, file_path: str):
super().__init__(file_path=file_path, client=QdrantClient(":memory:"))
self._build_index()
def _to_text(self) -> str:
converter = MarkItDown()
result = converter.convert(self.file_path)
text = result.text_content.strip()
if not text:
raise ValueError(f"Could not extract text from '{self.file_path}'.")
return text[:5000]
def _chunk(self, text: str) -> list:
chunker = SemanticChunker(
embedding_model=self.EMBED_MODEL,
threshold=self.SIMILARITY_T,
chunk_size=self.CHUNK_SIZE,
min_sentences=1,
)
return [c.text for c in chunker.chunk(text) if c.text.strip()]
def _build_index(self) -> None:
chunks = self._chunk(self._to_text())
source_name = os.path.basename(self.file_path)
self.client.add(
collection_name=self.COLLECTION,
documents=chunks,
metadata=[{"source": source_name, "chunk_id": i} for i in range(len(chunks))],
ids=list(range(len(chunks))),
)
def _run(self, query: str) -> str:
hits = self.client.query(
collection_name=self.COLLECTION,
query_text=query,
limit=self.TOP_K,
)
passages = [h.document for h in hits if h.document and h.document.strip()]
if not passages:
return "No relevant passages found in the document."
return self.SEPARATOR.join(passages)