"""Plan Caching System for Data Federation with LLM Cost Optimization. This module provides intelligent caching of generated query plans to reduce LLM API costs and improve response times. It implements hash-based plan storage and retrieval with cache invalidation strategies. Key features: - Hash-based plan caching using user query + schema fingerprint - Redis storage with configurable TTL - LLM bypass logic for cached plans - Cache invalidation when schemas change - Support for plan versioning and rollback - Cost tracking and optimization metrics """ import hashlib import json import logging import time from typing import Any, Dict, List, Optional, Tuple from datetime import datetime, timedelta, timezone from dataclasses import dataclass, asdict from enum import Enum import redis from pydantic import BaseModel, Field logger = logging.getLogger(__name__) class CacheStatus(str, Enum): """Cache lookup result status.""" HIT = "hit" # Plan found in cache and returned MISS = "miss" # Plan not found, LLM call required STALE = "stale" # Plan found but expired/invalid ERROR = "error" # Cache operation failed @dataclass class CacheMetrics: """Metrics for plan cache operations.""" cache_hits: int = 0 cache_misses: int = 0 cache_errors: int = 0 total_lookups: int = 0 cost_savings_estimated: float = 0.0 # Estimated $ saved from cache hits average_llm_cost_per_query: float = 0.05 # Default $0.05 per LLM call @property def hit_rate(self) -> float: """Calculate cache hit rate percentage.""" if self.total_lookups == 0: return 0.0 return (self.cache_hits / self.total_lookups) * 100 def record_hit(self): """Record a cache hit.""" self.cache_hits += 1 self.total_lookups += 1 self.cost_savings_estimated += self.average_llm_cost_per_query def record_miss(self): """Record a cache miss.""" self.cache_misses += 1 self.total_lookups += 1 def record_error(self): """Record a cache error.""" self.cache_errors += 1 self.total_lookups += 1 class CachedPlan(BaseModel): """A cached query plan with metadata.""" plan_hash: str = Field(..., description="SHA-256 hash of query+schema") plan: List[Dict[str, Any]] = Field(..., description="The actual query plan (AST)") schema_hash: str = Field(..., description="Hash of the schema used to generate this plan") user_query: str = Field(..., description="Original user query") created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) last_used_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) use_count: int = Field(default=1, description="Number of times this plan has been used") tenant_id: str = Field(..., description="Tenant who owns this cached plan") version: str = Field(default="1.0", description="Plan format version") # Performance metadata avg_execution_time_ms: Optional[float] = Field(None, description="Average execution time") success_rate: float = Field(default=1.0, description="Success rate of this plan") last_error: Optional[str] = Field(None, description="Last execution error if any") class PlanCache: """High-performance plan cache with Redis backend and intelligent invalidation. This class provides: - Deterministic hash-based plan storage - Schema change detection and cache invalidation - Cost optimization through LLM call reduction - Performance metrics and monitoring """ def __init__(self, redis_client: redis.Redis, default_ttl_hours: int = 24, max_plan_size_kb: int = 512, enable_metrics: bool = True): """Initialize the plan cache. Args: redis_client: Redis client for storage default_ttl_hours: Default TTL for cached plans max_plan_size_kb: Maximum size of cached plans in KB enable_metrics: Whether to track cache metrics """ self.redis = redis_client self.default_ttl = default_ttl_hours * 3600 # Convert to seconds self.max_plan_size = max_plan_size_kb * 1024 # Convert to bytes self.metrics = CacheMetrics() if enable_metrics else None # Redis key prefixes self.PLAN_PREFIX = "plan_cache:plan:" self.SCHEMA_PREFIX = "plan_cache:schema:" self.METRICS_PREFIX = "plan_cache:metrics:" self.INDEX_PREFIX = "plan_cache:index:" def _compute_query_hash(self, user_query: str, schema_content: str, tenant_id: str) -> str: """Compute deterministic hash for user_query + schema + tenant. This hash is used as the primary cache key. Changes to any component will result in a cache miss, ensuring plans stay synchronized with the current schema state. Args: user_query: User's natural language query schema_content: JSON schema content from connectors tenant_id: Tenant identifier for isolation Returns: SHA-256 hash string """ # Normalize inputs for consistent hashing normalized_query = user_query.strip().lower() # Sort schema JSON to ensure consistent ordering try: schema_obj = json.loads(schema_content) normalized_schema = json.dumps(schema_obj, sort_keys=True, separators=(',', ':')) except json.JSONDecodeError: # If schema isn't valid JSON, use as-is normalized_schema = schema_content # Create composite string composite = f"{tenant_id}:{normalized_query}:{normalized_schema}" # Generate hash return hashlib.sha256(composite.encode('utf-8')).hexdigest() def _compute_schema_hash(self, schema_content: str) -> str: """Compute hash for schema content only. Used for schema change detection and cache invalidation. """ try: schema_obj = json.loads(schema_content) normalized_schema = json.dumps(schema_obj, sort_keys=True, separators=(',', ':')) except json.JSONDecodeError: normalized_schema = schema_content return hashlib.sha256(normalized_schema.encode('utf-8')).hexdigest() def get_cached_plan(self, user_query: str, schema_content: str, tenant_id: str) -> Tuple[Optional[List[Dict[str, Any]]], CacheStatus]: """Retrieve a cached plan if available. Args: user_query: User's natural language query schema_content: Current schema from connectors tenant_id: Tenant identifier Returns: Tuple of (plan_or_none, cache_status) """ if self.metrics: start_time = time.time() try: # Compute cache key plan_hash = self._compute_query_hash(user_query, schema_content, tenant_id) cache_key = f"{self.PLAN_PREFIX}{plan_hash}" # Check if plan exists in cache cached_data = self.redis.get(cache_key) if not cached_data: if self.metrics: self.metrics.record_miss() return None, CacheStatus.MISS # Deserialize cached plan try: cached_plan = CachedPlan.model_validate_json(cached_data) except Exception as e: logger.warning(f"Failed to deserialize cached plan {plan_hash}: {e}") if self.metrics: self.metrics.record_error() return None, CacheStatus.ERROR # Validate schema hasn't changed current_schema_hash = self._compute_schema_hash(schema_content) if cached_plan.schema_hash != current_schema_hash: logger.info(f"Schema changed for cached plan {plan_hash}, invalidating cache") self._invalidate_plan(plan_hash) if self.metrics: self.metrics.record_miss() return None, CacheStatus.STALE # Update usage statistics cached_plan.last_used_at = datetime.now(timezone.utc) cached_plan.use_count += 1 # Update cache with new metadata self._store_plan_internal(cached_plan, plan_hash) if self.metrics: self.metrics.record_hit() logger.info(f"Cache HIT for query hash {plan_hash[:12]}... (used {cached_plan.use_count} times)") return cached_plan.plan, CacheStatus.HIT except Exception as e: logger.exception(f"Error during cache lookup: {e}") if self.metrics: self.metrics.record_error() return None, CacheStatus.ERROR def store_plan(self, user_query: str, schema_content: str, tenant_id: str, generated_plan: List[Dict[str, Any]]) -> bool: """Store a newly generated plan in the cache. Args: user_query: Original user query schema_content: Schema used to generate the plan tenant_id: Tenant identifier generated_plan: The LLM-generated query plan Returns: True if stored successfully, False otherwise """ try: # Validate plan size plan_json = json.dumps(generated_plan) if len(plan_json.encode('utf-8')) > self.max_plan_size: logger.warning(f"Plan too large to cache: {len(plan_json)} bytes") return False # Create cached plan object plan_hash = self._compute_query_hash(user_query, schema_content, tenant_id) schema_hash = self._compute_schema_hash(schema_content) cached_plan = CachedPlan( plan_hash=plan_hash, plan=generated_plan, schema_hash=schema_hash, user_query=user_query, tenant_id=tenant_id ) # Store in Redis success = self._store_plan_internal(cached_plan, plan_hash) if success: logger.info(f"Cached new plan {plan_hash[:12]}... for tenant {tenant_id}") # Update tenant plan index self._update_tenant_index(tenant_id, plan_hash) return success except Exception as e: logger.exception(f"Failed to store plan in cache: {e}") return False def _store_plan_internal(self, cached_plan: CachedPlan, plan_hash: str) -> bool: """Internal method to store a plan with proper error handling.""" try: cache_key = f"{self.PLAN_PREFIX}{plan_hash}" plan_json = cached_plan.model_dump_json() # Store with TTL return self.redis.setex(cache_key, self.default_ttl, plan_json) except Exception as e: logger.error(f"Redis storage failed for plan {plan_hash}: {e}") return False def _update_tenant_index(self, tenant_id: str, plan_hash: str): """Update the index of plans for a tenant.""" try: index_key = f"{self.INDEX_PREFIX}tenant:{tenant_id}" # Add plan hash to tenant's set of cached plans self.redis.sadd(index_key, plan_hash) # Set TTL on the index (slightly longer than plan TTL) self.redis.expire(index_key, self.default_ttl + 3600) except Exception as e: logger.warning(f"Failed to update tenant index: {e}") def _invalidate_plan(self, plan_hash: str): """Remove a specific plan from cache.""" try: cache_key = f"{self.PLAN_PREFIX}{plan_hash}" self.redis.delete(cache_key) except Exception as e: logger.warning(f"Failed to invalidate plan {plan_hash}: {e}") def invalidate_tenant_cache(self, tenant_id: str) -> int: """Invalidate all cached plans for a tenant. Useful when tenant configuration changes or schema updates. Args: tenant_id: Tenant to invalidate Returns: Number of plans invalidated """ try: index_key = f"{self.INDEX_PREFIX}tenant:{tenant_id}" plan_hashes = self.redis.smembers(index_key) if not plan_hashes: return 0 # Delete all plans for this tenant keys_to_delete = [f"{self.PLAN_PREFIX}{plan_hash.decode()}" for plan_hash in plan_hashes] keys_to_delete.append(index_key) # Also delete the index deleted_count = self.redis.delete(*keys_to_delete) logger.info(f"Invalidated {deleted_count} cached plans for tenant {tenant_id}") return deleted_count except Exception as e: logger.exception(f"Failed to invalidate tenant cache for {tenant_id}: {e}") return 0 def get_cache_stats(self) -> Dict[str, Any]: """Get comprehensive cache statistics.""" if not self.metrics: return {"metrics_disabled": True} # Get Redis memory usage for cache keys try: cache_keys = self.redis.keys(f"{self.PLAN_PREFIX}*") cache_memory_bytes = sum(self.redis.memory_usage(key) or 0 for key in cache_keys[:100]) # Sample first 100 total_cached_plans = len(cache_keys) except Exception: cache_memory_bytes = 0 total_cached_plans = 0 stats = { "cache_metrics": asdict(self.metrics), "redis_stats": { "total_cached_plans": total_cached_plans, "estimated_memory_bytes": cache_memory_bytes, "estimated_memory_mb": round(cache_memory_bytes / (1024 * 1024), 2) }, "cost_optimization": { "estimated_savings_usd": round(self.metrics.cost_savings_estimated, 2), "hit_rate_percent": round(self.metrics.hit_rate, 1), "avg_cost_per_query": self.metrics.average_llm_cost_per_query } } return stats def cleanup_expired_plans(self) -> int: """Manual cleanup of expired plans (Redis handles TTL automatically). This is mainly for monitoring and can be called periodically. Returns: Number of expired plans found (they're auto-deleted by Redis) """ try: cache_keys = self.redis.keys(f"{self.PLAN_PREFIX}*") expired_count = 0 for key in cache_keys: ttl = self.redis.ttl(key) if ttl == -2: # Key doesn't exist (expired) expired_count += 1 logger.info(f"Found {expired_count} expired plans during cleanup") return expired_count except Exception as e: logger.exception(f"Error during cache cleanup: {e}") return 0 # Global cache instance (initialized by application) _plan_cache_instance: Optional[PlanCache] = None def get_plan_cache() -> Optional[PlanCache]: """Get the global plan cache instance.""" return _plan_cache_instance def init_plan_cache(redis_client: redis.Redis, **kwargs) -> PlanCache: """Initialize the global plan cache instance.""" global _plan_cache_instance _plan_cache_instance = PlanCache(redis_client, **kwargs) return _plan_cache_instance # Utility functions for easy integration def check_plan_cache(user_query: str, schema_content: str, tenant_id: str) -> Tuple[Optional[List[Dict[str, Any]]], bool]: """Convenience function to check cache with simple True/False result. Args: user_query: User's natural language query schema_content: Current schema JSON tenant_id: Tenant identifier Returns: Tuple of (plan_or_none, cache_hit_boolean) """ cache = get_plan_cache() if not cache: return None, False plan, status = cache.get_cached_plan(user_query, schema_content, tenant_id) return plan, (status == CacheStatus.HIT) def cache_generated_plan(user_query: str, schema_content: str, tenant_id: str, plan: List[Dict[str, Any]]) -> bool: """Convenience function to cache a newly generated plan. Args: user_query: Original user query schema_content: Schema used for generation tenant_id: Tenant identifier plan: Generated query plan Returns: True if cached successfully """ cache = get_plan_cache() if not cache: return False return cache.store_plan(user_query, schema_content, tenant_id, plan)