trioskosmos commited on
Commit
e95dcfe
·
verified ·
1 Parent(s): d98decc

Upload ai/utils/mcts_battle.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/utils/mcts_battle.py +273 -0
ai/utils/mcts_battle.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import concurrent.futures
3
+ import json
4
+ import multiprocessing
5
+ import os
6
+ import random
7
+ import sys
8
+
9
+ from tqdm import tqdm
10
+
11
+ # Add project root to path
12
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
13
+
14
+ import engine_rust
15
+
16
+ from ai.utils.benchmark_decks import parse_deck
17
+
18
+
19
+ def run_single_battle(
20
+ g_idx, sims0, sims1, time0, time1, h0, h1, m0, m1, hr0, hr1, db_content, decks, alternate=False, verbose=False
21
+ ):
22
+ db = engine_rust.PyCardDatabase(db_content)
23
+ game = engine_rust.PyGameState(db)
24
+ game.silent = not verbose
25
+
26
+ # Convert strings to Rust enums
27
+ modes = {
28
+ "blind": engine_rust.EvalMode.Blind,
29
+ "normal": engine_rust.EvalMode.Normal,
30
+ }
31
+ horizons = {
32
+ "game": engine_rust.SearchHorizon.GameEnd,
33
+ "turn": engine_rust.SearchHorizon.TurnEnd,
34
+ }
35
+
36
+ m_enum0 = modes.get(m0, engine_rust.EvalMode.Blind)
37
+ m_enum1 = modes.get(m1, engine_rust.EvalMode.Blind)
38
+ hr_enum0 = horizons.get(hr0, engine_rust.SearchHorizon.GameEnd)
39
+ hr_enum1 = horizons.get(hr1, engine_rust.SearchHorizon.GameEnd)
40
+
41
+ # Select random decks
42
+ p0_deck, p0_lives, p0_energy = random.choice(decks)
43
+ p1_deck, p1_lives, p1_energy = random.choice(decks)
44
+
45
+ # Alternate who goes first
46
+ if alternate and g_idx % 2 == 1:
47
+ game.initialize_game(p1_deck, p0_deck, p1_energy, p0_energy, p1_lives, p0_lives)
48
+ sims = [sims1, sims0]
49
+ time_limits = [time1, time0]
50
+ heuristics = [h1, h0]
51
+ m_enums = [m_enum1, m_enum0]
52
+ hr_enums = [hr_enum1, hr_enum0]
53
+ p1_is_p0_in_engine = True
54
+ else:
55
+ game.initialize_game(p0_deck, p1_deck, p0_energy, p1_energy, p0_lives, p1_lives)
56
+ sims = [sims0, sims1]
57
+ time_limits = [time0, time1]
58
+ heuristics = [h0, h1]
59
+ m_enums = [m_enum0, m_enum1]
60
+ hr_enums = [hr_enum0, hr_enum1]
61
+ p1_is_p0_in_engine = False
62
+
63
+ step = 0
64
+ p0_sims_total, p1_sims_total = 0, 0
65
+ p0_moves, p1_moves = 0, 0
66
+
67
+ while not game.is_terminal() and step < 1000:
68
+ cp = game.current_player
69
+ phase = game.phase
70
+ # -1: MulliganP1, 0: MulliganP2, 2: Energy, 4: Main, 5: LiveSet, 8: LiveResult
71
+ if phase in [-1, 0, 2, 4, 5, 8]:
72
+ # Use time limit if provided, otherwise generic sims
73
+ t_limit = time_limits[cp] if time_limits else 0.0
74
+ n_sims = sims[cp] if t_limit <= 0.0 else 0
75
+ suggestions = game.search_mcts(n_sims, t_limit, heuristics[cp], hr_enums[cp], m_enums[cp])
76
+ best_action = suggestions[0][0] if suggestions else 0
77
+ try:
78
+ # Track sims
79
+ total_sims = sum(s[2] for s in suggestions)
80
+ if cp == 0:
81
+ p0_sims_total += total_sims
82
+ p0_moves += 1
83
+ else:
84
+ p1_sims_total += total_sims
85
+ p1_moves += 1
86
+
87
+ if verbose:
88
+ tqdm.write(f" P{cp} {phase}: Action {best_action} ({total_sims} sims)")
89
+
90
+ game.step(best_action)
91
+ except Exception as e:
92
+ if verbose:
93
+ tqdm.write(f" Step Error (Phase {phase}): {e}")
94
+ break
95
+ else:
96
+ try:
97
+ game.step(0)
98
+ except Exception as e:
99
+ if verbose:
100
+ tqdm.write(f" Auto-Step Error (Phase {phase}): {e}")
101
+ break
102
+ step += 1
103
+
104
+ if step >= 1000:
105
+ if verbose:
106
+ tqdm.write(f" Action Limit Reached (1000 steps) for Game {g_idx}")
107
+
108
+ winner = game.get_winner()
109
+ if alternate and p1_is_p0_in_engine:
110
+ if winner == 0:
111
+ actual_winner = 1
112
+ elif winner == 1:
113
+ actual_winner = 0
114
+ else:
115
+ actual_winner = 2
116
+ p0_score = game.get_player(1).score
117
+ p1_score = game.get_player(0).score
118
+ else:
119
+ actual_winner = winner
120
+ p0_score = game.get_player(0).score
121
+ p1_score = game.get_player(1).score
122
+
123
+ # Save rule log for draws
124
+ if actual_winner == 2:
125
+ log_dir = "ai/data/draw_logs"
126
+ os.makedirs(log_dir, exist_ok=True)
127
+ log_path = os.path.join(log_dir, f"draw_game_{g_idx}.txt")
128
+ with open(log_path, "w", encoding="utf-8") as f_log:
129
+ f_log.write("\n".join(game.rule_log))
130
+
131
+ return {
132
+ "game_id": g_idx,
133
+ "winner": actual_winner,
134
+ "p0_score": p0_score,
135
+ "p1_score": p1_score,
136
+ "turns": game.turn,
137
+ "steps": step,
138
+ "avg_sims_p0": p0_sims_total / max(1, p0_moves) if p0_moves > 0 else 0,
139
+ "avg_sims_p1": p1_sims_total / max(1, p1_moves) if p1_moves > 0 else 0,
140
+ }
141
+
142
+
143
+ def main():
144
+ parser = argparse.ArgumentParser()
145
+ parser.add_argument("--num-games", type=int, default=10)
146
+ parser.add_argument("--sims0", type=int, default=100)
147
+ parser.add_argument("--sims1", type=int, default=100)
148
+ parser.add_argument("--time0", type=float, default=0.0, help="Time limit per action for P0 (seconds)")
149
+ parser.add_argument("--time1", type=float, default=0.0, help="Time limit per action for P1 (seconds)")
150
+ parser.add_argument("--h0", type=str, default="original", choices=["original", "simple", "resnet", "hybrid"])
151
+ parser.add_argument("--h1", type=str, default="original", choices=["original", "simple", "resnet", "hybrid"])
152
+ parser.add_argument("--mode0", type=str, default="blind", choices=["blind", "normal"])
153
+ parser.add_argument("--mode1", type=str, default="blind", choices=["blind", "normal"])
154
+ parser.add_argument("--horizon0", type=str, default="game", choices=["game", "turn"])
155
+ parser.add_argument("--horizon1", type=str, default="game", choices=["game", "turn"])
156
+ parser.add_argument("--alternate", action="store_true")
157
+ parser.add_argument("--verbose", action="store_true", help="Enable move-by-move logging")
158
+ parser.add_argument("--workers", type=int, default=min(multiprocessing.cpu_count(), 4))
159
+ parser.add_argument("--output-file", type=str, default="ai/data/mcts_battle_results.json")
160
+ args = parser.parse_args()
161
+
162
+ print("Loading data and decks...")
163
+ with open("engine/data/cards_compiled.json", "r", encoding="utf-8") as f:
164
+ db_content = f.read()
165
+ db_json = json.loads(db_content)
166
+
167
+ deck_paths = [
168
+ "ai/decks/aqours_cup.txt",
169
+ "ai/decks/hasunosora_cup.txt",
170
+ "ai/decks/liella_cup.txt",
171
+ "ai/decks/muse_cup.txt",
172
+ "ai/decks/nijigaku_cup.txt",
173
+ ]
174
+ decks = []
175
+ for dp in deck_paths:
176
+ if os.path.exists(dp):
177
+ decks.append(parse_deck(dp, db_json["member_db"], db_json["live_db"], db_json.get("energy_db", {})))
178
+
179
+ if not decks:
180
+ p_deck = [124, 127, 130, 132] * 12
181
+ p_lives = [30000, 30001, 30002]
182
+ p_energy = [40000] * 10
183
+ decks = [(p_deck, p_lives, p_energy)]
184
+
185
+ print(
186
+ f"Starting Battle: P0({args.h0}, {args.mode0}, {args.horizon0}) vs P1({args.h1}, {args.mode1}, {args.horizon1})"
187
+ )
188
+
189
+ results = []
190
+ p0_wins, p1_wins, draws = 0, 0, 0
191
+
192
+ def save_results(results_list, path):
193
+ if not results_list:
194
+ return
195
+ summary = {
196
+ "num_games": len(results_list),
197
+ "p0": {"sims": args.sims0, "h": args.h0, "m": args.mode0, "hr": args.horizon0},
198
+ "p1": {"sims": args.sims1, "h": args.h1, "m": args.mode1, "hr": args.horizon1},
199
+ "p0_wins": sum(1 for r in results_list if r["winner"] == 0),
200
+ "p1_wins": sum(1 for r in results_list if r["winner"] == 1),
201
+ "draws": sum(1 for r in results_list if r["winner"] not in [0, 1]),
202
+ "avg_p0_score": sum(r["p0_score"] for r in results_list) / len(results_list),
203
+ "avg_p1_score": sum(r["p1_score"] for r in results_list) / len(results_list),
204
+ "games": results_list,
205
+ }
206
+ os.makedirs(os.path.dirname(path), exist_ok=True)
207
+ with open(path, "w", encoding="utf-8") as f_out:
208
+ json.dump(summary, f_out, indent=2)
209
+
210
+ try:
211
+ with concurrent.futures.ProcessPoolExecutor(max_workers=args.workers) as executor:
212
+ futures = []
213
+ for i in range(args.num_games):
214
+ futures.append(
215
+ executor.submit(
216
+ run_single_battle,
217
+ i,
218
+ args.sims0,
219
+ args.sims1,
220
+ args.time0,
221
+ args.time1,
222
+ args.h0,
223
+ args.h1,
224
+ args.mode0,
225
+ args.mode1,
226
+ args.horizon0,
227
+ args.horizon1,
228
+ db_content,
229
+ decks,
230
+ args.alternate,
231
+ args.verbose,
232
+ )
233
+ )
234
+
235
+ for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Playing"):
236
+ try:
237
+ res = future.result()
238
+ results.append(res)
239
+
240
+ # Determine winner string
241
+ winner_str = "P0 Wins" if res["winner"] == 0 else ("P1 Wins" if res["winner"] == 1 else "Draw")
242
+
243
+ # Print result immediately
244
+ tqdm.write(
245
+ f"Game {res['game_id']:2}: {winner_str:8} | P0: {res['p0_score']:2} - P1: {res['p1_score']:2} | Turns: {res['turns']:2}"
246
+ )
247
+
248
+ if res["winner"] == 0:
249
+ p0_wins += 1
250
+ elif res["winner"] == 1:
251
+ p1_wins += 1
252
+ else:
253
+ draws += 1
254
+ except Exception as e:
255
+ tqdm.write(f"Game failed: {e}")
256
+
257
+ except KeyboardInterrupt:
258
+ print("\nInterrupted!")
259
+
260
+ if results:
261
+ save_results(results, args.output_file)
262
+ print(f"\nResults Summary: P0 Wins: {p0_wins}, P1 Wins: {p1_wins}, Draws: {draws}")
263
+ print(
264
+ f"Avg Score: P0={sum(r['p0_score'] for r in results) / len(results):.2f}, P1={sum(r['p1_score'] for r in results) / len(results):.2f}"
265
+ )
266
+ print(f"Avg Turns: {sum(r['turns'] for r in results) / len(results):.2f}")
267
+ print(
268
+ f"Avg Sims/Action: P0={sum(r.get('avg_sims_p0', 0) for r in results) / len(results):.0f}, P1={sum(r.get('avg_sims_p1', 0) for r in results) / len(results):.0f}"
269
+ )
270
+
271
+
272
+ if __name__ == "__main__":
273
+ main()