DoAn / core /rag /retrival.py
hungnha's picture
build lại data
bf7ec12
from __future__ import annotations
import os
import time
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, TYPE_CHECKING
import re
import requests
from pydantic import Field
from langchain_core.documents import Document
from langchain_core.callbacks import Callbacks
from langchain_core.documents.compressor import BaseDocumentCompressor
from langchain_classic.retrievers import ContextualCompressionRetriever
from langchain_classic.retrievers.ensemble import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
if TYPE_CHECKING:
from core.rag.vector_store import ChromaVectorDB
logger = logging.getLogger(__name__)
class RetrievalMode(str, Enum):
"""Các chế độ retrieval hỗ trợ."""
VECTOR_ONLY = "vector_only" # Chỉ dùng vector search
BM25_ONLY = "bm25_only" # Chỉ dùng BM25 keyword search
HYBRID = "hybrid" # Kết hợp vector + BM25
HYBRID_RERANK = "hybrid_rerank" # Hybrid + reranking
@dataclass
class RetrievalConfig:
"""Cấu hình cho retrieval system."""
rerank_api_base_url: str = "https://api.siliconflow.com/v1" # API reranker
rerank_model: str = "Qwen/Qwen3-Reranker-8B" # Model reranker
rerank_top_n: int = 10 # Số kết quả sau rerank
initial_k: int = 25 # Số docs lấy ban đầu
top_k: int = 5 # Số kết quả cuối cùng
vector_weight: float = 0.5 # Trọng số vector search
bm25_weight: float = 0.5 # Trọng số BM25
_retrieval_config: RetrievalConfig | None = None
def get_retrieval_config() -> RetrievalConfig:
"""Lấy cấu hình retrieval (singleton pattern)."""
global _retrieval_config
if _retrieval_config is None:
_retrieval_config = RetrievalConfig()
return _retrieval_config
class SiliconFlowReranker(BaseDocumentCompressor):
"""Reranker sử dụng SiliconFlow API để sắp xếp lại kết quả."""
api_key: str = Field(default="")
api_base_url: str = Field(default="")
model: str = Field(default="")
top_n: Optional[int] = Field(default=None)
class Config:
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Rerank documents dựa trên độ liên quan với query."""
if not documents or not self.api_key:
return list(documents)
# Retry logic với exponential backoff
for attempt in range(3):
try:
response = requests.post(
f"{self.api_base_url}/rerank",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={
"model": self.model,
"query": query,
"documents": [doc.page_content for doc in documents],
"top_n": self.top_n or len(documents),
},
timeout=120,
)
response.raise_for_status()
data = response.json()
if "results" not in data:
return list(documents)
# Tạo danh sách documents đã rerank với score
reranked: List[Document] = []
for result in data["results"]:
doc = documents[result["index"]]
meta = dict(doc.metadata or {})
meta["rerank_score"] = result["relevance_score"]
reranked.append(Document(page_content=doc.page_content, metadata=meta))
return reranked
except Exception as e:
# Rate limit -> đợi rồi thử lại
if "rate" in str(e).lower() and attempt < 2:
time.sleep(2 ** attempt)
else:
logger.error(f"Lỗi rerank: {e}")
return list(documents)
return list(documents)
class Retriever:
"""Retriever chính hỗ trợ nhiều chế độ tìm kiếm."""
def __init__(self, vector_db: "ChromaVectorDB", use_reranker: bool = True):
"""Khởi tạo retriever với vector DB và reranker."""
self._vector_db = vector_db
self._config = get_retrieval_config()
self._reranker: Optional[SiliconFlowReranker] = None
# Vector retriever từ ChromaDB
self._vector_retriever = self._vector_db.vectorstore.as_retriever(
search_kwargs={"k": self._config.initial_k}
)
# Lazy-load BM25 - chỉ khởi tạo khi cần
self._bm25_retriever: Optional[BM25Retriever] = None
self._bm25_initialized = False
self._ensemble_retriever: Optional[EnsembleRetriever] = None
# Đường dẫn cache BM25 (lưu vào disk)
from pathlib import Path
persist_dir = getattr(self._vector_db.config, 'persist_dir', None)
if persist_dir:
self._bm25_cache_path = Path(persist_dir) / "bm25_cache.pkl"
else:
self._bm25_cache_path = None
if use_reranker:
self._reranker = self._init_reranker()
logger.info("Đã khởi tạo Retriever")
def _save_bm25_cache(self, bm25: BM25Retriever) -> None:
"""Lưu BM25 index vào cache file."""
if not self._bm25_cache_path:
return
try:
import pickle
with open(self._bm25_cache_path, 'wb') as f:
pickle.dump(bm25, f)
logger.info(f"Đã lưu BM25 cache vào {self._bm25_cache_path}")
except Exception as e:
logger.warning(f"Không thể lưu BM25 cache: {e}")
def _load_bm25_cache(self) -> Optional[BM25Retriever]:
"""Tải BM25 index từ cache file."""
if not self._bm25_cache_path or not self._bm25_cache_path.exists():
return None
try:
import pickle
start = time.time()
with open(self._bm25_cache_path, 'rb') as f:
bm25 = pickle.load(f)
bm25.k = self._config.initial_k
logger.info(f"Đã tải BM25 từ cache trong {time.time() - start:.2f}s")
return bm25
except Exception as e:
logger.warning(f"Không thể tải BM25 cache: {e}")
return None
def _init_bm25(self) -> Optional[BM25Retriever]:
"""Khởi tạo BM25 retriever (lazy-load với cache)."""
if self._bm25_initialized:
return self._bm25_retriever
self._bm25_initialized = True
# Thử tải từ cache trước
cached = self._load_bm25_cache()
if cached:
self._bm25_retriever = cached
return cached
# Build từ đầu nếu không có cache
try:
start = time.time()
logger.info("Đang xây dựng BM25 index từ documents...")
docs = self._vector_db.get_all_documents()
if not docs:
logger.warning("Không tìm thấy documents cho BM25")
return None
lc_docs = [
Document(page_content=d["content"], metadata=d.get("metadata", {}))
for d in docs
]
bm25 = BM25Retriever.from_documents(lc_docs)
bm25.k = self._config.initial_k
self._bm25_retriever = bm25
logger.info(f"Đã xây dựng BM25 với {len(docs)} docs trong {time.time() - start:.2f}s")
# Lưu vào cache cho lần sau
self._save_bm25_cache(bm25)
return bm25
except Exception as e:
logger.error(f"Không thể khởi tạo BM25: {e}")
return None
def _get_ensemble_retriever(self) -> EnsembleRetriever:
"""Lấy ensemble retriever (vector + BM25)."""
if self._ensemble_retriever is not None:
return self._ensemble_retriever
bm25 = self._init_bm25()
if bm25:
self._ensemble_retriever = EnsembleRetriever(
retrievers=[self._vector_retriever, bm25],
weights=[self._config.vector_weight, self._config.bm25_weight]
)
else:
# Fallback về vector only
self._ensemble_retriever = EnsembleRetriever(
retrievers=[self._vector_retriever],
weights=[1.0]
)
return self._ensemble_retriever
def _init_reranker(self) -> Optional[SiliconFlowReranker]:
"""Khởi tạo reranker nếu có API key."""
api_key = os.getenv("SILICONFLOW_API_KEY", "").strip()
if not api_key:
return None
return SiliconFlowReranker(
api_key=api_key,
api_base_url=self._config.rerank_api_base_url,
model=self._config.rerank_model,
top_n=self._config.rerank_top_n,
)
def _build_final(self):
"""Build retriever cuối cùng (ensemble + reranker nếu có)."""
ensemble = self._get_ensemble_retriever()
if self._reranker:
return ContextualCompressionRetriever(
base_compressor=self._reranker,
base_retriever=ensemble
)
return ensemble
@property
def has_reranker(self) -> bool:
"""Kiểm tra có reranker không."""
return self._reranker is not None
def _to_result(self, doc: Document, rank: int, **extra) -> Dict[str, Any]:
"""Chuyển Document thành dict result, xử lý Small-to-Big."""
metadata = doc.metadata or {}
content = doc.page_content
# Small-to-Big: Nếu là summary node -> swap với parent (bảng gốc)
if metadata.get("is_table_summary") and metadata.get("parent_id"):
parent = self._vector_db.get_parent_node(metadata["parent_id"])
if parent:
content = parent.get("content", content)
# Merge metadata, giữ lại info summary để debug
metadata = {
**parent.get("metadata", {}),
"original_summary": doc.page_content[:200],
"swapped_from_summary": True,
}
return {
"id": metadata.get("id"),
"content": content,
"metadata": metadata,
"final_rank": rank,
**extra,
}
def vector_search(
self, text: str, *, k: int | None = None, where: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""Tìm kiếm bằng vector similarity."""
if not text.strip():
return []
k = k or self._config.top_k
results = self._vector_db.vectorstore.similarity_search_with_score(text, k=k, filter=where)
return [self._to_result(doc, i + 1, distance=score) for i, (doc, score) in enumerate(results)]
def bm25_search(self, text: str, *, k: int | None = None) -> List[Dict[str, Any]]:
"""Tìm kiếm bằng BM25 keyword matching."""
if not text.strip():
return []
bm25 = self._init_bm25() # Lazy-load BM25
if not bm25:
return self.vector_search(text, k=k)
k = k or self._config.top_k
bm25.k = k
results = bm25.invoke(text)
return [self._to_result(doc, i + 1) for i, doc in enumerate(results[:k])]
def hybrid_search(
self, text: str, *, k: int | None = None, initial_k: int | None = None
) -> List[Dict[str, Any]]:
"""Tìm kiếm hybrid (vector + BM25) không có rerank."""
if not text.strip():
return []
k = k or self._config.top_k
if initial_k:
self._vector_retriever.search_kwargs["k"] = initial_k
bm25 = self._init_bm25()
if bm25:
bm25.k = initial_k
ensemble = self._get_ensemble_retriever()
results = ensemble.invoke(text)
return [self._to_result(doc, i + 1) for i, doc in enumerate(results[:k])]
def search_with_rerank(
self,
text: str,
*,
k: int | None = None,
where: Optional[Dict[str, Any]] = None,
initial_k: int | None = None,
) -> List[Dict[str, Any]]:
"""Tìm kiếm hybrid + reranking để có kết quả tốt nhất."""
if not text.strip():
return []
k = k or self._config.top_k
initial_k = initial_k or self._config.initial_k
# Có filter -> dùng vector search + manual rerank
if where:
results = self._vector_db.vectorstore.similarity_search(text, k=initial_k, filter=where)
if self._reranker:
results = self._reranker.compress_documents(results, text)
return [
self._to_result(doc, i + 1, rerank_score=doc.metadata.get("rerank_score"))
for i, doc in enumerate(results[:k])
]
# Cập nhật k cho initial fetch
if initial_k:
self._vector_retriever.search_kwargs["k"] = initial_k
bm25 = self._init_bm25()
if bm25:
bm25.k = initial_k
# Hybrid search
ensemble = self._get_ensemble_retriever()
ensemble_results = ensemble.invoke(text)
# Rerank nếu có
if self._reranker:
results = self._reranker.compress_documents(ensemble_results, text)
else:
results = ensemble_results
return [
self._to_result(doc, i + 1, rerank_score=doc.metadata.get("rerank_score"))
for i, doc in enumerate(results[:k])
]
def flexible_search(
self,
text: str,
*,
mode: RetrievalMode | str = RetrievalMode.HYBRID_RERANK,
k: int | None = None,
initial_k: int | None = None,
where: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""Tìm kiếm linh hoạt với nhiều chế độ."""
if not text.strip():
return []
# Parse mode từ string
if isinstance(mode, str):
try:
mode = RetrievalMode(mode.lower())
except ValueError:
mode = RetrievalMode.HYBRID_RERANK
k = k or self._config.top_k
initial_k = initial_k or self._config.initial_k
# Gọi method tương ứng theo mode
if mode == RetrievalMode.VECTOR_ONLY:
return self.vector_search(text, k=k, where=where)
elif mode == RetrievalMode.BM25_ONLY:
return self.bm25_search(text, k=k)
elif mode == RetrievalMode.HYBRID:
if where:
return self.vector_search(text, k=k, where=where)
return self.hybrid_search(text, k=k, initial_k=initial_k)
else: # HYBRID_RERANK
return self.search_with_rerank(text, k=k, where=where, initial_k=initial_k)
# Alias để tương thích ngược
query = vector_search