File size: 8,675 Bytes
3a4bdd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""
app/memory/semantic_cache.py β€” Semantic (meaning-aware) cache for MathMinds AI.

Architecture
────────────
Exact hash cache (Redis)          ← microseconds, free
        ↓ MISS
Semantic vector cache (Redis)     ← ~50ms, free (embedding stored in Redis)
        ↓ MISS
Gemini API call                   ← costs 1 quota unit

Why two layers?
  - Exact cache: zero cost, handles identical repeated questions instantly.
  - Semantic cache: handles paraphrases. Uses Google's gemini-embedding-001
    to embed both the query and stored questions, then finds nearest neighbour
    by cosine similarity. Entirely self-contained in Redis β€” no Supabase needed.

Redis key design
  semantic:index         β†’ Redis Set  β€” all embedding keys
  semantic:emb:{hash}    β†’ JSON {query, embedding, answer, metadata, timestamp}

Similarity threshold: 0.85
  - 0.85+ β†’ same mathematical question, different words  (safe to return)
  - 0.70-0.85 β†’ related topic, probably different question (skip)
  - <0.70 β†’ unrelated

Quota cost of embeddings
  gemini-embedding-001 is NOT counted against the generate_content quota.
  It has its own free tier: 1500 requests/day β€” far more than the 20/day
  generate limit, so semantic lookup is essentially free to run.
"""

import json
import logging
import hashlib
import time
import math
from typing import Optional, Dict, Any, List, Tuple

logger = logging.getLogger(__name__)

# ── Similarity threshold ───────────────────────────────────────────────────
# Tested against math paraphrase pairs. Lower = more aggressive matching.
SIMILARITY_THRESHOLD = 0.85

# Redis key prefixes
_PREFIX_EMB   = "semantic:emb:"    # stores embedding + answer
_INDEX_KEY    = "semantic:index"   # set of all embedding hashes
_TTL_SECONDS  = 7 * 24 * 3600     # 7 days


def _cosine_similarity(a: List[float], b: List[float]) -> float:
    """Pure-Python cosine similarity. No numpy needed."""
    dot = sum(x * y for x, y in zip(a, b))
    norm_a = math.sqrt(sum(x * x for x in a))
    norm_b = math.sqrt(sum(x * x for x in b))
    if norm_a == 0 or norm_b == 0:
        return 0.0
    return dot / (norm_a * norm_b)


def _normalize_query(query: str) -> str:
    """
    Light normalization before embedding.
    Removes punctuation noise but keeps math symbols β€” '2+2' and '2 + 2'
    should map to the same embedding region.
    """
    import re
    q = query.lower().strip()
    # collapse whitespace
    q = re.sub(r"\s+", " ", q)
    return q


class SemanticCache:
    """
    Semantic similarity cache backed by Redis.

    Usage (in orchestrator):
        sc = SemanticCache(redis_client, gemini_client)

        # Lookup
        result = sc.get(query)
        if result:
            return result["answer"]

        # Store after getting answer from API
        sc.set(query, answer_text, metadata)
    """

    def __init__(self, redis_client, gemini_api_key: str):
        self.redis  = redis_client
        self._api_key = gemini_api_key
        self._genai = None   # lazy init

    def _get_client(self):
        """Lazy-init google.genai client so import errors are surfaced clearly."""
        if self._genai is None:
            try:
                from google import genai
                self._genai = genai.Client(api_key=self._api_key)
            except Exception as e:
                logger.error(f"SemanticCache: failed to init genai client: {e}")
                raise
        return self._genai

    def _embed(self, text: str) -> Optional[List[float]]:
        """
        Generate embedding vector for text.
        Uses gemini-embedding-001 (NOT counted against generate_content quota).
        Returns None on failure so cache misses gracefully on API errors.
        """
        try:
            from google.genai import types
            client = self._get_client()
            resp = client.models.embed_content(
                model="models/gemini-embedding-001",
                contents=_normalize_query(text),
                config=types.EmbedContentConfig(output_dimensionality=768),
            )
            return resp.embeddings[0].values
        except Exception as e:
            logger.warning(f"SemanticCache: embedding failed: {e}")
            return None

    def _query_hash(self, query: str) -> str:
        return hashlib.sha256(_normalize_query(query).encode()).hexdigest()[:16]

    # ── Public API ─────────────────────────────────────────────────────────

    def get(self, query: str) -> Optional[Dict[str, Any]]:
        """
        Look up a semantically similar cached answer.

        Returns dict with keys: answer, metadata, source, similarity
        Returns None on cache miss or any error.
        """
        if not self.redis:
            return None

        try:
            # Get all stored embedding keys
            keys = self.redis.smembers(_INDEX_KEY)
            if not keys:
                return None

            # Embed the incoming query
            query_vec = self._embed(query)
            if query_vec is None:
                return None

            best_score = 0.0
            best_entry = None

            for key in keys:
                raw = self.redis.get(f"{_PREFIX_EMB}{key}")
                if not raw:
                    continue
                try:
                    entry = json.loads(raw)
                except json.JSONDecodeError:
                    continue

                stored_vec = entry.get("embedding")
                if not stored_vec:
                    continue

                score = _cosine_similarity(query_vec, stored_vec)
                if score > best_score:
                    best_score = score
                    best_entry = entry

            if best_score >= SIMILARITY_THRESHOLD and best_entry:
                logger.info(
                    f"SemanticCache HIT | similarity={best_score:.3f} | "
                    f"query='{query[:60]}' matched '{best_entry.get('query','')[:60]}'"
                )
                return {
                    "answer":     best_entry["answer"],
                    "metadata":   best_entry.get("metadata", {}),
                    "source":     "semantic_cache",
                    "similarity": round(best_score, 3),
                }

            logger.debug(f"SemanticCache MISS | best_score={best_score:.3f} | query='{query[:60]}'")
            return None

        except Exception as e:
            logger.error(f"SemanticCache.get failed: {e}")
            return None

    def set(self, query: str, answer: str, metadata: Optional[Dict] = None) -> bool:
        """
        Store a query+answer with its embedding vector.
        Silent on failure β€” caching is best-effort.
        """
        if not self.redis or not answer:
            return False

        try:
            embedding = self._embed(query)
            if embedding is None:
                return False

            key = self._query_hash(query)
            entry = {
                "query":     _normalize_query(query),
                "answer":    answer,
                "metadata":  metadata or {},
                "embedding": embedding,
                "timestamp": time.time(),
            }
            self.redis.setex(
                f"{_PREFIX_EMB}{key}",
                _TTL_SECONDS,
                json.dumps(entry),
            )
            self.redis.sadd(_INDEX_KEY, key)
            self.redis.expire(_INDEX_KEY, _TTL_SECONDS)

            logger.info(f"SemanticCache SET | key={key} | query='{query[:60]}'")
            return True

        except Exception as e:
            logger.error(f"SemanticCache.set failed: {e}")
            return False

    def invalidate(self, query: str) -> bool:
        """Remove a specific entry (e.g. if answer was wrong)."""
        try:
            key = self._query_hash(query)
            self.redis.delete(f"{_PREFIX_EMB}{key}")
            self.redis.srem(_INDEX_KEY, key)
            return True
        except Exception as e:
            logger.error(f"SemanticCache.invalidate failed: {e}")
            return False

    def stats(self) -> Dict[str, Any]:
        """How many entries are cached."""
        try:
            count = self.redis.scard(_INDEX_KEY) if self.redis else 0
            return {"entries": count, "threshold": SIMILARITY_THRESHOLD}
        except Exception:
            return {"entries": 0, "threshold": SIMILARITY_THRESHOLD}