File size: 9,318 Bytes
f9b81b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
"""

PyTorch implementation of Transformer-based AlphaZero network.

Processes the game state as a set of interacting cards (Tokens) rather than a flat vector.

"""

from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Import config constants
from .training_config import DROPOUT, HIDDEN_SIZE, N_HEADS, NUM_LAYERS


class Tokenizer(nn.Module):
    """

    Slices the 1200-float input vector into semantic tokens:

    - 1 Global Token (144 features: 20 basic + 124 heuristics/misc)

    - 22 Card Tokens (6 Stage, 6 Live, 10 Hand) - 48 features each

    """

    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model

        self.card_size = 48
        # Global (20) + Tail (1076:1200 = 124) = 144 features
        self.global_size = 144

        # Projections
        self.global_proj = nn.Linear(self.global_size, d_model)
        self.card_proj = nn.Linear(self.card_size, d_model)

        # Zone Embeddings: 0=Global, 1=P0_Stage, 2=P1_Stage, 3=P0_Live, 4=P1_Live, 5=P0_Hand
        self.zone_embedding = nn.Embedding(8, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, 1200)
        batch_size = x.shape[0]

        tokens = []

        # 1. Global Token
        # Basic Globals (0-20) + Tail Heuristics (1076-1200)
        global_feat = torch.cat([x[:, 0:20], x[:, 1076:1200]], dim=1)

        t_global = self.global_proj(global_feat)  # (B, d_model)
        t_global = t_global + self.zone_embedding(torch.zeros(batch_size, dtype=torch.long, device=x.device))
        tokens.append(t_global.unsqueeze(1))

        # 2. Card Tokens helper
        def make_cards(start_idx, count, zone_id):
            card_tokens = []
            for i in range(count):
                s = start_idx + i * 48
                e = s + 48
                c_vec = x[:, s:e]
                c_emb = self.card_proj(c_vec)
                c_emb = c_emb + self.zone_embedding(
                    torch.full((batch_size,), zone_id, dtype=torch.long, device=x.device)
                )
                card_tokens.append(c_emb.unsqueeze(1))
            return card_tokens

        # P0 Stage (Zone 1) - starts at 20
        tokens.extend(make_cards(20, 3, 1))
        # P1 Stage (Zone 2) - starts at 164
        tokens.extend(make_cards(164, 3, 2))
        # P0 Live (Zone 3) - starts at 308
        tokens.extend(make_cards(308, 3, 3))
        # P1 Live (Zone 4) - starts at 452
        tokens.extend(make_cards(452, 3, 4))
        # P0 Hand (Zone 5) - starts at 596
        tokens.extend(make_cards(596, 10, 5))

        # SeqLen = 1 + 3 + 3 + 3 + 3 + 10 = 23
        return torch.cat(tokens, dim=1)


