chessgpt / modeling_chessgpt.py
malcouffe's picture
Convert to HuggingFace format (safetensors + trust_remote_code)
1d38600 verified
"""
ChessGPT -- LLaMA-style decoder-only transformer for UCI move prediction.
Architecture: RMSNorm, RoPE, SwiGLU, QK-Norm, no bias, scaled residual init.
HuggingFace-compatible implementation.
"""
from __future__ import annotations
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_chessgpt import ChessGPTConfig
# ---------------------------------------------------------------------------
# Building blocks
# ---------------------------------------------------------------------------
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def precompute_rope_freqs(
head_dim: int, max_seq_len: int, theta: float = 10000.0
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return (freqs_cos, freqs_sin) as real-valued tensors."""
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len)
angles = torch.outer(t, freqs)
return angles.cos(), angles.sin()
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
# xq, xk: (B, n_heads, T, head_dim)
T = xq.shape[2]
cos = freqs_cos[:T][None, None, :, :] # (1, 1, T, head_dim//2)
sin = freqs_sin[:T][None, None, :, :]
# Split into pairs and apply rotation
xq_r = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_r = xk.float().reshape(*xk.shape[:-1], -1, 2)
xq_out = torch.stack([
xq_r[..., 0] * cos - xq_r[..., 1] * sin,
xq_r[..., 0] * sin + xq_r[..., 1] * cos,
], dim=-1).flatten(-2)
xk_out = torch.stack([
xk_r[..., 0] * cos - xk_r[..., 1] * sin,
xk_r[..., 0] * sin + xk_r[..., 1] * cos,
], dim=-1).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)
class SwiGLU(nn.Module):
def __init__(self, d_model: int, d_ff: int):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False) # gate
self.w2 = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class CausalSelfAttention(nn.Module):
"""Causal self-attention with RoPE and QK-Norm, using PyTorch SDPA."""
def __init__(self, config: ChessGPTConfig):
super().__init__()
assert config.d_model % config.n_heads == 0
self.n_heads = config.n_heads
self.head_dim = config.d_model // config.n_heads
self.qkv = nn.Linear(config.d_model, 3 * config.d_model, bias=False)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.proj = nn.Linear(config.d_model, config.d_model, bias=False)
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
qkv = self.qkv(x).view(B, T, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2) # (B, T, nh, hd)
q = q.transpose(1, 2) # (B, nh, T, hd)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# QK-Norm before RoPE
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin)
# PyTorch SDPA (uses flash-attn kernels when possible)
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
) # (B, nh, T, hd)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.proj(y)
class TransformerBlock(nn.Module):
def __init__(self, config: ChessGPTConfig):
super().__init__()
self.ln1 = RMSNorm(config.d_model, eps=config.rms_norm_eps)
self.attn = CausalSelfAttention(config)
self.ln2 = RMSNorm(config.d_model, eps=config.rms_norm_eps)
self.ffn = SwiGLU(config.d_model, config.d_ff)
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln1(x), freqs_cos, freqs_sin)
x = x + self.ffn(self.ln2(x))
return x
# ---------------------------------------------------------------------------
# HuggingFace-compatible model classes
# ---------------------------------------------------------------------------
class ChessGPTPreTrainedModel(PreTrainedModel):
config_class = ChessGPTConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["TransformerBlock"]
def _init_weights(self, module):
std = self.config.weight_init_std
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
class ChessGPTModel(ChessGPTPreTrainedModel):
"""The bare ChessGPT transformer outputting raw hidden-states."""
def __init__(self, config: ChessGPTConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
head_dim = config.d_model // config.n_heads
freqs_cos, freqs_sin = precompute_rope_freqs(
head_dim, config.max_seq_len, config.rope_theta
)
self.register_buffer("freqs_cos", freqs_cos, persistent=True)
self.register_buffer("freqs_sin", freqs_sin, persistent=True)
self.blocks = nn.ModuleList(
[TransformerBlock(config) for _ in range(config.n_layers)]
)
self.ln_f = RMSNorm(config.d_model, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
B, T = input_ids.shape
if T > self.config.max_seq_len:
raise ValueError(
f"Sequence length {T} > max_seq_len {self.config.max_seq_len}"
)
x = self.embed_tokens(input_ids)
for block in self.blocks:
if self.gradient_checkpointing and self.training:
x = torch_checkpoint(block, x, self.freqs_cos, self.freqs_sin, use_reentrant=False)
else:
x = block(x, self.freqs_cos, self.freqs_sin)
x = self.ln_f(x)
return x
class ChessGPTForCausalLM(ChessGPTPreTrainedModel):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
def __init__(self, config: ChessGPTConfig):
super().__init__(config)
self.model = ChessGPTModel(config)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
hidden_states = self.model(input_ids, attention_mask=attention_mask)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=self.config.pad_token_id,
)
return CausalLMOutputWithPast(loss=loss, logits=logits)