Spaces:
Running
Running
File size: 9,925 Bytes
681ec3c e51a05a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 | """
src/hybrid.py
-------------
Hybrid retriever combining BM25 keyword search and FAISS semantic search,
fused with Reciprocal Rank Fusion (RRF).
Designed to plug into the existing run_rag() pipeline in rag_pipeline.py
as a drop-in replacement for the semantic retriever:
hybrid_retriever = load_hybrid_retriever(
bm25_index_path="data/processed/tokenisation/bm25_index_mini.pkl",
faiss_store_path="data/processed/embeddings",
k=5,
)
answer = run_rag(hybrid_retriever, "Best coffee beans for espresso")
The HybridRetriever class extends LangChain's BaseRetriever so it is fully
compatible with the | (pipe) operator used in rag_pipeline.py:
rag_chain = (
{
"context": hybrid_retriever | RunnableLambda(build_context),
"question": RunnablePassthrough(),
}
| prompt_template
| llm
| StrOutputParser()
)
"""
from __future__ import annotations
import logging
from typing import Any
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import Field
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# HybridRetriever
# ---------------------------------------------------------------------------
class HybridRetriever(BaseRetriever):
"""
Combines BM25 keyword retrieval and FAISS semantic retrieval using
Reciprocal Rank Fusion (RRF) to produce a unified ranked document list.
RRF score for document d across retriever r:
score(d) = weight_r * (1 / (rrf_c + rank(d, r)))
Documents appearing in both retrievers accumulate scores from both,
naturally promoting results that are relevant by both keyword and meaning.
Parameters
----------
bm25_retriever : Fitted LangChain BM25Retriever (from bm25.load())
semantic_store : Loaded FAISS vectorstore (from semantic.load_vector_store())
k : Number of final documents to return
rrf_c : RRF constant β dampens the impact of rank differences.
Standard value is 60; lower = top ranks matter more.
bm25_weight : RRF weight for BM25 results (keyword signal)
semantic_weight : RRF weight for semantic results (meaning signal)
fetch_multiplier : Fetch this multiple of k from each retriever before fusing.
More candidates = better fusion quality. Default: 3.
"""
bm25_retriever: Any = Field(...)
semantic_store: Any = Field(...)
k: int = Field(default=5)
rrf_c: int = Field(default=60)
bm25_weight: float = Field(default=0.5)
semantic_weight: float = Field(default=0.5)
fetch_multiplier: int = Field(default=3)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
"""
Core retrieval logic called by LangChain when the retriever is invoked.
Steps
-----
1. Fetch candidates from BM25 and FAISS independently
2. Assign RRF scores weighted by retriever confidence
3. Deduplicate by parent_asin, accumulating scores for shared hits
4. Sort by fused RRF score and return top-k Documents
"""
fetch_k = self.k * self.fetch_multiplier
# ββ 1. BM25 retrieval ββββββββββββββββββββββββββββββββββββββββββββββββ
self.bm25_retriever.k = fetch_k
try:
bm25_docs: list[Document] = self.bm25_retriever.invoke(query)
logger.debug("BM25 returned %d docs for query: %r", len(bm25_docs), query)
except Exception as exc:
logger.warning("BM25 retrieval failed: %s β using empty list.", exc)
bm25_docs = []
# ββ 2. Semantic retrieval ββββββββββββββββββββββββββββββββββββββββββββ
# similarity_search returns list[Document] (no scores needed β rank is enough for RRF)
try:
semantic_docs: list[Document] = self.semantic_store.similarity_search(
query, k=fetch_k
)
logger.debug(
"Semantic returned %d docs for query: %r", len(semantic_docs), query
)
except Exception as exc:
logger.warning("Semantic retrieval failed: %s β using empty list.", exc)
semantic_docs = []
# ββ 3. RRF fusion ββββββββββββββββββββββββββββββββββββββββββββββββββββ
rrf_scores: dict[str, float] = {}
doc_map: dict[str, Document] = {}
def _asin_key(doc: Document, fallback: str) -> str:
"""Use parent_asin as the dedup key; fall back to a content prefix."""
return doc.metadata.get("parent_asin") or fallback
for rank, doc in enumerate(bm25_docs):
key = _asin_key(doc, f"bm25_{rank}")
score = self.bm25_weight / (self.rrf_c + rank + 1)
rrf_scores[key] = rrf_scores.get(key, 0.0) + score
doc_map[key] = doc # BM25 docs have richer metadata (top_reviews etc.)
for rank, doc in enumerate(semantic_docs):
key = _asin_key(doc, f"sem_{rank}")
score = self.semantic_weight / (self.rrf_c + rank + 1)
rrf_scores[key] = rrf_scores.get(key, 0.0) + score
# Only add to doc_map if BM25 didn't already supply this product
# (BM25 metadata is richer β has top_reviews, image_url, etc.)
if key not in doc_map:
doc_map[key] = doc
# ββ 4. Sort and truncate βββββββββββββββββββββββββββββββββββββββββββββ
ranked_keys = sorted(rrf_scores, key=lambda k: rrf_scores[k], reverse=True)
top_docs = [doc_map[key] for key in ranked_keys[: self.k]]
# Attach fused score to metadata β useful for app display
for key, doc in zip(ranked_keys, top_docs):
doc.metadata["hybrid_score"] = round(rrf_scores[key], 6)
# Record which retriever(s) contributed to this result
in_bm25 = any(
_asin_key(d, f"bm25_{i}") == key for i, d in enumerate(bm25_docs)
)
in_sem = any(
_asin_key(d, f"sem_{i}") == key for i, d in enumerate(semantic_docs)
)
if in_bm25 and in_sem:
doc.metadata["retrieval_source"] = "hybrid"
elif in_bm25:
doc.metadata["retrieval_source"] = "bm25"
else:
doc.metadata["retrieval_source"] = "semantic"
logger.info(
"HybridRetriever: BM25=%d, Semantic=%d β fused=%d (returning top %d)",
len(bm25_docs), len(semantic_docs), len(rrf_scores), len(top_docs),
)
return top_docs
# ---------------------------------------------------------------------------
# Convenience loader
# ---------------------------------------------------------------------------
def load_hybrid_retriever(
bm25_index_path: str = "data/processed/tokenisation/bm25_index_mini.pkl",
faiss_store_path: str = "data/processed/embeddings",
k: int = 5,
bm25_weight: float = 0.5,
semantic_weight: float = 0.5,
rrf_c: int = 60,
fetch_multiplier: int = 3,
) -> HybridRetriever:
"""
Load both indexes from disk and return a ready-to-use HybridRetriever.
Call this once in your notebook or app.py, then pass the result to run_rag().
Parameters
----------
bm25_index_path : Path to the pickled BM25Retriever (from bm25.build_and_save())
faiss_store_path : Directory containing index.faiss + index.pkl
(from semantic.build_and_save_vector_store())
k : Number of documents to return per query
bm25_weight : RRF weight for BM25 (keyword signal). Default 0.5.
semantic_weight : RRF weight for semantic (meaning signal). Default 0.5.
Weights don't need to sum to 1 but relative scale matters.
rrf_c : RRF rank-dampening constant. Default 60 (standard).
fetch_multiplier : Candidates to fetch per retriever = k * fetch_multiplier.
Returns
-------
HybridRetriever
A LangChain-compatible retriever pipeable with |.
Example
-------
>>> from src.hybrid import load_hybrid_retriever
>>> from src.rag_pipeline import run_rag
>>>
>>> hybrid = load_hybrid_retriever(k=5)
>>> answer = run_rag(hybrid, "Best coffee beans for a French press")
>>> print(answer)
"""
# Import here to avoid circular imports when used from rag_pipeline.py
from src.bm25 import load as load_bm25
from src.semantic import load_vector_store
print(f"Loading BM25 index from: {bm25_index_path}")
bm25_ret: BM25Retriever = load_bm25(bm25_index_path)
print(f"Loading FAISS store from: {faiss_store_path}")
faiss_store: FAISS = load_vector_store(faiss_store_path)
retriever = HybridRetriever(
bm25_retriever=bm25_ret,
semantic_store=faiss_store,
k=k,
bm25_weight=bm25_weight,
semantic_weight=semantic_weight,
rrf_c=rrf_c,
fetch_multiplier=fetch_multiplier,
)
print(
f"HybridRetriever ready β k={k}, "
f"BM25 weight={bm25_weight}, Semantic weight={semantic_weight}, RRF c={rrf_c}"
)
return retriever
|