"""Schema retriever — handles DB schemas (source_type="database") and tabular file columns stored as source_type="document" with file_type in ("csv","xlsx"). Strategy: hybrid_bm25 — RRF merge of dense cosine search (DB columns + DB tables + tabular columns + tabular sheets) and PostgreSQL full-text search (DB columns only). Embeds the query once, fans out five legs in parallel. The DB-tables leg surfaces table-level summary chunks (chunk_level='table') as a recall signal for multi-table questions: when a relevant table's columns don't individually win on similarity, the table chunk can still pull the table into the hit set, where db_executor's downstream full-schema fetch picks up the per-column detail. FTS requires a GIN index on langchain_pg_embedding.document (created by init_db.py). """ import asyncio from sqlalchemy import text from src.db.postgres.connection import _pgvector_engine from src.db.postgres.vector_store import get_vector_store from src.middlewares.logging import get_logger from src.rag.base import BaseRetriever, RetrievalResult logger = get_logger("schema_retriever") _TABULAR_FILE_TYPES = ("csv", "xlsx") _TABLE_CHUNK_K_MULTIPLIER = 2 # how many table chunks to pull before RRF class SchemaRetriever(BaseRetriever): def __init__(self): self.vector_store = get_vector_store() # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ async def _embed_query(self, query: str) -> list[float]: return await asyncio.to_thread(self.vector_store.embeddings.embed_query, query) async def _search_db( self, embedding: list[float], user_id: str, k: int ) -> list[RetrievalResult]: """Cosine vector search over database chunks.""" emb_str = "[" + ",".join(str(x) for x in embedding) + "]" sql = text(f""" SELECT lpe.document, lpe.cmetadata, 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'database' AND lpe.cmetadata->>'chunk_level' = 'column' ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC LIMIT :k """) async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, {"user_id": user_id, "k": k * 4}) rows = result.fetchall() return [ RetrievalResult( content=row.document, metadata=row.cmetadata, score=float(row.score), source_type="database", ) for row in rows ] async def _search_db_tables( self, embedding: list[float], user_id: str, k: int ) -> list[RetrievalResult]: """Cosine vector search over database TABLE-level chunks. Recall channel for multi-table questions. The chunk's content is discarded downstream — db_executor only consumes its `data.table_name` to seed full-schema fetch. """ emb_str = "[" + ",".join(str(x) for x in embedding) + "]" sql = text(f""" SELECT lpe.document, lpe.cmetadata, 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'database' AND lpe.cmetadata->>'chunk_level' = 'table' ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC LIMIT :k """) async with _pgvector_engine.connect() as conn: result = await conn.execute( sql, {"user_id": user_id, "k": k * _TABLE_CHUNK_K_MULTIPLIER} ) rows = result.fetchall() return [ RetrievalResult( content=row.document, metadata=row.cmetadata, score=float(row.score), source_type="database", ) for row in rows ] async def _search_tabular( self, embedding: list[float], user_id: str, k: int ) -> list[RetrievalResult]: """Cosine vector search over tabular document chunks (csv/xlsx).""" emb_str = "[" + ",".join(str(x) for x in embedding) + "]" sql = text(f""" SELECT lpe.document, lpe.cmetadata, 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'document' AND lpe.cmetadata->>'chunk_level' = 'column' AND (lpe.cmetadata->'data'->>'file_type' = 'csv' OR lpe.cmetadata->'data'->>'file_type' = 'xlsx') ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC LIMIT :k """) async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, {"user_id": user_id, "k": k * 4}) rows = result.fetchall() return [ RetrievalResult( content=row.document, metadata=row.cmetadata, score=float(row.score), source_type="document", ) for row in rows ] async def _search_tabular_sheets( self, embedding: list[float], user_id: str, k: int ) -> list[RetrievalResult]: """Leg 5: sheet-level summary chunks from CSV/XLSX files.""" emb_str = "[" + ",".join(str(x) for x in embedding) + "]" sql = text(f""" SELECT lpe.document, lpe.cmetadata, 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'document' AND lpe.cmetadata->>'chunk_level' = 'sheet' AND (lpe.cmetadata->'data'->>'file_type' = 'csv' OR lpe.cmetadata->'data'->>'file_type' = 'xlsx') ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC LIMIT :k """) async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, {"user_id": user_id, "k": k}) rows = result.fetchall() return [ RetrievalResult( content=row.document, metadata=row.cmetadata, score=float(row.score), source_type="document", ) for row in rows ] async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]: """Full-text search over DB schema chunks using PostgreSQL tsvector.""" sql = text(""" SELECT lpe.document, lpe.cmetadata, ts_rank(to_tsvector('english', lpe.document), plainto_tsquery('english', :query)) AS rank FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'database' AND lpe.cmetadata->>'chunk_level' = 'column' AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query) ORDER BY rank DESC LIMIT :k """) async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k}) rows = result.fetchall() return [ RetrievalResult( content=row.document, metadata=row.cmetadata, score=float(row.rank), source_type="database", ) for row in rows ] def _rank_tabular_sheets( self, sheet_results: list[RetrievalResult], column_results: list[RetrievalResult], top_k: int, k_rrf: int = 60, ) -> list[RetrievalResult]: """Rank tabular sheets by RRF across two voting legs: L1 (primary): sheet-chunk cosine score L2 (vote): best column-chunk position per (doc_id, sheet_name) Returns top-k sheet-level RetrievalResults. The full column list of each sheet is already in the sheet chunk's data.column_names from ingestion, so downstream tabular_executor can read full sheet context. For sheets surfaced by column votes but missing a sheet chunk (rare — ingestion always creates one), a minimal stub is returned and tabular_executor falls back to reading columns from the parquet. """ # L1: sheets indexed by (doc_id, sheet_name) from sheet chunks sheet_index: dict[tuple, RetrievalResult] = {} sheet_ranked: list[tuple] = [] for r in sheet_results: d = r.metadata.get("data", {}) key = (d.get("document_id"), d.get("sheet_name")) if key[0] and key not in sheet_index: sheet_index[key] = r sheet_ranked.append(key) # L2: sheets ranked by first-appearance in column-chunk results col_sheet_ranked: list[tuple] = [] seen: set[tuple] = set() for r in column_results: d = r.metadata.get("data", {}) key = (d.get("document_id"), d.get("sheet_name")) if key[0] and key not in seen: col_sheet_ranked.append(key) seen.add(key) # RRF over (doc_id, sheet_name) across the two legs rrf_scores: dict[tuple, float] = {} for ranked_list in [sheet_ranked, col_sheet_ranked]: for rank, key in enumerate(ranked_list): rrf_scores[key] = rrf_scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1) top_sheets = sorted(rrf_scores, key=lambda k: rrf_scores[k], reverse=True)[:top_k] results: list[RetrievalResult] = [] for key in top_sheets: if key in sheet_index: r = sheet_index[key] r.score = rrf_scores[key] results.append(r) else: # Surfaced by column votes only — build stub from a representative # column result so tabular_executor can group correctly. doc_id, sheet_name = key rep = next( (r for r in column_results if r.metadata.get("data", {}).get("document_id") == doc_id and r.metadata.get("data", {}).get("sheet_name") == sheet_name), None, ) if rep is None: continue stub_data = dict(rep.metadata.get("data", {})) stub_data.pop("column_name", None) stub_data.pop("column_type", None) results.append(RetrievalResult( content=f"Sheet: {stub_data.get('filename', '')}" + (f" / sheet: {sheet_name}" if sheet_name else ""), metadata={**rep.metadata, "data": stub_data, "chunk_level": "sheet"}, score=rrf_scores[key], source_type="document", )) return results def _rank_db_tables( self, tbl_results: list[RetrievalResult], col_results: list[RetrievalResult], fts_results: list[RetrievalResult], top_k: int, k_rrf: int = 60, ) -> list[RetrievalResult]: """Rank DB tables by RRF across three legs: L1 (primary): table-summary chunk similarity L2 (vote): best column-chunk position per table L3 (vote): best FTS position per table Returns top-k table-chunk RetrievalResults. For tables surfaced by L2/L3 but missing a table chunk, a minimal stub is returned so that db_executor._fetch_full_schema can seed off data.table_name. """ # L1: tables ranked by table-chunk cosine score tbl_index: dict[str, RetrievalResult] = {} tbl_ranked: list[str] = [] for r in tbl_results: tname = r.metadata.get("data", {}).get("table_name") if tname and tname not in tbl_index: tbl_index[tname] = r tbl_ranked.append(tname) # L2: tables ranked by first-appearance in column-chunk list (best col score) col_table_ranked: list[str] = [] seen: set[str] = set() for r in col_results: tname = r.metadata.get("data", {}).get("table_name") if tname and tname not in seen: col_table_ranked.append(tname) seen.add(tname) # L3: tables ranked by first-appearance in FTS list fts_table_ranked: list[str] = [] seen = set() for r in fts_results: tname = r.metadata.get("data", {}).get("table_name") if tname and tname not in seen: fts_table_ranked.append(tname) seen.add(tname) # RRF over table names across the three legs rrf_scores: dict[str, float] = {} for ranked_list in [tbl_ranked, col_table_ranked, fts_table_ranked]: for rank, tname in enumerate(ranked_list): rrf_scores[tname] = rrf_scores.get(tname, 0.0) + 1.0 / (k_rrf + rank + 1) top_tables = sorted(rrf_scores, key=lambda t: rrf_scores[t], reverse=True)[:top_k] results: list[RetrievalResult] = [] for tname in top_tables: if tname in tbl_index: r = tbl_index[tname] r.score = rrf_scores[tname] results.append(r) else: # Surfaced by column/FTS votes with no table chunk — minimal stub results.append(RetrievalResult( content=f"Table: {tname}", metadata={"data": {"table_name": tname}, "source_type": "database"}, score=rrf_scores[tname], source_type="database", )) return results # ------------------------------------------------------------------ # Public interface — called by the router # ------------------------------------------------------------------ async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: """Table-first retrieval for DB sources; chunk-level for tabular. DB tables are ranked via RRF across three legs: L1 (primary): table-summary chunk similarity L2 (vote): top-K column-chunk cosine, grouped by table L3 (vote): top-K FTS column hits, grouped by table db_executor downstream fetches the full per-column schema for the ranked table set via _fetch_full_schema — the column chunks returned here are intentionally NOT used as the schema source, only for voting. Tabular (CSV/XLSX) sheets are ranked via RRF across two legs: L1: sheet-chunk cosine L2: column-chunk votes (best position per sheet) Returns sheet-level RetrievalResults so tabular_executor receives full sheet context (all columns) rather than fragmented column hits. """ embedding = await self._embed_query(query) db_col_results, db_tbl_results, tabular_results, fts_results, sheet_results = await asyncio.gather( self._search_db(embedding, user_id, k), self._search_db_tables(embedding, user_id, k), self._search_tabular(embedding, user_id, k), self._search_fts_db(query, user_id, k * 4), self._search_tabular_sheets(embedding, user_id, k), ) db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k) tabular_ranked = self._rank_tabular_sheets(sheet_results, tabular_results, top_k=k) results = sorted(db_ranked + tabular_ranked, key=lambda r: r.score, reverse=True) logger.info( "schema retrieval", count=len(results), db_tables_ranked=len(db_ranked), db_cols=len(db_col_results), db_tables=len(db_tbl_results), tabular_cols=len(tabular_results), tabular_sheets=len(sheet_results), tabular_ranked=len(tabular_ranked), fts=len(fts_results), ) return results schema_retriever = SchemaRetriever()