Spaces:
Sleeping
Sleeping
| """ | |
| Module for creating and configuring the LangGraph agent workflow. | |
| """ | |
| from typing import List, Dict, Any | |
| from langsmith import traceable | |
| from pydantic import BaseModel | |
| from typing import Annotated | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.graph.message import add_messages | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_openai import ChatOpenAI | |
| _model = None # Initialize at module level | |
| class AgentState(BaseModel): | |
| """ | |
| State model for the LangGraph agent. | |
| """ | |
| messages: Annotated[list, add_messages] | |
| def call_model(state: AgentState) -> AgentState: | |
| """ | |
| Node function that calls the LLM to generate a response. | |
| Args: | |
| state: Current state containing messages | |
| Returns: | |
| Updated state with model response | |
| """ | |
| try: | |
| # Access the model from the global context | |
| model = call_model.__globals__["_model"] | |
| messages = state.messages | |
| response = model.invoke(messages) | |
| return {"messages": [response]} | |
| except Exception as e: | |
| # Handle the error gracefully | |
| error_message = f"Error calling model: {str(e)}" | |
| print(f"ERROR: {error_message}") | |
| # Return a message that indicates there was an error | |
| from langchain_core.messages import AIMessage | |
| return {"messages": [AIMessage(content=f"I encountered an error: {error_message}. Please check your API keys and try again.")]} | |
| def should_continue(state: AgentState) -> str: | |
| """ | |
| Conditional edge function to determine the next node. | |
| Args: | |
| state: Current agent state | |
| Returns: | |
| String indicating the next node or END | |
| """ | |
| last_message = state.messages[-1] | |
| if last_message.tool_calls: | |
| return "action" | |
| return END | |
| def create_agent_graph(tools: List, model: ChatOpenAI) -> StateGraph: | |
| """ | |
| Create the LangGraph agent workflow. | |
| Args: | |
| tools: List of LangChain tools | |
| model: ChatOpenAI model with tools bound | |
| Returns: | |
| Compiled StateGraph | |
| """ | |
| # Create tool node to execute tools | |
| tool_node = ToolNode(tools) | |
| # Set model in the global context for call_model function | |
| # This avoids issues with serialization when the graph is compiled | |
| global _model | |
| _model = model | |
| # Create the graph | |
| graph = StateGraph(AgentState) | |
| # Add nodes | |
| graph.add_node("action", tool_node) | |
| graph.add_node("agent", call_model) | |
| # Add edges | |
| graph.set_entry_point("agent") | |
| graph.add_edge("action", "agent") | |
| graph.add_conditional_edges( | |
| "agent", | |
| should_continue | |
| ) | |
| # Compile the graph | |
| return graph.compile() | |