| import time |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| data = """To be, or not to be, that is the question: |
| Whether 'tis nobler in the mind to suffer |
| The slings and arrows of outrageous fortune, |
| Or to take arms against a sea of troubles |
| And by opposing end them. To die—to sleep, |
| No more; and by a sleep to say we end |
| The heart-ache and the thousand natural shocks |
| That flesh is heir to: 'tis a consummation |
| Devoutly to be wish'd. To die, to sleep; |
| To sleep, perchance to dream—ay, there's the rub: |
| For in that sleep of death what dreams may come, |
| When we have shuffled off this mortal coil, |
| Must give us pause—there's the respect |
| That makes calamity of so long life. |
| For who would bear the whips and scorns of time, |
| Th'oppressor's wrong, the proud man's contumely, |
| The pangs of dispriz'd love, the law's delay, |
| The insolence of office, and the spurns |
| That patient merit of th'unworthy takes, |
| When he himself might his quietus make |
| With a bare bodkin? Who would fardels bear, |
| To grunt and sweat under a weary life, |
| But that the dread of something after death, |
| The undiscovere'd country, from whose bourn |
| No traveller returns, puzzles the will, |
| And makes us rather bear those ills we have |
| Than fly to others that we know not of? |
| Thus conscience doth make cowards of us all, |
| And thus the native hue of resolution |
| Is sicklied o'er with the pale cast of thought, |
| And enterprises of great pith and moment |
| With this regard their currents turn awry |
| And lose the name of action.""" |
|
|
| chars = sorted(list(set(data))) |
| vocab_size = len(chars) |
| stoi = {ch: i for i, ch in enumerate(chars)} |
| itos = {i: ch for i, ch in enumerate(chars)} |
| encoded = torch.tensor([stoi[c] for c in data], dtype=torch.long) |
|
|
| |
| D_MODEL = 256 |
| N_LAYERS = 4 |
| MAX_SEQ_LEN = 64 |
| LOCAL_K = 5 |
| GLOBAL_K = 128 |
| FFT_SIZE = 256 |
| TRAIN_TIME = 60 |
| BATCH_SIZE = 8 |
|
|
| |
|
|
| class GlobalConv1D(nn.Module): |
| def __init__(self, d_model, kernel_size, fft_size): |
| super().__init__() |
| self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01) |
| self.kernel_size = kernel_size |
| self.fft_size = fft_size |
|
|
| def forward(self, x): |
| B, C, T = x.shape |
| K = min(self.kernel_size, T) |
| overlap = K - 1 |
| block = self.fft_size - overlap |
|
|
| x = F.pad(x, (overlap, 0)) |
| k = self.kernel[:, :K] |
| k = F.pad(k, (0, self.fft_size - K)) |
| k_f = torch.fft.rfft(k, n=self.fft_size) |
|
|
| outs = [] |
| pos = 0 |
| while pos < T: |
| seg = x[..., pos:pos + self.fft_size] |
| if seg.shape[-1] < self.fft_size: |
| seg = F.pad(seg, (0, self.fft_size - seg.shape[-1])) |
| y = torch.fft.irfft(torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), n=self.fft_size) |
| outs.append(y[..., overlap:overlap + block]) |
| pos += block |
| return torch.cat(outs, dim=-1)[..., :T] |
|
|
| class LocalConv1D(nn.Module): |
| def __init__(self, d_model, k): |
| super().__init__() |
| self.k = k |
| self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model) |
| self.pw = nn.Conv1d(d_model, d_model, 1) |
|
|
| def forward(self, x): |
| x = F.pad(x, (self.k - 1, 0)) |
| return self.pw(F.relu(self.dw(x))) |
|
|
| class Block(nn.Module): |
| def __init__(self, d_model, use_global): |
| super().__init__() |
| self.use_global = use_global |
| self.ln1 = nn.LayerNorm(d_model) |
| self.local = LocalConv1D(d_model, LOCAL_K) |
| if use_global: |
| self.ln2 = nn.LayerNorm(d_model) |
| self.global_conv = GlobalConv1D(d_model, GLOBAL_K, FFT_SIZE) |
| self.ln3 = nn.LayerNorm(d_model) |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_model * 4), |
| nn.GELU(), |
| nn.Linear(d_model * 4, d_model) |
| ) |
|
|
| def forward(self, x): |
| x = x + self.local(self.ln1(x).transpose(1, 2)).transpose(1, 2) |
| if self.use_global: |
| x = x + self.global_conv(self.ln2(x).transpose(1, 2)).transpose(1, 2) |
| return x + self.ff(self.ln3(x)) |
|
|
| class GCLM(nn.Module): |
| def __init__(self, vocab): |
| super().__init__() |
| self.emb = nn.Embedding(vocab, D_MODEL) |
| self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL) |
| self.layers = nn.ModuleList([Block(D_MODEL, i % 2 == 0) for i in range(N_LAYERS)]) |
| self.ln = nn.LayerNorm(D_MODEL) |
| self.head = nn.Linear(D_MODEL, vocab) |
| self.head.weight = self.emb.weight |
|
|
| def forward(self, x): |
| T = x.size(1) |
| h = self.emb(x) + self.pos(torch.arange(T, device=x.device)) |
| for layer in self.layers: |
| h = layer(h) |
| return self.head(self.ln(h)) |
|
|
| |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = GCLM(vocab_size).to(device) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) |
|
|
| print(f"Training on {device} for {TRAIN_TIME} seconds...") |
| start_time = time.time() |
| step = 0 |
|
|
| model.train() |
| while (time.time() - start_time) < TRAIN_TIME: |
| |
| ix = torch.randint(0, len(encoded) - MAX_SEQ_LEN, (BATCH_SIZE,)) |
| x = torch.stack([encoded[i : i + MAX_SEQ_LEN] for i in ix]).to(device) |
| y = torch.stack([encoded[i + 1 : i + MAX_SEQ_LEN + 1] for i in ix]).to(device) |
|
|
| logits = model(x) |
| loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1)) |
|
|
| optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
| optimizer.step() |
|
|
| if step % 10 == 0: |
| elapsed = time.time() - start_time |
| print(f"\rStep {step} | Loss: {loss.item():.4f} | Progress: {min(100, (elapsed/TRAIN_TIME)*100):.1f}%", end="") |
| step += 1 |
|
|
| |
|
|
| print("\n\nTraining Complete. Generating:\n" + "-"*30) |
| model.eval() |
| prompt = "To be, " |
| ctx = torch.tensor([[stoi[c] for c in prompt]], dtype=torch.long, device=device) |
| print(prompt, end="", flush=True) |
|
|
| with torch.no_grad(): |
| for _ in range(MAX_SEQ_LEN * 2): |
| |
| inp = ctx[:, -MAX_SEQ_LEN:] |
| logits = model(inp) |
| logits = logits[:, -1, :] / 0.8 |
| |
| |
| v, _ = torch.topk(logits, min(10, vocab_size)) |
| logits[logits < v[:, [-1]]] = -float('Inf') |
| |
| probs = F.softmax(logits, dim=-1) |
| next_char_idx = torch.multinomial(probs, num_samples=1) |
| |
| ctx = torch.cat((ctx, next_char_idx), dim=1) |
| print(itos[next_char_idx.item()], end="", flush=True) |
| print("\n" + "-"*30) |