{ "cells": [ { "cell_type": "code", "execution_count": 5, "id": "b84a347c", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "import mmap\n", "import random\n", "import pickle" ] }, { "cell_type": "code", "execution_count": 6, "id": "058368c2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)\n", "block_size = 128\n", "batch_size = 32\n", "max_iters = 4000\n", "learning_rate = 3e-4\n", "eval_every = 500\n", "n_embd = 384\n", "n_head = 8\n", "n_layer = 8\n", "dropout = 0.2" ] }, { "cell_type": "code", "execution_count": 7, "id": "4ec3625c", "metadata": {}, "outputs": [ { "ename": "FileNotFoundError", "evalue": "[Errno 2] No such file or directory: 'openwebtext/vocab.txt'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m chars \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mopenwebtext/vocab.txt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mr\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mutf-8\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 3\u001b[0m text \u001b[38;5;241m=\u001b[39m f\u001b[38;5;241m.\u001b[39mread()\n\u001b[1;32m 4\u001b[0m chars \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msorted\u001b[39m(\u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mset\u001b[39m(text)))\n", "File \u001b[0;32m~/repos/main/llm-from-scratch/venv/lib/python3.10/site-packages/IPython/core/interactiveshell.py:310\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m 303\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m}:\n\u001b[1;32m 304\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 305\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIPython won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m by default \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 307\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myou can use builtins\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m open.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 308\u001b[0m )\n\u001b[0;32m--> 310\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'openwebtext/vocab.txt'" ] } ], "source": [ "chars = \"\"\n", "with open(\"./openwebtext/vocab.txt\", 'r', encoding='utf-8') as f:\n", " text = f.read()\n", " chars = sorted(list(set(text)))\n", " \n", "vocab_size = len(chars)" ] }, { "cell_type": "code", "execution_count": null, "id": "15e6af07", "metadata": {}, "outputs": [], "source": [ "print(f\"Vocab size: {vocab_size}\")\n", "print(f\"Text length: {len(text)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "425bf0b5", "metadata": {}, "outputs": [], "source": [ "string_to_int = {ch: i for i, ch in enumerate(chars)}\n", "int_to_string = {i: ch for i, ch in enumerate(chars)}\n", "\n", "encode = lambda s: [string_to_int[ch] for ch in s]\n", "decode = lambda x: ''.join([int_to_string[i] for i in x])" ] }, { "cell_type": "code", "execution_count": null, "id": "1b141a3a", "metadata": {}, "outputs": [], "source": [ "# memory map for using small snippets of text from a single file of any size\n", "def get_random_chunk(split):\n", " filename = \"./openwebtext/train_split.txt\" if split == 'train' else \"./openwebtext/val_split.txt\"\n", " with open(filename, 'rb') as f:\n", " with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:\n", " # Determine the file size and a random position to start reading\n", " file_size = len(mm)\n", " start_pos = random.randint(0, (file_size) - block_size*batch_size)\n", "\n", " # Seek to the random position and read the block of text\n", " mm.seek(start_pos)\n", " block = mm.read(block_size*batch_size-1)\n", "\n", " # Decode the block to a string, ignoring any invalid byte sequences\n", " decoded_block = block.decode('utf-8', errors='ignore').replace('\\r', '')\n", " \n", " # Train and test splits\n", " data = torch.tensor(encode(decoded_block), dtype=torch.long)\n", " \n", " return data\n", "\n", "\n", "def get_batch(split):\n", " data = get_random_chunk(split)\n", " ix = torch.randint(len(data) - block_size, (batch_size,))\n", " x = torch.stack([data[i:i+block_size] for i in ix])\n", " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n", " x, y = x.to(device), y.to(device)\n", " return x, y" ] }, { "cell_type": "code", "execution_count": null, "id": "4b27acf7", "metadata": {}, "outputs": [], "source": [ "@torch.no_grad()\n", "def estimate_loss():\n", " out = {}\n", " model.eval()\n", " for split in ['train', 'val']:\n", " losses = torch.zeros(eval_every)\n", " for k in range(eval_every):\n", " X, Y = get_batch(split)\n", " logits, loss = model(X, Y)\n", " losses[k] = loss.item()\n", " out[split] = losses.mean()\n", " model.train()\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "id": "517553b5", "metadata": {}, "outputs": [], "source": [ "\n", "class Head(nn.Module):\n", " \"\"\" one head of self-attention \"\"\"\n", "\n", " def __init__(self, head_size):\n", " super().__init__()\n", " self.key = nn.Linear(n_embd, head_size, bias=False)\n", " self.query = nn.Linear(n_embd, head_size, bias=False)\n", " self.value = nn.Linear(n_embd, head_size, bias=False)\n", " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n", "\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x):\n", " # input of size (batch, time-step, channels)\n", " # output of size (batch, time-step, head size)\n", " B,T,C = x.shape\n", " k = self.key(x) # (B,T,hs)\n", " q = self.query(x) # (B,T,hs)\n", " # compute attention scores (\"affinities\")\n", " wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)\n", " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n", " wei = F.softmax(wei, dim=-1) # (B, T, T)\n", " wei = self.dropout(wei)\n", " # perform the weighted aggregation of the values\n", " v = self.value(x) # (B,T,hs)\n", " out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)\n", " return out\n", "\n", "# [1, 0, 0]\n", "# [1, 0.6, 0]\n", "# [1, 0.6, 0.4]\n", "class MultiHeadAttention(nn.Module):\n", " \"\"\" multiple heads of self-attention in parallel \"\"\"\n", "\n", " def __init__(self, num_heads, head_size):\n", " super().__init__()\n", " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n", " self.proj = nn.Linear(head_size * num_heads, n_embd)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x):\n", " out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, F) -> (B, T, [h1, h1, h1, h1, h2, h2, h2, h2, h3, h3, h3, h3])\n", " out = self.dropout(self.proj(out))\n", " return out\n", " \n", "\n", "class FeedFoward(nn.Module):\n", " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n", "\n", " def __init__(self, n_embd):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(n_embd, 4 * n_embd),\n", " nn.ReLU(),\n", " nn.Linear(4 * n_embd, n_embd),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", " \n", "class Block(nn.Module):\n", " \"\"\" Transformer block: communication followed by computation \"\"\"\n", "\n", " def __init__(self, n_embd, n_head):\n", " # n_embd: embedding dimension, n_head: the number of heads we'd like\n", " super().__init__()\n", " head_size = n_embd // n_head\n", " self.sa = MultiHeadAttention(n_head, head_size)\n", " self.ffwd = FeedFoward(n_embd)\n", " self.ln1 = nn.LayerNorm(n_embd)\n", " self.ln2 = nn.LayerNorm(n_embd)\n", "\n", " def forward(self, x):\n", " y = self.sa(x)\n", " x = self.ln1(x + y)\n", " y = self.ffwd(x)\n", " x = self.ln2(x + y)\n", " return x\n", " \n", "class GPTLanguageModel(nn.Module):\n", " def __init__(self, vocab_size):\n", " super().__init__()\n", " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n", " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n", " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n", " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n", " self.lm_head = nn.Linear(n_embd, vocab_size)\n", " \n", " \n", " self.apply(self._init_weights)\n", "\n", " def _init_weights(self, module):\n", " if isinstance(module, nn.Linear):\n", " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n", " if module.bias is not None:\n", " torch.nn.init.zeros_(module.bias)\n", " elif isinstance(module, nn.Embedding):\n", " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n", "\n", " def forward(self, index, targets=None):\n", " B, T = index.shape\n", " \n", " \n", " # idx and targets are both (B,T) tensor of integers\n", " tok_emb = self.token_embedding_table(index) # (B,T,C)\n", " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n", " x = tok_emb + pos_emb # (B,T,C)\n", " x = self.blocks(x) # (B,T,C)\n", " x = self.ln_f(x) # (B,T,C)\n", " logits = self.lm_head(x) # (B,T,vocab_size)\n", " \n", " if targets is None:\n", " loss = None\n", " else:\n", " B, T, C = logits.shape\n", " logits = logits.view(B*T, C) # reshape to what torch.cross_entropy expects\n", " targets = targets.view(B*T)\n", " loss = F.cross_entropy(logits, targets) \n", " return logits, loss\n", " \n", " def generate(self, index, max_new_tokens):\n", " # index is (B, T) array of indices in the current context\n", " for _ in range(max_new_tokens):\n", " # crop idx to the last block_size tokens\n", " index_cond = index[:, -block_size:]\n", " # get the predictions\n", " logits, loss = self.forward(index_cond)\n", " # focus only on the last time step\n", " logits = logits[:, -1, :] # becomes (B, C)\n", " # apply softmax to get probabilities\n", " probs = F.softmax(logits, dim=-1) # (B, C)\n", " # sample from the distribution\n", " index_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", " # append sampled index to the running sequence\n", " index = torch.cat((index, index_next), dim=1) # (B, T+1)\n", " return index\n", "\n", "model = GPTLanguageModel(vocab_size).to(device)\n", "\n", "print('loading model parameters...')\n", "with open('model-01.pkl', 'rb') as f:\n", " model = pickle.load(f)\n", "print('loaded successfully!')" ] }, { "cell_type": "code", "execution_count": null, "id": "bb0f76ef", "metadata": {}, "outputs": [], "source": [ "# create a PyTorch optimizer\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n", "\n", "for iter in range(max_iters):\n", " if iter % eval_every == 0:\n", " losses = estimate_loss()\n", " print(f\"step: {iter}, train loss: {losses['train']:.3f}, val loss: {losses['val']:.3f}\")\n", "\n", " # sample a batch of data\n", " xb, yb = get_batch('train')\n", "\n", " # evaluate the loss\n", " logits, loss = model.forward(xb, yb)\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()\n", "print(loss.item())\n", "\n", "with open('model-01.pkl', 'wb') as f:\n", " pickle.dump(model, f)\n", "print('model saved')" ] }, { "cell_type": "code", "execution_count": null, "id": "ccdc0134", "metadata": {}, "outputs": [], "source": [ "prompt = 'Hello! Can you see me?'\n", "context = torch.tensor(encode(prompt), dtype=torch.long, device=device)\n", "generated_chars = decode(model.generate(context.unsqueeze(0), max_new_tokens=100)[0].tolist())\n", "print(generated_chars)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }