from __future__ import annotations import os import time import logging from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Optional, Sequence, TYPE_CHECKING import re import requests from pydantic import Field from langchain_core.documents import Document from langchain_core.callbacks import Callbacks from langchain_core.documents.compressor import BaseDocumentCompressor from langchain_classic.retrievers import ContextualCompressionRetriever from langchain_classic.retrievers.ensemble import EnsembleRetriever from langchain_community.retrievers import BM25Retriever if TYPE_CHECKING: from core.rag.vector_store import ChromaVectorDB logger = logging.getLogger(__name__) class RetrievalMode(str, Enum): """Các chế độ retrieval hỗ trợ.""" VECTOR_ONLY = "vector_only" # Chỉ dùng vector search BM25_ONLY = "bm25_only" # Chỉ dùng BM25 keyword search HYBRID = "hybrid" # Kết hợp vector + BM25 HYBRID_RERANK = "hybrid_rerank" # Hybrid + reranking @dataclass class RetrievalConfig: """Cấu hình cho retrieval system.""" rerank_api_base_url: str = "https://api.siliconflow.com/v1" # API reranker rerank_model: str = "Qwen/Qwen3-Reranker-8B" # Model reranker rerank_top_n: int = 10 # Số kết quả sau rerank initial_k: int = 25 # Số docs lấy ban đầu top_k: int = 5 # Số kết quả cuối cùng vector_weight: float = 0.5 # Trọng số vector search bm25_weight: float = 0.5 # Trọng số BM25 _retrieval_config: RetrievalConfig | None = None def get_retrieval_config() -> RetrievalConfig: """Lấy cấu hình retrieval (singleton pattern).""" global _retrieval_config if _retrieval_config is None: _retrieval_config = RetrievalConfig() return _retrieval_config class SiliconFlowReranker(BaseDocumentCompressor): """Reranker sử dụng SiliconFlow API để sắp xếp lại kết quả.""" api_key: str = Field(default="") api_base_url: str = Field(default="") model: str = Field(default="") top_n: Optional[int] = Field(default=None) class Config: arbitrary_types_allowed = True def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """Rerank documents dựa trên độ liên quan với query.""" if not documents or not self.api_key: return list(documents) # Retry logic với exponential backoff for attempt in range(3): try: response = requests.post( f"{self.api_base_url}/rerank", headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, json={ "model": self.model, "query": query, "documents": [doc.page_content for doc in documents], "top_n": self.top_n or len(documents), }, timeout=120, ) response.raise_for_status() data = response.json() if "results" not in data: return list(documents) # Tạo danh sách documents đã rerank với score reranked: List[Document] = [] for result in data["results"]: doc = documents[result["index"]] meta = dict(doc.metadata or {}) meta["rerank_score"] = result["relevance_score"] reranked.append(Document(page_content=doc.page_content, metadata=meta)) return reranked except Exception as e: # Rate limit -> đợi rồi thử lại if "rate" in str(e).lower() and attempt < 2: time.sleep(2 ** attempt) else: logger.error(f"Lỗi rerank: {e}") return list(documents) return list(documents) class Retriever: """Retriever chính hỗ trợ nhiều chế độ tìm kiếm.""" def __init__(self, vector_db: "ChromaVectorDB", use_reranker: bool = True): """Khởi tạo retriever với vector DB và reranker.""" self._vector_db = vector_db self._config = get_retrieval_config() self._reranker: Optional[SiliconFlowReranker] = None # Vector retriever từ ChromaDB self._vector_retriever = self._vector_db.vectorstore.as_retriever( search_kwargs={"k": self._config.initial_k} ) # Lazy-load BM25 - chỉ khởi tạo khi cần self._bm25_retriever: Optional[BM25Retriever] = None self._bm25_initialized = False self._ensemble_retriever: Optional[EnsembleRetriever] = None # Đường dẫn cache BM25 (lưu vào disk) from pathlib import Path persist_dir = getattr(self._vector_db.config, 'persist_dir', None) if persist_dir: self._bm25_cache_path = Path(persist_dir) / "bm25_cache.pkl" else: self._bm25_cache_path = None if use_reranker: self._reranker = self._init_reranker() logger.info("Đã khởi tạo Retriever") def _save_bm25_cache(self, bm25: BM25Retriever) -> None: """Lưu BM25 index vào cache file.""" if not self._bm25_cache_path: return try: import pickle with open(self._bm25_cache_path, 'wb') as f: pickle.dump(bm25, f) logger.info(f"Đã lưu BM25 cache vào {self._bm25_cache_path}") except Exception as e: logger.warning(f"Không thể lưu BM25 cache: {e}") def _load_bm25_cache(self) -> Optional[BM25Retriever]: """Tải BM25 index từ cache file.""" if not self._bm25_cache_path or not self._bm25_cache_path.exists(): return None try: import pickle start = time.time() with open(self._bm25_cache_path, 'rb') as f: bm25 = pickle.load(f) bm25.k = self._config.initial_k logger.info(f"Đã tải BM25 từ cache trong {time.time() - start:.2f}s") return bm25 except Exception as e: logger.warning(f"Không thể tải BM25 cache: {e}") return None def _init_bm25(self) -> Optional[BM25Retriever]: """Khởi tạo BM25 retriever (lazy-load với cache).""" if self._bm25_initialized: return self._bm25_retriever self._bm25_initialized = True # Thử tải từ cache trước cached = self._load_bm25_cache() if cached: self._bm25_retriever = cached return cached # Build từ đầu nếu không có cache try: start = time.time() logger.info("Đang xây dựng BM25 index từ documents...") docs = self._vector_db.get_all_documents() if not docs: logger.warning("Không tìm thấy documents cho BM25") return None lc_docs = [ Document(page_content=d["content"], metadata=d.get("metadata", {})) for d in docs ] bm25 = BM25Retriever.from_documents(lc_docs) bm25.k = self._config.initial_k self._bm25_retriever = bm25 logger.info(f"Đã xây dựng BM25 với {len(docs)} docs trong {time.time() - start:.2f}s") # Lưu vào cache cho lần sau self._save_bm25_cache(bm25) return bm25 except Exception as e: logger.error(f"Không thể khởi tạo BM25: {e}") return None def _get_ensemble_retriever(self) -> EnsembleRetriever: """Lấy ensemble retriever (vector + BM25).""" if self._ensemble_retriever is not None: return self._ensemble_retriever bm25 = self._init_bm25() if bm25: self._ensemble_retriever = EnsembleRetriever( retrievers=[self._vector_retriever, bm25], weights=[self._config.vector_weight, self._config.bm25_weight] ) else: # Fallback về vector only self._ensemble_retriever = EnsembleRetriever( retrievers=[self._vector_retriever], weights=[1.0] ) return self._ensemble_retriever def _init_reranker(self) -> Optional[SiliconFlowReranker]: """Khởi tạo reranker nếu có API key.""" api_key = os.getenv("SILICONFLOW_API_KEY", "").strip() if not api_key: return None return SiliconFlowReranker( api_key=api_key, api_base_url=self._config.rerank_api_base_url, model=self._config.rerank_model, top_n=self._config.rerank_top_n, ) def _build_final(self): """Build retriever cuối cùng (ensemble + reranker nếu có).""" ensemble = self._get_ensemble_retriever() if self._reranker: return ContextualCompressionRetriever( base_compressor=self._reranker, base_retriever=ensemble ) return ensemble @property def has_reranker(self) -> bool: """Kiểm tra có reranker không.""" return self._reranker is not None def _to_result(self, doc: Document, rank: int, **extra) -> Dict[str, Any]: """Chuyển Document thành dict result, xử lý Small-to-Big.""" metadata = doc.metadata or {} content = doc.page_content # Small-to-Big: Nếu là summary node -> swap với parent (bảng gốc) if metadata.get("is_table_summary") and metadata.get("parent_id"): parent = self._vector_db.get_parent_node(metadata["parent_id"]) if parent: content = parent.get("content", content) # Merge metadata, giữ lại info summary để debug metadata = { **parent.get("metadata", {}), "original_summary": doc.page_content[:200], "swapped_from_summary": True, } return { "id": metadata.get("id"), "content": content, "metadata": metadata, "final_rank": rank, **extra, } def vector_search( self, text: str, *, k: int | None = None, where: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: """Tìm kiếm bằng vector similarity.""" if not text.strip(): return [] k = k or self._config.top_k results = self._vector_db.vectorstore.similarity_search_with_score(text, k=k, filter=where) return [self._to_result(doc, i + 1, distance=score) for i, (doc, score) in enumerate(results)] def bm25_search(self, text: str, *, k: int | None = None) -> List[Dict[str, Any]]: """Tìm kiếm bằng BM25 keyword matching.""" if not text.strip(): return [] bm25 = self._init_bm25() # Lazy-load BM25 if not bm25: return self.vector_search(text, k=k) k = k or self._config.top_k bm25.k = k results = bm25.invoke(text) return [self._to_result(doc, i + 1) for i, doc in enumerate(results[:k])] def hybrid_search( self, text: str, *, k: int | None = None, initial_k: int | None = None ) -> List[Dict[str, Any]]: """Tìm kiếm hybrid (vector + BM25) không có rerank.""" if not text.strip(): return [] k = k or self._config.top_k if initial_k: self._vector_retriever.search_kwargs["k"] = initial_k bm25 = self._init_bm25() if bm25: bm25.k = initial_k ensemble = self._get_ensemble_retriever() results = ensemble.invoke(text) return [self._to_result(doc, i + 1) for i, doc in enumerate(results[:k])] def search_with_rerank( self, text: str, *, k: int | None = None, where: Optional[Dict[str, Any]] = None, initial_k: int | None = None, ) -> List[Dict[str, Any]]: """Tìm kiếm hybrid + reranking để có kết quả tốt nhất.""" if not text.strip(): return [] k = k or self._config.top_k initial_k = initial_k or self._config.initial_k # Có filter -> dùng vector search + manual rerank if where: results = self._vector_db.vectorstore.similarity_search(text, k=initial_k, filter=where) if self._reranker: results = self._reranker.compress_documents(results, text) return [ self._to_result(doc, i + 1, rerank_score=doc.metadata.get("rerank_score")) for i, doc in enumerate(results[:k]) ] # Cập nhật k cho initial fetch if initial_k: self._vector_retriever.search_kwargs["k"] = initial_k bm25 = self._init_bm25() if bm25: bm25.k = initial_k # Hybrid search ensemble = self._get_ensemble_retriever() ensemble_results = ensemble.invoke(text) # Rerank nếu có if self._reranker: results = self._reranker.compress_documents(ensemble_results, text) else: results = ensemble_results return [ self._to_result(doc, i + 1, rerank_score=doc.metadata.get("rerank_score")) for i, doc in enumerate(results[:k]) ] def flexible_search( self, text: str, *, mode: RetrievalMode | str = RetrievalMode.HYBRID_RERANK, k: int | None = None, initial_k: int | None = None, where: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """Tìm kiếm linh hoạt với nhiều chế độ.""" if not text.strip(): return [] # Parse mode từ string if isinstance(mode, str): try: mode = RetrievalMode(mode.lower()) except ValueError: mode = RetrievalMode.HYBRID_RERANK k = k or self._config.top_k initial_k = initial_k or self._config.initial_k # Gọi method tương ứng theo mode if mode == RetrievalMode.VECTOR_ONLY: return self.vector_search(text, k=k, where=where) elif mode == RetrievalMode.BM25_ONLY: return self.bm25_search(text, k=k) elif mode == RetrievalMode.HYBRID: if where: return self.vector_search(text, k=k, where=where) return self.hybrid_search(text, k=k, initial_k=initial_k) else: # HYBRID_RERANK return self.search_with_rerank(text, k=k, where=where, initial_k=initial_k) # Alias để tương thích ngược query = vector_search