class TransformerCardNet(nn.Module):
    def __init__(self, input_size=1200, action_size=2000):
        super().__init__()

        self.d_model = HIDDEN_SIZE

        # 1. Tokenizer
        self.tokenizer = Tokenizer(self.d_model)

        # 2. Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model, nhead=N_HEADS, dim_feedforward=self.d_model * 4, dropout=DROPOUT, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=NUM_LAYERS)

        # 3. Policy Heads
        self.hand_action_proj = nn.Linear(self.d_model, 6)  # [Play0, Play1, Play2, Energy, Mull, LiveSet]
        self.stage_action_proj = nn.Linear(self.d_model, 10)  # [Ability0..9]
        self.live_action_proj = nn.Linear(self.d_model, 1)  # [SelectSuccess]
        self.global_action_proj = nn.Linear(self.d_model, 10)  # [0:Pass, 1..6:Colors, ... ]

        # Value Heads
        # Win-rate head (Sigmoid)
        self.value_win_head = nn.Sequential(nn.Linear(self.d_model, 128), nn.ReLU(), nn.Linear(128, 1), nn.Sigmoid())
        # Score differential head (Tanh -1..1)
        self.value_score_head = nn.Sequential(nn.Linear(self.d_model, 128), nn.ReLU(), nn.Linear(128, 1), nn.Tanh())
        # Auxiliary Pacing Head (Progress 0..1)
        self.turns_head = nn.Sequential(nn.Linear(self.d_model, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())

    def forward(self, x):
        batch_size = x.size(0)
        tokens = self.tokenizer(x)
        encoded = self.transformer(tokens)  # (B, 23, d_model)

        # --- Policy Reconstruction ---
        logits = torch.zeros(batch_size, 2000, device=x.device)

        # Global Actions
        global_tok = encoded[:, 0, :]
        g_logits = self.global_action_proj(global_tok)
        logits[:, 0] = g_logits[:, 0]  # Pass
        logits[:, 580:586] = g_logits[:, 1:7]  # Colors

        # Hand Actions (Tokens 13-22)
        hand_toks = encoded[:, 13:23, :]
        h_logits = self.hand_action_proj(hand_toks)  # (B, 10, 6)
        for i in range(10):
            logits[:, 1 + 3 * i : 1 + 3 * i + 3] = h_logits[:, i, 0:3]
            logits[:, 100 + i] = h_logits[:, i, 3]  # Energy
            logits[:, 300 + i] = h_logits[:, i, 4]  # Mull
            logits[:, 400 + i] = h_logits[:, i, 5]  # LiveSet

        # Stage Actions (Tokens 1-3)
        stage_toks = encoded[:, 1:4, :]
        s_logits = self.stage_action_proj(stage_toks)  # (B, 3, 10)
        for i in range(3):
            logits[:, 200 + 10 * i : 200 + 10 * i + 10] = s_logits[:, i, :]

        # Live Zone Actions (Tokens 7-9)
        live_toks = encoded[:, 7:10, :]
        l_logits = self.live_action_proj(live_toks).squeeze(-1)  # (B, 3)
        logits[:, 600:603] = l_logits

        # --- Value Heads ---
        cls_token = encoded[:, 0, :]
        val_win = self.value_win_head(cls_token)  # (B, 1)
        val_score = self.value_score_head(cls_token)  # (B, 1)
        turns_pred = self.turns_head(cls_token)  # (B, 1)

        return F.softmax(logits, dim=1), val_win, val_score, turns_pred


class TorchNetworkWrapper:
    """Wrapper to interface with MCTS/Training loop"""

    def __init__(self, config=None, device=None, enable_compile=True):
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.net = TransformerCardNet().to(self.device)

        if enable_compile and hasattr(torch, "compile") and "win" not in torch.sys.platform:
            try:
                print("Compiling Transformer with torch.compile...")
                self.net = torch.compile(self.net, mode="reduce-overhead")
            except Exception as e:
                print(f"Compile failed: {e}")

        lr = 0.0003
        self.optimizer = optim.AdamW(self.net.parameters(), lr=lr, weight_decay=1e-4)

    def predict(self, state) -> Tuple[np.ndarray, float]:
        self.net.eval()
        obs = state.get_observation()
        if len(obs) != 1200:
            if len(obs) < 1200:
                obs = obs + [0.0] * (1200 - len(obs))
            else:
                obs = obs[:1200]

        x = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(self.device)

        with torch.no_grad():
            p_soft, v_win, v_score, t_pred = self.net(x)

        p = p_soft.cpu().numpy()[0]
        v = v_win.item()  # MCTS typically uses win probability [0,1] or [-1,1]

        # Mask illegal
        legal = state.get_legal_actions()
        masked = p * legal
        sum_p = masked.sum()
        if sum_p > 0:
            masked /= sum_p
        else:
            masked = legal.astype(np.float32) / legal.sum()

        return masked, v

    def train_step(self, obs, target_p, target_v_win, target_v_score, target_turns):
        """

        obs: (B, 1200)

        target_p: (B, 2000)

        target_v_win: (B, 1)

        target_v_score: (B, 1)

        target_turns: (B, 1)

        """
        self.net.train()
        self.optimizer.zero_grad()

        x = torch.tensor(obs, dtype=torch.float32).to(self.device)
        t_p = torch.tensor(target_p, dtype=torch.float32).to(self.device)
        t_w = torch.tensor(target_v_win, dtype=torch.float32).to(self.device)
        t_s = torch.tensor(target_v_score, dtype=torch.float32).to(self.device)
        t_t = torch.tensor(target_turns, dtype=torch.float32).to(self.device)

        p, w, s, t = self.net(x)

        loss_p = -torch.sum(t_p * torch.log(p + 1e-8)) / x.size(0)
        loss_w = F.binary_cross_entropy(w, t_w)
        loss_s = F.mse_loss(s, t_s)
        loss_t = F.mse_loss(t, t_t)

        total_loss = loss_p + loss_w + loss_s + loss_t
        total_loss.backward()
        self.optimizer.step()

        return total_loss.item(), loss_p.item(), loss_w.item(), loss_s.item()

    def save(self, path):
        if hasattr(self.net, "_orig_mod"):
            torch.save(self.net._orig_mod.state_dict(), path)
        else:
            torch.save(self.net.state_dict(), path)

    def load(self, path):
        sd = torch.load(path, map_location=self.device)
        sd = {k.replace("_orig_mod.", ""): v for k, v in sd.items()}
        if hasattr(self.net, "_orig_mod"):
            self.net._orig_mod.load_state_dict(sd)
        else:
            self.net.load_state_dict(sd)