Rifqi Hafizuddin commited on
Commit
ba550a5
·
1 Parent(s): 767625e

[KM-438-439] add retriever feature

Browse files
src/query/__init__.py ADDED
File without changes
src/query/base.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared contract for query executors."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass, field
5
+
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+
8
+ from src.rag.base import RetrievalResult
9
+
10
+
11
+ @dataclass
12
+ class QueryResult:
13
+ source_type: str # "database" or "document"
14
+ source_id: str # database_client_id or document_id
15
+ table_or_file: str
16
+ columns: list[str]
17
+ rows: list[dict]
18
+ row_count: int
19
+ metadata: dict = field(default_factory=dict)
20
+ # metadata should include "column_types": {"col_name": "dtype"} when available
21
+
22
+
23
+ class BaseExecutor(ABC):
24
+ @abstractmethod
25
+ async def execute(
26
+ self,
27
+ results: list[RetrievalResult],
28
+ user_id: str,
29
+ db: AsyncSession,
30
+ question: str,
31
+ limit: int = 100,
32
+ ) -> list[QueryResult]: ...
src/query/executors/__init__.py ADDED
File without changes
src/query/executors/db_executor.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Executor for registered database sources (source_type="database").
2
+
3
+ Flow per (client_id, question):
4
+ 1. Collect all relevant (table_name, column_name) pairs from retrieval results.
5
+ 2. Fetch the FULL schema for those tables from PGVector (not just top-k columns).
6
+ 3. Build a schema context string and send to LLM → structured SQLQuery output.
7
+ 4. Validate via sqlglot: SELECT-only, schema-grounded, LIMIT enforced.
8
+ 5. Execute on the user's DB via engine_scope + asyncio.to_thread.
9
+ 6. Return QueryResult per client_id (may span multiple tables via JOINs).
10
+
11
+ Supported db_types: postgres, supabase, mysql.
12
+ Other types are skipped with a warning — they do not raise.
13
+ """
14
+
15
+ import asyncio
16
+ from collections import defaultdict
17
+ from typing import Any
18
+
19
+ import sqlglot
20
+ import sqlglot.expressions as exp
21
+ import tiktoken
22
+ from langchain_core.prompts import ChatPromptTemplate
23
+ from langchain_openai import AzureChatOpenAI
24
+ from sqlalchemy import text
25
+ from sqlalchemy.ext.asyncio import AsyncSession
26
+
27
+ from src.config.settings import settings
28
+ from src.database_client.database_client_service import database_client_service
29
+ from src.db.postgres.connection import _pgvector_engine
30
+ from src.middlewares.logging import get_logger
31
+ from src.models.sql_query import SQLQuery
32
+ from src.pipeline.db_pipeline import db_pipeline_service
33
+ from src.query.base import BaseExecutor, QueryResult
34
+ from src.rag.base import RetrievalResult
35
+ from src.utils.db_credential_encryption import decrypt_credentials_dict
36
+
37
+ logger = get_logger("db_executor")
38
+
39
+ _enc = tiktoken.get_encoding("cl100k_base")
40
+
41
+ _SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"}
42
+ _MAX_RETRIES = 3
43
+ _MAX_LIMIT = 500
44
+
45
+ _SQL_SYSTEM_PROMPT = """\
46
+ You are a SQL data analyst working with a user's database.
47
+ Generate a single SQL SELECT statement that answers the user's question.
48
+
49
+ Database dialect: {dialect}
50
+
51
+ Rules:
52
+ - ONLY reference tables and columns listed in the schema below. Do not invent names.
53
+ - Always include a LIMIT clause (max {limit}).
54
+ - Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL.
55
+ - Prefer explicit JOINs over subqueries when combining tables.
56
+ - For aggregations, always alias the result column (e.g. COUNT(*) AS order_count).
57
+ - For date filtering, use dialect-appropriate functions ({dialect} syntax).
58
+
59
+ Schema:
60
+ {schema}
61
+
62
+ {error_section}"""
63
+
64
+
65
+ class DbExecutor(BaseExecutor):
66
+ def __init__(self) -> None:
67
+ self._llm = AzureChatOpenAI(
68
+ azure_deployment=settings.azureai_deployment_name_4o,
69
+ openai_api_version=settings.azureai_api_version_4o,
70
+ azure_endpoint=settings.azureai_endpoint_url_4o,
71
+ api_key=settings.azureai_api_key_4o,
72
+ temperature=0,
73
+ )
74
+ self._prompt = ChatPromptTemplate.from_messages([
75
+ ("system", _SQL_SYSTEM_PROMPT),
76
+ ("human", "{question}"),
77
+ ])
78
+ self._chain = self._prompt | self._llm.with_structured_output(SQLQuery)
79
+
80
+ # ------------------------------------------------------------------
81
+ # Public interface
82
+ # ------------------------------------------------------------------
83
+
84
+ async def execute(
85
+ self,
86
+ results: list[RetrievalResult],
87
+ user_id: str,
88
+ db: AsyncSession,
89
+ question: str,
90
+ limit: int = 100,
91
+ ) -> list[QueryResult]:
92
+ db_results = [r for r in results if r.source_type == "database"]
93
+ if not db_results:
94
+ return []
95
+
96
+ # Group by client_id — one SQL generation + execution pass per client
97
+ by_client: dict[str, list[RetrievalResult]] = defaultdict(list)
98
+ for r in db_results:
99
+ client_id = r.metadata.get("database_client_id", "")
100
+ if client_id:
101
+ by_client[client_id].append(r)
102
+ else:
103
+ logger.warning("db result missing database_client_id, skipping")
104
+
105
+ query_results: list[QueryResult] = []
106
+ for client_id, client_results in by_client.items():
107
+ try:
108
+ qr = await self._execute_for_client(client_id, client_results, user_id, db, question, limit)
109
+ if qr:
110
+ query_results.append(qr)
111
+ except Exception as e:
112
+ logger.error("db executor failed for client", client_id=client_id, error=str(e))
113
+
114
+ return query_results
115
+
116
+ # ------------------------------------------------------------------
117
+ # Per-client execution
118
+ # ------------------------------------------------------------------
119
+
120
+ async def _execute_for_client(
121
+ self,
122
+ client_id: str,
123
+ results: list[RetrievalResult],
124
+ user_id: str,
125
+ db: AsyncSession,
126
+ question: str,
127
+ limit: int,
128
+ ) -> QueryResult | None:
129
+ client = await database_client_service.get(db, client_id)
130
+ if not client:
131
+ logger.warning("database client not found", client_id=client_id)
132
+ return None
133
+ if client.user_id != user_id:
134
+ logger.warning("client ownership mismatch", client_id=client_id)
135
+ return None
136
+ if client.db_type not in _SUPPORTED_DB_TYPES:
137
+ logger.warning("unsupported db_type for query execution", db_type=client.db_type)
138
+ return None
139
+
140
+ # Distinct table names from retrieval results, expanded via FK relationships
141
+ table_names = list({
142
+ r.metadata.get("data", {}).get("table_name")
143
+ for r in results
144
+ if r.metadata.get("data", {}).get("table_name")
145
+ })
146
+ table_names = await self._expand_with_fk_tables(client_id, user_id, table_names)
147
+
148
+ full_schema = await self._fetch_full_schema(client_id, table_names, user_id)
149
+ if not full_schema:
150
+ logger.warning("no schema found in vector store", client_id=client_id, tables=table_names)
151
+ return None
152
+
153
+ schema_ctx = self._build_schema_context(full_schema)
154
+ capped_limit = min(limit, _MAX_LIMIT)
155
+ dialect = client.db_type
156
+
157
+ # SQL generation with retry
158
+ validated_sql: str | None = None
159
+ prev_error: str = ""
160
+ prev_reasoning: str = ""
161
+ for attempt in range(_MAX_RETRIES):
162
+ if prev_error:
163
+ error_section = (
164
+ f"Previous attempt reasoning: {prev_reasoning}\n"
165
+ f"Previous attempt failed: {prev_error}\n"
166
+ "Fix the issue above."
167
+ )
168
+ else:
169
+ error_section = ""
170
+ try:
171
+ prompt_text = schema_ctx + error_section + question
172
+ input_tokens = len(_enc.encode(prompt_text))
173
+ logger.info("sql generation input tokens", attempt=attempt + 1, tokens=input_tokens)
174
+
175
+ result: SQLQuery = await self._chain.ainvoke({
176
+ "schema": schema_ctx,
177
+ "dialect": dialect,
178
+ "limit": capped_limit,
179
+ "error_section": error_section,
180
+ "question": question,
181
+ })
182
+ sql = result.sql.strip()
183
+ validation_error = self._validate(sql, full_schema, capped_limit)
184
+ if validation_error:
185
+ prev_error = validation_error
186
+ prev_reasoning = result.reasoning
187
+ logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error)
188
+ continue
189
+ validated_sql = self._enforce_limit(sql, capped_limit)
190
+ output_tokens = len(_enc.encode(result.sql)) + len(_enc.encode(result.reasoning))
191
+ logger.info(
192
+ "sql generated",
193
+ attempt=attempt + 1,
194
+ input_tokens=input_tokens,
195
+ output_tokens=output_tokens,
196
+ total_tokens=input_tokens + output_tokens,
197
+ reasoning=result.reasoning,
198
+ )
199
+ break
200
+ except Exception as e:
201
+ prev_error = str(e)
202
+ logger.warning("sql generation error", attempt=attempt + 1, error=prev_error)
203
+
204
+ if not validated_sql:
205
+ logger.error("sql generation failed after retries", client_id=client_id)
206
+ return None
207
+
208
+ # Execute on user's DB
209
+ creds = decrypt_credentials_dict(client.credentials)
210
+ with db_pipeline_service.engine_scope(client.db_type, creds) as engine:
211
+ rows = await asyncio.to_thread(self._run_sql, engine, validated_sql)
212
+
213
+ column_types = {
214
+ col["name"]: col["type"]
215
+ for cols in full_schema.values()
216
+ for col in cols
217
+ }
218
+ columns = list(rows[0].keys()) if rows else []
219
+
220
+ return QueryResult(
221
+ source_type="database",
222
+ source_id=client_id,
223
+ table_or_file=", ".join(table_names),
224
+ columns=columns,
225
+ rows=rows,
226
+ row_count=len(rows),
227
+ metadata={
228
+ "db_type": client.db_type,
229
+ "client_name": client.name,
230
+ "sql": validated_sql,
231
+ "column_types": {c: column_types.get(c, "unknown") for c in columns},
232
+ },
233
+ )
234
+
235
+ # ------------------------------------------------------------------
236
+ # Schema helpers
237
+ # ------------------------------------------------------------------
238
+
239
+ async def _expand_with_fk_tables(
240
+ self,
241
+ client_id: str,
242
+ user_id: str,
243
+ table_names: list[str],
244
+ ) -> list[str]:
245
+ """Expand table_names with any tables FK-referenced by the retrieved tables.
246
+
247
+ Prevents SQL generation failures when a required table (e.g. orders) wasn't
248
+ returned by retrieval but is referenced via FK from a table that was
249
+ (e.g. order_items.order_id -> orders.id).
250
+ """
251
+ if not table_names:
252
+ return table_names
253
+
254
+ placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
255
+ sql = text(f"""
256
+ SELECT DISTINCT lpe.cmetadata->'data'->>'foreign_key' AS fk
257
+ FROM langchain_pg_embedding lpe
258
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
259
+ WHERE lpc.name = 'document_embeddings'
260
+ AND lpe.cmetadata->>'user_id' = :user_id
261
+ AND lpe.cmetadata->>'source_type' = 'database'
262
+ AND lpe.cmetadata->>'database_client_id' = :client_id
263
+ AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
264
+ AND lpe.cmetadata->'data'->>'foreign_key' IS NOT NULL
265
+ """)
266
+
267
+ params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
268
+ for i, name in enumerate(table_names):
269
+ params[f"t{i}"] = name
270
+
271
+ async with _pgvector_engine.connect() as conn:
272
+ result = await conn.execute(sql, params)
273
+ rows = result.fetchall()
274
+
275
+ expanded = set(table_names)
276
+ for row in rows:
277
+ fk = row.fk # format: "referred_table.referred_column"
278
+ if fk:
279
+ referred_table = fk.split(".")[0]
280
+ expanded.add(referred_table)
281
+
282
+ if expanded != set(table_names):
283
+ logger.info(
284
+ "expanded tables via FK",
285
+ original=sorted(table_names),
286
+ expanded=sorted(expanded),
287
+ )
288
+
289
+ return list(expanded)
290
+
291
+ async def _fetch_full_schema(
292
+ self,
293
+ client_id: str,
294
+ table_names: list[str],
295
+ user_id: str,
296
+ ) -> dict[str, list[dict[str, Any]]]:
297
+ """Fetch ALL column chunks for the given tables from PGVector.
298
+
299
+ Returns {table_name: [{"name": ..., "type": ..., "is_primary_key": ...,
300
+ "foreign_key": ..., "content": ...}]}
301
+ """
302
+ placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
303
+ sql = text(f"""
304
+ SELECT lpe.cmetadata, lpe.document
305
+ FROM langchain_pg_embedding lpe
306
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
307
+ WHERE lpc.name = 'document_embeddings'
308
+ AND lpe.cmetadata->>'user_id' = :user_id
309
+ AND lpe.cmetadata->>'source_type' = 'database'
310
+ AND lpe.cmetadata->>'database_client_id' = :client_id
311
+ AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
312
+ ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name'
313
+ """)
314
+
315
+ params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
316
+ for i, name in enumerate(table_names):
317
+ params[f"t{i}"] = name
318
+
319
+ async with _pgvector_engine.connect() as conn:
320
+ result = await conn.execute(sql, params)
321
+ rows = result.fetchall()
322
+
323
+ schema: dict[str, list[dict[str, Any]]] = defaultdict(list)
324
+ for row in rows:
325
+ data = row.cmetadata.get("data", {})
326
+ table = data.get("table_name")
327
+ if table:
328
+ schema[table].append({
329
+ "name": data.get("column_name", ""),
330
+ "type": data.get("column_type", ""),
331
+ "is_primary_key": data.get("is_primary_key", False),
332
+ "foreign_key": data.get("foreign_key"),
333
+ "content": row.document, # chunk text includes top values / samples
334
+ })
335
+ return dict(schema)
336
+
337
+ def _build_schema_context(self, schema: dict[str, list[dict[str, Any]]]) -> str:
338
+ lines: list[str] = []
339
+ for table, columns in schema.items():
340
+ lines.append(f"Table: {table}")
341
+ for col in columns:
342
+ flags = []
343
+ if col["is_primary_key"]:
344
+ flags.append("PRIMARY KEY")
345
+ if col["foreign_key"]:
346
+ flags.append(f"FK -> {col['foreign_key']}")
347
+ flag_str = f" [{', '.join(flags)}]" if flags else ""
348
+ lines.append(f" - {col['name']} {col['type']}{flag_str}")
349
+ # Include sample/top-values line from chunk content if present
350
+ for line in col["content"].splitlines():
351
+ if line.startswith(("Top values:", "Sample values:")):
352
+ lines.append(f" {line}")
353
+ break
354
+ lines.append("")
355
+ return "\n".join(lines).strip()
356
+
357
+ # ------------------------------------------------------------------
358
+ # Guardrails
359
+ # ------------------------------------------------------------------
360
+
361
+ def _validate(self, sql: str, schema: dict[str, list[dict]], limit: int) -> str:
362
+ """Return an error string if validation fails, empty string if OK."""
363
+ # Layer 1: sqlglot parse + SELECT-only check
364
+ try:
365
+ parsed = sqlglot.parse_one(sql)
366
+ except sqlglot.errors.ParseError as e:
367
+ return f"SQL parse error: {e}"
368
+
369
+ if not isinstance(parsed, exp.Select):
370
+ return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}"
371
+
372
+ # Check for DML anywhere in the AST (including writeable CTEs)
373
+ for node in parsed.find_all((exp.Insert, exp.Update, exp.Delete)):
374
+ return f"DML ({type(node).__name__}) is not allowed."
375
+
376
+ # Layer 2: schema grounding — table names
377
+ known_tables = {t.lower() for t in schema}
378
+ for tbl in parsed.find_all(exp.Table):
379
+ name = tbl.name.lower()
380
+ if name and name not in known_tables:
381
+ return f"Unknown table '{tbl.name}'. Only use tables from the schema."
382
+
383
+ # Layer 3: LIMIT enforcement (inject if missing — done before execution)
384
+ return ""
385
+
386
+ # ------------------------------------------------------------------
387
+ # SQL execution
388
+ # ------------------------------------------------------------------
389
+
390
+ def _enforce_limit(self, sql: str, limit: int) -> str:
391
+ """Inject or cap LIMIT using sqlglot AST manipulation."""
392
+ parsed = sqlglot.parse_one(sql)
393
+ existing = parsed.find(exp.Limit)
394
+ if existing:
395
+ current = int(existing.expression.this)
396
+ if current > limit:
397
+ existing.expression.set("this", limit)
398
+ else:
399
+ parsed = parsed.limit(limit)
400
+ return parsed.sql()
401
+
402
+ def _run_sql(self, engine: Any, sql: str) -> list[dict]:
403
+ # Ensure the user DB connection is a read-only credential — sqlglot validation alone is not sufficient.
404
+ with engine.connect() as conn:
405
+ result = conn.execute(text(sql))
406
+ return [dict(row) for row in result.mappings()]
407
+
408
+
409
+ db_executor = DbExecutor()
src/query/executors/tabular.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Executor for tabular document sources (source_type="document", file_type csv/xlsx).
2
+
3
+ Flow:
4
+ 1. Group RetrievalResult chunks by document_id.
5
+ 2. For each document: download bytes from Azure Blob -> read with pandas.
6
+ 3. Filter DataFrame to relevant columns identified by retrieval.
7
+ 4. Return QueryResult per document.
8
+ """
9
+
10
+ from sqlalchemy.ext.asyncio import AsyncSession
11
+
12
+ from src.middlewares.logging import get_logger
13
+ from src.query.base import BaseExecutor, QueryResult
14
+ from src.rag.base import RetrievalResult
15
+
16
+ logger = get_logger("tabular_executor")
17
+
18
+ _TABULAR_FILE_TYPES = ("csv", "xlsx")
19
+
20
+
21
+ class TabularExecutor(BaseExecutor):
22
+ async def execute(
23
+ self,
24
+ results: list[RetrievalResult],
25
+ user_id: str,
26
+ db: AsyncSession,
27
+ limit: int = 100,
28
+ ) -> list[QueryResult]:
29
+ # TODO: implement
30
+ # 1. filter results where source_type == "document" and file_type in _TABULAR_FILE_TYPES
31
+ # 2. group by document_id -> list of column_names
32
+ # 3. per group: look up Document by document_id -> get blob_name
33
+ # 4. blob_storage.download_file(blob_name) -> pd.read_csv / pd.read_excel
34
+ # 5. df[relevant_columns].head(limit) -> rows as list[dict]
35
+ # 6. return QueryResult per document
36
+ raise NotImplementedError
37
+
38
+
39
+ tabular_executor = TabularExecutor()
src/query/query_executor.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """QueryExecutor — dispatches retrieval results to the appropriate executor by source_type."""
2
+
3
+ import asyncio
4
+
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+
7
+ from src.middlewares.logging import get_logger
8
+ from src.query.base import QueryResult
9
+ from src.query.executors.db_executor import db_executor
10
+ from src.query.executors.tabular import tabular_executor
11
+ from src.rag.base import RetrievalResult
12
+
13
+ logger = get_logger("query_executor")
14
+
15
+
16
+ class QueryExecutor:
17
+ async def execute(
18
+ self,
19
+ results: list[RetrievalResult],
20
+ user_id: str,
21
+ db: AsyncSession,
22
+ question: str,
23
+ limit: int = 100,
24
+ ) -> list[QueryResult]:
25
+ db_results = [r for r in results if r.source_type == "database"]
26
+ tabular_results = [
27
+ r for r in results
28
+ if r.source_type == "document"
29
+ and r.metadata.get("data", {}).get("file_type") in ("csv", "xlsx")
30
+ ]
31
+
32
+ async def _empty() -> list[QueryResult]:
33
+ return []
34
+
35
+ batches = await asyncio.gather(
36
+ db_executor.execute(db_results, user_id, db, question, limit) if db_results else _empty(),
37
+ tabular_executor.execute(tabular_results, user_id, db, question, limit) if tabular_results else _empty(),
38
+ return_exceptions=True,
39
+ )
40
+
41
+ query_results: list[QueryResult] = []
42
+ for batch in batches:
43
+ if isinstance(batch, Exception):
44
+ logger.error("executor failed", error=str(batch))
45
+ continue
46
+ query_results.extend(batch)
47
+
48
+ logger.info("query execution complete", total=len(query_results))
49
+ return query_results
50
+
51
+
52
+ query_executor = QueryExecutor()
src/rag/base.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared contract for all retriever implementations."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+
8
+ @dataclass
9
+ class RetrievalResult:
10
+ content: str
11
+ metadata: dict[str, Any]
12
+ score: float
13
+ source_type: str # "document" | "database"
14
+
15
+
16
+ class BaseRetriever(ABC):
17
+ @abstractmethod
18
+ async def retrieve(
19
+ self, query: str, user_id: str, k: int = 5
20
+ ) -> list[RetrievalResult]: ...
src/rag/retriever.py CHANGED
@@ -1,69 +1,43 @@
1
- """Service for retrieving relevant documents from vector store."""
 
 
2
 
