File size: 13,899 Bytes
b2c1dad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 | """
model.py β Liquid Chess Model (LCM) architecture.
Hybrid transformer with 6 GQA attention blocks and 10 LIV convolution blocks,
distributed evenly via Bresenham algorithm. Trained with dual NTP + TOP objectives.
Architecture highlights:
- GQA (Grouped Query Attention) with RoPE positional embeddings
- LIV (Local Input-dependent Value) causal convolution blocks
- LRM (Learnable Rate Multipliers) on every block
- Weight tying between embedding and NTP head
- PyTorch SDPA for efficient attention
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import ChessModelConfig
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# SHARED COMPONENTS
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
return (x / rms) * self.weight
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# LIV CONVOLUTION BLOCK
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class LIVBlock(nn.Module):
"""
Local Input-dependent Value convolution block.
Each token attends to itself and its nearest neighbors (kernel_size=4)
using double gating. Efficient for capturing local sequential patterns.
Structure:
input β RMSNorm β project to 3Γ β split (B, C, x)
β B gates x β causal conv β C gates result β project back
β LRM scale β residual add
"""
def __init__(self, config: ChessModelConfig):
super().__init__()
d = config.d_model
k = config.conv_kernel_size
self.norm = RMSNorm(d)
self.input_proj = nn.Linear(d, 3 * d, bias=False)
self.conv = nn.Conv1d(
in_channels=d, out_channels=d, kernel_size=k,
padding=k - 1, groups=d, bias=False,
)
self.output_proj = nn.Linear(d, d, bias=False)
self.dropout = nn.Dropout(config.dropout)
self.lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.norm(x)
B, C, x = self.input_proj(x).chunk(3, dim=-1)
x = B * x
x = self.conv(x.transpose(1, 2))
x = x[:, :, :residual.shape[1]] # trim for causality
x = C * x.transpose(1, 2)
x = self.dropout(self.output_proj(x))
if self.lrm is not None:
x = x * self.lrm
return residual + x
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# GQA ATTENTION BLOCK
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def build_rope_cache(
seq_len: int, head_dim: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""Precompute RoPE cosine and sine tables."""
theta = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
positions = torch.arange(seq_len, device=device).float()
freqs = torch.outer(positions, theta)
return torch.cos(freqs), torch.sin(freqs)
def apply_rope(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
"""Apply rotary position embeddings to a query or key tensor."""
x1, x2 = x[..., ::2], x[..., 1::2]
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2)
class SwiGLU(nn.Module):
"""SwiGLU feed-forward network."""
def __init__(self, config: ChessModelConfig):
super().__init__()
d, h = config.d_model, config.ffn_hidden_size
self.gate_proj = nn.Linear(d, h, bias=False)
self.up_proj = nn.Linear(d, h, bias=False)
self.down_proj = nn.Linear(h, d, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
class GQABlock(nn.Module):
"""
Grouped Query Attention block with SwiGLU FFN and RoPE.
Uses PyTorch's scaled_dot_product_attention for efficiency.
"""
def __init__(self, config: ChessModelConfig):
super().__init__()
d = config.d_model
self.n_heads = config.n_heads
self.n_kv_heads = config.n_kv_heads
self.head_dim = config.head_dim
self.repeats = config.n_heads // config.n_kv_heads
self.attn_norm = RMSNorm(d)
self.ffn_norm = RMSNorm(d)
self.q_proj = nn.Linear(d, self.n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(d, self.n_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(d, d, bias=False)
self.ffn = SwiGLU(config)
self.dropout = nn.Dropout(config.dropout)
self.attn_lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None
self.ffn_lrm = nn.Parameter(torch.ones(d)) if config.use_lrm else None
def forward(
self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
) -> torch.Tensor:
B, T, _ = x.shape
# ββ Attention βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
residual = x
x_norm = self.attn_norm(x)
q = self.q_proj(x_norm).view(B, T, self.n_heads, self.head_dim)
k = self.k_proj(x_norm).view(B, T, self.n_kv_heads, self.head_dim)
v = self.v_proj(x_norm).view(B, T, self.n_kv_heads, self.head_dim)
q = apply_rope(q.transpose(1, 2), freqs_cos, freqs_sin).transpose(1, 2)
k = apply_rope(k.transpose(1, 2), freqs_cos, freqs_sin).transpose(1, 2)
# Expand KV heads to match query heads
k = k.repeat_interleave(self.repeats, dim=2).transpose(1, 2)
v = v.repeat_interleave(self.repeats, dim=2).transpose(1, 2)
q = q.transpose(1, 2)
attn_out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=True,
).transpose(1, 2).reshape(B, T, -1)
attn_out = self.o_proj(attn_out)
if self.attn_lrm is not None:
attn_out = attn_out * self.attn_lrm
x = residual + attn_out
# ββ FFN βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
residual = x
ffn_out = self.ffn(self.ffn_norm(x))
if self.ffn_lrm is not None:
ffn_out = ffn_out * self.ffn_lrm
return residual + ffn_out
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# LAYER DISTRIBUTION
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_layer_types(n_layers: int, n_gqa: int) -> list[str]:
"""
Distribute GQA layers evenly through the network using a Bresenham-style
integer accumulator. Avoids floating-point rounding collisions.
Always places a GQA block first.
Example (16 layers, 6 GQA):
GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA LIV LIV GQA
"""
if n_gqa == 0:
return ["liv"] * n_layers
if n_gqa >= n_layers:
return ["gqa"] * n_layers
layer_types = ["liv"] * n_layers
layer_types[0] = "gqa"
gqa_placed = 1
remaining = n_gqa - 1
slots = n_layers - 1
accumulator = 0
for i in range(1, n_layers):
accumulator += remaining
if accumulator >= slots:
layer_types[i] = "gqa"
accumulator -= slots
gqa_placed += 1
if gqa_placed == n_gqa:
break
return layer_types
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# FULL MODEL
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class ChessModel(nn.Module):
"""
Liquid Chess Model (LCM).
Input: token IDs (batch_size, seq_len)
Output: ntp_logits (batch_size, seq_len, vocab_size) β move generation
top_logits (batch_size, seq_len, vocab_size) β auxiliary training only
"""
def __init__(self, config: ChessModelConfig):
super().__init__()
self.config = config
self.embedding = nn.Embedding(
config.vocab_size, config.d_model, padding_idx=config.pad_id
)
layer_types = get_layer_types(config.n_layers, config.n_gqa_layers)
self.blocks = nn.ModuleList([
GQABlock(config) if lt == "gqa" else LIVBlock(config)
for lt in layer_types
])
self.layer_types = layer_types
self.norm = RMSNorm(config.d_model)
self.ntp_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.top_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Weight tying: embedding and NTP head are inverse operations
self.ntp_head.weight = self.embedding.weight
freqs_cos, freqs_sin = build_rope_cache(
config.max_seq_len, config.head_dim, device=torch.device("cpu")
)
self.register_buffer("freqs_cos", freqs_cos)
self.register_buffer("freqs_sin", freqs_sin)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
# Scale down output projections to stabilize residual stream
for name, param in self.named_parameters():
if "o_proj" in name or "down_proj" in name:
nn.init.normal_(param, mean=0.0,
std=0.02 / math.sqrt(2 * self.config.n_layers))
def forward(
self, token_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
B, T = token_ids.shape
assert T <= self.config.max_seq_len, \
f"Sequence length {T} exceeds maximum {self.config.max_seq_len}"
x = self.embedding(token_ids)
freqs_cos = self.freqs_cos[:T]
freqs_sin = self.freqs_sin[:T]
for block, lt in zip(self.blocks, self.layer_types):
x = block(x, freqs_cos, freqs_sin) if lt == "gqa" else block(x)
x = self.norm(x)
return self.ntp_head(x), self.top_head(x)
def count_parameters(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
if __name__ == "__main__":
from model.config import ChessModelConfig
config = ChessModelConfig()
model = ChessModel(config)
params = model.count_parameters()
print(f"Parameters: {params:,} ({params/1e6:.1f}M)")
x = torch.randint(0, config.vocab_size, (2, 255))
ntp, top = model(x)
assert ntp.shape == (2, 255, config.vocab_size)
assert top.shape == (2, 255, config.vocab_size)
print(f"Forward pass: {ntp.shape} β") |