####### # Simtoon "Simtoonism" Transformer model trainer # By Simtoon of Ongakken s. r. o. # the input dataset will grandually be expanded, which will make the resulting model more performant # Since I am still learning and this is my first from-scratch Transformer, I will be following a tutorial, but I will be making my own changes # There are two versions - bigram and GPT. I will compare them and see which one is better ####### import torch import torch.nn as nn from torch.nn import functional as F # hyperparams batchSize = 128 blockSize = 512 numEpochs = 10000 learningRate = 0.0001 dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") evalEpochs = 256 n_embd = 384 n_head = 6 n_layer = 6 dropout = 0.2 # load dataset with open("dataset.txt", "r", encoding="utf-8") as f: dataset = f.read() # some overview chars = sorted(list(set(dataset))) vocabSize = len(chars) print("Vocab size:", vocabSize) # create char2idx and idx2char char2idx = {ch: i for i, ch in enumerate(chars)} idx2char = {i: ch for i, ch in enumerate(chars)} enc = lambda c: char2idx[c] dec = lambda l: ''.join([idx2char[i] for i in l]) # split dataset into train and val, where train is 85% of the dataset data = torch.tensor(enc(dataset), dtype=torch.long) n = int(len(data) * 0.85) train, val = data[:n], data[n:] # create dataloader def mkBatch(split): # gen a small batch of data of x and y data = train if split == "train" else val ix = torch.randint(len(data) - blockSize, (batchSize,)) x = torch.stack([data[i:i + blockSize] for i in ix]) y = torch.stack([data[i + 1:i + blockSize + 1] for i in ix]) x, y = x.to(dev), y.to(dev) return x, y @torch.no_grad() def estLoss(): out = {} model.eval() for split in ["train", "val"]: losses = torch.zeros(evalEpochs) for i in range(evalEpochs): x, y = mkBatch(split) logits, loss = model(x, y) losses[i] = loss.item() out[split] = losses.mean() model.train() return out class Head(nn.Module): def __init__(self, headSize): super().__init__() self.key = nn.Linear(n_embd, headSize, bias=False) self.query = nn.Linear(n_embd, headSize, bias=False) self.value = nn.Linear(n_embd, headSize, bias=False) self.register_buffer("tril", torch.tril(torch.ones(blockSize, blockSize))) self.dropout = nn.Dropout(dropout) def forward(self, x): # input is (batchSize, time, channels) # output is (batchSize, time, headSize) b, t, c = x.shape k = self.key(x) q = self.query(x) w = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5 w = w.masked_fill(self.tril[:t, :t] == 0, float("-inf")) w = F.softmax(w, dim=-1) w = self.dropout(w) v = self.value(x) out = w @ v return out class MHA(nn.Module): def __init__(self, numHeads, headSize): super().__init__() self.heads = nn.ModuleList([Head(headSize) for _ in range(numHeads)]) self.proj = nn.Linear(numHeads * headSize, n_embd) self.dropout = nn.Dropout(dropout) def forward(self, x): out = torch.cat([h(x) for h in self.heads], dim=-1) out = self.dropout(self.proj(out)) return out class FeedForward(nn.Module): def __init__(self, n_embd): super().__init__() self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.ReLU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout)) def forward(self, x): return self.net(x) class Block(nn.Module): def __init__(self, n_embd, n_head): super().__init__() headSize = n_embd // n_head self.sa = MHA(n_head, headSize) self.ffwd = FeedForward(n_embd) self.ln1 = nn.LayerNorm(n_embd) self.ln2 = nn.LayerNorm(n_embd) def forward(self, x): x = x + self.sa(self.ln1(x)) x = x + self.ffwd(self.ln2(x)) return x class Model(nn.Module): def __init__(self): super().__init__() # token from logit for next token using lut self.tokenEmbeddingTable = nn.Embedding(vocabSize, n_embd) self.positionEmbeddingTable = nn.Embedding(blockSize, n_embd) self.blocks = nn.Sequential(*[Block(n_embd, n_head = n_head) for _ in range(n_layer)]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocabSize) self.apply(self.init_weights) def _initWeights(self, mod): if isinstance(mod, nn.Linear): torch.nn.init.normal_(mod.weight, std=0.02, mean=0) if mod.bias is not None: torch.nn.init.zeros_(mod.bias) elif isinstance(mod, nn.Embedding): torch.nn.init.normal_(mod.weight, std=0.02, mean=0) def forward(self, idx, targets=None): b, t = idx.shape tokEmbed = self.tokenEmbeddingTable(idx) posEmbed = self.positionEmbeddingTable(torch.arange(t, device=dev)) x = tokEmbed + posEmbed x = self.blocks(x) x = self.ln_f(x) logits = self.lm_head(x) if targets is None: loss = None else: b, t, c = logits.shape logits = logits.view(b * t, c) targets = targets.view(b * t) loss = F.cross_entropy(logits, targets) return logits, loss def gen(self, idx, genLen): for _ in range(genLen): idxCond = idx[:, -blockSize:] logits, loss = self(idxCond) logits = logits[:, -1, :] probs = F.softmax(logits, dim=-1) idxNext = torch.multinomial(probs, num_samples = 1) idx = torch.cat((idx, idxNext), dim=-1) return idx mdl = Model().to(dev) print(sum(p.numel() for p in mdl.parameters()) / 1e6, "M params") # optimizer optim = torch.optim.Adam(mdl.parameters(), lr=learningRate) # training loop for epoch in range(numEpochs): if epoch % evalEpochs == 0 or epoch == numEpochs - 1: losses = estLoss() print(f"epoch {epoch} train loss {losses['train']:.3f} val loss {losses['val']:.3f}") # pick data xb, yb = mkBatch("train") # eval loss logits, loss = mdl(xb, yb) optim.zero_grad(set_to_none=True) loss.backward() optim.step() # generate cont = torch.zeros((1, 1), dtype=torch.long, device=dev) print(dec(mdl.gen(cont, 1500)[0].tolist()))