Final_Assignment / agent.py
dennis111's picture
update
3a2c53a
raw
history blame
1.57 kB
import operator
import os
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, AnyMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import add_messages, START, END, StateGraph
from typing_extensions import TypedDict, Annotated
class State(TypedDict):
messages: Annotated[list, add_messages]
aggregate: Annotated[list, operator.add]
# graph_state: str
def get_llm():
os.getenv("GROQ_API_KEY")
return init_chat_model("llama-3.3-70b-versatile", model_provider="groq")
def get_graph(llm):
with open('prompts/system_prompt.md', 'r', encoding='utf-8') as markdown_file:
system_prompt = markdown_file.read()
prompt_template = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
]
)
def call_model(state: State):
print("\n-------------------- Agent has been called -----------------------------------\n")
prompt = prompt_template.invoke(state["messages"])
print("\nThe Prompt is: ", prompt, "\n")
response = llm.invoke(prompt)
print("Agent has made a decision: ",response.content)
return {"messages": [response], "aggregate": ["Agent"]}
# Build graph
builder = StateGraph(State)
builder.add_node("Agent", call_model)
# Logic
builder.add_edge(START, "Agent")
builder.add_edge("Agent", END)
return builder.compile()