| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim, max_seq_len=512): |
| super().__init__() |
| self.dim = dim |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| t = torch.arange(max_seq_len, dtype=torch.float) |
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("cos_cached", emb.cos(), persistent=False) |
| self.register_buffer("sin_cached", emb.sin(), persistent=False) |
|
|
| def _rotate_half(self, x): |
| return torch.cat((-x[..., self.dim // 2:], x[..., :self.dim // 2]), dim=-1) |
|
|
| def forward(self, x, seq_len, offset=0): |
| cos = self.cos_cached[offset:offset+seq_len, :].unsqueeze(0).unsqueeze(0) |
| sin = self.sin_cached[offset:offset+seq_len, :].unsqueeze(0).unsqueeze(0) |
| return (x * cos) + (self._rotate_half(x) * sin) |
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, d_model, d_ff): |
| super().__init__() |
| self.w1 = nn.Linear(d_model, d_ff, bias=False) |
| self.w2 = nn.Linear(d_model, d_ff, bias=False) |
| self.w3 = nn.Linear(d_ff, d_model, bias=False) |
|
|
| def forward(self, x): |
| return self.w3(F.silu(self.w1(x)) * self.w2(x)) |
|
|
| class OptimizedAttention(nn.Module): |
| def __init__(self, d_model, nhead, rope, num_kv_heads=2): |
| super().__init__() |
| self.nhead, self.head_dim = nhead, d_model // nhead |
| self.num_kv_heads = num_kv_heads |
| self.num_queries_per_kv = nhead // num_kv_heads |
| self.rope = rope |
| |
| self.q_proj = nn.Linear(d_model, d_model, bias=False) |
| self.k_proj = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False) |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) |
|
|
| def forward(self, x, past_key_value=None): |
| B, S, C = x.shape |
| q = self.q_proj(x).view(B, S, self.nhead, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| |
| offset = 0 if past_key_value is None else past_key_value[0].size(2) |
| q = self.rope(q, S, offset) |
| k = self.rope(k, S, offset) |
| |
| if past_key_value is not None: |
| past_k, past_v = past_key_value |
| k = torch.cat([past_k, k], dim=2) |
| v = torch.cat([past_v, v], dim=2) |
| |
| present_key_value = (k, v) |
| |
| if self.num_queries_per_kv > 1: |
| k = torch.repeat_interleave(k, repeats=self.num_queries_per_kv, dim=1) |
| v = torch.repeat_interleave(v, repeats=self.num_queries_per_kv, dim=1) |
| |
| attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| out = self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, C)) |
| return out, present_key_value |
|
|
| class EfficientDecoderLayer(nn.Module): |
| def __init__(self, d_model, nhead, d_ff, rope, num_kv_heads=2): |
| super().__init__() |
| self.attn = OptimizedAttention(d_model, nhead, rope, num_kv_heads) |
| self.ffn = SwiGLU(d_model, d_ff) |
| self.ln1, self.ln2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model) |
|
|
| def forward(self, x, past_key_value=None): |
| attn_out, present_key_value = self.attn(self.ln1(x), past_key_value) |
| x = x + attn_out |
| return x + self.ffn(self.ln2(x)), present_key_value |
|
|
| class ChessTransformer(nn.Module): |
| def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, max_length=120, num_kv_heads=2): |
| super().__init__() |
| self.token_embedding = nn.Embedding(vocab_size, d_model) |
| d_ff = int(2 * (d_model * 4) / 3) |
| self.rope = RotaryEmbedding(dim=d_model // nhead, max_seq_len=max_length) |
| self.layers = nn.ModuleList([EfficientDecoderLayer(d_model, nhead, d_ff, self.rope, num_kv_heads) for _ in range(num_layers)]) |
| self.ln_final = nn.LayerNorm(d_model) |
| self.value_head = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), |
| nn.SiLU(), |
| nn.Linear(d_model // 2, 1), |
| nn.Tanh() |
| ) |
| |
| def forward(self, src, past_key_values=None, use_cache=False): |
| x = self.token_embedding(src) |
| present_key_values = [] |
| for i, layer in enumerate(self.layers): |
| past_kv = past_key_values[i] if past_key_values is not None else None |
| x, present_kv = layer(x, past_kv) |
| if use_cache: |
| present_key_values.append(present_kv) |
| |
| hidden = self.ln_final(x) |
| value = self.value_head(hidden).squeeze(-1) |
| |
| if use_cache: |
| return value, tuple(present_key_values) |
| return value |
|
|