Spaces:
Sleeping
Sleeping
File size: 5,368 Bytes
c5185b5 ac0db5b 60e7a59 ac0db5b c5185b5 20bd124 ac0db5b d2552ae ac0db5b d2552ae c5185b5 ac0db5b fdfc130 b660c22 fdfc130 ac0db5b c5185b5 ac0db5b 20bd124 de1d9c7 20bd124 de1d9c7 20bd124 de1d9c7 20bd124 de1d9c7 d2552ae 99e6c99 20bd124 d2552ae 20bd124 d2552ae 20bd124 47545b1 20bd124 47545b1 20bd124 47545b1 20bd124 47545b1 20bd124 d2552ae 60e7a59 20bd124 47545b1 d2552ae 20bd124 47545b1 20bd124 60e7a59 20bd124 47545b1 20bd124 47545b1 20bd124 ac0db5b 20bd124 c5185b5 20bd124 ac0db5b 20bd124 ac0db5b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import operator
import os
import time
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, AnyMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import add_messages, START, END, StateGraph
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from typing_extensions import TypedDict, Annotated
class State(TypedDict):
messages: Annotated[list, add_messages]
content_type: str
content: str
aggregate: Annotated[list, operator.add]
# graph_state: str
def get_llm():
os.getenv("GROQ_API_KEY")
# return init_chat_model("llama-3.3-70b-versatile", model_provider="groq")
return init_chat_model("gemini-2.0-flash", model_provider="google_genai")
def get_graph(llm):
with open('prompts/system_prompt.md', 'r', encoding='utf-8') as markdown_file:
system_prompt = markdown_file.read()
prompt_template = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
]
)
from langchain_community.retrievers import WikipediaRetriever
from langchain_community.retrievers import TavilySearchAPIRetriever
# Wikipedia retriever
wiki_retriever = WikipediaRetriever(load_max_docs =20)
# Tavily retriever
tavily_retriever = TavilySearchAPIRetriever(k=3)
@tool
def retrieve(query: str):
"""
This function retrieves Wikipedia entries based on the query.
"""
print("\n-------------------- Tool (Wikipedia) has been called --------------------\n")
print("The query is: ", query)
docs = wiki_retriever.invoke(query)
serialized = "\n\n".join(
(f"\nContent:\n{doc.page_content}")
for doc in docs
)
return serialized
@tool
def online_search(query: str):
"""
This function does a web search based on the query.
"""
print("\n-------------------- Tool (Tavily) has been called --------------------\n")
print("The query is: ", query)
docs = tavily_retriever.invoke(query)
serialized = "\n\n".join(
(f"\nContent:\n{doc.page_content}")
for doc in docs
)
return serialized
tools = [retrieve, online_search]
tool_node = ToolNode(tools)
llm_with_tools = llm.bind_tools(tools)
def make_plan(state: State):
print("\n-------------------- Starting to create a plan --------------------\n")
print("Content is: ", state["content_type"])
# get all messages from the state
messages = state["messages"]
# append planning message
messages.append(HumanMessage(content="Write a plan how to solve this qustion?"))
# create prompt
prompt = prompt_template.invoke(messages)
# invoke LLM
response = llm.invoke(prompt)
print("The plan is: ", response.content)
return {"messages": [response], "aggregate": ["Plan"]}
def call_model(state: State):
print("\n-------------------- Agent has been called -----------------------------------\n")
# get all messages from the state
messages = state["messages"]
# append instruction message
messages.append(HumanMessage(content="Please provide me the answer to the question in detail."))
# create prompt
prompt_answer = prompt_template.invoke(messages)
# invoke LLM
response = llm_with_tools.invoke(prompt_answer)
print("Agent has made a decision:\n", response.content, response.tool_calls)
print("Waiting for 4 seconds...")
time.sleep(4)
return {"messages": [response], "aggregate": ["Agent"]}
def get_answer(state: State):
# get all messages from the state
messages = state["messages"]
# add prompt message
messages.append(HumanMessage(content="Please provide me just the plain answer to the question"))
# create prompt
prompt_answer = prompt_template.invoke(messages)
# invoke LLM
response = llm.invoke(prompt_answer)
print("The final answer is: ", response.content)
return {"messages": [response], "aggregate": ["Answer"]}
def should_continue(state: State):
print("\n-------------------- Decision of forwarding has been made --------------------\n")
messages = state["messages"]
print("This is round: ",len(state["aggregate"]))
print("The last message is: ", messages[-1])
if len(state["aggregate"]) < 8:
last_message = messages[-1]
if last_message.tool_calls:
return "tools"
return "Answer"
else:
return "Answer"
# Build graph
builder = StateGraph(State)
builder.add_node("tools", tool_node)
builder.add_node("Plan", make_plan)
builder.add_node("Agent", call_model)
builder.add_node("Answer", get_answer)
# Logic
builder.add_edge(START, "Plan")
builder.add_edge("Plan", "Agent")
builder.add_conditional_edges("Agent", should_continue, ["tools", "Answer"])
builder.add_edge("tools", "Agent")
builder.add_edge("Answer", END)
return builder.compile()
|