sofhiaazzhr commited on
Commit
b272cc7
·
1 Parent(s): 5670888

[KM-556] delete Phase 1 remnants: query/executors/, query_executor.py, orchestration.py

Browse files
PROGRESS.md CHANGED
@@ -2,7 +2,7 @@
2
 
3
  Persistent tracker mirroring the 42-item ownership table in `REPO_CONTEXT.md` "Team — division of work". Update as PRs land. Future Claude Code sessions read this to know what's already done.
4
 
5
- **Last updated**: 2026-05-08 (item 41 done; item 16 done; item 31 done; item 35 done; item 36 done chat endpoint rewired to Phase 2 QueryService)
6
  **Current open PR**: none — all Phase 2 contracts shipped on `pr/1`. Cleanup PR pending (API rewiring + Phase 1 removal).
7
 
8
  ---
 
2
 
3
  Persistent tracker mirroring the 42-item ownership table in `REPO_CONTEXT.md` "Team — division of work". Update as PRs land. Future Claude Code sessions read this to know what's already done.
4
 
5
+ **Last updated**: 2026-05-08 (items 16,31,35,36,41 done; Phase 1 remnants deleted: query/executors/, query_executor.py, agents/orchestration.py)
6
  **Current open PR**: none — all Phase 2 contracts shipped on `pr/1`. Cleanup PR pending (API rewiring + Phase 1 removal).
7
 
8
  ---
