T-Phong commited on
Commit
b21ec88
·
1 Parent(s): c952c4d

update code

Browse files
Files changed (6) hide show
  1. app.py +60 -167
  2. requirements_rag.txt +2 -1
  3. service/helper.py +1 -8
  4. service/rag.py +214 -190
  5. service/reranking.py +90 -94
  6. service/rewrite.py +122 -76
app.py CHANGED
@@ -1,249 +1,142 @@
1
  """
2
- REST API cho Vietnam Heritage RAG System
3
  """
4
  import uuid
5
- import json
6
  import os
7
  import sys
8
- from datetime import datetime
9
 
10
- import google.generativeai as genai
11
  from flask import Flask, request, jsonify
12
  from flask_cors import CORS
13
 
14
- # Add service directory to sys.path to allow imports
15
  sys.path.append(os.path.join(os.path.dirname(__file__), 'service'))
16
- from rewrite import QueryRewriter
 
17
 
18
  app = Flask(__name__)
19
  CORS(app)
20
- # Khởi tạo QueryRewriter (chứa ask_with_context)
21
- rewriter = QueryRewriter()
22
- islog = os.getenv('islog')
23
- metrics_log = [] # Lưu lại các lần đánh giá để dựng biểu đồ
24
-
25
- GENAI_API_KEY = os.getenv("GEMINI_API_KEY")
26
- if GENAI_API_KEY:
27
- genai.configure(api_key=GENAI_API_KEY)
28
 
 
 
29
 
30
- def _safe_json_parse(text):
31
- """Parse chuỗi JSON, cố gắng trích block {} đầu tiên nếu có thêm text."""
32
- try:
33
- return json.loads(text)
34
- except Exception:
35
- pass
36
-
37
- start = text.find("{")
38
- end = text.rfind("}")
39
- if start != -1 and end != -1 and end > start:
40
- try:
41
- return json.loads(text[start : end + 1])
42
- except Exception:
43
- return None
44
- return None
45
-
46
-
47
- def evaluate_answer_llm(question: str, answer: str, history_message):
48
- """Gọi LLM để chấm điểm mức liên quan, độ chính xác và mức độ lan man."""
49
- if not GENAI_API_KEY:
50
- return {
51
- "status": "skipped",
52
- "reason": "missing_gemini_api_key",
53
- }
54
 
55
- try:
56
- model = genai.GenerativeModel("gemini-2.5-flash")
57
- history_text = "\n".join([m.get("content", "") for m in history_message]) if history_message else ""
58
- prompt = (
59
- "You are an evaluator for a RAG chatbot."
60
- " Return JSON with keys: rag_relevance (0-1), answer_accuracy (0-1), hallucination (bool), notes (string)."
61
- " Evaluate strictly from question and answer (and chat history if provided)."
62
- " rag_relevance measures how well retrieved context seems relevant to the question."
63
- " answer_accuracy measures factual correctness and completeness."
64
- " hallucination is true if the answer includes unrelated, fabricated, or off-topic info."
65
- f"\nQuestion: {question}\nAnswer: {answer}\nHistory: {history_text}\nReturn JSON only."
66
- )
67
- resp = model.generate_content(prompt)
68
- parsed = _safe_json_parse(resp.text)
69
- if not parsed:
70
- raise ValueError("LLM did not return valid JSON")
71
-
72
- rag_rel = float(parsed.get("rag_relevance", 0))
73
- acc = float(parsed.get("answer_accuracy", 0))
74
- halluc = bool(parsed.get("hallucination", False))
75
-
76
- return {
77
- "status": "ok",
78
- "timestamp": datetime.utcnow().isoformat() + "Z",
79
- "rag_relevance": max(0.0, min(1.0, rag_rel)),
80
- "answer_accuracy": max(0.0, min(1.0, acc)),
81
- "hallucination": halluc,
82
- "notes": parsed.get("notes", "") or "",
83
- }
84
- except Exception as e:
85
- return {
86
- "status": "error",
87
- "error": str(e),
88
- }
89
 
90
  @app.route('/v1/chat/completions', methods=['POST'])
91
  def ask_api():
92
  """
93
- Main endpoint - Gọi ask_with_context
94
-
95
  Request body:
96
  {
97
- "question": "Câu hỏi của bạn"
 
 
98
  }
99
-
100
  Response:
101
  {
102
- "question": "Câu hỏi",
103
- "answer": "Câu trả lời từ RAG"
 
 
104
  }
105
- """
106
  try:
107
  data = request.get_json()
108
- all_messages = data.get("messages", [])
109
- history_message = all_messages[-6:-1]
110
- # if islog == "1":
111
- # for f in history_message:
112
- # print(f)
113
-
114
 
115
-
 
 
116
 
117
- question = all_messages[-1]["content"]
118
-
119
  if not question:
120
- return jsonify({
121
- "error": "'question' cannot be empty"
122
- }), 400
123
-
124
- # Gọi ask_with_context
125
- answer = rewriter.ask_with_context(question, history_message)
126
-
127
- # Đánh giá tự động bằng LLM
128
- # evaluation = evaluate_answer_llm(question, answer, history_message)
129
- # if evaluation:
130
- # metrics_log.append({
131
- # "question": question,
132
- # "answer": answer,
133
- # "evaluation": evaluation,
134
- # })
135
- # # Giữ kích thước log vừa phải để hiển thị biểu đồ
136
- # if len(metrics_log) > 200:
137
- # del metrics_log[:-200]
138
 
139
  return jsonify({
140
  "id": str(uuid.uuid4()),
141
- "object": "chat.completion",
142
- "choices": [
143
  {
144
- "index": 0,
145
- "message": {
146
- "role": "assistant",
147
- "content": answer
148
- },
149
- "finish_reason": "stop"
150
  }
151
  ],
152
- "evaluation": "evaluation"
153
  }), 200
154
-
155
  except Exception as e:
156
  return jsonify({
157
  "error": str(e),
158
  "status": "error"
159
  }), 500
160
 
 
161
  @app.route('/v1/models', methods=['GET'])
162
- def lstmodel():
163
  return jsonify({
164
  "object": "list",
165
  "data": [
166
- {"id": "Model-1", "object": "model", "owned_by": "owner"},
167
- {"id": "Model-2", "object": "model", "owned_by": "owner"}
168
- ]
169
  }), 200
170
-
171
 
172
  @app.route('/health', methods=['GET'])
173
  def health_check():
174
  """Health check endpoint"""
 
 
 
175
  return jsonify({
176
  "status": "healthy",
177
  "service": "Vietnam Heritage RAG API"
178
  }), 200
179
 
 
180
  @app.route('/', methods=['GET'])
181
  def home():
182
  """API documentation"""
183
  return jsonify({
184
  "message": "Vietnam Heritage AI REST API",
185
- "version": "1.0.0",
186
  "endpoints": {
187
- "POST /ask": {
188
- "description": "Ask a question about Vietnamese heritage",
189
- "body": {
190
- "question": "Your question here"
191
- }
192
- },
193
- "GET /health": "Health check endpoint",
194
- "GET /": "API documentation",
195
- "GET /lstmodel": "List available models"
196
  },
197
  "example": {
198
- "url": "/ask",
199
  "method": "POST",
200
  "body": {
201
- "question": "Nguyễn Trãi là ai?"
 
 
202
  }
203
  }
204
  }), 200
205
 
206
 
207
- @app.route('/metrics', methods=['GET'])
208
- def get_metrics():
209
- """Trả về log đánh giá để dựng biểu đồ ở frontend."""
210
- # Tính trung bình nhanh để tiện hiển thị
211
- rag_scores = [m["evaluation"].get("rag_relevance", 0) for m in metrics_log if m.get("evaluation", {}).get("status") == "ok"]
212
- acc_scores = [m["evaluation"].get("answer_accuracy", 0) for m in metrics_log if m.get("evaluation", {}).get("status") == "ok"]
213
- halluc_counts = [m["evaluation"].get("hallucination", False) for m in metrics_log if m.get("evaluation", {}).get("status") == "ok"]
214
-
215
- summary = {
216
- "total": len(metrics_log),
217
- "avg_rag_relevance": sum(rag_scores) / len(rag_scores) if rag_scores else 0,
218
- "avg_answer_accuracy": sum(acc_scores) / len(acc_scores) if acc_scores else 0,
219
- "hallucination_rate": (sum(1 for h in halluc_counts if h) / len(halluc_counts)) if halluc_counts else 0,
220
- }
221
-
222
- return jsonify({
223
- "summary": summary,
224
- "data": metrics_log,
225
- }), 200
226
-
227
- @app.route('/reset', methods=['POST'])
228
- def reset_history():
229
- """Reset conversation history"""
230
- global history
231
- history = []
232
- return jsonify({
233
- "message": "History reset successfully",
234
- "status": "success"
235
- }), 200
236
-
237
  if __name__ == '__main__':
238
  port = int(os.environ.get('PORT', 5000))
239
  print("=" * 60)
