|
|
import chess, chess.svg, math |
|
|
import functools, operator |
|
|
|
|
|
from datetime import date |
|
|
|
|
|
from typing import Annotated, Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union |
|
|
|
|
|
from langchain.agents import AgentExecutor, create_openai_tools_agent |
|
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
|
from langchain_core.messages import BaseMessage, HumanMessage |
|
|
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser |
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
|
from langchain_core.tools import tool |
|
|
from langchain_openai import ChatOpenAI |
|
|
|
|
|
from langgraph.graph import StateGraph, END |
|
|
|
|
|
board = None |
|
|
board_svgs = None |
|
|
|
|
|
num_moves = 0 |
|
|
move_num = 0 |
|
|
|
|
|
legal_moves = "" |
|
|
|
|
|
class AgentState(TypedDict): |
|
|
messages: Annotated[Sequence[BaseMessage], operator.add] |
|
|
next: str |
|
|
|
|
|
def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str): |
|
|
prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
("system", system_prompt), |
|
|
MessagesPlaceholder(variable_name="messages"), |
|
|
MessagesPlaceholder(variable_name="agent_scratchpad"), |
|
|
] |
|
|
) |
|
|
|
|
|
agent = create_openai_tools_agent(llm, tools, prompt) |
|
|
|
|
|
return AgentExecutor(agent=agent, |
|
|
tools=tools, |
|
|
handle_parsing_errors=True, |
|
|
return_intermediate_steps=True, |
|
|
verbose=True, |
|
|
max_iterations=5) |
|
|
|
|
|
def agent_node(state, agent, name): |
|
|
try: |
|
|
|
|
|
result = agent.invoke(state) |
|
|
return {"messages": [HumanMessage(content=result["output"], name=name)]} |
|
|
except Exception as e: |
|
|
print(f"An error occurred in agent_node: {e}") |
|
|
return {"messages": [HumanMessage(content=f"Error: {e}", name=name)]} |
|
|
|
|
|
@tool |
|
|
def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format"]: |
|
|
"""Returns a list of legal moves in UCI format. |
|
|
The input should always be an empty string, |
|
|
and this function will always return legal moves in UCI format.""" |
|
|
try: |
|
|
global legal_moves |
|
|
legal_moves = ",".join([str(move) for move in board.legal_moves]) |
|
|
return legal_moves |
|
|
except Exception as e: |
|
|
print(f"An error occurred in get_legal_moves: {e}") |
|
|
return "Error: unable to get legal moves" |
|
|
|
|
|
@tool |
|
|
def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "Result of the move."]: |
|
|
"""Makes a move. |
|
|
The input should always be a move in UCI format, |
|
|
and this function will always return the result of the move.""" |
|
|
try: |
|
|
move = chess.Move.from_uci(move) |
|
|
board.push_uci(str(move)) |
|
|
|
|
|
global move_num |
|
|
move_num += 1 |
|
|
print(f"move_num: {str(move_num)}") |
|
|
|
|
|
board_svgs.append(chess.svg.board( |
|
|
board, |
|
|
arrows=[(move.from_square, move.to_square)], |
|
|
fill={move.from_square: "gray"}, |
|
|
size=600 |
|
|
)) |
|
|
|
|
|
piece = board.piece_at(move.to_square) |
|
|
piece_symbol = piece.unicode_symbol() |
|
|
piece_name = ( |
|
|
chess.piece_name(piece.piece_type).capitalize() |
|
|
if piece_symbol.isupper() |
|
|
else chess.piece_name(piece.piece_type) |
|
|
) |
|
|
|
|
|
return f"Moved {piece_name} ({piece_symbol}) from "\ |
|
|
f"{chess.SQUARE_NAMES[move.from_square]} to "\ |
|
|
f"{chess.SQUARE_NAMES[move.to_square]}." |
|
|
except Exception as e: |
|
|
print(f"An error occurred in make_move: {e}") |
|
|
return f"Error: unable to make move {move}" |
|
|
|
|
|
def create_graph(llm_board, llm_white, llm_black): |
|
|
players = ["player_white", "player_black"] |
|
|
options = players |
|
|
|
|
|
llm_board_proxy = ChatOpenAI(model=llm_board) |
|
|
llm_player_white = ChatOpenAI(model=llm_white) |
|
|
llm_player_black = ChatOpenAI(model=llm_black) |
|
|
|
|
|
system_prompt = ( |
|
|
"You are a Chess Board Proxy tasked with managing a game of chess " |
|
|
"between player_white and player_black. player_white makes the first move, " |
|
|
"then the players take turns." |
|
|
) |
|
|
|
|
|
function_def = { |
|
|
"name": "route", |
|
|
"description": "Select the next player.", |
|
|
"parameters": { |
|
|
"title": "routeSchema", |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"next": { |
|
|
"title": "Next", |
|
|
"anyOf": [ |
|
|
{"enum": options}, |
|
|
], |
|
|
} |
|
|
}, |
|
|
"required": ["next"], |
|
|
}, |
|
|
} |
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
("system", system_prompt), |
|
|
MessagesPlaceholder(variable_name="messages"), |
|
|
( |
|
|
"system", |
|
|
"If player_white made a move, player_black must make the next move. " |
|
|
"If player_black made a move, player_white must make the next move. " |
|
|
"Select one of: {options}.", |
|
|
), |
|
|
] |
|
|
).partial(options=str(options), members=", ".join(players), verbose=True) |
|
|
|
|
|
supervisor_chain = ( |
|
|
prompt |
|
|
| llm_board_proxy.bind_functions(functions=[function_def], function_call="route") |
|
|
| JsonOutputFunctionsParser() |
|
|
) |
|
|
|
|
|
player_white_agent = create_agent(llm_player_white, [get_legal_moves, make_move], system_prompt= |
|
|
"You are a chess Grandmaster and you play as white. " |
|
|
"First call get_legal_moves() to get a list of legal moves. " |
|
|
"Then study the returned moves and call make_move(move) to make the best move. " |
|
|
"Finally analyze the move: **Analysis:** move in UCI format, emoji of piece, unordered list of 3 items.") |
|
|
player_white_node = functools.partial(agent_node, agent=player_white_agent, name="player_white") |
|
|
|
|
|
player_black_agent = create_agent(llm_player_black, [get_legal_moves, make_move], system_prompt= |
|
|
"You are a chess Grandmaster and you play as black. " |
|
|
"First call get_legal_moves() to get a list of legal moves. " |
|
|
"Then study the returned moves and call make_move(move) to make the best move. " |
|
|
"Finally analyze the move: **Analysis:** move in UCI format, emoji of piece, unordered list of 3 items.") |
|
|
player_black_node = functools.partial(agent_node, agent=player_black_agent, name="player_black") |
|
|
|
|
|
graph = StateGraph(AgentState) |
|
|
|
|
|
graph.add_node("chess_board_proxy", supervisor_chain) |
|
|
graph.add_node("player_white", player_white_node) |
|
|
graph.add_node("player_black", player_black_node) |
|
|
|
|
|
conditional_map = {k: k for k in players} |
|
|
graph.add_conditional_edges("chess_board_proxy", lambda x: x["next"], conditional_map) |
|
|
|
|
|
graph.add_conditional_edges( |
|
|
"player_white", |
|
|
should_continue, |
|
|
{"chess_board_proxy": "chess_board_proxy", END: END} |
|
|
) |
|
|
|
|
|
graph.add_conditional_edges( |
|
|
"player_black", |
|
|
should_continue, |
|
|
{"chess_board_proxy": "chess_board_proxy", END: END} |
|
|
) |
|
|
|
|
|
graph.set_entry_point("chess_board_proxy") |
|
|
|
|
|
return graph.compile() |
|
|
|
|
|
def should_continue(state): |
|
|
global move_num, num_moves, legal_moves |
|
|
|
|
|
if move_num == num_moves: |
|
|
return END |
|
|
|
|
|
if not legal_moves: |
|
|
return END |
|
|
|
|
|
return "chess_board_proxy" |
|
|
|
|
|
def initialize(): |
|
|
global board, board_svgs, num_moves, move_num, legal_moves |
|
|
|
|
|
board = chess.Board() |
|
|
board_svgs = [] |
|
|
|
|
|
num_moves = 0 |
|
|
move_num = 0 |
|
|
|
|
|
legal_moves = "" |
|
|
|
|
|
def run_multi_agent(llm_board, llm_white, llm_black, moves_num): |
|
|
initialize() |
|
|
|
|
|
global num_moves |
|
|
num_moves = moves_num |
|
|
|
|
|
graph = create_graph(llm_board, llm_white, llm_black) |
|
|
|
|
|
result = "" |
|
|
|
|
|
try: |
|
|
config = {"recursion_limit": 100} |
|
|
|
|
|
result = graph.invoke({ |
|
|
"messages": [ |
|
|
HumanMessage(content="Let's play chess, player_white starts.") |
|
|
] |
|
|
}, config=config) |
|
|
except Exception as e: |
|
|
print(f"An error occurred: {e}") |
|
|
|
|
|
result_md = "" |
|
|
num_move = 0 |
|
|
|
|
|
if "messages" in result: |
|
|
for message in result["messages"]: |
|
|
player = "" |
|
|
|
|
|
if num_move % 2 == 0: |
|
|
player = "Player Black" |
|
|
else: |
|
|
player = "Player White" |
|
|
|
|
|
if num_move > 0: |
|
|
result_md += f"**{player}, Move {num_move}**\n{message.content}\n{board_svgs[num_move - 1]}\n\n" |
|
|
|
|
|
num_move += 1 |
|
|
|
|
|
if num_moves % 2 == 0 and num_move == num_moves + 1: |
|
|
break |
|
|
|
|
|
return result_md |