trioskosmos commited on
Commit
ae8dcb4
·
verified ·
1 Parent(s): 2d88649

Upload ai/utils/profile_self_play.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/utils/profile_self_play.py +144 -0
ai/utils/profile_self_play.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Profile a single self-play game to identify bottlenecks."""
2
+
3
+ import json
4
+ import os
5
+ import sys
6
+ import time
7
+
8
+ import numpy as np
9
+
10
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
11
+
12
+ import engine_rust
13
+
14
+ from ai.models.training_config import POLICY_SIZE
15
+ from ai.utils.benchmark_decks import parse_deck
16
+
17
+
18
+ def profile_game(sims=100, neural_weight=0.3):
19
+ db_path = "engine/data/cards_compiled.json"
20
+ model_path = "ai/models/alphanet_best.onnx"
21
+
22
+ with open(db_path, "r", encoding="utf-8") as f:
23
+ db_content = f.read()
24
+ db_json = json.loads(db_content)
25
+
26
+ db = engine_rust.PyCardDatabase(db_content)
27
+ mcts = engine_rust.PyHybridMCTS(model_path, neural_weight)
28
+
29
+ deck_file = "ai/decks/liella_cup.txt"
30
+ main_deck, lives_deck, energy_deck = parse_deck(
31
+ deck_file, db_json["member_db"], db_json["live_db"], db_json.get("energy_db", {})
32
+ )
33
+ test_deck = (main_deck * 10)[:48]
34
+ test_lives = (lives_deck * 10)[:12]
35
+ test_energy = (energy_deck * 10)[:12]
36
+
37
+ game = engine_rust.PyGameState(db)
38
+ game.silent = True
39
+ game.initialize_game(test_deck, test_deck, test_energy, test_energy, test_lives, test_lives)
40
+
41
+ # Timing accumulators
42
+ times = {
43
+ "encode_state": 0.0,
44
+ "mcts_suggestions": 0.0,
45
+ "policy_build": 0.0,
46
+ "dirichlet_noise": 0.0,
47
+ "action_selection": 0.0,
48
+ "game_step": 0.0,
49
+ "other": 0.0,
50
+ }
51
+ counts = {"interactive": 0, "non_interactive": 0}
52
+
53
+ step = 0
54
+ t_game_start = time.perf_counter()
55
+
56
+ while not game.is_terminal() and step < 500:
57
+ phase = game.phase
58
+ is_interactive = phase in [-1, 0, 4, 5]
59
+
60
+ if is_interactive:
61
+ counts["interactive"] += 1
62
+
63
+ # 1. Encode State
64
+ t0 = time.perf_counter()
65
+ encoded = game.encode_state(db)
66
+ times["encode_state"] += time.perf_counter() - t0
67
+
68
+ # 2. MCTS Suggestions
69
+ t0 = time.perf_counter()
70
+ suggestions = mcts.get_suggestions(game, sims)
71
+ times["mcts_suggestions"] += time.perf_counter() - t0
72
+
73
+ # 3. Build Policy
74
+ t0 = time.perf_counter()
75
+ action_ids = []
76
+ visit_counts = []
77
+ total_visits = 0
78
+ for action, score, visits in suggestions:
79
+ if action < POLICY_SIZE:
80
+ action_ids.append(int(action))
81
+ visit_counts.append(visits)
82
+ total_visits += visits
83
+ if total_visits == 0:
84
+ legal = list(game.get_legal_action_ids())
85
+ action_ids = [int(a) for a in legal if a < POLICY_SIZE]
86
+ visit_counts = [1.0] * len(action_ids)
87
+ total_visits = len(action_ids)
88
+ probs = np.array(visit_counts, dtype=np.float32) / total_visits
89
+ times["policy_build"] += time.perf_counter() - t0
90
+
91
+ # 4. Dirichlet Noise
92
+ t0 = time.perf_counter()
93
+ noise = np.random.dirichlet([1.0] * len(probs))
94
+ probs = 0.5 * probs + 0.5 * noise
95
+ probs /= probs.sum()
96
+ times["dirichlet_noise"] += time.perf_counter() - t0
97
+
98
+ # 5. Action Selection
99
+ t0 = time.perf_counter()
100
+ if step < 60:
101
+ action = np.random.choice(action_ids, p=probs)
102
+ else:
103
+ action = action_ids[np.argmax(probs)]
104
+ times["action_selection"] += time.perf_counter() - t0
105
+
106
+ # 6. Game Step
107
+ t0 = time.perf_counter()
108
+ game.step(int(action))
109
+ times["game_step"] += time.perf_counter() - t0
110
+ else:
111
+ counts["non_interactive"] += 1
112
+ t0 = time.perf_counter()
113
+ game.step(0)
114
+ times["game_step"] += time.perf_counter() - t0
115
+
116
+ step += 1
117
+
118
+ t_game_total = time.perf_counter() - t_game_start
119
+
120
+ print(f"\n{'=' * 50}")
121
+ print(f"PROFILE RESULTS ({sims} sims, weight={neural_weight})")
122
+ print(f"{'=' * 50}")
123
+ print(f"Total Game Time: {t_game_total:.3f}s")
124
+ print(f"Steps: {step} ({counts['interactive']} interactive, {counts['non_interactive']} auto)")
125
+ print(f"\n{'Operation':<25} {'Time (s)':<10} {'% Total':<10} {'Per Call (ms)':<15}")
126
+ print("-" * 60)
127
+
128
+ for op, t in sorted(times.items(), key=lambda x: -x[1]):
129
+ pct = 100 * t / t_game_total if t_game_total > 0 else 0
130
+ calls = counts["interactive"] if op != "game_step" else step
131
+ per_call_ms = 1000 * t / calls if calls > 0 else 0
132
+ print(f"{op:<25} {t:<10.4f} {pct:<10.1f} {per_call_ms:<15.3f}")
133
+
134
+ print(f"\nTerminal: {game.is_terminal()}, Winner: {game.get_winner()}")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ import argparse
139
+
140
+ parser = argparse.ArgumentParser()
141
+ parser.add_argument("--sims", type=int, default=100)
142
+ parser.add_argument("--weight", type=float, default=0.3)
143
+ args = parser.parse_args()
144
+ profile_game(sims=args.sims, neural_weight=args.weight)