hungnha commited on
Commit
b91b0a5
·
1 Parent(s): c429a2d

change commit

Browse files
core/gradio/{gradio_rag_qwen.py → gradio_rag.py} RENAMED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  from __future__ import annotations
2
  import os
3
  import sys
@@ -8,6 +13,7 @@ import gradio as gr
8
  from dotenv import find_dotenv, load_dotenv
9
  from openai import OpenAI
10
 
 
11
  REPO_ROOT = Path(__file__).resolve().parents[2]
12
  if str(REPO_ROOT) not in sys.path:
13
  sys.path.insert(0, str(REPO_ROOT))
@@ -15,14 +21,18 @@ if str(REPO_ROOT) not in sys.path:
15
 
16
  @dataclass
17
  class GradioConfig:
 
18
  server_host: str = "127.0.0.1"
19
  server_port: int = 7860
20
 
 
21
  def _load_env() -> None:
 
22
  dotenv_path = find_dotenv(usecwd=True) or ""
23
  load_dotenv(dotenv_path=dotenv_path or None, override=False)
24
 
25
 
 
26
  from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings
27
  from core.rag.vector_store import ChromaConfig, ChromaVectorDB
28
  from core.rag.retrival import Retriever, RetrievalMode, get_retrieval_config
@@ -30,19 +40,20 @@ from core.rag.generator import RAGContextBuilder, build_context, build_prompt, S
30
 
31
  _load_env()
32
 
33
- RETRIEVAL_MODE = RetrievalMode.HYBRID_RERANK # Test with debug logs
 
 
 
 
34
 
35
- # LLM Config
36
- LLM_MODEL = os.getenv("LLM_MODEL", "qwen/qwen3-32b")
37
- LLM_API_BASE = "https://api.groq.com/openai/v1"
38
- LLM_API_KEY_ENV = "GROQ_API_KEY"
39
-
40
- # Load retrieval config
41
  GRADIO_CFG = GradioConfig()
42
  RETRIEVAL_CFG = get_retrieval_config()
43
 
44
 
45
  class AppState:
 
 
46
  def __init__(self) -> None:
47
  self.db: Optional[ChromaVectorDB] = None
48
  self.retriever: Optional[Retriever] = None
@@ -50,39 +61,38 @@ class AppState:
50
  self.client: Optional[OpenAI] = None
51
 
52
 
53
- STATE = AppState()
54
 
55
 
56
  def _init_resources() -> None:
 
57
  if STATE.db is not None:
58
  return
59
 
60
  print(f" Đang khởi tạo Database & Re-ranker...")
61
  print(f" Retrieval Mode: {RETRIEVAL_MODE.value}")
62
 
 
63
  emb = QwenEmbeddings(EmbeddingConfig())
64
-
65
  db_cfg = ChromaConfig()
66
 
67
- STATE.db = ChromaVectorDB(
68
- embedder=emb,
69
- config=db_cfg,
70
- )
71
  STATE.retriever = Retriever(vector_db=STATE.db)
72
 
73
- # LLM Client
74
  api_key = (os.getenv(LLM_API_KEY_ENV) or "").strip()
75
  if not api_key:
76
  raise RuntimeError(f"Missing {LLM_API_KEY_ENV}")
77
  STATE.client = OpenAI(api_key=api_key, base_url=LLM_API_BASE)
78
 
79
- # RAGContextBuilder - chỉ retrieve
80
  STATE.rag_builder = RAGContextBuilder(retriever=STATE.retriever)
81
 
82
  print(" Đã sẵn sàng!")
83
 
84
 
85
  def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
 
86
  _init_resources()
87
 
88
  assert STATE.db is not None
@@ -90,7 +100,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
90
  assert STATE.retriever is not None
91
  assert STATE.rag_builder is not None
92
 
