Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import json | |
| import torch | |
| import numpy as np | |
| import random | |
| # Ensure engine_rust is importable | |
| pwd = os.getcwd() | |
| if pwd not in sys.path: | |
| sys.path.append(pwd) | |
| import engine_rust | |
| from alphazero.vanilla_net import HighFidelityAlphaNet, VanillaTransformerConfig | |
| from alphazero.training.vanilla_action_codec import ( | |
| ACTION_SPACE, | |
| policy_id_to_engine_action, | |
| build_legal_policy_mask | |
| ) | |
| def load_deck_txt(path, db): | |
| with open(path, "r", encoding="utf-8") as f: | |
| lines = f.readlines() | |
| m_list = [] | |
| e_list = [] | |
| for line in lines: | |
| line = line.strip() | |
| if not line or line.startswith("#"): continue | |
| if " x " in line: | |
| parts = line.split(" x ") | |
| card_no = parts[0].strip() | |
| qty = int(parts[1].strip()) | |
| cid = db.id_by_no(card_no) | |
| if cid is None: continue | |
| if cid >= 10000: e_list.extend([cid] * qty) | |
| else: m_list.extend([cid] * qty) | |
| return {"initial_deck": m_list, "energy": e_list} | |
| def main(): | |
| root = os.getcwd() | |
| db_path = os.path.join(root, "data", "cards_vanilla.json") | |
| ckpt_path = os.path.join(root, "checkpoints", "vanilla_overnight", "best.pt") | |
| deck_path = os.path.join(root, "ai/decks/muse_cup.txt") | |
| print(f"Loading DB...") | |
| with open(db_path, "r", encoding="utf-8") as f: | |
| db_json = f.read() | |
| rust_db = engine_rust.PyCardDatabase(db_json) | |
| deck_data = load_deck_txt(deck_path, rust_db) | |
| print(f"Loading model (preset: tiny)...") | |
| config = VanillaTransformerConfig.from_preset("tiny") | |
| model = HighFidelityAlphaNet(config) | |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) | |
| model.load_state_dict(checkpoint["model"]) | |
| model.eval() | |
| num_games = 50 | |
| model_wins = 0 | |
| draws = 0 | |
| total_turns = 0 | |
| print(f"Starting comparison: Model vs Random ({num_games} games)...") | |
| for i in range(num_games): | |
| if i % 10 == 0: | |
| print(f" Playing game {i}...") | |
| # Initialize Game | |
| state = engine_rust.PyGameState(rust_db) | |
| seed = 42 + i | |
| state.initialize_game_with_seed( | |
| deck_data["initial_deck"], deck_data["initial_deck"], | |
| deck_data["energy"], deck_data["energy"], | |
| [], [], seed | |
| ) | |
| state.silent = True | |
| # Determine roles | |
| # Game 0-24: P0 is Model, P1 is Random | |
| # Game 25-49: P0 is Random, P1 is Model | |
| model_player = 0 if i < (num_games // 2) else 1 | |
| while not state.is_terminal() and state.turn < 25: | |
| legal_engine_ids = state.get_legal_action_ids() | |
| if not legal_engine_ids: | |
| state.auto_step(rust_db) | |
| continue | |
| curr_player = state.current_player | |
| if curr_player == model_player: | |
| # Model Turn | |
| obs = state.to_vanilla_tensor() | |
| obs_t = torch.from_numpy(obs).unsqueeze(0) | |
| mask = build_legal_policy_mask(state, curr_player, deck_data["initial_deck"], state.phase, legal_engine_ids) | |
| mask_t = torch.from_numpy(mask).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits, _ = model(obs_t, mask=mask_t) | |
| probs = torch.softmax(logits, dim=1).squeeze(0).numpy() | |
| # Filter strictly by legal engine actions mapping | |
| legal_probs = [] | |
| legal_actions = [] | |
| for pid in np.where(mask > 0)[0]: | |
| eng_id = policy_id_to_engine_action(state, curr_player, pid, state.phase, deck_data["initial_deck"]) | |
| if eng_id is not None and eng_id in legal_engine_ids: | |
| legal_probs.append(probs[pid]) | |
| legal_actions.append(eng_id) | |
| if not legal_actions: | |
| action = random.choice(list(legal_engine_ids)) | |
| else: | |
| action = legal_actions[np.argmax(legal_probs)] | |
| else: | |
| # Random Turn | |
| action = random.choice(list(legal_engine_ids)) | |
| state.step(int(action)) | |
| state.auto_step(rust_db) | |
| total_turns += state.turn | |
| winner = state.get_winner() | |
| if winner == model_player: | |
| model_wins += 1 | |
| elif winner == -1: | |
| draws += 1 # Or handle as tie/loss | |
| print(f"\nResults over {num_games} games:") | |
| print(f" Model Wins : {model_wins} ({model_wins/num_games*100:.1f}%)") | |
| print(f" Random Wins: {num_games - model_wins - draws}") | |
| print(f" Draws : {draws}") | |
| print(f" Avg Turns : {total_turns/num_games:.1f}") | |
| if __name__ == "__main__": | |
| main() | |