Spaces:
Sleeping
Sleeping
| """ | |
| src/hybrid.py | |
| ------------- | |
| Hybrid retriever combining BM25 keyword search and FAISS semantic search, | |
| fused with Reciprocal Rank Fusion (RRF). | |
| Designed to plug into the existing run_rag() pipeline in rag_pipeline.py | |
| as a drop-in replacement for the semantic retriever: | |
| hybrid_retriever = load_hybrid_retriever( | |
| bm25_index_path="data/processed/tokenisation/bm25_index_mini.pkl", | |
| faiss_store_path="data/processed/embeddings", | |
| k=5, | |
| ) | |
| answer = run_rag(hybrid_retriever, "Best coffee beans for espresso") | |
| The HybridRetriever class extends LangChain's BaseRetriever so it is fully | |
| compatible with the | (pipe) operator used in rag_pipeline.py: | |
| rag_chain = ( | |
| { | |
| "context": hybrid_retriever | RunnableLambda(build_context), | |
| "question": RunnablePassthrough(), | |
| } | |
| | prompt_template | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from typing import Any | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
| from langchain_core.documents import Document | |
| from langchain_core.retrievers import BaseRetriever | |
| from pydantic import Field | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # HybridRetriever | |
| # --------------------------------------------------------------------------- | |
| class HybridRetriever(BaseRetriever): | |
| """ | |
| Combines BM25 keyword retrieval and FAISS semantic retrieval using | |
| Reciprocal Rank Fusion (RRF) to produce a unified ranked document list. | |
| RRF score for document d across retriever r: | |
| score(d) = weight_r * (1 / (rrf_c + rank(d, r))) | |
| Documents appearing in both retrievers accumulate scores from both, | |
| naturally promoting results that are relevant by both keyword and meaning. | |
| Parameters | |
| ---------- | |
| bm25_retriever : Fitted LangChain BM25Retriever (from bm25.load()) | |
| semantic_store : Loaded FAISS vectorstore (from semantic.load_vector_store()) | |
| k : Number of final documents to return | |
| rrf_c : RRF constant β dampens the impact of rank differences. | |
| Standard value is 60; lower = top ranks matter more. | |
| bm25_weight : RRF weight for BM25 results (keyword signal) | |
| semantic_weight : RRF weight for semantic results (meaning signal) | |
| fetch_multiplier : Fetch this multiple of k from each retriever before fusing. | |
| More candidates = better fusion quality. Default: 3. | |
| """ | |
| bm25_retriever: Any = Field(...) | |
| semantic_store: Any = Field(...) | |
| k: int = Field(default=5) | |
| rrf_c: int = Field(default=60) | |
| bm25_weight: float = Field(default=0.5) | |
| semantic_weight: float = Field(default=0.5) | |
| fetch_multiplier: int = Field(default=3) | |
| def _get_relevant_documents( | |
| self, | |
| query: str, | |
| *, | |
| run_manager: CallbackManagerForRetrieverRun, | |
| ) -> list[Document]: | |
| """ | |
| Core retrieval logic called by LangChain when the retriever is invoked. | |
| Steps | |
| ----- | |
| 1. Fetch candidates from BM25 and FAISS independently | |
| 2. Assign RRF scores weighted by retriever confidence | |
| 3. Deduplicate by parent_asin, accumulating scores for shared hits | |
| 4. Sort by fused RRF score and return top-k Documents | |
| """ | |
| fetch_k = self.k * self.fetch_multiplier | |
| # ββ 1. BM25 retrieval ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| self.bm25_retriever.k = fetch_k | |
| try: | |
| bm25_docs: list[Document] = self.bm25_retriever.invoke(query) | |
| logger.debug("BM25 returned %d docs for query: %r", len(bm25_docs), query) | |
| except Exception as exc: | |
| logger.warning("BM25 retrieval failed: %s β using empty list.", exc) | |
| bm25_docs = [] | |
| # ββ 2. Semantic retrieval ββββββββββββββββββββββββββββββββββββββββββββ | |
| # similarity_search returns list[Document] (no scores needed β rank is enough for RRF) | |
| try: | |
| semantic_docs: list[Document] = self.semantic_store.similarity_search( | |
| query, k=fetch_k | |
| ) | |
| logger.debug( | |
| "Semantic returned %d docs for query: %r", len(semantic_docs), query | |
| ) | |
| except Exception as exc: | |
| logger.warning("Semantic retrieval failed: %s β using empty list.", exc) | |
| semantic_docs = [] | |
| # ββ 3. RRF fusion ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| rrf_scores: dict[str, float] = {} | |
| doc_map: dict[str, Document] = {} | |
| def _asin_key(doc: Document, fallback: str) -> str: | |
| """Use parent_asin as the dedup key; fall back to a content prefix.""" | |
| return doc.metadata.get("parent_asin") or fallback | |
| for rank, doc in enumerate(bm25_docs): | |
| key = _asin_key(doc, f"bm25_{rank}") | |
| score = self.bm25_weight / (self.rrf_c + rank + 1) | |
| rrf_scores[key] = rrf_scores.get(key, 0.0) + score | |
| doc_map[key] = doc # BM25 docs have richer metadata (top_reviews etc.) | |
| for rank, doc in enumerate(semantic_docs): | |
| key = _asin_key(doc, f"sem_{rank}") | |
| score = self.semantic_weight / (self.rrf_c + rank + 1) | |
| rrf_scores[key] = rrf_scores.get(key, 0.0) + score | |
| # Only add to doc_map if BM25 didn't already supply this product | |
| # (BM25 metadata is richer β has top_reviews, image_url, etc.) | |
| if key not in doc_map: | |
| doc_map[key] = doc | |
| # ββ 4. Sort and truncate βββββββββββββββββββββββββββββββββββββββββββββ | |
| ranked_keys = sorted(rrf_scores, key=lambda k: rrf_scores[k], reverse=True) | |
| top_docs = [doc_map[key] for key in ranked_keys[: self.k]] | |
| # Attach fused score to metadata β useful for app display | |
| for key, doc in zip(ranked_keys, top_docs): | |
| doc.metadata["hybrid_score"] = round(rrf_scores[key], 6) | |
| # Record which retriever(s) contributed to this result | |
| in_bm25 = any( | |
| _asin_key(d, f"bm25_{i}") == key for i, d in enumerate(bm25_docs) | |
| ) | |
| in_sem = any( | |
| _asin_key(d, f"sem_{i}") == key for i, d in enumerate(semantic_docs) | |
| ) | |
| if in_bm25 and in_sem: | |
| doc.metadata["retrieval_source"] = "hybrid" | |
| elif in_bm25: | |
| doc.metadata["retrieval_source"] = "bm25" | |
| else: | |
| doc.metadata["retrieval_source"] = "semantic" | |
| logger.info( | |
| "HybridRetriever: BM25=%d, Semantic=%d β fused=%d (returning top %d)", | |
| len(bm25_docs), len(semantic_docs), len(rrf_scores), len(top_docs), | |
| ) | |
| return top_docs | |
| # --------------------------------------------------------------------------- | |
| # Convenience loader | |
| # --------------------------------------------------------------------------- | |
| def load_hybrid_retriever( | |
| bm25_index_path: str = "data/processed/tokenisation/bm25_index_mini.pkl", | |
| faiss_store_path: str = "data/processed/embeddings", | |
| k: int = 5, | |
| bm25_weight: float = 0.5, | |
| semantic_weight: float = 0.5, | |
| rrf_c: int = 60, | |
| fetch_multiplier: int = 3, | |
| ) -> HybridRetriever: | |
| """ | |
| Load both indexes from disk and return a ready-to-use HybridRetriever. | |
| Call this once in your notebook or app.py, then pass the result to run_rag(). | |
| Parameters | |
| ---------- | |
| bm25_index_path : Path to the pickled BM25Retriever (from bm25.build_and_save()) | |
| faiss_store_path : Directory containing index.faiss + index.pkl | |
| (from semantic.build_and_save_vector_store()) | |
| k : Number of documents to return per query | |
| bm25_weight : RRF weight for BM25 (keyword signal). Default 0.5. | |
| semantic_weight : RRF weight for semantic (meaning signal). Default 0.5. | |
| Weights don't need to sum to 1 but relative scale matters. | |
| rrf_c : RRF rank-dampening constant. Default 60 (standard). | |
| fetch_multiplier : Candidates to fetch per retriever = k * fetch_multiplier. | |
| Returns | |
| ------- | |
| HybridRetriever | |
| A LangChain-compatible retriever pipeable with |. | |
| Example | |
| ------- | |
| >>> from src.hybrid import load_hybrid_retriever | |
| >>> from src.rag_pipeline import run_rag | |
| >>> | |
| >>> hybrid = load_hybrid_retriever(k=5) | |
| >>> answer = run_rag(hybrid, "Best coffee beans for a French press") | |
| >>> print(answer) | |
| """ | |
| # Import here to avoid circular imports when used from rag_pipeline.py | |
| from src.bm25 import load as load_bm25 | |
| from src.semantic import load_vector_store | |
| print(f"Loading BM25 index from: {bm25_index_path}") | |
| bm25_ret: BM25Retriever = load_bm25(bm25_index_path) | |
| print(f"Loading FAISS store from: {faiss_store_path}") | |
| faiss_store: FAISS = load_vector_store(faiss_store_path) | |
| retriever = HybridRetriever( | |
| bm25_retriever=bm25_ret, | |
| semantic_store=faiss_store, | |
| k=k, | |
| bm25_weight=bm25_weight, | |
| semantic_weight=semantic_weight, | |
| rrf_c=rrf_c, | |
| fetch_multiplier=fetch_multiplier, | |
| ) | |
| print( | |
| f"HybridRetriever ready β k={k}, " | |
| f"BM25 weight={bm25_weight}, Semantic weight={semantic_weight}, RRF c={rrf_c}" | |
| ) | |
| return retriever | |