File size: 4,131 Bytes
c3f5513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Tests for the caching module."""

import pytest
import numpy as np
from unittest.mock import AsyncMock, MagicMock, patch

from cascade.cache.embeddings import EmbeddingService, cosine_similarity
from cascade.cache.redis_client import RedisClient


class TestEmbeddingService:
    """Tests for the embedding service."""

    def test_fallback_embedding_deterministic(self):
        """Fallback embeddings should be deterministic."""
        service = EmbeddingService()
        service._use_fallback = True
        service._model = None

        text = "Hello world"
        emb1 = service._fallback_embed(text)
        emb2 = service._fallback_embed(text)

        np.testing.assert_array_equal(emb1, emb2)

    def test_fallback_embedding_normalized(self):
        """Fallback embeddings should be normalized."""
        service = EmbeddingService()
        service._use_fallback = True
        service._model = None

        text = "Test query"
        embedding = service._fallback_embed(text)
        norm = np.linalg.norm(embedding)

        assert abs(norm - 1.0) < 0.01

    def test_different_texts_different_embeddings(self):
        """Different texts should produce different embeddings."""
        service = EmbeddingService()
        service._use_fallback = True
        service._model = None

        emb1 = service._fallback_embed("Hello")
        emb2 = service._fallback_embed("Goodbye")

        assert not np.allclose(emb1, emb2)


class TestCosineSimilarity:
    """Tests for cosine similarity function."""

    def test_identical_vectors(self):
        """Identical vectors should have similarity 1.0."""
        vec = np.array([1.0, 2.0, 3.0])
        assert abs(cosine_similarity(vec, vec) - 1.0) < 0.001

    def test_orthogonal_vectors(self):
        """Orthogonal vectors should have similarity 0.0."""
        vec1 = np.array([1.0, 0.0, 0.0])
        vec2 = np.array([0.0, 1.0, 0.0])
        assert abs(cosine_similarity(vec1, vec2)) < 0.001

    def test_opposite_vectors(self):
        """Opposite vectors should have similarity -1.0."""
        vec1 = np.array([1.0, 2.0, 3.0])
        vec2 = np.array([-1.0, -2.0, -3.0])
        assert abs(cosine_similarity(vec1, vec2) + 1.0) < 0.001

    def test_zero_vector(self):
        """Zero vectors should return 0.0 similarity."""
        vec1 = np.array([1.0, 2.0, 3.0])
        vec2 = np.array([0.0, 0.0, 0.0])
        assert cosine_similarity(vec1, vec2) == 0.0


class TestRedisClient:
    """Tests for the Redis client."""

    def test_make_key_deterministic(self):
        """Cache keys should be deterministic."""
        client = RedisClient()

        key1 = client._make_key("prefix", "query", "model")
        key2 = client._make_key("prefix", "query", "model")

        assert key1 == key2

    def test_make_key_different_inputs(self):
        """Different inputs should produce different keys."""
        client = RedisClient()

        key1 = client._make_key("prefix", "query1", "model")
        key2 = client._make_key("prefix", "query2", "model")

        assert key1 != key2

    def test_make_key_format(self):
        """Keys should have correct format."""
        client = RedisClient()
        key = client._make_key("cascade", "hello", "gpt-4o")

        assert key.startswith("cascade:")
        assert len(key.split(":")[1]) == 16  # SHA256 truncated to 16 chars

    @pytest.mark.asyncio
    async def test_cache_response_and_get(self, mock_redis_client):
        """Should be able to cache and retrieve responses."""
        client = RedisClient()
        client._client = mock_redis_client

        response = {"content": "test response"}
        await client.cache_response("query", "model", response)

        mock_redis_client.setex.assert_called_once()

    @pytest.mark.asyncio
    async def test_invalidate_cache(self, mock_redis_client):
        """Should be able to invalidate cached entries."""
        client = RedisClient()
        client._client = mock_redis_client

        result = await client.invalidate("query", "model")

        mock_redis_client.delete.assert_called_once()
        assert result is True