trioskosmos's picture
Upload folder using huggingface_hub
463f868 verified
"""
Hyperparameter Search for AlphaZero Network.
Finds optimal Depth/Width by training candidates and evaluating performance.
"""
import time
import numpy as np
import torch
from game.game_state import GameState, create_sample_cards, initialize_game
from mcts import MCTS, MCTSConfig
from network import NetworkConfig # Reuse config structure
from network_torch import TorchNetworkWrapper
# Candidates to evaluate
CANDIDATES = {
"Small": {"hidden_size": 128, "num_layers": 3},
"Medium": {"hidden_size": 256, "num_layers": 5},
"Large": {"hidden_size": 512, "num_layers": 10},
}
def run_search(games_per_candidate=20, eval_games=10):
print("Initializing Search...")
# Init game data
m, l = create_sample_cards()
GameState.member_db = m
GameState.live_db = l
results = {}
for name, params in CANDIDATES.items():
print(f"\nEvaluating Candidate: {name} {params}")
# 1. Setup Network
dummy_game = initialize_game()
obs_size = len(dummy_game.get_observation())
config = NetworkConfig(
hidden_size=params["hidden_size"],
num_hidden_layers=params["num_layers"],
input_size=obs_size,
action_size=200,
learning_rate=0.001,
)
try:
wrapper = TorchNetworkWrapper(config)
except Exception as e:
print(f"Failed to create network for {name}: {e}")
continue
# 2. Generate Data (Self-Play)
print(f" Generating {games_per_candidate} self-play games...")
start_t = time.time()
training_data = [] # (state, policy, value)
# Use simple Neural MCTS (simulations reduced for speed in search)
from network import NeuralMCTS # We can reuse the MCTS class logic but pass torch wrapper
# We need to monkey-patch or adapter because NeuralMCTS expects `network.predict()`
# My TorchNetworkWrapper has `predict()` matching the signature.
mcts_agent = NeuralMCTS(network=wrapper, num_simulations=25)
for _ in range(games_per_candidate):
g = initialize_game()
states, policies = [], []
move_count = 0
while not g.is_terminal() and move_count < 150:
pol = mcts_agent.search(g)
states.append(g.get_observation())
policies.append(pol)
action = np.random.choice(len(pol), p=pol)
g = g.step(action)
move_count += 1
winner = g.get_winner() if g.is_terminal() else 2
# Process game data
for i, (s, p) in enumerate(zip(states, policies, strict=False)):
val = 0.0
if winner != 2:
val = 1.0 if (i % 2 == winner) else -1.0 # WRONG logic for player idx?
# winner is 0 or 1.
# if i%2 == 0 (Player 0 acted), and winner==0 -> +1.
# Correct.
training_data.append((s, p, val))
gen_time = time.time() - start_t
print(f" Gen Time: {gen_time:.1f}s")
# 3. Train
print(" Training...")
# Unpack data
all_s = np.array([x[0] for x in training_data])
all_p = np.array([x[1] for x in training_data])
all_v = np.array([x[2] for x in training_data])
# 5 epochs
final_loss = 0
for _ep in range(5):
# Full batch for simplicity in search
l, pl, vl = wrapper.train_step(all_s, all_p, all_v)
final_loss = l
print(f" Final Loss: {final_loss:.4f}")
# 4. Evaluation vs Random
print(f" Evaluating vs Random ({eval_games} games)...")
wins = 0
rand_mcts = MCTS(MCTSConfig(num_simulations=10))
for i in range(eval_games):
g = initialize_game()
net_player = i % 2
while not g.is_terminal() and g.turn_number < 100:
if g.current_player == net_player:
pol = mcts_agent.search(g) # Uses updated net
act = np.argmax(pol) # Deterministic for eval
else:
act = rand_mcts.select_action(g)
g = g.step(act)
w = g.get_winner() if g.is_terminal() else 2
if w == net_player:
wins += 1
print(f" Wins: {wins}/{eval_games}")
results[name] = {"loss": final_loss, "wins": wins, "time": gen_time}
del wrapper # Free GPU memory
torch.cuda.empty_cache()
# Report
print("\nSearch Results:")
print(f"{'Name':<10} | {'Loss':<8} | {'Wins':<5} | {'Time':<6}")
print("-" * 35)
for name, r in results.items():
print(f"{name:<10} | {r['loss']:.4f} | {r['wins']:<5} | {r['time']:.1f}")
if __name__ == "__main__":
run_search()