File size: 3,696 Bytes
6b5a8ab
 
 
 
8233fc5
 
6b5a8ab
 
 
 
8233fc5
6b5a8ab
8233fc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b5a8ab
 
 
 
1769d8d
6b5a8ab
8233fc5
6b5a8ab
 
 
 
 
 
8233fc5
 
6b5a8ab
 
 
8233fc5
 
6b5a8ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_mistralai import ChatMistralAI
from langchain_groq import ChatGroq
from langgraph.prebuilt import create_react_agent
from custom_tools import custom_tools

class ReActAgent:
    def __init__(self, provider: str="Google", model: str="gemini-2.5-flash"):
        load_dotenv()

        if provider=="Google":
            os.environ["GOOGLE_API_KEY"] = os.getenv("GOOGLE")
            # Initialize your LLM
            llm = ChatGoogleGenerativeAI(
                model=model,
                temperature=0,
                max_retries=5
            )

        if provider=="Mistral":
            os.environ["MISTRAL_API_KEY"] = os.getenv("MISTRAL")
            # Initialize your LLM
            llm = ChatMistralAI(
                model=model,
                temperature=0,
                max_retries=5
            )

        if provider=="Groq":
            os.environ["GROQ_API_KEY"] = os.getenv("GROQ")
            # Initialize your LLM
            llm = ChatGroq(
                model=model,
                temperature=0,
                max_retries=5
            )

        sys_prompt = "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, DON'T use comma to write your number NEITHER use units such as $ or percent sign unless specified otherwise. If you are asked for a string, DON'T use articles, NEITHER abbreviations (e.g. for cities) capitalize the first letter, and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending, unless the first letter capitalization, whether the element to be put in the list is a number or a string.\n\n\n \
        \n \
        You will be provided with tools to help you answer questions.\n \
        If you are asked to make a calculation, absolutely use the tools provided to you. You should AVOID calculating by yourself and ABSOLUTELY use appropriate tools.\n \
        If you need to search for information, use the web_search tool rather than wiki_search, unless the question specifies searching on wikipedia. After using the web_search tool, look for the first URL provided with the url_search tool and ask yourself if the answer is in the tool response. If it is, answer the question. If not, search on other links.\n \
        \n \
        If needed, use one tool first, then use the output of that tool as an input to another thinking then to the use of another tool."
        # Build the ReAct agent
        self.agent = create_react_agent(
            model=llm,
            tools=custom_tools,
            prompt=sys_prompt  
        )
        print(f"ReActAgent initialized with {provider} - {model}.")

    def __call__(self, question: str) -> str:
        # Wrap question in HumanMessage to match React expectations
        input_msg = HumanMessage(content=question)
        # Invoke the agent; returns a stream or single response
        out = self.agent.invoke({"messages": [input_msg]})
        for o in out["messages"]:
            print(o)
        # The last message contains the agent's reply
        reply = out["messages"][-1].content
        # Optionally, strip out “Final Answer:” headers
        if "FINAL ANSWER: " in reply:
            reply = reply.split("FINAL ANSWER: ")[-1].strip()
        return reply