3
- import hashlib
4
- import json
5
- from src.db.postgres.vector_store import get_vector_store
6
- from src.db.redis.connection import get_redis
7
  from sqlalchemy.ext.asyncio import AsyncSession
 
8
  from src.middlewares.logging import get_logger
9
- from typing import List, Dict, Any
 
 
10
 
11
  logger = get_logger("retriever")
12
 
13
- _RETRIEVAL_CACHE_TTL = 3600 # 1 hour
14
-
15
 
16
  class RetrieverService:
17
- """Service for retrieving relevant documents."""
 
 
 
 
18
 
19
  def __init__(self):
20
- self.vector_store = get_vector_store()
 
 
 
21
 
22
  async def retrieve(
23
  self,
24
  query: str,
25
  user_id: str,
26
  db: AsyncSession,
27
- k: int = 5
28
- ) -> List[Dict[str, Any]]:
29
- """Retrieve relevant chunks for a query, scoped to the user's documents.
30
-
31
- Returns:
32
- List of dicts with keys: content, metadata
33
- metadata includes: document_id, user_id, filename, chunk_index, page_label (if PDF)
34
- """
35
  try:
36
- redis = await get_redis()
37
- query_hash = hashlib.md5(query.encode()).hexdigest()
38
- cache_key = f"retrieval:{user_id}:{query_hash}:{k}"
39
-
40
- cached = await redis.get(cache_key)
41
- if cached:
42
- logger.info("Returning cached retrieval results")
43
- return json.loads(cached)
44
-
45
- logger.info(f"Retrieving for user {user_id}, query: {query[:50]}...")
46
-
47
- docs = await self.vector_store.asimilarity_search(
48
- query=query,
49
- k=k,
50
- filter={"user_id": user_id}
51
- )
52
-
53
- results = [
54
- {
55
- "content": doc.page_content,
56
- "metadata": doc.metadata,
57
- }
58
- for doc in docs
59
- ]
60
-
61
- logger.info(f"Retrieved {len(results)} chunks")
62
- await redis.setex(cache_key, _RETRIEVAL_CACHE_TTL, json.dumps(results))
63
- return results
64
-
65
  except Exception as e:
