import chess, chess.svg, math import functools, operator from datetime import date from typing import Annotated, Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from langchain.agents import AgentExecutor, create_openai_tools_agent from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.tools import tool from langchain_openai import ChatOpenAI from langgraph.graph import StateGraph, END board = None board_svgs = None num_moves = 0 move_num = 0 legal_moves = "" class AgentState(TypedDict): messages: Annotated[Sequence[BaseMessage], operator.add] next: str def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str): prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), MessagesPlaceholder(variable_name="messages"), MessagesPlaceholder(variable_name="agent_scratchpad"), ] ) agent = create_openai_tools_agent(llm, tools, prompt) return AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True, return_intermediate_steps=True, verbose=True, max_iterations=5) def agent_node(state, agent, name): try: #print(f"agent node: {name}") result = agent.invoke(state) return {"messages": [HumanMessage(content=result["output"], name=name)]} except Exception as e: print(f"An error occurred in agent_node: {e}") return {"messages": [HumanMessage(content=f"Error: {e}", name=name)]} @tool def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format"]: """Returns a list of legal moves in UCI format. The input should always be an empty string, and this function will always return legal moves in UCI format.""" try: global legal_moves legal_moves = ",".join([str(move) for move in board.legal_moves]) return legal_moves except Exception as e: print(f"An error occurred in get_legal_moves: {e}") return "Error: unable to get legal moves" @tool def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "Result of the move."]: """Makes a move. The input should always be a move in UCI format, and this function will always return the result of the move.""" try: move = chess.Move.from_uci(move) board.push_uci(str(move)) global move_num move_num += 1 print(f"move_num: {str(move_num)}") board_svgs.append(chess.svg.board( board, arrows=[(move.from_square, move.to_square)], fill={move.from_square: "gray"}, size=600 )) piece = board.piece_at(move.to_square) piece_symbol = piece.unicode_symbol() piece_name = ( chess.piece_name(piece.piece_type).capitalize() if piece_symbol.isupper() else chess.piece_name(piece.piece_type) ) return f"Moved {piece_name} ({piece_symbol}) from "\ f"{chess.SQUARE_NAMES[move.from_square]} to "\ f"{chess.SQUARE_NAMES[move.to_square]}." except Exception as e: print(f"An error occurred in make_move: {e}") return f"Error: unable to make move {move}" def create_graph(llm_board, llm_white, llm_black): players = ["player_white", "player_black"] options = players llm_board_proxy = ChatOpenAI(model=llm_board) llm_player_white = ChatOpenAI(model=llm_white) llm_player_black = ChatOpenAI(model=llm_black) system_prompt = ( "You are a Chess Board Proxy tasked with managing a game of chess " "between player_white and player_black. player_white makes the first move, " "then the players take turns." ) function_def = { "name": "route", "description": "Select the next player.", "parameters": { "title": "routeSchema", "type": "object", "properties": { "next": { "title": "Next", "anyOf": [ {"enum": options}, ], } }, "required": ["next"], }, } prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), MessagesPlaceholder(variable_name="messages"), ( "system", "If player_white made a move, player_black must make the next move. " "If player_black made a move, player_white must make the next move. " "Select one of: {options}.", ), ] ).partial(options=str(options), members=", ".join(players), verbose=True) supervisor_chain = ( prompt | llm_board_proxy.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser() ) player_white_agent = create_agent(llm_player_white, [get_legal_moves, make_move], system_prompt= "You are a chess Grandmaster and you play as white. " "First call get_legal_moves() to get a list of legal moves. " "Then study the returned moves and call make_move(move) to make the best move. " "Finally analyze the move: **Analysis:** move in UCI format, emoji of piece, unordered list of 3 items.") player_white_node = functools.partial(agent_node, agent=player_white_agent, name="player_white") player_black_agent = create_agent(llm_player_black, [get_legal_moves, make_move], system_prompt= "You are a chess Grandmaster and you play as black. " "First call get_legal_moves() to get a list of legal moves. " "Then study the returned moves and call make_move(move) to make the best move. " "Finally analyze the move: **Analysis:** move in UCI format, emoji of piece, unordered list of 3 items.") player_black_node = functools.partial(agent_node, agent=player_black_agent, name="player_black") graph = StateGraph(AgentState) graph.add_node("chess_board_proxy", supervisor_chain) graph.add_node("player_white", player_white_node) graph.add_node("player_black", player_black_node) conditional_map = {k: k for k in players} graph.add_conditional_edges("chess_board_proxy", lambda x: x["next"], conditional_map) graph.add_conditional_edges( "player_white", should_continue, {"chess_board_proxy": "chess_board_proxy", END: END} ) graph.add_conditional_edges( "player_black", should_continue, {"chess_board_proxy": "chess_board_proxy", END: END} ) graph.set_entry_point("chess_board_proxy") return graph.compile() def should_continue(state): global move_num, num_moves, legal_moves if move_num == num_moves: return END # max moves reached if not legal_moves: return END # checkmate or stalemate return "chess_board_proxy" def initialize(): global board, board_svgs, num_moves, move_num, legal_moves board = chess.Board() board_svgs = [] num_moves = 0 move_num = 0 legal_moves = "" def run_multi_agent(llm_board, llm_white, llm_black, moves_num): initialize() global num_moves num_moves = moves_num graph = create_graph(llm_board, llm_white, llm_black) result = "" try: config = {"recursion_limit": 100} result = graph.invoke({ "messages": [ HumanMessage(content="Let's play chess, player_white starts.") ] }, config=config) except Exception as e: print(f"An error occurred: {e}") result_md = "" num_move = 0 if "messages" in result: for message in result["messages"]: player = "" if num_move % 2 == 0: player = "Player Black" else: player = "Player White" if num_move > 0: result_md += f"**{player}, Move {num_move}**\n{message.content}\n{board_svgs[num_move - 1]}\n\n" num_move += 1 if num_moves % 2 == 0 and num_move == num_moves + 1: break return result_md