amazon_retriever / src /hybrid.py
Sarisha Das
fix all paths
e51a05a
"""
src/hybrid.py
-------------
Hybrid retriever combining BM25 keyword search and FAISS semantic search,
fused with Reciprocal Rank Fusion (RRF).
Designed to plug into the existing run_rag() pipeline in rag_pipeline.py
as a drop-in replacement for the semantic retriever:
hybrid_retriever = load_hybrid_retriever(
bm25_index_path="data/processed/tokenisation/bm25_index_mini.pkl",
faiss_store_path="data/processed/embeddings",
k=5,
)
answer = run_rag(hybrid_retriever, "Best coffee beans for espresso")
The HybridRetriever class extends LangChain's BaseRetriever so it is fully
compatible with the | (pipe) operator used in rag_pipeline.py:
rag_chain = (
{
"context": hybrid_retriever | RunnableLambda(build_context),
"question": RunnablePassthrough(),
}
| prompt_template
| llm
| StrOutputParser()
)
"""
from __future__ import annotations
import logging
from typing import Any
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import Field
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# HybridRetriever
# ---------------------------------------------------------------------------
class HybridRetriever(BaseRetriever):
"""
Combines BM25 keyword retrieval and FAISS semantic retrieval using
Reciprocal Rank Fusion (RRF) to produce a unified ranked document list.
RRF score for document d across retriever r:
score(d) = weight_r * (1 / (rrf_c + rank(d, r)))
Documents appearing in both retrievers accumulate scores from both,
naturally promoting results that are relevant by both keyword and meaning.
Parameters
----------
bm25_retriever : Fitted LangChain BM25Retriever (from bm25.load())
semantic_store : Loaded FAISS vectorstore (from semantic.load_vector_store())
k : Number of final documents to return
rrf_c : RRF constant β€” dampens the impact of rank differences.
Standard value is 60; lower = top ranks matter more.
bm25_weight : RRF weight for BM25 results (keyword signal)
semantic_weight : RRF weight for semantic results (meaning signal)
fetch_multiplier : Fetch this multiple of k from each retriever before fusing.
More candidates = better fusion quality. Default: 3.
"""
bm25_retriever: Any = Field(...)
semantic_store: Any = Field(...)
k: int = Field(default=5)
rrf_c: int = Field(default=60)
bm25_weight: float = Field(default=0.5)
semantic_weight: float = Field(default=0.5)
fetch_multiplier: int = Field(default=3)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
"""
Core retrieval logic called by LangChain when the retriever is invoked.
Steps
-----
1. Fetch candidates from BM25 and FAISS independently
2. Assign RRF scores weighted by retriever confidence
3. Deduplicate by parent_asin, accumulating scores for shared hits
4. Sort by fused RRF score and return top-k Documents
"""
fetch_k = self.k * self.fetch_multiplier
# ── 1. BM25 retrieval ────────────────────────────────────────────────
self.bm25_retriever.k = fetch_k
try:
bm25_docs: list[Document] = self.bm25_retriever.invoke(query)
logger.debug("BM25 returned %d docs for query: %r", len(bm25_docs), query)
except Exception as exc:
logger.warning("BM25 retrieval failed: %s β€” using empty list.", exc)
bm25_docs = []
# ── 2. Semantic retrieval ────────────────────────────────────────────
# similarity_search returns list[Document] (no scores needed β€” rank is enough for RRF)
try:
semantic_docs: list[Document] = self.semantic_store.similarity_search(
query, k=fetch_k
)
logger.debug(
"Semantic returned %d docs for query: %r", len(semantic_docs), query
)
except Exception as exc:
logger.warning("Semantic retrieval failed: %s β€” using empty list.", exc)
semantic_docs = []
# ── 3. RRF fusion ────────────────────────────────────────────────────
rrf_scores: dict[str, float] = {}
doc_map: dict[str, Document] = {}
def _asin_key(doc: Document, fallback: str) -> str:
"""Use parent_asin as the dedup key; fall back to a content prefix."""
return doc.metadata.get("parent_asin") or fallback
for rank, doc in enumerate(bm25_docs):
key = _asin_key(doc, f"bm25_{rank}")
score = self.bm25_weight / (self.rrf_c + rank + 1)
rrf_scores[key] = rrf_scores.get(key, 0.0) + score
doc_map[key] = doc # BM25 docs have richer metadata (top_reviews etc.)
for rank, doc in enumerate(semantic_docs):
key = _asin_key(doc, f"sem_{rank}")
score = self.semantic_weight / (self.rrf_c + rank + 1)
rrf_scores[key] = rrf_scores.get(key, 0.0) + score
# Only add to doc_map if BM25 didn't already supply this product
# (BM25 metadata is richer β€” has top_reviews, image_url, etc.)
if key not in doc_map:
doc_map[key] = doc
# ── 4. Sort and truncate ─────────────────────────────────────────────
ranked_keys = sorted(rrf_scores, key=lambda k: rrf_scores[k], reverse=True)
top_docs = [doc_map[key] for key in ranked_keys[: self.k]]
# Attach fused score to metadata β€” useful for app display
for key, doc in zip(ranked_keys, top_docs):
doc.metadata["hybrid_score"] = round(rrf_scores[key], 6)
# Record which retriever(s) contributed to this result
in_bm25 = any(
_asin_key(d, f"bm25_{i}") == key for i, d in enumerate(bm25_docs)
)
in_sem = any(
_asin_key(d, f"sem_{i}") == key for i, d in enumerate(semantic_docs)
)
if in_bm25 and in_sem:
doc.metadata["retrieval_source"] = "hybrid"
elif in_bm25:
doc.metadata["retrieval_source"] = "bm25"
else:
doc.metadata["retrieval_source"] = "semantic"
logger.info(
"HybridRetriever: BM25=%d, Semantic=%d β†’ fused=%d (returning top %d)",
len(bm25_docs), len(semantic_docs), len(rrf_scores), len(top_docs),
)
return top_docs
# ---------------------------------------------------------------------------
# Convenience loader
# ---------------------------------------------------------------------------
def load_hybrid_retriever(
bm25_index_path: str = "data/processed/tokenisation/bm25_index_mini.pkl",
faiss_store_path: str = "data/processed/embeddings",
k: int = 5,
bm25_weight: float = 0.5,
semantic_weight: float = 0.5,
rrf_c: int = 60,
fetch_multiplier: int = 3,
) -> HybridRetriever:
"""
Load both indexes from disk and return a ready-to-use HybridRetriever.
Call this once in your notebook or app.py, then pass the result to run_rag().
Parameters
----------
bm25_index_path : Path to the pickled BM25Retriever (from bm25.build_and_save())
faiss_store_path : Directory containing index.faiss + index.pkl
(from semantic.build_and_save_vector_store())
k : Number of documents to return per query
bm25_weight : RRF weight for BM25 (keyword signal). Default 0.5.
semantic_weight : RRF weight for semantic (meaning signal). Default 0.5.
Weights don't need to sum to 1 but relative scale matters.
rrf_c : RRF rank-dampening constant. Default 60 (standard).
fetch_multiplier : Candidates to fetch per retriever = k * fetch_multiplier.
Returns
-------
HybridRetriever
A LangChain-compatible retriever pipeable with |.
Example
-------
>>> from src.hybrid import load_hybrid_retriever
>>> from src.rag_pipeline import run_rag
>>>
>>> hybrid = load_hybrid_retriever(k=5)
>>> answer = run_rag(hybrid, "Best coffee beans for a French press")
>>> print(answer)
"""
# Import here to avoid circular imports when used from rag_pipeline.py
from src.bm25 import load as load_bm25
from src.semantic import load_vector_store
print(f"Loading BM25 index from: {bm25_index_path}")
bm25_ret: BM25Retriever = load_bm25(bm25_index_path)
print(f"Loading FAISS store from: {faiss_store_path}")
faiss_store: FAISS = load_vector_store(faiss_store_path)
retriever = HybridRetriever(
bm25_retriever=bm25_ret,
semantic_store=faiss_store,
k=k,
bm25_weight=bm25_weight,
semantic_weight=semantic_weight,
rrf_c=rrf_c,
fetch_multiplier=fetch_multiplier,
)
print(
f"HybridRetriever ready β€” k={k}, "
f"BM25 weight={bm25_weight}, Semantic weight={semantic_weight}, RRF c={rrf_c}"
)
return retriever