trioskosmos commited on
Commit
996ab02
·
verified ·
1 Parent(s): ed62625

Upload ai/data_generation/generate_data.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/data_generation/generate_data.py +310 -0
ai/data_generation/generate_data.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # Critical Performance Tuning:
5
+ # Each Python process handles 1 game. If we don't pin Rayon threads to 1,
6
+ # every process will try to use ALL CPU cores for its MCTS simulations,
7
+ # causing massive thread contention and slowing down generation by 5-10x.
8
+ os.environ["RAYON_NUM_THREADS"] = "1"
9
+
10
+ import argparse
11
+ import concurrent.futures
12
+ import glob
13
+ import json
14
+ import multiprocessing
15
+ import random
16
+ import time
17
+
18
+ import numpy as np
19
+ from tqdm import tqdm
20
+
21
+ # Add project root to path
22
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
23
+
24
+ import engine_rust
25
+
26
+ from ai.models.training_config import POLICY_SIZE
27
+ from ai.utils.benchmark_decks import parse_deck
28
+
29
+ # Global database cache for workers
30
+ _WORKER_DB = None
31
+ _WORKER_DB_JSON = None
32
+
33
+
34
+ def worker_init(db_content):
35
+ global _WORKER_DB, _WORKER_DB_JSON
36
+ _WORKER_DB = engine_rust.PyCardDatabase(db_content)
37
+ _WORKER_DB_JSON = json.loads(db_content)
38
+
39
+
40
+ def run_single_game(g_idx, sims, p0_deck_info, p1_deck_info):
41
+ if _WORKER_DB is None:
42
+ return None
43
+
44
+ game = engine_rust.PyGameState(_WORKER_DB)
45
+ game.silent = True
46
+ p0_deck, p0_lives, p0_energy = p0_deck_info
47
+ p1_deck, p1_lives, p1_energy = p1_deck_info
48
+
49
+ game.initialize_game(p0_deck, p1_deck, p0_energy, p1_energy, p0_lives, p1_lives)
50
+
51
+ game_states = []
52
+ game_policies = []
53
+ game_player_turn = []
54
+
55
+ step = 0
56
+ while not game.is_terminal() and step < 1500: # Slightly reduced limit for safety
57
+ cp = game.current_player
58
+ phase = game.phase
59
+
60
+ is_interactive = phase in [-1, 0, 4, 5]
61
+
62
+ if is_interactive:
63
+ encoded = game.encode_state(_WORKER_DB)
64
+ suggestions = game.get_mcts_suggestions(sims, engine_rust.SearchHorizon.TurnEnd)
65
+
66
+ policy = np.zeros(POLICY_SIZE, dtype=np.float32)
67
+ total_visits = 0
68
+ best_action = 0
69
+ most_visits = -1
70
+
71
+ for action, score, visits in suggestions:
72
+ if action < POLICY_SIZE:
73
+ policy[int(action)] = visits
74
+ total_visits += visits
75
+ if visits > most_visits:
76
+ most_visits = visits
77
+ best_action = int(action)
78
+
79
+ if total_visits > 0:
80
+ policy /= total_visits
81
+
82
+ game_states.append(encoded)
83
+ game_policies.append(policy)
84
+ game_player_turn.append(cp)
85
+
86
+ try:
87
+ game.step(best_action)
88
+ except:
89
+ break
90
+ else:
91
+ try:
92
+ game.step(0)
93
+ except:
94
+ break
95
+ step += 1
96
+
97
+ if not game.is_terminal():
98
+ return None
99
+
100
+ winner = game.get_winner()
101
+ s0 = game.get_player(0).score
102
+ s1 = game.get_player(1).score
103
+
104
+ game_winners = []
105
+ for cp in game_player_turn:
106
+ if winner == 2: # Draw
107
+ game_winners.append(0.0)
108
+ elif cp == winner:
109
+ game_winners.append(1.0)
110
+ else:
111
+ game_winners.append(-1.0)
112
+
113
+ # Game end summary for logging
114
+ outcome = {"winner": winner, "p0_score": s0, "p1_score": s1, "turns": game.turn}
115
+
116
+ # tqdm will handle the progress bar, but a periodic print is helpful
117
+ if g_idx % 100 == 0:
118
+ win_str = "P0" if winner == 0 else "P1" if winner == 1 else "Tie"
119
+ print(
120
+ f" [Game {g_idx}] Winner: {win_str} | Final Score: {s0}-{s1} | Turns: {game.turn} | States: {len(game_states)}"
121
+ )
122
+
123
+ return {"states": game_states, "policies": game_policies, "winners": game_winners, "outcome": outcome}
124
+
125
+
126
+ def generate_dataset(num_games=100, output_file="ai/data/data_batch_0.npz", sims=200, resume=False, chunk_size=5000):
127
+ db_path = "data/cards_compiled.json"
128
+ if not os.path.exists(db_path):
129
+ print(f"Error: Database not found at {db_path}")
130
+ return
131
+
132
+ with open(db_path, "r", encoding="utf-8") as f:
133
+ db_content = f.read()
134
+ db_json = json.loads(db_content)
135
+
136
+ deck_config = [
137
+ ("Aqours", "ai/decks/aqours_cup.txt"),
138
+ ("Hasunosora", "ai/decks/hasunosora_cup.txt"),
139
+ ("Liella", "ai/decks/liella_cup.txt"),
140
+ ("Muse", "ai/decks/muse_cup.txt"),
141
+ ("Nijigasaki", "ai/decks/nijigaku_cup.txt"),
142
+ ]
143
+ decks = []
144
+ deck_names = []
145
+ print("Loading curriculum decks...")
146
+ for name, dp in deck_config:
147
+ if os.path.exists(dp):
148
+ decks.append(parse_deck(dp, db_json["member_db"], db_json["live_db"], db_json.get("energy_db", {})))
149
+ deck_names.append(name)
150
+
151
+ if not decks:
152
+ p_deck = [124, 127, 130, 132] * 12
153
+ p_lives = [1024, 1025, 1027]
154
+ p_energy = [20000] * 10
155
+ decks = [(p_deck, p_lives, p_energy)]
156
+ deck_names = ["Starter-SD1"]
157
+
158
+ total_completed = 0
159
+ total_samples = 0
160
+ stats = {}
161
+ for i in range(len(decks)):
162
+ for j in range(len(decks)):
163
+ stats[(i, j)] = {"games": 0, "p0_wins": 0, "p0_total": 0, "p1_total": 0, "turns_total": 0}
164
+
165
+ all_states, all_policies, all_winners = [], [], []
166
+
167
+ def print_stats_table():
168
+ n = len(deck_names)
169
+ print("\n" + "=" * 95)
170
+ print(f" DECK VS DECK STATISTICS (Progress: {total_completed}/{num_games} | Samples: {total_samples})")
171
+ print("=" * 95)
172
+ header = f"{'P0 \\ P1':<12} | " + " | ".join([f"{name[:10]:^14}" for name in deck_names])
173
+ print(header)
174
+ print("-" * len(header))
175
+ for i in range(n):
176
+ row = f"{deck_names[i]:<12} | "
177
+ cols = []
178
+ for j in range(n):
179
+ s = stats[(i, j)]
180
+ if s["games"] > 0:
181
+ wr = (s["p0_wins"] / s["games"]) * 100
182
+ avg0 = s["p0_total"] / s["games"]
183
+ avg1 = s["p1_total"] / s["games"]
184
+ avg_t = s["turns_total"] / s["games"]
185
+ cols.append(f"{wr:>3.0f}%/{avg0:^3.1f}/T{avg_t:<2.1f}")
186
+ else:
187
+ cols.append(f"{'-':^14}")
188
+ print(row + " | ".join(cols))
189
+ print("=" * 95 + "\n")
190
+
191
+ def save_current_chunk(is_final=False):
192
+ nonlocal all_states, all_policies, all_winners
193
+ if not all_states:
194
+ return
195
+
196
+ # Unique timestamped or indexed chunks to prevent overwriting during write
197
+ chunk_idx = total_completed // chunk_size
198
+ path = output_file.replace(".npz", f"_chunk_{chunk_idx}_{int(time.time())}.npz")
199
+
200
+ print(f"\n[Disk] Attempting to save {len(all_states)} samples to {path}...")
201
+
202
+ try:
203
+ # Step 1: Save UNCOMPRESSED (Fast, less likely to fail mid-write)
204
+ np.savez(
205
+ path,
206
+ states=np.array(all_states, dtype=np.float32),
207
+ policies=np.array(all_policies, dtype=np.float32),
208
+ winners=np.array(all_winners, dtype=np.float32),
209
+ )
210
+
211
+ # Step 2: VERIFY immediately
212
+ with np.load(path) as data:
213
+ if "states" in data.keys() and len(data["states"]) == len(all_states):
214
+ print(f" -> VERIFIED: {path} is healthy.")
215
+ else:
216
+ raise IOError("Verification failed: File is truncated or keys missing.")
217
+
218
+ # Reset buffers only after successful verification
219
+ if not is_final:
220
+ all_states, all_policies, all_winners = [], [], []
221
+
222
+ except Exception as e:
223
+ print(f" !!! CRITICAL SAVE ERROR: {e}")
224
+ print(" !!! Data is still in memory, will retry next chunk.")
225
+
226
+ if resume:
227
+ existing = sorted(glob.glob(output_file.replace(".npz", "_chunk_*.npz")))
228
+ if existing:
229
+ total_completed = len(existing) * chunk_size
230
+ print(f"Resuming from game {total_completed} ({len(existing)} chunks found)")
231
+
232
+ max_workers = min(multiprocessing.cpu_count(), 16)
233
+ print(f"Starting generation using {max_workers} workers...")
234
+
235
+ try:
236
+ with concurrent.futures.ProcessPoolExecutor(
237
+ max_workers=max_workers, initializer=worker_init, initargs=(db_content,)
238
+ ) as executor:
239
+ pending = {}
240
+ batch_cap = max_workers * 2
241
+ games_submitted = total_completed
242
+
243
+ pbar = tqdm(total=num_games, initial=total_completed)
244
+ last_save_time = time.time()
245
+
246
+ while games_submitted < num_games or pending:
247
+ current_time = time.time()
248
+ # Autosave every 30 minutes
249
+ if current_time - last_save_time > 1800:
250
+ print("\n[Timer] 30 minutes passed. Autosaving...")
251
+ save_current_chunk()
252
+ last_save_time = current_time
253
+
254
+ while len(pending) < batch_cap and games_submitted < num_games:
255
+ p0, p1 = random.randint(0, len(decks) - 1), random.randint(0, len(decks) - 1)
256
+ f = executor.submit(run_single_game, games_submitted, sims, decks[p0], decks[p1])
257
+ pending[f] = (p0, p1)
258
+ games_submitted += 1
259
+
260
+ done, _ = concurrent.futures.wait(pending.keys(), return_when=concurrent.futures.FIRST_COMPLETED)
261
+ for f in done:
262
+ p0, p1 = pending.pop(f)
263
+ try:
264
+ res = f.result()
265
+ if res:
266
+ all_states.extend(res["states"])
267
+ all_policies.extend(res["policies"])
268
+ all_winners.extend(res["winners"])
269
+ total_completed += 1
270
+ total_samples += len(res["states"])
271
+ pbar.update(1)
272
+
273
+ o = res["outcome"]
274
+ s = stats[(p0, p1)]
275
+ s["games"] += 1
276
+ if o["winner"] == 0:
277
+ s["p0_wins"] += 1
278
+ s["p0_total"] += o["p0_score"]
279
+ s["p1_total"] += o["p1_score"]
280
+ s["turns_total"] += o["turns"]
281
+
282
+ if total_completed % chunk_size == 0:
283
+ save_current_chunk()
284
+ print_stats_table()
285
+ # REMOVED: dangerous 100-game re-compression checkpoints
286
+ except Exception:
287
+ pass
288
+ pbar.close()
289
+ except KeyboardInterrupt:
290
+ print("\nStopping...")
291
+
292
+ save_current_chunk(is_final=True)
293
+ print_stats_table()
294
+
295
+
296
+ if __name__ == "__main__":
297
+ parser = argparse.ArgumentParser()
298
+ parser.add_argument("--num-games", type=int, default=100)
299
+ parser.add_argument("--output-file", type=str, default="ai/data/data_batch_0.npz")
300
+ parser.add_argument("--sims", type=int, default=400)
301
+ parser.add_argument("--resume", action="store_true")
302
+ parser.add_argument("--chunk-size", type=int, default=1000)
303
+ args = parser.parse_args()
304
+ generate_dataset(
305
+ num_games=args.num_games,
306
+ output_file=args.output_file,
307
+ sims=args.sims,
308
+ resume=args.resume,
309
+ chunk_size=args.chunk_size,
310
+ )