| """ |
| ByteFight Policy Model: CNN board encoder + Transformer action decoder. |
| |
| Trains on full game trajectories: |
| [board₁, act₁₀, act₁₁, ..., board₂, act₂₀, act₂₁, ..., board_N, ...] |
| |
| CNN compresses each board to 1 vector. Transformer sees interleaved sequence. |
| Only action positions have prediction targets (board positions masked with -100). |
| |
| Board channels (8): my_paint, opp_paint, my_beacon, opp_beacon, wall, hill, powerup, valid |
| Scalars (10): my_stam, my_max_stam, opp_stam, opp_max_stam, my_r, my_c, opp_r, opp_c, turn, area |
| Actions (21): 20 action types + EOS |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
|
|
| NUM_ACTIONS = 21 |
| EOS_ACTION = 20 |
| BOARD_CHANNELS = 8 |
| MAX_BOARD = 31 |
|
|
|
|
| @dataclass |
| class Config: |
| d_model: int = 256 |
| n_layer: int = 6 |
| n_head: int = 8 |
| max_seq: int = 6000 |
| n_actions: int = NUM_ACTIONS |
| board_channels: int = BOARD_CHANNELS |
| n_scalars: int = 10 |
| dropout: float = 0.1 |
|
|
|
|
| class ResBlock(nn.Module): |
| def __init__(self, channels): |
| super().__init__() |
| self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False) |
| self.bn1 = nn.BatchNorm2d(channels) |
| self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False) |
| self.bn2 = nn.BatchNorm2d(channels) |
|
|
| def forward(self, x): |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = self.bn2(self.conv2(out)) |
| return F.relu(out + x) |
|
|
|
|
| class BoardEncoder(nn.Module): |
| def __init__(self, cfg: Config): |
| super().__init__() |
| self.stem = nn.Sequential( |
| nn.Conv2d(cfg.board_channels, 64, 3, padding=1, bias=False), |
| nn.BatchNorm2d(64), nn.ReLU()) |
| self.blocks = nn.Sequential( |
| ResBlock(64), ResBlock(64), |
| nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), |
| nn.BatchNorm2d(128), nn.ReLU(), |
| ResBlock(128), ResBlock(128)) |
| self.pool = nn.AdaptiveAvgPool2d(1) |
| self.proj = nn.Linear(128, cfg.d_model) |
| self.scalar_proj = nn.Linear(cfg.n_scalars, cfg.d_model) |
|
|
| def forward(self, board_grid, scalars): |
| x = self.stem(board_grid) |
| x = self.blocks(x) |
| x = self.pool(x).flatten(1) |
| return self.proj(x) + self.scalar_proj(scalars) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, d, eps=1e-6): |
| super().__init__() |
| self.w = nn.Parameter(torch.ones(d)) |
| self.eps = eps |
| def forward(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.w |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.n_head = cfg.n_head |
| self.head_dim = cfg.d_model // cfg.n_head |
| self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) |
| self.out = nn.Linear(cfg.d_model, cfg.d_model, bias=False) |
| self.dropout = cfg.dropout |
|
|
| def forward(self, x, kv_cache=None): |
| B, T, C = x.shape |
| qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.head_dim) |
| q, k, v = qkv.unbind(2) |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
|
|
| if kv_cache is not None: |
| k_prev, v_prev = kv_cache |
| k = torch.cat([k_prev, k], dim=2) |
| v = torch.cat([v_prev, v], dim=2) |
|
|
| new_cache = (k, v) |
| x = F.scaled_dot_product_attention(q, k, v, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=(kv_cache is None)) |
| return self.out(x.transpose(1, 2).reshape(B, T, C)), new_cache |
|
|
|
|
| class FFN(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| h = (int(cfg.d_model * 8 / 3) + 15) // 16 * 16 |
| self.w1 = nn.Linear(cfg.d_model, h, bias=False) |
| self.w2 = nn.Linear(h, cfg.d_model, bias=False) |
| self.w3 = nn.Linear(cfg.d_model, h, bias=False) |
|
|
| def forward(self, x): |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.norm1 = RMSNorm(cfg.d_model) |
| self.attn = Attention(cfg) |
| self.norm2 = RMSNorm(cfg.d_model) |
| self.ffn = FFN(cfg) |
|
|
| def forward(self, x, kv_cache=None): |
| attn_out, new_cache = self.attn(self.norm1(x), kv_cache) |
| x = x + attn_out |
| x = x + self.ffn(self.norm2(x)) |
| return x, new_cache |
|
|
|
|
| class PolicyModel(nn.Module): |
| """CNN encoder + Transformer decoder for full-game action generation.""" |
|
|
| def __init__(self, cfg: Config = None): |
| super().__init__() |
| if cfg is None: |
| cfg = Config() |
| self.cfg = cfg |
| self.encoder = BoardEncoder(cfg) |
| self.action_embed = nn.Embedding(cfg.n_actions, cfg.d_model) |
| self.pos_embed = nn.Embedding(cfg.max_seq, cfg.d_model) |
| self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)]) |
| self.norm = RMSNorm(cfg.d_model) |
| self.head = nn.Linear(cfg.d_model, cfg.n_actions, bias=False) |
|
|
| def forward(self, boards, scalars, seq_actions, seq_targets, seq_is_board, |
| board_counts): |
| """ |
| Full-game trajectory training. |
| |
| Args: |
| boards: (total_boards, C, H, W) — all boards from all games in batch |
| scalars: (total_boards, n_scalars) |
| seq_actions: (B, T) int64 — action tokens at each position (0 at board positions) |
| seq_targets: (B, T) int64 — targets (-100 at board/padding positions) |
| seq_is_board: (B, T) bool — True at board embedding positions |
| board_counts: (B,) int — number of boards per game (for splitting) |
| Returns: |
| logits: (B, T, n_actions) |
| loss: scalar |
| """ |
| B, T = seq_actions.shape |
| device = seq_actions.device |
|
|
| |
| board_embs = self.encoder(boards, scalars) |
|
|
| |
| act_emb = self.action_embed(seq_actions) |
| seq_emb = act_emb.clone() |
|
|
| |
| board_idx = 0 |
| for b in range(B): |
| positions = seq_is_board[b].nonzero(as_tuple=True)[0] |
| n_boards = board_counts[b].item() |
| for i in range(n_boards): |
| if i < len(positions): |
| seq_emb[b, positions[i]] = board_embs[board_idx] |
| board_idx += 1 |
|
|
| |
| pos = self.pos_embed(torch.arange(T, device=device)) |
| x = seq_emb + pos |
|
|
| for block in self.blocks: |
| x, _ = block(x) |
|
|
| logits = self.head(self.norm(x)) |
|
|
| loss = F.cross_entropy( |
| logits.reshape(-1, self.cfg.n_actions), |
| seq_targets.reshape(-1), |
| ignore_index=-100) |
|
|
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate(self, board_grid, scalars, kv_caches=None, seq_pos=0, |
| max_actions=10, temperature=0.0): |
| """ |
| Generate actions for one turn with KV cache. |
| |
| Args: |
| board_grid: (1, C, H, W) |
| scalars: (1, n_scalars) |
| kv_caches: list of (k, v) per layer from previous turns, or None |
| seq_pos: current position in the sequence (for positional embeddings) |
| Returns: |
| actions: list of action token IDs |
| kv_caches: updated KV caches |
| seq_pos: updated position |
| """ |
| self.eval() |
| device = board_grid.device |
|
|
| |
| board_emb = self.encoder(board_grid, scalars).unsqueeze(1) |
|
|
| |
| pos = self.pos_embed(torch.tensor([seq_pos], device=device)).unsqueeze(0) |
| x = board_emb + pos |
|
|
| if kv_caches is None: |
| kv_caches = [None] * len(self.blocks) |
|
|
| new_caches = [] |
| for block, cache in zip(self.blocks, kv_caches): |
| x, new_cache = block(x, cache) |
| new_caches.append(new_cache) |
| kv_caches = new_caches |
|
|
| logits = self.head(self.norm(x)) |
| next_logits = logits[0, -1] |
| seq_pos += 1 |
|
|
| |
| if temperature <= 0: |
| action = next_logits.argmax().item() |
| else: |
| action = torch.multinomial( |
| F.softmax(next_logits / temperature, dim=-1), 1).item() |
|
|
| actions = [] |
| if action == EOS_ACTION: |
| return actions, kv_caches, seq_pos |
|
|
| actions.append(action) |
|
|
| |
| for _ in range(max_actions - 1): |
| act_emb = self.action_embed(torch.tensor([[action]], device=device)) |
| pos = self.pos_embed(torch.tensor([seq_pos], device=device)).unsqueeze(0) |
| x = act_emb + pos |
|
|
| new_caches = [] |
| for block, cache in zip(self.blocks, kv_caches): |
| x, new_cache = block(x, cache) |
| new_caches.append(new_cache) |
| kv_caches = new_caches |
|
|
| logits = self.head(self.norm(x)) |
| next_logits = logits[0, -1] |
| seq_pos += 1 |
|
|
| if temperature <= 0: |
| action = next_logits.argmax().item() |
| else: |
| action = torch.multinomial( |
| F.softmax(next_logits / temperature, dim=-1), 1).item() |
|
|
| if action == EOS_ACTION: |
| break |
| actions.append(action) |
|
|
| return actions, kv_caches, seq_pos |
|
|
| def count_params(self): |
| return sum(p.numel() for p in self.parameters()) |
|
|
|
|
| if __name__ == '__main__': |
| cfg = Config() |
| model = PolicyModel(cfg) |
| print(f"Total params: {model.count_params():,}") |
|
|
| for h, w in [(8, 8), (15, 15), (31, 31)]: |
| board = torch.randn(2, BOARD_CHANNELS, h, w) |
| s = torch.randn(2, 10) |
| emb = model.encoder(board, s) |
| print(f" Board {h}×{w} → {emb.shape}") |
|
|
| |
| kv = None |
| pos = 0 |
| for turn in range(5): |
| board = torch.randn(1, BOARD_CHANNELS, 15, 15) |
| s = torch.randn(1, 10) |
| acts, kv, pos = model.generate(board, s, kv_caches=kv, seq_pos=pos) |
| cache_size = kv[0][0].shape[2] if kv else 0 |
| print(f" Turn {turn}: actions={acts}, seq_pos={pos}, cache_len={cache_size}") |
|
|