Luigi D'Addona commited on
Commit
7a786af
·
1 Parent(s): 5191ddb

aggiunt file agent.py con la definizione dell'agent

Browse files
Files changed (1) hide show
  1. agent.py +125 -0
agent.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ import traceback
4
+
5
+ from typing import Annotated,Sequence, TypedDict
6
+
7
+ from langchain_core.messages import BaseMessage
8
+ from langgraph.graph.message import add_messages # helper function to add messages to the state
9
+ from langchain_core.messages import ToolMessage
10
+ from langchain_core.runnables import RunnableConfig
11
+ from langgraph.graph import StateGraph, END
12
+ from langchain_google_genai import ChatGoogleGenerativeAI
13
+
14
+ # Local imports
15
+ from tools import get_search_tool, get_wikipedia_tool
16
+
17
+ # Nota: per i test in locale si usa il .env
18
+ # su HuggingFace invece si usano le variabili definite in Settings/"Variables and secrets"
19
+ load_dotenv()
20
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
21
+ GEMINI_MODEL = os.environ.get("GEMINI_MODEL")
22
+ GEMINI_BASE_URL = os.environ.get("GEMINI_BASE_URL")
23
+
24
+ #
25
+ # Inizializza il modello e gli associa i tool
26
+ #
27
+
28
+ # ChatGoogleGenerativeAI è il package ufficiale di LangChain per interagire con i modelli Gemini
29
+ # https://python.langchain.com/docs/integrations/chat/google_generative_ai/
30
+ chat = ChatGoogleGenerativeAI(
31
+ model=GEMINI_MODEL,
32
+ google_api_key=GEMINI_API_KEY)
33
+
34
+ # Imposta i tool
35
+ search_tool = get_search_tool()
36
+ wikipedia_tool = get_wikipedia_tool()
37
+
38
+ tools = [search_tool, wikipedia_tool]
39
+
40
+ # Bind tools to the model
41
+ model = chat.bind_tools(tools)
42
+
43
+ tools_by_name = {tool.name: tool for tool in tools}
44
+
45
+
46
+ #
47
+ # Definisce il grafo
48
+ #
49
+
50
+ class AgentState(TypedDict):
51
+ """The state of the agent."""
52
+ messages: Annotated[Sequence[BaseMessage], add_messages]
53
+ number_of_steps: int
54
+
55
+
56
+ # Define our tool node
57
+ def call_tool(state: AgentState):
58
+ outputs = []
59
+ # Iterate over the tool calls in the last message
60
+ for tool_call in state["messages"][-1].tool_calls:
61
+ # Get the tool by name
62
+ tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"])
63
+ outputs.append(
64
+ ToolMessage(
65
+ content=tool_result,
66
+ name=tool_call["name"],
67
+ tool_call_id=tool_call["id"],
68
+ )
69
+ )
70
+ return {"messages": outputs}
71
+
72
+
73
+ def call_model( state: AgentState, config: RunnableConfig):
74
+ # Invoke the model with the system prompt and the messages
75
+ response = model.invoke(state["messages"], config)
76
+ # We return a list, because this will get added to the existing messages state using the add_messages reducer
77
+ return {"messages": [response]}
78
+
79
+
80
+ # Define the conditional edge that determines whether to continue or not
81
+ def should_continue(state: AgentState):
82
+ messages = state["messages"]
83
+ # If the last message is not a tool call, then we finish
84
+ if not messages[-1].tool_calls:
85
+ return "end"
86
+ # default to continue
87
+ return "continue"
88
+
89
+
90
+ def get_agent():
91
+ # Creazione del grafo
92
+ workflow = StateGraph(AgentState)
93
+
94
+ # 1. Add our nodes
95
+ workflow.add_node("llm", call_model)
96
+ workflow.add_node("tools", call_tool)
97
+ # 2. Set the entrypoint as `agent`, this is the first node called
98
+ workflow.set_entry_point("llm")
99
+ # 3. Add a conditional edge after the `llm` node is called.
100
+ workflow.add_conditional_edges(
101
+ # Edge is used after the `llm` node is called.
102
+ "llm",
103
+ # The function that will determine which node is called next.
104
+ should_continue,
105
+ # Mapping for where to go next, keys are strings from the function return, and the values are other nodes.
106
+ # END is a special node marking that the graph is finish.
107
+ {
108
+ # If `tools`, then we call the tool node.
109
+ "continue": "tools",
110
+ # Otherwise we finish.
111
+ "end": END,
112
+ },
113
+ )
114
+ # 4. Add a normal edge after `tools` is called, `llm` node is called next.
115
+ workflow.add_edge("tools", "llm")
116
+
117
+ # 5. Now we can compile our graph
118
+ react_graph = workflow.compile()
119
+
120
+ return react_graph
121
+
122
+
123
+ # Riferimenti
124
+ #
125
+ # https://ai.google.dev/gemini-api/docs/langgraph-example