File size: 3,994 Bytes
4255a26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import argparse
import logging
from pathlib import Path

import numpy as np
import torch
from sklearn.decomposition import PCA

from distiller.model2vec.distill import distill
from distiller.model2vec.model import StaticModel
from distiller.tokenlearn.pretrain import TextDataset, train_supervised
from distiller.tokenlearn.utils import collect_means_and_texts, create_vocab

logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)

_MAX_N_VAL_SAMPLES = 10_000


def train_model(
	model_name: str,
	train_txt: list[str],
	train_vec: np.ndarray,
	device: str = "cpu",
	vocab_size: int | None = None,
	pca_dims: int = 256,
) -> StaticModel:
	"""
	Train a tokenlearn model.

	:param model_name: The sentence transformer model name for distillation.
	:param train_txt: List of texts to train on.
	:param train_vec: List of vectors to train on.
	:param device: Device to run the training on.
	:param vocab_size: The vocabulary size to use (optional).
	:param pca_dims: Number of dimensions to reduce the target embeddings to using PCA.
	    The model will use the same number of dimensions for the embeddings.
	:return: The trained model.
	"""
	pca_for_targets = PCA(n_components=pca_dims)
	train_vec = pca_for_targets.fit_transform(train_vec)
	var = np.cumsum(pca_for_targets.explained_variance_ratio_)[-1]
	logger.info(f"Explained variance of target embeddings: {var:.2f}")

	# Split the data into training and validation sets
	# We use a max of 10k samples as validation data
	val_samples = min(_MAX_N_VAL_SAMPLES, len(train_txt) // 10)
	train_txt, train_vec, val_txt, val_vec = (
		train_txt[:-val_samples],
		train_vec[:-val_samples],
		train_txt[-val_samples:],
		train_vec[-val_samples:],
	)

	if vocab_size:
		# Create a vocabulary if a vocab size is specified
		vocab = create_vocab(texts=train_txt, vocab_size=vocab_size)
		logger.info(f"Vocabulary created with {len(vocab)} tokens.")
	else:
		vocab = None
	model = distill(model_name=model_name, quantize_to="float32", vocabulary=vocab, pca_dims=pca_dims)
	train_data = TextDataset(train_txt, torch.from_numpy(train_vec), model.tokenizer)
	val_data = TextDataset(val_txt, torch.from_numpy(val_vec), model.tokenizer)

	# Train the model
	return train_supervised(train_dataset=train_data, validation_dataset=val_data, model=model, device=device)


def save_model(model: StaticModel, save_path: str) -> None:
	"""
	Save the model to the specified path.

	:param model: The model to save.
	:param save_path: Path to save the model.
	"""
	model.save_pretrained(save_path)
	logging.info(f"Model saved to {save_path}")


def main() -> None:
	"""Main function to train and save a Model2Vec model using tokenlearn."""
	parser = argparse.ArgumentParser(description="Train a Model2Vec using tokenlearn.")
	parser.add_argument(
		"--model-name",
		type=str,
		default="baai/bge-base-en-v1.5",
		help="The model name for distillation (e.g., 'baai/bge-base-en-v1.5').",
	)
	parser.add_argument(
		"--data-path",
		type=str,
		default="data/fineweb_bgebase",
		help="Path to the directory containing the dataset.",
	)
	parser.add_argument(
		"--save-path",
		type=str,
		required=True,
		help="Path to save the trained model.",
	)
	parser.add_argument(
		"--device",
		type=str,
		default="cpu",
		help="Device to run the training on (e.g., 'cpu', 'cuda').",
	)
	parser.add_argument(
		"--vocab-size",
		type=int,
		default=56000,
		help="The vocabulary size to use for training.",
	)
	parser.add_argument(
		"--pca-dims",
		type=int,
		default=256,
		help="Number of dimensions to reduce the target embeddings to using PCA.",
	)
	args = parser.parse_args()

	# Collect paths for training data
	paths = sorted(Path(args.data_path).glob("*.json"))
	train_txt, train_vec = collect_means_and_texts(paths)

	# Train the model
	model = train_model(
		args.model_name, train_txt, train_vec, device=args.device, vocab_size=args.vocab_size, pca_dims=args.pca_dims
	)
	save_model(model, args.save_path)


if __name__ == "__main__":
	main()