simtoonism-classifier / trainer.py
Simon Slamka
finished, just need to compile the dataset
cc90f92
#######
# Simtoon "Simtoonism" Transformer model trainer
# By Simtoon of Ongakken s. r. o.
# the input dataset will grandually be expanded, which will make the resulting model more performant
# Since I am still learning and this is my first from-scratch Transformer, I will be following a tutorial, but I will be making my own changes
# There are two versions - bigram and GPT. I will compare them and see which one is better
#######
import torch
import torch.nn as nn
from torch.nn import functional as F
# hyperparams
batchSize = 128
blockSize = 512
numEpochs = 10000
learningRate = 0.0001
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
evalEpochs = 256
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
# load dataset
with open("dataset.txt", "r", encoding="utf-8") as f:
dataset = f.read()
# some overview
chars = sorted(list(set(dataset)))
vocabSize = len(chars)
print("Vocab size:", vocabSize)
# create char2idx and idx2char
char2idx = {ch: i for i, ch in enumerate(chars)}
idx2char = {i: ch for i, ch in enumerate(chars)}
enc = lambda c: char2idx[c]
dec = lambda l: ''.join([idx2char[i] for i in l])
# split dataset into train and val, where train is 85% of the dataset
data = torch.tensor(enc(dataset), dtype=torch.long)
n = int(len(data) * 0.85)
train, val = data[:n], data[n:]
# create dataloader
def mkBatch(split):
# gen a small batch of data of x and y
data = train if split == "train" else val
ix = torch.randint(len(data) - blockSize, (batchSize,))
x = torch.stack([data[i:i + blockSize] for i in ix])
y = torch.stack([data[i + 1:i + blockSize + 1] for i in ix])
x, y = x.to(dev), y.to(dev)
return x, y
@torch.no_grad()
def estLoss():
out = {}
model.eval()
for split in ["train", "val"]:
losses = torch.zeros(evalEpochs)
for i in range(evalEpochs):
x, y = mkBatch(split)
logits, loss = model(x, y)
losses[i] = loss.item()
out[split] = losses.mean()
model.train()
return out
class Head(nn.Module):
def __init__(self, headSize):
super().__init__()
self.key = nn.Linear(n_embd, headSize, bias=False)
self.query = nn.Linear(n_embd, headSize, bias=False)
self.value = nn.Linear(n_embd, headSize, bias=False)
self.register_buffer("tril", torch.tril(torch.ones(blockSize, blockSize)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# input is (batchSize, time, channels)
# output is (batchSize, time, headSize)
b, t, c = x.shape
k = self.key(x)
q = self.query(x)
w = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
w = w.masked_fill(self.tril[:t, :t] == 0, float("-inf"))
w = F.softmax(w, dim=-1)
w = self.dropout(w)
v = self.value(x)
out = w @ v
return out
class MHA(nn.Module):
def __init__(self, numHeads, headSize):
super().__init__()
self.heads = nn.ModuleList([Head(headSize) for _ in range(numHeads)])
self.proj = nn.Linear(numHeads * headSize, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedForward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.ReLU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout))
def forward(self, x):
return self.net(x)
class Block(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
headSize = n_embd // n_head
self.sa = MHA(n_head, headSize)
self.ffwd = FeedForward(n_embd)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class Model(nn.Module):
def __init__(self):
super().__init__()
# token from logit for next token using lut
self.tokenEmbeddingTable = nn.Embedding(vocabSize, n_embd)
self.positionEmbeddingTable = nn.Embedding(blockSize, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head = n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocabSize)
self.apply(self.init_weights)
def _initWeights(self, mod):
if isinstance(mod, nn.Linear):
torch.nn.init.normal_(mod.weight, std=0.02, mean=0)
if mod.bias is not None:
torch.nn.init.zeros_(mod.bias)
elif isinstance(mod, nn.Embedding):
torch.nn.init.normal_(mod.weight, std=0.02, mean=0)
def forward(self, idx, targets=None):
b, t = idx.shape
tokEmbed = self.tokenEmbeddingTable(idx)
posEmbed = self.positionEmbeddingTable(torch.arange(t, device=dev))
x = tokEmbed + posEmbed
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
if targets is None:
loss = None
else:
b, t, c = logits.shape
logits = logits.view(b * t, c)
targets = targets.view(b * t)
loss = F.cross_entropy(logits, targets)
return logits, loss
def gen(self, idx, genLen):
for _ in range(genLen):
idxCond = idx[:, -blockSize:]
logits, loss = self(idxCond)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
idxNext = torch.multinomial(probs, num_samples = 1)
idx = torch.cat((idx, idxNext), dim=-1)
return idx
mdl = Model().to(dev)
print(sum(p.numel() for p in mdl.parameters()) / 1e6, "M params")
# optimizer
optim = torch.optim.Adam(mdl.parameters(), lr=learningRate)
# training loop
for epoch in range(numEpochs):
if epoch % evalEpochs == 0 or epoch == numEpochs - 1:
losses = estLoss()
print(f"epoch {epoch} train loss {losses['train']:.3f} val loss {losses['val']:.3f}")
# pick data
xb, yb = mkBatch("train")
# eval loss
logits, loss = mdl(xb, yb)
optim.zero_grad(set_to_none=True)
loss.backward()
optim.step()
# generate
cont = torch.zeros((1, 1), dtype=torch.long, device=dev)
print(dec(mdl.gen(cont, 1500)[0].tolist()))