qwenillustrious / sentence-transformers /examples /sparse_encoder /training /sts /train_splade_stsbenchmark.py
| """ | |
| This example trains a SparseEncoder for the Semantic Textual Similarity Benchmark dataset. | |
| The training script fine-tunes a SparseEncoder using the Splade loss function for retrieval. | |
| It loads a Semantic Textual Similarity Benchmark dataset, splits it into training and evaluation subsets, | |
| and he model is evaluated on the test/eval dataset. After training, the model is evaluated and | |
| saved locally, with an optional step to push the trained model to the Hugging Face Hub. | |
| Usage: | |
| python train_splade_stsbenchmark.py | |
| """ | |
| import logging | |
| import traceback | |
| from datasets import load_dataset | |
| from sentence_transformers import ( | |
| SparseEncoder, | |
| SparseEncoderModelCardData, | |
| SparseEncoderTrainer, | |
| SparseEncoderTrainingArguments, | |
| ) | |
| from sentence_transformers.similarity_functions import SimilarityFunction | |
| from sentence_transformers.sparse_encoder import evaluation, losses | |
| # Set the log level to INFO to get more information | |
| logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) | |
| def main(): | |
| model_name = "naver/splade-cocondenser-ensembledistil" | |
| short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1] | |
| train_batch_size = 16 | |
| num_epochs = 1 | |
| # 1a. Load a model to finetune with 1b. (Optional) model card data | |
| model = SparseEncoder( | |
| model_name, | |
| model_card_data=SparseEncoderModelCardData( | |
| language="en", | |
| license="apache-2.0", | |
| model_name=f"{short_model_name} trained on STS", | |
| ), | |
| similarity_fn_name="dot", # or cosine but anyway in the evaluator cosine is used | |
| ) | |
| model.max_seq_length = 256 # Set the max sequence length to 256 for the training | |
| logging.info("Model max length: %s", model.max_seq_length) | |
| # 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb | |
| train_dataset = load_dataset("sentence-transformers/stsb", split="train") | |
| eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") | |
| test_dataset = load_dataset("sentence-transformers/stsb", split="test") | |
| logging.info(train_dataset) | |
| # 3. Define our training loss. | |
| document_regularizer_weight = 3e-3 | |
| loss = losses.SpladeLoss( | |
| model=model, | |
| loss=losses.SparseCosineSimilarityLoss(model=model), | |
| document_regularizer_weight=document_regularizer_weight, # Weight for FLOPS (sparsity) loss, higher is sparser | |
| use_document_regularizer_only=True, | |
| ) | |
| # 4. Before and during training, we use SparseEmbeddingSimilarityEvaluator to measure the performance on the dev set | |
| dev_evaluator = evaluation.SparseEmbeddingSimilarityEvaluator( | |
| sentences1=eval_dataset["sentence1"], | |
| sentences2=eval_dataset["sentence2"], | |
| scores=eval_dataset["score"], | |
| main_similarity=SimilarityFunction.COSINE, | |
| name="sts-dev", | |
| ) | |
| dev_evaluator(model) | |
| # 5. Define the training arguments | |
| run_name = f"{short_model_name}-sts" | |
| training_args = SparseEncoderTrainingArguments( | |
| # Required parameter: | |
| output_dir=f"models/{run_name}", | |
| # Optional training parameters: | |
| num_train_epochs=num_epochs, | |
| per_device_train_batch_size=train_batch_size, | |
| per_device_eval_batch_size=train_batch_size, | |
| learning_rate=5e-5, | |
| fp16=False, # Set to False if you get an error that your GPU can't run on FP16 | |
| bf16=True, # Set to True if you have a GPU that supports BF16 | |
| # Optional tracking/debugging parameters: | |
| eval_strategy="steps", | |
| eval_steps=100, | |
| save_strategy="steps", | |
| save_steps=100, | |
| save_total_limit=2, | |
| logging_steps=100, | |
| run_name=run_name, # Will be used in W&B if `wandb` is installed | |
| seed=42, | |
| ) | |
| # 6. Create the trainer & start training | |
| trainer = SparseEncoderTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| loss=loss, | |
| evaluator=dev_evaluator, | |
| ) | |
| trainer.train() | |
| # 7. Evaluate the model performance on the STS Benchmark test dataset | |
| test_evaluator = evaluation.SparseEmbeddingSimilarityEvaluator( | |
| sentences1=test_dataset["sentence1"], | |
| sentences2=test_dataset["sentence2"], | |
| scores=test_dataset["score"], | |
| main_similarity=SimilarityFunction.COSINE, | |
| name="sts-test", | |
| ) | |
| test_evaluator(model) | |
| # 8. Save the final model | |
| final_output_dir = f"models/{run_name}/final" | |
| model.save_pretrained(final_output_dir) | |
| # 9. (Optional) save the model to the Hugging Face Hub! | |
| # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first | |
| try: | |
| model.push_to_hub(run_name) | |
| except Exception: | |
| logging.error( | |
| f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " | |
| f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` " | |
| f"and saving it using `model.push_to_hub('{run_name}')`." | |
| ) | |
| if __name__ == "__main__": | |
| main() | |