| """Schema retriever — handles DB schemas (source_type="database") and tabular file |
| columns stored as source_type="document" with file_type in ("csv","xlsx"). |
| |
| Multiple retrieval strategies are exposed for benchmarking. The active strategy |
| used by the router is `retrieve()`, which dispatches to ACTIVE_STRATEGY. |
| Change ACTIVE_STRATEGY at module level to switch without touching the router. |
| |
| All strategies embed the query exactly once, then fan out to parallel SQL legs. |
| |
| Vector distance strategies: |
| dense_no_threshold — cosine (<=>), no score floor, always returns k chunks |
| dense_dot — inner product (<#>), equivalent to cosine for normalized embeddings |
| dense_l2 — L2/euclidean (<->), monotonic with cosine on unit-sphere vectors |
| hybrid — RRF merge of dense + FTS (database + tabular) |
| hybrid_bm25 — RRF merge of dense + FTS (database only) |
| """ |
|
|
| import asyncio |
| import time |
| from typing import Literal |
|
|
| from sqlalchemy import text |
|
|
| from src.db.postgres.connection import _pgvector_engine |
| from src.db.postgres.vector_store import get_vector_store |
| from src.middlewares.logging import get_logger |
| from src.rag.base import BaseRetriever, RetrievalResult |
|
|
| logger = get_logger("schema_retriever") |
|
|
| _TABULAR_FILE_TYPES = ("csv", "xlsx") |
|
|
| Strategy = Literal["dense_no_threshold", "dense_dot", "dense_l2", "hybrid", "hybrid_bm25"] |
| ACTIVE_STRATEGY: Strategy = "hybrid_bm25" |
|
|
|
|
| class SchemaRetriever(BaseRetriever): |
| def __init__(self): |
| self.vector_store = get_vector_store() |
|
|
| |
| |
| |
|
|
| async def _embed_query(self, query: str) -> list[float]: |
| return await asyncio.to_thread(self.vector_store.embeddings.embed_query, query) |
|
|
| async def _search_db( |
| self, embedding: list[float], user_id: str, k: int, operator: str = "<=>" |
| ) -> list[RetrievalResult]: |
| """Vector search over database chunks. Accepts a pre-computed embedding.""" |
| emb_str = "[" + ",".join(str(x) for x in embedding) + "]" |
|
|
| if operator == "<#>": |
| score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1" |
| elif operator == "<->": |
| score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))" |
| else: |
| score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)" |
|
|
| sql = text(f""" |
| SELECT lpe.document, lpe.cmetadata, {score_sql} AS score |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'database' |
| ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC |
| LIMIT :k |
| """) |
|
|
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql, {"user_id": user_id, "k": k * 4}) |
| rows = result.fetchall() |
|
|
| return [ |
| RetrievalResult( |
| content=row.document, |
| metadata=row.cmetadata, |
| score=float(row.score), |
| source_type="database", |
| ) |
| for row in rows |
| ] |
|
|
| async def _search_tabular( |
| self, embedding: list[float], user_id: str, k: int, operator: str = "<=>" |
| ) -> list[RetrievalResult]: |
| """Vector search over tabular document chunks. Accepts a pre-computed embedding.""" |
| emb_str = "[" + ",".join(str(x) for x in embedding) + "]" |
|
|
| if operator == "<#>": |
| score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1" |
| elif operator == "<->": |
| score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))" |
| else: |
| score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)" |
|
|
| sql = text(f""" |
| SELECT lpe.document, lpe.cmetadata, {score_sql} AS score |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'document' |
| AND (lpe.cmetadata->'data'->>'file_type' = 'csv' |
| OR lpe.cmetadata->'data'->>'file_type' = 'xlsx') |
| ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC |
| LIMIT :k |
| """) |
|
|
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql, {"user_id": user_id, "k": k * 4}) |
| rows = result.fetchall() |
|
|
| results = [] |
| for row in rows: |
| results.append( |
| RetrievalResult( |
| content=row.document, |
| metadata=row.cmetadata, |
| score=float(row.score), |
| source_type="document", |
| ) |
| ) |
| if len(results) >= k: |
| break |
| return results |
|
|
| async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]: |
| """Full-text search over DB schema chunks using PostgreSQL tsvector. |
| |
| Requires GIN index on langchain_pg_embedding.document (created by init_db.py). |
| """ |
| sql = text(""" |
| SELECT lpe.document, lpe.cmetadata, |
| ts_rank(to_tsvector('english', lpe.document), |
| plainto_tsquery('english', :query)) AS rank |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'database' |
| AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query) |
| ORDER BY rank DESC |
| LIMIT :k |
| """) |
|
|
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k}) |
| rows = result.fetchall() |
|
|
| return [ |
| RetrievalResult( |
| content=row.document, |
| metadata=row.cmetadata, |
| score=float(row.rank), |
| source_type="database", |
| ) |
| for row in rows |
| ] |
|
|
| async def _search_fts_tabular(self, query: str, user_id: str, k: int) -> list[RetrievalResult]: |
| """Full-text search over tabular document chunks using PostgreSQL tsvector.""" |
| sql = text(""" |
| SELECT lpe.document, lpe.cmetadata, |
| ts_rank(to_tsvector('english', lpe.document), |
| plainto_tsquery('english', :query)) AS rank |
| FROM langchain_pg_embedding lpe |
| JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid |
| WHERE lpc.name = 'document_embeddings' |
| AND lpe.cmetadata->>'user_id' = :user_id |
| AND lpe.cmetadata->>'source_type' = 'document' |
| AND (lpe.cmetadata->'data'->>'file_type' = 'csv' |
| OR lpe.cmetadata->'data'->>'file_type' = 'xlsx') |
| AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query) |
| ORDER BY rank DESC |
| LIMIT :k |
| """) |
|
|
| async with _pgvector_engine.connect() as conn: |
| result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k}) |
| rows = result.fetchall() |
|
|
| return [ |
| RetrievalResult( |
| content=row.document, |
| metadata=row.cmetadata, |
| score=float(row.rank), |
| source_type="document", |
| ) |
| for row in rows |
| ] |
|
|
| def _rrf_merge( |
| self, |
| *ranked_lists: list[RetrievalResult], |
| k_rrf: int = 60, |
| top_k: int = 5, |
| ) -> list[RetrievalResult]: |
| """Reciprocal Rank Fusion — combines ranked lists using rank positions only.""" |
| scores: dict[tuple, float] = {} |
| index: dict[tuple, RetrievalResult] = {} |
|
|
| for ranked in ranked_lists: |
| for rank, result in enumerate(ranked): |
| data = result.metadata.get("data", {}) |
| key = (data.get("table_name"), data.get("column_name") or data.get("filename")) |
| scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1) |
| if key not in index or result.score > index[key].score: |
| index[key] = result |
|
|
| def _key(r: RetrievalResult) -> tuple: |
| d = r.metadata.get("data", {}) |
| return (d.get("table_name"), d.get("column_name") or d.get("filename")) |
|
|
| merged = sorted(index.values(), key=lambda r: scores[_key(r)], reverse=True) |
| return merged[:top_k] |
|
|
| def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]: |
| """Deduplicate by (table_name, column_name), keeping highest score per unique column.""" |
| seen: dict[tuple, RetrievalResult] = {} |
| for r in results: |
| data = r.metadata.get("data", {}) |
| key = (data.get("table_name"), data.get("column_name") or data.get("filename")) |
| if key not in seen or r.score > seen[key].score: |
| seen[key] = r |
| return sorted(seen.values(), key=lambda r: r.score, reverse=True) |
|
|
| |
| |
| |
|
|
| async def dense_no_threshold(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: |
| """Cosine similarity, no score cutoff — always returns k chunks.""" |
| embedding = await self._embed_query(query) |
| db_results, tabular_results = await asyncio.gather( |
| self._search_db(embedding, user_id, k), |
| self._search_tabular(embedding, user_id, k), |
| ) |
| return self._dedup(db_results + tabular_results)[:k] |
|
|
| async def dense_dot(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: |
| """Inner product similarity (<#>). |
| |
| For L2-normalized embeddings (OpenAI), ranking is identical to cosine. |
| Score = raw inner product (not bounded to [0,1]). |
| """ |
| embedding = await self._embed_query(query) |
| db_results, tabular_results = await asyncio.gather( |
| self._search_db(embedding, user_id, k, "<#>"), |
| self._search_tabular(embedding, user_id, k, "<#>"), |
| ) |
| return self._dedup(db_results + tabular_results)[:k] |
|
|
| async def dense_l2(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: |
| """L2 (Euclidean) distance similarity (<->). |
| |
| For L2-normalized embeddings (OpenAI), ranking order matches cosine. |
| Score = 1 / (1 + l2_distance), bounded to (0, 1]. |
| """ |
| embedding = await self._embed_query(query) |
| db_results, tabular_results = await asyncio.gather( |
| self._search_db(embedding, user_id, k, "<->"), |
| self._search_tabular(embedding, user_id, k, "<->"), |
| ) |
| return self._dedup(db_results + tabular_results)[:k] |
|
|
| async def hybrid(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: |
| """RRF merge of dense + FTS over both database and tabular sources. |
| |
| Embeds once, then runs all four legs (dense db, dense tabular, fts db, |
| fts tabular) in a single asyncio.gather. |
| """ |
| embedding = await self._embed_query(query) |
| db_results, tabular_results, fts_db, fts_tabular = await asyncio.gather( |
| self._search_db(embedding, user_id, k), |
| self._search_tabular(embedding, user_id, k), |
| self._search_fts_db(query, user_id, k * 4), |
| self._search_fts_tabular(query, user_id, k * 4), |
| ) |
| dense = self._dedup(db_results + tabular_results)[:k] |
| fts_all = self._dedup(fts_db + fts_tabular) |
| return self._rrf_merge(dense, fts_all, top_k=k) |
|
|
| async def hybrid_bm25(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: |
| """RRF merge of dense + FTS (database chunks only). |
| |
| Embeds once, then runs dense db, dense tabular, and fts db legs in parallel. |
| """ |
| embedding = await self._embed_query(query) |
| db_results, tabular_results, fts_results = await asyncio.gather( |
| self._search_db(embedding, user_id, k), |
| self._search_tabular(embedding, user_id, k), |
| self._search_fts_db(query, user_id, k * 4), |
| ) |
| dense = self._dedup(db_results + tabular_results)[:k] |
| return self._rrf_merge(dense, self._dedup(fts_results), top_k=k) |
|
|
| |
| |
| |
|
|
| async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]: |
| strategy_fn = getattr(self, ACTIVE_STRATEGY) |
| results = await strategy_fn(query, user_id, k) |
| logger.info("schema retrieval", strategy=ACTIVE_STRATEGY, count=len(results)) |
| return results |
|
|
|
|
| |
| |
| |
|
|
| async def benchmark( |
| query: str, |
| user_id: str, |
| k: int = 5, |
| strategies: list[Strategy] | None = None, |
| ) -> dict[str, dict]: |
| """Run multiple strategies against the same query and return timing + results.""" |
| retriever = SchemaRetriever() |
| targets: list[Strategy] = strategies or [ |
| "dense_no_threshold", |
| "dense_dot", |
| "dense_l2", |
| "hybrid", |
| "hybrid_bm25", |
| ] |
| report: dict[str, dict] = {} |
|
|
| for name in targets: |
| fn = getattr(retriever, name) |
| t0 = time.perf_counter() |
| chunks = await fn(query, user_id, k) |
| elapsed_ms = round((time.perf_counter() - t0) * 1000) |
|
|
| total_chars = sum(len(r.content) for r in chunks) |
| report[name] = { |
| "chunks": len(chunks), |
| "estimated_tokens": total_chars // 4, |
| "elapsed_ms": elapsed_ms, |
| "results": chunks, |
| } |
|
|
| return report |
|
|
|
|
| schema_retriever = SchemaRetriever() |
|
|