Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Latency Optimization Framework | |
| Comprehensive latency reduction optimizations for the RAG pipeline including: | |
| - Response caching with TTL | |
| - Connection pooling for API calls | |
| - Query preprocessing and deduplication | |
| - Parallel processing where possible | |
| - Embedding caching | |
| - Context compression | |
| """ | |
| import hashlib | |
| import logging | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from dataclasses import dataclass | |
| from functools import lru_cache, wraps | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import requests | |
| from requests.adapters import HTTPAdapter | |
| from urllib3.util.retry import Retry | |
| logger = logging.getLogger(__name__) | |
| class LatencyConfig: | |
| """Configuration for latency optimizations.""" | |
| # Caching configuration | |
| enable_response_cache: bool = True | |
| response_cache_ttl: int = 300 # 5 minutes | |
| response_cache_size: int = 100 | |
| enable_embedding_cache: bool = True | |
| embedding_cache_size: int = 500 | |
| enable_query_cache: bool = True | |
| query_cache_size: int = 200 | |
| # Connection pooling | |
| enable_connection_pooling: bool = True | |
| pool_size: int = 10 | |
| pool_maxsize: int = 20 | |
| pool_block: bool = False | |
| # Request optimization | |
| connection_timeout: float = 5.0 | |
| read_timeout: float = 15.0 | |
| max_retries: int = 3 | |
| backoff_factor: float = 0.3 | |
| # Parallel processing | |
| enable_parallel_processing: bool = True | |
| max_workers: int = 4 | |
| # Context optimization | |
| enable_context_compression: bool = True | |
| max_context_tokens: int = 2000 | |
| compression_ratio: float = 0.7 | |
| # Query preprocessing | |
| enable_query_preprocessing: bool = True | |
| min_query_length: int = 3 | |
| max_query_length: int = 500 | |
| class CacheManager: | |
| """Thread-safe cache manager with TTL support.""" | |
| def __init__(self, max_size: int = 100, default_ttl: int = 300): | |
| self.max_size = max_size | |
| self.default_ttl = default_ttl | |
| self._cache: Dict[str, Dict[str, Any]] = {} | |
| self._access_times: Dict[str, float] = {} | |
| def _cleanup_expired(self) -> None: | |
| """Remove expired cache entries.""" | |
| current_time = time.time() | |
| expired_keys = [] | |
| for key, data in self._cache.items(): | |
| if current_time > data.get("expires_at", 0): | |
| expired_keys.append(key) | |
| for key in expired_keys: | |
| self._cache.pop(key, None) | |
| self._access_times.pop(key, None) | |
| def _evict_lru(self) -> None: | |
| """Evict least recently used items if cache is full.""" | |
| while len(self._cache) >= self.max_size: | |
| if not self._access_times: | |
| break | |
| # Find LRU item | |
| lru_key = min(self._access_times.keys(), key=lambda k: self._access_times[k]) | |
| self._cache.pop(lru_key, None) | |
| self._access_times.pop(lru_key, None) | |
| def get(self, key: str) -> Optional[Any]: | |
| """Get item from cache.""" | |
| self._cleanup_expired() | |
| if key in self._cache: | |
| current_time = time.time() | |
| data = self._cache[key] | |
| if current_time <= data.get("expires_at", 0): | |
| self._access_times[key] = current_time | |
| return data["value"] | |
| else: | |
| # Expired item | |
| self._cache.pop(key, None) | |
| self._access_times.pop(key, None) | |
| return None | |
| def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: | |
| """Set item in cache with TTL.""" | |
| self._cleanup_expired() | |
| self._evict_lru() | |
| expires_at = time.time() + (ttl or self.default_ttl) | |
| self._cache[key] = {"value": value, "expires_at": expires_at} | |
| self._access_times[key] = time.time() | |
| def clear(self) -> None: | |
| """Clear all cache entries.""" | |
| self._cache.clear() | |
| self._access_times.clear() | |
| def stats(self) -> Dict[str, Any]: | |
| """Get cache statistics.""" | |
| self._cleanup_expired() | |
| return { | |
| "size": len(self._cache), | |
| "max_size": self.max_size, | |
| "hit_ratio": 0.0, # Would need to track hits/misses | |
| "default_ttl": self.default_ttl, | |
| } | |
| class ConnectionPoolManager: | |
| """HTTP connection pool manager for optimized API calls.""" | |
| def __init__(self, config: LatencyConfig): | |
| self.config = config | |
| self._sessions: Dict[str, requests.Session] = {} | |
| def get_session(self, base_url: str) -> requests.Session: | |
| """Get or create a session for the given base URL.""" | |
| if base_url not in self._sessions: | |
| session = requests.Session() | |
| if self.config.enable_connection_pooling: | |
| # Configure retry strategy | |
| retry_strategy = Retry( | |
| total=self.config.max_retries, | |
| status_forcelist=[429, 500, 502, 503, 504], | |
| method_whitelist=["HEAD", "GET", "POST"], | |
| backoff_factor=self.config.backoff_factor, | |
| ) | |
| # Configure adapter with connection pooling | |
| adapter = HTTPAdapter( | |
| pool_connections=self.config.pool_size, | |
| pool_maxsize=self.config.pool_maxsize, | |
| pool_block=self.config.pool_block, | |
| max_retries=retry_strategy, | |
| ) | |
| session.mount("http://", adapter) | |
| session.mount("https://", adapter) | |
| self._sessions[base_url] = session | |
| return self._sessions[base_url] | |
| def close_all(self) -> None: | |
| """Close all sessions.""" | |
| for session in self._sessions.values(): | |
| session.close() | |
| self._sessions.clear() | |
| class QueryPreprocessor: | |
| """Query preprocessing for latency optimization.""" | |
| def __init__(self, config: LatencyConfig): | |
| self.config = config | |
| self._query_cache = CacheManager( | |
| max_size=config.query_cache_size, default_ttl=600 # 10 minutes for query preprocessing | |
| ) | |
| def preprocess_query(self, query: str) -> Tuple[str, Dict[str, Any]]: | |
| """ | |
| Preprocess query for optimization. | |
| Returns: | |
| Tuple of (processed_query, metadata) | |
| """ | |
| if not self.config.enable_query_preprocessing: | |
| return query, {} | |
| # Check cache first | |
| query_hash = self._hash_query(query) | |
| cached = self._query_cache.get(query_hash) | |
| if cached: | |
| return cached["processed_query"], cached["metadata"] | |
| # Preprocess query | |
| processed_query = self._clean_query(query) | |
| metadata = { | |
| "original_length": len(query), | |
| "processed_length": len(processed_query), | |
| "hash": query_hash, | |
| "timestamp": time.time(), | |
| } | |
| # Cache result | |
| self._query_cache.set(query_hash, {"processed_query": processed_query, "metadata": metadata}) | |
| return processed_query, metadata | |
| def _clean_query(self, query: str) -> str: | |
| """Clean and normalize query.""" | |
| # Basic cleaning | |
| cleaned = query.strip() | |
| # Length validation | |
| if len(cleaned) < self.config.min_query_length: | |
| return cleaned | |
| if len(cleaned) > self.config.max_query_length: | |
| cleaned = cleaned[: self.config.max_query_length] | |
| # Remove excessive whitespace | |
| cleaned = " ".join(cleaned.split()) | |
| # Basic normalization | |
| cleaned = cleaned.lower() | |
| return cleaned | |
| def _hash_query(self, query: str) -> str: | |
| """Generate hash for query caching.""" | |
| return hashlib.md5(query.encode()).hexdigest() | |
| class ContextCompressor: | |
| """Context compression for reduced token usage and faster processing.""" | |
| def __init__(self, config: LatencyConfig): | |
| self.config = config | |
| def compress_context(self, context: str, target_length: Optional[int] = None) -> str: | |
| """ | |
| Compress context while preserving important information. | |
| Args: | |
| context: Original context string | |
| target_length: Target length in characters (uses config default if None) | |
| Returns: | |
| Compressed context string | |
| """ | |
| if not self.config.enable_context_compression: | |
| return context | |
| target_length = target_length or self.config.max_context_tokens | |
| if len(context) <= target_length: | |
| return context | |
| # Simple compression strategies | |
| compressed = self._extract_key_sentences(context, target_length) | |
| logger.debug(f"Context compressed from {len(context)} to {len(compressed)} chars") | |
| return compressed | |
| def _extract_key_sentences(self, text: str, target_length: int) -> str: | |
| """Extract key sentences that fit within target length.""" | |
| sentences = text.split(".") | |
| # Prioritize sentences with key policy terms | |
| key_terms = [ | |
| "policy", | |
| "accrual", | |
| "eligibility", | |
| "days", | |
| "hours", | |
| "employee", | |
| "vacation", | |
| "pto", | |
| "sick", | |
| "leave", | |
| ] | |
| # Score sentences by key terms | |
| scored_sentences = [] | |
| for sentence in sentences: | |
| sentence = sentence.strip() | |
| if len(sentence) < 10: # Skip very short sentences | |
| continue | |
| score = sum(1 for term in key_terms if term.lower() in sentence.lower()) | |
| scored_sentences.append((score, sentence)) | |
| # Sort by score (descending) | |
| scored_sentences.sort(reverse=True) | |
| # Build compressed context | |
| compressed_parts = [] | |
| current_length = 0 | |
| for score, sentence in scored_sentences: | |
| sentence_length = len(sentence) + 2 # +2 for '. ' | |
| if current_length + sentence_length <= target_length: | |
| compressed_parts.append(sentence) | |
| current_length += sentence_length | |
| else: | |
| break | |
| return ". ".join(compressed_parts) + "." if compressed_parts else text[:target_length] | |
| class LatencyOptimizer: | |
| """Main latency optimization coordinator.""" | |
| def __init__(self, config: Optional[LatencyConfig] = None): | |
| self.config = config or LatencyConfig() | |
| # Initialize components | |
| self.response_cache = ( | |
| CacheManager(max_size=self.config.response_cache_size, default_ttl=self.config.response_cache_ttl) | |
| if self.config.enable_response_cache | |
| else None | |
| ) | |
| self.embedding_cache = ( | |
| CacheManager(max_size=self.config.embedding_cache_size, default_ttl=1800) # 30 minutes for embeddings | |
| if self.config.enable_embedding_cache | |
| else None | |
| ) | |
| self.connection_pool = ConnectionPoolManager(self.config) | |
| self.query_preprocessor = QueryPreprocessor(self.config) | |
| self.context_compressor = ContextCompressor(self.config) | |
| # Thread pool for parallel processing | |
| self.thread_pool = ( | |
| ThreadPoolExecutor(max_workers=self.config.max_workers) if self.config.enable_parallel_processing else None | |
| ) | |
| self._metrics = {"cache_hits": 0, "cache_misses": 0, "parallel_tasks": 0, "compression_savings": 0} | |
| logger.info("LatencyOptimizer initialized with optimizations enabled") | |
| def optimize_response_generation(self, query: str, context: str) -> Dict[str, Any]: | |
| """ | |
| Optimize the complete response generation pipeline. | |
| Args: | |
| query: User query | |
| context: Retrieved context | |
| Returns: | |
| Optimization metadata and processed inputs | |
| """ | |
| start_time = time.time() | |
| # Preprocess query | |
| processed_query, query_metadata = self.query_preprocessor.preprocess_query(query) | |
| # Compress context if needed | |
| original_context_length = len(context) | |
| compressed_context = self.context_compressor.compress_context(context) | |
| compression_savings = original_context_length - len(compressed_context) | |
| if compression_savings > 0: | |
| self._metrics["compression_savings"] += compression_savings | |
| # Check response cache | |
| cache_key = self._generate_cache_key(processed_query, compressed_context) | |
| cached_response = None | |
| if self.response_cache: | |
| cached_response = self.response_cache.get(cache_key) | |
| if cached_response: | |
| self._metrics["cache_hits"] += 1 | |
| logger.debug(f"Response cache hit for query: {processed_query[:50]}...") | |
| else: | |
| self._metrics["cache_misses"] += 1 | |
| optimization_metadata = { | |
| "processing_time": time.time() - start_time, | |
| "query_metadata": query_metadata, | |
| "context_compression": { | |
| "original_length": original_context_length, | |
| "compressed_length": len(compressed_context), | |
| "savings": compression_savings, | |
| }, | |
| "cache_key": cache_key, | |
| "cached_response": cached_response is not None, | |
| "processed_query": processed_query, | |
| "compressed_context": compressed_context, | |
| } | |
| return optimization_metadata | |
| def cache_response(self, cache_key: str, response: Any) -> None: | |
| """Cache a response for future use.""" | |
| if self.response_cache: | |
| self.response_cache.set(cache_key, response) | |
| def optimize_embedding_generation(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]: | |
| """ | |
| Optimize embedding generation with caching and parallel processing. | |
| Args: | |
| texts: List of texts to embed | |
| Returns: | |
| Tuple of (embeddings, optimization_metadata) | |
| """ | |
| if not texts: | |
| return [], {"cache_hits": 0, "cache_misses": 0} | |
| embeddings = [] | |
| cache_hits = 0 | |
| cache_misses = 0 | |
| if self.embedding_cache: | |
| # Check cache for each text | |
| cached_embeddings = {} | |
| uncached_texts = [] | |
| for i, text in enumerate(texts): | |
| text_hash = hashlib.md5(text.encode()).hexdigest() | |
| cached = self.embedding_cache.get(text_hash) | |
| if cached: | |
| cached_embeddings[i] = cached | |
| cache_hits += 1 | |
| else: | |
| uncached_texts.append((i, text, text_hash)) | |
| cache_misses += 1 | |
| # Generate embeddings for uncached texts (would need actual embedding service) | |
| # This is a placeholder - actual implementation would call embedding service | |
| for i, text, text_hash in uncached_texts: | |
| # Placeholder embedding | |
| embedding = [0.0] * 1024 | |
| cached_embeddings[i] = embedding | |
| # Cache the embedding | |
| self.embedding_cache.set(text_hash, embedding) | |
| # Reconstruct embeddings in original order | |
| embeddings = [cached_embeddings[i] for i in range(len(texts))] | |
| optimization_metadata = {"cache_hits": cache_hits, "cache_misses": cache_misses, "total_texts": len(texts)} | |
| self._metrics["cache_hits"] += cache_hits | |
| self._metrics["cache_misses"] += cache_misses | |
| return embeddings, optimization_metadata | |
| def optimize_parallel_search(self, queries: List[str]) -> List[Dict[str, Any]]: | |
| """ | |
| Optimize parallel search processing. | |
| Args: | |
| queries: List of search queries | |
| Returns: | |
| List of search results | |
| """ | |
| if not self.config.enable_parallel_processing or not self.thread_pool: | |
| # Sequential processing fallback | |
| return [self._mock_search(query) for query in queries] | |
| # Parallel processing | |
| self._metrics["parallel_tasks"] += len(queries) | |
| future_to_query = {self.thread_pool.submit(self._mock_search, query): query for query in queries} | |
| results = [] | |
| for future in as_completed(future_to_query): | |
| try: | |
| result = future.result(timeout=self.config.read_timeout) | |
| results.append(result) | |
| except Exception as e: | |
| logger.error(f"Parallel search failed: {e}") | |
| results.append({"error": str(e)}) | |
| return results | |
| def _mock_search(self, query: str) -> Dict[str, Any]: | |
| """Mock search function for demonstration.""" | |
| return {"query": query, "results": [{"content": f"Mock result for {query}", "score": 0.9}]} | |
| def _generate_cache_key(self, query: str, context: str) -> str: | |
| """Generate cache key for response caching.""" | |
| combined = f"{query}|{context}" | |
| return hashlib.md5(combined.encode()).hexdigest() | |
| def get_metrics(self) -> Dict[str, Any]: | |
| """Get optimization metrics.""" | |
| return { | |
| **self._metrics, | |
| "response_cache_stats": self.response_cache.stats() if self.response_cache else {}, | |
| "embedding_cache_stats": self.embedding_cache.stats() if self.embedding_cache else {}, | |
| } | |
| def close(self) -> None: | |
| """Clean up resources.""" | |
| if self.thread_pool: | |
| self.thread_pool.shutdown(wait=True) | |
| self.connection_pool.close_all() | |
| if self.response_cache: | |
| self.response_cache.clear() | |
| if self.embedding_cache: | |
| self.embedding_cache.clear() | |
| # Decorator for automatic latency optimization | |
| def optimize_latency(optimizer: Optional[LatencyOptimizer] = None): | |
| """Decorator to automatically optimize function latency.""" | |
| def decorator(func): | |
| def wrapper(*args, **kwargs): | |
| nonlocal optimizer | |
| if optimizer is None: | |
| optimizer = LatencyOptimizer() | |
| start_time = time.time() | |
| result = func(*args, **kwargs) | |
| execution_time = time.time() - start_time | |
| logger.debug(f"Function {func.__name__} executed in {execution_time:.3f}s") | |
| return result | |
| return wrapper | |
| return decorator | |
| # Utility functions for quick optimization | |
| def create_optimized_session(base_url: str, config: Optional[LatencyConfig] = None) -> requests.Session: | |
| """Create an optimized requests session.""" | |
| config = config or LatencyConfig() | |
| pool_manager = ConnectionPoolManager(config) | |
| return pool_manager.get_session(base_url) | |
| def cached_hash(text: str) -> str: | |
| """Cached hash function for frequently used texts.""" | |
| return hashlib.md5(text.encode()).hexdigest() | |
| class PerformanceMonitor: | |
| """Monitor and track performance improvements.""" | |
| def __init__(self): | |
| self.start_time = time.time() | |
| self.metrics = { | |
| "total_requests": 0, | |
| "total_response_time": 0.0, | |
| "cache_hits": 0, | |
| "cache_misses": 0, | |
| "optimization_savings": 0.0, | |
| } | |
| def record_request(self, response_time: float, cache_hit: bool = False): | |
| """Record a request for performance tracking.""" | |
| self.metrics["total_requests"] += 1 | |
| self.metrics["total_response_time"] += response_time | |
| if cache_hit: | |
| self.metrics["cache_hits"] += 1 | |
| else: | |
| self.metrics["cache_misses"] += 1 | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get performance statistics.""" | |
| total_requests = self.metrics["total_requests"] | |
| return { | |
| "uptime": time.time() - self.start_time, | |
| "total_requests": total_requests, | |
| "average_response_time": ( | |
| self.metrics["total_response_time"] / total_requests if total_requests > 0 else 0.0 | |
| ), | |
| "cache_hit_rate": (self.metrics["cache_hits"] / total_requests if total_requests > 0 else 0.0), | |
| "optimization_savings": self.metrics["optimization_savings"], | |
| } | |
| # Global optimizer instance for shared use | |
| _global_optimizer: Optional[LatencyOptimizer] = None | |
| def get_global_optimizer() -> LatencyOptimizer: | |
| """Get or create global optimizer instance.""" | |
| global _global_optimizer | |
| if _global_optimizer is None: | |
| _global_optimizer = LatencyOptimizer() | |
| return _global_optimizer | |
| def configure_global_optimizer(config: LatencyConfig) -> LatencyOptimizer: | |
| """Configure global optimizer with specific settings.""" | |
| global _global_optimizer | |
| _global_optimizer = LatencyOptimizer(config) | |
| return _global_optimizer | |