"""PAWN: Causal Language Model for chess move prediction. Decoder-only transformer (`Vaswani et al., 2017 `_) with next-token prediction over the move vocabulary. Key architectural choices drawn from subsequent work: * **RMSNorm** -- `Zhang & Sennrich, 2019 `_ * **SwiGLU** FFN -- `Shazeer, 2020 `_ * **Rotary Position Embeddings (RoPE)** -- `Su et al., 2021 `_ """ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.attention import sdpa_kernel, SDPBackend from pawn.config import CLMConfig, OUTCOME_TOKEN_BASE from chess_engine import export_move_vocabulary # ROCm flash attention backward has stride mismatches with torch.compile. # Set this to use MATH backend instead (enables compile + AMP on ROCm). SDPA_BACKEND: SDPBackend | None = None def _build_decomposition_table() -> torch.Tensor: """Build static token -> (src, dst, promo_type) lookup table. Returns int16[4278, 3]. PAD (0) and outcome tokens (4273-4277) map to (0, 0, 0) — handled by standalone embeddings. """ vocab = export_move_vocabulary() table = torch.zeros(4278, 3, dtype=torch.int16) for token_idx, uci_str in vocab["token_to_move"].items(): if token_idx >= OUTCOME_TOKEN_BASE: continue # Outcome tokens use standalone embeddings src_name = uci_str[:2] dst_name = uci_str[2:4] promo_suffix = uci_str[4:] if len(uci_str) > 4 else "" sq_names = vocab["square_names"] src_sq = sq_names.index(src_name) dst_sq = sq_names.index(dst_name) promo_type = 0 if promo_suffix: promo_map = {"q": 1, "r": 2, "b": 3, "n": 4} promo_type = promo_map[promo_suffix] table[token_idx] = torch.tensor([src_sq, dst_sq, promo_type], dtype=torch.int16) return table class RMSNorm(nn.Module): """Root Mean Square Layer Normalization (`Zhang & Sennrich, 2019 `_).""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: norm = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) return (x.float() * norm).to(x.dtype) * self.weight def _apply_rope( x: torch.Tensor, rope_cos: torch.Tensor, rope_sin: torch.Tensor ) -> torch.Tensor: """Apply Rotary Position Embeddings (`Su et al., 2021 `_). x: (B, n_heads, T, head_dim) rope_cos, rope_sin: (1, 1, T, head_dim // 2) """ x_r = x.float().reshape(*x.shape[:-1], -1, 2) x0, x1 = x_r.unbind(-1) out0 = x0 * rope_cos - x1 * rope_sin out1 = x0 * rope_sin + x1 * rope_cos out = torch.stack([out0, out1], dim=-1).reshape(x.shape) return out.to(x.dtype) def _precompute_rope_freqs(dim: int, max_len: int, base: float = 10000.0) -> torch.Tensor: """Precompute RoPE frequency tensor. Returns (max_len, dim // 2).""" freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) t = torch.arange(max_len).float() freqs = torch.outer(t, freqs) return freqs class Attention(nn.Module): def __init__(self, cfg: CLMConfig): super().__init__() self.n_heads = cfg.n_heads self.head_dim = cfg.d_model // cfg.n_heads self.wq = nn.Linear(cfg.d_model, cfg.d_model, bias=False) self.wk = nn.Linear(cfg.d_model, cfg.d_model, bias=False) self.wv = nn.Linear(cfg.d_model, cfg.d_model, bias=False) self.wo = nn.Linear(cfg.d_model, cfg.d_model, bias=False) def forward( self, x: torch.Tensor, rope_cos: torch.Tensor, rope_sin: torch.Tensor, mask: torch.Tensor | None = None, ) -> torch.Tensor: B, T, _ = x.shape q = self.wq(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.wk(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = self.wv(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) q = _apply_rope(q, rope_cos, rope_sin) k = _apply_rope(k, rope_cos, rope_sin) if SDPA_BACKEND is not None: with sdpa_kernel(SDPA_BACKEND): attn_out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, is_causal=(mask is None) ) else: attn_out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, is_causal=(mask is None) ) attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, -1) return self.wo(attn_out) def forward_kv( self, x: torch.Tensor, rope_cos: torch.Tensor, rope_sin: torch.Tensor, kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """Forward with KV-cache for autoregressive generation. Args: x: (B, T_new, d_model) — full sequence for prefill, single token for decode. rope_cos/sin: (1, 1, T_new, head_dim//2) — RoPE for the new positions only. kv_cache: optional (K, V) each (B, n_heads, T_cached, head_dim). Returns: out: (B, T_new, d_model) new_cache: (K, V) each (B, n_heads, T_total, head_dim) """ B, T_new, _ = x.shape q = self.wq(x).view(B, T_new, self.n_heads, self.head_dim).transpose(1, 2) k = self.wk(x).view(B, T_new, self.n_heads, self.head_dim).transpose(1, 2) v = self.wv(x).view(B, T_new, self.n_heads, self.head_dim).transpose(1, 2) q = _apply_rope(q, rope_cos, rope_sin) k = _apply_rope(k, rope_cos, rope_sin) if kv_cache is not None: k = torch.cat([kv_cache[0], k], dim=2) v = torch.cat([kv_cache[1], v], dim=2) # Prefill (no cache): causal mask. Decode (with cache): single query # attends to all cached keys — no mask needed. attn_out = F.scaled_dot_product_attention( q, k, v, is_causal=(kv_cache is None) ) attn_out = attn_out.transpose(1, 2).contiguous().view(B, T_new, -1) return self.wo(attn_out), (k, v) class SwiGLUFFN(nn.Module): """SwiGLU feed-forward network (`Shazeer, 2020 `_).""" def __init__(self, cfg: CLMConfig): super().__init__() self.w_gate = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) self.w_up = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) self.w_down = nn.Linear(cfg.d_ff, cfg.d_model, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)) class TransformerBlock(nn.Module): attn_norm: RMSNorm attn: Attention ffn_norm: RMSNorm ffn: SwiGLUFFN def __init__(self, cfg: CLMConfig): super().__init__() self.attn_norm = RMSNorm(cfg.d_model) self.attn = Attention(cfg) self.ffn_norm = RMSNorm(cfg.d_model) self.ffn = SwiGLUFFN(cfg) def forward( self, x: torch.Tensor, rope_cos: torch.Tensor, rope_sin: torch.Tensor, mask: torch.Tensor | None = None, ) -> torch.Tensor: x = x + self.attn(self.attn_norm(x), rope_cos, rope_sin, mask) x = x + self.ffn(self.ffn_norm(x)) return x def forward_kv( self, x: torch.Tensor, rope_cos: torch.Tensor, rope_sin: torch.Tensor, kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """Forward with KV-cache.""" attn_out, new_cache = self.attn.forward_kv( self.attn_norm(x), rope_cos, rope_sin, kv_cache ) x = x + attn_out x = x + self.ffn(self.ffn_norm(x)) return x, new_cache class CLMEmbedding(nn.Module): """Factored input embeddings for CLM. Move tokens use factored embedding: src_embed[s] + dst_embed[d] + promo_embed[p]. PAD and outcome tokens use standalone embeddings. """ decomp_table: torch.Tensor def __init__(self, cfg: CLMConfig): super().__init__() self.d_model = cfg.d_model # Factored move components self.src_embed = nn.Embedding(64, cfg.d_model) self.dst_embed = nn.Embedding(64, cfg.d_model) self.promo_embed = nn.Embedding(5, cfg.d_model) # 0=none, 1=q, 2=r, 3=b, 4=n # Standalone embeddings self.pad_embed = nn.Parameter(torch.zeros(cfg.d_model)) self.outcome_embed = nn.Embedding(cfg.n_outcomes, cfg.d_model) # Static decomposition table: token_idx -> (src, dst, promo_type) self.register_buffer("decomp_table", _build_decomposition_table(), persistent=False) def forward(self, input_ids: torch.Tensor) -> torch.Tensor: """ input_ids: (B, T) int tensor of token indices [0..4277] Returns: (B, T, d_model) """ # Decompose all tokens (PAD and outcomes get (0,0,0) from the table — # their factored embeddings are garbage but will be overridden below) flat = input_ids.long().clamp(0, 4277) decomp = self.decomp_table[flat] # (B, T, 3) src_idx = decomp[..., 0].long() dst_idx = decomp[..., 1].long() promo_idx = decomp[..., 2].long() emb = self.src_embed(src_idx) + self.dst_embed(dst_idx) + self.promo_embed(promo_idx) # Override PAD positions (branchless for torch.compile) pad_mask = (input_ids == 0).unsqueeze(-1) # (B, T, 1) emb = torch.where(pad_mask, self.pad_embed, emb) # Override outcome token positions (branchless) # Compute outcome embeddings for ALL positions (clamp makes non-outcome # indices safe); torch.where selects only at actual outcome positions. outcome_idx = (input_ids - OUTCOME_TOKEN_BASE).clamp(0, self.outcome_embed.num_embeddings - 1) outcome_embs = self.outcome_embed(outcome_idx) outcome_mask = (input_ids >= OUTCOME_TOKEN_BASE).unsqueeze(-1) # (B, T, 1) emb = torch.where(outcome_mask, outcome_embs, emb) return emb class PAWNCLM(nn.Module): """PAWN: Causal Language Model for chess. Predicts the next token (move or padding) via softmax over the full vocabulary. No factored output head, no grid, no BCE. """ rope_cos: torch.Tensor rope_sin: torch.Tensor causal_mask: torch.Tensor embed: CLMEmbedding layers: nn.ModuleList final_norm: RMSNorm lm_head: nn.Linear def __init__(self, cfg: CLMConfig): super().__init__() self.cfg = cfg self.embed = CLMEmbedding(cfg) self.layers = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)]) self.final_norm = RMSNorm(cfg.d_model) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) # Static buffers rope_freqs = _precompute_rope_freqs( cfg.d_model // cfg.n_heads, cfg.max_seq_len, cfg.rope_base ) self.register_buffer( "rope_cos", rope_freqs.cos().unsqueeze(0).unsqueeze(0), persistent=False ) self.register_buffer( "rope_sin", rope_freqs.sin().unsqueeze(0).unsqueeze(0), persistent=False ) self.register_buffer( "causal_mask", torch.ones(cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool).tril(), persistent=False, ) self._init_weights() def get_block(self, i: int) -> TransformerBlock: """Typed accessor for transformer layers (avoids ModuleList type erasure).""" return self.layers[i] # type: ignore[return-value] def _init_weights(self): for p in self.parameters(): if p.dim() > 1: nn.init.normal_(p, mean=0.0, std=0.02) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, hidden_only: bool = False, ) -> tuple[torch.Tensor, list[torch.Tensor]]: """ input_ids: (B, T) token indices attention_mask: (B, T) bool — True for real tokens (outcome + moves) hidden_only: if True, skip intermediate layer collection and return only the final hidden state in layer_outputs. Returns: logits: (B, T, vocab_size) layer_outputs: list of (B, T, d_model) from each layer """ x = self.embed(input_ids) T = input_ids.shape[1] if T > self.rope_cos.shape[2]: raise ValueError( f"Sequence length {T} exceeds max {self.rope_cos.shape[2]}" ) causal = self.causal_mask[:T, :T] padding = attention_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T) mask = causal.unsqueeze(0) & padding # (B, 1, T, T) rope_cos = self.rope_cos[:, :, :T, :] rope_sin = self.rope_sin[:, :, :T, :] if hidden_only: for layer in self.layers: x = layer(x, rope_cos, rope_sin, mask) layer_outputs = [x] else: layer_outputs = [x] # embedding output for layer in self.layers: x = layer(x, rope_cos, rope_sin, mask) layer_outputs.append(x) x = self.final_norm(x) logits = self.lm_head(x) return logits, layer_outputs def forward_train( self, input_ids: torch.Tensor, loss_mask: torch.Tensor, targets: torch.Tensor, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """Training-optimized forward: computes lm_head only at non-padding positions to avoid materializing the full (B, T, vocab_size) logits tensor. Returns loss and metrics directly. Metrics are returned as raw GPU tensors to avoid CUDA synchronization. Call .item() on them only when you need to log (e.g. every N steps). Args: input_ids: (B, T) token indices loss_mask: (B, T) bool — True for positions included in loss (outcome + moves, not padding). Also used as the attention padding mask for SDPA. targets: (B, T) target token indices (padding positions ignored) Returns: loss: scalar tensor (for backward) metrics: dict with loss and accuracy as GPU tensors (no .item()) """ x = self.embed(input_ids) T = input_ids.shape[1] if T > self.rope_cos.shape[2]: raise ValueError( f"Sequence length {T} exceeds max {self.rope_cos.shape[2]}" ) causal = self.causal_mask[:T, :T] padding = loss_mask.unsqueeze(1).unsqueeze(2) mask = causal.unsqueeze(0) & padding rope_cos = self.rope_cos[:, :, :T, :] rope_sin = self.rope_sin[:, :, :T, :] for layer in self.layers: x = layer(x, rope_cos, rope_sin, mask) x = self.final_norm(x) # Project only valid positions through lm_head to save ~25% memory valid_x = x[loss_mask] # (N_valid, d_model) valid_logits = self.lm_head(valid_x) # (N_valid, vocab_size) valid_targets = targets[loss_mask] # (N_valid,) loss = F.cross_entropy(valid_logits, valid_targets) with torch.no_grad(): preds = valid_logits.argmax(dim=-1) accuracy = (preds == valid_targets).float().mean() return loss, {"loss": loss.detach(), "accuracy": accuracy} def forward_generate( self, input_ids: torch.Tensor, kv_cache: list[tuple[torch.Tensor, torch.Tensor]] | None = None, ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: """Forward pass with KV-cache for autoregressive generation. Prefill (kv_cache=None): processes full input, builds cache. Decode (kv_cache provided): processes single new token, extends cache. Always returns logits for the last position only to save memory. Args: input_ids: (B, T) for prefill, (B, 1) for decode. kv_cache: None for prefill, list of (K, V) per layer for decode. Returns: logits: (B, 1, vocab_size) new_kv_cache: list of (K, V) per layer. """ x = self.embed(input_ids) T_new = input_ids.shape[1] T_total = T_new if kv_cache is not None: T_cached = kv_cache[0][0].shape[2] T_total = T_cached + T_new rope_cos = self.rope_cos[:, :, T_cached:T_total, :] rope_sin = self.rope_sin[:, :, T_cached:T_total, :] else: rope_cos = self.rope_cos[:, :, :T_new, :] rope_sin = self.rope_sin[:, :, :T_new, :] if T_total > self.rope_cos.shape[2]: raise ValueError( f"Sequence length {T_total} exceeds max {self.rope_cos.shape[2]}" ) new_kv_cache = [] for i in range(len(self.layers)): layer_cache = kv_cache[i] if kv_cache is not None else None x, new_cache = self.get_block(i).forward_kv(x, rope_cos, rope_sin, layer_cache) new_kv_cache.append(new_cache) x = self.final_norm(x[:, -1:, :]) logits = self.lm_head(x) return logits, new_kv_cache _IGNORE_INDEX = -100 def clm_loss( logits: torch.Tensor, targets: torch.Tensor, loss_mask: torch.Tensor, ) -> tuple[torch.Tensor, dict[str, float]]: """Compute CLM cross-entropy loss on non-padding positions. Uses ignore_index on a flat view to avoid materializing a copy of all valid-position logits (which would be ~50K × 4278 floats). Args: logits: (B, T, vocab_size) targets: (B, T) target token indices loss_mask: (B, T) bool — True for positions included in loss Returns: loss: scalar metrics: dict with loss value and accuracy """ B, T, V = logits.shape # Flat views — no copy logits_flat = logits.view(-1, V) # Set padding targets to ignore_index so cross_entropy skips them targets_flat = torch.where(loss_mask.view(-1), targets.view(-1), _IGNORE_INDEX) loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=_IGNORE_INDEX) # Top-1 accuracy (only at valid positions) with torch.no_grad(): preds = logits_flat.argmax(dim=-1) valid = targets_flat != _IGNORE_INDEX accuracy = (preds[valid] == targets_flat[valid]).float().mean().item() metrics = { "loss": loss.item(), "accuracy": accuracy, } return loss, metrics