File size: 18,648 Bytes
9bbba62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
from __future__ import annotations

import tempfile
from contextlib import nullcontext
from pathlib import Path

import pytest
import torch

from sentence_transformers import SparseEncoder, SparseEncoderTrainer, SparseEncoderTrainingArguments
from sentence_transformers.sparse_encoder import losses
from sentence_transformers.util import is_datasets_available, is_training_available

if is_datasets_available():
    from datasets import Dataset, DatasetDict, IterableDatasetDict

if not is_training_available():
    pytest.skip(
        reason='Sentence Transformers was not installed with the `["train"]` extra.',
        allow_module_level=True,
    )


@pytest.fixture
def dummy_train_eval_datasets_for_trainer() -> tuple[Dataset, Dataset]:
    # Create minimal datasets for trainer tests
    train_data = {
        "sentence1": [f"train_s1_{i}" for i in range(20)],
        "sentence2": [f"train_s2_{i}" for i in range(20)],
        "score": [float(i % 2) for i in range(20)],
    }
    eval_data = {
        "sentence1": [f"eval_s1_{i}" for i in range(10)],
        "sentence2": [f"eval_s2_{i}" for i in range(10)],
        "score": [float(i % 2) for i in range(10)],
    }
    train_dataset = Dataset.from_dict(train_data)
    eval_dataset = Dataset.from_dict(eval_data)
    return train_dataset, eval_dataset


def test_model_card_reuse(splade_bert_tiny_model: SparseEncoder):
    model = splade_bert_tiny_model

    initial_card_text = model._model_card_text

    SparseEncoderTrainer(
        model=model,
        loss=losses.SpladeLoss(
            model=model,
            loss=losses.SparseMultipleNegativesRankingLoss(model=model),
            document_regularizer_weight=3e-5,
            query_regularizer_weight=5e-5,
        ),
    )

    with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmp_folder:
        model_path = Path(tmp_folder) / "sparse_model_local"
        model.save_pretrained(str(model_path))

        with open(model_path / "README.md", encoding="utf8") as f:
            trained_model_card_text = f.read()

        if initial_card_text:
            assert trained_model_card_text != initial_card_text
        else:
            assert trained_model_card_text is not None  # Should have created one


@pytest.mark.parametrize("streaming", [False, True])
def test_trainer(
    splade_bert_tiny_model: SparseEncoder,
    dummy_train_eval_datasets_for_trainer: tuple[Dataset, Dataset],
    streaming: bool,
) -> None:
    model = splade_bert_tiny_model
    train_dataset, eval_dataset = dummy_train_eval_datasets_for_trainer

    context = nullcontext()
    if streaming:
        train_dataset = train_dataset.to_iterable_dataset()
        eval_dataset = eval_dataset.to_iterable_dataset()

    original_model_params = [p.clone() for p in model.parameters()]

    loss = losses.SpladeLoss(
        model=model,
        loss=losses.SparseMultipleNegativesRankingLoss(model=model),
        document_regularizer_weight=3e-5,
        query_regularizer_weight=5e-5,
    )

    with tempfile.TemporaryDirectory() as temp_dir:
        args = SparseEncoderTrainingArguments(
            output_dir=str(temp_dir),
            max_steps=2,
            eval_strategy="steps",  # Changed from eval_steps to eval_strategy
            eval_steps=2,
            per_device_train_batch_size=2,
            per_device_eval_batch_size=2,
            logging_steps=1,
            remove_unused_columns=False,  # Important for custom dict datasets
            report_to=["none"],
        )
        with context:  # context is nullcontext unless streaming causes issues not caught here
            trainer = SparseEncoderTrainer(
                model=model,
                args=args,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                loss=loss,
            )
            trainer.train()

    if isinstance(context, nullcontext):
        # Check if model parameters have changed after training
        model_changed = False
        for p_orig, p_new in zip(original_model_params, model.parameters()):
            if not torch.equal(p_orig, p_new):
                model_changed = True
                break
        assert model_changed, "Model parameters should have changed after training."

        # Simple check to ensure prediction works after training
        try:
            model.encode(["Test sentence after training."])
        except Exception as e:
            pytest.fail(f"Encoding failed after training: {e}")


