Update multi_agent.py
Browse files- multi_agent.py +39 -64
multi_agent.py
CHANGED
|
@@ -28,9 +28,6 @@ class AgentState(TypedDict):
|
|
| 28 |
next: str
|
| 29 |
|
| 30 |
def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
|
| 31 |
-
print("## create_agent")
|
| 32 |
-
global num_moves
|
| 33 |
-
|
| 34 |
prompt = ChatPromptTemplate.from_messages(
|
| 35 |
[
|
| 36 |
("system", system_prompt),
|
|
@@ -46,12 +43,11 @@ def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
|
|
| 46 |
handle_parsing_errors=True,
|
| 47 |
return_intermediate_steps=True,
|
| 48 |
verbose=True,
|
| 49 |
-
max_iterations=
|
| 50 |
|
| 51 |
def agent_node(state, agent, name):
|
| 52 |
try:
|
| 53 |
-
print("##
|
| 54 |
-
print("## state=" + str(state))
|
| 55 |
result = agent.invoke(state)
|
| 56 |
return {"messages": [HumanMessage(content=result["output"], name=name)]}
|
| 57 |
except Exception as e:
|
|
@@ -64,7 +60,7 @@ def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format"]:
|
|
| 64 |
The input should always be an empty string,
|
| 65 |
and this function will always return legal moves in UCI format."""
|
| 66 |
try:
|
| 67 |
-
print("## get_legal_moves")
|
| 68 |
global legal_moves
|
| 69 |
legal_moves = ",".join([str(move) for move in board.legal_moves])
|
| 70 |
return legal_moves
|
|
@@ -78,9 +74,13 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
|
|
| 78 |
The input should always be a move in UCI format,
|
| 79 |
and this function will always return the result of the move in UCI format."""
|
| 80 |
try:
|
| 81 |
-
print("## make_move")
|
| 82 |
move = chess.Move.from_uci(move)
|
| 83 |
board.push_uci(str(move))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
board_svgs.append(chess.svg.board(
|
| 86 |
board,
|
|
@@ -96,10 +96,6 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
|
|
| 96 |
if piece_symbol.isupper()
|
| 97 |
else chess.piece_name(piece.piece_type)
|
| 98 |
)
|
| 99 |
-
|
| 100 |
-
global move_num
|
| 101 |
-
move_num += 1
|
| 102 |
-
print("## move_num=" + str(move_num))
|
| 103 |
|
| 104 |
return f"Moved {piece_name} ({piece_symbol}) from "\
|
| 105 |
f"{chess.SQUARE_NAMES[move.from_square]} to "\
|
|
@@ -109,9 +105,12 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
|
|
| 109 |
return f"Error: unable to make move {move}"
|
| 110 |
|
| 111 |
def create_graph():
|
| 112 |
-
print("## create_graph")
|
| 113 |
-
|
| 114 |
players = ["player_white", "player_black"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
system_prompt = (
|
| 117 |
"You are a Chess Board Proxy tasked with managing a game of chess "
|
|
@@ -119,9 +118,6 @@ def create_graph():
|
|
| 119 |
"then the players take turns."
|
| 120 |
)
|
| 121 |
|
| 122 |
-
#options = ["FINISH"] + players
|
| 123 |
-
options = players
|
| 124 |
-
|
| 125 |
function_def = {
|
| 126 |
"name": "route",
|
| 127 |
"description": "Select the next player.",
|
|
@@ -153,34 +149,35 @@ def create_graph():
|
|
| 153 |
]
|
| 154 |
).partial(options=str(options), members=", ".join(players), verbose=True)
|
| 155 |
|
| 156 |
-
llm_1 = ChatOpenAI(model="gpt-4o")
|
| 157 |
-
llm_2 = ChatOpenAI(model="gpt-4o")
|
| 158 |
-
llm_3 = ChatOpenAI(model="gpt-4o")
|
| 159 |
-
|
| 160 |
supervisor_chain = (
|
| 161 |
prompt
|
| 162 |
-
|
|
| 163 |
| JsonOutputFunctionsParser()
|
| 164 |
)
|
| 165 |
|
| 166 |
-
player_white_agent = create_agent(
|
| 167 |
"You are a chess Grandmaster and you play as white. "
|
| 168 |
-
"
|
| 169 |
-
"
|
| 170 |
-
|
| 171 |
player_white_node = functools.partial(agent_node, agent=player_white_agent, name="player_white")
|
| 172 |
|
| 173 |
-
player_black_agent = create_agent(
|
| 174 |
"You are a chess Grandmaster and you play as black. "
|
| 175 |
-
"
|
| 176 |
-
"
|
| 177 |
-
|
| 178 |
player_black_node = functools.partial(agent_node, agent=player_black_agent, name="player_black")
|
| 179 |
|
| 180 |
graph = StateGraph(AgentState)
|
|
|
|
|
|
|
| 181 |
graph.add_node("player_white", player_white_node)
|
| 182 |
graph.add_node("player_black", player_black_node)
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
graph.add_conditional_edges(
|
| 186 |
"player_white",
|
|
@@ -194,10 +191,6 @@ def create_graph():
|
|
| 194 |
{"chess_board_proxy": "chess_board_proxy", END: END}
|
| 195 |
)
|
| 196 |
|
| 197 |
-
conditional_map = {k: k for k in players}
|
| 198 |
-
conditional_map["END"] = END
|
| 199 |
-
|
| 200 |
-
graph.add_conditional_edges("chess_board_proxy", lambda x: x["next"], conditional_map)
|
| 201 |
graph.set_entry_point("chess_board_proxy")
|
| 202 |
|
| 203 |
return graph.compile()
|
|
@@ -205,13 +198,13 @@ def create_graph():
|
|
| 205 |
def should_continue(state):
|
| 206 |
print("#### should_continue")
|
| 207 |
global move_num, num_moves, legal_moves
|
|
|
|
| 208 |
if move_num == num_moves:
|
| 209 |
-
print("False (move_num == num_moves)")
|
| 210 |
return END # max moves reached
|
|
|
|
| 211 |
if not legal_moves:
|
| 212 |
-
print("False (not legal_moves)")
|
| 213 |
return END # checkmate or stalemate
|
| 214 |
-
|
| 215 |
return "chess_board_proxy"
|
| 216 |
|
| 217 |
def initialize():
|
|
@@ -226,21 +219,17 @@ def initialize():
|
|
| 226 |
legal_moves = ""
|
| 227 |
|
| 228 |
def run_multi_agent(moves_num):
|
| 229 |
-
initialize()
|
| 230 |
-
|
| 231 |
global num_moves
|
| 232 |
-
|
| 233 |
num_moves = moves_num
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
graph = create_graph()
|
| 239 |
|
| 240 |
result = ""
|
| 241 |
|
| 242 |
try:
|
| 243 |
-
config = {"recursion_limit":
|
| 244 |
|
| 245 |
result = graph.invoke({
|
| 246 |
"messages": [
|
|
@@ -250,22 +239,10 @@ def run_multi_agent(moves_num):
|
|
| 250 |
except Exception as e:
|
| 251 |
print(f"An error occurred: {e}")
|
| 252 |
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
result2 = ""
|
| 256 |
num_move = 0
|
| 257 |
|
| 258 |
-
"""
|
| 259 |
-
for message in result["messages"]:
|
| 260 |
-
if message.name:
|
| 261 |
-
print(f"{message.name}: {message.content}")
|
| 262 |
-
else:
|
| 263 |
-
print(message.content)
|
| 264 |
-
"""
|
| 265 |
|
| 266 |
-
print("### "+ str(type(result)))
|
| 267 |
-
print("### "+ str(len(result["messages"])))
|
| 268 |
-
|
| 269 |
if "messages" in result:
|
| 270 |
for message in result["messages"]:
|
| 271 |
player = ""
|
|
@@ -276,7 +253,7 @@ def run_multi_agent(moves_num):
|
|
| 276 |
player = "Player White"
|
| 277 |
|
| 278 |
if num_move > 0:
|
| 279 |
-
|
| 280 |
|
| 281 |
num_move += 1
|
| 282 |
|
|
@@ -284,9 +261,7 @@ def run_multi_agent(moves_num):
|
|
| 284 |
break
|
| 285 |
|
| 286 |
print("===")
|
| 287 |
-
print(
|
| 288 |
-
print("===")
|
| 289 |
-
print(str(result2))
|
| 290 |
print("===")
|
| 291 |
|
| 292 |
-
return
|
|
|
|
| 28 |
next: str
|
| 29 |
|
| 30 |
def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
|
|
|
|
|
|
|
|
|
|
| 31 |
prompt = ChatPromptTemplate.from_messages(
|
| 32 |
[
|
| 33 |
("system", system_prompt),
|
|
|
|
| 43 |
handle_parsing_errors=True,
|
| 44 |
return_intermediate_steps=True,
|
| 45 |
verbose=True,
|
| 46 |
+
max_iterations=2) # get_legal_moves & make_move
|
| 47 |
|
| 48 |
def agent_node(state, agent, name):
|
| 49 |
try:
|
| 50 |
+
print("### node=" + name)
|
|
|
|
| 51 |
result = agent.invoke(state)
|
| 52 |
return {"messages": [HumanMessage(content=result["output"], name=name)]}
|
| 53 |
except Exception as e:
|
|
|
|
| 60 |
The input should always be an empty string,
|
| 61 |
and this function will always return legal moves in UCI format."""
|
| 62 |
try:
|
| 63 |
+
print("### get_legal_moves")
|
| 64 |
global legal_moves
|
| 65 |
legal_moves = ",".join([str(move) for move in board.legal_moves])
|
| 66 |
return legal_moves
|
|
|
|
| 74 |
The input should always be a move in UCI format,
|
| 75 |
and this function will always return the result of the move in UCI format."""
|
| 76 |
try:
|
| 77 |
+
print("### make_move")
|
| 78 |
move = chess.Move.from_uci(move)
|
| 79 |
board.push_uci(str(move))
|
| 80 |
+
|
| 81 |
+
global move_num
|
| 82 |
+
move_num += 1
|
| 83 |
+
print("### move_num=" + str(move_num))
|
| 84 |
|
| 85 |
board_svgs.append(chess.svg.board(
|
| 86 |
board,
|
|
|
|
| 96 |
if piece_symbol.isupper()
|
| 97 |
else chess.piece_name(piece.piece_type)
|
| 98 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
return f"Moved {piece_name} ({piece_symbol}) from "\
|
| 101 |
f"{chess.SQUARE_NAMES[move.from_square]} to "\
|
|
|
|
| 105 |
return f"Error: unable to make move {move}"
|
| 106 |
|
| 107 |
def create_graph():
|
|
|
|
|
|
|
| 108 |
players = ["player_white", "player_black"]
|
| 109 |
+
options = players
|
| 110 |
+
|
| 111 |
+
llm_board_proxy = ChatOpenAI(model="gpt-4o")
|
| 112 |
+
llm_player_white = ChatOpenAI(model="gpt-4o")
|
| 113 |
+
llm_player_black = ChatOpenAI(model="gpt-4o")
|
| 114 |
|
| 115 |
system_prompt = (
|
| 116 |
"You are a Chess Board Proxy tasked with managing a game of chess "
|
|
|
|
| 118 |
"then the players take turns."
|
| 119 |
)
|
| 120 |
|
|
|
|
|
|
|
|
|
|
| 121 |
function_def = {
|
| 122 |
"name": "route",
|
| 123 |
"description": "Select the next player.",
|
|
|
|
| 149 |
]
|
| 150 |
).partial(options=str(options), members=", ".join(players), verbose=True)
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
supervisor_chain = (
|
| 153 |
prompt
|
| 154 |
+
| llm_board_proxy.bind_functions(functions=[function_def], function_call="route")
|
| 155 |
| JsonOutputFunctionsParser()
|
| 156 |
)
|
| 157 |
|
| 158 |
+
player_white_agent = create_agent(llm_player_white, [get_legal_moves, make_move], system_prompt=
|
| 159 |
"You are a chess Grandmaster and you play as white. "
|
| 160 |
+
"First call get_legal_moves() to get a list of legal moves. "
|
| 161 |
+
"Then study the returned moves and call make_move(move) to make the best move. "
|
| 162 |
+
"Finally analyze the move in format: **Analysis:** move in UCI format, emoji of piece, unordered list of 3 items.")
|
| 163 |
player_white_node = functools.partial(agent_node, agent=player_white_agent, name="player_white")
|
| 164 |
|
| 165 |
+
player_black_agent = create_agent(llm_player_black, [get_legal_moves, make_move], system_prompt=
|
| 166 |
"You are a chess Grandmaster and you play as black. "
|
| 167 |
+
"First call get_legal_moves() to get a list of legal moves. "
|
| 168 |
+
"Then study the returned moves and call make_move(move) to make the best move. "
|
| 169 |
+
"Finally analyze the move in format: **Analysis:** move in UCI format, emoji of piece, unordered list of 3 items.")
|
| 170 |
player_black_node = functools.partial(agent_node, agent=player_black_agent, name="player_black")
|
| 171 |
|
| 172 |
graph = StateGraph(AgentState)
|
| 173 |
+
|
| 174 |
+
graph.add_node("chess_board_proxy", supervisor_chain)
|
| 175 |
graph.add_node("player_white", player_white_node)
|
| 176 |
graph.add_node("player_black", player_black_node)
|
| 177 |
+
|
| 178 |
+
conditional_map = {k: k for k in players}
|
| 179 |
+
conditional_map["END"] = END
|
| 180 |
+
graph.add_conditional_edges("chess_board_proxy", lambda x: x["next"], conditional_map)
|
| 181 |
|
| 182 |
graph.add_conditional_edges(
|
| 183 |
"player_white",
|
|
|
|
| 191 |
{"chess_board_proxy": "chess_board_proxy", END: END}
|
| 192 |
)
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
graph.set_entry_point("chess_board_proxy")
|
| 195 |
|
| 196 |
return graph.compile()
|
|
|
|
| 198 |
def should_continue(state):
|
| 199 |
print("#### should_continue")
|
| 200 |
global move_num, num_moves, legal_moves
|
| 201 |
+
|
| 202 |
if move_num == num_moves:
|
|
|
|
| 203 |
return END # max moves reached
|
| 204 |
+
|
| 205 |
if not legal_moves:
|
|
|
|
| 206 |
return END # checkmate or stalemate
|
| 207 |
+
|
| 208 |
return "chess_board_proxy"
|
| 209 |
|
| 210 |
def initialize():
|
|
|
|
| 219 |
legal_moves = ""
|
| 220 |
|
| 221 |
def run_multi_agent(moves_num):
|
|
|
|
|
|
|
| 222 |
global num_moves
|
|
|
|
| 223 |
num_moves = moves_num
|
| 224 |
+
|
| 225 |
+
initialize()
|
| 226 |
+
|
|
|
|
| 227 |
graph = create_graph()
|
| 228 |
|
| 229 |
result = ""
|
| 230 |
|
| 231 |
try:
|
| 232 |
+
config = {"recursion_limit": 100}
|
| 233 |
|
| 234 |
result = graph.invoke({
|
| 235 |
"messages": [
|
|
|
|
| 239 |
except Exception as e:
|
| 240 |
print(f"An error occurred: {e}")
|
| 241 |
|
| 242 |
+
result_md = ""
|
|
|
|
|
|
|
| 243 |
num_move = 0
|
| 244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
|
|
|
|
|
|
|
|
|
| 246 |
if "messages" in result:
|
| 247 |
for message in result["messages"]:
|
| 248 |
player = ""
|
|
|
|
| 253 |
player = "Player White"
|
| 254 |
|
| 255 |
if num_move > 0:
|
| 256 |
+
result_md += f"**{player}, Move {num_move}**\n{message.content}\n{board_svgs[num_move - 1]}\n\n"
|
| 257 |
|
| 258 |
num_move += 1
|
| 259 |
|
|
|
|
| 261 |
break
|
| 262 |
|
| 263 |
print("===")
|
| 264 |
+
print(result_md)
|
|
|
|
|
|
|
| 265 |
print("===")
|
| 266 |
|
| 267 |
+
return result_md
|