lsmpp's picture
Add files using upload-large-folder tool
bd33eac verified
from __future__ import annotations
import pytest
from torch import Tensor
from sentence_transformers import SparseEncoder
@pytest.mark.parametrize(
"model_name",
[
("sentence-transformers/all-MiniLM-L6-v2"),
],
)
def test_load_and_encode(model_name: str) -> None:
# Ensure that SparseEncoder can be initialized with a base model and can encode
try:
model = SparseEncoder(model_name)
except Exception as e:
pytest.fail(f"Failed to load SparseEncoder with {model_name}: {e}")
sentences = [
"This is a test sentence.",
"Another example sentence here.",
"Sparse encoders are interesting.",
]
try:
embeddings = model.encode(sentences)
except Exception as e:
pytest.fail(f"SparseEncoder failed to encode sentences: {e}")
assert embeddings is not None
assert isinstance(embeddings, Tensor), "Embeddings should be a tensor for sparse encoders"
assert len(embeddings) == len(sentences), "Number of embeddings should match number of sentences"
decoded_embeddings = model.decode(embeddings)
assert len(decoded_embeddings) == len(sentences), "Decoded embeddings should match number of sentences"
assert all(isinstance(emb, list) for emb in decoded_embeddings), "Decoded embeddings should be a list of lists"
# Check a known property: encoding a single sentence
single_sentence_emb = model.encode(["A single sentence."], convert_to_tensor=False)
assert isinstance(
single_sentence_emb, list
), "Encoding a single sentence with convert_to_tensor=False should return a list of len 1"
assert len(single_sentence_emb) == 1, "Single sentence embedding dict should not be empty"
# If we're using a string instead of a list, we should get a single tensor embedding
single_sentence_emb_tensor = model.encode("A single sentence.", convert_to_tensor=False)
assert isinstance(
single_sentence_emb_tensor, Tensor
), "Encoding a single sentence with convert_to_tensor=False should return a tensor"
assert single_sentence_emb_tensor.dim() == 1, "Single sentence embedding tensor should be 1D"
# Check encoding with show_progress_bar
try:
embeddings_with_progress = model.encode(sentences, show_progress_bar=True)
assert len(embeddings_with_progress) == len(sentences)
except Exception as e:
pytest.fail(f"SparseEncoder failed to encode with progress bar: {e}")