qwenillustrious / sentence-transformers /examples /sparse_encoder /training /nli /train_splade_nli.py
| """ | |
| This example trains a SparseEncoder for the SNLI + MultiNLI (AllNLI) dataset. | |
| The training script fine-tunes a SparseEncoder using the Splade loss function for retrieval. | |
| It loads a subset of the AllNLI dataset, splits it into training and evaluation subsets, | |
| and he model is evaluated on the STS benchmark dataset. After training, the model is evaluated and | |
| saved locally, with an optional step to push the trained model to the Hugging Face Hub. | |
| Usage: | |
| python train_splade_nli.py | |
| """ | |
| import logging | |
| import traceback | |
| from datasets import load_dataset | |
| from sentence_transformers import ( | |
| SparseEncoder, | |
| SparseEncoderModelCardData, | |
| SparseEncoderTrainer, | |
| SparseEncoderTrainingArguments, | |
| ) | |
| from sentence_transformers.similarity_functions import SimilarityFunction | |
| from sentence_transformers.sparse_encoder import evaluation, losses | |
| from sentence_transformers.training_args import BatchSamplers | |
| # Set the log level to INFO to get more information | |
| logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) | |
| def main(): | |
| model_name = "naver/splade-cocondenser-ensembledistil" | |
| short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1] | |
| train_batch_size = 16 | |
| num_epochs = 1 | |
| # 1a. Load a model to finetune with 1b. (Optional) model card data | |
| model = SparseEncoder( | |
| model_name, | |
| model_card_data=SparseEncoderModelCardData( | |
| language="en", | |
| license="apache-2.0", | |
| model_name=f"{short_model_name} trained on Natural Language Inference (NLI)", | |
| ), | |
| similarity_fn_name="dot", # or cosine but dot tends to be stronger | |
| ) | |
| model.max_seq_length = 256 # Set the max sequence length to 256 for the training | |
| logging.info("Model max length: %s", model.max_seq_length) | |
| # 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli | |
| # We'll start with 10k training samples, but you can increase this to get a stronger model | |
| logging.info("Read AllNLI train dataset") | |
| train_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="train").select(range(10000)) | |
| eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="dev").select(range(1000)) | |
| logging.info(train_dataset) | |
| logging.info(eval_dataset) | |
| # 3. Define our training loss. | |
| document_regularizer_weight = 3e-3 | |
| loss = losses.SpladeLoss( | |
| model=model, | |
| loss=losses.SparseMultipleNegativesRankingLoss( | |
| model=model, | |
| scale=1, # need to be adapt if used cosine similarity | |
| similarity_fct=model.similarity, # Use the same similarity function as the model | |
| ), | |
| document_regularizer_weight=document_regularizer_weight, # Weight for document loss | |
| use_document_regularizer_only=True, | |
| ) | |
| # 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. | |
| stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") | |
| dev_evaluator = evaluation.SparseEmbeddingSimilarityEvaluator( | |
| sentences1=stsb_eval_dataset["sentence1"], | |
| sentences2=stsb_eval_dataset["sentence2"], | |
| scores=stsb_eval_dataset["score"], | |
| main_similarity=SimilarityFunction.COSINE, | |
| name="sts-dev", | |
| ) | |
| dev_evaluator(model) | |
| # 5. Define the training arguments | |
| run_name = f"{short_model_name}-nli" | |
| training_args = SparseEncoderTrainingArguments( | |
| # Required parameter: | |
| output_dir=f"models/{run_name}", | |
| # Optional training parameters: | |
| num_train_epochs=num_epochs, | |
| per_device_train_batch_size=train_batch_size, | |
| per_device_eval_batch_size=train_batch_size, | |
| learning_rate=4e-6, | |
| fp16=False, # Set to False if you get an error that your GPU can't run on FP16 | |
| bf16=True, # Set to True if you have a GPU that supports BF16 | |
| batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_sts-dev_spearman_cosine", | |
| # Optional tracking/debugging parameters: | |
| eval_strategy="steps", | |
| eval_steps=120, | |
| save_strategy="steps", | |
| save_steps=120, | |
| save_total_limit=2, | |
| logging_steps=20, | |
| run_name=run_name, # Will be used in W&B if `wandb` is installed | |
| seed=42, | |
| ) | |
| # 6. Create the trainer & start training | |
| trainer = SparseEncoderTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| loss=loss, | |
| evaluator=dev_evaluator, | |
| ) | |
| trainer.train() | |
| # 7. Evaluate the model performance on the STS Benchmark test dataset | |
| test_dataset = load_dataset("sentence-transformers/stsb", split="test") | |
| test_evaluator = evaluation.SparseEmbeddingSimilarityEvaluator( | |
| sentences1=test_dataset["sentence1"], | |
| sentences2=test_dataset["sentence2"], | |
| scores=test_dataset["score"], | |
| main_similarity=SimilarityFunction.COSINE, | |
| name="sts-test", | |
| ) | |
| test_evaluator(model) | |
| # 8. Save the final model | |
| final_output_dir = f"models/{run_name}/final" | |
| model.save_pretrained(final_output_dir) | |
| # 9. (Optional) save the model to the Hugging Face Hub! | |
| # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first | |
| try: | |
| model.push_to_hub(run_name) | |
| except Exception: | |
| logging.error( | |
| f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " | |
| f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` " | |
| f"and saving it using `model.push_to_hub('{run_name}')`." | |
| ) | |
| if __name__ == "__main__": | |
| main() | |