File size: 17,287 Bytes
b8277c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
"""Plan Caching System for Data Federation with LLM Cost Optimization.

This module provides intelligent caching of generated query plans to reduce LLM API costs
and improve response times. It implements hash-based plan storage and retrieval with 
cache invalidation strategies.

Key features:
- Hash-based plan caching using user query + schema fingerprint  
- Redis storage with configurable TTL
- LLM bypass logic for cached plans
- Cache invalidation when schemas change
- Support for plan versioning and rollback
- Cost tracking and optimization metrics
"""
import hashlib
import json
import logging
import time
from typing import Any, Dict, List, Optional, Tuple
from datetime import datetime, timedelta, timezone
from dataclasses import dataclass, asdict
from enum import Enum

import redis
from pydantic import BaseModel, Field

logger = logging.getLogger(__name__)


class CacheStatus(str, Enum):
    """Cache lookup result status."""
    HIT = "hit"           # Plan found in cache and returned
    MISS = "miss"         # Plan not found, LLM call required
    STALE = "stale"       # Plan found but expired/invalid
    ERROR = "error"       # Cache operation failed


@dataclass
class CacheMetrics:
    """Metrics for plan cache operations."""
    cache_hits: int = 0
    cache_misses: int = 0
    cache_errors: int = 0
    total_lookups: int = 0
    cost_savings_estimated: float = 0.0  # Estimated $ saved from cache hits
    average_llm_cost_per_query: float = 0.05  # Default $0.05 per LLM call
    
    @property
    def hit_rate(self) -> float:
        """Calculate cache hit rate percentage."""
        if self.total_lookups == 0:
            return 0.0
        return (self.cache_hits / self.total_lookups) * 100
    
    def record_hit(self):
        """Record a cache hit."""
        self.cache_hits += 1
        self.total_lookups += 1
        self.cost_savings_estimated += self.average_llm_cost_per_query
    
    def record_miss(self):
        """Record a cache miss."""
        self.cache_misses += 1
        self.total_lookups += 1
    
    def record_error(self):
        """Record a cache error."""
        self.cache_errors += 1
        self.total_lookups += 1


class CachedPlan(BaseModel):
    """A cached query plan with metadata."""
    plan_hash: str = Field(..., description="SHA-256 hash of query+schema")
    plan: List[Dict[str, Any]] = Field(..., description="The actual query plan (AST)")
    schema_hash: str = Field(..., description="Hash of the schema used to generate this plan")
    user_query: str = Field(..., description="Original user query")
    created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
    last_used_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
    use_count: int = Field(default=1, description="Number of times this plan has been used")
    tenant_id: str = Field(..., description="Tenant who owns this cached plan")
    version: str = Field(default="1.0", description="Plan format version")
    
    # Performance metadata
    avg_execution_time_ms: Optional[float] = Field(None, description="Average execution time")
    success_rate: float = Field(default=1.0, description="Success rate of this plan")
    last_error: Optional[str] = Field(None, description="Last execution error if any")


