Spaces:
Running
Running
| """ | |
| model.py — Liquid Chess Model (LCM) architecture. | |
| Hybrid transformer with 6 GQA attention blocks and 10 LIV convolution blocks, | |
| distributed evenly via Bresenham algorithm. Trained with dual NTP + TOP objectives. | |
| Architecture highlights: | |
| - GQA (Grouped Query Attention) with RoPE positional embeddings | |
| - LIV (Local Input-dependent Value) causal convolution blocks | |
| - LRM (Learnable Rate Multipliers) on every block | |
| - Weight tying between embedding and NTP head | |
| - PyTorch SDPA for efficient attention | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from config import ChessModelConfig | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # SHARED COMPONENTS | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| class RMSNorm(nn.Module): | |
| """Root Mean Square Layer Normalization.""" | |
| def __init__(self, d_model: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(d_model)) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt() | |
| return (x / rms) * self.weight | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # LIV CONVOLUTION BLOCK | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| class LIVBlock(nn.Module): | |
| """ | |
| Local Input-dependent Value convolution block. | |
| Each token attends to itself and its nearest neighbors (kernel_size=4) | |
| using double gating. Efficient for capturing local sequential patterns. | |
| Structure: | |
| input → RMSNorm → project to 3× → split (B, C, x) | |
| → B gates x → causal conv → C gates result → project back | |
| → LRM scale → residual add | |
| """ | |
| def __init__(self, config: ChessModelConfig): | |
| super().__init__() | |
| d = config.d_model | |
| k = config.conv_kernel_size | |
| self.norm = RMSNorm(d) | |
| self.input_proj = nn.Linear(d, 3 * d, bias=False) | |
| self.conv = nn.Conv1d( | |
| in_channels=d, out_channels=d, kernel_size=k, | |
| padding=k - 1, groups=d, bias=False, | |
| ) | |
| self.output_proj = nn.Linear(d, d, bias=False) | |
| self.dropout = nn.Dropout(config.dropout) | |
| self.lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| residual = x | |
| x = self.norm(x) | |
| B, C, x = self.input_proj(x).chunk(3, dim=-1) | |
| x = B * x | |
| x = self.conv(x.transpose(1, 2)) | |
| x = x[:, :, :residual.shape[1]] # trim for causality | |
| x = C * x.transpose(1, 2) | |
| x = self.dropout(self.output_proj(x)) | |
| if self.lrm is not None: | |
| x = x * self.lrm | |
| return residual + x | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # GQA ATTENTION BLOCK | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def build_rope_cache( | |
| seq_len: int, head_dim: int, device: torch.device | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Precompute RoPE cosine and sine tables.""" | |
| theta = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) | |
| positions = torch.arange(seq_len, device=device).float() | |
| freqs = torch.outer(positions, theta) | |
| return torch.cos(freqs), torch.sin(freqs) | |
| def apply_rope( | |
| x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Apply rotary position embeddings to a query or key tensor.""" | |
| x1, x2 = x[..., ::2], x[..., 1::2] | |
| cos = cos.unsqueeze(0).unsqueeze(0) | |
| sin = sin.unsqueeze(0).unsqueeze(0) | |
| return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2) | |
| class SwiGLU(nn.Module): | |
| """SwiGLU feed-forward network.""" | |
| def __init__(self, config: ChessModelConfig): | |
| super().__init__() | |
| d, h = config.d_model, config.ffn_hidden_size | |
| self.gate_proj = nn.Linear(d, h, bias=False) | |
| self.up_proj = nn.Linear(d, h, bias=False) | |
| self.down_proj = nn.Linear(h, d, bias=False) | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))) | |
| class GQABlock(nn.Module): | |
| """ | |
| Grouped Query Attention block with SwiGLU FFN and RoPE. | |
| Uses PyTorch's scaled_dot_product_attention for efficiency. | |
| """ | |
| def __init__(self, config: ChessModelConfig): | |
| super().__init__() | |
| d = config.d_model | |
| self.n_heads = config.n_heads | |
| self.n_kv_heads = config.n_kv_heads | |
| self.head_dim = config.head_dim | |
| self.repeats = config.n_heads // config.n_kv_heads | |
| self.attn_norm = RMSNorm(d) | |
| self.ffn_norm = RMSNorm(d) | |
| self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(d, d, bias=False) | |
| self.ffn = SwiGLU(config) | |
| self.dropout = nn.Dropout(config.dropout) | |
| self.attn_lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None | |
| self.ffn_lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None | |
| def forward( | |
| self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor | |
| ) -> torch.Tensor: | |
| B, T, _ = x.shape | |
| # ── Attention ───────────────────────────────────────────────────────── | |
| residual = x | |
| x_norm = self.attn_norm(x) | |
| q = self.q_proj(x_norm).view(B, T, self.n_heads, self.head_dim) | |
| k = self.k_proj(x_norm).view(B, T, self.n_kv_heads, self.head_dim) | |
| v = self.v_proj(x_norm).view(B, T, self.n_kv_heads, self.head_dim) | |
| q = apply_rope(q.transpose(1, 2), freqs_cos, freqs_sin).transpose(1, 2) | |
| k = apply_rope(k.transpose(1, 2), freqs_cos, freqs_sin).transpose(1, 2) | |
| # Expand KV heads to match query heads | |
| k = k.repeat_interleave(self.repeats, dim=2).transpose(1, 2) | |
| v = v.repeat_interleave(self.repeats, dim=2).transpose(1, 2) | |
| q = q.transpose(1, 2) | |
| attn_out = F.scaled_dot_product_attention( | |
| q, k, v, | |
| dropout_p=self.dropout.p if self.training else 0.0, | |
| is_causal=True, | |
| ).transpose(1, 2).reshape(B, T, -1) | |
| attn_out = self.o_proj(attn_out) | |
| if self.attn_lrm is not None: | |
| attn_out = attn_out * self.attn_lrm | |
| x = residual + attn_out | |
| # ── FFN ─────────────────────────────────────────────────────────────── | |
| residual = x | |
| ffn_out = self.ffn(self.ffn_norm(x)) | |
| if self.ffn_lrm is not None: | |
| ffn_out = ffn_out * self.ffn_lrm | |
| return residual + ffn_out | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # LAYER DISTRIBUTION | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def get_layer_types(n_layers: int, n_gqa: int) -> list[str]: | |
| """ | |
| Distribute GQA layers evenly through the network using a Bresenham-style | |
| integer accumulator. Avoids floating-point rounding collisions. | |
| Always places a GQA block first. | |
| Example (16 layers, 6 GQA): | |
| GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA | |
| """ | |
| if n_gqa == 0: | |
| return ["liv"] * n_layers | |
| if n_gqa >= n_layers: | |
| return ["gqa"] * n_layers | |
| layer_types = ["liv"] * n_layers | |
| layer_types[0] = "gqa" | |
| gqa_placed = 1 | |
| remaining = n_gqa - 1 | |
| slots = n_layers - 1 | |
| accumulator = 0 | |
| for i in range(1, n_layers): | |
| accumulator += remaining | |
| if accumulator >= slots: | |
| layer_types[i] = "gqa" | |
| accumulator -= slots | |
| gqa_placed += 1 | |
| if gqa_placed == n_gqa: | |
| break | |
| return layer_types | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # FULL MODEL | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| class ChessModel(nn.Module): | |
| """ | |
| Liquid Chess Model (LCM). | |
| Input: token IDs (batch_size, seq_len) | |
| Output: ntp_logits (batch_size, seq_len, vocab_size) — move generation | |
| top_logits (batch_size, seq_len, vocab_size) — auxiliary training only | |
| """ | |
| def __init__(self, config: ChessModelConfig): | |
| super().__init__() | |
| self.config = config | |
| self.embedding = nn.Embedding( | |
| config.vocab_size, config.d_model, padding_idx=config.pad_id | |
| ) | |
| layer_types = get_layer_types(config.n_layers, config.n_gqa_layers) | |
| self.blocks = nn.ModuleList([ | |
| GQABlock(config) if lt == "gqa" else LIVBlock(config) | |
| for lt in layer_types | |
| ]) | |
| self.layer_types = layer_types | |
| self.norm = RMSNorm(config.d_model) | |
| self.ntp_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| self.top_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| # Weight tying: embedding and NTP head are inverse operations | |
| self.ntp_head.weight = self.embedding.weight | |
| freqs_cos, freqs_sin = build_rope_cache( | |
| config.max_seq_len, config.head_dim, device=torch.device("cpu") | |
| ) | |
| self.register_buffer("freqs_cos", freqs_cos) | |
| self.register_buffer("freqs_sin", freqs_sin) | |
| self._init_weights() | |
| def _init_weights(self): | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| # Scale down output projections to stabilize residual stream | |
| for name, param in self.named_parameters(): | |
| if "o_proj" in name or "down_proj" in name: | |
| nn.init.normal_(param, mean=0.0, | |
| std=0.02 / math.sqrt(2 * self.config.n_layers)) | |
| def forward( | |
| self, token_ids: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| B, T = token_ids.shape | |
| assert T <= self.config.max_seq_len, \ | |
| f"Sequence length {T} exceeds maximum {self.config.max_seq_len}" | |
| x = self.embedding(token_ids) | |
| freqs_cos = self.freqs_cos[:T] | |
| freqs_sin = self.freqs_sin[:T] | |
| for block, lt in zip(self.blocks, self.layer_types): | |
| x = block(x, freqs_cos, freqs_sin) if lt == "gqa" else block(x) | |
| x = self.norm(x) | |
| return self.ntp_head(x), self.top_head(x) | |
| def count_parameters(self) -> int: | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| if __name__ == "__main__": | |
| from model.config import ChessModelConfig | |
| config = ChessModelConfig() | |
| model = ChessModel(config) | |
| params = model.count_parameters() | |
| print(f"Parameters: {params:,} ({params/1e6:.1f}M)") | |
| x = torch.randint(0, config.vocab_size, (2, 255)) | |
| ntp, top = model(x) | |
| assert ntp.shape == (2, 255, config.vocab_size) | |
| assert top.shape == (2, 255, config.vocab_size) | |
| print(f"Forward pass: {ntp.shape} ✓") |