from __future__ import annotations import os import time import logging from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Optional, Sequence, TYPE_CHECKING import re import requests from pydantic import Field from langchain_core.documents import Document from langchain_core.callbacks import Callbacks from langchain_core.documents.compressor import BaseDocumentCompressor from langchain_classic.retrievers import ContextualCompressionRetriever from langchain_classic.retrievers.ensemble import EnsembleRetriever from langchain_community.retrievers import BM25Retriever if TYPE_CHECKING: from core.rag.vector_store import ChromaVectorDB logger = logging.getLogger(__name__) class RetrievalMode(str, Enum): VECTOR_ONLY = "vector_only" BM25_ONLY = "bm25_only" HYBRID = "hybrid" HYBRID_RERANK = "hybrid_rerank" @dataclass class RetrievalConfig: rerank_api_base_url: str = "https://api.siliconflow.com/v1" rerank_model: str = "Qwen/Qwen3-Reranker-8B" rerank_top_n: int = 5 initial_k: int = 50 top_k: int = 5 vector_weight: float = 0.5 bm25_weight: float = 0.5 _retrieval_config: RetrievalConfig | None = None def get_retrieval_config() -> RetrievalConfig: global _retrieval_config if _retrieval_config is None: _retrieval_config = RetrievalConfig() return _retrieval_config class SiliconFlowReranker(BaseDocumentCompressor): api_key: str = Field(default="") api_base_url: str = Field(default="") model: str = Field(default="") top_n: Optional[int] = Field(default=None) class Config: arbitrary_types_allowed = True def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: if not documents or not self.api_key: return list(documents) # Retry with exponential backoff for attempt in range(3): try: response = requests.post( f"{self.api_base_url}/rerank", headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, json={ "model": self.model, "query": query, "documents": [doc.page_content for doc in documents], "top_n": self.top_n or len(documents), }, timeout=120, ) response.raise_for_status() data = response.json() if "results" not in data: return list(documents) # Build reranked document list with scores reranked: List[Document] = [] for result in data["results"]: doc = documents[result["index"]] meta = dict(doc.metadata or {}) meta["rerank_score"] = result["relevance_score"] reranked.append(Document(page_content=doc.page_content, metadata=meta)) return reranked except Exception as e: # Rate limit -> wait and retry if "rate" in str(e).lower() and attempt < 2: time.sleep(2 ** attempt) else: logger.error(f"Rerank error: {e}") return list(documents) return list(documents) class Retriever: def __init__(self, vector_db: "ChromaVectorDB", use_reranker: bool = True): self._vector_db = vector_db self._config = get_retrieval_config() self._reranker: Optional[SiliconFlowReranker] = None # Vector retriever from ChromaDB self._vector_retriever = self._vector_db.vectorstore.as_retriever( search_kwargs={"k": self._config.initial_k} ) # Lazy-load BM25 - only initialized when needed self._bm25_retriever: Optional[BM25Retriever] = None self._bm25_initialized = False self._ensemble_retriever: Optional[EnsembleRetriever] = None # BM25 cache path (saved to disk) from pathlib import Path persist_dir = getattr(self._vector_db.config, 'persist_dir', None) if persist_dir: self._bm25_cache_path = Path(persist_dir) / "bm25_cache.pkl" else: self._bm25_cache_path = None if use_reranker: self._reranker = self._init_reranker() logger.info("Initialized Retriever") def _save_bm25_cache(self, bm25: BM25Retriever) -> None: if not self._bm25_cache_path: return try: import pickle with open(self._bm25_cache_path, 'wb') as f: pickle.dump(bm25, f) logger.info(f"Saved BM25 cache to {self._bm25_cache_path}") except Exception as e: logger.warning(f"Failed to save BM25 cache: {e}") def _load_bm25_cache(self) -> Optional[BM25Retriever]: if not self._bm25_cache_path or not self._bm25_cache_path.exists(): return None try: import pickle start = time.time() with open(self._bm25_cache_path, 'rb') as f: bm25 = pickle.load(f) bm25.k = self._config.initial_k logger.info(f"Loaded BM25 from cache in {time.time() - start:.2f}s") return bm25 except Exception as e: logger.warning(f"Failed to load BM25 cache: {e}") return None def _init_bm25(self) -> Optional[BM25Retriever]: if self._bm25_initialized: return self._bm25_retriever self._bm25_initialized = True # Try loading from cache first cached = self._load_bm25_cache() if cached: self._bm25_retriever = cached return cached # Build from scratch if no cache try: start = time.time() logger.info("Building BM25 index from documents...") docs = self._vector_db.get_all_documents() if not docs: logger.warning("No documents found for BM25") return None lc_docs = [ Document(page_content=d["content"], metadata=d.get("metadata", {})) for d in docs ] bm25 = BM25Retriever.from_documents(lc_docs) bm25.k = self._config.initial_k self._bm25_retriever = bm25 logger.info(f"Built BM25 with {len(docs)} docs in {time.time() - start:.2f}s") # Save to cache for next time self._save_bm25_cache(bm25) return bm25 except Exception as e: logger.error(f"Failed to initialize BM25: {e}") return None def _get_ensemble_retriever(self) -> EnsembleRetriever: if self._ensemble_retriever is not None: return self._ensemble_retriever bm25 = self._init_bm25() if bm25: self._ensemble_retriever = EnsembleRetriever( retrievers=[self._vector_retriever, bm25], weights=[self._config.vector_weight, self._config.bm25_weight] ) else: # Fallback to vector only self._ensemble_retriever = EnsembleRetriever( retrievers=[self._vector_retriever], weights=[1.0] ) return self._ensemble_retriever def _init_reranker(self) -> Optional[SiliconFlowReranker]: api_key = os.getenv("SILICONFLOW_API_KEY", "").strip() if not api_key: return None return SiliconFlowReranker( api_key=api_key, api_base_url=self._config.rerank_api_base_url, model=self._config.rerank_model, top_n=self._config.rerank_top_n, ) def _build_final(self): ensemble = self._get_ensemble_retriever() if self._reranker: return ContextualCompressionRetriever( base_compressor=self._reranker, base_retriever=ensemble ) return ensemble @property def has_reranker(self) -> bool: return self._reranker is not None def _to_result(self, doc: Document, rank: int, **extra) -> Dict[str, Any]: metadata = doc.metadata or {} content = doc.page_content # Small-to-Big: if summary node -> swap with parent (original table) if metadata.get("is_table_summary") and metadata.get("parent_id"): parent = self._vector_db.get_parent_node(metadata["parent_id"]) if parent: content = parent.get("content", content) # Merge metadata, keep summary info for debugging metadata = { **parent.get("metadata", {}), "original_summary": doc.page_content[:200], "swapped_from_summary": True, } return { "id": metadata.get("id"), "content": content, "metadata": metadata, "final_rank": rank, **extra, } def vector_search( self, text: str, *, k: int | None = None, where: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: if not text.strip(): return [] k = k or self._config.top_k results = self._vector_db.vectorstore.similarity_search_with_score(text, k=k, filter=where) return [self._to_result(doc, i + 1, distance=score) for i, (doc, score) in enumerate(results)] def bm25_search(self, text: str, *, k: int | None = None) -> List[Dict[str, Any]]: if not text.strip(): return [] bm25 = self._init_bm25() # Lazy-load BM25 if not bm25: return self.vector_search(text, k=k) k = k or self._config.top_k bm25.k = k results = bm25.invoke(text) return [self._to_result(doc, i + 1) for i, doc in enumerate(results[:k])] def hybrid_search( self, text: str, *, k: int | None = None, initial_k: int | None = None ) -> List[Dict[str, Any]]: if not text.strip(): return [] k = k or self._config.top_k if initial_k: self._vector_retriever.search_kwargs["k"] = initial_k bm25 = self._init_bm25() if bm25: bm25.k = initial_k ensemble = self._get_ensemble_retriever() results = ensemble.invoke(text) return [self._to_result(doc, i + 1) for i, doc in enumerate(results[:k])] def search_with_rerank( self, text: str, *, k: int | None = None, where: Optional[Dict[str, Any]] = None, initial_k: int | None = None, ) -> List[Dict[str, Any]]: if not text.strip(): return [] k = k or self._config.top_k initial_k = initial_k or self._config.initial_k # Has filter -> use vector search + manual rerank if where: results = self._vector_db.vectorstore.similarity_search(text, k=initial_k, filter=where) if self._reranker: results = self._reranker.compress_documents(results, text) return [ self._to_result(doc, i + 1, rerank_score=doc.metadata.get("rerank_score")) for i, doc in enumerate(results[:k]) ] # Update k for initial fetch if initial_k: self._vector_retriever.search_kwargs["k"] = initial_k bm25 = self._init_bm25() if bm25: bm25.k = initial_k # Hybrid search ensemble = self._get_ensemble_retriever() ensemble_results = ensemble.invoke(text) # Rerank if available if self._reranker: results = self._reranker.compress_documents(ensemble_results, text) else: results = ensemble_results return [ self._to_result(doc, i + 1, rerank_score=doc.metadata.get("rerank_score")) for i, doc in enumerate(results[:k]) ] def flexible_search( self, text: str, *, mode: RetrievalMode | str = RetrievalMode.HYBRID_RERANK, k: int | None = None, initial_k: int | None = None, where: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: if not text.strip(): return [] # Parse mode from string if isinstance(mode, str): try: mode = RetrievalMode(mode.lower()) except ValueError: mode = RetrievalMode.HYBRID_RERANK k = k or self._config.top_k initial_k = initial_k or self._config.initial_k # Dispatch to corresponding method by mode if mode == RetrievalMode.VECTOR_ONLY: return self.vector_search(text, k=k, where=where) elif mode == RetrievalMode.BM25_ONLY: return self.bm25_search(text, k=k) elif mode == RetrievalMode.HYBRID: if where: return self.vector_search(text, k=k, where=where) return self.hybrid_search(text, k=k, initial_k=initial_k) else: # HYBRID_RERANK return self.search_with_rerank(text, k=k, where=where, initial_k=initial_k) # Backward compatibility alias query = vector_search