File size: 5,980 Bytes
ea8a378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import constants as cte

from dotenv import load_dotenv
from reportlab.lib.pagesizes import letter
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.lib.units import inch
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
from smolagents import CodeAgent, LiteLLMModel, Tool, MessageRole

load_dotenv()  # take environment variables from .env.

LLM_BASE = os.getenv("AZURE_OPENAI_BASE")
LLM_VERSION = os.getenv("AZURE_OPENAI_VERSION")
LLMI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
LLM_NAME = os.getenv("AZURE_OPENAI_MODEL")
EMBEDDING_BASE = os.getenv("AZURE_OPENAI_EMBEDDING_BASE")
EMBEDDING_VERSION = os.getenv("AZURE_OPENAI_EMBEDDING_VERSION")
EMBEDDING_API_KEY = os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY")
EMBEDDING_NAME = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL")

CHROMA_PATH = "/teamspace/studios/this_studio/AgenticRAG/chroma_db"
BM25_PATH = "philschmid/markdown-documentation-transformers"

class ChromaRetrieverTool(Tool):
    name = "chroma_retriever"
    description = """Uses vector search to retrieve chunks of information from the “Manual de la Renta” document that might be more relevant to answering your query. 

    Use the affirmative form rather than a question. If the age or residence is provided, be sure to include it in query to find the necessary information.

    For better results, searches must be for one specific data, for multiple concepts or different information, use multiple calls to the tool"""    
    inputs = {
        "query": {
            "type": "string",
            "description": "The query to perform. This should be vector space close to your target documents.",
        }
    }
    output_type = "string"

    def __init__(self, path_to_database, top_k_results: int = 5, **kwargs):
        super().__init__(**kwargs)

        import chromadb

        self.top_k_results = top_k_results
        self.openai_embedding = (
            chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
                api_key=EMBEDDING_API_KEY,
                api_base=EMBEDDING_BASE,
                api_type="azure",
                api_version=EMBEDDING_VERSION,
                model_name=str(EMBEDDING_NAME).split("/")[-1],
            )
        )
        self.retriever_client = chromadb.PersistentClient(path=path_to_database)
        self.collection = self.retriever_client.get_or_create_collection(
            name="RENTA_2023_LARGE",
            embedding_function=self.openai_embedding,
        )

    def forward(self, query: str) -> str:
        assert isinstance(query, str), "Your search query must be a string"

        results = self.collection.query(
            query_texts=[query],
            n_results=self.top_k_results,  # how many results to return
        )

        outout_str = "No information found"

        if "documents" in results and results["documents"] is not None:
            outout_str = "\nTop Retrieved documents:\n"
            for j, document in enumerate(results["documents"]):
                for i, doc in enumerate(document):
                    doc_str = (
                        f"\n\n===== Document {results['metadatas'][j][i]} =====\n"
                        + str(doc)
                    )
                    outout_str += doc_str
                    # metadatas.append(results["metadatas"])

        return outout_str

class GeneratePDFTool(Tool):
    name = "generate_pdf"
    description = "Generates a PDF document from the final answer."
    inputs = {
        "text": {
            "type": "string",
            "description": "The final answer to be included too in the PDF document.",
        }
    }
    output_type = "string"

    def forward(self, text: str) -> str:
        try:
            doc = SimpleDocTemplate("final_answer.pdf", pagesize=letter)
            styles = getSampleStyleSheet()
            story = []
            story.append(Paragraph(text, styles["Normal"]))
            story.append(Spacer(1, 0.2 * inch))  # Add a space below the content
            doc.build(story)
            return "PDF document 'final_answer.pdf' has been generated successfully."
        except Exception as e:
            return f"Error generating PDF: {str(e)}"


class SmolAgent:
    def __init__(self):
        # Define LLM model to call.
        model = LiteLLMModel(
            model_id=str(os.getenv("AZURE_OPENAI_MODEL")),
            api_base=str(os.getenv("AZURE_OPENAI_BASE")),
            api_key=str(os.getenv("AZURE_OPENAI_API_KEY")),
            temperature = 0
        )

        # Create retriever with documents preprocessed.
        chroma_tool = ChromaRetrieverTool(path_to_database=CHROMA_PATH)
        generate_pdf_tool = GeneratePDFTool()

        self.agent = CodeAgent(
            tools=[chroma_tool, generate_pdf_tool],
            model=model,
            max_steps=8,
            add_base_tools=True,
            additional_authorized_imports=["chromadb"],
        )

    def __call__(self, query):
        return self.agent.run(query, reset=False)


if __name__ == "__main__":
    agent = SmolAgent()

    # Call agent with different questions.
    question1 = "¿Tengo que declarar los bízums recibidos a lo largo del año pasado?"
    print(f"\nQuestion 1: {question1}")
    agent_output = agent(question1)
    print("======================================================================================================================================")
    print("======================================================================================================================================")
    #print(agent.agent.write_memory_to_messages())
    print("======================================================================================================================================")
    print("Agent response:\n", agent_output)