T-Phong commited on
Commit
e711718
·
1 Parent(s): 047898d
Files changed (1) hide show
  1. service/rag.py +86 -46
service/rag.py CHANGED
@@ -6,7 +6,7 @@ from sentence_transformers import SentenceTransformer
6
  from datasets import load_dataset, load_from_disk
7
  from huggingface_hub import snapshot_download
8
  from typing import List, Dict, Any, Optional
9
-
10
  from helper import format_metadata_list_to_context
11
 
12
  # ==============================================================================
@@ -15,7 +15,7 @@ from helper import format_metadata_list_to_context
15
  class HuggingFaceRAGService:
16
  _instance: Optional['HuggingFaceRAGService'] = None
17
 
18
- # Singleton Pattern: Đảm bảo chỉ có một instance của lớp này được tạo ra
19
  def __new__(cls):
20
  if cls._instance is None:
21
  print("Khởi tạo HuggingFaceRAGService...")
@@ -27,67 +27,107 @@ class HuggingFaceRAGService:
27
  if self._initialized:
28
  return
29
 
30
- # Cấu hình
31
  self.MODEL_NAME = "all-MiniLM-L6-v2"
32
- self.DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
33
- self.FAISS_PATH = os.path.join(self.DATA_DIR, "heritage.faiss")
34
- self.METADATA_PATH = os.path.join(self.DATA_DIR, "metadata.json")
35
- self.IDS_PATH = os.path.join(self.DATA_DIR, "ids.json")
36
 
37
- # Tải model dữ liệu
 
 
 
 
 
 
 
 
 
 
38
  self._load_model()
39
  self._load_data()
 
40
  self._initialized = True
41
  print("✅ HuggingFaceRAGService đã sẵn sàng.")
42
 
43
  def _load_model(self):
44
- print(f"🤖 [HF RAG] Đang tải model: {self.MODEL_NAME}...")
45
  self.model = SentenceTransformer(self.MODEL_NAME)
46
 
47
  def _load_data(self):
48
- self.index, self.metadata, self.ids = self._load_cache()
49
- if self.index and self.metadata and self.ids:
50
- print(f"💾 [HF RAG] Sử dụng cache FAISS index metadata (items: {len(self.ids)})")
51
- else:
52
- print("💾 [HF RAG] Cache không tồn tại. Tải dataset và xây dựng FAISS index...")
53
- dataset = load_dataset("synguyen1106/vietnam_heritage_embeddings_v4", split="train")
54
- vectors = np.array(dataset['embedding']).astype("float32")
55
- self.metadata = [{k: v for k, v in dataset[i].items() if k not in ['embedding', 'id', 'slug']} for i in range(len(dataset))]
56
- self.ids = [dataset[i]['id'] for i in range(len(dataset))]
57
- print(f"💾 [HF RAG] Đã tải {len(self.ids)} mục từ dataset.")
 
 
 
 
 
58
 
59
- d = vectors.shape[1]
60
- self.index = faiss.IndexFlatL2(d)
61
- self.index.add(vectors)
62
- print("🔨 [HF RAG] Số lượng vector trong FAISS index:", self.index.ntotal)
 
 
 
 
 
 
63
 
64
- self._save_cache(self.index, self.metadata, self.ids)
65
- print(f"💾 [HF RAG] Đã lưu cache tại: {self.FAISS_PATH}")
66
-
67
- def _save_cache(self, faiss_index, metadata_list, ids_list):
68
- os.makedirs(self.DATA_DIR, exist_ok=True)
69
- faiss.write_index(faiss_index, self.FAISS_PATH)
70
- with open(self.METADATA_PATH, "w", encoding="utf-8") as f:
71
- json.dump(metadata_list, f, ensure_ascii=False)
72
- with open(self.IDS_PATH, "w", encoding="utf-8") as f:
73
- json.dump(ids_list, f, ensure_ascii=False)
74
-
75
- def _load_cache(self):
76
- if not (os.path.exists(self.FAISS_PATH) and os.path.exists(self.METADATA_PATH) and os.path.exists(self.IDS_PATH)):
77
- return None, None, None
78
- idx = faiss.read_index(self.FAISS_PATH)
79
- with open(self.METADATA_PATH, "r", encoding="utf-8") as f:
80
- meta = json.load(f)
81
- with open(self.IDS_PATH, "r", encoding="utf-8") as f:
82
- ids_local = json.load(f)
83
- return idx, meta, ids_local
 
 
 
 
 
 
 
 
 
 
84
 
85
  def search(self, query: str, k: int = 2) -> List[Dict[str, Any]]:
 
86
  query_vec = self.model.encode([query], convert_to_numpy=True).astype("float32")
87
- _, indices = self.index.search(query_vec, k)
88
- results = [{"metadata": self.metadata[int(idx)]} for idx in indices[0]]
 
 
 
 
 
 
 
 
 
 
 
 
89
  return results
90
-
91
  # ==============================================================================
92
  # HỆ THỐNG RAG 2: SỬ DỤNG LOCAL DISK DATASET
93
  # ==============================================================================
 
6
  from datasets import load_dataset, load_from_disk
7
  from huggingface_hub import snapshot_download
8
  from typing import List, Dict, Any, Optional
9
+ from huggingface_hub import hf_hub_download
10
  from helper import format_metadata_list_to_context
11
 
12
  # ==============================================================================
 
15
  class HuggingFaceRAGService:
16
  _instance: Optional['HuggingFaceRAGService'] = None
