Spaces:
Sleeping
Sleeping
| import operator | |
| import os | |
| import time | |
| from langchain.chat_models import init_chat_model | |
| from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, AnyMessage | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langgraph.graph import add_messages, START, END, StateGraph | |
| from langchain_core.tools import tool | |
| from langgraph.prebuilt import ToolNode | |
| from typing_extensions import TypedDict, Annotated | |
| class State(TypedDict): | |
| messages: Annotated[list, add_messages] | |
| content_type: str | |
| content: str | |
| aggregate: Annotated[list, operator.add] | |
| # graph_state: str | |
| def get_llm(): | |
| os.getenv("GROQ_API_KEY") | |
| # return init_chat_model("llama-3.3-70b-versatile", model_provider="groq") | |
| return init_chat_model("gemini-2.0-flash", model_provider="google_genai") | |
| def get_graph(llm): | |
| with open('prompts/system_prompt.md', 'r', encoding='utf-8') as markdown_file: | |
| system_prompt = markdown_file.read() | |
| prompt_template = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| MessagesPlaceholder(variable_name="messages"), | |
| ] | |
| ) | |
| from langchain_community.retrievers import WikipediaRetriever | |
| from langchain_community.retrievers import TavilySearchAPIRetriever | |
| # Wikipedia retriever | |
| wiki_retriever = WikipediaRetriever(load_max_docs =20) | |
| # Tavily retriever | |
| tavily_retriever = TavilySearchAPIRetriever(k=3) | |
| def retrieve(query: str): | |
| """ | |
| This function retrieves Wikipedia entries based on the query. | |
| """ | |
| print("\n-------------------- Tool (Wikipedia) has been called --------------------\n") | |
| print("The query is: ", query) | |
| docs = wiki_retriever.invoke(query) | |
| serialized = "\n\n".join( | |
| (f"\nContent:\n{doc.page_content}") | |
| for doc in docs | |
| ) | |
| return serialized | |
| def online_search(query: str): | |
| """ | |
| This function does a web search based on the query. | |
| """ | |
| print("\n-------------------- Tool (Tavily) has been called --------------------\n") | |
| print("The query is: ", query) | |
| docs = tavily_retriever.invoke(query) | |
| serialized = "\n\n".join( | |
| (f"\nContent:\n{doc.page_content}") | |
| for doc in docs | |
| ) | |
| return serialized | |
| tools = [retrieve, online_search] | |
| tool_node = ToolNode(tools) | |
| llm_with_tools = llm.bind_tools(tools) | |
| def make_plan(state: State): | |
| print("\n-------------------- Starting to create a plan --------------------\n") | |
| print("Content is: ", state["content_type"]) | |
| # get all messages from the state | |
| messages = state["messages"] | |
| # append planning message | |
| messages.append(HumanMessage(content="Write a plan how to solve this qustion?")) | |
| # create prompt | |
| prompt = prompt_template.invoke(messages) | |
| # invoke LLM | |
| response = llm.invoke(prompt) | |
| print("The plan is: ", response.content) | |
| return {"messages": [response], "aggregate": ["Plan"]} | |
| def call_model(state: State): | |
| print("\n-------------------- Agent has been called -----------------------------------\n") | |
| # get all messages from the state | |
| messages = state["messages"] | |
| # append instruction message | |
| messages.append(HumanMessage(content="Please provide me the answer to the question in detail.")) | |
| # create prompt | |
| prompt_answer = prompt_template.invoke(messages) | |
| # invoke LLM | |
| response = llm_with_tools.invoke(prompt_answer) | |
| print("Agent has made a decision:\n", response.content, response.tool_calls) | |
| print("Waiting for 4 seconds...") | |
| time.sleep(4) | |
| return {"messages": [response], "aggregate": ["Agent"]} | |
| def get_answer(state: State): | |
| # get all messages from the state | |
| messages = state["messages"] | |
| # add prompt message | |
| messages.append(HumanMessage(content="Please provide me just the plain answer to the question")) | |
| # create prompt | |
| prompt_answer = prompt_template.invoke(messages) | |
| # invoke LLM | |
| response = llm.invoke(prompt_answer) | |
| print("The final answer is: ", response.content) | |
| return {"messages": [response], "aggregate": ["Answer"]} | |
| def should_continue(state: State): | |
| print("\n-------------------- Decision of forwarding has been made --------------------\n") | |
| messages = state["messages"] | |
| print("This is round: ",len(state["aggregate"])) | |
| print("The last message is: ", messages[-1]) | |
| if len(state["aggregate"]) < 8: | |
| last_message = messages[-1] | |
| if last_message.tool_calls: | |
| return "tools" | |
| return "Answer" | |
| else: | |
| return "Answer" | |
| # Build graph | |
| builder = StateGraph(State) | |
| builder.add_node("tools", tool_node) | |
| builder.add_node("Plan", make_plan) | |
| builder.add_node("Agent", call_model) | |
| builder.add_node("Answer", get_answer) | |
| # Logic | |
| builder.add_edge(START, "Plan") | |
| builder.add_edge("Plan", "Agent") | |
| builder.add_conditional_edges("Agent", should_continue, ["tools", "Answer"]) | |
| builder.add_edge("tools", "Agent") | |
| builder.add_edge("Answer", END) | |
| return builder.compile() | |