LovecaSim / ai /data_generation /generate_data.py
trioskosmos's picture
Upload ai/data_generation/generate_data.py with huggingface_hub
996ab02 verified
import os
import sys
# Critical Performance Tuning:
# Each Python process handles 1 game. If we don't pin Rayon threads to 1,
# every process will try to use ALL CPU cores for its MCTS simulations,
# causing massive thread contention and slowing down generation by 5-10x.
os.environ["RAYON_NUM_THREADS"] = "1"
import argparse
import concurrent.futures
import glob
import json
import multiprocessing
import random
import time
import numpy as np
from tqdm import tqdm
# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import engine_rust
from ai.models.training_config import POLICY_SIZE
from ai.utils.benchmark_decks import parse_deck
# Global database cache for workers
_WORKER_DB = None
_WORKER_DB_JSON = None
def worker_init(db_content):
global _WORKER_DB, _WORKER_DB_JSON
_WORKER_DB = engine_rust.PyCardDatabase(db_content)
_WORKER_DB_JSON = json.loads(db_content)
def run_single_game(g_idx, sims, p0_deck_info, p1_deck_info):
if _WORKER_DB is None:
return None
game = engine_rust.PyGameState(_WORKER_DB)
game.silent = True
p0_deck, p0_lives, p0_energy = p0_deck_info
p1_deck, p1_lives, p1_energy = p1_deck_info
game.initialize_game(p0_deck, p1_deck, p0_energy, p1_energy, p0_lives, p1_lives)
game_states = []
game_policies = []
game_player_turn = []
step = 0
while not game.is_terminal() and step < 1500: # Slightly reduced limit for safety
cp = game.current_player
phase = game.phase
is_interactive = phase in [-1, 0, 4, 5]
if is_interactive:
encoded = game.encode_state(_WORKER_DB)
suggestions = game.get_mcts_suggestions(sims, engine_rust.SearchHorizon.TurnEnd)
policy = np.zeros(POLICY_SIZE, dtype=np.float32)
total_visits = 0
best_action = 0
most_visits = -1
for action, score, visits in suggestions:
if action < POLICY_SIZE:
policy[int(action)] = visits
total_visits += visits
if visits > most_visits:
most_visits = visits
best_action = int(action)
if total_visits > 0:
policy /= total_visits
game_states.append(encoded)
game_policies.append(policy)
game_player_turn.append(cp)
try:
game.step(best_action)
except:
break
else:
try:
game.step(0)
except:
break
step += 1
if not game.is_terminal():
return None
winner = game.get_winner()
s0 = game.get_player(0).score
s1 = game.get_player(1).score
game_winners = []
for cp in game_player_turn:
if winner == 2: # Draw
game_winners.append(0.0)
elif cp == winner:
game_winners.append(1.0)
else:
game_winners.append(-1.0)
# Game end summary for logging
outcome = {"winner": winner, "p0_score": s0, "p1_score": s1, "turns": game.turn}
# tqdm will handle the progress bar, but a periodic print is helpful
if g_idx % 100 == 0:
win_str = "P0" if winner == 0 else "P1" if winner == 1 else "Tie"
print(
f" [Game {g_idx}] Winner: {win_str} | Final Score: {s0}-{s1} | Turns: {game.turn} | States: {len(game_states)}"
)
return {"states": game_states, "policies": game_policies, "winners": game_winners, "outcome": outcome}
def generate_dataset(num_games=100, output_file="ai/data/data_batch_0.npz", sims=200, resume=False, chunk_size=5000):
db_path = "data/cards_compiled.json"
if not os.path.exists(db_path):
print(f"Error: Database not found at {db_path}")
return
with open(db_path, "r", encoding="utf-8") as f:
db_content = f.read()
db_json = json.loads(db_content)
deck_config = [
("Aqours", "ai/decks/aqours_cup.txt"),
("Hasunosora", "ai/decks/hasunosora_cup.txt"),
("Liella", "ai/decks/liella_cup.txt"),
("Muse", "ai/decks/muse_cup.txt"),
("Nijigasaki", "ai/decks/nijigaku_cup.txt"),
]
decks = []
deck_names = []
print("Loading curriculum decks...")
for name, dp in deck_config:
if os.path.exists(dp):
decks.append(parse_deck(dp, db_json["member_db"], db_json["live_db"], db_json.get("energy_db", {})))
deck_names.append(name)
if not decks:
p_deck = [124, 127, 130, 132] * 12
p_lives = [1024, 1025, 1027]
p_energy = [20000] * 10
decks = [(p_deck, p_lives, p_energy)]
deck_names = ["Starter-SD1"]
total_completed = 0
total_samples = 0
stats = {}
for i in range(len(decks)):
for j in range(len(decks)):
stats[(i, j)] = {"games": 0, "p0_wins": 0, "p0_total": 0, "p1_total": 0, "turns_total": 0}
all_states, all_policies, all_winners = [], [], []
def print_stats_table():
n = len(deck_names)
print("\n" + "=" * 95)
print(f" DECK VS DECK STATISTICS (Progress: {total_completed}/{num_games} | Samples: {total_samples})")
print("=" * 95)
header = f"{'P0 \\ P1':<12} | " + " | ".join([f"{name[:10]:^14}" for name in deck_names])
print(header)
print("-" * len(header))
for i in range(n):
row = f"{deck_names[i]:<12} | "
cols = []
for j in range(n):
s = stats[(i, j)]
if s["games"] > 0:
wr = (s["p0_wins"] / s["games"]) * 100
avg0 = s["p0_total"] / s["games"]
avg1 = s["p1_total"] / s["games"]
avg_t = s["turns_total"] / s["games"]
cols.append(f"{wr:>3.0f}%/{avg0:^3.1f}/T{avg_t:<2.1f}")
else:
cols.append(f"{'-':^14}")
print(row + " | ".join(cols))
print("=" * 95 + "\n")
def save_current_chunk(is_final=False):
nonlocal all_states, all_policies, all_winners
if not all_states:
return
# Unique timestamped or indexed chunks to prevent overwriting during write
chunk_idx = total_completed // chunk_size
path = output_file.replace(".npz", f"_chunk_{chunk_idx}_{int(time.time())}.npz")
print(f"\n[Disk] Attempting to save {len(all_states)} samples to {path}...")
try:
# Step 1: Save UNCOMPRESSED (Fast, less likely to fail mid-write)
np.savez(
path,
states=np.array(all_states, dtype=np.float32),
policies=np.array(all_policies, dtype=np.float32),
winners=np.array(all_winners, dtype=np.float32),
)
# Step 2: VERIFY immediately
with np.load(path) as data:
if "states" in data.keys() and len(data["states"]) == len(all_states):
print(f" -> VERIFIED: {path} is healthy.")
else:
raise IOError("Verification failed: File is truncated or keys missing.")
# Reset buffers only after successful verification
if not is_final:
all_states, all_policies, all_winners = [], [], []
except Exception as e:
print(f" !!! CRITICAL SAVE ERROR: {e}")
print(" !!! Data is still in memory, will retry next chunk.")
if resume:
existing = sorted(glob.glob(output_file.replace(".npz", "_chunk_*.npz")))
if existing:
total_completed = len(existing) * chunk_size
print(f"Resuming from game {total_completed} ({len(existing)} chunks found)")
max_workers = min(multiprocessing.cpu_count(), 16)
print(f"Starting generation using {max_workers} workers...")
try:
with concurrent.futures.ProcessPoolExecutor(
max_workers=max_workers, initializer=worker_init, initargs=(db_content,)
) as executor:
pending = {}
batch_cap = max_workers * 2
games_submitted = total_completed
pbar = tqdm(total=num_games, initial=total_completed)
last_save_time = time.time()
while games_submitted < num_games or pending:
current_time = time.time()
# Autosave every 30 minutes
if current_time - last_save_time > 1800:
print("\n[Timer] 30 minutes passed. Autosaving...")
save_current_chunk()
last_save_time = current_time
while len(pending) < batch_cap and games_submitted < num_games:
p0, p1 = random.randint(0, len(decks) - 1), random.randint(0, len(decks) - 1)
f = executor.submit(run_single_game, games_submitted, sims, decks[p0], decks[p1])
pending[f] = (p0, p1)
games_submitted += 1
done, _ = concurrent.futures.wait(pending.keys(), return_when=concurrent.futures.FIRST_COMPLETED)
for f in done:
p0, p1 = pending.pop(f)
try:
res = f.result()
if res:
all_states.extend(res["states"])
all_policies.extend(res["policies"])
all_winners.extend(res["winners"])
total_completed += 1
total_samples += len(res["states"])
pbar.update(1)
o = res["outcome"]
s = stats[(p0, p1)]
s["games"] += 1
if o["winner"] == 0:
s["p0_wins"] += 1
s["p0_total"] += o["p0_score"]
s["p1_total"] += o["p1_score"]
s["turns_total"] += o["turns"]
if total_completed % chunk_size == 0:
save_current_chunk()
print_stats_table()
# REMOVED: dangerous 100-game re-compression checkpoints
except Exception:
pass
pbar.close()
except KeyboardInterrupt:
print("\nStopping...")
save_current_chunk(is_final=True)
print_stats_table()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-games", type=int, default=100)
parser.add_argument("--output-file", type=str, default="ai/data/data_batch_0.npz")
parser.add_argument("--sims", type=int, default=400)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--chunk-size", type=int, default=1000)
args = parser.parse_args()
generate_dataset(
num_games=args.num_games,
output_file=args.output_file,
sims=args.sims,
resume=args.resume,
chunk_size=args.chunk_size,
)