{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "059837a0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.nn import functional as F\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)\n", "block_size = 128\n", "batch_size = 32\n", "max_iters = 4000\n", "learning_rate = 3e-4\n", "eval_every = 500\n", "n_embd = 384\n", "n_head = 8\n", "n_layer = 8\n", "dropout = 0.2" ] }, { "cell_type": "code", "execution_count": 2, "id": "1fdc69a7", "metadata": {}, "outputs": [], "source": [ "with open(\"shakespeare.txt\") as f:\n", " text = f.read()" ] }, { "cell_type": "code", "execution_count": 3, "id": "0c09eeb0", "metadata": {}, "outputs": [], "source": [ "chars = sorted(set(text))\n", "vocab_size = len(chars)" ] }, { "cell_type": "code", "execution_count": 4, "id": "a278e7b9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Vocab size: 101\n", "Text length: 5357910\n" ] } ], "source": [ "print(f\"Vocab size: {vocab_size}\")\n", "print(f\"Text length: {len(text)}\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "2a540d96", "metadata": {}, "outputs": [], "source": [ "string_to_int = {ch: i for i, ch in enumerate(chars)}\n", "int_to_string = {i: ch for i, ch in enumerate(chars)}\n", "\n", "encode = lambda s: [string_to_int[ch] for ch in s]\n", "decode = lambda x: ''.join([int_to_string[i] for i in x])\n", "\n", "data = torch.tensor(encode(text), dtype=torch.long, device=device)" ] }, { "cell_type": "code", "execution_count": 6, "id": "c7c8e4aa", "metadata": {}, "outputs": [], "source": [ "n = int(0.8 * len(data))\n", "train_data = data[:n]\n", "val_data = data[n:]" ] }, { "cell_type": "code", "execution_count": 7, "id": "54d80f45", "metadata": {}, "outputs": [], "source": [ "def get_batch(split):\n", " data = train_data if split == 'train' else val_data\n", " ix = torch.randint(len(data) - block_size, (batch_size,))\n", " x = torch.stack([data[i:i+block_size] for i in ix])\n", " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n", " x, y = x.to(device), y.to(device)\n", " return x, y" ] }, { "cell_type": "code", "execution_count": 8, "id": "618df2dc", "metadata": {}, "outputs": [], "source": [ "@torch.no_grad()\n", "def estimate_loss():\n", " out = {}\n", " model.eval()\n", " for split in ['train', 'val']:\n", " losses = torch.zeros(eval_every)\n", " for k in range(eval_every):\n", " X, Y = get_batch(split)\n", " logits, loss = model(X, Y)\n", " losses[k] = loss.item()\n", " out[split] = losses.mean()\n", " model.train()\n", " return out" ] }, { "cell_type": "code", "execution_count": 9, "id": "d0a21928", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "step: 0, train loss: 4.615, val loss: 4.613\n", "step: 500, train loss: 1.923, val loss: 1.961\n", "step: 1000, train loss: 1.662, val loss: 1.753\n", "step: 1500, train loss: 1.531, val loss: 1.655\n", "step: 2000, train loss: 1.453, val loss: 1.608\n", "step: 2500, train loss: 1.398, val loss: 1.567\n", "step: 3000, train loss: 1.365, val loss: 1.543\n", "step: 3500, train loss: 1.340, val loss: 1.529\n", "1.3418211936950684\n" ] } ], "source": [ "\n", "class Head(nn.Module):\n", " \"\"\" one head of self-attention \"\"\"\n", "\n", " def __init__(self, head_size):\n", " super().__init__()\n", " self.key = nn.Linear(n_embd, head_size, bias=False)\n", " self.query = nn.Linear(n_embd, head_size, bias=False)\n", " self.value = nn.Linear(n_embd, head_size, bias=False)\n", " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n", "\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x):\n", " # input of size (batch, time-step, channels)\n", " # output of size (batch, time-step, head size)\n", " B,T,C = x.shape\n", " k = self.key(x) # (B,T,hs)\n", " q = self.query(x) # (B,T,hs)\n", " # compute attention scores (\"affinities\")\n", " wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)\n", " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n", " wei = F.softmax(wei, dim=-1) # (B, T, T)\n", " wei = self.dropout(wei)\n", " # perform the weighted aggregation of the values\n", " v = self.value(x) # (B,T,hs)\n", " out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)\n", " return out\n", "\n", "# [1, 0, 0]\n", "# [1, 0.6, 0]\n", "# [1, 0.6, 0.4]\n", "class MultiHeadAttention(nn.Module):\n", " \"\"\" multiple heads of self-attention in parallel \"\"\"\n", "\n", " def __init__(self, num_heads, head_size):\n", " super().__init__()\n", " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n", " self.proj = nn.Linear(head_size * num_heads, n_embd)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " def forward(self, x):\n", " out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, F) -> (B, T, [h1, h1, h1, h1, h2, h2, h2, h2, h3, h3, h3, h3])\n", " out = self.dropout(self.proj(out))\n", " return out\n", " \n", "\n", "class FeedFoward(nn.Module):\n", " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n", "\n", " def __init__(self, n_embd):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(n_embd, 4 * n_embd),\n", " nn.ReLU(),\n", " nn.Linear(4 * n_embd, n_embd),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", " \n", "class Block(nn.Module):\n", " \"\"\" Transformer block: communication followed by computation \"\"\"\n", "\n", " def __init__(self, n_embd, n_head):\n", " # n_embd: embedding dimension, n_head: the number of heads we'd like\n", " super().__init__()\n", " head_size = n_embd // n_head\n", " self.sa = MultiHeadAttention(n_head, head_size)\n", " self.ffwd = FeedFoward(n_embd)\n", " self.ln1 = nn.LayerNorm(n_embd)\n", " self.ln2 = nn.LayerNorm(n_embd)\n", "\n", " def forward(self, x):\n", " y = self.sa(x)\n", " x = self.ln1(x + y)\n", " y = self.ffwd(x)\n", " x = self.ln2(x + y)\n", " return x\n", " \n", "class GPTLanguageModel(nn.Module):\n", " def __init__(self, vocab_size):\n", " super().__init__()\n", " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n", " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n", " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n", " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n", " self.lm_head = nn.Linear(n_embd, vocab_size)\n", " \n", " \n", " self.apply(self._init_weights)\n", "\n", " def _init_weights(self, module):\n", " if isinstance(module, nn.Linear):\n", " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n", " if module.bias is not None:\n", " torch.nn.init.zeros_(module.bias)\n", " elif isinstance(module, nn.Embedding):\n", " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n", "\n", " def forward(self, index, targets=None):\n", " B, T = index.shape\n", " \n", " \n", " # idx and targets are both (B,T) tensor of integers\n", " tok_emb = self.token_embedding_table(index) # (B,T,C)\n", " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n", " x = tok_emb + pos_emb # (B,T,C)\n", " x = self.blocks(x) # (B,T,C)\n", " x = self.ln_f(x) # (B,T,C)\n", " logits = self.lm_head(x) # (B,T,vocab_size)\n", " \n", " if targets is None:\n", " loss = None\n", " else:\n", " B, T, C = logits.shape\n", " logits = logits.view(B*T, C) # reshape to what torch.cross_entropy expects\n", " targets = targets.view(B*T)\n", " loss = F.cross_entropy(logits, targets) \n", " return logits, loss\n", " \n", " def generate(self, index, max_new_tokens):\n", " # index is (B, T) array of indices in the current context\n", " for _ in range(max_new_tokens):\n", " # crop idx to the last block_size tokens\n", " index_cond = index[:, -block_size:]\n", " # get the predictions\n", " logits, loss = self.forward(index_cond)\n", " # focus only on the last time step\n", " logits = logits[:, -1, :] # becomes (B, C)\n", " # apply softmax to get probabilities\n", " probs = F.softmax(logits, dim=-1) # (B, C)\n", " # sample from the distribution\n", " index_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", " # append sampled index to the running sequence\n", " index = torch.cat((index, index_next), dim=1) # (B, T+1)\n", " return index\n", "\n", "model = GPTLanguageModel(vocab_size).to(device)\n", "\n", "# create a PyTorch optimizer\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n", "\n", "for iter in range(max_iters):\n", " if iter % eval_every == 0:\n", " losses = estimate_loss()\n", " print(f\"step: {iter}, train loss: {losses['train']:.3f}, val loss: {losses['val']:.3f}\")\n", "\n", " # sample a batch of data\n", " xb, yb = get_batch('train')\n", "\n", " # evaluate the loss\n", " logits, loss = model.forward(xb, yb)\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()\n", "print(loss.item())" ] }, { "cell_type": "code", "execution_count": 10, "id": "99a66247", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\t them part it. The leison drows\n", "Let them napposes them.\n", "\n", "SUFFUE.\n", "Yea, erow now, he was near angless.\n" ] } ], "source": [ "\n", "context = torch.zeros((1,1), dtype=torch.long, device=device)\n", "generated_chars = decode(model.generate(context, max_new_tokens=100)[0].tolist())\n", "print(generated_chars)" ] }, { "cell_type": "code", "execution_count": 11, "id": "c2b03115", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "To be or not to be, my lord; and at Worces and will.\n", "\n", "FRIAR L JOHN.\n", "My dory Gold Catesby say the King vow you are.\n", "\n", "ENT\n" ] } ], "source": [ "\n", "prompt = 'To be or not to be,'\n", "context = torch.tensor(encode(prompt), dtype=torch.long, device=device)\n", "generated_chars = decode(model.generate(context.unsqueeze(0), max_new_tokens=100)[0].tolist())\n", "print(generated_chars)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }