Spaces:
Running
Running
| """ | |
| retriever.py | |
| ============ | |
| Phase 5 β Retrieval Chain | |
| Retrieves the most relevant chunks from ChromaDB for a given query, | |
| optionally re-ranks them with a cross-encoder, and assembles a context | |
| block ready for LLM generation. | |
| Two-stage retrieval (when rerank=True) | |
| --------------------------------------- | |
| Stage 1 β Dense retrieval (all-MiniLM-L6-v2 + cosine ANN search) | |
| Fast approximate nearest-neighbour search over the full vector store. | |
| Fetches n_fetch candidates (default: n_results Γ 4). | |
| Stage 2 β Cross-encoder reranking (ms-marco-MiniLM-L-6-v2) | |
| Computes a relevance score for each (query, candidate_text) pair. | |
| More accurate than cosine similarity because it reads both texts | |
| together β captures keyword overlap, negations, and fine-grained | |
| semantic relationships. | |
| Returns top n_results by reranker score. | |
| Why two stages? | |
| Cosine similarity on 384-dim vectors is fast (milliseconds on CPU) | |
| but is an approximation. Cross-encoders are exact but O(n) β | |
| running them over the full collection would be prohibitively slow. | |
| Fetching ~20 candidates with dense retrieval and reranking those 20 | |
| gives the best of both worlds. | |
| Context assembly | |
| ----------------- | |
| build_context() formats retrieved chunks into a numbered list with | |
| source attribution lines β making it easy for the LLM to produce | |
| accurate citations: | |
| [1] Apple 10-K FY2024 (filed 2024-11-01) | Β§ PART I > Item 1. Business | |
| "Apple designs, manufactures and markets smartphones..." | |
| [2] Apple 10-Q Q2 2025 | Β§ Financial Statements [TABLE] | |
| | Net sales | $95,358 | $90,753 | | |
| Usage | |
| ------ | |
| from src.retriever import FinancialRetriever | |
| retriever = FinancialRetriever(rerank=True) | |
| # Simple query | |
| chunks = retriever.retrieve("Apple revenue Q1 2025", n_results=5) | |
| context = retriever.build_context(chunks) | |
| # Filtered query β only Apple 10-K FY2024 | |
| chunks = retriever.retrieve( | |
| "What are Apple's main risk factors?", | |
| n_results = 5, | |
| filters = {"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]}, | |
| ) | |
| """ | |
| import os | |
| import json | |
| import hashlib | |
| import logging | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # "chromadb" (local default) or "qdrant" (set QDRANT_URL to auto-switch) | |
| _BACKEND = "qdrant" if os.getenv("QDRANT_URL") else "chromadb" | |
| # ββ Logging ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig( | |
| level = logging.INFO, | |
| format = "%(asctime)s %(levelname)-8s %(message)s", | |
| ) | |
| log = logging.getLogger(__name__) | |
| # ββ Paths & constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_DIR = Path(__file__).parent.parent | |
| VECTORSTORE_DIR = BASE_DIR / "data" / "vectorstore" | |
| COLLECTION_NAME = "financial_docs" | |
| EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| def _table_to_labelled_text(text: str) -> str: | |
| """ | |
| Convert SEC markdown table to explicit key:value lines for LLM context. | |
| Small LLMs (β€3B) have recency bias β they read the last dollar amount in a | |
| multi-year table row and associate it with the most recently mentioned year, | |
| ignoring the column header. This function pairs every value with its column | |
| header explicitly: | |
| BEFORE (raw markdown): | |
| | | 2024 | | 2023 | | 2022 | | | |
| | Total net sales | $ | 391,035 | | $ | 383,285 | | $ | 394,328 | | |
| AFTER (labelled text): | |
| Total net sales: 2024=$391,035 2023=$383,285 2022=$394,328 | |
| Steps: | |
| 1. Strip markdown pipes and separator rows β list of cell lists | |
| 2. Identify the header row (first non-empty row) | |
| 3. Remove empty / `$` currency cells, merging `$` with following number | |
| 4. Output "label: header1=val1 header2=val2 ..." per data row | |
| """ | |
| def _parse_cells(line: str) -> list[str]: | |
| """Split a markdown row into non-empty cells.""" | |
| cells = [c.strip() for c in line.strip().strip("|").split("|")] | |
| # Merge $ with following number: ["$", "391,035"] β ["$391,035"] | |
| merged, i = [], 0 | |
| while i < len(cells): | |
| if cells[i] == "$" and i + 1 < len(cells) and cells[i + 1] not in ("", "---"): | |
| merged.append("$" + cells[i + 1]) | |
| i += 2 | |
| else: | |
| merged.append(cells[i]) | |
| i += 1 | |
| return [c for c in merged if c and c != "---"] | |
| raw_rows = [] | |
| for line in text.splitlines(): | |
| stripped = line.strip() | |
| if not stripped.startswith("|"): | |
| continue | |
| # Skip pure separator rows | |
| inner = stripped.strip("|").replace("-", "").replace(" ", "").replace("|", "") | |
| if not inner: | |
| continue | |
| cells = _parse_cells(stripped) | |
| if cells: | |
| raw_rows.append(cells) | |
| if not raw_rows: | |
| return _strip_table_markdown(text) # fallback | |
| # Identify the header row (usually row 0; skip if only one row) | |
| if len(raw_rows) < 2: | |
| return _strip_table_markdown(text) | |
| headers = raw_rows[0] # e.g. ["2024", "2023", "2022"] | |
| data = raw_rows[1:] | |
| # If the header row has no meaningful year/label content, fall back | |
| if not any(h for h in headers): | |
| return _strip_table_markdown(text) | |
| lines = [] | |
| for row in data: | |
| if not row: | |
| continue | |
| # First cell is the row label; remaining cells are values | |
| label = row[0] if row else "" | |
| values = row[1:] | |
| if not label and not values: | |
| continue | |
| # Pair values with headers | |
| pairs = [] | |
| for j, val in enumerate(values): | |
| if not val: | |
| continue | |
| hdr = headers[j] if j < len(headers) else "" | |
| if hdr: | |
| pairs.append(f"{hdr}={val}") | |
| else: | |
| pairs.append(val) | |
| if pairs: | |
| lines.append(f"{label}: {' '.join(pairs)}" if label else " ".join(pairs)) | |
| elif label: | |
| lines.append(label) | |
| return "\n".join(lines) if lines else _strip_table_markdown(text) | |
| def _strip_table_markdown(text: str) -> str: | |
| """ | |
| Convert markdown table syntax to plain text for cross-encoder scoring. | |
| The ms-marco cross-encoder was trained on natural-language passages. | |
| Markdown pipe characters and separator rows (| --- |) cause low scores | |
| even when the table contains the exact answer. Stripping them lets the | |
| reranker see the raw labels and numbers and score them correctly. | |
| Example: | |
| "| Total net sales | 391,035 | 383,285 |" | |
| β "Total net sales 391,035 383,285" | |
| """ | |
| lines = [] | |
| for line in text.splitlines(): | |
| # Drop pure separator rows like | --- | --- | | |
| stripped = line.strip() | |
| if stripped.startswith("|") and all( | |
| c in "|- " for c in stripped.replace("|", "") | |
| ): | |
| continue | |
| # Remove leading/trailing pipes and collapse whitespace | |
| line = stripped.strip("|") | |
| line = " ".join(line.split("|")) | |
| line = " ".join(line.split()) | |
| if line: | |
| lines.append(line) | |
| return "\n".join(lines) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FINANCIAL RETRIEVER | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FinancialRetriever: | |
| """ | |
| Retrieves relevant chunks from the ChromaDB financial_docs collection. | |
| Supports: | |
| - Dense similarity search (all-MiniLM-L6-v2) | |
| - Metadata filtering (source, doc_type, ticker, fiscal_year, ...) | |
| - Cross-encoder reranking (ms-marco-MiniLM-L-6-v2) | |
| - Context assembly (numbered, attributed, LLM-ready) | |
| """ | |
| def __init__( | |
| self, | |
| vectorstore_dir : Path = VECTORSTORE_DIR, | |
| collection_name : str = COLLECTION_NAME, | |
| embedding_model : str = EMBEDDING_MODEL, | |
| rerank : bool = False, | |
| reranker_model : str = RERANKER_MODEL, | |
| ): | |
| self.rerank = rerank | |
| self._collection_name = collection_name | |
| self._backend = _BACKEND | |
| if self._backend == "qdrant": | |
| self._init_qdrant(collection_name, embedding_model) | |
| else: | |
| self._init_chromadb(vectorstore_dir, collection_name, embedding_model) | |
| # ββ Load cross-encoder reranker if requested ββββββββββββββββββββββββββ | |
| self._reranker = None | |
| if rerank: | |
| from sentence_transformers import CrossEncoder | |
| self._reranker = CrossEncoder(reranker_model) | |
| log.info(f"Reranker loaded: {reranker_model}") | |
| def _init_chromadb(self, vectorstore_dir, collection_name, embedding_model): | |
| import chromadb | |
| from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction | |
| client = chromadb.PersistentClient(path=str(vectorstore_dir)) | |
| ef = SentenceTransformerEmbeddingFunction(model_name=embedding_model) | |
| self.collection = client.get_collection( | |
| name = collection_name, | |
| embedding_function = ef, | |
| ) | |
| log.info( | |
| f"[ChromaDB] Connected to '{collection_name}' " | |
| f"({self.collection.count()} vectors)" | |
| ) | |
| def _init_qdrant(self, collection_name, embedding_model): | |
| from qdrant_client import QdrantClient | |
| from sentence_transformers import SentenceTransformer | |
| self._qdrant = QdrantClient( | |
| url = os.getenv("QDRANT_URL"), | |
| api_key = os.getenv("QDRANT_API_KEY"), | |
| ) | |
| self._embed_model = SentenceTransformer(embedding_model) | |
| count = self._qdrant.count(collection_name).count | |
| log.info( | |
| f"[Qdrant] Connected to '{collection_name}' " | |
| f"({count} vectors)" | |
| ) | |
| # ββ Core retrieval βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def retrieve( | |
| self, | |
| query : str, | |
| n_results : int = 5, | |
| filters : dict = None, | |
| n_fetch : int = None, | |
| ) -> list[dict]: | |
| """ | |
| Return the top-n most relevant chunks for a query. | |
| Args: | |
| query : natural language question | |
| n_results : number of chunks to return | |
| filters : ChromaDB where clause, e.g. | |
| {"source": "sec_edgar"} | |
| {"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]} | |
| n_fetch : candidates fetched before reranking | |
| (default: n_results Γ 4 when reranking, else n_results) | |
| Returns: | |
| list of dicts, each with: | |
| id, text, metadata, score (cosine-sim 0β1, or reranker score) | |
| """ | |
| # Use Γ10 multiplier (min 50) so financial table chunks that rank lower | |
| # in dense search (due to sparse markdown text) still reach the reranker. | |
| fetch = n_fetch or (max(n_results * 10, 50) if self.rerank else n_results) | |
| # ββ Stage 1: dense retrieval ββββββββββββββββββββββββββββββββββββββββββ | |
| if self._backend == "qdrant": | |
| candidates = self._query_qdrant(query, fetch) | |
| else: | |
| candidates = self._query_chromadb(query, fetch, filters) | |
| if not self.rerank or self._reranker is None: | |
| return candidates[:n_results] | |
| # ββ Stage 2: cross-encoder reranking ββββββββββββββββββββββββββββββββββ | |
| # ms-marco-MiniLM was trained on text passages; markdown table syntax | |
| # (| --- | pipes) causes near-zero scores even for exact-match tables. | |
| # Strip markdown formatting for table chunks so the reranker sees | |
| # the raw numbers and labels, which it can match to the query. | |
| pairs = [(query, _strip_table_markdown(c["text"]) | |
| if c["metadata"].get("chunk_type") == "table" else c["text"]) | |
| for c in candidates] | |
| scores = self._reranker.predict(pairs) | |
| for c, s in zip(candidates, scores): | |
| c["dense_score"] = c["score"] # keep original for comparison | |
| c["rerank_score"] = float(s) | |
| c["score"] = float(s) # override with reranker score | |
| ranked = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True) | |
| return ranked[:n_results] | |
| # ββ Backend query helpers ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _query_chromadb(self, query: str, fetch: int, filters: dict) -> list[dict]: | |
| kwargs = { | |
| "query_texts" : [query], | |
| "n_results" : fetch, | |
| "include" : ["documents", "metadatas", "distances"], | |
| } | |
| if filters: | |
| kwargs["where"] = filters | |
| raw = self.collection.query(**kwargs) | |
| return [ | |
| { | |
| "id" : raw["ids"][0][i], | |
| "text" : raw["documents"][0][i], | |
| "metadata" : raw["metadatas"][0][i], | |
| "score" : round(1 - raw["distances"][0][i] / 2, 4), | |
| } | |
| for i in range(len(raw["ids"][0])) | |
| ] | |
| def _query_qdrant(self, query: str, fetch: int) -> list[dict]: | |
| query_vector = self._embed_model.encode(query).tolist() | |
| results = self._qdrant.search( | |
| collection_name = self._collection_name, | |
| query_vector = query_vector, | |
| limit = fetch, | |
| with_payload = True, | |
| ) | |
| return [ | |
| { | |
| "id" : str(r.id), | |
| "text" : r.payload.get("text", ""), | |
| "metadata" : {k: v for k, v in r.payload.items() if k != "text"}, | |
| "score" : round(r.score, 4), | |
| } | |
| for r in results | |
| ] | |
| # ββ Context assembly βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_context( | |
| self, | |
| chunks : list[dict], | |
| max_chars : int = 6000, | |
| ) -> str: | |
| """ | |
| Format retrieved chunks into an LLM-ready context block. | |
| Each chunk is prefixed with a source attribution line so the LLM | |
| can produce accurate citations in its answer. | |
| Format: | |
| [1] Apple 10-K FY2024 (filed 2024-11-01) | Β§ PART I > Item 1 [TABLE] | |
| "Apple designs, manufactures and markets smartphones..." | |
| Args: | |
| chunks : list returned by retrieve() | |
| max_chars : hard total-length limit (prevents exceeding LLM context) | |
| Returns: | |
| formatted multi-chunk context string | |
| """ | |
| parts = [] | |
| total = 0 | |
| for i, c in enumerate(chunks, 1): | |
| m = c["metadata"] | |
| # ββ Build source attribution line βββββββββββββββββββββββββββββββββ | |
| source_parts = [] | |
| if m.get("source") == "sec_edgar": | |
| dt = m.get("doc_type", "") | |
| fy = m.get("fiscal_year", "") | |
| fd = m.get("filing_date", "") | |
| label = f"Apple {dt}" | |
| if fy: | |
| label += f" FY{fy}" | |
| if fd: | |
| label += f" (filed {fd})" | |
| source_parts.append(label) | |
| else: | |
| doc_type = m.get("doc_type", m.get("source", "")) | |
| company = m.get("company", "") | |
| if company: | |
| source_parts.append(f"{doc_type} β {company}") | |
| else: | |
| source_parts.append(doc_type) | |
| heading = m.get("heading_path") or m.get("section_title") or "" | |
| if heading: | |
| source_parts.append(f"Β§ {heading[:80]}") | |
| pg = m.get("page_num") | |
| if pg: | |
| source_parts.append(f"p.{pg}") | |
| chunk_type = m.get("chunk_type", "text") | |
| suffix = " [TABLE]" if chunk_type == "table" else "" | |
| # Convert SEC table markdown to labelled key:value format for LLM. | |
| # Raw markdown has | $ | 391,035 | in separate columns with empty | |
| # spacers; small LLMs misread multi-year tables due to recency bias. | |
| # _table_to_labelled_text pairs each cell with its column header so | |
| # "Total net sales: 2024=$391,035 2023=$383,285 2022=$394,328". | |
| text = (_table_to_labelled_text(c["text"]) | |
| if chunk_type == "table" else c["text"]) | |
| header = f"[{i}] " + " | ".join(source_parts) + suffix | |
| block = f"{header}\n{text}" | |
| if total + len(block) > max_chars: | |
| log.info(f" Context limit reached at chunk {i} β truncating") | |
| break | |
| parts.append(block) | |
| total += len(block) | |
| return "\n\n---\n\n".join(parts) | |
| # ββ Convenience: LangChain-compatible retriever ββββββββββββββββββββββββββββ | |
| def as_langchain_retriever( | |
| self, | |
| n_results : int = 5, | |
| filters : dict = None, | |
| ): | |
| """ | |
| Wrap this retriever as a LangChain BaseRetriever for use in LCEL chains. | |
| Returns a LangChain retriever that calls self.retrieve() internally | |
| and returns LangChain Document objects. | |
| Usage: | |
| lc_retriever = retriever.as_langchain_retriever(n_results=5) | |
| chain = lc_retriever | format_docs | llm | |
| """ | |
| from langchain_core.retrievers import BaseRetriever | |
| from langchain_core.documents import Document | |
| from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun | |
| outer = self # reference to FinancialRetriever | |
| class _LCRetriever(BaseRetriever): | |
| def _get_relevant_documents( | |
| self, | |
| query : str, | |
| *, | |
| run_manager: CallbackManagerForRetrieverRun, | |
| ) -> list[Document]: | |
| chunks = outer.retrieve( | |
| query = query, | |
| n_results = n_results, | |
| filters = filters, | |
| ) | |
| return [ | |
| Document( | |
| page_content = c["text"], | |
| metadata = {**c["metadata"], "score": c["score"]}, | |
| ) | |
| for c in chunks | |
| ] | |
| return _LCRetriever() | |
| # ββ Collection info ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_stats(self) -> dict: | |
| """Return a summary of the collection contents.""" | |
| from collections import Counter | |
| if self._backend == "qdrant": | |
| count = self._qdrant.count(self._collection_name).count | |
| if count == 0: | |
| return {"total": 0} | |
| records, _ = self._qdrant.scroll( | |
| collection_name = self._collection_name, | |
| limit = count, | |
| with_payload = True, | |
| with_vectors = False, | |
| ) | |
| all_meta = [r.payload for r in records] | |
| else: | |
| count = self.collection.count() | |
| if count == 0: | |
| return {"total": 0} | |
| all_meta = self.collection.get( | |
| limit = count, | |
| include = ["metadatas"], | |
| )["metadatas"] | |
| return { | |
| "total" : count, | |
| "by_source" : dict(Counter(m.get("source", "") for m in all_meta)), | |
| "by_doc_type" : dict(Counter(m.get("doc_type", "") for m in all_meta)), | |
| "by_chunk_type": dict(Counter(m.get("chunk_type", "") for m in all_meta)), | |
| } | |