class PlanCache:
    """High-performance plan cache with Redis backend and intelligent invalidation.
    
    This class provides:
    - Deterministic hash-based plan storage
    - Schema change detection and cache invalidation  
    - Cost optimization through LLM call reduction
    - Performance metrics and monitoring
    """
    
    def __init__(self, redis_client: redis.Redis, 
                 default_ttl_hours: int = 24,
                 max_plan_size_kb: int = 512,
                 enable_metrics: bool = True):
        """Initialize the plan cache.
        
        Args:
            redis_client: Redis client for storage
            default_ttl_hours: Default TTL for cached plans
            max_plan_size_kb: Maximum size of cached plans in KB
            enable_metrics: Whether to track cache metrics
        """
        self.redis = redis_client
        self.default_ttl = default_ttl_hours * 3600  # Convert to seconds
        self.max_plan_size = max_plan_size_kb * 1024  # Convert to bytes
        self.metrics = CacheMetrics() if enable_metrics else None
        
        # Redis key prefixes
        self.PLAN_PREFIX = "plan_cache:plan:"
        self.SCHEMA_PREFIX = "plan_cache:schema:"
        self.METRICS_PREFIX = "plan_cache:metrics:"
        self.INDEX_PREFIX = "plan_cache:index:"
    
    def _compute_query_hash(self, user_query: str, schema_content: str, tenant_id: str) -> str:
        """Compute deterministic hash for user_query + schema + tenant.
        
        This hash is used as the primary cache key. Changes to any component
        will result in a cache miss, ensuring plans stay synchronized with
        the current schema state.
        
        Args:
            user_query: User's natural language query
            schema_content: JSON schema content from connectors
            tenant_id: Tenant identifier for isolation
            
        Returns:
            SHA-256 hash string
        """
        # Normalize inputs for consistent hashing
        normalized_query = user_query.strip().lower()
        # Sort schema JSON to ensure consistent ordering
        try:
            schema_obj = json.loads(schema_content)
            normalized_schema = json.dumps(schema_obj, sort_keys=True, separators=(',', ':'))
        except json.JSONDecodeError:
            # If schema isn't valid JSON, use as-is
            normalized_schema = schema_content
        
        # Create composite string
        composite = f"{tenant_id}:{normalized_query}:{normalized_schema}"
        
        # Generate hash
        return hashlib.sha256(composite.encode('utf-8')).hexdigest()
    
    def _compute_schema_hash(self, schema_content: str) -> str:
        """Compute hash for schema content only.
        
        Used for schema change detection and cache invalidation.
        """
        try:
            schema_obj = json.loads(schema_content)
            normalized_schema = json.dumps(schema_obj, sort_keys=True, separators=(',', ':'))
        except json.JSONDecodeError:
            normalized_schema = schema_content
        
        return hashlib.sha256(normalized_schema.encode('utf-8')).hexdigest()
    
    def get_cached_plan(self, user_query: str, schema_content: str, tenant_id: str) -> Tuple[Optional[List[Dict[str, Any]]], CacheStatus]:
        """Retrieve a cached plan if available.
        
        Args:
            user_query: User's natural language query
            schema_content: Current schema from connectors
            tenant_id: Tenant identifier
            
        Returns:
            Tuple of (plan_or_none, cache_status)
        """
        if self.metrics:
            start_time = time.time()
        
        try:
            # Compute cache key
            plan_hash = self._compute_query_hash(user_query, schema_content, tenant_id)
            cache_key = f"{self.PLAN_PREFIX}{plan_hash}"
            
            # Check if plan exists in cache
            cached_data = self.redis.get(cache_key)
            if not cached_data:
                if self.metrics:
                    self.metrics.record_miss()
                return None, CacheStatus.MISS
            
            # Deserialize cached plan
            try:
                cached_plan = CachedPlan.model_validate_json(cached_data)
            except Exception as e:
                logger.warning(f"Failed to deserialize cached plan {plan_hash}: {e}")
                if self.metrics:
                    self.metrics.record_error()
                return None, CacheStatus.ERROR
            
            # Validate schema hasn't changed
            current_schema_hash = self._compute_schema_hash(schema_content)
            if cached_plan.schema_hash != current_schema_hash:
                logger.info(f"Schema changed for cached plan {plan_hash}, invalidating cache")
                self._invalidate_plan(plan_hash)
                if self.metrics:
                    self.metrics.record_miss()
                return None, CacheStatus.STALE
            
            # Update usage statistics
            cached_plan.last_used_at = datetime.now(timezone.utc)
            cached_plan.use_count += 1
            
            # Update cache with new metadata
            self._store_plan_internal(cached_plan, plan_hash)
            
            if self.metrics:
                self.metrics.record_hit()
                
            logger.info(f"Cache HIT for query hash {plan_hash[:12]}... (used {cached_plan.use_count} times)")
            return cached_plan.plan, CacheStatus.HIT
            
        except Exception as e:
            logger.exception(f"Error during cache lookup: {e}")
            if self.metrics:
                self.metrics.record_error()
            return None, CacheStatus.ERROR
    
    def store_plan(self, user_query: str, schema_content: str, tenant_id: str, 
                   generated_plan: List[Dict[str, Any]]) -> bool:
        """Store a newly generated plan in the cache.
        
        Args:
            user_query: Original user query
            schema_content: Schema used to generate the plan
            tenant_id: Tenant identifier
            generated_plan: The LLM-generated query plan
            
        Returns:
            True if stored successfully, False otherwise
        """
        try:
            # Validate plan size
            plan_json = json.dumps(generated_plan)
            if len(plan_json.encode('utf-8')) > self.max_plan_size:
                logger.warning(f"Plan too large to cache: {len(plan_json)} bytes")
                return False
            
            # Create cached plan object
            plan_hash = self._compute_query_hash(user_query, schema_content, tenant_id)
            schema_hash = self._compute_schema_hash(schema_content)
            
            cached_plan = CachedPlan(
                plan_hash=plan_hash,
                plan=generated_plan,
                schema_hash=schema_hash,
                user_query=user_query,
                tenant_id=tenant_id
            )
            
            # Store in Redis
            success = self._store_plan_internal(cached_plan, plan_hash)
            
            if success:
                logger.info(f"Cached new plan {plan_hash[:12]}... for tenant {tenant_id}")
                
                # Update tenant plan index
                self._update_tenant_index(tenant_id, plan_hash)
            
            return success
            
        except Exception as e:
            logger.exception(f"Failed to store plan in cache: {e}")
            return False
    
    def _store_plan_internal(self, cached_plan: CachedPlan, plan_hash: str) -> bool:
        """Internal method to store a plan with proper error handling."""
        try:
            cache_key = f"{self.PLAN_PREFIX}{plan_hash}"
            plan_json = cached_plan.model_dump_json()
            
            # Store with TTL
            return self.redis.setex(cache_key, self.default_ttl, plan_json)
            
        except Exception as e:
            logger.error(f"Redis storage failed for plan {plan_hash}: {e}")
            return False
    
    def _update_tenant_index(self, tenant_id: str, plan_hash: str):
        """Update the index of plans for a tenant."""
        try:
            index_key = f"{self.INDEX_PREFIX}tenant:{tenant_id}"
            # Add plan hash to tenant's set of cached plans
            self.redis.sadd(index_key, plan_hash)
            # Set TTL on the index (slightly longer than plan TTL)
            self.redis.expire(index_key, self.default_ttl + 3600)
        except Exception as e:
            logger.warning(f"Failed to update tenant index: {e}")
    
    def _invalidate_plan(self, plan_hash: str):
        """Remove a specific plan from cache."""
        try:
            cache_key = f"{self.PLAN_PREFIX}{plan_hash}"
            self.redis.delete(cache_key)
        except Exception as e:
            logger.warning(f"Failed to invalidate plan {plan_hash}: {e}")
    
    def invalidate_tenant_cache(self, tenant_id: str) -> int:
        """Invalidate all cached plans for a tenant.
        
        Useful when tenant configuration changes or schema updates.
        
        Args:
            tenant_id: Tenant to invalidate
            
        Returns:
            Number of plans invalidated
        """
        try:
            index_key = f"{self.INDEX_PREFIX}tenant:{tenant_id}"
            plan_hashes = self.redis.smembers(index_key)
            
            if not plan_hashes:
                return 0
            
            # Delete all plans for this tenant
            keys_to_delete = [f"{self.PLAN_PREFIX}{plan_hash.decode()}" for plan_hash in plan_hashes]
            keys_to_delete.append(index_key)  # Also delete the index
            
            deleted_count = self.redis.delete(*keys_to_delete)
            
            logger.info(f"Invalidated {deleted_count} cached plans for tenant {tenant_id}")
            return deleted_count
            
        except Exception as e:
            logger.exception(f"Failed to invalidate tenant cache for {tenant_id}: {e}")
            return 0
    
    def get_cache_stats(self) -> Dict[str, Any]:
        """Get comprehensive cache statistics."""
        if not self.metrics:
            return {"metrics_disabled": True}
        
        # Get Redis memory usage for cache keys
        try:
            cache_keys = self.redis.keys(f"{self.PLAN_PREFIX}*")
            cache_memory_bytes = sum(self.redis.memory_usage(key) or 0 for key in cache_keys[:100])  # Sample first 100
            total_cached_plans = len(cache_keys)
        except Exception:
            cache_memory_bytes = 0
            total_cached_plans = 0
        
        stats = {
            "cache_metrics": asdict(self.metrics),
            "redis_stats": {
                "total_cached_plans": total_cached_plans,
                "estimated_memory_bytes": cache_memory_bytes,
                "estimated_memory_mb": round(cache_memory_bytes / (1024 * 1024), 2)
            },
            "cost_optimization": {
                "estimated_savings_usd": round(self.metrics.cost_savings_estimated, 2),
                "hit_rate_percent": round(self.metrics.hit_rate, 1),
                "avg_cost_per_query": self.metrics.average_llm_cost_per_query
            }
        }
        
        return stats
    
    def cleanup_expired_plans(self) -> int:
        """Manual cleanup of expired plans (Redis handles TTL automatically).
        
        This is mainly for monitoring and can be called periodically.
        
        Returns:
            Number of expired plans found (they're auto-deleted by Redis)
        """
        try:
            cache_keys = self.redis.keys(f"{self.PLAN_PREFIX}*")
            expired_count = 0
            
            for key in cache_keys:
                ttl = self.redis.ttl(key)
                if ttl == -2:  # Key doesn't exist (expired)
                    expired_count += 1
            
            logger.info(f"Found {expired_count} expired plans during cleanup")
            return expired_count
            
        except Exception as e:
            logger.exception(f"Error during cache cleanup: {e}")
            return 0


