| # """ | |
| # SOTA Chess Transformer (Llama/DeepSeek Style) | |
| # Updated for the 1M Parameter Challenge. | |
| # Improvements over baseline: | |
| # 1. RoPE (Rotary Positional Embeddings) - Saves ~32k params, better context. | |
| # 2. RMSNorm - More stable than LayerNorm. | |
| # 3. SwiGLU - Better activation function for reasoning. | |
| # 4. QK-Norm - (From OLMo 2) Stabilizes attention. | |
| # """ | |
| # from __future__ import annotations | |
| # import math | |
| # from dataclasses import dataclass | |
| # from typing import Optional, Tuple, Union | |
| # import torch | |
| # import torch.nn as nn | |
| # import torch.nn.functional as F | |
| # from transformers import PretrainedConfig, PreTrainedModel | |
| # from transformers.modeling_outputs import CausalLMOutputWithPast | |
| # from transformers import LogitsProcessor, LogitsProcessorList | |
| # from transformers.generation import GenerationMixin | |
| # class ChessConfig(PretrainedConfig): | |
| # model_type = "chess_transformer" | |
| # def __init__( | |
| # self, | |
| # vocab_size: int = 1200, | |
| # n_embd: int = 128, | |
| # n_layer: int = 8, # Increased default depth since RoPE saves params | |
| # n_head: int = 4, | |
| # n_ctx: int = 256, | |
| # n_inner: Optional[int] = None, | |
| # dropout: float = 0.0, # Modern LLMs often use 0 dropout | |
| # rms_norm_eps: float = 1e-6, | |
| # tie_weights: bool = True, | |
| # pad_token_id: int = 0, | |
| # bos_token_id: int = 1, | |
| # eos_token_id: int = 2, | |
| # **kwargs, | |
| # ): | |
| # super().__init__( | |
| # pad_token_id=pad_token_id, | |
| # bos_token_id=bos_token_id, | |
| # eos_token_id=eos_token_id, | |
| # is_decoder=True, | |
| # **kwargs, | |
| # ) | |
| # self.vocab_size = vocab_size | |
| # self.n_embd = n_embd | |
| # self.n_layer = n_layer | |
| # self.n_head = n_head | |
| # # Mapping for Hugging Face compatibility | |
| # self.num_hidden_layers = n_layer | |
| # self.hidden_size = n_embd | |
| # self.num_attention_heads = n_head | |
| # self.n_ctx = n_ctx | |
| # # SwiGLU needs a different inner dimension to match parameter count. | |
| # # Usually 2/3 of 4d, but we can tune this. | |
| # self.n_inner = n_inner if n_inner is not None else int(8/3 * n_embd) | |
| # self.dropout = dropout | |
| # self.rms_norm_eps = rms_norm_eps | |
| # self.tie_weights = tie_weights | |
| # self.tie_word_embeddings = bool(tie_weights) | |
| # class RMSNorm(nn.Module): | |
| # """Root Mean Square Layer Normalization (Llama style).""" | |
| # def __init__(self, dim: int, eps: float = 1e-6): | |
| # super().__init__() | |
| # self.eps = eps | |
| # self.weight = nn.Parameter(torch.ones(dim)) | |
| # def _norm(self, x): | |
| # return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| # def forward(self, x): | |
| # output = self._norm(x.float()).type_as(x) | |
| # return output * self.weight | |
| # def apply_rotary_pos_emb(q, k, cos, sin): | |
| # """Apply Rotary Positional Embeddings (RoPE).""" | |
| # # Reshape cos/sin to match q/k: [batch, 1, seq_len, head_dim] | |
| # # Note: This is a simplified implementation for the challenge | |
| # cos = cos.unsqueeze(1) | |
| # sin = sin.unsqueeze(1) | |
| # q_embed = (q * cos) + (rotate_half(q) * sin) | |
| # k_embed = (k * cos) + (rotate_half(k) * sin) | |
| # return q_embed, k_embed | |
| # def rotate_half(x): | |
| # """Rotates half the hidden dims of the input.""" | |
| # x1 = x[..., : x.shape[-1] // 2] | |
| # x2 = x[..., x.shape[-1] // 2 :] | |
| # return torch.cat((-x2, x1), dim=-1) | |
| # class SOTAMultiHeadAttention(nn.Module): | |
| # def __init__(self, config: ChessConfig): | |
| # super().__init__() | |
| # self.n_head = config.n_head | |
| # self.n_embd = config.n_embd | |
| # self.head_dim = config.n_embd // config.n_head | |
| # # QKV Projections | |
| # self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
| # self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
| # self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
| # self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
| # # QK-Norm (from OLMo 2) - Stabilizes training | |
| # self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| # self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| # # RoPE cache | |
| # self.register_buffer("inv_freq", 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)), persistent=False) | |
| # def get_rope_embeddings(self, seq_len, device): | |
| # t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) | |
| # freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
| # emb = torch.cat((freqs, freqs), dim=-1) | |
| # return emb.cos(), emb.sin() | |
| # def forward(self, x, attention_mask=None): | |
| # batch_size, seq_len, _ = x.size() | |
| # # 1. Project | |
| # q = self.q_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim) | |
| # k = self.k_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim) | |
| # v = self.v_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim) | |
| # # 2. QK-Norm (OLMo style) - Normalize BEFORE RoPE | |
| # q = self.q_norm(q) | |
| # k = self.k_norm(k) | |
| # # 3. Apply RoPE | |
| # # Transpose to [batch, head, seq, dim] for easier math | |
| # q = q.transpose(1, 2) | |
| # k = k.transpose(1, 2) | |
| # v = v.transpose(1, 2) | |
| # cos, sin = self.get_rope_embeddings(seq_len, x.device) | |
| # # Match dimensions for broadcasting | |
| # cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq, dim] | |
| # sin = sin.unsqueeze(0).unsqueeze(0) | |
| # q = (q * cos) + (rotate_half(q) * sin) | |
| # k = (k * cos) + (rotate_half(k) * sin) | |
| # # 4. Attention | |
| # # Efficient Flash Attention if available (or standard) | |
| # attn_output = F.scaled_dot_product_attention( | |
| # q, k, v, | |
| # attn_mask=None, | |
| # dropout_p=0.0, | |
| # is_causal=True | |
| # ) | |
| # # 5. Output Projection | |
| # attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.n_embd) | |
| # return self.o_proj(attn_output) | |
| # class SwiGLUFeedForward(nn.Module): | |
| # """SwiGLU FFN (Llama/DeepSeek style).""" | |
| # def __init__(self, config: ChessConfig): | |
| # super().__init__() | |
| # # SwiGLU has 3 projections: Gate, Value, Output | |
| # self.gate_proj = nn.Linear(config.n_embd, config.n_inner, bias=False) | |
| # self.up_proj = nn.Linear(config.n_embd, config.n_inner, bias=False) | |
| # self.down_proj = nn.Linear(config.n_inner, config.n_embd, bias=False) | |
| # def forward(self, x): | |
| # # SwiGLU: (Swish(Gate) * Up) -> Down | |
| # return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| # class SOTATransformerBlock(nn.Module): | |
| # def __init__(self, config: ChessConfig): | |
| # super().__init__() | |
| # self.input_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) | |
| # self.self_attn = SOTAMultiHeadAttention(config) | |
| # self.post_attention_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) | |
| # self.mlp = SwiGLUFeedForward(config) | |
| # def forward(self, x, attention_mask=None): | |
| # # Pre-norm architecture | |
| # x = x + self.self_attn(self.input_layernorm(x), attention_mask) | |
| # x = x + self.mlp(self.post_attention_layernorm(x)) | |
| # return x | |
| # class FourStepConsistency(LogitsProcessor): | |
| # """ | |
| # Enforces the 4-step rhythm: [Piece] -> [From] -> [To] -> [Suffix] | |
| # """ | |
| # def __init__(self, tokenizer, start_len): | |
| # self.tokenizer = tokenizer | |
| # self.start_len = start_len | |
| # all_ids = set(range(tokenizer.vocab_size)) | |
| # # 1. Piece IDs | |
| # self.piece_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.PIECES if t in tokenizer.get_vocab()} | |
| # # 2. Square IDs (Used for both From and To) | |
| # self.square_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.SQUARES if t in tokenizer.get_vocab()} | |
| # # 3. Suffix IDs | |
| # self.suffix_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.SUFFIXES if t in tokenizer.get_vocab()} | |
| # def __call__(self, input_ids, scores): | |
| # cur_len = input_ids.shape[1] | |
| # relative_pos = (cur_len - self.start_len) % 4 | |
| # mask_ids = set() | |
| # if relative_pos == 0: # Step 1: Piece | |
| # mask_ids = self.piece_ids | |
| # elif relative_pos == 1: # Step 2: From Square | |
| # mask_ids = self.square_ids | |
| # elif relative_pos == 2: # Step 3: To Square | |
| # mask_ids = self.square_ids | |
| # else: # Step 4: Suffix | |
| # mask_ids = self.suffix_ids | |
| # # Mask out disallowed tokens | |
| # for i in range(scores.shape[1]): | |
| # if i not in mask_ids and i != self.tokenizer.eos_token_id: | |
| # scores[:, i] = float("-inf") | |
| # return scores | |
| # class ChessForCausalLM(PreTrainedModel, GenerationMixin): | |
| # config_class = ChessConfig | |
| # def __init__(self, config: ChessConfig): | |
| # super().__init__(config) | |
| # # 1. Embeddings (No Position Embeddings needed, RoPE handles it!) | |
| # self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd) | |
| # # 2. Layers | |
| # self.layers = nn.ModuleList([ | |
| # SOTATransformerBlock(config) for _ in range(config.n_layer) | |
| # ]) | |
| # # 3. Final Norm | |
| # self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) | |
| # # 4. Head | |
| # self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| # # Tie weights | |
| # if config.tie_weights: | |
| # self.lm_head.weight = self.embed_tokens.weight | |
| # self._tied_weights_keys = ["lm_head.weight"] | |
| # self.post_init() | |
| # def get_input_embeddings(self): | |
| # return self.embed_tokens | |
| # def set_input_embeddings(self, value): | |
| # self.embed_tokens = value | |
| # def get_output_embeddings(self): | |
| # return self.lm_head | |
| # def set_output_embeddings(self, new_embeddings): | |
| # self.lm_head = new_embeddings | |
| # def forward( | |
| # self, | |
| # input_ids: torch.LongTensor = None, | |
| # attention_mask: Optional[torch.Tensor] = None, | |
| # labels: Optional[torch.LongTensor] = None, | |
| # return_dict: Optional[bool] = None, | |
| # **kwargs, | |
| # ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| # batch_size, seq_len = input_ids.shape | |
| # hidden_states = self.embed_tokens(input_ids) | |
| # for layer in self.layers: | |
| # hidden_states = layer(hidden_states, attention_mask) | |
| # hidden_states = self.norm(hidden_states) | |
| # logits = self.lm_head(hidden_states) | |
| # loss = None | |
| # if labels is not None: | |
| # shift_logits = logits[..., :-1, :].contiguous() | |
| # shift_labels = labels[..., 1:].contiguous() | |
| # loss_fct = nn.CrossEntropyLoss(ignore_index=-100) | |
| # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
| # return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # if not return_dict: | |
| # output = (logits,) | |
| # return ((loss,) + output) if loss is not None else output | |
| # return CausalLMOutputWithPast( | |
| # loss=loss, | |
| # logits=logits, | |
| # past_key_values=None, | |
| # hidden_states=None, | |
| # attentions=None, | |
| # ) | |
| # def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): | |
| # # 1. Handle Cache (Past Key Values) | |
| # # If we have a cache, we only need to process the very last token we generated | |
| # if past_key_values: | |
| # input_ids = input_ids[:, -1:] | |
| # # 2. Handle Position IDs | |
| # # If the user didn't provide position_ids, we might need to create them from the attention_mask | |
| # position_ids = kwargs.get("position_ids", None) | |
| # # FIX: Explicitly check 'is not None' to avoid the ambiguous Tensor error | |
| # attention_mask = kwargs.get("attention_mask", None) | |
| # if attention_mask is not None: | |
| # # Create position_ids based on the mask (0, 1, 2... ignoring padding) | |
| # if position_ids is None: | |
| # position_ids = attention_mask.long().cumsum(-1) - 1 | |
| # position_ids.masked_fill_(attention_mask == 0, 1) | |
| # # If using cache, we only need the position ID for the last token | |
| # if past_key_values: | |
| # position_ids = position_ids[:, -1].unsqueeze(-1) | |
| # return { | |
| # "input_ids": input_ids, | |
| # "past_key_values": past_key_values, | |
| # "use_cache": kwargs.get("use_cache"), | |
| # "position_ids": position_ids, | |
| # "attention_mask": attention_mask, | |
| # } | |
| # def generate(self, input_ids, **kwargs): | |
| # tokenizer = kwargs.pop("tokenizer", None) | |
| # if tokenizer is not None: | |
| # # Use the 4-step synthesizer | |
| # synthesizer = FourStepConsistency(tokenizer, input_ids.shape[1]) | |
| # logits_processor = kwargs.get("logits_processor", LogitsProcessorList()) | |
| # logits_processor.append(synthesizer) | |
| # kwargs["logits_processor"] = logits_processor | |
| # # Call GenerationMixin directly to bypass any PreTrainedModel ambiguity | |
| # return GenerationMixin.generate(self, input_ids, **kwargs) | |
| # # Register | |
| # from transformers import AutoConfig, AutoModelForCausalLM | |
| # AutoConfig.register("chess_transformer", ChessConfig) | |
| # AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM) | |
| """ | |
| SOTA Chess Transformer (Llama/DeepSeek Style) | |
| Updated for the 1M Parameter Challenge. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from transformers import LogitsProcessor, LogitsProcessorList | |
| from transformers.generation import GenerationMixin | |
| class ChessConfig(PretrainedConfig): | |
| model_type = "chess_transformer" | |
| def __init__( | |
| self, | |
| vocab_size: int = 1200, | |
| n_embd: int = 128, | |
| n_layer: int = 8, | |
| n_head: int = 4, | |
| n_ctx: int = 256, | |
| n_inner: Optional[int] = None, | |
| dropout: float = 0.0, | |
| rms_norm_eps: float = 1e-6, | |
| tie_weights: bool = True, | |
| pad_token_id: int = 0, | |
| bos_token_id: int = 1, | |
| eos_token_id: int = 2, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| pad_token_id=pad_token_id, | |
| bos_token_id=bos_token_id, | |
| eos_token_id=eos_token_id, | |
| is_decoder=True, | |
| **kwargs, | |
| ) | |
| self.vocab_size = vocab_size | |
| self.n_embd = n_embd | |
| self.n_layer = n_layer | |
| self.n_head = n_head | |
| self.n_ctx = n_ctx | |
| self.n_inner = n_inner if n_inner is not None else int(8/3 * n_embd) | |
| self.dropout = dropout | |
| self.rms_norm_eps = rms_norm_eps | |
| self.tie_weights = tie_weights | |
| self.tie_word_embeddings = bool(tie_weights) | |
| self.num_hidden_layers = n_layer | |
| self.hidden_size = n_embd | |
| self.num_attention_heads = n_head | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| def rotate_half(x): | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| class SOTAMultiHeadAttention(nn.Module): | |
| def __init__(self, config: ChessConfig): | |
| super().__init__() | |
| self.n_head = config.n_head | |
| self.n_embd = config.n_embd | |
| self.head_dim = config.n_embd // config.n_head | |
| self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
| self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
| self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
| self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
| self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| self.register_buffer("inv_freq", 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)), persistent=False) | |
| def get_rope_embeddings(self, position_ids, device): | |
| # FIX: Use explicit position_ids instead of arange(seq_len) | |
| # position_ids: [batch, seq_len] | |
| inv_freq = self.inv_freq.to(device) | |
| # Outer product: [batch, seq_len, head_dim/2] | |
| # We need to flatten batch/seq to simplify, or use broadcasting | |
| # freqs = (pos * freq) | |
| # position_ids is [batch, seq], inv_freq is [dim] | |
| # Output should be [batch, seq, dim] | |
| freqs = torch.einsum("bs,d->bsd", position_ids.float(), inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| return emb.cos(), emb.sin() | |
| def forward(self, x, attention_mask=None, position_ids=None): | |
| batch_size, seq_len, _ = x.size() | |
| q = self.q_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim) | |
| k = self.k_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim) | |
| v = self.v_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim) | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| # Transpose for RoPE [batch, head, seq, dim] | |
| q = q.transpose(1, 2) | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| if position_ids is None: | |
| position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1) | |
| cos, sin = self.get_rope_embeddings(position_ids, x.device) | |
| cos = cos.unsqueeze(1) | |
| sin = sin.unsqueeze(1) | |
| q = (q * cos) + (rotate_half(q) * sin) | |
| k = (k * cos) + (rotate_half(k) * sin) | |
| # --- FIX: Resolve Conflict between attn_mask and is_causal --- | |
| if attention_mask is not None: | |
| # 1. Expand to 4D for broadcasting if needed | |
| if attention_mask.dim() == 2: | |
| attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) | |
| # 2. Convert Int to Bool if needed (PyTorch SDPA prefers this) | |
| if attention_mask.dtype in [torch.long, torch.int64, torch.int32]: | |
| attention_mask = (attention_mask == 0) # True for masked, False for keep? | |
| # Wait, usually 1=Keep, 0=Mask. | |
| # If using bool mask in SDPA: True = Masked Out (Ignore). | |
| # So if input is 1 (Keep), we want False (Don't Mask). | |
| # If input is 0 (Pad), we want True (Mask). | |
| # So (mask == 0) gives us True for Padding. Correct. | |
| # 3. CRITICAL: If the mask is "Empty" (all False = keep everything), | |
| # drop it so we can use is_causal=True without error. | |
| # (Note: In boolean mask, 'False' means 'Keep') | |
| if not attention_mask.any(): | |
| attention_mask = None | |
| # ------------------------------------------------------------- | |
| attn_output = F.scaled_dot_product_attention( | |
| q, k, v, | |
| attn_mask=attention_mask, | |
| dropout_p=0.0, | |
| is_causal=True | |
| ) | |
| attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.n_embd) | |
| return self.o_proj(attn_output) | |
| class SwiGLUFeedForward(nn.Module): | |
| def __init__(self, config: ChessConfig): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(config.n_embd, config.n_inner, bias=False) | |
| self.up_proj = nn.Linear(config.n_embd, config.n_inner, bias=False) | |
| self.down_proj = nn.Linear(config.n_inner, config.n_embd, bias=False) | |
| def forward(self, x): | |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| class SOTATransformerBlock(nn.Module): | |
| def __init__(self, config: ChessConfig): | |
| super().__init__() | |
| self.input_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) | |
| self.self_attn = SOTAMultiHeadAttention(config) | |
| self.post_attention_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) | |
| self.mlp = SwiGLUFeedForward(config) | |
| def forward(self, x, attention_mask=None, position_ids=None): | |
| # FIX: Pass position_ids down | |
| x = x + self.self_attn(self.input_layernorm(x), attention_mask, position_ids) | |
| x = x + self.mlp(self.post_attention_layernorm(x)) | |
| return x | |
| class FourStepConsistency(LogitsProcessor): | |
| def __init__(self, tokenizer, start_len): | |
| self.tokenizer = tokenizer | |
| self.start_len = start_len | |
| self.piece_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.PIECES if t in tokenizer.get_vocab()} | |
| self.square_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.SQUARES if t in tokenizer.get_vocab()} | |
| self.suffix_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.SUFFIXES if t in tokenizer.get_vocab()} | |
| def __call__(self, input_ids, scores): | |
| cur_len = input_ids.shape[1] | |
| relative_pos = (cur_len - self.start_len) % 4 | |
| mask_ids = set() | |
| if relative_pos == 0: mask_ids = self.piece_ids | |
| elif relative_pos == 1: mask_ids = self.square_ids | |
| elif relative_pos == 2: mask_ids = self.square_ids | |
| else: mask_ids = self.suffix_ids | |
| for i in range(scores.shape[1]): | |
| if i not in mask_ids and i != self.tokenizer.eos_token_id: | |
| scores[:, i] = float("-inf") | |
| return scores | |
| class ChessForCausalLM(PreTrainedModel, GenerationMixin): | |
| config_class = ChessConfig | |
| def __init__(self, config: ChessConfig): | |
| super().__init__(config) | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd) | |
| self.layers = nn.ModuleList([SOTATransformerBlock(config) for _ in range(config.n_layer)]) | |
| self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) | |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| if config.tie_weights: | |
| self.lm_head.weight = self.embed_tokens.weight | |
| self._tied_weights_keys = ["lm_head.weight"] | |
| self.post_init() | |
| def get_input_embeddings(self): return self.embed_tokens | |
| def set_input_embeddings(self, value): self.embed_tokens = value | |
| def get_output_embeddings(self): return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings | |
| def forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, **kwargs): | |
| batch_size, seq_len = input_ids.shape | |
| hidden_states = self.embed_tokens(input_ids) | |
| # FIX: Ensure position_ids exist | |
| if position_ids is None: | |
| position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1) | |
| # FIX: Pass position_ids to layers | |
| for layer in self.layers: | |
| hidden_states = layer(hidden_states, attention_mask, position_ids) | |
| hidden_states = self.norm(hidden_states) | |
| logits = self.lm_head(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) | |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if not return_dict: | |
| output = (logits,) | |
| return ((loss,) + output) if loss is not None else output | |
| return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None) | |
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): | |
| # FORCE NO CACHE: Always process the full sequence. | |
| # This matches our SOTAMultiHeadAttention which handles the full history every time. | |
| position_ids = kwargs.get("position_ids", None) | |
| attention_mask = kwargs.get("attention_mask", None) | |
| if attention_mask is not None and position_ids is None: | |
| # Create position_ids from the mask | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| return { | |
| "input_ids": input_ids, # Return FULL input_ids (do not slice) | |
| "past_key_values": None, # Force None so the model doesn't expect cache | |
| "use_cache": False, # Explicitly disable cache flag | |
| "position_ids": position_ids, | |
| "attention_mask": attention_mask, | |
| } | |
| def generate(self, input_ids, **kwargs): | |
| tokenizer = kwargs.pop("tokenizer", None) | |
| if tokenizer is not None: | |
| synthesizer = FourStepConsistency(tokenizer, input_ids.shape[1]) | |
| logits_processor = kwargs.get("logits_processor", LogitsProcessorList()) | |
| logits_processor.append(synthesizer) | |
| kwargs["logits_processor"] = logits_processor | |
| return GenerationMixin.generate(self, input_ids, **kwargs) | |
| # Register | |
| from transformers import AutoConfig, AutoModelForCausalLM | |
| AutoConfig.register("chess_transformer", ChessConfig) | |
| AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM) |