trioskosmos commited on
Commit
ae55ffb
·
verified ·
1 Parent(s): 77b2fc5

Upload ai/utils/battle_benchmark.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/utils/battle_benchmark.py +177 -0
ai/utils/battle_benchmark.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import engine_rust
5
+
6
+
7
+ def run_battle():
8
+ # Load DB
9
+ db_path = "data/cards_compiled.json"
10
+ if not os.path.exists(db_path):
11
+ print(f"Error: {db_path} not found")
12
+ return
13
+
14
+ with open(db_path, "r", encoding="utf-8") as f:
15
+ db_json = f.read()
16
+ db = engine_rust.PyCardDatabase(db_json)
17
+
18
+ # Models
19
+ models = {
20
+ "Original (New)": "original",
21
+ "Simple (Old)": "simple",
22
+ }
23
+
24
+ # Battle setup
25
+ contestants = list(models.keys())
26
+ results = {c: {"wins": 0, "games": 0} for c in contestants}
27
+
28
+ # Deck setup (Standard starter)
29
+ lives0 = [30000, 30001, 30002]
30
+ deck0 = [
31
+ 101,
32
+ 102,
33
+ 103,
34
+ 104,
35
+ 105,
36
+ 106,
37
+ 107,
38
+ 108,
39
+ 109,
40
+ 110,
41
+ 111,
42
+ 112,
43
+ 113,
44
+ 114,
45
+ 115,
46
+ 116,
47
+ 117,
48
+ 118,
49
+ 119,
50
+ 120,
51
+ ] + lives0
52
+ deck1 = deck0.copy()
53
+ lives1 = lives0.copy()
54
+
55
+ print(f"{'Battle':<20} | {'Winner':<20} | {'Duration'}")
56
+ print("-" * 60)
57
+
58
+ # Round robin (partial for speed, can adjust)
59
+ for i in range(len(contestants)):
60
+ for j in range(i + 1, len(contestants)):
61
+ p0_name = contestants[i]
62
+ p1_name = contestants[j]
63
+
64
+ p0_agent = models[p0_name]
65
+ p1_agent = models[p1_name]
66
+
67
+ # Run 1 game
68
+ start = time.time()
69
+ game = engine_rust.PyGameState(db)
70
+ game.initialize_game(deck0, deck1, [], [], lives0, lives1)
71
+
72
+ step_count = 0
73
+ while not game.is_terminal() and step_count < 1000:
74
+ curr_p = game.current_player
75
+ phase = game.phase
76
+ turn = game.turn
77
+ agent_name = p0_name if curr_p == 0 else p1_name
78
+ agent = p0_agent if curr_p == 0 else p1_agent
79
+
80
+ print(f"\n{'=' * 40}")
81
+ print(f"Step {step_count} | Turn {turn} | Player {curr_p} ({agent_name}) | Phase {phase}")
82
+
83
+ # Show Legal Actions
84
+ legal_ids = game.get_legal_action_ids()
85
+ print(f" Legal Actions ({len(legal_ids)}): {legal_ids}")
86
+
87
+ if phase in [-1, 0]:
88
+ sel = game.get_player(curr_p).mulligan_selection
89
+ print(f" Mulligan Selection Mask: {sel:016b}")
90
+
91
+ if isinstance(agent, engine_rust.PyHybridMCTS):
92
+ stats = agent.get_suggestions(game, 0, 0.5)
93
+ else:
94
+ stats = game.search_mcts(0, 0.5, agent)
95
+
96
+ # Show Top 5
97
+ stats.sort(key=lambda x: x[2], reverse=True)
98
+ print(" Top Actions (MCTS Visits/Score):")
99
+ for act, score, visits in stats[:5]:
100
+ print(f" - Action {act:<4}: Score {score:.4f}, Visits {visits}")
101
+
102
+ # Trace Best Path (Next 3 steps)
103
+ if stats and step_count % 10 == 0:
104
+ print(" Predicted Best Path (Simulated):")
105
+ try:
106
+ # Attempt to use copy() or fallback to manual property copy
107
+ if hasattr(game, "copy"):
108
+ temp_game = game.copy()
109
+ else:
110
+ temp_game = engine_rust.PyGameState(db)
111
+ temp_game.current_player = game.current_player
112
+ temp_game.first_player = game.first_player
113
+ temp_game.phase = game.phase
114
+ temp_game.turn = game.turn
115
+ temp_game.set_player(0, game.get_player(0))
116
+ temp_game.set_player(1, game.get_player(1))
117
+
118
+ path_str = []
119
+ trace_curr = temp_game
120
+ for depth in range(3):
121
+ p = trace_curr.current_player
122
+ ph = trace_curr.phase
123
+ t = trace_curr.turn
124
+ # For the first step, use the MCTS top action
125
+ if depth == 0:
126
+ best_act = stats[0][0]
127
+ else:
128
+ # Quick MCTS search for lookahead (100 sims)
129
+ t_stats = trace_curr.search_mcts(100, 0.0, "original")
130
+ if not t_stats:
131
+ break
132
+ t_stats.sort(key=lambda x: x[2], reverse=True)
133
+ best_act = t_stats[0][0]
134
+
135
+ path_str.append(f"[T{t} P{p} Ph{ph} Act{best_act}]")
136
+ trace_curr.step(best_act)
137
+ if trace_curr.is_terminal():
138
+ break
139
+ print(f" {' -> '.join(path_str)}")
140
+ except Exception as e:
141
+ print(f" (Path trace failed: {e})")
142
+
143
+ action = stats[0][0] if stats else 0
144
+ game.step(action)
145
+ step_count += 1
146
+
147
+ # Check for progress
148
+ if step_count % 10 == 0:
149
+ p0 = game.get_player(0)
150
+ p1 = game.get_player(1)
151
+ print(
152
+ f"--- P0: Lives {len(p0.success_lives)}, Hand {len(p0.hand)} | P1: Lives {len(p1.success_lives)}, Hand {len(p1.hand)}"
153
+ )
154
+
155
+ winner_idx = game.get_winner()
156
+ winner_name = "Draw"
157
+ if winner_idx == 0:
158
+ winner_name = p0_name
159
+ results[p0_name]["wins"] += 1
160
+ elif winner_idx == 1:
161
+ winner_name = p1_name
162
+ results[p1_name]["wins"] += 1
163
+
164
+ results[p0_name]["games"] += 1
165
+ results[p1_name]["games"] += 1
166
+
167
+ elapsed = time.time() - start
168
+ print(f"{p0_name} vs {p1_name:<10} | {winner_name:<20} | {elapsed:.1f}s")
169
+
170
+ print("\nFinal Scoreboard:")
171
+ for name, stats in results.items():
172
+ wr = (stats["wins"] / stats["games"] * 100) if stats["games"] > 0 else 0
173
+ print(f"{name:<20}: {stats['wins']}/{stats['games']} ({wr:.1f}% Win Rate)")
174
+
175
+
176
+ if __name__ == "__main__":
177
+ run_battle()