File size: 20,988 Bytes
bd33eac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
from __future__ import annotations

import re
from unittest.mock import patch

import pytest
import torch

from sentence_transformers.models import Pooling, Transformer
from sentence_transformers.sparse_encoder.models import MLMTransformer, SparseAutoEncoder, SpladePooling
from sentence_transformers.sparse_encoder.SparseEncoder import SparseEncoder
from tests.sparse_encoder.utils import sparse_allclose


@pytest.mark.parametrize(
    ("texts", "top_k", "expected_shape"),
    [
        # Single text, default top_k (None)
        (["The weather is nice!"], None, 1),
        # Single text, specific top_k
        (["The weather is nice!"], 3, 1),
        # String text, specific top_k, expect a non-nested list
        ("The weather is nice!", 8, 8),
        # Multiple texts, default top_k (None)
        (["The weather is nice!", "It's sunny outside"], None, 2),
        # Multiple texts, specific top_k
        (["The weather is nice!", "It's sunny outside"], 3, 2),
    ],
)
def test_decode_shapes(
    splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, top_k: int, expected_shape: int
) -> None:
    model = splade_bert_tiny_model
    embeddings = model.encode(texts)
    decoded = model.decode(embeddings, top_k=top_k)

    assert len(decoded) == expected_shape

    if isinstance(texts, list):
        if len(texts) == 1:
            assert isinstance(decoded[0], tuple) or isinstance(decoded, list)
            if top_k is not None:
                assert len(decoded) <= top_k
        else:
            assert isinstance(decoded, list)
            for item in decoded:
                assert isinstance(item, list)
                if top_k is not None:
                    assert len(item) <= top_k


@pytest.mark.parametrize(
    ("text", "expected_token_types"),
    [
        ("The weather is nice!", str),
        ("It's sunny outside", str),
    ],
)
def test_decode_token_types(splade_bert_tiny_model: SparseEncoder, text: str, expected_token_types: type) -> None:
    model = splade_bert_tiny_model
    embeddings = model.encode(text)
    decoded = model.decode(embeddings)

    # Check the first item in the batch
    for token, weight in decoded:
        assert isinstance(token, expected_token_types)
        assert isinstance(weight, float)


@pytest.mark.parametrize(
    ("text", "top_k"),
    [
        ("The weather is nice!", 1),
        ("It's sunny outside", 3),
        ("Hello world", 5),
    ],
)
def test_decode_top_k_respects_limit(splade_bert_tiny_model: SparseEncoder, text: str, top_k: int) -> None:
    model = splade_bert_tiny_model
    embeddings = model.encode([text])
    decoded = model.decode(embeddings, top_k=top_k)

    assert len(decoded) <= top_k


@pytest.mark.parametrize(
    ("texts", "format_type"),
    [
        ("The weather is nice!", "1d"),
        (["The weather is nice!"], "1d"),
        (["The weather is nice!", "It's sunny outside"], "2d"),
    ],
)
def test_decode_handles_sparse_dense_inputs(
    splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, format_type: str
):
    model = splade_bert_tiny_model
    # Get embeddings and test both sparse and dense format handling
    embeddings = model.encode(texts)

    # Test with sparse tensor
    if not embeddings.is_sparse:
        embeddings_sparse = embeddings.to_sparse()
    else:
        embeddings_sparse = embeddings

    decoded_sparse = model.decode(embeddings_sparse)

    # Test with dense tensor
    if embeddings.is_sparse:
        embeddings_dense = embeddings.to_dense()
    else:
        embeddings_dense = embeddings

    decoded_dense = model.decode(embeddings_dense)

    # Verify both produce the same result structure
    if format_type == "1d":
        assert len(decoded_sparse) == len(decoded_dense)
    else:
        assert len(decoded_sparse) == len(decoded_dense)
        for i in range(len(decoded_sparse)):
            # Sort both results to ensure consistent comparison
            sorted_sparse = sorted(decoded_sparse[i], key=lambda x: (x[1], x[0]), reverse=True)
            sorted_dense = sorted(decoded_dense[i], key=lambda x: (x[1], x[0]), reverse=True)
            assert len(sorted_sparse) == len(sorted_dense)


