Spaces:
Sleeping
Sleeping
| from app.embeddings import average_pool, embed_text | |
| import torch | |
| import pytest | |
| def test_average_pool_basic() -> None: | |
| """Test average pooling produces correct shape and masking.""" | |
| last_hidden_states = torch.tensor( | |
| [ | |
| [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], | |
| [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]], | |
| ] | |
| ) # shape: (2, 3, 2) | |
| attention_mask = torch.tensor( | |
| [ | |
| [1, 1, 0], | |
| [1, 0, 0], | |
| ] | |
| ) # shape: (2, 3) | |
| result = average_pool(last_hidden_states, attention_mask) | |
| # Expected averages: | |
| # row1: [(1+3)/2, (2+4)/2] = [2,3] | |
| # row2: [10, 20] | |
| expected = torch.tensor([[2.0, 3.0], [10.0, 20.0]]) | |
| assert torch.allclose(result, expected, atol=1e-6) | |
| assert result.shape == (2, 2) | |
| def test_embed_text_valid() -> None: | |
| """Test embedding returns correct number of vectors and dimensions.""" | |
| texts = ["query: Hello world", "query: Hej verden"] | |
| embeddings = embed_text(texts) | |
| # Assertions | |
| assert isinstance(embeddings, list) | |
| assert len(embeddings) == len(texts) | |
| assert all(isinstance(vec, list) for vec in embeddings) | |
| assert all(isinstance(x, float) for x in embeddings[0]) | |
| assert len(embeddings[0]) == 1024 | |
| def test_embed_text_empty_list() -> None: | |
| """Should raise ValueError if no input texts.""" | |
| with pytest.raises(ValueError, match="No input texts provided"): | |
| embed_text([]) | |
| def test_embed_text_too_long() -> None: | |
| """Should raise ValueError for inputs exceeding 2000 characters.""" | |
| too_long = ["query: " + "a" * 1994] # 2001 characters | |
| with pytest.raises(ValueError, match="exceed the maximum length"): | |
| embed_text(too_long) | |