# Global cache instance (initialized by application)
_plan_cache_instance: Optional[PlanCache] = None


def get_plan_cache() -> Optional[PlanCache]:
    """Get the global plan cache instance."""
    return _plan_cache_instance


def init_plan_cache(redis_client: redis.Redis, **kwargs) -> PlanCache:
    """Initialize the global plan cache instance."""
    global _plan_cache_instance
    _plan_cache_instance = PlanCache(redis_client, **kwargs)
    return _plan_cache_instance


# Utility functions for easy integration

def check_plan_cache(user_query: str, schema_content: str, tenant_id: str) -> Tuple[Optional[List[Dict[str, Any]]], bool]:
    """Convenience function to check cache with simple True/False result.
    
    Args:
        user_query: User's natural language query
        schema_content: Current schema JSON
        tenant_id: Tenant identifier
        
    Returns:
        Tuple of (plan_or_none, cache_hit_boolean)
    """
    cache = get_plan_cache()
    if not cache:
        return None, False
    
    plan, status = cache.get_cached_plan(user_query, schema_content, tenant_id)
    return plan, (status == CacheStatus.HIT)


def cache_generated_plan(user_query: str, schema_content: str, tenant_id: str, plan: List[Dict[str, Any]]) -> bool:
    """Convenience function to cache a newly generated plan.
    
    Args:
        user_query: Original user query
        schema_content: Schema used for generation
        tenant_id: Tenant identifier  
        plan: Generated query plan
        
    Returns:
        True if cached successfully
    """
    cache = get_plan_cache()
    if not cache:
        return False
    
    return cache.store_plan(user_query, schema_content, tenant_id, plan)