66
- logger.error("Retrieval failed", error=str(e))
67
  return []
68
 
69
 
 
1
+ """Public retrieval API thin wrapper around RetrievalRouter."""
2
+
3
+ from typing import Any
4
 
 
 
 
 
5
  from sqlalchemy.ext.asyncio import AsyncSession
6
+
7
  from src.middlewares.logging import get_logger
8
+ from src.rag.retrievers.document import document_retriever
9
+ from src.rag.retrievers.schema import schema_retriever
10
+ from src.rag.router import RetrievalRouter, SourceHint
11
 
12
  logger = get_logger("retriever")
13
 
 
 
14
 
15
  class RetrieverService:
16
+ """Public retrieval service used by chat.py and search tools.
17
+
18
+ Delegates to RetrievalRouter which dispatches based on source_hint.
19
+ Returns List[Dict] to preserve backward compatibility with chat.py.
20
+ """
21
 
22
  def __init__(self):
23
+ self._router = RetrievalRouter(
24
+ schema_retriever=schema_retriever,
25
+ document_retriever=document_retriever,
26
+ )
27
 
28
  async def retrieve(
29
  self,
30
  query: str,
31
  user_id: str,
32
  db: AsyncSession,
33
+ k: int = 5,
34
+ source_hint: SourceHint = "both",
35
+ ) -> list[dict[str, Any]]:
 
 
 
 
 