def test_decode_empty_tensor(splade_bert_tiny_model: SparseEncoder) -> None:
    model = splade_bert_tiny_model
    # Create an empty sparse tensor
    empty_sparse = torch.sparse_coo_tensor(
        indices=torch.zeros((2, 0), dtype=torch.long),
        values=torch.zeros((0,), dtype=torch.float),
        size=(1, model.get_sentence_embedding_dimension()),
    )

    decoded = model.decode(empty_sparse)
    assert len(decoded) == 0 or (isinstance(decoded, list) and all(not item for item in decoded))


@pytest.mark.parametrize(
    "top_k",
    [None, 5, 1000],
)
@pytest.mark.parametrize(
    "texts",
    [
        ("The weather is nice!"),
        (["The weather is nice!"]),
        (["The weather is nice!", "It's sunny outside", "Hello world"]),
        (["Short text", "This is a longer text with more words to encode"]),
    ],
)
def test_decode_returns_sorted_weights(
    splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, top_k: int | None
) -> None:
    model = splade_bert_tiny_model
    embeddings = model.encode(texts)
    decoded = model.decode(embeddings, top_k=top_k)

    if isinstance(texts, list):
        for item in decoded:
            weights = [weight for _, weight in item]
            assert all(weights[i] >= weights[i + 1] for i in range(len(weights) - 1))
    else:
        weights = [weight for _, weight in decoded]
        assert all(weights[i] >= weights[i + 1] for i in range(len(weights) - 1))


def test_inference_free_splade(inference_free_splade_bert_tiny_model: SparseEncoder):
    model = inference_free_splade_bert_tiny_model
    dimensionality = model.get_sentence_embedding_dimension()

    query = "What is the capital of France?"
    document = "The capital of France is Paris."
    query_embeddings = model.encode_query(query)
    document_embeddings = model.encode_document(document)

    assert query_embeddings.shape == (dimensionality,)
    assert document_embeddings.shape == (dimensionality,)

    decoded_query = model.decode(query_embeddings)
    decoded_document = model.decode(document_embeddings)
    assert len(decoded_query) == len(model.tokenize(query, task="query")["input_ids"][0])
    assert len(decoded_document) >= 50

    assert model.max_seq_length == 512
    assert model[0].sub_modules["query"][0].max_seq_length == 512
    assert model[0].sub_modules["document"][0].max_seq_length == 512

    model.max_seq_length = 256
    assert model.max_seq_length == 256
    assert model[0].sub_modules["query"][0].max_seq_length == 256
    assert model[0].sub_modules["document"][0].max_seq_length == 256


@pytest.mark.parametrize("sentences", ["Hello world", ["Hello world", "This is a test"], [], [""]])
@pytest.mark.parametrize("prompt_name", [None, "query", "custom"])
@pytest.mark.parametrize("prompt", [None, "Custom prompt: "])
@pytest.mark.parametrize("convert_to_tensor", [True, False])
@pytest.mark.parametrize("convert_to_sparse_tensor", [True, False])
def test_encode_query(
    splade_bert_tiny_model: SparseEncoder,
    sentences: str | list[str],
    prompt_name: str | None,
    prompt: str | None,
    convert_to_tensor: bool,
    convert_to_sparse_tensor: bool,
):
    model = splade_bert_tiny_model
    # Create a mock model with required prompts
    model.prompts = {"query": "query: ", "custom": "custom: "}

    # Create a mock for the encode method
    with patch.object(model, "encode", autospec=True) as mock_encode:
        # Call encode_query
        model.encode_query(
            sentences=sentences,
            prompt_name=prompt_name,
            prompt=prompt,
            batch_size=32,
            convert_to_tensor=convert_to_tensor,
            convert_to_sparse_tensor=convert_to_sparse_tensor,
        )

        # Verify that encode was called with the correct parameters
        expected_prompt_name = prompt_name if prompt_name else "query"

        mock_encode.assert_called_once()
        args, kwargs = mock_encode.call_args

        # Check that sentences were passed correctly
        assert kwargs["sentences"] == sentences

        # Check prompt handling
        assert kwargs["prompt"] == prompt
        assert kwargs["prompt_name"] == expected_prompt_name

        # Check other parameters
        assert kwargs["convert_to_tensor"] == convert_to_tensor
        assert kwargs["convert_to_sparse_tensor"] == convert_to_sparse_tensor
        assert kwargs["task"] == "query"


