Sarthak commited on
Commit ·
4255a26
1
Parent(s): 473c3a0
chore: moved tokenlearn as in internal package
Browse files
src/distiller/tokenlearn/__init__.py
ADDED
|
File without changes
|
src/distiller/tokenlearn/featurize.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from collections.abc import Iterator
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
from more_itertools import batched
|
| 10 |
+
from sentence_transformers import SentenceTransformer
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 13 |
+
|
| 14 |
+
_SAVE_EVERY = 32
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def featurize(
|
| 21 |
+
dataset: Iterator[dict[str, str]],
|
| 22 |
+
model: SentenceTransformer,
|
| 23 |
+
output_dir: str,
|
| 24 |
+
max_means: int,
|
| 25 |
+
batch_size: int,
|
| 26 |
+
text_key: str,
|
| 27 |
+
) -> None:
|
| 28 |
+
"""Make a directory and dump all kinds of data in it."""
|
| 29 |
+
output_dir_path = Path(output_dir)
|
| 30 |
+
output_dir_path.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
# Ugly hack
|
| 33 |
+
largest_batch = max([int(x.stem.split("_")[1]) for x in list(output_dir_path.glob("*.json"))], default=0)
|
| 34 |
+
if largest_batch:
|
| 35 |
+
logger.info(f"Resuming from batch {largest_batch}, skipping previous batches.")
|
| 36 |
+
|
| 37 |
+
texts = []
|
| 38 |
+
embeddings = []
|
| 39 |
+
dim = model.get_sentence_embedding_dimension()
|
| 40 |
+
if dim is None:
|
| 41 |
+
msg = "Model has no sentence embedding dimension."
|
| 42 |
+
raise ValueError(msg)
|
| 43 |
+
|
| 44 |
+
tokenizer: PreTrainedTokenizer = model.tokenizer
|
| 45 |
+
# Binding i in case the dataset is empty.
|
| 46 |
+
i = 0
|
| 47 |
+
for i, batch in tqdm(enumerate(batched(dataset, n=batch_size))):
|
| 48 |
+
if i * batch_size >= max_means:
|
| 49 |
+
logger.info(f"Reached maximum number of means: {max_means}")
|
| 50 |
+
break
|
| 51 |
+
if largest_batch and i <= largest_batch:
|
| 52 |
+
continue
|
| 53 |
+
batch = [x[text_key] for x in batch]
|
| 54 |
+
|
| 55 |
+
if not all(isinstance(x, str) for x in batch):
|
| 56 |
+
msg = f"Detected non-string at batch: {i}"
|
| 57 |
+
raise ValueError(msg)
|
| 58 |
+
|
| 59 |
+
batch_embeddings = model.encode(batch, output_value="token_embeddings") # type: ignore # Annoying
|
| 60 |
+
for text, embedding in zip(batch, batch_embeddings, strict=False):
|
| 61 |
+
texts.append(_truncate_text(tokenizer, text))
|
| 62 |
+
embeddings.append(embedding[1:-1].mean(axis=0).cpu().numpy())
|
| 63 |
+
if i and i % _SAVE_EVERY == 0:
|
| 64 |
+
json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4)
|
| 65 |
+
np.save(output_dir_path / f"feature_{i}.npy", embeddings)
|
| 66 |
+
texts = []
|
| 67 |
+
embeddings = []
|
| 68 |
+
if texts:
|
| 69 |
+
json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4)
|
| 70 |
+
np.save(output_dir_path / f"feature_{i}.npy", embeddings)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _truncate_text(tokenizer: PreTrainedTokenizer, text: str) -> str:
|
| 74 |
+
"""Truncate text to fit the tokenizer's maximum length."""
|
| 75 |
+
tokens = tokenizer.encode(
|
| 76 |
+
text,
|
| 77 |
+
truncation=True,
|
| 78 |
+
max_length=tokenizer.model_max_length,
|
| 79 |
+
)
|
| 80 |
+
return tokenizer.decode(tokens, skip_special_tokens=True)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def main() -> None:
|
| 84 |
+
"""Main function to featurize texts using a sentence transformer."""
|
| 85 |
+
parser = argparse.ArgumentParser(description="Featurize texts using a sentence transformer.")
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--model-name",
|
| 88 |
+
type=str,
|
| 89 |
+
default="baai/bge-base-en-v1.5",
|
| 90 |
+
help="The model name for distillation (e.g., 'baai/bge-base-en-v1.5').",
|
| 91 |
+
)
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--output-dir",
|
| 94 |
+
type=str,
|
| 95 |
+
default=None,
|
| 96 |
+
help="Directory to save the featurized texts.",
|
| 97 |
+
)
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--dataset-path",
|
| 100 |
+
type=str,
|
| 101 |
+
default="allenai/c4",
|
| 102 |
+
help="The dataset path or name (e.g. 'allenai/c4').",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--dataset-name",
|
| 106 |
+
type=str,
|
| 107 |
+
default="en",
|
| 108 |
+
help="The dataset configuration name (e.g., 'en' for C4).",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--dataset-split",
|
| 112 |
+
type=str,
|
| 113 |
+
default="train",
|
| 114 |
+
help="The dataset split (e.g., 'train', 'validation').",
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--no-streaming",
|
| 118 |
+
action="store_false",
|
| 119 |
+
help="Disable streaming mode when loading the dataset.",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--max-means",
|
| 123 |
+
type=int,
|
| 124 |
+
default=1000000,
|
| 125 |
+
help="The maximum number of mean embeddings to generate.",
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--key",
|
| 129 |
+
type=str,
|
| 130 |
+
default="text",
|
| 131 |
+
help="The key of the text field in the dataset to featurize (default: 'text').",
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--batch-size",
|
| 135 |
+
type=int,
|
| 136 |
+
default=32,
|
| 137 |
+
help="Batch size to use for encoding the texts.",
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
args = parser.parse_args()
|
| 141 |
+
|
| 142 |
+
if args.output_dir is None:
|
| 143 |
+
model_name = args.model_name.replace("/", "_")
|
| 144 |
+
dataset_path = args.dataset_path.replace("/", "_")
|
| 145 |
+
output_dir = f"{model_name}_{dataset_path}_featurized"
|
| 146 |
+
else:
|
| 147 |
+
output_dir = args.output_dir
|
| 148 |
+
|
| 149 |
+
model = SentenceTransformer(args.model_name)
|
| 150 |
+
dataset = load_dataset(
|
| 151 |
+
args.dataset_path,
|
| 152 |
+
name=args.dataset_name,
|
| 153 |
+
split=args.dataset_split,
|
| 154 |
+
streaming=args.no_streaming,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
featurize(iter(dataset), model, output_dir, args.max_means, args.batch_size, args.key)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
main()
|
src/distiller/tokenlearn/pretrain.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
+
from torch.utils.data import DataLoader, Dataset
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from distiller.model2vec.distill.utils import select_optimal_device
|
| 14 |
+
from distiller.model2vec.model import StaticModel
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from tokenizers import Tokenizer
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class StaticModelFineTuner(nn.Module):
|
| 23 |
+
def __init__(self, vectors: torch.Tensor, out_dim: int, pad_id: int) -> None:
|
| 24 |
+
"""
|
| 25 |
+
Initialize from a model.
|
| 26 |
+
|
| 27 |
+
:param vectors: The vectors to use.
|
| 28 |
+
:param out_dim: The output dimension.
|
| 29 |
+
:param pad_id: The padding id.
|
| 30 |
+
"""
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.pad_id = pad_id
|
| 33 |
+
norms = vectors.norm(dim=1)
|
| 34 |
+
# Normalize the vectors
|
| 35 |
+
vectors = vectors / norms[:, None]
|
| 36 |
+
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
|
| 37 |
+
self.n_out = out_dim
|
| 38 |
+
self.out_layer = nn.Linear(vectors.shape[1], self.n_out)
|
| 39 |
+
weights = torch.Tensor(norms)
|
| 40 |
+
weights[pad_id] = 0
|
| 41 |
+
self.w = nn.Parameter(weights)
|
| 42 |
+
|
| 43 |
+
def sub_forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
"""Forward pass through the mean."""
|
| 45 |
+
# Fix for index out of bounds issue - filter out invalid tokens
|
| 46 |
+
valid_mask = (input_ids >= 0) & (input_ids < self.w.shape[0])
|
| 47 |
+
if not valid_mask.all():
|
| 48 |
+
input_ids = torch.where(valid_mask, input_ids, 0)
|
| 49 |
+
w = self.w[input_ids]
|
| 50 |
+
zeros = (input_ids != self.pad_id).float()
|
| 51 |
+
w = w * zeros
|
| 52 |
+
# Add a small epsilon to avoid division by zero
|
| 53 |
+
length = zeros.sum(1) + 1e-16
|
| 54 |
+
# Fix for embedding index out of bounds issue
|
| 55 |
+
valid_emb_mask = (input_ids >= 0) & (input_ids < self.embeddings.num_embeddings)
|
| 56 |
+
if not valid_emb_mask.all():
|
| 57 |
+
input_ids = torch.where(valid_emb_mask, input_ids, 0)
|
| 58 |
+
embedded = self.embeddings(input_ids)
|
| 59 |
+
# Zero out the padding
|
| 60 |
+
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
|
| 61 |
+
# Simulate actual mean
|
| 62 |
+
return embedded / length[:, None]
|
| 63 |
+
|
| 64 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 65 |
+
"""Forward pass through the mean, and a classifier layer after."""
|
| 66 |
+
embedded = self.sub_forward(x)
|
| 67 |
+
return self.out_layer(embedded), embedded
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def device(self) -> torch.device:
|
| 71 |
+
"""Get the device of the model."""
|
| 72 |
+
return self.embeddings.weight.device
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TextDataset(Dataset):
|
| 76 |
+
def __init__(self, texts: list[str], targets: torch.Tensor, tokenizer: Tokenizer) -> None:
|
| 77 |
+
"""
|
| 78 |
+
Initialize the dataset.
|
| 79 |
+
|
| 80 |
+
:param texts: The texts to tokenize.
|
| 81 |
+
:param targets: The targets.
|
| 82 |
+
:param tokenizer: The tokenizer to use.
|
| 83 |
+
:raises ValueError: If the number of labels does not match the number of texts.
|
| 84 |
+
"""
|
| 85 |
+
if len(targets) != len(texts):
|
| 86 |
+
msg = "Number of labels does not match number of texts."
|
| 87 |
+
raise ValueError(msg)
|
| 88 |
+
self.texts = [x[:20_000] for x in texts]
|
| 89 |
+
self.tokenized_texts: list[list[int]] = [
|
| 90 |
+
encoding.ids[:512] for encoding in tokenizer.encode_batch_fast(self.texts, add_special_tokens=False)
|
| 91 |
+
]
|
| 92 |
+
self.targets = targets
|
| 93 |
+
self.tokenizer = tokenizer
|
| 94 |
+
|
| 95 |
+
def __len__(self) -> int:
|
| 96 |
+
"""Return the length of the dataset."""
|
| 97 |
+
return len(self.tokenized_texts)
|
| 98 |
+
|
| 99 |
+
def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]:
|
| 100 |
+
"""Gets an item."""
|
| 101 |
+
return self.tokenized_texts[index], self.targets[index]
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]:
|
| 105 |
+
"""Collate function."""
|
| 106 |
+
texts, targets = zip(*batch, strict=False)
|
| 107 |
+
|
| 108 |
+
tensors = [torch.LongTensor(x).int() for x in texts]
|
| 109 |
+
padded = pad_sequence(tensors, batch_first=True, padding_value=0)
|
| 110 |
+
|
| 111 |
+
return padded, torch.stack(targets)
|
| 112 |
+
|
| 113 |
+
def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader:
|
| 114 |
+
"""Convert the dataset to a DataLoader."""
|
| 115 |
+
return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def train_supervised(
|
| 119 |
+
train_dataset: TextDataset,
|
| 120 |
+
validation_dataset: TextDataset,
|
| 121 |
+
model: StaticModel,
|
| 122 |
+
patience: int | None = 5,
|
| 123 |
+
device: str | None = None,
|
| 124 |
+
batch_size: int = 256,
|
| 125 |
+
lr: float = 1e-3,
|
| 126 |
+
) -> StaticModel:
|
| 127 |
+
"""
|
| 128 |
+
Train a tokenlearn model.
|
| 129 |
+
|
| 130 |
+
:param train_dataset: The training dataset.
|
| 131 |
+
:param validation_dataset: The validation dataset.
|
| 132 |
+
:param model: The model to train.
|
| 133 |
+
:param patience: The number of epochs to wait before early stopping.
|
| 134 |
+
:param device: The device to train on.
|
| 135 |
+
:param batch_size: The batch size.
|
| 136 |
+
:param lr: The learning rate.
|
| 137 |
+
:return: The trained model.
|
| 138 |
+
"""
|
| 139 |
+
device = select_optimal_device(device)
|
| 140 |
+
train_dataloader = train_dataset.to_dataloader(shuffle=True, batch_size=batch_size)
|
| 141 |
+
|
| 142 |
+
# Initialize the model
|
| 143 |
+
trainable_model = StaticModelFineTuner(
|
| 144 |
+
torch.from_numpy(model.embedding),
|
| 145 |
+
out_dim=train_dataset.targets.shape[1],
|
| 146 |
+
pad_id=model.tokenizer.token_to_id("[PAD]"),
|
| 147 |
+
)
|
| 148 |
+
trainable_model.to(device)
|
| 149 |
+
|
| 150 |
+
# Separate parameters for model and linear layer
|
| 151 |
+
model_params = [
|
| 152 |
+
*list(trainable_model.embeddings.parameters()),
|
| 153 |
+
trainable_model.w,
|
| 154 |
+
*list(trainable_model.out_layer.parameters()),
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
# Create optimizer with separate parameter groups
|
| 158 |
+
optimizer = torch.optim.AdamW(params=model_params, lr=lr)
|
| 159 |
+
|
| 160 |
+
lowest_loss = float("inf")
|
| 161 |
+
param_dict = trainable_model.state_dict()
|
| 162 |
+
curr_patience = patience
|
| 163 |
+
stop = False
|
| 164 |
+
|
| 165 |
+
criterion = nn.MSELoss()
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
for epoch in range(100_000):
|
| 169 |
+
logger.info(f"Epoch {epoch}")
|
| 170 |
+
trainable_model.train()
|
| 171 |
+
|
| 172 |
+
# Track train loss separately
|
| 173 |
+
train_losses = []
|
| 174 |
+
barred_train = tqdm(train_dataloader, desc=f"Epoch {epoch:03d} [Train]")
|
| 175 |
+
|
| 176 |
+
for idx, (x, y) in enumerate(barred_train):
|
| 177 |
+
optimizer.zero_grad()
|
| 178 |
+
x = x.to(trainable_model.device)
|
| 179 |
+
y_hat, _ = trainable_model(x)
|
| 180 |
+
# Separate loss components
|
| 181 |
+
train_loss = criterion(y_hat, y.to(trainable_model.device)).mean()
|
| 182 |
+
|
| 183 |
+
# Apply weights
|
| 184 |
+
train_loss.backward()
|
| 185 |
+
|
| 186 |
+
optimizer.step()
|
| 187 |
+
train_losses.append(train_loss.item())
|
| 188 |
+
|
| 189 |
+
barred_train.set_description_str(f"Train Loss: {np.mean(train_losses[-10:]):.3f}")
|
| 190 |
+
|
| 191 |
+
# Evaluate every 1000 steps and at the end of the epoch
|
| 192 |
+
if (idx > 0 and idx % 1000 == 0) or idx == len(train_dataloader) - 1:
|
| 193 |
+
trainable_model.eval()
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
validation_losses = []
|
| 196 |
+
barred_val = tqdm(
|
| 197 |
+
validation_dataset.to_dataloader(shuffle=False, batch_size=batch_size), desc="Validation"
|
| 198 |
+
)
|
| 199 |
+
for x_val, y_val in barred_val:
|
| 200 |
+
x_val = x_val.to(trainable_model.device)
|
| 201 |
+
y_hat_val, _ = trainable_model(x_val)
|
| 202 |
+
val_loss = criterion(y_hat_val, y_val.to(trainable_model.device)).mean()
|
| 203 |
+
validation_losses.append(val_loss.item())
|
| 204 |
+
barred_val.set_description_str(f"Validation Loss: {np.mean(validation_losses):.3f}")
|
| 205 |
+
|
| 206 |
+
validation_loss = np.mean(validation_losses)
|
| 207 |
+
# Early stopping logic based on validation loss
|
| 208 |
+
if patience is not None and curr_patience is not None:
|
| 209 |
+
if (lowest_loss - validation_loss) > 1e-4:
|
| 210 |
+
param_dict = trainable_model.state_dict() # Save best model state based on training loss
|
| 211 |
+
curr_patience = patience
|
| 212 |
+
lowest_loss = validation_loss
|
| 213 |
+
else:
|
| 214 |
+
curr_patience -= 1
|
| 215 |
+
if curr_patience == 0:
|
| 216 |
+
stop = True
|
| 217 |
+
break
|
| 218 |
+
logger.info(f"Patience level: {patience - curr_patience}")
|
| 219 |
+
logger.info(f"Validation loss: {validation_loss:.3f}")
|
| 220 |
+
logger.info(f"Lowest loss: {lowest_loss:.3f}")
|
| 221 |
+
|
| 222 |
+
trainable_model.train()
|
| 223 |
+
|
| 224 |
+
if stop:
|
| 225 |
+
logger.info("Early stopping")
|
| 226 |
+
break
|
| 227 |
+
|
| 228 |
+
except KeyboardInterrupt:
|
| 229 |
+
logger.info("Training interrupted")
|
| 230 |
+
|
| 231 |
+
trainable_model.eval()
|
| 232 |
+
# Load best model based on training loss
|
| 233 |
+
trainable_model.load_state_dict(param_dict)
|
| 234 |
+
|
| 235 |
+
# Move the embeddings to the device (GPU)
|
| 236 |
+
embeddings_weight = trainable_model.embeddings.weight.to(device)
|
| 237 |
+
|
| 238 |
+
# Perform the forward pass on GPU
|
| 239 |
+
with torch.no_grad():
|
| 240 |
+
vectors = trainable_model.sub_forward(torch.arange(len(embeddings_weight))[:, None].to(device)).cpu().numpy()
|
| 241 |
+
|
| 242 |
+
return StaticModel(vectors=vectors, tokenizer=model.tokenizer, config=model.config)
|
src/distiller/tokenlearn/train.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from sklearn.decomposition import PCA
|
| 8 |
+
|
| 9 |
+
from distiller.model2vec.distill import distill
|
| 10 |
+
from distiller.model2vec.model import StaticModel
|
| 11 |
+
from distiller.tokenlearn.pretrain import TextDataset, train_supervised
|
| 12 |
+
from distiller.tokenlearn.utils import collect_means_and_texts, create_vocab
|
| 13 |
+
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
_MAX_N_VAL_SAMPLES = 10_000
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def train_model(
|
| 22 |
+
model_name: str,
|
| 23 |
+
train_txt: list[str],
|
| 24 |
+
train_vec: np.ndarray,
|
| 25 |
+
device: str = "cpu",
|
| 26 |
+
vocab_size: int | None = None,
|
| 27 |
+
pca_dims: int = 256,
|
| 28 |
+
) -> StaticModel:
|
| 29 |
+
"""
|
| 30 |
+
Train a tokenlearn model.
|
| 31 |
+
|
| 32 |
+
:param model_name: The sentence transformer model name for distillation.
|
| 33 |
+
:param train_txt: List of texts to train on.
|
| 34 |
+
:param train_vec: List of vectors to train on.
|
| 35 |
+
:param device: Device to run the training on.
|
| 36 |
+
:param vocab_size: The vocabulary size to use (optional).
|
| 37 |
+
:param pca_dims: Number of dimensions to reduce the target embeddings to using PCA.
|
| 38 |
+
The model will use the same number of dimensions for the embeddings.
|
| 39 |
+
:return: The trained model.
|
| 40 |
+
"""
|
| 41 |
+
pca_for_targets = PCA(n_components=pca_dims)
|
| 42 |
+
train_vec = pca_for_targets.fit_transform(train_vec)
|
| 43 |
+
var = np.cumsum(pca_for_targets.explained_variance_ratio_)[-1]
|
| 44 |
+
logger.info(f"Explained variance of target embeddings: {var:.2f}")
|
| 45 |
+
|
| 46 |
+
# Split the data into training and validation sets
|
| 47 |
+
# We use a max of 10k samples as validation data
|
| 48 |
+
val_samples = min(_MAX_N_VAL_SAMPLES, len(train_txt) // 10)
|
| 49 |
+
train_txt, train_vec, val_txt, val_vec = (
|
| 50 |
+
train_txt[:-val_samples],
|
| 51 |
+
train_vec[:-val_samples],
|
| 52 |
+
train_txt[-val_samples:],
|
| 53 |
+
train_vec[-val_samples:],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
if vocab_size:
|
| 57 |
+
# Create a vocabulary if a vocab size is specified
|
| 58 |
+
vocab = create_vocab(texts=train_txt, vocab_size=vocab_size)
|
| 59 |
+
logger.info(f"Vocabulary created with {len(vocab)} tokens.")
|
| 60 |
+
else:
|
| 61 |
+
vocab = None
|
| 62 |
+
model = distill(model_name=model_name, quantize_to="float32", vocabulary=vocab, pca_dims=pca_dims)
|
| 63 |
+
train_data = TextDataset(train_txt, torch.from_numpy(train_vec), model.tokenizer)
|
| 64 |
+
val_data = TextDataset(val_txt, torch.from_numpy(val_vec), model.tokenizer)
|
| 65 |
+
|
| 66 |
+
# Train the model
|
| 67 |
+
return train_supervised(train_dataset=train_data, validation_dataset=val_data, model=model, device=device)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def save_model(model: StaticModel, save_path: str) -> None:
|
| 71 |
+
"""
|
| 72 |
+
Save the model to the specified path.
|
| 73 |
+
|
| 74 |
+
:param model: The model to save.
|
| 75 |
+
:param save_path: Path to save the model.
|
| 76 |
+
"""
|
| 77 |
+
model.save_pretrained(save_path)
|
| 78 |
+
logging.info(f"Model saved to {save_path}")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def main() -> None:
|
| 82 |
+
"""Main function to train and save a Model2Vec model using tokenlearn."""
|
| 83 |
+
parser = argparse.ArgumentParser(description="Train a Model2Vec using tokenlearn.")
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--model-name",
|
| 86 |
+
type=str,
|
| 87 |
+
default="baai/bge-base-en-v1.5",
|
| 88 |
+
help="The model name for distillation (e.g., 'baai/bge-base-en-v1.5').",
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--data-path",
|
| 92 |
+
type=str,
|
| 93 |
+
default="data/fineweb_bgebase",
|
| 94 |
+
help="Path to the directory containing the dataset.",
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--save-path",
|
| 98 |
+
type=str,
|
| 99 |
+
required=True,
|
| 100 |
+
help="Path to save the trained model.",
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--device",
|
| 104 |
+
type=str,
|
| 105 |
+
default="cpu",
|
| 106 |
+
help="Device to run the training on (e.g., 'cpu', 'cuda').",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--vocab-size",
|
| 110 |
+
type=int,
|
| 111 |
+
default=56000,
|
| 112 |
+
help="The vocabulary size to use for training.",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--pca-dims",
|
| 116 |
+
type=int,
|
| 117 |
+
default=256,
|
| 118 |
+
help="Number of dimensions to reduce the target embeddings to using PCA.",
|
| 119 |
+
)
|
| 120 |
+
args = parser.parse_args()
|
| 121 |
+
|
| 122 |
+
# Collect paths for training data
|
| 123 |
+
paths = sorted(Path(args.data_path).glob("*.json"))
|
| 124 |
+
train_txt, train_vec = collect_means_and_texts(paths)
|
| 125 |
+
|
| 126 |
+
# Train the model
|
| 127 |
+
model = train_model(
|
| 128 |
+
args.model_name, train_txt, train_vec, device=args.device, vocab_size=args.vocab_size, pca_dims=args.pca_dims
|
| 129 |
+
)
|
| 130 |
+
save_model(model, args.save_path)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
main()
|
src/distiller/tokenlearn/utils.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
from collections import Counter
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import regex
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_vocab(texts: list[str], vocab_size: int = 56_000) -> list[str]:
|
| 14 |
+
"""
|
| 15 |
+
Create a vocabulary from a list of texts.
|
| 16 |
+
|
| 17 |
+
:param texts: The list of texts to create the vocabulary from.
|
| 18 |
+
:param vocab_size: The size of the vocabulary. Defaults to 56,000, which is the vocab_size used for our 32M models.
|
| 19 |
+
:return: The vocabulary.
|
| 20 |
+
"""
|
| 21 |
+
tokenizer_regex = regex.compile(r"\w+|[^\w\s]+")
|
| 22 |
+
|
| 23 |
+
# Tokenize all texts
|
| 24 |
+
tokens = []
|
| 25 |
+
for text in tqdm(texts, desc="Tokenizing texts"):
|
| 26 |
+
tokens.extend(tokenizer_regex.findall(text.lower()))
|
| 27 |
+
|
| 28 |
+
# Count the tokens
|
| 29 |
+
token_counts = Counter(tokens)
|
| 30 |
+
|
| 31 |
+
# Get the most common tokens as the vocabulary
|
| 32 |
+
return [word for word, _ in token_counts.most_common(vocab_size)]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def collect_means_and_texts(paths: list[Path]) -> tuple[list[str], np.ndarray]:
|
| 36 |
+
"""Collect means and texts from a list of paths."""
|
| 37 |
+
txts = []
|
| 38 |
+
vectors_list = []
|
| 39 |
+
for items_path in tqdm(paths, desc="Collecting means and texts"):
|
| 40 |
+
if not items_path.name.endswith(".json"):
|
| 41 |
+
continue
|
| 42 |
+
base_path = items_path.with_name(items_path.stem.replace("", ""))
|
| 43 |
+
vectors_path = items_path.with_name(base_path.name.replace(".json", "") + ".npy")
|
| 44 |
+
try:
|
| 45 |
+
with open(items_path) as f:
|
| 46 |
+
items = json.load(f)
|
| 47 |
+
vectors = np.load(vectors_path, allow_pickle=False)
|
| 48 |
+
except (KeyError, FileNotFoundError, ValueError) as e:
|
| 49 |
+
logger.info(f"Error loading data from {base_path}: {e}")
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
# Filter out any NaN vectors before appending
|
| 53 |
+
vectors = np.array(vectors)
|
| 54 |
+
items = np.array(items)
|
| 55 |
+
non_nan_indices = ~np.isnan(vectors).any(axis=1)
|
| 56 |
+
valid_vectors = vectors[non_nan_indices]
|
| 57 |
+
valid_items = items[non_nan_indices]
|
| 58 |
+
txts.extend(valid_items.tolist())
|
| 59 |
+
vectors_list.append(valid_vectors)
|
| 60 |
+
|
| 61 |
+
all_vectors = np.concatenate(vectors_list, axis=0) if vectors_list else np.array([])
|
| 62 |
+
return txts, all_vectors
|
src/distiller/tokenlearn/version.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version_triple__ = (0, 2, 0)
|
| 2 |
+
__version__ = ".".join(map(str, __version_triple__))
|