Spaces:
Running
Running
| import os | |
| import sys | |
| # Critical Performance Tuning: | |
| # Each Python process handles 1 game. If we don't pin Rayon threads to 1, | |
| # every process will try to use ALL CPU cores for its MCTS simulations, | |
| # causing massive thread contention and slowing down generation by 5-10x. | |
| os.environ["RAYON_NUM_THREADS"] = "1" | |
| import argparse | |
| import concurrent.futures | |
| import glob | |
| import json | |
| import multiprocessing | |
| import random | |
| import time | |
| import numpy as np | |
| from tqdm import tqdm | |
| # Add project root to path | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| import engine_rust | |
| from ai.models.training_config import POLICY_SIZE | |
| from ai.utils.benchmark_decks import parse_deck | |
| # Global database cache for workers | |
| _WORKER_DB = None | |
| _WORKER_DB_JSON = None | |
| def worker_init(db_content): | |
| global _WORKER_DB, _WORKER_DB_JSON | |
| _WORKER_DB = engine_rust.PyCardDatabase(db_content) | |
| _WORKER_DB_JSON = json.loads(db_content) | |
| def run_single_game(g_idx, sims, p0_deck_info, p1_deck_info): | |
| if _WORKER_DB is None: | |
| return None | |
| game = engine_rust.PyGameState(_WORKER_DB) | |
| game.silent = True | |
| p0_deck, p0_lives, p0_energy = p0_deck_info | |
| p1_deck, p1_lives, p1_energy = p1_deck_info | |
| game.initialize_game(p0_deck, p1_deck, p0_energy, p1_energy, p0_lives, p1_lives) | |
| game_states = [] | |
| game_policies = [] | |
| game_player_turn = [] | |
| step = 0 | |
| while not game.is_terminal() and step < 1500: # Slightly reduced limit for safety | |
| cp = game.current_player | |
| phase = game.phase | |
| is_interactive = phase in [-1, 0, 4, 5] | |
| if is_interactive: | |
| encoded = game.encode_state(_WORKER_DB) | |
| suggestions = game.get_mcts_suggestions(sims, engine_rust.SearchHorizon.TurnEnd) | |
| policy = np.zeros(POLICY_SIZE, dtype=np.float32) | |
| total_visits = 0 | |
| best_action = 0 | |
| most_visits = -1 | |
| for action, score, visits in suggestions: | |
| if action < POLICY_SIZE: | |
| policy[int(action)] = visits | |
| total_visits += visits | |
| if visits > most_visits: | |
| most_visits = visits | |
| best_action = int(action) | |
| if total_visits > 0: | |
| policy /= total_visits | |
| game_states.append(encoded) | |
| game_policies.append(policy) | |
| game_player_turn.append(cp) | |
| try: | |
| game.step(best_action) | |
| except: | |
| break | |
| else: | |
| try: | |
| game.step(0) | |
| except: | |
| break | |
| step += 1 | |
| if not game.is_terminal(): | |
| return None | |
| winner = game.get_winner() | |
| s0 = game.get_player(0).score | |
| s1 = game.get_player(1).score | |
| game_winners = [] | |
| for cp in game_player_turn: | |
| if winner == 2: # Draw | |
| game_winners.append(0.0) | |
| elif cp == winner: | |
| game_winners.append(1.0) | |
| else: | |
| game_winners.append(-1.0) | |
| # Game end summary for logging | |
| outcome = {"winner": winner, "p0_score": s0, "p1_score": s1, "turns": game.turn} | |
| # tqdm will handle the progress bar, but a periodic print is helpful | |
| if g_idx % 100 == 0: | |
| win_str = "P0" if winner == 0 else "P1" if winner == 1 else "Tie" | |
| print( | |
| f" [Game {g_idx}] Winner: {win_str} | Final Score: {s0}-{s1} | Turns: {game.turn} | States: {len(game_states)}" | |
| ) | |
| return {"states": game_states, "policies": game_policies, "winners": game_winners, "outcome": outcome} | |
| def generate_dataset(num_games=100, output_file="ai/data/data_batch_0.npz", sims=200, resume=False, chunk_size=5000): | |
| db_path = "data/cards_compiled.json" | |
| if not os.path.exists(db_path): | |
| print(f"Error: Database not found at {db_path}") | |
| return | |
| with open(db_path, "r", encoding="utf-8") as f: | |
| db_content = f.read() | |
| db_json = json.loads(db_content) | |
| deck_config = [ | |
| ("Aqours", "ai/decks/aqours_cup.txt"), | |
| ("Hasunosora", "ai/decks/hasunosora_cup.txt"), | |
| ("Liella", "ai/decks/liella_cup.txt"), | |
| ("Muse", "ai/decks/muse_cup.txt"), | |
| ("Nijigasaki", "ai/decks/nijigaku_cup.txt"), | |
| ] | |
| decks = [] | |
| deck_names = [] | |
| print("Loading curriculum decks...") | |
| for name, dp in deck_config: | |
| if os.path.exists(dp): | |
| decks.append(parse_deck(dp, db_json["member_db"], db_json["live_db"], db_json.get("energy_db", {}))) | |
| deck_names.append(name) | |
| if not decks: | |
| p_deck = [124, 127, 130, 132] * 12 | |
| p_lives = [1024, 1025, 1027] | |
| p_energy = [20000] * 10 | |
| decks = [(p_deck, p_lives, p_energy)] | |
| deck_names = ["Starter-SD1"] | |
| total_completed = 0 | |
| total_samples = 0 | |
| stats = {} | |
| for i in range(len(decks)): | |
| for j in range(len(decks)): | |
| stats[(i, j)] = {"games": 0, "p0_wins": 0, "p0_total": 0, "p1_total": 0, "turns_total": 0} | |
| all_states, all_policies, all_winners = [], [], [] | |
| def print_stats_table(): | |
| n = len(deck_names) | |
| print("\n" + "=" * 95) | |
| print(f" DECK VS DECK STATISTICS (Progress: {total_completed}/{num_games} | Samples: {total_samples})") | |
| print("=" * 95) | |
| header = f"{'P0 \\ P1':<12} | " + " | ".join([f"{name[:10]:^14}" for name in deck_names]) | |
| print(header) | |
| print("-" * len(header)) | |
| for i in range(n): | |
| row = f"{deck_names[i]:<12} | " | |
| cols = [] | |
| for j in range(n): | |
| s = stats[(i, j)] | |
| if s["games"] > 0: | |
| wr = (s["p0_wins"] / s["games"]) * 100 | |
| avg0 = s["p0_total"] / s["games"] | |
| avg1 = s["p1_total"] / s["games"] | |
| avg_t = s["turns_total"] / s["games"] | |
| cols.append(f"{wr:>3.0f}%/{avg0:^3.1f}/T{avg_t:<2.1f}") | |
| else: | |
| cols.append(f"{'-':^14}") | |
| print(row + " | ".join(cols)) | |
| print("=" * 95 + "\n") | |
| def save_current_chunk(is_final=False): | |
| nonlocal all_states, all_policies, all_winners | |
| if not all_states: | |
| return | |
| # Unique timestamped or indexed chunks to prevent overwriting during write | |
| chunk_idx = total_completed // chunk_size | |
| path = output_file.replace(".npz", f"_chunk_{chunk_idx}_{int(time.time())}.npz") | |
| print(f"\n[Disk] Attempting to save {len(all_states)} samples to {path}...") | |
| try: | |
| # Step 1: Save UNCOMPRESSED (Fast, less likely to fail mid-write) | |
| np.savez( | |
| path, | |
| states=np.array(all_states, dtype=np.float32), | |
| policies=np.array(all_policies, dtype=np.float32), | |
| winners=np.array(all_winners, dtype=np.float32), | |
| ) | |
| # Step 2: VERIFY immediately | |
| with np.load(path) as data: | |
| if "states" in data.keys() and len(data["states"]) == len(all_states): | |
| print(f" -> VERIFIED: {path} is healthy.") | |
| else: | |
| raise IOError("Verification failed: File is truncated or keys missing.") | |
| # Reset buffers only after successful verification | |
| if not is_final: | |
| all_states, all_policies, all_winners = [], [], [] | |
| except Exception as e: | |
| print(f" !!! CRITICAL SAVE ERROR: {e}") | |
| print(" !!! Data is still in memory, will retry next chunk.") | |
| if resume: | |
| existing = sorted(glob.glob(output_file.replace(".npz", "_chunk_*.npz"))) | |
| if existing: | |
| total_completed = len(existing) * chunk_size | |
| print(f"Resuming from game {total_completed} ({len(existing)} chunks found)") | |
| max_workers = min(multiprocessing.cpu_count(), 16) | |
| print(f"Starting generation using {max_workers} workers...") | |
| try: | |
| with concurrent.futures.ProcessPoolExecutor( | |
| max_workers=max_workers, initializer=worker_init, initargs=(db_content,) | |
| ) as executor: | |
| pending = {} | |
| batch_cap = max_workers * 2 | |
| games_submitted = total_completed | |
| pbar = tqdm(total=num_games, initial=total_completed) | |
| last_save_time = time.time() | |
| while games_submitted < num_games or pending: | |
| current_time = time.time() | |
| # Autosave every 30 minutes | |
| if current_time - last_save_time > 1800: | |
| print("\n[Timer] 30 minutes passed. Autosaving...") | |
| save_current_chunk() | |
| last_save_time = current_time | |
| while len(pending) < batch_cap and games_submitted < num_games: | |
| p0, p1 = random.randint(0, len(decks) - 1), random.randint(0, len(decks) - 1) | |
| f = executor.submit(run_single_game, games_submitted, sims, decks[p0], decks[p1]) | |
| pending[f] = (p0, p1) | |
| games_submitted += 1 | |
| done, _ = concurrent.futures.wait(pending.keys(), return_when=concurrent.futures.FIRST_COMPLETED) | |
| for f in done: | |
| p0, p1 = pending.pop(f) | |
| try: | |
| res = f.result() | |
| if res: | |
| all_states.extend(res["states"]) | |
| all_policies.extend(res["policies"]) | |
| all_winners.extend(res["winners"]) | |
| total_completed += 1 | |
| total_samples += len(res["states"]) | |
| pbar.update(1) | |
| o = res["outcome"] | |
| s = stats[(p0, p1)] | |
| s["games"] += 1 | |
| if o["winner"] == 0: | |
| s["p0_wins"] += 1 | |
| s["p0_total"] += o["p0_score"] | |
| s["p1_total"] += o["p1_score"] | |
| s["turns_total"] += o["turns"] | |
| if total_completed % chunk_size == 0: | |
| save_current_chunk() | |
| print_stats_table() | |
| # REMOVED: dangerous 100-game re-compression checkpoints | |
| except Exception: | |
| pass | |
| pbar.close() | |
| except KeyboardInterrupt: | |
| print("\nStopping...") | |
| save_current_chunk(is_final=True) | |
| print_stats_table() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--num-games", type=int, default=100) | |
| parser.add_argument("--output-file", type=str, default="ai/data/data_batch_0.npz") | |
| parser.add_argument("--sims", type=int, default=400) | |
| parser.add_argument("--resume", action="store_true") | |
| parser.add_argument("--chunk-size", type=int, default=1000) | |
| args = parser.parse_args() | |
| generate_dataset( | |
| num_games=args.num_games, | |
| output_file=args.output_file, | |
| sims=args.sims, | |
| resume=args.resume, | |
| chunk_size=args.chunk_size, | |
| ) | |