Luigi D'Addona
aggiunt file agent.py con la definizione dell'agent
7a786af
raw
history blame
3.94 kB
import os
from dotenv import load_dotenv
import traceback
from typing import Annotated,Sequence, TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages # helper function to add messages to the state
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph, END
from langchain_google_genai import ChatGoogleGenerativeAI
# Local imports
from tools import get_search_tool, get_wikipedia_tool
# Nota: per i test in locale si usa il .env
# su HuggingFace invece si usano le variabili definite in Settings/"Variables and secrets"
load_dotenv()
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
GEMINI_MODEL = os.environ.get("GEMINI_MODEL")
GEMINI_BASE_URL = os.environ.get("GEMINI_BASE_URL")
#
# Inizializza il modello e gli associa i tool
#
# ChatGoogleGenerativeAI è il package ufficiale di LangChain per interagire con i modelli Gemini
# https://python.langchain.com/docs/integrations/chat/google_generative_ai/
chat = ChatGoogleGenerativeAI(
model=GEMINI_MODEL,
google_api_key=GEMINI_API_KEY)
# Imposta i tool
search_tool = get_search_tool()
wikipedia_tool = get_wikipedia_tool()
tools = [search_tool, wikipedia_tool]
# Bind tools to the model
model = chat.bind_tools(tools)
tools_by_name = {tool.name: tool for tool in tools}
#
# Definisce il grafo
#
class AgentState(TypedDict):
"""The state of the agent."""
messages: Annotated[Sequence[BaseMessage], add_messages]
number_of_steps: int
# Define our tool node
def call_tool(state: AgentState):
outputs = []
# Iterate over the tool calls in the last message
for tool_call in state["messages"][-1].tool_calls:
# Get the tool by name
tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"])
outputs.append(
ToolMessage(
content=tool_result,
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
def call_model( state: AgentState, config: RunnableConfig):
# Invoke the model with the system prompt and the messages
response = model.invoke(state["messages"], config)
# We return a list, because this will get added to the existing messages state using the add_messages reducer
return {"messages": [response]}
# Define the conditional edge that determines whether to continue or not
def should_continue(state: AgentState):
messages = state["messages"]
# If the last message is not a tool call, then we finish
if not messages[-1].tool_calls:
return "end"
# default to continue
return "continue"
def get_agent():
# Creazione del grafo
workflow = StateGraph(AgentState)
# 1. Add our nodes
workflow.add_node("llm", call_model)
workflow.add_node("tools", call_tool)
# 2. Set the entrypoint as `agent`, this is the first node called
workflow.set_entry_point("llm")
# 3. Add a conditional edge after the `llm` node is called.
workflow.add_conditional_edges(
# Edge is used after the `llm` node is called.
"llm",
# The function that will determine which node is called next.
should_continue,
# Mapping for where to go next, keys are strings from the function return, and the values are other nodes.
# END is a special node marking that the graph is finish.
{
# If `tools`, then we call the tool node.
"continue": "tools",
# Otherwise we finish.
"end": END,
},
)
# 4. Add a normal edge after `tools` is called, `llm` node is called next.
workflow.add_edge("tools", "llm")
# 5. Now we can compile our graph
react_graph = workflow.compile()
return react_graph
# Riferimenti
#
# https://ai.google.dev/gemini-api/docs/langgraph-example