|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from sklearn.decomposition import PCA |
|
|
|
|
|
from distiller.model2vec.distill import distill |
|
|
from distiller.model2vec.model import StaticModel |
|
|
from distiller.tokenlearn.pretrain import TextDataset, train_supervised |
|
|
from distiller.tokenlearn.utils import collect_means_and_texts, create_vocab |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
_MAX_N_VAL_SAMPLES = 10_000 |
|
|
|
|
|
|
|
|
def train_model( |
|
|
model_name: str, |
|
|
train_txt: list[str], |
|
|
train_vec: np.ndarray, |
|
|
device: str = "cpu", |
|
|
vocab_size: int | None = None, |
|
|
pca_dims: int = 256, |
|
|
) -> StaticModel: |
|
|
""" |
|
|
Train a tokenlearn model. |
|
|
|
|
|
:param model_name: The sentence transformer model name for distillation. |
|
|
:param train_txt: List of texts to train on. |
|
|
:param train_vec: List of vectors to train on. |
|
|
:param device: Device to run the training on. |
|
|
:param vocab_size: The vocabulary size to use (optional). |
|
|
:param pca_dims: Number of dimensions to reduce the target embeddings to using PCA. |
|
|
The model will use the same number of dimensions for the embeddings. |
|
|
:return: The trained model. |
|
|
""" |
|
|
pca_for_targets = PCA(n_components=pca_dims) |
|
|
train_vec = pca_for_targets.fit_transform(train_vec) |
|
|
var = np.cumsum(pca_for_targets.explained_variance_ratio_)[-1] |
|
|
logger.info(f"Explained variance of target embeddings: {var:.2f}") |
|
|
|
|
|
|
|
|
|
|
|
val_samples = min(_MAX_N_VAL_SAMPLES, len(train_txt) // 10) |
|
|
train_txt, train_vec, val_txt, val_vec = ( |
|
|
train_txt[:-val_samples], |
|
|
train_vec[:-val_samples], |
|
|
train_txt[-val_samples:], |
|
|
train_vec[-val_samples:], |
|
|
) |
|
|
|
|
|
if vocab_size: |
|
|
|
|
|
vocab = create_vocab(texts=train_txt, vocab_size=vocab_size) |
|
|
logger.info(f"Vocabulary created with {len(vocab)} tokens.") |
|
|
else: |
|
|
vocab = None |
|
|
model = distill(model_name=model_name, quantize_to="float32", vocabulary=vocab, pca_dims=pca_dims) |
|
|
train_data = TextDataset(train_txt, torch.from_numpy(train_vec), model.tokenizer) |
|
|
val_data = TextDataset(val_txt, torch.from_numpy(val_vec), model.tokenizer) |
|
|
|
|
|
|
|
|
return train_supervised(train_dataset=train_data, validation_dataset=val_data, model=model, device=device) |
|
|
|
|
|
|
|
|
def save_model(model: StaticModel, save_path: str) -> None: |
|
|
""" |
|
|
Save the model to the specified path. |
|
|
|
|
|
:param model: The model to save. |
|
|
:param save_path: Path to save the model. |
|
|
""" |
|
|
model.save_pretrained(save_path) |
|
|
logging.info(f"Model saved to {save_path}") |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
"""Main function to train and save a Model2Vec model using tokenlearn.""" |
|
|
parser = argparse.ArgumentParser(description="Train a Model2Vec using tokenlearn.") |
|
|
parser.add_argument( |
|
|
"--model-name", |
|
|
type=str, |
|
|
default="baai/bge-base-en-v1.5", |
|
|
help="The model name for distillation (e.g., 'baai/bge-base-en-v1.5').", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--data-path", |
|
|
type=str, |
|
|
default="data/fineweb_bgebase", |
|
|
help="Path to the directory containing the dataset.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--save-path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to save the trained model.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="cpu", |
|
|
help="Device to run the training on (e.g., 'cpu', 'cuda').", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--vocab-size", |
|
|
type=int, |
|
|
default=56000, |
|
|
help="The vocabulary size to use for training.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--pca-dims", |
|
|
type=int, |
|
|
default=256, |
|
|
help="Number of dimensions to reduce the target embeddings to using PCA.", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
paths = sorted(Path(args.data_path).glob("*.json")) |
|
|
train_txt, train_vec = collect_means_and_texts(paths) |
|
|
|
|
|
|
|
|
model = train_model( |
|
|
args.model_name, train_txt, train_vec, device=args.device, vocab_size=args.vocab_size, pca_dims=args.pca_dims |
|
|
) |
|
|
save_model(model, args.save_path) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|