Spaces:
Sleeping
Sleeping
File size: 5,018 Bytes
463f868 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | """
Hyperparameter Search for AlphaZero Network.
Finds optimal Depth/Width by training candidates and evaluating performance.
"""
import time
import numpy as np
import torch
from game.game_state import GameState, create_sample_cards, initialize_game
from mcts import MCTS, MCTSConfig
from network import NetworkConfig # Reuse config structure
from network_torch import TorchNetworkWrapper
# Candidates to evaluate
CANDIDATES = {
"Small": {"hidden_size": 128, "num_layers": 3},
"Medium": {"hidden_size": 256, "num_layers": 5},
"Large": {"hidden_size": 512, "num_layers": 10},
}
def run_search(games_per_candidate=20, eval_games=10):
print("Initializing Search...")
# Init game data
m, l = create_sample_cards()
GameState.member_db = m
GameState.live_db = l
results = {}
for name, params in CANDIDATES.items():
print(f"\nEvaluating Candidate: {name} {params}")
# 1. Setup Network
dummy_game = initialize_game()
obs_size = len(dummy_game.get_observation())
config = NetworkConfig(
hidden_size=params["hidden_size"],
num_hidden_layers=params["num_layers"],
input_size=obs_size,
action_size=200,
learning_rate=0.001,
)
try:
wrapper = TorchNetworkWrapper(config)
except Exception as e:
print(f"Failed to create network for {name}: {e}")
continue
# 2. Generate Data (Self-Play)
print(f" Generating {games_per_candidate} self-play games...")
start_t = time.time()
training_data = [] # (state, policy, value)
# Use simple Neural MCTS (simulations reduced for speed in search)
from network import NeuralMCTS # We can reuse the MCTS class logic but pass torch wrapper
# We need to monkey-patch or adapter because NeuralMCTS expects `network.predict()`
# My TorchNetworkWrapper has `predict()` matching the signature.
mcts_agent = NeuralMCTS(network=wrapper, num_simulations=25)
for _ in range(games_per_candidate):
g = initialize_game()
states, policies = [], []
move_count = 0
while not g.is_terminal() and move_count < 150:
pol = mcts_agent.search(g)
states.append(g.get_observation())
policies.append(pol)
action = np.random.choice(len(pol), p=pol)
g = g.step(action)
move_count += 1
winner = g.get_winner() if g.is_terminal() else 2
# Process game data
for i, (s, p) in enumerate(zip(states, policies, strict=False)):
val = 0.0
if winner != 2:
val = 1.0 if (i % 2 == winner) else -1.0 # WRONG logic for player idx?
# winner is 0 or 1.
# if i%2 == 0 (Player 0 acted), and winner==0 -> +1.
# Correct.
training_data.append((s, p, val))
gen_time = time.time() - start_t
print(f" Gen Time: {gen_time:.1f}s")
# 3. Train
print(" Training...")
# Unpack data
all_s = np.array([x[0] for x in training_data])
all_p = np.array([x[1] for x in training_data])
all_v = np.array([x[2] for x in training_data])
# 5 epochs
final_loss = 0
for _ep in range(5):
# Full batch for simplicity in search
l, pl, vl = wrapper.train_step(all_s, all_p, all_v)
final_loss = l
print(f" Final Loss: {final_loss:.4f}")
# 4. Evaluation vs Random
print(f" Evaluating vs Random ({eval_games} games)...")
wins = 0
rand_mcts = MCTS(MCTSConfig(num_simulations=10))
for i in range(eval_games):
g = initialize_game()
net_player = i % 2
while not g.is_terminal() and g.turn_number < 100:
if g.current_player == net_player:
pol = mcts_agent.search(g) # Uses updated net
act = np.argmax(pol) # Deterministic for eval
else:
act = rand_mcts.select_action(g)
g = g.step(act)
w = g.get_winner() if g.is_terminal() else 2
if w == net_player:
wins += 1
print(f" Wins: {wins}/{eval_games}")
results[name] = {"loss": final_loss, "wins": wins, "time": gen_time}
del wrapper # Free GPU memory
torch.cuda.empty_cache()
# Report
print("\nSearch Results:")
print(f"{'Name':<10} | {'Loss':<8} | {'Wins':<5} | {'Time':<6}")
print("-" * 35)
for name, r in results.items():
print(f"{name:<10} | {r['loss']:.4f} | {r['wins']:<5} | {r['time']:.1f}")
if __name__ == "__main__":
run_search()
|