Text Ranking
sentence-transformers
Safetensors
English
bert
cross-encoder
text-classification
Generated from Trainer
dataset_size:578402
loss:BinaryCrossEntropyLoss
text-embeddings-inference
Instructions to use cross-encoder-testing/reranker-bert-tiny-gooaq-bce with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use cross-encoder-testing/reranker-bert-tiny-gooaq-bce with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("cross-encoder-testing/reranker-bert-tiny-gooaq-bce") query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Notebooks
- Google Colab
- Kaggle
Create train_script.py
Browse files- train_script.py +171 -0
train_script.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import traceback
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderModelCardData
|
| 9 |
+
from sentence_transformers.cross_encoder.evaluation import (
|
| 10 |
+
CrossEncoderNanoBEIREvaluator,
|
| 11 |
+
CrossEncoderRerankingEvaluator,
|
| 12 |
+
)
|
| 13 |
+
from sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss import BinaryCrossEntropyLoss
|
| 14 |
+
from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer
|
| 15 |
+
from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments
|
| 16 |
+
from sentence_transformers.evaluation.SequentialEvaluator import SequentialEvaluator
|
| 17 |
+
from sentence_transformers.util import mine_hard_negatives
|
| 18 |
+
|
| 19 |
+
# Set the log level to INFO to get more information
|
| 20 |
+
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main():
|
| 24 |
+
model_name = "prajjwal1/bert-tiny"
|
| 25 |
+
|
| 26 |
+
train_batch_size = 2048
|
| 27 |
+
num_epochs = 1
|
| 28 |
+
num_hard_negatives = 5 # How many hard negatives should be mined for each question-answer pair
|
| 29 |
+
|
| 30 |
+
# 1a. Load a model to finetune with 1b. (Optional) model card data
|
| 31 |
+
model = CrossEncoder(
|
| 32 |
+
model_name,
|
| 33 |
+
model_card_data=CrossEncoderModelCardData(
|
| 34 |
+
language="en",
|
| 35 |
+
license="apache-2.0",
|
| 36 |
+
model_name="BERT-tiny trained on GooAQ",
|
| 37 |
+
),
|
| 38 |
+
)
|
| 39 |
+
print("Model max length:", model.max_length)
|
| 40 |
+
print("Model num labels:", model.num_labels)
|
| 41 |
+
|
| 42 |
+
# 2a. Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
|
| 43 |
+
logging.info("Read the gooaq training dataset")
|
| 44 |
+
full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
|
| 45 |
+
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
|
| 46 |
+
train_dataset = dataset_dict["train"]
|
| 47 |
+
eval_dataset = dataset_dict["test"]
|
| 48 |
+
logging.info(train_dataset)
|
| 49 |
+
logging.info(eval_dataset)
|
| 50 |
+
|
| 51 |
+
# 2b. Modify our training dataset to include hard negatives using a very efficient embedding model
|
| 52 |
+
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
|
| 53 |
+
hard_train_dataset = mine_hard_negatives(
|
| 54 |
+
train_dataset,
|
| 55 |
+
embedding_model,
|
| 56 |
+
num_negatives=num_hard_negatives, # How many negatives per question-answer pair
|
| 57 |
+
margin=0, # Similarity between query and negative samples should be x lower than query-positive similarity
|
| 58 |
+
range_min=0, # Skip the x most similar samples
|
| 59 |
+
range_max=100, # Consider only the x most similar samples
|
| 60 |
+
sampling_strategy="top", # Randomly sample negatives from the range
|
| 61 |
+
batch_size=4096, # Use a batch size of 4096 for the embedding model
|
| 62 |
+
output_format="labeled-pair", # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
|
| 63 |
+
use_faiss=True,
|
| 64 |
+
)
|
| 65 |
+
logging.info(hard_train_dataset)
|
| 66 |
+
|
| 67 |
+
# 2c. (Optionally) Save the hard training dataset to disk
|
| 68 |
+
# hard_train_dataset.save_to_disk("gooaq-hard-train")
|
| 69 |
+
# Load again with:
|
| 70 |
+
# hard_train_dataset = load_from_disk("gooaq-hard-train")
|
| 71 |
+
|
| 72 |
+
# 3. Define our training loss.
|
| 73 |
+
# pos_weight is recommended to be set as the ratio between positives to negatives, a.k.a. `num_hard_negatives`
|
| 74 |
+
loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))
|
| 75 |
+
|
| 76 |
+
# 4a. Define evaluators. We use the CrossEncoderNanoBEIREvaluator, which is a light-weight evaluator for English reranking
|
| 77 |
+
nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
|
| 78 |
+
dataset_names=["msmarco", "nfcorpus", "nq"],
|
| 79 |
+
batch_size=train_batch_size,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# 4b. Define a reranking evaluator by mining hard negatives given query-answer pairs
|
| 83 |
+
# We include the positive answer in the list of negatives, so the evaluator can use the performance of the
|
| 84 |
+
# embedding model as a baseline.
|
| 85 |
+
hard_eval_dataset = mine_hard_negatives(
|
| 86 |
+
eval_dataset,
|
| 87 |
+
embedding_model,
|
| 88 |
+
corpus=full_dataset["answer"], # Use the full dataset as the corpus
|
| 89 |
+
num_negatives=30, # How many documents to rerank
|
| 90 |
+
batch_size=4096,
|
| 91 |
+
disqualify_positives=False,
|
| 92 |
+
output_format="n-tuple",
|
| 93 |
+
use_faiss=True,
|
| 94 |
+
)
|
| 95 |
+
logging.info(hard_eval_dataset)
|
| 96 |
+
reranking_evaluator = CrossEncoderRerankingEvaluator(
|
| 97 |
+
samples=[
|
| 98 |
+
{
|
| 99 |
+
"query": sample["question"],
|
| 100 |
+
"positive": [sample["answer"]],
|
| 101 |
+
"documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
|
| 102 |
+
}
|
| 103 |
+
for sample in hard_eval_dataset
|
| 104 |
+
],
|
| 105 |
+
batch_size=train_batch_size,
|
| 106 |
+
name="gooaq-dev",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# 4c. Combine the evaluators & run the base model on them
|
| 110 |
+
evaluator = SequentialEvaluator([reranking_evaluator, nano_beir_evaluator])
|
| 111 |
+
evaluator(model)
|
| 112 |
+
|
| 113 |
+
# 5. Define the training arguments
|
| 114 |
+
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
|
| 115 |
+
run_name = f"reranker-{short_model_name}-gooaq-bce"
|
| 116 |
+
args = CrossEncoderTrainingArguments(
|
| 117 |
+
# Required parameter:
|
| 118 |
+
output_dir=f"models/{run_name}",
|
| 119 |
+
# Optional training parameters:
|
| 120 |
+
num_train_epochs=num_epochs,
|
| 121 |
+
per_device_train_batch_size=train_batch_size,
|
| 122 |
+
per_device_eval_batch_size=train_batch_size,
|
| 123 |
+
learning_rate=5e-4,
|
| 124 |
+
warmup_ratio=0.1,
|
| 125 |
+
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
|
| 126 |
+
bf16=True, # Set to True if you have a GPU that supports BF16
|
| 127 |
+
load_best_model_at_end=True,
|
| 128 |
+
metric_for_best_model="eval_NanoBEIR_R100_mean_ndcg@10",
|
| 129 |
+
# Optional tracking/debugging parameters:
|
| 130 |
+
eval_strategy="steps",
|
| 131 |
+
eval_steps=20,
|
| 132 |
+
save_strategy="steps",
|
| 133 |
+
save_steps=20,
|
| 134 |
+
save_total_limit=2,
|
| 135 |
+
logging_steps=20,
|
| 136 |
+
logging_first_step=True,
|
| 137 |
+
run_name=run_name, # Will be used in W&B if `wandb` is installed
|
| 138 |
+
seed=12,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# 6. Create the trainer & start training
|
| 142 |
+
trainer = CrossEncoderTrainer(
|
| 143 |
+
model=model,
|
| 144 |
+
args=args,
|
| 145 |
+
train_dataset=hard_train_dataset,
|
| 146 |
+
loss=loss,
|
| 147 |
+
evaluator=evaluator,
|
| 148 |
+
)
|
| 149 |
+
trainer.train()
|
| 150 |
+
|
| 151 |
+
# 7. Evaluate the final model, useful to include these in the model card
|
| 152 |
+
evaluator(model)
|
| 153 |
+
|
| 154 |
+
# 8. Save the final model
|
| 155 |
+
final_output_dir = f"models/{run_name}/final"
|
| 156 |
+
model.save_pretrained(final_output_dir)
|
| 157 |
+
|
| 158 |
+
# 9. (Optional) save the model to the Hugging Face Hub!
|
| 159 |
+
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
|
| 160 |
+
try:
|
| 161 |
+
model.push_to_hub(f"cross-encoder-testing/{run_name}")
|
| 162 |
+
except Exception:
|
| 163 |
+
logging.error(
|
| 164 |
+
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
|
| 165 |
+
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
|
| 166 |
+
f"and saving it using `model.push_to_hub('{run_name}')`."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
if __name__ == "__main__":
|
| 171 |
+
main()
|