import operator import os import time from typing import Optional from langchain.chat_models import init_chat_model from langchain_community.document_loaders import WikipediaLoader, ArxivLoader, YoutubeLoader from langchain_community.tools import TavilySearchResults from langchain_core.messages import HumanMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langgraph.graph import add_messages, START, END, StateGraph from langchain_core.tools import tool from langgraph.prebuilt import ToolNode from pydantic import SecretStr from langchain_custom import WikipediaTableLoader from typing_extensions import TypedDict, Annotated class State(TypedDict): messages: Annotated[list, add_messages] content_type: Optional[str] content: Optional[str] aggregate: Annotated[list, operator.add] # graph_state: str def get_llm(): os.getenv("GROQ_API_KEY") #return init_chat_model("llama-3.3-70b-versatile", model_provider="groq") return init_chat_model("gemini-2.0-flash", model_provider="google_genai") #return AzureChatOpenAI( # api_key=SecretStr(os.environ["AZURE_OPENAI_API_KEY"]), # azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], #azure_deployment="gpt-4o-mini", #api_version=os.environ["AZURE_OPENAI_API_VERSION"], #) def get_graph(llm): with open('prompts/system_prompt.md', 'r', encoding='utf-8') as markdown_file: system_prompt = markdown_file.read() prompt_template = ChatPromptTemplate.from_messages( [ ("system", system_prompt), MessagesPlaceholder(variable_name="messages"), ] ) from langchain_community.retrievers import WikipediaRetriever from langchain_community.retrievers import TavilySearchAPIRetriever # Wikipedia retriever wiki_retriever = WikipediaRetriever() # Tavily retriever tavily_retriever = TavilySearchAPIRetriever(k=3) @tool def multiply(a: int, b: int) -> int: """Multiply two numbers. Args: a: first int b: second int """ print("\n-------------------- Tool (Multiplication) has been called --------------------\n") return a * b @tool def add(a: int, b: int) -> int: """Add two numbers. Args: a: first int b: second int """ print("\n-------------------- Tool (Addition) has been called --------------------\n") return a + b @tool def subtract(a: int, b: int) -> int: """Subtract two numbers. Args: a: first int b: second int """ print("\n-------------------- Tool (Subtraction) has been called --------------------\n") return a - b @tool def divide(a: int, b: int) -> float: """Divide two numbers. Args: a: first int b: second int """ print("\n-------------------- Tool (Division) has been called --------------------\n") if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulus(a: int, b: int) -> int: """Get the modulus of two numbers. Args: a: first int b: second int """ print("\n-------------------- Tool (Modulus) has been called --------------------\n") return a % b @tool def retrieve(query: str): """ This function retrieves Wikipedia entries based on the query. """ print("\n-------------------- Tool (Wikipedia) has been called --------------------\n") print("The query is: ", query) docs = wiki_retriever.invoke(query) serialized = "\n\n".join( f"\nContent:\n{doc.page_content}" for doc in docs ) return serialized @tool def wiki_search(query: str) -> str: """Search Wikipedia for a query and return maximum 2 results. Args: query: The search query.""" print("\n-------------------- Tool (Wikipedia) has been called --------------------\n") search_docs = WikipediaLoader(query=query, load_max_docs=2).load() parts: list[str] = [] for doc in search_docs: parts.append( f'\n' f'{doc.page_content}\n' ) try: print("---------------------------------") print("Loading tables from: ", doc.metadata["source"]) print("---------------------------------") tables = WikipediaTableLoader(url=doc.metadata["source"], title=doc.metadata["title"]).load() for i, table in enumerate(tables): parts.append( f'\n' f'{table.page_content}\n' ) except Exception: pass formatted_search_docs = "\n\n---\n\n".join(parts) return formatted_search_docs @tool def wiki_table_search(url: str, title: str) -> str: """Get Wikipedia tables for a given URL and title. Args: url: The Wikipedia URL. title: The title of the Wikipedia page.""" print("\n-------------------- Tool (Wikipedia-Table) has been called --------------------\n") search_docs = WikipediaTableLoader(url=url, title=title).load() formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content}\n' for doc in search_docs ]) return formatted_search_docs @tool def online_search(query: str): """ This function does a web search based on the query. """ print("\n-------------------- Tool (Tavily) has been called --------------------\n") print("The query is: ", query) # docs = tavily_retriever.invoke(query) docs = TavilySearchResults(max_results=3).invoke({'query': query}) serialized = "\n\n".join( f"\nContent:\n{doc.page_content}" for doc in docs ) return serialized @tool def web_search(query: str) -> str: """Search Tavily for a query and return maximum 3 results. Args: query: The search query.""" print("\n-------------------- Tool (Tavily) has been called --------------------\n") search_docs = TavilySearchResults(max_results=3).invoke({'query': query}) formatted_search_docs = "\n\n---\n\n".join( [ f'URL: {doc["url"]}\nTitle= {doc["title"]}\nContent: {doc["content"]}' for doc in search_docs ]) return formatted_search_docs @tool def arvix_search(query: str) -> str: """Search Arxiv for a query and return maximum 3 result. Args: query: The search query.""" print() search_docs = ArxivLoader(query=query, load_max_docs=3).load() formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content[:1000]}\n' for doc in search_docs ]) return formatted_search_docs @tool def youtube_transcript(url: str) -> str: """Download a transcript of a YouTube video. Args: url: URL of the YouTube video.""" print("\n-------------------- Tool (YouTube Transcript) has been called --------------------\n") loader = YoutubeLoader.from_youtube_url( url, add_video_info=False ) docs = loader.load() transcript = "\n\n".join( f"\nContent:\n{doc.page_content}" for doc in docs ) return transcript tools = [wiki_search, web_search, arvix_search, youtube_transcript, multiply, add, subtract, divide, modulus] tool_node = ToolNode(tools) llm_with_tools = llm.bind_tools(tools) def make_plan(state: State): print("\n-------------------- Starting to create a plan --------------------\n") print("Waiting for 5 seconds...") time.sleep(5) if "content_type" in state: print("Content is: ", state["content"]) # get all messages from the state messages = state["messages"] # append planning message messages.append(HumanMessage(content="Write a plan how to solve this qustion?")) # create prompt prompt = prompt_template.invoke(messages) # invoke LLM response = llm.invoke(prompt) print("The plan is: ", response.content) return {"messages": [response], "aggregate": ["Plan"]} def call_model(state: State): print("\n-------------------- Agent has been called -----------------------------------\n") print("Waiting for 5 seconds...") time.sleep(5) # get all messages from the state messages = state["messages"] # append instruction message messages.append(HumanMessage(content="Please provide me the answer to the question in detail.")) # create prompt prompt_answer = prompt_template.invoke(messages) # invoke LLM response = llm_with_tools.invoke(prompt_answer) print("Agent has made a decision:\n", response.content, response.tool_calls) return {"messages": [response], "aggregate": ["Agent"]} def get_answer(state: State): print("\n-------------------- Generating Answer -----------------------------------\n") print("Waiting for 5 seconds...") time.sleep(5) # get all messages from the state messages = state["messages"] # add prompt message messages.append(HumanMessage(content="Please provide me just the plain answer to the question")) # create prompt prompt_answer = prompt_template.invoke(messages) # invoke LLM response = llm.invoke(prompt_answer) print("The final answer is: ", response.content) return {"messages": [response], "aggregate": ["Answer"]} def should_continue(state: State): print("\n-------------------- Decision of forwarding has been made --------------------\n") print("Waiting for 2 seconds...") time.sleep(2) messages = state["messages"] print("This is round: ",len(state["aggregate"])) print("The last message is: ", messages[-1]) if len(state["aggregate"]) < 8: last_message = messages[-1] if last_message.tool_calls: return "tools" return "Answer" else: return "Answer" # Build graph builder = StateGraph(State) builder.add_node("tools", tool_node) builder.add_node("Plan", make_plan) builder.add_node("Agent", call_model) builder.add_node("Answer", get_answer) # Logic builder.add_edge(START, "Plan") builder.add_edge("Plan", "Agent") builder.add_conditional_edges("Agent", should_continue, ["tools", "Answer"]) builder.add_edge("tools", "Agent") builder.add_edge("Answer", END) return builder.compile()