lsmpp's picture
Add files using upload-large-folder tool
bd33eac verified
from __future__ import annotations
import re
from unittest.mock import patch
import pytest
import torch
from sentence_transformers.models import Pooling, Transformer
from sentence_transformers.sparse_encoder.models import MLMTransformer, SparseAutoEncoder, SpladePooling
from sentence_transformers.sparse_encoder.SparseEncoder import SparseEncoder
from tests.sparse_encoder.utils import sparse_allclose
@pytest.mark.parametrize(
("texts", "top_k", "expected_shape"),
[
# Single text, default top_k (None)
(["The weather is nice!"], None, 1),
# Single text, specific top_k
(["The weather is nice!"], 3, 1),
# String text, specific top_k, expect a non-nested list
("The weather is nice!", 8, 8),
# Multiple texts, default top_k (None)
(["The weather is nice!", "It's sunny outside"], None, 2),
# Multiple texts, specific top_k
(["The weather is nice!", "It's sunny outside"], 3, 2),
],
)
def test_decode_shapes(
splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, top_k: int, expected_shape: int
) -> None:
model = splade_bert_tiny_model
embeddings = model.encode(texts)
decoded = model.decode(embeddings, top_k=top_k)
assert len(decoded) == expected_shape
if isinstance(texts, list):
if len(texts) == 1:
assert isinstance(decoded[0], tuple) or isinstance(decoded, list)
if top_k is not None:
assert len(decoded) <= top_k
else:
assert isinstance(decoded, list)
for item in decoded:
assert isinstance(item, list)
if top_k is not None:
assert len(item) <= top_k
@pytest.mark.parametrize(
("text", "expected_token_types"),
[
("The weather is nice!", str),
("It's sunny outside", str),
],
)
def test_decode_token_types(splade_bert_tiny_model: SparseEncoder, text: str, expected_token_types: type) -> None:
model = splade_bert_tiny_model
embeddings = model.encode(text)
decoded = model.decode(embeddings)
# Check the first item in the batch
for token, weight in decoded:
assert isinstance(token, expected_token_types)
assert isinstance(weight, float)
@pytest.mark.parametrize(
("text", "top_k"),
[
("The weather is nice!", 1),
("It's sunny outside", 3),
("Hello world", 5),
],
)
def test_decode_top_k_respects_limit(splade_bert_tiny_model: SparseEncoder, text: str, top_k: int) -> None:
model = splade_bert_tiny_model
embeddings = model.encode([text])
decoded = model.decode(embeddings, top_k=top_k)
assert len(decoded) <= top_k
@pytest.mark.parametrize(
("texts", "format_type"),
[
("The weather is nice!", "1d"),
(["The weather is nice!"], "1d"),
(["The weather is nice!", "It's sunny outside"], "2d"),
],
)
def test_decode_handles_sparse_dense_inputs(
splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, format_type: str
):
model = splade_bert_tiny_model
# Get embeddings and test both sparse and dense format handling
embeddings = model.encode(texts)
# Test with sparse tensor
if not embeddings.is_sparse:
embeddings_sparse = embeddings.to_sparse()
else:
embeddings_sparse = embeddings
decoded_sparse = model.decode(embeddings_sparse)
# Test with dense tensor
if embeddings.is_sparse:
embeddings_dense = embeddings.to_dense()
else:
embeddings_dense = embeddings
decoded_dense = model.decode(embeddings_dense)
# Verify both produce the same result structure
if format_type == "1d":
assert len(decoded_sparse) == len(decoded_dense)
else:
assert len(decoded_sparse) == len(decoded_dense)
for i in range(len(decoded_sparse)):
# Sort both results to ensure consistent comparison
sorted_sparse = sorted(decoded_sparse[i], key=lambda x: (x[1], x[0]), reverse=True)
sorted_dense = sorted(decoded_dense[i], key=lambda x: (x[1], x[0]), reverse=True)
assert len(sorted_sparse) == len(sorted_dense)
def test_decode_empty_tensor(splade_bert_tiny_model: SparseEncoder) -> None:
model = splade_bert_tiny_model
# Create an empty sparse tensor
empty_sparse = torch.sparse_coo_tensor(
indices=torch.zeros((2, 0), dtype=torch.long),
values=torch.zeros((0,), dtype=torch.float),
size=(1, model.get_sentence_embedding_dimension()),
)
decoded = model.decode(empty_sparse)
assert len(decoded) == 0 or (isinstance(decoded, list) and all(not item for item in decoded))
@pytest.mark.parametrize(
"top_k",
[None, 5, 1000],
)
@pytest.mark.parametrize(
"texts",
[
("The weather is nice!"),
(["The weather is nice!"]),
(["The weather is nice!", "It's sunny outside", "Hello world"]),
(["Short text", "This is a longer text with more words to encode"]),
],
)
def test_decode_returns_sorted_weights(
splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, top_k: int | None
) -> None:
model = splade_bert_tiny_model
embeddings = model.encode(texts)
decoded = model.decode(embeddings, top_k=top_k)
if isinstance(texts, list):
for item in decoded:
weights = [weight for _, weight in item]
assert all(weights[i] >= weights[i + 1] for i in range(len(weights) - 1))
else:
weights = [weight for _, weight in decoded]
assert all(weights[i] >= weights[i + 1] for i in range(len(weights) - 1))
def test_inference_free_splade(inference_free_splade_bert_tiny_model: SparseEncoder):
model = inference_free_splade_bert_tiny_model
dimensionality = model.get_sentence_embedding_dimension()
query = "What is the capital of France?"
document = "The capital of France is Paris."
query_embeddings = model.encode_query(query)
document_embeddings = model.encode_document(document)
assert query_embeddings.shape == (dimensionality,)
assert document_embeddings.shape == (dimensionality,)
decoded_query = model.decode(query_embeddings)
decoded_document = model.decode(document_embeddings)
assert len(decoded_query) == len(model.tokenize(query, task="query")["input_ids"][0])
assert len(decoded_document) >= 50
assert model.max_seq_length == 512
assert model[0].sub_modules["query"][0].max_seq_length == 512
assert model[0].sub_modules["document"][0].max_seq_length == 512
model.max_seq_length = 256
assert model.max_seq_length == 256
assert model[0].sub_modules["query"][0].max_seq_length == 256
assert model[0].sub_modules["document"][0].max_seq_length == 256
@pytest.mark.parametrize("sentences", ["Hello world", ["Hello world", "This is a test"], [], [""]])
@pytest.mark.parametrize("prompt_name", [None, "query", "custom"])
@pytest.mark.parametrize("prompt", [None, "Custom prompt: "])
@pytest.mark.parametrize("convert_to_tensor", [True, False])
@pytest.mark.parametrize("convert_to_sparse_tensor", [True, False])
def test_encode_query(
splade_bert_tiny_model: SparseEncoder,
sentences: str | list[str],
prompt_name: str | None,
prompt: str | None,
convert_to_tensor: bool,
convert_to_sparse_tensor: bool,
):
model = splade_bert_tiny_model
# Create a mock model with required prompts
model.prompts = {"query": "query: ", "custom": "custom: "}
# Create a mock for the encode method
with patch.object(model, "encode", autospec=True) as mock_encode:
# Call encode_query
model.encode_query(
sentences=sentences,
prompt_name=prompt_name,
prompt=prompt,
batch_size=32,
convert_to_tensor=convert_to_tensor,
convert_to_sparse_tensor=convert_to_sparse_tensor,
)
# Verify that encode was called with the correct parameters
expected_prompt_name = prompt_name if prompt_name else "query"
mock_encode.assert_called_once()
args, kwargs = mock_encode.call_args
# Check that sentences were passed correctly
assert kwargs["sentences"] == sentences
# Check prompt handling
assert kwargs["prompt"] == prompt
assert kwargs["prompt_name"] == expected_prompt_name
# Check other parameters
assert kwargs["convert_to_tensor"] == convert_to_tensor
assert kwargs["convert_to_sparse_tensor"] == convert_to_sparse_tensor
assert kwargs["task"] == "query"
@pytest.mark.parametrize("sentences", ["Hello world", ["Hello world", "This is a test"], [], [""]])
@pytest.mark.parametrize("prompt_name", [None, "document", "passage", "corpus", "custom"])
@pytest.mark.parametrize("prompt", [None, "Custom prompt: "])
@pytest.mark.parametrize("convert_to_tensor", [True, False])
@pytest.mark.parametrize("convert_to_sparse_tensor", [True, False])
def test_encode_document(
splade_bert_tiny_model: SparseEncoder,
sentences: str | list[str],
prompt_name: str | None,
prompt: str | None,
convert_to_tensor: bool,
convert_to_sparse_tensor: bool,
):
# Create a mock model with required prompts
model = splade_bert_tiny_model
model.prompts = {"document": "document: ", "passage": "passage: ", "corpus": "corpus: ", "custom": "custom: "}
# Create a mock for the encode method
with patch.object(model, "encode", autospec=True) as mock_encode:
# Call encode_document
model.encode_document(
sentences=sentences,
prompt_name=prompt_name,
prompt=prompt,
batch_size=32,
convert_to_tensor=convert_to_tensor,
convert_to_sparse_tensor=convert_to_sparse_tensor,
)
# Verify that encode was called with the correct parameters
mock_encode.assert_called_once()
args, kwargs = mock_encode.call_args
expected_prompt_name = prompt_name if prompt_name else "document"
# Check that sentences were passed correctly
assert kwargs["sentences"] == sentences
# Check prompt handling
assert kwargs["prompt"] == prompt
assert kwargs["prompt_name"] == expected_prompt_name
# Check other parameters
assert kwargs["convert_to_tensor"] == convert_to_tensor
assert kwargs["convert_to_sparse_tensor"] == convert_to_sparse_tensor
assert kwargs["task"] == "document"
def test_encode_document_prompt_priority(splade_bert_tiny_model: SparseEncoder):
"""Test that proper prompt priority is respected when multiple options are available"""
model = splade_bert_tiny_model
model.prompts = {
"document": "document: ",
"passage": "passage: ",
"corpus": "corpus: ",
}
# Create a mock for the encode method
with patch.object(model, "encode", autospec=True) as mock_encode:
# Call encode_document with no explicit prompt
model.encode_document("test")
# It should select "document" by default since that's first in the priority list
args, kwargs = mock_encode.call_args
assert kwargs["prompt_name"] == "document"
# Remove document, should fall back to passage
mock_encode.reset_mock()
model.prompts = {
"passage": "passage: ",
"corpus": "corpus: ",
}
model.encode_document("test")
args, kwargs = mock_encode.call_args
assert kwargs["prompt_name"] == "passage"
# Remove passage, should fall back to corpus
mock_encode.reset_mock()
model.prompts = {
"corpus": "corpus: ",
}
model.encode_document("test")
args, kwargs = mock_encode.call_args
assert kwargs["prompt_name"] == "corpus"
# No relevant prompts defined
mock_encode.reset_mock()
model.prompts = {
"query": "query: ",
}
model.encode_document("test")
args, kwargs = mock_encode.call_args
assert kwargs["prompt_name"] is None
def test_encode_advanced_parameters(splade_bert_tiny_model: SparseEncoder):
"""Test that additional parameters are correctly passed to encode"""
model = splade_bert_tiny_model
# Create a mock for the encode method
with patch.object(model, "encode", autospec=True) as mock_encode:
# Call with advanced parameters
model.encode_query(
"test",
normalize_embeddings=True,
batch_size=64,
show_progress_bar=True,
max_active_dims=128,
chunk_size=10,
custom_param="value",
)
# Verify all parameters were passed correctly
args, kwargs = mock_encode.call_args
assert kwargs["normalize_embeddings"] is True
assert kwargs["batch_size"] == 64
assert kwargs["show_progress_bar"] is True
assert kwargs["max_active_dims"] == 128
assert kwargs["chunk_size"] == 10
assert kwargs["custom_param"] == "value"
@pytest.mark.parametrize("inputs", ["test sentence", ["test sentence"]])
def test_encode_query_document_vs_encode(splade_bert_tiny_model: SparseEncoder, inputs: str | list[str]):
"""Test the actual integration with encode vs encode_query/encode_document"""
# This test requires a real model, but we'll use a small one
model = splade_bert_tiny_model
model.prompts = {"query": "query: ", "document": "document: "}
# Get embeddings with encode_query and encode_document
query_embeddings = model.encode_query(inputs)
document_embeddings = model.encode_document(inputs)
# And the same but with encode via prompts (task doesn't help here)
encode_query_embeddings = model.encode(inputs, prompt_name="query")
encode_document_embeddings = model.encode(inputs, prompt_name="document")
# With prompts they should be the same
assert sparse_allclose(query_embeddings, encode_query_embeddings)
assert sparse_allclose(document_embeddings, encode_document_embeddings)
# Without prompts they should be different
query_embeddings_without_prompt = model.encode(inputs)
document_embeddings_without_prompt = model.encode(inputs)
# Embeddings should differ when different prompts are used
assert not sparse_allclose(query_embeddings_without_prompt, query_embeddings)
assert not sparse_allclose(document_embeddings_without_prompt, document_embeddings)
def test_default_prompt(splade_bert_tiny_model: SparseEncoder):
"""Test that the default prompt is used when no prompt is specified"""
model = splade_bert_tiny_model
model.prompts = {"query": "query: ", "document": "document: "}
model.default_prompt_name = "query"
# Call encode_query without specifying a prompt
query_embeddings = model.encode_query("test")
assert query_embeddings.shape == (model.get_sentence_embedding_dimension(),)
# Call encode_document without specifying a prompt
document_embeddings = model.encode_document("test")
assert document_embeddings.shape == (model.get_sentence_embedding_dimension(),)
default_embeddings = model.encode("test")
assert default_embeddings.shape == (model.get_sentence_embedding_dimension(),)
# Make sure that the default prompt is used
assert sparse_allclose(query_embeddings, default_embeddings)
assert not sparse_allclose(document_embeddings, default_embeddings)
# Also check that if the default prompt is not set, the default embeddings aren't the same as query
model.default_prompt_name = None
default_embeddings_no_default = model.encode("test")
assert not sparse_allclose(default_embeddings_no_default, default_embeddings)
def test_wrong_prompt(splade_bert_tiny_model: SparseEncoder):
"""Test that using a wrong prompt raises an error"""
model = splade_bert_tiny_model
model.prompts = {"query": "query: ", "document": "document: "}
for encode_method in [model.encode_query, model.encode_document, model.encode]:
with pytest.raises(
ValueError,
match=re.escape(
"Prompt name 'invalid_prompt' not found in the configured prompts dictionary with keys ['query', 'document']."
),
):
encode_method("test", prompt_name="invalid_prompt")
def test_max_active_dims_set_init(splade_bert_tiny_model: SparseEncoder, csr_bert_tiny_model: SparseEncoder, tmp_path):
splade_bert_tiny_model.save_pretrained(str(tmp_path / "splade_bert_tiny"))
csr_bert_tiny_model.save_pretrained(str(tmp_path / "csr_bert_tiny"))
# Load the models with max_active_dims set
loaded_model = SparseEncoder(str(tmp_path / "splade_bert_tiny"))
assert loaded_model.max_active_dims is None
loaded_model = SparseEncoder(str(tmp_path / "splade_bert_tiny"), max_active_dims=13)
assert loaded_model.max_active_dims == 13
loaded_model = SparseEncoder(str(tmp_path / "csr_bert_tiny"))
assert loaded_model.max_active_dims == 16 # Based on the SparseAutoEncoder's k value
loaded_model = SparseEncoder(str(tmp_path / "csr_bert_tiny"), max_active_dims=13)
assert loaded_model.max_active_dims == 13
def test_detect_mlm():
model = SparseEncoder("distilbert/distilbert-base-uncased")
assert isinstance(model[0], MLMTransformer)
assert isinstance(model[1], SpladePooling)
def test_default_to_csr():
# NOTE: bert-tiny is actually MLM-based, but the config isn't modern enough to allow us to detect it,
# so we should default to CSR here.
model = SparseEncoder("prajjwal1/bert-tiny")
assert isinstance(model[0], Transformer)
assert isinstance(model[1], Pooling)
assert isinstance(model[2], SparseAutoEncoder)
def test_sparsity(splade_bert_tiny_model: SparseEncoder):
model = splade_bert_tiny_model
# Check that the sparsity is applied correctly
embeddings = model.encode_query(["What is the capital of France?", "Who has won the World Cup in 2016?"])
sparsity = model.sparsity(embeddings)
assert isinstance(sparsity, dict)
assert "active_dims" in sparsity
assert "sparsity_ratio" in sparsity
assert sparsity["active_dims"] < 100 and sparsity["active_dims"] > 0
assert sparsity["sparsity_ratio"] < 1.0 and sparsity["sparsity_ratio"] >= 0.99
# Also check with dense tensors
dense_sparsity = model.sparsity(embeddings.to_dense())
assert dense_sparsity == sparsity, "Sparsity should be the same for dense and sparse tensors"
# Check that 1-dimensional embeddings work correctly
sparsity_one = model.sparsity(embeddings[0])
sparsity_two = model.sparsity(embeddings[1])
assert (sparsity_one["active_dims"] + sparsity_two["active_dims"]) / 2 == sparsity["active_dims"]
def test_splade_pooling_chunk_size(splade_bert_tiny_model: SparseEncoder):
model = splade_bert_tiny_model
# The chunk size defaults to None, i.e. no chunking
assert model.splade_pooling_chunk_size is None
# But we can chunk the pooling to save memory at the cost of some speed
model.splade_pooling_chunk_size = 13
assert model.splade_pooling_chunk_size == 13
assert isinstance(model[1], SpladePooling)
assert model[1].chunk_size == 13
def test_intersection(splade_bert_tiny_model: SparseEncoder):
model = splade_bert_tiny_model
# Test intersection with a single text
query = "Where can I deposit my money?"
document = "I'm sitting by the river."
query_embeddings = model.encode_query(query)
document_embeddings = model.encode_document(document)
query_sparsity = model.sparsity(query_embeddings)
document_sparsity = model.sparsity(document_embeddings)
# Let's check that the intersection is a tensor and has the correct shape
intersection = model.intersection(query_embeddings, document_embeddings)
assert isinstance(intersection, torch.Tensor)
assert intersection.shape == (model.get_sentence_embedding_dimension(),)
# Check that the intersection sparsity is less than both query and document sparsities
intersection_sparsity = model.sparsity(intersection)
assert (
intersection_sparsity["active_dims"] < query_sparsity["active_dims"]
and intersection_sparsity["active_dims"] < document_sparsity["active_dims"]
)
# Test with multiple texts
query = "Who has won the World Cup in 2016?"
documents = ["The capital of France is Paris.", "Germany won the World Cup in 2014."]
query_embeddings = model.encode_query(query)
document_embeddings = model.encode_document(documents)
intersection_batch = model.intersection(query_embeddings, document_embeddings)
assert isinstance(intersection_batch, torch.Tensor)
assert intersection_batch.shape == (len(documents), model.get_sentence_embedding_dimension())
decoded_intersection_batch = model.decode(intersection_batch)
assert len(decoded_intersection_batch) == len(documents)