@pytest.mark.parametrize("sentences", ["Hello world", ["Hello world", "This is a test"], [], [""]])
@pytest.mark.parametrize("prompt_name", [None, "document", "passage", "corpus", "custom"])
@pytest.mark.parametrize("prompt", [None, "Custom prompt: "])
@pytest.mark.parametrize("convert_to_tensor", [True, False])
@pytest.mark.parametrize("convert_to_sparse_tensor", [True, False])
def test_encode_document(
    splade_bert_tiny_model: SparseEncoder,
    sentences: str | list[str],
    prompt_name: str | None,
    prompt: str | None,
    convert_to_tensor: bool,
    convert_to_sparse_tensor: bool,
):
    # Create a mock model with required prompts
    model = splade_bert_tiny_model
    model.prompts = {"document": "document: ", "passage": "passage: ", "corpus": "corpus: ", "custom": "custom: "}

    # Create a mock for the encode method
    with patch.object(model, "encode", autospec=True) as mock_encode:
        # Call encode_document
        model.encode_document(
            sentences=sentences,
            prompt_name=prompt_name,
            prompt=prompt,
            batch_size=32,
            convert_to_tensor=convert_to_tensor,
            convert_to_sparse_tensor=convert_to_sparse_tensor,
        )

        # Verify that encode was called with the correct parameters
        mock_encode.assert_called_once()
        args, kwargs = mock_encode.call_args

        expected_prompt_name = prompt_name if prompt_name else "document"

        # Check that sentences were passed correctly
        assert kwargs["sentences"] == sentences

        # Check prompt handling
        assert kwargs["prompt"] == prompt
        assert kwargs["prompt_name"] == expected_prompt_name

        # Check other parameters
        assert kwargs["convert_to_tensor"] == convert_to_tensor
        assert kwargs["convert_to_sparse_tensor"] == convert_to_sparse_tensor
        assert kwargs["task"] == "document"


def test_encode_document_prompt_priority(splade_bert_tiny_model: SparseEncoder):
    """Test that proper prompt priority is respected when multiple options are available"""
    model = splade_bert_tiny_model
    model.prompts = {
        "document": "document: ",
        "passage": "passage: ",
        "corpus": "corpus: ",
    }

    # Create a mock for the encode method
    with patch.object(model, "encode", autospec=True) as mock_encode:
        # Call encode_document with no explicit prompt
        model.encode_document("test")

        # It should select "document" by default since that's first in the priority list
        args, kwargs = mock_encode.call_args
        assert kwargs["prompt_name"] == "document"

        # Remove document, should fall back to passage
        mock_encode.reset_mock()
        model.prompts = {
            "passage": "passage: ",
            "corpus": "corpus: ",
        }
        model.encode_document("test")
        args, kwargs = mock_encode.call_args
        assert kwargs["prompt_name"] == "passage"

        # Remove passage, should fall back to corpus
        mock_encode.reset_mock()
        model.prompts = {
            "corpus": "corpus: ",
        }
        model.encode_document("test")
        args, kwargs = mock_encode.call_args
        assert kwargs["prompt_name"] == "corpus"

        # No relevant prompts defined
        mock_encode.reset_mock()
        model.prompts = {
            "query": "query: ",
        }
        model.encode_document("test")
        args, kwargs = mock_encode.call_args
        assert kwargs["prompt_name"] is None


def test_encode_advanced_parameters(splade_bert_tiny_model: SparseEncoder):
    """Test that additional parameters are correctly passed to encode"""
    model = splade_bert_tiny_model

    # Create a mock for the encode method
    with patch.object(model, "encode", autospec=True) as mock_encode:
        # Call with advanced parameters
        model.encode_query(
            "test",
            normalize_embeddings=True,
            batch_size=64,
            show_progress_bar=True,
            max_active_dims=128,
            chunk_size=10,
            custom_param="value",
        )

        # Verify all parameters were passed correctly
        args, kwargs = mock_encode.call_args
        assert kwargs["normalize_embeddings"] is True
        assert kwargs["batch_size"] == 64
        assert kwargs["show_progress_bar"] is True
        assert kwargs["max_active_dims"] == 128
        assert kwargs["chunk_size"] == 10
        assert kwargs["custom_param"] == "value"


