MostLime's picture
Upload 5 files
676fbdb verified
"""
model.py — Liquid Chess Model (LCM) architecture.
Hybrid transformer with 6 GQA attention blocks and 10 LIV convolution blocks,
distributed evenly via Bresenham algorithm. Trained with dual NTP + TOP objectives.
Architecture highlights:
- GQA (Grouped Query Attention) with RoPE positional embeddings
- LIV (Local Input-dependent Value) causal convolution blocks
- LRM (Learnable Rate Multipliers) on every block
- Weight tying between embedding and NTP head
- PyTorch SDPA for efficient attention
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import ChessModelConfig
# ══════════════════════════════════════════════════════════════════════════════
# SHARED COMPONENTS
# ══════════════════════════════════════════════════════════════════════════════
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
return (x / rms) * self.weight
# ══════════════════════════════════════════════════════════════════════════════
# LIV CONVOLUTION BLOCK
# ══════════════════════════════════════════════════════════════════════════════
class LIVBlock(nn.Module):
"""
Local Input-dependent Value convolution block.
Each token attends to itself and its nearest neighbors (kernel_size=4)
using double gating. Efficient for capturing local sequential patterns.
Structure:
input → RMSNorm → project to 3× → split (B, C, x)
→ B gates x → causal conv → C gates result → project back
→ LRM scale → residual add
"""
def __init__(self, config: ChessModelConfig):
super().__init__()
d = config.d_model
k = config.conv_kernel_size
self.norm = RMSNorm(d)
self.input_proj = nn.Linear(d, 3 * d, bias=False)
self.conv = nn.Conv1d(
in_channels=d, out_channels=d, kernel_size=k,
padding=k - 1, groups=d, bias=False,
)
self.output_proj = nn.Linear(d, d, bias=False)
self.dropout = nn.Dropout(config.dropout)
self.lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.norm(x)
B, C, x = self.input_proj(x).chunk(3, dim=-1)
x = B * x
x = self.conv(x.transpose(1, 2))
x = x[:, :, :residual.shape[1]] # trim for causality
x = C * x.transpose(1, 2)
x = self.dropout(self.output_proj(x))
if self.lrm is not None:
x = x * self.lrm
return residual + x
# ══════════════════════════════════════════════════════════════════════════════
# GQA ATTENTION BLOCK
# ══════════════════════════════════════════════════════════════════════════════
def build_rope_cache(
seq_len: int, head_dim: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""Precompute RoPE cosine and sine tables."""
theta = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
positions = torch.arange(seq_len, device=device).float()
freqs = torch.outer(positions, theta)
return torch.cos(freqs), torch.sin(freqs)
def apply_rope(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
"""Apply rotary position embeddings to a query or key tensor."""
x1, x2 = x[..., ::2], x[..., 1::2]
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2)
class SwiGLU(nn.Module):
"""SwiGLU feed-forward network."""
def __init__(self, config: ChessModelConfig):
super().__init__()
d, h = config.d_model, config.ffn_hidden_size
self.gate_proj = nn.Linear(d, h, bias=False)
self.up_proj = nn.Linear(d, h, bias=False)
self.down_proj = nn.Linear(h, d, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
class GQABlock(nn.Module):
"""
Grouped Query Attention block with SwiGLU FFN and RoPE.
Uses PyTorch's scaled_dot_product_attention for efficiency.
"""
def __init__(self, config: ChessModelConfig):
super().__init__()
d = config.d_model
self.n_heads = config.n_heads
self.n_kv_heads = config.n_kv_heads
self.head_dim = config.head_dim
self.repeats = config.n_heads // config.n_kv_heads
self.attn_norm = RMSNorm(d)
self.ffn_norm = RMSNorm(d)
self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(d, d, bias=False)
self.ffn = SwiGLU(config)
self.dropout = nn.Dropout(config.dropout)
self.attn_lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None
self.ffn_lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None
def forward(
self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
) -> torch.Tensor:
B, T, _ = x.shape
# ── Attention ─────────────────────────────────────────────────────────
residual = x
x_norm = self.attn_norm(x)
q = self.q_proj(x_norm).view(B, T, self.n_heads, self.head_dim)
k = self.k_proj(x_norm).view(B, T, self.n_kv_heads, self.head_dim)
v = self.v_proj(x_norm).view(B, T, self.n_kv_heads, self.head_dim)
q = apply_rope(q.transpose(1, 2), freqs_cos, freqs_sin).transpose(1, 2)
k = apply_rope(k.transpose(1, 2), freqs_cos, freqs_sin).transpose(1, 2)
# Expand KV heads to match query heads
k = k.repeat_interleave(self.repeats, dim=2).transpose(1, 2)
v = v.repeat_interleave(self.repeats, dim=2).transpose(1, 2)
q = q.transpose(1, 2)
attn_out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=True,
).transpose(1, 2).reshape(B, T, -1)
attn_out = self.o_proj(attn_out)
if self.attn_lrm is not None:
attn_out = attn_out * self.attn_lrm
x = residual + attn_out
# ── FFN ───────────────────────────────────────────────────────────────
residual = x
ffn_out = self.ffn(self.ffn_norm(x))
if self.ffn_lrm is not None:
ffn_out = ffn_out * self.ffn_lrm
return residual + ffn_out
# ══════════════════════════════════════════════════════════════════════════════
# LAYER DISTRIBUTION
# ══════════════════════════════════════════════════════════════════════════════
def get_layer_types(n_layers: int, n_gqa: int) -> list[str]:
"""
Distribute GQA layers evenly through the network using a Bresenham-style
integer accumulator. Avoids floating-point rounding collisions.
Always places a GQA block first.
Example (16 layers, 6 GQA):
GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA
"""
if n_gqa == 0:
return ["liv"] * n_layers
if n_gqa >= n_layers:
return ["gqa"] * n_layers
layer_types = ["liv"] * n_layers
layer_types[0] = "gqa"
gqa_placed = 1
remaining = n_gqa - 1
slots = n_layers - 1
accumulator = 0
for i in range(1, n_layers):
accumulator += remaining
if accumulator >= slots:
layer_types[i] = "gqa"
accumulator -= slots
gqa_placed += 1
if gqa_placed == n_gqa:
break
return layer_types
# ══════════════════════════════════════════════════════════════════════════════
# FULL MODEL
# ══════════════════════════════════════════════════════════════════════════════
class ChessModel(nn.Module):
"""
Liquid Chess Model (LCM).
Input: token IDs (batch_size, seq_len)
Output: ntp_logits (batch_size, seq_len, vocab_size) — move generation
top_logits (batch_size, seq_len, vocab_size) — auxiliary training only
"""
def __init__(self, config: ChessModelConfig):
super().__init__()
self.config = config
self.embedding = nn.Embedding(
config.vocab_size, config.d_model, padding_idx=config.pad_id
)
layer_types = get_layer_types(config.n_layers, config.n_gqa_layers)
self.blocks = nn.ModuleList([
GQABlock(config) if lt == "gqa" else LIVBlock(config)
for lt in layer_types
])
self.layer_types = layer_types
self.norm = RMSNorm(config.d_model)
self.ntp_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.top_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Weight tying: embedding and NTP head are inverse operations
self.ntp_head.weight = self.embedding.weight
freqs_cos, freqs_sin = build_rope_cache(
config.max_seq_len, config.head_dim, device=torch.device("cpu")
)
self.register_buffer("freqs_cos", freqs_cos)
self.register_buffer("freqs_sin", freqs_sin)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
# Scale down output projections to stabilize residual stream
for name, param in self.named_parameters():
if "o_proj" in name or "down_proj" in name:
nn.init.normal_(param, mean=0.0,
std=0.02 / math.sqrt(2 * self.config.n_layers))
def forward(
self, token_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
B, T = token_ids.shape
assert T <= self.config.max_seq_len, \
f"Sequence length {T} exceeds maximum {self.config.max_seq_len}"
x = self.embedding(token_ids)
freqs_cos = self.freqs_cos[:T]
freqs_sin = self.freqs_sin[:T]
for block, lt in zip(self.blocks, self.layer_types):
x = block(x, freqs_cos, freqs_sin) if lt == "gqa" else block(x)
x = self.norm(x)
return self.ntp_head(x), self.top_head(x)
def count_parameters(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
if __name__ == "__main__":
from model.config import ChessModelConfig
config = ChessModelConfig()
model = ChessModel(config)
params = model.count_parameters()
print(f"Parameters: {params:,} ({params/1e6:.1f}M)")
x = torch.randint(0, config.vocab_size, (2, 255))
ntp, top = model(x)
assert ntp.shape == (2, 255, config.vocab_size)
assert top.shape == (2, 255, config.vocab_size)
print(f"Forward pass: {ntp.shape} ✓")