""" retriever.py ============ Phase 5 – Retrieval Chain Retrieves the most relevant chunks from ChromaDB for a given query, optionally re-ranks them with a cross-encoder, and assembles a context block ready for LLM generation. Two-stage retrieval (when rerank=True) --------------------------------------- Stage 1 — Dense retrieval (all-MiniLM-L6-v2 + cosine ANN search) Fast approximate nearest-neighbour search over the full vector store. Fetches n_fetch candidates (default: n_results × 4). Stage 2 — Cross-encoder reranking (ms-marco-MiniLM-L-6-v2) Computes a relevance score for each (query, candidate_text) pair. More accurate than cosine similarity because it reads both texts together — captures keyword overlap, negations, and fine-grained semantic relationships. Returns top n_results by reranker score. Why two stages? Cosine similarity on 384-dim vectors is fast (milliseconds on CPU) but is an approximation. Cross-encoders are exact but O(n) — running them over the full collection would be prohibitively slow. Fetching ~20 candidates with dense retrieval and reranking those 20 gives the best of both worlds. Context assembly ----------------- build_context() formats retrieved chunks into a numbered list with source attribution lines — making it easy for the LLM to produce accurate citations: [1] Apple 10-K FY2024 (filed 2024-11-01) | § PART I > Item 1. Business "Apple designs, manufactures and markets smartphones..." [2] Apple 10-Q Q2 2025 | § Financial Statements [TABLE] | Net sales | $95,358 | $90,753 | Usage ------ from src.retriever import FinancialRetriever retriever = FinancialRetriever(rerank=True) # Simple query chunks = retriever.retrieve("Apple revenue Q1 2025", n_results=5) context = retriever.build_context(chunks) # Filtered query — only Apple 10-K FY2024 chunks = retriever.retrieve( "What are Apple's main risk factors?", n_results = 5, filters = {"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]}, ) """ import os import json import hashlib import logging from pathlib import Path from dotenv import load_dotenv load_dotenv() # "chromadb" (local default) or "qdrant" (set QDRANT_URL to auto-switch) _BACKEND = "qdrant" if os.getenv("QDRANT_URL") else "chromadb" # ── Logging ──────────────────────────────────────────────────────────────────── logging.basicConfig( level = logging.INFO, format = "%(asctime)s %(levelname)-8s %(message)s", ) log = logging.getLogger(__name__) # ── Paths & constants ────────────────────────────────────────────────────────── BASE_DIR = Path(__file__).parent.parent VECTORSTORE_DIR = BASE_DIR / "data" / "vectorstore" COLLECTION_NAME = "financial_docs" EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" def _table_to_labelled_text(text: str) -> str: """ Convert SEC markdown table to explicit key:value lines for LLM context. Small LLMs (≤3B) have recency bias — they read the last dollar amount in a multi-year table row and associate it with the most recently mentioned year, ignoring the column header. This function pairs every value with its column header explicitly: BEFORE (raw markdown): | | 2024 | | 2023 | | 2022 | | | Total net sales | $ | 391,035 | | $ | 383,285 | | $ | 394,328 | AFTER (labelled text): Total net sales: 2024=$391,035 2023=$383,285 2022=$394,328 Steps: 1. Strip markdown pipes and separator rows → list of cell lists 2. Identify the header row (first non-empty row) 3. Remove empty / `$` currency cells, merging `$` with following number 4. Output "label: header1=val1 header2=val2 ..." per data row """ def _parse_cells(line: str) -> list[str]: """Split a markdown row into non-empty cells.""" cells = [c.strip() for c in line.strip().strip("|").split("|")] # Merge $ with following number: ["$", "391,035"] → ["$391,035"] merged, i = [], 0 while i < len(cells): if cells[i] == "$" and i + 1 < len(cells) and cells[i + 1] not in ("", "---"): merged.append("$" + cells[i + 1]) i += 2 else: merged.append(cells[i]) i += 1 return [c for c in merged if c and c != "---"] raw_rows = [] for line in text.splitlines(): stripped = line.strip() if not stripped.startswith("|"): continue # Skip pure separator rows inner = stripped.strip("|").replace("-", "").replace(" ", "").replace("|", "") if not inner: continue cells = _parse_cells(stripped) if cells: raw_rows.append(cells) if not raw_rows: return _strip_table_markdown(text) # fallback # Identify the header row (usually row 0; skip if only one row) if len(raw_rows) < 2: return _strip_table_markdown(text) headers = raw_rows[0] # e.g. ["2024", "2023", "2022"] data = raw_rows[1:] # If the header row has no meaningful year/label content, fall back if not any(h for h in headers): return _strip_table_markdown(text) lines = [] for row in data: if not row: continue # First cell is the row label; remaining cells are values label = row[0] if row else "" values = row[1:] if not label and not values: continue # Pair values with headers pairs = [] for j, val in enumerate(values): if not val: continue hdr = headers[j] if j < len(headers) else "" if hdr: pairs.append(f"{hdr}={val}") else: pairs.append(val) if pairs: lines.append(f"{label}: {' '.join(pairs)}" if label else " ".join(pairs)) elif label: lines.append(label) return "\n".join(lines) if lines else _strip_table_markdown(text) def _strip_table_markdown(text: str) -> str: """ Convert markdown table syntax to plain text for cross-encoder scoring. The ms-marco cross-encoder was trained on natural-language passages. Markdown pipe characters and separator rows (| --- |) cause low scores even when the table contains the exact answer. Stripping them lets the reranker see the raw labels and numbers and score them correctly. Example: "| Total net sales | 391,035 | 383,285 |" → "Total net sales 391,035 383,285" """ lines = [] for line in text.splitlines(): # Drop pure separator rows like | --- | --- | stripped = line.strip() if stripped.startswith("|") and all( c in "|- " for c in stripped.replace("|", "") ): continue # Remove leading/trailing pipes and collapse whitespace line = stripped.strip("|") line = " ".join(line.split("|")) line = " ".join(line.split()) if line: lines.append(line) return "\n".join(lines) # ══════════════════════════════════════════════════════════════════════════════ # FINANCIAL RETRIEVER # ══════════════════════════════════════════════════════════════════════════════ class FinancialRetriever: """ Retrieves relevant chunks from the ChromaDB financial_docs collection. Supports: - Dense similarity search (all-MiniLM-L6-v2) - Metadata filtering (source, doc_type, ticker, fiscal_year, ...) - Cross-encoder reranking (ms-marco-MiniLM-L-6-v2) - Context assembly (numbered, attributed, LLM-ready) """ def __init__( self, vectorstore_dir : Path = VECTORSTORE_DIR, collection_name : str = COLLECTION_NAME, embedding_model : str = EMBEDDING_MODEL, rerank : bool = False, reranker_model : str = RERANKER_MODEL, ): self.rerank = rerank self._collection_name = collection_name self._backend = _BACKEND if self._backend == "qdrant": self._init_qdrant(collection_name, embedding_model) else: self._init_chromadb(vectorstore_dir, collection_name, embedding_model) # ── Load cross-encoder reranker if requested ────────────────────────── self._reranker = None if rerank: from sentence_transformers import CrossEncoder self._reranker = CrossEncoder(reranker_model) log.info(f"Reranker loaded: {reranker_model}") def _init_chromadb(self, vectorstore_dir, collection_name, embedding_model): import chromadb from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction client = chromadb.PersistentClient(path=str(vectorstore_dir)) ef = SentenceTransformerEmbeddingFunction(model_name=embedding_model) self.collection = client.get_collection( name = collection_name, embedding_function = ef, ) log.info( f"[ChromaDB] Connected to '{collection_name}' " f"({self.collection.count()} vectors)" ) def _init_qdrant(self, collection_name, embedding_model): from qdrant_client import QdrantClient from sentence_transformers import SentenceTransformer self._qdrant = QdrantClient( url = os.getenv("QDRANT_URL"), api_key = os.getenv("QDRANT_API_KEY"), ) self._embed_model = SentenceTransformer(embedding_model) count = self._qdrant.count(collection_name).count log.info( f"[Qdrant] Connected to '{collection_name}' " f"({count} vectors)" ) # ── Core retrieval ───────────────────────────────────────────────────────── def retrieve( self, query : str, n_results : int = 5, filters : dict = None, n_fetch : int = None, ) -> list[dict]: """ Return the top-n most relevant chunks for a query. Args: query : natural language question n_results : number of chunks to return filters : ChromaDB where clause, e.g. {"source": "sec_edgar"} {"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]} n_fetch : candidates fetched before reranking (default: n_results × 4 when reranking, else n_results) Returns: list of dicts, each with: id, text, metadata, score (cosine-sim 0–1, or reranker score) """ # Use ×10 multiplier (min 50) so financial table chunks that rank lower # in dense search (due to sparse markdown text) still reach the reranker. fetch = n_fetch or (max(n_results * 10, 50) if self.rerank else n_results) # ── Stage 1: dense retrieval ────────────────────────────────────────── if self._backend == "qdrant": candidates = self._query_qdrant(query, fetch) else: candidates = self._query_chromadb(query, fetch, filters) if not self.rerank or self._reranker is None: return candidates[:n_results] # ── Stage 2: cross-encoder reranking ────────────────────────────────── # ms-marco-MiniLM was trained on text passages; markdown table syntax # (| --- | pipes) causes near-zero scores even for exact-match tables. # Strip markdown formatting for table chunks so the reranker sees # the raw numbers and labels, which it can match to the query. pairs = [(query, _strip_table_markdown(c["text"]) if c["metadata"].get("chunk_type") == "table" else c["text"]) for c in candidates] scores = self._reranker.predict(pairs) for c, s in zip(candidates, scores): c["dense_score"] = c["score"] # keep original for comparison c["rerank_score"] = float(s) c["score"] = float(s) # override with reranker score ranked = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True) return ranked[:n_results] # ── Backend query helpers ────────────────────────────────────────────────── def _query_chromadb(self, query: str, fetch: int, filters: dict) -> list[dict]: kwargs = { "query_texts" : [query], "n_results" : fetch, "include" : ["documents", "metadatas", "distances"], } if filters: kwargs["where"] = filters raw = self.collection.query(**kwargs) return [ { "id" : raw["ids"][0][i], "text" : raw["documents"][0][i], "metadata" : raw["metadatas"][0][i], "score" : round(1 - raw["distances"][0][i] / 2, 4), } for i in range(len(raw["ids"][0])) ] def _query_qdrant(self, query: str, fetch: int) -> list[dict]: query_vector = self._embed_model.encode(query).tolist() results = self._qdrant.search( collection_name = self._collection_name, query_vector = query_vector, limit = fetch, with_payload = True, ) return [ { "id" : str(r.id), "text" : r.payload.get("text", ""), "metadata" : {k: v for k, v in r.payload.items() if k != "text"}, "score" : round(r.score, 4), } for r in results ] # ── Context assembly ─────────────────────────────────────────────────────── def build_context( self, chunks : list[dict], max_chars : int = 6000, ) -> str: """ Format retrieved chunks into an LLM-ready context block. Each chunk is prefixed with a source attribution line so the LLM can produce accurate citations in its answer. Format: [1] Apple 10-K FY2024 (filed 2024-11-01) | § PART I > Item 1 [TABLE] "Apple designs, manufactures and markets smartphones..." Args: chunks : list returned by retrieve() max_chars : hard total-length limit (prevents exceeding LLM context) Returns: formatted multi-chunk context string """ parts = [] total = 0 for i, c in enumerate(chunks, 1): m = c["metadata"] # ── Build source attribution line ───────────────────────────────── source_parts = [] if m.get("source") == "sec_edgar": dt = m.get("doc_type", "") fy = m.get("fiscal_year", "") fd = m.get("filing_date", "") label = f"Apple {dt}" if fy: label += f" FY{fy}" if fd: label += f" (filed {fd})" source_parts.append(label) else: doc_type = m.get("doc_type", m.get("source", "")) company = m.get("company", "") if company: source_parts.append(f"{doc_type} — {company}") else: source_parts.append(doc_type) heading = m.get("heading_path") or m.get("section_title") or "" if heading: source_parts.append(f"§ {heading[:80]}") pg = m.get("page_num") if pg: source_parts.append(f"p.{pg}") chunk_type = m.get("chunk_type", "text") suffix = " [TABLE]" if chunk_type == "table" else "" # Convert SEC table markdown to labelled key:value format for LLM. # Raw markdown has | $ | 391,035 | in separate columns with empty # spacers; small LLMs misread multi-year tables due to recency bias. # _table_to_labelled_text pairs each cell with its column header so # "Total net sales: 2024=$391,035 2023=$383,285 2022=$394,328". text = (_table_to_labelled_text(c["text"]) if chunk_type == "table" else c["text"]) header = f"[{i}] " + " | ".join(source_parts) + suffix block = f"{header}\n{text}" if total + len(block) > max_chars: log.info(f" Context limit reached at chunk {i} — truncating") break parts.append(block) total += len(block) return "\n\n---\n\n".join(parts) # ── Convenience: LangChain-compatible retriever ──────────────────────────── def as_langchain_retriever( self, n_results : int = 5, filters : dict = None, ): """ Wrap this retriever as a LangChain BaseRetriever for use in LCEL chains. Returns a LangChain retriever that calls self.retrieve() internally and returns LangChain Document objects. Usage: lc_retriever = retriever.as_langchain_retriever(n_results=5) chain = lc_retriever | format_docs | llm """ from langchain_core.retrievers import BaseRetriever from langchain_core.documents import Document from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun outer = self # reference to FinancialRetriever class _LCRetriever(BaseRetriever): def _get_relevant_documents( self, query : str, *, run_manager: CallbackManagerForRetrieverRun, ) -> list[Document]: chunks = outer.retrieve( query = query, n_results = n_results, filters = filters, ) return [ Document( page_content = c["text"], metadata = {**c["metadata"], "score": c["score"]}, ) for c in chunks ] return _LCRetriever() # ── Collection info ──────────────────────────────────────────────────────── def get_stats(self) -> dict: """Return a summary of the collection contents.""" from collections import Counter if self._backend == "qdrant": count = self._qdrant.count(self._collection_name).count if count == 0: return {"total": 0} records, _ = self._qdrant.scroll( collection_name = self._collection_name, limit = count, with_payload = True, with_vectors = False, ) all_meta = [r.payload for r in records] else: count = self.collection.count() if count == 0: return {"total": 0} all_meta = self.collection.get( limit = count, include = ["metadatas"], )["metadatas"] return { "total" : count, "by_source" : dict(Counter(m.get("source", "") for m in all_meta)), "by_doc_type" : dict(Counter(m.get("doc_type", "") for m in all_meta)), "by_chunk_type": dict(Counter(m.get("chunk_type", "") for m in all_meta)), }