Sarthak commited on
Commit
4255a26
·
1 Parent(s): 473c3a0

chore: moved tokenlearn as in internal package

Browse files
src/distiller/tokenlearn/__init__.py ADDED
File without changes
src/distiller/tokenlearn/featurize.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ from collections.abc import Iterator
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ from datasets import load_dataset
9
+ from more_itertools import batched
10
+ from sentence_transformers import SentenceTransformer
11
+ from tqdm import tqdm
12
+ from transformers.tokenization_utils import PreTrainedTokenizer
13
+
14
+ _SAVE_EVERY = 32
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def featurize(
21
+ dataset: Iterator[dict[str, str]],
22
+ model: SentenceTransformer,
23
+ output_dir: str,
24
+ max_means: int,
25
+ batch_size: int,
26
+ text_key: str,
27
+ ) -> None:
28
+ """Make a directory and dump all kinds of data in it."""
29
+ output_dir_path = Path(output_dir)
30
+ output_dir_path.mkdir(parents=True, exist_ok=True)
31
+
32
+ # Ugly hack
33
+ largest_batch = max([int(x.stem.split("_")[1]) for x in list(output_dir_path.glob("*.json"))], default=0)
34
+ if largest_batch:
35
+ logger.info(f"Resuming from batch {largest_batch}, skipping previous batches.")
36
+
37
+ texts = []
38
+ embeddings = []
39
+ dim = model.get_sentence_embedding_dimension()
40
+ if dim is None:
41
+ msg = "Model has no sentence embedding dimension."
42
+ raise ValueError(msg)
43
+
44
+ tokenizer: PreTrainedTokenizer = model.tokenizer
45
+ # Binding i in case the dataset is empty.
46
+ i = 0
47
+ for i, batch in tqdm(enumerate(batched(dataset, n=batch_size))):
48
+ if i * batch_size >= max_means:
49
+ logger.info(f"Reached maximum number of means: {max_means}")
50
+ break
51
+ if largest_batch and i <= largest_batch:
52
+ continue
53
+ batch = [x[text_key] for x in batch]
54
+
55
+ if not all(isinstance(x, str) for x in batch):
56
+ msg = f"Detected non-string at batch: {i}"
57
+ raise ValueError(msg)
58
+
59
+ batch_embeddings = model.encode(batch, output_value="token_embeddings") # type: ignore # Annoying
60
+ for text, embedding in zip(batch, batch_embeddings, strict=False):
61
+ texts.append(_truncate_text(tokenizer, text))
62
+ embeddings.append(embedding[1:-1].mean(axis=0).cpu().numpy())
63
+ if i and i % _SAVE_EVERY == 0:
64
+ json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4)
65
+ np.save(output_dir_path / f"feature_{i}.npy", embeddings)
66
+ texts = []
67
+ embeddings = []
68
+ if texts:
69
+ json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4)
70
+ np.save(output_dir_path / f"feature_{i}.npy", embeddings)
71
+
72
+
73
+ def _truncate_text(tokenizer: PreTrainedTokenizer, text: str) -> str:
74
+ """Truncate text to fit the tokenizer's maximum length."""
75
+ tokens = tokenizer.encode(
76
+ text,
77
+ truncation=True,
78
+ max_length=tokenizer.model_max_length,
79
+ )
80
+ return tokenizer.decode(tokens, skip_special_tokens=True)
81
+
82
+
83
+ def main() -> None:
84
+ """Main function to featurize texts using a sentence transformer."""
85
+ parser = argparse.ArgumentParser(description="Featurize texts using a sentence transformer.")
86
+ parser.add_argument(
87
+ "--model-name",
88
+ type=str,
89
+ default="baai/bge-base-en-v1.5",
90
+ help="The model name for distillation (e.g., 'baai/bge-base-en-v1.5').",
91
+ )
92
+ parser.add_argument(
93
+ "--output-dir",
94
+ type=str,
95
+ default=None,
96
+ help="Directory to save the featurized texts.",
97
+ )
98
+ parser.add_argument(
99
+ "--dataset-path",
100
+ type=str,
101
+ default="allenai/c4",
102
+ help="The dataset path or name (e.g. 'allenai/c4').",
103
+ )
104
+ parser.add_argument(
105
+ "--dataset-name",
106
+ type=str,
107
+ default="en",
108
+ help="The dataset configuration name (e.g., 'en' for C4).",
109
+ )
110
+ parser.add_argument(
111
+ "--dataset-split",
112
+ type=str,
113
+ default="train",
114
+ help="The dataset split (e.g., 'train', 'validation').",
115
+ )
116
+ parser.add_argument(
117
+ "--no-streaming",
118
+ action="store_false",
119
+ help="Disable streaming mode when loading the dataset.",
120
+ )
121
+ parser.add_argument(
122
+ "--max-means",
123
+ type=int,
124
+ default=1000000,
125
+ help="The maximum number of mean embeddings to generate.",
126
+ )
127
+ parser.add_argument(
128
+ "--key",
129
+ type=str,
130
+ default="text",
131
+ help="The key of the text field in the dataset to featurize (default: 'text').",
132
+ )
133
+ parser.add_argument(
134
+ "--batch-size",
135
+ type=int,
136
+ default=32,
137
+ help="Batch size to use for encoding the texts.",
138
+ )
139
+
140
+ args = parser.parse_args()
141
+
142
+ if args.output_dir is None:
143
+ model_name = args.model_name.replace("/", "_")
144
+ dataset_path = args.dataset_path.replace("/", "_")
145
+ output_dir = f"{model_name}_{dataset_path}_featurized"
146
+ else:
147
+ output_dir = args.output_dir
148
+
149
+ model = SentenceTransformer(args.model_name)
150
+ dataset = load_dataset(
151
+ args.dataset_path,
152
+ name=args.dataset_name,
153
+ split=args.dataset_split,
154
+ streaming=args.no_streaming,
155
+ )
156
+
157
+ featurize(iter(dataset), model, output_dir, args.max_means, args.batch_size, args.key)
158
+
159
+
160
+ if __name__ == "__main__":
161
+ main()
src/distiller/tokenlearn/pretrain.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from tqdm import tqdm
12
+
13
+ from distiller.model2vec.distill.utils import select_optimal_device
14
+ from distiller.model2vec.model import StaticModel
15
+
16
+ if TYPE_CHECKING:
17
+ from tokenizers import Tokenizer
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class StaticModelFineTuner(nn.Module):
23
+ def __init__(self, vectors: torch.Tensor, out_dim: int, pad_id: int) -> None:
24
+ """
25
+ Initialize from a model.
26
+
27
+ :param vectors: The vectors to use.
28
+ :param out_dim: The output dimension.
29
+ :param pad_id: The padding id.
30
+ """
31
+ super().__init__()
32
+ self.pad_id = pad_id
33
+ norms = vectors.norm(dim=1)
34
+ # Normalize the vectors
35
+ vectors = vectors / norms[:, None]
36
+ self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
37
+ self.n_out = out_dim
38
+ self.out_layer = nn.Linear(vectors.shape[1], self.n_out)
39
+ weights = torch.Tensor(norms)
40
+ weights[pad_id] = 0
41
+ self.w = nn.Parameter(weights)
42
+
43
+ def sub_forward(self, input_ids: torch.Tensor) -> torch.Tensor:
44
+ """Forward pass through the mean."""
45
+ # Fix for index out of bounds issue - filter out invalid tokens
46
+ valid_mask = (input_ids >= 0) & (input_ids < self.w.shape[0])
47
+ if not valid_mask.all():
48
+ input_ids = torch.where(valid_mask, input_ids, 0)
49
+ w = self.w[input_ids]
50
+ zeros = (input_ids != self.pad_id).float()
51
+ w = w * zeros
52
+ # Add a small epsilon to avoid division by zero
53
+ length = zeros.sum(1) + 1e-16
54
+ # Fix for embedding index out of bounds issue
55
+ valid_emb_mask = (input_ids >= 0) & (input_ids < self.embeddings.num_embeddings)
56
+ if not valid_emb_mask.all():
57
+ input_ids = torch.where(valid_emb_mask, input_ids, 0)
58
+ embedded = self.embeddings(input_ids)
59
+ # Zero out the padding
60
+ embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
61
+ # Simulate actual mean
62
+ return embedded / length[:, None]
63
+
64
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
65
+ """Forward pass through the mean, and a classifier layer after."""
66
+ embedded = self.sub_forward(x)
67
+ return self.out_layer(embedded), embedded
68
+
69
+ @property
70
+ def device(self) -> torch.device:
71
+ """Get the device of the model."""
72
+ return self.embeddings.weight.device
73
+
74
+
75
+ class TextDataset(Dataset):
76
+ def __init__(self, texts: list[str], targets: torch.Tensor, tokenizer: Tokenizer) -> None:
77
+ """
78
+ Initialize the dataset.
79
+
80
+ :param texts: The texts to tokenize.
81
+ :param targets: The targets.
82
+ :param tokenizer: The tokenizer to use.
83
+ :raises ValueError: If the number of labels does not match the number of texts.
84
+ """
85
+ if len(targets) != len(texts):
86
+ msg = "Number of labels does not match number of texts."
87
+ raise ValueError(msg)
88
+ self.texts = [x[:20_000] for x in texts]
89
+ self.tokenized_texts: list[list[int]] = [
90
+ encoding.ids[:512] for encoding in tokenizer.encode_batch_fast(self.texts, add_special_tokens=False)
91
+ ]
92
+ self.targets = targets
93
+ self.tokenizer = tokenizer
94
+
95
+ def __len__(self) -> int:
96
+ """Return the length of the dataset."""
97
+ return len(self.tokenized_texts)
98
+
99
+ def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]:
100
+ """Gets an item."""
101
+ return self.tokenized_texts[index], self.targets[index]
102
+
103
+ @staticmethod
104
+ def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]:
105
+ """Collate function."""
106
+ texts, targets = zip(*batch, strict=False)
107
+
108
+ tensors = [torch.LongTensor(x).int() for x in texts]
109
+ padded = pad_sequence(tensors, batch_first=True, padding_value=0)
110
+
111
+ return padded, torch.stack(targets)
112
+
113
+ def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader:
114
+ """Convert the dataset to a DataLoader."""
115
+ return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size)
116
+
117
+
118
+ def train_supervised(
119
+ train_dataset: TextDataset,
120
+ validation_dataset: TextDataset,
121
+ model: StaticModel,
122
+ patience: int | None = 5,
123
+ device: str | None = None,
124
+ batch_size: int = 256,
125
+ lr: float = 1e-3,
126
+ ) -> StaticModel:
127
+ """
128
+ Train a tokenlearn model.
129
+
130
+ :param train_dataset: The training dataset.
131
+ :param validation_dataset: The validation dataset.
132
+ :param model: The model to train.
133
+ :param patience: The number of epochs to wait before early stopping.
134
+ :param device: The device to train on.
135
+ :param batch_size: The batch size.
136
+ :param lr: The learning rate.
137
+ :return: The trained model.
138
+ """
139
+ device = select_optimal_device(device)
140
+ train_dataloader = train_dataset.to_dataloader(shuffle=True, batch_size=batch_size)
141
+
142
+ # Initialize the model
143
+ trainable_model = StaticModelFineTuner(
144
+ torch.from_numpy(model.embedding),
145
+ out_dim=train_dataset.targets.shape[1],
146
+ pad_id=model.tokenizer.token_to_id("[PAD]"),
147
+ )
148
+ trainable_model.to(device)
149
+
150
+ # Separate parameters for model and linear layer
151
+ model_params = [
152
+ *list(trainable_model.embeddings.parameters()),
153
+ trainable_model.w,
154
+ *list(trainable_model.out_layer.parameters()),
155
+ ]
156
+
157
+ # Create optimizer with separate parameter groups
158
+ optimizer = torch.optim.AdamW(params=model_params, lr=lr)
159
+
160
+ lowest_loss = float("inf")
161
+ param_dict = trainable_model.state_dict()
162
+ curr_patience = patience
163
+ stop = False
164
+
165
+ criterion = nn.MSELoss()
166
+
167
+ try:
168
+ for epoch in range(100_000):
169
+ logger.info(f"Epoch {epoch}")
170
+ trainable_model.train()
171
+
172
+ # Track train loss separately
173
+ train_losses = []
174
+ barred_train = tqdm(train_dataloader, desc=f"Epoch {epoch:03d} [Train]")
175
+
176
+ for idx, (x, y) in enumerate(barred_train):
177
+ optimizer.zero_grad()
178
+ x = x.to(trainable_model.device)
179
+ y_hat, _ = trainable_model(x)
180
+ # Separate loss components
181
+ train_loss = criterion(y_hat, y.to(trainable_model.device)).mean()
182
+
183
+ # Apply weights
184
+ train_loss.backward()
185
+
186
+ optimizer.step()
187
+ train_losses.append(train_loss.item())
188
+
189
+ barred_train.set_description_str(f"Train Loss: {np.mean(train_losses[-10:]):.3f}")
190
+
191
+ # Evaluate every 1000 steps and at the end of the epoch
192
+ if (idx > 0 and idx % 1000 == 0) or idx == len(train_dataloader) - 1:
193
+ trainable_model.eval()
194
+ with torch.no_grad():
195
+ validation_losses = []
196
+ barred_val = tqdm(
197
+ validation_dataset.to_dataloader(shuffle=False, batch_size=batch_size), desc="Validation"
198
+ )
199
+ for x_val, y_val in barred_val:
200
+ x_val = x_val.to(trainable_model.device)
201
+ y_hat_val, _ = trainable_model(x_val)
202
+ val_loss = criterion(y_hat_val, y_val.to(trainable_model.device)).mean()
203
+ validation_losses.append(val_loss.item())
204
+ barred_val.set_description_str(f"Validation Loss: {np.mean(validation_losses):.3f}")
205
+
206
+ validation_loss = np.mean(validation_losses)
207
+ # Early stopping logic based on validation loss
208
+ if patience is not None and curr_patience is not None:
209
+ if (lowest_loss - validation_loss) > 1e-4:
210
+ param_dict = trainable_model.state_dict() # Save best model state based on training loss
211
+ curr_patience = patience
212
+ lowest_loss = validation_loss
213
+ else:
214
+ curr_patience -= 1
215
+ if curr_patience == 0:
216
+ stop = True
217
+ break
218
+ logger.info(f"Patience level: {patience - curr_patience}")
219
+ logger.info(f"Validation loss: {validation_loss:.3f}")
220
+ logger.info(f"Lowest loss: {lowest_loss:.3f}")
221
+
222
+ trainable_model.train()
223
+
224
+ if stop:
225
+ logger.info("Early stopping")
226
+ break
227
+
228
+ except KeyboardInterrupt:
229
+ logger.info("Training interrupted")
230
+
231
+ trainable_model.eval()
232
+ # Load best model based on training loss
233
+ trainable_model.load_state_dict(param_dict)
234
+
235
+ # Move the embeddings to the device (GPU)
236
+ embeddings_weight = trainable_model.embeddings.weight.to(device)
237
+
238
+ # Perform the forward pass on GPU
239
+ with torch.no_grad():
240
+ vectors = trainable_model.sub_forward(torch.arange(len(embeddings_weight))[:, None].to(device)).cpu().numpy()
241
+
242
+ return StaticModel(vectors=vectors, tokenizer=model.tokenizer, config=model.config)
src/distiller/tokenlearn/train.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ from sklearn.decomposition import PCA
8
+
9
+ from distiller.model2vec.distill import distill
10
+ from distiller.model2vec.model import StaticModel
11
+ from distiller.tokenlearn.pretrain import TextDataset, train_supervised
12
+ from distiller.tokenlearn.utils import collect_means_and_texts, create_vocab
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ _MAX_N_VAL_SAMPLES = 10_000
19
+
20
+
21
+ def train_model(
22
+ model_name: str,
23
+ train_txt: list[str],
24
+ train_vec: np.ndarray,
25
+ device: str = "cpu",
26
+ vocab_size: int | None = None,
27
+ pca_dims: int = 256,
28
+ ) -> StaticModel:
29
+ """
30
+ Train a tokenlearn model.
31
+
32
+ :param model_name: The sentence transformer model name for distillation.
33
+ :param train_txt: List of texts to train on.
34
+ :param train_vec: List of vectors to train on.
35
+ :param device: Device to run the training on.
36
+ :param vocab_size: The vocabulary size to use (optional).
37
+ :param pca_dims: Number of dimensions to reduce the target embeddings to using PCA.
38
+ The model will use the same number of dimensions for the embeddings.
39
+ :return: The trained model.
40
+ """
41
+ pca_for_targets = PCA(n_components=pca_dims)
42
+ train_vec = pca_for_targets.fit_transform(train_vec)
43
+ var = np.cumsum(pca_for_targets.explained_variance_ratio_)[-1]
44
+ logger.info(f"Explained variance of target embeddings: {var:.2f}")
45
+
46
+ # Split the data into training and validation sets
47
+ # We use a max of 10k samples as validation data
48
+ val_samples = min(_MAX_N_VAL_SAMPLES, len(train_txt) // 10)
49
+ train_txt, train_vec, val_txt, val_vec = (
50
+ train_txt[:-val_samples],
51
+ train_vec[:-val_samples],
52
+ train_txt[-val_samples:],
53
+ train_vec[-val_samples:],
54
+ )
55
+
56
+ if vocab_size:
57
+ # Create a vocabulary if a vocab size is specified
58
+ vocab = create_vocab(texts=train_txt, vocab_size=vocab_size)
59
+ logger.info(f"Vocabulary created with {len(vocab)} tokens.")
60
+ else:
61
+ vocab = None
62
+ model = distill(model_name=model_name, quantize_to="float32", vocabulary=vocab, pca_dims=pca_dims)
63
+ train_data = TextDataset(train_txt, torch.from_numpy(train_vec), model.tokenizer)
64
+ val_data = TextDataset(val_txt, torch.from_numpy(val_vec), model.tokenizer)
65
+
66
+ # Train the model
67
+ return train_supervised(train_dataset=train_data, validation_dataset=val_data, model=model, device=device)
68
+
69
+
70
+ def save_model(model: StaticModel, save_path: str) -> None:
71
+ """
72
+ Save the model to the specified path.
73
+
74
+ :param model: The model to save.
75
+ :param save_path: Path to save the model.
76
+ """
77
+ model.save_pretrained(save_path)
78
+ logging.info(f"Model saved to {save_path}")
79
+
80
+
81
+ def main() -> None:
82
+ """Main function to train and save a Model2Vec model using tokenlearn."""
83
+ parser = argparse.ArgumentParser(description="Train a Model2Vec using tokenlearn.")
84
+ parser.add_argument(
85
+ "--model-name",
86
+ type=str,
87
+ default="baai/bge-base-en-v1.5",
88
+ help="The model name for distillation (e.g., 'baai/bge-base-en-v1.5').",
89
+ )
90
+ parser.add_argument(
91
+ "--data-path",
92
+ type=str,
93
+ default="data/fineweb_bgebase",
94
+ help="Path to the directory containing the dataset.",
95
+ )
96
+ parser.add_argument(
97
+ "--save-path",
98
+ type=str,
99
+ required=True,
100
+ help="Path to save the trained model.",
101
+ )
102
+ parser.add_argument(
103
+ "--device",
104
+ type=str,
105
+ default="cpu",
106
+ help="Device to run the training on (e.g., 'cpu', 'cuda').",
107
+ )
108
+ parser.add_argument(
109
+ "--vocab-size",
110
+ type=int,
111
+ default=56000,
112
+ help="The vocabulary size to use for training.",
113
+ )
114
+ parser.add_argument(
115
+ "--pca-dims",
116
+ type=int,
117
+ default=256,
118
+ help="Number of dimensions to reduce the target embeddings to using PCA.",
119
+ )
120
+ args = parser.parse_args()
121
+
122
+ # Collect paths for training data
123
+ paths = sorted(Path(args.data_path).glob("*.json"))
124
+ train_txt, train_vec = collect_means_and_texts(paths)
125
+
126
+ # Train the model
127
+ model = train_model(
128
+ args.model_name, train_txt, train_vec, device=args.device, vocab_size=args.vocab_size, pca_dims=args.pca_dims
129
+ )
130
+ save_model(model, args.save_path)
131
+
132
+
133
+ if __name__ == "__main__":
134
+ main()
src/distiller/tokenlearn/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from collections import Counter
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import regex
8
+ from tqdm import tqdm
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def create_vocab(texts: list[str], vocab_size: int = 56_000) -> list[str]:
14
+ """
15
+ Create a vocabulary from a list of texts.
16
+
17
+ :param texts: The list of texts to create the vocabulary from.
18
+ :param vocab_size: The size of the vocabulary. Defaults to 56,000, which is the vocab_size used for our 32M models.
19
+ :return: The vocabulary.
20
+ """
21
+ tokenizer_regex = regex.compile(r"\w+|[^\w\s]+")
22
+
23
+ # Tokenize all texts
24
+ tokens = []
25
+ for text in tqdm(texts, desc="Tokenizing texts"):
26
+ tokens.extend(tokenizer_regex.findall(text.lower()))
27
+
28
+ # Count the tokens
29
+ token_counts = Counter(tokens)
30
+
31
+ # Get the most common tokens as the vocabulary
32
+ return [word for word, _ in token_counts.most_common(vocab_size)]
33
+
34
+
35
+ def collect_means_and_texts(paths: list[Path]) -> tuple[list[str], np.ndarray]:
36
+ """Collect means and texts from a list of paths."""
37
+ txts = []
38
+ vectors_list = []
39
+ for items_path in tqdm(paths, desc="Collecting means and texts"):
40
+ if not items_path.name.endswith(".json"):
41
+ continue
42
+ base_path = items_path.with_name(items_path.stem.replace("", ""))
43
+ vectors_path = items_path.with_name(base_path.name.replace(".json", "") + ".npy")
44
+ try:
45
+ with open(items_path) as f:
46
+ items = json.load(f)
47
+ vectors = np.load(vectors_path, allow_pickle=False)
48
+ except (KeyError, FileNotFoundError, ValueError) as e:
49
+ logger.info(f"Error loading data from {base_path}: {e}")
50
+ continue
51
+
52
+ # Filter out any NaN vectors before appending
53
+ vectors = np.array(vectors)
54
+ items = np.array(items)
55
+ non_nan_indices = ~np.isnan(vectors).any(axis=1)
56
+ valid_vectors = vectors[non_nan_indices]
57
+ valid_items = items[non_nan_indices]
58
+ txts.extend(valid_items.tolist())
59
+ vectors_list.append(valid_vectors)
60
+
61
+ all_vectors = np.concatenate(vectors_list, axis=0) if vectors_list else np.array([])
62
+ return txts, all_vectors
src/distiller/tokenlearn/version.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __version_triple__ = (0, 2, 0)
2
+ __version__ = ".".join(map(str, __version_triple__))