from app.embeddings import average_pool, embed_text import torch import pytest def test_average_pool_basic() -> None: """Test average pooling produces correct shape and masking.""" last_hidden_states = torch.tensor( [ [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]], ] ) # shape: (2, 3, 2) attention_mask = torch.tensor( [ [1, 1, 0], [1, 0, 0], ] ) # shape: (2, 3) result = average_pool(last_hidden_states, attention_mask) # Expected averages: # row1: [(1+3)/2, (2+4)/2] = [2,3] # row2: [10, 20] expected = torch.tensor([[2.0, 3.0], [10.0, 20.0]]) assert torch.allclose(result, expected, atol=1e-6) assert result.shape == (2, 2) def test_embed_text_valid() -> None: """Test embedding returns correct number of vectors and dimensions.""" texts = ["query: Hello world", "query: Hej verden"] embeddings = embed_text(texts) # Assertions assert isinstance(embeddings, list) assert len(embeddings) == len(texts) assert all(isinstance(vec, list) for vec in embeddings) assert all(isinstance(x, float) for x in embeddings[0]) assert len(embeddings[0]) == 1024 def test_embed_text_empty_list() -> None: """Should raise ValueError if no input texts.""" with pytest.raises(ValueError, match="No input texts provided"): embed_text([]) def test_embed_text_too_long() -> None: """Should raise ValueError for inputs exceeding 2000 characters.""" too_long = ["query: " + "a" * 1994] # 2001 characters with pytest.raises(ValueError, match="exceed the maximum length"): embed_text(too_long)