ishaq101's picture
[KM-438][KM-439] Improve Retrieval and Querying feature (#15)
c93ec90
raw
history blame
17.3 kB
"""Schema retriever — handles DB schemas (source_type="database") and tabular file
columns stored as source_type="document" with file_type in ("csv","xlsx").
Strategy: hybrid_bm25 — RRF merge of dense cosine search (DB columns + DB tables
+ tabular columns + tabular sheets) and PostgreSQL full-text search (DB columns only).
Embeds the query once, fans out five legs in parallel.
The DB-tables leg surfaces table-level summary chunks (chunk_level='table') as
a recall signal for multi-table questions: when a relevant table's columns
don't individually win on similarity, the table chunk can still pull the table
into the hit set, where db_executor's downstream full-schema fetch picks up
the per-column detail.
FTS requires a GIN index on langchain_pg_embedding.document (created by init_db.py).
"""
import asyncio
from sqlalchemy import text
from src.db.postgres.connection import _pgvector_engine
from src.db.postgres.vector_store import get_vector_store
from src.middlewares.logging import get_logger
from src.rag.base import BaseRetriever, RetrievalResult
logger = get_logger("schema_retriever")
_TABULAR_FILE_TYPES = ("csv", "xlsx")
_TABLE_CHUNK_K_MULTIPLIER = 2 # how many table chunks to pull before RRF
class SchemaRetriever(BaseRetriever):
def __init__(self):
self.vector_store = get_vector_store()
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _embed_query(self, query: str) -> list[float]:
return await asyncio.to_thread(self.vector_store.embeddings.embed_query, query)
async def _search_db(
self, embedding: list[float], user_id: str, k: int
) -> list[RetrievalResult]:
"""Cosine vector search over database chunks."""
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
sql = text(f"""
SELECT lpe.document, lpe.cmetadata,
1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'database'
AND lpe.cmetadata->>'chunk_level' = 'column'
ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
LIMIT :k
""")
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
rows = result.fetchall()
return [
RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.score),
source_type="database",
)
for row in rows
]
async def _search_db_tables(
self, embedding: list[float], user_id: str, k: int
) -> list[RetrievalResult]:
"""Cosine vector search over database TABLE-level chunks.
Recall channel for multi-table questions. The chunk's content is
discarded downstream — db_executor only consumes its `data.table_name`
to seed full-schema fetch.
"""
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
sql = text(f"""
SELECT lpe.document, lpe.cmetadata,
1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'database'
AND lpe.cmetadata->>'chunk_level' = 'table'
ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
LIMIT :k
""")
async with _pgvector_engine.connect() as conn:
result = await conn.execute(
sql, {"user_id": user_id, "k": k * _TABLE_CHUNK_K_MULTIPLIER}
)
rows = result.fetchall()
return [
RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.score),
source_type="database",
)
for row in rows
]
async def _search_tabular(
self, embedding: list[float], user_id: str, k: int
) -> list[RetrievalResult]:
"""Cosine vector search over tabular document chunks (csv/xlsx)."""
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
sql = text(f"""
SELECT lpe.document, lpe.cmetadata,
1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'document'
AND lpe.cmetadata->>'chunk_level' = 'column'
AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
LIMIT :k
""")
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
rows = result.fetchall()
return [
RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.score),
source_type="document",
)
for row in rows
]
async def _search_tabular_sheets(
self, embedding: list[float], user_id: str, k: int
) -> list[RetrievalResult]:
"""Leg 5: sheet-level summary chunks from CSV/XLSX files."""
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
sql = text(f"""
SELECT lpe.document, lpe.cmetadata,
1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'document'
AND lpe.cmetadata->>'chunk_level' = 'sheet'
AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
LIMIT :k
""")
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, {"user_id": user_id, "k": k})
rows = result.fetchall()
return [
RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.score),
source_type="document",
)
for row in rows
]
async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
"""Full-text search over DB schema chunks using PostgreSQL tsvector."""
sql = text("""
SELECT lpe.document, lpe.cmetadata,
ts_rank(to_tsvector('english', lpe.document),
plainto_tsquery('english', :query)) AS rank
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'database'
AND lpe.cmetadata->>'chunk_level' = 'column'
AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query)
ORDER BY rank DESC
LIMIT :k
""")
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k})
rows = result.fetchall()
return [
RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.rank),
source_type="database",
)
for row in rows
]
def _rank_tabular_sheets(
self,
sheet_results: list[RetrievalResult],
column_results: list[RetrievalResult],
top_k: int,
k_rrf: int = 60,
) -> list[RetrievalResult]:
"""Rank tabular sheets by RRF across two voting legs:
L1 (primary): sheet-chunk cosine score
L2 (vote): best column-chunk position per (doc_id, sheet_name)
Returns top-k sheet-level RetrievalResults. The full column list of
each sheet is already in the sheet chunk's data.column_names from
ingestion, so downstream tabular_executor can read full sheet context.
For sheets surfaced by column votes but missing a sheet chunk (rare —
ingestion always creates one), a minimal stub is returned and
tabular_executor falls back to reading columns from the parquet.
"""
# L1: sheets indexed by (doc_id, sheet_name) from sheet chunks
sheet_index: dict[tuple, RetrievalResult] = {}
sheet_ranked: list[tuple] = []
for r in sheet_results:
d = r.metadata.get("data", {})
key = (d.get("document_id"), d.get("sheet_name"))
if key[0] and key not in sheet_index:
sheet_index[key] = r
sheet_ranked.append(key)
# L2: sheets ranked by first-appearance in column-chunk results
col_sheet_ranked: list[tuple] = []
seen: set[tuple] = set()
for r in column_results:
d = r.metadata.get("data", {})
key = (d.get("document_id"), d.get("sheet_name"))
if key[0] and key not in seen:
col_sheet_ranked.append(key)
seen.add(key)
# RRF over (doc_id, sheet_name) across the two legs
rrf_scores: dict[tuple, float] = {}
for ranked_list in [sheet_ranked, col_sheet_ranked]:
for rank, key in enumerate(ranked_list):
rrf_scores[key] = rrf_scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
top_sheets = sorted(rrf_scores, key=lambda k: rrf_scores[k], reverse=True)[:top_k]
results: list[RetrievalResult] = []
for key in top_sheets:
if key in sheet_index:
r = sheet_index[key]
r.score = rrf_scores[key]
results.append(r)
else:
# Surfaced by column votes only — build stub from a representative
# column result so tabular_executor can group correctly.
doc_id, sheet_name = key
rep = next(
(r for r in column_results
if r.metadata.get("data", {}).get("document_id") == doc_id
and r.metadata.get("data", {}).get("sheet_name") == sheet_name),
None,
)
if rep is None:
continue
stub_data = dict(rep.metadata.get("data", {}))
stub_data.pop("column_name", None)
stub_data.pop("column_type", None)
results.append(RetrievalResult(
content=f"Sheet: {stub_data.get('filename', '')}"
+ (f" / sheet: {sheet_name}" if sheet_name else ""),
metadata={**rep.metadata, "data": stub_data, "chunk_level": "sheet"},
score=rrf_scores[key],
source_type="document",
))
return results
def _rank_db_tables(
self,
tbl_results: list[RetrievalResult],
col_results: list[RetrievalResult],
fts_results: list[RetrievalResult],
top_k: int,
k_rrf: int = 60,
) -> list[RetrievalResult]:
"""Rank DB tables by RRF across three legs:
L1 (primary): table-summary chunk similarity
L2 (vote): best column-chunk position per table
L3 (vote): best FTS position per table
Returns top-k table-chunk RetrievalResults. For tables surfaced by
L2/L3 but missing a table chunk, a minimal stub is returned so that
db_executor._fetch_full_schema can seed off data.table_name.
"""
# L1: tables ranked by table-chunk cosine score
tbl_index: dict[str, RetrievalResult] = {}
tbl_ranked: list[str] = []
for r in tbl_results:
tname = r.metadata.get("data", {}).get("table_name")
if tname and tname not in tbl_index:
tbl_index[tname] = r
tbl_ranked.append(tname)
# L2: tables ranked by first-appearance in column-chunk list (best col score)
col_table_ranked: list[str] = []
seen: set[str] = set()
for r in col_results:
tname = r.metadata.get("data", {}).get("table_name")
if tname and tname not in seen:
col_table_ranked.append(tname)
seen.add(tname)
# L3: tables ranked by first-appearance in FTS list
fts_table_ranked: list[str] = []
seen = set()
for r in fts_results:
tname = r.metadata.get("data", {}).get("table_name")
if tname and tname not in seen:
fts_table_ranked.append(tname)
seen.add(tname)
# RRF over table names across the three legs
rrf_scores: dict[str, float] = {}
for ranked_list in [tbl_ranked, col_table_ranked, fts_table_ranked]:
for rank, tname in enumerate(ranked_list):
rrf_scores[tname] = rrf_scores.get(tname, 0.0) + 1.0 / (k_rrf + rank + 1)
top_tables = sorted(rrf_scores, key=lambda t: rrf_scores[t], reverse=True)[:top_k]
results: list[RetrievalResult] = []
for tname in top_tables:
if tname in tbl_index:
r = tbl_index[tname]
r.score = rrf_scores[tname]
results.append(r)
else:
# Surfaced by column/FTS votes with no table chunk — minimal stub
results.append(RetrievalResult(
content=f"Table: {tname}",
metadata={"data": {"table_name": tname}, "source_type": "database"},
score=rrf_scores[tname],
source_type="database",
))
return results
# ------------------------------------------------------------------
# Public interface — called by the router
# ------------------------------------------------------------------
async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
"""Table-first retrieval for DB sources; chunk-level for tabular.
DB tables are ranked via RRF across three legs:
L1 (primary): table-summary chunk similarity
L2 (vote): top-K column-chunk cosine, grouped by table
L3 (vote): top-K FTS column hits, grouped by table
db_executor downstream fetches the full per-column schema for the
ranked table set via _fetch_full_schema — the column chunks returned
here are intentionally NOT used as the schema source, only for voting.
Tabular (CSV/XLSX) sheets are ranked via RRF across two legs:
L1: sheet-chunk cosine
L2: column-chunk votes (best position per sheet)
Returns sheet-level RetrievalResults so tabular_executor receives
full sheet context (all columns) rather than fragmented column hits.
"""
embedding = await self._embed_query(query)
db_col_results, db_tbl_results, tabular_results, fts_results, sheet_results = await asyncio.gather(
self._search_db(embedding, user_id, k),
self._search_db_tables(embedding, user_id, k),
self._search_tabular(embedding, user_id, k),
self._search_fts_db(query, user_id, k * 4),
self._search_tabular_sheets(embedding, user_id, k),
)
db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k)
tabular_ranked = self._rank_tabular_sheets(sheet_results, tabular_results, top_k=k)
results = sorted(db_ranked + tabular_ranked, key=lambda r: r.score, reverse=True)
logger.info(
"schema retrieval",
count=len(results),
db_tables_ranked=len(db_ranked),
db_cols=len(db_col_results),
db_tables=len(db_tbl_results),
tabular_cols=len(tabular_results),
tabular_sheets=len(sheet_results),
tabular_ranked=len(tabular_ranked),
fts=len(fts_results),
)
return results
schema_retriever = SchemaRetriever()