Upload 14 files
Browse files- README.md +58 -5
- app.py +65 -30
- checkpoints/model.pth +3 -0
- data/input.txt +0 -0
- experiments/bigram.py +107 -0
- experiments/bigram_v2.py +200 -0
- experiments/exp.ipynb +468 -0
- gpt.ipynb +211 -0
- requirements.txt +2 -2
- src/inference.py +9 -0
- src/model.py +120 -0
- src/training.py +53 -0
- src/utils.py +32 -0
README.md
CHANGED
|
@@ -1,13 +1,66 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: indigo
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: "ERA SESSION21: GPT from scratch"
|
| 3 |
+
emoji: 🌍
|
| 4 |
colorFrom: indigo
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 3.50.2
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
### Results
|
| 14 |
+
**Bigram Base model training and results**
|
| 15 |
+
|
| 16 |
+

|
| 17 |
+
|
| 18 |
+
**GPT Model training results**
|
| 19 |
+
|
| 20 |
+

|
| 21 |
+
|
| 22 |
+
#### Generation Output:
|
| 23 |
+
```python
|
| 24 |
+
model = torch.load("checkpoints/model.pth", map_location={"cpu": device})
|
| 25 |
+
results = generate("hello", model, block_size, 1000, device)
|
| 26 |
+
print(results)
|
| 27 |
+
```
|
| 28 |
+
```
|
| 29 |
+
hellows thence grown from thee.
|
| 30 |
+
Since thou hast raim, thou thast well were quarterned; and
|
| 31 |
+
ever man tree can saw for words word from her at hour
|
| 32 |
+
Whiles contrations or devoided from ere years;
|
| 33 |
+
Yea, foul vice, indelice on the bird of the
|
| 34 |
+
noble of Hermione.
|
| 35 |
+
|
| 36 |
+
PARIS:
|
| 37 |
+
Sir, adies, sir, hate no choping but to your good.
|
| 38 |
+
|
| 39 |
+
HENRY BOLINGBROKE:
|
| 40 |
+
Yes, to ask you might, foreweed.
|
| 41 |
+
|
| 42 |
+
WARCK:
|
| 43 |
+
'Tis he made moust true.
|
| 44 |
+
|
| 45 |
+
RORSET:
|
| 46 |
+
It is an hour fastal that cracknaf at the chase
|
| 47 |
+
Upon; you are your hearing news a daughter.
|
| 48 |
+
|
| 49 |
+
KING EDWARD IV:
|
| 50 |
+
Tut, Lord Warwick, thou shouldst aft Rutlansps?
|
| 51 |
+
Thou tust but back hild, he countemn'd my lady's seal,
|
| 52 |
+
For access dead the treature moon! and the Englisting!
|
| 53 |
+
Thy vage for yonder see thou be donen?
|
| 54 |
+
O, count thou dost not Romeo, thou pratheeo sir,
|
| 55 |
+
That sweet thou feigh with no past blood on
|
| 56 |
+
Be see, here through on that find bears, if an
|
| 57 |
+
pretterinctors three and aspect die meeds thou,
|
| 58 |
+
Behing mine of thy denigning state lain business?
|
| 59 |
+
|
| 60 |
+
SAMPSA:
|
| 61 |
+
Sir, ha! but thou refused? thyself food, gr
|
| 62 |
+
```
|
| 63 |
+
### Gradio Interface
|
| 64 |
+

