Spaces:
Sleeping
Sleeping
File size: 4,964 Bytes
463f868 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | import os
import sys
import json
import torch
import numpy as np
import random
# Ensure engine_rust is importable
pwd = os.getcwd()
if pwd not in sys.path:
sys.path.append(pwd)
import engine_rust
from alphazero.vanilla_net import HighFidelityAlphaNet, VanillaTransformerConfig
from alphazero.training.vanilla_action_codec import (
ACTION_SPACE,
policy_id_to_engine_action,
build_legal_policy_mask
)
def load_deck_txt(path, db):
with open(path, "r", encoding="utf-8") as f:
lines = f.readlines()
m_list = []
e_list = []
for line in lines:
line = line.strip()
if not line or line.startswith("#"): continue
if " x " in line:
parts = line.split(" x ")
card_no = parts[0].strip()
qty = int(parts[1].strip())
cid = db.id_by_no(card_no)
if cid is None: continue
if cid >= 10000: e_list.extend([cid] * qty)
else: m_list.extend([cid] * qty)
return {"initial_deck": m_list, "energy": e_list}
def main():
root = os.getcwd()
db_path = os.path.join(root, "data", "cards_vanilla.json")
ckpt_path = os.path.join(root, "checkpoints", "vanilla_overnight", "best.pt")
deck_path = os.path.join(root, "ai/decks/muse_cup.txt")
print(f"Loading DB...")
with open(db_path, "r", encoding="utf-8") as f:
db_json = f.read()
rust_db = engine_rust.PyCardDatabase(db_json)
deck_data = load_deck_txt(deck_path, rust_db)
print(f"Loading model (preset: tiny)...")
config = VanillaTransformerConfig.from_preset("tiny")
model = HighFidelityAlphaNet(config)
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(checkpoint["model"])
model.eval()
num_games = 50
model_wins = 0
draws = 0
total_turns = 0
print(f"Starting comparison: Model vs Random ({num_games} games)...")
for i in range(num_games):
if i % 10 == 0:
print(f" Playing game {i}...")
# Initialize Game
state = engine_rust.PyGameState(rust_db)
seed = 42 + i
state.initialize_game_with_seed(
deck_data["initial_deck"], deck_data["initial_deck"],
deck_data["energy"], deck_data["energy"],
[], [], seed
)
state.silent = True
# Determine roles
# Game 0-24: P0 is Model, P1 is Random
# Game 25-49: P0 is Random, P1 is Model
model_player = 0 if i < (num_games // 2) else 1
while not state.is_terminal() and state.turn < 25:
legal_engine_ids = state.get_legal_action_ids()
if not legal_engine_ids:
state.auto_step(rust_db)
continue
curr_player = state.current_player
if curr_player == model_player:
# Model Turn
obs = state.to_vanilla_tensor()
obs_t = torch.from_numpy(obs).unsqueeze(0)
mask = build_legal_policy_mask(state, curr_player, deck_data["initial_deck"], state.phase, legal_engine_ids)
mask_t = torch.from_numpy(mask).unsqueeze(0)
with torch.no_grad():
logits, _ = model(obs_t, mask=mask_t)
probs = torch.softmax(logits, dim=1).squeeze(0).numpy()
# Filter strictly by legal engine actions mapping
legal_probs = []
legal_actions = []
for pid in np.where(mask > 0)[0]:
eng_id = policy_id_to_engine_action(state, curr_player, pid, state.phase, deck_data["initial_deck"])
if eng_id is not None and eng_id in legal_engine_ids:
legal_probs.append(probs[pid])
legal_actions.append(eng_id)
if not legal_actions:
action = random.choice(list(legal_engine_ids))
else:
action = legal_actions[np.argmax(legal_probs)]
else:
# Random Turn
action = random.choice(list(legal_engine_ids))
state.step(int(action))
state.auto_step(rust_db)
total_turns += state.turn
winner = state.get_winner()
if winner == model_player:
model_wins += 1
elif winner == -1:
draws += 1 # Or handle as tie/loss
print(f"\nResults over {num_games} games:")
print(f" Model Wins : {model_wins} ({model_wins/num_games*100:.1f}%)")
print(f" Random Wins: {num_games - model_wins - draws}")
print(f" Draws : {draws}")
print(f" Avg Turns : {total_turns/num_games:.1f}")
if __name__ == "__main__":
main()
|