Simon Slamka commited on
Commit
cc90f92
·
1 Parent(s): f6f6482

finished, just need to compile the dataset

Browse files
Files changed (1) hide show
  1. trainer.py +190 -2
trainer.py CHANGED
@@ -2,7 +2,195 @@
2
  # Simtoon "Simtoonism" Transformer model trainer
3
  # By Simtoon of Ongakken s. r. o.
4
  # the input dataset will grandually be expanded, which will make the resulting model more performant
 
 
5
  #######
6
 
7
- # with open("simtoon.dat", "r", encoding="utf-8") as f: # this is disabled for now because we don't have the file ready yet
8
- # txt = f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  # Simtoon "Simtoonism" Transformer model trainer
3
  # By Simtoon of Ongakken s. r. o.
4
  # the input dataset will grandually be expanded, which will make the resulting model more performant
5
+ # Since I am still learning and this is my first from-scratch Transformer, I will be following a tutorial, but I will be making my own changes
6
+ # There are two versions - bigram and GPT. I will compare them and see which one is better
7
  #######
8
 
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+
13
+ # hyperparams
14
+ batchSize = 128
15
+ blockSize = 512
16
+ numEpochs = 10000
17
+ learningRate = 0.0001
18
+ dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ evalEpochs = 256
20
+ n_embd = 384
21
+ n_head = 6
22
+ n_layer = 6
23
+ dropout = 0.2
24
+
25
+ # load dataset
26
+ with open("dataset.txt", "r", encoding="utf-8") as f:
27
+ dataset = f.read()
28
+
29
+ # some overview
30
+ chars = sorted(list(set(dataset)))
31
+ vocabSize = len(chars)
32
+ print("Vocab size:", vocabSize)
33
+
34
+ # create char2idx and idx2char
35
+ char2idx = {ch: i for i, ch in enumerate(chars)}
36
+ idx2char = {i: ch for i, ch in enumerate(chars)}
37
+ enc = lambda c: char2idx[c]
38
+ dec = lambda l: ''.join([idx2char[i] for i in l])
39
+
40
+ # split dataset into train and val, where train is 85% of the dataset
41
+ data = torch.tensor(enc(dataset), dtype=torch.long)
42
+ n = int(len(data) * 0.85)
43
+ train, val = data[:n], data[n:]
44
+
45
+ # create dataloader
46
+ def mkBatch(split):
47
+ # gen a small batch of data of x and y
48
+ data = train if split == "train" else val
49
+ ix = torch.randint(len(data) - blockSize, (batchSize,))
50
+ x = torch.stack([data[i:i + blockSize] for i in ix])
51
+ y = torch.stack([data[i + 1:i + blockSize + 1] for i in ix])
52
+ x, y = x.to(dev), y.to(dev)
53
+ return x, y
54
+
55
+ @torch.no_grad()
56
+ def estLoss():
57
+ out = {}
58
+ model.eval()
59
+ for split in ["train", "val"]:
60
+ losses = torch.zeros(evalEpochs)
61
+ for i in range(evalEpochs):
62
+ x, y = mkBatch(split)
63
+ logits, loss = model(x, y)
64
+ losses[i] = loss.item()
65
+ out[split] = losses.mean()
66
+ model.train()
67
+ return out
68
+
69
+ class Head(nn.Module):
70
+ def __init__(self, headSize):
71
+ super().__init__()
72
+ self.key = nn.Linear(n_embd, headSize, bias=False)
73
+ self.query = nn.Linear(n_embd, headSize, bias=False)
74
+ self.value = nn.Linear(n_embd, headSize, bias=False)
75
+ self.register_buffer("tril", torch.tril(torch.ones(blockSize, blockSize)))
76
+ self.dropout = nn.Dropout(dropout)
77
+
78
+ def forward(self, x):
79
+ # input is (batchSize, time, channels)
80
+ # output is (batchSize, time, headSize)
81
+ b, t, c = x.shape
82
+ k = self.key(x)
83
+ q = self.query(x)
84
+ w = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
85
+ w = w.masked_fill(self.tril[:t, :t] == 0, float("-inf"))
86
+ w = F.softmax(w, dim=-1)
87
+ w = self.dropout(w)
88
+ v = self.value(x)
89
+ out = w @ v
90
+ return out
91
+
92
+ class MHA(nn.Module):
93
+ def __init__(self, numHeads, headSize):
94
+ super().__init__()
95
+ self.heads = nn.ModuleList([Head(headSize) for _ in range(numHeads)])
96
+ self.proj = nn.Linear(numHeads * headSize, n_embd)
97
+ self.dropout = nn.Dropout(dropout)
98
+
99
+ def forward(self, x):
100
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
101
+ out = self.dropout(self.proj(out))
102
+ return out
103
+
104
+ class FeedForward(nn.Module):
105
+ def __init__(self, n_embd):
106
+ super().__init__()
107
+ self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.ReLU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout))
108
+
109
+ def forward(self, x):
110
+ return self.net(x)
111
+
112
+ class Block(nn.Module):
113
+ def __init__(self, n_embd, n_head):
114
+ super().__init__()
115
+ headSize = n_embd // n_head
116
+ self.sa = MHA(n_head, headSize)
117
+ self.ffwd = FeedForward(n_embd)
118
+ self.ln1 = nn.LayerNorm(n_embd)
119
+ self.ln2 = nn.LayerNorm(n_embd)
120
+
121
+ def forward(self, x):
122
+ x = x + self.sa(self.ln1(x))
123
+ x = x + self.ffwd(self.ln2(x))
124
+ return x
125
+
126
+ class Model(nn.Module):
127
+ def __init__(self):
128
+ super().__init__()
129
+ # token from logit for next token using lut
130
+ self.tokenEmbeddingTable = nn.Embedding(vocabSize, n_embd)
131
+ self.positionEmbeddingTable = nn.Embedding(blockSize, n_embd)
132
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head = n_head) for _ in range(n_layer)])
133
+ self.ln_f = nn.LayerNorm(n_embd)
134
+ self.lm_head = nn.Linear(n_embd, vocabSize)
135
+ self.apply(self.init_weights)
136
+
137
+ def _initWeights(self, mod):
138
+ if isinstance(mod, nn.Linear):
139
+ torch.nn.init.normal_(mod.weight, std=0.02, mean=0)
140
+ if mod.bias is not None:
141
+ torch.nn.init.zeros_(mod.bias)
142
+ elif isinstance(mod, nn.Embedding):
143
+ torch.nn.init.normal_(mod.weight, std=0.02, mean=0)
144
+
145
+ def forward(self, idx, targets=None):
146
+ b, t = idx.shape
147
+ tokEmbed = self.tokenEmbeddingTable(idx)
148
+ posEmbed = self.positionEmbeddingTable(torch.arange(t, device=dev))
149
+ x = tokEmbed + posEmbed
150
+ x = self.blocks(x)
151
+ x = self.ln_f(x)
152
+ logits = self.lm_head(x)
153
+ if targets is None:
154
+ loss = None
155
+ else:
156
+ b, t, c = logits.shape
157
+ logits = logits.view(b * t, c)
158
+ targets = targets.view(b * t)
159
+ loss = F.cross_entropy(logits, targets)
160
+
161
+ return logits, loss
162
+
163
+ def gen(self, idx, genLen):
164
+ for _ in range(genLen):
165
+ idxCond = idx[:, -blockSize:]
166
+ logits, loss = self(idxCond)
167
+ logits = logits[:, -1, :]
168
+ probs = F.softmax(logits, dim=-1)
169
+ idxNext = torch.multinomial(probs, num_samples = 1)
170
+ idx = torch.cat((idx, idxNext), dim=-1)
171
+ return idx
172
+
173
+ mdl = Model().to(dev)
174
+ print(sum(p.numel() for p in mdl.parameters()) / 1e6, "M params")
175
+
176
+ # optimizer
177
+ optim = torch.optim.Adam(mdl.parameters(), lr=learningRate)
178
+
179
+ # training loop
180
+ for epoch in range(numEpochs):
181
+ if epoch % evalEpochs == 0 or epoch == numEpochs - 1:
182
+ losses = estLoss()
183
+ print(f"epoch {epoch} train loss {losses['train']:.3f} val loss {losses['val']:.3f}")
184
+
185
+ # pick data
186
+ xb, yb = mkBatch("train")
187
+
188
+ # eval loss
189
+ logits, loss = mdl(xb, yb)
190
+ optim.zero_grad(set_to_none=True)
191
+ loss.backward()
192
+ optim.step()
193
+
194
+ # generate
195
+ cont = torch.zeros((1, 1), dtype=torch.long, device=dev)
196
+ print(dec(mdl.gen(cont, 1500)[0].tolist()))