File size: 7,340 Bytes
eff2120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import logging
import json
import hashlib
from typing import Optional, Any

logger = logging.getLogger(__name__)


class SemanticCache:
    """
    Semantic Cache powered by Redis and sentence-transformers.

    Recommended redis.conf / Redis server settings:
        maxmemory 240mb
        maxmemory-policy allkeys-lfu
        lfu-decay-time 5
        lfu-log-factor 10

    Automatically disables itself when Redis or ML dependencies are unavailable.
    """

    def __init__(self, redis_url: Optional[str] = None, similarity_threshold: float = 0.95):
        self.enabled = False
        self.similarity_threshold = similarity_threshold
        self.redis: Any = None
        self.model: Any = None
        self.cosine_similarity: Any = None
        self.np: Any = None

        if not redis_url:
            logger.info("SemanticCache: No Redis URL provided. Cache disabled.")
            return

        # Try connecting to Redis
        try:
            import redis  # type: ignore
            self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
            self.redis.ping()
        except ImportError:
            logger.warning("SemanticCache: 'redis' package not installed. Cache disabled.")
            return
        except Exception as e:
            logger.warning(f"SemanticCache: Failed to connect to Redis at {redis_url}: {e}")
            self.redis = None
            return

        # Try loading sentence-transformers + sklearn
        try:
            from sentence_transformers import SentenceTransformer  # type: ignore
            import numpy as np  # type: ignore
            from sklearn.metrics.pairwise import cosine_similarity

            self.cosine_similarity = cosine_similarity
            self.np = np
            logger.info("SemanticCache: Loading embedding model (all-MiniLM-L6-v2)...")
            self.model = SentenceTransformer("all-MiniLM-L6-v2")
            self.enabled = True
            logger.info("SemanticCache: Successfully initialized and connected to Redis!")
        except ImportError:
            logger.warning(
                "SemanticCache: 'sentence-transformers' or 'scikit-learn' not installed. Cache disabled."
            )
            self.redis = None
        except Exception as e:
            logger.warning(f"SemanticCache: Failed to load ML models: {e}")
            self.redis = None

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _is_within_memory_limit(self, safety_ratio: float = 0.90) -> bool:
        """
        Returns False when Redis has consumed >= safety_ratio of its maxmemory.
        Prevents new writes from pushing Redis over the 250 MB hard limit.
        Fails open (returns True) if the info call itself errors.
        """
        try:
            info = self.redis.info("memory")
            used = info["used_memory"]
            max_mem = info.get("maxmemory", 0)
            if max_mem == 0:
                # No maxmemory configured — rely solely on allkeys-lfu eviction.
                return True
            within = (used / max_mem) < safety_ratio
            if not within:
                logger.warning(
                    f"SemanticCache: Memory at {used / max_mem:.1%} of limit "
                    f"({used / 1_048_576:.1f} MB / {max_mem / 1_048_576:.1f} MB). "
                    "Skipping write."
                )
            return within
        except Exception as e:
            logger.warning(f"SemanticCache: Memory check failed (failing open): {e}")
            return True

    @staticmethod
    def _cache_key(query: str) -> str:
        """Stable, cross-process MD5 key for a query string."""
        query_hash = hashlib.md5(query.encode("utf-8")).hexdigest()
        return f"llmopt:cache:{query_hash}"

    @staticmethod
    def _ttl_for_response(response: str) -> int:
        """
        Longer, richer responses get a longer TTL — they are more expensive to
        regenerate and therefore more valuable to keep around.

        > 500 chars  →  7 days   (604 800 s)
        ≤ 500 chars  →  3 days   (259 200 s)
        """
        return 604_800 if len(response) > 500 else 259_200

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def get(self, query: str) -> Optional[str]:
        """
        Return the cached LLM response for a semantically similar query, or
        None on a cache miss.

        Uses a Redis pipeline to fetch all cached entries in a single round
        trip instead of one GET per key, keeping network overhead low even as
        the cache grows.
        """
        if not self.enabled:
            return None

        try:
            query_embedding = self.model.encode([query])[0]

            keys = self.redis.keys("llmopt:cache:*")
            if not keys:
                return None

            # Batch-fetch all entries in one round trip
            pipe = self.redis.pipeline()
            for key in keys:
                pipe.get(key)
            results = pipe.execute()

            best_key = None
            highest_sim = -1.0

            for key, data_str in zip(keys, results):
                if not data_str:
                    continue
                data = json.loads(data_str)
                cached_emb = self.np.array(data["embedding"])
                sim = self.cosine_similarity([query_embedding], [cached_emb])[0][0]
                if sim > highest_sim:
                    highest_sim = sim
                    best_key = key

            if highest_sim >= self.similarity_threshold and best_key:
                logger.info(f"SemanticCache HIT! Similarity: {highest_sim:.3f}")
                match_data = json.loads(self.redis.get(best_key))
                return match_data["response"]

        except Exception as e:
            logger.warning(f"SemanticCache GET error: {e}")

        return None

    def set(self, query: str, response: str) -> None:
        """
        Embed and store a query/response pair.

        Skips the write when Redis is near its memory ceiling so that the
        allkeys-lfu policy never has to evict a hot entry just to absorb a
        brand-new one.
        """
        if not self.enabled:
            return

        # Guard: don't write when we are close to the 250 MB limit
        if not self._is_within_memory_limit(safety_ratio=0.90):
            return

        try:
            query_embedding = self.model.encode([query])[0]
            key = self._cache_key(query)
            ttl = self._ttl_for_response(response)

            data = {
                "query": query,
                "embedding": query_embedding.tolist(),
                "response": response,
            }

            # Atomic set + expiry via pipeline
            pipe = self.redis.pipeline()
            pipe.set(key, json.dumps(data))
            pipe.expire(key, ttl)
            pipe.execute()

            logger.debug(
                f"SemanticCache SET: key={key} ttl={ttl}s "
                f"response_len={len(response)}"
            )

        except Exception as e:
            logger.warning(f"SemanticCache SET error: {e}")