bookmyservice-mhs / app /database /query_optimizer.py
MukeshKapoor25's picture
perf(optimization): Implement comprehensive performance optimization strategy
79ca9ba
"""
Query optimization and streaming utilities for MongoDB operations.
Implements cursor-based pagination and memory-efficient query execution.
"""
import asyncio
import logging
from typing import Dict, List, Any, Optional, AsyncGenerator, Tuple
from datetime import datetime
import pymongo
from app.nosql import db
from app.utils.simple_log_sanitizer import get_simple_sanitized_logger
logger = get_simple_sanitized_logger(__name__)
class QueryOptimizer:
"""Optimizes MongoDB queries for better performance and memory usage"""
def __init__(self):
self.query_cache = {}
self.cache_ttl = 300 # 5 minutes
def optimize_pipeline(self, pipeline: List[Dict]) -> List[Dict]:
"""Optimize aggregation pipeline for better performance"""
optimized = []
match_stages = []
other_stages = []
# Separate $match stages from other stages
for stage in pipeline:
if "$match" in stage:
match_stages.append(stage)
else:
other_stages.append(stage)
# Combine multiple $match stages into one
if len(match_stages) > 1:
combined_match = {"$match": {}}
for match_stage in match_stages:
combined_match["$match"].update(match_stage["$match"])
optimized.append(combined_match)
elif match_stages:
optimized.extend(match_stages)
# Add other stages
optimized.extend(other_stages)
# Ensure $match comes first for index utilization
final_pipeline = []
match_added = False
for stage in optimized:
if "$match" in stage and not match_added:
final_pipeline.insert(0, stage)
match_added = True
elif "$match" not in stage:
final_pipeline.append(stage)
return final_pipeline
def add_index_hints(self, pipeline: List[Dict], collection_name: str) -> List[Dict]:
"""Add index hints to optimize query execution"""
# Note: $hint is not available in aggregation pipeline
# Index hints are applied at the collection.aggregate() level
# This method is kept for future enhancement but currently returns pipeline as-is
return pipeline
async def execute_optimized_query(
self,
collection_name: str,
pipeline: List[Dict],
limit: Optional[int] = None,
use_cursor: bool = True
) -> List[Dict]:
"""Execute optimized query with optional cursor-based streaming"""
try:
# Optimize the pipeline
optimized_pipeline = self.optimize_pipeline(pipeline)
collection = db[collection_name]
if use_cursor and limit and limit > 100:
# Use cursor for large result sets
return await self._execute_with_cursor(collection, optimized_pipeline, limit)
else:
# Use regular aggregation for small result sets
results = await collection.aggregate(optimized_pipeline).to_list(length=limit)
return results
except Exception as e:
logger.error(f"Error executing optimized query on {collection_name}: {e}")
# Fallback to original pipeline if optimization fails
try:
logger.info(f"Falling back to original pipeline for {collection_name}")
collection = db[collection_name]
results = await collection.aggregate(pipeline).to_list(length=limit)
return results
except Exception as fallback_error:
logger.error(f"Fallback query also failed for {collection_name}: {fallback_error}")
raise fallback_error
async def _execute_with_cursor(
self,
collection,
pipeline: List[Dict],
limit: int,
batch_size: int = 100
) -> List[Dict]:
"""Execute query using cursor-based pagination to manage memory"""
results = []
processed = 0
# Add batch processing to pipeline
cursor = collection.aggregate(pipeline, batchSize=batch_size)
async for document in cursor:
results.append(document)
processed += 1
if processed >= limit:
break
# Yield control periodically to prevent blocking
if processed % batch_size == 0:
await asyncio.sleep(0) # Yield to event loop
return results
async def stream_query_results(
self,
collection_name: str,
pipeline: List[Dict],
batch_size: int = 100
) -> AsyncGenerator[List[Dict], None]:
"""Stream query results in batches to manage memory usage"""
optimized_pipeline = self.optimize_pipeline(pipeline)
collection = db[collection_name]
try:
cursor = collection.aggregate(optimized_pipeline, batchSize=batch_size)
batch = []
async for document in cursor:
batch.append(document)
if len(batch) >= batch_size:
yield batch
batch = []
await asyncio.sleep(0) # Yield to event loop
# Yield remaining documents
if batch:
yield batch
except Exception as e:
logger.error(f"Error streaming query results from {collection_name}")
raise
async def execute_paginated_query(
self,
collection_name: str,
pipeline: List[Dict],
page_size: int = 20,
cursor_field: str = "_id",
cursor_value: Optional[Any] = None,
sort_direction: int = 1
) -> Tuple[List[Dict], Optional[Any]]:
"""Execute cursor-based paginated query"""
# Add cursor-based pagination to pipeline
paginated_pipeline = pipeline.copy()
# Add cursor filter if provided
if cursor_value is not None:
cursor_filter = {
cursor_field: {"$gt" if sort_direction == 1 else "$lt": cursor_value}
}
# Add to existing $match or create new one
match_added = False
for stage in paginated_pipeline:
if "$match" in stage:
stage["$match"].update(cursor_filter)
match_added = True
break
if not match_added:
paginated_pipeline.insert(0, {"$match": cursor_filter})
# Add sort and limit
paginated_pipeline.extend([
{"$sort": {cursor_field: sort_direction}},
{"$limit": page_size + 1} # Get one extra to check if there are more
])
# Execute query
results = await self.execute_optimized_query(
collection_name,
paginated_pipeline,
limit=page_size + 1,
use_cursor=False
)
# Determine next cursor
next_cursor = None
if len(results) > page_size:
next_cursor = results[-1].get(cursor_field)
results = results[:-1] # Remove the extra document
return results, next_cursor
def get_query_stats(self) -> Dict[str, Any]:
"""Get query optimization statistics"""
return {
"cache_size": len(self.query_cache),
"cache_ttl": self.cache_ttl,
"optimizations_applied": [
"Pipeline stage reordering",
"Multiple $match stage combination",
"Index hint addition",
"Cursor-based pagination",
"Memory-efficient streaming"
]
}
class MemoryEfficientAggregator:
"""Memory-efficient aggregation operations"""
def __init__(self, max_memory_mb: int = 100):
self.max_memory_mb = max_memory_mb
self.batch_size = 1000
async def aggregate_with_memory_limit(
self,
collection_name: str,
pipeline: List[Dict],
max_results: int = 10000
) -> List[Dict]:
"""Aggregate with memory usage monitoring"""
collection = db[collection_name]
results = []
processed = 0
# Add allowDiskUse for large aggregations
cursor = collection.aggregate(
pipeline,
allowDiskUse=True,
batchSize=self.batch_size
)
try:
async for document in cursor:
results.append(document)
processed += 1
# Check memory usage periodically
if processed % self.batch_size == 0:
import psutil
memory_usage = psutil.Process().memory_info().rss / 1024 / 1024 # MB
if memory_usage > self.max_memory_mb:
logger.warning(f"Memory usage ({memory_usage:.1f}MB) exceeds limit ({self.max_memory_mb}MB)")
break
await asyncio.sleep(0) # Yield to event loop
if processed >= max_results:
break
logger.info(f"Processed {processed} documents with memory-efficient aggregation")
return results
except Exception as e:
logger.error(f"Error in memory-efficient aggregation: {e}")
raise
async def count_with_timeout(
self,
collection_name: str,
filter_criteria: Dict,
timeout_seconds: int = 30
) -> int:
"""Count documents with timeout to prevent long-running operations"""
collection = db[collection_name]
try:
# Use asyncio.wait_for to add timeout
count = await asyncio.wait_for(
collection.count_documents(filter_criteria),
timeout=timeout_seconds
)
return count
except asyncio.TimeoutError:
logger.warning(f"Count operation timed out after {timeout_seconds}s")
# Return estimated count using aggregation
pipeline = [
{"$match": filter_criteria},
{"$count": "total"}
]
result = await collection.aggregate(pipeline).to_list(length=1)
return result[0]["total"] if result else 0
except Exception as e:
logger.error(f"Error counting documents: {e}")
return 0
# Global instances
query_optimizer = QueryOptimizer()
memory_aggregator = MemoryEfficientAggregator()
async def execute_optimized_aggregation(
collection_name: str,
pipeline: List[Dict],
limit: Optional[int] = None,
use_streaming: bool = False
) -> List[Dict]:
"""Execute optimized aggregation with automatic optimization and fallback"""
try:
if use_streaming and limit and limit > 1000:
# Use streaming for large result sets
results = []
async for batch in query_optimizer.stream_query_results(collection_name, pipeline):
results.extend(batch)
if len(results) >= limit:
results = results[:limit]
break
return results
else:
# Use regular optimized query
return await query_optimizer.execute_optimized_query(
collection_name,
pipeline,
limit=limit,
use_cursor=False # Disable cursor for now to avoid complexity
)
except Exception as e:
logger.error(f"Optimized aggregation failed for {collection_name}: {e}")
# Final fallback - direct database call
collection = db[collection_name]
results = await collection.aggregate(pipeline).to_list(length=limit)
return results