naveensharma16's picture
Upload 8 files
178f14f verified
from smolagents import Tool
from typing import Dict, Any, List
from retriever import retrieve_query
import io
from contextlib import redirect_stdout
class DocumentRetrievalTool(Tool):
"""
A tool for performing semantic search across a document knowledge base stored in ChromaDB.
It accepts a query and retrieves the most relevant document chunks.
"""
name = "document_retrieval"
description = (
"Use this tool to search for specific, relevant information within the loaded document set. "
"Always use this when the user's question relates to the content of the documents."
)
inputs = {
"query": {
"type": "string",
"description": "The search query, which must be a specific question or topic related to the document's content."
}
}
output_type = "string"
def __init__(self, collection: Any):
self.collection = collection
super().__init__()
def forward(self, query: str) -> str:
"""
Performs a query using the custom retrieval function and returns the results
formatted for the agent.
"""
retrieved_results: List[Dict[str, Any]] = retrieve_query(query, self.collection, top_k=5)
context_parts = []
for result in retrieved_results:
context_parts.append(
f"Source ({result['source']}): {result['text']}"
)
context = "\n---\n".join(context_parts)
return (
f"Retrieved context from document:\n\n{context}\n\n"
)
class DocumentSummarizationTool(Tool):
"""
A tool that summarizes a given document text using pre-trained summarization model.
"""
name = "document_summarization"
description = (
"Use this tool to summarize a loaded document set."
)
inputs = {
"document_text": {
"type": "string",
"description": "The document text to be summarized."
}
}
output_type = "string"
def __init__(self, summarization_pipeline: Any):
self.summarization_pipeline = summarization_pipeline
super().__init__()
def forward(self, document_text: str) -> str:
"""
Performs summarization using the provided summarization pipeline.
"""
summary_output = self.summarization_pipeline(document_text)
return summary_output[0]["summary_text"]
class CodeExecutionTool(Tool):
"""
A sandboxed tool for executing Python code (e.g., for calculations,
data manipulation, or simple logic puzzles) that the LLM might struggle with).
"""
name: str = "python_interpreter"
description: str = (
"A Python interpreter used to run short snippets of code for calculations "
"or logic. The input must be valid Python code. The tool captures and "
"returns any printed output."
)
inputs = {
"code": {
"type": "string",
"description": "The Python code to execute. Must be a single code block."
}
}
output_type = "string"
def forward(self, code: str) -> str:
"""
Executes the given Python code in a restricted environment.
"""
safe_builtins = {'print': print, 'len': len, 'sum': sum, 'min': min, 'max': max, 'range': range, 'str': str, 'int': int, 'float': float}
safe_globals = {'__builtins__': safe_builtins}
output_buffer = io.StringIO()
try:
with redirect_stdout(output_buffer):
exec(code, safe_globals)
output = output_buffer.getvalue()
return f"Code Output:\n{output.strip()}" if output else "No output. Did you forget print()?"
except Exception as e:
return f"Code Execution Error: {type(e).__name__}: {str(e)}"