@pytest.mark.slow
@pytest.mark.parametrize("train_dict", [False, True])
@pytest.mark.parametrize("eval_dict", [False, True])
@pytest.mark.parametrize("loss_dict", [False, True])
@pytest.mark.parametrize("add_transform", [False, True])
@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize(
    "prompts",
    [
        None,  # No prompt
        "Prompt: ",  # Single prompt to all columns and all datasets
        {"stsb-1": "Prompt 1: ", "stsb-2": "Prompt 2: "},  # Different prompts for different datasets
        {"sentence1": "Prompt 1: ", "sentence2": "Prompt 2: "},  # Different prompts for different columns
        {
            "stsb-1": {"sentence1": "Prompt 1: ", "sentence2": "Prompt 2: "},
            "stsb-2": {"sentence1": "Prompt 3: ", "sentence2": "Prompt 4: "},
        },  # Different prompts for different datasets and columns
    ],
)
def test_trainer_prompts(
    splade_bert_tiny_model: SparseEncoder,
    train_dict: bool,
    eval_dict: bool,
    loss_dict: bool,
    add_transform: bool,
    streaming: bool,
    prompts: dict[str, dict[str, str]] | dict[str, str] | str | None,
):
    if loss_dict and (not train_dict or not eval_dict):
        pytest.skip(
            "Skipping test because loss_dict=True requires train_dict=True and eval_dict=True; already tested via test_trainer."
        )

    model = splade_bert_tiny_model

    train_dataset_1 = Dataset.from_dict(
        {
            "sentence1": ["train 1 sentence 1a", "train 1 sentence 1b"],
            "sentence2": ["train 1 sentence 2a", "train 1 sentence 2b"],
        }
    )
    train_dataset_2 = Dataset.from_dict(
        {
            "sentence1": ["train 2 sentence 1a", "train 2 sentence 1b"],
            "sentence2": ["train 2 sentence 2a", "train 2 sentence 2b"],
        }
    )
    eval_dataset_1 = Dataset.from_dict(
        {
            "sentence1": ["eval 1 sentence 1a", "eval 1 sentence 1b"],
            "sentence2": ["eval 1 sentence 2a", "eval 1 sentence 2b"],
        }
    )
    eval_dataset_2 = Dataset.from_dict(
        {
            "sentence1": ["eval 2 sentence 1a", "eval 2 sentence 1b"],
            "sentence2": ["eval 2 sentence 2a", "eval 2 sentence 2b"],
        }
    )

    loss = losses.SpladeLoss(
        model=model,
        loss=losses.SparseMultipleNegativesRankingLoss(model=model),
        document_regularizer_weight=3e-5,
        query_regularizer_weight=5e-5,
    )

    tracked_texts = []
    old_preprocess = model.preprocess

    def preprocess_tracker(texts, prompt=None, **kwargs):
        if prompt:
            tracked_texts.extend([prompt + text for text in texts])
        else:
            tracked_texts.extend(texts)
        return old_preprocess(texts, prompt=prompt, **kwargs)

    model.preprocess = preprocess_tracker

    if train_dict:
        if streaming:
            train_dataset = IterableDatasetDict({"stsb-1": train_dataset_1, "stsb-2": train_dataset_2})
        else:
            train_dataset = DatasetDict({"stsb-1": train_dataset_1, "stsb-2": train_dataset_2})
    else:
        if streaming:
            train_dataset = train_dataset_1.to_iterable_dataset()
        else:
            train_dataset = train_dataset_1

    if eval_dict:
        if streaming:
            eval_dataset = IterableDatasetDict({"stsb-1": eval_dataset_1, "stsb-2": eval_dataset_2})
        else:
            eval_dataset = DatasetDict({"stsb-1": eval_dataset_1, "stsb-2": eval_dataset_2})
    else:
        if streaming:
            eval_dataset = eval_dataset_1.to_iterable_dataset()
        else:
            eval_dataset = eval_dataset_1

    def upper_transform(batch):
        for column_name, column in batch.items():
            batch[column_name] = [text.upper() for text in column]
        return batch

    if add_transform:
        if streaming:
            if train_dict:
                train_dataset = IterableDatasetDict(
                    {
                        dataset_name: dataset.map(upper_transform, batched=True, features=dataset.features)
                        for dataset_name, dataset in train_dataset.items()
                    }
                )
            else:
                train_dataset = train_dataset.map(upper_transform, batched=True, features=train_dataset.features)
            if eval_dict:
                eval_dataset = IterableDatasetDict(
                    {
                        dataset_name: dataset.map(upper_transform, batched=True, features=dataset.features)
                        for dataset_name, dataset in eval_dataset.items()
                    }
                )
            else:
                eval_dataset = eval_dataset.map(upper_transform, batched=True, features=eval_dataset.features)
        else:
            train_dataset.set_transform(upper_transform)
            eval_dataset.set_transform(upper_transform)

    if loss_dict:
        loss = {
            "stsb-1": loss,
            "stsb-2": loss,
        }

    # Variables to more easily track the expected outputs. Uppercased if add_transform is True as we expect
    # the transform to be applied to the data.
    all_train_1_1 = {s.upper() if add_transform else s for s in train_dataset_1["sentence1"]}
    all_train_1_2 = {s.upper() if add_transform else s for s in train_dataset_1["sentence2"]}
    all_train_2_1 = {s.upper() if add_transform else s for s in train_dataset_2["sentence1"]}
    all_train_2_2 = {s.upper() if add_transform else s for s in train_dataset_2["sentence2"]}
    all_eval_1_1 = {s.upper() if add_transform else s for s in eval_dataset_1["sentence1"]}
    all_eval_1_2 = {s.upper() if add_transform else s for s in eval_dataset_1["sentence2"]}
    all_eval_2_1 = {s.upper() if add_transform else s for s in eval_dataset_2["sentence1"]}
    all_eval_2_2 = {s.upper() if add_transform else s for s in eval_dataset_2["sentence2"]}
    all_train_1 = all_train_1_1 | all_train_1_2
    all_train_2 = all_train_2_1 | all_train_2_2
    all_eval_1 = all_eval_1_1 | all_eval_1_2
    all_eval_2 = all_eval_2_1 | all_eval_2_2
    all_train = all_train_1 | all_train_2
    all_eval = all_eval_1 | all_eval_2

    if prompts == {
        "stsb-1": {"sentence1": "Prompt 1: ", "sentence2": "Prompt 2: "},
        "stsb-2": {"sentence1": "Prompt 3: ", "sentence2": "Prompt 4: "},
    } and (train_dict, eval_dict) != (True, True):
        context = pytest.raises(
            ValueError,
            match="The prompts provided to the trainer are a nested dictionary. In this setting, the first "
            "level of the dictionary should map to dataset names and the second level to column names. "
            "However, as the provided dataset is a not a DatasetDict, no dataset names can be inferred. "
            "The keys to the provided prompts dictionary are .*",
        )
    else:
        context = nullcontext()

    with tempfile.TemporaryDirectory() as temp_dir:
        args = SparseEncoderTrainingArguments(
            output_dir=str(temp_dir),
            prompts=prompts,
            max_steps=4 if train_dict else 2,
            eval_steps=4 if train_dict else 2,
            eval_strategy="steps",
            per_device_train_batch_size=1,
            per_device_eval_batch_size=1,
            report_to=["none"],
        )
        trainer = SparseEncoderTrainer(
            model=model,
            args=args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            loss=loss,
        )

        tracked_texts.clear()

        datacollator_keys = set()
        old_compute_loss = trainer.compute_loss

        def compute_loss_tracker(model, inputs, **kwargs):
            datacollator_keys.update(set(inputs.keys()))
            return old_compute_loss(model, inputs, **kwargs)

        trainer.compute_loss = compute_loss_tracker
        with context:
            trainer.train()

        if not isinstance(context, nullcontext):
            return

    # prompt_length keys may appear in the batch when prompts are provided (Transformer.preprocess always
    # computes them), but SpladePooling simply ignores them. Only Pooling uses them when include_prompt=False.

    # We only need the dataset_name if the loss requires it, or the prompts are a nested dictionary
    if (train_dict or eval_dict) and (loss_dict or (isinstance(prompts, dict))):
        assert "dataset_name" in datacollator_keys
    else:
        assert "dataset_name" not in datacollator_keys

    if prompts is None:
        if (train_dict, eval_dict) == (False, False):
            expected = all_train_1 | all_eval_1
        elif (train_dict, eval_dict) == (True, False):
            expected = all_train | all_eval_1
        elif (train_dict, eval_dict) == (False, True):
            expected = all_train_1 | all_eval
        elif (train_dict, eval_dict) == (True, True):
            expected = all_train | all_eval

    elif prompts == "Prompt: ":
        if (train_dict, eval_dict) == (False, False):
            expected = {prompts + sample for sample in all_train_1} | {prompts + sample for sample in all_eval_1}
        elif (train_dict, eval_dict) == (True, False):
            expected = {prompts + sample for sample in all_train} | {prompts + sample for sample in all_eval_1}
        elif (train_dict, eval_dict) == (False, True):
            expected = {prompts + sample for sample in all_train_1} | {prompts + sample for sample in all_eval}
        elif (train_dict, eval_dict) == (True, True):
            expected = {prompts + sample for sample in all_train} | {prompts + sample for sample in all_eval}

    elif prompts == {"stsb-1": "Prompt 1: ", "stsb-2": "Prompt 2: "}:
        # If we don't have dataset dictionaries, the prompts will be seen as column names
        if (train_dict, eval_dict) == (False, False):
            expected = all_train_1 | all_eval_1
        elif (train_dict, eval_dict) == (True, False):
            expected = (
                {prompts["stsb-1"] + sample for sample in all_train_1}
                | {prompts["stsb-2"] + sample for sample in all_train_2}
                | all_eval_1
            )
        elif (train_dict, eval_dict) == (False, True):
            expected = (
                all_train_1
                | {prompts["stsb-1"] + sample for sample in all_eval_1}
                | {prompts["stsb-2"] + sample for sample in all_eval_2}
            )
        elif (train_dict, eval_dict) == (True, True):
            expected = (
                {prompts["stsb-1"] + sample for sample in all_train_1}
                | {prompts["stsb-2"] + sample for sample in all_train_2}
                | {prompts["stsb-1"] + sample for sample in all_eval_1}
                | {prompts["stsb-2"] + sample for sample in all_eval_2}
            )

    elif prompts == {"sentence1": "Prompt 1: ", "sentence2": "Prompt 2: "}:
        if (train_dict, eval_dict) == (False, False):
            expected = (
                {prompts["sentence1"] + sample for sample in all_train_1_1}
                | {prompts["sentence2"] + sample for sample in all_train_1_2}
                | {prompts["sentence1"] + sample for sample in all_eval_1_1}
                | {prompts["sentence2"] + sample for sample in all_eval_1_2}
            )
        elif (train_dict, eval_dict) == (True, False):
            expected = (
                {prompts["sentence1"] + sample for sample in all_train_1_1}
                | {prompts["sentence2"] + sample for sample in all_train_1_2}
                | {prompts["sentence1"] + sample for sample in all_train_2_1}
                | {prompts["sentence2"] + sample for sample in all_train_2_2}
                | {prompts["sentence1"] + sample for sample in all_eval_1_1}
                | {prompts["sentence2"] + sample for sample in all_eval_1_2}
            )
        elif (train_dict, eval_dict) == (False, True):
            expected = (
                {prompts["sentence1"] + sample for sample in all_train_1_1}
                | {prompts["sentence2"] + sample for sample in all_train_1_2}
                | {prompts["sentence1"] + sample for sample in all_eval_1_1}
                | {prompts["sentence2"] + sample for sample in all_eval_1_2}
                | {prompts["sentence1"] + sample for sample in all_eval_2_1}
                | {prompts["sentence2"] + sample for sample in all_eval_2_2}
            )
        elif (train_dict, eval_dict) == (True, True):
            expected = (
                {prompts["sentence1"] + sample for sample in all_train_1_1}
                | {prompts["sentence2"] + sample for sample in all_train_1_2}
                | {prompts["sentence1"] + sample for sample in all_train_2_1}
                | {prompts["sentence2"] + sample for sample in all_train_2_2}
                | {prompts["sentence1"] + sample for sample in all_eval_1_1}
                | {prompts["sentence2"] + sample for sample in all_eval_1_2}
                | {prompts["sentence1"] + sample for sample in all_eval_2_1}
                | {prompts["sentence2"] + sample for sample in all_eval_2_2}
            )

    elif prompts == {
        "stsb-1": {"sentence1": "Prompt 1: ", "sentence2": "Prompt 2: "},
        "stsb-2": {"sentence1": "Prompt 3: ", "sentence2": "Prompt 4: "},
    }:
        # All other cases are tested above with the ValueError context
        if (train_dict, eval_dict) == (True, True):
            expected = (
                {prompts["stsb-1"]["sentence1"] + sample for sample in all_train_1_1}
                | {prompts["stsb-1"]["sentence2"] + sample for sample in all_train_1_2}
                | {prompts["stsb-2"]["sentence1"] + sample for sample in all_train_2_1}
                | {prompts["stsb-2"]["sentence2"] + sample for sample in all_train_2_2}
                | {prompts["stsb-1"]["sentence1"] + sample for sample in all_eval_1_1}
                | {prompts["stsb-1"]["sentence2"] + sample for sample in all_eval_1_2}
                | {prompts["stsb-2"]["sentence1"] + sample for sample in all_eval_2_1}
                | {prompts["stsb-2"]["sentence2"] + sample for sample in all_eval_2_2}
            )

    assert set(tracked_texts) == expected