ibrahimmkhalid commited on
Commit
0a83ccf
·
1 Parent(s): f8ea1ea

simple gpt model based on entire works of shakespeare

Browse files
Files changed (2) hide show
  1. gpt_shakespeare.sync.ipynb +428 -0
  2. gpt_shakespeare.sync.py +251 -0
gpt_shakespeare.sync.ipynb ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "059837a0",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "cuda\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "import torch\n",
19
+ "import torch.nn as nn\n",
20
+ "from torch.nn import functional as F\n",
21
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
22
+ "print(device)\n",
23
+ "block_size = 128\n",
24
+ "batch_size = 32\n",
25
+ "max_iters = 4000\n",
26
+ "learning_rate = 3e-4\n",
27
+ "eval_every = 500\n",
28
+ "n_embd = 384\n",
29
+ "n_head = 8\n",
30
+ "n_layer = 8\n",
31
+ "dropout = 0.2"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 2,
37
+ "id": "1fdc69a7",
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "with open(\"shakespeare.txt\") as f:\n",
42
+ " text = f.read()"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 3,
48
+ "id": "0c09eeb0",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "chars = sorted(set(text))\n",
53
+ "vocab_size = len(chars)"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 4,
59
+ "id": "a278e7b9",
60
+ "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "name": "stdout",
64
+ "output_type": "stream",
65
+ "text": [
66
+ "Vocab size: 101\n",
67
+ "Text length: 5357910\n"
68
+ ]
69
+ }
70
+ ],
71
+ "source": [
72
+ "print(f\"Vocab size: {vocab_size}\")\n",
73
+ "print(f\"Text length: {len(text)}\")"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 5,
79
+ "id": "2a540d96",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "string_to_int = {ch: i for i, ch in enumerate(chars)}\n",
84
+ "int_to_string = {i: ch for i, ch in enumerate(chars)}\n",
85
+ "\n",
86
+ "encode = lambda s: [string_to_int[ch] for ch in s]\n",
87
+ "decode = lambda x: ''.join([int_to_string[i] for i in x])\n",
88
+ "\n",
89
+ "data = torch.tensor(encode(text), dtype=torch.long, device=device)"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 6,
95
+ "id": "c7c8e4aa",
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "n = int(0.8 * len(data))\n",
100
+ "train_data = data[:n]\n",
101
+ "val_data = data[n:]"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 7,
107
+ "id": "54d80f45",
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "def get_batch(split):\n",
112
+ " data = train_data if split == 'train' else val_data\n",
113
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
114
+ " x = torch.stack([data[i:i+block_size] for i in ix])\n",
115
+ " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
116
+ " x, y = x.to(device), y.to(device)\n",
117
+ " return x, y"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": 8,
123
+ "id": "618df2dc",
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "@torch.no_grad()\n",
128
+ "def estimate_loss():\n",
129
+ " out = {}\n",
130
+ " model.eval()\n",
131
+ " for split in ['train', 'val']:\n",
132
+ " losses = torch.zeros(eval_every)\n",
133
+ " for k in range(eval_every):\n",
134
+ " X, Y = get_batch(split)\n",
135
+ " logits, loss = model(X, Y)\n",
136
+ " losses[k] = loss.item()\n",
137
+ " out[split] = losses.mean()\n",
138
+ " model.train()\n",
139
+ " return out"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 9,
145
+ "id": "d0a21928",
146
+ "metadata": {},
147
+ "outputs": [
148
+ {
149
+ "name": "stdout",
150
+ "output_type": "stream",
151
+ "text": [
152
+ "step: 0, train loss: 4.615, val loss: 4.613\n",
153
+ "step: 500, train loss: 1.923, val loss: 1.961\n",
154
+ "step: 1000, train loss: 1.662, val loss: 1.753\n",
155
+ "step: 1500, train loss: 1.531, val loss: 1.655\n",
156
+ "step: 2000, train loss: 1.453, val loss: 1.608\n",
157
+ "step: 2500, train loss: 1.398, val loss: 1.567\n",
158
+ "step: 3000, train loss: 1.365, val loss: 1.543\n",
159
+ "step: 3500, train loss: 1.340, val loss: 1.529\n",
160
+ "1.3418211936950684\n"
161
+ ]
162
+ }
163
+ ],
164
+ "source": [
165
+ "\n",
166
+ "class Head(nn.Module):\n",
167
+ " \"\"\" one head of self-attention \"\"\"\n",
168
+ "\n",
169
+ " def __init__(self, head_size):\n",
170
+ " super().__init__()\n",
171
+ " self.key = nn.Linear(n_embd, head_size, bias=False)\n",
172
+ " self.query = nn.Linear(n_embd, head_size, bias=False)\n",
173
+ " self.value = nn.Linear(n_embd, head_size, bias=False)\n",
174
+ " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
175
+ "\n",
176
+ " self.dropout = nn.Dropout(dropout)\n",
177
+ "\n",
178
+ " def forward(self, x):\n",
179
+ " # input of size (batch, time-step, channels)\n",
180
+ " # output of size (batch, time-step, head size)\n",
181
+ " B,T,C = x.shape\n",
182
+ " k = self.key(x) # (B,T,hs)\n",
183
+ " q = self.query(x) # (B,T,hs)\n",
184
+ " # compute attention scores (\"affinities\")\n",
185
+ " wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)\n",
186
+ " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
187
+ " wei = F.softmax(wei, dim=-1) # (B, T, T)\n",
188
+ " wei = self.dropout(wei)\n",
189
+ " # perform the weighted aggregation of the values\n",
190
+ " v = self.value(x) # (B,T,hs)\n",
191
+ " out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)\n",
192
+ " return out\n",
193
+ "\n",
194
+ "# [1, 0, 0]\n",
195
+ "# [1, 0.6, 0]\n",
196
+ "# [1, 0.6, 0.4]\n",
197
+ "class MultiHeadAttention(nn.Module):\n",
198
+ " \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
199
+ "\n",
200
+ " def __init__(self, num_heads, head_size):\n",
201
+ " super().__init__()\n",
202
+ " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
203
+ " self.proj = nn.Linear(head_size * num_heads, n_embd)\n",
204
+ " self.dropout = nn.Dropout(dropout)\n",
205
+ "\n",
206
+ " def forward(self, x):\n",
207
+ " out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, F) -> (B, T, [h1, h1, h1, h1, h2, h2, h2, h2, h3, h3, h3, h3])\n",
208
+ " out = self.dropout(self.proj(out))\n",
209
+ " return out\n",
210
+ " \n",
211
+ "\n",
212
+ "class FeedFoward(nn.Module):\n",
213
+ " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n",
214
+ "\n",
215
+ " def __init__(self, n_embd):\n",
216
+ " super().__init__()\n",
217
+ " self.net = nn.Sequential(\n",
218
+ " nn.Linear(n_embd, 4 * n_embd),\n",
219
+ " nn.ReLU(),\n",
220
+ " nn.Linear(4 * n_embd, n_embd),\n",
221
+ " nn.Dropout(dropout),\n",
222
+ " )\n",
223
+ "\n",
224
+ " def forward(self, x):\n",
225
+ " return self.net(x)\n",
226
+ " \n",
227
+ "class Block(nn.Module):\n",
228
+ " \"\"\" Transformer block: communication followed by computation \"\"\"\n",
229
+ "\n",
230
+ " def __init__(self, n_embd, n_head):\n",
231
+ " # n_embd: embedding dimension, n_head: the number of heads we'd like\n",
232
+ " super().__init__()\n",
233
+ " head_size = n_embd // n_head\n",
234
+ " self.sa = MultiHeadAttention(n_head, head_size)\n",
235
+ " self.ffwd = FeedFoward(n_embd)\n",
236
+ " self.ln1 = nn.LayerNorm(n_embd)\n",
237
+ " self.ln2 = nn.LayerNorm(n_embd)\n",
238
+ "\n",
239
+ " def forward(self, x):\n",
240
+ " y = self.sa(x)\n",
241
+ " x = self.ln1(x + y)\n",
242
+ " y = self.ffwd(x)\n",
243
+ " x = self.ln2(x + y)\n",
244
+ " return x\n",
245
+ " \n",
246
+ "class GPTLanguageModel(nn.Module):\n",
247
+ " def __init__(self, vocab_size):\n",
248
+ " super().__init__()\n",
249
+ " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n",
250
+ " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n",
251
+ " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n",
252
+ " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n",
253
+ " self.lm_head = nn.Linear(n_embd, vocab_size)\n",
254
+ " \n",
255
+ " \n",
256
+ " self.apply(self._init_weights)\n",
257
+ "\n",
258
+ " def _init_weights(self, module):\n",
259
+ " if isinstance(module, nn.Linear):\n",
260
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
261
+ " if module.bias is not None:\n",
262
+ " torch.nn.init.zeros_(module.bias)\n",
263
+ " elif isinstance(module, nn.Embedding):\n",
264
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
265
+ "\n",
266
+ " def forward(self, index, targets=None):\n",
267
+ " B, T = index.shape\n",
268
+ " \n",
269
+ " \n",
270
+ " # idx and targets are both (B,T) tensor of integers\n",
271
+ " tok_emb = self.token_embedding_table(index) # (B,T,C)\n",
272
+ " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n",
273
+ " x = tok_emb + pos_emb # (B,T,C)\n",
274
+ " x = self.blocks(x) # (B,T,C)\n",
275
+ " x = self.ln_f(x) # (B,T,C)\n",
276
+ " logits = self.lm_head(x) # (B,T,vocab_size)\n",
277
+ " \n",
278
+ " if targets is None:\n",
279
+ " loss = None\n",
280
+ " else:\n",
281
+ " B, T, C = logits.shape\n",
282
+ " logits = logits.view(B*T, C) # reshape to what torch.cross_entropy expects\n",
283
+ " targets = targets.view(B*T)\n",
284
+ " loss = F.cross_entropy(logits, targets) \n",
285
+ " return logits, loss\n",
286
+ " \n",
287
+ " def generate(self, index, max_new_tokens):\n",
288
+ " # index is (B, T) array of indices in the current context\n",
289
+ " for _ in range(max_new_tokens):\n",
290
+ " # crop idx to the last block_size tokens\n",
291
+ " index_cond = index[:, -block_size:]\n",
292
+ " # get the predictions\n",
293
+ " logits, loss = self.forward(index_cond)\n",
294
+ " # focus only on the last time step\n",
295
+ " logits = logits[:, -1, :] # becomes (B, C)\n",
296
+ " # apply softmax to get probabilities\n",
297
+ " probs = F.softmax(logits, dim=-1) # (B, C)\n",
298
+ " # sample from the distribution\n",
299
+ " index_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
300
+ " # append sampled index to the running sequence\n",
301
+ " index = torch.cat((index, index_next), dim=1) # (B, T+1)\n",
302
+ " return index\n",
303
+ "\n",
304
+ "model = GPTLanguageModel(vocab_size).to(device)\n",
305
+ "\n",
306
+ "# create a PyTorch optimizer\n",
307
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
308
+ "\n",
309
+ "for iter in range(max_iters):\n",
310
+ " if iter % eval_every == 0:\n",
311
+ " losses = estimate_loss()\n",
312
+ " print(f\"step: {iter}, train loss: {losses['train']:.3f}, val loss: {losses['val']:.3f}\")\n",
313
+ "\n",
314
+ " # sample a batch of data\n",
315
+ " xb, yb = get_batch('train')\n",
316
+ "\n",
317
+ " # evaluate the loss\n",
318
+ " logits, loss = model.forward(xb, yb)\n",
319
+ " optimizer.zero_grad(set_to_none=True)\n",
320
+ " loss.backward()\n",
321
+ " optimizer.step()\n",
322
+ "print(loss.item())"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": 10,
328
+ "id": "99a66247",
329
+ "metadata": {},
330
+ "outputs": [
331
+ {
332
+ "name": "stdout",
333
+ "output_type": "stream",
334
+ "text": [
335
+ "\t them part it. The leison drows\n",
336
+ "Let them napposes them.\n",
337
+ "\n",
338
+ "SUFFUE.\n",
339
+ "Yea, erow now, he was near angless.\n"
340
+ ]
341
+ }
342
+ ],
343
+ "source": [
344
+ "\n",
345
+ "context = torch.zeros((1,1), dtype=torch.long, device=device)\n",
346
+ "generated_chars = decode(model.generate(context, max_new_tokens=100)[0].tolist())\n",
347
+ "print(generated_chars)"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": 11,
353
+ "id": "c2b03115",
354
+ "metadata": {},
355
+ "outputs": [
356
+ {
357
+ "name": "stdout",
358
+ "output_type": "stream",
359
+ "text": [
360
+ "To be or not to be, my lord; and at Worces and will.\n",
361
+ "\n",
362
+ "FRIAR L JOHN.\n",
363
+ "My dory Gold Catesby say the King vow you are.\n",
364
+ "\n",
365
+ "ENT\n"
366
+ ]
367
+ }
368
+ ],
369
+ "source": [
370
+ "\n",
371
+ "prompt = 'To be or not to be,'\n",
372
+ "context = torch.tensor(encode(prompt), dtype=torch.long, device=device)\n",
373
+ "generated_chars = decode(model.generate(context.unsqueeze(0), max_new_tokens=100)[0].tolist())\n",
374
+ "print(generated_chars)"
375
+ ]
376
+ }
377
+ ],
378
+ "metadata": {
379
+ "kernelspec": {
380
+ "display_name": "Python 3 (ipykernel)",
381
+ "language": "python",
382
+ "name": "python3"
383
+ },
384
+ "language_info": {
385
+ "codemirror_mode": {
386
+ "name": "ipython",
387
+ "version": 3
388
+ },
389
+ "file_extension": ".py",
390
+ "mimetype": "text/x-python",
391
+ "name": "python",
392
+ "nbconvert_exporter": "python",
393
+ "pygments_lexer": "ipython3",
394
+ "version": "3.10.12"
395
+ },
396
+ "varInspector": {
397
+ "cols": {
398
+ "lenName": 16,
399
+ "lenType": 16,
400
+ "lenVar": 40
401
+ },
402
+ "kernels_config": {
403
+ "python": {
404
+ "delete_cmd_postfix": "",
405
+ "delete_cmd_prefix": "del ",
406
+ "library": "var_list.py",
407
+ "varRefreshCmd": "print(var_dic_list())"
408
+ },
409
+ "r": {
410
+ "delete_cmd_postfix": ") ",
411
+ "delete_cmd_prefix": "rm(",
412
+ "library": "var_list.r",
413
+ "varRefreshCmd": "cat(var_dic_list()) "
414
+ }
415
+ },
416
+ "types_to_exclude": [
417
+ "module",
418
+ "function",
419
+ "builtin_function_or_method",
420
+ "instance",
421
+ "_Feature"
422
+ ],
423
+ "window_display": false
424
+ }
425
+ },
426
+ "nbformat": 4,
427
+ "nbformat_minor": 5
428
+ }
gpt_shakespeare.sync.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---
2
+ # jupyter:
3
+ # jupytext:
4
+ # text_representation:
5
+ # extension: .py
6
+ # format_name: percent
7
+ # format_version: '1.3'
8
+ # jupytext_version: 1.3.4
9
+ # kernelspec:
10
+ # display_name: Python 3
11
+ # language: python
12
+ # name: python3
13
+ # ---
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ print(device)
19
+ block_size = 128
20
+ batch_size = 32
21
+ max_iters = 4000
22
+ learning_rate = 3e-4
23
+ eval_every = 500
24
+ n_embd = 384
25
+ n_head = 8
26
+ n_layer = 8
27
+ dropout = 0.2
28
+
29
+ # %%
30
+ with open("shakespeare.txt") as f:
31
+ text = f.read()
32
+ # %%
33
+ chars = sorted(set(text))
34
+ vocab_size = len(chars)
35
+
36
+ # %%
37
+ print(f"Vocab size: {vocab_size}")
38
+ print(f"Text length: {len(text)}")
39
+
40
+ # %%
41
+ string_to_int = {ch: i for i, ch in enumerate(chars)}
42
+ int_to_string = {i: ch for i, ch in enumerate(chars)}
43
+
44
+ encode = lambda s: [string_to_int[ch] for ch in s]
45
+ decode = lambda x: ''.join([int_to_string[i] for i in x])
46
+
47
+ data = torch.tensor(encode(text), dtype=torch.long, device=device)
48
+
49
+
50
+ # %%
51
+ n = int(0.8 * len(data))
52
+ train_data = data[:n]
53
+ val_data = data[n:]
54
+
55
+ # %%
56
+ def get_batch(split):
57
+ data = train_data if split == 'train' else val_data
58
+ ix = torch.randint(len(data) - block_size, (batch_size,))
59
+ x = torch.stack([data[i:i+block_size] for i in ix])
60
+ y = torch.stack([data[i+1:i+block_size+1] for i in ix])
61
+ x, y = x.to(device), y.to(device)
62
+ return x, y
63
+
64
+ # %%
65
+ @torch.no_grad()
66
+ def estimate_loss():
67
+ out = {}
68
+ model.eval()
69
+ for split in ['train', 'val']:
70
+ losses = torch.zeros(eval_every)
71
+ for k in range(eval_every):
72
+ X, Y = get_batch(split)
73
+ logits, loss = model(X, Y)
74
+ losses[k] = loss.item()
75
+ out[split] = losses.mean()
76
+ model.train()
77
+ return out
78
+
79
+ # %%
80
+
81
+ class Head(nn.Module):
82
+ """ one head of self-attention """
83
+
84
+ def __init__(self, head_size):
85
+ super().__init__()
86
+ self.key = nn.Linear(n_embd, head_size, bias=False)
87
+ self.query = nn.Linear(n_embd, head_size, bias=False)
88
+ self.value = nn.Linear(n_embd, head_size, bias=False)
89
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
90
+
91
+ self.dropout = nn.Dropout(dropout)
92
+
93
+ def forward(self, x):
94
+ # input of size (batch, time-step, channels)
95
+ # output of size (batch, time-step, head size)
96
+ B,T,C = x.shape
97
+ k = self.key(x) # (B,T,hs)
98
+ q = self.query(x) # (B,T,hs)
99
+ # compute attention scores ("affinities")
100
+ wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
101
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
102
+ wei = F.softmax(wei, dim=-1) # (B, T, T)
103
+ wei = self.dropout(wei)
104
+ # perform the weighted aggregation of the values
105
+ v = self.value(x) # (B,T,hs)
106
+ out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
107
+ return out
108
+
109
+ # [1, 0, 0]
110
+ # [1, 0.6, 0]
111
+ # [1, 0.6, 0.4]
112
+ class MultiHeadAttention(nn.Module):
113
+ """ multiple heads of self-attention in parallel """
114
+
115
+ def __init__(self, num_heads, head_size):
116
+ super().__init__()
117
+ self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
118
+ self.proj = nn.Linear(head_size * num_heads, n_embd)
119
+ self.dropout = nn.Dropout(dropout)
120
+
121
+ def forward(self, x):
122
+ out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, F) -> (B, T, [h1, h1, h1, h1, h2, h2, h2, h2, h3, h3, h3, h3])
123
+ out = self.dropout(self.proj(out))
124
+ return out
125
+
126
+
127
+ class FeedFoward(nn.Module):
128
+ """ a simple linear layer followed by a non-linearity """
129
+
130
+ def __init__(self, n_embd):
131
+ super().__init__()
132
+ self.net = nn.Sequential(
133
+ nn.Linear(n_embd, 4 * n_embd),
134
+ nn.ReLU(),
135
+ nn.Linear(4 * n_embd, n_embd),
136
+ nn.Dropout(dropout),
137
+ )
138
+
139
+ def forward(self, x):
140
+ return self.net(x)
141
+
142
+ class Block(nn.Module):
143
+ """ Transformer block: communication followed by computation """
144
+
145
+ def __init__(self, n_embd, n_head):
146
+ # n_embd: embedding dimension, n_head: the number of heads we'd like
147
+ super().__init__()
148
+ head_size = n_embd // n_head
149
+ self.sa = MultiHeadAttention(n_head, head_size)
150
+ self.ffwd = FeedFoward(n_embd)
151
+ self.ln1 = nn.LayerNorm(n_embd)
152
+ self.ln2 = nn.LayerNorm(n_embd)
153
+
154
+ def forward(self, x):
155
+ y = self.sa(x)
156
+ x = self.ln1(x + y)
157
+ y = self.ffwd(x)
158
+ x = self.ln2(x + y)
159
+ return x
160
+
161
+ class GPTLanguageModel(nn.Module):
162
+ def __init__(self, vocab_size):
163
+ super().__init__()
164
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
165
+ self.position_embedding_table = nn.Embedding(block_size, n_embd)
166
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
167
+ self.ln_f = nn.LayerNorm(n_embd) # final layer norm
168
+ self.lm_head = nn.Linear(n_embd, vocab_size)
169
+
170
+
171
+ self.apply(self._init_weights)
172
+
173
+ def _init_weights(self, module):
174
+ if isinstance(module, nn.Linear):
175
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
176
+ if module.bias is not None:
177
+ torch.nn.init.zeros_(module.bias)
178
+ elif isinstance(module, nn.Embedding):
179
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
180
+
181
+ def forward(self, index, targets=None):
182
+ B, T = index.shape
183
+
184
+
185
+ # idx and targets are both (B,T) tensor of integers
186
+ tok_emb = self.token_embedding_table(index) # (B,T,C)
187
+ pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
188
+ x = tok_emb + pos_emb # (B,T,C)
189
+ x = self.blocks(x) # (B,T,C)
190
+ x = self.ln_f(x) # (B,T,C)
191
+ logits = self.lm_head(x) # (B,T,vocab_size)
192
+
193
+ if targets is None:
194
+ loss = None
195
+ else:
196
+ B, T, C = logits.shape
197
+ logits = logits.view(B*T, C) # reshape to what torch.cross_entropy expects
198
+ targets = targets.view(B*T)
199
+ loss = F.cross_entropy(logits, targets)
200
+ return logits, loss
201
+
202
+ def generate(self, index, max_new_tokens):
203
+ # index is (B, T) array of indices in the current context
204
+ for _ in range(max_new_tokens):
205
+ # crop idx to the last block_size tokens
206
+ index_cond = index[:, -block_size:]
207
+ # get the predictions
208
+ logits, loss = self.forward(index_cond)
209
+ # focus only on the last time step
210
+ logits = logits[:, -1, :] # becomes (B, C)
211
+ # apply softmax to get probabilities
212
+ probs = F.softmax(logits, dim=-1) # (B, C)
213
+ # sample from the distribution
214
+ index_next = torch.multinomial(probs, num_samples=1) # (B, 1)
215
+ # append sampled index to the running sequence
216
+ index = torch.cat((index, index_next), dim=1) # (B, T+1)
217
+ return index
218
+
219
+ model = GPTLanguageModel(vocab_size).to(device)
220
+
221
+ # create a PyTorch optimizer
222
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
223
+
224
+ for iter in range(max_iters):
225
+ if iter % eval_every == 0:
226
+ losses = estimate_loss()
227
+ print(f"step: {iter}, train loss: {losses['train']:.3f}, val loss: {losses['val']:.3f}")
228
+
229
+ # sample a batch of data
230
+ xb, yb = get_batch('train')
231
+
232
+ # evaluate the loss
233
+ logits, loss = model.forward(xb, yb)
234
+ optimizer.zero_grad(set_to_none=True)
235
+ loss.backward()
236
+ optimizer.step()
237
+ print(loss.item())
238
+
239
+ # %%
240
+
241
+ context = torch.zeros((1,1), dtype=torch.long, device=device)
242
+ generated_chars = decode(model.generate(context, max_new_tokens=100)[0].tolist())
243
+ print(generated_chars)
244
+
245
+
246
+ # %%
247
+
248
+ prompt = 'To be or not to be,'
249
+ context = torch.tensor(encode(prompt), dtype=torch.long, device=device)
250
+ generated_chars = decode(model.generate(context.unsqueeze(0), max_new_tokens=100)[0].tolist())
251
+ print(generated_chars)