""" This example trains a SparseEncoder for the Semantic Textual Similarity Benchmark dataset. The training script fine-tunes a SparseEncoder using the Splade loss function for retrieval. It loads a Semantic Textual Similarity Benchmark dataset, splits it into training and evaluation subsets, and he model is evaluated on the test/eval dataset. After training, the model is evaluated and saved locally, with an optional step to push the trained model to the Hugging Face Hub. Usage: python train_splade_stsbenchmark.py """ import logging import traceback from datasets import load_dataset from sentence_transformers import ( SparseEncoder, SparseEncoderModelCardData, SparseEncoderTrainer, SparseEncoderTrainingArguments, ) from sentence_transformers.similarity_functions import SimilarityFunction from sentence_transformers.sparse_encoder import evaluation, losses # Set the log level to INFO to get more information logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) def main(): model_name = "naver/splade-cocondenser-ensembledistil" short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1] train_batch_size = 16 num_epochs = 1 # 1a. Load a model to finetune with 1b. (Optional) model card data model = SparseEncoder( model_name, model_card_data=SparseEncoderModelCardData( language="en", license="apache-2.0", model_name=f"{short_model_name} trained on STS", ), similarity_fn_name="dot", # or cosine but anyway in the evaluator cosine is used ) model.max_seq_length = 256 # Set the max sequence length to 256 for the training logging.info("Model max length: %s", model.max_seq_length) # 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb train_dataset = load_dataset("sentence-transformers/stsb", split="train") eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") test_dataset = load_dataset("sentence-transformers/stsb", split="test") logging.info(train_dataset) # 3. Define our training loss. document_regularizer_weight = 3e-3 loss = losses.SpladeLoss( model=model, loss=losses.SparseCosineSimilarityLoss(model=model), document_regularizer_weight=document_regularizer_weight, # Weight for FLOPS (sparsity) loss, higher is sparser use_document_regularizer_only=True, ) # 4. Before and during training, we use SparseEmbeddingSimilarityEvaluator to measure the performance on the dev set dev_evaluator = evaluation.SparseEmbeddingSimilarityEvaluator( sentences1=eval_dataset["sentence1"], sentences2=eval_dataset["sentence2"], scores=eval_dataset["score"], main_similarity=SimilarityFunction.COSINE, name="sts-dev", ) dev_evaluator(model) # 5. Define the training arguments run_name = f"{short_model_name}-sts" training_args = SparseEncoderTrainingArguments( # Required parameter: output_dir=f"models/{run_name}", # Optional training parameters: num_train_epochs=num_epochs, per_device_train_batch_size=train_batch_size, per_device_eval_batch_size=train_batch_size, learning_rate=5e-5, fp16=False, # Set to False if you get an error that your GPU can't run on FP16 bf16=True, # Set to True if you have a GPU that supports BF16 # Optional tracking/debugging parameters: eval_strategy="steps", eval_steps=100, save_strategy="steps", save_steps=100, save_total_limit=2, logging_steps=100, run_name=run_name, # Will be used in W&B if `wandb` is installed seed=42, ) # 6. Create the trainer & start training trainer = SparseEncoderTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, loss=loss, evaluator=dev_evaluator, ) trainer.train() # 7. Evaluate the model performance on the STS Benchmark test dataset test_evaluator = evaluation.SparseEmbeddingSimilarityEvaluator( sentences1=test_dataset["sentence1"], sentences2=test_dataset["sentence2"], scores=test_dataset["score"], main_similarity=SimilarityFunction.COSINE, name="sts-test", ) test_evaluator(model) # 8. Save the final model final_output_dir = f"models/{run_name}/final" model.save_pretrained(final_output_dir) # 9. (Optional) save the model to the Hugging Face Hub! # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first try: model.push_to_hub(run_name) except Exception: logging.error( f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` " f"and saving it using `model.push_to_hub('{run_name}')`." ) if __name__ == "__main__": main()