change commit
Browse files- core/gradio/{gradio_rag_qwen.py → gradio_rag.py} +33 -24
- core/hash_file/hash_data_goc.py +76 -95
- core/hash_file/hash_file.py +14 -3
- core/preprocessing/docling_processor.py +69 -9
- core/preprocessing/pdf_parser.py +11 -8
- core/rag/chunk.py +67 -54
- core/rag/embedding_model.py +24 -10
- core/rag/generator.py +16 -7
- core/rag/retrival.py +61 -57
- core/rag/vector_store.py +52 -30
- evaluation/eval_utils.py +19 -4
- evaluation/ragas_eval.py +32 -57
- scripts/build_data.py +43 -41
- scripts/run_eval.py +7 -5
- test/parse_data_hash_test.py +0 -102
- test/test_chunk.py +17 -7
core/gradio/{gradio_rag_qwen.py → gradio_rag.py}
RENAMED
|
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
import os
|
| 3 |
import sys
|
|
@@ -8,6 +13,7 @@ import gradio as gr
|
|
| 8 |
from dotenv import find_dotenv, load_dotenv
|
| 9 |
from openai import OpenAI
|
| 10 |
|
|
|
|
| 11 |
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 12 |
if str(REPO_ROOT) not in sys.path:
|
| 13 |
sys.path.insert(0, str(REPO_ROOT))
|
|
@@ -15,14 +21,18 @@ if str(REPO_ROOT) not in sys.path:
|
|
| 15 |
|
| 16 |
@dataclass
|
| 17 |
class GradioConfig:
|
|
|
|
| 18 |
server_host: str = "127.0.0.1"
|
| 19 |
server_port: int = 7860
|
| 20 |
|
|
|
|
| 21 |
def _load_env() -> None:
|
|
|
|
| 22 |
dotenv_path = find_dotenv(usecwd=True) or ""
|
| 23 |
load_dotenv(dotenv_path=dotenv_path or None, override=False)
|
| 24 |
|
| 25 |
|
|
|
|
| 26 |
from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings
|
| 27 |
from core.rag.vector_store import ChromaConfig, ChromaVectorDB
|
| 28 |
from core.rag.retrival import Retriever, RetrievalMode, get_retrieval_config
|
|
@@ -30,19 +40,20 @@ from core.rag.generator import RAGContextBuilder, build_context, build_prompt, S
|
|
| 30 |
|
| 31 |
_load_env()
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
#
|
| 36 |
-
LLM_MODEL = os.getenv("LLM_MODEL", "qwen/qwen3-32b")
|
| 37 |
-
LLM_API_BASE = "https://api.groq.com/openai/v1"
|
| 38 |
-
LLM_API_KEY_ENV = "GROQ_API_KEY"
|
| 39 |
-
|
| 40 |
-
# Load retrieval config
|
| 41 |
GRADIO_CFG = GradioConfig()
|
| 42 |
RETRIEVAL_CFG = get_retrieval_config()
|
| 43 |
|
| 44 |
|
| 45 |
class AppState:
|
|
|
|
|
|
|
| 46 |
def __init__(self) -> None:
|
| 47 |
self.db: Optional[ChromaVectorDB] = None
|
| 48 |
self.retriever: Optional[Retriever] = None
|
|
@@ -50,39 +61,38 @@ class AppState:
|
|
| 50 |
self.client: Optional[OpenAI] = None
|
| 51 |
|
| 52 |
|
| 53 |
-
STATE = AppState()
|
| 54 |
|
| 55 |
|
| 56 |
def _init_resources() -> None:
|
|
|
|
| 57 |
if STATE.db is not None:
|
| 58 |
return
|
| 59 |
|
| 60 |
print(f" Đang khởi tạo Database & Re-ranker...")
|
| 61 |
print(f" Retrieval Mode: {RETRIEVAL_MODE.value}")
|
| 62 |
|
|
|
|
| 63 |
emb = QwenEmbeddings(EmbeddingConfig())
|
| 64 |
-
|
| 65 |
db_cfg = ChromaConfig()
|
| 66 |
|
| 67 |
-
STATE.db = ChromaVectorDB(
|
| 68 |
-
embedder=emb,
|
| 69 |
-
config=db_cfg,
|
| 70 |
-
)
|
| 71 |
STATE.retriever = Retriever(vector_db=STATE.db)
|
| 72 |
|
| 73 |
-
# LLM
|
| 74 |
api_key = (os.getenv(LLM_API_KEY_ENV) or "").strip()
|
| 75 |
if not api_key:
|
| 76 |
raise RuntimeError(f"Missing {LLM_API_KEY_ENV}")
|
| 77 |
STATE.client = OpenAI(api_key=api_key, base_url=LLM_API_BASE)
|
| 78 |
|
| 79 |
-
#
|
| 80 |
STATE.rag_builder = RAGContextBuilder(retriever=STATE.retriever)
|
| 81 |
|
| 82 |
print(" Đã sẵn sàng!")
|
| 83 |
|
| 84 |
|
| 85 |
def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
|
|
| 86 |
_init_resources()
|
| 87 |
|
| 88 |
assert STATE.db is not None
|
|
@@ -90,7 +100,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
| 90 |
assert STATE.retriever is not None
|
| 91 |
assert STATE.rag_builder is not None
|
| 92 |
|
| 93 |
-
#
|
| 94 |
prepared = STATE.rag_builder.retrieve_and_prepare(
|
| 95 |
message,
|
| 96 |
k=RETRIEVAL_CFG.top_k,
|
|
@@ -103,7 +113,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
| 103 |
yield "Xin lỗi, tôi không tìm thấy thông tin phù hợp trong dữ liệu."
|
| 104 |
return
|
| 105 |
|
| 106 |
-
#
|
| 107 |
completion = STATE.client.chat.completions.create(
|
| 108 |
model=LLM_MODEL,
|
| 109 |
messages=[{"role": "user", "content": prepared["prompt"]}],
|
|
@@ -112,6 +122,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
| 112 |
stream=True,
|
| 113 |
)
|
| 114 |
|
|
|
|
| 115 |
acc = ""
|
| 116 |
for chunk in completion:
|
| 117 |
delta = getattr(chunk.choices[0].delta, "content", "") or ""
|
|
@@ -119,7 +130,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
| 119 |
acc += delta
|
| 120 |
yield acc
|
| 121 |
|
| 122 |
-
#
|
| 123 |
debug_info = f"\n\n---\n\n**Retrieved (Top {len(results)} | Mode: {RETRIEVAL_MODE.value})**\n\n"
|
| 124 |
for i, r in enumerate(results, 1):
|
| 125 |
md = r.get("metadata", {})
|
|
@@ -127,7 +138,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
| 127 |
rerank_score = r.get("rerank_score")
|
| 128 |
distance = r.get("distance")
|
| 129 |
|
| 130 |
-
#
|
| 131 |
source = md.get("source_file", "N/A")
|
| 132 |
doc_type = md.get("document_type", "N/A")
|
| 133 |
header = md.get("header_path", "")
|
|
@@ -135,7 +146,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
| 135 |
program = md.get("program_name", "")
|
| 136 |
issued_year = md.get("issued_year", "")
|
| 137 |
|
| 138 |
-
#
|
| 139 |
score_info = ""
|
| 140 |
if rerank_score is not None:
|
| 141 |
score_info += f"Rerank: `{rerank_score:.4f}` "
|
|
@@ -144,7 +155,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
| 144 |
if not score_info:
|
| 145 |
score_info = f"Rank: `{r.get('final_rank', i)}`"
|
| 146 |
|
| 147 |
-
#
|
| 148 |
meta_parts = [f"**Nguồn:** {source}", f"**Loại:** {doc_type}"]
|
| 149 |
if issued_year:
|
| 150 |
meta_parts.append(f"**Năm:** {issued_year}")
|
|
@@ -162,9 +173,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
|
| 162 |
yield acc + debug_info
|
| 163 |
|
| 164 |
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
# Create Gradio interface
|
| 168 |
demo = gr.ChatInterface(
|
| 169 |
fn=rag_chat,
|
| 170 |
title=f"HUST RAG Assistant",
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Giao diện Gradio cho hệ thống RAG - Trợ lý học vụ HUST.
|
| 3 |
+
Cho phép người dùng đặt câu hỏi và nhận câu trả lời từ hệ thống RAG.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
from __future__ import annotations
|
| 7 |
import os
|
| 8 |
import sys
|
|
|
|
| 13 |
from dotenv import find_dotenv, load_dotenv
|
| 14 |
from openai import OpenAI
|
| 15 |
|
| 16 |
+
# Thêm thư mục gốc vào Python path
|
| 17 |
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 18 |
if str(REPO_ROOT) not in sys.path:
|
| 19 |
sys.path.insert(0, str(REPO_ROOT))
|
|
|
|
| 21 |
|
| 22 |
@dataclass
|
| 23 |
class GradioConfig:
|
| 24 |
+
"""Cấu hình Gradio server: host và port."""
|
| 25 |
server_host: str = "127.0.0.1"
|
| 26 |
server_port: int = 7860
|
| 27 |
|
| 28 |
+
|
| 29 |
def _load_env() -> None:
|
| 30 |
+
"""Tải biến môi trường từ file .env."""
|
| 31 |
dotenv_path = find_dotenv(usecwd=True) or ""
|
| 32 |
load_dotenv(dotenv_path=dotenv_path or None, override=False)
|
| 33 |
|
| 34 |
|
| 35 |
+
# Import các module RAG
|
| 36 |
from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings
|
| 37 |
from core.rag.vector_store import ChromaConfig, ChromaVectorDB
|
| 38 |
from core.rag.retrival import Retriever, RetrievalMode, get_retrieval_config
|
|
|
|
| 40 |
|
| 41 |
_load_env()
|
| 42 |
|
| 43 |
+
# Cấu hình retrieval và LLM
|
| 44 |
+
RETRIEVAL_MODE = RetrievalMode.HYBRID_RERANK # Chế độ tìm kiếm
|
| 45 |
+
LLM_MODEL = os.getenv("LLM_MODEL", "qwen/qwen3-32b") # Model LLM
|
| 46 |
+
LLM_API_BASE = "https://api.groq.com/openai/v1" # Groq API endpoint
|
| 47 |
+
LLM_API_KEY_ENV = "GROQ_API_KEY" # Biến môi trường chứa API key
|
| 48 |
|
| 49 |
+
# Khởi tạo cấu hình
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
GRADIO_CFG = GradioConfig()
|
| 51 |
RETRIEVAL_CFG = get_retrieval_config()
|
| 52 |
|
| 53 |
|
| 54 |
class AppState:
|
| 55 |
+
"""Quản lý trạng thái ứng dụng: database, retriever, LLM client."""
|
| 56 |
+
|
| 57 |
def __init__(self) -> None:
|
| 58 |
self.db: Optional[ChromaVectorDB] = None
|
| 59 |
self.retriever: Optional[Retriever] = None
|
|
|
|
| 61 |
self.client: Optional[OpenAI] = None
|
| 62 |
|
| 63 |
|
| 64 |
+
STATE = AppState() # Singleton state
|
| 65 |
|
| 66 |
|
| 67 |
def _init_resources() -> None:
|
| 68 |
+
"""Khởi tạo các tài nguyên: DB, Retriever, LLM client (lazy init)."""
|
| 69 |
if STATE.db is not None:
|
| 70 |
return
|
| 71 |
|
| 72 |
print(f" Đang khởi tạo Database & Re-ranker...")
|
| 73 |
print(f" Retrieval Mode: {RETRIEVAL_MODE.value}")
|
| 74 |
|
| 75 |
+
# Khởi tạo embedding và database
|
| 76 |
emb = QwenEmbeddings(EmbeddingConfig())
|
|
|
|
| 77 |
db_cfg = ChromaConfig()
|
| 78 |
|
| 79 |
+
STATE.db = ChromaVectorDB(embedder=emb, config=db_cfg)
|
|
|
|
|
|
|
|
|
|
| 80 |
STATE.retriever = Retriever(vector_db=STATE.db)
|
| 81 |
|
| 82 |
+
# Khởi tạo LLM client
|
| 83 |
api_key = (os.getenv(LLM_API_KEY_ENV) or "").strip()
|
| 84 |
if not api_key:
|
| 85 |
raise RuntimeError(f"Missing {LLM_API_KEY_ENV}")
|
| 86 |
STATE.client = OpenAI(api_key=api_key, base_url=LLM_API_BASE)
|
| 87 |
|
| 88 |
+
# Khởi tạo RAG builder
|
| 89 |
STATE.rag_builder = RAGContextBuilder(retriever=STATE.retriever)
|
| 90 |
|
| 91 |
print(" Đã sẵn sàng!")
|
| 92 |
|
| 93 |
|
| 94 |
def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
|
| 95 |
+
"""Xử lý chat: retrieve documents -> gọi LLM -> stream response"""
|
| 96 |
_init_resources()
|
| 97 |
|
| 98 |
assert STATE.db is not None
|
|
|
|
| 100 |
assert STATE.retriever is not None
|
| 101 |
assert STATE.rag_builder is not None
|
| 102 |
|
| 103 |
+
# Retrieve và chuẩn bị context
|
| 104 |
prepared = STATE.rag_builder.retrieve_and_prepare(
|
| 105 |
message,
|
| 106 |
k=RETRIEVAL_CFG.top_k,
|
|
|
|
| 113 |
yield "Xin lỗi, tôi không tìm thấy thông tin phù hợp trong dữ liệu."
|
| 114 |
return
|
| 115 |
|
| 116 |
+
# Gọi LLM với streaming
|
| 117 |
completion = STATE.client.chat.completions.create(
|
| 118 |
model=LLM_MODEL,
|
| 119 |
messages=[{"role": "user", "content": prepared["prompt"]}],
|
|
|
|
| 122 |
stream=True,
|
| 123 |
)
|
| 124 |
|
| 125 |
+
# Stream response
|
| 126 |
acc = ""
|
| 127 |
for chunk in completion:
|
| 128 |
delta = getattr(chunk.choices[0].delta, "content", "") or ""
|
|
|
|
| 130 |
acc += delta
|
| 131 |
yield acc
|
| 132 |
|
| 133 |
+
# Thêm debug info về các documents đã retrieve
|
| 134 |
debug_info = f"\n\n---\n\n**Retrieved (Top {len(results)} | Mode: {RETRIEVAL_MODE.value})**\n\n"
|
| 135 |
for i, r in enumerate(results, 1):
|
| 136 |
md = r.get("metadata", {})
|
|
|
|
| 138 |
rerank_score = r.get("rerank_score")
|
| 139 |
distance = r.get("distance")
|
| 140 |
|
| 141 |
+
# Trích xuất metadata
|
| 142 |
source = md.get("source_file", "N/A")
|
| 143 |
doc_type = md.get("document_type", "N/A")
|
| 144 |
header = md.get("header_path", "")
|
|
|
|
| 146 |
program = md.get("program_name", "")
|
| 147 |
issued_year = md.get("issued_year", "")
|
| 148 |
|
| 149 |
+
# Format score
|
| 150 |
score_info = ""
|
| 151 |
if rerank_score is not None:
|
| 152 |
score_info += f"Rerank: `{rerank_score:.4f}` "
|
|
|
|
| 155 |
if not score_info:
|
| 156 |
score_info = f"Rank: `{r.get('final_rank', i)}`"
|
| 157 |
|
| 158 |
+
# Format metadata
|
| 159 |
meta_parts = [f"**Nguồn:** {source}", f"**Loại:** {doc_type}"]
|
| 160 |
if issued_year:
|
| 161 |
meta_parts.append(f"**Năm:** {issued_year}")
|
|
|
|
| 173 |
yield acc + debug_info
|
| 174 |
|
| 175 |
|
| 176 |
+
# Tạo giao diện Gradio
|
|
|
|
|
|
|
| 177 |
demo = gr.ChatInterface(
|
| 178 |
fn=rag_chat,
|
| 179 |
title=f"HUST RAG Assistant",
|
core/hash_file/hash_data_goc.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
| 1 |
import sys
|
| 2 |
-
import os
|
| 3 |
import json
|
| 4 |
import shutil
|
| 5 |
from pathlib import Path
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
sys.path.insert(0, str(project_root))
|
| 11 |
|
| 12 |
from core.hash_file.hash_file import HashProcessor
|
| 13 |
|
|
@@ -16,130 +14,113 @@ HF_RAW_PDF_REPO = "hungnha/Do_An_Dataset"
|
|
| 16 |
|
| 17 |
|
| 18 |
def download_from_hf(cache_dir: Path) -> Path:
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
except ImportError:
|
| 22 |
-
print("Installing huggingface_hub...")
|
| 23 |
-
os.system("pip install huggingface_hub")
|
| 24 |
-
from huggingface_hub import snapshot_download
|
| 25 |
|
|
|
|
| 26 |
if cache_dir.exists() and any(cache_dir.iterdir()):
|
| 27 |
print(f"Cache đã tồn tại: {cache_dir}")
|
| 28 |
return cache_dir / "data_rag"
|
| 29 |
|
| 30 |
-
print(f"Đang tải
|
| 31 |
snapshot_download(
|
| 32 |
repo_id=HF_RAW_PDF_REPO,
|
| 33 |
repo_type="dataset",
|
| 34 |
local_dir=str(cache_dir),
|
| 35 |
local_dir_use_symlinks=False,
|
| 36 |
)
|
| 37 |
-
print("Tải xong!")
|
| 38 |
return cache_dir / "data_rag"
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def main():
|
| 42 |
import argparse
|
| 43 |
-
parser = argparse.ArgumentParser()
|
| 44 |
-
parser.add_argument("--source", type=str, help="Đường dẫn local tới
|
| 45 |
parser.add_argument("--download-only", action="store_true", help="Chỉ tải về, không copy")
|
| 46 |
args = parser.parse_args()
|
| 47 |
|
| 48 |
-
data_dir =
|
| 49 |
files_dir = data_dir / "files"
|
| 50 |
files_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 51 |
|
| 52 |
-
# Xác định nguồn
|
| 53 |
if args.source:
|
| 54 |
source_root = Path(args.source)
|
| 55 |
if not source_root.exists():
|
| 56 |
-
print(f"
|
| 57 |
-
return
|
| 58 |
else:
|
| 59 |
# Tải từ HuggingFace
|
| 60 |
-
|
| 61 |
-
source_root = download_from_hf(cache_dir)
|
| 62 |
-
|
| 63 |
if args.download_only:
|
| 64 |
-
print(f"PDF đã cache tại: {source_root}")
|
| 65 |
-
return
|
| 66 |
|
| 67 |
if not source_root.exists():
|
| 68 |
-
print(f"Không tìm thấy thư mục PDF: {source_root}")
|
| 69 |
-
return
|
| 70 |
-
|
| 71 |
-
hash_processor = HashProcessor(verbose=False)
|
| 72 |
-
hash_file_path = data_dir / "hash_data_goc_index.json"
|
| 73 |
-
|
| 74 |
-
existing_hashes = {}
|
| 75 |
-
if hash_file_path.exists():
|
| 76 |
-
with open(hash_file_path, 'r', encoding='utf-8') as f:
|
| 77 |
-
data = json.load(f)
|
| 78 |
-
existing_hashes = {item['filename']: item['hash'] for item in data.get('train', [])}
|
| 79 |
-
print(f"Đã tải {len(existing_hashes)} hash từ index cũ")
|
| 80 |
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
print(f"Tìm thấy {len(pdf_files)} files PDF\n")
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
filename = str(relative_path)
|
| 93 |
-
dest_path = files_dir / relative_path
|
| 94 |
-
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
| 95 |
-
|
| 96 |
-
# Kiểm tra file đã tồn tại và hash khớp chưa
|
| 97 |
-
if dest_path.exists() and filename in existing_hashes:
|
| 98 |
-
current_hash = hash_processor.get_file_hash(str(dest_path))
|
| 99 |
-
if current_hash == existing_hashes[filename]:
|
| 100 |
-
hash_results.append({
|
| 101 |
-
'filename': filename,
|
| 102 |
-
'hash': current_hash,
|
| 103 |
-
'index': idx
|
| 104 |
-
})
|
| 105 |
-
skipped += 1
|
| 106 |
-
continue
|
| 107 |
-
|
| 108 |
-
try:
|
| 109 |
-
shutil.copy2(source_path, dest_path)
|
| 110 |
-
|
| 111 |
-
file_hash = hash_processor.get_file_hash(str(dest_path))
|
| 112 |
-
if file_hash is None:
|
| 113 |
-
print(f"Lỗi tính hash cho file {filename}")
|
| 114 |
-
continue
|
| 115 |
-
|
| 116 |
-
hash_results.append({
|
| 117 |
-
'filename': filename,
|
| 118 |
-
'hash': file_hash,
|
| 119 |
-
'index': idx
|
| 120 |
-
})
|
| 121 |
-
processed += 1
|
| 122 |
-
|
| 123 |
-
if (idx + 1) % 10 == 0:
|
| 124 |
-
print(f"Processed {idx + 1}/{len(pdf_files)} files")
|
| 125 |
-
|
| 126 |
-
except Exception as e:
|
| 127 |
-
print(f"Lỗi khi xử lý file {filename}: {e}")
|
| 128 |
-
continue
|
| 129 |
-
|
| 130 |
-
output_data = {
|
| 131 |
-
'train': hash_results,
|
| 132 |
-
'total_files': len(hash_results)
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
with open(hash_file_path, 'w', encoding='utf-8') as f:
|
| 136 |
-
json.dump(output_data, f, ensure_ascii=False, indent=2)
|
| 137 |
-
|
| 138 |
-
print(f"\nHoàn tất!")
|
| 139 |
-
print(f"Tổng số file: {len(hash_results)}")
|
| 140 |
-
print(f"Đã xử lý mới: {processed}")
|
| 141 |
-
print(f"Đã bỏ qua (trùng hash): {skipped}")
|
| 142 |
-
print(f"File index: {hash_file_path}")
|
| 143 |
|
| 144 |
|
| 145 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import sys
|
|
|
|
| 2 |
import json
|
| 3 |
import shutil
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
| 7 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 8 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
| 9 |
|
| 10 |
from core.hash_file.hash_file import HashProcessor
|
| 11 |
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def download_from_hf(cache_dir: Path) -> Path:
|
| 17 |
+
"""Tải PDF từ HuggingFace, trả về đường dẫn tới folder data_rag."""
|
| 18 |
+
from huggingface_hub import snapshot_download
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
# Kiểm tra cache đã tồn tại chưa
|
| 21 |
if cache_dir.exists() and any(cache_dir.iterdir()):
|
| 22 |
print(f"Cache đã tồn tại: {cache_dir}")
|
| 23 |
return cache_dir / "data_rag"
|
| 24 |
|
| 25 |
+
print(f"Đang tải từ HuggingFace: {HF_RAW_PDF_REPO}")
|
| 26 |
snapshot_download(
|
| 27 |
repo_id=HF_RAW_PDF_REPO,
|
| 28 |
repo_type="dataset",
|
| 29 |
local_dir=str(cache_dir),
|
| 30 |
local_dir_use_symlinks=False,
|
| 31 |
)
|
|
|
|
| 32 |
return cache_dir / "data_rag"
|
| 33 |
|
| 34 |
|
| 35 |
+
def load_existing_hashes(path: Path) -> dict:
|
| 36 |
+
"""Đọc hash index cũ từ file JSON."""
|
| 37 |
+
if not path.exists():
|
| 38 |
+
return {}
|
| 39 |
+
try:
|
| 40 |
+
data = json.loads(path.read_text(encoding='utf-8'))
|
| 41 |
+
return {item['filename']: item['hash'] for item in data.get('train', [])}
|
| 42 |
+
except Exception:
|
| 43 |
+
return {}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def process_pdfs(source_root: Path, dest_dir: Path, existing_hashes: dict) -> tuple:
|
| 47 |
+
"""Copy PDFs và tính hash. Trả về (results, processed, skipped)."""
|
| 48 |
+
hasher = HashProcessor(verbose=False)
|
| 49 |
+
pdf_files = list(source_root.rglob("*.pdf"))
|
| 50 |
+
print(f"Tìm thấy {len(pdf_files)} file PDF\n")
|
| 51 |
+
|
| 52 |
+
results, processed, skipped = [], 0, 0
|
| 53 |
+
|
| 54 |
+
for idx, src in enumerate(pdf_files):
|
| 55 |
+
rel_path = str(src.relative_to(source_root))
|
| 56 |
+
dest = dest_dir / rel_path
|
| 57 |
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
|
| 59 |
+
# Bỏ qua nếu file không thay đổi (hash khớp)
|
| 60 |
+
if dest.exists() and rel_path in existing_hashes:
|
| 61 |
+
current_hash = hasher.get_file_hash(str(dest))
|
| 62 |
+
if current_hash == existing_hashes[rel_path]:
|
| 63 |
+
results.append({'filename': rel_path, 'hash': current_hash, 'index': idx})
|
| 64 |
+
skipped += 1
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
# Copy và tính hash
|
| 68 |
+
try:
|
| 69 |
+
shutil.copy2(src, dest)
|
| 70 |
+
file_hash = hasher.get_file_hash(str(dest))
|
| 71 |
+
if file_hash:
|
| 72 |
+
results.append({'filename': rel_path, 'hash': file_hash, 'index': idx})
|
| 73 |
+
processed += 1
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"Lỗi: {rel_path} - {e}")
|
| 76 |
+
|
| 77 |
+
# Hiển thị tiến độ
|
| 78 |
+
if (idx + 1) % 20 == 0:
|
| 79 |
+
print(f"Tiến độ: {idx + 1}/{len(pdf_files)}")
|
| 80 |
+
|
| 81 |
+
return results, processed, skipped
|
| 82 |
+
|
| 83 |
+
|
| 84 |
def main():
|
| 85 |
import argparse
|
| 86 |
+
parser = argparse.ArgumentParser(description="Tải PDF và tạo hash index")
|
| 87 |
+
parser.add_argument("--source", type=str, help="Đường dẫn local tới PDFs (bỏ qua tải HF)")
|
| 88 |
parser.add_argument("--download-only", action="store_true", help="Chỉ tải về, không copy")
|
| 89 |
args = parser.parse_args()
|
| 90 |
|
| 91 |
+
data_dir = PROJECT_ROOT / "data"
|
| 92 |
files_dir = data_dir / "files"
|
| 93 |
files_dir.mkdir(parents=True, exist_ok=True)
|
| 94 |
+
hash_file = data_dir / "hash_data_goc_index.json"
|
| 95 |
|
| 96 |
+
# Xác định thư mục nguồn
|
| 97 |
if args.source:
|
| 98 |
source_root = Path(args.source)
|
| 99 |
if not source_root.exists():
|
| 100 |
+
return print(f"Không tìm thấy thư mục nguồn: {source_root}")
|
|
|
|
| 101 |
else:
|
| 102 |
# Tải từ HuggingFace
|
| 103 |
+
source_root = download_from_hf(data_dir / "raw_pdf_cache")
|
|
|
|
|
|
|
| 104 |
if args.download_only:
|
| 105 |
+
return print(f"PDF đã cache tại: {source_root}")
|
|
|
|
| 106 |
|
| 107 |
if not source_root.exists():
|
| 108 |
+
return print(f"Không tìm thấy thư mục PDF: {source_root}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
+
# Xử lý
|
| 111 |
+
existing = load_existing_hashes(hash_file)
|
| 112 |
+
print(f"Đã tải {len(existing)} hash từ index cũ")
|
| 113 |
|
| 114 |
+
results, processed, skipped = process_pdfs(source_root, files_dir, existing)
|
|
|
|
| 115 |
|
| 116 |
+
# Lưu kết quả
|
| 117 |
+
hash_file.write_text(json.dumps({
|
| 118 |
+
'train': results,
|
| 119 |
+
'total_files': len(results)
|
| 120 |
+
}, ensure_ascii=False, indent=2), encoding='utf-8')
|
| 121 |
|
| 122 |
+
print(f"\nHoàn tất! Tổng: {len(results)} | Mới: {processed} | Bỏ qua: {skipped}")
|
| 123 |
+
print(f"File index: {hash_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
if __name__ == "__main__":
|
core/hash_file/hash_file.py
CHANGED
|
@@ -9,19 +9,23 @@ from pathlib import Path
|
|
| 9 |
from typing import Dict, List, Optional
|
| 10 |
from datetime import datetime
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
CHUNK_SIZE = 8192 #
|
| 14 |
DEFAULT_FILE_EXTENSION = '.pdf'
|
| 15 |
|
|
|
|
| 16 |
class HashProcessor:
|
|
|
|
| 17 |
|
| 18 |
def __init__(self, verbose: bool = True):
|
|
|
|
| 19 |
self.verbose = verbose
|
| 20 |
self.logger = logging.getLogger(__name__)
|
| 21 |
if not verbose:
|
| 22 |
self.logger.setLevel(logging.WARNING)
|
| 23 |
|
| 24 |
def get_file_hash(self, path: str) -> Optional[str]:
|
|
|
|
| 25 |
h = hashlib.sha256()
|
| 26 |
try:
|
| 27 |
with open(path, "rb") as f:
|
|
@@ -41,6 +45,7 @@ class HashProcessor:
|
|
| 41 |
file_extension: str = DEFAULT_FILE_EXTENSION,
|
| 42 |
recursive: bool = False
|
| 43 |
) -> Dict[str, List[Dict[str, str]]]:
|
|
|
|
| 44 |
source_path = Path(source_dir)
|
| 45 |
if not source_path.exists():
|
| 46 |
raise FileNotFoundError(f"Thư mục không tồn tại: {source_dir}")
|
|
@@ -73,6 +78,7 @@ class HashProcessor:
|
|
| 73 |
return hash_to_files
|
| 74 |
|
| 75 |
def load_processed_index(self, index_file: str) -> Dict:
|
|
|
|
| 76 |
if os.path.exists(index_file):
|
| 77 |
try:
|
| 78 |
with open(index_file, "r", encoding="utf-8") as f:
|
|
@@ -86,6 +92,10 @@ class HashProcessor:
|
|
| 86 |
return {}
|
| 87 |
|
| 88 |
def save_processed_index(self, index_file: str, processed_hashes: Dict) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
temp_name = None
|
| 90 |
try:
|
| 91 |
os.makedirs(os.path.dirname(index_file), exist_ok=True)
|
|
@@ -106,8 +116,9 @@ class HashProcessor:
|
|
| 106 |
os.remove(temp_name)
|
| 107 |
|
| 108 |
def get_current_timestamp(self) -> str:
|
|
|
|
| 109 |
return datetime.now().isoformat()
|
| 110 |
|
| 111 |
def get_string_hash(self, text: str) -> str:
|
|
|
|
| 112 |
return hashlib.sha256(text.encode('utf-8')).hexdigest()
|
| 113 |
-
|
|
|
|
| 9 |
from typing import Dict, List, Optional
|
| 10 |
from datetime import datetime
|
| 11 |
|
| 12 |
+
# Hằng số
|
| 13 |
+
CHUNK_SIZE = 8192 # Đọc file theo chunk 8KB
|
| 14 |
DEFAULT_FILE_EXTENSION = '.pdf'
|
| 15 |
|
| 16 |
+
|
| 17 |
class HashProcessor:
|
| 18 |
+
"""Lớp xử lý hash cho files - dùng để phát hiện thay đổi và tránh xử lý lại."""
|
| 19 |
|
| 20 |
def __init__(self, verbose: bool = True):
|
| 21 |
+
"""Khởi tạo HashProcessor."""
|
| 22 |
self.verbose = verbose
|
| 23 |
self.logger = logging.getLogger(__name__)
|
| 24 |
if not verbose:
|
| 25 |
self.logger.setLevel(logging.WARNING)
|
| 26 |
|
| 27 |
def get_file_hash(self, path: str) -> Optional[str]:
|
| 28 |
+
"""Tính SHA256 hash của một file."""
|
| 29 |
h = hashlib.sha256()
|
| 30 |
try:
|
| 31 |
with open(path, "rb") as f:
|
|
|
|
| 45 |
file_extension: str = DEFAULT_FILE_EXTENSION,
|
| 46 |
recursive: bool = False
|
| 47 |
) -> Dict[str, List[Dict[str, str]]]:
|
| 48 |
+
"""Quét thư mục và tính hash cho tất cả files."""
|
| 49 |
source_path = Path(source_dir)
|
| 50 |
if not source_path.exists():
|
| 51 |
raise FileNotFoundError(f"Thư mục không tồn tại: {source_dir}")
|
|
|
|
| 78 |
return hash_to_files
|
| 79 |
|
| 80 |
def load_processed_index(self, index_file: str) -> Dict:
|
| 81 |
+
"""Đọc file index đã xử lý từ JSON."""
|
| 82 |
if os.path.exists(index_file):
|
| 83 |
try:
|
| 84 |
with open(index_file, "r", encoding="utf-8") as f:
|
|
|
|
| 92 |
return {}
|
| 93 |
|
| 94 |
def save_processed_index(self, index_file: str, processed_hashes: Dict) -> None:
|
| 95 |
+
"""Lưu index đã xử lý vào file JSON (atomic write).
|
| 96 |
+
|
| 97 |
+
Ghi vào file tạm trước, sau đó rename để đảm bảo an toàn.
|
| 98 |
+
"""
|
| 99 |
temp_name = None
|
| 100 |
try:
|
| 101 |
os.makedirs(os.path.dirname(index_file), exist_ok=True)
|
|
|
|
| 116 |
os.remove(temp_name)
|
| 117 |
|
| 118 |
def get_current_timestamp(self) -> str:
|
| 119 |
+
"""Lấy timestamp hiện tại theo định dạng ISO."""
|
| 120 |
return datetime.now().isoformat()
|
| 121 |
|
| 122 |
def get_string_hash(self, text: str) -> str:
|
| 123 |
+
"""Tính SHA256 hash của một chuỗi text."""
|
| 124 |
return hashlib.sha256(text.encode('utf-8')).hexdigest()
|
|
|
core/preprocessing/docling_processor.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import gc
|
|
|
|
| 4 |
import signal
|
| 5 |
import logging
|
| 6 |
from datetime import datetime
|
|
@@ -12,19 +13,35 @@ from docling.datamodel.pipeline_options import PdfPipelineOptions, TableStructur
|
|
| 12 |
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
| 13 |
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
class DoclingProcessor:
|
|
|
|
|
|
|
| 17 |
def __init__(self, output_dir: str, use_ocr: bool = True, timeout: int = 300, images_scale: float = 3.0):
|
|
|
|
| 18 |
self.output_dir = output_dir
|
| 19 |
self.timeout = timeout
|
| 20 |
self.logger = logging.getLogger(__name__)
|
|
|
|
| 21 |
os.makedirs(output_dir, exist_ok=True)
|
| 22 |
|
| 23 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
opts = PdfPipelineOptions(do_ocr=use_ocr, do_table_structure=True)
|
| 25 |
opts.table_structure_options = TableStructureOptions(do_cell_matching=True, mode=TableFormerMode.ACCURATE)
|
| 26 |
opts.images_scale = images_scale
|
| 27 |
|
|
|
|
| 28 |
if use_ocr:
|
| 29 |
ocr = EasyOcrOptions()
|
| 30 |
ocr.lang = ["vi"]
|
|
@@ -34,39 +51,69 @@ class DoclingProcessor:
|
|
| 34 |
self.converter = DocumentConverter(format_options={
|
| 35 |
InputFormat.PDF: FormatOption(backend=PyPdfiumDocumentBackend, pipeline_cls=StandardPdfPipeline, pipeline_options=opts)
|
| 36 |
})
|
| 37 |
-
self.logger.info(f"
|
| 38 |
|
| 39 |
def clean_markdown(self, text: str) -> str:
|
|
|
|
| 40 |
text = re.sub(r'\n\s*Trang\s+\d+\s*\n', '\n', text)
|
| 41 |
return re.sub(r'\n{3,}', '\n\n', text).strip()
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
def parse_document(self, file_path: str) -> str | None:
|
|
|
|
| 44 |
if not os.path.exists(file_path):
|
| 45 |
return None
|
| 46 |
filename = os.path.basename(file_path)
|
| 47 |
try:
|
|
|
|
| 48 |
signal.signal(signal.SIGALRM, lambda s, f: (_ for _ in ()).throw(TimeoutError()))
|
| 49 |
signal.alarm(self.timeout)
|
|
|
|
| 50 |
result = self.converter.convert(file_path)
|
| 51 |
md = result.document.export_to_markdown(image_placeholder="")
|
| 52 |
signal.alarm(0)
|
|
|
|
| 53 |
md = self.clean_markdown(md)
|
|
|
|
| 54 |
return f"---\nfilename: {filename}\nfilepath: {file_path}\npage_count: {len(result.document.pages)}\nprocessed_at: {datetime.now().isoformat()}\n---\n\n{md}"
|
| 55 |
except TimeoutError:
|
| 56 |
-
self.logger.warning(f"
|
| 57 |
signal.alarm(0)
|
| 58 |
return None
|
| 59 |
except Exception as e:
|
| 60 |
-
self.logger.error(f"
|
| 61 |
signal.alarm(0)
|
| 62 |
return None
|
| 63 |
|
| 64 |
def parse_directory(self, source_dir: str) -> dict:
|
|
|
|
| 65 |
source_path = Path(source_dir)
|
| 66 |
pdf_files = list(source_path.rglob("*.pdf"))
|
| 67 |
-
self.logger.info(f"
|
| 68 |
|
| 69 |
results = {"total": len(pdf_files), "parsed": 0, "skipped": 0, "errors": 0}
|
|
|
|
| 70 |
for i, fp in enumerate(pdf_files):
|
| 71 |
try:
|
| 72 |
rel = fp.relative_to(source_path)
|
|
@@ -75,20 +122,33 @@ class DoclingProcessor:
|
|
| 75 |
out = Path(self.output_dir) / rel.with_suffix(".md")
|
| 76 |
out.parent.mkdir(parents=True, exist_ok=True)
|
| 77 |
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
| 79 |
results["skipped"] += 1
|
| 80 |
continue
|
| 81 |
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
| 83 |
if md:
|
| 84 |
out.write_text(md, encoding="utf-8")
|
| 85 |
results["parsed"] += 1
|
|
|
|
|
|
|
|
|
|
| 86 |
else:
|
| 87 |
results["errors"] += 1
|
| 88 |
|
|
|
|
| 89 |
if (i + 1) % 10 == 0:
|
| 90 |
gc.collect()
|
| 91 |
-
self.logger.info(f"
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
self.logger.info(f"
|
| 94 |
return results
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import gc
|
| 4 |
+
import sys
|
| 5 |
import signal
|
| 6 |
import logging
|
| 7 |
from datetime import datetime
|
|
|
|
| 13 |
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
| 14 |
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
|
| 15 |
|
| 16 |
+
# Thêm project root vào path để import HashProcessor
|
| 17 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
| 18 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 19 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 20 |
+
|
| 21 |
+
from core.hash_file.hash_file import HashProcessor
|
| 22 |
+
|
| 23 |
|
| 24 |
class DoclingProcessor:
|
| 25 |
+
"""Chuyển đổi PDF sang Markdown bằng Docling."""
|
| 26 |
+
|
| 27 |
def __init__(self, output_dir: str, use_ocr: bool = True, timeout: int = 300, images_scale: float = 3.0):
|
| 28 |
+
"""Khởi tạo processor với cấu hình OCR và table extraction."""
|
| 29 |
self.output_dir = output_dir
|
| 30 |
self.timeout = timeout
|
| 31 |
self.logger = logging.getLogger(__name__)
|
| 32 |
+
self.hasher = HashProcessor(verbose=False)
|
| 33 |
os.makedirs(output_dir, exist_ok=True)
|
| 34 |
|
| 35 |
+
# File lưu hash index
|
| 36 |
+
self.hash_index_path = Path(output_dir) / "docling_hash_index.json"
|
| 37 |
+
self.hash_index = self.hasher.load_processed_index(str(self.hash_index_path))
|
| 38 |
+
|
| 39 |
+
# Cấu hình pipeline PDF
|
| 40 |
opts = PdfPipelineOptions(do_ocr=use_ocr, do_table_structure=True)
|
| 41 |
opts.table_structure_options = TableStructureOptions(do_cell_matching=True, mode=TableFormerMode.ACCURATE)
|
| 42 |
opts.images_scale = images_scale
|
| 43 |
|
| 44 |
+
# Cấu hình OCR tiếng Việt
|
| 45 |
if use_ocr:
|
| 46 |
ocr = EasyOcrOptions()
|
| 47 |
ocr.lang = ["vi"]
|
|
|
|
| 51 |
self.converter = DocumentConverter(format_options={
|
| 52 |
InputFormat.PDF: FormatOption(backend=PyPdfiumDocumentBackend, pipeline_cls=StandardPdfPipeline, pipeline_options=opts)
|
| 53 |
})
|
| 54 |
+
self.logger.info(f"Docling | OCR={use_ocr} | Table=accurate | Scale={images_scale} | timeout={timeout}s")
|
| 55 |
|
| 56 |
def clean_markdown(self, text: str) -> str:
|
| 57 |
+
"""Xóa số trang và khoảng trắng thừa."""
|
| 58 |
text = re.sub(r'\n\s*Trang\s+\d+\s*\n', '\n', text)
|
| 59 |
return re.sub(r'\n{3,}', '\n\n', text).strip()
|
| 60 |
|
| 61 |
+
def _should_process(self, pdf_path: str, output_path: Path) -> bool:
|
| 62 |
+
"""Kiểm tra xem file PDF có cần xử lý lại không (dựa trên hash)."""
|
| 63 |
+
# Nếu output chưa tồn tại -> cần xử lý
|
| 64 |
+
if not output_path.exists():
|
| 65 |
+
return True
|
| 66 |
+
|
| 67 |
+
# Tính hash file PDF hiện tại
|
| 68 |
+
current_hash = self.hasher.get_file_hash(pdf_path)
|
| 69 |
+
if not current_hash:
|
| 70 |
+
return True
|
| 71 |
+
|
| 72 |
+
# So sánh với hash đã lưu
|
| 73 |
+
saved_hash = self.hash_index.get(pdf_path, {}).get("hash")
|
| 74 |
+
return current_hash != saved_hash
|
| 75 |
+
|
| 76 |
+
def _save_hash(self, pdf_path: str, file_hash: str) -> None:
|
| 77 |
+
"""Lưu hash của file đã xử lý vào index."""
|
| 78 |
+
self.hash_index[pdf_path] = {
|
| 79 |
+
"hash": file_hash,
|
| 80 |
+
"processed_at": self.hasher.get_current_timestamp()
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
def parse_document(self, file_path: str) -> str | None:
|
| 84 |
+
"""Chuyển đổi 1 file PDF sang Markdown với timeout."""
|
| 85 |
if not os.path.exists(file_path):
|
| 86 |
return None
|
| 87 |
filename = os.path.basename(file_path)
|
| 88 |
try:
|
| 89 |
+
# Đặt timeout để tránh treo
|
| 90 |
signal.signal(signal.SIGALRM, lambda s, f: (_ for _ in ()).throw(TimeoutError()))
|
| 91 |
signal.alarm(self.timeout)
|
| 92 |
+
|
| 93 |
result = self.converter.convert(file_path)
|
| 94 |
md = result.document.export_to_markdown(image_placeholder="")
|
| 95 |
signal.alarm(0)
|
| 96 |
+
|
| 97 |
md = self.clean_markdown(md)
|
| 98 |
+
# Thêm frontmatter metadata
|
| 99 |
return f"---\nfilename: {filename}\nfilepath: {file_path}\npage_count: {len(result.document.pages)}\nprocessed_at: {datetime.now().isoformat()}\n---\n\n{md}"
|
| 100 |
except TimeoutError:
|
| 101 |
+
self.logger.warning(f"Timeout: {filename}")
|
| 102 |
signal.alarm(0)
|
| 103 |
return None
|
| 104 |
except Exception as e:
|
| 105 |
+
self.logger.error(f"Lỗi: {filename}: {e}")
|
| 106 |
signal.alarm(0)
|
| 107 |
return None
|
| 108 |
|
| 109 |
def parse_directory(self, source_dir: str) -> dict:
|
| 110 |
+
"""Xử lý toàn bộ thư mục PDF, bỏ qua file không thay đổi (dựa trên hash)."""
|
| 111 |
source_path = Path(source_dir)
|
| 112 |
pdf_files = list(source_path.rglob("*.pdf"))
|
| 113 |
+
self.logger.info(f"Tìm thấy {len(pdf_files)} file PDF trong {source_dir}")
|
| 114 |
|
| 115 |
results = {"total": len(pdf_files), "parsed": 0, "skipped": 0, "errors": 0}
|
| 116 |
+
|
| 117 |
for i, fp in enumerate(pdf_files):
|
| 118 |
try:
|
| 119 |
rel = fp.relative_to(source_path)
|
|
|
|
| 122 |
out = Path(self.output_dir) / rel.with_suffix(".md")
|
| 123 |
out.parent.mkdir(parents=True, exist_ok=True)
|
| 124 |
|
| 125 |
+
pdf_path = str(fp)
|
| 126 |
+
|
| 127 |
+
# Kiểm tra hash để quyết định có cần xử lý không
|
| 128 |
+
if not self._should_process(pdf_path, out):
|
| 129 |
results["skipped"] += 1
|
| 130 |
continue
|
| 131 |
|
| 132 |
+
# Tính hash trước khi xử lý
|
| 133 |
+
file_hash = self.hasher.get_file_hash(pdf_path)
|
| 134 |
+
|
| 135 |
+
md = self.parse_document(pdf_path)
|
| 136 |
if md:
|
| 137 |
out.write_text(md, encoding="utf-8")
|
| 138 |
results["parsed"] += 1
|
| 139 |
+
# Lưu hash sau khi xử lý thành công
|
| 140 |
+
if file_hash:
|
| 141 |
+
self._save_hash(pdf_path, file_hash)
|
| 142 |
else:
|
| 143 |
results["errors"] += 1
|
| 144 |
|
| 145 |
+
# Dọn memory mỗi 10 files
|
| 146 |
if (i + 1) % 10 == 0:
|
| 147 |
gc.collect()
|
| 148 |
+
self.logger.info(f"{i+1}/{len(pdf_files)} (bỏ qua: {results['skipped']})")
|
| 149 |
+
|
| 150 |
+
# Lưu hash index sau khi xử lý xong
|
| 151 |
+
self.hasher.save_processed_index(str(self.hash_index_path), self.hash_index)
|
| 152 |
|
| 153 |
+
self.logger.info(f"Xong: {results['parsed']} đã xử lý, {results['skipped']} bỏ qua, {results['errors']} lỗi")
|
| 154 |
return results
|
core/preprocessing/pdf_parser.py
CHANGED
|
@@ -1,19 +1,22 @@
|
|
| 1 |
from docling_processor import DoclingProcessor
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
if __name__ == "__main__":
|
| 10 |
processor = DoclingProcessor(OUTPUT_DIR, use_ocr=USE_OCR)
|
| 11 |
|
| 12 |
if PDF_FILE:
|
| 13 |
-
|
|
|
|
| 14 |
result = processor.parse_document(PDF_FILE)
|
| 15 |
-
print(
|
| 16 |
else:
|
| 17 |
-
|
|
|
|
| 18 |
r = processor.parse_directory(SOURCE_DIR)
|
| 19 |
-
print(f"
|
|
|
|
| 1 |
from docling_processor import DoclingProcessor
|
| 2 |
|
| 3 |
+
# Cấu hình đường dẫn
|
| 4 |
+
PDF_FILE = "" # File đơn lẻ (để trống nếu muốn parse cả thư mục)
|
| 5 |
+
SOURCE_DIR = "data/data_raw" # Thư mục chứa PDFs
|
| 6 |
+
OUTPUT_DIR = "data" # Thư mục xuất Markdown
|
| 7 |
+
USE_OCR = False # Bật OCR cho PDF scan
|
| 8 |
|
| 9 |
|
| 10 |
if __name__ == "__main__":
|
| 11 |
processor = DoclingProcessor(OUTPUT_DIR, use_ocr=USE_OCR)
|
| 12 |
|
| 13 |
if PDF_FILE:
|
| 14 |
+
# Parse 1 file đơn lẻ
|
| 15 |
+
print(f"Đang xử lý: {PDF_FILE}")
|
| 16 |
result = processor.parse_document(PDF_FILE)
|
| 17 |
+
print("Xong!" if result else "Lỗi hoặc bỏ qua")
|
| 18 |
else:
|
| 19 |
+
# Parse cả thư mục
|
| 20 |
+
print(f"Đang xử lý thư mục: {SOURCE_DIR}")
|
| 21 |
r = processor.parse_directory(SOURCE_DIR)
|
| 22 |
+
print(f"Tổng: {r['total']} | Thành công: {r['parsed']} | Bỏ qua: {r['skipped']} | Lỗi: {r['errors']}")
|
core/rag/chunk.py
CHANGED
|
@@ -10,37 +10,41 @@ from llama_index.core import Document
|
|
| 10 |
from llama_index.core.node_parser import MarkdownNodeParser, SentenceSplitter
|
| 11 |
from llama_index.core.schema import BaseNode, TextNode
|
| 12 |
|
| 13 |
-
#
|
| 14 |
CHUNK_SIZE = 1500
|
| 15 |
CHUNK_OVERLAP = 150
|
| 16 |
MIN_CHUNK_SIZE = 200
|
| 17 |
TABLE_ROWS_PER_CHUNK = 15
|
| 18 |
|
| 19 |
-
# Small-to-Big
|
| 20 |
ENABLE_TABLE_SUMMARY = True
|
| 21 |
-
MIN_TABLE_ROWS_FOR_SUMMARY = 0
|
| 22 |
-
SUMMARY_MODEL = "
|
| 23 |
-
|
| 24 |
|
| 25 |
-
# Regex
|
| 26 |
COURSE_PATTERN = re.compile(r"Học\s*phần\s+(.+?)\s*\(\s*m[ãa]\s+([^\)]+)\)", re.I | re.DOTALL)
|
| 27 |
TABLE_PLACEHOLDER = re.compile(r"__TBL_(\d+)__")
|
| 28 |
HEADER_KEYWORDS = {'TT', 'STT', 'MÃ', 'TÊN', 'KHỐI', 'SỐ', 'ID', 'NO', '#'}
|
| 29 |
FRONTMATTER_PATTERN = re.compile(r"^---\s*\n(.*?)\n---\s*\n", re.DOTALL)
|
| 30 |
-
# Pattern để trích xuất số bảng và tiêu đề (ví dụ: "Bảng 3.1 Danh mục các học phần...")
|
| 31 |
TABLE_TITLE_PATTERN = re.compile(r"(?:^|\n)#+\s*(?:Bảng|BẢNG)\s*(\d+(?:\.\d+)?)\s*[.:]*\s*(.+?)(?:\n|$)", re.IGNORECASE)
|
| 32 |
|
| 33 |
|
| 34 |
def _is_table_row(line: str) -> bool:
|
|
|
|
| 35 |
s = line.strip()
|
| 36 |
return s.startswith("|") and s.endswith("|") and s.count("|") >= 2
|
| 37 |
|
|
|
|
| 38 |
def _is_separator(line: str) -> bool:
|
|
|
|
| 39 |
if not _is_table_row(line):
|
| 40 |
return False
|
| 41 |
return not line.strip().replace("|", "").replace("-", "").replace(":", "").replace(" ", "")
|
| 42 |
|
|
|
|
| 43 |
def _is_header(line: str) -> bool:
|
|
|
|
| 44 |
if not _is_table_row(line):
|
| 45 |
return False
|
| 46 |
cells = [c.strip() for c in line.split("|") if c.strip()]
|
|
@@ -50,6 +54,7 @@ def _is_header(line: str) -> bool:
|
|
| 50 |
|
| 51 |
|
| 52 |
def _extract_tables(text: str) -> Tuple[List[Tuple[str, List[str]]], str]:
|
|
|
|
| 53 |
lines, tables, last_header, i = text.split("\n"), [], None, 0
|
| 54 |
|
| 55 |
while i < len(lines) - 1:
|
|
@@ -73,7 +78,7 @@ def _extract_tables(text: str) -> Tuple[List[Tuple[str, List[str]]], str]:
|
|
| 73 |
else:
|
| 74 |
i += 1
|
| 75 |
|
| 76 |
-
#
|
| 77 |
result, tbl_idx, i = [], 0, 0
|
| 78 |
while i < len(lines):
|
| 79 |
if tbl_idx < len(tables) and i < len(lines) - 1 and _is_table_row(lines[i]) and _is_separator(lines[i + 1]):
|
|
@@ -90,6 +95,7 @@ def _extract_tables(text: str) -> Tuple[List[Tuple[str, List[str]]], str]:
|
|
| 90 |
|
| 91 |
|
| 92 |
def _split_table(header: str, rows: List[str], max_rows: int = TABLE_ROWS_PER_CHUNK) -> List[str]:
|
|
|
|
| 93 |
if len(rows) <= max_rows:
|
| 94 |
return [header + "\n".join(rows)]
|
| 95 |
|
|
@@ -98,26 +104,29 @@ def _split_table(header: str, rows: List[str], max_rows: int = TABLE_ROWS_PER_CH
|
|
| 98 |
chunk_rows = rows[i:i + max_rows]
|
| 99 |
chunks.append(chunk_rows)
|
| 100 |
|
| 101 |
-
#
|
| 102 |
if len(chunks) > 1 and len(chunks[-1]) < 5:
|
| 103 |
chunks[-2].extend(chunks[-1])
|
| 104 |
chunks.pop()
|
| 105 |
|
| 106 |
return [header + "\n".join(r) for r in chunks]
|
| 107 |
|
|
|
|
| 108 |
_summary_client: Optional[OpenAI] = None
|
| 109 |
|
|
|
|
| 110 |
def _get_summary_client() -> Optional[OpenAI]:
|
|
|
|
| 111 |
global _summary_client
|
| 112 |
if _summary_client is not None:
|
| 113 |
return _summary_client
|
| 114 |
|
| 115 |
-
api_key = os.getenv("
|
| 116 |
if not api_key:
|
| 117 |
-
print("
|
| 118 |
return None
|
| 119 |
|
| 120 |
-
_summary_client = OpenAI(api_key=api_key, base_url=
|
| 121 |
return _summary_client
|
| 122 |
|
| 123 |
|
|
@@ -130,17 +139,17 @@ def _summarize_table(
|
|
| 130 |
max_retries: int = 5,
|
| 131 |
base_delay: float = 2.0
|
| 132 |
) -> str:
|
| 133 |
-
"""
|
| 134 |
import time
|
| 135 |
|
| 136 |
if not ENABLE_TABLE_SUMMARY:
|
| 137 |
-
raise RuntimeError("
|
| 138 |
|
| 139 |
client = _get_summary_client()
|
| 140 |
if client is None:
|
| 141 |
-
raise RuntimeError("
|
| 142 |
|
| 143 |
-
#
|
| 144 |
table_id_parts = []
|
| 145 |
if table_number:
|
| 146 |
table_id_parts.append(f"Bảng {table_number}")
|
|
@@ -149,7 +158,7 @@ def _summarize_table(
|
|
| 149 |
if source_file:
|
| 150 |
table_id_parts.append(f"từ file {source_file}")
|
| 151 |
|
| 152 |
-
table_identifier = " - ".join(table_id_parts) if table_id_parts else "
|
| 153 |
|
| 154 |
prompt = f"""Tóm tắt ngắn gọn nội dung bảng sau bằng tiếng Việt.
|
| 155 |
|
|
@@ -179,20 +188,17 @@ Bảng:
|
|
| 179 |
if summary.strip():
|
| 180 |
return summary.strip()
|
| 181 |
else:
|
| 182 |
-
raise ValueError("
|
| 183 |
|
| 184 |
except Exception as e:
|
| 185 |
last_error = e
|
| 186 |
-
delay = base_delay * (2 ** attempt) # Exponential backoff: 2, 4, 8, 16, 32
|
| 187 |
-
print(f"
|
| 188 |
-
print(f"
|
| 189 |
time.sleep(delay)
|
| 190 |
|
| 191 |
-
#
|
| 192 |
-
raise RuntimeError(f"
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
|
| 197 |
|
| 198 |
def _create_table_nodes(
|
|
@@ -203,11 +209,11 @@ def _create_table_nodes(
|
|
| 203 |
table_title: str = "",
|
| 204 |
source_file: str = ""
|
| 205 |
) -> List[TextNode]:
|
| 206 |
-
"""
|
| 207 |
-
#
|
| 208 |
row_count = table_text.count("\n")
|
| 209 |
|
| 210 |
-
#
|
| 211 |
table_meta = {**metadata}
|
| 212 |
if table_number:
|
| 213 |
table_meta["table_number"] = table_number
|
|
@@ -215,10 +221,15 @@ def _create_table_nodes(
|
|
| 215 |
table_meta["table_title"] = table_title
|
| 216 |
|
| 217 |
if row_count < MIN_TABLE_ROWS_FOR_SUMMARY:
|
| 218 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
return [TextNode(text=table_text, metadata={**table_meta, "is_table": True})]
|
| 220 |
|
| 221 |
-
#
|
| 222 |
summary = _summarize_table(
|
| 223 |
table_text,
|
| 224 |
context_hint,
|
|
@@ -227,37 +238,36 @@ def _create_table_nodes(
|
|
| 227 |
source_file=source_file
|
| 228 |
)
|
| 229 |
|
| 230 |
-
#
|
| 231 |
parent_id = str(uuid.uuid4())
|
| 232 |
parent_node = TextNode(
|
| 233 |
text=table_text,
|
| 234 |
metadata={
|
| 235 |
**table_meta,
|
| 236 |
"is_table": True,
|
| 237 |
-
"is_parent": True, # Flag
|
| 238 |
"node_id": parent_id,
|
| 239 |
}
|
| 240 |
)
|
| 241 |
parent_node.id_ = parent_id
|
| 242 |
|
| 243 |
-
#
|
| 244 |
summary_node = TextNode(
|
| 245 |
text=summary,
|
| 246 |
metadata={
|
| 247 |
**table_meta,
|
| 248 |
"is_table_summary": True,
|
| 249 |
-
"parent_id": parent_id, # Link
|
| 250 |
}
|
| 251 |
)
|
| 252 |
|
| 253 |
-
table_id = f"Bảng {table_number}" if table_number else "
|
| 254 |
-
print(f"
|
| 255 |
return [parent_node, summary_node]
|
| 256 |
|
| 257 |
|
| 258 |
-
|
| 259 |
-
|
| 260 |
def _enrich_metadata(node: BaseNode, source_path: Path | None) -> None:
|
|
|
|
| 261 |
if source_path:
|
| 262 |
node.metadata.update({"source_path": str(source_path), "source_file": source_path.name})
|
| 263 |
if "Học phần" in (text := node.get_content()) and (m := COURSE_PATTERN.search(text)):
|
|
@@ -265,6 +275,7 @@ def _enrich_metadata(node: BaseNode, source_path: Path | None) -> None:
|
|
| 265 |
|
| 266 |
|
| 267 |
def _chunk_text(text: str, metadata: dict) -> List[BaseNode]:
|
|
|
|
| 268 |
if len(text) <= CHUNK_SIZE:
|
| 269 |
return [TextNode(text=text, metadata=metadata.copy())]
|
| 270 |
return SentenceSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP).get_nodes_from_documents(
|
|
@@ -273,6 +284,7 @@ def _chunk_text(text: str, metadata: dict) -> List[BaseNode]:
|
|
| 273 |
|
| 274 |
|
| 275 |
def _extract_frontmatter(text: str) -> Tuple[Dict[str, Any], str]:
|
|
|
|
| 276 |
match = FRONTMATTER_PATTERN.match(text)
|
| 277 |
if not match:
|
| 278 |
return {}, text
|
|
@@ -286,22 +298,23 @@ def _extract_frontmatter(text: str) -> Tuple[Dict[str, Any], str]:
|
|
| 286 |
|
| 287 |
|
| 288 |
def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[BaseNode]:
|
|
|
|
| 289 |
if not text or not text.strip():
|
| 290 |
return []
|
| 291 |
|
| 292 |
path = Path(source_path) if source_path else None
|
| 293 |
|
| 294 |
-
#
|
| 295 |
frontmatter_meta, text = _extract_frontmatter(text)
|
| 296 |
|
| 297 |
tables, text_with_placeholders = _extract_tables(text)
|
| 298 |
|
| 299 |
-
#
|
| 300 |
base_meta = {**frontmatter_meta}
|
| 301 |
if path:
|
| 302 |
base_meta.update({"source_path": str(path), "source_file": path.name})
|
| 303 |
|
| 304 |
-
# Parse
|
| 305 |
doc = Document(text=text_with_placeholders, metadata=base_meta.copy())
|
| 306 |
heading_nodes = MarkdownNodeParser().get_nodes_from_documents([doc])
|
| 307 |
|
|
@@ -316,14 +329,13 @@ def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[Bas
|
|
| 316 |
|
| 317 |
last_end = 0
|
| 318 |
for match in matches:
|
| 319 |
-
# Text
|
| 320 |
before_text = content[last_end:match.start()].strip()
|
| 321 |
|
| 322 |
-
#
|
| 323 |
table_number = ""
|
| 324 |
table_title = ""
|
| 325 |
if before_text:
|
| 326 |
-
# Look for patterns like "## Bảng 3.1 Danh mục các học phần..."
|
| 327 |
title_match = TABLE_TITLE_PATTERN.search(before_text)
|
| 328 |
if title_match:
|
| 329 |
table_number = title_match.group(1).strip()
|
|
@@ -332,15 +344,15 @@ def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[Bas
|
|
| 332 |
if before_text and len(before_text) >= MIN_CHUNK_SIZE:
|
| 333 |
nodes.extend(_chunk_text(before_text, meta) if len(before_text) > CHUNK_SIZE else [TextNode(text=before_text, metadata=meta.copy())])
|
| 334 |
|
| 335 |
-
#
|
| 336 |
if (idx := int(match.group(1))) < len(tables):
|
| 337 |
header, rows = tables[idx]
|
| 338 |
table_chunks = _split_table(header, rows)
|
| 339 |
|
| 340 |
-
#
|
| 341 |
context_hint = meta.get("Header 1", "") or meta.get("section", "")
|
| 342 |
|
| 343 |
-
#
|
| 344 |
source_file = meta.get("source_file", "") or (path.name if path else "")
|
| 345 |
|
| 346 |
for i, chunk in enumerate(table_chunks):
|
|
@@ -348,7 +360,7 @@ def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[Bas
|
|
| 348 |
if len(table_chunks) > 1:
|
| 349 |
chunk_meta["table_part"] = f"{i+1}/{len(table_chunks)}"
|
| 350 |
|
| 351 |
-
#
|
| 352 |
table_nodes = _create_table_nodes(
|
| 353 |
chunk,
|
| 354 |
chunk_meta,
|
|
@@ -361,11 +373,11 @@ def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[Bas
|
|
| 361 |
|
| 362 |
last_end = match.end()
|
| 363 |
|
| 364 |
-
# Text
|
| 365 |
if (after := content[last_end:].strip()) and len(after) >= MIN_CHUNK_SIZE:
|
| 366 |
nodes.extend(_chunk_text(after, meta) if len(after) > CHUNK_SIZE else [TextNode(text=after, metadata=meta.copy())])
|
| 367 |
|
| 368 |
-
|
| 369 |
final: List[BaseNode] = []
|
| 370 |
i = 0
|
| 371 |
while i < len(nodes):
|
|
@@ -373,12 +385,12 @@ def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[Bas
|
|
| 373 |
curr_content = curr.get_content()
|
| 374 |
curr_is_table = curr.metadata.get("is_table")
|
| 375 |
|
| 376 |
-
#
|
| 377 |
if not curr_content.strip():
|
| 378 |
i += 1
|
| 379 |
continue
|
| 380 |
|
| 381 |
-
#
|
| 382 |
if not curr_is_table and len(curr_content) < MIN_CHUNK_SIZE and i + 1 < len(nodes):
|
| 383 |
next_node = nodes[i + 1]
|
| 384 |
next_is_table = next_node.metadata.get("is_table")
|
|
@@ -405,7 +417,8 @@ def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[Bas
|
|
| 405 |
|
| 406 |
|
| 407 |
def chunk_markdown_file(path: str | Path) -> List[BaseNode]:
|
|
|
|
| 408 |
p = Path(path)
|
| 409 |
if not p.exists():
|
| 410 |
-
raise FileNotFoundError(f"
|
| 411 |
return chunk_markdown(p.read_text(encoding="utf-8"), source_path=p)
|
|
|
|
| 10 |
from llama_index.core.node_parser import MarkdownNodeParser, SentenceSplitter
|
| 11 |
from llama_index.core.schema import BaseNode, TextNode
|
| 12 |
|
| 13 |
+
# Cấu hình chunking
|
| 14 |
CHUNK_SIZE = 1500
|
| 15 |
CHUNK_OVERLAP = 150
|
| 16 |
MIN_CHUNK_SIZE = 200
|
| 17 |
TABLE_ROWS_PER_CHUNK = 15
|
| 18 |
|
| 19 |
+
# Cấu hình Small-to-Big
|
| 20 |
ENABLE_TABLE_SUMMARY = True
|
| 21 |
+
MIN_TABLE_ROWS_FOR_SUMMARY = 0
|
| 22 |
+
SUMMARY_MODEL = "openai/gpt-oss-120b"
|
| 23 |
+
GROQ_BASE_URL = "https://api.groq.com/openai/v1"
|
| 24 |
|
| 25 |
+
# Regex patterns
|
| 26 |
COURSE_PATTERN = re.compile(r"Học\s*phần\s+(.+?)\s*\(\s*m[ãa]\s+([^\)]+)\)", re.I | re.DOTALL)
|
| 27 |
TABLE_PLACEHOLDER = re.compile(r"__TBL_(\d+)__")
|
| 28 |
HEADER_KEYWORDS = {'TT', 'STT', 'MÃ', 'TÊN', 'KHỐI', 'SỐ', 'ID', 'NO', '#'}
|
| 29 |
FRONTMATTER_PATTERN = re.compile(r"^---\s*\n(.*?)\n---\s*\n", re.DOTALL)
|
|
|
|
| 30 |
TABLE_TITLE_PATTERN = re.compile(r"(?:^|\n)#+\s*(?:Bảng|BẢNG)\s*(\d+(?:\.\d+)?)\s*[.:]*\s*(.+?)(?:\n|$)", re.IGNORECASE)
|
| 31 |
|
| 32 |
|
| 33 |
def _is_table_row(line: str) -> bool:
|
| 34 |
+
"""Kiểm tra dòng có phải là hàng trong bảng Markdown không."""
|
| 35 |
s = line.strip()
|
| 36 |
return s.startswith("|") and s.endswith("|") and s.count("|") >= 2
|
| 37 |
|
| 38 |
+
|
| 39 |
def _is_separator(line: str) -> bool:
|
| 40 |
+
"""Kiểm tra dòng có phải là separator của bảng (|---|---|)."""
|
| 41 |
if not _is_table_row(line):
|
| 42 |
return False
|
| 43 |
return not line.strip().replace("|", "").replace("-", "").replace(":", "").replace(" ", "")
|
| 44 |
|
| 45 |
+
|
| 46 |
def _is_header(line: str) -> bool:
|
| 47 |
+
"""Kiểm tra dòng có phải là header của bảng không."""
|
| 48 |
if not _is_table_row(line):
|
| 49 |
return False
|
| 50 |
cells = [c.strip() for c in line.split("|") if c.strip()]
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
def _extract_tables(text: str) -> Tuple[List[Tuple[str, List[str]]], str]:
|
| 57 |
+
"""Trích xuất bảng từ text và thay bằng placeholder."""
|
| 58 |
lines, tables, last_header, i = text.split("\n"), [], None, 0
|
| 59 |
|
| 60 |
while i < len(lines) - 1:
|
|
|
|
| 78 |
else:
|
| 79 |
i += 1
|
| 80 |
|
| 81 |
+
# Thay bảng bằng placeholder
|
| 82 |
result, tbl_idx, i = [], 0, 0
|
| 83 |
while i < len(lines):
|
| 84 |
if tbl_idx < len(tables) and i < len(lines) - 1 and _is_table_row(lines[i]) and _is_separator(lines[i + 1]):
|
|
|
|
| 95 |
|
| 96 |
|
| 97 |
def _split_table(header: str, rows: List[str], max_rows: int = TABLE_ROWS_PER_CHUNK) -> List[str]:
|
| 98 |
+
"""Chia bảng lớn thành nhiều chunks nhỏ."""
|
| 99 |
if len(rows) <= max_rows:
|
| 100 |
return [header + "\n".join(rows)]
|
| 101 |
|
|
|
|
| 104 |
chunk_rows = rows[i:i + max_rows]
|
| 105 |
chunks.append(chunk_rows)
|
| 106 |
|
| 107 |
+
# Gộp chunk cuối nếu quá nhỏ (< 5 dòng)
|
| 108 |
if len(chunks) > 1 and len(chunks[-1]) < 5:
|
| 109 |
chunks[-2].extend(chunks[-1])
|
| 110 |
chunks.pop()
|
| 111 |
|
| 112 |
return [header + "\n".join(r) for r in chunks]
|
| 113 |
|
| 114 |
+
|
| 115 |
_summary_client: Optional[OpenAI] = None
|
| 116 |
|
| 117 |
+
|
| 118 |
def _get_summary_client() -> Optional[OpenAI]:
|
| 119 |
+
"""Lấy Groq client để tóm tắt bảng."""
|
| 120 |
global _summary_client
|
| 121 |
if _summary_client is not None:
|
| 122 |
return _summary_client
|
| 123 |
|
| 124 |
+
api_key = os.getenv("GROQ_API_KEY", "").strip()
|
| 125 |
if not api_key:
|
| 126 |
+
print("Chưa đặt GROQ_API_KEY. Tắt tính năng tóm tắt bảng.")
|
| 127 |
return None
|
| 128 |
|
| 129 |
+
_summary_client = OpenAI(api_key=api_key, base_url=GROQ_BASE_URL)
|
| 130 |
return _summary_client
|
| 131 |
|
| 132 |
|
|
|
|
| 139 |
max_retries: int = 5,
|
| 140 |
base_delay: float = 2.0
|
| 141 |
) -> str:
|
| 142 |
+
"""Tóm tắt bảng bằng LLM với retry logic."""
|
| 143 |
import time
|
| 144 |
|
| 145 |
if not ENABLE_TABLE_SUMMARY:
|
| 146 |
+
raise RuntimeError("Tính năng tóm tắt bảng đã tắt. Đặt ENABLE_TABLE_SUMMARY = True")
|
| 147 |
|
| 148 |
client = _get_summary_client()
|
| 149 |
if client is None:
|
| 150 |
+
raise RuntimeError("Chưa đặt GROQ_API_KEY. Không thể tóm tắt bảng.")
|
| 151 |
|
| 152 |
+
# Tạo chuỗi định danh bảng
|
| 153 |
table_id_parts = []
|
| 154 |
if table_number:
|
| 155 |
table_id_parts.append(f"Bảng {table_number}")
|
|
|
|
| 158 |
if source_file:
|
| 159 |
table_id_parts.append(f"từ file {source_file}")
|
| 160 |
|
| 161 |
+
table_identifier = " - ".join(table_id_parts) if table_id_parts else "Bảng không xác định"
|
| 162 |
|
| 163 |
prompt = f"""Tóm tắt ngắn gọn nội dung bảng sau bằng tiếng Việt.
|
| 164 |
|
|
|
|
| 188 |
if summary.strip():
|
| 189 |
return summary.strip()
|
| 190 |
else:
|
| 191 |
+
raise ValueError("API trả về summary rỗng")
|
| 192 |
|
| 193 |
except Exception as e:
|
| 194 |
last_error = e
|
| 195 |
+
delay = base_delay * (2 ** attempt) # Exponential backoff: 2, 4, 8, 16, 32 giây
|
| 196 |
+
print(f"Thử lại {attempt + 1}/{max_retries} cho {table_identifier}: {e}")
|
| 197 |
+
print(f" Đợi {delay:.1f}s trước khi thử lại...")
|
| 198 |
time.sleep(delay)
|
| 199 |
|
| 200 |
+
# Tất cả retry đều thất bại
|
| 201 |
+
raise RuntimeError(f"Không thể tóm tắt {table_identifier} sau {max_retries} lần thử. Lỗi cuối: {last_error}")
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
|
| 204 |
def _create_table_nodes(
|
|
|
|
| 209 |
table_title: str = "",
|
| 210 |
source_file: str = ""
|
| 211 |
) -> List[TextNode]:
|
| 212 |
+
"""Tạo nodes cho bảng. Bảng lớn sẽ có parent + summary node."""
|
| 213 |
+
# Đếm số dòng để quyết định có cần tóm tắt không
|
| 214 |
row_count = table_text.count("\n")
|
| 215 |
|
| 216 |
+
# Thêm thông tin bảng vào metadata
|
| 217 |
table_meta = {**metadata}
|
| 218 |
if table_number:
|
| 219 |
table_meta["table_number"] = table_number
|
|
|
|
| 221 |
table_meta["table_title"] = table_title
|
| 222 |
|
| 223 |
if row_count < MIN_TABLE_ROWS_FOR_SUMMARY:
|
| 224 |
+
# Bảng quá nhỏ, không cần tóm tắt
|
| 225 |
+
return [TextNode(text=table_text, metadata={**table_meta, "is_table": True})]
|
| 226 |
+
|
| 227 |
+
# Kiểm tra có thể tóm tắt không (cần API key)
|
| 228 |
+
if _get_summary_client() is None:
|
| 229 |
+
# Không có API key -> trả về node bảng đơn giản, không tóm tắt
|
| 230 |
return [TextNode(text=table_text, metadata={**table_meta, "is_table": True})]
|
| 231 |
|
| 232 |
+
# Tạo summary với retry logic
|
| 233 |
summary = _summarize_table(
|
| 234 |
table_text,
|
| 235 |
context_hint,
|
|
|
|
| 238 |
source_file=source_file
|
| 239 |
)
|
| 240 |
|
| 241 |
+
# Tạo parent node (bảng gốc - KHÔNG embed)
|
| 242 |
parent_id = str(uuid.uuid4())
|
| 243 |
parent_node = TextNode(
|
| 244 |
text=table_text,
|
| 245 |
metadata={
|
| 246 |
**table_meta,
|
| 247 |
"is_table": True,
|
| 248 |
+
"is_parent": True, # Flag để bỏ qua embedding
|
| 249 |
"node_id": parent_id,
|
| 250 |
}
|
| 251 |
)
|
| 252 |
parent_node.id_ = parent_id
|
| 253 |
|
| 254 |
+
# Tạo summary node (SẼ được embed để search)
|
| 255 |
summary_node = TextNode(
|
| 256 |
text=summary,
|
| 257 |
metadata={
|
| 258 |
**table_meta,
|
| 259 |
"is_table_summary": True,
|
| 260 |
+
"parent_id": parent_id, # Link tới parent
|
| 261 |
}
|
| 262 |
)
|
| 263 |
|
| 264 |
+
table_id = f"Bảng {table_number}" if table_number else "bảng"
|
| 265 |
+
print(f"Đã tạo summary cho {table_id} ({row_count} dòng)")
|
| 266 |
return [parent_node, summary_node]
|
| 267 |
|
| 268 |
|
|
|
|
|
|
|
| 269 |
def _enrich_metadata(node: BaseNode, source_path: Path | None) -> None:
|
| 270 |
+
"""Bổ sung metadata từ source path và trích xuất thông tin học phần."""
|
| 271 |
if source_path:
|
| 272 |
node.metadata.update({"source_path": str(source_path), "source_file": source_path.name})
|
| 273 |
if "Học phần" in (text := node.get_content()) and (m := COURSE_PATTERN.search(text)):
|
|
|
|
| 275 |
|
| 276 |
|
| 277 |
def _chunk_text(text: str, metadata: dict) -> List[BaseNode]:
|
| 278 |
+
"""Chia text thành chunks theo kích thước cấu hình."""
|
| 279 |
if len(text) <= CHUNK_SIZE:
|
| 280 |
return [TextNode(text=text, metadata=metadata.copy())]
|
| 281 |
return SentenceSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP).get_nodes_from_documents(
|
|
|
|
| 284 |
|
| 285 |
|
| 286 |
def _extract_frontmatter(text: str) -> Tuple[Dict[str, Any], str]:
|
| 287 |
+
"""Trích xuất YAML frontmatter từ đầu file."""
|
| 288 |
match = FRONTMATTER_PATTERN.match(text)
|
| 289 |
if not match:
|
| 290 |
return {}, text
|
|
|
|
| 298 |
|
| 299 |
|
| 300 |
def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[BaseNode]:
|
| 301 |
+
"""Chunk một file Markdown thành các nodes."""
|
| 302 |
if not text or not text.strip():
|
| 303 |
return []
|
| 304 |
|
| 305 |
path = Path(source_path) if source_path else None
|
| 306 |
|
| 307 |
+
# Trích xuất YAML frontmatter làm metadata (không chunk)
|
| 308 |
frontmatter_meta, text = _extract_frontmatter(text)
|
| 309 |
|
| 310 |
tables, text_with_placeholders = _extract_tables(text)
|
| 311 |
|
| 312 |
+
# Metadata cơ bản từ frontmatter + source path
|
| 313 |
base_meta = {**frontmatter_meta}
|
| 314 |
if path:
|
| 315 |
base_meta.update({"source_path": str(path), "source_file": path.name})
|
| 316 |
|
| 317 |
+
# Parse theo headings
|
| 318 |
doc = Document(text=text_with_placeholders, metadata=base_meta.copy())
|
| 319 |
heading_nodes = MarkdownNodeParser().get_nodes_from_documents([doc])
|
| 320 |
|
|
|
|
| 329 |
|
| 330 |
last_end = 0
|
| 331 |
for match in matches:
|
| 332 |
+
# Text trước bảng
|
| 333 |
before_text = content[last_end:match.start()].strip()
|
| 334 |
|
| 335 |
+
# Trích xuất số bảng và tiêu đề từ text trước bảng
|
| 336 |
table_number = ""
|
| 337 |
table_title = ""
|
| 338 |
if before_text:
|
|
|
|
| 339 |
title_match = TABLE_TITLE_PATTERN.search(before_text)
|
| 340 |
if title_match:
|
| 341 |
table_number = title_match.group(1).strip()
|
|
|
|
| 344 |
if before_text and len(before_text) >= MIN_CHUNK_SIZE:
|
| 345 |
nodes.extend(_chunk_text(before_text, meta) if len(before_text) > CHUNK_SIZE else [TextNode(text=before_text, metadata=meta.copy())])
|
| 346 |
|
| 347 |
+
# Chunk bảng - sử dụng Small-to-Big pattern
|
| 348 |
if (idx := int(match.group(1))) < len(tables):
|
| 349 |
header, rows = tables[idx]
|
| 350 |
table_chunks = _split_table(header, rows)
|
| 351 |
|
| 352 |
+
# Lấy context hint từ header path
|
| 353 |
context_hint = meta.get("Header 1", "") or meta.get("section", "")
|
| 354 |
|
| 355 |
+
# Lấy source file cho summary
|
| 356 |
source_file = meta.get("source_file", "") or (path.name if path else "")
|
| 357 |
|
| 358 |
for i, chunk in enumerate(table_chunks):
|
|
|
|
| 360 |
if len(table_chunks) > 1:
|
| 361 |
chunk_meta["table_part"] = f"{i+1}/{len(table_chunks)}"
|
| 362 |
|
| 363 |
+
# Tạo parent + summary nodes nếu cần
|
| 364 |
table_nodes = _create_table_nodes(
|
| 365 |
chunk,
|
| 366 |
chunk_meta,
|
|
|
|
| 373 |
|
| 374 |
last_end = match.end()
|
| 375 |
|
| 376 |
+
# Text sau bảng
|
| 377 |
if (after := content[last_end:].strip()) and len(after) >= MIN_CHUNK_SIZE:
|
| 378 |
nodes.extend(_chunk_text(after, meta) if len(after) > CHUNK_SIZE else [TextNode(text=after, metadata=meta.copy())])
|
| 379 |
|
| 380 |
+
# Gộp các node nhỏ với node kế tiếp
|
| 381 |
final: List[BaseNode] = []
|
| 382 |
i = 0
|
| 383 |
while i < len(nodes):
|
|
|
|
| 385 |
curr_content = curr.get_content()
|
| 386 |
curr_is_table = curr.metadata.get("is_table")
|
| 387 |
|
| 388 |
+
# Bỏ qua node rỗng
|
| 389 |
if not curr_content.strip():
|
| 390 |
i += 1
|
| 391 |
continue
|
| 392 |
|
| 393 |
+
# Nếu node hiện tại nhỏ và không phải bảng -> gộp với node sau
|
| 394 |
if not curr_is_table and len(curr_content) < MIN_CHUNK_SIZE and i + 1 < len(nodes):
|
| 395 |
next_node = nodes[i + 1]
|
| 396 |
next_is_table = next_node.metadata.get("is_table")
|
|
|
|
| 417 |
|
| 418 |
|
| 419 |
def chunk_markdown_file(path: str | Path) -> List[BaseNode]:
|
| 420 |
+
"""Đọc và chunk một file Markdown."""
|
| 421 |
p = Path(path)
|
| 422 |
if not p.exists():
|
| 423 |
+
raise FileNotFoundError(f"Không tìm thấy file: {p}")
|
| 424 |
return chunk_markdown(p.read_text(encoding="utf-8"), source_path=p)
|
core/rag/embedding_model.py
CHANGED
|
@@ -1,26 +1,30 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
import os
|
| 3 |
import logging
|
|
|
|
| 4 |
from dataclasses import dataclass
|
| 5 |
from typing import List, Sequence
|
| 6 |
import numpy as np
|
| 7 |
from openai import OpenAI
|
| 8 |
from langchain_core.embeddings import Embeddings
|
| 9 |
-
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
|
| 13 |
@dataclass
|
| 14 |
class EmbeddingConfig:
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
_embed_config: EmbeddingConfig | None = None
|
| 22 |
|
|
|
|
| 23 |
def get_embedding_config() -> EmbeddingConfig:
|
|
|
|
| 24 |
global _embed_config
|
| 25 |
if _embed_config is None:
|
| 26 |
_embed_config = EmbeddingConfig()
|
|
@@ -28,26 +32,32 @@ def get_embedding_config() -> EmbeddingConfig:
|
|
| 28 |
|
| 29 |
|
| 30 |
class QwenEmbeddings(Embeddings):
|
|
|
|
|
|
|
| 31 |
def __init__(self, config: EmbeddingConfig | None = None):
|
|
|
|
| 32 |
self.config = config or get_embedding_config()
|
| 33 |
|
| 34 |
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
| 35 |
if not api_key:
|
| 36 |
-
raise ValueError("
|
| 37 |
|
| 38 |
self._client = OpenAI(
|
| 39 |
api_key=api_key,
|
| 40 |
base_url=self.config.api_base_url,
|
| 41 |
)
|
| 42 |
-
logger.info(f"QwenEmbeddings
|
| 43 |
|
| 44 |
def embed_query(self, text: str) -> List[float]:
|
|
|
|
| 45 |
return self._embed_texts([text])[0]
|
| 46 |
|
| 47 |
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
|
| 48 |
return self._embed_texts(texts)
|
| 49 |
|
| 50 |
def _embed_texts(self, texts: Sequence[str]) -> List[List[float]]:
|
|
|
|
| 51 |
if not texts:
|
| 52 |
return []
|
| 53 |
|
|
@@ -55,9 +65,11 @@ class QwenEmbeddings(Embeddings):
|
|
| 55 |
batch_size = self.config.batch_size
|
| 56 |
max_retries = 3
|
| 57 |
|
|
|
|
| 58 |
for i in range(0, len(texts), batch_size):
|
| 59 |
batch = list(texts[i:i + batch_size])
|
| 60 |
|
|
|
|
| 61 |
for attempt in range(max_retries):
|
| 62 |
try:
|
| 63 |
response = self._client.embeddings.create(
|
|
@@ -68,9 +80,10 @@ class QwenEmbeddings(Embeddings):
|
|
| 68 |
all_embeddings.append(item.embedding)
|
| 69 |
break
|
| 70 |
except Exception as e:
|
|
|
|
| 71 |
if "rate" in str(e).lower() and attempt < max_retries - 1:
|
| 72 |
-
wait_time = 2 ** attempt # 1s, 2s, 4s
|
| 73 |
-
logger.warning(f"
|
| 74 |
time.sleep(wait_time)
|
| 75 |
else:
|
| 76 |
raise
|
|
@@ -78,9 +91,10 @@ class QwenEmbeddings(Embeddings):
|
|
| 78 |
return all_embeddings
|
| 79 |
|
| 80 |
def embed_texts_np(self, texts: Sequence[str]) -> np.ndarray:
|
|
|
|
| 81 |
return np.asarray(self._embed_texts(list(texts)), dtype=np.float32)
|
| 82 |
|
| 83 |
|
| 84 |
-
#
|
| 85 |
SiliconFlowConfig = EmbeddingConfig
|
| 86 |
get_config = get_embedding_config
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
import os
|
| 3 |
import logging
|
| 4 |
+
import time
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import List, Sequence
|
| 7 |
import numpy as np
|
| 8 |
from openai import OpenAI
|
| 9 |
from langchain_core.embeddings import Embeddings
|
| 10 |
+
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
|
| 14 |
@dataclass
|
| 15 |
class EmbeddingConfig:
|
| 16 |
+
"""Cấu hình cho embedding model."""
|
| 17 |
+
api_base_url: str = "https://api.siliconflow.com/v1" # SiliconFlow API
|
| 18 |
+
model: str = "Qwen/Qwen3-Embedding-4B" # Model embedding
|
| 19 |
+
dimension: int = 2048 # Số chiều vector
|
| 20 |
+
batch_size: int = 16 # Số text mỗi batch
|
| 21 |
|
| 22 |
|
| 23 |
_embed_config: EmbeddingConfig | None = None
|
| 24 |
|
| 25 |
+
|
| 26 |
def get_embedding_config() -> EmbeddingConfig:
|
| 27 |
+
"""Lấy cấu hình embedding (singleton pattern)."""
|
| 28 |
global _embed_config
|
| 29 |
if _embed_config is None:
|
| 30 |
_embed_config = EmbeddingConfig()
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
class QwenEmbeddings(Embeddings):
|
| 35 |
+
"""Wrapper embedding model Qwen qua SiliconFlow API"""
|
| 36 |
+
|
| 37 |
def __init__(self, config: EmbeddingConfig | None = None):
|
| 38 |
+
"""Khởi tạo embedding client."""
|
| 39 |
self.config = config or get_embedding_config()
|
| 40 |
|
| 41 |
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
| 42 |
if not api_key:
|
| 43 |
+
raise ValueError("Chưa đặt biến môi trường SILICONFLOW_API_KEY")
|
| 44 |
|
| 45 |
self._client = OpenAI(
|
| 46 |
api_key=api_key,
|
| 47 |
base_url=self.config.api_base_url,
|
| 48 |
)
|
| 49 |
+
logger.info(f"Đã khởi tạo QwenEmbeddings: {self.config.model}")
|
| 50 |
|
| 51 |
def embed_query(self, text: str) -> List[float]:
|
| 52 |
+
"""Embed một câu query (dùng cho search)."""
|
| 53 |
return self._embed_texts([text])[0]
|
| 54 |
|
| 55 |
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 56 |
+
"""Embed nhiều documents (dùng khi index)."""
|
| 57 |
return self._embed_texts(texts)
|
| 58 |
|
| 59 |
def _embed_texts(self, texts: Sequence[str]) -> List[List[float]]:
|
| 60 |
+
"""Embed danh sách texts theo batch với retry logic."""
|
| 61 |
if not texts:
|
| 62 |
return []
|
| 63 |
|
|
|
|
| 65 |
batch_size = self.config.batch_size
|
| 66 |
max_retries = 3
|
| 67 |
|
| 68 |
+
# Xử lý theo batch
|
| 69 |
for i in range(0, len(texts), batch_size):
|
| 70 |
batch = list(texts[i:i + batch_size])
|
| 71 |
|
| 72 |
+
# Retry logic cho rate limit
|
| 73 |
for attempt in range(max_retries):
|
| 74 |
try:
|
| 75 |
response = self._client.embeddings.create(
|
|
|
|
| 80 |
all_embeddings.append(item.embedding)
|
| 81 |
break
|
| 82 |
except Exception as e:
|
| 83 |
+
# Nếu bị rate limit -> đợi rồi thử lại
|
| 84 |
if "rate" in str(e).lower() and attempt < max_retries - 1:
|
| 85 |
+
wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
|
| 86 |
+
logger.warning(f"Bị rate limit, đợi {wait_time}s...")
|
| 87 |
time.sleep(wait_time)
|
| 88 |
else:
|
| 89 |
raise
|
|
|
|
| 91 |
return all_embeddings
|
| 92 |
|
| 93 |
def embed_texts_np(self, texts: Sequence[str]) -> np.ndarray:
|
| 94 |
+
"""Embed texts và trả về numpy array (tiện cho tính toán)."""
|
| 95 |
return np.asarray(self._embed_texts(list(texts)), dtype=np.float32)
|
| 96 |
|
| 97 |
|
| 98 |
+
# Alias để tương thích ngược
|
| 99 |
SiliconFlowConfig = EmbeddingConfig
|
| 100 |
get_config = get_embedding_config
|
core/rag/generator.py
CHANGED
|
@@ -5,7 +5,7 @@ if TYPE_CHECKING:
|
|
| 5 |
from core.rag.retrival import Retriever
|
| 6 |
|
| 7 |
|
| 8 |
-
# System prompt
|
| 9 |
SYSTEM_PROMPT = """Bạn là Trợ lý học vụ Đại học Bách khoa Hà Nội.
|
| 10 |
|
| 11 |
## NGUYÊN TẮC:
|
|
@@ -16,6 +16,7 @@ SYSTEM_PROMPT = """Bạn là Trợ lý học vụ Đại học Bách khoa Hà N
|
|
| 16 |
|
| 17 |
|
| 18 |
def build_context(results: List[Dict[str, Any]], max_chars: int = 8000) -> str:
|
|
|
|
| 19 |
parts = []
|
| 20 |
for i, r in enumerate(results, 1):
|
| 21 |
meta = r.get("metadata", {})
|
|
@@ -30,7 +31,7 @@ def build_context(results: List[Dict[str, Any]], max_chars: int = 8000) -> str:
|
|
| 30 |
issued_year = meta.get("issued_year", "")
|
| 31 |
content = r.get("content", "").strip()
|
| 32 |
|
| 33 |
-
#
|
| 34 |
meta_info = f"Nguồn: {source}"
|
| 35 |
if header and header != "/":
|
| 36 |
meta_info += f" | Mục: {header}"
|
|
@@ -53,16 +54,20 @@ def build_context(results: List[Dict[str, Any]], max_chars: int = 8000) -> str:
|
|
| 53 |
parts.append(f"[TÀI LIỆU {i}]\n{meta_info}\n{content}")
|
| 54 |
|
| 55 |
context = "\n---\n".join(parts)
|
|
|
|
| 56 |
return context[:max_chars] if len(context) > max_chars else context
|
| 57 |
|
| 58 |
|
| 59 |
def build_prompt(question: str, context: str) -> str:
|
|
|
|
| 60 |
return f"{SYSTEM_PROMPT}\n\n## CONTEXT:\n{context}\n\n## CÂU HỎI: {question}\n\n## TRẢ LỜI:"
|
| 61 |
|
| 62 |
|
| 63 |
class RAGContextBuilder:
|
|
|
|
| 64 |
|
| 65 |
def __init__(self, retriever: "Retriever", max_context_chars: int = 8000):
|
|
|
|
| 66 |
self._retriever = retriever
|
| 67 |
self._max_context_chars = max_context_chars
|
| 68 |
|
|
@@ -73,9 +78,11 @@ class RAGContextBuilder:
|
|
| 73 |
initial_k: int = 20,
|
| 74 |
mode: str = "hybrid_rerank"
|
| 75 |
) -> Dict[str, Any]:
|
| 76 |
-
|
|
|
|
| 77 |
results = self._retriever.flexible_search(question, k=k, initial_k=initial_k, mode=mode)
|
| 78 |
|
|
|
|
| 79 |
if not results:
|
| 80 |
return {
|
| 81 |
"results": [],
|
|
@@ -84,15 +91,17 @@ class RAGContextBuilder:
|
|
| 84 |
"prompt": "",
|
| 85 |
}
|
| 86 |
|
|
|
|
| 87 |
context_text = build_context(results, self._max_context_chars)
|
| 88 |
prompt = build_prompt(question, context_text)
|
| 89 |
|
| 90 |
return {
|
| 91 |
-
"results": results,
|
| 92 |
-
"contexts": [r.get("content", "")[:1000] for r in results],
|
| 93 |
-
"context_text": context_text,
|
| 94 |
-
"prompt": prompt,
|
| 95 |
}
|
| 96 |
|
| 97 |
|
|
|
|
| 98 |
RAGGenerator = RAGContextBuilder
|
|
|
|
| 5 |
from core.rag.retrival import Retriever
|
| 6 |
|
| 7 |
|
| 8 |
+
# System prompt cho LLM (export để gradio/eval dùng)
|
| 9 |
SYSTEM_PROMPT = """Bạn là Trợ lý học vụ Đại học Bách khoa Hà Nội.
|
| 10 |
|
| 11 |
## NGUYÊN TẮC:
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def build_context(results: List[Dict[str, Any]], max_chars: int = 8000) -> str:
|
| 19 |
+
"""Xây dựng context từ kết quả retrieval để đưa vào prompt."""
|
| 20 |
parts = []
|
| 21 |
for i, r in enumerate(results, 1):
|
| 22 |
meta = r.get("metadata", {})
|
|
|
|
| 31 |
issued_year = meta.get("issued_year", "")
|
| 32 |
content = r.get("content", "").strip()
|
| 33 |
|
| 34 |
+
# Tạo dòng metadata
|
| 35 |
meta_info = f"Nguồn: {source}"
|
| 36 |
if header and header != "/":
|
| 37 |
meta_info += f" | Mục: {header}"
|
|
|
|
| 54 |
parts.append(f"[TÀI LIỆU {i}]\n{meta_info}\n{content}")
|
| 55 |
|
| 56 |
context = "\n---\n".join(parts)
|
| 57 |
+
# Cắt ngắn nếu vượt quá giới hạn
|
| 58 |
return context[:max_chars] if len(context) > max_chars else context
|
| 59 |
|
| 60 |
|
| 61 |
def build_prompt(question: str, context: str) -> str:
|
| 62 |
+
"""Ghép system prompt, context và câu hỏi thành prompt hoàn chỉnh."""
|
| 63 |
return f"{SYSTEM_PROMPT}\n\n## CONTEXT:\n{context}\n\n## CÂU HỎI: {question}\n\n## TRẢ LỜI:"
|
| 64 |
|
| 65 |
|
| 66 |
class RAGContextBuilder:
|
| 67 |
+
"""Kết hợp retrieval và context building thành một bước."""
|
| 68 |
|
| 69 |
def __init__(self, retriever: "Retriever", max_context_chars: int = 8000):
|
| 70 |
+
"""Khởi tạo với retriever và giới hạn context."""
|
| 71 |
self._retriever = retriever
|
| 72 |
self._max_context_chars = max_context_chars
|
| 73 |
|
|
|
|
| 78 |
initial_k: int = 20,
|
| 79 |
mode: str = "hybrid_rerank"
|
| 80 |
) -> Dict[str, Any]:
|
| 81 |
+
"""Retrieve documents và chuẩn bị context + prompt cho LLM."""
|
| 82 |
+
# Tìm kiếm documents liên quan
|
| 83 |
results = self._retriever.flexible_search(question, k=k, initial_k=initial_k, mode=mode)
|
| 84 |
|
| 85 |
+
# Không tìm thấy kết quả
|
| 86 |
if not results:
|
| 87 |
return {
|
| 88 |
"results": [],
|
|
|
|
| 91 |
"prompt": "",
|
| 92 |
}
|
| 93 |
|
| 94 |
+
# Xây dựng context và prompt
|
| 95 |
context_text = build_context(results, self._max_context_chars)
|
| 96 |
prompt = build_prompt(question, context_text)
|
| 97 |
|
| 98 |
return {
|
| 99 |
+
"results": results, # Kết quả retrieval gốc
|
| 100 |
+
"contexts": [r.get("content", "")[:1000] for r in results], # Context rút gọn (cho eval)
|
| 101 |
+
"context_text": context_text, # Context đầy đủ
|
| 102 |
+
"prompt": prompt, # Prompt hoàn chỉnh
|
| 103 |
}
|
| 104 |
|
| 105 |
|
| 106 |
+
# Alias để tương thích ngược
|
| 107 |
RAGGenerator = RAGContextBuilder
|
core/rag/retrival.py
CHANGED
|
@@ -22,29 +22,30 @@ logger = logging.getLogger(__name__)
|
|
| 22 |
|
| 23 |
|
| 24 |
class RetrievalMode(str, Enum):
|
| 25 |
-
"""
|
| 26 |
-
VECTOR_ONLY = "vector_only"
|
| 27 |
-
BM25_ONLY = "bm25_only"
|
| 28 |
-
HYBRID = "hybrid"
|
| 29 |
-
HYBRID_RERANK = "hybrid_rerank"
|
| 30 |
|
| 31 |
|
| 32 |
@dataclass
|
| 33 |
class RetrievalConfig:
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
|
| 43 |
|
| 44 |
_retrieval_config: RetrievalConfig | None = None
|
| 45 |
|
| 46 |
|
| 47 |
def get_retrieval_config() -> RetrievalConfig:
|
|
|
|
| 48 |
global _retrieval_config
|
| 49 |
if _retrieval_config is None:
|
| 50 |
_retrieval_config = RetrievalConfig()
|
|
@@ -52,6 +53,7 @@ def get_retrieval_config() -> RetrievalConfig:
|
|
| 52 |
|
| 53 |
|
| 54 |
class SiliconFlowReranker(BaseDocumentCompressor):
|
|
|
|
| 55 |
api_key: str = Field(default="")
|
| 56 |
api_base_url: str = Field(default="")
|
| 57 |
model: str = Field(default="")
|
|
@@ -66,9 +68,11 @@ class SiliconFlowReranker(BaseDocumentCompressor):
|
|
| 66 |
query: str,
|
| 67 |
callbacks: Optional[Callbacks] = None,
|
| 68 |
) -> Sequence[Document]:
|
|
|
|
| 69 |
if not documents or not self.api_key:
|
| 70 |
return list(documents)
|
| 71 |
|
|
|
|
| 72 |
for attempt in range(3):
|
| 73 |
try:
|
| 74 |
response = requests.post(
|
|
@@ -91,6 +95,7 @@ class SiliconFlowReranker(BaseDocumentCompressor):
|
|
| 91 |
if "results" not in data:
|
| 92 |
return list(documents)
|
| 93 |
|
|
|
|
| 94 |
reranked: List[Document] = []
|
| 95 |
for result in data["results"]:
|
| 96 |
doc = documents[result["index"]]
|
|
@@ -101,33 +106,36 @@ class SiliconFlowReranker(BaseDocumentCompressor):
|
|
| 101 |
return reranked
|
| 102 |
|
| 103 |
except Exception as e:
|
|
|
|
| 104 |
if "rate" in str(e).lower() and attempt < 2:
|
| 105 |
time.sleep(2 ** attempt)
|
| 106 |
else:
|
| 107 |
-
logger.error(f"
|
| 108 |
return list(documents)
|
| 109 |
|
| 110 |
return list(documents)
|
| 111 |
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
class Retriever:
|
|
|
|
|
|
|
| 116 |
def __init__(self, vector_db: "ChromaVectorDB", use_reranker: bool = True):
|
|
|
|
| 117 |
self._vector_db = vector_db
|
| 118 |
self._config = get_retrieval_config()
|
| 119 |
self._reranker: Optional[SiliconFlowReranker] = None
|
| 120 |
|
|
|
|
| 121 |
self._vector_retriever = self._vector_db.vectorstore.as_retriever(
|
| 122 |
search_kwargs={"k": self._config.initial_k}
|
| 123 |
)
|
| 124 |
|
| 125 |
-
# Lazy-load BM25 -
|
| 126 |
self._bm25_retriever: Optional[BM25Retriever] = None
|
| 127 |
self._bm25_initialized = False
|
| 128 |
self._ensemble_retriever: Optional[EnsembleRetriever] = None
|
| 129 |
|
| 130 |
-
#
|
| 131 |
from pathlib import Path
|
| 132 |
persist_dir = getattr(self._vector_db.config, 'persist_dir', None)
|
| 133 |
if persist_dir:
|
|
@@ -138,61 +146,57 @@ class Retriever:
|
|
| 138 |
if use_reranker:
|
| 139 |
self._reranker = self._init_reranker()
|
| 140 |
|
| 141 |
-
logger.info("Retriever
|
| 142 |
|
| 143 |
-
|
| 144 |
def _save_bm25_cache(self, bm25: BM25Retriever) -> None:
|
| 145 |
-
"""
|
| 146 |
if not self._bm25_cache_path:
|
| 147 |
return
|
| 148 |
try:
|
| 149 |
import pickle
|
| 150 |
with open(self._bm25_cache_path, 'wb') as f:
|
| 151 |
pickle.dump(bm25, f)
|
| 152 |
-
logger.info(f"BM25 cache
|
| 153 |
except Exception as e:
|
| 154 |
-
logger.warning(f"
|
| 155 |
|
| 156 |
def _load_bm25_cache(self) -> Optional[BM25Retriever]:
|
|
|
|
| 157 |
if not self._bm25_cache_path or not self._bm25_cache_path.exists():
|
| 158 |
return None
|
| 159 |
-
|
| 160 |
try:
|
| 161 |
import pickle
|
| 162 |
-
import time
|
| 163 |
start = time.time()
|
| 164 |
with open(self._bm25_cache_path, 'rb') as f:
|
| 165 |
bm25 = pickle.load(f)
|
| 166 |
bm25.k = self._config.initial_k
|
| 167 |
-
logger.info(f"BM25
|
| 168 |
return bm25
|
| 169 |
except Exception as e:
|
| 170 |
-
logger.warning(f"
|
| 171 |
return None
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
def _init_bm25(self) -> Optional[BM25Retriever]:
|
|
|
|
| 176 |
if self._bm25_initialized:
|
| 177 |
return self._bm25_retriever
|
| 178 |
|
| 179 |
self._bm25_initialized = True
|
| 180 |
|
| 181 |
-
#
|
| 182 |
cached = self._load_bm25_cache()
|
| 183 |
if cached:
|
| 184 |
self._bm25_retriever = cached
|
| 185 |
return cached
|
| 186 |
|
| 187 |
-
# Build
|
| 188 |
try:
|
| 189 |
-
import time
|
| 190 |
start = time.time()
|
| 191 |
-
logger.info("
|
| 192 |
|
| 193 |
docs = self._vector_db.get_all_documents()
|
| 194 |
if not docs:
|
| 195 |
-
logger.warning("
|
| 196 |
return None
|
| 197 |
|
| 198 |
lc_docs = [
|
|
@@ -203,19 +207,18 @@ class Retriever:
|
|
| 203 |
bm25.k = self._config.initial_k
|
| 204 |
|
| 205 |
self._bm25_retriever = bm25
|
| 206 |
-
logger.info(f"BM25
|
| 207 |
|
| 208 |
-
#
|
| 209 |
self._save_bm25_cache(bm25)
|
| 210 |
|
| 211 |
return bm25
|
| 212 |
except Exception as e:
|
| 213 |
-
logger.error(f"
|
| 214 |
return None
|
| 215 |
|
| 216 |
-
|
| 217 |
def _get_ensemble_retriever(self) -> EnsembleRetriever:
|
| 218 |
-
"""
|
| 219 |
if self._ensemble_retriever is not None:
|
| 220 |
return self._ensemble_retriever
|
| 221 |
|
|
@@ -226,14 +229,15 @@ class Retriever:
|
|
| 226 |
weights=[self._config.vector_weight, self._config.bm25_weight]
|
| 227 |
)
|
| 228 |
else:
|
|
|
|
| 229 |
self._ensemble_retriever = EnsembleRetriever(
|
| 230 |
retrievers=[self._vector_retriever],
|
| 231 |
weights=[1.0]
|
| 232 |
)
|
| 233 |
return self._ensemble_retriever
|
| 234 |
|
| 235 |
-
|
| 236 |
def _init_reranker(self) -> Optional[SiliconFlowReranker]:
|
|
|
|
| 237 |
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
| 238 |
if not api_key:
|
| 239 |
return None
|
|
@@ -245,7 +249,7 @@ class Retriever:
|
|
| 245 |
)
|
| 246 |
|
| 247 |
def _build_final(self):
|
| 248 |
-
"""Build
|
| 249 |
ensemble = self._get_ensemble_retriever()
|
| 250 |
if self._reranker:
|
| 251 |
return ContextualCompressionRetriever(
|
|
@@ -254,21 +258,22 @@ class Retriever:
|
|
| 254 |
)
|
| 255 |
return ensemble
|
| 256 |
|
| 257 |
-
|
| 258 |
@property
|
| 259 |
def has_reranker(self) -> bool:
|
|
|
|
| 260 |
return self._reranker is not None
|
| 261 |
|
| 262 |
def _to_result(self, doc: Document, rank: int, **extra) -> Dict[str, Any]:
|
|
|
|
| 263 |
metadata = doc.metadata or {}
|
| 264 |
content = doc.page_content
|
| 265 |
|
| 266 |
-
# Small-to-Big:
|
| 267 |
if metadata.get("is_table_summary") and metadata.get("parent_id"):
|
| 268 |
parent = self._vector_db.get_parent_node(metadata["parent_id"])
|
| 269 |
if parent:
|
| 270 |
content = parent.get("content", content)
|
| 271 |
-
# Merge metadata,
|
| 272 |
metadata = {
|
| 273 |
**parent.get("metadata", {}),
|
| 274 |
"original_summary": doc.page_content[:200],
|
|
@@ -283,10 +288,10 @@ class Retriever:
|
|
| 283 |
**extra,
|
| 284 |
}
|
| 285 |
|
| 286 |
-
|
| 287 |
def vector_search(
|
| 288 |
self, text: str, *, k: int | None = None, where: Optional[Dict[str, Any]] = None
|
| 289 |
) -> List[Dict[str, Any]]:
|
|
|
|
| 290 |
if not text.strip():
|
| 291 |
return []
|
| 292 |
k = k or self._config.top_k
|
|
@@ -294,13 +299,12 @@ class Retriever:
|
|
| 294 |
return [self._to_result(doc, i + 1, distance=score) for i, (doc, score) in enumerate(results)]
|
| 295 |
|
| 296 |
def bm25_search(self, text: str, *, k: int | None = None) -> List[Dict[str, Any]]:
|
|
|
|
| 297 |
if not text.strip():
|
| 298 |
return []
|
| 299 |
-
|
| 300 |
bm25 = self._init_bm25() # Lazy-load BM25
|
| 301 |
if not bm25:
|
| 302 |
return self.vector_search(text, k=k)
|
| 303 |
-
|
| 304 |
k = k or self._config.top_k
|
| 305 |
bm25.k = k
|
| 306 |
results = bm25.invoke(text)
|
|
@@ -309,9 +313,9 @@ class Retriever:
|
|
| 309 |
def hybrid_search(
|
| 310 |
self, text: str, *, k: int | None = None, initial_k: int | None = None
|
| 311 |
) -> List[Dict[str, Any]]:
|
|
|
|
| 312 |
if not text.strip():
|
| 313 |
return []
|
| 314 |
-
|
| 315 |
k = k or self._config.top_k
|
| 316 |
if initial_k:
|
| 317 |
self._vector_retriever.search_kwargs["k"] = initial_k
|
|
@@ -319,7 +323,6 @@ class Retriever:
|
|
| 319 |
if bm25:
|
| 320 |
bm25.k = initial_k
|
| 321 |
|
| 322 |
-
# Dùng ensemble_retriever (lazy-loaded, KHÔNG có reranker)
|
| 323 |
ensemble = self._get_ensemble_retriever()
|
| 324 |
results = ensemble.invoke(text)
|
| 325 |
return [self._to_result(doc, i + 1) for i, doc in enumerate(results[:k])]
|
|
@@ -332,11 +335,9 @@ class Retriever:
|
|
| 332 |
where: Optional[Dict[str, Any]] = None,
|
| 333 |
initial_k: int | None = None,
|
| 334 |
) -> List[Dict[str, Any]]:
|
| 335 |
-
|
| 336 |
-
|
| 337 |
if not text.strip():
|
| 338 |
return []
|
| 339 |
-
|
| 340 |
k = k or self._config.top_k
|
| 341 |
initial_k = initial_k or self._config.initial_k
|
| 342 |
|
|
@@ -350,16 +351,18 @@ class Retriever:
|
|
| 350 |
for i, doc in enumerate(results[:k])
|
| 351 |
]
|
| 352 |
|
| 353 |
-
#
|
| 354 |
if initial_k:
|
| 355 |
self._vector_retriever.search_kwargs["k"] = initial_k
|
| 356 |
bm25 = self._init_bm25()
|
| 357 |
if bm25:
|
| 358 |
bm25.k = initial_k
|
| 359 |
|
|
|
|
| 360 |
ensemble = self._get_ensemble_retriever()
|
| 361 |
ensemble_results = ensemble.invoke(text)
|
| 362 |
|
|
|
|
| 363 |
if self._reranker:
|
| 364 |
results = self._reranker.compress_documents(ensemble_results, text)
|
| 365 |
else:
|
|
@@ -370,8 +373,6 @@ class Retriever:
|
|
| 370 |
for i, doc in enumerate(results[:k])
|
| 371 |
]
|
| 372 |
|
| 373 |
-
|
| 374 |
-
|
| 375 |
def flexible_search(
|
| 376 |
self,
|
| 377 |
text: str,
|
|
@@ -381,9 +382,11 @@ class Retriever:
|
|
| 381 |
initial_k: int | None = None,
|
| 382 |
where: Optional[Dict[str, Any]] = None,
|
| 383 |
) -> List[Dict[str, Any]]:
|
|
|
|
| 384 |
if not text.strip():
|
| 385 |
return []
|
| 386 |
|
|
|
|
| 387 |
if isinstance(mode, str):
|
| 388 |
try:
|
| 389 |
mode = RetrievalMode(mode.lower())
|
|
@@ -393,6 +396,7 @@ class Retriever:
|
|
| 393 |
k = k or self._config.top_k
|
| 394 |
initial_k = initial_k or self._config.initial_k
|
| 395 |
|
|
|
|
| 396 |
if mode == RetrievalMode.VECTOR_ONLY:
|
| 397 |
return self.vector_search(text, k=k, where=where)
|
| 398 |
elif mode == RetrievalMode.BM25_ONLY:
|
|
@@ -404,5 +408,5 @@ class Retriever:
|
|
| 404 |
else: # HYBRID_RERANK
|
| 405 |
return self.search_with_rerank(text, k=k, where=where, initial_k=initial_k)
|
| 406 |
|
| 407 |
-
#
|
| 408 |
query = vector_search
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
class RetrievalMode(str, Enum):
|
| 25 |
+
"""Các chế độ retrieval hỗ trợ."""
|
| 26 |
+
VECTOR_ONLY = "vector_only" # Chỉ dùng vector search
|
| 27 |
+
BM25_ONLY = "bm25_only" # Chỉ dùng BM25 keyword search
|
| 28 |
+
HYBRID = "hybrid" # Kết hợp vector + BM25
|
| 29 |
+
HYBRID_RERANK = "hybrid_rerank" # Hybrid + reranking
|
| 30 |
|
| 31 |
|
| 32 |
@dataclass
|
| 33 |
class RetrievalConfig:
|
| 34 |
+
"""Cấu hình cho retrieval system."""
|
| 35 |
+
rerank_api_base_url: str = "https://api.siliconflow.com/v1" # API reranker
|
| 36 |
+
rerank_model: str = "Qwen/Qwen3-Reranker-4B" # Model reranker
|
| 37 |
+
rerank_top_n: int = 10 # Số kết quả sau rerank
|
| 38 |
+
initial_k: int = 25 # Số docs lấy ban đầu
|
| 39 |
+
top_k: int = 5 # Số kết quả cuối cùng
|
| 40 |
+
vector_weight: float = 0.5 # Trọng số vector search
|
| 41 |
+
bm25_weight: float = 0.5 # Trọng số BM25
|
| 42 |
|
| 43 |
|
| 44 |
_retrieval_config: RetrievalConfig | None = None
|
| 45 |
|
| 46 |
|
| 47 |
def get_retrieval_config() -> RetrievalConfig:
|
| 48 |
+
"""Lấy cấu hình retrieval (singleton pattern)."""
|
| 49 |
global _retrieval_config
|
| 50 |
if _retrieval_config is None:
|
| 51 |
_retrieval_config = RetrievalConfig()
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
class SiliconFlowReranker(BaseDocumentCompressor):
|
| 56 |
+
"""Reranker sử dụng SiliconFlow API để sắp xếp lại kết quả."""
|
| 57 |
api_key: str = Field(default="")
|
| 58 |
api_base_url: str = Field(default="")
|
| 59 |
model: str = Field(default="")
|
|
|
|
| 68 |
query: str,
|
| 69 |
callbacks: Optional[Callbacks] = None,
|
| 70 |
) -> Sequence[Document]:
|
| 71 |
+
"""Rerank documents dựa trên độ liên quan với query."""
|
| 72 |
if not documents or not self.api_key:
|
| 73 |
return list(documents)
|
| 74 |
|
| 75 |
+
# Retry logic với exponential backoff
|
| 76 |
for attempt in range(3):
|
| 77 |
try:
|
| 78 |
response = requests.post(
|
|
|
|
| 95 |
if "results" not in data:
|
| 96 |
return list(documents)
|
| 97 |
|
| 98 |
+
# Tạo danh sách documents đã rerank với score
|
| 99 |
reranked: List[Document] = []
|
| 100 |
for result in data["results"]:
|
| 101 |
doc = documents[result["index"]]
|
|
|
|
| 106 |
return reranked
|
| 107 |
|
| 108 |
except Exception as e:
|
| 109 |
+
# Rate limit -> đợi rồi thử lại
|
| 110 |
if "rate" in str(e).lower() and attempt < 2:
|
| 111 |
time.sleep(2 ** attempt)
|
| 112 |
else:
|
| 113 |
+
logger.error(f"Lỗi rerank: {e}")
|
| 114 |
return list(documents)
|
| 115 |
|
| 116 |
return list(documents)
|
| 117 |
|
| 118 |
|
|
|
|
|
|
|
| 119 |
class Retriever:
|
| 120 |
+
"""Retriever chính hỗ trợ nhiều chế độ tìm kiếm."""
|
| 121 |
+
|
| 122 |
def __init__(self, vector_db: "ChromaVectorDB", use_reranker: bool = True):
|
| 123 |
+
"""Khởi tạo retriever với vector DB và reranker."""
|
| 124 |
self._vector_db = vector_db
|
| 125 |
self._config = get_retrieval_config()
|
| 126 |
self._reranker: Optional[SiliconFlowReranker] = None
|
| 127 |
|
| 128 |
+
# Vector retriever từ ChromaDB
|
| 129 |
self._vector_retriever = self._vector_db.vectorstore.as_retriever(
|
| 130 |
search_kwargs={"k": self._config.initial_k}
|
| 131 |
)
|
| 132 |
|
| 133 |
+
# Lazy-load BM25 - chỉ khởi tạo khi cần
|
| 134 |
self._bm25_retriever: Optional[BM25Retriever] = None
|
| 135 |
self._bm25_initialized = False
|
| 136 |
self._ensemble_retriever: Optional[EnsembleRetriever] = None
|
| 137 |
|
| 138 |
+
# Đường dẫn cache BM25 (lưu vào disk)
|
| 139 |
from pathlib import Path
|
| 140 |
persist_dir = getattr(self._vector_db.config, 'persist_dir', None)
|
| 141 |
if persist_dir:
|
|
|
|
| 146 |
if use_reranker:
|
| 147 |
self._reranker = self._init_reranker()
|
| 148 |
|
| 149 |
+
logger.info("Đã khởi tạo Retriever")
|
| 150 |
|
|
|
|
| 151 |
def _save_bm25_cache(self, bm25: BM25Retriever) -> None:
|
| 152 |
+
"""Lưu BM25 index vào cache file."""
|
| 153 |
if not self._bm25_cache_path:
|
| 154 |
return
|
| 155 |
try:
|
| 156 |
import pickle
|
| 157 |
with open(self._bm25_cache_path, 'wb') as f:
|
| 158 |
pickle.dump(bm25, f)
|
| 159 |
+
logger.info(f"Đã lưu BM25 cache vào {self._bm25_cache_path}")
|
| 160 |
except Exception as e:
|
| 161 |
+
logger.warning(f"Không thể lưu BM25 cache: {e}")
|
| 162 |
|
| 163 |
def _load_bm25_cache(self) -> Optional[BM25Retriever]:
|
| 164 |
+
"""Tải BM25 index từ cache file."""
|
| 165 |
if not self._bm25_cache_path or not self._bm25_cache_path.exists():
|
| 166 |
return None
|
|
|
|
| 167 |
try:
|
| 168 |
import pickle
|
|
|
|
| 169 |
start = time.time()
|
| 170 |
with open(self._bm25_cache_path, 'rb') as f:
|
| 171 |
bm25 = pickle.load(f)
|
| 172 |
bm25.k = self._config.initial_k
|
| 173 |
+
logger.info(f"Đã tải BM25 từ cache trong {time.time() - start:.2f}s")
|
| 174 |
return bm25
|
| 175 |
except Exception as e:
|
| 176 |
+
logger.warning(f"Không thể tải BM25 cache: {e}")
|
| 177 |
return None
|
| 178 |
+
|
|
|
|
|
|
|
| 179 |
def _init_bm25(self) -> Optional[BM25Retriever]:
|
| 180 |
+
"""Khởi tạo BM25 retriever (lazy-load với cache)."""
|
| 181 |
if self._bm25_initialized:
|
| 182 |
return self._bm25_retriever
|
| 183 |
|
| 184 |
self._bm25_initialized = True
|
| 185 |
|
| 186 |
+
# Thử tải từ cache trước
|
| 187 |
cached = self._load_bm25_cache()
|
| 188 |
if cached:
|
| 189 |
self._bm25_retriever = cached
|
| 190 |
return cached
|
| 191 |
|
| 192 |
+
# Build từ đầu nếu không có cache
|
| 193 |
try:
|
|
|
|
| 194 |
start = time.time()
|
| 195 |
+
logger.info("Đang xây dựng BM25 index từ documents...")
|
| 196 |
|
| 197 |
docs = self._vector_db.get_all_documents()
|
| 198 |
if not docs:
|
| 199 |
+
logger.warning("Không tìm thấy documents cho BM25")
|
| 200 |
return None
|
| 201 |
|
| 202 |
lc_docs = [
|
|
|
|
| 207 |
bm25.k = self._config.initial_k
|
| 208 |
|
| 209 |
self._bm25_retriever = bm25
|
| 210 |
+
logger.info(f"Đã xây dựng BM25 với {len(docs)} docs trong {time.time() - start:.2f}s")
|
| 211 |
|
| 212 |
+
# Lưu vào cache cho lần sau
|
| 213 |
self._save_bm25_cache(bm25)
|
| 214 |
|
| 215 |
return bm25
|
| 216 |
except Exception as e:
|
| 217 |
+
logger.error(f"Không thể khởi tạo BM25: {e}")
|
| 218 |
return None
|
| 219 |
|
|
|
|
| 220 |
def _get_ensemble_retriever(self) -> EnsembleRetriever:
|
| 221 |
+
"""Lấy ensemble retriever (vector + BM25)."""
|
| 222 |
if self._ensemble_retriever is not None:
|
| 223 |
return self._ensemble_retriever
|
| 224 |
|
|
|
|
| 229 |
weights=[self._config.vector_weight, self._config.bm25_weight]
|
| 230 |
)
|
| 231 |
else:
|
| 232 |
+
# Fallback về vector only
|
| 233 |
self._ensemble_retriever = EnsembleRetriever(
|
| 234 |
retrievers=[self._vector_retriever],
|
| 235 |
weights=[1.0]
|
| 236 |
)
|
| 237 |
return self._ensemble_retriever
|
| 238 |
|
|
|
|
| 239 |
def _init_reranker(self) -> Optional[SiliconFlowReranker]:
|
| 240 |
+
"""Khởi tạo reranker nếu có API key."""
|
| 241 |
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
| 242 |
if not api_key:
|
| 243 |
return None
|
|
|
|
| 249 |
)
|
| 250 |
|
| 251 |
def _build_final(self):
|
| 252 |
+
"""Build retriever cuối cùng (ensemble + reranker nếu có)."""
|
| 253 |
ensemble = self._get_ensemble_retriever()
|
| 254 |
if self._reranker:
|
| 255 |
return ContextualCompressionRetriever(
|
|
|
|
| 258 |
)
|
| 259 |
return ensemble
|
| 260 |
|
|
|
|
| 261 |
@property
|
| 262 |
def has_reranker(self) -> bool:
|
| 263 |
+
"""Kiểm tra có reranker không."""
|
| 264 |
return self._reranker is not None
|
| 265 |
|
| 266 |
def _to_result(self, doc: Document, rank: int, **extra) -> Dict[str, Any]:
|
| 267 |
+
"""Chuyển Document thành dict result, xử lý Small-to-Big."""
|
| 268 |
metadata = doc.metadata or {}
|
| 269 |
content = doc.page_content
|
| 270 |
|
| 271 |
+
# Small-to-Big: Nếu là summary node -> swap với parent (bảng gốc)
|
| 272 |
if metadata.get("is_table_summary") and metadata.get("parent_id"):
|
| 273 |
parent = self._vector_db.get_parent_node(metadata["parent_id"])
|
| 274 |
if parent:
|
| 275 |
content = parent.get("content", content)
|
| 276 |
+
# Merge metadata, giữ lại info summary để debug
|
| 277 |
metadata = {
|
| 278 |
**parent.get("metadata", {}),
|
| 279 |
"original_summary": doc.page_content[:200],
|
|
|
|
| 288 |
**extra,
|
| 289 |
}
|
| 290 |
|
|
|
|
| 291 |
def vector_search(
|
| 292 |
self, text: str, *, k: int | None = None, where: Optional[Dict[str, Any]] = None
|
| 293 |
) -> List[Dict[str, Any]]:
|
| 294 |
+
"""Tìm kiếm bằng vector similarity."""
|
| 295 |
if not text.strip():
|
| 296 |
return []
|
| 297 |
k = k or self._config.top_k
|
|
|
|
| 299 |
return [self._to_result(doc, i + 1, distance=score) for i, (doc, score) in enumerate(results)]
|
| 300 |
|
| 301 |
def bm25_search(self, text: str, *, k: int | None = None) -> List[Dict[str, Any]]:
|
| 302 |
+
"""Tìm kiếm bằng BM25 keyword matching."""
|
| 303 |
if not text.strip():
|
| 304 |
return []
|
|
|
|
| 305 |
bm25 = self._init_bm25() # Lazy-load BM25
|
| 306 |
if not bm25:
|
| 307 |
return self.vector_search(text, k=k)
|
|
|
|
| 308 |
k = k or self._config.top_k
|
| 309 |
bm25.k = k
|
| 310 |
results = bm25.invoke(text)
|
|
|
|
| 313 |
def hybrid_search(
|
| 314 |
self, text: str, *, k: int | None = None, initial_k: int | None = None
|
| 315 |
) -> List[Dict[str, Any]]:
|
| 316 |
+
"""Tìm kiếm hybrid (vector + BM25) không có rerank."""
|
| 317 |
if not text.strip():
|
| 318 |
return []
|
|
|
|
| 319 |
k = k or self._config.top_k
|
| 320 |
if initial_k:
|
| 321 |
self._vector_retriever.search_kwargs["k"] = initial_k
|
|
|
|
| 323 |
if bm25:
|
| 324 |
bm25.k = initial_k
|
| 325 |
|
|
|
|
| 326 |
ensemble = self._get_ensemble_retriever()
|
| 327 |
results = ensemble.invoke(text)
|
| 328 |
return [self._to_result(doc, i + 1) for i, doc in enumerate(results[:k])]
|
|
|
|
| 335 |
where: Optional[Dict[str, Any]] = None,
|
| 336 |
initial_k: int | None = None,
|
| 337 |
) -> List[Dict[str, Any]]:
|
| 338 |
+
"""Tìm kiếm hybrid + reranking để có kết quả tốt nhất."""
|
|
|
|
| 339 |
if not text.strip():
|
| 340 |
return []
|
|
|
|
| 341 |
k = k or self._config.top_k
|
| 342 |
initial_k = initial_k or self._config.initial_k
|
| 343 |
|
|
|
|
| 351 |
for i, doc in enumerate(results[:k])
|
| 352 |
]
|
| 353 |
|
| 354 |
+
# Cập nhật k cho initial fetch
|
| 355 |
if initial_k:
|
| 356 |
self._vector_retriever.search_kwargs["k"] = initial_k
|
| 357 |
bm25 = self._init_bm25()
|
| 358 |
if bm25:
|
| 359 |
bm25.k = initial_k
|
| 360 |
|
| 361 |
+
# Hybrid search
|
| 362 |
ensemble = self._get_ensemble_retriever()
|
| 363 |
ensemble_results = ensemble.invoke(text)
|
| 364 |
|
| 365 |
+
# Rerank nếu có
|
| 366 |
if self._reranker:
|
| 367 |
results = self._reranker.compress_documents(ensemble_results, text)
|
| 368 |
else:
|
|
|
|
| 373 |
for i, doc in enumerate(results[:k])
|
| 374 |
]
|
| 375 |
|
|
|
|
|
|
|
| 376 |
def flexible_search(
|
| 377 |
self,
|
| 378 |
text: str,
|
|
|
|
| 382 |
initial_k: int | None = None,
|
| 383 |
where: Optional[Dict[str, Any]] = None,
|
| 384 |
) -> List[Dict[str, Any]]:
|
| 385 |
+
"""Tìm kiếm linh hoạt với nhiều chế độ."""
|
| 386 |
if not text.strip():
|
| 387 |
return []
|
| 388 |
|
| 389 |
+
# Parse mode từ string
|
| 390 |
if isinstance(mode, str):
|
| 391 |
try:
|
| 392 |
mode = RetrievalMode(mode.lower())
|
|
|
|
| 396 |
k = k or self._config.top_k
|
| 397 |
initial_k = initial_k or self._config.initial_k
|
| 398 |
|
| 399 |
+
# Gọi method tương ứng theo mode
|
| 400 |
if mode == RetrievalMode.VECTOR_ONLY:
|
| 401 |
return self.vector_search(text, k=k, where=where)
|
| 402 |
elif mode == RetrievalMode.BM25_ONLY:
|
|
|
|
| 408 |
else: # HYBRID_RERANK
|
| 409 |
return self.search_with_rerank(text, k=k, where=where, initial_k=initial_k)
|
| 410 |
|
| 411 |
+
# Alias để tương thích ngược
|
| 412 |
query = vector_search
|
core/rag/vector_store.py
CHANGED
|
@@ -13,66 +13,76 @@ logger = logging.getLogger(__name__)
|
|
| 13 |
|
| 14 |
@dataclass
|
| 15 |
class ChromaConfig:
|
|
|
|
|
|
|
| 16 |
def _default_persist_dir() -> str:
|
|
|
|
| 17 |
repo_root = Path(__file__).resolve().parents[2]
|
| 18 |
return str((repo_root / "data" / "chroma").resolve())
|
| 19 |
|
| 20 |
-
persist_dir: str = field(default_factory=_default_persist_dir)
|
| 21 |
-
collection_name: str = "hust_rag_collection"
|
| 22 |
|
| 23 |
|
| 24 |
class ChromaVectorDB:
|
|
|
|
|
|
|
| 25 |
def __init__(
|
| 26 |
self,
|
| 27 |
embedder: Any,
|
| 28 |
config: ChromaConfig | None = None,
|
| 29 |
):
|
|
|
|
| 30 |
self.embedder = embedder
|
| 31 |
self.config = config or ChromaConfig()
|
| 32 |
self._hasher = HashProcessor(verbose=False)
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
# Persist to JSON file in same directory as ChromaDB
|
| 36 |
self._parent_nodes_path = Path(self.config.persist_dir) / "parent_nodes.json"
|
| 37 |
self._parent_nodes: Dict[str, Dict[str, Any]] = self._load_parent_nodes()
|
| 38 |
|
|
|
|
| 39 |
self._vs = Chroma(
|
| 40 |
collection_name=self.config.collection_name,
|
| 41 |
embedding_function=self.embedder,
|
| 42 |
persist_directory=self.config.persist_dir,
|
| 43 |
)
|
| 44 |
-
logger.info(f"ChromaVectorDB
|
| 45 |
|
| 46 |
def _load_parent_nodes(self) -> Dict[str, Dict[str, Any]]:
|
|
|
|
| 47 |
if self._parent_nodes_path.exists():
|
| 48 |
try:
|
| 49 |
with open(self._parent_nodes_path, 'r', encoding='utf-8') as f:
|
| 50 |
data = json.load(f)
|
| 51 |
-
logger.info(f"
|
| 52 |
return data
|
| 53 |
except Exception as e:
|
| 54 |
-
logger.warning(f"
|
| 55 |
return {}
|
| 56 |
|
| 57 |
def _save_parent_nodes(self) -> None:
|
| 58 |
-
"""
|
| 59 |
try:
|
| 60 |
self._parent_nodes_path.parent.mkdir(parents=True, exist_ok=True)
|
| 61 |
with open(self._parent_nodes_path, 'w', encoding='utf-8') as f:
|
| 62 |
json.dump(self._parent_nodes, f, ensure_ascii=False, indent=2)
|
| 63 |
-
logger.info(f"
|
| 64 |
except Exception as e:
|
| 65 |
-
logger.warning(f"
|
| 66 |
|
| 67 |
@property
|
| 68 |
def collection(self):
|
|
|
|
| 69 |
return getattr(self._vs, "_collection", None)
|
| 70 |
|
| 71 |
@property
|
| 72 |
def vectorstore(self):
|
|
|
|
| 73 |
return self._vs
|
| 74 |
|
| 75 |
def _flatten_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
| 76 |
out: Dict[str, Any] = {}
|
| 77 |
for k, v in (metadata or {}).items():
|
| 78 |
if v is None:
|
|
@@ -80,33 +90,33 @@ class ChromaVectorDB:
|
|
| 80 |
if isinstance(v, (str, int, float, bool)):
|
| 81 |
out[str(k)] = v
|
| 82 |
elif isinstance(v, (list, tuple, set, dict)):
|
|
|
|
| 83 |
out[str(k)] = json.dumps(v, ensure_ascii=False)
|
| 84 |
else:
|
| 85 |
out[str(k)] = str(v)
|
| 86 |
return out
|
| 87 |
|
| 88 |
def _normalize_doc(self, doc: Any) -> Dict[str, Any]:
|
| 89 |
-
|
|
|
|
| 90 |
if isinstance(doc, dict):
|
| 91 |
return doc
|
| 92 |
-
|
| 93 |
-
# Nếu là TextNode/BaseNode từ llama_index
|
| 94 |
if hasattr(doc, "get_content") and hasattr(doc, "metadata"):
|
| 95 |
return {
|
| 96 |
"content": doc.get_content(),
|
| 97 |
"metadata": dict(doc.metadata) if doc.metadata else {},
|
| 98 |
}
|
| 99 |
-
|
| 100 |
-
# Nếu là Document từ langchain
|
| 101 |
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
| 102 |
return {
|
| 103 |
"content": doc.page_content,
|
| 104 |
"metadata": dict(doc.metadata) if doc.metadata else {},
|
| 105 |
}
|
| 106 |
-
|
| 107 |
-
raise TypeError(f"Unsupported document type: {type(doc)}")
|
| 108 |
|
| 109 |
def _to_documents(self, docs: Sequence[Any], ids: Sequence[str]) -> List[Document]:
|
|
|
|
| 110 |
out: List[Document] = []
|
| 111 |
for d, doc_id in zip(docs, ids):
|
| 112 |
normalized = self._normalize_doc(d)
|
|
@@ -116,6 +126,7 @@ class ChromaVectorDB:
|
|
| 116 |
return out
|
| 117 |
|
| 118 |
def _doc_id(self, doc: Any) -> str:
|
|
|
|
| 119 |
normalized = self._normalize_doc(doc)
|
| 120 |
md = normalized.get("metadata") or {}
|
| 121 |
key = {
|
|
@@ -133,13 +144,14 @@ class ChromaVectorDB:
|
|
| 133 |
ids: Optional[Sequence[str]] = None,
|
| 134 |
batch_size: int = 128,
|
| 135 |
) -> int:
|
|
|
|
| 136 |
if not docs:
|
| 137 |
return 0
|
| 138 |
|
| 139 |
if ids is not None and len(ids) != len(docs):
|
| 140 |
-
raise ValueError("ids
|
| 141 |
|
| 142 |
-
#
|
| 143 |
regular_docs = []
|
| 144 |
regular_ids = []
|
| 145 |
parent_count = 0
|
|
@@ -150,7 +162,7 @@ class ChromaVectorDB:
|
|
| 150 |
doc_id = ids[i] if ids else self._doc_id(d)
|
| 151 |
|
| 152 |
if md.get("is_parent"):
|
| 153 |
-
#
|
| 154 |
parent_id = md.get("node_id", doc_id)
|
| 155 |
self._parent_nodes[parent_id] = {
|
| 156 |
"id": parent_id,
|
|
@@ -163,12 +175,13 @@ class ChromaVectorDB:
|
|
| 163 |
regular_ids.append(doc_id)
|
| 164 |
|
| 165 |
if parent_count > 0:
|
| 166 |
-
logger.info(f"
|
| 167 |
-
self._save_parent_nodes()
|
| 168 |
|
| 169 |
if not regular_docs:
|
| 170 |
return parent_count
|
| 171 |
|
|
|
|
| 172 |
bs = max(1, batch_size)
|
| 173 |
total = 0
|
| 174 |
|
|
@@ -180,12 +193,13 @@ class ChromaVectorDB:
|
|
| 180 |
try:
|
| 181 |
self._vs.add_documents(lc_docs, ids=batch_ids)
|
| 182 |
except TypeError:
|
|
|
|
| 183 |
texts = [d.page_content for d in lc_docs]
|
| 184 |
metas = [d.metadata for d in lc_docs]
|
| 185 |
self._vs.add_texts(texts=texts, metadatas=metas, ids=batch_ids)
|
| 186 |
total += len(batch)
|
| 187 |
|
| 188 |
-
logger.info(f"
|
| 189 |
return total + parent_count
|
| 190 |
|
| 191 |
def upsert_documents(
|
|
@@ -195,13 +209,14 @@ class ChromaVectorDB:
|
|
| 195 |
ids: Optional[Sequence[str]] = None,
|
| 196 |
batch_size: int = 128,
|
| 197 |
) -> int:
|
|
|
|
| 198 |
if not docs:
|
| 199 |
return 0
|
| 200 |
|
| 201 |
if ids is not None and len(ids) != len(docs):
|
| 202 |
-
raise ValueError("ids
|
| 203 |
|
| 204 |
-
#
|
| 205 |
regular_docs = []
|
| 206 |
regular_ids = []
|
| 207 |
parent_count = 0
|
|
@@ -212,7 +227,7 @@ class ChromaVectorDB:
|
|
| 212 |
doc_id = ids[i] if ids else self._doc_id(d)
|
| 213 |
|
| 214 |
if md.get("is_parent"):
|
| 215 |
-
#
|
| 216 |
parent_id = md.get("node_id", doc_id)
|
| 217 |
self._parent_nodes[parent_id] = {
|
| 218 |
"id": parent_id,
|
|
@@ -225,8 +240,8 @@ class ChromaVectorDB:
|
|
| 225 |
regular_ids.append(doc_id)
|
| 226 |
|
| 227 |
if parent_count > 0:
|
| 228 |
-
logger.info(f"
|
| 229 |
-
self._save_parent_nodes()
|
| 230 |
|
| 231 |
if not regular_docs:
|
| 232 |
return parent_count
|
|
@@ -234,9 +249,11 @@ class ChromaVectorDB:
|
|
| 234 |
bs = max(1, batch_size)
|
| 235 |
col = self.collection
|
| 236 |
|
|
|
|
| 237 |
if col is None:
|
| 238 |
return self.add_documents(regular_docs, ids=regular_ids, batch_size=bs) + parent_count
|
| 239 |
|
|
|
|
| 240 |
total = 0
|
| 241 |
for start in range(0, len(regular_docs), bs):
|
| 242 |
batch = regular_docs[start : start + bs]
|
|
@@ -248,14 +265,16 @@ class ChromaVectorDB:
|
|
| 248 |
col.upsert(ids=batch_ids, documents=texts, metadatas=metas, embeddings=embs)
|
| 249 |
total += len(batch)
|
| 250 |
|
| 251 |
-
logger.info(f"
|
| 252 |
return total + parent_count
|
| 253 |
|
| 254 |
def count(self) -> int:
|
|
|
|
| 255 |
col = self.collection
|
| 256 |
return int(col.count()) if col else 0
|
| 257 |
|
| 258 |
def get_all_documents(self, limit: int = 5000) -> List[Dict[str, Any]]:
|
|
|
|
| 259 |
col = self.collection
|
| 260 |
if col is None:
|
| 261 |
return []
|
|
@@ -272,6 +291,7 @@ class ChromaVectorDB:
|
|
| 272 |
return docs
|
| 273 |
|
| 274 |
def delete_documents(self, ids: Sequence[str]) -> int:
|
|
|
|
| 275 |
if not ids:
|
| 276 |
return 0
|
| 277 |
|
|
@@ -280,12 +300,14 @@ class ChromaVectorDB:
|
|
| 280 |
return 0
|
| 281 |
|
| 282 |
col.delete(ids=list(ids))
|
| 283 |
-
logger.info(f"
|
| 284 |
return len(ids)
|
| 285 |
|
| 286 |
def get_parent_node(self, parent_id: str) -> Optional[Dict[str, Any]]:
|
|
|
|
| 287 |
return self._parent_nodes.get(parent_id)
|
| 288 |
|
| 289 |
@property
|
| 290 |
def parent_nodes(self) -> Dict[str, Dict[str, Any]]:
|
|
|
|
| 291 |
return self._parent_nodes
|
|
|
|
| 13 |
|
| 14 |
@dataclass
|
| 15 |
class ChromaConfig:
|
| 16 |
+
"""Cấu hình cho ChromaDB."""
|
| 17 |
+
|
| 18 |
def _default_persist_dir() -> str:
|
| 19 |
+
"""Lấy đường dẫn mặc định cho persist directory."""
|
| 20 |
repo_root = Path(__file__).resolve().parents[2]
|
| 21 |
return str((repo_root / "data" / "chroma").resolve())
|
| 22 |
|
| 23 |
+
persist_dir: str = field(default_factory=_default_persist_dir) # Thư mục lưu DB
|
| 24 |
+
collection_name: str = "hust_rag_collection" # Tên collection
|
| 25 |
|
| 26 |
|
| 27 |
class ChromaVectorDB:
|
| 28 |
+
"""Wrapper cho ChromaDB với hỗ trợ Small-to-Big retrieval."""
|
| 29 |
+
|
| 30 |
def __init__(
|
| 31 |
self,
|
| 32 |
embedder: Any,
|
| 33 |
config: ChromaConfig | None = None,
|
| 34 |
):
|
| 35 |
+
"""Khởi tạo ChromaDB với embedder và config."""
|
| 36 |
self.embedder = embedder
|
| 37 |
self.config = config or ChromaConfig()
|
| 38 |
self._hasher = HashProcessor(verbose=False)
|
| 39 |
|
| 40 |
+
# Lưu trữ parent nodes (không embed, dùng cho Small-to-Big)
|
|
|
|
| 41 |
self._parent_nodes_path = Path(self.config.persist_dir) / "parent_nodes.json"
|
| 42 |
self._parent_nodes: Dict[str, Dict[str, Any]] = self._load_parent_nodes()
|
| 43 |
|
| 44 |
+
# Khởi tạo ChromaDB
|
| 45 |
self._vs = Chroma(
|
| 46 |
collection_name=self.config.collection_name,
|
| 47 |
embedding_function=self.embedder,
|
| 48 |
persist_directory=self.config.persist_dir,
|
| 49 |
)
|
| 50 |
+
logger.info(f"Đã khởi tạo ChromaVectorDB: {self.config.collection_name}")
|
| 51 |
|
| 52 |
def _load_parent_nodes(self) -> Dict[str, Dict[str, Any]]:
|
| 53 |
+
"""Tải parent nodes từ file JSON."""
|
| 54 |
if self._parent_nodes_path.exists():
|
| 55 |
try:
|
| 56 |
with open(self._parent_nodes_path, 'r', encoding='utf-8') as f:
|
| 57 |
data = json.load(f)
|
| 58 |
+
logger.info(f"Đã tải {len(data)} parent nodes từ {self._parent_nodes_path}")
|
| 59 |
return data
|
| 60 |
except Exception as e:
|
| 61 |
+
logger.warning(f"Không thể tải parent nodes: {e}")
|
| 62 |
return {}
|
| 63 |
|
| 64 |
def _save_parent_nodes(self) -> None:
|
| 65 |
+
"""Lưu parent nodes vào file JSON."""
|
| 66 |
try:
|
| 67 |
self._parent_nodes_path.parent.mkdir(parents=True, exist_ok=True)
|
| 68 |
with open(self._parent_nodes_path, 'w', encoding='utf-8') as f:
|
| 69 |
json.dump(self._parent_nodes, f, ensure_ascii=False, indent=2)
|
| 70 |
+
logger.info(f"Đã lưu {len(self._parent_nodes)} parent nodes vào {self._parent_nodes_path}")
|
| 71 |
except Exception as e:
|
| 72 |
+
logger.warning(f"Không thể lưu parent nodes: {e}")
|
| 73 |
|
| 74 |
@property
|
| 75 |
def collection(self):
|
| 76 |
+
"""Lấy collection gốc của ChromaDB."""
|
| 77 |
return getattr(self._vs, "_collection", None)
|
| 78 |
|
| 79 |
@property
|
| 80 |
def vectorstore(self):
|
| 81 |
+
"""Lấy LangChain Chroma vectorstore."""
|
| 82 |
return self._vs
|
| 83 |
|
| 84 |
def _flatten_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
| 85 |
+
"""Chuyển metadata phức tạp thành format ChromaDB hỗ trợ."""
|
| 86 |
out: Dict[str, Any] = {}
|
| 87 |
for k, v in (metadata or {}).items():
|
| 88 |
if v is None:
|
|
|
|
| 90 |
if isinstance(v, (str, int, float, bool)):
|
| 91 |
out[str(k)] = v
|
| 92 |
elif isinstance(v, (list, tuple, set, dict)):
|
| 93 |
+
# Chuyển list/dict thành JSON string
|
| 94 |
out[str(k)] = json.dumps(v, ensure_ascii=False)
|
| 95 |
else:
|
| 96 |
out[str(k)] = str(v)
|
| 97 |
return out
|
| 98 |
|
| 99 |
def _normalize_doc(self, doc: Any) -> Dict[str, Any]:
|
| 100 |
+
"""Chuẩn hóa document từ nhiều format khác nhau thành dict."""
|
| 101 |
+
# Đã là dict
|
| 102 |
if isinstance(doc, dict):
|
| 103 |
return doc
|
| 104 |
+
# TextNode/BaseNode từ llama_index
|
|
|
|
| 105 |
if hasattr(doc, "get_content") and hasattr(doc, "metadata"):
|
| 106 |
return {
|
| 107 |
"content": doc.get_content(),
|
| 108 |
"metadata": dict(doc.metadata) if doc.metadata else {},
|
| 109 |
}
|
| 110 |
+
# Document từ LangChain
|
|
|
|
| 111 |
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
| 112 |
return {
|
| 113 |
"content": doc.page_content,
|
| 114 |
"metadata": dict(doc.metadata) if doc.metadata else {},
|
| 115 |
}
|
| 116 |
+
raise TypeError(f"Không hỗ trợ loại document: {type(doc)}")
|
|
|
|
| 117 |
|
| 118 |
def _to_documents(self, docs: Sequence[Any], ids: Sequence[str]) -> List[Document]:
|
| 119 |
+
"""Chuyển danh sách docs thành LangChain Documents."""
|
| 120 |
out: List[Document] = []
|
| 121 |
for d, doc_id in zip(docs, ids):
|
| 122 |
normalized = self._normalize_doc(d)
|
|
|
|
| 126 |
return out
|
| 127 |
|
| 128 |
def _doc_id(self, doc: Any) -> str:
|
| 129 |
+
"""Tạo ID duy nhất cho document dựa trên nội dung."""
|
| 130 |
normalized = self._normalize_doc(doc)
|
| 131 |
md = normalized.get("metadata") or {}
|
| 132 |
key = {
|
|
|
|
| 144 |
ids: Optional[Sequence[str]] = None,
|
| 145 |
batch_size: int = 128,
|
| 146 |
) -> int:
|
| 147 |
+
"""Thêm documents vào vector store."""
|
| 148 |
if not docs:
|
| 149 |
return 0
|
| 150 |
|
| 151 |
if ids is not None and len(ids) != len(docs):
|
| 152 |
+
raise ValueError("Số lượng ids phải bằng số lượng docs")
|
| 153 |
|
| 154 |
+
# Tách parent nodes (không embed) khỏi regular nodes
|
| 155 |
regular_docs = []
|
| 156 |
regular_ids = []
|
| 157 |
parent_count = 0
|
|
|
|
| 162 |
doc_id = ids[i] if ids else self._doc_id(d)
|
| 163 |
|
| 164 |
if md.get("is_parent"):
|
| 165 |
+
# Lưu parent node riêng (cho Small-to-Big)
|
| 166 |
parent_id = md.get("node_id", doc_id)
|
| 167 |
self._parent_nodes[parent_id] = {
|
| 168 |
"id": parent_id,
|
|
|
|
| 175 |
regular_ids.append(doc_id)
|
| 176 |
|
| 177 |
if parent_count > 0:
|
| 178 |
+
logger.info(f"Đã lưu {parent_count} parent nodes (không embed)")
|
| 179 |
+
self._save_parent_nodes()
|
| 180 |
|
| 181 |
if not regular_docs:
|
| 182 |
return parent_count
|
| 183 |
|
| 184 |
+
# Thêm theo batch
|
| 185 |
bs = max(1, batch_size)
|
| 186 |
total = 0
|
| 187 |
|
|
|
|
| 193 |
try:
|
| 194 |
self._vs.add_documents(lc_docs, ids=batch_ids)
|
| 195 |
except TypeError:
|
| 196 |
+
# Fallback nếu add_documents không nhận ids
|
| 197 |
texts = [d.page_content for d in lc_docs]
|
| 198 |
metas = [d.metadata for d in lc_docs]
|
| 199 |
self._vs.add_texts(texts=texts, metadatas=metas, ids=batch_ids)
|
| 200 |
total += len(batch)
|
| 201 |
|
| 202 |
+
logger.info(f"Đã thêm {total} documents vào vector store")
|
| 203 |
return total + parent_count
|
| 204 |
|
| 205 |
def upsert_documents(
|
|
|
|
| 209 |
ids: Optional[Sequence[str]] = None,
|
| 210 |
batch_size: int = 128,
|
| 211 |
) -> int:
|
| 212 |
+
"""Upsert documents (thêm mới hoặc cập nhật nếu đã tồn tại)."""
|
| 213 |
if not docs:
|
| 214 |
return 0
|
| 215 |
|
| 216 |
if ids is not None and len(ids) != len(docs):
|
| 217 |
+
raise ValueError("Số lượng ids phải bằng số lượng docs")
|
| 218 |
|
| 219 |
+
# Tách parent nodes khỏi regular nodes
|
| 220 |
regular_docs = []
|
| 221 |
regular_ids = []
|
| 222 |
parent_count = 0
|
|
|
|
| 227 |
doc_id = ids[i] if ids else self._doc_id(d)
|
| 228 |
|
| 229 |
if md.get("is_parent"):
|
| 230 |
+
# Lưu parent node riêng
|
| 231 |
parent_id = md.get("node_id", doc_id)
|
| 232 |
self._parent_nodes[parent_id] = {
|
| 233 |
"id": parent_id,
|
|
|
|
| 240 |
regular_ids.append(doc_id)
|
| 241 |
|
| 242 |
if parent_count > 0:
|
| 243 |
+
logger.info(f"Đã lưu {parent_count} parent nodes (không embed)")
|
| 244 |
+
self._save_parent_nodes()
|
| 245 |
|
| 246 |
if not regular_docs:
|
| 247 |
return parent_count
|
|
|
|
| 249 |
bs = max(1, batch_size)
|
| 250 |
col = self.collection
|
| 251 |
|
| 252 |
+
# Fallback nếu không có collection
|
| 253 |
if col is None:
|
| 254 |
return self.add_documents(regular_docs, ids=regular_ids, batch_size=bs) + parent_count
|
| 255 |
|
| 256 |
+
# Upsert theo batch
|
| 257 |
total = 0
|
| 258 |
for start in range(0, len(regular_docs), bs):
|
| 259 |
batch = regular_docs[start : start + bs]
|
|
|
|
| 265 |
col.upsert(ids=batch_ids, documents=texts, metadatas=metas, embeddings=embs)
|
| 266 |
total += len(batch)
|
| 267 |
|
| 268 |
+
logger.info(f"Đã upsert {total} documents vào vector store")
|
| 269 |
return total + parent_count
|
| 270 |
|
| 271 |
def count(self) -> int:
|
| 272 |
+
"""Đếm số documents trong collection."""
|
| 273 |
col = self.collection
|
| 274 |
return int(col.count()) if col else 0
|
| 275 |
|
| 276 |
def get_all_documents(self, limit: int = 5000) -> List[Dict[str, Any]]:
|
| 277 |
+
"""Lấy tất cả documents từ collection."""
|
| 278 |
col = self.collection
|
| 279 |
if col is None:
|
| 280 |
return []
|
|
|
|
| 291 |
return docs
|
| 292 |
|
| 293 |
def delete_documents(self, ids: Sequence[str]) -> int:
|
| 294 |
+
"""Xóa documents theo danh sách IDs."""
|
| 295 |
if not ids:
|
| 296 |
return 0
|
| 297 |
|
|
|
|
| 300 |
return 0
|
| 301 |
|
| 302 |
col.delete(ids=list(ids))
|
| 303 |
+
logger.info(f"Đã xóa {len(ids)} documents khỏi vector store")
|
| 304 |
return len(ids)
|
| 305 |
|
| 306 |
def get_parent_node(self, parent_id: str) -> Optional[Dict[str, Any]]:
|
| 307 |
+
"""Lấy parent node theo ID (cho Small-to-Big)."""
|
| 308 |
return self._parent_nodes.get(parent_id)
|
| 309 |
|
| 310 |
@property
|
| 311 |
def parent_nodes(self) -> Dict[str, Dict[str, Any]]:
|
| 312 |
+
"""Lấy tất cả parent nodes."""
|
| 313 |
return self._parent_nodes
|
evaluation/eval_utils.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import re
|
|
@@ -20,31 +22,38 @@ from core.rag.generator import RAGGenerator
|
|
| 20 |
|
| 21 |
|
| 22 |
def strip_thinking(text: str) -> str:
|
|
|
|
| 23 |
return re.sub(r'<think>.*?</think>\s*', '', text, flags=re.DOTALL).strip()
|
| 24 |
|
| 25 |
|
| 26 |
def load_csv_data(csv_path: str, sample_size: int = 0) -> tuple[list, list]:
|
|
|
|
| 27 |
questions, ground_truths = [], []
|
| 28 |
with open(csv_path, 'r', encoding='utf-8') as f:
|
| 29 |
for row in csv.DictReader(f):
|
| 30 |
if row.get('question') and row.get('ground_truth'):
|
| 31 |
questions.append(row['question'])
|
| 32 |
ground_truths.append(row['ground_truth'])
|
|
|
|
|
|
|
| 33 |
if sample_size > 0:
|
| 34 |
questions = questions[:sample_size]
|
| 35 |
ground_truths = ground_truths[:sample_size]
|
|
|
|
| 36 |
return questions, ground_truths
|
| 37 |
|
| 38 |
|
| 39 |
def init_rag() -> tuple[RAGGenerator, QwenEmbeddings, OpenAI]:
|
|
|
|
| 40 |
embeddings = QwenEmbeddings(EmbeddingConfig())
|
| 41 |
db = ChromaVectorDB(embedder=embeddings, config=ChromaConfig())
|
| 42 |
retriever = Retriever(vector_db=db)
|
| 43 |
rag = RAGGenerator(retriever=retriever)
|
| 44 |
|
|
|
|
| 45 |
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
| 46 |
if not api_key:
|
| 47 |
-
raise ValueError("
|
| 48 |
|
| 49 |
llm_client = OpenAI(api_key=api_key, base_url="https://api.siliconflow.com/v1", timeout=60.0)
|
| 50 |
return rag, embeddings, llm_client
|
|
@@ -58,14 +67,18 @@ def generate_answers(
|
|
| 58 |
retrieval_mode: str = "hybrid_rerank",
|
| 59 |
max_workers: int = 8,
|
| 60 |
) -> tuple[list, list]:
|
|
|
|
| 61 |
|
| 62 |
def process(idx_q):
|
|
|
|
| 63 |
idx, q = idx_q
|
| 64 |
try:
|
|
|
|
| 65 |
prepared = rag.retrieve_and_prepare(q, mode=retrieval_mode)
|
| 66 |
if not prepared["results"]:
|
| 67 |
return idx, "Không tìm thấy thông tin.", []
|
| 68 |
|
|
|
|
| 69 |
resp = llm_client.chat.completions.create(
|
| 70 |
model=llm_model,
|
| 71 |
messages=[{"role": "user", "content": prepared["prompt"]}],
|
|
@@ -75,18 +88,20 @@ def generate_answers(
|
|
| 75 |
answer = strip_thinking(resp.choices[0].message.content or "")
|
| 76 |
return idx, answer, prepared["contexts"]
|
| 77 |
except Exception as e:
|
| 78 |
-
print(f" Q{idx+1}
|
| 79 |
return idx, "Không thể trả lời.", []
|
| 80 |
|
| 81 |
n = len(questions)
|
| 82 |
answers, contexts = [""] * n, [[] for _ in range(n)]
|
| 83 |
|
| 84 |
-
print(f"
|
|
|
|
|
|
|
| 85 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 86 |
futures = {executor.submit(process, (i, q)): i for i, q in enumerate(questions)}
|
| 87 |
for i, future in enumerate(as_completed(futures), 1):
|
| 88 |
idx, ans, ctx = future.result(timeout=120)
|
| 89 |
answers[idx], contexts[idx] = ans, ctx
|
| 90 |
-
print(f" [{i}/{n}]
|
| 91 |
|
| 92 |
return answers, contexts
|
|
|
|
| 1 |
+
"""Các utility functions cho evaluation."""
|
| 2 |
+
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
import re
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def strip_thinking(text: str) -> str:
|
| 25 |
+
"""Loại bỏ các block <think>...</think> từ output của LLM."""
|
| 26 |
return re.sub(r'<think>.*?</think>\s*', '', text, flags=re.DOTALL).strip()
|
| 27 |
|
| 28 |
|
| 29 |
def load_csv_data(csv_path: str, sample_size: int = 0) -> tuple[list, list]:
|
| 30 |
+
"""Đọc dữ liệu câu hỏi và ground truth từ file CSV."""
|
| 31 |
questions, ground_truths = [], []
|
| 32 |
with open(csv_path, 'r', encoding='utf-8') as f:
|
| 33 |
for row in csv.DictReader(f):
|
| 34 |
if row.get('question') and row.get('ground_truth'):
|
| 35 |
questions.append(row['question'])
|
| 36 |
ground_truths.append(row['ground_truth'])
|
| 37 |
+
|
| 38 |
+
# Giới hạn số lượng sample
|
| 39 |
if sample_size > 0:
|
| 40 |
questions = questions[:sample_size]
|
| 41 |
ground_truths = ground_truths[:sample_size]
|
| 42 |
+
|
| 43 |
return questions, ground_truths
|
| 44 |
|
| 45 |
|
| 46 |
def init_rag() -> tuple[RAGGenerator, QwenEmbeddings, OpenAI]:
|
| 47 |
+
"""Khởi tạo các components RAG cho evaluation."""
|
| 48 |
embeddings = QwenEmbeddings(EmbeddingConfig())
|
| 49 |
db = ChromaVectorDB(embedder=embeddings, config=ChromaConfig())
|
| 50 |
retriever = Retriever(vector_db=db)
|
| 51 |
rag = RAGGenerator(retriever=retriever)
|
| 52 |
|
| 53 |
+
# Khởi tạo LLM client
|
| 54 |
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
|
| 55 |
if not api_key:
|
| 56 |
+
raise ValueError("Chưa đặt SILICONFLOW_API_KEY")
|
| 57 |
|
| 58 |
llm_client = OpenAI(api_key=api_key, base_url="https://api.siliconflow.com/v1", timeout=60.0)
|
| 59 |
return rag, embeddings, llm_client
|
|
|
|
| 67 |
retrieval_mode: str = "hybrid_rerank",
|
| 68 |
max_workers: int = 8,
|
| 69 |
) -> tuple[list, list]:
|
| 70 |
+
"""Generate câu trả lời cho danh sách câu hỏi với parallel processing."""
|
| 71 |
|
| 72 |
def process(idx_q):
|
| 73 |
+
"""Xử lý một câu hỏi: retrieve + generate."""
|
| 74 |
idx, q = idx_q
|
| 75 |
try:
|
| 76 |
+
# Retrieve và chuẩn bị context
|
| 77 |
prepared = rag.retrieve_and_prepare(q, mode=retrieval_mode)
|
| 78 |
if not prepared["results"]:
|
| 79 |
return idx, "Không tìm thấy thông tin.", []
|
| 80 |
|
| 81 |
+
# Gọi LLM để generate answer
|
| 82 |
resp = llm_client.chat.completions.create(
|
| 83 |
model=llm_model,
|
| 84 |
messages=[{"role": "user", "content": prepared["prompt"]}],
|
|
|
|
| 88 |
answer = strip_thinking(resp.choices[0].message.content or "")
|
| 89 |
return idx, answer, prepared["contexts"]
|
| 90 |
except Exception as e:
|
| 91 |
+
print(f" Q{idx+1} Lỗi: {e}")
|
| 92 |
return idx, "Không thể trả lời.", []
|
| 93 |
|
| 94 |
n = len(questions)
|
| 95 |
answers, contexts = [""] * n, [[] for _ in range(n)]
|
| 96 |
|
| 97 |
+
print(f" Đang generate {n} câu trả lời...")
|
| 98 |
+
|
| 99 |
+
# Xử lý song song với ThreadPoolExecutor
|
| 100 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 101 |
futures = {executor.submit(process, (i, q)): i for i, q in enumerate(questions)}
|
| 102 |
for i, future in enumerate(as_completed(futures), 1):
|
| 103 |
idx, ans, ctx = future.result(timeout=120)
|
| 104 |
answers[idx], contexts[idx] = ans, ctx
|
| 105 |
+
print(f" [{i}/{n}] Xong")
|
| 106 |
|
| 107 |
return answers, contexts
|
evaluation/ragas_eval.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import json
|
|
@@ -21,33 +23,34 @@ from ragas.run_config import RunConfig
|
|
| 21 |
|
| 22 |
from evaluation.eval_utils import load_csv_data, init_rag, generate_answers
|
| 23 |
|
| 24 |
-
#
|
| 25 |
-
CSV_PATH = "data/data.csv"
|
| 26 |
-
OUTPUT_DIR = "evaluation/results"
|
| 27 |
-
LLM_MODEL = os.getenv("EVAL_LLM_MODEL", "nex-agi/DeepSeek-V3.1-Nex-N1")
|
| 28 |
API_BASE = "https://api.siliconflow.com/v1"
|
| 29 |
|
| 30 |
|
| 31 |
def run_evaluation(sample_size: int = 10, retrieval_mode: str = "hybrid_rerank") -> dict:
|
|
|
|
| 32 |
print(f"\n{'='*60}")
|
| 33 |
print(f"RAGAS EVALUATION - Mode: {retrieval_mode}")
|
| 34 |
print(f"{'='*60}")
|
| 35 |
|
| 36 |
-
#
|
| 37 |
rag, embeddings, llm_client = init_rag()
|
| 38 |
|
| 39 |
-
#
|
| 40 |
questions, ground_truths = load_csv_data(str(REPO_ROOT / CSV_PATH), sample_size)
|
| 41 |
-
print(f"
|
| 42 |
|
| 43 |
-
# Generate
|
| 44 |
answers, contexts = generate_answers(
|
| 45 |
rag, questions, llm_client,
|
| 46 |
llm_model=LLM_MODEL,
|
| 47 |
retrieval_mode=retrieval_mode,
|
| 48 |
)
|
| 49 |
|
| 50 |
-
#
|
| 51 |
api_key = os.getenv("SILICONFLOW_API_KEY", "")
|
| 52 |
evaluator_llm = LangchainLLMWrapper(ChatOpenAI(
|
| 53 |
model=LLM_MODEL,
|
|
@@ -59,7 +62,7 @@ def run_evaluation(sample_size: int = 10, retrieval_mode: str = "hybrid_rerank")
|
|
| 59 |
))
|
| 60 |
evaluator_embeddings = LangchainEmbeddingsWrapper(embeddings)
|
| 61 |
|
| 62 |
-
#
|
| 63 |
dataset = Dataset.from_dict({
|
| 64 |
"question": questions,
|
| 65 |
"answer": answers,
|
|
@@ -67,18 +70,18 @@ def run_evaluation(sample_size: int = 10, retrieval_mode: str = "hybrid_rerank")
|
|
| 67 |
"ground_truth": ground_truths,
|
| 68 |
})
|
| 69 |
|
| 70 |
-
#
|
| 71 |
-
print("\n
|
| 72 |
results = evaluate(
|
| 73 |
dataset=dataset,
|
| 74 |
metrics=[
|
| 75 |
-
faithfulness,
|
| 76 |
-
answer_relevancy,
|
| 77 |
-
context_precision,
|
| 78 |
-
context_recall,
|
| 79 |
-
RougeScore(rouge_type='rouge1', mode='fmeasure'),
|
| 80 |
-
RougeScore(rouge_type='rouge2', mode='fmeasure'),
|
| 81 |
-
RougeScore(rouge_type='rougeL', mode='fmeasure'),
|
| 82 |
],
|
| 83 |
llm=evaluator_llm,
|
| 84 |
embeddings=evaluator_embeddings,
|
|
@@ -86,65 +89,37 @@ def run_evaluation(sample_size: int = 10, retrieval_mode: str = "hybrid_rerank")
|
|
| 86 |
run_config=RunConfig(max_workers=8, timeout=600, max_retries=3),
|
| 87 |
)
|
| 88 |
|
| 89 |
-
#
|
| 90 |
df = results.to_pandas()
|
| 91 |
metric_cols = [c for c in df.columns if c not in ("question", "answer", "contexts", "ground_truth", "user_input", "response", "reference", "retrieved_contexts")]
|
| 92 |
|
|
|
|
| 93 |
avg_scores = {}
|
| 94 |
for col in metric_cols:
|
| 95 |
values = df[col].dropna().tolist()
|
| 96 |
if values:
|
| 97 |
avg_scores[col] = sum(values) / len(values)
|
| 98 |
|
| 99 |
-
#
|
| 100 |
out_path = REPO_ROOT / OUTPUT_DIR
|
| 101 |
out_path.mkdir(parents=True, exist_ok=True)
|
| 102 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 103 |
-
|
| 104 |
-
#
|
| 105 |
-
json_path = out_path / f"ragas_{retrieval_mode}_{timestamp}.json"
|
| 106 |
-
with open(json_path, 'w', encoding='utf-8') as f:
|
| 107 |
-
json.dump({
|
| 108 |
-
"timestamp": timestamp,
|
| 109 |
-
"retrieval_mode": retrieval_mode,
|
| 110 |
-
"sample_size": len(questions),
|
| 111 |
-
"avg_scores": avg_scores,
|
| 112 |
-
"samples": [
|
| 113 |
-
{"question": q, "answer": a, "ground_truth": gt, "contexts": ctx}
|
| 114 |
-
for q, a, gt, ctx in zip(questions, answers, ground_truths, contexts)
|
| 115 |
-
]
|
| 116 |
-
}, f, ensure_ascii=False, indent=2)
|
| 117 |
-
|
| 118 |
-
# CSV
|
| 119 |
csv_path = out_path / f"ragas_{retrieval_mode}_{timestamp}.csv"
|
| 120 |
with open(csv_path, 'w', encoding='utf-8') as f:
|
| 121 |
f.write("retrieval_mode,sample_size," + ",".join(avg_scores.keys()) + "\n")
|
| 122 |
f.write(f"{retrieval_mode},{len(questions)}," + ",".join(f"{v:.4f}" for v in avg_scores.values()) + "\n")
|
| 123 |
|
| 124 |
-
#
|
| 125 |
print(f"\n{'='*60}")
|
| 126 |
-
print(f"
|
| 127 |
print(f"{'='*60}")
|
| 128 |
for metric, score in avg_scores.items():
|
| 129 |
bar = "#" * int(score * 20) + "-" * (20 - int(score * 20))
|
| 130 |
print(f" {metric:25} [{bar}] {score:.4f}")
|
| 131 |
|
| 132 |
-
print(f"\
|
| 133 |
-
print(f"
|
| 134 |
-
|
| 135 |
-
return avg_scores
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
if __name__ == "__main__":
|
| 139 |
-
import argparse
|
| 140 |
-
parser = argparse.ArgumentParser(description="RAGAS Evaluation")
|
| 141 |
-
parser.add_argument("--samples", type=int, default=10, help="Number of samples")
|
| 142 |
-
parser.add_argument("--mode", type=str, default="hybrid_rerank",
|
| 143 |
-
choices=["vector_only", "bm25_only", "hybrid", "hybrid_rerank", "all"])
|
| 144 |
-
args = parser.parse_args()
|
| 145 |
|
| 146 |
-
|
| 147 |
-
for mode in ["vector_only", "bm25_only", "hybrid", "hybrid_rerank"]:
|
| 148 |
-
run_evaluation(args.samples, mode)
|
| 149 |
-
else:
|
| 150 |
-
run_evaluation(args.samples, args.mode)
|
|
|
|
| 1 |
+
"""Script đánh giá RAG bằng RAGAS framework."""
|
| 2 |
+
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
import json
|
|
|
|
| 23 |
|
| 24 |
from evaluation.eval_utils import load_csv_data, init_rag, generate_answers
|
| 25 |
|
| 26 |
+
# Cấu hình
|
| 27 |
+
CSV_PATH = "data/data.csv" # File dữ liệu test
|
| 28 |
+
OUTPUT_DIR = "evaluation/results" # Thư mục output
|
| 29 |
+
LLM_MODEL = os.getenv("EVAL_LLM_MODEL", "nex-agi/DeepSeek-V3.1-Nex-N1") # Model đánh giá
|
| 30 |
API_BASE = "https://api.siliconflow.com/v1"
|
| 31 |
|
| 32 |
|
| 33 |
def run_evaluation(sample_size: int = 10, retrieval_mode: str = "hybrid_rerank") -> dict:
|
| 34 |
+
"""Chạy đánh giá RAGAS trên dữ liệu test."""
|
| 35 |
print(f"\n{'='*60}")
|
| 36 |
print(f"RAGAS EVALUATION - Mode: {retrieval_mode}")
|
| 37 |
print(f"{'='*60}")
|
| 38 |
|
| 39 |
+
# Khởi tạo RAG components
|
| 40 |
rag, embeddings, llm_client = init_rag()
|
| 41 |
|
| 42 |
+
# Tải dữ liệu test
|
| 43 |
questions, ground_truths = load_csv_data(str(REPO_ROOT / CSV_PATH), sample_size)
|
| 44 |
+
print(f" Đã tải {len(questions)} samples")
|
| 45 |
|
| 46 |
+
# Generate câu trả lời
|
| 47 |
answers, contexts = generate_answers(
|
| 48 |
rag, questions, llm_client,
|
| 49 |
llm_model=LLM_MODEL,
|
| 50 |
retrieval_mode=retrieval_mode,
|
| 51 |
)
|
| 52 |
|
| 53 |
+
# Thiết lập RAGAS evaluator
|
| 54 |
api_key = os.getenv("SILICONFLOW_API_KEY", "")
|
| 55 |
evaluator_llm = LangchainLLMWrapper(ChatOpenAI(
|
| 56 |
model=LLM_MODEL,
|
|
|
|
| 62 |
))
|
| 63 |
evaluator_embeddings = LangchainEmbeddingsWrapper(embeddings)
|
| 64 |
|
| 65 |
+
# Chuyển dữ liệu thành format Dataset
|
| 66 |
dataset = Dataset.from_dict({
|
| 67 |
"question": questions,
|
| 68 |
"answer": answers,
|
|
|
|
| 70 |
"ground_truth": ground_truths,
|
| 71 |
})
|
| 72 |
|
| 73 |
+
# Chạy đánh giá RAGAS
|
| 74 |
+
print("\n Đang chạy RAGAS metrics...")
|
| 75 |
results = evaluate(
|
| 76 |
dataset=dataset,
|
| 77 |
metrics=[
|
| 78 |
+
faithfulness, # Độ trung thực với context
|
| 79 |
+
answer_relevancy, # Độ liên quan của câu trả lời
|
| 80 |
+
context_precision, # Độ chính xác của context
|
| 81 |
+
context_recall, # Độ bao phủ của context
|
| 82 |
+
RougeScore(rouge_type='rouge1', mode='fmeasure'), # ROUGE-1
|
| 83 |
+
RougeScore(rouge_type='rouge2', mode='fmeasure'), # ROUGE-2
|
| 84 |
+
RougeScore(rouge_type='rougeL', mode='fmeasure'), # ROUGE-L
|
| 85 |
],
|
| 86 |
llm=evaluator_llm,
|
| 87 |
embeddings=evaluator_embeddings,
|
|
|
|
| 89 |
run_config=RunConfig(max_workers=8, timeout=600, max_retries=3),
|
| 90 |
)
|
| 91 |
|
| 92 |
+
# Trích xuất điểm số
|
| 93 |
df = results.to_pandas()
|
| 94 |
metric_cols = [c for c in df.columns if c not in ("question", "answer", "contexts", "ground_truth", "user_input", "response", "reference", "retrieved_contexts")]
|
| 95 |
|
| 96 |
+
# Tính điểm trung bình cho mỗi metric
|
| 97 |
avg_scores = {}
|
| 98 |
for col in metric_cols:
|
| 99 |
values = df[col].dropna().tolist()
|
| 100 |
if values:
|
| 101 |
avg_scores[col] = sum(values) / len(values)
|
| 102 |
|
| 103 |
+
# Lưu kết quả
|
| 104 |
out_path = REPO_ROOT / OUTPUT_DIR
|
| 105 |
out_path.mkdir(parents=True, exist_ok=True)
|
| 106 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 107 |
+
|
| 108 |
+
# Lưu file CSV (tóm tắt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
csv_path = out_path / f"ragas_{retrieval_mode}_{timestamp}.csv"
|
| 110 |
with open(csv_path, 'w', encoding='utf-8') as f:
|
| 111 |
f.write("retrieval_mode,sample_size," + ",".join(avg_scores.keys()) + "\n")
|
| 112 |
f.write(f"{retrieval_mode},{len(questions)}," + ",".join(f"{v:.4f}" for v in avg_scores.values()) + "\n")
|
| 113 |
|
| 114 |
+
# In kết quả
|
| 115 |
print(f"\n{'='*60}")
|
| 116 |
+
print(f"KẾT QUẢ - {retrieval_mode} ({len(questions)} samples)")
|
| 117 |
print(f"{'='*60}")
|
| 118 |
for metric, score in avg_scores.items():
|
| 119 |
bar = "#" * int(score * 20) + "-" * (20 - int(score * 20))
|
| 120 |
print(f" {metric:25} [{bar}] {score:.4f}")
|
| 121 |
|
| 122 |
+
print(f"\nĐã lưu: {json_path}")
|
| 123 |
+
print(f"Đã lưu: {csv_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
return avg_scores
|
|
|
|
|
|
|
|
|
|
|
|
scripts/build_data.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
|
|
|
|
|
| 1 |
import sys
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
from dotenv import find_dotenv, load_dotenv
|
| 4 |
|
| 5 |
-
# Load .env file
|
| 6 |
load_dotenv(find_dotenv(usecwd=True))
|
| 7 |
|
| 8 |
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
@@ -14,12 +16,11 @@ from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings
|
|
| 14 |
from core.rag.vector_store import ChromaConfig, ChromaVectorDB
|
| 15 |
from core.hash_file.hash_file import HashProcessor
|
| 16 |
|
| 17 |
-
# Global hash processor instance
|
| 18 |
_hasher = HashProcessor(verbose=False)
|
| 19 |
|
| 20 |
|
| 21 |
def get_db_file_info(db: ChromaVectorDB) -> dict:
|
| 22 |
-
"""
|
| 23 |
docs = db.get_all_documents()
|
| 24 |
file_to_ids = {}
|
| 25 |
file_to_hash = {}
|
|
@@ -35,7 +36,7 @@ def get_db_file_info(db: ChromaVectorDB) -> dict:
|
|
| 35 |
file_to_ids[source] = set()
|
| 36 |
file_to_ids[source].add(doc_id)
|
| 37 |
|
| 38 |
-
#
|
| 39 |
if source not in file_to_hash and content_hash:
|
| 40 |
file_to_hash[source] = content_hash
|
| 41 |
|
|
@@ -43,64 +44,65 @@ def get_db_file_info(db: ChromaVectorDB) -> dict:
|
|
| 43 |
|
| 44 |
|
| 45 |
def main():
|
| 46 |
-
|
| 47 |
-
parser =
|
| 48 |
-
parser.add_argument("--
|
| 49 |
-
parser.add_argument("--no-delete", action="store_true", help="Don't delete orphaned docs")
|
| 50 |
args = parser.parse_args()
|
| 51 |
|
| 52 |
print("=" * 60)
|
| 53 |
print("BUILD HUST RAG DATABASE")
|
| 54 |
print("=" * 60)
|
| 55 |
|
| 56 |
-
|
|
|
|
| 57 |
emb_cfg = EmbeddingConfig()
|
| 58 |
emb = QwenEmbeddings(emb_cfg)
|
| 59 |
print(f" Model: {emb_cfg.model}")
|
| 60 |
print(f" API: {emb_cfg.api_base_url}")
|
| 61 |
|
| 62 |
-
|
|
|
|
| 63 |
db_cfg = ChromaConfig()
|
| 64 |
db = ChromaVectorDB(embedder=emb, config=db_cfg)
|
| 65 |
old_count = db.count()
|
| 66 |
print(f" Collection: {db_cfg.collection_name}")
|
| 67 |
-
print(f"
|
| 68 |
|
| 69 |
-
#
|
| 70 |
db_info = {"ids": {}, "hashes": {}}
|
| 71 |
if not args.force and old_count > 0:
|
| 72 |
-
print("\n
|
| 73 |
db_info = get_db_file_info(db)
|
| 74 |
-
print(f"
|
| 75 |
|
| 76 |
-
#
|
| 77 |
-
print("\n[3/5]
|
| 78 |
root = REPO_ROOT / "data" / "data_process"
|
| 79 |
md_files = sorted(root.rglob("*.md"))
|
| 80 |
-
print(f"
|
| 81 |
|
| 82 |
-
#
|
| 83 |
current_files = {f.name for f in md_files}
|
| 84 |
db_files = set(db_info["ids"].keys())
|
| 85 |
|
| 86 |
-
#
|
| 87 |
files_to_delete = db_files - current_files
|
| 88 |
|
| 89 |
-
#
|
| 90 |
deleted_count = 0
|
| 91 |
if files_to_delete and not args.no_delete:
|
| 92 |
-
print(f"\n[4/5]
|
| 93 |
for filename in files_to_delete:
|
| 94 |
doc_ids = list(db_info["ids"].get(filename, []))
|
| 95 |
if doc_ids:
|
| 96 |
db.delete_documents(doc_ids)
|
| 97 |
deleted_count += len(doc_ids)
|
| 98 |
-
print(f"
|
| 99 |
else:
|
| 100 |
-
print("\n[4/5]
|
| 101 |
|
| 102 |
-
#
|
| 103 |
-
print("\n[5/5]
|
| 104 |
total_added = 0
|
| 105 |
total_updated = 0
|
| 106 |
skipped = 0
|
|
@@ -110,16 +112,16 @@ def main():
|
|
| 110 |
db_hash = db_info["hashes"].get(f.name, "")
|
| 111 |
existing_ids = db_info["ids"].get(f.name, set())
|
| 112 |
|
| 113 |
-
#
|
| 114 |
if not args.force and db_hash == file_hash:
|
| 115 |
-
print(f" [{i}/{len(md_files)}] {f.name}:
|
| 116 |
skipped += 1
|
| 117 |
continue
|
| 118 |
|
| 119 |
-
#
|
| 120 |
if existing_ids and not args.force:
|
| 121 |
db.delete_documents(list(existing_ids))
|
| 122 |
-
print(f" [{i}/{len(md_files)}] {f.name}:
|
| 123 |
is_update = True
|
| 124 |
else:
|
| 125 |
is_update = False
|
|
@@ -127,7 +129,7 @@ def main():
|
|
| 127 |
try:
|
| 128 |
docs = chunk_markdown_file(f)
|
| 129 |
if docs:
|
| 130 |
-
#
|
| 131 |
for doc in docs:
|
| 132 |
if hasattr(doc, 'metadata'):
|
| 133 |
doc.metadata["content_hash"] = file_hash
|
|
@@ -135,29 +137,29 @@ def main():
|
|
| 135 |
doc["metadata"]["content_hash"] = file_hash
|
| 136 |
|
| 137 |
n = db.upsert_documents(docs)
|
| 138 |
-
|
| 139 |
if is_update:
|
| 140 |
total_updated += n
|
| 141 |
-
print(f" [{i}/{len(md_files)}] {f.name}: +{n}
|
| 142 |
else:
|
| 143 |
total_added += n
|
| 144 |
print(f" [{i}/{len(md_files)}] {f.name}: {n} chunks")
|
| 145 |
else:
|
| 146 |
-
print(f" [{i}/{len(md_files)}] {f.name}:
|
| 147 |
except Exception as e:
|
| 148 |
-
print(f" [{i}/{len(md_files)}] {f.name}:
|
| 149 |
|
|
|
|
| 150 |
new_count = db.count()
|
| 151 |
print(f"\n{'=' * 60}")
|
| 152 |
-
print("
|
| 153 |
print("=" * 60)
|
| 154 |
-
print(f"
|
| 155 |
-
print(f"
|
| 156 |
-
print(f"
|
| 157 |
-
print(f"
|
| 158 |
-
print(f" DB
|
| 159 |
|
| 160 |
-
print("\
|
| 161 |
|
| 162 |
|
| 163 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
"""Script build ChromaDB từ markdown files với incremental update."""
|
| 2 |
+
|
| 3 |
import sys
|
| 4 |
+
import argparse
|
| 5 |
from pathlib import Path
|
| 6 |
from dotenv import find_dotenv, load_dotenv
|
| 7 |
|
|
|
|
| 8 |
load_dotenv(find_dotenv(usecwd=True))
|
| 9 |
|
| 10 |
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
|
|
| 16 |
from core.rag.vector_store import ChromaConfig, ChromaVectorDB
|
| 17 |
from core.hash_file.hash_file import HashProcessor
|
| 18 |
|
|
|
|
| 19 |
_hasher = HashProcessor(verbose=False)
|
| 20 |
|
| 21 |
|
| 22 |
def get_db_file_info(db: ChromaVectorDB) -> dict:
|
| 23 |
+
"""Lấy thông tin files đã có trong DB (IDs và hash)."""
|
| 24 |
docs = db.get_all_documents()
|
| 25 |
file_to_ids = {}
|
| 26 |
file_to_hash = {}
|
|
|
|
| 36 |
file_to_ids[source] = set()
|
| 37 |
file_to_ids[source].add(doc_id)
|
| 38 |
|
| 39 |
+
# Lưu hash đầu tiên tìm thấy cho file
|
| 40 |
if source not in file_to_hash and content_hash:
|
| 41 |
file_to_hash[source] = content_hash
|
| 42 |
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
def main():
|
| 47 |
+
parser = argparse.ArgumentParser(description="Build ChromaDB từ markdown files")
|
| 48 |
+
parser.add_argument("--force", action="store_true", help="Build lại tất cả files")
|
| 49 |
+
parser.add_argument("--no-delete", action="store_true", help="Không xóa docs orphaned")
|
|
|
|
| 50 |
args = parser.parse_args()
|
| 51 |
|
| 52 |
print("=" * 60)
|
| 53 |
print("BUILD HUST RAG DATABASE")
|
| 54 |
print("=" * 60)
|
| 55 |
|
| 56 |
+
# Bước 1: Khởi tạo embedder
|
| 57 |
+
print("\n[1/5] Khởi tạo embedder...")
|
| 58 |
emb_cfg = EmbeddingConfig()
|
| 59 |
emb = QwenEmbeddings(emb_cfg)
|
| 60 |
print(f" Model: {emb_cfg.model}")
|
| 61 |
print(f" API: {emb_cfg.api_base_url}")
|
| 62 |
|
| 63 |
+
# Bước 2: Khởi tạo ChromaDB
|
| 64 |
+
print("\n[2/5] Khởi tạo ChromaDB...")
|
| 65 |
db_cfg = ChromaConfig()
|
| 66 |
db = ChromaVectorDB(embedder=emb, config=db_cfg)
|
| 67 |
old_count = db.count()
|
| 68 |
print(f" Collection: {db_cfg.collection_name}")
|
| 69 |
+
print(f" Số docs hiện tại: {old_count}")
|
| 70 |
|
| 71 |
+
# Lấy trạng thái hiện tại của DB
|
| 72 |
db_info = {"ids": {}, "hashes": {}}
|
| 73 |
if not args.force and old_count > 0:
|
| 74 |
+
print("\n Đang quét documents trong DB...")
|
| 75 |
db_info = get_db_file_info(db)
|
| 76 |
+
print(f" Tìm thấy {len(db_info['ids'])} source files trong DB")
|
| 77 |
|
| 78 |
+
# Bước 3: Quét markdown files
|
| 79 |
+
print("\n[3/5] Quét markdown files...")
|
| 80 |
root = REPO_ROOT / "data" / "data_process"
|
| 81 |
md_files = sorted(root.rglob("*.md"))
|
| 82 |
+
print(f" Tìm thấy {len(md_files)} markdown files trên disk")
|
| 83 |
|
| 84 |
+
# So sánh files trên disk vs trong DB
|
| 85 |
current_files = {f.name for f in md_files}
|
| 86 |
db_files = set(db_info["ids"].keys())
|
| 87 |
|
| 88 |
+
# Tìm files cần xóa (có trong DB nhưng không có trên disk)
|
| 89 |
files_to_delete = db_files - current_files
|
| 90 |
|
| 91 |
+
# Bước 4: Xóa docs orphaned
|
| 92 |
deleted_count = 0
|
| 93 |
if files_to_delete and not args.no_delete:
|
| 94 |
+
print(f"\n[4/5] Dọn dẹp {len(files_to_delete)} files đã xóa...")
|
| 95 |
for filename in files_to_delete:
|
| 96 |
doc_ids = list(db_info["ids"].get(filename, []))
|
| 97 |
if doc_ids:
|
| 98 |
db.delete_documents(doc_ids)
|
| 99 |
deleted_count += len(doc_ids)
|
| 100 |
+
print(f" Đã xóa: {filename} ({len(doc_ids)} chunks)")
|
| 101 |
else:
|
| 102 |
+
print("\n[4/5] Không có files cần xóa")
|
| 103 |
|
| 104 |
+
# Bước 5: Xử lý markdown files (thêm mới, cập nhật)
|
| 105 |
+
print("\n[5/5] Xử lý markdown files...")
|
| 106 |
total_added = 0
|
| 107 |
total_updated = 0
|
| 108 |
skipped = 0
|
|
|
|
| 112 |
db_hash = db_info["hashes"].get(f.name, "")
|
| 113 |
existing_ids = db_info["ids"].get(f.name, set())
|
| 114 |
|
| 115 |
+
# Bỏ qua nếu hash khớp (file không thay đổi)
|
| 116 |
if not args.force and db_hash == file_hash:
|
| 117 |
+
print(f" [{i}/{len(md_files)}] {f.name}: BỎ QUA (không đổi)")
|
| 118 |
skipped += 1
|
| 119 |
continue
|
| 120 |
|
| 121 |
+
# Nếu file thay đổi, xóa chunks cũ trước
|
| 122 |
if existing_ids and not args.force:
|
| 123 |
db.delete_documents(list(existing_ids))
|
| 124 |
+
print(f" [{i}/{len(md_files)}] {f.name}: CẬP NHẬT (xóa {len(existing_ids)} chunks cũ)")
|
| 125 |
is_update = True
|
| 126 |
else:
|
| 127 |
is_update = False
|
|
|
|
| 129 |
try:
|
| 130 |
docs = chunk_markdown_file(f)
|
| 131 |
if docs:
|
| 132 |
+
# Thêm hash vào metadata để phát hiện thay đổi lần sau
|
| 133 |
for doc in docs:
|
| 134 |
if hasattr(doc, 'metadata'):
|
| 135 |
doc.metadata["content_hash"] = file_hash
|
|
|
|
| 137 |
doc["metadata"]["content_hash"] = file_hash
|
| 138 |
|
| 139 |
n = db.upsert_documents(docs)
|
|
|
|
| 140 |
if is_update:
|
| 141 |
total_updated += n
|
| 142 |
+
print(f" [{i}/{len(md_files)}] {f.name}: +{n} chunks mới")
|
| 143 |
else:
|
| 144 |
total_added += n
|
| 145 |
print(f" [{i}/{len(md_files)}] {f.name}: {n} chunks")
|
| 146 |
else:
|
| 147 |
+
print(f" [{i}/{len(md_files)}] {f.name}: BỎ QUA (không có chunks)")
|
| 148 |
except Exception as e:
|
| 149 |
+
print(f" [{i}/{len(md_files)}] {f.name}: LỖI - {e}")
|
| 150 |
|
| 151 |
+
# Tổng kết
|
| 152 |
new_count = db.count()
|
| 153 |
print(f"\n{'=' * 60}")
|
| 154 |
+
print("TỔNG KẾT")
|
| 155 |
print("=" * 60)
|
| 156 |
+
print(f" Đã xóa (orphaned): {deleted_count} chunks")
|
| 157 |
+
print(f" Đã cập nhật: {total_updated} chunks")
|
| 158 |
+
print(f" Đã thêm mới: {total_added} chunks")
|
| 159 |
+
print(f" Đã bỏ qua: {skipped} files")
|
| 160 |
+
print(f" Số docs trong DB: {old_count} -> {new_count} ({new_count - old_count:+d})")
|
| 161 |
|
| 162 |
+
print("\nHOÀN TẤT!")
|
| 163 |
|
| 164 |
|
| 165 |
if __name__ == "__main__":
|
scripts/run_eval.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import sys
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
|
| 4 |
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
@@ -7,18 +8,19 @@ if str(REPO_ROOT) not in sys.path:
|
|
| 7 |
|
| 8 |
|
| 9 |
def main():
|
| 10 |
-
|
| 11 |
-
parser =
|
| 12 |
-
parser.add_argument("--samples", type=int, default=10, help="Number of samples (0 = all)")
|
| 13 |
parser.add_argument("--mode", type=str, default="hybrid_rerank",
|
| 14 |
-
choices=["vector_only", "bm25_only", "hybrid", "hybrid_rerank", "all"]
|
|
|
|
| 15 |
args = parser.parse_args()
|
| 16 |
|
| 17 |
from evaluation.ragas_eval import run_evaluation
|
| 18 |
|
| 19 |
if args.mode == "all":
|
|
|
|
| 20 |
print("\n" + "=" * 60)
|
| 21 |
-
print("
|
| 22 |
print("=" * 60)
|
| 23 |
for mode in ["vector_only", "bm25_only", "hybrid", "hybrid_rerank"]:
|
| 24 |
run_evaluation(args.samples, mode)
|
|
|
|
| 1 |
import sys
|
| 2 |
+
import argparse
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
def main():
|
| 11 |
+
parser = argparse.ArgumentParser(description="Đánh giá RAG bằng RAGAS")
|
| 12 |
+
parser.add_argument("--samples", type=int, default=10, help="Số lượng samples (0 = tất cả)")
|
|
|
|
| 13 |
parser.add_argument("--mode", type=str, default="hybrid_rerank",
|
| 14 |
+
choices=["vector_only", "bm25_only", "hybrid", "hybrid_rerank", "all"],
|
| 15 |
+
help="Chế độ retrieval")
|
| 16 |
args = parser.parse_args()
|
| 17 |
|
| 18 |
from evaluation.ragas_eval import run_evaluation
|
| 19 |
|
| 20 |
if args.mode == "all":
|
| 21 |
+
# Chạy tất cả các chế độ retrieval
|
| 22 |
print("\n" + "=" * 60)
|
| 23 |
+
print("CHẠY TẤT CẢ CÁC CHẾ ĐỘ RETRIEVAL")
|
| 24 |
print("=" * 60)
|
| 25 |
for mode in ["vector_only", "bm25_only", "hybrid", "hybrid_rerank"]:
|
| 26 |
run_evaluation(args.samples, mode)
|
test/parse_data_hash_test.py
DELETED
|
@@ -1,102 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import random
|
| 4 |
-
import shutil
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
|
| 7 |
-
# Ensure project root is on sys.path
|
| 8 |
-
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 9 |
-
if _PROJECT_ROOT not in sys.path:
|
| 10 |
-
sys.path.insert(0, _PROJECT_ROOT)
|
| 11 |
-
|
| 12 |
-
from core.preprocessing.docling_processor import DoclingProcessor
|
| 13 |
-
|
| 14 |
-
def get_random_local_pdf(source_dir: str):
|
| 15 |
-
if not os.path.exists(source_dir):
|
| 16 |
-
return None
|
| 17 |
-
|
| 18 |
-
files = [f for f in os.listdir(source_dir) if f.lower().endswith('.pdf')]
|
| 19 |
-
if not files:
|
| 20 |
-
return None
|
| 21 |
-
|
| 22 |
-
return os.path.join(source_dir, random.choice(files))
|
| 23 |
-
|
| 24 |
-
def main(output_dir=None, use_ocr=False):
|
| 25 |
-
# Setup paths
|
| 26 |
-
source_dir = os.path.join(_PROJECT_ROOT, "data", "files")
|
| 27 |
-
if output_dir is None:
|
| 28 |
-
output_dir = os.path.join(_PROJECT_ROOT, "data", "test_output")
|
| 29 |
-
|
| 30 |
-
# Clean up old test output
|
| 31 |
-
if os.path.exists(output_dir):
|
| 32 |
-
shutil.rmtree(output_dir)
|
| 33 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 34 |
-
|
| 35 |
-
print(f"Đang tìm file PDF để test...")
|
| 36 |
-
|
| 37 |
-
# 1. Thử lấy từ local data/files
|
| 38 |
-
file_path = get_random_local_pdf(source_dir)
|
| 39 |
-
|
| 40 |
-
if not file_path:
|
| 41 |
-
print(f"Không tìm thấy file PDF nào trong {source_dir}")
|
| 42 |
-
print("Hãy chạy 'python core/hash_file/hash_data_goc.py' để tải dữ liệu trước.")
|
| 43 |
-
return 1
|
| 44 |
-
|
| 45 |
-
filename = os.path.basename(file_path)
|
| 46 |
-
print(f"Đã chọn file test: {filename}")
|
| 47 |
-
print(f"Đường dẫn: {file_path}")
|
| 48 |
-
|
| 49 |
-
try:
|
| 50 |
-
# Khởi tạo processor
|
| 51 |
-
print("Khởi tạo DoclingProcessor...")
|
| 52 |
-
processor = DoclingProcessor(
|
| 53 |
-
output_dir=output_dir,
|
| 54 |
-
use_ocr=use_ocr,
|
| 55 |
-
timeout=None
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
# Parse file
|
| 59 |
-
print(f"Bắt đầu parse...")
|
| 60 |
-
result = processor.parse_document(file_path)
|
| 61 |
-
|
| 62 |
-
if result:
|
| 63 |
-
print(f"Test thành công!")
|
| 64 |
-
|
| 65 |
-
# Kiểm tra kết quả
|
| 66 |
-
output_files = os.listdir(output_dir)
|
| 67 |
-
md_files = [f for f in output_files if f.endswith('.md')]
|
| 68 |
-
|
| 69 |
-
if md_files:
|
| 70 |
-
print(f"File output: {md_files[0]}")
|
| 71 |
-
print(f"Thư mục output: {output_dir}")
|
| 72 |
-
|
| 73 |
-
# In thống kê sơ bộ cho Markdown
|
| 74 |
-
content_len = len(result)
|
| 75 |
-
preview = result[:200].replace('\n', ' ') + "..."
|
| 76 |
-
print(f" Kích thước: {content_len} ký tự")
|
| 77 |
-
print(f" Preview: {preview}")
|
| 78 |
-
else:
|
| 79 |
-
print(" Không tìm thấy file Markdown output dù hàm trả về kết quả.")
|
| 80 |
-
else:
|
| 81 |
-
print("Test thất bại: Hàm parse trả về None")
|
| 82 |
-
return 1
|
| 83 |
-
|
| 84 |
-
return 0
|
| 85 |
-
|
| 86 |
-
except Exception as e:
|
| 87 |
-
print(f"Lỗi ngoại lệ: {e}")
|
| 88 |
-
import traceback
|
| 89 |
-
traceback.print_exc()
|
| 90 |
-
return 1
|
| 91 |
-
|
| 92 |
-
if __name__ == "__main__":
|
| 93 |
-
import argparse
|
| 94 |
-
parser = argparse.ArgumentParser(description="Test Docling với 1 file PDF ngẫu nhiên từ data/files")
|
| 95 |
-
parser.add_argument("--output", help="Thư mục output cho test (mặc định: data/test_output)")
|
| 96 |
-
parser.add_argument("--ocr", action="store_true", help="Bật OCR")
|
| 97 |
-
args = parser.parse_args()
|
| 98 |
-
|
| 99 |
-
sys.exit(main(
|
| 100 |
-
output_dir=args.output,
|
| 101 |
-
use_ocr=args.ocr
|
| 102 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test/test_chunk.py
CHANGED
|
@@ -1,47 +1,57 @@
|
|
|
|
|
|
|
|
| 1 |
import sys
|
| 2 |
sys.path.insert(0, "/home/bahung/DoAn")
|
| 3 |
|
|
|
|
|
|
|
|
|
|
| 4 |
from core.rag.chunk import chunk_markdown_file
|
| 5 |
|
|
|
|
| 6 |
test_file = "data/data_process/chuong_trinh_dao_tao/1.1. Kỹ thuật Cơ điện tử.md"
|
| 7 |
|
| 8 |
print("=" * 70)
|
| 9 |
print(f" File: {test_file}")
|
| 10 |
print("=" * 70)
|
| 11 |
|
| 12 |
-
#
|
| 13 |
nodes = chunk_markdown_file(test_file)
|
| 14 |
|
| 15 |
-
print(f"\n
|
| 16 |
|
|
|
|
| 17 |
for i, node in enumerate(nodes):
|
| 18 |
content = node.get_content()
|
| 19 |
metadata = node.metadata
|
| 20 |
|
| 21 |
print(f"\n{'─' * 70}")
|
| 22 |
print(f" NODE #{i}")
|
| 23 |
-
print(f"
|
| 24 |
-
print(f"
|
| 25 |
if metadata:
|
| 26 |
print(f" Metadata: {metadata}")
|
| 27 |
print(f"{'─' * 70}")
|
|
|
|
|
|
|
| 28 |
content_preview = content[:200]
|
| 29 |
if len(content) > 200:
|
| 30 |
content_preview += "..."
|
| 31 |
print(content_preview)
|
| 32 |
|
|
|
|
| 33 |
with open("test_chunk.md", "w", encoding="utf-8") as f:
|
| 34 |
for i, node in enumerate(nodes):
|
| 35 |
content = node.get_content()
|
| 36 |
metadata = node.metadata
|
| 37 |
|
| 38 |
f.write(f"# NODE {i}\n")
|
| 39 |
-
f.write(f"**
|
| 40 |
f.write("**Metadata:**\n")
|
| 41 |
for key, value in metadata.items():
|
| 42 |
f.write(f"- {key}: {value}\n")
|
| 43 |
-
f.write("\n**
|
| 44 |
f.write(content)
|
| 45 |
f.write("\n\n---\n\n")
|
| 46 |
|
| 47 |
-
print("\n
|
|
|
|
| 1 |
+
"""Script test chunking markdown file."""
|
| 2 |
+
|
| 3 |
import sys
|
| 4 |
sys.path.insert(0, "/home/bahung/DoAn")
|
| 5 |
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
load_dotenv() # Load biến môi trường từ .env
|
| 8 |
+
|
| 9 |
from core.rag.chunk import chunk_markdown_file
|
| 10 |
|
| 11 |
+
# File test
|
| 12 |
test_file = "data/data_process/chuong_trinh_dao_tao/1.1. Kỹ thuật Cơ điện tử.md"
|
| 13 |
|
| 14 |
print("=" * 70)
|
| 15 |
print(f" File: {test_file}")
|
| 16 |
print("=" * 70)
|
| 17 |
|
| 18 |
+
# Chunk file markdown
|
| 19 |
nodes = chunk_markdown_file(test_file)
|
| 20 |
|
| 21 |
+
print(f"\n Tổng số nodes: {len(nodes)}\n")
|
| 22 |
|
| 23 |
+
# Hiển thị thông tin từng node
|
| 24 |
for i, node in enumerate(nodes):
|
| 25 |
content = node.get_content()
|
| 26 |
metadata = node.metadata
|
| 27 |
|
| 28 |
print(f"\n{'─' * 70}")
|
| 29 |
print(f" NODE #{i}")
|
| 30 |
+
print(f" Loại: {type(node).__name__}")
|
| 31 |
+
print(f" Độ dài: {len(content)} ký tự")
|
| 32 |
if metadata:
|
| 33 |
print(f" Metadata: {metadata}")
|
| 34 |
print(f"{'─' * 70}")
|
| 35 |
+
|
| 36 |
+
# Preview nội dung (tối đa 200 ký tự)
|
| 37 |
content_preview = content[:200]
|
| 38 |
if len(content) > 200:
|
| 39 |
content_preview += "..."
|
| 40 |
print(content_preview)
|
| 41 |
|
| 42 |
+
# Lưu kết quả ra file markdown để dễ xem
|
| 43 |
with open("test_chunk.md", "w", encoding="utf-8") as f:
|
| 44 |
for i, node in enumerate(nodes):
|
| 45 |
content = node.get_content()
|
| 46 |
metadata = node.metadata
|
| 47 |
|
| 48 |
f.write(f"# NODE {i}\n")
|
| 49 |
+
f.write(f"**Loại:** {type(node).__name__}\n\n")
|
| 50 |
f.write("**Metadata:**\n")
|
| 51 |
for key, value in metadata.items():
|
| 52 |
f.write(f"- {key}: {value}\n")
|
| 53 |
+
f.write("\n**Nội dung:**\n")
|
| 54 |
f.write(content)
|
| 55 |
f.write("\n\n---\n\n")
|
| 56 |
|
| 57 |
+
print("\n Hoàn tất!")
|