LovecaSim / ai /models /network_torch.py
trioskosmos's picture
Upload ai/models/network_torch.py with huggingface_hub
f9b81b4 verified
"""
PyTorch implementation of Transformer-based AlphaZero network.
Processes the game state as a set of interacting cards (Tokens) rather than a flat vector.
"""
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# Import config constants
from .training_config import DROPOUT, HIDDEN_SIZE, N_HEADS, NUM_LAYERS
class Tokenizer(nn.Module):
"""
Slices the 1200-float input vector into semantic tokens:
- 1 Global Token (144 features: 20 basic + 124 heuristics/misc)
- 22 Card Tokens (6 Stage, 6 Live, 10 Hand) - 48 features each
"""
def __init__(self, d_model: int):
super().__init__()
self.d_model = d_model
self.card_size = 48
# Global (20) + Tail (1076:1200 = 124) = 144 features
self.global_size = 144
# Projections
self.global_proj = nn.Linear(self.global_size, d_model)
self.card_proj = nn.Linear(self.card_size, d_model)
# Zone Embeddings: 0=Global, 1=P0_Stage, 2=P1_Stage, 3=P0_Live, 4=P1_Live, 5=P0_Hand
self.zone_embedding = nn.Embedding(8, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, 1200)
batch_size = x.shape[0]
tokens = []
# 1. Global Token
# Basic Globals (0-20) + Tail Heuristics (1076-1200)
global_feat = torch.cat([x[:, 0:20], x[:, 1076:1200]], dim=1)
t_global = self.global_proj(global_feat) # (B, d_model)
t_global = t_global + self.zone_embedding(torch.zeros(batch_size, dtype=torch.long, device=x.device))
tokens.append(t_global.unsqueeze(1))
# 2. Card Tokens helper
def make_cards(start_idx, count, zone_id):
card_tokens = []
for i in range(count):
s = start_idx + i * 48
e = s + 48
c_vec = x[:, s:e]
c_emb = self.card_proj(c_vec)
c_emb = c_emb + self.zone_embedding(
torch.full((batch_size,), zone_id, dtype=torch.long, device=x.device)
)
card_tokens.append(c_emb.unsqueeze(1))
return card_tokens
# P0 Stage (Zone 1) - starts at 20
tokens.extend(make_cards(20, 3, 1))
# P1 Stage (Zone 2) - starts at 164
tokens.extend(make_cards(164, 3, 2))
# P0 Live (Zone 3) - starts at 308
tokens.extend(make_cards(308, 3, 3))
# P1 Live (Zone 4) - starts at 452
tokens.extend(make_cards(452, 3, 4))
# P0 Hand (Zone 5) - starts at 596
tokens.extend(make_cards(596, 10, 5))
# SeqLen = 1 + 3 + 3 + 3 + 3 + 10 = 23
return torch.cat(tokens, dim=1)
class TransformerCardNet(nn.Module):
def __init__(self, input_size=1200, action_size=2000):
super().__init__()
self.d_model = HIDDEN_SIZE
# 1. Tokenizer
self.tokenizer = Tokenizer(self.d_model)
# 2. Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model, nhead=N_HEADS, dim_feedforward=self.d_model * 4, dropout=DROPOUT, batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=NUM_LAYERS)
# 3. Policy Heads
self.hand_action_proj = nn.Linear(self.d_model, 6) # [Play0, Play1, Play2, Energy, Mull, LiveSet]
self.stage_action_proj = nn.Linear(self.d_model, 10) # [Ability0..9]
self.live_action_proj = nn.Linear(self.d_model, 1) # [SelectSuccess]
self.global_action_proj = nn.Linear(self.d_model, 10) # [0:Pass, 1..6:Colors, ... ]
# Value Heads
# Win-rate head (Sigmoid)
self.value_win_head = nn.Sequential(nn.Linear(self.d_model, 128), nn.ReLU(), nn.Linear(128, 1), nn.Sigmoid())
# Score differential head (Tanh -1..1)
self.value_score_head = nn.Sequential(nn.Linear(self.d_model, 128), nn.ReLU(), nn.Linear(128, 1), nn.Tanh())
# Auxiliary Pacing Head (Progress 0..1)
self.turns_head = nn.Sequential(nn.Linear(self.d_model, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())
def forward(self, x):
batch_size = x.size(0)
tokens = self.tokenizer(x)
encoded = self.transformer(tokens) # (B, 23, d_model)
# --- Policy Reconstruction ---
logits = torch.zeros(batch_size, 2000, device=x.device)
# Global Actions
global_tok = encoded[:, 0, :]
g_logits = self.global_action_proj(global_tok)
logits[:, 0] = g_logits[:, 0] # Pass
logits[:, 580:586] = g_logits[:, 1:7] # Colors
# Hand Actions (Tokens 13-22)
hand_toks = encoded[:, 13:23, :]
h_logits = self.hand_action_proj(hand_toks) # (B, 10, 6)
for i in range(10):
logits[:, 1 + 3 * i : 1 + 3 * i + 3] = h_logits[:, i, 0:3]
logits[:, 100 + i] = h_logits[:, i, 3] # Energy
logits[:, 300 + i] = h_logits[:, i, 4] # Mull
logits[:, 400 + i] = h_logits[:, i, 5] # LiveSet
# Stage Actions (Tokens 1-3)
stage_toks = encoded[:, 1:4, :]
s_logits = self.stage_action_proj(stage_toks) # (B, 3, 10)
for i in range(3):
logits[:, 200 + 10 * i : 200 + 10 * i + 10] = s_logits[:, i, :]
# Live Zone Actions (Tokens 7-9)
live_toks = encoded[:, 7:10, :]
l_logits = self.live_action_proj(live_toks).squeeze(-1) # (B, 3)
logits[:, 600:603] = l_logits
# --- Value Heads ---
cls_token = encoded[:, 0, :]
val_win = self.value_win_head(cls_token) # (B, 1)
val_score = self.value_score_head(cls_token) # (B, 1)
turns_pred = self.turns_head(cls_token) # (B, 1)
return F.softmax(logits, dim=1), val_win, val_score, turns_pred
class TorchNetworkWrapper:
"""Wrapper to interface with MCTS/Training loop"""
def __init__(self, config=None, device=None, enable_compile=True):
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.net = TransformerCardNet().to(self.device)
if enable_compile and hasattr(torch, "compile") and "win" not in torch.sys.platform:
try:
print("Compiling Transformer with torch.compile...")
self.net = torch.compile(self.net, mode="reduce-overhead")
except Exception as e:
print(f"Compile failed: {e}")
lr = 0.0003
self.optimizer = optim.AdamW(self.net.parameters(), lr=lr, weight_decay=1e-4)
def predict(self, state) -> Tuple[np.ndarray, float]:
self.net.eval()
obs = state.get_observation()
if len(obs) != 1200:
if len(obs) < 1200:
obs = obs + [0.0] * (1200 - len(obs))
else:
obs = obs[:1200]
x = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(self.device)
with torch.no_grad():
p_soft, v_win, v_score, t_pred = self.net(x)
p = p_soft.cpu().numpy()[0]
v = v_win.item() # MCTS typically uses win probability [0,1] or [-1,1]
# Mask illegal
legal = state.get_legal_actions()
masked = p * legal
sum_p = masked.sum()
if sum_p > 0:
masked /= sum_p
else:
masked = legal.astype(np.float32) / legal.sum()
return masked, v
def train_step(self, obs, target_p, target_v_win, target_v_score, target_turns):
"""
obs: (B, 1200)
target_p: (B, 2000)
target_v_win: (B, 1)
target_v_score: (B, 1)
target_turns: (B, 1)
"""
self.net.train()
self.optimizer.zero_grad()
x = torch.tensor(obs, dtype=torch.float32).to(self.device)
t_p = torch.tensor(target_p, dtype=torch.float32).to(self.device)
t_w = torch.tensor(target_v_win, dtype=torch.float32).to(self.device)
t_s = torch.tensor(target_v_score, dtype=torch.float32).to(self.device)
t_t = torch.tensor(target_turns, dtype=torch.float32).to(self.device)
p, w, s, t = self.net(x)
loss_p = -torch.sum(t_p * torch.log(p + 1e-8)) / x.size(0)
loss_w = F.binary_cross_entropy(w, t_w)
loss_s = F.mse_loss(s, t_s)
loss_t = F.mse_loss(t, t_t)
total_loss = loss_p + loss_w + loss_s + loss_t
total_loss.backward()
self.optimizer.step()
return total_loss.item(), loss_p.item(), loss_w.item(), loss_s.item()
def save(self, path):
if hasattr(self.net, "_orig_mod"):
torch.save(self.net._orig_mod.state_dict(), path)
else:
torch.save(self.net.state_dict(), path)
def load(self, path):
sd = torch.load(path, map_location=self.device)
sd = {k.replace("_orig_mod.", ""): v for k, v in sd.items()}
if hasattr(self.net, "_orig_mod"):
self.net._orig_mod.load_state_dict(sd)
else:
self.net.load_state_dict(sd)