import operator import os import time from langchain.chat_models import init_chat_model from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, AnyMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langgraph.graph import add_messages, START, END, StateGraph from langchain_core.tools import tool from langgraph.prebuilt import ToolNode from typing_extensions import TypedDict, Annotated class State(TypedDict): messages: Annotated[list, add_messages] content_type: str content: str aggregate: Annotated[list, operator.add] # graph_state: str def get_llm(): os.getenv("GROQ_API_KEY") # return init_chat_model("llama-3.3-70b-versatile", model_provider="groq") return init_chat_model("gemini-2.0-flash", model_provider="google_genai") def get_graph(llm): with open('prompts/system_prompt.md', 'r', encoding='utf-8') as markdown_file: system_prompt = markdown_file.read() prompt_template = ChatPromptTemplate.from_messages( [ ("system", system_prompt), MessagesPlaceholder(variable_name="messages"), ] ) from langchain_community.retrievers import WikipediaRetriever from langchain_community.retrievers import TavilySearchAPIRetriever # Wikipedia retriever wiki_retriever = WikipediaRetriever(load_max_docs =20) # Tavily retriever tavily_retriever = TavilySearchAPIRetriever(k=3) @tool def retrieve(query: str): """ This function retrieves Wikipedia entries based on the query. """ print("\n-------------------- Tool (Wikipedia) has been called --------------------\n") print("The query is: ", query) docs = wiki_retriever.invoke(query) serialized = "\n\n".join( (f"\nContent:\n{doc.page_content}") for doc in docs ) return serialized @tool def online_search(query: str): """ This function does a web search based on the query. """ print("\n-------------------- Tool (Tavily) has been called --------------------\n") print("The query is: ", query) docs = tavily_retriever.invoke(query) serialized = "\n\n".join( (f"\nContent:\n{doc.page_content}") for doc in docs ) return serialized tools = [retrieve, online_search] tool_node = ToolNode(tools) llm_with_tools = llm.bind_tools(tools) def make_plan(state: State): print("\n-------------------- Starting to create a plan --------------------\n") print("Content is: ", state["content_type"]) # get all messages from the state messages = state["messages"] # append planning message messages.append(HumanMessage(content="Write a plan how to solve this qustion?")) # create prompt prompt = prompt_template.invoke(messages) # invoke LLM response = llm.invoke(prompt) print("The plan is: ", response.content) return {"messages": [response], "aggregate": ["Plan"]} def call_model(state: State): print("\n-------------------- Agent has been called -----------------------------------\n") # get all messages from the state messages = state["messages"] # append instruction message messages.append(HumanMessage(content="Please provide me the answer to the question in detail.")) # create prompt prompt_answer = prompt_template.invoke(messages) # invoke LLM response = llm_with_tools.invoke(prompt_answer) print("Agent has made a decision:\n", response.content, response.tool_calls) print("Waiting for 4 seconds...") time.sleep(4) return {"messages": [response], "aggregate": ["Agent"]} def get_answer(state: State): # get all messages from the state messages = state["messages"] # add prompt message messages.append(HumanMessage(content="Please provide me just the plain answer to the question")) # create prompt prompt_answer = prompt_template.invoke(messages) # invoke LLM response = llm.invoke(prompt_answer) print("The final answer is: ", response.content) return {"messages": [response], "aggregate": ["Answer"]} def should_continue(state: State): print("\n-------------------- Decision of forwarding has been made --------------------\n") messages = state["messages"] print("This is round: ",len(state["aggregate"])) print("The last message is: ", messages[-1]) if len(state["aggregate"]) < 8: last_message = messages[-1] if last_message.tool_calls: return "tools" return "Answer" else: return "Answer" # Build graph builder = StateGraph(State) builder.add_node("tools", tool_node) builder.add_node("Plan", make_plan) builder.add_node("Agent", call_model) builder.add_node("Answer", get_answer) # Logic builder.add_edge(START, "Plan") builder.add_edge("Plan", "Agent") builder.add_conditional_edges("Agent", should_continue, ["tools", "Answer"]) builder.add_edge("tools", "Agent") builder.add_edge("Answer", END) return builder.compile()