File size: 8,895 Bytes
30b2934
c1295c9
68c2dd2
30b2934
 
c1295c9
 
 
30b2934
c1295c9
 
 
4f26eb7
9c70149
68c2dd2
c1295c9
 
e623936
 
 
a9ce147
30b2934
25c46e7
4b7945b
c1295c9
30b2934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb9aede
30b2934
 
 
07eb3c8
30b2934
 
 
 
 
 
a664616
 
 
 
 
 
4b7945b
9f94153
 
a664616
 
 
 
 
 
 
 
07eb3c8
a664616
 
 
1dbbcfc
 
 
07eb3c8
30b2934
c1295c9
 
 
 
ecfc813
c1295c9
 
a664616
 
 
 
 
 
 
 
 
 
 
 
 
 
4cff6a9
07eb3c8
1af08bd
1dbbcfc
 
07eb3c8
 
 
66bea43
c1295c9
 
 
f6e6c81
c1295c9
 
 
 
 
 
 
 
 
 
 
 
c9f3bc9
c1295c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65207b1
c1295c9
 
c9f3bc9
30b2934
c1295c9
 
1dbbcfc
c1295c9
 
 
1dbbcfc
2769590
1dbbcfc
 
07eb3c8
2769590
6a9a096
1dbbcfc
2769590
1dbbcfc
 
07eb3c8
2769590
c1295c9
4cff6a9
1dbbcfc
 
4cff6a9
 
1dbbcfc
 
 
c1295c9
b3a85bf
 
 
f6e6c81
b3a85bf
 
 
 
 
f6e6c81
b3a85bf
 
f6e6c81
716d63d
4cff6a9
c1295c9
b3a85bf
9f94153
40b8162
89659e8
b3a85bf
1dbbcfc
9f94153
 
d20cc76
f6e6c81
c1295c9
25c46e7
 
 
 
 
 
 
1e1c635
25c46e7
 
c1295c9
07eb3c8
9d5c872
 
7c01511
1af08bd
1dbbcfc
07eb3c8
c1295c9
1360816
68c2dd2
1360816
1dbbcfc
0690d4a
e3495d6
3991e5f
 
5a1290f
0690d4a
1360816
 
c1295c9
1dbbcfc
2394d16
 
9adbf53
 
 
 
 
 
 
 
 
 
1dbbcfc
9adbf53
 
 
 
 
68c2dd2
1dbbcfc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import chess, chess.svg, math
import functools, operator

from datetime import date

from typing import Annotated, Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union

from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI

from langgraph.graph import StateGraph, END
        
board = None
board_svgs = None

num_moves = 0
move_num = 0

legal_moves = ""

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    next: str

def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="messages"),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ]
    )
    
    agent = create_openai_tools_agent(llm, tools, prompt)
    
    return AgentExecutor(agent=agent, 
                         tools=tools,
                         handle_parsing_errors=True,
                         return_intermediate_steps=True,
                         verbose=True,
                         max_iterations=5)

def agent_node(state, agent, name):
    try:
        #print(f"agent node: {name}")
        result = agent.invoke(state)
        return {"messages": [HumanMessage(content=result["output"], name=name)]}
    except Exception as e:
        print(f"An error occurred in agent_node: {e}")
        return {"messages": [HumanMessage(content=f"Error: {e}", name=name)]}

@tool
def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format"]:
    """Returns a list of legal moves in UCI format. 
       The input should always be an empty string, 
       and this function will always return legal moves in UCI format."""
    try:
        global legal_moves
        legal_moves = ",".join([str(move) for move in board.legal_moves])
        return legal_moves
    except Exception as e:
        print(f"An error occurred in get_legal_moves: {e}")
        return "Error: unable to get legal moves"

@tool
def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "Result of the move."]:
    """Makes a move. 
       The input should always be a move in UCI format, 
       and this function will always return the result of the move."""
    try:
        move = chess.Move.from_uci(move)
        board.push_uci(str(move))

        global move_num
        move_num += 1
        print(f"move_num: {str(move_num)}")
        
        board_svgs.append(chess.svg.board(
            board,
            arrows=[(move.from_square, move.to_square)],
            fill={move.from_square: "gray"},
            size=600
        ))

        piece = board.piece_at(move.to_square)
        piece_symbol = piece.unicode_symbol()
        piece_name = (
            chess.piece_name(piece.piece_type).capitalize()
            if piece_symbol.isupper()
            else chess.piece_name(piece.piece_type)
        )
        
        return f"Moved {piece_name} ({piece_symbol}) from "\
               f"{chess.SQUARE_NAMES[move.from_square]} to "\
               f"{chess.SQUARE_NAMES[move.to_square]}."
    except Exception as e:
        print(f"An error occurred in make_move: {e}")
        return f"Error: unable to make move {move}"
    
