sofhiaazzhr commited on
Commit
3604994
·
1 Parent(s): 7ff66c9

[KM-515][document] Make Query for Tabular Type (XLSX & CSV)

Browse files
Files changed (1) hide show
  1. src/query/executors/tabular.py +285 -13
src/query/executors/tabular.py CHANGED
@@ -1,39 +1,311 @@
1
  """Executor for tabular document sources (source_type="document", file_type csv/xlsx).
2
 
3
  Flow:
4
- 1. Group RetrievalResult chunks by document_id.
5
- 2. For each document: download bytes from Azure Blob -> read with pandas.
6
- 3. Filter DataFrame to relevant columns identified by retrieval.
7
- 4. Return QueryResult per document.
 
 
 
8
  """
 
 
9
 
 
 
 
 
10
  from sqlalchemy.ext.asyncio import AsyncSession
11
 
 
 
12
  from src.middlewares.logging import get_logger
13
  from src.query.base import BaseExecutor, QueryResult
14
  from src.rag.base import RetrievalResult
15
 
16
  logger = get_logger("tabular_executor")
17
 
 
 
 
 
 
 
 
18
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  class TabularExecutor(BaseExecutor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  async def execute(
23
  self,
24
  results: list[RetrievalResult],
25
  user_id: str,
26
- db: AsyncSession,
 
27
  limit: int = 100,
28
  ) -> list[QueryResult]:
29
- # TODO: implement
30
- # 1. filter results where source_type == "document" and file_type in _TABULAR_FILE_TYPES
31
- # 2. group by document_id -> list of column_names
32
- # 3. per group: look up Document by document_id -> get blob_name
33
- # 4. blob_storage.download_file(blob_name) -> pd.read_csv / pd.read_excel
34
- # 5. df[relevant_columns].head(limit) -> rows as list[dict]
35
- # 6. return QueryResult per document
36
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  tabular_executor = TabularExecutor()
 
1
  """Executor for tabular document sources (source_type="document", file_type csv/xlsx).
2
 
3
  Flow:
4
+ 1. Group RetrievalResult chunks by (document_id, sheet_name).
5
+ 2. Per group: download Parquet from Azure Blob pandas DataFrame.
6
+ 3. Build schema context from DataFrame columns + sample values.
7
+ 4. LLM decides operation (groupby_sum, filter, top_n, etc.) via structured output.
8
+ 5. Pandas runs the operation; retry up to 3x on error with feedback to LLM.
9
+ 6. Fallback to raw rows if all retries fail.
10
+ 7. Return QueryResult per group.
11
  """
12
+ import asyncio
13
+ from typing import Literal, TypedDict
14
 
15
+ import pandas as pd
16
+ from langchain_core.prompts import ChatPromptTemplate
17
+ from langchain_openai import AzureChatOpenAI
18
+ from pydantic import BaseModel
19
  from sqlalchemy.ext.asyncio import AsyncSession
20
 
21
+ from src.config.settings import settings
22
+ from src.knowledge.parquet_service import download_parquet
23
  from src.middlewares.logging import get_logger
24
  from src.query.base import BaseExecutor, QueryResult
25
  from src.rag.base import RetrievalResult
26
 
27
  logger = get_logger("tabular_executor")
28
 
29
+
30
+ class _GroupInfo(TypedDict):
31
+ columns: list[str]
32
+ filename: str
33
+ file_type: str
34
+
35
+
36
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
37
+ _MAX_RETRIES = 3
38
+
39
+ _SYSTEM_PROMPT = """\
40
+ You are a data analyst. Given a DataFrame schema and a user question, \
41
+ decide which pandas operation to perform.
42
+
43
+ IMPORTANT rules:
44
+ - Use ONLY the exact column names as written in the schema below. Never translate or rename them.
45
+ - For top_n: always set value_col to the column to sort by. Do NOT use sort_col for top_n.
46
+ - For sort: use sort_col for the column to sort by.
47
+ - For filter with comparison (>, <, >=, <=, !=): set filter_operator accordingly (gt, lt, gte, lte, ne). Default is eq (==).
48
+ - For multi-condition filters (AND logic), use the filters field as a list of {{"col", "value", "op"}} dicts instead of filter_col/filter_value.
49
+ Example: status=SUCCESS AND amount_paid>200000 → filters=[{{"col":"status","value":"SUCCESS","op":"eq"}},{{"col":"amount_paid","value":"200000","op":"gt"}}]
50
+ - IMPORTANT: When the question uses "or" / "atau" between values of the same column, you MUST use or_filters (NOT filters).
51
+ or_filters applies OR logic: rows matching ANY condition are kept.
52
+ filters applies AND logic: rows must match ALL conditions.
53
+ Example: "(status FAILED or REVERSED) AND payment_channel=Tokopedia" →
54
+ or_filters=[{{"col":"status","value":"FAILED","op":"eq"}},{{"col":"status","value":"REVERSED","op":"eq"}}]
55
+ filters=[{{"col":"payment_channel","value":"Tokopedia","op":"eq"}}]
56
+ - For groupby with a pre-filter (e.g. count SUCCESS per channel): use filters or or_filters to narrow rows first, then use groupby_count/groupby_sum/groupby_avg on the filtered data by setting both filters and group_col.
57
+
58
+ Schema:
59
+ {schema}
60
+
61
+ {error_section}"""
62
+
63
+
64
+ class TabularOperation(BaseModel):
65
+ operation: Literal[
66
+ "filter", "groupby_sum", "groupby_avg", "groupby_count",
67
+ "top_n", "sort", "aggregate", "raw"
68
+ ]
69
+ group_col: str | None = None # for groupby_*
70
+ value_col: str | None = None # for groupby_*, top_n, aggregate
71
+ filter_col: str | None = None # for single filter
72
+ filter_value: str | None = None # for single filter
73
+ filter_operator: Literal["eq", "ne", "gt", "gte", "lt", "lte"] = "eq" # for single filter
74
+ filters: list[dict] | None = None # for multi-condition AND: [{"col": ..., "value": ..., "op": ...}]
75
+ or_filters: list[dict] | None = None # for OR conditions, applied before AND filters
76
+ sort_col: str | None = None # for sort
77
+ ascending: bool = True # for sort
78
+ n: int | None = None # for top_n
79
+ agg_func: Literal["sum", "avg", "min", "max", "count"] | None = None # for aggregate
80
+ reasoning: str
81
+
82
+
83
+ def _get_filter_mask(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.Series:
84
+ numeric = pd.to_numeric(df[col], errors="coerce")
85
+ if operator == "eq":
86
+ return df[col].astype(str) == str(value)
87
+ elif operator == "ne":
88
+ return df[col].astype(str) != str(value)
89
+ elif operator == "gt":
90
+ return numeric > float(value)
91
+ elif operator == "gte":
92
+ return numeric >= float(value)
93
+ elif operator == "lt":
94
+ return numeric < float(value)
95
+ elif operator == "lte":
96
+ return numeric <= float(value)
97
+ raise ValueError(f"Unknown operator: {operator}")
98
+
99
+
100
+ def _apply_single_filter(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.DataFrame:
101
+ numeric = pd.to_numeric(df[col], errors="coerce")
102
+ if operator == "eq":
103
+ return df[df[col].astype(str) == str(value)]
104
+ elif operator == "ne":
105
+ return df[df[col].astype(str) != str(value)]
106
+ elif operator == "gt":
107
+ return df[numeric > float(value)]
108
+ elif operator == "gte":
109
+ return df[numeric >= float(value)]
110
+ elif operator == "lt":
111
+ return df[numeric < float(value)]
112
+ elif operator == "lte":
113
+ return df[numeric <= float(value)]
114
+ raise ValueError(f"Unknown operator: {operator}")
115
+
116
+
117
+ def _build_schema_context(df: pd.DataFrame) -> str:
118
+ lines = []
119
+ for col in df.columns:
120
+ sample = df[col].dropna().head(3).tolist()
121
+ lines.append(f"- {col} ({df[col].dtype}): sample values: {sample}")
122
+ return "\n".join(lines)
123
+
124
+
125
+ def _apply_operation(df: pd.DataFrame, op: TabularOperation, limit: int) -> pd.DataFrame:
126
+ if op.operation == "groupby_sum":
127
+ if not op.group_col or not op.value_col:
128
+ raise ValueError(f"groupby_sum requires group_col and value_col, got {op}")
129
+ return df.groupby(op.group_col)[op.value_col].sum().reset_index().nlargest(limit, op.value_col)
130
+ elif op.operation == "groupby_avg":
131
+ if not op.group_col or not op.value_col:
132
+ raise ValueError(f"groupby_avg requires group_col and value_col, got {op}")
133
+ return df.groupby(op.group_col)[op.value_col].mean().reset_index().nlargest(limit, op.value_col)
134
+ elif op.operation == "groupby_count":
135
+ if not op.group_col:
136
+ raise ValueError(f"groupby_count requires group_col, got {op}")
137
+ df_filtered = df.copy()
138
+ if op.or_filters:
139
+ or_mask = pd.Series([False] * len(df_filtered), index=df_filtered.index)
140
+ for f in op.or_filters:
141
+ or_mask = or_mask | _get_filter_mask(df_filtered, f["col"], f["value"], f.get("op", "eq"))
142
+ df_filtered = df_filtered[or_mask]
143
+ if op.filters:
144
+ for f in op.filters:
145
+ df_filtered = _apply_single_filter(df_filtered, f["col"], f["value"], f.get("op", "eq"))
146
+ elif op.filter_col and op.filter_value is not None:
147
+ df_filtered = _apply_single_filter(df_filtered, op.filter_col, op.filter_value, op.filter_operator)
148
+ return df_filtered.groupby(op.group_col).size().reset_index(name="count").nlargest(limit, "count")
149
+ elif op.operation == "filter":
150
+ result = df.copy()
151
+ if op.or_filters:
152
+ or_mask = pd.Series([False] * len(result), index=result.index)
153
+ for f in op.or_filters:
154
+ or_mask = or_mask | _get_filter_mask(result, f["col"], f["value"], f.get("op", "eq"))
155
+ result = result[or_mask]
156
+ if op.filters:
157
+ for f in op.filters:
158
+ result = _apply_single_filter(result, f["col"], f["value"], f.get("op", "eq"))
159
+ elif op.filter_col and op.filter_value is not None and not op.or_filters:
160
+ result = _apply_single_filter(result, op.filter_col, op.filter_value, op.filter_operator)
161
+ elif not op.or_filters and not op.filters and (not op.filter_col or op.filter_value is None):
162
+ raise ValueError(f"filter requires filter_col/filter_value or filters or or_filters, got {op}")
163
+ return result.head(limit)
164
+ elif op.operation == "top_n":
165
+ col = op.value_col or op.sort_col
166
+ if not col:
167
+ raise ValueError(f"top_n requires value_col, got {op}")
168
+ n = op.n or limit
169
+ return df.nlargest(n, col)
170
+ elif op.operation == "sort":
171
+ if not op.sort_col:
172
+ raise ValueError(f"sort requires sort_col, got {op}")
173
+ return df.sort_values(op.sort_col, ascending=op.ascending).head(limit)
174
+ elif op.operation == "aggregate":
175
+ if not op.value_col or not op.agg_func:
176
+ raise ValueError(f"aggregate requires value_col and agg_func, got {op}")
177
+ funcs = {"sum": "sum", "avg": "mean", "min": "min", "max": "max", "count": "count"}
178
+ value = getattr(df[op.value_col], funcs[op.agg_func])()
179
+ return pd.DataFrame([{op.value_col: value, "operation": op.agg_func}])
180
+ else: # "raw"
181
+ return df.head(limit)
182
 
183
 
184
  class TabularExecutor(BaseExecutor):
185
+ def __init__(self) -> None:
186
+ self._llm = AzureChatOpenAI(
187
+ azure_deployment=settings.azureai_deployment_name_4o,
188
+ openai_api_version=settings.azureai_api_version_4o,
189
+ azure_endpoint=settings.azureai_endpoint_url_4o,
190
+ api_key=settings.azureai_api_key_4o,
191
+ temperature=0,
192
+ )
193
+ self._prompt = ChatPromptTemplate.from_messages([
194
+ ("system", _SYSTEM_PROMPT),
195
+ ("human", "{question}"),
196
+ ])
197
+ self._chain = self._prompt | self._llm.with_structured_output(TabularOperation)
198
+
199
  async def execute(
200
  self,
201
  results: list[RetrievalResult],
202
  user_id: str,
203
+ _db: AsyncSession,
204
+ question: str,
205
  limit: int = 100,
206
  ) -> list[QueryResult]:
207
+ tabular = [
208
+ r for r in results
209
+ if r.metadata.get("data", {}).get("file_type") in _TABULAR_FILE_TYPES
210
+ ]
211
+
212
+ if not tabular:
213
+ return []
214
+
215
+ # Group by (document_id, sheet_name) → collect relevant column names
216
+ groups: dict[tuple[str, str | None], _GroupInfo] = {}
217
+ for r in tabular:
218
+ data = r.metadata.get("data", {})
219
+ doc_id = data.get("document_id")
220
+ if not doc_id:
221
+ continue
222
+ sheet_name = data.get("sheet_name") # None for CSV
223
+ col_name = data.get("column_name")
224
+ filename = data.get("filename", "")
225
+ file_type = data.get("file_type", "")
226
+
227
+ key = (doc_id, sheet_name)
228
+ if key not in groups:
229
+ groups[key] = {
230
+ "columns": [],
231
+ "filename": filename,
232
+ "file_type": file_type,
233
+ }
234
+ if col_name and col_name not in groups[key]["columns"]:
235
+ groups[key]["columns"].append(col_name)
236
+
237
+ async def _process_group(
238
+ doc_id: str, sheet_name: str | None, info: _GroupInfo
239
+ ) -> QueryResult | None:
240
+ try:
241
+ df = await download_parquet(user_id, doc_id, sheet_name)
242
+ df_result = await self._query_with_agent(df, question, limit)
243
+
244
+ table_label = info["filename"]
245
+ if sheet_name:
246
+ table_label += f" / sheet: {sheet_name}"
247
+
248
+ logger.info(
249
+ "tabular query complete",
250
+ document_id=doc_id,
251
+ sheet=sheet_name,
252
+ file_type=info["file_type"],
253
+ rows=len(df_result),
254
+ columns=len(df_result.columns),
255
+ )
256
+ return QueryResult(
257
+ source_type="document",
258
+ source_id=doc_id,
259
+ table_or_file=table_label,
260
+ columns=list(df_result.columns),
261
+ rows=df_result.to_dict(orient="records"),
262
+ row_count=len(df_result),
263
+ )
264
+ except Exception as e:
265
+ logger.error(
266
+ "tabular query failed",
267
+ document_id=doc_id,
268
+ sheet=sheet_name,
269
+ error=str(e),
270
+ )
271
+ return None
272
+
273
+ gathered = await asyncio.gather(*[
274
+ _process_group(doc_id, sheet_name, info)
275
+ for (doc_id, sheet_name), info in groups.items()
276
+ ])
277
+ return [r for r in gathered if r is not None]
278
+
279
+ async def _query_with_agent(
280
+ self, df: pd.DataFrame, question: str, limit: int
281
+ ) -> pd.DataFrame:
282
+ schema_ctx = _build_schema_context(df)
283
+ prev_error = ""
284
+
285
+ for attempt in range(_MAX_RETRIES):
286
+ error_section = (
287
+ f"Previous attempt failed: {prev_error}\nFix the issue."
288
+ if prev_error else ""
289
+ )
290
+ try:
291
+ op: TabularOperation = await self._chain.ainvoke({
292
+ "schema": schema_ctx,
293
+ "error_section": error_section,
294
+ "question": question,
295
+ })
296
+ logger.info(
297
+ "tabular operation decided",
298
+ operation=op.operation,
299
+ reasoning=op.reasoning,
300
+ )
301
+ return _apply_operation(df, op, limit)
302
+ except Exception as e:
303
+ prev_error = str(e)
304
+ logger.warning("tabular agent error", attempt=attempt + 1, error=prev_error)
305
+
306
+ # Fallback: return raw rows
307
+ logger.warning("tabular agent failed after retries, returning raw rows")
308
+ return df.head(limit)
309
 
310
 
311
  tabular_executor = TabularExecutor()