|
|
import os |
|
|
from typing import TypedDict, Annotated |
|
|
from langgraph.graph.message import add_messages |
|
|
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, ToolMessage, SystemMessage |
|
|
from langchain_groq import ChatGroq |
|
|
from langgraph.prebuilt import ToolNode |
|
|
from langgraph.graph import START, StateGraph, END |
|
|
from langgraph.prebuilt import tools_condition |
|
|
from tools import (retriever, web_search, wiki_search, youtube_analysis, |
|
|
add_numbers, subtract_numbers, multiply_numbers, divide_numbers, modulus_numbers, |
|
|
detect_objects, run_python |
|
|
) |
|
|
from prompt import text_prompt |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY") |
|
|
|
|
|
class State(TypedDict): |
|
|
messages: Annotated[list[AnyMessage], add_messages] |
|
|
|
|
|
tools = [retriever, web_search, wiki_search, youtube_analysis, |
|
|
add_numbers, subtract_numbers, multiply_numbers, divide_numbers, modulus_numbers, |
|
|
detect_objects, run_python] |
|
|
|
|
|
model = "deepseek-r1-distill-llama-70b" |
|
|
llm = ChatGroq( |
|
|
model= model, |
|
|
temperature=0.0, |
|
|
max_tokens= None, |
|
|
reasoning_format="parsed", |
|
|
timeout=None, |
|
|
max_retries=2, |
|
|
) |
|
|
llm_with_tools = llm.bind_tools(tools) |
|
|
|
|
|
def ask_agent(agent_state: State): |
|
|
system_prompt = SystemMessage( |
|
|
content = text_prompt |
|
|
) |
|
|
query = agent_state["messages"][-1] |
|
|
response = llm_with_tools.invoke(text_prompt + query.content) |
|
|
return {"messages": [response]} |
|
|
|
|
|
graph_builder = StateGraph(State) |
|
|
|
|
|
graph_builder.add_node("agent", ask_agent) |
|
|
graph_builder.add_node("tools", ToolNode(tools)) |
|
|
graph_builder.add_edge(START, "agent") |
|
|
graph_builder.add_conditional_edges( |
|
|
"agent", |
|
|
tools_condition |
|
|
) |
|
|
graph_builder.add_edge("tools", "agent") |
|
|
|
|
|
alfred = graph_builder.compile() |