Spaces:
Running
Running
| """Plan Caching System for Data Federation with LLM Cost Optimization. | |
| This module provides intelligent caching of generated query plans to reduce LLM API costs | |
| and improve response times. It implements hash-based plan storage and retrieval with | |
| cache invalidation strategies. | |
| Key features: | |
| - Hash-based plan caching using user query + schema fingerprint | |
| - Redis storage with configurable TTL | |
| - LLM bypass logic for cached plans | |
| - Cache invalidation when schemas change | |
| - Support for plan versioning and rollback | |
| - Cost tracking and optimization metrics | |
| """ | |
| import hashlib | |
| import json | |
| import logging | |
| import time | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from datetime import datetime, timedelta, timezone | |
| from dataclasses import dataclass, asdict | |
| from enum import Enum | |
| import redis | |
| from pydantic import BaseModel, Field | |
| logger = logging.getLogger(__name__) | |
| class CacheStatus(str, Enum): | |
| """Cache lookup result status.""" | |
| HIT = "hit" # Plan found in cache and returned | |
| MISS = "miss" # Plan not found, LLM call required | |
| STALE = "stale" # Plan found but expired/invalid | |
| ERROR = "error" # Cache operation failed | |
| class CacheMetrics: | |
| """Metrics for plan cache operations.""" | |
| cache_hits: int = 0 | |
| cache_misses: int = 0 | |
| cache_errors: int = 0 | |
| total_lookups: int = 0 | |
| cost_savings_estimated: float = 0.0 # Estimated $ saved from cache hits | |
| average_llm_cost_per_query: float = 0.05 # Default $0.05 per LLM call | |
| def hit_rate(self) -> float: | |
| """Calculate cache hit rate percentage.""" | |
| if self.total_lookups == 0: | |
| return 0.0 | |
| return (self.cache_hits / self.total_lookups) * 100 | |
| def record_hit(self): | |
| """Record a cache hit.""" | |
| self.cache_hits += 1 | |
| self.total_lookups += 1 | |
| self.cost_savings_estimated += self.average_llm_cost_per_query | |
| def record_miss(self): | |
| """Record a cache miss.""" | |
| self.cache_misses += 1 | |
| self.total_lookups += 1 | |
| def record_error(self): | |
| """Record a cache error.""" | |
| self.cache_errors += 1 | |
| self.total_lookups += 1 | |
| class CachedPlan(BaseModel): | |
| """A cached query plan with metadata.""" | |
| plan_hash: str = Field(..., description="SHA-256 hash of query+schema") | |
| plan: List[Dict[str, Any]] = Field(..., description="The actual query plan (AST)") | |
| schema_hash: str = Field(..., description="Hash of the schema used to generate this plan") | |
| user_query: str = Field(..., description="Original user query") | |
| created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) | |
| last_used_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) | |
| use_count: int = Field(default=1, description="Number of times this plan has been used") | |
| tenant_id: str = Field(..., description="Tenant who owns this cached plan") | |
| version: str = Field(default="1.0", description="Plan format version") | |
| # Performance metadata | |
| avg_execution_time_ms: Optional[float] = Field(None, description="Average execution time") | |
| success_rate: float = Field(default=1.0, description="Success rate of this plan") | |
| last_error: Optional[str] = Field(None, description="Last execution error if any") | |
| class PlanCache: | |
| """High-performance plan cache with Redis backend and intelligent invalidation. | |
| This class provides: | |
| - Deterministic hash-based plan storage | |
| - Schema change detection and cache invalidation | |
| - Cost optimization through LLM call reduction | |
| - Performance metrics and monitoring | |
| """ | |
| def __init__(self, redis_client: redis.Redis, | |
| default_ttl_hours: int = 24, | |
| max_plan_size_kb: int = 512, | |
| enable_metrics: bool = True): | |
| """Initialize the plan cache. | |
| Args: | |
| redis_client: Redis client for storage | |
| default_ttl_hours: Default TTL for cached plans | |
| max_plan_size_kb: Maximum size of cached plans in KB | |
| enable_metrics: Whether to track cache metrics | |
| """ | |
| self.redis = redis_client | |
| self.default_ttl = default_ttl_hours * 3600 # Convert to seconds | |
| self.max_plan_size = max_plan_size_kb * 1024 # Convert to bytes | |
| self.metrics = CacheMetrics() if enable_metrics else None | |
| # Redis key prefixes | |
| self.PLAN_PREFIX = "plan_cache:plan:" | |
| self.SCHEMA_PREFIX = "plan_cache:schema:" | |
| self.METRICS_PREFIX = "plan_cache:metrics:" | |
| self.INDEX_PREFIX = "plan_cache:index:" | |
| def _compute_query_hash(self, user_query: str, schema_content: str, tenant_id: str) -> str: | |
| """Compute deterministic hash for user_query + schema + tenant. | |
| This hash is used as the primary cache key. Changes to any component | |
| will result in a cache miss, ensuring plans stay synchronized with | |
| the current schema state. | |
| Args: | |
| user_query: User's natural language query | |
| schema_content: JSON schema content from connectors | |
| tenant_id: Tenant identifier for isolation | |
| Returns: | |
| SHA-256 hash string | |
| """ | |
| # Normalize inputs for consistent hashing | |
| normalized_query = user_query.strip().lower() | |
| # Sort schema JSON to ensure consistent ordering | |
| try: | |
| schema_obj = json.loads(schema_content) | |
| normalized_schema = json.dumps(schema_obj, sort_keys=True, separators=(',', ':')) | |
| except json.JSONDecodeError: | |
| # If schema isn't valid JSON, use as-is | |
| normalized_schema = schema_content | |
| # Create composite string | |
| composite = f"{tenant_id}:{normalized_query}:{normalized_schema}" | |
| # Generate hash | |
| return hashlib.sha256(composite.encode('utf-8')).hexdigest() | |
| def _compute_schema_hash(self, schema_content: str) -> str: | |
| """Compute hash for schema content only. | |
| Used for schema change detection and cache invalidation. | |
| """ | |
| try: | |
| schema_obj = json.loads(schema_content) | |
| normalized_schema = json.dumps(schema_obj, sort_keys=True, separators=(',', ':')) | |
| except json.JSONDecodeError: | |
| normalized_schema = schema_content | |
| return hashlib.sha256(normalized_schema.encode('utf-8')).hexdigest() | |
| def get_cached_plan(self, user_query: str, schema_content: str, tenant_id: str) -> Tuple[Optional[List[Dict[str, Any]]], CacheStatus]: | |
| """Retrieve a cached plan if available. | |
| Args: | |
| user_query: User's natural language query | |
| schema_content: Current schema from connectors | |
| tenant_id: Tenant identifier | |
| Returns: | |
| Tuple of (plan_or_none, cache_status) | |
| """ | |
| if self.metrics: | |
| start_time = time.time() | |
| try: | |
| # Compute cache key | |
| plan_hash = self._compute_query_hash(user_query, schema_content, tenant_id) | |
| cache_key = f"{self.PLAN_PREFIX}{plan_hash}" | |
| # Check if plan exists in cache | |
| cached_data = self.redis.get(cache_key) | |
| if not cached_data: | |
| if self.metrics: | |
| self.metrics.record_miss() | |
| return None, CacheStatus.MISS | |
| # Deserialize cached plan | |
| try: | |
| cached_plan = CachedPlan.model_validate_json(cached_data) | |
| except Exception as e: | |
| logger.warning(f"Failed to deserialize cached plan {plan_hash}: {e}") | |
| if self.metrics: | |
| self.metrics.record_error() | |
| return None, CacheStatus.ERROR | |
| # Validate schema hasn't changed | |
| current_schema_hash = self._compute_schema_hash(schema_content) | |
| if cached_plan.schema_hash != current_schema_hash: | |
| logger.info(f"Schema changed for cached plan {plan_hash}, invalidating cache") | |
| self._invalidate_plan(plan_hash) | |
| if self.metrics: | |
| self.metrics.record_miss() | |
| return None, CacheStatus.STALE | |
| # Update usage statistics | |
| cached_plan.last_used_at = datetime.now(timezone.utc) | |
| cached_plan.use_count += 1 | |
| # Update cache with new metadata | |
| self._store_plan_internal(cached_plan, plan_hash) | |
| if self.metrics: | |
| self.metrics.record_hit() | |
| logger.info(f"Cache HIT for query hash {plan_hash[:12]}... (used {cached_plan.use_count} times)") | |
| return cached_plan.plan, CacheStatus.HIT | |
| except Exception as e: | |
| logger.exception(f"Error during cache lookup: {e}") | |
| if self.metrics: | |
| self.metrics.record_error() | |
| return None, CacheStatus.ERROR | |
| def store_plan(self, user_query: str, schema_content: str, tenant_id: str, | |
| generated_plan: List[Dict[str, Any]]) -> bool: | |
| """Store a newly generated plan in the cache. | |
| Args: | |
| user_query: Original user query | |
| schema_content: Schema used to generate the plan | |
| tenant_id: Tenant identifier | |
| generated_plan: The LLM-generated query plan | |
| Returns: | |
| True if stored successfully, False otherwise | |
| """ | |
| try: | |
| # Validate plan size | |
| plan_json = json.dumps(generated_plan) | |
| if len(plan_json.encode('utf-8')) > self.max_plan_size: | |
| logger.warning(f"Plan too large to cache: {len(plan_json)} bytes") | |
| return False | |
| # Create cached plan object | |
| plan_hash = self._compute_query_hash(user_query, schema_content, tenant_id) | |
| schema_hash = self._compute_schema_hash(schema_content) | |
| cached_plan = CachedPlan( | |
| plan_hash=plan_hash, | |
| plan=generated_plan, | |
| schema_hash=schema_hash, | |
| user_query=user_query, | |
| tenant_id=tenant_id | |
| ) | |
| # Store in Redis | |
| success = self._store_plan_internal(cached_plan, plan_hash) | |
| if success: | |
| logger.info(f"Cached new plan {plan_hash[:12]}... for tenant {tenant_id}") | |
| # Update tenant plan index | |
| self._update_tenant_index(tenant_id, plan_hash) | |
| return success | |
| except Exception as e: | |
| logger.exception(f"Failed to store plan in cache: {e}") | |
| return False | |
| def _store_plan_internal(self, cached_plan: CachedPlan, plan_hash: str) -> bool: | |
| """Internal method to store a plan with proper error handling.""" | |
| try: | |
| cache_key = f"{self.PLAN_PREFIX}{plan_hash}" | |
| plan_json = cached_plan.model_dump_json() | |
| # Store with TTL | |
| return self.redis.setex(cache_key, self.default_ttl, plan_json) | |
| except Exception as e: | |
| logger.error(f"Redis storage failed for plan {plan_hash}: {e}") | |
| return False | |
| def _update_tenant_index(self, tenant_id: str, plan_hash: str): | |
| """Update the index of plans for a tenant.""" | |
| try: | |
| index_key = f"{self.INDEX_PREFIX}tenant:{tenant_id}" | |
| # Add plan hash to tenant's set of cached plans | |
| self.redis.sadd(index_key, plan_hash) | |
| # Set TTL on the index (slightly longer than plan TTL) | |
| self.redis.expire(index_key, self.default_ttl + 3600) | |
| except Exception as e: | |
| logger.warning(f"Failed to update tenant index: {e}") | |
| def _invalidate_plan(self, plan_hash: str): | |
| """Remove a specific plan from cache.""" | |
| try: | |
| cache_key = f"{self.PLAN_PREFIX}{plan_hash}" | |
| self.redis.delete(cache_key) | |
| except Exception as e: | |
| logger.warning(f"Failed to invalidate plan {plan_hash}: {e}") | |
| def invalidate_tenant_cache(self, tenant_id: str) -> int: | |
| """Invalidate all cached plans for a tenant. | |
| Useful when tenant configuration changes or schema updates. | |
| Args: | |
| tenant_id: Tenant to invalidate | |
| Returns: | |
| Number of plans invalidated | |
| """ | |
| try: | |
| index_key = f"{self.INDEX_PREFIX}tenant:{tenant_id}" | |
| plan_hashes = self.redis.smembers(index_key) | |
| if not plan_hashes: | |
| return 0 | |
| # Delete all plans for this tenant | |
| keys_to_delete = [f"{self.PLAN_PREFIX}{plan_hash.decode()}" for plan_hash in plan_hashes] | |
| keys_to_delete.append(index_key) # Also delete the index | |
| deleted_count = self.redis.delete(*keys_to_delete) | |
| logger.info(f"Invalidated {deleted_count} cached plans for tenant {tenant_id}") | |
| return deleted_count | |
| except Exception as e: | |
| logger.exception(f"Failed to invalidate tenant cache for {tenant_id}: {e}") | |
| return 0 | |
| def get_cache_stats(self) -> Dict[str, Any]: | |
| """Get comprehensive cache statistics.""" | |
| if not self.metrics: | |
| return {"metrics_disabled": True} | |
| # Get Redis memory usage for cache keys | |
| try: | |
| cache_keys = self.redis.keys(f"{self.PLAN_PREFIX}*") | |
| cache_memory_bytes = sum(self.redis.memory_usage(key) or 0 for key in cache_keys[:100]) # Sample first 100 | |
| total_cached_plans = len(cache_keys) | |
| except Exception: | |
| cache_memory_bytes = 0 | |
| total_cached_plans = 0 | |
| stats = { | |
| "cache_metrics": asdict(self.metrics), | |
| "redis_stats": { | |
| "total_cached_plans": total_cached_plans, | |
| "estimated_memory_bytes": cache_memory_bytes, | |
| "estimated_memory_mb": round(cache_memory_bytes / (1024 * 1024), 2) | |
| }, | |
| "cost_optimization": { | |
| "estimated_savings_usd": round(self.metrics.cost_savings_estimated, 2), | |
| "hit_rate_percent": round(self.metrics.hit_rate, 1), | |
| "avg_cost_per_query": self.metrics.average_llm_cost_per_query | |
| } | |
| } | |
| return stats | |
| def cleanup_expired_plans(self) -> int: | |
| """Manual cleanup of expired plans (Redis handles TTL automatically). | |
| This is mainly for monitoring and can be called periodically. | |
| Returns: | |
| Number of expired plans found (they're auto-deleted by Redis) | |
| """ | |
| try: | |
| cache_keys = self.redis.keys(f"{self.PLAN_PREFIX}*") | |
| expired_count = 0 | |
| for key in cache_keys: | |
| ttl = self.redis.ttl(key) | |
| if ttl == -2: # Key doesn't exist (expired) | |
| expired_count += 1 | |
| logger.info(f"Found {expired_count} expired plans during cleanup") | |
| return expired_count | |
| except Exception as e: | |
| logger.exception(f"Error during cache cleanup: {e}") | |
| return 0 | |
| # Global cache instance (initialized by application) | |
| _plan_cache_instance: Optional[PlanCache] = None | |
| def get_plan_cache() -> Optional[PlanCache]: | |
| """Get the global plan cache instance.""" | |
| return _plan_cache_instance | |
| def init_plan_cache(redis_client: redis.Redis, **kwargs) -> PlanCache: | |
| """Initialize the global plan cache instance.""" | |
| global _plan_cache_instance | |
| _plan_cache_instance = PlanCache(redis_client, **kwargs) | |
| return _plan_cache_instance | |
| # Utility functions for easy integration | |
| def check_plan_cache(user_query: str, schema_content: str, tenant_id: str) -> Tuple[Optional[List[Dict[str, Any]]], bool]: | |
| """Convenience function to check cache with simple True/False result. | |
| Args: | |
| user_query: User's natural language query | |
| schema_content: Current schema JSON | |
| tenant_id: Tenant identifier | |
| Returns: | |
| Tuple of (plan_or_none, cache_hit_boolean) | |
| """ | |
| cache = get_plan_cache() | |
| if not cache: | |
| return None, False | |
| plan, status = cache.get_cached_plan(user_query, schema_content, tenant_id) | |
| return plan, (status == CacheStatus.HIT) | |
| def cache_generated_plan(user_query: str, schema_content: str, tenant_id: str, plan: List[Dict[str, Any]]) -> bool: | |
| """Convenience function to cache a newly generated plan. | |
| Args: | |
| user_query: Original user query | |
| schema_content: Schema used for generation | |
| tenant_id: Tenant identifier | |
| plan: Generated query plan | |
| Returns: | |
| True if cached successfully | |
| """ | |
| cache = get_plan_cache() | |
| if not cache: | |
| return False | |
| return cache.store_plan(user_query, schema_content, tenant_id, plan) |