240
- print(f"🚀 Vietnam Heritage RAG API")
241
  print("=" * 60)
242
  print(f"📍 Server: http://localhost:{port}")
243
  print(f"📝 Endpoints:")
244
- print(f" POST http://localhost:{port}/ask")
245
  print(f" GET http://localhost:{port}/health")
246
  print(f" GET http://localhost:{port}/")
247
  print("=" * 60)
248
-
249
- app.run(host='0.0.0.0', port=port, debug=True)
 
1
  """
2
+ REST API cho Vietnam Heritage RAG System
3
  """
4
  import uuid
 
5
  import os
6
  import sys
 
7
 
 
8
  from flask import Flask, request, jsonify
9
  from flask_cors import CORS
10
 
11
+ # Thêm thư mục service vào sys.path
12
  sys.path.append(os.path.join(os.path.dirname(__file__), 'service'))
13
+ from service.reranking import advanced_search
14
+ from service.rewrite import QueryRewriter
15
 
16
  app = Flask(__name__)
17
  CORS(app)
 
 
 
 
 
 
 
 
18
 
19
+ # Khởi tạo QueryRewriter (chứa toàn bộ pipeline RAG)
20
+ rewriter = QueryRewriter()
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # ==============================================================================
24
+ # ENDPOINTS
25
+ # ==============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  @app.route('/v1/chat/completions', methods=['POST'])
28
  def ask_api():
29
  """
30
+ Main endpoint - Chat với AI về văn hoá Việt Nam.
31
+
32
  Request body:
33
  {
34
+ "messages": [
35
+ {"role": "user", "content": "Câu hỏi của bạn"}
36
+ ]
37
  }
38
+
39
  Response:
40
  {
41
+ "id": "uuid",
42
+ "object": "chat.completion",
43
+ "choices": [{"index": 0, "message": {"role": "assistant", "content": "..."}, "finish_reason": "stop"}],
44
+ "image_url": "https://..." hoặc null
45
  }
46
+ """
47
  try:
48
  data = request.get_json()
49
+ if not data:
50
+ return jsonify({"error": "Request body is required"}), 400
 
 
 
 
51
 
52
+ all_messages = data.get("messages", [])
53
+ if not all_messages:
54
+ return jsonify({"error": "'messages' cannot be empty"}), 400
55
 
56
+ question = all_messages[-1].get("content", "").strip()
 
57
  if not question:
58
+ return jsonify({"error": "Last message content cannot be empty"}), 400
59
+
60
+ # Lấy tối đa 3 lượt hội thoại gần nhất (không tính message hiện tại)
61
+ history_message = all_messages[-7:-1]
62
+
63
+ # Pipeline RAG → trả về (answer_text, image_url | None)
64
+ answer, image_url = rewriter.ask_with_context(question, history_message)
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  return jsonify({
67
  "id": str(uuid.uuid4()),
68
+ "message": [
 
69
  {
70
+ "role": "assistant",
71
+ "content": answer
 
 
 
 
72
  }
73
  ],
74
+ "image_url": image_url
75
  }), 200
76
+
77
  except Exception as e:
78
  return jsonify({
79
  "error": str(e),
80
  "status": "error"
81
  }), 500
82
 
83
+
84
  @app.route('/v1/models', methods=['GET'])
85
+ def list_models():
86
  return jsonify({
87
  "object": "list",
88
  "data": [
89
+ {"id": "vietnam-heritage-rag-v1", "object": "model", "owned_by": "culturebot"}
90
+ ]
 
91
  }), 200
92
+
93
 
94
  @app.route('/health', methods=['GET'])
95
  def health_check():
96
  """Health check endpoint"""
97
+ question = "Người Ê Đê là một dân tộc thiểu số tại Việt Nam, chủ yếu sống ở vùng Tây Nguyên. Văn hóa của người Ê Đê được đặc trưng bởi các yếu tố sau: - **Ngôn n"
98
+ history_message = ['người Ê Đê', 'văn hoá']
99
+ advanced_search(question, history_message)
100
  return jsonify({
101
  "status": "healthy",
102
  "service": "Vietnam Heritage RAG API"
103
  }), 200
104
 
105
+
106
  @app.route('/', methods=['GET'])
107
  def home():
108
  """API documentation"""
109
  return jsonify({
110
  "message": "Vietnam Heritage AI REST API",
111
+ "version": "2.0.0",
112
  "endpoints": {
113
+ "POST /v1/chat/completions": "Chat với AI về văn hoá Việt Nam",
114
+ "GET /v1/models": "Danh sách model",
115
+ "GET /health": "Health check",
116
+ "GET /": "API documentation"
 
 
 
 
 
117
  },
118
  "example": {
119
+ "url": "/v1/chat/completions",
120
  "method": "POST",
121
  "body": {
122
+ "messages": [
123
+ {"role": "user", "content": "Giới thiệu về Vịnh Hạ Long"}
124
+ ]
125
  }
126
  }
127
  }), 200
128
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  if __name__ == '__main__':
131
  port = int(os.environ.get('PORT', 5000))
132
  print("=" * 60)
133
+ print("🚀 Vietnam Heritage RAG API")
134
  print("=" * 60)
135
  print(f"📍 Server: http://localhost:{port}")
136
  print(f"📝 Endpoints:")
137
+ print(f" POST http://localhost:{port}/v1/chat/completions")
138
  print(f" GET http://localhost:{port}/health")
139
  print(f" GET http://localhost:{port}/")
140
  print("=" * 60)
141
+
142
+ app.run(host='0.0.0.0', port=port, debug=False)
requirements_rag.txt CHANGED
@@ -12,4 +12,5 @@ gunicorn>=20.1.0
12
  flask-cors>=3.0.10
13
  google-generativeai>=0.3.3
14
  bitsandbytes>=0.39.1
15
- accelerate>=0.30.3
 
 
12
  flask-cors>=3.0.10
13
  google-generativeai>=0.3.3
14
  bitsandbytes>=0.39.1
15
+ accelerate>=0.30.3
16
+ rank_bm25>=0.2.1
service/helper.py CHANGED
@@ -18,7 +18,7 @@ def format_metadata_list_to_context(search_results: List[Dict[str, Any]]) -> str
18
 
19
  # 1. Trích xuất dữ liệu (Dùng .get để tránh lỗi nếu thiếu trường)
20
  ten = data.get('ten') or data.get('group', 'Không rõ tên')
21
- mo_ta = data.get('mo_ta') or data.get('content', '')
22
 
23
  # Nhóm thông tin phân loại
24
  loai_hinh = data.get('loai_hinh', 'N/A')
@@ -40,13 +40,6 @@ def format_metadata_list_to_context(search_results: List[Dict[str, Any]]) -> str
40
  [TỔNG QUAN]
41
  Tên: {ten}
42
  Mô tả/Nội dung: {mo_ta}
43
-
44
- [THÔNG TIN CHI TIẾT]
45
- - Phân loại: {loai_hinh} (Chủ đề: {chu_de})
46
- - Dân tộc: {dan_toc}
47
- - Thời gian: {nien_dai} ({thoi_ky})
48
- - Địa danh/Vùng miền: {vung_mien} - {dia_diem}
49
- - Chất liệu: {chat_lieu} - Nguyên liệu chính: {nguyen_lieu_chinh}
50
  """
51
 
52
  # 3. Ghép vào chuỗi tổng
 
18
 
19
  # 1. Trích xuất dữ liệu (Dùng .get để tránh lỗi nếu thiếu trường)
20
  ten = data.get('ten') or data.get('group', 'Không rõ tên')
21
+ mo_ta = data.get('mo_ta') or data.get('combined_text') or data.get('original_content') or data.get('content', '')
22
 
23
  # Nhóm thông tin phân loại
24
  loai_hinh = data.get('loai_hinh', 'N/A')
 
40
  [TỔNG QUAN]
41
  Tên: {ten}
42
  Mô tả/Nội dung: {mo_ta}
 
 
 
 
 
 
 
43
  """
44
 
45
  # 3. Ghép vào chuỗi tổng
service/rag.py CHANGED
@@ -1,244 +1,268 @@
1
  import os
2
  import json
 
3
  import numpy as np
4
  import faiss
5
  from sentence_transformers import SentenceTransformer
6
- from datasets import load_dataset, load_from_disk
7
- from huggingface_hub import snapshot_download
8
  from typing import List, Dict, Any, Optional
9
- from huggingface_hub import hf_hub_download
10
- from helper import format_metadata_list_to_context
11
 
12
- # ==============================================================================
13
- # HỆ THỐNG RAG 1: SỬ DỤNG HUGGING FACE DATASET
14
- # ==============================================================================
15
- class HuggingFaceRAGService:
16
- _instance: Optional['HuggingFaceRAGService'] = None
17
-
18
- # Singleton Pattern
19
- def __new__(cls):
20
- if cls._instance is None:
21
- print("Khởi tạo HuggingFaceRAGService...")
22
- cls._instance = super(HuggingFaceRAGService, cls).__new__(cls)
23
- cls._instance._initialized = False
24
- return cls._instance
25
 
