import torch import torch.nn as nn import torch.nn.functional as F class RotaryEmbedding(nn.Module): def __init__(self, dim, max_seq_len=512): super().__init__() self.dim = dim inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(max_seq_len, dtype=torch.float) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos(), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False) def _rotate_half(self, x): return torch.cat((-x[..., self.dim // 2:], x[..., :self.dim // 2]), dim=-1) def forward(self, x, seq_len, offset=0): cos = self.cos_cached[offset:offset+seq_len, :].unsqueeze(0).unsqueeze(0) sin = self.sin_cached[offset:offset+seq_len, :].unsqueeze(0).unsqueeze(0) return (x * cos) + (self._rotate_half(x) * sin) class SwiGLU(nn.Module): def __init__(self, d_model, d_ff): super().__init__() self.w1 = nn.Linear(d_model, d_ff, bias=False) self.w2 = nn.Linear(d_model, d_ff, bias=False) self.w3 = nn.Linear(d_ff, d_model, bias=False) def forward(self, x): return self.w3(F.silu(self.w1(x)) * self.w2(x)) class OptimizedAttention(nn.Module): def __init__(self, d_model, nhead, rope, num_kv_heads=2): super().__init__() self.nhead, self.head_dim = nhead, d_model // nhead self.num_kv_heads = num_kv_heads self.num_queries_per_kv = nhead // num_kv_heads self.rope = rope self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) def forward(self, x, past_key_value=None): B, S, C = x.shape q = self.q_proj(x).view(B, S, self.nhead, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) offset = 0 if past_key_value is None else past_key_value[0].size(2) q = self.rope(q, S, offset) k = self.rope(k, S, offset) if past_key_value is not None: past_k, past_v = past_key_value k = torch.cat([past_k, k], dim=2) v = torch.cat([past_v, v], dim=2) present_key_value = (k, v) if self.num_queries_per_kv > 1: k = torch.repeat_interleave(k, repeats=self.num_queries_per_kv, dim=1) v = torch.repeat_interleave(v, repeats=self.num_queries_per_kv, dim=1) attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) out = self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, C)) return out, present_key_value class EfficientDecoderLayer(nn.Module): def __init__(self, d_model, nhead, d_ff, rope, num_kv_heads=2): super().__init__() self.attn = OptimizedAttention(d_model, nhead, rope, num_kv_heads) self.ffn = SwiGLU(d_model, d_ff) self.ln1, self.ln2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model) def forward(self, x, past_key_value=None): attn_out, present_key_value = self.attn(self.ln1(x), past_key_value) x = x + attn_out return x + self.ffn(self.ln2(x)), present_key_value class ChessTransformer(nn.Module): def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, max_length=120, num_kv_heads=2): super().__init__() self.token_embedding = nn.Embedding(vocab_size, d_model) d_ff = int(2 * (d_model * 4) / 3) self.rope = RotaryEmbedding(dim=d_model // nhead, max_seq_len=max_length) self.layers = nn.ModuleList([EfficientDecoderLayer(d_model, nhead, d_ff, self.rope, num_kv_heads) for _ in range(num_layers)]) self.ln_final = nn.LayerNorm(d_model) self.value_head = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.SiLU(), nn.Linear(d_model // 2, 1), nn.Tanh() ) def forward(self, src, past_key_values=None, use_cache=False): x = self.token_embedding(src) present_key_values = [] for i, layer in enumerate(self.layers): past_kv = past_key_values[i] if past_key_values is not None else None x, present_kv = layer(x, past_kv) if use_cache: present_key_values.append(present_kv) hidden = self.ln_final(x) value = self.value_head(hidden).squeeze(-1) if use_cache: return value, tuple(present_key_values) return value