zenith-backend / app /services /advanced_performance.py
teoat's picture
fix(backend): fix port and health check robustness
d29a5a0 verified
"""
Advanced Performance Optimization Suite
Zenith Fraud Detection Platform - Enterprise-Grade Performance
"""
import asyncio
import time
import redis
import json
import hashlib
from datetime import datetime, timedelta
from typing import Optional, Any, Callable
from dataclasses import dataclass
from functools import wraps
import threading
from app.services.simplified_database import DatabaseService
from app.services.monitoring_collector import MonitoringCollector
@dataclass
class CacheEntry:
"""Cache entry with TTL"""
value: Any
expires_at: datetime
access_count: int = 0
created_at: datetime = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow()
@dataclass
class PerformanceMetric:
"""Performance metric data structure"""
operation: str
duration_ms: float
timestamp: datetime
success: bool
metadata: dict[str, Any] = None
class AdvancedCacheManager:
"""Advanced multi-level caching system"""
def __init__(self):
# L1: In-memory cache (fastest)
self.l1_cache = {}
self.l1_max_size = 1000
self.l1_access_order = []
# L2: Redis cache (fast)
self.redis_client = None
self.l2_ttl = 3600 # 1 hour
# L3: Database cache (slowest)
self.db_service = None
# Cache statistics
self.cache_hits = 0
self.cache_misses = 0
self.cache_operations = 0
async def initialize(self):
"""Initialize cache system"""
try:
# Initialize Redis (optional)
self.redis_client = redis.Redis(
host="localhost",
port=6379,
db=0,
decode_responses=True,
socket_connect_timeout=5,
socket_timeout=5,
retry_on_timeout=True,
)
self.redis_client.ping() # Test connection
except Exception:
print("Redis not available, using memory-only cache")
self.redis_client = None
self.db_service = DatabaseService()
await self.db_service.initialize()
def generate_cache_key(self, prefix: str, params: dict[str, Any]) -> str:
"""Generate cache key from parameters"""
param_str = json.dumps(params, sort_keys=True)
param_hash = hashlib.md5(param_str.encode()).hexdigest()
return f"{prefix}:{param_hash}"
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache (L1 -> L2 -> L3)"""
self.cache_operations += 1
# L1: In-memory cache
if key in self.l1_cache:
entry = self.l1_cache[key]
if entry.expires_at > datetime.utcnow():
entry.access_count += 1
self._update_l1_access_order(key)
self.cache_hits += 1
return entry.value
else:
# Expired, remove from L1
del self.l1_cache[key]
self.l1_access_order.remove(key)
# L2: Redis cache
if self.redis_client:
try:
cached = self.redis_client.get(key)
if cached:
data = json.loads(cached)
# Promote to L1 if recently accessed
if data.get("access_count", 0) > 2:
await self._promote_to_l1(key, data)
self.cache_hits += 1
return data["value"]
except Exception:
pass
# L3: Database cache
try:
# Query database cache table
result = await self.db_service.execute_query(
"SELECT value, expires_at FROM cache_entries WHERE key = %s AND expires_at > NOW()", (key,)
)
if result:
data = json.loads(result[0]["value"])
await self._promote_to_l1(key, data)
self.cache_hits += 1
return data["value"]
except Exception:
pass
self.cache_misses += 1
return None
async def set(self, key: str, value: Any, ttl: int = 3600):
"""Set value in all cache levels"""
expires_at = datetime.utcnow() + timedelta(seconds=ttl)
cache_entry = CacheEntry(value=value, expires_at=expires_at)
# L1: In-memory cache
await self._set_l1(key, cache_entry)
# L2: Redis cache
if self.redis_client:
try:
data = {"value": value, "expires_at": expires_at.isoformat(), "access_count": cache_entry.access_count}
self.redis_client.setex(key, ttl, json.dumps(data))
except Exception:
pass
# L3: Database cache
try:
await self.db_service.execute_insert(
"""INSERT INTO cache_entries (key, value, expires_at, created_at)
VALUES (%s, %s, %s, %s)
ON CONFLICT (key) DO UPDATE SET
value = EXCLUDED.value, expires_at = EXCLUDED.expires_at""",
(key, json.dumps({"value": value, "access_count": 0}), expires_at, datetime.utcnow()),
)
except Exception:
pass
async def _set_l1(self, key: str, entry: CacheEntry):
"""Set value in L1 cache with eviction"""
# Evict if at capacity
if len(self.l1_cache) >= self.l1_max_size:
await self._evict_lru_l1()
self.l1_cache[key] = entry
self._update_l1_access_order(key)
def _update_l1_access_order(self, key: str):
"""Update L1 cache access order"""
if key in self.l1_access_order:
self.l1_access_order.remove(key)
self.l1_access_order.append(key)
async def _evict_lru_l1(self):
"""Evict least recently used item from L1"""
if not self.l1_access_order:
return
lru_key = self.l1_access_order.pop(0)
if lru_key in self.l1_cache:
del self.l1_cache[lru_key]
async def _promote_to_l1(self, key: str, data: dict[str, Any]):
"""Promote data to L1 cache"""
expires_at = datetime.fromisoformat(data["expires_at"])
entry = CacheEntry(value=data["value"], expires_at=expires_at, access_count=data.get("access_count", 0))
await self._set_l1(key, entry)
def get_cache_stats(self) -> dict[str, Any]:
"""Get cache performance statistics"""
hit_rate = (self.cache_hits / self.cache_operations * 100) if self.cache_operations > 0 else 0
return {
"cache_hits": self.cache_hits,
"cache_misses": self.cache_misses,
"cache_operations": self.cache_operations,
"hit_rate_percent": round(hit_rate, 2),
"l1_size": len(self.l1_cache),
"l1_max_size": self.l1_max_size,
}
class QueryOptimizer:
"""Advanced SQL query optimization"""
def __init__(self, db_service: DatabaseService):
self.db_service = db_service
self.query_cache = {}
self.slow_query_threshold = 500 # ms
async def execute_optimized_query(self, query: str, params: tuple = None) -> list[dict[str, Any]]:
"""Execute query with optimization"""
start_time = time.time()
# Check query cache
query_hash = self._hash_query(query, params)
if query_hash in self.query_cache:
cached_result = self.query_cache[query_hash]
if cached_result["expires_at"] > datetime.utcnow():
return cached_result["result"]
# Analyze and optimize query
optimized_query = await self._optimize_query(query)
# Execute query
try:
result = await self.db_service.execute_query(optimized_query, params)
# Cache result if small enough
if len(str(result)) < 10000: # 10KB limit
self.query_cache[query_hash] = {
"result": result,
"expires_at": datetime.utcnow() + timedelta(minutes=5),
}
# Record performance
duration = (time.time() - start_time) * 1000
await self._record_query_performance(query, duration, True)
return result
except Exception as e:
duration = (time.time() - start_time) * 1000
await self._record_query_performance(query, duration, False, str(e))
raise e
def _hash_query(self, query: str, params: tuple = None) -> str:
"""Hash query for caching"""
query_str = query + str(params) if params else query
return hashlib.md5(query_str.encode()).hexdigest()
async def _optimize_query(self, query: str) -> str:
"""Optimize SQL query"""
# Add query hints for PostgreSQL
# optimizations = [
# # Force sequential scan for small tables
# "/*+ SeqScan */",
# # Force index usage for large tables
# "/*+ IndexScan */",
# # Parallel query execution
# "/*+ Parallel */",
# ]
# Analyze query type and apply optimizations
query_lower = query.lower()
# SELECT optimizations
if query_lower.strip().startswith("select"):
# Add appropriate hints based on query complexity
if "join" in query_lower and "where" in query_lower:
# Complex query with joins
if "index(" in query_lower:
return query # Already has hints
return f"/*+ HashJoin NestLoop */ {query}"
elif "order by" in query_lower and "limit" in query_lower:
# Simple ordered query
return f"/*+ IndexScan */ {query}"
return query
async def _record_query_performance(self, query: str, duration_ms: float, success: bool, error: str = None):
"""Record query performance for analysis"""
monitoring_collector = MonitoringCollector()
monitoring_collector.record_metric(
"db_query_time",
duration_ms,
{
"query_type": self._get_query_type(query),
"success": success,
"slow_query": duration_ms > self.slow_query_threshold,
},
)
# Log slow queries
if duration_ms > self.slow_query_threshold:
print(f"Slow query detected: {duration_ms:.2f}ms - {query[:100]}...")
def _get_query_type(self, query: str) -> str:
"""Categorize query type"""
query_lower = query.lower().strip()
if query_lower.startswith("select"):
return "SELECT"
elif query_lower.startswith("insert"):
return "INSERT"
elif query_lower.startswith("update"):
return "UPDATE"
elif query_lower.startswith("delete"):
return "DELETE"
elif "join" in query_lower:
return "JOIN"
else:
return "OTHER"
class ConnectionPoolManager:
"""Advanced database connection pool management"""
def __init__(self):
self.pools = {}
self.pool_stats = {}
async def get_pool(self, pool_type: str = "default"):
"""Get connection pool by type"""
if pool_type not in self.pools:
self.pools[pool_type] = await self._create_pool(pool_type)
return self.pools[pool_type]
async def _create_pool(self, pool_type: str):
"""Create optimized connection pool"""
if pool_type == "analytics":
# Large pool for analytics queries
return {"min_connections": 5, "max_connections": 20, "connection_timeout": 30, "idle_timeout": 300}
elif pool_type == "transaction":
# Small pool for transaction processing
return {"min_connections": 10, "max_connections": 15, "connection_timeout": 10, "idle_timeout": 60}
else:
# Default pool
return {"min_connections": 2, "max_connections": 10, "connection_timeout": 20, "idle_timeout": 180}
class PerformanceProfiler:
"""Real-time performance profiling and optimization"""
def __init__(self):
self.metrics = []
self.function_stats = {}
self.slow_functions = set()
self.monitoring_collector = MonitoringCollector()
def profile_function(self, operation_name: str):
"""Decorator to profile function performance"""
def decorator(func: Callable):
@wraps(func)
async def async_wrapper(*args, **kwargs):
start_time = time.time()
try:
result = await func(*args, **kwargs)
success = True
error = None
except Exception as e:
success = False
error = str(e)
raise
duration = (time.time() - start_time) * 1000
# Record performance
await self.record_function_performance(operation_name, duration, success, error)
return result
@wraps(func)
def sync_wrapper(*args, **kwargs):
start_time = time.time()
try:
result = func(*args, **kwargs)
success = True
error = None
except Exception as e:
success = False
error = str(e)
raise
duration = (time.time() - start_time) * 1000
# Record performance
asyncio.create_task(self.record_function_performance(operation_name, duration, success, error))
return result
# Return appropriate wrapper based on function type
import asyncio
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
async def record_function_performance(self, operation: str, duration_ms: float, success: bool, error: str = None):
"""Record function performance metrics"""
metric = PerformanceMetric(
operation=operation,
duration_ms=duration_ms,
timestamp=datetime.utcnow(),
success=success,
metadata={"error": error} if error else None,
)
self.metrics.append(metric)
# Update function statistics
if operation not in self.function_stats:
self.function_stats[operation] = {
"call_count": 0,
"total_duration": 0,
"min_duration": float("inf"),
"max_duration": 0,
"errors": 0,
}
stats = self.function_stats[operation]
stats["call_count"] += 1
stats["total_duration"] += duration_ms
stats["min_duration"] = min(stats["min_duration"], duration_ms)
stats["max_duration"] = max(stats["max_duration"], duration_ms)
if not success:
stats["errors"] += 1
self.slow_functions.add(operation)
# Send to monitoring
self.monitoring_collector.record_metric(
"function_execution_time", duration_ms, {"operation": operation, "success": success}
)
# Alert on slow functions
if duration_ms > 1000: # 1 second threshold
await self._alert_slow_function(operation, duration_ms, error)
async def _alert_slow_function(self, operation: str, duration_ms: float, error: str):
"""Alert on slow function execution"""
print(f"ALERT: Slow function detected - {operation}: {duration_ms:.2f}ms")
# Record in monitoring system
self.monitoring_collector.record_security_event(
"PERFORMANCE_ISSUE",
{
"type": "slow_function",
"operation": operation,
"duration_ms": duration_ms,
"error": error,
"timestamp": datetime.utcnow().isoformat(),
},
)
def get_performance_report(self) -> dict[str, Any]:
"""Generate comprehensive performance report"""
if not self.function_stats:
return {"message": "No performance data available"}
# Calculate statistics
total_operations = sum(stats["call_count"] for stats in self.function_stats.values())
total_errors = sum(stats["errors"] for stats in self.function_stats.values())
# Find slowest operations
slow_operations = []
for operation, stats in self.function_stats.items():
avg_duration = stats["total_duration"] / stats["call_count"]
slow_operations.append(
{
"operation": operation,
"avg_duration": avg_duration,
"max_duration": stats["max_duration"],
"call_count": stats["call_count"],
"error_rate": (stats["errors"] / stats["call_count"]) * 100,
}
)
slow_operations.sort(key=lambda x: x["avg_duration"], reverse=True)
return {
"total_operations": total_operations,
"total_errors": total_errors,
"error_rate": (total_errors / total_operations) * 100 if total_operations > 0 else 0,
"slow_functions": list(self.slow_functions),
"top_slow_operations": slow_operations[:10],
"function_count": len(self.function_stats),
"generated_at": datetime.utcnow().isoformat(),
}
class AsyncBatchProcessor:
"""Batch processing for improved performance"""
def __init__(self, batch_size: int = 100, flush_interval: float = 5.0):
self.batch_size = batch_size
self.flush_interval = flush_interval
self.pending_items = []
self.processors = {}
self.last_flush = time.time()
self.lock = threading.Lock()
def add_item(self, processor_name: str, item: Any):
"""Add item to batch for processing"""
with self.lock:
if processor_name not in self.processors:
self.processors[processor_name] = []
self.processors[processor_name].append(item)
# Check if batch is full or interval passed
if (
len(self.processors[processor_name]) >= self.batch_size
or time.time() - self.last_flush > self.flush_interval
):
asyncio.create_task(self._flush_processor(processor_name))
async def _flush_processor(self, processor_name: str):
"""Flush pending items for a processor"""
with self.lock:
items = self.processors.get(processor_name, [])
if not items:
return
self.processors[processor_name] = []
if items:
await self._process_batch(processor_name, items)
async def _process_batch(self, processor_name: str, items: list[Any]):
"""Process batch of items"""
# This should be overridden by specific implementations
print(f"Processing batch of {len(items)} items for {processor_name}")
async def flush_all(self):
"""Flush all pending items"""
for processor_name in list(self.processors.keys()):
await self._flush_processor(processor_name)
# Global performance optimization instances
cache_manager = AdvancedCacheManager()
query_optimizer = None
performance_profiler = PerformanceProfiler()
batch_processor = AsyncBatchProcessor()
async def initialize_performance_system():
"""Initialize all performance optimization systems"""
await cache_manager.initialize()
global query_optimizer
query_optimizer = QueryOptimizer(cache_manager.db_service)
print("Advanced performance optimization system initialized")
async def get_cache_manager() -> AdvancedCacheManager:
"""Get cache manager instance"""
return cache_manager
async def get_query_optimizer() -> QueryOptimizer:
"""Get query optimizer instance"""
return query_optimizer
def get_performance_profiler() -> PerformanceProfiler:
"""Get performance profiler instance"""
return performance_profiler
def get_batch_processor() -> AsyncBatchProcessor:
"""Get batch processor instance"""
return batch_processor
# Performance monitoring middleware
class PerformanceMonitoringMiddleware:
"""Middleware to monitor API performance"""
def __init__(self, app):
self.app = app
self.profiler = get_performance_profiler()
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
start_time = time.time()
# Get request info
method = scope["method"]
path = scope["path"]
# Process request
try:
await self.app(scope, receive, send)
success = True
error = None
except Exception as e:
success = False
error = str(e)
raise
# Record performance
duration = (time.time() - start_time) * 1000
await self.profiler.record_function_performance(f"{method} {path}", duration, success, error)
else:
await self.app(scope, receive, send)