ai-engineering-project / src /optimization /latency_optimizer.py
GitHub Action
Clean deployment without binary files
f884e6e
#!/usr/bin/env python3
"""
Latency Optimization Framework
Comprehensive latency reduction optimizations for the RAG pipeline including:
- Response caching with TTL
- Connection pooling for API calls
- Query preprocessing and deduplication
- Parallel processing where possible
- Embedding caching
- Context compression
"""
import hashlib
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from functools import lru_cache, wraps
from typing import Any, Dict, List, Optional, Tuple
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
logger = logging.getLogger(__name__)
@dataclass
class LatencyConfig:
"""Configuration for latency optimizations."""
# Caching configuration
enable_response_cache: bool = True
response_cache_ttl: int = 300 # 5 minutes
response_cache_size: int = 100
enable_embedding_cache: bool = True
embedding_cache_size: int = 500
enable_query_cache: bool = True
query_cache_size: int = 200
# Connection pooling
enable_connection_pooling: bool = True
pool_size: int = 10
pool_maxsize: int = 20
pool_block: bool = False
# Request optimization
connection_timeout: float = 5.0
read_timeout: float = 15.0
max_retries: int = 3
backoff_factor: float = 0.3
# Parallel processing
enable_parallel_processing: bool = True
max_workers: int = 4
# Context optimization
enable_context_compression: bool = True
max_context_tokens: int = 2000
compression_ratio: float = 0.7
# Query preprocessing
enable_query_preprocessing: bool = True
min_query_length: int = 3
max_query_length: int = 500
class CacheManager:
"""Thread-safe cache manager with TTL support."""
def __init__(self, max_size: int = 100, default_ttl: int = 300):
self.max_size = max_size
self.default_ttl = default_ttl
self._cache: Dict[str, Dict[str, Any]] = {}
self._access_times: Dict[str, float] = {}
def _cleanup_expired(self) -> None:
"""Remove expired cache entries."""
current_time = time.time()
expired_keys = []
for key, data in self._cache.items():
if current_time > data.get("expires_at", 0):
expired_keys.append(key)
for key in expired_keys:
self._cache.pop(key, None)
self._access_times.pop(key, None)
def _evict_lru(self) -> None:
"""Evict least recently used items if cache is full."""
while len(self._cache) >= self.max_size:
if not self._access_times:
break
# Find LRU item
lru_key = min(self._access_times.keys(), key=lambda k: self._access_times[k])
self._cache.pop(lru_key, None)
self._access_times.pop(lru_key, None)
def get(self, key: str) -> Optional[Any]:
"""Get item from cache."""
self._cleanup_expired()
if key in self._cache:
current_time = time.time()
data = self._cache[key]
if current_time <= data.get("expires_at", 0):
self._access_times[key] = current_time
return data["value"]
else:
# Expired item
self._cache.pop(key, None)
self._access_times.pop(key, None)
return None
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
"""Set item in cache with TTL."""
self._cleanup_expired()
self._evict_lru()
expires_at = time.time() + (ttl or self.default_ttl)
self._cache[key] = {"value": value, "expires_at": expires_at}
self._access_times[key] = time.time()
def clear(self) -> None:
"""Clear all cache entries."""
self._cache.clear()
self._access_times.clear()
def stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
self._cleanup_expired()
return {
"size": len(self._cache),
"max_size": self.max_size,
"hit_ratio": 0.0, # Would need to track hits/misses
"default_ttl": self.default_ttl,
}
class ConnectionPoolManager:
"""HTTP connection pool manager for optimized API calls."""
def __init__(self, config: LatencyConfig):
self.config = config
self._sessions: Dict[str, requests.Session] = {}
def get_session(self, base_url: str) -> requests.Session:
"""Get or create a session for the given base URL."""
if base_url not in self._sessions:
session = requests.Session()
if self.config.enable_connection_pooling:
# Configure retry strategy
retry_strategy = Retry(
total=self.config.max_retries,
status_forcelist=[429, 500, 502, 503, 504],
method_whitelist=["HEAD", "GET", "POST"],
backoff_factor=self.config.backoff_factor,
)
# Configure adapter with connection pooling
adapter = HTTPAdapter(
pool_connections=self.config.pool_size,
pool_maxsize=self.config.pool_maxsize,
pool_block=self.config.pool_block,
max_retries=retry_strategy,
)
session.mount("http://", adapter)
session.mount("https://", adapter)
self._sessions[base_url] = session
return self._sessions[base_url]
def close_all(self) -> None:
"""Close all sessions."""
for session in self._sessions.values():
session.close()
self._sessions.clear()
class QueryPreprocessor:
"""Query preprocessing for latency optimization."""
def __init__(self, config: LatencyConfig):
self.config = config
self._query_cache = CacheManager(
max_size=config.query_cache_size, default_ttl=600 # 10 minutes for query preprocessing
)
def preprocess_query(self, query: str) -> Tuple[str, Dict[str, Any]]:
"""
Preprocess query for optimization.
Returns:
Tuple of (processed_query, metadata)
"""
if not self.config.enable_query_preprocessing:
return query, {}
# Check cache first
query_hash = self._hash_query(query)
cached = self._query_cache.get(query_hash)
if cached:
return cached["processed_query"], cached["metadata"]
# Preprocess query
processed_query = self._clean_query(query)
metadata = {
"original_length": len(query),
"processed_length": len(processed_query),
"hash": query_hash,
"timestamp": time.time(),
}
# Cache result
self._query_cache.set(query_hash, {"processed_query": processed_query, "metadata": metadata})
return processed_query, metadata
def _clean_query(self, query: str) -> str:
"""Clean and normalize query."""
# Basic cleaning
cleaned = query.strip()
# Length validation
if len(cleaned) < self.config.min_query_length:
return cleaned
if len(cleaned) > self.config.max_query_length:
cleaned = cleaned[: self.config.max_query_length]
# Remove excessive whitespace
cleaned = " ".join(cleaned.split())
# Basic normalization
cleaned = cleaned.lower()
return cleaned
def _hash_query(self, query: str) -> str:
"""Generate hash for query caching."""
return hashlib.md5(query.encode()).hexdigest()
class ContextCompressor:
"""Context compression for reduced token usage and faster processing."""
def __init__(self, config: LatencyConfig):
self.config = config
def compress_context(self, context: str, target_length: Optional[int] = None) -> str:
"""
Compress context while preserving important information.
Args:
context: Original context string
target_length: Target length in characters (uses config default if None)
Returns:
Compressed context string
"""
if not self.config.enable_context_compression:
return context
target_length = target_length or self.config.max_context_tokens
if len(context) <= target_length:
return context
# Simple compression strategies
compressed = self._extract_key_sentences(context, target_length)
logger.debug(f"Context compressed from {len(context)} to {len(compressed)} chars")
return compressed
def _extract_key_sentences(self, text: str, target_length: int) -> str:
"""Extract key sentences that fit within target length."""
sentences = text.split(".")
# Prioritize sentences with key policy terms
key_terms = [
"policy",
"accrual",
"eligibility",
"days",
"hours",
"employee",
"vacation",
"pto",
"sick",
"leave",
]
# Score sentences by key terms
scored_sentences = []
for sentence in sentences:
sentence = sentence.strip()
if len(sentence) < 10: # Skip very short sentences
continue
score = sum(1 for term in key_terms if term.lower() in sentence.lower())
scored_sentences.append((score, sentence))
# Sort by score (descending)
scored_sentences.sort(reverse=True)
# Build compressed context
compressed_parts = []
current_length = 0
for score, sentence in scored_sentences:
sentence_length = len(sentence) + 2 # +2 for '. '
if current_length + sentence_length <= target_length:
compressed_parts.append(sentence)
current_length += sentence_length
else:
break
return ". ".join(compressed_parts) + "." if compressed_parts else text[:target_length]
class LatencyOptimizer:
"""Main latency optimization coordinator."""
def __init__(self, config: Optional[LatencyConfig] = None):
self.config = config or LatencyConfig()
# Initialize components
self.response_cache = (
CacheManager(max_size=self.config.response_cache_size, default_ttl=self.config.response_cache_ttl)
if self.config.enable_response_cache
else None
)
self.embedding_cache = (
CacheManager(max_size=self.config.embedding_cache_size, default_ttl=1800) # 30 minutes for embeddings
if self.config.enable_embedding_cache
else None
)
self.connection_pool = ConnectionPoolManager(self.config)
self.query_preprocessor = QueryPreprocessor(self.config)
self.context_compressor = ContextCompressor(self.config)
# Thread pool for parallel processing
self.thread_pool = (
ThreadPoolExecutor(max_workers=self.config.max_workers) if self.config.enable_parallel_processing else None
)
self._metrics = {"cache_hits": 0, "cache_misses": 0, "parallel_tasks": 0, "compression_savings": 0}
logger.info("LatencyOptimizer initialized with optimizations enabled")
def optimize_response_generation(self, query: str, context: str) -> Dict[str, Any]:
"""
Optimize the complete response generation pipeline.
Args:
query: User query
context: Retrieved context
Returns:
Optimization metadata and processed inputs
"""
start_time = time.time()
# Preprocess query
processed_query, query_metadata = self.query_preprocessor.preprocess_query(query)
# Compress context if needed
original_context_length = len(context)
compressed_context = self.context_compressor.compress_context(context)
compression_savings = original_context_length - len(compressed_context)
if compression_savings > 0:
self._metrics["compression_savings"] += compression_savings
# Check response cache
cache_key = self._generate_cache_key(processed_query, compressed_context)
cached_response = None
if self.response_cache:
cached_response = self.response_cache.get(cache_key)
if cached_response:
self._metrics["cache_hits"] += 1
logger.debug(f"Response cache hit for query: {processed_query[:50]}...")
else:
self._metrics["cache_misses"] += 1
optimization_metadata = {
"processing_time": time.time() - start_time,
"query_metadata": query_metadata,
"context_compression": {
"original_length": original_context_length,
"compressed_length": len(compressed_context),
"savings": compression_savings,
},
"cache_key": cache_key,
"cached_response": cached_response is not None,
"processed_query": processed_query,
"compressed_context": compressed_context,
}
return optimization_metadata
def cache_response(self, cache_key: str, response: Any) -> None:
"""Cache a response for future use."""
if self.response_cache:
self.response_cache.set(cache_key, response)
def optimize_embedding_generation(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]:
"""
Optimize embedding generation with caching and parallel processing.
Args:
texts: List of texts to embed
Returns:
Tuple of (embeddings, optimization_metadata)
"""
if not texts:
return [], {"cache_hits": 0, "cache_misses": 0}
embeddings = []
cache_hits = 0
cache_misses = 0
if self.embedding_cache:
# Check cache for each text
cached_embeddings = {}
uncached_texts = []
for i, text in enumerate(texts):
text_hash = hashlib.md5(text.encode()).hexdigest()
cached = self.embedding_cache.get(text_hash)
if cached:
cached_embeddings[i] = cached
cache_hits += 1
else:
uncached_texts.append((i, text, text_hash))
cache_misses += 1
# Generate embeddings for uncached texts (would need actual embedding service)
# This is a placeholder - actual implementation would call embedding service
for i, text, text_hash in uncached_texts:
# Placeholder embedding
embedding = [0.0] * 1024
cached_embeddings[i] = embedding
# Cache the embedding
self.embedding_cache.set(text_hash, embedding)
# Reconstruct embeddings in original order
embeddings = [cached_embeddings[i] for i in range(len(texts))]
optimization_metadata = {"cache_hits": cache_hits, "cache_misses": cache_misses, "total_texts": len(texts)}
self._metrics["cache_hits"] += cache_hits
self._metrics["cache_misses"] += cache_misses
return embeddings, optimization_metadata
def optimize_parallel_search(self, queries: List[str]) -> List[Dict[str, Any]]:
"""
Optimize parallel search processing.
Args:
queries: List of search queries
Returns:
List of search results
"""
if not self.config.enable_parallel_processing or not self.thread_pool:
# Sequential processing fallback
return [self._mock_search(query) for query in queries]
# Parallel processing
self._metrics["parallel_tasks"] += len(queries)
future_to_query = {self.thread_pool.submit(self._mock_search, query): query for query in queries}
results = []
for future in as_completed(future_to_query):
try:
result = future.result(timeout=self.config.read_timeout)
results.append(result)
except Exception as e:
logger.error(f"Parallel search failed: {e}")
results.append({"error": str(e)})
return results
def _mock_search(self, query: str) -> Dict[str, Any]:
"""Mock search function for demonstration."""
return {"query": query, "results": [{"content": f"Mock result for {query}", "score": 0.9}]}
def _generate_cache_key(self, query: str, context: str) -> str:
"""Generate cache key for response caching."""
combined = f"{query}|{context}"
return hashlib.md5(combined.encode()).hexdigest()
def get_metrics(self) -> Dict[str, Any]:
"""Get optimization metrics."""
return {
**self._metrics,
"response_cache_stats": self.response_cache.stats() if self.response_cache else {},
"embedding_cache_stats": self.embedding_cache.stats() if self.embedding_cache else {},
}
def close(self) -> None:
"""Clean up resources."""
if self.thread_pool:
self.thread_pool.shutdown(wait=True)
self.connection_pool.close_all()
if self.response_cache:
self.response_cache.clear()
if self.embedding_cache:
self.embedding_cache.clear()
# Decorator for automatic latency optimization
def optimize_latency(optimizer: Optional[LatencyOptimizer] = None):
"""Decorator to automatically optimize function latency."""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal optimizer
if optimizer is None:
optimizer = LatencyOptimizer()
start_time = time.time()
result = func(*args, **kwargs)
execution_time = time.time() - start_time
logger.debug(f"Function {func.__name__} executed in {execution_time:.3f}s")
return result
return wrapper
return decorator
# Utility functions for quick optimization
def create_optimized_session(base_url: str, config: Optional[LatencyConfig] = None) -> requests.Session:
"""Create an optimized requests session."""
config = config or LatencyConfig()
pool_manager = ConnectionPoolManager(config)
return pool_manager.get_session(base_url)
@lru_cache(maxsize=128)
def cached_hash(text: str) -> str:
"""Cached hash function for frequently used texts."""
return hashlib.md5(text.encode()).hexdigest()
class PerformanceMonitor:
"""Monitor and track performance improvements."""
def __init__(self):
self.start_time = time.time()
self.metrics = {
"total_requests": 0,
"total_response_time": 0.0,
"cache_hits": 0,
"cache_misses": 0,
"optimization_savings": 0.0,
}
def record_request(self, response_time: float, cache_hit: bool = False):
"""Record a request for performance tracking."""
self.metrics["total_requests"] += 1
self.metrics["total_response_time"] += response_time
if cache_hit:
self.metrics["cache_hits"] += 1
else:
self.metrics["cache_misses"] += 1
def get_stats(self) -> Dict[str, Any]:
"""Get performance statistics."""
total_requests = self.metrics["total_requests"]
return {
"uptime": time.time() - self.start_time,
"total_requests": total_requests,
"average_response_time": (
self.metrics["total_response_time"] / total_requests if total_requests > 0 else 0.0
),
"cache_hit_rate": (self.metrics["cache_hits"] / total_requests if total_requests > 0 else 0.0),
"optimization_savings": self.metrics["optimization_savings"],
}
# Global optimizer instance for shared use
_global_optimizer: Optional[LatencyOptimizer] = None
def get_global_optimizer() -> LatencyOptimizer:
"""Get or create global optimizer instance."""
global _global_optimizer
if _global_optimizer is None:
_global_optimizer = LatencyOptimizer()
return _global_optimizer
def configure_global_optimizer(config: LatencyConfig) -> LatencyOptimizer:
"""Configure global optimizer with specific settings."""
global _global_optimizer
_global_optimizer = LatencyOptimizer(config)
return _global_optimizer