AYYasaswini commited on
Commit
abb03b5
·
verified ·
1 Parent(s): 4a5a9f1

Delete gpt_dev.ipynb

Browse files
Files changed (1) hide show
  1. gpt_dev.ipynb +0 -1556
gpt_dev.ipynb DELETED
@@ -1,1556 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": []
7
- },
8
- "kernelspec": {
9
- "name": "python3",
10
- "display_name": "Python 3"
11
- },
12
- "language_info": {
13
- "name": "python"
14
- }
15
- },
16
- "cells": [
17
- {
18
- "cell_type": "markdown",
19
- "source": [
20
- "## Building a GPT\n",
21
- "\n",
22
- "Companion notebook to the [Zero To Hero](https://karpathy.ai/zero-to-hero.html) video on GPT."
23
- ],
24
- "metadata": {
25
- "id": "wJpXpmjEYC_T"
26
- }
27
- },
28
- {
29
- "cell_type": "code",
30
- "execution_count": 3,
31
- "metadata": {
32
- "colab": {
33
- "base_uri": "https://localhost:8080/"
34
- },
35
- "id": "h5hjCcLDr2WC",
36
- "outputId": "24b008b5-5eb3-4882-a553-1ef45aaaf782"
37
- },
38
- "outputs": [
39
- {
40
- "output_type": "stream",
41
- "name": "stdout",
42
- "text": [
43
- "--2024-06-11 13:37:04-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
44
- "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
45
- "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
46
- "HTTP request sent, awaiting response... 200 OK\n",
47
- "Length: 1115394 (1.1M) [text/plain]\n",
48
- "Saving to: ‘input.txt.1’\n",
49
- "\n",
50
- "\rinput.txt.1 0%[ ] 0 --.-KB/s \rinput.txt.1 100%[===================>] 1.06M --.-KB/s in 0.05s \n",
51
- "\n",
52
- "2024-06-11 13:37:04 (21.7 MB/s) - ‘input.txt.1’ saved [1115394/1115394]\n",
53
- "\n"
54
- ]
55
- }
56
- ],
57
- "source": [
58
- "# We always start with a dataset to train on. Let's download the tiny shakespeare dataset\n",
59
- "!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
60
- ]
61
- },
62
- {
63
- "cell_type": "code",
64
- "source": [
65
- "# read it in to inspect it\n",
66
- "with open('input.txt', 'r', encoding='utf-8') as f:\n",
67
- " text = f.read()"
68
- ],
69
- "metadata": {
70
- "id": "O6medjfRsLD9"
71
- },
72
- "execution_count": 4,
73
- "outputs": []
74
- },
75
- {
76
- "cell_type": "code",
77
- "source": [
78
- "print(\"length of dataset in characters: \", len(text))"
79
- ],
80
- "metadata": {
81
- "colab": {
82
- "base_uri": "https://localhost:8080/"
83
- },
84
- "id": "6xWI_VyAsN8F",
85
- "outputId": "68d2ea04-26cd-4ce8-f31e-10868b38f7d0"
86
- },
87
- "execution_count": 5,
88
- "outputs": [
89
- {
90
- "output_type": "stream",
91
- "name": "stdout",
92
- "text": [
93
- "length of dataset in characters: 1115394\n"
94
- ]
95
- }
96
- ]
97
- },
98
- {
99
- "cell_type": "code",
100
- "source": [
101
- "# let's look at the first 1000 characters\n",
102
- "print(text[:1000])"
103
- ],
104
- "metadata": {
105
- "colab": {
106
- "base_uri": "https://localhost:8080/"
107
- },
108
- "id": "2c5V0FvqseE0",
109
- "outputId": "5306e25a-cad6-4ac6-9d34-8138bbaa34a4"
110
- },
111
- "execution_count": 6,
112
- "outputs": [
113
- {
114
- "output_type": "stream",
115
- "name": "stdout",
116
- "text": [
117
- "First Citizen:\n",
118
- "Before we proceed any further, hear me speak.\n",
119
- "\n",
120
- "All:\n",
121
- "Speak, speak.\n",
122
- "\n",
123
- "First Citizen:\n",
124
- "You are all resolved rather to die than to famish?\n",
125
- "\n",
126
- "All:\n",
127
- "Resolved. resolved.\n",
128
- "\n",
129
- "First Citizen:\n",
130
- "First, you know Caius Marcius is chief enemy to the people.\n",
131
- "\n",
132
- "All:\n",
133
- "We know't, we know't.\n",
134
- "\n",
135
- "First Citizen:\n",
136
- "Let us kill him, and we'll have corn at our own price.\n",
137
- "Is't a verdict?\n",
138
- "\n",
139
- "All:\n",
140
- "No more talking on't; let it be done: away, away!\n",
141
- "\n",
142
- "Second Citizen:\n",
143
- "One word, good citizens.\n",
144
- "\n",
145
- "First Citizen:\n",
146
- "We are accounted poor citizens, the patricians good.\n",
147
- "What authority surfeits on would relieve us: if they\n",
148
- "would yield us but the superfluity, while it were\n",
149
- "wholesome, we might guess they relieved us humanely;\n",
150
- "but they think we are too dear: the leanness that\n",
151
- "afflicts us, the object of our misery, is as an\n",
152
- "inventory to particularise their abundance; our\n",
153
- "sufferance is a gain to them Let us revenge this with\n",
154
- "our pikes, ere we become rakes: for the gods know I\n",
155
- "speak this in hunger for bread, not in thirst for revenge.\n",
156
- "\n",
157
- "\n"
158
- ]
159
- }
160
- ]
161
- },
162
- {
163
- "cell_type": "code",
164
- "source": [
165
- "# here are all the unique characters that occur in this text\n",
166
- "chars = sorted(list(set(text)))\n",
167
- "vocab_size = len(chars)\n",
168
- "print(''.join(chars))\n",
169
- "print(vocab_size)"
170
- ],
171
- "metadata": {
172
- "colab": {
173
- "base_uri": "https://localhost:8080/"
174
- },
175
- "id": "0e-Rbyr8sfM8",
176
- "outputId": "3cfb92f5-e9dc-4a4d-bc24-01c34e91fe2c"
177
- },
178
- "execution_count": 7,
179
- "outputs": [
180
- {
181
- "output_type": "stream",
182
- "name": "stdout",
183
- "text": [
184
- "\n",
185
- " !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
186
- "65\n"
187
- ]
188
- }
189
- ]
190
- },
191
- {
192
- "cell_type": "code",
193
- "source": [
194
- "# create a mapping from characters to integers\n",
195
- "stoi = { ch:i for i,ch in enumerate(chars) }\n",
196
- "itos = { i:ch for i,ch in enumerate(chars) }\n",
197
- "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
198
- "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
199
- "\n",
200
- "print(encode(\"hii there\"))\n",
201
- "print(decode(encode(\"hii there\")))"
202
- ],
203
- "metadata": {
204
- "colab": {
205
- "base_uri": "https://localhost:8080/"
206
- },
207
- "id": "Yw1LKNCgwjj1",
208
- "outputId": "b32844f8-99ed-4eb8-c569-06196f56051f"
209
- },
210
- "execution_count": 8,
211
- "outputs": [
212
- {
213
- "output_type": "stream",
214
- "name": "stdout",
215
- "text": [
216
- "[46, 47, 47, 1, 58, 46, 43, 56, 43]\n",
217
- "hii there\n"
218
- ]
219
- }
220
- ]
221
- },
222
- {
223
- "cell_type": "code",
224
- "source": [
225
- "# let's now encode the entire text dataset and store it into a torch.Tensor\n",
226
- "import torch # we use PyTorch: https://pytorch.org\n",
227
- "data = torch.tensor(encode(text), dtype=torch.long)\n",
228
- "print(data.shape, data.dtype)\n",
229
- "print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this"
230
- ],
231
- "metadata": {
232
- "colab": {
233
- "base_uri": "https://localhost:8080/"
234
- },
235
- "id": "YJb0OXPwzvqg",
236
- "outputId": "7081b874-3ef5-4e65-ee10-acbc24ac9f9b"
237
- },
238
- "execution_count": 9,
239
- "outputs": [
240
- {
241
- "output_type": "stream",
242
- "name": "stdout",
243
- "text": [
244
- "torch.Size([1115394]) torch.int64\n",
245
- "tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44,\n",
246
- " 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63,\n",
247
- " 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1,\n",
248
- " 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49,\n",
249
- " 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47,\n",
250
- " 58, 47, 64, 43, 52, 10, 0, 37, 53, 59, 1, 39, 56, 43, 1, 39, 50, 50,\n",
251
- " 1, 56, 43, 57, 53, 50, 60, 43, 42, 1, 56, 39, 58, 46, 43, 56, 1, 58,\n",
252
- " 53, 1, 42, 47, 43, 1, 58, 46, 39, 52, 1, 58, 53, 1, 44, 39, 51, 47,\n",
253
- " 57, 46, 12, 0, 0, 13, 50, 50, 10, 0, 30, 43, 57, 53, 50, 60, 43, 42,\n",
254
- " 8, 1, 56, 43, 57, 53, 50, 60, 43, 42, 8, 0, 0, 18, 47, 56, 57, 58,\n",
255
- " 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 18, 47, 56, 57, 58, 6, 1, 63,\n",
256
- " 53, 59, 1, 49, 52, 53, 61, 1, 15, 39, 47, 59, 57, 1, 25, 39, 56, 41,\n",
257
- " 47, 59, 57, 1, 47, 57, 1, 41, 46, 47, 43, 44, 1, 43, 52, 43, 51, 63,\n",
258
- " 1, 58, 53, 1, 58, 46, 43, 1, 54, 43, 53, 54, 50, 43, 8, 0, 0, 13,\n",
259
- " 50, 50, 10, 0, 35, 43, 1, 49, 52, 53, 61, 5, 58, 6, 1, 61, 43, 1,\n",
260
- " 49, 52, 53, 61, 5, 58, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58,\n",
261
- " 47, 64, 43, 52, 10, 0, 24, 43, 58, 1, 59, 57, 1, 49, 47, 50, 50, 1,\n",
262
- " 46, 47, 51, 6, 1, 39, 52, 42, 1, 61, 43, 5, 50, 50, 1, 46, 39, 60,\n",
263
- " 43, 1, 41, 53, 56, 52, 1, 39, 58, 1, 53, 59, 56, 1, 53, 61, 52, 1,\n",
264
- " 54, 56, 47, 41, 43, 8, 0, 21, 57, 5, 58, 1, 39, 1, 60, 43, 56, 42,\n",
265
- " 47, 41, 58, 12, 0, 0, 13, 50, 50, 10, 0, 26, 53, 1, 51, 53, 56, 43,\n",
266
- " 1, 58, 39, 50, 49, 47, 52, 45, 1, 53, 52, 5, 58, 11, 1, 50, 43, 58,\n",
267
- " 1, 47, 58, 1, 40, 43, 1, 42, 53, 52, 43, 10, 1, 39, 61, 39, 63, 6,\n",
268
- " 1, 39, 61, 39, 63, 2, 0, 0, 31, 43, 41, 53, 52, 42, 1, 15, 47, 58,\n",
269
- " 47, 64, 43, 52, 10, 0, 27, 52, 43, 1, 61, 53, 56, 42, 6, 1, 45, 53,\n",
270
- " 53, 42, 1, 41, 47, 58, 47, 64, 43, 52, 57, 8, 0, 0, 18, 47, 56, 57,\n",
271
- " 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 35, 43, 1, 39, 56, 43, 1,\n",
272
- " 39, 41, 41, 53, 59, 52, 58, 43, 42, 1, 54, 53, 53, 56, 1, 41, 47, 58,\n",
273
- " 47, 64, 43, 52, 57, 6, 1, 58, 46, 43, 1, 54, 39, 58, 56, 47, 41, 47,\n",
274
- " 39, 52, 57, 1, 45, 53, 53, 42, 8, 0, 35, 46, 39, 58, 1, 39, 59, 58,\n",
275
- " 46, 53, 56, 47, 58, 63, 1, 57, 59, 56, 44, 43, 47, 58, 57, 1, 53, 52,\n",
276
- " 1, 61, 53, 59, 50, 42, 1, 56, 43, 50, 47, 43, 60, 43, 1, 59, 57, 10,\n",
277
- " 1, 47, 44, 1, 58, 46, 43, 63, 0, 61, 53, 59, 50, 42, 1, 63, 47, 43,\n",
278
- " 50, 42, 1, 59, 57, 1, 40, 59, 58, 1, 58, 46, 43, 1, 57, 59, 54, 43,\n",
279
- " 56, 44, 50, 59, 47, 58, 63, 6, 1, 61, 46, 47, 50, 43, 1, 47, 58, 1,\n",
280
- " 61, 43, 56, 43, 0, 61, 46, 53, 50, 43, 57, 53, 51, 43, 6, 1, 61, 43,\n",
281
- " 1, 51, 47, 45, 46, 58, 1, 45, 59, 43, 57, 57, 1, 58, 46, 43, 63, 1,\n",
282
- " 56, 43, 50, 47, 43, 60, 43, 42, 1, 59, 57, 1, 46, 59, 51, 39, 52, 43,\n",
283
- " 50, 63, 11, 0, 40, 59, 58, 1, 58, 46, 43, 63, 1, 58, 46, 47, 52, 49,\n",
284
- " 1, 61, 43, 1, 39, 56, 43, 1, 58, 53, 53, 1, 42, 43, 39, 56, 10, 1,\n",
285
- " 58, 46, 43, 1, 50, 43, 39, 52, 52, 43, 57, 57, 1, 58, 46, 39, 58, 0,\n",
286
- " 39, 44, 44, 50, 47, 41, 58, 57, 1, 59, 57, 6, 1, 58, 46, 43, 1, 53,\n",
287
- " 40, 48, 43, 41, 58, 1, 53, 44, 1, 53, 59, 56, 1, 51, 47, 57, 43, 56,\n",
288
- " 63, 6, 1, 47, 57, 1, 39, 57, 1, 39, 52, 0, 47, 52, 60, 43, 52, 58,\n",
289
- " 53, 56, 63, 1, 58, 53, 1, 54, 39, 56, 58, 47, 41, 59, 50, 39, 56, 47,\n",
290
- " 57, 43, 1, 58, 46, 43, 47, 56, 1, 39, 40, 59, 52, 42, 39, 52, 41, 43,\n",
291
- " 11, 1, 53, 59, 56, 0, 57, 59, 44, 44, 43, 56, 39, 52, 41, 43, 1, 47,\n",
292
- " 57, 1, 39, 1, 45, 39, 47, 52, 1, 58, 53, 1, 58, 46, 43, 51, 1, 24,\n",
293
- " 43, 58, 1, 59, 57, 1, 56, 43, 60, 43, 52, 45, 43, 1, 58, 46, 47, 57,\n",
294
- " 1, 61, 47, 58, 46, 0, 53, 59, 56, 1, 54, 47, 49, 43, 57, 6, 1, 43,\n",
295
- " 56, 43, 1, 61, 43, 1, 40, 43, 41, 53, 51, 43, 1, 56, 39, 49, 43, 57,\n",
296
- " 10, 1, 44, 53, 56, 1, 58, 46, 43, 1, 45, 53, 42, 57, 1, 49, 52, 53,\n",
297
- " 61, 1, 21, 0, 57, 54, 43, 39, 49, 1, 58, 46, 47, 57, 1, 47, 52, 1,\n",
298
- " 46, 59, 52, 45, 43, 56, 1, 44, 53, 56, 1, 40, 56, 43, 39, 42, 6, 1,\n",
299
- " 52, 53, 58, 1, 47, 52, 1, 58, 46, 47, 56, 57, 58, 1, 44, 53, 56, 1,\n",
300
- " 56, 43, 60, 43, 52, 45, 43, 8, 0, 0])\n"
301
- ]
302
- }
303
- ]
304
- },
305
- {
306
- "cell_type": "code",
307
- "source": [
308
- "# Let's now split up the data into train and validation sets\n",
309
- "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
310
- "train_data = data[:n]\n",
311
- "val_data = data[n:]"
312
- ],
313
- "metadata": {
314
- "id": "f_WIXqxz0lU5"
315
- },
316
- "execution_count": 10,
317
- "outputs": []
318
- },
319
- {
320
- "cell_type": "code",
321
- "source": [
322
- "block_size = 8\n",
323
- "train_data[:block_size+1]"
324
- ],
325
- "metadata": {
326
- "colab": {
327
- "base_uri": "https://localhost:8080/"
328
- },
329
- "id": "TD5Bj8Y6IAD4",
330
- "outputId": "44a45420-f035-40e7-a089-7685ca25d361"
331
- },
332
- "execution_count": 11,
333
- "outputs": [
334
- {
335
- "output_type": "execute_result",
336
- "data": {
337
- "text/plain": [
338
- "tensor([18, 47, 56, 57, 58, 1, 15, 47, 58])"
339
- ]
340
- },
341
- "metadata": {},
342
- "execution_count": 11
343
- }
344
- ]
345
- },
346
- {
347
- "cell_type": "code",
348
- "source": [
349
- "x = train_data[:block_size]\n",
350
- "y = train_data[1:block_size+1]\n",
351
- "for t in range(block_size):\n",
352
- " context = x[:t+1]\n",
353
- " target = y[t]\n",
354
- " print(f\"when input is {context} the target: {target}\")"
355
- ],
356
- "metadata": {
357
- "colab": {
358
- "base_uri": "https://localhost:8080/"
359
- },
360
- "id": "9HXDe8vGJCEn",
361
- "outputId": "96af3b4e-7307-4949-c0f9-05c892514196"
362
- },
363
- "execution_count": 12,
364
- "outputs": [
365
- {
366
- "output_type": "stream",
367
- "name": "stdout",
368
- "text": [
369
- "when input is tensor([18]) the target: 47\n",
370
- "when input is tensor([18, 47]) the target: 56\n",
371
- "when input is tensor([18, 47, 56]) the target: 57\n",
372
- "when input is tensor([18, 47, 56, 57]) the target: 58\n",
373
- "when input is tensor([18, 47, 56, 57, 58]) the target: 1\n",
374
- "when input is tensor([18, 47, 56, 57, 58, 1]) the target: 15\n",
375
- "when input is tensor([18, 47, 56, 57, 58, 1, 15]) the target: 47\n",
376
- "when input is tensor([18, 47, 56, 57, 58, 1, 15, 47]) the target: 58\n"
377
- ]
378
- }
379
- ]
380
- },
381
- {
382
- "cell_type": "code",
383
- "source": [
384
- "torch.manual_seed(1337)\n",
385
- "batch_size = 4 # how many independent sequences will we process in parallel?\n",
386
- "block_size = 8 # what is the maximum context length for predictions?\n",
387
- "\n",
388
- "def get_batch(split):\n",
389
- " # generate a small batch of data of inputs x and targets y\n",
390
- " data = train_data if split == 'train' else val_data\n",
391
- " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
392
- " x = torch.stack([data[i:i+block_size] for i in ix])\n",
393
- " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
394
- " return x, y\n",
395
- "\n",
396
- "xb, yb = get_batch('train')\n",
397
- "print('inputs:')\n",
398
- "print(xb.shape)\n",
399
- "print(xb)\n",
400
- "print('targets:')\n",
401
- "print(yb.shape)\n",
402
- "print(yb)\n",
403
- "\n",
404
- "print('----')\n",
405
- "\n",
406
- "for b in range(batch_size): # batch dimension\n",
407
- " for t in range(block_size): # time dimension\n",
408
- " context = xb[b, :t+1]\n",
409
- " target = yb[b,t]\n",
410
- " print(f\"when input is {context.tolist()} the target: {target}\")"
411
- ],
412
- "metadata": {
413
- "colab": {
414
- "base_uri": "https://localhost:8080/"
415
- },
416
- "id": "Q3k1Czf7LuA9",
417
- "outputId": "e7e206dc-1cae-4f95-a82d-5faa6fd1c627"
418
- },
419
- "execution_count": 13,
420
- "outputs": [
421
- {
422
- "output_type": "stream",
423
- "name": "stdout",
424
- "text": [
425
- "inputs:\n",
426
- "torch.Size([4, 8])\n",
427
- "tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n",
428
- " [44, 53, 56, 1, 58, 46, 39, 58],\n",
429
- " [52, 58, 1, 58, 46, 39, 58, 1],\n",
430
- " [25, 17, 27, 10, 0, 21, 1, 54]])\n",
431
- "targets:\n",
432
- "torch.Size([4, 8])\n",
433
- "tensor([[43, 58, 5, 57, 1, 46, 43, 39],\n",
434
- " [53, 56, 1, 58, 46, 39, 58, 1],\n",
435
- " [58, 1, 58, 46, 39, 58, 1, 46],\n",
436
- " [17, 27, 10, 0, 21, 1, 54, 39]])\n",
437
- "----\n",
438
- "when input is [24] the target: 43\n",
439
- "when input is [24, 43] the target: 58\n",
440
- "when input is [24, 43, 58] the target: 5\n",
441
- "when input is [24, 43, 58, 5] the target: 57\n",
442
- "when input is [24, 43, 58, 5, 57] the target: 1\n",
443
- "when input is [24, 43, 58, 5, 57, 1] the target: 46\n",
444
- "when input is [24, 43, 58, 5, 57, 1, 46] the target: 43\n",
445
- "when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39\n",
446
- "when input is [44] the target: 53\n",
447
- "when input is [44, 53] the target: 56\n",
448
- "when input is [44, 53, 56] the target: 1\n",
449
- "when input is [44, 53, 56, 1] the target: 58\n",
450
- "when input is [44, 53, 56, 1, 58] the target: 46\n",
451
- "when input is [44, 53, 56, 1, 58, 46] the target: 39\n",
452
- "when input is [44, 53, 56, 1, 58, 46, 39] the target: 58\n",
453
- "when input is [44, 53, 56, 1, 58, 46, 39, 58] the target: 1\n",
454
- "when input is [52] the target: 58\n",
455
- "when input is [52, 58] the target: 1\n",
456
- "when input is [52, 58, 1] the target: 58\n",
457
- "when input is [52, 58, 1, 58] the target: 46\n",
458
- "when input is [52, 58, 1, 58, 46] the target: 39\n",
459
- "when input is [52, 58, 1, 58, 46, 39] the target: 58\n",
460
- "when input is [52, 58, 1, 58, 46, 39, 58] the target: 1\n",
461
- "when input is [52, 58, 1, 58, 46, 39, 58, 1] the target: 46\n",
462
- "when input is [25] the target: 17\n",
463
- "when input is [25, 17] the target: 27\n",
464
- "when input is [25, 17, 27] the target: 10\n",
465
- "when input is [25, 17, 27, 10] the target: 0\n",
466
- "when input is [25, 17, 27, 10, 0] the target: 21\n",
467
- "when input is [25, 17, 27, 10, 0, 21] the target: 1\n",
468
- "when input is [25, 17, 27, 10, 0, 21, 1] the target: 54\n",
469
- "when input is [25, 17, 27, 10, 0, 21, 1, 54] the target: 39\n"
470
- ]
471
- }
472
- ]
473
- },
474
- {
475
- "cell_type": "code",
476
- "source": [
477
- "print(xb) # our input to the transformer"
478
- ],
479
- "metadata": {
480
- "colab": {
481
- "base_uri": "https://localhost:8080/"
482
- },
483
- "id": "qpyyAeIzQjlO",
484
- "outputId": "febd3181-36c8-4567-f33c-dbfc4cbc99d5"
485
- },
486
- "execution_count": 14,
487
- "outputs": [
488
- {
489
- "output_type": "stream",
490
- "name": "stdout",
491
- "text": [
492
- "tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n",
493
- " [44, 53, 56, 1, 58, 46, 39, 58],\n",
494
- " [52, 58, 1, 58, 46, 39, 58, 1],\n",
495
- " [25, 17, 27, 10, 0, 21, 1, 54]])\n"
496
- ]
497
- }
498
- ]
499
- },
500
- {
501
- "cell_type": "code",
502
- "source": [
503
- "import torch\n",
504
- "import torch.nn as nn\n",
505
- "from torch.nn import functional as F\n",
506
- "torch.manual_seed(1337)\n",
507
- "\n",
508
- "class BigramLanguageModel(nn.Module):\n",
509
- "\n",
510
- " def __init__(self, vocab_size):\n",
511
- " super().__init__()\n",
512
- " # each token directly reads off the logits for the next token from a lookup table\n",
513
- " self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n",
514
- "\n",
515
- " def forward(self, idx, targets=None):\n",
516
- "\n",
517
- " # idx and targets are both (B,T) tensor of integers\n",
518
- " logits = self.token_embedding_table(idx) # (B,T,C)\n",
519
- "\n",
520
- " if targets is None:\n",
521
- " loss = None\n",
522
- " else:\n",
523
- " B, T, C = logits.shape\n",
524
- " logits = logits.view(B*T, C)\n",
525
- " targets = targets.view(B*T)\n",
526
- " loss = F.cross_entropy(logits, targets)\n",
527
- "\n",
528
- " return logits, loss\n",
529
- "\n",
530
- " def generate(self, idx, max_new_tokens):\n",
531
- " # idx is (B, T) array of indices in the current context\n",
532
- " for _ in range(max_new_tokens):\n",
533
- " # get the predictions\n",
534
- " logits, loss = self(idx)\n",
535
- " # focus only on the last time step\n",
536
- " logits = logits[:, -1, :] # becomes (B, C)\n",
537
- " # apply softmax to get probabilities\n",
538
- " probs = F.softmax(logits, dim=-1) # (B, C)\n",
539
- " # sample from the distribution\n",
540
- " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
541
- " # append sampled index to the running sequence\n",
542
- " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
543
- " return idx\n",
544
- "\n",
545
- "m = BigramLanguageModel(vocab_size)\n",
546
- "logits, loss = m(xb, yb)\n",
547
- "print(logits.shape)\n",
548
- "print(loss)\n",
549
- "\n",
550
- "print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))\n"
551
- ],
552
- "metadata": {
553
- "colab": {
554
- "base_uri": "https://localhost:8080/"
555
- },
556
- "id": "nql_1ER53oCf",
557
- "outputId": "7b1620c9-3bf2-45a2-8e08-d6ca73d09528"
558
- },
559
- "execution_count": 15,
560
- "outputs": [
561
- {
562
- "output_type": "stream",
563
- "name": "stdout",
564
- "text": [
565
- "torch.Size([32, 65])\n",
566
- "tensor(4.8786, grad_fn=<NllLossBackward0>)\n",
567
- "\n",
568
- "Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3\n"
569
- ]
570
- }
571
- ]
572
- },
573
- {
574
- "cell_type": "code",
575
- "source": [
576
- "# create a PyTorch optimizer\n",
577
- "optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)"
578
- ],
579
- "metadata": {
580
- "id": "eTyJ8qAaDdiF"
581
- },
582
- "execution_count": 16,
583
- "outputs": []
584
- },
585
- {
586
- "cell_type": "code",
587
- "source": [
588
- "batch_size = 32\n",
589
- "for steps in range(100): # increase number of steps for good results...\n",
590
- "\n",
591
- " # sample a batch of data\n",
592
- " xb, yb = get_batch('train')\n",
593
- "\n",
594
- " # evaluate the loss\n",
595
- " logits, loss = m(xb, yb)\n",
596
- " optimizer.zero_grad(set_to_none=True)\n",
597
- " loss.backward()\n",
598
- " optimizer.step()\n",
599
- "\n",
600
- "print(loss.item())\n"
601
- ],
602
- "metadata": {
603
- "colab": {
604
- "base_uri": "https://localhost:8080/"
605
- },
606
- "id": "Hs4kI8YdEkQj",
607
- "outputId": "234b1d99-e1d5-4394-ca9a-964027301d48"
608
- },
609
- "execution_count": 17,
610
- "outputs": [
611
- {
612
- "output_type": "stream",
613
- "name": "stdout",
614
- "text": [
615
- "4.587916374206543\n"
616
- ]
617
- }
618
- ]
619
- },
620
- {
621
- "cell_type": "code",
622
- "source": [
623
- "print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))"
624
- ],
625
- "metadata": {
626
- "colab": {
627
- "base_uri": "https://localhost:8080/"
628
- },
629
- "id": "EcVIDWAZEtjN",
630
- "outputId": "13e7e5a8-e382-4610-aecb-ce274d466533"
631
- },
632
- "execution_count": 18,
633
- "outputs": [
634
- {
635
- "output_type": "stream",
636
- "name": "stdout",
637
- "text": [
638
- "\n",
639
- "xiKi-RJ:CgqVuUa!U?qMH.uk!sCuMXvv!CJFfx;LgRyJknOEti.?I&-gPlLyulId?XlaInQ'q,lT$\n",
640
- "3Q&sGlvHQ?mqSq-eON\n",
641
- "x?SP fUAfCAuCX:bOlgiRQWN:Mphaw\n",
642
- "tRLKuYXEaAXxrcq-gCUzeh3w!AcyaylgYWjmJM?Uzw:inaY,:C&OECW:vmGGJAn3onAuMgia!ms$Vb q-gCOcPcUhOnxJGUGSPJWT:.?ujmJFoiNL&A'DxY,prZ?qdT;hoo'dHooXXlxf'WkHK&u3Q?rqUi.kz;?Yx?C&u3Qbfzxlyh'Vl:zyxjKXgC?\n",
643
- "lv'QKFiBeviNxO'm!Upm$srm&TqViqiBD3HBP!juEOpmZJyF$Fwfy!PlvWPFC\n",
644
- "&WDdP!Ko,px\n",
645
- "x\n",
646
- "tREOE;AJ.BeXkylOVD3KHp$e?nD,.SFbWWI'ubcL!q-tU;aXmJ&uGXHxJXI&Z!gHRpajj;l.\n",
647
- "pTErIBjx;JKIgoCnLGXrJSP!AU-AcbczR?\n"
648
- ]
649
- }
650
- ]
651
- },
652
- {
653
- "cell_type": "markdown",
654
- "source": [
655
- "## The mathematical trick in self-attention"
656
- ],
657
- "metadata": {
658
- "id": "XinV8nmAnmKN"
659
- }
660
- },
661
- {
662
- "cell_type": "code",
663
- "source": [
664
- "# toy example illustrating how matrix multiplication can be used for a \"weighted aggregation\"\n",
665
- "torch.manual_seed(42)\n",
666
- "a = torch.tril(torch.ones(3, 3))\n",
667
- "a = a / torch.sum(a, 1, keepdim=True)\n",
668
- "b = torch.randint(0,10,(3,2)).float()\n",
669
- "c = a @ b\n",
670
- "print('a=')\n",
671
- "print(a)\n",
672
- "print('--')\n",
673
- "print('b=')\n",
674
- "print(b)\n",
675
- "print('--')\n",
676
- "print('c=')\n",
677
- "print(c)"
678
- ],
679
- "metadata": {
680
- "colab": {
681
- "base_uri": "https://localhost:8080/"
682
- },
683
- "id": "tukiH-NbRBhA",
684
- "outputId": "4de5f70a-e12c-4c6a-d591-5d0720e9de8c"
685
- },
686
- "execution_count": 19,
687
- "outputs": [
688
- {
689
- "output_type": "stream",
690
- "name": "stdout",
691
- "text": [
692
- "a=\n",
693
- "tensor([[1.0000, 0.0000, 0.0000],\n",
694
- " [0.5000, 0.5000, 0.0000],\n",
695
- " [0.3333, 0.3333, 0.3333]])\n",
696
- "--\n",
697
- "b=\n",
698
- "tensor([[2., 7.],\n",
699
- " [6., 4.],\n",
700
- " [6., 5.]])\n",
701
- "--\n",
702
- "c=\n",
703
- "tensor([[2.0000, 7.0000],\n",
704
- " [4.0000, 5.5000],\n",
705
- " [4.6667, 5.3333]])\n"
706
- ]
707
- }
708
- ]
709
- },
710
- {
711
- "cell_type": "code",
712
- "source": [
713
- "# consider the following toy example:\n",
714
- "\n",
715
- "torch.manual_seed(1337)\n",
716
- "B,T,C = 4,8,2 # batch, time, channels\n",
717
- "x = torch.randn(B,T,C)\n",
718
- "x.shape"
719
- ],
720
- "metadata": {
721
- "colab": {
722
- "base_uri": "https://localhost:8080/"
723
- },
724
- "id": "Hs_E24uRE8kr",
725
- "outputId": "f1591218-d10f-420e-8d5a-456a0f90aed9"
726
- },
727
- "execution_count": 20,
728
- "outputs": [
729
- {
730
- "output_type": "execute_result",
731
- "data": {
732
- "text/plain": [
733
- "torch.Size([4, 8, 2])"
734
- ]
735
- },
736
- "metadata": {},
737
- "execution_count": 20
738
- }
739
- ]
740
- },
741
- {
742
- "cell_type": "code",
743
- "source": [
744
- "# We want x[b,t] = mean_{i<=t} x[b,i]\n",
745
- "xbow = torch.zeros((B,T,C))\n",
746
- "for b in range(B):\n",
747
- " for t in range(T):\n",
748
- " xprev = x[b,:t+1] # (t,C)\n",
749
- " xbow[b,t] = torch.mean(xprev, 0)\n"
750
- ],
751
- "metadata": {
752
- "id": "86NuXX0fn7ps"
753
- },
754
- "execution_count": 21,
755
- "outputs": []
756
- },
757
- {
758
- "cell_type": "code",
759
- "source": [
760
- "# version 2: using matrix multiply for a weighted aggregation\n",
761
- "wei = torch.tril(torch.ones(T, T))\n",
762
- "wei = wei / wei.sum(1, keepdim=True)\n",
763
- "xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)\n",
764
- "torch.allclose(xbow, xbow2)"
765
- ],
766
- "metadata": {
767
- "colab": {
768
- "base_uri": "https://localhost:8080/"
769
- },
770
- "id": "yhdOAd6-wXkZ",
771
- "outputId": "c7313d9b-d406-46ce-e2cd-f28c10ef41c2"
772
- },
773
- "execution_count": 22,
774
- "outputs": [
775
- {
776
- "output_type": "execute_result",
777
- "data": {
778
- "text/plain": [
779
- "False"
780
- ]
781
- },
782
- "metadata": {},
783
- "execution_count": 22
784
- }
785
- ]
786
- },
787
- {
788
- "cell_type": "code",
789
- "source": [
790
- "# version 3: use Softmax\n",
791
- "tril = torch.tril(torch.ones(T, T))\n",
792
- "wei = torch.zeros((T,T))\n",
793
- "wei = wei.masked_fill(tril == 0, float('-inf'))\n",
794
- "wei = F.softmax(wei, dim=-1)\n",
795
- "xbow3 = wei @ x\n",
796
- "torch.allclose(xbow, xbow3)\n"
797
- ],
798
- "metadata": {
799
- "colab": {
800
- "base_uri": "https://localhost:8080/"
801
- },
802
- "id": "wOURrfG-ysoL",
803
- "outputId": "40a4a993-5a9b-419c-e558-b935fd843dbf"
804
- },
805
- "execution_count": 23,
806
- "outputs": [
807
- {
808
- "output_type": "execute_result",
809
- "data": {
810
- "text/plain": [
811
- "False"
812
- ]
813
- },
814
- "metadata": {},
815
- "execution_count": 23
816
- }
817
- ]
818
- },
819
- {
820
- "cell_type": "code",
821
- "source": [
822
- "# version 4: self-attention!\n",
823
- "torch.manual_seed(1337)\n",
824
- "B,T,C = 4,8,32 # batch, time, channels\n",
825
- "x = torch.randn(B,T,C)\n",
826
- "\n",
827
- "# let's see a single Head perform self-attention\n",
828
- "head_size = 16\n",
829
- "key = nn.Linear(C, head_size, bias=False)\n",
830
- "query = nn.Linear(C, head_size, bias=False)\n",
831
- "value = nn.Linear(C, head_size, bias=False)\n",
832
- "k = key(x) # (B, T, 16)\n",
833
- "q = query(x) # (B, T, 16)\n",
834
- "wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)\n",
835
- "\n",
836
- "tril = torch.tril(torch.ones(T, T))\n",
837
- "#wei = torch.zeros((T,T))\n",
838
- "wei = wei.masked_fill(tril == 0, float('-inf'))\n",
839
- "wei = F.softmax(wei, dim=-1)\n",
840
- "\n",
841
- "v = value(x)\n",
842
- "out = wei @ v\n",
843
- "#out = wei @ x\n",
844
- "\n",
845
- "out.shape"
846
- ],
847
- "metadata": {
848
- "colab": {
849
- "base_uri": "https://localhost:8080/"
850
- },
851
- "id": "EDarxEWIRMKq",
852
- "outputId": "6fee2aa4-4ab6-4d89-c8ca-7463ee54962b"
853
- },
854
- "execution_count": 24,
855
- "outputs": [
856
- {
857
- "output_type": "execute_result",
858
- "data": {
859
- "text/plain": [
860
- "torch.Size([4, 8, 16])"
861
- ]
862
- },
863
- "metadata": {},
864
- "execution_count": 24
865
- }
866
- ]
867
- },
868
- {
869
- "cell_type": "code",
870
- "source": [
871
- "wei[0]"
872
- ],
873
- "metadata": {
874
- "colab": {
875
- "base_uri": "https://localhost:8080/"
876
- },
877
- "id": "vT1hdtzXCjgL",
878
- "outputId": "c664020c-c9dd-4c85-84a4-fae0320453f8"
879
- },
880
- "execution_count": 25,
881
- "outputs": [
882
- {
883
- "output_type": "execute_result",
884
- "data": {
885
- "text/plain": [
886
- "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
887
- " [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
888
- " [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
889
- " [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],\n",
890
- " [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],\n",
891
- " [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],\n",
892
- " [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],\n",
893
- " [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],\n",
894
- " grad_fn=<SelectBackward0>)"
895
- ]
896
- },
897
- "metadata": {},
898
- "execution_count": 25
899
- }
900
- ]
901
- },
902
- {
903
- "cell_type": "markdown",
904
- "source": [
905
- "Notes:\n",
906
- "- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.\n",
907
- "- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.\n",
908
- "- Each example across batch dimension is of course processed completely independently and never \"talk\" to each other\n",
909
- "- In an \"encoder\" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a \"decoder\" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.\n",
910
- "- \"self-attention\" just means that the keys and values are produced from the same source as queries. In \"cross-attention\", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)\n",
911
- "- \"Scaled\" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below"
912
- ],
913
- "metadata": {
914
- "id": "M5CvobiQ0pLr"
915
- }
916
- },
917
- {
918
- "cell_type": "code",
919
- "source": [
920
- "k = torch.randn(B,T,head_size)\n",
921
- "q = torch.randn(B,T,head_size)\n",
922
- "wei = q @ k.transpose(-2, -1) * head_size**-0.5"
923
- ],
924
- "metadata": {
925
- "id": "4SNbLq5z3oBw"
926
- },
927
- "execution_count": 26,
928
- "outputs": []
929
- },
930
- {
931
- "cell_type": "code",
932
- "source": [
933
- "k.var()"
934
- ],
935
- "metadata": {
936
- "colab": {
937
- "base_uri": "https://localhost:8080/"
938
- },
939
- "id": "Nl6I9n9IRTSo",
940
- "outputId": "162aab09-b860-4b73-c0ae-394451367460"
941
- },
942
- "execution_count": 27,
943
- "outputs": [
944
- {
945
- "output_type": "execute_result",
946
- "data": {
947
- "text/plain": [
948
- "tensor(1.0449)"
949
- ]
950
- },
951
- "metadata": {},
952
- "execution_count": 27
953
- }
954
- ]
955
- },
956
- {
957
- "cell_type": "code",
958
- "source": [
959
- "q.var()"
960
- ],
961
- "metadata": {
962
- "colab": {
963
- "base_uri": "https://localhost:8080/"
964
- },
965
- "id": "T1tQx7oeRvtc",
966
- "outputId": "20aacd2d-d414-4268-981e-86a5fd8afcc8"
967
- },
968
- "execution_count": 28,
969
- "outputs": [
970
- {
971
- "output_type": "execute_result",
972
- "data": {
973
- "text/plain": [
974
- "tensor(1.0700)"
975
- ]
976
- },
977
- "metadata": {},
978
- "execution_count": 28
979
- }
980
- ]
981
- },
982
- {
983
- "cell_type": "code",
984
- "source": [
985
- "wei.var()"
986
- ],
987
- "metadata": {
988
- "colab": {
989
- "base_uri": "https://localhost:8080/"
990
- },
991
- "id": "MLb_odHU3iKM",
992
- "outputId": "5d6ca0fd-51df-42ec-daf8-7fb2ff9f640f"
993
- },
994
- "execution_count": 29,
995
- "outputs": [
996
- {
997
- "output_type": "execute_result",
998
- "data": {
999
- "text/plain": [
1000
- "tensor(1.0918)"
1001
- ]
1002
- },
1003
- "metadata": {},
1004
- "execution_count": 29
1005
- }
1006
- ]
1007
- },
1008
- {
1009
- "cell_type": "code",
1010
- "source": [
1011
- "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)"
1012
- ],
1013
- "metadata": {
1014
- "colab": {
1015
- "base_uri": "https://localhost:8080/"
1016
- },
1017
- "id": "JB82yzt44REI",
1018
- "outputId": "df0211e7-a2b0-46c7-9fd2-c5a8cc185ed7"
1019
- },
1020
- "execution_count": 30,
1021
- "outputs": [
1022
- {
1023
- "output_type": "execute_result",
1024
- "data": {
1025
- "text/plain": [
1026
- "tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])"
1027
- ]
1028
- },
1029
- "metadata": {},
1030
- "execution_count": 30
1031
- }
1032
- ]
1033
- },
1034
- {
1035
- "cell_type": "code",
1036
- "source": [
1037
- "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky, converges to one-hot"
1038
- ],
1039
- "metadata": {
1040
- "colab": {
1041
- "base_uri": "https://localhost:8080/"
1042
- },
1043
- "id": "Mpt8569BB9_f",
1044
- "outputId": "cf991a1e-7072-4944-d578-886a270f57de"
1045
- },
1046
- "execution_count": 31,
1047
- "outputs": [
1048
- {
1049
- "output_type": "execute_result",
1050
- "data": {
1051
- "text/plain": [
1052
- "tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])"
1053
- ]
1054
- },
1055
- "metadata": {},
1056
- "execution_count": 31
1057
- }
1058
- ]
1059
- },
1060
- {
1061
- "cell_type": "code",
1062
- "source": [
1063
- "class LayerNorm1d: # (used to be BatchNorm1d)\n",
1064
- "\n",
1065
- " def __init__(self, dim, eps=1e-5, momentum=0.1):\n",
1066
- " self.eps = eps\n",
1067
- " self.gamma = torch.ones(dim)\n",
1068
- " self.beta = torch.zeros(dim)\n",
1069
- "\n",
1070
- " def __call__(self, x):\n",
1071
- " # calculate the forward pass\n",
1072
- " xmean = x.mean(1, keepdim=True) # batch mean\n",
1073
- " xvar = x.var(1, keepdim=True) # batch variance\n",
1074
- " xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance\n",
1075
- " self.out = self.gamma * xhat + self.beta\n",
1076
- " return self.out\n",
1077
- "\n",
1078
- " def parameters(self):\n",
1079
- " return [self.gamma, self.beta]\n",
1080
- "\n",
1081
- "torch.manual_seed(1337)\n",
1082
- "module = LayerNorm1d(100)\n",
1083
- "x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors\n",
1084
- "x = module(x)\n",
1085
- "x.shape"
1086
- ],
1087
- "metadata": {
1088
- "colab": {
1089
- "base_uri": "https://localhost:8080/"
1090
- },
1091
- "id": "2Num7sX9CKOH",
1092
- "outputId": "14c48660-c741-4cb8-ac79-53d2bf094a63"
1093
- },
1094
- "execution_count": 32,
1095
- "outputs": [
1096
- {
1097
- "output_type": "execute_result",
1098
- "data": {
1099
- "text/plain": [
1100
- "torch.Size([32, 100])"
1101
- ]
1102
- },
1103
- "metadata": {},
1104
- "execution_count": 32
1105
- }
1106
- ]
1107
- },
1108
- {
1109
- "cell_type": "code",
1110
- "source": [
1111
- "x[:,0].mean(), x[:,0].std() # mean,std of one feature across all batch inputs"
1112
- ],
1113
- "metadata": {
1114
- "colab": {
1115
- "base_uri": "https://localhost:8080/"
1116
- },
1117
- "id": "633T2cmnW1uk",
1118
- "outputId": "2a6e887c-6b82-454f-8f32-aefde73777c5"
1119
- },
1120
- "execution_count": 33,
1121
- "outputs": [
1122
- {
1123
- "output_type": "execute_result",
1124
- "data": {
1125
- "text/plain": [
1126
- "(tensor(0.1469), tensor(0.8803))"
1127
- ]
1128
- },
1129
- "metadata": {},
1130
- "execution_count": 33
1131
- }
1132
- ]
1133
- },
1134
- {
1135
- "cell_type": "code",
1136
- "source": [
1137
- "x[0,:].mean(), x[0,:].std() # mean,std of a single input from the batch, of its features"
1138
- ],
1139
- "metadata": {
1140
- "colab": {
1141
- "base_uri": "https://localhost:8080/"
1142
- },
1143
- "id": "LN9cK9BoXCYb",
1144
- "outputId": "4c81f68e-b1d2-4a04-d38d-09583f104ea7"
1145
- },
1146
- "execution_count": 34,
1147
- "outputs": [
1148
- {
1149
- "output_type": "execute_result",
1150
- "data": {
1151
- "text/plain": [
1152
- "(tensor(-9.5367e-09), tensor(1.0000))"
1153
- ]
1154
- },
1155
- "metadata": {},
1156
- "execution_count": 34
1157
- }
1158
- ]
1159
- },
1160
- {
1161
- "cell_type": "code",
1162
- "source": [
1163
- "# French to English translation example:\n",
1164
- "\n",
1165
- "# <--------- ENCODE ------------------><--------------- DECODE ----------------->\n",
1166
- "# les réseaux de neurones sont géniaux! <START> neural networks are awesome!<END>\n",
1167
- "\n"
1168
- ],
1169
- "metadata": {
1170
- "id": "dRJH6wM_XFfU"
1171
- },
1172
- "execution_count": 35,
1173
- "outputs": []
1174
- },
1175
- {
1176
- "cell_type": "markdown",
1177
- "source": [
1178
- "### Full finished code, for reference\n",
1179
- "\n",
1180
- "You may want to refer directly to the git repo instead though."
1181
- ],
1182
- "metadata": {
1183
- "id": "ZcvKeBXoZFOY"
1184
- }
1185
- },
1186
- {
1187
- "cell_type": "code",
1188
- "source": [
1189
- "import torch\n",
1190
- "import torch.nn as nn\n",
1191
- "from torch.nn import functional as F\n",
1192
- "\n",
1193
- "# hyperparameters\n",
1194
- "batch_size = 16 # how many independent sequences will we process in parallel?\n",
1195
- "block_size = 32 # what is the maximum context length for predictions?\n",
1196
- "max_iters = 5000\n",
1197
- "#00\n",
1198
- "eval_interval = 100\n",
1199
- "learning_rate = 1e-3\n",
1200
- "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
1201
- "eval_iters = 200\n",
1202
- "n_embd = 64\n",
1203
- "n_head = 4\n",
1204
- "n_layer = 4\n",
1205
- "dropout = 0.0\n",
1206
- "# ------------\n",
1207
- "\n",
1208
- "torch.manual_seed(1337)\n",
1209
- "\n",
1210
- "# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
1211
- "with open('input.txt', 'r', encoding='utf-8') as f:\n",
1212
- " text = f.read()\n",
1213
- "\n",
1214
- "# here are all the unique characters that occur in this text\n",
1215
- "chars = sorted(list(set(text)))\n",
1216
- "vocab_size = len(chars)\n",
1217
- "# create a mapping from characters to integers\n",
1218
- "stoi = { ch:i for i,ch in enumerate(chars) }\n",
1219
- "itos = { i:ch for i,ch in enumerate(chars) }\n",
1220
- "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
1221
- "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
1222
- "\n",
1223
- "# Train and test splits\n",
1224
- "data = torch.tensor(encode(text), dtype=torch.long)\n",
1225
- "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
1226
- "train_data = data[:n]\n",
1227
- "val_data = data[n:]\n",
1228
- "\n",
1229
- "# data loading\n",
1230
- "def get_batch(split):\n",
1231
- " # generate a small batch of data of inputs x and targets y\n",
1232
- " data = train_data if split == 'train' else val_data\n",
1233
- " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
1234
- " x = torch.stack([data[i:i+block_size] for i in ix])\n",
1235
- " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
1236
- " x, y = x.to(device), y.to(device)\n",
1237
- " return x, y\n",
1238
- "\n",
1239
- "@torch.no_grad()\n",
1240
- "def estimate_loss():\n",
1241
- " out = {}\n",
1242
- " model.eval()\n",
1243
- " for split in ['train', 'val']:\n",
1244
- " losses = torch.zeros(eval_iters)\n",
1245
- " for k in range(eval_iters):\n",
1246
- " X, Y = get_batch(split)\n",
1247
- " logits, loss = model(X, Y)\n",
1248
- " losses[k] = loss.item()\n",
1249
- " out[split] = losses.mean()\n",
1250
- " model.train()\n",
1251
- " return out\n",
1252
- "\n",
1253
- "class Head(nn.Module):\n",
1254
- " \"\"\" one head of self-attention \"\"\"\n",
1255
- "\n",
1256
- " def __init__(self, head_size):\n",
1257
- " super().__init__()\n",
1258
- " self.key = nn.Linear(n_embd, head_size, bias=False)\n",
1259
- " self.query = nn.Linear(n_embd, head_size, bias=False)\n",
1260
- " self.value = nn.Linear(n_embd, head_size, bias=False)\n",
1261
- " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
1262
- "\n",
1263
- " self.dropout = nn.Dropout(dropout)\n",
1264
- "\n",
1265
- " def forward(self, x):\n",
1266
- " B,T,C = x.shape\n",
1267
- " k = self.key(x) # (B,T,C)\n",
1268
- " q = self.query(x) # (B,T,C)\n",
1269
- " # compute attention scores (\"affinities\")\n",
1270
- " wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n",
1271
- " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
1272
- " wei = F.softmax(wei, dim=-1) # (B, T, T)\n",
1273
- " wei = self.dropout(wei)\n",
1274
- " # perform the weighted aggregation of the values\n",
1275
- " v = self.value(x) # (B,T,C)\n",
1276
- " out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n",
1277
- " return out\n",
1278
- "\n",
1279
- "class MultiHeadAttention(nn.Module):\n",
1280
- " \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
1281
- "\n",
1282
- " def __init__(self, num_heads, head_size):\n",
1283
- " super().__init__()\n",
1284
- " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
1285
- " self.proj = nn.Linear(n_embd, n_embd)\n",
1286
- " self.dropout = nn.Dropout(dropout)\n",
1287
- "\n",
1288
- " def forward(self, x):\n",
1289
- " out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
1290
- " out = self.dropout(self.proj(out))\n",
1291
- " return out\n",
1292
- "\n",
1293
- "class FeedFoward(nn.Module):\n",
1294
- " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n",
1295
- "\n",
1296
- " def __init__(self, n_embd):\n",
1297
- " super().__init__()\n",
1298
- " self.net = nn.Sequential(\n",
1299
- " nn.Linear(n_embd, 4 * n_embd),\n",
1300
- " nn.ReLU(),\n",
1301
- " nn.Linear(4 * n_embd, n_embd),\n",
1302
- " nn.Dropout(dropout),\n",
1303
- " )\n",
1304
- "\n",
1305
- " def forward(self, x):\n",
1306
- " return self.net(x)\n",
1307
- "\n",
1308
- "class Block(nn.Module):\n",
1309
- " \"\"\" Transformer block: communication followed by computation \"\"\"\n",
1310
- "\n",
1311
- " def __init__(self, n_embd, n_head):\n",
1312
- " # n_embd: embedding dimension, n_head: the number of heads we'd like\n",
1313
- " super().__init__()\n",
1314
- " head_size = n_embd // n_head\n",
1315
- " self.sa = MultiHeadAttention(n_head, head_size)\n",
1316
- " self.ffwd = FeedFoward(n_embd)\n",
1317
- " self.ln1 = nn.LayerNorm(n_embd)\n",
1318
- " self.ln2 = nn.LayerNorm(n_embd)\n",
1319
- "\n",
1320
- " def forward(self, x):\n",
1321
- " x = x + self.sa(self.ln1(x))\n",
1322
- " x = x + self.ffwd(self.ln2(x))\n",
1323
- " return x\n",
1324
- "\n",
1325
- "# super simple bigram model\n",
1326
- "class BigramLanguageModel(nn.Module):\n",
1327
- "\n",
1328
- " def __init__(self):\n",
1329
- " super().__init__()\n",
1330
- " # each token directly reads off the logits for the next token from a lookup table\n",
1331
- " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n",
1332
- " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n",
1333
- " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n",
1334
- " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n",
1335
- " self.lm_head = nn.Linear(n_embd, vocab_size)\n",
1336
- "\n",
1337
- " def forward(self, idx, targets=None):\n",
1338
- " B, T = idx.shape\n",
1339
- "\n",
1340
- " # idx and targets are both (B,T) tensor of integers\n",
1341
- " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n",
1342
- " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n",
1343
- " x = tok_emb + pos_emb # (B,T,C)\n",
1344
- " x = self.blocks(x) # (B,T,C)\n",
1345
- " x = self.ln_f(x) # (B,T,C)\n",
1346
- " logits = self.lm_head(x) # (B,T,vocab_size)\n",
1347
- "\n",
1348
- " if targets is None:\n",
1349
- " loss = None\n",
1350
- " else:\n",
1351
- " B, T, C = logits.shape\n",
1352
- " logits = logits.view(B*T, C)\n",
1353
- " targets = targets.view(B*T)\n",
1354
- " loss = F.cross_entropy(logits, targets)\n",
1355
- "\n",
1356
- " return logits, loss\n",
1357
- "\n",
1358
- " def generate(self, idx, max_new_tokens):\n",
1359
- " # idx is (B, T) array of indices in the current context\n",
1360
- " for _ in range(max_new_tokens):\n",
1361
- " # crop idx to the last block_size tokens\n",
1362
- " idx_cond = idx[:, -block_size:]\n",
1363
- " # get the predictions\n",
1364
- " logits, loss = self(idx_cond)\n",
1365
- " # focus only on the last time step\n",
1366
- " logits = logits[:, -1, :] # becomes (B, C)\n",
1367
- " # apply softmax to get probabilities\n",
1368
- " probs = F.softmax(logits, dim=-1) # (B, C)\n",
1369
- " # sample from the distribution\n",
1370
- " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
1371
- " # append sampled index to the running sequence\n",
1372
- " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
1373
- " return idx\n",
1374
- "\n",
1375
- "model = BigramLanguageModel()\n",
1376
- "m = model.to(device)\n",
1377
- "# print the number of parameters in the model\n",
1378
- "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n",
1379
- "\n",
1380
- "# create a PyTorch optimizer\n",
1381
- "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
1382
- "\n",
1383
- "for iter in range(max_iters):\n",
1384
- "\n",
1385
- " # every once in a while evaluate the loss on train and val sets\n",
1386
- " if iter % eval_interval == 0 or iter == max_iters - 1:\n",
1387
- " losses = estimate_loss()\n",
1388
- " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n",
1389
- "\n",
1390
- " # sample a batch of data\n",
1391
- " xb, yb = get_batch('train')\n",
1392
- "\n",
1393
- " # evaluate the loss\n",
1394
- " logits, loss = model(xb, yb)\n",
1395
- " optimizer.zero_grad(set_to_none=True)\n",
1396
- " loss.backward()\n",
1397
- " optimizer.step()\n",
1398
- "\n",
1399
- "# generate from the model\n",
1400
- "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n",
1401
- "print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))\n"
1402
- ],
1403
- "metadata": {
1404
- "colab": {
1405
- "base_uri": "https://localhost:8080/"
1406
- },
1407
- "id": "hoelkOrFY8bN",
1408
- "outputId": "4f7e6e13-879e-469d-dcdb-0d3c48e263c5"
1409
- },
1410
- "execution_count": 37,
1411
- "outputs": [
1412
- {
1413
- "output_type": "stream",
1414
- "name": "stdout",
1415
- "text": [
1416
- "0.209729 M parameters\n",
1417
- "step 0: train loss 4.4116, val loss 4.4022\n",
1418
- "step 100: train loss 2.6568, val loss 2.6670\n",
1419
- "step 200: train loss 2.5090, val loss 2.5059\n",
1420
- "step 300: train loss 2.4196, val loss 2.4338\n",
1421
- "step 400: train loss 2.3504, val loss 2.3566\n",
1422
- "step 500: train loss 2.2965, val loss 2.3129\n",
1423
- "step 600: train loss 2.2410, val loss 2.2500\n",
1424
- "step 700: train loss 2.2057, val loss 2.2191\n",
1425
- "step 800: train loss 2.1633, val loss 2.1864\n",
1426
- "step 900: train loss 2.1244, val loss 2.1510\n",
1427
- "step 1000: train loss 2.1038, val loss 2.1308\n",
1428
- "step 1100: train loss 2.0707, val loss 2.1197\n",
1429
- "step 1200: train loss 2.0377, val loss 2.0800\n",
1430
- "step 1300: train loss 2.0268, val loss 2.0650\n",
1431
- "step 1400: train loss 1.9918, val loss 2.0356\n",
1432
- "step 1500: train loss 1.9697, val loss 2.0293\n",
1433
- "step 1600: train loss 1.9645, val loss 2.0499\n",
1434
- "step 1700: train loss 1.9404, val loss 2.0129\n",
1435
- "step 1800: train loss 1.9095, val loss 1.9951\n",
1436
- "step 1900: train loss 1.9067, val loss 1.9855\n",
1437
- "step 2000: train loss 1.8854, val loss 1.9948\n",
1438
- "step 2100: train loss 1.8727, val loss 1.9766\n",
1439
- "step 2200: train loss 1.8597, val loss 1.9631\n",
1440
- "step 2300: train loss 1.8530, val loss 1.9516\n",
1441
- "step 2400: train loss 1.8428, val loss 1.9464\n",
1442
- "step 2500: train loss 1.8161, val loss 1.9424\n",
1443
- "step 2600: train loss 1.8283, val loss 1.9406\n",
1444
- "step 2700: train loss 1.8101, val loss 1.9322\n",
1445
- "step 2800: train loss 1.8050, val loss 1.9233\n",
1446
- "step 2900: train loss 1.8033, val loss 1.9289\n",
1447
- "step 3000: train loss 1.7955, val loss 1.9216\n",
1448
- "step 3100: train loss 1.7697, val loss 1.9184\n",
1449
- "step 3200: train loss 1.7541, val loss 1.9088\n",
1450
- "step 3300: train loss 1.7567, val loss 1.9034\n",
1451
- "step 3400: train loss 1.7573, val loss 1.9000\n",
1452
- "step 3500: train loss 1.7398, val loss 1.8925\n",
1453
- "step 3600: train loss 1.7270, val loss 1.8869\n",
1454
- "step 3700: train loss 1.7283, val loss 1.8814\n",
1455
- "step 3800: train loss 1.7210, val loss 1.8918\n",
1456
- "step 3900: train loss 1.7219, val loss 1.8732\n",
1457
- "step 4000: train loss 1.7146, val loss 1.8576\n",
1458
- "step 4100: train loss 1.7136, val loss 1.8720\n",
1459
- "step 4200: train loss 1.7060, val loss 1.8653\n",
1460
- "step 4300: train loss 1.7032, val loss 1.8499\n",
1461
- "step 4400: train loss 1.7057, val loss 1.8656\n",
1462
- "step 4500: train loss 1.6907, val loss 1.8477\n",
1463
- "step 4600: train loss 1.6878, val loss 1.8371\n",
1464
- "step 4700: train loss 1.6808, val loss 1.8415\n",
1465
- "step 4800: train loss 1.6689, val loss 1.8457\n",
1466
- "step 4900: train loss 1.6716, val loss 1.8415\n",
1467
- "step 4999: train loss 1.6658, val loss 1.8275\n",
1468
- "\n",
1469
- "ROTCUMER:\n",
1470
- "Tyburforth, bloody,\n",
1471
- "WhIs migute: you duke I use list. WIthon of where's grande will! savist tought!\n",
1472
- "Why room upwor alond, liegle. I hone, Iell thou sudd have then strue thus mind,\n",
1473
- "His by blow, Virdom tow, glingien, yithre spees ssince them Those not.\n",
1474
- "\n",
1475
- "LUCIO:\n",
1476
- "Look,----\n",
1477
- "But thou sging them this my freceimmsed,\n",
1478
- "By thou sovor conursion that thou sade but grove\n",
1479
- "the tage encond:\n",
1480
- "It will Rament me; an your touther,\n",
1481
- "And havis like to-does, and little spright.\n",
1482
- "\n",
1483
- "GLOUCESTER:\n",
1484
- "Rewards thou for Panfessira's bigguards such ways!\n",
1485
- "What curfort his\n",
1486
- "will havolss you, as I have the cervirs arled,\n",
1487
- "Dear my love and pitace unto duly son.\n",
1488
- "\n",
1489
- "Secome:\n",
1490
- "Offolk, even thy whose my late all that you by jotly us belies!\n",
1491
- "Lord, we a-montencry! I\n",
1492
- "\n",
1493
- "SLARNE:\n",
1494
- "Day, mave from out prrive And orculing\n",
1495
- "What confess, temimelyour and stropt;\n",
1496
- "Secumfospet the gatieus I'll that confence-sting,\n",
1497
- "But; man't, Rolget\n",
1498
- "would garnion'd live in which, you, prothre?\n",
1499
- "\n",
1500
- "CORIOLANUS:\n",
1501
- "What bonum stravoing, not out be seemmed with\n",
1502
- "That the boly noll to.\n",
1503
- "Bently, which in on my not tomberven why, fortune,\n",
1504
- "And that wark you, banot thus orl'ld groves viles.\n",
1505
- "\n",
1506
- "PUMNIUS:\n",
1507
- "It thou addow less, proth-straing.\n",
1508
- "Mutwing your contrant stomfe, whom they\n",
1509
- "is by this famestle; and of the loves my not Mercarcious to the stord; thesoo, in thus my nome are:\n",
1510
- "Will fuch, have there enplience your gone, ho's,\n",
1511
- "And gentleman, my beged lind to be am\n",
1512
- "in That ant:\n",
1513
- "In I sugner murded! I play's,\n",
1514
- "If not sume the confity will reasur slord:\n",
1515
- "That get because at that his say\n",
1516
- "and to beepts guarst you lom if then.\n",
1517
- "\n",
1518
- "MENEN MARGARUS:\n",
1519
- "I but aftelence! made yoour never.\n",
1520
- "\n",
1521
- "KING RICHARD II:\n",
1522
- "Who too near?\n",
1523
- "\n",
1524
- "LORDIUS:\n",
1525
- "Or as madaw brird, tou thee?\n",
1526
- "\n",
1527
- "Sirightly the haste's beforempt.\n",
1528
- "\n",
1529
- "First:\n",
1530
- "Is though.\n",
1531
- "Fell, whose toes with requmpts, up I make\n",
1532
- "Here figUS verean that I will, by the wateon.\n",
1533
- "\n",
1534
- "MOWIDIUS:\n",
1535
- "How, while, more is in meep.\n",
1536
- "twan be the fless this countrens platcar merperter sure make Giventled,\n",
1537
- "At not your must to reason togs,\n",
1538
- "And what you gue;--\n",
1539
- "\n",
1540
- "RUKE ESFiren; gravent,\n",
1541
- "Apol\n"
1542
- ]
1543
- }
1544
- ]
1545
- },
1546
- {
1547
- "cell_type": "code",
1548
- "source": [],
1549
- "metadata": {
1550
- "id": "fjjvMifYZf7x"
1551
- },
1552
- "execution_count": 36,
1553
- "outputs": []
1554
- }
1555
- ]
1556
- }