Spaces:
Sleeping
Sleeping
T-Phong commited on
Commit ·
b21ec88
1
Parent(s): c952c4d
update code
Browse files- app.py +60 -167
- requirements_rag.txt +2 -1
- service/helper.py +1 -8
- service/rag.py +214 -190
- service/reranking.py +90 -94
- service/rewrite.py +122 -76
app.py
CHANGED
|
@@ -1,249 +1,142 @@
|
|
| 1 |
"""
|
| 2 |
-
REST API cho Vietnam Heritage RAG System
|
| 3 |
"""
|
| 4 |
import uuid
|
| 5 |
-
import json
|
| 6 |
import os
|
| 7 |
import sys
|
| 8 |
-
from datetime import datetime
|
| 9 |
|
| 10 |
-
import google.generativeai as genai
|
| 11 |
from flask import Flask, request, jsonify
|
| 12 |
from flask_cors import CORS
|
| 13 |
|
| 14 |
-
#
|
| 15 |
sys.path.append(os.path.join(os.path.dirname(__file__), 'service'))
|
| 16 |
-
from
|
|
|
|
| 17 |
|
| 18 |
app = Flask(__name__)
|
| 19 |
CORS(app)
|
| 20 |
-
# Khởi tạo QueryRewriter (chứa ask_with_context)
|
| 21 |
-
rewriter = QueryRewriter()
|
| 22 |
-
islog = os.getenv('islog')
|
| 23 |
-
metrics_log = [] # Lưu lại các lần đánh giá để dựng biểu đồ
|
| 24 |
-
|
| 25 |
-
GENAI_API_KEY = os.getenv("GEMINI_API_KEY")
|
| 26 |
-
if GENAI_API_KEY:
|
| 27 |
-
genai.configure(api_key=GENAI_API_KEY)
|
| 28 |
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
def _safe_json_parse(text):
|
| 31 |
-
"""Parse chuỗi JSON, cố gắng trích block {} đầu tiên nếu có thêm text."""
|
| 32 |
-
try:
|
| 33 |
-
return json.loads(text)
|
| 34 |
-
except Exception:
|
| 35 |
-
pass
|
| 36 |
-
|
| 37 |
-
start = text.find("{")
|
| 38 |
-
end = text.rfind("}")
|
| 39 |
-
if start != -1 and end != -1 and end > start:
|
| 40 |
-
try:
|
| 41 |
-
return json.loads(text[start : end + 1])
|
| 42 |
-
except Exception:
|
| 43 |
-
return None
|
| 44 |
-
return None
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def evaluate_answer_llm(question: str, answer: str, history_message):
|
| 48 |
-
"""Gọi LLM để chấm điểm mức liên quan, độ chính xác và mức độ lan man."""
|
| 49 |
-
if not GENAI_API_KEY:
|
| 50 |
-
return {
|
| 51 |
-
"status": "skipped",
|
| 52 |
-
"reason": "missing_gemini_api_key",
|
| 53 |
-
}
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
prompt = (
|
| 59 |
-
"You are an evaluator for a RAG chatbot."
|
| 60 |
-
" Return JSON with keys: rag_relevance (0-1), answer_accuracy (0-1), hallucination (bool), notes (string)."
|
| 61 |
-
" Evaluate strictly from question and answer (and chat history if provided)."
|
| 62 |
-
" rag_relevance measures how well retrieved context seems relevant to the question."
|
| 63 |
-
" answer_accuracy measures factual correctness and completeness."
|
| 64 |
-
" hallucination is true if the answer includes unrelated, fabricated, or off-topic info."
|
| 65 |
-
f"\nQuestion: {question}\nAnswer: {answer}\nHistory: {history_text}\nReturn JSON only."
|
| 66 |
-
)
|
| 67 |
-
resp = model.generate_content(prompt)
|
| 68 |
-
parsed = _safe_json_parse(resp.text)
|
| 69 |
-
if not parsed:
|
| 70 |
-
raise ValueError("LLM did not return valid JSON")
|
| 71 |
-
|
| 72 |
-
rag_rel = float(parsed.get("rag_relevance", 0))
|
| 73 |
-
acc = float(parsed.get("answer_accuracy", 0))
|
| 74 |
-
halluc = bool(parsed.get("hallucination", False))
|
| 75 |
-
|
| 76 |
-
return {
|
| 77 |
-
"status": "ok",
|
| 78 |
-
"timestamp": datetime.utcnow().isoformat() + "Z",
|
| 79 |
-
"rag_relevance": max(0.0, min(1.0, rag_rel)),
|
| 80 |
-
"answer_accuracy": max(0.0, min(1.0, acc)),
|
| 81 |
-
"hallucination": halluc,
|
| 82 |
-
"notes": parsed.get("notes", "") or "",
|
| 83 |
-
}
|
| 84 |
-
except Exception as e:
|
| 85 |
-
return {
|
| 86 |
-
"status": "error",
|
| 87 |
-
"error": str(e),
|
| 88 |
-
}
|
| 89 |
|
| 90 |
@app.route('/v1/chat/completions', methods=['POST'])
|
| 91 |
def ask_api():
|
| 92 |
"""
|
| 93 |
-
Main endpoint -
|
| 94 |
-
|
| 95 |
Request body:
|
| 96 |
{
|
| 97 |
-
"
|
|
|
|
|
|
|
| 98 |
}
|
| 99 |
-
|
| 100 |
Response:
|
| 101 |
{
|
| 102 |
-
"
|
| 103 |
-
"
|
|
|
|
|
|
|
| 104 |
}
|
| 105 |
-
"""
|
| 106 |
try:
|
| 107 |
data = request.get_json()
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
# if islog == "1":
|
| 111 |
-
# for f in history_message:
|
| 112 |
-
# print(f)
|
| 113 |
-
|
| 114 |
|
| 115 |
-
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
question = all_messages[-1]
|
| 118 |
-
|
| 119 |
if not question:
|
| 120 |
-
return jsonify({
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
# Đánh giá tự động bằng LLM
|
| 128 |
-
# evaluation = evaluate_answer_llm(question, answer, history_message)
|
| 129 |
-
# if evaluation:
|
| 130 |
-
# metrics_log.append({
|
| 131 |
-
# "question": question,
|
| 132 |
-
# "answer": answer,
|
| 133 |
-
# "evaluation": evaluation,
|
| 134 |
-
# })
|
| 135 |
-
# # Giữ kích thước log vừa phải để hiển thị biểu đồ
|
| 136 |
-
# if len(metrics_log) > 200:
|
| 137 |
-
# del metrics_log[:-200]
|
| 138 |
|
| 139 |
return jsonify({
|
| 140 |
"id": str(uuid.uuid4()),
|
| 141 |
-
"
|
| 142 |
-
"choices": [
|
| 143 |
{
|
| 144 |
-
"
|
| 145 |
-
"
|
| 146 |
-
"role": "assistant",
|
| 147 |
-
"content": answer
|
| 148 |
-
},
|
| 149 |
-
"finish_reason": "stop"
|
| 150 |
}
|
| 151 |
],
|
| 152 |
-
"
|
| 153 |
}), 200
|
| 154 |
-
|
| 155 |
except Exception as e:
|
| 156 |
return jsonify({
|
| 157 |
"error": str(e),
|
| 158 |
"status": "error"
|
| 159 |
}), 500
|
| 160 |
|
|
|
|
| 161 |
@app.route('/v1/models', methods=['GET'])
|
| 162 |
-
def
|
| 163 |
return jsonify({
|
| 164 |
"object": "list",
|
| 165 |
"data": [
|
| 166 |
-
{"id": "
|
| 167 |
-
|
| 168 |
-
]
|
| 169 |
}), 200
|
| 170 |
-
|
| 171 |
|
| 172 |
@app.route('/health', methods=['GET'])
|
| 173 |
def health_check():
|
| 174 |
"""Health check endpoint"""
|
|
|
|
|
|
|
|
|
|
| 175 |
return jsonify({
|
| 176 |
"status": "healthy",
|
| 177 |
"service": "Vietnam Heritage RAG API"
|
| 178 |
}), 200
|
| 179 |
|
|
|
|
| 180 |
@app.route('/', methods=['GET'])
|
| 181 |
def home():
|
| 182 |
"""API documentation"""
|
| 183 |
return jsonify({
|
| 184 |
"message": "Vietnam Heritage AI REST API",
|
| 185 |
-
"version": "
|
| 186 |
"endpoints": {
|
| 187 |
-
"POST /
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
}
|
| 192 |
-
},
|
| 193 |
-
"GET /health": "Health check endpoint",
|
| 194 |
-
"GET /": "API documentation",
|
| 195 |
-
"GET /lstmodel": "List available models"
|
| 196 |
},
|
| 197 |
"example": {
|
| 198 |
-
"url": "/
|
| 199 |
"method": "POST",
|
| 200 |
"body": {
|
| 201 |
-
"
|
|
|
|
|
|
|
| 202 |
}
|
| 203 |
}
|
| 204 |
}), 200
|
| 205 |
|
| 206 |
|
| 207 |
-
@app.route('/metrics', methods=['GET'])
|
| 208 |
-
def get_metrics():
|
| 209 |
-
"""Trả về log đánh giá để dựng biểu đồ ở frontend."""
|
| 210 |
-
# Tính trung bình nhanh để tiện hiển thị
|
| 211 |
-
rag_scores = [m["evaluation"].get("rag_relevance", 0) for m in metrics_log if m.get("evaluation", {}).get("status") == "ok"]
|
| 212 |
-
acc_scores = [m["evaluation"].get("answer_accuracy", 0) for m in metrics_log if m.get("evaluation", {}).get("status") == "ok"]
|
| 213 |
-
halluc_counts = [m["evaluation"].get("hallucination", False) for m in metrics_log if m.get("evaluation", {}).get("status") == "ok"]
|
| 214 |
-
|
| 215 |
-
summary = {
|
| 216 |
-
"total": len(metrics_log),
|
| 217 |
-
"avg_rag_relevance": sum(rag_scores) / len(rag_scores) if rag_scores else 0,
|
| 218 |
-
"avg_answer_accuracy": sum(acc_scores) / len(acc_scores) if acc_scores else 0,
|
| 219 |
-
"hallucination_rate": (sum(1 for h in halluc_counts if h) / len(halluc_counts)) if halluc_counts else 0,
|
| 220 |
-
}
|
| 221 |
-
|
| 222 |
-
return jsonify({
|
| 223 |
-
"summary": summary,
|
| 224 |
-
"data": metrics_log,
|
| 225 |
-
}), 200
|
| 226 |
-
|
| 227 |
-
@app.route('/reset', methods=['POST'])
|
| 228 |
-
def reset_history():
|
| 229 |
-
"""Reset conversation history"""
|
| 230 |
-
global history
|
| 231 |
-
history = []
|
| 232 |
-
return jsonify({
|
| 233 |
-
"message": "History reset successfully",
|
| 234 |
-
"status": "success"
|
| 235 |
-
}), 200
|
| 236 |
-
|
| 237 |
if __name__ == '__main__':
|
| 238 |
port = int(os.environ.get('PORT', 5000))
|
| 239 |
print("=" * 60)
|
| 240 |
-
print(
|
| 241 |
print("=" * 60)
|
| 242 |
print(f"📍 Server: http://localhost:{port}")
|
| 243 |
print(f"📝 Endpoints:")
|
| 244 |
-
print(f" POST http://localhost:{port}/
|
| 245 |
print(f" GET http://localhost:{port}/health")
|
| 246 |
print(f" GET http://localhost:{port}/")
|
| 247 |
print("=" * 60)
|
| 248 |
-
|
| 249 |
-
app.run(host='0.0.0.0', port=port, debug=
|
|
|
|
| 1 |
"""
|
| 2 |
+
REST API cho Vietnam Heritage RAG System
|
| 3 |
"""
|
| 4 |
import uuid
|
|
|
|
| 5 |
import os
|
| 6 |
import sys
|
|
|
|
| 7 |
|
|
|
|
| 8 |
from flask import Flask, request, jsonify
|
| 9 |
from flask_cors import CORS
|
| 10 |
|
| 11 |
+
# Thêm thư mục service vào sys.path
|
| 12 |
sys.path.append(os.path.join(os.path.dirname(__file__), 'service'))
|
| 13 |
+
from service.reranking import advanced_search
|
| 14 |
+
from service.rewrite import QueryRewriter
|
| 15 |
|
| 16 |
app = Flask(__name__)
|
| 17 |
CORS(app)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
# Khởi tạo QueryRewriter (chứa toàn bộ pipeline RAG)
|
| 20 |
+
rewriter = QueryRewriter()
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
# ==============================================================================
|
| 24 |
+
# ENDPOINTS
|
| 25 |
+
# ==============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
@app.route('/v1/chat/completions', methods=['POST'])
|
| 28 |
def ask_api():
|
| 29 |
"""
|
| 30 |
+
Main endpoint - Chat với AI về văn hoá Việt Nam.
|
| 31 |
+
|
| 32 |
Request body:
|
| 33 |
{
|
| 34 |
+
"messages": [
|
| 35 |
+
{"role": "user", "content": "Câu hỏi của bạn"}
|
| 36 |
+
]
|
| 37 |
}
|
| 38 |
+
|
| 39 |
Response:
|
| 40 |
{
|
| 41 |
+
"id": "uuid",
|
| 42 |
+
"object": "chat.completion",
|
| 43 |
+
"choices": [{"index": 0, "message": {"role": "assistant", "content": "..."}, "finish_reason": "stop"}],
|
| 44 |
+
"image_url": "https://..." hoặc null
|
| 45 |
}
|
| 46 |
+
"""
|
| 47 |
try:
|
| 48 |
data = request.get_json()
|
| 49 |
+
if not data:
|
| 50 |
+
return jsonify({"error": "Request body is required"}), 400
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
+
all_messages = data.get("messages", [])
|
| 53 |
+
if not all_messages:
|
| 54 |
+
return jsonify({"error": "'messages' cannot be empty"}), 400
|
| 55 |
|
| 56 |
+
question = all_messages[-1].get("content", "").strip()
|
|
|
|
| 57 |
if not question:
|
| 58 |
+
return jsonify({"error": "Last message content cannot be empty"}), 400
|
| 59 |
+
|
| 60 |
+
# Lấy tối đa 3 lượt hội thoại gần nhất (không tính message hiện tại)
|
| 61 |
+
history_message = all_messages[-7:-1]
|
| 62 |
+
|
| 63 |
+
# Pipeline RAG → trả về (answer_text, image_url | None)
|
| 64 |
+
answer, image_url = rewriter.ask_with_context(question, history_message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
return jsonify({
|
| 67 |
"id": str(uuid.uuid4()),
|
| 68 |
+
"message": [
|
|
|
|
| 69 |
{
|
| 70 |
+
"role": "assistant",
|
| 71 |
+
"content": answer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
}
|
| 73 |
],
|
| 74 |
+
"image_url": image_url
|
| 75 |
}), 200
|
| 76 |
+
|
| 77 |
except Exception as e:
|
| 78 |
return jsonify({
|
| 79 |
"error": str(e),
|
| 80 |
"status": "error"
|
| 81 |
}), 500
|
| 82 |
|
| 83 |
+
|
| 84 |
@app.route('/v1/models', methods=['GET'])
|
| 85 |
+
def list_models():
|
| 86 |
return jsonify({
|
| 87 |
"object": "list",
|
| 88 |
"data": [
|
| 89 |
+
{"id": "vietnam-heritage-rag-v1", "object": "model", "owned_by": "culturebot"}
|
| 90 |
+
]
|
|
|
|
| 91 |
}), 200
|
| 92 |
+
|
| 93 |
|
| 94 |
@app.route('/health', methods=['GET'])
|
| 95 |
def health_check():
|
| 96 |
"""Health check endpoint"""
|
| 97 |
+
question = "Người Ê Đê là một dân tộc thiểu số tại Việt Nam, chủ yếu sống ở vùng Tây Nguyên. Văn hóa của người Ê Đê được đặc trưng bởi các yếu tố sau: - **Ngôn n"
|
| 98 |
+
history_message = ['người Ê Đê', 'văn hoá']
|
| 99 |
+
advanced_search(question, history_message)
|
| 100 |
return jsonify({
|
| 101 |
"status": "healthy",
|
| 102 |
"service": "Vietnam Heritage RAG API"
|
| 103 |
}), 200
|
| 104 |
|
| 105 |
+
|
| 106 |
@app.route('/', methods=['GET'])
|
| 107 |
def home():
|
| 108 |
"""API documentation"""
|
| 109 |
return jsonify({
|
| 110 |
"message": "Vietnam Heritage AI REST API",
|
| 111 |
+
"version": "2.0.0",
|
| 112 |
"endpoints": {
|
| 113 |
+
"POST /v1/chat/completions": "Chat với AI về văn hoá Việt Nam",
|
| 114 |
+
"GET /v1/models": "Danh sách model",
|
| 115 |
+
"GET /health": "Health check",
|
| 116 |
+
"GET /": "API documentation"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
},
|
| 118 |
"example": {
|
| 119 |
+
"url": "/v1/chat/completions",
|
| 120 |
"method": "POST",
|
| 121 |
"body": {
|
| 122 |
+
"messages": [
|
| 123 |
+
{"role": "user", "content": "Giới thiệu về Vịnh Hạ Long"}
|
| 124 |
+
]
|
| 125 |
}
|
| 126 |
}
|
| 127 |
}), 200
|
| 128 |
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
if __name__ == '__main__':
|
| 131 |
port = int(os.environ.get('PORT', 5000))
|
| 132 |
print("=" * 60)
|
| 133 |
+
print("🚀 Vietnam Heritage RAG API")
|
| 134 |
print("=" * 60)
|
| 135 |
print(f"📍 Server: http://localhost:{port}")
|
| 136 |
print(f"📝 Endpoints:")
|
| 137 |
+
print(f" POST http://localhost:{port}/v1/chat/completions")
|
| 138 |
print(f" GET http://localhost:{port}/health")
|
| 139 |
print(f" GET http://localhost:{port}/")
|
| 140 |
print("=" * 60)
|
| 141 |
+
|
| 142 |
+
app.run(host='0.0.0.0', port=port, debug=False)
|
requirements_rag.txt
CHANGED
|
@@ -12,4 +12,5 @@ gunicorn>=20.1.0
|
|
| 12 |
flask-cors>=3.0.10
|
| 13 |
google-generativeai>=0.3.3
|
| 14 |
bitsandbytes>=0.39.1
|
| 15 |
-
accelerate>=0.30.3
|
|
|
|
|
|
| 12 |
flask-cors>=3.0.10
|
| 13 |
google-generativeai>=0.3.3
|
| 14 |
bitsandbytes>=0.39.1
|
| 15 |
+
accelerate>=0.30.3
|
| 16 |
+
rank_bm25>=0.2.1
|
service/helper.py
CHANGED
|
@@ -18,7 +18,7 @@ def format_metadata_list_to_context(search_results: List[Dict[str, Any]]) -> str
|
|
| 18 |
|
| 19 |
# 1. Trích xuất dữ liệu (Dùng .get để tránh lỗi nếu thiếu trường)
|
| 20 |
ten = data.get('ten') or data.get('group', 'Không rõ tên')
|
| 21 |
-
mo_ta = data.get('mo_ta') or data.get('content', '')
|
| 22 |
|
| 23 |
# Nhóm thông tin phân loại
|
| 24 |
loai_hinh = data.get('loai_hinh', 'N/A')
|
|
@@ -40,13 +40,6 @@ def format_metadata_list_to_context(search_results: List[Dict[str, Any]]) -> str
|
|
| 40 |
[TỔNG QUAN]
|
| 41 |
Tên: {ten}
|
| 42 |
Mô tả/Nội dung: {mo_ta}
|
| 43 |
-
|
| 44 |
-
[THÔNG TIN CHI TIẾT]
|
| 45 |
-
- Phân loại: {loai_hinh} (Chủ đề: {chu_de})
|
| 46 |
-
- Dân tộc: {dan_toc}
|
| 47 |
-
- Thời gian: {nien_dai} ({thoi_ky})
|
| 48 |
-
- Địa danh/Vùng miền: {vung_mien} - {dia_diem}
|
| 49 |
-
- Chất liệu: {chat_lieu} - Nguyên liệu chính: {nguyen_lieu_chinh}
|
| 50 |
"""
|
| 51 |
|
| 52 |
# 3. Ghép vào chuỗi tổng
|
|
|
|
| 18 |
|
| 19 |
# 1. Trích xuất dữ liệu (Dùng .get để tránh lỗi nếu thiếu trường)
|
| 20 |
ten = data.get('ten') or data.get('group', 'Không rõ tên')
|
| 21 |
+
mo_ta = data.get('mo_ta') or data.get('combined_text') or data.get('original_content') or data.get('content', '')
|
| 22 |
|
| 23 |
# Nhóm thông tin phân loại
|
| 24 |
loai_hinh = data.get('loai_hinh', 'N/A')
|
|
|
|
| 40 |
[TỔNG QUAN]
|
| 41 |
Tên: {ten}
|
| 42 |
Mô tả/Nội dung: {mo_ta}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
"""
|
| 44 |
|
| 45 |
# 3. Ghép vào chuỗi tổng
|
service/rag.py
CHANGED
|
@@ -1,244 +1,268 @@
|
|
| 1 |
import os
|
| 2 |
import json
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import faiss
|
| 5 |
from sentence_transformers import SentenceTransformer
|
| 6 |
-
from datasets import load_dataset
|
| 7 |
-
from huggingface_hub import snapshot_download
|
| 8 |
from typing import List, Dict, Any, Optional
|
| 9 |
-
from
|
| 10 |
-
from helper import format_metadata_list_to_context
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
# Singleton Pattern
|
| 19 |
-
def __new__(cls):
|
| 20 |
-
if cls._instance is None:
|
| 21 |
-
print("Khởi tạo HuggingFaceRAGService...")
|
| 22 |
-
cls._instance = super(HuggingFaceRAGService, cls).__new__(cls)
|
| 23 |
-
cls._instance._initialized = False
|
| 24 |
-
return cls._instance
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
self.
|
| 32 |
-
|
| 33 |
-
# ID của Repo trên Hugging Face chứa file index và data
|
| 34 |
-
# Bạn cần đảm bảo đã upload file .faiss và .json lên repo này (dạng Dataset hoặc Model)
|
| 35 |
-
self.HF_REPO_ID = "synguyen1106/vietnam_heritage_embeddings_v4"
|
| 36 |
-
self.HF_REPO_TYPE = "dataset" # Hoặc "model" hoặc "space" tùy nơi bạn để file
|
| 37 |
|
| 38 |
-
|
| 39 |
-
self.
|
| 40 |
-
self.
|
| 41 |
-
|
| 42 |
|
| 43 |
-
# Load model & Data
|
| 44 |
-
self._load_model()
|
| 45 |
self._load_data()
|
| 46 |
-
|
| 47 |
-
self._initialized = True
|
| 48 |
-
print("✅ HuggingFaceRAGService đã sẵn sàng.")
|
| 49 |
|
| 50 |
-
def
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def _load_data(self):
|
| 55 |
-
""
|
| 56 |
-
Chiến lược:
|
| 57 |
-
1. Cố gắng tải file index đã build sẵn từ Hugging Face (Nhanh, tránh lỗi LFS).
|
| 58 |
-
2. Nếu không tìm thấy file trên HF, fallback về việc tải Dataset gốc và build lại index (Chậm hơn).
|
| 59 |
-
"""
|
| 60 |
try:
|
| 61 |
-
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
index_path = hf_hub_download(
|
| 66 |
-
repo_id=self.HF_REPO_ID,
|
| 67 |
-
filename=self.FILENAME_INDEX,
|
| 68 |
-
repo_type=self.HF_REPO_TYPE
|
| 69 |
-
)
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
-
|
| 83 |
-
self.
|
|
|
|
| 84 |
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
except Exception as e:
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
""
|
| 95 |
-
print("💾 [HF RAG] Đang tải dataset và xây dựng FAISS index mới...")
|
| 96 |
-
dataset = load_dataset(self.HF_REPO_ID, split="train")
|
| 97 |
-
|
| 98 |
-
# Chuẩn bị vectors
|
| 99 |
-
vectors = np.array(dataset['embedding']).astype("float32")
|
| 100 |
-
|
| 101 |
-
# Chuẩn bị metadata (loại bỏ cột embedding để nhẹ RAM)
|
| 102 |
-
self.metadata = [{k: v for k, v in item.items() if k != 'embedding'} for item in dataset]
|
| 103 |
-
|
| 104 |
-
# Build Index
|
| 105 |
d = vectors.shape[1]
|
| 106 |
self.index = faiss.IndexFlatL2(d)
|
| 107 |
self.index.add(vectors)
|
| 108 |
-
|
| 109 |
-
print(f"
|
| 110 |
-
|
| 111 |
-
# Mẹo: Ở đây bạn có thể lưu file ra đĩa và upload ngược lên HF để lần sau dùng cách 1
|
| 112 |
|
| 113 |
-
def
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
results = []
|
| 122 |
-
for
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
return results
|
| 131 |
-
# ==============================================================================
|
| 132 |
-
# HỆ THỐNG RAG 2: SỬ DỤNG LOCAL DISK DATASET
|
| 133 |
-
# ==============================================================================
|
| 134 |
-
class LocalDiskRAGService:
|
| 135 |
-
_instance: Optional['LocalDiskRAGService'] = None
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
def __new__(cls):
|
| 138 |
if cls._instance is None:
|
| 139 |
-
print("
|
| 140 |
-
cls._instance = super(
|
| 141 |
cls._instance._initialized = False
|
| 142 |
return cls._instance
|
| 143 |
|
| 144 |
def __init__(self):
|
| 145 |
if self._initialized:
|
| 146 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
#
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
|
| 155 |
-
|
| 156 |
-
self.
|
| 157 |
-
|
|
|
|
|
|
|
| 158 |
self._initialized = True
|
| 159 |
-
print("✅
|
| 160 |
-
|
| 161 |
-
def _load_model(self):
|
| 162 |
-
print(f"🤖 [Local RAG] Đang tải model AI: {self.MODEL_NAME}...")
|
| 163 |
-
self.model = SentenceTransformer(self.MODEL_NAME)
|
| 164 |
|
| 165 |
-
def
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
self.dataset = load_from_disk(dataset_path)
|
| 173 |
-
print(f"💾 [Local RAG] Load xong! Tổng số dữ liệu: {len(self.dataset)} dòng.")
|
| 174 |
-
|
| 175 |
-
print("🔨 [Local RAG] Đang kích hoạt bộ tìm kiếm (Re-indexing)...")
|
| 176 |
-
self.dataset.add_faiss_index(column="embeddings")
|
| 177 |
-
print("🔨 [Local RAG] Đã kích hoạt xong FAISS Index!")
|
| 178 |
-
except Exception as e:
|
| 179 |
-
print(f"❌ Lỗi: Không thể tải dataset từ Hub. Lỗi: {e}")
|
| 180 |
-
self.dataset = None
|
| 181 |
-
return
|
| 182 |
|
| 183 |
def search(self, query: str, top_k: int = 3) -> List[Dict[str, Any]]:
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
query_vector = self.model.encode(query)
|
| 191 |
-
candidate_k = top_k * self.CANDIDATE_MULTIPLIER
|
| 192 |
-
scores, samples = self.dataset.get_nearest_examples("embeddings", query_vector, k=candidate_k)
|
| 193 |
-
|
| 194 |
-
results = []
|
| 195 |
-
for i in range(len(samples['original_content'])):
|
| 196 |
-
if len(results) >= top_k:
|
| 197 |
-
break
|
| 198 |
-
|
| 199 |
-
content = samples['original_content'][i]
|
| 200 |
-
if len(content) < self.MIN_CONTENT_LENGTH:
|
| 201 |
-
continue
|
| 202 |
-
|
| 203 |
-
score = scores[i]
|
| 204 |
-
metadata = samples['metadata'][i]
|
| 205 |
-
metadata['content'] = content
|
| 206 |
-
|
| 207 |
-
results.append({
|
| 208 |
-
"metadata": metadata,
|
| 209 |
-
"score": score
|
| 210 |
-
})
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
# print("-" * 50)
|
| 216 |
-
|
| 217 |
-
if not results:
|
| 218 |
-
print(f"Không tìm thấy kết quả nào có nội dung dài hơn {self.MIN_CONTENT_LENGTH} ký tự.")
|
| 219 |
-
|
| 220 |
-
return results
|
| 221 |
|
| 222 |
# ==============================================================================
|
| 223 |
-
# KHỞI TẠO SERVICE VÀ CUNG CẤP CÁC HÀM GỐC
|
| 224 |
# ==============================================================================
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
|
| 228 |
def retrieve_context(query: str, k: int = 2) -> str:
|
| 229 |
-
""
|
| 230 |
-
|
| 231 |
-
(Giữ nguyên hàm gốc để tương thích)
|
| 232 |
-
"""
|
| 233 |
-
print("\n>>> Sử dụng hệ thống RAG 1 (HuggingFace)...")
|
| 234 |
-
results = hf_rag_service.search(query, k)
|
| 235 |
return format_metadata_list_to_context(results)
|
| 236 |
|
| 237 |
def search_heritage(query: str, top_k: int = 3) -> str:
|
| 238 |
-
""
|
| 239 |
-
Tìm kiếm di sản sử dụng hệ thống RAG từ ổ đĩa cục bộ.
|
| 240 |
-
(Giữ nguyên hàm gốc để tương thích)
|
| 241 |
-
"""
|
| 242 |
-
print("\n>>> Sử dụng hệ thống RAG 2 (Local Disk)...")
|
| 243 |
results = local_rag_service.search(query, top_k)
|
| 244 |
return format_metadata_list_to_context(results)
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
+
import ast
|
| 4 |
import numpy as np
|
| 5 |
import faiss
|
| 6 |
from sentence_transformers import SentenceTransformer
|
| 7 |
+
from datasets import load_dataset
|
|
|
|
| 8 |
from typing import List, Dict, Any, Optional
|
| 9 |
+
from .helper import format_metadata_list_to_context
|
|
|
|
| 10 |
|
| 11 |
+
try:
|
| 12 |
+
from rank_bm25 import BM25Okapi
|
| 13 |
+
except ImportError:
|
| 14 |
+
BM25Okapi = None
|
| 15 |
+
print("Cảnh báo: Thư viện rank_bm25 chưa được cài đặt, keyword search có thể bị ảnh hưởng.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
class SingleDatasetRAGService:
|
| 18 |
+
"""Xử lý load dữ liệu, tạo vector và tìm kiếm FAISS+BM25 cho một dataset cụ thể."""
|
| 19 |
+
def __init__(self, model: SentenceTransformer, dataset_id: str, faiss_path: str):
|
| 20 |
+
self.model = model
|
| 21 |
+
self.DATASET_ID = dataset_id
|
| 22 |
+
self.FAISS_INDEX_PATH = faiss_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
self.dataset_records = []
|
| 25 |
+
self.image_lookup = {}
|
| 26 |
+
self.bm25_index = None
|
| 27 |
+
self.index = None
|
| 28 |
|
|
|
|
|
|
|
| 29 |
self._load_data()
|
| 30 |
+
print(f"✅ SingleDatasetRAGService ({self.DATASET_ID}) đã sẵn sàng.")
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
def _parse_image_url(self, raw) -> str:
|
| 33 |
+
"""Parse image_urls dù là list hay string JSON, trả về URL đầu tiên hợp lệ."""
|
| 34 |
+
if isinstance(raw, list):
|
| 35 |
+
for url in raw:
|
| 36 |
+
if isinstance(url, str) and url.startswith('http'):
|
| 37 |
+
return url
|
| 38 |
+
elif isinstance(raw, str) and raw.strip():
|
| 39 |
+
raw = raw.strip()
|
| 40 |
+
if raw.startswith('http'):
|
| 41 |
+
return raw
|
| 42 |
+
try:
|
| 43 |
+
parsed = ast.literal_eval(raw)
|
| 44 |
+
if isinstance(parsed, list):
|
| 45 |
+
for url in parsed:
|
| 46 |
+
if isinstance(url, str) and url.startswith('http'):
|
| 47 |
+
return url
|
| 48 |
+
except Exception:
|
| 49 |
+
pass
|
| 50 |
+
return None
|
| 51 |
|
| 52 |
def _load_data(self):
|
| 53 |
+
print(f"💾 [RAG {self.DATASET_ID}] Đang tải dataset...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
try:
|
| 55 |
+
dataset = load_dataset(self.DATASET_ID, split="train")
|
| 56 |
|
| 57 |
+
tokenized_corpus = []
|
| 58 |
+
texts_for_embedding = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
count_img = 0
|
| 61 |
+
for item in dataset:
|
| 62 |
+
group = item.get('group', '')
|
| 63 |
+
if not group:
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
group_str = str(group).strip()
|
| 67 |
+
group_lower = group_str.lower()
|
| 68 |
+
|
| 69 |
+
# 1. Image lookup
|
| 70 |
+
imgs = item.get('image_urls', [])
|
| 71 |
+
if group_lower not in self.image_lookup:
|
| 72 |
+
url = self._parse_image_url(imgs)
|
| 73 |
+
if url:
|
| 74 |
+
self.image_lookup[group_lower] = url
|
| 75 |
+
count_img += 1
|
| 76 |
+
|
| 77 |
+
# 2. Extract metadata
|
| 78 |
+
meta_raw = item.get('metadata', {})
|
| 79 |
+
if isinstance(meta_raw, str):
|
| 80 |
+
try:
|
| 81 |
+
meta_dict = ast.literal_eval(meta_raw)
|
| 82 |
+
except:
|
| 83 |
+
meta_dict = {}
|
| 84 |
+
else:
|
| 85 |
+
meta_dict = meta_raw if isinstance(meta_raw, dict) else {}
|
| 86 |
+
|
| 87 |
+
# Dùng combined_text nếu có
|
| 88 |
+
content = str(item.get('combined_text', ''))[:3000]
|
| 89 |
+
|
| 90 |
+
record = {
|
| 91 |
+
"group": group_str,
|
| 92 |
+
"combined_text": content,
|
| 93 |
+
"original_content": str(item.get('original_content', ''))[:3000],
|
| 94 |
+
"image_urls": item.get('image_urls', []),
|
| 95 |
+
"dataset_source": self.DATASET_ID, # Đánh dấu nguồn
|
| 96 |
+
**meta_dict
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
search_text = f"{group_str}. {content}"
|
| 100 |
+
|
| 101 |
+
self.dataset_records.append(record)
|
| 102 |
+
texts_for_embedding.append(search_text)
|
| 103 |
+
|
| 104 |
+
tokenized_corpus.append(search_text.lower().split())
|
| 105 |
+
|
| 106 |
+
print(f"🖼️ [RAG {self.DATASET_ID}] Image lookup built: {count_img} entries.")
|
| 107 |
|
| 108 |
+
if BM25Okapi is not None and tokenized_corpus:
|
| 109 |
+
print(f"🔍 [RAG {self.DATASET_ID}] Build BM25 cho {len(tokenized_corpus)} records...")
|
| 110 |
+
self.bm25_index = BM25Okapi(tokenized_corpus)
|
| 111 |
|
| 112 |
+
if os.path.exists(self.FAISS_INDEX_PATH):
|
| 113 |
+
print(f"📂 [RAG {self.DATASET_ID}] Đọc file FAISS từ: {self.FAISS_INDEX_PATH}")
|
| 114 |
+
self.index = faiss.read_index(self.FAISS_INDEX_PATH)
|
| 115 |
+
if self.index.ntotal != len(self.dataset_records):
|
| 116 |
+
print(f"⚠️ [RAG {self.DATASET_ID}] Kích thước FAISS không khớp, build lại...")
|
| 117 |
+
self._build_faiss(texts_for_embedding)
|
| 118 |
+
else:
|
| 119 |
+
self._build_faiss(texts_for_embedding)
|
| 120 |
|
| 121 |
except Exception as e:
|
| 122 |
+
import traceback
|
| 123 |
+
traceback.print_exc()
|
| 124 |
+
print(f"❌ [RAG {self.DATASET_ID}] Lỗi load dataset: {e}")
|
| 125 |
+
|
| 126 |
+
def _build_faiss(self, texts: List[str]):
|
| 127 |
+
print(f"🔨 [RAG {self.DATASET_ID}] Đang embed {len(texts)} văn bản...")
|
| 128 |
+
vectors = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=True).astype("float32")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
d = vectors.shape[1]
|
| 130 |
self.index = faiss.IndexFlatL2(d)
|
| 131 |
self.index.add(vectors)
|
| 132 |
+
faiss.write_index(self.index, self.FAISS_INDEX_PATH)
|
| 133 |
+
print(f"✅ [RAG {self.DATASET_ID}] Build và lưu FAISS thành công.")
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
def get_image_for_topic(self, topic_name: str) -> str:
|
| 136 |
+
if not topic_name or not self.image_lookup:
|
| 137 |
+
return None
|
| 138 |
+
topic_lower = topic_name.strip().lower()
|
| 139 |
+
if topic_lower in self.image_lookup:
|
| 140 |
+
return self.image_lookup[topic_lower]
|
| 141 |
+
for group_key, url in self.image_lookup.items():
|
| 142 |
+
if topic_lower in group_key or group_key in topic_lower:
|
| 143 |
+
return url
|
| 144 |
+
return None
|
| 145 |
+
|
| 146 |
+
def search(self, query: str, top_k: int = 3) -> List[Dict[str, Any]]:
|
| 147 |
+
if not self.dataset_records:
|
| 148 |
+
return []
|
| 149 |
+
|
| 150 |
+
faiss_scores = {}
|
| 151 |
+
bm25_scores = {}
|
| 152 |
+
fetch_k = min(len(self.dataset_records), 60)
|
| 153 |
|
| 154 |
+
try:
|
| 155 |
+
if self.index:
|
| 156 |
+
query_vec = self.model.encode([query], convert_to_numpy=True).astype("float32")
|
| 157 |
+
distances, indices = self.index.search(query_vec, fetch_k)
|
| 158 |
+
for rank, idx in enumerate(indices[0]):
|
| 159 |
+
idx_int = int(idx)
|
| 160 |
+
if 0 <= idx_int < len(self.dataset_records):
|
| 161 |
+
faiss_scores[idx_int] = rank + 1
|
| 162 |
+
except Exception as e:
|
| 163 |
+
print(f"❌ Lỗi FAISS ({self.DATASET_ID}): {e}")
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
if self.bm25_index:
|
| 167 |
+
tokenized_query = query.lower().split()
|
| 168 |
+
scores = self.bm25_index.get_scores(tokenized_query)
|
| 169 |
+
top_indexes = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:fetch_k]
|
| 170 |
+
for rank, idx in enumerate(top_indexes):
|
| 171 |
+
if scores[idx] > 0:
|
| 172 |
+
bm25_scores[idx] = rank + 1
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f"❌ Lỗi BM25 ({self.DATASET_ID}): {e}")
|
| 175 |
+
|
| 176 |
+
k_rrf = 60
|
| 177 |
+
rrf_scores = {}
|
| 178 |
+
all_indices = set(list(faiss_scores.keys()) + list(bm25_scores.keys()))
|
| 179 |
+
for idx in all_indices:
|
| 180 |
+
score = 0.0
|
| 181 |
+
if idx in faiss_scores:
|
| 182 |
+
score += 1.0 / (k_rrf + faiss_scores[idx])
|
| 183 |
+
if idx in bm25_scores:
|
| 184 |
+
score += 1.0 / (k_rrf + bm25_scores[idx])
|
| 185 |
+
rrf_scores[idx] = score
|
| 186 |
+
|
| 187 |
+
sorted_indices = sorted(rrf_scores.keys(), key=lambda idx: rrf_scores[idx], reverse=True)
|
| 188 |
results = []
|
| 189 |
+
for idx in sorted_indices[:top_k]:
|
| 190 |
+
item = {
|
| 191 |
+
"score": round(rrf_scores[idx], 4),
|
| 192 |
+
"metadata": self.dataset_records[idx]
|
| 193 |
+
}
|
| 194 |
+
results.append(item)
|
| 195 |
+
|
|
|
|
| 196 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
class MultiDatasetRAGService:
|
| 199 |
+
"""Service tổng quản lý model Embedding chung và gọi search qua nhiều Dataset độc lập."""
|
| 200 |
+
_instance: Optional['MultiDatasetRAGService'] = None
|
| 201 |
+
|
| 202 |
def __new__(cls):
|
| 203 |
if cls._instance is None:
|
| 204 |
+
print("Khởi tạo MultiDatasetRAGService...")
|
| 205 |
+
cls._instance = super(MultiDatasetRAGService, cls).__new__(cls)
|
| 206 |
cls._instance._initialized = False
|
| 207 |
return cls._instance
|
| 208 |
|
| 209 |
def __init__(self):
|
| 210 |
if self._initialized:
|
| 211 |
return
|
| 212 |
+
|
| 213 |
+
# Dùng model tiếng Việt nhẹ thay cho all-MiniLM-L6-v2 để chuẩn hóa
|
| 214 |
+
# chung cho cả lấy Reranking (do model 1 chuyên tiếng Việt)
|
| 215 |
+
self.MODEL_NAME = "keepitreal/vietnamese-sbert"
|
| 216 |
+
print(f"🤖 [Multi RAG] Đang tải chung model embedding: {self.MODEL_NAME}...")
|
| 217 |
+
self.model = SentenceTransformer(self.MODEL_NAME)
|
| 218 |
|
| 219 |
+
# Danh sách Dataset
|
| 220 |
+
# Lưu ý: Cập nhật tên FAISS file để tránh load nhầm file cũ không cùng dimension
|
| 221 |
+
self.datasets_config = [
|
| 222 |
+
{"id": "phongnt251199/vietnam_heritage_v3", "faiss": "heritage_v3_visbert.faiss"},
|
| 223 |
+
{"id": "phongnt251199/vietnam_heritage_wiki_chunks_v1", "faiss": "wiki_chunks_v1_visbert.faiss"}
|
| 224 |
+
]
|
| 225 |
|
| 226 |
+
self.services = []
|
| 227 |
+
for cfg in self.datasets_config:
|
| 228 |
+
svc = SingleDatasetRAGService(self.model, cfg["id"], cfg["faiss"])
|
| 229 |
+
self.services.append(svc)
|
| 230 |
+
|
| 231 |
self._initialized = True
|
| 232 |
+
print(f"✅ MultiDatasetRAGService đã sẵn sàng xử lý {len(self.services)} datasets.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
+
def get_image_for_topic(self, topic_name: str) -> str:
|
| 235 |
+
for svc in self.services:
|
| 236 |
+
url = svc.get_image_for_topic(topic_name)
|
| 237 |
+
if url:
|
| 238 |
+
return url
|
| 239 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
def search(self, query: str, top_k: int = 3) -> List[Dict[str, Any]]:
|
| 242 |
+
all_results = []
|
| 243 |
+
# Lấy từ cả 2 nguồn, sau đó sort chung dựa trên Score RRF
|
| 244 |
+
for svc in self.services:
|
| 245 |
+
res = svc.search(query, top_k=top_k * 2) # Lấy dư để merge
|
| 246 |
+
all_results.extend(res)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
+
# Sort tổng hợp lại theo điểm tổng
|
| 249 |
+
sorted_results = sorted(all_results, key=lambda x: x['score'], reverse=True)
|
| 250 |
+
return sorted_results[:top_k]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
# ==============================================================================
|
| 253 |
+
# KHỞI TẠO SERVICE (SINGLETON CỦA MULTI) VÀ CUNG CẤP CÁC HÀM GỐC TƯƠNG THÍCH
|
| 254 |
# ==============================================================================
|
| 255 |
+
local_rag_service = MultiDatasetRAGService()
|
| 256 |
+
|
| 257 |
+
# Maintain các object cũ để tương thích với các module khác
|
| 258 |
+
hf_rag_service = local_rag_service
|
| 259 |
|
| 260 |
def retrieve_context(query: str, k: int = 2) -> str:
|
| 261 |
+
print(f"\n>>> [Multi RAG] Tìm kiếm: {query}")
|
| 262 |
+
results = local_rag_service.search(query, k)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
return format_metadata_list_to_context(results)
|
| 264 |
|
| 265 |
def search_heritage(query: str, top_k: int = 3) -> str:
|
| 266 |
+
print(f"\n>>> [Multi RAG] Tìm kiếm: {query}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
results = local_rag_service.search(query, top_k)
|
| 268 |
return format_metadata_list_to_context(results)
|
service/reranking.py
CHANGED
|
@@ -1,13 +1,9 @@
|
|
| 1 |
-
import ast
|
| 2 |
import concurrent.futures
|
| 3 |
-
|
|
|
|
| 4 |
|
| 5 |
-
from helper import format_metadata_list_to_context
|
| 6 |
-
from rag import hf_rag_service, local_rag_service
|
| 7 |
-
|
| 8 |
-
# Load model Reranker (nhẹ, chạy CPU được)
|
| 9 |
-
#reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
| 10 |
-
reranker = CrossEncoder('Alibaba-NLP/gte-multilingual-reranker-base', trust_remote_code=True)
|
| 11 |
|
| 12 |
def reciprocal_rank_fusion(search_results_lists: list, k_rrf: int = 60):
|
| 13 |
"""
|
|
@@ -42,11 +38,11 @@ def reciprocal_rank_fusion(search_results_lists: list, k_rrf: int = 60):
|
|
| 42 |
|
| 43 |
def advanced_search(query, keyword):
|
| 44 |
"""
|
| 45 |
-
Tìm kiếm nâng cao
|
|
|
|
| 46 |
"""
|
| 47 |
try:
|
| 48 |
-
|
| 49 |
-
result_hf = []
|
| 50 |
|
| 51 |
# 1. Tạo danh sách các cụm từ tìm kiếm
|
| 52 |
search_terms = [query]
|
|
@@ -55,107 +51,107 @@ def advanced_search(query, keyword):
|
|
| 55 |
elif isinstance(keyword, str) and keyword:
|
| 56 |
search_terms.append(keyword)
|
| 57 |
|
| 58 |
-
#
|
| 59 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 60 |
-
|
| 61 |
for r in search_terms:
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
# Local RAG
|
| 66 |
-
future_local = executor.submit(local_rag_service.search, r.lower(), top_k=15)
|
| 67 |
-
future_to_source[future_local] = 'local'
|
| 68 |
|
| 69 |
-
for future in concurrent.futures.as_completed(
|
| 70 |
-
source = future_to_source[future]
|
| 71 |
try:
|
| 72 |
docs = future.result()
|
| 73 |
if docs:
|
| 74 |
-
|
| 75 |
-
result_hf.append(docs)
|
| 76 |
-
else:
|
| 77 |
-
result_wiki.append(docs)
|
| 78 |
except Exception as e:
|
| 79 |
-
print(f"Lỗi tìm kiếm
|
| 80 |
|
| 81 |
-
if not
|
| 82 |
-
return "Không tìm thấy thông tin phù hợp."
|
| 83 |
|
| 84 |
-
#
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
# Tối ưu: Giảm số lượng ứng viên rerank xuống 15 để tăng tốc độ
|
| 88 |
-
candidates_for_rerank_hf = fused_results_hf[:15]
|
| 89 |
-
pairs_to_score_hf = []
|
| 90 |
-
for item in candidates_for_rerank_hf:
|
| 91 |
-
meta = item['metadata']
|
| 92 |
-
name = meta.get('ten','')
|
| 93 |
-
desc = meta.get('mo_ta','')
|
| 94 |
-
loai_hinh = meta.get('loai_hinh', '')
|
| 95 |
-
chu_de = meta.get('chu_de', '')
|
| 96 |
-
y_nghia = meta.get('y_nghia', '')
|
| 97 |
-
constructed_text = f"Tên: {name}. Loại hình: {loai_hinh}. Chủ đề: {chu_de}. Mô tả: {desc}. Ý nghĩa: {y_nghia}"
|
| 98 |
-
pairs_to_score_hf.append([query, constructed_text])
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
sorted_docs_hf = sorted(candidates_for_rerank_hf, key=lambda x: x['rerank_score'], reverse=True)
|
| 106 |
-
|
| 107 |
-
# Xử lý nguồn Local (Wiki) RAG
|
| 108 |
-
fused_results_wiki = reciprocal_rank_fusion(result_wiki)
|
| 109 |
-
# Tối ưu: Giảm số lượng ứng viên rerank xuống 40
|
| 110 |
-
candidates_for_rerank_wiki = fused_results_wiki[:40]
|
| 111 |
-
print("candidates_for_rerank_wiki:", len(candidates_for_rerank_wiki))
|
| 112 |
-
pairs_to_score_wiki = []
|
| 113 |
-
for item in candidates_for_rerank_wiki:
|
| 114 |
meta = item['metadata']
|
| 115 |
-
|
| 116 |
-
|
|
|
|
| 117 |
constructed_text = f"Tên: {name}. Mô tả: {desc}"
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
sorted_docs_wiki = []
|
| 121 |
-
if pairs_to_score_wiki:
|
| 122 |
-
scores_wiki = reranker.predict(pairs_to_score_wiki)
|
| 123 |
-
for i, doc in enumerate(candidates_for_rerank_wiki):
|
| 124 |
-
doc['rerank_score'] = scores_wiki[i]
|
| 125 |
-
sorted_docs_wiki = sorted(candidates_for_rerank_wiki, key=lambda x: x['rerank_score'], reverse=True)
|
| 126 |
|
| 127 |
-
#
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
|
| 134 |
-
#
|
| 135 |
-
|
| 136 |
-
for doc in top_3_wiki:
|
| 137 |
-
metadata = doc.get('metadata', {})
|
| 138 |
-
name = metadata.get('group', 'Không rõ tên')
|
| 139 |
-
section = metadata.get('section', '')
|
| 140 |
-
description = metadata.get('content', 'Không có mô tả.')
|
| 141 |
-
wiki_context_parts.append(f"[Nguồn Wiki - Tên]: {name}\n[Nguồn Wiki - Mô tả]: {description}" + (f"\n[Nguồn Wiki - Section]: {section}" if section else ""))
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
if
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
if hf_context and hf_context.strip() != "Không có dữ liệu ngữ cảnh.":
|
| 153 |
-
final_context_parts.append(hf_context)
|
| 154 |
|
| 155 |
-
|
| 156 |
-
return "Không tìm thấy thông tin phù hợp."
|
| 157 |
|
| 158 |
-
return "\n\n".join(final_context_parts).strip()
|
| 159 |
except Exception as e:
|
|
|
|
|
|
|
| 160 |
print(f"Lỗi trong advanced_search: {e}")
|
| 161 |
-
return "Đã xảy ra lỗi trong quá trình tìm kiếm nâng cao."
|
|
|
|
|
|
|
| 1 |
import concurrent.futures
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
|
| 5 |
+
from .helper import format_metadata_list_to_context
|
| 6 |
+
from .rag import hf_rag_service, local_rag_service
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
def reciprocal_rank_fusion(search_results_lists: list, k_rrf: int = 60):
|
| 9 |
"""
|
|
|
|
| 38 |
|
| 39 |
def advanced_search(query, keyword):
|
| 40 |
"""
|
| 41 |
+
Tìm kiếm nâng cao: search song song các cụm từ → RRF fusion → rerank top-K (Bằng embedding model) → format context.
|
| 42 |
+
Trả về tuple: (context_string, image_url | None)
|
| 43 |
"""
|
| 44 |
try:
|
| 45 |
+
result_all = []
|
|
|
|
| 46 |
|
| 47 |
# 1. Tạo danh sách các cụm từ tìm kiếm
|
| 48 |
search_terms = [query]
|
|
|
|
| 51 |
elif isinstance(keyword, str) and keyword:
|
| 52 |
search_terms.append(keyword)
|
| 53 |
|
| 54 |
+
# 2. Tìm kiếm song song qua Manager (chứa 2 Dataset) với mỗi cụm từ
|
| 55 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 56 |
+
future_to_term = {}
|
| 57 |
for r in search_terms:
|
| 58 |
+
if not r.strip(): continue
|
| 59 |
+
future = executor.submit(local_rag_service.search, r.lower(), top_k=15)
|
| 60 |
+
future_to_term[future] = r
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
for future in concurrent.futures.as_completed(future_to_term):
|
|
|
|
| 63 |
try:
|
| 64 |
docs = future.result()
|
| 65 |
if docs:
|
| 66 |
+
result_all.append(docs)
|
|
|
|
|
|
|
|
|
|
| 67 |
except Exception as e:
|
| 68 |
+
print(f"Lỗi tìm kiếm concurrent: {e}")
|
| 69 |
|
| 70 |
+
if not result_all:
|
| 71 |
+
return "Không tìm thấy thông tin phù hợp.", None
|
| 72 |
|
| 73 |
+
# 3. RRF Fusion sau đó Rerank tổng cộng
|
| 74 |
+
fused_results = reciprocal_rank_fusion(result_all)
|
| 75 |
+
print("candidates_for_rerank (after deduplication):", len(fused_results))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
# Tối ưu: Cắt giảm xuống 20 ứng viên để Reranking
|
| 78 |
+
candidates_for_rerank = fused_results[:20]
|
| 79 |
+
|
| 80 |
+
pairs_to_score = []
|
| 81 |
+
for item in candidates_for_rerank:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
meta = item['metadata']
|
| 83 |
+
# Lấy các trường đã được đồng bộ trong dataset (group, combined_text, original_content, dataset_source)
|
| 84 |
+
name = meta.get('group', '') or meta.get('ten', '')
|
| 85 |
+
desc = meta.get('combined_text') or meta.get('original_content') or meta.get('content', '')
|
| 86 |
constructed_text = f"Tên: {name}. Mô tả: {desc}"
|
| 87 |
+
pairs_to_score.append([query, constructed_text])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
# === MỚI: Reranking bằng Cosine Similarity của model Embedding (Bỏ Cross-Encoder) ===
|
| 90 |
+
sorted_docs = []
|
| 91 |
+
if pairs_to_score:
|
| 92 |
+
print(f"Reranking {len(pairs_to_score)} candidates bằng Embedding Model siêu nhanh...")
|
| 93 |
+
try:
|
| 94 |
+
# Trích xuất model embeddings ban đầu
|
| 95 |
+
q_text = query
|
| 96 |
+
docs_texts = [p[1] for p in pairs_to_score]
|
| 97 |
+
|
| 98 |
+
# Encode text trực tiếp
|
| 99 |
+
q_emb = local_rag_service.model.encode([q_text], convert_to_tensor=True)
|
| 100 |
+
docs_emb = local_rag_service.model.encode(docs_texts, convert_to_tensor=True)
|
| 101 |
+
|
| 102 |
+
# Tính khoảng cách Cosine
|
| 103 |
+
cos_scores = F.cosine_similarity(q_emb, docs_emb).cpu().tolist()
|
| 104 |
+
|
| 105 |
+
# Gán điểm và sort
|
| 106 |
+
for i, doc in enumerate(candidates_for_rerank):
|
| 107 |
+
doc['rerank_score'] = cos_scores[i]
|
| 108 |
+
|
| 109 |
+
sorted_docs = sorted(candidates_for_rerank, key=lambda x: x['rerank_score'], reverse=True)
|
| 110 |
+
except Exception as e:
|
| 111 |
+
import traceback
|
| 112 |
+
traceback.print_exc()
|
| 113 |
+
print(f"Lỗi khi dùng embedding model de rerank, fallback sang RRF score: {e}")
|
| 114 |
+
sorted_docs = candidates_for_rerank
|
| 115 |
+
else:
|
| 116 |
+
sorted_docs = candidates_for_rerank
|
| 117 |
+
|
| 118 |
+
# Lấy top 5 kết quả tốt nhất
|
| 119 |
+
top_5 = sorted_docs[:5]
|
| 120 |
+
|
| 121 |
+
# 4. Định dạng lại nội dung ngữ cảnh chung (Format chung)
|
| 122 |
+
final_context = format_metadata_list_to_context(top_5)
|
| 123 |
|
| 124 |
+
if not final_context or final_context.strip() == "Không có dữ liệu ngữ cảnh.":
|
| 125 |
+
return "Không tìm thấy thông tin phù hợp.", None
|
| 126 |
|
| 127 |
+
# 5. Trích xuất image_url từ các dataset đã gộp
|
| 128 |
+
image_url = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
+
# (a) Từ topic name của top results (Ưu tiên group name)
|
| 131 |
+
for doc in top_5:
|
| 132 |
+
topic = (doc.get('metadata', {}).get('group') or doc.get('metadata', {}).get('ten') or '').strip()
|
| 133 |
+
if topic:
|
| 134 |
+
image_url = local_rag_service.get_image_for_topic(topic)
|
| 135 |
+
if image_url:
|
| 136 |
+
print(f"[IMAGE] Found via topic: '{topic}'")
|
| 137 |
+
break
|
| 138 |
+
|
| 139 |
+
# (b) Fallback: từ search terms
|
| 140 |
+
if not image_url:
|
| 141 |
+
for term in search_terms:
|
| 142 |
+
if term and len(term) > 2:
|
| 143 |
+
image_url = local_rag_service.get_image_for_topic(term)
|
| 144 |
+
if image_url:
|
| 145 |
+
print(f"[IMAGE] Found via keyword: '{term}'")
|
| 146 |
+
break
|
| 147 |
|
| 148 |
+
if image_url:
|
| 149 |
+
print(f"[IMAGE] final: {image_url}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
+
return final_context, image_url
|
|
|
|
| 152 |
|
|
|
|
| 153 |
except Exception as e:
|
| 154 |
+
import traceback
|
| 155 |
+
traceback.print_exc()
|
| 156 |
print(f"Lỗi trong advanced_search: {e}")
|
| 157 |
+
return "Đã xảy ra lỗi trong quá trình tìm kiếm nâng cao.", None
|
service/rewrite.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
from groq import Groq
|
| 2 |
import torch
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 4 |
-
from reranking import advanced_search
|
| 5 |
import os
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
|
| 8 |
load_dotenv()
|
| 9 |
|
| 10 |
#groq_api_key = os.environ.get("GROQ_API_KEY")
|
| 11 |
-
groq_api_key =
|
| 12 |
|
| 13 |
if not groq_api_key:
|
| 14 |
raise ValueError("GROQ_API_KEY environment variable is not set")
|
|
@@ -244,6 +244,49 @@ class QueryRewriter:
|
|
| 244 |
hypothetical_answer = completion.choices[0].message.content.strip()
|
| 245 |
return hypothetical_answer
|
| 246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
def chain_of_thought(self,question,ans):
|
| 248 |
COT_SYSTEM_PROMPT = """Bạn là một trợ lý AI chuyên biên tập và kiểm tra tính liên quan của câu trả lời (Relevance-Checking Editor) chuyên về văn hoá, lịch sử và địa điểm Việt Nam.
|
| 249 |
NHIỆM VỤ CỐT LÕI: Dựa vào "Câu hỏi gốc", hãy lọc lại "Câu trả lời được tạo ra" để đảm bảo mọi thông tin trong câu trả lời cuối cùng đều liên quan trực tiếp đến chủ thể trong câu hỏi.
|
|
@@ -252,7 +295,7 @@ class QueryRewriter:
|
|
| 252 |
1. **Nếu câu trả lời mang ý nghĩa xã giao (chào hỏi, cảm ơn, v.v...) hoặc tôi không tìm thấy thông tin về vấn đề này trong tài liệu được cung cấp. **, hãy trả về nguyên văn câu trả lời ban đầu mà không chỉnh sửa gì.
|
| 253 |
2. **Xác định chủ thể chính**: Đọc kỹ "Câu hỏi gốc" để xác định đối tượng, địa danh, hoặc khái niệm chính mà người dùng đang hỏi.
|
| 254 |
3. **Kiểm tra và lọc**: Rà soát từng câu, từng ý trong "Câu trả lời được tạo ra".
|
| 255 |
-
* **Giữ lại**: Chỉ giữ lại những thông tin mô tả, giải thích, hoặc liệt kê các chi tiết liên quan đến chủ thể chính của câu hỏi.
|
| 256 |
* **Loại bỏ**: Xóa bỏ hoàn toàn bất kỳ thông tin nào nói về một chủ thể khác, không liên quan.
|
| 257 |
4. **Tổng hợp lại**: Viết lại câu trả lời cuối cùng một cách mạch lạc, tự nhiên từ những thông tin đã được lọc.
|
| 258 |
|
|
@@ -288,92 +331,95 @@ class QueryRewriter:
|
|
| 288 |
{"role": "user", "content": user_content}
|
| 289 |
]
|
| 290 |
)
|
| 291 |
-
|
| 292 |
return completion.choices[0].message.content.strip()
|
| 293 |
|
| 294 |
-
def ask_with_context(self,question,history):
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
|
|
|
| 298 |
print(f"\n--- keyword: {keyword} ---")
|
| 299 |
-
|
| 300 |
-
#
|
| 301 |
-
q_rewrite = self.rewrite_query(question,history)
|
| 302 |
print(f"\n--- q_rewrite: {q_rewrite} ---")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
fake_answer = self.generate_hypothetical_answer(q_rewrite)
|
| 304 |
-
print(f"\n--- fake_answer: {fake_answer} ---")
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
-
|
| 313 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
-
|
| 316 |
-
1. **Đọc hiểu ngữ cảnh**: Bạn cần đọc kỹ cả thông tin từ [Nguồn Wiki] và [TỔNG QUAN] (nếu có) để có cái nhìn toàn diện.
|
| 317 |
-
2. **Tổng hợp câu trả lời**:
|
| 318 |
-
- Kết hợp thông tin từ cả hai nguồn để câu trả lời đầy đủ và chính xác nhất.
|
| 319 |
-
- Nếu [TỔNG QUAN] cung cấp thông tin cơ bản (tên, địa điểm, thời gian), hãy dùng nó để giới thiệu.
|
| 320 |
-
- Nếu [Nguồn Wiki] cung cấp chi tiết mô tả, lịch sử, ý nghĩa, hãy dùng nó để giải thích sâu hơn.
|
| 321 |
-
3. **Xử lý câu hỏi cụ thể**:
|
| 322 |
-
- Với câu hỏi so sánh: Tìm điểm giống và khác nhau trong ngữ cảnh của các đối tượng.
|
| 323 |
-
- Với câu hỏi liệt kê: Liệt kê các đối tượng có trong ngữ cảnh phù hợp với câu hỏi.
|
| 324 |
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
"""
|
| 330 |
-
|
| 331 |
-
# Tạo nội dung user prompt: Ghép context và câu hỏi gốc
|
| 332 |
-
user_content = f"""### Context:
|
| 333 |
-
{p}
|
| 334 |
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
#
|
| 351 |
-
|
| 352 |
-
# # Đưa input vào đúng device của model_base
|
| 353 |
-
# model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
| 354 |
-
|
| 355 |
-
# # --- SỬA LỖI TẠI ĐÂY ---
|
| 356 |
-
# # Dùng model_base (Qwen) thay vì model (SentenceTransformer)
|
| 357 |
-
# generated_ids = self.model.generate(
|
| 358 |
-
# **model_inputs, # Lưu ý: sửa cả tên biến input cho khớp (model_inputs thay vì model_base1 cho rõ nghĩa)
|
| 359 |
-
# max_new_tokens=512,
|
| 360 |
-
# temperature=0.1,
|
| 361 |
-
# top_p=0.9
|
| 362 |
-
# )
|
| 363 |
-
|
| 364 |
-
# generated_ids = [
|
| 365 |
-
# output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
| 366 |
-
# ]
|
| 367 |
-
|
| 368 |
-
# answer_bot = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 369 |
-
|
| 370 |
completion = client.chat.completions.create(
|
| 371 |
model="meta-llama/llama-4-scout-17b-16e-instruct",
|
| 372 |
-
messages
|
| 373 |
{"role": "system", "content": RAG_SYSTEM_PROMPT},
|
| 374 |
{"role": "user", "content": user_content}
|
| 375 |
]
|
| 376 |
)
|
| 377 |
answer_bot = completion.choices[0].message.content.strip()
|
| 378 |
-
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from groq import Groq
|
| 2 |
import torch
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 4 |
+
from .reranking import advanced_search
|
| 5 |
import os
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
|
| 8 |
load_dotenv()
|
| 9 |
|
| 10 |
#groq_api_key = os.environ.get("GROQ_API_KEY")
|
| 11 |
+
groq_api_key = "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"
|
| 12 |
|
| 13 |
if not groq_api_key:
|
| 14 |
raise ValueError("GROQ_API_KEY environment variable is not set")
|
|
|
|
| 244 |
hypothetical_answer = completion.choices[0].message.content.strip()
|
| 245 |
return hypothetical_answer
|
| 246 |
|
| 247 |
+
def detect_image_intent(self, query: str, history: list) -> bool:
|
| 248 |
+
"""Phát hiện xem câu hỏi có ý định muốn xem hình ảnh/minh họa không."""
|
| 249 |
+
conversation_str = ""
|
| 250 |
+
for turn in history:
|
| 251 |
+
conversation_str += f"{turn['role'].capitalize()}: {turn['content']}\n"
|
| 252 |
+
|
| 253 |
+
IMAGE_INTENT_PROMPT = """Bạn là một API phân loại ý định (Intent Classifier). Nhiệm vụ duy nhất của bạn là xác định xem câu hỏi của người dùng có muốn xem HÌNH ẢNH, MINH HỌA hay không.
|
| 254 |
+
|
| 255 |
+
Các ý định MUỐN XEM ẢNH bao gồm:
|
| 256 |
+
- Yêu cầu trực tiếp: "cho tôi xem ảnh", "hình ảnh", "ảnh", "photo", "picture"
|
| 257 |
+
- Câu hỏi ngoại hình/hình thức: "trông như thế nào", "có hình dạng gì", "diện mạo", "nhìn như thế nào"
|
| 258 |
+
- Câu hỏi giới thiệu tổng quát: "giới thiệu về", "tìm hiểu về", "cho tôi biết về", "kể về"
|
| 259 |
+
- Câu hỏi mô tả: "mô tả", "đặc điểm", "trình bày"
|
| 260 |
+
|
| 261 |
+
Các ý định KHÔNG muốn xem ảnh:
|
| 262 |
+
- Hỏi sự kiện/năm tháng: "được công nhận năm nào", "xảy ra khi nào"
|
| 263 |
+
- Hỏi so sánh hàm ý trừu tượng: "khác nhau ở điểm nào", "ý nghĩa là gì"
|
| 264 |
+
- Hỏi lý do/nguyên nhân: "tại sao", "vì sao"
|
| 265 |
+
- Chào hỏi xã giao: "xin chào", "cảm ơn"
|
| 266 |
+
|
| 267 |
+
QUY TẮC OUTPUT:
|
| 268 |
+
- Chỉ trả về một từ duy nhất: YES hoặc NO
|
| 269 |
+
- KHÔNG giải thích"""
|
| 270 |
+
|
| 271 |
+
user_content = f"""Lịch sử:\n{conversation_str}\nCâu hỏi: \"{query}\"\nOutput:"""
|
| 272 |
+
|
| 273 |
+
try:
|
| 274 |
+
completion = client.chat.completions.create(
|
| 275 |
+
model="meta-llama/llama-4-scout-17b-16e-instruct",
|
| 276 |
+
messages=[
|
| 277 |
+
{"role": "system", "content": IMAGE_INTENT_PROMPT},
|
| 278 |
+
{"role": "user", "content": user_content}
|
| 279 |
+
],
|
| 280 |
+
max_tokens=5,
|
| 281 |
+
temperature=0.0
|
| 282 |
+
)
|
| 283 |
+
result = completion.choices[0].message.content.strip().upper()
|
| 284 |
+
print(f"\n--- image_intent: {result} ---")
|
| 285 |
+
return result.startswith("YES")
|
| 286 |
+
except Exception as e:
|
| 287 |
+
print(f"\n--- image_intent error: {e} ---")
|
| 288 |
+
return False
|
| 289 |
+
|
| 290 |
def chain_of_thought(self,question,ans):
|
| 291 |
COT_SYSTEM_PROMPT = """Bạn là một trợ lý AI chuyên biên tập và kiểm tra tính liên quan của câu trả lời (Relevance-Checking Editor) chuyên về văn hoá, lịch sử và địa điểm Việt Nam.
|
| 292 |
NHIỆM VỤ CỐT LÕI: Dựa vào "Câu hỏi gốc", hãy lọc lại "Câu trả lời được tạo ra" để đảm bảo mọi thông tin trong câu trả lời cuối cùng đều liên quan trực tiếp đến chủ thể trong câu hỏi.
|
|
|
|
| 295 |
1. **Nếu câu trả lời mang ý nghĩa xã giao (chào hỏi, cảm ơn, v.v...) hoặc tôi không tìm thấy thông tin về vấn đề này trong tài liệu được cung cấp. **, hãy trả về nguyên văn câu trả lời ban đầu mà không chỉnh sửa gì.
|
| 296 |
2. **Xác định chủ thể chính**: Đọc kỹ "Câu hỏi gốc" để xác định đối tượng, địa danh, hoặc khái niệm chính mà người dùng đang hỏi.
|
| 297 |
3. **Kiểm tra và lọc**: Rà soát từng câu, từng ý trong "Câu trả lời được tạo ra".
|
| 298 |
+
* **Giữ lại**: Chỉ giữ lại những thông tin mô tả, giải thích, hoặc liệt kê các chi tiết liên quan, url hình ảnh (https://raw.githubusercontent.com/T-Phong/WikiImage/main/Aodai_%C3%81o%20d%C3%A0i_Introduction_2.png), v.v. liên quan đến chủ thể chính của câu hỏi.
|
| 299 |
* **Loại bỏ**: Xóa bỏ hoàn toàn bất kỳ thông tin nào nói về một chủ thể khác, không liên quan.
|
| 300 |
4. **Tổng hợp lại**: Viết lại câu trả lời cuối cùng một cách mạch lạc, tự nhiên từ những thông tin đã được lọc.
|
| 301 |
|
|
|
|
| 331 |
{"role": "user", "content": user_content}
|
| 332 |
]
|
| 333 |
)
|
| 334 |
+
|
| 335 |
return completion.choices[0].message.content.strip()
|
| 336 |
|
| 337 |
+
def ask_with_context(self, question, history):
|
| 338 |
+
"""Pipeline chính: trả về tuple (answer_text, image_url | None)"""
|
| 339 |
+
|
| 340 |
+
# 1. Trích xuất keyword
|
| 341 |
+
keyword = self.keyword(question, history)
|
| 342 |
print(f"\n--- keyword: {keyword} ---")
|
| 343 |
+
|
| 344 |
+
# 2. Viết lại câu hỏi
|
| 345 |
+
q_rewrite = self.rewrite_query(question, history)
|
| 346 |
print(f"\n--- q_rewrite: {q_rewrite} ---")
|
| 347 |
+
|
| 348 |
+
# 3. Tạo câu trả lời giả định (HyDE)
|
| 349 |
+
# HyDE chỉ dùng làm keyword PHỤ để mở rộng search — KHÔNG làm query chính
|
| 350 |
+
# vì embedding của HyDE dài không match với chunk ngắn trong vector DB
|
| 351 |
fake_answer = self.generate_hypothetical_answer(q_rewrite)
|
| 352 |
+
print(f"\n--- fake_answer (first 150): {str(fake_answer)[:150]} ---")
|
| 353 |
+
|
| 354 |
+
# 4. Tìm kiếm nâng cao
|
| 355 |
+
# Query chính = q_rewrite (ngắn gọn, khớp embedding DB tốt hơn)
|
| 356 |
+
# Keywords = keyword (từ LLM extractor) + đầu HyDE (150 ký tự) làm term mở rộng
|
| 357 |
+
hyde_snippet = (fake_answer or '').strip()[:150]
|
| 358 |
+
search_keywords = keyword + ([hyde_snippet] if hyde_snippet else [])
|
| 359 |
+
context, candidate_image_url = advanced_search(q_rewrite, search_keywords)
|
| 360 |
+
print(f"\n--- context (200 chars): {str(context)[:200]} ---")
|
| 361 |
+
print(f"\n--- candidate_image_url: {candidate_image_url} ---")
|
| 362 |
+
|
| 363 |
+
# 5. Phát hiện intent xem ảnh SỚM (trước khi gọi LLM)
|
| 364 |
+
wants_image = self.detect_image_intent(question, history)
|
| 365 |
+
print(f"\n--- image_intent: {'YES' if wants_image else 'NO'} ---")
|
| 366 |
+
image_url = candidate_image_url if (wants_image and candidate_image_url) else None
|
| 367 |
|
| 368 |
+
RAG_SYSTEM_PROMPT = """Bạn là một trợ lý AI chuyên trả lời các câu hỏi về văn hóa các dân tộc Việt Nam.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
+
NGUỒN DỮ LIỆU DUY NHẤT: Chỉ được dùng thông tin nằm trong thẻ <context>...</context> bên dưới.
|
| 371 |
+
TUYỆT ĐỐI KHÔNG dùng kiến thức ngoài context, kể cả khi bạn biết câu trả lời.
|
| 372 |
+
|
| 373 |
+
QUY TẮC XỬ LÝ:
|
| 374 |
+
1. Đọc kỹ toàn bộ <context>.
|
| 375 |
+
2. Kiểm tra xem context có đề cập đến chủ thể trong câu hỏi không:
|
| 376 |
+
- NẾU CÓ → tổng hợp từ [TỔNG QUAN] để trả lời đầy đủ, chính xác.
|
| 377 |
+
- NẾU KHÔNG CÓ hoặc thông tin không liên quan → trả lời ngay: "Xin lỗi, tôi không tìm thấy thông tin về vấn đề này trong tài liệu được cung cấp." — KHÔNG được suy luận hay bổ sung thêm.
|
| 378 |
+
3. Xử lý câu hỏi cụ thể:
|
| 379 |
+
- So sánh: tìm điểm gi���ng/khác trong context.
|
| 380 |
+
- Liệt kê: liệt kê đúng những gì context đề cập.
|
| 381 |
+
4. Hình ảnh: nếu context có "### Hình ảnh minh họa", nhúng ảnh: . Không đề cập ảnh nếu không có.
|
| 382 |
+
|
| 383 |
+
FORMAT MARKDOWN:
|
| 384 |
+
- **in đậm** cho tên riêng, địa danh, khái niệm quan trọng.
|
| 385 |
+
- `##` tiêu đề chính, `###` tiêu đề phụ (chỉ khi câu trả lời dài, nhiều phần).
|
| 386 |
+
- Danh sách `-` khi liệt kê. Câu trả lời ngắn/xã giao: chỉ text thuần, không heading.
|
| 387 |
+
- TUYỆT ĐỐI KHÔNG đề cập đến việc tìm kiếm ảnh trên internet.
|
| 388 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
+
# 6. Build user_content — wrap context trong thẻ XML để model phân biệt rõ ranh giới
|
| 391 |
+
image_section = ""
|
| 392 |
+
if image_url:
|
| 393 |
+
image_section = f"\n\n### Hình ảnh minh họa\n\n"
|
| 394 |
+
|
| 395 |
+
user_content = f"""<context>
|
| 396 |
+
{context}{image_section}
|
| 397 |
+
</context>
|
| 398 |
+
|
| 399 |
+
### Câu hỏi:
|
| 400 |
+
{q_rewrite}
|
| 401 |
+
"""
|
| 402 |
+
print(f"\n--- ---")
|
| 403 |
+
print(f"\n--- user_content: {str(user_content)} ---")
|
| 404 |
+
print(f"\n--- ---")
|
| 405 |
+
# 7. Sinh câu trả lời
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
completion = client.chat.completions.create(
|
| 407 |
model="meta-llama/llama-4-scout-17b-16e-instruct",
|
| 408 |
+
messages=[
|
| 409 |
{"role": "system", "content": RAG_SYSTEM_PROMPT},
|
| 410 |
{"role": "user", "content": user_content}
|
| 411 |
]
|
| 412 |
)
|
| 413 |
answer_bot = completion.choices[0].message.content.strip()
|
| 414 |
+
print(f"\n--- ---")
|
| 415 |
+
print("answer before cot:",answer_bot)
|
| 416 |
+
print(f"\n--- ---")
|
| 417 |
+
# 8. Lọc câu trả lời qua CoT
|
| 418 |
+
|
| 419 |
+
final_answer = self.chain_of_thought(q_rewrite, answer_bot)
|
| 420 |
+
print(f"\n--- ---")
|
| 421 |
+
print("answer:", answer_bot)
|
| 422 |
+
print(f"\n--- ---")
|
| 423 |
+
return final_answer, image_url
|
| 424 |
+
|
| 425 |
+
|