| """Search tool for agent.""" |
|
|
| from langchain_core.tools import tool |
| from src.rag.retriever import retriever |
| from sqlalchemy.ext.asyncio import AsyncSession |
| from src.middlewares.logging import get_logger |
|
|
| logger = get_logger("search_tool") |
|
|
|
|
| @tool |
| async def search_documents( |
| query: str, |
| user_id: str, |
| db: AsyncSession, |
| num_results: int = 5 |
| ) -> str: |
| """Search user's uploaded documents for relevant information. |
| |
| Args: |
| query: The search query or question |
| user_id: The user's ID |
| db: Database session |
| num_results: Number of results to return (default: 5) |
| |
| Returns: |
| Relevant document excerpts with source and page information |
| """ |
| try: |
| results = await retriever.retrieve(query, user_id, db, num_results) |
|
|
| if not results: |
| return "No relevant information found in the documents." |
|
|
| formatted_results = [] |
| for result in results: |
| filename = result["metadata"].get("filename", "Unknown") |
| page = result["metadata"].get("page_label") |
| source_label = f"{filename}, p.{page}" if page else filename |
| formatted_results.append(f"[Source: {source_label}]\n{result['content']}\n") |
|
|
| return "\n".join(formatted_results) |
|
|
| except Exception as e: |
| logger.error("Search failed", error=str(e)) |
| return "Sorry, I encountered an error while searching the documents." |
|
|