ibrahimmkhalid commited on
Commit
2ef951d
·
1 Parent(s): 7554be3

add openwebtext based gpt model

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ openwebtext.tar.xz filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -3,3 +3,5 @@
3
 
4
  venv/
5
  .ipynb_checkpoints/
 
 
 
3
 
4
  venv/
5
  .ipynb_checkpoints/
6
+
7
+ openwebtext/
extract.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import lzma
3
+ from tqdm import tqdm
4
+
5
+ def xz_files_in_dir(directory):
6
+ files = []
7
+ for filename in os.listdir(directory):
8
+ if filename.endswith(".xz") and os.path.isfile(os.path.join(directory, filename)):
9
+ files.append(filename)
10
+ return files
11
+
12
+ tarxz_path = "./openwebtext.tar.xz"
13
+ folder_path = "./openwebtext"
14
+ output_file_train = "./openwebtext/train_split.txt"
15
+ output_file_val = "./openwebtext/val_split.txt"
16
+ vocab_file = "./openwebtext/vocab.txt"
17
+
18
+ # Extract the tar.xz file
19
+ if not os.path.exists(folder_path):
20
+ os.mkdir(folder_path)
21
+ os.system(f"tar -xvf {tarxz_path}")
22
+
23
+ files = xz_files_in_dir(folder_path)
24
+ total_files = len(files)
25
+
26
+ # Calculate the split indices
27
+ split_index = int(total_files * 0.9) # 90% for training
28
+ files_train = files[:split_index]
29
+ files_val = files[split_index:]
30
+
31
+ # Process the files for training and validation separately
32
+ vocab = set()
33
+
34
+ # Process the training files
35
+ if not os.path.exists(output_file_train):
36
+ with open(output_file_train, "w", encoding="utf-8") as outfile:
37
+ for filename in tqdm(files_train, total=len(files_train)):
38
+ file_path = os.path.join(folder_path, filename)
39
+ with lzma.open(file_path, "rt", encoding="utf-8") as infile:
40
+ text = infile.read()
41
+ outfile.write(text)
42
+ characters = set(text)
43
+ vocab.update(characters)
44
+
45
+ # Process the validation files
46
+ if not os.path.exists(output_file_val):
47
+ with open(output_file_val, "w", encoding="utf-8") as outfile:
48
+ for filename in tqdm(files_val, total=len(files_val)):
49
+ file_path = os.path.join(folder_path, filename)
50
+ with lzma.open(file_path, "rt", encoding="utf-8") as infile:
51
+ text = infile.read()
52
+ outfile.write(text)
53
+ characters = set(text)
54
+ vocab.update(characters)
55
+
56
+ # Write the vocabulary to vocab.txt
57
+ if not os.path.exists(vocab_file):
58
+ with open(vocab_file, "w", encoding="utf-8") as vfile:
59
+ for char in vocab:
60
+ vfile.write(char + '\n')
gpt_openwebtext.sync.ipynb ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 5,
6
+ "id": "b84a347c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "import torch.nn as nn\n",
12
+ "from torch.nn import functional as F\n",
13
+ "import mmap\n",
14
+ "import random\n",
15
+ "import pickle"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 6,
21
+ "id": "058368c2",
22
+ "metadata": {},
23
+ "outputs": [
24
+ {
25
+ "name": "stdout",
26
+ "output_type": "stream",
27
+ "text": [
28
+ "cuda\n"
29
+ ]
30
+ }
31
+ ],
32
+ "source": [
33
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
34
+ "print(device)\n",
35
+ "block_size = 128\n",
36
+ "batch_size = 32\n",
37
+ "max_iters = 4000\n",
38
+ "learning_rate = 3e-4\n",
39
+ "eval_every = 500\n",
40
+ "n_embd = 384\n",
41
+ "n_head = 8\n",
42
+ "n_layer = 8\n",
43
+ "dropout = 0.2"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 7,
49
+ "id": "4ec3625c",
50
+ "metadata": {},
51
+ "outputs": [
52
+ {
53
+ "ename": "FileNotFoundError",
54
+ "evalue": "[Errno 2] No such file or directory: 'openwebtext/vocab.txt'",
55
+ "output_type": "error",
56
+ "traceback": [
57
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
58
+ "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
59
+ "Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m chars \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mopenwebtext/vocab.txt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mr\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mutf-8\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 3\u001b[0m text \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m.\u001b[39mread()\n\u001b[1;32m 4\u001b[0m chars \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msorted\u001b[39m(\u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mset\u001b[39m(text)))\n",
60
+ "File \u001b[0;32m~/repos/main/llm-from-scratch/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:310\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m 303\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m}:\n\u001b[1;32m 304\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 305\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIPython won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m by default \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 307\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myou can use builtins\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m open.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 308\u001b[0m )\n\u001b[0;32m--> 310\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
61
+ "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'openwebtext/vocab.txt'"
62
+ ]
63
+ }
64
+ ],
65
+ "source": [
66
+ "chars = \"\"\n",
67
+ "with open(\"./openwebtext/vocab.txt\", 'r', encoding='utf-8') as f:\n",
68
+ " text = f.read()\n",
69
+ " chars = sorted(list(set(text)))\n",
70
+ " \n",
71
+ "vocab_size = len(chars)"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "id": "15e6af07",
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "print(f\"Vocab size: {vocab_size}\")\n",
82
+ "print(f\"Text length: {len(text)}\")"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "id": "425bf0b5",
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "string_to_int = {ch: i for i, ch in enumerate(chars)}\n",
93
+ "int_to_string = {i: ch for i, ch in enumerate(chars)}\n",
94
+ "\n",
95
+ "encode = lambda s: [string_to_int[ch] for ch in s]\n",
96
+ "decode = lambda x: ''.join([int_to_string[i] for i in x])"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "id": "1b141a3a",
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "# memory map for using small snippets of text from a single file of any size\n",
107
+ "def get_random_chunk(split):\n",
108
+ " filename = \"./openwebtext/train_split.txt\" if split == 'train' else \"./openwebtext/val_split.txt\"\n",
109
+ " with open(filename, 'rb') as f:\n",
110
+ " with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:\n",
111
+ " # Determine the file size and a random position to start reading\n",
112
+ " file_size = len(mm)\n",
113
+ " start_pos = random.randint(0, (file_size) - block_size*batch_size)\n",
114
+ "\n",
115
+ " # Seek to the random position and read the block of text\n",
116
+ " mm.seek(start_pos)\n",
117
+ " block = mm.read(block_size*batch_size-1)\n",
118
+ "\n",
119
+ " # Decode the block to a string, ignoring any invalid byte sequences\n",
120
+ " decoded_block = block.decode('utf-8', errors='ignore').replace('\\r', '')\n",
121
+ " \n",
122
+ " # Train and test splits\n",
123
+ " data = torch.tensor(encode(decoded_block), dtype=torch.long)\n",
124
+ " \n",
125
+ " return data\n",
126
+ "\n",
127
+ "\n",
128
+ "def get_batch(split):\n",
129
+ " data = get_random_chunk(split)\n",
130
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
131
+ " x = torch.stack([data[i:i+block_size] for i in ix])\n",
132
+ " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
133
+ " x, y = x.to(device), y.to(device)\n",
134
+ " return x, y"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "4b27acf7",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "@torch.no_grad()\n",
145
+ "def estimate_loss():\n",
146
+ " out = {}\n",
147
+ " model.eval()\n",
148
+ " for split in ['train', 'val']:\n",
149
+ " losses = torch.zeros(eval_every)\n",
150
+ " for k in range(eval_every):\n",
151
+ " X, Y = get_batch(split)\n",
152
+ " logits, loss = model(X, Y)\n",
153
+ " losses[k] = loss.item()\n",
154
+ " out[split] = losses.mean()\n",
155
+ " model.train()\n",
156
+ " return out"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "517553b5",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "\n",
167
+ "class Head(nn.Module):\n",
168
+ " \"\"\" one head of self-attention \"\"\"\n",
169
+ "\n",
170
+ " def __init__(self, head_size):\n",
171
+ " super().__init__()\n",
172
+ " self.key = nn.Linear(n_embd, head_size, bias=False)\n",
173
+ " self.query = nn.Linear(n_embd, head_size, bias=False)\n",
174
+ " self.value = nn.Linear(n_embd, head_size, bias=False)\n",
175
+ " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
176
+ "\n",
177
+ " self.dropout = nn.Dropout(dropout)\n",
178
+ "\n",
179
+ " def forward(self, x):\n",
180
+ " # input of size (batch, time-step, channels)\n",
181
+ " # output of size (batch, time-step, head size)\n",
182
+ " B,T,C = x.shape\n",
183
+ " k = self.key(x) # (B,T,hs)\n",
184
+ " q = self.query(x) # (B,T,hs)\n",
185
+ " # compute attention scores (\"affinities\")\n",
186
+ " wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)\n",
187
+ " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
188
+ " wei = F.softmax(wei, dim=-1) # (B, T, T)\n",
189
+ " wei = self.dropout(wei)\n",
190
+ " # perform the weighted aggregation of the values\n",
191
+ " v = self.value(x) # (B,T,hs)\n",
192
+ " out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)\n",
193
+ " return out\n",
194
+ "\n",
195
+ "# [1, 0, 0]\n",
196
+ "# [1, 0.6, 0]\n",
197
+ "# [1, 0.6, 0.4]\n",
198
+ "class MultiHeadAttention(nn.Module):\n",
199
+ " \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
200
+ "\n",
201
+ " def __init__(self, num_heads, head_size):\n",
202
+ " super().__init__()\n",
203
+ " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
204
+ " self.proj = nn.Linear(head_size * num_heads, n_embd)\n",
205
+ " self.dropout = nn.Dropout(dropout)\n",
206
+ "\n",
207
+ " def forward(self, x):\n",
208
+ " out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, F) -> (B, T, [h1, h1, h1, h1, h2, h2, h2, h2, h3, h3, h3, h3])\n",
209
+ " out = self.dropout(self.proj(out))\n",
210
+ " return out\n",
211
+ " \n",
212
+ "\n",
213
+ "class FeedFoward(nn.Module):\n",
214
+ " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n",
215
+ "\n",
216
+ " def __init__(self, n_embd):\n",
217
+ " super().__init__()\n",
218
+ " self.net = nn.Sequential(\n",
219
+ " nn.Linear(n_embd, 4 * n_embd),\n",
220
+ " nn.ReLU(),\n",
221
+ " nn.Linear(4 * n_embd, n_embd),\n",
222
+ " nn.Dropout(dropout),\n",
223
+ " )\n",
224
+ "\n",
225
+ " def forward(self, x):\n",
226
+ " return self.net(x)\n",
227
+ " \n",
228
+ "class Block(nn.Module):\n",
229
+ " \"\"\" Transformer block: communication followed by computation \"\"\"\n",
230
+ "\n",
231
+ " def __init__(self, n_embd, n_head):\n",
232
+ " # n_embd: embedding dimension, n_head: the number of heads we'd like\n",
233
+ " super().__init__()\n",
234
+ " head_size = n_embd // n_head\n",
235
+ " self.sa = MultiHeadAttention(n_head, head_size)\n",
236
+ " self.ffwd = FeedFoward(n_embd)\n",
237
+ " self.ln1 = nn.LayerNorm(n_embd)\n",
238
+ " self.ln2 = nn.LayerNorm(n_embd)\n",
239
+ "\n",
240
+ " def forward(self, x):\n",
241
+ " y = self.sa(x)\n",
242
+ " x = self.ln1(x + y)\n",
243
+ " y = self.ffwd(x)\n",
244
+ " x = self.ln2(x + y)\n",
245
+ " return x\n",
246
+ " \n",
247
+ "class GPTLanguageModel(nn.Module):\n",
248
+ " def __init__(self, vocab_size):\n",
249
+ " super().__init__()\n",
250
+ " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n",
251
+ " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n",
252
+ " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n",
253
+ " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n",
254
+ " self.lm_head = nn.Linear(n_embd, vocab_size)\n",
255
+ " \n",
256
+ " \n",
257
+ " self.apply(self._init_weights)\n",
258
+ "\n",
259
+ " def _init_weights(self, module):\n",
260
+ " if isinstance(module, nn.Linear):\n",
261
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
262
+ " if module.bias is not None:\n",
263
+ " torch.nn.init.zeros_(module.bias)\n",
264
+ " elif isinstance(module, nn.Embedding):\n",
265
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
266
+ "\n",
267
+ " def forward(self, index, targets=None):\n",
268
+ " B, T = index.shape\n",
269
+ " \n",
270
+ " \n",
271
+ " # idx and targets are both (B,T) tensor of integers\n",
272
+ " tok_emb = self.token_embedding_table(index) # (B,T,C)\n",
273
+ " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n",
274
+ " x = tok_emb + pos_emb # (B,T,C)\n",
275
+ " x = self.blocks(x) # (B,T,C)\n",
276
+ " x = self.ln_f(x) # (B,T,C)\n",
277
+ " logits = self.lm_head(x) # (B,T,vocab_size)\n",
278
+ " \n",
279
+ " if targets is None:\n",
280
+ " loss = None\n",
281
+ " else:\n",
282
+ " B, T, C = logits.shape\n",
283
+ " logits = logits.view(B*T, C) # reshape to what torch.cross_entropy expects\n",
284
+ " targets = targets.view(B*T)\n",
285
+ " loss = F.cross_entropy(logits, targets) \n",
286
+ " return logits, loss\n",
287
+ " \n",
288
+ " def generate(self, index, max_new_tokens):\n",
289
+ " # index is (B, T) array of indices in the current context\n",
290
+ " for _ in range(max_new_tokens):\n",
291
+ " # crop idx to the last block_size tokens\n",
292
+ " index_cond = index[:, -block_size:]\n",
293
+ " # get the predictions\n",
294
+ " logits, loss = self.forward(index_cond)\n",
295
+ " # focus only on the last time step\n",
296
+ " logits = logits[:, -1, :] # becomes (B, C)\n",
297
+ " # apply softmax to get probabilities\n",
298
+ " probs = F.softmax(logits, dim=-1) # (B, C)\n",
299
+ " # sample from the distribution\n",
300
+ " index_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
301
+ " # append sampled index to the running sequence\n",
302
+ " index = torch.cat((index, index_next), dim=1) # (B, T+1)\n",
303
+ " return index\n",
304
+ "\n",
305
+ "model = GPTLanguageModel(vocab_size).to(device)\n",
306
+ "\n",
307
+ "print('loading model parameters...')\n",
308
+ "with open('model-01.pkl', 'rb') as f:\n",
309
+ " model = pickle.load(f)\n",
310
+ "print('loaded successfully!')"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": null,
316
+ "id": "bb0f76ef",
317
+ "metadata": {},
318
+ "outputs": [],
319
+ "source": [
320
+ "# create a PyTorch optimizer\n",
321
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
322
+ "\n",
323
+ "for iter in range(max_iters):\n",
324
+ " if iter % eval_every == 0:\n",
325
+ " losses = estimate_loss()\n",
326
+ " print(f\"step: {iter}, train loss: {losses['train']:.3f}, val loss: {losses['val']:.3f}\")\n",
327
+ "\n",
328
+ " # sample a batch of data\n",
329
+ " xb, yb = get_batch('train')\n",
330
+ "\n",
331
+ " # evaluate the loss\n",
332
+ " logits, loss = model.forward(xb, yb)\n",
333
+ " optimizer.zero_grad(set_to_none=True)\n",
334
+ " loss.backward()\n",
335
+ " optimizer.step()\n",
336
+ "print(loss.item())\n",
337
+ "\n",
338
+ "with open('model-01.pkl', 'wb') as f:\n",
339
+ " pickle.dump(model, f)\n",
340
+ "print('model saved')"
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": null,
346
+ "id": "ccdc0134",
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": [
350
+ "prompt = 'Hello! Can you see me?'\n",
351
+ "context = torch.tensor(encode(prompt), dtype=torch.long, device=device)\n",
352
+ "generated_chars = decode(model.generate(context.unsqueeze(0), max_new_tokens=100)[0].tolist())\n",
353
+ "print(generated_chars)"
354
+ ]
355
+ }
356
+ ],
357
+ "metadata": {
358
+ "kernelspec": {
359
+ "display_name": "Python 3 (ipykernel)",
360
+ "language": "python",
361
+ "name": "python3"
362
+ },
363
+ "language_info": {
364
+ "codemirror_mode": {
365
+ "name": "ipython",
366
+ "version": 3
367
+ },
368
+ "file_extension": ".py",
369
+ "mimetype": "text/x-python",
370
+ "name": "python",
371
+ "nbconvert_exporter": "python",
372
+ "pygments_lexer": "ipython3",
373
+ "version": "3.10.12"
374
+ },
375
+ "varInspector": {
376
+ "cols": {
377
+ "lenName": 16,
378
+ "lenType": 16,
379
+ "lenVar": 40
380
+ },
381
+ "kernels_config": {
382
+ "python": {
383
+ "delete_cmd_postfix": "",
384
+ "delete_cmd_prefix": "del ",
385
+ "library": "var_list.py",
386
+ "varRefreshCmd": "print(var_dic_list())"
387
+ },
388
+ "r": {
389
+ "delete_cmd_postfix": ") ",
390
+ "delete_cmd_prefix": "rm(",
391
+ "library": "var_list.r",
392
+ "varRefreshCmd": "cat(var_dic_list()) "
393
+ }
394
+ },
395
+ "types_to_exclude": [
396
+ "module",
397
+ "function",
398
+ "builtin_function_or_method",
399
+ "instance",
400
+ "_Feature"
401
+ ],
402
+ "window_display": false
403
+ }
404
+ },
405
+ "nbformat": 4,
406
+ "nbformat_minor": 5
407
+ }
gpt_openwebtext.sync.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---
2
+ # jupyter:
3
+ # jupytext:
4
+ # text_representation:
5
+ # extension: .py
6
+ # format_name: percent
7
+ # format_version: '1.3'
8
+ # jupytext_version: 1.3.4
9
+ # kernelspec:
10
+ # display_name: Python 3
11
+ # language: python
12
+ # name: python3
13
+ # ---
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+ import mmap
18
+ import random
19
+ import pickle
20
+ import os
21
+
22
+
23
+ # %%
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ print(device)
26
+ block_size = 128
27
+ batch_size = 32
28
+ max_iters = 4000
29
+ learning_rate = 3e-4
30
+ eval_every = 500
31
+ n_embd = 384
32
+ n_head = 8
33
+ n_layer = 8
34
+ dropout = 0.2
35
+
36
+ # %%
37
+ if not os.path.exists("./openwebtext/vocab.txt") or not os.path.exists("./openwebtext/train_split.txt") or not os.path.exists("./openwebtext/val_split.txt"):
38
+ raise Exception("Please run extract.py first")
39
+ # %%
40
+ chars = ""
41
+ with open("./openwebtext/vocab.txt", 'r', encoding='utf-8') as f:
42
+ text = f.read()
43
+ chars = sorted(list(set(text)))
44
+
45
+ vocab_size = len(chars)
46
+
47
+ # %%
48
+ print(f"Vocab size: {vocab_size}")
49
+ print(f"Text length: {len(text)}")
50
+
51
+ # %%
52
+ string_to_int = {ch: i for i, ch in enumerate(chars)}
53
+ int_to_string = {i: ch for i, ch in enumerate(chars)}
54
+
55
+ encode = lambda s: [string_to_int[ch] for ch in s]
56
+ decode = lambda x: ''.join([int_to_string[i] for i in x])
57
+ # %%
58
+ # memory map for using small snippets of text from a single file of any size
59
+ def get_random_chunk(split):
60
+ filename = "./openwebtext/train_split.txt" if split == 'train' else "./openwebtext/val_split.txt"
61
+ with open(filename, 'rb') as f:
62
+ with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
63
+ # Determine the file size and a random position to start reading
64
+ file_size = len(mm)
65
+ start_pos = random.randint(0, (file_size) - block_size*batch_size)
66
+
67
+ # Seek to the random position and read the block of text
68
+ mm.seek(start_pos)
69
+ block = mm.read(block_size*batch_size-1)
70
+
71
+ # Decode the block to a string, ignoring any invalid byte sequences
72
+ decoded_block = block.decode('utf-8', errors='ignore').replace('\r', '')
73
+
74
+ # Train and test splits
75
+ data = torch.tensor(encode(decoded_block), dtype=torch.long)
76
+
77
+ return data
78
+
79
+
80
+ def get_batch(split):
81
+ data = get_random_chunk(split)
82
+ ix = torch.randint(len(data) - block_size, (batch_size,))
83
+ x = torch.stack([data[i:i+block_size] for i in ix])
84
+ y = torch.stack([data[i+1:i+block_size+1] for i in ix])
85
+ x, y = x.to(device), y.to(device)
86
+ return x, y
87
+
88
+ # %%
89
+ @torch.no_grad()
90
+ def estimate_loss():
91
+ out = {}
92
+ model.eval()
93
+ for split in ['train', 'val']:
94
+ losses = torch.zeros(eval_every)
95
+ for k in range(eval_every):
96
+ X, Y = get_batch(split)
97
+ logits, loss = model(X, Y)
98
+ losses[k] = loss.item()
99
+ out[split] = losses.mean()
100
+ model.train()
101
+ return out
102
+
103
+ # %%
104
+
105
+ class Head(nn.Module):
106
+ """ one head of self-attention """
107
+
108
+ def __init__(self, head_size):
109
+ super().__init__()
110
+ self.key = nn.Linear(n_embd, head_size, bias=False)
111
+ self.query = nn.Linear(n_embd, head_size, bias=False)
112
+ self.value = nn.Linear(n_embd, head_size, bias=False)
113
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
114
+
115
+ self.dropout = nn.Dropout(dropout)
116
+
117
+ def forward(self, x):
118
+ # input of size (batch, time-step, channels)
119
+ # output of size (batch, time-step, head size)
120
+ B,T,C = x.shape
121
+ k = self.key(x) # (B,T,hs)
122
+ q = self.query(x) # (B,T,hs)
123
+ # compute attention scores ("affinities")
124
+ wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
125
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
126
+ wei = F.softmax(wei, dim=-1) # (B, T, T)
127
+ wei = self.dropout(wei)
128
+ # perform the weighted aggregation of the values
129
+ v = self.value(x) # (B,T,hs)
130
+ out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
131
+ return out
132
+
133
+ # [1, 0, 0]
134
+ # [1, 0.6, 0]
135
+ # [1, 0.6, 0.4]
136
+ class MultiHeadAttention(nn.Module):
137
+ """ multiple heads of self-attention in parallel """
138
+
139
+ def __init__(self, num_heads, head_size):
140
+ super().__init__()
141
+ self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
142
+ self.proj = nn.Linear(head_size * num_heads, n_embd)
143
+ self.dropout = nn.Dropout(dropout)
144
+
145
+ def forward(self, x):
146
+ out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, F) -> (B, T, [h1, h1, h1, h1, h2, h2, h2, h2, h3, h3, h3, h3])
147
+ out = self.dropout(self.proj(out))
148
+ return out
149
+
150
+
151
+ class FeedFoward(nn.Module):
152
+ """ a simple linear layer followed by a non-linearity """
153
+
154
+ def __init__(self, n_embd):
155
+ super().__init__()
156
+ self.net = nn.Sequential(
157
+ nn.Linear(n_embd, 4 * n_embd),
158
+ nn.ReLU(),
159
+ nn.Linear(4 * n_embd, n_embd),
160
+ nn.Dropout(dropout),
161
+ )
162
+
163
+ def forward(self, x):
164
+ return self.net(x)
165
+
166
+ class Block(nn.Module):
167
+ """ Transformer block: communication followed by computation """
168
+
169
+ def __init__(self, n_embd, n_head):
170
+ # n_embd: embedding dimension, n_head: the number of heads we'd like
171
+ super().__init__()
172
+ head_size = n_embd // n_head
173
+ self.sa = MultiHeadAttention(n_head, head_size)
174
+ self.ffwd = FeedFoward(n_embd)
175
+ self.ln1 = nn.LayerNorm(n_embd)
176
+ self.ln2 = nn.LayerNorm(n_embd)
177
+
178
+ def forward(self, x):
179
+ y = self.sa(x)
180
+ x = self.ln1(x + y)
181
+ y = self.ffwd(x)
182
+ x = self.ln2(x + y)
183
+ return x
184
+
185
+ class GPTLanguageModel(nn.Module):
186
+ def __init__(self, vocab_size):
187
+ super().__init__()
188
+ self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
189
+ self.position_embedding_table = nn.Embedding(block_size, n_embd)
190
+ self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
191
+ self.ln_f = nn.LayerNorm(n_embd) # final layer norm
192
+ self.lm_head = nn.Linear(n_embd, vocab_size)
193
+
194
+
195
+ self.apply(self._init_weights)
196
+
197
+ def _init_weights(self, module):
198
+ if isinstance(module, nn.Linear):
199
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
200
+ if module.bias is not None:
201
+ torch.nn.init.zeros_(module.bias)
202
+ elif isinstance(module, nn.Embedding):
203
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
204
+
205
+ def forward(self, index, targets=None):
206
+ B, T = index.shape
207
+
208
+
209
+ # idx and targets are both (B,T) tensor of integers
210
+ tok_emb = self.token_embedding_table(index) # (B,T,C)
211
+ pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
212
+ x = tok_emb + pos_emb # (B,T,C)
213
+ x = self.blocks(x) # (B,T,C)
214
+ x = self.ln_f(x) # (B,T,C)
215
+ logits = self.lm_head(x) # (B,T,vocab_size)
216
+
217
+ if targets is None:
218
+ loss = None
219
+ else:
220
+ B, T, C = logits.shape
221
+ logits = logits.view(B*T, C) # reshape to what torch.cross_entropy expects
222
+ targets = targets.view(B*T)
223
+ loss = F.cross_entropy(logits, targets)
224
+ return logits, loss
225
+
226
+ def generate(self, index, max_new_tokens):
227
+ # index is (B, T) array of indices in the current context
228
+ for _ in range(max_new_tokens):
229
+ # crop idx to the last block_size tokens
230
+ index_cond = index[:, -block_size:]
231
+ # get the predictions
232
+ logits, loss = self.forward(index_cond)
233
+ # focus only on the last time step
234
+ logits = logits[:, -1, :] # becomes (B, C)
235
+ # apply softmax to get probabilities
236
+ probs = F.softmax(logits, dim=-1) # (B, C)
237
+ # sample from the distribution
238
+ index_next = torch.multinomial(probs, num_samples=1) # (B, 1)
239
+ # append sampled index to the running sequence
240
+ index = torch.cat((index, index_next), dim=1) # (B, T+1)
241
+ return index
242
+
243
+ model = GPTLanguageModel(vocab_size).to(device)
244
+
245
+ model_pickle_path = './model.pkl'
246
+ if os.path.exists(model_pickle_path):
247
+ print('loading model parameters...')
248
+ with open(model_pickle_path, 'rb') as f:
249
+ model = pickle.load(f)
250
+ print('loaded successfully!')
251
+ # %%
252
+ # create a PyTorch optimizer
253
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
254
+
255
+ for iter in range(max_iters):
256
+ if iter % eval_every == 0:
257
+ losses = estimate_loss()
258
+ print(f"step: {iter}, train loss: {losses['train']:.3f}, val loss: {losses['val']:.3f}")
259
+
260
+ # sample a batch of data
261
+ xb, yb = get_batch('train')
262
+
263
+ # evaluate the loss
264
+ logits, loss = model.forward(xb, yb)
265
+ optimizer.zero_grad(set_to_none=True)
266
+ loss.backward()
267
+ optimizer.step()
268
+ print(loss.item())
269
+
270
+ with open(model_pickle_path, 'wb') as f:
271
+ pickle.dump(model, f)
272
+ print('model saved')
273
+
274
+ # %%
275
+ prompt = 'Hello! Can you see me?'
276
+ context = torch.tensor(encode(prompt), dtype=torch.long, device=device)
277
+ generated_chars = decode(model.generate(context.unsqueeze(0), max_new_tokens=100)[0].tolist())
278
+ print(generated_chars)
requirements-base.txt CHANGED
@@ -3,3 +3,4 @@ numpy
3
  pylzma
4
  ipykernel
5
  jupyter
 
 
3
  pylzma
4
  ipykernel
5
  jupyter
6
+ tqdm