from __future__ import annotations import copy import json import logging import re import tempfile from pathlib import Path import numpy as np import pytest import torch from sentence_transformers.sentence_transformer.modules import Pooling from sentence_transformers.sparse_encoder.model import SparseEncoder from sentence_transformers.sparse_encoder.modules import Router, SparseAutoEncoder, SpladePooling, Transformer from sentence_transformers.util.similarity import SimilarityFunction @pytest.mark.parametrize( ("texts", "top_k", "expected_shape"), [ # Single text, default top_k (None) (["The weather is nice!"], None, 1), # Single text, specific top_k (["The weather is nice!"], 3, 1), # String text, specific top_k, expect a non-nested list ("The weather is nice!", 8, 8), # Multiple texts, default top_k (None) (["The weather is nice!", "It's sunny outside"], None, 2), # Multiple texts, specific top_k (["The weather is nice!", "It's sunny outside"], 3, 2), ], ) def test_decode_shapes( splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, top_k: int, expected_shape: int ) -> None: model = splade_bert_tiny_model embeddings = model.encode(texts) decoded = model.decode(embeddings, top_k=top_k) assert len(decoded) == expected_shape if isinstance(texts, list): if len(texts) == 1: assert isinstance(decoded[0], tuple) or isinstance(decoded, list) if top_k is not None: assert len(decoded) <= top_k else: assert isinstance(decoded, list) for item in decoded: assert isinstance(item, list) if top_k is not None: assert len(item) <= top_k @pytest.mark.parametrize( ("text", "expected_token_types"), [ ("The weather is nice!", str), ("It's sunny outside", str), ], ) def test_decode_token_types(splade_bert_tiny_model: SparseEncoder, text: str, expected_token_types: type) -> None: model = splade_bert_tiny_model embeddings = model.encode(text) decoded = model.decode(embeddings) # Check the first item in the batch for token, weight in decoded: assert isinstance(token, expected_token_types) assert isinstance(weight, float) @pytest.mark.parametrize( ("text", "top_k"), [ ("The weather is nice!", 1), ("It's sunny outside", 3), ("Hello world", 5), ], ) def test_decode_top_k_respects_limit(splade_bert_tiny_model: SparseEncoder, text: str, top_k: int) -> None: model = splade_bert_tiny_model embeddings = model.encode([text]) decoded = model.decode(embeddings, top_k=top_k) assert len(decoded) <= top_k @pytest.mark.parametrize( ("texts", "format_type"), [ ("The weather is nice!", "1d"), (["The weather is nice!"], "1d"), (["The weather is nice!", "It's sunny outside"], "2d"), ], ) def test_decode_handles_sparse_dense_inputs( splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, format_type: str ): model = splade_bert_tiny_model # Get embeddings and test both sparse and dense format handling embeddings = model.encode(texts) # Test with sparse tensor if not embeddings.is_sparse: embeddings_sparse = embeddings.to_sparse() else: embeddings_sparse = embeddings decoded_sparse = model.decode(embeddings_sparse) # Test with dense tensor if embeddings.is_sparse: embeddings_dense = embeddings.to_dense() else: embeddings_dense = embeddings decoded_dense = model.decode(embeddings_dense) # Verify both produce the same result structure if format_type == "1d": assert len(decoded_sparse) == len(decoded_dense) else: assert len(decoded_sparse) == len(decoded_dense) for i in range(len(decoded_sparse)): # Sort both results to ensure consistent comparison sorted_sparse = sorted(decoded_sparse[i], key=lambda x: (x[1], x[0]), reverse=True) sorted_dense = sorted(decoded_dense[i], key=lambda x: (x[1], x[0]), reverse=True) assert len(sorted_sparse) == len(sorted_dense) def test_decode_empty_tensor(splade_bert_tiny_model: SparseEncoder) -> None: model = splade_bert_tiny_model # Create an empty sparse tensor empty_sparse = torch.sparse_coo_tensor( indices=torch.zeros((2, 0), dtype=torch.long), values=torch.zeros((0,), dtype=torch.float), size=(1, model.get_embedding_dimension()), ) decoded = model.decode(empty_sparse) assert len(decoded) == 0 or (isinstance(decoded, list) and all(not item for item in decoded)) @pytest.mark.parametrize("top_k", [0, -1, -5]) def test_decode_invalid_top_k(splade_bert_tiny_model: SparseEncoder, top_k: int) -> None: model = splade_bert_tiny_model embeddings = model.encode("Hello world") with pytest.raises(ValueError, match="top_k must be a positive integer"): model.decode(embeddings, top_k=top_k) def test_decode_invalid_input_type(splade_bert_tiny_model: SparseEncoder) -> None: model = splade_bert_tiny_model with pytest.raises(TypeError, match="Expected torch.Tensor"): model.decode([1, 2, 3]) def test_decode_invalid_ndim(splade_bert_tiny_model: SparseEncoder) -> None: model = splade_bert_tiny_model tensor_3d = torch.zeros(2, 3, 4) with pytest.raises(ValueError, match="Input tensor must be 1D or 2D"): model.decode(tensor_3d) def test_decode_batch_with_empty_sample(splade_bert_tiny_model: SparseEncoder) -> None: model = splade_bert_tiny_model vocab_size = model.get_embedding_dimension() # Create a batch where the first sample has values but the second is all zeros indices = torch.tensor([[0, 0], [1, 5]]) # both non-zero entries in sample 0 values = torch.tensor([1.0, 2.0]) batch_sparse = torch.sparse_coo_tensor(indices, values, size=(2, vocab_size)) decoded = model.decode(batch_sparse) assert len(decoded) == 2 assert len(decoded[0]) == 2 # sample 0 has 2 non-zero entries assert decoded[1] == [] # sample 1 is empty @pytest.mark.parametrize("top_k", [None, 5, 1000]) @pytest.mark.parametrize( "texts", [ ("The weather is nice!"), (["The weather is nice!"]), (["The weather is nice!", "It's sunny outside", "Hello world"]), (["Short text", "This is a longer text with more words to encode"]), ], ) def test_decode_returns_sorted_weights( splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, top_k: int | None ) -> None: model = splade_bert_tiny_model embeddings = model.encode(texts) decoded = model.decode(embeddings, top_k=top_k) if isinstance(texts, list): for item in decoded: weights = [weight for _, weight in item] assert all(weights[i] >= weights[i + 1] for i in range(len(weights) - 1)) else: weights = [weight for _, weight in decoded] assert all(weights[i] >= weights[i + 1] for i in range(len(weights) - 1)) def test_inference_free_splade(inference_free_splade_bert_tiny_model: SparseEncoder): model = inference_free_splade_bert_tiny_model dimensionality = model.get_embedding_dimension() query = "What is the capital of France?" document = "The capital of France is Paris." query_embeddings = model.encode_query(query) document_embeddings = model.encode_document(document) assert query_embeddings.shape == (dimensionality,) assert document_embeddings.shape == (dimensionality,) decoded_query = model.decode(query_embeddings) decoded_document = model.decode(document_embeddings) assert len(decoded_query) == len(model.preprocess(query, task="query")["input_ids"][0]) assert len(decoded_document) >= 50 assert model.max_seq_length == 512 assert model[0].sub_modules["query"][0].max_seq_length == 512 assert model[0].sub_modules["document"][0].max_seq_length == 512 model.max_seq_length = 256 assert model.max_seq_length == 256 assert model[0].sub_modules["query"][0].max_seq_length == 256 assert model[0].sub_modules["document"][0].max_seq_length == 256 def test_inference_free_splade_max_active_dims_routing(inference_free_splade_bert_tiny_model: SparseEncoder): model = inference_free_splade_bert_tiny_model query = "What is the capital of France?" document = "The capital of France is Paris." # Encode without max_active_dims — baseline query_emb = model.encode_query(query) doc_emb = model.encode_document(document) # Encode with max_active_dims — should route to the same sub-modules query_emb_mad = model.encode_query(query, max_active_dims=50) doc_emb_mad = model.encode_document(document, max_active_dims=50) # The non-zero indices of the max_active_dims result should be a subset of the baseline query_baseline_indices = query_emb.coalesce().indices()[0] query_mad_indices = query_emb_mad.coalesce().indices()[0] assert set(query_mad_indices.tolist()).issubset(set(query_baseline_indices.tolist())) assert query_emb_mad._nnz() <= 50 doc_baseline_indices = doc_emb.coalesce().indices()[0] doc_mad_indices = doc_emb_mad.coalesce().indices()[0] assert set(doc_mad_indices.tolist()).issubset(set(doc_baseline_indices.tolist())) assert doc_emb_mad._nnz() <= 50 def test_encode_advanced_parameters(splade_bert_tiny_model: SparseEncoder, monkeypatch: pytest.MonkeyPatch): """Test that additional parameters are correctly passed to encode""" model = splade_bert_tiny_model encode_calls = [] def spy_encode(*args, **kwargs): encode_calls.append((args, kwargs)) monkeypatch.setattr(model, "encode", spy_encode) # Call with advanced parameters model.encode_query( "test", normalize_embeddings=True, batch_size=64, show_progress_bar=True, max_active_dims=128, chunk_size=10, custom_param="value", ) # Verify all parameters were passed correctly _, kwargs = encode_calls[0] assert kwargs["normalize_embeddings"] is True assert kwargs["batch_size"] == 64 assert kwargs["show_progress_bar"] is True assert kwargs["max_active_dims"] == 128 assert kwargs["chunk_size"] == 10 assert kwargs["custom_param"] == "value" def test_csr_max_active_dims_passed_to_forward(csr_bert_tiny_model: SparseEncoder, monkeypatch: pytest.MonkeyPatch): model = csr_bert_tiny_model assert isinstance(model[-1], SparseAutoEncoder) assert model[-1].k == 16 # Verify that max_active_dims is passed to SparseAutoEncoder.forward() forward_calls = [] original_forward = model[-1].forward def spy_forward(*args, **kwargs): forward_calls.append(kwargs) return original_forward(*args, **kwargs) monkeypatch.setattr(model[-1], "forward", spy_forward) model.encode("Hello world", max_active_dims=5) assert len(forward_calls) == 1 assert forward_calls[0]["max_active_dims"] == 5 # Without max_active_dims, the model's default max_active_dims is used forward_calls.clear() model.encode("Hello world") assert len(forward_calls) == 1 assert forward_calls[0]["max_active_dims"] == model.max_active_dims def test_max_active_dims_set_init(splade_bert_tiny_model: SparseEncoder, csr_bert_tiny_model: SparseEncoder, tmp_path): splade_bert_tiny_model.save_pretrained(str(tmp_path / "splade_bert_tiny")) csr_bert_tiny_model.save_pretrained(str(tmp_path / "csr_bert_tiny")) # Load the models with max_active_dims set loaded_model = SparseEncoder(str(tmp_path / "splade_bert_tiny")) assert loaded_model.max_active_dims is None loaded_model = SparseEncoder(str(tmp_path / "splade_bert_tiny"), max_active_dims=13) assert loaded_model.max_active_dims == 13 loaded_model = SparseEncoder(str(tmp_path / "csr_bert_tiny")) assert loaded_model.max_active_dims == 16 # Based on the SparseAutoEncoder's k value loaded_model = SparseEncoder(str(tmp_path / "csr_bert_tiny"), max_active_dims=13) assert loaded_model.max_active_dims == 13 def test_detect_mlm(): model = SparseEncoder("distilbert/distilbert-base-uncased") assert isinstance(model[0], Transformer) assert model[0].transformer_task == "fill-mask" assert isinstance(model[1], SpladePooling) def test_default_to_csr(): # NOTE: bert-tiny is actually MLM-based, but the config isn't modern enough to allow us to detect it, # so we should default to CSR here. model = SparseEncoder("sentence-transformers-testing/stsb-bert-tiny-safetensors") assert isinstance(model[0], Transformer) assert isinstance(model[1], Pooling) assert isinstance(model[2], SparseAutoEncoder) def test_sparsity(splade_bert_tiny_model: SparseEncoder): model = splade_bert_tiny_model # Check that the sparsity is applied correctly embeddings = model.encode_query(["What is the capital of France?", "Who has won the World Cup in 2016?"]) sparsity = model.sparsity(embeddings) assert isinstance(sparsity, dict) assert "active_dims" in sparsity assert "sparsity_ratio" in sparsity assert sparsity["active_dims"] < 100 and sparsity["active_dims"] > 0 assert sparsity["sparsity_ratio"] < 1.0 and sparsity["sparsity_ratio"] >= 0.99 # Also check with dense tensors dense_sparsity = model.sparsity(embeddings.to_dense()) assert dense_sparsity == sparsity, "Sparsity should be the same for dense and sparse tensors" # Check that 1-dimensional embeddings work correctly sparsity_one = model.sparsity(embeddings[0]) sparsity_two = model.sparsity(embeddings[1]) assert (sparsity_one["active_dims"] + sparsity_two["active_dims"]) / 2 == sparsity["active_dims"] def test_splade_pooling_chunk_size(splade_bert_tiny_model: SparseEncoder): model = splade_bert_tiny_model # The chunk size defaults to None, i.e. no chunking assert model.splade_pooling_chunk_size is None # But we can chunk the pooling to save memory at the cost of some speed model.splade_pooling_chunk_size = 13 assert model.splade_pooling_chunk_size == 13 assert isinstance(model[1], SpladePooling) assert model[1].chunk_size == 13 def test_intersection(splade_bert_tiny_model: SparseEncoder): model = splade_bert_tiny_model # Test intersection with a single text query = "Where can I deposit my money?" document = "I'm sitting by the river." query_embeddings = model.encode_query(query) document_embeddings = model.encode_document(document) query_sparsity = model.sparsity(query_embeddings) document_sparsity = model.sparsity(document_embeddings) # Let's check that the intersection is a tensor and has the correct shape intersection = model.intersection(query_embeddings, document_embeddings) assert isinstance(intersection, torch.Tensor) assert intersection.shape == (model.get_embedding_dimension(),) # Check that the intersection sparsity is less than both query and document sparsities intersection_sparsity = model.sparsity(intersection) assert ( intersection_sparsity["active_dims"] < query_sparsity["active_dims"] and intersection_sparsity["active_dims"] < document_sparsity["active_dims"] ) # Test with multiple texts query = "Who has won the World Cup in 2016?" documents = ["The capital of France is Paris.", "Germany won the World Cup in 2014."] query_embeddings = model.encode_query(query) document_embeddings = model.encode_document(documents) intersection_batch = model.intersection(query_embeddings, document_embeddings) assert isinstance(intersection_batch, torch.Tensor) assert intersection_batch.shape == (len(documents), model.get_embedding_dimension()) decoded_intersection_batch = model.decode(intersection_batch) assert len(decoded_intersection_batch) == len(documents) def test_encode_with_dataset_column(splade_bert_tiny_model: SparseEncoder) -> None: """Test that encode can handle a dataset column as input.""" model = splade_bert_tiny_model from datasets import Dataset # Create a simple dataset with a text column dataset = Dataset.from_dict({"text": ["This is a test.", "Another sentence."]}) # Encode the dataset column embeddings = model.encode(dataset["text"], convert_to_tensor=True) # Check the shape of the embeddings assert embeddings.shape == (2, model.get_embedding_dimension()) def test_encode_numpy_1d_string_array(splade_bert_tiny_model: SparseEncoder) -> None: """Regression test for #3718: encoding a 1D numpy string array should produce one embedding per element.""" model = splade_bert_tiny_model texts = np.array(["Access Management", "Press Coordination", "Financial Reports"]) embeddings = model.encode(texts, convert_to_tensor=True, save_to_cpu=True) expected = model.encode(texts.tolist(), convert_to_tensor=True, save_to_cpu=True) assert embeddings.shape == (3, model.get_embedding_dimension()) assert torch.allclose(embeddings.to_dense(), expected.to_dense()) def test_encode_numpy_2d_string_array(splade_bert_tiny_model: SparseEncoder) -> None: """Encoding a 2D numpy string array should match encoding the equivalent nested list.""" model = splade_bert_tiny_model pairs = np.array([["what is AI?", "AI is artificial intelligence."], ["what is ML?", "ML is machine learning."]]) embeddings = model.encode(pairs, convert_to_tensor=True, save_to_cpu=True) expected = model.encode(pairs.tolist(), convert_to_tensor=True, save_to_cpu=True) assert embeddings.shape == (2, model.get_embedding_dimension()) assert torch.allclose(embeddings.to_dense(), expected.to_dense()) def test_encode_numpy_empty(splade_bert_tiny_model: SparseEncoder) -> None: """Encoding an empty string ndarray should return an empty tensor, like ``encode([])``.""" model = splade_bert_tiny_model embeddings = model.encode(np.array([], dtype=str), convert_to_tensor=True, save_to_cpu=True) expected = model.encode([], convert_to_tensor=True, save_to_cpu=True) assert embeddings.numel() == 0 assert torch.equal(embeddings.to_dense(), expected.to_dense()) @pytest.mark.parametrize("convert_to_tensor", [True, False]) @pytest.mark.parametrize("convert_to_sparse_tensor", [True, False]) @pytest.mark.parametrize("save_to_cpu", [True, False]) @pytest.mark.parametrize("max_active_dims", [None, 64, 128]) def test_empty_encode( splade_bert_tiny_model: SparseEncoder, convert_to_tensor: bool, convert_to_sparse_tensor: bool, save_to_cpu: bool, max_active_dims: int | None, ): model = splade_bert_tiny_model embeddings = model.encode( [], convert_to_tensor=convert_to_tensor, convert_to_sparse_tensor=convert_to_sparse_tensor, save_to_cpu=save_to_cpu, max_active_dims=max_active_dims, ) if convert_to_tensor: assert isinstance(embeddings, torch.Tensor) assert embeddings.numel() == 0 if save_to_cpu: assert embeddings.device == torch.device("cpu") else: assert embeddings.device == model.device if convert_to_sparse_tensor: assert embeddings.is_sparse else: assert not embeddings.is_sparse else: assert embeddings == [] def test_get_model_kwargs(splade_bert_tiny_model: SparseEncoder) -> None: """Test that get_model_kwargs returns the correct keyword arguments.""" model = splade_bert_tiny_model # Check that the forward kwargs are as expected, i.e. no extra forward kwargs # for this basic model forward_kwargs = model.get_model_kwargs() assert forward_kwargs == [] with pytest.raises( ValueError, match=re.escape( "SparseEncoder.encode() has been called with additional keyword arguments that this model does " "not use: ['normalize']. As per SparseEncoder.get_model_kwargs(), this model does not accept " "any additional keyword arguments." ), ): # There is no "normalize" argument, this should crash model.encode("Test sentence", normalize=True) # This should run fine model.encode("Test sentence") model.encode_query("Test sentence") # If one of the modules has additional forward kwargs, they should be included model[0].forward_kwargs = {"foo"} model[1].forward_kwargs = {"bar", "baz"} assert set(model.get_model_kwargs()) == {"foo", "bar", "baz"} with pytest.raises( ValueError, match=re.escape( "SparseEncoder.encode() has been called with additional keyword arguments that this model does " "not use: ['normalize']. As per SparseEncoder.get_model_kwargs(), the valid additional keyword" " arguments are: " ) + r"\[('foo'|'bar'|'baz'|, ){5}\].", ): # There is no "normalize" argument, this should crash model.encode("Test sentence", normalize=True) # This should run fine model.encode("Test sentence") model.encode_query("Test sentence") with pytest.raises( TypeError, match=r"(Transformer\.)?forward\(\) got an unexpected keyword argument '(foo|bar)'", ): # This would run fine, except the model can't actually accept these arguments (we monkeypatched the modules' # forward_kwargs for this test, after all). The model does send the args down to the underlying modules, though! model.encode("Test sentence", foo=True, bar=False) # And also if we have a Router in place query_pooling_copy = copy.deepcopy(model[1]) query_pooling_copy.forward_kwargs = {"query_arg"} document_pooling_copy = copy.deepcopy(model[1]) document_pooling_copy.forward_kwargs = {"document_arg_1", "document_arg_2"} model[1] = Router.for_query_document( query_modules=[query_pooling_copy], document_modules=[document_pooling_copy], ) assert set(model.get_model_kwargs()) == { "foo", "task", "query_arg", "document_arg_1", "document_arg_2", "modality", } with pytest.raises( ValueError, match=re.escape( "SparseEncoder.encode() has been called with additional keyword arguments that this model does " "not use: ['normalize']. As per SparseEncoder.get_model_kwargs(), the valid additional keyword" " arguments are: " ) + r"\[('foo'|'task'|'query_arg'|'document_arg_1'|'document_arg_2'|'modality'|, ){11}\].", ): # There is no "normalize" argument, this should crash model.encode("Test sentence", task="query", normalize=True) # This should run fine model.encode("Test sentence", task="document") model.encode_query("Test sentence") with pytest.raises( TypeError, match=r"(Transformer\.)?forward\(\) got an unexpected keyword argument '(foo|document_arg_1)'", ): # This would run fine, except the model can't actually accept these arguments (we monkeypatched the modules' # forward_kwargs for this test, after all). The model does send the args down to the underlying modules, though! model.encode("Test sentence", task="document", foo=True, document_arg_1=12) @pytest.mark.parametrize("similarity_fn_name", SimilarityFunction.possible_values()) def test_similarity_score(splade_bert_tiny_model: SparseEncoder, similarity_fn_name: str) -> None: model = splade_bert_tiny_model model.similarity_fn_name = similarity_fn_name sentences = [ "The weather is so nice!", "It's so sunny outside.", "He's driving to the movie theater.", "She's going to the cinema.", ] embeddings = model.encode(sentences, convert_to_sparse_tensor=False) scores = model.similarity(embeddings, embeddings) assert scores.shape == (len(sentences), len(sentences)) diag = np.diag(scores.cpu().numpy()) if similarity_fn_name == "cosine": np.testing.assert_almost_equal(diag, np.ones(len(sentences), dtype=float), decimal=4) elif similarity_fn_name in ("euclidean", "manhattan"): np.testing.assert_almost_equal(diag, np.zeros(len(sentences), dtype=float), decimal=4) else: # dot product - self-similarity of non-zero sparse vectors is positive assert (diag > 0).all() pairwise_scores = model.similarity_pairwise(embeddings[::2], embeddings[1::2]) assert pairwise_scores.shape == (len(sentences) // 2,) def test_similarity_score_save(splade_bert_tiny_model: SparseEncoder, tmp_path: Path) -> None: model = splade_bert_tiny_model assert model.similarity_fn_name == "dot" model.similarity_fn_name = "cosine" model.save(str(tmp_path)) loaded_model = SparseEncoder(str(tmp_path)) assert loaded_model.similarity_fn_name == "cosine" def test_similarity_fn_name_set_via_enum(splade_bert_tiny_model: SparseEncoder) -> None: model = splade_bert_tiny_model model.similarity_fn_name = SimilarityFunction.COSINE assert model.similarity_fn_name == "cosine" model.similarity_fn_name = SimilarityFunction.DOT assert model.similarity_fn_name == "dot" def test_similarity_fn_name_constructor_overrides_saved(splade_bert_tiny_model: SparseEncoder, tmp_path: Path) -> None: splade_bert_tiny_model.similarity_fn_name = "cosine" splade_bert_tiny_model.save(str(tmp_path)) model = SparseEncoder(str(tmp_path), similarity_fn_name="dot") assert model.similarity_fn_name == "dot" def test_prompts(splade_bert_tiny_model: SparseEncoder, caplog: pytest.LogCaptureFixture) -> None: model = splade_bert_tiny_model assert model.prompts == {"query": "", "document": ""} assert model.default_prompt_name is None texts = ["How to bake a chocolate cake", "Symptoms of the flu"] no_prompt_embedding = model.encode(texts, convert_to_sparse_tensor=False, save_to_cpu=True) prompt_embedding = model.encode( [f"query: {text}" for text in texts], convert_to_sparse_tensor=False, save_to_cpu=True ) assert not np.array_equal(no_prompt_embedding, prompt_embedding) query = "query: " # Test prompt="query: " model.prompts = {"query": "", "document": ""} assert np.array_equal( model.encode(texts, prompt=query, convert_to_sparse_tensor=False, save_to_cpu=True), prompt_embedding ) # Test prompt_name="..." model.prompts = {"query": query, "document": ""} assert np.array_equal( model.encode(texts, prompt_name="query", convert_to_sparse_tensor=False, save_to_cpu=True), prompt_embedding ) caplog.clear() # Test prompt_name="..." & prompt="..." with caplog.at_level(logging.WARNING): assert np.array_equal( model.encode(texts, prompt=query, prompt_name="query", convert_to_sparse_tensor=False, save_to_cpu=True), prompt_embedding, ) assert len(caplog.record_tuples) == 1 assert ( caplog.record_tuples[0][2] == "Provide either a `prompt`, a `prompt_name`, or neither, but not both. " "Ignoring the `prompt_name` in favor of `prompt`." ) with pytest.raises( ValueError, match=re.escape( "Prompt name 'invalid_prompt_name' not found in the configured prompts dictionary with keys ['query', 'document']." ), ): model.encode(texts, prompt_name="invalid_prompt_name") def test_save_load_prompts() -> None: with pytest.raises( ValueError, match=re.escape( "Default prompt name 'invalid_prompt_name' not found in the configured prompts dictionary with keys ['query', 'document']." ), ): SparseEncoder( "sparse-encoder-testing/splade-bert-tiny-nq", prompts={"query": "query: "}, default_prompt_name="invalid_prompt_name", ) model = SparseEncoder( "sparse-encoder-testing/splade-bert-tiny-nq", prompts={"query": "query: "}, default_prompt_name="query", ) assert model.prompts == {"query": "query: ", "document": ""} assert model.default_prompt_name == "query" with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmp_folder: model_path = Path(tmp_folder) / "tiny_model_local" model.save(str(model_path)) config_path = model_path / "config_sentence_transformers.json" assert config_path.exists() with open(config_path, encoding="utf8") as f: saved_config = json.load(f) assert saved_config["prompts"] == {"query": "query: ", "document": ""} assert saved_config["default_prompt_name"] == "query" fresh_model = SparseEncoder(str(model_path)) assert fresh_model.prompts == {"query": "query: ", "document": ""} assert fresh_model.default_prompt_name == "query" @pytest.mark.parametrize("sentences", ["Hello world", ["Hello world", "This is a test"], [], [""]]) @pytest.mark.parametrize("convert_to_tensor", [True, False]) @pytest.mark.parametrize("convert_to_sparse_tensor", [True, False]) @pytest.mark.parametrize("prompt_name", [None, "query", "custom"]) @pytest.mark.parametrize("prompt", [None, "Custom prompt: "]) def test_encode_query( splade_bert_tiny_model: SparseEncoder, sentences: str | list[str], convert_to_tensor: bool, convert_to_sparse_tensor: bool, prompt_name: str | None, prompt: str | None, monkeypatch: pytest.MonkeyPatch, ): model = splade_bert_tiny_model model.prompts = {"query": "query: ", "custom": "custom: "} encode_calls = [] def spy_encode(*args, **kwargs): encode_calls.append((args, kwargs)) monkeypatch.setattr(model, "encode", spy_encode) model.encode_query( sentences=sentences, batch_size=32, convert_to_tensor=convert_to_tensor, convert_to_sparse_tensor=convert_to_sparse_tensor, prompt_name=prompt_name, prompt=prompt, ) if prompt_name: expected_prompt_name = prompt_name elif prompt is not None: expected_prompt_name = None else: expected_prompt_name = "query" assert len(encode_calls) == 1 _, kwargs = encode_calls[0] assert kwargs["inputs"] == sentences assert kwargs["convert_to_tensor"] == convert_to_tensor assert kwargs["convert_to_sparse_tensor"] == convert_to_sparse_tensor assert kwargs["prompt"] == prompt assert kwargs["prompt_name"] == expected_prompt_name assert kwargs["task"] == "query" @pytest.mark.parametrize("sentences", ["Hello world", ["Hello world", "This is a test"], [], [""]]) @pytest.mark.parametrize("convert_to_tensor", [True, False]) @pytest.mark.parametrize("convert_to_sparse_tensor", [True, False]) @pytest.mark.parametrize("prompt_name", [None, "document", "passage", "corpus", "custom"]) @pytest.mark.parametrize("prompt", [None, "Custom prompt: "]) def test_encode_document( splade_bert_tiny_model: SparseEncoder, sentences: str | list[str], convert_to_tensor: bool, convert_to_sparse_tensor: bool, prompt_name: str | None, prompt: str | None, monkeypatch: pytest.MonkeyPatch, ): model = splade_bert_tiny_model model.prompts = {"document": "document: ", "passage": "passage: ", "corpus": "corpus: ", "custom": "custom: "} encode_calls = [] def spy_encode(*args, **kwargs): encode_calls.append((args, kwargs)) monkeypatch.setattr(model, "encode", spy_encode) model.encode_document( sentences=sentences, batch_size=32, convert_to_tensor=convert_to_tensor, convert_to_sparse_tensor=convert_to_sparse_tensor, prompt_name=prompt_name, prompt=prompt, ) assert len(encode_calls) == 1 _, kwargs = encode_calls[0] if prompt_name: expected_prompt_name = prompt_name elif prompt is not None: expected_prompt_name = None else: expected_prompt_name = "document" assert kwargs["inputs"] == sentences assert kwargs["convert_to_tensor"] == convert_to_tensor assert kwargs["convert_to_sparse_tensor"] == convert_to_sparse_tensor assert kwargs["prompt"] == prompt assert kwargs["prompt_name"] == expected_prompt_name assert kwargs["task"] == "document" def test_encode_document_prompt_priority(splade_bert_tiny_model: SparseEncoder, monkeypatch: pytest.MonkeyPatch): """Test that proper prompt priority is respected when multiple options are available""" model = splade_bert_tiny_model model.prompts = { "document": "document: ", "passage": "passage: ", "corpus": "corpus: ", } encode_calls = [] def spy_encode(*args, **kwargs): encode_calls.append((args, kwargs)) monkeypatch.setattr(model, "encode", spy_encode) model.encode_document("test") _, kwargs = encode_calls[-1] assert kwargs["prompt_name"] == "document" # Remove document, should fall back to passage encode_calls.clear() model.prompts = {"passage": "passage: ", "corpus": "corpus: "} model.encode_document("test") _, kwargs = encode_calls[-1] assert kwargs["prompt_name"] == "passage" # Remove passage, should fall back to corpus encode_calls.clear() model.prompts = {"corpus": "corpus: "} model.encode_document("test") _, kwargs = encode_calls[-1] assert kwargs["prompt_name"] == "corpus" # No document/passage/corpus, should use None encode_calls.clear() model.prompts = {"custom": "custom: "} model.encode_document("test") _, kwargs = encode_calls[-1] assert kwargs["prompt_name"] is None