Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| import traceback | |
| from typing import Annotated,Sequence, TypedDict | |
| from langchain_core.messages import BaseMessage | |
| from langgraph.graph.message import add_messages # helper function to add messages to the state | |
| from langchain_core.messages import ToolMessage | |
| from langchain_core.runnables import RunnableConfig | |
| from langgraph.graph import StateGraph, END | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| # Local imports | |
| from tools import get_search_tool, get_wikipedia_tool | |
| # Nota: per i test in locale si usa il .env | |
| # su HuggingFace invece si usano le variabili definite in Settings/"Variables and secrets" | |
| load_dotenv() | |
| GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") | |
| GEMINI_MODEL = os.environ.get("GEMINI_MODEL") | |
| GEMINI_BASE_URL = os.environ.get("GEMINI_BASE_URL") | |
| # | |
| # Inizializza il modello e gli associa i tool | |
| # | |
| # ChatGoogleGenerativeAI è il package ufficiale di LangChain per interagire con i modelli Gemini | |
| # https://python.langchain.com/docs/integrations/chat/google_generative_ai/ | |
| chat = ChatGoogleGenerativeAI( | |
| model=GEMINI_MODEL, | |
| google_api_key=GEMINI_API_KEY) | |
| # Imposta i tool | |
| search_tool = get_search_tool() | |
| wikipedia_tool = get_wikipedia_tool() | |
| tools = [search_tool, wikipedia_tool] | |
| # Bind tools to the model | |
| model = chat.bind_tools(tools) | |
| tools_by_name = {tool.name: tool for tool in tools} | |
| # | |
| # Definisce il grafo | |
| # | |
| class AgentState(TypedDict): | |
| """The state of the agent.""" | |
| messages: Annotated[Sequence[BaseMessage], add_messages] | |
| number_of_steps: int | |
| # Define our tool node | |
| def call_tool(state: AgentState): | |
| outputs = [] | |
| # Iterate over the tool calls in the last message | |
| for tool_call in state["messages"][-1].tool_calls: | |
| # Get the tool by name | |
| tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"]) | |
| outputs.append( | |
| ToolMessage( | |
| content=tool_result, | |
| name=tool_call["name"], | |
| tool_call_id=tool_call["id"], | |
| ) | |
| ) | |
| return {"messages": outputs} | |
| def call_model( state: AgentState, config: RunnableConfig): | |
| # Invoke the model with the system prompt and the messages | |
| response = model.invoke(state["messages"], config) | |
| # We return a list, because this will get added to the existing messages state using the add_messages reducer | |
| return {"messages": [response]} | |
| # Define the conditional edge that determines whether to continue or not | |
| def should_continue(state: AgentState): | |
| messages = state["messages"] | |
| # If the last message is not a tool call, then we finish | |
| if not messages[-1].tool_calls: | |
| return "end" | |
| # default to continue | |
| return "continue" | |
| def get_agent(): | |
| # Creazione del grafo | |
| workflow = StateGraph(AgentState) | |
| # 1. Add our nodes | |
| workflow.add_node("llm", call_model) | |
| workflow.add_node("tools", call_tool) | |
| # 2. Set the entrypoint as `agent`, this is the first node called | |
| workflow.set_entry_point("llm") | |
| # 3. Add a conditional edge after the `llm` node is called. | |
| workflow.add_conditional_edges( | |
| # Edge is used after the `llm` node is called. | |
| "llm", | |
| # The function that will determine which node is called next. | |
| should_continue, | |
| # Mapping for where to go next, keys are strings from the function return, and the values are other nodes. | |
| # END is a special node marking that the graph is finish. | |
| { | |
| # If `tools`, then we call the tool node. | |
| "continue": "tools", | |
| # Otherwise we finish. | |
| "end": END, | |
| }, | |
| ) | |
| # 4. Add a normal edge after `tools` is called, `llm` node is called next. | |
| workflow.add_edge("tools", "llm") | |
| # 5. Now we can compile our graph | |
| react_graph = workflow.compile() | |
| return react_graph | |
| # Riferimenti | |
| # | |
| # https://ai.google.dev/gemini-api/docs/langgraph-example |