Spaces:
Running
Running
File size: 11,154 Bytes
b655c88 e355040 b655c88 e355040 b655c88 0bcda63 b655c88 0bcda63 b655c88 47738d8 b655c88 0bcda63 b655c88 0bcda63 b655c88 0bcda63 b655c88 47738d8 b655c88 47738d8 b655c88 47738d8 b655c88 47738d8 b655c88 47738d8 b655c88 47738d8 b655c88 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 | """
Module indexing: Tạo vector database bằng ChromaDB
Sử dụng multilingual-e5-base cho embedding tiếng Việt chất lượng cao.
"""
import os
import sys
import chromadb
from typing import List, Dict
import torch
from sentence_transformers import SentenceTransformer
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT_DIR not in sys.path:
sys.path.insert(0, ROOT_DIR)
from backend.runtime_paths import VECTOR_DIR
# Cấu hình ChromaDB
CHROMA_PERSIST_DIR = VECTOR_DIR
COLLECTION_NAME = "lich_su_viet_nam"
EMBEDDING_MODEL = "intfloat/multilingual-e5-base"
# ======================== CUSTOM EMBEDDING ========================
class E5EmbeddingFunction:
"""
Embedding function cho model intfloat/multilingual-e5-base.
Model E5 yêu cầu prefix "query: " hoặc "passage: " trước mỗi text.
- Khi index tài liệu: dùng "passage: "
- Khi tìm kiếm: dùng "query: "
"""
def __init__(self, model_name: str = EMBEDDING_MODEL):
print(f"[Embedding] Loading model: {model_name} ...")
# Tránh lỗi PyTorch (HF Space / torch mới): "Cannot copy out of meta tensor"
# khi transformers dùng meta device + .to(device).
device = "cuda" if torch.cuda.is_available() else "cpu"
self._model = SentenceTransformer(
model_name,
device=device,
model_kwargs={
"low_cpu_mem_usage": False,
"trust_remote_code": False,
},
)
self._mode = "query" # Mặc định là query (search)
print(f"[Embedding] ✅ Model loaded ({self._model.get_sentence_embedding_dimension()} dims)")
def name(self) -> str:
"""Tên ổn định để ChromaDB có thể persist/check embedding config."""
return f"e5_embedding_{EMBEDDING_MODEL}"
def set_mode(self, mode: str):
"""Chuyển mode: 'query' cho tìm kiếm, 'passage' cho index tài liệu."""
assert mode in ("query", "passage"), f"Mode phải là 'query' hoặc 'passage', nhận: {mode}"
self._mode = mode
def __call__(self, input: List[str]) -> List[List[float]]:
prefix = "query: " if self._mode == "query" else "passage: "
prefixed = [prefix + text for text in input]
embeddings = self._model.encode(prefixed, normalize_embeddings=True)
return embeddings.tolist()
def embed_query(self, input: List[str]) -> List[List[float]]:
"""Tương thích với interface embedding mới của ChromaDB khi query."""
self.set_mode("query")
return self.__call__(input)
def embed_documents(self, input: List[str]) -> List[List[float]]:
"""Tương thích với interface embedding mới của ChromaDB khi index."""
self.set_mode("passage")
return self.__call__(input)
# Singleton embedding function (tránh load model nhiều lần)
_embedding_fn_instance = None
def get_embedding_function() -> E5EmbeddingFunction:
"""Lấy embedding function (singleton, chỉ load model 1 lần)."""
global _embedding_fn_instance
if _embedding_fn_instance is None:
_embedding_fn_instance = E5EmbeddingFunction(EMBEDDING_MODEL)
return _embedding_fn_instance
def get_chroma_client():
"""Tạo ChromaDB client với persistent storage."""
os.makedirs(CHROMA_PERSIST_DIR, exist_ok=True)
client = chromadb.PersistentClient(path=CHROMA_PERSIST_DIR)
return client
def get_collection():
"""Lấy hoặc tạo collection trong ChromaDB."""
client = get_chroma_client()
embedding_fn = get_embedding_function()
# Đảm bảo mode query khi sử dụng collection bình thường
embedding_fn.set_mode("query")
collection = client.get_or_create_collection(
name=COLLECTION_NAME,
embedding_function=embedding_fn,
metadata={"hnsw:space": "cosine"}
)
return collection
def get_indexed_sources() -> set:
"""Trả về tập hợp tên file (source) đã được index trong ChromaDB."""
collection = get_collection()
total = collection.count()
if total == 0:
return set()
batch_size = 10000
sources: set = set()
for offset in range(0, total, batch_size):
result = collection.get(
limit=batch_size,
offset=offset,
include=["metadatas"],
)
for meta in result.get("metadatas", []):
src = (meta or {}).get("source")
if src:
sources.add(src)
return sources
def is_document_indexed(source_name: str) -> bool:
"""Kiểm tra xem tài liệu (theo tên file) đã được index chưa."""
collection = get_collection()
result = collection.get(
where={"source": source_name},
limit=1,
include=[],
)
return len(result.get("ids", [])) > 0
def delete_chunks_by_source(source_name: str) -> int:
"""Xóa tất cả chunk thuộc một tài liệu. Trả về số chunk đã xóa."""
collection = get_collection()
result = collection.get(
where={"source": source_name},
include=[],
)
ids_to_delete = result.get("ids", [])
if ids_to_delete:
collection.delete(ids=ids_to_delete)
print(f"[Index] 🗑️ Đã xóa {len(ids_to_delete)} chunks của '{source_name}'")
return len(ids_to_delete)
def _make_chunk_id(source: str, chunk_index: int) -> str:
"""Tạo ID ổn định cho chunk dựa trên tên nguồn + thứ tự."""
return f"{source}__chunk_{chunk_index}"
def create_vector_database(chunks: List[Dict]):
"""
Tạo vector database từ danh sách chunks.
Mỗi chunk có dạng: {"content": "...", "metadata": {...}}
ID mỗi chunk = "{source}__chunk_{i}" để tránh ghi đè giữa các tài liệu.
"""
if not chunks:
print("❌ Không có chunks để index!")
return
collection = get_collection()
embedding_fn = get_embedding_function()
embedding_fn.set_mode("passage")
documents = []
metadatas = []
ids = []
per_source_counter: Dict[str, int] = {}
for chunk in chunks:
content = chunk.get("content", "").strip()
if not content:
continue
metadata = chunk.get("metadata", {})
clean_metadata = {}
for k, v in metadata.items():
if isinstance(v, (str, int, float, bool)):
clean_metadata[k] = v
else:
clean_metadata[k] = str(v)
source = clean_metadata.get("source", "unknown")
idx = per_source_counter.get(source, 0)
per_source_counter[source] = idx + 1
documents.append(content)
metadatas.append(clean_metadata)
ids.append(_make_chunk_id(source, idx))
batch_size = 500
total = len(documents)
skipped_existing = 0
inserted_new = 0
for start in range(0, total, batch_size):
end = min(start + batch_size, total)
batch_ids = ids[start:end]
existing = collection.get(ids=batch_ids, include=[])
existing_ids = set(existing.get("ids", []) if existing else [])
filtered_docs = []
filtered_metas = []
filtered_ids = []
for doc, meta, chunk_id in zip(
documents[start:end],
metadatas[start:end],
batch_ids,
):
if chunk_id in existing_ids:
skipped_existing += 1
continue
filtered_docs.append(doc)
filtered_metas.append(meta)
filtered_ids.append(chunk_id)
if not filtered_ids:
continue
collection.upsert(
documents=filtered_docs,
metadatas=filtered_metas,
ids=filtered_ids
)
inserted_new += len(filtered_ids)
print(f" ✅ Đã index mới {inserted_new}/{total} chunks")
embedding_fn.set_mode("query")
print(f"\n✅ Tổng cộng {inserted_new} chunks mới đã được index vào ChromaDB")
if skipped_existing:
print(f"⏭️ Bỏ qua {skipped_existing} chunks đã tồn tại")
print(f"📁 Dữ liệu lưu tại: {CHROMA_PERSIST_DIR}")
print(f"🧠 Embedding model: {EMBEDDING_MODEL}")
def search(query: str, top_k: int = 5, max_distance: float = 0.8) -> List[Dict]:
"""
Tìm kiếm tài liệu liên quan đến câu hỏi.
ChromaDB cosine distance: 0 = giống nhất, 2 = khác nhất.
max_distance: ngưỡng tối đa, chỉ trả về kết quả có distance < max_distance.
"""
collection = get_collection()
# Đảm bảo query luôn dùng đúng prefix "query: "
get_embedding_function().set_mode("query")
if collection.count() == 0:
print("[Search] ⚠️ Database rỗng! Chạy run_pipeline.py trước.")
return []
results = collection.query(
query_texts=[query],
n_results=min(top_k * 2, 20), # Lấy nhiều hơn rồi lọc
include=["documents", "metadatas", "distances"]
)
search_results = []
if results and results["documents"]:
for doc, meta, dist in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0]
):
if dist < max_distance: # Chỉ lấy kết quả đủ tốt
search_results.append({
"content": doc,
"metadata": meta,
"score": dist
})
# Sắp xếp theo score (distance thấp = tốt hơn)
search_results.sort(key=lambda x: x["score"])
return search_results[:top_k]
def test_search():
"""Test tìm kiếm với một số câu hỏi mẫu."""
test_queries = [
"Trận Bạch Đằng năm 938",
"Triều đại nhà Lý",
"Chiến thắng Điện Biên Phủ",
"Vua Quang Trung đại phá quân Thanh",
"Cách mạng tháng Tám 1945"
]
collection = get_collection()
total_chunks = collection.count()
print(f"\n📊 Tổng số chunks trong database: {total_chunks}")
if total_chunks == 0:
print("⚠️ Database trống!")
return
for query in test_queries:
print(f"\n🔍 Query: '{query}'")
results = search(query, top_k=3)
for j, r in enumerate(results):
score = r["score"]
content_preview = r["content"][:100] + "..."
print(f" [{j+1}] (score: {score:.4f}) {content_preview}")
def delete_collection():
"""Xóa toàn bộ collection trong ChromaDB."""
client = get_chroma_client()
try:
client.delete_collection(COLLECTION_NAME)
print(f"✅ Đã xóa collection '{COLLECTION_NAME}'")
except Exception as e:
print(f"⚠️ Lỗi khi xóa collection: {e}")
def get_stats() -> Dict:
"""Lấy thống kê về database."""
collection = get_collection()
return {
"collection_name": COLLECTION_NAME,
"total_chunks": collection.count(),
"persist_dir": CHROMA_PERSIST_DIR,
"embedding_model": EMBEDDING_MODEL
} |