Spaces:
Sleeping
Sleeping
HaRin2806 commited on
Commit ·
76a8f20
1
Parent(s): 59a1c47
fix bug
Browse files- core/data_processor.py +10 -60
- core/embedding_model.py +251 -81
- core/rag_pipeline.py +52 -76
core/data_processor.py
CHANGED
|
@@ -5,7 +5,6 @@ import logging
|
|
| 5 |
import datetime
|
| 6 |
from typing import Dict, List, Any, Union, Tuple
|
| 7 |
|
| 8 |
-
# Cấu hình logging
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
| 11 |
class DataProcessor:
|
|
@@ -26,18 +25,14 @@ class DataProcessor:
|
|
| 26 |
"""Tải tất cả dữ liệu từ các thư mục con trong data"""
|
| 27 |
logger.info(f"Đang tải dữ liệu từ thư mục: {self.data_dir}")
|
| 28 |
|
| 29 |
-
# Quét qua tất cả thư mục trong data
|
| 30 |
for item in os.listdir(self.data_dir):
|
| 31 |
folder_path = os.path.join(self.data_dir, item)
|
| 32 |
|
| 33 |
-
# Kiểm tra xem đây có phải là thư mục không
|
| 34 |
if os.path.isdir(folder_path):
|
| 35 |
metadata_file = os.path.join(folder_path, "metadata.json")
|
| 36 |
|
| 37 |
-
# Nếu có file metadata.json
|
| 38 |
if os.path.exists(metadata_file):
|
| 39 |
try:
|
| 40 |
-
# Tải metadata
|
| 41 |
with open(metadata_file, 'r', encoding='utf-8') as f:
|
| 42 |
content = f.read()
|
| 43 |
if not content.strip():
|
|
@@ -45,7 +40,6 @@ class DataProcessor:
|
|
| 45 |
continue
|
| 46 |
folder_metadata = json.loads(content)
|
| 47 |
|
| 48 |
-
# Xác định ID của thư mục
|
| 49 |
folder_id = None
|
| 50 |
if "bai_info" in folder_metadata:
|
| 51 |
folder_id = folder_metadata["bai_info"].get("id", item)
|
|
@@ -54,10 +48,8 @@ class DataProcessor:
|
|
| 54 |
else:
|
| 55 |
folder_id = item
|
| 56 |
|
| 57 |
-
# Lưu metadata vào từ điển
|
| 58 |
self.metadata[folder_id] = folder_metadata
|
| 59 |
|
| 60 |
-
# Tải tất cả chunks, tables và figures
|
| 61 |
self._load_content_from_metadata(folder_path, folder_metadata)
|
| 62 |
|
| 63 |
logger.info(f"Đã tải xong thư mục: {item}")
|
|
@@ -68,33 +60,28 @@ class DataProcessor:
|
|
| 68 |
|
| 69 |
def _load_content_from_metadata(self, folder_path: str, folder_metadata: Dict[str, Any]):
|
| 70 |
"""Tải nội dung chunks, tables và figures từ metadata"""
|
| 71 |
-
# Tải chunks
|
| 72 |
for chunk_meta in folder_metadata.get("chunks", []):
|
| 73 |
chunk_id = chunk_meta.get("id")
|
| 74 |
chunk_path = os.path.join(folder_path, "chunks", f"{chunk_id}.md")
|
| 75 |
|
| 76 |
-
chunk_data = chunk_meta.copy()
|
| 77 |
|
| 78 |
-
# Thêm nội dung từ file markdown nếu tồn tại
|
| 79 |
if os.path.exists(chunk_path):
|
| 80 |
with open(chunk_path, 'r', encoding='utf-8') as f:
|
| 81 |
content = f.read()
|
| 82 |
chunk_data["content"] = self._extract_content_from_markdown(content)
|
| 83 |
else:
|
| 84 |
-
# Nếu không tìm thấy file, tạo nội dung mẫu và ghi log ở debug level
|
| 85 |
chunk_data["content"] = f"Nội dung cho {chunk_id} không tìm thấy."
|
| 86 |
logger.debug(f"Không tìm thấy file chunk: {chunk_path}")
|
| 87 |
|
| 88 |
self.chunks.append(chunk_data)
|
| 89 |
|
| 90 |
-
# Tải tables
|
| 91 |
for table_meta in folder_metadata.get("tables", []):
|
| 92 |
table_id = table_meta.get("id")
|
| 93 |
table_path = os.path.join(folder_path, "tables", f"{table_id}.md")
|
| 94 |
|
| 95 |
table_data = table_meta.copy()
|
| 96 |
|
| 97 |
-
# Thêm nội dung từ file markdown nếu tồn tại
|
| 98 |
if os.path.exists(table_path):
|
| 99 |
with open(table_path, 'r', encoding='utf-8') as f:
|
| 100 |
content = f.read()
|
|
@@ -105,13 +92,11 @@ class DataProcessor:
|
|
| 105 |
|
| 106 |
self.tables.append(table_data)
|
| 107 |
|
| 108 |
-
# Tải figures
|
| 109 |
for figure_meta in folder_metadata.get("figures", []):
|
| 110 |
figure_id = figure_meta.get("id")
|
| 111 |
figure_path = os.path.join(folder_path, "figures", f"{figure_id}.md")
|
| 112 |
figure_data = figure_meta.copy()
|
| 113 |
|
| 114 |
-
# Thêm nội dung từ file markdown nếu tồn tại
|
| 115 |
content_loaded = False
|
| 116 |
if os.path.exists(figure_path):
|
| 117 |
with open(figure_path, 'r', encoding='utf-8') as f:
|
|
@@ -119,7 +104,6 @@ class DataProcessor:
|
|
| 119 |
figure_data["content"] = self._extract_content_from_markdown(content)
|
| 120 |
content_loaded = True
|
| 121 |
|
| 122 |
-
# Thêm đường dẫn đến file hình ảnh nếu có
|
| 123 |
image_path = None
|
| 124 |
image_extensions = ['.png', '.jpg', '.jpeg', '.gif', '.svg']
|
| 125 |
for ext in image_extensions:
|
|
@@ -130,18 +114,15 @@ class DataProcessor:
|
|
| 130 |
|
| 131 |
if image_path:
|
| 132 |
figure_data["image_path"] = image_path
|
| 133 |
-
# Tạo nội dung mặc định nếu không có file markdown
|
| 134 |
if not content_loaded:
|
| 135 |
figure_caption = figure_meta.get("title", f"Hình {figure_id}")
|
| 136 |
figure_data["content"] = f""
|
| 137 |
elif not content_loaded:
|
| 138 |
-
# Nếu không có cả file markdown và file hình
|
| 139 |
figure_data["content"] = f"Hình {figure_id} không tìm thấy."
|
| 140 |
logger.debug(f"Không tìm thấy file hình cho {figure_id}")
|
| 141 |
|
| 142 |
self.figures.append(figure_data)
|
| 143 |
|
| 144 |
-
# Tải data_files (trường hợp phụ lục)
|
| 145 |
if "data_files" in folder_metadata:
|
| 146 |
for data_file_meta in folder_metadata.get("data_files", []):
|
| 147 |
data_id = data_file_meta.get("id")
|
|
@@ -149,16 +130,13 @@ class DataProcessor:
|
|
| 149 |
|
| 150 |
data_file = data_file_meta.copy()
|
| 151 |
|
| 152 |
-
# Thêm nội dung từ file markdown nếu tồn tại
|
| 153 |
if os.path.exists(data_path):
|
| 154 |
with open(data_path, 'r', encoding='utf-8') as f:
|
| 155 |
content = f.read()
|
| 156 |
data_file["content"] = self._extract_content_from_markdown(content)
|
| 157 |
|
| 158 |
-
# Xác định loại nội dung
|
| 159 |
content_type = data_file.get("content_type", "table")
|
| 160 |
|
| 161 |
-
# Thêm vào danh sách phù hợp dựa trên loại nội dung
|
| 162 |
if content_type == "table":
|
| 163 |
self.tables.append(data_file)
|
| 164 |
elif content_type == "text":
|
|
@@ -173,7 +151,6 @@ class DataProcessor:
|
|
| 173 |
|
| 174 |
def _extract_content_from_markdown(self, md_content: str) -> str:
|
| 175 |
"""Trích xuất nội dung từ markdown, bỏ qua phần frontmatter"""
|
| 176 |
-
# Tách frontmatter (nằm giữa "---")
|
| 177 |
if md_content.startswith("---"):
|
| 178 |
parts = md_content.split("---", 2)
|
| 179 |
if len(parts) >= 3:
|
|
@@ -214,26 +191,23 @@ class DataProcessor:
|
|
| 214 |
return None
|
| 215 |
|
| 216 |
def find_items_by_age(self, age: int) -> Dict[str, List[Dict[str, Any]]]:
|
| 217 |
-
"""Tìm các items
|
| 218 |
relevant_chunks = []
|
| 219 |
relevant_tables = []
|
| 220 |
relevant_figures = []
|
| 221 |
|
| 222 |
-
# Lọc chunks
|
| 223 |
for chunk in self.chunks:
|
| 224 |
-
age_range = chunk.get("age_range", [0,
|
| 225 |
if len(age_range) == 2 and age_range[0] <= age <= age_range[1]:
|
| 226 |
relevant_chunks.append(chunk)
|
| 227 |
|
| 228 |
-
# Lọc tables
|
| 229 |
for table in self.tables:
|
| 230 |
-
age_range = table.get("age_range", [0,
|
| 231 |
if len(age_range) == 2 and age_range[0] <= age <= age_range[1]:
|
| 232 |
relevant_tables.append(table)
|
| 233 |
|
| 234 |
-
# Lọc figures
|
| 235 |
for figure in self.figures:
|
| 236 |
-
age_range = figure.get("age_range", [0,
|
| 237 |
if len(age_range) == 2 and age_range[0] <= age <= age_range[1]:
|
| 238 |
relevant_figures.append(figure)
|
| 239 |
|
|
@@ -249,7 +223,6 @@ class DataProcessor:
|
|
| 249 |
related_tables = []
|
| 250 |
related_figures = []
|
| 251 |
|
| 252 |
-
# Tìm item gốc
|
| 253 |
source_item = None
|
| 254 |
for item in self.chunks + self.tables + self.figures:
|
| 255 |
if item.get("id") == item_id:
|
|
@@ -263,24 +236,19 @@ class DataProcessor:
|
|
| 263 |
"figures": []
|
| 264 |
}
|
| 265 |
|
| 266 |
-
# Lấy danh sách IDs của các items liên quan
|
| 267 |
related_ids = source_item.get("related_chunks", [])
|
| 268 |
|
| 269 |
-
# Tìm các items liên quan
|
| 270 |
for related_id in related_ids:
|
| 271 |
-
# Tìm trong chunks
|
| 272 |
for chunk in self.chunks:
|
| 273 |
if chunk.get("id") == related_id:
|
| 274 |
related_chunks.append(chunk)
|
| 275 |
break
|
| 276 |
|
| 277 |
-
# Tìm trong tables
|
| 278 |
for table in self.tables:
|
| 279 |
if table.get("id") == related_id:
|
| 280 |
related_tables.append(table)
|
| 281 |
break
|
| 282 |
|
| 283 |
-
# Tìm trong figures
|
| 284 |
for figure in self.figures:
|
| 285 |
if figure.get("id") == related_id:
|
| 286 |
related_figures.append(figure)
|
|
@@ -294,9 +262,7 @@ class DataProcessor:
|
|
| 294 |
|
| 295 |
def preprocess_query(self, query: str) -> str:
|
| 296 |
"""Tiền xử lý câu truy vấn"""
|
| 297 |
-
# Loại bỏ ký tự đặc biệt
|
| 298 |
query = re.sub(r'[^\w\s\d]', ' ', query)
|
| 299 |
-
# Loại bỏ khoảng trắng thừa
|
| 300 |
query = re.sub(r'\s+', ' ', query).strip()
|
| 301 |
return query
|
| 302 |
|
|
@@ -310,10 +276,8 @@ class DataProcessor:
|
|
| 310 |
content = item.get("content", "")
|
| 311 |
content_type = item.get("content_type", "text")
|
| 312 |
|
| 313 |
-
# Nếu là bảng, thêm tiêu đề "B��ng:"
|
| 314 |
if content_type == "table":
|
| 315 |
title = f"Bảng: {title}"
|
| 316 |
-
# Nếu là hình, thêm tiêu đề "Hình:"
|
| 317 |
elif content_type == "figure":
|
| 318 |
title = f"Hình: {title}"
|
| 319 |
|
|
@@ -326,9 +290,7 @@ class DataProcessor:
|
|
| 326 |
"""Chuẩn bị dữ liệu cho việc nhúng (embedding)"""
|
| 327 |
all_items = []
|
| 328 |
|
| 329 |
-
# Thêm chunks
|
| 330 |
for chunk in self.chunks:
|
| 331 |
-
# Tìm chapter từ chunk ID
|
| 332 |
chunk_id = chunk.get("id", "")
|
| 333 |
chapter = "unknown"
|
| 334 |
if chunk_id.startswith("bai1_"):
|
|
@@ -346,13 +308,11 @@ class DataProcessor:
|
|
| 346 |
if chunk.get("title"):
|
| 347 |
content = f"Tiêu đề: {chunk.get('title')}\n\nNội dung: {content}"
|
| 348 |
|
| 349 |
-
|
| 350 |
-
age_range = chunk.get("age_range", [0, 100])
|
| 351 |
age_min = age_range[0] if len(age_range) > 0 else 0
|
| 352 |
-
age_max = age_range[1] if len(age_range) > 1 else
|
| 353 |
age_range_str = f"{age_min}-{age_max}"
|
| 354 |
|
| 355 |
-
# Xử lý related_chunks - convert list thành string
|
| 356 |
related_chunks = chunk.get("related_chunks", [])
|
| 357 |
related_chunks_str = ",".join(related_chunks) if related_chunks else ""
|
| 358 |
|
|
@@ -379,9 +339,7 @@ class DataProcessor:
|
|
| 379 |
}
|
| 380 |
all_items.append(embedding_item)
|
| 381 |
|
| 382 |
-
# Thêm tables
|
| 383 |
for table in self.tables:
|
| 384 |
-
# Tìm chapter từ table ID
|
| 385 |
table_id = table.get("id", "")
|
| 386 |
chapter = "unknown"
|
| 387 |
if table_id.startswith("bai1_"):
|
|
@@ -399,13 +357,11 @@ class DataProcessor:
|
|
| 399 |
if table.get("title"):
|
| 400 |
content = f"Bảng: {table.get('title')}\n\nNội dung: {content}"
|
| 401 |
|
| 402 |
-
|
| 403 |
-
age_range = table.get("age_range", [0, 100])
|
| 404 |
age_min = age_range[0] if len(age_range) > 0 else 0
|
| 405 |
-
age_max = age_range[1] if len(age_range) > 1 else
|
| 406 |
age_range_str = f"{age_min}-{age_max}"
|
| 407 |
|
| 408 |
-
# Xử lý related_chunks và table_columns
|
| 409 |
related_chunks = table.get("related_chunks", [])
|
| 410 |
related_chunks_str = ",".join(related_chunks) if related_chunks else ""
|
| 411 |
table_columns = table.get("table_columns", [])
|
|
@@ -433,9 +389,7 @@ class DataProcessor:
|
|
| 433 |
}
|
| 434 |
all_items.append(embedding_item)
|
| 435 |
|
| 436 |
-
# Thêm figures
|
| 437 |
for figure in self.figures:
|
| 438 |
-
# Tìm chapter từ figure ID
|
| 439 |
figure_id = figure.get("id", "")
|
| 440 |
chapter = "unknown"
|
| 441 |
if figure_id.startswith("bai1_"):
|
|
@@ -453,13 +407,11 @@ class DataProcessor:
|
|
| 453 |
if figure.get("title"):
|
| 454 |
content = f"Hình: {figure.get('title')}\n\nMô tả: {content}"
|
| 455 |
|
| 456 |
-
# Xử lý age_range
|
| 457 |
age_range = figure.get("age_range", [0, 100])
|
| 458 |
age_min = age_range[0] if len(age_range) > 0 else 0
|
| 459 |
age_max = age_range[1] if len(age_range) > 1 else 100
|
| 460 |
age_range_str = f"{age_min}-{age_max}"
|
| 461 |
|
| 462 |
-
# Xử lý related_chunks
|
| 463 |
related_chunks = figure.get("related_chunks", [])
|
| 464 |
related_chunks_str = ",".join(related_chunks) if related_chunks else ""
|
| 465 |
|
|
@@ -509,16 +461,14 @@ class DataProcessor:
|
|
| 509 |
"by_age": {}
|
| 510 |
}
|
| 511 |
|
| 512 |
-
# Thống kê theo bài
|
| 513 |
for item in os.listdir(self.data_dir):
|
| 514 |
if os.path.isdir(os.path.join(self.data_dir, item)):
|
| 515 |
item_stats = self.count_items_by_prefix(f"{item}_")
|
| 516 |
stats["by_lesson"][item] = item_stats
|
| 517 |
|
| 518 |
-
# Thống kê theo độ tuổi
|
| 519 |
age_ranges = {}
|
| 520 |
for chunk in self.chunks + self.tables + self.figures:
|
| 521 |
-
age_range = chunk.get("age_range", [0,
|
| 522 |
if len(age_range) == 2:
|
| 523 |
range_key = f"{age_range[0]}-{age_range[1]}"
|
| 524 |
if range_key not in age_ranges:
|
|
|
|
| 5 |
import datetime
|
| 6 |
from typing import Dict, List, Any, Union, Tuple
|
| 7 |
|
|
|
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
class DataProcessor:
|
|
|
|
| 25 |
"""Tải tất cả dữ liệu từ các thư mục con trong data"""
|
| 26 |
logger.info(f"Đang tải dữ liệu từ thư mục: {self.data_dir}")
|
| 27 |
|
|
|
|
| 28 |
for item in os.listdir(self.data_dir):
|
| 29 |
folder_path = os.path.join(self.data_dir, item)
|
| 30 |
|
|
|
|
| 31 |
if os.path.isdir(folder_path):
|
| 32 |
metadata_file = os.path.join(folder_path, "metadata.json")
|
| 33 |
|
|
|
|
| 34 |
if os.path.exists(metadata_file):
|
| 35 |
try:
|
|
|
|
| 36 |
with open(metadata_file, 'r', encoding='utf-8') as f:
|
| 37 |
content = f.read()
|
| 38 |
if not content.strip():
|
|
|
|
| 40 |
continue
|
| 41 |
folder_metadata = json.loads(content)
|
| 42 |
|
|
|
|
| 43 |
folder_id = None
|
| 44 |
if "bai_info" in folder_metadata:
|
| 45 |
folder_id = folder_metadata["bai_info"].get("id", item)
|
|
|
|
| 48 |
else:
|
| 49 |
folder_id = item
|
| 50 |
|
|
|
|
| 51 |
self.metadata[folder_id] = folder_metadata
|
| 52 |
|
|
|
|
| 53 |
self._load_content_from_metadata(folder_path, folder_metadata)
|
| 54 |
|
| 55 |
logger.info(f"Đã tải xong thư mục: {item}")
|
|
|
|
| 60 |
|
| 61 |
def _load_content_from_metadata(self, folder_path: str, folder_metadata: Dict[str, Any]):
|
| 62 |
"""Tải nội dung chunks, tables và figures từ metadata"""
|
|
|
|
| 63 |
for chunk_meta in folder_metadata.get("chunks", []):
|
| 64 |
chunk_id = chunk_meta.get("id")
|
| 65 |
chunk_path = os.path.join(folder_path, "chunks", f"{chunk_id}.md")
|
| 66 |
|
| 67 |
+
chunk_data = chunk_meta.copy()
|
| 68 |
|
|
|
|
| 69 |
if os.path.exists(chunk_path):
|
| 70 |
with open(chunk_path, 'r', encoding='utf-8') as f:
|
| 71 |
content = f.read()
|
| 72 |
chunk_data["content"] = self._extract_content_from_markdown(content)
|
| 73 |
else:
|
|
|
|
| 74 |
chunk_data["content"] = f"Nội dung cho {chunk_id} không tìm thấy."
|
| 75 |
logger.debug(f"Không tìm thấy file chunk: {chunk_path}")
|
| 76 |
|
| 77 |
self.chunks.append(chunk_data)
|
| 78 |
|
|
|
|
| 79 |
for table_meta in folder_metadata.get("tables", []):
|
| 80 |
table_id = table_meta.get("id")
|
| 81 |
table_path = os.path.join(folder_path, "tables", f"{table_id}.md")
|
| 82 |
|
| 83 |
table_data = table_meta.copy()
|
| 84 |
|
|
|
|
| 85 |
if os.path.exists(table_path):
|
| 86 |
with open(table_path, 'r', encoding='utf-8') as f:
|
| 87 |
content = f.read()
|
|
|
|
| 92 |
|
| 93 |
self.tables.append(table_data)
|
| 94 |
|
|
|
|
| 95 |
for figure_meta in folder_metadata.get("figures", []):
|
| 96 |
figure_id = figure_meta.get("id")
|
| 97 |
figure_path = os.path.join(folder_path, "figures", f"{figure_id}.md")
|
| 98 |
figure_data = figure_meta.copy()
|
| 99 |
|
|
|
|
| 100 |
content_loaded = False
|
| 101 |
if os.path.exists(figure_path):
|
| 102 |
with open(figure_path, 'r', encoding='utf-8') as f:
|
|
|
|
| 104 |
figure_data["content"] = self._extract_content_from_markdown(content)
|
| 105 |
content_loaded = True
|
| 106 |
|
|
|
|
| 107 |
image_path = None
|
| 108 |
image_extensions = ['.png', '.jpg', '.jpeg', '.gif', '.svg']
|
| 109 |
for ext in image_extensions:
|
|
|
|
| 114 |
|
| 115 |
if image_path:
|
| 116 |
figure_data["image_path"] = image_path
|
|
|
|
| 117 |
if not content_loaded:
|
| 118 |
figure_caption = figure_meta.get("title", f"Hình {figure_id}")
|
| 119 |
figure_data["content"] = f""
|
| 120 |
elif not content_loaded:
|
|
|
|
| 121 |
figure_data["content"] = f"Hình {figure_id} không tìm thấy."
|
| 122 |
logger.debug(f"Không tìm thấy file hình cho {figure_id}")
|
| 123 |
|
| 124 |
self.figures.append(figure_data)
|
| 125 |
|
|
|
|
| 126 |
if "data_files" in folder_metadata:
|
| 127 |
for data_file_meta in folder_metadata.get("data_files", []):
|
| 128 |
data_id = data_file_meta.get("id")
|
|
|
|
| 130 |
|
| 131 |
data_file = data_file_meta.copy()
|
| 132 |
|
|
|
|
| 133 |
if os.path.exists(data_path):
|
| 134 |
with open(data_path, 'r', encoding='utf-8') as f:
|
| 135 |
content = f.read()
|
| 136 |
data_file["content"] = self._extract_content_from_markdown(content)
|
| 137 |
|
|
|
|
| 138 |
content_type = data_file.get("content_type", "table")
|
| 139 |
|
|
|
|
| 140 |
if content_type == "table":
|
| 141 |
self.tables.append(data_file)
|
| 142 |
elif content_type == "text":
|
|
|
|
| 151 |
|
| 152 |
def _extract_content_from_markdown(self, md_content: str) -> str:
|
| 153 |
"""Trích xuất nội dung từ markdown, bỏ qua phần frontmatter"""
|
|
|
|
| 154 |
if md_content.startswith("---"):
|
| 155 |
parts = md_content.split("---", 2)
|
| 156 |
if len(parts) >= 3:
|
|
|
|
| 191 |
return None
|
| 192 |
|
| 193 |
def find_items_by_age(self, age: int) -> Dict[str, List[Dict[str, Any]]]:
|
| 194 |
+
"""Tìm các items liên quan đến độ tuổi của người dùng"""
|
| 195 |
relevant_chunks = []
|
| 196 |
relevant_tables = []
|
| 197 |
relevant_figures = []
|
| 198 |
|
|
|
|
| 199 |
for chunk in self.chunks:
|
| 200 |
+
age_range = chunk.get("age_range", [0, 19])
|
| 201 |
if len(age_range) == 2 and age_range[0] <= age <= age_range[1]:
|
| 202 |
relevant_chunks.append(chunk)
|
| 203 |
|
|
|
|
| 204 |
for table in self.tables:
|
| 205 |
+
age_range = table.get("age_range", [0, 19])
|
| 206 |
if len(age_range) == 2 and age_range[0] <= age <= age_range[1]:
|
| 207 |
relevant_tables.append(table)
|
| 208 |
|
|
|
|
| 209 |
for figure in self.figures:
|
| 210 |
+
age_range = figure.get("age_range", [0, 19])
|
| 211 |
if len(age_range) == 2 and age_range[0] <= age <= age_range[1]:
|
| 212 |
relevant_figures.append(figure)
|
| 213 |
|
|
|
|
| 223 |
related_tables = []
|
| 224 |
related_figures = []
|
| 225 |
|
|
|
|
| 226 |
source_item = None
|
| 227 |
for item in self.chunks + self.tables + self.figures:
|
| 228 |
if item.get("id") == item_id:
|
|
|
|
| 236 |
"figures": []
|
| 237 |
}
|
| 238 |
|
|
|
|
| 239 |
related_ids = source_item.get("related_chunks", [])
|
| 240 |
|
|
|
|
| 241 |
for related_id in related_ids:
|
|
|
|
| 242 |
for chunk in self.chunks:
|
| 243 |
if chunk.get("id") == related_id:
|
| 244 |
related_chunks.append(chunk)
|
| 245 |
break
|
| 246 |
|
|
|
|
| 247 |
for table in self.tables:
|
| 248 |
if table.get("id") == related_id:
|
| 249 |
related_tables.append(table)
|
| 250 |
break
|
| 251 |
|
|
|
|
| 252 |
for figure in self.figures:
|
| 253 |
if figure.get("id") == related_id:
|
| 254 |
related_figures.append(figure)
|
|
|
|
| 262 |
|
| 263 |
def preprocess_query(self, query: str) -> str:
|
| 264 |
"""Tiền xử lý câu truy vấn"""
|
|
|
|
| 265 |
query = re.sub(r'[^\w\s\d]', ' ', query)
|
|
|
|
| 266 |
query = re.sub(r'\s+', ' ', query).strip()
|
| 267 |
return query
|
| 268 |
|
|
|
|
| 276 |
content = item.get("content", "")
|
| 277 |
content_type = item.get("content_type", "text")
|
| 278 |
|
|
|
|
| 279 |
if content_type == "table":
|
| 280 |
title = f"Bảng: {title}"
|
|
|
|
| 281 |
elif content_type == "figure":
|
| 282 |
title = f"Hình: {title}"
|
| 283 |
|
|
|
|
| 290 |
"""Chuẩn bị dữ liệu cho việc nhúng (embedding)"""
|
| 291 |
all_items = []
|
| 292 |
|
|
|
|
| 293 |
for chunk in self.chunks:
|
|
|
|
| 294 |
chunk_id = chunk.get("id", "")
|
| 295 |
chapter = "unknown"
|
| 296 |
if chunk_id.startswith("bai1_"):
|
|
|
|
| 308 |
if chunk.get("title"):
|
| 309 |
content = f"Tiêu đề: {chunk.get('title')}\n\nNội dung: {content}"
|
| 310 |
|
| 311 |
+
age_range = chunk.get("age_range", [0, 19])
|
|
|
|
| 312 |
age_min = age_range[0] if len(age_range) > 0 else 0
|
| 313 |
+
age_max = age_range[1] if len(age_range) > 1 else 19
|
| 314 |
age_range_str = f"{age_min}-{age_max}"
|
| 315 |
|
|
|
|
| 316 |
related_chunks = chunk.get("related_chunks", [])
|
| 317 |
related_chunks_str = ",".join(related_chunks) if related_chunks else ""
|
| 318 |
|
|
|
|
| 339 |
}
|
| 340 |
all_items.append(embedding_item)
|
| 341 |
|
|
|
|
| 342 |
for table in self.tables:
|
|
|
|
| 343 |
table_id = table.get("id", "")
|
| 344 |
chapter = "unknown"
|
| 345 |
if table_id.startswith("bai1_"):
|
|
|
|
| 357 |
if table.get("title"):
|
| 358 |
content = f"Bảng: {table.get('title')}\n\nNội dung: {content}"
|
| 359 |
|
| 360 |
+
age_range = table.get("age_range", [0, 19])
|
|
|
|
| 361 |
age_min = age_range[0] if len(age_range) > 0 else 0
|
| 362 |
+
age_max = age_range[1] if len(age_range) > 1 else 19
|
| 363 |
age_range_str = f"{age_min}-{age_max}"
|
| 364 |
|
|
|
|
| 365 |
related_chunks = table.get("related_chunks", [])
|
| 366 |
related_chunks_str = ",".join(related_chunks) if related_chunks else ""
|
| 367 |
table_columns = table.get("table_columns", [])
|
|
|
|
| 389 |
}
|
| 390 |
all_items.append(embedding_item)
|
| 391 |
|
|
|
|
| 392 |
for figure in self.figures:
|
|
|
|
| 393 |
figure_id = figure.get("id", "")
|
| 394 |
chapter = "unknown"
|
| 395 |
if figure_id.startswith("bai1_"):
|
|
|
|
| 407 |
if figure.get("title"):
|
| 408 |
content = f"Hình: {figure.get('title')}\n\nMô tả: {content}"
|
| 409 |
|
|
|
|
| 410 |
age_range = figure.get("age_range", [0, 100])
|
| 411 |
age_min = age_range[0] if len(age_range) > 0 else 0
|
| 412 |
age_max = age_range[1] if len(age_range) > 1 else 100
|
| 413 |
age_range_str = f"{age_min}-{age_max}"
|
| 414 |
|
|
|
|
| 415 |
related_chunks = figure.get("related_chunks", [])
|
| 416 |
related_chunks_str = ",".join(related_chunks) if related_chunks else ""
|
| 417 |
|
|
|
|
| 461 |
"by_age": {}
|
| 462 |
}
|
| 463 |
|
|
|
|
| 464 |
for item in os.listdir(self.data_dir):
|
| 465 |
if os.path.isdir(os.path.join(self.data_dir, item)):
|
| 466 |
item_stats = self.count_items_by_prefix(f"{item}_")
|
| 467 |
stats["by_lesson"][item] = item_stats
|
| 468 |
|
|
|
|
| 469 |
age_ranges = {}
|
| 470 |
for chunk in self.chunks + self.tables + self.figures:
|
| 471 |
+
age_range = chunk.get("age_range", [0, 19])
|
| 472 |
if len(age_range) == 2:
|
| 473 |
range_key = f"{age_range[0]}-{age_range[1]}"
|
| 474 |
if range_key not in age_ranges:
|
core/embedding_model.py
CHANGED
|
@@ -6,16 +6,12 @@ import uuid
|
|
| 6 |
import os
|
| 7 |
from config import EMBEDDING_MODEL, CHROMA_PERSIST_DIRECTORY, COLLECTION_NAME
|
| 8 |
|
| 9 |
-
# Cấu hình logging
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
-
# Global instance để implement singleton pattern
|
| 13 |
_embedding_model_instance = None
|
| 14 |
|
| 15 |
def get_embedding_model():
|
| 16 |
-
"""
|
| 17 |
-
Singleton pattern để đảm bảo chỉ có một instance của EmbeddingModel
|
| 18 |
-
"""
|
| 19 |
global _embedding_model_instance
|
| 20 |
if _embedding_model_instance is None:
|
| 21 |
logger.info("Khởi tạo EmbeddingModel instance lần đầu")
|
|
@@ -40,52 +36,199 @@ class EmbeddingModel:
|
|
| 40 |
self.model = SentenceTransformer(EMBEDDING_MODEL, cache_folder=cache_dir, trust_remote_code=True)
|
| 41 |
logger.info("Đã tải sentence transformer model với cache folder explicit")
|
| 42 |
|
|
|
|
|
|
|
|
|
|
| 43 |
# Đảm bảo thư mục ChromaDB tồn tại và có quyền ghi
|
| 44 |
try:
|
| 45 |
-
os.makedirs(
|
| 46 |
# Test ghi file để kiểm tra permission
|
| 47 |
-
test_file = os.path.join(
|
| 48 |
with open(test_file, 'w') as f:
|
| 49 |
f.write('test')
|
| 50 |
os.remove(test_file)
|
| 51 |
-
logger.info(f"Thư mục ChromaDB đã sẵn sàng: {
|
| 52 |
except Exception as e:
|
| 53 |
logger.error(f"Lỗi tạo/kiểm tra thư mục ChromaDB: {e}")
|
| 54 |
# Fallback to /tmp directory
|
| 55 |
import tempfile
|
| 56 |
-
|
| 57 |
-
os.makedirs(
|
| 58 |
-
logger.warning(f"Sử dụng thư mục tạm thời: {
|
| 59 |
|
| 60 |
# Khởi tạo ChromaDB client với persistent storage
|
| 61 |
try:
|
| 62 |
self.chroma_client = chromadb.PersistentClient(
|
| 63 |
-
path=
|
| 64 |
settings=Settings(
|
| 65 |
anonymized_telemetry=False,
|
| 66 |
allow_reset=True
|
| 67 |
)
|
| 68 |
)
|
| 69 |
-
logger.info(f"Đã kết nối ChromaDB tại: {
|
| 70 |
except Exception as e:
|
| 71 |
logger.error(f"Lỗi kết nối ChromaDB: {e}")
|
| 72 |
# Fallback to in-memory client
|
| 73 |
logger.warning("Fallback to in-memory ChromaDB client")
|
| 74 |
self.chroma_client = chromadb.Client()
|
| 75 |
|
| 76 |
-
# Lấy hoặc tạo collection
|
| 77 |
try:
|
| 78 |
self.collection = self.chroma_client.get_collection(name=COLLECTION_NAME)
|
| 79 |
logger.info(f"Đã kết nối collection '{COLLECTION_NAME}' với {self.collection.count()} items")
|
| 80 |
except Exception:
|
| 81 |
-
logger.
|
| 82 |
-
self.collection = self.chroma_client.create_collection(
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
def _add_prefix_to_text(self, text, is_query=True):
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
# Kiểm tra xem text đã có prefix chưa
|
| 90 |
if text.startswith(('query:', 'passage:')):
|
| 91 |
return text
|
|
@@ -98,24 +241,32 @@ class EmbeddingModel:
|
|
| 98 |
|
| 99 |
def encode(self, texts, is_query=True):
|
| 100 |
"""
|
| 101 |
-
Encode văn bản thành embeddings
|
| 102 |
-
|
| 103 |
-
Args:
|
| 104 |
-
texts (str or list): Văn bản hoặc danh sách văn bản cần encode
|
| 105 |
-
is_query (bool): True nếu là query, False nếu là passage
|
| 106 |
-
|
| 107 |
-
Returns:
|
| 108 |
-
list: Embeddings vector
|
| 109 |
"""
|
| 110 |
try:
|
| 111 |
if isinstance(texts, str):
|
| 112 |
texts = [texts]
|
| 113 |
|
| 114 |
-
# Thêm prefix cho texts
|
| 115 |
processed_texts = [self._add_prefix_to_text(text, is_query) for text in texts]
|
| 116 |
|
| 117 |
-
logger.debug(f"Đang encode {len(processed_texts)} văn bản")
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
return embeddings.tolist()
|
| 121 |
|
|
@@ -124,24 +275,10 @@ class EmbeddingModel:
|
|
| 124 |
raise
|
| 125 |
|
| 126 |
def search(self, query, top_k=5, age_filter=None):
|
| 127 |
-
"""
|
| 128 |
-
Tìm kiếm văn bản tương tự trong ChromaDB
|
| 129 |
-
|
| 130 |
-
Args:
|
| 131 |
-
query (str): Câu hỏi cần tìm kiếm
|
| 132 |
-
top_k (int): Số lượng kết quả trả về
|
| 133 |
-
age_filter (int): Lọc theo độ tuổi (optional)
|
| 134 |
-
|
| 135 |
-
Returns:
|
| 136 |
-
list: Danh sách kết quả tìm kiếm
|
| 137 |
-
"""
|
| 138 |
try:
|
| 139 |
-
logger.debug(f"Dang tim kiem cho query: {query[:50]}...")
|
| 140 |
-
|
| 141 |
-
# Encode query thành embedding (với prefix query:)
|
| 142 |
query_embedding = self.encode(query, is_query=True)[0]
|
| 143 |
|
| 144 |
-
# Tạo where clause cho age filter
|
| 145 |
where_clause = None
|
| 146 |
if age_filter:
|
| 147 |
where_clause = {
|
|
@@ -150,34 +287,53 @@ class EmbeddingModel:
|
|
| 150 |
{"age_max": {"$gte": age_filter}}
|
| 151 |
]
|
| 152 |
}
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
| 155 |
search_results = self.collection.query(
|
| 156 |
query_embeddings=[query_embedding],
|
| 157 |
n_results=top_k,
|
| 158 |
where=where_clause,
|
| 159 |
include=['documents', 'metadatas', 'distances']
|
| 160 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
if not search_results or not search_results['documents']:
|
| 163 |
-
logger.warning("
|
| 164 |
return []
|
| 165 |
|
| 166 |
-
# Format kết quả
|
| 167 |
results = []
|
| 168 |
documents = search_results['documents'][0]
|
| 169 |
metadatas = search_results['metadatas'][0]
|
| 170 |
distances = search_results['distances'][0]
|
| 171 |
|
| 172 |
for i, (doc, metadata, distance) in enumerate(zip(documents, metadatas, distances)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
results.append({
|
| 174 |
'document': doc,
|
| 175 |
'metadata': metadata or {},
|
| 176 |
'distance': distance,
|
| 177 |
-
'similarity':
|
| 178 |
'rank': i + 1
|
| 179 |
})
|
| 180 |
|
|
|
|
| 181 |
logger.info(f"Tim thay {len(results)} ket qua cho query")
|
| 182 |
return results
|
| 183 |
|
|
@@ -186,36 +342,22 @@ class EmbeddingModel:
|
|
| 186 |
return []
|
| 187 |
|
| 188 |
def add_documents(self, documents, metadatas=None, ids=None):
|
| 189 |
-
"""
|
| 190 |
-
Thêm documents vào ChromaDB
|
| 191 |
-
|
| 192 |
-
Args:
|
| 193 |
-
documents (list): Danh sách văn bản
|
| 194 |
-
metadatas (list): Danh sách metadata tương ứng
|
| 195 |
-
ids (list): Danh sách ID tương ứng (optional)
|
| 196 |
-
|
| 197 |
-
Returns:
|
| 198 |
-
bool: True nếu thành công
|
| 199 |
-
"""
|
| 200 |
try:
|
| 201 |
if not documents:
|
| 202 |
logger.warning("Không có documents để thêm")
|
| 203 |
return False
|
| 204 |
|
| 205 |
-
# Tạo IDs nếu không được cung cấp
|
| 206 |
if not ids:
|
| 207 |
ids = [str(uuid.uuid4()) for _ in documents]
|
| 208 |
|
| 209 |
-
# Tạo metadatas rỗng nếu không được cung cấp
|
| 210 |
if not metadatas:
|
| 211 |
metadatas = [{} for _ in documents]
|
| 212 |
|
| 213 |
logger.info(f"Đang thêm {len(documents)} documents vào ChromaDB")
|
| 214 |
|
| 215 |
-
# Encode documents thành embeddings (với prefix passage:)
|
| 216 |
embeddings = self.encode(documents, is_query=False)
|
| 217 |
|
| 218 |
-
# Thêm vào collection
|
| 219 |
self.collection.add(
|
| 220 |
embeddings=embeddings,
|
| 221 |
documents=documents,
|
|
@@ -231,9 +373,7 @@ class EmbeddingModel:
|
|
| 231 |
return False
|
| 232 |
|
| 233 |
def index_chunks(self, chunks):
|
| 234 |
-
"""
|
| 235 |
-
Index các chunks dữ liệu vào ChromaDB
|
| 236 |
-
"""
|
| 237 |
try:
|
| 238 |
if not chunks:
|
| 239 |
logger.warning("Không có chunks để index")
|
|
@@ -250,11 +390,9 @@ class EmbeddingModel:
|
|
| 250 |
|
| 251 |
documents.append(chunk['content'])
|
| 252 |
|
| 253 |
-
# Lấy metadata đã được chuẩn bị sẵn
|
| 254 |
metadata = chunk.get('metadata', {})
|
| 255 |
metadatas.append(metadata)
|
| 256 |
|
| 257 |
-
# Sử dụng ID có sẵn hoặc tạo mới
|
| 258 |
chunk_id = chunk.get('id') or str(uuid.uuid4())
|
| 259 |
ids.append(chunk_id)
|
| 260 |
|
|
@@ -262,7 +400,6 @@ class EmbeddingModel:
|
|
| 262 |
logger.warning("Không có documents hợp lệ để index")
|
| 263 |
return False
|
| 264 |
|
| 265 |
-
# Batch processing để tránh overload
|
| 266 |
batch_size = 100
|
| 267 |
total_batches = (len(documents) + batch_size - 1) // batch_size
|
| 268 |
|
|
@@ -300,9 +437,9 @@ class EmbeddingModel:
|
|
| 300 |
logger.warning(f"Đang xóa collection: {COLLECTION_NAME}")
|
| 301 |
self.chroma_client.delete_collection(name=COLLECTION_NAME)
|
| 302 |
|
| 303 |
-
# Tạo lại collection
|
| 304 |
-
self.
|
| 305 |
-
logger.info("Đã tạo lại collection mới")
|
| 306 |
|
| 307 |
return True
|
| 308 |
|
|
@@ -310,15 +447,49 @@ class EmbeddingModel:
|
|
| 310 |
logger.error(f"Lỗi xóa collection: {e}")
|
| 311 |
return False
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
def get_stats(self):
|
| 314 |
"""Lấy thống kê về collection"""
|
| 315 |
try:
|
| 316 |
total_count = self.count()
|
|
|
|
| 317 |
|
| 318 |
-
# Lấy sample để phân tích metadata
|
| 319 |
sample_results = self.collection.get(limit=min(100, total_count))
|
| 320 |
|
| 321 |
-
# Thống kê content types
|
| 322 |
content_types = {}
|
| 323 |
chapters = {}
|
| 324 |
age_groups = {}
|
|
@@ -328,15 +499,12 @@ class EmbeddingModel:
|
|
| 328 |
if not metadata:
|
| 329 |
continue
|
| 330 |
|
| 331 |
-
# Content type stats
|
| 332 |
content_type = metadata.get('content_type', 'unknown')
|
| 333 |
content_types[content_type] = content_types.get(content_type, 0) + 1
|
| 334 |
|
| 335 |
-
# Chapter stats
|
| 336 |
chapter = metadata.get('chapter', 'unknown')
|
| 337 |
chapters[chapter] = chapters.get(chapter, 0) + 1
|
| 338 |
|
| 339 |
-
# Age group stats
|
| 340 |
age_group = metadata.get('age_group', 'unknown')
|
| 341 |
age_groups[age_group] = age_groups.get(age_group, 0) + 1
|
| 342 |
|
|
@@ -346,7 +514,9 @@ class EmbeddingModel:
|
|
| 346 |
'chapters': chapters,
|
| 347 |
'age_groups': age_groups,
|
| 348 |
'collection_name': COLLECTION_NAME,
|
| 349 |
-
'embedding_model': EMBEDDING_MODEL
|
|
|
|
|
|
|
| 350 |
}
|
| 351 |
|
| 352 |
except Exception as e:
|
|
|
|
| 6 |
import os
|
| 7 |
from config import EMBEDDING_MODEL, CHROMA_PERSIST_DIRECTORY, COLLECTION_NAME
|
| 8 |
|
|
|
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
|
|
|
| 11 |
_embedding_model_instance = None
|
| 12 |
|
| 13 |
def get_embedding_model():
|
| 14 |
+
"""Kiểm tra và khởi tạo embedding đảm bảo chỉ khởi tạo một lần"""
|
|
|
|
|
|
|
| 15 |
global _embedding_model_instance
|
| 16 |
if _embedding_model_instance is None:
|
| 17 |
logger.info("Khởi tạo EmbeddingModel instance lần đầu")
|
|
|
|
| 36 |
self.model = SentenceTransformer(EMBEDDING_MODEL, cache_folder=cache_dir, trust_remote_code=True)
|
| 37 |
logger.info("Đã tải sentence transformer model với cache folder explicit")
|
| 38 |
|
| 39 |
+
# SỬA: Khai báo biến persist_directory local để tránh lỗi scope
|
| 40 |
+
persist_directory = CHROMA_PERSIST_DIRECTORY
|
| 41 |
+
|
| 42 |
# Đảm bảo thư mục ChromaDB tồn tại và có quyền ghi
|
| 43 |
try:
|
| 44 |
+
os.makedirs(persist_directory, exist_ok=True)
|
| 45 |
# Test ghi file để kiểm tra permission
|
| 46 |
+
test_file = os.path.join(persist_directory, 'test_permission.tmp')
|
| 47 |
with open(test_file, 'w') as f:
|
| 48 |
f.write('test')
|
| 49 |
os.remove(test_file)
|
| 50 |
+
logger.info(f"Thư mục ChromaDB đã sẵn sàng: {persist_directory}")
|
| 51 |
except Exception as e:
|
| 52 |
logger.error(f"Lỗi tạo/kiểm tra thư mục ChromaDB: {e}")
|
| 53 |
# Fallback to /tmp directory
|
| 54 |
import tempfile
|
| 55 |
+
persist_directory = os.path.join(tempfile.gettempdir(), 'chroma_db')
|
| 56 |
+
os.makedirs(persist_directory, exist_ok=True)
|
| 57 |
+
logger.warning(f"Sử dụng thư mục tạm thời: {persist_directory}")
|
| 58 |
|
| 59 |
# Khởi tạo ChromaDB client với persistent storage
|
| 60 |
try:
|
| 61 |
self.chroma_client = chromadb.PersistentClient(
|
| 62 |
+
path=persist_directory,
|
| 63 |
settings=Settings(
|
| 64 |
anonymized_telemetry=False,
|
| 65 |
allow_reset=True
|
| 66 |
)
|
| 67 |
)
|
| 68 |
+
logger.info(f"Đã kết nối ChromaDB tại: {persist_directory}")
|
| 69 |
except Exception as e:
|
| 70 |
logger.error(f"Lỗi kết nối ChromaDB: {e}")
|
| 71 |
# Fallback to in-memory client
|
| 72 |
logger.warning("Fallback to in-memory ChromaDB client")
|
| 73 |
self.chroma_client = chromadb.Client()
|
| 74 |
|
| 75 |
+
# Lấy hoặc tạo collection với cosine similarity
|
| 76 |
try:
|
| 77 |
self.collection = self.chroma_client.get_collection(name=COLLECTION_NAME)
|
| 78 |
logger.info(f"Đã kết nối collection '{COLLECTION_NAME}' với {self.collection.count()} items")
|
| 79 |
except Exception:
|
| 80 |
+
logger.info(f"Collection '{COLLECTION_NAME}' không tồn tại, tạo mới với cosine similarity...")
|
| 81 |
+
self.collection = self.chroma_client.create_collection(
|
| 82 |
+
name=COLLECTION_NAME,
|
| 83 |
+
metadata={
|
| 84 |
+
"hnsw:space": "cosine", # Cosine distance
|
| 85 |
+
"hnsw:M": 16, # Optimize for accuracy
|
| 86 |
+
"hnsw:construction_ef": 100
|
| 87 |
+
}
|
| 88 |
+
)
|
| 89 |
+
logger.info(f"Đã tạo collection mới với cosine similarity: {COLLECTION_NAME}")
|
| 90 |
+
|
| 91 |
+
def _initialize_collection(self):
|
| 92 |
+
"""Khởi tạo collection với cosine similarity"""
|
| 93 |
+
try:
|
| 94 |
+
# Kiểm tra xem collection đã tồn tại chưa
|
| 95 |
+
existing_collections = [col.name for col in self.chroma_client.list_collections()]
|
| 96 |
+
|
| 97 |
+
if COLLECTION_NAME in existing_collections:
|
| 98 |
+
self.collection = self.chroma_client.get_collection(name=COLLECTION_NAME)
|
| 99 |
+
|
| 100 |
+
# Kiểm tra distance function hiện tại
|
| 101 |
+
current_metadata = self.collection.metadata or {}
|
| 102 |
+
current_space = current_metadata.get("hnsw:space", "l2")
|
| 103 |
+
|
| 104 |
+
if current_space != "cosine":
|
| 105 |
+
logger.warning(f"Collection hiện tại đang dùng {current_space}, cần migration sang cosine")
|
| 106 |
+
if self.collection.count() > 0:
|
| 107 |
+
self._migrate_to_cosine()
|
| 108 |
+
else:
|
| 109 |
+
# Collection trống, xóa và tạo lại
|
| 110 |
+
self.chroma_client.delete_collection(name=COLLECTION_NAME)
|
| 111 |
+
self._create_cosine_collection()
|
| 112 |
+
else:
|
| 113 |
+
logger.info(f"Đã kết nối collection '{COLLECTION_NAME}' với cosine similarity, {self.collection.count()} items")
|
| 114 |
+
else:
|
| 115 |
+
# Collection chưa tồn tại, tạo mới với cosine
|
| 116 |
+
self._create_cosine_collection()
|
| 117 |
+
|
| 118 |
+
except Exception as e:
|
| 119 |
+
logger.error(f"Lỗi khởi tạo collection: {e}")
|
| 120 |
+
# Fallback: tạo collection mới
|
| 121 |
+
self._create_cosine_collection()
|
| 122 |
+
|
| 123 |
+
def _create_cosine_collection(self):
|
| 124 |
+
"""Tạo collection mới với cosine similarity"""
|
| 125 |
+
try:
|
| 126 |
+
self.collection = self.chroma_client.create_collection(
|
| 127 |
+
name=COLLECTION_NAME,
|
| 128 |
+
metadata={"hnsw:space": "cosine"}
|
| 129 |
+
)
|
| 130 |
+
logger.info(f"Đã tạo collection mới với cosine similarity: {COLLECTION_NAME}")
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.error(f"Lỗi tạo collection với cosine: {e}")
|
| 133 |
+
# Fallback về collection mặc định
|
| 134 |
+
self.collection = self.chroma_client.get_or_create_collection(name=COLLECTION_NAME)
|
| 135 |
+
logger.warning("Đã fallback về collection mặc định (có thể dùng L2)")
|
| 136 |
+
|
| 137 |
+
def _migrate_to_cosine(self):
|
| 138 |
+
"""Migration collection từ L2 sang cosine"""
|
| 139 |
+
try:
|
| 140 |
+
logger.info("Bắt đầu migration collection sang cosine similarity...")
|
| 141 |
+
|
| 142 |
+
# Backup toàn bộ data
|
| 143 |
+
all_data = self.collection.get(
|
| 144 |
+
include=['documents', 'metadatas', 'embeddings'],
|
| 145 |
+
limit=self.collection.count()
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
if not all_data['documents']:
|
| 149 |
+
logger.info("Collection trống, chỉ cần tạo lại")
|
| 150 |
+
self.chroma_client.delete_collection(name=COLLECTION_NAME)
|
| 151 |
+
self._create_cosine_collection()
|
| 152 |
+
return
|
| 153 |
+
|
| 154 |
+
# Xóa collection cũ và tạo mới với cosine
|
| 155 |
+
self.chroma_client.delete_collection(name=COLLECTION_NAME)
|
| 156 |
+
self._create_cosine_collection()
|
| 157 |
+
|
| 158 |
+
# Restore data theo batch
|
| 159 |
+
documents = all_data['documents']
|
| 160 |
+
metadatas = all_data['metadatas']
|
| 161 |
+
embeddings = all_data['embeddings']
|
| 162 |
+
ids = all_data['ids']
|
| 163 |
+
|
| 164 |
+
batch_size = 100
|
| 165 |
+
total_items = len(documents)
|
| 166 |
+
|
| 167 |
+
for i in range(0, total_items, batch_size):
|
| 168 |
+
batch_docs = documents[i:i + batch_size]
|
| 169 |
+
batch_metas = metadatas[i:i + batch_size] if metadatas else None
|
| 170 |
+
batch_embeds = embeddings[i:i + batch_size] if embeddings else None
|
| 171 |
+
batch_ids = ids[i:i + batch_size]
|
| 172 |
+
|
| 173 |
+
if batch_embeds:
|
| 174 |
+
# Có embeddings sẵn, dùng luôn
|
| 175 |
+
self.collection.add(
|
| 176 |
+
documents=batch_docs,
|
| 177 |
+
metadatas=batch_metas,
|
| 178 |
+
embeddings=batch_embeds,
|
| 179 |
+
ids=batch_ids
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
# Tính lại embeddings
|
| 183 |
+
new_embeddings = self.encode(batch_docs, is_query=False)
|
| 184 |
+
self.collection.add(
|
| 185 |
+
documents=batch_docs,
|
| 186 |
+
metadatas=batch_metas,
|
| 187 |
+
embeddings=new_embeddings,
|
| 188 |
+
ids=batch_ids
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
logger.info(f"Migration progress: {min(i + batch_size, total_items)}/{total_items}")
|
| 192 |
+
|
| 193 |
+
logger.info(f"Migration hoàn thành! Đã chuyển {total_items} items sang cosine similarity")
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
logger.error(f"Lỗi migration: {e}")
|
| 197 |
+
# Tạo collection mới nếu migration thất bại
|
| 198 |
+
self._create_cosine_collection()
|
| 199 |
+
|
| 200 |
+
def test_embedding_quality(self):
|
| 201 |
+
try:
|
| 202 |
+
# Test cases
|
| 203 |
+
test_cases = [
|
| 204 |
+
("query: Tháp dinh dưỡng cho trẻ", "passage: Tháp dinh dưỡng cho trẻ từ 6-11 tuổi"),
|
| 205 |
+
("query: dinh dưỡng", "passage: dinh dưỡng cho học sinh"),
|
| 206 |
+
("query: xin chào", "passage: Tháp dinh dưỡng cho trẻ")
|
| 207 |
+
]
|
| 208 |
+
|
| 209 |
+
for query_text, doc_text in test_cases:
|
| 210 |
+
# Encode
|
| 211 |
+
query_emb = self.model.encode([query_text], normalize_embeddings=True)[0]
|
| 212 |
+
doc_emb = self.model.encode([doc_text], normalize_embeddings=True)[0]
|
| 213 |
+
|
| 214 |
+
# Calculate cosine similarity manually
|
| 215 |
+
import numpy as np
|
| 216 |
+
similarity = np.dot(query_emb, doc_emb)
|
| 217 |
+
|
| 218 |
+
logger.info(f"Query: {query_text}")
|
| 219 |
+
logger.info(f"Doc: {doc_text}")
|
| 220 |
+
logger.info(f"Similarity: {similarity:.3f}")
|
| 221 |
+
logger.info(f"Query norm: {np.linalg.norm(query_emb):.3f}")
|
| 222 |
+
logger.info(f"Doc norm: {np.linalg.norm(doc_emb):.3f}")
|
| 223 |
+
logger.info("-" * 50)
|
| 224 |
+
|
| 225 |
+
except Exception as e:
|
| 226 |
+
logger.error(f"Test embedding error: {e}")
|
| 227 |
|
| 228 |
def _add_prefix_to_text(self, text, is_query=True):
|
| 229 |
+
# Clean text trước
|
| 230 |
+
text = text.strip()
|
| 231 |
+
|
| 232 |
# Kiểm tra xem text đã có prefix chưa
|
| 233 |
if text.startswith(('query:', 'passage:')):
|
| 234 |
return text
|
|
|
|
| 241 |
|
| 242 |
def encode(self, texts, is_query=True):
|
| 243 |
"""
|
| 244 |
+
Encode văn bản thành embeddings với proper normalization
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
"""
|
| 246 |
try:
|
| 247 |
if isinstance(texts, str):
|
| 248 |
texts = [texts]
|
| 249 |
|
| 250 |
+
# Thêm prefix cho texts (QUAN TRỌNG cho multilingual-e5-base)
|
| 251 |
processed_texts = [self._add_prefix_to_text(text, is_query) for text in texts]
|
| 252 |
|
| 253 |
+
logger.debug(f"Đang encode {len(processed_texts)} văn bản với prefix")
|
| 254 |
+
logger.debug(f"Sample processed text: {processed_texts[0][:100]}...")
|
| 255 |
+
|
| 256 |
+
# Encode với normalize_embeddings=True (QUAN TRỌNG!)
|
| 257 |
+
embeddings = self.model.encode(
|
| 258 |
+
processed_texts,
|
| 259 |
+
show_progress_bar=False,
|
| 260 |
+
normalize_embeddings=True # ✅ THÊM DÒNG NÀY
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Double-check normalization
|
| 264 |
+
import numpy as np
|
| 265 |
+
for i, emb in enumerate(embeddings[:2]): # Check first 2 embeddings
|
| 266 |
+
norm = np.linalg.norm(emb)
|
| 267 |
+
logger.debug(f"Embedding {i} norm: {norm}")
|
| 268 |
+
if abs(norm - 1.0) > 0.01:
|
| 269 |
+
logger.warning(f"Embedding {i} not properly normalized: norm = {norm}")
|
| 270 |
|
| 271 |
return embeddings.tolist()
|
| 272 |
|
|
|
|
| 275 |
raise
|
| 276 |
|
| 277 |
def search(self, query, top_k=5, age_filter=None):
|
| 278 |
+
"""Tìm kiếm văn bản tương tự trong ChromaDB"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
try:
|
|
|
|
|
|
|
|
|
|
| 280 |
query_embedding = self.encode(query, is_query=True)[0]
|
| 281 |
|
|
|
|
| 282 |
where_clause = None
|
| 283 |
if age_filter:
|
| 284 |
where_clause = {
|
|
|
|
| 287 |
{"age_max": {"$gte": age_filter}}
|
| 288 |
]
|
| 289 |
}
|
| 290 |
+
print(f"🔍 AGE FILTER: Tìm kiếm cho tuổi {age_filter}")
|
| 291 |
+
print(f"🔍 WHERE CLAUSE: {where_clause}")
|
| 292 |
+
else:
|
| 293 |
+
print(f"⚠️ KHÔNG CÓ AGE FILTER - Tìm tất cả chunks")
|
| 294 |
search_results = self.collection.query(
|
| 295 |
query_embeddings=[query_embedding],
|
| 296 |
n_results=top_k,
|
| 297 |
where=where_clause,
|
| 298 |
include=['documents', 'metadatas', 'distances']
|
| 299 |
)
|
| 300 |
+
|
| 301 |
+
print(f"\n{'='*60}")
|
| 302 |
+
print(f"📊 CHROMADB SEARCH RESULTS")
|
| 303 |
+
print(f"{'='*60}")
|
| 304 |
+
print(f"Query: {query}")
|
| 305 |
+
print(f"Age filter: {age_filter}")
|
| 306 |
+
print(f"Found {len(search_results['documents'][0]) if search_results['documents'] else 0} chunks")
|
| 307 |
+
print(f"{'='*60}")
|
| 308 |
|
| 309 |
if not search_results or not search_results['documents']:
|
| 310 |
+
logger.warning("Không tìm thấy kết quả nào")
|
| 311 |
return []
|
| 312 |
|
|
|
|
| 313 |
results = []
|
| 314 |
documents = search_results['documents'][0]
|
| 315 |
metadatas = search_results['metadatas'][0]
|
| 316 |
distances = search_results['distances'][0]
|
| 317 |
|
| 318 |
for i, (doc, metadata, distance) in enumerate(zip(documents, metadatas, distances)):
|
| 319 |
+
chunk_id = metadata.get('chunk_id', f'chunk_{i}')
|
| 320 |
+
title = metadata.get('title', 'No title')
|
| 321 |
+
age_range = metadata.get('age_range', 'Unknown')
|
| 322 |
+
age_min = metadata.get('age_min', 'N/A')
|
| 323 |
+
age_max = metadata.get('age_max', 'N/A')
|
| 324 |
+
content_type = metadata.get('content_type', 'text')
|
| 325 |
+
chapter = metadata.get('chapter', 'Unknown')
|
| 326 |
+
similarity = round(1 - distance, 3)
|
| 327 |
+
|
| 328 |
results.append({
|
| 329 |
'document': doc,
|
| 330 |
'metadata': metadata or {},
|
| 331 |
'distance': distance,
|
| 332 |
+
'similarity': similarity,
|
| 333 |
'rank': i + 1
|
| 334 |
})
|
| 335 |
|
| 336 |
+
print(f"\n{'='*60}")
|
| 337 |
logger.info(f"Tim thay {len(results)} ket qua cho query")
|
| 338 |
return results
|
| 339 |
|
|
|
|
| 342 |
return []
|
| 343 |
|
| 344 |
def add_documents(self, documents, metadatas=None, ids=None):
|
| 345 |
+
"""Thêm documents vào ChromaDB"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
try:
|
| 347 |
if not documents:
|
| 348 |
logger.warning("Không có documents để thêm")
|
| 349 |
return False
|
| 350 |
|
|
|
|
| 351 |
if not ids:
|
| 352 |
ids = [str(uuid.uuid4()) for _ in documents]
|
| 353 |
|
|
|
|
| 354 |
if not metadatas:
|
| 355 |
metadatas = [{} for _ in documents]
|
| 356 |
|
| 357 |
logger.info(f"Đang thêm {len(documents)} documents vào ChromaDB")
|
| 358 |
|
|
|
|
| 359 |
embeddings = self.encode(documents, is_query=False)
|
| 360 |
|
|
|
|
| 361 |
self.collection.add(
|
| 362 |
embeddings=embeddings,
|
| 363 |
documents=documents,
|
|
|
|
| 373 |
return False
|
| 374 |
|
| 375 |
def index_chunks(self, chunks):
|
| 376 |
+
"""Index các chunks dữ liệu vào ChromaDB"""
|
|
|
|
|
|
|
| 377 |
try:
|
| 378 |
if not chunks:
|
| 379 |
logger.warning("Không có chunks để index")
|
|
|
|
| 390 |
|
| 391 |
documents.append(chunk['content'])
|
| 392 |
|
|
|
|
| 393 |
metadata = chunk.get('metadata', {})
|
| 394 |
metadatas.append(metadata)
|
| 395 |
|
|
|
|
| 396 |
chunk_id = chunk.get('id') or str(uuid.uuid4())
|
| 397 |
ids.append(chunk_id)
|
| 398 |
|
|
|
|
| 400 |
logger.warning("Không có documents hợp lệ để index")
|
| 401 |
return False
|
| 402 |
|
|
|
|
| 403 |
batch_size = 100
|
| 404 |
total_batches = (len(documents) + batch_size - 1) // batch_size
|
| 405 |
|
|
|
|
| 437 |
logger.warning(f"Đang xóa collection: {COLLECTION_NAME}")
|
| 438 |
self.chroma_client.delete_collection(name=COLLECTION_NAME)
|
| 439 |
|
| 440 |
+
# Tạo lại collection với cosine similarity
|
| 441 |
+
self._create_cosine_collection()
|
| 442 |
+
logger.info("Đã tạo lại collection mới với cosine similarity")
|
| 443 |
|
| 444 |
return True
|
| 445 |
|
|
|
|
| 447 |
logger.error(f"Lỗi xóa collection: {e}")
|
| 448 |
return False
|
| 449 |
|
| 450 |
+
def get_collection_info(self):
|
| 451 |
+
"""Lấy thông tin về collection và distance function"""
|
| 452 |
+
try:
|
| 453 |
+
metadata = self.collection.metadata or {}
|
| 454 |
+
distance_func = metadata.get("hnsw:space", "l2")
|
| 455 |
+
|
| 456 |
+
return {
|
| 457 |
+
'collection_name': COLLECTION_NAME,
|
| 458 |
+
'distance_function': distance_func,
|
| 459 |
+
'total_documents': self.count(),
|
| 460 |
+
'metadata': metadata
|
| 461 |
+
}
|
| 462 |
+
except Exception as e:
|
| 463 |
+
logger.error(f"Lỗi lấy collection info: {e}")
|
| 464 |
+
return {'error': str(e)}
|
| 465 |
+
|
| 466 |
+
def verify_cosine_similarity(self):
|
| 467 |
+
"""Kiểm tra và xác nhận đang sử dụng cosine similarity"""
|
| 468 |
+
try:
|
| 469 |
+
info = self.get_collection_info()
|
| 470 |
+
distance_func = info.get('distance_function', 'unknown')
|
| 471 |
+
|
| 472 |
+
logger.info(f"Collection đang sử dụng distance function: {distance_func}")
|
| 473 |
+
|
| 474 |
+
if distance_func == "cosine":
|
| 475 |
+
logger.info("Xác nhận: Đang sử dụng cosine similarity")
|
| 476 |
+
return True
|
| 477 |
+
else:
|
| 478 |
+
logger.warning(f"Cảnh báo: Đang sử dụng {distance_func}, không phải cosine")
|
| 479 |
+
return False
|
| 480 |
+
|
| 481 |
+
except Exception as e:
|
| 482 |
+
logger.error(f"Lỗi verify cosine: {e}")
|
| 483 |
+
return False
|
| 484 |
+
|
| 485 |
def get_stats(self):
|
| 486 |
"""Lấy thống kê về collection"""
|
| 487 |
try:
|
| 488 |
total_count = self.count()
|
| 489 |
+
collection_info = self.get_collection_info()
|
| 490 |
|
|
|
|
| 491 |
sample_results = self.collection.get(limit=min(100, total_count))
|
| 492 |
|
|
|
|
| 493 |
content_types = {}
|
| 494 |
chapters = {}
|
| 495 |
age_groups = {}
|
|
|
|
| 499 |
if not metadata:
|
| 500 |
continue
|
| 501 |
|
|
|
|
| 502 |
content_type = metadata.get('content_type', 'unknown')
|
| 503 |
content_types[content_type] = content_types.get(content_type, 0) + 1
|
| 504 |
|
|
|
|
| 505 |
chapter = metadata.get('chapter', 'unknown')
|
| 506 |
chapters[chapter] = chapters.get(chapter, 0) + 1
|
| 507 |
|
|
|
|
| 508 |
age_group = metadata.get('age_group', 'unknown')
|
| 509 |
age_groups[age_group] = age_groups.get(age_group, 0) + 1
|
| 510 |
|
|
|
|
| 514 |
'chapters': chapters,
|
| 515 |
'age_groups': age_groups,
|
| 516 |
'collection_name': COLLECTION_NAME,
|
| 517 |
+
'embedding_model': EMBEDDING_MODEL,
|
| 518 |
+
'distance_function': collection_info.get('distance_function', 'unknown'),
|
| 519 |
+
'using_cosine_similarity': collection_info.get('distance_function') == 'cosine'
|
| 520 |
}
|
| 521 |
|
| 522 |
except Exception as e:
|
core/rag_pipeline.py
CHANGED
|
@@ -5,68 +5,56 @@ from config import GEMINI_API_KEY, HUMAN_PROMPT_TEMPLATE, SYSTEM_PROMPT, TOP_K_R
|
|
| 5 |
import os
|
| 6 |
import re
|
| 7 |
|
| 8 |
-
# Cấu hình logging
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
| 11 |
-
# Cấu hình Gemini
|
| 12 |
genai.configure(api_key=GEMINI_API_KEY)
|
| 13 |
|
| 14 |
class RAGPipeline:
|
| 15 |
def __init__(self):
|
| 16 |
-
|
| 17 |
-
logger.info("
|
| 18 |
|
| 19 |
self.embedding_model = get_embedding_model()
|
| 20 |
-
|
| 21 |
-
# Khởi tạo Gemini model
|
| 22 |
self.gemini_model = genai.GenerativeModel('gemini-2.0-flash')
|
| 23 |
|
| 24 |
-
logger.info("RAG Pipeline đã sẵn sàng")
|
| 25 |
|
| 26 |
def generate_response(self, query, age=1):
|
| 27 |
-
|
| 28 |
-
Generate response cho user query sử dụng RAG
|
| 29 |
-
|
| 30 |
-
Args:
|
| 31 |
-
query (str): Câu hỏi của người dùng
|
| 32 |
-
age (int): Tuổi của người dùng (1-19)
|
| 33 |
-
|
| 34 |
-
Returns:
|
| 35 |
-
dict: Response data with success status
|
| 36 |
-
"""
|
| 37 |
try:
|
| 38 |
-
logger.info(f"Bắt đầu
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
# SỬA: Chỉ search trong ChromaDB, không load lại dữ liệu
|
| 41 |
-
logger.info("Đang tìm kiếm thông tin liên quan...")
|
| 42 |
-
search_results = self.embedding_model.search(query, top_k=TOP_K_RESULTS)
|
| 43 |
if not search_results or len(search_results) == 0:
|
| 44 |
-
logger.warning("Không tìm thấy thông tin liên quan")
|
| 45 |
return {
|
| 46 |
"success": True,
|
| 47 |
"response": "Xin lỗi, tôi không tìm thấy thông tin liên quan đến câu hỏi của bạn trong tài liệu.",
|
| 48 |
"sources": []
|
| 49 |
}
|
| 50 |
|
| 51 |
-
# Chuẩn bị
|
| 52 |
contexts = []
|
| 53 |
sources = []
|
| 54 |
|
| 55 |
for result in search_results:
|
| 56 |
-
# Lấy thông tin từ metadata
|
| 57 |
metadata = result.get('metadata', {})
|
| 58 |
content = result.get('document', '')
|
| 59 |
|
| 60 |
-
# Thêm
|
| 61 |
contexts.append({
|
| 62 |
"content": content,
|
| 63 |
"metadata": metadata
|
| 64 |
})
|
| 65 |
|
| 66 |
-
#
|
| 67 |
source_info = {
|
| 68 |
-
"
|
| 69 |
-
"title": metadata.get('title', metadata.get('chapter', 'Tài liệu dinh dưỡng')), # Giữ title nếu cần
|
| 70 |
"pages": metadata.get('pages'),
|
| 71 |
"content_type": metadata.get('content_type', 'text')
|
| 72 |
}
|
|
@@ -74,14 +62,14 @@ class RAGPipeline:
|
|
| 74 |
if source_info not in sources:
|
| 75 |
sources.append(source_info)
|
| 76 |
|
| 77 |
-
#
|
| 78 |
formatted_contexts = self._format_contexts(contexts)
|
| 79 |
|
| 80 |
-
# Tạo prompt với
|
| 81 |
full_prompt = self._create_prompt_with_age_context(query, age, formatted_contexts)
|
| 82 |
|
| 83 |
-
#
|
| 84 |
-
logger.info("Đang tạo phản hồi với Gemini
|
| 85 |
response = self.gemini_model.generate_content(
|
| 86 |
full_prompt,
|
| 87 |
generation_config=genai.types.GenerationConfig(
|
|
@@ -91,7 +79,7 @@ class RAGPipeline:
|
|
| 91 |
)
|
| 92 |
|
| 93 |
if not response or not response.text:
|
| 94 |
-
logger.error("Gemini không trả về
|
| 95 |
return {
|
| 96 |
"success": False,
|
| 97 |
"error": "Không thể tạo phản hồi"
|
|
@@ -99,7 +87,7 @@ class RAGPipeline:
|
|
| 99 |
|
| 100 |
response_text = response.text.strip()
|
| 101 |
|
| 102 |
-
#
|
| 103 |
response_text = self._process_image_links(response_text)
|
| 104 |
|
| 105 |
logger.info("Đã tạo phản hồi thành công")
|
|
@@ -111,25 +99,23 @@ class RAGPipeline:
|
|
| 111 |
}
|
| 112 |
|
| 113 |
except Exception as e:
|
| 114 |
-
logger.error(f"Lỗi
|
| 115 |
return {
|
| 116 |
"success": False,
|
| 117 |
"error": f"Lỗi tạo phản hồi: {str(e)}"
|
| 118 |
}
|
| 119 |
|
| 120 |
def _format_contexts(self, contexts):
|
| 121 |
-
|
| 122 |
formatted = []
|
| 123 |
|
| 124 |
for i, context in enumerate(contexts, 1):
|
| 125 |
content = context['content']
|
| 126 |
metadata = context['metadata']
|
| 127 |
|
| 128 |
-
# Thêm thông tin metadata
|
| 129 |
context_str = f"[Tài liệu {i}]"
|
| 130 |
-
if metadata.get('
|
| 131 |
-
context_str += f" - ID: {metadata['chunk_id']}"
|
| 132 |
-
elif metadata.get('title'):
|
| 133 |
context_str += f" - {metadata['title']}"
|
| 134 |
if metadata.get('pages'):
|
| 135 |
context_str += f" (Trang {metadata['pages']})"
|
|
@@ -139,9 +125,8 @@ class RAGPipeline:
|
|
| 139 |
|
| 140 |
return "\n".join(formatted)
|
| 141 |
|
| 142 |
-
def _create_prompt_with_age_context(self, query, age, contexts):
|
| 143 |
-
|
| 144 |
-
# Xác định age group
|
| 145 |
if age <= 3:
|
| 146 |
age_guidance = "Sử dụng ngôn ngữ đơn giản, dễ hiểu cho phụ huynh có con nhỏ."
|
| 147 |
elif age <= 6:
|
|
@@ -153,7 +138,7 @@ class RAGPipeline:
|
|
| 153 |
else:
|
| 154 |
age_guidance = "Thông tin đầy đủ, chi tiết cho học sinh trung học phổ thông."
|
| 155 |
|
| 156 |
-
# Tạo system prompt
|
| 157 |
age_aware_system_prompt = f"""{SYSTEM_PROMPT}
|
| 158 |
|
| 159 |
QUAN TRỌNG - Hướng dẫn theo độ tuổi:
|
|
@@ -163,7 +148,7 @@ Người dùng hiện tại {age} tuổi. {age_guidance}
|
|
| 163 |
- Tránh thông tin quá phức tạp hoặc không phù hợp
|
| 164 |
"""
|
| 165 |
|
| 166 |
-
# Tạo human prompt
|
| 167 |
human_prompt = HUMAN_PROMPT_TEMPLATE.format(
|
| 168 |
query=query,
|
| 169 |
age=age,
|
|
@@ -173,30 +158,28 @@ Người dùng hiện tại {age} tuổi. {age_guidance}
|
|
| 173 |
return f"{age_aware_system_prompt}\n\n{human_prompt}"
|
| 174 |
|
| 175 |
def _process_image_links(self, response_text):
|
| 176 |
-
|
| 177 |
try:
|
| 178 |
import re
|
| 179 |
|
| 180 |
-
# Tìm các pattern markdown
|
| 181 |
image_pattern = r'!\[([^\]]*)\]\(([^)]+)\)'
|
| 182 |
|
| 183 |
def replace_image_path(match):
|
| 184 |
alt_text = match.group(1)
|
| 185 |
image_path = match.group(2)
|
| 186 |
|
| 187 |
-
# Xử lý đường dẫn local Windows/Linux
|
| 188 |
if '\\' in image_path or image_path.startswith('/') or ':' in image_path:
|
| 189 |
-
#
|
| 190 |
filename = image_path.split('\\')[-1].split('/')[-1]
|
| 191 |
|
| 192 |
-
# Tìm bai_id từ
|
| 193 |
bai_match = re.match(r'^(bai\d+)_', filename)
|
| 194 |
if bai_match:
|
| 195 |
bai_id = bai_match.group(1)
|
| 196 |
-
else:
|
| 197 |
-
bai_id = 'bai1' # default
|
| 198 |
|
| 199 |
-
# Tạo
|
| 200 |
api_url = f"/api/figures/{bai_id}/{filename}"
|
| 201 |
return f""
|
| 202 |
|
|
@@ -210,39 +193,29 @@ Người dùng hiện tại {age} tuổi. {age_guidance}
|
|
| 210 |
bai_match = re.match(r'^(bai\d+)_', filename)
|
| 211 |
if bai_match:
|
| 212 |
bai_id = bai_match.group(1)
|
| 213 |
-
else:
|
| 214 |
-
bai_id = 'bai1'
|
| 215 |
|
| 216 |
api_url = f"/api/figures/{bai_id}/{filename}"
|
| 217 |
return f""
|
| 218 |
-
|
| 219 |
-
# Các trường hợp khác, giữ nguyên
|
| 220 |
return match.group(0)
|
| 221 |
|
| 222 |
-
# Thay thế tất cả
|
| 223 |
processed_text = re.sub(image_pattern, replace_image_path, response_text)
|
| 224 |
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
| 226 |
return processed_text
|
| 227 |
|
| 228 |
except Exception as e:
|
| 229 |
-
logger.error(f"Lỗi xử lý
|
| 230 |
return response_text
|
| 231 |
|
| 232 |
def generate_follow_up_questions(self, query, answer, age=1):
|
| 233 |
-
|
| 234 |
-
Tạo câu hỏi gợi ý dựa trên query và answer
|
| 235 |
-
|
| 236 |
-
Args:
|
| 237 |
-
query (str): Câu hỏi gốc
|
| 238 |
-
answer (str): Câu trả lời đã được tạo
|
| 239 |
-
age (int): Tuổi người dùng
|
| 240 |
-
|
| 241 |
-
Returns:
|
| 242 |
-
dict: Response data với danh sách câu hỏi gợi ý
|
| 243 |
-
"""
|
| 244 |
try:
|
| 245 |
-
logger.info("Đang tạo câu hỏi
|
| 246 |
|
| 247 |
follow_up_prompt = f"""
|
| 248 |
Dựa trên cuộc hội thoại sau, hãy tạo 3-5 câu hỏi gợi ý phù hợp cho người dùng {age} tuổi về chủ đề dinh dưỡng:
|
|
@@ -273,27 +246,30 @@ Trả về danh sách câu hỏi, mỗi câu một dòng, không đánh số.
|
|
| 273 |
"error": "Không thể tạo câu hỏi gợi ý"
|
| 274 |
}
|
| 275 |
|
| 276 |
-
#
|
| 277 |
questions = []
|
| 278 |
lines = response.text.strip().split('\n')
|
| 279 |
|
| 280 |
for line in lines:
|
| 281 |
line = line.strip()
|
|
|
|
| 282 |
if line and not line.startswith('#') and len(line) > 10:
|
| 283 |
-
# Loại bỏ số thứ tự nếu có
|
| 284 |
line = re.sub(r'^\d+[\.\)]\s*', '', line)
|
| 285 |
questions.append(line)
|
| 286 |
|
| 287 |
-
# Giới hạn 5 câu hỏi
|
| 288 |
questions = questions[:5]
|
| 289 |
|
|
|
|
|
|
|
| 290 |
return {
|
| 291 |
"success": True,
|
| 292 |
"questions": questions
|
| 293 |
}
|
| 294 |
|
| 295 |
except Exception as e:
|
| 296 |
-
logger.error(f"Lỗi tạo
|
| 297 |
return {
|
| 298 |
"success": False,
|
| 299 |
"error": f"Lỗi tạo câu hỏi gợi ý: {str(e)}"
|
|
|
|
| 5 |
import os
|
| 6 |
import re
|
| 7 |
|
|
|
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
+
# Cấu hình Gemini API
|
| 11 |
genai.configure(api_key=GEMINI_API_KEY)
|
| 12 |
|
| 13 |
class RAGPipeline:
|
| 14 |
def __init__(self):
|
| 15 |
+
# Khởi tạo RAG Pipeline với embedding model
|
| 16 |
+
logger.info("Đang khởi tạo RAG Pipeline")
|
| 17 |
|
| 18 |
self.embedding_model = get_embedding_model()
|
|
|
|
|
|
|
| 19 |
self.gemini_model = genai.GenerativeModel('gemini-2.0-flash')
|
| 20 |
|
| 21 |
+
logger.info("RAG Pipeline đã sẵn sàng hoạt động")
|
| 22 |
|
| 23 |
def generate_response(self, query, age=1):
|
| 24 |
+
# Tạo phản hồi cho câu hỏi của người dùng sử dụng RAG
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
try:
|
| 26 |
+
logger.info(f"Bắt đầu tạo phản hồi cho câu hỏi: {query[:50]}... (tuổi: {age})")
|
| 27 |
+
|
| 28 |
+
# Tìm kiếm thông tin liên quan trong ChromaDB
|
| 29 |
+
logger.info("Đang tìm kiếm thông tin liên quan trong cơ sở dữ liệu")
|
| 30 |
+
search_results = self.embedding_model.search(query, top_k=TOP_K_RESULTS, age_filter=age)
|
| 31 |
+
# search_results = self.embedding_model.search(query, top_k=TOP_K_RESULTS)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
if not search_results or len(search_results) == 0:
|
| 34 |
+
logger.warning("Không tìm thấy thông tin liên quan trong cơ sở dữ liệu")
|
| 35 |
return {
|
| 36 |
"success": True,
|
| 37 |
"response": "Xin lỗi, tôi không tìm thấy thông tin liên quan đến câu hỏi của bạn trong tài liệu.",
|
| 38 |
"sources": []
|
| 39 |
}
|
| 40 |
|
| 41 |
+
# Chuẩn bị ngữ cảnh từ kết quả tìm kiếm
|
| 42 |
contexts = []
|
| 43 |
sources = []
|
| 44 |
|
| 45 |
for result in search_results:
|
|
|
|
| 46 |
metadata = result.get('metadata', {})
|
| 47 |
content = result.get('document', '')
|
| 48 |
|
| 49 |
+
# Thêm nội dung vào ngữ cảnh
|
| 50 |
contexts.append({
|
| 51 |
"content": content,
|
| 52 |
"metadata": metadata
|
| 53 |
})
|
| 54 |
|
| 55 |
+
# Tạo thông tin nguồn tài liệu
|
| 56 |
source_info = {
|
| 57 |
+
"title": metadata.get('title', metadata.get('chapter', 'Tài liệu dinh dưỡng')),
|
|
|
|
| 58 |
"pages": metadata.get('pages'),
|
| 59 |
"content_type": metadata.get('content_type', 'text')
|
| 60 |
}
|
|
|
|
| 62 |
if source_info not in sources:
|
| 63 |
sources.append(source_info)
|
| 64 |
|
| 65 |
+
# Định dạng ngữ cảnh cho prompt
|
| 66 |
formatted_contexts = self._format_contexts(contexts)
|
| 67 |
|
| 68 |
+
# Tạo prompt với ngữ cảnh độ tuổi
|
| 69 |
full_prompt = self._create_prompt_with_age_context(query, age, formatted_contexts)
|
| 70 |
|
| 71 |
+
# Tạo phản hồi với Gemini AI
|
| 72 |
+
logger.info("Đang tạo phản hồi với Gemini AI")
|
| 73 |
response = self.gemini_model.generate_content(
|
| 74 |
full_prompt,
|
| 75 |
generation_config=genai.types.GenerationConfig(
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
if not response or not response.text:
|
| 82 |
+
logger.error("Gemini AI không trả về phản hồi")
|
| 83 |
return {
|
| 84 |
"success": False,
|
| 85 |
"error": "Không thể tạo phản hồi"
|
|
|
|
| 87 |
|
| 88 |
response_text = response.text.strip()
|
| 89 |
|
| 90 |
+
# Xử lý các đường dẫn hình ảnh trong phản hồi
|
| 91 |
response_text = self._process_image_links(response_text)
|
| 92 |
|
| 93 |
logger.info("Đã tạo phản hồi thành công")
|
|
|
|
| 99 |
}
|
| 100 |
|
| 101 |
except Exception as e:
|
| 102 |
+
logger.error(f"Lỗi khi tạo phản hồi: {str(e)}")
|
| 103 |
return {
|
| 104 |
"success": False,
|
| 105 |
"error": f"Lỗi tạo phản hồi: {str(e)}"
|
| 106 |
}
|
| 107 |
|
| 108 |
def _format_contexts(self, contexts):
|
| 109 |
+
# Định dạng ngữ cảnh thành chuỗi cho prompt
|
| 110 |
formatted = []
|
| 111 |
|
| 112 |
for i, context in enumerate(contexts, 1):
|
| 113 |
content = context['content']
|
| 114 |
metadata = context['metadata']
|
| 115 |
|
| 116 |
+
# Thêm thông tin metadata vào ngữ cảnh
|
| 117 |
context_str = f"[Tài liệu {i}]"
|
| 118 |
+
if metadata.get('title'):
|
|
|
|
|
|
|
| 119 |
context_str += f" - {metadata['title']}"
|
| 120 |
if metadata.get('pages'):
|
| 121 |
context_str += f" (Trang {metadata['pages']})"
|
|
|
|
| 125 |
|
| 126 |
return "\n".join(formatted)
|
| 127 |
|
| 128 |
+
def _create_prompt_with_age_context(self, query, age, contexts):
|
| 129 |
+
# Xác định hướng dẫn theo nhóm tuổi
|
|
|
|
| 130 |
if age <= 3:
|
| 131 |
age_guidance = "Sử dụng ngôn ngữ đơn giản, dễ hiểu cho phụ huynh có con nhỏ."
|
| 132 |
elif age <= 6:
|
|
|
|
| 138 |
else:
|
| 139 |
age_guidance = "Thông tin đầy đủ, chi tiết cho học sinh trung học phổ thông."
|
| 140 |
|
| 141 |
+
# Tạo system prompt có tính đến độ tuổi
|
| 142 |
age_aware_system_prompt = f"""{SYSTEM_PROMPT}
|
| 143 |
|
| 144 |
QUAN TRỌNG - Hướng dẫn theo độ tuổi:
|
|
|
|
| 148 |
- Tránh thông tin quá phức tạp hoặc không phù hợp
|
| 149 |
"""
|
| 150 |
|
| 151 |
+
# Tạo human prompt từ template
|
| 152 |
human_prompt = HUMAN_PROMPT_TEMPLATE.format(
|
| 153 |
query=query,
|
| 154 |
age=age,
|
|
|
|
| 158 |
return f"{age_aware_system_prompt}\n\n{human_prompt}"
|
| 159 |
|
| 160 |
def _process_image_links(self, response_text):
|
| 161 |
+
# Xử lý và chuyển đổi các đường dẫn hình ảnh trong phản hồi
|
| 162 |
try:
|
| 163 |
import re
|
| 164 |
|
| 165 |
+
# Tìm các pattern markdown: 
|
| 166 |
image_pattern = r'!\[([^\]]*)\]\(([^)]+)\)'
|
| 167 |
|
| 168 |
def replace_image_path(match):
|
| 169 |
alt_text = match.group(1)
|
| 170 |
image_path = match.group(2)
|
| 171 |
|
| 172 |
+
# Xử lý đường dẫn local (Windows/Linux)
|
| 173 |
if '\\' in image_path or image_path.startswith('/') or ':' in image_path:
|
| 174 |
+
# Trích xuất tên file từ đường dẫn local
|
| 175 |
filename = image_path.split('\\')[-1].split('/')[-1]
|
| 176 |
|
| 177 |
+
# Tìm bai_id từ tên file (format: baiX_filename)
|
| 178 |
bai_match = re.match(r'^(bai\d+)_', filename)
|
| 179 |
if bai_match:
|
| 180 |
bai_id = bai_match.group(1)
|
|
|
|
|
|
|
| 181 |
|
| 182 |
+
# Tạo URL API
|
| 183 |
api_url = f"/api/figures/{bai_id}/{filename}"
|
| 184 |
return f""
|
| 185 |
|
|
|
|
| 193 |
bai_match = re.match(r'^(bai\d+)_', filename)
|
| 194 |
if bai_match:
|
| 195 |
bai_id = bai_match.group(1)
|
|
|
|
|
|
|
| 196 |
|
| 197 |
api_url = f"/api/figures/{bai_id}/{filename}"
|
| 198 |
return f""
|
| 199 |
+
|
|
|
|
| 200 |
return match.group(0)
|
| 201 |
|
| 202 |
+
# Thay thế tất cả các liên kết hình ảnh
|
| 203 |
processed_text = re.sub(image_pattern, replace_image_path, response_text)
|
| 204 |
|
| 205 |
+
image_count = len(re.findall(image_pattern, response_text))
|
| 206 |
+
if image_count > 0:
|
| 207 |
+
logger.info(f"Đã xử lý {image_count} liên kết hình ảnh")
|
| 208 |
+
|
| 209 |
return processed_text
|
| 210 |
|
| 211 |
except Exception as e:
|
| 212 |
+
logger.error(f"Lỗi khi xử lý liên kết hình ảnh: {e}")
|
| 213 |
return response_text
|
| 214 |
|
| 215 |
def generate_follow_up_questions(self, query, answer, age=1):
|
| 216 |
+
# Tạo câu hỏi gợi ý dựa trên cuộc hội thoại hiện tại
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
try:
|
| 218 |
+
logger.info("Đang tạo câu hỏi gợi ý")
|
| 219 |
|
| 220 |
follow_up_prompt = f"""
|
| 221 |
Dựa trên cuộc hội thoại sau, hãy tạo 3-5 câu hỏi gợi ý phù hợp cho người dùng {age} tuổi về chủ đề dinh dưỡng:
|
|
|
|
| 246 |
"error": "Không thể tạo câu hỏi gợi ý"
|
| 247 |
}
|
| 248 |
|
| 249 |
+
# Chuyển đổi phản hồi thành danh sách câu hỏi
|
| 250 |
questions = []
|
| 251 |
lines = response.text.strip().split('\n')
|
| 252 |
|
| 253 |
for line in lines:
|
| 254 |
line = line.strip()
|
| 255 |
+
# Lọc các dòng hợp lệ (không rỗng, không phải comment, đủ dài)
|
| 256 |
if line and not line.startswith('#') and len(line) > 10:
|
| 257 |
+
# Loại bỏ số thứ tự nếu có (1. 2. hoặc 1) 2))
|
| 258 |
line = re.sub(r'^\d+[\.\)]\s*', '', line)
|
| 259 |
questions.append(line)
|
| 260 |
|
| 261 |
+
# Giới hạn tối đa 5 câu hỏi
|
| 262 |
questions = questions[:5]
|
| 263 |
|
| 264 |
+
logger.info(f"Đã tạo {len(questions)} câu hỏi gợi ý")
|
| 265 |
+
|
| 266 |
return {
|
| 267 |
"success": True,
|
| 268 |
"questions": questions
|
| 269 |
}
|
| 270 |
|
| 271 |
except Exception as e:
|
| 272 |
+
logger.error(f"Lỗi khi tạo câu hỏi gợi ý: {str(e)}")
|
| 273 |
return {
|
| 274 |
"success": False,
|
| 275 |
"error": f"Lỗi tạo câu hỏi gợi ý: {str(e)}"
|