rabukasim / scripts /evaluate_vanilla.py
trioskosmos's picture
Upload folder using huggingface_hub
463f868 verified
import os
import sys
import json
import torch
import numpy as np
import random
# Ensure engine_rust is importable
pwd = os.getcwd()
if pwd not in sys.path:
sys.path.append(pwd)
import engine_rust
from alphazero.vanilla_net import HighFidelityAlphaNet, VanillaTransformerConfig
from alphazero.training.vanilla_action_codec import (
ACTION_SPACE,
policy_id_to_engine_action,
build_legal_policy_mask
)
def load_deck_txt(path, db):
with open(path, "r", encoding="utf-8") as f:
lines = f.readlines()
m_list = []
e_list = []
for line in lines:
line = line.strip()
if not line or line.startswith("#"): continue
if " x " in line:
parts = line.split(" x ")
card_no = parts[0].strip()
qty = int(parts[1].strip())
cid = db.id_by_no(card_no)
if cid is None: continue
if cid >= 10000: e_list.extend([cid] * qty)
else: m_list.extend([cid] * qty)
return {"initial_deck": m_list, "energy": e_list}
def main():
root = os.getcwd()
db_path = os.path.join(root, "data", "cards_vanilla.json")
ckpt_path = os.path.join(root, "checkpoints", "vanilla_overnight", "best.pt")
deck_path = os.path.join(root, "ai/decks/muse_cup.txt")
print(f"Loading DB...")
with open(db_path, "r", encoding="utf-8") as f:
db_json = f.read()
rust_db = engine_rust.PyCardDatabase(db_json)
deck_data = load_deck_txt(deck_path, rust_db)
print(f"Loading model (preset: tiny)...")
config = VanillaTransformerConfig.from_preset("tiny")
model = HighFidelityAlphaNet(config)
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(checkpoint["model"])
model.eval()
num_games = 50
model_wins = 0
draws = 0
total_turns = 0
print(f"Starting comparison: Model vs Random ({num_games} games)...")
for i in range(num_games):
if i % 10 == 0:
print(f" Playing game {i}...")
# Initialize Game
state = engine_rust.PyGameState(rust_db)
seed = 42 + i
state.initialize_game_with_seed(
deck_data["initial_deck"], deck_data["initial_deck"],
deck_data["energy"], deck_data["energy"],
[], [], seed
)
state.silent = True
# Determine roles
# Game 0-24: P0 is Model, P1 is Random
# Game 25-49: P0 is Random, P1 is Model
model_player = 0 if i < (num_games // 2) else 1
while not state.is_terminal() and state.turn < 25:
legal_engine_ids = state.get_legal_action_ids()
if not legal_engine_ids:
state.auto_step(rust_db)
continue
curr_player = state.current_player
if curr_player == model_player:
# Model Turn
obs = state.to_vanilla_tensor()
obs_t = torch.from_numpy(obs).unsqueeze(0)
mask = build_legal_policy_mask(state, curr_player, deck_data["initial_deck"], state.phase, legal_engine_ids)
mask_t = torch.from_numpy(mask).unsqueeze(0)
with torch.no_grad():
logits, _ = model(obs_t, mask=mask_t)
probs = torch.softmax(logits, dim=1).squeeze(0).numpy()
# Filter strictly by legal engine actions mapping
legal_probs = []
legal_actions = []
for pid in np.where(mask > 0)[0]:
eng_id = policy_id_to_engine_action(state, curr_player, pid, state.phase, deck_data["initial_deck"])
if eng_id is not None and eng_id in legal_engine_ids:
legal_probs.append(probs[pid])
legal_actions.append(eng_id)
if not legal_actions:
action = random.choice(list(legal_engine_ids))
else:
action = legal_actions[np.argmax(legal_probs)]
else:
# Random Turn
action = random.choice(list(legal_engine_ids))
state.step(int(action))
state.auto_step(rust_db)
total_turns += state.turn
winner = state.get_winner()
if winner == model_player:
model_wins += 1
elif winner == -1:
draws += 1 # Or handle as tie/loss
print(f"\nResults over {num_games} games:")
print(f" Model Wins : {model_wins} ({model_wins/num_games*100:.1f}%)")
print(f" Random Wins: {num_games - model_wins - draws}")
print(f" Draws : {draws}")
print(f" Avg Turns : {total_turns/num_games:.1f}")
if __name__ == "__main__":
main()