def create_graph(llm_board, llm_white, llm_black):
    players = ["player_white", "player_black"]
    options = players

    llm_board_proxy = ChatOpenAI(model=llm_board)
    llm_player_white = ChatOpenAI(model=llm_white)
    llm_player_black = ChatOpenAI(model=llm_black)
    
    system_prompt = (
        "You are a Chess Board Proxy tasked with managing a game of chess "
        "between player_white and player_black. player_white makes the first move, "
        "then the players take turns."
    )

    function_def = {
        "name": "route",
        "description": "Select the next player.",
        "parameters": {
            "title": "routeSchema",
            "type": "object",
            "properties": {
                "next": {
                    "title": "Next",
                    "anyOf": [
                        {"enum": options},
                    ],
                }
            },
            "required": ["next"],
        },
    }
    
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="messages"),
            (
                "system",
                "If player_white made a move, player_black must make the next move. "
                "If player_black made a move, player_white must make the next move. "
                "Select one of: {options}.",
            ),
        ]
    ).partial(options=str(options), members=", ".join(players), verbose=True)
    
    supervisor_chain = (
        prompt
        | llm_board_proxy.bind_functions(functions=[function_def], function_call="route")
        | JsonOutputFunctionsParser()
    )

    player_white_agent = create_agent(llm_player_white, [get_legal_moves, make_move], system_prompt=
                                     "You are a chess Grandmaster and you play as white. "
                                     "First call get_legal_moves() to get a list of legal moves. "
                                     "Then study the returned moves and call make_move(move) to make the best move. "
                                     "Finally analyze the move: **Analysis:** move in UCI format, emoji of piece, unordered list of 3 items.")
    player_white_node = functools.partial(agent_node, agent=player_white_agent, name="player_white")

    player_black_agent = create_agent(llm_player_black, [get_legal_moves, make_move], system_prompt=
                                     "You are a chess Grandmaster and you play as black. "
                                     "First call get_legal_moves() to get a list of legal moves. "
                                     "Then study the returned moves and call make_move(move) to make the best move. "
                                     "Finally analyze the move: **Analysis:** move in UCI format, emoji of piece, unordered list of 3 items.")
    player_black_node = functools.partial(agent_node, agent=player_black_agent, name="player_black")
    
    graph = StateGraph(AgentState)

    graph.add_node("chess_board_proxy", supervisor_chain)
    graph.add_node("player_white", player_white_node)
    graph.add_node("player_black", player_black_node)

    conditional_map = {k: k for k in players}
    graph.add_conditional_edges("chess_board_proxy", lambda x: x["next"], conditional_map)

    graph.add_conditional_edges(
        "player_white", 
        should_continue, 
        {"chess_board_proxy": "chess_board_proxy", END: END}
    )

    graph.add_conditional_edges(
        "player_black", 
        should_continue, 
        {"chess_board_proxy": "chess_board_proxy", END: END}
    )
    
    graph.set_entry_point("chess_board_proxy")
    
    return graph.compile()

def should_continue(state):
    global move_num, num_moves, legal_moves
    
    if move_num == num_moves:
        return END # max moves reached
    
    if not legal_moves:
        return END # checkmate or stalemate

    return "chess_board_proxy"

def initialize():
    global board, board_svgs, num_moves, move_num, legal_moves

    board = chess.Board()
    board_svgs = []

    num_moves = 0
    move_num = 0
    
    legal_moves = ""

def run_multi_agent(llm_board, llm_white, llm_black, moves_num):
    initialize()
    
    global num_moves
    num_moves = moves_num

    graph = create_graph(llm_board, llm_white, llm_black)

    result = ""
    
    try:
        config = {"recursion_limit": 100}
        
        result = graph.invoke({
            "messages": [
                HumanMessage(content="Let's play chess, player_white starts.")
            ]
        }, config=config)
    except Exception as e:
        print(f"An error occurred: {e}")

    result_md = ""
    num_move = 0

    if "messages" in result:
        for message in result["messages"]:
            player = ""
            
            if num_move % 2 == 0:
                player = "Player Black"
            else:
                player = "Player White"
    
            if num_move > 0:
                result_md += f"**{player}, Move {num_move}**\n{message.content}\n{board_svgs[num_move - 1]}\n\n"
            
            num_move += 1
    
            if num_moves % 2 == 0 and num_move == num_moves + 1:
                break
    
    return result_md