ayushm98 commited on
Commit
c3f5513
·
1 Parent(s): d4faa2c

Add tests for caching and embeddings

Browse files
Files changed (1) hide show
  1. tests/test_cache.py +126 -0
tests/test_cache.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the caching module."""
2
+
3
+ import pytest
4
+ import numpy as np
5
+ from unittest.mock import AsyncMock, MagicMock, patch
6
+
7
+ from cascade.cache.embeddings import EmbeddingService, cosine_similarity
8
+ from cascade.cache.redis_client import RedisClient
9
+
10
+
11
+ class TestEmbeddingService:
12
+ """Tests for the embedding service."""
13
+
14
+ def test_fallback_embedding_deterministic(self):
15
+ """Fallback embeddings should be deterministic."""
16
+ service = EmbeddingService()
17
+ service._use_fallback = True
18
+ service._model = None
19
+
20
+ text = "Hello world"
21
+ emb1 = service._fallback_embed(text)
22
+ emb2 = service._fallback_embed(text)
23
+
24
+ np.testing.assert_array_equal(emb1, emb2)
25
+
26
+ def test_fallback_embedding_normalized(self):
27
+ """Fallback embeddings should be normalized."""
28
+ service = EmbeddingService()
29
+ service._use_fallback = True
30
+ service._model = None
31
+
32
+ text = "Test query"
33
+ embedding = service._fallback_embed(text)
34
+ norm = np.linalg.norm(embedding)
35
+
36
+ assert abs(norm - 1.0) < 0.01
37
+
38
+ def test_different_texts_different_embeddings(self):
39
+ """Different texts should produce different embeddings."""
40
+ service = EmbeddingService()
41
+ service._use_fallback = True
42
+ service._model = None
43
+
44
+ emb1 = service._fallback_embed("Hello")
45
+ emb2 = service._fallback_embed("Goodbye")
46
+
47
+ assert not np.allclose(emb1, emb2)
48
+
49
+
50
+ class TestCosineSimilarity:
51
+ """Tests for cosine similarity function."""
52
+
53
+ def test_identical_vectors(self):
54
+ """Identical vectors should have similarity 1.0."""
55
+ vec = np.array([1.0, 2.0, 3.0])
56
+ assert abs(cosine_similarity(vec, vec) - 1.0) < 0.001
57
+
58
+ def test_orthogonal_vectors(self):
59
+ """Orthogonal vectors should have similarity 0.0."""
60
+ vec1 = np.array([1.0, 0.0, 0.0])
61
+ vec2 = np.array([0.0, 1.0, 0.0])
62
+ assert abs(cosine_similarity(vec1, vec2)) < 0.001
63
+
64
+ def test_opposite_vectors(self):
65
+ """Opposite vectors should have similarity -1.0."""
66
+ vec1 = np.array([1.0, 2.0, 3.0])
67
+ vec2 = np.array([-1.0, -2.0, -3.0])
68
+ assert abs(cosine_similarity(vec1, vec2) + 1.0) < 0.001
69
+
70
+ def test_zero_vector(self):
71
+ """Zero vectors should return 0.0 similarity."""
72
+ vec1 = np.array([1.0, 2.0, 3.0])
73
+ vec2 = np.array([0.0, 0.0, 0.0])
74
+ assert cosine_similarity(vec1, vec2) == 0.0
75
+
76
+
77
+ class TestRedisClient:
78
+ """Tests for the Redis client."""
79
+
80
+ def test_make_key_deterministic(self):
81
+ """Cache keys should be deterministic."""
82
+ client = RedisClient()
83
+
84
+ key1 = client._make_key("prefix", "query", "model")
85
+ key2 = client._make_key("prefix", "query", "model")
86
+
87
+ assert key1 == key2
88
+
89
+ def test_make_key_different_inputs(self):
90
+ """Different inputs should produce different keys."""
91
+ client = RedisClient()
92
+
93
+ key1 = client._make_key("prefix", "query1", "model")
94
+ key2 = client._make_key("prefix", "query2", "model")
95
+
96
+ assert key1 != key2
97
+
98
+ def test_make_key_format(self):
99
+ """Keys should have correct format."""
100
+ client = RedisClient()
101
+ key = client._make_key("cascade", "hello", "gpt-4o")
102
+
103
+ assert key.startswith("cascade:")
104
+ assert len(key.split(":")[1]) == 16 # SHA256 truncated to 16 chars
105
+
106
+ @pytest.mark.asyncio
107
+ async def test_cache_response_and_get(self, mock_redis_client):
108
+ """Should be able to cache and retrieve responses."""
109
+ client = RedisClient()
110
+ client._client = mock_redis_client
111
+
112
+ response = {"content": "test response"}
113
+ await client.cache_response("query", "model", response)
114
+
115
+ mock_redis_client.setex.assert_called_once()
116
+
117
+ @pytest.mark.asyncio
118
+ async def test_invalidate_cache(self, mock_redis_client):
119
+ """Should be able to invalidate cached entries."""
120
+ client = RedisClient()
121
+ client._client = mock_redis_client
122
+
123
+ result = await client.invalidate("query", "model")
124
+
125
+ mock_redis_client.delete.assert_called_once()
126
+ assert result is True