|
|
from __future__ import annotations |
|
|
|
|
|
import pytest |
|
|
from torch import Tensor |
|
|
|
|
|
from sentence_transformers import SparseEncoder |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
"model_name", |
|
|
[ |
|
|
("sentence-transformers/all-MiniLM-L6-v2"), |
|
|
], |
|
|
) |
|
|
def test_load_and_encode(model_name: str) -> None: |
|
|
|
|
|
try: |
|
|
model = SparseEncoder(model_name) |
|
|
except Exception as e: |
|
|
pytest.fail(f"Failed to load SparseEncoder with {model_name}: {e}") |
|
|
|
|
|
sentences = [ |
|
|
"This is a test sentence.", |
|
|
"Another example sentence here.", |
|
|
"Sparse encoders are interesting.", |
|
|
] |
|
|
|
|
|
try: |
|
|
embeddings = model.encode(sentences) |
|
|
except Exception as e: |
|
|
pytest.fail(f"SparseEncoder failed to encode sentences: {e}") |
|
|
|
|
|
assert embeddings is not None |
|
|
|
|
|
assert isinstance(embeddings, Tensor), "Embeddings should be a tensor for sparse encoders" |
|
|
assert len(embeddings) == len(sentences), "Number of embeddings should match number of sentences" |
|
|
|
|
|
decoded_embeddings = model.decode(embeddings) |
|
|
assert len(decoded_embeddings) == len(sentences), "Decoded embeddings should match number of sentences" |
|
|
assert all(isinstance(emb, list) for emb in decoded_embeddings), "Decoded embeddings should be a list of lists" |
|
|
|
|
|
|
|
|
single_sentence_emb = model.encode(["A single sentence."], convert_to_tensor=False) |
|
|
assert isinstance( |
|
|
single_sentence_emb, list |
|
|
), "Encoding a single sentence with convert_to_tensor=False should return a list of len 1" |
|
|
assert len(single_sentence_emb) == 1, "Single sentence embedding dict should not be empty" |
|
|
|
|
|
|
|
|
single_sentence_emb_tensor = model.encode("A single sentence.", convert_to_tensor=False) |
|
|
assert isinstance( |
|
|
single_sentence_emb_tensor, Tensor |
|
|
), "Encoding a single sentence with convert_to_tensor=False should return a tensor" |
|
|
assert single_sentence_emb_tensor.dim() == 1, "Single sentence embedding tensor should be 1D" |
|
|
|
|
|
|
|
|
try: |
|
|
embeddings_with_progress = model.encode(sentences, show_progress_bar=True) |
|
|
assert len(embeddings_with_progress) == len(sentences) |
|
|
except Exception as e: |
|
|
pytest.fail(f"SparseEncoder failed to encode with progress bar: {e}") |
|
|
|