Syndria98 commited on
Commit
52365e5
·
verified ·
1 Parent(s): c2cac70

Upload run.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run.py +384 -0
run.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from game import UltimateTicTacToe
8
+ from mcts import MCTS
9
+ from model import UltimateTicTacToeModel
10
+ from trainer import Trainer
11
+
12
+
13
+ DEFAULT_ARGS = {
14
+ "num_simulations": 100,
15
+ "numIters": 50,
16
+ "numEps": 20,
17
+ "epochs": 5,
18
+ "batch_size": 64,
19
+ "lr": 5e-4,
20
+ "weight_decay": 1e-4,
21
+ "replay_buffer_size": 50000,
22
+ "value_loss_weight": 1.0,
23
+ "grad_clip_norm": 5.0,
24
+ "checkpoint_path": "latest.pth",
25
+ "temperature_threshold": 10,
26
+ "root_dirichlet_alpha": 0.3,
27
+ "root_exploration_fraction": 0.25,
28
+ "arena_compare_games": 6,
29
+ "arena_accept_threshold": 0.55,
30
+ "arena_compare_simulations": 8,
31
+ }
32
+
33
+
34
+ def get_device(device_arg):
35
+ if device_arg:
36
+ return device_arg
37
+ return "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+
40
+ def build_model(game, device):
41
+ return UltimateTicTacToeModel(
42
+ game.get_board_size(),
43
+ game.get_action_size(),
44
+ device,
45
+ )
46
+
47
+
48
+ def load_checkpoint(model, checkpoint_path, device, optimizer=None, required=True):
49
+ checkpoint = Path(checkpoint_path)
50
+ if not checkpoint.exists():
51
+ if required:
52
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
53
+ return False
54
+
55
+ state = torch.load(checkpoint, map_location=device)
56
+ model.load_state_dict(state["state_dict"])
57
+ if optimizer is not None and "optimizer_state_dict" in state:
58
+ optimizer.load_state_dict(state["optimizer_state_dict"])
59
+ model.eval()
60
+ return True
61
+
62
+
63
+ def canonical_state(game, state, player):
64
+ board_data, active_board = state
65
+ return (game.get_canonical_board_data(board_data, player), active_board)
66
+
67
+
68
+ def apply_moves(game, moves):
69
+ state = game.get_init_board()
70
+ player = 1
71
+ for action in moves:
72
+ next_state = game.get_next_state(state, player, action, verify_move=True)
73
+ if next_state is False:
74
+ raise ValueError(f"Illegal move in sequence: {action}")
75
+ state, player = next_state
76
+ return state, player
77
+
78
+
79
+ def format_board(board_data):
80
+ symbols = {1: "X", -1: "O", 0: "."}
81
+ rows = []
82
+ for row in range(9):
83
+ cells = [symbols[int(board_data[row * 9 + col])] for col in range(9)]
84
+ groups = [" ".join(cells[idx:idx + 3]) for idx in (0, 3, 6)]
85
+ rows.append(" | ".join(groups))
86
+ if row in (2, 5):
87
+ rows.append("-" * 23)
88
+ return "\n".join(rows)
89
+
90
+
91
+ def top_policy_moves(policy, limit):
92
+ ranked = np.argsort(policy)[::-1][:limit]
93
+ return [(int(action), float(policy[action])) for action in ranked]
94
+
95
+
96
+ def parse_moves(text):
97
+ if not text:
98
+ return []
99
+ return [int(part.strip()) for part in text.split(",") if part.strip()]
100
+
101
+
102
+ def parse_action(text):
103
+ raw = text.strip().replace(",", " ").split()
104
+ if len(raw) == 1:
105
+ action = int(raw[0])
106
+ elif len(raw) == 2:
107
+ row, col = (int(value) for value in raw)
108
+ if not (0 <= row < 9 and 0 <= col < 9):
109
+ raise ValueError("Row and column must be in [0, 8].")
110
+ action = row * 9 + col
111
+ else:
112
+ raise ValueError("Enter either a flat move index or 'row col'.")
113
+ if not (0 <= action < 81):
114
+ raise ValueError("Move index must be in [0, 80].")
115
+ return action
116
+
117
+
118
+ def scalar_value(value):
119
+ return float(np.asarray(value).reshape(-1)[0])
120
+
121
+
122
+ def train_command(args):
123
+ device = get_device(args.device)
124
+ game = UltimateTicTacToe()
125
+ model = build_model(game, device)
126
+
127
+ train_args = dict(DEFAULT_ARGS)
128
+ train_args.update(
129
+ {
130
+ "num_simulations": args.num_simulations,
131
+ "numIters": args.num_iters,
132
+ "numEps": args.num_eps,
133
+ "epochs": args.epochs,
134
+ "batch_size": args.batch_size,
135
+ "lr": args.lr,
136
+ "weight_decay": args.weight_decay,
137
+ "replay_buffer_size": args.replay_buffer_size,
138
+ "value_loss_weight": args.value_loss_weight,
139
+ "grad_clip_norm": args.grad_clip_norm,
140
+ "checkpoint_path": args.checkpoint,
141
+ "temperature_threshold": args.temperature_threshold,
142
+ "root_dirichlet_alpha": args.root_dirichlet_alpha,
143
+ "root_exploration_fraction": args.root_exploration_fraction,
144
+ "arena_compare_games": args.arena_compare_games,
145
+ "arena_accept_threshold": args.arena_accept_threshold,
146
+ "arena_compare_simulations": args.arena_compare_simulations,
147
+ }
148
+ )
149
+
150
+ trainer = Trainer(game, model, train_args)
151
+ if args.resume:
152
+ load_checkpoint(model, args.checkpoint, device, optimizer=trainer.optimizer)
153
+ trainer.learn()
154
+
155
+
156
+ def eval_command(args):
157
+ device = get_device(args.device)
158
+ game = UltimateTicTacToe()
159
+ model = build_model(game, device)
160
+ load_checkpoint(model, args.checkpoint, device)
161
+
162
+ moves = parse_moves(args.moves)
163
+ state, player = apply_moves(game, moves)
164
+ current_state = canonical_state(game, state, player)
165
+ encoded = game.encode_state(current_state)
166
+ policy, value = model.predict(encoded)
167
+ legal_mask = np.array(game.get_valid_moves(state), dtype=np.float32)
168
+ policy = policy * legal_mask
169
+ if policy.sum() > 0:
170
+ policy = policy / policy.sum()
171
+
172
+ print("Board:")
173
+ print(format_board(state[0]))
174
+ print()
175
+ print(f"Side to move: {'X' if player == 1 else 'O'}")
176
+ print(f"Active small board: {state[1]}")
177
+ print(f"Model value: {scalar_value(value):.4f}")
178
+ print("Top policy moves:")
179
+ for action, prob in top_policy_moves(policy, args.top_k):
180
+ print(f" {action:2d} -> {prob:.4f}")
181
+
182
+ if args.with_mcts:
183
+ mcts_args = dict(DEFAULT_ARGS)
184
+ mcts_args.update(
185
+ {
186
+ "num_simulations": args.num_simulations,
187
+ "root_dirichlet_alpha": None,
188
+ "root_exploration_fraction": None,
189
+ }
190
+ )
191
+ root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1)
192
+ action = root.select_action(temperature=0)
193
+ print(f"MCTS best move: {action}")
194
+
195
+
196
+ def ai_action(game, model, state, player, num_simulations):
197
+ current_state = canonical_state(game, state, player)
198
+ mcts_args = dict(DEFAULT_ARGS)
199
+ mcts_args.update(
200
+ {
201
+ "num_simulations": num_simulations,
202
+ "root_dirichlet_alpha": None,
203
+ "root_exploration_fraction": None,
204
+ }
205
+ )
206
+ root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1)
207
+ return root.select_action(temperature=0)
208
+
209
+
210
+ def random_action(game, state):
211
+ legal_actions = [index for index, allowed in enumerate(game.get_valid_moves(state)) if allowed]
212
+ if not legal_actions:
213
+ raise ValueError("No legal actions available.")
214
+ return int(np.random.choice(legal_actions))
215
+
216
+
217
+ def load_player_model(game, checkpoint, device):
218
+ model = build_model(game, device)
219
+ load_checkpoint(model, checkpoint, device)
220
+ return model
221
+
222
+
223
+ def choose_action(game, player_kind, model, state, player, num_simulations):
224
+ if player_kind == "random":
225
+ return random_action(game, state)
226
+ return ai_action(game, model, state, player, num_simulations)
227
+
228
+
229
+ def play_match(game, x_kind, x_model, o_kind, o_model, num_simulations):
230
+ state = game.get_init_board()
231
+ player = 1
232
+
233
+ while True:
234
+ reward = game.get_reward_for_player(state, player)
235
+ if reward is not None:
236
+ if reward == 0:
237
+ return 0
238
+ return player if reward == 1 else -player
239
+
240
+ if player == 1:
241
+ action = choose_action(game, x_kind, x_model, state, player, num_simulations)
242
+ else:
243
+ action = choose_action(game, o_kind, o_model, state, player, num_simulations)
244
+ state, player = game.get_next_state(state, player, action)
245
+
246
+
247
+ def arena_command(args):
248
+ device = get_device(args.device)
249
+ game = UltimateTicTacToe()
250
+
251
+ x_model = None
252
+ o_model = None
253
+ if args.x_player == "checkpoint":
254
+ x_model = load_player_model(game, args.x_checkpoint, device)
255
+ if args.o_player == "checkpoint":
256
+ o_model = load_player_model(game, args.o_checkpoint, device)
257
+
258
+ results = {1: 0, -1: 0, 0: 0}
259
+ for _ in range(args.games):
260
+ winner = play_match(
261
+ game,
262
+ args.x_player,
263
+ x_model,
264
+ args.o_player,
265
+ o_model,
266
+ args.num_simulations,
267
+ )
268
+ results[winner] += 1
269
+
270
+ print(f"Games: {args.games}")
271
+ print(f"X ({args.x_player}) wins: {results[1]}")
272
+ print(f"O ({args.o_player}) wins: {results[-1]}")
273
+ print(f"Draws: {results[0]}")
274
+
275
+
276
+ def play_command(args):
277
+ device = get_device(args.device)
278
+ game = UltimateTicTacToe()
279
+ model = build_model(game, device)
280
+ load_checkpoint(model, args.checkpoint, device)
281
+
282
+ state = game.get_init_board()
283
+ player = 1
284
+ human_player = args.human_player
285
+
286
+ while True:
287
+ print()
288
+ print(format_board(state[0]))
289
+ print(f"Turn: {'X' if player == 1 else 'O'}")
290
+ print(f"Active small board: {state[1]}")
291
+
292
+ reward = game.get_reward_for_player(state, player)
293
+ if reward is not None:
294
+ if reward == 0:
295
+ print("Result: draw")
296
+ else:
297
+ winner = player if reward == 1 else -player
298
+ print(f"Winner: {'X' if winner == 1 else 'O'}")
299
+ return
300
+
301
+ valid_moves = game.get_valid_moves(state)
302
+ legal_actions = [index for index, allowed in enumerate(valid_moves) if allowed]
303
+ print(f"Legal moves: {legal_actions}")
304
+
305
+ if player == human_player:
306
+ while True:
307
+ try:
308
+ action = parse_action(input("Your move (index or 'row col'): "))
309
+ next_state = game.get_next_state(state, player, action, verify_move=True)
310
+ if next_state is False:
311
+ raise ValueError(f"Illegal move: {action}")
312
+ state, player = next_state
313
+ break
314
+ except ValueError as exc:
315
+ print(exc)
316
+ else:
317
+ action = ai_action(game, model, state, player, args.num_simulations)
318
+ print(f"AI move: {action}")
319
+ state, player = game.get_next_state(state, player, action)
320
+
321
+
322
+ def build_parser():
323
+ parser = argparse.ArgumentParser(description="Ultimate Tic-Tac-Toe Runner")
324
+ subparsers = parser.add_subparsers(dest="command", required=True)
325
+
326
+ train_parser = subparsers.add_parser("train", help="Train the model with self-play")
327
+ train_parser.add_argument("--device")
328
+ train_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
329
+ train_parser.add_argument("--resume", action="store_true")
330
+ train_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
331
+ train_parser.add_argument("--num-iters", type=int, default=DEFAULT_ARGS["numIters"])
332
+ train_parser.add_argument("--num-eps", type=int, default=DEFAULT_ARGS["numEps"])
333
+ train_parser.add_argument("--epochs", type=int, default=DEFAULT_ARGS["epochs"])
334
+ train_parser.add_argument("--batch-size", type=int, default=DEFAULT_ARGS["batch_size"])
335
+ train_parser.add_argument("--lr", type=float, default=DEFAULT_ARGS["lr"])
336
+ train_parser.add_argument("--weight-decay", type=float, default=DEFAULT_ARGS["weight_decay"])
337
+ train_parser.add_argument("--replay-buffer-size", type=int, default=DEFAULT_ARGS["replay_buffer_size"])
338
+ train_parser.add_argument("--value-loss-weight", type=float, default=DEFAULT_ARGS["value_loss_weight"])
339
+ train_parser.add_argument("--grad-clip-norm", type=float, default=DEFAULT_ARGS["grad_clip_norm"])
340
+ train_parser.add_argument("--temperature-threshold", type=int, default=DEFAULT_ARGS["temperature_threshold"])
341
+ train_parser.add_argument("--root-dirichlet-alpha", type=float, default=DEFAULT_ARGS["root_dirichlet_alpha"])
342
+ train_parser.add_argument("--root-exploration-fraction", type=float, default=DEFAULT_ARGS["root_exploration_fraction"])
343
+ train_parser.add_argument("--arena-compare-games", type=int, default=DEFAULT_ARGS["arena_compare_games"])
344
+ train_parser.add_argument("--arena-accept-threshold", type=float, default=DEFAULT_ARGS["arena_accept_threshold"])
345
+ train_parser.add_argument("--arena-compare-simulations", type=int, default=DEFAULT_ARGS["arena_compare_simulations"])
346
+ train_parser.set_defaults(func=train_command)
347
+
348
+ eval_parser = subparsers.add_parser("eval", help="Inspect a checkpoint on a position")
349
+ eval_parser.add_argument("--device")
350
+ eval_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
351
+ eval_parser.add_argument("--moves", default="", help="Comma-separated move sequence")
352
+ eval_parser.add_argument("--top-k", type=int, default=10)
353
+ eval_parser.add_argument("--with-mcts", action="store_true")
354
+ eval_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
355
+ eval_parser.set_defaults(func=eval_command)
356
+
357
+ play_parser = subparsers.add_parser("play", help="Play against the checkpoint")
358
+ play_parser.add_argument("--device")
359
+ play_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
360
+ play_parser.add_argument("--human-player", type=int, choices=[1, -1], default=1)
361
+ play_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
362
+ play_parser.set_defaults(func=play_command)
363
+
364
+ arena_parser = subparsers.add_parser("arena", help="Run repeated matches between agents")
365
+ arena_parser.add_argument("--device")
366
+ arena_parser.add_argument("--games", type=int, default=20)
367
+ arena_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"])
368
+ arena_parser.add_argument("--x-player", choices=["checkpoint", "random"], default="checkpoint")
369
+ arena_parser.add_argument("--o-player", choices=["checkpoint", "random"], default="random")
370
+ arena_parser.add_argument("--x-checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
371
+ arena_parser.add_argument("--o-checkpoint", default=DEFAULT_ARGS["checkpoint_path"])
372
+ arena_parser.set_defaults(func=arena_command)
373
+
374
+ return parser
375
+
376
+
377
+ def main():
378
+ parser = build_parser()
379
+ args = parser.parse_args()
380
+ args.func(args)
381
+
382
+
383
+ if __name__ == "__main__":
384
+ main()