LovecaSim / ai /utils /mcts_battle.py
trioskosmos's picture
Upload ai/utils/mcts_battle.py with huggingface_hub
e95dcfe verified
import argparse
import concurrent.futures
import json
import multiprocessing
import os
import random
import sys
from tqdm import tqdm
# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import engine_rust
from ai.utils.benchmark_decks import parse_deck
def run_single_battle(
g_idx, sims0, sims1, time0, time1, h0, h1, m0, m1, hr0, hr1, db_content, decks, alternate=False, verbose=False
):
db = engine_rust.PyCardDatabase(db_content)
game = engine_rust.PyGameState(db)
game.silent = not verbose
# Convert strings to Rust enums
modes = {
"blind": engine_rust.EvalMode.Blind,
"normal": engine_rust.EvalMode.Normal,
}
horizons = {
"game": engine_rust.SearchHorizon.GameEnd,
"turn": engine_rust.SearchHorizon.TurnEnd,
}
m_enum0 = modes.get(m0, engine_rust.EvalMode.Blind)
m_enum1 = modes.get(m1, engine_rust.EvalMode.Blind)
hr_enum0 = horizons.get(hr0, engine_rust.SearchHorizon.GameEnd)
hr_enum1 = horizons.get(hr1, engine_rust.SearchHorizon.GameEnd)
# Select random decks
p0_deck, p0_lives, p0_energy = random.choice(decks)
p1_deck, p1_lives, p1_energy = random.choice(decks)
# Alternate who goes first
if alternate and g_idx % 2 == 1:
game.initialize_game(p1_deck, p0_deck, p1_energy, p0_energy, p1_lives, p0_lives)
sims = [sims1, sims0]
time_limits = [time1, time0]
heuristics = [h1, h0]
m_enums = [m_enum1, m_enum0]
hr_enums = [hr_enum1, hr_enum0]
p1_is_p0_in_engine = True
else:
game.initialize_game(p0_deck, p1_deck, p0_energy, p1_energy, p0_lives, p1_lives)
sims = [sims0, sims1]
time_limits = [time0, time1]
heuristics = [h0, h1]
m_enums = [m_enum0, m_enum1]
hr_enums = [hr_enum0, hr_enum1]
p1_is_p0_in_engine = False
step = 0
p0_sims_total, p1_sims_total = 0, 0
p0_moves, p1_moves = 0, 0
while not game.is_terminal() and step < 1000:
cp = game.current_player
phase = game.phase
# -1: MulliganP1, 0: MulliganP2, 2: Energy, 4: Main, 5: LiveSet, 8: LiveResult
if phase in [-1, 0, 2, 4, 5, 8]:
# Use time limit if provided, otherwise generic sims
t_limit = time_limits[cp] if time_limits else 0.0
n_sims = sims[cp] if t_limit <= 0.0 else 0
suggestions = game.search_mcts(n_sims, t_limit, heuristics[cp], hr_enums[cp], m_enums[cp])
best_action = suggestions[0][0] if suggestions else 0
try:
# Track sims
total_sims = sum(s[2] for s in suggestions)
if cp == 0:
p0_sims_total += total_sims
p0_moves += 1
else:
p1_sims_total += total_sims
p1_moves += 1
if verbose:
tqdm.write(f" P{cp} {phase}: Action {best_action} ({total_sims} sims)")
game.step(best_action)
except Exception as e:
if verbose:
tqdm.write(f" Step Error (Phase {phase}): {e}")
break
else:
try:
game.step(0)
except Exception as e:
if verbose:
tqdm.write(f" Auto-Step Error (Phase {phase}): {e}")
break
step += 1
if step >= 1000:
if verbose:
tqdm.write(f" Action Limit Reached (1000 steps) for Game {g_idx}")
winner = game.get_winner()
if alternate and p1_is_p0_in_engine:
if winner == 0:
actual_winner = 1
elif winner == 1:
actual_winner = 0
else:
actual_winner = 2
p0_score = game.get_player(1).score
p1_score = game.get_player(0).score
else:
actual_winner = winner
p0_score = game.get_player(0).score
p1_score = game.get_player(1).score
# Save rule log for draws
if actual_winner == 2:
log_dir = "ai/data/draw_logs"
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(log_dir, f"draw_game_{g_idx}.txt")
with open(log_path, "w", encoding="utf-8") as f_log:
f_log.write("\n".join(game.rule_log))
return {
"game_id": g_idx,
"winner": actual_winner,
"p0_score": p0_score,
"p1_score": p1_score,
"turns": game.turn,
"steps": step,
"avg_sims_p0": p0_sims_total / max(1, p0_moves) if p0_moves > 0 else 0,
"avg_sims_p1": p1_sims_total / max(1, p1_moves) if p1_moves > 0 else 0,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num-games", type=int, default=10)
parser.add_argument("--sims0", type=int, default=100)
parser.add_argument("--sims1", type=int, default=100)
parser.add_argument("--time0", type=float, default=0.0, help="Time limit per action for P0 (seconds)")
parser.add_argument("--time1", type=float, default=0.0, help="Time limit per action for P1 (seconds)")
parser.add_argument("--h0", type=str, default="original", choices=["original", "simple", "resnet", "hybrid"])
parser.add_argument("--h1", type=str, default="original", choices=["original", "simple", "resnet", "hybrid"])
parser.add_argument("--mode0", type=str, default="blind", choices=["blind", "normal"])
parser.add_argument("--mode1", type=str, default="blind", choices=["blind", "normal"])
parser.add_argument("--horizon0", type=str, default="game", choices=["game", "turn"])
parser.add_argument("--horizon1", type=str, default="game", choices=["game", "turn"])
parser.add_argument("--alternate", action="store_true")
parser.add_argument("--verbose", action="store_true", help="Enable move-by-move logging")
parser.add_argument("--workers", type=int, default=min(multiprocessing.cpu_count(), 4))
parser.add_argument("--output-file", type=str, default="ai/data/mcts_battle_results.json")
args = parser.parse_args()
print("Loading data and decks...")
with open("engine/data/cards_compiled.json", "r", encoding="utf-8") as f:
db_content = f.read()
db_json = json.loads(db_content)
deck_paths = [
"ai/decks/aqours_cup.txt",
"ai/decks/hasunosora_cup.txt",
"ai/decks/liella_cup.txt",
"ai/decks/muse_cup.txt",
"ai/decks/nijigaku_cup.txt",
]
decks = []
for dp in deck_paths:
if os.path.exists(dp):
decks.append(parse_deck(dp, db_json["member_db"], db_json["live_db"], db_json.get("energy_db", {})))
if not decks:
p_deck = [124, 127, 130, 132] * 12
p_lives = [30000, 30001, 30002]
p_energy = [40000] * 10
decks = [(p_deck, p_lives, p_energy)]
print(
f"Starting Battle: P0({args.h0}, {args.mode0}, {args.horizon0}) vs P1({args.h1}, {args.mode1}, {args.horizon1})"
)
results = []
p0_wins, p1_wins, draws = 0, 0, 0
def save_results(results_list, path):
if not results_list:
return
summary = {
"num_games": len(results_list),
"p0": {"sims": args.sims0, "h": args.h0, "m": args.mode0, "hr": args.horizon0},
"p1": {"sims": args.sims1, "h": args.h1, "m": args.mode1, "hr": args.horizon1},
"p0_wins": sum(1 for r in results_list if r["winner"] == 0),
"p1_wins": sum(1 for r in results_list if r["winner"] == 1),
"draws": sum(1 for r in results_list if r["winner"] not in [0, 1]),
"avg_p0_score": sum(r["p0_score"] for r in results_list) / len(results_list),
"avg_p1_score": sum(r["p1_score"] for r in results_list) / len(results_list),
"games": results_list,
}
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as f_out:
json.dump(summary, f_out, indent=2)
try:
with concurrent.futures.ProcessPoolExecutor(max_workers=args.workers) as executor:
futures = []
for i in range(args.num_games):
futures.append(
executor.submit(
run_single_battle,
i,
args.sims0,
args.sims1,
args.time0,
args.time1,
args.h0,
args.h1,
args.mode0,
args.mode1,
args.horizon0,
args.horizon1,
db_content,
decks,
args.alternate,
args.verbose,
)
)
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Playing"):
try:
res = future.result()
results.append(res)
# Determine winner string
winner_str = "P0 Wins" if res["winner"] == 0 else ("P1 Wins" if res["winner"] == 1 else "Draw")
# Print result immediately
tqdm.write(
f"Game {res['game_id']:2}: {winner_str:8} | P0: {res['p0_score']:2} - P1: {res['p1_score']:2} | Turns: {res['turns']:2}"
)
if res["winner"] == 0:
p0_wins += 1
elif res["winner"] == 1:
p1_wins += 1
else:
draws += 1
except Exception as e:
tqdm.write(f"Game failed: {e}")
except KeyboardInterrupt:
print("\nInterrupted!")
if results:
save_results(results, args.output_file)
print(f"\nResults Summary: P0 Wins: {p0_wins}, P1 Wins: {p1_wins}, Draws: {draws}")
print(
f"Avg Score: P0={sum(r['p0_score'] for r in results) / len(results):.2f}, P1={sum(r['p1_score'] for r in results) / len(results):.2f}"
)
print(f"Avg Turns: {sum(r['turns'] for r in results) / len(results):.2f}")
print(
f"Avg Sims/Action: P0={sum(r.get('avg_sims_p0', 0) for r in results) / len(results):.0f}, P1={sum(r.get('avg_sims_p1', 0) for r in results) / len(results):.0f}"
)
if __name__ == "__main__":
main()