bstraehle commited on
Commit
e623936
·
verified ·
1 Parent(s): a4a5257

Update multi_agent.py

Browse files
Files changed (1) hide show
  1. multi_agent.py +27 -4
multi_agent.py CHANGED
@@ -6,8 +6,10 @@ from langchain_openai import ChatOpenAI
6
  from langgraph.graph import StateGraph, END
7
  from typing import TypedDict, Annotated, List
8
 
 
 
 
9
  class AgentState(TypedDict):
10
- board: chess.Board
11
  player: str
12
  made_move: bool
13
  move_num: int
@@ -15,11 +17,11 @@ class AgentState(TypedDict):
15
  legal_moves: list[chess.Move]
16
  white_moves: Annotated[list[chess.Move], operator.add]
17
  black_moves: Annotated[list[chess.Move], operator.add]
18
- board_svgs: Annotated[list[str], operator.add]
19
 
20
  def create_graph():
 
21
  graph = StateGraph(AgentState)
22
-
23
  graph.add_node("chess_board", board_node)
24
  graph.add_node("player_white", white_node)
25
  graph.add_node("player_black", black_node)
@@ -44,6 +46,8 @@ def create_graph():
44
  return graph.compile()
45
 
46
  def should_continue_white(state):
 
 
47
  print("## move_num=" + str(state["move_num"]))
48
  print("## max_moves=" + str(state["max_moves"]))
49
  if state["move_num"] > state["max_moves"]:
@@ -53,6 +57,8 @@ def should_continue_white(state):
53
  return "player_white"
54
 
55
  def should_continue_black(state):
 
 
56
  print("## move_num=" + str(state["move_num"]))
57
  print("## max_moves=" + str(state["max_moves"]))
58
  if state["move_num"] > state["max_moves"]:
@@ -62,6 +68,8 @@ def should_continue_black(state):
62
  return "player_black"
63
 
64
  def board_node(state: AgentState):
 
 
65
  player = "player_black" if state["player"] == "player_white" else "player_white"
66
  print("## player=" + player)
67
  messages = [
@@ -76,6 +84,8 @@ def board_node(state: AgentState):
76
  }
77
 
78
  def white_node(state: AgentState):
 
 
79
  legal_moves = str(state["legal_moves"][-1])
80
  print("## legal_moves=" + legal_moves)
81
  messages = [
@@ -90,6 +100,8 @@ def white_node(state: AgentState):
90
  }
91
 
92
  def black_node(state: AgentState):
 
 
93
  legal_moves = str(state["legal_moves"][-1])
94
  print("## legal_moves=" + legal_moves)
95
  messages = [
@@ -109,6 +121,7 @@ def get_legal_moves() -> Annotated[str, "A list of legal moves in UCI format"]:
109
  The input should always be an empty string,
110
  and this function will always return legal moves in UCI format."""
111
  try:
 
112
  return "Possible moves are: " + ",".join(
113
  [str(move) for move in board.legal_moves]
114
  )
@@ -122,6 +135,7 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
122
  The input should always be a move in UCI format,
123
  and this function will always return the result of the move in UCI format."""
124
  try:
 
125
  move = chess.Move.from_uci(move)
126
  board.push_uci(str(move))
127
  piece = board.piece_at(move.to_square)
@@ -145,10 +159,19 @@ def make_move(move: Annotated[str, "A move in UCI format."]) -> Annotated[str, "
145
  except Exception as e:
146
  print(f"An error occurred in make_move: {e}")
147
  return f"Error: unable to make move {move}"
148
-
 
 
 
 
 
149
  def run_multi_agent(max_moves):
 
 
150
  graph = create_graph()
151
 
 
 
152
  for s in graph.stream({
153
  "board": chess.Board(),
154
  "move_num": 0,
 
6
  from langgraph.graph import StateGraph, END
7
  from typing import TypedDict, Annotated, List
8
 
9
+ board = None
10
+ board_svgs = None
11
+
12
  class AgentState(TypedDict):
 
13
  player: str
14
  made_move: bool
15
  move_num: int
 
17
  legal_moves: list[chess.Move]
18
  white_moves: Annotated[list[chess.Move], operator.add]
19
  black_moves: Annotated[list[chess.Move], operator.add]
 
20
 
21
  def create_graph():
22
+ print("#### create_graph")
23
  graph = StateGraph(AgentState)
24
+
25
  graph.add_node("chess_board", board_node)
26
  graph.add_node("player_white", white_node)
27
  graph.add_node("player_black", black_node)
 
46
  return graph.compile()
47
 
48
  def should_continue_white(state):
49
+ print("#### should_continue_white")
50
+ print("#### state=" + str(state))
51
  print("## move_num=" + str(state["move_num"]))
52
  print("## max_moves=" + str(state["max_moves"]))
53
  if state["move_num"] > state["max_moves"]:
 
57
  return "player_white"
58
 
59
  def should_continue_black(state):
60
+ print("#### should_continue_black")
61
+ print("#### state=" + str(state))
62
  print("## move_num=" + str(state["move_num"]))
63
  print("## max_moves=" + str(state["max_moves"]))
64
  if state["move_num"] > state["max_moves"]:
 
68
  return "player_black"
69
 
70
  def board_node(state: AgentState):
71
+ print("#### board_node")
72
+ print("#### state=" + str(state))
73
  player = "player_black" if state["player"] == "player_white" else "player_white"
74
  print("## player=" + player)
75
  messages = [
 
84
  }
85
 
86
  def white_node(state: AgentState):
87
+ print("#### white_node")
88
+ print("#### state=" + str(state))
89
  legal_moves = str(state["legal_moves"][-1])
90
  print("## legal_moves=" + legal_moves)
91
  messages = [
 
100
  }
101
 
102
  def black_node(state: AgentState):
103
+ print("#### black_node")
104
+ print("#### state=" + str(state))
105
  legal_moves = str(state["legal_moves"][-1])
106
  print("## legal_moves=" + legal_moves)
107
  messages = [
 
121
  The input should always be an empty string,
122
  and this function will always return legal moves in UCI format."""
123
  try:
124
+ print("#### get_legal_moves")
125
  return "Possible moves are: " + ",".join(
126
  [str(move) for move in board.legal_moves]
127
  )
 
135
  The input should always be a move in UCI format,
136
  and this function will always return the result of the move in UCI format."""
137
  try:
138
+ print("#### make_move")
139
  move = chess.Move.from_uci(move)
140
  board.push_uci(str(move))
141
  piece = board.piece_at(move.to_square)
 
159
  except Exception as e:
160
  print(f"An error occurred in make_move: {e}")
161
  return f"Error: unable to make move {move}"
162
+
163
+ def initialize():
164
+ global board, board_svgs
165
+ board = chess.Board()
166
+ board_svgs = []
167
+
168
  def run_multi_agent(max_moves):
169
+ initialize()
170
+
171
  graph = create_graph()
172
 
173
+ print("#### Start")
174
+
175
  for s in graph.stream({
176
  "board": chess.Board(),
177
  "move_num": 0,