Chess-Web / model.py
dpv007's picture
Upload folder using huggingface_hub
0bdc0cd verified
Raw
History Blame Contribute Delete
4.97 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(max_seq_len, dtype=torch.float)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def _rotate_half(self, x):
return torch.cat((-x[..., self.dim // 2:], x[..., :self.dim // 2]), dim=-1)
def forward(self, x, seq_len, offset=0):
cos = self.cos_cached[offset:offset+seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin_cached[offset:offset+seq_len, :].unsqueeze(0).unsqueeze(0)
return (x * cos) + (self._rotate_half(x) * sin)
class SwiGLU(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_model, d_ff, bias=False)
self.w3 = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x):
return self.w3(F.silu(self.w1(x)) * self.w2(x))
class OptimizedAttention(nn.Module):
def __init__(self, d_model, nhead, rope, num_kv_heads=2):
super().__init__()
self.nhead, self.head_dim = nhead, d_model // nhead
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = nhead // num_kv_heads
self.rope = rope
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, past_key_value=None):
B, S, C = x.shape
q = self.q_proj(x).view(B, S, self.nhead, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
offset = 0 if past_key_value is None else past_key_value[0].size(2)
q = self.rope(q, S, offset)
k = self.rope(k, S, offset)
if past_key_value is not None:
past_k, past_v = past_key_value
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
present_key_value = (k, v)
if self.num_queries_per_kv > 1:
k = torch.repeat_interleave(k, repeats=self.num_queries_per_kv, dim=1)
v = torch.repeat_interleave(v, repeats=self.num_queries_per_kv, dim=1)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
out = self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, C))
return out, present_key_value
class EfficientDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, d_ff, rope, num_kv_heads=2):
super().__init__()
self.attn = OptimizedAttention(d_model, nhead, rope, num_kv_heads)
self.ffn = SwiGLU(d_model, d_ff)
self.ln1, self.ln2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
def forward(self, x, past_key_value=None):
attn_out, present_key_value = self.attn(self.ln1(x), past_key_value)
x = x + attn_out
return x + self.ffn(self.ln2(x)), present_key_value
class ChessTransformer(nn.Module):
def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, max_length=120, num_kv_heads=2):
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, d_model)
d_ff = int(2 * (d_model * 4) / 3)
self.rope = RotaryEmbedding(dim=d_model // nhead, max_seq_len=max_length)
self.layers = nn.ModuleList([EfficientDecoderLayer(d_model, nhead, d_ff, self.rope, num_kv_heads) for _ in range(num_layers)])
self.ln_final = nn.LayerNorm(d_model)
self.value_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.SiLU(),
nn.Linear(d_model // 2, 1),
nn.Tanh()
)
def forward(self, src, past_key_values=None, use_cache=False):
x = self.token_embedding(src)
present_key_values = []
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values is not None else None
x, present_kv = layer(x, past_kv)
if use_cache:
present_key_values.append(present_kv)
hidden = self.ln_final(x)
value = self.value_head(hidden).squeeze(-1)
if use_cache:
return value, tuple(present_key_values)
return value