sirus / backend /data_sources /plan_cache.py
ranilmukesh's picture
Deploy SiRUS SQL Agent backend
b8277c4
"""Plan Caching System for Data Federation with LLM Cost Optimization.
This module provides intelligent caching of generated query plans to reduce LLM API costs
and improve response times. It implements hash-based plan storage and retrieval with
cache invalidation strategies.
Key features:
- Hash-based plan caching using user query + schema fingerprint
- Redis storage with configurable TTL
- LLM bypass logic for cached plans
- Cache invalidation when schemas change
- Support for plan versioning and rollback
- Cost tracking and optimization metrics
"""
import hashlib
import json
import logging
import time
from typing import Any, Dict, List, Optional, Tuple
from datetime import datetime, timedelta, timezone
from dataclasses import dataclass, asdict
from enum import Enum
import redis
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class CacheStatus(str, Enum):
"""Cache lookup result status."""
HIT = "hit" # Plan found in cache and returned
MISS = "miss" # Plan not found, LLM call required
STALE = "stale" # Plan found but expired/invalid
ERROR = "error" # Cache operation failed
@dataclass
class CacheMetrics:
"""Metrics for plan cache operations."""
cache_hits: int = 0
cache_misses: int = 0
cache_errors: int = 0
total_lookups: int = 0
cost_savings_estimated: float = 0.0 # Estimated $ saved from cache hits
average_llm_cost_per_query: float = 0.05 # Default $0.05 per LLM call
@property
def hit_rate(self) -> float:
"""Calculate cache hit rate percentage."""
if self.total_lookups == 0:
return 0.0
return (self.cache_hits / self.total_lookups) * 100
def record_hit(self):
"""Record a cache hit."""
self.cache_hits += 1
self.total_lookups += 1
self.cost_savings_estimated += self.average_llm_cost_per_query
def record_miss(self):
"""Record a cache miss."""
self.cache_misses += 1
self.total_lookups += 1
def record_error(self):
"""Record a cache error."""
self.cache_errors += 1
self.total_lookups += 1
class CachedPlan(BaseModel):
"""A cached query plan with metadata."""
plan_hash: str = Field(..., description="SHA-256 hash of query+schema")
plan: List[Dict[str, Any]] = Field(..., description="The actual query plan (AST)")
schema_hash: str = Field(..., description="Hash of the schema used to generate this plan")
user_query: str = Field(..., description="Original user query")
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
last_used_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
use_count: int = Field(default=1, description="Number of times this plan has been used")
tenant_id: str = Field(..., description="Tenant who owns this cached plan")
version: str = Field(default="1.0", description="Plan format version")
# Performance metadata
avg_execution_time_ms: Optional[float] = Field(None, description="Average execution time")
success_rate: float = Field(default=1.0, description="Success rate of this plan")
last_error: Optional[str] = Field(None, description="Last execution error if any")
class PlanCache:
"""High-performance plan cache with Redis backend and intelligent invalidation.
This class provides:
- Deterministic hash-based plan storage
- Schema change detection and cache invalidation
- Cost optimization through LLM call reduction
- Performance metrics and monitoring
"""
def __init__(self, redis_client: redis.Redis,
default_ttl_hours: int = 24,
max_plan_size_kb: int = 512,
enable_metrics: bool = True):
"""Initialize the plan cache.
Args:
redis_client: Redis client for storage
default_ttl_hours: Default TTL for cached plans
max_plan_size_kb: Maximum size of cached plans in KB
enable_metrics: Whether to track cache metrics
"""
self.redis = redis_client
self.default_ttl = default_ttl_hours * 3600 # Convert to seconds
self.max_plan_size = max_plan_size_kb * 1024 # Convert to bytes
self.metrics = CacheMetrics() if enable_metrics else None
# Redis key prefixes
self.PLAN_PREFIX = "plan_cache:plan:"
self.SCHEMA_PREFIX = "plan_cache:schema:"
self.METRICS_PREFIX = "plan_cache:metrics:"
self.INDEX_PREFIX = "plan_cache:index:"
def _compute_query_hash(self, user_query: str, schema_content: str, tenant_id: str) -> str:
"""Compute deterministic hash for user_query + schema + tenant.
This hash is used as the primary cache key. Changes to any component
will result in a cache miss, ensuring plans stay synchronized with
the current schema state.
Args:
user_query: User's natural language query
schema_content: JSON schema content from connectors
tenant_id: Tenant identifier for isolation
Returns:
SHA-256 hash string
"""
# Normalize inputs for consistent hashing
normalized_query = user_query.strip().lower()
# Sort schema JSON to ensure consistent ordering
try:
schema_obj = json.loads(schema_content)
normalized_schema = json.dumps(schema_obj, sort_keys=True, separators=(',', ':'))
except json.JSONDecodeError:
# If schema isn't valid JSON, use as-is
normalized_schema = schema_content
# Create composite string
composite = f"{tenant_id}:{normalized_query}:{normalized_schema}"
# Generate hash
return hashlib.sha256(composite.encode('utf-8')).hexdigest()
def _compute_schema_hash(self, schema_content: str) -> str:
"""Compute hash for schema content only.
Used for schema change detection and cache invalidation.
"""
try:
schema_obj = json.loads(schema_content)
normalized_schema = json.dumps(schema_obj, sort_keys=True, separators=(',', ':'))
except json.JSONDecodeError:
normalized_schema = schema_content
return hashlib.sha256(normalized_schema.encode('utf-8')).hexdigest()
def get_cached_plan(self, user_query: str, schema_content: str, tenant_id: str) -> Tuple[Optional[List[Dict[str, Any]]], CacheStatus]:
"""Retrieve a cached plan if available.
Args:
user_query: User's natural language query
schema_content: Current schema from connectors
tenant_id: Tenant identifier
Returns:
Tuple of (plan_or_none, cache_status)
"""
if self.metrics:
start_time = time.time()
try:
# Compute cache key
plan_hash = self._compute_query_hash(user_query, schema_content, tenant_id)
cache_key = f"{self.PLAN_PREFIX}{plan_hash}"
# Check if plan exists in cache
cached_data = self.redis.get(cache_key)
if not cached_data:
if self.metrics:
self.metrics.record_miss()
return None, CacheStatus.MISS
# Deserialize cached plan
try:
cached_plan = CachedPlan.model_validate_json(cached_data)
except Exception as e:
logger.warning(f"Failed to deserialize cached plan {plan_hash}: {e}")
if self.metrics:
self.metrics.record_error()
return None, CacheStatus.ERROR
# Validate schema hasn't changed
current_schema_hash = self._compute_schema_hash(schema_content)
if cached_plan.schema_hash != current_schema_hash:
logger.info(f"Schema changed for cached plan {plan_hash}, invalidating cache")
self._invalidate_plan(plan_hash)
if self.metrics:
self.metrics.record_miss()
return None, CacheStatus.STALE
# Update usage statistics
cached_plan.last_used_at = datetime.now(timezone.utc)
cached_plan.use_count += 1
# Update cache with new metadata
self._store_plan_internal(cached_plan, plan_hash)
if self.metrics:
self.metrics.record_hit()
logger.info(f"Cache HIT for query hash {plan_hash[:12]}... (used {cached_plan.use_count} times)")
return cached_plan.plan, CacheStatus.HIT
except Exception as e:
logger.exception(f"Error during cache lookup: {e}")
if self.metrics:
self.metrics.record_error()
return None, CacheStatus.ERROR
def store_plan(self, user_query: str, schema_content: str, tenant_id: str,
generated_plan: List[Dict[str, Any]]) -> bool:
"""Store a newly generated plan in the cache.
Args:
user_query: Original user query
schema_content: Schema used to generate the plan
tenant_id: Tenant identifier
generated_plan: The LLM-generated query plan
Returns:
True if stored successfully, False otherwise
"""
try:
# Validate plan size
plan_json = json.dumps(generated_plan)
if len(plan_json.encode('utf-8')) > self.max_plan_size:
logger.warning(f"Plan too large to cache: {len(plan_json)} bytes")
return False
# Create cached plan object
plan_hash = self._compute_query_hash(user_query, schema_content, tenant_id)
schema_hash = self._compute_schema_hash(schema_content)
cached_plan = CachedPlan(
plan_hash=plan_hash,
plan=generated_plan,
schema_hash=schema_hash,
user_query=user_query,
tenant_id=tenant_id
)
# Store in Redis
success = self._store_plan_internal(cached_plan, plan_hash)
if success:
logger.info(f"Cached new plan {plan_hash[:12]}... for tenant {tenant_id}")
# Update tenant plan index
self._update_tenant_index(tenant_id, plan_hash)
return success
except Exception as e:
logger.exception(f"Failed to store plan in cache: {e}")
return False
def _store_plan_internal(self, cached_plan: CachedPlan, plan_hash: str) -> bool:
"""Internal method to store a plan with proper error handling."""
try:
cache_key = f"{self.PLAN_PREFIX}{plan_hash}"
plan_json = cached_plan.model_dump_json()
# Store with TTL
return self.redis.setex(cache_key, self.default_ttl, plan_json)
except Exception as e:
logger.error(f"Redis storage failed for plan {plan_hash}: {e}")
return False
def _update_tenant_index(self, tenant_id: str, plan_hash: str):
"""Update the index of plans for a tenant."""
try:
index_key = f"{self.INDEX_PREFIX}tenant:{tenant_id}"
# Add plan hash to tenant's set of cached plans
self.redis.sadd(index_key, plan_hash)
# Set TTL on the index (slightly longer than plan TTL)
self.redis.expire(index_key, self.default_ttl + 3600)
except Exception as e:
logger.warning(f"Failed to update tenant index: {e}")
def _invalidate_plan(self, plan_hash: str):
"""Remove a specific plan from cache."""
try:
cache_key = f"{self.PLAN_PREFIX}{plan_hash}"
self.redis.delete(cache_key)
except Exception as e:
logger.warning(f"Failed to invalidate plan {plan_hash}: {e}")
def invalidate_tenant_cache(self, tenant_id: str) -> int:
"""Invalidate all cached plans for a tenant.
Useful when tenant configuration changes or schema updates.
Args:
tenant_id: Tenant to invalidate
Returns:
Number of plans invalidated
"""
try:
index_key = f"{self.INDEX_PREFIX}tenant:{tenant_id}"
plan_hashes = self.redis.smembers(index_key)
if not plan_hashes:
return 0
# Delete all plans for this tenant
keys_to_delete = [f"{self.PLAN_PREFIX}{plan_hash.decode()}" for plan_hash in plan_hashes]
keys_to_delete.append(index_key) # Also delete the index
deleted_count = self.redis.delete(*keys_to_delete)
logger.info(f"Invalidated {deleted_count} cached plans for tenant {tenant_id}")
return deleted_count
except Exception as e:
logger.exception(f"Failed to invalidate tenant cache for {tenant_id}: {e}")
return 0
def get_cache_stats(self) -> Dict[str, Any]:
"""Get comprehensive cache statistics."""
if not self.metrics:
return {"metrics_disabled": True}
# Get Redis memory usage for cache keys
try:
cache_keys = self.redis.keys(f"{self.PLAN_PREFIX}*")
cache_memory_bytes = sum(self.redis.memory_usage(key) or 0 for key in cache_keys[:100]) # Sample first 100
total_cached_plans = len(cache_keys)
except Exception:
cache_memory_bytes = 0
total_cached_plans = 0
stats = {
"cache_metrics": asdict(self.metrics),
"redis_stats": {
"total_cached_plans": total_cached_plans,
"estimated_memory_bytes": cache_memory_bytes,
"estimated_memory_mb": round(cache_memory_bytes / (1024 * 1024), 2)
},
"cost_optimization": {
"estimated_savings_usd": round(self.metrics.cost_savings_estimated, 2),
"hit_rate_percent": round(self.metrics.hit_rate, 1),
"avg_cost_per_query": self.metrics.average_llm_cost_per_query
}
}
return stats
def cleanup_expired_plans(self) -> int:
"""Manual cleanup of expired plans (Redis handles TTL automatically).
This is mainly for monitoring and can be called periodically.
Returns:
Number of expired plans found (they're auto-deleted by Redis)
"""
try:
cache_keys = self.redis.keys(f"{self.PLAN_PREFIX}*")
expired_count = 0
for key in cache_keys:
ttl = self.redis.ttl(key)
if ttl == -2: # Key doesn't exist (expired)
expired_count += 1
logger.info(f"Found {expired_count} expired plans during cleanup")
return expired_count
except Exception as e:
logger.exception(f"Error during cache cleanup: {e}")
return 0
# Global cache instance (initialized by application)
_plan_cache_instance: Optional[PlanCache] = None
def get_plan_cache() -> Optional[PlanCache]:
"""Get the global plan cache instance."""
return _plan_cache_instance
def init_plan_cache(redis_client: redis.Redis, **kwargs) -> PlanCache:
"""Initialize the global plan cache instance."""
global _plan_cache_instance
_plan_cache_instance = PlanCache(redis_client, **kwargs)
return _plan_cache_instance
# Utility functions for easy integration
def check_plan_cache(user_query: str, schema_content: str, tenant_id: str) -> Tuple[Optional[List[Dict[str, Any]]], bool]:
"""Convenience function to check cache with simple True/False result.
Args:
user_query: User's natural language query
schema_content: Current schema JSON
tenant_id: Tenant identifier
Returns:
Tuple of (plan_or_none, cache_hit_boolean)
"""
cache = get_plan_cache()
if not cache:
return None, False
plan, status = cache.get_cached_plan(user_query, schema_content, tenant_id)
return plan, (status == CacheStatus.HIT)
def cache_generated_plan(user_query: str, schema_content: str, tenant_id: str, plan: List[Dict[str, Any]]) -> bool:
"""Convenience function to cache a newly generated plan.
Args:
user_query: Original user query
schema_content: Schema used for generation
tenant_id: Tenant identifier
plan: Generated query plan
Returns:
True if cached successfully
"""
cache = get_plan_cache()
if not cache:
return False
return cache.store_plan(user_query, schema_content, tenant_id, plan)