File size: 2,726 Bytes
e521af9
 
 
 
 
 
 
 
 
 
 
 
 
 
fbb910a
 
 
e521af9
 
 
 
 
 
fbb910a
 
e521af9
 
 
 
 
 
 
 
 
 
026a398
 
89cb1de
026a398
 
 
 
 
 
 
 
 
 
e521af9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Module for creating and configuring the LangGraph agent workflow.
"""
from typing import List, Dict, Any

from langsmith import traceable
from pydantic import BaseModel
from typing import Annotated

from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langchain_openai import ChatOpenAI


_model = None  # Initialize at module level

class AgentState(BaseModel):
    """
    State model for the LangGraph agent.
    """
    messages: Annotated[list, add_messages]



def call_model(state: AgentState) -> AgentState:
    """
    Node function that calls the LLM to generate a response.
    
    Args:
        state: Current state containing messages
        
    Returns:
        Updated state with model response
    """
    try:
        # Access the model from the global context
        model = call_model.__globals__["_model"]
        messages = state.messages
        response = model.invoke(messages)
        return {"messages": [response]}
    except Exception as e:
        # Handle the error gracefully
        error_message = f"Error calling model: {str(e)}"
        print(f"ERROR: {error_message}")
        # Return a message that indicates there was an error
        from langchain_core.messages import AIMessage
        return {"messages": [AIMessage(content=f"I encountered an error: {error_message}. Please check your API keys and try again.")]}

def should_continue(state: AgentState) -> str:
    """
    Conditional edge function to determine the next node.
    
    Args:
        state: Current agent state
        
    Returns:
        String indicating the next node or END
    """
    last_message = state.messages[-1]

    if last_message.tool_calls:
        return "action"

    return END

def create_agent_graph(tools: List, model: ChatOpenAI) -> StateGraph:
    """
    Create the LangGraph agent workflow.
    
    Args:
        tools: List of LangChain tools
        model: ChatOpenAI model with tools bound
        
    Returns:
        Compiled StateGraph
    """
    # Create tool node to execute tools
    tool_node = ToolNode(tools)
    
    # Set model in the global context for call_model function
    # This avoids issues with serialization when the graph is compiled
    global _model
    _model = model
 
    
    # Create the graph
    graph = StateGraph(AgentState)
    
    # Add nodes
    graph.add_node("action", tool_node)
    graph.add_node("agent", call_model)
    
    # Add edges
    graph.set_entry_point("agent")
    graph.add_edge("action", "agent")
    graph.add_conditional_edges(
        "agent",
        should_continue
    )
    
    # Compile the graph
    return graph.compile()