trioskosmos commited on
Commit
f9b81b4
·
verified ·
1 Parent(s): 5c0c5f6

Upload ai/models/network_torch.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/models/network_torch.py +244 -0
ai/models/network_torch.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch implementation of Transformer-based AlphaZero network.
3
+ Processes the game state as a set of interacting cards (Tokens) rather than a flat vector.
4
+ """
5
+
6
+ from typing import Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.optim as optim
13
+
14
+ # Import config constants
15
+ from .training_config import DROPOUT, HIDDEN_SIZE, N_HEADS, NUM_LAYERS
16
+
17
+
18
+ class Tokenizer(nn.Module):
19
+ """
20
+ Slices the 1200-float input vector into semantic tokens:
21
+ - 1 Global Token (144 features: 20 basic + 124 heuristics/misc)
22
+ - 22 Card Tokens (6 Stage, 6 Live, 10 Hand) - 48 features each
23
+ """
24
+
25
+ def __init__(self, d_model: int):
26
+ super().__init__()
27
+ self.d_model = d_model
28
+
29
+ self.card_size = 48
30
+ # Global (20) + Tail (1076:1200 = 124) = 144 features
31
+ self.global_size = 144
32
+
33
+ # Projections
34
+ self.global_proj = nn.Linear(self.global_size, d_model)
35
+ self.card_proj = nn.Linear(self.card_size, d_model)
36
+
37
+ # Zone Embeddings: 0=Global, 1=P0_Stage, 2=P1_Stage, 3=P0_Live, 4=P1_Live, 5=P0_Hand
38
+ self.zone_embedding = nn.Embedding(8, d_model)
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ # x: (B, 1200)
42
+ batch_size = x.shape[0]
43
+
44
+ tokens = []
45
+
46
+ # 1. Global Token
47
+ # Basic Globals (0-20) + Tail Heuristics (1076-1200)
48
+ global_feat = torch.cat([x[:, 0:20], x[:, 1076:1200]], dim=1)
49
+
50
+ t_global = self.global_proj(global_feat) # (B, d_model)
51
+ t_global = t_global + self.zone_embedding(torch.zeros(batch_size, dtype=torch.long, device=x.device))
52
+ tokens.append(t_global.unsqueeze(1))
53
+
54
+ # 2. Card Tokens helper
55
+ def make_cards(start_idx, count, zone_id):
56
+ card_tokens = []
57
+ for i in range(count):
58
+ s = start_idx + i * 48
59
+ e = s + 48
60
+ c_vec = x[:, s:e]
61
+ c_emb = self.card_proj(c_vec)
62
+ c_emb = c_emb + self.zone_embedding(
63
+ torch.full((batch_size,), zone_id, dtype=torch.long, device=x.device)
64
+ )
65
+ card_tokens.append(c_emb.unsqueeze(1))
66
+ return card_tokens
67
+
68
+ # P0 Stage (Zone 1) - starts at 20
69
+ tokens.extend(make_cards(20, 3, 1))
70
+ # P1 Stage (Zone 2) - starts at 164
71
+ tokens.extend(make_cards(164, 3, 2))
72
+ # P0 Live (Zone 3) - starts at 308
73
+ tokens.extend(make_cards(308, 3, 3))
74
+ # P1 Live (Zone 4) - starts at 452
75
+ tokens.extend(make_cards(452, 3, 4))
76
+ # P0 Hand (Zone 5) - starts at 596
77
+ tokens.extend(make_cards(596, 10, 5))
78
+
79
+ # SeqLen = 1 + 3 + 3 + 3 + 3 + 10 = 23
80
+ return torch.cat(tokens, dim=1)
81
+
82
+
83
+ class TransformerCardNet(nn.Module):
84
+ def __init__(self, input_size=1200, action_size=2000):
85
+ super().__init__()
86
+
87
+ self.d_model = HIDDEN_SIZE
88
+
89
+ # 1. Tokenizer
90
+ self.tokenizer = Tokenizer(self.d_model)
91
+
92
+ # 2. Transformer Encoder
93
+ encoder_layer = nn.TransformerEncoderLayer(
94
+ d_model=self.d_model, nhead=N_HEADS, dim_feedforward=self.d_model * 4, dropout=DROPOUT, batch_first=True
95
+ )
96
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=NUM_LAYERS)
97
+
98
+ # 3. Policy Heads
99
+ self.hand_action_proj = nn.Linear(self.d_model, 6) # [Play0, Play1, Play2, Energy, Mull, LiveSet]
100
+ self.stage_action_proj = nn.Linear(self.d_model, 10) # [Ability0..9]
101
+ self.live_action_proj = nn.Linear(self.d_model, 1) # [SelectSuccess]
102
+ self.global_action_proj = nn.Linear(self.d_model, 10) # [0:Pass, 1..6:Colors, ... ]
103
+
104
+ # Value Heads
105
+ # Win-rate head (Sigmoid)
106
+ self.value_win_head = nn.Sequential(nn.Linear(self.d_model, 128), nn.ReLU(), nn.Linear(128, 1), nn.Sigmoid())
107
+ # Score differential head (Tanh -1..1)
108
+ self.value_score_head = nn.Sequential(nn.Linear(self.d_model, 128), nn.ReLU(), nn.Linear(128, 1), nn.Tanh())
109
+ # Auxiliary Pacing Head (Progress 0..1)
110
+ self.turns_head = nn.Sequential(nn.Linear(self.d_model, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())
111
+
112
+ def forward(self, x):
113
+ batch_size = x.size(0)
114
+ tokens = self.tokenizer(x)
115
+ encoded = self.transformer(tokens) # (B, 23, d_model)
116
+
117
+ # --- Policy Reconstruction ---
118
+ logits = torch.zeros(batch_size, 2000, device=x.device)
119
+
120
+ # Global Actions
121
+ global_tok = encoded[:, 0, :]
122
+ g_logits = self.global_action_proj(global_tok)
123
+ logits[:, 0] = g_logits[:, 0] # Pass
124
+ logits[:, 580:586] = g_logits[:, 1:7] # Colors
125
+
126
+ # Hand Actions (Tokens 13-22)
127
+ hand_toks = encoded[:, 13:23, :]
128
+ h_logits = self.hand_action_proj(hand_toks) # (B, 10, 6)
129
+ for i in range(10):
130
+ logits[:, 1 + 3 * i : 1 + 3 * i + 3] = h_logits[:, i, 0:3]
131
+ logits[:, 100 + i] = h_logits[:, i, 3] # Energy
132
+ logits[:, 300 + i] = h_logits[:, i, 4] # Mull
133
+ logits[:, 400 + i] = h_logits[:, i, 5] # LiveSet
134
+
135
+ # Stage Actions (Tokens 1-3)
136
+ stage_toks = encoded[:, 1:4, :]
137
+ s_logits = self.stage_action_proj(stage_toks) # (B, 3, 10)
138
+ for i in range(3):
139
+ logits[:, 200 + 10 * i : 200 + 10 * i + 10] = s_logits[:, i, :]
140
+
141
+ # Live Zone Actions (Tokens 7-9)
142
+ live_toks = encoded[:, 7:10, :]
143
+ l_logits = self.live_action_proj(live_toks).squeeze(-1) # (B, 3)
144
+ logits[:, 600:603] = l_logits
145
+
146
+ # --- Value Heads ---
147
+ cls_token = encoded[:, 0, :]
148
+ val_win = self.value_win_head(cls_token) # (B, 1)
149
+ val_score = self.value_score_head(cls_token) # (B, 1)
150
+ turns_pred = self.turns_head(cls_token) # (B, 1)
151
+
152
+ return F.softmax(logits, dim=1), val_win, val_score, turns_pred
153
+
154
+
155
+ class TorchNetworkWrapper:
156
+ """Wrapper to interface with MCTS/Training loop"""
157
+
158
+ def __init__(self, config=None, device=None, enable_compile=True):
159
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
160
+ print(f"Using device: {self.device}")
161
+
162
+ self.net = TransformerCardNet().to(self.device)
163
+
164
+ if enable_compile and hasattr(torch, "compile") and "win" not in torch.sys.platform:
165
+ try:
166
+ print("Compiling Transformer with torch.compile...")
167
+ self.net = torch.compile(self.net, mode="reduce-overhead")
168
+ except Exception as e:
169
+ print(f"Compile failed: {e}")
170
+
171
+ lr = 0.0003
172
+ self.optimizer = optim.AdamW(self.net.parameters(), lr=lr, weight_decay=1e-4)
173
+
174
+ def predict(self, state) -> Tuple[np.ndarray, float]:
175
+ self.net.eval()
176
+ obs = state.get_observation()
177
+ if len(obs) != 1200:
178
+ if len(obs) < 1200:
179
+ obs = obs + [0.0] * (1200 - len(obs))
180
+ else:
181
+ obs = obs[:1200]
182
+
183
+ x = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(self.device)
184
+
185
+ with torch.no_grad():
186
+ p_soft, v_win, v_score, t_pred = self.net(x)
187
+
188
+ p = p_soft.cpu().numpy()[0]
189
+ v = v_win.item() # MCTS typically uses win probability [0,1] or [-1,1]
190
+
191
+ # Mask illegal
192
+ legal = state.get_legal_actions()
193
+ masked = p * legal
194
+ sum_p = masked.sum()
195
+ if sum_p > 0:
196
+ masked /= sum_p
197
+ else:
198
+ masked = legal.astype(np.float32) / legal.sum()
199
+
200
+ return masked, v
201
+
202
+ def train_step(self, obs, target_p, target_v_win, target_v_score, target_turns):
203
+ """
204
+ obs: (B, 1200)
205
+ target_p: (B, 2000)
206
+ target_v_win: (B, 1)
207
+ target_v_score: (B, 1)
208
+ target_turns: (B, 1)
209
+ """
210
+ self.net.train()
211
+ self.optimizer.zero_grad()
212
+
213
+ x = torch.tensor(obs, dtype=torch.float32).to(self.device)
214
+ t_p = torch.tensor(target_p, dtype=torch.float32).to(self.device)
215
+ t_w = torch.tensor(target_v_win, dtype=torch.float32).to(self.device)
216
+ t_s = torch.tensor(target_v_score, dtype=torch.float32).to(self.device)
217
+ t_t = torch.tensor(target_turns, dtype=torch.float32).to(self.device)
218
+
219
+ p, w, s, t = self.net(x)
220
+
221
+ loss_p = -torch.sum(t_p * torch.log(p + 1e-8)) / x.size(0)
222
+ loss_w = F.binary_cross_entropy(w, t_w)
223
+ loss_s = F.mse_loss(s, t_s)
224
+ loss_t = F.mse_loss(t, t_t)
225
+
226
+ total_loss = loss_p + loss_w + loss_s + loss_t
227
+ total_loss.backward()
228
+ self.optimizer.step()
229
+
230
+ return total_loss.item(), loss_p.item(), loss_w.item(), loss_s.item()
231
+
232
+ def save(self, path):
233
+ if hasattr(self.net, "_orig_mod"):
234
+ torch.save(self.net._orig_mod.state_dict(), path)
235
+ else:
236
+ torch.save(self.net.state_dict(), path)
237
+
238
+ def load(self, path):
239
+ sd = torch.load(path, map_location=self.device)
240
+ sd = {k.replace("_orig_mod.", ""): v for k, v in sd.items()}
241
+ if hasattr(self.net, "_orig_mod"):
242
+ self.net._orig_mod.load_state_dict(sd)
243
+ else:
244
+ self.net.load_state_dict(sd)