Financial_bot / src /retriever.py
Pushkya's picture
Upload 30 files
8299003 verified
Raw
History Blame Contribute Delete
21 kB
"""
retriever.py
============
Phase 5 – Retrieval Chain
Retrieves the most relevant chunks from ChromaDB for a given query,
optionally re-ranks them with a cross-encoder, and assembles a context
block ready for LLM generation.
Two-stage retrieval (when rerank=True)
---------------------------------------
Stage 1 β€” Dense retrieval (all-MiniLM-L6-v2 + cosine ANN search)
Fast approximate nearest-neighbour search over the full vector store.
Fetches n_fetch candidates (default: n_results Γ— 4).
Stage 2 β€” Cross-encoder reranking (ms-marco-MiniLM-L-6-v2)
Computes a relevance score for each (query, candidate_text) pair.
More accurate than cosine similarity because it reads both texts
together β€” captures keyword overlap, negations, and fine-grained
semantic relationships.
Returns top n_results by reranker score.
Why two stages?
Cosine similarity on 384-dim vectors is fast (milliseconds on CPU)
but is an approximation. Cross-encoders are exact but O(n) β€”
running them over the full collection would be prohibitively slow.
Fetching ~20 candidates with dense retrieval and reranking those 20
gives the best of both worlds.
Context assembly
-----------------
build_context() formats retrieved chunks into a numbered list with
source attribution lines β€” making it easy for the LLM to produce
accurate citations:
[1] Apple 10-K FY2024 (filed 2024-11-01) | Β§ PART I > Item 1. Business
"Apple designs, manufactures and markets smartphones..."
[2] Apple 10-Q Q2 2025 | Β§ Financial Statements [TABLE]
| Net sales | $95,358 | $90,753 |
Usage
------
from src.retriever import FinancialRetriever
retriever = FinancialRetriever(rerank=True)
# Simple query
chunks = retriever.retrieve("Apple revenue Q1 2025", n_results=5)
context = retriever.build_context(chunks)
# Filtered query β€” only Apple 10-K FY2024
chunks = retriever.retrieve(
"What are Apple's main risk factors?",
n_results = 5,
filters = {"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
)
"""
import os
import json
import hashlib
import logging
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
# "chromadb" (local default) or "qdrant" (set QDRANT_URL to auto-switch)
_BACKEND = "qdrant" if os.getenv("QDRANT_URL") else "chromadb"
# ── Logging ────────────────────────────────────────────────────────────────────
logging.basicConfig(
level = logging.INFO,
format = "%(asctime)s %(levelname)-8s %(message)s",
)
log = logging.getLogger(__name__)
# ── Paths & constants ──────────────────────────────────────────────────────────
BASE_DIR = Path(__file__).parent.parent
VECTORSTORE_DIR = BASE_DIR / "data" / "vectorstore"
COLLECTION_NAME = "financial_docs"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
def _table_to_labelled_text(text: str) -> str:
"""
Convert SEC markdown table to explicit key:value lines for LLM context.
Small LLMs (≀3B) have recency bias β€” they read the last dollar amount in a
multi-year table row and associate it with the most recently mentioned year,
ignoring the column header. This function pairs every value with its column
header explicitly:
BEFORE (raw markdown):
| | 2024 | | 2023 | | 2022 | |
| Total net sales | $ | 391,035 | | $ | 383,285 | | $ | 394,328 |
AFTER (labelled text):
Total net sales: 2024=$391,035 2023=$383,285 2022=$394,328
Steps:
1. Strip markdown pipes and separator rows β†’ list of cell lists
2. Identify the header row (first non-empty row)
3. Remove empty / `$` currency cells, merging `$` with following number
4. Output "label: header1=val1 header2=val2 ..." per data row
"""
def _parse_cells(line: str) -> list[str]:
"""Split a markdown row into non-empty cells."""
cells = [c.strip() for c in line.strip().strip("|").split("|")]
# Merge $ with following number: ["$", "391,035"] β†’ ["$391,035"]
merged, i = [], 0
while i < len(cells):
if cells[i] == "$" and i + 1 < len(cells) and cells[i + 1] not in ("", "---"):
merged.append("$" + cells[i + 1])
i += 2
else:
merged.append(cells[i])
i += 1
return [c for c in merged if c and c != "---"]
raw_rows = []
for line in text.splitlines():
stripped = line.strip()
if not stripped.startswith("|"):
continue
# Skip pure separator rows
inner = stripped.strip("|").replace("-", "").replace(" ", "").replace("|", "")
if not inner:
continue
cells = _parse_cells(stripped)
if cells:
raw_rows.append(cells)
if not raw_rows:
return _strip_table_markdown(text) # fallback
# Identify the header row (usually row 0; skip if only one row)
if len(raw_rows) < 2:
return _strip_table_markdown(text)
headers = raw_rows[0] # e.g. ["2024", "2023", "2022"]
data = raw_rows[1:]
# If the header row has no meaningful year/label content, fall back
if not any(h for h in headers):
return _strip_table_markdown(text)
lines = []
for row in data:
if not row:
continue
# First cell is the row label; remaining cells are values
label = row[0] if row else ""
values = row[1:]
if not label and not values:
continue
# Pair values with headers
pairs = []
for j, val in enumerate(values):
if not val:
continue
hdr = headers[j] if j < len(headers) else ""
if hdr:
pairs.append(f"{hdr}={val}")
else:
pairs.append(val)
if pairs:
lines.append(f"{label}: {' '.join(pairs)}" if label else " ".join(pairs))
elif label:
lines.append(label)
return "\n".join(lines) if lines else _strip_table_markdown(text)
def _strip_table_markdown(text: str) -> str:
"""
Convert markdown table syntax to plain text for cross-encoder scoring.
The ms-marco cross-encoder was trained on natural-language passages.
Markdown pipe characters and separator rows (| --- |) cause low scores
even when the table contains the exact answer. Stripping them lets the
reranker see the raw labels and numbers and score them correctly.
Example:
"| Total net sales | 391,035 | 383,285 |"
β†’ "Total net sales 391,035 383,285"
"""
lines = []
for line in text.splitlines():
# Drop pure separator rows like | --- | --- |
stripped = line.strip()
if stripped.startswith("|") and all(
c in "|- " for c in stripped.replace("|", "")
):
continue
# Remove leading/trailing pipes and collapse whitespace
line = stripped.strip("|")
line = " ".join(line.split("|"))
line = " ".join(line.split())
if line:
lines.append(line)
return "\n".join(lines)
# ══════════════════════════════════════════════════════════════════════════════
# FINANCIAL RETRIEVER
# ══════════════════════════════════════════════════════════════════════════════
class FinancialRetriever:
"""
Retrieves relevant chunks from the ChromaDB financial_docs collection.
Supports:
- Dense similarity search (all-MiniLM-L6-v2)
- Metadata filtering (source, doc_type, ticker, fiscal_year, ...)
- Cross-encoder reranking (ms-marco-MiniLM-L-6-v2)
- Context assembly (numbered, attributed, LLM-ready)
"""
def __init__(
self,
vectorstore_dir : Path = VECTORSTORE_DIR,
collection_name : str = COLLECTION_NAME,
embedding_model : str = EMBEDDING_MODEL,
rerank : bool = False,
reranker_model : str = RERANKER_MODEL,
):
self.rerank = rerank
self._collection_name = collection_name
self._backend = _BACKEND
if self._backend == "qdrant":
self._init_qdrant(collection_name, embedding_model)
else:
self._init_chromadb(vectorstore_dir, collection_name, embedding_model)
# ── Load cross-encoder reranker if requested ──────────────────────────
self._reranker = None
if rerank:
from sentence_transformers import CrossEncoder
self._reranker = CrossEncoder(reranker_model)
log.info(f"Reranker loaded: {reranker_model}")
def _init_chromadb(self, vectorstore_dir, collection_name, embedding_model):
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
client = chromadb.PersistentClient(path=str(vectorstore_dir))
ef = SentenceTransformerEmbeddingFunction(model_name=embedding_model)
self.collection = client.get_collection(
name = collection_name,
embedding_function = ef,
)
log.info(
f"[ChromaDB] Connected to '{collection_name}' "
f"({self.collection.count()} vectors)"
)
def _init_qdrant(self, collection_name, embedding_model):
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer
self._qdrant = QdrantClient(
url = os.getenv("QDRANT_URL"),
api_key = os.getenv("QDRANT_API_KEY"),
)
self._embed_model = SentenceTransformer(embedding_model)
count = self._qdrant.count(collection_name).count
log.info(
f"[Qdrant] Connected to '{collection_name}' "
f"({count} vectors)"
)
# ── Core retrieval ─────────────────────────────────────────────────────────
def retrieve(
self,
query : str,
n_results : int = 5,
filters : dict = None,
n_fetch : int = None,
) -> list[dict]:
"""
Return the top-n most relevant chunks for a query.
Args:
query : natural language question
n_results : number of chunks to return
filters : ChromaDB where clause, e.g.
{"source": "sec_edgar"}
{"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]}
n_fetch : candidates fetched before reranking
(default: n_results Γ— 4 when reranking, else n_results)
Returns:
list of dicts, each with:
id, text, metadata, score (cosine-sim 0–1, or reranker score)
"""
# Use Γ—10 multiplier (min 50) so financial table chunks that rank lower
# in dense search (due to sparse markdown text) still reach the reranker.
fetch = n_fetch or (max(n_results * 10, 50) if self.rerank else n_results)
# ── Stage 1: dense retrieval ──────────────────────────────────────────
if self._backend == "qdrant":
candidates = self._query_qdrant(query, fetch)
else:
candidates = self._query_chromadb(query, fetch, filters)
if not self.rerank or self._reranker is None:
return candidates[:n_results]
# ── Stage 2: cross-encoder reranking ──────────────────────────────────
# ms-marco-MiniLM was trained on text passages; markdown table syntax
# (| --- | pipes) causes near-zero scores even for exact-match tables.
# Strip markdown formatting for table chunks so the reranker sees
# the raw numbers and labels, which it can match to the query.
pairs = [(query, _strip_table_markdown(c["text"])
if c["metadata"].get("chunk_type") == "table" else c["text"])
for c in candidates]
scores = self._reranker.predict(pairs)
for c, s in zip(candidates, scores):
c["dense_score"] = c["score"] # keep original for comparison
c["rerank_score"] = float(s)
c["score"] = float(s) # override with reranker score
ranked = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True)
return ranked[:n_results]
# ── Backend query helpers ──────────────────────────────────────────────────
def _query_chromadb(self, query: str, fetch: int, filters: dict) -> list[dict]:
kwargs = {
"query_texts" : [query],
"n_results" : fetch,
"include" : ["documents", "metadatas", "distances"],
}
if filters:
kwargs["where"] = filters
raw = self.collection.query(**kwargs)
return [
{
"id" : raw["ids"][0][i],
"text" : raw["documents"][0][i],
"metadata" : raw["metadatas"][0][i],
"score" : round(1 - raw["distances"][0][i] / 2, 4),
}
for i in range(len(raw["ids"][0]))
]
def _query_qdrant(self, query: str, fetch: int) -> list[dict]:
query_vector = self._embed_model.encode(query).tolist()
results = self._qdrant.search(
collection_name = self._collection_name,
query_vector = query_vector,
limit = fetch,
with_payload = True,
)
return [
{
"id" : str(r.id),
"text" : r.payload.get("text", ""),
"metadata" : {k: v for k, v in r.payload.items() if k != "text"},
"score" : round(r.score, 4),
}
for r in results
]
# ── Context assembly ───────────────────────────────────────────────────────
def build_context(
self,
chunks : list[dict],
max_chars : int = 6000,
) -> str:
"""
Format retrieved chunks into an LLM-ready context block.
Each chunk is prefixed with a source attribution line so the LLM
can produce accurate citations in its answer.
Format:
[1] Apple 10-K FY2024 (filed 2024-11-01) | Β§ PART I > Item 1 [TABLE]
"Apple designs, manufactures and markets smartphones..."
Args:
chunks : list returned by retrieve()
max_chars : hard total-length limit (prevents exceeding LLM context)
Returns:
formatted multi-chunk context string
"""
parts = []
total = 0
for i, c in enumerate(chunks, 1):
m = c["metadata"]
# ── Build source attribution line ─────────────────────────────────
source_parts = []
if m.get("source") == "sec_edgar":
dt = m.get("doc_type", "")
fy = m.get("fiscal_year", "")
fd = m.get("filing_date", "")
label = f"Apple {dt}"
if fy:
label += f" FY{fy}"
if fd:
label += f" (filed {fd})"
source_parts.append(label)
else:
doc_type = m.get("doc_type", m.get("source", ""))
company = m.get("company", "")
if company:
source_parts.append(f"{doc_type} β€” {company}")
else:
source_parts.append(doc_type)
heading = m.get("heading_path") or m.get("section_title") or ""
if heading:
source_parts.append(f"Β§ {heading[:80]}")
pg = m.get("page_num")
if pg:
source_parts.append(f"p.{pg}")
chunk_type = m.get("chunk_type", "text")
suffix = " [TABLE]" if chunk_type == "table" else ""
# Convert SEC table markdown to labelled key:value format for LLM.
# Raw markdown has | $ | 391,035 | in separate columns with empty
# spacers; small LLMs misread multi-year tables due to recency bias.
# _table_to_labelled_text pairs each cell with its column header so
# "Total net sales: 2024=$391,035 2023=$383,285 2022=$394,328".
text = (_table_to_labelled_text(c["text"])
if chunk_type == "table" else c["text"])
header = f"[{i}] " + " | ".join(source_parts) + suffix
block = f"{header}\n{text}"
if total + len(block) > max_chars:
log.info(f" Context limit reached at chunk {i} β€” truncating")
break
parts.append(block)
total += len(block)
return "\n\n---\n\n".join(parts)
# ── Convenience: LangChain-compatible retriever ────────────────────────────
def as_langchain_retriever(
self,
n_results : int = 5,
filters : dict = None,
):
"""
Wrap this retriever as a LangChain BaseRetriever for use in LCEL chains.
Returns a LangChain retriever that calls self.retrieve() internally
and returns LangChain Document objects.
Usage:
lc_retriever = retriever.as_langchain_retriever(n_results=5)
chain = lc_retriever | format_docs | llm
"""
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
outer = self # reference to FinancialRetriever
class _LCRetriever(BaseRetriever):
def _get_relevant_documents(
self,
query : str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
chunks = outer.retrieve(
query = query,
n_results = n_results,
filters = filters,
)
return [
Document(
page_content = c["text"],
metadata = {**c["metadata"], "score": c["score"]},
)
for c in chunks
]
return _LCRetriever()
# ── Collection info ────────────────────────────────────────────────────────
def get_stats(self) -> dict:
"""Return a summary of the collection contents."""
from collections import Counter
if self._backend == "qdrant":
count = self._qdrant.count(self._collection_name).count
if count == 0:
return {"total": 0}
records, _ = self._qdrant.scroll(
collection_name = self._collection_name,
limit = count,
with_payload = True,
with_vectors = False,
)
all_meta = [r.payload for r in records]
else:
count = self.collection.count()
if count == 0:
return {"total": 0}
all_meta = self.collection.get(
limit = count,
include = ["metadatas"],
)["metadatas"]
return {
"total" : count,
"by_source" : dict(Counter(m.get("source", "") for m in all_meta)),
"by_doc_type" : dict(Counter(m.get("doc_type", "") for m in all_meta)),
"by_chunk_type": dict(Counter(m.get("chunk_type", "") for m in all_meta)),
}