36
  try:
37
+ results = await self._router.retrieve(query, user_id, source_hint, k)
38
+ return [{"content": r.content, "metadata": r.metadata} for r in results]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
+ logger.error("retrieval failed", error=str(e))
41
  return []
42
 
43
 
src/rag/retrievers/__init__.py ADDED
File without changes
src/rag/retrievers/baseline.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Service for retrieving relevant documents from vector store."""
2
+
3
+ import hashlib
4
+ import json
5
+ from src.db.postgres.vector_store import get_vector_store
6
+ from src.db.redis.connection import get_redis
7
+ from sqlalchemy.ext.asyncio import AsyncSession
8
+ from src.middlewares.logging import get_logger
9
+ from typing import List, Dict, Any
10
+
11
+ logger = get_logger("retriever")
12
+
13
+ _RETRIEVAL_CACHE_TTL = 3600 # 1 hour
14
+
15
+
16
+ class RetrieverService:
17
+ """Service for retrieving relevant documents."""
18
+
19
+ def __init__(self):
20
+ self.vector_store = get_vector_store()
21
+
22
+ async def retrieve(
23
+ self,
24
+ query: str,
25
+ user_id: str,
26
+ db: AsyncSession,
27
+ k: int = 5
28
+ ) -> List[Dict[str, Any]]:
29
+ """Retrieve relevant chunks for a query, scoped to the user's documents.
30
+
31
+ Returns:
32
+ List of dicts with keys: content, metadata
33
+ metadata includes: document_id, user_id, filename, chunk_index, page_label (if PDF)
34
+ """
35
+ try:
36
+ redis = await get_redis()
37
+ query_hash = hashlib.md5(query.encode()).hexdigest()
38
+ cache_key = f"retrieval:{user_id}:{query_hash}:{k}"
39
+
40
+ cached = await redis.get(cache_key)
41
+ if cached:
42
+ logger.info("Returning cached retrieval results")
43
+ return json.loads(cached)
44
+
45
+ logger.info(f"Retrieving for user {user_id}, query: {query[:50]}...")
46
+
47
+ docs = await self.vector_store.asimilarity_search(
48
+ query=query,
49
+ k=k,
50
+ filter={"user_id": user_id}
51
+ )
52
+
53
+ results = [
54
+ {
55
+ "content": doc.page_content,
56
+ "metadata": doc.metadata,
57
+ }
58
+ for doc in docs
59
+ ]
60
+
61
+ logger.info(f"Retrieved {len(results)} chunks")
62
+ await redis.setex(cache_key, _RETRIEVAL_CACHE_TTL, json.dumps(results))
63
+ return results
64
+
65
+ except Exception as e:
66
+ logger.error("Retrieval failed", error=str(e))
67
+ return []
68
+
69
+
70
+ retriever = RetrieverService()
src/rag/retrievers/document.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document retriever — handles PDF, DOCX, TXT chunks (source_type="document", non-tabular).
2
+
3
+ TEAMMATE: implement retrieve() below.
4
+ Strategy: MMR (amax_marginal_relevance_search) + score threshold to avoid returning
5
+ near-identical chunks from the same PDF page.
6
+ Filter: source_type="document" AND data->>'file_type' NOT IN ('csv', 'xlsx')
7
+ """
8
+
9
+ from src.db.postgres.vector_store import get_vector_store
10
+ from src.middlewares.logging import get_logger
11
+ from src.rag.base import BaseRetriever, RetrievalResult
12
+
13
+ logger = get_logger("document_retriever")
14
+
15
+ _SCORE_THRESHOLD = 0.45 # discard chunks with cosine distance above this
16
+
17
+
18
+ class DocumentRetriever(BaseRetriever):
19
+ def __init__(self):
20
+ self.vector_store = get_vector_store()
21
+
22
+ async def retrieve(
23
+ self, query: str, user_id: str, k: int = 5
24
+ ) -> list[RetrievalResult]:
25
+ # TODO (teammate): implement MMR retrieval for prose documents
26
+ # Filter: {"user_id": user_id, "source_type": "document"}
27
+ # then post-filter to exclude file_type in ("csv", "xlsx")
28
+ logger.info("document retriever not yet implemented — returning empty")
29
+ return []
30
+
31
+
32
+ document_retriever = DocumentRetriever()
src/rag/retrievers/schema.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Schema retriever — handles DB schemas (source_type="database") and tabular file
2
+ columns stored as source_type="document" with file_type in ("csv","xlsx").
3
+
4
+ Multiple retrieval strategies are exposed for benchmarking. The active strategy
5
+ used by the router is `retrieve()`, which dispatches to ACTIVE_STRATEGY.
6
+ Change ACTIVE_STRATEGY at module level to switch without touching the router.
7
+
8
+ All strategies embed the query exactly once, then fan out to parallel SQL legs.
9
+
10
+ Vector distance strategies:
11
+ dense_no_threshold — cosine (<=>), no score floor, always returns k chunks
12
+ dense_dot — inner product (<#>), equivalent to cosine for normalized embeddings
13
+ dense_l2 — L2/euclidean (<->), monotonic with cosine on unit-sphere vectors
14
+ hybrid — RRF merge of dense + FTS (database + tabular)
15
+ hybrid_bm25 — RRF merge of dense + FTS (database only)
16
+ """
17
+
18
+ import asyncio
19
+ import time
20
+ from typing import Literal
21
+
22
+ from sqlalchemy import text
23
+
24
+ from src.db.postgres.connection import _pgvector_engine
25
+ from src.db.postgres.vector_store import get_vector_store
26
+ from src.middlewares.logging import get_logger
27
+ from src.rag.base import BaseRetriever, RetrievalResult
28
+
29
+ logger = get_logger("schema_retriever")
30
+
31
+ _TABULAR_FILE_TYPES = ("csv", "xlsx")
32
+
33
+ Strategy = Literal["dense_no_threshold", "dense_dot", "dense_l2", "hybrid", "hybrid_bm25"]
34
+ ACTIVE_STRATEGY: Strategy = "hybrid_bm25"
35
+
36
+
37
+ class SchemaRetriever(BaseRetriever):
38
+ def __init__(self):
39
+ self.vector_store = get_vector_store()
40
+
41
+ # ------------------------------------------------------------------
42
+ # Internal helpers
43
+ # ------------------------------------------------------------------
44
+
45
+ async def _embed_query(self, query: str) -> list[float]:
46
+ return await asyncio.to_thread(self.vector_store.embeddings.embed_query, query)
47
+
48
+ async def _search_db(
49
+ self, embedding: list[float], user_id: str, k: int, operator: str = "<=>"
50
+ ) -> list[RetrievalResult]:
51
+ """Vector search over database chunks. Accepts a pre-computed embedding."""
52
+ emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
53
+
54
+ if operator == "<#>":
55
+ score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
56
+ elif operator == "<->":
57
+ score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
58
+ else:
59
+ score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
60
+
61
+ sql = text(f"""
62
+ SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
63
+ FROM langchain_pg_embedding lpe
64
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
65
+ WHERE lpc.name = 'document_embeddings'
66
+ AND lpe.cmetadata->>'user_id' = :user_id
67
+ AND lpe.cmetadata->>'source_type' = 'database'
68
+ ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
69
+ LIMIT :k
70
+ """)
71
+
72
+ async with _pgvector_engine.connect() as conn:
73
+ result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
74
+ rows = result.fetchall()
75
+
76
+ return [
77
+ RetrievalResult(
78
+ content=row.document,
79
+ metadata=row.cmetadata,
80
+ score=float(row.score),
81
+ source_type="database",
82
+ )
83
+ for row in rows
84
+ ]
85
+
86
+ async def _search_tabular(
87
+ self, embedding: list[float], user_id: str, k: int, operator: str = "<=>"
88
+ ) -> list[RetrievalResult]:
89
+ """Vector search over tabular document chunks. Accepts a pre-computed embedding."""
90
+ emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
91
+
92
+ if operator == "<#>":
93
+ score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
94
+ elif operator == "<->":
95
+ score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
96
+ else:
97
+ score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
98
+
99
+ sql = text(f"""
100
+ SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
101
+ FROM langchain_pg_embedding lpe
102
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
103
+ WHERE lpc.name = 'document_embeddings'
104
+ AND lpe.cmetadata->>'user_id' = :user_id
105
+ AND lpe.cmetadata->>'source_type' = 'document'
106
+ AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
107
+ OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
108
+ ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
109
+ LIMIT :k
110
+ """)
111
+
112
+ async with _pgvector_engine.connect() as conn:
113
+ result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
114
+ rows = result.fetchall()
115
+
116
+ results = []
117
+ for row in rows:
118
+ results.append(
119
+ RetrievalResult(
120
+ content=row.document,
121
+ metadata=row.cmetadata,
122
+ score=float(row.score),
123
+ source_type="document",
124
+ )
125
+ )
126
+ if len(results) >= k:
127
+ break
128
+ return results
129
+
130
+ async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
131
+ """Full-text search over DB schema chunks using PostgreSQL tsvector.
132
+
133
+ Requires GIN index on langchain_pg_embedding.document (created by init_db.py).
134
+ """
135
+ sql = text("""
136
+ SELECT lpe.document, lpe.cmetadata,
137
+ ts_rank(to_tsvector('english', lpe.document),
138
+ plainto_tsquery('english', :query)) AS rank
139
+ FROM langchain_pg_embedding lpe
140
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
141
+ WHERE lpc.name = 'document_embeddings'
142
+ AND lpe.cmetadata->>'user_id' = :user_id
143
+ AND lpe.cmetadata->>'source_type' = 'database'
144
+ AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query)
145
+ ORDER BY rank DESC
146
+ LIMIT :k
147
+ """)
148
+
149
+ async with _pgvector_engine.connect() as conn:
150
+ result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k})
151
+ rows = result.fetchall()
152
+
153
+ return [
154
+ RetrievalResult(
155
+ content=row.document,
156
+ metadata=row.cmetadata,
157
+ score=float(row.rank),
158
+ source_type="database",
159
+ )
160
+ for row in rows
161
+ ]
162
+
163
+ async def _search_fts_tabular(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
164
+ """Full-text search over tabular document chunks using PostgreSQL tsvector."""
165
+ sql = text("""
166
+ SELECT lpe.document, lpe.cmetadata,
167
+ ts_rank(to_tsvector('english', lpe.document),
168
+ plainto_tsquery('english', :query)) AS rank
169
+ FROM langchain_pg_embedding lpe
170
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
171
+ WHERE lpc.name = 'document_embeddings'
172
+ AND lpe.cmetadata->>'user_id' = :user_id
173
+ AND lpe.cmetadata->>'source_type' = 'document'
174
+ AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
175
+ OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
176
+ AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query)
177
+ ORDER BY rank DESC
178
+ LIMIT :k
179
+ """)
180
+
181
+ async with _pgvector_engine.connect() as conn:
182
+ result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k})
183
+ rows = result.fetchall()
184
+
185
+ return [
186
+ RetrievalResult(
187
+ content=row.document,
188
+ metadata=row.cmetadata,
189
+ score=float(row.rank),
190
+ source_type="document",
191
+ )
192
+ for row in rows
193
+ ]
194
+
195
+ def _rrf_merge(
196
+ self,
197
+ *ranked_lists: list[RetrievalResult],
198
+ k_rrf: int = 60,
199
+ top_k: int = 5,
200
+ ) -> list[RetrievalResult]:
201
+ """Reciprocal Rank Fusion — combines ranked lists using rank positions only."""
202
+ scores: dict[tuple, float] = {}
203
+ index: dict[tuple, RetrievalResult] = {}
204
+
205
+ for ranked in ranked_lists:
206
+ for rank, result in enumerate(ranked):
207
+ data = result.metadata.get("data", {})
208
+ key = (data.get("table_name"), data.get("column_name") or data.get("filename"))
209
+ scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
210
+ if key not in index or result.score > index[key].score:
211
+ index[key] = result
212
+
213
+ def _key(r: RetrievalResult) -> tuple:
214
+ d = r.metadata.get("data", {})
215
+ return (d.get("table_name"), d.get("column_name") or d.get("filename"))
216
+
217
+ merged = sorted(index.values(), key=lambda r: scores[_key(r)], reverse=True)
218
+ return merged[:top_k]
219
+
220
+ def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
221
+ """Deduplicate by (table_name, column_name), keeping highest score per unique column."""
222
+ seen: dict[tuple, RetrievalResult] = {}
223
+ for r in results:
224
+ data = r.metadata.get("data", {})
225
+ key = (data.get("table_name"), data.get("column_name") or data.get("filename"))
226
+ if key not in seen or r.score > seen[key].score:
227
+ seen[key] = r
228
+ return sorted(seen.values(), key=lambda r: r.score, reverse=True)
229
+
230
+ # ------------------------------------------------------------------
231
+ # Named strategies — one embed call each, legs run in parallel
232
+ # ------------------------------------------------------------------
233
+
234
+ async def dense_no_threshold(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
235
+ """Cosine similarity, no score cutoff — always returns k chunks."""
236
+ embedding = await self._embed_query(query)
237
+ db_results, tabular_results = await asyncio.gather(
238
+ self._search_db(embedding, user_id, k),
239
+ self._search_tabular(embedding, user_id, k),
240
+ )
241
+ return self._dedup(db_results + tabular_results)[:k]
242
+
243
+ async def dense_dot(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
244
+ """Inner product similarity (<#>).
245
+
246
+ For L2-normalized embeddings (OpenAI), ranking is identical to cosine.
247
+ Score = raw inner product (not bounded to [0,1]).
248
+ """
249
+ embedding = await self._embed_query(query)
250
+ db_results, tabular_results = await asyncio.gather(
251
+ self._search_db(embedding, user_id, k, "<#>"),
252
+ self._search_tabular(embedding, user_id, k, "<#>"),
253
+ )
254
+ return self._dedup(db_results + tabular_results)[:k]
255
+
256
+ async def dense_l2(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
257
+ """L2 (Euclidean) distance similarity (<->).
258
+
259
+ For L2-normalized embeddings (OpenAI), ranking order matches cosine.
260
+ Score = 1 / (1 + l2_distance), bounded to (0, 1].
261
+ """
262
+ embedding = await self._embed_query(query)
263
+ db_results, tabular_results = await asyncio.gather(
264
+ self._search_db(embedding, user_id, k, "<->"),
265
+ self._search_tabular(embedding, user_id, k, "<->"),
266
+ )
267
+ return self._dedup(db_results + tabular_results)[:k]
268
+
269
+ async def hybrid(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
270
+ """RRF merge of dense + FTS over both database and tabular sources.
271
+
272
+ Embeds once, then runs all four legs (dense db, dense tabular, fts db,
273
+ fts tabular) in a single asyncio.gather.
274
+ """
275
+ embedding = await self._embed_query(query)
276
+ db_results, tabular_results, fts_db, fts_tabular = await asyncio.gather(
277
+ self._search_db(embedding, user_id, k),
278
+ self._search_tabular(embedding, user_id, k),
279
+ self._search_fts_db(query, user_id, k * 4),
280
+ self._search_fts_tabular(query, user_id, k * 4),
281
+ )
282
+ dense = self._dedup(db_results + tabular_results)[:k]
283
+ fts_all = self._dedup(fts_db + fts_tabular)
284
+ return self._rrf_merge(dense, fts_all, top_k=k)
285
+
286
+ async def hybrid_bm25(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
287
+ """RRF merge of dense + FTS (database chunks only).
288
+
289
+ Embeds once, then runs dense db, dense tabular, and fts db legs in parallel.
290
+ """
291
+ embedding = await self._embed_query(query)
292
+ db_results, tabular_results, fts_results = await asyncio.gather(
293
+ self._search_db(embedding, user_id, k),
294
+ self._search_tabular(embedding, user_id, k),
295
+ self._search_fts_db(query, user_id, k * 4),
296
+ )
297
+ dense = self._dedup(db_results + tabular_results)[:k]
298
+ return self._rrf_merge(dense, self._dedup(fts_results), top_k=k)
299
+
300
+ # ------------------------------------------------------------------
301
+ # Public interface — called by the router
302
+ # ------------------------------------------------------------------
303
+
304
+ async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
305
+ strategy_fn = getattr(self, ACTIVE_STRATEGY)
306
+ results = await strategy_fn(query, user_id, k)
307
+ logger.info("schema retrieval", strategy=ACTIVE_STRATEGY, count=len(results))
308
+ return results
309
+
310
+
311
+ # ------------------------------------------------------------------
312
+ # Benchmark helper — import in test scripts
313
+ # ------------------------------------------------------------------
314
+
315
+ async def benchmark(
316
+ query: str,
317
+ user_id: str,
318
+ k: int = 5,
319
+ strategies: list[Strategy] | None = None,
320
+ ) -> dict[str, dict]:
321
+ """Run multiple strategies against the same query and return timing + results."""
322
+ retriever = SchemaRetriever()
323
+ targets: list[Strategy] = strategies or [
324
+ "dense_no_threshold",
325
+ "dense_dot",
326
+ "dense_l2",
327
+ "hybrid",
328
+ "hybrid_bm25",
329
+ ]
330
+ report: dict[str, dict] = {}
331
+
332
+ for name in targets:
333
+ fn = getattr(retriever, name)
334
+ t0 = time.perf_counter()
335
+ chunks = await fn(query, user_id, k)
336
+ elapsed_ms = round((time.perf_counter() - t0) * 1000)
337
+
338
+ total_chars = sum(len(r.content) for r in chunks)
339
+ report[name] = {
340
+ "chunks": len(chunks),
341
+ "estimated_tokens": total_chars // 4,
342
+ "elapsed_ms": elapsed_ms,
343
+ "results": chunks,
344
+ }
345
+
346
+ return report
347
+
348
+
349
+ schema_retriever = SchemaRetriever()
src/rag/router.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Routes retrieval requests to the appropriate retriever based on source_hint."""
2
+
3
+ import asyncio
4
+ import hashlib
5
+ import json
6
+ from typing import Literal
7
+
8
+ from src.db.redis.connection import get_redis
9
+ from src.middlewares.logging import get_logger
10
+ from src.rag.base import BaseRetriever, RetrievalResult
11
+
12
+ logger = get_logger("retrieval_router")
13
+
14
+ _CACHE_TTL = 3600 # 1 hour
15
+ SourceHint = Literal["document", "schema", "both"]
16
+
17
+
18
+ class RetrievalRouter:
19
+ def __init__(
20
+ self,
21
+ schema_retriever: BaseRetriever,
22
+ document_retriever: BaseRetriever,
23
+ ):
24
+ self._retrievers: dict[str, BaseRetriever] = {
25
+ "schema": schema_retriever,
26
+ "document": document_retriever,
27
+ }
28
+
29
+ def _route(self, source_hint: SourceHint) -> list[BaseRetriever]:
30
+ if source_hint == "schema":
31
+ return [self._retrievers["schema"]]
32
+ if source_hint == "document":
33
+ return [self._retrievers["document"]]
34
+ return list(self._retrievers.values())
35
+
36
+ async def retrieve(
37
+ self,
38
+ query: str,
39
+ user_id: str,
40
+ source_hint: SourceHint = "both",
41
+ k: int = 10,
42
+ ) -> list[RetrievalResult]:
43
+ redis = await get_redis()
44
+ query_hash = hashlib.md5(query.encode()).hexdigest()
45
+ cache_key = f"retrieval:{user_id}:{source_hint}:{query_hash}:{k}"
46
+
47
+ cached = await redis.get(cache_key)
48
+ if cached:
49
+ logger.info("returning cached retrieval results", source_hint=source_hint)
50
+ raw = json.loads(cached)
51
+ return [RetrievalResult(**r) for r in raw]
52
+
53
+ retrievers = self._route(source_hint)
54
+ batches = await asyncio.gather(
55
+ *[r.retrieve(query, user_id, k) for r in retrievers],
56
+ return_exceptions=True,
57
+ )
58
+
59
+ results: list[RetrievalResult] = []
60
+ for batch in batches:
61
+ if isinstance(batch, Exception):
62
+ logger.error("retriever failed", error=str(batch))
63
+ continue
64
+ results.extend(batch)
65
+
66
+ results.sort(key=lambda r: r.score, reverse=True)
67
+ results = results[:k]
68
+
69
+ logger.info("retrieved chunks", count=len(results), source_hint=source_hint)
70
+ await redis.setex(
71
+ cache_key,
72
+ _CACHE_TTL,
73
+ json.dumps([vars(r) for r in results]),
74
+ )
75
+ return results