@pytest.mark.parametrize("inputs", ["test sentence", ["test sentence"]])
def test_encode_query_document_vs_encode(splade_bert_tiny_model: SparseEncoder, inputs: str | list[str]):
    """Test the actual integration with encode vs encode_query/encode_document"""
    # This test requires a real model, but we'll use a small one
    model = splade_bert_tiny_model
    model.prompts = {"query": "query: ", "document": "document: "}

    # Get embeddings with encode_query and encode_document
    query_embeddings = model.encode_query(inputs)
    document_embeddings = model.encode_document(inputs)

    # And the same but with encode via prompts (task doesn't help here)
    encode_query_embeddings = model.encode(inputs, prompt_name="query")
    encode_document_embeddings = model.encode(inputs, prompt_name="document")

    # With prompts they should be the same
    assert sparse_allclose(query_embeddings, encode_query_embeddings)
    assert sparse_allclose(document_embeddings, encode_document_embeddings)

    # Without prompts they should be different
    query_embeddings_without_prompt = model.encode(inputs)
    document_embeddings_without_prompt = model.encode(inputs)

    # Embeddings should differ when different prompts are used
    assert not sparse_allclose(query_embeddings_without_prompt, query_embeddings)
    assert not sparse_allclose(document_embeddings_without_prompt, document_embeddings)


def test_default_prompt(splade_bert_tiny_model: SparseEncoder):
    """Test that the default prompt is used when no prompt is specified"""
    model = splade_bert_tiny_model
    model.prompts = {"query": "query: ", "document": "document: "}
    model.default_prompt_name = "query"

    # Call encode_query without specifying a prompt
    query_embeddings = model.encode_query("test")
    assert query_embeddings.shape == (model.get_sentence_embedding_dimension(),)

    # Call encode_document without specifying a prompt
    document_embeddings = model.encode_document("test")
    assert document_embeddings.shape == (model.get_sentence_embedding_dimension(),)

    default_embeddings = model.encode("test")
    assert default_embeddings.shape == (model.get_sentence_embedding_dimension(),)

    # Make sure that the default prompt is used
    assert sparse_allclose(query_embeddings, default_embeddings)
    assert not sparse_allclose(document_embeddings, default_embeddings)

    # Also check that if the default prompt is not set, the default embeddings aren't the same as query
    model.default_prompt_name = None
    default_embeddings_no_default = model.encode("test")
    assert not sparse_allclose(default_embeddings_no_default, default_embeddings)


def test_wrong_prompt(splade_bert_tiny_model: SparseEncoder):
    """Test that using a wrong prompt raises an error"""
    model = splade_bert_tiny_model
    model.prompts = {"query": "query: ", "document": "document: "}

    for encode_method in [model.encode_query, model.encode_document, model.encode]:
        with pytest.raises(
            ValueError,
            match=re.escape(
                "Prompt name 'invalid_prompt' not found in the configured prompts dictionary with keys ['query', 'document']."
            ),
        ):
            encode_method("test", prompt_name="invalid_prompt")


def test_max_active_dims_set_init(splade_bert_tiny_model: SparseEncoder, csr_bert_tiny_model: SparseEncoder, tmp_path):
    splade_bert_tiny_model.save_pretrained(str(tmp_path / "splade_bert_tiny"))
    csr_bert_tiny_model.save_pretrained(str(tmp_path / "csr_bert_tiny"))

    # Load the models with max_active_dims set
    loaded_model = SparseEncoder(str(tmp_path / "splade_bert_tiny"))
    assert loaded_model.max_active_dims is None
    loaded_model = SparseEncoder(str(tmp_path / "splade_bert_tiny"), max_active_dims=13)
    assert loaded_model.max_active_dims == 13

    loaded_model = SparseEncoder(str(tmp_path / "csr_bert_tiny"))
    assert loaded_model.max_active_dims == 16  # Based on the SparseAutoEncoder's k value
    loaded_model = SparseEncoder(str(tmp_path / "csr_bert_tiny"), max_active_dims=13)
    assert loaded_model.max_active_dims == 13


