lsmpp's picture
Add files using upload-large-folder tool
bd33eac verified
"""
This example trains a SparseEncoder for the SNLI + MultiNLI (AllNLI) dataset.
The training script fine-tunes a SparseEncoder using the Splade loss function for retrieval.
It loads a subset of the AllNLI dataset, splits it into training and evaluation subsets,
and he model is evaluated on the STS benchmark dataset. After training, the model is evaluated and
saved locally, with an optional step to push the trained model to the Hugging Face Hub.
Usage:
python train_splade_nli.py
"""
import logging
import traceback
from datasets import load_dataset
from sentence_transformers import (
SparseEncoder,
SparseEncoderModelCardData,
SparseEncoderTrainer,
SparseEncoderTrainingArguments,
)
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.sparse_encoder import evaluation, losses
from sentence_transformers.training_args import BatchSamplers
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
def main():
model_name = "naver/splade-cocondenser-ensembledistil"
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
train_batch_size = 16
num_epochs = 1
# 1a. Load a model to finetune with 1b. (Optional) model card data
model = SparseEncoder(
model_name,
model_card_data=SparseEncoderModelCardData(
language="en",
license="apache-2.0",
model_name=f"{short_model_name} trained on Natural Language Inference (NLI)",
),
similarity_fn_name="dot", # or cosine but dot tends to be stronger
)
model.max_seq_length = 256 # Set the max sequence length to 256 for the training
logging.info("Model max length: %s", model.max_seq_length)
# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli
# We'll start with 10k training samples, but you can increase this to get a stronger model
logging.info("Read AllNLI train dataset")
train_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="train").select(range(10000))
eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="dev").select(range(1000))
logging.info(train_dataset)
logging.info(eval_dataset)
# 3. Define our training loss.
document_regularizer_weight = 3e-3
loss = losses.SpladeLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(
model=model,
scale=1, # need to be adapt if used cosine similarity
similarity_fct=model.similarity, # Use the same similarity function as the model
),
document_regularizer_weight=document_regularizer_weight, # Weight for document loss
use_document_regularizer_only=True,
)
# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss.
stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
dev_evaluator = evaluation.SparseEmbeddingSimilarityEvaluator(
sentences1=stsb_eval_dataset["sentence1"],
sentences2=stsb_eval_dataset["sentence2"],
scores=stsb_eval_dataset["score"],
main_similarity=SimilarityFunction.COSINE,
name="sts-dev",
)
dev_evaluator(model)
# 5. Define the training arguments
run_name = f"{short_model_name}-nli"
training_args = SparseEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
learning_rate=4e-6,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
load_best_model_at_end=True,
metric_for_best_model="eval_sts-dev_spearman_cosine",
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=120,
save_strategy="steps",
save_steps=120,
save_total_limit=2,
logging_steps=20,
run_name=run_name, # Will be used in W&B if `wandb` is installed
seed=42,
)
# 6. Create the trainer & start training
trainer = SparseEncoderTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# 7. Evaluate the model performance on the STS Benchmark test dataset
test_dataset = load_dataset("sentence-transformers/stsb", split="test")
test_evaluator = evaluation.SparseEmbeddingSimilarityEvaluator(
sentences1=test_dataset["sentence1"],
sentences2=test_dataset["sentence2"],
scores=test_dataset["score"],
main_similarity=SimilarityFunction.COSINE,
name="sts-test",
)
test_evaluator(model)
# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)
# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
model.push_to_hub(run_name)
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
f"and saving it using `model.push_to_hub('{run_name}')`."
)
if __name__ == "__main__":
main()