mathstutor / app /memory /semantic_cache.py
ghadgemadhuri92's picture
Sub-agent with native grounding
3a4bdd3
"""
app/memory/semantic_cache.py β€” Semantic (meaning-aware) cache for MathMinds AI.
Architecture
────────────
Exact hash cache (Redis) ← microseconds, free
↓ MISS
Semantic vector cache (Redis) ← ~50ms, free (embedding stored in Redis)
↓ MISS
Gemini API call ← costs 1 quota unit
Why two layers?
- Exact cache: zero cost, handles identical repeated questions instantly.
- Semantic cache: handles paraphrases. Uses Google's gemini-embedding-001
to embed both the query and stored questions, then finds nearest neighbour
by cosine similarity. Entirely self-contained in Redis β€” no Supabase needed.
Redis key design
semantic:index β†’ Redis Set β€” all embedding keys
semantic:emb:{hash} β†’ JSON {query, embedding, answer, metadata, timestamp}
Similarity threshold: 0.85
- 0.85+ β†’ same mathematical question, different words (safe to return)
- 0.70-0.85 β†’ related topic, probably different question (skip)
- <0.70 β†’ unrelated
Quota cost of embeddings
gemini-embedding-001 is NOT counted against the generate_content quota.
It has its own free tier: 1500 requests/day β€” far more than the 20/day
generate limit, so semantic lookup is essentially free to run.
"""
import json
import logging
import hashlib
import time
import math
from typing import Optional, Dict, Any, List, Tuple
logger = logging.getLogger(__name__)
# ── Similarity threshold ───────────────────────────────────────────────────
# Tested against math paraphrase pairs. Lower = more aggressive matching.
SIMILARITY_THRESHOLD = 0.85
# Redis key prefixes
_PREFIX_EMB = "semantic:emb:" # stores embedding + answer
_INDEX_KEY = "semantic:index" # set of all embedding hashes
_TTL_SECONDS = 7 * 24 * 3600 # 7 days
def _cosine_similarity(a: List[float], b: List[float]) -> float:
"""Pure-Python cosine similarity. No numpy needed."""
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
def _normalize_query(query: str) -> str:
"""
Light normalization before embedding.
Removes punctuation noise but keeps math symbols β€” '2+2' and '2 + 2'
should map to the same embedding region.
"""
import re
q = query.lower().strip()
# collapse whitespace
q = re.sub(r"\s+", " ", q)
return q
class SemanticCache:
"""
Semantic similarity cache backed by Redis.
Usage (in orchestrator):
sc = SemanticCache(redis_client, gemini_client)
# Lookup
result = sc.get(query)
if result:
return result["answer"]
# Store after getting answer from API
sc.set(query, answer_text, metadata)
"""
def __init__(self, redis_client, gemini_api_key: str):
self.redis = redis_client
self._api_key = gemini_api_key
self._genai = None # lazy init
def _get_client(self):
"""Lazy-init google.genai client so import errors are surfaced clearly."""
if self._genai is None:
try:
from google import genai
self._genai = genai.Client(api_key=self._api_key)
except Exception as e:
logger.error(f"SemanticCache: failed to init genai client: {e}")
raise
return self._genai
def _embed(self, text: str) -> Optional[List[float]]:
"""
Generate embedding vector for text.
Uses gemini-embedding-001 (NOT counted against generate_content quota).
Returns None on failure so cache misses gracefully on API errors.
"""
try:
from google.genai import types
client = self._get_client()
resp = client.models.embed_content(
model="models/gemini-embedding-001",
contents=_normalize_query(text),
config=types.EmbedContentConfig(output_dimensionality=768),
)
return resp.embeddings[0].values
except Exception as e:
logger.warning(f"SemanticCache: embedding failed: {e}")
return None
def _query_hash(self, query: str) -> str:
return hashlib.sha256(_normalize_query(query).encode()).hexdigest()[:16]
# ── Public API ─────────────────────────────────────────────────────────
def get(self, query: str) -> Optional[Dict[str, Any]]:
"""
Look up a semantically similar cached answer.
Returns dict with keys: answer, metadata, source, similarity
Returns None on cache miss or any error.
"""
if not self.redis:
return None
try:
# Get all stored embedding keys
keys = self.redis.smembers(_INDEX_KEY)
if not keys:
return None
# Embed the incoming query
query_vec = self._embed(query)
if query_vec is None:
return None
best_score = 0.0
best_entry = None
for key in keys:
raw = self.redis.get(f"{_PREFIX_EMB}{key}")
if not raw:
continue
try:
entry = json.loads(raw)
except json.JSONDecodeError:
continue
stored_vec = entry.get("embedding")
if not stored_vec:
continue
score = _cosine_similarity(query_vec, stored_vec)
if score > best_score:
best_score = score
best_entry = entry
if best_score >= SIMILARITY_THRESHOLD and best_entry:
logger.info(
f"SemanticCache HIT | similarity={best_score:.3f} | "
f"query='{query[:60]}' matched '{best_entry.get('query','')[:60]}'"
)
return {
"answer": best_entry["answer"],
"metadata": best_entry.get("metadata", {}),
"source": "semantic_cache",
"similarity": round(best_score, 3),
}
logger.debug(f"SemanticCache MISS | best_score={best_score:.3f} | query='{query[:60]}'")
return None
except Exception as e:
logger.error(f"SemanticCache.get failed: {e}")
return None
def set(self, query: str, answer: str, metadata: Optional[Dict] = None) -> bool:
"""
Store a query+answer with its embedding vector.
Silent on failure β€” caching is best-effort.
"""
if not self.redis or not answer:
return False
try:
embedding = self._embed(query)
if embedding is None:
return False
key = self._query_hash(query)
entry = {
"query": _normalize_query(query),
"answer": answer,
"metadata": metadata or {},
"embedding": embedding,
"timestamp": time.time(),
}
self.redis.setex(
f"{_PREFIX_EMB}{key}",
_TTL_SECONDS,
json.dumps(entry),
)
self.redis.sadd(_INDEX_KEY, key)
self.redis.expire(_INDEX_KEY, _TTL_SECONDS)
logger.info(f"SemanticCache SET | key={key} | query='{query[:60]}'")
return True
except Exception as e:
logger.error(f"SemanticCache.set failed: {e}")
return False
def invalidate(self, query: str) -> bool:
"""Remove a specific entry (e.g. if answer was wrong)."""
try:
key = self._query_hash(query)
self.redis.delete(f"{_PREFIX_EMB}{key}")
self.redis.srem(_INDEX_KEY, key)
return True
except Exception as e:
logger.error(f"SemanticCache.invalidate failed: {e}")
return False
def stats(self) -> Dict[str, Any]:
"""How many entries are cached."""
try:
count = self.redis.scard(_INDEX_KEY) if self.redis else 0
return {"entries": count, "threshold": SIMILARITY_THRESHOLD}
except Exception:
return {"entries": 0, "threshold": SIMILARITY_THRESHOLD}