""" model.py — Liquid Chess Model (LCM) architecture. Hybrid transformer with 6 GQA attention blocks and 10 LIV convolution blocks, distributed evenly via Bresenham algorithm. Trained with dual NTP + TOP objectives. Architecture highlights: - GQA (Grouped Query Attention) with RoPE positional embeddings - LIV (Local Input-dependent Value) causal convolution blocks - LRM (Learnable Rate Multipliers) on every block - Weight tying between embedding and NTP head - PyTorch SDPA for efficient attention """ import math import torch import torch.nn as nn import torch.nn.functional as F from config import ChessModelConfig # ══════════════════════════════════════════════════════════════════════════════ # SHARED COMPONENTS # ══════════════════════════════════════════════════════════════════════════════ class RMSNorm(nn.Module): """Root Mean Square Layer Normalization.""" def __init__(self, d_model: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(d_model)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt() return (x / rms) * self.weight # ══════════════════════════════════════════════════════════════════════════════ # LIV CONVOLUTION BLOCK # ══════════════════════════════════════════════════════════════════════════════ class LIVBlock(nn.Module): """ Local Input-dependent Value convolution block. Each token attends to itself and its nearest neighbors (kernel_size=4) using double gating. Efficient for capturing local sequential patterns. Structure: input → RMSNorm → project to 3× → split (B, C, x) → B gates x → causal conv → C gates result → project back → LRM scale → residual add """ def __init__(self, config: ChessModelConfig): super().__init__() d = config.d_model k = config.conv_kernel_size self.norm = RMSNorm(d) self.input_proj = nn.Linear(d, 3 * d, bias=False) self.conv = nn.Conv1d( in_channels=d, out_channels=d, kernel_size=k, padding=k - 1, groups=d, bias=False, ) self.output_proj = nn.Linear(d, d, bias=False) self.dropout = nn.Dropout(config.dropout) self.lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = self.norm(x) B, C, x = self.input_proj(x).chunk(3, dim=-1) x = B * x x = self.conv(x.transpose(1, 2)) x = x[:, :, :residual.shape[1]] # trim for causality x = C * x.transpose(1, 2) x = self.dropout(self.output_proj(x)) if self.lrm is not None: x = x * self.lrm return residual + x # ══════════════════════════════════════════════════════════════════════════════ # GQA ATTENTION BLOCK # ══════════════════════════════════════════════════════════════════════════════ def build_rope_cache( seq_len: int, head_dim: int, device: torch.device ) -> tuple[torch.Tensor, torch.Tensor]: """Precompute RoPE cosine and sine tables.""" theta = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) positions = torch.arange(seq_len, device=device).float() freqs = torch.outer(positions, theta) return torch.cos(freqs), torch.sin(freqs) def apply_rope( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> torch.Tensor: """Apply rotary position embeddings to a query or key tensor.""" x1, x2 = x[..., ::2], x[..., 1::2] cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2) class SwiGLU(nn.Module): """SwiGLU feed-forward network.""" def __init__(self, config: ChessModelConfig): super().__init__() d, h = config.d_model, config.ffn_hidden_size self.gate_proj = nn.Linear(d, h, bias=False) self.up_proj = nn.Linear(d, h, bias=False) self.down_proj = nn.Linear(h, d, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))) class GQABlock(nn.Module): """ Grouped Query Attention block with SwiGLU FFN and RoPE. Uses PyTorch's scaled_dot_product_attention for efficiency. """ def __init__(self, config: ChessModelConfig): super().__init__() d = config.d_model self.n_heads = config.n_heads self.n_kv_heads = config.n_kv_heads self.head_dim = config.head_dim self.repeats = config.n_heads // config.n_kv_heads self.attn_norm = RMSNorm(d) self.ffn_norm = RMSNorm(d) self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(d, d, bias=False) self.ffn = SwiGLU(config) self.dropout = nn.Dropout(config.dropout) self.attn_lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None self.ffn_lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None def forward( self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor ) -> torch.Tensor: B, T, _ = x.shape # ── Attention ───────────────────────────────────────────────────────── residual = x x_norm = self.attn_norm(x) q = self.q_proj(x_norm).view(B, T, self.n_heads, self.head_dim) k = self.k_proj(x_norm).view(B, T, self.n_kv_heads, self.head_dim) v = self.v_proj(x_norm).view(B, T, self.n_kv_heads, self.head_dim) q = apply_rope(q.transpose(1, 2), freqs_cos, freqs_sin).transpose(1, 2) k = apply_rope(k.transpose(1, 2), freqs_cos, freqs_sin).transpose(1, 2) # Expand KV heads to match query heads k = k.repeat_interleave(self.repeats, dim=2).transpose(1, 2) v = v.repeat_interleave(self.repeats, dim=2).transpose(1, 2) q = q.transpose(1, 2) attn_out = F.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout.p if self.training else 0.0, is_causal=True, ).transpose(1, 2).reshape(B, T, -1) attn_out = self.o_proj(attn_out) if self.attn_lrm is not None: attn_out = attn_out * self.attn_lrm x = residual + attn_out # ── FFN ─────────────────────────────────────────────────────────────── residual = x ffn_out = self.ffn(self.ffn_norm(x)) if self.ffn_lrm is not None: ffn_out = ffn_out * self.ffn_lrm return residual + ffn_out # ══════════════════════════════════════════════════════════════════════════════ # LAYER DISTRIBUTION # ══════════════════════════════════════════════════════════════════════════════ def get_layer_types(n_layers: int, n_gqa: int) -> list[str]: """ Distribute GQA layers evenly through the network using a Bresenham-style integer accumulator. Avoids floating-point rounding collisions. Always places a GQA block first. Example (16 layers, 6 GQA): GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA """ if n_gqa == 0: return ["liv"] * n_layers if n_gqa >= n_layers: return ["gqa"] * n_layers layer_types = ["liv"] * n_layers layer_types[0] = "gqa" gqa_placed = 1 remaining = n_gqa - 1 slots = n_layers - 1 accumulator = 0 for i in range(1, n_layers): accumulator += remaining if accumulator >= slots: layer_types[i] = "gqa" accumulator -= slots gqa_placed += 1 if gqa_placed == n_gqa: break return layer_types # ══════════════════════════════════════════════════════════════════════════════ # FULL MODEL # ══════════════════════════════════════════════════════════════════════════════ class ChessModel(nn.Module): """ Liquid Chess Model (LCM). Input: token IDs (batch_size, seq_len) Output: ntp_logits (batch_size, seq_len, vocab_size) — move generation top_logits (batch_size, seq_len, vocab_size) — auxiliary training only """ def __init__(self, config: ChessModelConfig): super().__init__() self.config = config self.embedding = nn.Embedding( config.vocab_size, config.d_model, padding_idx=config.pad_id ) layer_types = get_layer_types(config.n_layers, config.n_gqa_layers) self.blocks = nn.ModuleList([ GQABlock(config) if lt == "gqa" else LIVBlock(config) for lt in layer_types ]) self.layer_types = layer_types self.norm = RMSNorm(config.d_model) self.ntp_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.top_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # Weight tying: embedding and NTP head are inverse operations self.ntp_head.weight = self.embedding.weight freqs_cos, freqs_sin = build_rope_cache( config.max_seq_len, config.head_dim, device=torch.device("cpu") ) self.register_buffer("freqs_cos", freqs_cos) self.register_buffer("freqs_sin", freqs_sin) self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() # Scale down output projections to stabilize residual stream for name, param in self.named_parameters(): if "o_proj" in name or "down_proj" in name: nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers)) def forward( self, token_ids: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: B, T = token_ids.shape assert T <= self.config.max_seq_len, \ f"Sequence length {T} exceeds maximum {self.config.max_seq_len}" x = self.embedding(token_ids) freqs_cos = self.freqs_cos[:T] freqs_sin = self.freqs_sin[:T] for block, lt in zip(self.blocks, self.layer_types): x = block(x, freqs_cos, freqs_sin) if lt == "gqa" else block(x) x = self.norm(x) return self.ntp_head(x), self.top_head(x) def count_parameters(self) -> int: return sum(p.numel() for p in self.parameters() if p.requires_grad) if __name__ == "__main__": from model.config import ChessModelConfig config = ChessModelConfig() model = ChessModel(config) params = model.count_parameters() print(f"Parameters: {params:,} ({params/1e6:.1f}M)") x = torch.randint(0, config.vocab_size, (2, 255)) ntp, top = model(x) assert ntp.shape == (2, 255, config.vocab_size) assert top.shape == (2, 255, config.vocab_size) print(f"Forward pass: {ntp.shape} ✓")