Spaces:
Running
Running
| """ | |
| Module indexing: Tạo vector database bằng ChromaDB | |
| Sử dụng multilingual-e5-base cho embedding tiếng Việt chất lượng cao. | |
| """ | |
| import os | |
| import sys | |
| import chromadb | |
| from typing import List, Dict | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| if ROOT_DIR not in sys.path: | |
| sys.path.insert(0, ROOT_DIR) | |
| from backend.runtime_paths import VECTOR_DIR | |
| # Cấu hình ChromaDB | |
| CHROMA_PERSIST_DIR = VECTOR_DIR | |
| COLLECTION_NAME = "lich_su_viet_nam" | |
| EMBEDDING_MODEL = "intfloat/multilingual-e5-base" | |
| # ======================== CUSTOM EMBEDDING ======================== | |
| class E5EmbeddingFunction: | |
| """ | |
| Embedding function cho model intfloat/multilingual-e5-base. | |
| Model E5 yêu cầu prefix "query: " hoặc "passage: " trước mỗi text. | |
| - Khi index tài liệu: dùng "passage: " | |
| - Khi tìm kiếm: dùng "query: " | |
| """ | |
| def __init__(self, model_name: str = EMBEDDING_MODEL): | |
| print(f"[Embedding] Loading model: {model_name} ...") | |
| # Tránh lỗi PyTorch (HF Space / torch mới): "Cannot copy out of meta tensor" | |
| # khi transformers dùng meta device + .to(device). | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self._model = SentenceTransformer( | |
| model_name, | |
| device=device, | |
| model_kwargs={ | |
| "low_cpu_mem_usage": False, | |
| "trust_remote_code": False, | |
| }, | |
| ) | |
| self._mode = "query" # Mặc định là query (search) | |
| print(f"[Embedding] ✅ Model loaded ({self._model.get_sentence_embedding_dimension()} dims)") | |
| def name(self) -> str: | |
| """Tên ổn định để ChromaDB có thể persist/check embedding config.""" | |
| return f"e5_embedding_{EMBEDDING_MODEL}" | |
| def set_mode(self, mode: str): | |
| """Chuyển mode: 'query' cho tìm kiếm, 'passage' cho index tài liệu.""" | |
| assert mode in ("query", "passage"), f"Mode phải là 'query' hoặc 'passage', nhận: {mode}" | |
| self._mode = mode | |
| def __call__(self, input: List[str]) -> List[List[float]]: | |
| prefix = "query: " if self._mode == "query" else "passage: " | |
| prefixed = [prefix + text for text in input] | |
| embeddings = self._model.encode(prefixed, normalize_embeddings=True) | |
| return embeddings.tolist() | |
| def embed_query(self, input: List[str]) -> List[List[float]]: | |
| """Tương thích với interface embedding mới của ChromaDB khi query.""" | |
| self.set_mode("query") | |
| return self.__call__(input) | |
| def embed_documents(self, input: List[str]) -> List[List[float]]: | |
| """Tương thích với interface embedding mới của ChromaDB khi index.""" | |
| self.set_mode("passage") | |
| return self.__call__(input) | |
| # Singleton embedding function (tránh load model nhiều lần) | |
| _embedding_fn_instance = None | |
| def get_embedding_function() -> E5EmbeddingFunction: | |
| """Lấy embedding function (singleton, chỉ load model 1 lần).""" | |
| global _embedding_fn_instance | |
| if _embedding_fn_instance is None: | |
| _embedding_fn_instance = E5EmbeddingFunction(EMBEDDING_MODEL) | |
| return _embedding_fn_instance | |
| def get_chroma_client(): | |
| """Tạo ChromaDB client với persistent storage.""" | |
| os.makedirs(CHROMA_PERSIST_DIR, exist_ok=True) | |
| client = chromadb.PersistentClient(path=CHROMA_PERSIST_DIR) | |
| return client | |
| def get_collection(): | |
| """Lấy hoặc tạo collection trong ChromaDB.""" | |
| client = get_chroma_client() | |
| embedding_fn = get_embedding_function() | |
| # Đảm bảo mode query khi sử dụng collection bình thường | |
| embedding_fn.set_mode("query") | |
| collection = client.get_or_create_collection( | |
| name=COLLECTION_NAME, | |
| embedding_function=embedding_fn, | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| return collection | |
| def get_indexed_sources() -> set: | |
| """Trả về tập hợp tên file (source) đã được index trong ChromaDB.""" | |
| collection = get_collection() | |
| total = collection.count() | |
| if total == 0: | |
| return set() | |
| batch_size = 10000 | |
| sources: set = set() | |
| for offset in range(0, total, batch_size): | |
| result = collection.get( | |
| limit=batch_size, | |
| offset=offset, | |
| include=["metadatas"], | |
| ) | |
| for meta in result.get("metadatas", []): | |
| src = (meta or {}).get("source") | |
| if src: | |
| sources.add(src) | |
| return sources | |
| def is_document_indexed(source_name: str) -> bool: | |
| """Kiểm tra xem tài liệu (theo tên file) đã được index chưa.""" | |
| collection = get_collection() | |
| result = collection.get( | |
| where={"source": source_name}, | |
| limit=1, | |
| include=[], | |
| ) | |
| return len(result.get("ids", [])) > 0 | |
| def delete_chunks_by_source(source_name: str) -> int: | |
| """Xóa tất cả chunk thuộc một tài liệu. Trả về số chunk đã xóa.""" | |
| collection = get_collection() | |
| result = collection.get( | |
| where={"source": source_name}, | |
| include=[], | |
| ) | |
| ids_to_delete = result.get("ids", []) | |
| if ids_to_delete: | |
| collection.delete(ids=ids_to_delete) | |
| print(f"[Index] 🗑️ Đã xóa {len(ids_to_delete)} chunks của '{source_name}'") | |
| return len(ids_to_delete) | |
| def _make_chunk_id(source: str, chunk_index: int) -> str: | |
| """Tạo ID ổn định cho chunk dựa trên tên nguồn + thứ tự.""" | |
| return f"{source}__chunk_{chunk_index}" | |
| def create_vector_database(chunks: List[Dict]): | |
| """ | |
| Tạo vector database từ danh sách chunks. | |
| Mỗi chunk có dạng: {"content": "...", "metadata": {...}} | |
| ID mỗi chunk = "{source}__chunk_{i}" để tránh ghi đè giữa các tài liệu. | |
| """ | |
| if not chunks: | |
| print("❌ Không có chunks để index!") | |
| return | |
| collection = get_collection() | |
| embedding_fn = get_embedding_function() | |
| embedding_fn.set_mode("passage") | |
| documents = [] | |
| metadatas = [] | |
| ids = [] | |
| per_source_counter: Dict[str, int] = {} | |
| for chunk in chunks: | |
| content = chunk.get("content", "").strip() | |
| if not content: | |
| continue | |
| metadata = chunk.get("metadata", {}) | |
| clean_metadata = {} | |
| for k, v in metadata.items(): | |
| if isinstance(v, (str, int, float, bool)): | |
| clean_metadata[k] = v | |
| else: | |
| clean_metadata[k] = str(v) | |
| source = clean_metadata.get("source", "unknown") | |
| idx = per_source_counter.get(source, 0) | |
| per_source_counter[source] = idx + 1 | |
| documents.append(content) | |
| metadatas.append(clean_metadata) | |
| ids.append(_make_chunk_id(source, idx)) | |
| batch_size = 500 | |
| total = len(documents) | |
| skipped_existing = 0 | |
| inserted_new = 0 | |
| for start in range(0, total, batch_size): | |
| end = min(start + batch_size, total) | |
| batch_ids = ids[start:end] | |
| existing = collection.get(ids=batch_ids, include=[]) | |
| existing_ids = set(existing.get("ids", []) if existing else []) | |
| filtered_docs = [] | |
| filtered_metas = [] | |
| filtered_ids = [] | |
| for doc, meta, chunk_id in zip( | |
| documents[start:end], | |
| metadatas[start:end], | |
| batch_ids, | |
| ): | |
| if chunk_id in existing_ids: | |
| skipped_existing += 1 | |
| continue | |
| filtered_docs.append(doc) | |
| filtered_metas.append(meta) | |
| filtered_ids.append(chunk_id) | |
| if not filtered_ids: | |
| continue | |
| collection.upsert( | |
| documents=filtered_docs, | |
| metadatas=filtered_metas, | |
| ids=filtered_ids | |
| ) | |
| inserted_new += len(filtered_ids) | |
| print(f" ✅ Đã index mới {inserted_new}/{total} chunks") | |
| embedding_fn.set_mode("query") | |
| print(f"\n✅ Tổng cộng {inserted_new} chunks mới đã được index vào ChromaDB") | |
| if skipped_existing: | |
| print(f"⏭️ Bỏ qua {skipped_existing} chunks đã tồn tại") | |
| print(f"📁 Dữ liệu lưu tại: {CHROMA_PERSIST_DIR}") | |
| print(f"🧠 Embedding model: {EMBEDDING_MODEL}") | |
| def search(query: str, top_k: int = 5, max_distance: float = 0.8) -> List[Dict]: | |
| """ | |
| Tìm kiếm tài liệu liên quan đến câu hỏi. | |
| ChromaDB cosine distance: 0 = giống nhất, 2 = khác nhất. | |
| max_distance: ngưỡng tối đa, chỉ trả về kết quả có distance < max_distance. | |
| """ | |
| collection = get_collection() | |
| # Đảm bảo query luôn dùng đúng prefix "query: " | |
| get_embedding_function().set_mode("query") | |
| if collection.count() == 0: | |
| print("[Search] ⚠️ Database rỗng! Chạy run_pipeline.py trước.") | |
| return [] | |
| results = collection.query( | |
| query_texts=[query], | |
| n_results=min(top_k * 2, 20), # Lấy nhiều hơn rồi lọc | |
| include=["documents", "metadatas", "distances"] | |
| ) | |
| search_results = [] | |
| if results and results["documents"]: | |
| for doc, meta, dist in zip( | |
| results["documents"][0], | |
| results["metadatas"][0], | |
| results["distances"][0] | |
| ): | |
| if dist < max_distance: # Chỉ lấy kết quả đủ tốt | |
| search_results.append({ | |
| "content": doc, | |
| "metadata": meta, | |
| "score": dist | |
| }) | |
| # Sắp xếp theo score (distance thấp = tốt hơn) | |
| search_results.sort(key=lambda x: x["score"]) | |
| return search_results[:top_k] | |
| def test_search(): | |
| """Test tìm kiếm với một số câu hỏi mẫu.""" | |
| test_queries = [ | |
| "Trận Bạch Đằng năm 938", | |
| "Triều đại nhà Lý", | |
| "Chiến thắng Điện Biên Phủ", | |
| "Vua Quang Trung đại phá quân Thanh", | |
| "Cách mạng tháng Tám 1945" | |
| ] | |
| collection = get_collection() | |
| total_chunks = collection.count() | |
| print(f"\n📊 Tổng số chunks trong database: {total_chunks}") | |
| if total_chunks == 0: | |
| print("⚠️ Database trống!") | |
| return | |
| for query in test_queries: | |
| print(f"\n🔍 Query: '{query}'") | |
| results = search(query, top_k=3) | |
| for j, r in enumerate(results): | |
| score = r["score"] | |
| content_preview = r["content"][:100] + "..." | |
| print(f" [{j+1}] (score: {score:.4f}) {content_preview}") | |
| def delete_collection(): | |
| """Xóa toàn bộ collection trong ChromaDB.""" | |
| client = get_chroma_client() | |
| try: | |
| client.delete_collection(COLLECTION_NAME) | |
| print(f"✅ Đã xóa collection '{COLLECTION_NAME}'") | |
| except Exception as e: | |
| print(f"⚠️ Lỗi khi xóa collection: {e}") | |
| def get_stats() -> Dict: | |
| """Lấy thống kê về database.""" | |
| collection = get_collection() | |
| return { | |
| "collection_name": COLLECTION_NAME, | |
| "total_chunks": collection.count(), | |
| "persist_dir": CHROMA_PERSIST_DIR, | |
| "embedding_model": EMBEDDING_MODEL | |
| } |