26
- def __init__(self):
27
- if self._initialized:
28
- return
29
-
30
- # --- CẤU HÌNH ---
31
- self.MODEL_NAME = "all-MiniLM-L6-v2"
32
-
33
- # ID của Repo trên Hugging Face chứa file index và data
34
- # Bạn cần đảm bảo đã upload file .faiss và .json lên repo này (dạng Dataset hoặc Model)
35
- self.HF_REPO_ID = "synguyen1106/vietnam_heritage_embeddings_v4"
36
- self.HF_REPO_TYPE = "dataset" # Hoặc "model" hoặc "space" tùy nơi bạn để file
37
 
38
- # Tên file trên repo HF
39
- self.FILENAME_INDEX = "heritage.faiss"
40
- self.FILENAME_META = "metadata.json"
41
- # self.FILENAME_IDS = "ids.json" # Nếu bạn gộp vào metadata thì ko cần file này
42
 
43
- # Load model & Data
44
- self._load_model()
45
  self._load_data()
46
-
47
- self._initialized = True
48
- print("✅ HuggingFaceRAGService đã sẵn sàng.")
49
 
50
- def _load_model(self):
51
- print(f"🤖 [HF RAG] Đang tải model embedding: {self.MODEL_NAME}...")
52
- self.model = SentenceTransformer(self.MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  def _load_data(self):
55
- """
56
- Chiến lược:
57
- 1. Cố gắng tải file index đã build sẵn từ Hugging Face (Nhanh, tránh lỗi LFS).
58
- 2. Nếu không tìm thấy file trên HF, fallback về việc tải Dataset gốc và build lại index (Chậm hơn).
59
- """
60
  try:
61
- print(f"⬇️ [HF RAG] Đang thử tải Index pre-built từ HF Hub: {self.HF_REPO_ID}...")
62
 
63
- # 1. Tải file FAISS Index
64
- # hf_hub_download sẽ tự xử lý caching và LFS pointer
65
- index_path = hf_hub_download(
66
- repo_id=self.HF_REPO_ID,
67
- filename=self.FILENAME_INDEX,
68
- repo_type=self.HF_REPO_TYPE
69
- )
70
 
71
- # 2. Tải file Metadata
72
- metadata_path = hf_hub_download(
73
- repo_id=self.HF_REPO_ID,
74
- filename=self.FILENAME_META,
75
- repo_type=self.HF_REPO_TYPE
76
- )
77
-
78
- # 3. Load vào RAM
79
- print(f"📂 [HF RAG] Đang đọc file index từ: {index_path}")
80
- self.index = faiss.read_index(index_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- with open(metadata_path, "r", encoding="utf-8") as f:
83
- self.metadata = json.load(f)
 
84
 
85
- print(f"✅ [HF RAG] Load thành công từ Cache HF! (Items: {self.index.ntotal})")
 
 
 
 
 
 
 
86
 
87
  except Exception as e:
88
- print(f"⚠️ [HF RAG] Không tải được pre-built index ({e}). \n🔄 Chuyển sang build từ Dataset gốc...")
89
- self._build_from_dataset()
90
-
91
- def _build_from_dataset(self):
92
- """
93
- Hàm fallback: Tải dataset thô và build index tại chỗ (Tốn RAM và CPU lúc khởi động)
94
- """
95
- print("💾 [HF RAG] Đang tải dataset và xây dựng FAISS index mới...")
96
- dataset = load_dataset(self.HF_REPO_ID, split="train")
97
-
98
- # Chuẩn bị vectors
99
- vectors = np.array(dataset['embedding']).astype("float32")
100
-
101
- # Chuẩn bị metadata (loại bỏ cột embedding để nhẹ RAM)
102
- self.metadata = [{k: v for k, v in item.items() if k != 'embedding'} for item in dataset]
103
-
104
- # Build Index
105
  d = vectors.shape[1]
106
  self.index = faiss.IndexFlatL2(d)
107
  self.index.add(vectors)
108
-
109
- print(f"🔨 [HF RAG] Đã build xong index. Số lượng vector: {self.index.ntotal}")
110
-
111
- # Mẹo: Ở đây bạn có thể lưu file ra đĩa và upload ngược lên HF để lần sau dùng cách 1
112
 
113
- def search(self, query: str, k: int = 2) -> List[Dict[str, Any]]:
114
- # Encode câu hỏi
115
- query_vec = self.model.encode([query], convert_to_numpy=True).astype("float32")
116
-
117
- # Search FAISS
118
- distances, indices = self.index.search(query_vec, k)
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- # Map kết quả
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  results = []
122
- for i, idx in enumerate(indices[0]):
123
- if idx != -1: # Kiểm tra nếu tìm thấy
124
- item = {
125
- "score": float(distances[0][i]), # Distance càng nhỏ càng giống (với L2)
126
- "metadata": self.metadata[int(idx)]
127
- }
128
- results.append(item)
129
-
130
  return results
131
- # ==============================================================================
132
- # HỆ THỐNG RAG 2: SỬ DỤNG LOCAL DISK DATASET
133
- # ==============================================================================
134
- class LocalDiskRAGService:
135
- _instance: Optional['LocalDiskRAGService'] = None
136
 
 
 
 
 
137
  def __new__(cls):
138
  if cls._instance is None:
139
- print("\nKhởi tạo LocalDiskRAGService...")
140
- cls._instance = super(LocalDiskRAGService, cls).__new__(cls)
141
  cls._instance._initialized = False
142
  return cls._instance
143
 
144
  def __init__(self):
145
  if self._initialized:
146
  return
 
 
 
 
 
 
147
 
148
- # Cấu hình
149
- self.MODEL_NAME = 'AITeamVN/Vietnamese_Embedding_v2'
150
- # Thay đổi từ đường dẫn local sang ID của dataset trên Hugging Face Hub
151
- self.DATASET_ID = "phongnt251199/Wiki_Culture_Vec"
152
- self.MIN_CONTENT_LENGTH = 200
153
- self.CANDIDATE_MULTIPLIER = 5
154
 
155
- # Tải model và dữ liệu
156
- self._load_model()
157
- self._load_data()
 
 
158
  self._initialized = True
159
- print("✅ LocalDiskRAGService đã sẵn sàng.")
160
-
161
- def _load_model(self):
162
- print(f"🤖 [Local RAG] Đang tải model AI: {self.MODEL_NAME}...")
163
- self.model = SentenceTransformer(self.MODEL_NAME)
164
 
165
- def _load_data(self):
166
- print(f"💾 [Local RAG] Đang tải dữ liệu từ Hugging Face Hub: {self.DATASET_ID}...")
167
- try:
168
- # Tải toàn bộ dataset về và lấy đường dẫn local
169
- # Hugging Face Spaces sẽ tự động sử dụng token trong secrets nếu repo là private
170
- dataset_path = snapshot_download(repo_id=self.DATASET_ID, repo_type="dataset")
171
-
172
- self.dataset = load_from_disk(dataset_path)
173
- print(f"💾 [Local RAG] Load xong! Tổng số dữ liệu: {len(self.dataset)} dòng.")
174
-
175
- print("🔨 [Local RAG] Đang kích hoạt bộ tìm kiếm (Re-indexing)...")
176
- self.dataset.add_faiss_index(column="embeddings")
177
- print("🔨 [Local RAG] Đã kích hoạt xong FAISS Index!")
178
- except Exception as e:
179
- print(f"❌ Lỗi: Không thể tải dataset từ Hub. Lỗi: {e}")
180
- self.dataset = None
181
- return
182
 
183
  def search(self, query: str, top_k: int = 3) -> List[Dict[str, Any]]:
184
- if not self.dataset:
185
- return []
186
-
187
- # print(f"\n🔎 [Local RAG] Đang tìm: '{query}'")
188
- # print("-" * 50)
189
-
190
- query_vector = self.model.encode(query)
191
- candidate_k = top_k * self.CANDIDATE_MULTIPLIER
192
- scores, samples = self.dataset.get_nearest_examples("embeddings", query_vector, k=candidate_k)
193
-
194
- results = []
195
- for i in range(len(samples['original_content'])):
196
- if len(results) >= top_k:
197
- break
198
-
199
- content = samples['original_content'][i]
200
- if len(content) < self.MIN_CONTENT_LENGTH:
201
- continue
202
-
203
- score = scores[i]
204
- metadata = samples['metadata'][i]
205
- metadata['content'] = content
206
-
207
- results.append({
208
- "metadata": metadata,
209
- "score": score
210
- })
211
 
212
- # In ra console để debug như hàm gốc
213
- # print(f"Top {len(results)} (Độ sai lệch: {score:.2f}):")
214
- # print(f"Nội dung: {content[:200]}...")
215
- # print("-" * 50)
216
-
217
- if not results:
218
- print(f"Không tìm thấy kết quả nào có nội dung dài hơn {self.MIN_CONTENT_LENGTH} ký tự.")
219
-
220
- return results
221
 
222
  # ==============================================================================
223
- # KHỞI TẠO SERVICE VÀ CUNG CẤP CÁC HÀM GỐC
224
  # ==============================================================================
225
- hf_rag_service = HuggingFaceRAGService()
226
- local_rag_service = LocalDiskRAGService()
 
 
227
 
228
  def retrieve_context(query: str, k: int = 2) -> str:
229
- """
230
- Tìm kiếm ngữ cảnh sử dụng hệ thống RAG từ Hugging Face.
231
- (Giữ nguyên hàm gốc để tương thích)
232
- """
233
- print("\n>>> Sử dụng hệ thống RAG 1 (HuggingFace)...")
234
- results = hf_rag_service.search(query, k)
235
  return format_metadata_list_to_context(results)
236
 
237
  def search_heritage(query: str, top_k: int = 3) -> str:
238
- """
239
- Tìm kiếm di sản sử dụng hệ thống RAG từ ổ đĩa cục bộ.
240
- (Giữ nguyên hàm gốc để tương thích)
241
- """
242
- print("\n>>> Sử dụng hệ thống RAG 2 (Local Disk)...")
243
  results = local_rag_service.search(query, top_k)
244
  return format_metadata_list_to_context(results)
 
1
  import os
2
  import json
3
+ import ast
4
  import numpy as np
5
  import faiss
6
  from sentence_transformers import SentenceTransformer
7
+ from datasets import load_dataset
 
8
  from typing import List, Dict, Any, Optional
9
+ from .helper import format_metadata_list_to_context
 
10
 
11
+ try:
12
+ from rank_bm25 import BM25Okapi
13
+ except ImportError:
14
+ BM25Okapi = None
15
+ print("Cảnh báo: Thư viện rank_bm25 chưa được cài đặt, keyword search có thể bị ảnh hưởng.")
 
 
 
 
 
 
 
 
16
 
17
+ class SingleDatasetRAGService:
18
+ """Xử lý load dữ liệu, tạo vector và tìm kiếm FAISS+BM25 cho một dataset cụ thể."""
19
+ def __init__(self, model: SentenceTransformer, dataset_id: str, faiss_path: str):
20
+ self.model = model
21
+ self.DATASET_ID = dataset_id
22
+ self.FAISS_INDEX_PATH = faiss_path
 
 
 
 
 
23
 
24
+ self.dataset_records = []
25
+ self.image_lookup = {}
26
+ self.bm25_index = None
27
+ self.index = None
28
 
 
 
29
  self._load_data()
30
+ print(f"✅ SingleDatasetRAGService ({self.DATASET_ID}) đã sẵn sàng.")
 
 
31
 
32
+ def _parse_image_url(self, raw) -> str:
33
+ """Parse image_urls list hay string JSON, trả về URL đầu tiên hợp lệ."""
34
+ if isinstance(raw, list):
35
+ for url in raw:
36
+ if isinstance(url, str) and url.startswith('http'):
37
+ return url
38
+ elif isinstance(raw, str) and raw.strip():
39
+ raw = raw.strip()
40
+ if raw.startswith('http'):
41
+ return raw
42
+ try:
43
+ parsed = ast.literal_eval(raw)
44
+ if isinstance(parsed, list):
45
+ for url in parsed:
46
+ if isinstance(url, str) and url.startswith('http'):
47
+ return url
48
+ except Exception:
49
+ pass
50
+ return None
51
 
52
  def _load_data(self):
53
+ print(f"💾 [RAG {self.DATASET_ID}] Đang tải dataset...")
 
 
 
 
54
  try:
55
+ dataset = load_dataset(self.DATASET_ID, split="train")
56
 
57
+ tokenized_corpus = []
58
+ texts_for_embedding = []
 
 
 
 
 
59
 
60
+ count_img = 0
61
+ for item in dataset:
62
+ group = item.get('group', '')
63
+ if not group:
64
+ continue
65
+
66
+ group_str = str(group).strip()
67
+ group_lower = group_str.lower()
68
+
69
+ # 1. Image lookup
70
+ imgs = item.get('image_urls', [])
71
+ if group_lower not in self.image_lookup:
72
+ url = self._parse_image_url(imgs)
73
+ if url:
74
+ self.image_lookup[group_lower] = url
75
+ count_img += 1
76
+
77
+ # 2. Extract metadata
78
+ meta_raw = item.get('metadata', {})
79
+ if isinstance(meta_raw, str):
80
+ try:
81
+ meta_dict = ast.literal_eval(meta_raw)
82
+ except:
83
+ meta_dict = {}
84
+ else:
85
+ meta_dict = meta_raw if isinstance(meta_raw, dict) else {}
86
+
87
+ # Dùng combined_text nếu có
88
+ content = str(item.get('combined_text', ''))[:3000]
89
+
90
+ record = {
91
+ "group": group_str,
92
+ "combined_text": content,
93
+ "original_content": str(item.get('original_content', ''))[:3000],
94
+ "image_urls": item.get('image_urls', []),
95
+ "dataset_source": self.DATASET_ID, # Đánh dấu nguồn
96
+ **meta_dict
97
+ }
98
+
99
+ search_text = f"{group_str}. {content}"
100
+
101
+ self.dataset_records.append(record)
102
+ texts_for_embedding.append(search_text)
103
+
104
+ tokenized_corpus.append(search_text.lower().split())
105
+
106
+ print(f"🖼️ [RAG {self.DATASET_ID}] Image lookup built: {count_img} entries.")
107
 
108
+ if BM25Okapi is not None and tokenized_corpus:
109
+ print(f"🔍 [RAG {self.DATASET_ID}] Build BM25 cho {len(tokenized_corpus)} records...")
110
+ self.bm25_index = BM25Okapi(tokenized_corpus)
111
 
112
+ if os.path.exists(self.FAISS_INDEX_PATH):
113
+ print(f"📂 [RAG {self.DATASET_ID}] Đọc file FAISS từ: {self.FAISS_INDEX_PATH}")
114
+ self.index = faiss.read_index(self.FAISS_INDEX_PATH)
115
+ if self.index.ntotal != len(self.dataset_records):
116
+ print(f"⚠️ [RAG {self.DATASET_ID}] Kích thước FAISS không khớp, build lại...")
117
+ self._build_faiss(texts_for_embedding)
118
+ else:
119
+ self._build_faiss(texts_for_embedding)
120
 
121
  except Exception as e:
122
+ import traceback
123
+ traceback.print_exc()
124
+ print(f"❌ [RAG {self.DATASET_ID}] Lỗi load dataset: {e}")
125
+
126
+ def _build_faiss(self, texts: List[str]):
127
+ print(f"🔨 [RAG {self.DATASET_ID}] Đang embed {len(texts)} văn bản...")
128
+ vectors = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=True).astype("float32")
 
 
 
 
 
 
 
 
 
 
129
  d = vectors.shape[1]
130
  self.index = faiss.IndexFlatL2(d)
131
  self.index.add(vectors)
132
+ faiss.write_index(self.index, self.FAISS_INDEX_PATH)
133
+ print(f" [RAG {self.DATASET_ID}] Build lưu FAISS thành công.")
 
 
134
 
135
+ def get_image_for_topic(self, topic_name: str) -> str:
136
+ if not topic_name or not self.image_lookup:
137
+ return None
138
+ topic_lower = topic_name.strip().lower()
139
+ if topic_lower in self.image_lookup:
140
+ return self.image_lookup[topic_lower]
141
+ for group_key, url in self.image_lookup.items():
142
+ if topic_lower in group_key or group_key in topic_lower:
143
+ return url
144
+ return None
145
+
146
+ def search(self, query: str, top_k: int = 3) -> List[Dict[str, Any]]:
147
+ if not self.dataset_records:
148
+ return []
149
+
150
+ faiss_scores = {}
151
+ bm25_scores = {}
152
+ fetch_k = min(len(self.dataset_records), 60)
153
 
154
+ try:
155
+ if self.index:
156
+ query_vec = self.model.encode([query], convert_to_numpy=True).astype("float32")
157
+ distances, indices = self.index.search(query_vec, fetch_k)
158
+ for rank, idx in enumerate(indices[0]):
159
+ idx_int = int(idx)
160
+ if 0 <= idx_int < len(self.dataset_records):
161
+ faiss_scores[idx_int] = rank + 1
162
+ except Exception as e:
163
+ print(f"❌ Lỗi FAISS ({self.DATASET_ID}): {e}")
164
+
165
+ try:
166
+ if self.bm25_index:
167
+ tokenized_query = query.lower().split()
168
+ scores = self.bm25_index.get_scores(tokenized_query)
169
+ top_indexes = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:fetch_k]
170
+ for rank, idx in enumerate(top_indexes):
171
+ if scores[idx] > 0:
172
+ bm25_scores[idx] = rank + 1
173
+ except Exception as e:
174
+ print(f"❌ Lỗi BM25 ({self.DATASET_ID}): {e}")
175
+
176
+ k_rrf = 60
177
+ rrf_scores = {}
178
+ all_indices = set(list(faiss_scores.keys()) + list(bm25_scores.keys()))
179
+ for idx in all_indices:
180
+ score = 0.0
181
+ if idx in faiss_scores:
182
+ score += 1.0 / (k_rrf + faiss_scores[idx])
183
+ if idx in bm25_scores:
184
+ score += 1.0 / (k_rrf + bm25_scores[idx])
185
+ rrf_scores[idx] = score
186
+
187
+ sorted_indices = sorted(rrf_scores.keys(), key=lambda idx: rrf_scores[idx], reverse=True)
188
  results = []
