import argparse import concurrent.futures import json import multiprocessing import os import random import sys import time import numpy as np from tqdm import tqdm # Pin threads for performance os.environ["RAYON_NUM_THREADS"] = "1" # Add project root to path sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) import engine_rust from ai.utils.benchmark_decks import parse_deck # Global cache for workers (optional, for NN mode) _WORKER_MODEL_PATH = None def worker_init(db_content, model_path=None): global _WORKER_DB, _WORKER_MODEL_PATH _WORKER_DB = engine_rust.PyCardDatabase(db_content) _WORKER_MODEL_PATH = model_path def run_self_play_game(g_idx, sims, p0_deck_info, p1_deck_info): if _WORKER_DB is None: return None game = engine_rust.PyGameState(_WORKER_DB) game.silent = True p0_deck, p0_lives, p0_energy = p0_deck_info p1_deck, p1_lives, p1_energy = p1_deck_info game.initialize_game(p0_deck, p1_deck, p0_energy, p1_energy, p0_lives, p1_lives) game_states = [] game_policies = [] game_turns_remaining = [] game_player_turn = [] game_score_diffs = [] # Target values will be backfilled after game ends step = 0 max_turns = 150 # Estimated max turns for normalization while not game.is_terminal() and step < 1000: cp = game.current_player phase = game.phase # Interactive Phases: Mulligan (-1, 0), Main (4), LiveSet (5) is_interactive = phase in [-1, 0, 4, 5] if is_interactive: # Observation (now 1200) encoded = game.get_observation() if len(encoded) != 1200: # Pad to 1200 if engine mismatch if len(encoded) < 1200: encoded = encoded + [0.0] * (1200 - len(encoded)) else: encoded = encoded[:1200] # Use MCTS with Original Heuristic (Teacher Mode) # If _WORKER_MODEL_PATH is None, we use pure MCTS h_type = "original" if _WORKER_MODEL_PATH is None else "hybrid" suggestions = game.search_mcts( num_sims=sims, seconds=0.0, heuristic_type=h_type, model_path=_WORKER_MODEL_PATH ) # Build policy policy = np.zeros(2000, dtype=np.float32) action_ids = [] visit_counts = [] total_visits = 0 for action, _, visits in suggestions: if action < 2000: action_ids.append(int(action)) visit_counts.append(visits) total_visits += visits if total_visits == 0: legal = list(game.get_legal_action_ids()) action_ids = [int(a) for a in legal if a < 2000] visit_counts = [1.0] * len(action_ids) total_visits = len(action_ids) probs = np.array(visit_counts, dtype=np.float32) / total_visits # Add Noise (Dirichlet) for exploration if len(probs) > 1: noise = np.random.dirichlet([0.3] * len(probs)) probs = 0.75 * probs + 0.25 * noise # CRITICAL: Re-normalize for np.random.choice float precision probs = probs / np.sum(probs) for i, aid in enumerate(action_ids): policy[aid] = probs[i] game_states.append(encoded) game_policies.append(policy) game_player_turn.append(cp) game_turns_remaining.append(float(game.turn)) # Store current turn, normalize later # Action Selection if step < 40: # Explore in early game action = np.random.choice(action_ids, p=probs) else: # Exploit action = action_ids[np.argmax(probs)] try: game.step(int(action)) except: break else: # Auto-step try: game.step(0) except: break step += 1 if not game.is_terminal(): return None winner = game.get_winner() s0 = float(game.get_player(0).score) s1 = float(game.get_player(1).score) final_turn = float(game.turn) # Process rewards and normalized turns winners = [] scores = [] turns_normalized = [] for i in range(len(game_player_turn)): p_idx = game_player_turn[i] # Win Signal (1, 0, -1) if winner == 2: winners.append(0.0) elif p_idx == winner: winners.append(1.0) else: winners.append(-1.0) # Score Diff (Normalized) diff = (s0 - s1) if p_idx == 0 else (s1 - s0) score_norm = np.tanh(diff / 50.0) # Scale roughly to [-1, 1] scores.append(score_norm) # Turns Remaining (Normalized 0..1) # 1.0 at start, 0.0 at end rem = (final_turn - game_turns_remaining[i]) / max_turns turns_normalized.append(np.clip(rem, 0.0, 1.0)) return { "states": np.array(game_states, dtype=np.float32), "policies": np.array(game_policies, dtype=np.float32), "winners": np.array(winners, dtype=np.float32), "scores": np.array(scores, dtype=np.float32), "turns_left": np.array(turns_normalized, dtype=np.float32), "outcome": {"winner": winner, "score": (s0, s1), "turns": game.turn}, } def generate_self_play( num_games=100, model_path="ai/models/alphanet.onnx", output_file="ai/data/self_play_0.npz", sims=100, weight=0.3, skip_rollout=False, workers=0, ): db_path = "engine/data/cards_compiled.json" with open(db_path, "r", encoding="utf-8") as f: db_content = f.read() db_json = json.loads(db_content) # Load Decks (Standard Pool) deck_paths = [ "ai/decks/aqours_cup.txt", "ai/decks/hasunosora_cup.txt", "ai/decks/liella_cup.txt", "ai/decks/muse_cup.txt", "ai/decks/nijigaku_cup.txt", ] decks = [] for dp in deck_paths: if os.path.exists(dp): decks.append(parse_deck(dp, db_json["member_db"], db_json["live_db"], db_json.get("energy_db", {}))) all_states, all_policies, all_winners = [], [], [] all_scores, all_turns = [], [] total_completed = 0 total_samples = 0 chunk_size = 100 # Save every 100 games stats = {"wins": 0, "losses": 0, "draws": 0} if model_path == "None": model_path = None max_workers = workers if workers > 0 else min(multiprocessing.cpu_count(), 12) mode_str = "Teacher (Heuristic MCTS)" if model_path is None else "Student (Hybrid MCTS)" print(f"Starting Self-Play: {num_games} games using {max_workers} workers... Mode: {mode_str}") def save_chunk(): nonlocal all_states, all_policies, all_winners, all_scores, all_turns if not all_states: return ts = int(time.time()) path = output_file.replace(".npz", f"_chunk_{total_completed // chunk_size}_{ts}.npz") print(f"\n[Disk] Saving {len(all_states)} samples to {path}...") np.savez( path, states=np.array(all_states, dtype=np.float32), policies=np.array(all_policies, dtype=np.float32), winners=np.array(all_winners, dtype=np.float32), scores=np.array(all_scores, dtype=np.float32), turns_left=np.array(all_turns, dtype=np.float32), ) all_states, all_policies, all_winners = [], [], [] all_scores, all_turns = [], [] with concurrent.futures.ProcessPoolExecutor( max_workers=max_workers, initializer=worker_init, initargs=(db_content, model_path) ) as executor: pending = {} batch_cap = max_workers * 2 games_submitted = 0 pbar = tqdm(total=num_games) while total_completed < num_games or pending: while len(pending) < batch_cap and games_submitted < num_games: p0, p1 = random.randint(0, len(decks) - 1), random.randint(0, len(decks) - 1) f = executor.submit(run_self_play_game, games_submitted, sims, decks[p0], decks[p1]) pending[f] = games_submitted games_submitted += 1 if not pending: break done, _ = concurrent.futures.wait(pending.keys(), return_when=concurrent.futures.FIRST_COMPLETED) for f in done: pending.pop(f) try: res = f.result() if res: all_states.extend(res["states"]) all_policies.extend(res["policies"]) all_winners.extend(res["winners"]) all_scores.extend(res["scores"]) all_turns.extend(res["turns_left"]) total_completed += 1 total_samples += len(res["states"]) # Update stats outcome = res["outcome"] w_idx = outcome["winner"] turns = outcome["turns"] win_str = "DRAW" if w_idx == 2 else f"P{w_idx} WIN" if w_idx == 2: stats["draws"] += 1 elif w_idx == 0: stats["wins"] += 1 else: stats["losses"] += 1 # Reduce log spam for large runs if total_completed % 10 == 0 or total_completed < 10: print( f" [Game {total_completed}] {win_str} in {turns} turns | Samples: {len(res['states'])} | Total W/L/D: {stats['wins']}/{stats['losses']}/{stats['draws']}" ) pbar.update(1) if total_completed % chunk_size == 0: save_chunk() except Exception as e: print(f"Game failed: {e}") pbar.close() if all_states: save_chunk() print(f"Self-play generation complete. Total samples: {total_samples}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--games", type=int, default=100) parser.add_argument("--sims", type=int, default=100) parser.add_argument("--model", type=str, default="ai/models/alphanet_best.onnx") parser.add_argument("--weight", type=float, default=0.3) parser.add_argument("--workers", type=int, default=0, help="Number of workers (0 = auto)") parser.add_argument("--fast", action="store_true", help="Skip rollouts, use pure NN value (faster)") args = parser.parse_args() generate_self_play( num_games=args.games, model_path=args.model, sims=args.sims, weight=args.weight, skip_rollout=args.fast, workers=args.workers, )