| """ |
| ByteFight Policy Model v3: ViT board encoder + Transformer action decoder. |
| |
| Board encoder: 970 discrete tokens (9 scalars + 961 cells) |
| → shared embedding → 2-layer bidirectional self-attention → mean pool → 1 vector |
| |
| Uses same tokenization as original alphabyte (tokenizer.py): |
| Vocab 2462: CLS=0, stamina 1-381, position 382-412, turn 413-2413, cells 2414-2461 |
| |
| Actions (21): 20 action types + EOS |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
|
|
| NUM_ACTIONS = 21 |
| EOS_ACTION = 20 |
| BOARD_VOCAB = 2462 |
| BOARD_SEQ_LEN = 970 |
| MAX_BOARD = 31 |
|
|
|
|
| @dataclass |
| class Config: |
| d_model: int = 256 |
| n_layer: int = 6 |
| n_head: int = 8 |
| max_seq: int = 6000 |
| n_actions: int = NUM_ACTIONS |
| board_vocab: int = BOARD_VOCAB |
| board_seq_len: int = BOARD_SEQ_LEN |
| dropout: float = 0.1 |
|
|
|
|
| class BoardAttnBlock(nn.Module): |
| def __init__(self, d, n_head): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(d) |
| self.n_head = n_head |
| self.head_dim = d // n_head |
| self.qkv = nn.Linear(d, 3 * d, bias=False) |
| self.out = nn.Linear(d, d, bias=False) |
| self.norm2 = nn.LayerNorm(d) |
| h = d * 4 |
| self.ffn = nn.Sequential(nn.Linear(d, h), nn.GELU(), nn.Linear(h, d)) |
|
|
| def forward(self, x): |
| B, T, C = x.shape |
| h = self.norm1(x) |
| qkv = self.qkv(h).reshape(B, T, 3, self.n_head, self.head_dim) |
| q, k, v = qkv.unbind(2) |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
| attn = F.scaled_dot_product_attention(q, k, v) |
| x = x + self.out(attn.transpose(1, 2).reshape(B, T, C)) |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| class BoardEncoder(nn.Module): |
| """ViT: 970 discrete board tokens → embedding → self-attention → 1 vector.""" |
|
|
| def __init__(self, cfg: Config): |
| super().__init__() |
| self.embed = nn.Embedding(cfg.board_vocab, cfg.d_model) |
| self.pos_embed = nn.Parameter(torch.randn(1, cfg.board_seq_len, cfg.d_model) * 0.02) |
| self.blocks = nn.ModuleList([ |
| BoardAttnBlock(cfg.d_model, cfg.n_head) for _ in range(2) |
| ]) |
| self.norm = nn.LayerNorm(cfg.d_model) |
|
|
| def forward(self, board_tokens): |
| """board_tokens: (B, 970) int64 → (B, d_model)""" |
| x = self.embed(board_tokens) |
| x = x + self.pos_embed[:, :x.shape[1]] |
|
|
| for block in self.blocks: |
| x = block(x) |
|
|
| return self.norm(x).mean(dim=1) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, d, eps=1e-6): |
| super().__init__() |
| self.w = nn.Parameter(torch.ones(d)) |
| self.eps = eps |
| def forward(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.w |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.n_head = cfg.n_head |
| self.head_dim = cfg.d_model // cfg.n_head |
| self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) |
| self.out = nn.Linear(cfg.d_model, cfg.d_model, bias=False) |
| self.dropout = cfg.dropout |
|
|
| def forward(self, x, kv_cache=None): |
| B, T, C = x.shape |
| qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.head_dim) |
| q, k, v = qkv.unbind(2) |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
|
|
| if kv_cache is not None: |
| k_prev, v_prev = kv_cache |
| k = torch.cat([k_prev, k], dim=2) |
| v = torch.cat([v_prev, v], dim=2) |
|
|
| new_cache = (k, v) |
| x = F.scaled_dot_product_attention(q, k, v, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=(kv_cache is None)) |
| return self.out(x.transpose(1, 2).reshape(B, T, C)), new_cache |
|
|
|
|
| class FFN(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| h = (int(cfg.d_model * 8 / 3) + 15) // 16 * 16 |
| self.w1 = nn.Linear(cfg.d_model, h, bias=False) |
| self.w2 = nn.Linear(h, cfg.d_model, bias=False) |
| self.w3 = nn.Linear(cfg.d_model, h, bias=False) |
|
|
| def forward(self, x): |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.norm1 = RMSNorm(cfg.d_model) |
| self.attn = Attention(cfg) |
| self.norm2 = RMSNorm(cfg.d_model) |
| self.ffn = FFN(cfg) |
|
|
| def forward(self, x, kv_cache=None): |
| attn_out, new_cache = self.attn(self.norm1(x), kv_cache) |
| x = x + attn_out |
| x = x + self.ffn(self.norm2(x)) |
| return x, new_cache |
|
|
|
|
| class PolicyModel(nn.Module): |
| """16-channel CNN encoder + Transformer decoder.""" |
|
|
| def __init__(self, cfg: Config = None): |
| super().__init__() |
| if cfg is None: |
| cfg = Config() |
| self.cfg = cfg |
| self.encoder = BoardEncoder(cfg) |
| self.action_embed = nn.Embedding(cfg.n_actions, cfg.d_model) |
| self.pos_embed = nn.Embedding(cfg.max_seq, cfg.d_model) |
| self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)]) |
| self.norm = RMSNorm(cfg.d_model) |
| self.head = nn.Linear(cfg.d_model, cfg.n_actions, bias=False) |
|
|
| def forward(self, board_tokens, seq_actions, seq_targets, seq_is_board, board_counts): |
| """Training forward. board_tokens: (total_boards, 970) int64.""" |
| B, T = seq_actions.shape |
| device = seq_actions.device |
|
|
| board_embs = self.encoder(board_tokens) |
|
|
| act_emb = self.action_embed(seq_actions) |
| seq_emb = act_emb.clone() |
|
|
| board_idx = 0 |
| for b in range(B): |
| positions = seq_is_board[b].nonzero(as_tuple=True)[0] |
| n_boards = board_counts[b].item() |
| for i in range(n_boards): |
| if i < len(positions): |
| seq_emb[b, positions[i]] = board_embs[board_idx] |
| board_idx += 1 |
|
|
| pos = self.pos_embed(torch.arange(T, device=device)) |
| x = seq_emb + pos |
|
|
| for block in self.blocks: |
| x, _ = block(x) |
|
|
| logits = self.head(self.norm(x)) |
|
|
| loss = F.cross_entropy( |
| logits.reshape(-1, self.cfg.n_actions), |
| seq_targets.reshape(-1), |
| ignore_index=-100) |
|
|
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate(self, board_tokens, kv_caches=None, seq_pos=0, |
| max_actions=10, temperature=0.0): |
| """Generate actions with KV cache. board_tokens: (1, 970) int64.""" |
| self.eval() |
| device = board_tokens.device |
|
|
| board_emb = self.encoder(board_tokens).unsqueeze(1) |
| pos = self.pos_embed(torch.tensor([seq_pos], device=device)).unsqueeze(0) |
| x = board_emb + pos |
|
|
| if kv_caches is None: |
| kv_caches = [None] * len(self.blocks) |
|
|
| new_caches = [] |
| for block, cache in zip(self.blocks, kv_caches): |
| x, new_cache = block(x, cache) |
| new_caches.append(new_cache) |
| kv_caches = new_caches |
|
|
| logits = self.head(self.norm(x)) |
| next_logits = logits[0, -1] |
| seq_pos += 1 |
|
|
| if temperature <= 0: |
| action = next_logits.argmax().item() |
| else: |
| action = torch.multinomial( |
| F.softmax(next_logits / temperature, dim=-1), 1).item() |
|
|
| actions = [] |
| if action == EOS_ACTION: |
| return actions, kv_caches, seq_pos |
|
|
| actions.append(action) |
|
|
| for _ in range(max_actions - 1): |
| act_emb = self.action_embed(torch.tensor([[action]], device=device)) |
| pos = self.pos_embed(torch.tensor([seq_pos], device=device)).unsqueeze(0) |
| x = act_emb + pos |
|
|
| new_caches = [] |
| for block, cache in zip(self.blocks, kv_caches): |
| x, new_cache = block(x, cache) |
| new_caches.append(new_cache) |
| kv_caches = new_caches |
|
|
| logits = self.head(self.norm(x)) |
| next_logits = logits[0, -1] |
| seq_pos += 1 |
|
|
| if temperature <= 0: |
| action = next_logits.argmax().item() |
| else: |
| action = torch.multinomial( |
| F.softmax(next_logits / temperature, dim=-1), 1).item() |
|
|
| if action == EOS_ACTION: |
| break |
| actions.append(action) |
|
|
| return actions, kv_caches, seq_pos |
|
|
| def count_params(self): |
| return sum(p.numel() for p in self.parameters()) |
|
|
|
|
| if __name__ == '__main__': |
| cfg = Config() |
| model = PolicyModel(cfg) |
| print(f"Total params: {model.count_params():,}") |
| print(f" Encoder: {sum(p.numel() for p in model.encoder.parameters()):,}") |
| print(f" Decoder: {model.count_params() - sum(p.numel() for p in model.encoder.parameters()):,}") |
|
|
| |
| tokens = torch.randint(0, BOARD_VOCAB, (2, BOARD_SEQ_LEN)) |
| emb = model.encoder(tokens) |
| print(f" Board tokens {tokens.shape} -> {emb.shape}") |
|
|
| |
| kv = None |
| pos = 0 |
| for turn in range(5): |
| tokens = torch.randint(0, BOARD_VOCAB, (1, BOARD_SEQ_LEN)) |
| acts, kv, pos = model.generate(tokens, kv_caches=kv, seq_pos=pos) |
| cache_size = kv[0][0].shape[2] |
| print(f" Turn {turn}: actions={acts}, seq_pos={pos}, cache_len={cache_size}") |
|
|