niikun commited on
Commit
9826bc1
·
1 Parent(s): fcdebb0

update tools

Browse files
Files changed (1) hide show
  1. app.py +19 -25
app.py CHANGED
@@ -28,42 +28,36 @@ class State(TypedDict):
28
  class BasicAgent:
29
  def __init__(self):
30
  print("BasicAgent initialized.")
31
- graph_builder = StateGraph(State)
32
 
33
- # ツールを定義
34
  tool = TavilySearchResults(max_results=5)
35
- tools = [tool]
36
  llm = ChatOpenAI(model="gpt-4o", temperature=0)
37
- llm_with_tools = llm.bind_tools(tools)
38
 
39
- # ノード登録:まず「チャット」ノード
40
  def chatbot_node(state: State):
41
- # LLM+ツールを呼び出し、メッセージ履歴を更新
42
- response_msg = llm_with_tools.invoke(state["messages"])
43
- return {"messages": state["messages"] + [response_msg]}
44
 
45
- graph_builder.add_node("chatbot", chatbot_node)
46
 
47
- # ツールノード登録
48
- tool_node = ToolNode(tools=tools)
49
- graph_builder.add_node("tools", tool_node)
50
 
51
- # 条件付きエッジ:チャットノード→(必要なら)ツール→チャットノード
52
- graph_builder.add_conditional_edges("chatbot", tools_condition)
53
- # スタートからチャットへ
54
- graph_builder.add_edge(START, "chatbot")
55
- # ツールノードからチャットノードへ戻る
56
- graph_builder.add_edge("tools", "chatbot")
57
 
58
- # グラフをコンパイルして保持
59
- self.graph = graph_builder.compile()
60
 
61
  def __call__(self, question: str) -> str:
62
- init_state = {"messages": [question]}
63
- # グラフを実行
64
- final_state = self.graph(init_state)
65
- # 最後のメッセージを返す
66
- return final_state["messages"][-1]
67
 
68
  def run_and_submit_all( profile: gr.OAuthProfile | None):
69
  """
 
28
  class BasicAgent:
29
  def __init__(self):
30
  print("BasicAgent initialized.")
31
+ gb = StateGraph(State)
32
 
33
+ # 1) define LLM + tool
34
  tool = TavilySearchResults(max_results=5)
 
35
  llm = ChatOpenAI(model="gpt-4o", temperature=0)
36
+ llm_with_tools = llm.bind_tools([tool])
37
 
38
+ # 2) chatbot node
39
  def chatbot_node(state: State):
40
+ resp = llm_with_tools.invoke(state["messages"])
41
+ return {"messages": state["messages"] + [resp]}
 
42
 
43
+ gb.add_node("chatbot", chatbot_node)
44
 
45
+ # 3) tool node must be named "tools"
46
+ tools_node = ToolNode(tools=[tool])
47
+ gb.add_node("tools", tools_node)
48
 
49
+ # 4) hook up the conditional edges
50
+ gb.add_conditional_edges("chatbot", tools_condition)
51
+ gb.add_edge(START, "chatbot")
52
+ gb.add_edge("tools", "chatbot")
 
 
53
 
54
+ # compile once
55
+ self.graph = gb.compile()
56
 
57
  def __call__(self, question: str) -> str:
58
+ state = {"messages": [question]}
59
+ out = self.graph(state)
60
+ return out["messages"][-1]
 
 
61
 
62
  def run_and_submit_all( profile: gr.OAuthProfile | None):
63
  """