| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| D_MODEL = 256 | |
| N_LAYERS = 4 | |
| MAX_SEQ_LEN = 1024 | |
| LOCAL_KERNEL_SIZE = 5 | |
| GLOBAL_KERNEL_SIZE = 256 | |
| USE_GLOBAL_EVERY_N_LAYERS = 2 | |
| FFT_SIZE = 1024 | |
| class GlobalConv1D(nn.Module): | |
| def __init__(self, d_model, kernel_size, fft_size): | |
| super().__init__() | |
| self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01) | |
| self.kernel_size = kernel_size | |
| self.fft_size = fft_size | |
| def forward(self, x): | |
| B, C, T = x.shape | |
| K = min(self.kernel_size, T) | |
| overlap = K - 1 | |
| block = self.fft_size - overlap | |
| x = F.pad(x, (overlap, 0)) | |
| k = self.kernel[:, :K] | |
| k = F.pad(k, (0, self.fft_size - K)) | |
| k_f = torch.fft.rfft(k, n=self.fft_size) | |
| outs = [] | |
| pos = 0 | |
| while pos < T: | |
| seg = x[..., pos:pos+self.fft_size] | |
| if seg.shape[-1] < self.fft_size: | |
| seg = F.pad(seg, (0, self.fft_size - seg.shape[-1])) | |
| y = torch.fft.irfft(torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size) | |
| outs.append(y[..., overlap:overlap+block]) | |
| pos += block | |
| return torch.cat(outs, dim=-1)[..., :T] | |
| class LocalConv1D(nn.Module): | |
| def __init__(self, d_model, k): | |
| super().__init__() | |
| self.k = k | |
| self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model) | |
| self.pw = nn.Conv1d(d_model, d_model, 1) | |
| def forward(self, x): | |
| x = F.pad(x, (self.k - 1, 0)) | |
| return self.pw(F.relu(self.dw(x))) | |
| class Block(nn.Module): | |
| def __init__(self, d_model, use_global): | |
| super().__init__() | |
| self.use_global = use_global | |
| self.ln1 = nn.LayerNorm(d_model) | |
| self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE) | |
| if use_global: | |
| self.ln2 = nn.LayerNorm(d_model) | |
| self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE) | |
| self.ln3 = nn.LayerNorm(d_model) | |
| self.ff = nn.Sequential( | |
| nn.Linear(d_model, d_model*4), | |
| nn.GELU(), | |
| nn.Linear(d_model*4, d_model) | |
| ) | |
| def forward(self, x): | |
| x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2) | |
| if self.use_global: | |
| x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2) | |
| return x + self.ff(self.ln3(x)) | |
| class Crimson(nn.Module): | |
| def __init__(self, vocab): | |
| super().__init__() | |
| self.emb = nn.Embedding(vocab, D_MODEL) | |
| self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL) | |
| self.layers = nn.ModuleList([ | |
| Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0) | |
| for i in range(N_LAYERS) | |
| ]) | |
| self.ln = nn.LayerNorm(D_MODEL) | |
| self.head = nn.Linear(D_MODEL, vocab) | |
| self.head.weight = self.emb.weight | |
| def forward(self, x): | |
| T = x.size(1) | |
| if T > MAX_SEQ_LEN: | |
| x = x[:, -MAX_SEQ_LEN:] | |
| T = MAX_SEQ_LEN | |
| h = self.emb(x) + self.pos(torch.arange(T, device=x.device)) | |
| for layer in self.layers: | |
| h = layer(h) | |
| return self.head(self.ln(h)) | |