"""Schema retriever — handles DB schemas (source_type="database") and tabular file columns stored as source_type="document" with file_type in ("csv","xlsx"). Multiple retrieval strategies are exposed for benchmarking. The active strategy used by the router is `retrieve()`, which dispatches to ACTIVE_STRATEGY. Change ACTIVE_STRATEGY at module level to switch without touching the router. All strategies embed the query exactly once, then fan out to parallel SQL legs. Vector distance strategies: dense_no_threshold — cosine (<=>), no score floor, always returns k chunks dense_dot — inner product (<#>), equivalent to cosine for normalized embeddings dense_l2 — L2/euclidean (<->), monotonic with cosine on unit-sphere vectors hybrid — RRF merge of dense + FTS (database + tabular) hybrid_bm25 — RRF merge of dense + FTS (database only) """ import asyncio import time from typing import Literal from sqlalchemy import text from src.db.postgres.connection import _pgvector_engine from src.db.postgres.vector_store import get_vector_store from src.middlewares.logging import get_logger from src.rag.base import BaseRetriever, RetrievalResult logger = get_logger("schema_retriever") _TABULAR_FILE_TYPES = ("csv", "xlsx") Strategy = Literal["dense_no_threshold", "dense_dot", "dense_l2", "hybrid", "hybrid_bm25"] ACTIVE_STRATEGY: Strategy = "hybrid_bm25" class SchemaRetriever(BaseRetriever): def __init__(self): self.vector_store = get_vector_store() # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ async def _embed_query(self, query: str) -> list[float]: return await asyncio.to_thread(self.vector_store.embeddings.embed_query, query) async def _search_db( self, embedding: list[float], user_id: str, k: int, operator: str = "<=>" ) -> list[RetrievalResult]: """Vector search over database chunks. Accepts a pre-computed embedding.""" emb_str = "[" + ",".join(str(x) for x in embedding) + "]" if operator == "<#>": score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1" elif operator == "<->": score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))" else: score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)" sql = text(f""" SELECT lpe.document, lpe.cmetadata, {score_sql} AS score FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'database' ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC LIMIT :k """) async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, {"user_id": user_id, "k": k * 4}) rows = result.fetchall() return [ RetrievalResult( content=row.document, metadata=row.cmetadata, score=float(row.score), source_type="database", ) for row in rows ] async def _search_tabular( self, embedding: list[float], user_id: str, k: int, operator: str = "<=>" ) -> list[RetrievalResult]: """Vector search over tabular document chunks. Accepts a pre-computed embedding.""" emb_str = "[" + ",".join(str(x) for x in embedding) + "]" if operator == "<#>": score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1" elif operator == "<->": score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))" else: score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)" sql = text(f""" SELECT lpe.document, lpe.cmetadata, {score_sql} AS score FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'document' AND (lpe.cmetadata->'data'->>'file_type' = 'csv' OR lpe.cmetadata->'data'->>'file_type' = 'xlsx') ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC LIMIT :k """) async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, {"user_id": user_id, "k": k * 4}) rows = result.fetchall() results = [] for row in rows: results.append( RetrievalResult( content=row.document, metadata=row.cmetadata, score=float(row.score), source_type="document", ) ) if len(results) >= k: break return results async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]: """Full-text search over DB schema chunks using PostgreSQL tsvector. Requires GIN index on langchain_pg_embedding.document (created by init_db.py). """ sql = text(""" SELECT lpe.document, lpe.cmetadata, ts_rank(to_tsvector('english', lpe.document), plainto_tsquery('english', :query)) AS rank FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'database' AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query) ORDER BY rank DESC LIMIT :k """) async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k}) rows = result.fetchall() return [ RetrievalResult( content=row.document, metadata=row.cmetadata, score=float(row.rank), source_type="database", ) for row in rows ] async def _search_fts_tabular(self, query: str, user_id: str, k: int) -> list[RetrievalResult]: """Full-text search over tabular document chunks using PostgreSQL tsvector.""" sql = text(""" SELECT lpe.document, lpe.cmetadata, ts_rank(to_tsvector('english', lpe.document), plainto_tsquery('english', :query)) AS rank FROM langchain_pg_embedding lpe JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid WHERE lpc.name = 'document_embeddings' AND lpe.cmetadata->>'user_id' = :user_id AND lpe.cmetadata->>'source_type' = 'document' AND (lpe.cmetadata->'data'->>'file_type' = 'csv' OR lpe.cmetadata->'data'->>'file_type' = 'xlsx') AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query) ORDER BY rank DESC LIMIT :k """) async with _pgvector_engine.connect() as conn: result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k}) rows = result.fetchall() return [ RetrievalResult( content=row.document, metadata=row.cmetadata, score=float(row.rank), source_type="document", ) for row in rows ] def _rrf_merge( self, *ranked_lists: list[RetrievalResult], k_rrf: int = 60, top_k: int = 5, ) -> list[RetrievalResult]: """Reciprocal Rank Fusion — combines ranked lists using rank positions only.""" scores: dict[tuple, float] = {} index: dict[tuple, RetrievalResult] = {} for ranked in ranked_lists: for rank, result in enumerate(ranked): data = result.metadata.get("data", {}) key = (data.get("table_name"), data.get("column_name") or data.get("filename")) scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1) if key not in index or result.score > index[key].score: index[key] = result def _key(r: RetrievalResult) -> tuple: d = r.metadata.get("data", {}) return (d.get("table_name"), d.get("column_name") or d.get("filename")) merged = sorted(index.values(), key=lambda r: scores[_key(r)], reverse=True) return merged[:top_k] def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]: """Deduplicate by (table_name, column_name), keeping highest score per unique column.""" seen: dict[tuple, RetrievalResult] = {} for r in results: data = r.metadata.get("data", {}) key = (data.get("table_name"), data.get("column_name") or data.get("filename")) if key not in seen or r.score > seen[key].score: seen[key] = r return sorted(seen.values(), key=lambda r: r.score, reverse=True) # ------------------------------------------------------------------ # Named strategies — one embed call each, legs run in parallel # ------------------------------------------------------------------ async def dense_no_threshold(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: """Cosine similarity, no score cutoff — always returns k chunks.""" embedding = await self._embed_query(query) db_results, tabular_results = await asyncio.gather( self._search_db(embedding, user_id, k), self._search_tabular(embedding, user_id, k), ) return self._dedup(db_results + tabular_results)[:k] async def dense_dot(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: """Inner product similarity (<#>). For L2-normalized embeddings (OpenAI), ranking is identical to cosine. Score = raw inner product (not bounded to [0,1]). """ embedding = await self._embed_query(query) db_results, tabular_results = await asyncio.gather( self._search_db(embedding, user_id, k, "<#>"), self._search_tabular(embedding, user_id, k, "<#>"), ) return self._dedup(db_results + tabular_results)[:k] async def dense_l2(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: """L2 (Euclidean) distance similarity (<->). For L2-normalized embeddings (OpenAI), ranking order matches cosine. Score = 1 / (1 + l2_distance), bounded to (0, 1]. """ embedding = await self._embed_query(query) db_results, tabular_results = await asyncio.gather( self._search_db(embedding, user_id, k, "<->"), self._search_tabular(embedding, user_id, k, "<->"), ) return self._dedup(db_results + tabular_results)[:k] async def hybrid(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: """RRF merge of dense + FTS over both database and tabular sources. Embeds once, then runs all four legs (dense db, dense tabular, fts db, fts tabular) in a single asyncio.gather. """ embedding = await self._embed_query(query) db_results, tabular_results, fts_db, fts_tabular = await asyncio.gather( self._search_db(embedding, user_id, k), self._search_tabular(embedding, user_id, k), self._search_fts_db(query, user_id, k * 4), self._search_fts_tabular(query, user_id, k * 4), ) dense = self._dedup(db_results + tabular_results)[:k] fts_all = self._dedup(fts_db + fts_tabular) return self._rrf_merge(dense, fts_all, top_k=k) async def hybrid_bm25(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: """RRF merge of dense + FTS (database chunks only). Embeds once, then runs dense db, dense tabular, and fts db legs in parallel. """ embedding = await self._embed_query(query) db_results, tabular_results, fts_results = await asyncio.gather( self._search_db(embedding, user_id, k), self._search_tabular(embedding, user_id, k), self._search_fts_db(query, user_id, k * 4), ) dense = self._dedup(db_results + tabular_results)[:k] return self._rrf_merge(dense, self._dedup(fts_results), top_k=k) # ------------------------------------------------------------------ # Public interface — called by the router # ------------------------------------------------------------------ async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: strategy_fn = getattr(self, ACTIVE_STRATEGY) results = await strategy_fn(query, user_id, k) logger.info("schema retrieval", strategy=ACTIVE_STRATEGY, count=len(results)) return results # ------------------------------------------------------------------ # Benchmark helper — import in test scripts # ------------------------------------------------------------------ async def benchmark( query: str, user_id: str, k: int = 5, strategies: list[Strategy] | None = None, ) -> dict[str, dict]: """Run multiple strategies against the same query and return timing + results.""" retriever = SchemaRetriever() targets: list[Strategy] = strategies or [ "dense_no_threshold", "dense_dot", "dense_l2", "hybrid", "hybrid_bm25", ] report: dict[str, dict] = {} for name in targets: fn = getattr(retriever, name) t0 = time.perf_counter() chunks = await fn(query, user_id, k) elapsed_ms = round((time.perf_counter() - t0) * 1000) total_chars = sum(len(r.content) for r in chunks) report[name] = { "chunks": len(chunks), "estimated_tokens": total_chars // 4, "elapsed_ms": elapsed_ms, "results": chunks, } return report schema_retriever = SchemaRetriever()