|
|
|
|
|
""" |
|
|
Embedding Optimizer - Performance optimization and caching |
|
|
Advanced optimization strategies for the embedding pipeline |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import logging |
|
|
import numpy as np |
|
|
from typing import List, Dict, Any, Optional, Union, Tuple, Callable |
|
|
from dataclasses import dataclass |
|
|
import json |
|
|
import time |
|
|
import threading |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
import pickle |
|
|
import hashlib |
|
|
from pathlib import Path |
|
|
import sqlite3 |
|
|
from contextlib import contextmanager |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OptimizationConfig: |
|
|
"""Configuration for embedding optimization""" |
|
|
|
|
|
use_disk_cache: bool = True |
|
|
cache_directory: str = "./cache/optimized_embeddings" |
|
|
max_cache_size_mb: int = 1000 |
|
|
cache_compression: bool = True |
|
|
|
|
|
|
|
|
use_gpu: bool = False |
|
|
batch_processing: bool = True |
|
|
max_batch_size: int = 64 |
|
|
prefetch_embeddings: bool = True |
|
|
|
|
|
|
|
|
use_memory_mapping: bool = False |
|
|
max_memory_usage_mb: int = 2048 |
|
|
garbage_collection_frequency: int = 100 |
|
|
|
|
|
|
|
|
use_indexing: bool = True |
|
|
index_type: str = "faiss" |
|
|
index_dimensions: int = 768 |
|
|
|
|
|
|
|
|
adaptive_batching: bool = True |
|
|
performance_monitoring: bool = True |
|
|
auto_tuning: bool = True |
|
|
|
|
|
|
|
|
class EmbeddingCache: |
|
|
"""Advanced embedding cache with disk persistence and compression""" |
|
|
|
|
|
def __init__(self, config: OptimizationConfig): |
|
|
self.config = config |
|
|
self.memory_cache = {} |
|
|
self.cache_stats = { |
|
|
"hits": 0, |
|
|
"misses": 0, |
|
|
"disk_hits": 0, |
|
|
"disk_misses": 0 |
|
|
} |
|
|
self.cache_lock = threading.RLock() |
|
|
|
|
|
|
|
|
if self.config.use_disk_cache: |
|
|
self.cache_dir = Path(self.config.cache_directory) |
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
self._setup_disk_cache() |
|
|
|
|
|
def _setup_disk_cache(self): |
|
|
"""Setup disk-based cache""" |
|
|
try: |
|
|
|
|
|
self.db_path = self.cache_dir / "cache.db" |
|
|
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) |
|
|
|
|
|
|
|
|
self.conn.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS cache_metadata ( |
|
|
key TEXT PRIMARY KEY, |
|
|
file_path TEXT NOT NULL, |
|
|
created_at REAL NOT NULL, |
|
|
access_count INTEGER DEFAULT 0, |
|
|
last_accessed REAL NOT NULL, |
|
|
size_bytes INTEGER NOT NULL |
|
|
) |
|
|
""") |
|
|
|
|
|
self.conn.commit() |
|
|
logger.info("โ
Disk cache initialized") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Disk cache setup failed: {e}") |
|
|
self.config.use_disk_cache = False |
|
|
|
|
|
def _get_cache_key(self, text: str, config_hash: str = "") -> str: |
|
|
"""Generate cache key""" |
|
|
key_data = f"{text}_{config_hash}" |
|
|
return hashlib.md5(key_data.encode()).hexdigest() |
|
|
|
|
|
def get(self, text: str, config_hash: str = "") -> Optional[Dict[str, Any]]: |
|
|
"""Get cached embedding""" |
|
|
cache_key = self._get_cache_key(text, config_hash) |
|
|
|
|
|
with self.cache_lock: |
|
|
|
|
|
if cache_key in self.memory_cache: |
|
|
self.cache_stats["hits"] += 1 |
|
|
return self.memory_cache[cache_key] |
|
|
|
|
|
|
|
|
if self.config.use_disk_cache: |
|
|
disk_result = self._get_from_disk(cache_key) |
|
|
if disk_result: |
|
|
self.cache_stats["disk_hits"] += 1 |
|
|
|
|
|
self.memory_cache[cache_key] = disk_result |
|
|
return disk_result |
|
|
else: |
|
|
self.cache_stats["disk_misses"] += 1 |
|
|
|
|
|
self.cache_stats["misses"] += 1 |
|
|
return None |
|
|
|
|
|
def _get_from_disk(self, cache_key: str) -> Optional[Dict[str, Any]]: |
|
|
"""Get embedding from disk cache""" |
|
|
try: |
|
|
cursor = self.conn.execute( |
|
|
"SELECT file_path FROM cache_metadata WHERE key = ?", |
|
|
(cache_key,) |
|
|
) |
|
|
result = cursor.fetchone() |
|
|
|
|
|
if result: |
|
|
file_path = Path(result[0]) |
|
|
if file_path.exists(): |
|
|
with open(file_path, 'rb') as f: |
|
|
if self.config.cache_compression: |
|
|
import gzip |
|
|
data = pickle.loads(gzip.decompress(f.read())) |
|
|
else: |
|
|
data = pickle.load(f) |
|
|
|
|
|
|
|
|
self.conn.execute( |
|
|
"UPDATE cache_metadata SET access_count = access_count + 1, last_accessed = ? WHERE key = ?", |
|
|
(time.time(), cache_key) |
|
|
) |
|
|
self.conn.commit() |
|
|
|
|
|
return data |
|
|
|
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Disk cache retrieval failed: {e}") |
|
|
return None |
|
|
|
|
|
def put(self, text: str, config_hash: str, embedding_data: Dict[str, Any]): |
|
|
"""Store embedding in cache""" |
|
|
cache_key = self._get_cache_key(text, config_hash) |
|
|
|
|
|
with self.cache_lock: |
|
|
|
|
|
self.memory_cache[cache_key] = embedding_data |
|
|
|
|
|
|
|
|
if self.config.use_disk_cache: |
|
|
self._put_to_disk(cache_key, embedding_data) |
|
|
|
|
|
|
|
|
self._check_memory_usage() |
|
|
|
|
|
def _put_to_disk(self, cache_key: str, embedding_data: Dict[str, Any]): |
|
|
"""Store embedding to disk cache""" |
|
|
try: |
|
|
|
|
|
file_path = self.cache_dir / f"{cache_key}.pkl" |
|
|
|
|
|
|
|
|
serialized_data = pickle.dumps(embedding_data) |
|
|
|
|
|
|
|
|
if self.config.cache_compression: |
|
|
import gzip |
|
|
serialized_data = gzip.compress(serialized_data) |
|
|
|
|
|
|
|
|
with open(file_path, 'wb') as f: |
|
|
f.write(serialized_data) |
|
|
|
|
|
|
|
|
self.conn.execute( |
|
|
"INSERT OR REPLACE INTO cache_metadata (key, file_path, created_at, last_accessed, size_bytes) VALUES (?, ?, ?, ?, ?)", |
|
|
(cache_key, str(file_path), time.time(), time.time(), len(serialized_data)) |
|
|
) |
|
|
self.conn.commit() |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Disk cache storage failed: {e}") |
|
|
|
|
|
def _check_memory_usage(self): |
|
|
"""Check and manage memory usage""" |
|
|
if len(self.memory_cache) > 1000: |
|
|
|
|
|
sorted_items = sorted( |
|
|
self.memory_cache.items(), |
|
|
key=lambda x: x[1].get("metadata", {}).get("created_at", 0) |
|
|
) |
|
|
|
|
|
|
|
|
remove_count = len(sorted_items) // 5 |
|
|
for key, _ in sorted_items[:remove_count]: |
|
|
del self.memory_cache[key] |
|
|
|
|
|
def clear(self): |
|
|
"""Clear all caches""" |
|
|
with self.cache_lock: |
|
|
self.memory_cache.clear() |
|
|
|
|
|
if self.config.use_disk_cache: |
|
|
try: |
|
|
|
|
|
for file_path in self.cache_dir.glob("*.pkl"): |
|
|
file_path.unlink() |
|
|
|
|
|
|
|
|
self.conn.execute("DELETE FROM cache_metadata") |
|
|
self.conn.commit() |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Disk cache clear failed: {e}") |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
"""Get cache statistics""" |
|
|
with self.cache_lock: |
|
|
total_requests = self.cache_stats["hits"] + self.cache_stats["misses"] |
|
|
hit_rate = self.cache_stats["hits"] / total_requests if total_requests > 0 else 0 |
|
|
|
|
|
return { |
|
|
**self.cache_stats, |
|
|
"memory_cache_size": len(self.memory_cache), |
|
|
"hit_rate": hit_rate, |
|
|
"total_requests": total_requests |
|
|
} |
|
|
|
|
|
|
|
|
class EmbeddingOptimizer: |
|
|
"""Advanced embedding optimizer with caching, batching, and performance monitoring""" |
|
|
|
|
|
def __init__(self, config: Optional[OptimizationConfig] = None): |
|
|
self.config = config or OptimizationConfig() |
|
|
self.cache = EmbeddingCache(self.config) |
|
|
self.performance_metrics = { |
|
|
"total_embeddings": 0, |
|
|
"cache_hits": 0, |
|
|
"batch_operations": 0, |
|
|
"average_batch_size": 0.0, |
|
|
"average_processing_time": 0.0, |
|
|
"memory_usage_mb": 0.0 |
|
|
} |
|
|
|
|
|
|
|
|
self.processing_times = [] |
|
|
self.batch_sizes = [] |
|
|
|
|
|
|
|
|
self.optimal_batch_size = self.config.max_batch_size |
|
|
|
|
|
logger.info("โ
Embedding optimizer initialized") |
|
|
|
|
|
async def optimize_embedding_generation(self, embedder_func: Callable, |
|
|
texts: List[str], |
|
|
config_hash: str = "") -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Optimize embedding generation with caching and batching |
|
|
|
|
|
Args: |
|
|
embedder_func: Function to generate embeddings |
|
|
texts: List of texts to embed |
|
|
config_hash: Configuration hash for cache key |
|
|
|
|
|
Returns: |
|
|
List of embedding results |
|
|
""" |
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
|
|
|
cached_results = [] |
|
|
uncached_texts = [] |
|
|
uncached_indices = [] |
|
|
|
|
|
for i, text in enumerate(texts): |
|
|
cached_result = self.cache.get(text, config_hash) |
|
|
if cached_result: |
|
|
cached_result["cached"] = True |
|
|
cached_results.append(cached_result) |
|
|
self.performance_metrics["cache_hits"] += 1 |
|
|
else: |
|
|
cached_results.append(None) |
|
|
uncached_texts.append(text) |
|
|
uncached_indices.append(i) |
|
|
|
|
|
|
|
|
if uncached_texts: |
|
|
if self.config.batch_processing: |
|
|
|
|
|
batch_results = await self._process_batch(embedder_func, uncached_texts, config_hash) |
|
|
|
|
|
|
|
|
for i, result in zip(uncached_indices, batch_results): |
|
|
cached_results[i] = result |
|
|
else: |
|
|
|
|
|
for i, text in zip(uncached_indices, uncached_texts): |
|
|
result = await embedder_func(text) |
|
|
result["cached"] = False |
|
|
cached_results[i] = result |
|
|
|
|
|
|
|
|
self.cache.put(text, config_hash, result) |
|
|
|
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
self._update_metrics(len(texts), processing_time) |
|
|
|
|
|
return cached_results |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"โ Optimized embedding generation failed: {e}") |
|
|
return [] |
|
|
|
|
|
async def _process_batch(self, embedder_func: Callable, texts: List[str], |
|
|
config_hash: str) -> List[Dict[str, Any]]: |
|
|
"""Process batch of texts with adaptive batching""" |
|
|
try: |
|
|
|
|
|
if self.config.adaptive_batching: |
|
|
optimal_size = self._calculate_optimal_batch_size() |
|
|
else: |
|
|
optimal_size = self.config.max_batch_size |
|
|
|
|
|
|
|
|
results = [] |
|
|
for i in range(0, len(texts), optimal_size): |
|
|
batch = texts[i:i + optimal_size] |
|
|
|
|
|
|
|
|
batch_start = time.time() |
|
|
batch_results = await embedder_func(batch) |
|
|
batch_time = time.time() - batch_start |
|
|
|
|
|
|
|
|
if isinstance(batch_results, dict): |
|
|
|
|
|
for text in batch: |
|
|
result = batch_results.copy() |
|
|
result["cached"] = False |
|
|
results.append(result) |
|
|
|
|
|
|
|
|
self.cache.put(text, config_hash, result) |
|
|
else: |
|
|
|
|
|
for text, result in zip(batch, batch_results): |
|
|
result["cached"] = False |
|
|
results.append(result) |
|
|
|
|
|
|
|
|
self.cache.put(text, config_hash, result) |
|
|
|
|
|
|
|
|
self.batch_sizes.append(len(batch)) |
|
|
self.processing_times.append(batch_time) |
|
|
|
|
|
|
|
|
if self.config.adaptive_batching: |
|
|
self._adjust_batch_size(batch_time, len(batch)) |
|
|
|
|
|
self.performance_metrics["batch_operations"] += 1 |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"โ Batch processing failed: {e}") |
|
|
return [] |
|
|
|
|
|
def _calculate_optimal_batch_size(self) -> int: |
|
|
"""Calculate optimal batch size based on performance metrics""" |
|
|
if not self.processing_times or not self.batch_sizes: |
|
|
return self.config.max_batch_size |
|
|
|
|
|
|
|
|
if len(self.processing_times) > 1: |
|
|
recent_time = np.mean(self.processing_times[-3:]) |
|
|
older_time = np.mean(self.processing_times[:-3]) if len(self.processing_times) > 3 else recent_time |
|
|
|
|
|
if recent_time > older_time * 1.2: |
|
|
return max(1, self.optimal_batch_size // 2) |
|
|
elif recent_time < older_time * 0.8: |
|
|
return min(self.config.max_batch_size, self.optimal_batch_size * 2) |
|
|
|
|
|
return self.optimal_batch_size |
|
|
|
|
|
def _adjust_batch_size(self, processing_time: float, batch_size: int): |
|
|
"""Adjust batch size based on processing time""" |
|
|
|
|
|
if processing_time > 5.0: |
|
|
self.optimal_batch_size = max(1, batch_size // 2) |
|
|
elif processing_time < 1.0 and batch_size < self.config.max_batch_size: |
|
|
self.optimal_batch_size = min(self.config.max_batch_size, batch_size * 2) |
|
|
|
|
|
def _update_metrics(self, count: int, processing_time: float): |
|
|
"""Update performance metrics""" |
|
|
self.performance_metrics["total_embeddings"] += count |
|
|
|
|
|
|
|
|
total_ops = self.performance_metrics["total_embeddings"] |
|
|
if total_ops == count: |
|
|
self.performance_metrics["average_processing_time"] = processing_time |
|
|
else: |
|
|
current_avg = self.performance_metrics["average_processing_time"] |
|
|
self.performance_metrics["average_processing_time"] = ( |
|
|
(current_avg * (total_ops - count) + processing_time * count) / total_ops |
|
|
) |
|
|
|
|
|
|
|
|
if self.batch_sizes: |
|
|
self.performance_metrics["average_batch_size"] = np.mean(self.batch_sizes) |
|
|
|
|
|
|
|
|
import psutil |
|
|
process = psutil.Process() |
|
|
self.performance_metrics["memory_usage_mb"] = process.memory_info().rss / 1024 / 1024 |
|
|
|
|
|
def create_index(self, embeddings: List[np.ndarray], texts: List[str]) -> Dict[str, Any]: |
|
|
"""Create search index for embeddings""" |
|
|
try: |
|
|
if not self.config.use_indexing or not embeddings: |
|
|
return {"index": None, "type": "none"} |
|
|
|
|
|
if self.config.index_type == "faiss": |
|
|
return self._create_faiss_index(embeddings, texts) |
|
|
elif self.config.index_type == "annoy": |
|
|
return self._create_annoy_index(embeddings, texts) |
|
|
else: |
|
|
logger.warning(f"Unsupported index type: {self.config.index_type}") |
|
|
return {"index": None, "type": "none"} |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Index creation failed: {e}") |
|
|
return {"index": None, "type": "none"} |
|
|
|
|
|
def _create_faiss_index(self, embeddings: List[np.ndarray], texts: List[str]) -> Dict[str, Any]: |
|
|
"""Create FAISS index""" |
|
|
try: |
|
|
import faiss |
|
|
|
|
|
|
|
|
embeddings_array = np.array(embeddings, dtype=np.float32) |
|
|
faiss.normalize_L2(embeddings_array) |
|
|
|
|
|
|
|
|
dimension = embeddings_array.shape[1] |
|
|
index = faiss.IndexFlatIP(dimension) |
|
|
|
|
|
|
|
|
index.add(embeddings_array) |
|
|
|
|
|
return { |
|
|
"index": index, |
|
|
"type": "faiss", |
|
|
"dimension": dimension, |
|
|
"size": len(embeddings), |
|
|
"texts": texts |
|
|
} |
|
|
|
|
|
except ImportError: |
|
|
logger.warning("FAISS not available") |
|
|
return {"index": None, "type": "none"} |
|
|
except Exception as e: |
|
|
logger.warning(f"FAISS index creation failed: {e}") |
|
|
return {"index": None, "type": "none"} |
|
|
|
|
|
def _create_annoy_index(self, embeddings: List[np.ndarray], texts: List[str]) -> Dict[str, Any]: |
|
|
"""Create Annoy index""" |
|
|
try: |
|
|
from annoy import AnnoyIndex |
|
|
|
|
|
if not embeddings: |
|
|
return {"index": None, "type": "none"} |
|
|
|
|
|
dimension = len(embeddings[0]) |
|
|
index = AnnoyIndex(dimension, 'angular') |
|
|
|
|
|
|
|
|
for i, embedding in enumerate(embeddings): |
|
|
index.add_item(i, embedding) |
|
|
|
|
|
|
|
|
index.build(10) |
|
|
|
|
|
return { |
|
|
"index": index, |
|
|
"type": "annoy", |
|
|
"dimension": dimension, |
|
|
"size": len(embeddings), |
|
|
"texts": texts |
|
|
} |
|
|
|
|
|
except ImportError: |
|
|
logger.warning("Annoy not available") |
|
|
return {"index": None, "type": "none"} |
|
|
except Exception as e: |
|
|
logger.warning(f"Annoy index creation failed: {e}") |
|
|
return {"index": None, "type": "none"} |
|
|
|
|
|
def search_similar(self, index_data: Dict[str, Any], query_embedding: np.ndarray, |
|
|
top_k: int = 10) -> List[Tuple[int, float]]: |
|
|
"""Search for similar embeddings using index""" |
|
|
try: |
|
|
if not index_data.get("index"): |
|
|
return [] |
|
|
|
|
|
index_type = index_data["type"] |
|
|
|
|
|
if index_type == "faiss": |
|
|
return self._search_faiss(index_data, query_embedding, top_k) |
|
|
elif index_type == "annoy": |
|
|
return self._search_annoy(index_data, query_embedding, top_k) |
|
|
else: |
|
|
return [] |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Similarity search failed: {e}") |
|
|
return [] |
|
|
|
|
|
def _search_faiss(self, index_data: Dict[str, Any], query_embedding: np.ndarray, |
|
|
top_k: int) -> List[Tuple[int, float]]: |
|
|
"""Search using FAISS index""" |
|
|
try: |
|
|
import faiss |
|
|
|
|
|
index = index_data["index"] |
|
|
query = query_embedding.reshape(1, -1).astype(np.float32) |
|
|
faiss.normalize_L2(query) |
|
|
|
|
|
|
|
|
scores, indices = index.search(query, top_k) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for i, (score, idx) in enumerate(zip(scores[0], indices[0])): |
|
|
if idx != -1: |
|
|
results.append((int(idx), float(score))) |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"FAISS search failed: {e}") |
|
|
return [] |
|
|
|
|
|
def _search_annoy(self, index_data: Dict[str, Any], query_embedding: np.ndarray, |
|
|
top_k: int) -> List[Tuple[int, float]]: |
|
|
"""Search using Annoy index""" |
|
|
try: |
|
|
index = index_data["index"] |
|
|
|
|
|
|
|
|
indices, distances = index.get_nns_by_vector(query_embedding, top_k, include_distances=True) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for idx, dist in zip(indices, distances): |
|
|
similarity = 1.0 / (1.0 + dist) |
|
|
results.append((int(idx), float(similarity))) |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Annoy search failed: {e}") |
|
|
return [] |
|
|
|
|
|
def get_performance_report(self) -> Dict[str, Any]: |
|
|
"""Get comprehensive performance report""" |
|
|
cache_stats = self.cache.get_stats() |
|
|
|
|
|
return { |
|
|
"performance_metrics": self.performance_metrics.copy(), |
|
|
"cache_stats": cache_stats, |
|
|
"optimization_config": { |
|
|
"batch_processing": self.config.batch_processing, |
|
|
"adaptive_batching": self.config.adaptive_batching, |
|
|
"optimal_batch_size": self.optimal_batch_size, |
|
|
"use_indexing": self.config.use_indexing, |
|
|
"index_type": self.config.index_type |
|
|
}, |
|
|
"recent_performance": { |
|
|
"recent_batch_sizes": self.batch_sizes[-10:] if self.batch_sizes else [], |
|
|
"recent_processing_times": self.processing_times[-10:] if self.processing_times else [] |
|
|
} |
|
|
} |
|
|
|
|
|
def clear_cache(self): |
|
|
"""Clear all caches""" |
|
|
self.cache.clear() |
|
|
logger.info("Optimizer cache cleared") |
|
|
|
|
|
def reset_metrics(self): |
|
|
"""Reset performance metrics""" |
|
|
self.performance_metrics = { |
|
|
"total_embeddings": 0, |
|
|
"cache_hits": 0, |
|
|
"batch_operations": 0, |
|
|
"average_batch_size": 0.0, |
|
|
"average_processing_time": 0.0, |
|
|
"memory_usage_mb": 0.0 |
|
|
} |
|
|
self.processing_times.clear() |
|
|
self.batch_sizes.clear() |
|
|
self.optimal_batch_size = self.config.max_batch_size |
|
|
logger.info("Performance metrics reset") |
|
|
|