| from smolagents import Tool | |
| from typing import Dict, Any, List | |
| from retriever import retrieve_query | |
| import io | |
| from contextlib import redirect_stdout | |
| class DocumentRetrievalTool(Tool): | |
| """ | |
| A tool for performing semantic search across a document knowledge base stored in ChromaDB. | |
| It accepts a query and retrieves the most relevant document chunks. | |
| """ | |
| name = "document_retrieval" | |
| description = ( | |
| "Use this tool to search for specific, relevant information within the loaded document set. " | |
| "Always use this when the user's question relates to the content of the documents." | |
| ) | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The search query, which must be a specific question or topic related to the document's content." | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, collection: Any): | |
| self.collection = collection | |
| super().__init__() | |
| def forward(self, query: str) -> str: | |
| """ | |
| Performs a query using the custom retrieval function and returns the results | |
| formatted for the agent. | |
| """ | |
| retrieved_results: List[Dict[str, Any]] = retrieve_query(query, self.collection, top_k=5) | |
| context_parts = [] | |
| for result in retrieved_results: | |
| context_parts.append( | |
| f"Source ({result['source']}): {result['text']}" | |
| ) | |
| context = "\n---\n".join(context_parts) | |
| return ( | |
| f"Retrieved context from document:\n\n{context}\n\n" | |
| ) | |
| class DocumentSummarizationTool(Tool): | |
| """ | |
| A tool that summarizes a given document text using pre-trained summarization model. | |
| """ | |
| name = "document_summarization" | |
| description = ( | |
| "Use this tool to summarize a loaded document set." | |
| ) | |
| inputs = { | |
| "document_text": { | |
| "type": "string", | |
| "description": "The document text to be summarized." | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, summarization_pipeline: Any): | |
| self.summarization_pipeline = summarization_pipeline | |
| super().__init__() | |
| def forward(self, document_text: str) -> str: | |
| """ | |
| Performs summarization using the provided summarization pipeline. | |
| """ | |
| summary_output = self.summarization_pipeline(document_text) | |
| return summary_output[0]["summary_text"] | |
| class CodeExecutionTool(Tool): | |
| """ | |
| A sandboxed tool for executing Python code (e.g., for calculations, | |
| data manipulation, or simple logic puzzles) that the LLM might struggle with). | |
| """ | |
| name: str = "python_interpreter" | |
| description: str = ( | |
| "A Python interpreter used to run short snippets of code for calculations " | |
| "or logic. The input must be valid Python code. The tool captures and " | |
| "returns any printed output." | |
| ) | |
| inputs = { | |
| "code": { | |
| "type": "string", | |
| "description": "The Python code to execute. Must be a single code block." | |
| } | |
| } | |
| output_type = "string" | |
| def forward(self, code: str) -> str: | |
| """ | |
| Executes the given Python code in a restricted environment. | |
| """ | |
| safe_builtins = {'print': print, 'len': len, 'sum': sum, 'min': min, 'max': max, 'range': range, 'str': str, 'int': int, 'float': float} | |
| safe_globals = {'__builtins__': safe_builtins} | |
| output_buffer = io.StringIO() | |
| try: | |
| with redirect_stdout(output_buffer): | |
| exec(code, safe_globals) | |
| output = output_buffer.getvalue() | |
| return f"Code Output:\n{output.strip()}" if output else "No output. Did you forget print()?" | |
| except Exception as e: | |
| return f"Code Execution Error: {type(e).__name__}: {str(e)}" |