Spaces:
Running
Running
| """ | |
| PyTorch implementation of Transformer-based AlphaZero network. | |
| Processes the game state as a set of interacting cards (Tokens) rather than a flat vector. | |
| """ | |
| from typing import Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| # Import config constants | |
| from .training_config import DROPOUT, HIDDEN_SIZE, N_HEADS, NUM_LAYERS | |
| class Tokenizer(nn.Module): | |
| """ | |
| Slices the 1200-float input vector into semantic tokens: | |
| - 1 Global Token (144 features: 20 basic + 124 heuristics/misc) | |
| - 22 Card Tokens (6 Stage, 6 Live, 10 Hand) - 48 features each | |
| """ | |
| def __init__(self, d_model: int): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.card_size = 48 | |
| # Global (20) + Tail (1076:1200 = 124) = 144 features | |
| self.global_size = 144 | |
| # Projections | |
| self.global_proj = nn.Linear(self.global_size, d_model) | |
| self.card_proj = nn.Linear(self.card_size, d_model) | |
| # Zone Embeddings: 0=Global, 1=P0_Stage, 2=P1_Stage, 3=P0_Live, 4=P1_Live, 5=P0_Hand | |
| self.zone_embedding = nn.Embedding(8, d_model) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x: (B, 1200) | |
| batch_size = x.shape[0] | |
| tokens = [] | |
| # 1. Global Token | |
| # Basic Globals (0-20) + Tail Heuristics (1076-1200) | |
| global_feat = torch.cat([x[:, 0:20], x[:, 1076:1200]], dim=1) | |
| t_global = self.global_proj(global_feat) # (B, d_model) | |
| t_global = t_global + self.zone_embedding(torch.zeros(batch_size, dtype=torch.long, device=x.device)) | |
| tokens.append(t_global.unsqueeze(1)) | |
| # 2. Card Tokens helper | |
| def make_cards(start_idx, count, zone_id): | |
| card_tokens = [] | |
| for i in range(count): | |
| s = start_idx + i * 48 | |
| e = s + 48 | |
| c_vec = x[:, s:e] | |
| c_emb = self.card_proj(c_vec) | |
| c_emb = c_emb + self.zone_embedding( | |
| torch.full((batch_size,), zone_id, dtype=torch.long, device=x.device) | |
| ) | |
| card_tokens.append(c_emb.unsqueeze(1)) | |
| return card_tokens | |
| # P0 Stage (Zone 1) - starts at 20 | |
| tokens.extend(make_cards(20, 3, 1)) | |
| # P1 Stage (Zone 2) - starts at 164 | |
| tokens.extend(make_cards(164, 3, 2)) | |
| # P0 Live (Zone 3) - starts at 308 | |
| tokens.extend(make_cards(308, 3, 3)) | |
| # P1 Live (Zone 4) - starts at 452 | |
| tokens.extend(make_cards(452, 3, 4)) | |
| # P0 Hand (Zone 5) - starts at 596 | |
| tokens.extend(make_cards(596, 10, 5)) | |
| # SeqLen = 1 + 3 + 3 + 3 + 3 + 10 = 23 | |
| return torch.cat(tokens, dim=1) | |
| class TransformerCardNet(nn.Module): | |
| def __init__(self, input_size=1200, action_size=2000): | |
| super().__init__() | |
| self.d_model = HIDDEN_SIZE | |
| # 1. Tokenizer | |
| self.tokenizer = Tokenizer(self.d_model) | |
| # 2. Transformer Encoder | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=self.d_model, nhead=N_HEADS, dim_feedforward=self.d_model * 4, dropout=DROPOUT, batch_first=True | |
| ) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=NUM_LAYERS) | |
| # 3. Policy Heads | |
| self.hand_action_proj = nn.Linear(self.d_model, 6) # [Play0, Play1, Play2, Energy, Mull, LiveSet] | |
| self.stage_action_proj = nn.Linear(self.d_model, 10) # [Ability0..9] | |
| self.live_action_proj = nn.Linear(self.d_model, 1) # [SelectSuccess] | |
| self.global_action_proj = nn.Linear(self.d_model, 10) # [0:Pass, 1..6:Colors, ... ] | |
| # Value Heads | |
| # Win-rate head (Sigmoid) | |
| self.value_win_head = nn.Sequential(nn.Linear(self.d_model, 128), nn.ReLU(), nn.Linear(128, 1), nn.Sigmoid()) | |
| # Score differential head (Tanh -1..1) | |
| self.value_score_head = nn.Sequential(nn.Linear(self.d_model, 128), nn.ReLU(), nn.Linear(128, 1), nn.Tanh()) | |
| # Auxiliary Pacing Head (Progress 0..1) | |
| self.turns_head = nn.Sequential(nn.Linear(self.d_model, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()) | |
| def forward(self, x): | |
| batch_size = x.size(0) | |
| tokens = self.tokenizer(x) | |
| encoded = self.transformer(tokens) # (B, 23, d_model) | |
| # --- Policy Reconstruction --- | |
| logits = torch.zeros(batch_size, 2000, device=x.device) | |
| # Global Actions | |
| global_tok = encoded[:, 0, :] | |
| g_logits = self.global_action_proj(global_tok) | |
| logits[:, 0] = g_logits[:, 0] # Pass | |
| logits[:, 580:586] = g_logits[:, 1:7] # Colors | |
| # Hand Actions (Tokens 13-22) | |
| hand_toks = encoded[:, 13:23, :] | |
| h_logits = self.hand_action_proj(hand_toks) # (B, 10, 6) | |
| for i in range(10): | |
| logits[:, 1 + 3 * i : 1 + 3 * i + 3] = h_logits[:, i, 0:3] | |
| logits[:, 100 + i] = h_logits[:, i, 3] # Energy | |
| logits[:, 300 + i] = h_logits[:, i, 4] # Mull | |
| logits[:, 400 + i] = h_logits[:, i, 5] # LiveSet | |
| # Stage Actions (Tokens 1-3) | |
| stage_toks = encoded[:, 1:4, :] | |
| s_logits = self.stage_action_proj(stage_toks) # (B, 3, 10) | |
| for i in range(3): | |
| logits[:, 200 + 10 * i : 200 + 10 * i + 10] = s_logits[:, i, :] | |
| # Live Zone Actions (Tokens 7-9) | |
| live_toks = encoded[:, 7:10, :] | |
| l_logits = self.live_action_proj(live_toks).squeeze(-1) # (B, 3) | |
| logits[:, 600:603] = l_logits | |
| # --- Value Heads --- | |
| cls_token = encoded[:, 0, :] | |
| val_win = self.value_win_head(cls_token) # (B, 1) | |
| val_score = self.value_score_head(cls_token) # (B, 1) | |
| turns_pred = self.turns_head(cls_token) # (B, 1) | |
| return F.softmax(logits, dim=1), val_win, val_score, turns_pred | |
| class TorchNetworkWrapper: | |
| """Wrapper to interface with MCTS/Training loop""" | |
| def __init__(self, config=None, device=None, enable_compile=True): | |
| self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {self.device}") | |
| self.net = TransformerCardNet().to(self.device) | |
| if enable_compile and hasattr(torch, "compile") and "win" not in torch.sys.platform: | |
| try: | |
| print("Compiling Transformer with torch.compile...") | |
| self.net = torch.compile(self.net, mode="reduce-overhead") | |
| except Exception as e: | |
| print(f"Compile failed: {e}") | |
| lr = 0.0003 | |
| self.optimizer = optim.AdamW(self.net.parameters(), lr=lr, weight_decay=1e-4) | |
| def predict(self, state) -> Tuple[np.ndarray, float]: | |
| self.net.eval() | |
| obs = state.get_observation() | |
| if len(obs) != 1200: | |
| if len(obs) < 1200: | |
| obs = obs + [0.0] * (1200 - len(obs)) | |
| else: | |
| obs = obs[:1200] | |
| x = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| p_soft, v_win, v_score, t_pred = self.net(x) | |
| p = p_soft.cpu().numpy()[0] | |
| v = v_win.item() # MCTS typically uses win probability [0,1] or [-1,1] | |
| # Mask illegal | |
| legal = state.get_legal_actions() | |
| masked = p * legal | |
| sum_p = masked.sum() | |
| if sum_p > 0: | |
| masked /= sum_p | |
| else: | |
| masked = legal.astype(np.float32) / legal.sum() | |
| return masked, v | |
| def train_step(self, obs, target_p, target_v_win, target_v_score, target_turns): | |
| """ | |
| obs: (B, 1200) | |
| target_p: (B, 2000) | |
| target_v_win: (B, 1) | |
| target_v_score: (B, 1) | |
| target_turns: (B, 1) | |
| """ | |
| self.net.train() | |
| self.optimizer.zero_grad() | |
| x = torch.tensor(obs, dtype=torch.float32).to(self.device) | |
| t_p = torch.tensor(target_p, dtype=torch.float32).to(self.device) | |
| t_w = torch.tensor(target_v_win, dtype=torch.float32).to(self.device) | |
| t_s = torch.tensor(target_v_score, dtype=torch.float32).to(self.device) | |
| t_t = torch.tensor(target_turns, dtype=torch.float32).to(self.device) | |
| p, w, s, t = self.net(x) | |
| loss_p = -torch.sum(t_p * torch.log(p + 1e-8)) / x.size(0) | |
| loss_w = F.binary_cross_entropy(w, t_w) | |
| loss_s = F.mse_loss(s, t_s) | |
| loss_t = F.mse_loss(t, t_t) | |
| total_loss = loss_p + loss_w + loss_s + loss_t | |
| total_loss.backward() | |
| self.optimizer.step() | |
| return total_loss.item(), loss_p.item(), loss_w.item(), loss_s.item() | |
| def save(self, path): | |
| if hasattr(self.net, "_orig_mod"): | |
| torch.save(self.net._orig_mod.state_dict(), path) | |
| else: | |
| torch.save(self.net.state_dict(), path) | |
| def load(self, path): | |
| sd = torch.load(path, map_location=self.device) | |
| sd = {k.replace("_orig_mod.", ""): v for k, v in sd.items()} | |
| if hasattr(self.net, "_orig_mod"): | |
| self.net._orig_mod.load_state_dict(sd) | |
| else: | |
| self.net.load_state_dict(sd) | |