|
|
|
|
|
import hashlib |
|
|
import json |
|
|
import time |
|
|
from datetime import datetime, timedelta |
|
|
from typing import List, Dict, Any, Optional, Tuple |
|
|
import numpy as np |
|
|
import faiss |
|
|
import redis |
|
|
import pickle |
|
|
import os |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
@dataclass |
|
|
class CAGConfig: |
|
|
"""Cấu hình hệ thống CAG""" |
|
|
|
|
|
USE_MEMORY_CACHE = True |
|
|
USE_REDIS_CACHE = False |
|
|
USE_DISK_CACHE = True |
|
|
CACHE_DIR = ".cag_cache" |
|
|
|
|
|
|
|
|
EMBEDDING_TTL = 86400 |
|
|
SEARCH_RESULT_TTL = 3600 |
|
|
SEMANTIC_CACHE_TTL = 7200 |
|
|
GENERATION_TTL = 1800 |
|
|
|
|
|
|
|
|
SEMANTIC_SIMILARITY_THRESHOLD = 0.85 |
|
|
MIN_QUERY_LENGTH = 3 |
|
|
MAX_CACHE_SIZE = 10000 |
|
|
|
|
|
|
|
|
ENABLE_CACHE_STATS = True |
|
|
LOG_CACHE_PERFORMANCE = True |
|
|
|
|
|
class CacheHitType(str, Enum): |
|
|
"""Loại cache hit""" |
|
|
EXACT = "exact" |
|
|
SEMANTIC = "semantic" |
|
|
PARTIAL = "partial" |
|
|
NONE = "none" |
|
|
|
|
|
class CAGService: |
|
|
"""Cache-Augmented Generation Service""" |
|
|
|
|
|
def __init__(self, rag_system, multilingual_manager): |
|
|
self.rag_system = rag_system |
|
|
self.multilingual_manager = multilingual_manager |
|
|
|
|
|
|
|
|
self.config = CAGConfig() |
|
|
|
|
|
|
|
|
self.memory_cache = {} |
|
|
self.semantic_cache_index = None |
|
|
self.semantic_cache_embeddings = [] |
|
|
self.semantic_cache_keys = [] |
|
|
|
|
|
|
|
|
self.redis_client = None |
|
|
self._init_redis() |
|
|
|
|
|
|
|
|
self._init_cache_directory() |
|
|
|
|
|
|
|
|
self.stats = { |
|
|
"total_queries": 0, |
|
|
"cache_hits": 0, |
|
|
"exact_hits": 0, |
|
|
"semantic_hits": 0, |
|
|
"response_times": [], |
|
|
"cost_savings": 0 |
|
|
} |
|
|
|
|
|
print("✅ CAG Service initialized") |
|
|
|
|
|
def _init_redis(self): |
|
|
"""Khởi tạo Redis client nếu được cấu hình""" |
|
|
if self.config.USE_REDIS_CACHE: |
|
|
try: |
|
|
self.redis_client = redis.Redis( |
|
|
host='localhost', |
|
|
port=6379, |
|
|
db=0, |
|
|
decode_responses=False |
|
|
) |
|
|
self.redis_client.ping() |
|
|
print("✅ Redis cache connected") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Redis not available: {e}") |
|
|
self.config.USE_REDIS_CACHE = False |
|
|
|
|
|
def _init_cache_directory(self): |
|
|
"""Khởi tạo thư mục cache""" |
|
|
os.makedirs(self.config.CACHE_DIR, exist_ok=True) |
|
|
os.makedirs(f"{self.config.CACHE_DIR}/embeddings", exist_ok=True) |
|
|
os.makedirs(f"{self.config.CACHE_DIR}/results", exist_ok=True) |
|
|
|
|
|
def _generate_cache_key(self, data_type: str, content: str, params: Dict = None) -> str: |
|
|
"""Tạo cache key duy nhất""" |
|
|
key_data = { |
|
|
"type": data_type, |
|
|
"content": content, |
|
|
"params": params or {} |
|
|
} |
|
|
key_str = json.dumps(key_data, sort_keys=True) |
|
|
return hashlib.sha256(key_str.encode()).hexdigest()[:32] |
|
|
|
|
|
def cache_embedding(self, text: str, embedding: np.ndarray, language: str): |
|
|
"""Cache embedding của text""" |
|
|
if not self.config.USE_MEMORY_CACHE: |
|
|
return |
|
|
|
|
|
cache_key = self._generate_cache_key("embedding", text, {"language": language}) |
|
|
|
|
|
cache_entry = { |
|
|
"embedding": embedding.tolist(), |
|
|
"language": language, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"text_length": len(text) |
|
|
} |
|
|
|
|
|
|
|
|
self.memory_cache[cache_key] = cache_entry |
|
|
|
|
|
|
|
|
if self.config.USE_DISK_CACHE: |
|
|
cache_path = f"{self.config.CACHE_DIR}/embeddings/{cache_key}.pkl" |
|
|
try: |
|
|
with open(cache_path, 'wb') as f: |
|
|
pickle.dump(cache_entry, f) |
|
|
except Exception as e: |
|
|
print(f"⚠️ Failed to save embedding cache: {e}") |
|
|
|
|
|
def get_cached_embedding(self, text: str, language: str) -> Optional[np.ndarray]: |
|
|
"""Lấy embedding từ cache nếu có""" |
|
|
cache_key = self._generate_cache_key("embedding", text, {"language": language}) |
|
|
|
|
|
|
|
|
if cache_key in self.memory_cache: |
|
|
entry = self.memory_cache[cache_key] |
|
|
if self._is_cache_entry_valid(entry, self.config.EMBEDDING_TTL): |
|
|
return np.array(entry["embedding"]) |
|
|
|
|
|
|
|
|
if self.config.USE_DISK_CACHE: |
|
|
cache_path = f"{self.config.CACHE_DIR}/embeddings/{cache_key}.pkl" |
|
|
if os.path.exists(cache_path): |
|
|
try: |
|
|
with open(cache_path, 'rb') as f: |
|
|
entry = pickle.load(f) |
|
|
if self._is_cache_entry_valid(entry, self.config.EMBEDDING_TTL): |
|
|
|
|
|
self.memory_cache[cache_key] = entry |
|
|
return np.array(entry["embedding"]) |
|
|
except Exception as e: |
|
|
print(f"⚠️ Failed to load embedding cache: {e}") |
|
|
|
|
|
return None |
|
|
|
|
|
def cache_search_results(self, query: str, results: List, top_k: int, language: str): |
|
|
"""Cache kết quả tìm kiếm""" |
|
|
cache_key = self._generate_cache_key("search", query, { |
|
|
"top_k": top_k, |
|
|
"language": language |
|
|
}) |
|
|
|
|
|
|
|
|
embedding_model = self.multilingual_manager.get_embedding_model(language) |
|
|
if embedding_model: |
|
|
query_embedding = embedding_model.encode([query])[0] |
|
|
self._update_semantic_cache(cache_key, query_embedding) |
|
|
|
|
|
cache_entry = { |
|
|
"query": query, |
|
|
"results": [r.__dict__ if hasattr(r, '__dict__') else r for r in results], |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"language": language, |
|
|
"top_k": top_k |
|
|
} |
|
|
|
|
|
|
|
|
self.memory_cache[cache_key] = cache_entry |
|
|
|
|
|
|
|
|
if self.config.USE_REDIS_CACHE and self.redis_client: |
|
|
try: |
|
|
self.redis_client.setex( |
|
|
f"cag:search:{cache_key}", |
|
|
self.config.SEARCH_RESULT_TTL, |
|
|
pickle.dumps(cache_entry) |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"⚠️ Redis cache failed: {e}") |
|
|
|
|
|
|
|
|
if self.config.USE_DISK_CACHE: |
|
|
cache_path = f"{self.config.CACHE_DIR}/results/{cache_key}.pkl" |
|
|
try: |
|
|
with open(cache_path, 'wb') as f: |
|
|
pickle.dump(cache_entry, f) |
|
|
except Exception as e: |
|
|
print(f"⚠️ Failed to save search cache: {e}") |
|
|
|
|
|
def get_cached_search_results(self, query: str, top_k: int, language: str) -> Tuple[Optional[List], CacheHitType]: |
|
|
"""Lấy kết quả tìm kiếm từ cache""" |
|
|
self.stats["total_queries"] += 1 |
|
|
|
|
|
if len(query.strip()) < self.config.MIN_QUERY_LENGTH: |
|
|
return None, CacheHitType.NONE |
|
|
|
|
|
|
|
|
exact_key = self._generate_cache_key("search", query, { |
|
|
"top_k": top_k, |
|
|
"language": language |
|
|
}) |
|
|
|
|
|
cached_results = self._get_cache_entry(exact_key, self.config.SEARCH_RESULT_TTL) |
|
|
if cached_results: |
|
|
self.stats["cache_hits"] += 1 |
|
|
self.stats["exact_hits"] += 1 |
|
|
return cached_results.get("results"), CacheHitType.EXACT |
|
|
|
|
|
|
|
|
if self.semantic_cache_index is not None and len(self.semantic_cache_embeddings) > 0: |
|
|
embedding_model = self.multilingual_manager.get_embedding_model(language) |
|
|
if embedding_model: |
|
|
query_embedding = embedding_model.encode([query])[0] |
|
|
similar_key, similarity = self._semantic_cache_lookup(query_embedding) |
|
|
|
|
|
if similarity >= self.config.SEMANTIC_SIMILARITY_THRESHOLD: |
|
|
cached_results = self._get_cache_entry(similar_key, self.config.SEMANTIC_CACHE_TTL) |
|
|
if cached_results: |
|
|
self.stats["cache_hits"] += 1 |
|
|
self.stats["semantic_hits"] += 1 |
|
|
|
|
|
|
|
|
adjusted_results = self._adjust_cached_results( |
|
|
cached_results.get("results"), |
|
|
query, |
|
|
similarity |
|
|
) |
|
|
return adjusted_results, CacheHitType.SEMANTIC |
|
|
|
|
|
return None, CacheHitType.NONE |
|
|
|
|
|
def _update_semantic_cache(self, cache_key: str, embedding: np.ndarray): |
|
|
"""Cập nhật semantic cache""" |
|
|
if len(self.semantic_cache_embeddings) >= self.config.MAX_CACHE_SIZE: |
|
|
|
|
|
self.semantic_cache_keys.pop(0) |
|
|
self.semantic_cache_embeddings.pop(0) |
|
|
|
|
|
self.semantic_cache_keys.append(cache_key) |
|
|
self.semantic_cache_embeddings.append(embedding) |
|
|
|
|
|
|
|
|
if len(self.semantic_cache_embeddings) > 0: |
|
|
embeddings_array = np.array(self.semantic_cache_embeddings).astype(np.float32) |
|
|
dimension = embeddings_array.shape[1] |
|
|
|
|
|
if self.semantic_cache_index is None: |
|
|
self.semantic_cache_index = faiss.IndexFlatIP(dimension) |
|
|
|
|
|
self.semantic_cache_index.reset() |
|
|
faiss.normalize_L2(embeddings_array) |
|
|
self.semantic_cache_index.add(embeddings_array) |
|
|
|
|
|
def _semantic_cache_lookup(self, query_embedding: np.ndarray) -> Tuple[Optional[str], float]: |
|
|
"""Tìm kiếm trong semantic cache""" |
|
|
if len(self.semantic_cache_embeddings) == 0: |
|
|
return None, 0.0 |
|
|
|
|
|
query_embedding = query_embedding / np.linalg.norm(query_embedding) |
|
|
query_embedding = query_embedding.reshape(1, -1).astype(np.float32) |
|
|
|
|
|
distances, indices = self.semantic_cache_index.search(query_embedding, k=1) |
|
|
|
|
|
if len(indices[0]) > 0 and indices[0][0] != -1: |
|
|
idx = indices[0][0] |
|
|
similarity = 1 - distances[0][0] |
|
|
return self.semantic_cache_keys[idx], similarity |
|
|
|
|
|
return None, 0.0 |
|
|
|
|
|
def _get_cache_entry(self, cache_key: str, ttl: int) -> Optional[Dict]: |
|
|
"""Lấy cache entry từ multiple layers""" |
|
|
|
|
|
if cache_key in self.memory_cache: |
|
|
entry = self.memory_cache[cache_key] |
|
|
if self._is_cache_entry_valid(entry, ttl): |
|
|
return entry |
|
|
|
|
|
|
|
|
if self.config.USE_REDIS_CACHE and self.redis_client: |
|
|
try: |
|
|
cached = self.redis_client.get(f"cag:search:{cache_key}") |
|
|
if cached: |
|
|
entry = pickle.loads(cached) |
|
|
if self._is_cache_entry_valid(entry, ttl): |
|
|
|
|
|
self.memory_cache[cache_key] = entry |
|
|
return entry |
|
|
except Exception as e: |
|
|
print(f"⚠️ Redis get failed: {e}") |
|
|
|
|
|
|
|
|
if self.config.USE_DISK_CACHE: |
|
|
cache_path = f"{self.config.CACHE_DIR}/results/{cache_key}.pkl" |
|
|
if os.path.exists(cache_path): |
|
|
try: |
|
|
with open(cache_path, 'rb') as f: |
|
|
entry = pickle.load(f) |
|
|
if self._is_cache_entry_valid(entry, ttl): |
|
|
|
|
|
self.memory_cache[cache_key] = entry |
|
|
return entry |
|
|
except Exception as e: |
|
|
print(f"⚠️ Disk cache read failed: {e}") |
|
|
|
|
|
return None |
|
|
|
|
|
def _is_cache_entry_valid(self, entry: Dict, ttl: int) -> bool: |
|
|
"""Kiểm tra cache entry có còn valid không""" |
|
|
if "timestamp" not in entry: |
|
|
return False |
|
|
|
|
|
try: |
|
|
timestamp = datetime.fromisoformat(entry["timestamp"]) |
|
|
age = datetime.now() - timestamp |
|
|
return age.total_seconds() < ttl |
|
|
except: |
|
|
return False |
|
|
|
|
|
def _adjust_cached_results(self, cached_results: List, new_query: str, similarity: float) -> List: |
|
|
"""Điều chỉnh cached results cho semantic match""" |
|
|
adjusted_results = [] |
|
|
|
|
|
for result in cached_results: |
|
|
|
|
|
if isinstance(result, dict) and "similarity" in result: |
|
|
result["similarity"] *= similarity |
|
|
result["source"] = "semantic_cache" |
|
|
result["cache_similarity"] = similarity |
|
|
|
|
|
adjusted_results.append(result) |
|
|
|
|
|
return adjusted_results |
|
|
|
|
|
def search_with_cache(self, query: str, top_k: int = 5, use_cache: bool = True) -> Dict: |
|
|
"""Tìm kiếm với cache augmentation""" |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
language = self.multilingual_manager.detect_language(query) |
|
|
|
|
|
|
|
|
cached_results, hit_type = None, CacheHitType.NONE |
|
|
if use_cache: |
|
|
cached_results, hit_type = self.get_cached_search_results(query, top_k, language) |
|
|
|
|
|
if cached_results and hit_type != CacheHitType.NONE: |
|
|
|
|
|
response_time = time.time() - start_time |
|
|
self.stats["response_times"].append(response_time) |
|
|
|
|
|
return { |
|
|
"query": query, |
|
|
"results": cached_results, |
|
|
"cache_hit": True, |
|
|
"hit_type": hit_type.value, |
|
|
"response_time_ms": round(response_time * 1000, 2), |
|
|
"language": language, |
|
|
"cached": True |
|
|
} |
|
|
|
|
|
|
|
|
rag_start_time = time.time() |
|
|
rag_results = self.rag_system.semantic_search(query, top_k=top_k) |
|
|
rag_time = time.time() - rag_start_time |
|
|
|
|
|
|
|
|
if use_cache and rag_results: |
|
|
self.cache_search_results(query, rag_results, top_k, language) |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
self.stats["response_times"].append(total_time) |
|
|
|
|
|
|
|
|
results_list = [] |
|
|
for result in rag_results: |
|
|
results_list.append({ |
|
|
"text": result.text, |
|
|
"similarity": result.similarity, |
|
|
"metadata": result.metadata, |
|
|
"source": "rag_search" |
|
|
}) |
|
|
|
|
|
return { |
|
|
"query": query, |
|
|
"results": results_list, |
|
|
"cache_hit": False, |
|
|
"hit_type": "none", |
|
|
"response_time_ms": round(total_time * 1000, 2), |
|
|
"rag_time_ms": round(rag_time * 1000, 2), |
|
|
"language": language, |
|
|
"cached": False |
|
|
} |
|
|
|
|
|
def batch_search_with_cache(self, queries: List[str], top_k: int = 3) -> List[Dict]: |
|
|
"""Batch search với cache optimization""" |
|
|
results = [] |
|
|
|
|
|
|
|
|
for query in queries: |
|
|
language = self.multilingual_manager.detect_language(query) |
|
|
cached_results, hit_type = self.get_cached_search_results(query, top_k, language) |
|
|
|
|
|
if cached_results: |
|
|
results.append({ |
|
|
"query": query, |
|
|
"results": cached_results, |
|
|
"cache_hit": True, |
|
|
"hit_type": hit_type.value |
|
|
}) |
|
|
else: |
|
|
results.append({ |
|
|
"query": query, |
|
|
"cache_hit": False, |
|
|
"pending": True |
|
|
}) |
|
|
|
|
|
|
|
|
uncached_queries = [] |
|
|
uncached_indices = [] |
|
|
|
|
|
for i, result in enumerate(results): |
|
|
if result.get("pending", False): |
|
|
uncached_queries.append(result["query"]) |
|
|
uncached_indices.append(i) |
|
|
|
|
|
if uncached_queries: |
|
|
|
|
|
for idx, query in zip(uncached_indices, uncached_queries): |
|
|
search_result = self.search_with_cache(query, top_k, use_cache=False) |
|
|
results[idx] = search_result |
|
|
|
|
|
return results |
|
|
|
|
|
def get_cache_stats(self) -> Dict: |
|
|
"""Lấy thống kê cache""" |
|
|
total_hits = self.stats["cache_hits"] |
|
|
total_queries = self.stats["total_queries"] |
|
|
|
|
|
hit_rate = total_hits / total_queries if total_queries > 0 else 0 |
|
|
|
|
|
if self.stats["response_times"]: |
|
|
avg_response_time = sum(self.stats["response_times"]) / len(self.stats["response_times"]) |
|
|
p95_response_time = np.percentile(self.stats["response_times"], 95) |
|
|
else: |
|
|
avg_response_time = p95_response_time = 0 |
|
|
|
|
|
|
|
|
|
|
|
cost_per_call = 0.01 |
|
|
estimated_savings = total_hits * cost_per_call |
|
|
|
|
|
return { |
|
|
"total_queries": total_queries, |
|
|
"cache_hits": total_hits, |
|
|
"cache_misses": total_queries - total_hits, |
|
|
"hit_rate": round(hit_rate * 100, 2), |
|
|
"exact_hits": self.stats["exact_hits"], |
|
|
"semantic_hits": self.stats["semantic_hits"], |
|
|
"avg_response_time_ms": round(avg_response_time * 1000, 2), |
|
|
"p95_response_time_ms": round(p95_response_time * 1000, 2), |
|
|
"memory_cache_size": len(self.memory_cache), |
|
|
"semantic_cache_size": len(self.semantic_cache_embeddings), |
|
|
"estimated_cost_savings_usd": round(estimated_savings, 2) |
|
|
} |
|
|
|
|
|
def clear_cache(self, cache_type: str = "all"): |
|
|
"""Xóa cache""" |
|
|
if cache_type == "all" or cache_type == "memory": |
|
|
self.memory_cache.clear() |
|
|
print("✅ Memory cache cleared") |
|
|
|
|
|
if cache_type == "all" or cache_type == "semantic": |
|
|
self.semantic_cache_index = None |
|
|
self.semantic_cache_embeddings = [] |
|
|
self.semantic_cache_keys = [] |
|
|
print("✅ Semantic cache cleared") |
|
|
|
|
|
if cache_type == "all" or cache_type == "disk": |
|
|
import shutil |
|
|
shutil.rmtree(self.config.CACHE_DIR, ignore_errors=True) |
|
|
self._init_cache_directory() |
|
|
print("✅ Disk cache cleared") |
|
|
|
|
|
if cache_type == "all" or cache_type == "redis": |
|
|
if self.redis_client: |
|
|
try: |
|
|
self.redis_client.flushdb() |
|
|
print("✅ Redis cache cleared") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Failed to clear Redis: {e}") |