agent_test / agent.py
blazingbunny's picture
Upload 6 files
df1a0e5 verified
raw
history blame
2.84 kB
from typing import TypedDict, Annotated, List
import operator
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import BaseMessage, HumanMessage
from langgraph.graph import StateGraph, END, START
from langgraph.prebuilt import ToolNode
from langchain_tavily import TavilySearch
import google.auth
from dotenv import load_dotenv
load_dotenv()
# Set up Google credentials
try:
_, project_id = google.auth.default()
os.environ["GOOGLE_CLOUD_PROJECT"] = project_id
os.environ["GOOGLE_CLOUD_LOCATION"] = "global"
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "True"
except google.auth.exceptions.DefaultCredentialsError:
print("Google Cloud credentials not found. Please configure your credentials.")
# You might want to fall back to an API key or raise an exception here
# For this example, we'll proceed, but it will likely fail if not configured
pass
# 1. Define the state
class AgentState(TypedDict):
messages: Annotated[List[BaseMessage], operator.add]
# 2. Define the tools
tools = [TavilySearch(max_results=1)]
tool_node = ToolNode(tools)
# 3. Define the model
LLM = "gemini-1.5-flash"
model = ChatGoogleGenerativeAI(model=LLM, temperature=0)
model = model.bind_tools(tools)
# 4. Define the agent node
def should_continue(state):
messages = state['messages']
last_message = messages[-1]
# If there are no tool calls, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there are tool calls, we continue
else:
return "continue"
def call_model(state):
messages = state['messages']
response = model.invoke(messages)
# We return a list, because this will get added to the existing list
return {"messages": [response]}
# 5. Create the graph
workflow = StateGraph(AgentState)
# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.add_edge(START, "agent")
# We now add a conditional edge
workflow.add_conditional_edges(
"agent",
should_continue,
{
"continue": "action",
"end": END,
},
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()
class LangGraphAgent:
def __init__(self):
self.app = app
def __call__(self, question: str) -> str:
inputs = {"messages": [HumanMessage(content=question)]}
final_state = self.app.invoke(inputs)
return final_state['messages'][-1].content