File size: 11,712 Bytes
c5185b5
ac0db5b
60e7a59
91b39e8
 
ac0db5b
061b557
d4ba1d5
db29fc7
c542bf3
c5185b5
 
20bd124
 
061b557
ac0db5b
707cf08
 
ac0db5b
 
 
d2552ae
ac0db5b
 
91b39e8
 
c5185b5
 
ac0db5b
 
 
 
1e9c645
b660c22
fdfc130
061b557
 
 
 
 
 
 
ac0db5b
 
 
 
 
c5185b5
 
 
 
 
 
ac0db5b
20bd124
de1d9c7
20bd124
 
c542bf3
de1d9c7
 
20bd124
db29fc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c542bf3
db29fc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20bd124
 
 
 
 
de1d9c7
20bd124
 
 
c542bf3
20bd124
 
 
 
 
db29fc7
 
 
 
 
 
689735c
db29fc7
1691123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fbf8f6
707cf08
 
 
 
 
 
 
 
 
 
db29fc7
 
5fbf8f6
db29fc7
 
c542bf3
db29fc7
de1d9c7
 
 
 
 
 
 
1e9c645
 
de1d9c7
c542bf3
de1d9c7
 
 
 
 
db29fc7
 
 
 
 
 
1e9c645
689735c
db29fc7
 
1e9c645
db29fc7
 
c542bf3
d2552ae
db29fc7
 
 
 
 
 
 
 
 
 
 
 
 
c542bf3
db29fc7
d4ba1d5
 
 
 
 
 
 
 
 
 
 
 
c542bf3
d4ba1d5
 
c542bf3
d4ba1d5
1be50f1
 
d4ba1d5
20bd124
 
 
 
 
d4ba1d5
 
91b39e8
 
20bd124
47545b1
20bd124
47545b1
20bd124
 
 
 
 
 
 
 
47545b1
 
20bd124
 
061b557
 
20bd124
 
 
 
 
47545b1
20bd124
 
d2552ae
061b557
60e7a59
20bd124
 
 
061b557
 
 
20bd124
 
 
 
 
 
 
47545b1
d2552ae
20bd124
47545b1
20bd124
 
d4ba1d5
 
20bd124
60e7a59
20bd124
47545b1
20bd124
 
 
47545b1
20bd124
 
 
 
ac0db5b
 
 
20bd124
 
c5185b5
20bd124
 
ac0db5b
 
 
20bd124
 
 
 
 
ac0db5b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import operator
import os
import time
from typing import Optional

from langchain.chat_models import init_chat_model

from langchain_community.document_loaders import WikipediaLoader, ArxivLoader, YoutubeLoader
from langchain_community.tools import TavilySearchResults
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import add_messages, START, END, StateGraph
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from pydantic import SecretStr

from langchain_custom import WikipediaTableLoader

from typing_extensions import TypedDict, Annotated



class State(TypedDict):
    messages: Annotated[list, add_messages]
    content_type: Optional[str]
    content: Optional[str]
    aggregate: Annotated[list, operator.add]
    # graph_state: str


def get_llm():
    os.getenv("GROQ_API_KEY")
    #return init_chat_model("llama-3.3-70b-versatile", model_provider="groq")

    return init_chat_model("gemini-2.0-flash", model_provider="google_genai")
    #return AzureChatOpenAI(
       # api_key=SecretStr(os.environ["AZURE_OPENAI_API_KEY"]),
       # azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
        #azure_deployment="gpt-4o-mini",
        #api_version=os.environ["AZURE_OPENAI_API_VERSION"],
    #)


