Spaces:
Sleeping
Sleeping
| from typing import TypedDict, Annotated, List | |
| import operator | |
| import os | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.messages import BaseMessage, HumanMessage | |
| from langgraph.graph import StateGraph, END, START | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_tavily import TavilySearch | |
| import google.auth | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Set up Google credentials | |
| try: | |
| _, project_id = google.auth.default() | |
| os.environ["GOOGLE_CLOUD_PROJECT"] = project_id | |
| os.environ["GOOGLE_CLOUD_LOCATION"] = "global" | |
| os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "True" | |
| except google.auth.exceptions.DefaultCredentialsError: | |
| print("Google Cloud credentials not found. Please configure your credentials.") | |
| # You might want to fall back to an API key or raise an exception here | |
| # For this example, we'll proceed, but it will likely fail if not configured | |
| pass | |
| # 1. Define the state | |
| class AgentState(TypedDict): | |
| messages: Annotated[List[BaseMessage], operator.add] | |
| # 2. Define the tools | |
| tools = [TavilySearch(max_results=1)] | |
| tool_node = ToolNode(tools) | |
| # 3. Define the model | |
| LLM = "gemini-1.5-flash" | |
| model = ChatGoogleGenerativeAI(model=LLM, temperature=0) | |
| model = model.bind_tools(tools) | |
| # 4. Define the agent node | |
| def should_continue(state): | |
| messages = state['messages'] | |
| last_message = messages[-1] | |
| # If there are no tool calls, then we finish | |
| if not last_message.tool_calls: | |
| return "end" | |
| # Otherwise if there are tool calls, we continue | |
| else: | |
| return "continue" | |
| def call_model(state): | |
| messages = state['messages'] | |
| response = model.invoke(messages) | |
| # We return a list, because this will get added to the existing list | |
| return {"messages": [response]} | |
| # 5. Create the graph | |
| workflow = StateGraph(AgentState) | |
| # Define the two nodes we will cycle between | |
| workflow.add_node("agent", call_model) | |
| workflow.add_node("action", tool_node) | |
| # Set the entrypoint as `agent` | |
| # This means that this node is the first one called | |
| workflow.add_edge(START, "agent") | |
| # We now add a conditional edge | |
| workflow.add_conditional_edges( | |
| "agent", | |
| should_continue, | |
| { | |
| "continue": "action", | |
| "end": END, | |
| }, | |
| ) | |
| # We now add a normal edge from `tools` to `agent`. | |
| # This means that after `tools` is called, `agent` node is called next. | |
| workflow.add_edge("action", "agent") | |
| # Finally, we compile it! | |
| # This compiles it into a LangChain Runnable, | |
| # meaning you can use it as you would any other runnable | |
| app = workflow.compile() | |
| class LangGraphAgent: | |
| def __init__(self): | |
| self.app = app | |
| def __call__(self, question: str) -> str: | |
| inputs = {"messages": [HumanMessage(content=question)]} | |
| final_state = self.app.invoke(inputs) | |
| return final_state['messages'][-1].content | |