File size: 4,297 Bytes
9bbba62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | from __future__ import annotations
import os
from collections.abc import Generator
import pytest
from datasets import Dataset, load_dataset
from sentence_transformers import SparseEncoder, SparseEncoderTrainer, SparseEncoderTrainingArguments
from sentence_transformers.sparse_encoder import losses
from sentence_transformers.sparse_encoder.evaluation import SparseEmbeddingSimilarityEvaluator
from sentence_transformers.util import is_training_available
if not is_training_available():
pytest.skip(
reason='Sentence Transformers was not installed with the `["train"]` extra.',
allow_module_level=True,
)
@pytest.fixture()
def sts_resource() -> Generator[tuple[Dataset, Dataset], None, None]:
sts_dataset = load_dataset("sentence-transformers/stsb")
yield sts_dataset["train"], sts_dataset["test"]
@pytest.fixture()
def dummy_sparse_encoder_model() -> SparseEncoder:
return SparseEncoder("sparse-encoder-testing/splade-bert-tiny-nq")
def evaluate_stsb_test(
model: SparseEncoder, expected_score: float, test_dataset: Dataset, num_test_samples: int = -1
) -> None:
if num_test_samples > 0:
test_dataset = test_dataset.select(range(num_test_samples))
test_s1 = test_dataset["sentence1"]
test_s2 = test_dataset["sentence2"]
test_labels = test_dataset["score"]
evaluator = SparseEmbeddingSimilarityEvaluator(
sentences1=test_s1,
sentences2=test_s2,
scores=test_labels,
max_active_dims=64,
)
scores_dict = evaluator(model)
assert evaluator.primary_metric, "Could not find spearman cosine correlation metric in evaluator output"
score = scores_dict[evaluator.primary_metric] * 100
print(f"STS-Test Performance: {score:.2f} vs. exp: {expected_score:.2f}")
assert score > expected_score or abs(score - expected_score) < 0.5
@pytest.mark.slow
def test_train_stsb_slow(
dummy_sparse_encoder_model: SparseEncoder, sts_resource: tuple[Dataset, Dataset], tmp_path
) -> None:
model = dummy_sparse_encoder_model
train_dataset, test_dataset = sts_resource
loss = losses.SpladeLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(model=model),
document_regularizer_weight=3e-5,
query_regularizer_weight=5e-5,
)
training_args = SparseEncoderTrainingArguments(
output_dir=tmp_path,
num_train_epochs=1,
per_device_train_batch_size=16, # Smaller batch for faster test
warmup_steps=10,
logging_steps=10,
eval_strategy="no",
save_strategy="no",
learning_rate=2e-5,
remove_unused_columns=False, # Important when using custom datasets
)
trainer = SparseEncoderTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
evaluate_stsb_test(model, 50, test_dataset, num_test_samples=50) # Lower expected score for a short training
@pytest.mark.skipif("CI" in os.environ, reason="This test triggers rate limits too often in the CI")
def test_train_stsb(
dummy_sparse_encoder_model: SparseEncoder, sts_resource: tuple[Dataset, Dataset], tmp_path
) -> None:
model = dummy_sparse_encoder_model
train_dataset, test_dataset = sts_resource
train_dataset = train_dataset.select(range(100))
loss = losses.SpladeLoss(
model=model,
loss=losses.SparseMultipleNegativesRankingLoss(model=model),
document_regularizer_weight=3e-5,
query_regularizer_weight=5e-5,
)
training_args = SparseEncoderTrainingArguments(
output_dir=tmp_path,
num_train_epochs=1,
per_device_train_batch_size=8, # Even smaller batch
warmup_steps=10,
logging_steps=5,
# eval_strategy="steps", # No eval during this very short training
# eval_steps=20,
save_strategy="no", # No saving for this quick test
# save_steps=20,
learning_rate=2e-5,
remove_unused_columns=False,
)
trainer = SparseEncoderTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
evaluate_stsb_test(model, 50, test_dataset, num_test_samples=50) # Very low expectation
|