File size: 3,853 Bytes
178f14f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from smolagents import Tool
from typing import Dict, Any, List
from retriever import retrieve_query
import io
from contextlib import redirect_stdout

class DocumentRetrievalTool(Tool):
    """
    A tool for performing semantic search across a document knowledge base stored in ChromaDB.
    It accepts a query and retrieves the most relevant document chunks.
    """
    name = "document_retrieval"
    description = (
        "Use this tool to search for specific, relevant information within the loaded document set. "
        "Always use this when the user's question relates to the content of the documents."
    )
    inputs = {
        "query": {
            "type": "string",
            "description": "The search query, which must be a specific question or topic related to the document's content."
        }
    }
    output_type = "string"

    def __init__(self, collection: Any):
        self.collection = collection
        super().__init__()

    def forward(self, query: str) -> str:
        """
        Performs a query using the custom retrieval function and returns the results 
        formatted for the agent.
        """
        retrieved_results: List[Dict[str, Any]] = retrieve_query(query, self.collection, top_k=5)
        context_parts = []
        for result in retrieved_results:
            context_parts.append(
                f"Source ({result['source']}): {result['text']}"
            )
        context = "\n---\n".join(context_parts)
        
        return (
            f"Retrieved context from document:\n\n{context}\n\n"
        )
    
class DocumentSummarizationTool(Tool):
    """
    A tool that summarizes a given document text using pre-trained summarization model.
    """
    name = "document_summarization"
    description = (
        "Use this tool to summarize a loaded document set."
    )
    inputs = {
        "document_text": {
            "type": "string",
            "description": "The document text to be summarized."
        }
    }
    output_type = "string"

    def __init__(self, summarization_pipeline: Any):
        self.summarization_pipeline = summarization_pipeline
        super().__init__()

    def forward(self, document_text: str) -> str:
        """
        Performs summarization using the provided summarization pipeline.
        """
        summary_output = self.summarization_pipeline(document_text)
        return summary_output[0]["summary_text"]
    

class CodeExecutionTool(Tool):
    """
    A sandboxed tool for executing Python code (e.g., for calculations, 
    data manipulation, or simple logic puzzles) that the LLM might struggle with).
    """
    name: str = "python_interpreter"
    description: str = (
        "A Python interpreter used to run short snippets of code for calculations "
        "or logic. The input must be valid Python code. The tool captures and "
        "returns any printed output."
    )
    inputs = {
        "code": {
            "type": "string",
            "description": "The Python code to execute. Must be a single code block."
        }
    }
    output_type = "string"

    def forward(self, code: str) -> str:
        """
        Executes the given Python code in a restricted environment.
        """
        safe_builtins = {'print': print, 'len': len, 'sum': sum, 'min': min, 'max': max, 'range': range, 'str': str, 'int': int, 'float': float}
        safe_globals = {'__builtins__': safe_builtins}
        output_buffer = io.StringIO()

        try:
            with redirect_stdout(output_buffer):
                exec(code, safe_globals)
            
            output = output_buffer.getvalue()
            return f"Code Output:\n{output.strip()}" if output else "No output. Did you forget print()?"
        
        except Exception as e:
            return f"Code Execution Error: {type(e).__name__}: {str(e)}"