import argparse import json import os import random import sys import torch from tqdm import tqdm # Add project root to path sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) import engine_rust from ai.agents.neural_mcts import HybridMCTSAgent from ai.models.training_config import POLICY_SIZE from ai.training.train import AlphaNet from ai.utils.benchmark_decks import parse_deck class Agent: def get_action(self, game, db): pass class RandomAgent(Agent): def get_action(self, game, db): actions = game.get_legal_action_ids() if not actions: return 0 return random.choice(actions) class MCTSAgent(Agent): def __init__(self, sims=100): self.sims = sims def get_action(self, game, db): suggestions = game.get_mcts_suggestions(self.sims, engine_rust.SearchHorizon.TurnEnd) if not suggestions: return 0 return suggestions[0][0] class ResNetAgent(Agent): def __init__(self, model_path): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(model_path, map_location=self.device) # Handle new dictionary checkpoint format if isinstance(checkpoint, dict) and "model_state" in checkpoint: state_dict = checkpoint["model_state"] else: state_dict = checkpoint # Detect policy size from weights p_fc_bias = state_dict.get("policy_head_fc.bias") detected_policy_size = p_fc_bias.shape[0] if p_fc_bias is not None else POLICY_SIZE print(f"ResNetAgent: Detected Policy Size {detected_policy_size}") self.model = AlphaNet(policy_size=detected_policy_size).to(self.device) self.model.load_state_dict(state_dict) self.model.eval() self.policy_size = detected_policy_size def get_action(self, game, db): # 1. Encode state encoded = game.encode_state(db) state_tensor = torch.FloatTensor(encoded).unsqueeze(0).to(self.device) # 2. Get policy logits with torch.no_grad(): logits, _ = self.model(state_tensor) # 3. Mask illegal actions legal_ids = game.get_legal_action_ids() mask = torch.full((self.policy_size,), -1e9).to(self.device) for aid in legal_ids: if aid < self.policy_size: mask[int(aid)] = 0.0 masked_logits = logits.squeeze(0) + mask # 4. Argmax return int(torch.argmax(masked_logits).item()) def play_match(agent0, agent1, db_content, decks, game_id): db = engine_rust.PyCardDatabase(db_content) game = engine_rust.PyGameState(db) # Select random decks p0_deck, p0_lives, p0_energy = random.choice(decks) p1_deck, p1_lives, p1_energy = random.choice(decks) game.initialize_game(p0_deck, p1_deck, p0_energy, p1_energy, p0_lives, p1_lives) agents = [agent0, agent1] step = 0 while not game.is_terminal() and step < 1000: cp = game.current_player phase = game.phase is_interactive = phase in [-1, 0, 4, 5] if is_interactive: action = agents[cp].get_action(game, game.db) try: game.step(action) except Exception: # print(f"Action {action} failed: {e}") # Fallback to random if model fails legal = game.get_legal_action_ids() if legal: game.step(int(legal[0])) else: break else: game.step(0) step += 1 return game.get_winner(), game.get_player(0).score, game.get_player(1).score, game.turn def run_tournament(num_games=10): with open("engine/data/cards_compiled.json", "r", encoding="utf-8") as f: db_content = f.read() db_json = json.loads(db_content) # Load Decks deck_paths = [ "ai/decks/aqours_cup.txt", "ai/decks/hasunosora_cup.txt", "ai/decks/liella_cup.txt", "ai/decks/muse_cup.txt", "ai/decks/nijigaku_cup.txt", ] decks = [] for dp in deck_paths: if os.path.exists(dp): decks.append(parse_deck(dp, db_json["member_db"], db_json["live_db"], db_json.get("energy_db", {}))) # Agents # Agents random_agent = RandomAgent() mcts_agent = MCTSAgent(sims=100) # resnet_agent = ResNetAgent("ai/models/alphanet_best.pt") competitors = { "Random": random_agent, "MCTS-100": mcts_agent, # "ResNet-Standalone": resnet_agent, # "Neural-Hybrid (Py)": NeuralHeuristicAgent("ai/models/alphanet_best.pt", sims=100), # "Neural-Rust (Full)": NeuralMCTSFullAgent("ai/models/alphanet.onnx", sims=100), "Neural-Rust (Hybrid)": HybridMCTSAgent("ai/models/alphanet_best.onnx", sims=100, neural_weight=0.3), } results = {name: {"wins": 0, "draws": 0, "losses": 0, "total_score": 0, "turns": []} for name in competitors} matchups = [("Neural-Rust (Hybrid)", "MCTS-100"), ("Neural-Rust (Hybrid)", "Random")] print(f"Starting Tournament: {num_games} rounds per matchup...") for p0_name, p1_name in matchups: print(f"Matchup: {p0_name} vs {p1_name}") for i in tqdm(range(num_games)): # Swap sides every game if i % 2 == 0: winner, s0, s1, t = play_match(competitors[p0_name], competitors[p1_name], db_content, decks, i) results[p0_name]["total_score"] += s0 results[p1_name]["total_score"] += s1 results[p0_name]["turns"].append(t) results[p1_name]["turns"].append(t) if winner == 0: results[p0_name]["wins"] += 1 results[p1_name]["losses"] += 1 elif winner == 1: results[p1_name]["wins"] += 1 results[p0_name]["losses"] += 1 else: results[p0_name]["draws"] += 1 results[p1_name]["draws"] += 1 else: winner, s1, s0, t = play_match(competitors[p1_name], competitors[p0_name], db_content, decks, i) results[p0_name]["total_score"] += s0 results[p1_name]["total_score"] += s1 results[p0_name]["turns"].append(t) results[p1_name]["turns"].append(t) if winner == 0: results[p1_name]["wins"] += 1 results[p0_name]["losses"] += 1 elif winner == 1: results[p0_name]["wins"] += 1 results[p1_name]["losses"] += 1 else: results[p0_name]["draws"] += 1 results[p1_name]["draws"] += 1 print("\nTournament Results:") print(f"{'Agent':<18} | {'Wins':<5} | {'Draws':<5} | {'Losses':<5} | {'Avg Score':<10} | {'Avg Turns':<10}") print("-" * 75) for name, stat in results.items(): total_games = stat["wins"] + stat["draws"] + stat["losses"] avg_score = stat["total_score"] / total_games if total_games > 0 else 0 avg_turns = sum(stat["turns"]) / len(stat["turns"]) if stat["turns"] else 0 print( f"{name:<18} | {stat['wins']:<5} | {stat['draws']:<5} | {stat['losses']:<5} | {avg_score:<10.2f} | {avg_turns:<10.2f}" ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--rounds", type=int, default=10) args = parser.parse_args() run_tournament(num_games=args.rounds)