chess-trm-powerful / model.py
Chiensaucisse67's picture
Chess Challenge submission by Chiensaucisse67
0a626fd verified
"""
TRM (Tiny Recursive Model) adapted for Causal Language Modeling (Chess).
Based on the official implementation: TinyRecursiveModels/models/recursive_reasoning/trm.py
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
class ChessConfig(PretrainedConfig):
model_type = "chess_transformer"
def __init__(
self,
vocab_size: int = 1200,
n_embd: int = 128,
n_head: int = 4,
n_ctx: int = 256,
h_cycles: int = 2, # Number of High-level reasoning cycles
l_cycles: int = 2, # Number of Low-level reasoning cycles per H-cycle
n_layers_per_block: int = 1, # Number of physical layers in the shared block
n_inner: Optional[int] = None,
n_layer: Optional[int] = None, # Not used directly; total layers = h_cycles * l_cycles
dropout: float = 0.0, # TRM usually uses 0 dropout for reasoning
layer_norm_epsilon: float = 1e-5,
tie_weights: bool = True,
rope_theta: float = 10000.0,
pad_token_id: int = 0, # Assuming 0 is padding based on your log
bos_token_id: int = 1,
eos_token_id: int = 2,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.n_embd = n_embd
self.n_head = n_head
self.n_ctx = n_ctx
self.h_cycles = h_cycles
self.l_cycles = l_cycles
self.n_layers_per_block = n_layers_per_block
self.n_layers = n_layer
self.n_inner = n_inner if n_inner is not None else int(n_embd * 8/3) # SwiGLU convention
self.dropout = dropout
self.layer_norm_epsilon = layer_norm_epsilon
self.tie_weights = tie_weights
self.rope_theta = rope_theta
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
var = torch.mean(x**2, dim=-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return self.weight * x
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000.0, device=None):
super().__init__()
self.dim = dim
self.base = base
self.max_position_embeddings = max_position_embeddings
self.register_buffer("inv_freq", None, persistent=False)
self.register_buffer("cos_cached", None, persistent=False)
self.register_buffer("sin_cached", None, persistent=False)
def _update_cos_sin_tables(self, x, seq_len):
if (self.cos_cached is None or self.cos_cached.shape[0] < seq_len):
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim))
t = torch.arange(max(seq_len, self.max_position_embeddings), device=x.device).float()
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
def forward(self, x, seq_len=None):
if seq_len is None:
seq_len = x.shape[1]
self._update_cos_sin_tables(x, seq_len)
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
# q, k: [batch, seq, head, dim] (after transpose)
# cos, sin: [seq, dim] -> need broadcast
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq, 1, dim]
sin = sin.unsqueeze(0).unsqueeze(2)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class MultiQueryAttention(nn.Module):
"""
Standard Attention with RoPE support.
Using Multi-Query (MQA) or standard MHA depending on config.
Adapted for Causal Masking.
"""
def __init__(self, config: ChessConfig):
super().__init__()
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
self.c_q = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.c_k = nn.Linear(config.n_embd, self.head_dim, bias=False)
self.c_v = nn.Linear(config.n_embd, self.head_dim, bias=False)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x, cos, sin, attention_mask=None):
B, T, C = x.size()
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, 1, self.head_dim)
v = self.c_v(x).view(B, T, 1, self.head_dim)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
k = k.expand(-1, self.n_head, -1, -1)
v = v.expand(-1, self.n_head, -1, -1)
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=True
)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.c_proj(y)
return y
class SwiGLU(nn.Module):
def __init__(self, config: ChessConfig):
super().__init__()
self.w1 = nn.Linear(config.n_embd, config.n_inner, bias=False)
self.w2 = nn.Linear(config.n_embd, config.n_inner, bias=False)
self.w3 = nn.Linear(config.n_inner, config.n_embd, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x1 = self.w1(x)
x2 = self.w2(x)
hidden = F.silu(x1) * x2
return self.dropout(self.w3(hidden))
class TRMBlock(nn.Module):
def __init__(self, config: ChessConfig):
super().__init__()
self.self_attn = MultiQueryAttention(config)
self.mlp = SwiGLU(config)
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
def forward(self, x, cos, sin):
attn_out = self.self_attn(x, cos, sin)
x = self.ln_1(x + attn_out)
mlp_out = self.mlp(x)
x = self.ln_2(x + mlp_out)
return x
class TRMReasoningModule(nn.Module):
"""
The reusable module containing shared layers.
Implements Input Injection: hidden_states = hidden_states + injection
"""
def __init__(self, config: ChessConfig):
super().__init__()
self.layers = nn.ModuleList([TRMBlock(config) for _ in range(config.n_layers_per_block)])
def forward(self, hidden_states, input_injection, cos, sin):
hidden_states = hidden_states + input_injection
for layer in self.layers:
hidden_states = layer(hidden_states, cos, sin)
return hidden_states
class ChessForCausalLM(PreTrainedModel):
config_class = ChessConfig
def __init__(self, config: ChessConfig):
super().__init__(config)
self.config = config
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.rotary = RotaryEmbedding(config.n_embd // config.n_head, max_position_embeddings=config.n_ctx, base=config.rope_theta)
self.reasoning_module = TRMReasoningModule(config)
self.z_H_init = nn.Parameter(torch.randn(1, 1, config.n_embd) * 0.02)
self.z_L_init = nn.Parameter(torch.randn(1, 1, config.n_embd) * 0.02)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
if config.tie_weights:
self.lm_head.weight = self.wte.weight
self.post_init()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
input_ids: torch.LongTensor,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
B, T = input_ids.size()
x_emb = self.wte(input_ids)
cos, sin = self.rotary(x_emb, seq_len=T)
z_H = self.z_H_init.expand(B, T, -1).contiguous()
z_L = self.z_L_init.expand(B, T, -1).contiguous()
with torch.no_grad():
for _h in range(self.config.h_cycles - 1):
# L-loop (updates z_L)
for _l in range(self.config.l_cycles):
z_L = self.reasoning_module(
hidden_states=z_L,
input_injection=(z_H + x_emb),
cos=cos, sin=sin
)
# H-loop step (updates z_H)
z_H = self.reasoning_module(
hidden_states=z_H,
input_injection=z_L,
cos=cos, sin=sin
)
for _l in range(self.config.l_cycles):
z_L = self.reasoning_module(
hidden_states=z_L,
input_injection=(z_H + x_emb),
cos=cos, sin=sin
)
z_H = self.reasoning_module(
hidden_states=z_H,
input_injection=z_L,
cos=cos, sin=sin
)
logits = self.lm_head(z_H)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None
)
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("chess_transformer", ChessConfig)
AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)