File size: 12,485 Bytes
2c8a3e8
 
 
3604994
 
73b7fe3
3604994
 
 
 
2c8a3e8
3604994
 
2c8a3e8
3604994
 
 
 
220f59e
 
3604994
 
2c8a3e8
 
 
 
 
 
3604994
 
 
 
 
 
2c8a3e8
3604994
 
 
 
 
 
 
 
 
 
 
 
 
73b7fe3
 
3604994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73b7fe3
3604994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a49dc1b
3604994
 
 
 
 
 
 
 
 
73b7fe3
 
 
3604994
 
 
 
2c8a3e8
 
 
3604994
 
 
 
 
 
 
 
 
 
 
 
 
 
2c8a3e8
 
 
 
3604994
 
2c8a3e8
 
3604994
 
a49dc1b
 
3604994
 
 
 
 
959b1b0
3604994
 
 
 
 
 
 
 
 
 
959b1b0
 
3604994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73b7fe3
3604994
959b1b0
2c8a3e8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
"""Executor for tabular document sources (source_type="document", file_type csv/xlsx).

Flow:
  1. Group RetrievalResult chunks by (document_id, sheet_name).
  2. Per group: download Parquet from Azure Blob → pandas DataFrame.
  3. Build schema context from DataFrame columns + sample values.
  4. LLM decides operation (groupby_sum, filter, top_n, etc.) via structured output.
  5. Pandas runs the operation; retry up to 3x on error with feedback to LLM.
  6. Fallback to raw rows if all retries fail.
  7. Return QueryResult per group.
"""
import asyncio
from typing import Literal, TypedDict

import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import AzureChatOpenAI
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession

from src.config.settings import settings
from src.knowledge.parquet_service import download_parquet
from src.middlewares.logging import get_logger
from src.query.base import BaseExecutor, QueryResult
from src.rag.base import RetrievalResult

logger = get_logger("tabular_executor")


class _GroupInfo(TypedDict):
    filename: str
    file_type: str


_TABULAR_FILE_TYPES = ("csv", "xlsx")
_MAX_RETRIES = 3

_SYSTEM_PROMPT = """\
You are a data analyst. Given a DataFrame schema and a user question, \
decide which pandas operation to perform.

IMPORTANT rules:
- Use ONLY the exact column names as written in the schema below. Never translate or rename them.
- For top_n: always set value_col to the column to sort by. Do NOT use sort_col for top_n.
- For sort: use sort_col for the column to sort by.
- For filter with comparison (>, <, >=, <=, !=): set filter_operator accordingly (gt, lt, gte, lte, ne). Default is eq (==).
- For multi-condition filters (AND logic), use the filters field as a list of {{"col", "value", "op"}} dicts instead of filter_col/filter_value.
  Example: status=SUCCESS AND amount_paid>200000 → filters=[{{"col":"status","value":"SUCCESS","op":"eq"}},{{"col":"amount_paid","value":"200000","op":"gt"}}]
- For OR conditions on a column (e.g. value is A or B), use or_filters. Combine with filters for mixed AND+OR logic.
  Example: (status=FAILED OR status=REVERSED) AND payment_channel=X → or_filters=[{{"col":"status","value":"FAILED","op":"eq"}},{{"col":"status","value":"REVERSED","op":"eq"}}], filters=[{{"col":"payment_channel","value":"X","op":"eq"}}]
- For groupby with a pre-filter (e.g. count SUCCESS per channel): use filters or or_filters to narrow rows first, then use groupby_count/groupby_sum/groupby_avg on the filtered data by setting both filters and group_col.

Schema:
{schema}

{error_section}"""


class TabularOperation(BaseModel):
    operation: Literal[
        "filter", "groupby_sum", "groupby_avg", "groupby_count",
        "top_n", "sort", "aggregate", "raw"
    ]
    group_col: str | None = None       # for groupby_*
    value_col: str | None = None       # for groupby_*, top_n, aggregate
    filter_col: str | None = None      # for single filter
    filter_value: str | None = None    # for single filter
    filter_operator: Literal["eq", "ne", "gt", "gte", "lt", "lte"] = "eq"  # for single filter
    filters: list[dict] | None = None     # for multi-condition AND: [{"col": ..., "value": ..., "op": ...}]
    or_filters: list[dict] | None = None  # for OR conditions, applied before AND filters
    sort_col: str | None = None        # for sort
    ascending: bool = True             # for sort
    n: int | None = None               # for top_n
    agg_func: Literal["sum", "avg", "min", "max", "count"] | None = None  # for aggregate
    reasoning: str