|
| 65 |
+
|
| 66 |
+
|
app.py
CHANGED
|
@@ -1,43 +1,78 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
| 3 |
import torch
|
| 4 |
-
import
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
with open('input.txt', 'r', encoding='utf-8') as f:
|
| 10 |
-
text = f.read()
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
chars = sorted(list(set(text)))
|
| 14 |
-
vocab_size = len(chars)
|
| 15 |
|
| 16 |
-
|
| 17 |
-
itos = { i:ch for i,ch in enumerate(chars) }
|
| 18 |
-
encode = lambda s: [stoi[c] for c in s]
|
| 19 |
-
decode = lambda l: ''.join([itos[i] for i in l])
|
| 20 |
|
| 21 |
-
model = GPTLanguageModel(vocab_size)
|
| 22 |
-
model.load_state_dict(torch.load('model.pth', map_location=cfg.device))
|
| 23 |
-
m = model.to(cfg.device)
|
| 24 |
|
| 25 |
-
def
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
out_text = decode(m.generate(context, max_new_tokens=count)[0].tolist())
|
| 31 |
-
return out_text
|
| 32 |
|
| 33 |
-
title = "ERAV1 Session 21"
|
| 34 |
-
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import random
|
| 3 |
import torch
|
| 4 |
+
import pathlib
|
| 5 |
|
| 6 |
+
from src.model import GPTModel
|
| 7 |
+
from src.inference import generate as generate_text
|
| 8 |
+
from src.utils import vocab_size
|
| 9 |
|
| 10 |
+
batch_size = 64
|
| 11 |
+
block_size = 256
|
| 12 |
+
max_iters = 5000
|
| 13 |
+
eval_interval = 500
|
| 14 |
+
learning_rate = 3e-4
|
| 15 |
+
device = "cuda:1" if torch.cuda.is_available() else "cpu"
|
| 16 |
+
eval_iters = 200
|
| 17 |
+
n_embeds = 384
|
| 18 |
+
n_heads = 6
|
| 19 |
+
n_layers = 6
|
| 20 |
+
dropout = 0.2
|
| 21 |
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
def load_model():
|
| 24 |
+
model_ckpt = torch.load("checkpoints/model.pth", map_location=device)
|
| 25 |
+
model = GPTModel(
|
| 26 |
+
vocab_size, n_embeds, block_size, n_heads, n_layers, dropout, device
|
| 27 |
+
)
|
| 28 |
+
model.load_state_dict(model_ckpt.state_dict())
|
| 29 |
+
return model
|
| 30 |
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
model = load_model()
|
|
|
|
|
|
|
|
|
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
def generate(prompt, max_new_tokens):
|
| 36 |
+
prompt = prompt.strip()
|
| 37 |
+
out = generate_text(prompt, model, block_size, max_new_tokens, device)
|
| 38 |
+
return {gpt_output: out}
|
|
|
|
|
|
|
|
|
|
| 39 |
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
with gr.Blocks() as app:
|
| 42 |
+
gr.Markdown("## ERA Session21 - GPT from scratch")
|
| 43 |
+
gr.Markdown(
|
| 44 |
+
"""This is an implementation of GPT [Let's build GPT: from scratch, in code, spelled out.](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=2s) by Andrej Karpathy.
|
| 45 |
+
|
| 46 |
+
Please find the source code and training details [here](https://github.com/RaviNaik/ERA-SESSION21).
|
| 47 |
+
|
| 48 |
+
Dataset used to train: [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).
|
| 49 |
+
"""
|
| 50 |
+
)
|
| 51 |
+
with gr.Row():
|
| 52 |
+
with gr.Column():
|
| 53 |
+
prompt_box = gr.Textbox(label="Initial Prompt", interactive=True)
|
| 54 |
+
max_new_tokens = gr.Slider(
|
| 55 |
+
minimum=10,
|
| 56 |
+
maximum=2500,
|
| 57 |
+
value=100,
|
| 58 |
+
step=10,
|
| 59 |
+
label="Select Number of Tokens to be Generated",
|
| 60 |
+
interactive=True,
|
| 61 |
+
)
|
| 62 |
+
submit_btn = gr.Button(value="Generate")
|
| 63 |
|
| 64 |
+
with gr.Column():
|
| 65 |
+
gpt_output = gr.TextArea(
|
| 66 |
+
label="Text Generated by GPT",
|
| 67 |
+
show_label=True,
|
| 68 |
+
max_lines=100,
|
| 69 |
+
interactive=False,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
submit_btn.click(
|
| 73 |
+
generate,
|
| 74 |
+
inputs=[prompt_box, max_new_tokens],
|
| 75 |
+
outputs=[gpt_output],
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
app.launch()
|
checkpoints/model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a8b930ee87e1eecc6a03bc49983a81fd11aaa95f4cd5e1d64091d6107827811b
|
| 3 |
+
size 52698997
|
data/input.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
experiments/bigram.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
batch_size = 32
|
| 6 |
+
block_size = 8
|
| 7 |
+
max_iters = 3000
|
| 8 |
+
eval_interval = 300
|
| 9 |
+
learning_rate = 1e-2
|
| 10 |
+
device = "cuda:1" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
eval_iters = 200
|
| 12 |
+
|
| 13 |
+
torch.manual_seed(1123)
|
| 14 |
+
|
| 15 |
+
with open("input.txt") as f:
|
| 16 |
+
text = f.read()
|
| 17 |
+
|
| 18 |
+
chars = sorted(list(set(text)))
|
| 19 |
+
vocab_size = len(chars)
|
| 20 |
+
|
| 21 |
+
stoi = {ch: i for i, ch in enumerate(chars)}
|
| 22 |
+
itos = {i: ch for i, ch in enumerate(chars)}
|
| 23 |
+
|
| 24 |
+
encode = lambda s: [stoi[c] for c in s]
|
| 25 |
+
decode = lambda l: "".join([itos[i] for i in l])
|
| 26 |
+
|
| 27 |
+
data = torch.tensor(encode(text), dtype=torch.long)
|
| 28 |
+
n = int(0.9 * len(data))
|
| 29 |
+
train_data = data[:n]
|
| 30 |
+
val_data = data[n:]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_batch(split):
|
| 34 |
+
data = train_data if split == "train" else val_data
|
| 35 |
+
ix = torch.randint(len(data) - block_size, (batch_size,))
|
| 36 |
+
x = torch.stack([data[i : i + block_size] for i in ix])
|
| 37 |
+
y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
|
| 38 |
+
return x, y
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@torch.no_grad()
|
| 42 |
+
def estimate_loss(model: nn.Module):
|
| 43 |
+
out = {}
|
| 44 |
+
model.eval()
|
| 45 |
+
for split in ["train", "val"]:
|
| 46 |
+
losses = torch.zeros(eval_iters)
|
| 47 |
+
for k in range(eval_iters):
|
| 48 |
+
X, Y = get_batch(split)
|
| 49 |
+
X, Y = X.to(device), Y.to(device)
|
| 50 |
+
logits, loss = model(X, Y)
|
| 51 |
+
losses[k] = loss.item()
|
| 52 |
+
out[split] = losses.mean()
|
| 53 |
+
model.train()
|
| 54 |
+
return out
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class BigramLanguageModel(nn.Module):
|
| 58 |
+
def __init__(self, vocab_size):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
|
| 61 |
+
|
| 62 |
+
def forward(self, idx, targets=None):
|
| 63 |
+
logits = self.token_embedding_table(idx) # BTC
|
| 64 |
+
loss = None
|
| 65 |
+
if targets is not None:
|
| 66 |
+
B, T, C = logits.shape
|
| 67 |
+
logits = logits.view(B * T, C)
|
| 68 |
+
targets = targets.view(B * T)
|
| 69 |
+
loss = F.cross_entropy(logits, targets)
|
| 70 |
+
return logits, loss
|
| 71 |
+
|
| 72 |
+
def generate(self, idx, max_new_tokens):
|
| 73 |
+
for _ in range(max_new_tokens):
|
| 74 |
+
logits, loss = self(idx) # BxTxC
|
| 75 |
+
logits = logits[:, -1, :] # BxC
|
| 76 |
+
probs = F.softmax(logits, dim=-1) # BxC
|
| 77 |
+
idx_next = torch.multinomial(probs, num_samples=1) # Bx1
|
| 78 |
+
idx = torch.cat((idx, idx_next), dim=1) # BxT+1
|
| 79 |
+
|
| 80 |
+
return idx
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
model = BigramLanguageModel(vocab_size)
|
| 84 |
+
|
| 85 |
+
model = model.to(device)
|
| 86 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
|
| 87 |
+
|
| 88 |
+
for iter in range(max_iters):
|
| 89 |
+
if iter % eval_interval == 0:
|
| 90 |
+
losses = estimate_loss(model)
|
| 91 |
+
print(
|
| 92 |
+
f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
xb, yb = get_batch("train")
|
| 96 |
+
xb, yb = xb.to(device), yb.to(device)
|
| 97 |
+
|
| 98 |
+
logits, loss = model(xb, yb)
|
| 99 |
+
|
| 100 |
+
optimizer.zero_grad(set_to_none=True)
|
| 101 |
+
loss.backward()
|
| 102 |
+
optimizer.step()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
context = torch.zeros((1, 1), dtype=torch.long, device=device)
|
| 106 |
+
results = decode(model.generate(context, max_new_tokens=100)[0].tolist())
|
| 107 |
+
print(results)
|
experiments/bigram_v2.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
batch_size = 64
|
| 6 |
+
block_size = 256
|
| 7 |
+
max_iters = 5000
|
| 8 |
+
eval_interval = 500
|
| 9 |
+
learning_rate = 3e-4
|
| 10 |
+
device = "cuda:1" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
eval_iters = 200
|
| 12 |
+
n_embeds = 384
|
| 13 |
+
n_heads = 6
|
| 14 |
+
n_layers = 6
|
| 15 |
+
dropout = 0.2
|
| 16 |
+
|
| 17 |
+
torch.manual_seed(1123)
|
| 18 |
+
|
| 19 |
+
with open("input.txt") as f:
|
| 20 |
+
text = f.read()
|
| 21 |
+
|
| 22 |
+
chars = sorted(list(set(text)))
|
| 23 |
+
vocab_size = len(chars)
|
| 24 |
+
|
| 25 |
+
stoi = {ch: i for i, ch in enumerate(chars)}
|
| 26 |
+
itos = {i: ch for i, ch in enumerate(chars)}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def encode(s):
|
| 30 |
+
return [stoi[c] for c in s]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def decode(l):
|
| 34 |
+
return "".join([itos[i] for i in l])
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
data = torch.tensor(encode(text), dtype=torch.long)
|
| 38 |
+
n = int(0.9 * len(data))
|
| 39 |
+
train_data = data[:n]
|
| 40 |
+
val_data = data[n:]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_batch(split):
|
| 44 |
+
data = train_data if split == "train" else val_data
|
| 45 |
+
ix = torch.randint(len(data) - block_size, (batch_size,))
|
| 46 |
+
x = torch.stack([data[i : i + block_size] for i in ix])
|
| 47 |
+
y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
|
| 48 |
+
return x, y
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@torch.no_grad()
|
| 52 |
+
def estimate_loss(model: nn.Module):
|
| 53 |
+
out = {}
|
| 54 |
+
model.eval()
|
| 55 |
+
for split in ["train", "val"]:
|
| 56 |
+
losses = torch.zeros(eval_iters)
|
| 57 |
+
for k in range(eval_iters):
|
| 58 |
+
X, Y = get_batch(split)
|
| 59 |
+
X, Y = X.to(device), Y.to(device)
|
| 60 |
+
logits, loss = model(X, Y)
|
| 61 |
+
losses[k] = loss.item()
|
| 62 |
+
out[split] = losses.mean()
|
| 63 |
+
model.train()
|
| 64 |
+
return out
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Head(nn.Module):
|
| 68 |
+
def __init__(self, n_embed, head_size) -> None:
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.key = nn.Linear(n_embed, head_size, bias=False)
|
| 71 |
+
self.query = nn.Linear(n_embed, head_size, bias=False)
|
| 72 |
+
self.value = nn.Linear(n_embed, head_size, bias=False)
|
| 73 |
+
self.dropout = nn.Dropout(dropout)
|
| 74 |
+
self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
B, T, C = x.shape
|
| 78 |
+
k = self.key(x)
|
| 79 |
+
q = self.query(x)
|
| 80 |
+
wei = q @ k.transpose(-2, -1) * (C**-0.5) # (B,T,16) @ (B,16,T) --> (B,T,T)
|
| 81 |
+
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
|
| 82 |
+
wei = F.softmax(wei, dim=-1)
|
| 83 |
+
wei = self.dropout(wei)
|
| 84 |
+
v = self.value(x)
|
| 85 |
+
out = wei @ v
|
| 86 |
+
return out
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class MultiHeadAttention(nn.Module):
|
| 90 |
+
def __init__(self, n_heads, n_embeds, head_size):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.heads = nn.ModuleList([Head(n_embeds, head_size) for _ in range(n_heads)])
|
| 93 |
+
self.proj = nn.Linear(n_embeds, n_embeds)
|
| 94 |
+
self.dropout = nn.Dropout(dropout)
|
| 95 |
+
|
| 96 |
+
def forward(self, x):
|
| 97 |
+
x = torch.cat([h(x) for h in self.heads], dim=-1)
|
| 98 |
+
x = self.proj(x)
|
| 99 |
+
x = self.dropout(x)
|
| 100 |
+
return x
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class FeedForward(nn.Module):
|
| 104 |
+
def __init__(self, n_embeds):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.net = nn.Sequential(
|
| 107 |
+
nn.Linear(n_embeds, 4 * n_embeds),
|
| 108 |
+
nn.ReLU(),
|
| 109 |
+
nn.Linear(4 * n_embeds, n_embeds),
|
| 110 |
+
nn.Dropout(dropout),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
return self.net(x)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Block(nn.Module):
|
| 118 |
+
def __init__(self, n_embeds, n_heads):
|
| 119 |
+
super().__init__()
|
| 120 |
+
head_size = n_embeds // n_heads
|
| 121 |
+
self.sa_heads = MultiHeadAttention(n_heads, n_embeds, head_size)
|
| 122 |
+
self.ffwd = FeedForward(n_embeds)
|
| 123 |
+
self.ln1 = nn.LayerNorm(n_embeds)
|
| 124 |
+
self.ln2 = nn.LayerNorm(n_embeds)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
x = x + self.sa_heads(self.ln1(x))
|
| 128 |
+
x = x + self.ffwd(self.ln2(x))
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class BigramLanguageModel(nn.Module):
|
| 133 |
+
def __init__(self, vocab_size, n_embeds, block_size):
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.token_embedding_table = nn.Embedding(vocab_size, n_embeds)
|
| 136 |
+
self.position_embedding_table = nn.Embedding(block_size, n_embeds)
|
| 137 |
+
self.blocks = nn.Sequential(
|
| 138 |
+
*[Block(n_embeds, n_heads) for _ in range(n_layers)]
|
| 139 |
+
)
|
| 140 |
+
self.lnf = nn.LayerNorm(n_embeds)
|
| 141 |
+
self.lm_head = nn.Linear(n_embeds, vocab_size)
|
| 142 |
+
|
| 143 |
+
def forward(self, idx, targets=None):
|
| 144 |
+
B, T = idx.shape
|
| 145 |
+
|
| 146 |
+
tok_embeds = self.token_embedding_table(idx) # BxTxNemb
|
| 147 |
+
pos_embeds = self.position_embedding_table(
|
| 148 |
+
torch.arange(T, device=device)
|
| 149 |
+
) # TXNemb
|
| 150 |
+
|
| 151 |
+
x = tok_embeds + pos_embeds # BxTxNemb
|
| 152 |
+
x = self.blocks(x)
|
| 153 |
+
x = self.lnf(x)
|
| 154 |
+
logits = self.lm_head(x) # BxTxVocabSize
|
| 155 |
+
|
| 156 |
+
loss = None
|
| 157 |
+
if targets is not None:
|
| 158 |
+
B, T, C = logits.shape
|
| 159 |
+
logits = logits.view(B * T, C)
|
| 160 |
+
targets = targets.view(B * T)
|
| 161 |
+
loss = F.cross_entropy(logits, targets)
|
| 162 |
+
return logits, loss
|
| 163 |
+
|
| 164 |
+
def generate(self, idx, max_new_tokens):
|
| 165 |
+
for _ in range(max_new_tokens):
|
| 166 |
+
idx_cond = idx[:, -block_size:]
|
| 167 |
+
logits, loss = self(idx_cond) # BxTxC
|
| 168 |
+
logits = logits[:, -1, :] # BxC
|
| 169 |
+
probs = F.softmax(logits, dim=-1) # BxC
|
| 170 |
+
idx_next = torch.multinomial(probs, num_samples=1) # Bx1
|
| 171 |
+
idx = torch.cat((idx, idx_next), dim=1) # BxT+1
|
| 172 |
+
|
| 173 |
+
return idx
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
model = BigramLanguageModel(vocab_size, n_embeds, block_size)
|
| 177 |
+
|
| 178 |
+
model = model.to(device)
|
| 179 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
|
| 180 |
+
|
| 181 |
+
for iter in range(max_iters):
|
| 182 |
+
if iter % eval_interval == 0:
|
| 183 |
+
losses = estimate_loss(model)
|
| 184 |
+
print(
|
| 185 |
+
f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
xb, yb = get_batch("train")
|
| 189 |
+
xb, yb = xb.to(device), yb.to(device)
|
| 190 |
+
|
| 191 |
+
logits, loss = model(xb, yb)
|
| 192 |
+
|
| 193 |
+
optimizer.zero_grad(set_to_none=True)
|
| 194 |
+
loss.backward()
|
| 195 |
+
optimizer.step()
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
context = torch.zeros((1, 1), dtype=torch.long, device=device)
|
| 199 |
+
results = decode(model.generate(context, max_new_tokens=100)[0].tolist())
|
| 200 |
+
print(results)
|
experiments/exp.ipynb
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [
|
| 8 |
+
{
|
| 9 |
+
"name": "stdout",
|
| 10 |
+
"output_type": "stream",
|
| 11 |
+
"text": [
|
| 12 |
+
"--2023-10-27 16:11:32-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n",
|
| 13 |
+
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...\n",
|
| 14 |
+
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... "
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"name": "stdout",
|
| 19 |
+
"output_type": "stream",
|
| 20 |
+
"text": [
|
| 21 |
+
"connected.\n",
|
| 22 |
+
"HTTP request sent, awaiting response... 200 OK\n",
|
| 23 |
+
"Length: 1115394 (1.1M) [text/plain]\n",
|
| 24 |
+
"Saving to: ‘input.txt.1’\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"input.txt.1 100%[===================>] 1.06M 734KB/s in 1.5s \n",
|
| 27 |
+
"\n",
|
| 28 |
+
"2023-10-27 16:11:36 (734 KB/s) - ‘input.txt.1’ saved [1115394/1115394]\n",
|
| 29 |
+
"\n"
|
| 30 |
+
]
|
| 31 |
+
}
|
| 32 |
+
],
|
| 33 |
+
"source": [
|
| 34 |
+
"!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": 2,
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"outputs": [],
|
| 42 |
+
"source": [
|
| 43 |
+
"with open(\"input.txt\") as f:\n",
|
| 44 |
+
" text = f.read()"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": 4,
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"outputs": [
|
| 52 |
+
{
|
| 53 |
+
"data": {
|
| 54 |
+
"text/plain": [
|
| 55 |
+
"'First Citizen:\\nBefore we proceed any further, hear'"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
"execution_count": 4,
|
| 59 |
+
"metadata": {},
|
| 60 |
+
"output_type": "execute_result"
|
| 61 |
+
}
|
| 62 |
+
],
|
| 63 |
+
"source": [
|
| 64 |
+
"text[:50]"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "code",
|
| 69 |
+
"execution_count": 5,
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"outputs": [
|
| 72 |
+
{
|
| 73 |
+
"name": "stdout",
|
| 74 |
+
"output_type": "stream",
|
| 75 |
+
"text": [
|
| 76 |
+
"\n",
|
| 77 |
+
" !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
|
| 78 |
+
"65\n"
|
| 79 |
+
]
|
| 80 |
+
}
|
| 81 |
+
],
|
| 82 |
+
"source": [
|
| 83 |
+
"chars = sorted(list(set(text)))\n",
|
| 84 |
+
"vocab_size = len(chars)\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"print(\"\".join(chars))\n",
|
| 87 |
+
"print(vocab_size)"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": 7,
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"outputs": [
|
| 95 |
+
{
|
| 96 |
+
"name": "stdout",
|
| 97 |
+
"output_type": "stream",
|
| 98 |
+
"text": [
|
| 99 |
+
"[46, 47, 1, 58, 46, 43, 56, 43]\n",
|
| 100 |
+
"hi there\n"
|
| 101 |
+
]
|
| 102 |
+
}
|
| 103 |
+
],
|
| 104 |
+
"source": [
|
| 105 |
+
"stoi = {ch: i for i, ch in enumerate(chars)}\n",
|
| 106 |
+
"itos = {i: ch for i, ch in enumerate(chars)}\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"encode = lambda s: [stoi[c] for c in s]\n",
|
| 109 |
+
"decode = lambda l: \"\".join([itos[i] for i in l])\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"print(encode(\"hi there\"))\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"print(decode(encode(\"hi there\")))"
|
| 114 |
+
]
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "code",
|
| 118 |
+
"execution_count": 8,
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"outputs": [
|
| 121 |
+
{
|
| 122 |
+
"name": "stdout",
|
| 123 |
+
"output_type": "stream",
|
| 124 |
+
"text": [
|
| 125 |
+
"torch.Size([1115394]) torch.int64\n",
|
| 126 |
+
"tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44,\n",
|
| 127 |
+
" 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63,\n",
|
| 128 |
+
" 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1,\n",
|
| 129 |
+
" 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49,\n",
|
| 130 |
+
" 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47,\n",
|
| 131 |
+
" 58, 47, 64, 43, 52, 10, 0, 37, 53, 59])\n"
|
| 132 |
+
]
|
| 133 |
+
}
|
| 134 |
+
],
|
| 135 |
+
"source": [
|
| 136 |
+
"import torch\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"data = torch.tensor(encode(text), dtype=torch.long)\n",
|
| 139 |
+
"print(data.shape, data.dtype)\n",
|
| 140 |
+
"print(data[:100])"
|
| 141 |
+
]
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"cell_type": "code",
|
| 145 |
+
"execution_count": 9,
|
| 146 |
+
"metadata": {},
|
| 147 |
+
"outputs": [],
|
| 148 |
+
"source": [
|
| 149 |
+
"n = int(0.9 * len(data))\n",
|
| 150 |
+
"train_data = data[:n]\n",
|
| 151 |
+
"val_data = data[n:]"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "code",
|
| 156 |
+
"execution_count": 10,
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"outputs": [
|
| 159 |
+
{
|
| 160 |
+
"name": "stdout",
|
| 161 |
+
"output_type": "stream",
|
| 162 |
+
"text": [
|
| 163 |
+
"Inputs:\n",
|
| 164 |
+
"torch.Size([4, 8])\n",
|
| 165 |
+
"tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n",
|
| 166 |
+
" [44, 53, 56, 1, 58, 46, 39, 58],\n",
|
| 167 |
+
" [52, 58, 1, 58, 46, 39, 58, 1],\n",
|
| 168 |
+
" [25, 17, 27, 10, 0, 21, 1, 54]])\n",
|
| 169 |
+
"-----------\n",
|
| 170 |
+
"Targets:\n",
|
| 171 |
+
"torch.Size([4, 8])\n",
|
| 172 |
+
"tensor([[43, 58, 5, 57, 1, 46, 43, 39],\n",
|
| 173 |
+
" [53, 56, 1, 58, 46, 39, 58, 1],\n",
|
| 174 |
+
" [58, 1, 58, 46, 39, 58, 1, 46],\n",
|
| 175 |
+
" [17, 27, 10, 0, 21, 1, 54, 39]])\n"
|
| 176 |
+
]
|
| 177 |
+
}
|
| 178 |
+
],
|
| 179 |
+
"source": [
|
| 180 |
+
"torch.manual_seed(1337)\n",
|
| 181 |
+
"batch_size = 4\n",
|
| 182 |
+
"block_size = 8\n",
|
| 183 |
+
"\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"def get_batch(split):\n",
|
| 186 |
+
" data = train_data if split == \"train\" else val_data\n",
|
| 187 |
+
" ix = torch.randint(len(data) - block_size, (batch_size,))\n",
|
| 188 |
+
" x = torch.stack([data[i : i + block_size] for i in ix])\n",
|
| 189 |
+
" y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])\n",
|
| 190 |
+
" return x, y\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"\n",
|
| 193 |
+
"xb, yb = get_batch(\"train\")\n",
|
| 194 |
+
"print(\"Inputs:\")\n",
|
| 195 |
+
"print(xb.shape)\n",
|
| 196 |
+
"print(xb)\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"print(\"-----------\")\n",
|
| 199 |
+
"print(\"Targets:\")\n",
|
| 200 |
+
"print(yb.shape)\n",
|
| 201 |
+
"print(yb)"
|
| 202 |
+
]
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"cell_type": "code",
|
| 206 |
+
"execution_count": 11,
|
| 207 |
+
"metadata": {},
|
| 208 |
+
"outputs": [],
|
| 209 |
+
"source": [
|
| 210 |
+
"import torch.nn as nn\n",
|
| 211 |
+
"from torch.nn import functional as F\n",
|
| 212 |
+
"\n",
|
| 213 |
+
"\n",
|
| 214 |
+
"class BigramLanguageModel(nn.Module):\n",
|
| 215 |
+
" def __init__(self, vocab_size):\n",
|
| 216 |
+
" super().__init__()\n",
|
| 217 |
+
" self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n",
|
| 218 |
+
"\n",
|
| 219 |
+
" def forward(self, idx, targets):\n",
|
| 220 |
+
" logits = self.token_embedding_table(idx)\n",
|
| 221 |
+
"\n",
|
| 222 |
+
" return logits"
|
| 223 |
+
]
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"cell_type": "code",
|
| 227 |
+
"execution_count": 12,
|
| 228 |
+
"metadata": {},
|
| 229 |
+
"outputs": [
|
| 230 |
+
{
|
| 231 |
+
"name": "stdout",
|
| 232 |
+
"output_type": "stream",
|
| 233 |
+
"text": [
|
| 234 |
+
"torch.Size([4, 8, 65])\n"
|
| 235 |
+
]
|
| 236 |
+
}
|
| 237 |
+
],
|
| 238 |
+
"source": [
|
| 239 |
+
"m = BigramLanguageModel(vocab_size)\n",
|
| 240 |
+
"out = m(xb, yb)\n",
|
| 241 |
+
"print(out.shape) # B,T,C -> 4X8X65"
|
| 242 |
+
]
|
| 243 |
+
},
|
| 244 |
+
{
|
| 245 |
+
"cell_type": "code",
|
| 246 |
+
"execution_count": 15,
|
| 247 |
+
"metadata": {},
|
| 248 |
+
"outputs": [
|
| 249 |
+
{
|
| 250 |
+
"name": "stdout",
|
| 251 |
+
"output_type": "stream",
|
| 252 |
+
"text": [
|
| 253 |
+
"torch.Size([32, 65])\n",
|
| 254 |
+
"tensor(4.5262, grad_fn=<NllLossBackward0>)\n"
|
| 255 |
+
]
|
| 256 |
+
}
|
| 257 |
+
],
|
| 258 |
+
"source": [
|
| 259 |
+
"class BigramLanguageModel(nn.Module):\n",
|
| 260 |
+
" def __init__(self, vocab_size):\n",
|
| 261 |
+
" super().__init__()\n",
|
| 262 |
+
" self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n",
|
| 263 |
+
"\n",
|
| 264 |
+
" def forward(self, idx, targets=None):\n",
|
| 265 |
+
" logits = self.token_embedding_table(idx) # BTC\n",
|
| 266 |
+
" loss = None\n",
|
| 267 |
+
" if targets is not None:\n",
|
| 268 |
+
" B, T, C = logits.shape\n",
|
| 269 |
+
" logits = logits.view(B * T, C)\n",
|
| 270 |
+
" targets = targets.view(B * T)\n",
|
| 271 |
+
" loss = F.cross_entropy(logits, targets)\n",
|
| 272 |
+
" return logits, loss\n",
|
| 273 |
+
"\n",
|
| 274 |
+
" def generate(self, idx, max_new_tokens):\n",
|
| 275 |
+
" for _ in range(max_new_tokens):\n",
|
| 276 |
+
" logits, loss = self(idx) # BxTxC\n",
|
| 277 |
+
" logits = logits[:, -1, :] # BxC\n",
|
| 278 |
+
" probs = F.softmax(logits, dim=-1) # BxC\n",
|
| 279 |
+
" idx_next = torch.multinomial(probs, num_samples=1) # Bx1\n",
|
| 280 |
+
" idx = torch.cat((idx, idx_next), dim=1) # BxT+1\n",
|
| 281 |
+
"\n",
|
| 282 |
+
" return idx\n",
|
| 283 |
+
"\n",
|
| 284 |
+
"\n",
|
| 285 |
+
"m = BigramLanguageModel(vocab_size)\n",
|
| 286 |
+
"logits, loss = m(xb, yb)\n",
|
| 287 |
+
"print(logits.shape) # B,T,C -> 4X8X65\n",
|
| 288 |
+
"print(loss)"
|
| 289 |
+
]
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"cell_type": "code",
|
| 293 |
+
"execution_count": 16,
|
| 294 |
+
"metadata": {},
|
| 295 |
+
"outputs": [
|
| 296 |
+
{
|
| 297 |
+
"name": "stdout",
|
| 298 |
+
"output_type": "stream",
|
| 299 |
+
"text": [
|
| 300 |
+
"\n",
|
| 301 |
+
"'JgC.JZWqUkpdtkSpmzjM-,RqzgaN?vC:hgjnAnBZDga-APqGUH!WdCbIb;$DefOYbEvcaKGMmnO'q$KdS-'ZH\n",
|
| 302 |
+
".YSqr'X!Q! d;\n"
|
| 303 |
+
]
|
| 304 |
+
}
|
| 305 |
+
],
|
| 306 |
+
"source": [
|
| 307 |
+
"idx = torch.zeros((1, 1), dtype=torch.long)\n",
|
| 308 |
+
"\n",
|
| 309 |
+
"results = decode(m.generate(idx, max_new_tokens=100)[0].tolist())\n",
|
| 310 |
+
"\n",
|
| 311 |
+
"print(results)"
|
| 312 |
+
]
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"cell_type": "code",
|
| 316 |
+
"execution_count": 17,
|
| 317 |
+
"metadata": {},
|
| 318 |
+
"outputs": [],
|
| 319 |
+
"source": [
|
| 320 |
+
"optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)"
|
| 321 |
+
]
|
| 322 |
+
},
|
| 323 |
+
{
|
| 324 |
+
"cell_type": "code",
|
| 325 |
+
"execution_count": 19,
|
| 326 |
+
"metadata": {},
|
| 327 |
+
"outputs": [
|
| 328 |
+
{
|
| 329 |
+
"name": "stdout",
|
| 330 |
+
"output_type": "stream",
|
| 331 |
+
"text": [
|
| 332 |
+
"2.4206888675689697\n"
|
| 333 |
+
]
|
| 334 |
+
}
|
| 335 |
+
],
|
| 336 |
+
"source": [
|
| 337 |
+
"batch_size = 32\n",
|
| 338 |
+
"\n",
|
| 339 |
+
"for steps in range(10000):\n",
|
| 340 |
+
" xb, yb = get_batch(\"train\")\n",
|
| 341 |
+
"\n",
|
| 342 |
+
" logits, loss = m(xb, yb)\n",
|
| 343 |
+
" optimizer.zero_grad(set_to_none=True)\n",
|
| 344 |
+
" loss.backward()\n",
|
| 345 |
+
" optimizer.step()\n",
|
| 346 |
+
"\n",
|
| 347 |
+
"print(loss.item())"
|
| 348 |
+
]
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
"cell_type": "code",
|
| 352 |
+
"execution_count": 20,
|
| 353 |
+
"metadata": {},
|
| 354 |
+
"outputs": [
|
| 355 |
+
{
|
| 356 |
+
"name": "stdout",
|
| 357 |
+
"output_type": "stream",
|
| 358 |
+
"text": [
|
| 359 |
+
"\n",
|
| 360 |
+
"Hou'sy'ting'stis's w ys'stholealy woawhimedy it 'save,\n",
|
| 361 |
+
"Too:Had wh fo an, ZCENERUCHENar ee onds, th h\n"
|
| 362 |
+
]
|
| 363 |
+
}
|
| 364 |
+
],
|
| 365 |
+
"source": [
|
| 366 |
+
"idx = torch.zeros((1, 1), dtype=torch.long)\n",
|
| 367 |
+
"\n",
|
| 368 |
+
"results = decode(m.generate(idx, max_new_tokens=100)[0].tolist())\n",
|
| 369 |
+
"\n",
|
| 370 |
+
"print(results)"
|
| 371 |
+
]
|
| 372 |
+
},
|
| 373 |
+
{
|
| 374 |
+
"cell_type": "code",
|
| 375 |
+
"execution_count": 28,
|
| 376 |
+
"metadata": {},
|
| 377 |
+
"outputs": [
|
| 378 |
+
{
|
| 379 |
+
"data": {
|
| 380 |
+
"text/plain": [
|
| 381 |
+
"torch.Size([4, 8, 16])"
|
| 382 |
+
]
|
| 383 |
+
},
|
| 384 |
+
"execution_count": 28,
|
| 385 |
+
"metadata": {},
|
| 386 |
+
"output_type": "execute_result"
|
| 387 |
+
}
|
| 388 |
+
],
|
| 389 |
+
"source": [
|
| 390 |
+
"B, T, C = 4, 8, 32\n",
|
| 391 |
+
"\n",
|
| 392 |
+
"x = torch.randn(B, T, C)\n",
|
| 393 |
+
"\n",
|
| 394 |
+
"head_size = 16\n",
|
| 395 |
+
"key = nn.Linear(C, head_size, bias=False)\n",
|
| 396 |
+
"query = nn.Linear(C, head_size, bias=False)\n",
|
| 397 |
+
"value = nn.Linear(C, head_size, bias=False)\n",
|
| 398 |
+
"k = key(x)\n",
|
| 399 |
+
"q = query(x)\n",
|
| 400 |
+
"wei = q @ k.transpose(-2, -1) * (head_size**-0.5) # (B,T,16) @ (B,16,T) --> (B,T,T)\n",
|
| 401 |
+
"\n",
|
| 402 |
+
"tril = torch.tril(torch.ones(T, T))\n",
|
| 403 |
+
"wei = wei.masked_fill(tril == 0, float(\"-inf\"))\n",
|
| 404 |
+
"wei = F.softmax(wei, dim=-1)\n",
|
| 405 |
+
"v = value(x)\n",
|
| 406 |
+
"out = wei @ v\n",
|
| 407 |
+
"\n",
|
| 408 |
+
"out.shape\n"
|
| 409 |
+
]
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"cell_type": "code",
|
| 413 |
+
"execution_count": 29,
|
| 414 |
+
"metadata": {},
|
| 415 |
+
"outputs": [
|
| 416 |
+
{
|
| 417 |
+
"data": {
|
| 418 |
+
"text/plain": [
|
| 419 |
+
"tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
| 420 |
+
" [0.3325, 0.6675, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
| 421 |
+
" [0.3578, 0.2873, 0.3550, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
| 422 |
+
" [0.2281, 0.1964, 0.2733, 0.3022, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
| 423 |
+
" [0.2851, 0.1588, 0.2068, 0.1436, 0.2057, 0.0000, 0.0000, 0.0000],\n",
|
| 424 |
+
" [0.2429, 0.1547, 0.1550, 0.1475, 0.2049, 0.0951, 0.0000, 0.0000],\n",
|
| 425 |
+
" [0.1573, 0.1838, 0.1123, 0.1680, 0.1528, 0.1194, 0.1063, 0.0000],\n",
|
| 426 |
+
" [0.1139, 0.1704, 0.0766, 0.1134, 0.1600, 0.1466, 0.1228, 0.0963]],\n",
|
| 427 |
+
" grad_fn=<SelectBackward0>)"
|
| 428 |
+
]
|
| 429 |
+
},
|
| 430 |
+
"execution_count": 29,
|
| 431 |
+
"metadata": {},
|
| 432 |
+
"output_type": "execute_result"
|
| 433 |
+
}
|
| 434 |
+
],
|
| 435 |
+
"source": [
|
| 436 |
+
"wei[0]\n"
|
| 437 |
+
]
|
| 438 |
+
},
|
| 439 |
+
{
|
| 440 |
+
"cell_type": "code",
|
| 441 |
+
"execution_count": null,
|
| 442 |
+
"metadata": {},
|
| 443 |
+
"outputs": [],
|
| 444 |
+
"source": []
|
| 445 |
+
}
|
| 446 |
+
],
|
| 447 |
+
"metadata": {
|
| 448 |
+
"kernelspec": {
|
| 449 |
+
"display_name": "Python 3",
|
| 450 |
+
"language": "python",
|
| 451 |
+
"name": "python3"
|
| 452 |
+
},
|
| 453 |
+
"language_info": {
|
| 454 |
+
"codemirror_mode": {
|
| 455 |
+
"name": "ipython",
|
| 456 |
+
"version": 3
|
| 457 |
+
},
|
| 458 |
+
"file_extension": ".py",
|
| 459 |
+
"mimetype": "text/x-python",
|
| 460 |
+
"name": "python",
|
| 461 |
+
"nbconvert_exporter": "python",
|
| 462 |
+
"pygments_lexer": "ipython3",
|
| 463 |
+
"version": "3.10.12"
|
| 464 |
+
}
|
| 465 |
+
},
|
| 466 |
+
"nbformat": 4,
|
| 467 |
+
"nbformat_minor": 2
|
| 468 |
+
}
|
gpt.ipynb
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"## Import Dependencies"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "code",
|
| 12 |
+
"execution_count": 1,
|
| 13 |
+
"metadata": {},
|
| 14 |
+
"outputs": [],
|
| 15 |
+
"source": [
|
| 16 |
+
"import torch\n"
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": 2,
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"outputs": [],
|
| 24 |
+
"source": [
|
| 25 |
+
"from src.model import GPTModel\n",
|
| 26 |
+
"from src.training import train\n",
|
| 27 |
+
"from src.inference import generate\n",
|
| 28 |
+
"from src.utils import vocab_size\n"
|
| 29 |
+
]
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"cell_type": "markdown",
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"source": [
|
| 35 |
+
"## Decalre Hyperparams"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "code",
|
| 40 |
+
"execution_count": 3,
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"outputs": [],
|
| 43 |
+
"source": [
|
| 44 |
+
"batch_size = 64\n",
|
| 45 |
+
"block_size = 256\n",
|
| 46 |
+
"max_iters = 5000\n",
|
| 47 |
+
"eval_interval = 500\n",
|
| 48 |
+
"learning_rate = 3e-4\n",
|
| 49 |
+
"device = \"cuda:1\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 50 |
+
"eval_iters = 200\n",
|
| 51 |
+
"n_embeds = 384\n",
|
| 52 |
+
"n_heads = 6\n",
|
| 53 |
+
"n_layers = 6\n",
|
| 54 |
+
"dropout = 0.2"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "markdown",
|
| 59 |
+
"metadata": {},
|
| 60 |
+
"source": [
|
| 61 |
+
"## Initialize Model and Optimizer"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"execution_count": 6,
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"outputs": [],
|
| 69 |
+
"source": [
|
| 70 |
+
"model = GPTModel(vocab_size, n_embeds, block_size, n_heads, n_layers, dropout, device)\n",
|
| 71 |
+
"model = model.to(device)\n",
|
| 72 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "markdown",
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"source": [
|
| 79 |
+
"## Model Training"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"cell_type": "code",
|
| 84 |
+
"execution_count": 7,
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"outputs": [
|
| 87 |
+
{
|
| 88 |
+
"name": "stdout",
|
| 89 |
+
"output_type": "stream",
|
| 90 |
+
"text": [
|
| 91 |
+
"Step 0: train loss 4.3249, val loss 4.3219\n",
|
| 92 |
+
"Step 500: train loss 2.0213, val loss 2.0953\n",
|
| 93 |
+
"Step 1000: train loss 1.6067, val loss 1.7813\n",
|
| 94 |
+
"Step 1500: train loss 1.4462, val loss 1.6380\n",
|
| 95 |
+
"Step 2000: train loss 1.3516, val loss 1.5810\n",
|
| 96 |
+
"Step 2500: train loss 1.2836, val loss 1.5376\n",
|
| 97 |
+
"Step 3000: train loss 1.2309, val loss 1.5148\n",
|
| 98 |
+
"Step 3500: train loss 1.1910, val loss 1.4904\n",
|
| 99 |
+
"Step 4000: train loss 1.1522, val loss 1.4822\n",
|
| 100 |
+
"Step 4500: train loss 1.1186, val loss 1.4838\n"
|
| 101 |
+
]
|
| 102 |
+
}
|
| 103 |
+
],
|
| 104 |
+
"source": [
|
| 105 |
+
"train(\n",
|
| 106 |
+
" model,\n",
|
| 107 |
+
" optimizer,\n",
|
| 108 |
+
" max_iters,\n",
|
| 109 |
+
" eval_interval,\n",
|
| 110 |
+
" eval_iters,\n",
|
| 111 |
+
" block_size,\n",
|
| 112 |
+
" batch_size,\n",
|
| 113 |
+
" device,\n",
|
| 114 |
+
")\n"
|
| 115 |
+
]
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"cell_type": "markdown",
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"source": [
|
| 121 |
+
"## Load the model and Generate text"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": 4,
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"outputs": [
|
| 129 |
+
{
|
| 130 |
+
"name": "stdout",
|
| 131 |
+
"output_type": "stream",
|
| 132 |
+
"text": [
|
| 133 |
+
"hellows thence grown from thee.\n",
|
| 134 |
+
"Since thou hast raim, thou thast well were quarterned; and\n",
|
| 135 |
+
"ever man tree can saw for words word from her at hour\n",
|
| 136 |
+
"Whiles contrations or devoided from ere years;\n",
|
| 137 |
+
"Yea, foul vice, indelice on the bird of the\n",
|
| 138 |
+
"noble of Hermione.\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"PARIS:\n",
|
| 141 |
+
"Sir, adies, sir, hate no choping but to your good.\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"HENRY BOLINGBROKE:\n",
|
| 144 |
+
"Yes, to ask you might, foreweed.\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"WARCK:\n",
|
| 147 |
+
"'Tis he made moust true.\n",
|
| 148 |
+
"\n",
|
| 149 |
+
"RORSET:\n",
|
| 150 |
+
"It is an hour fastal that cracknaf at the chase\n",
|
| 151 |
+
"Upon; you are your hearing news a daughter.\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"KING EDWARD IV:\n",
|
| 154 |
+
"Tut, Lord Warwick, thou shouldst aft Rutlansps?\n",
|
| 155 |
+
"Thou tust but back hild, he countemn'd my lady's seal,\n",
|
| 156 |
+
"For access dead the treature moon! and the Englisting!\n",
|
| 157 |
+
"Thy vage for yonder see thou be donen?\n",
|
| 158 |
+
"O, count thou dost not Romeo, thou pratheeo sir,\n",
|
| 159 |
+
"That sweet thou feigh with no past blood on\n",
|
| 160 |
+
"Be see, here through on that find bears, if an\n",
|
| 161 |
+
"pretterinctors three and aspect die meeds thou,\n",
|
| 162 |
+
"Behing mine of thy denigning state lain business?\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"SAMPSA:\n",
|
| 165 |
+
"Sir, ha! but thou refused? thyself food, gr\n"
|
| 166 |
+
]
|
| 167 |
+
}
|
| 168 |
+
],
|
| 169 |
+
"source": [
|
| 170 |
+
"model = torch.load(\"checkpoints/model.pth\", map_location={\"cpu\": device})\n",
|
| 171 |
+
"results = generate(\"hello\", model, block_size, 1000, device)\n",
|
| 172 |
+
"print(results)"
|
| 173 |
+
]
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"cell_type": "code",
|
| 177 |
+
"execution_count": null,
|
| 178 |
+
"metadata": {},
|
| 179 |
+
"outputs": [],
|
| 180 |
+
"source": []
|
| 181 |
+
},
|
| 182 |
+
{
|
| 183 |
+
"cell_type": "code",
|
| 184 |
+
"execution_count": null,
|
| 185 |
+
"metadata": {},
|
| 186 |
+
"outputs": [],
|
| 187 |
+
"source": []
|
| 188 |
+
}
|
| 189 |
+
],
|
| 190 |
+
"metadata": {
|
| 191 |
+
"kernelspec": {
|
| 192 |
+
"display_name": "Python 3",
|
| 193 |
+
"language": "python",
|
| 194 |
+
"name": "python3"
|
| 195 |
+
},
|
| 196 |
+
"language_info": {
|
| 197 |
+
"codemirror_mode": {
|
| 198 |
+
"name": "ipython",
|
| 199 |
+
"version": 3
|
| 200 |
+
},
|
| 201 |
+
"file_extension": ".py",
|
| 202 |
+
"mimetype": "text/x-python",
|
| 203 |
+
"name": "python",
|
| 204 |
+
"nbconvert_exporter": "python",
|
| 205 |
+
"pygments_lexer": "ipython3",
|
| 206 |
+
"version": "3.10.12"
|
| 207 |
+
}
|
| 208 |
+
},
|
| 209 |
+
"nbformat": 4,
|
| 210 |
+
"nbformat_minor": 2
|
| 211 |
+
}
|
requirements.txt
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
torch
|
src/inference.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from src.utils import encode, decode
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def generate(prompt, model, block_size, max_new_tokens, device):
|
| 6 |
+
X = torch.tensor(encode(prompt), dtype=torch.long, device=device)
|
| 7 |
+
X = X[:block_size].unsqueeze(0)
|
| 8 |
+
results = decode(model.generate(X, max_new_tokens=max_new_tokens)[0].tolist())
|
| 9 |
+
return results
|
src/model.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Head(nn.Module):
|
| 7 |
+
def __init__(self, n_embeds, head_size, block_size, dropout) -> None:
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.key = nn.Linear(n_embeds, head_size, bias=False)
|
| 10 |
+
self.query = nn.Linear(n_embeds, head_size, bias=False)
|
| 11 |
+
self.value = nn.Linear(n_embeds, head_size, bias=False)
|
| 12 |
+
self.dropout = nn.Dropout(dropout)
|
| 13 |
+
self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
B, T, C = x.shape
|
| 17 |
+
k = self.key(x)
|
| 18 |
+
q = self.query(x)
|
| 19 |
+
wei = q @ k.transpose(-2, -1) * (C**-0.5) # (B,T,16) @ (B,16,T) --> (B,T,T)
|
| 20 |
+
wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
|
| 21 |
+
wei = F.softmax(wei, dim=-1)
|
| 22 |
+
wei = self.dropout(wei)
|
| 23 |
+
v = self.value(x)
|
| 24 |
+
out = wei @ v
|
| 25 |
+
return out
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MultiHeadAttention(nn.Module):
|
| 29 |
+
def __init__(self, n_heads, n_embeds, head_size, block_size, dropout):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.heads = nn.ModuleList(
|
| 32 |
+
[Head(n_embeds, head_size, block_size, dropout) for _ in range(n_heads)]
|
| 33 |
+
)
|
| 34 |
+
self.proj = nn.Linear(n_embeds, n_embeds)
|
| 35 |
+
self.dropout = nn.Dropout(dropout)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
x = torch.cat([h(x) for h in self.heads], dim=-1)
|
| 39 |
+
x = self.proj(x)
|
| 40 |
+
x = self.dropout(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class FeedForward(nn.Module):
|
| 45 |
+
def __init__(self, n_embeds, dropout):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.net = nn.Sequential(
|
| 48 |
+
nn.Linear(n_embeds, 4 * n_embeds),
|
| 49 |
+
nn.ReLU(),
|
| 50 |
+
nn.Linear(4 * n_embeds, n_embeds),
|
| 51 |
+
nn.Dropout(dropout),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
return self.net(x)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Decoder(nn.Module):
|
| 59 |
+
def __init__(self, n_embeds, n_heads, block_size, dropout):
|
| 60 |
+
super().__init__()
|
| 61 |
+
head_size = n_embeds // n_heads
|
| 62 |
+
self.sa_heads = MultiHeadAttention(
|
| 63 |
+
n_heads, n_embeds, head_size, block_size, dropout
|
| 64 |
+
)
|
| 65 |
+
self.ffwd = FeedForward(n_embeds, dropout)
|
| 66 |
+
self.ln1 = nn.LayerNorm(n_embeds)
|
| 67 |
+
self.ln2 = nn.LayerNorm(n_embeds)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
x = x + self.sa_heads(self.ln1(x))
|
| 71 |
+
x = x + self.ffwd(self.ln2(x))
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class GPTModel(nn.Module):
|
| 76 |
+
def __init__(
|
| 77 |
+
self, vocab_size, n_embeds, block_size, n_heads, n_layers, dropout, device
|
| 78 |
+
):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.device = device
|
| 81 |
+
self.block_size = block_size
|
| 82 |
+
self.token_embedding_table = nn.Embedding(vocab_size, n_embeds)
|
| 83 |
+
self.position_embedding_table = nn.Embedding(block_size, n_embeds)
|
| 84 |
+
self.blocks = nn.Sequential(
|
| 85 |
+
*[Decoder(n_embeds, n_heads, block_size, dropout) for _ in range(n_layers)]
|
| 86 |
+
)
|
| 87 |
+
self.lnf = nn.LayerNorm(n_embeds)
|
| 88 |
+
self.lm_head = nn.Linear(n_embeds, vocab_size)
|
| 89 |
+
|
| 90 |
+
def forward(self, idx, targets=None):
|
| 91 |
+
B, T = idx.shape
|
| 92 |
+
|
| 93 |
+
tok_embeds = self.token_embedding_table(idx) # BxTxNemb
|
| 94 |
+
pos_embeds = self.position_embedding_table(
|
| 95 |
+
torch.arange(T, device=self.device)
|
| 96 |
+
) # TXNemb
|
| 97 |
+
|
| 98 |
+
x = tok_embeds + pos_embeds # BxTxNemb
|
| 99 |
+
x = self.blocks(x)
|
| 100 |
+
x = self.lnf(x)
|
| 101 |
+
logits = self.lm_head(x) # BxTxVocabSize
|
| 102 |
+
|
| 103 |
+
loss = None
|
| 104 |
+
if targets is not None:
|
| 105 |
+
B, T, C = logits.shape
|
| 106 |
+
logits = logits.view(B * T, C)
|
| 107 |
+
targets = targets.view(B * T)
|
| 108 |
+
loss = F.cross_entropy(logits, targets)
|
| 109 |
+
return logits, loss
|
| 110 |
+
|
| 111 |
+
def generate(self, idx, max_new_tokens):
|
| 112 |
+
for _ in range(max_new_tokens):
|
| 113 |
+
idx_cond = idx[:, -self.block_size :]
|
| 114 |
+
logits, loss = self(idx_cond) # BxTxC
|
| 115 |
+
logits = logits[:, -1, :] # BxC
|
| 116 |
+
probs = F.softmax(logits, dim=-1) # BxC
|
| 117 |
+
idx_next = torch.multinomial(probs, num_samples=1) # Bx1
|
| 118 |
+
idx = torch.cat((idx, idx_next), dim=1) # BxT+1
|
| 119 |
+
|
| 120 |
+
return idx
|
src/training.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from src.utils import get_batch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@torch.no_grad()
|
| 8 |
+
def estimate_loss(model: nn.Module, eval_iters, block_size, batch_size, device):
|
| 9 |
+
out = {}
|
| 10 |
+
model.eval()
|
| 11 |
+
for split in ["train", "val"]:
|
| 12 |
+
losses = torch.zeros(eval_iters)
|
| 13 |
+
for k in range(eval_iters):
|
| 14 |
+
X, Y = get_batch(split, block_size, batch_size)
|
| 15 |
+
X, Y = X.to(device), Y.to(device)
|
| 16 |
+
logits, loss = model(X, Y)
|
| 17 |
+
losses[k] = loss.item()
|
| 18 |
+
out[split] = losses.mean()
|
| 19 |
+
model.train()
|
| 20 |
+
return out
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def train(
|
| 24 |
+
model,
|
| 25 |
+
optimizer,
|
| 26 |
+
max_iters,
|
| 27 |
+
eval_interval,
|
| 28 |
+
eval_iters,
|
| 29 |
+
block_size,
|
| 30 |
+
batch_size,
|
| 31 |
+
device,
|
| 32 |
+
):
|
| 33 |
+
val_loss = None
|
| 34 |
+
for iter in range(max_iters):
|
| 35 |
+
if iter % eval_interval == 0:
|
| 36 |
+
losses = estimate_loss(model, eval_iters, block_size, batch_size, device)
|
| 37 |
+
print(
|
| 38 |
+
f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
|
| 39 |
+
)
|
| 40 |
+
if val_loss is not None:
|
| 41 |
+
if losses["val"] < val_loss:
|
| 42 |
+
torch.save(model, "checkpoints/model.pth")
|
| 43 |
+
else:
|
| 44 |
+
val_loss = losses["val"]
|
| 45 |
+
|
| 46 |
+
xb, yb = get_batch("train", block_size, batch_size)
|
| 47 |
+
xb, yb = xb.to(device), yb.to(device)
|
| 48 |
+
|
| 49 |
+
logits, loss = model(xb, yb)
|
| 50 |
+
|
| 51 |
+
optimizer.zero_grad(set_to_none=True)
|
| 52 |
+
loss.backward()
|
| 53 |
+
optimizer.step()
|
src/utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
with open("data/input.txt") as f:
|
| 4 |
+
text = f.read()
|
| 5 |
+
|
| 6 |
+
chars = sorted(list(set(text)))
|
| 7 |
+
vocab_size = len(chars)
|
| 8 |
+
|
| 9 |
+
stoi = {ch: i for i, ch in enumerate(chars)}
|
| 10 |
+
itos = {i: ch for i, ch in enumerate(chars)}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def encode(s):
|
| 14 |
+
return [stoi[c] for c in s]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def decode(l):
|
| 18 |
+
return "".join([itos[i] for i in l])
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
data = torch.tensor(encode(text), dtype=torch.long)
|
| 22 |
+
n = int(0.9 * len(data))
|
| 23 |
+
train_data = data[:n]
|
| 24 |
+
val_data = data[n:]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_batch(split, block_size, batch_size):
|
| 28 |
+
data = train_data if split == "train" else val_data
|
| 29 |
+
ix = torch.randint(len(data) - block_size, (batch_size,))
|
| 30 |
+
x = torch.stack([data[i : i + block_size] for i in ix])
|
| 31 |
+
y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
|
| 32 |
+
return x, y
|