File size: 2,843 Bytes
df1a0e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import TypedDict, Annotated, List
import operator
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import BaseMessage, HumanMessage
from langgraph.graph import StateGraph, END, START
from langgraph.prebuilt import ToolNode
from langchain_tavily import TavilySearch
import google.auth
from dotenv import load_dotenv

load_dotenv()


# Set up Google credentials
try:
    _, project_id = google.auth.default()
    os.environ["GOOGLE_CLOUD_PROJECT"] = project_id
    os.environ["GOOGLE_CLOUD_LOCATION"] = "global"
    os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "True"
except google.auth.exceptions.DefaultCredentialsError:
    print("Google Cloud credentials not found. Please configure your credentials.")
    # You might want to fall back to an API key or raise an exception here
    # For this example, we'll proceed, but it will likely fail if not configured
    pass


# 1. Define the state
class AgentState(TypedDict):
    messages: Annotated[List[BaseMessage], operator.add]

# 2. Define the tools
tools = [TavilySearch(max_results=1)]
tool_node = ToolNode(tools)

# 3. Define the model
LLM = "gemini-1.5-flash"
model = ChatGoogleGenerativeAI(model=LLM, temperature=0)
model = model.bind_tools(tools)

# 4. Define the agent node
def should_continue(state):
    messages = state['messages']
    last_message = messages[-1]
    # If there are no tool calls, then we finish
    if not last_message.tool_calls:
        return "end"
    # Otherwise if there are tool calls, we continue
    else:
        return "continue"

def call_model(state):
    messages = state['messages']
    response = model.invoke(messages)
    # We return a list, because this will get added to the existing list
    return {"messages": [response]}

# 5. Create the graph
workflow = StateGraph(AgentState)

# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)

# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.add_edge(START, "agent")

# We now add a conditional edge
workflow.add_conditional_edges(
    "agent",
    should_continue,
    {
        "continue": "action",
        "end": END,
    },
)

# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")

# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()


class LangGraphAgent:
    def __init__(self):
        self.app = app

    def __call__(self, question: str) -> str:
        inputs = {"messages": [HumanMessage(content=question)]}
        final_state = self.app.invoke(inputs)
        return final_state['messages'][-1].content