| from __future__ import annotations |
|
|
| import copy |
| import json |
| import logging |
| import re |
| import tempfile |
| from pathlib import Path |
|
|
| import numpy as np |
| import pytest |
| import torch |
|
|
| from sentence_transformers.sentence_transformer.modules import Pooling |
| from sentence_transformers.sparse_encoder.model import SparseEncoder |
| from sentence_transformers.sparse_encoder.modules import Router, SparseAutoEncoder, SpladePooling, Transformer |
| from sentence_transformers.util.similarity import SimilarityFunction |
|
|
|
|
| @pytest.mark.parametrize( |
| ("texts", "top_k", "expected_shape"), |
| [ |
| |
| (["The weather is nice!"], None, 1), |
| |
| (["The weather is nice!"], 3, 1), |
| |
| ("The weather is nice!", 8, 8), |
| |
| (["The weather is nice!", "It's sunny outside"], None, 2), |
| |
| (["The weather is nice!", "It's sunny outside"], 3, 2), |
| ], |
| ) |
| def test_decode_shapes( |
| splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, top_k: int, expected_shape: int |
| ) -> None: |
| model = splade_bert_tiny_model |
| embeddings = model.encode(texts) |
| decoded = model.decode(embeddings, top_k=top_k) |
|
|
| assert len(decoded) == expected_shape |
|
|
| if isinstance(texts, list): |
| if len(texts) == 1: |
| assert isinstance(decoded[0], tuple) or isinstance(decoded, list) |
| if top_k is not None: |
| assert len(decoded) <= top_k |
| else: |
| assert isinstance(decoded, list) |
| for item in decoded: |
| assert isinstance(item, list) |
| if top_k is not None: |
| assert len(item) <= top_k |
|
|
|
|
| @pytest.mark.parametrize( |
| ("text", "expected_token_types"), |
| [ |
| ("The weather is nice!", str), |
| ("It's sunny outside", str), |
| ], |
| ) |
| def test_decode_token_types(splade_bert_tiny_model: SparseEncoder, text: str, expected_token_types: type) -> None: |
| model = splade_bert_tiny_model |
| embeddings = model.encode(text) |
| decoded = model.decode(embeddings) |
|
|
| |
| for token, weight in decoded: |
| assert isinstance(token, expected_token_types) |
| assert isinstance(weight, float) |
|
|
|
|
| @pytest.mark.parametrize( |
| ("text", "top_k"), |
| [ |
| ("The weather is nice!", 1), |
| ("It's sunny outside", 3), |
| ("Hello world", 5), |
| ], |
| ) |
| def test_decode_top_k_respects_limit(splade_bert_tiny_model: SparseEncoder, text: str, top_k: int) -> None: |
| model = splade_bert_tiny_model |
| embeddings = model.encode([text]) |
| decoded = model.decode(embeddings, top_k=top_k) |
|
|
| assert len(decoded) <= top_k |
|
|
|
|
| @pytest.mark.parametrize( |
| ("texts", "format_type"), |
| [ |
| ("The weather is nice!", "1d"), |
| (["The weather is nice!"], "1d"), |
| (["The weather is nice!", "It's sunny outside"], "2d"), |
| ], |
| ) |
| def test_decode_handles_sparse_dense_inputs( |
| splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, format_type: str |
| ): |
| model = splade_bert_tiny_model |
| |
| embeddings = model.encode(texts) |
|
|
| |
| if not embeddings.is_sparse: |
| embeddings_sparse = embeddings.to_sparse() |
| else: |
| embeddings_sparse = embeddings |
|
|
| decoded_sparse = model.decode(embeddings_sparse) |
|
|
| |
| if embeddings.is_sparse: |
| embeddings_dense = embeddings.to_dense() |
| else: |
| embeddings_dense = embeddings |
|
|
| decoded_dense = model.decode(embeddings_dense) |
|
|
| |
| if format_type == "1d": |
| assert len(decoded_sparse) == len(decoded_dense) |
| else: |
| assert len(decoded_sparse) == len(decoded_dense) |
| for i in range(len(decoded_sparse)): |
| |
| sorted_sparse = sorted(decoded_sparse[i], key=lambda x: (x[1], x[0]), reverse=True) |
| sorted_dense = sorted(decoded_dense[i], key=lambda x: (x[1], x[0]), reverse=True) |
| assert len(sorted_sparse) == len(sorted_dense) |
|
|
|
|
| def test_decode_empty_tensor(splade_bert_tiny_model: SparseEncoder) -> None: |
| model = splade_bert_tiny_model |
| |
| empty_sparse = torch.sparse_coo_tensor( |
| indices=torch.zeros((2, 0), dtype=torch.long), |
| values=torch.zeros((0,), dtype=torch.float), |
| size=(1, model.get_embedding_dimension()), |
| ) |
|
|
| decoded = model.decode(empty_sparse) |
| assert len(decoded) == 0 or (isinstance(decoded, list) and all(not item for item in decoded)) |
|
|
|
|
| @pytest.mark.parametrize("top_k", [0, -1, -5]) |
| def test_decode_invalid_top_k(splade_bert_tiny_model: SparseEncoder, top_k: int) -> None: |
| model = splade_bert_tiny_model |
| embeddings = model.encode("Hello world") |
| with pytest.raises(ValueError, match="top_k must be a positive integer"): |
| model.decode(embeddings, top_k=top_k) |
|
|
|
|
| def test_decode_invalid_input_type(splade_bert_tiny_model: SparseEncoder) -> None: |
| model = splade_bert_tiny_model |
| with pytest.raises(TypeError, match="Expected torch.Tensor"): |
| model.decode([1, 2, 3]) |
|
|
|
|
| def test_decode_invalid_ndim(splade_bert_tiny_model: SparseEncoder) -> None: |
| model = splade_bert_tiny_model |
| tensor_3d = torch.zeros(2, 3, 4) |
| with pytest.raises(ValueError, match="Input tensor must be 1D or 2D"): |
| model.decode(tensor_3d) |
|
|
|
|
| def test_decode_batch_with_empty_sample(splade_bert_tiny_model: SparseEncoder) -> None: |
| model = splade_bert_tiny_model |
| vocab_size = model.get_embedding_dimension() |
| |
| indices = torch.tensor([[0, 0], [1, 5]]) |
| values = torch.tensor([1.0, 2.0]) |
| batch_sparse = torch.sparse_coo_tensor(indices, values, size=(2, vocab_size)) |
|
|
| decoded = model.decode(batch_sparse) |
| assert len(decoded) == 2 |
| assert len(decoded[0]) == 2 |
| assert decoded[1] == [] |
|
|
|
|
| @pytest.mark.parametrize("top_k", [None, 5, 1000]) |
| @pytest.mark.parametrize( |
| "texts", |
| [ |
| ("The weather is nice!"), |
| (["The weather is nice!"]), |
| (["The weather is nice!", "It's sunny outside", "Hello world"]), |
| (["Short text", "This is a longer text with more words to encode"]), |
| ], |
| ) |
| def test_decode_returns_sorted_weights( |
| splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, top_k: int | None |
| ) -> None: |
| model = splade_bert_tiny_model |
| embeddings = model.encode(texts) |
| decoded = model.decode(embeddings, top_k=top_k) |
|
|
| if isinstance(texts, list): |
| for item in decoded: |
| weights = [weight for _, weight in item] |
| assert all(weights[i] >= weights[i + 1] for i in range(len(weights) - 1)) |
| else: |
| weights = [weight for _, weight in decoded] |
| assert all(weights[i] >= weights[i + 1] for i in range(len(weights) - 1)) |
|
|
|
|
| def test_inference_free_splade(inference_free_splade_bert_tiny_model: SparseEncoder): |
| model = inference_free_splade_bert_tiny_model |
| dimensionality = model.get_embedding_dimension() |
|
|
| query = "What is the capital of France?" |
| document = "The capital of France is Paris." |
| query_embeddings = model.encode_query(query) |
| document_embeddings = model.encode_document(document) |
|
|
| assert query_embeddings.shape == (dimensionality,) |
| assert document_embeddings.shape == (dimensionality,) |
|
|
| decoded_query = model.decode(query_embeddings) |
| decoded_document = model.decode(document_embeddings) |
| assert len(decoded_query) == len(model.preprocess(query, task="query")["input_ids"][0]) |
| assert len(decoded_document) >= 50 |
|
|
| assert model.max_seq_length == 512 |
| assert model[0].sub_modules["query"][0].max_seq_length == 512 |
| assert model[0].sub_modules["document"][0].max_seq_length == 512 |
|
|
| model.max_seq_length = 256 |
| assert model.max_seq_length == 256 |
| assert model[0].sub_modules["query"][0].max_seq_length == 256 |
| assert model[0].sub_modules["document"][0].max_seq_length == 256 |
|
|
|
|
| def test_inference_free_splade_max_active_dims_routing(inference_free_splade_bert_tiny_model: SparseEncoder): |
| model = inference_free_splade_bert_tiny_model |
| query = "What is the capital of France?" |
| document = "The capital of France is Paris." |
|
|
| |
| query_emb = model.encode_query(query) |
| doc_emb = model.encode_document(document) |
|
|
| |
| query_emb_mad = model.encode_query(query, max_active_dims=50) |
| doc_emb_mad = model.encode_document(document, max_active_dims=50) |
|
|
| |
| query_baseline_indices = query_emb.coalesce().indices()[0] |
| query_mad_indices = query_emb_mad.coalesce().indices()[0] |
| assert set(query_mad_indices.tolist()).issubset(set(query_baseline_indices.tolist())) |
| assert query_emb_mad._nnz() <= 50 |
|
|
| doc_baseline_indices = doc_emb.coalesce().indices()[0] |
| doc_mad_indices = doc_emb_mad.coalesce().indices()[0] |
| assert set(doc_mad_indices.tolist()).issubset(set(doc_baseline_indices.tolist())) |
| assert doc_emb_mad._nnz() <= 50 |
|
|
|
|
| def test_encode_advanced_parameters(splade_bert_tiny_model: SparseEncoder, monkeypatch: pytest.MonkeyPatch): |
| """Test that additional parameters are correctly passed to encode""" |
| model = splade_bert_tiny_model |
|
|
| encode_calls = [] |
|
|
| def spy_encode(*args, **kwargs): |
| encode_calls.append((args, kwargs)) |
|
|
| monkeypatch.setattr(model, "encode", spy_encode) |
| |
| model.encode_query( |
| "test", |
| normalize_embeddings=True, |
| batch_size=64, |
| show_progress_bar=True, |
| max_active_dims=128, |
| chunk_size=10, |
| custom_param="value", |
| ) |
|
|
| |
| _, kwargs = encode_calls[0] |
| assert kwargs["normalize_embeddings"] is True |
| assert kwargs["batch_size"] == 64 |
| assert kwargs["show_progress_bar"] is True |
| assert kwargs["max_active_dims"] == 128 |
| assert kwargs["chunk_size"] == 10 |
| assert kwargs["custom_param"] == "value" |
|
|
|
|
| def test_csr_max_active_dims_passed_to_forward(csr_bert_tiny_model: SparseEncoder, monkeypatch: pytest.MonkeyPatch): |
| model = csr_bert_tiny_model |
| assert isinstance(model[-1], SparseAutoEncoder) |
| assert model[-1].k == 16 |
|
|
| |
| forward_calls = [] |
| original_forward = model[-1].forward |
|
|
| def spy_forward(*args, **kwargs): |
| forward_calls.append(kwargs) |
| return original_forward(*args, **kwargs) |
|
|
| monkeypatch.setattr(model[-1], "forward", spy_forward) |
|
|
| model.encode("Hello world", max_active_dims=5) |
| assert len(forward_calls) == 1 |
| assert forward_calls[0]["max_active_dims"] == 5 |
|
|
| |
| forward_calls.clear() |
| model.encode("Hello world") |
| assert len(forward_calls) == 1 |
| assert forward_calls[0]["max_active_dims"] == model.max_active_dims |
|
|
|
|
| def test_max_active_dims_set_init(splade_bert_tiny_model: SparseEncoder, csr_bert_tiny_model: SparseEncoder, tmp_path): |
| splade_bert_tiny_model.save_pretrained(str(tmp_path / "splade_bert_tiny")) |
| csr_bert_tiny_model.save_pretrained(str(tmp_path / "csr_bert_tiny")) |
|
|
| |
| loaded_model = SparseEncoder(str(tmp_path / "splade_bert_tiny")) |
| assert loaded_model.max_active_dims is None |
| loaded_model = SparseEncoder(str(tmp_path / "splade_bert_tiny"), max_active_dims=13) |
| assert loaded_model.max_active_dims == 13 |
|
|
| loaded_model = SparseEncoder(str(tmp_path / "csr_bert_tiny")) |
| assert loaded_model.max_active_dims == 16 |
| loaded_model = SparseEncoder(str(tmp_path / "csr_bert_tiny"), max_active_dims=13) |
| assert loaded_model.max_active_dims == 13 |
|
|
|
|
| def test_detect_mlm(): |
| model = SparseEncoder("distilbert/distilbert-base-uncased") |
|
|
| assert isinstance(model[0], Transformer) |
| assert model[0].transformer_task == "fill-mask" |
| assert isinstance(model[1], SpladePooling) |
|
|
|
|
| def test_default_to_csr(): |
| |
| |
| model = SparseEncoder("sentence-transformers-testing/stsb-bert-tiny-safetensors") |
| assert isinstance(model[0], Transformer) |
| assert isinstance(model[1], Pooling) |
| assert isinstance(model[2], SparseAutoEncoder) |
|
|
|
|
| def test_sparsity(splade_bert_tiny_model: SparseEncoder): |
| model = splade_bert_tiny_model |
|
|
| |
| embeddings = model.encode_query(["What is the capital of France?", "Who has won the World Cup in 2016?"]) |
| sparsity = model.sparsity(embeddings) |
| assert isinstance(sparsity, dict) |
| assert "active_dims" in sparsity |
| assert "sparsity_ratio" in sparsity |
| assert sparsity["active_dims"] < 100 and sparsity["active_dims"] > 0 |
| assert sparsity["sparsity_ratio"] < 1.0 and sparsity["sparsity_ratio"] >= 0.99 |
|
|
| |
| dense_sparsity = model.sparsity(embeddings.to_dense()) |
| assert dense_sparsity == sparsity, "Sparsity should be the same for dense and sparse tensors" |
|
|
| |
| sparsity_one = model.sparsity(embeddings[0]) |
| sparsity_two = model.sparsity(embeddings[1]) |
| assert (sparsity_one["active_dims"] + sparsity_two["active_dims"]) / 2 == sparsity["active_dims"] |
|
|
|
|
| def test_splade_pooling_chunk_size(splade_bert_tiny_model: SparseEncoder): |
| model = splade_bert_tiny_model |
|
|
| |
| assert model.splade_pooling_chunk_size is None |
| |
| model.splade_pooling_chunk_size = 13 |
| assert model.splade_pooling_chunk_size == 13 |
| assert isinstance(model[1], SpladePooling) |
| assert model[1].chunk_size == 13 |
|
|
|
|
| def test_intersection(splade_bert_tiny_model: SparseEncoder): |
| model = splade_bert_tiny_model |
|
|
| |
| query = "Where can I deposit my money?" |
| document = "I'm sitting by the river." |
| query_embeddings = model.encode_query(query) |
| document_embeddings = model.encode_document(document) |
| query_sparsity = model.sparsity(query_embeddings) |
| document_sparsity = model.sparsity(document_embeddings) |
|
|
| |
| intersection = model.intersection(query_embeddings, document_embeddings) |
| assert isinstance(intersection, torch.Tensor) |
| assert intersection.shape == (model.get_embedding_dimension(),) |
|
|
| |
| intersection_sparsity = model.sparsity(intersection) |
| assert ( |
| intersection_sparsity["active_dims"] < query_sparsity["active_dims"] |
| and intersection_sparsity["active_dims"] < document_sparsity["active_dims"] |
| ) |
|
|
| |
| query = "Who has won the World Cup in 2016?" |
| documents = ["The capital of France is Paris.", "Germany won the World Cup in 2014."] |
| query_embeddings = model.encode_query(query) |
| document_embeddings = model.encode_document(documents) |
|
|
| intersection_batch = model.intersection(query_embeddings, document_embeddings) |
| assert isinstance(intersection_batch, torch.Tensor) |
| assert intersection_batch.shape == (len(documents), model.get_embedding_dimension()) |
|
|
| decoded_intersection_batch = model.decode(intersection_batch) |
| assert len(decoded_intersection_batch) == len(documents) |
|
|
|
|
| def test_encode_with_dataset_column(splade_bert_tiny_model: SparseEncoder) -> None: |
| """Test that encode can handle a dataset column as input.""" |
| model = splade_bert_tiny_model |
| from datasets import Dataset |
|
|
| |
| dataset = Dataset.from_dict({"text": ["This is a test.", "Another sentence."]}) |
|
|
| |
| embeddings = model.encode(dataset["text"], convert_to_tensor=True) |
|
|
| |
| assert embeddings.shape == (2, model.get_embedding_dimension()) |
|
|
|
|
| def test_encode_numpy_1d_string_array(splade_bert_tiny_model: SparseEncoder) -> None: |
| """Regression test for #3718: encoding a 1D numpy string array should produce one embedding per element.""" |
| model = splade_bert_tiny_model |
| texts = np.array(["Access Management", "Press Coordination", "Financial Reports"]) |
| embeddings = model.encode(texts, convert_to_tensor=True, save_to_cpu=True) |
| expected = model.encode(texts.tolist(), convert_to_tensor=True, save_to_cpu=True) |
| assert embeddings.shape == (3, model.get_embedding_dimension()) |
| assert torch.allclose(embeddings.to_dense(), expected.to_dense()) |
|
|
|
|
| def test_encode_numpy_2d_string_array(splade_bert_tiny_model: SparseEncoder) -> None: |
| """Encoding a 2D numpy string array should match encoding the equivalent nested list.""" |
| model = splade_bert_tiny_model |
| pairs = np.array([["what is AI?", "AI is artificial intelligence."], ["what is ML?", "ML is machine learning."]]) |
| embeddings = model.encode(pairs, convert_to_tensor=True, save_to_cpu=True) |
| expected = model.encode(pairs.tolist(), convert_to_tensor=True, save_to_cpu=True) |
| assert embeddings.shape == (2, model.get_embedding_dimension()) |
| assert torch.allclose(embeddings.to_dense(), expected.to_dense()) |
|
|
|
|
| def test_encode_numpy_empty(splade_bert_tiny_model: SparseEncoder) -> None: |
| """Encoding an empty string ndarray should return an empty tensor, like ``encode([])``.""" |
| model = splade_bert_tiny_model |
| embeddings = model.encode(np.array([], dtype=str), convert_to_tensor=True, save_to_cpu=True) |
| expected = model.encode([], convert_to_tensor=True, save_to_cpu=True) |
| assert embeddings.numel() == 0 |
| assert torch.equal(embeddings.to_dense(), expected.to_dense()) |
|
|
|
|
| @pytest.mark.parametrize("convert_to_tensor", [True, False]) |
| @pytest.mark.parametrize("convert_to_sparse_tensor", [True, False]) |
| @pytest.mark.parametrize("save_to_cpu", [True, False]) |
| @pytest.mark.parametrize("max_active_dims", [None, 64, 128]) |
| def test_empty_encode( |
| splade_bert_tiny_model: SparseEncoder, |
| convert_to_tensor: bool, |
| convert_to_sparse_tensor: bool, |
| save_to_cpu: bool, |
| max_active_dims: int | None, |
| ): |
| model = splade_bert_tiny_model |
| embeddings = model.encode( |
| [], |
| convert_to_tensor=convert_to_tensor, |
| convert_to_sparse_tensor=convert_to_sparse_tensor, |
| save_to_cpu=save_to_cpu, |
| max_active_dims=max_active_dims, |
| ) |
|
|
| if convert_to_tensor: |
| assert isinstance(embeddings, torch.Tensor) |
| assert embeddings.numel() == 0 |
| if save_to_cpu: |
| assert embeddings.device == torch.device("cpu") |
| else: |
| assert embeddings.device == model.device |
|
|
| if convert_to_sparse_tensor: |
| assert embeddings.is_sparse |
| else: |
| assert not embeddings.is_sparse |
| else: |
| assert embeddings == [] |
|
|
|
|
| def test_get_model_kwargs(splade_bert_tiny_model: SparseEncoder) -> None: |
| """Test that get_model_kwargs returns the correct keyword arguments.""" |
| model = splade_bert_tiny_model |
|
|
| |
| |
| forward_kwargs = model.get_model_kwargs() |
| assert forward_kwargs == [] |
| with pytest.raises( |
| ValueError, |
| match=re.escape( |
| "SparseEncoder.encode() has been called with additional keyword arguments that this model does " |
| "not use: ['normalize']. As per SparseEncoder.get_model_kwargs(), this model does not accept " |
| "any additional keyword arguments." |
| ), |
| ): |
| |
| model.encode("Test sentence", normalize=True) |
| |
| model.encode("Test sentence") |
| model.encode_query("Test sentence") |
|
|
| |
| model[0].forward_kwargs = {"foo"} |
| model[1].forward_kwargs = {"bar", "baz"} |
| assert set(model.get_model_kwargs()) == {"foo", "bar", "baz"} |
| with pytest.raises( |
| ValueError, |
| match=re.escape( |
| "SparseEncoder.encode() has been called with additional keyword arguments that this model does " |
| "not use: ['normalize']. As per SparseEncoder.get_model_kwargs(), the valid additional keyword" |
| " arguments are: " |
| ) |
| + r"\[('foo'|'bar'|'baz'|, ){5}\].", |
| ): |
| |
| model.encode("Test sentence", normalize=True) |
| |
| model.encode("Test sentence") |
| model.encode_query("Test sentence") |
| with pytest.raises( |
| TypeError, |
| match=r"(Transformer\.)?forward\(\) got an unexpected keyword argument '(foo|bar)'", |
| ): |
| |
| |
| model.encode("Test sentence", foo=True, bar=False) |
|
|
| |
| query_pooling_copy = copy.deepcopy(model[1]) |
| query_pooling_copy.forward_kwargs = {"query_arg"} |
| document_pooling_copy = copy.deepcopy(model[1]) |
| document_pooling_copy.forward_kwargs = {"document_arg_1", "document_arg_2"} |
| model[1] = Router.for_query_document( |
| query_modules=[query_pooling_copy], |
| document_modules=[document_pooling_copy], |
| ) |
| assert set(model.get_model_kwargs()) == { |
| "foo", |
| "task", |
| "query_arg", |
| "document_arg_1", |
| "document_arg_2", |
| "modality", |
| } |
| with pytest.raises( |
| ValueError, |
| match=re.escape( |
| "SparseEncoder.encode() has been called with additional keyword arguments that this model does " |
| "not use: ['normalize']. As per SparseEncoder.get_model_kwargs(), the valid additional keyword" |
| " arguments are: " |
| ) |
| + r"\[('foo'|'task'|'query_arg'|'document_arg_1'|'document_arg_2'|'modality'|, ){11}\].", |
| ): |
| |
| model.encode("Test sentence", task="query", normalize=True) |
| |
| model.encode("Test sentence", task="document") |
| model.encode_query("Test sentence") |
| with pytest.raises( |
| TypeError, |
| match=r"(Transformer\.)?forward\(\) got an unexpected keyword argument '(foo|document_arg_1)'", |
| ): |
| |
| |
| model.encode("Test sentence", task="document", foo=True, document_arg_1=12) |
|
|
|
|
| @pytest.mark.parametrize("similarity_fn_name", SimilarityFunction.possible_values()) |
| def test_similarity_score(splade_bert_tiny_model: SparseEncoder, similarity_fn_name: str) -> None: |
| model = splade_bert_tiny_model |
| model.similarity_fn_name = similarity_fn_name |
| sentences = [ |
| "The weather is so nice!", |
| "It's so sunny outside.", |
| "He's driving to the movie theater.", |
| "She's going to the cinema.", |
| ] |
| embeddings = model.encode(sentences, convert_to_sparse_tensor=False) |
| scores = model.similarity(embeddings, embeddings) |
| assert scores.shape == (len(sentences), len(sentences)) |
| diag = np.diag(scores.cpu().numpy()) |
| if similarity_fn_name == "cosine": |
| np.testing.assert_almost_equal(diag, np.ones(len(sentences), dtype=float), decimal=4) |
| elif similarity_fn_name in ("euclidean", "manhattan"): |
| np.testing.assert_almost_equal(diag, np.zeros(len(sentences), dtype=float), decimal=4) |
| else: |
| assert (diag > 0).all() |
|
|
| pairwise_scores = model.similarity_pairwise(embeddings[::2], embeddings[1::2]) |
| assert pairwise_scores.shape == (len(sentences) // 2,) |
|
|
|
|
| def test_similarity_score_save(splade_bert_tiny_model: SparseEncoder, tmp_path: Path) -> None: |
| model = splade_bert_tiny_model |
| assert model.similarity_fn_name == "dot" |
|
|
| model.similarity_fn_name = "cosine" |
| model.save(str(tmp_path)) |
| loaded_model = SparseEncoder(str(tmp_path)) |
| assert loaded_model.similarity_fn_name == "cosine" |
|
|
|
|
| def test_similarity_fn_name_set_via_enum(splade_bert_tiny_model: SparseEncoder) -> None: |
| model = splade_bert_tiny_model |
| model.similarity_fn_name = SimilarityFunction.COSINE |
| assert model.similarity_fn_name == "cosine" |
| model.similarity_fn_name = SimilarityFunction.DOT |
| assert model.similarity_fn_name == "dot" |
|
|
|
|
| def test_similarity_fn_name_constructor_overrides_saved(splade_bert_tiny_model: SparseEncoder, tmp_path: Path) -> None: |
| splade_bert_tiny_model.similarity_fn_name = "cosine" |
| splade_bert_tiny_model.save(str(tmp_path)) |
| model = SparseEncoder(str(tmp_path), similarity_fn_name="dot") |
| assert model.similarity_fn_name == "dot" |
|
|
|
|
| def test_prompts(splade_bert_tiny_model: SparseEncoder, caplog: pytest.LogCaptureFixture) -> None: |
| model = splade_bert_tiny_model |
| assert model.prompts == {"query": "", "document": ""} |
| assert model.default_prompt_name is None |
| texts = ["How to bake a chocolate cake", "Symptoms of the flu"] |
| no_prompt_embedding = model.encode(texts, convert_to_sparse_tensor=False, save_to_cpu=True) |
| prompt_embedding = model.encode( |
| [f"query: {text}" for text in texts], convert_to_sparse_tensor=False, save_to_cpu=True |
| ) |
| assert not np.array_equal(no_prompt_embedding, prompt_embedding) |
|
|
| query = "query: " |
| |
| model.prompts = {"query": "", "document": ""} |
| assert np.array_equal( |
| model.encode(texts, prompt=query, convert_to_sparse_tensor=False, save_to_cpu=True), prompt_embedding |
| ) |
|
|
| |
| model.prompts = {"query": query, "document": ""} |
| assert np.array_equal( |
| model.encode(texts, prompt_name="query", convert_to_sparse_tensor=False, save_to_cpu=True), prompt_embedding |
| ) |
|
|
| caplog.clear() |
| |
| with caplog.at_level(logging.WARNING): |
| assert np.array_equal( |
| model.encode(texts, prompt=query, prompt_name="query", convert_to_sparse_tensor=False, save_to_cpu=True), |
| prompt_embedding, |
| ) |
| assert len(caplog.record_tuples) == 1 |
| assert ( |
| caplog.record_tuples[0][2] == "Provide either a `prompt`, a `prompt_name`, or neither, but not both. " |
| "Ignoring the `prompt_name` in favor of `prompt`." |
| ) |
|
|
| with pytest.raises( |
| ValueError, |
| match=re.escape( |
| "Prompt name 'invalid_prompt_name' not found in the configured prompts dictionary with keys ['query', 'document']." |
| ), |
| ): |
| model.encode(texts, prompt_name="invalid_prompt_name") |
|
|
|
|
| def test_save_load_prompts() -> None: |
| with pytest.raises( |
| ValueError, |
| match=re.escape( |
| "Default prompt name 'invalid_prompt_name' not found in the configured prompts dictionary with keys ['query', 'document']." |
| ), |
| ): |
| SparseEncoder( |
| "sparse-encoder-testing/splade-bert-tiny-nq", |
| prompts={"query": "query: "}, |
| default_prompt_name="invalid_prompt_name", |
| ) |
|
|
| model = SparseEncoder( |
| "sparse-encoder-testing/splade-bert-tiny-nq", |
| prompts={"query": "query: "}, |
| default_prompt_name="query", |
| ) |
| assert model.prompts == {"query": "query: ", "document": ""} |
| assert model.default_prompt_name == "query" |
|
|
| with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmp_folder: |
| model_path = Path(tmp_folder) / "tiny_model_local" |
| model.save(str(model_path)) |
| config_path = model_path / "config_sentence_transformers.json" |
| assert config_path.exists() |
| with open(config_path, encoding="utf8") as f: |
| saved_config = json.load(f) |
| assert saved_config["prompts"] == {"query": "query: ", "document": ""} |
| assert saved_config["default_prompt_name"] == "query" |
|
|
| fresh_model = SparseEncoder(str(model_path)) |
| assert fresh_model.prompts == {"query": "query: ", "document": ""} |
| assert fresh_model.default_prompt_name == "query" |
|
|
|
|
| @pytest.mark.parametrize("sentences", ["Hello world", ["Hello world", "This is a test"], [], [""]]) |
| @pytest.mark.parametrize("convert_to_tensor", [True, False]) |
| @pytest.mark.parametrize("convert_to_sparse_tensor", [True, False]) |
| @pytest.mark.parametrize("prompt_name", [None, "query", "custom"]) |
| @pytest.mark.parametrize("prompt", [None, "Custom prompt: "]) |
| def test_encode_query( |
| splade_bert_tiny_model: SparseEncoder, |
| sentences: str | list[str], |
| convert_to_tensor: bool, |
| convert_to_sparse_tensor: bool, |
| prompt_name: str | None, |
| prompt: str | None, |
| monkeypatch: pytest.MonkeyPatch, |
| ): |
| model = splade_bert_tiny_model |
| model.prompts = {"query": "query: ", "custom": "custom: "} |
|
|
| encode_calls = [] |
|
|
| def spy_encode(*args, **kwargs): |
| encode_calls.append((args, kwargs)) |
|
|
| monkeypatch.setattr(model, "encode", spy_encode) |
| model.encode_query( |
| sentences=sentences, |
| batch_size=32, |
| convert_to_tensor=convert_to_tensor, |
| convert_to_sparse_tensor=convert_to_sparse_tensor, |
| prompt_name=prompt_name, |
| prompt=prompt, |
| ) |
|
|
| if prompt_name: |
| expected_prompt_name = prompt_name |
| elif prompt is not None: |
| expected_prompt_name = None |
| else: |
| expected_prompt_name = "query" |
|
|
| assert len(encode_calls) == 1 |
| _, kwargs = encode_calls[0] |
|
|
| assert kwargs["inputs"] == sentences |
| assert kwargs["convert_to_tensor"] == convert_to_tensor |
| assert kwargs["convert_to_sparse_tensor"] == convert_to_sparse_tensor |
| assert kwargs["prompt"] == prompt |
| assert kwargs["prompt_name"] == expected_prompt_name |
| assert kwargs["task"] == "query" |
|
|
|
|
| @pytest.mark.parametrize("sentences", ["Hello world", ["Hello world", "This is a test"], [], [""]]) |
| @pytest.mark.parametrize("convert_to_tensor", [True, False]) |
| @pytest.mark.parametrize("convert_to_sparse_tensor", [True, False]) |
| @pytest.mark.parametrize("prompt_name", [None, "document", "passage", "corpus", "custom"]) |
| @pytest.mark.parametrize("prompt", [None, "Custom prompt: "]) |
| def test_encode_document( |
| splade_bert_tiny_model: SparseEncoder, |
| sentences: str | list[str], |
| convert_to_tensor: bool, |
| convert_to_sparse_tensor: bool, |
| prompt_name: str | None, |
| prompt: str | None, |
| monkeypatch: pytest.MonkeyPatch, |
| ): |
| model = splade_bert_tiny_model |
| model.prompts = {"document": "document: ", "passage": "passage: ", "corpus": "corpus: ", "custom": "custom: "} |
|
|
| encode_calls = [] |
|
|
| def spy_encode(*args, **kwargs): |
| encode_calls.append((args, kwargs)) |
|
|
| monkeypatch.setattr(model, "encode", spy_encode) |
| model.encode_document( |
| sentences=sentences, |
| batch_size=32, |
| convert_to_tensor=convert_to_tensor, |
| convert_to_sparse_tensor=convert_to_sparse_tensor, |
| prompt_name=prompt_name, |
| prompt=prompt, |
| ) |
|
|
| assert len(encode_calls) == 1 |
| _, kwargs = encode_calls[0] |
|
|
| if prompt_name: |
| expected_prompt_name = prompt_name |
| elif prompt is not None: |
| expected_prompt_name = None |
| else: |
| expected_prompt_name = "document" |
|
|
| assert kwargs["inputs"] == sentences |
| assert kwargs["convert_to_tensor"] == convert_to_tensor |
| assert kwargs["convert_to_sparse_tensor"] == convert_to_sparse_tensor |
| assert kwargs["prompt"] == prompt |
| assert kwargs["prompt_name"] == expected_prompt_name |
| assert kwargs["task"] == "document" |
|
|
|
|
| def test_encode_document_prompt_priority(splade_bert_tiny_model: SparseEncoder, monkeypatch: pytest.MonkeyPatch): |
| """Test that proper prompt priority is respected when multiple options are available""" |
| model = splade_bert_tiny_model |
| model.prompts = { |
| "document": "document: ", |
| "passage": "passage: ", |
| "corpus": "corpus: ", |
| } |
|
|
| encode_calls = [] |
|
|
| def spy_encode(*args, **kwargs): |
| encode_calls.append((args, kwargs)) |
|
|
| monkeypatch.setattr(model, "encode", spy_encode) |
|
|
| model.encode_document("test") |
| _, kwargs = encode_calls[-1] |
| assert kwargs["prompt_name"] == "document" |
|
|
| |
| encode_calls.clear() |
| model.prompts = {"passage": "passage: ", "corpus": "corpus: "} |
| model.encode_document("test") |
| _, kwargs = encode_calls[-1] |
| assert kwargs["prompt_name"] == "passage" |
|
|
| |
| encode_calls.clear() |
| model.prompts = {"corpus": "corpus: "} |
| model.encode_document("test") |
| _, kwargs = encode_calls[-1] |
| assert kwargs["prompt_name"] == "corpus" |
|
|
| |
| encode_calls.clear() |
| model.prompts = {"custom": "custom: "} |
| model.encode_document("test") |
| _, kwargs = encode_calls[-1] |
| assert kwargs["prompt_name"] is None |
|
|