chess-aj-split / model.py
ali-javani's picture
Upload folder using huggingface_hub
db9c3ab verified
# """
# SOTA Chess Transformer (Llama/DeepSeek Style)
# Updated for the 1M Parameter Challenge.
# Improvements over baseline:
# 1. RoPE (Rotary Positional Embeddings) - Saves ~32k params, better context.
# 2. RMSNorm - More stable than LayerNorm.
# 3. SwiGLU - Better activation function for reasoning.
# 4. QK-Norm - (From OLMo 2) Stabilizes attention.
# """
# from __future__ import annotations
# import math
# from dataclasses import dataclass
# from typing import Optional, Tuple, Union
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from transformers import PretrainedConfig, PreTrainedModel
# from transformers.modeling_outputs import CausalLMOutputWithPast
# from transformers import LogitsProcessor, LogitsProcessorList
# from transformers.generation import GenerationMixin
# class ChessConfig(PretrainedConfig):
# model_type = "chess_transformer"
# def __init__(
# self,
# vocab_size: int = 1200,
# n_embd: int = 128,
# n_layer: int = 8, # Increased default depth since RoPE saves params
# n_head: int = 4,
# n_ctx: int = 256,
# n_inner: Optional[int] = None,
# dropout: float = 0.0, # Modern LLMs often use 0 dropout
# rms_norm_eps: float = 1e-6,
# tie_weights: bool = True,
# pad_token_id: int = 0,
# bos_token_id: int = 1,
# eos_token_id: int = 2,
# **kwargs,
# ):
# super().__init__(
# pad_token_id=pad_token_id,
# bos_token_id=bos_token_id,
# eos_token_id=eos_token_id,
# is_decoder=True,
# **kwargs,
# )
# self.vocab_size = vocab_size
# self.n_embd = n_embd
# self.n_layer = n_layer
# self.n_head = n_head
# # Mapping for Hugging Face compatibility
# self.num_hidden_layers = n_layer
# self.hidden_size = n_embd
# self.num_attention_heads = n_head
# self.n_ctx = n_ctx
# # SwiGLU needs a different inner dimension to match parameter count.
# # Usually 2/3 of 4d, but we can tune this.
# self.n_inner = n_inner if n_inner is not None else int(8/3 * n_embd)
# self.dropout = dropout
# self.rms_norm_eps = rms_norm_eps
# self.tie_weights = tie_weights
# self.tie_word_embeddings = bool(tie_weights)
# class RMSNorm(nn.Module):
# """Root Mean Square Layer Normalization (Llama style)."""
# def __init__(self, dim: int, eps: float = 1e-6):
# super().__init__()
# self.eps = eps
# self.weight = nn.Parameter(torch.ones(dim))
# def _norm(self, x):
# return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
# def forward(self, x):
# output = self._norm(x.float()).type_as(x)
# return output * self.weight
# def apply_rotary_pos_emb(q, k, cos, sin):
# """Apply Rotary Positional Embeddings (RoPE)."""
# # Reshape cos/sin to match q/k: [batch, 1, seq_len, head_dim]
# # Note: This is a simplified implementation for the challenge
# cos = cos.unsqueeze(1)
# sin = sin.unsqueeze(1)
# q_embed = (q * cos) + (rotate_half(q) * sin)
# k_embed = (k * cos) + (rotate_half(k) * sin)
# return q_embed, k_embed
# def rotate_half(x):
# """Rotates half the hidden dims of the input."""
# x1 = x[..., : x.shape[-1] // 2]
# x2 = x[..., x.shape[-1] // 2 :]
# return torch.cat((-x2, x1), dim=-1)
# class SOTAMultiHeadAttention(nn.Module):
# def __init__(self, config: ChessConfig):
# super().__init__()
# self.n_head = config.n_head
# self.n_embd = config.n_embd
# self.head_dim = config.n_embd // config.n_head
# # QKV Projections
# self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
# self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
# self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
# self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
# # QK-Norm (from OLMo 2) - Stabilizes training
# self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
# self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
# # RoPE cache
# self.register_buffer("inv_freq", 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)), persistent=False)
# def get_rope_embeddings(self, seq_len, device):
# t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# emb = torch.cat((freqs, freqs), dim=-1)
# return emb.cos(), emb.sin()
# def forward(self, x, attention_mask=None):
# batch_size, seq_len, _ = x.size()
# # 1. Project
# q = self.q_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim)
# k = self.k_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim)
# v = self.v_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim)
# # 2. QK-Norm (OLMo style) - Normalize BEFORE RoPE
# q = self.q_norm(q)
# k = self.k_norm(k)
# # 3. Apply RoPE
# # Transpose to [batch, head, seq, dim] for easier math
# q = q.transpose(1, 2)
# k = k.transpose(1, 2)
# v = v.transpose(1, 2)
# cos, sin = self.get_rope_embeddings(seq_len, x.device)
# # Match dimensions for broadcasting
# cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq, dim]
# sin = sin.unsqueeze(0).unsqueeze(0)
# q = (q * cos) + (rotate_half(q) * sin)
# k = (k * cos) + (rotate_half(k) * sin)
# # 4. Attention
# # Efficient Flash Attention if available (or standard)
# attn_output = F.scaled_dot_product_attention(
# q, k, v,
# attn_mask=None,
# dropout_p=0.0,
# is_causal=True
# )
# # 5. Output Projection
# attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.n_embd)
# return self.o_proj(attn_output)
# class SwiGLUFeedForward(nn.Module):
# """SwiGLU FFN (Llama/DeepSeek style)."""
# def __init__(self, config: ChessConfig):
# super().__init__()
# # SwiGLU has 3 projections: Gate, Value, Output
# self.gate_proj = nn.Linear(config.n_embd, config.n_inner, bias=False)
# self.up_proj = nn.Linear(config.n_embd, config.n_inner, bias=False)
# self.down_proj = nn.Linear(config.n_inner, config.n_embd, bias=False)
# def forward(self, x):
# # SwiGLU: (Swish(Gate) * Up) -> Down
# return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
# class SOTATransformerBlock(nn.Module):
# def __init__(self, config: ChessConfig):
# super().__init__()
# self.input_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
# self.self_attn = SOTAMultiHeadAttention(config)
# self.post_attention_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
# self.mlp = SwiGLUFeedForward(config)
# def forward(self, x, attention_mask=None):
# # Pre-norm architecture
# x = x + self.self_attn(self.input_layernorm(x), attention_mask)
# x = x + self.mlp(self.post_attention_layernorm(x))
# return x
# class FourStepConsistency(LogitsProcessor):
# """
# Enforces the 4-step rhythm: [Piece] -> [From] -> [To] -> [Suffix]
# """
# def __init__(self, tokenizer, start_len):
# self.tokenizer = tokenizer
# self.start_len = start_len
# all_ids = set(range(tokenizer.vocab_size))
# # 1. Piece IDs
# self.piece_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.PIECES if t in tokenizer.get_vocab()}
# # 2. Square IDs (Used for both From and To)
# self.square_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.SQUARES if t in tokenizer.get_vocab()}
# # 3. Suffix IDs
# self.suffix_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.SUFFIXES if t in tokenizer.get_vocab()}
# def __call__(self, input_ids, scores):
# cur_len = input_ids.shape[1]
# relative_pos = (cur_len - self.start_len) % 4
# mask_ids = set()
# if relative_pos == 0: # Step 1: Piece
# mask_ids = self.piece_ids
# elif relative_pos == 1: # Step 2: From Square
# mask_ids = self.square_ids
# elif relative_pos == 2: # Step 3: To Square
# mask_ids = self.square_ids
# else: # Step 4: Suffix
# mask_ids = self.suffix_ids
# # Mask out disallowed tokens
# for i in range(scores.shape[1]):
# if i not in mask_ids and i != self.tokenizer.eos_token_id:
# scores[:, i] = float("-inf")
# return scores
# class ChessForCausalLM(PreTrainedModel, GenerationMixin):
# config_class = ChessConfig
# def __init__(self, config: ChessConfig):
# super().__init__(config)
# # 1. Embeddings (No Position Embeddings needed, RoPE handles it!)
# self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
# # 2. Layers
# self.layers = nn.ModuleList([
# SOTATransformerBlock(config) for _ in range(config.n_layer)
# ])
# # 3. Final Norm
# self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
# # 4. Head
# self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# # Tie weights
# if config.tie_weights:
# self.lm_head.weight = self.embed_tokens.weight
# self._tied_weights_keys = ["lm_head.weight"]
# self.post_init()
# def get_input_embeddings(self):
# return self.embed_tokens
# def set_input_embeddings(self, value):
# self.embed_tokens = value
# def get_output_embeddings(self):
# return self.lm_head
# def set_output_embeddings(self, new_embeddings):
# self.lm_head = new_embeddings
# def forward(
# self,
# input_ids: torch.LongTensor = None,
# attention_mask: Optional[torch.Tensor] = None,
# labels: Optional[torch.LongTensor] = None,
# return_dict: Optional[bool] = None,
# **kwargs,
# ) -> Union[Tuple, CausalLMOutputWithPast]:
# batch_size, seq_len = input_ids.shape
# hidden_states = self.embed_tokens(input_ids)
# for layer in self.layers:
# hidden_states = layer(hidden_states, attention_mask)
# hidden_states = self.norm(hidden_states)
# logits = self.lm_head(hidden_states)
# loss = None
# if labels is not None:
# shift_logits = logits[..., :-1, :].contiguous()
# shift_labels = labels[..., 1:].contiguous()
# loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
# loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# if not return_dict:
# output = (logits,)
# return ((loss,) + output) if loss is not None else output
# return CausalLMOutputWithPast(
# loss=loss,
# logits=logits,
# past_key_values=None,
# hidden_states=None,
# attentions=None,
# )
# def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
# # 1. Handle Cache (Past Key Values)
# # If we have a cache, we only need to process the very last token we generated
# if past_key_values:
# input_ids = input_ids[:, -1:]
# # 2. Handle Position IDs
# # If the user didn't provide position_ids, we might need to create them from the attention_mask
# position_ids = kwargs.get("position_ids", None)
# # FIX: Explicitly check 'is not None' to avoid the ambiguous Tensor error
# attention_mask = kwargs.get("attention_mask", None)
# if attention_mask is not None:
# # Create position_ids based on the mask (0, 1, 2... ignoring padding)
# if position_ids is None:
# position_ids = attention_mask.long().cumsum(-1) - 1
# position_ids.masked_fill_(attention_mask == 0, 1)
# # If using cache, we only need the position ID for the last token
# if past_key_values:
# position_ids = position_ids[:, -1].unsqueeze(-1)
# return {
# "input_ids": input_ids,
# "past_key_values": past_key_values,
# "use_cache": kwargs.get("use_cache"),
# "position_ids": position_ids,
# "attention_mask": attention_mask,
# }
# def generate(self, input_ids, **kwargs):
# tokenizer = kwargs.pop("tokenizer", None)
# if tokenizer is not None:
# # Use the 4-step synthesizer
# synthesizer = FourStepConsistency(tokenizer, input_ids.shape[1])
# logits_processor = kwargs.get("logits_processor", LogitsProcessorList())
# logits_processor.append(synthesizer)
# kwargs["logits_processor"] = logits_processor
# # Call GenerationMixin directly to bypass any PreTrainedModel ambiguity
# return GenerationMixin.generate(self, input_ids, **kwargs)
# # Register
# from transformers import AutoConfig, AutoModelForCausalLM
# AutoConfig.register("chess_transformer", ChessConfig)
# AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
"""
SOTA Chess Transformer (Llama/DeepSeek Style)
Updated for the 1M Parameter Challenge.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import LogitsProcessor, LogitsProcessorList
from transformers.generation import GenerationMixin
class ChessConfig(PretrainedConfig):
model_type = "chess_transformer"
def __init__(
self,
vocab_size: int = 1200,
n_embd: int = 128,
n_layer: int = 8,
n_head: int = 4,
n_ctx: int = 256,
n_inner: Optional[int] = None,
dropout: float = 0.0,
rms_norm_eps: float = 1e-6,
tie_weights: bool = True,
pad_token_id: int = 0,
bos_token_id: int = 1,
eos_token_id: int = 2,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_decoder=True,
**kwargs,
)
self.vocab_size = vocab_size
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.n_ctx = n_ctx
self.n_inner = n_inner if n_inner is not None else int(8/3 * n_embd)
self.dropout = dropout
self.rms_norm_eps = rms_norm_eps
self.tie_weights = tie_weights
self.tie_word_embeddings = bool(tie_weights)
self.num_hidden_layers = n_layer
self.hidden_size = n_embd
self.num_attention_heads = n_head
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class SOTAMultiHeadAttention(nn.Module):
def __init__(self, config: ChessConfig):
super().__init__()
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.register_buffer("inv_freq", 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)), persistent=False)
def get_rope_embeddings(self, position_ids, device):
# FIX: Use explicit position_ids instead of arange(seq_len)
# position_ids: [batch, seq_len]
inv_freq = self.inv_freq.to(device)
# Outer product: [batch, seq_len, head_dim/2]
# We need to flatten batch/seq to simplify, or use broadcasting
# freqs = (pos * freq)
# position_ids is [batch, seq], inv_freq is [dim]
# Output should be [batch, seq, dim]
freqs = torch.einsum("bs,d->bsd", position_ids.float(), inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()
def forward(self, x, attention_mask=None, position_ids=None):
batch_size, seq_len, _ = x.size()
q = self.q_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim)
k = self.k_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim)
v = self.v_proj(x).view(batch_size, seq_len, self.n_head, self.head_dim)
q = self.q_norm(q)
k = self.k_norm(k)
# Transpose for RoPE [batch, head, seq, dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if position_ids is None:
position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
cos, sin = self.get_rope_embeddings(position_ids, x.device)
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q = (q * cos) + (rotate_half(q) * sin)
k = (k * cos) + (rotate_half(k) * sin)
# --- FIX: Resolve Conflict between attn_mask and is_causal ---
if attention_mask is not None:
# 1. Expand to 4D for broadcasting if needed
if attention_mask.dim() == 2:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
# 2. Convert Int to Bool if needed (PyTorch SDPA prefers this)
if attention_mask.dtype in [torch.long, torch.int64, torch.int32]:
attention_mask = (attention_mask == 0) # True for masked, False for keep?
# Wait, usually 1=Keep, 0=Mask.
# If using bool mask in SDPA: True = Masked Out (Ignore).
# So if input is 1 (Keep), we want False (Don't Mask).
# If input is 0 (Pad), we want True (Mask).
# So (mask == 0) gives us True for Padding. Correct.
# 3. CRITICAL: If the mask is "Empty" (all False = keep everything),
# drop it so we can use is_causal=True without error.
# (Note: In boolean mask, 'False' means 'Keep')
if not attention_mask.any():
attention_mask = None
# -------------------------------------------------------------
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=True
)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.n_embd)
return self.o_proj(attn_output)
class SwiGLUFeedForward(nn.Module):
def __init__(self, config: ChessConfig):
super().__init__()
self.gate_proj = nn.Linear(config.n_embd, config.n_inner, bias=False)
self.up_proj = nn.Linear(config.n_embd, config.n_inner, bias=False)
self.down_proj = nn.Linear(config.n_inner, config.n_embd, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class SOTATransformerBlock(nn.Module):
def __init__(self, config: ChessConfig):
super().__init__()
self.input_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
self.self_attn = SOTAMultiHeadAttention(config)
self.post_attention_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
self.mlp = SwiGLUFeedForward(config)
def forward(self, x, attention_mask=None, position_ids=None):
# FIX: Pass position_ids down
x = x + self.self_attn(self.input_layernorm(x), attention_mask, position_ids)
x = x + self.mlp(self.post_attention_layernorm(x))
return x
class FourStepConsistency(LogitsProcessor):
def __init__(self, tokenizer, start_len):
self.tokenizer = tokenizer
self.start_len = start_len
self.piece_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.PIECES if t in tokenizer.get_vocab()}
self.square_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.SQUARES if t in tokenizer.get_vocab()}
self.suffix_ids = {tokenizer.convert_tokens_to_ids(t) for t in tokenizer.SUFFIXES if t in tokenizer.get_vocab()}
def __call__(self, input_ids, scores):
cur_len = input_ids.shape[1]
relative_pos = (cur_len - self.start_len) % 4
mask_ids = set()
if relative_pos == 0: mask_ids = self.piece_ids
elif relative_pos == 1: mask_ids = self.square_ids
elif relative_pos == 2: mask_ids = self.square_ids
else: mask_ids = self.suffix_ids
for i in range(scores.shape[1]):
if i not in mask_ids and i != self.tokenizer.eos_token_id:
scores[:, i] = float("-inf")
return scores
class ChessForCausalLM(PreTrainedModel, GenerationMixin):
config_class = ChessConfig
def __init__(self, config: ChessConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
self.layers = nn.ModuleList([SOTATransformerBlock(config) for _ in range(config.n_layer)])
self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
if config.tie_weights:
self.lm_head.weight = self.embed_tokens.weight
self._tied_weights_keys = ["lm_head.weight"]
self.post_init()
def get_input_embeddings(self): return self.embed_tokens
def set_input_embeddings(self, value): self.embed_tokens = value
def get_output_embeddings(self): return self.lm_head
def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings
def forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, **kwargs):
batch_size, seq_len = input_ids.shape
hidden_states = self.embed_tokens(input_ids)
# FIX: Ensure position_ids exist
if position_ids is None:
position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
# FIX: Pass position_ids to layers
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, position_ids)
hidden_states = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
# FORCE NO CACHE: Always process the full sequence.
# This matches our SOTAMultiHeadAttention which handles the full history every time.
position_ids = kwargs.get("position_ids", None)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None and position_ids is None:
# Create position_ids from the mask
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
return {
"input_ids": input_ids, # Return FULL input_ids (do not slice)
"past_key_values": None, # Force None so the model doesn't expect cache
"use_cache": False, # Explicitly disable cache flag
"position_ids": position_ids,
"attention_mask": attention_mask,
}
def generate(self, input_ids, **kwargs):
tokenizer = kwargs.pop("tokenizer", None)
if tokenizer is not None:
synthesizer = FourStepConsistency(tokenizer, input_ids.shape[1])
logits_processor = kwargs.get("logits_processor", LogitsProcessorList())
logits_processor.append(synthesizer)
kwargs["logits_processor"] = logits_processor
return GenerationMixin.generate(self, input_ids, **kwargs)
# Register
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("chess_transformer", ChessConfig)
AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)