93
- # Bước 1: Retrieve và prepare context
94
  prepared = STATE.rag_builder.retrieve_and_prepare(
95
  message,
96
  k=RETRIEVAL_CFG.top_k,
@@ -103,7 +113,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
103
  yield "Xin lỗi, tôi không tìm thấy thông tin phù hợp trong dữ liệu."
104
  return
105
 
106
- # Bước 2: Gọi LLM streaming để generate answer
107
  completion = STATE.client.chat.completions.create(
108
  model=LLM_MODEL,
109
  messages=[{"role": "user", "content": prepared["prompt"]}],
@@ -112,6 +122,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
112
  stream=True,
113
  )
114
 
 
115
  acc = ""
116
  for chunk in completion:
117
  delta = getattr(chunk.choices[0].delta, "content", "") or ""
@@ -119,7 +130,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
119
  acc += delta
120
  yield acc
121
 
122
- # Debug info with mode indicator
123
  debug_info = f"\n\n---\n\n**Retrieved (Top {len(results)} | Mode: {RETRIEVAL_MODE.value})**\n\n"
124
  for i, r in enumerate(results, 1):
125
  md = r.get("metadata", {})
@@ -127,7 +138,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
127
  rerank_score = r.get("rerank_score")
128
  distance = r.get("distance")
129
 
130
- # Extract metadata
131
  source = md.get("source_file", "N/A")
132
  doc_type = md.get("document_type", "N/A")
133
  header = md.get("header_path", "")
@@ -135,7 +146,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
135
  program = md.get("program_name", "")
136
  issued_year = md.get("issued_year", "")
137
 
138
- # Show relevant scores based on mode
139
  score_info = ""
140
  if rerank_score is not None:
141
  score_info += f"Rerank: `{rerank_score:.4f}` "
@@ -144,7 +155,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
144
  if not score_info:
145
  score_info = f"Rank: `{r.get('final_rank', i)}`"
146
 
147
- # Build metadata line
148
  meta_parts = [f"**Nguồn:** {source}", f"**Loại:** {doc_type}"]
149
  if issued_year:
150
  meta_parts.append(f"**Năm:** {issued_year}")
@@ -162,9 +173,7 @@ def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
162
  yield acc + debug_info
163
 
164
 
165
-
166
-
167
- # Create Gradio interface
168
  demo = gr.ChatInterface(
169
  fn=rag_chat,
170
  title=f"HUST RAG Assistant",
 
1
+ """
2
+ Giao diện Gradio cho hệ thống RAG - Trợ lý học vụ HUST.
3
+ Cho phép người dùng đặt câu hỏi và nhận câu trả lời từ hệ thống RAG.
4
+ """
5
+
6
  from __future__ import annotations
7
  import os
8
  import sys
 
13
  from dotenv import find_dotenv, load_dotenv
14
  from openai import OpenAI
15
 
16
+ # Thêm thư mục gốc vào Python path
17
  REPO_ROOT = Path(__file__).resolve().parents[2]
18
  if str(REPO_ROOT) not in sys.path:
19
  sys.path.insert(0, str(REPO_ROOT))
 
21
 
22
  @dataclass
23
  class GradioConfig:
24
+ """Cấu hình Gradio server: host và port."""
25
  server_host: str = "127.0.0.1"
26
  server_port: int = 7860
27
 
28
+
29
  def _load_env() -> None:
30
+ """Tải biến môi trường từ file .env."""
31
  dotenv_path = find_dotenv(usecwd=True) or ""
32
  load_dotenv(dotenv_path=dotenv_path or None, override=False)
33
 
34
 
35
+ # Import các module RAG
36
  from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings
37
  from core.rag.vector_store import ChromaConfig, ChromaVectorDB
38
  from core.rag.retrival import Retriever, RetrievalMode, get_retrieval_config
 
40
 
41
  _load_env()
42
 
43
+ # Cấu hình retrieval LLM
44
+ RETRIEVAL_MODE = RetrievalMode.HYBRID_RERANK # Chế độ tìm kiếm
45
+ LLM_MODEL = os.getenv("LLM_MODEL", "qwen/qwen3-32b") # Model LLM
46
+ LLM_API_BASE = "https://api.groq.com/openai/v1" # Groq API endpoint
47
+ LLM_API_KEY_ENV = "GROQ_API_KEY" # Biến môi trường chứa API key
48
 
49
+ # Khởi tạo cấu hình
 
 
 
 
 
50
  GRADIO_CFG = GradioConfig()
51
  RETRIEVAL_CFG = get_retrieval_config()
52
 
53
 
54
  class AppState:
55
+ """Quản lý trạng thái ứng dụng: database, retriever, LLM client."""
56
+
57
  def __init__(self) -> None:
58
  self.db: Optional[ChromaVectorDB] = None
59
  self.retriever: Optional[Retriever] = None
 
61
  self.client: Optional[OpenAI] = None
62
 
63
 
64
+ STATE = AppState() # Singleton state
65
 
66
 
67
  def _init_resources() -> None:
68
+ """Khởi tạo các tài nguyên: DB, Retriever, LLM client (lazy init)."""
69
  if STATE.db is not None:
70
  return
71
 
72
  print(f" Đang khởi tạo Database & Re-ranker...")
73
  print(f" Retrieval Mode: {RETRIEVAL_MODE.value}")
74
 
75
+ # Khởi tạo embedding và database
76
  emb = QwenEmbeddings(EmbeddingConfig())
 
77
  db_cfg = ChromaConfig()
78
 
79
+ STATE.db = ChromaVectorDB(embedder=emb, config=db_cfg)
 
 
 
80
  STATE.retriever = Retriever(vector_db=STATE.db)
81
 
82
+ # Khởi tạo LLM client
83
  api_key = (os.getenv(LLM_API_KEY_ENV) or "").strip()
84
  if not api_key:
85
  raise RuntimeError(f"Missing {LLM_API_KEY_ENV}")
86
  STATE.client = OpenAI(api_key=api_key, base_url=LLM_API_BASE)
87
 
88
+ # Khởi tạo RAG builder
89
  STATE.rag_builder = RAGContextBuilder(retriever=STATE.retriever)
90
 
91
  print(" Đã sẵn sàng!")
92
 
93
 
94
  def rag_chat(message: str, history: List[Dict[str, str]] | None = None):
95
+ """Xử lý chat: retrieve documents -> gọi LLM -> stream response"""
96
  _init_resources()
97
 
98
  assert STATE.db is not None
 
100
  assert STATE.retriever is not None
101
  assert STATE.rag_builder is not None
102
 
103
+ # Retrieve và chuẩn bị context
104
  prepared = STATE.rag_builder.retrieve_and_prepare(
105
  message,
106
  k=RETRIEVAL_CFG.top_k,
 
113
  yield "Xin lỗi, tôi không tìm thấy thông tin phù hợp trong dữ liệu."
114
  return
115
 
116
+ # Gọi LLM với streaming
117
  completion = STATE.client.chat.completions.create(
118
  model=LLM_MODEL,
119
  messages=[{"role": "user", "content": prepared["prompt"]}],
 
122
  stream=True,
123
  )
124
 
125
+ # Stream response
126
  acc = ""
127
  for chunk in completion:
128
  delta = getattr(chunk.choices[0].delta, "content", "") or ""
 
130
  acc += delta
131
  yield acc
132
 
133
+ # Thêm debug info về các documents đã retrieve
134
  debug_info = f"\n\n---\n\n**Retrieved (Top {len(results)} | Mode: {RETRIEVAL_MODE.value})**\n\n"
135
  for i, r in enumerate(results, 1):
136
  md = r.get("metadata", {})
 
138
  rerank_score = r.get("rerank_score")
139
  distance = r.get("distance")
140
 
141
+ # Trích xuất metadata
142
  source = md.get("source_file", "N/A")
143
  doc_type = md.get("document_type", "N/A")
144
  header = md.get("header_path", "")
 
146
  program = md.get("program_name", "")
147
  issued_year = md.get("issued_year", "")
148
 
149
+ # Format score
150
  score_info = ""
151
  if rerank_score is not None:
152
  score_info += f"Rerank: `{rerank_score:.4f}` "
 
155
  if not score_info:
156
  score_info = f"Rank: `{r.get('final_rank', i)}`"
157
 
158
+ # Format metadata
159
  meta_parts = [f"**Nguồn:** {source}", f"**Loại:** {doc_type}"]
160
  if issued_year:
161
  meta_parts.append(f"**Năm:** {issued_year}")
 
173
  yield acc + debug_info
174
 
175
 
176
+ # Tạo giao diện Gradio
 
 
177
  demo = gr.ChatInterface(
178
  fn=rag_chat,
179
  title=f"HUST RAG Assistant",
core/hash_file/hash_data_goc.py CHANGED
@@ -1,13 +1,11 @@
1
  import sys
2
- import os
3
  import json
4
  import shutil
5
  from pathlib import Path
6
 
7
- current_file = Path(__file__).resolve()
8
- project_root = current_file.parent.parent.parent
9
- if str(project_root) not in sys.path:
10
- sys.path.insert(0, str(project_root))
11
 
12
  from core.hash_file.hash_file import HashProcessor
13
 
@@ -16,130 +14,113 @@ HF_RAW_PDF_REPO = "hungnha/Do_An_Dataset"
16
 
17
 
18
  def download_from_hf(cache_dir: Path) -> Path:
19
- try:
20
- from huggingface_hub import snapshot_download
21
- except ImportError:
22
- print("Installing huggingface_hub...")
23
- os.system("pip install huggingface_hub")
24
- from huggingface_hub import snapshot_download
25
 
 
26
  if cache_dir.exists() and any(cache_dir.iterdir()):
27
  print(f"Cache đã tồn tại: {cache_dir}")
28
  return cache_dir / "data_rag"
29
 
30
- print(f"Đang tải PDF từ HuggingFace: {HF_RAW_PDF_REPO}")
31
  snapshot_download(
32
  repo_id=HF_RAW_PDF_REPO,
33
  repo_type="dataset",
34
  local_dir=str(cache_dir),
35
  local_dir_use_symlinks=False,
36
  )
37
- print("Tải xong!")
38
  return cache_dir / "data_rag"
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def main():
42
  import argparse
43
- parser = argparse.ArgumentParser()
44
- parser.add_argument("--source", type=str, help="Đường dẫn local tới PDF (bỏ qua để tải từ HF)")
45
  parser.add_argument("--download-only", action="store_true", help="Chỉ tải về, không copy")
46
  args = parser.parse_args()
47
 
48
- data_dir = project_root / "data"
49
  files_dir = data_dir / "files"
50
  files_dir.mkdir(parents=True, exist_ok=True)
 
51
 
52
- # Xác định nguồn PDF
53
  if args.source:
54
  source_root = Path(args.source)
55
  if not source_root.exists():
56
- print(f"Thư mục nguồn không tồn tại: {source_root}")
57
- return
58
  else:
59
  # Tải từ HuggingFace
60
- cache_dir = data_dir / "raw_pdf_cache"
61
- source_root = download_from_hf(cache_dir)
62
-
63
  if args.download_only:
64
- print(f"PDF đã cache tại: {source_root}")
65
- return
66
 
67
  if not source_root.exists():
68
- print(f"Không tìm thấy thư mục PDF: {source_root}")
69
- return
70
-
71
- hash_processor = HashProcessor(verbose=False)
72
- hash_file_path = data_dir / "hash_data_goc_index.json"
73
-
74
- existing_hashes = {}
75
- if hash_file_path.exists():
76
- with open(hash_file_path, 'r', encoding='utf-8') as f:
77
- data = json.load(f)
78
- existing_hashes = {item['filename']: item['hash'] for item in data.get('train', [])}
79
- print(f"Đã tải {len(existing_hashes)} hash từ index cũ")
80
 
81
- print(f"Đang quét file từ: {source_root}")
 
 
82
 
83
- pdf_files = list(source_root.rglob("*.pdf"))
84
- print(f"Tìm thấy {len(pdf_files)} files PDF\n")
85
 
86
- hash_results = []
87
- skipped = 0
88
- processed = 0
 
 
89
 
90
- for idx, source_path in enumerate(pdf_files):
91
- relative_path = source_path.relative_to(source_root)
92
- filename = str(relative_path)
93
- dest_path = files_dir / relative_path
94
- dest_path.parent.mkdir(parents=True, exist_ok=True)
95
-
96
- # Kiểm tra file đã tồn tại và hash khớp chưa
97
- if dest_path.exists() and filename in existing_hashes:
98
- current_hash = hash_processor.get_file_hash(str(dest_path))
99
- if current_hash == existing_hashes[filename]:
100
- hash_results.append({
101
- 'filename': filename,
102
- 'hash': current_hash,
103
- 'index': idx
104
- })
105
- skipped += 1
106
- continue
107
-
108
- try:
109
- shutil.copy2(source_path, dest_path)
110
-
111
- file_hash = hash_processor.get_file_hash(str(dest_path))
112
- if file_hash is None:
113
- print(f"Lỗi tính hash cho file {filename}")
114
- continue
115
-
116
- hash_results.append({
117
- 'filename': filename,
118
- 'hash': file_hash,
119
- 'index': idx
120
- })
121
- processed += 1
122
-
123
- if (idx + 1) % 10 == 0:
124
- print(f"Processed {idx + 1}/{len(pdf_files)} files")
125
-
126
- except Exception as e:
127
- print(f"Lỗi khi xử lý file {filename}: {e}")
128
- continue
129
-
130
- output_data = {
131
- 'train': hash_results,
132
- 'total_files': len(hash_results)
133
- }
134
-
135
- with open(hash_file_path, 'w', encoding='utf-8') as f:
136
- json.dump(output_data, f, ensure_ascii=False, indent=2)
137
-
138
- print(f"\nHoàn tất!")
139
- print(f"Tổng số file: {len(hash_results)}")
140
- print(f"Đã xử lý mới: {processed}")
141
- print(f"Đã bỏ qua (trùng hash): {skipped}")
142
- print(f"File index: {hash_file_path}")
143
 
144
 
145
  if __name__ == "__main__":
 
1
  import sys
 
2
  import json
3
  import shutil
4
  from pathlib import Path
5
 
6
+ PROJECT_ROOT = Path(__file__).resolve().parents[2]
7
+ if str(PROJECT_ROOT) not in sys.path:
8
+ sys.path.insert(0, str(PROJECT_ROOT))
 
9
 
10
  from core.hash_file.hash_file import HashProcessor
11
 
 
14
 
15
 
16
  def download_from_hf(cache_dir: Path) -> Path:
17
+ """Tải PDF từ HuggingFace, trả về đường dẫn tới folder data_rag."""
18
+ from huggingface_hub import snapshot_download
 
 
 
 
19
 
20
+ # Kiểm tra cache đã tồn tại chưa
21
  if cache_dir.exists() and any(cache_dir.iterdir()):
22
  print(f"Cache đã tồn tại: {cache_dir}")
23
  return cache_dir / "data_rag"
24
 
25
+ print(f"Đang tải từ HuggingFace: {HF_RAW_PDF_REPO}")
26
  snapshot_download(
27
  repo_id=HF_RAW_PDF_REPO,
28
  repo_type="dataset",
29
  local_dir=str(cache_dir),
30
  local_dir_use_symlinks=False,
31
  )
 
32
  return cache_dir / "data_rag"
33
 
34
 
35
+ def load_existing_hashes(path: Path) -> dict:
36
+ """Đọc hash index cũ từ file JSON."""
37
+ if not path.exists():
38
+ return {}
39
+ try:
40
+ data = json.loads(path.read_text(encoding='utf-8'))
41
+ return {item['filename']: item['hash'] for item in data.get('train', [])}
42
+ except Exception:
43
+ return {}
44
+
45
+
46
+ def process_pdfs(source_root: Path, dest_dir: Path, existing_hashes: dict) -> tuple:
47
+ """Copy PDFs và tính hash. Trả về (results, processed, skipped)."""
48
+ hasher = HashProcessor(verbose=False)
49
+ pdf_files = list(source_root.rglob("*.pdf"))
50
+ print(f"Tìm thấy {len(pdf_files)} file PDF\n")
51
+
52
+ results, processed, skipped = [], 0, 0
53
+
54
+ for idx, src in enumerate(pdf_files):
55
+ rel_path = str(src.relative_to(source_root))
56
+ dest = dest_dir / rel_path
57
+ dest.parent.mkdir(parents=True, exist_ok=True)
58
+
59
+ # Bỏ qua nếu file không thay đổi (hash khớp)
60
+ if dest.exists() and rel_path in existing_hashes:
61
+ current_hash = hasher.get_file_hash(str(dest))
62
+ if current_hash == existing_hashes[rel_path]:
63
+ results.append({'filename': rel_path, 'hash': current_hash, 'index': idx})
64
+ skipped += 1
65
+ continue
66
+
67
+ # Copy và tính hash
68
+ try:
69
+ shutil.copy2(src, dest)
70
+ file_hash = hasher.get_file_hash(str(dest))
71
+ if file_hash:
72
+ results.append({'filename': rel_path, 'hash': file_hash, 'index': idx})
73
+ processed += 1
74
+ except Exception as e:
75
+ print(f"Lỗi: {rel_path} - {e}")
76
+
77
+ # Hiển thị tiến độ
78
+ if (idx + 1) % 20 == 0:
79
+ print(f"Tiến độ: {idx + 1}/{len(pdf_files)}")
80
+
81
+ return results, processed, skipped
82
+
83
+
84
  def main():
85
  import argparse
86
+ parser = argparse.ArgumentParser(description="Tải PDF và tạo hash index")
87
+ parser.add_argument("--source", type=str, help="Đường dẫn local tới PDFs (bỏ qua tải HF)")
88
  parser.add_argument("--download-only", action="store_true", help="Chỉ tải về, không copy")
89
  args = parser.parse_args()
90
 
91
+ data_dir = PROJECT_ROOT / "data"
92
  files_dir = data_dir / "files"
93
  files_dir.mkdir(parents=True, exist_ok=True)
94
+ hash_file = data_dir / "hash_data_goc_index.json"
95
 
96
+ # Xác định thư mục nguồn
97
  if args.source:
98
  source_root = Path(args.source)
99
  if not source_root.exists():
100
+ return print(f"Không tìm thấy thư mục nguồn: {source_root}")
 
101
  else:
102
  # Tải từ HuggingFace
103
+ source_root = download_from_hf(data_dir / "raw_pdf_cache")
 
 
104
  if args.download_only:
105
+ return print(f"PDF đã cache tại: {source_root}")
 
106
 
107
  if not source_root.exists():
108
+ return print(f"Không tìm thấy thư mục PDF: {source_root}")
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ # Xử
111
+ existing = load_existing_hashes(hash_file)
112
+ print(f"Đã tải {len(existing)} hash từ index cũ")
113
 
114
+ results, processed, skipped = process_pdfs(source_root, files_dir, existing)
 
115
 
116
+ # Lưu kết quả
117
+ hash_file.write_text(json.dumps({
118
+ 'train': results,
119
+ 'total_files': len(results)
120
+ }, ensure_ascii=False, indent=2), encoding='utf-8')
121
 
122
+ print(f"\nHoàn tất! Tổng: {len(results)} | Mới: {processed} | Bỏ qua: {skipped}")
123
+ print(f"File index: {hash_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  if __name__ == "__main__":
core/hash_file/hash_file.py CHANGED
@@ -9,19 +9,23 @@ from pathlib import Path
9
  from typing import Dict, List, Optional
10
  from datetime import datetime
11
 
12
- # Constants
13
- CHUNK_SIZE = 8192 # 8KB chunks for reading files
14
  DEFAULT_FILE_EXTENSION = '.pdf'
15
 
 
16
  class HashProcessor:
 
17
 
18
  def __init__(self, verbose: bool = True):
 
19
  self.verbose = verbose
20
  self.logger = logging.getLogger(__name__)
21
  if not verbose:
22
  self.logger.setLevel(logging.WARNING)
23
 
24
  def get_file_hash(self, path: str) -> Optional[str]:
 
25
  h = hashlib.sha256()
26
  try:
27
  with open(path, "rb") as f:
@@ -41,6 +45,7 @@ class HashProcessor:
41
  file_extension: str = DEFAULT_FILE_EXTENSION,
42
  recursive: bool = False
43
  ) -> Dict[str, List[Dict[str, str]]]:
 
44
  source_path = Path(source_dir)
45
  if not source_path.exists():
46
  raise FileNotFoundError(f"Thư mục không tồn tại: {source_dir}")
@@ -73,6 +78,7 @@ class HashProcessor:
73
  return hash_to_files
74
 
75
  def load_processed_index(self, index_file: str) -> Dict:
 
76
  if os.path.exists(index_file):
77
  try:
78
  with open(index_file, "r", encoding="utf-8") as f:
@@ -86,6 +92,10 @@ class HashProcessor:
86
  return {}
87
 
88
  def save_processed_index(self, index_file: str, processed_hashes: Dict) -> None:
 
 
 
 
89
  temp_name = None
90
  try:
91
  os.makedirs(os.path.dirname(index_file), exist_ok=True)
@@ -106,8 +116,9 @@ class HashProcessor:
106
  os.remove(temp_name)
107
 
108
  def get_current_timestamp(self) -> str:
 
109
  return datetime.now().isoformat()
110
 
111
  def get_string_hash(self, text: str) -> str:
 
112
  return hashlib.sha256(text.encode('utf-8')).hexdigest()
113
-
 
9
  from typing import Dict, List, Optional
10
  from datetime import datetime
11
 
12
+ # Hằng số
13
+ CHUNK_SIZE = 8192 # Đọc file theo chunk 8KB
14
  DEFAULT_FILE_EXTENSION = '.pdf'
15
 
16
+
17
  class HashProcessor:
18
+ """Lớp xử lý hash cho files - dùng để phát hiện thay đổi và tránh xử lý lại."""
19
 
20
  def __init__(self, verbose: bool = True):
21
+ """Khởi tạo HashProcessor."""
22
  self.verbose = verbose
23
  self.logger = logging.getLogger(__name__)
24
  if not verbose:
25
  self.logger.setLevel(logging.WARNING)
26
 
27
  def get_file_hash(self, path: str) -> Optional[str]:
28
+ """Tính SHA256 hash của một file."""
29
  h = hashlib.sha256()
30
  try:
31
  with open(path, "rb") as f:
 
45
  file_extension: str = DEFAULT_FILE_EXTENSION,
46
  recursive: bool = False
47
  ) -> Dict[str, List[Dict[str, str]]]:
48
+ """Quét thư mục và tính hash cho tất cả files."""
49
  source_path = Path(source_dir)
50
  if not source_path.exists():
51
  raise FileNotFoundError(f"Thư mục không tồn tại: {source_dir}")
 
78
  return hash_to_files
79
 
80
  def load_processed_index(self, index_file: str) -> Dict:
81
+ """Đọc file index đã xử lý từ JSON."""
82
  if os.path.exists(index_file):
83
  try:
84
  with open(index_file, "r", encoding="utf-8") as f:
 
92
  return {}
93
 
94
  def save_processed_index(self, index_file: str, processed_hashes: Dict) -> None:
95
+ """Lưu index đã xử lý vào file JSON (atomic write).
96
+
97
+ Ghi vào file tạm trước, sau đó rename để đảm bảo an toàn.
98
+ """
99
  temp_name = None
100
  try:
101
  os.makedirs(os.path.dirname(index_file), exist_ok=True)
 
116
  os.remove(temp_name)
117
 
118
  def get_current_timestamp(self) -> str:
119
+ """Lấy timestamp hiện tại theo định dạng ISO."""
120
  return datetime.now().isoformat()
121
 
122
  def get_string_hash(self, text: str) -> str:
123
+ """Tính SHA256 hash của một chuỗi text."""
124
  return hashlib.sha256(text.encode('utf-8')).hexdigest()
 
core/preprocessing/docling_processor.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import re
3
  import gc
 
4
  import signal
5
  import logging
6
  from datetime import datetime
@@ -12,19 +13,35 @@ from docling.datamodel.pipeline_options import PdfPipelineOptions, TableStructur
12
  from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
13
  from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
14
 
 
 
 
 
 
 
 
15
 
16
  class DoclingProcessor:
 
 
17
  def __init__(self, output_dir: str, use_ocr: bool = True, timeout: int = 300, images_scale: float = 3.0):
 
18
  self.output_dir = output_dir
19
  self.timeout = timeout
20
  self.logger = logging.getLogger(__name__)
 
21
  os.makedirs(output_dir, exist_ok=True)
22
 
23
- # Pipeline options
 
 
 
 
24
  opts = PdfPipelineOptions(do_ocr=use_ocr, do_table_structure=True)
25
  opts.table_structure_options = TableStructureOptions(do_cell_matching=True, mode=TableFormerMode.ACCURATE)
26
  opts.images_scale = images_scale
27
 
 
28
  if use_ocr:
29
  ocr = EasyOcrOptions()
30
  ocr.lang = ["vi"]
@@ -34,39 +51,69 @@ class DoclingProcessor:
34
  self.converter = DocumentConverter(format_options={
35
  InputFormat.PDF: FormatOption(backend=PyPdfiumDocumentBackend, pipeline_cls=StandardPdfPipeline, pipeline_options=opts)
36
  })
37
- self.logger.info(f"🔧 Docling | OCR={use_ocr} | Table=accurate | Scale={images_scale} | timeout={timeout}s")
38
 
39
  def clean_markdown(self, text: str) -> str:
 
40
  text = re.sub(r'\n\s*Trang\s+\d+\s*\n', '\n', text)
41
  return re.sub(r'\n{3,}', '\n\n', text).strip()
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def parse_document(self, file_path: str) -> str | None:
 
44
  if not os.path.exists(file_path):
45
  return None
46
  filename = os.path.basename(file_path)
47
  try:
 
48
  signal.signal(signal.SIGALRM, lambda s, f: (_ for _ in ()).throw(TimeoutError()))
49
  signal.alarm(self.timeout)
 
50
  result = self.converter.convert(file_path)
51
  md = result.document.export_to_markdown(image_placeholder="")
52
  signal.alarm(0)
 
53
  md = self.clean_markdown(md)
 
54
  return f"---\nfilename: {filename}\nfilepath: {file_path}\npage_count: {len(result.document.pages)}\nprocessed_at: {datetime.now().isoformat()}\n---\n\n{md}"
55
  except TimeoutError:
56
- self.logger.warning(f" Timeout: {filename}")
57
  signal.alarm(0)
58
  return None
59
  except Exception as e:
60
- self.logger.error(f" Failed: {filename}: {e}")
61
  signal.alarm(0)
62
  return None
63
 
64
  def parse_directory(self, source_dir: str) -> dict:
 
65
  source_path = Path(source_dir)
66
  pdf_files = list(source_path.rglob("*.pdf"))
67
- self.logger.info(f" Found {len(pdf_files)} PDFs in {source_dir}")
68
 
69
  results = {"total": len(pdf_files), "parsed": 0, "skipped": 0, "errors": 0}
 
70
  for i, fp in enumerate(pdf_files):
71
  try:
72
  rel = fp.relative_to(source_path)
@@ -75,20 +122,33 @@ class DoclingProcessor:
75
  out = Path(self.output_dir) / rel.with_suffix(".md")
76
  out.parent.mkdir(parents=True, exist_ok=True)
77
 
78
- if out.exists():
 
 
 
79
  results["skipped"] += 1
80
  continue
81
 
82
- md = self.parse_document(str(fp))
 
 
 
83
  if md:
84
  out.write_text(md, encoding="utf-8")
85
  results["parsed"] += 1
 
 
 
86
  else:
87
  results["errors"] += 1
88
 
 
89
  if (i + 1) % 10 == 0:
90
  gc.collect()
91
- self.logger.info(f" {i+1}/{len(pdf_files)} (skip: {results['skipped']})")
 
 
 
92
 
93
- self.logger.info(f" Done: {results['parsed']} parsed, {results['skipped']} skipped, {results['errors']} errors")
94
  return results
 
1
  import os
2
  import re
3
  import gc
4
+ import sys
5
  import signal
6
  import logging
7
  from datetime import datetime
 
13
  from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
14
  from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
15
 
16
+ # Thêm project root vào path để import HashProcessor
17
+ PROJECT_ROOT = Path(__file__).resolve().parents[2]
18
+ if str(PROJECT_ROOT) not in sys.path:
19
+ sys.path.insert(0, str(PROJECT_ROOT))
20
+
21
+ from core.hash_file.hash_file import HashProcessor
22
+
23
 
24
  class DoclingProcessor:
25
+ """Chuyển đổi PDF sang Markdown bằng Docling."""
26
+
27
  def __init__(self, output_dir: str, use_ocr: bool = True, timeout: int = 300, images_scale: float = 3.0):
28
+ """Khởi tạo processor với cấu hình OCR và table extraction."""
29
  self.output_dir = output_dir
30
  self.timeout = timeout
31
  self.logger = logging.getLogger(__name__)
32
+ self.hasher = HashProcessor(verbose=False)
33
  os.makedirs(output_dir, exist_ok=True)
34
 
35
+ # File lưu hash index
36
+ self.hash_index_path = Path(output_dir) / "docling_hash_index.json"
37
+ self.hash_index = self.hasher.load_processed_index(str(self.hash_index_path))
38
+
39
+ # Cấu hình pipeline PDF
40
  opts = PdfPipelineOptions(do_ocr=use_ocr, do_table_structure=True)
41
  opts.table_structure_options = TableStructureOptions(do_cell_matching=True, mode=TableFormerMode.ACCURATE)
42
  opts.images_scale = images_scale
43
 
44
+ # Cấu hình OCR tiếng Việt
45
  if use_ocr:
46
  ocr = EasyOcrOptions()
47
  ocr.lang = ["vi"]
 
51
  self.converter = DocumentConverter(format_options={
52
  InputFormat.PDF: FormatOption(backend=PyPdfiumDocumentBackend, pipeline_cls=StandardPdfPipeline, pipeline_options=opts)
53
  })
54
+ self.logger.info(f"Docling | OCR={use_ocr} | Table=accurate | Scale={images_scale} | timeout={timeout}s")
55
 
56
  def clean_markdown(self, text: str) -> str:
57
+ """Xóa số trang và khoảng trắng thừa."""
58
  text = re.sub(r'\n\s*Trang\s+\d+\s*\n', '\n', text)
59
  return re.sub(r'\n{3,}', '\n\n', text).strip()
60
 
61
+ def _should_process(self, pdf_path: str, output_path: Path) -> bool:
62
+ """Kiểm tra xem file PDF có cần xử lý lại không (dựa trên hash)."""
63
+ # Nếu output chưa tồn tại -> cần xử lý
64
+ if not output_path.exists():
65
+ return True
66
+
67
+ # Tính hash file PDF hiện tại
68
+ current_hash = self.hasher.get_file_hash(pdf_path)
69
+ if not current_hash:
70
+ return True
71
+
72
+ # So sánh với hash đã lưu
73
+ saved_hash = self.hash_index.get(pdf_path, {}).get("hash")
74
+ return current_hash != saved_hash
75
+
76
+ def _save_hash(self, pdf_path: str, file_hash: str) -> None:
77
+ """Lưu hash của file đã xử lý vào index."""
78
+ self.hash_index[pdf_path] = {
79
+ "hash": file_hash,
80
+ "processed_at": self.hasher.get_current_timestamp()
81
+ }
82
+
83
  def parse_document(self, file_path: str) -> str | None:
84
+ """Chuyển đổi 1 file PDF sang Markdown với timeout."""
85
  if not os.path.exists(file_path):
86
  return None
87
  filename = os.path.basename(file_path)
88
  try:
89
+ # Đặt timeout để tránh treo
90
  signal.signal(signal.SIGALRM, lambda s, f: (_ for _ in ()).throw(TimeoutError()))
91
  signal.alarm(self.timeout)
92
+
93
  result = self.converter.convert(file_path)
94
  md = result.document.export_to_markdown(image_placeholder="")
95
  signal.alarm(0)
96
+
97
  md = self.clean_markdown(md)
98
+ # Thêm frontmatter metadata
99
  return f"---\nfilename: {filename}\nfilepath: {file_path}\npage_count: {len(result.document.pages)}\nprocessed_at: {datetime.now().isoformat()}\n---\n\n{md}"
100
  except TimeoutError:
101
+ self.logger.warning(f"Timeout: {filename}")
102
  signal.alarm(0)
103
  return None
104
  except Exception as e:
105
+ self.logger.error(f"Lỗi: {filename}: {e}")
106
  signal.alarm(0)
107
  return None
108
 
109
  def parse_directory(self, source_dir: str) -> dict:
110
+ """Xử lý toàn bộ thư mục PDF, bỏ qua file không thay đổi (dựa trên hash)."""
111
  source_path = Path(source_dir)
112
  pdf_files = list(source_path.rglob("*.pdf"))
113
+ self.logger.info(f"Tìm thấy {len(pdf_files)} file PDF trong {source_dir}")
114
 
115
  results = {"total": len(pdf_files), "parsed": 0, "skipped": 0, "errors": 0}
116
+
117
  for i, fp in enumerate(pdf_files):
118
  try:
119
  rel = fp.relative_to(source_path)
 
122
  out = Path(self.output_dir) / rel.with_suffix(".md")
123
  out.parent.mkdir(parents=True, exist_ok=True)
124
 
125
+ pdf_path = str(fp)
126
+
127
+ # Kiểm tra hash để quyết định có cần xử lý không
128
+ if not self._should_process(pdf_path, out):
129
  results["skipped"] += 1
130
  continue
131
 
132
+ # Tính hash trước khi xử lý
133
+ file_hash = self.hasher.get_file_hash(pdf_path)
134
+
135
+ md = self.parse_document(pdf_path)
136
  if md:
137
  out.write_text(md, encoding="utf-8")
138
  results["parsed"] += 1
139
+ # Lưu hash sau khi xử lý thành công
140
+ if file_hash:
141
+ self._save_hash(pdf_path, file_hash)
142
  else:
143
  results["errors"] += 1
144
 
145
+ # Dọn memory mỗi 10 files
146
  if (i + 1) % 10 == 0:
147
  gc.collect()
148
+ self.logger.info(f"{i+1}/{len(pdf_files)} (bỏ qua: {results['skipped']})")
149
+
150
+ # Lưu hash index sau khi xử lý xong
151
+ self.hasher.save_processed_index(str(self.hash_index_path), self.hash_index)
152
 
153
+ self.logger.info(f"Xong: {results['parsed']} đã xử lý, {results['skipped']} bỏ qua, {results['errors']} lỗi")
154
  return results
core/preprocessing/pdf_parser.py CHANGED
@@ -1,19 +1,22 @@
1
  from docling_processor import DoclingProcessor
2
 
3
- PDF_FILE = "data/data_raw/quyet_dinh/quy-dinh-chuan-ngoai-ngu-2021.pdf"
4
- SOURCE_DIR = "data/data_raw"
5
- OUTPUT_DIR = "data"
6
- USE_OCR = False
 
7
 
8
 
9
  if __name__ == "__main__":
10
  processor = DoclingProcessor(OUTPUT_DIR, use_ocr=USE_OCR)
11
 
12
  if PDF_FILE:
13
- print(f"Parsing: {PDF_FILE}")
 
14
  result = processor.parse_document(PDF_FILE)
15
- print(f"Done: {result}" if result else "Skipped/failed")
16
  else:
17
- print(f"Parsing: {SOURCE_DIR}")
 
18
  r = processor.parse_directory(SOURCE_DIR)
19
- print(f"Total: {r['total']} | OK: {r['parsed']} | Skip: {r['skipped']} | Err: {r['errors']}")
 
1
  from docling_processor import DoclingProcessor
2
 
3
+ # Cấu hình đường dẫn
4
+ PDF_FILE = "" # File đơn lẻ (để trống nếu muốn parse cả thư mục)
5
+ SOURCE_DIR = "data/data_raw" # Thư mục chứa PDFs
6
+ OUTPUT_DIR = "data" # Thư mục xuất Markdown
7
+ USE_OCR = False # Bật OCR cho PDF scan
8
 
9
 
10
  if __name__ == "__main__":
11
  processor = DoclingProcessor(OUTPUT_DIR, use_ocr=USE_OCR)
12
 
13
  if PDF_FILE:
14
+ # Parse 1 file đơn lẻ
15
+ print(f"Đang xử lý: {PDF_FILE}")
16
  result = processor.parse_document(PDF_FILE)
17
+ print("Xong!" if result else "Lỗi hoặc bỏ qua")
18
  else:
19
+ # Parse cả thư mục
20
+ print(f"Đang xử lý thư mục: {SOURCE_DIR}")
21
  r = processor.parse_directory(SOURCE_DIR)
22
+ print(f"Tổng: {r['total']} | Thành công: {r['parsed']} | Bỏ qua: {r['skipped']} | Lỗi: {r['errors']}")
core/rag/chunk.py CHANGED
@@ -10,37 +10,41 @@ from llama_index.core import Document
10
  from llama_index.core.node_parser import MarkdownNodeParser, SentenceSplitter
11
  from llama_index.core.schema import BaseNode, TextNode
12
 
13
- # Config
14
  CHUNK_SIZE = 1500
15
  CHUNK_OVERLAP = 150
16
  MIN_CHUNK_SIZE = 200
17
  TABLE_ROWS_PER_CHUNK = 15
18
 
19
- # Small-to-Big Config
20
  ENABLE_TABLE_SUMMARY = True
21
- MIN_TABLE_ROWS_FOR_SUMMARY = 0 # Summarize ALL tables regardless of size
22
- SUMMARY_MODEL = "nex-agi/DeepSeek-V3.1-Nex-N1"
23
- SILICONFLOW_BASE_URL = "https://api.siliconflow.com/v1"
24
 
25
- # Regex
26
  COURSE_PATTERN = re.compile(r"Học\s*phần\s+(.+?)\s*\(\s*m[ãa]\s+([^\)]+)\)", re.I | re.DOTALL)
27
  TABLE_PLACEHOLDER = re.compile(r"__TBL_(\d+)__")
28
  HEADER_KEYWORDS = {'TT', 'STT', 'MÃ', 'TÊN', 'KHỐI', 'SỐ', 'ID', 'NO', '#'}
29
  FRONTMATTER_PATTERN = re.compile(r"^---\s*\n(.*?)\n---\s*\n", re.DOTALL)
30
- # Pattern để trích xuất số bảng và tiêu đề (ví dụ: "Bảng 3.1 Danh mục các học phần...")
31
  TABLE_TITLE_PATTERN = re.compile(r"(?:^|\n)#+\s*(?:Bảng|BẢNG)\s*(\d+(?:\.\d+)?)\s*[.:]*\s*(.+?)(?:\n|$)", re.IGNORECASE)
32
 
33
 
34
  def _is_table_row(line: str) -> bool:
 
35
  s = line.strip()
36
  return s.startswith("|") and s.endswith("|") and s.count("|") >= 2
37
 
 
38
  def _is_separator(line: str) -> bool:
 
39
  if not _is_table_row(line):
40
  return False
41
  return not line.strip().replace("|", "").replace("-", "").replace(":", "").replace(" ", "")
42
 
 
43
  def _is_header(line: str) -> bool:
 
44
  if not _is_table_row(line):
45
  return False
46
  cells = [c.strip() for c in line.split("|") if c.strip()]
@@ -50,6 +54,7 @@ def _is_header(line: str) -> bool:
50
 
51
 
52
  def _extract_tables(text: str) -> Tuple[List[Tuple[str, List[str]]], str]:
 
53
  lines, tables, last_header, i = text.split("\n"), [], None, 0
54
 
55
  while i < len(lines) - 1:
@@ -73,7 +78,7 @@ def _extract_tables(text: str) -> Tuple[List[Tuple[str, List[str]]], str]:
73
  else:
74
  i += 1
75
 
76
- # Replace tables with placeholders
77
  result, tbl_idx, i = [], 0, 0
78
  while i < len(lines):
79
  if tbl_idx < len(tables) and i < len(lines) - 1 and _is_table_row(lines[i]) and _is_separator(lines[i + 1]):
@@ -90,6 +95,7 @@ def _extract_tables(text: str) -> Tuple[List[Tuple[str, List[str]]], str]:
90
 
91
 
92
  def _split_table(header: str, rows: List[str], max_rows: int = TABLE_ROWS_PER_CHUNK) -> List[str]:
 
93
  if len(rows) <= max_rows:
94
  return [header + "\n".join(rows)]
95
 
@@ -98,26 +104,29 @@ def _split_table(header: str, rows: List[str], max_rows: int = TABLE_ROWS_PER_CH
98
  chunk_rows = rows[i:i + max_rows]
99
  chunks.append(chunk_rows)
100
 
101
- # Merge last chunk if too small (< 5 rows)
102
  if len(chunks) > 1 and len(chunks[-1]) < 5:
103
  chunks[-2].extend(chunks[-1])
104
  chunks.pop()
105
 
106
  return [header + "\n".join(r) for r in chunks]
107
 
 
108
  _summary_client: Optional[OpenAI] = None
109
 
 
110
  def _get_summary_client() -> Optional[OpenAI]:
 
111
  global _summary_client
112
  if _summary_client is not None:
113
  return _summary_client
114
 
115
- api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
116
  if not api_key:
117
- print("SILICONFLOW_API_KEY not set. Table summarization disabled.")
118
  return None
119
 
120
- _summary_client = OpenAI(api_key=api_key, base_url=SILICONFLOW_BASE_URL)
121
  return _summary_client
122
 
123
 
@@ -130,17 +139,17 @@ def _summarize_table(
130
  max_retries: int = 5,
131
  base_delay: float = 2.0
132
  ) -> str:
133
- """Summarize a table with retry logic. Raises exception if all retries fail."""
134
  import time
135
 
136
  if not ENABLE_TABLE_SUMMARY:
137
- raise RuntimeError("Table summarization is disabled but required. Set ENABLE_TABLE_SUMMARY = True")
138
 
139
  client = _get_summary_client()
140
  if client is None:
141
- raise RuntimeError("SILICONFLOW_API_KEY not set. Cannot summarize tables.")
142
 
143
- # Build table identifier string
144
  table_id_parts = []
145
  if table_number:
146
  table_id_parts.append(f"Bảng {table_number}")
@@ -149,7 +158,7 @@ def _summarize_table(
149
  if source_file:
150
  table_id_parts.append(f"từ file {source_file}")
151
 
152
- table_identifier = " - ".join(table_id_parts) if table_id_parts else "Unknown table"
153
 
154
  prompt = f"""Tóm tắt ngắn gọn nội dung bảng sau bằng tiếng Việt.
155
 
@@ -179,20 +188,17 @@ Bảng:
179
  if summary.strip():
180
  return summary.strip()
181
  else:
182
- raise ValueError("Empty summary returned from API")
183
 
184
  except Exception as e:
185
  last_error = e
186
- delay = base_delay * (2 ** attempt) # Exponential backoff: 2, 4, 8, 16, 32 seconds
187
- print(f"⚠️ Retry {attempt + 1}/{max_retries} for {table_identifier}: {e}")
188
- print(f" Waiting {delay:.1f}s before retry...")
189
  time.sleep(delay)
190
 
191
- # All retries failed
192
- raise RuntimeError(f"Failed to summarize {table_identifier} after {max_retries} retries. Last error: {last_error}")
193
-
194
-
195
-
196
 
197
 
198
  def _create_table_nodes(
@@ -203,11 +209,11 @@ def _create_table_nodes(
203
  table_title: str = "",
204
  source_file: str = ""
205
  ) -> List[TextNode]:
206
- """Create table nodes. For large tables, creates parent+summary nodes with retry until success."""
207
- # Count rows to decide if we should summarize
208
  row_count = table_text.count("\n")
209
 
210
- # Add table info to metadata
211
  table_meta = {**metadata}
212
  if table_number:
213
  table_meta["table_number"] = table_number
@@ -215,10 +221,15 @@ def _create_table_nodes(
215
  table_meta["table_title"] = table_title
216
 
217
  if row_count < MIN_TABLE_ROWS_FOR_SUMMARY:
218
- # Table too small, just return as-is (no summary needed)
 
 
 
 
 
219
  return [TextNode(text=table_text, metadata={**table_meta, "is_table": True})]
220
 
221
- # Generate summary with retry logic (will raise exception if all retries fail)
222
  summary = _summarize_table(
223
  table_text,
224
  context_hint,
@@ -227,37 +238,36 @@ def _create_table_nodes(
227
  source_file=source_file
228
  )
229
 
230
- # Create parent node (raw table - will NOT be embedded)
231
  parent_id = str(uuid.uuid4())
232
  parent_node = TextNode(
233
  text=table_text,
234
  metadata={
235
  **table_meta,
236
  "is_table": True,
237
- "is_parent": True, # Flag to skip embedding
238
  "node_id": parent_id,
239
  }
240
  )
241
  parent_node.id_ = parent_id
242
 
243
- # Create summary node (will be embedded for search)
244
  summary_node = TextNode(
245
  text=summary,
246
  metadata={
247
  **table_meta,
248
  "is_table_summary": True,
249
- "parent_id": parent_id, # Link to parent
250
  }
251
  )
252
 
253
- table_id = f"Bảng {table_number}" if table_number else "table"
254
- print(f" Created summary for {table_id} ({row_count} rows)")
255
  return [parent_node, summary_node]
256
 
257
 
258
-
259
-
260
  def _enrich_metadata(node: BaseNode, source_path: Path | None) -> None:
 
261
  if source_path:
262
  node.metadata.update({"source_path": str(source_path), "source_file": source_path.name})
263
  if "Học phần" in (text := node.get_content()) and (m := COURSE_PATTERN.search(text)):
@@ -265,6 +275,7 @@ def _enrich_metadata(node: BaseNode, source_path: Path | None) -> None:
265
 
266
 
267
  def _chunk_text(text: str, metadata: dict) -> List[BaseNode]:
 
268
  if len(text) <= CHUNK_SIZE:
269
  return [TextNode(text=text, metadata=metadata.copy())]
270
  return SentenceSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP).get_nodes_from_documents(
@@ -273,6 +284,7 @@ def _chunk_text(text: str, metadata: dict) -> List[BaseNode]:
273
 
274
 
275
  def _extract_frontmatter(text: str) -> Tuple[Dict[str, Any], str]:
 
276
  match = FRONTMATTER_PATTERN.match(text)
277
  if not match:
278
  return {}, text
@@ -286,22 +298,23 @@ def _extract_frontmatter(text: str) -> Tuple[Dict[str, Any], str]:
286
 
287
 
288
  def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[BaseNode]:
 
289
  if not text or not text.strip():
290
  return []
291
 
292
  path = Path(source_path) if source_path else None
293
 
294
- # Extract YAML frontmatter as metadata (không chunk)
295
  frontmatter_meta, text = _extract_frontmatter(text)
296
 
297
  tables, text_with_placeholders = _extract_tables(text)
298
 
299
- # Base metadata from frontmatter + source path
300
  base_meta = {**frontmatter_meta}
301
  if path:
302
  base_meta.update({"source_path": str(path), "source_file": path.name})
303
 
304
- # Parse by headings
305
  doc = Document(text=text_with_placeholders, metadata=base_meta.copy())
306
  heading_nodes = MarkdownNodeParser().get_nodes_from_documents([doc])
307
 
@@ -316,14 +329,13 @@ def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[Bas
316
 
317
  last_end = 0
318
  for match in matches:
319
- # Text before table
320
  before_text = content[last_end:match.start()].strip()
321
 
322
- # Extract table number and title from text before table
323
  table_number = ""
324
  table_title = ""
325
  if before_text:
326
- # Look for patterns like "## Bảng 3.1 Danh mục các học phần..."
327
  title_match = TABLE_TITLE_PATTERN.search(before_text)
328
  if title_match:
329
  table_number = title_match.group(1).strip()
@@ -332,15 +344,15 @@ def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[Bas
332
  if before_text and len(before_text) >= MIN_CHUNK_SIZE:
333
  nodes.extend(_chunk_text(before_text, meta) if len(before_text) > CHUNK_SIZE else [TextNode(text=before_text, metadata=meta.copy())])
334
 
335
- # Table chunks - using Small-to-Big pattern
336
  if (idx := int(match.group(1))) < len(tables):
337
  header, rows = tables[idx]
338
  table_chunks = _split_table(header, rows)
339
 
340
- # Get context hint from header path
341
  context_hint = meta.get("Header 1", "") or meta.get("section", "")
342
 
343
- # Get source file for summary
344
  source_file = meta.get("source_file", "") or (path.name if path else "")
345
 
346
  for i, chunk in enumerate(table_chunks):
@@ -348,7 +360,7 @@ def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[Bas
348
  if len(table_chunks) > 1:
349
  chunk_meta["table_part"] = f"{i+1}/{len(table_chunks)}"
350
 
351
- # Create parent + summary nodes if applicable
352
  table_nodes = _create_table_nodes(
353
  chunk,
354
  chunk_meta,
@@ -361,11 +373,11 @@ def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[Bas
361
 
362
  last_end = match.end()
363
 
364
- # Text after table
365
  if (after := content[last_end:].strip()) and len(after) >= MIN_CHUNK_SIZE:
366
  nodes.extend(_chunk_text(after, meta) if len(after) > CHUNK_SIZE else [TextNode(text=after, metadata=meta.copy())])
367
 
368
-
369
  final: List[BaseNode] = []
370
  i = 0
371
  while i < len(nodes):
@@ -373,12 +385,12 @@ def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[Bas
373
  curr_content = curr.get_content()
374
  curr_is_table = curr.metadata.get("is_table")
375
 
376
- # Skip empty or whitespace-only nodes
377
  if not curr_content.strip():
378
  i += 1
379
  continue
380
 
381
- # If current node is small non-table and there's a next node
382
  if not curr_is_table and len(curr_content) < MIN_CHUNK_SIZE and i + 1 < len(nodes):
383
  next_node = nodes[i + 1]
384
  next_is_table = next_node.metadata.get("is_table")
@@ -405,7 +417,8 @@ def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[Bas
405
 
406
 
407
  def chunk_markdown_file(path: str | Path) -> List[BaseNode]:
 
408
  p = Path(path)
409
  if not p.exists():
410
- raise FileNotFoundError(f"File not found: {p}")
411
  return chunk_markdown(p.read_text(encoding="utf-8"), source_path=p)
 
10
  from llama_index.core.node_parser import MarkdownNodeParser, SentenceSplitter
11
  from llama_index.core.schema import BaseNode, TextNode
12
 
13
+ # Cấu hình chunking
14
  CHUNK_SIZE = 1500
15
  CHUNK_OVERLAP = 150
16
  MIN_CHUNK_SIZE = 200
17
  TABLE_ROWS_PER_CHUNK = 15
18
 
19
+ # Cấu hình Small-to-Big
20
  ENABLE_TABLE_SUMMARY = True
21
+ MIN_TABLE_ROWS_FOR_SUMMARY = 0
22
+ SUMMARY_MODEL = "openai/gpt-oss-120b"
23
+ GROQ_BASE_URL = "https://api.groq.com/openai/v1"
24
 
25
+ # Regex patterns
26
  COURSE_PATTERN = re.compile(r"Học\s*phần\s+(.+?)\s*\(\s*m[ãa]\s+([^\)]+)\)", re.I | re.DOTALL)
27
  TABLE_PLACEHOLDER = re.compile(r"__TBL_(\d+)__")
28
  HEADER_KEYWORDS = {'TT', 'STT', 'MÃ', 'TÊN', 'KHỐI', 'SỐ', 'ID', 'NO', '#'}
29
  FRONTMATTER_PATTERN = re.compile(r"^---\s*\n(.*?)\n---\s*\n", re.DOTALL)
 
30
  TABLE_TITLE_PATTERN = re.compile(r"(?:^|\n)#+\s*(?:Bảng|BẢNG)\s*(\d+(?:\.\d+)?)\s*[.:]*\s*(.+?)(?:\n|$)", re.IGNORECASE)
31
 
32
 
33
  def _is_table_row(line: str) -> bool:
34
+ """Kiểm tra dòng có phải là hàng trong bảng Markdown không."""
35
  s = line.strip()
36
  return s.startswith("|") and s.endswith("|") and s.count("|") >= 2
37
 
38
+
39
  def _is_separator(line: str) -> bool:
40
+ """Kiểm tra dòng có phải là separator của bảng (|---|---|)."""
41
  if not _is_table_row(line):
42
  return False
43
  return not line.strip().replace("|", "").replace("-", "").replace(":", "").replace(" ", "")
44
 
45
+
46
  def _is_header(line: str) -> bool:
47
+ """Kiểm tra dòng có phải là header của bảng không."""
48
  if not _is_table_row(line):
49
  return False
50
  cells = [c.strip() for c in line.split("|") if c.strip()]
 
54
 
55
 
56
  def _extract_tables(text: str) -> Tuple[List[Tuple[str, List[str]]], str]:
57
+ """Trích xuất bảng từ text và thay bằng placeholder."""
58
  lines, tables, last_header, i = text.split("\n"), [], None, 0
59
 
60
  while i < len(lines) - 1:
 
78
  else:
79
  i += 1
80
 
81
+ # Thay bảng bằng placeholder
82
  result, tbl_idx, i = [], 0, 0
83
  while i < len(lines):
84
  if tbl_idx < len(tables) and i < len(lines) - 1 and _is_table_row(lines[i]) and _is_separator(lines[i + 1]):
 
95
 
96
 
97
  def _split_table(header: str, rows: List[str], max_rows: int = TABLE_ROWS_PER_CHUNK) -> List[str]:
98
+ """Chia bảng lớn thành nhiều chunks nhỏ."""
99
  if len(rows) <= max_rows:
100
  return [header + "\n".join(rows)]
101
 
 
104
  chunk_rows = rows[i:i + max_rows]
105
  chunks.append(chunk_rows)
106
 
107
+ # Gộp chunk cuối nếu quá nhỏ (< 5 dòng)
108
  if len(chunks) > 1 and len(chunks[-1]) < 5:
109
  chunks[-2].extend(chunks[-1])
110
  chunks.pop()
111
 
112
  return [header + "\n".join(r) for r in chunks]
113
 
114
+
115
  _summary_client: Optional[OpenAI] = None
116
 
117
+
118
  def _get_summary_client() -> Optional[OpenAI]:
119
+ """Lấy Groq client để tóm tắt bảng."""
120
  global _summary_client
121
  if _summary_client is not None:
122
  return _summary_client
123
 
124
+ api_key = os.getenv("GROQ_API_KEY", "").strip()
125
  if not api_key:
126
+ print("Chưa đặt GROQ_API_KEY. Tắt tính năng tóm tắt bảng.")
127
  return None
128
 
129
+ _summary_client = OpenAI(api_key=api_key, base_url=GROQ_BASE_URL)
130
  return _summary_client
131
 
132
 
 
139
  max_retries: int = 5,
140
  base_delay: float = 2.0
141
  ) -> str:
142
+ """Tóm tắt bảng bằng LLM với retry logic."""
143
  import time
144
 
145
  if not ENABLE_TABLE_SUMMARY:
146
+ raise RuntimeError("Tính năng tóm tắt bảng đã tắt. Đặt ENABLE_TABLE_SUMMARY = True")
147
 
148
  client = _get_summary_client()
149
  if client is None:
150
+ raise RuntimeError("Chưa đặt GROQ_API_KEY. Không thể tóm tắt bảng.")
151
 
152
+ # Tạo chuỗi định danh bảng
153
  table_id_parts = []
154
  if table_number:
155
  table_id_parts.append(f"Bảng {table_number}")
 
158
  if source_file:
159
  table_id_parts.append(f"từ file {source_file}")
160
 
161
+ table_identifier = " - ".join(table_id_parts) if table_id_parts else "Bảng không xác định"
162
 
163
  prompt = f"""Tóm tắt ngắn gọn nội dung bảng sau bằng tiếng Việt.
164
 
 
188
  if summary.strip():
189
  return summary.strip()
190
  else:
191
+ raise ValueError("API trả về summary rỗng")
192
 
193
  except Exception as e:
194
  last_error = e
195
+ delay = base_delay * (2 ** attempt) # Exponential backoff: 2, 4, 8, 16, 32 giây
196
+ print(f"Thử lại {attempt + 1}/{max_retries} cho {table_identifier}: {e}")
197
+ print(f" Đợi {delay:.1f}s trước khi thử lại...")
198
  time.sleep(delay)
199
 
200
+ # Tất cả retry đều thất bại
201
+ raise RuntimeError(f"Không thể tóm tắt {table_identifier} sau {max_retries} lần thử. Lỗi cuối: {last_error}")
 
 
 
202
 
203
 
204
  def _create_table_nodes(
 
209
  table_title: str = "",
210
  source_file: str = ""
211
  ) -> List[TextNode]:
212
+ """Tạo nodes cho bảng. Bảng lớn sẽ parent + summary node."""
213
+ # Đếm số dòng để quyết định cần tóm tắt không
214
  row_count = table_text.count("\n")
215
 
216
+ # Thêm thông tin bảng vào metadata
217
  table_meta = {**metadata}
218
  if table_number:
219
  table_meta["table_number"] = table_number
 
221
  table_meta["table_title"] = table_title
222
 
223
  if row_count < MIN_TABLE_ROWS_FOR_SUMMARY:
224
+ # Bảng quá nhỏ, không cần tóm tắt
225
+ return [TextNode(text=table_text, metadata={**table_meta, "is_table": True})]
226
+
227
+ # Kiểm tra có thể tóm tắt không (cần API key)
228
+ if _get_summary_client() is None:
229
+ # Không có API key -> trả về node bảng đơn giản, không tóm tắt
230
  return [TextNode(text=table_text, metadata={**table_meta, "is_table": True})]
231
 
232
+ # Tạo summary với retry logic
233
  summary = _summarize_table(
234
  table_text,
235
  context_hint,
 
238
  source_file=source_file
239
  )
240
 
241
+ # Tạo parent node (bảng gốc - KHÔNG embed)
242
  parent_id = str(uuid.uuid4())
243
  parent_node = TextNode(
244
  text=table_text,
245
  metadata={
246
  **table_meta,
247
  "is_table": True,
248
+ "is_parent": True, # Flag để bỏ qua embedding
249
  "node_id": parent_id,
250
  }
251
  )
252
  parent_node.id_ = parent_id
253
 
254
+ # Tạo summary node (SẼ được embed để search)
255
  summary_node = TextNode(
256
  text=summary,
257
  metadata={
258
  **table_meta,
259
  "is_table_summary": True,
260
+ "parent_id": parent_id, # Link tới parent
261
  }
262
  )
263
 
264
+ table_id = f"Bảng {table_number}" if table_number else "bảng"
265
+ print(f"Đã tạo summary cho {table_id} ({row_count} dòng)")
266
  return [parent_node, summary_node]
267
 
268
 
 
 
269
  def _enrich_metadata(node: BaseNode, source_path: Path | None) -> None:
270
+ """Bổ sung metadata từ source path và trích xuất thông tin học phần."""
271
  if source_path:
272
  node.metadata.update({"source_path": str(source_path), "source_file": source_path.name})
273
  if "Học phần" in (text := node.get_content()) and (m := COURSE_PATTERN.search(text)):
 
275
 
276
 
277
  def _chunk_text(text: str, metadata: dict) -> List[BaseNode]:
278
+ """Chia text thành chunks theo kích thước cấu hình."""
279
  if len(text) <= CHUNK_SIZE:
280
  return [TextNode(text=text, metadata=metadata.copy())]
281
  return SentenceSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP).get_nodes_from_documents(
 
284
 
285
 
286
  def _extract_frontmatter(text: str) -> Tuple[Dict[str, Any], str]:
287
+ """Trích xuất YAML frontmatter từ đầu file."""
288
  match = FRONTMATTER_PATTERN.match(text)
289
  if not match:
290
  return {}, text
 
298
 
299
 
300
  def chunk_markdown(text: str, source_path: str | Path | None = None) -> List[BaseNode]:
301
+ """Chunk một file Markdown thành các nodes."""
302
  if not text or not text.strip():
303
  return []
304
 
305
  path = Path(source_path) if source_path else None
306
 
307
+ # Trích xuất YAML frontmatter làm metadata (không chunk)
308
  frontmatter_meta, text = _extract_frontmatter(text)
309
 
310
  tables, text_with_placeholders = _extract_tables(text)
311
 
312
+ # Metadata bản từ frontmatter + source path
313
  base_meta = {**frontmatter_meta}
314
  if path:
315
  base_meta.update({"source_path": str(path), "source_file": path.name})
316
 
317
+ # Parse theo headings
318
  doc = Document(text=text_with_placeholders, metadata=base_meta.copy())
319
  heading_nodes = MarkdownNodeParser().get_nodes_from_documents([doc])
320
 
 
329
 
330
  last_end = 0
331
  for match in matches:
332
+ # Text trước bảng
333
  before_text = content[last_end:match.start()].strip()
334
 
335
+ # Trích xuất số bảng tiêu đề từ text trước bảng
336
  table_number = ""
337
  table_title = ""
338
  if before_text:
 
339
  title_match = TABLE_TITLE_PATTERN.search(before_text)
340
  if title_match:
341
  table_number = title_match.group(1).strip()
 
344
  if before_text and len(before_text) >= MIN_CHUNK_SIZE:
345
  nodes.extend(_chunk_text(before_text, meta) if len(before_text) > CHUNK_SIZE else [TextNode(text=before_text, metadata=meta.copy())])
346
 
347
+ # Chunk bảng - sử dụng Small-to-Big pattern
348
  if (idx := int(match.group(1))) < len(tables):
349
  header, rows = tables[idx]
350
  table_chunks = _split_table(header, rows)
351
 
352
+ # Lấy context hint từ header path
353
  context_hint = meta.get("Header 1", "") or meta.get("section", "")
354
 
355
+ # Lấy source file cho summary
356
  source_file = meta.get("source_file", "") or (path.name if path else "")
357
 
358
  for i, chunk in enumerate(table_chunks):
 
360
  if len(table_chunks) > 1:
361
  chunk_meta["table_part"] = f"{i+1}/{len(table_chunks)}"
362
 
363
+ # Tạo parent + summary nodes nếu cần
364
  table_nodes = _create_table_nodes(
365
  chunk,
366
  chunk_meta,
 
373
 
374
  last_end = match.end()
375
 
376
+ # Text sau bảng
377
  if (after := content[last_end:].strip()) and len(after) >= MIN_CHUNK_SIZE:
378
  nodes.extend(_chunk_text(after, meta) if len(after) > CHUNK_SIZE else [TextNode(text=after, metadata=meta.copy())])
379
 
380
+ # Gộp các node nhỏ với node kế tiếp
381
  final: List[BaseNode] = []
382
  i = 0
383
  while i < len(nodes):
 
385
  curr_content = curr.get_content()
386
  curr_is_table = curr.metadata.get("is_table")
387
 
388
+ # Bỏ qua node rỗng
389
  if not curr_content.strip():
390
  i += 1
391
  continue
392
 
393
+ # Nếu node hiện tại nhỏ không phải bảng -> gộp với node sau
394
  if not curr_is_table and len(curr_content) < MIN_CHUNK_SIZE and i + 1 < len(nodes):
395
  next_node = nodes[i + 1]
396
  next_is_table = next_node.metadata.get("is_table")
 
417
 
418
 
419
  def chunk_markdown_file(path: str | Path) -> List[BaseNode]:
420
+ """Đọc và chunk một file Markdown."""
421
  p = Path(path)
422
  if not p.exists():
423
+ raise FileNotFoundError(f"Không tìm thấy file: {p}")
424
  return chunk_markdown(p.read_text(encoding="utf-8"), source_path=p)
core/rag/embedding_model.py CHANGED
@@ -1,26 +1,30 @@
1
  from __future__ import annotations
2
  import os
3
  import logging
 
4
  from dataclasses import dataclass
5
  from typing import List, Sequence
6
  import numpy as np
7
  from openai import OpenAI
8
  from langchain_core.embeddings import Embeddings
9
- import time
10
  logger = logging.getLogger(__name__)
11
 
12
 
13
  @dataclass
14
  class EmbeddingConfig:
15
- api_base_url: str = "https://api.siliconflow.com/v1"
16
- model: str = "Qwen/Qwen3-Embedding-4B"
17
- dimension: int = 2048
18
- batch_size: int = 16
 
19
 
20
 
21
  _embed_config: EmbeddingConfig | None = None
22
 
 
23
  def get_embedding_config() -> EmbeddingConfig:
 
24
  global _embed_config
25
  if _embed_config is None:
26
  _embed_config = EmbeddingConfig()
@@ -28,26 +32,32 @@ def get_embedding_config() -> EmbeddingConfig:
28
 
29
 
30
  class QwenEmbeddings(Embeddings):
 
 
31
  def __init__(self, config: EmbeddingConfig | None = None):
 
32
  self.config = config or get_embedding_config()
33
 
34
  api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
35
  if not api_key:
36
- raise ValueError("SILICONFLOW_API_KEY environment variable not set")
37
 
38
  self._client = OpenAI(
39
  api_key=api_key,
40
  base_url=self.config.api_base_url,
41
  )
42
- logger.info(f"QwenEmbeddings initialized: {self.config.model}")
43
 
44
  def embed_query(self, text: str) -> List[float]:
 
45
  return self._embed_texts([text])[0]
46
 
47
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
 
48
  return self._embed_texts(texts)
49
 
50
  def _embed_texts(self, texts: Sequence[str]) -> List[List[float]]:
 
51
  if not texts:
52
  return []
53
 
@@ -55,9 +65,11 @@ class QwenEmbeddings(Embeddings):
55
  batch_size = self.config.batch_size
56
  max_retries = 3
57
 
 
58
  for i in range(0, len(texts), batch_size):
59
  batch = list(texts[i:i + batch_size])
60
 
 
61
  for attempt in range(max_retries):
62
  try:
63
  response = self._client.embeddings.create(
@@ -68,9 +80,10 @@ class QwenEmbeddings(Embeddings):
68
  all_embeddings.append(item.embedding)
69
  break
70
  except Exception as e:
 
71
  if "rate" in str(e).lower() and attempt < max_retries - 1:
72
- wait_time = 2 ** attempt # 1s, 2s, 4s
73
- logger.warning(f"Rate limit hit, waiting {wait_time}s...")
74
  time.sleep(wait_time)
75
  else:
76
  raise
@@ -78,9 +91,10 @@ class QwenEmbeddings(Embeddings):
78
  return all_embeddings
79
 
80
  def embed_texts_np(self, texts: Sequence[str]) -> np.ndarray:
 
81
  return np.asarray(self._embed_texts(list(texts)), dtype=np.float32)
82
 
83
 
84
- # Legacy alias
85
  SiliconFlowConfig = EmbeddingConfig
86
  get_config = get_embedding_config
 
1
  from __future__ import annotations
2
  import os
3
  import logging
4
+ import time
5
  from dataclasses import dataclass
6
  from typing import List, Sequence
7
  import numpy as np
8
  from openai import OpenAI
9
  from langchain_core.embeddings import Embeddings
10
+
11
  logger = logging.getLogger(__name__)
12
 
13
 
14
  @dataclass
15
  class EmbeddingConfig:
16
+ """Cấu hình cho embedding model."""
17
+ api_base_url: str = "https://api.siliconflow.com/v1" # SiliconFlow API
18
+ model: str = "Qwen/Qwen3-Embedding-4B" # Model embedding
19
+ dimension: int = 2048 # Số chiều vector
20
+ batch_size: int = 16 # Số text mỗi batch
21
 
22
 
23
  _embed_config: EmbeddingConfig | None = None
24
 
25
+
26
  def get_embedding_config() -> EmbeddingConfig:
27
+ """Lấy cấu hình embedding (singleton pattern)."""
28
  global _embed_config
29
  if _embed_config is None:
30
  _embed_config = EmbeddingConfig()
 
32
 
33
 
34
  class QwenEmbeddings(Embeddings):
35
+ """Wrapper embedding model Qwen qua SiliconFlow API"""
36
+
37
  def __init__(self, config: EmbeddingConfig | None = None):
38
+ """Khởi tạo embedding client."""
39
  self.config = config or get_embedding_config()
40
 
41
  api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
42
  if not api_key:
43
+ raise ValueError("Chưa đặt biến môi trường SILICONFLOW_API_KEY")
44
 
45
  self._client = OpenAI(
46
  api_key=api_key,
47
  base_url=self.config.api_base_url,
48
  )
49
+ logger.info(f"Đã khởi tạo QwenEmbeddings: {self.config.model}")
50
 
51
  def embed_query(self, text: str) -> List[float]:
52
+ """Embed một câu query (dùng cho search)."""
53
  return self._embed_texts([text])[0]
54
 
55
  def embed_documents(self, texts: List[str]) -> List[List[float]]:
56
+ """Embed nhiều documents (dùng khi index)."""
57
  return self._embed_texts(texts)
58
 
59
  def _embed_texts(self, texts: Sequence[str]) -> List[List[float]]:
60
+ """Embed danh sách texts theo batch với retry logic."""
61
  if not texts:
62
  return []
63
 
 
65
  batch_size = self.config.batch_size
66
  max_retries = 3
67
 
68
+ # Xử lý theo batch
69
  for i in range(0, len(texts), batch_size):
70
  batch = list(texts[i:i + batch_size])
71
 
72
+ # Retry logic cho rate limit
73
  for attempt in range(max_retries):
74
  try:
75
  response = self._client.embeddings.create(
 
80
  all_embeddings.append(item.embedding)
81
  break
82
  except Exception as e:
83
+ # Nếu bị rate limit -> đợi rồi thử lại
84
  if "rate" in str(e).lower() and attempt < max_retries - 1:
85
+ wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
86
+ logger.warning(f"Bị rate limit, đợi {wait_time}s...")
87
  time.sleep(wait_time)
88
  else:
89
  raise
 
91
  return all_embeddings
92
 
93
  def embed_texts_np(self, texts: Sequence[str]) -> np.ndarray:
94
+ """Embed texts và trả về numpy array (tiện cho tính toán)."""
95
  return np.asarray(self._embed_texts(list(texts)), dtype=np.float32)
96
 
97
 
98
+ # Alias để tương thích ngược
99
  SiliconFlowConfig = EmbeddingConfig
100
  get_config = get_embedding_config
core/rag/generator.py CHANGED
@@ -5,7 +5,7 @@ if TYPE_CHECKING:
5
  from core.rag.retrival import Retriever
6
 
7
 
8
- # System prompt để sử dụng khi gọi LLM (export cho gradio/eval dùng)
9
  SYSTEM_PROMPT = """Bạn là Trợ lý học vụ Đại học Bách khoa Hà Nội.
10
 
11
  ## NGUYÊN TẮC:
@@ -16,6 +16,7 @@ SYSTEM_PROMPT = """Bạn là Trợ lý học vụ Đại học Bách khoa Hà N
16
 
17
 
18
  def build_context(results: List[Dict[str, Any]], max_chars: int = 8000) -> str:
 
19
  parts = []
20
  for i, r in enumerate(results, 1):
21
  meta = r.get("metadata", {})
@@ -30,7 +31,7 @@ def build_context(results: List[Dict[str, Any]], max_chars: int = 8000) -> str:
30
  issued_year = meta.get("issued_year", "")
31
  content = r.get("content", "").strip()
32
 
33
- # Build metadata line
34
  meta_info = f"Nguồn: {source}"
35
  if header and header != "/":
36
  meta_info += f" | Mục: {header}"
@@ -53,16 +54,20 @@ def build_context(results: List[Dict[str, Any]], max_chars: int = 8000) -> str:
53
  parts.append(f"[TÀI LIỆU {i}]\n{meta_info}\n{content}")
54
 
55
  context = "\n---\n".join(parts)
 
56
  return context[:max_chars] if len(context) > max_chars else context
57
 
58
 
59
  def build_prompt(question: str, context: str) -> str:
 
60
  return f"{SYSTEM_PROMPT}\n\n## CONTEXT:\n{context}\n\n## CÂU HỎI: {question}\n\n## TRẢ LỜI:"
61
 
62
 
63
  class RAGContextBuilder:
 
64
 
65
  def __init__(self, retriever: "Retriever", max_context_chars: int = 8000):
 
66
  self._retriever = retriever
67
  self._max_context_chars = max_context_chars
68
 
@@ -73,9 +78,11 @@ class RAGContextBuilder:
73
  initial_k: int = 20,
74
  mode: str = "hybrid_rerank"
75
  ) -> Dict[str, Any]:
76
-
 
77
  results = self._retriever.flexible_search(question, k=k, initial_k=initial_k, mode=mode)
78
 
 
79
  if not results:
80
  return {
81
  "results": [],
@@ -84,15 +91,17 @@ class RAGContextBuilder:
84
  "prompt": "",
85
  }
86
 
 
87
  context_text = build_context(results, self._max_context_chars)
88
  prompt = build_prompt(question, context_text)
89
 
90
  return {
91
- "results": results,
92
- "contexts": [r.get("content", "")[:1000] for r in results],
93
- "context_text": context_text,
94
- "prompt": prompt,
95
  }
96
 
97
 
 
98
  RAGGenerator = RAGContextBuilder
 
5
  from core.rag.retrival import Retriever
6
 
7
 
8
+ # System prompt cho LLM (export để gradio/eval dùng)
9
  SYSTEM_PROMPT = """Bạn là Trợ lý học vụ Đại học Bách khoa Hà Nội.
10
 
11
  ## NGUYÊN TẮC:
 
16
 
17
 
18
  def build_context(results: List[Dict[str, Any]], max_chars: int = 8000) -> str:
19
+ """Xây dựng context từ kết quả retrieval để đưa vào prompt."""
20
  parts = []
21
  for i, r in enumerate(results, 1):
22
  meta = r.get("metadata", {})
 
31
  issued_year = meta.get("issued_year", "")
32
  content = r.get("content", "").strip()
33
 
34
+ # Tạo dòng metadata
35
  meta_info = f"Nguồn: {source}"
36
  if header and header != "/":
37
  meta_info += f" | Mục: {header}"
 
54
  parts.append(f"[TÀI LIỆU {i}]\n{meta_info}\n{content}")
55
 
56
  context = "\n---\n".join(parts)
57
+ # Cắt ngắn nếu vượt quá giới hạn
58
  return context[:max_chars] if len(context) > max_chars else context
59
 
60
 
61
  def build_prompt(question: str, context: str) -> str:
62
+ """Ghép system prompt, context và câu hỏi thành prompt hoàn chỉnh."""
63
  return f"{SYSTEM_PROMPT}\n\n## CONTEXT:\n{context}\n\n## CÂU HỎI: {question}\n\n## TRẢ LỜI:"
64
 
65
 
66
  class RAGContextBuilder:
67
+ """Kết hợp retrieval và context building thành một bước."""
68
 
69
  def __init__(self, retriever: "Retriever", max_context_chars: int = 8000):
70
+ """Khởi tạo với retriever và giới hạn context."""
71
  self._retriever = retriever
72
  self._max_context_chars = max_context_chars
73
 
 
78
  initial_k: int = 20,
79
  mode: str = "hybrid_rerank"
80
  ) -> Dict[str, Any]:
81
+ """Retrieve documents và chuẩn bị context + prompt cho LLM."""
82
+ # Tìm kiếm documents liên quan
83
  results = self._retriever.flexible_search(question, k=k, initial_k=initial_k, mode=mode)
84
 
85
+ # Không tìm thấy kết quả
86
  if not results:
87
  return {
88
  "results": [],
 
91
  "prompt": "",
92
  }
93
 
94
+ # Xây dựng context và prompt
95
  context_text = build_context(results, self._max_context_chars)
96
  prompt = build_prompt(question, context_text)
97
 
98
  return {
99
+ "results": results, # Kết quả retrieval gốc
100
+ "contexts": [r.get("content", "")[:1000] for r in results], # Context rút gọn (cho eval)
101
+ "context_text": context_text, # Context đầy đủ
102
+ "prompt": prompt, # Prompt hoàn chỉnh
103
  }
104
 
105
 
106
+ # Alias để tương thích ngược
107
  RAGGenerator = RAGContextBuilder
core/rag/retrival.py CHANGED
@@ -22,29 +22,30 @@ logger = logging.getLogger(__name__)
22
 
23
 
24
  class RetrievalMode(str, Enum):
25
- """Retrieval modes."""
26
- VECTOR_ONLY = "vector_only"
27
- BM25_ONLY = "bm25_only"
28
- HYBRID = "hybrid"
29
- HYBRID_RERANK = "hybrid_rerank"
30
 
31
 
32
  @dataclass
33
  class RetrievalConfig:
34
- rerank_api_base_url: str = "https://api.siliconflow.com/v1"
35
- rerank_model: str = "Qwen/Qwen3-Reranker-4B"
36
- rerank_top_n: int = 10
37
- initial_k: int = 25 # Reduced to minimize reranker time
38
- top_k: int = 5
39
- vector_weight: float = 0.5
40
- bm25_weight: float = 0.5
41
-
42
 
43
 
44
  _retrieval_config: RetrievalConfig | None = None
45
 
46
 
47
  def get_retrieval_config() -> RetrievalConfig:
 
48
  global _retrieval_config
49
  if _retrieval_config is None:
50
  _retrieval_config = RetrievalConfig()
@@ -52,6 +53,7 @@ def get_retrieval_config() -> RetrievalConfig:
52
 
53
 
54
  class SiliconFlowReranker(BaseDocumentCompressor):
 
55
  api_key: str = Field(default="")
56
  api_base_url: str = Field(default="")
57
  model: str = Field(default="")
@@ -66,9 +68,11 @@ class SiliconFlowReranker(BaseDocumentCompressor):
66
  query: str,
67
  callbacks: Optional[Callbacks] = None,
68
  ) -> Sequence[Document]:
 
69
  if not documents or not self.api_key:
70
  return list(documents)
71
 
 
72
  for attempt in range(3):
73
  try:
74
  response = requests.post(
@@ -91,6 +95,7 @@ class SiliconFlowReranker(BaseDocumentCompressor):
91
  if "results" not in data:
92
  return list(documents)
93
 
 
94
  reranked: List[Document] = []
95
  for result in data["results"]:
96
  doc = documents[result["index"]]
@@ -101,33 +106,36 @@ class SiliconFlowReranker(BaseDocumentCompressor):
101
  return reranked
102
 
103
  except Exception as e:
 
104
  if "rate" in str(e).lower() and attempt < 2:
105
  time.sleep(2 ** attempt)
106
  else:
107
- logger.error(f"Rerank error: {e}")
108
  return list(documents)
109
 
110
  return list(documents)
111
 
112
 
113
-
114
-
115
  class Retriever:
 
 
116
  def __init__(self, vector_db: "ChromaVectorDB", use_reranker: bool = True):
 
117
  self._vector_db = vector_db
118
  self._config = get_retrieval_config()
119
  self._reranker: Optional[SiliconFlowReranker] = None
120
 
 
121
  self._vector_retriever = self._vector_db.vectorstore.as_retriever(
122
  search_kwargs={"k": self._config.initial_k}
123
  )
124
 
125
- # Lazy-load BM25 - only initialize when needed
126
  self._bm25_retriever: Optional[BM25Retriever] = None
127
  self._bm25_initialized = False
128
  self._ensemble_retriever: Optional[EnsembleRetriever] = None
129
 
130
- # BM25 cache path (persist to disk)
131
  from pathlib import Path
132
  persist_dir = getattr(self._vector_db.config, 'persist_dir', None)
133
  if persist_dir:
@@ -138,61 +146,57 @@ class Retriever:
138
  if use_reranker:
139
  self._reranker = self._init_reranker()
140
 
141
- logger.info("Retriever initialized")
142
 
143
-
144
  def _save_bm25_cache(self, bm25: BM25Retriever) -> None:
145
- """Save BM25 retriever to disk for fast loading."""
146
  if not self._bm25_cache_path:
147
  return
148
  try:
149
  import pickle
150
  with open(self._bm25_cache_path, 'wb') as f:
151
  pickle.dump(bm25, f)
152
- logger.info(f"BM25 cache saved to {self._bm25_cache_path}")
153
  except Exception as e:
154
- logger.warning(f"Failed to save BM25 cache: {e}")
155
 
156
  def _load_bm25_cache(self) -> Optional[BM25Retriever]:
 
157
  if not self._bm25_cache_path or not self._bm25_cache_path.exists():
158
  return None
159
-
160
  try:
161
  import pickle
162
- import time
163
  start = time.time()
164
  with open(self._bm25_cache_path, 'rb') as f:
165
  bm25 = pickle.load(f)
166
  bm25.k = self._config.initial_k
167
- logger.info(f"BM25 loaded from cache in {time.time() - start:.2f}s")
168
  return bm25
169
  except Exception as e:
170
- logger.warning(f"Failed to load BM25 cache: {e}")
171
  return None
172
-
173
-
174
-
175
  def _init_bm25(self) -> Optional[BM25Retriever]:
 
176
  if self._bm25_initialized:
177
  return self._bm25_retriever
178
 
179
  self._bm25_initialized = True
180
 
181
- # Try loading from cache first
182
  cached = self._load_bm25_cache()
183
  if cached:
184
  self._bm25_retriever = cached
185
  return cached
186
 
187
- # Build from scratch
188
  try:
189
- import time
190
  start = time.time()
191
- logger.info("Building BM25 index from documents...")
192
 
193
  docs = self._vector_db.get_all_documents()
194
  if not docs:
195
- logger.warning("No documents found for BM25")
196
  return None
197
 
198
  lc_docs = [
@@ -203,19 +207,18 @@ class Retriever:
203
  bm25.k = self._config.initial_k
204
 
205
  self._bm25_retriever = bm25
206
- logger.info(f"BM25 built with {len(docs)} docs in {time.time() - start:.2f}s")
207
 
208
- # Save to cache for next time
209
  self._save_bm25_cache(bm25)
210
 
211
  return bm25
212
  except Exception as e:
213
- logger.error(f"Failed to init BM25: {e}")
214
  return None
215
 
216
-
217
  def _get_ensemble_retriever(self) -> EnsembleRetriever:
218
- """Get or create ensemble retriever (lazy-loaded)."""
219
  if self._ensemble_retriever is not None:
220
  return self._ensemble_retriever
221
 
@@ -226,14 +229,15 @@ class Retriever:
226
  weights=[self._config.vector_weight, self._config.bm25_weight]
227
  )
228
  else:
 
229
  self._ensemble_retriever = EnsembleRetriever(
230
  retrievers=[self._vector_retriever],
231
  weights=[1.0]
232
  )
233
  return self._ensemble_retriever
234
 
235
-
236
  def _init_reranker(self) -> Optional[SiliconFlowReranker]:
 
237
  api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
238
  if not api_key:
239
  return None
@@ -245,7 +249,7 @@ class Retriever:
245
  )
246
 
247
  def _build_final(self):
248
- """Build final retriever with reranker (lazy-loaded)."""
249
  ensemble = self._get_ensemble_retriever()
250
  if self._reranker:
251
  return ContextualCompressionRetriever(
@@ -254,21 +258,22 @@ class Retriever:
254
  )
255
  return ensemble
256
 
257
-
258
  @property
259
  def has_reranker(self) -> bool:
 
260
  return self._reranker is not None
261
 
262
  def _to_result(self, doc: Document, rank: int, **extra) -> Dict[str, Any]:
 
263
  metadata = doc.metadata or {}
264
  content = doc.page_content
265
 
266
- # Small-to-Big: If this is a summary node, swap with parent (raw table)
267
  if metadata.get("is_table_summary") and metadata.get("parent_id"):
268
  parent = self._vector_db.get_parent_node(metadata["parent_id"])
269
  if parent:
270
  content = parent.get("content", content)
271
- # Merge metadata, keeping summary info for debugging
272
  metadata = {
273
  **parent.get("metadata", {}),
274
  "original_summary": doc.page_content[:200],
@@ -283,10 +288,10 @@ class Retriever:
283
  **extra,
284
  }
285
 
286
-
287
  def vector_search(
288
  self, text: str, *, k: int | None = None, where: Optional[Dict[str, Any]] = None
289
  ) -> List[Dict[str, Any]]:
 
290
  if not text.strip():
291
  return []
292
  k = k or self._config.top_k
@@ -294,13 +299,12 @@ class Retriever:
294
  return [self._to_result(doc, i + 1, distance=score) for i, (doc, score) in enumerate(results)]
295
 
296
  def bm25_search(self, text: str, *, k: int | None = None) -> List[Dict[str, Any]]:
 
297
  if not text.strip():
298
  return []
299
-
300
  bm25 = self._init_bm25() # Lazy-load BM25
301
  if not bm25:
302
  return self.vector_search(text, k=k)
303
-
304
  k = k or self._config.top_k
305
  bm25.k = k
306
  results = bm25.invoke(text)
@@ -309,9 +313,9 @@ class Retriever:
309
  def hybrid_search(
310
  self, text: str, *, k: int | None = None, initial_k: int | None = None
311
  ) -> List[Dict[str, Any]]:
 
312
  if not text.strip():
313
  return []
314
-
315
  k = k or self._config.top_k
316
  if initial_k:
317
  self._vector_retriever.search_kwargs["k"] = initial_k
@@ -319,7 +323,6 @@ class Retriever:
319
  if bm25:
320
  bm25.k = initial_k
321
 
322
- # Dùng ensemble_retriever (lazy-loaded, KHÔNG có reranker)
323
  ensemble = self._get_ensemble_retriever()
324
  results = ensemble.invoke(text)
325
  return [self._to_result(doc, i + 1) for i, doc in enumerate(results[:k])]
@@ -332,11 +335,9 @@ class Retriever:
332
  where: Optional[Dict[str, Any]] = None,
333
  initial_k: int | None = None,
334
  ) -> List[Dict[str, Any]]:
335
- import time
336
-
337
  if not text.strip():
338
  return []
339
-
340
  k = k or self._config.top_k
341
  initial_k = initial_k or self._config.initial_k
342
 
@@ -350,16 +351,18 @@ class Retriever:
350
  for i, doc in enumerate(results[:k])
351
  ]
352
 
353
- # Build final retriever (lazy-loaded ensemble + reranker)
354
  if initial_k:
355
  self._vector_retriever.search_kwargs["k"] = initial_k
356
  bm25 = self._init_bm25()
357
  if bm25:
358
  bm25.k = initial_k
359
 
 
360
  ensemble = self._get_ensemble_retriever()
361
  ensemble_results = ensemble.invoke(text)
362
 
 
363
  if self._reranker:
364
  results = self._reranker.compress_documents(ensemble_results, text)
365
  else:
@@ -370,8 +373,6 @@ class Retriever:
370
  for i, doc in enumerate(results[:k])
371
  ]
372
 
373
-
374
-
375
  def flexible_search(
376
  self,
377
  text: str,
@@ -381,9 +382,11 @@ class Retriever:
381
  initial_k: int | None = None,
382
  where: Optional[Dict[str, Any]] = None,
383
  ) -> List[Dict[str, Any]]:
 
384
  if not text.strip():
385
  return []
386
 
 
387
  if isinstance(mode, str):
388
  try:
389
  mode = RetrievalMode(mode.lower())
@@ -393,6 +396,7 @@ class Retriever:
393
  k = k or self._config.top_k
394
  initial_k = initial_k or self._config.initial_k
395
 
 
396
  if mode == RetrievalMode.VECTOR_ONLY:
397
  return self.vector_search(text, k=k, where=where)
398
  elif mode == RetrievalMode.BM25_ONLY:
@@ -404,5 +408,5 @@ class Retriever:
404
  else: # HYBRID_RERANK
405
  return self.search_with_rerank(text, k=k, where=where, initial_k=initial_k)
406
 
407
- # Legacy alias
408
  query = vector_search
 
22
 
23
 
24
  class RetrievalMode(str, Enum):
25
+ """Các chế độ retrieval hỗ trợ."""
26
+ VECTOR_ONLY = "vector_only" # Chỉ dùng vector search
27
+ BM25_ONLY = "bm25_only" # Chỉ dùng BM25 keyword search
28
+ HYBRID = "hybrid" # Kết hợp vector + BM25
29
+ HYBRID_RERANK = "hybrid_rerank" # Hybrid + reranking
30
 
31
 
32
  @dataclass
33
  class RetrievalConfig:
34
+ """Cấu hình cho retrieval system."""
35
+ rerank_api_base_url: str = "https://api.siliconflow.com/v1" # API reranker
36
+ rerank_model: str = "Qwen/Qwen3-Reranker-4B" # Model reranker
37
+ rerank_top_n: int = 10 # Số kết quả sau rerank
38
+ initial_k: int = 25 # Số docs lấy ban đầu
39
+ top_k: int = 5 # Số kết quả cuối cùng
40
+ vector_weight: float = 0.5 # Trọng số vector search
41
+ bm25_weight: float = 0.5 # Trọng số BM25
42
 
43
 
44
  _retrieval_config: RetrievalConfig | None = None
45
 
46
 
47
  def get_retrieval_config() -> RetrievalConfig:
48
+ """Lấy cấu hình retrieval (singleton pattern)."""
49
  global _retrieval_config
50
  if _retrieval_config is None:
51
  _retrieval_config = RetrievalConfig()
 
53
 
54
 
55
  class SiliconFlowReranker(BaseDocumentCompressor):
56
+ """Reranker sử dụng SiliconFlow API để sắp xếp lại kết quả."""
57
  api_key: str = Field(default="")
58
  api_base_url: str = Field(default="")
59
  model: str = Field(default="")
 
68
  query: str,
69
  callbacks: Optional[Callbacks] = None,
70
  ) -> Sequence[Document]:
71
+ """Rerank documents dựa trên độ liên quan với query."""
72
  if not documents or not self.api_key:
73
  return list(documents)
74
 
75
+ # Retry logic với exponential backoff
76
  for attempt in range(3):
77
  try:
78
  response = requests.post(
 
95
  if "results" not in data:
96
  return list(documents)
97
 
98
+ # Tạo danh sách documents đã rerank với score
99
  reranked: List[Document] = []
100
  for result in data["results"]:
101
  doc = documents[result["index"]]
 
106
  return reranked
107
 
108
  except Exception as e:
109
+ # Rate limit -> đợi rồi thử lại
110
  if "rate" in str(e).lower() and attempt < 2:
111
  time.sleep(2 ** attempt)
112
  else:
113
+ logger.error(f"Lỗi rerank: {e}")
114
  return list(documents)
115
 
116
  return list(documents)
117
 
118
 
 
 
119
  class Retriever:
120
+ """Retriever chính hỗ trợ nhiều chế độ tìm kiếm."""
121
+
122
  def __init__(self, vector_db: "ChromaVectorDB", use_reranker: bool = True):
123
+ """Khởi tạo retriever với vector DB và reranker."""
124
  self._vector_db = vector_db
125
  self._config = get_retrieval_config()
126
  self._reranker: Optional[SiliconFlowReranker] = None
127
 
128
+ # Vector retriever từ ChromaDB
129
  self._vector_retriever = self._vector_db.vectorstore.as_retriever(
130
  search_kwargs={"k": self._config.initial_k}
131
  )
132
 
133
+ # Lazy-load BM25 - chỉ khởi tạo khi cần
134
  self._bm25_retriever: Optional[BM25Retriever] = None
135
  self._bm25_initialized = False
136
  self._ensemble_retriever: Optional[EnsembleRetriever] = None
137
 
138
+ # Đường dẫn cache BM25 (lưu vào disk)
139
  from pathlib import Path
140
  persist_dir = getattr(self._vector_db.config, 'persist_dir', None)
141
  if persist_dir:
 
146
  if use_reranker:
147
  self._reranker = self._init_reranker()
148
 
149
+ logger.info("Đã khởi tạo Retriever")
150
 
 
151
  def _save_bm25_cache(self, bm25: BM25Retriever) -> None:
152
+ """Lưu BM25 index vào cache file."""
153
  if not self._bm25_cache_path:
154
  return
155
  try:
156
  import pickle
157
  with open(self._bm25_cache_path, 'wb') as f:
158
  pickle.dump(bm25, f)
159
+ logger.info(f"Đã lưu BM25 cache vào {self._bm25_cache_path}")
160
  except Exception as e:
161
+ logger.warning(f"Không thể lưu BM25 cache: {e}")
162
 
163
  def _load_bm25_cache(self) -> Optional[BM25Retriever]:
164
+ """Tải BM25 index từ cache file."""
165
  if not self._bm25_cache_path or not self._bm25_cache_path.exists():
166
  return None
 
167
  try:
168
  import pickle
 
169
  start = time.time()
170
  with open(self._bm25_cache_path, 'rb') as f:
171
  bm25 = pickle.load(f)
172
  bm25.k = self._config.initial_k
173
+ logger.info(f"Đã tải BM25 từ cache trong {time.time() - start:.2f}s")
174
  return bm25
175
  except Exception as e:
176
+ logger.warning(f"Không thể tải BM25 cache: {e}")
177
  return None
178
+
 
 
179
  def _init_bm25(self) -> Optional[BM25Retriever]:
180
+ """Khởi tạo BM25 retriever (lazy-load với cache)."""
181
  if self._bm25_initialized:
182
  return self._bm25_retriever
183
 
184
  self._bm25_initialized = True
185
 
186
+ # Thử tải từ cache trước
187
  cached = self._load_bm25_cache()
188
  if cached:
189
  self._bm25_retriever = cached
190
  return cached
191
 
192
+ # Build từ đầu nếu không có cache
193
  try:
 
194
  start = time.time()
195
+ logger.info("Đang xây dựng BM25 index từ documents...")
196
 
197
  docs = self._vector_db.get_all_documents()
198
  if not docs:
199
+ logger.warning("Không tìm thấy documents cho BM25")
200
  return None
201
 
202
  lc_docs = [
 
207
  bm25.k = self._config.initial_k
208
 
209
  self._bm25_retriever = bm25
210
+ logger.info(f"Đã xây dựng BM25 với {len(docs)} docs trong {time.time() - start:.2f}s")
211
 
212
+ # Lưu vào cache cho lần sau
213
  self._save_bm25_cache(bm25)
214
 
215
  return bm25
216
  except Exception as e:
217
+ logger.error(f"Không thể khởi tạo BM25: {e}")
218
  return None
219
 
 
220
  def _get_ensemble_retriever(self) -> EnsembleRetriever:
221
+ """Lấy ensemble retriever (vector + BM25)."""
222
  if self._ensemble_retriever is not None:
223
  return self._ensemble_retriever
224
 
 
229
  weights=[self._config.vector_weight, self._config.bm25_weight]
230
  )
231
  else:
232
+ # Fallback về vector only
233
  self._ensemble_retriever = EnsembleRetriever(
234
  retrievers=[self._vector_retriever],
235
  weights=[1.0]
236
  )
237
  return self._ensemble_retriever
238
 
 
239
  def _init_reranker(self) -> Optional[SiliconFlowReranker]:
240
+ """Khởi tạo reranker nếu có API key."""
241
  api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
242
  if not api_key:
243
  return None
 
249
  )
250
 
251
  def _build_final(self):
252
+ """Build retriever cuối cùng (ensemble + reranker nếu có)."""
253
  ensemble = self._get_ensemble_retriever()
254
  if self._reranker:
255
  return ContextualCompressionRetriever(
 
258
  )
259
  return ensemble
260
 
 
261
  @property
262
  def has_reranker(self) -> bool:
263
+ """Kiểm tra có reranker không."""
264
  return self._reranker is not None
265
 
266
  def _to_result(self, doc: Document, rank: int, **extra) -> Dict[str, Any]:
267
+ """Chuyển Document thành dict result, xử lý Small-to-Big."""
268
  metadata = doc.metadata or {}
269
  content = doc.page_content
270
 
271
+ # Small-to-Big: Nếu summary node -> swap với parent (bảng gốc)
272
  if metadata.get("is_table_summary") and metadata.get("parent_id"):
273
  parent = self._vector_db.get_parent_node(metadata["parent_id"])
274
  if parent:
275
  content = parent.get("content", content)
276
+ # Merge metadata, giữ lại info summary để debug
277
  metadata = {
278
  **parent.get("metadata", {}),
279
  "original_summary": doc.page_content[:200],
 
288
  **extra,
289
  }
290
 
 
291
  def vector_search(
292
  self, text: str, *, k: int | None = None, where: Optional[Dict[str, Any]] = None
293
  ) -> List[Dict[str, Any]]:
294
+ """Tìm kiếm bằng vector similarity."""
295
  if not text.strip():
296
  return []
297
  k = k or self._config.top_k
 
299
  return [self._to_result(doc, i + 1, distance=score) for i, (doc, score) in enumerate(results)]
300
 
301
  def bm25_search(self, text: str, *, k: int | None = None) -> List[Dict[str, Any]]:
302
+ """Tìm kiếm bằng BM25 keyword matching."""
303
  if not text.strip():
304
  return []
 
305
  bm25 = self._init_bm25() # Lazy-load BM25
306
  if not bm25:
307
  return self.vector_search(text, k=k)
 
308
  k = k or self._config.top_k
309
  bm25.k = k
310
  results = bm25.invoke(text)
 
313
  def hybrid_search(
314
  self, text: str, *, k: int | None = None, initial_k: int | None = None
315
  ) -> List[Dict[str, Any]]:
316
+ """Tìm kiếm hybrid (vector + BM25) không có rerank."""
317
  if not text.strip():
318
  return []
 
319
  k = k or self._config.top_k
320
  if initial_k:
321
  self._vector_retriever.search_kwargs["k"] = initial_k
 
323
  if bm25:
324
  bm25.k = initial_k
325
 
 
326
  ensemble = self._get_ensemble_retriever()
327
  results = ensemble.invoke(text)
328
  return [self._to_result(doc, i + 1) for i, doc in enumerate(results[:k])]
 
335
  where: Optional[Dict[str, Any]] = None,
336
  initial_k: int | None = None,
337
  ) -> List[Dict[str, Any]]:
338
+ """Tìm kiếm hybrid + reranking để có kết quả tốt nhất."""
 
339
  if not text.strip():
340
  return []
 
341
  k = k or self._config.top_k
342
  initial_k = initial_k or self._config.initial_k
343
 
 
351
  for i, doc in enumerate(results[:k])
352
  ]
353
 
354
+ # Cập nhật k cho initial fetch
355
  if initial_k:
356
  self._vector_retriever.search_kwargs["k"] = initial_k
357
  bm25 = self._init_bm25()
358
  if bm25:
359
  bm25.k = initial_k
360
 
361
+ # Hybrid search
362
  ensemble = self._get_ensemble_retriever()
363
  ensemble_results = ensemble.invoke(text)
364
 
365
+ # Rerank nếu có
366
  if self._reranker:
367
  results = self._reranker.compress_documents(ensemble_results, text)
368
  else:
 
373
  for i, doc in enumerate(results[:k])
374
  ]
375
 
 
 
376
  def flexible_search(
377
  self,
378
  text: str,
 
382
  initial_k: int | None = None,
383
  where: Optional[Dict[str, Any]] = None,
384
  ) -> List[Dict[str, Any]]:
385
+ """Tìm kiếm linh hoạt với nhiều chế độ."""
386
  if not text.strip():
387
  return []
388
 
389
+ # Parse mode từ string
390
  if isinstance(mode, str):
391
  try:
392
  mode = RetrievalMode(mode.lower())
 
396
  k = k or self._config.top_k
397
  initial_k = initial_k or self._config.initial_k
398
 
399
+ # Gọi method tương ứng theo mode
400
  if mode == RetrievalMode.VECTOR_ONLY:
401
  return self.vector_search(text, k=k, where=where)
402
  elif mode == RetrievalMode.BM25_ONLY:
 
408
  else: # HYBRID_RERANK
409
  return self.search_with_rerank(text, k=k, where=where, initial_k=initial_k)
410
 
411
+ # Alias để tương thích ngược
412
  query = vector_search
core/rag/vector_store.py CHANGED
@@ -13,66 +13,76 @@ logger = logging.getLogger(__name__)
13
 
14
  @dataclass
15
  class ChromaConfig:
 
 
16
  def _default_persist_dir() -> str:
 
17
  repo_root = Path(__file__).resolve().parents[2]
18
  return str((repo_root / "data" / "chroma").resolve())
19
 
20
- persist_dir: str = field(default_factory=_default_persist_dir)
21
- collection_name: str = "hust_rag_collection"
22
 
23
 
24
  class ChromaVectorDB:
 
 
25
  def __init__(
26
  self,
27
  embedder: Any,
28
  config: ChromaConfig | None = None,
29
  ):
 
30
  self.embedder = embedder
31
  self.config = config or ChromaConfig()
32
  self._hasher = HashProcessor(verbose=False)
33
 
34
- # Storage for parent nodes (not embedded, used for Small-to-Big retrieval)
35
- # Persist to JSON file in same directory as ChromaDB
36
  self._parent_nodes_path = Path(self.config.persist_dir) / "parent_nodes.json"
37
  self._parent_nodes: Dict[str, Dict[str, Any]] = self._load_parent_nodes()
38
 
 
39
  self._vs = Chroma(
40
  collection_name=self.config.collection_name,
41
  embedding_function=self.embedder,
42
  persist_directory=self.config.persist_dir,
43
  )
44
- logger.info(f"ChromaVectorDB initialized: {self.config.collection_name}")
45
 
46
  def _load_parent_nodes(self) -> Dict[str, Dict[str, Any]]:
 
47
  if self._parent_nodes_path.exists():
48
  try:
49
  with open(self._parent_nodes_path, 'r', encoding='utf-8') as f:
50
  data = json.load(f)
51
- logger.info(f"Loaded {len(data)} parent nodes from {self._parent_nodes_path}")
52
  return data
53
  except Exception as e:
54
- logger.warning(f"Failed to load parent nodes: {e}")
55
  return {}
56
 
57
  def _save_parent_nodes(self) -> None:
58
- """Save parent nodes to JSON file."""
59
  try:
60
  self._parent_nodes_path.parent.mkdir(parents=True, exist_ok=True)
61
  with open(self._parent_nodes_path, 'w', encoding='utf-8') as f:
62
  json.dump(self._parent_nodes, f, ensure_ascii=False, indent=2)
63
- logger.info(f"Saved {len(self._parent_nodes)} parent nodes to {self._parent_nodes_path}")
64
  except Exception as e:
65
- logger.warning(f"Failed to save parent nodes: {e}")
66
 
67
  @property
68
  def collection(self):
 
69
  return getattr(self._vs, "_collection", None)
70
 
71
  @property
72
  def vectorstore(self):
 
73
  return self._vs
74
 
75
  def _flatten_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
 
76
  out: Dict[str, Any] = {}
77
  for k, v in (metadata or {}).items():
78
  if v is None:
@@ -80,33 +90,33 @@ class ChromaVectorDB:
80
  if isinstance(v, (str, int, float, bool)):
81
  out[str(k)] = v
82
  elif isinstance(v, (list, tuple, set, dict)):
 
83
  out[str(k)] = json.dumps(v, ensure_ascii=False)
84
  else:
85
  out[str(k)] = str(v)
86
  return out
87
 
88
  def _normalize_doc(self, doc: Any) -> Dict[str, Any]:
89
- # Nếu đã dict
 
90
  if isinstance(doc, dict):
91
  return doc
92
-
93
- # Nếu là TextNode/BaseNode từ llama_index
94
  if hasattr(doc, "get_content") and hasattr(doc, "metadata"):
95
  return {
96
  "content": doc.get_content(),
97
  "metadata": dict(doc.metadata) if doc.metadata else {},
98
  }
99
-
100
- # Nếu là Document từ langchain
101
  if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
102
  return {
103
  "content": doc.page_content,
104
  "metadata": dict(doc.metadata) if doc.metadata else {},
105
  }
106
-
107
- raise TypeError(f"Unsupported document type: {type(doc)}")
108
 
109
  def _to_documents(self, docs: Sequence[Any], ids: Sequence[str]) -> List[Document]:
 
110
  out: List[Document] = []
111
  for d, doc_id in zip(docs, ids):
112
  normalized = self._normalize_doc(d)
@@ -116,6 +126,7 @@ class ChromaVectorDB:
116
  return out
117
 
118
  def _doc_id(self, doc: Any) -> str:
 
119
  normalized = self._normalize_doc(doc)
120
  md = normalized.get("metadata") or {}
121
  key = {
@@ -133,13 +144,14 @@ class ChromaVectorDB:
133
  ids: Optional[Sequence[str]] = None,
134
  batch_size: int = 128,
135
  ) -> int:
 
136
  if not docs:
137
  return 0
138
 
139
  if ids is not None and len(ids) != len(docs):
140
- raise ValueError("ids length must match docs length")
141
 
142
- # Separate parent nodes (not embedded) from regular nodes
143
  regular_docs = []
144
  regular_ids = []
145
  parent_count = 0
@@ -150,7 +162,7 @@ class ChromaVectorDB:
150
  doc_id = ids[i] if ids else self._doc_id(d)
151
 
152
  if md.get("is_parent"):
153
- # Store parent node separately (for Small-to-Big retrieval)
154
  parent_id = md.get("node_id", doc_id)
155
  self._parent_nodes[parent_id] = {
156
  "id": parent_id,
@@ -163,12 +175,13 @@ class ChromaVectorDB:
163
  regular_ids.append(doc_id)
164
 
165
  if parent_count > 0:
166
- logger.info(f"Stored {parent_count} parent nodes (not embedded)")
167
- self._save_parent_nodes() # Persist to disk
168
 
169
  if not regular_docs:
170
  return parent_count
171
 
 
172
  bs = max(1, batch_size)
173
  total = 0
174
 
@@ -180,12 +193,13 @@ class ChromaVectorDB:
180
  try:
181
  self._vs.add_documents(lc_docs, ids=batch_ids)
182
  except TypeError:
 
183
  texts = [d.page_content for d in lc_docs]
184
  metas = [d.metadata for d in lc_docs]
185
  self._vs.add_texts(texts=texts, metadatas=metas, ids=batch_ids)
186
  total += len(batch)
187
 
188
- logger.info(f"Added {total} documents to vector store")
189
  return total + parent_count
190
 
191
  def upsert_documents(
@@ -195,13 +209,14 @@ class ChromaVectorDB:
195
  ids: Optional[Sequence[str]] = None,
196
  batch_size: int = 128,
197
  ) -> int:
 
198
  if not docs:
199
  return 0
200
 
201
  if ids is not None and len(ids) != len(docs):
202
- raise ValueError("ids length must match docs length")
203
 
204
- # Separate parent nodes (not embedded) from regular nodes
205
  regular_docs = []
206
  regular_ids = []
207
  parent_count = 0
@@ -212,7 +227,7 @@ class ChromaVectorDB:
212
  doc_id = ids[i] if ids else self._doc_id(d)
213
 
214
  if md.get("is_parent"):
215
- # Store parent node separately (for Small-to-Big retrieval)
216
  parent_id = md.get("node_id", doc_id)
217
  self._parent_nodes[parent_id] = {
218
  "id": parent_id,
@@ -225,8 +240,8 @@ class ChromaVectorDB:
225
  regular_ids.append(doc_id)
226
 
227
  if parent_count > 0:
228
- logger.info(f"Stored {parent_count} parent nodes (not embedded)")
229
- self._save_parent_nodes() # Persist to disk
230
 
231
  if not regular_docs:
232
  return parent_count
@@ -234,9 +249,11 @@ class ChromaVectorDB:
234
  bs = max(1, batch_size)
235
  col = self.collection
236
 
 
237
  if col is None:
238
  return self.add_documents(regular_docs, ids=regular_ids, batch_size=bs) + parent_count
239
 
 
240
  total = 0
241
  for start in range(0, len(regular_docs), bs):
242
  batch = regular_docs[start : start + bs]
@@ -248,14 +265,16 @@ class ChromaVectorDB:
248
  col.upsert(ids=batch_ids, documents=texts, metadatas=metas, embeddings=embs)
249
  total += len(batch)
250
 
251
- logger.info(f"Upserted {total} documents to vector store")
252
  return total + parent_count
253
 
254
  def count(self) -> int:
 
255
  col = self.collection
256
  return int(col.count()) if col else 0
257
 
258
  def get_all_documents(self, limit: int = 5000) -> List[Dict[str, Any]]:
 
259
  col = self.collection
260
  if col is None:
261
  return []
@@ -272,6 +291,7 @@ class ChromaVectorDB:
272
  return docs
273
 
274
  def delete_documents(self, ids: Sequence[str]) -> int:
 
275
  if not ids:
276
  return 0
277
 
@@ -280,12 +300,14 @@ class ChromaVectorDB:
280
  return 0
281
 
282
  col.delete(ids=list(ids))
283
- logger.info(f"Deleted {len(ids)} documents from vector store")
284
  return len(ids)
285
 
286
  def get_parent_node(self, parent_id: str) -> Optional[Dict[str, Any]]:
 
287
  return self._parent_nodes.get(parent_id)
288
 
289
  @property
290
  def parent_nodes(self) -> Dict[str, Dict[str, Any]]:
 
291
  return self._parent_nodes
 
13
 
14
  @dataclass
15
  class ChromaConfig:
16
+ """Cấu hình cho ChromaDB."""
17
+
18
  def _default_persist_dir() -> str:
19
+ """Lấy đường dẫn mặc định cho persist directory."""
20
  repo_root = Path(__file__).resolve().parents[2]
21
  return str((repo_root / "data" / "chroma").resolve())
22
 
23
+ persist_dir: str = field(default_factory=_default_persist_dir) # Thư mục lưu DB
24
+ collection_name: str = "hust_rag_collection" # Tên collection
25
 
26
 
27
  class ChromaVectorDB:
28
+ """Wrapper cho ChromaDB với hỗ trợ Small-to-Big retrieval."""
29
+
30
  def __init__(
31
  self,
32
  embedder: Any,
33
  config: ChromaConfig | None = None,
34
  ):
35
+ """Khởi tạo ChromaDB với embedder và config."""
36
  self.embedder = embedder
37
  self.config = config or ChromaConfig()
38
  self._hasher = HashProcessor(verbose=False)
39
 
40
+ # Lưu trữ parent nodes (không embed, dùng cho Small-to-Big)
 
41
  self._parent_nodes_path = Path(self.config.persist_dir) / "parent_nodes.json"
42
  self._parent_nodes: Dict[str, Dict[str, Any]] = self._load_parent_nodes()
43
 
44
+ # Khởi tạo ChromaDB
45
  self._vs = Chroma(
46
  collection_name=self.config.collection_name,
47
  embedding_function=self.embedder,
48
  persist_directory=self.config.persist_dir,
49
  )
50
+ logger.info(f"Đã khởi tạo ChromaVectorDB: {self.config.collection_name}")
51
 
52
  def _load_parent_nodes(self) -> Dict[str, Dict[str, Any]]:
53
+ """Tải parent nodes từ file JSON."""
54
  if self._parent_nodes_path.exists():
55
  try:
56
  with open(self._parent_nodes_path, 'r', encoding='utf-8') as f:
57
  data = json.load(f)
58
+ logger.info(f"Đã tải {len(data)} parent nodes từ {self._parent_nodes_path}")
59
  return data
60
  except Exception as e:
61
+ logger.warning(f"Không thể tải parent nodes: {e}")
62
  return {}
63
 
64
  def _save_parent_nodes(self) -> None:
65
+ """Lưu parent nodes vào file JSON."""
66
  try:
67
  self._parent_nodes_path.parent.mkdir(parents=True, exist_ok=True)
68
  with open(self._parent_nodes_path, 'w', encoding='utf-8') as f:
69
  json.dump(self._parent_nodes, f, ensure_ascii=False, indent=2)
70
+ logger.info(f"Đã lưu {len(self._parent_nodes)} parent nodes vào {self._parent_nodes_path}")
71
  except Exception as e:
72
+ logger.warning(f"Không thể lưu parent nodes: {e}")
73
 
74
  @property
75
  def collection(self):
76
+ """Lấy collection gốc của ChromaDB."""
77
  return getattr(self._vs, "_collection", None)
78
 
79
  @property
80
  def vectorstore(self):
81
+ """Lấy LangChain Chroma vectorstore."""
82
  return self._vs
83
 
84
  def _flatten_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
85
+ """Chuyển metadata phức tạp thành format ChromaDB hỗ trợ."""
86
  out: Dict[str, Any] = {}
87
  for k, v in (metadata or {}).items():
88
  if v is None:
 
90
  if isinstance(v, (str, int, float, bool)):
91
  out[str(k)] = v
92
  elif isinstance(v, (list, tuple, set, dict)):
93
+ # Chuyển list/dict thành JSON string
94
  out[str(k)] = json.dumps(v, ensure_ascii=False)
95
  else:
96
  out[str(k)] = str(v)
97
  return out
98
 
99
  def _normalize_doc(self, doc: Any) -> Dict[str, Any]:
100
+ """Chuẩn hóa document từ nhiều format khác nhau thành dict."""
101
+ # Đã là dict
102
  if isinstance(doc, dict):
103
  return doc
104
+ # TextNode/BaseNode từ llama_index
 
105
  if hasattr(doc, "get_content") and hasattr(doc, "metadata"):
106
  return {
107
  "content": doc.get_content(),
108
  "metadata": dict(doc.metadata) if doc.metadata else {},
109
  }
110
+ # Document từ LangChain
 
111
  if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
112
  return {
113
  "content": doc.page_content,
114
  "metadata": dict(doc.metadata) if doc.metadata else {},
115
  }
116
+ raise TypeError(f"Không hỗ trợ loại document: {type(doc)}")
 
117
 
118
  def _to_documents(self, docs: Sequence[Any], ids: Sequence[str]) -> List[Document]:
119
+ """Chuyển danh sách docs thành LangChain Documents."""
120
  out: List[Document] = []
121
  for d, doc_id in zip(docs, ids):
122
  normalized = self._normalize_doc(d)
 
126
  return out
127
 
128
  def _doc_id(self, doc: Any) -> str:
129
+ """Tạo ID duy nhất cho document dựa trên nội dung."""
130
  normalized = self._normalize_doc(doc)
131
  md = normalized.get("metadata") or {}
132
  key = {
 
144
  ids: Optional[Sequence[str]] = None,
145
  batch_size: int = 128,
146
  ) -> int:
147
+ """Thêm documents vào vector store."""
148
  if not docs:
149
  return 0
150
 
151
  if ids is not None and len(ids) != len(docs):
152
+ raise ValueError("Số lượng ids phải bằng số lượng docs")
153
 
154
+ # Tách parent nodes (không embed) khỏi regular nodes
155
  regular_docs = []
156
  regular_ids = []
157
  parent_count = 0
 
162
  doc_id = ids[i] if ids else self._doc_id(d)
163
 
164
  if md.get("is_parent"):
165
+ # Lưu parent node riêng (cho Small-to-Big)
166
  parent_id = md.get("node_id", doc_id)
167
  self._parent_nodes[parent_id] = {
168
  "id": parent_id,
 
175
  regular_ids.append(doc_id)
176
 
177
  if parent_count > 0:
178
+ logger.info(f"Đã lưu {parent_count} parent nodes (không embed)")
179
+ self._save_parent_nodes()
180
 
181
  if not regular_docs:
182
  return parent_count
183
 
184
+ # Thêm theo batch
185
  bs = max(1, batch_size)
186
  total = 0
187
 
 
193
  try:
194
  self._vs.add_documents(lc_docs, ids=batch_ids)
195
  except TypeError:
196
+ # Fallback nếu add_documents không nhận ids
197
  texts = [d.page_content for d in lc_docs]
198
  metas = [d.metadata for d in lc_docs]
199
  self._vs.add_texts(texts=texts, metadatas=metas, ids=batch_ids)
200
  total += len(batch)
201
 
202
+ logger.info(f"Đã thêm {total} documents vào vector store")
203
  return total + parent_count
204
 
205
  def upsert_documents(
 
209
  ids: Optional[Sequence[str]] = None,
210
  batch_size: int = 128,
211
  ) -> int:
212
+ """Upsert documents (thêm mới hoặc cập nhật nếu đã tồn tại)."""
213
  if not docs:
214
  return 0
215
 
216
  if ids is not None and len(ids) != len(docs):
217
+ raise ValueError("Số lượng ids phải bằng số lượng docs")
218
 
219
+ # Tách parent nodes khỏi regular nodes
220
  regular_docs = []
221
  regular_ids = []
222
  parent_count = 0
 
227
  doc_id = ids[i] if ids else self._doc_id(d)
228
 
229
  if md.get("is_parent"):
230
+ # Lưu parent node riêng
231
  parent_id = md.get("node_id", doc_id)
232
  self._parent_nodes[parent_id] = {
233
  "id": parent_id,
 
240
  regular_ids.append(doc_id)
241
 
242
  if parent_count > 0:
243
+ logger.info(f"Đã lưu {parent_count} parent nodes (không embed)")
244
+ self._save_parent_nodes()
245
 
246
  if not regular_docs:
247
  return parent_count
 
249
  bs = max(1, batch_size)
250
  col = self.collection
251
 
252
+ # Fallback nếu không có collection
253
  if col is None:
254
  return self.add_documents(regular_docs, ids=regular_ids, batch_size=bs) + parent_count
255
 
256
+ # Upsert theo batch
257
  total = 0
258
  for start in range(0, len(regular_docs), bs):
259
  batch = regular_docs[start : start + bs]
 
265
  col.upsert(ids=batch_ids, documents=texts, metadatas=metas, embeddings=embs)
266
  total += len(batch)
267
 
268
+ logger.info(f"Đã upsert {total} documents vào vector store")
269
  return total + parent_count
270
 
271
  def count(self) -> int:
272
+ """Đếm số documents trong collection."""
273
  col = self.collection
274
  return int(col.count()) if col else 0
275
 
276
  def get_all_documents(self, limit: int = 5000) -> List[Dict[str, Any]]:
277
+ """Lấy tất cả documents từ collection."""
278
  col = self.collection
279
  if col is None:
280
  return []
 
291
  return docs
292
 
293
  def delete_documents(self, ids: Sequence[str]) -> int:
294
+ """Xóa documents theo danh sách IDs."""
295
  if not ids:
296
  return 0
297
 
 
300
  return 0
301
 
302
  col.delete(ids=list(ids))
303
+ logger.info(f"Đã xóa {len(ids)} documents khỏi vector store")
304
  return len(ids)
305
 
306
  def get_parent_node(self, parent_id: str) -> Optional[Dict[str, Any]]:
307
+ """Lấy parent node theo ID (cho Small-to-Big)."""
308
  return self._parent_nodes.get(parent_id)
309
 
310
  @property
311
  def parent_nodes(self) -> Dict[str, Dict[str, Any]]:
312
+ """Lấy tất cả parent nodes."""
313
  return self._parent_nodes
evaluation/eval_utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import sys
3
  import re
@@ -20,31 +22,38 @@ from core.rag.generator import RAGGenerator
20
 
21
 
22
  def strip_thinking(text: str) -> str:
 
23
  return re.sub(r'<think>.*?</think>\s*', '', text, flags=re.DOTALL).strip()
24
 
25
 
26
  def load_csv_data(csv_path: str, sample_size: int = 0) -> tuple[list, list]:
 
27
  questions, ground_truths = [], []
28
  with open(csv_path, 'r', encoding='utf-8') as f:
29
  for row in csv.DictReader(f):
30
  if row.get('question') and row.get('ground_truth'):
31
  questions.append(row['question'])
32
  ground_truths.append(row['ground_truth'])
 
 
33
  if sample_size > 0:
34
  questions = questions[:sample_size]
35
  ground_truths = ground_truths[:sample_size]
 
36
  return questions, ground_truths
37
 
38
 
39
  def init_rag() -> tuple[RAGGenerator, QwenEmbeddings, OpenAI]:
 
40
  embeddings = QwenEmbeddings(EmbeddingConfig())
41
  db = ChromaVectorDB(embedder=embeddings, config=ChromaConfig())
42
  retriever = Retriever(vector_db=db)
43
  rag = RAGGenerator(retriever=retriever)
44
 
 
45
  api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
46
  if not api_key:
47
- raise ValueError("Missing SILICONFLOW_API_KEY")
48
 
49
  llm_client = OpenAI(api_key=api_key, base_url="https://api.siliconflow.com/v1", timeout=60.0)
50
  return rag, embeddings, llm_client
@@ -58,14 +67,18 @@ def generate_answers(
58
  retrieval_mode: str = "hybrid_rerank",
59
  max_workers: int = 8,
60
  ) -> tuple[list, list]:
 
61
 
62
  def process(idx_q):
 
63
  idx, q = idx_q
64
  try:
 
65
  prepared = rag.retrieve_and_prepare(q, mode=retrieval_mode)
66
  if not prepared["results"]:
67
  return idx, "Không tìm thấy thông tin.", []
68
 
 
69
  resp = llm_client.chat.completions.create(
70
  model=llm_model,
71
  messages=[{"role": "user", "content": prepared["prompt"]}],
@@ -75,18 +88,20 @@ def generate_answers(
75
  answer = strip_thinking(resp.choices[0].message.content or "")
76
  return idx, answer, prepared["contexts"]
77
  except Exception as e:
78
- print(f" Q{idx+1} Error: {e}")
79
  return idx, "Không thể trả lời.", []
80
 
81
  n = len(questions)
82
  answers, contexts = [""] * n, [[] for _ in range(n)]
83
 
84
- print(f" Generating {n} answers...")
 
 
85
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
86
  futures = {executor.submit(process, (i, q)): i for i, q in enumerate(questions)}
87
  for i, future in enumerate(as_completed(futures), 1):
88
  idx, ans, ctx = future.result(timeout=120)
89
  answers[idx], contexts[idx] = ans, ctx
90
- print(f" [{i}/{n}] Done")
91
 
92
  return answers, contexts
 
1
+ """Các utility functions cho evaluation."""
2
+
3
  import os
4
  import sys
5
  import re
 
22
 
23
 
24
  def strip_thinking(text: str) -> str:
25
+ """Loại bỏ các block <think>...</think> từ output của LLM."""
26
  return re.sub(r'<think>.*?</think>\s*', '', text, flags=re.DOTALL).strip()
27
 
28
 
29
  def load_csv_data(csv_path: str, sample_size: int = 0) -> tuple[list, list]:
30
+ """Đọc dữ liệu câu hỏi và ground truth từ file CSV."""
31
  questions, ground_truths = [], []
32
  with open(csv_path, 'r', encoding='utf-8') as f:
33
  for row in csv.DictReader(f):
34
  if row.get('question') and row.get('ground_truth'):
35
  questions.append(row['question'])
36
  ground_truths.append(row['ground_truth'])
37
+
38
+ # Giới hạn số lượng sample
39
  if sample_size > 0:
40
  questions = questions[:sample_size]
41
  ground_truths = ground_truths[:sample_size]
42
+
43
  return questions, ground_truths
44
 
45
 
46
  def init_rag() -> tuple[RAGGenerator, QwenEmbeddings, OpenAI]:
47
+ """Khởi tạo các components RAG cho evaluation."""
48
  embeddings = QwenEmbeddings(EmbeddingConfig())
49
  db = ChromaVectorDB(embedder=embeddings, config=ChromaConfig())
50
  retriever = Retriever(vector_db=db)
51
  rag = RAGGenerator(retriever=retriever)
52
 
53
+ # Khởi tạo LLM client
54
  api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
55
  if not api_key:
56
+ raise ValueError("Chưa đặt SILICONFLOW_API_KEY")
57
 
58
  llm_client = OpenAI(api_key=api_key, base_url="https://api.siliconflow.com/v1", timeout=60.0)
59
  return rag, embeddings, llm_client
 
67
  retrieval_mode: str = "hybrid_rerank",
68
  max_workers: int = 8,
69
  ) -> tuple[list, list]:
70
+ """Generate câu trả lời cho danh sách câu hỏi với parallel processing."""
71
 
72
  def process(idx_q):
73
+ """Xử lý một câu hỏi: retrieve + generate."""
74
  idx, q = idx_q
75
  try:
76
+ # Retrieve và chuẩn bị context
77
  prepared = rag.retrieve_and_prepare(q, mode=retrieval_mode)
78
  if not prepared["results"]:
79
  return idx, "Không tìm thấy thông tin.", []
80
 
81
+ # Gọi LLM để generate answer
82
  resp = llm_client.chat.completions.create(
83
  model=llm_model,
84
  messages=[{"role": "user", "content": prepared["prompt"]}],
 
88
  answer = strip_thinking(resp.choices[0].message.content or "")
89
  return idx, answer, prepared["contexts"]
90
  except Exception as e:
91
+ print(f" Q{idx+1} Lỗi: {e}")
92
  return idx, "Không thể trả lời.", []
93
 
94
  n = len(questions)
95
  answers, contexts = [""] * n, [[] for _ in range(n)]
96
 
97
+ print(f" Đang generate {n} câu trả lời...")
98
+
99
+ # Xử lý song song với ThreadPoolExecutor
100
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
101
  futures = {executor.submit(process, (i, q)): i for i, q in enumerate(questions)}
102
  for i, future in enumerate(as_completed(futures), 1):
103
  idx, ans, ctx = future.result(timeout=120)
104
  answers[idx], contexts[idx] = ans, ctx
105
+ print(f" [{i}/{n}] Xong")
106
 
107
  return answers, contexts
evaluation/ragas_eval.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import sys
3
  import json
@@ -21,33 +23,34 @@ from ragas.run_config import RunConfig
21
 
22
  from evaluation.eval_utils import load_csv_data, init_rag, generate_answers
23
 
24
- # Config
25
- CSV_PATH = "data/data.csv"
26
- OUTPUT_DIR = "evaluation/results"
27
- LLM_MODEL = os.getenv("EVAL_LLM_MODEL", "nex-agi/DeepSeek-V3.1-Nex-N1")
28
  API_BASE = "https://api.siliconflow.com/v1"
29
 
30
 
31
  def run_evaluation(sample_size: int = 10, retrieval_mode: str = "hybrid_rerank") -> dict:
 
32
  print(f"\n{'='*60}")
33
  print(f"RAGAS EVALUATION - Mode: {retrieval_mode}")
34
  print(f"{'='*60}")
35
 
36
- # Init RAG components
37
  rag, embeddings, llm_client = init_rag()
38
 
39
- # Load data
40
  questions, ground_truths = load_csv_data(str(REPO_ROOT / CSV_PATH), sample_size)
41
- print(f" Loaded {len(questions)} samples")
42
 
43
- # Generate answers
44
  answers, contexts = generate_answers(
45
  rag, questions, llm_client,
46
  llm_model=LLM_MODEL,
47
  retrieval_mode=retrieval_mode,
48
  )
49
 
50
- # Setup RAGAS evaluator
51
  api_key = os.getenv("SILICONFLOW_API_KEY", "")
52
  evaluator_llm = LangchainLLMWrapper(ChatOpenAI(
53
  model=LLM_MODEL,
@@ -59,7 +62,7 @@ def run_evaluation(sample_size: int = 10, retrieval_mode: str = "hybrid_rerank")
59
  ))
60
  evaluator_embeddings = LangchainEmbeddingsWrapper(embeddings)
61
 
62
- # Create dataset
63
  dataset = Dataset.from_dict({
64
  "question": questions,
65
  "answer": answers,
@@ -67,18 +70,18 @@ def run_evaluation(sample_size: int = 10, retrieval_mode: str = "hybrid_rerank")
67
  "ground_truth": ground_truths,
68
  })
69
 
70
- # Run RAGAS evaluation
71
- print("\n Running RAGAS metrics...")
72
  results = evaluate(
73
  dataset=dataset,
74
  metrics=[
75
- faithfulness,
76
- answer_relevancy,
77
- context_precision,
78
- context_recall,
79
- RougeScore(rouge_type='rouge1', mode='fmeasure'),
80
- RougeScore(rouge_type='rouge2', mode='fmeasure'),
81
- RougeScore(rouge_type='rougeL', mode='fmeasure'),
82
  ],
83
  llm=evaluator_llm,
84
  embeddings=evaluator_embeddings,
@@ -86,65 +89,37 @@ def run_evaluation(sample_size: int = 10, retrieval_mode: str = "hybrid_rerank")
86
  run_config=RunConfig(max_workers=8, timeout=600, max_retries=3),
87
  )
88
 
89
- # Extract scores
90
  df = results.to_pandas()
91
  metric_cols = [c for c in df.columns if c not in ("question", "answer", "contexts", "ground_truth", "user_input", "response", "reference", "retrieved_contexts")]
92
 
 
93
  avg_scores = {}
94
  for col in metric_cols:
95
  values = df[col].dropna().tolist()
96
  if values:
97
  avg_scores[col] = sum(values) / len(values)
98
 
99
- # Save results
100
  out_path = REPO_ROOT / OUTPUT_DIR
101
  out_path.mkdir(parents=True, exist_ok=True)
102
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
103
-
104
- # JSON
105
- json_path = out_path / f"ragas_{retrieval_mode}_{timestamp}.json"
106
- with open(json_path, 'w', encoding='utf-8') as f:
107
- json.dump({
108
- "timestamp": timestamp,
109
- "retrieval_mode": retrieval_mode,
110
- "sample_size": len(questions),
111
- "avg_scores": avg_scores,
112
- "samples": [
113
- {"question": q, "answer": a, "ground_truth": gt, "contexts": ctx}
114
- for q, a, gt, ctx in zip(questions, answers, ground_truths, contexts)
115
- ]
116
- }, f, ensure_ascii=False, indent=2)
117
-
118
- # CSV
119
  csv_path = out_path / f"ragas_{retrieval_mode}_{timestamp}.csv"
120
  with open(csv_path, 'w', encoding='utf-8') as f:
121
  f.write("retrieval_mode,sample_size," + ",".join(avg_scores.keys()) + "\n")
122
  f.write(f"{retrieval_mode},{len(questions)}," + ",".join(f"{v:.4f}" for v in avg_scores.values()) + "\n")
123
 
124
- # Print summary
125
  print(f"\n{'='*60}")
126
- print(f"RESULTS - {retrieval_mode} ({len(questions)} samples)")
127
  print(f"{'='*60}")
128
  for metric, score in avg_scores.items():
129
  bar = "#" * int(score * 20) + "-" * (20 - int(score * 20))
130
  print(f" {metric:25} [{bar}] {score:.4f}")
131
 
132
- print(f"\nSaved: {json_path}")
133
- print(f"Saved: {csv_path}")
134
-
135
- return avg_scores
136
-
137
-
138
- if __name__ == "__main__":
139
- import argparse
140
- parser = argparse.ArgumentParser(description="RAGAS Evaluation")
141
- parser.add_argument("--samples", type=int, default=10, help="Number of samples")
142
- parser.add_argument("--mode", type=str, default="hybrid_rerank",
143
- choices=["vector_only", "bm25_only", "hybrid", "hybrid_rerank", "all"])
144
- args = parser.parse_args()
145
 
146
- if args.mode == "all":
147
- for mode in ["vector_only", "bm25_only", "hybrid", "hybrid_rerank"]:
148
- run_evaluation(args.samples, mode)
149
- else:
150
- run_evaluation(args.samples, args.mode)
 
1
+ """Script đánh giá RAG bằng RAGAS framework."""
2
+
3
  import os
4
  import sys
5
  import json
 
23
 
24
  from evaluation.eval_utils import load_csv_data, init_rag, generate_answers
25
 
26
+ # Cấu hình
27
+ CSV_PATH = "data/data.csv" # File dữ liệu test
28
+ OUTPUT_DIR = "evaluation/results" # Thư mục output
29
+ LLM_MODEL = os.getenv("EVAL_LLM_MODEL", "nex-agi/DeepSeek-V3.1-Nex-N1") # Model đánh giá
30
  API_BASE = "https://api.siliconflow.com/v1"
31
 
32
 
33
  def run_evaluation(sample_size: int = 10, retrieval_mode: str = "hybrid_rerank") -> dict:
34
+ """Chạy đánh giá RAGAS trên dữ liệu test."""
35
  print(f"\n{'='*60}")
36
  print(f"RAGAS EVALUATION - Mode: {retrieval_mode}")
37
  print(f"{'='*60}")
38
 
39
+ # Khởi tạo RAG components
40
  rag, embeddings, llm_client = init_rag()
41
 
42
+ # Tải dữ liệu test
43
  questions, ground_truths = load_csv_data(str(REPO_ROOT / CSV_PATH), sample_size)
44
+ print(f" Đã tải {len(questions)} samples")
45
 
46
+ # Generate câu trả lời
47
  answers, contexts = generate_answers(
48
  rag, questions, llm_client,
49
  llm_model=LLM_MODEL,
50
  retrieval_mode=retrieval_mode,
51
  )
52
 
53
+ # Thiết lập RAGAS evaluator
54
  api_key = os.getenv("SILICONFLOW_API_KEY", "")
55
  evaluator_llm = LangchainLLMWrapper(ChatOpenAI(
56
  model=LLM_MODEL,
 
62
  ))
63
  evaluator_embeddings = LangchainEmbeddingsWrapper(embeddings)
64
 
65
+ # Chuyển dữ liệu thành format Dataset
66
  dataset = Dataset.from_dict({
67
  "question": questions,
68
  "answer": answers,
 
70
  "ground_truth": ground_truths,
71
  })
72
 
73
+ # Chạy đánh giá RAGAS
74
+ print("\n Đang chạy RAGAS metrics...")
75
  results = evaluate(
76
  dataset=dataset,
77
  metrics=[
78
+ faithfulness, # Độ trung thực với context
79
+ answer_relevancy, # Độ liên quan của câu trả lời
80
+ context_precision, # Độ chính xác của context
81
+ context_recall, # Độ bao phủ của context
82
+ RougeScore(rouge_type='rouge1', mode='fmeasure'), # ROUGE-1
83
+ RougeScore(rouge_type='rouge2', mode='fmeasure'), # ROUGE-2
84
+ RougeScore(rouge_type='rougeL', mode='fmeasure'), # ROUGE-L
85
  ],
86
  llm=evaluator_llm,
87
  embeddings=evaluator_embeddings,
 
89
  run_config=RunConfig(max_workers=8, timeout=600, max_retries=3),
90
  )
91
 
92
+ # Trích xuất điểm số
93
  df = results.to_pandas()
94
  metric_cols = [c for c in df.columns if c not in ("question", "answer", "contexts", "ground_truth", "user_input", "response", "reference", "retrieved_contexts")]
95
 
96
+ # Tính điểm trung bình cho mỗi metric
97
  avg_scores = {}
98
  for col in metric_cols:
99
  values = df[col].dropna().tolist()
100
  if values:
101
  avg_scores[col] = sum(values) / len(values)
102
 
103
+ # Lưu kết quả
104
  out_path = REPO_ROOT / OUTPUT_DIR
105
  out_path.mkdir(parents=True, exist_ok=True)
106
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
107
+
108
+ # Lưu file CSV (tóm tắt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  csv_path = out_path / f"ragas_{retrieval_mode}_{timestamp}.csv"
110
  with open(csv_path, 'w', encoding='utf-8') as f:
111
  f.write("retrieval_mode,sample_size," + ",".join(avg_scores.keys()) + "\n")
112
  f.write(f"{retrieval_mode},{len(questions)}," + ",".join(f"{v:.4f}" for v in avg_scores.values()) + "\n")
113
 
114
+ # In kết quả
115
  print(f"\n{'='*60}")
116
+ print(f"KẾT QUẢ - {retrieval_mode} ({len(questions)} samples)")
117
  print(f"{'='*60}")
118
  for metric, score in avg_scores.items():
119
  bar = "#" * int(score * 20) + "-" * (20 - int(score * 20))
120
  print(f" {metric:25} [{bar}] {score:.4f}")
121
 
122
+ print(f"\nĐã lưu: {json_path}")
123
+ print(f"Đã lưu: {csv_path}")
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ return avg_scores
 
 
 
 
scripts/build_data.py CHANGED
@@ -1,8 +1,10 @@
 
 
1
  import sys
 
2
  from pathlib import Path
3
  from dotenv import find_dotenv, load_dotenv
4
 
5
- # Load .env file
6
  load_dotenv(find_dotenv(usecwd=True))
7
 
8
  REPO_ROOT = Path(__file__).resolve().parents[1]
@@ -14,12 +16,11 @@ from core.rag.embedding_model import EmbeddingConfig, QwenEmbeddings
14
  from core.rag.vector_store import ChromaConfig, ChromaVectorDB
15
  from core.hash_file.hash_file import HashProcessor
16
 
17
- # Global hash processor instance
18
  _hasher = HashProcessor(verbose=False)
19
 
20
 
21
  def get_db_file_info(db: ChromaVectorDB) -> dict:
22
- """Get mapping of source_file -> set of doc IDs in DB."""
23
  docs = db.get_all_documents()
24
  file_to_ids = {}
25
  file_to_hash = {}
@@ -35,7 +36,7 @@ def get_db_file_info(db: ChromaVectorDB) -> dict:
35
  file_to_ids[source] = set()
36
  file_to_ids[source].add(doc_id)
37
 
38
- # Store first hash we see for this file
39
  if source not in file_to_hash and content_hash:
40
  file_to_hash[source] = content_hash
41
 
@@ -43,64 +44,65 @@ def get_db_file_info(db: ChromaVectorDB) -> dict:
43
 
44
 
45
  def main():
46
- import argparse
47
- parser = argparse.ArgumentParser()
48
- parser.add_argument("--force", action="store_true", help="Force rebuild all files")
49
- parser.add_argument("--no-delete", action="store_true", help="Don't delete orphaned docs")
50
  args = parser.parse_args()
51
 
52
  print("=" * 60)
53
  print("BUILD HUST RAG DATABASE")
54
  print("=" * 60)
55
 
56
- print("\n[1/5] Initializing embedder...")
 
57
  emb_cfg = EmbeddingConfig()
58
  emb = QwenEmbeddings(emb_cfg)
59
  print(f" Model: {emb_cfg.model}")
60
  print(f" API: {emb_cfg.api_base_url}")
61
 
62
- print("\n[2/5] Initializing ChromaDB...")
 
63
  db_cfg = ChromaConfig()
64
  db = ChromaVectorDB(embedder=emb, config=db_cfg)
65
  old_count = db.count()
66
  print(f" Collection: {db_cfg.collection_name}")
67
- print(f" Current docs: {old_count}")
68
 
69
- # Get current state of DB
70
  db_info = {"ids": {}, "hashes": {}}
71
  if not args.force and old_count > 0:
72
- print("\n Scanning existing documents...")
73
  db_info = get_db_file_info(db)
74
- print(f" Found {len(db_info['ids'])} source files in DB")
75
 
76
- # Scan markdown files
77
- print("\n[3/5] Scanning markdown files...")
78
  root = REPO_ROOT / "data" / "data_process"
79
  md_files = sorted(root.rglob("*.md"))
80
- print(f" Found {len(md_files)} markdown files on disk")
81
 
82
- # Build set of current file names
83
  current_files = {f.name for f in md_files}
84
  db_files = set(db_info["ids"].keys())
85
 
86
- # Find files to delete (in DB but not on disk)
87
  files_to_delete = db_files - current_files
88
 
89
- # Delete orphaned documents
90
  deleted_count = 0
91
  if files_to_delete and not args.no_delete:
92
- print(f"\n[4/5] Cleaning up {len(files_to_delete)} deleted files...")
93
  for filename in files_to_delete:
94
  doc_ids = list(db_info["ids"].get(filename, []))
95
  if doc_ids:
96
  db.delete_documents(doc_ids)
97
  deleted_count += len(doc_ids)
98
- print(f" Deleted: {filename} ({len(doc_ids)} chunks)")
99
  else:
100
- print("\n[4/5] No files to delete")
101
 
102
- # Process files (add new, update changed)
103
- print("\n[5/5] Processing markdown files...")
104
  total_added = 0
105
  total_updated = 0
106
  skipped = 0
@@ -110,16 +112,16 @@ def main():
110
  db_hash = db_info["hashes"].get(f.name, "")
111
  existing_ids = db_info["ids"].get(f.name, set())
112
 
113
- # Skip if hash matches (file unchanged)
114
  if not args.force and db_hash == file_hash:
115
- print(f" [{i}/{len(md_files)}] {f.name}: SKIP (unchanged)")
116
  skipped += 1
117
  continue
118
 
119
- # If file changed, delete old chunks first
120
  if existing_ids and not args.force:
121
  db.delete_documents(list(existing_ids))
122
- print(f" [{i}/{len(md_files)}] {f.name}: UPDATE (deleted {len(existing_ids)} old chunks)")
123
  is_update = True
124
  else:
125
  is_update = False
@@ -127,7 +129,7 @@ def main():
127
  try:
128
  docs = chunk_markdown_file(f)
129
  if docs:
130
- # Add content_hash to metadata for future change detection
131
  for doc in docs:
132
  if hasattr(doc, 'metadata'):
133
  doc.metadata["content_hash"] = file_hash
@@ -135,29 +137,29 @@ def main():
135
  doc["metadata"]["content_hash"] = file_hash
136
 
137
  n = db.upsert_documents(docs)
138
-
139
  if is_update:
140
  total_updated += n
141
- print(f" [{i}/{len(md_files)}] {f.name}: +{n} new chunks")
142
  else:
143
  total_added += n
144
  print(f" [{i}/{len(md_files)}] {f.name}: {n} chunks")
145
  else:
146
- print(f" [{i}/{len(md_files)}] {f.name}: SKIP (no chunks)")
147
  except Exception as e:
148
- print(f" [{i}/{len(md_files)}] {f.name}: ERROR - {e}")
149
 
 
150
  new_count = db.count()
151
  print(f"\n{'=' * 60}")
152
- print("SUMMARY")
153
  print("=" * 60)
154
- print(f" Deleted (orphaned): {deleted_count} chunks")
155
- print(f" Updated: {total_updated} chunks")
156
- print(f" Added new: {total_added} chunks")
157
- print(f" Skipped (unchanged): {skipped} files")
158
- print(f" DB count: {old_count} -> {new_count} ({new_count - old_count:+d})")
159
 
160
- print("\nDONE!")
161
 
162
 
163
  if __name__ == "__main__":
 
1
+ """Script build ChromaDB từ markdown files với incremental update."""
2
+
3
  import sys
4
+ import argparse
5
  from pathlib import Path
6
  from dotenv import find_dotenv, load_dotenv
7
 
 
8
  load_dotenv(find_dotenv(usecwd=True))
9
 
10
  REPO_ROOT = Path(__file__).resolve().parents[1]
 
16
  from core.rag.vector_store import ChromaConfig, ChromaVectorDB
17
  from core.hash_file.hash_file import HashProcessor
18
 
 
19
  _hasher = HashProcessor(verbose=False)
20
 
21
 
22
  def get_db_file_info(db: ChromaVectorDB) -> dict:
23
+ """Lấy thông tin files đã trong DB (IDs hash)."""
24
  docs = db.get_all_documents()
25
  file_to_ids = {}
26
  file_to_hash = {}
 
36
  file_to_ids[source] = set()
37
  file_to_ids[source].add(doc_id)
38
 
39
+ # Lưu hash đầu tiên tìm thấy cho file
40
  if source not in file_to_hash and content_hash:
41
  file_to_hash[source] = content_hash
42
 
 
44
 
45
 
46
  def main():
47
+ parser = argparse.ArgumentParser(description="Build ChromaDB từ markdown files")
48
+ parser.add_argument("--force", action="store_true", help="Build lại tất cả files")
49
+ parser.add_argument("--no-delete", action="store_true", help="Không xóa docs orphaned")
 
50
  args = parser.parse_args()
51
 
52
  print("=" * 60)
53
  print("BUILD HUST RAG DATABASE")
54
  print("=" * 60)
55
 
56
+ # Bước 1: Khởi tạo embedder
57
+ print("\n[1/5] Khởi tạo embedder...")
58
  emb_cfg = EmbeddingConfig()
59
  emb = QwenEmbeddings(emb_cfg)
60
  print(f" Model: {emb_cfg.model}")
61
  print(f" API: {emb_cfg.api_base_url}")
62
 
63
+ # Bước 2: Khởi tạo ChromaDB
64
+ print("\n[2/5] Khởi tạo ChromaDB...")
65
  db_cfg = ChromaConfig()
66
  db = ChromaVectorDB(embedder=emb, config=db_cfg)
67
  old_count = db.count()
68
  print(f" Collection: {db_cfg.collection_name}")
69
+ print(f" Số docs hiện tại: {old_count}")
70
 
71
+ # Lấy trạng thái hiện tại của DB
72
  db_info = {"ids": {}, "hashes": {}}
73
  if not args.force and old_count > 0:
74
+ print("\n Đang quét documents trong DB...")
75
  db_info = get_db_file_info(db)
76
+ print(f" Tìm thấy {len(db_info['ids'])} source files trong DB")
77
 
78
+ # Bước 3: Quét markdown files
79
+ print("\n[3/5] Quét markdown files...")
80
  root = REPO_ROOT / "data" / "data_process"
81
  md_files = sorted(root.rglob("*.md"))
82
+ print(f" Tìm thấy {len(md_files)} markdown files trên disk")
83
 
84
+ # So sánh files trên disk vs trong DB
85
  current_files = {f.name for f in md_files}
86
  db_files = set(db_info["ids"].keys())
87
 
88
+ # Tìm files cần xóa ( trong DB nhưng không trên disk)
89
  files_to_delete = db_files - current_files
90
 
91
+ # Bước 4: Xóa docs orphaned
92
  deleted_count = 0
93
  if files_to_delete and not args.no_delete:
94
+ print(f"\n[4/5] Dọn dẹp {len(files_to_delete)} files đã xóa...")
95
  for filename in files_to_delete:
96
  doc_ids = list(db_info["ids"].get(filename, []))
97
  if doc_ids:
98
  db.delete_documents(doc_ids)
99
  deleted_count += len(doc_ids)
100
+ print(f" Đã xóa: {filename} ({len(doc_ids)} chunks)")
101
  else:
102
+ print("\n[4/5] Không files cần xóa")
103
 
104
+ # Bước 5: Xử lý markdown files (thêm mới, cập nhật)
105
+ print("\n[5/5] Xử markdown files...")
106
  total_added = 0
107
  total_updated = 0
108
  skipped = 0
 
112
  db_hash = db_info["hashes"].get(f.name, "")
113
  existing_ids = db_info["ids"].get(f.name, set())
114
 
115
+ # Bỏ qua nếu hash khớp (file không thay đổi)
116
  if not args.force and db_hash == file_hash:
117
+ print(f" [{i}/{len(md_files)}] {f.name}: BỎ QUA (không đổi)")
118
  skipped += 1
119
  continue
120
 
121
+ # Nếu file thay đổi, xóa chunks cũ trước
122
  if existing_ids and not args.force:
123
  db.delete_documents(list(existing_ids))
124
+ print(f" [{i}/{len(md_files)}] {f.name}: CẬP NHẬT (xóa {len(existing_ids)} chunks )")
125
  is_update = True
126
  else:
127
  is_update = False
 
129
  try:
130
  docs = chunk_markdown_file(f)
131
  if docs:
132
+ # Thêm hash vào metadata để phát hiện thay đổi lần sau
133
  for doc in docs:
134
  if hasattr(doc, 'metadata'):
135
  doc.metadata["content_hash"] = file_hash
 
137
  doc["metadata"]["content_hash"] = file_hash
138
 
139
  n = db.upsert_documents(docs)
 
140
  if is_update:
141
  total_updated += n
142
+ print(f" [{i}/{len(md_files)}] {f.name}: +{n} chunks mới")
143
  else:
144
  total_added += n
145
  print(f" [{i}/{len(md_files)}] {f.name}: {n} chunks")
146
  else:
147
+ print(f" [{i}/{len(md_files)}] {f.name}: BỎ QUA (không chunks)")
148
  except Exception as e:
149
+ print(f" [{i}/{len(md_files)}] {f.name}: LỖI - {e}")
150
 
151
+ # Tổng kết
152
  new_count = db.count()
153
  print(f"\n{'=' * 60}")
154
+ print("TỔNG KẾT")
155
  print("=" * 60)
156
+ print(f" Đã xóa (orphaned): {deleted_count} chunks")
157
+ print(f" Đã cập nhật: {total_updated} chunks")
158
+ print(f" Đã thêm mới: {total_added} chunks")
159
+ print(f" Đã bỏ qua: {skipped} files")
160
+ print(f" Số docs trong DB: {old_count} -> {new_count} ({new_count - old_count:+d})")
161
 
162
+ print("\nHOÀN TẤT!")
163
 
164
 
165
  if __name__ == "__main__":
scripts/run_eval.py CHANGED
@@ -1,4 +1,5 @@
1
  import sys
 
2
  from pathlib import Path
3
 
4
  REPO_ROOT = Path(__file__).resolve().parents[1]
@@ -7,18 +8,19 @@ if str(REPO_ROOT) not in sys.path:
7
 
8
 
9
  def main():
10
- import argparse
11
- parser = argparse.ArgumentParser(description="RAG Evaluation")
12
- parser.add_argument("--samples", type=int, default=10, help="Number of samples (0 = all)")
13
  parser.add_argument("--mode", type=str, default="hybrid_rerank",
14
- choices=["vector_only", "bm25_only", "hybrid", "hybrid_rerank", "all"])
 
15
  args = parser.parse_args()
16
 
17
  from evaluation.ragas_eval import run_evaluation
18
 
19
  if args.mode == "all":
 
20
  print("\n" + "=" * 60)
21
- print("RUNNING ALL RETRIEVAL MODES")
22
  print("=" * 60)
23
  for mode in ["vector_only", "bm25_only", "hybrid", "hybrid_rerank"]:
24
  run_evaluation(args.samples, mode)
 
1
  import sys
2
+ import argparse
3
  from pathlib import Path
4
 
5
  REPO_ROOT = Path(__file__).resolve().parents[1]
 
8
 
9
 
10
  def main():
11
+ parser = argparse.ArgumentParser(description="Đánh giá RAG bằng RAGAS")
12
+ parser.add_argument("--samples", type=int, default=10, help="Số lượng samples (0 = tất cả)")
 
13
  parser.add_argument("--mode", type=str, default="hybrid_rerank",
14
+ choices=["vector_only", "bm25_only", "hybrid", "hybrid_rerank", "all"],
15
+ help="Chế độ retrieval")
16
  args = parser.parse_args()
17
 
18
  from evaluation.ragas_eval import run_evaluation
19
 
20
  if args.mode == "all":
21
+ # Chạy tất cả các chế độ retrieval
22
  print("\n" + "=" * 60)
23
+ print("CHẠY TẤT CẢ CÁC CHẾ ĐỘ RETRIEVAL")
24
  print("=" * 60)
25
  for mode in ["vector_only", "bm25_only", "hybrid", "hybrid_rerank"]:
26
  run_evaluation(args.samples, mode)
test/parse_data_hash_test.py DELETED
@@ -1,102 +0,0 @@
1
- import os
2
- import sys
3
- import random
4
- import shutil
5
- from pathlib import Path
6
-
7
- # Ensure project root is on sys.path
8
- _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
9
- if _PROJECT_ROOT not in sys.path:
10
- sys.path.insert(0, _PROJECT_ROOT)
11
-
12
- from core.preprocessing.docling_processor import DoclingProcessor
13
-
14
- def get_random_local_pdf(source_dir: str):
15
- if not os.path.exists(source_dir):
16
- return None
17
-
18
- files = [f for f in os.listdir(source_dir) if f.lower().endswith('.pdf')]
19
- if not files:
20
- return None
21
-
22
- return os.path.join(source_dir, random.choice(files))
23
-
24
- def main(output_dir=None, use_ocr=False):
25
- # Setup paths
26
- source_dir = os.path.join(_PROJECT_ROOT, "data", "files")
27
- if output_dir is None:
28
- output_dir = os.path.join(_PROJECT_ROOT, "data", "test_output")
29
-
30
- # Clean up old test output
31
- if os.path.exists(output_dir):
32
- shutil.rmtree(output_dir)
33
- os.makedirs(output_dir, exist_ok=True)
34
-
35
- print(f"Đang tìm file PDF để test...")
36
-
37
- # 1. Thử lấy từ local data/files
38
- file_path = get_random_local_pdf(source_dir)
39
-
40
- if not file_path:
41
- print(f"Không tìm thấy file PDF nào trong {source_dir}")
42
- print("Hãy chạy 'python core/hash_file/hash_data_goc.py' để tải dữ liệu trước.")
43
- return 1
44
-
45
- filename = os.path.basename(file_path)
46
- print(f"Đã chọn file test: {filename}")
47
- print(f"Đường dẫn: {file_path}")
48
-
49
- try:
50
- # Khởi tạo processor
51
- print("Khởi tạo DoclingProcessor...")
52
- processor = DoclingProcessor(
53
- output_dir=output_dir,
54
- use_ocr=use_ocr,
55
- timeout=None
56
- )
57
-
58
- # Parse file
59
- print(f"Bắt đầu parse...")
60
- result = processor.parse_document(file_path)
61
-
62
- if result:
63
- print(f"Test thành công!")
64
-
65
- # Kiểm tra kết quả
66
- output_files = os.listdir(output_dir)
67
- md_files = [f for f in output_files if f.endswith('.md')]
68
-
69
- if md_files:
70
- print(f"File output: {md_files[0]}")
71
- print(f"Thư mục output: {output_dir}")
72
-
73
- # In thống kê sơ bộ cho Markdown
74
- content_len = len(result)
75
- preview = result[:200].replace('\n', ' ') + "..."
76
- print(f" Kích thước: {content_len} ký tự")
77
- print(f" Preview: {preview}")
78
- else:
79
- print(" Không tìm thấy file Markdown output dù hàm trả về kết quả.")
80
- else:
81
- print("Test thất bại: Hàm parse trả về None")
82
- return 1
83
-
84
- return 0
85
-
86
- except Exception as e:
87
- print(f"Lỗi ngoại lệ: {e}")
88
- import traceback
89
- traceback.print_exc()
90
- return 1
91
-
92
- if __name__ == "__main__":
93
- import argparse
94
- parser = argparse.ArgumentParser(description="Test Docling với 1 file PDF ngẫu nhiên từ data/files")
95
- parser.add_argument("--output", help="Thư mục output cho test (mặc định: data/test_output)")
96
- parser.add_argument("--ocr", action="store_true", help="Bật OCR")
97
- args = parser.parse_args()
98
-
99
- sys.exit(main(
100
- output_dir=args.output,
101
- use_ocr=args.ocr
102
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test/test_chunk.py CHANGED
@@ -1,47 +1,57 @@
 
 
1
  import sys
2
  sys.path.insert(0, "/home/bahung/DoAn")
3
 
 
 
 
4
  from core.rag.chunk import chunk_markdown_file
5
 
 
6
  test_file = "data/data_process/chuong_trinh_dao_tao/1.1. Kỹ thuật Cơ điện tử.md"
7
 
8
  print("=" * 70)
9
  print(f" File: {test_file}")
10
  print("=" * 70)
11
 
12
- # Now returns List[BaseNode] instead of List[Dict]
13
  nodes = chunk_markdown_file(test_file)
14
 
15
- print(f"\n Total nodes: {len(nodes)}\n")
16
 
 
17
  for i, node in enumerate(nodes):
18
  content = node.get_content()
19
  metadata = node.metadata
20
 
21
  print(f"\n{'─' * 70}")
22
  print(f" NODE #{i}")
23
- print(f" Type: {type(node).__name__}")
24
- print(f" Length: {len(content)} chars")
25
  if metadata:
26
  print(f" Metadata: {metadata}")
27
  print(f"{'─' * 70}")
 
 
28
  content_preview = content[:200]
29
  if len(content) > 200:
30
  content_preview += "..."
31
  print(content_preview)
32
 
 
33
  with open("test_chunk.md", "w", encoding="utf-8") as f:
34
  for i, node in enumerate(nodes):
35
  content = node.get_content()
36
  metadata = node.metadata
37
 
38
  f.write(f"# NODE {i}\n")
39
- f.write(f"**Type:** {type(node).__name__}\n\n")
40
  f.write("**Metadata:**\n")
41
  for key, value in metadata.items():
42
  f.write(f"- {key}: {value}\n")
43
- f.write("\n**Content:**\n")
44
  f.write(content)
45
  f.write("\n\n---\n\n")
46
 
47
- print("\n Done")
 
1
+ """Script test chunking markdown file."""
2
+
3
  import sys
4
  sys.path.insert(0, "/home/bahung/DoAn")
5
 
6
+ from dotenv import load_dotenv
7
+ load_dotenv() # Load biến môi trường từ .env
8
+
9
  from core.rag.chunk import chunk_markdown_file
10
 
11
+ # File test
12
  test_file = "data/data_process/chuong_trinh_dao_tao/1.1. Kỹ thuật Cơ điện tử.md"
13
 
14
  print("=" * 70)
15
  print(f" File: {test_file}")
16
  print("=" * 70)
17
 
18
+ # Chunk file markdown
19
  nodes = chunk_markdown_file(test_file)
20
 
21
+ print(f"\n Tổng số nodes: {len(nodes)}\n")
22
 
23
+ # Hiển thị thông tin từng node
24
  for i, node in enumerate(nodes):
25
  content = node.get_content()
26
  metadata = node.metadata
27
 
28
  print(f"\n{'─' * 70}")
29
  print(f" NODE #{i}")
30
+ print(f" Loại: {type(node).__name__}")
31
+ print(f" Độ dài: {len(content)} ký tự")
32
  if metadata:
33
  print(f" Metadata: {metadata}")
34
  print(f"{'─' * 70}")
35
+
36
+ # Preview nội dung (tối đa 200 ký tự)
37
  content_preview = content[:200]
38
  if len(content) > 200:
39
  content_preview += "..."
40
  print(content_preview)
41
 
42
+ # Lưu kết quả ra file markdown để dễ xem
43
  with open("test_chunk.md", "w", encoding="utf-8") as f:
44
  for i, node in enumerate(nodes):
45
  content = node.get_content()
46
  metadata = node.metadata
47
 
48
  f.write(f"# NODE {i}\n")
49
+ f.write(f"**Loại:** {type(node).__name__}\n\n")
50
  f.write("**Metadata:**\n")
51
  for key, value in metadata.items():
52
  f.write(f"- {key}: {value}\n")
53
+ f.write("\n**Nội dung:**\n")
54
  f.write(content)
55
  f.write("\n\n---\n\n")
56
 
57
+ print("\n Hoàn tất!")