import operator
import os
import time
from typing import Optional
from langchain.chat_models import init_chat_model
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader, YoutubeLoader
from langchain_community.tools import TavilySearchResults
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import add_messages, START, END, StateGraph
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from pydantic import SecretStr
from langchain_custom import WikipediaTableLoader
from typing_extensions import TypedDict, Annotated
class State(TypedDict):
messages: Annotated[list, add_messages]
content_type: Optional[str]
content: Optional[str]
aggregate: Annotated[list, operator.add]
# graph_state: str
def get_llm():
os.getenv("GROQ_API_KEY")
#return init_chat_model("llama-3.3-70b-versatile", model_provider="groq")
return init_chat_model("gemini-2.0-flash", model_provider="google_genai")
#return AzureChatOpenAI(
# api_key=SecretStr(os.environ["AZURE_OPENAI_API_KEY"]),
# azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
#azure_deployment="gpt-4o-mini",
#api_version=os.environ["AZURE_OPENAI_API_VERSION"],
#)
def get_graph(llm):
with open('prompts/system_prompt.md', 'r', encoding='utf-8') as markdown_file:
system_prompt = markdown_file.read()
prompt_template = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
]
)
from langchain_community.retrievers import WikipediaRetriever
from langchain_community.retrievers import TavilySearchAPIRetriever
# Wikipedia retriever
wiki_retriever = WikipediaRetriever()
# Tavily retriever
tavily_retriever = TavilySearchAPIRetriever(k=3)
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers.
Args:
a: first int
b: second int
"""
print("\n-------------------- Tool (Multiplication) has been called --------------------\n")
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two numbers.
Args:
a: first int
b: second int
"""
print("\n-------------------- Tool (Addition) has been called --------------------\n")
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract two numbers.
Args:
a: first int
b: second int
"""
print("\n-------------------- Tool (Subtraction) has been called --------------------\n")
return a - b
@tool
def divide(a: int, b: int) -> float:
"""Divide two numbers.
Args:
a: first int
b: second int
"""
print("\n-------------------- Tool (Division) has been called --------------------\n")
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Get the modulus of two numbers.
Args:
a: first int
b: second int
"""
print("\n-------------------- Tool (Modulus) has been called --------------------\n")
return a % b
@tool
def retrieve(query: str):
"""
This function retrieves Wikipedia entries based on the query.
"""
print("\n-------------------- Tool (Wikipedia) has been called --------------------\n")
print("The query is: ", query)
docs = wiki_retriever.invoke(query)
serialized = "\n\n".join(
f"\nContent:\n{doc.page_content}"
for doc in docs
)
return serialized
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for a query and return maximum 2 results.
Args:
query: The search query."""
print("\n-------------------- Tool (Wikipedia) has been called --------------------\n")
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
parts: list[str] = []
for doc in search_docs:
parts.append(
f'\n'
f'{doc.page_content}\n'
)
try:
print("---------------------------------")
print("Loading tables from: ", doc.metadata["source"])
print("---------------------------------")
tables = WikipediaTableLoader(url=doc.metadata["source"], title=doc.metadata["title"]).load()
for i, table in enumerate(tables):
parts.append(
f'\n'
f'{table.page_content}\n'
)
except Exception:
pass
formatted_search_docs = "\n\n---\n\n".join(parts)
return formatted_search_docs
@tool
def wiki_table_search(url: str, title: str) -> str:
"""Get Wikipedia tables for a given URL and title.
Args:
url: The Wikipedia URL.
title: The title of the Wikipedia page."""
print("\n-------------------- Tool (Wikipedia-Table) has been called --------------------\n")
search_docs = WikipediaTableLoader(url=url, title=title).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'\n{doc.page_content}\n'
for doc in search_docs
])
return formatted_search_docs
@tool
def online_search(query: str):
"""
This function does a web search based on the query.
"""
print("\n-------------------- Tool (Tavily) has been called --------------------\n")
print("The query is: ", query)
# docs = tavily_retriever.invoke(query)
docs = TavilySearchResults(max_results=3).invoke({'query': query})
serialized = "\n\n".join(
f"\nContent:\n{doc.page_content}"
for doc in docs
)
return serialized
@tool
def web_search(query: str) -> str:
"""Search Tavily for a query and return maximum 3 results.
Args:
query: The search query."""
print("\n-------------------- Tool (Tavily) has been called --------------------\n")
search_docs = TavilySearchResults(max_results=3).invoke({'query': query})
formatted_search_docs = "\n\n---\n\n".join(
[
f'URL: {doc["url"]}\nTitle= {doc["title"]}\nContent: {doc["content"]}'
for doc in search_docs
])
return formatted_search_docs
@tool
def arvix_search(query: str) -> str:
"""Search Arxiv for a query and return maximum 3 result.
Args:
query: The search query."""
print()
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'\n{doc.page_content[:1000]}\n'
for doc in search_docs
])
return formatted_search_docs
@tool
def youtube_transcript(url: str) -> str:
"""Download a transcript of a YouTube video.
Args:
url: URL of the YouTube video."""
print("\n-------------------- Tool (YouTube Transcript) has been called --------------------\n")
loader = YoutubeLoader.from_youtube_url(
url, add_video_info=False
)
docs = loader.load()
transcript = "\n\n".join(
f"\nContent:\n{doc.page_content}"
for doc in docs
)
return transcript
tools = [wiki_search, web_search, arvix_search, youtube_transcript, multiply, add, subtract, divide, modulus]
tool_node = ToolNode(tools)
llm_with_tools = llm.bind_tools(tools)
def make_plan(state: State):
print("\n-------------------- Starting to create a plan --------------------\n")
print("Waiting for 5 seconds...")
time.sleep(5)
if "content_type" in state:
print("Content is: ", state["content"])
# get all messages from the state
messages = state["messages"]
# append planning message
messages.append(HumanMessage(content="Write a plan how to solve this qustion?"))
# create prompt
prompt = prompt_template.invoke(messages)
# invoke LLM
response = llm.invoke(prompt)
print("The plan is: ", response.content)
return {"messages": [response], "aggregate": ["Plan"]}
def call_model(state: State):
print("\n-------------------- Agent has been called -----------------------------------\n")
print("Waiting for 5 seconds...")
time.sleep(5)
# get all messages from the state
messages = state["messages"]
# append instruction message
messages.append(HumanMessage(content="Please provide me the answer to the question in detail."))
# create prompt
prompt_answer = prompt_template.invoke(messages)
# invoke LLM
response = llm_with_tools.invoke(prompt_answer)
print("Agent has made a decision:\n", response.content, response.tool_calls)
return {"messages": [response], "aggregate": ["Agent"]}
def get_answer(state: State):
print("\n-------------------- Generating Answer -----------------------------------\n")
print("Waiting for 5 seconds...")
time.sleep(5)
# get all messages from the state
messages = state["messages"]
# add prompt message
messages.append(HumanMessage(content="Please provide me just the plain answer to the question"))
# create prompt
prompt_answer = prompt_template.invoke(messages)
# invoke LLM
response = llm.invoke(prompt_answer)
print("The final answer is: ", response.content)
return {"messages": [response], "aggregate": ["Answer"]}
def should_continue(state: State):
print("\n-------------------- Decision of forwarding has been made --------------------\n")
print("Waiting for 2 seconds...")
time.sleep(2)
messages = state["messages"]
print("This is round: ",len(state["aggregate"]))
print("The last message is: ", messages[-1])
if len(state["aggregate"]) < 8:
last_message = messages[-1]
if last_message.tool_calls:
return "tools"
return "Answer"
else:
return "Answer"
# Build graph
builder = StateGraph(State)
builder.add_node("tools", tool_node)
builder.add_node("Plan", make_plan)
builder.add_node("Agent", call_model)
builder.add_node("Answer", get_answer)
# Logic
builder.add_edge(START, "Plan")
builder.add_edge("Plan", "Agent")
builder.add_conditional_edges("Agent", should_continue, ["tools", "Answer"])
builder.add_edge("tools", "Agent")
builder.add_edge("Answer", END)
return builder.compile()