bstraehle commited on
Commit
1dbbcfc
·
verified ·
1 Parent(s): 30b2934

Update multi_agent.py

Browse files
Files changed (1) hide show
  1. multi_agent.py +39 -64
multi_agent.py CHANGED
@@ -28,9 +28,6 @@ class AgentState(TypedDict):
28
  next: str
29
 
30
  def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
31
- print("## create_agent")
32
- global num_moves
33
-
34
  prompt = ChatPromptTemplate.from_messages(
35
  [
36
  ("system", system_prompt),
@@ -46,12 +43,11 @@ def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
46
  handle_parsing_errors=True,
47
  return_intermediate_steps=True,
48
  verbose=True,
49
- max_iterations=num_moves)
50
 
51
  def agent_node(state, agent, name):
52
  try:
53
- print("## agent_node=" + name)
54
- print("## state=" + str(state))
55
  result = agent.invoke(state)
56
  return {"messages": [HumanMessage(content=result["output"], name=name)]}
57
  except Exception as e:
@@ -64,7 +60,7 @@ def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format"]:
64
  The input should always be an empty string,
65
  and this function will always return legal moves in UCI format."""
66
  try:
67
- print("## get_legal_moves")
68
  global legal_moves
69
  legal_moves = ",".join([str(move) for move in board.legal_moves])
70
  return legal_moves
@@ -78,9 +74,13 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
78
  The input should always be a move in UCI format,
79
  and this function will always return the result of the move in UCI format."""
80
  try:
81
- print("## make_move")
82
  move = chess.Move.from_uci(move)
83
  board.push_uci(str(move))
 
 
 
 
84
 
85
  board_svgs.append(chess.svg.board(
86
  board,
@@ -96,10 +96,6 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
96
  if piece_symbol.isupper()
97
  else chess.piece_name(piece.piece_type)
98
  )
99
-
100
- global move_num
101
- move_num += 1
102
- print("## move_num=" + str(move_num))
103
 
104
  return f"Moved {piece_name} ({piece_symbol}) from "\
105
  f"{chess.SQUARE_NAMES[move.from_square]} to "\
@@ -109,9 +105,12 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
109
  return f"Error: unable to make move {move}"
110
 
111
  def create_graph():
112
- print("## create_graph")
113
-
114
  players = ["player_white", "player_black"]
 
 
 
 
 
115
 
116
  system_prompt = (
117
  "You are a Chess Board Proxy tasked with managing a game of chess "
@@ -119,9 +118,6 @@ def create_graph():
119
  "then the players take turns."
120
  )
121
 
122
- #options = ["FINISH"] + players
123
- options = players
124
-
125
  function_def = {
126
  "name": "route",
127
  "description": "Select the next player.",
@@ -153,34 +149,35 @@ def create_graph():
153
  ]
154
  ).partial(options=str(options), members=", ".join(players), verbose=True)
155
 
156
- llm_1 = ChatOpenAI(model="gpt-4o")
157
- llm_2 = ChatOpenAI(model="gpt-4o")
158
- llm_3 = ChatOpenAI(model="gpt-4o")
159
-
160
  supervisor_chain = (
161
  prompt
162
- | llm_1.bind_functions(functions=[function_def], function_call="route")
163
  | JsonOutputFunctionsParser()
164
  )
165
 
166
- player_white_agent = create_agent(llm_2, [get_legal_moves, make_move], system_prompt=
167
  "You are a chess Grandmaster and you play as white. "
168
- "1. First call get_legal_moves(), to get a list of legal moves. "
169
- "2. Then call make_move(move) to make a move. ONLY make a move in the list returned by step 1.")
170
- #"3. Finally analyze the move in format: **Analysis:** move in UCI format, emoji of piece emoji, unordered list.")
171
  player_white_node = functools.partial(agent_node, agent=player_white_agent, name="player_white")
172
 
173
- player_black_agent = create_agent(llm_3, [get_legal_moves, make_move], system_prompt=
174
  "You are a chess Grandmaster and you play as black. "
175
- "1. First call get_legal_moves(), to get a list of legal moves. "
176
- "2. Then call make_move(move) to make a move. ONLY make a move in the list returned by step 1.")
177
- #"3. Finally analyze the move in format: **Analysis:** move in UCI format, emoji of piece emoji, unordered list.")
178
  player_black_node = functools.partial(agent_node, agent=player_black_agent, name="player_black")
179
 
180
  graph = StateGraph(AgentState)
 
 
181
  graph.add_node("player_white", player_white_node)
182
  graph.add_node("player_black", player_black_node)
183
- graph.add_node("chess_board_proxy", supervisor_chain)
 
 
 
184
 
185
  graph.add_conditional_edges(
186
  "player_white",
@@ -194,10 +191,6 @@ def create_graph():
194
  {"chess_board_proxy": "chess_board_proxy", END: END}
195
  )
196
 
197
- conditional_map = {k: k for k in players}
198
- conditional_map["END"] = END
199
-
200
- graph.add_conditional_edges("chess_board_proxy", lambda x: x["next"], conditional_map)
201
  graph.set_entry_point("chess_board_proxy")
202
 
203
  return graph.compile()
@@ -205,13 +198,13 @@ def create_graph():
205
  def should_continue(state):
206
  print("#### should_continue")
207
  global move_num, num_moves, legal_moves
 
208
  if move_num == num_moves:
209
- print("False (move_num == num_moves)")
210
  return END # max moves reached
 
211
  if not legal_moves:
212
- print("False (not legal_moves)")
213
  return END # checkmate or stalemate
214
- print("True")
215
  return "chess_board_proxy"
216
 
217
  def initialize():
@@ -226,21 +219,17 @@ def initialize():
226
  legal_moves = ""
227
 
228
  def run_multi_agent(moves_num):
229
- initialize()
230
-
231
  global num_moves
232
-
233
  num_moves = moves_num
234
-
235
- print("## START")
236
- print("## num_moves=" + str(num_moves))
237
-
238
  graph = create_graph()
239
 
240
  result = ""
241
 
242
  try:
243
- config = {"recursion_limit": 500}
244
 
245
  result = graph.invoke({
246
  "messages": [
@@ -250,22 +239,10 @@ def run_multi_agent(moves_num):
250
  except Exception as e:
251
  print(f"An error occurred: {e}")
252
 
253
- ###
254
-
255
- result2 = ""
256
  num_move = 0
257
 
258
- """
259
- for message in result["messages"]:
260
- if message.name:
261
- print(f"{message.name}: {message.content}")
262
- else:
263
- print(message.content)
264
- """
265
 
266
- print("### "+ str(type(result)))
267
- print("### "+ str(len(result["messages"])))
268
-
269
  if "messages" in result:
270
  for message in result["messages"]:
271
  player = ""
@@ -276,7 +253,7 @@ def run_multi_agent(moves_num):
276
  player = "Player White"
277
 
278
  if num_move > 0:
279
- result2 += f"**{player}, Move {num_move}**\n{message.content}\n{board_svgs[num_move - 1]}\n\n"
280
 
281
  num_move += 1
282
 
@@ -284,9 +261,7 @@ def run_multi_agent(moves_num):
284
  break
285
 
286
  print("===")
287
- print(str(result))
288
- print("===")
289
- print(str(result2))
290
  print("===")
291
 
292
- return str(result2)
 
28
  next: str
29
 
30
  def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
 
 
 
31
  prompt = ChatPromptTemplate.from_messages(
32
  [
33
  ("system", system_prompt),
 
43
  handle_parsing_errors=True,
44
  return_intermediate_steps=True,
45
  verbose=True,
46
+ max_iterations=2) # get_legal_moves & make_move
47
 
48
  def agent_node(state, agent, name):
49
  try:
50
+ print("### node=" + name)
 
51
  result = agent.invoke(state)
52
  return {"messages": [HumanMessage(content=result["output"], name=name)]}
53
  except Exception as e:
 
60
  The input should always be an empty string,
61
  and this function will always return legal moves in UCI format."""
62
  try:
63
+ print("### get_legal_moves")
64
  global legal_moves
65
  legal_moves = ",".join([str(move) for move in board.legal_moves])
66
  return legal_moves
 
74
  The input should always be a move in UCI format,
75
  and this function will always return the result of the move in UCI format."""
76
  try:
77
+ print("### make_move")
78
  move = chess.Move.from_uci(move)
79
  board.push_uci(str(move))
80
+
81
+ global move_num
82
+ move_num += 1
83
+ print("### move_num=" + str(move_num))
84
 
85
  board_svgs.append(chess.svg.board(
86
  board,
 
96
  if piece_symbol.isupper()
97
  else chess.piece_name(piece.piece_type)
98
  )
 
 
 
 
99
 
100
  return f"Moved {piece_name} ({piece_symbol}) from "\
101
  f"{chess.SQUARE_NAMES[move.from_square]} to "\
 
105
  return f"Error: unable to make move {move}"
106
 
107
  def create_graph():
 
 
108
  players = ["player_white", "player_black"]
109
+ options = players
110
+
111
+ llm_board_proxy = ChatOpenAI(model="gpt-4o")
112
+ llm_player_white = ChatOpenAI(model="gpt-4o")
113
+ llm_player_black = ChatOpenAI(model="gpt-4o")
114
 
115
  system_prompt = (
116
  "You are a Chess Board Proxy tasked with managing a game of chess "
 
118
  "then the players take turns."
119
  )
120
 
 
 
 
121
  function_def = {
122
  "name": "route",
123
  "description": "Select the next player.",
 
149
  ]
150
  ).partial(options=str(options), members=", ".join(players), verbose=True)
151
 
 
 
 
 
152
  supervisor_chain = (
153
  prompt
154
+ | llm_board_proxy.bind_functions(functions=[function_def], function_call="route")
155
  | JsonOutputFunctionsParser()
156
  )
157
 
158
+ player_white_agent = create_agent(llm_player_white, [get_legal_moves, make_move], system_prompt=
159
  "You are a chess Grandmaster and you play as white. "
160
+ "First call get_legal_moves() to get a list of legal moves. "
161
+ "Then study the returned moves and call make_move(move) to make the best move. "
162
+ "Finally analyze the move in format: **Analysis:** move in UCI format, emoji of piece, unordered list of 3 items.")
163
  player_white_node = functools.partial(agent_node, agent=player_white_agent, name="player_white")
164
 
165
+ player_black_agent = create_agent(llm_player_black, [get_legal_moves, make_move], system_prompt=
166
  "You are a chess Grandmaster and you play as black. "
167
+ "First call get_legal_moves() to get a list of legal moves. "
168
+ "Then study the returned moves and call make_move(move) to make the best move. "
169
+ "Finally analyze the move in format: **Analysis:** move in UCI format, emoji of piece, unordered list of 3 items.")
170
  player_black_node = functools.partial(agent_node, agent=player_black_agent, name="player_black")
171
 
172
  graph = StateGraph(AgentState)
173
+
174
+ graph.add_node("chess_board_proxy", supervisor_chain)
175
  graph.add_node("player_white", player_white_node)
176
  graph.add_node("player_black", player_black_node)
177
+
178
+ conditional_map = {k: k for k in players}
179
+ conditional_map["END"] = END
180
+ graph.add_conditional_edges("chess_board_proxy", lambda x: x["next"], conditional_map)
181
 
182
  graph.add_conditional_edges(
183
  "player_white",
 
191
  {"chess_board_proxy": "chess_board_proxy", END: END}
192
  )
193
 
 
 
 
 
194
  graph.set_entry_point("chess_board_proxy")
195
 
196
  return graph.compile()
 
198
  def should_continue(state):
199
  print("#### should_continue")
200
  global move_num, num_moves, legal_moves
201
+
202
  if move_num == num_moves:
 
203
  return END # max moves reached
204
+
205
  if not legal_moves:
 
206
  return END # checkmate or stalemate
207
+
208
  return "chess_board_proxy"
209
 
210
  def initialize():
 
219
  legal_moves = ""
220
 
221
  def run_multi_agent(moves_num):
 
 
222
  global num_moves
 
223
  num_moves = moves_num
224
+
225
+ initialize()
226
+
 
227
  graph = create_graph()
228
 
229
  result = ""
230
 
231
  try:
232
+ config = {"recursion_limit": 100}
233
 
234
  result = graph.invoke({
235
  "messages": [
 
239
  except Exception as e:
240
  print(f"An error occurred: {e}")
241
 
242
+ result_md = ""
 
 
243
  num_move = 0
244
 
 
 
 
 
 
 
 
245
 
 
 
 
246
  if "messages" in result:
247
  for message in result["messages"]:
248
  player = ""
 
253
  player = "Player White"
254
 
255
  if num_move > 0:
256
+ result_md += f"**{player}, Move {num_move}**\n{message.content}\n{board_svgs[num_move - 1]}\n\n"
257
 
258
  num_move += 1
259
 
 
261
  break
262
 
263
  print("===")
264
+ print(result_md)
 
 
265
  print("===")
266
 
267
+ return result_md