wanderlust.ai / src /wanderlust_ai /core /intelligent_caching.py
BlakeL's picture
Upload 115 files
3f9f85b verified
"""
Intelligent Multi-Level Caching System
This module implements sophisticated caching strategies across multiple levels:
- API Response Caching
- Agent Decision Caching
- Orchestrator Plan Caching
- Result Aggregation Caching
"""
import hashlib
import json
import time
import asyncio
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Callable, Union, Tuple
from dataclasses import dataclass, field
from collections import defaultdict, OrderedDict
from enum import Enum
import pickle
import threading
from functools import wraps
from pydantic import BaseModel, Field
class CacheLevel(str, Enum):
"""Different levels of caching in the system."""
API_RESPONSE = "api_response"
AGENT_DECISION = "agent_decision"
ORCHESTRATOR_PLAN = "orchestrator_plan"
RESULT_AGGREGATION = "result_aggregation"
USER_PREFERENCE = "user_preference"
class CacheStrategy(str, Enum):
"""Different caching strategies."""
LRU = "lru" # Least Recently Used
LFU = "lfu" # Least Frequently Used
TTL = "ttl" # Time To Live
ADAPTIVE = "adaptive" # Adaptive based on access patterns
INVALIDATION = "invalidation" # Manual invalidation
@dataclass
class CacheEntry:
"""Individual cache entry with metadata."""
key: str
value: Any
created_at: datetime
last_accessed: datetime
access_count: int = 0
ttl_seconds: Optional[int] = None
level: CacheLevel = CacheLevel.API_RESPONSE
tags: List[str] = field(default_factory=list)
size_bytes: int = 0
hit_rate: float = 0.0
def is_expired(self) -> bool:
"""Check if cache entry has expired."""
if self.ttl_seconds is None:
return False
return datetime.now() > self.created_at + timedelta(seconds=self.ttl_seconds)
def access(self) -> None:
"""Record cache access."""
self.last_accessed = datetime.now()
self.access_count += 1
def calculate_size(self) -> int:
"""Calculate size of cached value."""
try:
if isinstance(self.value, str):
return len(self.value.encode('utf-8'))
elif isinstance(self.value, (dict, list)):
return len(json.dumps(self.value).encode('utf-8'))
else:
return len(pickle.dumps(self.value))
except:
return 0
class CacheStats(BaseModel):
"""Cache performance statistics."""
level: CacheLevel
total_entries: int
total_size_mb: float
hit_rate: float
miss_rate: float
eviction_count: int
average_access_time_ms: float
most_accessed_keys: List[str]
class IntelligentCache:
"""
Intelligent multi-level cache with adaptive strategies.
Features:
- Multiple cache levels with different strategies
- Automatic TTL and size-based eviction
- Access pattern analysis
- Cache warming and prefetching
- Intelligent invalidation
"""
def __init__(self, max_size_mb: int = 100, default_ttl_seconds: int = 3600):
self.max_size_bytes = max_size_mb * 1024 * 1024
self.default_ttl_seconds = default_ttl_seconds
# Thread-safe storage
self._lock = threading.RLock()
self._entries: Dict[str, CacheEntry] = {}
self._level_entries: Dict[CacheLevel, Dict[str, CacheEntry]] = defaultdict(dict)
self._access_order: OrderedDict = OrderedDict()
# Statistics
self._hits = 0
self._misses = 0
self._evictions = 0
self._total_size = 0
# Background tasks
self._cleanup_task: Optional[asyncio.Task] = None
self._start_cleanup_task()
def _start_cleanup_task(self):
"""Start background cleanup task."""
if self._cleanup_task is None or self._cleanup_task.done():
try:
loop = asyncio.get_running_loop()
self._cleanup_task = loop.create_task(self._periodic_cleanup())
except RuntimeError:
# No event loop running, will start later
pass
async def _periodic_cleanup(self):
"""Periodic cleanup of expired entries."""
while True:
try:
await asyncio.sleep(60) # Cleanup every minute
await self._cleanup_expired()
except asyncio.CancelledError:
break
except Exception:
# Continue running even if cleanup fails
pass
async def _cleanup_expired(self):
"""Remove expired entries."""
with self._lock:
expired_keys = []
for key, entry in self._entries.items():
if entry.is_expired():
expired_keys.append(key)
for key in expired_keys:
self._remove_entry(key)
def _generate_key(self, level: CacheLevel, identifier: str,
params: Optional[Dict[str, Any]] = None) -> str:
"""Generate a unique cache key."""
key_data = {
"level": level.value,
"identifier": identifier,
"params": params or {}
}
key_string = json.dumps(key_data, sort_keys=True)
return hashlib.sha256(key_string.encode()).hexdigest()[:16]
def get(self, level: CacheLevel, identifier: str,
params: Optional[Dict[str, Any]] = None) -> Optional[Any]:
"""
Retrieve value from cache.
Returns:
Cached value if found and not expired, None otherwise
"""
key = self._generate_key(level, identifier, params)
with self._lock:
if key not in self._entries:
self._misses += 1
return None
entry = self._entries[key]
if entry.is_expired():
self._remove_entry(key)
self._misses += 1
return None
entry.access()
self._hits += 1
# Update access order for LRU
if key in self._access_order:
self._access_order.move_to_end(key)
return entry.value
def set(self, level: CacheLevel, identifier: str, value: Any,
ttl_seconds: Optional[int] = None, tags: Optional[List[str]] = None,
params: Optional[Dict[str, Any]] = None) -> None:
"""
Store value in cache.
Args:
level: Cache level
identifier: Unique identifier for the cache entry
value: Value to cache
ttl_seconds: Time to live (uses default if None)
tags: Tags for categorization and invalidation
params: Parameters used in key generation
"""
key = self._generate_key(level, identifier, params)
ttl = ttl_seconds or self.default_ttl_seconds
entry = CacheEntry(
key=key,
value=value,
created_at=datetime.now(),
last_accessed=datetime.now(),
ttl_seconds=ttl,
level=level,
tags=tags or []
)
# Calculate size
entry.size_bytes = entry.calculate_size()
with self._lock:
# Remove existing entry if present
if key in self._entries:
self._remove_entry(key)
# Check if we need to evict entries
self._ensure_space(entry.size_bytes)
# Add new entry
self._entries[key] = entry
self._level_entries[level][key] = entry
self._access_order[key] = entry
self._total_size += entry.size_bytes
def _remove_entry(self, key: str) -> None:
"""Remove entry from all data structures."""
if key not in self._entries:
return
entry = self._entries[key]
# Remove from all data structures
del self._entries[key]
if key in self._level_entries[entry.level]:
del self._level_entries[entry.level][key]
if key in self._access_order:
del self._access_order[key]
self._total_size -= entry.size_bytes
self._evictions += 1
def _ensure_space(self, required_bytes: int) -> None:
"""Ensure there's enough space for new entry."""
while (self._total_size + required_bytes > self.max_size_bytes and
self._entries):
# LRU eviction
oldest_key = next(iter(self._access_order))
self._remove_entry(oldest_key)
def invalidate(self, level: Optional[CacheLevel] = None,
tags: Optional[List[str]] = None,
pattern: Optional[str] = None) -> int:
"""
Invalidate cache entries based on criteria.
Args:
level: Invalidate entries at specific level
tags: Invalidate entries with specific tags
pattern: Invalidate entries matching pattern
Returns:
Number of entries invalidated
"""
invalidated = 0
with self._lock:
keys_to_remove = []
for key, entry in self._entries.items():
should_invalidate = False
if level and entry.level == level:
should_invalidate = True
if tags and any(tag in entry.tags for tag in tags):
should_invalidate = True
if pattern and pattern in key:
should_invalidate = True
if should_invalidate:
keys_to_remove.append(key)
for key in keys_to_remove:
self._remove_entry(key)
invalidated += 1
return invalidated
def warm_cache(self, level: CacheLevel, warm_data: Dict[str, Any],
ttl_seconds: Optional[int] = None) -> None:
"""
Warm the cache with predefined data.
Args:
level: Cache level to warm
warm_data: Dictionary of identifier -> value mappings
ttl_seconds: TTL for warmed entries
"""
for identifier, value in warm_data.items():
self.set(level, identifier, value, ttl_seconds)
def get_stats(self) -> Dict[CacheLevel, CacheStats]:
"""Get comprehensive cache statistics."""
with self._lock:
stats = {}
total_hits_misses = self._hits + self._misses
overall_hit_rate = self._hits / total_hits_misses if total_hits_misses > 0 else 0
for level in CacheLevel:
level_entries = self._level_entries[level]
level_size = sum(entry.size_bytes for entry in level_entries.values())
# Calculate level-specific hit rate
level_hits = sum(1 for entry in level_entries.values()
if entry.access_count > 0)
level_total = len(level_entries)
level_hit_rate = level_hits / level_total if level_total > 0 else 0
# Most accessed keys
most_accessed = sorted(
[(key, entry.access_count) for key, entry in level_entries.items()],
key=lambda x: x[1], reverse=True
)[:5]
stats[level] = CacheStats(
level=level,
total_entries=len(level_entries),
total_size_mb=level_size / 1024 / 1024,
hit_rate=level_hit_rate,
miss_rate=1.0 - level_hit_rate,
eviction_count=self._evictions,
average_access_time_ms=0.0, # Would need timing data
most_accessed_keys=[key for key, _ in most_accessed]
)
return stats
def clear(self, level: Optional[CacheLevel] = None) -> None:
"""Clear cache entries."""
with self._lock:
if level:
# Clear specific level
if level in self._level_entries:
for key in list(self._level_entries[level].keys()):
self._remove_entry(key)
else:
# Clear all
self._entries.clear()
self._level_entries.clear()
self._access_order.clear()
self._total_size = 0
def prefetch(self, level: CacheLevel, identifier: str,
fetch_func: Callable, *args, **kwargs) -> None:
"""
Prefetch data into cache.
Args:
level: Cache level
identifier: Cache identifier
fetch_func: Function to fetch data
*args, **kwargs: Arguments for fetch function
"""
try:
value = fetch_func(*args, **kwargs)
self.set(level, identifier, value)
except Exception:
# Silently fail prefetch
pass
class MultiLevelCacheManager:
"""
Manager for coordinating multiple cache levels.
Provides intelligent cache coordination across:
- API responses
- Agent decisions
- Orchestrator plans
- Result aggregations
"""
def __init__(self):
self._caches: Dict[CacheLevel, IntelligentCache] = {
level: IntelligentCache(max_size_mb=50, default_ttl_seconds=3600)
for level in CacheLevel
}
# Cross-level invalidation rules
self._invalidation_rules: Dict[CacheLevel, List[CacheLevel]] = {
CacheLevel.API_RESPONSE: [CacheLevel.AGENT_DECISION, CacheLevel.ORCHESTRATOR_PLAN],
CacheLevel.AGENT_DECISION: [CacheLevel.ORCHESTRATOR_PLAN],
CacheLevel.ORCHESTRATOR_PLAN: [CacheLevel.RESULT_AGGREGATION]
}
def get(self, level: CacheLevel, identifier: str,
params: Optional[Dict[str, Any]] = None) -> Optional[Any]:
"""Get value from specified cache level."""
return self._caches[level].get(level, identifier, params)
def set(self, level: CacheLevel, identifier: str, value: Any,
ttl_seconds: Optional[int] = None, tags: Optional[List[str]] = None,
params: Optional[Dict[str, Any]] = None) -> None:
"""Set value in specified cache level with cascade invalidation."""
self._caches[level].set(level, identifier, value, ttl_seconds, tags, params)
# Cascade invalidation to dependent levels
if level in self._invalidation_rules:
for dependent_level in self._invalidation_rules[level]:
self._caches[dependent_level].invalidate(tags=tags)
def get_aggregated_stats(self) -> Dict[str, Any]:
"""Get aggregated statistics across all cache levels."""
all_stats = {}
total_entries = 0
total_size_mb = 0
total_hits = 0
total_misses = 0
for level, cache in self._caches.items():
stats = cache.get_stats()[level]
all_stats[level.value] = stats.model_dump()
total_entries += stats.total_entries
total_size_mb += stats.total_size_mb
total_hits += int(stats.hit_rate * stats.total_entries)
total_misses += int(stats.miss_rate * stats.total_entries)
overall_hit_rate = total_hits / (total_hits + total_misses) if (total_hits + total_misses) > 0 else 0
return {
"overall": {
"total_entries": total_entries,
"total_size_mb": total_size_mb,
"overall_hit_rate": overall_hit_rate
},
"by_level": all_stats
}
# Global cache manager
_global_cache_manager: Optional[MultiLevelCacheManager] = None
def get_global_cache() -> MultiLevelCacheManager:
"""Get the global cache manager instance."""
global _global_cache_manager
if _global_cache_manager is None:
_global_cache_manager = MultiLevelCacheManager()
return _global_cache_manager
def cache_result(level: CacheLevel, identifier: str,
ttl_seconds: Optional[int] = None, tags: Optional[List[str]] = None):
"""
Decorator for caching function results.
Usage:
@cache_result(CacheLevel.API_RESPONSE, "flight_search")
async def search_flights(self, origin, destination):
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def async_wrapper(*args, **kwargs):
cache = get_global_cache()
# Generate cache key from function arguments
cache_params = {"args": args, "kwargs": kwargs}
cached_result = cache.get(level, identifier, cache_params)
if cached_result is not None:
return cached_result
# Execute function and cache result
result = await func(*args, **kwargs)
cache.set(level, identifier, result, ttl_seconds, tags, cache_params)
return result
@wraps(func)
def sync_wrapper(*args, **kwargs):
cache = get_global_cache()
# Generate cache key from function arguments
cache_params = {"args": args, "kwargs": kwargs}
cached_result = cache.get(level, identifier, cache_params)
if cached_result is not None:
return cached_result
# Execute function and cache result
result = func(*args, **kwargs)
cache.set(level, identifier, result, ttl_seconds, tags, cache_params)
return result
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
def cache_api_response(ttl_seconds: int = 1800):
"""Decorator for caching API responses."""
return cache_result(CacheLevel.API_RESPONSE, ttl_seconds=ttl_seconds, tags=["api"])
def cache_agent_decision(ttl_seconds: int = 3600):
"""Decorator for caching agent decisions."""
return cache_result(CacheLevel.AGENT_DECISION, ttl_seconds=ttl_seconds, tags=["agent"])
def cache_orchestrator_plan(ttl_seconds: int = 7200):
"""Decorator for caching orchestrator plans."""
return cache_result(CacheLevel.ORCHESTRATOR_PLAN, ttl_seconds=ttl_seconds, tags=["orchestrator"])