Sarthak
chore: moved tokenlearn as in internal package
4255a26
import argparse
import logging
from pathlib import Path
import numpy as np
import torch
from sklearn.decomposition import PCA
from distiller.model2vec.distill import distill
from distiller.model2vec.model import StaticModel
from distiller.tokenlearn.pretrain import TextDataset, train_supervised
from distiller.tokenlearn.utils import collect_means_and_texts, create_vocab
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
_MAX_N_VAL_SAMPLES = 10_000
def train_model(
model_name: str,
train_txt: list[str],
train_vec: np.ndarray,
device: str = "cpu",
vocab_size: int | None = None,
pca_dims: int = 256,
) -> StaticModel:
"""
Train a tokenlearn model.
:param model_name: The sentence transformer model name for distillation.
:param train_txt: List of texts to train on.
:param train_vec: List of vectors to train on.
:param device: Device to run the training on.
:param vocab_size: The vocabulary size to use (optional).
:param pca_dims: Number of dimensions to reduce the target embeddings to using PCA.
The model will use the same number of dimensions for the embeddings.
:return: The trained model.
"""
pca_for_targets = PCA(n_components=pca_dims)
train_vec = pca_for_targets.fit_transform(train_vec)
var = np.cumsum(pca_for_targets.explained_variance_ratio_)[-1]
logger.info(f"Explained variance of target embeddings: {var:.2f}")
# Split the data into training and validation sets
# We use a max of 10k samples as validation data
val_samples = min(_MAX_N_VAL_SAMPLES, len(train_txt) // 10)
train_txt, train_vec, val_txt, val_vec = (
train_txt[:-val_samples],
train_vec[:-val_samples],
train_txt[-val_samples:],
train_vec[-val_samples:],
)
if vocab_size:
# Create a vocabulary if a vocab size is specified
vocab = create_vocab(texts=train_txt, vocab_size=vocab_size)
logger.info(f"Vocabulary created with {len(vocab)} tokens.")
else:
vocab = None
model = distill(model_name=model_name, quantize_to="float32", vocabulary=vocab, pca_dims=pca_dims)
train_data = TextDataset(train_txt, torch.from_numpy(train_vec), model.tokenizer)
val_data = TextDataset(val_txt, torch.from_numpy(val_vec), model.tokenizer)
# Train the model
return train_supervised(train_dataset=train_data, validation_dataset=val_data, model=model, device=device)
def save_model(model: StaticModel, save_path: str) -> None:
"""
Save the model to the specified path.
:param model: The model to save.
:param save_path: Path to save the model.
"""
model.save_pretrained(save_path)
logging.info(f"Model saved to {save_path}")
def main() -> None:
"""Main function to train and save a Model2Vec model using tokenlearn."""
parser = argparse.ArgumentParser(description="Train a Model2Vec using tokenlearn.")
parser.add_argument(
"--model-name",
type=str,
default="baai/bge-base-en-v1.5",
help="The model name for distillation (e.g., 'baai/bge-base-en-v1.5').",
)
parser.add_argument(
"--data-path",
type=str,
default="data/fineweb_bgebase",
help="Path to the directory containing the dataset.",
)
parser.add_argument(
"--save-path",
type=str,
required=True,
help="Path to save the trained model.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="Device to run the training on (e.g., 'cpu', 'cuda').",
)
parser.add_argument(
"--vocab-size",
type=int,
default=56000,
help="The vocabulary size to use for training.",
)
parser.add_argument(
"--pca-dims",
type=int,
default=256,
help="Number of dimensions to reduce the target embeddings to using PCA.",
)
args = parser.parse_args()
# Collect paths for training data
paths = sorted(Path(args.data_path).glob("*.json"))
train_txt, train_vec = collect_means_and_texts(paths)
# Train the model
model = train_model(
args.model_name, train_txt, train_vec, device=args.device, vocab_size=args.vocab_size, pca_dims=args.pca_dims
)
save_model(model, args.save_path)
if __name__ == "__main__":
main()