""" ByteFight Policy Model: CNN board encoder + Transformer action decoder. Trains on full game trajectories: [board₁, act₁₀, act₁₁, ..., board₂, act₂₀, act₂₁, ..., board_N, ...] CNN compresses each board to 1 vector. Transformer sees interleaved sequence. Only action positions have prediction targets (board positions masked with -100). Board channels (8): my_paint, opp_paint, my_beacon, opp_beacon, wall, hill, powerup, valid Scalars (10): my_stam, my_max_stam, opp_stam, opp_max_stam, my_r, my_c, opp_r, opp_c, turn, area Actions (21): 20 action types + EOS """ import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass NUM_ACTIONS = 21 EOS_ACTION = 20 BOARD_CHANNELS = 8 MAX_BOARD = 31 @dataclass class Config: d_model: int = 256 n_layer: int = 6 n_head: int = 8 max_seq: int = 6000 n_actions: int = NUM_ACTIONS board_channels: int = BOARD_CHANNELS n_scalars: int = 10 dropout: float = 0.1 class ResBlock(nn.Module): def __init__(self, channels): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(channels) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) return F.relu(out + x) class BoardEncoder(nn.Module): def __init__(self, cfg: Config): super().__init__() self.stem = nn.Sequential( nn.Conv2d(cfg.board_channels, 64, 3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU()) self.blocks = nn.Sequential( ResBlock(64), ResBlock(64), nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(), ResBlock(128), ResBlock(128)) self.pool = nn.AdaptiveAvgPool2d(1) self.proj = nn.Linear(128, cfg.d_model) self.scalar_proj = nn.Linear(cfg.n_scalars, cfg.d_model) def forward(self, board_grid, scalars): x = self.stem(board_grid) x = self.blocks(x) x = self.pool(x).flatten(1) return self.proj(x) + self.scalar_proj(scalars) class RMSNorm(nn.Module): def __init__(self, d, eps=1e-6): super().__init__() self.w = nn.Parameter(torch.ones(d)) self.eps = eps def forward(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.w class Attention(nn.Module): def __init__(self, cfg): super().__init__() self.n_head = cfg.n_head self.head_dim = cfg.d_model // cfg.n_head self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) self.out = nn.Linear(cfg.d_model, cfg.d_model, bias=False) self.dropout = cfg.dropout def forward(self, x, kv_cache=None): B, T, C = x.shape qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.head_dim) q, k, v = qkv.unbind(2) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) if kv_cache is not None: k_prev, v_prev = kv_cache k = torch.cat([k_prev, k], dim=2) v = torch.cat([v_prev, v], dim=2) new_cache = (k, v) x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=(kv_cache is None)) # causal only for full sequence return self.out(x.transpose(1, 2).reshape(B, T, C)), new_cache class FFN(nn.Module): def __init__(self, cfg): super().__init__() h = (int(cfg.d_model * 8 / 3) + 15) // 16 * 16 self.w1 = nn.Linear(cfg.d_model, h, bias=False) self.w2 = nn.Linear(h, cfg.d_model, bias=False) self.w3 = nn.Linear(cfg.d_model, h, bias=False) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) class Block(nn.Module): def __init__(self, cfg): super().__init__() self.norm1 = RMSNorm(cfg.d_model) self.attn = Attention(cfg) self.norm2 = RMSNorm(cfg.d_model) self.ffn = FFN(cfg) def forward(self, x, kv_cache=None): attn_out, new_cache = self.attn(self.norm1(x), kv_cache) x = x + attn_out x = x + self.ffn(self.norm2(x)) return x, new_cache class PolicyModel(nn.Module): """CNN encoder + Transformer decoder for full-game action generation.""" def __init__(self, cfg: Config = None): super().__init__() if cfg is None: cfg = Config() self.cfg = cfg self.encoder = BoardEncoder(cfg) self.action_embed = nn.Embedding(cfg.n_actions, cfg.d_model) self.pos_embed = nn.Embedding(cfg.max_seq, cfg.d_model) self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)]) self.norm = RMSNorm(cfg.d_model) self.head = nn.Linear(cfg.d_model, cfg.n_actions, bias=False) def forward(self, boards, scalars, seq_actions, seq_targets, seq_is_board, board_counts): """ Full-game trajectory training. Args: boards: (total_boards, C, H, W) — all boards from all games in batch scalars: (total_boards, n_scalars) seq_actions: (B, T) int64 — action tokens at each position (0 at board positions) seq_targets: (B, T) int64 — targets (-100 at board/padding positions) seq_is_board: (B, T) bool — True at board embedding positions board_counts: (B,) int — number of boards per game (for splitting) Returns: logits: (B, T, n_actions) loss: scalar """ B, T = seq_actions.shape device = seq_actions.device # CNN encode all boards at once board_embs = self.encoder(boards, scalars) # (total_boards, d_model) # Build embedding sequence act_emb = self.action_embed(seq_actions) # (B, T, d_model) seq_emb = act_emb.clone() # Place board embeddings at the right positions board_idx = 0 for b in range(B): positions = seq_is_board[b].nonzero(as_tuple=True)[0] n_boards = board_counts[b].item() for i in range(n_boards): if i < len(positions): seq_emb[b, positions[i]] = board_embs[board_idx] board_idx += 1 # Add positional embeddings pos = self.pos_embed(torch.arange(T, device=device)) x = seq_emb + pos for block in self.blocks: x, _ = block(x) logits = self.head(self.norm(x)) loss = F.cross_entropy( logits.reshape(-1, self.cfg.n_actions), seq_targets.reshape(-1), ignore_index=-100) return logits, loss @torch.no_grad() def generate(self, board_grid, scalars, kv_caches=None, seq_pos=0, max_actions=10, temperature=0.0): """ Generate actions for one turn with KV cache. Args: board_grid: (1, C, H, W) scalars: (1, n_scalars) kv_caches: list of (k, v) per layer from previous turns, or None seq_pos: current position in the sequence (for positional embeddings) Returns: actions: list of action token IDs kv_caches: updated KV caches seq_pos: updated position """ self.eval() device = board_grid.device # CNN encode board board_emb = self.encoder(board_grid, scalars).unsqueeze(1) # (1, 1, d) # Process board embedding token pos = self.pos_embed(torch.tensor([seq_pos], device=device)).unsqueeze(0) # (1, 1, d) x = board_emb + pos if kv_caches is None: kv_caches = [None] * len(self.blocks) new_caches = [] for block, cache in zip(self.blocks, kv_caches): x, new_cache = block(x, cache) new_caches.append(new_cache) kv_caches = new_caches logits = self.head(self.norm(x)) next_logits = logits[0, -1] seq_pos += 1 # First action predicted from board embedding if temperature <= 0: action = next_logits.argmax().item() else: action = torch.multinomial( F.softmax(next_logits / temperature, dim=-1), 1).item() actions = [] if action == EOS_ACTION: return actions, kv_caches, seq_pos actions.append(action) # Generate remaining actions autoregressively for _ in range(max_actions - 1): act_emb = self.action_embed(torch.tensor([[action]], device=device)) pos = self.pos_embed(torch.tensor([seq_pos], device=device)).unsqueeze(0) x = act_emb + pos new_caches = [] for block, cache in zip(self.blocks, kv_caches): x, new_cache = block(x, cache) new_caches.append(new_cache) kv_caches = new_caches logits = self.head(self.norm(x)) next_logits = logits[0, -1] seq_pos += 1 if temperature <= 0: action = next_logits.argmax().item() else: action = torch.multinomial( F.softmax(next_logits / temperature, dim=-1), 1).item() if action == EOS_ACTION: break actions.append(action) return actions, kv_caches, seq_pos def count_params(self): return sum(p.numel() for p in self.parameters()) if __name__ == '__main__': cfg = Config() model = PolicyModel(cfg) print(f"Total params: {model.count_params():,}") for h, w in [(8, 8), (15, 15), (31, 31)]: board = torch.randn(2, BOARD_CHANNELS, h, w) s = torch.randn(2, 10) emb = model.encoder(board, s) print(f" Board {h}×{w} → {emb.shape}") # Test KV cache generation across multiple turns kv = None pos = 0 for turn in range(5): board = torch.randn(1, BOARD_CHANNELS, 15, 15) s = torch.randn(1, 10) acts, kv, pos = model.generate(board, s, kv_caches=kv, seq_pos=pos) cache_size = kv[0][0].shape[2] if kv else 0 print(f" Turn {turn}: actions={acts}, seq_pos={pos}, cache_len={cache_size}")