| """Schema retriever — handles DB schemas (source_type="database") and tabular file |
| columns stored as source_type="document" with file_type in ("csv","xlsx"). |
| |
| Strategy: hybrid_bm25 — RRF merge of dense cosine search (DB columns + DB tables |
| + tabular columns + tabular sheets) and PostgreSQL full-text search (DB columns only). |
| Embeds the query once, fans out five legs in parallel. |
| |
| The DB-tables leg surfaces table-level summary chunks (chunk_level='table') as |
| a recall signal for multi-table questions: when a relevant table's columns |
| don't individually win on similarity, the table chunk can still pull the table |
| into the hit set, where db_executor's downstream full-schema fetch picks up |
| the per-column detail. |
| |
| FTS requires a GIN index on langchain_pg_embedding.document (created by init_db.py). |
| """ |
|
|
| import asyncio |
|
|
| from sqlalchemy import text |
|
|
| from src.db.postgres.connection import _pgvector_engine |
| from src.db.postgres.vector_store import get_vector_store |
| from src.middlewares.logging import get_logger |
| from src.rag.base import BaseRetriever, RetrievalResult |
|
|
| logger = get_logger("schema_retriever") |
|
|
| _TABULAR_FILE_TYPES = ("csv", "xlsx") |
| _TABLE_CHUNK_K_MULTIPLIER = 2 |
|
|
|
|
| class SchemaRetriever(BaseRetriever): |
| def __init__(self): |
| self.vector_store = get_vector_store() |
|
|
| |
| |
| |
|
|
| async def _embed_query(self, query: str) -> list[float]: |
| return await asyncio.to_thread(self.vector_store.embeddings.embed_query, query) |
|
|
| async def _search_db( |
| self, embedding: list[float], user_id: str, k: int |
| ) -> list[RetrievalResult]: |
| """Cosine vector search over database chunks.""" |
| emb_str = "[" + ",".join(str(x) for x in embedding) + "]" |
|
|
| sql = text(f""" |
| SELECT lpe.document, lpe.cmetadata, |
| 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'database' |
| AND lpe.cmetadata->>'chunk_level' = 'column' |
| ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC |
| LIMIT :k |
| """) |
|
|
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql, {"user_id": user_id, "k": k * 4}) |
| rows = result.fetchall() |
|
|
| return [ |
| RetrievalResult( |
| content=row.document, |
| metadata=row.cmetadata, |
| score=float(row.score), |
| source_type="database", |
| ) |
| for row in rows |
| ] |
|
|
| async def _search_db_tables( |
| self, embedding: list[float], user_id: str, k: int |
| ) -> list[RetrievalResult]: |
| """Cosine vector search over database TABLE-level chunks. |
| |
| Recall channel for multi-table questions. The chunk's content is |
| discarded downstream — db_executor only consumes its `data.table_name` |
| to seed full-schema fetch. |
| """ |
| emb_str = "[" + ",".join(str(x) for x in embedding) + "]" |
|
|
| sql = text(f""" |
| SELECT lpe.document, lpe.cmetadata, |
| 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'database' |
| AND lpe.cmetadata->>'chunk_level' = 'table' |
| ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC |
| LIMIT :k |
| """) |
|
|
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute( |
| sql, {"user_id": user_id, "k": k * _TABLE_CHUNK_K_MULTIPLIER} |
| ) |
| rows = result.fetchall() |
|
|
| return [ |
| RetrievalResult( |
| content=row.document, |
| metadata=row.cmetadata, |
| score=float(row.score), |
| source_type="database", |
| ) |
| for row in rows |
| ] |
|
|
| async def _search_tabular( |
| self, embedding: list[float], user_id: str, k: int |
| ) -> list[RetrievalResult]: |
| """Cosine vector search over tabular document chunks (csv/xlsx).""" |
| emb_str = "[" + ",".join(str(x) for x in embedding) + "]" |
|
|
| sql = text(f""" |
| SELECT lpe.document, lpe.cmetadata, |
| 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'document' |
| AND lpe.cmetadata->>'chunk_level' = 'column' |
| AND (lpe.cmetadata->'data'->>'file_type' = 'csv' |
| OR lpe.cmetadata->'data'->>'file_type' = 'xlsx') |
| ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC |
| LIMIT :k |
| """) |
|
|
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql, {"user_id": user_id, "k": k * 4}) |
| rows = result.fetchall() |
|
|
| return [ |
| RetrievalResult( |
| content=row.document, |
| metadata=row.cmetadata, |
| score=float(row.score), |
| source_type="document", |
| ) |
| for row in rows |
| ] |
|
|
| async def _search_tabular_sheets( |
| self, embedding: list[float], user_id: str, k: int |
| ) -> list[RetrievalResult]: |
| """Leg 5: sheet-level summary chunks from CSV/XLSX files.""" |
| emb_str = "[" + ",".join(str(x) for x in embedding) + "]" |
|
|
| sql = text(f""" |
| SELECT lpe.document, lpe.cmetadata, |
| 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'document' |
| AND lpe.cmetadata->>'chunk_level' = 'sheet' |
| AND (lpe.cmetadata->'data'->>'file_type' = 'csv' |
| OR lpe.cmetadata->'data'->>'file_type' = 'xlsx') |
| ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC |
| LIMIT :k |
| """) |
|
|
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql, {"user_id": user_id, "k": k}) |
| rows = result.fetchall() |
|
|
| return [ |
| RetrievalResult( |
| content=row.document, |
| metadata=row.cmetadata, |
| score=float(row.score), |
| source_type="document", |
| ) |
| for row in rows |
| ] |
|
|
| async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]: |
| """Full-text search over DB schema chunks using PostgreSQL tsvector.""" |
| sql = text(""" |
| SELECT lpe.document, lpe.cmetadata, |
| ts_rank(to_tsvector('english', lpe.document), |
| plainto_tsquery('english', :query)) AS rank |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'database' |
| AND lpe.cmetadata->>'chunk_level' = 'column' |
| AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query) |
| ORDER BY rank DESC |
| LIMIT :k |
| """) |
|
|
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k}) |
| rows = result.fetchall() |
|
|
| return [ |
| RetrievalResult( |
| content=row.document, |
| metadata=row.cmetadata, |
| score=float(row.rank), |
| source_type="database", |
| ) |
| for row in rows |
| ] |
|
|
| def _rank_tabular_sheets( |
| self, |
| sheet_results: list[RetrievalResult], |
| column_results: list[RetrievalResult], |
| top_k: int, |
| k_rrf: int = 60, |
| ) -> list[RetrievalResult]: |
| """Rank tabular sheets by RRF across two voting legs: |
| L1 (primary): sheet-chunk cosine score |
| L2 (vote): best column-chunk position per (doc_id, sheet_name) |
| |
| Returns top-k sheet-level RetrievalResults. The full column list of |
| each sheet is already in the sheet chunk's data.column_names from |
| ingestion, so downstream tabular_executor can read full sheet context. |
| |
| For sheets surfaced by column votes but missing a sheet chunk (rare — |
| ingestion always creates one), a minimal stub is returned and |
| tabular_executor falls back to reading columns from the parquet. |
| """ |
| |
| sheet_index: dict[tuple, RetrievalResult] = {} |
| sheet_ranked: list[tuple] = [] |
| for r in sheet_results: |
| d = r.metadata.get("data", {}) |
| key = (d.get("document_id"), d.get("sheet_name")) |
| if key[0] and key not in sheet_index: |
| sheet_index[key] = r |
| sheet_ranked.append(key) |
|
|
| |
| col_sheet_ranked: list[tuple] = [] |
| seen: set[tuple] = set() |
| for r in column_results: |
| d = r.metadata.get("data", {}) |
| key = (d.get("document_id"), d.get("sheet_name")) |
| if key[0] and key not in seen: |
| col_sheet_ranked.append(key) |
| seen.add(key) |
|
|
| |
| rrf_scores: dict[tuple, float] = {} |
| for ranked_list in [sheet_ranked, col_sheet_ranked]: |
| for rank, key in enumerate(ranked_list): |
| rrf_scores[key] = rrf_scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1) |
|
|
| top_sheets = sorted(rrf_scores, key=lambda k: rrf_scores[k], reverse=True)[:top_k] |
|
|
| results: list[RetrievalResult] = [] |
| for key in top_sheets: |
| if key in sheet_index: |
| r = sheet_index[key] |
| r.score = rrf_scores[key] |
| results.append(r) |
| else: |
| |
| |
| doc_id, sheet_name = key |
| rep = next( |
| (r for r in column_results |
| if r.metadata.get("data", {}).get("document_id") == doc_id |
| and r.metadata.get("data", {}).get("sheet_name") == sheet_name), |
| None, |
| ) |
| if rep is None: |
| continue |
| stub_data = dict(rep.metadata.get("data", {})) |
| stub_data.pop("column_name", None) |
| stub_data.pop("column_type", None) |
| results.append(RetrievalResult( |
| content=f"Sheet: {stub_data.get('filename', '')}" |
| + (f" / sheet: {sheet_name}" if sheet_name else ""), |
| metadata={**rep.metadata, "data": stub_data, "chunk_level": "sheet"}, |
| score=rrf_scores[key], |
| source_type="document", |
| )) |
| return results |
|
|
| def _rank_db_tables( |
| self, |
| tbl_results: list[RetrievalResult], |
| col_results: list[RetrievalResult], |
| fts_results: list[RetrievalResult], |
| top_k: int, |
| k_rrf: int = 60, |
| ) -> list[RetrievalResult]: |
| """Rank DB tables by RRF across three legs: |
| L1 (primary): table-summary chunk similarity |
| L2 (vote): best column-chunk position per table |
| L3 (vote): best FTS position per table |
| |
| Returns top-k table-chunk RetrievalResults. For tables surfaced by |
| L2/L3 but missing a table chunk, a minimal stub is returned so that |
| db_executor._fetch_full_schema can seed off data.table_name. |
| """ |
| |
| tbl_index: dict[str, RetrievalResult] = {} |
| tbl_ranked: list[str] = [] |
| for r in tbl_results: |
| tname = r.metadata.get("data", {}).get("table_name") |
| if tname and tname not in tbl_index: |
| tbl_index[tname] = r |
| tbl_ranked.append(tname) |
|
|
| |
| col_table_ranked: list[str] = [] |
| seen: set[str] = set() |
| for r in col_results: |
| tname = r.metadata.get("data", {}).get("table_name") |
| if tname and tname not in seen: |
| col_table_ranked.append(tname) |
| seen.add(tname) |
|
|
| |
| fts_table_ranked: list[str] = [] |
| seen = set() |
| for r in fts_results: |
| tname = r.metadata.get("data", {}).get("table_name") |
| if tname and tname not in seen: |
| fts_table_ranked.append(tname) |
| seen.add(tname) |
|
|
| |
| rrf_scores: dict[str, float] = {} |
| for ranked_list in [tbl_ranked, col_table_ranked, fts_table_ranked]: |
| for rank, tname in enumerate(ranked_list): |
| rrf_scores[tname] = rrf_scores.get(tname, 0.0) + 1.0 / (k_rrf + rank + 1) |
|
|
| top_tables = sorted(rrf_scores, key=lambda t: rrf_scores[t], reverse=True)[:top_k] |
|
|
| results: list[RetrievalResult] = [] |
| for tname in top_tables: |
| if tname in tbl_index: |
| r = tbl_index[tname] |
| r.score = rrf_scores[tname] |
| results.append(r) |
| else: |
| |
| results.append(RetrievalResult( |
| content=f"Table: {tname}", |
| metadata={"data": {"table_name": tname}, "source_type": "database"}, |
| score=rrf_scores[tname], |
| source_type="database", |
| )) |
| return results |
|
|
| |
| |
| |
|
|
| async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: |
| """Table-first retrieval for DB sources; chunk-level for tabular. |
| |
| DB tables are ranked via RRF across three legs: |
| L1 (primary): table-summary chunk similarity |
| L2 (vote): top-K column-chunk cosine, grouped by table |
| L3 (vote): top-K FTS column hits, grouped by table |
| |
| db_executor downstream fetches the full per-column schema for the |
| ranked table set via _fetch_full_schema — the column chunks returned |
| here are intentionally NOT used as the schema source, only for voting. |
| |
| Tabular (CSV/XLSX) sheets are ranked via RRF across two legs: |
| L1: sheet-chunk cosine |
| L2: column-chunk votes (best position per sheet) |
| Returns sheet-level RetrievalResults so tabular_executor receives |
| full sheet context (all columns) rather than fragmented column hits. |
| """ |
| embedding = await self._embed_query(query) |
| db_col_results, db_tbl_results, tabular_results, fts_results, sheet_results = await asyncio.gather( |
| self._search_db(embedding, user_id, k), |
| self._search_db_tables(embedding, user_id, k), |
| self._search_tabular(embedding, user_id, k), |
| self._search_fts_db(query, user_id, k * 4), |
| self._search_tabular_sheets(embedding, user_id, k), |
| ) |
|
|
| db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k) |
| tabular_ranked = self._rank_tabular_sheets(sheet_results, tabular_results, top_k=k) |
|
|
| results = sorted(db_ranked + tabular_ranked, key=lambda r: r.score, reverse=True) |
| logger.info( |
| "schema retrieval", |
| count=len(results), |
| db_tables_ranked=len(db_ranked), |
| db_cols=len(db_col_results), |
| db_tables=len(db_tbl_results), |
| tabular_cols=len(tabular_results), |
| tabular_sheets=len(sheet_results), |
| tabular_ranked=len(tabular_ranked), |
| fts=len(fts_results), |
| ) |
| return results |
|
|
|
|
| schema_retriever = SchemaRetriever() |
|
|