|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
D_MODEL = 1024 |
|
|
N_LAYERS = 22 |
|
|
MAX_SEQ_LEN = 1024 |
|
|
LOCAL_KERNEL_SIZE = 3 |
|
|
GLOBAL_KERNEL_SIZE = 512 |
|
|
USE_GLOBAL_EVERY_N_LAYERS = 2 |
|
|
FFT_SIZE = 1024 |
|
|
|
|
|
class GlobalConv1D(nn.Module): |
|
|
def __init__(self, d_model, kernel_size, fft_size): |
|
|
super().__init__() |
|
|
self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01) |
|
|
self.kernel_size = kernel_size |
|
|
self.fft_size = fft_size |
|
|
|
|
|
def forward(self, x): |
|
|
B, C, T = x.shape |
|
|
K = min(self.kernel_size, T) |
|
|
overlap = K - 1 |
|
|
block = self.fft_size - overlap |
|
|
x = F.pad(x, (overlap, 0)) |
|
|
k = self.kernel[:, :K] |
|
|
k = F.pad(k, (0, self.fft_size - K)) |
|
|
k_f = torch.fft.rfft(k, n=self.fft_size) |
|
|
outs = [] |
|
|
pos = 0 |
|
|
while pos < T: |
|
|
seg = x[..., pos:pos+self.fft_size] |
|
|
if seg.shape[-1] < self.fft_size: |
|
|
seg = F.pad(seg, (0, self.fft_size - seg.shape[-1])) |
|
|
y = torch.fft.irfft(torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size) |
|
|
outs.append(y[..., overlap:overlap+block]) |
|
|
pos += block |
|
|
return torch.cat(outs, dim=-1)[..., :T] |
|
|
|
|
|
class LocalConv1D(nn.Module): |
|
|
def __init__(self, d_model, k): |
|
|
super().__init__() |
|
|
self.k = k |
|
|
self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model) |
|
|
self.pw = nn.Conv1d(d_model, d_model, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.pad(x, (self.k - 1, 0)) |
|
|
return self.pw(F.relu(self.dw(x))) |
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, d_model, use_global): |
|
|
super().__init__() |
|
|
self.use_global = use_global |
|
|
self.ln1 = nn.LayerNorm(d_model) |
|
|
self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE) |
|
|
if use_global: |
|
|
self.ln2 = nn.LayerNorm(d_model) |
|
|
self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE) |
|
|
self.ln3 = nn.LayerNorm(d_model) |
|
|
self.ff = nn.Sequential( |
|
|
nn.Linear(d_model, d_model*4), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model*4, d_model) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2) |
|
|
if self.use_global: |
|
|
x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2) |
|
|
return x + self.ff(self.ln3(x)) |
|
|
|
|
|
class ChatGCLM(nn.Module): |
|
|
def __init__(self, vocab): |
|
|
super().__init__() |
|
|
self.emb = nn.Embedding(vocab, D_MODEL) |
|
|
self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL) |
|
|
self.layers = nn.ModuleList([ |
|
|
Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0) |
|
|
for i in range(N_LAYERS) |
|
|
]) |
|
|
self.ln = nn.LayerNorm(D_MODEL) |
|
|
self.head = nn.Linear(D_MODEL, vocab) |
|
|
self.head.weight = self.emb.weight |
|
|
|
|
|
def forward(self, x): |
|
|
T = x.size(1) |
|
|
if T > MAX_SEQ_LEN: |
|
|
x = x[:, -MAX_SEQ_LEN:] |
|
|
T = MAX_SEQ_LEN |
|
|
h = self.emb(x) + self.pos(torch.arange(T, device=x.device)) |
|
|
for layer in self.layers: |
|
|
h = layer(h) |
|
|
return self.head(self.ln(h)) |
|
|
|