File size: 3,180 Bytes
53264fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import torch.nn.functional as F

D_MODEL = 512 
N_LAYERS = 8   
MAX_SEQ_LEN = 4096
LOCAL_KERNEL_SIZE = 3
GLOBAL_KERNEL_SIZE = 512
USE_GLOBAL_EVERY_N_LAYERS = 2
FFT_SIZE = MAX_SEQ_LEN

class GlobalConv1D(nn.Module):
    def __init__(self, d_model, kernel_size, fft_size):
        super().__init__()
        self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
        self.kernel_size = kernel_size
        self.fft_size = fft_size

    def forward(self, x):
        B, C, T = x.shape
        K = min(self.kernel_size, T)
        overlap = K - 1
        block = self.fft_size - overlap
        x = F.pad(x, (overlap, 0))
        k = self.kernel[:, :K]
        k = F.pad(k, (0, self.fft_size - K))
        k_f = torch.fft.rfft(k, n=self.fft_size)
        outs = []
        pos = 0
        while pos < T:
            seg = x[..., pos:pos+self.fft_size]
            if seg.shape[-1] < self.fft_size:
                seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
            y = torch.fft.irfft(torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size)
            outs.append(y[..., overlap:overlap+block])
            pos += block
        return torch.cat(outs, dim=-1)[..., :T]

class LocalConv1D(nn.Module):
    def __init__(self, d_model, k):
        super().__init__()
        self.k = k
        self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
        self.pw = nn.Conv1d(d_model, d_model, 1)

    def forward(self, x):
        x = F.pad(x, (self.k - 1, 0))
        return self.pw(F.relu(self.dw(x)))

class Block(nn.Module):
    def __init__(self, d_model, use_global):
        super().__init__()
        self.use_global = use_global
        self.ln1 = nn.LayerNorm(d_model)
        self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE)
        if use_global:
            self.ln2 = nn.LayerNorm(d_model)
            self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE)
        self.ln3 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_model*4),
            nn.GELU(),
            nn.Linear(d_model*4, d_model)
        )

    def forward(self, x):
        x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2)
        if self.use_global:
            x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2)
        return x + self.ff(self.ln3(x))

class ChatGCLM(nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.emb = nn.Embedding(vocab, D_MODEL)
        self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
        self.layers = nn.ModuleList([
            Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0)
            for i in range(N_LAYERS)
        ])
        self.ln = nn.LayerNorm(D_MODEL)
        self.head = nn.Linear(D_MODEL, vocab)
        self.head.weight = self.emb.weight

    def forward(self, x):
        T = x.size(1)
        if T > MAX_SEQ_LEN:
            x = x[:, -MAX_SEQ_LEN:]
            T = MAX_SEQ_LEN
        h = self.emb(x) + self.pos(torch.arange(T, device=x.device))
        for layer in self.layers:
            h = layer(h)
        return self.head(self.ln(h))