File size: 4,972 Bytes
769f59f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bdc0cd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
import torch.nn.functional as F

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=512):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        t = torch.arange(max_seq_len, dtype=torch.float)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)

    def _rotate_half(self, x):
        return torch.cat((-x[..., self.dim // 2:], x[..., :self.dim // 2]), dim=-1)

    def forward(self, x, seq_len, offset=0):
        cos = self.cos_cached[offset:offset+seq_len, :].unsqueeze(0).unsqueeze(0)
        sin = self.sin_cached[offset:offset+seq_len, :].unsqueeze(0).unsqueeze(0)
        return (x * cos) + (self._rotate_half(x) * sin)

class SwiGLU(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_model, d_ff, bias=False)
        self.w3 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class OptimizedAttention(nn.Module):
    def __init__(self, d_model, nhead, rope, num_kv_heads=2):
        super().__init__()
        self.nhead, self.head_dim = nhead, d_model // nhead
        self.num_kv_heads = num_kv_heads
        self.num_queries_per_kv = nhead // num_kv_heads
        self.rope = rope
        
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, past_key_value=None):
        B, S, C = x.shape
        q = self.q_proj(x).view(B, S, self.nhead, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
        
        offset = 0 if past_key_value is None else past_key_value[0].size(2)
        q = self.rope(q, S, offset)
        k = self.rope(k, S, offset)
        
        if past_key_value is not None:
            past_k, past_v = past_key_value
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)
            
        present_key_value = (k, v)
        
        if self.num_queries_per_kv > 1:
            k = torch.repeat_interleave(k, repeats=self.num_queries_per_kv, dim=1)
            v = torch.repeat_interleave(v, repeats=self.num_queries_per_kv, dim=1)
            
        attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        out = self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, C))
        return out, present_key_value

class EfficientDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, d_ff, rope, num_kv_heads=2):
        super().__init__()
        self.attn = OptimizedAttention(d_model, nhead, rope, num_kv_heads)
        self.ffn = SwiGLU(d_model, d_ff)
        self.ln1, self.ln2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)

    def forward(self, x, past_key_value=None):
        attn_out, present_key_value = self.attn(self.ln1(x), past_key_value)
        x = x + attn_out
        return x + self.ffn(self.ln2(x)), present_key_value

class ChessTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, max_length=120, num_kv_heads=2):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        d_ff = int(2 * (d_model * 4) / 3)
        self.rope = RotaryEmbedding(dim=d_model // nhead, max_seq_len=max_length)
        self.layers = nn.ModuleList([EfficientDecoderLayer(d_model, nhead, d_ff, self.rope, num_kv_heads) for _ in range(num_layers)])
        self.ln_final = nn.LayerNorm(d_model)
        self.value_head = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.SiLU(),
            nn.Linear(d_model // 2, 1),
            nn.Tanh()
        )
        
    def forward(self, src, past_key_values=None, use_cache=False):
        x = self.token_embedding(src)
        present_key_values = []
        for i, layer in enumerate(self.layers):
            past_kv = past_key_values[i] if past_key_values is not None else None
            x, present_kv = layer(x, past_kv)
            if use_cache:
                present_key_values.append(present_kv)
            
        hidden = self.ln_final(x)
        value = self.value_head(hidden).squeeze(-1)
        
        if use_cache:
            return value, tuple(present_key_values)
        return value