Spaces:
Sleeping
Sleeping
File size: 2,223 Bytes
2bf4d05 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | import pytest
from mediastorm.vectorize.embedder import Embedder
@pytest.fixture(scope="module")
def embedder():
return Embedder()
def test_embed_query_returns_384_dimensions(embedder):
"""embed_query should produce a 384-dimensional vector."""
vec = embedder.embed_query("documentary about war")
assert len(vec) == 384
def test_embed_query_is_normalized(embedder):
"""embed_query result should be L2-normalized."""
vec = embedder.embed_query("human rights in Africa")
norm = sum(x ** 2 for x in vec) ** 0.5
assert norm == pytest.approx(1.0, abs=1e-4)
def test_embed_query_repeated_calls_return_identical_results(embedder):
"""Repeated calls with same text should return identical vectors."""
text = "award winning photography"
v1 = embedder.embed_query(text)
v2 = embedder.embed_query(text)
assert v1 == v2
def test_embed_query_cache_avoids_recomputation(embedder):
"""Second call with same text should not invoke embed_texts again."""
# Clear the cache so we start fresh
embedder._embed_query_cached.cache_clear()
call_count = 0
original_embed_texts = embedder.embed_texts
def counting_embed_texts(texts):
nonlocal call_count
call_count += 1
return original_embed_texts(texts)
embedder.embed_texts = counting_embed_texts
try:
embedder.embed_query("cache test query")
embedder.embed_query("cache test query")
finally:
embedder.embed_texts = original_embed_texts
assert call_count == 1, f"embed_texts called {call_count} times, expected 1"
def test_embed_query_different_texts_produce_different_vectors(embedder):
"""Different queries should produce different embeddings."""
v1 = embedder.embed_query("war documentary")
v2 = embedder.embed_query("cooking show")
assert v1 != v2
def test_embed_query_matches_embed_texts_output(embedder):
"""embed_query result should match embed_texts([text])[0]."""
text = "journalism and press freedom"
# Use embed_texts directly to bypass cache for comparison
expected = embedder.embed_texts([text])[0]
result = embedder.embed_query(text)
assert result == pytest.approx(expected, abs=1e-6)
|