Rifqi Hafizuddin commited on
Commit ·
2c8a3e8
1
Parent(s): 145bca3
[KM-512] create folder for querying from bd/tabular docs
Browse files- src/query/__init__.py +0 -0
- src/query/base.py +27 -0
- src/query/executor.py +48 -0
- src/query/executors/__init__.py +0 -0
- src/query/executors/db.py +32 -0
- src/query/executors/tabular.py +36 -0
src/query/__init__.py
ADDED
|
File without changes
|
src/query/base.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared contract for query executors."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
|
| 6 |
+
from src.rag.base import RetrievalResult
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class QueryResult:
|
| 11 |
+
source_type: str # "database" or "document"
|
| 12 |
+
source_id: str # database_client_id or document_id
|
| 13 |
+
table_or_file: str
|
| 14 |
+
columns: list[str]
|
| 15 |
+
rows: list[dict]
|
| 16 |
+
row_count: int
|
| 17 |
+
metadata: dict = field(default_factory=dict)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class BaseExecutor(ABC):
|
| 21 |
+
@abstractmethod
|
| 22 |
+
async def execute(
|
| 23 |
+
self,
|
| 24 |
+
results: list[RetrievalResult],
|
| 25 |
+
user_id: str,
|
| 26 |
+
limit: int = 100,
|
| 27 |
+
) -> list[QueryResult]: ...
|
src/query/executor.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""QueryExecutor — dispatches retrieval results to the appropriate executor by source_type."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
|
| 5 |
+
from src.middlewares.logging import get_logger
|
| 6 |
+
from src.query.base import QueryResult
|
| 7 |
+
from src.query.executors.db import db_executor
|
| 8 |
+
from src.query.executors.tabular import tabular_executor
|
| 9 |
+
from src.rag.base import RetrievalResult
|
| 10 |
+
|
| 11 |
+
logger = get_logger("query_executor")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class QueryExecutor:
|
| 15 |
+
async def execute(
|
| 16 |
+
self,
|
| 17 |
+
results: list[RetrievalResult],
|
| 18 |
+
user_id: str,
|
| 19 |
+
limit: int = 100,
|
| 20 |
+
) -> list[QueryResult]:
|
| 21 |
+
db_results = [r for r in results if r.source_type == "database"]
|
| 22 |
+
tabular_results = [
|
| 23 |
+
r for r in results
|
| 24 |
+
if r.source_type == "document"
|
| 25 |
+
and r.metadata.get("data", {}).get("file_type") in ("csv", "xlsx")
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
async def _empty() -> list[QueryResult]:
|
| 29 |
+
return []
|
| 30 |
+
|
| 31 |
+
batches = await asyncio.gather(
|
| 32 |
+
db_executor.execute(db_results, user_id, limit) if db_results else _empty(),
|
| 33 |
+
tabular_executor.execute(tabular_results, user_id, limit) if tabular_results else _empty(),
|
| 34 |
+
return_exceptions=True,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
query_results: list[QueryResult] = []
|
| 38 |
+
for batch in batches:
|
| 39 |
+
if isinstance(batch, Exception):
|
| 40 |
+
logger.error("executor failed", error=str(batch))
|
| 41 |
+
continue
|
| 42 |
+
query_results.extend(batch)
|
| 43 |
+
|
| 44 |
+
logger.info("query execution complete", total=len(query_results))
|
| 45 |
+
return query_results
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
query_executor = QueryExecutor()
|
src/query/executors/__init__.py
ADDED
|
File without changes
|
src/query/executors/db.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Executor for registered database sources (source_type="database").
|
| 2 |
+
|
| 3 |
+
Flow:
|
| 4 |
+
1. Group RetrievalResult chunks by database_client_id.
|
| 5 |
+
2. For each client: decrypt creds -> connect -> SELECT relevant columns FROM table LIMIT n.
|
| 6 |
+
3. Return QueryResult per (client_id, table_name).
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from src.middlewares.logging import get_logger
|
| 10 |
+
from src.query.base import BaseExecutor, QueryResult
|
| 11 |
+
from src.rag.base import RetrievalResult
|
| 12 |
+
|
| 13 |
+
logger = get_logger("db_executor")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DbExecutor(BaseExecutor):
|
| 17 |
+
async def execute(
|
| 18 |
+
self,
|
| 19 |
+
results: list[RetrievalResult],
|
| 20 |
+
user_id: str,
|
| 21 |
+
limit: int = 100,
|
| 22 |
+
) -> list[QueryResult]:
|
| 23 |
+
# TODO: implement
|
| 24 |
+
# 1. filter results where source_type == "database"
|
| 25 |
+
# 2. group by (database_client_id, table_name) -> list of column_names
|
| 26 |
+
# 3. per group: look up DatabaseClient, decrypt creds, connect via db_pipeline_service
|
| 27 |
+
# 4. SELECT <columns> FROM <table> LIMIT limit
|
| 28 |
+
# 5. return QueryResult per group
|
| 29 |
+
raise NotImplementedError
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
db_executor = DbExecutor()
|
src/query/executors/tabular.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Executor for tabular document sources (source_type="document", file_type csv/xlsx).
|
| 2 |
+
|
| 3 |
+
Flow:
|
| 4 |
+
1. Group RetrievalResult chunks by document_id.
|
| 5 |
+
2. For each document: download bytes from Azure Blob -> read with pandas.
|
| 6 |
+
3. Filter DataFrame to relevant columns identified by retrieval.
|
| 7 |
+
4. Return QueryResult per document.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from src.middlewares.logging import get_logger
|
| 11 |
+
from src.query.base import BaseExecutor, QueryResult
|
| 12 |
+
from src.rag.base import RetrievalResult
|
| 13 |
+
|
| 14 |
+
logger = get_logger("tabular_executor")
|
| 15 |
+
|
| 16 |
+
_TABULAR_FILE_TYPES = ("csv", "xlsx")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TabularExecutor(BaseExecutor):
|
| 20 |
+
async def execute(
|
| 21 |
+
self,
|
| 22 |
+
results: list[RetrievalResult],
|
| 23 |
+
user_id: str,
|
| 24 |
+
limit: int = 100,
|
| 25 |
+
) -> list[QueryResult]:
|
| 26 |
+
# TODO: implement
|
| 27 |
+
# 1. filter results where source_type == "document" and file_type in _TABULAR_FILE_TYPES
|
| 28 |
+
# 2. group by document_id -> list of column_names
|
| 29 |
+
# 3. per group: look up Document by document_id -> get blob_name
|
| 30 |
+
# 4. blob_storage.download_file(blob_name) -> pd.read_csv / pd.read_excel
|
| 31 |
+
# 5. df[relevant_columns].head(limit) -> rows as list[dict]
|
| 32 |
+
# 6. return QueryResult per document
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
tabular_executor = TabularExecutor()
|