RAG / phase2_retrieval.py
sumitnewold's picture
Upload 10 files
76bd1fc verified
Raw
History Blame Contribute Delete
13.5 kB
"""
Phase 2 β€” Advanced Retrieval Pipeline
Steps:
1. Full ingestion β€” all 369 pages into Chroma (skipped if already done)
2. BM25 retriever β€” keyword search over full pages
3. Hybrid retrieval β€” BM25 (40%) + Dense (60%) via EnsembleRetriever
4. Query decomposition β€” break complex queries into sub-questions
5. HyDE β€” generate hypothetical answer, use it for dense retrieval
6. Cross-encoder reranking β€” accurate relevance scoring, keep top 4
7. Contextual compression β€” strip irrelevant sentences via LLM
Public API for Phase 3:
advanced_retrieval_pipeline(query, bm25_retriever, dense_retriever,
cross_encoder, llm) -> list[Document]
"""
import json
import os
import time
from dotenv import load_dotenv
from langchain_classic.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_classic.retrievers.document_compressors import (
EmbeddingsFilter,
LLMChainExtractor,
)
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sentence_transformers import CrossEncoder
from phase1_ingestion import (
PDF_PATH,
CHROMA_DIR,
SemanticParentSplitter,
build_embeddings,
build_llm,
load_pdf,
tag_financial_entities,
)
load_dotenv()
PHASE2_COLLECTION = "child_chunks_full"
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
FULL_INDEX_THRESHOLD = 1000 # min chunks to consider vectorstore fully indexed
CHILD_CHUNK_SIZE = 400
CHILD_CHUNK_OVERLAP = 50
# ── Step 1: Full ingestion ────────────────────────────────────────────────────
def load_or_build_vectorstore(embeddings, tagged_pages):
"""
Load persisted Chroma vectorstore or build it fresh from all pages.
Uses a dedicated Phase 2 collection so Phase 1's smoke-test data is untouched.
"""
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHILD_CHUNK_SIZE, chunk_overlap=CHILD_CHUNK_OVERLAP
)
vectorstore = Chroma(
collection_name=PHASE2_COLLECTION,
embedding_function=embeddings,
persist_directory=CHROMA_DIR,
)
count = vectorstore._collection.count()
if count >= FULL_INDEX_THRESHOLD:
print(f"[INGESTION] Vectorstore already has {count} chunks β€” skipping re-index")
return vectorstore
print(f"[INGESTION] Indexing all {len(tagged_pages)} pages into '{PHASE2_COLLECTION}'...")
all_chunks = child_splitter.split_documents(tagged_pages)
print(f"[INGESTION] {len(all_chunks)} child chunks to embed")
t0 = time.time()
batch_size = 100
for i in range(0, len(all_chunks), batch_size):
batch = all_chunks[i : i + batch_size]
vectorstore.add_documents(batch)
done = min(i + batch_size, len(all_chunks))
print(f" [{done}/{len(all_chunks)}]", end="\r")
print(f"\n[INGESTION] Done in {time.time() - t0:.1f}s β€” "
f"{vectorstore._collection.count()} chunks indexed")
return vectorstore
# ── Step 2: BM25 retriever ────────────────────────────────────────────────────
def build_bm25_retriever(tagged_pages, k: int = 8) -> BM25Retriever:
bm25 = BM25Retriever.from_documents(tagged_pages)
bm25.k = k
print(f"[BM25] Built over {len(tagged_pages)} pages (k={k})")
return bm25
# ── Step 3: Dense retriever ───────────────────────────────────────────────────
def build_dense_retriever(vectorstore, k: int = 8):
retriever = vectorstore.as_retriever(search_kwargs={"k": k})
print(f"[DENSE] Chroma retriever ready (k={k})")
return retriever
# ── Step 4: Hybrid EnsembleRetriever ─────────────────────────────────────────
def build_hybrid_retriever(bm25_retriever, dense_retriever) -> EnsembleRetriever:
hybrid = EnsembleRetriever(
retrievers=[bm25_retriever, dense_retriever],
weights=[0.4, 0.6],
)
print("[HYBRID] EnsembleRetriever: BM25 40% + Dense 60%")
return hybrid
# ── Step 5: Query decomposition ───────────────────────────────────────────────
_DECOMP_PROMPT = """You are a financial research assistant.
Break this question into 2-4 simpler sub-questions that together cover the full answer.
Return ONLY a valid JSON array of strings with no extra text.
Question: {query}
Sub-questions:"""
# Signals that a query genuinely covers multiple distinct sub-questions.
_MULTI_INTENT_MARKERS = (
"compare", "comparison", "versus", " vs ", "difference between",
"as well as", "and also", "both", "list all", "breakdown",
)
def is_complex_query(query: str) -> bool:
"""
Token-free heuristic deciding whether decomposition is worth an LLM call.
Decomposition HELPS multi-part questions but HURTS simple factual ones
(baseline showed 'operating margin?' decomposed into 'what is the formula…',
pulling irrelevant docs). Default: treat as simple unless clearly multi-intent.
"""
q = query.lower()
if query.count("?") > 1:
return True
if any(m in q for m in _MULTI_INTENT_MARKERS):
return True
# "X and Y" where both look like separate asks (long query with a conjunction)
if " and " in q and len(query.split()) > 18:
return True
return False
def decompose_query(query: str, llm) -> list[str]:
# Skip the LLM call entirely for simple factual queries (saves tokens +
# latency and avoids candidate dilution).
if not is_complex_query(query):
print("[DECOMP] Simple query β€” skipping decomposition")
return [query]
try:
resp = llm.invoke(_DECOMP_PROMPT.format(query=query))
sub_queries = json.loads(resp.content)
if isinstance(sub_queries, list) and all(isinstance(q, str) for q in sub_queries):
print(f"[DECOMP] {len(sub_queries)} sub-queries: {sub_queries}")
return sub_queries
except Exception:
pass
print("[DECOMP] Fallback β€” using original query")
return [query]
# ── Step 6: HyDE ─────────────────────────────────────────────────────────────
_HYDE_PROMPT = """You are a financial analyst writing an annual report excerpt.
Write 2-3 sentences from a financial document that directly answers this question.
Write ONLY the document text, no preamble.
Question: {query}
Excerpt:"""
def generate_hyde(query: str, llm) -> str:
# Degrade gracefully: if the LLM is unavailable (e.g. rate limit), fall back
# to the raw query so dense retrieval still runs instead of crashing.
try:
return llm.invoke(_HYDE_PROMPT.format(query=query)).content
except Exception:
print("[HyDE] LLM unavailable β€” falling back to raw query")
return query
# ── Step 7: Cross-encoder reranking ──────────────────────────────────────────
def build_cross_encoder() -> CrossEncoder:
model = CrossEncoder(CROSS_ENCODER_MODEL)
print(f"[RERANKER] Loaded {CROSS_ENCODER_MODEL}")
return model
def rerank(query: str, docs: list[Document], cross_encoder: CrossEncoder,
top_k: int = 4, min_keep: int = 1, rel_margin: float = 4.0) -> list[Document]:
"""
Rerank with the cross-encoder, then DROP docs much weaker than the best one
instead of always padding to top_k. Baseline kept 3/4 irrelevant docs that
diluted the LLM context; relative-margin filtering removes them while always
keeping at least `min_keep` so an answer can still be grounded.
"""
if not docs:
return docs
pairs = [(query, doc.page_content) for doc in docs]
scores = cross_encoder.predict(pairs)
ranked = sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)
best = float(ranked[0][0])
kept = [doc for score, doc in ranked[:top_k] if float(score) >= best - rel_margin]
if len(kept) < min_keep:
kept = [doc for _, doc in ranked[:min_keep]]
print(f"[RERANKER] {len(docs)} candidates β†’ {len(kept)} kept "
f"(best={best:.2f}, margin={rel_margin})")
return kept
# ── Step 8: Contextual compression ───────────────────────────────────────────
def compress_docs(query: str, docs: list[Document], llm,
embeddings=None) -> list[Document]:
"""
Token-free contextual compression.
The original implementation ran an LLMChainExtractor per doc β€” up to 4 extra
LLM calls per query, a major driver of the token exhaustion + latency seen in
the baseline, while leaving the doc set essentially unchanged. The cross-encoder
rerank already selects the relevant docs, so when embeddings are available we
do a cheap EmbeddingsFilter (no LLM); otherwise we pass the docs through.
"""
if not docs:
return docs
if embeddings is not None:
try:
efilter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.20)
result = list(efilter.compress_documents(docs, query))
kept = result if result else docs
print(f"[COMPRESS] EmbeddingsFilter {len(docs)} β†’ {len(kept)} (no LLM)")
return kept
except Exception:
pass
print(f"[COMPRESS] pass-through {len(docs)} docs (no LLM)")
return docs
# ── Full pipeline (imported by Phase 3) ──────────────────────────────────────
def advanced_retrieval_pipeline(
query: str,
bm25_retriever: BM25Retriever,
dense_retriever,
cross_encoder: CrossEncoder,
llm,
embeddings=None,
top_k: int = 4,
) -> list[Document]:
"""
Complete 5-step retrieval pipeline.
BM25 searches by keyword on the original sub-query.
Dense retriever searches by meaning on the HyDE-generated hypothetical answer.
This separation is intentional: BM25 catches exact financial terms (e.g. 'EBITDA'),
dense catches semantically related content the keywords might miss.
"""
# 1. Decompose complex query into sub-questions
sub_queries = decompose_query(query, llm)
# 2. For each sub-query: HyDE β†’ separate BM25 + dense retrieval β†’ merge
seen: set[str] = set()
all_docs: list[Document] = []
for sub_q in sub_queries:
hyde_text = generate_hyde(sub_q, llm)
# BM25 uses the original sub-query (keyword precision)
bm25_docs = bm25_retriever.invoke(sub_q)
# Dense uses HyDE text (document-space semantic matching)
dense_docs = dense_retriever.invoke(hyde_text)
for doc in bm25_docs + dense_docs:
if doc.page_content not in seen:
seen.add(doc.page_content)
all_docs.append(doc)
print(f"[PIPELINE] {len(all_docs)} unique candidates from {len(sub_queries)} sub-queries")
# 3. Cross-encoder reranking
top_docs = rerank(query, all_docs, cross_encoder, top_k=top_k)
# 4. Contextual compression
final_docs = compress_docs(query, top_docs, llm, embeddings)
return final_docs
# ── Smoke test ────────────────────────────────────────────────────────────────
def run_pipeline_test(bm25_retriever, dense_retriever, cross_encoder, llm, embeddings):
test_queries = [
"What was Infosys revenue and EBITDA margin in FY25?",
"Compare the risk factors and compliance issues mentioned in the annual report",
]
for query in test_queries:
print(f"\n{'='*65}")
print(f"QUERY: {query}")
print("=" * 65)
docs = advanced_retrieval_pipeline(
query, bm25_retriever, dense_retriever, cross_encoder, llm, embeddings
)
for i, doc in enumerate(docs, 1):
page = doc.metadata.get("page", "?")
print(f"\n[Doc {i} | Page {page}]\n{doc.page_content[:400]}")
print()
# ── Main ─────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
# Re-use Phase 1 setup
llm = build_llm(verify=True)
embeddings = build_embeddings()
pages = load_pdf(PDF_PATH)
tagged_pages = tag_financial_entities(pages)
# Full ingestion (skipped if already done)
vectorstore = load_or_build_vectorstore(embeddings, tagged_pages)
# Build retrievers
bm25_retriever = build_bm25_retriever(tagged_pages)
dense_retriever = build_dense_retriever(vectorstore)
cross_encoder = build_cross_encoder()
# Run test
run_pipeline_test(bm25_retriever, dense_retriever, cross_encoder, llm, embeddings)