src/agents/orchestration.py DELETED
@@ -1,79 +0,0 @@
1
- """Orchestrator agent for intent recognition and planning."""
2
-
3
- from langchain_openai import AzureChatOpenAI
4
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
5
- from src.config.settings import settings
6
- from src.middlewares.logging import get_logger
7
- from src.models.structured_output import IntentClassification
8
-
9
- logger = get_logger("orchestrator")
10
-
11
-
12
- class OrchestratorAgent:
13
- """Orchestrator agent for intent recognition and planning."""
14
-
15
- def __init__(self):
16
- self.llm = AzureChatOpenAI(
17
- azure_deployment=settings.azureai_deployment_name_4o,
18
- openai_api_version=settings.azureai_api_version_4o,
19
- azure_endpoint=settings.azureai_endpoint_url_4o,
20
- api_key=settings.azureai_api_key_4o,
21
- temperature=0
22
- )
23
-
24
- self.prompt = ChatPromptTemplate.from_messages([
25
- ("system", """You are an orchestrator agent. You receive recent conversation history and the user's latest message.
26
-
27
- Your task:
28
- 1. Determine intent: question, greeting, goodbye, or other
29
- 2. Decide whether to search the user's documents (needs_search)
30
- 3. If search is needed, rewrite the user's message into a STANDALONE search query that incorporates necessary context from conversation history. If the user says "tell me more" or "how many papers?", the search_query must spell out the full topic explicitly from history.
31
- 4. If no search needed, provide a short direct_response (plain text only, no markdown formatting).
32
-
33
- Intent Routing:
34
- - question -> needs_search=True, search_query=<standalone rewritten query>
35
- - greeting -> needs_search=False, direct_response="Hello! How can I assist you today?"
36
- - goodbye -> needs_search=False, direct_response="Goodbye! Have a great day!"
37
- - other -> needs_search=True, search_query=<standalone rewritten query>
38
-
39
- Source Routing (set source_hint):
40
- - Columns, tables, sheets, data types, schema, row counts, statistics -> source_hint=schema
41
- - Document content, paragraphs, reports, articles, text -> source_hint=document
42
- - Unclear or spans both -> source_hint=both
43
- """),
44
- MessagesPlaceholder(variable_name="history"),
45
- ("user", "{message}")
46
- ])
47
-
48
- # with_structured_output uses function calling — guarantees valid schema regardless of LLM response style
49
- self.chain = self.prompt | self.llm.with_structured_output(IntentClassification)
50
-
51
- async def analyze_message(self, message: str, history: list = None) -> dict:
52
- """Analyze user message and determine next actions.
53
-
54
- Args:
55
- message: The current user message.
56
- history: Recent conversation as LangChain BaseMessage objects (oldest-first).
57
- Used to rewrite ambiguous follow-ups into standalone search queries.
58
- """
59
- try:
60
- logger.info(f"Analyzing message: {message[:50]}...")
61
-
62
- history_messages = history or []
63
- result: IntentClassification = await self.chain.ainvoke({"message": message, "history": history_messages})
64
-
65
- logger.info(f"Intent: {result.intent}, Needs search: {result.needs_search}, Search query: {result.search_query[:50] if result.search_query else ''}")
66
- return result.model_dump()
67
-
68
- except Exception as e:
69
- logger.error("Message analysis failed", error=str(e))
70
- # Fallback to treating everything as a question
71
- return {
72
- "intent": "question",
73
- "needs_search": True,
74
- "search_query": message,
75
- "direct_response": None
76
- }
77
-
78
-
79
- orchestrator = OrchestratorAgent()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/query/executors/__init__.py DELETED
File without changes
src/query/executors/db_executor.py DELETED
@@ -1,648 +0,0 @@
1
- """Executor for registered database sources (source_type="database").
2
-
3
- Flow per (client_id, question):
4
- 1. Collect all relevant (table_name, column_name) pairs from retrieval results.
5
- 2. Fetch the FULL schema for those tables from PGVector (not just top-k columns).
6
- 3. Build a schema context string and send to LLM → structured SQLQuery output.
7
- 4. Validate via sqlglot: SELECT-only, schema-grounded, LIMIT enforced.
8
- 5. Execute on the user's DB via engine_scope + asyncio.to_thread.
9
- 6. Return QueryResult per client_id (may span multiple tables via JOINs).
10
-
11
- Supported db_types: postgres, supabase, mysql.
12
- Other types are skipped with a warning — they do not raise.
13
- """
14
-
15
- import asyncio
16
- from collections import defaultdict
17
- from typing import Any
18
-
19
- import sqlglot
20
- import sqlglot.expressions as exp
21
- import tiktoken
22
- from langchain_core.prompts import ChatPromptTemplate
23
- from langchain_openai import AzureChatOpenAI
24
- from sqlalchemy import text
25
- from sqlalchemy.ext.asyncio import AsyncSession
26
-
27
- from src.config.settings import settings
28
- from src.database_client.database_client_service import database_client_service
29
- from src.db.postgres.connection import _pgvector_engine
30
- from src.middlewares.logging import get_logger
31
- from src.models.sql_query import SQLQuery
32
- from src.pipeline.db_pipeline import db_pipeline_service
33
- from src.query.base import BaseExecutor, QueryResult
34
- from src.retrieval.base import RetrievalResult
35
- from src.utils.db_credential_encryption import decrypt_credentials_dict
36
-
37
- logger = get_logger("db_executor")
38
-
39
- _enc = tiktoken.get_encoding("cl100k_base")
40
-
41
- _SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"}
42
- _MAX_RETRIES = 3
43
- _MAX_LIMIT = 500
44
- _FK_EXPANSION_MAX_TABLES = 5
45
-
46
- _SQL_SYSTEM_PROMPT = """\
47
- You are a SQL data analyst working with a user's database.
48
- Generate a single SQL SELECT statement that answers the user's question.
49
-
50
- Database dialect: {dialect}
51
-
52
- Rules:
53
- - ONLY reference tables and columns listed in the schema below. Do not invent names.
54
- - Always include a LIMIT clause (max {limit}).
55
- - Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL.
56
- - Prefer explicit JOINs over subqueries when combining tables.
57
- - For aggregations, always alias the result column (e.g. COUNT(*) AS order_count).
58
- - For date filtering, use dialect-appropriate functions ({dialect} syntax).
59
-
60
- Schema:
61
- {schema}
62
-
63
- {error_section}"""
64
-
65
-
66
- class DbExecutor(BaseExecutor):
67
- def __init__(self) -> None:
68
- self._llm = AzureChatOpenAI(
69
- azure_deployment=settings.azureai_deployment_name_4o,
70
- openai_api_version=settings.azureai_api_version_4o,
71
- azure_endpoint=settings.azureai_endpoint_url_4o,
72
- api_key=settings.azureai_api_key_4o,
73
- temperature=0,
74
- )
75
- self._prompt = ChatPromptTemplate.from_messages([
76
- ("system", _SQL_SYSTEM_PROMPT),
77
- ("human", "{question}"),
78
- ])
79
- self._chain = self._prompt | self._llm.with_structured_output(SQLQuery)
80
-
81
- # ------------------------------------------------------------------
82
- # Public interface
83
- # ------------------------------------------------------------------
84
-
85
- async def execute(
86
- self,
87
- results: list[RetrievalResult],
88
- user_id: str,
89
- db: AsyncSession,
90
- question: str,
91
- limit: int = 100,
92
- ) -> list[QueryResult]:
93
- db_results = [r for r in results if r.source_type == "database"]
94
- if not db_results:
95
- return []
96
-
97
- # Group by client_id — one SQL generation + execution pass per client
98
- by_client: dict[str, list[RetrievalResult]] = defaultdict(list)
99
- for r in db_results:
100
- client_id = r.metadata.get("database_client_id", "")
101
- if client_id:
102
- by_client[client_id].append(r)
103
- else:
104
- logger.warning("db result missing database_client_id, skipping")
105
-
106
- query_results: list[QueryResult] = []
107
- for client_id, client_results in by_client.items():
108
- try:
109
- qr = await self._execute_for_client(client_id, client_results, user_id, db, question, limit)
110
- if qr:
111
- query_results.append(qr)
112
- except Exception as e:
113
- logger.error("db executor failed for client", client_id=client_id, error=str(e))
114
-
115
- return query_results
116
-
117
- # ------------------------------------------------------------------
118
- # Per-client execution
119
- # ------------------------------------------------------------------
120
-
121
- async def _execute_for_client(
122
- self,
123
- client_id: str,
124
- results: list[RetrievalResult],
125
- user_id: str,
126
- db: AsyncSession,
127
- question: str,
128
- limit: int,
129
- ) -> QueryResult | None:
130
- client = await database_client_service.get(db, client_id)
131
- if not client:
132
- logger.warning("database client not found", client_id=client_id)
133
- return None
134
- if client.user_id != user_id:
135
- logger.warning("client ownership mismatch", client_id=client_id)
136
- return None
137
- if client.db_type not in _SUPPORTED_DB_TYPES:
138
- logger.warning("unsupported db_type for query execution", db_type=client.db_type)
139
- return None
140
-
141
- # Hit tables = tables retrieval pointed at directly. Get full per-column
142
- # schema for these. Related tables (one FK hop away, both directions) are
143
- # fetched separately in abbreviated form to give the LLM enough context
144
- # to JOIN without paying the per-column profile token cost.
145
- hit_tables = list({
146
- r.metadata.get("data", {}).get("table_name")
147
- for r in results
148
- if r.metadata.get("data", {}).get("table_name")
149
- })
150
- if not hit_tables:
151
- logger.warning("no table_name on any retrieval result", client_id=client_id)
152
- return None
153
-
154
- full_schema = await self._fetch_full_schema(client_id, hit_tables, user_id)
155
- if not full_schema:
156
- logger.warning("no schema found in vector store", client_id=client_id, tables=hit_tables)
157
- return None
158
-
159
- related_tables = await self._find_related_tables(client_id, user_id, hit_tables)
160
- related_schema = (
161
- await self._fetch_abbreviated_schema(client_id, user_id, related_tables)
162
- if related_tables else {}
163
- )
164
-
165
- schema_ctx = self._build_schema_context(full_schema, related_schema)
166
- capped_limit = min(limit, _MAX_LIMIT)
167
- dialect = client.db_type
168
-
169
- # SQL generation with retry
170
- validated_sql: str | None = None
171
- prev_error: str = ""
172
- prev_reasoning: str = ""
173
- for attempt in range(_MAX_RETRIES):
174
- if prev_error:
175
- error_section = (
176
- f"Previous attempt reasoning: {prev_reasoning}\n"
177
- f"Previous attempt failed: {prev_error}\n"
178
- "Fix the issue above."
179
- )
180
- else:
181
- error_section = ""
182
- try:
183
- prompt_text = schema_ctx + error_section + question
184
- input_tokens = len(_enc.encode(prompt_text))
185
- logger.info("sql generation input tokens", attempt=attempt + 1, tokens=input_tokens)
186
-
187
- result: SQLQuery = await self._chain.ainvoke({
188
- "schema": schema_ctx,
189
- "dialect": dialect,
190
- "limit": capped_limit,
191
- "error_section": error_section,
192
- "question": question,
193
- })
194
- sql = result.sql.strip()
195
- allowed_tables = set(full_schema) | set(related_schema)
196
- column_map: dict[str, set[str]] = {
197
- t: {c["name"] for c in cols} for t, cols in full_schema.items()
198
- }
199
- for t, info in related_schema.items():
200
- column_map[t] = set(info.get("column_names") or [])
201
- validation_error = self._validate(sql, allowed_tables, capped_limit, column_map)
202
- if validation_error:
203
- prev_error = validation_error
204
- prev_reasoning = result.reasoning
205
- logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error)
206
- continue
207
- validated_sql = self._enforce_limit(sql, capped_limit)
208
- output_tokens = len(_enc.encode(result.sql)) + len(_enc.encode(result.reasoning))
209
- logger.info(
210
- "sql generated",
211
- attempt=attempt + 1,
212
- input_tokens=input_tokens,
213
- output_tokens=output_tokens,
214
- total_tokens=input_tokens + output_tokens,
215
- reasoning=result.reasoning,
216
- )
217
- break
218
- except Exception as e:
219
- prev_error = str(e)
220
- logger.warning("sql generation error", attempt=attempt + 1, error=prev_error)
221
-
222
- if not validated_sql:
223
- logger.error("sql generation failed after retries", client_id=client_id)
224
- return None
225
-
226
- # Execute on user's DB
227
- creds = decrypt_credentials_dict(client.credentials)
228
- with db_pipeline_service.engine_scope(client.db_type, creds) as engine:
229
- rows = await asyncio.to_thread(self._run_sql, engine, validated_sql)
230
-
231
- column_types = {
232
- col["name"]: col["type"]
233
- for cols in full_schema.values()
234
- for col in cols
235
- }
236
- columns = list(rows[0].keys()) if rows else []
237
-
238
- return QueryResult(
239
- source_type="database",
240
- source_id=client_id,
241
- table_or_file=", ".join(hit_tables),
242
- columns=columns,
243
- rows=rows,
244
- row_count=len(rows),
245
- metadata={
246
- "db_type": client.db_type,
247
- "client_name": client.name,
248
- "sql": validated_sql,
249
- "column_types": {c: column_types.get(c, "unknown") for c in columns},
250
- },
251
- )
252
-
253
- # ------------------------------------------------------------------
254
- # Schema helpers
255
- # ------------------------------------------------------------------
256
-
257
- async def _find_related_tables(
258
- self,
259
- client_id: str,
260
- user_id: str,
261
- hit_tables: list[str],
262
- ) -> list[str]:
263
- """One-hop FK neighbours of `hit_tables`, both directions, excluding hits.
264
-
265
- Prefers chunk_level='table' rows; if none exist for the client (legacy
266
- ingest predating Phase 1), falls back to aggregating from column-chunk
267
- metadata. Returns [] when no FK metadata is available.
268
-
269
- Capped at _FK_EXPANSION_MAX_TABLES, ranked by edge count desc then
270
- table name asc. A warning is logged when the cap kicks in.
271
- """
272
- if not hit_tables:
273
- return []
274
-
275
- hit_set = set(hit_tables)
276
- # edge_counts[related_table] = number of FK edges connecting it to the hit set
277
- edge_counts: dict[str, int] = defaultdict(int)
278
-
279
- # ---- Primary path: table-level chunks ----
280
- sql = text("""
281
- SELECT lpe.cmetadata
282
- FROM langchain_pg_embedding lpe
283
- JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
284
- WHERE lpc.name = 'document_embeddings'
285
- AND lpe.cmetadata->>'user_id' = :user_id
286
- AND lpe.cmetadata->>'source_type' = 'database'
287
- AND lpe.cmetadata->>'database_client_id' = :client_id
288
- AND lpe.cmetadata->>'chunk_level' = 'table'
289
- """)
290
- async with _pgvector_engine.connect() as conn:
291
- result = await conn.execute(sql, {"user_id": user_id, "client_id": client_id})
292
- table_rows = result.fetchall()
293
-
294
- if table_rows:
295
- for row in table_rows:
296
- data = row.cmetadata.get("data", {})
297
- table = data.get("table_name")
298
- fks = data.get("foreign_keys") or []
299
- if not table:
300
- continue
301
- if table in hit_set:
302
- # Outgoing: this hit's FKs point at related tables
303
- for fk in fks:
304
- target = fk.get("target_table")
305
- if target and target not in hit_set:
306
- edge_counts[target] += 1
307
- else:
308
- # Incoming: this non-hit table's FKs point into the hit set
309
- for fk in fks:
310
- target = fk.get("target_table")
311
- if target in hit_set:
312
- edge_counts[table] += 1
313
- else:
314
- # ---- Fallback: aggregate from column chunks ----
315
- sql = text("""
316
- SELECT lpe.cmetadata->'data'->>'table_name' AS src_table,
317
- lpe.cmetadata->'data'->>'foreign_key' AS fk
318
- FROM langchain_pg_embedding lpe
319
- JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
320
- WHERE lpc.name = 'document_embeddings'
321
- AND lpe.cmetadata->>'user_id' = :user_id
322
- AND lpe.cmetadata->>'source_type' = 'database'
323
- AND lpe.cmetadata->>'database_client_id' = :client_id
324
- AND lpe.cmetadata->>'chunk_level' = 'column'
325
- AND lpe.cmetadata->'data'->>'foreign_key' IS NOT NULL
326
- """)
327
- async with _pgvector_engine.connect() as conn:
328
- result = await conn.execute(sql, {"user_id": user_id, "client_id": client_id})
329
- col_rows = result.fetchall()
330
-
331
- for row in col_rows:
332
- src = row.src_table
333
- fk = row.fk
334
- if not src or not fk:
335
- continue
336
- target = fk.split(".", 1)[0]
337
- if src in hit_set and target and target not in hit_set:
338
- edge_counts[target] += 1
339
- elif src not in hit_set and target in hit_set:
340
- edge_counts[src] += 1
341
-
342
- if not edge_counts:
343
- return []
344
-
345
- ranked = sorted(edge_counts.items(), key=lambda kv: (-kv[1], kv[0]))
346
- if len(ranked) > _FK_EXPANSION_MAX_TABLES:
347
- logger.warning(
348
- "fk expansion cap hit",
349
- client_id=client_id,
350
- total=len(ranked),
351
- cap=_FK_EXPANSION_MAX_TABLES,
352
- dropped=[t for t, _ in ranked[_FK_EXPANSION_MAX_TABLES:]],
353
- )
354
- ranked = ranked[:_FK_EXPANSION_MAX_TABLES]
355
-
356
- related = [t for t, _ in ranked]
357
- logger.info("fk-related tables", hit=sorted(hit_set), related=related)
358
- return related
359
-
360
- async def _fetch_abbreviated_schema(
361
- self,
362
- client_id: str,
363
- user_id: str,
364
- table_names: list[str],
365
- ) -> dict[str, dict[str, Any]]:
366
- """Abbreviated schema: name, row_count, PK, FKs, column names — no profiles.
367
-
368
- Prefers chunk_level='table' rows. Falls back to aggregating column-chunk
369
- metadata when table chunks are missing for a given table_name.
370
-
371
- Returns {table_name: {"row_count": int|None, "primary_key": [str],
372
- "foreign_keys": [{column, target_table, target_column}],
373
- "column_names": [str]}}.
374
- """
375
- if not table_names:
376
- return {}
377
-
378
- placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
379
- params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
380
- for i, name in enumerate(table_names):
381
- params[f"t{i}"] = name
382
-
383
- # Primary path: one row per table from chunk_level='table'
384
- sql_table = text(f"""
385
- SELECT lpe.cmetadata
386
- FROM langchain_pg_embedding lpe
387
- JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
388
- WHERE lpc.name = 'document_embeddings'
389
- AND lpe.cmetadata->>'user_id' = :user_id
390
- AND lpe.cmetadata->>'source_type' = 'database'
391
- AND lpe.cmetadata->>'database_client_id' = :client_id
392
- AND lpe.cmetadata->>'chunk_level' = 'table'
393
- AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
394
- """)
395
- async with _pgvector_engine.connect() as conn:
396
- result = await conn.execute(sql_table, params)
397
- t_rows = result.fetchall()
398
-
399
- out: dict[str, dict[str, Any]] = {}
400
- for row in t_rows:
401
- data = row.cmetadata.get("data", {})
402
- tname = data.get("table_name")
403
- if not tname:
404
- continue
405
- out[tname] = {
406
- "row_count": data.get("row_count"),
407
- "primary_key": list(data.get("primary_key") or []),
408
- "foreign_keys": list(data.get("foreign_keys") or []),
409
- "column_names": list(data.get("column_names") or []),
410
- }
411
-
412
- # Fallback for tables with no table-chunk: aggregate column chunks
413
- missing = [t for t in table_names if t not in out]
414
- if missing:
415
- placeholders_m = ", ".join(f":m{i}" for i in range(len(missing)))
416
- params_m: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
417
- for i, name in enumerate(missing):
418
- params_m[f"m{i}"] = name
419
- sql_col = text(f"""
420
- SELECT lpe.cmetadata
421
- FROM langchain_pg_embedding lpe
422
- JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
423
- WHERE lpc.name = 'document_embeddings'
424
- AND lpe.cmetadata->>'user_id' = :user_id
425
- AND lpe.cmetadata->>'source_type' = 'database'
426
- AND lpe.cmetadata->>'database_client_id' = :client_id
427
- AND lpe.cmetadata->>'chunk_level' = 'column'
428
- AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders_m})
429
- ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name'
430
- """)
431
- async with _pgvector_engine.connect() as conn:
432
- result = await conn.execute(sql_col, params_m)
433
- c_rows = result.fetchall()
434
-
435
- agg: dict[str, dict[str, Any]] = {
436
- t: {"row_count": None, "primary_key": [], "foreign_keys": [], "column_names": []}
437
- for t in missing
438
- }
439
- for row in c_rows:
440
- data = row.cmetadata.get("data", {})
441
- tname = data.get("table_name")
442
- cname = data.get("column_name")
443
- if not tname or tname not in agg or not cname:
444
- continue
445
- bucket = agg[tname]
446
- bucket["column_names"].append(cname)
447
- if data.get("is_primary_key"):
448
- bucket["primary_key"].append(cname)
449
- fk = data.get("foreign_key")
450
- if fk:
451
- target_table, _, target_col = fk.partition(".")
452
- bucket["foreign_keys"].append({
453
- "column": cname,
454
- "target_table": target_table,
455
- "target_column": target_col,
456
- })
457
- for t, v in agg.items():
458
- if v["column_names"]:
459
- out[t] = v
460
-
461
- return out
462
-
463
- async def _fetch_full_schema(
464
- self,
465
- client_id: str,
466
- table_names: list[str],
467
- user_id: str,
468
- ) -> dict[str, list[dict[str, Any]]]:
469
- """Fetch ALL column chunks for the given tables from PGVector.
470
-
471
- Returns {table_name: [{"name": ..., "type": ..., "is_primary_key": ...,
472
- "foreign_key": ..., "content": ...}]}
473
- """
474
- placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
475
- sql = text(f"""
476
- SELECT lpe.cmetadata, lpe.document
477
- FROM langchain_pg_embedding lpe
478
- JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
479
- WHERE lpc.name = 'document_embeddings'
480
- AND lpe.cmetadata->>'user_id' = :user_id
481
- AND lpe.cmetadata->>'source_type' = 'database'
482
- AND lpe.cmetadata->>'chunk_level' = 'column'
483
- AND lpe.cmetadata->>'database_client_id' = :client_id
484
- AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
485
- ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name'
486
- """)
487
-
488
- params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
489
- for i, name in enumerate(table_names):
490
- params[f"t{i}"] = name
491
-
492
- async with _pgvector_engine.connect() as conn:
493
- result = await conn.execute(sql, params)
494
- rows = result.fetchall()
495
-
496
- schema: dict[str, list[dict[str, Any]]] = defaultdict(list)
497
- for row in rows:
498
- data = row.cmetadata.get("data", {})
499
- table = data.get("table_name")
500
- if table:
501
- schema[table].append({
502
- "name": data.get("column_name", ""),
503
- "type": data.get("column_type", ""),
504
- "is_primary_key": data.get("is_primary_key", False),
505
- "foreign_key": data.get("foreign_key"),
506
- "content": row.document, # chunk text includes top values / samples
507
- })
508
- return dict(schema)
509
-
510
- def _build_schema_context(
511
- self,
512
- schema: dict[str, list[dict[str, Any]]],
513
- related_schema: dict[str, dict[str, Any]] | None = None,
514
- ) -> str:
515
- lines: list[str] = []
516
- for table, columns in schema.items():
517
- lines.append(f"Table: {table}")
518
- for col in columns:
519
- flags = []
520
- if col["is_primary_key"]:
521
- flags.append("PRIMARY KEY")
522
- if col["foreign_key"]:
523
- flags.append(f"FK -> {col['foreign_key']}")
524
- flag_str = f" [{', '.join(flags)}]" if flags else ""
525
- lines.append(f" - {col['name']} {col['type']}{flag_str}")
526
- # Include sample/top-values line from chunk content if present
527
- for line in col["content"].splitlines():
528
- if line.startswith(("Top values:", "Sample values:")):
529
- lines.append(f" {line}")
530
- break
531
- lines.append("")
532
-
533
- related_block = self._build_related_schema_block(related_schema or {})
534
- if related_block:
535
- lines.append(related_block)
536
-
537
- return "\n".join(lines).strip()
538
-
539
- def _build_related_schema_block(self, related_schema: dict[str, dict[str, Any]]) -> str:
540
- """Format the abbreviated FK-related-tables section. Empty string when no related."""
541
- if not related_schema:
542
- return ""
543
- lines: list[str] = ["Related tables (one hop via FK, abbreviated — use for JOINs only):"]
544
- for table, info in related_schema.items():
545
- row_count = info.get("row_count")
546
- header = f"- {table} ({row_count} rows)" if row_count is not None else f"- {table}"
547
- lines.append(header)
548
- pk = info.get("primary_key") or []
549
- lines.append(f" Primary key: {', '.join(pk) if pk else '(none)'}")
550
- fks = info.get("foreign_keys") or []
551
- if fks:
552
- fk_strs = [
553
- f"{fk.get('column')} -> {fk.get('target_table')}.{fk.get('target_column')}"
554
- for fk in fks
555
- ]
556
- lines.append(f" Foreign keys: {', '.join(fk_strs)}")
557
- else:
558
- lines.append(" Foreign keys: (none)")
559
- cols = info.get("column_names") or []
560
- lines.append(f" Columns: {', '.join(cols)}")
561
- return "\n".join(lines)
562
-
563
- # ------------------------------------------------------------------
564
- # Guardrails
565
- # ------------------------------------------------------------------
566
-
567
- def _validate(
568
- self,
569
- sql: str,
570
- allowed_tables: set[str],
571
- limit: int,
572
- column_map: dict[str, set[str]] | None = None,
573
- ) -> str:
574
- """Return an error string if validation fails, empty string if OK.
575
-
576
- `allowed_tables` is the union of hit-table names and FK-related table
577
- names — both are legal targets for SELECT/JOIN.
578
-
579
- `column_map` maps table_name → set of valid column names. When provided,
580
- any qualified table.column reference not found in the map triggers a retry
581
- with an informative error so the LLM can self-correct without hallucinating.
582
- """
583
- # Layer 1: sqlglot parse + SELECT-only check
584
- try:
585
- parsed = sqlglot.parse_one(sql)
586
- except sqlglot.errors.ParseError as e:
587
- return f"SQL parse error: {e}"
588
-
589
- if not isinstance(parsed, exp.Select):
590
- return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}"
591
-
592
- # Check for DML anywhere in the AST (including writeable CTEs)
593
- for node in parsed.find_all((exp.Insert, exp.Update, exp.Delete)):
594
- return f"DML ({type(node).__name__}) is not allowed."
595
-
596
- # Layer 2: schema grounding — table names
597
- known_tables = {t.lower() for t in allowed_tables}
598
- alias_to_table: dict[str, str] = {}
599
- for tbl in parsed.find_all(exp.Table):
600
- name = tbl.name.lower()
601
- if name and name not in known_tables:
602
- return f"Unknown table '{tbl.name}'. Only use tables from the schema."
603
- alias = (tbl.alias or tbl.name).lower()
604
- alias_to_table[alias] = name
605
-
606
- # Layer 3: column grounding — qualified references only (table.column)
607
- if column_map:
608
- normalized_map = {t.lower(): {c.lower() for c in cols} for t, cols in column_map.items()}
609
- for col_node in parsed.find_all(exp.Column):
610
- tbl_ref = col_node.table
611
- if not tbl_ref:
612
- continue # unqualified — skip, can't resolve without full alias tracking
613
- tbl_name = alias_to_table.get(tbl_ref.lower(), tbl_ref.lower())
614
- col_name = col_node.name.lower()
615
- if tbl_name in normalized_map and col_name not in normalized_map[tbl_name]:
616
- available = ", ".join(sorted(normalized_map[tbl_name]))
617
- return (
618
- f"Column '{col_node.name}' does not exist on table '{tbl_name}'. "
619
- f"Available columns: {available}."
620
- )
621
-
622
- # Layer 4: LIMIT enforcement (inject if missing — done before execution)
623
- return ""
624
-
625
- # ------------------------------------------------------------------
626
- # SQL execution
627
- # ------------------------------------------------------------------
628
-
629
- def _enforce_limit(self, sql: str, limit: int) -> str:
630
- """Inject or cap LIMIT using sqlglot AST manipulation."""
631
- parsed = sqlglot.parse_one(sql)
632
- existing = parsed.find(exp.Limit)
633
- if existing:
634
- current = int(existing.expression.this)
635
- if current > limit:
636
- return parsed.limit(limit).sql()
637
- else:
638
- return parsed.limit(limit).sql()
639
- return parsed.sql()
640
-
641
- def _run_sql(self, engine: Any, sql: str) -> list[dict]:
642
- # Ensure the user DB connection is a read-only credential — sqlglot validation alone is not sufficient.
643
- with engine.connect() as conn:
644
- result = conn.execute(text(sql))
645
- return [dict(row) for row in result.mappings()]
646
-
647
-
648
- db_executor = DbExecutor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/query/executors/tabular.py DELETED
@@ -1,287 +0,0 @@
1
- """Executor for tabular document sources (source_type="document", file_type csv/xlsx).
2
-
3
- Flow:
4
- 1. Group RetrievalResult chunks by (document_id, sheet_name).
5
- 2. Per group: download Parquet from Azure Blob → pandas DataFrame.
6
- 3. Build schema context from DataFrame columns + sample values.
7
- 4. LLM decides operation (groupby_sum, filter, top_n, etc.) via structured output.
8
- 5. Pandas runs the operation; retry up to 3x on error with feedback to LLM.
9
- 6. Fallback to raw rows if all retries fail.
10
- 7. Return QueryResult per group.
11
- """
12
- import asyncio
13
- from typing import Literal, TypedDict
14
-
15
- import pandas as pd
16
- from langchain_core.prompts import ChatPromptTemplate
17
- from langchain_openai import AzureChatOpenAI
18
- from pydantic import BaseModel
19
- from sqlalchemy.ext.asyncio import AsyncSession
20
-
21
- from src.config.settings import settings
22
- from src.storage.parquet import download_parquet
23
- from src.middlewares.logging import get_logger
24
- from src.query.base import BaseExecutor, QueryResult
25
- from src.retrieval.base import RetrievalResult
26
-
27
- logger = get_logger("tabular_executor")
28
-
29
-
30
- class _GroupInfo(TypedDict):
31
- filename: str
32
- file_type: str
33
-
34
-
35
- _TABULAR_FILE_TYPES = ("csv", "xlsx")
36
- _MAX_RETRIES = 3
37
-
38
- _SYSTEM_PROMPT = """\
39
- You are a data analyst. Given a DataFrame schema and a user question, \
40
- decide which pandas operation to perform.
41
-
42
- IMPORTANT rules:
43
- - Use ONLY the exact column names as written in the schema below. Never translate or rename them.
44
- - For top_n: always set value_col to the column to sort by. Do NOT use sort_col for top_n.
45
- - For sort: use sort_col for the column to sort by.
46
- - For filter with comparison (>, <, >=, <=, !=): set filter_operator accordingly (gt, lt, gte, lte, ne). Default is eq (==).
47
- - For multi-condition filters (AND logic), use the filters field as a list of {{"col", "value", "op"}} dicts instead of filter_col/filter_value.
48
- Example: status=SUCCESS AND amount_paid>200000 → filters=[{{"col":"status","value":"SUCCESS","op":"eq"}},{{"col":"amount_paid","value":"200000","op":"gt"}}]
49
- - For OR conditions on a column (e.g. value is A or B), use or_filters. Combine with filters for mixed AND+OR logic.
50
- Example: (status=FAILED OR status=REVERSED) AND payment_channel=X → or_filters=[{{"col":"status","value":"FAILED","op":"eq"}},{{"col":"status","value":"REVERSED","op":"eq"}}], filters=[{{"col":"payment_channel","value":"X","op":"eq"}}]
51
- - For groupby with a pre-filter (e.g. count SUCCESS per channel): use filters or or_filters to narrow rows first, then use groupby_count/groupby_sum/groupby_avg on the filtered data by setting both filters and group_col.
52
-
53
- Schema:
54
- {schema}
55
-
56
- {error_section}"""
57
-
58
-
59
- class TabularOperation(BaseModel):
60
- operation: Literal[
61
- "filter", "groupby_sum", "groupby_avg", "groupby_count",
62
- "top_n", "sort", "aggregate", "raw"
63
- ]
64
- group_col: str | None = None # for groupby_*
65
- value_col: str | None = None # for groupby_*, top_n, aggregate
66
- filter_col: str | None = None # for single filter
67
- filter_value: str | None = None # for single filter
68
- filter_operator: Literal["eq", "ne", "gt", "gte", "lt", "lte"] = "eq" # for single filter
69
- filters: list[dict] | None = None # for multi-condition AND: [{"col": ..., "value": ..., "op": ...}]
70
- or_filters: list[dict] | None = None # for OR conditions, applied before AND filters
71
- sort_col: str | None = None # for sort
72
- ascending: bool = True # for sort
73
- n: int | None = None # for top_n
74
- agg_func: Literal["sum", "avg", "min", "max", "count"] | None = None # for aggregate
75
- reasoning: str
76
-
77
-
78
- def _get_filter_mask(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.Series:
79
- numeric = pd.to_numeric(df[col], errors="coerce")
80
- if operator == "eq":
81
- return df[col].astype(str) == str(value)
82
- elif operator == "ne":
83
- return df[col].astype(str) != str(value)
84
- elif operator == "gt":
85
- return numeric > float(value)
86
- elif operator == "gte":
87
- return numeric >= float(value)
88
- elif operator == "lt":
89
- return numeric < float(value)
90
- elif operator == "lte":
91
- return numeric <= float(value)
92
- raise ValueError(f"Unknown operator: {operator}")
93
-
94
-
95
- def _apply_single_filter(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.DataFrame:
96
- return df[_get_filter_mask(df, col, value, operator)]
97
-
98
-
99
- def _build_schema_context(df: pd.DataFrame) -> str:
100
- lines = []
101
- for col in df.columns:
102
- sample = df[col].dropna().head(3).tolist()
103
- lines.append(f"- {col} ({df[col].dtype}): sample values: {sample}")
104
- return "\n".join(lines)
105
-
106
-
107
- def _apply_operation(df: pd.DataFrame, op: TabularOperation, limit: int) -> pd.DataFrame:
108
- if op.operation == "groupby_sum":
109
- if not op.group_col or not op.value_col:
110
- raise ValueError(f"groupby_sum requires group_col and value_col, got {op}")
111
- return df.groupby(op.group_col)[op.value_col].sum().reset_index().nlargest(limit, op.value_col)
112
- elif op.operation == "groupby_avg":
113
- if not op.group_col or not op.value_col:
114
- raise ValueError(f"groupby_avg requires group_col and value_col, got {op}")
115
- return df.groupby(op.group_col)[op.value_col].mean().reset_index().nlargest(limit, op.value_col)
116
- elif op.operation == "groupby_count":
117
- if not op.group_col:
118
- raise ValueError(f"groupby_count requires group_col, got {op}")
119
- df_filtered = df.copy()
120
- if op.or_filters:
121
- or_mask = pd.Series([False] * len(df_filtered), index=df_filtered.index)
122
- for f in op.or_filters:
123
- or_mask = or_mask | _get_filter_mask(df_filtered, f["col"], f["value"], f.get("op", "eq"))
124
- df_filtered = df_filtered[or_mask]
125
- if op.filters:
126
- for f in op.filters:
127
- df_filtered = _apply_single_filter(df_filtered, f["col"], f["value"], f.get("op", "eq"))
128
- elif op.filter_col and op.filter_value is not None:
129
- df_filtered = _apply_single_filter(df_filtered, op.filter_col, op.filter_value, op.filter_operator)
130
- return df_filtered.groupby(op.group_col).size().reset_index(name="count").nlargest(limit, "count")
131
- elif op.operation == "filter":
132
- result = df.copy()
133
- if op.or_filters:
134
- or_mask = pd.Series([False] * len(result), index=result.index)
135
- for f in op.or_filters:
136
- or_mask = or_mask | _get_filter_mask(result, f["col"], f["value"], f.get("op", "eq"))
137
- result = result[or_mask]
138
- if op.filters:
139
- for f in op.filters:
140
- result = _apply_single_filter(result, f["col"], f["value"], f.get("op", "eq"))
141
- elif op.filter_col and op.filter_value is not None and not op.or_filters:
142
- result = _apply_single_filter(result, op.filter_col, op.filter_value, op.filter_operator)
143
- elif not op.or_filters and not op.filters and (not op.filter_col or op.filter_value is None):
144
- raise ValueError(f"filter requires filter_col/filter_value or filters or or_filters, got {op}")
145
- return result.head(limit)
146
- elif op.operation == "top_n":
147
- col = op.value_col
148
- if not col:
149
- raise ValueError(f"top_n requires value_col, got {op}")
150
- n = op.n or limit
151
- return df.nlargest(n, col)
152
- elif op.operation == "sort":
153
- if not op.sort_col:
154
- raise ValueError(f"sort requires sort_col, got {op}")
155
- return df.sort_values(op.sort_col, ascending=op.ascending).head(limit)
156
- elif op.operation == "aggregate":
157
- if not op.value_col or not op.agg_func:
158
- raise ValueError(f"aggregate requires value_col and agg_func, got {op}")
159
- funcs = {"sum": "sum", "avg": "mean", "min": "min", "max": "max", "count": "count"}
160
- value = getattr(df[op.value_col], funcs[op.agg_func])()
161
- return pd.DataFrame([{op.value_col: value, "operation": op.agg_func}])
162
- else: # "raw"
163
- return df.head(limit)
164
-
165
-
166
- class TabularExecutor(BaseExecutor):
167
- def __init__(self) -> None:
168
- self._llm = AzureChatOpenAI(
169
- azure_deployment=settings.azureai_deployment_name_4o,
170
- openai_api_version=settings.azureai_api_version_4o,
171
- azure_endpoint=settings.azureai_endpoint_url_4o,
172
- api_key=settings.azureai_api_key_4o,
173
- temperature=0,
174
- )
175
- self._prompt = ChatPromptTemplate.from_messages([
176
- ("system", _SYSTEM_PROMPT),
177
- ("human", "{question}"),
178
- ])
179
- self._chain = self._prompt | self._llm.with_structured_output(TabularOperation)
180
-
181
- async def execute(
182
- self,
183
- results: list[RetrievalResult],
184
- user_id: str,
185
- _db: AsyncSession,
186
- question: str,
187
- limit: int = 100,
188
- ) -> list[QueryResult]:
189
- tabular = [
190
- r for r in results
191
- if r.source_type == "document"
192
- and r.metadata.get("data", {}).get("file_type") in _TABULAR_FILE_TYPES
193
- ]
194
-
195
- if not tabular:
196
- return []
197
-
198
- # Group by (document_id, sheet_name) — one parquet download per group
199
- groups: dict[tuple[str, str | None], _GroupInfo] = {}
200
- for r in tabular:
201
- data = r.metadata.get("data", {})
202
- doc_id = data.get("document_id")
203
- if not doc_id:
204
- continue
205
- sheet_name = data.get("sheet_name") # None for CSV
206
- key = (doc_id, sheet_name)
207
- if key not in groups:
208
- groups[key] = {
209
- "filename": data.get("filename", ""),
210
- "file_type": data.get("file_type", ""),
211
- }
212
-
213
- async def _process_group(
214
- doc_id: str, sheet_name: str | None, info: _GroupInfo
215
- ) -> QueryResult | None:
216
- try:
217
- df = await download_parquet(user_id, doc_id, sheet_name)
218
- df_result = await self._query_with_agent(df, question, limit)
219
-
220
- table_label = info["filename"]
221
- if sheet_name:
222
- table_label += f" / sheet: {sheet_name}"
223
-
224
- logger.info(
225
- "tabular query complete",
226
- document_id=doc_id,
227
- sheet=sheet_name,
228
- file_type=info["file_type"],
229
- rows=len(df_result),
230
- columns=len(df_result.columns),
231
- )
232
- return QueryResult(
233
- source_type="document",
234
- source_id=doc_id,
235
- table_or_file=table_label,
236
- columns=list(df_result.columns),
237
- rows=df_result.to_dict(orient="records"),
238
- row_count=len(df_result),
239
- )
240
- except Exception as e:
241
- logger.error(
242
- "tabular query failed",
243
- document_id=doc_id,
244
- sheet=sheet_name,
245
- error=str(e),
246
- )
247
- return None
248
-
249
- gathered = await asyncio.gather(*[
250
- _process_group(doc_id, sheet_name, info)
251
- for (doc_id, sheet_name), info in groups.items()
252
- ])
253
- return [r for r in gathered if r is not None]
254
-
255
- async def _query_with_agent(
256
- self, df: pd.DataFrame, question: str, limit: int
257
- ) -> pd.DataFrame:
258
- schema_ctx = _build_schema_context(df)
259
- prev_error = ""
260
-
261
- for attempt in range(_MAX_RETRIES):
262
- error_section = (
263
- f"Previous attempt failed: {prev_error}\nFix the issue."
264
- if prev_error else ""
265
- )
266
- try:
267
- op: TabularOperation = await self._chain.ainvoke({
268
- "schema": schema_ctx,
269
- "error_section": error_section,
270
- "question": question,
271
- })
272
- logger.info(
273
- "tabular operation decided",
274
- operation=op.operation,
275
- reasoning=op.reasoning,
276
- )
277
- return _apply_operation(df, op, limit)
278
- except Exception as e:
279
- prev_error = str(e)
280
- logger.warning("tabular agent error", attempt=attempt + 1, error=prev_error)
281
-
282
- # Fallback: return raw rows
283
- logger.warning("tabular agent failed after retries, returning raw rows")
284
- return df.head(limit)
285
-
286
-
287
- tabular_executor = TabularExecutor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/query/query_executor.py DELETED
@@ -1,42 +0,0 @@
1
- """QueryExecutor — dispatches retrieval results to the appropriate executor by source_type."""
2
-
3
- import asyncio
4
-
5
- from sqlalchemy.ext.asyncio import AsyncSession
6
-
7
- from src.middlewares.logging import get_logger
8
- from src.query.base import QueryResult
9
- from src.query.executors.db_executor import db_executor
10
- from src.query.executors.tabular import tabular_executor
11
- from src.retrieval.base import RetrievalResult
12
-
13
- logger = get_logger("query_executor")
14
-
15
-
16
- class QueryExecutor:
17
- async def execute(
18
- self,
19
- results: list[RetrievalResult],
20
- user_id: str,
21
- db: AsyncSession,
22
- question: str,
23
- limit: int = 100,
24
- ) -> list[QueryResult]:
25
- batches = await asyncio.gather(
26
- db_executor.execute(results, user_id, db, question, limit),
27
- tabular_executor.execute(results, user_id, db, question, limit),
28
- return_exceptions=True,
29
- )
30
-
31
- query_results: list[QueryResult] = []
32
- for batch in batches:
33
- if isinstance(batch, Exception):
34
- logger.error("executor failed", error=str(batch))
35
- continue
36
- query_results.extend(batch)
37
-
38
- logger.info("query execution complete", total=len(query_results))
39
- return query_results
40
-
41
-
42
- query_executor = QueryExecutor()