""" Semantic caching layer using RedisVL for vector similarity search. This module implements a semantic cache that stores LLM responses indexed by their embedding vectors, enabling retrieval of semantically similar queries without re-computing expensive LLM generations. Mathematical Foundations ------------------------ 1. Cosine Similarity for Semantic Matching: sim(u, v) = (u · v) / (||u||₂ · ||v||₂) ∈ [-1, 1] For L2-normalized vectors: sim(u, v) = u · v Cache hit threshold: τ_sim = 0.92 (empirically tuned) Reference: Reimers & Gurevych, "Sentence-BERT", EMNLP 2019 [1] 2. L2 Normalization for Vector Indexing: Given embedding e ∈ ℝᵈ: ê = e / (||e||₂ + ε) where ε = 1e-9 for numerical stability Ensures unit-norm vectors for consistent cosine distance computation. 3. Time-Based Expiration (TTL): Entry valid iff: current_time - timestamp ≤ ttl Provides automatic cache invalidation for stale responses. 4. Flat Index Search Complexity: Exact nearest neighbor search: O(N·d) where N=docs, d=embedding_dim Acceptable for N < 100K; consider HNSW for larger datasets. Reference: Malkov & Yashunin, "Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs" [2] References ---------- [1] Reimers, N., & Gurevych, I. (2019). Sentence-BERT: Sentence embeddings using Siamese BERT-networks. EMNLP-IJCNLP 2019. https://github.com/UKPLab/sentence-transformers [2] Malkov, Y. A., & Yashunin, D. A. (2020). Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs. IEEE Transactions on Pattern Analysis and Machine Intelligence. https://github.com/nmslib/hnswlib Performance Characteristics --------------------------- - _normalize(): O(d) for L2 normalization - search(): O(N·d) for flat index vector scan + O(1) for metadata filtering - store(): O(d) for normalization + O(log N) for Redis index insertion - clear(): O(N) for full index deletion or O(M) for domain-filtered scan Memory Footprint ---------------- - Per entry: d·4 bytes (float32 embeddings) + metadata overhead (~200-500B) - Example: d=384 → ~1.5KB/embedding + response text - For 10K entries: ~15-20MB for embeddings + variable for text Thread Safety ------------- - Redis client operations are thread-safe per redis-py documentation - Index queries and loads are atomic at Redis level - No shared mutable state beyond Redis connection pool Author: IntelliDeep Labs Team License: BSL 1.1 """ from __future__ import annotations import json import logging import time import uuid from typing import Dict, List, Optional import numpy as np from redis import Redis from redisvl.index import SearchIndex from redisvl.query import VectorQuery from redisvl.schema import IndexSchema logger = logging.getLogger(__name__) class SemanticLLMCache: """ Vector-based semantic cache for LLM responses using RedisVL. Stores prompt-response pairs indexed by embedding vectors, enabling retrieval of semantically similar queries without re-generating responses. Key Features ------------ - Cosine similarity search for semantic matching - Domain-based filtering for multi-tenant isolation - TTL-based automatic expiration for cache freshness - Hit/miss statistics for cache performance monitoring - Thread-safe operations via Redis connection pooling Usage Example ------------- >>> cache = SemanticLLMCache( ... redis_url="redis://localhost:6379", ... similarity_threshold=0.92, ... dimension=384 # all-MiniLM-L6-v2 embedding size ... ) >>> # Store a response >>> cache.store(query_emb, response_text, metadata={"model": "gpt-4"}) >>> # Search for similar queries >>> result = cache.search(new_query_emb, domain="general") >>> if result: ... print(f"Cache hit: {result['response']}") """ # Default configuration constants _DEFAULT_REDIS_URL: str = "redis://localhost:6379" _DEFAULT_SIMILARITY_THRESHOLD: float = 0.92 _DEFAULT_TTL_SECONDS: int = 3600 _DEFAULT_EMBEDDING_DIM: int = 384 # all-MiniLM-L6-v2 output dimension _DEFAULT_INDEX_NAME: str = "prompt_cache" _DEFAULT_KEY_PREFIX: str = "cache:" # Numerical constants _L2_NORMALIZATION_EPSILON: float = 1e-9 def __init__( self, redis_url: str = _DEFAULT_REDIS_URL, similarity_threshold: float = _DEFAULT_SIMILARITY_THRESHOLD, default_ttl: int = _DEFAULT_TTL_SECONDS, dimension: int = _DEFAULT_EMBEDDING_DIM, index_name: str = _DEFAULT_INDEX_NAME, prefix: str = _DEFAULT_KEY_PREFIX, max_connections: int = 50, socket_timeout: float = 5.0, ) -> None: """ Initialize the semantic cache with RedisVL backend. Parameters ---------- redis_url : str, optional Redis connection URL (default: redis://localhost:6379). similarity_threshold : float, optional Minimum cosine similarity for cache hits ∈ [0, 1]. Higher values = stricter matching, fewer false positives. default_ttl : int, optional Default time-to-live for cached entries in seconds. dimension : int, optional Embedding vector dimensionality (must match embedding model). index_name : str, optional Name for the RedisVL search index. prefix : str, optional Key prefix for Redis entries (namespace isolation). max_connections : int, optional Maximum Redis connection pool size. socket_timeout : float, optional Socket timeout for Redis operations in seconds. Raises ------ ConnectionError If Redis server is unreachable at initialization. ValueError If dimension <= 0 or similarity_threshold not in [0, 1]. """ # Validate parameters if not 0.0 <= similarity_threshold <= 1.0: raise ValueError(f"similarity_threshold must be in [0, 1], got {similarity_threshold}") if dimension <= 0: raise ValueError(f"dimension must be positive, got {dimension}") # Store configuration self.threshold = similarity_threshold self.default_ttl = default_ttl self.dim = dimension self.index_name = index_name self.prefix = prefix # Initialize Redis client with response decoding and configured connection pool self.redis_client = Redis.from_url( redis_url, decode_responses=True, max_connections=max_connections, socket_timeout=socket_timeout, ) # Test connection early to fail fast self.redis_client.ping() # Define vector index schema for RedisVL schema = IndexSchema.from_dict({ "index": { "name": index_name, "prefix": prefix, }, "fields": [ { "name": "embedding", "type": "vector", "attrs": { "dims": dimension, "distance_metric": "cosine", "algorithm": "hnsw", "m": 16, "ef_construction": 200, }, }, {"name": "response", "type": "text"}, {"name": "metadata", "type": "text"}, {"name": "domain", "type": "tag"}, {"name": "timestamp", "type": "numeric"}, {"name": "ttl", "type": "numeric"}, ], }) # Create or connect to search index self.index = SearchIndex(schema, redis_client=self.redis_client) try: self.index.create(overwrite=False) logger.info(f"Created new vector index '{index_name}'") except Exception: # Index already exists; connect to existing logger.debug(f"Connected to existing vector index '{index_name}'") # Runtime statistics self.stats = {"hits": 0, "misses": 0, "evictions": 0} logger.info( f"SemanticLLMCache initialized: threshold={similarity_threshold:.2f}, " f"dim={dimension}, ttl={default_ttl}s, index={index_name}" ) def _normalize(self, embedding: np.ndarray) -> List[float]: """ L2-normalize embedding vector and convert to Python list for Redis. Parameters ---------- embedding : np.ndarray Input embedding array of shape (d,) or (1, d). Returns ------- List[float] L2-normalized embedding as flat list of floats. Mathematical Note ----------------- For embedding e ∈ ℝᵈ: ||e||₂ = √(Σᵢ eᵢ²) ê = e / (||e||₂ + ε) where ε = 1e-9 for numerical stability This ensures cosine similarity equals dot product: sim(u, v) = û · v̂ for unit-norm vectors Complexity ---------- Time: O(d) for norm computation and normalization Space: O(d) for output list """ # Ensure 2D shape for batch operations if embedding.ndim == 1: embedding = embedding.reshape(1, -1) # Compute L2 norms with epsilon for numerical stability norms = np.linalg.norm(embedding, axis=1, keepdims=True) normalized = embedding / (norms + self._L2_NORMALIZATION_EPSILON) # Flatten to list for JSON/Redis compatibility return normalized.flatten().tolist() def search( self, query_embedding: np.ndarray, domain: str = "general", ) -> Optional[Dict]: """ Search for semantically similar cached responses. Parameters ---------- query_embedding : np.ndarray Embedding vector of the query prompt. domain : str, optional Domain tag for filtering results (default: "general"). Returns ------- Optional[Dict] Cached entry dict with keys: response, metadata, timestamp, ttl, id. None if no match above similarity threshold or entry expired. Search Algorithm ---------------- 1. Normalize query embedding to unit norm 2. Execute vector similarity query via RedisVL 3. Filter results by: a. Cosine similarity >= threshold b. TTL not expired (timestamp + ttl >= now) c. Domain match (or domain="general" for cross-domain) 4. Return first valid match or None Complexity ---------- Time: O(N·d) for flat index scan where N=docs, d=embedding_dim Space: O(1) additional beyond Redis response Note ---- For large datasets (N > 100K), consider switching to HNSW algorithm in index schema for O(log N) approximate nearest neighbor search. """ # Normalize query embedding for cosine similarity query_vector = self._normalize(query_embedding) # Build vector similarity query query = VectorQuery( vector=query_vector, vector_field_name="embedding", num_results=1, # Return only top match return_score=True, ) # Execute search results = self.index.query(query) if not results: self.stats["misses"] += 1 return None # Evaluate top result against filters doc = results[0] similarity = doc.get("vector_score", 0.0) # Check similarity threshold if similarity < self.threshold: self.stats["misses"] += 1 return None # Check TTL expiration timestamp = float(doc.get("timestamp", 0)) ttl = int(doc.get("ttl", self.default_ttl)) if time.time() - timestamp > ttl: self.stats["evictions"] += 1 self.stats["misses"] += 1 return None # Check domain filter ("general" matches all domains) doc_domain = doc.get("domain", "general") if domain != "general" and doc_domain != domain: self.stats["misses"] += 1 return None # Cache hit: update stats and return result self.stats["hits"] += 1 return { "response": doc.get("response"), "metadata": json.loads(doc.get("metadata", "{}")), "timestamp": timestamp, "ttl": ttl, "id": doc.get("id"), "similarity": similarity, # Include for observability } def store( self, query_embedding: np.ndarray, response: str, metadata: Dict, ttl: Optional[int] = None, domain: str = "general", ) -> str: """ Store a prompt-response pair in the semantic cache. Parameters ---------- query_embedding : np.ndarray Embedding vector of the query prompt. response : str LLM response text to cache. metadata : Dict Additional metadata (e.g., model name, token counts). ttl : Optional[int], optional Time-to-live in seconds. If None, uses default_ttl. domain : str, optional Domain tag for filtering (default: "general"). Returns ------- str Unique document ID for the cached entry. Storage Format -------------- Each entry stored as Redis hash with fields: - embedding: L2-normalized vector (float list) - response: response text - metadata: JSON-encoded metadata dict - domain: domain tag for filtering - timestamp: Unix timestamp of insertion - ttl: time-to-live in seconds Complexity ---------- Time: O(d) for normalization + O(log N) for index insertion Space: O(d + |response| + |metadata|) per entry Note ---- Redis TTL is set at key level for automatic expiration. Application-level timestamp+ttl provides fallback validation. """ ttl_value = ttl if ttl is not None else self.default_ttl vector = self._normalize(query_embedding) # Generate unique document ID doc_id = f"{self.prefix}{uuid.uuid4().hex}" # Prepare entry data entry = { "embedding": vector, "response": response, "metadata": json.dumps(metadata), "domain": domain, "timestamp": time.time(), "ttl": ttl_value, } # Store in RedisVL index and set TTL try: self.index.load([entry], keys=[doc_id]) except Exception as e: logger.error(f"Failed to load entry to redisvl: {e}") raise self.redis_client.expire(doc_id, ttl_value) logger.debug(f"Cache stored: id={doc_id}, domain={domain}, ttl={ttl_value}s") return doc_id def clear(self, domain: Optional[str] = None) -> int: """ Remove cached entries, optionally filtered by domain. Parameters ---------- domain : Optional[str], optional If specified, only clear entries matching this domain. If None, clear all entries in the index. Returns ------- int Number of entries deleted. Complexity ---------- Time: O(N) for full index clear, O(M) for domain-filtered scan where N=total docs, M=docs in domain Space: O(1) additional Note ---- Domain-filtered clear uses SCAN + DELETE pattern which is non-atomic; consider using Redis keyspace notifications for production-grade cache invalidation if needed. """ deleted_count = 0 if domain: # Domain-filtered deletion via SCAN cursor = "0" pattern = f"{self.prefix}*" while cursor != 0: cursor, keys = self.redis_client.scan( cursor=cursor, match=pattern, count=100 ) for key in keys: doc_data = self.redis_client.hgetall(key) if doc_data and doc_data.get("domain") == domain: self.redis_client.delete(key) deleted_count += 1 self.stats["evictions"] += 1 else: # Full index deletion deleted_count = self.index.delete(delete_documents=True) self.stats["evictions"] += deleted_count logger.info(f"Cache cleared: {deleted_count} entries (domain={domain})") return deleted_count def get_stats(self) -> Dict: """ Return cache performance statistics. Returns ------- Dict Statistics including: - hits: number of successful cache retrievals - misses: number of failed retrievals - evictions: number of expired/deleted entries - size: current number of documents in index - index_name: name of the RedisVL index - hit_rate: hits / (hits + misses) if any requests made """ index_info = self.index.info() total_requests = self.stats["hits"] + self.stats["misses"] return { "hits": self.stats["hits"], "misses": self.stats["misses"], "evictions": self.stats["evictions"], "size": index_info.get("num_docs", 0), "index_name": self.index_name, "hit_rate": ( self.stats["hits"] / total_requests if total_requests > 0 else 0.0 ), } def reset_stats(self) -> None: """Reset runtime statistics counters (useful for testing/monitoring).""" self.stats = {"hits": 0, "misses": 0, "evictions": 0} logger.debug("Cache statistics reset")