def test_detect_mlm():
    model = SparseEncoder("distilbert/distilbert-base-uncased")

    assert isinstance(model[0], MLMTransformer)
    assert isinstance(model[1], SpladePooling)


def test_default_to_csr():
    # NOTE: bert-tiny is actually MLM-based, but the config isn't modern enough to allow us to detect it,
    # so we should default to CSR here.
    model = SparseEncoder("prajjwal1/bert-tiny")
    assert isinstance(model[0], Transformer)
    assert isinstance(model[1], Pooling)
    assert isinstance(model[2], SparseAutoEncoder)


def test_sparsity(splade_bert_tiny_model: SparseEncoder):
    model = splade_bert_tiny_model

    # Check that the sparsity is applied correctly
    embeddings = model.encode_query(["What is the capital of France?", "Who has won the World Cup in 2016?"])
    sparsity = model.sparsity(embeddings)
    assert isinstance(sparsity, dict)
    assert "active_dims" in sparsity
    assert "sparsity_ratio" in sparsity
    assert sparsity["active_dims"] < 100 and sparsity["active_dims"] > 0
    assert sparsity["sparsity_ratio"] < 1.0 and sparsity["sparsity_ratio"] >= 0.99

    # Also check with dense tensors
    dense_sparsity = model.sparsity(embeddings.to_dense())
    assert dense_sparsity == sparsity, "Sparsity should be the same for dense and sparse tensors"

    # Check that 1-dimensional embeddings work correctly
    sparsity_one = model.sparsity(embeddings[0])
    sparsity_two = model.sparsity(embeddings[1])
    assert (sparsity_one["active_dims"] + sparsity_two["active_dims"]) / 2 == sparsity["active_dims"]


def test_splade_pooling_chunk_size(splade_bert_tiny_model: SparseEncoder):
    model = splade_bert_tiny_model

    # The chunk size defaults to None, i.e. no chunking
    assert model.splade_pooling_chunk_size is None
    # But we can chunk the pooling to save memory at the cost of some speed
    model.splade_pooling_chunk_size = 13
    assert model.splade_pooling_chunk_size == 13
    assert isinstance(model[1], SpladePooling)
    assert model[1].chunk_size == 13


def test_intersection(splade_bert_tiny_model: SparseEncoder):
    model = splade_bert_tiny_model

    # Test intersection with a single text
    query = "Where can I deposit my money?"
    document = "I'm sitting by the river."
    query_embeddings = model.encode_query(query)
    document_embeddings = model.encode_document(document)
    query_sparsity = model.sparsity(query_embeddings)
    document_sparsity = model.sparsity(document_embeddings)

    # Let's check that the intersection is a tensor and has the correct shape
    intersection = model.intersection(query_embeddings, document_embeddings)
    assert isinstance(intersection, torch.Tensor)
    assert intersection.shape == (model.get_sentence_embedding_dimension(),)

    # Check that the intersection sparsity is less than both query and document sparsities
    intersection_sparsity = model.sparsity(intersection)
    assert (
        intersection_sparsity["active_dims"] < query_sparsity["active_dims"]
        and intersection_sparsity["active_dims"] < document_sparsity["active_dims"]
    )

    # Test with multiple texts
    query = "Who has won the World Cup in 2016?"
    documents = ["The capital of France is Paris.", "Germany won the World Cup in 2014."]
    query_embeddings = model.encode_query(query)
    document_embeddings = model.encode_document(documents)

    intersection_batch = model.intersection(query_embeddings, document_embeddings)
    assert isinstance(intersection_batch, torch.Tensor)
    assert intersection_batch.shape == (len(documents), model.get_sentence_embedding_dimension())

    decoded_intersection_batch = model.decode(intersection_batch)
    assert len(decoded_intersection_batch) == len(documents)