File size: 5,413 Bytes
6bff5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""QueryService β€” orchestrates plan β†’ validate β†’ compile β†’ execute.

Top-level entry point for catalog-driven structured queries. Wired into
the chat endpoint when source_hint == "structured".

Flow per call:
  1. Plan (LLM): question + catalog β†’ QueryIR
  2. Validate IR against catalog. On failure, re-prompt the planner with the
     error context and retry (up to `max_retries` total attempts).
  3. Dispatch IR to the right executor by `source.source_type`.
  4. Execute. Any exception (including NotImplementedError from the
     TabularExecutor placeholder) is caught and surfaced via
     `QueryResult.error` so the chatbot can branch on success / failure.

The service never raises β€” every code path returns a `QueryResult`.
"""

from __future__ import annotations

from collections.abc import Callable

from src.middlewares.logging import get_logger

from ..catalog.models import Catalog
from .executor.base import QueryResult
from .executor.dispatcher import ExecutorDispatcher
from .ir.validator import IRValidationError, IRValidator
from .planner.service import QueryPlannerService

logger = get_logger("query_service")


class QueryService:
    """End-to-end runner for a user question against a catalog.

    All heavy dependencies are injectable so unit tests don't need real
    LLMs or DB engines. Default constructors lazy-build the production
    deps so importing this module is side-effect-free.
    """

    def __init__(
        self,
        planner: QueryPlannerService | None = None,
        validator: IRValidator | None = None,
        dispatcher_factory: Callable[[Catalog], ExecutorDispatcher] | None = None,
        max_retries: int = 3,
    ) -> None:
        self._planner = planner or QueryPlannerService()
        self._validator = validator or IRValidator()
        self._dispatcher_factory = dispatcher_factory or ExecutorDispatcher
        self._max_retries = max(1, max_retries)

    async def run(self, user_id: str, question: str, catalog: Catalog) -> QueryResult:
        if not catalog.sources:
            return _error_result(
                source_id="",
                error="No structured data registered yet β€” connect a database "
                "or upload a CSV/XLSX before asking data questions.",
            )

        # ---------- planner + validator with retry ------------------
        previous_error: str | None = None
        ir = None
        for attempt in range(1, self._max_retries + 1):
            try:
                ir = await self._planner.plan(question, catalog, previous_error)
            except Exception as e:
                logger.error("planner crashed", attempt=attempt, error=str(e))
                return _error_result(source_id="", error=f"planner failed: {e}")

            try:
                self._validator.validate(ir, catalog)
                logger.info(
                    "ir planned and validated",
                    attempt=attempt,
                    source_id=ir.source_id,
                    table_id=ir.table_id,
                    select=[s.model_dump() for s in ir.select],
                    filters=[f.model_dump() for f in ir.filters],
                    group_by=ir.group_by,
                )
                break
            except IRValidationError as e:
                previous_error = str(e)
                logger.warning(
                    "ir validation failed",
                    attempt=attempt,
                    error=previous_error,
                )
                ir = None  # discard invalid IR
                continue
        else:
            return _error_result(
                source_id="",
                error=(
                    f"Planner could not produce a valid IR after "
                    f"{self._max_retries} attempts. Last error: {previous_error}"
                ),
            )

        # `ir` is non-None and valid here (guarded by the for/else above)
        assert ir is not None

        # ---------- dispatch + execute ------------------------------
        try:
            dispatcher = self._dispatcher_factory(catalog)
            executor = dispatcher.pick(ir)
        except Exception as e:
            logger.error("dispatch failed", source_id=ir.source_id, error=str(e))
            return _error_result(source_id=ir.source_id, error=f"dispatch failed: {e}")

        try:
            return await executor.run(ir)
        except NotImplementedError as e:
            # TabularExecutor placeholder β€” TAB owner ships PR3-TAB
            logger.warning(
                "executor not yet implemented",
                source_id=ir.source_id,
                error=str(e),
            )
            return _error_result(
                source_id=ir.source_id,
                error="Tabular execution is not yet available β€” pending PR3-TAB.",
            )
        except Exception as e:
            logger.error("executor crashed", source_id=ir.source_id, error=str(e))
            return _error_result(
                source_id=ir.source_id, error=f"executor failed: {e}"
            )


def _error_result(source_id: str, error: str) -> QueryResult:
    """Build a uniform error QueryResult.

    `backend` is intentionally empty when the failure happens before an
    executor is picked β€” the caller can still distinguish via `error`.
    """
    return QueryResult(source_id=source_id, backend="", error=error)