189
+ for idx in sorted_indices[:top_k]:
190
+ item = {
191
+ "score": round(rrf_scores[idx], 4),
192
+ "metadata": self.dataset_records[idx]
193
+ }
194
+ results.append(item)
195
+
 
196
  return results
 
 
 
 
 
197
 
198
+ class MultiDatasetRAGService:
199
+ """Service tổng quản lý model Embedding chung và gọi search qua nhiều Dataset độc lập."""
200
+ _instance: Optional['MultiDatasetRAGService'] = None
201
+
202
  def __new__(cls):
203
  if cls._instance is None:
204
+ print("Khởi tạo MultiDatasetRAGService...")
205
+ cls._instance = super(MultiDatasetRAGService, cls).__new__(cls)
206
  cls._instance._initialized = False
207
  return cls._instance
208
 
209
  def __init__(self):
210
  if self._initialized:
211
  return
212
+
213
+ # Dùng model tiếng Việt nhẹ thay cho all-MiniLM-L6-v2 để chuẩn hóa
214
+ # chung cho cả lấy Reranking (do model 1 chuyên tiếng Việt)
215
+ self.MODEL_NAME = "keepitreal/vietnamese-sbert"
216
+ print(f"🤖 [Multi RAG] Đang tải chung model embedding: {self.MODEL_NAME}...")
217
+ self.model = SentenceTransformer(self.MODEL_NAME)
218
 
