File size: 5,368 Bytes
c5185b5
ac0db5b
60e7a59
ac0db5b
 
c5185b5
 
20bd124
 
 
ac0db5b
 
 
 
d2552ae
ac0db5b
 
d2552ae
 
c5185b5
 
ac0db5b
 
 
 
fdfc130
b660c22
fdfc130
ac0db5b
 
 
 
 
c5185b5
 
 
 
 
 
ac0db5b
20bd124
de1d9c7
20bd124
 
 
de1d9c7
 
20bd124
 
 
 
 
 
de1d9c7
20bd124
 
 
 
 
 
 
 
 
de1d9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2552ae
 
99e6c99
20bd124
 
 
 
d2552ae
20bd124
d2552ae
20bd124
47545b1
20bd124
47545b1
20bd124
 
 
 
 
 
 
 
47545b1
 
20bd124
 
 
 
 
 
 
47545b1
20bd124
 
d2552ae
60e7a59
 
 
20bd124
 
 
 
 
 
 
 
 
 
47545b1
d2552ae
20bd124
47545b1
20bd124
 
 
60e7a59
20bd124
47545b1
20bd124
 
 
47545b1
20bd124
 
 
 
ac0db5b
 
 
20bd124
 
c5185b5
20bd124
 
ac0db5b
 
 
20bd124
 
 
 
 
ac0db5b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import operator
import os
import time
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, AnyMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import add_messages, START, END, StateGraph
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode


from typing_extensions import TypedDict, Annotated



class State(TypedDict):
    messages: Annotated[list, add_messages]
    content_type: str
    content: str
    aggregate: Annotated[list, operator.add]
    # graph_state: str


def get_llm():
    os.getenv("GROQ_API_KEY")
    # return init_chat_model("llama-3.3-70b-versatile", model_provider="groq")

    return init_chat_model("gemini-2.0-flash", model_provider="google_genai")

def get_graph(llm):
    with open('prompts/system_prompt.md', 'r', encoding='utf-8') as markdown_file:
        system_prompt = markdown_file.read()

    prompt_template = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )

    from langchain_community.retrievers import WikipediaRetriever
    from langchain_community.retrievers import TavilySearchAPIRetriever

    # Wikipedia retriever
    wiki_retriever = WikipediaRetriever(load_max_docs =20)
    # Tavily retriever
    tavily_retriever = TavilySearchAPIRetriever(k=3)

    @tool
    def retrieve(query: str):
        """
        This function retrieves Wikipedia entries based on the query.
        """
        print("\n-------------------- Tool (Wikipedia) has been called --------------------\n")
        print("The query is: ", query)
        docs = wiki_retriever.invoke(query)
        serialized = "\n\n".join(
            (f"\nContent:\n{doc.page_content}")
            for doc in docs
        )

        return serialized

    @tool
    def online_search(query: str):
        """
        This function does a web search based on the query.
        """
        print("\n-------------------- Tool (Tavily) has been called --------------------\n")
        print("The query is: ", query)
        docs = tavily_retriever.invoke(query)
        serialized = "\n\n".join(
            (f"\nContent:\n{doc.page_content}")
            for doc in docs
        )

        return serialized



    tools = [retrieve, online_search]
    tool_node = ToolNode(tools)
    llm_with_tools = llm.bind_tools(tools)

    def make_plan(state: State):

        print("\n-------------------- Starting to create a plan --------------------\n")
        print("Content is: ", state["content_type"])
        # get all messages from the state
        messages = state["messages"]
        # append planning message
        messages.append(HumanMessage(content="Write a plan how to solve this qustion?"))
        # create prompt
        prompt = prompt_template.invoke(messages)
        # invoke LLM
        response = llm.invoke(prompt)
        print("The plan is: ", response.content)
        return {"messages": [response], "aggregate": ["Plan"]}




    def call_model(state: State):
        print("\n-------------------- Agent has been called -----------------------------------\n")
        # get all messages from the state
        messages = state["messages"]
        # append instruction message
        messages.append(HumanMessage(content="Please provide me the answer to the question in detail."))
        # create prompt
        prompt_answer = prompt_template.invoke(messages)
        # invoke LLM
        response = llm_with_tools.invoke(prompt_answer)
        print("Agent has made a decision:\n", response.content, response.tool_calls)
        print("Waiting for 4 seconds...")
        time.sleep(4)

        return {"messages": [response], "aggregate": ["Agent"]}

    def get_answer(state: State):
        # get all messages from the state
        messages = state["messages"]
        # add prompt message
        messages.append(HumanMessage(content="Please provide me just the plain answer to the question"))
        # create prompt
        prompt_answer = prompt_template.invoke(messages)
        # invoke LLM
        response = llm.invoke(prompt_answer)
        print("The final answer is: ", response.content)
        return {"messages": [response], "aggregate": ["Answer"]}

    def should_continue(state: State):
        print("\n-------------------- Decision of forwarding has been made --------------------\n")
        messages = state["messages"]
        print("This is round: ",len(state["aggregate"]))
        print("The last message is: ", messages[-1])

        if len(state["aggregate"]) < 8:
            last_message = messages[-1]
            if last_message.tool_calls:

                return "tools"
            return "Answer"
        else:
            return "Answer"

    # Build graph
    builder = StateGraph(State)
    builder.add_node("tools", tool_node)
    builder.add_node("Plan", make_plan)
    builder.add_node("Agent", call_model)
    builder.add_node("Answer", get_answer)



    # Logic
    builder.add_edge(START, "Plan")
    builder.add_edge("Plan", "Agent")
    builder.add_conditional_edges("Agent", should_continue, ["tools", "Answer"])
    builder.add_edge("tools", "Agent")
    builder.add_edge("Answer", END)

    return builder.compile()