datbkpro commited on
Commit
47284c1
·
verified ·
1 Parent(s): 7c39dcb

Create cag_system.py

Browse files
Files changed (1) hide show
  1. core/cag_system.py +510 -0
core/cag_system.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # services/cag_service.py
2
+ import hashlib
3
+ import json
4
+ import time
5
+ from datetime import datetime, timedelta
6
+ from typing import List, Dict, Any, Optional, Tuple
7
+ import numpy as np
8
+ import faiss
9
+ import redis
10
+ import pickle
11
+ import os
12
+ from dataclasses import dataclass
13
+ from enum import Enum
14
+
15
+ @dataclass
16
+ class CAGConfig:
17
+ """Cấu hình hệ thống CAG"""
18
+ # Cache settings
19
+ USE_MEMORY_CACHE = True
20
+ USE_REDIS_CACHE = False
21
+ USE_DISK_CACHE = True
22
+ CACHE_DIR = ".cag_cache"
23
+
24
+ # TTL settings (seconds)
25
+ EMBEDDING_TTL = 86400 # 24 hours
26
+ SEARCH_RESULT_TTL = 3600 # 1 hour
27
+ SEMANTIC_CACHE_TTL = 7200 # 2 hours
28
+ GENERATION_TTL = 1800 # 30 minutes
29
+
30
+ # Cache thresholds
31
+ SEMANTIC_SIMILARITY_THRESHOLD = 0.85
32
+ MIN_QUERY_LENGTH = 3
33
+ MAX_CACHE_SIZE = 10000
34
+
35
+ # Performance settings
36
+ ENABLE_CACHE_STATS = True
37
+ LOG_CACHE_PERFORMANCE = True
38
+
39
+ class CacheHitType(str, Enum):
40
+ """Loại cache hit"""
41
+ EXACT = "exact"
42
+ SEMANTIC = "semantic"
43
+ PARTIAL = "partial"
44
+ NONE = "none"
45
+
46
+ class CAGService:
47
+ """Cache-Augmented Generation Service"""
48
+
49
+ def __init__(self, rag_system, multilingual_manager):
50
+ self.rag_system = rag_system
51
+ self.multilingual_manager = multilingual_manager
52
+
53
+ # Cache configuration
54
+ self.config = CAGConfig()
55
+
56
+ # Cache storage
57
+ self.memory_cache = {} # In-memory cache
58
+ self.semantic_cache_index = None
59
+ self.semantic_cache_embeddings = []
60
+ self.semantic_cache_keys = []
61
+
62
+ # Redis client (optional)
63
+ self.redis_client = None
64
+ self._init_redis()
65
+
66
+ # Disk cache
67
+ self._init_cache_directory()
68
+
69
+ # Performance tracking
70
+ self.stats = {
71
+ "total_queries": 0,
72
+ "cache_hits": 0,
73
+ "exact_hits": 0,
74
+ "semantic_hits": 0,
75
+ "response_times": [],
76
+ "cost_savings": 0
77
+ }
78
+
79
+ print("✅ CAG Service initialized")
80
+
81
+ def _init_redis(self):
82
+ """Khởi tạo Redis client nếu được cấu hình"""
83
+ if self.config.USE_REDIS_CACHE:
84
+ try:
85
+ self.redis_client = redis.Redis(
86
+ host='localhost',
87
+ port=6379,
88
+ db=0,
89
+ decode_responses=False
90
+ )
91
+ self.redis_client.ping()
92
+ print("✅ Redis cache connected")
93
+ except Exception as e:
94
+ print(f"⚠️ Redis not available: {e}")
95
+ self.config.USE_REDIS_CACHE = False
96
+
97
+ def _init_cache_directory(self):
98
+ """Khởi tạo thư mục cache"""
99
+ os.makedirs(self.config.CACHE_DIR, exist_ok=True)
100
+ os.makedirs(f"{self.config.CACHE_DIR}/embeddings", exist_ok=True)
101
+ os.makedirs(f"{self.config.CACHE_DIR}/results", exist_ok=True)
102
+
103
+ def _generate_cache_key(self, data_type: str, content: str, params: Dict = None) -> str:
104
+ """Tạo cache key duy nhất"""
105
+ key_data = {
106
+ "type": data_type,
107
+ "content": content,
108
+ "params": params or {}
109
+ }
110
+ key_str = json.dumps(key_data, sort_keys=True)
111
+ return hashlib.sha256(key_str.encode()).hexdigest()[:32]
112
+
113
+ def cache_embedding(self, text: str, embedding: np.ndarray, language: str):
114
+ """Cache embedding của text"""
115
+ if not self.config.USE_MEMORY_CACHE:
116
+ return
117
+
118
+ cache_key = self._generate_cache_key("embedding", text, {"language": language})
119
+
120
+ cache_entry = {
121
+ "embedding": embedding.tolist(),
122
+ "language": language,
123
+ "timestamp": datetime.now().isoformat(),
124
+ "text_length": len(text)
125
+ }
126
+
127
+ # Lưu vào memory cache
128
+ self.memory_cache[cache_key] = cache_entry
129
+
130
+ # Lưu vào disk cache
131
+ if self.config.USE_DISK_CACHE:
132
+ cache_path = f"{self.config.CACHE_DIR}/embeddings/{cache_key}.pkl"
133
+ try:
134
+ with open(cache_path, 'wb') as f:
135
+ pickle.dump(cache_entry, f)
136
+ except Exception as e:
137
+ print(f"⚠️ Failed to save embedding cache: {e}")
138
+
139
+ def get_cached_embedding(self, text: str, language: str) -> Optional[np.ndarray]:
140
+ """Lấy embedding từ cache nếu có"""
141
+ cache_key = self._generate_cache_key("embedding", text, {"language": language})
142
+
143
+ # Check memory cache first
144
+ if cache_key in self.memory_cache:
145
+ entry = self.memory_cache[cache_key]
146
+ if self._is_cache_entry_valid(entry, self.config.EMBEDDING_TTL):
147
+ return np.array(entry["embedding"])
148
+
149
+ # Check disk cache
150
+ if self.config.USE_DISK_CACHE:
151
+ cache_path = f"{self.config.CACHE_DIR}/embeddings/{cache_key}.pkl"
152
+ if os.path.exists(cache_path):
153
+ try:
154
+ with open(cache_path, 'rb') as f:
155
+ entry = pickle.load(f)
156
+ if self._is_cache_entry_valid(entry, self.config.EMBEDDING_TTL):
157
+ # Update memory cache
158
+ self.memory_cache[cache_key] = entry
159
+ return np.array(entry["embedding"])
160
+ except Exception as e:
161
+ print(f"⚠️ Failed to load embedding cache: {e}")
162
+
163
+ return None
164
+
165
+ def cache_search_results(self, query: str, results: List, top_k: int, language: str):
166
+ """Cache kết quả tìm kiếm"""
167
+ cache_key = self._generate_cache_key("search", query, {
168
+ "top_k": top_k,
169
+ "language": language
170
+ })
171
+
172
+ # Generate query embedding for semantic cache
173
+ embedding_model = self.multilingual_manager.get_embedding_model(language)
174
+ if embedding_model:
175
+ query_embedding = embedding_model.encode([query])[0]
176
+ self._update_semantic_cache(cache_key, query_embedding)
177
+
178
+ cache_entry = {
179
+ "query": query,
180
+ "results": [r.__dict__ if hasattr(r, '__dict__') else r for r in results],
181
+ "timestamp": datetime.now().isoformat(),
182
+ "language": language,
183
+ "top_k": top_k
184
+ }
185
+
186
+ # Save to memory cache
187
+ self.memory_cache[cache_key] = cache_entry
188
+
189
+ # Save to Redis if available
190
+ if self.config.USE_REDIS_CACHE and self.redis_client:
191
+ try:
192
+ self.redis_client.setex(
193
+ f"cag:search:{cache_key}",
194
+ self.config.SEARCH_RESULT_TTL,
195
+ pickle.dumps(cache_entry)
196
+ )
197
+ except Exception as e:
198
+ print(f"⚠️ Redis cache failed: {e}")
199
+
200
+ # Save to disk
201
+ if self.config.USE_DISK_CACHE:
202
+ cache_path = f"{self.config.CACHE_DIR}/results/{cache_key}.pkl"
203
+ try:
204
+ with open(cache_path, 'wb') as f:
205
+ pickle.dump(cache_entry, f)
206
+ except Exception as e:
207
+ print(f"⚠️ Failed to save search cache: {e}")
208
+
209
+ def get_cached_search_results(self, query: str, top_k: int, language: str) -> Tuple[Optional[List], CacheHitType]:
210
+ """Lấy kết quả tìm kiếm từ cache"""
211
+ self.stats["total_queries"] += 1
212
+
213
+ if len(query.strip()) < self.config.MIN_QUERY_LENGTH:
214
+ return None, CacheHitType.NONE
215
+
216
+ # 1. Try exact match cache
217
+ exact_key = self._generate_cache_key("search", query, {
218
+ "top_k": top_k,
219
+ "language": language
220
+ })
221
+
222
+ cached_results = self._get_cache_entry(exact_key, self.config.SEARCH_RESULT_TTL)
223
+ if cached_results:
224
+ self.stats["cache_hits"] += 1
225
+ self.stats["exact_hits"] += 1
226
+ return cached_results.get("results"), CacheHitType.EXACT
227
+
228
+ # 2. Try semantic cache
229
+ if self.semantic_cache_index is not None and len(self.semantic_cache_embeddings) > 0:
230
+ embedding_model = self.multilingual_manager.get_embedding_model(language)
231
+ if embedding_model:
232
+ query_embedding = embedding_model.encode([query])[0]
233
+ similar_key, similarity = self._semantic_cache_lookup(query_embedding)
234
+
235
+ if similarity >= self.config.SEMANTIC_SIMILARITY_THRESHOLD:
236
+ cached_results = self._get_cache_entry(similar_key, self.config.SEMANTIC_CACHE_TTL)
237
+ if cached_results:
238
+ self.stats["cache_hits"] += 1
239
+ self.stats["semantic_hits"] += 1
240
+
241
+ # Adjust results for semantic match
242
+ adjusted_results = self._adjust_cached_results(
243
+ cached_results.get("results"),
244
+ query,
245
+ similarity
246
+ )
247
+ return adjusted_results, CacheHitType.SEMANTIC
248
+
249
+ return None, CacheHitType.NONE
250
+
251
+ def _update_semantic_cache(self, cache_key: str, embedding: np.ndarray):
252
+ """Cập nhật semantic cache"""
253
+ if len(self.semantic_cache_embeddings) >= self.config.MAX_CACHE_SIZE:
254
+ # Remove oldest entries
255
+ self.semantic_cache_keys.pop(0)
256
+ self.semantic_cache_embeddings.pop(0)
257
+
258
+ self.semantic_cache_keys.append(cache_key)
259
+ self.semantic_cache_embeddings.append(embedding)
260
+
261
+ # Rebuild FAISS index
262
+ if len(self.semantic_cache_embeddings) > 0:
263
+ embeddings_array = np.array(self.semantic_cache_embeddings).astype(np.float32)
264
+ dimension = embeddings_array.shape[1]
265
+
266
+ if self.semantic_cache_index is None:
267
+ self.semantic_cache_index = faiss.IndexFlatIP(dimension)
268
+
269
+ self.semantic_cache_index.reset()
270
+ faiss.normalize_L2(embeddings_array)
271
+ self.semantic_cache_index.add(embeddings_array)
272
+
273
+ def _semantic_cache_lookup(self, query_embedding: np.ndarray) -> Tuple[Optional[str], float]:
274
+ """Tìm kiếm trong semantic cache"""
275
+ if len(self.semantic_cache_embeddings) == 0:
276
+ return None, 0.0
277
+
278
+ query_embedding = query_embedding / np.linalg.norm(query_embedding)
279
+ query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
280
+
281
+ distances, indices = self.semantic_cache_index.search(query_embedding, k=1)
282
+
283
+ if len(indices[0]) > 0 and indices[0][0] != -1:
284
+ idx = indices[0][0]
285
+ similarity = 1 - distances[0][0]
286
+ return self.semantic_cache_keys[idx], similarity
287
+
288
+ return None, 0.0
289
+
290
+ def _get_cache_entry(self, cache_key: str, ttl: int) -> Optional[Dict]:
291
+ """Lấy cache entry từ multiple layers"""
292
+ # Check memory cache
293
+ if cache_key in self.memory_cache:
294
+ entry = self.memory_cache[cache_key]
295
+ if self._is_cache_entry_valid(entry, ttl):
296
+ return entry
297
+
298
+ # Check Redis
299
+ if self.config.USE_REDIS_CACHE and self.redis_client:
300
+ try:
301
+ cached = self.redis_client.get(f"cag:search:{cache_key}")
302
+ if cached:
303
+ entry = pickle.loads(cached)
304
+ if self._is_cache_entry_valid(entry, ttl):
305
+ # Update memory cache
306
+ self.memory_cache[cache_key] = entry
307
+ return entry
308
+ except Exception as e:
309
+ print(f"⚠️ Redis get failed: {e}")
310
+
311
+ # Check disk cache
312
+ if self.config.USE_DISK_CACHE:
313
+ cache_path = f"{self.config.CACHE_DIR}/results/{cache_key}.pkl"
314
+ if os.path.exists(cache_path):
315
+ try:
316
+ with open(cache_path, 'rb') as f:
317
+ entry = pickle.load(f)
318
+ if self._is_cache_entry_valid(entry, ttl):
319
+ # Update memory cache
320
+ self.memory_cache[cache_key] = entry
321
+ return entry
322
+ except Exception as e:
323
+ print(f"⚠️ Disk cache read failed: {e}")
324
+
325
+ return None
326
+
327
+ def _is_cache_entry_valid(self, entry: Dict, ttl: int) -> bool:
328
+ """Kiểm tra cache entry có còn valid không"""
329
+ if "timestamp" not in entry:
330
+ return False
331
+
332
+ try:
333
+ timestamp = datetime.fromisoformat(entry["timestamp"])
334
+ age = datetime.now() - timestamp
335
+ return age.total_seconds() < ttl
336
+ except:
337
+ return False
338
+
339
+ def _adjust_cached_results(self, cached_results: List, new_query: str, similarity: float) -> List:
340
+ """Điều chỉnh cached results cho semantic match"""
341
+ adjusted_results = []
342
+
343
+ for result in cached_results:
344
+ # Adjust similarity score based on query similarity
345
+ if isinstance(result, dict) and "similarity" in result:
346
+ result["similarity"] *= similarity
347
+ result["source"] = "semantic_cache"
348
+ result["cache_similarity"] = similarity
349
+
350
+ adjusted_results.append(result)
351
+
352
+ return adjusted_results
353
+
354
+ def search_with_cache(self, query: str, top_k: int = 5, use_cache: bool = True) -> Dict:
355
+ """Tìm kiếm với cache augmentation"""
356
+ start_time = time.time()
357
+
358
+ # Detect language
359
+ language = self.multilingual_manager.detect_language(query)
360
+
361
+ # Try to get from cache
362
+ cached_results, hit_type = None, CacheHitType.NONE
363
+ if use_cache:
364
+ cached_results, hit_type = self.get_cached_search_results(query, top_k, language)
365
+
366
+ if cached_results and hit_type != CacheHitType.NONE:
367
+ # Cache hit
368
+ response_time = time.time() - start_time
369
+ self.stats["response_times"].append(response_time)
370
+
371
+ return {
372
+ "query": query,
373
+ "results": cached_results,
374
+ "cache_hit": True,
375
+ "hit_type": hit_type.value,
376
+ "response_time_ms": round(response_time * 1000, 2),
377
+ "language": language,
378
+ "cached": True
379
+ }
380
+
381
+ # Cache miss - perform actual RAG search
382
+ rag_start_time = time.time()
383
+ rag_results = self.rag_system.semantic_search(query, top_k=top_k)
384
+ rag_time = time.time() - rag_start_time
385
+
386
+ # Cache the results for next time
387
+ if use_cache and rag_results:
388
+ self.cache_search_results(query, rag_results, top_k, language)
389
+
390
+ total_time = time.time() - start_time
391
+ self.stats["response_times"].append(total_time)
392
+
393
+ # Convert RAG results to list of dicts
394
+ results_list = []
395
+ for result in rag_results:
396
+ results_list.append({
397
+ "text": result.text,
398
+ "similarity": result.similarity,
399
+ "metadata": result.metadata,
400
+ "source": "rag_search"
401
+ })
402
+
403
+ return {
404
+ "query": query,
405
+ "results": results_list,
406
+ "cache_hit": False,
407
+ "hit_type": "none",
408
+ "response_time_ms": round(total_time * 1000, 2),
409
+ "rag_time_ms": round(rag_time * 1000, 2),
410
+ "language": language,
411
+ "cached": False
412
+ }
413
+
414
+ def batch_search_with_cache(self, queries: List[str], top_k: int = 3) -> List[Dict]:
415
+ """Batch search với cache optimization"""
416
+ results = []
417
+
418
+ # First pass: check cache for all queries
419
+ for query in queries:
420
+ language = self.multilingual_manager.detect_language(query)
421
+ cached_results, hit_type = self.get_cached_search_results(query, top_k, language)
422
+
423
+ if cached_results:
424
+ results.append({
425
+ "query": query,
426
+ "results": cached_results,
427
+ "cache_hit": True,
428
+ "hit_type": hit_type.value
429
+ })
430
+ else:
431
+ results.append({
432
+ "query": query,
433
+ "cache_hit": False,
434
+ "pending": True
435
+ })
436
+
437
+ # Process uncached queries in batch
438
+ uncached_queries = []
439
+ uncached_indices = []
440
+
441
+ for i, result in enumerate(results):
442
+ if result.get("pending", False):
443
+ uncached_queries.append(result["query"])
444
+ uncached_indices.append(i)
445
+
446
+ if uncached_queries:
447
+ # Process uncached queries
448
+ for idx, query in zip(uncached_indices, uncached_queries):
449
+ search_result = self.search_with_cache(query, top_k, use_cache=False)
450
+ results[idx] = search_result
451
+
452
+ return results
453
+
454
+ def get_cache_stats(self) -> Dict:
455
+ """Lấy thống kê cache"""
456
+ total_hits = self.stats["cache_hits"]
457
+ total_queries = self.stats["total_queries"]
458
+
459
+ hit_rate = total_hits / total_queries if total_queries > 0 else 0
460
+
461
+ if self.stats["response_times"]:
462
+ avg_response_time = sum(self.stats["response_times"]) / len(self.stats["response_times"])
463
+ p95_response_time = np.percentile(self.stats["response_times"], 95)
464
+ else:
465
+ avg_response_time = p95_response_time = 0
466
+
467
+ # Calculate estimated cost savings
468
+ # Giả sử mỗi LLM call tốn $0.01, mỗi cache hit tiết kiệm được 1 call
469
+ cost_per_call = 0.01 # USD
470
+ estimated_savings = total_hits * cost_per_call
471
+
472
+ return {
473
+ "total_queries": total_queries,
474
+ "cache_hits": total_hits,
475
+ "cache_misses": total_queries - total_hits,
476
+ "hit_rate": round(hit_rate * 100, 2),
477
+ "exact_hits": self.stats["exact_hits"],
478
+ "semantic_hits": self.stats["semantic_hits"],
479
+ "avg_response_time_ms": round(avg_response_time * 1000, 2),
480
+ "p95_response_time_ms": round(p95_response_time * 1000, 2),
481
+ "memory_cache_size": len(self.memory_cache),
482
+ "semantic_cache_size": len(self.semantic_cache_embeddings),
483
+ "estimated_cost_savings_usd": round(estimated_savings, 2)
484
+ }
485
+
486
+ def clear_cache(self, cache_type: str = "all"):
487
+ """Xóa cache"""
488
+ if cache_type == "all" or cache_type == "memory":
489
+ self.memory_cache.clear()
490
+ print("✅ Memory cache cleared")
491
+
492
+ if cache_type == "all" or cache_type == "semantic":
493
+ self.semantic_cache_index = None
494
+ self.semantic_cache_embeddings = []
495
+ self.semantic_cache_keys = []
496
+ print("✅ Semantic cache cleared")
497
+
498
+ if cache_type == "all" or cache_type == "disk":
499
+ import shutil
500
+ shutil.rmtree(self.config.CACHE_DIR, ignore_errors=True)
501
+ self._init_cache_directory()
502
+ print("✅ Disk cache cleared")
503
+
504
+ if cache_type == "all" or cache_type == "redis":
505
+ if self.redis_client:
506
+ try:
507
+ self.redis_client.flushdb()
508
+ print("✅ Redis cache cleared")
509
+ except Exception as e:
510
+ print(f"⚠️ Failed to clear Redis: {e}")