trioskosmos commited on
Commit
f6fedc3
·
verified ·
1 Parent(s): 996ab02

Upload ai/data_generation/self_play.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/data_generation/self_play.py +318 -0
ai/data_generation/self_play.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import concurrent.futures
3
+ import json
4
+ import multiprocessing
5
+ import os
6
+ import random
7
+ import sys
8
+ import time
9
+
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+
13
+ # Pin threads for performance
14
+ os.environ["RAYON_NUM_THREADS"] = "1"
15
+
16
+ # Add project root to path
17
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
18
+
19
+ import engine_rust
20
+
21
+ from ai.utils.benchmark_decks import parse_deck
22
+
23
+ # Global cache for workers (optional, for NN mode)
24
+ _WORKER_MODEL_PATH = None
25
+
26
+
27
+ def worker_init(db_content, model_path=None):
28
+ global _WORKER_DB, _WORKER_MODEL_PATH
29
+ _WORKER_DB = engine_rust.PyCardDatabase(db_content)
30
+ _WORKER_MODEL_PATH = model_path
31
+
32
+
33
+ def run_self_play_game(g_idx, sims, p0_deck_info, p1_deck_info):
34
+ if _WORKER_DB is None:
35
+ return None
36
+
37
+ game = engine_rust.PyGameState(_WORKER_DB)
38
+ game.silent = True
39
+ p0_deck, p0_lives, p0_energy = p0_deck_info
40
+ p1_deck, p1_lives, p1_energy = p1_deck_info
41
+
42
+ game.initialize_game(p0_deck, p1_deck, p0_energy, p1_energy, p0_lives, p1_lives)
43
+
44
+ game_states = []
45
+ game_policies = []
46
+ game_turns_remaining = []
47
+ game_player_turn = []
48
+ game_score_diffs = []
49
+
50
+ # Target values will be backfilled after game ends
51
+
52
+ step = 0
53
+ max_turns = 150 # Estimated max turns for normalization
54
+ while not game.is_terminal() and step < 1000:
55
+ cp = game.current_player
56
+ phase = game.phase
57
+
58
+ # Interactive Phases: Mulligan (-1, 0), Main (4), LiveSet (5)
59
+ is_interactive = phase in [-1, 0, 4, 5]
60
+
61
+ if is_interactive:
62
+ # Observation (now 1200)
63
+ encoded = game.get_observation()
64
+ if len(encoded) != 1200:
65
+ # Pad to 1200 if engine mismatch
66
+ if len(encoded) < 1200:
67
+ encoded = encoded + [0.0] * (1200 - len(encoded))
68
+ else:
69
+ encoded = encoded[:1200]
70
+
71
+ # Use MCTS with Original Heuristic (Teacher Mode)
72
+ # If _WORKER_MODEL_PATH is None, we use pure MCTS
73
+ h_type = "original" if _WORKER_MODEL_PATH is None else "hybrid"
74
+ suggestions = game.search_mcts(
75
+ num_sims=sims, seconds=0.0, heuristic_type=h_type, model_path=_WORKER_MODEL_PATH
76
+ )
77
+
78
+ # Build policy
79
+ policy = np.zeros(2000, dtype=np.float32)
80
+ action_ids = []
81
+ visit_counts = []
82
+ total_visits = 0
83
+ for action, _, visits in suggestions:
84
+ if action < 2000:
85
+ action_ids.append(int(action))
86
+ visit_counts.append(visits)
87
+ total_visits += visits
88
+
89
+ if total_visits == 0:
90
+ legal = list(game.get_legal_action_ids())
91
+ action_ids = [int(a) for a in legal if a < 2000]
92
+ visit_counts = [1.0] * len(action_ids)
93
+ total_visits = len(action_ids)
94
+
95
+ probs = np.array(visit_counts, dtype=np.float32) / total_visits
96
+
97
+ # Add Noise (Dirichlet) for exploration
98
+ if len(probs) > 1:
99
+ noise = np.random.dirichlet([0.3] * len(probs))
100
+ probs = 0.75 * probs + 0.25 * noise
101
+ # CRITICAL: Re-normalize for np.random.choice float precision
102
+ probs = probs / np.sum(probs)
103
+
104
+ for i, aid in enumerate(action_ids):
105
+ policy[aid] = probs[i]
106
+
107
+ game_states.append(encoded)
108
+ game_policies.append(policy)
109
+ game_player_turn.append(cp)
110
+ game_turns_remaining.append(float(game.turn)) # Store current turn, normalize later
111
+
112
+ # Action Selection
113
+ if step < 40: # Explore in early game
114
+ action = np.random.choice(action_ids, p=probs)
115
+ else: # Exploit
116
+ action = action_ids[np.argmax(probs)]
117
+
118
+ try:
119
+ game.step(int(action))
120
+ except:
121
+ break
122
+ else:
123
+ # Auto-step
124
+ try:
125
+ game.step(0)
126
+ except:
127
+ break
128
+ step += 1
129
+
130
+ if not game.is_terminal():
131
+ return None
132
+
133
+ winner = game.get_winner()
134
+ s0 = float(game.get_player(0).score)
135
+ s1 = float(game.get_player(1).score)
136
+ final_turn = float(game.turn)
137
+
138
+ # Process rewards and normalized turns
139
+ winners = []
140
+ scores = []
141
+ turns_normalized = []
142
+
143
+ for i in range(len(game_player_turn)):
144
+ p_idx = game_player_turn[i]
145
+
146
+ # Win Signal (1, 0, -1)
147
+ if winner == 2:
148
+ winners.append(0.0)
149
+ elif p_idx == winner:
150
+ winners.append(1.0)
151
+ else:
152
+ winners.append(-1.0)
153
+
154
+ # Score Diff (Normalized)
155
+ diff = (s0 - s1) if p_idx == 0 else (s1 - s0)
156
+ score_norm = np.tanh(diff / 50.0) # Scale roughly to [-1, 1]
157
+ scores.append(score_norm)
158
+
159
+ # Turns Remaining (Normalized 0..1)
160
+ # 1.0 at start, 0.0 at end
161
+ rem = (final_turn - game_turns_remaining[i]) / max_turns
162
+ turns_normalized.append(np.clip(rem, 0.0, 1.0))
163
+
164
+ return {
165
+ "states": np.array(game_states, dtype=np.float32),
166
+ "policies": np.array(game_policies, dtype=np.float32),
167
+ "winners": np.array(winners, dtype=np.float32),
168
+ "scores": np.array(scores, dtype=np.float32),
169
+ "turns_left": np.array(turns_normalized, dtype=np.float32),
170
+ "outcome": {"winner": winner, "score": (s0, s1), "turns": game.turn},
171
+ }
172
+
173
+
174
+ def generate_self_play(
175
+ num_games=100,
176
+ model_path="ai/models/alphanet.onnx",
177
+ output_file="ai/data/self_play_0.npz",
178
+ sims=100,
179
+ weight=0.3,
180
+ skip_rollout=False,
181
+ workers=0,
182
+ ):
183
+ db_path = "engine/data/cards_compiled.json"
184
+ with open(db_path, "r", encoding="utf-8") as f:
185
+ db_content = f.read()
186
+ db_json = json.loads(db_content)
187
+
188
+ # Load Decks (Standard Pool)
189
+ deck_paths = [
190
+ "ai/decks/aqours_cup.txt",
191
+ "ai/decks/hasunosora_cup.txt",
192
+ "ai/decks/liella_cup.txt",
193
+ "ai/decks/muse_cup.txt",
194
+ "ai/decks/nijigaku_cup.txt",
195
+ ]
196
+ decks = []
197
+ for dp in deck_paths:
198
+ if os.path.exists(dp):
199
+ decks.append(parse_deck(dp, db_json["member_db"], db_json["live_db"], db_json.get("energy_db", {})))
200
+
201
+ all_states, all_policies, all_winners = [], [], []
202
+ all_scores, all_turns = [], []
203
+ total_completed = 0
204
+ total_samples = 0
205
+ chunk_size = 100 # Save every 100 games
206
+
207
+ stats = {"wins": 0, "losses": 0, "draws": 0}
208
+
209
+ if model_path == "None":
210
+ model_path = None
211
+
212
+ max_workers = workers if workers > 0 else min(multiprocessing.cpu_count(), 12)
213
+ mode_str = "Teacher (Heuristic MCTS)" if model_path is None else "Student (Hybrid MCTS)"
214
+ print(f"Starting Self-Play: {num_games} games using {max_workers} workers... Mode: {mode_str}")
215
+
216
+ def save_chunk():
217
+ nonlocal all_states, all_policies, all_winners, all_scores, all_turns
218
+ if not all_states:
219
+ return
220
+ ts = int(time.time())
221
+ path = output_file.replace(".npz", f"_chunk_{total_completed // chunk_size}_{ts}.npz")
222
+ print(f"\n[Disk] Saving {len(all_states)} samples to {path}...")
223
+ np.savez(
224
+ path,
225
+ states=np.array(all_states, dtype=np.float32),
226
+ policies=np.array(all_policies, dtype=np.float32),
227
+ winners=np.array(all_winners, dtype=np.float32),
228
+ scores=np.array(all_scores, dtype=np.float32),
229
+ turns_left=np.array(all_turns, dtype=np.float32),
230
+ )
231
+ all_states, all_policies, all_winners = [], [], []
232
+ all_scores, all_turns = [], []
233
+
234
+ with concurrent.futures.ProcessPoolExecutor(
235
+ max_workers=max_workers, initializer=worker_init, initargs=(db_content, model_path)
236
+ ) as executor:
237
+ pending = {}
238
+ batch_cap = max_workers * 2
239
+ games_submitted = 0
240
+
241
+ pbar = tqdm(total=num_games)
242
+
243
+ while total_completed < num_games or pending:
244
+ while len(pending) < batch_cap and games_submitted < num_games:
245
+ p0, p1 = random.randint(0, len(decks) - 1), random.randint(0, len(decks) - 1)
246
+ f = executor.submit(run_self_play_game, games_submitted, sims, decks[p0], decks[p1])
247
+ pending[f] = games_submitted
248
+ games_submitted += 1
249
+
250
+ if not pending:
251
+ break
252
+
253
+ done, _ = concurrent.futures.wait(pending.keys(), return_when=concurrent.futures.FIRST_COMPLETED)
254
+ for f in done:
255
+ pending.pop(f)
256
+ try:
257
+ res = f.result()
258
+ if res:
259
+ all_states.extend(res["states"])
260
+ all_policies.extend(res["policies"])
261
+ all_winners.extend(res["winners"])
262
+ all_scores.extend(res["scores"])
263
+ all_turns.extend(res["turns_left"])
264
+
265
+ total_completed += 1
266
+ total_samples += len(res["states"])
267
+
268
+ # Update stats
269
+ outcome = res["outcome"]
270
+ w_idx = outcome["winner"]
271
+ turns = outcome["turns"]
272
+
273
+ win_str = "DRAW" if w_idx == 2 else f"P{w_idx} WIN"
274
+
275
+ if w_idx == 2:
276
+ stats["draws"] += 1
277
+ elif w_idx == 0:
278
+ stats["wins"] += 1
279
+ else:
280
+ stats["losses"] += 1
281
+
282
+ # Reduce log spam for large runs
283
+ if total_completed % 10 == 0 or total_completed < 10:
284
+ print(
285
+ f" [Game {total_completed}] {win_str} in {turns} turns | Samples: {len(res['states'])} | Total W/L/D: {stats['wins']}/{stats['losses']}/{stats['draws']}"
286
+ )
287
+
288
+ pbar.update(1)
289
+ if total_completed % chunk_size == 0:
290
+ save_chunk()
291
+ except Exception as e:
292
+ print(f"Game failed: {e}")
293
+
294
+ pbar.close()
295
+
296
+ if all_states:
297
+ save_chunk()
298
+ print(f"Self-play generation complete. Total samples: {total_samples}")
299
+
300
+
301
+ if __name__ == "__main__":
302
+ parser = argparse.ArgumentParser()
303
+ parser.add_argument("--games", type=int, default=100)
304
+ parser.add_argument("--sims", type=int, default=100)
305
+ parser.add_argument("--model", type=str, default="ai/models/alphanet_best.onnx")
306
+ parser.add_argument("--weight", type=float, default=0.3)
307
+ parser.add_argument("--workers", type=int, default=0, help="Number of workers (0 = auto)")
308
+ parser.add_argument("--fast", action="store_true", help="Skip rollouts, use pure NN value (faster)")
309
+ args = parser.parse_args()
310
+
311
+ generate_self_play(
312
+ num_games=args.games,
313
+ model_path=args.model,
314
+ sims=args.sims,
315
+ weight=args.weight,
316
+ skip_rollout=args.fast,
317
+ workers=args.workers,
318
+ )