Nguyễn Quốc Vỹ
fix lỗi không up được tài liệu
e355040
"""
Module indexing: Tạo vector database bằng ChromaDB
Sử dụng multilingual-e5-base cho embedding tiếng Việt chất lượng cao.
"""
import os
import sys
import chromadb
from typing import List, Dict
import torch
from sentence_transformers import SentenceTransformer
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT_DIR not in sys.path:
sys.path.insert(0, ROOT_DIR)
from backend.runtime_paths import VECTOR_DIR
# Cấu hình ChromaDB
CHROMA_PERSIST_DIR = VECTOR_DIR
COLLECTION_NAME = "lich_su_viet_nam"
EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
# ======================== CUSTOM EMBEDDING ========================
class E5EmbeddingFunction:
"""
Embedding function cho model intfloat/multilingual-e5-base.
Model E5 yêu cầu prefix "query: " hoặc "passage: " trước mỗi text.
- Khi index tài liệu: dùng "passage: "
- Khi tìm kiếm: dùng "query: "
"""
def __init__(self, model_name: str = EMBEDDING_MODEL):
print(f"[Embedding] Loading model: {model_name} ...")
# Tránh lỗi PyTorch (HF Space / torch mới): "Cannot copy out of meta tensor"
# khi transformers dùng meta device + .to(device).
device = "cuda" if torch.cuda.is_available() else "cpu"
self._model = SentenceTransformer(
model_name,
device=device,
model_kwargs={
"low_cpu_mem_usage": False,
"trust_remote_code": False,
},
)
self._mode = "query" # Mặc định là query (search)
print(f"[Embedding] ✅ Model loaded ({self._model.get_sentence_embedding_dimension()} dims)")
def name(self) -> str:
"""Tên ổn định để ChromaDB có thể persist/check embedding config."""
return f"e5_embedding_{EMBEDDING_MODEL}"
def set_mode(self, mode: str):
"""Chuyển mode: 'query' cho tìm kiếm, 'passage' cho index tài liệu."""
assert mode in ("query", "passage"), f"Mode phải là 'query' hoặc 'passage', nhận: {mode}"
self._mode = mode
def __call__(self, input: List[str]) -> List[List[float]]:
prefix = "query: " if self._mode == "query" else "passage: "
prefixed = [prefix + text for text in input]
embeddings = self._model.encode(prefixed, normalize_embeddings=True)
return embeddings.tolist()
def embed_query(self, input: List[str]) -> List[List[float]]:
"""Tương thích với interface embedding mới của ChromaDB khi query."""
self.set_mode("query")
return self.__call__(input)
def embed_documents(self, input: List[str]) -> List[List[float]]:
"""Tương thích với interface embedding mới của ChromaDB khi index."""
self.set_mode("passage")
return self.__call__(input)
# Singleton embedding function (tránh load model nhiều lần)
_embedding_fn_instance = None
def get_embedding_function() -> E5EmbeddingFunction:
"""Lấy embedding function (singleton, chỉ load model 1 lần)."""
global _embedding_fn_instance
if _embedding_fn_instance is None:
_embedding_fn_instance = E5EmbeddingFunction(EMBEDDING_MODEL)
return _embedding_fn_instance
def get_chroma_client():
"""Tạo ChromaDB client với persistent storage."""
os.makedirs(CHROMA_PERSIST_DIR, exist_ok=True)
client = chromadb.PersistentClient(path=CHROMA_PERSIST_DIR)
return client
def get_collection():
"""Lấy hoặc tạo collection trong ChromaDB."""
client = get_chroma_client()
embedding_fn = get_embedding_function()
# Đảm bảo mode query khi sử dụng collection bình thường
embedding_fn.set_mode("query")
collection = client.get_or_create_collection(
name=COLLECTION_NAME,
embedding_function=embedding_fn,
metadata={"hnsw:space": "cosine"}
)
return collection
def get_indexed_sources() -> set:
"""Trả về tập hợp tên file (source) đã được index trong ChromaDB."""
collection = get_collection()
total = collection.count()
if total == 0:
return set()
batch_size = 10000
sources: set = set()
for offset in range(0, total, batch_size):
result = collection.get(
limit=batch_size,
offset=offset,
include=["metadatas"],
)
for meta in result.get("metadatas", []):
src = (meta or {}).get("source")
if src:
sources.add(src)
return sources
def is_document_indexed(source_name: str) -> bool:
"""Kiểm tra xem tài liệu (theo tên file) đã được index chưa."""
collection = get_collection()
result = collection.get(
where={"source": source_name},
limit=1,
include=[],
)
return len(result.get("ids", [])) > 0
def delete_chunks_by_source(source_name: str) -> int:
"""Xóa tất cả chunk thuộc một tài liệu. Trả về số chunk đã xóa."""
collection = get_collection()
result = collection.get(
where={"source": source_name},
include=[],
)
ids_to_delete = result.get("ids", [])
if ids_to_delete:
collection.delete(ids=ids_to_delete)
print(f"[Index] 🗑️ Đã xóa {len(ids_to_delete)} chunks của '{source_name}'")
return len(ids_to_delete)
def _make_chunk_id(source: str, chunk_index: int) -> str:
"""Tạo ID ổn định cho chunk dựa trên tên nguồn + thứ tự."""
return f"{source}__chunk_{chunk_index}"
def create_vector_database(chunks: List[Dict]):
"""
Tạo vector database từ danh sách chunks.
Mỗi chunk có dạng: {"content": "...", "metadata": {...}}
ID mỗi chunk = "{source}__chunk_{i}" để tránh ghi đè giữa các tài liệu.
"""
if not chunks:
print("❌ Không có chunks để index!")
return
collection = get_collection()
embedding_fn = get_embedding_function()
embedding_fn.set_mode("passage")
documents = []
metadatas = []
ids = []
per_source_counter: Dict[str, int] = {}
for chunk in chunks:
content = chunk.get("content", "").strip()
if not content:
continue
metadata = chunk.get("metadata", {})
clean_metadata = {}
for k, v in metadata.items():
if isinstance(v, (str, int, float, bool)):
clean_metadata[k] = v
else:
clean_metadata[k] = str(v)
source = clean_metadata.get("source", "unknown")
idx = per_source_counter.get(source, 0)
per_source_counter[source] = idx + 1
documents.append(content)
metadatas.append(clean_metadata)
ids.append(_make_chunk_id(source, idx))
batch_size = 500
total = len(documents)
skipped_existing = 0
inserted_new = 0
for start in range(0, total, batch_size):
end = min(start + batch_size, total)
batch_ids = ids[start:end]
existing = collection.get(ids=batch_ids, include=[])
existing_ids = set(existing.get("ids", []) if existing else [])
filtered_docs = []
filtered_metas = []
filtered_ids = []
for doc, meta, chunk_id in zip(
documents[start:end],
metadatas[start:end],
batch_ids,
):
if chunk_id in existing_ids:
skipped_existing += 1
continue
filtered_docs.append(doc)
filtered_metas.append(meta)
filtered_ids.append(chunk_id)
if not filtered_ids:
continue
collection.upsert(
documents=filtered_docs,
metadatas=filtered_metas,
ids=filtered_ids
)
inserted_new += len(filtered_ids)
print(f" ✅ Đã index mới {inserted_new}/{total} chunks")
embedding_fn.set_mode("query")
print(f"\n✅ Tổng cộng {inserted_new} chunks mới đã được index vào ChromaDB")
if skipped_existing:
print(f"⏭️ Bỏ qua {skipped_existing} chunks đã tồn tại")
print(f"📁 Dữ liệu lưu tại: {CHROMA_PERSIST_DIR}")
print(f"🧠 Embedding model: {EMBEDDING_MODEL}")
def search(query: str, top_k: int = 5, max_distance: float = 0.8) -> List[Dict]:
"""
Tìm kiếm tài liệu liên quan đến câu hỏi.
ChromaDB cosine distance: 0 = giống nhất, 2 = khác nhất.
max_distance: ngưỡng tối đa, chỉ trả về kết quả có distance < max_distance.
"""
collection = get_collection()
# Đảm bảo query luôn dùng đúng prefix "query: "
get_embedding_function().set_mode("query")
if collection.count() == 0:
print("[Search] ⚠️ Database rỗng! Chạy run_pipeline.py trước.")
return []
results = collection.query(
query_texts=[query],
n_results=min(top_k * 2, 20), # Lấy nhiều hơn rồi lọc
include=["documents", "metadatas", "distances"]
)
search_results = []
if results and results["documents"]:
for doc, meta, dist in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0]
):
if dist < max_distance: # Chỉ lấy kết quả đủ tốt
search_results.append({
"content": doc,
"metadata": meta,
"score": dist
})
# Sắp xếp theo score (distance thấp = tốt hơn)
search_results.sort(key=lambda x: x["score"])
return search_results[:top_k]
def test_search():
"""Test tìm kiếm với một số câu hỏi mẫu."""
test_queries = [
"Trận Bạch Đằng năm 938",
"Triều đại nhà Lý",
"Chiến thắng Điện Biên Phủ",
"Vua Quang Trung đại phá quân Thanh",
"Cách mạng tháng Tám 1945"
]
collection = get_collection()
total_chunks = collection.count()
print(f"\n📊 Tổng số chunks trong database: {total_chunks}")
if total_chunks == 0:
print("⚠️ Database trống!")
return
for query in test_queries:
print(f"\n🔍 Query: '{query}'")
results = search(query, top_k=3)
for j, r in enumerate(results):
score = r["score"]
content_preview = r["content"][:100] + "..."
print(f" [{j+1}] (score: {score:.4f}) {content_preview}")
def delete_collection():
"""Xóa toàn bộ collection trong ChromaDB."""
client = get_chroma_client()
try:
client.delete_collection(COLLECTION_NAME)
print(f"✅ Đã xóa collection '{COLLECTION_NAME}'")
except Exception as e:
print(f"⚠️ Lỗi khi xóa collection: {e}")
def get_stats() -> Dict:
"""Lấy thống kê về database."""
collection = get_collection()
return {
"collection_name": COLLECTION_NAME,
"total_chunks": collection.count(),
"persist_dir": CHROMA_PERSIST_DIR,
"embedding_model": EMBEDDING_MODEL
}