# """ # SOTA Chess Transformer (Llama/DeepSeek Style) # Updated for the 1M Parameter Challenge. # Improvements over baseline: # 1. RoPE (Rotary Positional Embeddings) - Saves ~32k params, better context. # 2. RMSNorm - More stable than LayerNorm. # 3. SwiGLU - Better activation function for reasoning. # 4. QK-Norm - (From OLMo 2) Stabilizes attention. # """ # from __future__ import annotations # import math # from dataclasses import dataclass # from typing import Optional, Tuple, Union # import torch # import torch.nn as nn # import torch.nn.functional as F # from transformers import PretrainedConfig, PreTrainedModel # from transformers.modeling_outputs import CausalLMOutputWithPast # from transformers import LogitsProcessor, LogitsProcessorList # from transformers.generation import GenerationMixin # class ChessConfig(PretrainedConfig): # model_type = "chess_transformer" # def __init__( # self, # vocab_size: int = 1200, # n_embd: int = 128, # n_layer: int = 8, # Increased default depth since RoPE saves params # n_head: int = 4, # n_ctx: int = 256, # n_inner: Optional[int] = None, # dropout: float = 0.0, # Modern LLMs often use 0 dropout # rms_norm_eps: float = 1e-6, # tie_weights: bool = True, # pad_token_id: int = 0, # bos_token_id: int = 1, # eos_token_id: int = 2, # **kwargs, # ): # super().__init__( # pad_token_id=pad_token_id, # bos_token_id=bos_token_id, # eos_token_id=eos_token_id, # is_decoder=True, # **kwargs, # ) # self.vocab_size = vocab_size # self.n_embd = n_embd # self.n_layer = n_layer # self.n_head = n_head # # Mapping for Hugging Face compatibility # self.num_hidden_layers = n_layer # self.hidden_size = n_embd # self.num_attention_heads = n_head # self.n_ctx = n_ctx # # SwiGLU needs a different inner dimension to match parameter count. # # Usually 2/3 of 4d, but we can tune this. # self.n_inner = n_inner if n_inner is not None else int(8/3 * n_embd) # self.dropout = dropout # self.rms_norm_eps = rms_norm_eps # self.tie_weights = tie_weights # self.tie_word_embeddings = bool(tie_weights) # class RMSNorm(nn.Module): # """Root Mean Square Layer Normalization (Llama style).""" # def __init__(self, dim: int, eps: float = 1e-6): # super().__init__() # self.eps = eps # self.weight = nn.Parameter(torch.ones(dim)) # def _norm(self, x): # return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) # def forward(self, x): # output = self._norm(x.float()).type_as(x) # return output * self.weight # def apply_rotary_pos_emb(q, k, cos, sin): # """Apply Rotary Positional Embeddings (RoPE).""" # # Reshape cos/sin to match q/k: [batch, 1, seq_len, head_dim] # # Note: This is a simplified implementation for the challenge # cos = cos.unsqueeze(1) # sin = sin.unsqueeze(1) # q_embed = (q * cos) + (rotate_half(q) * sin) # k_embed = (k * cos) + (rotate_half(k) * sin) # return q_embed, k_embed # def rotate_half(x): # """Rotates half the hidden dims of the input.""" # x1 = x[..., : x.shape[-1] // 2] # x2 = x[..., x.shape[-1] // 2 :] # return torch.cat((-x2, x1), dim=-1) # class SOTAMultiHeadAttention(nn.Module): # def __init__(self, config: ChessConfig): # super().__init__() # self.n_head = config.n_head # self.n_embd = config.n_embd # self.head_dim = config.n_embd // config.n_head # # QKV Projections # self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # # QK-Norm (from OLMo 2) - Stabilizes training # self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) # self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) # # RoPE cache # self.register_buffer("inv_freq", 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)), persistent=False) # def get_rope_embeddings(self, seq_len, device): # t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) # freqs = torch.einsum("i,j->ij", t, self.inv_freq) # emb = torch.cat((freqs, freqs), dim=-1) # return emb.cos(), emb.sin() # def forward(self, x, attention_mask=None): # batch_size, seq_len, _ = x.size() # # 1. Project # q = self.q_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim) # k = self.k_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim) # v = self.v_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim) # # 2. QK-Norm (OLMo style) - Normalize BEFORE RoPE # q = self.q_norm(q) # k = self.k_norm(k) # # 3. Apply RoPE # # Transpose to [batch, head, seq, dim] for easier math # q = q.transpose(1, 2) # k = k.transpose(1, 2) # v = v.transpose(1, 2) # cos, sin = self.get_rope_embeddings(seq_len, x.device) # # Match dimensions for broadcasting # cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq, dim] # sin = sin.unsqueeze(0).unsqueeze(0) # q = (q * cos) + (rotate_half(q) * sin) # k = (k * cos) + (rotate_half(k) * sin) # # 4. Attention # # Efficient Flash Attention if available (or standard) # attn_output = F.scaled_dot_product_attention( # q, k, v, # attn_mask=None, # dropout_p=0.0, # is_causal=True # ) # # 5. Output Projection # attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.n_embd) # return self.o_proj(attn_output) # class SwiGLUFeedForward(nn.Module): # """SwiGLU FFN (Llama/DeepSeek style).""" # def __init__(self, config: ChessConfig): # super().__init__() # # SwiGLU has 3 projections: Gate, Value, Output # self.gate_proj = nn.Linear(config.n_embd, config.n_inner, bias=False) # self.up_proj = nn.Linear(config.n_embd, config.n_inner, bias=False) # self.down_proj = nn.Linear(config.n_inner, config.n_embd, bias=False) # def forward(self, x): # # SwiGLU: (Swish(Gate) * Up) -> Down # return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) # class SOTATransformerBlock(nn.Module): # def __init__(self, config: ChessConfig): # super().__init__() # self.input_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) # self.self_attn = SOTAMultiHeadAttention(config) # self.post_attention_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) # self.mlp = SwiGLUFeedForward(config) # def forward(self, x, attention_mask=None): # # Pre-norm architecture # x = x + self.self_attn(self.input_layernorm(x), attention_mask) # x = x + self.mlp(self.post_attention_layernorm(x)) # return x # class FourStepConsistency(LogitsProcessor): # """ # Enforces the 4-step rhythm: [Piece] -> [From] -> [To] -> [Suffix] # """ # def __init__(self, tokenizer, start_len): # self.tokenizer = tokenizer # self.start_len = start_len # all_ids = set(range(tokenizer.vocab_size)) # # 1. Piece IDs # self.piece_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.PIECES if t in tokenizer.get_vocab()} # # 2. Square IDs (Used for both From and To) # self.square_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.SQUARES if t in tokenizer.get_vocab()} # # 3. Suffix IDs # self.suffix_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.SUFFIXES if t in tokenizer.get_vocab()} # def __call__(self, input_ids, scores): # cur_len = input_ids.shape[1] # relative_pos = (cur_len - self.start_len) % 4 # mask_ids = set() # if relative_pos == 0: # Step 1: Piece # mask_ids = self.piece_ids # elif relative_pos == 1: # Step 2: From Square # mask_ids = self.square_ids # elif relative_pos == 2: # Step 3: To Square # mask_ids = self.square_ids # else: # Step 4: Suffix # mask_ids = self.suffix_ids # # Mask out disallowed tokens # for i in range(scores.shape[1]): # if i not in mask_ids and i != self.tokenizer.eos_token_id: # scores[:, i] = float("-inf") # return scores # class ChessForCausalLM(PreTrainedModel, GenerationMixin): # config_class = ChessConfig # def __init__(self, config: ChessConfig): # super().__init__(config) # # 1. Embeddings (No Position Embeddings needed, RoPE handles it!) # self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd) # # 2. Layers # self.layers = nn.ModuleList([ # SOTATransformerBlock(config) for _ in range(config.n_layer) # ]) # # 3. Final Norm # self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) # # 4. Head # self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # # Tie weights # if config.tie_weights: # self.lm_head.weight = self.embed_tokens.weight # self._tied_weights_keys = ["lm_head.weight"] # self.post_init() # def get_input_embeddings(self): # return self.embed_tokens # def set_input_embeddings(self, value): # self.embed_tokens = value # def get_output_embeddings(self): # return self.lm_head # def set_output_embeddings(self, new_embeddings): # self.lm_head = new_embeddings # def forward( # self, # input_ids: torch.LongTensor = None, # attention_mask: Optional[torch.Tensor] = None, # labels: Optional[torch.LongTensor] = None, # return_dict: Optional[bool] = None, # **kwargs, # ) -> Union[Tuple, CausalLMOutputWithPast]: # batch_size, seq_len = input_ids.shape # hidden_states = self.embed_tokens(input_ids) # for layer in self.layers: # hidden_states = layer(hidden_states, attention_mask) # hidden_states = self.norm(hidden_states) # logits = self.lm_head(hidden_states) # loss = None # if labels is not None: # shift_logits = logits[..., :-1, :].contiguous() # shift_labels = labels[..., 1:].contiguous() # loss_fct = nn.CrossEntropyLoss(ignore_index=-100) # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # return_dict = return_dict if return_dict is not None else self.config.use_return_dict # if not return_dict: # output = (logits,) # return ((loss,) + output) if loss is not None else output # return CausalLMOutputWithPast( # loss=loss, # logits=logits, # past_key_values=None, # hidden_states=None, # attentions=None, # ) # def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): # # 1. Handle Cache (Past Key Values) # # If we have a cache, we only need to process the very last token we generated # if past_key_values: # input_ids = input_ids[:, -1:] # # 2. Handle Position IDs # # If the user didn't provide position_ids, we might need to create them from the attention_mask # position_ids = kwargs.get("position_ids", None) # # FIX: Explicitly check 'is not None' to avoid the ambiguous Tensor error # attention_mask = kwargs.get("attention_mask", None) # if attention_mask is not None: # # Create position_ids based on the mask (0, 1, 2... ignoring padding) # if position_ids is None: # position_ids = attention_mask.long().cumsum(-1) - 1 # position_ids.masked_fill_(attention_mask == 0, 1) # # If using cache, we only need the position ID for the last token # if past_key_values: # position_ids = position_ids[:, -1].unsqueeze(-1) # return { # "input_ids": input_ids, # "past_key_values": past_key_values, # "use_cache": kwargs.get("use_cache"), # "position_ids": position_ids, # "attention_mask": attention_mask, # } # def generate(self, input_ids, **kwargs): # tokenizer = kwargs.pop("tokenizer", None) # if tokenizer is not None: # # Use the 4-step synthesizer # synthesizer = FourStepConsistency(tokenizer, input_ids.shape[1]) # logits_processor = kwargs.get("logits_processor", LogitsProcessorList()) # logits_processor.append(synthesizer) # kwargs["logits_processor"] = logits_processor # # Call GenerationMixin directly to bypass any PreTrainedModel ambiguity # return GenerationMixin.generate(self, input_ids, **kwargs) # # Register # from transformers import AutoConfig, AutoModelForCausalLM # AutoConfig.register("chess_transformer", ChessConfig) # AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM) """ SOTA Chess Transformer (Llama/DeepSeek Style) Updated for the 1M Parameter Challenge. """ from __future__ import annotations import math from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from transformers import LogitsProcessor, LogitsProcessorList from transformers.generation import GenerationMixin class ChessConfig(PretrainedConfig): model_type = "chess_transformer" def __init__( self, vocab_size: int = 1200, n_embd: int = 128, n_layer: int = 8, n_head: int = 4, n_ctx: int = 256, n_inner: Optional[int] = None, dropout: float = 0.0, rms_norm_eps: float = 1e-6, tie_weights: bool = True, pad_token_id: int = 0, bos_token_id: int = 1, eos_token_id: int = 2, **kwargs, ): super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, is_decoder=True, **kwargs, ) self.vocab_size = vocab_size self.n_embd = n_embd self.n_layer = n_layer self.n_head = n_head self.n_ctx = n_ctx self.n_inner = n_inner if n_inner is not None else int(8/3 * n_embd) self.dropout = dropout self.rms_norm_eps = rms_norm_eps self.tie_weights = tie_weights self.tie_word_embeddings = bool(tie_weights) self.num_hidden_layers = n_layer self.hidden_size = n_embd self.num_attention_heads = n_head class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) class SOTAMultiHeadAttention(nn.Module): def __init__(self, config: ChessConfig): super().__init__() self.n_head = config.n_head self.n_embd = config.n_embd self.head_dim = config.n_embd // config.n_head self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.register_buffer("inv_freq", 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)), persistent=False) def get_rope_embeddings(self, position_ids, device): # FIX: Use explicit position_ids instead of arange(seq_len) # position_ids: [batch, seq_len] inv_freq = self.inv_freq.to(device) # Outer product: [batch, seq_len, head_dim/2] # We need to flatten batch/seq to simplify, or use broadcasting # freqs = (pos * freq) # position_ids is [batch, seq], inv_freq is [dim] # Output should be [batch, seq, dim] freqs = torch.einsum("bs,d->bsd", position_ids.float(), inv_freq) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos(), emb.sin() def forward(self, x, attention_mask=None, position_ids=None): batch_size, seq_len, _ = x.size() q = self.q_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim) k = self.k_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim) v = self.v_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim) q = self.q_norm(q) k = self.k_norm(k) # Transpose for RoPE [batch, head, seq, dim] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) if position_ids is None: position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1) cos, sin = self.get_rope_embeddings(position_ids, x.device) cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) q = (q * cos) + (rotate_half(q) * sin) k = (k * cos) + (rotate_half(k) * sin) # --- FIX: Resolve Conflict between attn_mask and is_causal --- if attention_mask is not None: # 1. Expand to 4D for broadcasting if needed if attention_mask.dim() == 2: attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # 2. Convert Int to Bool if needed (PyTorch SDPA prefers this) if attention_mask.dtype in [torch.long, torch.int64, torch.int32]: attention_mask = (attention_mask == 0) # True for masked, False for keep? # Wait, usually 1=Keep, 0=Mask. # If using bool mask in SDPA: True = Masked Out (Ignore). # So if input is 1 (Keep), we want False (Don't Mask). # If input is 0 (Pad), we want True (Mask). # So (mask == 0) gives us True for Padding. Correct. # 3. CRITICAL: If the mask is "Empty" (all False = keep everything), # drop it so we can use is_causal=True without error. # (Note: In boolean mask, 'False' means 'Keep') if not attention_mask.any(): attention_mask = None # ------------------------------------------------------------- attn_output = F.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=True ) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.n_embd) return self.o_proj(attn_output) class SwiGLUFeedForward(nn.Module): def __init__(self, config: ChessConfig): super().__init__() self.gate_proj = nn.Linear(config.n_embd, config.n_inner, bias=False) self.up_proj = nn.Linear(config.n_embd, config.n_inner, bias=False) self.down_proj = nn.Linear(config.n_inner, config.n_embd, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class SOTATransformerBlock(nn.Module): def __init__(self, config: ChessConfig): super().__init__() self.input_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) self.self_attn = SOTAMultiHeadAttention(config) self.post_attention_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) self.mlp = SwiGLUFeedForward(config) def forward(self, x, attention_mask=None, position_ids=None): # FIX: Pass position_ids down x = x + self.self_attn(self.input_layernorm(x), attention_mask, position_ids) x = x + self.mlp(self.post_attention_layernorm(x)) return x class FourStepConsistency(LogitsProcessor): def __init__(self, tokenizer, start_len): self.tokenizer = tokenizer self.start_len = start_len self.piece_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.PIECES if t in tokenizer.get_vocab()} self.square_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.SQUARES if t in tokenizer.get_vocab()} self.suffix_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.SUFFIXES if t in tokenizer.get_vocab()} def __call__(self, input_ids, scores): cur_len = input_ids.shape[1] relative_pos = (cur_len - self.start_len) % 4 mask_ids = set() if relative_pos == 0: mask_ids = self.piece_ids elif relative_pos == 1: mask_ids = self.square_ids elif relative_pos == 2: mask_ids = self.square_ids else: mask_ids = self.suffix_ids for i in range(scores.shape[1]): if i not in mask_ids and i != self.tokenizer.eos_token_id: scores[:, i] = float("-inf") return scores class ChessForCausalLM(PreTrainedModel, GenerationMixin): config_class = ChessConfig def __init__(self, config: ChessConfig): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd) self.layers = nn.ModuleList([SOTATransformerBlock(config) for _ in range(config.n_layer)]) self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) if config.tie_weights: self.lm_head.weight = self.embed_tokens.weight self._tied_weights_keys = ["lm_head.weight"] self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, **kwargs): batch_size, seq_len = input_ids.shape hidden_states = self.embed_tokens(input_ids) # FIX: Ensure position_ids exist if position_ids is None: position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1) # FIX: Pass position_ids to layers for layer in self.layers: hidden_states = layer(hidden_states, attention_mask, position_ids) hidden_states = self.norm(hidden_states) logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): # FORCE NO CACHE: Always process the full sequence. # This matches our SOTAMultiHeadAttention which handles the full history every time. position_ids = kwargs.get("position_ids", None) attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: # Create position_ids from the mask position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) return { "input_ids": input_ids, # Return FULL input_ids (do not slice) "past_key_values": None, # Force None so the model doesn't expect cache "use_cache": False, # Explicitly disable cache flag "position_ids": position_ids, "attention_mask": attention_mask, } def generate(self, input_ids, **kwargs): tokenizer = kwargs.pop("tokenizer", None) if tokenizer is not None: synthesizer = FourStepConsistency(tokenizer, input_ids.shape[1]) logits_processor = kwargs.get("logits_processor", LogitsProcessorList()) logits_processor.append(synthesizer) kwargs["logits_processor"] = logits_processor return GenerationMixin.generate(self, input_ids, **kwargs) # Register from transformers import AutoConfig, AutoModelForCausalLM AutoConfig.register("chess_transformer", ChessConfig) AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)