voicebot / core /cag_system.py
datbkpro's picture
Create cag_system.py
47284c1 verified
# services/cag_service.py
import hashlib
import json
import time
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
import faiss
import redis
import pickle
import os
from dataclasses import dataclass
from enum import Enum
@dataclass
class CAGConfig:
"""Cấu hình hệ thống CAG"""
# Cache settings
USE_MEMORY_CACHE = True
USE_REDIS_CACHE = False
USE_DISK_CACHE = True
CACHE_DIR = ".cag_cache"
# TTL settings (seconds)
EMBEDDING_TTL = 86400 # 24 hours
SEARCH_RESULT_TTL = 3600 # 1 hour
SEMANTIC_CACHE_TTL = 7200 # 2 hours
GENERATION_TTL = 1800 # 30 minutes
# Cache thresholds
SEMANTIC_SIMILARITY_THRESHOLD = 0.85
MIN_QUERY_LENGTH = 3
MAX_CACHE_SIZE = 10000
# Performance settings
ENABLE_CACHE_STATS = True
LOG_CACHE_PERFORMANCE = True
class CacheHitType(str, Enum):
"""Loại cache hit"""
EXACT = "exact"
SEMANTIC = "semantic"
PARTIAL = "partial"
NONE = "none"
class CAGService:
"""Cache-Augmented Generation Service"""
def __init__(self, rag_system, multilingual_manager):
self.rag_system = rag_system
self.multilingual_manager = multilingual_manager
# Cache configuration
self.config = CAGConfig()
# Cache storage
self.memory_cache = {} # In-memory cache
self.semantic_cache_index = None
self.semantic_cache_embeddings = []
self.semantic_cache_keys = []
# Redis client (optional)
self.redis_client = None
self._init_redis()
# Disk cache
self._init_cache_directory()
# Performance tracking
self.stats = {
"total_queries": 0,
"cache_hits": 0,
"exact_hits": 0,
"semantic_hits": 0,
"response_times": [],
"cost_savings": 0
}
print("✅ CAG Service initialized")
def _init_redis(self):
"""Khởi tạo Redis client nếu được cấu hình"""
if self.config.USE_REDIS_CACHE:
try:
self.redis_client = redis.Redis(
host='localhost',
port=6379,
db=0,
decode_responses=False
)
self.redis_client.ping()
print("✅ Redis cache connected")
except Exception as e:
print(f"⚠️ Redis not available: {e}")
self.config.USE_REDIS_CACHE = False
def _init_cache_directory(self):
"""Khởi tạo thư mục cache"""
os.makedirs(self.config.CACHE_DIR, exist_ok=True)
os.makedirs(f"{self.config.CACHE_DIR}/embeddings", exist_ok=True)
os.makedirs(f"{self.config.CACHE_DIR}/results", exist_ok=True)
def _generate_cache_key(self, data_type: str, content: str, params: Dict = None) -> str:
"""Tạo cache key duy nhất"""
key_data = {
"type": data_type,
"content": content,
"params": params or {}
}
key_str = json.dumps(key_data, sort_keys=True)
return hashlib.sha256(key_str.encode()).hexdigest()[:32]
def cache_embedding(self, text: str, embedding: np.ndarray, language: str):
"""Cache embedding của text"""
if not self.config.USE_MEMORY_CACHE:
return
cache_key = self._generate_cache_key("embedding", text, {"language": language})
cache_entry = {
"embedding": embedding.tolist(),
"language": language,
"timestamp": datetime.now().isoformat(),
"text_length": len(text)
}
# Lưu vào memory cache
self.memory_cache[cache_key] = cache_entry
# Lưu vào disk cache
if self.config.USE_DISK_CACHE:
cache_path = f"{self.config.CACHE_DIR}/embeddings/{cache_key}.pkl"
try:
with open(cache_path, 'wb') as f:
pickle.dump(cache_entry, f)
except Exception as e:
print(f"⚠️ Failed to save embedding cache: {e}")
def get_cached_embedding(self, text: str, language: str) -> Optional[np.ndarray]:
"""Lấy embedding từ cache nếu có"""
cache_key = self._generate_cache_key("embedding", text, {"language": language})
# Check memory cache first
if cache_key in self.memory_cache:
entry = self.memory_cache[cache_key]
if self._is_cache_entry_valid(entry, self.config.EMBEDDING_TTL):
return np.array(entry["embedding"])
# Check disk cache
if self.config.USE_DISK_CACHE:
cache_path = f"{self.config.CACHE_DIR}/embeddings/{cache_key}.pkl"
if os.path.exists(cache_path):
try:
with open(cache_path, 'rb') as f:
entry = pickle.load(f)
if self._is_cache_entry_valid(entry, self.config.EMBEDDING_TTL):
# Update memory cache
self.memory_cache[cache_key] = entry
return np.array(entry["embedding"])
except Exception as e:
print(f"⚠️ Failed to load embedding cache: {e}")
return None
def cache_search_results(self, query: str, results: List, top_k: int, language: str):
"""Cache kết quả tìm kiếm"""
cache_key = self._generate_cache_key("search", query, {
"top_k": top_k,
"language": language
})
# Generate query embedding for semantic cache
embedding_model = self.multilingual_manager.get_embedding_model(language)
if embedding_model:
query_embedding = embedding_model.encode([query])[0]
self._update_semantic_cache(cache_key, query_embedding)
cache_entry = {
"query": query,
"results": [r.__dict__ if hasattr(r, '__dict__') else r for r in results],
"timestamp": datetime.now().isoformat(),
"language": language,
"top_k": top_k
}
# Save to memory cache
self.memory_cache[cache_key] = cache_entry
# Save to Redis if available
if self.config.USE_REDIS_CACHE and self.redis_client:
try:
self.redis_client.setex(
f"cag:search:{cache_key}",
self.config.SEARCH_RESULT_TTL,
pickle.dumps(cache_entry)
)
except Exception as e:
print(f"⚠️ Redis cache failed: {e}")
# Save to disk
if self.config.USE_DISK_CACHE:
cache_path = f"{self.config.CACHE_DIR}/results/{cache_key}.pkl"
try:
with open(cache_path, 'wb') as f:
pickle.dump(cache_entry, f)
except Exception as e:
print(f"⚠️ Failed to save search cache: {e}")
def get_cached_search_results(self, query: str, top_k: int, language: str) -> Tuple[Optional[List], CacheHitType]:
"""Lấy kết quả tìm kiếm từ cache"""
self.stats["total_queries"] += 1
if len(query.strip()) < self.config.MIN_QUERY_LENGTH:
return None, CacheHitType.NONE
# 1. Try exact match cache
exact_key = self._generate_cache_key("search", query, {
"top_k": top_k,
"language": language
})
cached_results = self._get_cache_entry(exact_key, self.config.SEARCH_RESULT_TTL)
if cached_results:
self.stats["cache_hits"] += 1
self.stats["exact_hits"] += 1
return cached_results.get("results"), CacheHitType.EXACT
# 2. Try semantic cache
if self.semantic_cache_index is not None and len(self.semantic_cache_embeddings) > 0:
embedding_model = self.multilingual_manager.get_embedding_model(language)
if embedding_model:
query_embedding = embedding_model.encode([query])[0]
similar_key, similarity = self._semantic_cache_lookup(query_embedding)
if similarity >= self.config.SEMANTIC_SIMILARITY_THRESHOLD:
cached_results = self._get_cache_entry(similar_key, self.config.SEMANTIC_CACHE_TTL)
if cached_results:
self.stats["cache_hits"] += 1
self.stats["semantic_hits"] += 1
# Adjust results for semantic match
adjusted_results = self._adjust_cached_results(
cached_results.get("results"),
query,
similarity
)
return adjusted_results, CacheHitType.SEMANTIC
return None, CacheHitType.NONE
def _update_semantic_cache(self, cache_key: str, embedding: np.ndarray):
"""Cập nhật semantic cache"""
if len(self.semantic_cache_embeddings) >= self.config.MAX_CACHE_SIZE:
# Remove oldest entries
self.semantic_cache_keys.pop(0)
self.semantic_cache_embeddings.pop(0)
self.semantic_cache_keys.append(cache_key)
self.semantic_cache_embeddings.append(embedding)
# Rebuild FAISS index
if len(self.semantic_cache_embeddings) > 0:
embeddings_array = np.array(self.semantic_cache_embeddings).astype(np.float32)
dimension = embeddings_array.shape[1]
if self.semantic_cache_index is None:
self.semantic_cache_index = faiss.IndexFlatIP(dimension)
self.semantic_cache_index.reset()
faiss.normalize_L2(embeddings_array)
self.semantic_cache_index.add(embeddings_array)
def _semantic_cache_lookup(self, query_embedding: np.ndarray) -> Tuple[Optional[str], float]:
"""Tìm kiếm trong semantic cache"""
if len(self.semantic_cache_embeddings) == 0:
return None, 0.0
query_embedding = query_embedding / np.linalg.norm(query_embedding)
query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
distances, indices = self.semantic_cache_index.search(query_embedding, k=1)
if len(indices[0]) > 0 and indices[0][0] != -1:
idx = indices[0][0]
similarity = 1 - distances[0][0]
return self.semantic_cache_keys[idx], similarity
return None, 0.0
def _get_cache_entry(self, cache_key: str, ttl: int) -> Optional[Dict]:
"""Lấy cache entry từ multiple layers"""
# Check memory cache
if cache_key in self.memory_cache:
entry = self.memory_cache[cache_key]
if self._is_cache_entry_valid(entry, ttl):
return entry
# Check Redis
if self.config.USE_REDIS_CACHE and self.redis_client:
try:
cached = self.redis_client.get(f"cag:search:{cache_key}")
if cached:
entry = pickle.loads(cached)
if self._is_cache_entry_valid(entry, ttl):
# Update memory cache
self.memory_cache[cache_key] = entry
return entry
except Exception as e:
print(f"⚠️ Redis get failed: {e}")
# Check disk cache
if self.config.USE_DISK_CACHE:
cache_path = f"{self.config.CACHE_DIR}/results/{cache_key}.pkl"
if os.path.exists(cache_path):
try:
with open(cache_path, 'rb') as f:
entry = pickle.load(f)
if self._is_cache_entry_valid(entry, ttl):
# Update memory cache
self.memory_cache[cache_key] = entry
return entry
except Exception as e:
print(f"⚠️ Disk cache read failed: {e}")
return None
def _is_cache_entry_valid(self, entry: Dict, ttl: int) -> bool:
"""Kiểm tra cache entry có còn valid không"""
if "timestamp" not in entry:
return False
try:
timestamp = datetime.fromisoformat(entry["timestamp"])
age = datetime.now() - timestamp
return age.total_seconds() < ttl
except:
return False
def _adjust_cached_results(self, cached_results: List, new_query: str, similarity: float) -> List:
"""Điều chỉnh cached results cho semantic match"""
adjusted_results = []
for result in cached_results:
# Adjust similarity score based on query similarity
if isinstance(result, dict) and "similarity" in result:
result["similarity"] *= similarity
result["source"] = "semantic_cache"
result["cache_similarity"] = similarity
adjusted_results.append(result)
return adjusted_results
def search_with_cache(self, query: str, top_k: int = 5, use_cache: bool = True) -> Dict:
"""Tìm kiếm với cache augmentation"""
start_time = time.time()
# Detect language
language = self.multilingual_manager.detect_language(query)
# Try to get from cache
cached_results, hit_type = None, CacheHitType.NONE
if use_cache:
cached_results, hit_type = self.get_cached_search_results(query, top_k, language)
if cached_results and hit_type != CacheHitType.NONE:
# Cache hit
response_time = time.time() - start_time
self.stats["response_times"].append(response_time)
return {
"query": query,
"results": cached_results,
"cache_hit": True,
"hit_type": hit_type.value,
"response_time_ms": round(response_time * 1000, 2),
"language": language,
"cached": True
}
# Cache miss - perform actual RAG search
rag_start_time = time.time()
rag_results = self.rag_system.semantic_search(query, top_k=top_k)
rag_time = time.time() - rag_start_time
# Cache the results for next time
if use_cache and rag_results:
self.cache_search_results(query, rag_results, top_k, language)
total_time = time.time() - start_time
self.stats["response_times"].append(total_time)
# Convert RAG results to list of dicts
results_list = []
for result in rag_results:
results_list.append({
"text": result.text,
"similarity": result.similarity,
"metadata": result.metadata,
"source": "rag_search"
})
return {
"query": query,
"results": results_list,
"cache_hit": False,
"hit_type": "none",
"response_time_ms": round(total_time * 1000, 2),
"rag_time_ms": round(rag_time * 1000, 2),
"language": language,
"cached": False
}
def batch_search_with_cache(self, queries: List[str], top_k: int = 3) -> List[Dict]:
"""Batch search với cache optimization"""
results = []
# First pass: check cache for all queries
for query in queries:
language = self.multilingual_manager.detect_language(query)
cached_results, hit_type = self.get_cached_search_results(query, top_k, language)
if cached_results:
results.append({
"query": query,
"results": cached_results,
"cache_hit": True,
"hit_type": hit_type.value
})
else:
results.append({
"query": query,
"cache_hit": False,
"pending": True
})
# Process uncached queries in batch
uncached_queries = []
uncached_indices = []
for i, result in enumerate(results):
if result.get("pending", False):
uncached_queries.append(result["query"])
uncached_indices.append(i)
if uncached_queries:
# Process uncached queries
for idx, query in zip(uncached_indices, uncached_queries):
search_result = self.search_with_cache(query, top_k, use_cache=False)
results[idx] = search_result
return results
def get_cache_stats(self) -> Dict:
"""Lấy thống kê cache"""
total_hits = self.stats["cache_hits"]
total_queries = self.stats["total_queries"]
hit_rate = total_hits / total_queries if total_queries > 0 else 0
if self.stats["response_times"]:
avg_response_time = sum(self.stats["response_times"]) / len(self.stats["response_times"])
p95_response_time = np.percentile(self.stats["response_times"], 95)
else:
avg_response_time = p95_response_time = 0
# Calculate estimated cost savings
# Giả sử mỗi LLM call tốn $0.01, mỗi cache hit tiết kiệm được 1 call
cost_per_call = 0.01 # USD
estimated_savings = total_hits * cost_per_call
return {
"total_queries": total_queries,
"cache_hits": total_hits,
"cache_misses": total_queries - total_hits,
"hit_rate": round(hit_rate * 100, 2),
"exact_hits": self.stats["exact_hits"],
"semantic_hits": self.stats["semantic_hits"],
"avg_response_time_ms": round(avg_response_time * 1000, 2),
"p95_response_time_ms": round(p95_response_time * 1000, 2),
"memory_cache_size": len(self.memory_cache),
"semantic_cache_size": len(self.semantic_cache_embeddings),
"estimated_cost_savings_usd": round(estimated_savings, 2)
}
def clear_cache(self, cache_type: str = "all"):
"""Xóa cache"""
if cache_type == "all" or cache_type == "memory":
self.memory_cache.clear()
print("✅ Memory cache cleared")
if cache_type == "all" or cache_type == "semantic":
self.semantic_cache_index = None
self.semantic_cache_embeddings = []
self.semantic_cache_keys = []
print("✅ Semantic cache cleared")
if cache_type == "all" or cache_type == "disk":
import shutil
shutil.rmtree(self.config.CACHE_DIR, ignore_errors=True)
self._init_cache_directory()
print("✅ Disk cache cleared")
if cache_type == "all" or cache_type == "redis":
if self.redis_client:
try:
self.redis_client.flushdb()
print("✅ Redis cache cleared")
except Exception as e:
print(f"⚠️ Failed to clear Redis: {e}")