ktluege commited on
Commit
885ae0e
Β·
verified Β·
1 Parent(s): a023ff4

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +8 -5
agent.py CHANGED
@@ -41,6 +41,13 @@ def build_graph(provider: str = "openai"):
41
  raise ValueError("Invalid provider.")
42
 
43
  llm_with_tools = llm.bind_tools(tools)
 
 
 
 
 
 
 
44
 
45
  def retriever(state: MessagesState):
46
  query = state["messages"][-1].content
@@ -55,11 +62,6 @@ def build_graph(provider: str = "openai"):
55
  else:
56
  return {"messages": [AIMessage(content=content.strip())]}
57
 
58
- def assistant(state: MessagesState):
59
- user_message = state["messages"][-1]
60
- # Make sure you send both system and user message
61
- result = llm_with_tools.invoke([sys_msg, user_message])
62
- return {"messages": [result]}
63
 
64
  builder = StateGraph(MessagesState)
65
  builder.add_node("retriever", retriever)
@@ -67,5 +69,6 @@ def build_graph(provider: str = "openai"):
67
  builder.add_edge(START, "retriever")
68
  builder.add_edge("retriever", "assistant")
69
  builder.set_finish_point("assistant")
 
70
 
71
  return builder.compile()
 
41
  raise ValueError("Invalid provider.")
42
 
43
  llm_with_tools = llm.bind_tools(tools)
44
+ def assistant(state: MessagesState):
45
+ user_message = state["messages"][-1]
46
+ # You must have llm_with_tools defined earlier in build_graph
47
+ result = llm_with_tools.invoke([sys_msg, user_message])
48
+ return {"messages": [result]}
49
+
50
+
51
 
52
  def retriever(state: MessagesState):
53
  query = state["messages"][-1].content
 
62
  else:
63
  return {"messages": [AIMessage(content=content.strip())]}
64
 
 
 
 
 
 
65
 
66
  builder = StateGraph(MessagesState)
67
  builder.add_node("retriever", retriever)
 
69
  builder.add_edge(START, "retriever")
70
  builder.add_edge("retriever", "assistant")
71
  builder.set_finish_point("assistant")
72
+ return builder.compile()
73
 
74
  return builder.compile()