17
 
18
+ # Singleton Pattern
19
  def __new__(cls):
20
  if cls._instance is None:
21
  print("Khởi tạo HuggingFaceRAGService...")
 
27
  if self._initialized:
28
  return
29
 
30
+ # --- CẤU HÌNH ---
31
  self.MODEL_NAME = "all-MiniLM-L6-v2"
 
 
 
 
32
 
33
+ # ID của Repo trên Hugging Face chứa file index và data
34
+ # Bạn cần đảm bảo đã upload file .faiss và .json lên repo này (dạng Dataset hoặc Model)
35
+ self.HF_REPO_ID = "synguyen1106/vietnam_heritage_embeddings_v4"
36
+ self.HF_REPO_TYPE = "dataset" # Hoặc "model" hoặc "space" tùy nơi bạn để file
37
+
38
+ # Tên file trên repo HF
39
+ self.FILENAME_INDEX = "heritage.faiss"
40
+ self.FILENAME_META = "metadata.json"
41
+ # self.FILENAME_IDS = "ids.json" # Nếu bạn gộp vào metadata thì ko cần file này
42
+
43
+ # Load model & Data
44
  self._load_model()
45
  self._load_data()
46
+
47
  self._initialized = True
48
  print("✅ HuggingFaceRAGService đã sẵn sàng.")
49
 
50
  def _load_model(self):
51
+ print(f"🤖 [HF RAG] Đang tải model embedding: {self.MODEL_NAME}...")
52
  self.model = SentenceTransformer(self.MODEL_NAME)
53
 
54
  def _load_data(self):
55
+ """
56
+ Chiến lược:
57
+ 1. Cố gắng tải file index đã build sẵn từ Hugging Face (Nhanh, tránh lỗi LFS).
58
+ 2. Nếu không tìm thấy file trên HF, fallback về việc tải Dataset gốc và build lại index (Chậm hơn).
59
+ """
60
+ try:
61
+ print(f"⬇️ [HF RAG] Đang thử tải Index pre-built từ HF Hub: {self.HF_REPO_ID}...")
62
+
63
+ # 1. Tải file FAISS Index
64
+ # hf_hub_download sẽ tự xử caching LFS pointer
65
+ index_path = hf_hub_download(
66
+ repo_id=self.HF_REPO_ID,
67
+ filename=self.FILENAME_INDEX,
68
+ repo_type=self.HF_REPO_TYPE
69
+ )
70
 
71
+ # 2. Tải file Metadata
72
+ metadata_path = hf_hub_download(
73
+ repo_id=self.HF_REPO_ID,
74
+ filename=self.FILENAME_META,
75
+ repo_type=self.HF_REPO_TYPE
76
+ )
77
+
78
+ # 3. Load vào RAM
79
+ print(f"📂 [HF RAG] Đang đọc file index từ: {index_path}")
80
+ self.index = faiss.read_index(index_path)
81
 
82
+ with open(metadata_path, "r", encoding="utf-8") as f:
83
+ self.metadata = json.load(f)
84
+
85
+ print(f"✅ [HF RAG] Load thành công từ Cache HF! (Items: {self.index.ntotal})")
86
+
87
+ except Exception as e:
88
+ print(f"⚠️ [HF RAG] Không tải được pre-built index ({e}). \n🔄 Chuyển sang build từ Dataset gốc...")
89
+ self._build_from_dataset()
90
+
91
+ def _build_from_dataset(self):
92
+ """
93
+ Hàm fallback: Tải dataset thô và build index tại chỗ (Tốn RAM và CPU lúc khởi động)
94
+ """
95
+ print("💾 [HF RAG] Đang tải dataset và xây dựng FAISS index mới...")
96
+ dataset = load_dataset(self.HF_REPO_ID, split="train")
97
+
98
+ # Chuẩn bị vectors
99
+ vectors = np.array(dataset['embedding']).astype("float32")
100
+
101
+ # Chuẩn bị metadata (loại bỏ cột embedding để nhẹ RAM)
102
+ self.metadata = [{k: v for k, v in item.items() if k != 'embedding'} for item in dataset]
103
+
104
+ # Build Index
105
+ d = vectors.shape[1]
106
+ self.index = faiss.IndexFlatL2(d)
107
+ self.index.add(vectors)
108
+
109
+ print(f"🔨 [HF RAG] Đã build xong index. Số lượng vector: {self.index.ntotal}")
110
+
111
+ # Mẹo: Ở đây bạn có thể lưu file ra đĩa và upload ngược lên HF để lần sau dùng cách 1
112
 
113
  def search(self, query: str, k: int = 2) -> List[Dict[str, Any]]:
114
+ # Encode câu hỏi
115
  query_vec = self.model.encode([query], convert_to_numpy=True).astype("float32")
116
+
117
+ # Search FAISS
118
+ distances, indices = self.index.search(query_vec, k)
119
+
120
+ # Map kết quả
121
+ results = []
122
+ for i, idx in enumerate(indices[0]):
123
+ if idx != -1: # Kiểm tra nếu tìm thấy
124
+ item = {
125
+ "score": float(distances[0][i]), # Distance càng nhỏ càng giống (với L2)
126
+ "metadata": self.metadata[int(idx)]
127
+ }
128
+ results.append(item)
129
+
130
  return results
 
131
  # ==============================================================================
132
  # HỆ THỐNG RAG 2: SỬ DỤNG LOCAL DISK DATASET
133
  # ==============================================================================