Rifqi Hafizuddin
[KM-438-439] add retriever feature
ba550a5
raw
history blame
14.6 kB
"""Schema retriever — handles DB schemas (source_type="database") and tabular file
columns stored as source_type="document" with file_type in ("csv","xlsx").
Multiple retrieval strategies are exposed for benchmarking. The active strategy
used by the router is `retrieve()`, which dispatches to ACTIVE_STRATEGY.
Change ACTIVE_STRATEGY at module level to switch without touching the router.
All strategies embed the query exactly once, then fan out to parallel SQL legs.
Vector distance strategies:
dense_no_threshold — cosine (<=>), no score floor, always returns k chunks
dense_dot — inner product (<#>), equivalent to cosine for normalized embeddings
dense_l2 — L2/euclidean (<->), monotonic with cosine on unit-sphere vectors
hybrid — RRF merge of dense + FTS (database + tabular)
hybrid_bm25 — RRF merge of dense + FTS (database only)
"""
import asyncio
import time
from typing import Literal
from sqlalchemy import text
from src.db.postgres.connection import _pgvector_engine
from src.db.postgres.vector_store import get_vector_store
from src.middlewares.logging import get_logger
from src.rag.base import BaseRetriever, RetrievalResult
logger = get_logger("schema_retriever")
_TABULAR_FILE_TYPES = ("csv", "xlsx")
Strategy = Literal["dense_no_threshold", "dense_dot", "dense_l2", "hybrid", "hybrid_bm25"]
ACTIVE_STRATEGY: Strategy = "hybrid_bm25"
class SchemaRetriever(BaseRetriever):
def __init__(self):
self.vector_store = get_vector_store()
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _embed_query(self, query: str) -> list[float]:
return await asyncio.to_thread(self.vector_store.embeddings.embed_query, query)
async def _search_db(
self, embedding: list[float], user_id: str, k: int, operator: str = "<=>"
) -> list[RetrievalResult]:
"""Vector search over database chunks. Accepts a pre-computed embedding."""
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
if operator == "<#>":
score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
elif operator == "<->":
score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
else:
score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
sql = text(f"""
SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'database'
ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
LIMIT :k
""")
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
rows = result.fetchall()
return [
RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.score),
source_type="database",
)
for row in rows
]
async def _search_tabular(
self, embedding: list[float], user_id: str, k: int, operator: str = "<=>"
) -> list[RetrievalResult]:
"""Vector search over tabular document chunks. Accepts a pre-computed embedding."""
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
if operator == "<#>":
score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
elif operator == "<->":
score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
else:
score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
sql = text(f"""
SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'document'
AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
LIMIT :k
""")
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
rows = result.fetchall()
results = []
for row in rows:
results.append(
RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.score),
source_type="document",
)
)
if len(results) >= k:
break
return results
async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
"""Full-text search over DB schema chunks using PostgreSQL tsvector.
Requires GIN index on langchain_pg_embedding.document (created by init_db.py).
"""
sql = text("""
SELECT lpe.document, lpe.cmetadata,
ts_rank(to_tsvector('english', lpe.document),
plainto_tsquery('english', :query)) AS rank
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'database'
AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query)
ORDER BY rank DESC
LIMIT :k
""")
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k})
rows = result.fetchall()
return [
RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.rank),
source_type="database",
)
for row in rows
]
async def _search_fts_tabular(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
"""Full-text search over tabular document chunks using PostgreSQL tsvector."""
sql = text("""
SELECT lpe.document, lpe.cmetadata,
ts_rank(to_tsvector('english', lpe.document),
plainto_tsquery('english', :query)) AS rank
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'document'
AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query)
ORDER BY rank DESC
LIMIT :k
""")
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k})
rows = result.fetchall()
return [
RetrievalResult(
content=row.document,
metadata=row.cmetadata,
score=float(row.rank),
source_type="document",
)
for row in rows
]
def _rrf_merge(
self,
*ranked_lists: list[RetrievalResult],
k_rrf: int = 60,
top_k: int = 5,
) -> list[RetrievalResult]:
"""Reciprocal Rank Fusion — combines ranked lists using rank positions only."""
scores: dict[tuple, float] = {}
index: dict[tuple, RetrievalResult] = {}
for ranked in ranked_lists:
for rank, result in enumerate(ranked):
data = result.metadata.get("data", {})
key = (data.get("table_name"), data.get("column_name") or data.get("filename"))
scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
if key not in index or result.score > index[key].score:
index[key] = result
def _key(r: RetrievalResult) -> tuple:
d = r.metadata.get("data", {})
return (d.get("table_name"), d.get("column_name") or d.get("filename"))
merged = sorted(index.values(), key=lambda r: scores[_key(r)], reverse=True)
return merged[:top_k]
def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
"""Deduplicate by (table_name, column_name), keeping highest score per unique column."""
seen: dict[tuple, RetrievalResult] = {}
for r in results:
data = r.metadata.get("data", {})
key = (data.get("table_name"), data.get("column_name") or data.get("filename"))
if key not in seen or r.score > seen[key].score:
seen[key] = r
return sorted(seen.values(), key=lambda r: r.score, reverse=True)
# ------------------------------------------------------------------
# Named strategies — one embed call each, legs run in parallel
# ------------------------------------------------------------------
async def dense_no_threshold(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
"""Cosine similarity, no score cutoff — always returns k chunks."""
embedding = await self._embed_query(query)
db_results, tabular_results = await asyncio.gather(
self._search_db(embedding, user_id, k),
self._search_tabular(embedding, user_id, k),
)
return self._dedup(db_results + tabular_results)[:k]
async def dense_dot(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
"""Inner product similarity (<#>).
For L2-normalized embeddings (OpenAI), ranking is identical to cosine.
Score = raw inner product (not bounded to [0,1]).
"""
embedding = await self._embed_query(query)
db_results, tabular_results = await asyncio.gather(
self._search_db(embedding, user_id, k, "<#>"),
self._search_tabular(embedding, user_id, k, "<#>"),
)
return self._dedup(db_results + tabular_results)[:k]
async def dense_l2(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
"""L2 (Euclidean) distance similarity (<->).
For L2-normalized embeddings (OpenAI), ranking order matches cosine.
Score = 1 / (1 + l2_distance), bounded to (0, 1].
"""
embedding = await self._embed_query(query)
db_results, tabular_results = await asyncio.gather(
self._search_db(embedding, user_id, k, "<->"),
self._search_tabular(embedding, user_id, k, "<->"),
)
return self._dedup(db_results + tabular_results)[:k]
async def hybrid(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
"""RRF merge of dense + FTS over both database and tabular sources.
Embeds once, then runs all four legs (dense db, dense tabular, fts db,
fts tabular) in a single asyncio.gather.
"""
embedding = await self._embed_query(query)
db_results, tabular_results, fts_db, fts_tabular = await asyncio.gather(
self._search_db(embedding, user_id, k),
self._search_tabular(embedding, user_id, k),
self._search_fts_db(query, user_id, k * 4),
self._search_fts_tabular(query, user_id, k * 4),
)
dense = self._dedup(db_results + tabular_results)[:k]
fts_all = self._dedup(fts_db + fts_tabular)
return self._rrf_merge(dense, fts_all, top_k=k)
async def hybrid_bm25(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
"""RRF merge of dense + FTS (database chunks only).
Embeds once, then runs dense db, dense tabular, and fts db legs in parallel.
"""
embedding = await self._embed_query(query)
db_results, tabular_results, fts_results = await asyncio.gather(
self._search_db(embedding, user_id, k),
self._search_tabular(embedding, user_id, k),
self._search_fts_db(query, user_id, k * 4),
)
dense = self._dedup(db_results + tabular_results)[:k]
return self._rrf_merge(dense, self._dedup(fts_results), top_k=k)
# ------------------------------------------------------------------
# Public interface — called by the router
# ------------------------------------------------------------------
async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
strategy_fn = getattr(self, ACTIVE_STRATEGY)
results = await strategy_fn(query, user_id, k)
logger.info("schema retrieval", strategy=ACTIVE_STRATEGY, count=len(results))
return results
# ------------------------------------------------------------------
# Benchmark helper — import in test scripts
# ------------------------------------------------------------------
async def benchmark(
query: str,
user_id: str,
k: int = 5,
strategies: list[Strategy] | None = None,
) -> dict[str, dict]:
"""Run multiple strategies against the same query and return timing + results."""
retriever = SchemaRetriever()
targets: list[Strategy] = strategies or [
"dense_no_threshold",
"dense_dot",
"dense_l2",
"hybrid",
"hybrid_bm25",
]
report: dict[str, dict] = {}
for name in targets:
fn = getattr(retriever, name)
t0 = time.perf_counter()
chunks = await fn(query, user_id, k)
elapsed_ms = round((time.perf_counter() - t0) * 1000)
total_chars = sum(len(r.content) for r in chunks)
report[name] = {
"chunks": len(chunks),
"estimated_tokens": total_chars // 4,
"elapsed_ms": elapsed_ms,
"results": chunks,
}
return report
schema_retriever = SchemaRetriever()