skyliulu's picture
default groq
4d8ca01
raw
history blame
3.02 kB
import os
from typing import TypedDict, Annotated
from dotenv import load_dotenv
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
from langgraph.prebuilt import ToolNode
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
import requests
from tools import *
# load api key
load_dotenv()
def buildAgent(provider="groq"):
# load the system prompt from the file
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
print(system_prompt)
# System message
sys_msg = SystemMessage(content=system_prompt)
# Generate the chat interface, including the tools
if provider == "huggingface":
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(repo_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
)
elif provider == "groq":
llm = ChatGroq(model="qwen-qwq-32b")
else:
raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.")
agent_tools = [
multiply,
add,
subtract,
divide,
modulus,
power,
square_root,
web_search,
wiki_search,
arxiv_search,
download_file,
]
chat_with_tools = llm.bind_tools(agent_tools)
# nodes
def assistant(state: MessagesState):
return {
"messages": [chat_with_tools.invoke(state["messages"])],
}
# todo add rag
def retriever(state: MessagesState):
"""Retriever node"""
# Handle the case when no similar questions are found
return {"messages": [sys_msg] + state["messages"]}
## The graph
builder = StateGraph(MessagesState)
# Define nodes: these do the work
builder.add_node("retriever", retriever)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(agent_tools))
# Define edges: these determine how the control flow moves
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges(
"assistant",
# If the latest message requires a tool, route to tools
# Otherwise, provide a direct response
tools_condition,
)
builder.add_edge("tools", "assistant")
return builder.compile()
if __name__ == "__main__":
random_question_url = "https://agents-course-unit4-scoring.hf.space/random-question"
response = requests.get(random_question_url, timeout=15)
questions_data = response.json()
question = questions_data.get("question")
graph = buildAgent(provider="groq")
messages = [HumanMessage(content=question)]
print(messages)
messages = graph.invoke({"messages": messages})
for m in messages["messages"]:
m.pretty_print()