File size: 5,018 Bytes
463f868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""

Hyperparameter Search for AlphaZero Network.

Finds optimal Depth/Width by training candidates and evaluating performance.

"""

import time

import numpy as np
import torch
from game.game_state import GameState, create_sample_cards, initialize_game
from mcts import MCTS, MCTSConfig
from network import NetworkConfig  # Reuse config structure
from network_torch import TorchNetworkWrapper

# Candidates to evaluate
CANDIDATES = {
    "Small": {"hidden_size": 128, "num_layers": 3},
    "Medium": {"hidden_size": 256, "num_layers": 5},
    "Large": {"hidden_size": 512, "num_layers": 10},
}


def run_search(games_per_candidate=20, eval_games=10):
    print("Initializing Search...")

    # Init game data
    m, l = create_sample_cards()
    GameState.member_db = m
    GameState.live_db = l

    results = {}

    for name, params in CANDIDATES.items():
        print(f"\nEvaluating Candidate: {name} {params}")

        # 1. Setup Network
        dummy_game = initialize_game()
        obs_size = len(dummy_game.get_observation())

        config = NetworkConfig(
            hidden_size=params["hidden_size"],
            num_hidden_layers=params["num_layers"],
            input_size=obs_size,
            action_size=200,
            learning_rate=0.001,
        )

        try:
            wrapper = TorchNetworkWrapper(config)
        except Exception as e:
            print(f"Failed to create network for {name}: {e}")
            continue

        # 2. Generate Data (Self-Play)
        print(f"  Generating {games_per_candidate} self-play games...")
        start_t = time.time()

        training_data = []  # (state, policy, value)

        # Use simple Neural MCTS (simulations reduced for speed in search)
        from network import NeuralMCTS  # We can reuse the MCTS class logic but pass torch wrapper

        # We need to monkey-patch or adapter because NeuralMCTS expects `network.predict()`
        # My TorchNetworkWrapper has `predict()` matching the signature.
        mcts_agent = NeuralMCTS(network=wrapper, num_simulations=25)

        for _ in range(games_per_candidate):
            g = initialize_game()
            states, policies = [], []
            move_count = 0
            while not g.is_terminal() and move_count < 150:
                pol = mcts_agent.search(g)
                states.append(g.get_observation())
                policies.append(pol)
                action = np.random.choice(len(pol), p=pol)
                g = g.step(action)
                move_count += 1

            winner = g.get_winner() if g.is_terminal() else 2

            # Process game data
            for i, (s, p) in enumerate(zip(states, policies, strict=False)):
                val = 0.0
                if winner != 2:
                    val = 1.0 if (i % 2 == winner) else -1.0  # WRONG logic for player idx?
                    # winner is 0 or 1.
                    # if i%2 == 0 (Player 0 acted), and winner==0 -> +1.
                    # Correct.
                training_data.append((s, p, val))

        gen_time = time.time() - start_t
        print(f"  Gen Time: {gen_time:.1f}s")

        # 3. Train
        print("  Training...")
        # Unpack data
        all_s = np.array([x[0] for x in training_data])
        all_p = np.array([x[1] for x in training_data])
        all_v = np.array([x[2] for x in training_data])

        # 5 epochs
        final_loss = 0
        for _ep in range(5):
            # Full batch for simplicity in search
            l, pl, vl = wrapper.train_step(all_s, all_p, all_v)
            final_loss = l

        print(f"  Final Loss: {final_loss:.4f}")

        # 4. Evaluation vs Random
        print(f"  Evaluating vs Random ({eval_games} games)...")
        wins = 0
        rand_mcts = MCTS(MCTSConfig(num_simulations=10))

        for i in range(eval_games):
            g = initialize_game()
            net_player = i % 2

            while not g.is_terminal() and g.turn_number < 100:
                if g.current_player == net_player:
                    pol = mcts_agent.search(g)  # Uses updated net
                    act = np.argmax(pol)  # Deterministic for eval
                else:
                    act = rand_mcts.select_action(g)
                g = g.step(act)

            w = g.get_winner() if g.is_terminal() else 2
            if w == net_player:
                wins += 1

        print(f"  Wins: {wins}/{eval_games}")
        results[name] = {"loss": final_loss, "wins": wins, "time": gen_time}
        del wrapper  # Free GPU memory
        torch.cuda.empty_cache()

    # Report
    print("\nSearch Results:")
    print(f"{'Name':<10} | {'Loss':<8} | {'Wins':<5} | {'Time':<6}")
    print("-" * 35)
    for name, r in results.items():
        print(f"{name:<10} | {r['loss']:.4f}   | {r['wins']:<5} | {r['time']:.1f}")


if __name__ == "__main__":
    run_search()