sofhiaazzhr commited on
Commit
04e5c48
·
1 Parent(s): ca42520

[KM 556] add columns field to QueryResult + update DbExecutor

Browse files
src/query/executor/base.py CHANGED
@@ -11,6 +11,7 @@ from ..ir.models import QueryIR
11
  class QueryResult:
12
  source_id: str
13
  backend: str # "sql" | "tabular"
 
14
  rows: list[dict[str, Any]] = field(default_factory=list)
15
  row_count: int = 0
16
  truncated: bool = False
 
11
  class QueryResult:
12
  source_id: str
13
  backend: str # "sql" | "tabular"
14
+ columns: list[str] = field(default_factory=list)
15
  rows: list[dict[str, Any]] = field(default_factory=list)
16
  row_count: int = 0
17
  truncated: bool = False
src/query/executor/db.py CHANGED
@@ -79,7 +79,7 @@ class DbExecutor(BaseExecutor):
79
  )
80
  creds = decrypt_credentials_dict(client.credentials)
81
 
82
- rows = await asyncio.wait_for(
83
  asyncio.to_thread(self._run_sync, client.db_type, creds, compiled),
84
  timeout=_QUERY_TIMEOUT_SECONDS,
85
  )
@@ -97,6 +97,7 @@ class DbExecutor(BaseExecutor):
97
  return QueryResult(
98
  source_id=ir.source_id,
99
  backend="sql",
 
100
  rows=capped,
101
  row_count=len(capped),
102
  truncated=truncated,
@@ -175,7 +176,7 @@ class DbExecutor(BaseExecutor):
175
  )
176
 
177
  @staticmethod
178
- def _run_sync(db_type: str, creds: dict, compiled: CompiledSql) -> list[dict]:
179
  with db_pipeline_service.engine_scope(db_type, creds) as engine:
180
  with engine.connect() as conn:
181
  if db_type in _POSTGRES_LIKE:
@@ -185,4 +186,6 @@ class DbExecutor(BaseExecutor):
185
  text(f"SET statement_timeout = {_QUERY_TIMEOUT_SECONDS * 1000}")
186
  )
187
  result = conn.execute(text(compiled.sql), compiled.params)
188
- return [dict(row) for row in result.mappings()]
 
 
 
79
  )
80
  creds = decrypt_credentials_dict(client.credentials)
81
 
82
+ columns, rows = await asyncio.wait_for(
83
  asyncio.to_thread(self._run_sync, client.db_type, creds, compiled),
84
  timeout=_QUERY_TIMEOUT_SECONDS,
85
  )
 
97
  return QueryResult(
98
  source_id=ir.source_id,
99
  backend="sql",
100
+ columns=columns,
101
  rows=capped,
102
  row_count=len(capped),
103
  truncated=truncated,
 
176
  )
177
 
178
  @staticmethod
179
+ def _run_sync(db_type: str, creds: dict, compiled: CompiledSql) -> tuple[list[str], list[dict]]:
180
  with db_pipeline_service.engine_scope(db_type, creds) as engine:
181
  with engine.connect() as conn:
182
  if db_type in _POSTGRES_LIKE:
 
186
  text(f"SET statement_timeout = {_QUERY_TIMEOUT_SECONDS * 1000}")
187
  )
188
  result = conn.execute(text(compiled.sql), compiled.params)
189
+ columns = list(result.keys())
190
+ rows = [dict(row) for row in result.mappings()]
191
+ return columns, rows