bytefight-policy / model.py
Broyojo's picture
Upload model.py with huggingface_hub
5a4391b verified
"""
ByteFight Policy Model: CNN board encoder + Transformer action decoder.
Trains on full game trajectories:
[board₁, act₁₀, act₁₁, ..., board₂, act₂₀, act₂₁, ..., board_N, ...]
CNN compresses each board to 1 vector. Transformer sees interleaved sequence.
Only action positions have prediction targets (board positions masked with -100).
Board channels (8): my_paint, opp_paint, my_beacon, opp_beacon, wall, hill, powerup, valid
Scalars (10): my_stam, my_max_stam, opp_stam, opp_max_stam, my_r, my_c, opp_r, opp_c, turn, area
Actions (21): 20 action types + EOS
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
NUM_ACTIONS = 21
EOS_ACTION = 20
BOARD_CHANNELS = 8
MAX_BOARD = 31
@dataclass
class Config:
d_model: int = 256
n_layer: int = 6
n_head: int = 8
max_seq: int = 6000
n_actions: int = NUM_ACTIONS
board_channels: int = BOARD_CHANNELS
n_scalars: int = 10
dropout: float = 0.1
class ResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
return F.relu(out + x)
class BoardEncoder(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(cfg.board_channels, 64, 3, padding=1, bias=False),
nn.BatchNorm2d(64), nn.ReLU())
self.blocks = nn.Sequential(
ResBlock(64), ResBlock(64),
nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128), nn.ReLU(),
ResBlock(128), ResBlock(128))
self.pool = nn.AdaptiveAvgPool2d(1)
self.proj = nn.Linear(128, cfg.d_model)
self.scalar_proj = nn.Linear(cfg.n_scalars, cfg.d_model)
def forward(self, board_grid, scalars):
x = self.stem(board_grid)
x = self.blocks(x)
x = self.pool(x).flatten(1)
return self.proj(x) + self.scalar_proj(scalars)
class RMSNorm(nn.Module):
def __init__(self, d, eps=1e-6):
super().__init__()
self.w = nn.Parameter(torch.ones(d))
self.eps = eps
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.w
class Attention(nn.Module):
def __init__(self, cfg):
super().__init__()
self.n_head = cfg.n_head
self.head_dim = cfg.d_model // cfg.n_head
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
self.out = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
self.dropout = cfg.dropout
def forward(self, x, kv_cache=None):
B, T, C = x.shape
qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.head_dim)
q, k, v = qkv.unbind(2)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if kv_cache is not None:
k_prev, v_prev = kv_cache
k = torch.cat([k_prev, k], dim=2)
v = torch.cat([v_prev, v], dim=2)
new_cache = (k, v)
x = F.scaled_dot_product_attention(q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=(kv_cache is None)) # causal only for full sequence
return self.out(x.transpose(1, 2).reshape(B, T, C)), new_cache
class FFN(nn.Module):
def __init__(self, cfg):
super().__init__()
h = (int(cfg.d_model * 8 / 3) + 15) // 16 * 16
self.w1 = nn.Linear(cfg.d_model, h, bias=False)
self.w2 = nn.Linear(h, cfg.d_model, bias=False)
self.w3 = nn.Linear(cfg.d_model, h, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Block(nn.Module):
def __init__(self, cfg):
super().__init__()
self.norm1 = RMSNorm(cfg.d_model)
self.attn = Attention(cfg)
self.norm2 = RMSNorm(cfg.d_model)
self.ffn = FFN(cfg)
def forward(self, x, kv_cache=None):
attn_out, new_cache = self.attn(self.norm1(x), kv_cache)
x = x + attn_out
x = x + self.ffn(self.norm2(x))
return x, new_cache
class PolicyModel(nn.Module):
"""CNN encoder + Transformer decoder for full-game action generation."""
def __init__(self, cfg: Config = None):
super().__init__()
if cfg is None:
cfg = Config()
self.cfg = cfg
self.encoder = BoardEncoder(cfg)
self.action_embed = nn.Embedding(cfg.n_actions, cfg.d_model)
self.pos_embed = nn.Embedding(cfg.max_seq, cfg.d_model)
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
self.norm = RMSNorm(cfg.d_model)
self.head = nn.Linear(cfg.d_model, cfg.n_actions, bias=False)
def forward(self, boards, scalars, seq_actions, seq_targets, seq_is_board,
board_counts):
"""
Full-game trajectory training.
Args:
boards: (total_boards, C, H, W) — all boards from all games in batch
scalars: (total_boards, n_scalars)
seq_actions: (B, T) int64 — action tokens at each position (0 at board positions)
seq_targets: (B, T) int64 — targets (-100 at board/padding positions)
seq_is_board: (B, T) bool — True at board embedding positions
board_counts: (B,) int — number of boards per game (for splitting)
Returns:
logits: (B, T, n_actions)
loss: scalar
"""
B, T = seq_actions.shape
device = seq_actions.device
# CNN encode all boards at once
board_embs = self.encoder(boards, scalars) # (total_boards, d_model)
# Build embedding sequence
act_emb = self.action_embed(seq_actions) # (B, T, d_model)
seq_emb = act_emb.clone()
# Place board embeddings at the right positions
board_idx = 0
for b in range(B):
positions = seq_is_board[b].nonzero(as_tuple=True)[0]
n_boards = board_counts[b].item()
for i in range(n_boards):
if i < len(positions):
seq_emb[b, positions[i]] = board_embs[board_idx]
board_idx += 1
# Add positional embeddings
pos = self.pos_embed(torch.arange(T, device=device))
x = seq_emb + pos
for block in self.blocks:
x, _ = block(x)
logits = self.head(self.norm(x))
loss = F.cross_entropy(
logits.reshape(-1, self.cfg.n_actions),
seq_targets.reshape(-1),
ignore_index=-100)
return logits, loss
@torch.no_grad()
def generate(self, board_grid, scalars, kv_caches=None, seq_pos=0,
max_actions=10, temperature=0.0):
"""
Generate actions for one turn with KV cache.
Args:
board_grid: (1, C, H, W)
scalars: (1, n_scalars)
kv_caches: list of (k, v) per layer from previous turns, or None
seq_pos: current position in the sequence (for positional embeddings)
Returns:
actions: list of action token IDs
kv_caches: updated KV caches
seq_pos: updated position
"""
self.eval()
device = board_grid.device
# CNN encode board
board_emb = self.encoder(board_grid, scalars).unsqueeze(1) # (1, 1, d)
# Process board embedding token
pos = self.pos_embed(torch.tensor([seq_pos], device=device)).unsqueeze(0) # (1, 1, d)
x = board_emb + pos
if kv_caches is None:
kv_caches = [None] * len(self.blocks)
new_caches = []
for block, cache in zip(self.blocks, kv_caches):
x, new_cache = block(x, cache)
new_caches.append(new_cache)
kv_caches = new_caches
logits = self.head(self.norm(x))
next_logits = logits[0, -1]
seq_pos += 1
# First action predicted from board embedding
if temperature <= 0:
action = next_logits.argmax().item()
else:
action = torch.multinomial(
F.softmax(next_logits / temperature, dim=-1), 1).item()
actions = []
if action == EOS_ACTION:
return actions, kv_caches, seq_pos
actions.append(action)
# Generate remaining actions autoregressively
for _ in range(max_actions - 1):
act_emb = self.action_embed(torch.tensor([[action]], device=device))
pos = self.pos_embed(torch.tensor([seq_pos], device=device)).unsqueeze(0)
x = act_emb + pos
new_caches = []
for block, cache in zip(self.blocks, kv_caches):
x, new_cache = block(x, cache)
new_caches.append(new_cache)
kv_caches = new_caches
logits = self.head(self.norm(x))
next_logits = logits[0, -1]
seq_pos += 1
if temperature <= 0:
action = next_logits.argmax().item()
else:
action = torch.multinomial(
F.softmax(next_logits / temperature, dim=-1), 1).item()
if action == EOS_ACTION:
break
actions.append(action)
return actions, kv_caches, seq_pos
def count_params(self):
return sum(p.numel() for p in self.parameters())
if __name__ == '__main__':
cfg = Config()
model = PolicyModel(cfg)
print(f"Total params: {model.count_params():,}")
for h, w in [(8, 8), (15, 15), (31, 31)]:
board = torch.randn(2, BOARD_CHANNELS, h, w)
s = torch.randn(2, 10)
emb = model.encoder(board, s)
print(f" Board {h}×{w}{emb.shape}")
# Test KV cache generation across multiple turns
kv = None
pos = 0
for turn in range(5):
board = torch.randn(1, BOARD_CHANNELS, 15, 15)
s = torch.randn(1, 10)
acts, kv, pos = model.generate(board, s, kv_caches=kv, seq_pos=pos)
cache_size = kv[0][0].shape[2] if kv else 0
print(f" Turn {turn}: actions={acts}, seq_pos={pos}, cache_len={cache_size}")