""" Caching Module for RAG System. Implements Redis-based caching for embeddings and responses. """ import hashlib import json import pickle from typing import Any, Optional, List, Union from datetime import timedelta from ..utils import get_logger logger = get_logger(__name__) class CacheBackend: """Base cache backend interface.""" def get(self, key: str) -> Optional[Any]: raise NotImplementedError def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: raise NotImplementedError def delete(self, key: str) -> bool: raise NotImplementedError def exists(self, key: str) -> bool: raise NotImplementedError def clear(self) -> bool: raise NotImplementedError class InMemoryCache(CacheBackend): """ Simple in-memory cache for development. """ def __init__(self, max_size: int = 1000): """ Initialize in-memory cache. Args: max_size: Maximum number of entries """ self.max_size = max_size self._cache = {} self._expiry = {} logger.info("In-memory cache initialized") def get(self, key: str) -> Optional[Any]: """Get value from cache.""" import time if key in self._expiry: if time.time() > self._expiry[key]: self.delete(key) return None return self._cache.get(key) def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: """Set value in cache.""" import time # Evict if at max size if len(self._cache) >= self.max_size: oldest = next(iter(self._cache)) self.delete(oldest) self._cache[key] = value if ttl: self._expiry[key] = time.time() + ttl return True def delete(self, key: str) -> bool: """Delete key from cache.""" self._cache.pop(key, None) self._expiry.pop(key, None) return True def exists(self, key: str) -> bool: """Check if key exists.""" return key in self._cache def clear(self) -> bool: """Clear all cache entries.""" self._cache.clear() self._expiry.clear() return True class RedisCache(CacheBackend): """ Redis-based cache for production. """ def __init__( self, host: str = "localhost", port: int = 6379, db: int = 0, password: Optional[str] = None, prefix: str = "rag:" ): """ Initialize Redis cache. Args: host: Redis host port: Redis port db: Redis database number password: Redis password prefix: Key prefix """ self.prefix = prefix self._redis = None try: import redis self._redis = redis.Redis( host=host, port=port, db=db, password=password, decode_responses=False ) self._redis.ping() logger.info(f"Redis cache connected: {host}:{port}") except Exception as e: logger.warning(f"Redis not available: {e}") self._redis = None def _key(self, key: str) -> str: """Add prefix to key.""" return f"{self.prefix}{key}" def get(self, key: str) -> Optional[Any]: """Get value from Redis.""" if not self._redis: return None try: data = self._redis.get(self._key(key)) if data: return pickle.loads(data) return None except Exception as e: logger.debug(f"Cache get error: {e}") return None def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: """Set value in Redis.""" if not self._redis: return False try: data = pickle.dumps(value) if ttl: self._redis.setex(self._key(key), ttl, data) else: self._redis.set(self._key(key), data) return True except Exception as e: logger.debug(f"Cache set error: {e}") return False def delete(self, key: str) -> bool: """Delete key from Redis.""" if not self._redis: return False try: self._redis.delete(self._key(key)) return True except Exception as e: logger.debug(f"Cache delete error: {e}") return False def exists(self, key: str) -> bool: """Check if key exists in Redis.""" if not self._redis: return False try: return self._redis.exists(self._key(key)) > 0 except Exception: return False def clear(self) -> bool: """Clear all cache entries with prefix.""" if not self._redis: return False try: keys = self._redis.keys(f"{self.prefix}*") if keys: self._redis.delete(*keys) return True except Exception as e: logger.debug(f"Cache clear error: {e}") return False class RAGCache: """ High-level caching for RAG operations. Caches embeddings, search results, and LLM responses. """ def __init__(self, backend: Optional[CacheBackend] = None): """ Initialize RAG cache. Args: backend: Cache backend to use """ if backend: self.backend = backend else: # Try Redis first, fallback to in-memory redis_cache = RedisCache() if redis_cache._redis: self.backend = redis_cache else: self.backend = InMemoryCache() self.embedding_ttl = 86400 # 24 hours self.search_ttl = 3600 # 1 hour self.response_ttl = 1800 # 30 minutes def _hash_key(self, *args) -> str: """Generate cache key from arguments.""" content = json.dumps(args, sort_keys=True, default=str) return hashlib.md5(content.encode()).hexdigest() def get_embedding(self, text: str, model: str) -> Optional[List[float]]: """Get cached embedding.""" key = f"emb:{self._hash_key(text, model)}" return self.backend.get(key) def set_embedding(self, text: str, model: str, embedding: List[float]) -> bool: """Cache an embedding.""" key = f"emb:{self._hash_key(text, model)}" return self.backend.set(key, embedding, self.embedding_ttl) def get_search_results(self, query: str, top_k: int) -> Optional[Any]: """Get cached search results.""" key = f"search:{self._hash_key(query, top_k)}" return self.backend.get(key) def set_search_results(self, query: str, top_k: int, results: Any) -> bool: """Cache search results.""" key = f"search:{self._hash_key(query, top_k)}" return self.backend.set(key, results, self.search_ttl) def get_response(self, query: str, context_hash: str) -> Optional[str]: """Get cached LLM response.""" key = f"resp:{self._hash_key(query, context_hash)}" return self.backend.get(key) def set_response(self, query: str, context_hash: str, response: str) -> bool: """Cache LLM response.""" key = f"resp:{self._hash_key(query, context_hash)}" return self.backend.set(key, response, self.response_ttl) def invalidate_all(self) -> bool: """Invalidate all cached data.""" return self.backend.clear() # Global instance _cache: Optional[RAGCache] = None def get_cache() -> RAGCache: """Get global cache instance.""" global _cache if _cache is None: _cache = RAGCache() return _cache