from __future__ import annotations import json import logging from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Sequence from langchain_core.documents import Document from langchain_chroma import Chroma from core.hash_file.hash_file import HashProcessor logger = logging.getLogger(__name__) @dataclass class ChromaConfig: """Cấu hình cho ChromaDB.""" def _default_persist_dir() -> str: """Lấy đường dẫn mặc định cho persist directory.""" repo_root = Path(__file__).resolve().parents[2] return str((repo_root / "data" / "chroma").resolve()) persist_dir: str = field(default_factory=_default_persist_dir) # Thư mục lưu DB collection_name: str = "hust_rag_collection" # Tên collection class ChromaVectorDB: """Wrapper cho ChromaDB với hỗ trợ Small-to-Big retrieval.""" def __init__( self, embedder: Any, config: ChromaConfig | None = None, ): """Khởi tạo ChromaDB với embedder và config.""" self.embedder = embedder self.config = config or ChromaConfig() self._hasher = HashProcessor(verbose=False) # Lưu trữ parent nodes (không embed, dùng cho Small-to-Big) self._parent_nodes_path = Path(self.config.persist_dir) / "parent_nodes.json" self._parent_nodes: Dict[str, Dict[str, Any]] = self._load_parent_nodes() # Khởi tạo ChromaDB self._vs = Chroma( collection_name=self.config.collection_name, embedding_function=self.embedder, persist_directory=self.config.persist_dir, ) logger.info(f"Đã khởi tạo ChromaVectorDB: {self.config.collection_name}") def _load_parent_nodes(self) -> Dict[str, Dict[str, Any]]: """Tải parent nodes từ file JSON.""" if self._parent_nodes_path.exists(): try: with open(self._parent_nodes_path, 'r', encoding='utf-8') as f: data = json.load(f) logger.info(f"Đã tải {len(data)} parent nodes từ {self._parent_nodes_path}") return data except Exception as e: logger.warning(f"Không thể tải parent nodes: {e}") return {} def _save_parent_nodes(self) -> None: """Lưu parent nodes vào file JSON.""" try: self._parent_nodes_path.parent.mkdir(parents=True, exist_ok=True) with open(self._parent_nodes_path, 'w', encoding='utf-8') as f: json.dump(self._parent_nodes, f, ensure_ascii=False, indent=2) logger.info(f"Đã lưu {len(self._parent_nodes)} parent nodes vào {self._parent_nodes_path}") except Exception as e: logger.warning(f"Không thể lưu parent nodes: {e}") @property def collection(self): """Lấy collection gốc của ChromaDB.""" return getattr(self._vs, "_collection", None) @property def vectorstore(self): """Lấy LangChain Chroma vectorstore.""" return self._vs def _flatten_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: """Chuyển metadata phức tạp thành format ChromaDB hỗ trợ.""" out: Dict[str, Any] = {} for k, v in (metadata or {}).items(): if v is None: continue if isinstance(v, (str, int, float, bool)): out[str(k)] = v elif isinstance(v, (list, tuple, set, dict)): # Chuyển list/dict thành JSON string out[str(k)] = json.dumps(v, ensure_ascii=False) else: out[str(k)] = str(v) return out def _normalize_doc(self, doc: Any) -> Dict[str, Any]: """Chuẩn hóa document từ nhiều format khác nhau thành dict.""" # Đã là dict if isinstance(doc, dict): return doc # TextNode/BaseNode từ llama_index if hasattr(doc, "get_content") and hasattr(doc, "metadata"): return { "content": doc.get_content(), "metadata": dict(doc.metadata) if doc.metadata else {}, } # Document từ LangChain if hasattr(doc, "page_content") and hasattr(doc, "metadata"): return { "content": doc.page_content, "metadata": dict(doc.metadata) if doc.metadata else {}, } raise TypeError(f"Không hỗ trợ loại document: {type(doc)}") def _to_documents(self, docs: Sequence[Any], ids: Sequence[str]) -> List[Document]: """Chuyển danh sách docs thành LangChain Documents.""" out: List[Document] = [] for d, doc_id in zip(docs, ids): normalized = self._normalize_doc(d) md = self._flatten_metadata(normalized.get("metadata", {}) or {}) md.setdefault("id", doc_id) out.append(Document(page_content=normalized.get("content", ""), metadata=md)) return out def _doc_id(self, doc: Any) -> str: """Tạo ID duy nhất cho document dựa trên nội dung.""" normalized = self._normalize_doc(doc) md = normalized.get("metadata") or {} key = { "source_file": md.get("source_file"), "header_path": md.get("header_path"), "chunk_index": md.get("chunk_index"), "content": normalized.get("content"), } return self._hasher.get_string_hash(str(key)) def add_documents( self, docs: Sequence[Dict[str, Any]], *, ids: Optional[Sequence[str]] = None, batch_size: int = 128, ) -> int: """Thêm documents vào vector store.""" if not docs: return 0 if ids is not None and len(ids) != len(docs): raise ValueError("Số lượng ids phải bằng số lượng docs") # Tách parent nodes (không embed) khỏi regular nodes regular_docs = [] regular_ids = [] parent_count = 0 for i, d in enumerate(docs): normalized = self._normalize_doc(d) md = normalized.get("metadata", {}) or {} doc_id = ids[i] if ids else self._doc_id(d) if md.get("is_parent"): # Lưu parent node riêng (cho Small-to-Big) parent_id = md.get("node_id", doc_id) self._parent_nodes[parent_id] = { "id": parent_id, "content": normalized.get("content", ""), "metadata": md, } parent_count += 1 else: regular_docs.append(d) regular_ids.append(doc_id) if parent_count > 0: logger.info(f"Đã lưu {parent_count} parent nodes (không embed)") self._save_parent_nodes() if not regular_docs: return parent_count # Thêm theo batch bs = max(1, batch_size) total = 0 for start in range(0, len(regular_docs), bs): batch = regular_docs[start : start + bs] batch_ids = regular_ids[start : start + bs] lc_docs = self._to_documents(batch, batch_ids) try: self._vs.add_documents(lc_docs, ids=batch_ids) except TypeError: # Fallback nếu add_documents không nhận ids texts = [d.page_content for d in lc_docs] metas = [d.metadata for d in lc_docs] self._vs.add_texts(texts=texts, metadatas=metas, ids=batch_ids) total += len(batch) logger.info(f"Đã thêm {total} documents vào vector store") return total + parent_count def upsert_documents( self, docs: Sequence[Dict[str, Any]], *, ids: Optional[Sequence[str]] = None, batch_size: int = 128, ) -> int: """Upsert documents (thêm mới hoặc cập nhật nếu đã tồn tại).""" if not docs: return 0 if ids is not None and len(ids) != len(docs): raise ValueError("Số lượng ids phải bằng số lượng docs") # Tách parent nodes khỏi regular nodes regular_docs = [] regular_ids = [] parent_count = 0 for i, d in enumerate(docs): normalized = self._normalize_doc(d) md = normalized.get("metadata", {}) or {} doc_id = ids[i] if ids else self._doc_id(d) if md.get("is_parent"): # Lưu parent node riêng parent_id = md.get("node_id", doc_id) self._parent_nodes[parent_id] = { "id": parent_id, "content": normalized.get("content", ""), "metadata": md, } parent_count += 1 else: regular_docs.append(d) regular_ids.append(doc_id) if parent_count > 0: logger.info(f"Đã lưu {parent_count} parent nodes (không embed)") self._save_parent_nodes() if not regular_docs: return parent_count bs = max(1, batch_size) col = self.collection # Fallback nếu không có collection if col is None: return self.add_documents(regular_docs, ids=regular_ids, batch_size=bs) + parent_count # Upsert theo batch total = 0 for start in range(0, len(regular_docs), bs): batch = regular_docs[start : start + bs] batch_ids = regular_ids[start : start + bs] lc_docs = self._to_documents(batch, batch_ids) texts = [d.page_content for d in lc_docs] metas = [d.metadata for d in lc_docs] embs = self.embedder.embed_documents(texts) col.upsert(ids=batch_ids, documents=texts, metadatas=metas, embeddings=embs) total += len(batch) logger.info(f"Đã upsert {total} documents vào vector store") return total + parent_count def count(self) -> int: """Đếm số documents trong collection.""" col = self.collection return int(col.count()) if col else 0 def get_all_documents(self, limit: int = 5000) -> List[Dict[str, Any]]: """Lấy tất cả documents từ collection.""" col = self.collection if col is None: return [] result = col.get(limit=limit, include=['documents', 'metadatas']) docs = [] for i, doc_content in enumerate(result.get('documents', [])): if doc_content: docs.append({ 'id': result['ids'][i] if result.get('ids') else str(i), 'content': doc_content, 'metadata': result['metadatas'][i] if result.get('metadatas') else {}, }) return docs def delete_documents(self, ids: Sequence[str]) -> int: """Xóa documents theo danh sách IDs.""" if not ids: return 0 col = self.collection if col is None: return 0 col.delete(ids=list(ids)) logger.info(f"Đã xóa {len(ids)} documents khỏi vector store") return len(ids) def get_parent_node(self, parent_id: str) -> Optional[Dict[str, Any]]: """Lấy parent node theo ID (cho Small-to-Big).""" return self._parent_nodes.get(parent_id) @property def parent_nodes(self) -> Dict[str, Dict[str, Any]]: """Lấy tất cả parent nodes.""" return self._parent_nodes