219
+ # Danh sách Dataset
220
+ # Lưu ý: Cập nhật tên FAISS file để tránh load nhầm file cũ không cùng dimension
221
+ self.datasets_config = [
222
+ {"id": "phongnt251199/vietnam_heritage_v3", "faiss": "heritage_v3_visbert.faiss"},
223
+ {"id": "phongnt251199/vietnam_heritage_wiki_chunks_v1", "faiss": "wiki_chunks_v1_visbert.faiss"}
224
+ ]
225
 
226
+ self.services = []
227
+ for cfg in self.datasets_config:
228
+ svc = SingleDatasetRAGService(self.model, cfg["id"], cfg["faiss"])
229
+ self.services.append(svc)
230
+
231
  self._initialized = True
232
+ print(f"✅ MultiDatasetRAGService đã sẵn sàng xử lý {len(self.services)} datasets.")
 
 
 
 
233
 
234
+ def get_image_for_topic(self, topic_name: str) -> str:
235
+ for svc in self.services:
236
+ url = svc.get_image_for_topic(topic_name)
237
+ if url:
238
+ return url
239
+ return None
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  def search(self, query: str, top_k: int = 3) -> List[Dict[str, Any]]:
242
+ all_results = []
243
+ # Lấy từ cả 2 nguồn, sau đó sort chung dựa trên Score RRF
244
+ for svc in self.services:
245
+ res = svc.search(query, top_k=top_k * 2) # Lấy dư để merge
246
+ all_results.extend(res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
+ # Sort tổng hợp lại theo điểm tổng
249
+ sorted_results = sorted(all_results, key=lambda x: x['score'], reverse=True)
250
+ return sorted_results[:top_k]
 
 
 
 
 
 
251
 
252
  # ==============================================================================
253
+ # KHỞI TẠO SERVICE (SINGLETON CỦA MULTI) VÀ CUNG CẤP CÁC HÀM GỐC TƯƠNG THÍCH
254
  # ==============================================================================
255
+ local_rag_service = MultiDatasetRAGService()
256
+
257
+ # Maintain các object cũ để tương thích với các module khác
258
+ hf_rag_service = local_rag_service
259
 
260
  def retrieve_context(query: str, k: int = 2) -> str:
261
+ print(f"\n>>> [Multi RAG] Tìm kiếm: {query}")
262
+ results = local_rag_service.search(query, k)
 
 
 
 
263
  return format_metadata_list_to_context(results)
264
 
265
  def search_heritage(query: str, top_k: int = 3) -> str:
266
+ print(f"\n>>> [Multi RAG] Tìm kiếm: {query}")
 
 
 
 
267
  results = local_rag_service.search(query, top_k)
268
  return format_metadata_list_to_context(results)
service/reranking.py CHANGED
@@ -1,13 +1,9 @@
1
- import ast
2
  import concurrent.futures
3
- from sentence_transformers import CrossEncoder
 
4
 
5
- from helper import format_metadata_list_to_context
6
- from rag import hf_rag_service, local_rag_service
7
-
8
- # Load model Reranker (nhẹ, chạy CPU được)
9
- #reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
10
- reranker = CrossEncoder('Alibaba-NLP/gte-multilingual-reranker-base', trust_remote_code=True)
11
 
12
  def reciprocal_rank_fusion(search_results_lists: list, k_rrf: int = 60):
13
  """
@@ -42,11 +38,11 @@ def reciprocal_rank_fusion(search_results_lists: list, k_rrf: int = 60):
42
 
43
  def advanced_search(query, keyword):
44
  """
45
- Tìm kiếm nâng cao kết hợp cả 2 nguồn dữ liệu, rerank định dạng riêng biệt.
 
46
  """
47
  try:
48
- result_wiki = []
49
- result_hf = []
50
 
51
  # 1. Tạo danh sách các cụm từ tìm kiếm
52
  search_terms = [query]
@@ -55,107 +51,107 @@ def advanced_search(query, keyword):
55
  elif isinstance(keyword, str) and keyword:
56
  search_terms.append(keyword)
57
 
58
- # Tối ưu: Tìm kiếm song song trên cả 2 nguồn với mỗi cụm từ
59
  with concurrent.futures.ThreadPoolExecutor() as executor:
60
- future_to_source = {}
61
  for r in search_terms:
62
- # HF RAG
63
- future_hf = executor.submit(hf_rag_service.search, r.lower(), k=15)
64
- future_to_source[future_hf] = 'hf'
65
- # Local RAG
66
- future_local = executor.submit(local_rag_service.search, r.lower(), top_k=15)
67
- future_to_source[future_local] = 'local'
68
 
69
- for future in concurrent.futures.as_completed(future_to_source):
70
- source = future_to_source[future]
71
  try:
72
  docs = future.result()
73
  if docs:
74
- if source == 'hf':
75
- result_hf.append(docs)
76
- else:
77
- result_wiki.append(docs)
78
  except Exception as e:
79
- print(f"Lỗi tìm kiếm {source}: {e}")
80
 
81
- if not result_hf and not result_wiki:
82
- return "Không tìm thấy thông tin phù hợp."
83
 
84
- # 2. Kết hợp kết quả từ các lần tìm kiếm (Fusion) Rerank cho từng nguồn
85
- # Xử lý nguồn HF RAG
86
- fused_results_hf = reciprocal_rank_fusion(result_hf)
87
- # Tối ưu: Giảm số lượng ứng viên rerank xuống 15 để tăng tốc độ
88
- candidates_for_rerank_hf = fused_results_hf[:15]
89
- pairs_to_score_hf = []
90
- for item in candidates_for_rerank_hf:
91
- meta = item['metadata']
92
- name = meta.get('ten','')
93
- desc = meta.get('mo_ta','')
94
- loai_hinh = meta.get('loai_hinh', '')
95
- chu_de = meta.get('chu_de', '')
96
- y_nghia = meta.get('y_nghia', '')
97
- constructed_text = f"Tên: {name}. Loại hình: {loai_hinh}. Chủ đề: {chu_de}. Mô tả: {desc}. Ý nghĩa: {y_nghia}"
98
- pairs_to_score_hf.append([query, constructed_text])
99
 
100
- sorted_docs_hf = []
101
- if pairs_to_score_hf:
102
- scores_hf = reranker.predict(pairs_to_score_hf)
103
- for i, doc in enumerate(candidates_for_rerank_hf):
104
- doc['rerank_score'] = scores_hf[i]
105
- sorted_docs_hf = sorted(candidates_for_rerank_hf, key=lambda x: x['rerank_score'], reverse=True)
106
-
107
- # Xử lý nguồn Local (Wiki) RAG
108
- fused_results_wiki = reciprocal_rank_fusion(result_wiki)
109
- # Tối ưu: Giảm số lượng ứng viên rerank xuống 40
110
- candidates_for_rerank_wiki = fused_results_wiki[:40]
111
- print("candidates_for_rerank_wiki:", len(candidates_for_rerank_wiki))
112
- pairs_to_score_wiki = []
113
- for item in candidates_for_rerank_wiki:
114
  meta = item['metadata']
115
- name = meta.get('group','')
116
- desc = meta.get('content','')
 
117
  constructed_text = f"Tên: {name}. Mô tả: {desc}"
118
- pairs_to_score_wiki.append([query, constructed_text])
119
-
120
- sorted_docs_wiki = []
121
- if pairs_to_score_wiki:
122
- scores_wiki = reranker.predict(pairs_to_score_wiki)
123
- for i, doc in enumerate(candidates_for_rerank_wiki):
124
- doc['rerank_score'] = scores_wiki[i]
125
- sorted_docs_wiki = sorted(candidates_for_rerank_wiki, key=lambda x: x['rerank_score'], reverse=True)
126
 
127
- # 3. Lấy Top 3 từ mỗi loại định dạng riêng biệt
128
- top_3_hf = sorted_docs_hf[:3]
129
- top_3_wiki = sorted_docs_wiki[:3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- # Định dạng kết quả từ HF RAG (sử dụng helper, định dạng đầy đủ)
132
- hf_context = format_metadata_list_to_context(top_3_hf)
133
 
134
- # Định dạng kết quả từ Wiki RAG (chỉ tên + mô tả)
135
- wiki_context_parts = []
136
- for doc in top_3_wiki:
137
- metadata = doc.get('metadata', {})
138
- name = metadata.get('group', 'Không rõ tên')
139
- section = metadata.get('section', '')
140
- description = metadata.get('content', 'Không có mô tả.')
141
- wiki_context_parts.append(f"[Nguồn Wiki - Tên]: {name}\n[Nguồn Wiki - Mô tả]: {description}" + (f"\n[Nguồn Wiki - Section]: {section}" if section else ""))
142
 
143
- wiki_context = "\n\n".join(wiki_context_parts)
144
-
145
- # 4. Kết hợp hai ngữ cảnh thành một chuỗi duy nhất
146
- final_context_parts = []
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- if wiki_context:
149
- final_context_parts.append(wiki_context)
150
-
151
-
152
- if hf_context and hf_context.strip() != "Không có dữ liệu ngữ cảnh.":
153
- final_context_parts.append(hf_context)
154
 
155
- if not final_context_parts:
156
- return "Không tìm thấy thông tin phù hợp."
157
 
158
- return "\n\n".join(final_context_parts).strip()
159
  except Exception as e:
 
 
160
  print(f"Lỗi trong advanced_search: {e}")
161
- return "Đã xảy ra lỗi trong quá trình tìm kiếm nâng cao."
 
 
1
  import concurrent.futures
2
+ import torch
3
+ import torch.nn.functional as F
4
 
5
+ from .helper import format_metadata_list_to_context
6
+ from .rag import hf_rag_service, local_rag_service
 
 
 
 
7
 
8
  def reciprocal_rank_fusion(search_results_lists: list, k_rrf: int = 60):
9
  """
 
38
 
39
  def advanced_search(query, keyword):
40
  """
41
+ Tìm kiếm nâng cao: search song song các cụm từ RRF fusion → rerank top-K (Bằng embedding model) → format context.
42
+ Trả về tuple: (context_string, image_url | None)
43
  """
44
  try:
45
+ result_all = []
 
46
 
47
  # 1. Tạo danh sách các cụm từ tìm kiếm
48
  search_terms = [query]
 
51
  elif isinstance(keyword, str) and keyword:
52
  search_terms.append(keyword)
53
 
54
+ # 2. Tìm kiếm song song qua Manager (chứa 2 Dataset) với mỗi cụm từ
55
  with concurrent.futures.ThreadPoolExecutor() as executor:
56
+ future_to_term = {}
57
  for r in search_terms:
58
+ if not r.strip(): continue
59
+ future = executor.submit(local_rag_service.search, r.lower(), top_k=15)
60
+ future_to_term[future] = r
 
 
 
61
 
62
+ for future in concurrent.futures.as_completed(future_to_term):
 
63
  try:
64
  docs = future.result()
65
  if docs:
66
+ result_all.append(docs)
 
 
 
67
  except Exception as e:
68
+ print(f"Lỗi tìm kiếm concurrent: {e}")
69
 
70
+ if not result_all:
71
+ return "Không tìm thấy thông tin phù hợp.", None
72
 
73
+ # 3. RRF Fusion sau đó Rerank tổng cộng
74
+ fused_results = reciprocal_rank_fusion(result_all)
75
+ print("candidates_for_rerank (after deduplication):", len(fused_results))
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # Tối ưu: Cắt giảm xuống 20 ứng viên để Reranking
78
+ candidates_for_rerank = fused_results[:20]
79
+
80
+ pairs_to_score = []
81
+ for item in candidates_for_rerank:
 
 
 
 
 
 
 
 
 
82
  meta = item['metadata']
83
+ # Lấy các trường đã được đồng bộ trong dataset (group, combined_text, original_content, dataset_source)
84
+ name = meta.get('group', '') or meta.get('ten', '')
85
+ desc = meta.get('combined_text') or meta.get('original_content') or meta.get('content', '')
86
  constructed_text = f"Tên: {name}. Mô tả: {desc}"
87
+ pairs_to_score.append([query, constructed_text])
 
 
 
 
 
 
 
88
 
89
+ # === MỚI: Reranking bằng Cosine Similarity của model Embedding (Bỏ Cross-Encoder) ===
90
+ sorted_docs = []
91
+ if pairs_to_score:
92
+ print(f"Reranking {len(pairs_to_score)} candidates bằng Embedding Model siêu nhanh...")
93
+ try:
94
+ # Trích xuất model embeddings ban đầu
95
+ q_text = query
96
+ docs_texts = [p[1] for p in pairs_to_score]
97
+
98
+ # Encode text trực tiếp
99
+ q_emb = local_rag_service.model.encode([q_text], convert_to_tensor=True)
100
+ docs_emb = local_rag_service.model.encode(docs_texts, convert_to_tensor=True)
101
+
102
+ # Tính khoảng cách Cosine
103
+ cos_scores = F.cosine_similarity(q_emb, docs_emb).cpu().tolist()
104
+
105
+ # Gán điểm và sort
106
+ for i, doc in enumerate(candidates_for_rerank):
107
+ doc['rerank_score'] = cos_scores[i]
108
+
109
+ sorted_docs = sorted(candidates_for_rerank, key=lambda x: x['rerank_score'], reverse=True)
110
+ except Exception as e:
111
+ import traceback
112
+ traceback.print_exc()
113
+ print(f"Lỗi khi dùng embedding model de rerank, fallback sang RRF score: {e}")
114
+ sorted_docs = candidates_for_rerank
115
+ else:
116
+ sorted_docs = candidates_for_rerank
117
+
118
+ # Lấy top 5 kết quả tốt nhất
119
+ top_5 = sorted_docs[:5]
120
+
121
+ # 4. Định dạng lại nội dung ngữ cảnh chung (Format chung)
122
+ final_context = format_metadata_list_to_context(top_5)
123
 
124
+ if not final_context or final_context.strip() == "Không dữ liệu ngữ cảnh.":
125
+ return "Không tìm thấy thông tin phù hợp.", None
126
 
127
+ # 5. Trích xuất image_url từ các dataset đã gộp
128
+ image_url = None
 
 
 
 
 
 
129
 
130
+ # (a) Từ topic name của top results (Ưu tiên group name)
131
+ for doc in top_5:
132
+ topic = (doc.get('metadata', {}).get('group') or doc.get('metadata', {}).get('ten') or '').strip()
133
+ if topic:
134
+ image_url = local_rag_service.get_image_for_topic(topic)
135
+ if image_url:
136
+ print(f"[IMAGE] Found via topic: '{topic}'")
137
+ break
138
+
139
+ # (b) Fallback: từ search terms
140
+ if not image_url:
141
+ for term in search_terms:
142
+ if term and len(term) > 2:
143
+ image_url = local_rag_service.get_image_for_topic(term)
144
+ if image_url:
145
+ print(f"[IMAGE] Found via keyword: '{term}'")
146
+ break
147
 
148
+ if image_url:
149
+ print(f"[IMAGE] final: {image_url}")
 
 
 
 
150
 
151
+ return final_context, image_url
 
152
 
 
153
  except Exception as e:
154
+ import traceback
155
+ traceback.print_exc()
156
  print(f"Lỗi trong advanced_search: {e}")
157
+ return "Đã xảy ra lỗi trong quá trình tìm kiếm nâng cao.", None
service/rewrite.py CHANGED
@@ -1,14 +1,14 @@
1
  from groq import Groq
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
- from reranking import advanced_search
5
  import os
6
  from dotenv import load_dotenv
7
 
8
  load_dotenv()
9
 
10
  #groq_api_key = os.environ.get("GROQ_API_KEY")
11
- groq_api_key = os.getenv('GROQ_API_KEY')
12
 
13
  if not groq_api_key:
14
  raise ValueError("GROQ_API_KEY environment variable is not set")
@@ -244,6 +244,49 @@ class QueryRewriter:
244
  hypothetical_answer = completion.choices[0].message.content.strip()
245
  return hypothetical_answer
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def chain_of_thought(self,question,ans):
248
  COT_SYSTEM_PROMPT = """Bạn là một trợ lý AI chuyên biên tập và kiểm tra tính liên quan của câu trả lời (Relevance-Checking Editor) chuyên về văn hoá, lịch sử và địa điểm Việt Nam.
249
  NHIỆM VỤ CỐT LÕI: Dựa vào "Câu hỏi gốc", hãy lọc lại "Câu trả lời được tạo ra" để đảm bảo mọi thông tin trong câu trả lời cuối cùng đều liên quan trực tiếp đến chủ thể trong câu hỏi.
@@ -252,7 +295,7 @@ class QueryRewriter:
252
  1. **Nếu câu trả lời mang ý nghĩa xã giao (chào hỏi, cảm ơn, v.v...) hoặc tôi không tìm thấy thông tin về vấn đề này trong tài liệu được cung cấp. **, hãy trả về nguyên văn câu trả lời ban đầu mà không chỉnh sửa gì.
253
  2. **Xác định chủ thể chính**: Đọc kỹ "Câu hỏi gốc" để xác định đối tượng, địa danh, hoặc khái niệm chính mà người dùng đang hỏi.
254
  3. **Kiểm tra và lọc**: Rà soát từng câu, từng ý trong "Câu trả lời được tạo ra".
255
- * **Giữ lại**: Chỉ giữ lại những thông tin mô tả, giải thích, hoặc liệt kê các chi tiết liên quan đến chủ thể chính của câu hỏi.
256
  * **Loại bỏ**: Xóa bỏ hoàn toàn bất kỳ thông tin nào nói về một chủ thể khác, không liên quan.
257
  4. **Tổng hợp lại**: Viết lại câu trả lời cuối cùng một cách mạch lạc, tự nhiên từ những thông tin đã được lọc.
258
 
@@ -288,92 +331,95 @@ class QueryRewriter:
288
  {"role": "user", "content": user_content}
289
  ]
290
  )
291
- print("answer before cot:",ans)
292
  return completion.choices[0].message.content.strip()
293
 
294
- def ask_with_context(self,question,history):
295
-
296
- # get key word
297
- keyword = self.keyword(question,history)
 
298
  print(f"\n--- keyword: {keyword} ---")
299
- print(type(keyword))
300
- # rewrite question with key word
301
- q_rewrite = self.rewrite_query(question,history)
302
  print(f"\n--- q_rewrite: {q_rewrite} ---")
 
 
 
 
303
  fake_answer = self.generate_hypothetical_answer(q_rewrite)
304
- print(f"\n--- fake_answer: {fake_answer} ---")
305
- # get top 30 RAG and reranking by question rewrite and keyword then get 5
306
- p = advanced_search(fake_answer,keyword)
307
- print(f"\n--- context p: {p} ---")
308
- RAG_SYSTEM_PROMPT = """Bạn một trợ AI chuyên trả lời các câu hỏi về văn hóa các dân tộc Việt Nam.
309
- NHIỆM VỤ CỐT LÕI: Trả lời câu hỏi của người dùng CHỈ DỰA VÀO thông tin được cung cấp trong phần "Dữ liệu Ngữ cảnh" (Context).
310
-
311
- Dữ liệu ngữ cảnh bao gồm các nguồn:
312
- - [Nguồn Wiki]: Thông tin chi tiết, mô tả sâu.
313
- - [TỔNG QUAN]: Thông tin tóm tắt, định danh.
 
 
 
 
 
314
 
315
- HƯỚNG DẪN XỬ LÝ:
316
- 1. **Đọc hiểu ngữ cảnh**: Bạn cần đọc kỹ cả thông tin từ [Nguồn Wiki] và [TỔNG QUAN] (nếu có) để có cái nhìn toàn diện.
317
- 2. **Tổng hợp câu trả lời**:
318
- - Kết hợp thông tin từ cả hai nguồn để câu trả lời đầy đủ và chính xác nhất.
319
- - Nếu [TỔNG QUAN] cung cấp thông tin cơ bản (tên, địa điểm, thời gian), hãy dùng nó để giới thiệu.
320
- - Nếu [Nguồn Wiki] cung cấp chi tiết mô tả, lịch sử, ý nghĩa, hãy dùng nó để giải thích sâu hơn.
321
- 3. **Xử lý câu hỏi cụ thể**:
322
- - Với câu hỏi so sánh: Tìm điểm giống và khác nhau trong ngữ cảnh của các đối tượng.
323
- - Với câu hỏi liệt kê: Liệt kê các đối tượng có trong ngữ cảnh phù hợp với câu hỏi.
324
 
325
- QUY TẮC BẮT BUỘC:
326
- - TUYỆT ĐỐI KHÔNG sử dụng kiến thức bên ngoài ngữ cảnh.
327
- - Nếu không có thông tin trong ngữ cảnh, hãy trả lời: "Xin lỗi, tôi không tìm thấy thông tin về vấn đề này trong tài liệu được cung cấp."
328
- - Trả lời bằng tiếng Việt có dấu, văn phong lịch sự, rõ ràng.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  """
330
-
331
- # Tạo nội dung user prompt: Ghép context và câu hỏi gốc
332
- user_content = f"""### Context:
333
- {p}
334
 
335
- ### User Question:
336
- {q_rewrite}
337
- """
338
-
339
- # pull model
340
- # messages = [
341
- # {"role": "system", "content": RAG_SYSTEM_PROMPT},
342
- # {"role": "user", "content": user_content}
343
- # ]
344
-
345
- # # Format prompt theo chuẩn của Qwen
346
- # text = self.tokenizer.apply_chat_template(
347
- # messages,
348
- # tokenize=False,
349
- # add_generation_prompt=True
350
- # )
351
-
352
- # # Đưa input vào đúng device của model_base
353
- # model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
354
-
355
- # # --- SỬA LỖI TẠI ĐÂY ---
356
- # # Dùng model_base (Qwen) thay vì model (SentenceTransformer)
357
- # generated_ids = self.model.generate(
358
- # **model_inputs, # Lưu ý: sửa cả tên biến input cho khớp (model_inputs thay vì model_base1 cho rõ nghĩa)
359
- # max_new_tokens=512,
360
- # temperature=0.1,
361
- # top_p=0.9
362
- # )
363
-
364
- # generated_ids = [
365
- # output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
366
- # ]
367
-
368
- # answer_bot = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
369
-
370
  completion = client.chat.completions.create(
371
  model="meta-llama/llama-4-scout-17b-16e-instruct",
372
- messages = [
373
  {"role": "system", "content": RAG_SYSTEM_PROMPT},
374
  {"role": "user", "content": user_content}
375
  ]
376
  )
377
  answer_bot = completion.choices[0].message.content.strip()
378
- return self.chain_of_thought(q_rewrite,answer_bot)
379
- #return completion.choices[0].message.content.strip()
 
 
 
 
 
 
 
 
 
 
 
1
  from groq import Groq
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
+ from .reranking import advanced_search
5
  import os
6
  from dotenv import load_dotenv
7
 
8
  load_dotenv()
9
 
10
  #groq_api_key = os.environ.get("GROQ_API_KEY")
11
+ groq_api_key = "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"
12
 
13
  if not groq_api_key:
14
  raise ValueError("GROQ_API_KEY environment variable is not set")
 
244
  hypothetical_answer = completion.choices[0].message.content.strip()
245
  return hypothetical_answer
246
 
247
+ def detect_image_intent(self, query: str, history: list) -> bool:
248
+ """Phát hiện xem câu hỏi có ý định muốn xem hình ảnh/minh họa không."""
249
+ conversation_str = ""
250
+ for turn in history:
251
+ conversation_str += f"{turn['role'].capitalize()}: {turn['content']}\n"
252
+
253
+ IMAGE_INTENT_PROMPT = """Bạn là một API phân loại ý định (Intent Classifier). Nhiệm vụ duy nhất của bạn là xác định xem câu hỏi của người dùng có muốn xem HÌNH ẢNH, MINH HỌA hay không.
254
+
255
+ Các ý định MUỐN XEM ẢNH bao gồm:
256
+ - Yêu cầu trực tiếp: "cho tôi xem ảnh", "hình ảnh", "ảnh", "photo", "picture"
257
+ - Câu hỏi ngoại hình/hình thức: "trông như thế nào", "có hình dạng gì", "diện mạo", "nhìn như thế nào"
258
+ - Câu hỏi giới thiệu tổng quát: "giới thiệu về", "tìm hiểu về", "cho tôi biết về", "kể về"
259
+ - Câu hỏi mô tả: "mô tả", "đặc điểm", "trình bày"
260
+
261
+ Các ý định KHÔNG muốn xem ảnh:
262
+ - Hỏi sự kiện/năm tháng: "được công nhận năm nào", "xảy ra khi nào"
263
+ - Hỏi so sánh hàm ý trừu tượng: "khác nhau ở điểm nào", "ý nghĩa là gì"
264
+ - Hỏi lý do/nguyên nhân: "tại sao", "vì sao"
265
+ - Chào hỏi xã giao: "xin chào", "cảm ơn"
266
+
267
+ QUY TẮC OUTPUT:
268
+ - Chỉ trả về một từ duy nhất: YES hoặc NO
269
+ - KHÔNG giải thích"""
270
+
271
+ user_content = f"""Lịch sử:\n{conversation_str}\nCâu hỏi: \"{query}\"\nOutput:"""
272
+
273
+ try:
274
+ completion = client.chat.completions.create(
275
+ model="meta-llama/llama-4-scout-17b-16e-instruct",
276
+ messages=[
277
+ {"role": "system", "content": IMAGE_INTENT_PROMPT},
278
+ {"role": "user", "content": user_content}
279
+ ],
280
+ max_tokens=5,
281
+ temperature=0.0
282
+ )
283
+ result = completion.choices[0].message.content.strip().upper()
284
+ print(f"\n--- image_intent: {result} ---")
285
+ return result.startswith("YES")
286
+ except Exception as e:
287
+ print(f"\n--- image_intent error: {e} ---")
288
+ return False
289
+
290
  def chain_of_thought(self,question,ans):
291
  COT_SYSTEM_PROMPT = """Bạn là một trợ lý AI chuyên biên tập và kiểm tra tính liên quan của câu trả lời (Relevance-Checking Editor) chuyên về văn hoá, lịch sử và địa điểm Việt Nam.
292
  NHIỆM VỤ CỐT LÕI: Dựa vào "Câu hỏi gốc", hãy lọc lại "Câu trả lời được tạo ra" để đảm bảo mọi thông tin trong câu trả lời cuối cùng đều liên quan trực tiếp đến chủ thể trong câu hỏi.
 
295
  1. **Nếu câu trả lời mang ý nghĩa xã giao (chào hỏi, cảm ơn, v.v...) hoặc tôi không tìm thấy thông tin về vấn đề này trong tài liệu được cung cấp. **, hãy trả về nguyên văn câu trả lời ban đầu mà không chỉnh sửa gì.
296
  2. **Xác định chủ thể chính**: Đọc kỹ "Câu hỏi gốc" để xác định đối tượng, địa danh, hoặc khái niệm chính mà người dùng đang hỏi.
297
  3. **Kiểm tra và lọc**: Rà soát từng câu, từng ý trong "Câu trả lời được tạo ra".
298
+ * **Giữ lại**: Chỉ giữ lại những thông tin mô tả, giải thích, hoặc liệt kê các chi tiết liên quan, url hình ảnh (https://raw.githubusercontent.com/T-Phong/WikiImage/main/Aodai_%C3%81o%20d%C3%A0i_Introduction_2.png), v.v. liên quan đến chủ thể chính của câu hỏi.
299
  * **Loại bỏ**: Xóa bỏ hoàn toàn bất kỳ thông tin nào nói về một chủ thể khác, không liên quan.
300
  4. **Tổng hợp lại**: Viết lại câu trả lời cuối cùng một cách mạch lạc, tự nhiên từ những thông tin đã được lọc.
301
 
 
331
  {"role": "user", "content": user_content}
332
  ]
333
  )
334
+
335
  return completion.choices[0].message.content.strip()
336
 
337
+ def ask_with_context(self, question, history):
338
+ """Pipeline chính: trả về tuple (answer_text, image_url | None)"""
339
+
340
+ # 1. Trích xuất keyword
341
+ keyword = self.keyword(question, history)
342
  print(f"\n--- keyword: {keyword} ---")
343
+
344
+ # 2. Viết lại câu hỏi
345
+ q_rewrite = self.rewrite_query(question, history)
346
  print(f"\n--- q_rewrite: {q_rewrite} ---")
347
+
348
+ # 3. Tạo câu trả lời giả định (HyDE)
349
+ # HyDE chỉ dùng làm keyword PHỤ để mở rộng search — KHÔNG làm query chính
350
+ # vì embedding của HyDE dài không match với chunk ngắn trong vector DB
351
  fake_answer = self.generate_hypothetical_answer(q_rewrite)
352
+ print(f"\n--- fake_answer (first 150): {str(fake_answer)[:150]} ---")
353
+
354
+ # 4. Tìm kiếm nâng cao
355
+ # Query chính = q_rewrite (ngắn gọn, khớp embedding DB tốt hơn)
356
+ # Keywords = keyword (từ LLM extractor) + đầu HyDE (150 tự) làm term mở rộng
357
+ hyde_snippet = (fake_answer or '').strip()[:150]
358
+ search_keywords = keyword + ([hyde_snippet] if hyde_snippet else [])
359
+ context, candidate_image_url = advanced_search(q_rewrite, search_keywords)
360
+ print(f"\n--- context (200 chars): {str(context)[:200]} ---")
361
+ print(f"\n--- candidate_image_url: {candidate_image_url} ---")
362
+
363
+ # 5. Phát hiện intent xem ảnh SỚM (trước khi gọi LLM)
364
+ wants_image = self.detect_image_intent(question, history)
365
+ print(f"\n--- image_intent: {'YES' if wants_image else 'NO'} ---")
366
+ image_url = candidate_image_url if (wants_image and candidate_image_url) else None
367
 
368
+ RAG_SYSTEM_PROMPT = """Bạn là một trợ lý AI chuyên trả lời các câu hỏi về văn hóa các dân tộc Việt Nam.
 
 
 
 
 
 
 
 
369
 
370
+ NGUỒN DỮ LIỆU DUY NHẤT: Chỉ được dùng thông tin nằm trong thẻ <context>...</context> bên dưới.
371
+ TUYỆT ĐỐI KHÔNG dùng kiến thức ngoài context, kể cả khi bạn biết câu trả lời.
372
+
373
+ QUY TẮC XỬ LÝ:
374
+ 1. Đọc kỹ toàn bộ <context>.
375
+ 2. Kiểm tra xem context có đề cập đến chủ thể trong câu hỏi không:
376
+ - NẾU CÓ → tổng hợp từ [TỔNG QUAN] để trả lời đầy đủ, chính xác.
377
+ - NẾU KHÔNG CÓ hoặc thông tin không liên quan → trả lời ngay: "Xin lỗi, tôi không tìm thấy thông tin về vấn đề này trong tài liệu được cung cấp." — KHÔNG được suy luận hay bổ sung thêm.
378
+ 3. Xử lý câu hỏi cụ thể:
379
+ - So sánh: tìm điểm gi���ng/khác trong context.
380
+ - Liệt kê: liệt kê đúng những gì context đề cập.
381
+ 4. Hình ảnh: nếu context có "### Hình ảnh minh họa", nhúng ảnh: ![mô tả](URL). Không đề cập ảnh nếu không có.
382
+
383
+ FORMAT MARKDOWN:
384
+ - **in đậm** cho tên riêng, địa danh, khái niệm quan trọng.
385
+ - `##` tiêu đề chính, `###` tiêu đề phụ (chỉ khi câu trả lời dài, nhiều phần).
386
+ - Danh sách `-` khi liệt kê. Câu trả lời ngắn/xã giao: chỉ text thuần, không heading.
387
+ - TUYỆT ĐỐI KHÔNG đề cập đến việc tìm kiếm ảnh trên internet.
388
  """
 
 
 
 
389
 
390
+ # 6. Build user_content — wrap context trong thẻ XML để model phân biệt rõ ranh giới
391
+ image_section = ""
392
+ if image_url:
393
+ image_section = f"\n\n### Hình ảnh minh họa\n![]({image_url})\n"
394
+
395
+ user_content = f"""<context>
396
+ {context}{image_section}
397
+ </context>
398
+
399
+ ### Câu hỏi:
400
+ {q_rewrite}
401
+ """
402
+ print(f"\n--- ---")
403
+ print(f"\n--- user_content: {str(user_content)} ---")
404
+ print(f"\n--- ---")
405
+ # 7. Sinh câu trả lời
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  completion = client.chat.completions.create(
407
  model="meta-llama/llama-4-scout-17b-16e-instruct",
408
+ messages=[
409
  {"role": "system", "content": RAG_SYSTEM_PROMPT},
410
  {"role": "user", "content": user_content}
411
  ]
412
  )
413
  answer_bot = completion.choices[0].message.content.strip()
414
+ print(f"\n--- ---")
415
+ print("answer before cot:",answer_bot)
416
+ print(f"\n--- ---")
417
+ # 8. Lọc câu trả lời qua CoT
418
+
419
+ final_answer = self.chain_of_thought(q_rewrite, answer_bot)
420
+ print(f"\n--- ---")
421
+ print("answer:", answer_bot)
422
+ print(f"\n--- ---")
423
+ return final_answer, image_url
424
+
425
+