LovecaSim / ai /utils /tournament.py
trioskosmos's picture
Upload ai/utils/tournament.py with huggingface_hub
a384afe verified
import argparse
import json
import os
import random
import sys
import torch
from tqdm import tqdm
# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import engine_rust
from ai.agents.neural_mcts import HybridMCTSAgent
from ai.models.training_config import POLICY_SIZE
from ai.training.train import AlphaNet
from ai.utils.benchmark_decks import parse_deck
class Agent:
def get_action(self, game, db):
pass
class RandomAgent(Agent):
def get_action(self, game, db):
actions = game.get_legal_action_ids()
if not actions:
return 0
return random.choice(actions)
class MCTSAgent(Agent):
def __init__(self, sims=100):
self.sims = sims
def get_action(self, game, db):
suggestions = game.get_mcts_suggestions(self.sims, engine_rust.SearchHorizon.TurnEnd)
if not suggestions:
return 0
return suggestions[0][0]
class ResNetAgent(Agent):
def __init__(self, model_path):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(model_path, map_location=self.device)
# Handle new dictionary checkpoint format
if isinstance(checkpoint, dict) and "model_state" in checkpoint:
state_dict = checkpoint["model_state"]
else:
state_dict = checkpoint
# Detect policy size from weights
p_fc_bias = state_dict.get("policy_head_fc.bias")
detected_policy_size = p_fc_bias.shape[0] if p_fc_bias is not None else POLICY_SIZE
print(f"ResNetAgent: Detected Policy Size {detected_policy_size}")
self.model = AlphaNet(policy_size=detected_policy_size).to(self.device)
self.model.load_state_dict(state_dict)
self.model.eval()
self.policy_size = detected_policy_size
def get_action(self, game, db):
# 1. Encode state
encoded = game.encode_state(db)
state_tensor = torch.FloatTensor(encoded).unsqueeze(0).to(self.device)
# 2. Get policy logits
with torch.no_grad():
logits, _ = self.model(state_tensor)
# 3. Mask illegal actions
legal_ids = game.get_legal_action_ids()
mask = torch.full((self.policy_size,), -1e9).to(self.device)
for aid in legal_ids:
if aid < self.policy_size:
mask[int(aid)] = 0.0
masked_logits = logits.squeeze(0) + mask
# 4. Argmax
return int(torch.argmax(masked_logits).item())
def play_match(agent0, agent1, db_content, decks, game_id):
db = engine_rust.PyCardDatabase(db_content)
game = engine_rust.PyGameState(db)
# Select random decks
p0_deck, p0_lives, p0_energy = random.choice(decks)
p1_deck, p1_lives, p1_energy = random.choice(decks)
game.initialize_game(p0_deck, p1_deck, p0_energy, p1_energy, p0_lives, p1_lives)
agents = [agent0, agent1]
step = 0
while not game.is_terminal() and step < 1000:
cp = game.current_player
phase = game.phase
is_interactive = phase in [-1, 0, 4, 5]
if is_interactive:
action = agents[cp].get_action(game, game.db)
try:
game.step(action)
except Exception:
# print(f"Action {action} failed: {e}")
# Fallback to random if model fails
legal = game.get_legal_action_ids()
if legal:
game.step(int(legal[0]))
else:
break
else:
game.step(0)
step += 1
return game.get_winner(), game.get_player(0).score, game.get_player(1).score, game.turn
def run_tournament(num_games=10):
with open("engine/data/cards_compiled.json", "r", encoding="utf-8") as f:
db_content = f.read()
db_json = json.loads(db_content)
# Load Decks
deck_paths = [
"ai/decks/aqours_cup.txt",
"ai/decks/hasunosora_cup.txt",
"ai/decks/liella_cup.txt",
"ai/decks/muse_cup.txt",
"ai/decks/nijigaku_cup.txt",
]
decks = []
for dp in deck_paths:
if os.path.exists(dp):
decks.append(parse_deck(dp, db_json["member_db"], db_json["live_db"], db_json.get("energy_db", {})))
# Agents
# Agents
random_agent = RandomAgent()
mcts_agent = MCTSAgent(sims=100)
# resnet_agent = ResNetAgent("ai/models/alphanet_best.pt")
competitors = {
"Random": random_agent,
"MCTS-100": mcts_agent,
# "ResNet-Standalone": resnet_agent,
# "Neural-Hybrid (Py)": NeuralHeuristicAgent("ai/models/alphanet_best.pt", sims=100),
# "Neural-Rust (Full)": NeuralMCTSFullAgent("ai/models/alphanet.onnx", sims=100),
"Neural-Rust (Hybrid)": HybridMCTSAgent("ai/models/alphanet_best.onnx", sims=100, neural_weight=0.3),
}
results = {name: {"wins": 0, "draws": 0, "losses": 0, "total_score": 0, "turns": []} for name in competitors}
matchups = [("Neural-Rust (Hybrid)", "MCTS-100"), ("Neural-Rust (Hybrid)", "Random")]
print(f"Starting Tournament: {num_games} rounds per matchup...")
for p0_name, p1_name in matchups:
print(f"Matchup: {p0_name} vs {p1_name}")
for i in tqdm(range(num_games)):
# Swap sides every game
if i % 2 == 0:
winner, s0, s1, t = play_match(competitors[p0_name], competitors[p1_name], db_content, decks, i)
results[p0_name]["total_score"] += s0
results[p1_name]["total_score"] += s1
results[p0_name]["turns"].append(t)
results[p1_name]["turns"].append(t)
if winner == 0:
results[p0_name]["wins"] += 1
results[p1_name]["losses"] += 1
elif winner == 1:
results[p1_name]["wins"] += 1
results[p0_name]["losses"] += 1
else:
results[p0_name]["draws"] += 1
results[p1_name]["draws"] += 1
else:
winner, s1, s0, t = play_match(competitors[p1_name], competitors[p0_name], db_content, decks, i)
results[p0_name]["total_score"] += s0
results[p1_name]["total_score"] += s1
results[p0_name]["turns"].append(t)
results[p1_name]["turns"].append(t)
if winner == 0:
results[p1_name]["wins"] += 1
results[p0_name]["losses"] += 1
elif winner == 1:
results[p0_name]["wins"] += 1
results[p1_name]["losses"] += 1
else:
results[p0_name]["draws"] += 1
results[p1_name]["draws"] += 1
print("\nTournament Results:")
print(f"{'Agent':<18} | {'Wins':<5} | {'Draws':<5} | {'Losses':<5} | {'Avg Score':<10} | {'Avg Turns':<10}")
print("-" * 75)
for name, stat in results.items():
total_games = stat["wins"] + stat["draws"] + stat["losses"]
avg_score = stat["total_score"] / total_games if total_games > 0 else 0
avg_turns = sum(stat["turns"]) / len(stat["turns"]) if stat["turns"] else 0
print(
f"{name:<18} | {stat['wins']:<5} | {stat['draws']:<5} | {stat['losses']:<5} | {avg_score:<10.2f} | {avg_turns:<10.2f}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--rounds", type=int, default=10)
args = parser.parse_args()
run_tournament(num_games=args.rounds)