Rifqi Hafizuddin commited on
Commit ·
ba550a5
1
Parent(s): 767625e
[KM-438-439] add retriever feature
Browse files- src/query/__init__.py +0 -0
- src/query/base.py +32 -0
- src/query/executors/__init__.py +0 -0
- src/query/executors/db_executor.py +409 -0
- src/query/executors/tabular.py +39 -0
- src/query/query_executor.py +52 -0
- src/rag/base.py +20 -0
- src/rag/retriever.py +22 -48
- src/rag/retrievers/__init__.py +0 -0
- src/rag/retrievers/baseline.py +70 -0
- src/rag/retrievers/document.py +32 -0
- src/rag/retrievers/schema.py +349 -0
- src/rag/router.py +75 -0
src/query/__init__.py
ADDED
|
File without changes
|
src/query/base.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared contract for query executors."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
|
| 6 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 7 |
+
|
| 8 |
+
from src.rag.base import RetrievalResult
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class QueryResult:
|
| 13 |
+
source_type: str # "database" or "document"
|
| 14 |
+
source_id: str # database_client_id or document_id
|
| 15 |
+
table_or_file: str
|
| 16 |
+
columns: list[str]
|
| 17 |
+
rows: list[dict]
|
| 18 |
+
row_count: int
|
| 19 |
+
metadata: dict = field(default_factory=dict)
|
| 20 |
+
# metadata should include "column_types": {"col_name": "dtype"} when available
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BaseExecutor(ABC):
|
| 24 |
+
@abstractmethod
|
| 25 |
+
async def execute(
|
| 26 |
+
self,
|
| 27 |
+
results: list[RetrievalResult],
|
| 28 |
+
user_id: str,
|
| 29 |
+
db: AsyncSession,
|
| 30 |
+
question: str,
|
| 31 |
+
limit: int = 100,
|
| 32 |
+
) -> list[QueryResult]: ...
|
src/query/executors/__init__.py
ADDED
|
File without changes
|
src/query/executors/db_executor.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Executor for registered database sources (source_type="database").
|
| 2 |
+
|
| 3 |
+
Flow per (client_id, question):
|
| 4 |
+
1. Collect all relevant (table_name, column_name) pairs from retrieval results.
|
| 5 |
+
2. Fetch the FULL schema for those tables from PGVector (not just top-k columns).
|
| 6 |
+
3. Build a schema context string and send to LLM → structured SQLQuery output.
|
| 7 |
+
4. Validate via sqlglot: SELECT-only, schema-grounded, LIMIT enforced.
|
| 8 |
+
5. Execute on the user's DB via engine_scope + asyncio.to_thread.
|
| 9 |
+
6. Return QueryResult per client_id (may span multiple tables via JOINs).
|
| 10 |
+
|
| 11 |
+
Supported db_types: postgres, supabase, mysql.
|
| 12 |
+
Other types are skipped with a warning — they do not raise.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
from collections import defaultdict
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
import sqlglot
|
| 20 |
+
import sqlglot.expressions as exp
|
| 21 |
+
import tiktoken
|
| 22 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 23 |
+
from langchain_openai import AzureChatOpenAI
|
| 24 |
+
from sqlalchemy import text
|
| 25 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 26 |
+
|
| 27 |
+
from src.config.settings import settings
|
| 28 |
+
from src.database_client.database_client_service import database_client_service
|
| 29 |
+
from src.db.postgres.connection import _pgvector_engine
|
| 30 |
+
from src.middlewares.logging import get_logger
|
| 31 |
+
from src.models.sql_query import SQLQuery
|
| 32 |
+
from src.pipeline.db_pipeline import db_pipeline_service
|
| 33 |
+
from src.query.base import BaseExecutor, QueryResult
|
| 34 |
+
from src.rag.base import RetrievalResult
|
| 35 |
+
from src.utils.db_credential_encryption import decrypt_credentials_dict
|
| 36 |
+
|
| 37 |
+
logger = get_logger("db_executor")
|
| 38 |
+
|
| 39 |
+
_enc = tiktoken.get_encoding("cl100k_base")
|
| 40 |
+
|
| 41 |
+
_SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"}
|
| 42 |
+
_MAX_RETRIES = 3
|
| 43 |
+
_MAX_LIMIT = 500
|
| 44 |
+
|
| 45 |
+
_SQL_SYSTEM_PROMPT = """\
|
| 46 |
+
You are a SQL data analyst working with a user's database.
|
| 47 |
+
Generate a single SQL SELECT statement that answers the user's question.
|
| 48 |
+
|
| 49 |
+
Database dialect: {dialect}
|
| 50 |
+
|
| 51 |
+
Rules:
|
| 52 |
+
- ONLY reference tables and columns listed in the schema below. Do not invent names.
|
| 53 |
+
- Always include a LIMIT clause (max {limit}).
|
| 54 |
+
- Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL.
|
| 55 |
+
- Prefer explicit JOINs over subqueries when combining tables.
|
| 56 |
+
- For aggregations, always alias the result column (e.g. COUNT(*) AS order_count).
|
| 57 |
+
- For date filtering, use dialect-appropriate functions ({dialect} syntax).
|
| 58 |
+
|
| 59 |
+
Schema:
|
| 60 |
+
{schema}
|
| 61 |
+
|
| 62 |
+
{error_section}"""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class DbExecutor(BaseExecutor):
|
| 66 |
+
def __init__(self) -> None:
|
| 67 |
+
self._llm = AzureChatOpenAI(
|
| 68 |
+
azure_deployment=settings.azureai_deployment_name_4o,
|
| 69 |
+
openai_api_version=settings.azureai_api_version_4o,
|
| 70 |
+
azure_endpoint=settings.azureai_endpoint_url_4o,
|
| 71 |
+
api_key=settings.azureai_api_key_4o,
|
| 72 |
+
temperature=0,
|
| 73 |
+
)
|
| 74 |
+
self._prompt = ChatPromptTemplate.from_messages([
|
| 75 |
+
("system", _SQL_SYSTEM_PROMPT),
|
| 76 |
+
("human", "{question}"),
|
| 77 |
+
])
|
| 78 |
+
self._chain = self._prompt | self._llm.with_structured_output(SQLQuery)
|
| 79 |
+
|
| 80 |
+
# ------------------------------------------------------------------
|
| 81 |
+
# Public interface
|
| 82 |
+
# ------------------------------------------------------------------
|
| 83 |
+
|
| 84 |
+
async def execute(
|
| 85 |
+
self,
|
| 86 |
+
results: list[RetrievalResult],
|
| 87 |
+
user_id: str,
|
| 88 |
+
db: AsyncSession,
|
| 89 |
+
question: str,
|
| 90 |
+
limit: int = 100,
|
| 91 |
+
) -> list[QueryResult]:
|
| 92 |
+
db_results = [r for r in results if r.source_type == "database"]
|
| 93 |
+
if not db_results:
|
| 94 |
+
return []
|
| 95 |
+
|
| 96 |
+
# Group by client_id — one SQL generation + execution pass per client
|
| 97 |
+
by_client: dict[str, list[RetrievalResult]] = defaultdict(list)
|
| 98 |
+
for r in db_results:
|
| 99 |
+
client_id = r.metadata.get("database_client_id", "")
|
| 100 |
+
if client_id:
|
| 101 |
+
by_client[client_id].append(r)
|
| 102 |
+
else:
|
| 103 |
+
logger.warning("db result missing database_client_id, skipping")
|
| 104 |
+
|
| 105 |
+
query_results: list[QueryResult] = []
|
| 106 |
+
for client_id, client_results in by_client.items():
|
| 107 |
+
try:
|
| 108 |
+
qr = await self._execute_for_client(client_id, client_results, user_id, db, question, limit)
|
| 109 |
+
if qr:
|
| 110 |
+
query_results.append(qr)
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.error("db executor failed for client", client_id=client_id, error=str(e))
|
| 113 |
+
|
| 114 |
+
return query_results
|
| 115 |
+
|
| 116 |
+
# ------------------------------------------------------------------
|
| 117 |
+
# Per-client execution
|
| 118 |
+
# ------------------------------------------------------------------
|
| 119 |
+
|
| 120 |
+
async def _execute_for_client(
|
| 121 |
+
self,
|
| 122 |
+
client_id: str,
|
| 123 |
+
results: list[RetrievalResult],
|
| 124 |
+
user_id: str,
|
| 125 |
+
db: AsyncSession,
|
| 126 |
+
question: str,
|
| 127 |
+
limit: int,
|
| 128 |
+
) -> QueryResult | None:
|
| 129 |
+
client = await database_client_service.get(db, client_id)
|
| 130 |
+
if not client:
|
| 131 |
+
logger.warning("database client not found", client_id=client_id)
|
| 132 |
+
return None
|
| 133 |
+
if client.user_id != user_id:
|
| 134 |
+
logger.warning("client ownership mismatch", client_id=client_id)
|
| 135 |
+
return None
|
| 136 |
+
if client.db_type not in _SUPPORTED_DB_TYPES:
|
| 137 |
+
logger.warning("unsupported db_type for query execution", db_type=client.db_type)
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
# Distinct table names from retrieval results, expanded via FK relationships
|
| 141 |
+
table_names = list({
|
| 142 |
+
r.metadata.get("data", {}).get("table_name")
|
| 143 |
+
for r in results
|
| 144 |
+
if r.metadata.get("data", {}).get("table_name")
|
| 145 |
+
})
|
| 146 |
+
table_names = await self._expand_with_fk_tables(client_id, user_id, table_names)
|
| 147 |
+
|
| 148 |
+
full_schema = await self._fetch_full_schema(client_id, table_names, user_id)
|
| 149 |
+
if not full_schema:
|
| 150 |
+
logger.warning("no schema found in vector store", client_id=client_id, tables=table_names)
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
schema_ctx = self._build_schema_context(full_schema)
|
| 154 |
+
capped_limit = min(limit, _MAX_LIMIT)
|
| 155 |
+
dialect = client.db_type
|
| 156 |
+
|
| 157 |
+
# SQL generation with retry
|
| 158 |
+
validated_sql: str | None = None
|
| 159 |
+
prev_error: str = ""
|
| 160 |
+
prev_reasoning: str = ""
|
| 161 |
+
for attempt in range(_MAX_RETRIES):
|
| 162 |
+
if prev_error:
|
| 163 |
+
error_section = (
|
| 164 |
+
f"Previous attempt reasoning: {prev_reasoning}\n"
|
| 165 |
+
f"Previous attempt failed: {prev_error}\n"
|
| 166 |
+
"Fix the issue above."
|
| 167 |
+
)
|
| 168 |
+
else:
|
| 169 |
+
error_section = ""
|
| 170 |
+
try:
|
| 171 |
+
prompt_text = schema_ctx + error_section + question
|
| 172 |
+
input_tokens = len(_enc.encode(prompt_text))
|
| 173 |
+
logger.info("sql generation input tokens", attempt=attempt + 1, tokens=input_tokens)
|
| 174 |
+
|
| 175 |
+
result: SQLQuery = await self._chain.ainvoke({
|
| 176 |
+
"schema": schema_ctx,
|
| 177 |
+
"dialect": dialect,
|
| 178 |
+
"limit": capped_limit,
|
| 179 |
+
"error_section": error_section,
|
| 180 |
+
"question": question,
|
| 181 |
+
})
|
| 182 |
+
sql = result.sql.strip()
|
| 183 |
+
validation_error = self._validate(sql, full_schema, capped_limit)
|
| 184 |
+
if validation_error:
|
| 185 |
+
prev_error = validation_error
|
| 186 |
+
prev_reasoning = result.reasoning
|
| 187 |
+
logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error)
|
| 188 |
+
continue
|
| 189 |
+
validated_sql = self._enforce_limit(sql, capped_limit)
|
| 190 |
+
output_tokens = len(_enc.encode(result.sql)) + len(_enc.encode(result.reasoning))
|
| 191 |
+
logger.info(
|
| 192 |
+
"sql generated",
|
| 193 |
+
attempt=attempt + 1,
|
| 194 |
+
input_tokens=input_tokens,
|
| 195 |
+
output_tokens=output_tokens,
|
| 196 |
+
total_tokens=input_tokens + output_tokens,
|
| 197 |
+
reasoning=result.reasoning,
|
| 198 |
+
)
|
| 199 |
+
break
|
| 200 |
+
except Exception as e:
|
| 201 |
+
prev_error = str(e)
|
| 202 |
+
logger.warning("sql generation error", attempt=attempt + 1, error=prev_error)
|
| 203 |
+
|
| 204 |
+
if not validated_sql:
|
| 205 |
+
logger.error("sql generation failed after retries", client_id=client_id)
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
+
# Execute on user's DB
|
| 209 |
+
creds = decrypt_credentials_dict(client.credentials)
|
| 210 |
+
with db_pipeline_service.engine_scope(client.db_type, creds) as engine:
|
| 211 |
+
rows = await asyncio.to_thread(self._run_sql, engine, validated_sql)
|
| 212 |
+
|
| 213 |
+
column_types = {
|
| 214 |
+
col["name"]: col["type"]
|
| 215 |
+
for cols in full_schema.values()
|
| 216 |
+
for col in cols
|
| 217 |
+
}
|
| 218 |
+
columns = list(rows[0].keys()) if rows else []
|
| 219 |
+
|
| 220 |
+
return QueryResult(
|
| 221 |
+
source_type="database",
|
| 222 |
+
source_id=client_id,
|
| 223 |
+
table_or_file=", ".join(table_names),
|
| 224 |
+
columns=columns,
|
| 225 |
+
rows=rows,
|
| 226 |
+
row_count=len(rows),
|
| 227 |
+
metadata={
|
| 228 |
+
"db_type": client.db_type,
|
| 229 |
+
"client_name": client.name,
|
| 230 |
+
"sql": validated_sql,
|
| 231 |
+
"column_types": {c: column_types.get(c, "unknown") for c in columns},
|
| 232 |
+
},
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# ------------------------------------------------------------------
|
| 236 |
+
# Schema helpers
|
| 237 |
+
# ------------------------------------------------------------------
|
| 238 |
+
|
| 239 |
+
async def _expand_with_fk_tables(
|
| 240 |
+
self,
|
| 241 |
+
client_id: str,
|
| 242 |
+
user_id: str,
|
| 243 |
+
table_names: list[str],
|
| 244 |
+
) -> list[str]:
|
| 245 |
+
"""Expand table_names with any tables FK-referenced by the retrieved tables.
|
| 246 |
+
|
| 247 |
+
Prevents SQL generation failures when a required table (e.g. orders) wasn't
|
| 248 |
+
returned by retrieval but is referenced via FK from a table that was
|
| 249 |
+
(e.g. order_items.order_id -> orders.id).
|
| 250 |
+
"""
|
| 251 |
+
if not table_names:
|
| 252 |
+
return table_names
|
| 253 |
+
|
| 254 |
+
placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
|
| 255 |
+
sql = text(f"""
|
| 256 |
+
SELECT DISTINCT lpe.cmetadata->'data'->>'foreign_key' AS fk
|
| 257 |
+
FROM langchain_pg_embedding lpe
|
| 258 |
+
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
|
| 259 |
+
WHERE lpc.name = 'document_embeddings'
|
| 260 |
+
AND lpe.cmetadata->>'user_id' = :user_id
|
| 261 |
+
AND lpe.cmetadata->>'source_type' = 'database'
|
| 262 |
+
AND lpe.cmetadata->>'database_client_id' = :client_id
|
| 263 |
+
AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
|
| 264 |
+
AND lpe.cmetadata->'data'->>'foreign_key' IS NOT NULL
|
| 265 |
+
""")
|
| 266 |
+
|
| 267 |
+
params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
|
| 268 |
+
for i, name in enumerate(table_names):
|
| 269 |
+
params[f"t{i}"] = name
|
| 270 |
+
|
| 271 |
+
async with _pgvector_engine.connect() as conn:
|
| 272 |
+
result = await conn.execute(sql, params)
|
| 273 |
+
rows = result.fetchall()
|
| 274 |
+
|
| 275 |
+
expanded = set(table_names)
|
| 276 |
+
for row in rows:
|
| 277 |
+
fk = row.fk # format: "referred_table.referred_column"
|
| 278 |
+
if fk:
|
| 279 |
+
referred_table = fk.split(".")[0]
|
| 280 |
+
expanded.add(referred_table)
|
| 281 |
+
|
| 282 |
+
if expanded != set(table_names):
|
| 283 |
+
logger.info(
|
| 284 |
+
"expanded tables via FK",
|
| 285 |
+
original=sorted(table_names),
|
| 286 |
+
expanded=sorted(expanded),
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
return list(expanded)
|
| 290 |
+
|
| 291 |
+
async def _fetch_full_schema(
|
| 292 |
+
self,
|
| 293 |
+
client_id: str,
|
| 294 |
+
table_names: list[str],
|
| 295 |
+
user_id: str,
|
| 296 |
+
) -> dict[str, list[dict[str, Any]]]:
|
| 297 |
+
"""Fetch ALL column chunks for the given tables from PGVector.
|
| 298 |
+
|
| 299 |
+
Returns {table_name: [{"name": ..., "type": ..., "is_primary_key": ...,
|
| 300 |
+
"foreign_key": ..., "content": ...}]}
|
| 301 |
+
"""
|
| 302 |
+
placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
|
| 303 |
+
sql = text(f"""
|
| 304 |
+
SELECT lpe.cmetadata, lpe.document
|
| 305 |
+
FROM langchain_pg_embedding lpe
|
| 306 |
+
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
|
| 307 |
+
WHERE lpc.name = 'document_embeddings'
|
| 308 |
+
AND lpe.cmetadata->>'user_id' = :user_id
|
| 309 |
+
AND lpe.cmetadata->>'source_type' = 'database'
|
| 310 |
+
AND lpe.cmetadata->>'database_client_id' = :client_id
|
| 311 |
+
AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
|
| 312 |
+
ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name'
|
| 313 |
+
""")
|
| 314 |
+
|
| 315 |
+
params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
|
| 316 |
+
for i, name in enumerate(table_names):
|
| 317 |
+
params[f"t{i}"] = name
|
| 318 |
+
|
| 319 |
+
async with _pgvector_engine.connect() as conn:
|
| 320 |
+
result = await conn.execute(sql, params)
|
| 321 |
+
rows = result.fetchall()
|
| 322 |
+
|
| 323 |
+
schema: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
| 324 |
+
for row in rows:
|
| 325 |
+
data = row.cmetadata.get("data", {})
|
| 326 |
+
table = data.get("table_name")
|
| 327 |
+
if table:
|
| 328 |
+
schema[table].append({
|
| 329 |
+
"name": data.get("column_name", ""),
|
| 330 |
+
"type": data.get("column_type", ""),
|
| 331 |
+
"is_primary_key": data.get("is_primary_key", False),
|
| 332 |
+
"foreign_key": data.get("foreign_key"),
|
| 333 |
+
"content": row.document, # chunk text includes top values / samples
|
| 334 |
+
})
|
| 335 |
+
return dict(schema)
|
| 336 |
+
|
| 337 |
+
def _build_schema_context(self, schema: dict[str, list[dict[str, Any]]]) -> str:
|
| 338 |
+
lines: list[str] = []
|
| 339 |
+
for table, columns in schema.items():
|
| 340 |
+
lines.append(f"Table: {table}")
|
| 341 |
+
for col in columns:
|
| 342 |
+
flags = []
|
| 343 |
+
if col["is_primary_key"]:
|
| 344 |
+
flags.append("PRIMARY KEY")
|
| 345 |
+
if col["foreign_key"]:
|
| 346 |
+
flags.append(f"FK -> {col['foreign_key']}")
|
| 347 |
+
flag_str = f" [{', '.join(flags)}]" if flags else ""
|
| 348 |
+
lines.append(f" - {col['name']} {col['type']}{flag_str}")
|
| 349 |
+
# Include sample/top-values line from chunk content if present
|
| 350 |
+
for line in col["content"].splitlines():
|
| 351 |
+
if line.startswith(("Top values:", "Sample values:")):
|
| 352 |
+
lines.append(f" {line}")
|
| 353 |
+
break
|
| 354 |
+
lines.append("")
|
| 355 |
+
return "\n".join(lines).strip()
|
| 356 |
+
|
| 357 |
+
# ------------------------------------------------------------------
|
| 358 |
+
# Guardrails
|
| 359 |
+
# ------------------------------------------------------------------
|
| 360 |
+
|
| 361 |
+
def _validate(self, sql: str, schema: dict[str, list[dict]], limit: int) -> str:
|
| 362 |
+
"""Return an error string if validation fails, empty string if OK."""
|
| 363 |
+
# Layer 1: sqlglot parse + SELECT-only check
|
| 364 |
+
try:
|
| 365 |
+
parsed = sqlglot.parse_one(sql)
|
| 366 |
+
except sqlglot.errors.ParseError as e:
|
| 367 |
+
return f"SQL parse error: {e}"
|
| 368 |
+
|
| 369 |
+
if not isinstance(parsed, exp.Select):
|
| 370 |
+
return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}"
|
| 371 |
+
|
| 372 |
+
# Check for DML anywhere in the AST (including writeable CTEs)
|
| 373 |
+
for node in parsed.find_all((exp.Insert, exp.Update, exp.Delete)):
|
| 374 |
+
return f"DML ({type(node).__name__}) is not allowed."
|
| 375 |
+
|
| 376 |
+
# Layer 2: schema grounding — table names
|
| 377 |
+
known_tables = {t.lower() for t in schema}
|
| 378 |
+
for tbl in parsed.find_all(exp.Table):
|
| 379 |
+
name = tbl.name.lower()
|
| 380 |
+
if name and name not in known_tables:
|
| 381 |
+
return f"Unknown table '{tbl.name}'. Only use tables from the schema."
|
| 382 |
+
|
| 383 |
+
# Layer 3: LIMIT enforcement (inject if missing — done before execution)
|
| 384 |
+
return ""
|
| 385 |
+
|
| 386 |
+
# ------------------------------------------------------------------
|
| 387 |
+
# SQL execution
|
| 388 |
+
# ------------------------------------------------------------------
|
| 389 |
+
|
| 390 |
+
def _enforce_limit(self, sql: str, limit: int) -> str:
|
| 391 |
+
"""Inject or cap LIMIT using sqlglot AST manipulation."""
|
| 392 |
+
parsed = sqlglot.parse_one(sql)
|
| 393 |
+
existing = parsed.find(exp.Limit)
|
| 394 |
+
if existing:
|
| 395 |
+
current = int(existing.expression.this)
|
| 396 |
+
if current > limit:
|
| 397 |
+
existing.expression.set("this", limit)
|
| 398 |
+
else:
|
| 399 |
+
parsed = parsed.limit(limit)
|
| 400 |
+
return parsed.sql()
|
| 401 |
+
|
| 402 |
+
def _run_sql(self, engine: Any, sql: str) -> list[dict]:
|
| 403 |
+
# Ensure the user DB connection is a read-only credential — sqlglot validation alone is not sufficient.
|
| 404 |
+
with engine.connect() as conn:
|
| 405 |
+
result = conn.execute(text(sql))
|
| 406 |
+
return [dict(row) for row in result.mappings()]
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
db_executor = DbExecutor()
|
src/query/executors/tabular.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Executor for tabular document sources (source_type="document", file_type csv/xlsx).
|
| 2 |
+
|
| 3 |
+
Flow:
|
| 4 |
+
1. Group RetrievalResult chunks by document_id.
|
| 5 |
+
2. For each document: download bytes from Azure Blob -> read with pandas.
|
| 6 |
+
3. Filter DataFrame to relevant columns identified by retrieval.
|
| 7 |
+
4. Return QueryResult per document.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 11 |
+
|
| 12 |
+
from src.middlewares.logging import get_logger
|
| 13 |
+
from src.query.base import BaseExecutor, QueryResult
|
| 14 |
+
from src.rag.base import RetrievalResult
|
| 15 |
+
|
| 16 |
+
logger = get_logger("tabular_executor")
|
| 17 |
+
|
| 18 |
+
_TABULAR_FILE_TYPES = ("csv", "xlsx")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TabularExecutor(BaseExecutor):
|
| 22 |
+
async def execute(
|
| 23 |
+
self,
|
| 24 |
+
results: list[RetrievalResult],
|
| 25 |
+
user_id: str,
|
| 26 |
+
db: AsyncSession,
|
| 27 |
+
limit: int = 100,
|
| 28 |
+
) -> list[QueryResult]:
|
| 29 |
+
# TODO: implement
|
| 30 |
+
# 1. filter results where source_type == "document" and file_type in _TABULAR_FILE_TYPES
|
| 31 |
+
# 2. group by document_id -> list of column_names
|
| 32 |
+
# 3. per group: look up Document by document_id -> get blob_name
|
| 33 |
+
# 4. blob_storage.download_file(blob_name) -> pd.read_csv / pd.read_excel
|
| 34 |
+
# 5. df[relevant_columns].head(limit) -> rows as list[dict]
|
| 35 |
+
# 6. return QueryResult per document
|
| 36 |
+
raise NotImplementedError
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
tabular_executor = TabularExecutor()
|
src/query/query_executor.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""QueryExecutor — dispatches retrieval results to the appropriate executor by source_type."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
|
| 5 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 6 |
+
|
| 7 |
+
from src.middlewares.logging import get_logger
|
| 8 |
+
from src.query.base import QueryResult
|
| 9 |
+
from src.query.executors.db_executor import db_executor
|
| 10 |
+
from src.query.executors.tabular import tabular_executor
|
| 11 |
+
from src.rag.base import RetrievalResult
|
| 12 |
+
|
| 13 |
+
logger = get_logger("query_executor")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class QueryExecutor:
|
| 17 |
+
async def execute(
|
| 18 |
+
self,
|
| 19 |
+
results: list[RetrievalResult],
|
| 20 |
+
user_id: str,
|
| 21 |
+
db: AsyncSession,
|
| 22 |
+
question: str,
|
| 23 |
+
limit: int = 100,
|
| 24 |
+
) -> list[QueryResult]:
|
| 25 |
+
db_results = [r for r in results if r.source_type == "database"]
|
| 26 |
+
tabular_results = [
|
| 27 |
+
r for r in results
|
| 28 |
+
if r.source_type == "document"
|
| 29 |
+
and r.metadata.get("data", {}).get("file_type") in ("csv", "xlsx")
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
async def _empty() -> list[QueryResult]:
|
| 33 |
+
return []
|
| 34 |
+
|
| 35 |
+
batches = await asyncio.gather(
|
| 36 |
+
db_executor.execute(db_results, user_id, db, question, limit) if db_results else _empty(),
|
| 37 |
+
tabular_executor.execute(tabular_results, user_id, db, question, limit) if tabular_results else _empty(),
|
| 38 |
+
return_exceptions=True,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
query_results: list[QueryResult] = []
|
| 42 |
+
for batch in batches:
|
| 43 |
+
if isinstance(batch, Exception):
|
| 44 |
+
logger.error("executor failed", error=str(batch))
|
| 45 |
+
continue
|
| 46 |
+
query_results.extend(batch)
|
| 47 |
+
|
| 48 |
+
logger.info("query execution complete", total=len(query_results))
|
| 49 |
+
return query_results
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
query_executor = QueryExecutor()
|
src/rag/base.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared contract for all retriever implementations."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class RetrievalResult:
|
| 10 |
+
content: str
|
| 11 |
+
metadata: dict[str, Any]
|
| 12 |
+
score: float
|
| 13 |
+
source_type: str # "document" | "database"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseRetriever(ABC):
|
| 17 |
+
@abstractmethod
|
| 18 |
+
async def retrieve(
|
| 19 |
+
self, query: str, user_id: str, k: int = 5
|
| 20 |
+
) -> list[RetrievalResult]: ...
|
src/rag/retriever.py
CHANGED
|
@@ -1,69 +1,43 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
import hashlib
|
| 4 |
-
import json
|
| 5 |
-
from src.db.postgres.vector_store import get_vector_store
|
| 6 |
-
from src.db.redis.connection import get_redis
|
| 7 |
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
| 8 |
from src.middlewares.logging import get_logger
|
| 9 |
-
from
|
|
|
|
|
|
|
| 10 |
|
| 11 |
logger = get_logger("retriever")
|
| 12 |
|
| 13 |
-
_RETRIEVAL_CACHE_TTL = 3600 # 1 hour
|
| 14 |
-
|
| 15 |
|
| 16 |
class RetrieverService:
|
| 17 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def __init__(self):
|
| 20 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
async def retrieve(
|
| 23 |
self,
|
| 24 |
query: str,
|
| 25 |
user_id: str,
|
| 26 |
db: AsyncSession,
|
| 27 |
-
k: int = 5
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
Returns:
|
| 32 |
-
List of dicts with keys: content, metadata
|
| 33 |
-
metadata includes: document_id, user_id, filename, chunk_index, page_label (if PDF)
|
| 34 |
-
"""
|
| 35 |
try:
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
cache_key = f"retrieval:{user_id}:{query_hash}:{k}"
|
| 39 |
-
|
| 40 |
-
cached = await redis.get(cache_key)
|
| 41 |
-
if cached:
|
| 42 |
-
logger.info("Returning cached retrieval results")
|
| 43 |
-
return json.loads(cached)
|
| 44 |
-
|
| 45 |
-
logger.info(f"Retrieving for user {user_id}, query: {query[:50]}...")
|
| 46 |
-
|
| 47 |
-
docs = await self.vector_store.asimilarity_search(
|
| 48 |
-
query=query,
|
| 49 |
-
k=k,
|
| 50 |
-
filter={"user_id": user_id}
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
results = [
|
| 54 |
-
{
|
| 55 |
-
"content": doc.page_content,
|
| 56 |
-
"metadata": doc.metadata,
|
| 57 |
-
}
|
| 58 |
-
for doc in docs
|
| 59 |
-
]
|
| 60 |
-
|
| 61 |
-
logger.info(f"Retrieved {len(results)} chunks")
|
| 62 |
-
await redis.setex(cache_key, _RETRIEVAL_CACHE_TTL, json.dumps(results))
|
| 63 |
-
return results
|
| 64 |
-
|
| 65 |
except Exception as e:
|
| 66 |
-
logger.error("
|
| 67 |
return []
|
| 68 |
|
| 69 |
|
|
|
|
| 1 |
+
"""Public retrieval API — thin wrapper around RetrievalRouter."""
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 6 |
+
|
| 7 |
from src.middlewares.logging import get_logger
|
| 8 |
+
from src.rag.retrievers.document import document_retriever
|
| 9 |
+
from src.rag.retrievers.schema import schema_retriever
|
| 10 |
+
from src.rag.router import RetrievalRouter, SourceHint
|
| 11 |
|
| 12 |
logger = get_logger("retriever")
|
| 13 |
|
|
|
|
|
|
|
| 14 |
|
| 15 |
class RetrieverService:
|
| 16 |
+
"""Public retrieval service used by chat.py and search tools.
|
| 17 |
+
|
| 18 |
+
Delegates to RetrievalRouter which dispatches based on source_hint.
|
| 19 |
+
Returns List[Dict] to preserve backward compatibility with chat.py.
|
| 20 |
+
"""
|
| 21 |
|
| 22 |
def __init__(self):
|
| 23 |
+
self._router = RetrievalRouter(
|
| 24 |
+
schema_retriever=schema_retriever,
|
| 25 |
+
document_retriever=document_retriever,
|
| 26 |
+
)
|
| 27 |
|
| 28 |
async def retrieve(
|
| 29 |
self,
|
| 30 |
query: str,
|
| 31 |
user_id: str,
|
| 32 |
db: AsyncSession,
|
| 33 |
+
k: int = 5,
|
| 34 |
+
source_hint: SourceHint = "both",
|
| 35 |
+
) -> list[dict[str, Any]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
try:
|
| 37 |
+
results = await self._router.retrieve(query, user_id, source_hint, k)
|
| 38 |
+
return [{"content": r.content, "metadata": r.metadata} for r in results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
except Exception as e:
|
| 40 |
+
logger.error("retrieval failed", error=str(e))
|
| 41 |
return []
|
| 42 |
|
| 43 |
|
src/rag/retrievers/__init__.py
ADDED
|
File without changes
|
src/rag/retrievers/baseline.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Service for retrieving relevant documents from vector store."""
|
| 2 |
+
|
| 3 |
+
import hashlib
|
| 4 |
+
import json
|
| 5 |
+
from src.db.postgres.vector_store import get_vector_store
|
| 6 |
+
from src.db.redis.connection import get_redis
|
| 7 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 8 |
+
from src.middlewares.logging import get_logger
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
+
|
| 11 |
+
logger = get_logger("retriever")
|
| 12 |
+
|
| 13 |
+
_RETRIEVAL_CACHE_TTL = 3600 # 1 hour
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class RetrieverService:
|
| 17 |
+
"""Service for retrieving relevant documents."""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.vector_store = get_vector_store()
|
| 21 |
+
|
| 22 |
+
async def retrieve(
|
| 23 |
+
self,
|
| 24 |
+
query: str,
|
| 25 |
+
user_id: str,
|
| 26 |
+
db: AsyncSession,
|
| 27 |
+
k: int = 5
|
| 28 |
+
) -> List[Dict[str, Any]]:
|
| 29 |
+
"""Retrieve relevant chunks for a query, scoped to the user's documents.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
List of dicts with keys: content, metadata
|
| 33 |
+
metadata includes: document_id, user_id, filename, chunk_index, page_label (if PDF)
|
| 34 |
+
"""
|
| 35 |
+
try:
|
| 36 |
+
redis = await get_redis()
|
| 37 |
+
query_hash = hashlib.md5(query.encode()).hexdigest()
|
| 38 |
+
cache_key = f"retrieval:{user_id}:{query_hash}:{k}"
|
| 39 |
+
|
| 40 |
+
cached = await redis.get(cache_key)
|
| 41 |
+
if cached:
|
| 42 |
+
logger.info("Returning cached retrieval results")
|
| 43 |
+
return json.loads(cached)
|
| 44 |
+
|
| 45 |
+
logger.info(f"Retrieving for user {user_id}, query: {query[:50]}...")
|
| 46 |
+
|
| 47 |
+
docs = await self.vector_store.asimilarity_search(
|
| 48 |
+
query=query,
|
| 49 |
+
k=k,
|
| 50 |
+
filter={"user_id": user_id}
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
results = [
|
| 54 |
+
{
|
| 55 |
+
"content": doc.page_content,
|
| 56 |
+
"metadata": doc.metadata,
|
| 57 |
+
}
|
| 58 |
+
for doc in docs
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
logger.info(f"Retrieved {len(results)} chunks")
|
| 62 |
+
await redis.setex(cache_key, _RETRIEVAL_CACHE_TTL, json.dumps(results))
|
| 63 |
+
return results
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.error("Retrieval failed", error=str(e))
|
| 67 |
+
return []
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
retriever = RetrieverService()
|
src/rag/retrievers/document.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular).
|
| 2 |
+
|
| 3 |
+
TEAMMATE: implement retrieve() below.
|
| 4 |
+
Strategy: MMR (amax_marginal_relevance_search) + score threshold to avoid returning
|
| 5 |
+
near-identical chunks from the same PDF page.
|
| 6 |
+
Filter: source_type="document" AND data->>'file_type' NOT IN ('csv', 'xlsx')
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from src.db.postgres.vector_store import get_vector_store
|
| 10 |
+
from src.middlewares.logging import get_logger
|
| 11 |
+
from src.rag.base import BaseRetriever, RetrievalResult
|
| 12 |
+
|
| 13 |
+
logger = get_logger("document_retriever")
|
| 14 |
+
|
| 15 |
+
_SCORE_THRESHOLD = 0.45 # discard chunks with cosine distance above this
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DocumentRetriever(BaseRetriever):
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.vector_store = get_vector_store()
|
| 21 |
+
|
| 22 |
+
async def retrieve(
|
| 23 |
+
self, query: str, user_id: str, k: int = 5
|
| 24 |
+
) -> list[RetrievalResult]:
|
| 25 |
+
# TODO (teammate): implement MMR retrieval for prose documents
|
| 26 |
+
# Filter: {"user_id": user_id, "source_type": "document"}
|
| 27 |
+
# then post-filter to exclude file_type in ("csv", "xlsx")
|
| 28 |
+
logger.info("document retriever not yet implemented — returning empty")
|
| 29 |
+
return []
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
document_retriever = DocumentRetriever()
|
src/rag/retrievers/schema.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Schema retriever — handles DB schemas (source_type="database") and tabular file
|
| 2 |
+
columns stored as source_type="document" with file_type in ("csv","xlsx").
|
| 3 |
+
|
| 4 |
+
Multiple retrieval strategies are exposed for benchmarking. The active strategy
|
| 5 |
+
used by the router is `retrieve()`, which dispatches to ACTIVE_STRATEGY.
|
| 6 |
+
Change ACTIVE_STRATEGY at module level to switch without touching the router.
|
| 7 |
+
|
| 8 |
+
All strategies embed the query exactly once, then fan out to parallel SQL legs.
|
| 9 |
+
|
| 10 |
+
Vector distance strategies:
|
| 11 |
+
dense_no_threshold — cosine (<=>), no score floor, always returns k chunks
|
| 12 |
+
dense_dot — inner product (<#>), equivalent to cosine for normalized embeddings
|
| 13 |
+
dense_l2 — L2/euclidean (<->), monotonic with cosine on unit-sphere vectors
|
| 14 |
+
hybrid — RRF merge of dense + FTS (database + tabular)
|
| 15 |
+
hybrid_bm25 — RRF merge of dense + FTS (database only)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import asyncio
|
| 19 |
+
import time
|
| 20 |
+
from typing import Literal
|
| 21 |
+
|
| 22 |
+
from sqlalchemy import text
|
| 23 |
+
|
| 24 |
+
from src.db.postgres.connection import _pgvector_engine
|
| 25 |
+
from src.db.postgres.vector_store import get_vector_store
|
| 26 |
+
from src.middlewares.logging import get_logger
|
| 27 |
+
from src.rag.base import BaseRetriever, RetrievalResult
|
| 28 |
+
|
| 29 |
+
logger = get_logger("schema_retriever")
|
| 30 |
+
|
| 31 |
+
_TABULAR_FILE_TYPES = ("csv", "xlsx")
|
| 32 |
+
|
| 33 |
+
Strategy = Literal["dense_no_threshold", "dense_dot", "dense_l2", "hybrid", "hybrid_bm25"]
|
| 34 |
+
ACTIVE_STRATEGY: Strategy = "hybrid_bm25"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SchemaRetriever(BaseRetriever):
|
| 38 |
+
def __init__(self):
|
| 39 |
+
self.vector_store = get_vector_store()
|
| 40 |
+
|
| 41 |
+
# ------------------------------------------------------------------
|
| 42 |
+
# Internal helpers
|
| 43 |
+
# ------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
async def _embed_query(self, query: str) -> list[float]:
|
| 46 |
+
return await asyncio.to_thread(self.vector_store.embeddings.embed_query, query)
|
| 47 |
+
|
| 48 |
+
async def _search_db(
|
| 49 |
+
self, embedding: list[float], user_id: str, k: int, operator: str = "<=>"
|
| 50 |
+
) -> list[RetrievalResult]:
|
| 51 |
+
"""Vector search over database chunks. Accepts a pre-computed embedding."""
|
| 52 |
+
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
| 53 |
+
|
| 54 |
+
if operator == "<#>":
|
| 55 |
+
score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
|
| 56 |
+
elif operator == "<->":
|
| 57 |
+
score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
|
| 58 |
+
else:
|
| 59 |
+
score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
|
| 60 |
+
|
| 61 |
+
sql = text(f"""
|
| 62 |
+
SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
|
| 63 |
+
FROM langchain_pg_embedding lpe
|
| 64 |
+
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
|
| 65 |
+
WHERE lpc.name = 'document_embeddings'
|
| 66 |
+
AND lpe.cmetadata->>'user_id' = :user_id
|
| 67 |
+
AND lpe.cmetadata->>'source_type' = 'database'
|
| 68 |
+
ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
|
| 69 |
+
LIMIT :k
|
| 70 |
+
""")
|
| 71 |
+
|
| 72 |
+
async with _pgvector_engine.connect() as conn:
|
| 73 |
+
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
|
| 74 |
+
rows = result.fetchall()
|
| 75 |
+
|
| 76 |
+
return [
|
| 77 |
+
RetrievalResult(
|
| 78 |
+
content=row.document,
|
| 79 |
+
metadata=row.cmetadata,
|
| 80 |
+
score=float(row.score),
|
| 81 |
+
source_type="database",
|
| 82 |
+
)
|
| 83 |
+
for row in rows
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
async def _search_tabular(
|
| 87 |
+
self, embedding: list[float], user_id: str, k: int, operator: str = "<=>"
|
| 88 |
+
) -> list[RetrievalResult]:
|
| 89 |
+
"""Vector search over tabular document chunks. Accepts a pre-computed embedding."""
|
| 90 |
+
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
| 91 |
+
|
| 92 |
+
if operator == "<#>":
|
| 93 |
+
score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
|
| 94 |
+
elif operator == "<->":
|
| 95 |
+
score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
|
| 96 |
+
else:
|
| 97 |
+
score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
|
| 98 |
+
|
| 99 |
+
sql = text(f"""
|
| 100 |
+
SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
|
| 101 |
+
FROM langchain_pg_embedding lpe
|
| 102 |
+
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
|
| 103 |
+
WHERE lpc.name = 'document_embeddings'
|
| 104 |
+
AND lpe.cmetadata->>'user_id' = :user_id
|
| 105 |
+
AND lpe.cmetadata->>'source_type' = 'document'
|
| 106 |
+
AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
|
| 107 |
+
OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
|
| 108 |
+
ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
|
| 109 |
+
LIMIT :k
|
| 110 |
+
""")
|
| 111 |
+
|
| 112 |
+
async with _pgvector_engine.connect() as conn:
|
| 113 |
+
result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
|
| 114 |
+
rows = result.fetchall()
|
| 115 |
+
|
| 116 |
+
results = []
|
| 117 |
+
for row in rows:
|
| 118 |
+
results.append(
|
| 119 |
+
RetrievalResult(
|
| 120 |
+
content=row.document,
|
| 121 |
+
metadata=row.cmetadata,
|
| 122 |
+
score=float(row.score),
|
| 123 |
+
source_type="document",
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
if len(results) >= k:
|
| 127 |
+
break
|
| 128 |
+
return results
|
| 129 |
+
|
| 130 |
+
async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
|
| 131 |
+
"""Full-text search over DB schema chunks using PostgreSQL tsvector.
|
| 132 |
+
|
| 133 |
+
Requires GIN index on langchain_pg_embedding.document (created by init_db.py).
|
| 134 |
+
"""
|
| 135 |
+
sql = text("""
|
| 136 |
+
SELECT lpe.document, lpe.cmetadata,
|
| 137 |
+
ts_rank(to_tsvector('english', lpe.document),
|
| 138 |
+
plainto_tsquery('english', :query)) AS rank
|
| 139 |
+
FROM langchain_pg_embedding lpe
|
| 140 |
+
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
|
| 141 |
+
WHERE lpc.name = 'document_embeddings'
|
| 142 |
+
AND lpe.cmetadata->>'user_id' = :user_id
|
| 143 |
+
AND lpe.cmetadata->>'source_type' = 'database'
|
| 144 |
+
AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query)
|
| 145 |
+
ORDER BY rank DESC
|
| 146 |
+
LIMIT :k
|
| 147 |
+
""")
|
| 148 |
+
|
| 149 |
+
async with _pgvector_engine.connect() as conn:
|
| 150 |
+
result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k})
|
| 151 |
+
rows = result.fetchall()
|
| 152 |
+
|
| 153 |
+
return [
|
| 154 |
+
RetrievalResult(
|
| 155 |
+
content=row.document,
|
| 156 |
+
metadata=row.cmetadata,
|
| 157 |
+
score=float(row.rank),
|
| 158 |
+
source_type="database",
|
| 159 |
+
)
|
| 160 |
+
for row in rows
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
async def _search_fts_tabular(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
|
| 164 |
+
"""Full-text search over tabular document chunks using PostgreSQL tsvector."""
|
| 165 |
+
sql = text("""
|
| 166 |
+
SELECT lpe.document, lpe.cmetadata,
|
| 167 |
+
ts_rank(to_tsvector('english', lpe.document),
|
| 168 |
+
plainto_tsquery('english', :query)) AS rank
|
| 169 |
+
FROM langchain_pg_embedding lpe
|
| 170 |
+
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
|
| 171 |
+
WHERE lpc.name = 'document_embeddings'
|
| 172 |
+
AND lpe.cmetadata->>'user_id' = :user_id
|
| 173 |
+
AND lpe.cmetadata->>'source_type' = 'document'
|
| 174 |
+
AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
|
| 175 |
+
OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
|
| 176 |
+
AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query)
|
| 177 |
+
ORDER BY rank DESC
|
| 178 |
+
LIMIT :k
|
| 179 |
+
""")
|
| 180 |
+
|
| 181 |
+
async with _pgvector_engine.connect() as conn:
|
| 182 |
+
result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k})
|
| 183 |
+
rows = result.fetchall()
|
| 184 |
+
|
| 185 |
+
return [
|
| 186 |
+
RetrievalResult(
|
| 187 |
+
content=row.document,
|
| 188 |
+
metadata=row.cmetadata,
|
| 189 |
+
score=float(row.rank),
|
| 190 |
+
source_type="document",
|
| 191 |
+
)
|
| 192 |
+
for row in rows
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
def _rrf_merge(
|
| 196 |
+
self,
|
| 197 |
+
*ranked_lists: list[RetrievalResult],
|
| 198 |
+
k_rrf: int = 60,
|
| 199 |
+
top_k: int = 5,
|
| 200 |
+
) -> list[RetrievalResult]:
|
| 201 |
+
"""Reciprocal Rank Fusion — combines ranked lists using rank positions only."""
|
| 202 |
+
scores: dict[tuple, float] = {}
|
| 203 |
+
index: dict[tuple, RetrievalResult] = {}
|
| 204 |
+
|
| 205 |
+
for ranked in ranked_lists:
|
| 206 |
+
for rank, result in enumerate(ranked):
|
| 207 |
+
data = result.metadata.get("data", {})
|
| 208 |
+
key = (data.get("table_name"), data.get("column_name") or data.get("filename"))
|
| 209 |
+
scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
|
| 210 |
+
if key not in index or result.score > index[key].score:
|
| 211 |
+
index[key] = result
|
| 212 |
+
|
| 213 |
+
def _key(r: RetrievalResult) -> tuple:
|
| 214 |
+
d = r.metadata.get("data", {})
|
| 215 |
+
return (d.get("table_name"), d.get("column_name") or d.get("filename"))
|
| 216 |
+
|
| 217 |
+
merged = sorted(index.values(), key=lambda r: scores[_key(r)], reverse=True)
|
| 218 |
+
return merged[:top_k]
|
| 219 |
+
|
| 220 |
+
def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
|
| 221 |
+
"""Deduplicate by (table_name, column_name), keeping highest score per unique column."""
|
| 222 |
+
seen: dict[tuple, RetrievalResult] = {}
|
| 223 |
+
for r in results:
|
| 224 |
+
data = r.metadata.get("data", {})
|
| 225 |
+
key = (data.get("table_name"), data.get("column_name") or data.get("filename"))
|
| 226 |
+
if key not in seen or r.score > seen[key].score:
|
| 227 |
+
seen[key] = r
|
| 228 |
+
return sorted(seen.values(), key=lambda r: r.score, reverse=True)
|
| 229 |
+
|
| 230 |
+
# ------------------------------------------------------------------
|
| 231 |
+
# Named strategies — one embed call each, legs run in parallel
|
| 232 |
+
# ------------------------------------------------------------------
|
| 233 |
+
|
| 234 |
+
async def dense_no_threshold(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
|
| 235 |
+
"""Cosine similarity, no score cutoff — always returns k chunks."""
|
| 236 |
+
embedding = await self._embed_query(query)
|
| 237 |
+
db_results, tabular_results = await asyncio.gather(
|
| 238 |
+
self._search_db(embedding, user_id, k),
|
| 239 |
+
self._search_tabular(embedding, user_id, k),
|
| 240 |
+
)
|
| 241 |
+
return self._dedup(db_results + tabular_results)[:k]
|
| 242 |
+
|
| 243 |
+
async def dense_dot(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
|
| 244 |
+
"""Inner product similarity (<#>).
|
| 245 |
+
|
| 246 |
+
For L2-normalized embeddings (OpenAI), ranking is identical to cosine.
|
| 247 |
+
Score = raw inner product (not bounded to [0,1]).
|
| 248 |
+
"""
|
| 249 |
+
embedding = await self._embed_query(query)
|
| 250 |
+
db_results, tabular_results = await asyncio.gather(
|
| 251 |
+
self._search_db(embedding, user_id, k, "<#>"),
|
| 252 |
+
self._search_tabular(embedding, user_id, k, "<#>"),
|
| 253 |
+
)
|
| 254 |
+
return self._dedup(db_results + tabular_results)[:k]
|
| 255 |
+
|
| 256 |
+
async def dense_l2(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
|
| 257 |
+
"""L2 (Euclidean) distance similarity (<->).
|
| 258 |
+
|
| 259 |
+
For L2-normalized embeddings (OpenAI), ranking order matches cosine.
|
| 260 |
+
Score = 1 / (1 + l2_distance), bounded to (0, 1].
|
| 261 |
+
"""
|
| 262 |
+
embedding = await self._embed_query(query)
|
| 263 |
+
db_results, tabular_results = await asyncio.gather(
|
| 264 |
+
self._search_db(embedding, user_id, k, "<->"),
|
| 265 |
+
self._search_tabular(embedding, user_id, k, "<->"),
|
| 266 |
+
)
|
| 267 |
+
return self._dedup(db_results + tabular_results)[:k]
|
| 268 |
+
|
| 269 |
+
async def hybrid(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
|
| 270 |
+
"""RRF merge of dense + FTS over both database and tabular sources.
|
| 271 |
+
|
| 272 |
+
Embeds once, then runs all four legs (dense db, dense tabular, fts db,
|
| 273 |
+
fts tabular) in a single asyncio.gather.
|
| 274 |
+
"""
|
| 275 |
+
embedding = await self._embed_query(query)
|
| 276 |
+
db_results, tabular_results, fts_db, fts_tabular = await asyncio.gather(
|
| 277 |
+
self._search_db(embedding, user_id, k),
|
| 278 |
+
self._search_tabular(embedding, user_id, k),
|
| 279 |
+
self._search_fts_db(query, user_id, k * 4),
|
| 280 |
+
self._search_fts_tabular(query, user_id, k * 4),
|
| 281 |
+
)
|
| 282 |
+
dense = self._dedup(db_results + tabular_results)[:k]
|
| 283 |
+
fts_all = self._dedup(fts_db + fts_tabular)
|
| 284 |
+
return self._rrf_merge(dense, fts_all, top_k=k)
|
| 285 |
+
|
| 286 |
+
async def hybrid_bm25(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
|
| 287 |
+
"""RRF merge of dense + FTS (database chunks only).
|
| 288 |
+
|
| 289 |
+
Embeds once, then runs dense db, dense tabular, and fts db legs in parallel.
|
| 290 |
+
"""
|
| 291 |
+
embedding = await self._embed_query(query)
|
| 292 |
+
db_results, tabular_results, fts_results = await asyncio.gather(
|
| 293 |
+
self._search_db(embedding, user_id, k),
|
| 294 |
+
self._search_tabular(embedding, user_id, k),
|
| 295 |
+
self._search_fts_db(query, user_id, k * 4),
|
| 296 |
+
)
|
| 297 |
+
dense = self._dedup(db_results + tabular_results)[:k]
|
| 298 |
+
return self._rrf_merge(dense, self._dedup(fts_results), top_k=k)
|
| 299 |
+
|
| 300 |
+
# ------------------------------------------------------------------
|
| 301 |
+
# Public interface — called by the router
|
| 302 |
+
# ------------------------------------------------------------------
|
| 303 |
+
|
| 304 |
+
async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
|
| 305 |
+
strategy_fn = getattr(self, ACTIVE_STRATEGY)
|
| 306 |
+
results = await strategy_fn(query, user_id, k)
|
| 307 |
+
logger.info("schema retrieval", strategy=ACTIVE_STRATEGY, count=len(results))
|
| 308 |
+
return results
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# ------------------------------------------------------------------
|
| 312 |
+
# Benchmark helper — import in test scripts
|
| 313 |
+
# ------------------------------------------------------------------
|
| 314 |
+
|
| 315 |
+
async def benchmark(
|
| 316 |
+
query: str,
|
| 317 |
+
user_id: str,
|
| 318 |
+
k: int = 5,
|
| 319 |
+
strategies: list[Strategy] | None = None,
|
| 320 |
+
) -> dict[str, dict]:
|
| 321 |
+
"""Run multiple strategies against the same query and return timing + results."""
|
| 322 |
+
retriever = SchemaRetriever()
|
| 323 |
+
targets: list[Strategy] = strategies or [
|
| 324 |
+
"dense_no_threshold",
|
| 325 |
+
"dense_dot",
|
| 326 |
+
"dense_l2",
|
| 327 |
+
"hybrid",
|
| 328 |
+
"hybrid_bm25",
|
| 329 |
+
]
|
| 330 |
+
report: dict[str, dict] = {}
|
| 331 |
+
|
| 332 |
+
for name in targets:
|
| 333 |
+
fn = getattr(retriever, name)
|
| 334 |
+
t0 = time.perf_counter()
|
| 335 |
+
chunks = await fn(query, user_id, k)
|
| 336 |
+
elapsed_ms = round((time.perf_counter() - t0) * 1000)
|
| 337 |
+
|
| 338 |
+
total_chars = sum(len(r.content) for r in chunks)
|
| 339 |
+
report[name] = {
|
| 340 |
+
"chunks": len(chunks),
|
| 341 |
+
"estimated_tokens": total_chars // 4,
|
| 342 |
+
"elapsed_ms": elapsed_ms,
|
| 343 |
+
"results": chunks,
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
return report
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
schema_retriever = SchemaRetriever()
|
src/rag/router.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Routes retrieval requests to the appropriate retriever based on source_hint."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import hashlib
|
| 5 |
+
import json
|
| 6 |
+
from typing import Literal
|
| 7 |
+
|
| 8 |
+
from src.db.redis.connection import get_redis
|
| 9 |
+
from src.middlewares.logging import get_logger
|
| 10 |
+
from src.rag.base import BaseRetriever, RetrievalResult
|
| 11 |
+
|
| 12 |
+
logger = get_logger("retrieval_router")
|
| 13 |
+
|
| 14 |
+
_CACHE_TTL = 3600 # 1 hour
|
| 15 |
+
SourceHint = Literal["document", "schema", "both"]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RetrievalRouter:
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
schema_retriever: BaseRetriever,
|
| 22 |
+
document_retriever: BaseRetriever,
|
| 23 |
+
):
|
| 24 |
+
self._retrievers: dict[str, BaseRetriever] = {
|
| 25 |
+
"schema": schema_retriever,
|
| 26 |
+
"document": document_retriever,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
def _route(self, source_hint: SourceHint) -> list[BaseRetriever]:
|
| 30 |
+
if source_hint == "schema":
|
| 31 |
+
return [self._retrievers["schema"]]
|
| 32 |
+
if source_hint == "document":
|
| 33 |
+
return [self._retrievers["document"]]
|
| 34 |
+
return list(self._retrievers.values())
|
| 35 |
+
|
| 36 |
+
async def retrieve(
|
| 37 |
+
self,
|
| 38 |
+
query: str,
|
| 39 |
+
user_id: str,
|
| 40 |
+
source_hint: SourceHint = "both",
|
| 41 |
+
k: int = 10,
|
| 42 |
+
) -> list[RetrievalResult]:
|
| 43 |
+
redis = await get_redis()
|
| 44 |
+
query_hash = hashlib.md5(query.encode()).hexdigest()
|
| 45 |
+
cache_key = f"retrieval:{user_id}:{source_hint}:{query_hash}:{k}"
|
| 46 |
+
|
| 47 |
+
cached = await redis.get(cache_key)
|
| 48 |
+
if cached:
|
| 49 |
+
logger.info("returning cached retrieval results", source_hint=source_hint)
|
| 50 |
+
raw = json.loads(cached)
|
| 51 |
+
return [RetrievalResult(**r) for r in raw]
|
| 52 |
+
|
| 53 |
+
retrievers = self._route(source_hint)
|
| 54 |
+
batches = await asyncio.gather(
|
| 55 |
+
*[r.retrieve(query, user_id, k) for r in retrievers],
|
| 56 |
+
return_exceptions=True,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
results: list[RetrievalResult] = []
|
| 60 |
+
for batch in batches:
|
| 61 |
+
if isinstance(batch, Exception):
|
| 62 |
+
logger.error("retriever failed", error=str(batch))
|
| 63 |
+
continue
|
| 64 |
+
results.extend(batch)
|
| 65 |
+
|
| 66 |
+
results.sort(key=lambda r: r.score, reverse=True)
|
| 67 |
+
results = results[:k]
|
| 68 |
+
|
| 69 |
+
logger.info("retrieved chunks", count=len(results), source_hint=source_hint)
|
| 70 |
+
await redis.setex(
|
| 71 |
+
cache_key,
|
| 72 |
+
_CACHE_TTL,
|
| 73 |
+
json.dumps([vars(r) for r in results]),
|
| 74 |
+
)
|
| 75 |
+
return results
|