LangGraph_agent / graph.py
pratikmurali's picture
Fixing model calling in graph.py
89cb1de
"""
Module for creating and configuring the LangGraph agent workflow.
"""
from typing import List, Dict, Any
from langsmith import traceable
from pydantic import BaseModel
from typing import Annotated
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langchain_openai import ChatOpenAI
_model = None # Initialize at module level
class AgentState(BaseModel):
"""
State model for the LangGraph agent.
"""
messages: Annotated[list, add_messages]
def call_model(state: AgentState) -> AgentState:
"""
Node function that calls the LLM to generate a response.
Args:
state: Current state containing messages
Returns:
Updated state with model response
"""
try:
# Access the model from the global context
model = call_model.__globals__["_model"]
messages = state.messages
response = model.invoke(messages)
return {"messages": [response]}
except Exception as e:
# Handle the error gracefully
error_message = f"Error calling model: {str(e)}"
print(f"ERROR: {error_message}")
# Return a message that indicates there was an error
from langchain_core.messages import AIMessage
return {"messages": [AIMessage(content=f"I encountered an error: {error_message}. Please check your API keys and try again.")]}
def should_continue(state: AgentState) -> str:
"""
Conditional edge function to determine the next node.
Args:
state: Current agent state
Returns:
String indicating the next node or END
"""
last_message = state.messages[-1]
if last_message.tool_calls:
return "action"
return END
def create_agent_graph(tools: List, model: ChatOpenAI) -> StateGraph:
"""
Create the LangGraph agent workflow.
Args:
tools: List of LangChain tools
model: ChatOpenAI model with tools bound
Returns:
Compiled StateGraph
"""
# Create tool node to execute tools
tool_node = ToolNode(tools)
# Set model in the global context for call_model function
# This avoids issues with serialization when the graph is compiled
global _model
_model = model
# Create the graph
graph = StateGraph(AgentState)
# Add nodes
graph.add_node("action", tool_node)
graph.add_node("agent", call_model)
# Add edges
graph.set_entry_point("agent")
graph.add_edge("action", "agent")
graph.add_conditional_edges(
"agent",
should_continue
)
# Compile the graph
return graph.compile()