File size: 4,297 Bytes
9bbba62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from __future__ import annotations

import os
from collections.abc import Generator

import pytest
from datasets import Dataset, load_dataset

from sentence_transformers import SparseEncoder, SparseEncoderTrainer, SparseEncoderTrainingArguments
from sentence_transformers.sparse_encoder import losses
from sentence_transformers.sparse_encoder.evaluation import SparseEmbeddingSimilarityEvaluator
from sentence_transformers.util import is_training_available

if not is_training_available():
    pytest.skip(
        reason='Sentence Transformers was not installed with the `["train"]` extra.',
        allow_module_level=True,
    )


@pytest.fixture()
def sts_resource() -> Generator[tuple[Dataset, Dataset], None, None]:
    sts_dataset = load_dataset("sentence-transformers/stsb")
    yield sts_dataset["train"], sts_dataset["test"]


@pytest.fixture()
def dummy_sparse_encoder_model() -> SparseEncoder:
    return SparseEncoder("sparse-encoder-testing/splade-bert-tiny-nq")


def evaluate_stsb_test(
    model: SparseEncoder, expected_score: float, test_dataset: Dataset, num_test_samples: int = -1
) -> None:
    if num_test_samples > 0:
        test_dataset = test_dataset.select(range(num_test_samples))
    test_s1 = test_dataset["sentence1"]
    test_s2 = test_dataset["sentence2"]
    test_labels = test_dataset["score"]

    evaluator = SparseEmbeddingSimilarityEvaluator(
        sentences1=test_s1,
        sentences2=test_s2,
        scores=test_labels,
        max_active_dims=64,
    )
    scores_dict = evaluator(model)

    assert evaluator.primary_metric, "Could not find spearman cosine correlation metric in evaluator output"

    score = scores_dict[evaluator.primary_metric] * 100
    print(f"STS-Test Performance: {score:.2f} vs. exp: {expected_score:.2f}")
    assert score > expected_score or abs(score - expected_score) < 0.5


@pytest.mark.slow
def test_train_stsb_slow(
    dummy_sparse_encoder_model: SparseEncoder, sts_resource: tuple[Dataset, Dataset], tmp_path
) -> None:
    model = dummy_sparse_encoder_model
    train_dataset, test_dataset = sts_resource

    loss = losses.SpladeLoss(
        model=model,
        loss=losses.SparseMultipleNegativesRankingLoss(model=model),
        document_regularizer_weight=3e-5,
        query_regularizer_weight=5e-5,
    )

    training_args = SparseEncoderTrainingArguments(
        output_dir=tmp_path,
        num_train_epochs=1,
        per_device_train_batch_size=16,  # Smaller batch for faster test
        warmup_steps=10,
        logging_steps=10,
        eval_strategy="no",
        save_strategy="no",
        learning_rate=2e-5,
        remove_unused_columns=False,  # Important when using custom datasets
    )

    trainer = SparseEncoderTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        loss=loss,
    )
    trainer.train()
    evaluate_stsb_test(model, 50, test_dataset, num_test_samples=50)  # Lower expected score for a short training


@pytest.mark.skipif("CI" in os.environ, reason="This test triggers rate limits too often in the CI")
def test_train_stsb(
    dummy_sparse_encoder_model: SparseEncoder, sts_resource: tuple[Dataset, Dataset], tmp_path
) -> None:
    model = dummy_sparse_encoder_model
    train_dataset, test_dataset = sts_resource
    train_dataset = train_dataset.select(range(100))

    loss = losses.SpladeLoss(
        model=model,
        loss=losses.SparseMultipleNegativesRankingLoss(model=model),
        document_regularizer_weight=3e-5,
        query_regularizer_weight=5e-5,
    )

    training_args = SparseEncoderTrainingArguments(
        output_dir=tmp_path,
        num_train_epochs=1,
        per_device_train_batch_size=8,  # Even smaller batch
        warmup_steps=10,
        logging_steps=5,
        # eval_strategy="steps", # No eval during this very short training
        # eval_steps=20,
        save_strategy="no",  # No saving for this quick test
        # save_steps=20,
        learning_rate=2e-5,
        remove_unused_columns=False,
    )

    trainer = SparseEncoderTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        loss=loss,
    )
    trainer.train()
    evaluate_stsb_test(model, 50, test_dataset, num_test_samples=50)  # Very low expectation