| | """
|
| | model.py β Liquid Chess Model (LCM) architecture.
|
| |
|
| | Hybrid transformer with 6 GQA attention blocks and 10 LIV convolution blocks,
|
| | distributed evenly via Bresenham algorithm. Trained with dual NTP + TOP objectives.
|
| |
|
| | Architecture highlights:
|
| | - GQA (Grouped Query Attention) with RoPE positional embeddings
|
| | - LIV (Local Input-dependent Value) causal convolution blocks
|
| | - LRM (Learnable Rate Multipliers) on every block
|
| | - Weight tying between embedding and NTP head
|
| | - PyTorch SDPA for efficient attention
|
| | """
|
| |
|
| | import math
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| |
|
| | from config import ChessModelConfig
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class RMSNorm(nn.Module):
|
| | """Root Mean Square Layer Normalization."""
|
| |
|
| | def __init__(self, d_model: int, eps: float = 1e-6):
|
| | super().__init__()
|
| | self.weight = nn.Parameter(torch.ones(d_model))
|
| | self.eps = eps
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
|
| | return (x / rms) * self.weight
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class LIVBlock(nn.Module):
|
| | """
|
| | Local Input-dependent Value convolution block.
|
| |
|
| | Each token attends to itself and its nearest neighbors (kernel_size=4)
|
| | using double gating. Efficient for capturing local sequential patterns.
|
| |
|
| | Structure:
|
| | input β RMSNorm β project to 3Γ β split (B, C, x)
|
| | β B gates x β causal conv β C gates result β project back
|
| | β LRM scale β residual add
|
| | """
|
| |
|
| | def __init__(self, config: ChessModelConfig):
|
| | super().__init__()
|
| | d = config.d_model
|
| | k = config.conv_kernel_size
|
| |
|
| | self.norm = RMSNorm(d)
|
| | self.input_proj = nn.Linear(d, 3 * d, bias=False)
|
| | self.conv = nn.Conv1d(
|
| | in_channels=d, out_channels=d, kernel_size=k,
|
| | padding=k - 1, groups=d, bias=False,
|
| | )
|
| | self.output_proj = nn.Linear(d, d, bias=False)
|
| | self.dropout = nn.Dropout(config.dropout)
|
| | self.lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | residual = x
|
| | x = self.norm(x)
|
| | B, C, x = self.input_proj(x).chunk(3, dim=-1)
|
| | x = B * x
|
| | x = self.conv(x.transpose(1, 2))
|
| | x = x[:, :, :residual.shape[1]]
|
| | x = C * x.transpose(1, 2)
|
| | x = self.dropout(self.output_proj(x))
|
| | if self.lrm is not None:
|
| | x = x * self.lrm
|
| | return residual + x
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def build_rope_cache(
|
| | seq_len: int, head_dim: int, device: torch.device
|
| | ) -> tuple[torch.Tensor, torch.Tensor]:
|
| | """Precompute RoPE cosine and sine tables."""
|
| | theta = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
|
| | positions = torch.arange(seq_len, device=device).float()
|
| | freqs = torch.outer(positions, theta)
|
| | return torch.cos(freqs), torch.sin(freqs)
|
| |
|
| |
|
| | def apply_rope(
|
| | x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| | ) -> torch.Tensor:
|
| | """Apply rotary position embeddings to a query or key tensor."""
|
| | x1, x2 = x[..., ::2], x[..., 1::2]
|
| | cos = cos.unsqueeze(0).unsqueeze(0)
|
| | sin = sin.unsqueeze(0).unsqueeze(0)
|
| | return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2)
|
| |
|
| |
|
| | class SwiGLU(nn.Module):
|
| | """SwiGLU feed-forward network."""
|
| |
|
| | def __init__(self, config: ChessModelConfig):
|
| | super().__init__()
|
| | d, h = config.d_model, config.ffn_hidden_size
|
| | self.gate_proj = nn.Linear(d, h, bias=False)
|
| | self.up_proj = nn.Linear(d, h, bias=False)
|
| | self.down_proj = nn.Linear(h, d, bias=False)
|
| | self.dropout = nn.Dropout(config.dropout)
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
|
| |
|
| |
|
| | class GQABlock(nn.Module):
|
| | """
|
| | Grouped Query Attention block with SwiGLU FFN and RoPE.
|
| | Uses PyTorch's scaled_dot_product_attention for efficiency.
|
| | """
|
| |
|
| | def __init__(self, config: ChessModelConfig):
|
| | super().__init__()
|
| | d = config.d_model
|
| | self.n_heads = config.n_heads
|
| | self.n_kv_heads = config.n_kv_heads
|
| | self.head_dim = config.head_dim
|
| | self.repeats = config.n_heads // config.n_kv_heads
|
| |
|
| | self.attn_norm = RMSNorm(d)
|
| | self.ffn_norm = RMSNorm(d)
|
| |
|
| | self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False)
|
| | self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
|
| | self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
|
| | self.o_proj = nn.Linear(d, d, bias=False)
|
| |
|
| | self.ffn = SwiGLU(config)
|
| | self.dropout = nn.Dropout(config.dropout)
|
| |
|
| | self.attn_lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None
|
| | self.ffn_lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None
|
| |
|
| | def forward(
|
| | self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
|
| | ) -> torch.Tensor:
|
| | B, T, _ = x.shape
|
| |
|
| |
|
| | residual = x
|
| | x_norm = self.attn_norm(x)
|
| |
|
| | q = self.q_proj(x_norm).view(B, T, self.n_heads, self.head_dim)
|
| | k = self.k_proj(x_norm).view(B, T, self.n_kv_heads, self.head_dim)
|
| | v = self.v_proj(x_norm).view(B, T, self.n_kv_heads, self.head_dim)
|
| |
|
| | q = apply_rope(q.transpose(1, 2), freqs_cos, freqs_sin).transpose(1, 2)
|
| | k = apply_rope(k.transpose(1, 2), freqs_cos, freqs_sin).transpose(1, 2)
|
| |
|
| |
|
| | k = k.repeat_interleave(self.repeats, dim=2).transpose(1, 2)
|
| | v = v.repeat_interleave(self.repeats, dim=2).transpose(1, 2)
|
| | q = q.transpose(1, 2)
|
| |
|
| | attn_out = F.scaled_dot_product_attention(
|
| | q, k, v,
|
| | dropout_p=self.dropout.p if self.training else 0.0,
|
| | is_causal=True,
|
| | ).transpose(1, 2).reshape(B, T, -1)
|
| |
|
| | attn_out = self.o_proj(attn_out)
|
| | if self.attn_lrm is not None:
|
| | attn_out = attn_out * self.attn_lrm
|
| | x = residual + attn_out
|
| |
|
| |
|
| | residual = x
|
| | ffn_out = self.ffn(self.ffn_norm(x))
|
| | if self.ffn_lrm is not None:
|
| | ffn_out = ffn_out * self.ffn_lrm
|
| | return residual + ffn_out
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def get_layer_types(n_layers: int, n_gqa: int) -> list[str]:
|
| | """
|
| | Distribute GQA layers evenly through the network using a Bresenham-style
|
| | integer accumulator. Avoids floating-point rounding collisions.
|
| | Always places a GQA block first.
|
| |
|
| | Example (16 layers, 6 GQA):
|
| | GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA
|
| | """
|
| | if n_gqa == 0:
|
| | return ["liv"] * n_layers
|
| | if n_gqa >= n_layers:
|
| | return ["gqa"] * n_layers
|
| |
|
| | layer_types = ["liv"] * n_layers
|
| | layer_types[0] = "gqa"
|
| | gqa_placed = 1
|
| | remaining = n_gqa - 1
|
| | slots = n_layers - 1
|
| | accumulator = 0
|
| |
|
| | for i in range(1, n_layers):
|
| | accumulator += remaining
|
| | if accumulator >= slots:
|
| | layer_types[i] = "gqa"
|
| | accumulator -= slots
|
| | gqa_placed += 1
|
| | if gqa_placed == n_gqa:
|
| | break
|
| |
|
| | return layer_types
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class ChessModel(nn.Module):
|
| | """
|
| | Liquid Chess Model (LCM).
|
| |
|
| | Input: token IDs (batch_size, seq_len)
|
| | Output: ntp_logits (batch_size, seq_len, vocab_size) β move generation
|
| | top_logits (batch_size, seq_len, vocab_size) β auxiliary training only
|
| | """
|
| |
|
| | def __init__(self, config: ChessModelConfig):
|
| | super().__init__()
|
| | self.config = config
|
| |
|
| | self.embedding = nn.Embedding(
|
| | config.vocab_size, config.d_model, padding_idx=config.pad_id
|
| | )
|
| |
|
| | layer_types = get_layer_types(config.n_layers, config.n_gqa_layers)
|
| | self.blocks = nn.ModuleList([
|
| | GQABlock(config) if lt == "gqa" else LIVBlock(config)
|
| | for lt in layer_types
|
| | ])
|
| | self.layer_types = layer_types
|
| |
|
| | self.norm = RMSNorm(config.d_model)
|
| | self.ntp_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| | self.top_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| |
|
| |
|
| | self.ntp_head.weight = self.embedding.weight
|
| |
|
| | freqs_cos, freqs_sin = build_rope_cache(
|
| | config.max_seq_len, config.head_dim, device=torch.device("cpu")
|
| | )
|
| | self.register_buffer("freqs_cos", freqs_cos)
|
| | self.register_buffer("freqs_sin", freqs_sin)
|
| |
|
| | self._init_weights()
|
| |
|
| | def _init_weights(self):
|
| | for module in self.modules():
|
| | if isinstance(module, nn.Linear):
|
| | nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| | if module.bias is not None:
|
| | nn.init.zeros_(module.bias)
|
| | elif isinstance(module, nn.Embedding):
|
| | nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| | if module.padding_idx is not None:
|
| | module.weight.data[module.padding_idx].zero_()
|
| |
|
| |
|
| | for name, param in self.named_parameters():
|
| | if "o_proj" in name or "down_proj" in name:
|
| | nn.init.normal_(param, mean=0.0,
|
| | std=0.02 / math.sqrt(2 * self.config.n_layers))
|
| |
|
| | def forward(
|
| | self, token_ids: torch.Tensor
|
| | ) -> tuple[torch.Tensor, torch.Tensor]:
|
| | B, T = token_ids.shape
|
| | assert T <= self.config.max_seq_len, \
|
| | f"Sequence length {T} exceeds maximum {self.config.max_seq_len}"
|
| |
|
| | x = self.embedding(token_ids)
|
| | freqs_cos = self.freqs_cos[:T]
|
| | freqs_sin = self.freqs_sin[:T]
|
| |
|
| | for block, lt in zip(self.blocks, self.layer_types):
|
| | x = block(x, freqs_cos, freqs_sin) if lt == "gqa" else block(x)
|
| |
|
| | x = self.norm(x)
|
| | return self.ntp_head(x), self.top_head(x)
|
| |
|
| | def count_parameters(self) -> int:
|
| | return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | from model.config import ChessModelConfig
|
| |
|
| | config = ChessModelConfig()
|
| | model = ChessModel(config)
|
| | params = model.count_parameters()
|
| | print(f"Parameters: {params:,} ({params/1e6:.1f}M)")
|
| |
|
| | x = torch.randint(0, config.vocab_size, (2, 255))
|
| | ntp, top = model(x)
|
| | assert ntp.shape == (2, 255, config.vocab_size)
|
| | assert top.shape == (2, 255, config.vocab_size)
|
| | print(f"Forward pass: {ntp.shape} β") |