Update multi_agent.py
Browse files- multi_agent.py +9 -21
multi_agent.py
CHANGED
|
@@ -20,7 +20,7 @@ board_svgs = None
|
|
| 20 |
|
| 21 |
num_moves = 0
|
| 22 |
move_num = 0
|
| 23 |
-
|
| 24 |
|
| 25 |
class AgentState(TypedDict):
|
| 26 |
messages: Annotated[Sequence[BaseMessage], operator.add]
|
|
@@ -45,7 +45,7 @@ def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
|
|
| 45 |
handle_parsing_errors=True,
|
| 46 |
return_intermediate_steps=True,
|
| 47 |
verbose=True,
|
| 48 |
-
max_iterations=
|
| 49 |
|
| 50 |
def agent_node(state, agent, name):
|
| 51 |
try:
|
|
@@ -64,9 +64,9 @@ def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format"]:
|
|
| 64 |
and this function will always return legal moves in UCI format."""
|
| 65 |
try:
|
| 66 |
print("## get_legal_moves")
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
except Exception as e:
|
| 71 |
print(f"An error occurred in get_legal_moves: {e}")
|
| 72 |
return "Error: unable to get legal moves"
|
|
@@ -80,8 +80,6 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
|
|
| 80 |
print("## make_move")
|
| 81 |
move = chess.Move.from_uci(move)
|
| 82 |
board.push_uci(str(move))
|
| 83 |
-
global made_move
|
| 84 |
-
made_move = True
|
| 85 |
|
| 86 |
board_svgs.append(chess.svg.board(
|
| 87 |
board,
|
|
@@ -108,16 +106,6 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
|
|
| 108 |
except Exception as e:
|
| 109 |
print(f"An error occurred in make_move: {e}")
|
| 110 |
return f"Error: unable to make move {move}"
|
| 111 |
-
|
| 112 |
-
def check_made_move(msg):
|
| 113 |
-
print("## check_made_move")
|
| 114 |
-
global made_move
|
| 115 |
-
|
| 116 |
-
if made_move:
|
| 117 |
-
made_move = False
|
| 118 |
-
return True
|
| 119 |
-
else:
|
| 120 |
-
return False
|
| 121 |
|
| 122 |
def get_num_turns(num_moves):
|
| 123 |
# Each turn includes two moves (one by each player)
|
|
@@ -214,9 +202,6 @@ def create_graph():
|
|
| 214 |
graph.add_node("player_black", player_black_node)
|
| 215 |
graph.add_node("manager", supervisor_chain)
|
| 216 |
|
| 217 |
-
# for player in players:
|
| 218 |
-
# graph.add_edge(player, "manager")
|
| 219 |
-
|
| 220 |
graph.add_conditional_edges(
|
| 221 |
"player_white",
|
| 222 |
should_continue,
|
|
@@ -239,10 +224,13 @@ def create_graph():
|
|
| 239 |
|
| 240 |
def should_continue(state):
|
| 241 |
print("#### should_continue")
|
| 242 |
-
global move_num, num_moves
|
| 243 |
if move_num == num_moves:
|
| 244 |
print("False")
|
| 245 |
return END # max moves reached
|
|
|
|
|
|
|
|
|
|
| 246 |
print("True")
|
| 247 |
return "manager"
|
| 248 |
|
|
|
|
| 20 |
|
| 21 |
num_moves = 0
|
| 22 |
move_num = 0
|
| 23 |
+
legal_moves = ""
|
| 24 |
|
| 25 |
class AgentState(TypedDict):
|
| 26 |
messages: Annotated[Sequence[BaseMessage], operator.add]
|
|
|
|
| 45 |
handle_parsing_errors=True,
|
| 46 |
return_intermediate_steps=True,
|
| 47 |
verbose=True,
|
| 48 |
+
max_iterations=get_num_turns(num_moves))
|
| 49 |
|
| 50 |
def agent_node(state, agent, name):
|
| 51 |
try:
|
|
|
|
| 64 |
and this function will always return legal moves in UCI format."""
|
| 65 |
try:
|
| 66 |
print("## get_legal_moves")
|
| 67 |
+
global legal_moves
|
| 68 |
+
legal_moves = ",".join([str(move) for move in board.legal_moves])
|
| 69 |
+
return legal_moves
|
| 70 |
except Exception as e:
|
| 71 |
print(f"An error occurred in get_legal_moves: {e}")
|
| 72 |
return "Error: unable to get legal moves"
|
|
|
|
| 80 |
print("## make_move")
|
| 81 |
move = chess.Move.from_uci(move)
|
| 82 |
board.push_uci(str(move))
|
|
|
|
|
|
|
| 83 |
|
| 84 |
board_svgs.append(chess.svg.board(
|
| 85 |
board,
|
|
|
|
| 106 |
except Exception as e:
|
| 107 |
print(f"An error occurred in make_move: {e}")
|
| 108 |
return f"Error: unable to make move {move}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
def get_num_turns(num_moves):
|
| 111 |
# Each turn includes two moves (one by each player)
|
|
|
|
| 202 |
graph.add_node("player_black", player_black_node)
|
| 203 |
graph.add_node("manager", supervisor_chain)
|
| 204 |
|
|
|
|
|
|
|
|
|
|
| 205 |
graph.add_conditional_edges(
|
| 206 |
"player_white",
|
| 207 |
should_continue,
|
|
|
|
| 224 |
|
| 225 |
def should_continue(state):
|
| 226 |
print("#### should_continue")
|
| 227 |
+
global move_num, num_moves, legal_moves
|
| 228 |
if move_num == num_moves:
|
| 229 |
print("False")
|
| 230 |
return END # max moves reached
|
| 231 |
+
if not legal_moves:
|
| 232 |
+
print("False")
|
| 233 |
+
return END # checkmate or stalemate
|
| 234 |
print("True")
|
| 235 |
return "manager"
|
| 236 |
|