Spaces:
Sleeping
Sleeping
File size: 2,851 Bytes
8bfcf43 cf0a8ed 8bfcf43 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | """Tests for EmbeddingEngine — TASK-001."""
import pytest
import numpy as np
from apohara_context_forge.embeddings.embedding_engine import EmbeddingEngine
faiss_spec = __import__('importlib').util.find_spec
pytestmark = pytest.mark.skipif(
not faiss_spec('onnxruntime'),
reason="onnxruntime not installed — GPU/DevCloud environment required"
)
@pytest.fixture
async def engine():
"""Get EmbeddingEngine singleton."""
return await EmbeddingEngine.get_instance(dim=512, use_onnx=False)
class TestEmbeddingEngine:
"""Tests for EmbeddingEngine core functionality."""
@pytest.mark.asyncio
async def test_get_instance_returns_singleton(self, engine):
"""get_instance() returns the same instance on repeated calls."""
engine2 = await EmbeddingEngine.get_instance()
assert engine is engine2
@pytest.mark.asyncio
async def test_encode_returns_normalized_vector(self, engine):
"""encode() returns L2-normalized embedding."""
embedding = await engine.encode("test prompt")
assert isinstance(embedding, np.ndarray)
assert embedding.shape[0] == 512 # dim=512
norm = np.linalg.norm(embedding)
assert abs(norm - 1.0) < 1e-6
@pytest.mark.asyncio
async def test_encode_batch_returns_list(self, engine):
"""encode_batch() returns list of embeddings."""
texts = ["prompt one", "prompt two", "prompt three"]
embeddings = await engine.encode_batch(texts)
assert isinstance(embeddings, list)
assert len(embeddings) == 3
for emb in embeddings:
assert isinstance(emb, np.ndarray)
assert emb.shape[0] == 512
@pytest.mark.asyncio
async def test_simhash_returns_int(self, engine):
"""simhash() returns 64-bit integer."""
token_ids = [101, 2003, 1996, 3007, 102]
h = await engine.simhash(token_ids)
assert isinstance(h, int)
assert h >= 0
@pytest.mark.asyncio
async def test_simhash_deterministic(self, engine):
"""simhash() is deterministic for same input."""
token_ids = [101, 2003, 1996, 3007, 102]
h1 = await engine.simhash(token_ids)
h2 = await engine.simhash(token_ids)
assert h1 == h2
@pytest.mark.asyncio
async def test_simhash_different_for_different_inputs(self, engine):
"""simhash() returns different values for different token sequences."""
h1 = await engine.simhash([101, 2003, 1996])
h2 = await engine.simhash([101, 3007, 102])
assert h1 != h2
@pytest.mark.asyncio
async def test_encode_caching(self, engine):
"""Identical text returns cached embedding."""
text = "shared system prompt"
e1 = await engine.encode(text)
e2 = await engine.encode(text)
assert np.allclose(e1, e2) |