Final_Assignment / agent.py
dennis111's picture
update
d2552ae
raw
history blame
5.37 kB
import operator
import os
import time
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, AnyMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import add_messages, START, END, StateGraph
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from typing_extensions import TypedDict, Annotated
class State(TypedDict):
messages: Annotated[list, add_messages]
content_type: str
content: str
aggregate: Annotated[list, operator.add]
# graph_state: str
def get_llm():
os.getenv("GROQ_API_KEY")
# return init_chat_model("llama-3.3-70b-versatile", model_provider="groq")
return init_chat_model("gemini-2.0-flash", model_provider="google_genai")
def get_graph(llm):
with open('prompts/system_prompt.md', 'r', encoding='utf-8') as markdown_file:
system_prompt = markdown_file.read()
prompt_template = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
]
)
from langchain_community.retrievers import WikipediaRetriever
from langchain_community.retrievers import TavilySearchAPIRetriever
# Wikipedia retriever
wiki_retriever = WikipediaRetriever(load_max_docs =20)
# Tavily retriever
tavily_retriever = TavilySearchAPIRetriever(k=3)
@tool
def retrieve(query: str):
"""
This function retrieves Wikipedia entries based on the query.
"""
print("\n-------------------- Tool (Wikipedia) has been called --------------------\n")
print("The query is: ", query)
docs = wiki_retriever.invoke(query)
serialized = "\n\n".join(
(f"\nContent:\n{doc.page_content}")
for doc in docs
)
return serialized
@tool
def online_search(query: str):
"""
This function does a web search based on the query.
"""
print("\n-------------------- Tool (Tavily) has been called --------------------\n")
print("The query is: ", query)
docs = tavily_retriever.invoke(query)
serialized = "\n\n".join(
(f"\nContent:\n{doc.page_content}")
for doc in docs
)
return serialized
tools = [retrieve, online_search]
tool_node = ToolNode(tools)
llm_with_tools = llm.bind_tools(tools)
def make_plan(state: State):
print("\n-------------------- Starting to create a plan --------------------\n")
print("Content is: ", state["content_type"])
# get all messages from the state
messages = state["messages"]
# append planning message
messages.append(HumanMessage(content="Write a plan how to solve this qustion?"))
# create prompt
prompt = prompt_template.invoke(messages)
# invoke LLM
response = llm.invoke(prompt)
print("The plan is: ", response.content)
return {"messages": [response], "aggregate": ["Plan"]}
def call_model(state: State):
print("\n-------------------- Agent has been called -----------------------------------\n")
# get all messages from the state
messages = state["messages"]
# append instruction message
messages.append(HumanMessage(content="Please provide me the answer to the question in detail."))
# create prompt
prompt_answer = prompt_template.invoke(messages)
# invoke LLM
response = llm_with_tools.invoke(prompt_answer)
print("Agent has made a decision:\n", response.content, response.tool_calls)
print("Waiting for 4 seconds...")
time.sleep(4)
return {"messages": [response], "aggregate": ["Agent"]}
def get_answer(state: State):
# get all messages from the state
messages = state["messages"]
# add prompt message
messages.append(HumanMessage(content="Please provide me just the plain answer to the question"))
# create prompt
prompt_answer = prompt_template.invoke(messages)
# invoke LLM
response = llm.invoke(prompt_answer)
print("The final answer is: ", response.content)
return {"messages": [response], "aggregate": ["Answer"]}
def should_continue(state: State):
print("\n-------------------- Decision of forwarding has been made --------------------\n")
messages = state["messages"]
print("This is round: ",len(state["aggregate"]))
print("The last message is: ", messages[-1])
if len(state["aggregate"]) < 8:
last_message = messages[-1]
if last_message.tool_calls:
return "tools"
return "Answer"
else:
return "Answer"
# Build graph
builder = StateGraph(State)
builder.add_node("tools", tool_node)
builder.add_node("Plan", make_plan)
builder.add_node("Agent", call_model)
builder.add_node("Answer", get_answer)
# Logic
builder.add_edge(START, "Plan")
builder.add_edge("Plan", "Agent")
builder.add_conditional_edges("Agent", should_continue, ["tools", "Answer"])
builder.add_edge("tools", "Agent")
builder.add_edge("Answer", END)
return builder.compile()