Spaces:
Running
Running
| """ | |
| Semantic caching layer using RedisVL for vector similarity search. | |
| This module implements a semantic cache that stores LLM responses indexed by | |
| their embedding vectors, enabling retrieval of semantically similar queries | |
| without re-computing expensive LLM generations. | |
| Mathematical Foundations | |
| ------------------------ | |
| 1. Cosine Similarity for Semantic Matching: | |
| sim(u, v) = (u · v) / (||u||₂ · ||v||₂) ∈ [-1, 1] | |
| For L2-normalized vectors: sim(u, v) = u · v | |
| Cache hit threshold: τ_sim = 0.92 (empirically tuned) | |
| Reference: Reimers & Gurevych, "Sentence-BERT", EMNLP 2019 [1] | |
| 2. L2 Normalization for Vector Indexing: | |
| Given embedding e ∈ ℝᵈ: | |
| ê = e / (||e||₂ + ε) where ε = 1e-9 for numerical stability | |
| Ensures unit-norm vectors for consistent cosine distance computation. | |
| 3. Time-Based Expiration (TTL): | |
| Entry valid iff: current_time - timestamp ≤ ttl | |
| Provides automatic cache invalidation for stale responses. | |
| 4. Flat Index Search Complexity: | |
| Exact nearest neighbor search: O(N·d) where N=docs, d=embedding_dim | |
| Acceptable for N < 100K; consider HNSW for larger datasets. | |
| Reference: Malkov & Yashunin, "Efficient and robust approximate | |
| nearest neighbor search using Hierarchical Navigable Small World graphs" [2] | |
| References | |
| ---------- | |
| [1] Reimers, N., & Gurevych, I. (2019). Sentence-BERT: Sentence embeddings | |
| using Siamese BERT-networks. EMNLP-IJCNLP 2019. | |
| https://github.com/UKPLab/sentence-transformers | |
| [2] Malkov, Y. A., & Yashunin, D. A. (2020). Efficient and robust approximate | |
| nearest neighbor search using Hierarchical Navigable Small World graphs. | |
| IEEE Transactions on Pattern Analysis and Machine Intelligence. | |
| https://github.com/nmslib/hnswlib | |
| Performance Characteristics | |
| --------------------------- | |
| - _normalize(): O(d) for L2 normalization | |
| - search(): O(N·d) for flat index vector scan + O(1) for metadata filtering | |
| - store(): O(d) for normalization + O(log N) for Redis index insertion | |
| - clear(): O(N) for full index deletion or O(M) for domain-filtered scan | |
| Memory Footprint | |
| ---------------- | |
| - Per entry: d·4 bytes (float32 embeddings) + metadata overhead (~200-500B) | |
| - Example: d=384 → ~1.5KB/embedding + response text | |
| - For 10K entries: ~15-20MB for embeddings + variable for text | |
| Thread Safety | |
| ------------- | |
| - Redis client operations are thread-safe per redis-py documentation | |
| - Index queries and loads are atomic at Redis level | |
| - No shared mutable state beyond Redis connection pool | |
| Author: IntelliDeep Labs Team | |
| License: BSL 1.1 | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import time | |
| import uuid | |
| from typing import Dict, List, Optional | |
| import numpy as np | |
| from redis import Redis | |
| from redisvl.index import SearchIndex | |
| from redisvl.query import VectorQuery | |
| from redisvl.schema import IndexSchema | |
| logger = logging.getLogger(__name__) | |
| class SemanticLLMCache: | |
| """ | |
| Vector-based semantic cache for LLM responses using RedisVL. | |
| Stores prompt-response pairs indexed by embedding vectors, enabling | |
| retrieval of semantically similar queries without re-generating responses. | |
| Key Features | |
| ------------ | |
| - Cosine similarity search for semantic matching | |
| - Domain-based filtering for multi-tenant isolation | |
| - TTL-based automatic expiration for cache freshness | |
| - Hit/miss statistics for cache performance monitoring | |
| - Thread-safe operations via Redis connection pooling | |
| Usage Example | |
| ------------- | |
| >>> cache = SemanticLLMCache( | |
| ... redis_url="redis://localhost:6379", | |
| ... similarity_threshold=0.92, | |
| ... dimension=384 # all-MiniLM-L6-v2 embedding size | |
| ... ) | |
| >>> # Store a response | |
| >>> cache.store(query_emb, response_text, metadata={"model": "gpt-4"}) | |
| >>> # Search for similar queries | |
| >>> result = cache.search(new_query_emb, domain="general") | |
| >>> if result: | |
| ... print(f"Cache hit: {result['response']}") | |
| """ | |
| # Default configuration constants | |
| _DEFAULT_REDIS_URL: str = "redis://localhost:6379" | |
| _DEFAULT_SIMILARITY_THRESHOLD: float = 0.92 | |
| _DEFAULT_TTL_SECONDS: int = 3600 | |
| _DEFAULT_EMBEDDING_DIM: int = 384 # all-MiniLM-L6-v2 output dimension | |
| _DEFAULT_INDEX_NAME: str = "prompt_cache" | |
| _DEFAULT_KEY_PREFIX: str = "cache:" | |
| # Numerical constants | |
| _L2_NORMALIZATION_EPSILON: float = 1e-9 | |
| def __init__( | |
| self, | |
| redis_url: str = _DEFAULT_REDIS_URL, | |
| similarity_threshold: float = _DEFAULT_SIMILARITY_THRESHOLD, | |
| default_ttl: int = _DEFAULT_TTL_SECONDS, | |
| dimension: int = _DEFAULT_EMBEDDING_DIM, | |
| index_name: str = _DEFAULT_INDEX_NAME, | |
| prefix: str = _DEFAULT_KEY_PREFIX, | |
| max_connections: int = 50, | |
| socket_timeout: float = 5.0, | |
| ) -> None: | |
| """ | |
| Initialize the semantic cache with RedisVL backend. | |
| Parameters | |
| ---------- | |
| redis_url : str, optional | |
| Redis connection URL (default: redis://localhost:6379). | |
| similarity_threshold : float, optional | |
| Minimum cosine similarity for cache hits ∈ [0, 1]. | |
| Higher values = stricter matching, fewer false positives. | |
| default_ttl : int, optional | |
| Default time-to-live for cached entries in seconds. | |
| dimension : int, optional | |
| Embedding vector dimensionality (must match embedding model). | |
| index_name : str, optional | |
| Name for the RedisVL search index. | |
| prefix : str, optional | |
| Key prefix for Redis entries (namespace isolation). | |
| max_connections : int, optional | |
| Maximum Redis connection pool size. | |
| socket_timeout : float, optional | |
| Socket timeout for Redis operations in seconds. | |
| Raises | |
| ------ | |
| ConnectionError | |
| If Redis server is unreachable at initialization. | |
| ValueError | |
| If dimension <= 0 or similarity_threshold not in [0, 1]. | |
| """ | |
| # Validate parameters | |
| if not 0.0 <= similarity_threshold <= 1.0: | |
| raise ValueError(f"similarity_threshold must be in [0, 1], got {similarity_threshold}") | |
| if dimension <= 0: | |
| raise ValueError(f"dimension must be positive, got {dimension}") | |
| # Store configuration | |
| self.threshold = similarity_threshold | |
| self.default_ttl = default_ttl | |
| self.dim = dimension | |
| self.index_name = index_name | |
| self.prefix = prefix | |
| # Initialize Redis client with response decoding and configured connection pool | |
| self.redis_client = Redis.from_url( | |
| redis_url, | |
| decode_responses=True, | |
| max_connections=max_connections, | |
| socket_timeout=socket_timeout, | |
| ) | |
| # Test connection early to fail fast | |
| self.redis_client.ping() | |
| # Define vector index schema for RedisVL | |
| schema = IndexSchema.from_dict({ | |
| "index": { | |
| "name": index_name, | |
| "prefix": prefix, | |
| }, | |
| "fields": [ | |
| { | |
| "name": "embedding", | |
| "type": "vector", | |
| "attrs": { | |
| "dims": dimension, | |
| "distance_metric": "cosine", | |
| "algorithm": "hnsw", | |
| "m": 16, | |
| "ef_construction": 200, | |
| }, | |
| }, | |
| {"name": "response", "type": "text"}, | |
| {"name": "metadata", "type": "text"}, | |
| {"name": "domain", "type": "tag"}, | |
| {"name": "timestamp", "type": "numeric"}, | |
| {"name": "ttl", "type": "numeric"}, | |
| ], | |
| }) | |
| # Create or connect to search index | |
| self.index = SearchIndex(schema, redis_client=self.redis_client) | |
| try: | |
| self.index.create(overwrite=False) | |
| logger.info(f"Created new vector index '{index_name}'") | |
| except Exception: | |
| # Index already exists; connect to existing | |
| logger.debug(f"Connected to existing vector index '{index_name}'") | |
| # Runtime statistics | |
| self.stats = {"hits": 0, "misses": 0, "evictions": 0} | |
| logger.info( | |
| f"SemanticLLMCache initialized: threshold={similarity_threshold:.2f}, " | |
| f"dim={dimension}, ttl={default_ttl}s, index={index_name}" | |
| ) | |
| def _normalize(self, embedding: np.ndarray) -> List[float]: | |
| """ | |
| L2-normalize embedding vector and convert to Python list for Redis. | |
| Parameters | |
| ---------- | |
| embedding : np.ndarray | |
| Input embedding array of shape (d,) or (1, d). | |
| Returns | |
| ------- | |
| List[float] | |
| L2-normalized embedding as flat list of floats. | |
| Mathematical Note | |
| ----------------- | |
| For embedding e ∈ ℝᵈ: | |
| ||e||₂ = √(Σᵢ eᵢ²) | |
| ê = e / (||e||₂ + ε) where ε = 1e-9 for numerical stability | |
| This ensures cosine similarity equals dot product: | |
| sim(u, v) = û · v̂ for unit-norm vectors | |
| Complexity | |
| ---------- | |
| Time: O(d) for norm computation and normalization | |
| Space: O(d) for output list | |
| """ | |
| # Ensure 2D shape for batch operations | |
| if embedding.ndim == 1: | |
| embedding = embedding.reshape(1, -1) | |
| # Compute L2 norms with epsilon for numerical stability | |
| norms = np.linalg.norm(embedding, axis=1, keepdims=True) | |
| normalized = embedding / (norms + self._L2_NORMALIZATION_EPSILON) | |
| # Flatten to list for JSON/Redis compatibility | |
| return normalized.flatten().tolist() | |
| def search( | |
| self, | |
| query_embedding: np.ndarray, | |
| domain: str = "general", | |
| ) -> Optional[Dict]: | |
| """ | |
| Search for semantically similar cached responses. | |
| Parameters | |
| ---------- | |
| query_embedding : np.ndarray | |
| Embedding vector of the query prompt. | |
| domain : str, optional | |
| Domain tag for filtering results (default: "general"). | |
| Returns | |
| ------- | |
| Optional[Dict] | |
| Cached entry dict with keys: response, metadata, timestamp, ttl, id. | |
| None if no match above similarity threshold or entry expired. | |
| Search Algorithm | |
| ---------------- | |
| 1. Normalize query embedding to unit norm | |
| 2. Execute vector similarity query via RedisVL | |
| 3. Filter results by: | |
| a. Cosine similarity >= threshold | |
| b. TTL not expired (timestamp + ttl >= now) | |
| c. Domain match (or domain="general" for cross-domain) | |
| 4. Return first valid match or None | |
| Complexity | |
| ---------- | |
| Time: O(N·d) for flat index scan where N=docs, d=embedding_dim | |
| Space: O(1) additional beyond Redis response | |
| Note | |
| ---- | |
| For large datasets (N > 100K), consider switching to HNSW algorithm | |
| in index schema for O(log N) approximate nearest neighbor search. | |
| """ | |
| # Normalize query embedding for cosine similarity | |
| query_vector = self._normalize(query_embedding) | |
| # Build vector similarity query | |
| query = VectorQuery( | |
| vector=query_vector, | |
| vector_field_name="embedding", | |
| num_results=1, # Return only top match | |
| return_score=True, | |
| ) | |
| # Execute search | |
| results = self.index.query(query) | |
| if not results: | |
| self.stats["misses"] += 1 | |
| return None | |
| # Evaluate top result against filters | |
| doc = results[0] | |
| similarity = doc.get("vector_score", 0.0) | |
| # Check similarity threshold | |
| if similarity < self.threshold: | |
| self.stats["misses"] += 1 | |
| return None | |
| # Check TTL expiration | |
| timestamp = float(doc.get("timestamp", 0)) | |
| ttl = int(doc.get("ttl", self.default_ttl)) | |
| if time.time() - timestamp > ttl: | |
| self.stats["evictions"] += 1 | |
| self.stats["misses"] += 1 | |
| return None | |
| # Check domain filter ("general" matches all domains) | |
| doc_domain = doc.get("domain", "general") | |
| if domain != "general" and doc_domain != domain: | |
| self.stats["misses"] += 1 | |
| return None | |
| # Cache hit: update stats and return result | |
| self.stats["hits"] += 1 | |
| return { | |
| "response": doc.get("response"), | |
| "metadata": json.loads(doc.get("metadata", "{}")), | |
| "timestamp": timestamp, | |
| "ttl": ttl, | |
| "id": doc.get("id"), | |
| "similarity": similarity, # Include for observability | |
| } | |
| def store( | |
| self, | |
| query_embedding: np.ndarray, | |
| response: str, | |
| metadata: Dict, | |
| ttl: Optional[int] = None, | |
| domain: str = "general", | |
| ) -> str: | |
| """ | |
| Store a prompt-response pair in the semantic cache. | |
| Parameters | |
| ---------- | |
| query_embedding : np.ndarray | |
| Embedding vector of the query prompt. | |
| response : str | |
| LLM response text to cache. | |
| metadata : Dict | |
| Additional metadata (e.g., model name, token counts). | |
| ttl : Optional[int], optional | |
| Time-to-live in seconds. If None, uses default_ttl. | |
| domain : str, optional | |
| Domain tag for filtering (default: "general"). | |
| Returns | |
| ------- | |
| str | |
| Unique document ID for the cached entry. | |
| Storage Format | |
| -------------- | |
| Each entry stored as Redis hash with fields: | |
| - embedding: L2-normalized vector (float list) | |
| - response: response text | |
| - metadata: JSON-encoded metadata dict | |
| - domain: domain tag for filtering | |
| - timestamp: Unix timestamp of insertion | |
| - ttl: time-to-live in seconds | |
| Complexity | |
| ---------- | |
| Time: O(d) for normalization + O(log N) for index insertion | |
| Space: O(d + |response| + |metadata|) per entry | |
| Note | |
| ---- | |
| Redis TTL is set at key level for automatic expiration. | |
| Application-level timestamp+ttl provides fallback validation. | |
| """ | |
| ttl_value = ttl if ttl is not None else self.default_ttl | |
| vector = self._normalize(query_embedding) | |
| # Generate unique document ID | |
| doc_id = f"{self.prefix}{uuid.uuid4().hex}" | |
| # Prepare entry data | |
| entry = { | |
| "embedding": vector, | |
| "response": response, | |
| "metadata": json.dumps(metadata), | |
| "domain": domain, | |
| "timestamp": time.time(), | |
| "ttl": ttl_value, | |
| } | |
| # Store in RedisVL index and set TTL | |
| try: | |
| self.index.load([entry], keys=[doc_id]) | |
| except Exception as e: | |
| logger.error(f"Failed to load entry to redisvl: {e}") | |
| raise | |
| self.redis_client.expire(doc_id, ttl_value) | |
| logger.debug(f"Cache stored: id={doc_id}, domain={domain}, ttl={ttl_value}s") | |
| return doc_id | |
| def clear(self, domain: Optional[str] = None) -> int: | |
| """ | |
| Remove cached entries, optionally filtered by domain. | |
| Parameters | |
| ---------- | |
| domain : Optional[str], optional | |
| If specified, only clear entries matching this domain. | |
| If None, clear all entries in the index. | |
| Returns | |
| ------- | |
| int | |
| Number of entries deleted. | |
| Complexity | |
| ---------- | |
| Time: O(N) for full index clear, O(M) for domain-filtered scan | |
| where N=total docs, M=docs in domain | |
| Space: O(1) additional | |
| Note | |
| ---- | |
| Domain-filtered clear uses SCAN + DELETE pattern which is | |
| non-atomic; consider using Redis keyspace notifications for | |
| production-grade cache invalidation if needed. | |
| """ | |
| deleted_count = 0 | |
| if domain: | |
| # Domain-filtered deletion via SCAN | |
| cursor = "0" | |
| pattern = f"{self.prefix}*" | |
| while cursor != 0: | |
| cursor, keys = self.redis_client.scan( | |
| cursor=cursor, match=pattern, count=100 | |
| ) | |
| for key in keys: | |
| doc_data = self.redis_client.hgetall(key) | |
| if doc_data and doc_data.get("domain") == domain: | |
| self.redis_client.delete(key) | |
| deleted_count += 1 | |
| self.stats["evictions"] += 1 | |
| else: | |
| # Full index deletion | |
| deleted_count = self.index.delete(delete_documents=True) | |
| self.stats["evictions"] += deleted_count | |
| logger.info(f"Cache cleared: {deleted_count} entries (domain={domain})") | |
| return deleted_count | |
| def get_stats(self) -> Dict: | |
| """ | |
| Return cache performance statistics. | |
| Returns | |
| ------- | |
| Dict | |
| Statistics including: | |
| - hits: number of successful cache retrievals | |
| - misses: number of failed retrievals | |
| - evictions: number of expired/deleted entries | |
| - size: current number of documents in index | |
| - index_name: name of the RedisVL index | |
| - hit_rate: hits / (hits + misses) if any requests made | |
| """ | |
| index_info = self.index.info() | |
| total_requests = self.stats["hits"] + self.stats["misses"] | |
| return { | |
| "hits": self.stats["hits"], | |
| "misses": self.stats["misses"], | |
| "evictions": self.stats["evictions"], | |
| "size": index_info.get("num_docs", 0), | |
| "index_name": self.index_name, | |
| "hit_rate": ( | |
| self.stats["hits"] / total_requests | |
| if total_requests > 0 | |
| else 0.0 | |
| ), | |
| } | |
| def reset_stats(self) -> None: | |
| """Reset runtime statistics counters (useful for testing/monitoring).""" | |
| self.stats = {"hits": 0, "misses": 0, "evictions": 0} | |
| logger.debug("Cache statistics reset") | |