File size: 6,405 Bytes
f6f6482
 
 
 
cc90f92
 
f6f6482
 
cc90f92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
#######
# Simtoon "Simtoonism" Transformer model trainer
# By Simtoon of Ongakken s. r. o.
# the input dataset will grandually be expanded, which will make the resulting model more performant
# Since I am still learning and this is my first from-scratch Transformer, I will be following a tutorial, but I will be making my own changes
# There are two versions - bigram and GPT. I will compare them and see which one is better
#######

import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparams
batchSize = 128
blockSize = 512
numEpochs = 10000
learningRate = 0.0001
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
evalEpochs = 256
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2

# load dataset
with open("dataset.txt", "r", encoding="utf-8") as f:
    dataset = f.read()

# some overview
chars = sorted(list(set(dataset)))
vocabSize = len(chars)
print("Vocab size:", vocabSize)

# create char2idx and idx2char
char2idx = {ch: i for i, ch in enumerate(chars)}
idx2char = {i: ch for i, ch in enumerate(chars)}
enc = lambda c: char2idx[c]
dec = lambda l: ''.join([idx2char[i] for i in l])

# split dataset into train and val, where train is 85% of the dataset
data = torch.tensor(enc(dataset), dtype=torch.long)
n = int(len(data) * 0.85)
train, val = data[:n], data[n:]

# create dataloader
def mkBatch(split):
    # gen a small batch of data of x and y
    data = train if split == "train" else val
    ix = torch.randint(len(data) - blockSize, (batchSize,))
    x = torch.stack([data[i:i + blockSize] for i in ix])
    y = torch.stack([data[i + 1:i + blockSize + 1] for i in ix])
    x, y = x.to(dev), y.to(dev)
    return x, y

@torch.no_grad()
def estLoss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(evalEpochs)
        for i in range(evalEpochs):
            x, y = mkBatch(split)
            logits, loss = model(x, y)
            losses[i] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    def __init__(self, headSize):
        super().__init__()
        self.key = nn.Linear(n_embd, headSize, bias=False)
        self.query = nn.Linear(n_embd, headSize, bias=False)
        self.value = nn.Linear(n_embd, headSize, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(blockSize, blockSize)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input is (batchSize, time, channels)
        # output is (batchSize, time, headSize)
        b, t, c = x.shape
        k = self.key(x)
        q = self.query(x)
        w = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
        w = w.masked_fill(self.tril[:t, :t] == 0, float("-inf"))
        w = F.softmax(w, dim=-1)
        w = self.dropout(w)
        v = self.value(x)
        out = w @ v
        return out

class MHA(nn.Module):
    def __init__(self, numHeads, headSize):
        super().__init__()
        self.heads = nn.ModuleList([Head(headSize) for _ in range(numHeads)])
        self.proj = nn.Linear(numHeads * headSize, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(n_embd, 4 * n_embd), nn.ReLU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout))

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        headSize = n_embd // n_head
        self.sa = MHA(n_head, headSize)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # token from logit for next token using lut
        self.tokenEmbeddingTable = nn.Embedding(vocabSize, n_embd)
        self.positionEmbeddingTable = nn.Embedding(blockSize, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head = n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocabSize)
        self.apply(self.init_weights)

    def _initWeights(self, mod):
        if isinstance(mod, nn.Linear):
            torch.nn.init.normal_(mod.weight, std=0.02, mean=0)
            if mod.bias is not None:
                torch.nn.init.zeros_(mod.bias)
        elif isinstance(mod, nn.Embedding):
            torch.nn.init.normal_(mod.weight, std=0.02, mean=0)

    def forward(self, idx, targets=None):
        b, t = idx.shape
        tokEmbed = self.tokenEmbeddingTable(idx)
        posEmbed = self.positionEmbeddingTable(torch.arange(t, device=dev))
        x = tokEmbed + posEmbed
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        if targets is None:
            loss = None
        else:
            b, t, c = logits.shape
            logits = logits.view(b * t, c)
            targets = targets.view(b * t)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def gen(self, idx, genLen):
        for _ in range(genLen):
            idxCond = idx[:, -blockSize:]
            logits, loss = self(idxCond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idxNext = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, idxNext), dim=-1)
        return idx

mdl = Model().to(dev)
print(sum(p.numel() for p in mdl.parameters()) / 1e6, "M params")

# optimizer
optim = torch.optim.Adam(mdl.parameters(), lr=learningRate)

# training loop
for epoch in range(numEpochs):
    if epoch % evalEpochs == 0 or epoch == numEpochs - 1:
        losses = estLoss()
        print(f"epoch {epoch} train loss {losses['train']:.3f} val loss {losses['val']:.3f}")

    # pick data
    xb, yb = mkBatch("train")

    # eval loss
    logits, loss = mdl(xb, yb)
    optim.zero_grad(set_to_none=True)
    loss.backward()
    optim.step()

# generate
cont = torch.zeros((1, 1), dtype=torch.long, device=dev)
print(dec(mdl.gen(cont, 1500)[0].tolist()))