Serkan007's picture
Sentence-Transformers ve E5-Large model aktarımı.
9bbba62 verified
from __future__ import annotations
import copy
import json
import logging
import re
import tempfile
from pathlib import Path
import numpy as np
import pytest
import torch
from sentence_transformers.sentence_transformer.modules import Pooling
from sentence_transformers.sparse_encoder.model import SparseEncoder
from sentence_transformers.sparse_encoder.modules import Router, SparseAutoEncoder, SpladePooling, Transformer
from sentence_transformers.util.similarity import SimilarityFunction
@pytest.mark.parametrize(
("texts", "top_k", "expected_shape"),
[
# Single text, default top_k (None)
(["The weather is nice!"], None, 1),
# Single text, specific top_k
(["The weather is nice!"], 3, 1),
# String text, specific top_k, expect a non-nested list
("The weather is nice!", 8, 8),
# Multiple texts, default top_k (None)
(["The weather is nice!", "It's sunny outside"], None, 2),
# Multiple texts, specific top_k
(["The weather is nice!", "It's sunny outside"], 3, 2),
],
)
def test_decode_shapes(
splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, top_k: int, expected_shape: int
) -> None:
model = splade_bert_tiny_model
embeddings = model.encode(texts)
decoded = model.decode(embeddings, top_k=top_k)
assert len(decoded) == expected_shape
if isinstance(texts, list):
if len(texts) == 1:
assert isinstance(decoded[0], tuple) or isinstance(decoded, list)
if top_k is not None:
assert len(decoded) <= top_k
else:
assert isinstance(decoded, list)
for item in decoded:
assert isinstance(item, list)
if top_k is not None:
assert len(item) <= top_k
@pytest.mark.parametrize(
("text", "expected_token_types"),
[
("The weather is nice!", str),
("It's sunny outside", str),
],
)
def test_decode_token_types(splade_bert_tiny_model: SparseEncoder, text: str, expected_token_types: type) -> None:
model = splade_bert_tiny_model
embeddings = model.encode(text)
decoded = model.decode(embeddings)
# Check the first item in the batch
for token, weight in decoded:
assert isinstance(token, expected_token_types)
assert isinstance(weight, float)
@pytest.mark.parametrize(
("text", "top_k"),
[
("The weather is nice!", 1),
("It's sunny outside", 3),
("Hello world", 5),
],
)
def test_decode_top_k_respects_limit(splade_bert_tiny_model: SparseEncoder, text: str, top_k: int) -> None:
model = splade_bert_tiny_model
embeddings = model.encode([text])
decoded = model.decode(embeddings, top_k=top_k)
assert len(decoded) <= top_k
@pytest.mark.parametrize(
("texts", "format_type"),
[
("The weather is nice!", "1d"),
(["The weather is nice!"], "1d"),
(["The weather is nice!", "It's sunny outside"], "2d"),
],
)
def test_decode_handles_sparse_dense_inputs(
splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, format_type: str
):
model = splade_bert_tiny_model
# Get embeddings and test both sparse and dense format handling
embeddings = model.encode(texts)
# Test with sparse tensor
if not embeddings.is_sparse:
embeddings_sparse = embeddings.to_sparse()
else:
embeddings_sparse = embeddings
decoded_sparse = model.decode(embeddings_sparse)
# Test with dense tensor
if embeddings.is_sparse:
embeddings_dense = embeddings.to_dense()
else:
embeddings_dense = embeddings
decoded_dense = model.decode(embeddings_dense)
# Verify both produce the same result structure
if format_type == "1d":
assert len(decoded_sparse) == len(decoded_dense)
else:
assert len(decoded_sparse) == len(decoded_dense)
for i in range(len(decoded_sparse)):
# Sort both results to ensure consistent comparison
sorted_sparse = sorted(decoded_sparse[i], key=lambda x: (x[1], x[0]), reverse=True)
sorted_dense = sorted(decoded_dense[i], key=lambda x: (x[1], x[0]), reverse=True)
assert len(sorted_sparse) == len(sorted_dense)
def test_decode_empty_tensor(splade_bert_tiny_model: SparseEncoder) -> None:
model = splade_bert_tiny_model
# Create an empty sparse tensor
empty_sparse = torch.sparse_coo_tensor(
indices=torch.zeros((2, 0), dtype=torch.long),
values=torch.zeros((0,), dtype=torch.float),
size=(1, model.get_embedding_dimension()),
)
decoded = model.decode(empty_sparse)
assert len(decoded) == 0 or (isinstance(decoded, list) and all(not item for item in decoded))
@pytest.mark.parametrize("top_k", [0, -1, -5])
def test_decode_invalid_top_k(splade_bert_tiny_model: SparseEncoder, top_k: int) -> None:
model = splade_bert_tiny_model
embeddings = model.encode("Hello world")
with pytest.raises(ValueError, match="top_k must be a positive integer"):
model.decode(embeddings, top_k=top_k)
def test_decode_invalid_input_type(splade_bert_tiny_model: SparseEncoder) -> None:
model = splade_bert_tiny_model
with pytest.raises(TypeError, match="Expected torch.Tensor"):
model.decode([1, 2, 3])
def test_decode_invalid_ndim(splade_bert_tiny_model: SparseEncoder) -> None:
model = splade_bert_tiny_model
tensor_3d = torch.zeros(2, 3, 4)
with pytest.raises(ValueError, match="Input tensor must be 1D or 2D"):
model.decode(tensor_3d)
def test_decode_batch_with_empty_sample(splade_bert_tiny_model: SparseEncoder) -> None:
model = splade_bert_tiny_model
vocab_size = model.get_embedding_dimension()
# Create a batch where the first sample has values but the second is all zeros
indices = torch.tensor([[0, 0], [1, 5]]) # both non-zero entries in sample 0
values = torch.tensor([1.0, 2.0])
batch_sparse = torch.sparse_coo_tensor(indices, values, size=(2, vocab_size))
decoded = model.decode(batch_sparse)
assert len(decoded) == 2
assert len(decoded[0]) == 2 # sample 0 has 2 non-zero entries
assert decoded[1] == [] # sample 1 is empty
@pytest.mark.parametrize("top_k", [None, 5, 1000])
@pytest.mark.parametrize(
"texts",
[
("The weather is nice!"),
(["The weather is nice!"]),
(["The weather is nice!", "It's sunny outside", "Hello world"]),
(["Short text", "This is a longer text with more words to encode"]),
],
)
def test_decode_returns_sorted_weights(
splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, top_k: int | None
) -> None:
model = splade_bert_tiny_model
embeddings = model.encode(texts)
decoded = model.decode(embeddings, top_k=top_k)
if isinstance(texts, list):
for item in decoded:
weights = [weight for _, weight in item]
assert all(weights[i] >= weights[i + 1] for i in range(len(weights) - 1))
else:
weights = [weight for _, weight in decoded]
assert all(weights[i] >= weights[i + 1] for i in range(len(weights) - 1))
def test_inference_free_splade(inference_free_splade_bert_tiny_model: SparseEncoder):
model = inference_free_splade_bert_tiny_model
dimensionality = model.get_embedding_dimension()
query = "What is the capital of France?"
document = "The capital of France is Paris."
query_embeddings = model.encode_query(query)
document_embeddings = model.encode_document(document)
assert query_embeddings.shape == (dimensionality,)
assert document_embeddings.shape == (dimensionality,)
decoded_query = model.decode(query_embeddings)
decoded_document = model.decode(document_embeddings)
assert len(decoded_query) == len(model.preprocess(query, task="query")["input_ids"][0])
assert len(decoded_document) >= 50
assert model.max_seq_length == 512
assert model[0].sub_modules["query"][0].max_seq_length == 512
assert model[0].sub_modules["document"][0].max_seq_length == 512
model.max_seq_length = 256
assert model.max_seq_length == 256
assert model[0].sub_modules["query"][0].max_seq_length == 256
assert model[0].sub_modules["document"][0].max_seq_length == 256
def test_inference_free_splade_max_active_dims_routing(inference_free_splade_bert_tiny_model: SparseEncoder):
model = inference_free_splade_bert_tiny_model
query = "What is the capital of France?"
document = "The capital of France is Paris."
# Encode without max_active_dims — baseline
query_emb = model.encode_query(query)
doc_emb = model.encode_document(document)
# Encode with max_active_dims — should route to the same sub-modules
query_emb_mad = model.encode_query(query, max_active_dims=50)
doc_emb_mad = model.encode_document(document, max_active_dims=50)
# The non-zero indices of the max_active_dims result should be a subset of the baseline
query_baseline_indices = query_emb.coalesce().indices()[0]
query_mad_indices = query_emb_mad.coalesce().indices()[0]
assert set(query_mad_indices.tolist()).issubset(set(query_baseline_indices.tolist()))
assert query_emb_mad._nnz() <= 50
doc_baseline_indices = doc_emb.coalesce().indices()[0]
doc_mad_indices = doc_emb_mad.coalesce().indices()[0]
assert set(doc_mad_indices.tolist()).issubset(set(doc_baseline_indices.tolist()))
assert doc_emb_mad._nnz() <= 50
def test_encode_advanced_parameters(splade_bert_tiny_model: SparseEncoder, monkeypatch: pytest.MonkeyPatch):
"""Test that additional parameters are correctly passed to encode"""
model = splade_bert_tiny_model
encode_calls = []
def spy_encode(*args, **kwargs):
encode_calls.append((args, kwargs))
monkeypatch.setattr(model, "encode", spy_encode)
# Call with advanced parameters
model.encode_query(
"test",
normalize_embeddings=True,
batch_size=64,
show_progress_bar=True,
max_active_dims=128,
chunk_size=10,
custom_param="value",
)
# Verify all parameters were passed correctly
_, kwargs = encode_calls[0]
assert kwargs["normalize_embeddings"] is True
assert kwargs["batch_size"] == 64
assert kwargs["show_progress_bar"] is True
assert kwargs["max_active_dims"] == 128
assert kwargs["chunk_size"] == 10
assert kwargs["custom_param"] == "value"
def test_csr_max_active_dims_passed_to_forward(csr_bert_tiny_model: SparseEncoder, monkeypatch: pytest.MonkeyPatch):
model = csr_bert_tiny_model
assert isinstance(model[-1], SparseAutoEncoder)
assert model[-1].k == 16
# Verify that max_active_dims is passed to SparseAutoEncoder.forward()
forward_calls = []
original_forward = model[-1].forward
def spy_forward(*args, **kwargs):
forward_calls.append(kwargs)
return original_forward(*args, **kwargs)
monkeypatch.setattr(model[-1], "forward", spy_forward)
model.encode("Hello world", max_active_dims=5)
assert len(forward_calls) == 1
assert forward_calls[0]["max_active_dims"] == 5
# Without max_active_dims, the model's default max_active_dims is used
forward_calls.clear()
model.encode("Hello world")
assert len(forward_calls) == 1
assert forward_calls[0]["max_active_dims"] == model.max_active_dims
def test_max_active_dims_set_init(splade_bert_tiny_model: SparseEncoder, csr_bert_tiny_model: SparseEncoder, tmp_path):
splade_bert_tiny_model.save_pretrained(str(tmp_path / "splade_bert_tiny"))
csr_bert_tiny_model.save_pretrained(str(tmp_path / "csr_bert_tiny"))
# Load the models with max_active_dims set
loaded_model = SparseEncoder(str(tmp_path / "splade_bert_tiny"))
assert loaded_model.max_active_dims is None
loaded_model = SparseEncoder(str(tmp_path / "splade_bert_tiny"), max_active_dims=13)
assert loaded_model.max_active_dims == 13
loaded_model = SparseEncoder(str(tmp_path / "csr_bert_tiny"))
assert loaded_model.max_active_dims == 16 # Based on the SparseAutoEncoder's k value
loaded_model = SparseEncoder(str(tmp_path / "csr_bert_tiny"), max_active_dims=13)
assert loaded_model.max_active_dims == 13
def test_detect_mlm():
model = SparseEncoder("distilbert/distilbert-base-uncased")
assert isinstance(model[0], Transformer)
assert model[0].transformer_task == "fill-mask"
assert isinstance(model[1], SpladePooling)
def test_default_to_csr():
# NOTE: bert-tiny is actually MLM-based, but the config isn't modern enough to allow us to detect it,
# so we should default to CSR here.
model = SparseEncoder("sentence-transformers-testing/stsb-bert-tiny-safetensors")
assert isinstance(model[0], Transformer)
assert isinstance(model[1], Pooling)
assert isinstance(model[2], SparseAutoEncoder)
def test_sparsity(splade_bert_tiny_model: SparseEncoder):
model = splade_bert_tiny_model
# Check that the sparsity is applied correctly
embeddings = model.encode_query(["What is the capital of France?", "Who has won the World Cup in 2016?"])
sparsity = model.sparsity(embeddings)
assert isinstance(sparsity, dict)
assert "active_dims" in sparsity
assert "sparsity_ratio" in sparsity
assert sparsity["active_dims"] < 100 and sparsity["active_dims"] > 0
assert sparsity["sparsity_ratio"] < 1.0 and sparsity["sparsity_ratio"] >= 0.99
# Also check with dense tensors
dense_sparsity = model.sparsity(embeddings.to_dense())
assert dense_sparsity == sparsity, "Sparsity should be the same for dense and sparse tensors"
# Check that 1-dimensional embeddings work correctly
sparsity_one = model.sparsity(embeddings[0])
sparsity_two = model.sparsity(embeddings[1])
assert (sparsity_one["active_dims"] + sparsity_two["active_dims"]) / 2 == sparsity["active_dims"]
def test_splade_pooling_chunk_size(splade_bert_tiny_model: SparseEncoder):
model = splade_bert_tiny_model
# The chunk size defaults to None, i.e. no chunking
assert model.splade_pooling_chunk_size is None
# But we can chunk the pooling to save memory at the cost of some speed
model.splade_pooling_chunk_size = 13
assert model.splade_pooling_chunk_size == 13
assert isinstance(model[1], SpladePooling)
assert model[1].chunk_size == 13
def test_intersection(splade_bert_tiny_model: SparseEncoder):
model = splade_bert_tiny_model
# Test intersection with a single text
query = "Where can I deposit my money?"
document = "I'm sitting by the river."
query_embeddings = model.encode_query(query)
document_embeddings = model.encode_document(document)
query_sparsity = model.sparsity(query_embeddings)
document_sparsity = model.sparsity(document_embeddings)
# Let's check that the intersection is a tensor and has the correct shape
intersection = model.intersection(query_embeddings, document_embeddings)
assert isinstance(intersection, torch.Tensor)
assert intersection.shape == (model.get_embedding_dimension(),)
# Check that the intersection sparsity is less than both query and document sparsities
intersection_sparsity = model.sparsity(intersection)
assert (
intersection_sparsity["active_dims"] < query_sparsity["active_dims"]
and intersection_sparsity["active_dims"] < document_sparsity["active_dims"]
)
# Test with multiple texts
query = "Who has won the World Cup in 2016?"
documents = ["The capital of France is Paris.", "Germany won the World Cup in 2014."]
query_embeddings = model.encode_query(query)
document_embeddings = model.encode_document(documents)
intersection_batch = model.intersection(query_embeddings, document_embeddings)
assert isinstance(intersection_batch, torch.Tensor)
assert intersection_batch.shape == (len(documents), model.get_embedding_dimension())
decoded_intersection_batch = model.decode(intersection_batch)
assert len(decoded_intersection_batch) == len(documents)
def test_encode_with_dataset_column(splade_bert_tiny_model: SparseEncoder) -> None:
"""Test that encode can handle a dataset column as input."""
model = splade_bert_tiny_model
from datasets import Dataset
# Create a simple dataset with a text column
dataset = Dataset.from_dict({"text": ["This is a test.", "Another sentence."]})
# Encode the dataset column
embeddings = model.encode(dataset["text"], convert_to_tensor=True)
# Check the shape of the embeddings
assert embeddings.shape == (2, model.get_embedding_dimension())
def test_encode_numpy_1d_string_array(splade_bert_tiny_model: SparseEncoder) -> None:
"""Regression test for #3718: encoding a 1D numpy string array should produce one embedding per element."""
model = splade_bert_tiny_model
texts = np.array(["Access Management", "Press Coordination", "Financial Reports"])
embeddings = model.encode(texts, convert_to_tensor=True, save_to_cpu=True)
expected = model.encode(texts.tolist(), convert_to_tensor=True, save_to_cpu=True)
assert embeddings.shape == (3, model.get_embedding_dimension())
assert torch.allclose(embeddings.to_dense(), expected.to_dense())
def test_encode_numpy_2d_string_array(splade_bert_tiny_model: SparseEncoder) -> None:
"""Encoding a 2D numpy string array should match encoding the equivalent nested list."""
model = splade_bert_tiny_model
pairs = np.array([["what is AI?", "AI is artificial intelligence."], ["what is ML?", "ML is machine learning."]])
embeddings = model.encode(pairs, convert_to_tensor=True, save_to_cpu=True)
expected = model.encode(pairs.tolist(), convert_to_tensor=True, save_to_cpu=True)
assert embeddings.shape == (2, model.get_embedding_dimension())
assert torch.allclose(embeddings.to_dense(), expected.to_dense())
def test_encode_numpy_empty(splade_bert_tiny_model: SparseEncoder) -> None:
"""Encoding an empty string ndarray should return an empty tensor, like ``encode([])``."""
model = splade_bert_tiny_model
embeddings = model.encode(np.array([], dtype=str), convert_to_tensor=True, save_to_cpu=True)
expected = model.encode([], convert_to_tensor=True, save_to_cpu=True)
assert embeddings.numel() == 0
assert torch.equal(embeddings.to_dense(), expected.to_dense())
@pytest.mark.parametrize("convert_to_tensor", [True, False])
@pytest.mark.parametrize("convert_to_sparse_tensor", [True, False])
@pytest.mark.parametrize("save_to_cpu", [True, False])
@pytest.mark.parametrize("max_active_dims", [None, 64, 128])
def test_empty_encode(
splade_bert_tiny_model: SparseEncoder,
convert_to_tensor: bool,
convert_to_sparse_tensor: bool,
save_to_cpu: bool,
max_active_dims: int | None,
):
model = splade_bert_tiny_model
embeddings = model.encode(
[],
convert_to_tensor=convert_to_tensor,
convert_to_sparse_tensor=convert_to_sparse_tensor,
save_to_cpu=save_to_cpu,
max_active_dims=max_active_dims,
)
if convert_to_tensor:
assert isinstance(embeddings, torch.Tensor)
assert embeddings.numel() == 0
if save_to_cpu:
assert embeddings.device == torch.device("cpu")
else:
assert embeddings.device == model.device
if convert_to_sparse_tensor:
assert embeddings.is_sparse
else:
assert not embeddings.is_sparse
else:
assert embeddings == []
def test_get_model_kwargs(splade_bert_tiny_model: SparseEncoder) -> None:
"""Test that get_model_kwargs returns the correct keyword arguments."""
model = splade_bert_tiny_model
# Check that the forward kwargs are as expected, i.e. no extra forward kwargs
# for this basic model
forward_kwargs = model.get_model_kwargs()
assert forward_kwargs == []
with pytest.raises(
ValueError,
match=re.escape(
"SparseEncoder.encode() has been called with additional keyword arguments that this model does "
"not use: ['normalize']. As per SparseEncoder.get_model_kwargs(), this model does not accept "
"any additional keyword arguments."
),
):
# There is no "normalize" argument, this should crash
model.encode("Test sentence", normalize=True)
# This should run fine
model.encode("Test sentence")
model.encode_query("Test sentence")
# If one of the modules has additional forward kwargs, they should be included
model[0].forward_kwargs = {"foo"}
model[1].forward_kwargs = {"bar", "baz"}
assert set(model.get_model_kwargs()) == {"foo", "bar", "baz"}
with pytest.raises(
ValueError,
match=re.escape(
"SparseEncoder.encode() has been called with additional keyword arguments that this model does "
"not use: ['normalize']. As per SparseEncoder.get_model_kwargs(), the valid additional keyword"
" arguments are: "
)
+ r"\[('foo'|'bar'|'baz'|, ){5}\].",
):
# There is no "normalize" argument, this should crash
model.encode("Test sentence", normalize=True)
# This should run fine
model.encode("Test sentence")
model.encode_query("Test sentence")
with pytest.raises(
TypeError,
match=r"(Transformer\.)?forward\(\) got an unexpected keyword argument '(foo|bar)'",
):
# This would run fine, except the model can't actually accept these arguments (we monkeypatched the modules'
# forward_kwargs for this test, after all). The model does send the args down to the underlying modules, though!
model.encode("Test sentence", foo=True, bar=False)
# And also if we have a Router in place
query_pooling_copy = copy.deepcopy(model[1])
query_pooling_copy.forward_kwargs = {"query_arg"}
document_pooling_copy = copy.deepcopy(model[1])
document_pooling_copy.forward_kwargs = {"document_arg_1", "document_arg_2"}
model[1] = Router.for_query_document(
query_modules=[query_pooling_copy],
document_modules=[document_pooling_copy],
)
assert set(model.get_model_kwargs()) == {
"foo",
"task",
"query_arg",
"document_arg_1",
"document_arg_2",
"modality",
}
with pytest.raises(
ValueError,
match=re.escape(
"SparseEncoder.encode() has been called with additional keyword arguments that this model does "
"not use: ['normalize']. As per SparseEncoder.get_model_kwargs(), the valid additional keyword"
" arguments are: "
)
+ r"\[('foo'|'task'|'query_arg'|'document_arg_1'|'document_arg_2'|'modality'|, ){11}\].",
):
# There is no "normalize" argument, this should crash
model.encode("Test sentence", task="query", normalize=True)
# This should run fine
model.encode("Test sentence", task="document")
model.encode_query("Test sentence")
with pytest.raises(
TypeError,
match=r"(Transformer\.)?forward\(\) got an unexpected keyword argument '(foo|document_arg_1)'",
):
# This would run fine, except the model can't actually accept these arguments (we monkeypatched the modules'
# forward_kwargs for this test, after all). The model does send the args down to the underlying modules, though!
model.encode("Test sentence", task="document", foo=True, document_arg_1=12)
@pytest.mark.parametrize("similarity_fn_name", SimilarityFunction.possible_values())
def test_similarity_score(splade_bert_tiny_model: SparseEncoder, similarity_fn_name: str) -> None:
model = splade_bert_tiny_model
model.similarity_fn_name = similarity_fn_name
sentences = [
"The weather is so nice!",
"It's so sunny outside.",
"He's driving to the movie theater.",
"She's going to the cinema.",
]
embeddings = model.encode(sentences, convert_to_sparse_tensor=False)
scores = model.similarity(embeddings, embeddings)
assert scores.shape == (len(sentences), len(sentences))
diag = np.diag(scores.cpu().numpy())
if similarity_fn_name == "cosine":
np.testing.assert_almost_equal(diag, np.ones(len(sentences), dtype=float), decimal=4)
elif similarity_fn_name in ("euclidean", "manhattan"):
np.testing.assert_almost_equal(diag, np.zeros(len(sentences), dtype=float), decimal=4)
else: # dot product - self-similarity of non-zero sparse vectors is positive
assert (diag > 0).all()
pairwise_scores = model.similarity_pairwise(embeddings[::2], embeddings[1::2])
assert pairwise_scores.shape == (len(sentences) // 2,)
def test_similarity_score_save(splade_bert_tiny_model: SparseEncoder, tmp_path: Path) -> None:
model = splade_bert_tiny_model
assert model.similarity_fn_name == "dot"
model.similarity_fn_name = "cosine"
model.save(str(tmp_path))
loaded_model = SparseEncoder(str(tmp_path))
assert loaded_model.similarity_fn_name == "cosine"
def test_similarity_fn_name_set_via_enum(splade_bert_tiny_model: SparseEncoder) -> None:
model = splade_bert_tiny_model
model.similarity_fn_name = SimilarityFunction.COSINE
assert model.similarity_fn_name == "cosine"
model.similarity_fn_name = SimilarityFunction.DOT
assert model.similarity_fn_name == "dot"
def test_similarity_fn_name_constructor_overrides_saved(splade_bert_tiny_model: SparseEncoder, tmp_path: Path) -> None:
splade_bert_tiny_model.similarity_fn_name = "cosine"
splade_bert_tiny_model.save(str(tmp_path))
model = SparseEncoder(str(tmp_path), similarity_fn_name="dot")
assert model.similarity_fn_name == "dot"
def test_prompts(splade_bert_tiny_model: SparseEncoder, caplog: pytest.LogCaptureFixture) -> None:
model = splade_bert_tiny_model
assert model.prompts == {"query": "", "document": ""}
assert model.default_prompt_name is None
texts = ["How to bake a chocolate cake", "Symptoms of the flu"]
no_prompt_embedding = model.encode(texts, convert_to_sparse_tensor=False, save_to_cpu=True)
prompt_embedding = model.encode(
[f"query: {text}" for text in texts], convert_to_sparse_tensor=False, save_to_cpu=True
)
assert not np.array_equal(no_prompt_embedding, prompt_embedding)
query = "query: "
# Test prompt="query: "
model.prompts = {"query": "", "document": ""}
assert np.array_equal(
model.encode(texts, prompt=query, convert_to_sparse_tensor=False, save_to_cpu=True), prompt_embedding
)
# Test prompt_name="..."
model.prompts = {"query": query, "document": ""}
assert np.array_equal(
model.encode(texts, prompt_name="query", convert_to_sparse_tensor=False, save_to_cpu=True), prompt_embedding
)
caplog.clear()
# Test prompt_name="..." & prompt="..."
with caplog.at_level(logging.WARNING):
assert np.array_equal(
model.encode(texts, prompt=query, prompt_name="query", convert_to_sparse_tensor=False, save_to_cpu=True),
prompt_embedding,
)
assert len(caplog.record_tuples) == 1
assert (
caplog.record_tuples[0][2] == "Provide either a `prompt`, a `prompt_name`, or neither, but not both. "
"Ignoring the `prompt_name` in favor of `prompt`."
)
with pytest.raises(
ValueError,
match=re.escape(
"Prompt name 'invalid_prompt_name' not found in the configured prompts dictionary with keys ['query', 'document']."
),
):
model.encode(texts, prompt_name="invalid_prompt_name")
def test_save_load_prompts() -> None:
with pytest.raises(
ValueError,
match=re.escape(
"Default prompt name 'invalid_prompt_name' not found in the configured prompts dictionary with keys ['query', 'document']."
),
):
SparseEncoder(
"sparse-encoder-testing/splade-bert-tiny-nq",
prompts={"query": "query: "},
default_prompt_name="invalid_prompt_name",
)
model = SparseEncoder(
"sparse-encoder-testing/splade-bert-tiny-nq",
prompts={"query": "query: "},
default_prompt_name="query",
)
assert model.prompts == {"query": "query: ", "document": ""}
assert model.default_prompt_name == "query"
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmp_folder:
model_path = Path(tmp_folder) / "tiny_model_local"
model.save(str(model_path))
config_path = model_path / "config_sentence_transformers.json"
assert config_path.exists()
with open(config_path, encoding="utf8") as f:
saved_config = json.load(f)
assert saved_config["prompts"] == {"query": "query: ", "document": ""}
assert saved_config["default_prompt_name"] == "query"
fresh_model = SparseEncoder(str(model_path))
assert fresh_model.prompts == {"query": "query: ", "document": ""}
assert fresh_model.default_prompt_name == "query"
@pytest.mark.parametrize("sentences", ["Hello world", ["Hello world", "This is a test"], [], [""]])
@pytest.mark.parametrize("convert_to_tensor", [True, False])
@pytest.mark.parametrize("convert_to_sparse_tensor", [True, False])
@pytest.mark.parametrize("prompt_name", [None, "query", "custom"])
@pytest.mark.parametrize("prompt", [None, "Custom prompt: "])
def test_encode_query(
splade_bert_tiny_model: SparseEncoder,
sentences: str | list[str],
convert_to_tensor: bool,
convert_to_sparse_tensor: bool,
prompt_name: str | None,
prompt: str | None,
monkeypatch: pytest.MonkeyPatch,
):
model = splade_bert_tiny_model
model.prompts = {"query": "query: ", "custom": "custom: "}
encode_calls = []
def spy_encode(*args, **kwargs):
encode_calls.append((args, kwargs))
monkeypatch.setattr(model, "encode", spy_encode)
model.encode_query(
sentences=sentences,
batch_size=32,
convert_to_tensor=convert_to_tensor,
convert_to_sparse_tensor=convert_to_sparse_tensor,
prompt_name=prompt_name,
prompt=prompt,
)
if prompt_name:
expected_prompt_name = prompt_name
elif prompt is not None:
expected_prompt_name = None
else:
expected_prompt_name = "query"
assert len(encode_calls) == 1
_, kwargs = encode_calls[0]
assert kwargs["inputs"] == sentences
assert kwargs["convert_to_tensor"] == convert_to_tensor
assert kwargs["convert_to_sparse_tensor"] == convert_to_sparse_tensor
assert kwargs["prompt"] == prompt
assert kwargs["prompt_name"] == expected_prompt_name
assert kwargs["task"] == "query"
@pytest.mark.parametrize("sentences", ["Hello world", ["Hello world", "This is a test"], [], [""]])
@pytest.mark.parametrize("convert_to_tensor", [True, False])
@pytest.mark.parametrize("convert_to_sparse_tensor", [True, False])
@pytest.mark.parametrize("prompt_name", [None, "document", "passage", "corpus", "custom"])
@pytest.mark.parametrize("prompt", [None, "Custom prompt: "])
def test_encode_document(
splade_bert_tiny_model: SparseEncoder,
sentences: str | list[str],
convert_to_tensor: bool,
convert_to_sparse_tensor: bool,
prompt_name: str | None,
prompt: str | None,
monkeypatch: pytest.MonkeyPatch,
):
model = splade_bert_tiny_model
model.prompts = {"document": "document: ", "passage": "passage: ", "corpus": "corpus: ", "custom": "custom: "}
encode_calls = []
def spy_encode(*args, **kwargs):
encode_calls.append((args, kwargs))
monkeypatch.setattr(model, "encode", spy_encode)
model.encode_document(
sentences=sentences,
batch_size=32,
convert_to_tensor=convert_to_tensor,
convert_to_sparse_tensor=convert_to_sparse_tensor,
prompt_name=prompt_name,
prompt=prompt,
)
assert len(encode_calls) == 1
_, kwargs = encode_calls[0]
if prompt_name:
expected_prompt_name = prompt_name
elif prompt is not None:
expected_prompt_name = None
else:
expected_prompt_name = "document"
assert kwargs["inputs"] == sentences
assert kwargs["convert_to_tensor"] == convert_to_tensor
assert kwargs["convert_to_sparse_tensor"] == convert_to_sparse_tensor
assert kwargs["prompt"] == prompt
assert kwargs["prompt_name"] == expected_prompt_name
assert kwargs["task"] == "document"
def test_encode_document_prompt_priority(splade_bert_tiny_model: SparseEncoder, monkeypatch: pytest.MonkeyPatch):
"""Test that proper prompt priority is respected when multiple options are available"""
model = splade_bert_tiny_model
model.prompts = {
"document": "document: ",
"passage": "passage: ",
"corpus": "corpus: ",
}
encode_calls = []
def spy_encode(*args, **kwargs):
encode_calls.append((args, kwargs))
monkeypatch.setattr(model, "encode", spy_encode)
model.encode_document("test")
_, kwargs = encode_calls[-1]
assert kwargs["prompt_name"] == "document"
# Remove document, should fall back to passage
encode_calls.clear()
model.prompts = {"passage": "passage: ", "corpus": "corpus: "}
model.encode_document("test")
_, kwargs = encode_calls[-1]
assert kwargs["prompt_name"] == "passage"
# Remove passage, should fall back to corpus
encode_calls.clear()
model.prompts = {"corpus": "corpus: "}
model.encode_document("test")
_, kwargs = encode_calls[-1]
assert kwargs["prompt_name"] == "corpus"
# No document/passage/corpus, should use None
encode_calls.clear()
model.prompts = {"custom": "custom: "}
model.encode_document("test")
_, kwargs = encode_calls[-1]
assert kwargs["prompt_name"] is None