Rifqi Hafizuddin commited on
Commit
2c8a3e8
·
1 Parent(s): 145bca3

[KM-512] create folder for querying from bd/tabular docs

Browse files
src/query/__init__.py ADDED
File without changes
src/query/base.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared contract for query executors."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass, field
5
+
6
+ from src.rag.base import RetrievalResult
7
+
8
+
9
+ @dataclass
10
+ class QueryResult:
11
+ source_type: str # "database" or "document"
12
+ source_id: str # database_client_id or document_id
13
+ table_or_file: str
14
+ columns: list[str]
15
+ rows: list[dict]
16
+ row_count: int
17
+ metadata: dict = field(default_factory=dict)
18
+
19
+
20
+ class BaseExecutor(ABC):
21
+ @abstractmethod
22
+ async def execute(
23
+ self,
24
+ results: list[RetrievalResult],
25
+ user_id: str,
26
+ limit: int = 100,
27
+ ) -> list[QueryResult]: ...
src/query/executor.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """QueryExecutor — dispatches retrieval results to the appropriate executor by source_type."""
2
+
3
+ import asyncio
4
+
5
+ from src.middlewares.logging import get_logger
6
+ from src.query.base import QueryResult
7
+ from src.query.executors.db import db_executor
8
+ from src.query.executors.tabular import tabular_executor
9
+ from src.rag.base import RetrievalResult
10
+
11
+ logger = get_logger("query_executor")
12
+
13
+
14
+ class QueryExecutor:
15
+ async def execute(
16
+ self,
17
+ results: list[RetrievalResult],
18
+ user_id: str,
19
+ limit: int = 100,
20
+ ) -> list[QueryResult]:
21
+ db_results = [r for r in results if r.source_type == "database"]
22
+ tabular_results = [
23
+ r for r in results
24
+ if r.source_type == "document"
25
+ and r.metadata.get("data", {}).get("file_type") in ("csv", "xlsx")
26
+ ]
27
+
28
+ async def _empty() -> list[QueryResult]:
29
+ return []
30
+
31
+ batches = await asyncio.gather(
32
+ db_executor.execute(db_results, user_id, limit) if db_results else _empty(),
33
+ tabular_executor.execute(tabular_results, user_id, limit) if tabular_results else _empty(),
34
+ return_exceptions=True,
35
+ )
36
+
37
+ query_results: list[QueryResult] = []
38
+ for batch in batches:
39
+ if isinstance(batch, Exception):
40
+ logger.error("executor failed", error=str(batch))
41
+ continue
42
+ query_results.extend(batch)
43
+
44
+ logger.info("query execution complete", total=len(query_results))
45
+ return query_results
46
+
47
+
48
+ query_executor = QueryExecutor()
src/query/executors/__init__.py ADDED
File without changes
src/query/executors/db.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Executor for registered database sources (source_type="database").
2
+
3
+ Flow:
4
+ 1. Group RetrievalResult chunks by database_client_id.
5
+ 2. For each client: decrypt creds -> connect -> SELECT relevant columns FROM table LIMIT n.
6
+ 3. Return QueryResult per (client_id, table_name).
7
+ """
8
+
9
+ from src.middlewares.logging import get_logger
10
+ from src.query.base import BaseExecutor, QueryResult
11
+ from src.rag.base import RetrievalResult
12
+
13
+ logger = get_logger("db_executor")
14
+
15
+
16
+ class DbExecutor(BaseExecutor):
17
+ async def execute(
18
+ self,
19
+ results: list[RetrievalResult],
20
+ user_id: str,
21
+ limit: int = 100,
22
+ ) -> list[QueryResult]:
23
+ # TODO: implement
24
+ # 1. filter results where source_type == "database"
25
+ # 2. group by (database_client_id, table_name) -> list of column_names
26
+ # 3. per group: look up DatabaseClient, decrypt creds, connect via db_pipeline_service
27
+ # 4. SELECT <columns> FROM <table> LIMIT limit
28
+ # 5. return QueryResult per group
29
+ raise NotImplementedError
30
+
31
+
32
+ db_executor = DbExecutor()
src/query/executors/tabular.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Executor for tabular document sources (source_type="document", file_type csv/xlsx).
2
+
3
+ Flow:
4
+ 1. Group RetrievalResult chunks by document_id.
5
+ 2. For each document: download bytes from Azure Blob -> read with pandas.
6
+ 3. Filter DataFrame to relevant columns identified by retrieval.
7
+ 4. Return QueryResult per document.
8
+ """
9
+
10
+ from src.middlewares.logging import get_logger
11
+ from src.query.base import BaseExecutor, QueryResult
12
+ from src.rag.base import RetrievalResult
13
+
14
+ logger = get_logger("tabular_executor")
15
+
16
+ _TABULAR_FILE_TYPES = ("csv", "xlsx")
17
+
18
+
19
+ class TabularExecutor(BaseExecutor):
20
+ async def execute(
21
+ self,
22
+ results: list[RetrievalResult],
23
+ user_id: str,
24
+ limit: int = 100,
25
+ ) -> list[QueryResult]:
26
+ # TODO: implement
27
+ # 1. filter results where source_type == "document" and file_type in _TABULAR_FILE_TYPES
28
+ # 2. group by document_id -> list of column_names
29
+ # 3. per group: look up Document by document_id -> get blob_name
30
+ # 4. blob_storage.download_file(blob_name) -> pd.read_csv / pd.read_excel
31
+ # 5. df[relevant_columns].head(limit) -> rows as list[dict]
32
+ # 6. return QueryResult per document
33
+ raise NotImplementedError
34
+
35
+
36
+ tabular_executor = TabularExecutor()