def _get_filter_mask(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.Series:
    numeric = pd.to_numeric(df[col], errors="coerce")
    if operator == "eq":
        return df[col].astype(str) == str(value)
    elif operator == "ne":
        return df[col].astype(str) != str(value)
    elif operator == "gt":
        return numeric > float(value)
    elif operator == "gte":
        return numeric >= float(value)
    elif operator == "lt":
        return numeric < float(value)
    elif operator == "lte":
        return numeric <= float(value)
    raise ValueError(f"Unknown operator: {operator}")


def _apply_single_filter(df: pd.DataFrame, col: str, value: str, operator: str) -> pd.DataFrame:
    return df[_get_filter_mask(df, col, value, operator)]


def _build_schema_context(df: pd.DataFrame) -> str:
    lines = []
    for col in df.columns:
        sample = df[col].dropna().head(3).tolist()
        lines.append(f"- {col} ({df[col].dtype}): sample values: {sample}")
    return "\n".join(lines)


def _apply_operation(df: pd.DataFrame, op: TabularOperation, limit: int) -> pd.DataFrame:
    if op.operation == "groupby_sum":
        if not op.group_col or not op.value_col:
            raise ValueError(f"groupby_sum requires group_col and value_col, got {op}")
        return df.groupby(op.group_col)[op.value_col].sum().reset_index().nlargest(limit, op.value_col)
    elif op.operation == "groupby_avg":
        if not op.group_col or not op.value_col:
            raise ValueError(f"groupby_avg requires group_col and value_col, got {op}")
        return df.groupby(op.group_col)[op.value_col].mean().reset_index().nlargest(limit, op.value_col)
    elif op.operation == "groupby_count":
        if not op.group_col:
            raise ValueError(f"groupby_count requires group_col, got {op}")
        df_filtered = df.copy()
        if op.or_filters:
            or_mask = pd.Series([False] * len(df_filtered), index=df_filtered.index)
            for f in op.or_filters:
                or_mask = or_mask | _get_filter_mask(df_filtered, f["col"], f["value"], f.get("op", "eq"))
            df_filtered = df_filtered[or_mask]
        if op.filters:
            for f in op.filters:
                df_filtered = _apply_single_filter(df_filtered, f["col"], f["value"], f.get("op", "eq"))
        elif op.filter_col and op.filter_value is not None:
            df_filtered = _apply_single_filter(df_filtered, op.filter_col, op.filter_value, op.filter_operator)
        return df_filtered.groupby(op.group_col).size().reset_index(name="count").nlargest(limit, "count")
    elif op.operation == "filter":
        result = df.copy()
        if op.or_filters:
            or_mask = pd.Series([False] * len(result), index=result.index)
            for f in op.or_filters:
                or_mask = or_mask | _get_filter_mask(result, f["col"], f["value"], f.get("op", "eq"))
            result = result[or_mask]
        if op.filters:
            for f in op.filters:
                result = _apply_single_filter(result, f["col"], f["value"], f.get("op", "eq"))
        elif op.filter_col and op.filter_value is not None and not op.or_filters:
            result = _apply_single_filter(result, op.filter_col, op.filter_value, op.filter_operator)
        elif not op.or_filters and not op.filters and (not op.filter_col or op.filter_value is None):
            raise ValueError(f"filter requires filter_col/filter_value or filters or or_filters, got {op}")
        return result.head(limit)
    elif op.operation == "top_n":
        col = op.value_col
        if not col:
            raise ValueError(f"top_n requires value_col, got {op}")
        n = op.n or limit
        return df.nlargest(n, col)
    elif op.operation == "sort":
        if not op.sort_col:
            raise ValueError(f"sort requires sort_col, got {op}")
        return df.sort_values(op.sort_col, ascending=op.ascending).head(limit)
    elif op.operation == "aggregate":
        if not op.value_col or not op.agg_func:
            raise ValueError(f"aggregate requires value_col and agg_func, got {op}")
        funcs = {"sum": "sum", "avg": "mean", "min": "min", "max": "max", "count": "count"}
        value = getattr(df[op.value_col], funcs[op.agg_func])()
        return pd.DataFrame([{op.value_col: value, "operation": op.agg_func}])
    else:  # "raw"
        return df.head(limit)


class TabularExecutor(BaseExecutor):
    def __init__(self) -> None:
        self._llm = AzureChatOpenAI(
            azure_deployment=settings.azureai_deployment_name_4o,
            openai_api_version=settings.azureai_api_version_4o,
            azure_endpoint=settings.azureai_endpoint_url_4o,
            api_key=settings.azureai_api_key_4o,
            temperature=0,
        )
        self._prompt = ChatPromptTemplate.from_messages([
            ("system", _SYSTEM_PROMPT),
            ("human", "{question}"),
        ])
        self._chain = self._prompt | self._llm.with_structured_output(TabularOperation)

    async def execute(
        self,
        results: list[RetrievalResult],
        user_id: str,
        _db: AsyncSession,
        question: str,
        limit: int = 100,
    ) -> list[QueryResult]:
        tabular = [
            r for r in results
            if r.source_type == "document"
            and r.metadata.get("data", {}).get("file_type") in _TABULAR_FILE_TYPES
        ]

        if not tabular:
            return []

        # Group by (document_id, sheet_name) — one parquet download per group
        groups: dict[tuple[str, str | None], _GroupInfo] = {}
        for r in tabular:
            data = r.metadata.get("data", {})
            doc_id = data.get("document_id")
            if not doc_id:
                continue
            sheet_name = data.get("sheet_name")  # None for CSV
            key = (doc_id, sheet_name)
            if key not in groups:
                groups[key] = {
                    "filename": data.get("filename", ""),
                    "file_type": data.get("file_type", ""),
                }

        async def _process_group(
            doc_id: str, sheet_name: str | None, info: _GroupInfo
        ) -> QueryResult | None:
            try:
                df = await download_parquet(user_id, doc_id, sheet_name)
                df_result = await self._query_with_agent(df, question, limit)

                table_label = info["filename"]
                if sheet_name:
                    table_label += f" / sheet: {sheet_name}"

                logger.info(
                    "tabular query complete",
                    document_id=doc_id,
                    sheet=sheet_name,
                    file_type=info["file_type"],
                    rows=len(df_result),
                    columns=len(df_result.columns),
                )
                return QueryResult(
                    source_type="document",
                    source_id=doc_id,
                    table_or_file=table_label,
                    columns=list(df_result.columns),
                    rows=df_result.to_dict(orient="records"),
                    row_count=len(df_result),
                )
            except Exception as e:
                logger.error(
                    "tabular query failed",
                    document_id=doc_id,
                    sheet=sheet_name,
                    error=str(e),
                )
                return None

        gathered = await asyncio.gather(*[
            _process_group(doc_id, sheet_name, info)
            for (doc_id, sheet_name), info in groups.items()
        ])
        return [r for r in gathered if r is not None]

    async def _query_with_agent(
        self, df: pd.DataFrame, question: str, limit: int
    ) -> pd.DataFrame:
        schema_ctx = _build_schema_context(df)
        prev_error = ""

        for attempt in range(_MAX_RETRIES):
            error_section = (
                f"Previous attempt failed: {prev_error}\nFix the issue."
                if prev_error else ""
            )
            try:
                op: TabularOperation = await self._chain.ainvoke({
                    "schema": schema_ctx,
                    "error_section": error_section,
                    "question": question,
                })
                logger.info(
                    "tabular operation decided",
                    operation=op.operation,
                    reasoning=op.reasoning,
                )
                return _apply_operation(df, op, limit)
            except Exception as e:
                prev_error = str(e)
                logger.warning("tabular agent error", attempt=attempt + 1, error=prev_error)

        # Fallback: return raw rows
        logger.warning("tabular agent failed after retries, returning raw rows")
        return df.head(limit)


tabular_executor = TabularExecutor()