Update multi_agent.py
Browse files- multi_agent.py +27 -4
multi_agent.py
CHANGED
|
@@ -6,8 +6,10 @@ from langchain_openai import ChatOpenAI
|
|
| 6 |
from langgraph.graph import StateGraph, END
|
| 7 |
from typing import TypedDict, Annotated, List
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
class AgentState(TypedDict):
|
| 10 |
-
board: chess.Board
|
| 11 |
player: str
|
| 12 |
made_move: bool
|
| 13 |
move_num: int
|
|
@@ -15,11 +17,11 @@ class AgentState(TypedDict):
|
|
| 15 |
legal_moves: list[chess.Move]
|
| 16 |
white_moves: Annotated[list[chess.Move], operator.add]
|
| 17 |
black_moves: Annotated[list[chess.Move], operator.add]
|
| 18 |
-
board_svgs: Annotated[list[str], operator.add]
|
| 19 |
|
| 20 |
def create_graph():
|
|
|
|
| 21 |
graph = StateGraph(AgentState)
|
| 22 |
-
|
| 23 |
graph.add_node("chess_board", board_node)
|
| 24 |
graph.add_node("player_white", white_node)
|
| 25 |
graph.add_node("player_black", black_node)
|
|
@@ -44,6 +46,8 @@ def create_graph():
|
|
| 44 |
return graph.compile()
|
| 45 |
|
| 46 |
def should_continue_white(state):
|
|
|
|
|
|
|
| 47 |
print("## move_num=" + str(state["move_num"]))
|
| 48 |
print("## max_moves=" + str(state["max_moves"]))
|
| 49 |
if state["move_num"] > state["max_moves"]:
|
|
@@ -53,6 +57,8 @@ def should_continue_white(state):
|
|
| 53 |
return "player_white"
|
| 54 |
|
| 55 |
def should_continue_black(state):
|
|
|
|
|
|
|
| 56 |
print("## move_num=" + str(state["move_num"]))
|
| 57 |
print("## max_moves=" + str(state["max_moves"]))
|
| 58 |
if state["move_num"] > state["max_moves"]:
|
|
@@ -62,6 +68,8 @@ def should_continue_black(state):
|
|
| 62 |
return "player_black"
|
| 63 |
|
| 64 |
def board_node(state: AgentState):
|
|
|
|
|
|
|
| 65 |
player = "player_black" if state["player"] == "player_white" else "player_white"
|
| 66 |
print("## player=" + player)
|
| 67 |
messages = [
|
|
@@ -76,6 +84,8 @@ def board_node(state: AgentState):
|
|
| 76 |
}
|
| 77 |
|
| 78 |
def white_node(state: AgentState):
|
|
|
|
|
|
|
| 79 |
legal_moves = str(state["legal_moves"][-1])
|
| 80 |
print("## legal_moves=" + legal_moves)
|
| 81 |
messages = [
|
|
@@ -90,6 +100,8 @@ def white_node(state: AgentState):
|
|
| 90 |
}
|
| 91 |
|
| 92 |
def black_node(state: AgentState):
|
|
|
|
|
|
|
| 93 |
legal_moves = str(state["legal_moves"][-1])
|
| 94 |
print("## legal_moves=" + legal_moves)
|
| 95 |
messages = [
|
|
@@ -109,6 +121,7 @@ def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format"]:
|
|
| 109 |
The input should always be an empty string,
|
| 110 |
and this function will always return legal moves in UCI format."""
|
| 111 |
try:
|
|
|
|
| 112 |
return "Possible moves are: " + ",".join(
|
| 113 |
[str(move) for move in board.legal_moves]
|
| 114 |
)
|
|
@@ -122,6 +135,7 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
|
|
| 122 |
The input should always be a move in UCI format,
|
| 123 |
and this function will always return the result of the move in UCI format."""
|
| 124 |
try:
|
|
|
|
| 125 |
move = chess.Move.from_uci(move)
|
| 126 |
board.push_uci(str(move))
|
| 127 |
piece = board.piece_at(move.to_square)
|
|
@@ -145,10 +159,19 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
|
|
| 145 |
except Exception as e:
|
| 146 |
print(f"An error occurred in make_move: {e}")
|
| 147 |
return f"Error: unable to make move {move}"
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
def run_multi_agent(max_moves):
|
|
|
|
|
|
|
| 150 |
graph = create_graph()
|
| 151 |
|
|
|
|
|
|
|
| 152 |
for s in graph.stream({
|
| 153 |
"board": chess.Board(),
|
| 154 |
"move_num": 0,
|
|
|
|
| 6 |
from langgraph.graph import StateGraph, END
|
| 7 |
from typing import TypedDict, Annotated, List
|
| 8 |
|
| 9 |
+
board = None
|
| 10 |
+
board_svgs = None
|
| 11 |
+
|
| 12 |
class AgentState(TypedDict):
|
|
|
|
| 13 |
player: str
|
| 14 |
made_move: bool
|
| 15 |
move_num: int
|
|
|
|
| 17 |
legal_moves: list[chess.Move]
|
| 18 |
white_moves: Annotated[list[chess.Move], operator.add]
|
| 19 |
black_moves: Annotated[list[chess.Move], operator.add]
|
|
|
|
| 20 |
|
| 21 |
def create_graph():
|
| 22 |
+
print("#### create_graph")
|
| 23 |
graph = StateGraph(AgentState)
|
| 24 |
+
|
| 25 |
graph.add_node("chess_board", board_node)
|
| 26 |
graph.add_node("player_white", white_node)
|
| 27 |
graph.add_node("player_black", black_node)
|
|
|
|
| 46 |
return graph.compile()
|
| 47 |
|
| 48 |
def should_continue_white(state):
|
| 49 |
+
print("#### should_continue_white")
|
| 50 |
+
print("#### state=" + str(state))
|
| 51 |
print("## move_num=" + str(state["move_num"]))
|
| 52 |
print("## max_moves=" + str(state["max_moves"]))
|
| 53 |
if state["move_num"] > state["max_moves"]:
|
|
|
|
| 57 |
return "player_white"
|
| 58 |
|
| 59 |
def should_continue_black(state):
|
| 60 |
+
print("#### should_continue_black")
|
| 61 |
+
print("#### state=" + str(state))
|
| 62 |
print("## move_num=" + str(state["move_num"]))
|
| 63 |
print("## max_moves=" + str(state["max_moves"]))
|
| 64 |
if state["move_num"] > state["max_moves"]:
|
|
|
|
| 68 |
return "player_black"
|
| 69 |
|
| 70 |
def board_node(state: AgentState):
|
| 71 |
+
print("#### board_node")
|
| 72 |
+
print("#### state=" + str(state))
|
| 73 |
player = "player_black" if state["player"] == "player_white" else "player_white"
|
| 74 |
print("## player=" + player)
|
| 75 |
messages = [
|
|
|
|
| 84 |
}
|
| 85 |
|
| 86 |
def white_node(state: AgentState):
|
| 87 |
+
print("#### white_node")
|
| 88 |
+
print("#### state=" + str(state))
|
| 89 |
legal_moves = str(state["legal_moves"][-1])
|
| 90 |
print("## legal_moves=" + legal_moves)
|
| 91 |
messages = [
|
|
|
|
| 100 |
}
|
| 101 |
|
| 102 |
def black_node(state: AgentState):
|
| 103 |
+
print("#### black_node")
|
| 104 |
+
print("#### state=" + str(state))
|
| 105 |
legal_moves = str(state["legal_moves"][-1])
|
| 106 |
print("## legal_moves=" + legal_moves)
|
| 107 |
messages = [
|
|
|
|
| 121 |
The input should always be an empty string,
|
| 122 |
and this function will always return legal moves in UCI format."""
|
| 123 |
try:
|
| 124 |
+
print("#### get_legal_moves")
|
| 125 |
return "Possible moves are: " + ",".join(
|
| 126 |
[str(move) for move in board.legal_moves]
|
| 127 |
)
|
|
|
|
| 135 |
The input should always be a move in UCI format,
|
| 136 |
and this function will always return the result of the move in UCI format."""
|
| 137 |
try:
|
| 138 |
+
print("#### make_move")
|
| 139 |
move = chess.Move.from_uci(move)
|
| 140 |
board.push_uci(str(move))
|
| 141 |
piece = board.piece_at(move.to_square)
|
|
|
|
| 159 |
except Exception as e:
|
| 160 |
print(f"An error occurred in make_move: {e}")
|
| 161 |
return f"Error: unable to make move {move}"
|
| 162 |
+
|
| 163 |
+
def initialize():
|
| 164 |
+
global board, board_svgs
|
| 165 |
+
board = chess.Board()
|
| 166 |
+
board_svgs = []
|
| 167 |
+
|
| 168 |
def run_multi_agent(max_moves):
|
| 169 |
+
initialize()
|
| 170 |
+
|
| 171 |
graph = create_graph()
|
| 172 |
|
| 173 |
+
print("#### Start")
|
| 174 |
+
|
| 175 |
for s in graph.stream({
|
| 176 |
"board": chess.Board(),
|
| 177 |
"move_num": 0,
|