Added tokenise method for streamed data, fixed issues with einsums
Browse files
main.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import argparse
|
|
|
|
| 2 |
import torch as t
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.functional as F
|
|
@@ -9,7 +10,8 @@ import wandb
|
|
| 9 |
from typing import Tuple
|
| 10 |
from torch.utils.data.dataloader import DataLoader
|
| 11 |
from datasets import load_dataset
|
| 12 |
-
from
|
|
|
|
| 13 |
from model import OsSoluModel
|
| 14 |
|
| 15 |
WANDB_PROJECT_NAME = "os_solu"
|
|
@@ -32,7 +34,7 @@ def parse_arguments() -> dict:
|
|
| 32 |
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
|
| 33 |
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
|
| 34 |
parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
|
| 35 |
-
parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings.")
|
| 36 |
parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
|
| 37 |
parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
|
| 38 |
parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
|
|
@@ -40,7 +42,7 @@ def parse_arguments() -> dict:
|
|
| 40 |
parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
|
| 41 |
parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
|
| 42 |
parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
|
| 43 |
-
parser.add_argument("--vocab_size", type=int, default=
|
| 44 |
args = vars(parser.parse_args())
|
| 45 |
|
| 46 |
# Parse string arguments.
|
|
@@ -67,7 +69,6 @@ def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader
|
|
| 67 |
Returns:
|
| 68 |
OsSoluModel: The trained model.
|
| 69 |
"""
|
| 70 |
-
# TODO: training loop
|
| 71 |
train_loss_fn = t.nn.CrossEntropyLoss()
|
| 72 |
wandb.watch(model, criterion=train_loss_fn, log="all", log_freq=10, log_graph=True)
|
| 73 |
|
|
@@ -77,16 +78,17 @@ def train(config: OsSoluConfig, model: OsSoluModel, train_dataloader: DataLoader
|
|
| 77 |
|
| 78 |
# Train loop.
|
| 79 |
examples_seen = 0
|
|
|
|
| 80 |
for epoch in range(config.num_epochs):
|
| 81 |
-
for i,
|
| 82 |
-
|
|
|
|
| 83 |
data = data.to(DEVICE)
|
| 84 |
-
target = target.to(DEVICE)
|
| 85 |
|
| 86 |
predictions = model(data)
|
| 87 |
accuracy = (predictions.argmax(dim=-1) == target).sum() / len(data)
|
| 88 |
optimiser.zero_grad()
|
| 89 |
-
loss = train_loss_fn(
|
| 90 |
loss.backward()
|
| 91 |
optimiser.step()
|
| 92 |
|
|
@@ -109,9 +111,10 @@ def eval(model: OsSoluModel, test_dataloader: DataLoader) -> None:
|
|
| 109 |
total_loss, num_correct = 0, 0
|
| 110 |
model.eval()
|
| 111 |
with t.inference_mode():
|
| 112 |
-
|
|
|
|
|
|
|
| 113 |
data = data.to(DEVICE)
|
| 114 |
-
target = target.to(DEVICE)
|
| 115 |
|
| 116 |
predictions = model(data)
|
| 117 |
num_correct += (predictions.argmax(dim=-1) == target).sum().item()
|
|
@@ -134,15 +137,31 @@ def setup() -> Tuple[OsSoluConfig, OsSoluModel]:
|
|
| 134 |
args = parse_arguments()
|
| 135 |
wandb.init(project=WANDB_PROJECT_NAME, config=args)
|
| 136 |
config = OsSoluConfig(args)
|
| 137 |
-
model = OsSoluModel(config)
|
| 138 |
|
|
|
|
| 139 |
# Load and prep data.
|
| 140 |
ds = load_dataset("the_pile", streaming=True)
|
| 141 |
-
train_dataset = ds["train"].with_format("torch")
|
| 142 |
-
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
|
| 143 |
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size)
|
|
|
|
|
|
|
| 146 |
return config, model, (train_dataloader, test_dataloader)
|
| 147 |
|
| 148 |
if __name__=="__main__":
|
|
|
|
| 1 |
import argparse
|
| 2 |
+
import time
|
| 3 |
import torch as t
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.functional as F
|
|
|
|
| 10 |
from typing import Tuple
|
| 11 |
from torch.utils.data.dataloader import DataLoader
|
| 12 |
from datasets import load_dataset
|
| 13 |
+
from transformers import AutoTokenizer
|
| 14 |
+
from utils import OsSoluConfig, tokenise
|
| 15 |
from model import OsSoluModel
|
| 16 |
|
| 17 |
WANDB_PROJECT_NAME = "os_solu"
|
|
|
|
| 34 |
parser.add_argument("--dropout", type=float, default=0.1, help="Probability of dropout.")
|
| 35 |
parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate for the optimiser.")
|
| 36 |
parser.add_argument("--ln_eps", type=float, default=1e-3, help="Layer norm epsilon.")
|
| 37 |
+
parser.add_argument("--max_positional_embeddings", type=int, default=1024, help="Maximum number of positional embeddings/sequence length.")
|
| 38 |
parser.add_argument("--nonlinearity", type=str, default="solu", help=" Nonlinearity to use inside MLP block: must be relu or solu.")
|
| 39 |
parser.add_argument("--num_blocks", type=int, default=1, help="Number of transformer blocks.")
|
| 40 |
parser.add_argument("--num_embeddings", type=int, default=1024, help="Number of embeddings.")
|
|
|
|
| 42 |
parser.add_argument("--num_heads", type=int, default=4, help="Number of attention heads in each attention layer.")
|
| 43 |
parser.add_argument("--optimiser_type", type=str, default="adam", help="Optimiser type.")
|
| 44 |
parser.add_argument("--self_attention_type", type=str, default="unidirectional", help="What type of attention to use: rotary or unidirectional.")
|
| 45 |
+
parser.add_argument("--vocab_size", type=int, default=50_278, help="Vocabulary size of the input sequence.")
|
| 46 |
args = vars(parser.parse_args())
|
| 47 |
|
| 48 |
# Parse string arguments.
|
|
|
|
| 69 |
Returns:
|
| 70 |
OsSoluModel: The trained model.
|
| 71 |
"""
|
|
|
|
| 72 |
train_loss_fn = t.nn.CrossEntropyLoss()
|
| 73 |
wandb.watch(model, criterion=train_loss_fn, log="all", log_freq=10, log_graph=True)
|
| 74 |
|
|
|
|
| 78 |
|
| 79 |
# Train loop.
|
| 80 |
examples_seen = 0
|
| 81 |
+
train_data_iterator = iter(train_dataloader)
|
| 82 |
for epoch in range(config.num_epochs):
|
| 83 |
+
for i, batch in enumerate(tqdm(train_data_iterator
|
| 84 |
+
)):
|
| 85 |
+
data = batch["text"]
|
| 86 |
data = data.to(DEVICE)
|
|
|
|
| 87 |
|
| 88 |
predictions = model(data)
|
| 89 |
accuracy = (predictions.argmax(dim=-1) == target).sum() / len(data)
|
| 90 |
optimiser.zero_grad()
|
| 91 |
+
# loss = train_loss_fn(data, predictions)
|
| 92 |
loss.backward()
|
| 93 |
optimiser.step()
|
| 94 |
|
|
|
|
| 111 |
total_loss, num_correct = 0, 0
|
| 112 |
model.eval()
|
| 113 |
with t.inference_mode():
|
| 114 |
+
test_data_iterator = iter(test_dataloader)
|
| 115 |
+
for i, (data, target) in enumerate(tqdm(test_data_iterator)):
|
| 116 |
+
data = batch["text"]
|
| 117 |
data = data.to(DEVICE)
|
|
|
|
| 118 |
|
| 119 |
predictions = model(data)
|
| 120 |
num_correct += (predictions.argmax(dim=-1) == target).sum().item()
|
|
|
|
| 137 |
args = parse_arguments()
|
| 138 |
wandb.init(project=WANDB_PROJECT_NAME, config=args)
|
| 139 |
config = OsSoluConfig(args)
|
| 140 |
+
model = OsSoluModel(config).to(DEVICE)
|
| 141 |
|
| 142 |
+
start_data_time = time.time()
|
| 143 |
# Load and prep data.
|
| 144 |
ds = load_dataset("the_pile", streaming=True)
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
try:
|
| 147 |
+
ds = ds.remove_columns("meta")
|
| 148 |
+
except:
|
| 149 |
+
print("Dataset did not contain 'meta' column.")
|
| 150 |
+
|
| 151 |
+
train_dataset = ds["train"]
|
| 152 |
+
test_dataset = ds["test"]
|
| 153 |
+
|
| 154 |
+
# TODO: tokenise the data before sending it to the model.
|
| 155 |
+
tokeniser = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
| 156 |
+
tokeniser.add_special_tokens({"pad_token": "<PAD>"})
|
| 157 |
+
|
| 158 |
+
train_dataset = train_dataset.map(lambda x: tokenise(x, tokeniser), batched=True).with_format("torch")
|
| 159 |
+
test_dataset = test_dataset.map(tokenise, batched=True).with_format("torch")
|
| 160 |
+
|
| 161 |
+
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size)
|
| 162 |
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size)
|
| 163 |
+
print(f"Data loaded in {time.time() - start_data_time:.1f}s.")
|
| 164 |
+
|
| 165 |
return config, model, (train_dataloader, test_dataloader)
|
| 166 |
|
| 167 |
if __name__=="__main__":
|
model.py
CHANGED
|
@@ -3,7 +3,7 @@ import torch.nn as nn
|
|
| 3 |
import torch.functional as F
|
| 4 |
import torch.optim as optim
|
| 5 |
import wandb
|
| 6 |
-
|
| 7 |
from einops import rearrange, repeat, reduce
|
| 8 |
from utils import OsSoluConfig
|
| 9 |
|
|
@@ -22,7 +22,7 @@ class OsSoluModel(nn.Module):
|
|
| 22 |
self.final_ln = nn.LayerNorm(config.d_model, config.ln_eps)
|
| 23 |
|
| 24 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
| 25 |
-
positional_embeddings = self.embed_positions(t.arange(x.size(1)))
|
| 26 |
token_embeddings = self.embed_tokens(x)
|
| 27 |
embeddings = positional_embeddings + token_embeddings
|
| 28 |
out = self.dropout(embeddings)
|
|
@@ -69,9 +69,9 @@ class UnidirectionalAttention(nn.Module):
|
|
| 69 |
super().__init__()
|
| 70 |
self.num_heads = config.num_heads
|
| 71 |
self.d_model = config.d_model
|
| 72 |
-
self.project_q = nn.Linear(config.
|
| 73 |
-
self.project_k = nn.Linear(config.
|
| 74 |
-
self.project_v = nn.Linear(config.
|
| 75 |
self.project_out = nn.Linear(config.d_model, config.d_model)
|
| 76 |
self.LARGE_NEGATIVE_VALUE = -1e5
|
| 77 |
|
|
@@ -84,7 +84,11 @@ class UnidirectionalAttention(nn.Module):
|
|
| 84 |
|
| 85 |
Q = self.hidden_to_heads(Q)
|
| 86 |
K = self.hidden_to_heads(K)
|
| 87 |
-
attention_pattern = einsum(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
return attention_pattern
|
| 90 |
|
|
@@ -95,18 +99,23 @@ class UnidirectionalAttention(nn.Module):
|
|
| 95 |
|
| 96 |
# Masking attention. Since GPT is unidirectional, it should only attend to previous tokens.
|
| 97 |
if seqlen > 1:
|
| 98 |
-
fst_range = t.arange(seqlen, device=
|
| 99 |
-
snd_range = t.arange(seqlen, device=
|
| 100 |
bool_array = fst_range < snd_range
|
| 101 |
-
|
| 102 |
|
| 103 |
|
| 104 |
attention_pattern = attention_pattern / t.sqrt(t.tensor(self.d_model // self.num_heads))
|
| 105 |
attention_score = attention_pattern.softmax(dim=-1)
|
| 106 |
|
| 107 |
V = self.hidden_to_heads(V)
|
| 108 |
-
out = einsum(
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
out = self.project_out(out)
|
| 111 |
|
| 112 |
|
|
|
|
| 3 |
import torch.functional as F
|
| 4 |
import torch.optim as optim
|
| 5 |
import wandb
|
| 6 |
+
from fancy_einsum import einsum
|
| 7 |
from einops import rearrange, repeat, reduce
|
| 8 |
from utils import OsSoluConfig
|
| 9 |
|
|
|
|
| 22 |
self.final_ln = nn.LayerNorm(config.d_model, config.ln_eps)
|
| 23 |
|
| 24 |
def forward(self, x: t.Tensor) -> t.Tensor:
|
| 25 |
+
positional_embeddings = self.embed_positions(t.arange(x.size(1), device=x.device))
|
| 26 |
token_embeddings = self.embed_tokens(x)
|
| 27 |
embeddings = positional_embeddings + token_embeddings
|
| 28 |
out = self.dropout(embeddings)
|
|
|
|
| 69 |
super().__init__()
|
| 70 |
self.num_heads = config.num_heads
|
| 71 |
self.d_model = config.d_model
|
| 72 |
+
self.project_q = nn.Linear(config.d_model, config.d_model)
|
| 73 |
+
self.project_k = nn.Linear(config.d_model, config.d_model)
|
| 74 |
+
self.project_v = nn.Linear(config.d_model, config.d_model)
|
| 75 |
self.project_out = nn.Linear(config.d_model, config.d_model)
|
| 76 |
self.LARGE_NEGATIVE_VALUE = -1e5
|
| 77 |
|
|
|
|
| 84 |
|
| 85 |
Q = self.hidden_to_heads(Q)
|
| 86 |
K = self.hidden_to_heads(K)
|
| 87 |
+
attention_pattern = einsum(
|
| 88 |
+
"batch num_heads seqlen_q head_size, "
|
| 89 |
+
"batch num_heads seqlen_k head_size ->"
|
| 90 |
+
"batch num_heads seqlen_q seqlen_k",
|
| 91 |
+
Q, K)
|
| 92 |
|
| 93 |
return attention_pattern
|
| 94 |
|
|
|
|
| 99 |
|
| 100 |
# Masking attention. Since GPT is unidirectional, it should only attend to previous tokens.
|
| 101 |
if seqlen > 1:
|
| 102 |
+
fst_range = t.arange(seqlen, device=x.device).unsqueeze(0).T
|
| 103 |
+
snd_range = t.arange(seqlen, device=x.device).unsqueeze(0)
|
| 104 |
bool_array = fst_range < snd_range
|
| 105 |
+
attention_pattern[..., bool_array] = self.LARGE_NEGATIVE_VALUE
|
| 106 |
|
| 107 |
|
| 108 |
attention_pattern = attention_pattern / t.sqrt(t.tensor(self.d_model // self.num_heads))
|
| 109 |
attention_score = attention_pattern.softmax(dim=-1)
|
| 110 |
|
| 111 |
V = self.hidden_to_heads(V)
|
| 112 |
+
out = einsum(
|
| 113 |
+
"batch num_heads seqlen_q seqlen_k,"
|
| 114 |
+
"batch num_heads seqlen_k head_size ->"
|
| 115 |
+
"batch num_heads seqlen_q head_size",
|
| 116 |
+
attention_score, V)
|
| 117 |
+
|
| 118 |
+
out = rearrange(out, "b nh s hs -> b s (nh hs)")
|
| 119 |
out = self.project_out(out)
|
| 120 |
|
| 121 |
|
requirements.txt
CHANGED
|
@@ -9,6 +9,7 @@ notebook
|
|
| 9 |
numpy-stl
|
| 10 |
plotly
|
| 11 |
torch
|
|
|
|
| 12 |
tqdm
|
| 13 |
wandb
|
| 14 |
zstandard
|
|
|
|
| 9 |
numpy-stl
|
| 10 |
plotly
|
| 11 |
torch
|
| 12 |
+
transformers
|
| 13 |
tqdm
|
| 14 |
wandb
|
| 15 |
zstandard
|
utils.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
class OsSoluConfig:
|
| 2 |
"""A class to hold hyperparameters for the model itself and for the training process."""
|
| 3 |
|
|
@@ -32,4 +35,42 @@ class OsSoluConfig:
|
|
| 32 |
self.num_heads = args["num_heads"]
|
| 33 |
self.optimiser_type = args["optimiser_type"]
|
| 34 |
self.self_attention_type = args["self_attention_type"]
|
| 35 |
-
self.vocab_size = args["vocab_size"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
|
| 4 |
class OsSoluConfig:
|
| 5 |
"""A class to hold hyperparameters for the model itself and for the training process."""
|
| 6 |
|
|
|
|
| 35 |
self.num_heads = args["num_heads"]
|
| 36 |
self.optimiser_type = args["optimiser_type"]
|
| 37 |
self.self_attention_type = args["self_attention_type"]
|
| 38 |
+
self.vocab_size = args["vocab_size"]
|
| 39 |
+
|
| 40 |
+
def tokenise(batch, tokeniser, num_gpus: int = 1, context_length: int = 1024):
|
| 41 |
+
"""Tokenise a batch of text data. This implementation is idiosyncratic to the Pile dataset, but can be easily modified to work with e.g. C4.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
batch (dict): The batch of text, as a dict with a 'text' field.
|
| 45 |
+
tokeniser (-): A huggingface-API tokeniser, of type returned by AutoTokenizer.from_pretrained (depends on model chosen).
|
| 46 |
+
num_gpus (int, optional): The number of GPUs available for data parallel training. Defaults to 1.
|
| 47 |
+
context_length (int, optional): The context length of the model that will be trained on this data. Defaults to 1024.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
dict: A single field dictionary, 'text', whose value is a tensor of shape (batch_size, sequence_length) containing tokenised sequences.
|
| 51 |
+
"""
|
| 52 |
+
batch = batch["text"]
|
| 53 |
+
full_text = tokeniser.eos_token.join(batch)
|
| 54 |
+
|
| 55 |
+
# Divide entire batch among all GPUs available.
|
| 56 |
+
seq_len = len(full_text)//num_gpus
|
| 57 |
+
sequence_list = [full_text[i*seq_len:(i+1)*seq_len] for i in range(num_gpus)]
|
| 58 |
+
|
| 59 |
+
# Tokenise sequences, removing padding tokens.
|
| 60 |
+
all_tokens = tokeniser(sequence_list, return_tensors="pt", padding=True)["input_ids"].flatten()
|
| 61 |
+
all_tokens = all_tokens[all_tokens != tokeniser.pad_token_id]
|
| 62 |
+
|
| 63 |
+
# Reshape all_tokens to be (batch_size x sequence_length) where each sequence has
|
| 64 |
+
# a "beginning of sequence" token prepended to it.
|
| 65 |
+
num_tokens = len(all_tokens)
|
| 66 |
+
current_batch_size = num_tokens // (context_length-1)
|
| 67 |
+
all_tokens = all_tokens[:(context_length-1)*current_batch_size]
|
| 68 |
+
all_tokens = rearrange(all_tokens, "(batch_size seq_len) -> batch_size seq_len", batch_size=current_batch_size, seq_len=context_length-1)
|
| 69 |
+
prefix = np.full((current_batch_size, 1), tokeniser.bos_token_id, dtype=np.int64)
|
| 70 |
+
|
| 71 |
+
tokenised_text = np.concatenate([prefix, all_tokens], axis=1)
|
| 72 |
+
assert tokenised_text.shape == (current_batch_size, context_length)
|
| 73 |
+
print(f"{current_batch_size=}, {context_length=}")
|
| 74 |
+
return {"text": tokenised_text}
|
| 75 |
+
|
| 76 |
+
|