| | """ |
| | SPARKNET Cache Manager |
| | Redis-based caching for RAG queries and embeddings. |
| | """ |
| |
|
| | from typing import Optional, Any, List, Dict |
| | from datetime import timedelta |
| | import hashlib |
| | import json |
| | import os |
| | from loguru import logger |
| |
|
| | |
| | _redis_client = None |
| |
|
| |
|
| | def get_redis_client(): |
| | """Get or create Redis client.""" |
| | global _redis_client |
| | if _redis_client is None: |
| | try: |
| | import redis |
| | redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") |
| | _redis_client = redis.from_url(redis_url, decode_responses=True) |
| | |
| | _redis_client.ping() |
| | logger.info(f"Redis connected: {redis_url}") |
| | except Exception as e: |
| | logger.warning(f"Redis not available: {e}. Using in-memory cache.") |
| | _redis_client = None |
| | return _redis_client |
| |
|
| |
|
| | class CacheManager: |
| | """ |
| | Unified cache manager supporting Redis and in-memory fallback. |
| | """ |
| |
|
| | def __init__(self, prefix: str = "sparknet", default_ttl: int = 3600): |
| | """ |
| | Initialize cache manager. |
| | |
| | Args: |
| | prefix: Key prefix for namespacing |
| | default_ttl: Default TTL in seconds (1 hour) |
| | """ |
| | self.prefix = prefix |
| | self.default_ttl = default_ttl |
| | self._memory_cache: Dict[str, Dict[str, Any]] = {} |
| | self._redis = get_redis_client() |
| |
|
| | def _make_key(self, key: str) -> str: |
| | """Create namespaced cache key.""" |
| | return f"{self.prefix}:{key}" |
| |
|
| | def _hash_key(self, *args, **kwargs) -> str: |
| | """Create hash key from arguments.""" |
| | content = json.dumps({"args": args, "kwargs": kwargs}, sort_keys=True) |
| | return hashlib.md5(content.encode()).hexdigest() |
| |
|
| | def get(self, key: str) -> Optional[Any]: |
| | """ |
| | Get value from cache. |
| | |
| | Args: |
| | key: Cache key |
| | |
| | Returns: |
| | Cached value or None |
| | """ |
| | full_key = self._make_key(key) |
| |
|
| | |
| | if self._redis: |
| | try: |
| | value = self._redis.get(full_key) |
| | if value: |
| | return json.loads(value) |
| | except Exception as e: |
| | logger.warning(f"Redis get failed: {e}") |
| |
|
| | |
| | if full_key in self._memory_cache: |
| | entry = self._memory_cache[full_key] |
| | import time |
| | if entry.get("expires_at", 0) > time.time(): |
| | return entry.get("value") |
| | else: |
| | del self._memory_cache[full_key] |
| |
|
| | return None |
| |
|
| | def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: |
| | """ |
| | Set value in cache. |
| | |
| | Args: |
| | key: Cache key |
| | value: Value to cache |
| | ttl: Time-to-live in seconds (default: self.default_ttl) |
| | |
| | Returns: |
| | True if successful |
| | """ |
| | full_key = self._make_key(key) |
| | ttl = ttl or self.default_ttl |
| |
|
| | |
| | if self._redis: |
| | try: |
| | self._redis.setex(full_key, ttl, json.dumps(value)) |
| | return True |
| | except Exception as e: |
| | logger.warning(f"Redis set failed: {e}") |
| |
|
| | |
| | import time |
| | self._memory_cache[full_key] = { |
| | "value": value, |
| | "expires_at": time.time() + ttl |
| | } |
| |
|
| | |
| | if len(self._memory_cache) > 10000: |
| | self._cleanup_memory_cache() |
| |
|
| | return True |
| |
|
| | def delete(self, key: str) -> bool: |
| | """Delete a cache entry.""" |
| | full_key = self._make_key(key) |
| |
|
| | if self._redis: |
| | try: |
| | self._redis.delete(full_key) |
| | except Exception as e: |
| | logger.warning(f"Redis delete failed: {e}") |
| |
|
| | if full_key in self._memory_cache: |
| | del self._memory_cache[full_key] |
| |
|
| | return True |
| |
|
| | def clear_prefix(self, prefix: str) -> int: |
| | """Clear all keys matching a prefix.""" |
| | pattern = self._make_key(f"{prefix}:*") |
| | count = 0 |
| |
|
| | if self._redis: |
| | try: |
| | keys = self._redis.keys(pattern) |
| | if keys: |
| | count = self._redis.delete(*keys) |
| | except Exception as e: |
| | logger.warning(f"Redis clear failed: {e}") |
| |
|
| | |
| | to_delete = [k for k in self._memory_cache if k.startswith(self._make_key(prefix))] |
| | for k in to_delete: |
| | del self._memory_cache[k] |
| | count += 1 |
| |
|
| | return count |
| |
|
| | def _cleanup_memory_cache(self): |
| | """Remove expired entries from memory cache.""" |
| | import time |
| | now = time.time() |
| | expired = [ |
| | k for k, v in self._memory_cache.items() |
| | if v.get("expires_at", 0) < now |
| | ] |
| | for k in expired: |
| | del self._memory_cache[k] |
| |
|
| | |
| | if len(self._memory_cache) > 10000: |
| | sorted_keys = sorted( |
| | self._memory_cache.keys(), |
| | key=lambda k: self._memory_cache[k].get("expires_at", 0) |
| | ) |
| | for k in sorted_keys[:len(sorted_keys) // 2]: |
| | del self._memory_cache[k] |
| |
|
| |
|
| | class QueryCache(CacheManager): |
| | """ |
| | Specialized cache for RAG queries. |
| | """ |
| |
|
| | def __init__(self, ttl: int = 3600): |
| | super().__init__(prefix="sparknet:query", default_ttl=ttl) |
| |
|
| | def get_query_key(self, query: str, doc_ids: Optional[List[str]] = None) -> str: |
| | """Generate cache key for a query.""" |
| | doc_str = ",".join(sorted(doc_ids)) if doc_ids else "all" |
| | content = f"{query.lower().strip()}:{doc_str}" |
| | return hashlib.md5(content.encode()).hexdigest() |
| |
|
| | def get_query_response(self, query: str, doc_ids: Optional[List[str]] = None) -> Optional[Dict]: |
| | """Get cached query response.""" |
| | key = self.get_query_key(query, doc_ids) |
| | return self.get(key) |
| |
|
| | def cache_query_response( |
| | self, |
| | query: str, |
| | response: Dict, |
| | doc_ids: Optional[List[str]] = None, |
| | ttl: Optional[int] = None |
| | ) -> bool: |
| | """Cache a query response.""" |
| | key = self.get_query_key(query, doc_ids) |
| | return self.set(key, response, ttl) |
| |
|
| |
|
| | class EmbeddingCache(CacheManager): |
| | """ |
| | Specialized cache for embeddings. |
| | """ |
| |
|
| | def __init__(self, ttl: int = 86400): |
| | super().__init__(prefix="sparknet:embed", default_ttl=ttl) |
| |
|
| | def get_embedding_key(self, text: str, model: str = "default") -> str: |
| | """Generate cache key for embedding.""" |
| | content = f"{model}:{text}" |
| | return hashlib.md5(content.encode()).hexdigest() |
| |
|
| | def get_embedding(self, text: str, model: str = "default") -> Optional[List[float]]: |
| | """Get cached embedding.""" |
| | key = self.get_embedding_key(text, model) |
| | return self.get(key) |
| |
|
| | def cache_embedding( |
| | self, |
| | text: str, |
| | embedding: List[float], |
| | model: str = "default" |
| | ) -> bool: |
| | """Cache an embedding.""" |
| | key = self.get_embedding_key(text, model) |
| | return self.set(key, embedding) |
| |
|
| |
|
| | |
| | _query_cache: Optional[QueryCache] = None |
| | _embedding_cache: Optional[EmbeddingCache] = None |
| |
|
| |
|
| | def get_query_cache() -> QueryCache: |
| | """Get or create query cache instance.""" |
| | global _query_cache |
| | if _query_cache is None: |
| | _query_cache = QueryCache() |
| | return _query_cache |
| |
|
| |
|
| | def get_embedding_cache() -> EmbeddingCache: |
| | """Get or create embedding cache instance.""" |
| | global _embedding_cache |
| | if _embedding_cache is None: |
| | _embedding_cache = EmbeddingCache() |
| | return _embedding_cache |
| |
|
| |
|
| | |
| | def cached(prefix: str = "func", ttl: int = 3600): |
| | """ |
| | Decorator to cache function results. |
| | |
| | Usage: |
| | @cached(prefix="my_func", ttl=600) |
| | def expensive_function(arg1, arg2): |
| | ... |
| | """ |
| | def decorator(func): |
| | cache = CacheManager(prefix=f"sparknet:{prefix}", default_ttl=ttl) |
| |
|
| | def wrapper(*args, **kwargs): |
| | |
| | key = f"{func.__name__}:{cache._hash_key(*args, **kwargs)}" |
| |
|
| | |
| | result = cache.get(key) |
| | if result is not None: |
| | return result |
| |
|
| | |
| | result = func(*args, **kwargs) |
| | cache.set(key, result) |
| | return result |
| |
|
| | return wrapper |
| | return decorator |
| |
|