def get_graph(llm):
    with open('prompts/system_prompt.md', 'r', encoding='utf-8') as markdown_file:
        system_prompt = markdown_file.read()

    prompt_template = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )

    from langchain_community.retrievers import WikipediaRetriever
    from langchain_community.retrievers import TavilySearchAPIRetriever

    # Wikipedia retriever
    wiki_retriever = WikipediaRetriever()
    # Tavily retriever
    tavily_retriever = TavilySearchAPIRetriever(k=3)

    @tool
    def multiply(a: int, b: int) -> int:
        """Multiply two numbers.
        Args:
            a: first int
            b: second int
        """
        print("\n-------------------- Tool (Multiplication) has been called --------------------\n")
        return a * b

    @tool
    def add(a: int, b: int) -> int:
        """Add two numbers.

        Args:
            a: first int
            b: second int
        """
        print("\n-------------------- Tool (Addition) has been called --------------------\n")
        return a + b

    @tool
    def subtract(a: int, b: int) -> int:
        """Subtract two numbers.

        Args:
            a: first int
            b: second int
        """
        print("\n-------------------- Tool (Subtraction) has been called --------------------\n")
        return a - b

    @tool
    def divide(a: int, b: int) -> float:
        """Divide two numbers.

        Args:
            a: first int
            b: second int
        """
        print("\n-------------------- Tool (Division) has been called --------------------\n")
        if b == 0:
            raise ValueError("Cannot divide by zero.")
        return a / b

    @tool
    def modulus(a: int, b: int) -> int:
        """Get the modulus of two numbers.

        Args:
            a: first int
            b: second int
        """
        print("\n-------------------- Tool (Modulus) has been called --------------------\n")
        return a % b

    @tool
    def retrieve(query: str):
        """
        This function retrieves Wikipedia entries based on the query.
        """
        print("\n-------------------- Tool (Wikipedia) has been called --------------------\n")
        print("The query is: ", query)
        docs = wiki_retriever.invoke(query)
        serialized = "\n\n".join(
            f"\nContent:\n{doc.page_content}"
            for doc in docs
        )

        return serialized

    @tool
    def wiki_search(query: str) -> str:
        """Search Wikipedia for a query and return maximum 2 results.

        Args:
            query: The search query."""
        print("\n-------------------- Tool (Wikipedia) has been called --------------------\n")
        search_docs = WikipediaLoader(query=query, load_max_docs=2).load()

        parts: list[str] = []

        for doc in search_docs:
            parts.append(
            f'<Document source="{doc.metadata["source"]}" '
            f'title="{doc.metadata["title"]}" '
            f'page="{doc.metadata.get("page", "")}">\n'
            f'{doc.page_content}\n</Document>'
            )

            try:
                print("---------------------------------")
                print("Loading tables from: ", doc.metadata["source"])
                print("---------------------------------")
                tables = WikipediaTableLoader(url=doc.metadata["source"], title=doc.metadata["title"]).load()
                for i, table in enumerate(tables):
                    parts.append(
                        f'<Document source="{table.metadata["source"]}" '
                        f'title="{table.metadata["title"]}" '
                        f'table_index="{i}">\n'
                        f'{table.page_content}\n</Document>'
                    )
            except Exception:
                pass

        formatted_search_docs = "\n\n---\n\n".join(parts)

        return formatted_search_docs

    @tool
    def wiki_table_search(url: str, title: str) -> str:
        """Get Wikipedia tables for a given URL and title.

        Args:
            url: The Wikipedia URL.
            title: The title of the Wikipedia page."""
        print("\n-------------------- Tool (Wikipedia-Table) has been called --------------------\n")
        search_docs = WikipediaTableLoader(url=url, title=title).load()
        formatted_search_docs = "\n\n---\n\n".join(
            [
                f'<Document source="{doc.metadata["source"]}" title="{doc.metadata["title"]}" table_index={doc.metadata["table_index"]}/>\n{doc.page_content}\n</Document>'
                for doc in search_docs
            ])
        return formatted_search_docs

    @tool
    def online_search(query: str):
        """
        This function does a web search based on the query.
        """
        print("\n-------------------- Tool (Tavily) has been called --------------------\n")
        print("The query is: ", query)
        # docs = tavily_retriever.invoke(query)
        docs = TavilySearchResults(max_results=3).invoke({'query': query})
        serialized = "\n\n".join(
            f"\nContent:\n{doc.page_content}"
            for doc in docs
        )

        return serialized

    @tool
    def web_search(query: str) -> str:
        """Search Tavily for a query and return maximum 3 results.

        Args:
            query: The search query."""
        print("\n-------------------- Tool (Tavily) has been called --------------------\n")
        search_docs = TavilySearchResults(max_results=3).invoke({'query': query})
        formatted_search_docs = "\n\n---\n\n".join(
            [
                f'URL: {doc["url"]}\nTitle= {doc["title"]}\nContent: {doc["content"]}'
                for doc in search_docs
            ])
        return formatted_search_docs

    @tool
    def arvix_search(query: str) -> str:
        """Search Arxiv for a query and return maximum 3 result.

        Args:
            query: The search query."""
        print()
        search_docs = ArxivLoader(query=query, load_max_docs=3).load()
        formatted_search_docs = "\n\n---\n\n".join(
            [
                f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
                for doc in search_docs
            ])
        return formatted_search_docs

    @tool
    def youtube_transcript(url: str) -> str:
        """Download a transcript of a YouTube video.

        Args:
            url: URL of the YouTube video."""
        print("\n-------------------- Tool (YouTube Transcript) has been called --------------------\n")
        loader = YoutubeLoader.from_youtube_url(
            url, add_video_info=False
        )
        docs = loader.load()
        transcript = "\n\n".join(
            f"\nContent:\n{doc.page_content}"
            for doc in docs
        )
        return transcript



    tools = [wiki_search, web_search, arvix_search, youtube_transcript, multiply, add, subtract, divide, modulus]
    tool_node = ToolNode(tools)
    llm_with_tools = llm.bind_tools(tools)

    def make_plan(state: State):
        print("\n-------------------- Starting to create a plan --------------------\n")
        print("Waiting for 5 seconds...")
        time.sleep(5)
        if "content_type" in state:
            print("Content is: ", state["content"])
        # get all messages from the state
        messages = state["messages"]
        # append planning message
        messages.append(HumanMessage(content="Write a plan how to solve this qustion?"))
        # create prompt
        prompt = prompt_template.invoke(messages)
        # invoke LLM
        response = llm.invoke(prompt)
        print("The plan is: ", response.content)
        return {"messages": [response], "aggregate": ["Plan"]}




    def call_model(state: State):
        print("\n-------------------- Agent has been called -----------------------------------\n")
        print("Waiting for 5 seconds...")
        time.sleep(5)
        # get all messages from the state
        messages = state["messages"]
        # append instruction message
        messages.append(HumanMessage(content="Please provide me the answer to the question in detail."))
        # create prompt
        prompt_answer = prompt_template.invoke(messages)
        # invoke LLM
        response = llm_with_tools.invoke(prompt_answer)
        print("Agent has made a decision:\n", response.content, response.tool_calls)


        return {"messages": [response], "aggregate": ["Agent"]}

    def get_answer(state: State):
        print("\n-------------------- Generating Answer -----------------------------------\n")
        print("Waiting for 5 seconds...")
        time.sleep(5)
        # get all messages from the state
        messages = state["messages"]
        # add prompt message
        messages.append(HumanMessage(content="Please provide me just the plain answer to the question"))
        # create prompt
        prompt_answer = prompt_template.invoke(messages)
        # invoke LLM
        response = llm.invoke(prompt_answer)
        print("The final answer is: ", response.content)
        return {"messages": [response], "aggregate": ["Answer"]}

    def should_continue(state: State):
        print("\n-------------------- Decision of forwarding has been made --------------------\n")
        print("Waiting for 2 seconds...")
        time.sleep(2)
        messages = state["messages"]
        print("This is round: ",len(state["aggregate"]))
        print("The last message is: ", messages[-1])

        if len(state["aggregate"]) < 8:
            last_message = messages[-1]
            if last_message.tool_calls:

                return "tools"
            return "Answer"
        else:
            return "Answer"

    # Build graph
    builder = StateGraph(State)
    builder.add_node("tools", tool_node)
    builder.add_node("Plan", make_plan)
    builder.add_node("Agent", call_model)
    builder.add_node("Answer", get_answer)



    # Logic
    builder.add_edge(START, "Plan")
    builder.add_edge("Plan", "Agent")
    builder.add_conditional_edges("Agent", should_continue, ["tools", "Answer"])
    builder.add_edge("tools", "Agent")
    builder.add_edge("Answer", END)

    return builder.compile()