bstraehle commited on
Commit
9f94153
·
verified ·
1 Parent(s): 2453971

Update multi_agent.py

Browse files
Files changed (1) hide show
  1. multi_agent.py +9 -21
multi_agent.py CHANGED
@@ -20,7 +20,7 @@ board_svgs = None
20
 
21
  num_moves = 0
22
  move_num = 0
23
- made_move = None
24
 
25
  class AgentState(TypedDict):
26
  messages: Annotated[Sequence[BaseMessage], operator.add]
@@ -45,7 +45,7 @@ def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
45
  handle_parsing_errors=True,
46
  return_intermediate_steps=True,
47
  verbose=True,
48
- max_iterations=3) #get_num_turns(num_moves))
49
 
50
  def agent_node(state, agent, name):
51
  try:
@@ -64,9 +64,9 @@ def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format"]:
64
  and this function will always return legal moves in UCI format."""
65
  try:
66
  print("## get_legal_moves")
67
- return "Possible moves are: " + ",".join(
68
- [str(move) for move in board.legal_moves]
69
- )
70
  except Exception as e:
71
  print(f"An error occurred in get_legal_moves: {e}")
72
  return "Error: unable to get legal moves"
@@ -80,8 +80,6 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
80
  print("## make_move")
81
  move = chess.Move.from_uci(move)
82
  board.push_uci(str(move))
83
- global made_move
84
- made_move = True
85
 
86
  board_svgs.append(chess.svg.board(
87
  board,
@@ -108,16 +106,6 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
108
  except Exception as e:
109
  print(f"An error occurred in make_move: {e}")
110
  return f"Error: unable to make move {move}"
111
-
112
- def check_made_move(msg):
113
- print("## check_made_move")
114
- global made_move
115
-
116
- if made_move:
117
- made_move = False
118
- return True
119
- else:
120
- return False
121
 
122
  def get_num_turns(num_moves):
123
  # Each turn includes two moves (one by each player)
@@ -214,9 +202,6 @@ def create_graph():
214
  graph.add_node("player_black", player_black_node)
215
  graph.add_node("manager", supervisor_chain)
216
 
217
- # for player in players:
218
- # graph.add_edge(player, "manager")
219
-
220
  graph.add_conditional_edges(
221
  "player_white",
222
  should_continue,
@@ -239,10 +224,13 @@ def create_graph():
239
 
240
  def should_continue(state):
241
  print("#### should_continue")
242
- global move_num, num_moves
243
  if move_num == num_moves:
244
  print("False")
245
  return END # max moves reached
 
 
 
246
  print("True")
247
  return "manager"
248
 
 
20
 
21
  num_moves = 0
22
  move_num = 0
23
+ legal_moves = ""
24
 
25
  class AgentState(TypedDict):
26
  messages: Annotated[Sequence[BaseMessage], operator.add]
 
45
  handle_parsing_errors=True,
46
  return_intermediate_steps=True,
47
  verbose=True,
48
+ max_iterations=get_num_turns(num_moves))
49
 
50
  def agent_node(state, agent, name):
51
  try:
 
64
  and this function will always return legal moves in UCI format."""
65
  try:
66
  print("## get_legal_moves")
67
+ global legal_moves
68
+ legal_moves = ",".join([str(move) for move in board.legal_moves])
69
+ return legal_moves
70
  except Exception as e:
71
  print(f"An error occurred in get_legal_moves: {e}")
72
  return "Error: unable to get legal moves"
 
80
  print("## make_move")
81
  move = chess.Move.from_uci(move)
82
  board.push_uci(str(move))
 
 
83
 
84
  board_svgs.append(chess.svg.board(
85
  board,
 
106
  except Exception as e:
107
  print(f"An error occurred in make_move: {e}")
108
  return f"Error: unable to make move {move}"
 
 
 
 
 
 
 
 
 
 
109
 
110
  def get_num_turns(num_moves):
111
  # Each turn includes two moves (one by each player)
 
202
  graph.add_node("player_black", player_black_node)
203
  graph.add_node("manager", supervisor_chain)
204
 
 
 
 
205
  graph.add_conditional_edges(
206
  "player_white",
207
  should_continue,
 
224
 
225
  def should_continue(state):
226
  print("#### should_continue")
227
+ global move_num, num_moves, legal_moves
228
  if move_num == num_moves:
229
  print("False")
230
  return END # max moves reached
231
+ if not legal_moves:
232
+ print("False")
233
+ return END # checkmate or stalemate
234
  print("True")
235
  return "manager"
236